From 62302259d093517341fc833bab6e6dcd9cb9bc9e Mon Sep 17 00:00:00 2001 From: ahmedalbahnasawy <50653875+ahmedalbahnasawy@users.noreply.github.com> Date: Mon, 14 Nov 2022 20:11:42 +0400 Subject: [PATCH 001/174] add kaldifeat (#680) --- docker/Ubuntu18.04-pytorch1.12.1-cuda11.3-cudnn8/Dockerfile | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docker/Ubuntu18.04-pytorch1.12.1-cuda11.3-cudnn8/Dockerfile b/docker/Ubuntu18.04-pytorch1.12.1-cuda11.3-cudnn8/Dockerfile index 524303fb8..3637d2f11 100644 --- a/docker/Ubuntu18.04-pytorch1.12.1-cuda11.3-cudnn8/Dockerfile +++ b/docker/Ubuntu18.04-pytorch1.12.1-cuda11.3-cudnn8/Dockerfile @@ -68,6 +68,7 @@ RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \ cd /workspace/icefall && \ pip install -r requirements.txt +RUN pip install kaldifeat ENV PYTHONPATH /workspace/icefall:$PYTHONPATH -WORKDIR /workspace/icefall \ No newline at end of file +WORKDIR /workspace/icefall From 952a7b3fcc7dc1ad0f87f431edecdcb4f2c6fd3b Mon Sep 17 00:00:00 2001 From: Tiance Wang Date: Tue, 15 Nov 2022 10:45:48 +0800 Subject: [PATCH 002/174] Fix typo (#681) * Update add_alignment_librispeech.py * Update scaling_converter.py --- egs/librispeech/ASR/local/add_alignment_librispeech.py | 2 +- .../ASR/pruned_transducer_stateless3/scaling_converter.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/local/add_alignment_librispeech.py b/egs/librispeech/ASR/local/add_alignment_librispeech.py index cd1bcea67..fe6a26c51 100755 --- a/egs/librispeech/ASR/local/add_alignment_librispeech.py +++ b/egs/librispeech/ASR/local/add_alignment_librispeech.py @@ -171,7 +171,7 @@ def add_alignment( ali = alignments[origin_id] else: logging.info( - f"Warning: {origin_id} does not has alignment." + f"Warning: {origin_id} does not have alignment." ) ali = [] subcut.alignment = {"word": ali} diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py b/egs/librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py index 1e7e808c7..1e6022b57 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py @@ -87,7 +87,7 @@ def scaled_linear_to_linear(scaled_linear: ScaledLinear) -> nn.Linear: in_features=scaled_linear.in_features, out_features=scaled_linear.out_features, bias=True, # otherwise, it throws errors when converting to PNNX format - # device=weight.device, # Pytorch version before v1.9.0 does not has + # device=weight.device, # Pytorch version before v1.9.0 does not have # this argument. Comment out for now, we will # see if it will raise error for versions # after v1.9.0 From 855c76655b49de61a5c6d054a7ff2158a639e6f7 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 15 Nov 2022 16:56:05 +0800 Subject: [PATCH 003/174] Add zipformer from Dan using multi-dataset setup (#675) * Bug fix * Change subsamplling factor from 1 to 2 * Implement AttentionCombine as replacement for RandomCombine * Decrease random_prob from 0.5 to 0.333 * Add print statement * Apply single_prob mask, so sometimes we just get one layer as output. * Introduce feature mask per frame * Include changes from Liyong about padding conformer module. * Reduce single_prob from 0.5 to 0.25 * Reduce feature_mask_dropout_prob from 0.25 to 0.15. * Remove dropout from inside ConformerEncoderLayer, for adding to residuals * Increase feature_mask_dropout_prob from 0.15 to 0.2. * Swap random_prob and single_prob, to reduce prob of being randomized. * Decrease feature_mask_dropout_prob back from 0.2 to 0.15, i.e. revert the 43->48 change. * Randomize order of some modules * Bug fix * Stop backprop bug * Introduce a scale dependent on the masking value * Implement efficient layer dropout * Simplify the learned scaling factor on the modules * Compute valid loss on batch 0. * Make the scaling factors more global and the randomness of dropout more random * Bug fix * Introduce offset in layerdrop_scaleS * Remove final combination; implement layer drop that drops the final layers. * Bug fices * Fix bug RE self.training * Fix bug setting layerdrop mask * Fix eigs call * Add debug info * Remove warmup * Remove layer dropout and model-level warmup * Don't always apply the frame mask * Slight code cleanup/simplification * Various fixes, finish implementating frame masking * Remove debug info * Don't compute validation if printing diagnostics. * Apply layer bypass during warmup in a new way, including 2s and 4s of layers. * Update checkpoint.py to deal with int params * Revert initial_scale to previous values. * Remove the feature where it was bypassing groups of layers. * Implement layer dropout with probability 0.075 * Fix issue with warmup in test time * Add warmup schedule where dropout disappears from earlier layers first. * Have warmup that gradually removes dropout from layers; multiply initialization scales by 0.1. * Do dropout a different way * Fix bug in warmup * Remove debug print * Make the warmup mask per frame. * Implement layer dropout (in a relatively efficient way) * Decrease initial keep_prob to 0.25. * Make it start warming up from the very start, and increase warmup_batches to 6k * Change warmup schedule and increase warmup_batches from 4k to 6k * Make the bypass scale trainable. * Change the initial keep-prob back from 0.25 to 0.5 * Bug fix * Limit bypass scale to >= 0.1 * Revert "Change warmup schedule and increase warmup_batches from 4k to 6k" This reverts commit 86845bd5d859ceb6f83cd83f3719c3e6641de987. * Do warmup by dropping out whole layers. * Decrease frequency of logging variance_proportion * Make layerdrop different in different processes. * For speed, drop the same num layers per job. * Decrease initial_layerdrop_prob from 0.75 to 0.5 * Revert also the changes in scaled_adam_exp85 regarding warmup schedule * Remove unused code LearnedScale. * Reintroduce batching to the optimizer * Various fixes from debugging with nvtx, but removed the NVTX annotations. * Only apply ActivationBalancer with prob 0.25. * Fix s -> scaling for import. * Increase final layerdrop prob from 0.05 to 0.075 * Fix bug where fewer layers were dropped than should be; remove unnecesary print statement. * Fix bug in choosing layers to drop * Refactor RelPosMultiheadAttention to have 2nd forward function and introduce more modules in conformer encoder layer * Reduce final layerdrop_prob from 0.075 to 0.05. * Fix issue with diagnostics if stats is None * Remove persistent attention scores. * Make ActivationBalancer and MaxEig more efficient. * Cosmetic improvements * Change scale_factor_scale from 0.5 to 0.8 * Make the ActivationBalancer regress to the data mean, not zero, when enforcing abs constraint. * Remove unused config value * Fix bug when channel_dim < 0 * Fix bug when channel_dim < 0 * Simplify how the positional-embedding scores work in attention (thanks to Zengwei for this concept) * Revert dropout on attention scores to 0.0. * This should just be a cosmetic change, regularizing how we get the warmup times from the layers. * Reduce beta from 0.75 to 0.0. * Reduce stats period from 10 to 4. * Reworking of ActivationBalancer code to hopefully balance speed and effectiveness. * Add debug code for attention weihts and eigs * Remove debug statement * Add different debug info. * Penalize attention-weight entropies above a limit. * Remove debug statements * use larger delta but only penalize if small grad norm * Bug fixes; change debug freq * Change cutoff for small_grad_norm * Implement whitening of values in conformer. * Also whiten the keys in conformer. * Fix an issue with scaling of grad. * Decrease whitening limit from 2.0 to 1.1. * Fix debug stats. * Reorganize Whiten() code; configs are not the same as before. Also remove MaxEig for self_attn module * Bug fix RE float16 * Revert whitening_limit from 1.1 to 2.2. * Replace MaxEig with Whiten with limit=5.0, and move it to end of ConformerEncoderLayer * Change LR schedule to start off higher * Simplify the dropout mask, no non-dropped-out sequences * Make attention dims configurable, not embed_dim//2, trying 256. * Reduce attention_dim to 192; cherry-pick scaled_adam_exp130 which is linear_pos interacting with query * Use half the dim for values, vs. keys and queries. * Increase initial-lr from 0.04 to 0.05, plus changes for diagnostics * Cosmetic changes * Changes to avoid bug in backward hooks, affecting diagnostics. * Random clip attention scores to -5..5. * Add some random clamping in model.py * Add reflect=0.1 to invocations of random_clamp() * Remove in_balancer. * Revert model.py so there are no constraints on the output. * Implement randomized backprop for softmax. * Reduce min_abs from 1e-03 to 1e-04 * Add RandomGrad with min_abs=1.0e-04 * Use full precision to do softmax and store ans. * Fix bug in backprop of random_clamp() * Get the randomized backprop for softmax in autocast mode working. * Remove debug print * Reduce min_abs from 1.0e-04 to 5.0e-06 * Add hard limit of attention weights to +- 50 * Use normal implementation of softmax. * Remove use of RandomGrad * Remove the use of random_clamp in conformer.py. * Reduce the limit on attention weights from 50 to 25. * Reduce min_prob of ActivationBalancer from 0.1 to 0.05. * Penalize too large weights in softmax of AttentionDownsample() * Also apply limit on logit in SimpleCombiner * Increase limit on logit for SimpleCombiner to 25.0 * Add more diagnostics to debug gradient scale problems * Changes to grad scale logging; increase grad scale more frequently if less than one. * Add logging * Remove comparison diagnostics, which were not that useful. * Configuration changes: scores limit 5->10, min_prob 0.05->0.1, cur_grad_scale more aggressive increase * Reset optimizer state when we change loss function definition. * Make warmup period decrease scale on simple loss, leaving pruned loss scale constant. * Cosmetic change * Increase initial-lr from 0.05 to 0.06. * Increase initial-lr from 0.06 to 0.075 and decrease lr-epochs from 3.5 to 3. * Fixes to logging statements. * Introduce warmup schedule in optimizer * Increase grad_scale to Whiten module * Add inf check hooks * Renaming in optim.py; remove step() from scan_pessimistic_batches_for_oom in train.py * Change base lr to 0.1, also rename from initial lr in train.py * Adding activation balancers after simple_am_prob and simple_lm_prob * Reduce max_abs on am_balancer * Increase max_factor in final lm_balancer and am_balancer * Use penalize_abs_values_gt, not ActivationBalancer. * Trying to reduce grad_scale of Whiten() from 0.02 to 0.01. * Add hooks.py, had negleted to git add it. * don't do penalize_values_gt on simple_lm_proj and simple_am_proj; reduce --base-lr from 0.1 to 0.075 * Increase probs of activation balancer and make it decay slower. * Dont print out full non-finite tensor * Increase default max_factor for ActivationBalancer from 0.02 to 0.04; decrease max_abs in ConvolutionModule.deriv_balancer2 from 100.0 to 20.0 * reduce initial scale in GradScaler * Increase max_abs in ActivationBalancer of conv module from 20 to 50 * --base-lr0.075->0.5; --lr-epochs 3->3.5 * Revert 179->180 change, i.e. change max_abs for deriv_balancer2 back from 50.0 20.0 * Save some memory in the autograd of DoubleSwish. * Change the discretization of the sigmoid to be expectation preserving. * Fix randn to rand * Try a more exact way to round to uint8 that should prevent ever wrapping around to zero * Make it use float16 if in amp but use clamp to avoid wrapping error * Store only half precision output for softmax. * More memory efficient backprop for DoubleSwish. * Change to warmup schedule. * Changes to more accurately estimate OOM conditions * Reduce cutoff from 100 to 5 for estimating OOM with warmup * Make 20 the limit for warmup_count * Cast to float16 in DoubleSwish forward * Hopefully make penalize_abs_values_gt more memory efficient. * Add logging about memory used. * Change scalar_max in optim.py from 2.0 to 5.0 * Regularize how we apply the min and max to the eps of BasicNorm * Fix clamping of bypass scale; remove a couple unused variables. * Increase floor on bypass_scale from 0.1 to 0.2. * Increase bypass_scale from 0.2 to 0.4. * Increase bypass_scale min from 0.4 to 0.5 * Rename conformer.py to zipformer.py * Rename Conformer to Zipformer * Update decode.py by copying from pruned_transducer_stateless5 and changing directory name * Remove some unused variables. * Fix clamping of epsilon * Refactor zipformer for more flexibility so we can change number of encoder layers. * Have a 3rd encoder, at downsampling factor of 8. * Refactor how the downsampling is done so that it happens later, but the 1st encoder stack still operates after a subsampling of 2. * Fix bug RE seq lengths * Have 4 encoder stacks * Have 6 different encoder stacks, U-shaped network. * Reduce dim of linear positional encoding in attention layers. * Reduce min of bypass_scale from 0.5 to 0.3, and make it not applied in test mode. * Tuning change to num encoder layers, inspired by relative param importance. * Make decoder group size equal to 4. * Add skip connections as in normal U-net * Avoid falling off the loop for weird inputs * Apply layer-skip dropout prob * Have warmup schedule for layer-skipping * Rework how warmup count is produced; should not affect results. * Add warmup schedule for zipformer encoder layer, from 1.0 -> 0.2. * Reduce initial clamp_min for bypass_scale from 1.0 to 0.5. * Restore the changes from scaled_adam_219 and scaled_adam_exp220, accidentally lost, re layer skipping * Change to schedule of bypass_scale min: make it larger, decrease slower. * Change schedule after initial loss not promising * Implement pooling module, add it after initial feedforward. * Bug fix * Introduce dropout rate to dynamic submodules of conformer. * Introduce minimum probs in the SimpleCombiner * Add bias in weight module * Remove dynamic weights in SimpleCombine * Remove the 5th of 6 encoder stacks * Fix some typos * small fixes * small fixes * Copy files * Update decode.py * Add changes from the master * Add changes from the master * update results * Add CI * Small fixes * Small fixes Co-authored-by: Daniel Povey --- ...pruned-transducer-stateless7-2022-11-11.sh | 1 + ...pruned-transducer-stateless8-2022-11-14.sh | 116 ++ .../run-librispeech-2022-11-14-stateless8.yml | 155 ++ egs/librispeech/ASR/README.md | 1 + egs/librispeech/ASR/RESULTS.md | 58 + .../jit_pretrained.py | 1 + .../pruned_transducer_stateless8/__init__.py | 0 .../asr_datamodule.py | 1 + .../beam_search.py | 1 + .../pruned_transducer_stateless8/decode.py | 863 +++++++++++ .../pruned_transducer_stateless8/decoder.py | 1 + .../encoder_interface.py | 1 + .../pruned_transducer_stateless8/export.py | 334 ++++ .../gigaspeech.py | 1 + .../jit_pretrained.py | 275 ++++ .../pruned_transducer_stateless8/joiner.py | 1 + .../librispeech.py | 1 + .../ASR/pruned_transducer_stateless8/model.py | 222 +++ .../ASR/pruned_transducer_stateless8/optim.py | 1 + .../pretrained.py | 363 +++++ .../pruned_transducer_stateless8/scaling.py | 1 + .../scaling_converter.py | 1 + .../ASR/pruned_transducer_stateless8/train.py | 1367 +++++++++++++++++ .../pruned_transducer_stateless8/zipformer.py | 1 + 24 files changed, 3767 insertions(+) create mode 100755 .github/scripts/run-librispeech-pruned-transducer-stateless8-2022-11-14.sh create mode 100644 .github/workflows/run-librispeech-2022-11-14-stateless8.yml create mode 100644 egs/librispeech/ASR/pruned_transducer_stateless8/__init__.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless8/asr_datamodule.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless8/beam_search.py create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless8/decode.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless8/decoder.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless8/encoder_interface.py create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless8/export.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless8/gigaspeech.py create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless8/jit_pretrained.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless8/joiner.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless8/librispeech.py create mode 100644 egs/librispeech/ASR/pruned_transducer_stateless8/model.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless8/optim.py create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless8/pretrained.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless8/scaling.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless8/scaling_converter.py create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless8/train.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless8/zipformer.py diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless7-2022-11-11.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless7-2022-11-11.sh index 75861bbc7..8e485d2e6 100755 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless7-2022-11-11.sh +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless7-2022-11-11.sh @@ -33,6 +33,7 @@ popd log "Export to torchscript model" ./pruned_transducer_stateless7/export.py \ --exp-dir $repo/exp \ + --use-averaged-model false \ --bpe-model $repo/data/lang_bpe_500/bpe.model \ --epoch 99 \ --avg 1 \ diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless8-2022-11-14.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless8-2022-11-14.sh new file mode 100755 index 000000000..e782b8425 --- /dev/null +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless8-2022-11-14.sh @@ -0,0 +1,116 @@ +#!/usr/bin/env bash + +set -e + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless8-2022-11-14 + +log "Downloading pre-trained model from $repo_url" +git lfs install +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +log "Display test files" +tree $repo/ +soxi $repo/test_wavs/*.wav +ls -lh $repo/test_wavs/*.wav + +pushd $repo/exp +git lfs pull --include "data/lang_bpe_500/bpe.model" +git lfs pull --include "exp/cpu_jit.pt" +git lfs pull --include "exp/pretrained.pt" +ln -s pretrained.pt epoch-99.pt +ls -lh *.pt +popd + +log "Decode with models exported by torch.jit.script()" + +./pruned_transducer_stateless8/jit_pretrained.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --nn-model-filename $repo/exp/cpu_jit.pt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + +log "Export to torchscript model" +./pruned_transducer_stateless8/export.py \ + --exp-dir $repo/exp \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --use-averaged-model false \ + --epoch 99 \ + --avg 1 \ + --jit 1 + +ls -lh $repo/exp/*.pt + +log "Decode with models exported by torch.jit.script()" + +./pruned_transducer_stateless8/jit_pretrained.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --nn-model-filename $repo/exp/cpu_jit.pt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + +for sym in 1 2 3; do + log "Greedy search with --max-sym-per-frame $sym" + + ./pruned_transducer_stateless8/pretrained.py \ + --method greedy_search \ + --max-sym-per-frame $sym \ + --checkpoint $repo/exp/pretrained.pt \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav +done + +for method in modified_beam_search beam_search fast_beam_search; do + log "$method" + + ./pruned_transducer_stateless8/pretrained.py \ + --method $method \ + --beam-size 4 \ + --checkpoint $repo/exp/pretrained.pt \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav +done + +echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}" +echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}" +if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then + mkdir -p pruned_transducer_stateless8/exp + ln -s $PWD/$repo/exp/pretrained.pt pruned_transducer_stateless8/exp/epoch-999.pt + ln -s $PWD/$repo/data/lang_bpe_500 data/ + + ls -lh data + ls -lh pruned_transducer_stateless8/exp + + log "Decoding test-clean and test-other" + + # use a small value for decoding with CPU + max_duration=100 + + for method in greedy_search fast_beam_search modified_beam_search; do + log "Decoding with $method" + + ./pruned_transducer_stateless8/decode.py \ + --decoding-method $method \ + --epoch 999 \ + --avg 1 \ + --use-averaged-model 0 \ + --max-duration $max_duration \ + --exp-dir pruned_transducer_stateless8/exp + done + + rm pruned_transducer_stateless8/exp/*.pt +fi diff --git a/.github/workflows/run-librispeech-2022-11-14-stateless8.yml b/.github/workflows/run-librispeech-2022-11-14-stateless8.yml new file mode 100644 index 000000000..eaab35189 --- /dev/null +++ b/.github/workflows/run-librispeech-2022-11-14-stateless8.yml @@ -0,0 +1,155 @@ +# Copyright 2022 Fangjun Kuang (csukuangfj@gmail.com) + +# See ../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: run-librispeech-2022-11-14-stateless8 +# zipformer + +on: + push: + branches: + - master + pull_request: + types: [labeled] + + schedule: + # minute (0-59) + # hour (0-23) + # day of the month (1-31) + # month (1-12) + # day of the week (0-6) + # nightly build at 15:50 UTC time every day + - cron: "50 15 * * *" + +jobs: + run_librispeech_2022_11_14_zipformer_stateless8: + if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule' + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-latest] + python-version: [3.8] + + fail-fast: false + + steps: + - uses: actions/checkout@v2 + with: + fetch-depth: 0 + + - name: Setup Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + cache: 'pip' + cache-dependency-path: '**/requirements-ci.txt' + + - name: Install Python dependencies + run: | + grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install + pip uninstall -y protobuf + pip install --no-binary protobuf protobuf + + - name: Cache kaldifeat + id: my-cache + uses: actions/cache@v2 + with: + path: | + ~/tmp/kaldifeat + key: cache-tmp-${{ matrix.python-version }}-2022-09-25 + + - name: Install kaldifeat + if: steps.my-cache.outputs.cache-hit != 'true' + shell: bash + run: | + .github/scripts/install-kaldifeat.sh + + - name: Cache LibriSpeech test-clean and test-other datasets + id: libri-test-clean-and-test-other-data + uses: actions/cache@v2 + with: + path: | + ~/tmp/download + key: cache-libri-test-clean-and-test-other + + - name: Download LibriSpeech test-clean and test-other + if: steps.libri-test-clean-and-test-other-data.outputs.cache-hit != 'true' + shell: bash + run: | + .github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh + + - name: Prepare manifests for LibriSpeech test-clean and test-other + shell: bash + run: | + .github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh + + - name: Cache LibriSpeech test-clean and test-other fbank features + id: libri-test-clean-and-test-other-fbank + uses: actions/cache@v2 + with: + path: | + ~/tmp/fbank-libri + key: cache-libri-fbank-test-clean-and-test-other-v2 + + - name: Compute fbank for LibriSpeech test-clean and test-other + if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true' + shell: bash + run: | + .github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh + + - name: Inference with pre-trained model + shell: bash + env: + GITHUB_EVENT_NAME: ${{ github.event_name }} + GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }} + run: | + mkdir -p egs/librispeech/ASR/data + ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank + ls -lh egs/librispeech/ASR/data/* + + sudo apt-get -qq install git-lfs tree sox + export PYTHONPATH=$PWD:$PYTHONPATH + export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH + export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH + + .github/scripts/run-librispeech-pruned-transducer-stateless8-2022-11-14.sh + + - name: Display decoding results for librispeech pruned_transducer_stateless8 + if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' + shell: bash + run: | + cd egs/librispeech/ASR/ + tree ./pruned_transducer_stateless8/exp + + cd pruned_transducer_stateless8 + echo "results for pruned_transducer_stateless8" + echo "===greedy search===" + find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 + find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 + + echo "===fast_beam_search===" + find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 + find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 + + echo "===modified beam search===" + find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 + find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 + + - name: Upload decoding results for librispeech pruned_transducer_stateless8 + uses: actions/upload-artifact@v2 + if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' + with: + name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-pruned_transducer_stateless8-2022-11-14 + path: egs/librispeech/ASR/pruned_transducer_stateless8/exp/ diff --git a/egs/librispeech/ASR/README.md b/egs/librispeech/ASR/README.md index c366650bb..e737d68bd 100644 --- a/egs/librispeech/ASR/README.md +++ b/egs/librispeech/ASR/README.md @@ -23,6 +23,7 @@ The following table lists the differences among them. | `pruned_transducer_stateless5` | Conformer(modified) | Embedding + Conv1d | same as pruned_transducer_stateless4 + more layers + random combiner| | `pruned_transducer_stateless6` | Conformer(modified) | Embedding + Conv1d | same as pruned_transducer_stateless4 + distillation with hubert| | `pruned_transducer_stateless7` | Zipformer | Embedding + Conv1d | First experiment with Zipformer from Dan| +| `pruned_transducer_stateless8` | Zipformer | Embedding + Conv1d | Same as pruned_transducer_stateless7, but using extra data from GigaSpeech| | `pruned_stateless_emformer_rnnt2` | Emformer(from torchaudio) | Embedding + Conv1d | Using Emformer from torchaudio for streaming ASR| | `conv_emformer_transducer_stateless` | ConvEmformer | Embedding + Conv1d | Using ConvEmformer for streaming ASR + mechanisms in reworked model | | `conv_emformer_transducer_stateless2` | ConvEmformer | Embedding + Conv1d | Using ConvEmformer with simplified memory for streaming ASR + mechanisms in reworked model | diff --git a/egs/librispeech/ASR/RESULTS.md b/egs/librispeech/ASR/RESULTS.md index 43cd67c85..030e47b86 100644 --- a/egs/librispeech/ASR/RESULTS.md +++ b/egs/librispeech/ASR/RESULTS.md @@ -1,5 +1,63 @@ ## Results +### pruned_transducer_stateless8 (zipformer + multidataset) + +See for more details. + +[pruned_transducer_stateless8](./pruned_transducer_stateless8) + +The tensorboard log can be found at + + +You can find a pretrained model, training logs, decoding logs, and decoding +results at: + + +You can use to deploy it. + +Number of model parameters: 70369391, i.e., 70.37 M + +| | test-clean | test-other | comment | +|----------------------|------------|-------------|----------------------------------------| +| greedy search | 1.87 | 4.38 | --epoch 16 --avg 2 --max-duration 600 | +| modified beam search | 1.81 | 4.34 | --epoch 16 --avg 2 --max-duration 600 | +| fast beam search | 1.91 | 4.33 | --epoch 16 --avg 2 --max-duration 600 | + +The training commands are: +```bash +export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" + +./pruned_transducer_stateless8/train.py \ + --world-size 8 \ + --num-epochs 20 \ + --full-libri 1 \ + --use-fp16 1 \ + --max-duration 750 \ + --exp-dir pruned_transducer_stateless8/exp \ + --feedforward-dims "1024,1024,2048,2048,1024" \ + --master-port 12535 \ + --giga-prob 0.9 +``` + +The decoding commands are: +```bash +for m in greedy_search fast_beam_search modified_beam_search ; do + for epoch in 16; do + for avg in 2; do + ./pruned_transducer_stateless8/decode.py \ + --epoch $epoch \ + --avg $avg \ + --use-averaged-model 1 \ + --exp-dir ./pruned_transducer_stateless8/exp \ + --feedforward-dims "1024,1024,2048,2048,1024" \ + --max-duration 600 \ + --decoding-method $m + done + done +done +``` + + ### pruned_transducer_stateless7 (zipformer) See for more details. diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/jit_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7/jit_pretrained.py index 81b0deba3..e2405d5ef 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/jit_pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/jit_pretrained.py @@ -30,6 +30,7 @@ Usage of this script: ./pruned_transducer_stateless7/jit_pretrained.py \ --nn-model-filename ./pruned_transducer_stateless7/exp/cpu_jit.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ /path/to/foo.wav \ /path/to/bar.wav """ diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/__init__.py b/egs/librispeech/ASR/pruned_transducer_stateless8/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/asr_datamodule.py b/egs/librispeech/ASR/pruned_transducer_stateless8/asr_datamodule.py new file mode 120000 index 000000000..3ba9ada4f --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/asr_datamodule.py @@ -0,0 +1 @@ +../pruned_transducer_stateless3/asr_datamodule.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless8/beam_search.py new file mode 120000 index 000000000..8554e44cc --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/beam_search.py @@ -0,0 +1 @@ +../pruned_transducer_stateless2/beam_search.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless8/decode.py new file mode 100755 index 000000000..9d7335e77 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/decode.py @@ -0,0 +1,863 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) greedy search +./pruned_transducer_stateless8/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless8/exp \ + --max-duration 600 \ + --decoding-method greedy_search + +(2) beam search (not recommended) +./pruned_transducer_stateless8/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless8/exp \ + --max-duration 600 \ + --decoding-method beam_search \ + --beam-size 4 + +(3) modified beam search +./pruned_transducer_stateless8/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless8/exp \ + --max-duration 600 \ + --decoding-method modified_beam_search \ + --beam-size 4 + +(4) fast beam search (one best) +./pruned_transducer_stateless8/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless8/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 + +(5) fast beam search (nbest) +./pruned_transducer_stateless8/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless8/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(6) fast beam search (nbest oracle WER) +./pruned_transducer_stateless8/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless8/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_oracle \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(7) fast beam search (with LG) +./pruned_transducer_stateless8/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless8/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_LG \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 +""" + + +import argparse +import logging +import math +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import AsrDataModule +from beam_search import ( + beam_search, + fast_beam_search_nbest, + fast_beam_search_nbest_LG, + fast_beam_search_nbest_oracle, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, +) +from librispeech import LibriSpeech +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=9, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless8/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_bpe_500", + help="The lang dir containing word table and LG graph", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - fast_beam_search + - fast_beam_search_nbest + - fast_beam_search_nbest_oracle + - fast_beam_search_nbest_LG + If you use fast_beam_search_nbest_LG, you have to specify + `--lang-dir`, which should contain `LG.pt`. + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=20.0, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search, + fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle + """, + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.01, + help=""" + Used only when --decoding_method is fast_beam_search_nbest_LG. + It specifies the scale for n-gram LM scores. + """, + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=8, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=64, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. + Used only when --decoding_method is greedy_search""", + ) + + parser.add_argument( + "--num-paths", + type=int, + default=200, + help="""Number of paths for nbest decoding. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""Scale applied to lattice scores when computing nbest paths. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--simulate-streaming", + type=str2bool, + default=False, + help="""Whether to simulate streaming in decoding, this is a good way to + test a streaming model. + """, + ) + + parser.add_argument( + "--decode-chunk-size", + type=int, + default=16, + help="The chunk size for decoding (in frames after subsampling)", + ) + + parser.add_argument( + "--left-context", + type=int, + default=64, + help="left context can be seen during decoding (in frames after subsampling)", + ) + + add_model_arguments(parser) + + return parser + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + batch: dict, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if greedy_search is used, it would be "greedy_search" + If beam search with a beam size of 7 is used, it would be + "beam_7" + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return the decoding result. See above description for the format of + the returned dict. + """ + device = next(model.parameters()).device + feature = batch["inputs"] + assert feature.ndim == 3 + + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + if params.simulate_streaming: + feature_lens += params.left_context + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, params.left_context), + value=LOG_EPS, + ) + encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward( + x=feature, + x_lens=feature_lens, + chunk_size=params.decode_chunk_size, + left_context=params.left_context, + simulate_streaming=True, + ) + else: + encoder_out, encoder_out_lens = model.encoder( + x=feature, x_lens=feature_lens + ) + + hyps = [] + + if params.decoding_method == "fast_beam_search": + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_LG": + hyp_tokens = fast_beam_search_nbest_LG( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in hyp_tokens: + hyps.append([word_table[i] for i in hyp]) + elif params.decoding_method == "fast_beam_search_nbest": + hyp_tokens = fast_beam_search_nbest( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_oracle": + hyp_tokens = fast_beam_search_nbest_oracle( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + ref_texts=sp.encode(supervisions["text"]), + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif ( + params.decoding_method == "greedy_search" + and params.max_sym_per_frame == 1 + ): + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + else: + batch_size = encoder_out.size(0) + + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.decoding_method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyps.append(sp.decode(hyp).split()) + + if params.decoding_method == "greedy_search": + return {"greedy_search": hyps} + elif "fast_beam_search" in params.decoding_method: + key = f"beam_{params.beam}_" + key += f"max_contexts_{params.max_contexts}_" + key += f"max_states_{params.max_states}" + if "nbest" in params.decoding_method: + key += f"_num_paths_{params.num_paths}_" + key += f"nbest_scale_{params.nbest_scale}" + if "LG" in params.decoding_method: + key += f"_ngram_lm_scale_{params.ngram_lm_scale}" + + return {key: hyps} + else: + return {f"beam_size_{params.beam_size}": hyps} + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + if params.decoding_method == "greedy_search": + log_interval = 50 + else: + log_interval = 20 + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + word_table=word_table, + batch=batch, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = ( + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + AsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + assert params.decoding_method in ( + "greedy_search", + "beam_search", + "fast_beam_search", + "fast_beam_search_nbest", + "fast_beam_search_nbest_LG", + "fast_beam_search_nbest_oracle", + "modified_beam_search", + ) + params.res_dir = params.exp_dir / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if params.simulate_streaming: + params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}" + params.suffix += f"-left-context-{params.left_context}" + + if "fast_beam_search" in params.decoding_method: + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + if "nbest" in params.decoding_method: + params.suffix += f"-nbest-scale-{params.nbest_scale}" + params.suffix += f"-num-paths-{params.num_paths}" + if "LG" in params.decoding_method: + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" + elif "beam_search" in params.decoding_method: + params.suffix += ( + f"-{params.decoding_method}-beam-size-{params.beam_size}" + ) + else: + params.suffix += f"-context-{params.context_size}" + params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and are defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + if params.simulate_streaming: + assert ( + params.causal_convolution + ), "Decoding in streaming requires causal convolution" + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params, enable_giga=False) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict( + average_checkpoints(filenames, device=device), + strict=False, + ) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict( + average_checkpoints(filenames, device=device), strict=False + ) + else: + if params.iter > 0: + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ), + strict=False, + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ), + strict=False, + ) + + model.to(device) + model.eval() + + if "fast_beam_search" in params.decoding_method: + if params.decoding_method == "fast_beam_search_nbest_LG": + lexicon = Lexicon(params.lang_dir) + word_table = lexicon.word_table + lg_filename = params.lang_dir / "LG.pt" + logging.info(f"Loading {lg_filename}") + decoding_graph = k2.Fsa.from_dict( + torch.load(lg_filename, map_location=device) + ) + decoding_graph.scores *= params.ngram_lm_scale + else: + word_table = None + decoding_graph = k2.trivial_graph( + params.vocab_size - 1, device=device + ) + else: + decoding_graph = None + word_table = None + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + asr_datamodule = AsrDataModule(args) + librispeech = LibriSpeech(manifest_dir=args.manifest_dir) + + test_clean_cuts = librispeech.test_clean_cuts() + test_other_cuts = librispeech.test_other_cuts() + + test_clean_dl = asr_datamodule.test_dataloaders(test_clean_cuts) + test_other_dl = asr_datamodule.test_dataloaders(test_other_cuts) + + test_sets = ["test-clean", "test-other"] + test_dl = [test_clean_dl, test_other_dl] + + for test_set, test_dl in zip(test_sets, test_dl): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + sp=sp, + word_table=word_table, + decoding_graph=decoding_graph, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/decoder.py b/egs/librispeech/ASR/pruned_transducer_stateless8/decoder.py new file mode 120000 index 000000000..33944d0d2 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/decoder.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7/decoder.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/encoder_interface.py b/egs/librispeech/ASR/pruned_transducer_stateless8/encoder_interface.py new file mode 120000 index 000000000..b9aa0ae08 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/encoder_interface.py @@ -0,0 +1 @@ +../pruned_transducer_stateless2/encoder_interface.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/export.py b/egs/librispeech/ASR/pruned_transducer_stateless8/export.py new file mode 100755 index 000000000..49f469e29 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/export.py @@ -0,0 +1,334 @@ +#!/usr/bin/env python3 +# +# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This script converts several saved checkpoints +# to a single one using model averaging. +""" + +Usage: + +(1) Export to torchscript model using torch.jit.script() + +./pruned_transducer_stateless8/export.py \ + --exp-dir ./pruned_transducer_stateless8/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 30 \ + --avg 9 \ + --jit 1 + +It will generate a file `cpu_jit.pt` in the given `exp_dir`. You can later +load it by `torch.jit.load("cpu_jit.pt")`. + +Note `cpu` in the name `cpu_jit.pt` means the parameters when loaded into Python +are on CPU. You can use `to("cuda")` to move them to a CUDA device. + +Check +https://github.com/k2-fsa/sherpa +for how to use the exported models outside of icefall. + +(2) Export `model.state_dict()` + +./pruned_transducer_stateless8/export.py \ + --exp-dir ./pruned_transducer_stateless8/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 20 \ + --avg 10 + +It will generate a file `pretrained.pt` in the given `exp_dir`. You can later +load it by `icefall.checkpoint.load_checkpoint()`. + +To use the generated file with `pruned_transducer_stateless8/decode.py`, +you can do: + + cd /path/to/exp_dir + ln -s pretrained.pt epoch-9999.pt + + cd /path/to/egs/librispeech/ASR + ./pruned_transducer_stateless8/decode.py \ + --exp-dir ./pruned_transducer_stateless8/exp \ + --epoch 9999 \ + --avg 1 \ + --max-duration 600 \ + --decoding-method greedy_search \ + --bpe-model data/lang_bpe_500/bpe.model + +Check ./pretrained.py for its usage. + +Note: If you don't want to train a model from scratch, we have +provided one for you. You can get it at + +https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless8-2022-11-14 + +with the following commands: + + sudo apt-get install git-lfs + git lfs install + git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless8-2022-11-14 + # You will find the pre-trained model in icefall-asr-librispeech-pruned-transducer-stateless8-2022-11-14/exp +""" + +import argparse +import logging +from pathlib import Path + +import sentencepiece as spm +import torch +import torch.nn as nn +from scaling_converter import convert_scaled_to_non_scaled +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=9, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless8/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--jit", + type=str2bool, + default=False, + help="""True to save a model after applying torch.jit.script. + It will generate a file named cpu_jit.pt + + Check ./jit_pretrained.py for how to use it. + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", + ) + + add_model_arguments(parser) + + return parser + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params, enable_giga=False) + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + model.to(device) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict( + average_checkpoints(filenames, device=device), + strict=False, + ) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict( + average_checkpoints(filenames, device=device), + strict=False, + ) + else: + if params.iter > 0: + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ), + strict=False, + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ), + strict=False, + ) + + model.to("cpu") + model.eval() + + if params.jit is True: + convert_scaled_to_non_scaled(model, inplace=True) + logging.info("Using torch.jit.script()") + # We won't use the forward() method of the model in C++, so just ignore + # it here. + # Otherwise, one of its arguments is a ragged tensor and is not + # torch scriptabe. + model.__class__.forward = torch.jit.ignore(model.__class__.forward) + logging.info("Using torch.jit.script") + model = torch.jit.script(model) + filename = params.exp_dir / "cpu_jit.pt" + model.save(str(filename)) + logging.info(f"Saved to {filename}") + else: + logging.info("Not using torchscript. Export model.state_dict()") + # Save it using a format so that it can be loaded + # by :func:`load_checkpoint` + filename = params.exp_dir / "pretrained.pt" + torch.save({"model": model.state_dict()}, str(filename)) + logging.info(f"Saved to {filename}") + + +if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/gigaspeech.py b/egs/librispeech/ASR/pruned_transducer_stateless8/gigaspeech.py new file mode 120000 index 000000000..5242c652a --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/gigaspeech.py @@ -0,0 +1 @@ +../pruned_transducer_stateless3/gigaspeech.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/jit_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless8/jit_pretrained.py new file mode 100755 index 000000000..e79a3a3aa --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/jit_pretrained.py @@ -0,0 +1,275 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads torchscript models, exported by `torch.jit.script()` +and uses them to decode waves. +You can use the following command to get the exported models: + +./pruned_transducer_stateless8/export.py \ + --exp-dir ./pruned_transducer_stateless8/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 20 \ + --avg 10 \ + --jit 1 + +Usage of this script: + +./pruned_transducer_stateless8/jit_pretrained.py \ + --nn-model-filename ./pruned_transducer_stateless8/exp/cpu_jit.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + /path/to/foo.wav \ + /path/to/bar.wav +""" + +import argparse +import logging +import math +from typing import List + +import kaldifeat +import sentencepiece as spm +import torch +import torchaudio +from torch.nn.utils.rnn import pad_sequence + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--nn-model-filename", + type=str, + required=True, + help="Path to the torchscript model cpu_jit.pt", + ) + + parser.add_argument( + "--bpe-model", + type=str, + help="""Path to bpe.model.""", + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float = 16000 +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) + # We use only the first channel + ans.append(wave[0]) + return ans + + +def greedy_search( + model: torch.jit.ScriptModule, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, +) -> List[List[int]]: + """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. + Args: + model: + The transducer model. + encoder_out: + A 3-D tensor of shape (N, T, C) + encoder_out_lens: + A 1-D tensor of shape (N,). + Returns: + Return the decoded results for each utterance. + """ + assert encoder_out.ndim == 3 + assert encoder_out.size(0) >= 1, encoder_out.size(0) + + packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( + input=encoder_out, + lengths=encoder_out_lens.cpu(), + batch_first=True, + enforce_sorted=False, + ) + + device = encoder_out.device + blank_id = 0 # hard-code to 0 + + batch_size_list = packed_encoder_out.batch_sizes.tolist() + N = encoder_out.size(0) + + assert torch.all(encoder_out_lens > 0), encoder_out_lens + assert N == batch_size_list[0], (N, batch_size_list) + + context_size = model.decoder.context_size + hyps = [[blank_id] * context_size for _ in range(N)] + + decoder_input = torch.tensor( + hyps, + device=device, + dtype=torch.int64, + ) # (N, context_size) + + decoder_out = model.decoder( + decoder_input, + need_pad=torch.tensor([False]), + ).squeeze(1) + + offset = 0 + for batch_size in batch_size_list: + start = offset + end = offset + batch_size + current_encoder_out = packed_encoder_out.data[start:end] + current_encoder_out = current_encoder_out + # current_encoder_out's shape: (batch_size, encoder_out_dim) + offset = end + + decoder_out = decoder_out[:batch_size] + + logits = model.joiner( + current_encoder_out, + decoder_out, + ) + # logits'shape (batch_size, vocab_size) + + assert logits.ndim == 2, logits.shape + y = logits.argmax(dim=1).tolist() + emitted = False + for i, v in enumerate(y): + if v != blank_id: + hyps[i].append(v) + emitted = True + if emitted: + # update decoder output + decoder_input = [h[-context_size:] for h in hyps[:batch_size]] + decoder_input = torch.tensor( + decoder_input, + device=device, + dtype=torch.int64, + ) + decoder_out = model.decoder( + decoder_input, + need_pad=torch.tensor([False]), + ) + decoder_out = decoder_out.squeeze(1) + + sorted_ans = [h[context_size:] for h in hyps] + ans = [] + unsorted_indices = packed_encoder_out.unsorted_indices.tolist() + for i in range(N): + ans.append(sorted_ans[unsorted_indices[i]]) + + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + logging.info(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + model = torch.jit.load(args.nn_model_filename) + + model.eval() + + model.to(device) + + sp = spm.SentencePieceProcessor() + sp.load(args.bpe_model) + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = 16000 + opts.mel_opts.num_bins = 80 + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {args.sound_files}") + waves = read_sound_files( + filenames=args.sound_files, + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + + features = pad_sequence( + features, + batch_first=True, + padding_value=math.log(1e-10), + ) + + feature_lengths = torch.tensor(feature_lengths, device=device) + + encoder_out, encoder_out_lens = model.encoder( + x=features, + x_lens=feature_lengths, + ) + + hyps = greedy_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + s = "\n" + for filename, hyp in zip(args.sound_files, hyps): + words = sp.decode(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/joiner.py b/egs/librispeech/ASR/pruned_transducer_stateless8/joiner.py new file mode 120000 index 000000000..ecfb6dd8a --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/joiner.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7/joiner.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/librispeech.py b/egs/librispeech/ASR/pruned_transducer_stateless8/librispeech.py new file mode 120000 index 000000000..b76723bf5 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/librispeech.py @@ -0,0 +1 @@ +../pruned_transducer_stateless3/librispeech.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/model.py b/egs/librispeech/ASR/pruned_transducer_stateless8/model.py new file mode 100644 index 000000000..497b89136 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/model.py @@ -0,0 +1,222 @@ +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, Wei Kang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import random +from typing import Optional, Tuple + +import k2 +import torch +import torch.nn as nn +from encoder_interface import EncoderInterface +from scaling import penalize_abs_values_gt + +from icefall.utils import add_sos + + +class Transducer(nn.Module): + """It implements https://arxiv.org/pdf/1211.3711.pdf + "Sequence Transduction with Recurrent Neural Networks" + """ + + def __init__( + self, + encoder: EncoderInterface, + decoder: nn.Module, + joiner: nn.Module, + encoder_dim: int, + decoder_dim: int, + joiner_dim: int, + vocab_size: int, + decoder_giga: Optional[nn.Module] = None, + joiner_giga: Optional[nn.Module] = None, + ): + """ + Args: + encoder: + It is the transcription network in the paper. Its accepts + two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,). + It returns two tensors: `logits` of shape (N, T, encoder_dm) and + `logit_lens` of shape (N,). + decoder: + It is the prediction network in the paper. Its input shape + is (N, U) and its output shape is (N, U, decoder_dim). + It should contain one attribute: `blank_id`. + joiner: + It has two inputs with shapes: (N, T, encoder_dim) and (N, U, decoder_dim). + Its output shape is (N, T, U, vocab_size). Note that its output contains + unnormalized probs, i.e., not processed by log-softmax. + """ + super().__init__() + assert isinstance(encoder, EncoderInterface), type(encoder) + assert hasattr(decoder, "blank_id") + + self.encoder = encoder + self.decoder = decoder + self.joiner = joiner + + self.decoder_giga = decoder_giga + self.joiner_giga = joiner_giga + + self.simple_am_proj = nn.Linear( + encoder_dim, + vocab_size, + ) + self.simple_lm_proj = nn.Linear(decoder_dim, vocab_size) + + if decoder_giga is not None: + self.simple_am_proj_giga = nn.Linear(encoder_dim, vocab_size) + self.simple_lm_proj_giga = nn.Linear(decoder_dim, vocab_size) + + def forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + y: k2.RaggedTensor, + libri: bool = True, + prune_range: int = 5, + am_scale: float = 0.0, + lm_scale: float = 0.0, + ) -> torch.Tensor: + """ + Args: + x: + A 3-D tensor of shape (N, T, C). + x_lens: + A 1-D tensor of shape (N,). It contains the number of frames in `x` + before padding. + y: + A ragged tensor with 2 axes [utt][label]. It contains labels of each + utterance. + libri: + True to use the decoder and joiner for the LibriSpeech dataset. + False to use the decoder and joiner for the GigaSpeech dataset. + prune_range: + The prune range for rnnt loss, it means how many symbols(context) + we are considering for each frame to compute the loss. + am_scale: + The scale to smooth the loss with am (output of encoder network) + part + lm_scale: + The scale to smooth the loss with lm (output of predictor network) + part + Returns: + Return the transducer loss. + + Note: + Regarding am_scale & lm_scale, it will make the loss-function one of + the form: + lm_scale * lm_probs + am_scale * am_probs + + (1-lm_scale-am_scale) * combined_probs + """ + assert x.ndim == 3, x.shape + assert x_lens.ndim == 1, x_lens.shape + assert y.num_axes == 2, y.num_axes + + assert x.size(0) == x_lens.size(0) == y.dim0 + + encoder_out, x_lens = self.encoder(x, x_lens) + assert torch.all(x_lens > 0) + + if libri: + decoder = self.decoder + simple_lm_proj = self.simple_lm_proj + simple_am_proj = self.simple_am_proj + joiner = self.joiner + else: + decoder = self.decoder_giga + simple_lm_proj = self.simple_lm_proj_giga + simple_am_proj = self.simple_am_proj_giga + joiner = self.joiner_giga + + # Now for the decoder, i.e., the prediction network + row_splits = y.shape.row_splits(1) + y_lens = row_splits[1:] - row_splits[:-1] + + blank_id = decoder.blank_id + sos_y = add_sos(y, sos_id=blank_id) + + # sos_y_padded: [B, S + 1], start with SOS. + sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id) + + # decoder_out: [B, S + 1, decoder_dim] + decoder_out = decoder(sos_y_padded) + + # Note: y does not start with SOS + # y_padded : [B, S] + y_padded = y.pad(mode="constant", padding_value=0) + + y_padded = y_padded.to(torch.int64) + boundary = torch.zeros( + (x.size(0), 4), dtype=torch.int64, device=x.device + ) + boundary[:, 2] = y_lens + boundary[:, 3] = x_lens + + lm = simple_lm_proj(decoder_out) + am = simple_am_proj(encoder_out) + + # if self.training and random.random() < 0.25: + # lm = penalize_abs_values_gt(lm, 100.0, 1.0e-04) + # if self.training and random.random() < 0.25: + # am = penalize_abs_values_gt(am, 30.0, 1.0e-04) + + with torch.cuda.amp.autocast(enabled=False): + simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( + lm=lm.float(), + am=am.float(), + symbols=y_padded, + termination_symbol=blank_id, + lm_only_scale=lm_scale, + am_only_scale=am_scale, + boundary=boundary, + reduction="sum", + return_grad=True, + ) + + # ranges : [B, T, prune_range] + ranges = k2.get_rnnt_prune_ranges( + px_grad=px_grad, + py_grad=py_grad, + boundary=boundary, + s_range=prune_range, + ) + + # am_pruned : [B, T, prune_range, encoder_dim] + # lm_pruned : [B, T, prune_range, decoder_dim] + am_pruned, lm_pruned = k2.do_rnnt_pruning( + am=joiner.encoder_proj(encoder_out), + lm=joiner.decoder_proj(decoder_out), + ranges=ranges, + ) + + # logits : [B, T, prune_range, vocab_size] + + # project_input=False since we applied the decoder's input projections + # prior to do_rnnt_pruning (this is an optimization for speed). + logits = joiner(am_pruned, lm_pruned, project_input=False) + + with torch.cuda.amp.autocast(enabled=False): + pruned_loss = k2.rnnt_loss_pruned( + logits=logits.float(), + symbols=y_padded, + ranges=ranges, + termination_symbol=blank_id, + boundary=boundary, + reduction="sum", + ) + + return (simple_loss, pruned_loss) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless8/optim.py new file mode 120000 index 000000000..81ac4a89a --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/optim.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7/optim.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless8/pretrained.py new file mode 100755 index 000000000..373a48fc1 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/pretrained.py @@ -0,0 +1,363 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads a checkpoint and uses it to decode waves. +You can generate the checkpoint with the following command: + +./pruned_transducer_stateless8/export.py \ + --exp-dir ./pruned_transducer_stateless8/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 20 \ + --avg 10 + +Usage of this script: + +(1) greedy search +./pruned_transducer_stateless8/pretrained.py \ + --checkpoint ./pruned_transducer_stateless8/exp/pretrained.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method greedy_search \ + /path/to/foo.wav \ + /path/to/bar.wav + +(2) beam search +./pruned_transducer_stateless8/pretrained.py \ + --checkpoint ./pruned_transducer_stateless8/exp/pretrained.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method beam_search \ + --beam-size 4 \ + /path/to/foo.wav \ + /path/to/bar.wav + +(3) modified beam search +./pruned_transducer_stateless8/pretrained.py \ + --checkpoint ./pruned_transducer_stateless8/exp/pretrained.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method modified_beam_search \ + --beam-size 4 \ + /path/to/foo.wav \ + /path/to/bar.wav + +(4) fast beam search +./pruned_transducer_stateless8/pretrained.py \ + --checkpoint ./pruned_transducer_stateless8/exp/pretrained.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method fast_beam_search \ + --beam-size 4 \ + /path/to/foo.wav \ + /path/to/bar.wav + +You can also use `./pruned_transducer_stateless8/exp/epoch-xx.pt`. + +Note: ./pruned_transducer_stateless8/exp/pretrained.pt is generated by +./pruned_transducer_stateless8/export.py +""" + + +import argparse +import logging +import math +from typing import List + +import k2 +import kaldifeat +import sentencepiece as spm +import torch +import torchaudio +from beam_search import ( + beam_search, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, +) +from torch.nn.utils.rnn import pad_sequence +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.utils import str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--checkpoint", + type=str, + required=True, + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", + ) + + parser.add_argument( + "--bpe-model", + type=str, + help="""Path to bpe.model.""", + ) + + parser.add_argument( + "--method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - fast_beam_search + """, + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=4, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=4, + help="""Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=8, + help="""Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. Used only when + --method is greedy_search. + """, + ) + + add_model_arguments(parser) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) + # We use only the first channel + ans.append(wave[0]) + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + + params = get_params() + + params.update(vars(args)) + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(f"{params}") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + logging.info("Creating model") + model = get_transducer_model(params, enable_giga=False) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + checkpoint = torch.load(args.checkpoint, map_location="cpu") + model.load_state_dict(checkpoint["model"], strict=False) + model.to(device) + model.eval() + model.device = device + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = params.sample_rate + opts.mel_opts.num_bins = params.feature_dim + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {params.sound_files}") + waves = read_sound_files( + filenames=params.sound_files, expected_sample_rate=params.sample_rate + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + + features = pad_sequence( + features, batch_first=True, padding_value=math.log(1e-10) + ) + + feature_lengths = torch.tensor(feature_lengths, device=device) + + encoder_out, encoder_out_lens = model.encoder( + x=features, x_lens=feature_lengths + ) + + num_waves = encoder_out.size(0) + hyps = [] + msg = f"Using {params.method}" + if params.method == "beam_search": + msg += f" with beam size {params.beam_size}" + logging.info(msg) + + if params.method == "fast_beam_search": + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + ) + + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + else: + for i in range(num_waves): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError(f"Unsupported method: {params.method}") + + hyps.append(sp.decode(hyp).split()) + + s = "\n" + for filename, hyp in zip(params.sound_files, hyps): + words = " ".join(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless8/scaling.py new file mode 120000 index 000000000..2428b74b9 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/scaling.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7/scaling.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/scaling_converter.py b/egs/librispeech/ASR/pruned_transducer_stateless8/scaling_converter.py new file mode 120000 index 000000000..b8b8ba432 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/scaling_converter.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7/scaling_converter.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/train.py b/egs/librispeech/ASR/pruned_transducer_stateless8/train.py new file mode 100755 index 000000000..b4177d3f0 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/train.py @@ -0,0 +1,1367 @@ +#!/usr/bin/env python3 +# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo,) +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +cd egs/librispeech/ASR/ +./prepare.sh +./prepare_giga_speech.sh + +./pruned_transducer_stateless8/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --exp-dir pruned_transducer_stateless8/exp \ + --full-libri 1 \ + --max-duration 300 + +# For mix precision training: + +./pruned_transducer_stateless8/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir pruned_transducer_stateless8/exp \ + --full-libri 1 \ + --max-duration 550 + +""" + + +import argparse +import copy +import logging +import random +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import k2 +import optim +import sentencepiece as spm +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import AsrDataModule +from decoder import Decoder +from gigaspeech import GigaSpeech +from joiner import Joiner +from lhotse import CutSet, load_manifest +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from librispeech import LibriSpeech +from model import Transducer +from optim import Eden, ScaledAdam +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from zipformer import Zipformer + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.utils import ( + AttributeDict, + MetricsTracker, + setup_logger, + str2bool, +) + +LRSchedulerType = Union[ + torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler +] + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for module in model.modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=str, + default="2,4,3,2,4", + help="Number of zipformer encoder layers, comma separated.", + ) + + parser.add_argument( + "--feedforward-dims", + type=str, + default="1024,1024,2048,2048,1024", + help="Feedforward dimension of the zipformer encoder layers, comma separated.", + ) + + parser.add_argument( + "--nhead", + type=str, + default="8,8,8,8,8", + help="Number of attention heads in the zipformer encoder layers.", + ) + + parser.add_argument( + "--encoder-dims", + type=str, + default="384,384,384,384,384", + help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated", + ) + + parser.add_argument( + "--attention-dims", + type=str, + default="192,192,192,192,192", + help="""Attention dimension in the 2 blocks of zipformer encoder layers, comma separated; + not the same as embedding dimension.""", + ) + + parser.add_argument( + "--encoder-unmasked-dims", + type=str, + default="256,256,256,256,256", + help="Unmasked dimensions in the encoders, relates to augmentation during training. " + "Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance " + " worse.", + ) + + parser.add_argument( + "--zipformer-downsampling-factors", + type=str, + default="1,2,4,8,2", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--cnn-module-kernels", + type=str, + default="31,31,31,31,31", + help="Sizes of kernels in convolution modules", + ) + + parser.add_argument( + "--decoder-dim", + type=int, + default=512, + help="Embedding dimension in the decoder model.", + ) + + parser.add_argument( + "--joiner-dim", + type=int, + default=512, + help="""Dimension used in the joiner model. + Outputs from the encoder and decoder model are projected + to this dimension before adding. + """, + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--full-libri", + type=str2bool, + default=True, + help="When enabled, use 960h LibriSpeech. " + "Otherwise, use 100h subset.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless8/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--base-lr", type=float, default=0.05, help="The base learning rate." + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=5000, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=3.5, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network)" + "part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=2000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 0. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + parser.add_argument( + "--giga-prob", + type=float, + default=0.5, + help="The probability to select a batch from the GigaSpeech dataset", + ) + + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - encoder_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warmup period that dictates the decay of the + scale on "simple" (un-pruned) loss. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, # For the 100h subset, use 800 + # parameters for zipformer + "feature_dim": 80, + "subsampling_factor": 4, # not passed in, this is fixed. + "warm_step": 2000, + "env_info": get_env_info(), + } + ) + + return params + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + # TODO: We can add an option to switch between Zipformer and Transformer + def to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + + encoder = Zipformer( + num_features=params.feature_dim, + output_downsampling_factor=2, + zipformer_downsampling_factors=to_int_tuple( + params.zipformer_downsampling_factors + ), + encoder_dims=to_int_tuple(params.encoder_dims), + attention_dim=to_int_tuple(params.attention_dims), + encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims), + nhead=to_int_tuple(params.nhead), + feedforward_dim=to_int_tuple(params.feedforward_dims), + cnn_module_kernels=to_int_tuple(params.cnn_module_kernels), + num_encoder_layers=to_int_tuple(params.num_encoder_layers), + ) + return encoder + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + decoder_dim=params.decoder_dim, + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + joiner = Joiner( + encoder_dim=int(params.encoder_dims.split(",")[-1]), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return joiner + + +def get_transducer_model( + params: AttributeDict, + enable_giga: bool = True, +) -> nn.Module: + encoder = get_encoder_model(params) + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + + if enable_giga: + logging.info("Use giga") + decoder_giga = get_decoder_model(params) + joiner_giga = get_joiner_model(params) + else: + logging.info("Disable giga") + decoder_giga = None + joiner_giga = None + + model = Transducer( + encoder=encoder, + decoder=decoder, + joiner=joiner, + decoder_giga=decoder_giga, + joiner_giga=joiner_giga, + encoder_dim=int(params.encoder_dims.split(",")[-1]), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def is_libri(c: Cut) -> bool: + """Return True if this cut is from the LibriSpeech dataset. + + Note: + During data preparation, we set the custom field in + the supervision segment of GigaSpeech to dict(origin='giga') + See ../local/preprocess_gigaspeech.py. + """ + return c.supervisions[0].custom is None + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute transducer loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = ( + model.device + if isinstance(model, DDP) + else next(model.parameters()).device + ) + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + libri = is_libri(supervisions["cut"][0]) + batch_idx_train = params.batch_idx_train + warm_step = params.warm_step + + texts = batch["supervisions"]["text"] + y = sp.encode(texts, out_type=int) + y = k2.RaggedTensor(y).to(device) + + with torch.set_grad_enabled(is_training): + simple_loss, pruned_loss = model( + x=feature, + x_lens=feature_lens, + y=y, + libri=libri, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + ) + + s = params.simple_loss_scale + # take down the scale on the simple loss from 1.0 at the start + # to params.simple_loss scale by warm_step. + simple_loss_scale = ( + s + if batch_idx_train >= warm_step + else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) + ) + pruned_loss_scale = ( + 1.0 + if batch_idx_train >= warm_step + else 0.1 + 0.9 * (batch_idx_train / warm_step) + ) + + loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = ( + (feature_lens // params.subsampling_factor).sum().item() + ) + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + sp: spm.SentencePieceProcessor, + train_dl: torch.utils.data.DataLoader, + giga_train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + rng: random.Random, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + giga_train_dl: + Dataloader for the GigaSpeech training dataset. + valid_dl: + Dataloader for the validation dataset. + rng: + For selecting which dataset to use. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + libri_tot_loss = MetricsTracker() + giga_tot_loss = MetricsTracker() + tot_loss = MetricsTracker() + + # index 0: for LibriSpeech + # index 1: for GigaSpeech + # This sets the probabilities for choosing which datasets + dl_weights = [1 - params.giga_prob, params.giga_prob] + + iter_libri = iter(train_dl) + iter_giga = iter(giga_train_dl) + + batch_idx = 0 + + while True: + idx = rng.choices((0, 1), weights=dl_weights, k=1)[0] + dl = iter_libri if idx == 0 else iter_giga + + try: + batch = next(dl) + except StopIteration: + name = "libri" if idx == 0 else "giga" + logging.info(f"{name} reaches end of dataloader") + break + + batch_idx += 1 + + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + libri = is_libri(batch["supervisions"]["cut"][0]) + + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + if libri: + libri_tot_loss = ( + libri_tot_loss * (1 - 1 / params.reset_interval) + ) + loss_info + prefix = "libri" # for logging only + else: + giga_tot_loss = ( + giga_tot_loss * (1 - 1 / params.reset_interval) + ) + loss_info + prefix = "giga" + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + set_batch_count(model, params.batch_idx_train) + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except: # noqa + display_and_save_batch(batch, params=params, sp=sp) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + if cur_grad_scale < 1.0 or ( + cur_grad_scale < 8.0 and batch_idx % 400 == 0 + ): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if batch_idx % params.log_interval == 0: + cur_lr = scheduler.get_last_lr()[0] + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, {prefix}_loss[{loss_info}], " + f"tot_loss[{tot_loss}], " + f"libri_tot_loss[{libri_tot_loss}], " + f"giga_tot_loss[{giga_tot_loss}], " + f"batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + ( + f"grad_scale: {scaler._scale.item()}" + if params.use_fp16 + else "" + ) + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, + f"train/current_{prefix}_", + params.batch_idx_train, + ) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) + libri_tot_loss.write_summary( + tb_writer, "train/libri_tot_", params.batch_idx_train + ) + giga_tot_loss.write_summary( + tb_writer, "train/giga_tot_", params.batch_idx_train + ) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", + cur_grad_scale, + params.batch_idx_train, + ) + + if ( + batch_idx % params.valid_interval == 0 + and not params.print_diagnostics + ): + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def filter_short_and_long_utterances(cuts: CutSet) -> CutSet: + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + return 1.0 <= c.duration <= 20.0 + + cuts = cuts.filter(remove_short_and_long_utt) + + return cuts + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + if params.full_libri is False: + params.valid_interval = 1600 + + fix_random_seed(params.seed) + rng = random.Random(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params, enable_giga=True) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + optimizer = ScaledAdam( + model.parameters(), lr=params.base_lr, clipping_scale=2.0 + ) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 2 ** 22 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + librispeech = LibriSpeech(manifest_dir=args.manifest_dir) + + train_cuts = librispeech.train_clean_100_cuts() + if params.full_libri: + train_cuts += librispeech.train_clean_360_cuts() + train_cuts += librispeech.train_other_500_cuts() + + train_cuts = filter_short_and_long_utterances(train_cuts) + + gigaspeech = GigaSpeech(manifest_dir=args.manifest_dir) + # XL 10k hours + # L 2.5k hours + # M 1k hours + # S 250 hours + # XS 10 hours + # DEV 12 hours + # Test 40 hours + if params.full_libri: + logging.info("Using the XL subset of GigaSpeech (10k hours)") + train_giga_cuts = gigaspeech.train_XL_cuts() + else: + logging.info("Using the S subset of GigaSpeech (250 hours)") + train_giga_cuts = gigaspeech.train_S_cuts() + + train_giga_cuts = filter_short_and_long_utterances(train_giga_cuts) + train_giga_cuts = train_giga_cuts.repeat(times=None) + + if args.enable_musan: + cuts_musan = load_manifest( + Path(args.manifest_dir) / "musan_cuts.jsonl.gz" + ) + else: + cuts_musan = None + + asr_datamodule = AsrDataModule(args) + + train_dl = asr_datamodule.train_dataloaders( + train_cuts, + on_the_fly_feats=False, + cuts_musan=cuts_musan, + ) + + giga_train_dl = asr_datamodule.train_dataloaders( + train_giga_cuts, + on_the_fly_feats=False, + cuts_musan=cuts_musan, + ) + + valid_cuts = librispeech.dev_clean_cuts() + valid_cuts += librispeech.dev_other_cuts() + valid_dl = asr_datamodule.valid_dataloaders(valid_cuts) + + if False and not params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + sp=sp, + params=params, + ) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sp=sp, + train_dl=train_dl, + giga_train_dl=giga_train_dl, + valid_dl=valid_dl, + rng=rng, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + sp: spm.SentencePieceProcessor, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + sp: + The BPE model. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + y = sp.encode(supervisions["text"], out_type=int) + num_tokens = sum(len(i) for i in y) + logging.info(f"num tokens: {num_tokens}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + sp: spm.SentencePieceProcessor, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params, sp=sp) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + parser = get_parser() + AsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + assert 0 <= args.giga_prob < 1, args.giga_prob + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless8/zipformer.py new file mode 120000 index 000000000..79b076556 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/zipformer.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7/zipformer.py \ No newline at end of file From c8ce243255d7f18dc4485c3367ef470234670e92 Mon Sep 17 00:00:00 2001 From: Desh Raj Date: Tue, 15 Nov 2022 22:29:45 -0500 Subject: [PATCH 004/174] Zipformer output length (#686) * add assertion for output length * add comment in filter_cuts * add length filter to Zipformer recipes --- egs/librispeech/ASR/local/filter_cuts.py | 3 + .../ASR/pruned_transducer_stateless7/train.py | 119 ++++++++++++------ .../pruned_transducer_stateless7/zipformer.py | 1 + .../ASR/pruned_transducer_stateless8/train.py | 44 +++++-- 4 files changed, 116 insertions(+), 51 deletions(-) diff --git a/egs/librispeech/ASR/local/filter_cuts.py b/egs/librispeech/ASR/local/filter_cuts.py index 53dbb8211..dff98a954 100644 --- a/egs/librispeech/ASR/local/filter_cuts.py +++ b/egs/librispeech/ASR/local/filter_cuts.py @@ -101,6 +101,9 @@ def filter_cuts(cut_set: CutSet, sp: spm.SentencePieceProcessor): # Note: for ./lstm_transducer_stateless/lstm.py, the formula is # T = ((num_frames - 3) // 2 - 1) // 2 + # Note: for ./pruned_transducer_stateless7/zipformer.py, the formula is + # T = ((num_frames - 7) // 2 + 1) // 2 + tokens = sp.encode(c.supervisions[0].text, out_type=str) if T < len(tokens): diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py index 8927be227..3f27736b3 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py @@ -59,7 +59,6 @@ import torch import torch.multiprocessing as mp import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule -from zipformer import Zipformer from decoder import Decoder from joiner import Joiner from lhotse.cut import Cut @@ -71,6 +70,7 @@ from torch import Tensor from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter +from zipformer import Zipformer from icefall import diagnostics from icefall.checkpoint import load_checkpoint, remove_checkpoints @@ -79,9 +79,9 @@ from icefall.checkpoint import ( save_checkpoint_with_global_batch_idx, update_averaged_model, ) -from icefall.hooks import register_inf_check_hooks from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool LRSchedulerType = Union[ @@ -89,14 +89,12 @@ LRSchedulerType = Union[ ] -def set_batch_count( - model: Union[nn.Module, DDP], batch_count: float -) -> None: +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: if isinstance(model, DDP): # get underlying nn.Module model = model.module for module in model.modules(): - if hasattr(module, 'batch_count'): + if hasattr(module, "batch_count"): module.batch_count = batch_count @@ -126,7 +124,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): "--encoder-dims", type=str, default="384,384,384,384,384", - help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated" + help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated", ) parser.add_argument( @@ -134,7 +132,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): type=str, default="192,192,192,192,192", help="""Attention dimension in the 2 blocks of zipformer encoder layers, comma separated; - not the same as embedding dimension.""" + not the same as embedding dimension.""", ) parser.add_argument( @@ -143,7 +141,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): default="256,256,256,256,256", help="Unmasked dimensions in the encoders, relates to augmentation during training. " "Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance " - " worse." + " worse.", ) parser.add_argument( @@ -248,10 +246,7 @@ def get_parser(): ) parser.add_argument( - "--base-lr", - type=float, - default=0.05, - help="The base learning rate." + "--base-lr", type=float, default=0.05, help="The base learning rate." ) parser.add_argument( @@ -451,11 +446,14 @@ def get_params() -> AttributeDict: def get_encoder_model(params: AttributeDict) -> nn.Module: # TODO: We can add an option to switch between Zipformer and Transformer def to_int_tuple(s: str): - return tuple(map(int, s.split(','))) + return tuple(map(int, s.split(","))) + encoder = Zipformer( num_features=params.feature_dim, output_downsampling_factor=2, - zipformer_downsampling_factors=to_int_tuple(params.zipformer_downsampling_factors), + zipformer_downsampling_factors=to_int_tuple( + params.zipformer_downsampling_factors + ), encoder_dims=to_int_tuple(params.encoder_dims), attention_dim=to_int_tuple(params.attention_dims), encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims), @@ -479,7 +477,7 @@ def get_decoder_model(params: AttributeDict) -> nn.Module: def get_joiner_model(params: AttributeDict) -> nn.Module: joiner = Joiner( - encoder_dim=int(params.encoder_dims.split(',')[-1]), + encoder_dim=int(params.encoder_dims.split(",")[-1]), decoder_dim=params.decoder_dim, joiner_dim=params.joiner_dim, vocab_size=params.vocab_size, @@ -496,7 +494,7 @@ def get_transducer_model(params: AttributeDict) -> nn.Module: encoder=encoder, decoder=decoder, joiner=joiner, - encoder_dim=int(params.encoder_dims.split(',')[-1]), + encoder_dim=int(params.encoder_dims.split(",")[-1]), decoder_dim=params.decoder_dim, joiner_dim=params.joiner_dim, vocab_size=params.vocab_size, @@ -682,18 +680,17 @@ def compute_loss( # take down the scale on the simple loss from 1.0 at the start # to params.simple_loss scale by warm_step. simple_loss_scale = ( - s if batch_idx_train >= warm_step + s + if batch_idx_train >= warm_step else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) ) pruned_loss_scale = ( - 1.0 if batch_idx_train >= warm_step + 1.0 + if batch_idx_train >= warm_step else 0.1 + 0.9 * (batch_idx_train / warm_step) ) - loss = ( - simple_loss_scale * simple_loss + - pruned_loss_scale * pruned_loss - ) + loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training @@ -873,12 +870,16 @@ def train_one_epoch( # of the grad scaler is configurable, but we can't configure it to have different # behavior depending on the current grad scale. cur_grad_scale = scaler._scale.item() - if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0): + if cur_grad_scale < 1.0 or ( + cur_grad_scale < 8.0 and batch_idx % 400 == 0 + ): scaler.update(cur_grad_scale * 2.0) if cur_grad_scale < 0.01: logging.warning(f"Grad scale is small: {cur_grad_scale}") if cur_grad_scale < 1.0e-05: - raise RuntimeError(f"grad_scale is too small, exiting: {cur_grad_scale}") + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) if batch_idx % params.log_interval == 0: cur_lr = scheduler.get_last_lr()[0] @@ -888,8 +889,12 @@ def train_one_epoch( f"Epoch {params.cur_epoch}, " f"batch {batch_idx}, loss[{loss_info}], " f"tot_loss[{tot_loss}], batch size: {batch_size}, " - f"lr: {cur_lr:.2e}, " + - (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + f"lr: {cur_lr:.2e}, " + + ( + f"grad_scale: {scaler._scale.item()}" + if params.use_fp16 + else "" + ) ) if tb_writer is not None: @@ -905,12 +910,15 @@ def train_one_epoch( ) if params.use_fp16: tb_writer.add_scalar( - "train/grad_scale", cur_grad_scale, params.batch_idx_train + "train/grad_scale", + cur_grad_scale, + params.batch_idx_train, ) - - - if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + if ( + batch_idx % params.valid_interval == 0 + and not params.print_diagnostics + ): logging.info("Computing validation loss") valid_info = compute_validation_loss( params=params, @@ -921,7 +929,9 @@ def train_one_epoch( ) model.train() logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") - logging.info(f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) if tb_writer is not None: valid_info.write_summary( tb_writer, "train/valid_", params.batch_idx_train @@ -997,12 +1007,11 @@ def run(rank, world_size, args): model.to(device) if world_size > 1: logging.info("Using DDP") - model = DDP(model, device_ids=[rank], - find_unused_parameters=True) + model = DDP(model, device_ids=[rank], find_unused_parameters=True) - optimizer = ScaledAdam(model.parameters(), - lr=params.base_lr, - clipping_scale=2.0) + optimizer = ScaledAdam( + model.parameters(), lr=params.base_lr, clipping_scale=2.0 + ) scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) @@ -1043,7 +1052,34 @@ def run(rank, world_size, args): # You should use ../local/display_manifest_statistics.py to get # an utterance duration distribution for your dataset to select # the threshold - return 1.0 <= c.duration <= 20.0 + if c.duration < 1.0 or c.duration > 20.0: + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Duration: {c.duration}" + ) + return False + + # In pruned RNN-T, we require that T >= S + # where T is the number of feature frames after subsampling + # and S is the number of tokens in the utterance + + # In ./zipformer.py, the conv module uses the following expression + # for subsampling + T = ((c.num_frames - 7) // 2 + 1) // 2 + tokens = sp.encode(c.supervisions[0].text, out_type=str) + + if T < len(tokens): + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Number of frames (before subsampling): {c.num_frames}. " + f"Number of frames (after subsampling): {T}. " + f"Text: {c.supervisions[0].text}. " + f"Tokens: {tokens}. " + f"Number of tokens: {len(tokens)}" + ) + return False + + return True train_cuts = train_cuts.filter(remove_short_and_long_utt) @@ -1071,8 +1107,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16, - init_scale=1.0) + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1193,7 +1228,9 @@ def scan_pessimistic_batches_for_oom( ) display_and_save_batch(batch, params=params, sp=sp) raise - logging.info(f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) def main(): diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index c14066d38..023dec97d 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -1828,6 +1828,7 @@ def _test_zipformer_main(): torch.randn(batch_size, seq_len, feature_dim), torch.full((batch_size,), seq_len, dtype=torch.int64), ) + assert ((seq_len - 7) // 2 + 1) // 2 == f[0].shape[1], (seq_len, f.shape[1]) f[0].sum().backward() c.eval() f = c( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/train.py b/egs/librispeech/ASR/pruned_transducer_stateless8/train.py index b4177d3f0..2603bb854 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless8/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/train.py @@ -90,12 +90,7 @@ from icefall.checkpoint import ( from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.hooks import register_inf_check_hooks -from icefall.utils import ( - AttributeDict, - MetricsTracker, - setup_logger, - str2bool, -) +from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool LRSchedulerType = Union[ torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler @@ -1045,7 +1040,9 @@ def train_one_epoch( params.best_train_loss = params.train_loss -def filter_short_and_long_utterances(cuts: CutSet) -> CutSet: +def filter_short_and_long_utterances( + cuts: CutSet, sp: spm.SentencePieceProcessor +) -> CutSet: def remove_short_and_long_utt(c: Cut): # Keep only utterances with duration between 1 second and 20 seconds # @@ -1055,7 +1052,34 @@ def filter_short_and_long_utterances(cuts: CutSet) -> CutSet: # You should use ../local/display_manifest_statistics.py to get # an utterance duration distribution for your dataset to select # the threshold - return 1.0 <= c.duration <= 20.0 + if c.duration < 1.0 or c.duration > 20.0: + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Duration: {c.duration}" + ) + return False + + # In pruned RNN-T, we require that T >= S + # where T is the number of feature frames after subsampling + # and S is the number of tokens in the utterance + + # In ./zipformer.py, the conv module uses the following expression + # for subsampling + T = ((c.num_frames - 7) // 2 + 1) // 2 + tokens = sp.encode(c.supervisions[0].text, out_type=str) + + if T < len(tokens): + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Number of frames (before subsampling): {c.num_frames}. " + f"Number of frames (after subsampling): {T}. " + f"Text: {c.supervisions[0].text}. " + f"Tokens: {tokens}. " + f"Number of tokens: {len(tokens)}" + ) + return False + + return True cuts = cuts.filter(remove_short_and_long_utt) @@ -1162,7 +1186,7 @@ def run(rank, world_size, args): train_cuts += librispeech.train_clean_360_cuts() train_cuts += librispeech.train_other_500_cuts() - train_cuts = filter_short_and_long_utterances(train_cuts) + train_cuts = filter_short_and_long_utterances(train_cuts, sp) gigaspeech = GigaSpeech(manifest_dir=args.manifest_dir) # XL 10k hours @@ -1179,7 +1203,7 @@ def run(rank, world_size, args): logging.info("Using the S subset of GigaSpeech (250 hours)") train_giga_cuts = gigaspeech.train_S_cuts() - train_giga_cuts = filter_short_and_long_utterances(train_giga_cuts) + train_giga_cuts = filter_short_and_long_utterances(train_giga_cuts, sp) train_giga_cuts = train_giga_cuts.repeat(times=None) if args.enable_musan: From aa7bae1ecd520e06804721c3b13a8c3c2eb06bcc Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Wed, 16 Nov 2022 19:58:28 +0800 Subject: [PATCH 005/174] fix decode.py for conformer_ctc in gigaspeech (#688) --- egs/gigaspeech/ASR/conformer_ctc/decode.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/gigaspeech/ASR/conformer_ctc/decode.py b/egs/gigaspeech/ASR/conformer_ctc/decode.py index 51406667e..9c1418baa 100755 --- a/egs/gigaspeech/ASR/conformer_ctc/decode.py +++ b/egs/gigaspeech/ASR/conformer_ctc/decode.py @@ -481,9 +481,9 @@ def decode_dataset( ), "It should not decode to empty in the first batch!" this_batch = [] hyp_words = [] - for ref_text in texts: + for cut_id, ref_text in zip(cut_ids, texts): ref_words = ref_text.split() - this_batch.append((ref_words, hyp_words)) + this_batch.append((cut_id, ref_words, hyp_words)) for lm_scale in results.keys(): results[lm_scale].extend(this_batch) From d110b04ad389134c82fa314e3aafc7b40043efb0 Mon Sep 17 00:00:00 2001 From: Desh Raj Date: Wed, 16 Nov 2022 13:06:43 -0500 Subject: [PATCH 006/174] apply new black formatting to all files --- .github/workflows/style_check.yml | 11 +- .pre-commit-config.yaml | 26 +- docker/README.md | 24 +- .../Dockerfile | 14 +- .../Dockerfile | 17 +- .../images/k2-gt-v1.9-blueviolet.svg | 2 +- .../images/python-gt-v3.6-blue.svg | 2 +- .../images/torch-gt-v1.6.0-green.svg | 2 +- docs/source/recipes/aishell/index.rst | 1 - docs/source/recipes/timit/index.rst | 1 - docs/source/recipes/timit/tdnn_ligru_ctc.rst | 28 +- docs/source/recipes/timit/tdnn_lstm_ctc.rst | 24 +- .../local/compute_fbank_aidatatang_200zh.py | 8 +- .../ASR/local/prepare_char.py | 8 +- .../ASR/local/prepare_lang.py | 4 +- .../ASR/local/test_prepare_lang.py | 4 +- egs/aidatatang_200zh/ASR/local/text2token.py | 21 +- egs/aidatatang_200zh/ASR/prepare.sh | 3 +- .../asr_datamodule.py | 110 +- .../pruned_transducer_stateless2/decode.py | 50 +- .../pruned_transducer_stateless2/export.py | 20 +- .../pretrained.py | 41 +- .../ASR/pruned_transducer_stateless2/train.py | 50 +- egs/aishell/ASR/conformer_ctc/conformer.py | 70 +- egs/aishell/ASR/conformer_ctc/decode.py | 29 +- egs/aishell/ASR/conformer_ctc/export.py | 17 +- egs/aishell/ASR/conformer_ctc/pretrained.py | 39 +- egs/aishell/ASR/conformer_ctc/subsampling.py | 16 +- .../ASR/conformer_ctc/test_subsampling.py | 3 +- egs/aishell/ASR/conformer_ctc/train.py | 12 +- egs/aishell/ASR/conformer_ctc/transformer.py | 44 +- egs/aishell/ASR/conformer_mmi/conformer.py | 70 +- egs/aishell/ASR/conformer_mmi/decode.py | 33 +- egs/aishell/ASR/conformer_mmi/subsampling.py | 16 +- egs/aishell/ASR/conformer_mmi/train.py | 8 +- egs/aishell/ASR/conformer_mmi/transformer.py | 44 +- .../local/compute_fbank_aidatatang_200zh.py | 8 +- .../ASR/local/compute_fbank_aishell.py | 8 +- egs/aishell/ASR/local/prepare_char.py | 8 +- egs/aishell/ASR/local/prepare_lang.py | 4 +- egs/aishell/ASR/local/test_prepare_lang.py | 4 +- .../pruned_transducer_stateless2/decode.py | 50 +- .../pruned_transducer_stateless2/export.py | 31 +- .../pretrained.py | 50 +- .../ASR/pruned_transducer_stateless2/train.py | 64 +- .../pruned_transducer_stateless3/decode.py | 73 +- .../pruned_transducer_stateless3/export.py | 54 +- .../ASR/pruned_transducer_stateless3/model.py | 8 +- .../pretrained.py | 50 +- .../ASR/pruned_transducer_stateless3/train.py | 79 +- .../ASR/tdnn_lstm_ctc/asr_datamodule.py | 118 +- egs/aishell/ASR/tdnn_lstm_ctc/decode.py | 33 +- egs/aishell/ASR/tdnn_lstm_ctc/model.py | 5 +- egs/aishell/ASR/tdnn_lstm_ctc/pretrained.py | 37 +- egs/aishell/ASR/tdnn_lstm_ctc/train.py | 7 +- .../ASR/transducer_stateless/beam_search.py | 22 +- .../ASR/transducer_stateless/conformer.py | 70 +- .../ASR/transducer_stateless/decode.py | 39 +- .../ASR/transducer_stateless/decoder.py | 4 +- .../ASR/transducer_stateless/export.py | 20 +- egs/aishell/ASR/transducer_stateless/model.py | 4 +- .../ASR/transducer_stateless/pretrained.py | 36 +- egs/aishell/ASR/transducer_stateless/train.py | 15 +- .../ASR/transducer_stateless/transformer.py | 4 +- .../asr_datamodule.py | 85 +- .../transducer_stateless_modified-2/decode.py | 46 +- .../transducer_stateless_modified-2/export.py | 20 +- .../pretrained.py | 50 +- .../transducer_stateless_modified-2/train.py | 22 +- .../transducer_stateless_modified/decode.py | 46 +- .../transducer_stateless_modified/export.py | 20 +- .../pretrained.py | 50 +- .../transducer_stateless_modified/train.py | 15 +- egs/aishell2/ASR/local/__init__.py | 0 .../ASR/local/compute_fbank_aishell2.py | 8 +- .../pruned_transducer_stateless5/__init__.py | 0 .../asr_datamodule.py | 114 +- .../pruned_transducer_stateless5/decode.py | 67 +- .../pruned_transducer_stateless5/export.py | 47 +- .../pretrained.py | 40 +- .../ASR/pruned_transducer_stateless5/train.py | 67 +- .../ASR/local/compute_fbank_aishell4.py | 8 +- egs/aishell4/ASR/local/prepare_char.py | 8 +- egs/aishell4/ASR/local/prepare_lang.py | 4 +- egs/aishell4/ASR/local/test_prepare_lang.py | 4 +- egs/aishell4/ASR/local/text2token.py | 21 +- .../asr_datamodule.py | 110 +- .../pruned_transducer_stateless5/decode.py | 69 +- .../pruned_transducer_stateless5/export.py | 47 +- .../pretrained.py | 45 +- .../ASR/pruned_transducer_stateless5/train.py | 59 +- .../ASR/local/compute_fbank_alimeeting.py | 8 +- egs/alimeeting/ASR/local/prepare_char.py | 8 +- egs/alimeeting/ASR/local/prepare_lang.py | 4 +- egs/alimeeting/ASR/local/test_prepare_lang.py | 4 +- egs/alimeeting/ASR/local/text2segments.py | 2 +- egs/alimeeting/ASR/local/text2token.py | 21 +- .../asr_datamodule.py | 110 +- .../pruned_transducer_stateless2/decode.py | 60 +- .../pruned_transducer_stateless2/export.py | 20 +- .../pretrained.py | 41 +- .../ASR/pruned_transducer_stateless2/train.py | 50 +- egs/csj/ASR/.gitignore | 2 +- egs/csj/ASR/local/compute_fbank_csj.py | 38 +- egs/csj/ASR/local/compute_fbank_musan.py | 17 +- egs/csj/ASR/local/conf/disfluent.ini | 55 +- egs/csj/ASR/local/conf/fluent.ini | 55 +- egs/csj/ASR/local/conf/number.ini | 55 +- egs/csj/ASR/local/conf/symbol.ini | 55 +- .../ASR/local/display_manifest_statistics.py | 4 +- egs/csj/ASR/local/prepare_lang_char.py | 17 +- egs/csj/ASR/local/validate_manifest.py | 7 +- .../ASR/conformer_ctc/asr_datamodule.py | 117 +- egs/gigaspeech/ASR/conformer_ctc/conformer.py | 66 +- egs/gigaspeech/ASR/conformer_ctc/decode.py | 29 +- .../ASR/conformer_ctc/gigaspeech_scoring.py | 3 +- .../ASR/conformer_ctc/label_smoothing.py | 7 +- .../ASR/conformer_ctc/subsampling.py | 16 +- egs/gigaspeech/ASR/conformer_ctc/train.py | 12 +- .../ASR/conformer_ctc/transformer.py | 49 +- .../compute_fbank_gigaspeech_dev_test.py | 4 +- .../local/compute_fbank_gigaspeech_splits.py | 10 +- .../ASR/local/preprocess_gigaspeech.py | 10 +- .../asr_datamodule.py | 117 +- .../pruned_transducer_stateless2/decode.py | 42 +- .../pruned_transducer_stateless2/export.py | 24 +- .../ASR/pruned_transducer_stateless2/train.py | 48 +- egs/librispeech/ASR/conformer_ctc/ali.py | 25 +- .../ASR/conformer_ctc/conformer.py | 66 +- egs/librispeech/ASR/conformer_ctc/decode.py | 29 +- egs/librispeech/ASR/conformer_ctc/export.py | 17 +- .../ASR/conformer_ctc/label_smoothing.py | 7 +- .../ASR/conformer_ctc/pretrained.py | 33 +- .../ASR/conformer_ctc/subsampling.py | 16 +- egs/librispeech/ASR/conformer_ctc/train.py | 22 +- .../ASR/conformer_ctc/transformer.py | 49 +- .../ASR/conformer_ctc2/attention.py | 19 +- .../ASR/conformer_ctc2/conformer.py | 65 +- egs/librispeech/ASR/conformer_ctc2/decode.py | 56 +- egs/librispeech/ASR/conformer_ctc2/export.py | 49 +- egs/librispeech/ASR/conformer_ctc2/train.py | 39 +- .../ASR/conformer_ctc2/transformer.py | 50 +- .../ASR/conformer_mmi/conformer.py | 70 +- egs/librispeech/ASR/conformer_mmi/decode.py | 29 +- .../ASR/conformer_mmi/subsampling.py | 16 +- .../ASR/conformer_mmi/test_subsampling.py | 3 +- .../ASR/conformer_mmi/test_transformer.py | 9 +- .../ASR/conformer_mmi/train-with-attention.py | 27 +- egs/librispeech/ASR/conformer_mmi/train.py | 27 +- .../ASR/conformer_mmi/transformer.py | 28 +- .../decode.py | 69 +- .../emformer.py | 119 +- .../export.py | 47 +- .../stream.py | 8 +- .../streaming_decode.py | 75 +- .../train.py | 56 +- .../decode.py | 69 +- .../emformer.py | 108 +- .../export.py | 47 +- .../streaming_decode.py | 75 +- .../train.py | 56 +- .../ASR/local/add_alignment_librispeech.py | 12 +- egs/librispeech/ASR/local/compile_hlg.py | 4 +- egs/librispeech/ASR/local/compile_lg.py | 4 +- .../compute_fbank_gigaspeech_dev_test.py | 4 +- .../local/compute_fbank_gigaspeech_splits.py | 10 +- .../ASR/local/compute_fbank_librispeech.py | 8 +- .../ASR/local/compute_fbank_musan.py | 8 +- .../convert_transcript_words_to_tokens.py | 16 +- egs/librispeech/ASR/local/download_lm.py | 4 +- egs/librispeech/ASR/local/filter_cuts.py | 10 +- .../ASR/local/generate_unique_lexicon.py | 4 +- egs/librispeech/ASR/local/prepare_lang_bpe.py | 4 +- .../ASR/local/prepare_lm_training_data.py | 11 +- .../ASR/local/preprocess_gigaspeech.py | 4 +- .../ASR/local/test_prepare_lang.py | 4 +- .../ASR/local/validate_manifest.py | 7 +- .../ASR/lstm_transducer_stateless/decode.py | 818 ------------ .../ASR/lstm_transducer_stateless/export.py | 388 ------ .../jit_pretrained.py | 322 ----- .../ASR/lstm_transducer_stateless/lstm.py | 871 ------------- .../ASR/lstm_transducer_stateless/model.py | 210 --- .../lstm_transducer_stateless/pretrained.py | 352 ----- .../ASR/lstm_transducer_stateless/stream.py | 148 --- .../streaming_decode.py | 968 -------------- .../ASR/lstm_transducer_stateless/train.py | 1157 ----------------- .../ASR/lstm_transducer_stateless2/decode.py | 67 +- .../ASR/lstm_transducer_stateless2/export.py | 59 +- .../jit_pretrained.py | 21 +- .../ASR/lstm_transducer_stateless2/model.py | 8 +- .../lstm_transducer_stateless2/ncnn-decode.py | 15 +- .../lstm_transducer_stateless2/pretrained.py | 40 +- .../streaming-ncnn-decode.py | 27 +- .../streaming-onnx-decode.py | 45 +- .../ASR/lstm_transducer_stateless2/train.py | 68 +- .../ASR/lstm_transducer_stateless3/decode.py | 79 +- .../ASR/lstm_transducer_stateless3/export.py | 47 +- .../jit_pretrained.py | 21 +- .../ASR/lstm_transducer_stateless3/lstm.py | 14 +- .../lstm_transducer_stateless3/pretrained.py | 40 +- .../streaming_decode.py | 74 +- .../ASR/lstm_transducer_stateless3/train.py | 66 +- .../ASR/pruned2_knowledge/asr_datamodule.py | 125 +- .../ASR/pruned2_knowledge/beam_search.py | 18 +- .../ASR/pruned2_knowledge/conformer.py | 90 +- .../ASR/pruned2_knowledge/decode.py | 44 +- .../ASR/pruned2_knowledge/decoder.py | 4 +- .../ASR/pruned2_knowledge/decoder2.py | 84 +- .../ASR/pruned2_knowledge/export.py | 20 +- .../ASR/pruned2_knowledge/joiner.py | 4 +- .../ASR/pruned2_knowledge/model.py | 8 +- .../ASR/pruned2_knowledge/optim.py | 35 +- .../ASR/pruned2_knowledge/sampling.py | 184 +-- .../ASR/pruned2_knowledge/scaling.py | 51 +- .../ASR/pruned2_knowledge/scaling_tmp.py | 355 +++-- .../ASR/pruned2_knowledge/train.py | 50 +- .../pruned_stateless_emformer_rnnt2/decode.py | 69 +- .../emformer.py | 8 +- .../pruned_stateless_emformer_rnnt2/export.py | 47 +- .../pruned_stateless_emformer_rnnt2/model.py | 4 +- .../pruned_stateless_emformer_rnnt2/train.py | 44 +- .../beam_search.py | 26 +- .../ASR/pruned_transducer_stateless/decode.py | 44 +- .../decode_stream.py | 19 +- .../pruned_transducer_stateless/decoder.py | 4 +- .../ASR/pruned_transducer_stateless/export.py | 20 +- .../ASR/pruned_transducer_stateless/model.py | 4 +- .../pruned_transducer_stateless/pretrained.py | 36 +- .../streaming_beam_search.py | 8 +- .../streaming_decode.py | 39 +- .../ASR/pruned_transducer_stateless/train.py | 46 +- .../beam_search.py | 51 +- .../pruned_transducer_stateless2/conformer.py | 97 +- .../pruned_transducer_stateless2/decode.py | 50 +- .../pruned_transducer_stateless2/decoder.py | 8 +- .../pruned_transducer_stateless2/export.py | 24 +- .../pruned_transducer_stateless2/joiner.py | 4 +- .../ASR/pruned_transducer_stateless2/model.py | 8 +- .../ASR/pruned_transducer_stateless2/optim.py | 35 +- .../pretrained.py | 36 +- .../pruned_transducer_stateless2/scaling.py | 56 +- .../streaming_beam_search.py | 12 +- .../streaming_decode.py | 39 +- .../ASR/pruned_transducer_stateless2/train.py | 58 +- .../asr_datamodule.py | 85 +- .../decode-giga.py | 54 +- .../pruned_transducer_stateless3/decode.py | 74 +- .../pruned_transducer_stateless3/export.py | 32 +- .../gigaspeech.py | 8 +- .../jit_pretrained.py | 21 +- .../ASR/pruned_transducer_stateless3/model.py | 8 +- .../onnx_check.py | 24 +- .../onnx_pretrained.py | 27 +- .../pretrained.py | 36 +- .../scaling_converter.py | 10 +- .../streaming_decode.py | 39 +- .../pruned_transducer_stateless3/test_onnx.py | 24 +- .../ASR/pruned_transducer_stateless3/train.py | 65 +- .../pruned_transducer_stateless4/decode.py | 79 +- .../pruned_transducer_stateless4/export.py | 47 +- .../streaming_decode.py | 62 +- .../ASR/pruned_transducer_stateless4/train.py | 61 +- .../pruned_transducer_stateless5/conformer.py | 118 +- .../pruned_transducer_stateless5/decode.py | 67 +- .../pruned_transducer_stateless5/export.py | 47 +- .../pretrained.py | 40 +- .../streaming_decode.py | 62 +- .../ASR/pruned_transducer_stateless5/train.py | 66 +- .../pruned_transducer_stateless6/conformer.py | 67 +- .../pruned_transducer_stateless6/decode.py | 69 +- .../pruned_transducer_stateless6/export.py | 24 +- .../extract_codebook_index.py | 3 +- .../hubert_decode.py | 17 +- .../hubert_xlarge.py | 22 +- .../ASR/pruned_transducer_stateless6/model.py | 12 +- .../ASR/pruned_transducer_stateless6/train.py | 65 +- .../pruned_transducer_stateless6/vq_utils.py | 31 +- .../pruned_transducer_stateless7/decode.py | 67 +- .../pruned_transducer_stateless7/decoder.py | 6 +- .../pruned_transducer_stateless7/export.py | 47 +- .../jit_pretrained.py | 21 +- .../pruned_transducer_stateless7/joiner.py | 4 +- .../ASR/pruned_transducer_stateless7/model.py | 16 +- .../ASR/pruned_transducer_stateless7/optim.py | 439 ++++--- .../pretrained.py | 40 +- .../pruned_transducer_stateless7/scaling.py | 487 +++---- .../scaling_converter.py | 12 +- .../ASR/pruned_transducer_stateless7/train.py | 88 +- .../pruned_transducer_stateless7/zipformer.py | 660 +++++----- .../pruned_transducer_stateless8/decode.py | 67 +- .../pruned_transducer_stateless8/export.py | 47 +- .../jit_pretrained.py | 21 +- .../ASR/pruned_transducer_stateless8/model.py | 4 +- .../pretrained.py | 40 +- .../ASR/pruned_transducer_stateless8/train.py | 99 +- .../ASR/streaming_conformer_ctc/README.md | 16 +- .../ASR/streaming_conformer_ctc/conformer.py | 116 +- .../streaming_decode.py | 68 +- .../ASR/streaming_conformer_ctc/train.py | 16 +- .../streaming_conformer_ctc/transformer.py | 40 +- .../ASR/tdnn_lstm_ctc/asr_datamodule.py | 113 +- egs/librispeech/ASR/tdnn_lstm_ctc/decode.py | 29 +- egs/librispeech/ASR/tdnn_lstm_ctc/model.py | 5 +- .../ASR/tdnn_lstm_ctc/pretrained.py | 43 +- egs/librispeech/ASR/tdnn_lstm_ctc/train.py | 8 +- egs/librispeech/ASR/transducer/beam_search.py | 14 +- egs/librispeech/ASR/transducer/decode.py | 28 +- egs/librispeech/ASR/transducer/export.py | 17 +- egs/librispeech/ASR/transducer/pretrained.py | 33 +- egs/librispeech/ASR/transducer/rnn.py | 24 +- egs/librispeech/ASR/transducer/test_rnn.py | 16 +- egs/librispeech/ASR/transducer/train.py | 12 +- .../ASR/transducer_lstm/beam_search.py | 14 +- egs/librispeech/ASR/transducer_lstm/decode.py | 28 +- .../ASR/transducer_lstm/encoder.py | 4 +- egs/librispeech/ASR/transducer_lstm/train.py | 12 +- .../ASR/transducer_stateless/alignment.py | 4 +- .../ASR/transducer_stateless/beam_search.py | 28 +- .../ASR/transducer_stateless/compute_ali.py | 24 +- .../ASR/transducer_stateless/conformer.py | 107 +- .../ASR/transducer_stateless/decode.py | 42 +- .../ASR/transducer_stateless/decoder.py | 4 +- .../ASR/transducer_stateless/export.py | 20 +- .../ASR/transducer_stateless/joiner.py | 8 +- .../ASR/transducer_stateless/pretrained.py | 36 +- .../transducer_stateless/test_compute_ali.py | 11 +- .../transducer_stateless/test_conformer.py | 4 +- .../ASR/transducer_stateless/train.py | 23 +- .../ASR/transducer_stateless/transformer.py | 4 +- .../ASR/transducer_stateless2/decode.py | 42 +- .../ASR/transducer_stateless2/export.py | 20 +- .../ASR/transducer_stateless2/pretrained.py | 36 +- .../ASR/transducer_stateless2/train.py | 23 +- .../decode.py | 42 +- .../export.py | 20 +- .../pretrained.py | 36 +- .../test_asr_datamodule.py | 4 +- .../train.py | 22 +- egs/ptb/LM/local/sort_lm_training_data.py | 4 +- .../LM/local/test_prepare_lm_training_data.py | 4 +- .../ASR/local/compute_fbank_musan.py | 8 +- .../ASR/local/compute_fbank_spgispeech.py | 14 +- egs/spgispeech/ASR/local/prepare_splits.py | 8 +- .../asr_datamodule.py | 100 +- .../pruned_transducer_stateless2/decode.py | 66 +- .../pruned_transducer_stateless2/export.py | 26 +- .../ASR/pruned_transducer_stateless2/train.py | 51 +- .../ASR/local/compute_fbank_tal_csasr.py | 8 +- egs/tal_csasr/ASR/local/prepare_char.py | 4 +- egs/tal_csasr/ASR/local/prepare_lang.py | 4 +- egs/tal_csasr/ASR/local/test_prepare_lang.py | 4 +- egs/tal_csasr/ASR/local/text2token.py | 21 +- .../asr_datamodule.py | 110 +- .../pruned_transducer_stateless5/decode.py | 77 +- .../pruned_transducer_stateless5/export.py | 47 +- .../pretrained.py | 40 +- .../ASR/pruned_transducer_stateless5/train.py | 59 +- .../ASR/local/compute_fbank_tedlium.py | 8 +- .../convert_transcript_words_to_bpe_ids.py | 4 +- egs/tedlium3/ASR/local/prepare_lexicon.py | 11 +- egs/tedlium3/ASR/local/prepare_transcripts.py | 11 +- .../ASR/pruned_transducer_stateless/decode.py | 38 +- .../ASR/pruned_transducer_stateless/export.py | 20 +- .../pruned_transducer_stateless/pretrained.py | 41 +- .../ASR/pruned_transducer_stateless/train.py | 35 +- .../transducer_stateless/asr_datamodule.py | 118 +- .../ASR/transducer_stateless/beam_search.py | 30 +- .../ASR/transducer_stateless/decode.py | 31 +- .../ASR/transducer_stateless/decoder.py | 4 +- .../ASR/transducer_stateless/export.py | 20 +- .../ASR/transducer_stateless/pretrained.py | 36 +- .../ASR/transducer_stateless/train.py | 11 +- egs/timit/ASR/RESULTS.md | 2 +- egs/timit/ASR/local/compile_hlg.py | 4 +- egs/timit/ASR/local/compute_fbank_timit.py | 8 +- egs/timit/ASR/local/prepare_lexicon.py | 8 +- egs/timit/ASR/prepare.sh | 4 +- egs/timit/ASR/tdnn_ligru_ctc/decode.py | 29 +- egs/timit/ASR/tdnn_ligru_ctc/model.py | 12 +- egs/timit/ASR/tdnn_ligru_ctc/pretrained.py | 43 +- egs/timit/ASR/tdnn_ligru_ctc/train.py | 4 +- egs/timit/ASR/tdnn_lstm_ctc/asr_datamodule.py | 104 +- egs/timit/ASR/tdnn_lstm_ctc/decode.py | 29 +- egs/timit/ASR/tdnn_lstm_ctc/model.py | 5 +- egs/timit/ASR/tdnn_lstm_ctc/pretrained.py | 43 +- egs/timit/ASR/tdnn_lstm_ctc/train.py | 4 +- .../compute_fbank_wenetspeech_dev_test.py | 11 +- .../local/compute_fbank_wenetspeech_splits.py | 10 +- egs/wenetspeech/ASR/local/prepare_char.py | 8 +- .../ASR/local/preprocess_wenetspeech.py | 6 +- egs/wenetspeech/ASR/local/text2token.py | 21 +- egs/wenetspeech/ASR/prepare.sh | 2 +- .../asr_datamodule.py | 121 +- .../pruned_transducer_stateless2/decode.py | 64 +- .../pruned_transducer_stateless2/export.py | 28 +- .../jit_pretrained.py | 21 +- .../onnx_check.py | 24 +- .../onnx_pretrained.py | 27 +- .../pretrained.py | 41 +- .../ASR/pruned_transducer_stateless2/train.py | 50 +- .../pruned_transducer_stateless5/conformer.py | 97 +- .../pruned_transducer_stateless5/decode.py | 75 +- .../decode_stream.py | 19 +- .../pruned_transducer_stateless5/export.py | 20 +- .../pretrained.py | 41 +- .../streaming_beam_search.py | 8 +- .../streaming_decode.py | 62 +- .../ASR/pruned_transducer_stateless5/train.py | 67 +- egs/yesno/ASR/local/compile_hlg.py | 4 +- egs/yesno/ASR/local/compute_fbank_yesno.py | 12 +- egs/yesno/ASR/tdnn/asr_datamodule.py | 74 +- egs/yesno/ASR/tdnn/decode.py | 29 +- egs/yesno/ASR/tdnn/pretrained.py | 37 +- egs/yesno/ASR/tdnn/train.py | 4 +- egs/yesno/ASR/transducer/decode.py | 25 +- egs/yesno/ASR/transducer/train.py | 4 +- icefall/char_graph_compiler.py | 8 +- icefall/checkpoint.py | 12 +- icefall/decode.py | 36 +- icefall/diagnostics.py | 80 +- icefall/dist.py | 4 +- icefall/env.py | 4 +- icefall/graph_compiler.py | 4 +- icefall/hooks.py | 19 +- icefall/lexicon.py | 16 +- icefall/mmi.py | 29 +- icefall/mmi_graph_compiler.py | 8 +- icefall/rnn_lm/compute_perplexity.py | 15 +- icefall/rnn_lm/dataset.py | 8 +- icefall/rnn_lm/export.py | 17 +- icefall/rnn_lm/model.py | 28 +- icefall/rnn_lm/train.py | 11 +- icefall/shared/make_kn_lm.py | 184 ++- icefall/utils.py | 64 +- pyproject.toml | 2 +- setup.py | 3 +- test/test_checkpoint.py | 6 +- test/test_decode.py | 1 + test/test_graph_compiler.py | 4 +- test/test_utils.py | 4 +- 440 files changed, 6789 insertions(+), 14532 deletions(-) mode change 100755 => 100644 egs/aishell2/ASR/local/__init__.py mode change 100755 => 100644 egs/aishell2/ASR/pruned_transducer_stateless5/__init__.py mode change 100755 => 100644 egs/aishell2/ASR/pruned_transducer_stateless5/asr_datamodule.py mode change 100755 => 100644 egs/librispeech/ASR/lstm_transducer_stateless/decode.py mode change 100755 => 100644 egs/librispeech/ASR/lstm_transducer_stateless/export.py mode change 100755 => 100644 egs/librispeech/ASR/lstm_transducer_stateless/jit_pretrained.py mode change 100755 => 100644 egs/librispeech/ASR/lstm_transducer_stateless/pretrained.py mode change 100755 => 100644 egs/librispeech/ASR/lstm_transducer_stateless/streaming_decode.py mode change 100755 => 100644 egs/librispeech/ASR/lstm_transducer_stateless/train.py diff --git a/.github/workflows/style_check.yml b/.github/workflows/style_check.yml index 90459bc1c..45d261ccc 100644 --- a/.github/workflows/style_check.yml +++ b/.github/workflows/style_check.yml @@ -45,17 +45,18 @@ jobs: - name: Install Python dependencies run: | - python3 -m pip install --upgrade pip black==21.6b0 flake8==3.9.2 click==8.0.4 - # See https://github.com/psf/black/issues/2964 - # The version of click should be selected from 8.0.0, 8.0.1, 8.0.2, 8.0.3, and 8.0.4 + python3 -m pip install --upgrade pip black==22.3.0 flake8==5.0.4 click==8.1.0 + # Click issue fixed in https://github.com/psf/black/pull/2966 - name: Run flake8 shell: bash working-directory: ${{github.workspace}} run: | # stop the build if there are Python syntax errors or undefined names - flake8 . --count --show-source --statistics - flake8 . + flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics + # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide + flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 \ + --statistics --extend-ignore=E203,E266,E501,F401,E402,F403,F841,W503 - name: Run black shell: bash diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 446ba0fe7..e2055801b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,26 +1,38 @@ repos: - repo: https://github.com/psf/black - rev: 21.6b0 + rev: 22.3.0 hooks: - id: black - args: [--line-length=80] + args: ["--line-length=88"] additional_dependencies: ['click==8.0.1'] exclude: icefall\/__init__\.py - repo: https://github.com/PyCQA/flake8 - rev: 3.9.2 + rev: 5.0.4 hooks: - id: flake8 - args: [--max-line-length=80] + args: ["--max-line-length=88", "--extend-ignore=E203,E266,E501,F401,E402,F403,F841,W503"] + + # What are we ignoring here? + # E203: whitespace before ':' + # E266: too many leading '#' for block comment + # E501: line too long + # F401: module imported but unused + # E402: module level import not at top of file + # F403: 'from module import *' used; unable to detect undefined names + # F841: local variable is assigned to but never used + # W503: line break before binary operator + # In addition, the default ignore list is: + # E121,E123,E126,E226,E24,E704,W503,W504 - repo: https://github.com/pycqa/isort - rev: 5.9.2 + rev: 5.10.1 hooks: - id: isort - args: [--profile=black, --line-length=80] + args: ["--profile=black"] - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.0.1 + rev: v4.2.0 hooks: - id: check-executables-have-shebangs - id: end-of-file-fixer diff --git a/docker/README.md b/docker/README.md index 6f2314e96..c14b9bf75 100644 --- a/docker/README.md +++ b/docker/README.md @@ -2,7 +2,7 @@ 2 sets of configuration are provided - (a) Ubuntu18.04-pytorch1.12.1-cuda11.3-cudnn8, and (b) Ubuntu18.04-pytorch1.7.1-cuda11.0-cudnn8. -If your NVIDIA driver supports CUDA Version: 11.3, please go for case (a) Ubuntu18.04-pytorch1.12.1-cuda11.3-cudnn8. +If your NVIDIA driver supports CUDA Version: 11.3, please go for case (a) Ubuntu18.04-pytorch1.12.1-cuda11.3-cudnn8. Otherwise, since the older PyTorch images are not updated with the [apt-key rotation by NVIDIA](https://developer.nvidia.com/blog/updating-the-cuda-linux-gpg-repository-key), you have to go for case (b) Ubuntu18.04-pytorch1.7.1-cuda11.0-cudnn8. Ensure that your NVDIA driver supports at least CUDA 11.0. @@ -10,7 +10,7 @@ You can check the highest CUDA version within your NVIDIA driver's support with ```bash $ nvidia-smi -Tue Sep 20 00:26:13 2022 +Tue Sep 20 00:26:13 2022 +-----------------------------------------------------------------------------+ | NVIDIA-SMI 450.119.03 Driver Version: 450.119.03 CUDA Version: 11.0 | |-------------------------------+----------------------+----------------------+ @@ -26,7 +26,7 @@ Tue Sep 20 00:26:13 2022 | 41% 30C P8 11W / 280W | 6MiB / 24220MiB | 0% Default | | | | N/A | +-------------------------------+----------------------+----------------------+ - + +-----------------------------------------------------------------------------+ | Processes: | | GPU GI CI PID Type Process name GPU Memory | @@ -40,15 +40,15 @@ Tue Sep 20 00:26:13 2022 ``` ## Building images locally -If your environment requires a proxy to access the Internet, remember to add those information into the Dockerfile directly. -For most cases, you can uncomment these lines in the Dockerfile and add in your proxy details. +If your environment requires a proxy to access the Internet, remember to add those information into the Dockerfile directly. +For most cases, you can uncomment these lines in the Dockerfile and add in your proxy details. ```dockerfile ENV http_proxy=http://aaa.bb.cc.net:8080 \ https_proxy=http://aaa.bb.cc.net:8080 ``` -Then, proceed with these commands. +Then, proceed with these commands. ### If you are case (a), i.e. your NVIDIA driver supports CUDA version >= 11.3: @@ -72,11 +72,11 @@ docker run -it --runtime=nvidia --shm-size=2gb --name=icefall --gpus all icefall ``` ### Tips: -1. Since your data and models most probably won't be in the docker, you must use the -v flag to access the host machine. Do this by specifying `-v {/path/in/host/machine}:{/path/in/docker}`. +1. Since your data and models most probably won't be in the docker, you must use the -v flag to access the host machine. Do this by specifying `-v {/path/in/host/machine}:{/path/in/docker}`. 2. Also, if your environment requires a proxy, this would be a good time to add it in too: `-e http_proxy=http://aaa.bb.cc.net:8080 -e https_proxy=http://aaa.bb.cc.net:8080`. -Overall, your docker run command should look like this. +Overall, your docker run command should look like this. ```bash docker run -it --runtime=nvidia --shm-size=2gb --name=icefall --gpus all -v {/path/in/host/machine}:{/path/in/docker} -e http_proxy=http://aaa.bb.cc.net:8080 -e https_proxy=http://aaa.bb.cc.net:8080 icefall/pytorch1.12.1 @@ -86,9 +86,9 @@ You can explore more docker run options [here](https://docs.docker.com/engine/re ### Linking to icefall in your host machine -If you already have icefall downloaded onto your host machine, you can use that repository instead so that changes in your code are visible inside and outside of the container. +If you already have icefall downloaded onto your host machine, you can use that repository instead so that changes in your code are visible inside and outside of the container. -Note: Remember to set the -v flag above during the first run of the container, as that is the only way for your container to access your host machine. +Note: Remember to set the -v flag above during the first run of the container, as that is the only way for your container to access your host machine. Warning: Check that the icefall in your host machine is visible from within your container before proceeding to the commands below. Use these commands once you are inside the container. @@ -103,7 +103,7 @@ ln -s {/path/in/docker/to/icefall} /workspace/icefall docker exec -it icefall /bin/bash ``` -## Restarting a killed container that has been run before. +## Restarting a killed container that has been run before. ```bash docker start -ai icefall ``` @@ -111,4 +111,4 @@ docker start -ai icefall ## Sample usage of the CPU based images: ```bash docker run -it icefall /bin/bash -``` +``` diff --git a/docker/Ubuntu18.04-pytorch1.12.1-cuda11.3-cudnn8/Dockerfile b/docker/Ubuntu18.04-pytorch1.12.1-cuda11.3-cudnn8/Dockerfile index 3637d2f11..ff9e40604 100644 --- a/docker/Ubuntu18.04-pytorch1.12.1-cuda11.3-cudnn8/Dockerfile +++ b/docker/Ubuntu18.04-pytorch1.12.1-cuda11.3-cudnn8/Dockerfile @@ -1,7 +1,7 @@ FROM pytorch/pytorch:1.12.1-cuda11.3-cudnn8-devel # ENV http_proxy=http://aaa.bbb.cc.net:8080 \ -# https_proxy=http://aaa.bbb.cc.net:8080 +# https_proxy=http://aaa.bbb.cc.net:8080 # install normal source RUN apt-get update && \ @@ -38,10 +38,10 @@ RUN wget -P /opt https://cmake.org/files/v3.18/cmake-3.18.0.tar.gz && \ rm -rf cmake-3.18.0.tar.gz && \ find /opt/cmake-3.18.0 -type f \( -name "*.o" -o -name "*.la" -o -name "*.a" \) -exec rm {} \; && \ cd - - -# flac + +# flac RUN wget -P /opt https://downloads.xiph.org/releases/flac/flac-1.3.2.tar.xz && \ - cd /opt && \ + cd /opt && \ xz -d flac-1.3.2.tar.xz && \ tar -xvf flac-1.3.2.tar && \ cd flac-1.3.2 && \ @@ -49,11 +49,11 @@ RUN wget -P /opt https://downloads.xiph.org/releases/flac/flac-1.3.2.tar.xz && make && make install && \ rm -rf flac-1.3.2.tar && \ find /opt/flac-1.3.2 -type f \( -name "*.o" -o -name "*.la" -o -name "*.a" \) -exec rm {} \; && \ - cd - + cd - RUN conda install -y -c pytorch torchaudio=0.12 && \ pip install graphviz - + #install k2 from source RUN git clone https://github.com/k2-fsa/k2.git /opt/k2 && \ @@ -68,7 +68,7 @@ RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \ cd /workspace/icefall && \ pip install -r requirements.txt -RUN pip install kaldifeat +RUN pip install kaldifeat ENV PYTHONPATH /workspace/icefall:$PYTHONPATH WORKDIR /workspace/icefall diff --git a/docker/Ubuntu18.04-pytorch1.7.1-cuda11.0-cudnn8/Dockerfile b/docker/Ubuntu18.04-pytorch1.7.1-cuda11.0-cudnn8/Dockerfile index 17a8215f9..5c7423fa5 100644 --- a/docker/Ubuntu18.04-pytorch1.7.1-cuda11.0-cudnn8/Dockerfile +++ b/docker/Ubuntu18.04-pytorch1.7.1-cuda11.0-cudnn8/Dockerfile @@ -1,12 +1,12 @@ FROM pytorch/pytorch:1.7.1-cuda11.0-cudnn8-devel # ENV http_proxy=http://aaa.bbb.cc.net:8080 \ -# https_proxy=http://aaa.bbb.cc.net:8080 +# https_proxy=http://aaa.bbb.cc.net:8080 RUN rm /etc/apt/sources.list.d/cuda.list && \ rm /etc/apt/sources.list.d/nvidia-ml.list && \ apt-key del 7fa2af80 - + # install normal source RUN apt-get update && \ apt-get install -y --no-install-recommends \ @@ -36,7 +36,7 @@ RUN curl -fsSL https://developer.download.nvidia.com/compute/cuda/repos/ubuntu18 curl -fsSL https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64/7fa2af80.pub | apt-key add - && \ echo "deb https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64 /" > /etc/apt/sources.list.d/cuda.list && \ echo "deb https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64 /" > /etc/apt/sources.list.d/nvidia-ml.list && \ - rm -rf /var/lib/apt/lists/* && \ + rm -rf /var/lib/apt/lists/* && \ mv /opt/conda/lib/libcufft.so.10 /opt/libcufft.so.10.bak && \ mv /opt/conda/lib/libcurand.so.10 /opt/libcurand.so.10.bak && \ mv /opt/conda/lib/libcublas.so.11 /opt/libcublas.so.11.bak && \ @@ -56,10 +56,10 @@ RUN wget -P /opt https://cmake.org/files/v3.18/cmake-3.18.0.tar.gz && \ rm -rf cmake-3.18.0.tar.gz && \ find /opt/cmake-3.18.0 -type f \( -name "*.o" -o -name "*.la" -o -name "*.a" \) -exec rm {} \; && \ cd - - -# flac + +# flac RUN wget -P /opt https://downloads.xiph.org/releases/flac/flac-1.3.2.tar.xz && \ - cd /opt && \ + cd /opt && \ xz -d flac-1.3.2.tar.xz && \ tar -xvf flac-1.3.2.tar && \ cd flac-1.3.2 && \ @@ -67,7 +67,7 @@ RUN wget -P /opt https://downloads.xiph.org/releases/flac/flac-1.3.2.tar.xz && make && make install && \ rm -rf flac-1.3.2.tar && \ find /opt/flac-1.3.2 -type f \( -name "*.o" -o -name "*.la" -o -name "*.a" \) -exec rm {} \; && \ - cd - + cd - RUN conda install -y -c pytorch torchaudio=0.7.1 && \ pip install graphviz @@ -79,7 +79,7 @@ RUN git clone https://github.com/k2-fsa/k2.git /opt/k2 && \ cd - # install lhotse -RUN pip install git+https://github.com/lhotse-speech/lhotse +RUN pip install git+https://github.com/lhotse-speech/lhotse RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \ cd /workspace/icefall && \ @@ -88,4 +88,3 @@ RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \ ENV PYTHONPATH /workspace/icefall:$PYTHONPATH WORKDIR /workspace/icefall - diff --git a/docs/source/installation/images/k2-gt-v1.9-blueviolet.svg b/docs/source/installation/images/k2-gt-v1.9-blueviolet.svg index 534b2e534..3019ff03d 100644 --- a/docs/source/installation/images/k2-gt-v1.9-blueviolet.svg +++ b/docs/source/installation/images/k2-gt-v1.9-blueviolet.svg @@ -1 +1 @@ -k2: >= v1.9k2>= v1.9 \ No newline at end of file +k2: >= v1.9k2>= v1.9 diff --git a/docs/source/installation/images/python-gt-v3.6-blue.svg b/docs/source/installation/images/python-gt-v3.6-blue.svg index 4254dc58a..df677ad09 100644 --- a/docs/source/installation/images/python-gt-v3.6-blue.svg +++ b/docs/source/installation/images/python-gt-v3.6-blue.svg @@ -1 +1 @@ -python: >= 3.6python>= 3.6 \ No newline at end of file +python: >= 3.6python>= 3.6 diff --git a/docs/source/installation/images/torch-gt-v1.6.0-green.svg b/docs/source/installation/images/torch-gt-v1.6.0-green.svg index d3ece9a17..d7007d742 100644 --- a/docs/source/installation/images/torch-gt-v1.6.0-green.svg +++ b/docs/source/installation/images/torch-gt-v1.6.0-green.svg @@ -1 +1 @@ -torch: >= 1.6.0torch>= 1.6.0 \ No newline at end of file +torch: >= 1.6.0torch>= 1.6.0 diff --git a/docs/source/recipes/aishell/index.rst b/docs/source/recipes/aishell/index.rst index d072d6e9c..b77d59bca 100644 --- a/docs/source/recipes/aishell/index.rst +++ b/docs/source/recipes/aishell/index.rst @@ -19,4 +19,3 @@ It can be downloaded from ``_ tdnn_lstm_ctc conformer_ctc stateless_transducer - diff --git a/docs/source/recipes/timit/index.rst b/docs/source/recipes/timit/index.rst index 17f40cdb7..5ee147be7 100644 --- a/docs/source/recipes/timit/index.rst +++ b/docs/source/recipes/timit/index.rst @@ -6,4 +6,3 @@ TIMIT tdnn_ligru_ctc tdnn_lstm_ctc - diff --git a/docs/source/recipes/timit/tdnn_ligru_ctc.rst b/docs/source/recipes/timit/tdnn_ligru_ctc.rst index 186420ee7..3d7aefe02 100644 --- a/docs/source/recipes/timit/tdnn_ligru_ctc.rst +++ b/docs/source/recipes/timit/tdnn_ligru_ctc.rst @@ -148,10 +148,10 @@ Some commonly used options are: $ ./tdnn_ligru_ctc/decode.py --epoch 25 --avg 17 - uses the average of ``epoch-9.pt``, ``epoch-10.pt``, ``epoch-11.pt``, - ``epoch-12.pt``, ``epoch-13.pt``, ``epoch-14.pt``, ``epoch-15.pt``, - ``epoch-16.pt``, ``epoch-17.pt``, ``epoch-18.pt``, ``epoch-19.pt``, - ``epoch-20.pt``, ``epoch-21.pt``, ``epoch-22.pt``, ``epoch-23.pt``, + uses the average of ``epoch-9.pt``, ``epoch-10.pt``, ``epoch-11.pt``, + ``epoch-12.pt``, ``epoch-13.pt``, ``epoch-14.pt``, ``epoch-15.pt``, + ``epoch-16.pt``, ``epoch-17.pt``, ``epoch-18.pt``, ``epoch-19.pt``, + ``epoch-20.pt``, ``epoch-21.pt``, ``epoch-22.pt``, ``epoch-23.pt``, ``epoch-24.pt`` and ``epoch-25.pt`` for decoding. @@ -317,13 +317,13 @@ To decode with ``1best`` method, we can use: .. code-block:: bash - ./tdnn_ligru_ctc/pretrained.py + ./tdnn_ligru_ctc/pretrained.py --method 1best - --checkpoint ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/exp/pretrained_average_9_25.pt - --words-file ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/data/lang_phone/words.txt - --HLG ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/data/lang_phone/HLG.pt - ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FDHC0_SI1559.WAV - ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FELC0_SI756.WAV + --checkpoint ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/exp/pretrained_average_9_25.pt + --words-file ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/data/lang_phone/words.txt + --HLG ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/data/lang_phone/HLG.pt + ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FDHC0_SI1559.WAV + ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FELC0_SI756.WAV ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FMGD0_SI1564.WAV The output is: @@ -337,7 +337,7 @@ The output is: 2021-11-08 20:41:38,697 INFO [pretrained.py:210] Reading sound files: ['./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FDHC0_SI1559.WAV', './tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FELC0_SI756.WAV', './tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FMGD0_SI1564.WAV'] 2021-11-08 20:41:38,704 INFO [pretrained.py:216] Decoding started 2021-11-08 20:41:39,819 INFO [pretrained.py:246] Use HLG decoding - 2021-11-08 20:41:39,829 INFO [pretrained.py:267] + 2021-11-08 20:41:39,829 INFO [pretrained.py:267] ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FDHC0_SI1559.WAV: sil dh ih sh uw ah l iy v iy z ih sil p r aa sil k s ih m ey dx ih sil d w uh dx ih w ih s f iy l ih ng w ih th ih n ih m s eh l f sil jh @@ -362,8 +362,8 @@ To decode with ``whole-lattice-rescoring`` methond, you can use --HLG ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/data/lang_phone/HLG.pt \ --G ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/data/lm/G_4_gram.pt \ --ngram-lm-scale 0.1 \ - ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FDHC0_SI1559.WAV - ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FELC0_SI756.WAV + ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FDHC0_SI1559.WAV + ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FELC0_SI756.WAV ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FMGD0_SI1564.WAV The decoding output is: @@ -378,7 +378,7 @@ The decoding output is: 2021-11-08 20:37:54,715 INFO [pretrained.py:210] Reading sound files: ['./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FDHC0_SI1559.WAV', './tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FELC0_SI756.WAV', './tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FMGD0_SI1564.WAV'] 2021-11-08 20:37:54,720 INFO [pretrained.py:216] Decoding started 2021-11-08 20:37:55,808 INFO [pretrained.py:251] Use HLG decoding + LM rescoring - 2021-11-08 20:37:56,348 INFO [pretrained.py:267] + 2021-11-08 20:37:56,348 INFO [pretrained.py:267] ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FDHC0_SI1559.WAV: sil dh ih sh uw ah l iy v iy z ah sil p r aa sil k s ih m ey dx ih sil d w uh dx iy w ih s f iy l iy ng w ih th ih n ih m s eh l f sil jh diff --git a/docs/source/recipes/timit/tdnn_lstm_ctc.rst b/docs/source/recipes/timit/tdnn_lstm_ctc.rst index 6f760a9ce..ee67a6edc 100644 --- a/docs/source/recipes/timit/tdnn_lstm_ctc.rst +++ b/docs/source/recipes/timit/tdnn_lstm_ctc.rst @@ -148,8 +148,8 @@ Some commonly used options are: $ ./tdnn_lstm_ctc/decode.py --epoch 25 --avg 10 - uses the average of ``epoch-16.pt``, ``epoch-17.pt``, ``epoch-18.pt``, - ``epoch-19.pt``, ``epoch-20.pt``, ``epoch-21.pt``, ``epoch-22.pt``, + uses the average of ``epoch-16.pt``, ``epoch-17.pt``, ``epoch-18.pt``, + ``epoch-19.pt``, ``epoch-20.pt``, ``epoch-21.pt``, ``epoch-22.pt``, ``epoch-23.pt``, ``epoch-24.pt`` and ``epoch-25.pt`` for decoding. @@ -315,13 +315,13 @@ To decode with ``1best`` method, we can use: .. code-block:: bash - ./tdnn_lstm_ctc/pretrained.py + ./tdnn_lstm_ctc/pretrained.py --method 1best - --checkpoint ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/exp/pretrained_average_16_25.pt - --words-file ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/data/lang_phone/words.txt - --HLG ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/data/lang_phone/HLG.pt - ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FDHC0_SI1559.WAV - ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FELC0_SI756.WAV + --checkpoint ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/exp/pretrained_average_16_25.pt + --words-file ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/data/lang_phone/words.txt + --HLG ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/data/lang_phone/HLG.pt + ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FDHC0_SI1559.WAV + ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FELC0_SI756.WAV ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FMGD0_SI1564.WAV The output is: @@ -335,7 +335,7 @@ The output is: 2021-11-08 21:02:53,827 INFO [pretrained.py:210] Reading sound files: ['./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FDHC0_SI1559.WAV', './tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FELC0_SI756.WAV', './tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FMGD0_SI1564.WAV'] 2021-11-08 21:02:53,831 INFO [pretrained.py:216] Decoding started 2021-11-08 21:02:54,380 INFO [pretrained.py:246] Use HLG decoding - 2021-11-08 21:02:54,387 INFO [pretrained.py:267] + 2021-11-08 21:02:54,387 INFO [pretrained.py:267] ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FDHC0_SI1559.WAV: sil dh ih sh uw ah l iy v iy z ih sil p r aa sil k s ih m ey dx ih sil d w uh dx iy w ih s f iy l iy w ih th ih n ih m s eh l f sil jh @@ -360,8 +360,8 @@ To decode with ``whole-lattice-rescoring`` methond, you can use --HLG ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/data/lang_phone/HLG.pt \ --G ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/data/lm/G_4_gram.pt \ --ngram-lm-scale 0.08 \ - ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FDHC0_SI1559.WAV - ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FELC0_SI756.WAV + ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FDHC0_SI1559.WAV + ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FELC0_SI756.WAV ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FMGD0_SI1564.WAV The decoding output is: @@ -376,7 +376,7 @@ The decoding output is: 2021-11-08 20:05:26,978 INFO [pretrained.py:210] Reading sound files: ['./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FDHC0_SI1559.WAV', './tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FELC0_SI756.WAV', './tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FMGD0_SI1564.WAV'] 2021-11-08 20:05:26,981 INFO [pretrained.py:216] Decoding started 2021-11-08 20:05:27,519 INFO [pretrained.py:251] Use HLG decoding + LM rescoring - 2021-11-08 20:05:27,878 INFO [pretrained.py:267] + 2021-11-08 20:05:27,878 INFO [pretrained.py:267] ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FDHC0_SI1559.WAV: sil dh ih sh uw l iy v iy z ih sil p r aa sil k s ah m ey dx ih sil w uh dx iy w ih s f iy l ih ng w ih th ih n ih m s eh l f sil jh diff --git a/egs/aidatatang_200zh/ASR/local/compute_fbank_aidatatang_200zh.py b/egs/aidatatang_200zh/ASR/local/compute_fbank_aidatatang_200zh.py index fb2751c0f..387c14acf 100755 --- a/egs/aidatatang_200zh/ASR/local/compute_fbank_aidatatang_200zh.py +++ b/egs/aidatatang_200zh/ASR/local/compute_fbank_aidatatang_200zh.py @@ -87,9 +87,7 @@ def compute_fbank_aidatatang_200zh(num_mel_bins: int = 80): ) if "train" in partition: cut_set = ( - cut_set - + cut_set.perturb_speed(0.9) - + cut_set.perturb_speed(1.1) + cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) ) cut_set = cut_set.compute_and_store_features( extractor=extractor, @@ -116,9 +114,7 @@ def get_args(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/aidatatang_200zh/ASR/local/prepare_char.py b/egs/aidatatang_200zh/ASR/local/prepare_char.py index d9e47d17a..6b440dfb3 100755 --- a/egs/aidatatang_200zh/ASR/local/prepare_char.py +++ b/egs/aidatatang_200zh/ASR/local/prepare_char.py @@ -86,9 +86,7 @@ def lexicon_to_fst_no_sil( cur_state = loop_state word = word2id[word] - pieces = [ - token2id[i] if i in token2id else token2id[""] for i in pieces - ] + pieces = [token2id[i] if i in token2id else token2id[""] for i in pieces] for i in range(len(pieces) - 1): w = word if i == 0 else eps @@ -142,9 +140,7 @@ def contain_oov(token_sym_table: Dict[str, int], tokens: List[str]) -> bool: return False -def generate_lexicon( - token_sym_table: Dict[str, int], words: List[str] -) -> Lexicon: +def generate_lexicon(token_sym_table: Dict[str, int], words: List[str]) -> Lexicon: """Generate a lexicon from a word list and token_sym_table. Args: diff --git a/egs/aidatatang_200zh/ASR/local/prepare_lang.py b/egs/aidatatang_200zh/ASR/local/prepare_lang.py index e5ae89ec4..c8cf9b881 100755 --- a/egs/aidatatang_200zh/ASR/local/prepare_lang.py +++ b/egs/aidatatang_200zh/ASR/local/prepare_lang.py @@ -317,9 +317,7 @@ def lexicon_to_fst( def get_args(): parser = argparse.ArgumentParser() - parser.add_argument( - "--lang-dir", type=str, help="The lang dir, data/lang_phone" - ) + parser.add_argument("--lang-dir", type=str, help="The lang dir, data/lang_phone") return parser.parse_args() diff --git a/egs/aidatatang_200zh/ASR/local/test_prepare_lang.py b/egs/aidatatang_200zh/ASR/local/test_prepare_lang.py index d4cf62bba..74e025ad7 100755 --- a/egs/aidatatang_200zh/ASR/local/test_prepare_lang.py +++ b/egs/aidatatang_200zh/ASR/local/test_prepare_lang.py @@ -88,9 +88,7 @@ def test_read_lexicon(filename: str): fsa.aux_labels_sym = k2.SymbolTable.from_file("words.txt") fsa.draw("L.pdf", title="L") - fsa_disambig = lexicon_to_fst( - lexicon_disambig, phone2id=phone2id, word2id=word2id - ) + fsa_disambig = lexicon_to_fst(lexicon_disambig, phone2id=phone2id, word2id=word2id) fsa_disambig.labels_sym = k2.SymbolTable.from_file("phones.txt") fsa_disambig.aux_labels_sym = k2.SymbolTable.from_file("words.txt") fsa_disambig.draw("L_disambig.pdf", title="L_disambig") diff --git a/egs/aidatatang_200zh/ASR/local/text2token.py b/egs/aidatatang_200zh/ASR/local/text2token.py index 71be2a613..2be639b7a 100755 --- a/egs/aidatatang_200zh/ASR/local/text2token.py +++ b/egs/aidatatang_200zh/ASR/local/text2token.py @@ -50,15 +50,15 @@ def get_parser(): "-n", default=1, type=int, - help="number of characters to split, i.e., \ - aabb -> a a b b with -n 1 and aa bb with -n 2", + help=( + "number of characters to split, i.e., aabb -> a a b" + " b with -n 1 and aa bb with -n 2" + ), ) parser.add_argument( "--skip-ncols", "-s", default=0, type=int, help="skip first n columns" ) - parser.add_argument( - "--space", default="", type=str, help="space symbol" - ) + parser.add_argument("--space", default="", type=str, help="space symbol") parser.add_argument( "--non-lang-syms", "-l", @@ -66,9 +66,7 @@ def get_parser(): type=str, help="list of non-linguistic symobles, e.g., etc.", ) - parser.add_argument( - "text", type=str, default=False, nargs="?", help="input text" - ) + parser.add_argument("text", type=str, default=False, nargs="?", help="input text") parser.add_argument( "--trans_type", "-t", @@ -108,8 +106,7 @@ def token2id( if token_type == "lazy_pinyin": text = lazy_pinyin(chars_list) sub_ids = [ - token_table[txt] if txt in token_table else oov_id - for txt in text + token_table[txt] if txt in token_table else oov_id for txt in text ] ids.append(sub_ids) else: # token_type = "pinyin" @@ -135,9 +132,7 @@ def main(): if args.text: f = codecs.open(args.text, encoding="utf-8") else: - f = codecs.getreader("utf-8")( - sys.stdin if is_python2 else sys.stdin.buffer - ) + f = codecs.getreader("utf-8")(sys.stdin if is_python2 else sys.stdin.buffer) sys.stdout = codecs.getwriter("utf-8")( sys.stdout if is_python2 else sys.stdout.buffer diff --git a/egs/aidatatang_200zh/ASR/prepare.sh b/egs/aidatatang_200zh/ASR/prepare.sh index 039951354..4749e1b7f 100755 --- a/egs/aidatatang_200zh/ASR/prepare.sh +++ b/egs/aidatatang_200zh/ASR/prepare.sh @@ -106,11 +106,10 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then if [ ! -f $lang_char_dir/words.txt ]; then ./local/prepare_words.py \ --input-file $lang_char_dir/words_no_ids.txt \ - --output-file $lang_char_dir/words.txt + --output-file $lang_char_dir/words.txt fi if [ ! -f $lang_char_dir/L_disambig.pt ]; then ./local/prepare_char.py fi fi - diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/asr_datamodule.py index 6a5b57e24..8c94f5bea 100644 --- a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/asr_datamodule.py +++ b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/asr_datamodule.py @@ -81,10 +81,12 @@ class Aidatatang_200zhAsrDataModule: def add_arguments(cls, parser: argparse.ArgumentParser): group = parser.add_argument_group( title="ASR data related options", - description="These options are used for the preparation of " - "PyTorch DataLoaders from Lhotse CutSet's -- they control the " - "effective batch sizes, sampling strategies, applied data " - "augmentations, etc.", + description=( + "These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc." + ), ) group.add_argument( "--manifest-dir", @@ -96,75 +98,91 @@ class Aidatatang_200zhAsrDataModule: "--max-duration", type=int, default=200.0, - help="Maximum pooled recordings duration (seconds) in a " - "single batch. You can reduce it if it causes CUDA OOM.", + help=( + "Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM." + ), ) group.add_argument( "--bucketing-sampler", type=str2bool, default=True, - help="When enabled, the batches will come from buckets of " - "similar duration (saves padding frames).", + help=( + "When enabled, the batches will come from buckets of " + "similar duration (saves padding frames)." + ), ) group.add_argument( "--num-buckets", type=int, default=300, - help="The number of buckets for the DynamicBucketingSampler" - "(you might want to increase it for larger datasets).", + help=( + "The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets)." + ), ) group.add_argument( "--concatenate-cuts", type=str2bool, default=False, - help="When enabled, utterances (cuts) will be concatenated " - "to minimize the amount of padding.", + help=( + "When enabled, utterances (cuts) will be concatenated " + "to minimize the amount of padding." + ), ) group.add_argument( "--duration-factor", type=float, default=1.0, - help="Determines the maximum duration of a concatenated cut " - "relative to the duration of the longest cut in a batch.", + help=( + "Determines the maximum duration of a concatenated cut " + "relative to the duration of the longest cut in a batch." + ), ) group.add_argument( "--gap", type=float, default=1.0, - help="The amount of padding (in seconds) inserted between " - "concatenated cuts. This padding is filled with noise when " - "noise augmentation is used.", + help=( + "The amount of padding (in seconds) inserted between " + "concatenated cuts. This padding is filled with noise when " + "noise augmentation is used." + ), ) group.add_argument( "--on-the-fly-feats", type=str2bool, default=False, - help="When enabled, use on-the-fly cut mixing and feature " - "extraction. Will drop existing precomputed feature manifests " - "if available.", + help=( + "When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available." + ), ) group.add_argument( "--shuffle", type=str2bool, default=True, - help="When enabled (=default), the examples will be " - "shuffled for each epoch.", + help=( + "When enabled (=default), the examples will be shuffled for each epoch." + ), ) group.add_argument( "--return-cuts", type=str2bool, default=True, - help="When enabled, each batch will have the " - "field: batch['supervisions']['cut'] with the cuts that " - "were used to construct it.", + help=( + "When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it." + ), ) group.add_argument( "--num-workers", type=int, default=2, - help="The number of training dataloader workers that " - "collect the batches.", + help="The number of training dataloader workers that collect the batches.", ) group.add_argument( @@ -178,18 +196,22 @@ class Aidatatang_200zhAsrDataModule: "--spec-aug-time-warp-factor", type=int, default=80, - help="Used only when --enable-spec-aug is True. " - "It specifies the factor for time warping in SpecAugment. " - "Larger values mean more warping. " - "A value less than 1 means to disable time warp.", + help=( + "Used only when --enable-spec-aug is True. " + "It specifies the factor for time warping in SpecAugment. " + "Larger values mean more warping. " + "A value less than 1 means to disable time warp." + ), ) group.add_argument( "--enable-musan", type=str2bool, default=True, - help="When enabled, select noise from MUSAN and mix it" - "with training dataset. ", + help=( + "When enabled, select noise from MUSAN and mix it" + "with training dataset. " + ), ) def train_dataloaders( @@ -205,24 +227,20 @@ class Aidatatang_200zhAsrDataModule: The state dict for the training sampler. """ logging.info("About to get Musan cuts") - cuts_musan = load_manifest( - self.args.manifest_dir / "musan_cuts.jsonl.gz" - ) + cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") transforms = [] if self.args.enable_musan: logging.info("Enable MUSAN") transforms.append( - CutMix( - cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True - ) + CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) ) else: logging.info("Disable MUSAN") if self.args.concatenate_cuts: logging.info( - f"Using cut concatenation with duration factor " + "Using cut concatenation with duration factor " f"{self.args.duration_factor} and gap {self.args.gap}." ) # Cut concatenation should be the first transform in the list, @@ -237,9 +255,7 @@ class Aidatatang_200zhAsrDataModule: input_transforms = [] if self.args.enable_spec_aug: logging.info("Enable SpecAugment") - logging.info( - f"Time warp factor: {self.args.spec_aug_time_warp_factor}" - ) + logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") # Set the value of num_frame_masks according to Lhotse's version. # In different Lhotse's versions, the default of num_frame_masks is # different. @@ -282,9 +298,7 @@ class Aidatatang_200zhAsrDataModule: # Drop feats to be on the safe side. train = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) - ), + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), input_transforms=input_transforms, return_cuts=self.args.return_cuts, ) @@ -340,9 +354,7 @@ class Aidatatang_200zhAsrDataModule: if self.args.on_the_fly_feats: validate = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) - ), + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), return_cuts=self.args.return_cuts, ) else: diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/decode.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/decode.py index f0407f429..3f582ef04 100755 --- a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/decode.py @@ -69,11 +69,7 @@ from beam_search import ( ) from train import get_params, get_transducer_model -from icefall.checkpoint import ( - average_checkpoints, - find_checkpoints, - load_checkpoint, -) +from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, @@ -92,25 +88,30 @@ def get_parser(): "--epoch", type=int, default=28, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", + help=( + "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." + ), ) parser.add_argument( "--batch", type=int, default=None, - help="It specifies the batch checkpoint to use for decoding." - "Note: Epoch counts from 0.", + help=( + "It specifies the batch checkpoint to use for decoding." + "Note: Epoch counts from 0." + ), ) parser.add_argument( "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. " + ), ) parser.add_argument( @@ -192,8 +193,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -249,9 +249,7 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) hyps = [] if params.decoding_method == "fast_beam_search": @@ -266,10 +264,7 @@ def decode_one_batch( ) for i in range(encoder_out.size(0)): hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -315,11 +310,7 @@ def decode_one_batch( return {"greedy_search": hyps} elif params.decoding_method == "fast_beam_search": return { - ( - f"beam_{params.beam}_" - f"max_contexts_{params.max_contexts}_" - f"max_states_{params.max_states}" - ): hyps + f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps } else: return {f"beam_size_{params.beam_size}": hyps} @@ -390,9 +381,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -425,8 +414,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/export.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/export.py index 00b54c39f..34f4d3ddf 100644 --- a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/export.py +++ b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/export.py @@ -62,17 +62,20 @@ def get_parser(): "--epoch", type=int, default=28, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", + help=( + "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." + ), ) parser.add_argument( "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. " + ), ) parser.add_argument( @@ -103,8 +106,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) return parser @@ -173,9 +175,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/pretrained.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/pretrained.py index eb5e6b0d4..3c96ed07b 100644 --- a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/pretrained.py +++ b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/pretrained.py @@ -85,9 +85,11 @@ def get_parser(): "--checkpoint", type=str, required=True, - help="Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint().", + help=( + "Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint()." + ), ) parser.add_argument( @@ -112,10 +114,12 @@ def get_parser(): "sound_files", type=str, nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", + help=( + "The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz." + ), ) parser.add_argument( @@ -162,8 +166,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( @@ -193,10 +196,9 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" - ) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" # We use only the first channel ans.append(wave[0]) return ans @@ -257,9 +259,7 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) feature_lengths = torch.tensor(feature_lengths, device=device) @@ -284,10 +284,7 @@ def main(): ) for i in range(encoder_out.size(0)): hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -339,9 +336,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/train.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/train.py index d46838b68..c7b1a4266 100644 --- a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/train.py +++ b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/train.py @@ -81,9 +81,7 @@ from icefall.env import get_env_info from icefall.lexicon import Lexicon from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] os.environ["CUDA_LAUNCH_BLOCKING"] = "1" @@ -187,42 +185,45 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", + help=( + "The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss" + ), ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", + help=( + "The scale to smooth the loss with lm (output of prediction network) part." + ), ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network)part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", + help=( + "To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss." + ), ) parser.add_argument( @@ -542,22 +543,15 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 - if warmup < 1.0 - else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = ( - params.simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss + 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) ) + loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() @@ -711,9 +705,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -813,7 +805,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2 ** 22 + 2**22 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/aishell/ASR/conformer_ctc/conformer.py b/egs/aishell/ASR/conformer_ctc/conformer.py index cb7205e51..f5b5873b4 100644 --- a/egs/aishell/ASR/conformer_ctc/conformer.py +++ b/egs/aishell/ASR/conformer_ctc/conformer.py @@ -157,9 +157,7 @@ class ConformerEncoderLayer(nn.Module): normalize_before: bool = True, ) -> None: super(ConformerEncoderLayer, self).__init__() - self.self_attn = RelPositionMultiheadAttention( - d_model, nhead, dropout=0.0 - ) + self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) self.feed_forward = nn.Sequential( nn.Linear(d_model, dim_feedforward), @@ -177,18 +175,14 @@ class ConformerEncoderLayer(nn.Module): self.conv_module = ConvolutionModule(d_model, cnn_module_kernel) - self.norm_ff_macaron = nn.LayerNorm( - d_model - ) # for the macaron style FNN module + self.norm_ff_macaron = nn.LayerNorm(d_model) # for the macaron style FNN module self.norm_ff = nn.LayerNorm(d_model) # for the FNN module self.norm_mha = nn.LayerNorm(d_model) # for the MHA module self.ff_scale = 0.5 self.norm_conv = nn.LayerNorm(d_model) # for the CNN module - self.norm_final = nn.LayerNorm( - d_model - ) # for the final output of the block + self.norm_final = nn.LayerNorm(d_model) # for the final output of the block self.dropout = nn.Dropout(dropout) @@ -222,9 +216,7 @@ class ConformerEncoderLayer(nn.Module): residual = src if self.normalize_before: src = self.norm_ff_macaron(src) - src = residual + self.ff_scale * self.dropout( - self.feed_forward_macaron(src) - ) + src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src)) if not self.normalize_before: src = self.norm_ff_macaron(src) @@ -343,9 +335,7 @@ class RelPositionalEncoding(torch.nn.Module): """ - def __init__( - self, d_model: int, dropout_rate: float, max_len: int = 5000 - ) -> None: + def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: """Construct an PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() self.d_model = d_model @@ -361,9 +351,7 @@ class RelPositionalEncoding(torch.nn.Module): # the length of self.pe is 2 * input_len - 1 if self.pe.size(1) >= x.size(1) * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str( - x.device - ): + if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return # Suppose `i` means to the position of query vector and `j` means the @@ -633,9 +621,9 @@ class RelPositionMultiheadAttention(nn.Module): if torch.equal(query, key) and torch.equal(key, value): # self-attention - q, k, v = nn.functional.linear( - query, in_proj_weight, in_proj_bias - ).chunk(3, dim=-1) + q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk( + 3, dim=-1 + ) elif torch.equal(key, value): # encoder-decoder attention @@ -703,33 +691,25 @@ class RelPositionMultiheadAttention(nn.Module): if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: - raise RuntimeError( - "The size of the 2D attn_mask is not correct." - ) + raise RuntimeError("The size of the 2D attn_mask is not correct.") elif attn_mask.dim() == 3: if list(attn_mask.size()) != [ bsz * num_heads, query.size(0), key.size(0), ]: - raise RuntimeError( - "The size of the 3D attn_mask is not correct." - ) + raise RuntimeError("The size of the 3D attn_mask is not correct.") else: raise RuntimeError( - "attn_mask's dimension {} is not supported".format( - attn_mask.dim() - ) + "attn_mask's dimension {} is not supported".format(attn_mask.dim()) ) # attn_mask's dim is 3 now. # convert ByteTensor key_padding_mask to bool - if ( - key_padding_mask is not None - and key_padding_mask.dtype == torch.uint8 - ): + if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: warnings.warn( - "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." + "Byte tensor for key_padding_mask is deprecated. Use bool tensor" + " instead." ) key_padding_mask = key_padding_mask.to(torch.bool) @@ -766,9 +746,7 @@ class RelPositionMultiheadAttention(nn.Module): # first compute matrix a and matrix c # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - matrix_ac = torch.matmul( - q_with_bias_u, k - ) # (batch, head, time1, time2) + matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2) # compute matrix b and matrix d matrix_bd = torch.matmul( @@ -780,9 +758,7 @@ class RelPositionMultiheadAttention(nn.Module): matrix_ac + matrix_bd ) * scaling # (batch, head, time1, time2) - attn_output_weights = attn_output_weights.view( - bsz * num_heads, tgt_len, -1 - ) + attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) assert list(attn_output_weights.size()) == [ bsz * num_heads, @@ -816,13 +792,9 @@ class RelPositionMultiheadAttention(nn.Module): attn_output = torch.bmm(attn_output_weights, v) assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] attn_output = ( - attn_output.transpose(0, 1) - .contiguous() - .view(tgt_len, bsz, embed_dim) - ) - attn_output = nn.functional.linear( - attn_output, out_proj_weight, out_proj_bias + attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) ) + attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) if need_weights: # average attention weights over heads @@ -845,9 +817,7 @@ class ConvolutionModule(nn.Module): """ - def __init__( - self, channels: int, kernel_size: int, bias: bool = True - ) -> None: + def __init__(self, channels: int, kernel_size: int, bias: bool = True) -> None: """Construct an ConvolutionModule object.""" super(ConvolutionModule, self).__init__() # kernerl_size should be a odd number for 'SAME' padding diff --git a/egs/aishell/ASR/conformer_ctc/decode.py b/egs/aishell/ASR/conformer_ctc/decode.py index 751b7d5b5..a30fa52df 100755 --- a/egs/aishell/ASR/conformer_ctc/decode.py +++ b/egs/aishell/ASR/conformer_ctc/decode.py @@ -58,16 +58,19 @@ def get_parser(): "--epoch", type=int, default=49, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", + help=( + "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." + ), ) parser.add_argument( "--avg", type=int, default=20, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. " + ), ) parser.add_argument( @@ -401,9 +404,7 @@ def decode_dataset( if batch_idx % 100 == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -431,9 +432,7 @@ def save_results( # we compute CER for aishell dataset. results_char = [] for res in results: - results_char.append( - (res[0], list("".join(res[1])), list("".join(res[2]))) - ) + results_char.append((res[0], list("".join(res[1])), list("".join(res[2])))) with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results_char, enable_log=enable_log @@ -441,9 +440,7 @@ def save_results( test_set_wers[key] = wer if enable_log: - logging.info( - "Wrote detailed error stats to {}".format(errs_filename) - ) + logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = params.exp_dir / f"cer-summary-{test_set_name}.txt" @@ -562,9 +559,7 @@ def main(): eos_id=eos_id, ) - save_results( - params=params, test_set_name=test_set, results_dict=results_dict - ) + save_results(params=params, test_set_name=test_set, results_dict=results_dict) logging.info("Done!") diff --git a/egs/aishell/ASR/conformer_ctc/export.py b/egs/aishell/ASR/conformer_ctc/export.py index 42b8c29e7..9ee405e8b 100644 --- a/egs/aishell/ASR/conformer_ctc/export.py +++ b/egs/aishell/ASR/conformer_ctc/export.py @@ -40,17 +40,20 @@ def get_parser(): "--epoch", type=int, default=84, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", + help=( + "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." + ), ) parser.add_argument( "--avg", type=int, default=25, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. " + ), ) parser.add_argument( @@ -157,9 +160,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/aishell/ASR/conformer_ctc/pretrained.py b/egs/aishell/ASR/conformer_ctc/pretrained.py index 27776bc24..e3d5a20e3 100755 --- a/egs/aishell/ASR/conformer_ctc/pretrained.py +++ b/egs/aishell/ASR/conformer_ctc/pretrained.py @@ -46,27 +46,29 @@ def get_parser(): "--checkpoint", type=str, required=True, - help="Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint().", + help=( + "Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint()." + ), ) parser.add_argument( "--tokens-file", type=str, - help="Path to tokens.txt" "Used only when method is ctc-decoding", + help="Path to tokens.txtUsed only when method is ctc-decoding", ) parser.add_argument( "--words-file", type=str, - help="Path to words.txt" "Used when method is NOT ctc-decoding", + help="Path to words.txtUsed when method is NOT ctc-decoding", ) parser.add_argument( "--HLG", type=str, - help="Path to HLG.pt." "Used when method is NOT ctc-decoding", + help="Path to HLG.pt.Used when method is NOT ctc-decoding", ) parser.add_argument( @@ -163,10 +165,12 @@ def get_parser(): "sound_files", type=str, nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", + help=( + "The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz." + ), ) return parser @@ -210,10 +214,9 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" - ) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" # We use only the first channel ans.append(wave[0]) return ans @@ -274,9 +277,7 @@ def main(): logging.info("Decoding started") features = fbank(waves) - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) # Note: We don't use key padding mask for attention during decoding with torch.no_grad(): @@ -371,9 +372,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/aishell/ASR/conformer_ctc/subsampling.py b/egs/aishell/ASR/conformer_ctc/subsampling.py index 542fb0364..8e0f73d05 100644 --- a/egs/aishell/ASR/conformer_ctc/subsampling.py +++ b/egs/aishell/ASR/conformer_ctc/subsampling.py @@ -42,13 +42,9 @@ class Conv2dSubsampling(nn.Module): assert idim >= 7 super().__init__() self.conv = nn.Sequential( - nn.Conv2d( - in_channels=1, out_channels=odim, kernel_size=3, stride=2 - ), + nn.Conv2d(in_channels=1, out_channels=odim, kernel_size=3, stride=2), nn.ReLU(), - nn.Conv2d( - in_channels=odim, out_channels=odim, kernel_size=3, stride=2 - ), + nn.Conv2d(in_channels=odim, out_channels=odim, kernel_size=3, stride=2), nn.ReLU(), ) self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim) @@ -132,17 +128,13 @@ class VggSubsampling(nn.Module): ) ) layers.append( - torch.nn.MaxPool2d( - kernel_size=2, stride=2, padding=0, ceil_mode=True - ) + torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=True) ) cur_channels = block_dim self.layers = nn.Sequential(*layers) - self.out = nn.Linear( - block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim - ) + self.out = nn.Linear(block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim) def forward(self, x: torch.Tensor) -> torch.Tensor: """Subsample x. diff --git a/egs/aishell/ASR/conformer_ctc/test_subsampling.py b/egs/aishell/ASR/conformer_ctc/test_subsampling.py index e3361d0c9..81fa234dd 100755 --- a/egs/aishell/ASR/conformer_ctc/test_subsampling.py +++ b/egs/aishell/ASR/conformer_ctc/test_subsampling.py @@ -16,9 +16,8 @@ # limitations under the License. -from subsampling import Conv2dSubsampling -from subsampling import VggSubsampling import torch +from subsampling import Conv2dSubsampling, VggSubsampling def test_conv2d_subsampling(): diff --git a/egs/aishell/ASR/conformer_ctc/train.py b/egs/aishell/ASR/conformer_ctc/train.py index a228cc1fe..c2cbe6e3b 100755 --- a/egs/aishell/ASR/conformer_ctc/train.py +++ b/egs/aishell/ASR/conformer_ctc/train.py @@ -382,9 +382,7 @@ def compute_loss( # # See https://github.com/k2-fsa/icefall/issues/97 # for more details - unsorted_token_ids = graph_compiler.texts_to_ids( - supervisions["text"] - ) + unsorted_token_ids = graph_compiler.texts_to_ids(supervisions["text"]) att_loss = mmodel.decoder_forward( encoder_memory, memory_mask, @@ -520,9 +518,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -630,9 +626,7 @@ def run(rank, world_size, args): cur_lr = optimizer._rate if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) + tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) if rank == 0: diff --git a/egs/aishell/ASR/conformer_ctc/transformer.py b/egs/aishell/ASR/conformer_ctc/transformer.py index f93914aaa..a3e50e385 100644 --- a/egs/aishell/ASR/conformer_ctc/transformer.py +++ b/egs/aishell/ASR/conformer_ctc/transformer.py @@ -149,9 +149,7 @@ class Transformer(nn.Module): norm=decoder_norm, ) - self.decoder_output_layer = torch.nn.Linear( - d_model, self.decoder_num_class - ) + self.decoder_output_layer = torch.nn.Linear(d_model, self.decoder_num_class) self.decoder_criterion = LabelSmoothingLoss() else: @@ -183,9 +181,7 @@ class Transformer(nn.Module): x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T) x = self.feat_batchnorm(x) x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C) - encoder_memory, memory_key_padding_mask = self.run_encoder( - x, supervision - ) + encoder_memory, memory_key_padding_mask = self.run_encoder(x, supervision) x = self.ctc_output(encoder_memory) return x, encoder_memory, memory_key_padding_mask @@ -266,23 +262,17 @@ class Transformer(nn.Module): """ ys_in = add_sos(token_ids, sos_id=sos_id) ys_in = [torch.tensor(y) for y in ys_in] - ys_in_pad = pad_sequence( - ys_in, batch_first=True, padding_value=float(eos_id) - ) + ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id)) ys_out = add_eos(token_ids, eos_id=eos_id) ys_out = [torch.tensor(y) for y in ys_out] - ys_out_pad = pad_sequence( - ys_out, batch_first=True, padding_value=float(-1) - ) + ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1)) device = memory.device ys_in_pad = ys_in_pad.to(device) ys_out_pad = ys_out_pad.to(device) - tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( - device - ) + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) # TODO: Use length information to create the decoder padding mask @@ -343,23 +333,17 @@ class Transformer(nn.Module): ys_in = add_sos(token_ids, sos_id=sos_id) ys_in = [torch.tensor(y) for y in ys_in] - ys_in_pad = pad_sequence( - ys_in, batch_first=True, padding_value=float(eos_id) - ) + ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id)) ys_out = add_eos(token_ids, eos_id=eos_id) ys_out = [torch.tensor(y) for y in ys_out] - ys_out_pad = pad_sequence( - ys_out, batch_first=True, padding_value=float(-1) - ) + ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1)) device = memory.device ys_in_pad = ys_in_pad.to(device, dtype=torch.int64) ys_out_pad = ys_out_pad.to(device, dtype=torch.int64) - tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( - device - ) + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) # TODO: Use length information to create the decoder padding mask @@ -632,9 +616,7 @@ def _get_activation_fn(activation: str): elif activation == "gelu": return nn.functional.gelu - raise RuntimeError( - "activation should be relu/gelu, not {}".format(activation) - ) + raise RuntimeError("activation should be relu/gelu, not {}".format(activation)) class PositionalEncoding(nn.Module): @@ -836,9 +818,7 @@ def encoder_padding_mask( 1, ).to(torch.int32) - lengths = [ - 0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1) - ] + lengths = [0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)] for idx in range(supervision_segments.size(0)): # Note: TorchScript doesn't allow to unpack tensors as tuples sequence_idx = supervision_segments[idx, 0].item() @@ -859,9 +839,7 @@ def encoder_padding_mask( return mask -def decoder_padding_mask( - ys_pad: torch.Tensor, ignore_id: int = -1 -) -> torch.Tensor: +def decoder_padding_mask(ys_pad: torch.Tensor, ignore_id: int = -1) -> torch.Tensor: """Generate a length mask for input. The masked position are filled with True, diff --git a/egs/aishell/ASR/conformer_mmi/conformer.py b/egs/aishell/ASR/conformer_mmi/conformer.py index cb7205e51..f5b5873b4 100644 --- a/egs/aishell/ASR/conformer_mmi/conformer.py +++ b/egs/aishell/ASR/conformer_mmi/conformer.py @@ -157,9 +157,7 @@ class ConformerEncoderLayer(nn.Module): normalize_before: bool = True, ) -> None: super(ConformerEncoderLayer, self).__init__() - self.self_attn = RelPositionMultiheadAttention( - d_model, nhead, dropout=0.0 - ) + self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) self.feed_forward = nn.Sequential( nn.Linear(d_model, dim_feedforward), @@ -177,18 +175,14 @@ class ConformerEncoderLayer(nn.Module): self.conv_module = ConvolutionModule(d_model, cnn_module_kernel) - self.norm_ff_macaron = nn.LayerNorm( - d_model - ) # for the macaron style FNN module + self.norm_ff_macaron = nn.LayerNorm(d_model) # for the macaron style FNN module self.norm_ff = nn.LayerNorm(d_model) # for the FNN module self.norm_mha = nn.LayerNorm(d_model) # for the MHA module self.ff_scale = 0.5 self.norm_conv = nn.LayerNorm(d_model) # for the CNN module - self.norm_final = nn.LayerNorm( - d_model - ) # for the final output of the block + self.norm_final = nn.LayerNorm(d_model) # for the final output of the block self.dropout = nn.Dropout(dropout) @@ -222,9 +216,7 @@ class ConformerEncoderLayer(nn.Module): residual = src if self.normalize_before: src = self.norm_ff_macaron(src) - src = residual + self.ff_scale * self.dropout( - self.feed_forward_macaron(src) - ) + src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src)) if not self.normalize_before: src = self.norm_ff_macaron(src) @@ -343,9 +335,7 @@ class RelPositionalEncoding(torch.nn.Module): """ - def __init__( - self, d_model: int, dropout_rate: float, max_len: int = 5000 - ) -> None: + def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: """Construct an PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() self.d_model = d_model @@ -361,9 +351,7 @@ class RelPositionalEncoding(torch.nn.Module): # the length of self.pe is 2 * input_len - 1 if self.pe.size(1) >= x.size(1) * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str( - x.device - ): + if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return # Suppose `i` means to the position of query vector and `j` means the @@ -633,9 +621,9 @@ class RelPositionMultiheadAttention(nn.Module): if torch.equal(query, key) and torch.equal(key, value): # self-attention - q, k, v = nn.functional.linear( - query, in_proj_weight, in_proj_bias - ).chunk(3, dim=-1) + q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk( + 3, dim=-1 + ) elif torch.equal(key, value): # encoder-decoder attention @@ -703,33 +691,25 @@ class RelPositionMultiheadAttention(nn.Module): if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: - raise RuntimeError( - "The size of the 2D attn_mask is not correct." - ) + raise RuntimeError("The size of the 2D attn_mask is not correct.") elif attn_mask.dim() == 3: if list(attn_mask.size()) != [ bsz * num_heads, query.size(0), key.size(0), ]: - raise RuntimeError( - "The size of the 3D attn_mask is not correct." - ) + raise RuntimeError("The size of the 3D attn_mask is not correct.") else: raise RuntimeError( - "attn_mask's dimension {} is not supported".format( - attn_mask.dim() - ) + "attn_mask's dimension {} is not supported".format(attn_mask.dim()) ) # attn_mask's dim is 3 now. # convert ByteTensor key_padding_mask to bool - if ( - key_padding_mask is not None - and key_padding_mask.dtype == torch.uint8 - ): + if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: warnings.warn( - "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." + "Byte tensor for key_padding_mask is deprecated. Use bool tensor" + " instead." ) key_padding_mask = key_padding_mask.to(torch.bool) @@ -766,9 +746,7 @@ class RelPositionMultiheadAttention(nn.Module): # first compute matrix a and matrix c # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - matrix_ac = torch.matmul( - q_with_bias_u, k - ) # (batch, head, time1, time2) + matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2) # compute matrix b and matrix d matrix_bd = torch.matmul( @@ -780,9 +758,7 @@ class RelPositionMultiheadAttention(nn.Module): matrix_ac + matrix_bd ) * scaling # (batch, head, time1, time2) - attn_output_weights = attn_output_weights.view( - bsz * num_heads, tgt_len, -1 - ) + attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) assert list(attn_output_weights.size()) == [ bsz * num_heads, @@ -816,13 +792,9 @@ class RelPositionMultiheadAttention(nn.Module): attn_output = torch.bmm(attn_output_weights, v) assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] attn_output = ( - attn_output.transpose(0, 1) - .contiguous() - .view(tgt_len, bsz, embed_dim) - ) - attn_output = nn.functional.linear( - attn_output, out_proj_weight, out_proj_bias + attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) ) + attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) if need_weights: # average attention weights over heads @@ -845,9 +817,7 @@ class ConvolutionModule(nn.Module): """ - def __init__( - self, channels: int, kernel_size: int, bias: bool = True - ) -> None: + def __init__(self, channels: int, kernel_size: int, bias: bool = True) -> None: """Construct an ConvolutionModule object.""" super(ConvolutionModule, self).__init__() # kernerl_size should be a odd number for 'SAME' padding diff --git a/egs/aishell/ASR/conformer_mmi/decode.py b/egs/aishell/ASR/conformer_mmi/decode.py index 4db367e36..a43183063 100755 --- a/egs/aishell/ASR/conformer_mmi/decode.py +++ b/egs/aishell/ASR/conformer_mmi/decode.py @@ -59,16 +59,19 @@ def get_parser(): "--epoch", type=int, default=49, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", + help=( + "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." + ), ) parser.add_argument( "--avg", type=int, default=20, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. " + ), ) parser.add_argument( @@ -413,9 +416,7 @@ def decode_dataset( if batch_idx % 100 == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -443,9 +444,7 @@ def save_results( # we compute CER for aishell dataset. results_char = [] for res in results: - results_char.append( - (res[0], list("".join(res[1])), list("".join(res[2]))) - ) + results_char.append((res[0], list("".join(res[1])), list("".join(res[2])))) with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results_char, enable_log=enable_log @@ -453,9 +452,7 @@ def save_results( test_set_wers[key] = wer if enable_log: - logging.info( - "Wrote detailed error stats to {}".format(errs_filename) - ) + logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = params.exp_dir / f"cer-summary-{test_set_name}.txt" @@ -550,9 +547,7 @@ def main(): if params.export: logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt") - torch.save( - {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt" - ) + torch.save({"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt") return model.to(device) @@ -581,9 +576,7 @@ def main(): eos_id=eos_id, ) - save_results( - params=params, test_set_name=test_set, results_dict=results_dict - ) + save_results(params=params, test_set_name=test_set, results_dict=results_dict) logging.info("Done!") diff --git a/egs/aishell/ASR/conformer_mmi/subsampling.py b/egs/aishell/ASR/conformer_mmi/subsampling.py index 720ed6c22..398837a46 100644 --- a/egs/aishell/ASR/conformer_mmi/subsampling.py +++ b/egs/aishell/ASR/conformer_mmi/subsampling.py @@ -42,13 +42,9 @@ class Conv2dSubsampling(nn.Module): assert idim >= 7 super().__init__() self.conv = nn.Sequential( - nn.Conv2d( - in_channels=1, out_channels=odim, kernel_size=3, stride=2 - ), + nn.Conv2d(in_channels=1, out_channels=odim, kernel_size=3, stride=2), nn.ReLU(), - nn.Conv2d( - in_channels=odim, out_channels=odim, kernel_size=3, stride=2 - ), + nn.Conv2d(in_channels=odim, out_channels=odim, kernel_size=3, stride=2), nn.ReLU(), ) self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim) @@ -132,17 +128,13 @@ class VggSubsampling(nn.Module): ) ) layers.append( - torch.nn.MaxPool2d( - kernel_size=2, stride=2, padding=0, ceil_mode=True - ) + torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=True) ) cur_channels = block_dim self.layers = nn.Sequential(*layers) - self.out = nn.Linear( - block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim - ) + self.out = nn.Linear(block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim) def forward(self, x: torch.Tensor) -> torch.Tensor: """Subsample x. diff --git a/egs/aishell/ASR/conformer_mmi/train.py b/egs/aishell/ASR/conformer_mmi/train.py index 685831d09..09cd6e60c 100755 --- a/egs/aishell/ASR/conformer_mmi/train.py +++ b/egs/aishell/ASR/conformer_mmi/train.py @@ -511,9 +511,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -625,9 +623,7 @@ def run(rank, world_size, args): cur_lr = optimizer._rate if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) + tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) if rank == 0: diff --git a/egs/aishell/ASR/conformer_mmi/transformer.py b/egs/aishell/ASR/conformer_mmi/transformer.py index f93914aaa..a3e50e385 100644 --- a/egs/aishell/ASR/conformer_mmi/transformer.py +++ b/egs/aishell/ASR/conformer_mmi/transformer.py @@ -149,9 +149,7 @@ class Transformer(nn.Module): norm=decoder_norm, ) - self.decoder_output_layer = torch.nn.Linear( - d_model, self.decoder_num_class - ) + self.decoder_output_layer = torch.nn.Linear(d_model, self.decoder_num_class) self.decoder_criterion = LabelSmoothingLoss() else: @@ -183,9 +181,7 @@ class Transformer(nn.Module): x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T) x = self.feat_batchnorm(x) x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C) - encoder_memory, memory_key_padding_mask = self.run_encoder( - x, supervision - ) + encoder_memory, memory_key_padding_mask = self.run_encoder(x, supervision) x = self.ctc_output(encoder_memory) return x, encoder_memory, memory_key_padding_mask @@ -266,23 +262,17 @@ class Transformer(nn.Module): """ ys_in = add_sos(token_ids, sos_id=sos_id) ys_in = [torch.tensor(y) for y in ys_in] - ys_in_pad = pad_sequence( - ys_in, batch_first=True, padding_value=float(eos_id) - ) + ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id)) ys_out = add_eos(token_ids, eos_id=eos_id) ys_out = [torch.tensor(y) for y in ys_out] - ys_out_pad = pad_sequence( - ys_out, batch_first=True, padding_value=float(-1) - ) + ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1)) device = memory.device ys_in_pad = ys_in_pad.to(device) ys_out_pad = ys_out_pad.to(device) - tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( - device - ) + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) # TODO: Use length information to create the decoder padding mask @@ -343,23 +333,17 @@ class Transformer(nn.Module): ys_in = add_sos(token_ids, sos_id=sos_id) ys_in = [torch.tensor(y) for y in ys_in] - ys_in_pad = pad_sequence( - ys_in, batch_first=True, padding_value=float(eos_id) - ) + ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id)) ys_out = add_eos(token_ids, eos_id=eos_id) ys_out = [torch.tensor(y) for y in ys_out] - ys_out_pad = pad_sequence( - ys_out, batch_first=True, padding_value=float(-1) - ) + ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1)) device = memory.device ys_in_pad = ys_in_pad.to(device, dtype=torch.int64) ys_out_pad = ys_out_pad.to(device, dtype=torch.int64) - tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( - device - ) + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) # TODO: Use length information to create the decoder padding mask @@ -632,9 +616,7 @@ def _get_activation_fn(activation: str): elif activation == "gelu": return nn.functional.gelu - raise RuntimeError( - "activation should be relu/gelu, not {}".format(activation) - ) + raise RuntimeError("activation should be relu/gelu, not {}".format(activation)) class PositionalEncoding(nn.Module): @@ -836,9 +818,7 @@ def encoder_padding_mask( 1, ).to(torch.int32) - lengths = [ - 0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1) - ] + lengths = [0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)] for idx in range(supervision_segments.size(0)): # Note: TorchScript doesn't allow to unpack tensors as tuples sequence_idx = supervision_segments[idx, 0].item() @@ -859,9 +839,7 @@ def encoder_padding_mask( return mask -def decoder_padding_mask( - ys_pad: torch.Tensor, ignore_id: int = -1 -) -> torch.Tensor: +def decoder_padding_mask(ys_pad: torch.Tensor, ignore_id: int = -1) -> torch.Tensor: """Generate a length mask for input. The masked position are filled with True, diff --git a/egs/aishell/ASR/local/compute_fbank_aidatatang_200zh.py b/egs/aishell/ASR/local/compute_fbank_aidatatang_200zh.py index 42700a972..037971927 100755 --- a/egs/aishell/ASR/local/compute_fbank_aidatatang_200zh.py +++ b/egs/aishell/ASR/local/compute_fbank_aidatatang_200zh.py @@ -87,9 +87,7 @@ def compute_fbank_aidatatang_200zh(num_mel_bins: int = 80): ) if "train" in partition: cut_set = ( - cut_set - + cut_set.perturb_speed(0.9) - + cut_set.perturb_speed(1.1) + cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) ) cut_set = cut_set.compute_and_store_features( extractor=extractor, @@ -116,9 +114,7 @@ def get_args(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/aishell/ASR/local/compute_fbank_aishell.py b/egs/aishell/ASR/local/compute_fbank_aishell.py index deab6c809..115ca1031 100755 --- a/egs/aishell/ASR/local/compute_fbank_aishell.py +++ b/egs/aishell/ASR/local/compute_fbank_aishell.py @@ -83,9 +83,7 @@ def compute_fbank_aishell(num_mel_bins: int = 80): ) if "train" in partition: cut_set = ( - cut_set - + cut_set.perturb_speed(0.9) - + cut_set.perturb_speed(1.1) + cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) ) cut_set = cut_set.compute_and_store_features( extractor=extractor, @@ -111,9 +109,7 @@ def get_args(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/aishell/ASR/local/prepare_char.py b/egs/aishell/ASR/local/prepare_char.py index d9e47d17a..6b440dfb3 100755 --- a/egs/aishell/ASR/local/prepare_char.py +++ b/egs/aishell/ASR/local/prepare_char.py @@ -86,9 +86,7 @@ def lexicon_to_fst_no_sil( cur_state = loop_state word = word2id[word] - pieces = [ - token2id[i] if i in token2id else token2id[""] for i in pieces - ] + pieces = [token2id[i] if i in token2id else token2id[""] for i in pieces] for i in range(len(pieces) - 1): w = word if i == 0 else eps @@ -142,9 +140,7 @@ def contain_oov(token_sym_table: Dict[str, int], tokens: List[str]) -> bool: return False -def generate_lexicon( - token_sym_table: Dict[str, int], words: List[str] -) -> Lexicon: +def generate_lexicon(token_sym_table: Dict[str, int], words: List[str]) -> Lexicon: """Generate a lexicon from a word list and token_sym_table. Args: diff --git a/egs/aishell/ASR/local/prepare_lang.py b/egs/aishell/ASR/local/prepare_lang.py index e5ae89ec4..c8cf9b881 100755 --- a/egs/aishell/ASR/local/prepare_lang.py +++ b/egs/aishell/ASR/local/prepare_lang.py @@ -317,9 +317,7 @@ def lexicon_to_fst( def get_args(): parser = argparse.ArgumentParser() - parser.add_argument( - "--lang-dir", type=str, help="The lang dir, data/lang_phone" - ) + parser.add_argument("--lang-dir", type=str, help="The lang dir, data/lang_phone") return parser.parse_args() diff --git a/egs/aishell/ASR/local/test_prepare_lang.py b/egs/aishell/ASR/local/test_prepare_lang.py index d4cf62bba..74e025ad7 100755 --- a/egs/aishell/ASR/local/test_prepare_lang.py +++ b/egs/aishell/ASR/local/test_prepare_lang.py @@ -88,9 +88,7 @@ def test_read_lexicon(filename: str): fsa.aux_labels_sym = k2.SymbolTable.from_file("words.txt") fsa.draw("L.pdf", title="L") - fsa_disambig = lexicon_to_fst( - lexicon_disambig, phone2id=phone2id, word2id=word2id - ) + fsa_disambig = lexicon_to_fst(lexicon_disambig, phone2id=phone2id, word2id=word2id) fsa_disambig.labels_sym = k2.SymbolTable.from_file("phones.txt") fsa_disambig.aux_labels_sym = k2.SymbolTable.from_file("words.txt") fsa_disambig.draw("L_disambig.pdf", title="L_disambig") diff --git a/egs/aishell/ASR/pruned_transducer_stateless2/decode.py b/egs/aishell/ASR/pruned_transducer_stateless2/decode.py index a12934d55..ae926ec66 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/aishell/ASR/pruned_transducer_stateless2/decode.py @@ -76,11 +76,7 @@ from beam_search import ( ) from train import add_model_arguments, get_params, get_transducer_model -from icefall.checkpoint import ( - average_checkpoints, - find_checkpoints, - load_checkpoint, -) +from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, @@ -118,9 +114,11 @@ def get_parser(): "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'" + ), ) parser.add_argument( @@ -188,8 +186,7 @@ def get_parser(): "--context-size", type=int, default=1, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -249,9 +246,7 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) if params.decoding_method == "fast_beam_search": hyp_tokens = fast_beam_search_one_best( @@ -263,10 +258,7 @@ def decode_one_batch( max_contexts=params.max_contexts, max_states=params.max_states, ) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -310,11 +302,7 @@ def decode_one_batch( return {"greedy_search": hyps} elif params.decoding_method == "fast_beam_search": return { - ( - f"beam_{params.beam}_" - f"max_contexts_{params.max_contexts}_" - f"max_states_{params.max_states}" - ): hyps + f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps } else: return {f"beam_size_{params.beam_size}": hyps} @@ -387,9 +375,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -415,9 +401,7 @@ def save_results( # we compute CER for aishell dataset. results_char = [] for res in results: - results_char.append( - (res[0], list("".join(res[1])), list("".join(res[2]))) - ) + results_char.append((res[0], list("".join(res[1])), list("".join(res[2])))) with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results_char, enable_log=True @@ -428,8 +412,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -473,9 +456,7 @@ def main(): params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -504,8 +485,7 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( diff --git a/egs/aishell/ASR/pruned_transducer_stateless2/export.py b/egs/aishell/ASR/pruned_transducer_stateless2/export.py index feababdd2..5f6888db4 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless2/export.py +++ b/egs/aishell/ASR/pruned_transducer_stateless2/export.py @@ -50,11 +50,7 @@ from pathlib import Path import torch from train import add_model_arguments, get_params, get_transducer_model -from icefall.checkpoint import ( - average_checkpoints, - find_checkpoints, - load_checkpoint, -) +from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.lexicon import Lexicon from icefall.utils import str2bool @@ -87,9 +83,11 @@ def get_parser(): "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'" + ), ) parser.add_argument( @@ -120,8 +118,7 @@ def get_parser(): "--context-size", type=int, default=1, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) add_model_arguments(parser) @@ -157,8 +154,7 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -191,9 +187,7 @@ def main(): model.__class__.forward = torch.jit.ignore(model.__class__.forward) logging.info("Using torch.jit.script") model = torch.jit.script(model) - filename = ( - params.exp_dir / f"cpu_jit-epoch-{params.epoch}-avg-{params.avg}.pt" - ) + filename = params.exp_dir / f"cpu_jit-epoch-{params.epoch}-avg-{params.avg}.pt" model.save(str(filename)) logging.info(f"Saved to {filename}") else: @@ -201,17 +195,14 @@ def main(): # Save it using a format so that it can be loaded # by :func:`load_checkpoint` filename = ( - params.exp_dir - / f"pretrained-epoch-{params.epoch}-avg-{params.avg}.pt" + params.exp_dir / f"pretrained-epoch-{params.epoch}-avg-{params.avg}.pt" ) torch.save({"model": model.state_dict()}, str(filename)) logging.info(f"Saved to {filename}") if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/aishell/ASR/pruned_transducer_stateless2/pretrained.py b/egs/aishell/ASR/pruned_transducer_stateless2/pretrained.py index 3c38e5db7..f754a7b9e 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless2/pretrained.py +++ b/egs/aishell/ASR/pruned_transducer_stateless2/pretrained.py @@ -87,9 +87,11 @@ def get_parser(): "--checkpoint", type=str, required=True, - help="Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint().", + help=( + "Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint()." + ), ) parser.add_argument( @@ -115,10 +117,12 @@ def get_parser(): "sound_files", type=str, nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", + help=( + "The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz." + ), ) parser.add_argument( @@ -165,15 +169,16 @@ def get_parser(): "--context-size", type=int, default=1, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", type=int, default=1, - help="Maximum number of symbols per frame. " - "Use only when --method is greedy_search", + help=( + "Maximum number of symbols per frame. " + "Use only when --method is greedy_search" + ), ) add_model_arguments(parser) @@ -196,10 +201,9 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" - ) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" # We use only the first channel ans.append(wave[0]) return ans @@ -256,13 +260,9 @@ def main(): feature_lens = [f.size(0) for f in features] feature_lens = torch.tensor(feature_lens, device=device) - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) - encoder_out, encoder_out_lens = model.encoder( - x=features, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lens) num_waves = encoder_out.size(0) hyp_list = [] @@ -310,9 +310,7 @@ def main(): beam=params.beam_size, ) else: - raise ValueError( - f"Unsupported decoding method: {params.method}" - ) + raise ValueError(f"Unsupported decoding method: {params.method}") hyp_list.append(hyp) hyps = [] @@ -329,9 +327,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/aishell/ASR/pruned_transducer_stateless2/train.py b/egs/aishell/ASR/pruned_transducer_stateless2/train.py index 97d892754..66ca23035 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless2/train.py +++ b/egs/aishell/ASR/pruned_transducer_stateless2/train.py @@ -49,7 +49,6 @@ import optim import torch import torch.multiprocessing as mp import torch.nn as nn - from asr_datamodule import AishellAsrDataModule from conformer import Conformer from decoder import Decoder @@ -75,9 +74,7 @@ from icefall.env import get_env_info from icefall.lexicon import Lexicon from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def add_model_arguments(parser: argparse.ArgumentParser): @@ -203,8 +200,7 @@ def get_parser(): "--initial-lr", type=float, default=0.003, - help="The initial learning rate. This value should not need " - "to be changed.", + help="The initial learning rate. This value should not need to be changed.", ) parser.add_argument( @@ -227,42 +223,45 @@ def get_parser(): "--context-size", type=int, default=1, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", + help=( + "The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss" + ), ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", + help=( + "The scale to smooth the loss with lm (output of prediction network) part." + ), ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network)part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", + help=( + "To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss." + ), ) parser.add_argument( @@ -561,11 +560,7 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = ( - model.device - if isinstance(model, DDP) - else next(model.parameters()).device - ) + device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -593,23 +588,16 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 - if warmup < 1.0 - else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = ( - params.simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss + 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) ) + loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() @@ -725,9 +713,7 @@ def train_one_epoch( scaler.update() optimizer.zero_grad() except: # noqa - display_and_save_batch( - batch, params=params, graph_compiler=graph_compiler - ) + display_and_save_batch(batch, params=params, graph_compiler=graph_compiler) raise if params.print_diagnostics and batch_idx == 5: @@ -891,7 +877,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2 ** 22 + 2**22 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) @@ -1029,9 +1015,7 @@ def scan_pessimistic_batches_for_oom( f"Failing criterion: {criterion} " f"(={crit_values[criterion]}) ..." ) - display_and_save_batch( - batch, params=params, graph_compiler=graph_compiler - ) + display_and_save_batch(batch, params=params, graph_compiler=graph_compiler) raise diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/decode.py b/egs/aishell/ASR/pruned_transducer_stateless3/decode.py index d159e420b..6c505940d 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless3/decode.py +++ b/egs/aishell/ASR/pruned_transducer_stateless3/decode.py @@ -121,20 +121,24 @@ def get_parser(): "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'" + ), ) parser.add_argument( "--use-averaged-model", type=str2bool, default=False, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", + help=( + "Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. " + ), ) parser.add_argument( @@ -202,8 +206,7 @@ def get_parser(): "--context-size", type=int, default=1, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -263,9 +266,7 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) if params.decoding_method == "fast_beam_search": hyp_tokens = fast_beam_search_one_best( @@ -277,10 +278,7 @@ def decode_one_batch( max_contexts=params.max_contexts, max_states=params.max_states, ) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -324,11 +322,7 @@ def decode_one_batch( return {"greedy_search": hyps} elif params.decoding_method == "fast_beam_search": return { - ( - f"beam_{params.beam}_" - f"max_contexts_{params.max_contexts}_" - f"max_states_{params.max_states}" - ): hyps + f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps } else: return {f"beam_size_{params.beam_size}": hyps} @@ -401,9 +395,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -429,9 +421,7 @@ def save_results( # we compute CER for aishell dataset. results_char = [] for res in results: - results_char.append( - (res[0], list("".join(res[1])), list("".join(res[2]))) - ) + results_char.append((res[0], list("".join(res[1])), list("".join(res[2])))) with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results_char, enable_log=True @@ -442,8 +432,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tCER", file=f) @@ -488,9 +477,7 @@ def main(): params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -518,13 +505,12 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -551,13 +537,12 @@ def main(): ) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -586,7 +571,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - f"Calculating the averaged model over epoch range from " + "Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/export.py b/egs/aishell/ASR/pruned_transducer_stateless3/export.py index 566902a85..e5a5d7c77 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless3/export.py +++ b/egs/aishell/ASR/pruned_transducer_stateless3/export.py @@ -88,20 +88,24 @@ def get_parser(): "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'" + ), ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", + help=( + "Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. " + ), ) parser.add_argument( @@ -132,8 +136,7 @@ def get_parser(): "--context-size", type=int, default=1, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) add_model_arguments(parser) @@ -166,13 +169,12 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -195,13 +197,12 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -229,7 +230,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - f"Calculating the averaged model over epoch range from " + "Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -252,9 +253,7 @@ def main(): model.__class__.forward = torch.jit.ignore(model.__class__.forward) logging.info("Using torch.jit.script") model = torch.jit.script(model) - filename = ( - params.exp_dir / f"cpu_jit-epoch-{params.epoch}-avg-{params.avg}.pt" - ) + filename = params.exp_dir / f"cpu_jit-epoch-{params.epoch}-avg-{params.avg}.pt" model.save(str(filename)) logging.info(f"Saved to {filename}") else: @@ -262,17 +261,14 @@ def main(): # Save it using a format so that it can be loaded # by :func:`load_checkpoint` filename = ( - params.exp_dir - / f"pretrained-epoch-{params.epoch}-avg-{params.avg}.pt" + params.exp_dir / f"pretrained-epoch-{params.epoch}-avg-{params.avg}.pt" ) torch.save({"model": model.state_dict()}, str(filename)) logging.info(f"Saved to {filename}") if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/model.py b/egs/aishell/ASR/pruned_transducer_stateless3/model.py index e150e8230..a4dda0d6d 100644 --- a/egs/aishell/ASR/pruned_transducer_stateless3/model.py +++ b/egs/aishell/ASR/pruned_transducer_stateless3/model.py @@ -84,9 +84,7 @@ class Transducer(nn.Module): self.decoder_datatang = decoder_datatang self.joiner_datatang = joiner_datatang - self.simple_am_proj = ScaledLinear( - encoder_dim, vocab_size, initial_speed=0.5 - ) + self.simple_am_proj = ScaledLinear(encoder_dim, vocab_size, initial_speed=0.5) self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size) if decoder_datatang is not None: @@ -179,9 +177,7 @@ class Transducer(nn.Module): y_padded = y.pad(mode="constant", padding_value=0) y_padded = y_padded.to(torch.int64) - boundary = torch.zeros( - (x.size(0), 4), dtype=torch.int64, device=x.device - ) + boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) boundary[:, 2] = y_lens boundary[:, 3] = encoder_out_lens diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/pretrained.py b/egs/aishell/ASR/pruned_transducer_stateless3/pretrained.py index 04a0a882a..109879952 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless3/pretrained.py +++ b/egs/aishell/ASR/pruned_transducer_stateless3/pretrained.py @@ -87,9 +87,11 @@ def get_parser(): "--checkpoint", type=str, required=True, - help="Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint().", + help=( + "Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint()." + ), ) parser.add_argument( @@ -115,10 +117,12 @@ def get_parser(): "sound_files", type=str, nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", + help=( + "The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz." + ), ) parser.add_argument( @@ -165,15 +169,16 @@ def get_parser(): "--context-size", type=int, default=1, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", type=int, default=1, - help="Maximum number of symbols per frame. " - "Use only when --method is greedy_search", + help=( + "Maximum number of symbols per frame. " + "Use only when --method is greedy_search" + ), ) add_model_arguments(parser) @@ -196,10 +201,9 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" - ) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" # We use only the first channel ans.append(wave[0]) return ans @@ -257,13 +261,9 @@ def main(): feature_lens = [f.size(0) for f in features] feature_lens = torch.tensor(feature_lens, device=device) - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) - encoder_out, encoder_out_lens = model.encoder( - x=features, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lens) num_waves = encoder_out.size(0) hyp_list = [] @@ -311,9 +311,7 @@ def main(): beam=params.beam_size, ) else: - raise ValueError( - f"Unsupported decoding method: {params.method}" - ) + raise ValueError(f"Unsupported decoding method: {params.method}") hyp_list.append(hyp) hyps = [] @@ -330,9 +328,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/train.py b/egs/aishell/ASR/pruned_transducer_stateless3/train.py index feaef5cf6..b24f533ff 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless3/train.py +++ b/egs/aishell/ASR/pruned_transducer_stateless3/train.py @@ -96,9 +96,7 @@ from icefall.env import get_env_info from icefall.lexicon import Lexicon from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def add_model_arguments(parser: argparse.ArgumentParser): @@ -224,8 +222,7 @@ def get_parser(): "--initial-lr", type=float, default=0.003, - help="The initial learning rate. This value should not need " - "to be changed.", + help="The initial learning rate. This value should not need to be changed.", ) parser.add_argument( @@ -248,42 +245,45 @@ def get_parser(): "--context-size", type=int, default=1, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", + help=( + "The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss" + ), ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", + help=( + "The scale to smooth the loss with lm (output of prediction network) part." + ), ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network)part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", + help=( + "To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss." + ), ) parser.add_argument( @@ -635,11 +635,7 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = ( - model.device - if isinstance(model, DDP) - else next(model.parameters()).device - ) + device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -670,23 +666,16 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 - if warmup < 1.0 - else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = ( - params.simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss + 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) ) + loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() @@ -824,9 +813,7 @@ def train_one_epoch( ) # summary stats if datatang_train_dl is not None: - tot_loss = ( - tot_loss * (1 - 1 / params.reset_interval) - ) + loss_info + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info if aishell: aishell_tot_loss = ( @@ -847,9 +834,7 @@ def train_one_epoch( scaler.update() optimizer.zero_grad() except: # noqa - display_and_save_batch( - batch, params=params, graph_compiler=graph_compiler - ) + display_and_save_batch(batch, params=params, graph_compiler=graph_compiler) raise if params.print_diagnostics and batch_idx == 5: @@ -892,9 +877,7 @@ def train_one_epoch( cur_lr = scheduler.get_last_lr()[0] if datatang_train_dl is not None: datatang_str = f"datatang_tot_loss[{datatang_tot_loss}], " - tot_loss_str = ( - f"tot_loss[{tot_loss}], batch size: {batch_size}, " - ) + tot_loss_str = f"tot_loss[{tot_loss}], batch size: {batch_size}, " else: tot_loss_str = "" datatang_str = "" @@ -1067,7 +1050,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2 ** 22 + 2**22 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) @@ -1076,9 +1059,7 @@ def run(rank, world_size, args): train_cuts = filter_short_and_long_utterances(train_cuts) if args.enable_musan: - cuts_musan = load_manifest( - Path(args.manifest_dir) / "musan_cuts.jsonl.gz" - ) + cuts_musan = load_manifest(Path(args.manifest_dir) / "musan_cuts.jsonl.gz") else: cuts_musan = None @@ -1093,9 +1074,7 @@ def run(rank, world_size, args): if params.datatang_prob > 0: datatang = AIDatatang200zh(manifest_dir=args.manifest_dir) train_datatang_cuts = datatang.train_cuts() - train_datatang_cuts = filter_short_and_long_utterances( - train_datatang_cuts - ) + train_datatang_cuts = filter_short_and_long_utterances(train_datatang_cuts) train_datatang_cuts = train_datatang_cuts.repeat(times=None) datatang_train_dl = asr_datamodule.train_dataloaders( train_datatang_cuts, @@ -1249,9 +1228,7 @@ def scan_pessimistic_batches_for_oom( f"Failing criterion: {criterion} " f"(={crit_values[criterion]}) ..." ) - display_and_save_batch( - batch, params=params, graph_compiler=graph_compiler - ) + display_and_save_batch(batch, params=params, graph_compiler=graph_compiler) raise diff --git a/egs/aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py index d24ba6bb7..12ae6e7d4 100644 --- a/egs/aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py +++ b/egs/aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py @@ -64,10 +64,12 @@ class AishellAsrDataModule: def add_arguments(cls, parser: argparse.ArgumentParser): group = parser.add_argument_group( title="ASR data related options", - description="These options are used for the preparation of " - "PyTorch DataLoaders from Lhotse CutSet's -- they control the " - "effective batch sizes, sampling strategies, applied data " - "augmentations, etc.", + description=( + "These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc." + ), ) group.add_argument( "--manifest-dir", @@ -79,59 +81,74 @@ class AishellAsrDataModule: "--max-duration", type=int, default=200.0, - help="Maximum pooled recordings duration (seconds) in a " - "single batch. You can reduce it if it causes CUDA OOM.", + help=( + "Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM." + ), ) group.add_argument( "--bucketing-sampler", type=str2bool, default=True, - help="When enabled, the batches will come from buckets of " - "similar duration (saves padding frames).", + help=( + "When enabled, the batches will come from buckets of " + "similar duration (saves padding frames)." + ), ) group.add_argument( "--num-buckets", type=int, default=30, - help="The number of buckets for the DynamicBucketingSampler" - "(you might want to increase it for larger datasets).", + help=( + "The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets)." + ), ) group.add_argument( "--concatenate-cuts", type=str2bool, default=False, - help="When enabled, utterances (cuts) will be concatenated " - "to minimize the amount of padding.", + help=( + "When enabled, utterances (cuts) will be concatenated " + "to minimize the amount of padding." + ), ) group.add_argument( "--duration-factor", type=float, default=1.0, - help="Determines the maximum duration of a concatenated cut " - "relative to the duration of the longest cut in a batch.", + help=( + "Determines the maximum duration of a concatenated cut " + "relative to the duration of the longest cut in a batch." + ), ) group.add_argument( "--gap", type=float, default=1.0, - help="The amount of padding (in seconds) inserted between " - "concatenated cuts. This padding is filled with noise when " - "noise augmentation is used.", + help=( + "The amount of padding (in seconds) inserted between " + "concatenated cuts. This padding is filled with noise when " + "noise augmentation is used." + ), ) group.add_argument( "--on-the-fly-feats", type=str2bool, default=False, - help="When enabled, use on-the-fly cut mixing and feature " - "extraction. Will drop existing precomputed feature manifests " - "if available.", + help=( + "When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available." + ), ) group.add_argument( "--shuffle", type=str2bool, default=True, - help="When enabled (=default), the examples will be " - "shuffled for each epoch.", + help=( + "When enabled (=default), the examples will be shuffled for each epoch." + ), ) group.add_argument( "--drop-last", @@ -143,17 +160,18 @@ class AishellAsrDataModule: "--return-cuts", type=str2bool, default=True, - help="When enabled, each batch will have the " - "field: batch['supervisions']['cut'] with the cuts that " - "were used to construct it.", + help=( + "When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it." + ), ) group.add_argument( "--num-workers", type=int, default=2, - help="The number of training dataloader workers that " - "collect the batches.", + help="The number of training dataloader workers that collect the batches.", ) group.add_argument( @@ -167,40 +185,40 @@ class AishellAsrDataModule: "--spec-aug-time-warp-factor", type=int, default=80, - help="Used only when --enable-spec-aug is True. " - "It specifies the factor for time warping in SpecAugment. " - "Larger values mean more warping. " - "A value less than 1 means to disable time warp.", + help=( + "Used only when --enable-spec-aug is True. " + "It specifies the factor for time warping in SpecAugment. " + "Larger values mean more warping. " + "A value less than 1 means to disable time warp." + ), ) group.add_argument( "--enable-musan", type=str2bool, default=True, - help="When enabled, select noise from MUSAN and mix it" - "with training dataset. ", + help=( + "When enabled, select noise from MUSAN and mix it" + "with training dataset. " + ), ) def train_dataloaders(self, cuts_train: CutSet) -> DataLoader: logging.info("About to get Musan cuts") - cuts_musan = load_manifest( - self.args.manifest_dir / "musan_cuts.jsonl.gz" - ) + cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") transforms = [] if self.args.enable_musan: logging.info("Enable MUSAN") transforms.append( - CutMix( - cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True - ) + CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) ) else: logging.info("Disable MUSAN") if self.args.concatenate_cuts: logging.info( - f"Using cut concatenation with duration factor " + "Using cut concatenation with duration factor " f"{self.args.duration_factor} and gap {self.args.gap}." ) # Cut concatenation should be the first transform in the list, @@ -215,9 +233,7 @@ class AishellAsrDataModule: input_transforms = [] if self.args.enable_spec_aug: logging.info("Enable SpecAugment") - logging.info( - f"Time warp factor: {self.args.spec_aug_time_warp_factor}" - ) + logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") # Set the value of num_frame_masks according to Lhotse's version. # In different Lhotse's versions, the default of num_frame_masks is # different. @@ -260,9 +276,7 @@ class AishellAsrDataModule: # Drop feats to be on the safe side. train = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) - ), + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), input_transforms=input_transforms, return_cuts=self.args.return_cuts, ) @@ -308,9 +322,7 @@ class AishellAsrDataModule: if self.args.on_the_fly_feats: validate = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) - ), + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), return_cuts=self.args.return_cuts, ) else: @@ -366,13 +378,9 @@ class AishellAsrDataModule: @lru_cache() def valid_cuts(self) -> CutSet: logging.info("About to get dev cuts") - return load_manifest_lazy( - self.args.manifest_dir / "aishell_cuts_dev.jsonl.gz" - ) + return load_manifest_lazy(self.args.manifest_dir / "aishell_cuts_dev.jsonl.gz") @lru_cache() def test_cuts(self) -> List[CutSet]: logging.info("About to get test cuts") - return load_manifest_lazy( - self.args.manifest_dir / "aishell_cuts_test.jsonl.gz" - ) + return load_manifest_lazy(self.args.manifest_dir / "aishell_cuts_test.jsonl.gz") diff --git a/egs/aishell/ASR/tdnn_lstm_ctc/decode.py b/egs/aishell/ASR/tdnn_lstm_ctc/decode.py index 66b734fc4..8ef247438 100755 --- a/egs/aishell/ASR/tdnn_lstm_ctc/decode.py +++ b/egs/aishell/ASR/tdnn_lstm_ctc/decode.py @@ -49,16 +49,19 @@ def get_parser(): "--epoch", type=int, default=19, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", + help=( + "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." + ), ) parser.add_argument( "--avg", type=int, default=5, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. " + ), ) parser.add_argument( "--method", @@ -265,9 +268,7 @@ def decode_dataset( if batch_idx % 100 == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -289,9 +290,7 @@ def save_results( # We compute CER for aishell dataset. results_char = [] for res in results: - results_char.append( - (res[0], list("".join(res[1])), list("".join(res[2]))) - ) + results_char.append((res[0], list("".join(res[1])), list("".join(res[2])))) with open(errs_filename, "w") as f: wer = write_error_stats(f, f"{test_set_name}-{key}", results_char) test_set_wers[key] = wer @@ -335,9 +334,7 @@ def main(): logging.info(f"device: {device}") - HLG = k2.Fsa.from_dict( - torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu") - ) + HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")) HLG = HLG.to(device) assert HLG.requires_grad is False @@ -362,9 +359,7 @@ def main(): if params.export: logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt") - torch.save( - {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt" - ) + torch.save({"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt") model.to(device) model.eval() @@ -392,9 +387,7 @@ def main(): lexicon=lexicon, ) - save_results( - params=params, test_set_name=test_set, results_dict=results_dict - ) + save_results(params=params, test_set_name=test_set, results_dict=results_dict) logging.info("Done!") diff --git a/egs/aishell/ASR/tdnn_lstm_ctc/model.py b/egs/aishell/ASR/tdnn_lstm_ctc/model.py index 5e04c11b4..1731e1ebe 100644 --- a/egs/aishell/ASR/tdnn_lstm_ctc/model.py +++ b/egs/aishell/ASR/tdnn_lstm_ctc/model.py @@ -66,10 +66,7 @@ class TdnnLstm(nn.Module): nn.BatchNorm1d(num_features=500, affine=False), ) self.lstms = nn.ModuleList( - [ - nn.LSTM(input_size=500, hidden_size=500, num_layers=1) - for _ in range(5) - ] + [nn.LSTM(input_size=500, hidden_size=500, num_layers=1) for _ in range(5)] ) self.lstm_bnorms = nn.ModuleList( [nn.BatchNorm1d(num_features=500, affine=False) for _ in range(5)] diff --git a/egs/aishell/ASR/tdnn_lstm_ctc/pretrained.py b/egs/aishell/ASR/tdnn_lstm_ctc/pretrained.py index 9bd810809..52f9410cf 100644 --- a/egs/aishell/ASR/tdnn_lstm_ctc/pretrained.py +++ b/egs/aishell/ASR/tdnn_lstm_ctc/pretrained.py @@ -41,9 +41,11 @@ def get_parser(): "--checkpoint", type=str, required=True, - help="Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint().", + help=( + "Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint()." + ), ) parser.add_argument( @@ -53,9 +55,7 @@ def get_parser(): help="Path to words.txt", ) - parser.add_argument( - "--HLG", type=str, required=True, help="Path to HLG.pt." - ) + parser.add_argument("--HLG", type=str, required=True, help="Path to HLG.pt.") parser.add_argument( "--method", @@ -71,10 +71,12 @@ def get_parser(): "sound_files", type=str, nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", + help=( + "The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz." + ), ) return parser @@ -112,10 +114,9 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" - ) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" # We use only the first channel ans.append(wave[0]) return ans @@ -173,9 +174,7 @@ def main(): logging.info("Decoding started") features = fbank(waves) - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) features = features.permute(0, 2, 1) # now features is [N, C, T] with torch.no_grad(): @@ -219,9 +218,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/aishell/ASR/tdnn_lstm_ctc/train.py b/egs/aishell/ASR/tdnn_lstm_ctc/train.py index 7619b0551..e574cf89b 100755 --- a/egs/aishell/ASR/tdnn_lstm_ctc/train.py +++ b/egs/aishell/ASR/tdnn_lstm_ctc/train.py @@ -49,12 +49,7 @@ from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.dist import cleanup_dist, setup_dist from icefall.graph_compiler import CtcTrainingGraphCompiler from icefall.lexicon import Lexicon -from icefall.utils import ( - AttributeDict, - encode_supervisions, - setup_logger, - str2bool, -) +from icefall.utils import AttributeDict, encode_supervisions, setup_logger, str2bool def get_parser(): diff --git a/egs/aishell/ASR/transducer_stateless/beam_search.py b/egs/aishell/ASR/transducer_stateless/beam_search.py index 9ed9b2ad1..de0a8d0f5 100644 --- a/egs/aishell/ASR/transducer_stateless/beam_search.py +++ b/egs/aishell/ASR/transducer_stateless/beam_search.py @@ -47,9 +47,9 @@ def greedy_search( device = model.device - decoder_input = torch.tensor( - [blank_id] * context_size, device=device - ).reshape(1, context_size) + decoder_input = torch.tensor([blank_id] * context_size, device=device).reshape( + 1, context_size + ) decoder_out = model.decoder(decoder_input, need_pad=False) @@ -81,9 +81,9 @@ def greedy_search( y = logits.argmax().item() if y != blank_id: hyp.append(y) - decoder_input = torch.tensor( - [hyp[-context_size:]], device=device - ).reshape(1, context_size) + decoder_input = torch.tensor([hyp[-context_size:]], device=device).reshape( + 1, context_size + ) decoder_out = model.decoder(decoder_input, need_pad=False) @@ -157,9 +157,7 @@ class HypothesisList(object): """ if length_norm: - return max( - self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys) - ) + return max(self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys)) else: return max(self._data.values(), key=lambda hyp: hyp.log_prob) @@ -246,9 +244,9 @@ def beam_search( device = model.device - decoder_input = torch.tensor( - [blank_id] * context_size, device=device - ).reshape(1, context_size) + decoder_input = torch.tensor([blank_id] * context_size, device=device).reshape( + 1, context_size + ) decoder_out = model.decoder(decoder_input, need_pad=False) diff --git a/egs/aishell/ASR/transducer_stateless/conformer.py b/egs/aishell/ASR/transducer_stateless/conformer.py index 64114253d..e26c6c385 100644 --- a/egs/aishell/ASR/transducer_stateless/conformer.py +++ b/egs/aishell/ASR/transducer_stateless/conformer.py @@ -155,9 +155,7 @@ class ConformerEncoderLayer(nn.Module): normalize_before: bool = True, ) -> None: super(ConformerEncoderLayer, self).__init__() - self.self_attn = RelPositionMultiheadAttention( - d_model, nhead, dropout=0.0 - ) + self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) self.feed_forward = nn.Sequential( nn.Linear(d_model, dim_feedforward), @@ -175,18 +173,14 @@ class ConformerEncoderLayer(nn.Module): self.conv_module = ConvolutionModule(d_model, cnn_module_kernel) - self.norm_ff_macaron = nn.LayerNorm( - d_model - ) # for the macaron style FNN module + self.norm_ff_macaron = nn.LayerNorm(d_model) # for the macaron style FNN module self.norm_ff = nn.LayerNorm(d_model) # for the FNN module self.norm_mha = nn.LayerNorm(d_model) # for the MHA module self.ff_scale = 0.5 self.norm_conv = nn.LayerNorm(d_model) # for the CNN module - self.norm_final = nn.LayerNorm( - d_model - ) # for the final output of the block + self.norm_final = nn.LayerNorm(d_model) # for the final output of the block self.dropout = nn.Dropout(dropout) @@ -220,9 +214,7 @@ class ConformerEncoderLayer(nn.Module): residual = src if self.normalize_before: src = self.norm_ff_macaron(src) - src = residual + self.ff_scale * self.dropout( - self.feed_forward_macaron(src) - ) + src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src)) if not self.normalize_before: src = self.norm_ff_macaron(src) @@ -341,9 +333,7 @@ class RelPositionalEncoding(torch.nn.Module): """ - def __init__( - self, d_model: int, dropout_rate: float, max_len: int = 5000 - ) -> None: + def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: """Construct an PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() self.d_model = d_model @@ -359,9 +349,7 @@ class RelPositionalEncoding(torch.nn.Module): # the length of self.pe is 2 * input_len - 1 if self.pe.size(1) >= x.size(1) * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str( - x.device - ): + if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return # Suppose `i` means to the position of query vector and `j` means the @@ -631,9 +619,9 @@ class RelPositionMultiheadAttention(nn.Module): if torch.equal(query, key) and torch.equal(key, value): # self-attention - q, k, v = nn.functional.linear( - query, in_proj_weight, in_proj_bias - ).chunk(3, dim=-1) + q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk( + 3, dim=-1 + ) elif torch.equal(key, value): # encoder-decoder attention @@ -701,33 +689,25 @@ class RelPositionMultiheadAttention(nn.Module): if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: - raise RuntimeError( - "The size of the 2D attn_mask is not correct." - ) + raise RuntimeError("The size of the 2D attn_mask is not correct.") elif attn_mask.dim() == 3: if list(attn_mask.size()) != [ bsz * num_heads, query.size(0), key.size(0), ]: - raise RuntimeError( - "The size of the 3D attn_mask is not correct." - ) + raise RuntimeError("The size of the 3D attn_mask is not correct.") else: raise RuntimeError( - "attn_mask's dimension {} is not supported".format( - attn_mask.dim() - ) + "attn_mask's dimension {} is not supported".format(attn_mask.dim()) ) # attn_mask's dim is 3 now. # convert ByteTensor key_padding_mask to bool - if ( - key_padding_mask is not None - and key_padding_mask.dtype == torch.uint8 - ): + if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: warnings.warn( - "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." + "Byte tensor for key_padding_mask is deprecated. Use bool tensor" + " instead." ) key_padding_mask = key_padding_mask.to(torch.bool) @@ -764,9 +744,7 @@ class RelPositionMultiheadAttention(nn.Module): # first compute matrix a and matrix c # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - matrix_ac = torch.matmul( - q_with_bias_u, k - ) # (batch, head, time1, time2) + matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2) # compute matrix b and matrix d matrix_bd = torch.matmul( @@ -778,9 +756,7 @@ class RelPositionMultiheadAttention(nn.Module): matrix_ac + matrix_bd ) * scaling # (batch, head, time1, time2) - attn_output_weights = attn_output_weights.view( - bsz * num_heads, tgt_len, -1 - ) + attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) assert list(attn_output_weights.size()) == [ bsz * num_heads, @@ -814,13 +790,9 @@ class RelPositionMultiheadAttention(nn.Module): attn_output = torch.bmm(attn_output_weights, v) assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] attn_output = ( - attn_output.transpose(0, 1) - .contiguous() - .view(tgt_len, bsz, embed_dim) - ) - attn_output = nn.functional.linear( - attn_output, out_proj_weight, out_proj_bias + attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) ) + attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) if need_weights: # average attention weights over heads @@ -843,9 +815,7 @@ class ConvolutionModule(nn.Module): """ - def __init__( - self, channels: int, kernel_size: int, bias: bool = True - ) -> None: + def __init__(self, channels: int, kernel_size: int, bias: bool = True) -> None: """Construct an ConvolutionModule object.""" super(ConvolutionModule, self).__init__() # kernerl_size should be a odd number for 'SAME' padding diff --git a/egs/aishell/ASR/transducer_stateless/decode.py b/egs/aishell/ASR/transducer_stateless/decode.py index 780b0c4bb..1f7bb14e1 100755 --- a/egs/aishell/ASR/transducer_stateless/decode.py +++ b/egs/aishell/ASR/transducer_stateless/decode.py @@ -52,16 +52,19 @@ def get_parser(): "--epoch", type=int, default=30, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", + help=( + "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." + ), ) parser.add_argument( "--avg", type=int, default=10, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. " + ), ) parser.add_argument( @@ -99,8 +102,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -227,9 +229,7 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) hyps = [] batch_size = encoder_out.size(0) @@ -248,9 +248,7 @@ def decode_one_batch( model=model, encoder_out=encoder_out_i, beam=params.beam_size ) else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") hyps.append([lexicon.token_table[i] for i in hyp]) if params.decoding_method == "greedy_search": @@ -319,9 +317,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -346,9 +342,7 @@ def save_results( # we compute CER for aishell dataset. results_char = [] for res in results: - results_char.append( - (res[0], list("".join(res[1])), list("".join(res[2]))) - ) + results_char.append((res[0], list("".join(res[1])), list("".join(res[2])))) with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results_char, enable_log=True @@ -359,8 +353,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tCER", file=f) @@ -430,9 +423,7 @@ def main(): if params.export: logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt") - torch.save( - {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt" - ) + torch.save({"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt") return model.to(device) diff --git a/egs/aishell/ASR/transducer_stateless/decoder.py b/egs/aishell/ASR/transducer_stateless/decoder.py index c2c6552a9..70e9e6c96 100644 --- a/egs/aishell/ASR/transducer_stateless/decoder.py +++ b/egs/aishell/ASR/transducer_stateless/decoder.py @@ -86,9 +86,7 @@ class Decoder(nn.Module): if self.context_size > 1: embedding_out = embedding_out.permute(0, 2, 1) if need_pad is True: - embedding_out = F.pad( - embedding_out, pad=(self.context_size - 1, 0) - ) + embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0)) else: # During inference time, there is no need to do extra padding # as we only need one output diff --git a/egs/aishell/ASR/transducer_stateless/export.py b/egs/aishell/ASR/transducer_stateless/export.py index 4c6519b96..e35b26fe0 100755 --- a/egs/aishell/ASR/transducer_stateless/export.py +++ b/egs/aishell/ASR/transducer_stateless/export.py @@ -69,17 +69,20 @@ def get_parser(): "--epoch", type=int, default=20, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", + help=( + "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." + ), ) parser.add_argument( "--avg", type=int, default=10, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. " + ), ) parser.add_argument( @@ -110,8 +113,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) return parser @@ -243,9 +245,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/aishell/ASR/transducer_stateless/model.py b/egs/aishell/ASR/transducer_stateless/model.py index 994305fc1..591bbe44f 100644 --- a/egs/aishell/ASR/transducer_stateless/model.py +++ b/egs/aishell/ASR/transducer_stateless/model.py @@ -103,9 +103,7 @@ class Transducer(nn.Module): y_padded = y.pad(mode="constant", padding_value=0) y_padded = y_padded.to(torch.int64) - boundary = torch.zeros( - (x.size(0), 4), dtype=torch.int64, device=x.device - ) + boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) boundary[:, 2] = y_lens boundary[:, 3] = x_lens diff --git a/egs/aishell/ASR/transducer_stateless/pretrained.py b/egs/aishell/ASR/transducer_stateless/pretrained.py index db89c4d67..8effc9815 100755 --- a/egs/aishell/ASR/transducer_stateless/pretrained.py +++ b/egs/aishell/ASR/transducer_stateless/pretrained.py @@ -73,9 +73,11 @@ def get_parser(): "--checkpoint", type=str, required=True, - help="Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint().", + help=( + "Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint()." + ), ) parser.add_argument( @@ -100,10 +102,12 @@ def get_parser(): "sound_files", type=str, nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", + help=( + "The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz." + ), ) parser.add_argument( @@ -117,8 +121,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -211,10 +214,9 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" - ) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" # We use only the first channel ans.append(wave[0]) return ans @@ -273,9 +275,7 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) feature_lengths = torch.tensor(feature_lengths, device=device) @@ -319,9 +319,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/aishell/ASR/transducer_stateless/train.py b/egs/aishell/ASR/transducer_stateless/train.py index d54157709..62ffff473 100755 --- a/egs/aishell/ASR/transducer_stateless/train.py +++ b/egs/aishell/ASR/transducer_stateless/train.py @@ -126,8 +126,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( @@ -389,9 +388,7 @@ def compute_loss( info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() @@ -504,9 +501,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -625,9 +620,7 @@ def run(rank, world_size, args): cur_lr = optimizer._rate if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) + tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) if rank == 0: diff --git a/egs/aishell/ASR/transducer_stateless/transformer.py b/egs/aishell/ASR/transducer_stateless/transformer.py index e851dcc32..b3ff153c1 100644 --- a/egs/aishell/ASR/transducer_stateless/transformer.py +++ b/egs/aishell/ASR/transducer_stateless/transformer.py @@ -250,9 +250,7 @@ def _get_activation_fn(activation: str): elif activation == "gelu": return nn.functional.gelu - raise RuntimeError( - "activation should be relu/gelu, not {}".format(activation) - ) + raise RuntimeError("activation should be relu/gelu, not {}".format(activation)) class PositionalEncoding(nn.Module): diff --git a/egs/aishell/ASR/transducer_stateless_modified-2/asr_datamodule.py b/egs/aishell/ASR/transducer_stateless_modified-2/asr_datamodule.py index 838e53658..76e209f06 100644 --- a/egs/aishell/ASR/transducer_stateless_modified-2/asr_datamodule.py +++ b/egs/aishell/ASR/transducer_stateless_modified-2/asr_datamodule.py @@ -29,10 +29,7 @@ from lhotse.dataset import ( K2SpeechRecognitionDataset, SpecAugment, ) -from lhotse.dataset.input_strategies import ( - OnTheFlyFeatures, - PrecomputedFeatures, -) +from lhotse.dataset.input_strategies import OnTheFlyFeatures, PrecomputedFeatures from torch.utils.data import DataLoader from icefall.utils import str2bool @@ -46,59 +43,69 @@ class AsrDataModule: def add_arguments(cls, parser: argparse.ArgumentParser): group = parser.add_argument_group( title="ASR data related options", - description="These options are used for the preparation of " - "PyTorch DataLoaders from Lhotse CutSet's -- they control the " - "effective batch sizes, sampling strategies, applied data " - "augmentations, etc.", + description=( + "These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc." + ), ) group.add_argument( "--max-duration", type=int, default=200.0, - help="Maximum pooled recordings duration (seconds) in a " - "single batch. You can reduce it if it causes CUDA OOM.", + help=( + "Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM." + ), ) group.add_argument( "--bucketing-sampler", type=str2bool, default=True, - help="When enabled, the batches will come from buckets of " - "similar duration (saves padding frames).", + help=( + "When enabled, the batches will come from buckets of " + "similar duration (saves padding frames)." + ), ) group.add_argument( "--num-buckets", type=int, default=30, - help="The number of buckets for the DynamicBucketingSampler " - "(you might want to increase it for larger datasets).", + help=( + "The number of buckets for the DynamicBucketingSampler " + "(you might want to increase it for larger datasets)." + ), ) group.add_argument( "--shuffle", type=str2bool, default=True, - help="When enabled (=default), the examples will be " - "shuffled for each epoch.", + help=( + "When enabled (=default), the examples will be shuffled for each epoch." + ), ) group.add_argument( "--return-cuts", type=str2bool, default=True, - help="When enabled, each batch will have the " - "field: batch['supervisions']['cut'] with the cuts that " - "were used to construct it.", + help=( + "When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it." + ), ) group.add_argument( "--num-workers", type=int, default=2, - help="The number of training dataloader workers that " - "collect the batches.", + help="The number of training dataloader workers that collect the batches.", ) group.add_argument( @@ -112,18 +119,22 @@ class AsrDataModule: "--spec-aug-time-warp-factor", type=int, default=80, - help="Used only when --enable-spec-aug is True. " - "It specifies the factor for time warping in SpecAugment. " - "Larger values mean more warping. " - "A value less than 1 means to disable time warp.", + help=( + "Used only when --enable-spec-aug is True. " + "It specifies the factor for time warping in SpecAugment. " + "Larger values mean more warping. " + "A value less than 1 means to disable time warp." + ), ) group.add_argument( "--enable-musan", type=str2bool, default=True, - help="When enabled, select noise from MUSAN and mix it" - "with training dataset. ", + help=( + "When enabled, select noise from MUSAN and mix it" + "with training dataset. " + ), ) group.add_argument( @@ -137,9 +148,11 @@ class AsrDataModule: "--on-the-fly-feats", type=str2bool, default=False, - help="When enabled, use on-the-fly cut mixing and feature " - "extraction. Will drop existing precomputed feature manifests " - "if available. Used only in dev/test CutSet", + help=( + "When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available. Used only in dev/test CutSet" + ), ) def train_dataloaders( @@ -162,9 +175,7 @@ class AsrDataModule: if cuts_musan is not None: logging.info("Enable MUSAN") transforms.append( - CutMix( - cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True - ) + CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) ) else: logging.info("Disable MUSAN") @@ -173,9 +184,7 @@ class AsrDataModule: if self.args.enable_spec_aug: logging.info("Enable SpecAugment") - logging.info( - f"Time warp factor: {self.args.spec_aug_time_warp_factor}" - ) + logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") # Set the value of num_frame_masks according to Lhotse's version. # In different Lhotse's versions, the default of num_frame_masks is # different. @@ -252,9 +261,7 @@ class AsrDataModule: if self.args.on_the_fly_feats: validate = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) - ), + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), return_cuts=self.args.return_cuts, ) else: diff --git a/egs/aishell/ASR/transducer_stateless_modified-2/decode.py b/egs/aishell/ASR/transducer_stateless_modified-2/decode.py index ea3f94fd8..fd4cb8385 100755 --- a/egs/aishell/ASR/transducer_stateless_modified-2/decode.py +++ b/egs/aishell/ASR/transducer_stateless_modified-2/decode.py @@ -93,16 +93,19 @@ def get_parser(): "--epoch", type=int, default=30, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", + help=( + "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." + ), ) parser.add_argument( "--avg", type=int, default=10, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. " + ), ) parser.add_argument( @@ -170,8 +173,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( @@ -227,9 +229,7 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) if params.decoding_method == "fast_beam_search": hyp_tokens = fast_beam_search_one_best( @@ -241,10 +241,7 @@ def decode_one_batch( max_contexts=params.max_contexts, max_states=params.max_states, ) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -288,11 +285,7 @@ def decode_one_batch( return {"greedy_search": hyps} elif params.decoding_method == "fast_beam_search": return { - ( - f"beam_{params.beam}_" - f"max_contexts_{params.max_contexts}_" - f"max_states_{params.max_states}" - ): hyps + f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps } else: return {f"beam_size_{params.beam_size}": hyps} @@ -365,9 +358,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -393,9 +384,7 @@ def save_results( # we compute CER for aishell dataset. results_char = [] for res in results: - results_char.append( - (res[0], list("".join(res[1])), list("".join(res[2]))) - ) + results_char.append((res[0], list("".join(res[1])), list("".join(res[2])))) with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results_char, enable_log=True @@ -406,8 +395,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tCER", file=f) @@ -448,9 +436,7 @@ def main(): params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" diff --git a/egs/aishell/ASR/transducer_stateless_modified-2/export.py b/egs/aishell/ASR/transducer_stateless_modified-2/export.py index 3bd2ceb11..32481829c 100755 --- a/egs/aishell/ASR/transducer_stateless_modified-2/export.py +++ b/egs/aishell/ASR/transducer_stateless_modified-2/export.py @@ -68,17 +68,20 @@ def get_parser(): "--epoch", type=int, default=20, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", + help=( + "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." + ), ) parser.add_argument( "--avg", type=int, default=10, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. " + ), ) parser.add_argument( @@ -109,8 +112,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) return parser @@ -241,9 +243,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/aishell/ASR/transducer_stateless_modified-2/pretrained.py b/egs/aishell/ASR/transducer_stateless_modified-2/pretrained.py index a95a4bc52..55701a007 100755 --- a/egs/aishell/ASR/transducer_stateless_modified-2/pretrained.py +++ b/egs/aishell/ASR/transducer_stateless_modified-2/pretrained.py @@ -87,9 +87,11 @@ def get_parser(): "--checkpoint", type=str, required=True, - help="Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint().", + help=( + "Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint()." + ), ) parser.add_argument( @@ -115,10 +117,12 @@ def get_parser(): "sound_files", type=str, nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", + help=( + "The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz." + ), ) parser.add_argument( @@ -165,15 +169,16 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", type=int, default=1, - help="Maximum number of symbols per frame. " - "Use only when --method is greedy_search", + help=( + "Maximum number of symbols per frame. " + "Use only when --method is greedy_search" + ), ) return parser @@ -194,10 +199,9 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" - ) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" # We use only the first channel ans.append(wave[0]) return ans @@ -254,13 +258,9 @@ def main(): feature_lens = [f.size(0) for f in features] feature_lens = torch.tensor(feature_lens, device=device) - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) - encoder_out, encoder_out_lens = model.encoder( - x=features, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lens) num_waves = encoder_out.size(0) hyp_list = [] @@ -308,9 +308,7 @@ def main(): beam=params.beam_size, ) else: - raise ValueError( - f"Unsupported decoding method: {params.method}" - ) + raise ValueError(f"Unsupported decoding method: {params.method}") hyp_list.append(hyp) hyps = [] @@ -327,9 +325,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/aishell/ASR/transducer_stateless_modified-2/train.py b/egs/aishell/ASR/transducer_stateless_modified-2/train.py index 225d0d709..8fb7d1e49 100755 --- a/egs/aishell/ASR/transducer_stateless_modified-2/train.py +++ b/egs/aishell/ASR/transducer_stateless_modified-2/train.py @@ -149,8 +149,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( @@ -168,8 +167,7 @@ def get_parser(): "--datatang-prob", type=float, default=0.2, - help="The probability to select a batch from the " - "aidatatang_200zh dataset", + help="The probability to select a batch from the aidatatang_200zh dataset", ) return parser @@ -449,9 +447,7 @@ def compute_loss( info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() @@ -605,9 +601,7 @@ def train_one_epoch( f"train/current_{prefix}_", params.batch_idx_train, ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) aishell_tot_loss.write_summary( tb_writer, "train/aishell_tot_", params.batch_idx_train ) @@ -735,9 +729,7 @@ def run(rank, world_size, args): train_datatang_cuts = train_datatang_cuts.repeat(times=None) if args.enable_musan: - cuts_musan = load_manifest( - Path(args.manifest_dir) / "musan_cuts.jsonl.gz" - ) + cuts_musan = load_manifest(Path(args.manifest_dir) / "musan_cuts.jsonl.gz") else: cuts_musan = None @@ -776,9 +768,7 @@ def run(rank, world_size, args): cur_lr = optimizer._rate if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) + tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) if rank == 0: diff --git a/egs/aishell/ASR/transducer_stateless_modified/decode.py b/egs/aishell/ASR/transducer_stateless_modified/decode.py index 65fcda873..1e41942da 100755 --- a/egs/aishell/ASR/transducer_stateless_modified/decode.py +++ b/egs/aishell/ASR/transducer_stateless_modified/decode.py @@ -94,16 +94,19 @@ def get_parser(): "--epoch", type=int, default=30, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", + help=( + "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." + ), ) parser.add_argument( "--avg", type=int, default=10, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. " + ), ) parser.add_argument( @@ -171,8 +174,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( @@ -231,9 +233,7 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) if params.decoding_method == "fast_beam_search": hyp_tokens = fast_beam_search_one_best( @@ -245,10 +245,7 @@ def decode_one_batch( max_contexts=params.max_contexts, max_states=params.max_states, ) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -292,11 +289,7 @@ def decode_one_batch( return {"greedy_search": hyps} elif params.decoding_method == "fast_beam_search": return { - ( - f"beam_{params.beam}_" - f"max_contexts_{params.max_contexts}_" - f"max_states_{params.max_states}" - ): hyps + f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps } else: return {f"beam_size_{params.beam_size}": hyps} @@ -369,9 +362,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -397,9 +388,7 @@ def save_results( # we compute CER for aishell dataset. results_char = [] for res in results: - results_char.append( - (res[0], list("".join(res[1])), list("".join(res[2]))) - ) + results_char.append((res[0], list("".join(res[1])), list("".join(res[2])))) with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results_char, enable_log=True @@ -410,8 +399,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tCER", file=f) @@ -452,9 +440,7 @@ def main(): params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" diff --git a/egs/aishell/ASR/transducer_stateless_modified/export.py b/egs/aishell/ASR/transducer_stateless_modified/export.py index 11335a834..ca1d4bd4a 100755 --- a/egs/aishell/ASR/transducer_stateless_modified/export.py +++ b/egs/aishell/ASR/transducer_stateless_modified/export.py @@ -68,17 +68,20 @@ def get_parser(): "--epoch", type=int, default=20, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", + help=( + "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." + ), ) parser.add_argument( "--avg", type=int, default=10, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. " + ), ) parser.add_argument( @@ -109,8 +112,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) return parser @@ -241,9 +243,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/aishell/ASR/transducer_stateless_modified/pretrained.py b/egs/aishell/ASR/transducer_stateless_modified/pretrained.py index 262e822c2..038090461 100755 --- a/egs/aishell/ASR/transducer_stateless_modified/pretrained.py +++ b/egs/aishell/ASR/transducer_stateless_modified/pretrained.py @@ -87,9 +87,11 @@ def get_parser(): "--checkpoint", type=str, required=True, - help="Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint().", + help=( + "Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint()." + ), ) parser.add_argument( @@ -115,10 +117,12 @@ def get_parser(): "sound_files", type=str, nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", + help=( + "The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz." + ), ) parser.add_argument( @@ -165,15 +169,16 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", type=int, default=1, - help="Maximum number of symbols per frame. " - "Use only when --method is greedy_search", + help=( + "Maximum number of symbols per frame. " + "Use only when --method is greedy_search" + ), ) return parser @@ -194,10 +199,9 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" - ) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" # We use only the first channel ans.append(wave[0]) return ans @@ -254,13 +258,9 @@ def main(): feature_lens = [f.size(0) for f in features] feature_lens = torch.tensor(feature_lens, device=device) - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) - encoder_out, encoder_out_lens = model.encoder( - x=features, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lens) num_waves = encoder_out.size(0) hyp_list = [] @@ -308,9 +308,7 @@ def main(): beam=params.beam_size, ) else: - raise ValueError( - f"Unsupported decoding method: {params.method}" - ) + raise ValueError(f"Unsupported decoding method: {params.method}") hyp_list.append(hyp) hyps = [] @@ -327,9 +325,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/aishell/ASR/transducer_stateless_modified/train.py b/egs/aishell/ASR/transducer_stateless_modified/train.py index d3ffccafa..5f116f2bd 100755 --- a/egs/aishell/ASR/transducer_stateless_modified/train.py +++ b/egs/aishell/ASR/transducer_stateless_modified/train.py @@ -142,8 +142,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( @@ -414,9 +413,7 @@ def compute_loss( info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() @@ -529,9 +526,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -657,9 +652,7 @@ def run(rank, world_size, args): cur_lr = optimizer._rate if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) + tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) if rank == 0: diff --git a/egs/aishell2/ASR/local/__init__.py b/egs/aishell2/ASR/local/__init__.py old mode 100755 new mode 100644 diff --git a/egs/aishell2/ASR/local/compute_fbank_aishell2.py b/egs/aishell2/ASR/local/compute_fbank_aishell2.py index d8d3622bd..ec0c584ca 100755 --- a/egs/aishell2/ASR/local/compute_fbank_aishell2.py +++ b/egs/aishell2/ASR/local/compute_fbank_aishell2.py @@ -83,9 +83,7 @@ def compute_fbank_aishell2(num_mel_bins: int = 80): ) if "train" in partition: cut_set = ( - cut_set - + cut_set.perturb_speed(0.9) - + cut_set.perturb_speed(1.1) + cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) ) cut_set = cut_set.compute_and_store_features( extractor=extractor, @@ -111,9 +109,7 @@ def get_args(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/__init__.py b/egs/aishell2/ASR/pruned_transducer_stateless5/__init__.py old mode 100755 new mode 100644 diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/asr_datamodule.py b/egs/aishell2/ASR/pruned_transducer_stateless5/asr_datamodule.py old mode 100755 new mode 100644 index b7a21f579..e8966b554 --- a/egs/aishell2/ASR/pruned_transducer_stateless5/asr_datamodule.py +++ b/egs/aishell2/ASR/pruned_transducer_stateless5/asr_datamodule.py @@ -76,10 +76,12 @@ class AiShell2AsrDataModule: def add_arguments(cls, parser: argparse.ArgumentParser): group = parser.add_argument_group( title="ASR data related options", - description="These options are used for the preparation of " - "PyTorch DataLoaders from Lhotse CutSet's -- they control the " - "effective batch sizes, sampling strategies, applied data " - "augmentations, etc.", + description=( + "These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc." + ), ) group.add_argument( "--manifest-dir", @@ -91,59 +93,74 @@ class AiShell2AsrDataModule: "--max-duration", type=int, default=200.0, - help="Maximum pooled recordings duration (seconds) in a " - "single batch. You can reduce it if it causes CUDA OOM.", + help=( + "Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM." + ), ) group.add_argument( "--bucketing-sampler", type=str2bool, default=True, - help="When enabled, the batches will come from buckets of " - "similar duration (saves padding frames).", + help=( + "When enabled, the batches will come from buckets of " + "similar duration (saves padding frames)." + ), ) group.add_argument( "--num-buckets", type=int, default=30, - help="The number of buckets for the DynamicBucketingSampler" - "(you might want to increase it for larger datasets).", + help=( + "The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets)." + ), ) group.add_argument( "--concatenate-cuts", type=str2bool, default=False, - help="When enabled, utterances (cuts) will be concatenated " - "to minimize the amount of padding.", + help=( + "When enabled, utterances (cuts) will be concatenated " + "to minimize the amount of padding." + ), ) group.add_argument( "--duration-factor", type=float, default=1.0, - help="Determines the maximum duration of a concatenated cut " - "relative to the duration of the longest cut in a batch.", + help=( + "Determines the maximum duration of a concatenated cut " + "relative to the duration of the longest cut in a batch." + ), ) group.add_argument( "--gap", type=float, default=1.0, - help="The amount of padding (in seconds) inserted between " - "concatenated cuts. This padding is filled with noise when " - "noise augmentation is used.", + help=( + "The amount of padding (in seconds) inserted between " + "concatenated cuts. This padding is filled with noise when " + "noise augmentation is used." + ), ) group.add_argument( "--on-the-fly-feats", type=str2bool, default=False, - help="When enabled, use on-the-fly cut mixing and feature " - "extraction. Will drop existing precomputed feature manifests " - "if available.", + help=( + "When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available." + ), ) group.add_argument( "--shuffle", type=str2bool, default=True, - help="When enabled (=default), the examples will be " - "shuffled for each epoch.", + help=( + "When enabled (=default), the examples will be shuffled for each epoch." + ), ) group.add_argument( "--drop-last", @@ -155,17 +172,18 @@ class AiShell2AsrDataModule: "--return-cuts", type=str2bool, default=True, - help="When enabled, each batch will have the " - "field: batch['supervisions']['cut'] with the cuts that " - "were used to construct it.", + help=( + "When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it." + ), ) group.add_argument( "--num-workers", type=int, default=2, - help="The number of training dataloader workers that " - "collect the batches.", + help="The number of training dataloader workers that collect the batches.", ) group.add_argument( @@ -179,18 +197,22 @@ class AiShell2AsrDataModule: "--spec-aug-time-warp-factor", type=int, default=80, - help="Used only when --enable-spec-aug is True. " - "It specifies the factor for time warping in SpecAugment. " - "Larger values mean more warping. " - "A value less than 1 means to disable time warp.", + help=( + "Used only when --enable-spec-aug is True. " + "It specifies the factor for time warping in SpecAugment. " + "Larger values mean more warping. " + "A value less than 1 means to disable time warp." + ), ) group.add_argument( "--enable-musan", type=str2bool, default=True, - help="When enabled, select noise from MUSAN and mix it" - "with training dataset. ", + help=( + "When enabled, select noise from MUSAN and mix it" + "with training dataset. " + ), ) group.add_argument( @@ -216,20 +238,16 @@ class AiShell2AsrDataModule: if self.args.enable_musan: logging.info("Enable MUSAN") logging.info("About to get Musan cuts") - cuts_musan = load_manifest( - self.args.manifest_dir / "musan_cuts.jsonl.gz" - ) + cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") transforms.append( - CutMix( - cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True - ) + CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) ) else: logging.info("Disable MUSAN") if self.args.concatenate_cuts: logging.info( - f"Using cut concatenation with duration factor " + "Using cut concatenation with duration factor " f"{self.args.duration_factor} and gap {self.args.gap}." ) # Cut concatenation should be the first transform in the list, @@ -244,9 +262,7 @@ class AiShell2AsrDataModule: input_transforms = [] if self.args.enable_spec_aug: logging.info("Enable SpecAugment") - logging.info( - f"Time warp factor: {self.args.spec_aug_time_warp_factor}" - ) + logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") # Set the value of num_frame_masks according to Lhotse's version. # In different Lhotse's versions, the default of num_frame_masks is # different. @@ -290,9 +306,7 @@ class AiShell2AsrDataModule: # Drop feats to be on the safe side. train = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) - ), + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), input_transforms=input_transforms, return_cuts=self.args.return_cuts, ) @@ -348,9 +362,7 @@ class AiShell2AsrDataModule: if self.args.on_the_fly_feats: validate = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) - ), + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), return_cuts=self.args.return_cuts, ) else: @@ -406,9 +418,7 @@ class AiShell2AsrDataModule: @lru_cache() def valid_cuts(self) -> CutSet: logging.info("About to gen cuts from aishell2_cuts_dev.jsonl.gz") - return load_manifest_lazy( - self.args.manifest_dir / "aishell2_cuts_dev.jsonl.gz" - ) + return load_manifest_lazy(self.args.manifest_dir / "aishell2_cuts_dev.jsonl.gz") @lru_cache() def test_cuts(self) -> CutSet: diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/decode.py b/egs/aishell2/ASR/pruned_transducer_stateless5/decode.py index 915737f4a..64b64d1b1 100755 --- a/egs/aishell2/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/aishell2/ASR/pruned_transducer_stateless5/decode.py @@ -168,20 +168,24 @@ def get_parser(): "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'" + ), ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", + help=( + "Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. " + ), ) parser.add_argument( @@ -269,8 +273,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -348,9 +351,7 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) hyps = [] if params.decoding_method == "fast_beam_search": @@ -409,10 +410,7 @@ def decode_one_batch( ) for i in range(encoder_out.size(0)): hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -538,9 +536,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -573,8 +569,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -625,9 +620,7 @@ def main(): if "LG" in params.decoding_method: params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -661,13 +654,12 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -690,13 +682,12 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -724,7 +715,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - f"Calculating the averaged model over epoch range from " + "Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -749,9 +740,7 @@ def main(): ) decoding_graph.scores *= params.ngram_lm_scale else: - decoding_graph = k2.trivial_graph( - params.vocab_size - 1, device=device - ) + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) else: decoding_graph = None diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/export.py b/egs/aishell2/ASR/pruned_transducer_stateless5/export.py index bc7bd71cb..547ce2069 100755 --- a/egs/aishell2/ASR/pruned_transducer_stateless5/export.py +++ b/egs/aishell2/ASR/pruned_transducer_stateless5/export.py @@ -89,20 +89,24 @@ def get_parser(): "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'" + ), ) parser.add_argument( "--use-averaged-model", type=str2bool, default=False, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", + help=( + "Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. " + ), ) parser.add_argument( @@ -133,8 +137,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) add_model_arguments(parser) @@ -167,13 +170,12 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -196,13 +198,12 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -230,7 +231,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - f"Calculating the averaged model over epoch range from " + "Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -266,9 +267,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/pretrained.py b/egs/aishell2/ASR/pruned_transducer_stateless5/pretrained.py index 09de1bece..4b16511e8 100755 --- a/egs/aishell2/ASR/pruned_transducer_stateless5/pretrained.py +++ b/egs/aishell2/ASR/pruned_transducer_stateless5/pretrained.py @@ -81,9 +81,11 @@ def get_parser(): "--checkpoint", type=str, required=True, - help="Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint().", + help=( + "Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint()." + ), ) parser.add_argument( @@ -109,10 +111,12 @@ def get_parser(): "sound_files", type=str, nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", + help=( + "The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz." + ), ) parser.add_argument( @@ -159,8 +163,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -191,10 +194,9 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" - ) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" # We use only the first channel ans.append(wave[0]) return ans @@ -254,15 +256,11 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) feature_lengths = torch.tensor(feature_lengths, device=device) - encoder_out, encoder_out_lens = model.encoder( - x=features, x_lens=feature_lengths - ) + encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths) num_waves = encoder_out.size(0) hyps = [] @@ -334,9 +332,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/train.py b/egs/aishell2/ASR/pruned_transducer_stateless5/train.py index 838a0497f..d37e7bdca 100755 --- a/egs/aishell2/ASR/pruned_transducer_stateless5/train.py +++ b/egs/aishell2/ASR/pruned_transducer_stateless5/train.py @@ -92,9 +92,7 @@ from icefall.env import get_env_info from icefall.lexicon import Lexicon from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def add_model_arguments(parser: argparse.ArgumentParser): @@ -220,8 +218,7 @@ def get_parser(): "--initial-lr", type=float, default=0.003, - help="The initial learning rate. This value should not need " - "to be changed.", + help="The initial learning rate. This value should not need to be changed.", ) parser.add_argument( @@ -244,42 +241,45 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", + help=( + "The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss" + ), ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", + help=( + "The scale to smooth the loss with lm (output of prediction network) part." + ), ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network)part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", + help=( + "To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss." + ), ) parser.add_argument( @@ -603,11 +603,7 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = ( - model.device - if isinstance(model, DDP) - else next(model.parameters()).device - ) + device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -636,23 +632,16 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 - if warmup < 1.0 - else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = ( - params.simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss + 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) ) + loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() @@ -771,9 +760,7 @@ def train_one_epoch( scaler.update() optimizer.zero_grad() except: # noqa - display_and_save_batch( - batch, params=params, graph_compiler=graph_compiler - ) + display_and_save_batch(batch, params=params, graph_compiler=graph_compiler) raise if params.print_diagnostics and batch_idx == 5: @@ -829,9 +816,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -939,7 +924,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2 ** 22 + 2**22 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) @@ -1104,9 +1089,7 @@ def scan_pessimistic_batches_for_oom( f"Failing criterion: {criterion} " f"(={crit_values[criterion]}) ..." ) - display_and_save_batch( - batch, params=params, graph_compiler=graph_compiler - ) + display_and_save_batch(batch, params=params, graph_compiler=graph_compiler) raise diff --git a/egs/aishell4/ASR/local/compute_fbank_aishell4.py b/egs/aishell4/ASR/local/compute_fbank_aishell4.py index 3f50d9e3e..400c406f0 100755 --- a/egs/aishell4/ASR/local/compute_fbank_aishell4.py +++ b/egs/aishell4/ASR/local/compute_fbank_aishell4.py @@ -85,9 +85,7 @@ def compute_fbank_aishell4(num_mel_bins: int = 80): ) if "train" in partition: cut_set = ( - cut_set - + cut_set.perturb_speed(0.9) - + cut_set.perturb_speed(1.1) + cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) ) cut_set = cut_set.compute_and_store_features( extractor=extractor, @@ -120,9 +118,7 @@ def get_args(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/aishell4/ASR/local/prepare_char.py b/egs/aishell4/ASR/local/prepare_char.py index d9e47d17a..6b440dfb3 100755 --- a/egs/aishell4/ASR/local/prepare_char.py +++ b/egs/aishell4/ASR/local/prepare_char.py @@ -86,9 +86,7 @@ def lexicon_to_fst_no_sil( cur_state = loop_state word = word2id[word] - pieces = [ - token2id[i] if i in token2id else token2id[""] for i in pieces - ] + pieces = [token2id[i] if i in token2id else token2id[""] for i in pieces] for i in range(len(pieces) - 1): w = word if i == 0 else eps @@ -142,9 +140,7 @@ def contain_oov(token_sym_table: Dict[str, int], tokens: List[str]) -> bool: return False -def generate_lexicon( - token_sym_table: Dict[str, int], words: List[str] -) -> Lexicon: +def generate_lexicon(token_sym_table: Dict[str, int], words: List[str]) -> Lexicon: """Generate a lexicon from a word list and token_sym_table. Args: diff --git a/egs/aishell4/ASR/local/prepare_lang.py b/egs/aishell4/ASR/local/prepare_lang.py index e5ae89ec4..c8cf9b881 100755 --- a/egs/aishell4/ASR/local/prepare_lang.py +++ b/egs/aishell4/ASR/local/prepare_lang.py @@ -317,9 +317,7 @@ def lexicon_to_fst( def get_args(): parser = argparse.ArgumentParser() - parser.add_argument( - "--lang-dir", type=str, help="The lang dir, data/lang_phone" - ) + parser.add_argument("--lang-dir", type=str, help="The lang dir, data/lang_phone") return parser.parse_args() diff --git a/egs/aishell4/ASR/local/test_prepare_lang.py b/egs/aishell4/ASR/local/test_prepare_lang.py index d4cf62bba..74e025ad7 100755 --- a/egs/aishell4/ASR/local/test_prepare_lang.py +++ b/egs/aishell4/ASR/local/test_prepare_lang.py @@ -88,9 +88,7 @@ def test_read_lexicon(filename: str): fsa.aux_labels_sym = k2.SymbolTable.from_file("words.txt") fsa.draw("L.pdf", title="L") - fsa_disambig = lexicon_to_fst( - lexicon_disambig, phone2id=phone2id, word2id=word2id - ) + fsa_disambig = lexicon_to_fst(lexicon_disambig, phone2id=phone2id, word2id=word2id) fsa_disambig.labels_sym = k2.SymbolTable.from_file("phones.txt") fsa_disambig.aux_labels_sym = k2.SymbolTable.from_file("words.txt") fsa_disambig.draw("L_disambig.pdf", title="L_disambig") diff --git a/egs/aishell4/ASR/local/text2token.py b/egs/aishell4/ASR/local/text2token.py index 71be2a613..2be639b7a 100755 --- a/egs/aishell4/ASR/local/text2token.py +++ b/egs/aishell4/ASR/local/text2token.py @@ -50,15 +50,15 @@ def get_parser(): "-n", default=1, type=int, - help="number of characters to split, i.e., \ - aabb -> a a b b with -n 1 and aa bb with -n 2", + help=( + "number of characters to split, i.e., aabb -> a a b" + " b with -n 1 and aa bb with -n 2" + ), ) parser.add_argument( "--skip-ncols", "-s", default=0, type=int, help="skip first n columns" ) - parser.add_argument( - "--space", default="", type=str, help="space symbol" - ) + parser.add_argument("--space", default="", type=str, help="space symbol") parser.add_argument( "--non-lang-syms", "-l", @@ -66,9 +66,7 @@ def get_parser(): type=str, help="list of non-linguistic symobles, e.g., etc.", ) - parser.add_argument( - "text", type=str, default=False, nargs="?", help="input text" - ) + parser.add_argument("text", type=str, default=False, nargs="?", help="input text") parser.add_argument( "--trans_type", "-t", @@ -108,8 +106,7 @@ def token2id( if token_type == "lazy_pinyin": text = lazy_pinyin(chars_list) sub_ids = [ - token_table[txt] if txt in token_table else oov_id - for txt in text + token_table[txt] if txt in token_table else oov_id for txt in text ] ids.append(sub_ids) else: # token_type = "pinyin" @@ -135,9 +132,7 @@ def main(): if args.text: f = codecs.open(args.text, encoding="utf-8") else: - f = codecs.getreader("utf-8")( - sys.stdin if is_python2 else sys.stdin.buffer - ) + f = codecs.getreader("utf-8")(sys.stdin if is_python2 else sys.stdin.buffer) sys.stdout = codecs.getwriter("utf-8")( sys.stdout if is_python2 else sys.stdout.buffer diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/asr_datamodule.py b/egs/aishell4/ASR/pruned_transducer_stateless5/asr_datamodule.py index 7aa53ddda..84c7f0443 100644 --- a/egs/aishell4/ASR/pruned_transducer_stateless5/asr_datamodule.py +++ b/egs/aishell4/ASR/pruned_transducer_stateless5/asr_datamodule.py @@ -74,10 +74,12 @@ class Aishell4AsrDataModule: def add_arguments(cls, parser: argparse.ArgumentParser): group = parser.add_argument_group( title="ASR data related options", - description="These options are used for the preparation of " - "PyTorch DataLoaders from Lhotse CutSet's -- they control the " - "effective batch sizes, sampling strategies, applied data " - "augmentations, etc.", + description=( + "These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc." + ), ) group.add_argument( @@ -91,66 +93,81 @@ class Aishell4AsrDataModule: "--max-duration", type=int, default=200.0, - help="Maximum pooled recordings duration (seconds) in a " - "single batch. You can reduce it if it causes CUDA OOM.", + help=( + "Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM." + ), ) group.add_argument( "--bucketing-sampler", type=str2bool, default=True, - help="When enabled, the batches will come from buckets of " - "similar duration (saves padding frames).", + help=( + "When enabled, the batches will come from buckets of " + "similar duration (saves padding frames)." + ), ) group.add_argument( "--num-buckets", type=int, default=300, - help="The number of buckets for the DynamicBucketingSampler" - "(you might want to increase it for larger datasets).", + help=( + "The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets)." + ), ) group.add_argument( "--concatenate-cuts", type=str2bool, default=False, - help="When enabled, utterances (cuts) will be concatenated " - "to minimize the amount of padding.", + help=( + "When enabled, utterances (cuts) will be concatenated " + "to minimize the amount of padding." + ), ) group.add_argument( "--duration-factor", type=float, default=1.0, - help="Determines the maximum duration of a concatenated cut " - "relative to the duration of the longest cut in a batch.", + help=( + "Determines the maximum duration of a concatenated cut " + "relative to the duration of the longest cut in a batch." + ), ) group.add_argument( "--gap", type=float, default=1.0, - help="The amount of padding (in seconds) inserted between " - "concatenated cuts. This padding is filled with noise when " - "noise augmentation is used.", + help=( + "The amount of padding (in seconds) inserted between " + "concatenated cuts. This padding is filled with noise when " + "noise augmentation is used." + ), ) group.add_argument( "--on-the-fly-feats", type=str2bool, default=False, - help="When enabled, use on-the-fly cut mixing and feature " - "extraction. Will drop existing precomputed feature manifests " - "if available.", + help=( + "When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available." + ), ) group.add_argument( "--shuffle", type=str2bool, default=True, - help="When enabled (=default), the examples will be " - "shuffled for each epoch.", + help=( + "When enabled (=default), the examples will be shuffled for each epoch." + ), ) group.add_argument( @@ -164,17 +181,18 @@ class Aishell4AsrDataModule: "--return-cuts", type=str2bool, default=True, - help="When enabled, each batch will have the " - "field: batch['supervisions']['cut'] with the cuts that " - "were used to construct it.", + help=( + "When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it." + ), ) group.add_argument( "--num-workers", type=int, default=2, - help="The number of training dataloader workers that " - "collect the batches.", + help="The number of training dataloader workers that collect the batches.", ) group.add_argument( @@ -188,18 +206,22 @@ class Aishell4AsrDataModule: "--spec-aug-time-warp-factor", type=int, default=80, - help="Used only when --enable-spec-aug is True. " - "It specifies the factor for time warping in SpecAugment. " - "Larger values mean more warping. " - "A value less than 1 means to disable time warp.", + help=( + "Used only when --enable-spec-aug is True. " + "It specifies the factor for time warping in SpecAugment. " + "Larger values mean more warping. " + "A value less than 1 means to disable time warp." + ), ) group.add_argument( "--enable-musan", type=str2bool, default=True, - help="When enabled, select noise from MUSAN and mix it" - "with training dataset. ", + help=( + "When enabled, select noise from MUSAN and mix it" + "with training dataset. " + ), ) group.add_argument( @@ -222,24 +244,20 @@ class Aishell4AsrDataModule: The state dict for the training sampler. """ logging.info("About to get Musan cuts") - cuts_musan = load_manifest( - self.args.manifest_dir / "musan_cuts.jsonl.gz" - ) + cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") transforms = [] if self.args.enable_musan: logging.info("Enable MUSAN") transforms.append( - CutMix( - cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True - ) + CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) ) else: logging.info("Disable MUSAN") if self.args.concatenate_cuts: logging.info( - f"Using cut concatenation with duration factor " + "Using cut concatenation with duration factor " f"{self.args.duration_factor} and gap {self.args.gap}." ) # Cut concatenation should be the first transform in the list, @@ -254,9 +272,7 @@ class Aishell4AsrDataModule: input_transforms = [] if self.args.enable_spec_aug: logging.info("Enable SpecAugment") - logging.info( - f"Time warp factor: {self.args.spec_aug_time_warp_factor}" - ) + logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") # Set the value of num_frame_masks according to Lhotse's version. # In different Lhotse's versions, the default of num_frame_masks is # different. @@ -300,9 +316,7 @@ class Aishell4AsrDataModule: # Drop feats to be on the safe side. train = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) - ), + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), input_transforms=input_transforms, return_cuts=self.args.return_cuts, ) @@ -359,9 +373,7 @@ class Aishell4AsrDataModule: if self.args.on_the_fly_feats: validate = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) - ), + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), return_cuts=self.args.return_cuts, ) else: diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/decode.py b/egs/aishell4/ASR/pruned_transducer_stateless5/decode.py index 14e44c7d9..616a88937 100755 --- a/egs/aishell4/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/aishell4/ASR/pruned_transducer_stateless5/decode.py @@ -117,20 +117,24 @@ def get_parser(): "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'" + ), ) parser.add_argument( "--use-averaged-model", type=str2bool, default=False, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", + help=( + "Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. " + ), ) parser.add_argument( @@ -201,8 +205,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -260,9 +263,7 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) hyps = [] if params.decoding_method == "fast_beam_search": @@ -277,10 +278,7 @@ def decode_one_batch( ) for i in range(encoder_out.size(0)): hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -326,11 +324,7 @@ def decode_one_batch( return {"greedy_search": hyps} elif params.decoding_method == "fast_beam_search": return { - ( - f"beam_{params.beam}_" - f"max_contexts_{params.max_contexts}_" - f"max_states_{params.max_states}" - ): hyps + f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps } else: return {f"beam_size_{params.beam_size}": hyps} @@ -401,9 +395,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -436,8 +428,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -480,9 +471,7 @@ def main(): params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -510,13 +499,12 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -543,13 +531,12 @@ def main(): ) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -578,7 +565,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - f"Calculating the averaged model over epoch range from " + "Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/export.py b/egs/aishell4/ASR/pruned_transducer_stateless5/export.py index 993341131..3c580ff7b 100755 --- a/egs/aishell4/ASR/pruned_transducer_stateless5/export.py +++ b/egs/aishell4/ASR/pruned_transducer_stateless5/export.py @@ -89,20 +89,24 @@ def get_parser(): "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'" + ), ) parser.add_argument( "--use-averaged-model", type=str2bool, default=False, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", + help=( + "Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. " + ), ) parser.add_argument( @@ -136,8 +140,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) add_model_arguments(parser) @@ -169,13 +172,12 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -202,13 +204,12 @@ def main(): ) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -237,7 +238,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - f"Calculating the averaged model over epoch range from " + "Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -276,9 +277,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/pretrained.py b/egs/aishell4/ASR/pruned_transducer_stateless5/pretrained.py index 1fa893637..8151442af 100755 --- a/egs/aishell4/ASR/pruned_transducer_stateless5/pretrained.py +++ b/egs/aishell4/ASR/pruned_transducer_stateless5/pretrained.py @@ -94,9 +94,11 @@ def get_parser(): "--checkpoint", type=str, required=True, - help="Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint().", + help=( + "Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint()." + ), ) parser.add_argument( @@ -122,10 +124,12 @@ def get_parser(): "sound_files", type=str, nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", + help=( + "The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz." + ), ) parser.add_argument( @@ -172,8 +176,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -204,10 +207,9 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" - ) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" # We use only the first channel ans.append(wave[0]) return ans @@ -266,15 +268,11 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) feature_lengths = torch.tensor(feature_lengths, device=device) - encoder_out, encoder_out_lens = model.encoder( - x=features, x_lens=feature_lengths - ) + encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths) num_waves = encoder_out.size(0) hyps = [] @@ -306,10 +304,7 @@ def main(): for i in range(encoder_out.size(0)): hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -350,9 +345,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/train.py b/egs/aishell4/ASR/pruned_transducer_stateless5/train.py index 0a48b9059..aacd23ecd 100755 --- a/egs/aishell4/ASR/pruned_transducer_stateless5/train.py +++ b/egs/aishell4/ASR/pruned_transducer_stateless5/train.py @@ -85,9 +85,7 @@ from icefall.env import get_env_info from icefall.lexicon import Lexicon from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def add_model_arguments(parser: argparse.ArgumentParser): @@ -213,8 +211,7 @@ def get_parser(): "--initial-lr", type=float, default=0.003, - help="The initial learning rate. This value should not need " - "to be changed.", + help="The initial learning rate. This value should not need to be changed.", ) parser.add_argument( @@ -237,42 +234,45 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", + help=( + "The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss" + ), ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", + help=( + "The scale to smooth the loss with lm (output of prediction network) part." + ), ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network)part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", + help=( + "To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss." + ), ) parser.add_argument( @@ -599,11 +599,7 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = ( - model.device - if isinstance(model, DDP) - else next(model.parameters()).device - ) + device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -633,22 +629,15 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 - if warmup < 1.0 - else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = ( - params.simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss + 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) ) + loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() @@ -827,9 +816,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -937,7 +924,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2 ** 22 + 2**22 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/alimeeting/ASR/local/compute_fbank_alimeeting.py b/egs/alimeeting/ASR/local/compute_fbank_alimeeting.py index af926aa53..96115a230 100755 --- a/egs/alimeeting/ASR/local/compute_fbank_alimeeting.py +++ b/egs/alimeeting/ASR/local/compute_fbank_alimeeting.py @@ -84,9 +84,7 @@ def compute_fbank_alimeeting(num_mel_bins: int = 80): ) if "train" in partition: cut_set = ( - cut_set - + cut_set.perturb_speed(0.9) - + cut_set.perturb_speed(1.1) + cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) ) cur_num_jobs = num_jobs if ex is None else 80 cur_num_jobs = min(cur_num_jobs, len(cut_set)) @@ -121,9 +119,7 @@ def get_args(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/alimeeting/ASR/local/prepare_char.py b/egs/alimeeting/ASR/local/prepare_char.py index d9e47d17a..6b440dfb3 100755 --- a/egs/alimeeting/ASR/local/prepare_char.py +++ b/egs/alimeeting/ASR/local/prepare_char.py @@ -86,9 +86,7 @@ def lexicon_to_fst_no_sil( cur_state = loop_state word = word2id[word] - pieces = [ - token2id[i] if i in token2id else token2id[""] for i in pieces - ] + pieces = [token2id[i] if i in token2id else token2id[""] for i in pieces] for i in range(len(pieces) - 1): w = word if i == 0 else eps @@ -142,9 +140,7 @@ def contain_oov(token_sym_table: Dict[str, int], tokens: List[str]) -> bool: return False -def generate_lexicon( - token_sym_table: Dict[str, int], words: List[str] -) -> Lexicon: +def generate_lexicon(token_sym_table: Dict[str, int], words: List[str]) -> Lexicon: """Generate a lexicon from a word list and token_sym_table. Args: diff --git a/egs/alimeeting/ASR/local/prepare_lang.py b/egs/alimeeting/ASR/local/prepare_lang.py index e5ae89ec4..c8cf9b881 100755 --- a/egs/alimeeting/ASR/local/prepare_lang.py +++ b/egs/alimeeting/ASR/local/prepare_lang.py @@ -317,9 +317,7 @@ def lexicon_to_fst( def get_args(): parser = argparse.ArgumentParser() - parser.add_argument( - "--lang-dir", type=str, help="The lang dir, data/lang_phone" - ) + parser.add_argument("--lang-dir", type=str, help="The lang dir, data/lang_phone") return parser.parse_args() diff --git a/egs/alimeeting/ASR/local/test_prepare_lang.py b/egs/alimeeting/ASR/local/test_prepare_lang.py index d4cf62bba..74e025ad7 100755 --- a/egs/alimeeting/ASR/local/test_prepare_lang.py +++ b/egs/alimeeting/ASR/local/test_prepare_lang.py @@ -88,9 +88,7 @@ def test_read_lexicon(filename: str): fsa.aux_labels_sym = k2.SymbolTable.from_file("words.txt") fsa.draw("L.pdf", title="L") - fsa_disambig = lexicon_to_fst( - lexicon_disambig, phone2id=phone2id, word2id=word2id - ) + fsa_disambig = lexicon_to_fst(lexicon_disambig, phone2id=phone2id, word2id=word2id) fsa_disambig.labels_sym = k2.SymbolTable.from_file("phones.txt") fsa_disambig.aux_labels_sym = k2.SymbolTable.from_file("words.txt") fsa_disambig.draw("L_disambig.pdf", title="L_disambig") diff --git a/egs/alimeeting/ASR/local/text2segments.py b/egs/alimeeting/ASR/local/text2segments.py index 7c1019aa8..27b904fc8 100644 --- a/egs/alimeeting/ASR/local/text2segments.py +++ b/egs/alimeeting/ASR/local/text2segments.py @@ -30,8 +30,8 @@ with word segmenting: import argparse -import paddle import jieba +import paddle from tqdm import tqdm paddle.enable_static() diff --git a/egs/alimeeting/ASR/local/text2token.py b/egs/alimeeting/ASR/local/text2token.py index 71be2a613..2be639b7a 100755 --- a/egs/alimeeting/ASR/local/text2token.py +++ b/egs/alimeeting/ASR/local/text2token.py @@ -50,15 +50,15 @@ def get_parser(): "-n", default=1, type=int, - help="number of characters to split, i.e., \ - aabb -> a a b b with -n 1 and aa bb with -n 2", + help=( + "number of characters to split, i.e., aabb -> a a b" + " b with -n 1 and aa bb with -n 2" + ), ) parser.add_argument( "--skip-ncols", "-s", default=0, type=int, help="skip first n columns" ) - parser.add_argument( - "--space", default="", type=str, help="space symbol" - ) + parser.add_argument("--space", default="", type=str, help="space symbol") parser.add_argument( "--non-lang-syms", "-l", @@ -66,9 +66,7 @@ def get_parser(): type=str, help="list of non-linguistic symobles, e.g., etc.", ) - parser.add_argument( - "text", type=str, default=False, nargs="?", help="input text" - ) + parser.add_argument("text", type=str, default=False, nargs="?", help="input text") parser.add_argument( "--trans_type", "-t", @@ -108,8 +106,7 @@ def token2id( if token_type == "lazy_pinyin": text = lazy_pinyin(chars_list) sub_ids = [ - token_table[txt] if txt in token_table else oov_id - for txt in text + token_table[txt] if txt in token_table else oov_id for txt in text ] ids.append(sub_ids) else: # token_type = "pinyin" @@ -135,9 +132,7 @@ def main(): if args.text: f = codecs.open(args.text, encoding="utf-8") else: - f = codecs.getreader("utf-8")( - sys.stdin if is_python2 else sys.stdin.buffer - ) + f = codecs.getreader("utf-8")(sys.stdin if is_python2 else sys.stdin.buffer) sys.stdout = codecs.getwriter("utf-8")( sys.stdout if is_python2 else sys.stdout.buffer diff --git a/egs/alimeeting/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/alimeeting/ASR/pruned_transducer_stateless2/asr_datamodule.py index bf6faad7a..d0467a29e 100644 --- a/egs/alimeeting/ASR/pruned_transducer_stateless2/asr_datamodule.py +++ b/egs/alimeeting/ASR/pruned_transducer_stateless2/asr_datamodule.py @@ -81,10 +81,12 @@ class AlimeetingAsrDataModule: def add_arguments(cls, parser: argparse.ArgumentParser): group = parser.add_argument_group( title="ASR data related options", - description="These options are used for the preparation of " - "PyTorch DataLoaders from Lhotse CutSet's -- they control the " - "effective batch sizes, sampling strategies, applied data " - "augmentations, etc.", + description=( + "These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc." + ), ) group.add_argument( "--manifest-dir", @@ -96,75 +98,91 @@ class AlimeetingAsrDataModule: "--max-duration", type=int, default=200.0, - help="Maximum pooled recordings duration (seconds) in a " - "single batch. You can reduce it if it causes CUDA OOM.", + help=( + "Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM." + ), ) group.add_argument( "--bucketing-sampler", type=str2bool, default=True, - help="When enabled, the batches will come from buckets of " - "similar duration (saves padding frames).", + help=( + "When enabled, the batches will come from buckets of " + "similar duration (saves padding frames)." + ), ) group.add_argument( "--num-buckets", type=int, default=300, - help="The number of buckets for the DynamicBucketingSampler" - "(you might want to increase it for larger datasets).", + help=( + "The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets)." + ), ) group.add_argument( "--concatenate-cuts", type=str2bool, default=False, - help="When enabled, utterances (cuts) will be concatenated " - "to minimize the amount of padding.", + help=( + "When enabled, utterances (cuts) will be concatenated " + "to minimize the amount of padding." + ), ) group.add_argument( "--duration-factor", type=float, default=1.0, - help="Determines the maximum duration of a concatenated cut " - "relative to the duration of the longest cut in a batch.", + help=( + "Determines the maximum duration of a concatenated cut " + "relative to the duration of the longest cut in a batch." + ), ) group.add_argument( "--gap", type=float, default=1.0, - help="The amount of padding (in seconds) inserted between " - "concatenated cuts. This padding is filled with noise when " - "noise augmentation is used.", + help=( + "The amount of padding (in seconds) inserted between " + "concatenated cuts. This padding is filled with noise when " + "noise augmentation is used." + ), ) group.add_argument( "--on-the-fly-feats", type=str2bool, default=False, - help="When enabled, use on-the-fly cut mixing and feature " - "extraction. Will drop existing precomputed feature manifests " - "if available.", + help=( + "When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available." + ), ) group.add_argument( "--shuffle", type=str2bool, default=True, - help="When enabled (=default), the examples will be " - "shuffled for each epoch.", + help=( + "When enabled (=default), the examples will be shuffled for each epoch." + ), ) group.add_argument( "--return-cuts", type=str2bool, default=True, - help="When enabled, each batch will have the " - "field: batch['supervisions']['cut'] with the cuts that " - "were used to construct it.", + help=( + "When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it." + ), ) group.add_argument( "--num-workers", type=int, default=2, - help="The number of training dataloader workers that " - "collect the batches.", + help="The number of training dataloader workers that collect the batches.", ) group.add_argument( @@ -178,18 +196,22 @@ class AlimeetingAsrDataModule: "--spec-aug-time-warp-factor", type=int, default=80, - help="Used only when --enable-spec-aug is True. " - "It specifies the factor for time warping in SpecAugment. " - "Larger values mean more warping. " - "A value less than 1 means to disable time warp.", + help=( + "Used only when --enable-spec-aug is True. " + "It specifies the factor for time warping in SpecAugment. " + "Larger values mean more warping. " + "A value less than 1 means to disable time warp." + ), ) group.add_argument( "--enable-musan", type=str2bool, default=True, - help="When enabled, select noise from MUSAN and mix it" - "with training dataset. ", + help=( + "When enabled, select noise from MUSAN and mix it" + "with training dataset. " + ), ) def train_dataloaders( @@ -205,24 +227,20 @@ class AlimeetingAsrDataModule: The state dict for the training sampler. """ logging.info("About to get Musan cuts") - cuts_musan = load_manifest( - self.args.manifest_dir / "musan_cuts.jsonl.gz" - ) + cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") transforms = [] if self.args.enable_musan: logging.info("Enable MUSAN") transforms.append( - CutMix( - cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True - ) + CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) ) else: logging.info("Disable MUSAN") if self.args.concatenate_cuts: logging.info( - f"Using cut concatenation with duration factor " + "Using cut concatenation with duration factor " f"{self.args.duration_factor} and gap {self.args.gap}." ) # Cut concatenation should be the first transform in the list, @@ -237,9 +255,7 @@ class AlimeetingAsrDataModule: input_transforms = [] if self.args.enable_spec_aug: logging.info("Enable SpecAugment") - logging.info( - f"Time warp factor: {self.args.spec_aug_time_warp_factor}" - ) + logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") # Set the value of num_frame_masks according to Lhotse's version. # In different Lhotse's versions, the default of num_frame_masks is # different. @@ -282,9 +298,7 @@ class AlimeetingAsrDataModule: # Drop feats to be on the safe side. train = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) - ), + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), input_transforms=input_transforms, return_cuts=self.args.return_cuts, ) @@ -341,9 +355,7 @@ class AlimeetingAsrDataModule: if self.args.on_the_fly_feats: validate = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) - ), + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), return_cuts=self.args.return_cuts, ) else: diff --git a/egs/alimeeting/ASR/pruned_transducer_stateless2/decode.py b/egs/alimeeting/ASR/pruned_transducer_stateless2/decode.py index 6358fe970..ffaca1021 100755 --- a/egs/alimeeting/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/alimeeting/ASR/pruned_transducer_stateless2/decode.py @@ -70,11 +70,7 @@ from beam_search import ( from lhotse.cut import Cut from train import get_params, get_transducer_model -from icefall.checkpoint import ( - average_checkpoints, - find_checkpoints, - load_checkpoint, -) +from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, @@ -93,25 +89,30 @@ def get_parser(): "--epoch", type=int, default=28, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", + help=( + "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." + ), ) parser.add_argument( "--batch", type=int, default=None, - help="It specifies the batch checkpoint to use for decoding." - "Note: Epoch counts from 0.", + help=( + "It specifies the batch checkpoint to use for decoding." + "Note: Epoch counts from 0." + ), ) parser.add_argument( "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. " + ), ) parser.add_argument( @@ -193,8 +194,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -249,9 +249,7 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) hyps = [] if params.decoding_method == "fast_beam_search": @@ -266,10 +264,7 @@ def decode_one_batch( ) for i in range(encoder_out.size(0)): hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -315,11 +310,7 @@ def decode_one_batch( return {"greedy_search": hyps} elif params.decoding_method == "fast_beam_search": return { - ( - f"beam_{params.beam}_" - f"max_contexts_{params.max_contexts}_" - f"max_states_{params.max_states}" - ): hyps + f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps } else: return {f"beam_size_{params.beam_size}": hyps} @@ -390,9 +381,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -425,8 +414,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -563,8 +551,7 @@ def main(): ) dev_shards = [ - str(path) - for path in sorted(glob.glob(os.path.join(dev, "shared-*.tar"))) + str(path) for path in sorted(glob.glob(os.path.join(dev, "shared-*.tar"))) ] cuts_dev_webdataset = CutSet.from_webdataset( dev_shards, @@ -574,8 +561,7 @@ def main(): ) test_shards = [ - str(path) - for path in sorted(glob.glob(os.path.join(test, "shared-*.tar"))) + str(path) for path in sorted(glob.glob(os.path.join(test, "shared-*.tar"))) ] cuts_test_webdataset = CutSet.from_webdataset( test_shards, @@ -588,9 +574,7 @@ def main(): return 1.0 <= c.duration cuts_dev_webdataset = cuts_dev_webdataset.filter(remove_short_and_long_utt) - cuts_test_webdataset = cuts_test_webdataset.filter( - remove_short_and_long_utt - ) + cuts_test_webdataset = cuts_test_webdataset.filter(remove_short_and_long_utt) dev_dl = alimeeting.valid_dataloaders(cuts_dev_webdataset) test_dl = alimeeting.test_dataloaders(cuts_test_webdataset) diff --git a/egs/alimeeting/ASR/pruned_transducer_stateless2/export.py b/egs/alimeeting/ASR/pruned_transducer_stateless2/export.py index 8beec1b8a..482e52d83 100644 --- a/egs/alimeeting/ASR/pruned_transducer_stateless2/export.py +++ b/egs/alimeeting/ASR/pruned_transducer_stateless2/export.py @@ -62,17 +62,20 @@ def get_parser(): "--epoch", type=int, default=28, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", + help=( + "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." + ), ) parser.add_argument( "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. " + ), ) parser.add_argument( @@ -103,8 +106,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) return parser @@ -173,9 +175,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/alimeeting/ASR/pruned_transducer_stateless2/pretrained.py b/egs/alimeeting/ASR/pruned_transducer_stateless2/pretrained.py index 93b1e1f57..afbf0960a 100644 --- a/egs/alimeeting/ASR/pruned_transducer_stateless2/pretrained.py +++ b/egs/alimeeting/ASR/pruned_transducer_stateless2/pretrained.py @@ -85,9 +85,11 @@ def get_parser(): "--checkpoint", type=str, required=True, - help="Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint().", + help=( + "Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint()." + ), ) parser.add_argument( @@ -112,10 +114,12 @@ def get_parser(): "sound_files", type=str, nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", + help=( + "The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz." + ), ) parser.add_argument( @@ -162,8 +166,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( @@ -193,10 +196,9 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" - ) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" # We use only the first channel ans.append(wave[0]) return ans @@ -257,9 +259,7 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) feature_lengths = torch.tensor(feature_lengths, device=device) @@ -284,10 +284,7 @@ def main(): ) for i in range(encoder_out.size(0)): hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -339,9 +336,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/alimeeting/ASR/pruned_transducer_stateless2/train.py b/egs/alimeeting/ASR/pruned_transducer_stateless2/train.py index 81a0ede7f..158ea9c1b 100644 --- a/egs/alimeeting/ASR/pruned_transducer_stateless2/train.py +++ b/egs/alimeeting/ASR/pruned_transducer_stateless2/train.py @@ -81,9 +81,7 @@ from icefall.env import get_env_info from icefall.lexicon import Lexicon from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] os.environ["CUDA_LAUNCH_BLOCKING"] = "1" @@ -187,42 +185,45 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", + help=( + "The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss" + ), ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", + help=( + "The scale to smooth the loss with lm (output of prediction network) part." + ), ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network)part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", + help=( + "To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss." + ), ) parser.add_argument( @@ -542,22 +543,15 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 - if warmup < 1.0 - else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = ( - params.simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss + 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) ) + loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() @@ -711,9 +705,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -813,7 +805,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2 ** 22 + 2**22 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/csj/ASR/.gitignore b/egs/csj/ASR/.gitignore index 5d965832e..cd0e20c4c 100644 --- a/egs/csj/ASR/.gitignore +++ b/egs/csj/ASR/.gitignore @@ -5,4 +5,4 @@ notify_tg.py finetune_* misc.ini .vscode/* -offline/* \ No newline at end of file +offline/* diff --git a/egs/csj/ASR/local/compute_fbank_csj.py b/egs/csj/ASR/local/compute_fbank_csj.py index 994dedbdd..036ce925f 100644 --- a/egs/csj/ASR/local/compute_fbank_csj.py +++ b/egs/csj/ASR/local/compute_fbank_csj.py @@ -25,15 +25,10 @@ from random import Random from typing import List, Tuple import torch -from lhotse import ( +from lhotse import ( # fmt: off; See the following for why LilcomChunkyWriter is preferred; https://github.com/k2-fsa/icefall/pull/404; https://github.com/lhotse-speech/lhotse/pull/527; fmt: on CutSet, Fbank, FbankConfig, - # fmt: off - # See the following for why LilcomChunkyWriter is preferred - # https://github.com/k2-fsa/icefall/pull/404 - # https://github.com/lhotse-speech/lhotse/pull/527 - # fmt: on LilcomChunkyWriter, RecordingSet, SupervisionSet, @@ -81,17 +76,13 @@ def make_cutset_blueprints( cut_sets.append((f"eval{i}", cut_set)) # Create train and valid cuts - logging.info( - "Loading, trimming, and shuffling the remaining core+noncore cuts." - ) + logging.info("Loading, trimming, and shuffling the remaining core+noncore cuts.") recording_set = RecordingSet.from_file( manifest_dir / "csj_recordings_core.jsonl.gz" ) + RecordingSet.from_file(manifest_dir / "csj_recordings_noncore.jsonl.gz") supervision_set = SupervisionSet.from_file( manifest_dir / "csj_supervisions_core.jsonl.gz" - ) + SupervisionSet.from_file( - manifest_dir / "csj_supervisions_noncore.jsonl.gz" - ) + ) + SupervisionSet.from_file(manifest_dir / "csj_supervisions_noncore.jsonl.gz") cut_set = CutSet.from_manifests( recordings=recording_set, @@ -101,15 +92,12 @@ def make_cutset_blueprints( cut_set = cut_set.shuffle(Random(RNG_SEED)) logging.info( - "Creating valid and train cuts from core and noncore," - f"split at {split}." + f"Creating valid and train cuts from core and noncore,split at {split}." ) valid_set = CutSet.from_cuts(islice(cut_set, 0, split)) train_set = CutSet.from_cuts(islice(cut_set, split, None)) - train_set = ( - train_set + train_set.perturb_speed(0.9) + train_set.perturb_speed(1.1) - ) + train_set = train_set + train_set.perturb_speed(0.9) + train_set.perturb_speed(1.1) cut_sets.extend([("valid", valid_set), ("train", train_set)]) @@ -122,15 +110,9 @@ def get_args(): formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) - parser.add_argument( - "--manifest-dir", type=Path, help="Path to save manifests" - ) - parser.add_argument( - "--fbank-dir", type=Path, help="Path to save fbank features" - ) - parser.add_argument( - "--split", type=int, default=4000, help="Split at this index" - ) + parser.add_argument("--manifest-dir", type=Path, help="Path to save manifests") + parser.add_argument("--fbank-dir", type=Path, help="Path to save fbank features") + parser.add_argument("--split", type=int, default=4000, help="Split at this index") return parser.parse_args() @@ -141,9 +123,7 @@ def main(): extractor = Fbank(FbankConfig(num_mel_bins=80)) num_jobs = min(16, os.cpu_count()) - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/csj/ASR/local/compute_fbank_musan.py b/egs/csj/ASR/local/compute_fbank_musan.py index 44a33c4eb..f60e62c85 100644 --- a/egs/csj/ASR/local/compute_fbank_musan.py +++ b/egs/csj/ASR/local/compute_fbank_musan.py @@ -26,7 +26,6 @@ from lhotse.recipes.utils import read_manifests_if_cached from icefall.utils import get_executor - ARGPARSE_DESCRIPTION = """ This file computes fbank features of the musan dataset. It looks for manifests in the directory data/manifests. @@ -84,9 +83,7 @@ def compute_fbank_musan(manifest_dir: Path, fbank_dir: Path): # create chunks of Musan with duration 5 - 10 seconds musan_cuts = ( CutSet.from_manifests( - recordings=combine( - part["recordings"] for part in manifests.values() - ) + recordings=combine(part["recordings"] for part in manifests.values()) ) .cut_into_windows(10.0) .filter(lambda c: c.duration > 5) @@ -107,21 +104,15 @@ def get_args(): formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) - parser.add_argument( - "--manifest-dir", type=Path, help="Path to save manifests" - ) - parser.add_argument( - "--fbank-dir", type=Path, help="Path to save fbank features" - ) + parser.add_argument("--manifest-dir", type=Path, help="Path to save manifests") + parser.add_argument("--fbank-dir", type=Path, help="Path to save fbank features") return parser.parse_args() if __name__ == "__main__": args = get_args() - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) compute_fbank_musan(args.manifest_dir, args.fbank_dir) diff --git a/egs/csj/ASR/local/conf/disfluent.ini b/egs/csj/ASR/local/conf/disfluent.ini index eb70673de..c987e72c5 100644 --- a/egs/csj/ASR/local/conf/disfluent.ini +++ b/egs/csj/ASR/local/conf/disfluent.ini @@ -1,17 +1,17 @@ ; # This section is ignored if this file is not supplied as the first config file to -; # lhotse prepare csj +; # lhotse prepare csj [SEGMENTS] ; # Allowed period of nonverbal noise. If exceeded, a new segment is created. gap = 0.5 ; # Maximum length of segment (s). maxlen = 10 -; # Minimum length of segment (s). Segments shorter than `minlen` will be dropped silently. +; # Minimum length of segment (s). Segments shorter than `minlen` will be dropped silently. minlen = 0.02 -; # Use this symbol to represent a period of allowed nonverbal noise, i.e. `gap`. -; # Pass an empty string to avoid adding any symbol. It was "" in kaldi. -; # If you intend to use a multicharacter string for gap_sym, remember to register the -; # multicharacter string as part of userdef-string in prepare_lang_char.py. -gap_sym = +; # Use this symbol to represent a period of allowed nonverbal noise, i.e. `gap`. +; # Pass an empty string to avoid adding any symbol. It was "" in kaldi. +; # If you intend to use a multicharacter string for gap_sym, remember to register the +; # multicharacter string as part of userdef-string in prepare_lang_char.py. +gap_sym = [CONSTANTS] ; # Name of this mode @@ -115,59 +115,59 @@ B^ = 0 ; # 0 to remain, 1 to delete ; # Example: '(笑 ナニガ)', '(笑 (F エー)+ソー+イッ+タ+ヨー+ナ)' 笑 = 0 -; # Example: 'コク(笑 サイ+(D オン))', +; # Example: 'コク(笑 サイ+(D オン))', 笑^ = 0 ; # 泣きながら発話 ; # 0 to remain, 1 to delete -; # Example: '(泣 ドンナニ)' +; # Example: '(泣 ドンナニ)' 泣 = 0 泣^ = 0 ; # 咳をしながら発話 ; # 0 to remain, 1 to delete -; # Example: 'シャ(咳 リン) ノ' +; # Example: 'シャ(咳 リン) ノ' 咳 = 0 ; # Example: 'イッ(咳 パン)', 'ワズ(咳 カ)' 咳^ = 0 ; # ささやき声や独り言などの小さな声 ; # 0 to remain, 1 to delete -; # Example: '(L アレコレナンダッケ)', '(L (W コデ;(? コレ,ココデ))+(? セツメー+シ+タ+ホー+ガ+イー+カ+ナ))' +; # Example: '(L アレコレナンダッケ)', '(L (W コデ;(? コレ,ココデ))+(? セツメー+シ+タ+ホー+ガ+イー+カ+ナ))' L = 0 ; # Example: 'デ(L ス)', 'ッ(L テ+コ)ト' L^ = 0 [REPLACEMENTS] ; # ボーカルフライなどで母音が同定できない場合 - = + = ; # 「うん/うーん/ふーん」の音の特定が困難な場合 - = + = ; # 非語彙的な母音の引き延ばし - = + = ; # 非語彙的な子音の引き延ばし - = + = ; # 言語音と独立に講演者の笑いが生じている場合 -<笑> = +<笑> = ; # 言語音と独立に講演者の咳が生じている場合 -<咳> = +<咳> = ; # 言語音と独立に講演者の息が生じている場合 -<息> = +<息> = ; # 講演者の泣き声 -<泣> = +<泣> = ; # 聴衆(司会者なども含む)の発話 -<フロア発話> = +<フロア発話> = ; # 聴衆の笑い -<フロア笑> = +<フロア笑> = ; # 聴衆の拍手 -<拍手> = +<拍手> = ; # 講演者が発表中に用いたデモンストレーションの音声 -<デモ> = +<デモ> = ; # 学会講演に発表時間を知らせるためにならすベルの音 -<ベル> = +<ベル> = ; # 転記単位全体が再度読み直された場合 -<朗読間違い> = +<朗読間違い> = ; # 上記以外の音で特に目立った音 -<雑音> = +<雑音> = ; # 0.2秒以上のポーズ -

= +

= ; # Redacted information, for R ; # It is \x00D7 multiplication sign, not your normal 'x' × = × @@ -318,4 +318,3 @@ spk_id = 2 ャ = ǐa ュ = ǐu ョ = ǐo - diff --git a/egs/csj/ASR/local/conf/fluent.ini b/egs/csj/ASR/local/conf/fluent.ini index 5d22f9eb8..f7f27f5bc 100644 --- a/egs/csj/ASR/local/conf/fluent.ini +++ b/egs/csj/ASR/local/conf/fluent.ini @@ -1,17 +1,17 @@ ; # This section is ignored if this file is not supplied as the first config file to -; # lhotse prepare csj +; # lhotse prepare csj [SEGMENTS] ; # Allowed period of nonverbal noise. If exceeded, a new segment is created. gap = 0.5 ; # Maximum length of segment (s). maxlen = 10 -; # Minimum length of segment (s). Segments shorter than `minlen` will be dropped silently. +; # Minimum length of segment (s). Segments shorter than `minlen` will be dropped silently. minlen = 0.02 -; # Use this symbol to represent a period of allowed nonverbal noise, i.e. `gap`. -; # Pass an empty string to avoid adding any symbol. It was "" in kaldi. -; # If you intend to use a multicharacter string for gap_sym, remember to register the -; # multicharacter string as part of userdef-string in prepare_lang_char.py. -gap_sym = +; # Use this symbol to represent a period of allowed nonverbal noise, i.e. `gap`. +; # Pass an empty string to avoid adding any symbol. It was "" in kaldi. +; # If you intend to use a multicharacter string for gap_sym, remember to register the +; # multicharacter string as part of userdef-string in prepare_lang_char.py. +gap_sym = [CONSTANTS] ; # Name of this mode @@ -115,59 +115,59 @@ B^ = 0 ; # 0 to remain, 1 to delete ; # Example: '(笑 ナニガ)', '(笑 (F エー)+ソー+イッ+タ+ヨー+ナ)' 笑 = 0 -; # Example: 'コク(笑 サイ+(D オン))', +; # Example: 'コク(笑 サイ+(D オン))', 笑^ = 0 ; # 泣きながら発話 ; # 0 to remain, 1 to delete -; # Example: '(泣 ドンナニ)' +; # Example: '(泣 ドンナニ)' 泣 = 0 泣^ = 0 ; # 咳をしながら発話 ; # 0 to remain, 1 to delete -; # Example: 'シャ(咳 リン) ノ' +; # Example: 'シャ(咳 リン) ノ' 咳 = 0 ; # Example: 'イッ(咳 パン)', 'ワズ(咳 カ)' 咳^ = 0 ; # ささやき声や独り言などの小さな声 ; # 0 to remain, 1 to delete -; # Example: '(L アレコレナンダッケ)', '(L (W コデ;(? コレ,ココデ))+(? セツメー+シ+タ+ホー+ガ+イー+カ+ナ))' +; # Example: '(L アレコレナンダッケ)', '(L (W コデ;(? コレ,ココデ))+(? セツメー+シ+タ+ホー+ガ+イー+カ+ナ))' L = 0 ; # Example: 'デ(L ス)', 'ッ(L テ+コ)ト' L^ = 0 [REPLACEMENTS] ; # ボーカルフライなどで母音が同定できない場合 - = + = ; # 「うん/うーん/ふーん」の音の特定が困難な場合 - = + = ; # 非語彙的な母音の引き延ばし - = + = ; # 非語彙的な子音の引き延ばし - = + = ; # 言語音と独立に講演者の笑いが生じている場合 -<笑> = +<笑> = ; # 言語音と独立に講演者の咳が生じている場合 -<咳> = +<咳> = ; # 言語音と独立に講演者の息が生じている場合 -<息> = +<息> = ; # 講演者の泣き声 -<泣> = +<泣> = ; # 聴衆(司会者なども含む)の発話 -<フロア発話> = +<フロア発話> = ; # 聴衆の笑い -<フロア笑> = +<フロア笑> = ; # 聴衆の拍手 -<拍手> = +<拍手> = ; # 講演者が発表中に用いたデモンストレーションの音声 -<デモ> = +<デモ> = ; # 学会講演に発表時間を知らせるためにならすベルの音 -<ベル> = +<ベル> = ; # 転記単位全体が再度読み直された場合 -<朗読間違い> = +<朗読間違い> = ; # 上記以外の音で特に目立った音 -<雑音> = +<雑音> = ; # 0.2秒以上のポーズ -

= +

= ; # Redacted information, for R ; # It is \x00D7 multiplication sign, not your normal 'x' × = × @@ -318,4 +318,3 @@ spk_id = 2 ャ = ǐa ュ = ǐu ョ = ǐo - diff --git a/egs/csj/ASR/local/conf/number.ini b/egs/csj/ASR/local/conf/number.ini index 2613c3409..cf9038f62 100644 --- a/egs/csj/ASR/local/conf/number.ini +++ b/egs/csj/ASR/local/conf/number.ini @@ -1,17 +1,17 @@ ; # This section is ignored if this file is not supplied as the first config file to -; # lhotse prepare csj +; # lhotse prepare csj [SEGMENTS] ; # Allowed period of nonverbal noise. If exceeded, a new segment is created. gap = 0.5 ; # Maximum length of segment (s). maxlen = 10 -; # Minimum length of segment (s). Segments shorter than `minlen` will be dropped silently. +; # Minimum length of segment (s). Segments shorter than `minlen` will be dropped silently. minlen = 0.02 -; # Use this symbol to represent a period of allowed nonverbal noise, i.e. `gap`. -; # Pass an empty string to avoid adding any symbol. It was "" in kaldi. -; # If you intend to use a multicharacter string for gap_sym, remember to register the -; # multicharacter string as part of userdef-string in prepare_lang_char.py. -gap_sym = +; # Use this symbol to represent a period of allowed nonverbal noise, i.e. `gap`. +; # Pass an empty string to avoid adding any symbol. It was "" in kaldi. +; # If you intend to use a multicharacter string for gap_sym, remember to register the +; # multicharacter string as part of userdef-string in prepare_lang_char.py. +gap_sym = [CONSTANTS] ; # Name of this mode @@ -115,59 +115,59 @@ B^ = 0 ; # 0 to remain, 1 to delete ; # Example: '(笑 ナニガ)', '(笑 (F エー)+ソー+イッ+タ+ヨー+ナ)' 笑 = 0 -; # Example: 'コク(笑 サイ+(D オン))', +; # Example: 'コク(笑 サイ+(D オン))', 笑^ = 0 ; # 泣きながら発話 ; # 0 to remain, 1 to delete -; # Example: '(泣 ドンナニ)' +; # Example: '(泣 ドンナニ)' 泣 = 0 泣^ = 0 ; # 咳をしながら発話 ; # 0 to remain, 1 to delete -; # Example: 'シャ(咳 リン) ノ' +; # Example: 'シャ(咳 リン) ノ' 咳 = 0 ; # Example: 'イッ(咳 パン)', 'ワズ(咳 カ)' 咳^ = 0 ; # ささやき声や独り言などの小さな声 ; # 0 to remain, 1 to delete -; # Example: '(L アレコレナンダッケ)', '(L (W コデ;(? コレ,ココデ))+(? セツメー+シ+タ+ホー+ガ+イー+カ+ナ))' +; # Example: '(L アレコレナンダッケ)', '(L (W コデ;(? コレ,ココデ))+(? セツメー+シ+タ+ホー+ガ+イー+カ+ナ))' L = 0 ; # Example: 'デ(L ス)', 'ッ(L テ+コ)ト' L^ = 0 [REPLACEMENTS] ; # ボーカルフライなどで母音が同定できない場合 - = + = ; # 「うん/うーん/ふーん」の音の特定が困難な場合 - = + = ; # 非語彙的な母音の引き延ばし - = + = ; # 非語彙的な子音の引き延ばし - = + = ; # 言語音と独立に講演者の笑いが生じている場合 -<笑> = +<笑> = ; # 言語音と独立に講演者の咳が生じている場合 -<咳> = +<咳> = ; # 言語音と独立に講演者の息が生じている場合 -<息> = +<息> = ; # 講演者の泣き声 -<泣> = +<泣> = ; # 聴衆(司会者なども含む)の発話 -<フロア発話> = +<フロア発話> = ; # 聴衆の笑い -<フロア笑> = +<フロア笑> = ; # 聴衆の拍手 -<拍手> = +<拍手> = ; # 講演者が発表中に用いたデモンストレーションの音声 -<デモ> = +<デモ> = ; # 学会講演に発表時間を知らせるためにならすベルの音 -<ベル> = +<ベル> = ; # 転記単位全体が再度読み直された場合 -<朗読間違い> = +<朗読間違い> = ; # 上記以外の音で特に目立った音 -<雑音> = +<雑音> = ; # 0.2秒以上のポーズ -

= +

= ; # Redacted information, for R ; # It is \x00D7 multiplication sign, not your normal 'x' × = × @@ -318,4 +318,3 @@ spk_id = 2 ャ = ǐa ュ = ǐu ョ = ǐo - diff --git a/egs/csj/ASR/local/conf/symbol.ini b/egs/csj/ASR/local/conf/symbol.ini index 8ba451dd5..f9801284b 100644 --- a/egs/csj/ASR/local/conf/symbol.ini +++ b/egs/csj/ASR/local/conf/symbol.ini @@ -1,17 +1,17 @@ ; # This section is ignored if this file is not supplied as the first config file to -; # lhotse prepare csj +; # lhotse prepare csj [SEGMENTS] ; # Allowed period of nonverbal noise. If exceeded, a new segment is created. gap = 0.5 ; # Maximum length of segment (s). maxlen = 10 -; # Minimum length of segment (s). Segments shorter than `minlen` will be dropped silently. +; # Minimum length of segment (s). Segments shorter than `minlen` will be dropped silently. minlen = 0.02 -; # Use this symbol to represent a period of allowed nonverbal noise, i.e. `gap`. -; # Pass an empty string to avoid adding any symbol. It was "" in kaldi. -; # If you intend to use a multicharacter string for gap_sym, remember to register the -; # multicharacter string as part of userdef-string in prepare_lang_char.py. -gap_sym = +; # Use this symbol to represent a period of allowed nonverbal noise, i.e. `gap`. +; # Pass an empty string to avoid adding any symbol. It was "" in kaldi. +; # If you intend to use a multicharacter string for gap_sym, remember to register the +; # multicharacter string as part of userdef-string in prepare_lang_char.py. +gap_sym = [CONSTANTS] ; # Name of this mode @@ -116,59 +116,59 @@ B^ = 0 ; # 0 to remain, 1 to delete ; # Example: '(笑 ナニガ)', '(笑 (F エー)+ソー+イッ+タ+ヨー+ナ)' 笑 = 0 -; # Example: 'コク(笑 サイ+(D オン))', +; # Example: 'コク(笑 サイ+(D オン))', 笑^ = 0 ; # 泣きながら発話 ; # 0 to remain, 1 to delete -; # Example: '(泣 ドンナニ)' +; # Example: '(泣 ドンナニ)' 泣 = 0 泣^ = 0 ; # 咳をしながら発話 ; # 0 to remain, 1 to delete -; # Example: 'シャ(咳 リン) ノ' +; # Example: 'シャ(咳 リン) ノ' 咳 = 0 ; # Example: 'イッ(咳 パン)', 'ワズ(咳 カ)' 咳^ = 0 ; # ささやき声や独り言などの小さな声 ; # 0 to remain, 1 to delete -; # Example: '(L アレコレナンダッケ)', '(L (W コデ;(? コレ,ココデ))+(? セツメー+シ+タ+ホー+ガ+イー+カ+ナ))' +; # Example: '(L アレコレナンダッケ)', '(L (W コデ;(? コレ,ココデ))+(? セツメー+シ+タ+ホー+ガ+イー+カ+ナ))' L = 0 ; # Example: 'デ(L ス)', 'ッ(L テ+コ)ト' L^ = 0 [REPLACEMENTS] ; # ボーカルフライなどで母音が同定できない場合 - = + = ; # 「うん/うーん/ふーん」の音の特定が困難な場合 - = + = ; # 非語彙的な母音の引き延ばし - = + = ; # 非語彙的な子音の引き延ばし - = + = ; # 言語音と独立に講演者の笑いが生じている場合 -<笑> = +<笑> = ; # 言語音と独立に講演者の咳が生じている場合 -<咳> = +<咳> = ; # 言語音と独立に講演者の息が生じている場合 -<息> = +<息> = ; # 講演者の泣き声 -<泣> = +<泣> = ; # 聴衆(司会者なども含む)の発話 -<フロア発話> = +<フロア発話> = ; # 聴衆の笑い -<フロア笑> = +<フロア笑> = ; # 聴衆の拍手 -<拍手> = +<拍手> = ; # 講演者が発表中に用いたデモンストレーションの音声 -<デモ> = +<デモ> = ; # 学会講演に発表時間を知らせるためにならすベルの音 -<ベル> = +<ベル> = ; # 転記単位全体が再度読み直された場合 -<朗読間違い> = +<朗読間違い> = ; # 上記以外の音で特に目立った音 -<雑音> = +<雑音> = ; # 0.2秒以上のポーズ -

= +

= ; # Redacted information, for R ; # It is \x00D7 multiplication sign, not your normal 'x' × = × @@ -319,4 +319,3 @@ spk_id = 2 ャ = ǐa ュ = ǐu ョ = ǐo - diff --git a/egs/csj/ASR/local/display_manifest_statistics.py b/egs/csj/ASR/local/display_manifest_statistics.py index c9de21073..c043cf853 100644 --- a/egs/csj/ASR/local/display_manifest_statistics.py +++ b/egs/csj/ASR/local/display_manifest_statistics.py @@ -37,9 +37,7 @@ def get_parser(): formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) - parser.add_argument( - "--manifest-dir", type=Path, help="Path to cutset manifests" - ) + parser.add_argument("--manifest-dir", type=Path, help="Path to cutset manifests") return parser.parse_args() diff --git a/egs/csj/ASR/local/prepare_lang_char.py b/egs/csj/ASR/local/prepare_lang_char.py index e4d996871..f0078421b 100644 --- a/egs/csj/ASR/local/prepare_lang_char.py +++ b/egs/csj/ASR/local/prepare_lang_char.py @@ -68,8 +68,7 @@ def get_args(): type=Path, default=None, help=( - "Name of lang dir. " - "If not set, this will default to lang_char_{trans-mode}" + "Name of lang dir. If not set, this will default to lang_char_{trans-mode}" ), ) @@ -87,9 +86,7 @@ def main(): args = get_args() logging.basicConfig( - format=( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] " "%(message)s" - ), + format="%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s", level=logging.INFO, ) @@ -111,8 +108,7 @@ def main(): words = set() logging.info( - f"Creating vocabulary from {args.train_cut.name}" - f" at {args.trans_mode} mode." + f"Creating vocabulary from {args.train_cut.name} at {args.trans_mode} mode." ) for cut in train_set: try: @@ -123,8 +119,7 @@ def main(): ) except KeyError: raise KeyError( - f"Could not find {args.trans_mode} in " - f"{cut.supervisions[0].custom}" + f"Could not find {args.trans_mode} in {cut.supervisions[0].custom}" ) for t in text.split(): if t in args.userdef_string: @@ -143,9 +138,7 @@ def main(): (args.lang_dir / "words_len").write_text(f"{len(words)}") - (args.lang_dir / "userdef_string").write_text( - "\n".join(args.userdef_string) - ) + (args.lang_dir / "userdef_string").write_text("\n".join(args.userdef_string)) (args.lang_dir / "trans_mode").write_text(args.trans_mode) logging.info("Done.") diff --git a/egs/csj/ASR/local/validate_manifest.py b/egs/csj/ASR/local/validate_manifest.py index 0c4c6c1ea..89448a49c 100644 --- a/egs/csj/ASR/local/validate_manifest.py +++ b/egs/csj/ASR/local/validate_manifest.py @@ -68,8 +68,7 @@ def validate_supervision_and_cut_time_bounds(c: Cut): if s.end > c.end: raise ValueError( - f"{c.id}: Supervision end time {s.end} is larger " - f"than cut end time {c.end}" + f"{c.id}: Supervision end time {s.end} is larger than cut end time {c.end}" ) @@ -89,9 +88,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/gigaspeech/ASR/conformer_ctc/asr_datamodule.py b/egs/gigaspeech/ASR/conformer_ctc/asr_datamodule.py index d78e26240..c3e3e84bf 100644 --- a/egs/gigaspeech/ASR/conformer_ctc/asr_datamodule.py +++ b/egs/gigaspeech/ASR/conformer_ctc/asr_datamodule.py @@ -61,10 +61,12 @@ class GigaSpeechAsrDataModule: def add_arguments(cls, parser: argparse.ArgumentParser): group = parser.add_argument_group( title="ASR data related options", - description="These options are used for the preparation of " - "PyTorch DataLoaders from Lhotse CutSet's -- they control the " - "effective batch sizes, sampling strategies, applied data " - "augmentations, etc.", + description=( + "These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc." + ), ) group.add_argument( "--manifest-dir", @@ -76,75 +78,91 @@ class GigaSpeechAsrDataModule: "--max-duration", type=int, default=200.0, - help="Maximum pooled recordings duration (seconds) in a " - "single batch. You can reduce it if it causes CUDA OOM.", + help=( + "Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM." + ), ) group.add_argument( "--bucketing-sampler", type=str2bool, default=True, - help="When enabled, the batches will come from buckets of " - "similar duration (saves padding frames).", + help=( + "When enabled, the batches will come from buckets of " + "similar duration (saves padding frames)." + ), ) group.add_argument( "--num-buckets", type=int, default=30, - help="The number of buckets for the DynamicBucketingSampler" - "(you might want to increase it for larger datasets).", + help=( + "The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets)." + ), ) group.add_argument( "--concatenate-cuts", type=str2bool, default=False, - help="When enabled, utterances (cuts) will be concatenated " - "to minimize the amount of padding.", + help=( + "When enabled, utterances (cuts) will be concatenated " + "to minimize the amount of padding." + ), ) group.add_argument( "--duration-factor", type=float, default=1.0, - help="Determines the maximum duration of a concatenated cut " - "relative to the duration of the longest cut in a batch.", + help=( + "Determines the maximum duration of a concatenated cut " + "relative to the duration of the longest cut in a batch." + ), ) group.add_argument( "--gap", type=float, default=1.0, - help="The amount of padding (in seconds) inserted between " - "concatenated cuts. This padding is filled with noise when " - "noise augmentation is used.", + help=( + "The amount of padding (in seconds) inserted between " + "concatenated cuts. This padding is filled with noise when " + "noise augmentation is used." + ), ) group.add_argument( "--on-the-fly-feats", type=str2bool, default=False, - help="When enabled, use on-the-fly cut mixing and feature " - "extraction. Will drop existing precomputed feature manifests " - "if available.", + help=( + "When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available." + ), ) group.add_argument( "--shuffle", type=str2bool, default=True, - help="When enabled (=default), the examples will be " - "shuffled for each epoch.", + help=( + "When enabled (=default), the examples will be shuffled for each epoch." + ), ) group.add_argument( "--return-cuts", type=str2bool, default=True, - help="When enabled, each batch will have the " - "field: batch['supervisions']['cut'] with the cuts that " - "were used to construct it.", + help=( + "When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it." + ), ) group.add_argument( "--num-workers", type=int, default=2, - help="The number of training dataloader workers that " - "collect the batches.", + help="The number of training dataloader workers that collect the batches.", ) group.add_argument( @@ -158,18 +176,22 @@ class GigaSpeechAsrDataModule: "--spec-aug-time-warp-factor", type=int, default=80, - help="Used only when --enable-spec-aug is True. " - "It specifies the factor for time warping in SpecAugment. " - "Larger values mean more warping. " - "A value less than 1 means to disable time warp.", + help=( + "Used only when --enable-spec-aug is True. " + "It specifies the factor for time warping in SpecAugment. " + "Larger values mean more warping. " + "A value less than 1 means to disable time warp." + ), ) group.add_argument( "--enable-musan", type=str2bool, default=True, - help="When enabled, select noise from MUSAN and mix it " - "with training dataset. ", + help=( + "When enabled, select noise from MUSAN and mix it " + "with training dataset. " + ), ) # GigaSpeech specific arguments @@ -183,30 +205,25 @@ class GigaSpeechAsrDataModule: "--small-dev", type=str2bool, default=False, - help="Should we use only 1000 utterances for dev " - "(speeds up training)", + help="Should we use only 1000 utterances for dev (speeds up training)", ) def train_dataloaders(self, cuts_train: CutSet) -> DataLoader: logging.info("About to get Musan cuts") - cuts_musan = load_manifest( - self.args.manifest_dir / "musan_cuts.jsonl.gz" - ) + cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") transforms = [] if self.args.enable_musan: logging.info("Enable MUSAN") transforms.append( - CutMix( - cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True - ) + CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) ) else: logging.info("Disable MUSAN") if self.args.concatenate_cuts: logging.info( - f"Using cut concatenation with duration factor " + "Using cut concatenation with duration factor " f"{self.args.duration_factor} and gap {self.args.gap}." ) # Cut concatenation should be the first transform in the list, @@ -221,9 +238,7 @@ class GigaSpeechAsrDataModule: input_transforms = [] if self.args.enable_spec_aug: logging.info("Enable SpecAugment") - logging.info( - f"Time warp factor: {self.args.spec_aug_time_warp_factor}" - ) + logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") input_transforms.append( SpecAugment( time_warp_factor=self.args.spec_aug_time_warp_factor, @@ -256,9 +271,7 @@ class GigaSpeechAsrDataModule: # Drop feats to be on the safe side. train = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) - ), + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), input_transforms=input_transforms, return_cuts=self.args.return_cuts, ) @@ -304,9 +317,7 @@ class GigaSpeechAsrDataModule: if self.args.on_the_fly_feats: validate = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) - ), + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), return_cuts=self.args.return_cuts, ) else: @@ -362,9 +373,7 @@ class GigaSpeechAsrDataModule: @lru_cache() def dev_cuts(self) -> CutSet: logging.info("About to get dev cuts") - cuts_valid = load_manifest_lazy( - self.args.manifest_dir / "cuts_DEV.jsonl.gz" - ) + cuts_valid = load_manifest_lazy(self.args.manifest_dir / "cuts_DEV.jsonl.gz") if self.args.small_dev: return cuts_valid.subset(first=1000) else: diff --git a/egs/gigaspeech/ASR/conformer_ctc/conformer.py b/egs/gigaspeech/ASR/conformer_ctc/conformer.py index 6fac07f93..1153a814c 100644 --- a/egs/gigaspeech/ASR/conformer_ctc/conformer.py +++ b/egs/gigaspeech/ASR/conformer_ctc/conformer.py @@ -160,9 +160,7 @@ class ConformerEncoderLayer(nn.Module): use_conv_batchnorm: bool = False, ) -> None: super(ConformerEncoderLayer, self).__init__() - self.self_attn = RelPositionMultiheadAttention( - d_model, nhead, dropout=0.0 - ) + self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) self.feed_forward = nn.Sequential( nn.Linear(d_model, dim_feedforward), @@ -182,18 +180,14 @@ class ConformerEncoderLayer(nn.Module): d_model, cnn_module_kernel, use_batchnorm=use_conv_batchnorm ) - self.norm_ff_macaron = nn.LayerNorm( - d_model - ) # for the macaron style FNN module + self.norm_ff_macaron = nn.LayerNorm(d_model) # for the macaron style FNN module self.norm_ff = nn.LayerNorm(d_model) # for the FNN module self.norm_mha = nn.LayerNorm(d_model) # for the MHA module self.ff_scale = 0.5 self.norm_conv = nn.LayerNorm(d_model) # for the CNN module - self.norm_final = nn.LayerNorm( - d_model - ) # for the final output of the block + self.norm_final = nn.LayerNorm(d_model) # for the final output of the block self.dropout = nn.Dropout(dropout) @@ -227,9 +221,7 @@ class ConformerEncoderLayer(nn.Module): residual = src if self.normalize_before: src = self.norm_ff_macaron(src) - src = residual + self.ff_scale * self.dropout( - self.feed_forward_macaron(src) - ) + src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src)) if not self.normalize_before: src = self.norm_ff_macaron(src) @@ -348,9 +340,7 @@ class RelPositionalEncoding(torch.nn.Module): """ - def __init__( - self, d_model: int, dropout_rate: float, max_len: int = 5000 - ) -> None: + def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: """Construct an PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() self.d_model = d_model @@ -366,9 +356,7 @@ class RelPositionalEncoding(torch.nn.Module): # the length of self.pe is 2 * input_len - 1 if self.pe.size(1) >= x.size(1) * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str( - x.device - ): + if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return # Suppose `i` means to the position of query vector and `j` means the @@ -638,9 +626,9 @@ class RelPositionMultiheadAttention(nn.Module): if torch.equal(query, key) and torch.equal(key, value): # self-attention - q, k, v = nn.functional.linear( - query, in_proj_weight, in_proj_bias - ).chunk(3, dim=-1) + q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk( + 3, dim=-1 + ) elif torch.equal(key, value): # encoder-decoder attention @@ -708,33 +696,25 @@ class RelPositionMultiheadAttention(nn.Module): if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: - raise RuntimeError( - "The size of the 2D attn_mask is not correct." - ) + raise RuntimeError("The size of the 2D attn_mask is not correct.") elif attn_mask.dim() == 3: if list(attn_mask.size()) != [ bsz * num_heads, query.size(0), key.size(0), ]: - raise RuntimeError( - "The size of the 3D attn_mask is not correct." - ) + raise RuntimeError("The size of the 3D attn_mask is not correct.") else: raise RuntimeError( - "attn_mask's dimension {} is not supported".format( - attn_mask.dim() - ) + "attn_mask's dimension {} is not supported".format(attn_mask.dim()) ) # attn_mask's dim is 3 now. # convert ByteTensor key_padding_mask to bool - if ( - key_padding_mask is not None - and key_padding_mask.dtype == torch.uint8 - ): + if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: warnings.warn( - "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." + "Byte tensor for key_padding_mask is deprecated. Use bool tensor" + " instead." ) key_padding_mask = key_padding_mask.to(torch.bool) @@ -771,9 +751,7 @@ class RelPositionMultiheadAttention(nn.Module): # first compute matrix a and matrix c # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - matrix_ac = torch.matmul( - q_with_bias_u, k - ) # (batch, head, time1, time2) + matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2) # compute matrix b and matrix d matrix_bd = torch.matmul( @@ -785,9 +763,7 @@ class RelPositionMultiheadAttention(nn.Module): matrix_ac + matrix_bd ) * scaling # (batch, head, time1, time2) - attn_output_weights = attn_output_weights.view( - bsz * num_heads, tgt_len, -1 - ) + attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) assert list(attn_output_weights.size()) == [ bsz * num_heads, @@ -821,13 +797,9 @@ class RelPositionMultiheadAttention(nn.Module): attn_output = torch.bmm(attn_output_weights, v) assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] attn_output = ( - attn_output.transpose(0, 1) - .contiguous() - .view(tgt_len, bsz, embed_dim) - ) - attn_output = nn.functional.linear( - attn_output, out_proj_weight, out_proj_bias + attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) ) + attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) if need_weights: # average attention weights over heads diff --git a/egs/gigaspeech/ASR/conformer_ctc/decode.py b/egs/gigaspeech/ASR/conformer_ctc/decode.py index 9c1418baa..b38ae9c8c 100755 --- a/egs/gigaspeech/ASR/conformer_ctc/decode.py +++ b/egs/gigaspeech/ASR/conformer_ctc/decode.py @@ -62,16 +62,19 @@ def get_parser(): "--epoch", type=int, default=0, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", + help=( + "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." + ), ) parser.add_argument( "--avg", type=int, default=1, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. " + ), ) parser.add_argument( @@ -476,9 +479,7 @@ def decode_dataset( results[lm_scale].extend(this_batch) else: - assert ( - len(results) > 0 - ), "It should not decode to empty in the first batch!" + assert len(results) > 0, "It should not decode to empty in the first batch!" this_batch = [] hyp_words = [] for cut_id, ref_text in zip(cut_ids, texts): @@ -493,9 +494,7 @@ def decode_dataset( if batch_idx % 100 == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -528,9 +527,7 @@ def save_results( test_set_wers[key] = wer if enable_log: - logging.info( - "Wrote detailed error stats to {}".format(errs_filename) - ) + logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt" @@ -705,9 +702,7 @@ def main(): eos_id=eos_id, ) - save_results( - params=params, test_set_name=test_set, results_dict=results_dict - ) + save_results(params=params, test_set_name=test_set, results_dict=results_dict) logging.info("Done!") diff --git a/egs/gigaspeech/ASR/conformer_ctc/gigaspeech_scoring.py b/egs/gigaspeech/ASR/conformer_ctc/gigaspeech_scoring.py index ef53b77f8..880aa76e2 100755 --- a/egs/gigaspeech/ASR/conformer_ctc/gigaspeech_scoring.py +++ b/egs/gigaspeech/ASR/conformer_ctc/gigaspeech_scoring.py @@ -73,8 +73,7 @@ def asr_text_post_processing(text: str) -> str: if __name__ == "__main__": parser = argparse.ArgumentParser( - description="This script evaluates GigaSpeech ASR result via" - "SCTK's tool sclite" + description="This script evaluates GigaSpeech ASR result viaSCTK's tool sclite" ) parser.add_argument( "ref", diff --git a/egs/gigaspeech/ASR/conformer_ctc/label_smoothing.py b/egs/gigaspeech/ASR/conformer_ctc/label_smoothing.py index cdc85ce9a..3b94f0c4b 100644 --- a/egs/gigaspeech/ASR/conformer_ctc/label_smoothing.py +++ b/egs/gigaspeech/ASR/conformer_ctc/label_smoothing.py @@ -78,13 +78,10 @@ class LabelSmoothingLoss(torch.nn.Module): ignored = target == self.ignore_index target[ignored] = 0 - true_dist = torch.nn.functional.one_hot( - target, num_classes=num_classes - ).to(x) + true_dist = torch.nn.functional.one_hot(target, num_classes=num_classes).to(x) true_dist = ( - true_dist * (1 - self.label_smoothing) - + self.label_smoothing / num_classes + true_dist * (1 - self.label_smoothing) + self.label_smoothing / num_classes ) # Set the value of ignored indexes to 0 true_dist[ignored] = 0 diff --git a/egs/gigaspeech/ASR/conformer_ctc/subsampling.py b/egs/gigaspeech/ASR/conformer_ctc/subsampling.py index 542fb0364..8e0f73d05 100644 --- a/egs/gigaspeech/ASR/conformer_ctc/subsampling.py +++ b/egs/gigaspeech/ASR/conformer_ctc/subsampling.py @@ -42,13 +42,9 @@ class Conv2dSubsampling(nn.Module): assert idim >= 7 super().__init__() self.conv = nn.Sequential( - nn.Conv2d( - in_channels=1, out_channels=odim, kernel_size=3, stride=2 - ), + nn.Conv2d(in_channels=1, out_channels=odim, kernel_size=3, stride=2), nn.ReLU(), - nn.Conv2d( - in_channels=odim, out_channels=odim, kernel_size=3, stride=2 - ), + nn.Conv2d(in_channels=odim, out_channels=odim, kernel_size=3, stride=2), nn.ReLU(), ) self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim) @@ -132,17 +128,13 @@ class VggSubsampling(nn.Module): ) ) layers.append( - torch.nn.MaxPool2d( - kernel_size=2, stride=2, padding=0, ceil_mode=True - ) + torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=True) ) cur_channels = block_dim self.layers = nn.Sequential(*layers) - self.out = nn.Linear( - block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim - ) + self.out = nn.Linear(block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim) def forward(self, x: torch.Tensor) -> torch.Tensor: """Subsample x. diff --git a/egs/gigaspeech/ASR/conformer_ctc/train.py b/egs/gigaspeech/ASR/conformer_ctc/train.py index 2965cde18..4883d04d8 100755 --- a/egs/gigaspeech/ASR/conformer_ctc/train.py +++ b/egs/gigaspeech/ASR/conformer_ctc/train.py @@ -386,9 +386,7 @@ def compute_loss( # # See https://github.com/k2-fsa/icefall/issues/97 # for more details - unsorted_token_ids = graph_compiler.texts_to_ids( - supervisions["text"] - ) + unsorted_token_ids = graph_compiler.texts_to_ids(supervisions["text"]) att_loss = mmodel.decoder_forward( encoder_memory, memory_mask, @@ -521,9 +519,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -641,9 +637,7 @@ def run(rank, world_size, args): cur_lr = optimizer._rate if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) + tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) if rank == 0: diff --git a/egs/gigaspeech/ASR/conformer_ctc/transformer.py b/egs/gigaspeech/ASR/conformer_ctc/transformer.py index 00ca027a7..0566cfc81 100644 --- a/egs/gigaspeech/ASR/conformer_ctc/transformer.py +++ b/egs/gigaspeech/ASR/conformer_ctc/transformer.py @@ -151,9 +151,7 @@ class Transformer(nn.Module): norm=decoder_norm, ) - self.decoder_output_layer = torch.nn.Linear( - d_model, self.decoder_num_class - ) + self.decoder_output_layer = torch.nn.Linear(d_model, self.decoder_num_class) self.decoder_criterion = LabelSmoothingLoss() else: @@ -181,18 +179,13 @@ class Transformer(nn.Module): memory_key_padding_mask for the decoder. Its shape is (N, T). It is None if `supervision` is None. """ - if ( - isinstance(self.use_feat_batchnorm, bool) - and self.use_feat_batchnorm - ): + if isinstance(self.use_feat_batchnorm, bool) and self.use_feat_batchnorm: x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T) x = self.feat_batchnorm(x) x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C) if isinstance(self.use_feat_batchnorm, float): x *= self.use_feat_batchnorm - encoder_memory, memory_key_padding_mask = self.run_encoder( - x, supervision - ) + encoder_memory, memory_key_padding_mask = self.run_encoder(x, supervision) x = self.ctc_output(encoder_memory) return x, encoder_memory, memory_key_padding_mask @@ -273,23 +266,17 @@ class Transformer(nn.Module): """ ys_in = add_sos(token_ids, sos_id=sos_id) ys_in = [torch.tensor(y) for y in ys_in] - ys_in_pad = pad_sequence( - ys_in, batch_first=True, padding_value=float(eos_id) - ) + ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id)) ys_out = add_eos(token_ids, eos_id=eos_id) ys_out = [torch.tensor(y) for y in ys_out] - ys_out_pad = pad_sequence( - ys_out, batch_first=True, padding_value=float(-1) - ) + ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1)) device = memory.device ys_in_pad = ys_in_pad.to(device) ys_out_pad = ys_out_pad.to(device) - tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( - device - ) + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) # TODO: Use length information to create the decoder padding mask @@ -350,23 +337,17 @@ class Transformer(nn.Module): ys_in = add_sos(token_ids, sos_id=sos_id) ys_in = [torch.tensor(y) for y in ys_in] - ys_in_pad = pad_sequence( - ys_in, batch_first=True, padding_value=float(eos_id) - ) + ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id)) ys_out = add_eos(token_ids, eos_id=eos_id) ys_out = [torch.tensor(y) for y in ys_out] - ys_out_pad = pad_sequence( - ys_out, batch_first=True, padding_value=float(-1) - ) + ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1)) device = memory.device ys_in_pad = ys_in_pad.to(device, dtype=torch.int64) ys_out_pad = ys_out_pad.to(device, dtype=torch.int64) - tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( - device - ) + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) # TODO: Use length information to create the decoder padding mask @@ -639,9 +620,7 @@ def _get_activation_fn(activation: str): elif activation == "gelu": return nn.functional.gelu - raise RuntimeError( - "activation should be relu/gelu, not {}".format(activation) - ) + raise RuntimeError("activation should be relu/gelu, not {}".format(activation)) class PositionalEncoding(nn.Module): @@ -843,9 +822,7 @@ def encoder_padding_mask( 1, ).to(torch.int32) - lengths = [ - 0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1) - ] + lengths = [0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)] for idx in range(supervision_segments.size(0)): # Note: TorchScript doesn't allow to unpack tensors as tuples sequence_idx = supervision_segments[idx, 0].item() @@ -866,9 +843,7 @@ def encoder_padding_mask( return mask -def decoder_padding_mask( - ys_pad: torch.Tensor, ignore_id: int = -1 -) -> torch.Tensor: +def decoder_padding_mask(ys_pad: torch.Tensor, ignore_id: int = -1) -> torch.Tensor: """Generate a length mask for input. The masked position are filled with True, diff --git a/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_dev_test.py b/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_dev_test.py index 8209ee3ec..07beeb1f0 100755 --- a/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_dev_test.py +++ b/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_dev_test.py @@ -77,9 +77,7 @@ def compute_fbank_gigaspeech_dev_test(): def main(): - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) compute_fbank_gigaspeech_dev_test() diff --git a/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_splits.py b/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_splits.py index 6410249db..0ee845ec8 100755 --- a/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_splits.py +++ b/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_splits.py @@ -47,8 +47,10 @@ def get_parser(): "--batch-duration", type=float, default=600.0, - help="The maximum number of audio seconds in a batch." - "Determines batch size dynamically.", + help=( + "The maximum number of audio seconds in a batch." + "Determines batch size dynamically." + ), ) parser.add_argument( @@ -134,9 +136,7 @@ def main(): date_time = now.strftime("%Y-%m-%d-%H-%M-%S") log_filename = "log-compute_fbank_gigaspeech_splits" - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" log_filename = f"{log_filename}-{date_time}" logging.basicConfig( diff --git a/egs/gigaspeech/ASR/local/preprocess_gigaspeech.py b/egs/gigaspeech/ASR/local/preprocess_gigaspeech.py index 48d10a157..31abe7fff 100755 --- a/egs/gigaspeech/ASR/local/preprocess_gigaspeech.py +++ b/egs/gigaspeech/ASR/local/preprocess_gigaspeech.py @@ -98,19 +98,13 @@ def preprocess_giga_speech(): f"Speed perturb for {partition} with factors 0.9 and 1.1 " "(Perturbing may take 8 minutes and saving may take 20 minutes)" ) - cut_set = ( - cut_set - + cut_set.perturb_speed(0.9) - + cut_set.perturb_speed(1.1) - ) + cut_set = cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) logging.info(f"Saving to {raw_cuts_path}") cut_set.to_file(raw_cuts_path) def main(): - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) preprocess_giga_speech() diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py index c87686e1e..9ae3f071e 100644 --- a/egs/gigaspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py +++ b/egs/gigaspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py @@ -73,10 +73,12 @@ class GigaSpeechAsrDataModule: def add_arguments(cls, parser: argparse.ArgumentParser): group = parser.add_argument_group( title="ASR data related options", - description="These options are used for the preparation of " - "PyTorch DataLoaders from Lhotse CutSet's -- they control the " - "effective batch sizes, sampling strategies, applied data " - "augmentations, etc.", + description=( + "These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc." + ), ) group.add_argument( "--manifest-dir", @@ -88,75 +90,91 @@ class GigaSpeechAsrDataModule: "--max-duration", type=int, default=200.0, - help="Maximum pooled recordings duration (seconds) in a " - "single batch. You can reduce it if it causes CUDA OOM.", + help=( + "Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM." + ), ) group.add_argument( "--bucketing-sampler", type=str2bool, default=True, - help="When enabled, the batches will come from buckets of " - "similar duration (saves padding frames).", + help=( + "When enabled, the batches will come from buckets of " + "similar duration (saves padding frames)." + ), ) group.add_argument( "--num-buckets", type=int, default=30, - help="The number of buckets for the DynamicBucketingSampler" - "(you might want to increase it for larger datasets).", + help=( + "The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets)." + ), ) group.add_argument( "--concatenate-cuts", type=str2bool, default=False, - help="When enabled, utterances (cuts) will be concatenated " - "to minimize the amount of padding.", + help=( + "When enabled, utterances (cuts) will be concatenated " + "to minimize the amount of padding." + ), ) group.add_argument( "--duration-factor", type=float, default=1.0, - help="Determines the maximum duration of a concatenated cut " - "relative to the duration of the longest cut in a batch.", + help=( + "Determines the maximum duration of a concatenated cut " + "relative to the duration of the longest cut in a batch." + ), ) group.add_argument( "--gap", type=float, default=1.0, - help="The amount of padding (in seconds) inserted between " - "concatenated cuts. This padding is filled with noise when " - "noise augmentation is used.", + help=( + "The amount of padding (in seconds) inserted between " + "concatenated cuts. This padding is filled with noise when " + "noise augmentation is used." + ), ) group.add_argument( "--on-the-fly-feats", type=str2bool, default=False, - help="When enabled, use on-the-fly cut mixing and feature " - "extraction. Will drop existing precomputed feature manifests " - "if available.", + help=( + "When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available." + ), ) group.add_argument( "--shuffle", type=str2bool, default=True, - help="When enabled (=default), the examples will be " - "shuffled for each epoch.", + help=( + "When enabled (=default), the examples will be shuffled for each epoch." + ), ) group.add_argument( "--return-cuts", type=str2bool, default=True, - help="When enabled, each batch will have the " - "field: batch['supervisions']['cut'] with the cuts that " - "were used to construct it.", + help=( + "When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it." + ), ) group.add_argument( "--num-workers", type=int, default=2, - help="The number of training dataloader workers that " - "collect the batches.", + help="The number of training dataloader workers that collect the batches.", ) group.add_argument( @@ -170,18 +188,22 @@ class GigaSpeechAsrDataModule: "--spec-aug-time-warp-factor", type=int, default=80, - help="Used only when --enable-spec-aug is True. " - "It specifies the factor for time warping in SpecAugment. " - "Larger values mean more warping. " - "A value less than 1 means to disable time warp.", + help=( + "Used only when --enable-spec-aug is True. " + "It specifies the factor for time warping in SpecAugment. " + "Larger values mean more warping. " + "A value less than 1 means to disable time warp." + ), ) group.add_argument( "--enable-musan", type=str2bool, default=True, - help="When enabled, select noise from MUSAN and mix it " - "with training dataset. ", + help=( + "When enabled, select noise from MUSAN and mix it " + "with training dataset. " + ), ) # GigaSpeech specific arguments @@ -195,8 +217,7 @@ class GigaSpeechAsrDataModule: "--small-dev", type=str2bool, default=False, - help="Should we use only 1000 utterances for dev " - "(speeds up training)", + help="Should we use only 1000 utterances for dev (speeds up training)", ) def train_dataloaders( @@ -216,20 +237,16 @@ class GigaSpeechAsrDataModule: if self.args.enable_musan: logging.info("Enable MUSAN") logging.info("About to get Musan cuts") - cuts_musan = load_manifest( - self.args.manifest_dir / "musan_cuts.jsonl.gz" - ) + cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") transforms.append( - CutMix( - cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True - ) + CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) ) else: logging.info("Disable MUSAN") if self.args.concatenate_cuts: logging.info( - f"Using cut concatenation with duration factor " + "Using cut concatenation with duration factor " f"{self.args.duration_factor} and gap {self.args.gap}." ) # Cut concatenation should be the first transform in the list, @@ -244,9 +261,7 @@ class GigaSpeechAsrDataModule: input_transforms = [] if self.args.enable_spec_aug: logging.info("Enable SpecAugment") - logging.info( - f"Time warp factor: {self.args.spec_aug_time_warp_factor}" - ) + logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") # Set the value of num_frame_masks according to Lhotse's version. # In different Lhotse's versions, the default of num_frame_masks is # different. @@ -289,9 +304,7 @@ class GigaSpeechAsrDataModule: # Drop feats to be on the safe side. train = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) - ), + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), input_transforms=input_transforms, return_cuts=self.args.return_cuts, ) @@ -347,9 +360,7 @@ class GigaSpeechAsrDataModule: if self.args.on_the_fly_feats: validate = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) - ), + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), return_cuts=self.args.return_cuts, ) else: @@ -405,9 +416,7 @@ class GigaSpeechAsrDataModule: @lru_cache() def dev_cuts(self) -> CutSet: logging.info("About to get dev cuts") - cuts_valid = load_manifest_lazy( - self.args.manifest_dir / "cuts_DEV.jsonl.gz" - ) + cuts_valid = load_manifest_lazy(self.args.manifest_dir / "cuts_DEV.jsonl.gz") if self.args.small_dev: return cuts_valid.subset(first=1000) else: diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/decode.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/decode.py index 5849a3471..9f5d4711b 100755 --- a/egs/gigaspeech/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/gigaspeech/ASR/pruned_transducer_stateless2/decode.py @@ -77,11 +77,7 @@ from beam_search import ( from gigaspeech_scoring import asr_text_post_processing from train import get_params, get_transducer_model -from icefall.checkpoint import ( - average_checkpoints, - find_checkpoints, - load_checkpoint, -) +from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.utils import ( AttributeDict, setup_logger, @@ -118,9 +114,11 @@ def get_parser(): "--avg", type=int, default=8, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'" + ), ) parser.add_argument( @@ -188,8 +186,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -258,9 +255,7 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) hyps = [] if params.decoding_method == "fast_beam_search": @@ -275,10 +270,7 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -324,11 +316,7 @@ def decode_one_batch( return {"greedy_search": hyps} elif params.decoding_method == "fast_beam_search": return { - ( - f"beam_{params.beam}_" - f"max_contexts_{params.max_contexts}_" - f"max_states_{params.max_states}" - ): hyps + f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps } else: return {f"beam_size_{params.beam_size}": hyps} @@ -398,9 +386,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -434,8 +420,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -511,8 +496,7 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/export.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/export.py index cff9c7377..17f8614dc 100755 --- a/egs/gigaspeech/ASR/pruned_transducer_stateless2/export.py +++ b/egs/gigaspeech/ASR/pruned_transducer_stateless2/export.py @@ -51,11 +51,7 @@ import sentencepiece as spm import torch from train import get_params, get_transducer_model -from icefall.checkpoint import ( - average_checkpoints, - find_checkpoints, - load_checkpoint, -) +from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.utils import str2bool @@ -87,9 +83,11 @@ def get_parser(): "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'" + ), ) parser.add_argument( @@ -120,8 +118,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) return parser @@ -160,8 +157,7 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -209,9 +205,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py index 83ae25561..4d1a2356d 100755 --- a/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py @@ -77,9 +77,7 @@ from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def get_parser(): @@ -178,42 +176,45 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", + help=( + "The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss" + ), ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", + help=( + "The scale to smooth the loss with lm (output of prediction network) part." + ), ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network)part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", + help=( + "To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss." + ), ) parser.add_argument( @@ -553,23 +554,16 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 - if warmup < 1.0 - else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = ( - params.simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss + 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) ) + loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() @@ -732,9 +726,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") diff --git a/egs/librispeech/ASR/conformer_ctc/ali.py b/egs/librispeech/ASR/conformer_ctc/ali.py index 2828e309e..0169d0f82 100755 --- a/egs/librispeech/ASR/conformer_ctc/ali.py +++ b/egs/librispeech/ASR/conformer_ctc/ali.py @@ -61,16 +61,19 @@ def get_parser(): "--epoch", type=int, default=34, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", + help=( + "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." + ), ) parser.add_argument( "--avg", type=int, default=20, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. " + ), ) parser.add_argument( @@ -231,9 +234,7 @@ def compute_alignments( labels_ali = get_alignments(best_path, kind="labels") aux_labels_ali = get_alignments(best_path, kind="aux_labels") assert len(labels_ali) == len(aux_labels_ali) == len(cut_list) - for cut, labels, aux_labels in zip( - cut_list, labels_ali, aux_labels_ali - ): + for cut, labels, aux_labels in zip(cut_list, labels_ali, aux_labels_ali): cut.labels_alignment = labels_writer.store_array( key=cut.id, value=np.asarray(labels, dtype=np.int32), @@ -258,9 +259,7 @@ def compute_alignments( if batch_idx % 100 == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return CutSet.from_cuts(cuts) @@ -289,9 +288,7 @@ def main(): out_labels_ali_filename = out_dir / f"labels_{params.dataset}.h5" out_aux_labels_ali_filename = out_dir / f"aux_labels_{params.dataset}.h5" - out_manifest_filename = ( - out_dir / f"librispeech_cuts_{params.dataset}.jsonl.gz" - ) + out_manifest_filename = out_dir / f"librispeech_cuts_{params.dataset}.jsonl.gz" for f in ( out_labels_ali_filename, diff --git a/egs/librispeech/ASR/conformer_ctc/conformer.py b/egs/librispeech/ASR/conformer_ctc/conformer.py index 6fac07f93..1153a814c 100644 --- a/egs/librispeech/ASR/conformer_ctc/conformer.py +++ b/egs/librispeech/ASR/conformer_ctc/conformer.py @@ -160,9 +160,7 @@ class ConformerEncoderLayer(nn.Module): use_conv_batchnorm: bool = False, ) -> None: super(ConformerEncoderLayer, self).__init__() - self.self_attn = RelPositionMultiheadAttention( - d_model, nhead, dropout=0.0 - ) + self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) self.feed_forward = nn.Sequential( nn.Linear(d_model, dim_feedforward), @@ -182,18 +180,14 @@ class ConformerEncoderLayer(nn.Module): d_model, cnn_module_kernel, use_batchnorm=use_conv_batchnorm ) - self.norm_ff_macaron = nn.LayerNorm( - d_model - ) # for the macaron style FNN module + self.norm_ff_macaron = nn.LayerNorm(d_model) # for the macaron style FNN module self.norm_ff = nn.LayerNorm(d_model) # for the FNN module self.norm_mha = nn.LayerNorm(d_model) # for the MHA module self.ff_scale = 0.5 self.norm_conv = nn.LayerNorm(d_model) # for the CNN module - self.norm_final = nn.LayerNorm( - d_model - ) # for the final output of the block + self.norm_final = nn.LayerNorm(d_model) # for the final output of the block self.dropout = nn.Dropout(dropout) @@ -227,9 +221,7 @@ class ConformerEncoderLayer(nn.Module): residual = src if self.normalize_before: src = self.norm_ff_macaron(src) - src = residual + self.ff_scale * self.dropout( - self.feed_forward_macaron(src) - ) + src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src)) if not self.normalize_before: src = self.norm_ff_macaron(src) @@ -348,9 +340,7 @@ class RelPositionalEncoding(torch.nn.Module): """ - def __init__( - self, d_model: int, dropout_rate: float, max_len: int = 5000 - ) -> None: + def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: """Construct an PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() self.d_model = d_model @@ -366,9 +356,7 @@ class RelPositionalEncoding(torch.nn.Module): # the length of self.pe is 2 * input_len - 1 if self.pe.size(1) >= x.size(1) * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str( - x.device - ): + if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return # Suppose `i` means to the position of query vector and `j` means the @@ -638,9 +626,9 @@ class RelPositionMultiheadAttention(nn.Module): if torch.equal(query, key) and torch.equal(key, value): # self-attention - q, k, v = nn.functional.linear( - query, in_proj_weight, in_proj_bias - ).chunk(3, dim=-1) + q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk( + 3, dim=-1 + ) elif torch.equal(key, value): # encoder-decoder attention @@ -708,33 +696,25 @@ class RelPositionMultiheadAttention(nn.Module): if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: - raise RuntimeError( - "The size of the 2D attn_mask is not correct." - ) + raise RuntimeError("The size of the 2D attn_mask is not correct.") elif attn_mask.dim() == 3: if list(attn_mask.size()) != [ bsz * num_heads, query.size(0), key.size(0), ]: - raise RuntimeError( - "The size of the 3D attn_mask is not correct." - ) + raise RuntimeError("The size of the 3D attn_mask is not correct.") else: raise RuntimeError( - "attn_mask's dimension {} is not supported".format( - attn_mask.dim() - ) + "attn_mask's dimension {} is not supported".format(attn_mask.dim()) ) # attn_mask's dim is 3 now. # convert ByteTensor key_padding_mask to bool - if ( - key_padding_mask is not None - and key_padding_mask.dtype == torch.uint8 - ): + if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: warnings.warn( - "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." + "Byte tensor for key_padding_mask is deprecated. Use bool tensor" + " instead." ) key_padding_mask = key_padding_mask.to(torch.bool) @@ -771,9 +751,7 @@ class RelPositionMultiheadAttention(nn.Module): # first compute matrix a and matrix c # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - matrix_ac = torch.matmul( - q_with_bias_u, k - ) # (batch, head, time1, time2) + matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2) # compute matrix b and matrix d matrix_bd = torch.matmul( @@ -785,9 +763,7 @@ class RelPositionMultiheadAttention(nn.Module): matrix_ac + matrix_bd ) * scaling # (batch, head, time1, time2) - attn_output_weights = attn_output_weights.view( - bsz * num_heads, tgt_len, -1 - ) + attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) assert list(attn_output_weights.size()) == [ bsz * num_heads, @@ -821,13 +797,9 @@ class RelPositionMultiheadAttention(nn.Module): attn_output = torch.bmm(attn_output_weights, v) assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] attn_output = ( - attn_output.transpose(0, 1) - .contiguous() - .view(tgt_len, bsz, embed_dim) - ) - attn_output = nn.functional.linear( - attn_output, out_proj_weight, out_proj_bias + attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) ) + attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) if need_weights: # average attention weights over heads diff --git a/egs/librispeech/ASR/conformer_ctc/decode.py b/egs/librispeech/ASR/conformer_ctc/decode.py index 3f3b1acda..66fdf82d9 100755 --- a/egs/librispeech/ASR/conformer_ctc/decode.py +++ b/egs/librispeech/ASR/conformer_ctc/decode.py @@ -64,16 +64,19 @@ def get_parser(): "--epoch", type=int, default=77, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", + help=( + "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." + ), ) parser.add_argument( "--avg", type=int, default=55, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. " + ), ) parser.add_argument( @@ -551,9 +554,7 @@ def decode_dataset( results[lm_scale].extend(this_batch) else: - assert ( - len(results) > 0 - ), "It should not decode to empty in the first batch!" + assert len(results) > 0, "It should not decode to empty in the first batch!" this_batch = [] hyp_words = [] for ref_text in texts: @@ -568,9 +569,7 @@ def decode_dataset( if batch_idx % 100 == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -602,9 +601,7 @@ def save_results( test_set_wers[key] = wer if enable_log: - logging.info( - "Wrote detailed error stats to {}".format(errs_filename) - ) + logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt" @@ -809,9 +806,7 @@ def main(): eos_id=eos_id, ) - save_results( - params=params, test_set_name=test_set, results_dict=results_dict - ) + save_results(params=params, test_set_name=test_set, results_dict=results_dict) logging.info("Done!") diff --git a/egs/librispeech/ASR/conformer_ctc/export.py b/egs/librispeech/ASR/conformer_ctc/export.py index 28c28df01..bdb8a85e5 100755 --- a/egs/librispeech/ASR/conformer_ctc/export.py +++ b/egs/librispeech/ASR/conformer_ctc/export.py @@ -40,17 +40,20 @@ def get_parser(): "--epoch", type=int, default=34, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", + help=( + "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." + ), ) parser.add_argument( "--avg", type=int, default=20, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. " + ), ) parser.add_argument( @@ -157,9 +160,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/conformer_ctc/label_smoothing.py b/egs/librispeech/ASR/conformer_ctc/label_smoothing.py index 1f2f3b137..cb0d6e04d 100644 --- a/egs/librispeech/ASR/conformer_ctc/label_smoothing.py +++ b/egs/librispeech/ASR/conformer_ctc/label_smoothing.py @@ -82,13 +82,10 @@ class LabelSmoothingLoss(torch.nn.Module): # for why we don't use target[ignored] = 0 here target = torch.where(ignored, torch.zeros_like(target), target) - true_dist = torch.nn.functional.one_hot( - target, num_classes=num_classes - ).to(x) + true_dist = torch.nn.functional.one_hot(target, num_classes=num_classes).to(x) true_dist = ( - true_dist * (1 - self.label_smoothing) - + self.label_smoothing / num_classes + true_dist * (1 - self.label_smoothing) + self.label_smoothing / num_classes ) # Set the value of ignored indexes to 0 diff --git a/egs/librispeech/ASR/conformer_ctc/pretrained.py b/egs/librispeech/ASR/conformer_ctc/pretrained.py index a2c0a5486..8cabf1a53 100755 --- a/egs/librispeech/ASR/conformer_ctc/pretrained.py +++ b/egs/librispeech/ASR/conformer_ctc/pretrained.py @@ -48,9 +48,11 @@ def get_parser(): "--checkpoint", type=str, required=True, - help="Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint().", + help=( + "Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint()." + ), ) parser.add_argument( @@ -189,10 +191,12 @@ def get_parser(): "sound_files", type=str, nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", + help=( + "The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz." + ), ) return parser @@ -236,10 +240,9 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" - ) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" # We use only the first channel ans.append(wave[0]) return ans @@ -300,9 +303,7 @@ def main(): logging.info("Decoding started") features = fbank(waves) - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) # Note: We don't use key padding mask for attention during decoding with torch.no_grad(): @@ -427,9 +428,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py index 542fb0364..8e0f73d05 100644 --- a/egs/librispeech/ASR/conformer_ctc/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -42,13 +42,9 @@ class Conv2dSubsampling(nn.Module): assert idim >= 7 super().__init__() self.conv = nn.Sequential( - nn.Conv2d( - in_channels=1, out_channels=odim, kernel_size=3, stride=2 - ), + nn.Conv2d(in_channels=1, out_channels=odim, kernel_size=3, stride=2), nn.ReLU(), - nn.Conv2d( - in_channels=odim, out_channels=odim, kernel_size=3, stride=2 - ), + nn.Conv2d(in_channels=odim, out_channels=odim, kernel_size=3, stride=2), nn.ReLU(), ) self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim) @@ -132,17 +128,13 @@ class VggSubsampling(nn.Module): ) ) layers.append( - torch.nn.MaxPool2d( - kernel_size=2, stride=2, padding=0, ceil_mode=True - ) + torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=True) ) cur_channels = block_dim self.layers = nn.Sequential(*layers) - self.out = nn.Linear( - block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim - ) + self.out = nn.Linear(block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim) def forward(self, x: torch.Tensor) -> torch.Tensor: """Subsample x. diff --git a/egs/librispeech/ASR/conformer_ctc/train.py b/egs/librispeech/ASR/conformer_ctc/train.py index 6419f6816..1a1c2f4c5 100755 --- a/egs/librispeech/ASR/conformer_ctc/train.py +++ b/egs/librispeech/ASR/conformer_ctc/train.py @@ -393,9 +393,7 @@ def compute_loss( # Works with a phone lexicon decoding_graph = graph_compiler.compile(texts) else: - raise ValueError( - f"Unsupported type of graph compiler: {type(graph_compiler)}" - ) + raise ValueError(f"Unsupported type of graph compiler: {type(graph_compiler)}") dense_fsa_vec = k2.DenseFsaVec( nnet_output, @@ -422,9 +420,7 @@ def compute_loss( # # See https://github.com/k2-fsa/icefall/issues/97 # for more details - unsorted_token_ids = graph_compiler.texts_to_ids( - supervisions["text"] - ) + unsorted_token_ids = graph_compiler.texts_to_ids(supervisions["text"]) att_loss = mmodel.decoder_forward( encoder_memory, memory_mask, @@ -453,9 +449,7 @@ def compute_loss( info["utt_duration"] = supervisions["num_frames"].sum().item() # averaged padding proportion over utterances info["utt_pad_proportion"] = ( - ((feature.size(1) - supervisions["num_frames"]) / feature.size(1)) - .sum() - .item() + ((feature.size(1) - supervisions["num_frames"]) / feature.size(1)).sum().item() ) return loss, info @@ -568,9 +562,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -660,7 +652,7 @@ def run(rank, world_size, args): graph_compiler.eos_id = 1 else: raise ValueError( - f"Unsupported type of lang dir (we expected it to have " + "Unsupported type of lang dir (we expected it to have " f"'lang_bpe' or 'lang_phone' in its name): {params.lang_dir}" ) @@ -733,9 +725,7 @@ def run(rank, world_size, args): cur_lr = optimizer._rate if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) + tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) if rank == 0: diff --git a/egs/librispeech/ASR/conformer_ctc/transformer.py b/egs/librispeech/ASR/conformer_ctc/transformer.py index 00ca027a7..0566cfc81 100644 --- a/egs/librispeech/ASR/conformer_ctc/transformer.py +++ b/egs/librispeech/ASR/conformer_ctc/transformer.py @@ -151,9 +151,7 @@ class Transformer(nn.Module): norm=decoder_norm, ) - self.decoder_output_layer = torch.nn.Linear( - d_model, self.decoder_num_class - ) + self.decoder_output_layer = torch.nn.Linear(d_model, self.decoder_num_class) self.decoder_criterion = LabelSmoothingLoss() else: @@ -181,18 +179,13 @@ class Transformer(nn.Module): memory_key_padding_mask for the decoder. Its shape is (N, T). It is None if `supervision` is None. """ - if ( - isinstance(self.use_feat_batchnorm, bool) - and self.use_feat_batchnorm - ): + if isinstance(self.use_feat_batchnorm, bool) and self.use_feat_batchnorm: x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T) x = self.feat_batchnorm(x) x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C) if isinstance(self.use_feat_batchnorm, float): x *= self.use_feat_batchnorm - encoder_memory, memory_key_padding_mask = self.run_encoder( - x, supervision - ) + encoder_memory, memory_key_padding_mask = self.run_encoder(x, supervision) x = self.ctc_output(encoder_memory) return x, encoder_memory, memory_key_padding_mask @@ -273,23 +266,17 @@ class Transformer(nn.Module): """ ys_in = add_sos(token_ids, sos_id=sos_id) ys_in = [torch.tensor(y) for y in ys_in] - ys_in_pad = pad_sequence( - ys_in, batch_first=True, padding_value=float(eos_id) - ) + ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id)) ys_out = add_eos(token_ids, eos_id=eos_id) ys_out = [torch.tensor(y) for y in ys_out] - ys_out_pad = pad_sequence( - ys_out, batch_first=True, padding_value=float(-1) - ) + ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1)) device = memory.device ys_in_pad = ys_in_pad.to(device) ys_out_pad = ys_out_pad.to(device) - tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( - device - ) + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) # TODO: Use length information to create the decoder padding mask @@ -350,23 +337,17 @@ class Transformer(nn.Module): ys_in = add_sos(token_ids, sos_id=sos_id) ys_in = [torch.tensor(y) for y in ys_in] - ys_in_pad = pad_sequence( - ys_in, batch_first=True, padding_value=float(eos_id) - ) + ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id)) ys_out = add_eos(token_ids, eos_id=eos_id) ys_out = [torch.tensor(y) for y in ys_out] - ys_out_pad = pad_sequence( - ys_out, batch_first=True, padding_value=float(-1) - ) + ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1)) device = memory.device ys_in_pad = ys_in_pad.to(device, dtype=torch.int64) ys_out_pad = ys_out_pad.to(device, dtype=torch.int64) - tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( - device - ) + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) # TODO: Use length information to create the decoder padding mask @@ -639,9 +620,7 @@ def _get_activation_fn(activation: str): elif activation == "gelu": return nn.functional.gelu - raise RuntimeError( - "activation should be relu/gelu, not {}".format(activation) - ) + raise RuntimeError("activation should be relu/gelu, not {}".format(activation)) class PositionalEncoding(nn.Module): @@ -843,9 +822,7 @@ def encoder_padding_mask( 1, ).to(torch.int32) - lengths = [ - 0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1) - ] + lengths = [0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)] for idx in range(supervision_segments.size(0)): # Note: TorchScript doesn't allow to unpack tensors as tuples sequence_idx = supervision_segments[idx, 0].item() @@ -866,9 +843,7 @@ def encoder_padding_mask( return mask -def decoder_padding_mask( - ys_pad: torch.Tensor, ignore_id: int = -1 -) -> torch.Tensor: +def decoder_padding_mask(ys_pad: torch.Tensor, ignore_id: int = -1) -> torch.Tensor: """Generate a length mask for input. The masked position are filled with True, diff --git a/egs/librispeech/ASR/conformer_ctc2/attention.py b/egs/librispeech/ASR/conformer_ctc2/attention.py index 1375d7245..356d3f21b 100644 --- a/egs/librispeech/ASR/conformer_ctc2/attention.py +++ b/egs/librispeech/ASR/conformer_ctc2/attention.py @@ -18,11 +18,10 @@ from typing import Optional, Tuple import torch import torch.nn as nn +from scaling import ScaledLinear from torch import Tensor from torch.nn.init import xavier_normal_ -from scaling import ScaledLinear - class MultiheadAttention(nn.Module): r"""Allows the model to jointly attend to information @@ -76,9 +75,7 @@ class MultiheadAttention(nn.Module): self.embed_dim = embed_dim self.kdim = kdim if kdim is not None else embed_dim self.vdim = vdim if vdim is not None else embed_dim - self._qkv_same_embed_dim = ( - self.kdim == embed_dim and self.vdim == embed_dim - ) + self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim self.num_heads = num_heads self.dropout = dropout @@ -94,9 +91,7 @@ class MultiheadAttention(nn.Module): self.v_proj_weight = ScaledLinear(self.vdim, embed_dim, bias=bias) self.register_parameter("in_proj_weight", None) else: - self.in_proj_weight = ScaledLinear( - embed_dim, 3 * embed_dim, bias=bias - ) + self.in_proj_weight = ScaledLinear(embed_dim, 3 * embed_dim, bias=bias) self.register_parameter("q_proj_weight", None) self.register_parameter("k_proj_weight", None) self.register_parameter("v_proj_weight", None) @@ -107,12 +102,8 @@ class MultiheadAttention(nn.Module): self.out_proj = ScaledLinear(embed_dim, embed_dim, bias=bias) if add_bias_kv: - self.bias_k = nn.Parameter( - torch.empty((1, 1, embed_dim), **factory_kwargs) - ) - self.bias_v = nn.Parameter( - torch.empty((1, 1, embed_dim), **factory_kwargs) - ) + self.bias_k = nn.Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs)) + self.bias_v = nn.Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs)) else: self.bias_k = self.bias_v = None diff --git a/egs/librispeech/ASR/conformer_ctc2/conformer.py b/egs/librispeech/ASR/conformer_ctc2/conformer.py index b906d2650..a6f1679ef 100644 --- a/egs/librispeech/ASR/conformer_ctc2/conformer.py +++ b/egs/librispeech/ASR/conformer_ctc2/conformer.py @@ -29,9 +29,8 @@ from scaling import ( ScaledConv1d, ScaledLinear, ) -from torch import Tensor, nn from subsampling import Conv2dSubsampling - +from torch import Tensor, nn from transformer import Supervisions, Transformer, encoder_padding_mask @@ -182,9 +181,7 @@ class ConformerEncoderLayer(nn.Module): self.d_model = d_model - self.self_attn = RelPositionMultiheadAttention( - d_model, nhead, dropout=0.0 - ) + self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) self.feed_forward = nn.Sequential( ScaledLinear(d_model, dim_feedforward), @@ -356,9 +353,7 @@ class RelPositionalEncoding(torch.nn.Module): """ - def __init__( - self, d_model: int, dropout_rate: float, max_len: int = 5000 - ) -> None: + def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: """Construct an PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() self.d_model = d_model @@ -373,9 +368,7 @@ class RelPositionalEncoding(torch.nn.Module): # the length of self.pe is 2 * input_len - 1 if self.pe.size(1) >= x.size(1) * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str( - x.device - ): + if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return # Suppose `i` means to the position of query vecotr and `j` means the @@ -650,9 +643,9 @@ class RelPositionMultiheadAttention(nn.Module): if torch.equal(query, key) and torch.equal(key, value): # self-attention - q, k, v = nn.functional.linear( - query, in_proj_weight, in_proj_bias - ).chunk(3, dim=-1) + q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk( + 3, dim=-1 + ) elif torch.equal(key, value): # encoder-decoder attention @@ -721,33 +714,25 @@ class RelPositionMultiheadAttention(nn.Module): if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: - raise RuntimeError( - "The size of the 2D attn_mask is not correct." - ) + raise RuntimeError("The size of the 2D attn_mask is not correct.") elif attn_mask.dim() == 3: if list(attn_mask.size()) != [ bsz * num_heads, query.size(0), key.size(0), ]: - raise RuntimeError( - "The size of the 3D attn_mask is not correct." - ) + raise RuntimeError("The size of the 3D attn_mask is not correct.") else: raise RuntimeError( - "attn_mask's dimension {} is not supported".format( - attn_mask.dim() - ) + "attn_mask's dimension {} is not supported".format(attn_mask.dim()) ) # attn_mask's dim is 3 now. # convert ByteTensor key_padding_mask to bool - if ( - key_padding_mask is not None - and key_padding_mask.dtype == torch.uint8 - ): + if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: warnings.warn( - "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." + "Byte tensor for key_padding_mask is deprecated. Use bool tensor" + " instead." ) key_padding_mask = key_padding_mask.to(torch.bool) @@ -784,9 +769,7 @@ class RelPositionMultiheadAttention(nn.Module): # first compute matrix a and matrix c # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - matrix_ac = torch.matmul( - q_with_bias_u, k - ) # (batch, head, time1, time2) + matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2) # compute matrix b and matrix d matrix_bd = torch.matmul( @@ -794,13 +777,9 @@ class RelPositionMultiheadAttention(nn.Module): ) # (batch, head, time1, 2*time1-1) matrix_bd = self.rel_shift(matrix_bd) - attn_output_weights = ( - matrix_ac + matrix_bd - ) # (batch, head, time1, time2) + attn_output_weights = matrix_ac + matrix_bd # (batch, head, time1, time2) - attn_output_weights = attn_output_weights.view( - bsz * num_heads, tgt_len, -1 - ) + attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) assert list(attn_output_weights.size()) == [ bsz * num_heads, @@ -834,13 +813,9 @@ class RelPositionMultiheadAttention(nn.Module): attn_output = torch.bmm(attn_output_weights, v) assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] attn_output = ( - attn_output.transpose(0, 1) - .contiguous() - .view(tgt_len, bsz, embed_dim) - ) - attn_output = nn.functional.linear( - attn_output, out_proj_weight, out_proj_bias + attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) ) + attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) if need_weights: # average attention weights over heads @@ -863,9 +838,7 @@ class ConvolutionModule(nn.Module): """ - def __init__( - self, channels: int, kernel_size: int, bias: bool = True - ) -> None: + def __init__(self, channels: int, kernel_size: int, bias: bool = True) -> None: """Construct an ConvolutionModule object.""" super(ConvolutionModule, self).__init__() # kernerl_size should be a odd number for 'SAME' padding diff --git a/egs/librispeech/ASR/conformer_ctc2/decode.py b/egs/librispeech/ASR/conformer_ctc2/decode.py index 97f2f2d39..934177b1f 100755 --- a/egs/librispeech/ASR/conformer_ctc2/decode.py +++ b/egs/librispeech/ASR/conformer_ctc2/decode.py @@ -90,9 +90,11 @@ def get_parser(): "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'" + ), ) parser.add_argument( @@ -130,11 +132,13 @@ def get_parser(): "--use-averaged-model", type=str2bool, default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", + help=( + "Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. " + ), ) parser.add_argument( @@ -658,9 +662,7 @@ def decode_dataset( results[lm_scale].extend(this_batch) else: - assert ( - len(results) > 0 - ), "It should not decode to empty in the first batch!" + assert len(results) > 0, "It should not decode to empty in the first batch!" this_batch = [] hyp_words = [] for ref_text in texts: @@ -675,9 +677,7 @@ def decode_dataset( if batch_idx % 100 == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -709,9 +709,7 @@ def save_results( test_set_wers[key] = wer if enable_log: - logging.info( - "Wrote detailed error stats to {}".format(errs_filename) - ) + logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt" @@ -852,13 +850,12 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -881,13 +878,12 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -915,7 +911,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - f"Calculating the averaged model over epoch range from " + "Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -985,9 +981,7 @@ def main(): eos_id=eos_id, ) - save_results( - params=params, test_set_name=test_set, results_dict=results_dict - ) + save_results(params=params, test_set_name=test_set, results_dict=results_dict) logging.info("Done!") diff --git a/egs/librispeech/ASR/conformer_ctc2/export.py b/egs/librispeech/ASR/conformer_ctc2/export.py index 584b3c3fc..0e1841d8d 100755 --- a/egs/librispeech/ASR/conformer_ctc2/export.py +++ b/egs/librispeech/ASR/conformer_ctc2/export.py @@ -47,6 +47,7 @@ import logging from pathlib import Path import torch +from conformer import Conformer from decode import get_params from icefall.checkpoint import ( @@ -55,10 +56,8 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from conformer import Conformer - -from icefall.utils import str2bool from icefall.lexicon import Lexicon +from icefall.utils import str2bool def get_parser(): @@ -89,20 +88,24 @@ def get_parser(): "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'" + ), ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", + help=( + "Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. " + ), ) parser.add_argument( @@ -177,13 +180,12 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -206,13 +208,12 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -240,7 +241,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - f"Calculating the averaged model over epoch range from " + "Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -273,9 +274,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/conformer_ctc2/train.py b/egs/librispeech/ASR/conformer_ctc2/train.py index 9d9c2af1f..63534b76b 100755 --- a/egs/librispeech/ASR/conformer_ctc2/train.py +++ b/egs/librispeech/ASR/conformer_ctc2/train.py @@ -69,8 +69,8 @@ from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter -from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler from icefall import diagnostics +from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler from icefall.checkpoint import load_checkpoint, remove_checkpoints from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.checkpoint import ( @@ -89,9 +89,7 @@ from icefall.utils import ( str2bool, ) -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def get_parser(): @@ -505,11 +503,7 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = ( - model.device - if isinstance(model, DDP) - else next(model.parameters()).device - ) + device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -546,9 +540,7 @@ def compute_loss( # Works with a phone lexicon decoding_graph = graph_compiler.compile(texts) else: - raise ValueError( - f"Unsupported type of graph compiler: {type(graph_compiler)}" - ) + raise ValueError(f"Unsupported type of graph compiler: {type(graph_compiler)}") dense_fsa_vec = k2.DenseFsaVec( nnet_output, @@ -575,9 +567,7 @@ def compute_loss( # # See https://github.com/k2-fsa/icefall/issues/97 # for more details - unsorted_token_ids = graph_compiler.texts_to_ids( - supervisions["text"] - ) + unsorted_token_ids = graph_compiler.texts_to_ids(supervisions["text"]) att_loss = mmodel.decoder_forward( encoder_memory, memory_mask, @@ -595,9 +585,7 @@ def compute_loss( info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() info["ctc_loss"] = ctc_loss.detach().cpu().item() if params.att_rate != 0.0: info["att_loss"] = att_loss.detach().cpu().item() @@ -735,8 +723,7 @@ def train_one_epoch( except RuntimeError as e: if "CUDA out of memory" in str(e): logging.error( - f"failing batch size:{batch_size} " - f"failing batch names {batch_name}" + f"failing batch size:{batch_size} failing batch names {batch_name}" ) raise @@ -791,9 +778,9 @@ def train_one_epoch( f"tot_loss[{tot_loss}], batch size: {batch_size}, " f"lr: {cur_lr:.2e}" ) - if loss_info["ctc_loss"] == float("inf") or loss_info[ - "att_loss" - ] == float("inf"): + if loss_info["ctc_loss"] == float("inf") or loss_info["att_loss"] == float( + "inf" + ): logging.error( "Your loss contains inf, something goes wrong" f"failing batch names {batch_name}" @@ -806,9 +793,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -900,7 +885,7 @@ def run(rank, world_size, args): graph_compiler.eos_id = 1 else: raise ValueError( - f"Unsupported type of lang dir (we expected it to have " + "Unsupported type of lang dir (we expected it to have " f"'lang_bpe' or 'lang_phone' in its name): {params.lang_dir}" ) diff --git a/egs/librispeech/ASR/conformer_ctc2/transformer.py b/egs/librispeech/ASR/conformer_ctc2/transformer.py index fa179acc0..8f0c7dcde 100644 --- a/egs/librispeech/ASR/conformer_ctc2/transformer.py +++ b/egs/librispeech/ASR/conformer_ctc2/transformer.py @@ -21,19 +21,17 @@ from typing import Dict, List, Optional, Tuple import torch import torch.nn as nn -from label_smoothing import LabelSmoothingLoss -from subsampling import Conv2dSubsampling from attention import MultiheadAttention -from torch.nn.utils.rnn import pad_sequence - +from label_smoothing import LabelSmoothingLoss from scaling import ( ActivationBalancer, BasicNorm, DoubleSwish, - ScaledLinear, ScaledEmbedding, + ScaledLinear, ) - +from subsampling import Conv2dSubsampling +from torch.nn.utils.rnn import pad_sequence # Note: TorchScript requires Dict/List/etc. to be fully typed. Supervisions = Dict[str, torch.Tensor] @@ -210,9 +208,7 @@ class Transformer(nn.Module): x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) mask = encoder_padding_mask(x.size(0), supervisions) mask = mask.to(x.device) if mask is not None else None - x = self.encoder( - x, src_key_padding_mask=mask, warmup=warmup - ) # (T, N, C) + x = self.encoder(x, src_key_padding_mask=mask, warmup=warmup) # (T, N, C) return x, mask @@ -261,23 +257,17 @@ class Transformer(nn.Module): """ ys_in = add_sos(token_ids, sos_id=sos_id) ys_in = [torch.tensor(y) for y in ys_in] - ys_in_pad = pad_sequence( - ys_in, batch_first=True, padding_value=float(eos_id) - ) + ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id)) ys_out = add_eos(token_ids, eos_id=eos_id) ys_out = [torch.tensor(y) for y in ys_out] - ys_out_pad = pad_sequence( - ys_out, batch_first=True, padding_value=float(-1) - ) + ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1)) device = memory.device ys_in_pad = ys_in_pad.to(device) ys_out_pad = ys_out_pad.to(device) - tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( - device - ) + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) # TODO: Use length information to create the decoder padding mask @@ -338,23 +328,17 @@ class Transformer(nn.Module): ys_in = add_sos(token_ids, sos_id=sos_id) ys_in = [torch.tensor(y) for y in ys_in] - ys_in_pad = pad_sequence( - ys_in, batch_first=True, padding_value=float(eos_id) - ) + ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id)) ys_out = add_eos(token_ids, eos_id=eos_id) ys_out = [torch.tensor(y) for y in ys_out] - ys_out_pad = pad_sequence( - ys_out, batch_first=True, padding_value=float(-1) - ) + ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1)) device = memory.device ys_in_pad = ys_in_pad.to(device, dtype=torch.int64) ys_out_pad = ys_out_pad.to(device, dtype=torch.int64) - tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( - device - ) + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) # TODO: Use length information to create the decoder padding mask @@ -659,9 +643,7 @@ def _get_activation_fn(activation: str): elif activation == "gelu": return nn.functional.gelu - raise RuntimeError( - "activation should be relu/gelu, not {}".format(activation) - ) + raise RuntimeError("activation should be relu/gelu, not {}".format(activation)) class TransformerEncoder(nn.Module): @@ -982,9 +964,7 @@ def encoder_padding_mask( 1, ).to(torch.int32) - lengths = [ - 0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1) - ] + lengths = [0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)] for idx in range(supervision_segments.size(0)): # Note: TorchScript doesn't allow to unpack tensors as tuples sequence_idx = supervision_segments[idx, 0].item() @@ -1005,9 +985,7 @@ def encoder_padding_mask( return mask -def decoder_padding_mask( - ys_pad: torch.Tensor, ignore_id: int = -1 -) -> torch.Tensor: +def decoder_padding_mask(ys_pad: torch.Tensor, ignore_id: int = -1) -> torch.Tensor: """Generate a length mask for input. The masked position are filled with True, diff --git a/egs/librispeech/ASR/conformer_mmi/conformer.py b/egs/librispeech/ASR/conformer_mmi/conformer.py index 97c8d83a2..4d9ddaea9 100644 --- a/egs/librispeech/ASR/conformer_mmi/conformer.py +++ b/egs/librispeech/ASR/conformer_mmi/conformer.py @@ -156,9 +156,7 @@ class ConformerEncoderLayer(nn.Module): normalize_before: bool = True, ) -> None: super(ConformerEncoderLayer, self).__init__() - self.self_attn = RelPositionMultiheadAttention( - d_model, nhead, dropout=0.0 - ) + self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) self.feed_forward = nn.Sequential( nn.Linear(d_model, dim_feedforward), @@ -176,18 +174,14 @@ class ConformerEncoderLayer(nn.Module): self.conv_module = ConvolutionModule(d_model, cnn_module_kernel) - self.norm_ff_macaron = nn.LayerNorm( - d_model - ) # for the macaron style FNN module + self.norm_ff_macaron = nn.LayerNorm(d_model) # for the macaron style FNN module self.norm_ff = nn.LayerNorm(d_model) # for the FNN module self.norm_mha = nn.LayerNorm(d_model) # for the MHA module self.ff_scale = 0.5 self.norm_conv = nn.LayerNorm(d_model) # for the CNN module - self.norm_final = nn.LayerNorm( - d_model - ) # for the final output of the block + self.norm_final = nn.LayerNorm(d_model) # for the final output of the block self.dropout = nn.Dropout(dropout) @@ -221,9 +215,7 @@ class ConformerEncoderLayer(nn.Module): residual = src if self.normalize_before: src = self.norm_ff_macaron(src) - src = residual + self.ff_scale * self.dropout( - self.feed_forward_macaron(src) - ) + src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src)) if not self.normalize_before: src = self.norm_ff_macaron(src) @@ -342,9 +334,7 @@ class RelPositionalEncoding(torch.nn.Module): """ - def __init__( - self, d_model: int, dropout_rate: float, max_len: int = 5000 - ) -> None: + def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: """Construct an PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() self.d_model = d_model @@ -360,9 +350,7 @@ class RelPositionalEncoding(torch.nn.Module): # the length of self.pe is 2 * input_len - 1 if self.pe.size(1) >= x.size(1) * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str( - x.device - ): + if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return # Suppose `i` means to the position of query vector and `j` means the @@ -632,9 +620,9 @@ class RelPositionMultiheadAttention(nn.Module): if torch.equal(query, key) and torch.equal(key, value): # self-attention - q, k, v = nn.functional.linear( - query, in_proj_weight, in_proj_bias - ).chunk(3, dim=-1) + q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk( + 3, dim=-1 + ) elif torch.equal(key, value): # encoder-decoder attention @@ -702,33 +690,25 @@ class RelPositionMultiheadAttention(nn.Module): if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: - raise RuntimeError( - "The size of the 2D attn_mask is not correct." - ) + raise RuntimeError("The size of the 2D attn_mask is not correct.") elif attn_mask.dim() == 3: if list(attn_mask.size()) != [ bsz * num_heads, query.size(0), key.size(0), ]: - raise RuntimeError( - "The size of the 3D attn_mask is not correct." - ) + raise RuntimeError("The size of the 3D attn_mask is not correct.") else: raise RuntimeError( - "attn_mask's dimension {} is not supported".format( - attn_mask.dim() - ) + "attn_mask's dimension {} is not supported".format(attn_mask.dim()) ) # attn_mask's dim is 3 now. # convert ByteTensor key_padding_mask to bool - if ( - key_padding_mask is not None - and key_padding_mask.dtype == torch.uint8 - ): + if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: warnings.warn( - "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." + "Byte tensor for key_padding_mask is deprecated. Use bool tensor" + " instead." ) key_padding_mask = key_padding_mask.to(torch.bool) @@ -765,9 +745,7 @@ class RelPositionMultiheadAttention(nn.Module): # first compute matrix a and matrix c # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - matrix_ac = torch.matmul( - q_with_bias_u, k - ) # (batch, head, time1, time2) + matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2) # compute matrix b and matrix d matrix_bd = torch.matmul( @@ -779,9 +757,7 @@ class RelPositionMultiheadAttention(nn.Module): matrix_ac + matrix_bd ) * scaling # (batch, head, time1, time2) - attn_output_weights = attn_output_weights.view( - bsz * num_heads, tgt_len, -1 - ) + attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) assert list(attn_output_weights.size()) == [ bsz * num_heads, @@ -815,13 +791,9 @@ class RelPositionMultiheadAttention(nn.Module): attn_output = torch.bmm(attn_output_weights, v) assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] attn_output = ( - attn_output.transpose(0, 1) - .contiguous() - .view(tgt_len, bsz, embed_dim) - ) - attn_output = nn.functional.linear( - attn_output, out_proj_weight, out_proj_bias + attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) ) + attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) if need_weights: # average attention weights over heads @@ -844,9 +816,7 @@ class ConvolutionModule(nn.Module): """ - def __init__( - self, channels: int, kernel_size: int, bias: bool = True - ) -> None: + def __init__(self, channels: int, kernel_size: int, bias: bool = True) -> None: """Construct an ConvolutionModule object.""" super(ConvolutionModule, self).__init__() # kernerl_size should be a odd number for 'SAME' padding diff --git a/egs/librispeech/ASR/conformer_mmi/decode.py b/egs/librispeech/ASR/conformer_mmi/decode.py index fc9861489..e8390ded9 100755 --- a/egs/librispeech/ASR/conformer_mmi/decode.py +++ b/egs/librispeech/ASR/conformer_mmi/decode.py @@ -60,16 +60,19 @@ def get_parser(): "--epoch", type=int, default=34, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", + help=( + "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." + ), ) parser.add_argument( "--avg", type=int, default=20, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. " + ), ) parser.add_argument( @@ -478,9 +481,7 @@ def decode_dataset( if batch_idx % 100 == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -512,9 +513,7 @@ def save_results( test_set_wers[key] = wer if enable_log: - logging.info( - "Wrote detailed error stats to {}".format(errs_filename) - ) + logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt" @@ -653,9 +652,7 @@ def main(): if params.export: logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt") - torch.save( - {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt" - ) + torch.save({"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt") return model.to(device) @@ -687,9 +684,7 @@ def main(): eos_id=eos_id, ) - save_results( - params=params, test_set_name=test_set, results_dict=results_dict - ) + save_results(params=params, test_set_name=test_set, results_dict=results_dict) logging.info("Done!") diff --git a/egs/librispeech/ASR/conformer_mmi/subsampling.py b/egs/librispeech/ASR/conformer_mmi/subsampling.py index 5c3e1222e..ad9415987 100644 --- a/egs/librispeech/ASR/conformer_mmi/subsampling.py +++ b/egs/librispeech/ASR/conformer_mmi/subsampling.py @@ -25,13 +25,9 @@ class Conv2dSubsampling(nn.Module): assert idim >= 7 super().__init__() self.conv = nn.Sequential( - nn.Conv2d( - in_channels=1, out_channels=odim, kernel_size=3, stride=2 - ), + nn.Conv2d(in_channels=1, out_channels=odim, kernel_size=3, stride=2), nn.ReLU(), - nn.Conv2d( - in_channels=odim, out_channels=odim, kernel_size=3, stride=2 - ), + nn.Conv2d(in_channels=odim, out_channels=odim, kernel_size=3, stride=2), nn.ReLU(), ) self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim) @@ -115,17 +111,13 @@ class VggSubsampling(nn.Module): ) ) layers.append( - torch.nn.MaxPool2d( - kernel_size=2, stride=2, padding=0, ceil_mode=True - ) + torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=True) ) cur_channels = block_dim self.layers = nn.Sequential(*layers) - self.out = nn.Linear( - block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim - ) + self.out = nn.Linear(block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim) def forward(self, x: torch.Tensor) -> torch.Tensor: """Subsample x. diff --git a/egs/librispeech/ASR/conformer_mmi/test_subsampling.py b/egs/librispeech/ASR/conformer_mmi/test_subsampling.py index 937845d77..d0bb017dd 100755 --- a/egs/librispeech/ASR/conformer_mmi/test_subsampling.py +++ b/egs/librispeech/ASR/conformer_mmi/test_subsampling.py @@ -1,8 +1,7 @@ #!/usr/bin/env python3 -from subsampling import Conv2dSubsampling -from subsampling import VggSubsampling import torch +from subsampling import Conv2dSubsampling, VggSubsampling def test_conv2d_subsampling(): diff --git a/egs/librispeech/ASR/conformer_mmi/test_transformer.py b/egs/librispeech/ASR/conformer_mmi/test_transformer.py index 08e680607..25d18076d 100644 --- a/egs/librispeech/ASR/conformer_mmi/test_transformer.py +++ b/egs/librispeech/ASR/conformer_mmi/test_transformer.py @@ -1,17 +1,16 @@ #!/usr/bin/env python3 import torch +from torch.nn.utils.rnn import pad_sequence from transformer import ( Transformer, + add_eos, + add_sos, + decoder_padding_mask, encoder_padding_mask, generate_square_subsequent_mask, - decoder_padding_mask, - add_sos, - add_eos, ) -from torch.nn.utils.rnn import pad_sequence - def test_encoder_padding_mask(): supervisions = { diff --git a/egs/librispeech/ASR/conformer_mmi/train-with-attention.py b/egs/librispeech/ASR/conformer_mmi/train-with-attention.py index 011dadd73..f8c94cff9 100755 --- a/egs/librispeech/ASR/conformer_mmi/train-with-attention.py +++ b/egs/librispeech/ASR/conformer_mmi/train-with-attention.py @@ -36,23 +36,14 @@ from torch.nn.utils import clip_grad_norm_ from torch.utils.tensorboard import SummaryWriter from transformer import Noam -from icefall.ali import ( - convert_alignments_to_tensor, - load_alignments, - lookup_alignments, -) +from icefall.ali import convert_alignments_to_tensor, load_alignments, lookup_alignments from icefall.checkpoint import load_checkpoint from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.dist import cleanup_dist, setup_dist from icefall.lexicon import Lexicon from icefall.mmi import LFMMILoss from icefall.mmi_graph_compiler import MmiTrainingGraphCompiler -from icefall.utils import ( - AttributeDict, - encode_supervisions, - setup_logger, - str2bool, -) +from icefall.utils import AttributeDict, encode_supervisions, setup_logger, str2bool def get_parser(): @@ -370,10 +361,7 @@ def compute_loss( nnet_output = nnet_output.clone() nnet_output[:, :min_len, :] += ali_scale * mask[:, :min_len, :] - if ( - params.batch_idx_train > params.use_ali_until - and params.beam_size < 8 - ): + if params.batch_idx_train > params.use_ali_until and params.beam_size < 8: # logging.info("Change beam size to 8") params.beam_size = 8 else: @@ -762,19 +750,14 @@ def run(rank, world_size, args): for epoch in range(params.start_epoch, params.num_epochs): train_dl.sampler.set_epoch(epoch) - if ( - params.batch_idx_train >= params.use_ali_until - and train_ali is not None - ): + if params.batch_idx_train >= params.use_ali_until and train_ali is not None: # Delete the alignments to save memory train_ali = None valid_ali = None cur_lr = optimizer._rate if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) + tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) if rank == 0: diff --git a/egs/librispeech/ASR/conformer_mmi/train.py b/egs/librispeech/ASR/conformer_mmi/train.py index 9a5bdcce2..5cfb2bfc7 100755 --- a/egs/librispeech/ASR/conformer_mmi/train.py +++ b/egs/librispeech/ASR/conformer_mmi/train.py @@ -36,23 +36,14 @@ from torch.nn.utils import clip_grad_norm_ from torch.utils.tensorboard import SummaryWriter from transformer import Noam -from icefall.ali import ( - convert_alignments_to_tensor, - load_alignments, - lookup_alignments, -) +from icefall.ali import convert_alignments_to_tensor, load_alignments, lookup_alignments from icefall.checkpoint import load_checkpoint from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.dist import cleanup_dist, setup_dist from icefall.lexicon import Lexicon from icefall.mmi import LFMMILoss from icefall.mmi_graph_compiler import MmiTrainingGraphCompiler -from icefall.utils import ( - AttributeDict, - encode_supervisions, - setup_logger, - str2bool, -) +from icefall.utils import AttributeDict, encode_supervisions, setup_logger, str2bool def get_parser(): @@ -377,10 +368,7 @@ def compute_loss( nnet_output = nnet_output.clone() nnet_output[:, :min_len, :] += ali_scale * mask[:, :min_len, :] - if ( - params.batch_idx_train > params.use_ali_until - and params.beam_size < 8 - ): + if params.batch_idx_train > params.use_ali_until and params.beam_size < 8: logging.info("Change beam size to 8") params.beam_size = 8 else: @@ -770,19 +758,14 @@ def run(rank, world_size, args): for epoch in range(params.start_epoch, params.num_epochs): fix_random_seed(params.seed + epoch) train_dl.sampler.set_epoch(epoch) - if ( - params.batch_idx_train >= params.use_ali_until - and train_ali is not None - ): + if params.batch_idx_train >= params.use_ali_until and train_ali is not None: # Delete the alignments to save memory train_ali = None valid_ali = None cur_lr = optimizer._rate if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) + tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) if rank == 0: diff --git a/egs/librispeech/ASR/conformer_mmi/transformer.py b/egs/librispeech/ASR/conformer_mmi/transformer.py index 68a4ff65c..2542d9abe 100644 --- a/egs/librispeech/ASR/conformer_mmi/transformer.py +++ b/egs/librispeech/ASR/conformer_mmi/transformer.py @@ -148,9 +148,7 @@ class Transformer(nn.Module): norm=decoder_norm, ) - self.decoder_output_layer = torch.nn.Linear( - d_model, self.decoder_num_class - ) + self.decoder_output_layer = torch.nn.Linear(d_model, self.decoder_num_class) self.decoder_criterion = LabelSmoothingLoss(self.decoder_num_class) else: @@ -182,9 +180,7 @@ class Transformer(nn.Module): x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T) x = self.feat_batchnorm(x) x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C) - encoder_memory, memory_key_padding_mask = self.run_encoder( - x, supervision - ) + encoder_memory, memory_key_padding_mask = self.run_encoder(x, supervision) x = self.ctc_output(encoder_memory) return x, encoder_memory, memory_key_padding_mask @@ -274,9 +270,7 @@ class Transformer(nn.Module): ys_in_pad = ys_in_pad.to(device) ys_out_pad = ys_out_pad.to(device) - tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( - device - ) + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) # TODO: Use length information to create the decoder padding mask @@ -341,9 +335,7 @@ class Transformer(nn.Module): ys_in_pad = ys_in_pad.to(device, dtype=torch.int64) ys_out_pad = ys_out_pad.to(device, dtype=torch.int64) - tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( - device - ) + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) # TODO: Use length information to create the decoder padding mask @@ -616,9 +608,7 @@ def _get_activation_fn(activation: str): elif activation == "gelu": return nn.functional.gelu - raise RuntimeError( - "activation should be relu/gelu, not {}".format(activation) - ) + raise RuntimeError("activation should be relu/gelu, not {}".format(activation)) class PositionalEncoding(nn.Module): @@ -887,9 +877,7 @@ def encoder_padding_mask( 1, ).to(torch.int32) - lengths = [ - 0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1) - ] + lengths = [0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)] for idx in range(supervision_segments.size(0)): # Note: TorchScript doesn't allow to unpack tensors as tuples sequence_idx = supervision_segments[idx, 0].item() @@ -910,9 +898,7 @@ def encoder_padding_mask( return mask -def decoder_padding_mask( - ys_pad: torch.Tensor, ignore_id: int = -1 -) -> torch.Tensor: +def decoder_padding_mask(ys_pad: torch.Tensor, ignore_id: int = -1) -> torch.Tensor: """Generate a length mask for input. The masked position are filled with True, diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/decode.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/decode.py index 620d69a19..a1c43f7f5 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/decode.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/decode.py @@ -135,20 +135,24 @@ def get_parser(): "--avg", type=int, default=10, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'" + ), ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", + help=( + "Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. " + ), ) parser.add_argument( @@ -215,8 +219,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( @@ -284,9 +287,7 @@ def decode_one_batch( value=LOG_EPS, ) - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) hyps = [] if params.decoding_method == "fast_beam_search": @@ -301,10 +302,7 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -350,11 +348,7 @@ def decode_one_batch( return {"greedy_search": hyps} elif params.decoding_method == "fast_beam_search": return { - ( - f"beam_{params.beam}_" - f"max_contexts_{params.max_contexts}_" - f"max_states_{params.max_states}" - ): hyps + f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps } else: return {f"beam_size_{params.beam_size}": hyps} @@ -427,9 +421,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -462,8 +454,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -506,9 +497,7 @@ def main(): params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -540,13 +529,12 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -569,13 +557,12 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -603,7 +590,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - f"Calculating the averaged model over epoch range from " + "Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py index 8ca7d5568..0639ba746 100644 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py @@ -35,7 +35,6 @@ from scaling import ( from icefall.utils import make_pad_mask - LOG_EPSILON = math.log(1e-10) @@ -127,9 +126,7 @@ def stack_states( for si, s in enumerate(layer): attn_caches[li][si].append(s) if b == batch_size - 1: - attn_caches[li][si] = torch.stack( - attn_caches[li][si], dim=1 - ) + attn_caches[li][si] = torch.stack(attn_caches[li][si], dim=1) conv_caches = [] for layer in state_list[0][1]: @@ -268,9 +265,7 @@ class ConvolutionModule(nn.Module): intervals = torch.arange( 0, self.chunk_length * (num_chunks - 1), self.chunk_length ) - first = torch.arange( - self.chunk_length, self.chunk_length + self.cache_size - ) + first = torch.arange(self.chunk_length, self.chunk_length + self.cache_size) indexes = intervals.unsqueeze(1) + first.unsqueeze(0) indexes = torch.cat( [indexes, torch.arange(U_ - self.cache_size, U_).unsqueeze(0)] @@ -284,9 +279,7 @@ class ConvolutionModule(nn.Module): # (num_chunks * B, cache_size + right_context_length, D) return pad_right_context.permute(0, 2, 1) - def _merge_right_context( - self, right_context: torch.Tensor, B: int - ) -> torch.Tensor: + def _merge_right_context(self, right_context: torch.Tensor, B: int) -> torch.Tensor: """ Args: right_context: @@ -337,12 +330,8 @@ class ConvolutionModule(nn.Module): right_context = x[:, :, :R] # (B, D, R) # make causal convolution - cache = torch.zeros( - B, D, self.cache_size, device=x.device, dtype=x.dtype - ) - pad_utterance = torch.cat( - [cache, utterance], dim=2 - ) # (B, D, cache + U) + cache = torch.zeros(B, D, self.cache_size, device=x.device, dtype=x.dtype) + pad_utterance = torch.cat([cache, utterance], dim=2) # (B, D, cache + U) # depth-wise conv on utterance utterance = self.depthwise_conv(pad_utterance) # (B, D, U) @@ -355,9 +344,7 @@ class ConvolutionModule(nn.Module): right_context = self.depthwise_conv( pad_right_context ) # (num_segs * B, D, right_context_length) - right_context = self._merge_right_context( - right_context, B - ) # (B, D, R) + right_context = self._merge_right_context(right_context, B) # (B, D, R) x = torch.cat([right_context, utterance], dim=2) # (B, D, R + U) x = self.deriv_balancer2(x) @@ -458,8 +445,7 @@ class EmformerAttention(nn.Module): if embed_dim % nhead != 0: raise ValueError( - f"embed_dim ({embed_dim}) is not a multiple of" - f"nhead ({nhead})." + f"embed_dim ({embed_dim}) is not a multiple ofnhead ({nhead})." ) self.embed_dim = embed_dim @@ -469,9 +455,7 @@ class EmformerAttention(nn.Module): self.head_dim = embed_dim // nhead self.dropout = dropout - self.emb_to_key_value = ScaledLinear( - embed_dim, 2 * embed_dim, bias=True - ) + self.emb_to_key_value = ScaledLinear(embed_dim, 2 * embed_dim, bias=True) self.emb_to_query = ScaledLinear(embed_dim, embed_dim, bias=True) self.out_proj = ScaledLinear( embed_dim, embed_dim, bias=True, initial_scale=0.25 @@ -513,9 +497,7 @@ class EmformerAttention(nn.Module): if padding_mask is not None: Q = attention_weights.size(1) B = attention_weights.size(0) // self.nhead - attention_weights_float = attention_weights_float.view( - B, self.nhead, Q, -1 - ) + attention_weights_float = attention_weights_float.view(B, self.nhead, Q, -1) attention_weights_float = attention_weights_float.masked_fill( padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), self.negative_inf, @@ -551,9 +533,7 @@ class EmformerAttention(nn.Module): scaling = float(self.head_dim) ** -0.5 # compute query with [right_context, utterance, summary]. - query = self.emb_to_query( - torch.cat([right_context, utterance, summary]) - ) + query = self.emb_to_query(torch.cat([right_context, utterance, summary])) # compute key and value with [memory, right_context, utterance]. key, value = self.emb_to_key_value( torch.cat([memory, right_context, utterance]) @@ -564,16 +544,12 @@ class EmformerAttention(nn.Module): # [memory, right context, left context, uttrance] # this is used in inference mode key = torch.cat([key[: M + R], left_context_key, key[M + R :]]) - value = torch.cat( - [value[: M + R], left_context_val, value[M + R :]] - ) + value = torch.cat([value[: M + R], left_context_val, value[M + R :]]) Q = query.size(0) # KV = key.size(0) reshaped_query, reshaped_key, reshaped_value = [ - tensor.contiguous() - .view(-1, B * self.nhead, self.head_dim) - .transpose(0, 1) + tensor.contiguous().view(-1, B * self.nhead, self.head_dim).transpose(0, 1) for tensor in [query, key, value] ] # (B * nhead, Q or KV, head_dim) attention_weights = torch.bmm( @@ -588,9 +564,7 @@ class EmformerAttention(nn.Module): # compute attention outputs attention = torch.bmm(attention_probs, reshaped_value) assert attention.shape == (B * self.nhead, Q, self.head_dim) - attention = ( - attention.transpose(0, 1).contiguous().view(Q, B, self.embed_dim) - ) + attention = attention.transpose(0, 1).contiguous().view(Q, B, self.embed_dim) # apply output projection outputs = self.out_proj(attention) @@ -672,12 +646,7 @@ class EmformerAttention(nn.Module): - output of right context and utterance, with shape (R + U, B, D). - memory output, with shape (M, B, D), where M = S - 1 or M = 0. """ - ( - output_right_context_utterance, - output_memory, - _, - _, - ) = self._forward_impl( + (output_right_context_utterance, output_memory, _, _,) = self._forward_impl( utterance, right_context, summary, @@ -947,13 +916,9 @@ class EmformerEncoderLayer(nn.Module): right_context = right_context_utterance[:R] if self.use_memory: - summary = self.summary_op(utterance.permute(1, 2, 0)).permute( - 2, 0, 1 - ) + summary = self.summary_op(utterance.permute(1, 2, 0)).permute(2, 0, 1) else: - summary = torch.empty(0).to( - dtype=utterance.dtype, device=utterance.device - ) + summary = torch.empty(0).to(dtype=utterance.dtype, device=utterance.device) output_right_context_utterance, output_memory = self.attention( utterance=utterance, right_context=right_context, @@ -992,14 +957,10 @@ class EmformerEncoderLayer(nn.Module): left_context_val = attn_cache[2] if self.use_memory: - summary = self.summary_op(utterance.permute(1, 2, 0)).permute( - 2, 0, 1 - ) + summary = self.summary_op(utterance.permute(1, 2, 0)).permute(2, 0, 1) summary = summary[:1] else: - summary = torch.empty(0).to( - dtype=utterance.dtype, device=utterance.device - ) + summary = torch.empty(0).to(dtype=utterance.dtype, device=utterance.device) ( output_right_context_utterance, output_memory, @@ -1014,9 +975,7 @@ class EmformerEncoderLayer(nn.Module): left_context_val=left_context_val, padding_mask=padding_mask, ) - attn_cache = self._update_attn_cache( - next_key, next_val, memory, attn_cache - ) + attn_cache = self._update_attn_cache(next_key, next_val, memory, attn_cache) return output_right_context_utterance, output_memory, attn_cache def forward( @@ -1151,11 +1110,7 @@ class EmformerEncoderLayer(nn.Module): src = src + self.dropout(self.feed_forward_macaron(src)) # emformer attention module - ( - src_att, - output_memory, - attn_cache, - ) = self._apply_attention_module_infer( + (src_att, output_memory, attn_cache,) = self._apply_attention_module_infer( src, R, memory, attn_cache, padding_mask=padding_mask ) src = src + self.dropout(src_att) @@ -1295,9 +1250,7 @@ class EmformerEncoder(nn.Module): def _gen_right_context(self, x: torch.Tensor) -> torch.Tensor: """Hard copy each chunk's right context and concat them.""" T = x.shape[0] - num_chunks = math.ceil( - (T - self.right_context_length) / self.chunk_length - ) + num_chunks = math.ceil((T - self.right_context_length) / self.chunk_length) # first (num_chunks - 1) right context block intervals = torch.arange( 0, self.chunk_length * (num_chunks - 1), self.chunk_length @@ -1316,9 +1269,7 @@ class EmformerEncoder(nn.Module): right_context_blocks = x[indexes.reshape(-1)] return right_context_blocks - def _gen_attention_mask_col_widths( - self, chunk_idx: int, U: int - ) -> List[int]: + def _gen_attention_mask_col_widths(self, chunk_idx: int, U: int) -> List[int]: """Calculate column widths (key, value) in attention mask for the chunk_idx chunk.""" num_chunks = math.ceil(U / self.chunk_length) @@ -1479,9 +1430,7 @@ class EmformerEncoder(nn.Module): output_lengths = torch.clamp(lengths - self.right_context_length, min=0) attention_mask = self._gen_attention_mask(utterance) memory = ( - self.init_memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)[ - :-1 - ] + self.init_memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)[:-1] if self.use_memory else torch.empty(0).to(dtype=x.dtype, device=x.device) ) @@ -1643,12 +1592,8 @@ class EmformerEncoder(nn.Module): attn_caches = [ [ torch.zeros(self.memory_size, self.d_model, device=device), - torch.zeros( - self.left_context_length, self.d_model, device=device - ), - torch.zeros( - self.left_context_length, self.d_model, device=device - ), + torch.zeros(self.left_context_length, self.d_model, device=device), + torch.zeros(self.left_context_length, self.d_model, device=device), ] for _ in range(self.num_encoder_layers) ] @@ -1693,17 +1638,11 @@ class Emformer(EncoderInterface): raise NotImplementedError( "chunk_length must be a mutiple of subsampling_factor." ) - if ( - left_context_length != 0 - and left_context_length % subsampling_factor != 0 - ): + if left_context_length != 0 and left_context_length % subsampling_factor != 0: raise NotImplementedError( "left_context_length must be 0 or a mutiple of subsampling_factor." # noqa ) - if ( - right_context_length != 0 - and right_context_length % subsampling_factor != 0 - ): + if right_context_length != 0 and right_context_length % subsampling_factor != 0: raise NotImplementedError( "right_context_length must be 0 or a mutiple of subsampling_factor." # noqa ) @@ -1766,9 +1705,7 @@ class Emformer(EncoderInterface): x_lens = (((x_lens - 1) >> 1) - 1) >> 1 assert x.size(0) == x_lens.max().item() - output, output_lengths = self.encoder( - x, x_lens, warmup=warmup - ) # (T, N, C) + output, output_lengths = self.encoder(x, x_lens, warmup=warmup) # (T, N, C) output = output.permute(1, 0, 2) # (T, N, C) -> (N, T, C) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/export.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/export.py index 4930881ea..59105e286 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/export.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/export.py @@ -103,9 +103,11 @@ def get_parser(): "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'" + ), ) parser.add_argument( @@ -136,19 +138,20 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", + help=( + "Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. " + ), ) add_model_arguments(parser) @@ -181,13 +184,12 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -210,13 +212,12 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -244,7 +245,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - f"Calculating the averaged model over epoch range from " + "Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -279,9 +280,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/stream.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/stream.py index 9494e1fc1..c211b215e 100644 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/stream.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/stream.py @@ -68,14 +68,12 @@ class Stream(object): elif params.decoding_method == "fast_beam_search": # feature_len is needed to get partial results. # The rnnt_decoding_stream for fast_beam_search. - self.rnnt_decoding_stream: k2.RnntDecodingStream = ( - k2.RnntDecodingStream(decoding_graph) + self.rnnt_decoding_stream: k2.RnntDecodingStream = k2.RnntDecodingStream( + decoding_graph ) self.hyp: Optional[List[int]] = None else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") self.ground_truth: str = "" diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py index 61dbe8658..abe83732a 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py @@ -113,8 +113,9 @@ def get_parser(): "--epoch", type=int, default=28, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", + help=( + "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." + ), ) parser.add_argument( @@ -131,20 +132,24 @@ def get_parser(): "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. " + ), ) parser.add_argument( "--use-averaged-model", type=str2bool, default=False, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", + help=( + "Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. " + ), ) parser.add_argument( @@ -211,8 +216,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -371,9 +375,7 @@ def modified_beam_search( index=hyps_shape.row_ids(1).to(torch.int64), ) # (num_hyps, encoder_out_dim) - logits = model.joiner( - current_encoder_out, decoder_out, project_input=False - ) + logits = model.joiner(current_encoder_out, decoder_out, project_input=False) # logits is of shape (num_hyps, 1, 1, vocab_size) logits = logits.squeeze(1).squeeze(1) @@ -390,9 +392,7 @@ def modified_beam_search( log_probs_shape = k2.ragged.create_ragged_shape2( row_splits=row_splits, cached_tot_size=log_probs.numel() ) - ragged_log_probs = k2.RaggedTensor( - shape=log_probs_shape, value=log_probs - ) + ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) for i in range(batch_size): topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) @@ -551,14 +551,10 @@ def decode_one_chunk( feature_list, batch_first=True, padding_value=LOG_EPSILON ).to(device) feature_lens = torch.tensor(feature_len_list, device=device) - num_processed_frames = torch.tensor( - num_processed_frames_list, device=device - ) + num_processed_frames = torch.tensor(num_processed_frames_list, device=device) # Make sure it has at least 1 frame after subsampling, first-and-last-frame cutting, and right context cutting # noqa - tail_length = ( - 3 * params.subsampling_factor + params.right_context_length + 3 - ) + tail_length = 3 * params.subsampling_factor + params.right_context_length + 3 if features.size(1) < tail_length: pad_length = tail_length - features.size(1) feature_lens += pad_length @@ -605,9 +601,7 @@ def decode_one_chunk( max_states=params.max_states, ) else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") # Update cached states of each stream state_list = unstack_states(states) @@ -782,8 +776,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -831,9 +824,7 @@ def main(): params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -867,13 +858,12 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -896,13 +886,12 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -930,7 +919,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - f"Calculating the averaged model over epoch range from " + "Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py index c07d8f76b..a76417e5f 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py @@ -95,9 +95,7 @@ from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def add_model_arguments(parser: argparse.ArgumentParser): @@ -265,42 +263,45 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", + help=( + "The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss" + ), ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", + help=( + "The scale to smooth the loss with lm (output of prediction network) part." + ), ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network)part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", + help=( + "To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss." + ), ) parser.add_argument( @@ -636,11 +637,7 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = ( - model.device - if isinstance(model, DDP) - else next(model.parameters()).device - ) + device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -668,23 +665,16 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 - if warmup < 1.0 - else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = ( - params.simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss + 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) ) + loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa info["utterances"] = feature.size(0) @@ -871,9 +861,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -981,7 +969,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2 ** 22 + 2**22 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/decode.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/decode.py index 98b8290b5..9cb4a5afc 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/decode.py @@ -135,20 +135,24 @@ def get_parser(): "--avg", type=int, default=10, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'" + ), ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", + help=( + "Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. " + ), ) parser.add_argument( @@ -215,8 +219,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( @@ -284,9 +287,7 @@ def decode_one_batch( value=LOG_EPS, ) - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) hyps = [] if params.decoding_method == "fast_beam_search": @@ -301,10 +302,7 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -350,11 +348,7 @@ def decode_one_batch( return {"greedy_search": hyps} elif params.decoding_method == "fast_beam_search": return { - ( - f"beam_{params.beam}_" - f"max_contexts_{params.max_contexts}_" - f"max_states_{params.max_states}" - ): hyps + f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps } else: return {f"beam_size_{params.beam_size}": hyps} @@ -427,9 +421,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -462,8 +454,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -506,9 +497,7 @@ def main(): params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -540,13 +529,12 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -569,13 +557,12 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -603,7 +590,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - f"Calculating the averaged model over epoch range from " + "Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer.py index f16f5acc7..09200f2e1 100644 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer.py @@ -35,7 +35,6 @@ from scaling import ( from icefall.utils import make_pad_mask - LOG_EPSILON = math.log(1e-10) @@ -127,9 +126,7 @@ def stack_states( for si, s in enumerate(layer): attn_caches[li][si].append(s) if b == batch_size - 1: - attn_caches[li][si] = torch.stack( - attn_caches[li][si], dim=1 - ) + attn_caches[li][si] = torch.stack(attn_caches[li][si], dim=1) conv_caches = [] for layer in state_list[0][1]: @@ -268,9 +265,7 @@ class ConvolutionModule(nn.Module): intervals = torch.arange( 0, self.chunk_length * (num_chunks - 1), self.chunk_length ) - first = torch.arange( - self.chunk_length, self.chunk_length + self.cache_size - ) + first = torch.arange(self.chunk_length, self.chunk_length + self.cache_size) indexes = intervals.unsqueeze(1) + first.unsqueeze(0) indexes = torch.cat( [indexes, torch.arange(U_ - self.cache_size, U_).unsqueeze(0)] @@ -284,9 +279,7 @@ class ConvolutionModule(nn.Module): # (num_chunks * B, cache_size + right_context_length, D) return pad_right_context.permute(0, 2, 1) - def _merge_right_context( - self, right_context: torch.Tensor, B: int - ) -> torch.Tensor: + def _merge_right_context(self, right_context: torch.Tensor, B: int) -> torch.Tensor: """ Args: right_context: @@ -337,12 +330,8 @@ class ConvolutionModule(nn.Module): right_context = x[:, :, :R] # (B, D, R) # make causal convolution - cache = torch.zeros( - B, D, self.cache_size, device=x.device, dtype=x.dtype - ) - pad_utterance = torch.cat( - [cache, utterance], dim=2 - ) # (B, D, cache + U) + cache = torch.zeros(B, D, self.cache_size, device=x.device, dtype=x.dtype) + pad_utterance = torch.cat([cache, utterance], dim=2) # (B, D, cache + U) # depth-wise conv on utterance utterance = self.depthwise_conv(pad_utterance) # (B, D, U) @@ -355,9 +344,7 @@ class ConvolutionModule(nn.Module): right_context = self.depthwise_conv( pad_right_context ) # (num_segs * B, D, right_context_length) - right_context = self._merge_right_context( - right_context, B - ) # (B, D, R) + right_context = self._merge_right_context(right_context, B) # (B, D, R) x = torch.cat([right_context, utterance], dim=2) # (B, D, R + U) x = self.deriv_balancer2(x) @@ -458,8 +445,7 @@ class EmformerAttention(nn.Module): if embed_dim % nhead != 0: raise ValueError( - f"embed_dim ({embed_dim}) is not a multiple of" - f"nhead ({nhead})." + f"embed_dim ({embed_dim}) is not a multiple ofnhead ({nhead})." ) self.embed_dim = embed_dim @@ -469,9 +455,7 @@ class EmformerAttention(nn.Module): self.head_dim = embed_dim // nhead self.dropout = dropout - self.emb_to_key_value = ScaledLinear( - embed_dim, 2 * embed_dim, bias=True - ) + self.emb_to_key_value = ScaledLinear(embed_dim, 2 * embed_dim, bias=True) self.emb_to_query = ScaledLinear(embed_dim, embed_dim, bias=True) self.out_proj = ScaledLinear( embed_dim, embed_dim, bias=True, initial_scale=0.25 @@ -513,9 +497,7 @@ class EmformerAttention(nn.Module): if padding_mask is not None: Q = attention_weights.size(1) B = attention_weights.size(0) // self.nhead - attention_weights_float = attention_weights_float.view( - B, self.nhead, Q, -1 - ) + attention_weights_float = attention_weights_float.view(B, self.nhead, Q, -1) attention_weights_float = attention_weights_float.masked_fill( padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), self.negative_inf, @@ -561,16 +543,12 @@ class EmformerAttention(nn.Module): # [memory, right context, left context, uttrance] # this is used in inference mode key = torch.cat([key[: M + R], left_context_key, key[M + R :]]) - value = torch.cat( - [value[: M + R], left_context_val, value[M + R :]] - ) + value = torch.cat([value[: M + R], left_context_val, value[M + R :]]) Q = query.size(0) # KV = key.size(0) reshaped_query, reshaped_key, reshaped_value = [ - tensor.contiguous() - .view(-1, B * self.nhead, self.head_dim) - .transpose(0, 1) + tensor.contiguous().view(-1, B * self.nhead, self.head_dim).transpose(0, 1) for tensor in [query, key, value] ] # (B * nhead, Q or KV, head_dim) attention_weights = torch.bmm( @@ -585,9 +563,7 @@ class EmformerAttention(nn.Module): # compute attention outputs attention = torch.bmm(attention_probs, reshaped_value) assert attention.shape == (B * self.nhead, Q, self.head_dim) - attention = ( - attention.transpose(0, 1).contiguous().view(Q, B, self.embed_dim) - ) + attention = attention.transpose(0, 1).contiguous().view(Q, B, self.embed_dim) # apply output projection output_right_context_utterance = self.out_proj(attention) @@ -905,13 +881,11 @@ class EmformerEncoderLayer(nn.Module): right_context = right_context_utterance[:R] if self.use_memory: - memory = self.summary_op(utterance.permute(1, 2, 0)).permute( - 2, 0, 1 - )[:-1, :, :] + memory = self.summary_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)[ + :-1, :, : + ] else: - memory = torch.empty(0).to( - dtype=utterance.dtype, device=utterance.device - ) + memory = torch.empty(0).to(dtype=utterance.dtype, device=utterance.device) output_right_context_utterance = self.attention( utterance=utterance, right_context=right_context, @@ -948,18 +922,12 @@ class EmformerEncoderLayer(nn.Module): left_context_val = attn_cache[2] if self.use_memory: - memory = self.summary_op(utterance.permute(1, 2, 0)).permute( - 2, 0, 1 - )[:1, :, :] + memory = self.summary_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)[ + :1, :, : + ] else: - memory = torch.empty(0).to( - dtype=utterance.dtype, device=utterance.device - ) - ( - output_right_context_utterance, - next_key, - next_val, - ) = self.attention.infer( + memory = torch.empty(0).to(dtype=utterance.dtype, device=utterance.device) + (output_right_context_utterance, next_key, next_val,) = self.attention.infer( utterance=utterance, right_context=right_context, memory=pre_memory, @@ -967,9 +935,7 @@ class EmformerEncoderLayer(nn.Module): left_context_val=left_context_val, padding_mask=padding_mask, ) - attn_cache = self._update_attn_cache( - next_key, next_val, memory, attn_cache - ) + attn_cache = self._update_attn_cache(next_key, next_val, memory, attn_cache) return output_right_context_utterance, attn_cache def forward( @@ -1226,9 +1192,7 @@ class EmformerEncoder(nn.Module): def _gen_right_context(self, x: torch.Tensor) -> torch.Tensor: """Hard copy each chunk's right context and concat them.""" T = x.shape[0] - num_chunks = math.ceil( - (T - self.right_context_length) / self.chunk_length - ) + num_chunks = math.ceil((T - self.right_context_length) / self.chunk_length) # first (num_chunks - 1) right context block intervals = torch.arange( 0, self.chunk_length * (num_chunks - 1), self.chunk_length @@ -1247,9 +1211,7 @@ class EmformerEncoder(nn.Module): right_context_blocks = x[indexes.reshape(-1)] return right_context_blocks - def _gen_attention_mask_col_widths( - self, chunk_idx: int, U: int - ) -> List[int]: + def _gen_attention_mask_col_widths(self, chunk_idx: int, U: int) -> List[int]: """Calculate column widths (key, value) in attention mask for the chunk_idx chunk.""" num_chunks = math.ceil(U / self.chunk_length) @@ -1549,12 +1511,8 @@ class EmformerEncoder(nn.Module): attn_caches = [ [ torch.zeros(self.memory_size, self.d_model, device=device), - torch.zeros( - self.left_context_length, self.d_model, device=device - ), - torch.zeros( - self.left_context_length, self.d_model, device=device - ), + torch.zeros(self.left_context_length, self.d_model, device=device), + torch.zeros(self.left_context_length, self.d_model, device=device), ] for _ in range(self.num_encoder_layers) ] @@ -1599,17 +1557,11 @@ class Emformer(EncoderInterface): raise NotImplementedError( "chunk_length must be a mutiple of subsampling_factor." ) - if ( - left_context_length != 0 - and left_context_length % subsampling_factor != 0 - ): + if left_context_length != 0 and left_context_length % subsampling_factor != 0: raise NotImplementedError( "left_context_length must be 0 or a mutiple of subsampling_factor." # noqa ) - if ( - right_context_length != 0 - and right_context_length % subsampling_factor != 0 - ): + if right_context_length != 0 and right_context_length % subsampling_factor != 0: raise NotImplementedError( "right_context_length must be 0 or a mutiple of subsampling_factor." # noqa ) @@ -1672,9 +1624,7 @@ class Emformer(EncoderInterface): x_lens = (((x_lens - 1) >> 1) - 1) >> 1 assert x.size(0) == x_lens.max().item() - output, output_lengths = self.encoder( - x, x_lens, warmup=warmup - ) # (T, N, C) + output, output_lengths = self.encoder(x, x_lens, warmup=warmup) # (T, N, C) output = output.permute(1, 0, 2) # (T, N, C) -> (N, T, C) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export.py index ab15e0241..4d05b367c 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export.py @@ -103,9 +103,11 @@ def get_parser(): "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'" + ), ) parser.add_argument( @@ -136,19 +138,20 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", + help=( + "Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. " + ), ) add_model_arguments(parser) @@ -181,13 +184,12 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -210,13 +212,12 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -244,7 +245,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - f"Calculating the averaged model over epoch range from " + "Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -279,9 +280,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming_decode.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming_decode.py index 71150392d..0486ac2eb 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming_decode.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming_decode.py @@ -113,8 +113,9 @@ def get_parser(): "--epoch", type=int, default=28, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", + help=( + "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." + ), ) parser.add_argument( @@ -131,20 +132,24 @@ def get_parser(): "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. " + ), ) parser.add_argument( "--use-averaged-model", type=str2bool, default=False, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", + help=( + "Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. " + ), ) parser.add_argument( @@ -211,8 +216,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -371,9 +375,7 @@ def modified_beam_search( index=hyps_shape.row_ids(1).to(torch.int64), ) # (num_hyps, encoder_out_dim) - logits = model.joiner( - current_encoder_out, decoder_out, project_input=False - ) + logits = model.joiner(current_encoder_out, decoder_out, project_input=False) # logits is of shape (num_hyps, 1, 1, vocab_size) logits = logits.squeeze(1).squeeze(1) @@ -390,9 +392,7 @@ def modified_beam_search( log_probs_shape = k2.ragged.create_ragged_shape2( row_splits=row_splits, cached_tot_size=log_probs.numel() ) - ragged_log_probs = k2.RaggedTensor( - shape=log_probs_shape, value=log_probs - ) + ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) for i in range(batch_size): topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) @@ -551,14 +551,10 @@ def decode_one_chunk( feature_list, batch_first=True, padding_value=LOG_EPSILON ).to(device) feature_lens = torch.tensor(feature_len_list, device=device) - num_processed_frames = torch.tensor( - num_processed_frames_list, device=device - ) + num_processed_frames = torch.tensor(num_processed_frames_list, device=device) # Make sure it has at least 1 frame after subsampling, first-and-last-frame cutting, and right context cutting # noqa - tail_length = ( - 3 * params.subsampling_factor + params.right_context_length + 3 - ) + tail_length = 3 * params.subsampling_factor + params.right_context_length + 3 if features.size(1) < tail_length: pad_length = tail_length - features.size(1) feature_lens += pad_length @@ -605,9 +601,7 @@ def decode_one_chunk( max_states=params.max_states, ) else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") # Update cached states of each stream state_list = unstack_states(states) @@ -782,8 +776,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -831,9 +824,7 @@ def main(): params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -867,13 +858,12 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -896,13 +886,12 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -930,7 +919,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - f"Calculating the averaged model over epoch range from " + "Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py index 2bbc45d78..2c2593b56 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py @@ -95,9 +95,7 @@ from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def add_model_arguments(parser: argparse.ArgumentParser): @@ -265,42 +263,45 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", + help=( + "The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss" + ), ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", + help=( + "The scale to smooth the loss with lm (output of prediction network) part." + ), ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network)part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", + help=( + "To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss." + ), ) parser.add_argument( @@ -636,11 +637,7 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = ( - model.device - if isinstance(model, DDP) - else next(model.parameters()).device - ) + device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -668,23 +665,16 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 - if warmup < 1.0 - else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = ( - params.simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss + 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) ) + loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa info["utterances"] = feature.size(0) @@ -871,9 +861,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -981,7 +969,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2 ** 22 + 2**22 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/librispeech/ASR/local/add_alignment_librispeech.py b/egs/librispeech/ASR/local/add_alignment_librispeech.py index fe6a26c51..cc34a72d8 100755 --- a/egs/librispeech/ASR/local/add_alignment_librispeech.py +++ b/egs/librispeech/ASR/local/add_alignment_librispeech.py @@ -157,9 +157,7 @@ def add_alignment( for ali_path in part_ali_dir.rglob("*.alignment.txt"): ali = parse_alignments(ali_path) alignments.update(ali) - logging.info( - f"{part} has {len(alignments.keys())} cuts with alignments." - ) + logging.info(f"{part} has {len(alignments.keys())} cuts with alignments.") # add alignment attribute and write out cuts_in = load_manifest_lazy(cuts_in_path) @@ -170,18 +168,14 @@ def add_alignment( if origin_id in alignments: ali = alignments[origin_id] else: - logging.info( - f"Warning: {origin_id} does not have alignment." - ) + logging.info(f"Warning: {origin_id} does not have alignment.") ali = [] subcut.alignment = {"word": ali} writer.write(cut, flush=True) def main(): - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) parser = get_parser() diff --git a/egs/librispeech/ASR/local/compile_hlg.py b/egs/librispeech/ASR/local/compile_hlg.py index 9a35750e0..295156ed5 100755 --- a/egs/librispeech/ASR/local/compile_hlg.py +++ b/egs/librispeech/ASR/local/compile_hlg.py @@ -150,9 +150,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/librispeech/ASR/local/compile_lg.py b/egs/librispeech/ASR/local/compile_lg.py index 45c4b7f5f..19bf3bff4 100755 --- a/egs/librispeech/ASR/local/compile_lg.py +++ b/egs/librispeech/ASR/local/compile_lg.py @@ -132,9 +132,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/librispeech/ASR/local/compute_fbank_gigaspeech_dev_test.py b/egs/librispeech/ASR/local/compute_fbank_gigaspeech_dev_test.py index c0c7ef8c5..97750f3ea 100644 --- a/egs/librispeech/ASR/local/compute_fbank_gigaspeech_dev_test.py +++ b/egs/librispeech/ASR/local/compute_fbank_gigaspeech_dev_test.py @@ -80,9 +80,7 @@ def compute_fbank_gigaspeech_dev_test(): def main(): - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) compute_fbank_gigaspeech_dev_test() diff --git a/egs/librispeech/ASR/local/compute_fbank_gigaspeech_splits.py b/egs/librispeech/ASR/local/compute_fbank_gigaspeech_splits.py index 5587106e5..37fce11f4 100644 --- a/egs/librispeech/ASR/local/compute_fbank_gigaspeech_splits.py +++ b/egs/librispeech/ASR/local/compute_fbank_gigaspeech_splits.py @@ -48,8 +48,10 @@ def get_parser(): "--batch-duration", type=float, default=600.0, - help="The maximum number of audio seconds in a batch." - "Determines batch size dynamically.", + help=( + "The maximum number of audio seconds in a batch." + "Determines batch size dynamically." + ), ) parser.add_argument( @@ -144,9 +146,7 @@ def main(): date_time = now.strftime("%Y-%m-%d-%H-%M-%S") log_filename = "log-compute_fbank_gigaspeech_splits" - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" log_filename = f"{log_filename}-{date_time}" logging.basicConfig( diff --git a/egs/librispeech/ASR/local/compute_fbank_librispeech.py b/egs/librispeech/ASR/local/compute_fbank_librispeech.py index ce7d087f0..9f8503814 100755 --- a/egs/librispeech/ASR/local/compute_fbank_librispeech.py +++ b/egs/librispeech/ASR/local/compute_fbank_librispeech.py @@ -112,9 +112,7 @@ def compute_fbank_librispeech(bpe_model: Optional[str] = None): if "train" in partition: cut_set = ( - cut_set - + cut_set.perturb_speed(0.9) - + cut_set.perturb_speed(1.1) + cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) ) cut_set = cut_set.compute_and_store_features( extractor=extractor, @@ -128,9 +126,7 @@ def compute_fbank_librispeech(bpe_model: Optional[str] = None): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) args = get_args() diff --git a/egs/librispeech/ASR/local/compute_fbank_musan.py b/egs/librispeech/ASR/local/compute_fbank_musan.py index 056da29e5..4a4093ae4 100755 --- a/egs/librispeech/ASR/local/compute_fbank_musan.py +++ b/egs/librispeech/ASR/local/compute_fbank_musan.py @@ -83,9 +83,7 @@ def compute_fbank_musan(): # create chunks of Musan with duration 5 - 10 seconds musan_cuts = ( CutSet.from_manifests( - recordings=combine( - part["recordings"] for part in manifests.values() - ) + recordings=combine(part["recordings"] for part in manifests.values()) ) .cut_into_windows(10.0) .filter(lambda c: c.duration > 5) @@ -101,9 +99,7 @@ def compute_fbank_musan(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) compute_fbank_musan() diff --git a/egs/librispeech/ASR/local/convert_transcript_words_to_tokens.py b/egs/librispeech/ASR/local/convert_transcript_words_to_tokens.py index 133499c8b..f149b7871 100755 --- a/egs/librispeech/ASR/local/convert_transcript_words_to_tokens.py +++ b/egs/librispeech/ASR/local/convert_transcript_words_to_tokens.py @@ -46,21 +46,19 @@ def get_args(): parser.add_argument( "--transcript", type=str, - help="The input transcript file." - "We assume that the transcript file consists of " - "lines. Each line consists of space separated words.", + help=( + "The input transcript file." + "We assume that the transcript file consists of " + "lines. Each line consists of space separated words." + ), ) parser.add_argument("--lexicon", type=str, help="The input lexicon file.") - parser.add_argument( - "--oov", type=str, default="", help="The OOV word." - ) + parser.add_argument("--oov", type=str, default="", help="The OOV word.") return parser.parse_args() -def process_line( - lexicon: Dict[str, List[str]], line: str, oov_token: str -) -> None: +def process_line(lexicon: Dict[str, List[str]], line: str, oov_token: str) -> None: """ Args: lexicon: diff --git a/egs/librispeech/ASR/local/download_lm.py b/egs/librispeech/ASR/local/download_lm.py index 030122aa7..3518db524 100755 --- a/egs/librispeech/ASR/local/download_lm.py +++ b/egs/librispeech/ASR/local/download_lm.py @@ -87,9 +87,7 @@ def main(out_dir: str): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/librispeech/ASR/local/filter_cuts.py b/egs/librispeech/ASR/local/filter_cuts.py index dff98a954..fbcc9e24a 100644 --- a/egs/librispeech/ASR/local/filter_cuts.py +++ b/egs/librispeech/ASR/local/filter_cuts.py @@ -79,8 +79,7 @@ def filter_cuts(cut_set: CutSet, sp: spm.SentencePieceProcessor): total += 1 if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. " - f"Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" ) removed += 1 return False @@ -125,8 +124,7 @@ def filter_cuts(cut_set: CutSet, sp: spm.SentencePieceProcessor): ans = cut_set.filter(remove_short_and_long_utterances).to_eager() ratio = removed / total * 100 logging.info( - f"Removed {removed} cuts from {total} cuts. " - f"{ratio:.3f}% data is removed." + f"Removed {removed} cuts from {total} cuts. {ratio:.3f}% data is removed." ) return ans @@ -155,9 +153,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/librispeech/ASR/local/generate_unique_lexicon.py b/egs/librispeech/ASR/local/generate_unique_lexicon.py index 566c0743d..3459c2f5a 100755 --- a/egs/librispeech/ASR/local/generate_unique_lexicon.py +++ b/egs/librispeech/ASR/local/generate_unique_lexicon.py @@ -91,9 +91,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/librispeech/ASR/local/prepare_lang_bpe.py b/egs/librispeech/ASR/local/prepare_lang_bpe.py index dec8a7442..e121aefa9 100755 --- a/egs/librispeech/ASR/local/prepare_lang_bpe.py +++ b/egs/librispeech/ASR/local/prepare_lang_bpe.py @@ -150,9 +150,7 @@ def generate_lexicon( words_pieces_ids: List[List[int]] = sp.encode(words, out_type=int) # Now convert word piece IDs back to word piece strings. - words_pieces: List[List[str]] = [ - sp.id_to_piece(ids) for ids in words_pieces_ids - ] + words_pieces: List[List[str]] = [sp.id_to_piece(ids) for ids in words_pieces_ids] lexicon = [] for word, pieces in zip(words, words_pieces): diff --git a/egs/librispeech/ASR/local/prepare_lm_training_data.py b/egs/librispeech/ASR/local/prepare_lm_training_data.py index 5070341f1..70343fef7 100755 --- a/egs/librispeech/ASR/local/prepare_lm_training_data.py +++ b/egs/librispeech/ASR/local/prepare_lm_training_data.py @@ -137,8 +137,7 @@ def main(): for i in range(num_sentences): if step and i % step == 0: logging.info( - f"Processed number of lines: {i} " - f"({i/num_sentences*100: .3f}%)" + f"Processed number of lines: {i} ({i/num_sentences*100: .3f}%)" ) word_ids = sentences[i] @@ -154,18 +153,14 @@ def main(): sentence_lengths[i] = token_ids.numel() - output["sentence_lengths"] = torch.tensor( - sentence_lengths, dtype=torch.int32 - ) + output["sentence_lengths"] = torch.tensor(sentence_lengths, dtype=torch.int32) torch.save(output, args.lm_archive) logging.info(f"Saved to {args.lm_archive}") if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/librispeech/ASR/local/preprocess_gigaspeech.py b/egs/librispeech/ASR/local/preprocess_gigaspeech.py index 077f23039..8aa5e461d 100644 --- a/egs/librispeech/ASR/local/preprocess_gigaspeech.py +++ b/egs/librispeech/ASR/local/preprocess_gigaspeech.py @@ -119,9 +119,7 @@ def preprocess_giga_speech(): def main(): - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) preprocess_giga_speech() diff --git a/egs/librispeech/ASR/local/test_prepare_lang.py b/egs/librispeech/ASR/local/test_prepare_lang.py index d4cf62bba..74e025ad7 100755 --- a/egs/librispeech/ASR/local/test_prepare_lang.py +++ b/egs/librispeech/ASR/local/test_prepare_lang.py @@ -88,9 +88,7 @@ def test_read_lexicon(filename: str): fsa.aux_labels_sym = k2.SymbolTable.from_file("words.txt") fsa.draw("L.pdf", title="L") - fsa_disambig = lexicon_to_fst( - lexicon_disambig, phone2id=phone2id, word2id=word2id - ) + fsa_disambig = lexicon_to_fst(lexicon_disambig, phone2id=phone2id, word2id=word2id) fsa_disambig.labels_sym = k2.SymbolTable.from_file("phones.txt") fsa_disambig.aux_labels_sym = k2.SymbolTable.from_file("words.txt") fsa_disambig.draw("L_disambig.pdf", title="L_disambig") diff --git a/egs/librispeech/ASR/local/validate_manifest.py b/egs/librispeech/ASR/local/validate_manifest.py index 7c57d629a..807aaf891 100755 --- a/egs/librispeech/ASR/local/validate_manifest.py +++ b/egs/librispeech/ASR/local/validate_manifest.py @@ -64,8 +64,7 @@ def validate_supervision_and_cut_time_bounds(c: Cut): if s.end > c.end: raise ValueError( - f"{c.id}: Supervision end time {s.end} is larger " - f"than cut end time {c.end}" + f"{c.id}: Supervision end time {s.end} is larger than cut end time {c.end}" ) @@ -85,9 +84,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/decode.py b/egs/librispeech/ASR/lstm_transducer_stateless/decode.py old mode 100755 new mode 100644 index 27414d717..e69de29bb --- a/egs/librispeech/ASR/lstm_transducer_stateless/decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/decode.py @@ -1,818 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, -# Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: -(1) greedy search -./lstm_transducer_stateless/decode.py \ - --epoch 35 \ - --avg 15 \ - --exp-dir ./lstm_transducer_stateless/exp \ - --max-duration 600 \ - --decoding-method greedy_search - -(2) beam search (not recommended) -./lstm_transducer_stateless/decode.py \ - --epoch 35 \ - --avg 15 \ - --exp-dir ./lstm_transducer_stateless/exp \ - --max-duration 600 \ - --decoding-method beam_search \ - --beam-size 4 - -(3) modified beam search -./lstm_transducer_stateless/decode.py \ - --epoch 35 \ - --avg 15 \ - --exp-dir ./lstm_transducer_stateless/exp \ - --max-duration 600 \ - --decoding-method modified_beam_search \ - --beam-size 4 - -(4) fast beam search (one best) -./lstm_transducer_stateless/decode.py \ - --epoch 35 \ - --avg 15 \ - --exp-dir ./lstm_transducer_stateless/exp \ - --max-duration 600 \ - --decoding-method fast_beam_search \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 - -(5) fast beam search (nbest) -./lstm_transducer_stateless/decode.py \ - --epoch 30 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless3/exp \ - --max-duration 600 \ - --decoding-method fast_beam_search_nbest \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 \ - --num-paths 200 \ - --nbest-scale 0.5 - -(6) fast beam search (nbest oracle WER) -./lstm_transducer_stateless/decode.py \ - --epoch 35 \ - --avg 15 \ - --exp-dir ./lstm_transducer_stateless/exp \ - --max-duration 600 \ - --decoding-method fast_beam_search_nbest_oracle \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 \ - --num-paths 200 \ - --nbest-scale 0.5 - -(7) fast beam search (with LG) -./lstm_transducer_stateless/decode.py \ - --epoch 35 \ - --avg 15 \ - --exp-dir ./lstm_transducer_stateless/exp \ - --max-duration 600 \ - --decoding-method fast_beam_search_nbest_LG \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 -""" - - -import argparse -import logging -import math -from collections import defaultdict -from pathlib import Path -from typing import Dict, List, Optional, Tuple - -import k2 -import sentencepiece as spm -import torch -import torch.nn as nn -from asr_datamodule import LibriSpeechAsrDataModule -from beam_search import ( - beam_search, - fast_beam_search_nbest, - fast_beam_search_nbest_LG, - fast_beam_search_nbest_oracle, - fast_beam_search_one_best, - greedy_search, - greedy_search_batch, - modified_beam_search, -) -from train import add_model_arguments, get_params, get_transducer_model - -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.lexicon import Lexicon -from icefall.utils import ( - AttributeDict, - setup_logger, - store_transcripts, - str2bool, - write_error_stats, -) - -LOG_EPS = math.log(1e-10) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=30, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="lstm_transducer_stateless/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", - ) - - parser.add_argument( - "--lang-dir", - type=Path, - default="data/lang_bpe_500", - help="The lang dir containing word table and LG graph", - ) - - parser.add_argument( - "--decoding-method", - type=str, - default="greedy_search", - help="""Possible values are: - - greedy_search - - beam_search - - modified_beam_search - - fast_beam_search - - fast_beam_search_nbest - - fast_beam_search_nbest_oracle - - fast_beam_search_nbest_LG - If you use fast_beam_search_nbest_LG, you have to specify - `--lang-dir`, which should contain `LG.pt`. - """, - ) - - parser.add_argument( - "--beam-size", - type=int, - default=4, - help="""An integer indicating how many candidates we will keep for each - frame. Used only when --decoding-method is beam_search or - modified_beam_search.""", - ) - - parser.add_argument( - "--beam", - type=float, - default=20.0, - help="""A floating point value to calculate the cutoff score during beam - search (i.e., `cutoff = max-score - beam`), which is the same as the - `beam` in Kaldi. - Used only when --decoding-method is fast_beam_search, - fast_beam_search_nbest, fast_beam_search_nbest_LG, - and fast_beam_search_nbest_oracle - """, - ) - - parser.add_argument( - "--ngram-lm-scale", - type=float, - default=0.01, - help=""" - Used only when --decoding_method is fast_beam_search_nbest_LG. - It specifies the scale for n-gram LM scores. - """, - ) - - parser.add_argument( - "--max-contexts", - type=int, - default=8, - help="""Used only when --decoding-method is - fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, - and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--max-states", - type=int, - default=64, - help="""Used only when --decoding-method is - fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, - and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", - ) - - parser.add_argument( - "--max-sym-per-frame", - type=int, - default=1, - help="""Maximum number of symbols per frame. - Used only when --decoding_method is greedy_search""", - ) - - parser.add_argument( - "--num-paths", - type=int, - default=200, - help="""Number of paths for nbest decoding. - Used only when the decoding method is fast_beam_search_nbest, - fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--nbest-scale", - type=float, - default=0.5, - help="""Scale applied to lattice scores when computing nbest paths. - Used only when the decoding method is fast_beam_search_nbest, - fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", - ) - - add_model_arguments(parser) - - return parser - - -def decode_one_batch( - params: AttributeDict, - model: nn.Module, - sp: spm.SentencePieceProcessor, - batch: dict, - word_table: Optional[k2.SymbolTable] = None, - decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[List[str]]]: - """Decode one batch and return the result in a dict. The dict has the - following format: - - - key: It indicates the setting used for decoding. For example, - if greedy_search is used, it would be "greedy_search" - If beam search with a beam size of 7 is used, it would be - "beam_7" - - value: It contains the decoding result. `len(value)` equals to - batch size. `value[i]` is the decoding result for the i-th - utterance in the given batch. - Args: - params: - It's the return value of :func:`get_params`. - model: - The neural model. - sp: - The BPE model. - batch: - It is the return value from iterating - `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation - for the format of the `batch`. - word_table: - The word symbol table. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or LG, Used - only when --decoding_method is fast_beam_search, fast_beam_search_nbest, - fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. - Returns: - Return the decoding result. See above description for the format of - the returned dict. - """ - device = next(model.parameters()).device - feature = batch["inputs"] - assert feature.ndim == 3 - - feature = feature.to(device) - # at entry, feature is (N, T, C) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - # tail padding here to alleviate the tail deletion problem - num_tail_padded_frames = 35 - feature = torch.nn.functional.pad( - feature, - (0, 0, 0, num_tail_padded_frames), - mode="constant", - value=LOG_EPS, - ) - feature_lens += num_tail_padded_frames - - encoder_out, encoder_out_lens, _ = model.encoder( - x=feature, x_lens=feature_lens - ) - - hyps = [] - - if params.decoding_method == "fast_beam_search": - hyp_tokens = fast_beam_search_one_best( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.decoding_method == "fast_beam_search_nbest_LG": - hyp_tokens = fast_beam_search_nbest_LG( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - num_paths=params.num_paths, - nbest_scale=params.nbest_scale, - ) - for hyp in hyp_tokens: - hyps.append([word_table[i] for i in hyp]) - elif params.decoding_method == "fast_beam_search_nbest": - hyp_tokens = fast_beam_search_nbest( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - num_paths=params.num_paths, - nbest_scale=params.nbest_scale, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.decoding_method == "fast_beam_search_nbest_oracle": - hyp_tokens = fast_beam_search_nbest_oracle( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - num_paths=params.num_paths, - ref_texts=sp.encode(supervisions["text"]), - nbest_scale=params.nbest_scale, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): - hyp_tokens = greedy_search_batch( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.decoding_method == "modified_beam_search": - hyp_tokens = modified_beam_search( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - else: - batch_size = encoder_out.size(0) - - for i in range(batch_size): - # fmt: off - encoder_out_i = encoder_out[i:i + 1, :encoder_out_lens[i]] - # fmt: on - if params.decoding_method == "greedy_search": - hyp = greedy_search( - model=model, - encoder_out=encoder_out_i, - max_sym_per_frame=params.max_sym_per_frame, - ) - elif params.decoding_method == "beam_search": - hyp = beam_search( - model=model, - encoder_out=encoder_out_i, - beam=params.beam_size, - ) - else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) - hyps.append(sp.decode(hyp).split()) - - if params.decoding_method == "greedy_search": - return {"greedy_search": hyps} - elif "fast_beam_search" in params.decoding_method: - key = f"beam_{params.beam}_" - key += f"max_contexts_{params.max_contexts}_" - key += f"max_states_{params.max_states}" - if "nbest" in params.decoding_method: - key += f"_num_paths_{params.num_paths}_" - key += f"nbest_scale_{params.nbest_scale}" - if "LG" in params.decoding_method: - key += f"_ngram_lm_scale_{params.ngram_lm_scale}" - - return {key: hyps} - else: - return {f"beam_size_{params.beam_size}": hyps} - - -def decode_dataset( - dl: torch.utils.data.DataLoader, - params: AttributeDict, - model: nn.Module, - sp: spm.SentencePieceProcessor, - word_table: Optional[k2.SymbolTable] = None, - decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: - """Decode dataset. - - Args: - dl: - PyTorch's dataloader containing the dataset to decode. - params: - It is returned by :func:`get_params`. - model: - The neural model. - sp: - The BPE model. - word_table: - The word symbol table. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or LG, Used - only when --decoding_method is fast_beam_search, fast_beam_search_nbest, - fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. - Returns: - Return a dict, whose key may be "greedy_search" if greedy search - is used, or it may be "beam_7" if beam size of 7 is used. - Its value is a list of tuples. Each tuple contains two elements: - The first is the reference transcript, and the second is the - predicted result. - """ - num_cuts = 0 - - try: - num_batches = len(dl) - except TypeError: - num_batches = "?" - - if params.decoding_method == "greedy_search": - log_interval = 50 - else: - log_interval = 20 - - results = defaultdict(list) - for batch_idx, batch in enumerate(dl): - texts = batch["supervisions"]["text"] - cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] - - hyps_dict = decode_one_batch( - params=params, - model=model, - sp=sp, - decoding_graph=decoding_graph, - word_table=word_table, - batch=batch, - ) - - for name, hyps in hyps_dict.items(): - this_batch = [] - assert len(hyps) == len(texts) - for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): - ref_words = ref_text.split() - this_batch.append((cut_id, ref_words, hyp_words)) - - results[name].extend(this_batch) - - num_cuts += len(texts) - - if batch_idx % log_interval == 0: - batch_str = f"{batch_idx}/{num_batches}" - - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) - return results - - -def save_results( - params: AttributeDict, - test_set_name: str, - results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], -): - test_set_wers = dict() - for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) - results = sorted(results) - store_transcripts(filename=recog_path, texts=results) - logging.info(f"The transcripts are stored in {recog_path}") - - # The following prints out WERs, per-word error statistics and aligned - # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) - with open(errs_filename, "w") as f: - wer = write_error_stats( - f, f"{test_set_name}-{key}", results, enable_log=True - ) - test_set_wers[key] = wer - - logging.info("Wrote detailed error stats to {}".format(errs_filename)) - - test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) - with open(errs_info, "w") as f: - print("settings\tWER", file=f) - for key, val in test_set_wers: - print("{}\t{}".format(key, val), file=f) - - s = "\nFor {}, WER of different settings are:\n".format(test_set_name) - note = "\tbest for {}".format(test_set_name) - for key, val in test_set_wers: - s += "{}\t{}{}\n".format(key, val, note) - note = "" - logging.info(s) - - -@torch.no_grad() -def main(): - parser = get_parser() - LibriSpeechAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - assert params.decoding_method in ( - "greedy_search", - "beam_search", - "fast_beam_search", - "fast_beam_search_nbest", - "fast_beam_search_nbest_LG", - "fast_beam_search_nbest_oracle", - "modified_beam_search", - ) - params.res_dir = params.exp_dir / params.decoding_method - - if params.iter > 0: - params.suffix = f"iter-{params.iter}-avg-{params.avg}" - else: - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - - if "fast_beam_search" in params.decoding_method: - params.suffix += f"-beam-{params.beam}" - params.suffix += f"-max-contexts-{params.max_contexts}" - params.suffix += f"-max-states-{params.max_states}" - if "nbest" in params.decoding_method: - params.suffix += f"-nbest-scale-{params.nbest_scale}" - params.suffix += f"-num-paths-{params.num_paths}" - if "LG" in params.decoding_method: - params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" - elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) - else: - params.suffix += f"-context-{params.context_size}" - params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" - - if params.use_averaged_model: - params.suffix += "-use-averaged-model" - - setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") - logging.info("Decoding started") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"Device: {device}") - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # and are defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.to(device) - model.eval() - - if "fast_beam_search" in params.decoding_method: - if params.decoding_method == "fast_beam_search_nbest_LG": - lexicon = Lexicon(params.lang_dir) - word_table = lexicon.word_table - lg_filename = params.lang_dir / "LG.pt" - logging.info(f"Loading {lg_filename}") - decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) - ) - decoding_graph.scores *= params.ngram_lm_scale - else: - word_table = None - decoding_graph = k2.trivial_graph( - params.vocab_size - 1, device=device - ) - else: - decoding_graph = None - word_table = None - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - # we need cut ids to display recognition results. - args.return_cuts = True - librispeech = LibriSpeechAsrDataModule(args) - - test_clean_cuts = librispeech.test_clean_cuts() - test_other_cuts = librispeech.test_other_cuts() - - test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) - test_other_dl = librispeech.test_dataloaders(test_other_cuts) - - test_sets = ["test-clean", "test-other"] - test_dl = [test_clean_dl, test_other_dl] - - for test_set, test_dl in zip(test_sets, test_dl): - results_dict = decode_dataset( - dl=test_dl, - params=params, - model=model, - sp=sp, - word_table=word_table, - decoding_graph=decoding_graph, - ) - - save_results( - params=params, - test_set_name=test_set, - results_dict=results_dict, - ) - - logging.info("Done!") - - -if __name__ == "__main__": - main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/export.py b/egs/librispeech/ASR/lstm_transducer_stateless/export.py old mode 100755 new mode 100644 index 13dac6009..e69de29bb --- a/egs/librispeech/ASR/lstm_transducer_stateless/export.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/export.py @@ -1,388 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# This script converts several saved checkpoints -# to a single one using model averaging. -""" - -Usage: - -(1) Export to torchscript model using torch.jit.trace() - -./lstm_transducer_stateless/export.py \ - --exp-dir ./lstm_transducer_stateless/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ - --epoch 35 \ - --avg 10 \ - --jit-trace 1 - -It will generate 3 files: `encoder_jit_trace.pt`, -`decoder_jit_trace.pt`, and `joiner_jit_trace.pt`. - -(2) Export `model.state_dict()` - -./lstm_transducer_stateless/export.py \ - --exp-dir ./lstm_transducer_stateless/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ - --epoch 35 \ - --avg 10 - -It will generate a file `pretrained.pt` in the given `exp_dir`. You can later -load it by `icefall.checkpoint.load_checkpoint()`. - -To use the generated file with `lstm_transducer_stateless/decode.py`, -you can do: - - cd /path/to/exp_dir - ln -s pretrained.pt epoch-9999.pt - - cd /path/to/egs/librispeech/ASR - ./lstm_transducer_stateless/decode.py \ - --exp-dir ./lstm_transducer_stateless/exp \ - --epoch 9999 \ - --avg 1 \ - --max-duration 600 \ - --decoding-method greedy_search \ - --bpe-model data/lang_bpe_500/bpe.model - -Check ./pretrained.py for its usage. - -Note: If you don't want to train a model from scratch, we have -provided one for you. You can get it at - -https://huggingface.co/Zengwei/icefall-asr-librispeech-lstm-transducer-stateless-2022-08-18 - -with the following commands: - - sudo apt-get install git-lfs - git lfs install - git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-lstm-transducer-stateless-2022-08-18 - # You will find the pre-trained model in icefall-asr-librispeech-lstm-transducer-stateless-2022-08-18/exp -""" - -import argparse -import logging -from pathlib import Path - -import sentencepiece as spm -import torch -import torch.nn as nn -from scaling_converter import convert_scaled_to_non_scaled -from train import add_model_arguments, get_params, get_transducer_model - -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.utils import str2bool - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=28, - help="""It specifies the checkpoint to use for averaging. - Note: Epoch counts from 0. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="pruned_transducer_stateless3/exp", - help="""It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", - ) - - parser.add_argument( - "--jit-trace", - type=str2bool, - default=False, - help="""True to save a model after applying torch.jit.trace. - It will generate 3 files: - - encoder_jit_trace.pt - - decoder_jit_trace.pt - - joiner_jit_trace.pt - - Check ./jit_pretrained.py for how to use them. - """, - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", - ) - - add_model_arguments(parser) - - return parser - - -def export_encoder_model_jit_trace( - encoder_model: nn.Module, - encoder_filename: str, -) -> None: - """Export the given encoder model with torch.jit.trace() - - Note: The warmup argument is fixed to 1. - - Args: - encoder_model: - The input encoder model - encoder_filename: - The filename to save the exported model. - """ - x = torch.zeros(1, 100, 80, dtype=torch.float32) - x_lens = torch.tensor([100], dtype=torch.int64) - states = encoder_model.get_init_states() - - traced_model = torch.jit.trace(encoder_model, (x, x_lens, states)) - traced_model.save(encoder_filename) - logging.info(f"Saved to {encoder_filename}") - - -def export_decoder_model_jit_trace( - decoder_model: nn.Module, - decoder_filename: str, -) -> None: - """Export the given decoder model with torch.jit.trace() - - Note: The argument need_pad is fixed to False. - - Args: - decoder_model: - The input decoder model - decoder_filename: - The filename to save the exported model. - """ - y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64) - need_pad = torch.tensor([False]) - - traced_model = torch.jit.trace(decoder_model, (y, need_pad)) - traced_model.save(decoder_filename) - logging.info(f"Saved to {decoder_filename}") - - -def export_joiner_model_jit_trace( - joiner_model: nn.Module, - joiner_filename: str, -) -> None: - """Export the given joiner model with torch.jit.trace() - - Note: The argument project_input is fixed to True. A user should not - project the encoder_out/decoder_out by himself/herself. The exported joiner - will do that for the user. - - Args: - joiner_model: - The input joiner model - joiner_filename: - The filename to save the exported model. - - """ - encoder_out_dim = joiner_model.encoder_proj.weight.shape[1] - decoder_out_dim = joiner_model.decoder_proj.weight.shape[1] - encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32) - decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32) - - traced_model = torch.jit.trace(joiner_model, (encoder_out, decoder_out)) - traced_model.save(joiner_filename) - logging.info(f"Saved to {joiner_filename}") - - -@torch.no_grad() -def main(): - args = get_parser().parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.to("cpu") - model.eval() - - if params.jit_trace is True: - convert_scaled_to_non_scaled(model, inplace=True) - logging.info("Using torch.jit.trace()") - encoder_filename = params.exp_dir / "encoder_jit_trace.pt" - export_encoder_model_jit_trace(model.encoder, encoder_filename) - - decoder_filename = params.exp_dir / "decoder_jit_trace.pt" - export_decoder_model_jit_trace(model.decoder, decoder_filename) - - joiner_filename = params.exp_dir / "joiner_jit_trace.pt" - export_joiner_model_jit_trace(model.joiner, joiner_filename) - else: - logging.info("Not using torchscript") - # Save it using a format so that it can be loaded - # by :func:`load_checkpoint` - filename = params.exp_dir / "pretrained.pt" - torch.save({"model": model.state_dict()}, str(filename)) - logging.info(f"Saved to {filename}") - - -if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/jit_pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless/jit_pretrained.py old mode 100755 new mode 100644 index 594c33e4f..e69de29bb --- a/egs/librispeech/ASR/lstm_transducer_stateless/jit_pretrained.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/jit_pretrained.py @@ -1,322 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -This script loads torchscript models, either exported by `torch.jit.trace()` -or by `torch.jit.script()`, and uses them to decode waves. -You can use the following command to get the exported models: - -./lstm_transducer_stateless/export.py \ - --exp-dir ./lstm_transducer_stateless/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ - --epoch 20 \ - --avg 10 \ - --jit-trace 1 - -Usage of this script: - -./lstm_transducer_stateless/jit_pretrained.py \ - --encoder-model-filename ./lstm_transducer_stateless/exp/encoder_jit_trace.pt \ - --decoder-model-filename ./lstm_transducer_stateless/exp/decoder_jit_trace.pt \ - --joiner-model-filename ./lstm_transducer_stateless/exp/joiner_jit_trace.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ - /path/to/foo.wav \ - /path/to/bar.wav -""" - -import argparse -import logging -import math -from typing import List - -import kaldifeat -import sentencepiece as spm -import torch -import torchaudio -from torch.nn.utils.rnn import pad_sequence - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--encoder-model-filename", - type=str, - required=True, - help="Path to the encoder torchscript model. ", - ) - - parser.add_argument( - "--decoder-model-filename", - type=str, - required=True, - help="Path to the decoder torchscript model. ", - ) - - parser.add_argument( - "--joiner-model-filename", - type=str, - required=True, - help="Path to the joiner torchscript model. ", - ) - - parser.add_argument( - "--bpe-model", - type=str, - help="""Path to bpe.model.""", - ) - - parser.add_argument( - "sound_files", - type=str, - nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", - ) - - parser.add_argument( - "--sample-rate", - type=int, - default=16000, - help="The sample rate of the input sound file", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="Context size of the decoder model", - ) - - return parser - - -def read_sound_files( - filenames: List[str], expected_sample_rate: float -) -> List[torch.Tensor]: - """Read a list of sound files into a list 1-D float32 torch tensors. - Args: - filenames: - A list of sound filenames. - expected_sample_rate: - The expected sample rate of the sound files. - Returns: - Return a list of 1-D float32 torch tensors. - """ - ans = [] - for f in filenames: - wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" - ) - # We use only the first channel - ans.append(wave[0]) - return ans - - -def greedy_search( - decoder: torch.jit.ScriptModule, - joiner: torch.jit.ScriptModule, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - context_size: int, -) -> List[List[int]]: - """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. - Args: - decoder: - The decoder model. - joiner: - The joiner model. - encoder_out: - A 3-D tensor of shape (N, T, C) - encoder_out_lens: - A 1-D tensor of shape (N,). - context_size: - The context size of the decoder model. - Returns: - Return the decoded results for each utterance. - """ - assert encoder_out.ndim == 3 - assert encoder_out.size(0) >= 1, encoder_out.size(0) - - packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( - input=encoder_out, - lengths=encoder_out_lens.cpu(), - batch_first=True, - enforce_sorted=False, - ) - - device = encoder_out.device - blank_id = 0 # hard-code to 0 - - batch_size_list = packed_encoder_out.batch_sizes.tolist() - N = encoder_out.size(0) - - assert torch.all(encoder_out_lens > 0), encoder_out_lens - assert N == batch_size_list[0], (N, batch_size_list) - - hyps = [[blank_id] * context_size for _ in range(N)] - - decoder_input = torch.tensor( - hyps, - device=device, - dtype=torch.int64, - ) # (N, context_size) - - decoder_out = decoder( - decoder_input, - need_pad=torch.tensor([False]), - ).squeeze(1) - - offset = 0 - for batch_size in batch_size_list: - start = offset - end = offset + batch_size - current_encoder_out = packed_encoder_out.data[start:end] - current_encoder_out = current_encoder_out - # current_encoder_out's shape: (batch_size, encoder_out_dim) - offset = end - - decoder_out = decoder_out[:batch_size] - - logits = joiner( - current_encoder_out, - decoder_out, - ) - # logits'shape (batch_size, vocab_size) - - assert logits.ndim == 2, logits.shape - y = logits.argmax(dim=1).tolist() - emitted = False - for i, v in enumerate(y): - if v != blank_id: - hyps[i].append(v) - emitted = True - if emitted: - # update decoder output - decoder_input = [h[-context_size:] for h in hyps[:batch_size]] - decoder_input = torch.tensor( - decoder_input, - device=device, - dtype=torch.int64, - ) - decoder_out = decoder( - decoder_input, - need_pad=torch.tensor([False]), - ) - decoder_out = decoder_out.squeeze(1) - - sorted_ans = [h[context_size:] for h in hyps] - ans = [] - unsorted_indices = packed_encoder_out.unsorted_indices.tolist() - for i in range(N): - ans.append(sorted_ans[unsorted_indices[i]]) - - return ans - - -@torch.no_grad() -def main(): - parser = get_parser() - args = parser.parse_args() - logging.info(vars(args)) - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - encoder = torch.jit.load(args.encoder_model_filename) - decoder = torch.jit.load(args.decoder_model_filename) - joiner = torch.jit.load(args.joiner_model_filename) - - encoder.eval() - decoder.eval() - joiner.eval() - - encoder.to(device) - decoder.to(device) - joiner.to(device) - - sp = spm.SentencePieceProcessor() - sp.load(args.bpe_model) - - logging.info("Constructing Fbank computer") - opts = kaldifeat.FbankOptions() - opts.device = device - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = args.sample_rate - opts.mel_opts.num_bins = 80 - - fbank = kaldifeat.Fbank(opts) - - logging.info(f"Reading sound files: {args.sound_files}") - waves = read_sound_files( - filenames=args.sound_files, - expected_sample_rate=args.sample_rate, - ) - waves = [w.to(device) for w in waves] - - logging.info("Decoding started") - features = fbank(waves) - feature_lengths = [f.size(0) for f in features] - - features = pad_sequence( - features, - batch_first=True, - padding_value=math.log(1e-10), - ) - - feature_lengths = torch.tensor(feature_lengths, device=device) - - states = encoder.get_init_states(batch_size=features.size(0), device=device) - - encoder_out, encoder_out_lens, _ = encoder( - x=features, - x_lens=feature_lengths, - states=states, - ) - - hyps = greedy_search( - decoder=decoder, - joiner=joiner, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - context_size=args.context_size, - ) - s = "\n" - for filename, hyp in zip(args.sound_files, hyps): - words = sp.decode(hyp) - s += f"{filename}:\n{words}\n\n" - logging.info(s) - - logging.info("Decoding Done") - - -if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/lstm.py b/egs/librispeech/ASR/lstm_transducer_stateless/lstm.py index c54a4c478..e69de29bb 100644 --- a/egs/librispeech/ASR/lstm_transducer_stateless/lstm.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/lstm.py @@ -1,871 +0,0 @@ -# Copyright 2022 Xiaomi Corp. (authors: Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import copy -import math -from typing import List, Optional, Tuple - -import torch -from encoder_interface import EncoderInterface -from scaling import ( - ActivationBalancer, - BasicNorm, - DoubleSwish, - ScaledConv2d, - ScaledLinear, - ScaledLSTM, -) -from torch import nn - -LOG_EPSILON = math.log(1e-10) - - -def unstack_states( - states: Tuple[torch.Tensor, torch.Tensor] -) -> List[Tuple[torch.Tensor, torch.Tensor]]: - """ - Unstack the lstm states corresponding to a batch of utterances into a list - of states, where the i-th entry is the state from the i-th utterance. - - Args: - states: - A tuple of 2 elements. - ``states[0]`` is the lstm hidden states, of a batch of utterance. - ``states[1]`` is the lstm cell states, of a batch of utterances. - - Returns: - A list of states. - ``states[i]`` is a tuple of 2 elememts of i-th utterance. - ``states[i][0]`` is the lstm hidden states of i-th utterance. - ``states[i][1]`` is the lstm cell states of i-th utterance. - """ - hidden_states, cell_states = states - - list_hidden_states = hidden_states.unbind(dim=1) - list_cell_states = cell_states.unbind(dim=1) - - ans = [ - (h.unsqueeze(1), c.unsqueeze(1)) - for (h, c) in zip(list_hidden_states, list_cell_states) - ] - return ans - - -def stack_states( - states_list: List[Tuple[torch.Tensor, torch.Tensor]] -) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Stack list of lstm states corresponding to separate utterances into a single - lstm state so that it can be used as an input for lstm when those utterances - are formed into a batch. - - Args: - state_list: - Each element in state_list corresponds to the lstm state for a single - utterance. - ``states[i]`` is a tuple of 2 elememts of i-th utterance. - ``states[i][0]`` is the lstm hidden states of i-th utterance. - ``states[i][1]`` is the lstm cell states of i-th utterance. - - - Returns: - A new state corresponding to a batch of utterances. - It is a tuple of 2 elements. - ``states[0]`` is the lstm hidden states, of a batch of utterance. - ``states[1]`` is the lstm cell states, of a batch of utterances. - """ - hidden_states = torch.cat([s[0] for s in states_list], dim=1) - cell_states = torch.cat([s[1] for s in states_list], dim=1) - ans = (hidden_states, cell_states) - return ans - - -class RNN(EncoderInterface): - """ - Args: - num_features (int): - Number of input features. - subsampling_factor (int): - Subsampling factor of encoder (convolution layers before lstm layers) (default=4). # noqa - d_model (int): - Output dimension (default=512). - dim_feedforward (int): - Feedforward dimension (default=2048). - rnn_hidden_size (int): - Hidden dimension for lstm layers (default=1024). - num_encoder_layers (int): - Number of encoder layers (default=12). - dropout (float): - Dropout rate (default=0.1). - layer_dropout (float): - Dropout value for model-level warmup (default=0.075). - aux_layer_period (int): - Period of auxiliary layers used for random combiner during training. - If set to 0, will not use the random combiner (Default). - You can set a positive integer to use the random combiner, e.g., 3. - is_pnnx: - True to make this class exportable via PNNX. - """ - - def __init__( - self, - num_features: int, - subsampling_factor: int = 4, - d_model: int = 512, - dim_feedforward: int = 2048, - rnn_hidden_size: int = 1024, - num_encoder_layers: int = 12, - dropout: float = 0.1, - layer_dropout: float = 0.075, - aux_layer_period: int = 0, - is_pnnx: bool = False, - ) -> None: - super(RNN, self).__init__() - - self.num_features = num_features - self.subsampling_factor = subsampling_factor - if subsampling_factor != 4: - raise NotImplementedError("Support only 'subsampling_factor=4'.") - - # self.encoder_embed converts the input of shape (N, T, num_features) - # to the shape (N, T//subsampling_factor, d_model). - # That is, it does two things simultaneously: - # (1) subsampling: T -> T//subsampling_factor - # (2) embedding: num_features -> d_model - self.encoder_embed = Conv2dSubsampling( - num_features, - d_model, - is_pnnx=is_pnnx, - ) - - self.is_pnnx = is_pnnx - - self.num_encoder_layers = num_encoder_layers - self.d_model = d_model - self.rnn_hidden_size = rnn_hidden_size - - encoder_layer = RNNEncoderLayer( - d_model=d_model, - dim_feedforward=dim_feedforward, - rnn_hidden_size=rnn_hidden_size, - dropout=dropout, - layer_dropout=layer_dropout, - ) - self.encoder = RNNEncoder( - encoder_layer, - num_encoder_layers, - aux_layers=list( - range( - num_encoder_layers // 3, - num_encoder_layers - 1, - aux_layer_period, - ) - ) - if aux_layer_period > 0 - else None, - ) - - def forward( - self, - x: torch.Tensor, - x_lens: torch.Tensor, - states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - warmup: float = 1.0, - ) -> Tuple[torch.Tensor, torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - """ - Args: - x: - The input tensor. Its shape is (N, T, C), where N is the batch size, - T is the sequence length, C is the feature dimension. - x_lens: - A tensor of shape (N,), containing the number of frames in `x` - before padding. - states: - A tuple of 2 tensors (optional). It is for streaming inference. - states[0] is the hidden states of all layers, - with shape of (num_layers, N, d_model); - states[1] is the cell states of all layers, - with shape of (num_layers, N, rnn_hidden_size). - warmup: - A floating point value that gradually increases from 0 throughout - training; when it is >= 1.0 we are "fully warmed up". It is used - to turn modules on sequentially. - - Returns: - A tuple of 3 tensors: - - embeddings: its shape is (N, T', d_model), where T' is the output - sequence lengths. - - lengths: a tensor of shape (batch_size,) containing the number of - frames in `embeddings` before padding. - - updated states, whose shape is the same as the input states. - """ - x = self.encoder_embed(x) - x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - - # lengths = ((x_lens - 3) // 2 - 1) // 2 # issue an warning - # - # Note: rounding_mode in torch.div() is available only in torch >= 1.8.0 - if not self.is_pnnx: - lengths = (((x_lens - 3) >> 1) - 1) >> 1 - else: - lengths1 = torch.floor((x_lens - 3) / 2) - lengths = torch.floor((lengths1 - 1) / 2) - lengths = lengths.to(x_lens) - - if not torch.jit.is_tracing(): - assert x.size(0) == lengths.max().item() - - if states is None: - x = self.encoder(x, warmup=warmup)[0] - # torch.jit.trace requires returned types to be the same as annotated # noqa - new_states = (torch.empty(0), torch.empty(0)) - else: - assert not self.training - assert len(states) == 2 - if not torch.jit.is_tracing(): - # for hidden state - assert states[0].shape == ( - self.num_encoder_layers, - x.size(1), - self.d_model, - ) - # for cell state - assert states[1].shape == ( - self.num_encoder_layers, - x.size(1), - self.rnn_hidden_size, - ) - x, new_states = self.encoder(x, states) - - x = x.permute(1, 0, 2) # (T, N, C) -> (N, T, C) - return x, lengths, new_states - - @torch.jit.export - def get_init_states( - self, batch_size: int = 1, device: torch.device = torch.device("cpu") - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Get model initial states.""" - # for rnn hidden states - hidden_states = torch.zeros( - (self.num_encoder_layers, batch_size, self.d_model), device=device - ) - cell_states = torch.zeros( - (self.num_encoder_layers, batch_size, self.rnn_hidden_size), - device=device, - ) - return (hidden_states, cell_states) - - -class RNNEncoderLayer(nn.Module): - """ - RNNEncoderLayer is made up of lstm and feedforward networks. - - Args: - d_model: - The number of expected features in the input (required). - dim_feedforward: - The dimension of feedforward network model (default=2048). - rnn_hidden_size: - The hidden dimension of rnn layer. - dropout: - The dropout value (default=0.1). - layer_dropout: - The dropout value for model-level warmup (default=0.075). - """ - - def __init__( - self, - d_model: int, - dim_feedforward: int, - rnn_hidden_size: int, - dropout: float = 0.1, - layer_dropout: float = 0.075, - ) -> None: - super(RNNEncoderLayer, self).__init__() - self.layer_dropout = layer_dropout - self.d_model = d_model - self.rnn_hidden_size = rnn_hidden_size - - assert rnn_hidden_size >= d_model, (rnn_hidden_size, d_model) - self.lstm = ScaledLSTM( - input_size=d_model, - hidden_size=rnn_hidden_size, - proj_size=d_model if rnn_hidden_size > d_model else 0, - num_layers=1, - dropout=0.0, - ) - self.feed_forward = nn.Sequential( - ScaledLinear(d_model, dim_feedforward), - ActivationBalancer(channel_dim=-1), - DoubleSwish(), - nn.Dropout(dropout), - ScaledLinear(dim_feedforward, d_model, initial_scale=0.25), - ) - self.norm_final = BasicNorm(d_model) - - # try to ensure the output is close to zero-mean (or at least, zero-median). # noqa - self.balancer = ActivationBalancer( - channel_dim=-1, min_positive=0.45, max_positive=0.55, max_abs=6.0 - ) - self.dropout = nn.Dropout(dropout) - - def forward( - self, - src: torch.Tensor, - states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - warmup: float = 1.0, - ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - """ - Pass the input through the encoder layer. - - Args: - src: - The sequence to the encoder layer (required). - Its shape is (S, N, E), where S is the sequence length, - N is the batch size, and E is the feature number. - states: - A tuple of 2 tensors (optional). It is for streaming inference. - states[0] is the hidden states of all layers, - with shape of (1, N, d_model); - states[1] is the cell states of all layers, - with shape of (1, N, rnn_hidden_size). - warmup: - It controls selective bypass of of layers; if < 1.0, we will - bypass layers more frequently. - """ - src_orig = src - - warmup_scale = min(0.1 + warmup, 1.0) - # alpha = 1.0 means fully use this encoder layer, 0.0 would mean - # completely bypass it. - if self.training: - alpha = ( - warmup_scale - if torch.rand(()).item() <= (1.0 - self.layer_dropout) - else 0.1 - ) - else: - alpha = 1.0 - - # lstm module - if states is None: - src_lstm = self.lstm(src)[0] - # torch.jit.trace requires returned types be the same as annotated - new_states = (torch.empty(0), torch.empty(0)) - else: - assert not self.training - assert len(states) == 2 - if not torch.jit.is_tracing(): - # for hidden state - assert states[0].shape == (1, src.size(1), self.d_model) - # for cell state - assert states[1].shape == (1, src.size(1), self.rnn_hidden_size) - src_lstm, new_states = self.lstm(src, states) - src = self.dropout(src_lstm) + src - - # feed forward module - src = src + self.dropout(self.feed_forward(src)) - - src = self.norm_final(self.balancer(src)) - - if alpha != 1.0: - src = alpha * src + (1 - alpha) * src_orig - - return src, new_states - - -class RNNEncoder(nn.Module): - """ - RNNEncoder is a stack of N encoder layers. - - Args: - encoder_layer: - An instance of the RNNEncoderLayer() class (required). - num_layers: - The number of sub-encoder-layers in the encoder (required). - """ - - def __init__( - self, - encoder_layer: nn.Module, - num_layers: int, - aux_layers: Optional[List[int]] = None, - ) -> None: - super(RNNEncoder, self).__init__() - self.layers = nn.ModuleList( - [copy.deepcopy(encoder_layer) for i in range(num_layers)] - ) - self.num_layers = num_layers - self.d_model = encoder_layer.d_model - self.rnn_hidden_size = encoder_layer.rnn_hidden_size - - self.aux_layers: List[int] = [] - self.combiner: Optional[nn.Module] = None - if aux_layers is not None: - assert len(set(aux_layers)) == len(aux_layers) - assert num_layers - 1 not in aux_layers - self.aux_layers = aux_layers + [num_layers - 1] - self.combiner = RandomCombine( - num_inputs=len(self.aux_layers), - final_weight=0.5, - pure_prob=0.333, - stddev=2.0, - ) - - def forward( - self, - src: torch.Tensor, - states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - warmup: float = 1.0, - ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - """ - Pass the input through the encoder layer in turn. - - Args: - src: - The sequence to the encoder layer (required). - Its shape is (S, N, E), where S is the sequence length, - N is the batch size, and E is the feature number. - states: - A tuple of 2 tensors (optional). It is for streaming inference. - states[0] is the hidden states of all layers, - with shape of (num_layers, N, d_model); - states[1] is the cell states of all layers, - with shape of (num_layers, N, rnn_hidden_size). - warmup: - It controls selective bypass of of layers; if < 1.0, we will - bypass layers more frequently. - """ - if states is not None: - assert not self.training - assert len(states) == 2 - if not torch.jit.is_tracing(): - # for hidden state - assert states[0].shape == ( - self.num_layers, - src.size(1), - self.d_model, - ) - # for cell state - assert states[1].shape == ( - self.num_layers, - src.size(1), - self.rnn_hidden_size, - ) - - output = src - - outputs = [] - - new_hidden_states = [] - new_cell_states = [] - - for i, mod in enumerate(self.layers): - if states is None: - output = mod(output, warmup=warmup)[0] - else: - layer_state = ( - states[0][i : i + 1, :, :], # h: (1, N, d_model) - states[1][i : i + 1, :, :], # c: (1, N, rnn_hidden_size) - ) - output, (h, c) = mod(output, layer_state) - new_hidden_states.append(h) - new_cell_states.append(c) - - if self.combiner is not None and i in self.aux_layers: - outputs.append(output) - - if self.combiner is not None: - output = self.combiner(outputs) - - if states is None: - new_states = (torch.empty(0), torch.empty(0)) - else: - new_states = ( - torch.cat(new_hidden_states, dim=0), - torch.cat(new_cell_states, dim=0), - ) - - return output, new_states - - -class Conv2dSubsampling(nn.Module): - """Convolutional 2D subsampling (to 1/4 length). - - Convert an input of shape (N, T, idim) to an output - with shape (N, T', odim), where - T' = ((T-3)//2-1)//2, which approximates T' == T//4 - - It is based on - https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa - """ - - def __init__( - self, - in_channels: int, - out_channels: int, - layer1_channels: int = 8, - layer2_channels: int = 32, - layer3_channels: int = 128, - is_pnnx: bool = False, - ) -> None: - """ - Args: - in_channels: - Number of channels in. The input shape is (N, T, in_channels). - Caution: It requires: T >= 9, in_channels >= 9. - out_channels - Output dim. The output shape is (N, ((T-3)//2-1)//2, out_channels) - layer1_channels: - Number of channels in layer1 - layer1_channels: - Number of channels in layer2 - is_pnnx: - True if we are converting the model to PNNX format. - False otherwise. - """ - assert in_channels >= 9 - super().__init__() - - self.conv = nn.Sequential( - ScaledConv2d( - in_channels=1, - out_channels=layer1_channels, - kernel_size=3, - padding=0, - ), - ActivationBalancer(channel_dim=1), - DoubleSwish(), - ScaledConv2d( - in_channels=layer1_channels, - out_channels=layer2_channels, - kernel_size=3, - stride=2, - ), - ActivationBalancer(channel_dim=1), - DoubleSwish(), - ScaledConv2d( - in_channels=layer2_channels, - out_channels=layer3_channels, - kernel_size=3, - stride=2, - ), - ActivationBalancer(channel_dim=1), - DoubleSwish(), - ) - self.out = ScaledLinear( - layer3_channels * (((in_channels - 3) // 2 - 1) // 2), out_channels - ) - # set learn_eps=False because out_norm is preceded by `out`, and `out` - # itself has learned scale, so the extra degree of freedom is not - # needed. - self.out_norm = BasicNorm(out_channels, learn_eps=False) - # constrain median of output to be close to zero. - self.out_balancer = ActivationBalancer( - channel_dim=-1, min_positive=0.45, max_positive=0.55 - ) - - # ncnn supports only batch size == 1 - self.is_pnnx = is_pnnx - self.conv_out_dim = self.out.weight.shape[1] - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """Subsample x. - - Args: - x: - Its shape is (N, T, idim). - - Returns: - Return a tensor of shape (N, ((T-3)//2-1)//2, odim) - """ - # On entry, x is (N, T, idim) - x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W) - x = self.conv(x) - - if torch.jit.is_tracing() and self.is_pnnx: - x = x.permute(0, 2, 1, 3).reshape(1, -1, self.conv_out_dim) - x = self.out(x) - else: - # Now x is of shape (N, odim, ((T-3)//2-1)//2, ((idim-3)//2-1)//2) - b, c, t, f = x.size() - x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) - - # Now x is of shape (N, ((T-3)//2-1))//2, odim) - x = self.out_norm(x) - x = self.out_balancer(x) - return x - - -class RandomCombine(nn.Module): - """ - This module combines a list of Tensors, all with the same shape, to - produce a single output of that same shape which, in training time, - is a random combination of all the inputs; but which in test time - will be just the last input. - - The idea is that the list of Tensors will be a list of outputs of multiple - conformer layers. This has a similar effect as iterated loss. (See: - DEJA-VU: DOUBLE FEATURE PRESENTATION AND ITERATED LOSS IN DEEP TRANSFORMER - NETWORKS). - """ - - def __init__( - self, - num_inputs: int, - final_weight: float = 0.5, - pure_prob: float = 0.5, - stddev: float = 2.0, - ) -> None: - """ - Args: - num_inputs: - The number of tensor inputs, which equals the number of layers' - outputs that are fed into this module. E.g. in an 18-layer neural - net if we output layers 16, 12, 18, num_inputs would be 3. - final_weight: - The amount of weight or probability we assign to the - final layer when randomly choosing layers or when choosing - continuous layer weights. - pure_prob: - The probability, on each frame, with which we choose - only a single layer to output (rather than an interpolation) - stddev: - A standard deviation that we add to log-probs for computing - randomized weights. - - The method of choosing which layers, or combinations of layers, to use, - is conceptually as follows:: - - With probability `pure_prob`:: - With probability `final_weight`: choose final layer, - Else: choose random non-final layer. - Else:: - Choose initial log-weights that correspond to assigning - weight `final_weight` to the final layer and equal - weights to other layers; then add Gaussian noise - with variance `stddev` to these log-weights, and normalize - to weights (note: the average weight assigned to the - final layer here will not be `final_weight` if stddev>0). - """ - super().__init__() - assert 0 <= pure_prob <= 1, pure_prob - assert 0 < final_weight < 1, final_weight - assert num_inputs >= 1 - - self.num_inputs = num_inputs - self.final_weight = final_weight - self.pure_prob = pure_prob - self.stddev = stddev - - self.final_log_weight = ( - torch.tensor( - (final_weight / (1 - final_weight)) * (self.num_inputs - 1) - ) - .log() - .item() - ) - - def forward(self, inputs: List[torch.Tensor]) -> torch.Tensor: - """Forward function. - Args: - inputs: - A list of Tensor, e.g. from various layers of a transformer. - All must be the same shape, of (*, num_channels) - Returns: - A Tensor of shape (*, num_channels). In test mode - this is just the final input. - """ - num_inputs = self.num_inputs - assert len(inputs) == num_inputs - if not self.training or torch.jit.is_scripting(): - return inputs[-1] - - # Shape of weights: (*, num_inputs) - num_channels = inputs[0].shape[-1] - num_frames = inputs[0].numel() // num_channels - - ndim = inputs[0].ndim - # stacked_inputs: (num_frames, num_channels, num_inputs) - stacked_inputs = torch.stack(inputs, dim=ndim).reshape( - (num_frames, num_channels, num_inputs) - ) - - # weights: (num_frames, num_inputs) - weights = self._get_random_weights( - inputs[0].dtype, inputs[0].device, num_frames - ) - - weights = weights.reshape(num_frames, num_inputs, 1) - # ans: (num_frames, num_channels, 1) - ans = torch.matmul(stacked_inputs, weights) - # ans: (*, num_channels) - - ans = ans.reshape(inputs[0].shape[:-1] + (num_channels,)) - - # The following if causes errors for torch script in torch 1.6.0 - # if __name__ == "__main__": - # # for testing only... - # print("Weights = ", weights.reshape(num_frames, num_inputs)) - return ans - - def _get_random_weights( - self, dtype: torch.dtype, device: torch.device, num_frames: int - ) -> torch.Tensor: - """Return a tensor of random weights, of shape - `(num_frames, self.num_inputs)`, - Args: - dtype: - The data-type desired for the answer, e.g. float, double. - device: - The device needed for the answer. - num_frames: - The number of sets of weights desired - Returns: - A tensor of shape (num_frames, self.num_inputs), such that - `ans.sum(dim=1)` is all ones. - """ - pure_prob = self.pure_prob - if pure_prob == 0.0: - return self._get_random_mixed_weights(dtype, device, num_frames) - elif pure_prob == 1.0: - return self._get_random_pure_weights(dtype, device, num_frames) - else: - p = self._get_random_pure_weights(dtype, device, num_frames) - m = self._get_random_mixed_weights(dtype, device, num_frames) - return torch.where( - torch.rand(num_frames, 1, device=device) < self.pure_prob, p, m - ) - - def _get_random_pure_weights( - self, dtype: torch.dtype, device: torch.device, num_frames: int - ): - """Return a tensor of random one-hot weights, of shape - `(num_frames, self.num_inputs)`, - Args: - dtype: - The data-type desired for the answer, e.g. float, double. - device: - The device needed for the answer. - num_frames: - The number of sets of weights desired. - Returns: - A one-hot tensor of shape `(num_frames, self.num_inputs)`, with - exactly one weight equal to 1.0 on each frame. - """ - final_prob = self.final_weight - - # final contains self.num_inputs - 1 in all elements - final = torch.full((num_frames,), self.num_inputs - 1, device=device) - # nonfinal contains random integers in [0..num_inputs - 2], these are for non-final weights. # noqa - nonfinal = torch.randint( - self.num_inputs - 1, (num_frames,), device=device - ) - - indexes = torch.where( - torch.rand(num_frames, device=device) < final_prob, final, nonfinal - ) - ans = torch.nn.functional.one_hot( - indexes, num_classes=self.num_inputs - ).to(dtype=dtype) - return ans - - def _get_random_mixed_weights( - self, dtype: torch.dtype, device: torch.device, num_frames: int - ): - """Return a tensor of random one-hot weights, of shape - `(num_frames, self.num_inputs)`, - Args: - dtype: - The data-type desired for the answer, e.g. float, double. - device: - The device needed for the answer. - num_frames: - The number of sets of weights desired. - Returns: - A tensor of shape (num_frames, self.num_inputs), which elements - in [0..1] that sum to one over the second axis, i.e. - `ans.sum(dim=1)` is all ones. - """ - logprobs = ( - torch.randn(num_frames, self.num_inputs, dtype=dtype, device=device) - * self.stddev # noqa - ) - logprobs[:, -1] += self.final_log_weight - return logprobs.softmax(dim=1) - - -def _test_random_combine(final_weight: float, pure_prob: float, stddev: float): - print( - f"_test_random_combine: final_weight={final_weight}, pure_prob={pure_prob}, stddev={stddev}" # noqa - ) - num_inputs = 3 - num_channels = 50 - m = RandomCombine( - num_inputs=num_inputs, - final_weight=final_weight, - pure_prob=pure_prob, - stddev=stddev, - ) - - x = [torch.ones(3, 4, num_channels) for _ in range(num_inputs)] - - y = m(x) - assert y.shape == x[0].shape - assert torch.allclose(y, x[0]) # .. since actually all ones. - - -def _test_random_combine_main(): - _test_random_combine(0.999, 0, 0.0) - _test_random_combine(0.5, 0, 0.0) - _test_random_combine(0.999, 0, 0.0) - _test_random_combine(0.5, 0, 0.3) - _test_random_combine(0.5, 1, 0.3) - _test_random_combine(0.5, 0.5, 0.3) - - feature_dim = 50 - c = RNN(num_features=feature_dim, d_model=128) - batch_size = 5 - seq_len = 20 - # Just make sure the forward pass runs. - f = c( - torch.randn(batch_size, seq_len, feature_dim), - torch.full((batch_size,), seq_len, dtype=torch.int64), - ) - f # to remove flake8 warnings - - -if __name__ == "__main__": - feature_dim = 80 - m = RNN( - num_features=feature_dim, - d_model=512, - rnn_hidden_size=1024, - dim_feedforward=2048, - num_encoder_layers=12, - ) - batch_size = 5 - seq_len = 20 - # Just make sure the forward pass runs. - f = m( - torch.randn(batch_size, seq_len, feature_dim), - torch.full((batch_size,), seq_len, dtype=torch.int64), - warmup=0.5, - ) - num_param = sum([p.numel() for p in m.parameters()]) - print(f"Number of model parameters: {num_param}") - - _test_random_combine_main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/model.py b/egs/librispeech/ASR/lstm_transducer_stateless/model.py index d71132b4a..e69de29bb 100644 --- a/egs/librispeech/ASR/lstm_transducer_stateless/model.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/model.py @@ -1,210 +0,0 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, Wei Kang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from typing import Tuple - -import k2 -import torch -import torch.nn as nn -from encoder_interface import EncoderInterface -from scaling import ScaledLinear - -from icefall.utils import add_sos - - -class Transducer(nn.Module): - """It implements https://arxiv.org/pdf/1211.3711.pdf - "Sequence Transduction with Recurrent Neural Networks" - """ - - def __init__( - self, - encoder: EncoderInterface, - decoder: nn.Module, - joiner: nn.Module, - encoder_dim: int, - decoder_dim: int, - joiner_dim: int, - vocab_size: int, - ): - """ - Args: - encoder: - It is the transcription network in the paper. Its accepts - two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,). - It returns two tensors: `logits` of shape (N, T, encoder_dm) and - `logit_lens` of shape (N,). - decoder: - It is the prediction network in the paper. Its input shape - is (N, U) and its output shape is (N, U, decoder_dim). - It should contain one attribute: `blank_id`. - joiner: - It has two inputs with shapes: (N, T, encoder_dim) and - (N, U, decoder_dim). - Its output shape is (N, T, U, vocab_size). Note that its output - contains unnormalized probs, i.e., not processed by log-softmax. - """ - super().__init__() - assert isinstance(encoder, EncoderInterface), type(encoder) - assert hasattr(decoder, "blank_id") - - self.encoder = encoder - self.decoder = decoder - self.joiner = joiner - - self.simple_am_proj = ScaledLinear( - encoder_dim, vocab_size, initial_speed=0.5 - ) - self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size) - - def forward( - self, - x: torch.Tensor, - x_lens: torch.Tensor, - y: k2.RaggedTensor, - prune_range: int = 5, - am_scale: float = 0.0, - lm_scale: float = 0.0, - warmup: float = 1.0, - reduction: str = "sum", - delay_penalty: float = 0.0, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Args: - x: - A 3-D tensor of shape (N, T, C). - x_lens: - A 1-D tensor of shape (N,). It contains the number of frames in `x` - before padding. - y: - A ragged tensor with 2 axes [utt][label]. It contains labels of each - utterance. - prune_range: - The prune range for rnnt loss, it means how many symbols(context) - we are considering for each frame to compute the loss. - am_scale: - The scale to smooth the loss with am (output of encoder network) - part - lm_scale: - The scale to smooth the loss with lm (output of predictor network) - part - warmup: - A value warmup >= 0 that determines which modules are active, values - warmup > 1 "are fully warmed up" and all modules will be active. - reduction: - "sum" to sum the losses over all utterances in the batch. - "none" to return the loss in a 1-D tensor for each utterance - in the batch. - delay_penalty: - A constant value used to penalize symbol delay, to encourage - streaming models to emit symbols earlier. - See https://github.com/k2-fsa/k2/issues/955 and - https://arxiv.org/pdf/2211.00490.pdf for more details. - Returns: - Return the transducer loss. - - Note: - Regarding am_scale & lm_scale, it will make the loss-function one of - the form: - lm_scale * lm_probs + am_scale * am_probs + - (1-lm_scale-am_scale) * combined_probs - """ - assert reduction in ("sum", "none"), reduction - assert x.ndim == 3, x.shape - assert x_lens.ndim == 1, x_lens.shape - assert y.num_axes == 2, y.num_axes - - assert x.size(0) == x_lens.size(0) == y.dim0 - - encoder_out, x_lens, _ = self.encoder(x, x_lens, warmup=warmup) - assert torch.all(x_lens > 0) - - # Now for the decoder, i.e., the prediction network - row_splits = y.shape.row_splits(1) - y_lens = row_splits[1:] - row_splits[:-1] - - blank_id = self.decoder.blank_id - sos_y = add_sos(y, sos_id=blank_id) - - # sos_y_padded: [B, S + 1], start with SOS. - sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id) - - # decoder_out: [B, S + 1, decoder_dim] - decoder_out = self.decoder(sos_y_padded) - - # Note: y does not start with SOS - # y_padded : [B, S] - y_padded = y.pad(mode="constant", padding_value=0) - - y_padded = y_padded.to(torch.int64) - boundary = torch.zeros( - (x.size(0), 4), dtype=torch.int64, device=x.device - ) - boundary[:, 2] = y_lens - boundary[:, 3] = x_lens - - lm = self.simple_lm_proj(decoder_out) - am = self.simple_am_proj(encoder_out) - - with torch.cuda.amp.autocast(enabled=False): - simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( - lm=lm.float(), - am=am.float(), - symbols=y_padded, - termination_symbol=blank_id, - lm_only_scale=lm_scale, - am_only_scale=am_scale, - boundary=boundary, - reduction=reduction, - delay_penalty=delay_penalty, - return_grad=True, - ) - - # ranges : [B, T, prune_range] - ranges = k2.get_rnnt_prune_ranges( - px_grad=px_grad, - py_grad=py_grad, - boundary=boundary, - s_range=prune_range, - ) - - # am_pruned : [B, T, prune_range, encoder_dim] - # lm_pruned : [B, T, prune_range, decoder_dim] - am_pruned, lm_pruned = k2.do_rnnt_pruning( - am=self.joiner.encoder_proj(encoder_out), - lm=self.joiner.decoder_proj(decoder_out), - ranges=ranges, - ) - - # logits : [B, T, prune_range, vocab_size] - - # project_input=False since we applied the decoder's input projections - # prior to do_rnnt_pruning (this is an optimization for speed). - logits = self.joiner(am_pruned, lm_pruned, project_input=False) - - with torch.cuda.amp.autocast(enabled=False): - pruned_loss = k2.rnnt_loss_pruned( - logits=logits.float(), - symbols=y_padded, - ranges=ranges, - termination_symbol=blank_id, - boundary=boundary, - delay_penalty=delay_penalty, - reduction=reduction, - ) - - return (simple_loss, pruned_loss) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless/pretrained.py old mode 100755 new mode 100644 index 2a6e2adc6..e69de29bb --- a/egs/librispeech/ASR/lstm_transducer_stateless/pretrained.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/pretrained.py @@ -1,352 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: - -(1) greedy search -./lstm_transducer_stateless/pretrained.py \ - --checkpoint ./lstm_transducer_stateless/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ - --method greedy_search \ - /path/to/foo.wav \ - /path/to/bar.wav - -(2) beam search -./lstm_transducer_stateless/pretrained.py \ - --checkpoint ./lstm_transducer_stateless/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ - --method beam_search \ - --beam-size 4 \ - /path/to/foo.wav \ - /path/to/bar.wav - -(3) modified beam search -./lstm_transducer_stateless/pretrained.py \ - --checkpoint ./lstm_transducer_stateless/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ - --method modified_beam_search \ - --beam-size 4 \ - /path/to/foo.wav \ - /path/to/bar.wav - -(4) fast beam search -./lstm_transducer_stateless/pretrained.py \ - --checkpoint ./lstm_transducer_stateless/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ - --method fast_beam_search \ - --beam-size 4 \ - /path/to/foo.wav \ - /path/to/bar.wav - -You can also use `./lstm_transducer_stateless/exp/epoch-xx.pt`. - -Note: ./lstm_transducer_stateless/exp/pretrained.pt is generated by -./lstm_transducer_stateless/export.py -""" - - -import argparse -import logging -import math -from typing import List - -import k2 -import kaldifeat -import sentencepiece as spm -import torch -import torchaudio -from beam_search import ( - beam_search, - fast_beam_search_one_best, - greedy_search, - greedy_search_batch, - modified_beam_search, -) -from torch.nn.utils.rnn import pad_sequence -from train import add_model_arguments, get_params, get_transducer_model - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--checkpoint", - type=str, - required=True, - help="Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint().", - ) - - parser.add_argument( - "--bpe-model", - type=str, - help="""Path to bpe.model.""", - ) - - parser.add_argument( - "--method", - type=str, - default="greedy_search", - help="""Possible values are: - - greedy_search - - beam_search - - modified_beam_search - - fast_beam_search - """, - ) - - parser.add_argument( - "sound_files", - type=str, - nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", - ) - - parser.add_argument( - "--sample-rate", - type=int, - default=16000, - help="The sample rate of the input sound file", - ) - - parser.add_argument( - "--beam-size", - type=int, - default=4, - help="""An integer indicating how many candidates we will keep for each - frame. Used only when --method is beam_search or - modified_beam_search.""", - ) - - parser.add_argument( - "--beam", - type=float, - default=4, - help="""A floating point value to calculate the cutoff score during beam - search (i.e., `cutoff = max-score - beam`), which is the same as the - `beam` in Kaldi. - Used only when --method is fast_beam_search""", - ) - - parser.add_argument( - "--max-contexts", - type=int, - default=4, - help="""Used only when --method is fast_beam_search""", - ) - - parser.add_argument( - "--max-states", - type=int, - default=8, - help="""Used only when --method is fast_beam_search""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", - ) - parser.add_argument( - "--max-sym-per-frame", - type=int, - default=1, - help="""Maximum number of symbols per frame. Used only when - --method is greedy_search. - """, - ) - - add_model_arguments(parser) - - return parser - - -def read_sound_files( - filenames: List[str], expected_sample_rate: float -) -> List[torch.Tensor]: - """Read a list of sound files into a list 1-D float32 torch tensors. - Args: - filenames: - A list of sound filenames. - expected_sample_rate: - The expected sample rate of the sound files. - Returns: - Return a list of 1-D float32 torch tensors. - """ - ans = [] - for f in filenames: - wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" - ) - # We use only the first channel - ans.append(wave[0]) - return ans - - -@torch.no_grad() -def main(): - parser = get_parser() - args = parser.parse_args() - - params = get_params() - - params.update(vars(args)) - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(f"{params}") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - logging.info("Creating model") - model = get_transducer_model(params) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - checkpoint = torch.load(args.checkpoint, map_location="cpu") - model.load_state_dict(checkpoint["model"], strict=False) - model.to(device) - model.eval() - model.device = device - - logging.info("Constructing Fbank computer") - opts = kaldifeat.FbankOptions() - opts.device = device - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = params.sample_rate - opts.mel_opts.num_bins = params.feature_dim - - fbank = kaldifeat.Fbank(opts) - - logging.info(f"Reading sound files: {params.sound_files}") - waves = read_sound_files( - filenames=params.sound_files, expected_sample_rate=params.sample_rate - ) - waves = [w.to(device) for w in waves] - - logging.info("Decoding started") - features = fbank(waves) - feature_lengths = [f.size(0) for f in features] - - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) - - feature_lengths = torch.tensor(feature_lengths, device=device) - - encoder_out, encoder_out_lens, _ = model.encoder( - x=features, x_lens=feature_lengths - ) - - num_waves = encoder_out.size(0) - hyps = [] - msg = f"Using {params.method}" - if params.method == "beam_search": - msg += f" with beam size {params.beam_size}" - logging.info(msg) - - if params.method == "fast_beam_search": - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) - hyp_tokens = fast_beam_search_one_best( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.method == "modified_beam_search": - hyp_tokens = modified_beam_search( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - ) - - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.method == "greedy_search" and params.max_sym_per_frame == 1: - hyp_tokens = greedy_search_batch( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - else: - for i in range(num_waves): - # fmt: off - encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] - # fmt: on - if params.method == "greedy_search": - hyp = greedy_search( - model=model, - encoder_out=encoder_out_i, - max_sym_per_frame=params.max_sym_per_frame, - ) - elif params.method == "beam_search": - hyp = beam_search( - model=model, - encoder_out=encoder_out_i, - beam=params.beam_size, - ) - else: - raise ValueError(f"Unsupported method: {params.method}") - - hyps.append(sp.decode(hyp).split()) - - s = "\n" - for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" - logging.info(s) - - logging.info("Decoding Done") - - -if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/stream.py b/egs/librispeech/ASR/lstm_transducer_stateless/stream.py index 97d890c82..e69de29bb 100644 --- a/egs/librispeech/ASR/lstm_transducer_stateless/stream.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/stream.py @@ -1,148 +0,0 @@ -# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang, -# Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import math -from typing import List, Optional, Tuple - -import k2 -import torch -from beam_search import Hypothesis, HypothesisList - -from icefall.utils import AttributeDict - - -class Stream(object): - def __init__( - self, - params: AttributeDict, - cut_id: str, - decoding_graph: Optional[k2.Fsa] = None, - device: torch.device = torch.device("cpu"), - LOG_EPS: float = math.log(1e-10), - ) -> None: - """ - Args: - params: - It's the return value of :func:`get_params`. - cut_id: - The cut id of the current stream. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search. - device: - The device to run this stream. - LOG_EPS: - A float value used for padding. - """ - self.LOG_EPS = LOG_EPS - self.cut_id = cut_id - - # Containing attention caches and convolution caches - self.states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None - - # It uses different attributes for different decoding methods. - self.context_size = params.context_size - self.decoding_method = params.decoding_method - if params.decoding_method == "greedy_search": - self.hyp = [params.blank_id] * params.context_size - elif params.decoding_method == "modified_beam_search": - self.hyps = HypothesisList() - self.hyps.add( - Hypothesis( - ys=[params.blank_id] * params.context_size, - log_prob=torch.zeros(1, dtype=torch.float32, device=device), - ) - ) - elif params.decoding_method == "fast_beam_search": - # feature_len is needed to get partial results. - # The rnnt_decoding_stream for fast_beam_search. - self.rnnt_decoding_stream: k2.RnntDecodingStream = ( - k2.RnntDecodingStream(decoding_graph) - ) - self.hyp: Optional[List[int]] = None - else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) - - self.ground_truth: str = "" - - self.feature: Optional[torch.Tensor] = None - # Make sure all feature frames can be used. - # We aim to obtain 1 frame after subsampling. - self.chunk_length = params.subsampling_factor - self.pad_length = 5 - self.num_frames = 0 - self.num_processed_frames = 0 - - # After all feature frames are processed, we set this flag to True - self._done = False - - def set_feature(self, feature: torch.Tensor) -> None: - assert feature.dim() == 2, feature.dim() - # tail padding here to alleviate the tail deletion problem - num_tail_padded_frames = 35 - self.num_frames = feature.size(0) + num_tail_padded_frames - self.feature = torch.nn.functional.pad( - feature, - (0, 0, 0, self.pad_length + num_tail_padded_frames), - mode="constant", - value=self.LOG_EPS, - ) - - def get_feature_chunk(self) -> torch.Tensor: - """Get a chunk of feature frames. - - Returns: - A tensor of shape (ret_length, feature_dim). - """ - update_length = min( - self.num_frames - self.num_processed_frames, self.chunk_length - ) - ret_length = update_length + self.pad_length - - ret_feature = self.feature[ - self.num_processed_frames : self.num_processed_frames + ret_length - ] - # Cut off used frames. - # self.feature = self.feature[update_length:] - - self.num_processed_frames += update_length - if self.num_processed_frames >= self.num_frames: - self._done = True - - return ret_feature - - @property - def id(self) -> str: - return self.cut_id - - @property - def done(self) -> bool: - """Return True if all feature frames are processed.""" - return self._done - - def decoding_result(self) -> List[int]: - """Obtain current decoding result.""" - if self.decoding_method == "greedy_search": - return self.hyp[self.context_size :] - elif self.decoding_method == "modified_beam_search": - best_hyp = self.hyps.get_most_probable(length_norm=True) - return best_hyp.ys[self.context_size :] - else: - assert self.decoding_method == "fast_beam_search" - return self.hyp diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/streaming_decode.py b/egs/librispeech/ASR/lstm_transducer_stateless/streaming_decode.py old mode 100755 new mode 100644 index d6376bdc0..e69de29bb --- a/egs/librispeech/ASR/lstm_transducer_stateless/streaming_decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/streaming_decode.py @@ -1,968 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, -# Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: -(1) greedy search -./lstm_transducer_stateless/streaming_decode.py \ - --epoch 35 \ - --avg 10 \ - --exp-dir lstm_transducer_stateless/exp \ - --num-decode-streams 2000 \ - --num-encoder-layers 12 \ - --rnn-hidden-size 1024 \ - --decoding-method greedy_search \ - --use-averaged-model True - -(2) modified beam search -./lstm_transducer_stateless/streaming_decode.py \ - --epoch 35 \ - --avg 10 \ - --exp-dir lstm_transducer_stateless/exp \ - --num-decode-streams 2000 \ - --num-encoder-layers 12 \ - --rnn-hidden-size 1024 \ - --decoding-method modified_beam_search \ - --use-averaged-model True \ - --beam-size 4 - -(3) fast beam search -./lstm_transducer_stateless/streaming_decode.py \ - --epoch 35 \ - --avg 10 \ - --exp-dir lstm_transducer_stateless/exp \ - --num-decode-streams 2000 \ - --num-encoder-layers 12 \ - --rnn-hidden-size 1024 \ - --decoding-method fast_beam_search \ - --use-averaged-model True \ - --beam 4 \ - --max-contexts 4 \ - --max-states 8 -""" -import argparse -import logging -import warnings -from pathlib import Path -from typing import Dict, List, Optional, Tuple - -import k2 -import numpy as np -import sentencepiece as spm -import torch -import torch.nn as nn -from asr_datamodule import LibriSpeechAsrDataModule -from beam_search import Hypothesis, HypothesisList, get_hyps_shape -from kaldifeat import Fbank, FbankOptions -from lhotse import CutSet -from lstm import LOG_EPSILON, stack_states, unstack_states -from stream import Stream -from torch.nn.utils.rnn import pad_sequence -from train import add_model_arguments, get_params, get_transducer_model - -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.decode import one_best_decoding -from icefall.utils import ( - AttributeDict, - get_texts, - setup_logger, - store_transcripts, - str2bool, - write_error_stats, -) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=28, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=False, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="transducer_emformer/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", - ) - - parser.add_argument( - "--decoding-method", - type=str, - default="greedy_search", - help="""Possible values are: - - greedy_search - - modified_beam_search - - fast_beam_search - """, - ) - - parser.add_argument( - "--beam-size", - type=int, - default=4, - help="""An interger indicating how many candidates we will keep for each - frame. Used only when --decoding-method is beam_search or - modified_beam_search.""", - ) - - parser.add_argument( - "--beam", - type=float, - default=20.0, - help="""A floating point value to calculate the cutoff score during beam - search (i.e., `cutoff = max-score - beam`), which is the same as the - `beam` in Kaldi. - Used only when --decoding-method is fast_beam_search""", - ) - - parser.add_argument( - "--max-contexts", - type=int, - default=8, - help="""Used only when --decoding-method is - fast_beam_search""", - ) - - parser.add_argument( - "--max-states", - type=int, - default=64, - help="""Used only when --decoding-method is - fast_beam_search""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", - ) - parser.add_argument( - "--max-sym-per-frame", - type=int, - default=1, - help="""Maximum number of symbols per frame. - Used only when --decoding_method is greedy_search""", - ) - - parser.add_argument( - "--sampling-rate", - type=float, - default=16000, - help="Sample rate of the audio", - ) - - parser.add_argument( - "--num-decode-streams", - type=int, - default=2000, - help="The number of streams that can be decoded in parallel", - ) - - add_model_arguments(parser) - - return parser - - -def greedy_search( - model: nn.Module, - encoder_out: torch.Tensor, - streams: List[Stream], -) -> None: - """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. - - Args: - model: - The transducer model. - encoder_out: - Output from the encoder. Its shape is (N, T, C), where N >= 1. - streams: - A list of Stream objects. - """ - assert len(streams) == encoder_out.size(0) - assert encoder_out.ndim == 3 - - blank_id = model.decoder.blank_id - context_size = model.decoder.context_size - device = next(model.parameters()).device - T = encoder_out.size(1) - - encoder_out = model.joiner.encoder_proj(encoder_out) - - decoder_input = torch.tensor( - [stream.hyp[-context_size:] for stream in streams], - device=device, - dtype=torch.int64, - ) - # decoder_out is of shape (batch_size, 1, decoder_out_dim) - decoder_out = model.decoder(decoder_input, need_pad=False) - decoder_out = model.joiner.decoder_proj(decoder_out) - - for t in range(T): - # current_encoder_out's shape: (batch_size, 1, encoder_out_dim) - current_encoder_out = encoder_out[:, t : t + 1, :] # noqa - - logits = model.joiner( - current_encoder_out.unsqueeze(2), - decoder_out.unsqueeze(1), - project_input=False, - ) - # logits'shape (batch_size, vocab_size) - logits = logits.squeeze(1).squeeze(1) - - assert logits.ndim == 2, logits.shape - y = logits.argmax(dim=1).tolist() - emitted = False - for i, v in enumerate(y): - if v != blank_id: - streams[i].hyp.append(v) - emitted = True - if emitted: - # update decoder output - decoder_input = torch.tensor( - [stream.hyp[-context_size:] for stream in streams], - device=device, - dtype=torch.int64, - ) - decoder_out = model.decoder( - decoder_input, - need_pad=False, - ) - decoder_out = model.joiner.decoder_proj(decoder_out) - - -def modified_beam_search( - model: nn.Module, - encoder_out: torch.Tensor, - streams: List[Stream], - beam: int = 4, -): - """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. - - Args: - model: - The RNN-T model. - encoder_out: - A 3-D tensor of shape (N, T, encoder_out_dim) containing the output of - the encoder model. - streams: - A list of stream objects. - beam: - Number of active paths during the beam search. - """ - assert encoder_out.ndim == 3, encoder_out.shape - assert len(streams) == encoder_out.size(0) - - blank_id = model.decoder.blank_id - context_size = model.decoder.context_size - device = next(model.parameters()).device - batch_size = len(streams) - T = encoder_out.size(1) - - B = [stream.hyps for stream in streams] - - encoder_out = model.joiner.encoder_proj(encoder_out) - - for t in range(T): - current_encoder_out = encoder_out[:, t].unsqueeze(1).unsqueeze(1) - # current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim) - - hyps_shape = get_hyps_shape(B).to(device) - - A = [list(b) for b in B] - B = [HypothesisList() for _ in range(batch_size)] - - ys_log_probs = torch.stack( - [hyp.log_prob.reshape(1) for hyps in A for hyp in hyps], dim=0 - ) # (num_hyps, 1) - - decoder_input = torch.tensor( - [hyp.ys[-context_size:] for hyps in A for hyp in hyps], - device=device, - dtype=torch.int64, - ) # (num_hyps, context_size) - - decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) - decoder_out = model.joiner.decoder_proj(decoder_out) - # decoder_out is of shape (num_hyps, 1, 1, decoder_output_dim) - - # Note: For torch 1.7.1 and below, it requires a torch.int64 tensor - # as index, so we use `to(torch.int64)` below. - current_encoder_out = torch.index_select( - current_encoder_out, - dim=0, - index=hyps_shape.row_ids(1).to(torch.int64), - ) # (num_hyps, encoder_out_dim) - - logits = model.joiner( - current_encoder_out, decoder_out, project_input=False - ) - # logits is of shape (num_hyps, 1, 1, vocab_size) - - logits = logits.squeeze(1).squeeze(1) - - log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size) - - log_probs.add_(ys_log_probs) - - vocab_size = log_probs.size(-1) - - log_probs = log_probs.reshape(-1) - - row_splits = hyps_shape.row_splits(1) * vocab_size - log_probs_shape = k2.ragged.create_ragged_shape2( - row_splits=row_splits, cached_tot_size=log_probs.numel() - ) - ragged_log_probs = k2.RaggedTensor( - shape=log_probs_shape, value=log_probs - ) - - for i in range(batch_size): - topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) - - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - topk_hyp_indexes = (topk_indexes // vocab_size).tolist() - topk_token_indexes = (topk_indexes % vocab_size).tolist() - - for k in range(len(topk_hyp_indexes)): - hyp_idx = topk_hyp_indexes[k] - hyp = A[i][hyp_idx] - - new_ys = hyp.ys[:] - new_token = topk_token_indexes[k] - if new_token != blank_id: - new_ys.append(new_token) - - new_log_prob = topk_log_probs[k] - new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob) - B[i].add(new_hyp) - - for i in range(batch_size): - streams[i].hyps = B[i] - - -def fast_beam_search_one_best( - model: nn.Module, - streams: List[Stream], - encoder_out: torch.Tensor, - processed_lens: torch.Tensor, - beam: float, - max_states: int, - max_contexts: int, -) -> None: - """It limits the maximum number of symbols per frame to 1. - - A lattice is first obtained using modified beam search, and then - the shortest path within the lattice is used as the final output. - - Args: - model: - An instance of `Transducer`. - streams: - A list of stream objects. - encoder_out: - A tensor of shape (N, T, C) from the encoder. - processed_lens: - A tensor of shape (N,) containing the number of processed frames - in `encoder_out` before padding. - beam: - Beam value, similar to the beam used in Kaldi.. - max_states: - Max states per stream per frame. - max_contexts: - Max contexts pre stream per frame. - """ - assert encoder_out.ndim == 3 - - context_size = model.decoder.context_size - vocab_size = model.decoder.vocab_size - - B, T, C = encoder_out.shape - assert B == len(streams) - - config = k2.RnntDecodingConfig( - vocab_size=vocab_size, - decoder_history_len=context_size, - beam=beam, - max_contexts=max_contexts, - max_states=max_states, - ) - individual_streams = [] - for i in range(B): - individual_streams.append(streams[i].rnnt_decoding_stream) - decoding_streams = k2.RnntDecodingStreams(individual_streams, config) - - encoder_out = model.joiner.encoder_proj(encoder_out) - - for t in range(T): - # shape is a RaggedShape of shape (B, context) - # contexts is a Tensor of shape (shape.NumElements(), context_size) - shape, contexts = decoding_streams.get_contexts() - # `nn.Embedding()` in torch below v1.7.1 supports only torch.int64 - contexts = contexts.to(torch.int64) - # decoder_out is of shape (shape.NumElements(), 1, decoder_out_dim) - decoder_out = model.decoder(contexts, need_pad=False) - decoder_out = model.joiner.decoder_proj(decoder_out) - # current_encoder_out is of shape - # (shape.NumElements(), 1, joiner_dim) - # fmt: off - current_encoder_out = torch.index_select( - encoder_out[:, t:t + 1, :], 0, shape.row_ids(1).to(torch.int64) - ) - # fmt: on - logits = model.joiner( - current_encoder_out.unsqueeze(2), - decoder_out.unsqueeze(1), - project_input=False, - ) - logits = logits.squeeze(1).squeeze(1) - log_probs = logits.log_softmax(dim=-1) - decoding_streams.advance(log_probs) - - decoding_streams.terminate_and_flush_to_streams() - - lattice = decoding_streams.format_output(processed_lens.tolist()) - - best_path = one_best_decoding(lattice) - hyps = get_texts(best_path) - - for i in range(B): - streams[i].hyp = hyps[i] - - -def decode_one_chunk( - model: nn.Module, - streams: List[Stream], - params: AttributeDict, - decoding_graph: Optional[k2.Fsa] = None, -) -> List[int]: - """ - Args: - model: - The Transducer model. - streams: - A list of Stream objects. - params: - It is returned by :func:`get_params`. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or LG, Used - only when --decoding_method is fast_beam_search. - - Returns: - A list of indexes indicating the finished streams. - """ - device = next(model.parameters()).device - - feature_list = [] - feature_len_list = [] - state_list = [] - num_processed_frames_list = [] - - for stream in streams: - # We should first get `stream.num_processed_frames` - # before calling `stream.get_feature_chunk()` - # since `stream.num_processed_frames` would be updated - num_processed_frames_list.append(stream.num_processed_frames) - feature = stream.get_feature_chunk() - feature_len = feature.size(0) - feature_list.append(feature) - feature_len_list.append(feature_len) - state_list.append(stream.states) - - features = pad_sequence( - feature_list, batch_first=True, padding_value=LOG_EPSILON - ).to(device) - feature_lens = torch.tensor(feature_len_list, device=device) - num_processed_frames = torch.tensor( - num_processed_frames_list, device=device - ) - - # Make sure it has at least 1 frame after subsampling - tail_length = params.subsampling_factor + 5 - if features.size(1) < tail_length: - pad_length = tail_length - features.size(1) - feature_lens += pad_length - features = torch.nn.functional.pad( - features, - (0, 0, 0, pad_length), - mode="constant", - value=LOG_EPSILON, - ) - - # Stack states of all streams - states = stack_states(state_list) - - encoder_out, encoder_out_lens, states = model.encoder( - x=features, - x_lens=feature_lens, - states=states, - ) - - if params.decoding_method == "greedy_search": - greedy_search( - model=model, - streams=streams, - encoder_out=encoder_out, - ) - elif params.decoding_method == "modified_beam_search": - modified_beam_search( - model=model, - streams=streams, - encoder_out=encoder_out, - beam=params.beam_size, - ) - elif params.decoding_method == "fast_beam_search": - # feature_len is needed to get partial results. - # The rnnt_decoding_stream for fast_beam_search. - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - processed_lens = ( - num_processed_frames // params.subsampling_factor - + encoder_out_lens - ) - fast_beam_search_one_best( - model=model, - streams=streams, - encoder_out=encoder_out, - processed_lens=processed_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - ) - else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) - - # Update cached states of each stream - state_list = unstack_states(states) - for i, s in enumerate(state_list): - streams[i].states = s - - finished_streams = [i for i, stream in enumerate(streams) if stream.done] - return finished_streams - - -def create_streaming_feature_extractor() -> Fbank: - """Create a CPU streaming feature extractor. - - At present, we assume it returns a fbank feature extractor with - fixed options. In the future, we will support passing in the options - from outside. - - Returns: - Return a CPU streaming feature extractor. - """ - opts = FbankOptions() - opts.device = "cpu" - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = 16000 - opts.mel_opts.num_bins = 80 - return Fbank(opts) - - -def decode_dataset( - cuts: CutSet, - model: nn.Module, - params: AttributeDict, - sp: spm.SentencePieceProcessor, - decoding_graph: Optional[k2.Fsa] = None, -): - """Decode dataset. - - Args: - cuts: - Lhotse Cutset containing the dataset to decode. - params: - It is returned by :func:`get_params`. - model: - The Transducer model. - sp: - The BPE model. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or LG, Used - only when --decoding_method is fast_beam_search. - - Returns: - Return a dict, whose key may be "greedy_search" if greedy search - is used, or it may be "beam_7" if beam size of 7 is used. - Its value is a list of tuples. Each tuple contains two elements: - The first is the reference transcript, and the second is the - predicted result. - """ - device = next(model.parameters()).device - - log_interval = 300 - - fbank = create_streaming_feature_extractor() - - decode_results = [] - streams = [] - for num, cut in enumerate(cuts): - # Each utterance has a Stream. - stream = Stream( - params=params, - cut_id=cut.id, - decoding_graph=decoding_graph, - device=device, - LOG_EPS=LOG_EPSILON, - ) - - stream.states = model.encoder.get_init_states(device=device) - - audio: np.ndarray = cut.load_audio() - # audio.shape: (1, num_samples) - assert len(audio.shape) == 2 - assert audio.shape[0] == 1, "Should be single channel" - assert audio.dtype == np.float32, audio.dtype - # The trained model is using normalized samples - assert audio.max() <= 1, "Should be normalized to [-1, 1])" - - samples = torch.from_numpy(audio).squeeze(0) - feature = fbank(samples) - stream.set_feature(feature) - stream.ground_truth = cut.supervisions[0].text - - streams.append(stream) - - while len(streams) >= params.num_decode_streams: - finished_streams = decode_one_chunk( - model=model, - streams=streams, - params=params, - decoding_graph=decoding_graph, - ) - - for i in sorted(finished_streams, reverse=True): - decode_results.append( - ( - streams[i].id, - streams[i].ground_truth.split(), - sp.decode(streams[i].decoding_result()).split(), - ) - ) - del streams[i] - - if num % log_interval == 0: - logging.info(f"Cuts processed until now is {num}.") - - while len(streams) > 0: - finished_streams = decode_one_chunk( - model=model, - streams=streams, - params=params, - decoding_graph=decoding_graph, - ) - - for i in sorted(finished_streams, reverse=True): - decode_results.append( - ( - streams[i].id, - streams[i].ground_truth.split(), - sp.decode(streams[i].decoding_result()).split(), - ) - ) - del streams[i] - - if params.decoding_method == "greedy_search": - key = "greedy_search" - elif params.decoding_method == "fast_beam_search": - key = ( - f"beam_{params.beam}_" - f"max_contexts_{params.max_contexts}_" - f"max_states_{params.max_states}" - ) - else: - key = f"beam_size_{params.beam_size}" - - return {key: decode_results} - - -def save_results( - params: AttributeDict, - test_set_name: str, - results_dict: Dict[str, List[Tuple[List[str], List[str]]]], -): - test_set_wers = dict() - for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) - store_transcripts(filename=recog_path, texts=sorted(results)) - logging.info(f"The transcripts are stored in {recog_path}") - - # The following prints out WERs, per-word error statistics and aligned - # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) - with open(errs_filename, "w") as f: - wer = write_error_stats( - f, f"{test_set_name}-{key}", results, enable_log=True - ) - test_set_wers[key] = wer - - logging.info("Wrote detailed error stats to {}".format(errs_filename)) - - test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) - with open(errs_info, "w") as f: - print("settings\tWER", file=f) - for key, val in test_set_wers: - print("{}\t{}".format(key, val), file=f) - - s = "\nFor {}, WER of different settings are:\n".format(test_set_name) - note = "\tbest for {}".format(test_set_name) - for key, val in test_set_wers: - s += "{}\t{}{}\n".format(key, val, note) - note = "" - logging.info(s) - - -@torch.no_grad() -def main(): - parser = get_parser() - LibriSpeechAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - assert params.decoding_method in ( - "greedy_search", - "fast_beam_search", - "modified_beam_search", - ) - params.res_dir = params.exp_dir / "streaming" / params.decoding_method - - if params.iter > 0: - params.suffix = f"iter-{params.iter}-avg-{params.avg}" - else: - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - - if "fast_beam_search" in params.decoding_method: - params.suffix += f"-beam-{params.beam}" - params.suffix += f"-max-contexts-{params.max_contexts}" - params.suffix += f"-max-states-{params.max_states}" - elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) - else: - params.suffix += f"-context-{params.context_size}" - params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" - - if params.use_averaged_model: - params.suffix += "-use-averaged-model" - - setup_logger(f"{params.res_dir}/log-streaming-decode") - logging.info("Decoding started") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"Device: {device}") - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # and are defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - params.device = device - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.eval() - - if params.decoding_method == "fast_beam_search": - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) - else: - decoding_graph = None - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - librispeech = LibriSpeechAsrDataModule(args) - - test_clean_cuts = librispeech.test_clean_cuts() - test_other_cuts = librispeech.test_other_cuts() - - test_sets = ["test-clean", "test-other"] - test_cuts = [test_clean_cuts, test_other_cuts] - - for test_set, test_cut in zip(test_sets, test_cuts): - results_dict = decode_dataset( - cuts=test_cut, - model=model, - params=params, - sp=sp, - decoding_graph=decoding_graph, - ) - - save_results( - params=params, - test_set_name=test_set, - results_dict=results_dict, - ) - - logging.info("Done!") - - -if __name__ == "__main__": - torch.manual_seed(20220810) - main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/train.py b/egs/librispeech/ASR/lstm_transducer_stateless/train.py old mode 100755 new mode 100644 index d30fc260a..e69de29bb --- a/egs/librispeech/ASR/lstm_transducer_stateless/train.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/train.py @@ -1,1157 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, -# Wei Kang, -# Mingshuang Luo,) -# Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: - -export CUDA_VISIBLE_DEVICES="0,1,2,3" - -./lstm_transducer_stateless/train.py \ - --world-size 4 \ - --num-epochs 30 \ - --start-epoch 1 \ - --exp-dir lstm_transducer_stateless/exp \ - --full-libri 1 \ - --max-duration 300 - -# For mix precision training: - -./lstm_transducer_stateless/train.py \ - --world-size 4 \ - --num-epochs 30 \ - --start-epoch 1 \ - --use-fp16 1 \ - --exp-dir lstm_transducer_stateless/exp \ - --full-libri 1 \ - --max-duration 550 -""" - -import argparse -import copy -import logging -import warnings -from pathlib import Path -from shutil import copyfile -from typing import Any, Dict, Optional, Tuple, Union - -import k2 -import optim -import sentencepiece as spm -import torch -import torch.multiprocessing as mp -import torch.nn as nn -from asr_datamodule import LibriSpeechAsrDataModule -from decoder import Decoder -from joiner import Joiner -from lhotse.cut import Cut -from lhotse.dataset.sampling.base import CutSampler -from lhotse.utils import fix_random_seed -from lstm import RNN -from model import Transducer -from optim import Eden, Eve -from torch import Tensor -from torch.cuda.amp import GradScaler -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.utils.tensorboard import SummaryWriter - -from icefall import diagnostics -from icefall.checkpoint import load_checkpoint, remove_checkpoints -from icefall.checkpoint import save_checkpoint as save_checkpoint_impl -from icefall.checkpoint import ( - save_checkpoint_with_global_batch_idx, - update_averaged_model, -) -from icefall.dist import cleanup_dist, setup_dist -from icefall.env import get_env_info -from icefall.utils import ( - AttributeDict, - MetricsTracker, - display_and_save_batch, - setup_logger, - str2bool, -) - -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] - - -def add_model_arguments(parser: argparse.ArgumentParser): - parser.add_argument( - "--num-encoder-layers", - type=int, - default=12, - help="Number of RNN encoder layers..", - ) - - parser.add_argument( - "--encoder-dim", - type=int, - default=512, - help="Encoder output dimesion.", - ) - - parser.add_argument( - "--rnn-hidden-size", - type=int, - default=1024, - help="Hidden dim for LSTM layers.", - ) - - parser.add_argument( - "--aux-layer-period", - type=int, - default=0, - help="""Peroid of auxiliary layers used for randomly combined during training. - If set to 0, will not use the random combiner (Default). - You can set a positive integer to use the random combiner, e.g., 3. - """, - ) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--world-size", - type=int, - default=1, - help="Number of GPUs for DDP training.", - ) - - parser.add_argument( - "--master-port", - type=int, - default=12354, - help="Master port to use for DDP training.", - ) - - parser.add_argument( - "--tensorboard", - type=str2bool, - default=True, - help="Should various information be logged in tensorboard.", - ) - - parser.add_argument( - "--num-epochs", - type=int, - default=35, - help="Number of epochs to train.", - ) - - parser.add_argument( - "--start-epoch", - type=int, - default=1, - help="""Resume training from this epoch. It should be positive. - If larger than 1, it will load checkpoint from - exp-dir/epoch-{start_epoch-1}.pt - """, - ) - - parser.add_argument( - "--start-batch", - type=int, - default=0, - help="""If positive, --start-epoch is ignored and - it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt - """, - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="lstm_transducer_stateless/exp", - help="""The experiment dir. - It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", - ) - - parser.add_argument( - "--initial-lr", - type=float, - default=0.003, - help="""The initial learning rate. This value should not need to be - changed.""", - ) - - parser.add_argument( - "--lr-batches", - type=float, - default=5000, - help="""Number of steps that affects how rapidly the learning rate decreases. - We suggest not to change this.""", - ) - - parser.add_argument( - "--lr-epochs", - type=float, - default=10, - help="""Number of epochs that affects how rapidly the learning rate decreases. - """, - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", - ) - - parser.add_argument( - "--prune-range", - type=int, - default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", - ) - - parser.add_argument( - "--lm-scale", - type=float, - default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", - ) - - parser.add_argument( - "--am-scale", - type=float, - default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", - ) - - parser.add_argument( - "--simple-loss-scale", - type=float, - default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", - ) - - parser.add_argument( - "--seed", - type=int, - default=42, - help="The seed for random generators intended for reproducibility", - ) - - parser.add_argument( - "--print-diagnostics", - type=str2bool, - default=False, - help="Accumulate stats on activations, print them and exit.", - ) - - parser.add_argument( - "--save-every-n", - type=int, - default=4000, - help="""Save checkpoint after processing this number of batches" - periodically. We save checkpoint to exp-dir/ whenever - params.batch_idx_train % save_every_n == 0. The checkpoint filename - has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' - Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the - end of each epoch where `xxx` is the epoch number counting from 0. - """, - ) - - parser.add_argument( - "--keep-last-k", - type=int, - default=20, - help="""Only keep this number of checkpoints on disk. - For instance, if it is 3, there are only 3 checkpoints - in the exp-dir with filenames `checkpoint-xxx.pt`. - It does not affect checkpoints with name `epoch-xxx.pt`. - """, - ) - - parser.add_argument( - "--average-period", - type=int, - default=100, - help="""Update the averaged model, namely `model_avg`, after processing - this number of batches. `model_avg` is a separate version of model, - in which each floating-point parameter is the average of all the - parameters from the start of training. Each time we take the average, - we do: `model_avg = model * (average_period / batch_idx_train) + - model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. - """, - ) - - parser.add_argument( - "--use-fp16", - type=str2bool, - default=False, - help="Whether to use half precision training.", - ) - - parser.add_argument( - "--delay-penalty", - type=float, - default=0.0, - help="""A constant value used to penalize symbol delay, - to encourage streaming models to emit symbols earlier. - See https://github.com/k2-fsa/k2/issues/955 and - https://arxiv.org/pdf/2211.00490.pdf for more details.""", - ) - - add_model_arguments(parser) - - return parser - - -def get_params() -> AttributeDict: - """Return a dict containing training parameters. - - All training related parameters that are not passed from the commandline - are saved in the variable `params`. - - Commandline options are merged into `params` after they are parsed, so - you can also access them via `params`. - - Explanation of options saved in `params`: - - - best_train_loss: Best training loss so far. It is used to select - the model that has the lowest training loss. It is - updated during the training. - - - best_valid_loss: Best validation loss so far. It is used to select - the model that has the lowest validation loss. It is - updated during the training. - - - best_train_epoch: It is the epoch that has the best training loss. - - - best_valid_epoch: It is the epoch that has the best validation loss. - - - batch_idx_train: Used to writing statistics to tensorboard. It - contains number of batches trained so far across - epochs. - - - log_interval: Print training loss if batch_idx % log_interval` is 0 - - - reset_interval: Reset statistics if batch_idx % reset_interval is 0 - - - valid_interval: Run validation if batch_idx % valid_interval is 0 - - - feature_dim: The model input dim. It has to match the one used - in computing features. - - - subsampling_factor: The subsampling factor for the model. - - - num_decoder_layers: Number of decoder layer of transformer decoder. - - - warm_step: The warm_step for Noam optimizer. - """ - params = AttributeDict( - { - "best_train_loss": float("inf"), - "best_valid_loss": float("inf"), - "best_train_epoch": -1, - "best_valid_epoch": -1, - "batch_idx_train": 0, - "log_interval": 50, - "reset_interval": 200, - "valid_interval": 3000, # For the 100h subset, use 800 - # parameters for conformer - "feature_dim": 80, - "subsampling_factor": 4, - "dim_feedforward": 2048, - # parameters for decoder - "decoder_dim": 512, - # parameters for joiner - "joiner_dim": 512, - # parameters for Noam - "model_warm_step": 3000, # arg given to model, not for lrate - "env_info": get_env_info(), - } - ) - - return params - - -def get_encoder_model(params: AttributeDict) -> nn.Module: - encoder = RNN( - num_features=params.feature_dim, - subsampling_factor=params.subsampling_factor, - d_model=params.encoder_dim, - rnn_hidden_size=params.rnn_hidden_size, - dim_feedforward=params.dim_feedforward, - num_encoder_layers=params.num_encoder_layers, - aux_layer_period=params.aux_layer_period, - ) - return encoder - - -def get_decoder_model(params: AttributeDict) -> nn.Module: - decoder = Decoder( - vocab_size=params.vocab_size, - decoder_dim=params.decoder_dim, - blank_id=params.blank_id, - context_size=params.context_size, - ) - return decoder - - -def get_joiner_model(params: AttributeDict) -> nn.Module: - joiner = Joiner( - encoder_dim=params.encoder_dim, - decoder_dim=params.decoder_dim, - joiner_dim=params.joiner_dim, - vocab_size=params.vocab_size, - ) - return joiner - - -def get_transducer_model(params: AttributeDict) -> nn.Module: - encoder = get_encoder_model(params) - decoder = get_decoder_model(params) - joiner = get_joiner_model(params) - - model = Transducer( - encoder=encoder, - decoder=decoder, - joiner=joiner, - encoder_dim=params.encoder_dim, - decoder_dim=params.decoder_dim, - joiner_dim=params.joiner_dim, - vocab_size=params.vocab_size, - ) - return model - - -def load_checkpoint_if_available( - params: AttributeDict, - model: nn.Module, - model_avg: nn.Module = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, -) -> Optional[Dict[str, Any]]: - """Load checkpoint from file. - - If params.start_batch is positive, it will load the checkpoint from - `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if - params.start_epoch is larger than 1, it will load the checkpoint from - `params.start_epoch - 1`. - - Apart from loading state dict for `model` and `optimizer` it also updates - `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, - and `best_valid_loss` in `params`. - - Args: - params: - The return value of :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer that we are using. - scheduler: - The scheduler that we are using. - Returns: - Return a dict containing previously saved training info. - """ - if params.start_batch > 0: - filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" - elif params.start_epoch > 1: - filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" - else: - return None - - assert filename.is_file(), f"{filename} does not exist!" - - saved_params = load_checkpoint( - filename, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - ) - - keys = [ - "best_train_epoch", - "best_valid_epoch", - "batch_idx_train", - "best_train_loss", - "best_valid_loss", - ] - for k in keys: - params[k] = saved_params[k] - - if params.start_batch > 0: - if "cur_epoch" in saved_params: - params["start_epoch"] = saved_params["cur_epoch"] - - return saved_params - - -def save_checkpoint( - params: AttributeDict, - model: Union[nn.Module, DDP], - model_avg: Optional[nn.Module] = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, - sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, - rank: int = 0, -) -> None: - """Save model, optimizer, scheduler and training stats to file. - - Args: - params: - It is returned by :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer used in the training. - sampler: - The sampler for the training dataset. - scaler: - The scaler used for mix precision training. - """ - if rank != 0: - return - filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" - save_checkpoint_impl( - filename=filename, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=sampler, - scaler=scaler, - rank=rank, - ) - - if params.best_train_epoch == params.cur_epoch: - best_train_filename = params.exp_dir / "best-train-loss.pt" - copyfile(src=filename, dst=best_train_filename) - - if params.best_valid_epoch == params.cur_epoch: - best_valid_filename = params.exp_dir / "best-valid-loss.pt" - copyfile(src=filename, dst=best_valid_filename) - - -def compute_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - sp: spm.SentencePieceProcessor, - batch: dict, - is_training: bool, - warmup: float = 1.0, -) -> Tuple[Tensor, MetricsTracker]: - """ - Compute RNN-T loss given the model and its inputs. - - Args: - params: - Parameters for training. See :func:`get_params`. - model: - The model for training. It is an instance of Conformer in our case. - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - is_training: - True for training. False for validation. When it is True, this - function enables autograd during computation; when it is False, it - disables autograd. - warmup: a floating point value which increases throughout training; - values >= 1.0 are fully warmed up and have all modules present. - """ - device = ( - model.device - if isinstance(model, DDP) - else next(model.parameters()).device - ) - feature = batch["inputs"] - # at entry, feature is (N, T, C) - assert feature.ndim == 3 - feature = feature.to(device) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - texts = batch["supervisions"]["text"] - y = sp.encode(texts, out_type=int) - y = k2.RaggedTensor(y).to(device) - - with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss = model( - x=feature, - x_lens=feature_lens, - y=y, - prune_range=params.prune_range, - am_scale=params.am_scale, - lm_scale=params.lm_scale, - warmup=warmup, - reduction="none", - delay_penalty=params.delay_penalty if warmup >= 2.0 else 0, - ) - simple_loss_is_finite = torch.isfinite(simple_loss) - pruned_loss_is_finite = torch.isfinite(pruned_loss) - is_finite = simple_loss_is_finite & pruned_loss_is_finite - if not torch.all(is_finite): - logging.info( - "Not all losses are finite!\n" - f"simple_loss: {simple_loss}\n" - f"pruned_loss: {pruned_loss}" - ) - display_and_save_batch(batch, params=params, sp=sp) - simple_loss = simple_loss[simple_loss_is_finite] - pruned_loss = pruned_loss[pruned_loss_is_finite] - - # If either all simple_loss or pruned_loss is inf or nan, - # we stop the training process by raising an exception - if torch.all(~simple_loss_is_finite) or torch.all( - ~pruned_loss_is_finite - ): - raise ValueError( - "There are too many utterances in this batch " - "leading to inf or nan losses." - ) - - simple_loss = simple_loss.sum() - pruned_loss = pruned_loss.sum() - # after the main warmup step, we keep pruned_loss_scale small - # for the same amount of time (model_warm_step), to avoid - # overwhelming the simple_loss and causing it to diverge, - # in case it had not fully learned the alignment yet. - pruned_loss_scale = ( - 0.0 - if warmup < 1.0 - else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = ( - params.simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss - ) - - assert loss.requires_grad == is_training - - info = MetricsTracker() - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - # info["frames"] is an approximate number for two reasons: - # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2 - # (2) If some utterances in the batch lead to inf/nan loss, they - # are filtered out. - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) - - # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa - info["utterances"] = feature.size(0) - # averaged input duration in frames over utterances - info["utt_duration"] = feature_lens.sum().item() - # averaged padding proportion over utterances - info["utt_pad_proportion"] = ( - ((feature.size(1) - feature_lens) / feature.size(1)).sum().item() - ) - - # Note: We use reduction=sum while computing the loss. - info["loss"] = loss.detach().cpu().item() - info["simple_loss"] = simple_loss.detach().cpu().item() - info["pruned_loss"] = pruned_loss.detach().cpu().item() - - return loss, info - - -def compute_validation_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - sp: spm.SentencePieceProcessor, - valid_dl: torch.utils.data.DataLoader, - world_size: int = 1, -) -> MetricsTracker: - """Run the validation process.""" - model.eval() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(valid_dl): - loss, loss_info = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=False, - ) - assert loss.requires_grad is False - tot_loss = tot_loss + loss_info - - if world_size > 1: - tot_loss.reduce(loss.device) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - if loss_value < params.best_valid_loss: - params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = loss_value - - return tot_loss - - -def train_one_epoch( - params: AttributeDict, - model: Union[nn.Module, DDP], - optimizer: torch.optim.Optimizer, - scheduler: LRSchedulerType, - sp: spm.SentencePieceProcessor, - train_dl: torch.utils.data.DataLoader, - valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, - model_avg: Optional[nn.Module] = None, - tb_writer: Optional[SummaryWriter] = None, - world_size: int = 1, - rank: int = 0, -) -> None: - """Train the model for one epoch. - - The training loss from the mean of all frames is saved in - `params.train_loss`. It runs the validation process every - `params.valid_interval` batches. - - Args: - params: - It is returned by :func:`get_params`. - model: - The model for training. - optimizer: - The optimizer we are using. - scheduler: - The learning rate scheduler, we call step() every step. - train_dl: - Dataloader for the training dataset. - valid_dl: - Dataloader for the validation dataset. - scaler: - The scaler used for mix precision training. - model_avg: - The stored model averaged from the start of training. - tb_writer: - Writer to write log messages to tensorboard. - world_size: - Number of nodes in DDP training. If it is 1, DDP is disabled. - rank: - The rank of the node in DDP training. If no DDP is used, it should - be set to 0. - """ - model.train() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(train_dl): - params.batch_idx_train += 1 - batch_size = len(batch["supervisions"]["text"]) - - try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, loss_info = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=True, - warmup=(params.batch_idx_train / params.model_warm_step), - ) - # summary stats - tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info - - # NOTE: We use reduction==sum and loss is computed over utterances - # in the batch and there is no normalization to it so far. - scaler.scale(loss).backward() - scheduler.step_batch(params.batch_idx_train) - scaler.step(optimizer) - scaler.update() - optimizer.zero_grad() - except: # noqa - display_and_save_batch(batch, params=params, sp=sp) - raise - - if params.print_diagnostics and batch_idx == 30: - return - - if ( - rank == 0 - and params.batch_idx_train > 0 - and params.batch_idx_train % params.average_period == 0 - ): - update_averaged_model( - params=params, - model_cur=model, - model_avg=model_avg, - ) - - if ( - params.batch_idx_train > 0 - and params.batch_idx_train % params.save_every_n == 0 - ): - save_checkpoint_with_global_batch_idx( - out_dir=params.exp_dir, - global_batch_idx=params.batch_idx_train, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - remove_checkpoints( - out_dir=params.exp_dir, - topk=params.keep_last_k, - rank=rank, - ) - - if batch_idx % params.log_interval == 0: - cur_lr = scheduler.get_last_lr()[0] - logging.info( - f"Epoch {params.cur_epoch}, " - f"batch {batch_idx}, loss[{loss_info}], " - f"tot_loss[{tot_loss}], batch size: {batch_size}, " - f"lr: {cur_lr:.2e}" - ) - - if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) - - loss_info.write_summary( - tb_writer, "train/current_", params.batch_idx_train - ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) - - if batch_idx > 0 and batch_idx % params.valid_interval == 0: - logging.info("Computing validation loss") - valid_info = compute_validation_loss( - params=params, - model=model, - sp=sp, - valid_dl=valid_dl, - world_size=world_size, - ) - model.train() - logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") - if tb_writer is not None: - valid_info.write_summary( - tb_writer, "train/valid_", params.batch_idx_train - ) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - params.train_loss = loss_value - if params.train_loss < params.best_train_loss: - params.best_train_epoch = params.cur_epoch - params.best_train_loss = params.train_loss - - -def run(rank, world_size, args): - """ - Args: - rank: - It is a value between 0 and `world_size-1`, which is - passed automatically by `mp.spawn()` in :func:`main`. - The node with rank 0 is responsible for saving checkpoint. - world_size: - Number of GPUs for DDP training. - args: - The return value of get_parser().parse_args() - """ - params = get_params() - params.update(vars(args)) - if params.full_libri is False: - params.valid_interval = 800 - - fix_random_seed(params.seed) - if world_size > 1: - setup_dist(rank, world_size, params.master_port) - - setup_logger(f"{params.exp_dir}/log/log-train") - logging.info("Training started") - - if args.tensorboard and rank == 0: - tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") - else: - tb_writer = None - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", rank) - logging.info(f"Device: {device}") - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - assert params.save_every_n >= params.average_period - model_avg: Optional[nn.Module] = None - if rank == 0: - # model_avg is only used with rank 0 - model_avg = copy.deepcopy(model) - - assert params.start_epoch > 0, params.start_epoch - checkpoints = load_checkpoint_if_available( - params=params, model=model, model_avg=model_avg - ) - - model.to(device) - if world_size > 1: - logging.info("Using DDP") - model = DDP(model, device_ids=[rank]) - - optimizer = Eve(model.parameters(), lr=params.initial_lr) - - scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) - - if checkpoints and "optimizer" in checkpoints: - logging.info("Loading optimizer state dict") - optimizer.load_state_dict(checkpoints["optimizer"]) - - if ( - checkpoints - and "scheduler" in checkpoints - and checkpoints["scheduler"] is not None - ): - logging.info("Loading scheduler state dict") - scheduler.load_state_dict(checkpoints["scheduler"]) - - # # overwrite it - # scheduler.base_lrs = [params.initial_lr for _ in scheduler.base_lrs] - # print(scheduler.base_lrs) - - if params.print_diagnostics: - diagnostic = diagnostics.attach_diagnostics(model) - - librispeech = LibriSpeechAsrDataModule(args) - - train_cuts = librispeech.train_clean_100_cuts() - if params.full_libri: - train_cuts += librispeech.train_clean_360_cuts() - train_cuts += librispeech.train_other_500_cuts() - - def remove_short_and_long_utt(c: Cut): - # Keep only utterances with duration between 1 second and 20 seconds - # - # Caution: There is a reason to select 20.0 here. Please see - # ../local/display_manifest_statistics.py - # - # You should use ../local/display_manifest_statistics.py to get - # an utterance duration distribution for your dataset to select - # the threshold - if c.duration < 1.0 or c.duration > 20.0: - logging.warning( - f"Exclude cut with ID {c.id} from training. " - f"Duration: {c.duration}" - ) - return False - - # In pruned RNN-T, we require that T >= S - # where T is the number of feature frames after subsampling - # and S is the number of tokens in the utterance - - # In ./lstm.py, the conv module uses the following expression - # for subsampling - T = ((c.num_frames - 3) // 2 - 1) // 2 - tokens = sp.encode(c.supervisions[0].text, out_type=str) - - if T < len(tokens): - logging.warning( - f"Exclude cut with ID {c.id} from training. " - f"Number of frames (before subsampling): {c.num_frames}. " - f"Number of frames (after subsampling): {T}. " - f"Text: {c.supervisions[0].text}. " - f"Tokens: {tokens}. " - f"Number of tokens: {len(tokens)}" - ) - return False - - return True - - train_cuts = train_cuts.filter(remove_short_and_long_utt) - - if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: - # We only load the sampler's state dict when it loads a checkpoint - # saved in the middle of an epoch - sampler_state_dict = checkpoints["sampler"] - else: - sampler_state_dict = None - - train_dl = librispeech.train_dataloaders( - train_cuts, sampler_state_dict=sampler_state_dict - ) - - valid_cuts = librispeech.dev_clean_cuts() - valid_cuts += librispeech.dev_other_cuts() - valid_dl = librispeech.valid_dataloaders(valid_cuts) - - if not params.print_diagnostics: - scan_pessimistic_batches_for_oom( - model=model, - train_dl=train_dl, - optimizer=optimizer, - sp=sp, - params=params, - warmup=0.0 if params.start_epoch == 1 else 1.0, - ) - - scaler = GradScaler(enabled=params.use_fp16) - if checkpoints and "grad_scaler" in checkpoints: - logging.info("Loading grad scaler state dict") - scaler.load_state_dict(checkpoints["grad_scaler"]) - - for epoch in range(params.start_epoch, params.num_epochs + 1): - scheduler.step_epoch(epoch - 1) - fix_random_seed(params.seed + epoch - 1) - train_dl.sampler.set_epoch(epoch - 1) - - if tb_writer is not None: - tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) - - params.cur_epoch = epoch - - train_one_epoch( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - sp=sp, - train_dl=train_dl, - valid_dl=valid_dl, - scaler=scaler, - tb_writer=tb_writer, - world_size=world_size, - rank=rank, - ) - - if params.print_diagnostics: - diagnostic.print_diagnostics() - break - - save_checkpoint( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - - logging.info("Done!") - - if world_size > 1: - torch.distributed.barrier() - cleanup_dist() - - -def scan_pessimistic_batches_for_oom( - model: Union[nn.Module, DDP], - train_dl: torch.utils.data.DataLoader, - optimizer: torch.optim.Optimizer, - sp: spm.SentencePieceProcessor, - params: AttributeDict, - warmup: float, -): - from lhotse.dataset import find_pessimistic_batches - - logging.info( - "Sanity check -- see if any of the batches in epoch 1 would cause OOM." - ) - batches, crit_values = find_pessimistic_batches(train_dl.sampler) - for criterion, cuts in batches.items(): - batch = train_dl.dataset[cuts] - try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, _ = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=True, - warmup=warmup, - ) - loss.backward() - optimizer.step() - optimizer.zero_grad() - except RuntimeError as e: - if "CUDA out of memory" in str(e): - logging.error( - "Your GPU ran out of memory with the current " - "max_duration setting. We recommend decreasing " - "max_duration and trying again.\n" - f"Failing criterion: {criterion} " - f"(={crit_values[criterion]}) ..." - ) - raise - - -def main(): - parser = get_parser() - LibriSpeechAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - world_size = args.world_size - assert world_size >= 1 - if world_size > 1: - mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) - else: - run(rank=0, world_size=1, args=args) - - -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - -if __name__ == "__main__": - main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py b/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py index bad4e243e..f7e1b5a54 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py @@ -185,20 +185,24 @@ def get_parser(): "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'" + ), ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", + help=( + "Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. " + ), ) parser.add_argument( @@ -295,8 +299,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( @@ -474,9 +477,7 @@ def decode_one_batch( ) feature_lens += num_tail_padded_frames - encoder_out, encoder_out_lens, _ = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens, _ = model.encoder(x=feature, x_lens=feature_lens) hyps = [] @@ -535,10 +536,7 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -700,9 +698,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -735,8 +731,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -789,9 +784,7 @@ def main(): if "LG" in params.decoding_method: params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -826,13 +819,12 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -860,13 +852,12 @@ def main(): ) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -895,7 +886,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - f"Calculating the averaged model over epoch range from " + "Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -961,9 +952,7 @@ def main(): decoding_graph.scores *= params.ngram_lm_scale else: word_table = None - decoding_graph = k2.trivial_graph( - params.vocab_size - 1, device=device - ) + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) else: decoding_graph = None word_table = None diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/export.py b/egs/librispeech/ASR/lstm_transducer_stateless2/export.py index 190673638..0ad00cda3 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/export.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/export.py @@ -146,20 +146,24 @@ def get_parser(): "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'" + ), ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", + help=( + "Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. " + ), ) parser.add_argument( @@ -225,8 +229,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) add_model_arguments(parser) @@ -342,9 +345,7 @@ def export_encoder_model_onnx( x = torch.zeros(N, 9, 80, dtype=torch.float32) x_lens = torch.tensor([9], dtype=torch.int64) h = torch.rand(encoder_model.num_encoder_layers, N, encoder_model.d_model) - c = torch.rand( - encoder_model.num_encoder_layers, N, encoder_model.rnn_hidden_size - ) + c = torch.rand(encoder_model.num_encoder_layers, N, encoder_model.rnn_hidden_size) warmup = 1.0 torch.onnx.export( @@ -445,13 +446,9 @@ def export_joiner_model_onnx( - projected_decoder_out: a tensor of shape (N, joiner_dim) """ - encoder_proj_filename = str(joiner_filename).replace( - ".onnx", "_encoder_proj.onnx" - ) + encoder_proj_filename = str(joiner_filename).replace(".onnx", "_encoder_proj.onnx") - decoder_proj_filename = str(joiner_filename).replace( - ".onnx", "_decoder_proj.onnx" - ) + decoder_proj_filename = str(joiner_filename).replace(".onnx", "_decoder_proj.onnx") encoder_out_dim = joiner_model.encoder_proj.weight.shape[1] decoder_out_dim = joiner_model.decoder_proj.weight.shape[1] @@ -550,13 +547,12 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -585,13 +581,12 @@ def main(): ) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -620,7 +615,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - f"Calculating the averaged model over epoch range from " + "Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -694,9 +689,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/jit_pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless2/jit_pretrained.py index da184b76f..5a8efd718 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/jit_pretrained.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/jit_pretrained.py @@ -86,10 +86,12 @@ def get_parser(): "sound_files", type=str, nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", + help=( + "The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz." + ), ) parser.add_argument( @@ -124,10 +126,9 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" - ) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" # We use only the first channel ans.append(wave[0]) return ans @@ -315,9 +316,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/model.py b/egs/librispeech/ASR/lstm_transducer_stateless2/model.py index fadeb4ac2..4957d14b1 100644 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/model.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/model.py @@ -84,9 +84,7 @@ class Transducer(nn.Module): self.decoder_giga = decoder_giga self.joiner_giga = joiner_giga - self.simple_am_proj = ScaledLinear( - encoder_dim, vocab_size, initial_speed=0.5 - ) + self.simple_am_proj = ScaledLinear(encoder_dim, vocab_size, initial_speed=0.5) self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size) if decoder_giga is not None: @@ -190,9 +188,7 @@ class Transducer(nn.Module): y_padded = y.pad(mode="constant", padding_value=0) y_padded = y_padded.to(torch.int64) - boundary = torch.zeros( - (x.size(0), 4), dtype=torch.int64, device=x.device - ) + boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) boundary[:, 2] = y_lens boundary[:, 3] = x_lens diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/ncnn-decode.py b/egs/librispeech/ASR/lstm_transducer_stateless2/ncnn-decode.py index 410de8d3d..3b471fa85 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/ncnn-decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/ncnn-decode.py @@ -156,9 +156,7 @@ class Model: assert ret == 0, ret encoder_out = torch.from_numpy(ncnn_out0.numpy()).clone() - encoder_out_lens = torch.from_numpy(ncnn_out1.numpy()).to( - torch.int32 - ) + encoder_out_lens = torch.from_numpy(ncnn_out1.numpy()).to(torch.int32) hx = torch.from_numpy(ncnn_out2.numpy()).clone() cx = torch.from_numpy(ncnn_out3.numpy()).clone() return encoder_out, encoder_out_lens, hx, cx @@ -200,10 +198,9 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" - ) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" # We use only the first channel ans.append(wave[0]) return ans @@ -286,9 +283,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless2/pretrained.py index bef0ad760..7d931a286 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/pretrained.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/pretrained.py @@ -92,9 +92,11 @@ def get_parser(): "--checkpoint", type=str, required=True, - help="Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint().", + help=( + "Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint()." + ), ) parser.add_argument( @@ -119,10 +121,12 @@ def get_parser(): "sound_files", type=str, nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", + help=( + "The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz." + ), ) parser.add_argument( @@ -169,8 +173,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -201,10 +204,9 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" - ) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" # We use only the first channel ans.append(wave[0]) return ans @@ -267,15 +269,11 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) feature_lengths = torch.tensor(feature_lengths, device=device) - encoder_out, encoder_out_lens, _ = model.encoder( - x=features, x_lens=feature_lengths - ) + encoder_out, encoder_out_lens, _ = model.encoder(x=features, x_lens=feature_lengths) num_waves = encoder_out.size(0) hyps = [] @@ -347,9 +345,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-ncnn-decode.py b/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-ncnn-decode.py index e47a05a9e..baff15ea6 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-ncnn-decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-ncnn-decode.py @@ -144,9 +144,7 @@ class Model: assert ret == 0, ret encoder_out = torch.from_numpy(ncnn_out0.numpy()).clone() - encoder_out_lens = torch.from_numpy(ncnn_out1.numpy()).to( - torch.int32 - ) + encoder_out_lens = torch.from_numpy(ncnn_out1.numpy()).to(torch.int32) hx = torch.from_numpy(ncnn_out2.numpy()).clone() cx = torch.from_numpy(ncnn_out3.numpy()).clone() return encoder_out, encoder_out_lens, hx, cx @@ -188,10 +186,9 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" - ) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" # We use only the first channel ans.append(wave[0]) return ans @@ -229,9 +226,7 @@ def greedy_search( if decoder_out is None: assert hyp is None, hyp hyp = [blank_id] * context_size - decoder_input = torch.tensor( - hyp, dtype=torch.int32 - ) # (1, context_size) + decoder_input = torch.tensor(hyp, dtype=torch.int32) # (1, context_size) decoder_out = model.run_decoder(decoder_input).squeeze(0) else: assert decoder_out.ndim == 1 @@ -310,9 +305,7 @@ def main(): frames.append(online_fbank.get_frame(num_processed_frames + i)) num_processed_frames += offset frames = torch.cat(frames, dim=0) - encoder_out, encoder_out_lens, hx, cx = model.run_encoder( - frames, states - ) + encoder_out, encoder_out_lens, hx, cx = model.run_encoder(frames, states) states = (hx, cx) hyp, decoder_out = greedy_search( model, encoder_out.squeeze(0), decoder_out, hyp @@ -328,9 +321,7 @@ def main(): frames.append(online_fbank.get_frame(num_processed_frames + i)) num_processed_frames += offset frames = torch.cat(frames, dim=0) - encoder_out, encoder_out_lens, hx, cx = model.run_encoder( - frames, states - ) + encoder_out, encoder_out_lens, hx, cx = model.run_encoder(frames, states) states = (hx, cx) hyp, decoder_out = greedy_search( model, encoder_out.squeeze(0), decoder_out, hyp @@ -343,9 +334,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-onnx-decode.py b/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-onnx-decode.py index 232d3dd18..b31fefa0a 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-onnx-decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-onnx-decode.py @@ -109,10 +109,12 @@ def get_args(): parser.add_argument( "sound_filename", type=str, - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", + help=( + "The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz." + ), ) parser.add_argument( @@ -147,10 +149,9 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" - ) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" # We use only the first channel ans.append(wave[0]) return ans @@ -199,9 +200,7 @@ class Model: sess_options=self.session_opts, ) - def run_encoder( - self, x, h0, c0 - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + def run_encoder(self, x, h0, c0) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Args: x: @@ -258,9 +257,7 @@ class Model: }, )[0] - return self.run_joiner_decoder_proj( - torch.from_numpy(decoder_out).squeeze(1) - ) + return self.run_joiner_decoder_proj(torch.from_numpy(decoder_out).squeeze(1)) def run_joiner( self, @@ -303,11 +300,7 @@ class Model: projected_encoder_out = self.joiner_encoder_proj.run( [self.joiner_encoder_proj.get_outputs()[0].name], - { - self.joiner_encoder_proj.get_inputs()[ - 0 - ].name: encoder_out.numpy() - }, + {self.joiner_encoder_proj.get_inputs()[0].name: encoder_out.numpy()}, )[0] return torch.from_numpy(projected_encoder_out) @@ -326,11 +319,7 @@ class Model: projected_decoder_out = self.joiner_decoder_proj.run( [self.joiner_decoder_proj.get_outputs()[0].name], - { - self.joiner_decoder_proj.get_inputs()[ - 0 - ].name: decoder_out.numpy() - }, + {self.joiner_decoder_proj.get_inputs()[0].name: decoder_out.numpy()}, )[0] return torch.from_numpy(projected_decoder_out) @@ -369,9 +358,7 @@ def greedy_search( if decoder_out is None: assert hyp is None, hyp hyp = [blank_id] * context_size - decoder_input = torch.tensor( - [hyp], dtype=torch.int64 - ) # (1, context_size) + decoder_input = torch.tensor([hyp], dtype=torch.int64) # (1, context_size) decoder_out = model.run_decoder(decoder_input) else: assert decoder_out.shape[0] == 1 @@ -474,9 +461,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/train.py b/egs/librispeech/ASR/lstm_transducer_stateless2/train.py index 5eaaf321f..08a895a75 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/train.py @@ -95,9 +95,7 @@ from icefall.utils import ( str2bool, ) -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def add_model_arguments(parser: argparse.ArgumentParser): @@ -163,8 +161,7 @@ def get_parser(): "--full-libri", type=str2bool, default=True, - help="When enabled, use 960h LibriSpeech. " - "Otherwise, use 100h subset.", + help="When enabled, use 960h LibriSpeech. Otherwise, use 100h subset.", ) parser.add_argument( @@ -238,42 +235,45 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", + help=( + "The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss" + ), ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", + help=( + "The scale to smooth the loss with lm (output of prediction network) part." + ), ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network)part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", + help=( + "To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss." + ), ) parser.add_argument( @@ -645,11 +645,7 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = ( - model.device - if isinstance(model, DDP) - else next(model.parameters()).device - ) + device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -692,9 +688,7 @@ def compute_loss( # If either all simple_loss or pruned_loss is inf or nan, # we stop the training process by raising an exception - if torch.all(~simple_loss_is_finite) or torch.all( - ~pruned_loss_is_finite - ): + if torch.all(~simple_loss_is_finite) or torch.all(~pruned_loss_is_finite): raise ValueError( "There are too many utterances in this batch " "leading to inf or nan losses." @@ -707,14 +701,9 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 - if warmup < 1.0 - else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = ( - params.simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss + 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) ) + loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training @@ -725,9 +714,7 @@ def compute_loss( # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2 # (2) If some utterances in the batch lead to inf/nan loss, they # are filtered out. - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa info["utterances"] = feature.size(0) @@ -958,9 +945,7 @@ def train_one_epoch( f"train/current_{prefix}_", params.batch_idx_train, ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) libri_tot_loss.write_summary( tb_writer, "train/libri_tot_", params.batch_idx_train ) @@ -1006,8 +991,7 @@ def filter_short_and_long_utterances( # the threshold if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. " - f"Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" ) return False @@ -1155,9 +1139,7 @@ def run(rank, world_size, args): train_giga_cuts = train_giga_cuts.repeat(times=None) if args.enable_musan: - cuts_musan = load_manifest( - Path(args.manifest_dir) / "musan_cuts.jsonl.gz" - ) + cuts_musan = load_manifest(Path(args.manifest_dir) / "musan_cuts.jsonl.gz") else: cuts_musan = None diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/decode.py b/egs/librispeech/ASR/lstm_transducer_stateless3/decode.py index 9eee19379..a8d5605fb 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/decode.py @@ -182,20 +182,24 @@ def get_parser(): "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'" + ), ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", + help=( + "Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. " + ), ) parser.add_argument( @@ -290,8 +294,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( @@ -386,9 +389,7 @@ def decode_one_batch( ) feature_lens += num_tail_padded_frames - encoder_out, encoder_out_lens, _ = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens, _ = model.encoder(x=feature, x_lens=feature_lens) if params.decoding_method == "fast_beam_search": res = fast_beam_search_one_best( @@ -441,10 +442,7 @@ def decode_one_batch( nbest_scale=params.nbest_scale, return_timestamps=True, ) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: res = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -522,9 +520,7 @@ def decode_dataset( sp: spm.SentencePieceProcessor, word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[ - str, List[Tuple[str, List[str], List[str], List[float], List[float]]] -]: +) -> Dict[str, List[Tuple[str, List[str], List[str], List[float], List[float]]]]: """Decode dataset. Args: @@ -599,9 +595,7 @@ def decode_dataset( cut_ids, hyps, texts, timestamps_hyp, timestamps_ref ): ref_words = ref_text.split() - this_batch.append( - (cut_id, ref_words, hyp_words, time_ref, time_hyp) - ) + this_batch.append((cut_id, ref_words, hyp_words, time_ref, time_hyp)) results[name].extend(this_batch) @@ -610,9 +604,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -650,8 +642,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -678,9 +669,7 @@ def save_results( note = "" logging.info(s) - s = "\nFor {}, symbol-delay of different settings are:\n".format( - test_set_name - ) + s = "\nFor {}, symbol-delay of different settings are:\n".format(test_set_name) note = "\tbest for {}".format(test_set_name) for key, val in test_set_delays: s += "{}\tmean: {}s, variance: {}{}\n".format(key, val[0], val[1], note) @@ -724,9 +713,7 @@ def main(): if "LG" in params.decoding_method: params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -758,13 +745,12 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -787,13 +773,12 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -821,7 +806,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - f"Calculating the averaged model over epoch range from " + "Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -848,9 +833,7 @@ def main(): decoding_graph.scores *= params.ngram_lm_scale else: word_table = None - decoding_graph = k2.trivial_graph( - params.vocab_size - 1, device=device - ) + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) else: decoding_graph = None word_table = None diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/export.py b/egs/librispeech/ASR/lstm_transducer_stateless3/export.py index 212c7bad6..51238f768 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/export.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/export.py @@ -122,20 +122,24 @@ def get_parser(): "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'" + ), ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", + help=( + "Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. " + ), ) parser.add_argument( @@ -172,8 +176,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) add_model_arguments(parser) @@ -281,13 +284,12 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -310,13 +312,12 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -344,7 +345,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - f"Calculating the averaged model over epoch range from " + "Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -380,9 +381,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/jit_pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless3/jit_pretrained.py index a3443cf0a..180ba8c72 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/jit_pretrained.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/jit_pretrained.py @@ -85,10 +85,12 @@ def get_parser(): "sound_files", type=str, nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", + help=( + "The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz." + ), ) parser.add_argument( @@ -123,10 +125,9 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" - ) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" # We use only the first channel ans.append(wave[0]) return ans @@ -314,9 +315,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/lstm.py b/egs/librispeech/ASR/lstm_transducer_stateless3/lstm.py index 90bc351f4..6e51b85e4 100644 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/lstm.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/lstm.py @@ -661,9 +661,7 @@ class RandomCombine(nn.Module): self.stddev = stddev self.final_log_weight = ( - torch.tensor( - (final_weight / (1 - final_weight)) * (self.num_inputs - 1) - ) + torch.tensor((final_weight / (1 - final_weight)) * (self.num_inputs - 1)) .log() .item() ) @@ -760,16 +758,14 @@ class RandomCombine(nn.Module): # final contains self.num_inputs - 1 in all elements final = torch.full((num_frames,), self.num_inputs - 1, device=device) # nonfinal contains random integers in [0..num_inputs - 2], these are for non-final weights. # noqa - nonfinal = torch.randint( - self.num_inputs - 1, (num_frames,), device=device - ) + nonfinal = torch.randint(self.num_inputs - 1, (num_frames,), device=device) indexes = torch.where( torch.rand(num_frames, device=device) < final_prob, final, nonfinal ) - ans = torch.nn.functional.one_hot( - indexes, num_classes=self.num_inputs - ).to(dtype=dtype) + ans = torch.nn.functional.one_hot(indexes, num_classes=self.num_inputs).to( + dtype=dtype + ) return ans def _get_random_mixed_weights( diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless3/pretrained.py index 0e48fef04..4f8049245 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/pretrained.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/pretrained.py @@ -89,9 +89,11 @@ def get_parser(): "--checkpoint", type=str, required=True, - help="Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint().", + help=( + "Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint()." + ), ) parser.add_argument( @@ -116,10 +118,12 @@ def get_parser(): "sound_files", type=str, nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", + help=( + "The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz." + ), ) parser.add_argument( @@ -166,8 +170,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -198,10 +201,9 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" - ) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" # We use only the first channel ans.append(wave[0]) return ans @@ -264,15 +266,11 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) feature_lengths = torch.tensor(feature_lengths, device=device) - encoder_out, encoder_out_lens, _ = model.encoder( - x=features, x_lens=feature_lengths - ) + encoder_out, encoder_out_lens, _ = model.encoder(x=features, x_lens=feature_lengths) num_waves = encoder_out.size(0) hyps = [] @@ -344,9 +342,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/streaming_decode.py b/egs/librispeech/ASR/lstm_transducer_stateless3/streaming_decode.py index cfa918ed5..4e9063a40 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/streaming_decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/streaming_decode.py @@ -101,8 +101,9 @@ def get_parser(): "--epoch", type=int, default=40, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", + help=( + "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." + ), ) parser.add_argument( @@ -119,20 +120,24 @@ def get_parser(): "--avg", type=int, default=20, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. " + ), ) parser.add_argument( "--use-averaged-model", type=str2bool, default=False, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", + help=( + "Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. " + ), ) parser.add_argument( @@ -199,8 +204,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -359,9 +363,7 @@ def modified_beam_search( index=hyps_shape.row_ids(1).to(torch.int64), ) # (num_hyps, encoder_out_dim) - logits = model.joiner( - current_encoder_out, decoder_out, project_input=False - ) + logits = model.joiner(current_encoder_out, decoder_out, project_input=False) # logits is of shape (num_hyps, 1, 1, vocab_size) logits = logits.squeeze(1).squeeze(1) @@ -378,9 +380,7 @@ def modified_beam_search( log_probs_shape = k2.ragged.create_ragged_shape2( row_splits=row_splits, cached_tot_size=log_probs.numel() ) - ragged_log_probs = k2.RaggedTensor( - shape=log_probs_shape, value=log_probs - ) + ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) for i in range(batch_size): topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) @@ -539,9 +539,7 @@ def decode_one_chunk( feature_list, batch_first=True, padding_value=LOG_EPSILON ).to(device) feature_lens = torch.tensor(feature_len_list, device=device) - num_processed_frames = torch.tensor( - num_processed_frames_list, device=device - ) + num_processed_frames = torch.tensor(num_processed_frames_list, device=device) # Make sure it has at least 1 frame after subsampling tail_length = params.subsampling_factor + 5 @@ -583,8 +581,7 @@ def decode_one_chunk( with warnings.catch_warnings(): warnings.simplefilter("ignore") processed_lens = ( - num_processed_frames // params.subsampling_factor - + encoder_out_lens + num_processed_frames // params.subsampling_factor + encoder_out_lens ) fast_beam_search_one_best( model=model, @@ -596,9 +593,7 @@ def decode_one_chunk( max_states=params.max_states, ) else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") # Update cached states of each stream state_list = unstack_states(states) @@ -773,8 +768,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -816,9 +810,7 @@ def main(): params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -852,13 +844,12 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -881,13 +872,12 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -915,7 +905,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - f"Calculating the averaged model over epoch range from " + "Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/train.py b/egs/librispeech/ASR/lstm_transducer_stateless3/train.py index 60a5a2be7..a1d19fb73 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/train.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/train.py @@ -87,9 +87,7 @@ from icefall.utils import ( str2bool, ) -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def add_model_arguments(parser: argparse.ArgumentParser): @@ -232,42 +230,45 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", + help=( + "The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss" + ), ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", + help=( + "The scale to smooth the loss with lm (output of prediction network) part." + ), ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network)part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", + help=( + "To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss." + ), ) parser.add_argument( @@ -606,11 +607,7 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = ( - model.device - if isinstance(model, DDP) - else next(model.parameters()).device - ) + device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -650,9 +647,7 @@ def compute_loss( # If either all simple_loss or pruned_loss is inf or nan, # we stop the training process by raising an exception - if torch.all(~simple_loss_is_finite) or torch.all( - ~pruned_loss_is_finite - ): + if torch.all(~simple_loss_is_finite) or torch.all(~pruned_loss_is_finite): raise ValueError( "There are too many utterances in this batch " "leading to inf or nan losses." @@ -665,14 +660,9 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 - if warmup < 1.0 - else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = ( - params.simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss + 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) ) + loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training @@ -683,9 +673,7 @@ def compute_loss( # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2 # (2) If some utterances in the batch lead to inf/nan loss, they # are filtered out. - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa info["utterances"] = feature.size(0) @@ -852,10 +840,7 @@ def train_one_epoch( rank=rank, ) - if ( - batch_idx % params.log_interval == 0 - and not params.print_diagnostics - ): + if batch_idx % params.log_interval == 0 and not params.print_diagnostics: cur_lr = scheduler.get_last_lr()[0] logging.info( f"Epoch {params.cur_epoch}, " @@ -872,9 +857,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if ( batch_idx > 0 @@ -1009,8 +992,7 @@ def run(rank, world_size, args): # the threshold if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. " - f"Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" ) return False diff --git a/egs/librispeech/ASR/pruned2_knowledge/asr_datamodule.py b/egs/librispeech/ASR/pruned2_knowledge/asr_datamodule.py index 8dd1459ca..fd2a5354a 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/asr_datamodule.py +++ b/egs/librispeech/ASR/pruned2_knowledge/asr_datamodule.py @@ -74,17 +74,18 @@ class LibriSpeechAsrDataModule: def add_arguments(cls, parser: argparse.ArgumentParser): group = parser.add_argument_group( title="ASR data related options", - description="These options are used for the preparation of " - "PyTorch DataLoaders from Lhotse CutSet's -- they control the " - "effective batch sizes, sampling strategies, applied data " - "augmentations, etc.", + description=( + "These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc." + ), ) group.add_argument( "--full-libri", type=str2bool, default=True, - help="When enabled, use 960h LibriSpeech. " - "Otherwise, use 100h subset.", + help="When enabled, use 960h LibriSpeech. Otherwise, use 100h subset.", ) group.add_argument( "--manifest-dir", @@ -96,75 +97,91 @@ class LibriSpeechAsrDataModule: "--max-duration", type=int, default=200.0, - help="Maximum pooled recordings duration (seconds) in a " - "single batch. You can reduce it if it causes CUDA OOM.", + help=( + "Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM." + ), ) group.add_argument( "--bucketing-sampler", type=str2bool, default=True, - help="When enabled, the batches will come from buckets of " - "similar duration (saves padding frames).", + help=( + "When enabled, the batches will come from buckets of " + "similar duration (saves padding frames)." + ), ) group.add_argument( "--num-buckets", type=int, default=30, - help="The number of buckets for the BucketingSampler" - "(you might want to increase it for larger datasets).", + help=( + "The number of buckets for the BucketingSampler" + "(you might want to increase it for larger datasets)." + ), ) group.add_argument( "--concatenate-cuts", type=str2bool, default=False, - help="When enabled, utterances (cuts) will be concatenated " - "to minimize the amount of padding.", + help=( + "When enabled, utterances (cuts) will be concatenated " + "to minimize the amount of padding." + ), ) group.add_argument( "--duration-factor", type=float, default=1.0, - help="Determines the maximum duration of a concatenated cut " - "relative to the duration of the longest cut in a batch.", + help=( + "Determines the maximum duration of a concatenated cut " + "relative to the duration of the longest cut in a batch." + ), ) group.add_argument( "--gap", type=float, default=1.0, - help="The amount of padding (in seconds) inserted between " - "concatenated cuts. This padding is filled with noise when " - "noise augmentation is used.", + help=( + "The amount of padding (in seconds) inserted between " + "concatenated cuts. This padding is filled with noise when " + "noise augmentation is used." + ), ) group.add_argument( "--on-the-fly-feats", type=str2bool, default=False, - help="When enabled, use on-the-fly cut mixing and feature " - "extraction. Will drop existing precomputed feature manifests " - "if available.", + help=( + "When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available." + ), ) group.add_argument( "--shuffle", type=str2bool, default=True, - help="When enabled (=default), the examples will be " - "shuffled for each epoch.", + help=( + "When enabled (=default), the examples will be shuffled for each epoch." + ), ) group.add_argument( "--return-cuts", type=str2bool, default=True, - help="When enabled, each batch will have the " - "field: batch['supervisions']['cut'] with the cuts that " - "were used to construct it.", + help=( + "When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it." + ), ) group.add_argument( "--num-workers", type=int, default=2, - help="The number of training dataloader workers that " - "collect the batches.", + help="The number of training dataloader workers that collect the batches.", ) group.add_argument( @@ -178,18 +195,22 @@ class LibriSpeechAsrDataModule: "--spec-aug-time-warp-factor", type=int, default=80, - help="Used only when --enable-spec-aug is True. " - "It specifies the factor for time warping in SpecAugment. " - "Larger values mean more warping. " - "A value less than 1 means to disable time warp.", + help=( + "Used only when --enable-spec-aug is True. " + "It specifies the factor for time warping in SpecAugment. " + "Larger values mean more warping. " + "A value less than 1 means to disable time warp." + ), ) group.add_argument( "--enable-musan", type=str2bool, default=True, - help="When enabled, select noise from MUSAN and mix it" - "with training dataset. ", + help=( + "When enabled, select noise from MUSAN and mix it" + "with training dataset. " + ), ) def train_dataloaders( @@ -208,20 +229,16 @@ class LibriSpeechAsrDataModule: if self.args.enable_musan: logging.info("Enable MUSAN") logging.info("About to get Musan cuts") - cuts_musan = load_manifest( - self.args.manifest_dir / "cuts_musan.json.gz" - ) + cuts_musan = load_manifest(self.args.manifest_dir / "cuts_musan.json.gz") transforms.append( - CutMix( - cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True - ) + CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) ) else: logging.info("Disable MUSAN") if self.args.concatenate_cuts: logging.info( - f"Using cut concatenation with duration factor " + "Using cut concatenation with duration factor " f"{self.args.duration_factor} and gap {self.args.gap}." ) # Cut concatenation should be the first transform in the list, @@ -236,9 +253,7 @@ class LibriSpeechAsrDataModule: input_transforms = [] if self.args.enable_spec_aug: logging.info("Enable SpecAugment") - logging.info( - f"Time warp factor: {self.args.spec_aug_time_warp_factor}" - ) + logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") # Set the value of num_frame_masks according to Lhotse's version. # In different Lhotse's versions, the default of num_frame_masks is # different. @@ -281,9 +296,7 @@ class LibriSpeechAsrDataModule: # Drop feats to be on the safe side. train = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) - ), + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), input_transforms=input_transforms, return_cuts=self.args.return_cuts, ) @@ -340,9 +353,7 @@ class LibriSpeechAsrDataModule: if self.args.on_the_fly_feats: validate = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) - ), + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), return_cuts=self.args.return_cuts, ) else: @@ -389,23 +400,17 @@ class LibriSpeechAsrDataModule: @lru_cache() def train_clean_100_cuts(self) -> CutSet: logging.info("About to get train-clean-100 cuts") - return load_manifest( - self.args.manifest_dir / "cuts_train-clean-100.json.gz" - ) + return load_manifest(self.args.manifest_dir / "cuts_train-clean-100.json.gz") @lru_cache() def train_clean_360_cuts(self) -> CutSet: logging.info("About to get train-clean-360 cuts") - return load_manifest( - self.args.manifest_dir / "cuts_train-clean-360.json.gz" - ) + return load_manifest(self.args.manifest_dir / "cuts_train-clean-360.json.gz") @lru_cache() def train_other_500_cuts(self) -> CutSet: logging.info("About to get train-other-500 cuts") - return load_manifest( - self.args.manifest_dir / "cuts_train-other-500.json.gz" - ) + return load_manifest(self.args.manifest_dir / "cuts_train-other-500.json.gz") @lru_cache() def dev_clean_cuts(self) -> CutSet: diff --git a/egs/librispeech/ASR/pruned2_knowledge/beam_search.py b/egs/librispeech/ASR/pruned2_knowledge/beam_search.py index 2e9bf3e0b..785a8f097 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/beam_search.py +++ b/egs/librispeech/ASR/pruned2_knowledge/beam_search.py @@ -172,9 +172,9 @@ def greedy_search( y = logits.argmax().item() if y != blank_id: hyp.append(y) - decoder_input = torch.tensor( - [hyp[-context_size:]], device=device - ).reshape(1, context_size) + decoder_input = torch.tensor([hyp[-context_size:]], device=device).reshape( + 1, context_size + ) decoder_out = model.decoder(decoder_input, need_pad=False) decoder_out = model.joiner.decoder_proj(decoder_out) @@ -302,9 +302,7 @@ class HypothesisList(object): key = hyp.key if key in self: old_hyp = self._data[key] # shallow copy - torch.logaddexp( - old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob - ) + torch.logaddexp(old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob) else: self._data[key] = hyp @@ -320,9 +318,7 @@ class HypothesisList(object): Return the hypothesis that has the largest `log_prob`. """ if length_norm: - return max( - self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys) - ) + return max(self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys)) else: return max(self._data.values(), key=lambda hyp: hyp.log_prob) @@ -496,9 +492,7 @@ def modified_beam_search( log_probs_shape = k2.ragged.create_ragged_shape2( row_splits=row_splits, cached_tot_size=log_probs.numel() ) - ragged_log_probs = k2.RaggedTensor( - shape=log_probs_shape, value=log_probs - ) + ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) for i in range(batch_size): topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) diff --git a/egs/librispeech/ASR/pruned2_knowledge/conformer.py b/egs/librispeech/ASR/pruned2_knowledge/conformer.py index 295a35204..3b6d0549d 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/conformer.py +++ b/egs/librispeech/ASR/pruned2_knowledge/conformer.py @@ -18,10 +18,10 @@ import math import warnings from typing import Optional, Tuple -from sampling import create_knowledge_base, KnowledgeBaseLookup import torch from encoder_interface import EncoderInterface +from sampling import KnowledgeBaseLookup, create_knowledge_base from scaling import ( ActivationBalancer, BasicNorm, @@ -73,9 +73,9 @@ class Conformer(EncoderInterface): if subsampling_factor != 4: raise NotImplementedError("Support only 'subsampling_factor=4'.") - - self.knowledge_base = create_knowledge_base(knowledge_M, knowledge_N, - knowledge_D) + self.knowledge_base = create_knowledge_base( + knowledge_M, knowledge_N, knowledge_D + ) # self.encoder_embed converts the input of shape (N, T, num_features) # to the shape (N, T//subsampling_factor, d_model). @@ -89,7 +89,7 @@ class Conformer(EncoderInterface): # Pass in a lambda that creates a new ConformerEncoderLayer with these # args. Don't use deepcopy because we need the knowledge_base # to be shared. - encoder_layer_fn = lambda: ConformerEncoderLayer( + encoder_layer_fn = lambda: ConformerEncoderLayer( # noqa: E731 self.knowledge_base, d_model, nhead, @@ -100,7 +100,7 @@ class Conformer(EncoderInterface): knowledge_M, knowledge_N, knowledge_D, - knowledge_K + knowledge_K, ) self.encoder = ConformerEncoder(encoder_layer_fn, num_encoder_layers) @@ -187,9 +187,7 @@ class ConformerEncoderLayer(nn.Module): self.d_model = d_model - self.self_attn = RelPositionMultiheadAttention( - d_model, nhead, dropout=0.0 - ) + self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) self.feed_forward = nn.Sequential( ScaledLinear(d_model, dim_feedforward), @@ -209,10 +207,14 @@ class ConformerEncoderLayer(nn.Module): self.conv_module = ConvolutionModule(d_model, cnn_module_kernel) - self.lookup = KnowledgeBaseLookup(knowledge_M, knowledge_N, - knowledge_D, knowledge_K, - d_model, - knowledge_base) + self.lookup = KnowledgeBaseLookup( + knowledge_M, + knowledge_N, + knowledge_D, + knowledge_K, + d_model, + knowledge_base, + ) self.norm_final = BasicNorm(d_model) @@ -311,9 +313,7 @@ class ConformerEncoder(nn.Module): def __init__(self, encoder_layer_fn, num_layers: int) -> None: super().__init__() - self.layers = nn.ModuleList( - [encoder_layer_fn() for i in range(num_layers)] - ) + self.layers = nn.ModuleList([encoder_layer_fn() for i in range(num_layers)]) self.num_layers = num_layers def forward( @@ -367,9 +367,7 @@ class RelPositionalEncoding(torch.nn.Module): """ - def __init__( - self, d_model: int, dropout_rate: float, max_len: int = 5000 - ) -> None: + def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: """Construct an PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() self.d_model = d_model @@ -384,9 +382,7 @@ class RelPositionalEncoding(torch.nn.Module): # the length of self.pe is 2 * input_len - 1 if self.pe.size(1) >= x.size(1) * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str( - x.device - ): + if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return # Suppose `i` means to the position of query vecotr and `j` means the @@ -661,9 +657,9 @@ class RelPositionMultiheadAttention(nn.Module): if torch.equal(query, key) and torch.equal(key, value): # self-attention - q, k, v = nn.functional.linear( - query, in_proj_weight, in_proj_bias - ).chunk(3, dim=-1) + q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk( + 3, dim=-1 + ) elif torch.equal(key, value): # encoder-decoder attention @@ -732,33 +728,25 @@ class RelPositionMultiheadAttention(nn.Module): if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: - raise RuntimeError( - "The size of the 2D attn_mask is not correct." - ) + raise RuntimeError("The size of the 2D attn_mask is not correct.") elif attn_mask.dim() == 3: if list(attn_mask.size()) != [ bsz * num_heads, query.size(0), key.size(0), ]: - raise RuntimeError( - "The size of the 3D attn_mask is not correct." - ) + raise RuntimeError("The size of the 3D attn_mask is not correct.") else: raise RuntimeError( - "attn_mask's dimension {} is not supported".format( - attn_mask.dim() - ) + "attn_mask's dimension {} is not supported".format(attn_mask.dim()) ) # attn_mask's dim is 3 now. # convert ByteTensor key_padding_mask to bool - if ( - key_padding_mask is not None - and key_padding_mask.dtype == torch.uint8 - ): + if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: warnings.warn( - "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." + "Byte tensor for key_padding_mask is deprecated. Use bool tensor" + " instead." ) key_padding_mask = key_padding_mask.to(torch.bool) @@ -795,9 +783,7 @@ class RelPositionMultiheadAttention(nn.Module): # first compute matrix a and matrix c # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - matrix_ac = torch.matmul( - q_with_bias_u, k - ) # (batch, head, time1, time2) + matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2) # compute matrix b and matrix d matrix_bd = torch.matmul( @@ -805,13 +791,9 @@ class RelPositionMultiheadAttention(nn.Module): ) # (batch, head, time1, 2*time1-1) matrix_bd = self.rel_shift(matrix_bd) - attn_output_weights = ( - matrix_ac + matrix_bd - ) # (batch, head, time1, time2) + attn_output_weights = matrix_ac + matrix_bd # (batch, head, time1, time2) - attn_output_weights = attn_output_weights.view( - bsz * num_heads, tgt_len, -1 - ) + attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) assert list(attn_output_weights.size()) == [ bsz * num_heads, @@ -845,13 +827,9 @@ class RelPositionMultiheadAttention(nn.Module): attn_output = torch.bmm(attn_output_weights, v) assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] attn_output = ( - attn_output.transpose(0, 1) - .contiguous() - .view(tgt_len, bsz, embed_dim) - ) - attn_output = nn.functional.linear( - attn_output, out_proj_weight, out_proj_bias + attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) ) + attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) if need_weights: # average attention weights over heads @@ -874,9 +852,7 @@ class ConvolutionModule(nn.Module): """ - def __init__( - self, channels: int, kernel_size: int, bias: bool = True - ) -> None: + def __init__(self, channels: int, kernel_size: int, bias: bool = True) -> None: """Construct an ConvolutionModule object.""" super(ConvolutionModule, self).__init__() # kernerl_size should be a odd number for 'SAME' padding diff --git a/egs/librispeech/ASR/pruned2_knowledge/decode.py b/egs/librispeech/ASR/pruned2_knowledge/decode.py index b4a9af55a..65da19f27 100755 --- a/egs/librispeech/ASR/pruned2_knowledge/decode.py +++ b/egs/librispeech/ASR/pruned2_knowledge/decode.py @@ -76,11 +76,7 @@ from beam_search import ( ) from train import get_params, get_transducer_model -from icefall.checkpoint import ( - average_checkpoints, - find_checkpoints, - load_checkpoint, -) +from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.utils import ( AttributeDict, setup_logger, @@ -98,16 +94,19 @@ def get_parser(): "--epoch", type=int, default=28, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", + help=( + "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." + ), ) parser.add_argument( "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. " + ), ) parser.add_argument( @@ -186,8 +185,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -245,9 +243,7 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) hyps = [] if params.decoding_method == "fast_beam_search": @@ -262,10 +258,7 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -309,11 +302,7 @@ def decode_one_batch( return {"greedy_search": hyps} elif params.decoding_method == "fast_beam_search": return { - ( - f"beam_{params.beam}_" - f"max_contexts_{params.max_contexts}_" - f"max_states_{params.max_states}" - ): hyps + f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps } else: return {f"beam_size_{params.beam_size}": hyps} @@ -385,9 +374,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -419,8 +406,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) diff --git a/egs/librispeech/ASR/pruned2_knowledge/decoder.py b/egs/librispeech/ASR/pruned2_knowledge/decoder.py index b6d94aaf1..0b9c886c7 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/decoder.py +++ b/egs/librispeech/ASR/pruned2_knowledge/decoder.py @@ -90,9 +90,7 @@ class Decoder(nn.Module): if self.context_size > 1: embedding_out = embedding_out.permute(0, 2, 1) if need_pad is True: - embedding_out = F.pad( - embedding_out, pad=(self.context_size - 1, 0) - ) + embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0)) else: # During inference time, there is no need to do extra padding # as we only need one output diff --git a/egs/librispeech/ASR/pruned2_knowledge/decoder2.py b/egs/librispeech/ASR/pruned2_knowledge/decoder2.py index db51fb1cd..2ca76a30c 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/decoder2.py +++ b/egs/librispeech/ASR/pruned2_knowledge/decoder2.py @@ -14,12 +14,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional + import torch import torch.nn as nn import torch.nn.functional as F -from torch import Tensor -from typing import Optional from subsampling import ScaledConv1d +from torch import Tensor class Decoder(nn.Module): @@ -90,9 +91,7 @@ class Decoder(nn.Module): if self.context_size > 1: embedding_out = embedding_out.permute(0, 2, 1) if need_pad is True: - embedding_out = F.pad( - embedding_out, pad=(self.context_size - 1, 0) - ) + embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0)) else: # During inference time, there is no need to do extra padding # as we only need one output @@ -102,7 +101,6 @@ class Decoder(nn.Module): return embedding_out - class ScaledEmbedding(nn.Module): r"""A simple lookup table that stores embeddings of a fixed dictionary and size. @@ -171,8 +169,13 @@ class ScaledEmbedding(nn.Module): [ 0.0000, 0.0000, 0.0000], [-0.1655, 0.9897, 0.0635]]]) """ - __constants__ = ['num_embeddings', 'embedding_dim', 'padding_idx', - 'scale_grad_by_freq', 'sparse'] + __constants__ = [ + "num_embeddings", + "embedding_dim", + "padding_idx", + "scale_grad_by_freq", + "sparse", + ] num_embeddings: int embedding_dim: int @@ -181,34 +184,41 @@ class ScaledEmbedding(nn.Module): weight: Tensor sparse: bool - def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None, - scale_grad_by_freq: bool = False, - sparse: bool = False, - scale_speed: float = 5.0) -> None: + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + padding_idx: Optional[int] = None, + scale_grad_by_freq: bool = False, + sparse: bool = False, + scale_speed: float = 5.0, + ) -> None: super(ScaledEmbedding, self).__init__() self.num_embeddings = num_embeddings self.embedding_dim = embedding_dim if padding_idx is not None: if padding_idx > 0: - assert padding_idx < self.num_embeddings, 'Padding_idx must be within num_embeddings' + assert ( + padding_idx < self.num_embeddings + ), "Padding_idx must be within num_embeddings" elif padding_idx < 0: - assert padding_idx >= -self.num_embeddings, 'Padding_idx must be within num_embeddings' + assert ( + padding_idx >= -self.num_embeddings + ), "Padding_idx must be within num_embeddings" padding_idx = self.num_embeddings + padding_idx self.padding_idx = padding_idx self.scale_grad_by_freq = scale_grad_by_freq self.scale_speed = scale_speed - self.scale = nn.Parameter(torch.zeros(())) # see reset_parameters() + self.scale = nn.Parameter(torch.zeros(())) # see reset_parameters() self.sparse = sparse self.weight = nn.Parameter(torch.Tensor(num_embeddings, embedding_dim)) self.reset_parameters() - - def reset_parameters(self) -> None: nn.init.normal_(self.weight, std=0.05) - nn.init.constant_(self.scale, torch.tensor(1.0/0.05).log() / self.scale_speed) + nn.init.constant_(self.scale, torch.tensor(1.0 / 0.05).log() / self.scale_speed) if self.padding_idx is not None: with torch.no_grad(): @@ -217,22 +227,38 @@ class ScaledEmbedding(nn.Module): def forward(self, input: Tensor) -> Tensor: scale = (self.scale * self.scale_speed).exp() if input.numel() < self.num_embeddings: - return F.embedding( - input, self.weight, self.padding_idx, - None, 2.0, # None, 2.0 relate to normalization - self.scale_grad_by_freq, self.sparse) * scale + return ( + F.embedding( + input, + self.weight, + self.padding_idx, + None, + 2.0, # None, 2.0 relate to normalization + self.scale_grad_by_freq, + self.sparse, + ) + * scale + ) else: return F.embedding( - input, self.weight * scale, self.padding_idx, - None, 2.0, # None, 2.0 relates to normalization - self.scale_grad_by_freq, self.sparse) + input, + self.weight * scale, + self.padding_idx, + None, + 2.0, # None, 2.0 relates to normalization + self.scale_grad_by_freq, + self.sparse, + ) def extra_repr(self) -> str: - s = '{num_embeddings}, {embedding_dim}, scale_speed={scale_speed}, scale={scale}' + s = ( + "{num_embeddings}, {embedding_dim}, scale_speed={scale_speed}," + " scale={scale}" + ) if self.padding_idx is not None: - s += ', padding_idx={padding_idx}' + s += ", padding_idx={padding_idx}" if self.scale_grad_by_freq is not False: - s += ', scale_grad_by_freq={scale_grad_by_freq}' + s += ", scale_grad_by_freq={scale_grad_by_freq}" if self.sparse is not False: - s += ', sparse=True' + s += ", sparse=True" return s.format(**self.__dict__) diff --git a/egs/librispeech/ASR/pruned2_knowledge/export.py b/egs/librispeech/ASR/pruned2_knowledge/export.py index 96d1a30fb..1af05d9c8 100755 --- a/egs/librispeech/ASR/pruned2_knowledge/export.py +++ b/egs/librispeech/ASR/pruned2_knowledge/export.py @@ -64,17 +64,20 @@ def get_parser(): "--epoch", type=int, default=28, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", + help=( + "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." + ), ) parser.add_argument( "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. " + ), ) parser.add_argument( @@ -105,8 +108,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) return parser @@ -174,9 +176,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned2_knowledge/joiner.py b/egs/librispeech/ASR/pruned2_knowledge/joiner.py index 35f75ed2a..68c663b66 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/joiner.py +++ b/egs/librispeech/ASR/pruned2_knowledge/joiner.py @@ -56,9 +56,7 @@ class Joiner(nn.Module): assert encoder_out.shape[:-1] == decoder_out.shape[:-1] if project_input: - logit = self.encoder_proj(encoder_out) + self.decoder_proj( - decoder_out - ) + logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out) else: logit = encoder_out + decoder_out diff --git a/egs/librispeech/ASR/pruned2_knowledge/model.py b/egs/librispeech/ASR/pruned2_knowledge/model.py index 599bf2506..ca8c28af1 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/model.py +++ b/egs/librispeech/ASR/pruned2_knowledge/model.py @@ -63,9 +63,7 @@ class Transducer(nn.Module): self.decoder = decoder self.joiner = joiner - self.simple_am_proj = ScaledLinear( - encoder_dim, vocab_size, initial_speed=0.5 - ) + self.simple_am_proj = ScaledLinear(encoder_dim, vocab_size, initial_speed=0.5) self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size) def forward( @@ -136,9 +134,7 @@ class Transducer(nn.Module): y_padded = y.pad(mode="constant", padding_value=0) y_padded = y_padded.to(torch.int64) - boundary = torch.zeros( - (x.size(0), 4), dtype=torch.int64, device=x.device - ) + boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) boundary[:, 2] = y_lens boundary[:, 3] = x_lens diff --git a/egs/librispeech/ASR/pruned2_knowledge/optim.py b/egs/librispeech/ASR/pruned2_knowledge/optim.py index 432bf8220..76cd4e11e 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/optim.py +++ b/egs/librispeech/ASR/pruned2_knowledge/optim.py @@ -72,17 +72,11 @@ class Eve(Optimizer): if not 0.0 <= eps: raise ValueError("Invalid epsilon value: {}".format(eps)) if not 0.0 <= betas[0] < 1.0: - raise ValueError( - "Invalid beta parameter at index 0: {}".format(betas[0]) - ) + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) if not 0.0 <= betas[1] < 1.0: - raise ValueError( - "Invalid beta parameter at index 1: {}".format(betas[1]) - ) + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) if not 0 <= weight_decay <= 0.1: - raise ValueError( - "Invalid weight_decay value: {}".format(weight_decay) - ) + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) if not 0 < target_rms <= 10.0: raise ValueError("Invalid target_rms value: {}".format(target_rms)) defaults = dict( @@ -118,9 +112,7 @@ class Eve(Optimizer): # Perform optimization step grad = p.grad if grad.is_sparse: - raise RuntimeError( - "AdamW does not support sparse gradients" - ) + raise RuntimeError("AdamW does not support sparse gradients") state = self.state[p] @@ -147,7 +139,7 @@ class Eve(Optimizer): # Decay the first and second moment running average coefficient exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) - denom = (exp_avg_sq.sqrt() * (bias_correction2 ** -0.5)).add_( + denom = (exp_avg_sq.sqrt() * (bias_correction2**-0.5)).add_( group["eps"] ) @@ -158,9 +150,7 @@ class Eve(Optimizer): if p.numel() > 1: # avoid applying this weight-decay on "scaling factors" # (which are scalar). - is_above_target_rms = p.norm() > ( - target_rms * (p.numel() ** 0.5) - ) + is_above_target_rms = p.norm() > (target_rms * (p.numel() ** 0.5)) p.mul_(1 - (weight_decay * is_above_target_rms)) p.addcdiv_(exp_avg, denom, value=-step_size) @@ -176,18 +166,14 @@ class LRScheduler(object): def __init__(self, optimizer: Optimizer, verbose: bool = False): # Attach optimizer if not isinstance(optimizer, Optimizer): - raise TypeError( - "{} is not an Optimizer".format(type(optimizer).__name__) - ) + raise TypeError("{} is not an Optimizer".format(type(optimizer).__name__)) self.optimizer = optimizer self.verbose = verbose for group in optimizer.param_groups: group.setdefault("initial_lr", group["lr"]) - self.base_lrs = [ - group["initial_lr"] for group in optimizer.param_groups - ] + self.base_lrs = [group["initial_lr"] for group in optimizer.param_groups] self.epoch = 0 self.batch = 0 @@ -295,10 +281,9 @@ class Eden(LRScheduler): def get_lr(self): factor = ( - (self.batch ** 2 + self.lr_batches ** 2) / self.lr_batches ** 2 + (self.batch**2 + self.lr_batches**2) / self.lr_batches**2 ) ** -0.25 * ( - ((self.epoch ** 2 + self.lr_epochs ** 2) / self.lr_epochs ** 2) - ** -0.25 + ((self.epoch**2 + self.lr_epochs**2) / self.lr_epochs**2) ** -0.25 ) return [x * factor for x in self.base_lrs] diff --git a/egs/librispeech/ASR/pruned2_knowledge/sampling.py b/egs/librispeech/ASR/pruned2_knowledge/sampling.py index 7b05e2f00..8cc930927 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/sampling.py +++ b/egs/librispeech/ASR/pruned2_knowledge/sampling.py @@ -3,32 +3,29 @@ # This was copied from /ceph-dan/torch-sampling/torch_sampling/sampling_ref.py, # its git history is there. -import timeit -import torch -from torch import Tensor -from torch import nn -from torch.cuda.amp import GradScaler, custom_fwd, custom_bwd -from typing import Tuple, Optional -from scaling import ScaledLinear import random +import timeit +from typing import Optional, Tuple + +import torch +from scaling import ScaledLinear +from torch import Tensor, nn +from torch.cuda.amp import GradScaler, custom_bwd, custom_fwd from torch_scheduled_sampling import sample_combined # The main exports of this file are the module KnowledgeBaseLookup and the # function create_knowledge_base. - - - - def create_knowledge_base(M: int, N: int, D: int) -> nn.Parameter: std = 0.1 - a = (3 ** 0.5) * std # this sqrt(3) thing is intended to get variance of - # 0.1 from uniform distribution - ans = nn.Parameter(torch.ones(M ** N, D)) + a = (3**0.5) * std # this sqrt(3) thing is intended to get variance of + # 0.1 from uniform distribution + ans = nn.Parameter(torch.ones(M**N, D)) nn.init.uniform_(ans, -a, a) return ans + def join_indexes(indexes: Tensor, M: int) -> Tensor: """ Combines N-tuples of indexes into single indexes that can be used for @@ -47,9 +44,9 @@ def join_indexes(indexes: Tensor, M: int) -> Tensor: # Note, we don't use this, we -def weighted_matrix_lookup(weights: Tensor, - indexes: Tensor, - knowledge_base: Tensor) -> Tensor: +def weighted_matrix_lookup( + weights: Tensor, indexes: Tensor, knowledge_base: Tensor +) -> Tensor: """ Weighted combination of specified rows of a matrix. weights: Tensor of shape (*, K), can contain any value but probably in [0..1]. @@ -65,9 +62,9 @@ def weighted_matrix_lookup(weights: Tensor, # simpler but less memory-efficient implementation lookup = torch.index_select(knowledge_base, dim=0, index=indexes.flatten()) D = knowledge_base.shape[-1] - weights = weights.unsqueeze(-2) # (*, 1, K) - lookup = lookup.reshape(*indexes.shape, D) # (*, K, D) - ans = torch.matmul(weights, lookup) # ans: (*, 1, D) + weights = weights.unsqueeze(-2) # (*, 1, K) + lookup = lookup.reshape(*indexes.shape, D) # (*, K, D) + ans = torch.matmul(weights, lookup) # ans: (*, 1, D) ans = ans.squeeze(-2) assert list(ans.shape) == list(weights.shape[:-2]) + [D] return ans @@ -76,7 +73,9 @@ def weighted_matrix_lookup(weights: Tensor, class WeightedMatrixLookupFunction(torch.autograd.Function): @staticmethod @custom_fwd - def forward(ctx, weights: Tensor, indexes: Tensor, knowledge_base: Tensor) -> Tensor: + def forward( + ctx, weights: Tensor, indexes: Tensor, knowledge_base: Tensor + ) -> Tensor: """ Weighted combination of specified rows of a matrix. weights: Tensor of shape (*, K), can contain any value but probably in [0..1]. @@ -88,15 +87,16 @@ class WeightedMatrixLookupFunction(torch.autograd.Function): """ if random.random() < 0.001: print("dtype[1] = ", weights.dtype) - ctx.save_for_backward(weights.detach(), indexes.detach(), - knowledge_base.detach()) + ctx.save_for_backward( + weights.detach(), indexes.detach(), knowledge_base.detach() + ) with torch.no_grad(): lookup = torch.index_select(knowledge_base, dim=0, index=indexes.flatten()) D = knowledge_base.shape[-1] - weights = weights.unsqueeze(-2) # (*, 1, K) - lookup = lookup.reshape(*indexes.shape, D) # (*, K, D) - ans = torch.matmul(weights, lookup) # ans: (*, 1, D) - ans = ans.squeeze(-2) #(*, D) + weights = weights.unsqueeze(-2) # (*, 1, K) + lookup = lookup.reshape(*indexes.shape, D) # (*, K, D) + ans = torch.matmul(weights, lookup) # ans: (*, 1, D) + ans = ans.squeeze(-2) # (*, D) return ans @staticmethod @@ -107,7 +107,7 @@ class WeightedMatrixLookupFunction(torch.autograd.Function): knowledge_base.requires_grad = True dtype = ans_grad.dtype ans_grad = ans_grad.to(weights.dtype) - assert weights.requires_grad == False + assert weights.requires_grad is False D = knowledge_base.shape[-1] with torch.enable_grad(): # we'll use torch's autograd to differentiate this operation, which @@ -115,16 +115,19 @@ class WeightedMatrixLookupFunction(torch.autograd.Function): # We don't save `lookup` because it's large, that is the reason # we override Torch autograd. lookup = torch.index_select(knowledge_base, dim=0, index=indexes.flatten()) - lookup = lookup.reshape(*indexes.shape, D) # (*, K, D) - weights = weights.unsqueeze(-1) # (*, K, 1) + lookup = lookup.reshape(*indexes.shape, D) # (*, K, D) + weights = weights.unsqueeze(-1) # (*, K, 1) # forward pass: was: ## ans = torch.matmul(weights, lookup) ## ans: (*, 1, D) ## ans = ans.squeeze(-2) # ans, ans_grad: (*, D) - weights_grad = torch.matmul(lookup, # (*, K, D) - ans_grad.unsqueeze(-1)) # (*, D, 1) - weights_grad = weights_grad.squeeze(-1) # (*, K, 1) -> (*, K) - lookup_grad = weights * ans_grad.unsqueeze(-2) # (*, K, 1) * (*, 1, D) = (*, K, D) + weights_grad = torch.matmul( + lookup, ans_grad.unsqueeze(-1) # (*, K, D) + ) # (*, D, 1) + weights_grad = weights_grad.squeeze(-1) # (*, K, 1) -> (*, K) + lookup_grad = weights * ans_grad.unsqueeze( + -2 + ) # (*, K, 1) * (*, 1, D) = (*, K, D) lookup.backward(gradient=lookup_grad) return weights_grad.to(dtype), None, knowledge_base.grad.to(dtype) @@ -146,6 +149,7 @@ class PenalizeNegentropyFunction(torch.autograd.Function): Returns: logprobs """ + @staticmethod def forward(ctx, logprobs: Tensor, alpha: float): ctx.save_for_backward(logprobs.detach()) @@ -154,18 +158,23 @@ class PenalizeNegentropyFunction(torch.autograd.Function): @staticmethod def backward(ctx, logprobs_grad: Tensor) -> Tuple[Tensor, None]: - logprobs, = ctx.saved_tensors + (logprobs,) = ctx.saved_tensors with torch.enable_grad(): logprobs.requires_grad = True # `negentropy` is the negative entropy of the average distribution. # distributions. It will be <= 0. - l = logprobs.reshape(-1, logprobs.shape[-1]) + l = logprobs.reshape(-1, logprobs.shape[-1]) # noqa: E741 scale = ctx.alpha * l.shape[0] avg_dist = l.exp().mean(dim=0) negentropy = (avg_dist * (avg_dist + 1.0e-20).log()).sum() if random.random() < 0.0005: negentropy_individual = (l * l.exp()).sum(dim=-1).mean() - print("Negentropy[individual,combined] = ", negentropy_individual.item(), ", ", negentropy.item()) + print( + "Negentropy[individual,combined] = ", + negentropy_individual.item(), + ", ", + negentropy.item(), + ) loss = negentropy * scale loss.backward() return logprobs_grad + logprobs.grad, None @@ -183,18 +192,23 @@ class KnowledgeBaseLookup(nn.Module): embedding_dim: the dimension to project from and to, e.g. the d_model of the conformer. """ - def __init__(self, M: int, N: int, D: int, - K: int, embedding_dim: int, - knowledge_base: nn.Parameter, - negentropy_penalty: float = 0.001): + + def __init__( + self, + M: int, + N: int, + D: int, + K: int, + embedding_dim: int, + knowledge_base: nn.Parameter, + negentropy_penalty: float = 0.001, + ): super(KnowledgeBaseLookup, self).__init__() self.knowledge_base = knowledge_base # shared! - self.in_proj = ScaledLinear(embedding_dim, M * N, - initial_scale=1.0) + self.in_proj = ScaledLinear(embedding_dim, M * N, initial_scale=1.0) # initial_scale = 4.0 because the knowlege_base activations are # quite small -- if we use our optimizer they'll have stddev <= 0.1. - self.out_proj = ScaledLinear(D, embedding_dim, - initial_scale = 4.0) + self.out_proj = ScaledLinear(D, embedding_dim, initial_scale=4.0) self.M = M self.N = N self.K = K @@ -210,14 +224,14 @@ class KnowledgeBaseLookup(nn.Module): # TODO: later we can try multiplying by a projection of x or something like that. """ - x = self.in_proj(x) # now (*, M*N) - x = x.reshape(*x.shape[:-1], self.N, self.M) # now (*, N, M) - x = x.log_softmax(dim=-1) # now normalized logprobs, dim= (*, N, M) + x = self.in_proj(x) # now (*, M*N) + x = x.reshape(*x.shape[:-1], self.N, self.M) # now (*, N, M) + x = x.log_softmax(dim=-1) # now normalized logprobs, dim= (*, N, M) x = PenalizeNegentropyFunction.apply(x, self.negentropy_penalty) _, indexes, weights = sample_combined(x, self.K, input_is_log=True) - x = weighted_matrix_lookup(weights, indexes, self.knowledge_base) # now (*, D) - x = self.out_proj(x) # now (*, self.embedding_dim) + x = weighted_matrix_lookup(weights, indexes, self.knowledge_base) # now (*, D) + x = self.out_proj(x) # now (*, self.embedding_dim) return x @@ -237,38 +251,44 @@ def _test_knowledge_base_lookup(): x.requires_grad = True y = m(x) assert y.shape == x.shape - y.sum().backward() # make sure backward doesn't crash.. + y.sum().backward() # make sure backward doesn't crash.. print("y = ", y) print("x.grad = ", x.grad) print("knowlege_base.grad norm = ", knowledge_base.grad.norm()) dtype = torch.float32 - device = torch.device('cuda') - train_pairs = [ (torch.randn(B, T, E, device=device, dtype=dtype), torch.randn(B, T, E, device=device, dtype=dtype)) for _ in range(10) ] + device = torch.device("cuda") + train_pairs = [ + ( + torch.randn(B, T, E, device=device, dtype=dtype), + torch.randn(B, T, E, device=device, dtype=dtype), + ) + for _ in range(10) + ] from optim import Eve + optimizer = Eve(m.parameters(), lr=0.005, eps=1.0e-04) m = m.to(device).to(dtype) - start = timeit.default_timer() -# Epoch 0, batch 0, loss 1.0109944343566895 -# Epoch 10, batch 0, loss 1.0146660804748535 -# Epoch 20, batch 0, loss 1.0119813680648804 -# Epoch 30, batch 0, loss 1.0105408430099487 -# Epoch 40, batch 0, loss 1.0077732801437378 -# Epoch 50, batch 0, loss 1.0050103664398193 -# Epoch 60, batch 0, loss 1.0033129453659058 -# Epoch 70, batch 0, loss 1.0014232397079468 -# Epoch 80, batch 0, loss 0.9977912306785583 -# Epoch 90, batch 0, loss 0.8274348974227905 -# Epoch 100, batch 0, loss 0.3368612825870514 -# Epoch 110, batch 0, loss 0.11323091387748718 -# Time taken: 17.591704960912466 + # Epoch 0, batch 0, loss 1.0109944343566895 + # Epoch 10, batch 0, loss 1.0146660804748535 + # Epoch 20, batch 0, loss 1.0119813680648804 + # Epoch 30, batch 0, loss 1.0105408430099487 + # Epoch 40, batch 0, loss 1.0077732801437378 + # Epoch 50, batch 0, loss 1.0050103664398193 + # Epoch 60, batch 0, loss 1.0033129453659058 + # Epoch 70, batch 0, loss 1.0014232397079468 + # Epoch 80, batch 0, loss 0.9977912306785583 + # Epoch 90, batch 0, loss 0.8274348974227905 + # Epoch 100, batch 0, loss 0.3368612825870514 + # Epoch 110, batch 0, loss 0.11323091387748718 + # Time taken: 17.591704960912466 for epoch in range(150): - for n, (x,y) in enumerate(train_pairs): + for n, (x, y) in enumerate(train_pairs): y_out = m(x) - loss = ((y_out - y)**2).mean() * 100.0 + loss = ((y_out - y) ** 2).mean() * 100.0 if n % 10 == 0 and epoch % 10 == 0: print(f"Epoch {epoch}, batch {n}, loss {loss.item()}") loss.backward() @@ -276,7 +296,8 @@ def _test_knowledge_base_lookup(): optimizer.zero_grad() stop = timeit.default_timer() - print('Time taken: ', stop - start) + print("Time taken: ", stop - start) + def _test_knowledge_base_lookup_autocast(): K = 16 @@ -294,14 +315,21 @@ def _test_knowledge_base_lookup_autocast(): x.requires_grad = True y = m(x) assert y.shape == x.shape - y.sum().backward() # make sure backward doesn't crash.. + y.sum().backward() # make sure backward doesn't crash.. print("y = ", y) print("x.grad = ", x.grad) print("knowlege_base.grad norm = ", knowledge_base.grad.norm()) - device = torch.device('cuda') - train_pairs = [ (torch.randn(B, T, E, device=device), torch.randn(B, T, E, device=device)) for _ in range(10) ] + device = torch.device("cuda") + train_pairs = [ + ( + torch.randn(B, T, E, device=device), + torch.randn(B, T, E, device=device), + ) + for _ in range(10) + ] from optim import Eve + optimizer = Eve(m.parameters(), lr=0.005, eps=1.0e-04) m = m.to(device) @@ -309,12 +337,11 @@ def _test_knowledge_base_lookup_autocast(): start = timeit.default_timer() - for epoch in range(150): - for n, (x,y) in enumerate(train_pairs): + for n, (x, y) in enumerate(train_pairs): y_out = m(x) with torch.cuda.amp.autocast(enabled=True): - loss = ((y_out - y)**2).mean() * 100.0 + loss = ((y_out - y) ** 2).mean() * 100.0 if n % 10 == 0 and epoch % 10 == 0: print(f"Epoch {epoch}, batch {n}, loss {loss.item()}") scaler.scale(loss).backward() @@ -323,10 +350,9 @@ def _test_knowledge_base_lookup_autocast(): optimizer.zero_grad() stop = timeit.default_timer() - print('Time taken: ', stop - start) + print("Time taken: ", stop - start) - -if __name__ == '__main__': +if __name__ == "__main__": _test_knowledge_base_lookup() _test_knowledge_base_lookup_autocast() diff --git a/egs/librispeech/ASR/pruned2_knowledge/scaling.py b/egs/librispeech/ASR/pruned2_knowledge/scaling.py index f726c2583..527c735eb 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/scaling.py +++ b/egs/librispeech/ASR/pruned2_knowledge/scaling.py @@ -18,11 +18,11 @@ import collections from itertools import repeat from typing import Optional, Tuple -from torch.cuda.amp import custom_fwd, custom_bwd import torch import torch.nn as nn from torch import Tensor +from torch.cuda.amp import custom_bwd, custom_fwd def _ntuple(n): @@ -79,9 +79,7 @@ class ActivationBalancerFunction(torch.autograd.Function): below_threshold = mean_abs < min_abs above_threshold = mean_abs > max_abs - ctx.save_for_backward( - factor, xgt0, below_threshold, above_threshold - ) + ctx.save_for_backward(factor, xgt0, below_threshold, above_threshold) ctx.max_factor = max_factor ctx.sum_dims = sum_dims return x @@ -149,8 +147,7 @@ class BasicNorm(torch.nn.Module): def forward(self, x: Tensor) -> Tensor: assert x.shape[self.channel_dim] == self.num_channels scales = ( - torch.mean(x ** 2, dim=self.channel_dim, keepdim=True) - + self.eps.exp() + torch.mean(x**2, dim=self.channel_dim, keepdim=True) + self.eps.exp() ) ** -0.5 return x * scales @@ -182,11 +179,7 @@ class ScaledLinear(nn.Linear): """ def __init__( - self, - *args, - initial_scale: float = 1.0, - initial_speed: float = 1.0, - **kwargs + self, *args, initial_scale: float = 1.0, initial_speed: float = 1.0, **kwargs ): super(ScaledLinear, self).__init__(*args, **kwargs) initial_scale = torch.tensor(initial_scale).log() @@ -202,12 +195,12 @@ class ScaledLinear(nn.Linear): def _reset_parameters(self, initial_speed: float): std = 0.1 / initial_speed - a = (3 ** 0.5) * std + a = (3**0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: nn.init.constant_(self.bias, 0.0) fan_in = self.weight.shape[1] * self.weight[0][0].numel() - scale = fan_in ** -0.5 # 1/sqrt(fan_in) + scale = fan_in**-0.5 # 1/sqrt(fan_in) with torch.no_grad(): self.weight_scale += torch.tensor(scale / std).log() @@ -218,19 +211,13 @@ class ScaledLinear(nn.Linear): return None if self.bias is None else self.bias * self.bias_scale.exp() def forward(self, input: Tensor) -> Tensor: - return torch.nn.functional.linear( - input, self.get_weight(), self.get_bias() - ) + return torch.nn.functional.linear(input, self.get_weight(), self.get_bias()) class ScaledConv1d(nn.Conv1d): # See docs for ScaledLinear def __init__( - self, - *args, - initial_scale: float = 1.0, - initial_speed: float = 1.0, - **kwargs + self, *args, initial_scale: float = 1.0, initial_speed: float = 1.0, **kwargs ): super(ScaledConv1d, self).__init__(*args, **kwargs) initial_scale = torch.tensor(initial_scale).log() @@ -245,12 +232,12 @@ class ScaledConv1d(nn.Conv1d): def _reset_parameters(self, initial_speed: float): std = 0.1 / initial_speed - a = (3 ** 0.5) * std + a = (3**0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: nn.init.constant_(self.bias, 0.0) fan_in = self.weight.shape[1] * self.weight[0][0].numel() - scale = fan_in ** -0.5 # 1/sqrt(fan_in) + scale = fan_in**-0.5 # 1/sqrt(fan_in) with torch.no_grad(): self.weight_scale += torch.tensor(scale / std).log() @@ -290,11 +277,7 @@ class ScaledConv1d(nn.Conv1d): class ScaledConv2d(nn.Conv2d): # See docs for ScaledLinear def __init__( - self, - *args, - initial_scale: float = 1.0, - initial_speed: float = 1.0, - **kwargs + self, *args, initial_scale: float = 1.0, initial_speed: float = 1.0, **kwargs ): super(ScaledConv2d, self).__init__(*args, **kwargs) initial_scale = torch.tensor(initial_scale).log() @@ -309,12 +292,12 @@ class ScaledConv2d(nn.Conv2d): def _reset_parameters(self, initial_speed: float): std = 0.1 / initial_speed - a = (3 ** 0.5) * std + a = (3**0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: nn.init.constant_(self.bias, 0.0) fan_in = self.weight.shape[1] * self.weight[0][0].numel() - scale = fan_in ** -0.5 # 1/sqrt(fan_in) + scale = fan_in**-0.5 # 1/sqrt(fan_in) with torch.no_grad(): self.weight_scale += torch.tensor(scale / std).log() @@ -653,9 +636,7 @@ def _test_activation_balancer_sign(): def _test_activation_balancer_magnitude(): magnitudes = torch.arange(0, 1, 0.01) N = 1000 - x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze( - -1 - ) + x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1) x = x.detach() x.requires_grad = True m = ActivationBalancer( @@ -685,8 +666,8 @@ def _test_basic_norm(): y = m(x) assert y.shape == x.shape - x_rms = (x ** 2).mean().sqrt() - y_rms = (y ** 2).mean().sqrt() + x_rms = (x**2).mean().sqrt() + y_rms = (y**2).mean().sqrt() print("x rms = ", x_rms) print("y rms = ", y_rms) assert y_rms < x_rms diff --git a/egs/librispeech/ASR/pruned2_knowledge/scaling_tmp.py b/egs/librispeech/ASR/pruned2_knowledge/scaling_tmp.py index 6293e081a..3f21133a0 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/scaling_tmp.py +++ b/egs/librispeech/ASR/pruned2_knowledge/scaling_tmp.py @@ -15,21 +15,23 @@ # limitations under the License. +from typing import Optional, Tuple + import torch import torch.nn as nn from torch import Tensor -from typing import Tuple, Optional - -def _activation_balancer_loss(mean_pos: Tensor, - mean_neg: Tensor, - min_positive: float, # e.g. 0.05 - max_positive: float, # e.g. 0.95 - max_factor: float, # e.g. 0.01 - min_abs: float, # e.g. 0.2 - max_abs: float, # e.g. 100.0 - eps: float = 1.0e-10): +def _activation_balancer_loss( + mean_pos: Tensor, + mean_neg: Tensor, + min_positive: float, # e.g. 0.05 + max_positive: float, # e.g. 0.95 + max_factor: float, # e.g. 0.01 + min_abs: float, # e.g. 0.2 + max_abs: float, # e.g. 100.0 + eps: float = 1.0e-10, +): """ Returns a loss-function for the ActivationBalancer module. This loss function is not exposed to the user but is used internally, and eventually @@ -50,28 +52,32 @@ def _activation_balancer_loss(mean_pos: Tensor, """ loss_parts = [] - x_mean = mean_positive - mean_negative - x_mean_abs = (mean_positive + mean_negative + eps).detach() - x_rel_mean= x_mean / x_mean_abs + x_mean = mean_pos - mean_neg + x_mean_abs = (mean_pos + mean_neg + eps).detach() + x_rel_mean = x_mean / x_mean_abs if min_positive != 0.0: # e.g. x_mean_floor = -0.95 + 0.05 = -0.9 - x_rel_mean_floor = (-(1-min_positive) + min_positive) - min_positive_loss = (x_rel_mean_floor - x_rel_mean).relu().sum() * (1.0 / (2*min_positive)) + x_rel_mean_floor = -(1 - min_positive) + min_positive + min_positive_loss = (x_rel_mean_floor - x_rel_mean).relu().sum() * ( + 1.0 / (2 * min_positive) + ) # this part of the loss would be 1.0 * num_channels if all these constraints were # 100% violated. loss_parts.append(min_positive_loss) if max_positive != 1.0: # e.g. x_mean_floor = -0.05 + 0.95 = 0.8 - x_rel_mean_ceil = - (1.0-max_positive) + max_positive - max_positive_loss = (x_rel_mean - x_rel_mean_ceil).relu().sum() * (1.0 / (1 - x_rel_mean_ceil)) + x_rel_mean_ceil = -(1.0 - max_positive) + max_positive + max_positive_loss = (x_rel_mean - x_rel_mean_ceil).relu().sum() * ( + 1.0 / (1 - x_rel_mean_ceil) + ) # this part of the loss would be 1.0 * num_channels if all these constraints were # 100% violated. loss_parts.append(max_positive_loss) if min_abs != 0.0: - min_abs_loss = min_abs - x_mean_abs).relu().sum() / min_abs + min_abs_loss = (min_abs - x_mean_abs).relu().sum() / min_abs # this part of the loss would be 1.0 * num_channels if all these constraints were # 100% violated. loss_parts.append(min_abs_loss) @@ -82,43 +88,53 @@ def _activation_balancer_loss(mean_pos: Tensor, # 100% violated. loss_parts.append(max_abs_loss) - # the min_positive and 1 - max_positive are "ballast" added to the denom = mean_pos + mean_neg + (min_positive + (1 - max_positive)) - num + # num if min_positive != 0.0: - - + pass class ActivationBalancerFunction(torch.autograd.Function): @staticmethod - def forward(ctx, x: Tensor, - channel_dim: int, - min_positive: float, # e.g. 0.05 - max_positive: float, # e.g. 0.95 - max_factor: float, # e.g. 0.01 - min_abs: float, # e.g. 0.2 - max_abs: float, # e.g. 100.0 + def forward( + ctx, + x: Tensor, + channel_dim: int, + min_positive: float, # e.g. 0.05 + max_positive: float, # e.g. 0.95 + max_factor: float, # e.g. 0.01 + min_abs: float, # e.g. 0.2 + max_abs: float, # e.g. 100.0 ) -> Tensor: if x.requires_grad: if channel_dim < 0: channel_dim += x.ndim sum_dims = [d for d in range(x.ndim) if d != channel_dim] xgt0 = x > 0 - proportion_positive = torch.mean(xgt0.to(x.dtype), dim=sum_dims, keepdim=True) - factor1 = ((min_positive - proportion_positive).relu() * (max_factor / min_positive) - if min_positive != 0.0 else 0.0) - factor2 = ((proportion_positive - max_positive).relu() * (max_factor / (max_positive - 1.0)) - if max_positive != 1.0 else 0.0) + proportion_positive = torch.mean( + xgt0.to(x.dtype), dim=sum_dims, keepdim=True + ) + factor1 = ( + (min_positive - proportion_positive).relu() + * (max_factor / min_positive) + if min_positive != 0.0 + else 0.0 + ) + factor2 = ( + (proportion_positive - max_positive).relu() + * (max_factor / (max_positive - 1.0)) + if max_positive != 1.0 + else 0.0 + ) factor = factor1 + factor2 if isinstance(factor, float): factor = torch.zeros_like(proportion_positive) mean_abs = torch.mean(x.abs(), dim=sum_dims, keepdim=True) - below_threshold = (mean_abs < min_abs) - above_threshold = (mean_abs > max_abs) + below_threshold = mean_abs < min_abs + above_threshold = mean_abs > max_abs ctx.save_for_backward(factor, xgt0, below_threshold, above_threshold) ctx.max_factor = max_factor @@ -126,11 +142,16 @@ class ActivationBalancerFunction(torch.autograd.Function): return x @staticmethod - def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None, None, None, None]: + def backward( + ctx, x_grad: Tensor + ) -> Tuple[Tensor, None, None, None, None, None, None]: factor, xgt0, below_threshold, above_threshold = ctx.saved_tensors dtype = x_grad.dtype - scale_factor = ((below_threshold.to(dtype) - above_threshold.to(dtype)) * - (xgt0.to(dtype) - 0.5) * (ctx.max_factor * 2.0)) + scale_factor = ( + (below_threshold.to(dtype) - above_threshold.to(dtype)) + * (xgt0.to(dtype) - 0.5) + * (ctx.max_factor * 2.0) + ) neg_delta_grad = x_grad.abs() * (factor + scale_factor) return x_grad - neg_delta_grad, None, None, None, None, None, None @@ -163,29 +184,30 @@ class BasicNorm(torch.nn.Module): learn_eps: if true, we learn epsilon; if false, we keep it at the initial value. """ - def __init__(self, - num_channels: int, - channel_dim: int = -1, # CAUTION: see documentation. - eps: float = 0.25, - learn_eps: bool = True) -> None: + + def __init__( + self, + num_channels: int, + channel_dim: int = -1, # CAUTION: see documentation. + eps: float = 0.25, + learn_eps: bool = True, + ) -> None: super(BasicNorm, self).__init__() self.num_channels = num_channels self.channel_dim = channel_dim if learn_eps: self.eps = nn.Parameter(torch.tensor(eps).log().detach()) else: - self.register_buffer('eps', torch.tensor(eps).log().detach()) - + self.register_buffer("eps", torch.tensor(eps).log().detach()) def forward(self, x: Tensor) -> Tensor: assert x.shape[self.channel_dim] == self.num_channels - scales = (torch.mean(x**2, dim=self.channel_dim, keepdim=True) + - self.eps.exp()) ** -0.5 + scales = ( + torch.mean(x**2, dim=self.channel_dim, keepdim=True) + self.eps.exp() + ) ** -0.5 return x * scales - - class ScaledLinear(nn.Linear): """ A modified version of nn.Linear where the parameters are scaled before @@ -207,27 +229,26 @@ class ScaledLinear(nn.Linear): inherited from nn.Linear. For modules with small fan-in, this may be larger than optimal. """ - def __init__(self, *args, - initial_scale: float = 1.0, - **kwargs): + + def __init__(self, *args, initial_scale: float = 1.0, **kwargs): super(ScaledLinear, self).__init__(*args, **kwargs) initial_scale = torch.tensor(initial_scale).log() self.weight_scale = nn.Parameter(initial_scale.clone().detach()) if self.bias is not None: self.bias_scale = nn.Parameter(initial_scale.clone().detach()) else: - self.register_parameter('bias_scale', None) + self.register_parameter("bias_scale", None) self._reset_parameters() # Overrides the reset_parameters in nn.Linear def _reset_parameters(self): std = 0.01 - a = (3 ** 0.5) * std + a = (3**0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: nn.init.constant_(self.bias, 0.0) fan_in = self.weight.shape[1] * self.weight[0][0].numel() - scale = fan_in ** -0.5 # 1/sqrt(fan_in) + scale = fan_in**-0.5 # 1/sqrt(fan_in) with torch.no_grad(): self.weight_scale += torch.tensor(scale / std).log() if self.bias is not None: @@ -237,56 +258,67 @@ class ScaledLinear(nn.Linear): return self.weight * self.weight_scale.exp() def get_bias(self): - return (None if self.bias is None else - self.bias * self.bias_scale.exp()) + return None if self.bias is None else self.bias * self.bias_scale.exp() def forward(self, input: Tensor) -> Tensor: - return torch.nn.functional.linear(input, self.get_weight(), - self.get_bias()) + return torch.nn.functional.linear(input, self.get_weight(), self.get_bias()) class ScaledConv1d(nn.Conv1d): - def __init__(self, *args, - initial_scale=1.0, **kwargs): + def __init__(self, *args, initial_scale=1.0, **kwargs): super(ScaledConv1d, self).__init__(*args, **kwargs) initial_scale = torch.tensor(initial_scale).log() self.weight_scale = nn.Parameter(initial_scale.clone().detach()) if self.bias is not None: self.bias_scale = nn.Parameter(initial_scale.clone().detach()) else: - self.register_parameter('bias_scale', None) + self.register_parameter("bias_scale", None) self._reset_parameters() # Overrides the reset_parameters in base class def _reset_parameters(self): std = 0.01 - a = (3 ** 0.5) * std + a = (3**0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: nn.init.constant_(self.bias, 0.0) fan_in = self.weight.shape[1] * self.weight[0][0].numel() - scale = fan_in ** -0.5 # 1/sqrt(fan_in) + scale = fan_in**-0.5 # 1/sqrt(fan_in) with torch.no_grad(): self.weight_scale += torch.tensor(scale / std).log() if self.bias is not None: self.bias_scale += torch.tensor(scale / std).log() - def get_weight(self): return self.weight * self.weight_scale.exp() def get_bias(self): - return (None if self.bias is None else - self.bias * self.bias_scale.exp()) + return None if self.bias is None else self.bias * self.bias_scale.exp() def forward(self, input: Tensor) -> Tensor: F = torch.nn.functional - if self.padding_mode != 'zeros': - return F.conv1d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode), - self.get_weight(), self.get_bias(), self.stride, - _single(0), self.dilation, self.groups) - return F.conv1d(input, self.get_weight(), self.get_bias(), self.stride, - self.padding, self.dilation, self.groups) - + if self.padding_mode != "zeros": + return F.conv1d( + F.pad( + input, + self._reversed_padding_repeated_twice, + mode=self.padding_mode, + ), + self.get_weight(), + self.get_bias(), + self.stride, + _single(0), # noqa: F821 + self.dilation, + self.groups, + ) + return F.conv1d( + input, + self.get_weight(), + self.get_bias(), + self.stride, + self.padding, + self.dilation, + self.groups, + ) class ScaledConv2d(nn.Conv2d): @@ -297,45 +329,58 @@ class ScaledConv2d(nn.Conv2d): if self.bias is not None: self.bias_scale = nn.Parameter(initial_scale.clone().detach()) else: - self.register_parameter('bias_scale', None) + self.register_parameter("bias_scale", None) self._reset_parameters() # Overrides the reset_parameters in base class def _reset_parameters(self): std = 0.01 - a = (3 ** 0.5) * std + a = (3**0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: nn.init.constant_(self.bias, 0.0) fan_in = self.weight.shape[1] * self.weight[0][0].numel() - scale = fan_in ** -0.5 # 1/sqrt(fan_in) + scale = fan_in**-0.5 # 1/sqrt(fan_in) with torch.no_grad(): self.weight_scale += torch.tensor(scale / std).log() if self.bias is not None: self.bias_scale += torch.tensor(scale / std).log() - def get_weight(self): return self.weight * self.weight_scale.exp() def get_bias(self): - return (None if self.bias is None else - self.bias * self.bias_scale.exp()) + return None if self.bias is None else self.bias * self.bias_scale.exp() def _conv_forward(self, input, weight): F = torch.nn.functional - if self.padding_mode != 'zeros': - return F.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode), - weight, self.get_bias(), self.stride, - _pair(0), self.dilation, self.groups) - return F.conv2d(input, weight, self.get_bias(), self.stride, - self.padding, self.dilation, self.groups) + if self.padding_mode != "zeros": + return F.conv2d( + F.pad( + input, + self._reversed_padding_repeated_twice, + mode=self.padding_mode, + ), + weight, + self.get_bias(), + self.stride, + _pair(0), # noqa: F821 + self.dilation, + self.groups, + ) + return F.conv2d( + input, + weight, + self.get_bias(), + self.stride, + self.padding, + self.dilation, + self.groups, + ) def forward(self, input: Tensor) -> Tensor: return self._conv_forward(input, self.get_weight()) - - class ActivationBalancer(torch.nn.Module): """ Modifies the backpropped derivatives of a function to try to encourage, for @@ -364,12 +409,16 @@ class ActivationBalancer(torch.nn.Module): we allow, before we start to modify the derivatives to prevent this. """ - def __init__(self, channel_dim: int, - min_positive: float = 0.05, - max_positive: float = 0.95, - max_factor: float = 0.01, - min_abs: float = 0.2, - max_abs: float = 100.0): + + def __init__( + self, + channel_dim: int, + min_positive: float = 0.05, + max_positive: float = 0.95, + max_factor: float = 0.01, + min_abs: float = 0.2, + max_abs: float = 100.0, + ): super(ActivationBalancer, self).__init__() self.channel_dim = channel_dim self.min_positive = min_positive @@ -379,10 +428,15 @@ class ActivationBalancer(torch.nn.Module): self.max_abs = max_abs def forward(self, x: Tensor) -> Tensor: - return ActivationBalancerFunction.apply(x, self.channel_dim, - self.min_positive, self.max_positive, - self.max_factor, self.min_abs, - self.max_abs) + return ActivationBalancerFunction.apply( + x, + self.channel_dim, + self.min_positive, + self.max_positive, + self.max_factor, + self.min_abs, + self.max_abs, + ) class DoubleSwishFunction(torch.autograd.Function): @@ -400,6 +454,7 @@ class DoubleSwishFunction(torch.autograd.Function): = double_swish(x) * (1-s(x)) + s(x) ... so we just need to remember s(x) but not x itself. """ + @staticmethod def forward(ctx, x: Tensor) -> Tensor: x = x.detach() @@ -411,18 +466,17 @@ class DoubleSwishFunction(torch.autograd.Function): @staticmethod def backward(ctx, y_grad: Tensor) -> Tensor: s, y = ctx.saved_tensors - return (y * (1-s) + s) * y_grad + return (y * (1 - s) + s) * y_grad + class DoubleSwish(torch.nn.Module): def forward(self, x: Tensor) -> Tensor: """Return double-swish activation function which is an approximation to Swish(Swish(x)), - that we approximate closely with x * sigmoid(x-1). + that we approximate closely with x * sigmoid(x-1). """ return DoubleSwishFunction.apply(x) - - class ScaledEmbedding(nn.Module): r"""A simple lookup table that stores embeddings of a fixed dictionary and size. @@ -491,8 +545,13 @@ class ScaledEmbedding(nn.Module): [ 0.0000, 0.0000, 0.0000], [-0.1655, 0.9897, 0.0635]]]) """ - __constants__ = ['num_embeddings', 'embedding_dim', 'padding_idx', - 'scale_grad_by_freq', 'sparse'] + __constants__ = [ + "num_embeddings", + "embedding_dim", + "padding_idx", + "scale_grad_by_freq", + "sparse", + ] num_embeddings: int embedding_dim: int @@ -501,33 +560,40 @@ class ScaledEmbedding(nn.Module): weight: Tensor sparse: bool - def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None, - scale_grad_by_freq: bool = False, - sparse: bool = False) -> None: + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + padding_idx: Optional[int] = None, + scale_grad_by_freq: bool = False, + sparse: bool = False, + ) -> None: super(ScaledEmbedding, self).__init__() self.num_embeddings = num_embeddings self.embedding_dim = embedding_dim if padding_idx is not None: if padding_idx > 0: - assert padding_idx < self.num_embeddings, 'Padding_idx must be within num_embeddings' + assert ( + padding_idx < self.num_embeddings + ), "Padding_idx must be within num_embeddings" elif padding_idx < 0: - assert padding_idx >= -self.num_embeddings, 'Padding_idx must be within num_embeddings' + assert ( + padding_idx >= -self.num_embeddings + ), "Padding_idx must be within num_embeddings" padding_idx = self.num_embeddings + padding_idx self.padding_idx = padding_idx self.scale_grad_by_freq = scale_grad_by_freq - self.scale = nn.Parameter(torch.zeros(())) # see reset_parameters() + self.scale = nn.Parameter(torch.zeros(())) # see reset_parameters() self.sparse = sparse self.weight = nn.Parameter(torch.Tensor(num_embeddings, embedding_dim)) self.reset_parameters() - - def reset_parameters(self) -> None: std = 0.01 nn.init.normal_(self.weight, std=std) - nn.init.constant_(self.scale, torch.tensor(1.0/std).log()) + nn.init.constant_(self.scale, torch.tensor(1.0 / std).log()) if self.padding_idx is not None: with torch.no_grad(): @@ -537,24 +603,37 @@ class ScaledEmbedding(nn.Module): F = torch.nn.functional scale = self.scale.exp() if input.numel() < self.num_embeddings: - return F.embedding( - input, self.weight, self.padding_idx, - None, 2.0, # None, 2.0 relate to normalization - self.scale_grad_by_freq, self.sparse) * scale + return ( + F.embedding( + input, + self.weight, + self.padding_idx, + None, + 2.0, # None, 2.0 relate to normalization + self.scale_grad_by_freq, + self.sparse, + ) + * scale + ) else: return F.embedding( - input, self.weight * scale, self.padding_idx, - None, 2.0, # None, 2.0 relates to normalization - self.scale_grad_by_freq, self.sparse) + input, + self.weight * scale, + self.padding_idx, + None, + 2.0, # None, 2.0 relates to normalization + self.scale_grad_by_freq, + self.sparse, + ) def extra_repr(self) -> str: - s = '{num_embeddings}, {embedding_dim}, scale={scale}' + s = "{num_embeddings}, {embedding_dim}, scale={scale}" if self.padding_idx is not None: - s += ', padding_idx={padding_idx}' + s += ", padding_idx={padding_idx}" if self.scale_grad_by_freq is not False: - s += ', scale_grad_by_freq={scale_grad_by_freq}' + s += ", scale_grad_by_freq={scale_grad_by_freq}" if self.sparse is not False: - s += ', sparse=True' + s += ", sparse=True" return s.format(**self.__dict__) @@ -565,8 +644,13 @@ def _test_activation_balancer_sign(): x = 1.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1)) x = x.detach() x.requires_grad = True - m = ActivationBalancer(channel_dim=0, min_positive=0.05, max_positive=0.95, - max_factor=0.2, min_abs=0.0) + m = ActivationBalancer( + channel_dim=0, + min_positive=0.05, + max_positive=0.95, + max_factor=0.2, + min_abs=0.0, + ) y_grad = torch.sign(torch.randn(probs.numel(), N)) @@ -576,17 +660,22 @@ def _test_activation_balancer_sign(): print("_test_activation_balancer_sign: y grad = ", y_grad) print("_test_activation_balancer_sign: x grad = ", x.grad) + def _test_activation_balancer_magnitude(): channel_dim = 0 magnitudes = torch.arange(0, 1, 0.01) N = 1000 - x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1) + x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1) x = x.detach() x.requires_grad = True - m = ActivationBalancer(channel_dim=0, - min_positive=0.0, max_positive=1.0, - max_factor=0.2, - min_abs=0.2, max_abs=0.8) + m = ActivationBalancer( + channel_dim=0, + min_positive=0.0, + max_positive=1.0, + max_factor=0.2, + min_abs=0.2, + max_abs=0.8, + ) y_grad = torch.sign(torch.randn(magnitudes.numel(), N)) @@ -621,7 +710,7 @@ def _test_double_swish_deriv(): torch.autograd.gradcheck(m, x) -if __name__ == '__main__': +if __name__ == "__main__": _test_activation_balancer_sign() _test_activation_balancer_magnitude() _test_basic_norm() diff --git a/egs/librispeech/ASR/pruned2_knowledge/train.py b/egs/librispeech/ASR/pruned2_knowledge/train.py index 2f6840166..a60d15c3b 100755 --- a/egs/librispeech/ASR/pruned2_knowledge/train.py +++ b/egs/librispeech/ASR/pruned2_knowledge/train.py @@ -78,9 +78,7 @@ from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def get_parser(): @@ -179,42 +177,45 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", + help=( + "The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss" + ), ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", + help=( + "The scale to smooth the loss with lm (output of prediction network) part." + ), ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network)part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", + help=( + "To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss." + ), ) parser.add_argument( @@ -554,23 +555,16 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 - if warmup < 1.0 - else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = ( - params.simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss + 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) ) + loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() @@ -733,9 +727,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -835,7 +827,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2 ** 22 + 2**22 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/decode.py b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/decode.py index 2d5724d30..1df1650f3 100755 --- a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/decode.py +++ b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/decode.py @@ -123,20 +123,24 @@ def get_parser(): "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'" + ), ) parser.add_argument( "--use-averaged-model", type=str2bool, default=False, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", + help=( + "Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. " + ), ) parser.add_argument( @@ -204,8 +208,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -272,9 +275,7 @@ def decode_one_batch( value=LOG_EPS, ) - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) hyps = [] if params.decoding_method == "fast_beam_search": @@ -289,10 +290,7 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -338,11 +336,7 @@ def decode_one_batch( return {"greedy_search": hyps} elif params.decoding_method == "fast_beam_search": return { - ( - f"beam_{params.beam}_" - f"max_contexts_{params.max_contexts}_" - f"max_states_{params.max_states}" - ): hyps + f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps } else: return {f"beam_size_{params.beam_size}": hyps} @@ -415,9 +409,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -450,8 +442,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -494,9 +485,7 @@ def main(): params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -528,13 +517,12 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -557,13 +545,12 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -591,7 +578,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - f"Calculating the averaged model over epoch range from " + "Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) diff --git a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/emformer.py b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/emformer.py index 318cd5094..008f40fb1 100644 --- a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/emformer.py +++ b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/emformer.py @@ -272,13 +272,9 @@ class Emformer(EncoderInterface): # Caution: We assume the subsampling factor is 4! x_lens = (((x_lens - 1) >> 1) - 1) >> 1 - emformer_out, emformer_out_lens, states = self.model.infer( - x, x_lens, states - ) + emformer_out, emformer_out_lens, states = self.model.infer(x, x_lens, states) - if x.size(1) != ( - self.model.segment_length + self.model.right_context_length - ): + if x.size(1) != (self.model.segment_length + self.model.right_context_length): raise ValueError( "Incorrect input shape." f"{x.size(1)} vs {self.model.segment_length} + " diff --git a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/export.py b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/export.py index 2375f5001..81afb523d 100755 --- a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/export.py +++ b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/export.py @@ -89,20 +89,24 @@ def get_parser(): "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'" + ), ) parser.add_argument( "--use-averaged-model", type=str2bool, default=False, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", + help=( + "Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. " + ), ) parser.add_argument( @@ -133,8 +137,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) add_model_arguments(parser) @@ -170,13 +173,12 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -199,13 +201,12 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -233,7 +234,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - f"Calculating the averaged model over epoch range from " + "Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -273,9 +274,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/model.py b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/model.py index 2f019bcdb..ed6848879 100644 --- a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/model.py +++ b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/model.py @@ -122,9 +122,7 @@ class Transducer(nn.Module): y_padded = y.pad(mode="constant", padding_value=0) y_padded = y_padded.to(torch.int64) - boundary = torch.zeros( - (x.size(0), 4), dtype=torch.int64, device=x.device - ) + boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) boundary[:, 2] = y_lens boundary[:, 3] = x_lens diff --git a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/train.py b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/train.py index fed814f19..6b30d3be8 100755 --- a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/train.py +++ b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/train.py @@ -209,42 +209,45 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", + help=( + "The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss" + ), ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", + help=( + "The scale to smooth the loss with lm (output of prediction network) part." + ), ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network)part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", + help=( + "To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss." + ), ) parser.add_argument( @@ -566,11 +569,7 @@ def compute_loss( function enables autograd during computation; when it is False, it disables autograd. """ - device = ( - model.device - if isinstance(model, DDP) - else next(model.parameters()).device - ) + device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -599,9 +598,7 @@ def compute_loss( info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa info["utterances"] = feature.size(0) @@ -782,9 +779,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -908,8 +903,7 @@ def run(rank, world_size, args): # the threshold if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. " - f"Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" ) return False diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py index 7af9cc3d7..830b37cfb 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py @@ -509,9 +509,9 @@ def greedy_search( y = logits.argmax().item() if y not in (blank_id, unk_id): hyp.append(y) - decoder_input = torch.tensor( - [hyp[-context_size:]], device=device - ).reshape(1, context_size) + decoder_input = torch.tensor([hyp[-context_size:]], device=device).reshape( + 1, context_size + ) decoder_out = model.decoder(decoder_input, need_pad=False) @@ -670,9 +670,7 @@ class HypothesisList(object): if use_max: old_hyp.log_prob = max(old_hyp.log_prob, hyp.log_prob) else: - torch.logaddexp( - old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob - ) + torch.logaddexp(old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob) else: self._data[key] = hyp @@ -688,9 +686,7 @@ class HypothesisList(object): Return the hypothesis that has the largest `log_prob`. """ if length_norm: - return max( - self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys) - ) + return max(self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys)) else: return max(self._data.values(), key=lambda hyp: hyp.log_prob) @@ -892,9 +888,7 @@ def modified_beam_search( log_probs_shape = k2.ragged.create_ragged_shape2( row_splits=row_splits, cached_tot_size=log_probs.numel() ) - ragged_log_probs = k2.RaggedTensor( - shape=log_probs_shape, value=log_probs - ) + ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) for i in range(batch_size): topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) @@ -1088,9 +1082,7 @@ def beam_search( t = 0 B = HypothesisList() - B.add( - Hypothesis(ys=[blank_id] * context_size, log_prob=0.0), use_max=use_max - ) + B.add(Hypothesis(ys=[blank_id] * context_size, log_prob=0.0), use_max=use_max) max_sym_per_utt = 20000 @@ -1130,9 +1122,7 @@ def beam_search( cached_key += f"-t-{t}" if cached_key not in joint_cache: - logits = model.joiner( - current_encoder_out, decoder_out.unsqueeze(1) - ) + logits = model.joiner(current_encoder_out, decoder_out.unsqueeze(1)) # TODO(fangjun): Scale the blank posterior diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py index 7b6338948..03ad45f49 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py @@ -128,11 +128,7 @@ from beam_search import ( ) from train import add_model_arguments, get_params, get_transducer_model -from icefall.checkpoint import ( - average_checkpoints, - find_checkpoints, - load_checkpoint, -) +from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, @@ -171,9 +167,11 @@ def get_parser(): "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'" + ), ) parser.add_argument( @@ -269,8 +267,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -383,9 +380,7 @@ def decode_one_batch( simulate_streaming=True, ) else: - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) hyps = [] if ( @@ -450,10 +445,7 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -584,9 +576,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -619,8 +609,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -678,9 +667,7 @@ def main(): if "LG" in params.decoding_method: params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -718,8 +705,7 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -757,9 +743,7 @@ def main(): decoding_graph.scores *= params.ngram_lm_scale else: word_table = None - decoding_graph = k2.trivial_graph( - params.vocab_size - 1, device=device - ) + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) else: decoding_graph = None word_table = None diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/decode_stream.py b/egs/librispeech/ASR/pruned_transducer_stateless/decode_stream.py index 386248554..e522943c0 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless/decode_stream.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/decode_stream.py @@ -75,9 +75,7 @@ class DecodeStream(object): # encoder.streaming_forward self.done_frames: int = 0 - self.pad_length = ( - params.right_context + 2 - ) * params.subsampling_factor + 3 + self.pad_length = (params.right_context + 2) * params.subsampling_factor + 3 if params.decoding_method == "greedy_search": self.hyp = [params.blank_id] * params.context_size @@ -91,13 +89,11 @@ class DecodeStream(object): ) elif params.decoding_method == "fast_beam_search": # The rnnt_decoding_stream for fast_beam_search. - self.rnnt_decoding_stream: k2.RnntDecodingStream = ( - k2.RnntDecodingStream(decoding_graph) + self.rnnt_decoding_stream: k2.RnntDecodingStream = k2.RnntDecodingStream( + decoding_graph ) else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") @property def done(self) -> bool: @@ -126,13 +122,10 @@ class DecodeStream(object): """Consume chunk_size frames of features""" chunk_length = chunk_size + self.pad_length - ret_length = min( - self.num_frames - self.num_processed_frames, chunk_length - ) + ret_length = min(self.num_frames - self.num_processed_frames, chunk_length) ret_features = self.features[ - self.num_processed_frames : self.num_processed_frames # noqa - + ret_length + self.num_processed_frames : self.num_processed_frames + ret_length # noqa ] self.num_processed_frames += chunk_size diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/decoder.py b/egs/librispeech/ASR/pruned_transducer_stateless/decoder.py index f4355e8a0..72593173c 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless/decoder.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/decoder.py @@ -92,9 +92,7 @@ class Decoder(nn.Module): if self.context_size > 1: embedding_out = embedding_out.permute(0, 2, 1) if need_pad is True: - embedding_out = F.pad( - embedding_out, pad=(self.context_size - 1, 0) - ) + embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0)) else: # During inference time, there is no need to do extra padding # as we only need one output diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/export.py b/egs/librispeech/ASR/pruned_transducer_stateless/export.py index b5a151878..64708e524 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/export.py @@ -64,17 +64,20 @@ def get_parser(): "--epoch", type=int, default=28, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", + help=( + "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." + ), ) parser.add_argument( "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. " + ), ) parser.add_argument( @@ -105,8 +108,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( @@ -192,9 +194,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/model.py b/egs/librispeech/ASR/pruned_transducer_stateless/model.py index 73b651b3f..2cca7fa27 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/model.py @@ -130,9 +130,7 @@ class Transducer(nn.Module): y_padded = y.pad(mode="constant", padding_value=0) y_padded = y_padded.to(torch.int64) - boundary = torch.zeros( - (x.size(0), 4), dtype=torch.int64, device=x.device - ) + boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) boundary[:, 2] = y_lens boundary[:, 3] = x_lens diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless/pretrained.py index eb95827af..a42b63b9c 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/pretrained.py @@ -91,9 +91,11 @@ def get_parser(): "--checkpoint", type=str, required=True, - help="Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint().", + help=( + "Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint()." + ), ) parser.add_argument( @@ -118,10 +120,12 @@ def get_parser(): "sound_files", type=str, nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", + help=( + "The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz." + ), ) parser.add_argument( @@ -168,8 +172,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -221,10 +224,9 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" - ) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" # We use only the first channel ans.append(wave[0]) return ans @@ -292,9 +294,7 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) feature_lengths = torch.tensor(feature_lengths, device=device) @@ -381,9 +381,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/streaming_beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless/streaming_beam_search.py index dcf6dc42f..9e09200a1 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless/streaming_beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/streaming_beam_search.py @@ -166,14 +166,10 @@ def modified_beam_search( log_probs_shape = k2.ragged.create_ragged_shape2( row_splits=row_splits, cached_tot_size=log_probs.numel() ) - ragged_log_probs = k2.RaggedTensor( - shape=log_probs_shape, value=log_probs - ) + ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) for i in range(batch_size): - topk_log_probs, topk_indexes = ragged_log_probs[i].topk( - num_active_paths - ) + topk_log_probs, topk_indexes = ragged_log_probs[i].topk(num_active_paths) with warnings.catch_warnings(): warnings.simplefilter("ignore") diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py index d2cae4f9f..a50b4d4f0 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py @@ -51,11 +51,7 @@ from streaming_beam_search import ( from torch.nn.utils.rnn import pad_sequence from train import add_model_arguments, get_params, get_transducer_model -from icefall.checkpoint import ( - average_checkpoints, - find_checkpoints, - load_checkpoint, -) +from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.utils import ( AttributeDict, setup_logger, @@ -94,9 +90,11 @@ def get_parser(): "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'" + ), ) parser.add_argument( @@ -162,8 +160,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( @@ -269,9 +266,7 @@ def decode_one_chunk( ) if params.decoding_method == "greedy_search": - greedy_search( - model=model, encoder_out=encoder_out, streams=decode_streams - ) + greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams) elif params.decoding_method == "fast_beam_search": processed_lens = processed_lens + encoder_out_lens fast_beam_search_one_best( @@ -291,9 +286,7 @@ def decode_one_chunk( num_active_paths=params.num_active_paths, ) else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") states = [torch.unbind(states[0], dim=2), torch.unbind(states[1], dim=2)] @@ -349,9 +342,7 @@ def decode_dataset( decode_results = [] # Contain decode streams currently running. decode_streams = [] - initial_states = model.encoder.get_init_state( - params.left_context, device=device - ) + initial_states = model.encoder.get_init_state(params.left_context, device=device) for num, cut in enumerate(cuts): # each utterance has a DecodeStream. decode_stream = DecodeStream( @@ -422,9 +413,7 @@ def decode_dataset( elif params.decoding_method == "modified_beam_search": key = f"num_active_paths_{params.num_active_paths}" else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") return {key: decode_results} @@ -460,8 +449,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -533,8 +521,7 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/train.py b/egs/librispeech/ASR/pruned_transducer_stateless/train.py index 399b11a29..dd0331a60 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/train.py @@ -203,42 +203,45 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", + help=( + "The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss" + ), ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", + help=( + "The scale to smooth the loss with lm (output of prediction network) part." + ), ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network)part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", + help=( + "To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss." + ), ) parser.add_argument( @@ -562,9 +565,7 @@ def compute_loss( # If either all simple_loss or pruned_loss is inf or nan, # we stop the training process by raising an exception - if torch.all(~simple_loss_is_finite) or torch.all( - ~pruned_loss_is_finite - ): + if torch.all(~simple_loss_is_finite) or torch.all(~pruned_loss_is_finite): raise ValueError( "There are too many utterances in this batch " "leading to inf or nan losses." @@ -584,9 +585,7 @@ def compute_loss( # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2 # (2) If some utterances in the batch lead to inf/nan loss, they # are filtered out. - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa info["utterances"] = feature.size(0) @@ -777,9 +776,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -897,8 +894,7 @@ def run(rank, world_size, args): # the threshold if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. " - f"Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" ) return False @@ -956,9 +952,7 @@ def run(rank, world_size, args): cur_lr = optimizer._rate if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) + tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) if rank == 0: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py index b7c2010f7..5e9428b60 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -580,9 +580,9 @@ def greedy_search( if y not in (blank_id, unk_id): hyp.append(y) timestamp.append(t) - decoder_input = torch.tensor( - [hyp[-context_size:]], device=device - ).reshape(1, context_size) + decoder_input = torch.tensor([hyp[-context_size:]], device=device).reshape( + 1, context_size + ) decoder_out = model.decoder(decoder_input, need_pad=False) decoder_out = model.joiner.decoder_proj(decoder_out) @@ -775,9 +775,7 @@ class HypothesisList(object): key = hyp.key if key in self: old_hyp = self._data[key] # shallow copy - torch.logaddexp( - old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob - ) + torch.logaddexp(old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob) else: self._data[key] = hyp @@ -793,9 +791,7 @@ class HypothesisList(object): Return the hypothesis that has the largest `log_prob`. """ if length_norm: - return max( - self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys) - ) + return max(self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys)) else: return max(self._data.values(), key=lambda hyp: hyp.log_prob) @@ -990,9 +986,7 @@ def modified_beam_search( logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) - log_probs = (logits / temperature).log_softmax( - dim=-1 - ) # (num_hyps, vocab_size) + log_probs = (logits / temperature).log_softmax(dim=-1) # (num_hyps, vocab_size) log_probs.add_(ys_log_probs) @@ -1004,9 +998,7 @@ def modified_beam_search( log_probs_shape = k2.ragged.create_ragged_shape2( row_splits=row_splits, cached_tot_size=log_probs.numel() ) - ragged_log_probs = k2.RaggedTensor( - shape=log_probs_shape, value=log_probs - ) + ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) for i in range(batch_size): topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) @@ -1676,9 +1668,7 @@ def fast_beam_search_with_nbest_rnn_rescoring( for rnn_scale in rnn_lm_scale_list: key = f"ngram_lm_scale_{n_scale}_rnn_lm_scale_{rnn_scale}" tot_scores = ( - am_scores.values - + n_scale * ngram_lm_scores - + rnn_scale * rnn_lm_scores + am_scores.values + n_scale * ngram_lm_scores + rnn_scale * rnn_lm_scores ) ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores) max_indexes = ragged_tot_scores.argmax() @@ -1804,9 +1794,7 @@ def modified_beam_search_ngram_rescoring( logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) - log_probs = (logits / temperature).log_softmax( - dim=-1 - ) # (num_hyps, vocab_size) + log_probs = (logits / temperature).log_softmax(dim=-1) # (num_hyps, vocab_size) log_probs.add_(ys_log_probs) vocab_size = log_probs.size(-1) @@ -1816,9 +1804,7 @@ def modified_beam_search_ngram_rescoring( log_probs_shape = k2.ragged.create_ragged_shape2( row_splits=row_splits, cached_tot_size=log_probs.numel() ) - ragged_log_probs = k2.RaggedTensor( - shape=log_probs_shape, value=log_probs - ) + ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) for i in range(batch_size): topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) @@ -1841,9 +1827,7 @@ def modified_beam_search_ngram_rescoring( state_cost = hyp.state_cost # We only keep AM scores in new_hyp.log_prob - new_log_prob = ( - topk_log_probs[k] - hyp.state_cost.lm_score * lm_scale - ) + new_log_prob = topk_log_probs[k] - hyp.state_cost.lm_score * lm_scale new_hyp = Hypothesis( ys=new_ys, log_prob=new_log_prob, state_cost=state_cost @@ -1995,9 +1979,7 @@ def modified_beam_search_rnnlm_shallow_fusion( log_probs_shape = k2.ragged.create_ragged_shape2( row_splits=row_splits, cached_tot_size=log_probs.numel() ) - ragged_log_probs = k2.RaggedTensor( - shape=log_probs_shape, value=log_probs - ) + ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) """ for all hyps with a non-blank new token, score this token. It is a little confusing here because this for-loop @@ -2032,10 +2014,7 @@ def modified_beam_search_rnnlm_shallow_fusion( # forward RNNLM to get new states and scores if len(token_list) != 0: tokens_to_score = ( - torch.tensor(token_list) - .to(torch.int64) - .to(device) - .reshape(-1, 1) + torch.tensor(token_list).to(torch.int64).to(device).reshape(-1, 1) ) hs = torch.cat(hs, dim=1).to(device) @@ -2067,9 +2046,7 @@ def modified_beam_search_rnnlm_shallow_fusion( ys.append(new_token) new_timestamp.append(t) - hyp_log_prob += ( - lm_score[new_token] * lm_scale - ) # add the lm score + hyp_log_prob += lm_score[new_token] * lm_scale # add the lm score lm_score = scores[count] state = ( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py index bc273d33b..34ff0d7e2 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py @@ -214,10 +214,7 @@ class Conformer(EncoderInterface): NOTE: the returned tensors are on the given device. """ - if ( - len(self._init_state) == 2 - and self._init_state[0].size(1) == left_context - ): + if len(self._init_state) == 2 and self._init_state[0].size(1) == left_context: # Note: It is OK to share the init state as it is # not going to be modified by the model return self._init_state @@ -439,9 +436,7 @@ class ConformerEncoderLayer(nn.Module): self.d_model = d_model - self.self_attn = RelPositionMultiheadAttention( - d_model, nhead, dropout=0.0 - ) + self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) self.feed_forward = nn.Sequential( ScaledLinear(d_model, dim_feedforward), @@ -459,9 +454,7 @@ class ConformerEncoderLayer(nn.Module): ScaledLinear(dim_feedforward, d_model, initial_scale=0.25), ) - self.conv_module = ConvolutionModule( - d_model, cnn_module_kernel, causal=causal - ) + self.conv_module = ConvolutionModule(d_model, cnn_module_kernel, causal=causal) self.norm_final = BasicNorm(d_model) @@ -527,9 +520,7 @@ class ConformerEncoderLayer(nn.Module): src = src + self.dropout(src_att) # convolution module - conv, _ = self.conv_module( - src, src_key_padding_mask=src_key_padding_mask - ) + conv, _ = self.conv_module(src, src_key_padding_mask=src_key_padding_mask) src = src + self.dropout(conv) # feed forward module @@ -785,9 +776,7 @@ class RelPositionalEncoding(torch.nn.Module): """ - def __init__( - self, d_model: int, dropout_rate: float, max_len: int = 5000 - ) -> None: + def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: """Construct an PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() if is_jit_tracing(): @@ -811,9 +800,7 @@ class RelPositionalEncoding(torch.nn.Module): # the length of self.pe is 2 * input_len - 1 if self.pe.size(1) >= x_size_1 * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str( - x.device - ): + if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return # Suppose `i` means to the position of query vector and `j` means the @@ -1127,9 +1114,9 @@ class RelPositionMultiheadAttention(nn.Module): if torch.equal(query, key) and torch.equal(key, value): # self-attention - q, k, v = nn.functional.linear( - query, in_proj_weight, in_proj_bias - ).chunk(3, dim=-1) + q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk( + 3, dim=-1 + ) elif torch.equal(key, value): # encoder-decoder attention @@ -1198,33 +1185,25 @@ class RelPositionMultiheadAttention(nn.Module): if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: - raise RuntimeError( - "The size of the 2D attn_mask is not correct." - ) + raise RuntimeError("The size of the 2D attn_mask is not correct.") elif attn_mask.dim() == 3: if list(attn_mask.size()) != [ bsz * num_heads, query.size(0), key.size(0), ]: - raise RuntimeError( - "The size of the 3D attn_mask is not correct." - ) + raise RuntimeError("The size of the 3D attn_mask is not correct.") else: raise RuntimeError( - "attn_mask's dimension {} is not supported".format( - attn_mask.dim() - ) + "attn_mask's dimension {} is not supported".format(attn_mask.dim()) ) # attn_mask's dim is 3 now. # convert ByteTensor key_padding_mask to bool - if ( - key_padding_mask is not None - and key_padding_mask.dtype == torch.uint8 - ): + if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: warnings.warn( - "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." + "Byte tensor for key_padding_mask is deprecated. Use bool tensor" + " instead." ) key_padding_mask = key_padding_mask.to(torch.bool) @@ -1264,23 +1243,15 @@ class RelPositionMultiheadAttention(nn.Module): # first compute matrix a and matrix c # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - matrix_ac = torch.matmul( - q_with_bias_u, k - ) # (batch, head, time1, time2) + matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2) # compute matrix b and matrix d - matrix_bd = torch.matmul( - q_with_bias_v, p - ) # (batch, head, time1, 2*time1-1) + matrix_bd = torch.matmul(q_with_bias_v, p) # (batch, head, time1, 2*time1-1) matrix_bd = self.rel_shift(matrix_bd, left_context) - attn_output_weights = ( - matrix_ac + matrix_bd - ) # (batch, head, time1, time2) + attn_output_weights = matrix_ac + matrix_bd # (batch, head, time1, time2) - attn_output_weights = attn_output_weights.view( - bsz * num_heads, tgt_len, -1 - ) + attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) if not is_jit_tracing(): assert list(attn_output_weights.size()) == [ @@ -1322,21 +1293,17 @@ class RelPositionMultiheadAttention(nn.Module): ): if attn_mask.size(0) != 1: attn_mask = attn_mask.view(bsz, num_heads, tgt_len, src_len) - combined_mask = attn_mask | key_padding_mask.unsqueeze( - 1 - ).unsqueeze(2) + combined_mask = attn_mask | key_padding_mask.unsqueeze(1).unsqueeze(2) else: # attn_mask.shape == (1, tgt_len, src_len) - combined_mask = attn_mask.unsqueeze( - 0 - ) | key_padding_mask.unsqueeze(1).unsqueeze(2) + combined_mask = attn_mask.unsqueeze(0) | key_padding_mask.unsqueeze( + 1 + ).unsqueeze(2) attn_output_weights = attn_output_weights.view( bsz, num_heads, tgt_len, src_len ) - attn_output_weights = attn_output_weights.masked_fill( - combined_mask, 0.0 - ) + attn_output_weights = attn_output_weights.masked_fill(combined_mask, 0.0) attn_output_weights = attn_output_weights.view( bsz * num_heads, tgt_len, src_len ) @@ -1355,13 +1322,9 @@ class RelPositionMultiheadAttention(nn.Module): ] attn_output = ( - attn_output.transpose(0, 1) - .contiguous() - .view(tgt_len, bsz, embed_dim) - ) - attn_output = nn.functional.linear( - attn_output, out_proj_weight, out_proj_bias + attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) ) + attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) if need_weights: # average attention weights over heads @@ -1498,16 +1461,12 @@ class ConvolutionModule(nn.Module): # manualy padding self.lorder zeros to the left x = nn.functional.pad(x, (self.lorder, 0), "constant", 0.0) else: - assert ( - not self.training - ), "Cache should be None in training time" + assert not self.training, "Cache should be None in training time" assert cache.size(0) == self.lorder x = torch.cat([cache.permute(1, 2, 0), x], dim=2) if right_context > 0: cache = x.permute(2, 0, 1)[ - -(self.lorder + right_context) : ( # noqa - -right_context - ), + -(self.lorder + right_context) : (-right_context), # noqa ..., ] else: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py index 979a0e02e..32cd53be3 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py @@ -132,11 +132,7 @@ from beam_search import ( ) from train import add_model_arguments, get_params, get_transducer_model -from icefall.checkpoint import ( - average_checkpoints, - find_checkpoints, - load_checkpoint, -) +from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, @@ -177,9 +173,11 @@ def get_parser(): "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'" + ), ) parser.add_argument( @@ -275,8 +273,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( @@ -397,9 +394,7 @@ def decode_one_batch( simulate_streaming=True, ) else: - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) hyps = [] @@ -465,10 +460,7 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -514,11 +506,7 @@ def decode_one_batch( return {"greedy_search": hyps} elif params.decoding_method == "fast_beam_search": return { - ( - f"beam_{params.beam}_" - f"max_contexts_{params.max_contexts}_" - f"max_states_{params.max_states}" - ): hyps + f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps } elif "fast_beam_search" in params.decoding_method: key = f"beam_{params.beam}_" @@ -608,9 +596,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -643,8 +629,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -700,9 +685,7 @@ def main(): if "LG" in params.decoding_method: params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -740,8 +723,7 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -779,9 +761,7 @@ def main(): decoding_graph.scores *= params.ngram_lm_scale else: word_table = None - decoding_graph = k2.trivial_graph( - params.vocab_size - 1, device=device - ) + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) else: decoding_graph = None word_table = None diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py b/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py index ba91302ce..b59928103 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py @@ -107,15 +107,11 @@ class Decoder(nn.Module): # This is for exporting to PNNX via ONNX embedding_out = self.embedding(y) else: - embedding_out = self.embedding(y.clamp(min=0)) * (y >= 0).unsqueeze( - -1 - ) + embedding_out = self.embedding(y.clamp(min=0)) * (y >= 0).unsqueeze(-1) if self.context_size > 1: embedding_out = embedding_out.permute(0, 2, 1) if need_pad: - embedding_out = F.pad( - embedding_out, pad=(self.context_size - 1, 0) - ) + embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0)) else: # During inference time, there is no need to do extra padding # as we only need one output diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/export.py b/egs/librispeech/ASR/pruned_transducer_stateless2/export.py index f1a8ea589..90367bd03 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/export.py @@ -51,11 +51,7 @@ import sentencepiece as spm import torch from train import add_model_arguments, get_params, get_transducer_model -from icefall.checkpoint import ( - average_checkpoints, - find_checkpoints, - load_checkpoint, -) +from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.utils import str2bool @@ -87,9 +83,11 @@ def get_parser(): "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'" + ), ) parser.add_argument( @@ -120,8 +118,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( @@ -173,8 +170,7 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -222,9 +218,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py b/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py index 6a9d08033..1954f4724 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py @@ -60,9 +60,7 @@ class Joiner(nn.Module): assert encoder_out.shape == decoder_out.shape if project_input: - logit = self.encoder_proj(encoder_out) + self.decoder_proj( - decoder_out - ) + logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out) else: logit = encoder_out + decoder_out diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/model.py b/egs/librispeech/ASR/pruned_transducer_stateless2/model.py index 417c391d9..272d06c37 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/model.py @@ -66,9 +66,7 @@ class Transducer(nn.Module): self.decoder = decoder self.joiner = joiner - self.simple_am_proj = ScaledLinear( - encoder_dim, vocab_size, initial_speed=0.5 - ) + self.simple_am_proj = ScaledLinear(encoder_dim, vocab_size, initial_speed=0.5) self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size) def forward( @@ -152,9 +150,7 @@ class Transducer(nn.Module): y_padded = y.pad(mode="constant", padding_value=0) y_padded = y_padded.to(torch.int64) - boundary = torch.zeros( - (x.size(0), 4), dtype=torch.int64, device=x.device - ) + boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) boundary[:, 2] = y_lens boundary[:, 3] = x_lens diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py index 041a81f45..2d7f557ad 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py @@ -72,17 +72,11 @@ class Eve(Optimizer): if not 0.0 <= eps: raise ValueError("Invalid epsilon value: {}".format(eps)) if not 0.0 <= betas[0] < 1.0: - raise ValueError( - "Invalid beta parameter at index 0: {}".format(betas[0]) - ) + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) if not 0.0 <= betas[1] < 1.0: - raise ValueError( - "Invalid beta parameter at index 1: {}".format(betas[1]) - ) + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) if not 0 <= weight_decay <= 0.1: - raise ValueError( - "Invalid weight_decay value: {}".format(weight_decay) - ) + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) if not 0 < target_rms <= 10.0: raise ValueError("Invalid target_rms value: {}".format(target_rms)) defaults = dict( @@ -118,9 +112,7 @@ class Eve(Optimizer): # Perform optimization step grad = p.grad if grad.is_sparse: - raise RuntimeError( - "AdamW does not support sparse gradients" - ) + raise RuntimeError("AdamW does not support sparse gradients") state = self.state[p] @@ -147,7 +139,7 @@ class Eve(Optimizer): # Decay the first and second moment running average coefficient exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) - denom = (exp_avg_sq.sqrt() * (bias_correction2 ** -0.5)).add_( + denom = (exp_avg_sq.sqrt() * (bias_correction2**-0.5)).add_( group["eps"] ) @@ -158,9 +150,7 @@ class Eve(Optimizer): if p.numel() > 1: # avoid applying this weight-decay on "scaling factors" # (which are scalar). - is_above_target_rms = p.norm() > ( - target_rms * (p.numel() ** 0.5) - ) + is_above_target_rms = p.norm() > (target_rms * (p.numel() ** 0.5)) p.mul_(1 - (weight_decay * is_above_target_rms)) p.addcdiv_(exp_avg, denom, value=-step_size) @@ -180,18 +170,14 @@ class LRScheduler(object): def __init__(self, optimizer: Optimizer, verbose: bool = False): # Attach optimizer if not isinstance(optimizer, Optimizer): - raise TypeError( - "{} is not an Optimizer".format(type(optimizer).__name__) - ) + raise TypeError("{} is not an Optimizer".format(type(optimizer).__name__)) self.optimizer = optimizer self.verbose = verbose for group in optimizer.param_groups: group.setdefault("initial_lr", group["lr"]) - self.base_lrs = [ - group["initial_lr"] for group in optimizer.param_groups - ] + self.base_lrs = [group["initial_lr"] for group in optimizer.param_groups] self.epoch = 0 self.batch = 0 @@ -299,10 +285,9 @@ class Eden(LRScheduler): def get_lr(self): factor = ( - (self.batch ** 2 + self.lr_batches ** 2) / self.lr_batches ** 2 + (self.batch**2 + self.lr_batches**2) / self.lr_batches**2 ) ** -0.25 * ( - ((self.epoch ** 2 + self.lr_epochs ** 2) / self.lr_epochs ** 2) - ** -0.25 + ((self.epoch**2 + self.lr_epochs**2) / self.lr_epochs**2) ** -0.25 ) return [x * factor for x in self.base_lrs] diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless2/pretrained.py index f52cb22ab..58de6875f 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/pretrained.py @@ -91,9 +91,11 @@ def get_parser(): "--checkpoint", type=str, required=True, - help="Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint().", + help=( + "Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint()." + ), ) parser.add_argument( @@ -118,10 +120,12 @@ def get_parser(): "sound_files", type=str, nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", + help=( + "The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz." + ), ) parser.add_argument( @@ -168,8 +172,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -222,10 +225,9 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" - ) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" # We use only the first channel ans.append(wave[0]) return ans @@ -293,9 +295,7 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) feature_lengths = torch.tensor(feature_lengths, device=device) @@ -382,9 +382,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py index 8c572a9ef..f671e97b1 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py @@ -89,9 +89,7 @@ class ActivationBalancerFunction(torch.autograd.Function): below_threshold = mean_abs < min_abs above_threshold = mean_abs > max_abs - ctx.save_for_backward( - factor, xgt0, below_threshold, above_threshold - ) + ctx.save_for_backward(factor, xgt0, below_threshold, above_threshold) ctx.max_factor = max_factor ctx.sum_dims = sum_dims return x @@ -137,7 +135,7 @@ class GradientFilterFunction(torch.autograd.Function): eps = 1.0e-20 dim = ctx.batch_dim norm_dims = [d for d in range(x_grad.ndim) if d != dim] - norm_of_batch = (x_grad ** 2).mean(dim=norm_dims, keepdim=True).sqrt() + norm_of_batch = (x_grad**2).mean(dim=norm_dims, keepdim=True).sqrt() median_norm = norm_of_batch.median() cutoff = median_norm * ctx.threshold @@ -229,8 +227,7 @@ class BasicNorm(torch.nn.Module): if not is_jit_tracing(): assert x.shape[self.channel_dim] == self.num_channels scales = ( - torch.mean(x ** 2, dim=self.channel_dim, keepdim=True) - + self.eps.exp() + torch.mean(x**2, dim=self.channel_dim, keepdim=True) + self.eps.exp() ) ** -0.5 return x * scales @@ -282,12 +279,12 @@ class ScaledLinear(nn.Linear): def _reset_parameters(self, initial_speed: float): std = 0.1 / initial_speed - a = (3 ** 0.5) * std + a = (3**0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: nn.init.constant_(self.bias, 0.0) fan_in = self.weight.shape[1] * self.weight[0][0].numel() - scale = fan_in ** -0.5 # 1/sqrt(fan_in) + scale = fan_in**-0.5 # 1/sqrt(fan_in) with torch.no_grad(): self.weight_scale += torch.tensor(scale / std).log() @@ -301,9 +298,7 @@ class ScaledLinear(nn.Linear): return self.bias * self.bias_scale.exp() def forward(self, input: Tensor) -> Tensor: - return torch.nn.functional.linear( - input, self.get_weight(), self.get_bias() - ) + return torch.nn.functional.linear(input, self.get_weight(), self.get_bias()) class ScaledConv1d(nn.Conv1d): @@ -331,12 +326,12 @@ class ScaledConv1d(nn.Conv1d): def _reset_parameters(self, initial_speed: float): std = 0.1 / initial_speed - a = (3 ** 0.5) * std + a = (3**0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: nn.init.constant_(self.bias, 0.0) fan_in = self.weight.shape[1] * self.weight[0][0].numel() - scale = fan_in ** -0.5 # 1/sqrt(fan_in) + scale = fan_in**-0.5 # 1/sqrt(fan_in) with torch.no_grad(): self.weight_scale += torch.tensor(scale / std).log() @@ -400,12 +395,12 @@ class ScaledConv2d(nn.Conv2d): def _reset_parameters(self, initial_speed: float): std = 0.1 / initial_speed - a = (3 ** 0.5) * std + a = (3**0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: nn.init.constant_(self.bias, 0.0) fan_in = self.weight.shape[1] * self.weight[0][0].numel() - scale = fan_in ** -0.5 # 1/sqrt(fan_in) + scale = fan_in**-0.5 # 1/sqrt(fan_in) with torch.no_grad(): self.weight_scale += torch.tensor(scale / std).log() @@ -476,9 +471,7 @@ class ScaledLSTM(nn.LSTM): setattr(self, scale_name, param) self._scales.append(param) - self.grad_filter = GradientFilter( - batch_dim=1, threshold=grad_norm_threshold - ) + self.grad_filter = GradientFilter(batch_dim=1, threshold=grad_norm_threshold) self._reset_parameters( initial_speed @@ -486,8 +479,8 @@ class ScaledLSTM(nn.LSTM): def _reset_parameters(self, initial_speed: float): std = 0.1 / initial_speed - a = (3 ** 0.5) * std - scale = self.hidden_size ** -0.5 + a = (3**0.5) * std + scale = self.hidden_size**-0.5 v = scale / std for idx, name in enumerate(self._flat_weights_names): if "weight" in name: @@ -559,15 +552,11 @@ class ScaledLSTM(nn.LSTM): """Get scaled weights, and resets their data pointer.""" flat_weights = [] for idx in range(len(self._flat_weights_names)): - flat_weights.append( - self._flat_weights[idx] * self._scales[idx].exp() - ) + flat_weights.append(self._flat_weights[idx] * self._scales[idx].exp()) self._flatten_parameters(flat_weights) return flat_weights - def forward( - self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None - ): + def forward(self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None): # This function is modified from https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/rnn.py # noqa # The change for calling `_VF.lstm()` is: # self._flat_weights -> self._get_flat_weights() @@ -915,9 +904,7 @@ def _test_activation_balancer_sign(): def _test_activation_balancer_magnitude(): magnitudes = torch.arange(0, 1, 0.01) N = 1000 - x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze( - -1 - ) + x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1) x = x.detach() x.requires_grad = True m = ActivationBalancer( @@ -947,8 +934,8 @@ def _test_basic_norm(): y = m(x) assert y.shape == x.shape - x_rms = (x ** 2).mean().sqrt() - y_rms = (y ** 2).mean().sqrt() + x_rms = (x**2).mean().sqrt() + y_rms = (y**2).mean().sqrt() print("x rms = ", x_rms) print("y rms = ", y_rms) assert y_rms < x_rms @@ -1001,17 +988,18 @@ def _test_grad_filter(): ) print( - "_test_grad_filter: for gradient norms, the first element > median * threshold ", # noqa + "_test_grad_filter: for gradient norms, the first element > median *" + " threshold ", # noqa i % 2 == 1, ) print( "_test_grad_filter: x_out_grad norm = ", - (x_out_grad ** 2).mean(dim=(0, 2)).sqrt(), + (x_out_grad**2).mean(dim=(0, 2)).sqrt(), ) print( "_test_grad_filter: x.grad norm = ", - (x.grad ** 2).mean(dim=(0, 2)).sqrt(), + (x.grad**2).mean(dim=(0, 2)).sqrt(), ) print("_test_grad_filter: w_out_grad = ", w_out_grad) print("_test_grad_filter: w.grad = ", w.grad) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_beam_search.py index 9bcd2f9f9..e6e0fb1c8 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_beam_search.py @@ -153,9 +153,7 @@ def modified_beam_search( index=hyps_shape.row_ids(1).to(torch.int64), ) # (num_hyps, encoder_out_dim) - logits = model.joiner( - current_encoder_out, decoder_out, project_input=False - ) + logits = model.joiner(current_encoder_out, decoder_out, project_input=False) # logits is of shape (num_hyps, 1, 1, vocab_size) logits = logits.squeeze(1).squeeze(1) @@ -172,14 +170,10 @@ def modified_beam_search( log_probs_shape = k2.ragged.create_ragged_shape2( row_splits=row_splits, cached_tot_size=log_probs.numel() ) - ragged_log_probs = k2.RaggedTensor( - shape=log_probs_shape, value=log_probs - ) + ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) for i in range(batch_size): - topk_log_probs, topk_indexes = ragged_log_probs[i].topk( - num_active_paths - ) + topk_log_probs, topk_indexes = ragged_log_probs[i].topk(num_active_paths) with warnings.catch_warnings(): warnings.simplefilter("ignore") diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py index d76a03946..0139863a1 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py @@ -51,11 +51,7 @@ from streaming_beam_search import ( from torch.nn.utils.rnn import pad_sequence from train import add_model_arguments, get_params, get_transducer_model -from icefall.checkpoint import ( - average_checkpoints, - find_checkpoints, - load_checkpoint, -) +from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.utils import ( AttributeDict, setup_logger, @@ -94,9 +90,11 @@ def get_parser(): "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'" + ), ) parser.add_argument( @@ -162,8 +160,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( @@ -271,9 +268,7 @@ def decode_one_chunk( encoder_out = model.joiner.encoder_proj(encoder_out) if params.decoding_method == "greedy_search": - greedy_search( - model=model, encoder_out=encoder_out, streams=decode_streams - ) + greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams) elif params.decoding_method == "fast_beam_search": processed_lens = processed_lens + encoder_out_lens fast_beam_search_one_best( @@ -293,9 +288,7 @@ def decode_one_chunk( num_active_paths=params.num_active_paths, ) else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") states = [torch.unbind(states[0], dim=2), torch.unbind(states[1], dim=2)] @@ -351,9 +344,7 @@ def decode_dataset( decode_results = [] # Contain decode streams currently running. decode_streams = [] - initial_states = model.encoder.get_init_state( - params.left_context, device=device - ) + initial_states = model.encoder.get_init_state(params.left_context, device=device) for num, cut in enumerate(cuts): # each utterance has a DecodeStream. decode_stream = DecodeStream( @@ -425,9 +416,7 @@ def decode_dataset( elif params.decoding_method == "modified_beam_search": key = f"num_active_paths_{params.num_active_paths}" else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") return {key: decode_results} @@ -462,8 +451,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -536,8 +524,7 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index 1947834bf..623bdd51a 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -96,9 +96,7 @@ from icefall.utils import ( str2bool, ) -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def add_model_arguments(parser: argparse.ArgumentParser): @@ -210,8 +208,7 @@ def get_parser(): "--initial-lr", type=float, default=0.003, - help="The initial learning rate. This value should not need to " - "be changed.", + help="The initial learning rate. This value should not need to be changed.", ) parser.add_argument( @@ -234,42 +231,45 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", + help=( + "The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss" + ), ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", + help=( + "The scale to smooth the loss with lm (output of prediction network) part." + ), ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network)part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", + help=( + "To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss." + ), ) parser.add_argument( @@ -634,9 +634,7 @@ def compute_loss( # If either all simple_loss or pruned_loss is inf or nan, # we stop the training process by raising an exception - if torch.all(~simple_loss_is_finite) or torch.all( - ~pruned_loss_is_finite - ): + if torch.all(~simple_loss_is_finite) or torch.all(~pruned_loss_is_finite): raise ValueError( "There are too many utterances in this batch " "leading to inf or nan losses." @@ -649,14 +647,9 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 - if warmup < 1.0 - else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = ( - params.simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss + 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) ) + loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training @@ -667,9 +660,7 @@ def compute_loss( # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2 # (2) If some utterances in the batch lead to inf/nan loss, they # are filtered out. - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa info["utterances"] = feature.size(0) @@ -837,9 +828,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -963,8 +952,7 @@ def run(rank, world_size, args): # the threshold if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. " - f"Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" ) return False diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/asr_datamodule.py b/egs/librispeech/ASR/pruned_transducer_stateless3/asr_datamodule.py index 1df7f9ee5..5e81aef07 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/asr_datamodule.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/asr_datamodule.py @@ -27,10 +27,7 @@ from lhotse.dataset import ( K2SpeechRecognitionDataset, SpecAugment, ) -from lhotse.dataset.input_strategies import ( - OnTheFlyFeatures, - PrecomputedFeatures, -) +from lhotse.dataset.input_strategies import OnTheFlyFeatures, PrecomputedFeatures from torch.utils.data import DataLoader from icefall.utils import str2bool @@ -44,59 +41,69 @@ class AsrDataModule: def add_arguments(cls, parser: argparse.ArgumentParser): group = parser.add_argument_group( title="ASR data related options", - description="These options are used for the preparation of " - "PyTorch DataLoaders from Lhotse CutSet's -- they control the " - "effective batch sizes, sampling strategies, applied data " - "augmentations, etc.", + description=( + "These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc." + ), ) group.add_argument( "--max-duration", type=int, default=200.0, - help="Maximum pooled recordings duration (seconds) in a " - "single batch. You can reduce it if it causes CUDA OOM.", + help=( + "Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM." + ), ) group.add_argument( "--bucketing-sampler", type=str2bool, default=True, - help="When enabled, the batches will come from buckets of " - "similar duration (saves padding frames).", + help=( + "When enabled, the batches will come from buckets of " + "similar duration (saves padding frames)." + ), ) group.add_argument( "--num-buckets", type=int, default=30, - help="The number of buckets for the DynamicBucketingSampler. " - "(you might want to increase it for larger datasets).", + help=( + "The number of buckets for the DynamicBucketingSampler. " + "(you might want to increase it for larger datasets)." + ), ) group.add_argument( "--shuffle", type=str2bool, default=True, - help="When enabled (=default), the examples will be " - "shuffled for each epoch.", + help=( + "When enabled (=default), the examples will be shuffled for each epoch." + ), ) group.add_argument( "--return-cuts", type=str2bool, default=True, - help="When enabled, each batch will have the " - "field: batch['supervisions']['cut'] with the cuts that " - "were used to construct it.", + help=( + "When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it." + ), ) group.add_argument( "--num-workers", type=int, default=2, - help="The number of training dataloader workers that " - "collect the batches.", + help="The number of training dataloader workers that collect the batches.", ) group.add_argument( @@ -117,18 +124,22 @@ class AsrDataModule: "--spec-aug-time-warp-factor", type=int, default=80, - help="Used only when --enable-spec-aug is True. " - "It specifies the factor for time warping in SpecAugment. " - "Larger values mean more warping. " - "A value less than 1 means to disable time warp.", + help=( + "Used only when --enable-spec-aug is True. " + "It specifies the factor for time warping in SpecAugment. " + "Larger values mean more warping. " + "A value less than 1 means to disable time warp." + ), ) group.add_argument( "--enable-musan", type=str2bool, default=True, - help="When enabled, select noise from MUSAN and mix it" - "with training dataset. ", + help=( + "When enabled, select noise from MUSAN and mix it" + "with training dataset. " + ), ) group.add_argument( @@ -142,9 +153,11 @@ class AsrDataModule: "--on-the-fly-feats", type=str2bool, default=False, - help="When enabled, use on-the-fly cut mixing and feature " - "extraction. Will drop existing precomputed feature manifests " - "if available. Used only in dev/test CutSet", + help=( + "When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available. Used only in dev/test CutSet" + ), ) def train_dataloaders( @@ -167,9 +180,7 @@ class AsrDataModule: if cuts_musan is not None: logging.info("Enable MUSAN") transforms.append( - CutMix( - cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True - ) + CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) ) else: logging.info("Disable MUSAN") @@ -178,9 +189,7 @@ class AsrDataModule: if self.args.enable_spec_aug: logging.info("Enable SpecAugment") - logging.info( - f"Time warp factor: {self.args.spec_aug_time_warp_factor}" - ) + logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") input_transforms.append( SpecAugment( time_warp_factor=self.args.spec_aug_time_warp_factor, @@ -250,9 +259,7 @@ class AsrDataModule: if self.args.on_the_fly_feats: validate = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) - ), + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), return_cuts=self.args.return_cuts, ) else: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/decode-giga.py b/egs/librispeech/ASR/pruned_transducer_stateless3/decode-giga.py index 5784a78ba..66c8e30ba 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/decode-giga.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/decode-giga.py @@ -79,11 +79,7 @@ from gigaspeech import GigaSpeech from gigaspeech_scoring import asr_text_post_processing from train import get_params, get_transducer_model -from icefall.checkpoint import ( - average_checkpoints, - find_checkpoints, - load_checkpoint, -) +from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.utils import ( AttributeDict, setup_logger, @@ -120,9 +116,11 @@ def get_parser(): "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'" + ), ) parser.add_argument( @@ -192,8 +190,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -280,9 +277,7 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) hyps = [] if params.decoding_method == "fast_beam_search": @@ -312,10 +307,7 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -359,21 +351,11 @@ def decode_one_batch( return {"greedy_search": hyps} elif params.decoding_method == "fast_beam_search": return { - ( - f"beam_{params.beam}_" - f"max_contexts_{params.max_contexts}_" - f"max_states_{params.max_states}" - ): hyps + f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps } elif params.decoding_method == "fast_beam_search_nbest_oracle": return { - ( - f"beam_{params.beam}_" - f"max_contexts_{params.max_contexts}_" - f"max_states_{params.max_states}_" - f"num_paths_{params.num_paths}_" - f"nbest_scale_{params.nbest_scale}" - ): hyps + f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}_num_paths_{params.num_paths}_nbest_scale_{params.nbest_scale}": hyps } else: return {f"beam_size_{params.beam_size}": hyps} @@ -446,9 +428,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -481,8 +461,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -532,9 +511,7 @@ def main(): params.suffix += f"-num-paths-{params.num_paths}" params.suffix += f"-nbest-scale-{params.nbest_scale}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -567,8 +544,7 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py index 8025d6be1..d90497e26 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py @@ -120,11 +120,7 @@ from beam_search import ( from librispeech import LibriSpeech from train import add_model_arguments, get_params, get_transducer_model -from icefall.checkpoint import ( - average_checkpoints, - find_checkpoints, - load_checkpoint, -) +from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.lexicon import Lexicon from icefall.rnn_lm.model import RnnLmModel from icefall.utils import ( @@ -167,9 +163,11 @@ def get_parser(): "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'" + ), ) parser.add_argument( @@ -265,8 +263,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -478,9 +475,7 @@ def decode_one_batch( simulate_streaming=True, ) else: - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) hyps = [] @@ -550,10 +545,7 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -646,21 +638,11 @@ def decode_one_batch( return {"greedy_search": hyps} elif params.decoding_method == "fast_beam_search": return { - ( - f"beam_{params.beam}_" - f"max_contexts_{params.max_contexts}_" - f"max_states_{params.max_states}" - f"temperature_{params.temperature}" - ): hyps + f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}temperature_{params.temperature}": hyps } elif params.decoding_method == "fast_beam_search": return { - ( - f"beam_{params.beam}_" - f"max_contexts_{params.max_contexts}_" - f"max_states_{params.max_states}" - f"temperature_{params.temperature}" - ): hyps + f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}temperature_{params.temperature}": hyps } elif params.decoding_method in [ "fast_beam_search_with_nbest_rescoring", @@ -690,12 +672,7 @@ def decode_one_batch( key += f"_ngram_lm_scale_{params.ngram_lm_scale}" return {key: hyps} else: - return { - ( - f"beam_size_{params.beam_size}_" - f"temperature_{params.temperature}" - ): hyps - } + return {f"beam_size_{params.beam_size}_temperature_{params.temperature}": hyps} def decode_dataset( @@ -779,9 +756,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -814,8 +789,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -939,9 +913,7 @@ def main(): if "LG" in params.decoding_method: params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" params.suffix += f"-temperature-{params.temperature}" else: params.suffix += f"-context-{params.context_size}" @@ -981,8 +953,7 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -1032,15 +1003,10 @@ def main(): word_table=word_table, device=device, ) - decoding_graph = k2.trivial_graph( - params.vocab_size - 1, device=device - ) + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) logging.info(f"G properties_str: {G.properties_str}") rnn_lm_model = None - if ( - params.decoding_method - == "fast_beam_search_with_nbest_rnn_rescoring" - ): + if params.decoding_method == "fast_beam_search_with_nbest_rnn_rescoring": rnn_lm_model = RnnLmModel( vocab_size=params.vocab_size, embedding_dim=params.rnn_lm_embedding_dim, @@ -1065,9 +1031,7 @@ def main(): rnn_lm_model.eval() else: word_table = None - decoding_graph = k2.trivial_graph( - params.vocab_size - 1, device=device - ) + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) rnn_lm_model = None else: decoding_graph = None diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/export.py b/egs/librispeech/ASR/pruned_transducer_stateless3/export.py index 47217ba05..dcf65e937 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/export.py @@ -128,11 +128,7 @@ import torch.nn as nn from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_params, get_transducer_model -from icefall.checkpoint import ( - average_checkpoints, - find_checkpoints, - load_checkpoint, -) +from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.utils import str2bool @@ -164,9 +160,11 @@ def get_parser(): "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'" + ), ) parser.add_argument( @@ -235,8 +233,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( @@ -509,13 +506,9 @@ def export_joiner_model_onnx( - projected_decoder_out: a tensor of shape (N, joiner_dim) """ - encoder_proj_filename = str(joiner_filename).replace( - ".onnx", "_encoder_proj.onnx" - ) + encoder_proj_filename = str(joiner_filename).replace(".onnx", "_encoder_proj.onnx") - decoder_proj_filename = str(joiner_filename).replace( - ".onnx", "_decoder_proj.onnx" - ) + decoder_proj_filename = str(joiner_filename).replace(".onnx", "_decoder_proj.onnx") encoder_out_dim = joiner_model.encoder_proj.weight.shape[1] decoder_out_dim = joiner_model.decoder_proj.weight.shape[1] @@ -616,8 +609,7 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -715,9 +707,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/gigaspeech.py b/egs/librispeech/ASR/pruned_transducer_stateless3/gigaspeech.py index 36f32c6b3..598434f54 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/gigaspeech.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/gigaspeech.py @@ -52,18 +52,14 @@ class GigaSpeech: ) pattern = re.compile(r"gigaspeech_cuts_XL.([0-9]+).jsonl.gz") - idx_filenames = [ - (int(pattern.search(f).group(1)), f) for f in filenames - ] + idx_filenames = [(int(pattern.search(f).group(1)), f) for f in filenames] idx_filenames = sorted(idx_filenames, key=lambda x: x[0]) sorted_filenames = [f[1] for f in idx_filenames] logging.info(f"Loading {len(sorted_filenames)} splits") - return lhotse.combine( - lhotse.load_manifest_lazy(p) for p in sorted_filenames - ) + return lhotse.combine(lhotse.load_manifest_lazy(p) for p in sorted_filenames) def train_L_cuts(self) -> CutSet: f = self.manifest_dir / "gigaspeech_cuts_L_raw.jsonl.gz" diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/jit_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless3/jit_pretrained.py index 162f8c7db..108915389 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/jit_pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/jit_pretrained.py @@ -104,10 +104,12 @@ def get_parser(): "sound_files", type=str, nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", + help=( + "The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz." + ), ) parser.add_argument( @@ -142,10 +144,9 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" - ) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" # We use only the first channel ans.append(wave[0]) return ans @@ -330,9 +331,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/model.py b/egs/librispeech/ASR/pruned_transducer_stateless3/model.py index 7852f84e9..d45f6dadc 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/model.py @@ -84,9 +84,7 @@ class Transducer(nn.Module): self.decoder_giga = decoder_giga self.joiner_giga = joiner_giga - self.simple_am_proj = ScaledLinear( - encoder_dim, vocab_size, initial_speed=0.5 - ) + self.simple_am_proj = ScaledLinear(encoder_dim, vocab_size, initial_speed=0.5) self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size) if decoder_giga is not None: @@ -190,9 +188,7 @@ class Transducer(nn.Module): y_padded = y.pad(mode="constant", padding_value=0) y_padded = y_padded.to(torch.int64) - boundary = torch.zeros( - (x.size(0), 4), dtype=torch.int64, device=x.device - ) + boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) boundary[:, 2] = y_lens boundary[:, 3] = encoder_out_lens diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check.py b/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check.py index d03d1d7ef..163d737e3 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check.py @@ -203,9 +203,7 @@ def test_joiner( ) # Now test encoder_proj - joiner_encoder_proj_inputs = { - encoder_proj_input_name: encoder_out.numpy() - } + joiner_encoder_proj_inputs = {encoder_proj_input_name: encoder_out.numpy()} joiner_encoder_proj_out = joiner_encoder_proj_session.run( [encoder_proj_output_name], joiner_encoder_proj_inputs )[0] @@ -214,16 +212,10 @@ def test_joiner( torch_joiner_encoder_proj_out = model.joiner.encoder_proj(encoder_out) assert torch.allclose( joiner_encoder_proj_out, torch_joiner_encoder_proj_out, atol=1e-5 - ), ( - (joiner_encoder_proj_out - torch_joiner_encoder_proj_out) - .abs() - .max() - ) + ), ((joiner_encoder_proj_out - torch_joiner_encoder_proj_out).abs().max()) # Now test decoder_proj - joiner_decoder_proj_inputs = { - decoder_proj_input_name: decoder_out.numpy() - } + joiner_decoder_proj_inputs = {decoder_proj_input_name: decoder_out.numpy()} joiner_decoder_proj_out = joiner_decoder_proj_session.run( [decoder_proj_output_name], joiner_decoder_proj_inputs )[0] @@ -232,11 +224,7 @@ def test_joiner( torch_joiner_decoder_proj_out = model.joiner.decoder_proj(decoder_out) assert torch.allclose( joiner_decoder_proj_out, torch_joiner_decoder_proj_out, atol=1e-5 - ), ( - (joiner_decoder_proj_out - torch_joiner_decoder_proj_out) - .abs() - .max() - ) + ), ((joiner_decoder_proj_out - torch_joiner_decoder_proj_out).abs().max()) @torch.no_grad() @@ -288,9 +276,7 @@ def main(): if __name__ == "__main__": torch.manual_seed(20220727) - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_pretrained.py index ea5d4e674..11597aa49 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_pretrained.py @@ -102,10 +102,12 @@ def get_parser(): "sound_files", type=str, nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", + help=( + "The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz." + ), ) parser.add_argument( @@ -140,10 +142,9 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" - ) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" # We use only the first channel ans.append(wave[0]) return ans @@ -191,11 +192,7 @@ def greedy_search( projected_encoder_out = joiner_encoder_proj.run( [joiner_encoder_proj.get_outputs()[0].name], - { - joiner_encoder_proj.get_inputs()[ - 0 - ].name: packed_encoder_out.data.numpy() - }, + {joiner_encoder_proj.get_inputs()[0].name: packed_encoder_out.data.numpy()}, )[0] blank_id = 0 # hard-code to 0 @@ -382,9 +379,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py index 19b636a23..849d6cf4e 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py @@ -100,9 +100,11 @@ def get_parser(): "--checkpoint", type=str, required=True, - help="Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint().", + help=( + "Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint()." + ), ) parser.add_argument( @@ -127,10 +129,12 @@ def get_parser(): "sound_files", type=str, nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", + help=( + "The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz." + ), ) parser.add_argument( @@ -177,8 +181,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -231,10 +234,9 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" - ) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" # We use only the first channel ans.append(wave[0]) return ans @@ -302,9 +304,7 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) feature_lengths = torch.tensor(feature_lengths, device=device) @@ -391,9 +391,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py b/egs/librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py index 1e6022b57..85d87f8f2 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py @@ -234,9 +234,7 @@ def scaled_lstm_to_lstm(scaled_lstm: ScaledLSTM) -> nn.LSTM: assert lstm._flat_weights_names == scaled_lstm._flat_weights_names for idx in range(len(scaled_lstm._flat_weights_names)): - scaled_weight = ( - scaled_lstm._flat_weights[idx] * scaled_lstm._scales[idx].exp() - ) + scaled_weight = scaled_lstm._flat_weights[idx] * scaled_lstm._scales[idx].exp() lstm._flat_weights[idx].data.copy_(scaled_weight) return lstm @@ -251,12 +249,10 @@ def get_submodule(model, target): mod: torch.nn.Module = model for item in atoms: if not hasattr(mod, item): - raise AttributeError( - mod._get_name() + " has no " "attribute `" + item + "`" - ) + raise AttributeError(mod._get_name() + " has no attribute `" + item + "`") mod = getattr(mod, item) if not isinstance(mod, torch.nn.Module): - raise AttributeError("`" + item + "` is not " "an nn.Module") + raise AttributeError("`" + item + "` is not an nn.Module") return mod diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_decode.py index 10bb44e00..41a712498 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_decode.py @@ -52,11 +52,7 @@ from streaming_beam_search import ( from torch.nn.utils.rnn import pad_sequence from train import add_model_arguments, get_params, get_transducer_model -from icefall.checkpoint import ( - average_checkpoints, - find_checkpoints, - load_checkpoint, -) +from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.utils import ( AttributeDict, setup_logger, @@ -95,9 +91,11 @@ def get_parser(): "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'" + ), ) parser.add_argument( @@ -163,8 +161,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( @@ -272,9 +269,7 @@ def decode_one_chunk( encoder_out = model.joiner.encoder_proj(encoder_out) if params.decoding_method == "greedy_search": - greedy_search( - model=model, encoder_out=encoder_out, streams=decode_streams - ) + greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams) elif params.decoding_method == "fast_beam_search": processed_lens = processed_lens + encoder_out_lens fast_beam_search_one_best( @@ -294,9 +289,7 @@ def decode_one_chunk( num_active_paths=params.num_active_paths, ) else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") states = [torch.unbind(states[0], dim=2), torch.unbind(states[1], dim=2)] @@ -352,9 +345,7 @@ def decode_dataset( decode_results = [] # Contain decode streams currently running. decode_streams = [] - initial_states = model.encoder.get_init_state( - params.left_context, device=device - ) + initial_states = model.encoder.get_init_state(params.left_context, device=device) for num, cut in enumerate(cuts): # each utterance has a DecodeStream. decode_stream = DecodeStream( @@ -426,9 +417,7 @@ def decode_dataset( elif params.decoding_method == "modified_beam_search": key = f"num_active_paths_{params.num_active_paths}" else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") return {key: decode_results} @@ -461,8 +450,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -535,8 +523,7 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/test_onnx.py b/egs/librispeech/ASR/pruned_transducer_stateless3/test_onnx.py index 66ffbd3ec..598fcf344 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/test_onnx.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/test_onnx.py @@ -90,9 +90,7 @@ def test_conv2d_subsampling(): onnx_y = torch.from_numpy(onnx_y) torch_y = jit_model(x) - assert torch.allclose(onnx_y, torch_y, atol=1e-05), ( - (onnx_y - torch_y).abs().max() - ) + assert torch.allclose(onnx_y, torch_y, atol=1e-05), (onnx_y - torch_y).abs().max() os.remove(filename) @@ -147,9 +145,7 @@ def test_rel_pos(): onnx_pos_emb = torch.from_numpy(onnx_pos_emb) torch_y, torch_pos_emb = jit_model(x) - assert torch.allclose(onnx_y, torch_y, atol=1e-05), ( - (onnx_y - torch_y).abs().max() - ) + assert torch.allclose(onnx_y, torch_y, atol=1e-05), (onnx_y - torch_y).abs().max() assert torch.allclose(onnx_pos_emb, torch_pos_emb, atol=1e-05), ( (onnx_pos_emb - torch_pos_emb).abs().max() @@ -197,9 +193,7 @@ def test_conformer_encoder_layer(): encoder_layer.eval() encoder_layer = convert_scaled_to_non_scaled(encoder_layer, inplace=True) - jit_model = torch.jit.trace( - encoder_layer, (x, pos_emb, src_key_padding_mask) - ) + jit_model = torch.jit.trace(encoder_layer, (x, pos_emb, src_key_padding_mask)) torch.onnx.export( encoder_layer, @@ -236,9 +230,7 @@ def test_conformer_encoder_layer(): onnx_y = torch.from_numpy(onnx_y) torch_y = jit_model(x, pos_emb, src_key_padding_mask) - assert torch.allclose(onnx_y, torch_y, atol=1e-05), ( - (onnx_y - torch_y).abs().max() - ) + assert torch.allclose(onnx_y, torch_y, atol=1e-05), (onnx_y - torch_y).abs().max() print(onnx_y.abs().sum(), torch_y.abs().sum(), onnx_y.shape, torch_y.shape) @@ -322,9 +314,7 @@ def test_conformer_encoder(): onnx_y = torch.from_numpy(onnx_y) torch_y = jit_model(x, pos_emb, src_key_padding_mask) - assert torch.allclose(onnx_y, torch_y, atol=1e-05), ( - (onnx_y - torch_y).abs().max() - ) + assert torch.allclose(onnx_y, torch_y, atol=1e-05), (onnx_y - torch_y).abs().max() print(onnx_y.abs().sum(), torch_y.abs().sum(), onnx_y.shape, torch_y.shape) @@ -379,9 +369,7 @@ def test_conformer(): onnx_y_lens = torch.from_numpy(onnx_y_lens) torch_y, torch_y_lens = jit_model(x, x_lens) - assert torch.allclose(onnx_y, torch_y, atol=1e-05), ( - (onnx_y - torch_y).abs().max() - ) + assert torch.allclose(onnx_y, torch_y, atol=1e-05), (onnx_y - torch_y).abs().max() assert torch.allclose(onnx_y_lens, torch_y_lens, atol=1e-05), ( (onnx_y_lens - torch_y_lens).abs().max() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/train.py b/egs/librispeech/ASR/pruned_transducer_stateless3/train.py index 44e96644a..6724343dd 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/train.py @@ -92,9 +92,7 @@ from icefall.utils import ( str2bool, ) -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def add_model_arguments(parser: argparse.ArgumentParser): @@ -163,8 +161,7 @@ def get_parser(): "--full-libri", type=str2bool, default=True, - help="When enabled, use 960h LibriSpeech. " - "Otherwise, use 100h subset.", + help="When enabled, use 960h LibriSpeech. Otherwise, use 100h subset.", ) parser.add_argument( @@ -214,8 +211,7 @@ def get_parser(): "--initial-lr", type=float, default=0.003, - help="The initial learning rate. This value should not need " - "to be changed.", + help="The initial learning rate. This value should not need to be changed.", ) parser.add_argument( @@ -238,42 +234,45 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", + help=( + "The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss" + ), ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", + help=( + "The scale to smooth the loss with lm (output of prediction network) part." + ), ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network)part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", + help=( + "To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss." + ), ) parser.add_argument( @@ -672,9 +671,7 @@ def compute_loss( # If either all simple_loss or pruned_loss is inf or nan, # we stop the training process by raising an exception - if torch.all(~simple_loss_is_finite) or torch.all( - ~pruned_loss_is_finite - ): + if torch.all(~simple_loss_is_finite) or torch.all(~pruned_loss_is_finite): raise ValueError( "There are too many utterances in this batch " "leading to inf or nan losses." @@ -687,14 +684,9 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 - if warmup < 1.0 - else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = ( - params.simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss + 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) ) + loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training @@ -705,9 +697,7 @@ def compute_loss( # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2 # (2) If some utterances in the batch lead to inf/nan loss, they # are filtered out. - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa info["utterances"] = feature.size(0) @@ -919,9 +909,7 @@ def train_one_epoch( f"train/current_{prefix}_", params.batch_idx_train, ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) libri_tot_loss.write_summary( tb_writer, "train/libri_tot_", params.batch_idx_train ) @@ -967,8 +955,7 @@ def filter_short_and_long_utterances( # the threshold if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. " - f"Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" ) return False @@ -1109,9 +1096,7 @@ def run(rank, world_size, args): train_giga_cuts = train_giga_cuts.repeat(times=None) if args.enable_musan: - cuts_musan = load_manifest( - Path(args.manifest_dir) / "musan_cuts.jsonl.gz" - ) + cuts_musan = load_manifest(Path(args.manifest_dir) / "musan_cuts.jsonl.gz") else: cuts_musan = None diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py index 4f043e5a6..69cfcd298 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py @@ -197,20 +197,24 @@ def get_parser(): "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'" + ), ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", + help=( + "Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. " + ), ) parser.add_argument( @@ -306,8 +310,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -427,9 +430,7 @@ def decode_one_batch( simulate_streaming=True, ) else: - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) if ( params.decoding_method == "fast_beam_search" @@ -485,10 +486,7 @@ def decode_one_batch( nbest_scale=params.nbest_scale, return_timestamps=True, ) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: res = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -566,9 +564,7 @@ def decode_dataset( sp: spm.SentencePieceProcessor, word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[ - str, List[Tuple[str, List[str], List[str], List[float], List[float]]] -]: +) -> Dict[str, List[Tuple[str, List[str], List[str], List[float], List[float]]]]: """Decode dataset. Args: @@ -643,9 +639,7 @@ def decode_dataset( cut_ids, hyps, texts, timestamps_hyp, timestamps_ref ): ref_words = ref_text.split() - this_batch.append( - (cut_id, ref_words, hyp_words, time_ref, time_hyp) - ) + this_batch.append((cut_id, ref_words, hyp_words, time_ref, time_hyp)) results[name].extend(this_batch) @@ -654,9 +648,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -694,8 +686,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -722,9 +713,7 @@ def save_results( note = "" logging.info(s) - s = "\nFor {}, symbol-delay of different settings are:\n".format( - test_set_name - ) + s = "\nFor {}, symbol-delay of different settings are:\n".format(test_set_name) note = "\tbest for {}".format(test_set_name) for key, val in test_set_delays: s += "{}\tmean: {}s, variance: {}{}\n".format(key, val[0], val[1], note) @@ -773,9 +762,7 @@ def main(): if "LG" in params.decoding_method: params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -812,13 +799,12 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -841,13 +827,12 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -875,7 +860,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - f"Calculating the averaged model over epoch range from " + "Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -902,9 +887,7 @@ def main(): decoding_graph.scores *= params.ngram_lm_scale else: word_table = None - decoding_graph = k2.trivial_graph( - params.vocab_size - 1, device=device - ) + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) else: decoding_graph = None word_table = None diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/export.py b/egs/librispeech/ASR/pruned_transducer_stateless4/export.py index ce7518ceb..bd5801a78 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/export.py @@ -89,20 +89,24 @@ def get_parser(): "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'" + ), ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", + help=( + "Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. " + ), ) parser.add_argument( @@ -133,8 +137,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( @@ -183,13 +186,12 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -212,13 +214,12 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -246,7 +247,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - f"Calculating the averaged model over epoch range from " + "Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -282,9 +283,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_decode.py index 7af9ea9b8..a28e52c78 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_decode.py @@ -96,20 +96,24 @@ def get_parser(): "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'" + ), ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", + help=( + "Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. " + ), ) parser.add_argument( @@ -175,8 +179,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( @@ -284,9 +287,7 @@ def decode_one_chunk( encoder_out = model.joiner.encoder_proj(encoder_out) if params.decoding_method == "greedy_search": - greedy_search( - model=model, encoder_out=encoder_out, streams=decode_streams - ) + greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams) elif params.decoding_method == "fast_beam_search": processed_lens = processed_lens + encoder_out_lens fast_beam_search_one_best( @@ -306,9 +307,7 @@ def decode_one_chunk( num_active_paths=params.num_active_paths, ) else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") states = [torch.unbind(states[0], dim=2), torch.unbind(states[1], dim=2)] @@ -364,9 +363,7 @@ def decode_dataset( decode_results = [] # Contain decode streams currently running. decode_streams = [] - initial_states = model.encoder.get_init_state( - params.left_context, device=device - ) + initial_states = model.encoder.get_init_state(params.left_context, device=device) for num, cut in enumerate(cuts): # each utterance has a DecodeStream. decode_stream = DecodeStream( @@ -438,9 +435,7 @@ def decode_dataset( elif params.decoding_method == "modified_beam_search": key = f"num_active_paths_{params.num_active_paths}" else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") return {key: decode_results} @@ -473,8 +468,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -547,13 +541,12 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -576,13 +569,12 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -610,7 +602,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - f"Calculating the averaged model over epoch range from " + "Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py index cf32e565b..76785a845 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py @@ -101,9 +101,7 @@ from icefall.utils import ( str2bool, ) -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def add_model_arguments(parser: argparse.ArgumentParser): @@ -239,42 +237,45 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", + help=( + "The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss" + ), ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", + help=( + "The scale to smooth the loss with lm (output of prediction network) part." + ), ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network)part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", + help=( + "To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss." + ), ) parser.add_argument( @@ -621,11 +622,7 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = ( - model.device - if isinstance(model, DDP) - else next(model.parameters()).device - ) + device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -665,9 +662,7 @@ def compute_loss( # If either all simple_loss or pruned_loss is inf or nan, # we stop the training process by raising an exception - if torch.all(~simple_loss_is_finite) or torch.all( - ~pruned_loss_is_finite - ): + if torch.all(~simple_loss_is_finite) or torch.all(~pruned_loss_is_finite): raise ValueError( "There are too many utterances in this batch " "leading to inf or nan losses." @@ -680,14 +675,9 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 - if warmup < 1.0 - else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = ( - params.simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss + 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) ) + loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training @@ -698,9 +688,7 @@ def compute_loss( # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2 # (2) If some utterances in the batch lead to inf/nan loss, they # are filtered out. - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa info["utterances"] = feature.size(0) @@ -879,9 +867,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -1013,8 +999,7 @@ def run(rank, world_size, args): # the threshold if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. " - f"Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" ) return False diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py index 427b06294..8499651d7 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py @@ -214,10 +214,7 @@ class Conformer(EncoderInterface): (num_encoder_layers, cnn_module_kernel - 1, encoder_dim). NOTE: the returned tensors are on the given device. """ - if ( - len(self._init_state) == 2 - and self._init_state[0].size(1) == left_context - ): + if len(self._init_state) == 2 and self._init_state[0].size(1) == left_context: # Note: It is OK to share the init state as it is # not going to be modified by the model return self._init_state @@ -439,9 +436,7 @@ class ConformerEncoderLayer(nn.Module): self.d_model = d_model - self.self_attn = RelPositionMultiheadAttention( - d_model, nhead, dropout=0.0 - ) + self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) self.feed_forward = nn.Sequential( ScaledLinear(d_model, dim_feedforward), @@ -459,9 +454,7 @@ class ConformerEncoderLayer(nn.Module): ScaledLinear(dim_feedforward, d_model, initial_scale=0.25), ) - self.conv_module = ConvolutionModule( - d_model, cnn_module_kernel, causal=causal - ) + self.conv_module = ConvolutionModule(d_model, cnn_module_kernel, causal=causal) self.norm_final = BasicNorm(d_model) @@ -527,9 +520,7 @@ class ConformerEncoderLayer(nn.Module): src = src + self.dropout(src_att) # convolution module - conv, _ = self.conv_module( - src, src_key_padding_mask=src_key_padding_mask - ) + conv, _ = self.conv_module(src, src_key_padding_mask=src_key_padding_mask) src = src + self.dropout(conv) # feed forward module @@ -802,9 +793,7 @@ class RelPositionalEncoding(torch.nn.Module): """ - def __init__( - self, d_model: int, dropout_rate: float, max_len: int = 5000 - ) -> None: + def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: """Construct an PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() self.d_model = d_model @@ -820,9 +809,7 @@ class RelPositionalEncoding(torch.nn.Module): # the length of self.pe is 2 * input_len - 1 if self.pe.size(1) >= x_size_1 * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str( - x.device - ): + if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return # Suppose `i` means to the position of query vector and `j` means the @@ -848,9 +835,7 @@ class RelPositionalEncoding(torch.nn.Module): pe = torch.cat([pe_positive, pe_negative], dim=1) self.pe = pe.to(device=x.device, dtype=x.dtype) - def forward( - self, x: torch.Tensor, left_context: int = 0 - ) -> Tuple[Tensor, Tensor]: + def forward(self, x: torch.Tensor, left_context: int = 0) -> Tuple[Tensor, Tensor]: """Add positional encoding. Args: @@ -1118,9 +1103,9 @@ class RelPositionMultiheadAttention(nn.Module): if torch.equal(query, key) and torch.equal(key, value): # self-attention - q, k, v = nn.functional.linear( - query, in_proj_weight, in_proj_bias - ).chunk(3, dim=-1) + q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk( + 3, dim=-1 + ) elif torch.equal(key, value): # encoder-decoder attention @@ -1189,33 +1174,25 @@ class RelPositionMultiheadAttention(nn.Module): if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: - raise RuntimeError( - "The size of the 2D attn_mask is not correct." - ) + raise RuntimeError("The size of the 2D attn_mask is not correct.") elif attn_mask.dim() == 3: if list(attn_mask.size()) != [ bsz * num_heads, query.size(0), key.size(0), ]: - raise RuntimeError( - "The size of the 3D attn_mask is not correct." - ) + raise RuntimeError("The size of the 3D attn_mask is not correct.") else: raise RuntimeError( - "attn_mask's dimension {} is not supported".format( - attn_mask.dim() - ) + "attn_mask's dimension {} is not supported".format(attn_mask.dim()) ) # attn_mask's dim is 3 now. # convert ByteTensor key_padding_mask to bool - if ( - key_padding_mask is not None - and key_padding_mask.dtype == torch.uint8 - ): + if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: warnings.warn( - "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." + "Byte tensor for key_padding_mask is deprecated. Use bool tensor" + " instead." ) key_padding_mask = key_padding_mask.to(torch.bool) @@ -1253,23 +1230,15 @@ class RelPositionMultiheadAttention(nn.Module): # first compute matrix a and matrix c # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - matrix_ac = torch.matmul( - q_with_bias_u, k - ) # (batch, head, time1, time2) + matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2) # compute matrix b and matrix d - matrix_bd = torch.matmul( - q_with_bias_v, p - ) # (batch, head, time1, 2*time1-1) + matrix_bd = torch.matmul(q_with_bias_v, p) # (batch, head, time1, 2*time1-1) matrix_bd = self.rel_shift(matrix_bd, left_context) - attn_output_weights = ( - matrix_ac + matrix_bd - ) # (batch, head, time1, time2) + attn_output_weights = matrix_ac + matrix_bd # (batch, head, time1, time2) - attn_output_weights = attn_output_weights.view( - bsz * num_heads, tgt_len, -1 - ) + attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) assert list(attn_output_weights.size()) == [ bsz * num_heads, @@ -1310,21 +1279,17 @@ class RelPositionMultiheadAttention(nn.Module): ): if attn_mask.size(0) != 1: attn_mask = attn_mask.view(bsz, num_heads, tgt_len, src_len) - combined_mask = attn_mask | key_padding_mask.unsqueeze( - 1 - ).unsqueeze(2) + combined_mask = attn_mask | key_padding_mask.unsqueeze(1).unsqueeze(2) else: # attn_mask.shape == (1, tgt_len, src_len) - combined_mask = attn_mask.unsqueeze( - 0 - ) | key_padding_mask.unsqueeze(1).unsqueeze(2) + combined_mask = attn_mask.unsqueeze(0) | key_padding_mask.unsqueeze( + 1 + ).unsqueeze(2) attn_output_weights = attn_output_weights.view( bsz, num_heads, tgt_len, src_len ) - attn_output_weights = attn_output_weights.masked_fill( - combined_mask, 0.0 - ) + attn_output_weights = attn_output_weights.masked_fill(combined_mask, 0.0) attn_output_weights = attn_output_weights.view( bsz * num_heads, tgt_len, src_len ) @@ -1336,13 +1301,9 @@ class RelPositionMultiheadAttention(nn.Module): attn_output = torch.bmm(attn_output_weights, v) assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] attn_output = ( - attn_output.transpose(0, 1) - .contiguous() - .view(tgt_len, bsz, embed_dim) - ) - attn_output = nn.functional.linear( - attn_output, out_proj_weight, out_proj_bias + attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) ) + attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) if need_weights: # average attention weights over heads @@ -1481,16 +1442,12 @@ class ConvolutionModule(nn.Module): # manualy padding self.lorder zeros to the left x = nn.functional.pad(x, (self.lorder, 0), "constant", 0.0) else: - assert ( - not self.training - ), "Cache should be None in training time" + assert not self.training, "Cache should be None in training time" assert cache.size(0) == self.lorder x = torch.cat([cache.permute(1, 2, 0), x], dim=2) if right_context > 0: cache = x.permute(2, 0, 1)[ - -(self.lorder + right_context) : ( # noqa - -right_context - ), + -(self.lorder + right_context) : (-right_context), # noqa ..., ] else: @@ -1666,9 +1623,7 @@ class RandomCombine(nn.Module): self.stddev = stddev self.final_log_weight = ( - torch.tensor( - (final_weight / (1 - final_weight)) * (self.num_inputs - 1) - ) + torch.tensor((final_weight / (1 - final_weight)) * (self.num_inputs - 1)) .log() .item() ) @@ -1765,16 +1720,14 @@ class RandomCombine(nn.Module): # final contains self.num_inputs - 1 in all elements final = torch.full((num_frames,), self.num_inputs - 1, device=device) # nonfinal contains random integers in [0..num_inputs - 2], these are for non-final weights. - nonfinal = torch.randint( - self.num_inputs - 1, (num_frames,), device=device - ) + nonfinal = torch.randint(self.num_inputs - 1, (num_frames,), device=device) indexes = torch.where( torch.rand(num_frames, device=device) < final_prob, final, nonfinal ) - ans = torch.nn.functional.one_hot( - indexes, num_classes=self.num_inputs - ).to(dtype=dtype) + ans = torch.nn.functional.one_hot(indexes, num_classes=self.num_inputs).to( + dtype=dtype + ) return ans def _get_random_mixed_weights( @@ -1804,7 +1757,8 @@ class RandomCombine(nn.Module): def _test_random_combine(final_weight: float, pure_prob: float, stddev: float): print( - f"_test_random_combine: final_weight={final_weight}, pure_prob={pure_prob}, stddev={stddev}" + f"_test_random_combine: final_weight={final_weight}, pure_prob={pure_prob}," + f" stddev={stddev}" ) num_inputs = 3 num_channels = 50 diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py index 22bcdd88e..f462cc42f 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py @@ -179,20 +179,24 @@ def get_parser(): "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'" + ), ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", + help=( + "Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. " + ), ) parser.add_argument( @@ -303,8 +307,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( @@ -477,9 +480,7 @@ def decode_one_batch( simulate_streaming=True, ) else: - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) hyps = [] @@ -545,10 +546,7 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -696,9 +694,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -731,8 +727,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -787,9 +782,7 @@ def main(): if "LG" in params.decoding_method: params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -828,13 +821,12 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -857,13 +849,12 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -891,7 +882,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - f"Calculating the averaged model over epoch range from " + "Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -937,9 +928,7 @@ def main(): decoding_graph.scores *= params.ngram_lm_scale else: word_table = None - decoding_graph = k2.trivial_graph( - params.vocab_size - 1, device=device - ) + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) else: decoding_graph = None word_table = None diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/export.py b/egs/librispeech/ASR/pruned_transducer_stateless5/export.py index b2e5b430e..a739c17bc 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/export.py @@ -89,20 +89,24 @@ def get_parser(): "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'" + ), ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", + help=( + "Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. " + ), ) parser.add_argument( @@ -133,8 +137,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( @@ -181,13 +184,12 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -210,13 +212,12 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -244,7 +245,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - f"Calculating the averaged model over epoch range from " + "Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -280,9 +281,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless5/pretrained.py index 1e100fcbd..e2da0da4c 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/pretrained.py @@ -89,9 +89,11 @@ def get_parser(): "--checkpoint", type=str, required=True, - help="Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint().", + help=( + "Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint()." + ), ) parser.add_argument( @@ -116,10 +118,12 @@ def get_parser(): "sound_files", type=str, nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", + help=( + "The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz." + ), ) parser.add_argument( @@ -166,8 +170,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -198,10 +201,9 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" - ) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" # We use only the first channel ans.append(wave[0]) return ans @@ -264,15 +266,11 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) feature_lengths = torch.tensor(feature_lengths, device=device) - encoder_out, encoder_out_lens = model.encoder( - x=features, x_lens=feature_lengths - ) + encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths) num_waves = encoder_out.size(0) hyps = [] @@ -344,9 +342,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless5/streaming_decode.py index 6fee9483e..59a0e8fa2 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/streaming_decode.py @@ -96,20 +96,24 @@ def get_parser(): "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'" + ), ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", + help=( + "Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. " + ), ) parser.add_argument( @@ -175,8 +179,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( @@ -284,9 +287,7 @@ def decode_one_chunk( encoder_out = model.joiner.encoder_proj(encoder_out) if params.decoding_method == "greedy_search": - greedy_search( - model=model, encoder_out=encoder_out, streams=decode_streams - ) + greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams) elif params.decoding_method == "fast_beam_search": processed_lens = processed_lens + encoder_out_lens fast_beam_search_one_best( @@ -306,9 +307,7 @@ def decode_one_chunk( num_active_paths=params.num_active_paths, ) else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") states = [torch.unbind(states[0], dim=2), torch.unbind(states[1], dim=2)] @@ -364,9 +363,7 @@ def decode_dataset( decode_results = [] # Contain decode streams currently running. decode_streams = [] - initial_states = model.encoder.get_init_state( - params.left_context, device=device - ) + initial_states = model.encoder.get_init_state(params.left_context, device=device) for num, cut in enumerate(cuts): # each utterance has a DecodeStream. decode_stream = DecodeStream( @@ -438,9 +435,7 @@ def decode_dataset( elif params.decoding_method == "modified_beam_search": key = f"num_active_paths_{params.num_active_paths}" else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") return {key: decode_results} @@ -473,8 +468,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -547,13 +541,12 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -576,13 +569,12 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -610,7 +602,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - f"Calculating the averaged model over epoch range from " + "Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/train.py b/egs/librispeech/ASR/pruned_transducer_stateless5/train.py index 179d9372e..75696d61b 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/train.py @@ -89,9 +89,7 @@ from icefall.utils import ( str2bool, ) -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def add_model_arguments(parser: argparse.ArgumentParser): @@ -248,8 +246,7 @@ def get_parser(): "--initial-lr", type=float, default=0.003, - help="The initial learning rate. This value should not need " - "to be changed.", + help="The initial learning rate. This value should not need to be changed.", ) parser.add_argument( @@ -272,42 +269,45 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", + help=( + "The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss" + ), ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", + help=( + "The scale to smooth the loss with lm (output of prediction network) part." + ), ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network)part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", + help=( + "To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss." + ), ) parser.add_argument( @@ -645,11 +645,7 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = ( - model.device - if isinstance(model, DDP) - else next(model.parameters()).device - ) + device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -690,9 +686,7 @@ def compute_loss( # If the batch contains more than 10 utterances AND # if either all simple_loss or pruned_loss is inf or nan, # we stop the training process by raising an exception - if torch.all(~simple_loss_is_finite) or torch.all( - ~pruned_loss_is_finite - ): + if torch.all(~simple_loss_is_finite) or torch.all(~pruned_loss_is_finite): raise ValueError( "There are too many utterances in this batch " "leading to inf or nan losses." @@ -705,14 +699,9 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 - if warmup < 1.0 - else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = ( - params.simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss + 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) ) + loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training @@ -723,9 +712,7 @@ def compute_loss( # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2 # (2) If some utterances in the batch lead to inf/nan loss, they # are filtered out. - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa info["utterances"] = feature.size(0) @@ -908,9 +895,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -1023,7 +1008,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2 ** 22 + 2**22 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) @@ -1045,8 +1030,7 @@ def run(rank, world_size, args): # the threshold if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. " - f"Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" ) return False diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless6/conformer.py index 53788b3f7..40ad61fd4 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/conformer.py @@ -90,10 +90,7 @@ class Conformer(EncoderInterface): output_layers = [] if middle_output_layer is not None: - assert ( - middle_output_layer >= 0 - and middle_output_layer < num_encoder_layers - ) + assert middle_output_layer >= 0 and middle_output_layer < num_encoder_layers output_layers.append(middle_output_layer) # The last layer is always needed. @@ -178,9 +175,7 @@ class ConformerEncoderLayer(nn.Module): self.d_model = d_model - self.self_attn = RelPositionMultiheadAttention( - d_model, nhead, dropout=0.0 - ) + self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) self.feed_forward = nn.Sequential( ScaledLinear(d_model, dim_feedforward), @@ -362,9 +357,7 @@ class RelPositionalEncoding(torch.nn.Module): """ - def __init__( - self, d_model: int, dropout_rate: float, max_len: int = 5000 - ) -> None: + def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: """Construct an PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() self.d_model = d_model @@ -379,9 +372,7 @@ class RelPositionalEncoding(torch.nn.Module): # the length of self.pe is 2 * input_len - 1 if self.pe.size(1) >= x.size(1) * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str( - x.device - ): + if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return # Suppose `i` means to the position of query vector and `j` means the @@ -656,9 +647,9 @@ class RelPositionMultiheadAttention(nn.Module): if torch.equal(query, key) and torch.equal(key, value): # self-attention - q, k, v = nn.functional.linear( - query, in_proj_weight, in_proj_bias - ).chunk(3, dim=-1) + q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk( + 3, dim=-1 + ) elif torch.equal(key, value): # encoder-decoder attention @@ -727,33 +718,25 @@ class RelPositionMultiheadAttention(nn.Module): if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: - raise RuntimeError( - "The size of the 2D attn_mask is not correct." - ) + raise RuntimeError("The size of the 2D attn_mask is not correct.") elif attn_mask.dim() == 3: if list(attn_mask.size()) != [ bsz * num_heads, query.size(0), key.size(0), ]: - raise RuntimeError( - "The size of the 3D attn_mask is not correct." - ) + raise RuntimeError("The size of the 3D attn_mask is not correct.") else: raise RuntimeError( - "attn_mask's dimension {} is not supported".format( - attn_mask.dim() - ) + "attn_mask's dimension {} is not supported".format(attn_mask.dim()) ) # attn_mask's dim is 3 now. # convert ByteTensor key_padding_mask to bool - if ( - key_padding_mask is not None - and key_padding_mask.dtype == torch.uint8 - ): + if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: warnings.warn( - "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." + "Byte tensor for key_padding_mask is deprecated. Use bool tensor" + " instead." ) key_padding_mask = key_padding_mask.to(torch.bool) @@ -790,9 +773,7 @@ class RelPositionMultiheadAttention(nn.Module): # first compute matrix a and matrix c # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - matrix_ac = torch.matmul( - q_with_bias_u, k - ) # (batch, head, time1, time2) + matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2) # compute matrix b and matrix d matrix_bd = torch.matmul( @@ -800,13 +781,9 @@ class RelPositionMultiheadAttention(nn.Module): ) # (batch, head, time1, 2*time1-1) matrix_bd = self.rel_shift(matrix_bd) - attn_output_weights = ( - matrix_ac + matrix_bd - ) # (batch, head, time1, time2) + attn_output_weights = matrix_ac + matrix_bd # (batch, head, time1, time2) - attn_output_weights = attn_output_weights.view( - bsz * num_heads, tgt_len, -1 - ) + attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) assert list(attn_output_weights.size()) == [ bsz * num_heads, @@ -840,13 +817,9 @@ class RelPositionMultiheadAttention(nn.Module): attn_output = torch.bmm(attn_output_weights, v) assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] attn_output = ( - attn_output.transpose(0, 1) - .contiguous() - .view(tgt_len, bsz, embed_dim) - ) - attn_output = nn.functional.linear( - attn_output, out_proj_weight, out_proj_bias + attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) ) + attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) if need_weights: # average attention weights over heads @@ -869,9 +842,7 @@ class ConvolutionModule(nn.Module): """ - def __init__( - self, channels: int, kernel_size: int, bias: bool = True - ) -> None: + def __init__(self, channels: int, kernel_size: int, bias: bool = True) -> None: """Construct an ConvolutionModule object.""" super(ConvolutionModule, self).__init__() # kernerl_size should be a odd number for 'SAME' padding diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless6/decode.py index 74df04006..600aa9b39 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/decode.py @@ -120,20 +120,24 @@ def get_parser(): "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'" + ), ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", + help=( + "Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. " + ), ) parser.add_argument( @@ -208,8 +212,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -267,9 +270,7 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - layer_results, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + layer_results, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) encoder_out = layer_results[-1] hyps = [] @@ -285,10 +286,7 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -334,11 +332,7 @@ def decode_one_batch( return {"greedy_search": hyps} elif params.decoding_method == "fast_beam_search": return { - ( - f"beam_{params.beam}_" - f"max_contexts_{params.max_contexts}_" - f"max_states_{params.max_states}" - ): hyps + f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps } else: return {f"beam_size_{params.beam_size}": hyps} @@ -411,9 +405,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -446,8 +438,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -490,9 +481,7 @@ def main(): params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -524,13 +513,12 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -553,13 +541,12 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -587,7 +574,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - f"Calculating the averaged model over epoch range from " + "Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/export.py b/egs/librispeech/ASR/pruned_transducer_stateless6/export.py index cff9c7377..17f8614dc 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/export.py @@ -51,11 +51,7 @@ import sentencepiece as spm import torch from train import get_params, get_transducer_model -from icefall.checkpoint import ( - average_checkpoints, - find_checkpoints, - load_checkpoint, -) +from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.utils import str2bool @@ -87,9 +83,11 @@ def get_parser(): "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'" + ), ) parser.add_argument( @@ -120,8 +118,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) return parser @@ -160,8 +157,7 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -209,9 +205,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/extract_codebook_index.py b/egs/librispeech/ASR/pruned_transducer_stateless6/extract_codebook_index.py index 21409287c..86cf34877 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/extract_codebook_index.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/extract_codebook_index.py @@ -21,9 +21,10 @@ import os from pathlib import Path import torch -from vq_utils import CodebookIndexExtractor from asr_datamodule import LibriSpeechAsrDataModule from hubert_xlarge import HubertXlargeFineTuned +from vq_utils import CodebookIndexExtractor + from icefall.utils import AttributeDict, str2bool diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/hubert_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless6/hubert_decode.py index 49b557814..b8440f90a 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/hubert_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/hubert_decode.py @@ -23,7 +23,6 @@ from pathlib import Path from typing import Dict, List, Tuple import torch - from asr_datamodule import LibriSpeechAsrDataModule from hubert_xlarge import HubertXlargeFineTuned @@ -99,9 +98,7 @@ def decode_dataset( if batch_idx % 20 == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -124,9 +121,7 @@ def save_results( ) test_set_wers[key] = wer - logging.info( - "Wrote detailed error stats to {}".format(errs_filename) - ) + logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = params.res_dir / f"wer-summary-{test_set_name}.txt" @@ -155,9 +150,7 @@ def main(): # reset some parameters needed by hubert. params.update(HubertXlargeFineTuned.get_params()) - params.res_dir = ( - params.exp_dir / f"ctc_greedy_search-{params.teacher_model_id}" - ) + params.res_dir = params.exp_dir / f"ctc_greedy_search-{params.teacher_model_id}" setup_logger(f"{params.res_dir}/log/log-ctc_greedy_search") logging.info("Decoding started") @@ -190,9 +183,7 @@ def main(): params=params, ) - save_results( - params=params, test_set_name=test_set, results_dict=results_dict - ) + save_results(params=params, test_set_name=test_set, results_dict=results_dict) logging.info("Done!") diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/hubert_xlarge.py b/egs/librispeech/ASR/pruned_transducer_stateless6/hubert_xlarge.py index 55ce7b00d..4f9417c9f 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/hubert_xlarge.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/hubert_xlarge.py @@ -22,11 +22,7 @@ from pathlib import Path from typing import Dict, List, Tuple import torch -from fairseq import ( - checkpoint_utils, - tasks, - utils, -) +from fairseq import checkpoint_utils, tasks, utils from fairseq.data.data_utils import post_process from omegaconf import OmegaConf @@ -51,9 +47,7 @@ def _load_hubert_model(params: AttributeDict): "data": str(params.hubert_model_dir), } ) - model_path = Path(params.hubert_model_dir) / ( - params.teacher_model_id + ".pt" - ) + model_path = Path(params.hubert_model_dir) / (params.teacher_model_id + ".pt") task = tasks.setup_task(cfg_task) processor = task.target_dictionary models, saved_cfg = checkpoint_utils.load_model_ensemble( @@ -151,9 +145,7 @@ class HubertXlargeFineTuned: supervisions = batch["supervisions"] num_samples = supervisions["num_samples"] B, T = features.shape - padding_mask = torch.arange(0, T).expand(B, T) > num_samples.reshape( - [-1, 1] - ) + padding_mask = torch.arange(0, T).expand(B, T) > num_samples.reshape([-1, 1]) padding_mask = padding_mask.to(self.params.device) features = features.to(self.params.device) @@ -163,9 +155,7 @@ class HubertXlargeFineTuned: features = features.transpose(1, 2) features = self.w2v_model.layer_norm(features) - padding_mask = self.w2v_model.forward_padding_mask( - features, padding_mask - ) + padding_mask = self.w2v_model.forward_padding_mask(features, padding_mask) if self.w2v_model.post_extract_proj is not None: features = self.w2v_model.post_extract_proj(features) @@ -212,9 +202,7 @@ class HubertXlargeFineTuned: toks = encoder_out.argmax(dim=-1) blank = 0 toks = [tok.unique_consecutive() for tok in toks] - hyps = [ - self.processor.string(tok[tok != blank].int().cpu()) for tok in toks - ] + hyps = [self.processor.string(tok[tok != blank].int().cpu()) for tok in toks] hyps = [post_process(hyp, "letter") for hyp in hyps] return hyps diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/model.py b/egs/librispeech/ASR/pruned_transducer_stateless6/model.py index 7716d19cf..daadb70c9 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/model.py @@ -69,9 +69,7 @@ class Transducer(nn.Module): self.decoder = decoder self.joiner = joiner - self.simple_am_proj = ScaledLinear( - encoder_dim, vocab_size, initial_speed=0.5 - ) + self.simple_am_proj = ScaledLinear(encoder_dim, vocab_size, initial_speed=0.5) self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size) from icefall import is_module_available @@ -180,9 +178,7 @@ class Transducer(nn.Module): y_padded = y.pad(mode="constant", padding_value=0) y_padded = y_padded.to(torch.int64) - boundary = torch.zeros( - (x.size(0), 4), dtype=torch.int64, device=x.device - ) + boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) boundary[:, 2] = y_lens boundary[:, 3] = x_lens @@ -237,9 +233,7 @@ class Transducer(nn.Module): return (simple_loss, pruned_loss, codebook_loss) @staticmethod - def concat_successive_codebook_indexes( - middle_layer_output, codebook_indexes - ): + def concat_successive_codebook_indexes(middle_layer_output, codebook_indexes): # Output rate of hubert is 50 frames per second, # while that of current encoder is 25. # Following code handling two issues: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/train.py b/egs/librispeech/ASR/pruned_transducer_stateless6/train.py index f717d85fb..be54ff0ce 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/train.py @@ -101,9 +101,7 @@ from icefall.utils import ( str2bool, ) -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def get_parser(): @@ -203,42 +201,45 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", + help=( + "The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss" + ), ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", + help=( + "The scale to smooth the loss with lm (output of prediction network) part." + ), ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network)part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", + help=( + "To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss." + ), ) parser.add_argument( @@ -569,9 +570,7 @@ def save_checkpoint( def extract_codebook_indexes(batch): cuts = batch["supervisions"]["cut"] # -100 is identical to ignore_value in CE loss computation. - cuts_pre_mixed = [ - c if isinstance(c, MonoCut) else c.tracks[0].cut for c in cuts - ] + cuts_pre_mixed = [c if isinstance(c, MonoCut) else c.tracks[0].cut for c in cuts] codebook_indexes, codebook_indexes_lens = collate_custom_field( cuts_pre_mixed, "codebook_indexes", pad_value=-100 ) @@ -604,11 +603,7 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = ( - model.device - if isinstance(model, DDP) - else next(model.parameters()).device - ) + device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -655,9 +650,7 @@ def compute_loss( # If the batch contains more than 10 utterances AND # if either all simple_loss or pruned_loss is inf or nan, # we stop the training process by raising an exception - if torch.all(~simple_loss_is_finite) or torch.all( - ~pruned_loss_is_finite - ): + if torch.all(~simple_loss_is_finite) or torch.all(~pruned_loss_is_finite): raise ValueError( "There are too many utterances in this batch " "leading to inf or nan losses." @@ -670,14 +663,9 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 - if warmup < 1.0 - else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = ( - params.simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss + 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) ) + loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss if is_training and params.enable_distillation: assert codebook_loss is not None loss += params.codebook_loss_scale * codebook_loss @@ -690,9 +678,7 @@ def compute_loss( # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2 # (2) If some utterances in the batch lead to inf/nan loss, they # are filtered out. - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa info["utterances"] = feature.size(0) @@ -873,9 +859,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -1007,8 +991,7 @@ def run(rank, world_size, args): # the threshold if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. " - f"Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" ) return False diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/vq_utils.py b/egs/librispeech/ASR/pruned_transducer_stateless6/vq_utils.py index 47cf2b14b..40f97f662 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/vq_utils.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/vq_utils.py @@ -68,9 +68,7 @@ class CodebookIndexExtractor: def init_dirs(self): # vq_dir is the root dir for quantization, containing: # training data, trained quantizer, and extracted codebook indexes - self.vq_dir = ( - self.params.exp_dir / f"vq/{self.params.teacher_model_id}/" - ) + self.vq_dir = self.params.exp_dir / f"vq/{self.params.teacher_model_id}/" self.vq_dir.mkdir(parents=True, exist_ok=True) # manifest_dir contains: @@ -208,9 +206,7 @@ class CodebookIndexExtractor: start = cur_offset % (data.shape[0] + 1 - B) end = start + B cur_offset += B - yield data[start:end, :].to(self.params.device).to( - dtype=torch.float - ) + yield data[start:end, :].to(self.params.device).to(dtype=torch.float) for x in minibatch_generator(train, repeat=True): trainer.step(x) @@ -227,10 +223,11 @@ class CodebookIndexExtractor: """ for subset in self.params.subsets: logging.info(f"About to split {subset}.") - ori_manifest = ( - f"./data/fbank/librispeech_cuts_train-{subset}.jsonl.gz" + ori_manifest = f"./data/fbank/librispeech_cuts_train-{subset}.jsonl.gz" + split_cmd = ( + "lhotse split" + f" {self.params.world_size} {ori_manifest} {self.manifest_dir}" ) - split_cmd = f"lhotse split {self.params.world_size} {ori_manifest} {self.manifest_dir}" os.system(f"{split_cmd}") def join_manifests(self): @@ -240,16 +237,13 @@ class CodebookIndexExtractor: logging.info("Start to join manifest files.") for subset in self.params.subsets: vq_manifest_path = ( - self.dst_manifest_dir - / f"librispeech_cuts_train-{subset}-vq.jsonl.gz" + self.dst_manifest_dir / f"librispeech_cuts_train-{subset}-vq.jsonl.gz" ) ori_manifest_path = ( - self.ori_manifest_dir - / f"librispeech_cuts_train-{subset}.jsonl.gz" + self.ori_manifest_dir / f"librispeech_cuts_train-{subset}.jsonl.gz" ) dst_vq_manifest_path = ( - self.dst_manifest_dir - / f"librispeech_cuts_train-{subset}.jsonl.gz" + self.dst_manifest_dir / f"librispeech_cuts_train-{subset}.jsonl.gz" ) cuts_vq = load_manifest(vq_manifest_path) cuts_ori = load_manifest(ori_manifest_path) @@ -269,8 +263,7 @@ class CodebookIndexExtractor: for subset in self.params.subsets: vq_manifests = f"{self.manifest_dir}/with_codebook_indexes-librispeech-cuts_train-{subset}*.jsonl.gz" dst_vq_manifest = ( - self.dst_manifest_dir - / f"librispeech_cuts_train-{subset}-vq.jsonl.gz" + self.dst_manifest_dir / f"librispeech_cuts_train-{subset}-vq.jsonl.gz" ) if 1 == self.params.world_size: merge_cmd = f"cp {vq_manifests} {dst_vq_manifest}" @@ -330,9 +323,7 @@ class CodebookIndexExtractor: def load_ori_dl(self, subset): if self.params.world_size == 1: - ori_manifest_path = ( - f"./data/fbank/librispeech_cuts_train-{subset}.jsonl.gz" - ) + ori_manifest_path = f"./data/fbank/librispeech_cuts_train-{subset}.jsonl.gz" else: ori_manifest_path = ( self.manifest_dir diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py index 06c5863f1..fa8144935 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py @@ -164,20 +164,24 @@ def get_parser(): "--avg", type=int, default=9, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'" + ), ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", + help=( + "Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. " + ), ) parser.add_argument( @@ -272,8 +276,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -393,9 +396,7 @@ def decode_one_batch( simulate_streaming=True, ) else: - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) hyps = [] @@ -454,10 +455,7 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -588,9 +586,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -623,8 +619,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -679,9 +674,7 @@ def main(): if "LG" in params.decoding_method: params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -718,13 +711,12 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -747,13 +739,12 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -781,7 +772,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - f"Calculating the averaged model over epoch range from " + "Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -808,9 +799,7 @@ def main(): decoding_graph.scores *= params.ngram_lm_scale else: word_table = None - decoding_graph = k2.trivial_graph( - params.vocab_size - 1, device=device - ) + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) else: decoding_graph = None word_table = None diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/decoder.py b/egs/librispeech/ASR/pruned_transducer_stateless7/decoder.py index 712dc8ce1..5f90e6375 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/decoder.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/decoder.py @@ -69,7 +69,7 @@ class Decoder(nn.Module): out_channels=decoder_dim, kernel_size=context_size, padding=0, - groups=decoder_dim//4, # group size == 4 + groups=decoder_dim // 4, # group size == 4 bias=False, ) @@ -91,9 +91,7 @@ class Decoder(nn.Module): if self.context_size > 1: embedding_out = embedding_out.permute(0, 2, 1) if need_pad is True: - embedding_out = F.pad( - embedding_out, pad=(self.context_size - 1, 0) - ) + embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0)) else: # During inference time, there is no need to do extra padding # as we only need one output diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/export.py b/egs/librispeech/ASR/pruned_transducer_stateless7/export.py index 5744ea3ea..43ac658e5 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/export.py @@ -129,20 +129,24 @@ def get_parser(): "--avg", type=int, default=9, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'" + ), ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", + help=( + "Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. " + ), ) parser.add_argument( @@ -176,8 +180,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) add_model_arguments(parser) @@ -215,13 +218,12 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -244,13 +246,12 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -278,7 +279,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - f"Calculating the averaged model over epoch range from " + "Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -316,9 +317,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/jit_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7/jit_pretrained.py index e2405d5ef..c94a34d58 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/jit_pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/jit_pretrained.py @@ -69,10 +69,12 @@ def get_parser(): "sound_files", type=str, nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", + help=( + "The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz." + ), ) return parser @@ -93,10 +95,9 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" - ) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" # We use only the first channel ans.append(wave[0]) return ans @@ -267,9 +268,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/joiner.py b/egs/librispeech/ASR/pruned_transducer_stateless7/joiner.py index 7d8de5afe..3ddac2cf2 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/joiner.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/joiner.py @@ -56,9 +56,7 @@ class Joiner(nn.Module): assert encoder_out.shape[:-1] == decoder_out.shape[:-1] if project_input: - logit = self.encoder_proj(encoder_out) + self.decoder_proj( - decoder_out - ) + logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out) else: logit = encoder_out + decoder_out diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/model.py b/egs/librispeech/ASR/pruned_transducer_stateless7/model.py index 53cde6c6f..0e59b0f2f 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/model.py @@ -15,14 +15,15 @@ # limitations under the License. +import random + import k2 import torch import torch.nn as nn -import random from encoder_interface import EncoderInterface +from scaling import penalize_abs_values_gt from icefall.utils import add_sos -from scaling import penalize_abs_values_gt class Transducer(nn.Module): @@ -65,7 +66,8 @@ class Transducer(nn.Module): self.joiner = joiner self.simple_am_proj = nn.Linear( - encoder_dim, vocab_size, + encoder_dim, + vocab_size, ) self.simple_lm_proj = nn.Linear(decoder_dim, vocab_size) @@ -133,18 +135,16 @@ class Transducer(nn.Module): y_padded = y.pad(mode="constant", padding_value=0) y_padded = y_padded.to(torch.int64) - boundary = torch.zeros( - (x.size(0), 4), dtype=torch.int64, device=x.device - ) + boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) boundary[:, 2] = y_lens boundary[:, 3] = x_lens lm = self.simple_lm_proj(decoder_out) am = self.simple_am_proj(encoder_out) - #if self.training and random.random() < 0.25: + # if self.training and random.random() < 0.25: # lm = penalize_abs_values_gt(lm, 100.0, 1.0e-04) - #if self.training and random.random() < 0.25: + # if self.training and random.random() < 0.25: # am = penalize_abs_values_gt(am, 30.0, 1.0e-04) with torch.cuda.amp.autocast(enabled=False): diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py index bb8b0a0e3..460ac2c3e 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py @@ -14,17 +14,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -from collections import defaultdict -from typing import List, Optional, Union, Tuple, List -from lhotse.utils import fix_random_seed -import torch -from scaling import ActivationBalancer +import contextlib +import logging import random +from collections import defaultdict +from typing import List, Optional, Tuple, Union + +import torch +from lhotse.utils import fix_random_seed +from scaling import ActivationBalancer from torch import Tensor from torch.optim import Optimizer -import logging -import contextlib - class BatchedOptimizer(Optimizer): @@ -37,11 +37,10 @@ class BatchedOptimizer(Optimizer): Args: params: """ + def __init__(self, params, defaults): super(BatchedOptimizer, self).__init__(params, defaults) - - @contextlib.contextmanager def batched_params(self, param_group): """ @@ -73,7 +72,9 @@ class BatchedOptimizer(Optimizer): group: a parameter group, which is a list of parameters; should be one of self.groups. """ - batches = defaultdict(list) # `batches` maps from tuple (dtype_as_str,*shape) to list of nn.Parameter + batches = defaultdict( + list + ) # `batches` maps from tuple (dtype_as_str,*shape) to list of nn.Parameter for p in param_group: key = (str(p.dtype), *p.shape) @@ -82,7 +83,7 @@ class BatchedOptimizer(Optimizer): stacked_params_dict = dict() # turn batches into a list, in deterministic order. - batches = [ batches[key] for key in sorted(batches.keys()) ] + batches = [batches[key] for key in sorted(batches.keys())] # pairs will contain pairs of (stacked_param, state), one for each batch # in `batches`. pairs = [] @@ -94,76 +95,77 @@ class BatchedOptimizer(Optimizer): # group. class Optimizer will take care of saving/loading state. state = self.state[p] p_stacked = torch.stack(batch) - grad = torch.stack([torch.zeros_like(p) if p.grad is None else p.grad for p in batch ]) + grad = torch.stack( + [torch.zeros_like(p) if p.grad is None else p.grad for p in batch] + ) p_stacked.grad = grad stacked_params_dict[key] = p_stacked pairs.append((p_stacked, state)) - yield pairs # <-- calling code will do the actual optimization here! + yield pairs # <-- calling code will do the actual optimization here! for ((stacked_params, _state), batch) in zip(pairs, batches): for i, p in enumerate(batch): # batch is list of Parameter p.copy_(stacked_params[i]) - class ScaledAdam(BatchedOptimizer): """ - Implements 'Scaled Adam', a variant of Adam where we scale each parameter's update - proportional to the norm of that parameter; and also learn the scale of the parameter, - in log space, subject to upper and lower limits (as if we had factored each parameter as - param = underlying_param * log_scale.exp()) + Implements 'Scaled Adam', a variant of Adam where we scale each parameter's update + proportional to the norm of that parameter; and also learn the scale of the parameter, + in log space, subject to upper and lower limits (as if we had factored each parameter as + param = underlying_param * log_scale.exp()) - Args: - params: The parameters or param_groups to optimize (like other Optimizer subclasses) - lr: The learning rate. We will typically use a learning rate schedule that starts - at 0.03 and decreases over time, i.e. much higher than other common - optimizers. - clipping_scale: (e.g. 2.0) - A scale for gradient-clipping: if specified, the normalized gradients - over the whole model will be clipped to have 2-norm equal to - `clipping_scale` times the median 2-norm over the most recent period - of `clipping_update_period` minibatches. By "normalized gradients", - we mean after multiplying by the rms parameter value for this tensor - [for non-scalars]; this is appropriate because our update is scaled - by this quantity. - betas: beta1,beta2 are momentum constants for regular momentum, and moving sum-sq grad. - Must satisfy 0 < beta <= beta2 < 1. - scalar_lr_scale: A scaling factor on the learning rate, that we use to update the - scale of each parameter tensor and scalar parameters of the mode.. - If each parameter were decomposed - as p * p_scale.exp(), where (p**2).mean().sqrt() == 1.0, scalar_lr_scale - would be a the scaling factor on the learning rate of p_scale. - eps: A general-purpose epsilon to prevent division by zero - param_min_rms: Minimum root-mean-square value of parameter tensor, for purposes of - learning the scale on the parameters (we'll constrain the rms of each non-scalar - parameter tensor to be >= this value) - param_max_rms: Maximum root-mean-square value of parameter tensor, for purposes of - learning the scale on the parameters (we'll constrain the rms of each non-scalar - parameter tensor to be <= this value) - scalar_max: Maximum absolute value for scalar parameters (applicable if your - model has any parameters with numel() == 1). - size_update_period: The periodicity, in steps, with which we update the size (scale) - of the parameter tensor. This is provided to save a little time - in the update. - clipping_update_period: if clipping_scale is specified, this is the period + Args: + params: The parameters or param_groups to optimize (like other Optimizer subclasses) + lr: The learning rate. We will typically use a learning rate schedule that starts + at 0.03 and decreases over time, i.e. much higher than other common + optimizers. + clipping_scale: (e.g. 2.0) + A scale for gradient-clipping: if specified, the normalized gradients + over the whole model will be clipped to have 2-norm equal to + `clipping_scale` times the median 2-norm over the most recent period + of `clipping_update_period` minibatches. By "normalized gradients", + we mean after multiplying by the rms parameter value for this tensor + [for non-scalars]; this is appropriate because our update is scaled + by this quantity. + betas: beta1,beta2 are momentum constants for regular momentum, and moving sum-sq grad. + Must satisfy 0 < beta <= beta2 < 1. + scalar_lr_scale: A scaling factor on the learning rate, that we use to update the + scale of each parameter tensor and scalar parameters of the mode.. + If each parameter were decomposed + as p * p_scale.exp(), where (p**2).mean().sqrt() == 1.0, scalar_lr_scale + would be a the scaling factor on the learning rate of p_scale. + eps: A general-purpose epsilon to prevent division by zero + param_min_rms: Minimum root-mean-square value of parameter tensor, for purposes of + learning the scale on the parameters (we'll constrain the rms of each non-scalar + parameter tensor to be >= this value) + param_max_rms: Maximum root-mean-square value of parameter tensor, for purposes of + learning the scale on the parameters (we'll constrain the rms of each non-scalar + parameter tensor to be <= this value) + scalar_max: Maximum absolute value for scalar parameters (applicable if your + model has any parameters with numel() == 1). + size_update_period: The periodicity, in steps, with which we update the size (scale) + of the parameter tensor. This is provided to save a little time + in the update. + clipping_update_period: if clipping_scale is specified, this is the period """ - def __init__( - self, - params, - lr=3e-02, - clipping_scale=None, - betas=(0.9, 0.98), - scalar_lr_scale=0.1, - eps=1.0e-08, - param_min_rms=1.0e-05, - param_max_rms=3.0, - scalar_max=10.0, - size_update_period=4, - clipping_update_period=100, - ): + def __init__( + self, + params, + lr=3e-02, + clipping_scale=None, + betas=(0.9, 0.98), + scalar_lr_scale=0.1, + eps=1.0e-08, + param_min_rms=1.0e-05, + param_max_rms=3.0, + scalar_max=10.0, + size_update_period=4, + clipping_update_period=100, + ): defaults = dict( lr=lr, @@ -183,7 +185,6 @@ class ScaledAdam(BatchedOptimizer): def __setstate__(self, state): super(ScaledAdam, self).__setstate__(state) - @torch.no_grad() def step(self, closure=None): """Performs a single optimization step. @@ -206,7 +207,9 @@ class ScaledAdam(BatchedOptimizer): # a regular parameter, and will have a .grad, but the 1st dim corresponds to # a stacking dim, it is not a real dim. - if len(batches[0][1]) == 0: # if len(first state) == 0: not yet initialized + if ( + len(batches[0][1]) == 0 + ): # if len(first state) == 0: not yet initialized clipping_scale = 1 else: clipping_scale = self._get_clipping_scale(group, batches) @@ -225,13 +228,9 @@ class ScaledAdam(BatchedOptimizer): self._step_one_batch(group, p, state, clipping_scale) - return loss - def _init_state(self, - group: dict, - p: Tensor, - state: dict): + def _init_state(self, group: dict, p: Tensor, state: dict): """ Initializes state dict for parameter 'p'. Assumes that dim 0 of tensor p is actually the batch dimension, corresponding to batched-together @@ -247,7 +246,7 @@ class ScaledAdam(BatchedOptimizer): state["step"] = 0 - kwargs = {'device':p.device, 'dtype':p.dtype} + kwargs = {"device": p.device, "dtype": p.dtype} # 'delta' implements conventional momentum. There are # several different kinds of update going on, so rather than @@ -255,36 +254,30 @@ class ScaledAdam(BatchedOptimizer): # parameter-change "delta", which combines all forms of # update. this is equivalent to how it's done in Adam, # except for the first few steps. - state["delta"] = torch.zeros_like( - p, memory_format=torch.preserve_format - ) + state["delta"] = torch.zeros_like(p, memory_format=torch.preserve_format) batch_size = p.shape[0] numel = p.numel() // batch_size numel = p.numel() - if numel > 1: # "param_rms" just periodically records the scalar root-mean-square value of # the parameter tensor. # it has a shape like (batch_size, 1, 1, 1, 1) - param_rms = (p**2).mean(dim=list(range(1, p.ndim)), - keepdim=True).sqrt() + param_rms = (p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt() state["param_rms"] = param_rms state["scale_exp_avg_sq"] = torch.zeros_like(param_rms) - state["scale_grads"] = torch.zeros(size_update_period, *param_rms.shape, - **kwargs) - + state["scale_grads"] = torch.zeros( + size_update_period, *param_rms.shape, **kwargs + ) # exp_avg_sq is the weighted sum of scaled gradients. as in Adam. - state["exp_avg_sq"] = torch.zeros_like( - p, memory_format=torch.preserve_format - ) + state["exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format) - def _get_clipping_scale(self, - group: dict, - pairs: List[Tuple[Tensor, dict]]) -> float: + def _get_clipping_scale( + self, group: dict, pairs: List[Tuple[Tensor, dict]] + ) -> float: """ Returns a scalar factor <= 1.0 that dictates gradient clipping, i.e. we will scale the gradients by this amount before applying the rest of the update. @@ -314,57 +307,67 @@ class ScaledAdam(BatchedOptimizer): if p.numel() == p.shape[0]: # a batch of scalars tot_sumsq += (grad**2).sum() # sum() to change shape [1] to [] else: - tot_sumsq += ((grad * state["param_rms"])**2).sum() + tot_sumsq += ((grad * state["param_rms"]) ** 2).sum() tot_norm = tot_sumsq.sqrt() - if not "model_norms" in first_state: - first_state["model_norms"] = torch.zeros(clipping_update_period, - device=p.device) + if "model_norms" not in first_state: + first_state["model_norms"] = torch.zeros( + clipping_update_period, device=p.device + ) first_state["model_norms"][step % clipping_update_period] = tot_norm if step % clipping_update_period == 0: # Print some stats. # We don't reach here if step == 0 because we would have returned # above. - sorted_norms = first_state["model_norms"].sort()[0].to('cpu') + sorted_norms = first_state["model_norms"].sort()[0].to("cpu") quartiles = [] for n in range(0, 5): - index = min(clipping_update_period - 1, - (clipping_update_period // 4) * n) + index = min( + clipping_update_period - 1, + (clipping_update_period // 4) * n, + ) quartiles.append(sorted_norms[index].item()) median = quartiles[2] threshold = clipping_scale * median first_state["model_norm_threshold"] = threshold - percent_clipped = (first_state["num_clipped"] * 100.0 / clipping_update_period - if "num_clipped" in first_state else 0.0) + percent_clipped = ( + first_state["num_clipped"] * 100.0 / clipping_update_period + if "num_clipped" in first_state + else 0.0 + ) first_state["num_clipped"] = 0 - quartiles = ' '.join([ '%.3e' % x for x in quartiles ]) - logging.info(f"Clipping_scale={clipping_scale}, grad-norm quartiles {quartiles}, " - f"threshold={threshold:.3e}, percent-clipped={percent_clipped:.1f}") + quartiles = " ".join(["%.3e" % x for x in quartiles]) + logging.info( + f"Clipping_scale={clipping_scale}, grad-norm quartiles {quartiles}, " + f"threshold={threshold:.3e}, percent-clipped={percent_clipped:.1f}" + ) if step < clipping_update_period: return 1.0 # We have not yet estimated a norm to clip to. else: try: model_norm_threshold = first_state["model_norm_threshold"] - except: - logging.info("Warning: model_norm_threshold not in state: possibly " - "you changed config when restarting, adding clipping_scale option?") + except KeyError: + logging.info( + "Warning: model_norm_threshold not in state: possibly " + "you changed config when restarting, adding clipping_scale option?" + ) return 1.0 - ans = min(1.0,(model_norm_threshold / (tot_norm + 1.0e-20)).item()) + ans = min(1.0, (model_norm_threshold / (tot_norm + 1.0e-20)).item()) if ans < 1.0: first_state["num_clipped"] += 1 if ans < 0.1: - logging.warn(f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}") + logging.warn( + f"Scaling gradients by {ans}," + f" model_norm_threshold={model_norm_threshold}" + ) return ans - - def _step_one_batch(self, - group: dict, - p: Tensor, - state: dict, - clipping_scale: float): + def _step_one_batch( + self, group: dict, p: Tensor, state: dict, clipping_scale: float + ): """ Do the step for one parameter, which is actually going to be a batch of `real` parameters, with dim 0 as the batch dim. @@ -391,17 +394,18 @@ class ScaledAdam(BatchedOptimizer): # Update the size/scale of p, and set param_rms scale_grads = state["scale_grads"] scale_grads[step % size_update_period] = (p * grad).sum( - dim=list(range(1, p.ndim)), keepdim=True) + dim=list(range(1, p.ndim)), keepdim=True + ) if step % size_update_period == size_update_period - 1: param_rms = state["param_rms"] # shape: (batch_size, 1, 1, ..) - param_rms.copy_((p**2).mean(dim=list(range(1, p.ndim)), - keepdim=True).sqrt()) + param_rms.copy_( + (p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt() + ) if step > 0: # self._size_update() learns the overall scale on the # parameter, by shrinking or expanding it. self._size_update(group, scale_grads, p, state) - if numel == 1: # For parameters with 1 element we just use regular Adam. # Updates delta. @@ -411,24 +415,21 @@ class ScaledAdam(BatchedOptimizer): state["step"] = step + 1 - - def _size_update(self, - group: dict, - scale_grads: Tensor, - p: Tensor, - state: dict) -> None: + def _size_update( + self, group: dict, scale_grads: Tensor, p: Tensor, state: dict + ) -> None: """ - Called only where p.numel() > 1, this updates the scale of the parameter. - If we imagine: p = underlying_param * scale.exp(), and we are doing - gradient descent on underlying param and on scale, this function does the update - on `scale`. + Called only where p.numel() > 1, this updates the scale of the parameter. + If we imagine: p = underlying_param * scale.exp(), and we are doing + gradient descent on underlying param and on scale, this function does the update + on `scale`. - Args: - group: dict to look up configuration values - scale_grads: a tensor of shape (size_update_period, batch_size, 1, 1,...) containing - grads w.r.t. the scales. - p: The parameter to update - state: The state-dict of p + Args: + group: dict to look up configuration values + scale_grads: a tensor of shape (size_update_period, batch_size, 1, 1,...) containing + grads w.r.t. the scales. + p: The parameter to update + state: The state-dict of p """ param_rms = state["param_rms"] @@ -443,25 +444,28 @@ class ScaledAdam(BatchedOptimizer): size_update_period = scale_grads.shape[0] # correct beta2 for the size update period: we will have # faster decay at this level. - beta2_corr = beta2 ** size_update_period + beta2_corr = beta2**size_update_period scale_exp_avg_sq = state["scale_exp_avg_sq"] # shape: (batch_size, 1, 1, ..) scale_exp_avg_sq.mul_(beta2_corr).add_( - (scale_grads ** 2).mean(dim=0), # mean over dim `size_update_period` - alpha=1-beta2_corr) # shape is (batch_size, 1, 1, ...) + (scale_grads**2).mean(dim=0), # mean over dim `size_update_period` + alpha=1 - beta2_corr, + ) # shape is (batch_size, 1, 1, ...) # The 1st time we reach here is when size_step == 1. size_step = (step + 1) // size_update_period - bias_correction2 = 1 - beta2_corr ** size_step + bias_correction2 = 1 - beta2_corr**size_step # we don't bother with bias_correction1; this will help prevent divergence # at the start of training. denom = scale_exp_avg_sq.sqrt() + eps - scale_step = -size_lr * (bias_correction2 ** 0.5) * scale_grads.sum(dim=0) / denom + scale_step = ( + -size_lr * (bias_correction2**0.5) * scale_grads.sum(dim=0) / denom + ) - is_too_small = (param_rms < param_min_rms) - is_too_large = (param_rms > param_max_rms) + is_too_small = param_rms < param_min_rms + is_too_large = param_rms > param_max_rms # when the param gets too small, just don't shrink it any further. scale_step.masked_fill_(is_too_small, 0.0) @@ -469,13 +473,9 @@ class ScaledAdam(BatchedOptimizer): scale_step.masked_fill_(is_too_large, -size_lr * size_update_period) delta = state["delta"] # the factor of (1-beta1) relates to momentum. - delta.add_(p * scale_step, alpha=(1-beta1)) + delta.add_(p * scale_step, alpha=(1 - beta1)) - - def _step(self, - group: dict, - p: Tensor, - state: dict): + def _step(self, group: dict, p: Tensor, state: dict): """ This function does the core update of self.step(), in the case where the members of the batch have more than 1 element. @@ -496,8 +496,7 @@ class ScaledAdam(BatchedOptimizer): step = state["step"] exp_avg_sq = state["exp_avg_sq"] - exp_avg_sq.mul_(beta2).addcmul_(grad, grad, - value=(1-beta2)) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=(1 - beta2)) this_step = state["step"] - (state["zero_step"] if "zero_step" in state else 0) bias_correction2 = 1 - beta2 ** (this_step + 1) @@ -509,17 +508,13 @@ class ScaledAdam(BatchedOptimizer): denom += eps grad = grad / denom - alpha = -lr * (1-beta1) * state["param_rms"].clamp(min=param_min_rms) + alpha = -lr * (1 - beta1) * state["param_rms"].clamp(min=param_min_rms) delta = state["delta"] delta.add_(grad * alpha) p.add_(delta) - - def _step_scalar(self, - group: dict, - p: Tensor, - state: dict): + def _step_scalar(self, group: dict, p: Tensor, state: dict): """ A simplified form of the core update for scalar tensors, where we cannot get a good estimate of the parameter rms. @@ -531,8 +526,7 @@ class ScaledAdam(BatchedOptimizer): grad = p.grad exp_avg_sq = state["exp_avg_sq"] # shape: (batch_size,) - exp_avg_sq.mul_(beta2).addcmul_(grad, grad, - value=1-beta2) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) # bias_correction2 is like in Adam. Don't bother with bias_correction1; # slower update at the start will help stability anyway. @@ -540,12 +534,11 @@ class ScaledAdam(BatchedOptimizer): denom = (exp_avg_sq / bias_correction2).sqrt() + eps delta = state["delta"] - delta.add_(grad / denom, alpha=-lr*(1-beta1)) + delta.add_(grad / denom, alpha=-lr * (1 - beta1)) p.clamp_(min=-scalar_max, max=scalar_max) p.add_(delta) - class LRScheduler(object): """ Base-class for learning rate schedulers where the learning-rate depends on both the @@ -555,18 +548,14 @@ class LRScheduler(object): def __init__(self, optimizer: Optimizer, verbose: bool = False): # Attach optimizer if not isinstance(optimizer, Optimizer): - raise TypeError( - "{} is not an Optimizer".format(type(optimizer).__name__) - ) + raise TypeError("{} is not an Optimizer".format(type(optimizer).__name__)) self.optimizer = optimizer self.verbose = verbose for group in optimizer.param_groups: group.setdefault("base_lr", group["lr"]) - self.base_lrs = [ - group["base_lr"] for group in optimizer.param_groups - ] + self.base_lrs = [group["base_lr"] for group in optimizer.param_groups] self.epoch = 0 self.batch = 0 @@ -680,13 +669,15 @@ class Eden(LRScheduler): def get_lr(self): factor = ( - (self.batch ** 2 + self.lr_batches ** 2) / self.lr_batches ** 2 + (self.batch**2 + self.lr_batches**2) / self.lr_batches**2 ) ** -0.25 * ( - ((self.epoch ** 2 + self.lr_epochs ** 2) / self.lr_epochs ** 2) - ** -0.25 + ((self.epoch**2 + self.lr_epochs**2) / self.lr_epochs**2) ** -0.25 + ) + warmup_factor = ( + 1.0 + if self.batch >= self.warmup_batches + else 0.5 + 0.5 * (self.batch / self.warmup_batches) ) - warmup_factor = (1.0 if self.batch >= self.warmup_batches - else 0.5 + 0.5 * (self.batch / self.warmup_batches)) return [x * factor * warmup_factor for x in self.base_lrs] @@ -745,13 +736,14 @@ class Eve(Optimizer): parameters, if they fall below this we will stop applying weight decay. - .. _Adam\: A Method for Stochastic Optimization: + .. _Adam: A Method for Stochastic Optimization: https://arxiv.org/abs/1412.6980 .. _Decoupled Weight Decay Regularization: https://arxiv.org/abs/1711.05101 .. _On the Convergence of Adam and Beyond: https://openreview.net/forum?id=ryQu7f-RZ """ + def __init__( self, params, @@ -766,17 +758,11 @@ class Eve(Optimizer): if not 0.0 <= eps: raise ValueError("Invalid epsilon value: {}".format(eps)) if not 0.0 <= betas[0] < 1.0: - raise ValueError( - "Invalid beta parameter at index 0: {}".format(betas[0]) - ) + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) if not 0.0 <= betas[1] < 1.0: - raise ValueError( - "Invalid beta parameter at index 1: {}".format(betas[1]) - ) + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) if not 0 <= weight_decay <= 0.1: - raise ValueError( - "Invalid weight_decay value: {}".format(weight_decay) - ) + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) if not 0 < target_rms <= 10.0: raise ValueError("Invalid target_rms value: {}".format(target_rms)) defaults = dict( @@ -812,9 +798,7 @@ class Eve(Optimizer): # Perform optimization step grad = p.grad if grad.is_sparse: - raise RuntimeError( - "AdamW does not support sparse gradients" - ) + raise RuntimeError("AdamW does not support sparse gradients") state = self.state[p] @@ -841,7 +825,7 @@ class Eve(Optimizer): # Decay the first and second moment running average coefficient exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) - denom = (exp_avg_sq.sqrt() * (bias_correction2 ** -0.5)).add_( + denom = (exp_avg_sq.sqrt() * (bias_correction2**-0.5)).add_( group["eps"] ) @@ -852,30 +836,31 @@ class Eve(Optimizer): if p.numel() > 1: # avoid applying this weight-decay on "scaling factors" # (which are scalar). - is_above_target_rms = p.norm() > ( - target_rms * (p.numel() ** 0.5) - ) + is_above_target_rms = p.norm() > (target_rms * (p.numel() ** 0.5)) p.mul_(1 - (weight_decay * is_above_target_rms)) p.addcdiv_(exp_avg, denom, value=-step_size) if random.random() < 0.0005: - step = (exp_avg/denom) * step_size - logging.info(f"Delta rms = {(step**2).mean().item()}, shape = {step.shape}") - + step = (exp_avg / denom) * step_size + logging.info( + f"Delta rms = {(step**2).mean().item()}, shape = {step.shape}" + ) return loss def _test_scaled_adam(hidden_dim: int): import timeit + from scaling import ScaledLinear + E = 100 B = 4 T = 2 logging.info("in test_eve_cain") - #device = torch.device('cuda') - device = torch.device('cpu') + # device = torch.device('cuda') + device = torch.device("cpu") dtype = torch.float32 fix_random_seed(42) @@ -889,79 +874,93 @@ def _test_scaled_adam(hidden_dim: int): fix_random_seed(42) Linear = torch.nn.Linear if iter == 0 else ScaledLinear - m = torch.nn.Sequential(Linear(E, hidden_dim), - torch.nn.PReLU(), - Linear(hidden_dim, hidden_dim), - torch.nn.PReLU(), - Linear(hidden_dim, E), - ).to(device) + m = torch.nn.Sequential( + Linear(E, hidden_dim), + torch.nn.PReLU(), + Linear(hidden_dim, hidden_dim), + torch.nn.PReLU(), + Linear(hidden_dim, E), + ).to(device) - train_pairs = [ (100.0 * torch.randn(B, T, E, device=device, dtype=dtype) * input_magnitudes, - torch.randn(B, T, E, device=device, dtype=dtype) * output_magnitudes) for _ in range(20) ] + train_pairs = [ + ( + 100.0 + * torch.randn(B, T, E, device=device, dtype=dtype) + * input_magnitudes, + torch.randn(B, T, E, device=device, dtype=dtype) * output_magnitudes, + ) + for _ in range(20) + ] - if iter == 0: optim = Eve(m.parameters(), lr=0.003) - elif iter == 1: optim = ScaledAdam(m.parameters(), lr=0.03, clipping_scale=2.0) + if iter == 0: + optim = Eve(m.parameters(), lr=0.003) + elif iter == 1: + optim = ScaledAdam(m.parameters(), lr=0.03, clipping_scale=2.0) scheduler = Eden(optim, lr_batches=200, lr_epochs=5, verbose=False) - start = timeit.default_timer() avg_loss = 0.0 for epoch in range(180): scheduler.step_epoch() - #if epoch == 100 and iter in [2,3]: + # if epoch == 100 and iter in [2,3]: # optim.reset_speedup() # check it doesn't crash. - #if epoch == 130: + # if epoch == 130: # opts = diagnostics.TensorDiagnosticOptions( # 2 ** 22 # ) # allow 4 megabytes per sub-module # diagnostic = diagnostics.attach_diagnostics(m, opts) - - for n, (x,y) in enumerate(train_pairs): + for n, (x, y) in enumerate(train_pairs): y_out = m(x) - loss = ((y_out - y)**2).mean() * 100.0 + loss = ((y_out - y) ** 2).mean() * 100.0 if epoch == 0 and n == 0: avg_loss = loss.item() else: avg_loss = 0.98 * avg_loss + 0.02 * loss.item() if n == 0 and epoch % 5 == 0: - #norm1 = '%.2e' % (m[0].weight**2).mean().sqrt().item() - #norm1b = '%.2e' % (m[0].bias**2).mean().sqrt().item() - #norm2 = '%.2e' % (m[2].weight**2).mean().sqrt().item() - #norm2b = '%.2e' % (m[2].bias**2).mean().sqrt().item() - #scale1 = '%.2e' % (m[0].weight_scale.exp().item()) - #scale1b = '%.2e' % (m[0].bias_scale.exp().item()) - #scale2 = '%.2e' % (m[2].weight_scale.exp().item()) - #scale2b = '%.2e' % (m[2].bias_scale.exp().item()) + # norm1 = '%.2e' % (m[0].weight**2).mean().sqrt().item() + # norm1b = '%.2e' % (m[0].bias**2).mean().sqrt().item() + # norm2 = '%.2e' % (m[2].weight**2).mean().sqrt().item() + # norm2b = '%.2e' % (m[2].bias**2).mean().sqrt().item() + # scale1 = '%.2e' % (m[0].weight_scale.exp().item()) + # scale1b = '%.2e' % (m[0].bias_scale.exp().item()) + # scale2 = '%.2e' % (m[2].weight_scale.exp().item()) + # scale2b = '%.2e' % (m[2].bias_scale.exp().item()) lr = scheduler.get_last_lr()[0] - logging.info(f"Iter {iter}, epoch {epoch}, batch {n}, avg_loss {avg_loss:.4g}, lr={lr:.4e}") #, norms={norm1,norm1b,norm2,norm2b}") # scales={scale1,scale1b,scale2,scale2b} + logging.info( + f"Iter {iter}, epoch {epoch}, batch {n}, avg_loss" + f" {avg_loss:.4g}, lr={lr:.4e}" + ) # , norms={norm1,norm1b,norm2,norm2b}") # scales={scale1,scale1b,scale2,scale2b} loss.log().backward() optim.step() optim.zero_grad() scheduler.step_batch() - #diagnostic.print_diagnostics() + # diagnostic.print_diagnostics() stop = timeit.default_timer() logging.info(f"Iter={iter}, Time taken: {stop - start}") logging.info(f"last lr = {scheduler.get_last_lr()}") - #logging.info("state dict = ", scheduler.state_dict()) - #logging.info("optim state_dict = ", optim.state_dict()) + # logging.info("state dict = ", scheduler.state_dict()) + # logging.info("optim state_dict = ", optim.state_dict()) logging.info(f"input_magnitudes = {input_magnitudes}") logging.info(f"output_magnitudes = {output_magnitudes}") - if __name__ == "__main__": torch.set_num_threads(1) torch.set_num_interop_threads(1) logging.getLogger().setLevel(logging.INFO) import subprocess - s = subprocess.check_output("git status -uno .; git log -1; git diff HEAD .", shell=True) + + s = subprocess.check_output( + "git status -uno .; git log -1; git diff HEAD .", shell=True + ) logging.info(s) import sys + if len(sys.argv) > 1: hidden_dim = int(sys.argv[1]) else: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7/pretrained.py index 7fe1e681a..8b4d88871 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/pretrained.py @@ -100,9 +100,11 @@ def get_parser(): "--checkpoint", type=str, required=True, - help="Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint().", + help=( + "Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint()." + ), ) parser.add_argument( @@ -127,10 +129,12 @@ def get_parser(): "sound_files", type=str, nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", + help=( + "The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz." + ), ) parser.add_argument( @@ -177,8 +181,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -209,10 +212,9 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" - ) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" # We use only the first channel ans.append(wave[0]) return ans @@ -275,15 +277,11 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) feature_lengths = torch.tensor(feature_lengths, device=device) - encoder_out, encoder_out_lens = model.encoder( - x=features, x_lens=feature_lengths - ) + encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths) num_waves = encoder_out.size(0) hyps = [] @@ -355,9 +353,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 50cedba56..4040065e1 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -16,12 +16,12 @@ import collections +import logging +import random +from functools import reduce from itertools import repeat from typing import Optional, Tuple, Union -from functools import reduce -import logging -import random import torch import torch.nn as nn import torch.nn.functional as F @@ -32,27 +32,24 @@ from torch.nn import Embedding as ScaledEmbedding class ActivationBalancerFunction(torch.autograd.Function): @staticmethod def forward( - ctx, - x: Tensor, - scale_factor: Tensor, - sign_factor: Optional[Tensor], - channel_dim: int, + ctx, + x: Tensor, + scale_factor: Tensor, + sign_factor: Optional[Tensor], + channel_dim: int, ) -> Tensor: if channel_dim < 0: channel_dim += x.ndim ctx.channel_dim = channel_dim - xgt0 = (x > 0) + xgt0 = x > 0 if sign_factor is None: ctx.save_for_backward(xgt0, scale_factor) else: ctx.save_for_backward(xgt0, scale_factor, sign_factor) return x - @staticmethod - def backward( - ctx, x_grad: Tensor - ) -> Tuple[Tensor, None, None, None]: + def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None]: if len(ctx.saved_tensors) == 3: xgt0, scale_factor, sign_factor = ctx.saved_tensors for _ in range(ctx.channel_dim, x_grad.ndim - 1): @@ -65,14 +62,22 @@ class ActivationBalancerFunction(torch.autograd.Function): scale_factor = scale_factor.unsqueeze(-1) factor = scale_factor * (xgt0.to(x_grad.dtype) - 0.5) neg_delta_grad = x_grad.abs() * factor - return x_grad - neg_delta_grad, None, None, None, + return ( + x_grad - neg_delta_grad, + None, + None, + None, + ) -def _compute_scale_factor(x: Tensor, - channel_dim: int, - min_abs: float, - max_abs: float, - gain_factor: float, - max_factor: float) -> Tensor: + +def _compute_scale_factor( + x: Tensor, + channel_dim: int, + min_abs: float, + max_abs: float, + gain_factor: float, + max_factor: float, +) -> Tensor: if channel_dim < 0: channel_dim += x.ndim sum_dims = [d for d in range(x.ndim) if d != channel_dim] @@ -83,71 +88,76 @@ def _compute_scale_factor(x: Tensor, else: # below_threshold is 0 if x_abs_mean > min_abs, can be at most max_factor if # x_abs)_mean , min_abs. - below_threshold = ((min_abs - x_abs_mean) * (gain_factor / min_abs)).clamp(min=0, max=max_factor) + below_threshold = ((min_abs - x_abs_mean) * (gain_factor / min_abs)).clamp( + min=0, max=max_factor + ) - above_threshold = ((x_abs_mean - max_abs) * (gain_factor / max_abs)).clamp(min=0, max=max_factor) + above_threshold = ((x_abs_mean - max_abs) * (gain_factor / max_abs)).clamp( + min=0, max=max_factor + ) return below_threshold - above_threshold -def _compute_sign_factor(x: Tensor, - channel_dim: int, - min_positive: float, - max_positive: float, - gain_factor: float, - max_factor: float) -> Tensor: + +def _compute_sign_factor( + x: Tensor, + channel_dim: int, + min_positive: float, + max_positive: float, + gain_factor: float, + max_factor: float, +) -> Tensor: if channel_dim < 0: channel_dim += x.ndim sum_dims = [d for d in range(x.ndim) if d != channel_dim] - proportion_positive = torch.mean((x > 0).to(torch.float32), - dim=sum_dims) + proportion_positive = torch.mean((x > 0).to(torch.float32), dim=sum_dims) if min_positive == 0.0: factor1 = 0.0 else: # 0 if proportion_positive >= min_positive, else can be # as large as max_factor. - factor1 = ((min_positive - proportion_positive) * - (gain_factor / min_positive)).clamp_(min=0, max=max_factor) + factor1 = ( + (min_positive - proportion_positive) * (gain_factor / min_positive) + ).clamp_(min=0, max=max_factor) if max_positive == 1.0: factor2 = 0.0 else: # 0 if self.proportion_positive <= max_positive, else can be # as large as -max_factor. - factor2 = ((proportion_positive - max_positive) * - (gain_factor / (1.0 - max_positive))).clamp_(min=0, max=max_factor) + factor2 = ( + (proportion_positive - max_positive) * (gain_factor / (1.0 - max_positive)) + ).clamp_(min=0, max=max_factor) sign_factor = factor1 - factor2 # require min_positive != 0 or max_positive != 1: assert not isinstance(sign_factor, float) return sign_factor - class ActivationScaleBalancerFunction(torch.autograd.Function): """ This object is used in class ActivationBalancer when the user specified min_positive=0, max_positive=1, so there are no constraints on the signs of the activations and only the absolute value has a constraint. """ + @staticmethod def forward( - ctx, - x: Tensor, - sign_factor: Tensor, - scale_factor: Tensor, - channel_dim: int, + ctx, + x: Tensor, + sign_factor: Tensor, + scale_factor: Tensor, + channel_dim: int, ) -> Tensor: if channel_dim < 0: channel_dim += x.ndim ctx.channel_dim = channel_dim - xgt0 = (x > 0) + xgt0 = x > 0 ctx.save_for_backward(xgt0, sign_factor, scale_factor) return x - @staticmethod - def backward( - ctx, x_grad: Tensor - ) -> Tuple[Tensor, None, None, None]: + def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None]: xgt0, sign_factor, scale_factor = ctx.saved_tensors for _ in range(ctx.channel_dim, x_grad.ndim - 1): sign_factor = sign_factor.unsqueeze(-1) @@ -155,18 +165,24 @@ class ActivationScaleBalancerFunction(torch.autograd.Function): factor = sign_factor + scale_factor * (xgt0.to(x_grad.dtype) - 0.5) neg_delta_grad = x_grad.abs() * factor - return x_grad - neg_delta_grad, None, None, None, + return ( + x_grad - neg_delta_grad, + None, + None, + None, + ) class RandomClampFunction(torch.autograd.Function): @staticmethod def forward( - ctx, - x: Tensor, - min: Optional[float], - max: Optional[float], - prob: float, - reflect: float) -> Tensor: + ctx, + x: Tensor, + min: Optional[float], + max: Optional[float], + prob: float, + reflect: float, + ) -> Tensor: x_clamped = torch.clamp(x, min=min, max=max) mask = torch.rand_like(x) < prob ans = torch.where(mask, x_clamped, x) @@ -179,30 +195,32 @@ class RandomClampFunction(torch.autograd.Function): @staticmethod def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None, None, None, None]: - is_same, = ctx.saved_tensors + (is_same,) = ctx.saved_tensors x_grad = ans_grad * is_same.to(ans_grad.dtype) reflect = ctx.reflect - if reflect != 0.0: + if reflect != 0.0: x_grad = x_grad * (1.0 + reflect) - (ans_grad * reflect) return x_grad, None, None, None, None -def random_clamp(x: Tensor, - min: Optional[float] = None, - max: Optional[float] = None, - prob: float = 0.5, - reflect: float = 0.0): + +def random_clamp( + x: Tensor, + min: Optional[float] = None, + max: Optional[float] = None, + prob: float = 0.5, + reflect: float = 0.0, +): return RandomClampFunction.apply(x, min, max, prob, reflect) -def random_cast_to_half(x: Tensor, - min_abs: float = 5.0e-06) -> Tensor: +def random_cast_to_half(x: Tensor, min_abs: float = 5.0e-06) -> Tensor: """ A randomized way of casting a floating point value to half precision. """ if x.dtype == torch.float16: return x x_abs = x.abs() - is_too_small = (x_abs < min_abs) + is_too_small = x_abs < min_abs # for elements where is_too_small is true, random_val will contain +-min_abs with # probability (x.abs() / min_abs), and 0.0 otherwise. [so this preserves expectations, # for those elements]. @@ -215,6 +233,7 @@ class RandomGradFunction(torch.autograd.Function): Does nothing in forward pass; in backward pass, gets rid of very small grads using randomized approach that preserves expectations (intended to reduce roundoff). """ + @staticmethod def forward(ctx, x: Tensor, min_abs: float) -> Tensor: ctx.min_abs = min_abs @@ -223,35 +242,37 @@ class RandomGradFunction(torch.autograd.Function): @staticmethod def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None]: if ans_grad.dtype == torch.float16: - return random_cast_to_half(ans_grad.to(torch.float32), - min_abs=ctx.min_abs), None + return ( + random_cast_to_half(ans_grad.to(torch.float32), min_abs=ctx.min_abs), + None, + ) else: return ans_grad, None + class RandomGrad(torch.nn.Module): """ Gets rid of very small gradients using an expectation-preserving method, intended to increase accuracy of training when using amp (automatic mixed precision) """ - def __init__(self, - min_abs: float = 5.0e-06): + + def __init__(self, min_abs: float = 5.0e-06): super(RandomGrad, self).__init__() self.min_abs = min_abs - def forward(self, - x: Tensor): + def forward(self, x: Tensor): if torch.jit.is_scripting() or not self.training: return x else: return RandomGradFunction.apply(x, self.min_abs) - class SoftmaxFunction(torch.autograd.Function): """ Tries to handle half-precision derivatives in a randomized way that should be more accurate for training than the default behavior. """ + @staticmethod def forward(ctx, x: Tensor, dim: int): ans = x.softmax(dim=dim) @@ -267,7 +288,7 @@ class SoftmaxFunction(torch.autograd.Function): @staticmethod def backward(ctx, ans_grad: Tensor): - ans, = ctx.saved_tensors + (ans,) = ctx.saved_tensors with torch.cuda.amp.autocast(enabled=False): ans_grad = ans_grad.to(torch.float32) ans = ans.to(torch.float32) @@ -276,9 +297,7 @@ class SoftmaxFunction(torch.autograd.Function): return x_grad, None - -def softmax(x: Tensor, - dim: int): +def softmax(x: Tensor, dim: int): if torch.jit.is_scripting(): return x.softmax(dim) @@ -288,20 +307,18 @@ def softmax(x: Tensor, class MaxEigLimiterFunction(torch.autograd.Function): @staticmethod def forward( - ctx, - x: Tensor, - coeffs: Tensor, - direction: Tensor, - channel_dim: int, - grad_scale: float) -> Tensor: + ctx, + x: Tensor, + coeffs: Tensor, + direction: Tensor, + channel_dim: int, + grad_scale: float, + ) -> Tensor: ctx.channel_dim = channel_dim ctx.grad_scale = grad_scale - ctx.save_for_backward(x.detach(), - coeffs.detach(), - direction.detach()) + ctx.save_for_backward(x.detach(), coeffs.detach(), direction.detach()) return x - @staticmethod def backward(ctx, x_grad, *args): with torch.enable_grad(): @@ -311,15 +328,20 @@ class MaxEigLimiterFunction(torch.autograd.Function): x = x_orig.transpose(ctx.channel_dim, -1).reshape(-1, num_channels) new_direction.requires_grad = False x = x - x.mean(dim=0) - x_var = (x ** 2).mean() + x_var = (x**2).mean() x_residual = x - coeffs * new_direction - x_residual_var = (x_residual ** 2).mean() + x_residual_var = (x_residual**2).mean() # `variance_proportion` is the proportion of the variance accounted for # by the top eigen-direction. This is to be minimized. variance_proportion = (x_var - x_residual_var) / (x_var + 1.0e-20) variance_proportion.backward() x_orig_grad = x_orig.grad - x_extra_grad = x_orig.grad * ctx.grad_scale * x_grad.norm() / (x_orig_grad.norm() + 1.0e-20) + x_extra_grad = ( + x_orig.grad + * ctx.grad_scale + * x_grad.norm() + / (x_orig_grad.norm() + 1.0e-20) + ) return x_grad + x_extra_grad.detach(), None, None, None, None @@ -385,15 +407,12 @@ class BasicNorm(torch.nn.Module): # region if it happens to exit it. eps = eps.clamp(min=self.eps_min, max=self.eps_max) scales = ( - torch.mean(x ** 2, dim=self.channel_dim, keepdim=True) + eps.exp() + torch.mean(x**2, dim=self.channel_dim, keepdim=True) + eps.exp() ) ** -0.5 return x * scales - -def ScaledLinear(*args, - initial_scale: float = 1.0, - **kwargs ) -> nn.Linear: +def ScaledLinear(*args, initial_scale: float = 1.0, **kwargs) -> nn.Linear: """ Behaves like a constructor of a modified version of nn.Linear that gives an easy way to set the default initial parameter scale. @@ -412,16 +431,11 @@ def ScaledLinear(*args, with torch.no_grad(): ans.weight[:] *= initial_scale if ans.bias is not None: - torch.nn.init.uniform_(ans.bias, - -0.1 * initial_scale, - 0.1 * initial_scale) + torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale) return ans - -def ScaledConv1d(*args, - initial_scale: float = 1.0, - **kwargs ) -> nn.Conv1d: +def ScaledConv1d(*args, initial_scale: float = 1.0, **kwargs) -> nn.Conv1d: """ Behaves like a constructor of a modified version of nn.Conv1d that gives an easy way to set the default initial parameter scale. @@ -440,13 +454,10 @@ def ScaledConv1d(*args, with torch.no_grad(): ans.weight[:] *= initial_scale if ans.bias is not None: - torch.nn.init.uniform_(ans.bias, - -0.1 * initial_scale, - 0.1 * initial_scale) + torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale) return ans - class ActivationBalancer(torch.nn.Module): """ Modifies the backpropped derivatives of a function to try to encourage, for @@ -486,18 +497,19 @@ class ActivationBalancer(torch.nn.Module): from doing it at the same time. Early in training we may use higher probabilities than this; it will decay to this value. """ + def __init__( - self, - num_channels: int, - channel_dim: int, - min_positive: float = 0.05, - max_positive: float = 0.95, - max_factor: float = 0.04, - sign_gain_factor: float = 0.01, - scale_gain_factor: float = 0.02, - min_abs: float = 0.2, - max_abs: float = 100.0, - min_prob: float = 0.1, + self, + num_channels: int, + channel_dim: int, + min_positive: float = 0.05, + max_positive: float = 0.95, + max_factor: float = 0.04, + sign_gain_factor: float = 0.01, + scale_gain_factor: float = 0.02, + min_abs: float = 0.2, + max_abs: float = 100.0, + min_prob: float = 0.1, ): super(ActivationBalancer, self).__init__() self.num_channels = num_channels @@ -515,9 +527,7 @@ class ActivationBalancer(torch.nn.Module): # We occasionally sync this to a tensor called `count`, that exists to # make sure it is synced to disk when we load and save the model. self.cpu_count = 0 - self.register_buffer('count', torch.tensor(0, dtype=torch.int64)) - - + self.register_buffer("count", torch.tensor(0, dtype=torch.int64)) def forward(self, x: Tensor) -> Tensor: if torch.jit.is_scripting() or not x.requires_grad: @@ -535,26 +545,35 @@ class ActivationBalancer(torch.nn.Module): # the prob of doing some work exponentially decreases from 0.5 till it hits # a floor at min_prob (==0.1, by default) - prob = max(self.min_prob, 0.5 ** (1 + (count/4000.0))) + prob = max(self.min_prob, 0.5 ** (1 + (count / 4000.0))) if random.random() < prob: sign_gain_factor = 0.5 if self.min_positive != 0.0 or self.max_positive != 1.0: - sign_factor = _compute_sign_factor(x, self.channel_dim, - self.min_positive, self.max_positive, - gain_factor=self.sign_gain_factor / prob, - max_factor=self.max_factor) + sign_factor = _compute_sign_factor( + x, + self.channel_dim, + self.min_positive, + self.max_positive, + gain_factor=self.sign_gain_factor / prob, + max_factor=self.max_factor, + ) else: sign_factor = None - - scale_factor = _compute_scale_factor(x, self.channel_dim, - min_abs=self.min_abs, - max_abs=self.max_abs, - gain_factor=self.scale_gain_factor / prob, - max_factor=self.max_factor) + scale_factor = _compute_scale_factor( + x, + self.channel_dim, + min_abs=self.min_abs, + max_abs=self.max_abs, + gain_factor=self.scale_gain_factor / prob, + max_factor=self.max_factor, + ) return ActivationBalancerFunction.apply( - x, scale_factor, sign_factor, self.channel_dim, + x, + scale_factor, + sign_factor, + self.channel_dim, ) else: return _no_op(x) @@ -594,13 +613,12 @@ def _diag(x: Tensor): # like .diag(), but works for tensors with 3 dims. else: (batch, dim, dim) = x.shape x = x.reshape(batch, dim * dim) - x = x[:, ::dim+1] + x = x[:, :: dim + 1] assert x.shape == (batch, dim) return x -def _whitening_metric(x: Tensor, - num_groups: int): +def _whitening_metric(x: Tensor, num_groups: int): """ Computes the "whitening metric", a value which will be 1.0 if all the eigenvalues of of the centered feature covariance are the same within each group's covariance matrix @@ -630,19 +648,21 @@ def _whitening_metric(x: Tensor, # the following expression is what we'd get if we took the matrix product # of each covariance and measured the mean of its trace, i.e. # the same as _diag(torch.matmul(x_covar, x_covar)).mean(). - x_covarsq_mean_diag = (x_covar ** 2).sum() / (num_groups * channels_per_group) + x_covarsq_mean_diag = (x_covar**2).sum() / (num_groups * channels_per_group) # this metric will be >= 1.0; the larger it is, the less 'white' the data was. - metric = x_covarsq_mean_diag / (x_covar_mean_diag ** 2 + 1.0e-20) + metric = x_covarsq_mean_diag / (x_covar_mean_diag**2 + 1.0e-20) return metric class WhiteningPenaltyFunction(torch.autograd.Function): @staticmethod - def forward(ctx, - x: Tensor, - num_groups: int, - whitening_limit: float, - grad_scale: float) -> Tensor: + def forward( + ctx, + x: Tensor, + num_groups: int, + whitening_limit: float, + grad_scale: float, + ) -> Tensor: ctx.save_for_backward(x) ctx.num_groups = num_groups ctx.whitening_limit = whitening_limit @@ -650,9 +670,8 @@ class WhiteningPenaltyFunction(torch.autograd.Function): return x @staticmethod - def backward(ctx, - x_grad: Tensor): - x_orig, = ctx.saved_tensors + def backward(ctx, x_grad: Tensor): + (x_orig,) = ctx.saved_tensors with torch.enable_grad(): with torch.cuda.amp.autocast(enabled=False): x_detached = x_orig.to(torch.float32).detach() @@ -661,25 +680,29 @@ class WhiteningPenaltyFunction(torch.autograd.Function): metric = _whitening_metric(x_detached, ctx.num_groups) if random.random() < 0.005 or __name__ == "__main__": - logging.info(f"Whitening: num_groups={ctx.num_groups}, num_channels={x_orig.shape[-1]}, " - f"metric={metric.item():.2f} vs. limit={ctx.whitening_limit}") + logging.info( + f"Whitening: num_groups={ctx.num_groups}," + f" num_channels={x_orig.shape[-1]}," + f" metric={metric.item():.2f} vs. limit={ctx.whitening_limit}" + ) (metric - ctx.whitening_limit).relu().backward() penalty_grad = x_detached.grad - scale = ctx.grad_scale * (x_grad.to(torch.float32).norm() / - (penalty_grad.norm() + 1.0e-20)) + scale = ctx.grad_scale * ( + x_grad.to(torch.float32).norm() / (penalty_grad.norm() + 1.0e-20) + ) penalty_grad = penalty_grad * scale return x_grad + penalty_grad.to(x_grad.dtype), None, None, None - class Whiten(nn.Module): def __init__( - self, - num_groups: int, - whitening_limit: float, - prob: Union[float, Tuple[float,float]], - grad_scale: float): + self, + num_groups: int, + whitening_limit: float, + prob: Union[float, Tuple[float, float]], + grad_scale: float, + ): """ Args: num_groups: the number of groups to divide the channel dim into before @@ -714,8 +737,7 @@ class Whiten(nn.Module): self.grad_scale = grad_scale - def forward(self, - x: Tensor) -> Tensor: + def forward(self, x: Tensor) -> Tensor: """ In the forward pass, this function just returns the input unmodified. In the backward pass, it will modify the gradients to ensure that the @@ -735,19 +757,21 @@ class Whiten(nn.Module): if not x.requires_grad or random.random() > self.prob or self.grad_scale == 0: return _no_op(x) else: - if hasattr(self, 'min_prob') and random.random() < 0.25: + if hasattr(self, "min_prob") and random.random() < 0.25: # occasionally switch between min_prob and max_prob, based on whether # we are above or below the threshold. - if _whitening_metric(x.to(torch.float32), self.num_groups) > self.whitening_limit: + if ( + _whitening_metric(x.to(torch.float32), self.num_groups) + > self.whitening_limit + ): # there would be a change to the grad. self.prob = self.max_prob else: self.prob = self.min_prob - return WhiteningPenaltyFunction.apply(x, - self.num_groups, - self.whitening_limit, - self.grad_scale) + return WhiteningPenaltyFunction.apply( + x, self.num_groups, self.whitening_limit, self.grad_scale + ) class WithLoss(torch.autograd.Function): @@ -755,11 +779,14 @@ class WithLoss(torch.autograd.Function): def forward(ctx, x: Tensor, y: Tensor): ctx.y_shape = y.shape return x + @staticmethod def backward(ctx, ans_grad: Tensor): - return ans_grad, torch.ones(ctx.y_shape, - dtype=ans_grad.dtype, - device=ans_grad.device) + return ans_grad, torch.ones( + ctx.y_shape, dtype=ans_grad.dtype, device=ans_grad.device + ) + + def with_loss(x, y): if torch.jit.is_scripting(): return x @@ -768,7 +795,7 @@ def with_loss(x, y): def _no_op(x: Tensor) -> Tensor: - if (torch.jit.is_scripting()): + if torch.jit.is_scripting(): return x else: # a no-op function that will have a node in the autograd graph, @@ -783,6 +810,7 @@ class Identity(torch.nn.Module): def forward(self, x): return _no_op(x) + class MaxEig(torch.nn.Module): """ Modifies the backpropped derivatives of a function to try to discourage @@ -803,13 +831,14 @@ class MaxEig(torch.nn.Module): scale: determines the scale with which we modify the gradients, relative to the existing / unmodified gradients """ + def __init__( - self, - num_channels: int, - channel_dim: int, - max_var_per_eig: float = 0.2, - min_prob: float = 0.01, - scale: float = 0.01, + self, + num_channels: int, + channel_dim: int, + max_var_per_eig: float = 0.2, + min_prob: float = 0.01, + scale: float = 0.01, ): super(MaxEig, self).__init__() self.num_channels = num_channels @@ -825,7 +854,7 @@ class MaxEig(torch.nn.Module): # random parameters unchanged for comparison direction = torch.arange(num_channels).to(torch.float) direction = direction / direction.norm() - self.register_buffer('max_eig_direction', direction) + self.register_buffer("max_eig_direction", direction) self.min_prob = min_prob # cur_prob is the current probability we'll use to apply the ActivationBalancer. @@ -833,12 +862,12 @@ class MaxEig(torch.nn.Module): # active. self.cur_prob = 1.0 - - def forward(self, x: Tensor) -> Tensor: - if (torch.jit.is_scripting() or - self.max_var_per_eig <= 0 or - random.random() > self.cur_prob): + if ( + torch.jit.is_scripting() + or self.max_var_per_eig <= 0 + or random.random() > self.cur_prob + ): return _no_op(x) with torch.cuda.amp.autocast(enabled=False): @@ -848,7 +877,9 @@ class MaxEig(torch.nn.Module): with torch.no_grad(): x = x.transpose(self.channel_dim, -1).reshape(-1, self.num_channels) x = x - x.mean(dim=0) - new_direction, coeffs = self._find_direction_coeffs(x, self.max_eig_direction) + new_direction, coeffs = self._find_direction_coeffs( + x, self.max_eig_direction + ) x_var = (x**2).mean() x_residual = x - coeffs * new_direction x_residual_var = (x_residual**2).mean() @@ -861,7 +892,10 @@ class MaxEig(torch.nn.Module): self._set_direction(0.1 * self.max_eig_direction + new_direction) if random.random() < 0.01 or __name__ == "__main__": - logging.info(f"variance_proportion = {variance_proportion.item()}, shape={tuple(orig_x.shape)}, cur_prob={self.cur_prob}") + logging.info( + f"variance_proportion = {variance_proportion.item()}," + f" shape={tuple(orig_x.shape)}, cur_prob={self.cur_prob}" + ) if variance_proportion >= self.max_var_per_eig: # The constraint is active. Note, we should quite rarely @@ -869,17 +903,16 @@ class MaxEig(torch.nn.Module): # starting to diverge, should this constraint be active. cur_prob = self.cur_prob self.cur_prob = 1.0 # next time, do the update with probability 1.0. - return MaxEigLimiterFunction.apply(orig_x, coeffs, new_direction, - self.channel_dim, self.scale) + return MaxEigLimiterFunction.apply( + orig_x, coeffs, new_direction, self.channel_dim, self.scale + ) else: # let self.cur_prob exponentially approach self.min_prob, as # long as the constraint is inactive. self.cur_prob = 0.75 * self.cur_prob + 0.25 * self.min_prob return orig_x - - def _set_direction(self, - direction: Tensor): + def _set_direction(self, direction: Tensor): """ Sets self.max_eig_direction to a normalized version of `direction` """ @@ -889,40 +922,39 @@ class MaxEig(torch.nn.Module): if direction_sum - direction_sum == 0: # no inf/nan self.max_eig_direction[:] = direction else: - logging.info(f"Warning: sum of direction in MaxEig is {direction_sum}, " - "num_channels={self.num_channels}, channel_dim={self.channel_dim}") + logging.info( + f"Warning: sum of direction in MaxEig is {direction_sum}, " + "num_channels={self.num_channels}, channel_dim={self.channel_dim}" + ) - - def _find_direction_coeffs(self, - x: Tensor, - prev_direction: Tensor) -> Tuple[Tensor, Tensor, Tensor]: + def _find_direction_coeffs( + self, x: Tensor, prev_direction: Tensor + ) -> Tuple[Tensor, Tensor, Tensor]: """ - Figure out (an approximation to) the proportion of the variance of a set of - feature vectors that can be attributed to the top eigen-direction. - Args: - x: a Tensor of shape (num_frames, num_channels), with num_frames > 1. - prev_direction: a Tensor of shape (num_channels,), that is our previous estimate - of the top eigen-direction, or a random direction if this is the first - iteration. Does not have to be normalized, but should be nonzero. + Figure out (an approximation to) the proportion of the variance of a set of + feature vectors that can be attributed to the top eigen-direction. + Args: + x: a Tensor of shape (num_frames, num_channels), with num_frames > 1. + prev_direction: a Tensor of shape (num_channels,), that is our previous estimate + of the top eigen-direction, or a random direction if this is the first + iteration. Does not have to be normalized, but should be nonzero. - Returns: (cur_direction, coeffs), where: - cur_direction: a Tensor of shape (num_channels,) that is the current - estimate of the top eigen-direction. - coeffs: a Tensor of shape (num_frames, 1) that minimizes, or - approximately minimizes, (x - coeffs * cur_direction).norm() - """ + Returns: (cur_direction, coeffs), where: + cur_direction: a Tensor of shape (num_channels,) that is the current + estimate of the top eigen-direction. + coeffs: a Tensor of shape (num_frames, 1) that minimizes, or + approximately minimizes, (x - coeffs * cur_direction).norm() + """ (num_frames, num_channels) = x.shape assert num_channels > 1 and num_frames > 1 assert prev_direction.shape == (num_channels,) # `coeffs` are the coefficients of `prev_direction` in x. # actually represent the coeffs up to a constant positive factor. coeffs = (x * prev_direction).sum(dim=1, keepdim=True) + 1.0e-10 - cur_direction = (x * coeffs).sum(dim=0) / ((coeffs ** 2).sum() + 1.0e-20) + cur_direction = (x * coeffs).sum(dim=0) / ((coeffs**2).sum() + 1.0e-20) return cur_direction, coeffs - - class DoubleSwishFunction(torch.autograd.Function): """ double_swish(x) = x * torch.sigmoid(x-1) @@ -950,7 +982,7 @@ class DoubleSwishFunction(torch.autograd.Function): y = x * s if requires_grad: - deriv = (y * (1 - s) + s) + deriv = y * (1 - s) + s # notes on derivative of x * sigmoid(x - 1): # https://www.wolframalpha.com/input?i=d%2Fdx+%28x+*+sigmoid%28x-1%29%29 # min \simeq -0.043638. Take floor as -0.043637 so it's a lower bund @@ -959,7 +991,9 @@ class DoubleSwishFunction(torch.autograd.Function): # floors), should be expectation-preserving. floor = -0.043637 ceil = 1.2 - d_scaled = ((deriv - floor) * (255.0 / (ceil - floor)) + torch.rand_like(deriv)) + d_scaled = (deriv - floor) * (255.0 / (ceil - floor)) + torch.rand_like( + deriv + ) if __name__ == "__main__": # for self-testing only. assert d_scaled.min() >= 0.0 @@ -972,12 +1006,12 @@ class DoubleSwishFunction(torch.autograd.Function): @staticmethod def backward(ctx, y_grad: Tensor) -> Tensor: - d, = ctx.saved_tensors + (d,) = ctx.saved_tensors # the same constants as used in forward pass. floor = -0.043637 ceil = 1.2 - d = (d * ((ceil - floor) / 255.0) + floor) - return (y_grad * d) + d = d * ((ceil - floor) / 255.0) + floor + return y_grad * d class DoubleSwish(torch.nn.Module): @@ -990,7 +1024,6 @@ class DoubleSwish(torch.nn.Module): return DoubleSwishFunction.apply(x) - def _test_max_eig(): for proportion in [0.1, 0.5, 10.0]: logging.info(f"proportion = {proportion}") @@ -1002,11 +1035,9 @@ def _test_max_eig(): x.requires_grad = True num_channels = 128 - m = MaxEig(num_channels, - 1, # channel_dim - 0.5, # max_var_per_eig - scale=0.1) # grad_scale - + m = MaxEig( + num_channels, 1, 0.5, scale=0.1 # channel_dim # max_var_per_eig + ) # grad_scale for _ in range(4): y = m(x) @@ -1031,11 +1062,9 @@ def _test_whiten(): x.requires_grad = True num_channels = 128 - m = Whiten(1, # num_groups - 5.0, # whitening_limit, - prob=1.0, - grad_scale=0.1) # grad_scale - + m = Whiten( + 1, 5.0, prob=1.0, grad_scale=0.1 # num_groups # whitening_limit, + ) # grad_scale for _ in range(4): y = m(x) @@ -1049,7 +1078,6 @@ def _test_whiten(): assert not torch.allclose(x.grad, y_grad) - def _test_activation_balancer_sign(): probs = torch.arange(0, 1, 0.01) N = 1000 @@ -1077,9 +1105,7 @@ def _test_activation_balancer_sign(): def _test_activation_balancer_magnitude(): magnitudes = torch.arange(0, 1, 0.01) N = 1000 - x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze( - -1 - ) + x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1) x = x.detach() x.requires_grad = True m = ActivationBalancer( @@ -1111,8 +1137,8 @@ def _test_basic_norm(): y = m(x) assert y.shape == x.shape - x_rms = (x ** 2).mean().sqrt() - y_rms = (y ** 2).mean().sqrt() + x_rms = (x**2).mean().sqrt() + y_rms = (y**2).mean().sqrt() print("x rms = ", x_rms) print("y rms = ", y_rms) assert y_rms < x_rms @@ -1124,30 +1150,27 @@ def _test_double_swish_deriv(): x.requires_grad = True m = DoubleSwish() - tol = ((1.2-(-0.043637))/255.0) + tol = (1.2 - (-0.043637)) / 255.0 torch.autograd.gradcheck(m, x, atol=tol) - # for self-test. x = torch.randn(1000, 1000, dtype=torch.double) * 3.0 x.requires_grad = True y = m(x) - def _test_softmax(): a = torch.randn(2, 10, dtype=torch.float64) b = a.clone() a.requires_grad = True b.requires_grad = True - a.softmax(dim=1)[:,0].sum().backward() + a.softmax(dim=1)[:, 0].sum().backward() print("a grad = ", a.grad) - softmax(b, dim=1)[:,0].sum().backward() + softmax(b, dim=1)[:, 0].sum().backward() print("b grad = ", b.grad) assert torch.allclose(a.grad, b.grad) - if __name__ == "__main__": logging.getLogger().setLevel(logging.INFO) torch.set_num_threads(1) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling_converter.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling_converter.py index 8d357b15f..46e775285 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling_converter.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling_converter.py @@ -26,11 +26,7 @@ from typing import List import torch import torch.nn as nn -from scaling import ( - ActivationBalancer, - BasicNorm, - Whiten, -) +from scaling import ActivationBalancer, BasicNorm, Whiten class NonScaledNorm(nn.Module): @@ -75,12 +71,10 @@ def get_submodule(model, target): mod: torch.nn.Module = model for item in atoms: if not hasattr(mod, item): - raise AttributeError( - mod._get_name() + " has no " "attribute `" + item + "`" - ) + raise AttributeError(mod._get_name() + " has no attribute `" + item + "`") mod = getattr(mod, item) if not isinstance(mod, torch.nn.Module): - raise AttributeError("`" + item + "` is not " "an nn.Module") + raise AttributeError("`" + item + "` is not an nn.Module") return mod diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py index 3f27736b3..7f9526104 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py @@ -84,9 +84,7 @@ from icefall.env import get_env_info from icefall.hooks import register_inf_check_hooks from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: @@ -124,7 +122,10 @@ def add_model_arguments(parser: argparse.ArgumentParser): "--encoder-dims", type=str, default="384,384,384,384,384", - help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated", + help=( + "Embedding dimension in the 2 blocks of zipformer encoder layers, comma" + " separated" + ), ) parser.add_argument( @@ -139,9 +140,11 @@ def add_model_arguments(parser: argparse.ArgumentParser): "--encoder-unmasked-dims", type=str, default="256,256,256,256,256", - help="Unmasked dimensions in the encoders, relates to augmentation during training. " - "Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance " - " worse.", + help=( + "Unmasked dimensions in the encoders, relates to augmentation during" + " training. Must be <= each of encoder_dims. Empirically, less than 256" + " seems to make performance worse." + ), ) parser.add_argument( @@ -269,42 +272,45 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", + help=( + "The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss" + ), ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", + help=( + "The scale to smooth the loss with lm (output of prediction network) part." + ), ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network)part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", + help=( + "To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss." + ), ) parser.add_argument( @@ -646,11 +652,7 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = ( - model.device - if isinstance(model, DDP) - else next(model.parameters()).device - ) + device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -697,9 +699,7 @@ def compute_loss( info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() @@ -870,9 +870,7 @@ def train_one_epoch( # of the grad scaler is configurable, but we can't configure it to have different # behavior depending on the current grad scale. cur_grad_scale = scaler._scale.item() - if cur_grad_scale < 1.0 or ( - cur_grad_scale < 8.0 and batch_idx % 400 == 0 - ): + if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0): scaler.update(cur_grad_scale * 2.0) if cur_grad_scale < 0.01: logging.warning(f"Grad scale is small: {cur_grad_scale}") @@ -890,11 +888,7 @@ def train_one_epoch( f"batch {batch_idx}, loss[{loss_info}], " f"tot_loss[{tot_loss}], batch size: {batch_size}, " f"lr: {cur_lr:.2e}, " - + ( - f"grad_scale: {scaler._scale.item()}" - if params.use_fp16 - else "" - ) + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") ) if tb_writer is not None: @@ -905,9 +899,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if params.use_fp16: tb_writer.add_scalar( "train/grad_scale", @@ -915,10 +907,7 @@ def train_one_epoch( params.batch_idx_train, ) - if ( - batch_idx % params.valid_interval == 0 - and not params.print_diagnostics - ): + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: logging.info("Computing validation loss") valid_info = compute_validation_loss( params=params, @@ -930,7 +919,8 @@ def train_one_epoch( model.train() logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + "Maximum memory allocated so far is" + f" {torch.cuda.max_memory_allocated()//1000000}MB" ) if tb_writer is not None: valid_info.write_summary( @@ -1009,9 +999,7 @@ def run(rank, world_size, args): logging.info("Using DDP") model = DDP(model, device_ids=[rank], find_unused_parameters=True) - optimizer = ScaledAdam( - model.parameters(), lr=params.base_lr, clipping_scale=2.0 - ) + optimizer = ScaledAdam(model.parameters(), lr=params.base_lr, clipping_scale=2.0) scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) @@ -1029,7 +1017,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2 ** 22 + 2**22 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) @@ -1054,8 +1042,7 @@ def run(rank, world_size, args): # the threshold if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. " - f"Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" ) return False @@ -1229,7 +1216,8 @@ def scan_pessimistic_batches_for_oom( display_and_save_batch(batch, params=params, sp=sp) raise logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + "Maximum memory allocated so far is" + f" {torch.cuda.max_memory_allocated()//1000000}MB" ) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 023dec97d..fcd9858cd 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -16,32 +16,35 @@ # limitations under the License. import copy -import math -import warnings import itertools -from typing import List, Optional, Tuple, Union import logging -import torch +import math import random +import warnings +from typing import List, Optional, Tuple, Union + +import torch from encoder_interface import EncoderInterface +from scaling import ( + ScaledLinear, # not as in other dirs.. just scales down initial parameter values. +) from scaling import ( ActivationBalancer, BasicNorm, - MaxEig, DoubleSwish, - ScaledConv1d, - ScaledLinear, # not as in other dirs.. just scales down initial parameter values. - Whiten, Identity, + MaxEig, + ScaledConv1d, + Whiten, _diag, - random_clamp, penalize_abs_values_gt, + random_clamp, softmax, ) from torch import Tensor, nn -from icefall.utils import make_pad_mask from icefall.dist import get_rank +from icefall.utils import make_pad_mask class Zipformer(EncoderInterface): @@ -89,7 +92,7 @@ class Zipformer(EncoderInterface): self.batch_count = 0 self.warmup_end = warmup_batches - for u,d in zip(encoder_unmasked_dims, encoder_dims): + for u, d in zip(encoder_unmasked_dims, encoder_dims): assert u <= d, (u, d) # self.encoder_embed converts the input of shape (N, T, num_features) @@ -97,9 +100,9 @@ class Zipformer(EncoderInterface): # That is, it does two things simultaneously: # (1) subsampling: T -> (T - 7)//2 # (2) embedding: num_features -> encoder_dims - self.encoder_embed = Conv2dSubsampling(num_features, encoder_dims[0], - dropout=dropout) - + self.encoder_embed = Conv2dSubsampling( + num_features, encoder_dims[0], dropout=dropout + ) # each one will be ZipformerEncoder or DownsampledZipformerEncoder encoders = [] @@ -123,13 +126,13 @@ class Zipformer(EncoderInterface): num_encoder_layers[i], dropout, warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1), - warmup_end=warmup_batches * (i + 2) / (num_encoders + 1) + warmup_end=warmup_batches * (i + 2) / (num_encoders + 1), ) if zipformer_downsampling_factors[i] != 1: encoder = DownsampledZipformerEncoder( encoder, - input_dim=encoder_dims[i-1] if i > 0 else encoder_dims[0], + input_dim=encoder_dims[i - 1] if i > 0 else encoder_dims[0], output_dim=encoder_dims[i], downsample=zipformer_downsampling_factors[i], ) @@ -139,10 +142,11 @@ class Zipformer(EncoderInterface): # initializes self.skip_layers and self.skip_modules self._init_skip_modules() - self.downsample_output = AttentionDownsample(encoder_dims[-1], - encoder_dims[-1], - downsample=output_downsampling_factor) - + self.downsample_output = AttentionDownsample( + encoder_dims[-1], + encoder_dims[-1], + downsample=output_downsampling_factor, + ) def _get_layer_skip_dropout_prob(self): if not self.training: @@ -166,27 +170,33 @@ class Zipformer(EncoderInterface): skip_modules = [] z = self.zipformer_downsampling_factors for i in range(len(z)): - if i <= 1 or z[i-1] <= z[i]: + if i <= 1 or z[i - 1] <= z[i]: skip_layers.append(None) skip_modules.append(SimpleCombinerIdentity()) else: # TEMP - for j in range(i-2, -1, -1): + for j in range(i - 2, -1, -1): if z[j] <= z[i] or j == 0: # TEMP logging statement. - logging.info(f"At encoder stack {i}, which has downsampling_factor={z[i]}, we will " - f"combine the outputs of layers {j} and {i-1}, with downsampling_factors={z[j]} and {z[i-1]}.") + logging.info( + f"At encoder stack {i}, which has" + f" downsampling_factor={z[i]}, we will combine the outputs" + f" of layers {j} and {i-1}, with" + f" downsampling_factors={z[j]} and {z[i-1]}." + ) skip_layers.append(j) - skip_modules.append(SimpleCombiner(self.encoder_dims[j], - self.encoder_dims[i-1], - min_weight=(0.0,0.25))) + skip_modules.append( + SimpleCombiner( + self.encoder_dims[j], + self.encoder_dims[i - 1], + min_weight=(0.0, 0.25), + ) + ) break self.skip_layers = skip_layers self.skip_modules = nn.ModuleList(skip_modules) - def get_feature_masks( - self, - x: torch.Tensor) -> List[float]: + def get_feature_masks(self, x: torch.Tensor) -> List[float]: # Note: The actual return type is Union[List[float], List[Tensor]], # but to make torch.jit.script() work, we use List[float] """ @@ -206,46 +216,56 @@ class Zipformer(EncoderInterface): """ num_encoders = len(self.encoder_dims) if torch.jit.is_scripting() or not self.training: - return [ 1.0 ] * num_encoders + return [1.0] * num_encoders (num_frames0, batch_size, _encoder_dims0) = x.shape - - assert self.encoder_dims[0] == _encoder_dims0, (self.encoder_dims, _encoder_dims0) + assert self.encoder_dims[0] == _encoder_dims0, ( + self.encoder_dims, + _encoder_dims0, + ) max_downsampling_factor = max(self.zipformer_downsampling_factors) - num_frames_max = (num_frames0 + max_downsampling_factor - 1) - + num_frames_max = num_frames0 + max_downsampling_factor - 1 feature_mask_dropout_prob = 0.15 # frame_mask_max shape: (num_frames_max, batch_size, 1) - frame_mask_max = (torch.rand(num_frames_max, batch_size, 1, - device=x.device) > - feature_mask_dropout_prob).to(x.dtype) + frame_mask_max = ( + torch.rand(num_frames_max, batch_size, 1, device=x.device) + > feature_mask_dropout_prob + ).to(x.dtype) feature_masks = [] for i in range(num_encoders): ds = self.zipformer_downsampling_factors[i] - upsample_factor = (max_downsampling_factor // ds) + upsample_factor = max_downsampling_factor // ds - frame_mask = (frame_mask_max.unsqueeze(1).expand(num_frames_max, upsample_factor, - batch_size, 1) - .reshape(num_frames_max * upsample_factor, batch_size, 1)) + frame_mask = ( + frame_mask_max.unsqueeze(1) + .expand(num_frames_max, upsample_factor, batch_size, 1) + .reshape(num_frames_max * upsample_factor, batch_size, 1) + ) num_frames = (num_frames0 + ds - 1) // ds frame_mask = frame_mask[:num_frames] - feature_mask = torch.ones(num_frames, batch_size, self.encoder_dims[i], - dtype=x.dtype, device=x.device) + feature_mask = torch.ones( + num_frames, + batch_size, + self.encoder_dims[i], + dtype=x.dtype, + device=x.device, + ) u = self.encoder_unmasked_dims[i] feature_mask[:, :, u:] *= frame_mask feature_masks.append(feature_mask) return feature_masks - def forward( - self, x: torch.Tensor, x_lens: torch.Tensor, + self, + x: torch.Tensor, + x_lens: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: @@ -265,13 +285,19 @@ class Zipformer(EncoderInterface): x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) lengths = (x_lens - 7) >> 1 - assert x.size(0) == lengths.max().item(), (x.shape, lengths, lengths.max()) + assert x.size(0) == lengths.max().item(), ( + x.shape, + lengths, + lengths.max(), + ) mask = make_pad_mask(lengths) outputs = [] feature_masks = self.get_feature_masks(x) - for i, (module, skip_module) in enumerate(zip(self.encoders, self.skip_modules)): + for i, (module, skip_module) in enumerate( + zip(self.encoders, self.skip_modules) + ): ds = self.zipformer_downsampling_factors[i] k = self.skip_layers[i] if isinstance(k, int): @@ -280,9 +306,11 @@ class Zipformer(EncoderInterface): x = skip_module(outputs[k], x) elif (not self.training) or random.random() > layer_skip_dropout_prob: x = skip_module(outputs[k], x) - x = module(x, - feature_mask=feature_masks[i], - src_key_padding_mask=None if mask is None else mask[...,::ds]) + x = module( + x, + feature_mask=feature_masks[i], + src_key_padding_mask=None if mask is None else mask[..., ::ds], + ) outputs.append(x) x = self.downsample_output(x) @@ -312,15 +340,16 @@ class ZipformerEncoderLayer(nn.Module): >>> pos_emb = torch.rand(32, 19, 512) >>> out = encoder_layer(src, pos_emb) """ + def __init__( - self, - d_model: int, - attention_dim: int, - nhead: int, - feedforward_dim: int = 2048, - dropout: float = 0.1, - cnn_module_kernel: int = 31, - pos_dim: int = 4, + self, + d_model: int, + attention_dim: int, + nhead: int, + feedforward_dim: int = 2048, + dropout: float = 0.1, + cnn_module_kernel: int = 31, + pos_dim: int = 4, ) -> None: super(ZipformerEncoderLayer, self).__init__() @@ -330,29 +359,24 @@ class ZipformerEncoderLayer(nn.Module): self.batch_count = 0 self.self_attn = RelPositionMultiheadAttention( - d_model, attention_dim, nhead, pos_dim, dropout=0.0, + d_model, + attention_dim, + nhead, + pos_dim, + dropout=0.0, ) self.pooling = PoolingModule(d_model) - self.feed_forward1 = FeedforwardModule(d_model, - feedforward_dim, - dropout) + self.feed_forward1 = FeedforwardModule(d_model, feedforward_dim, dropout) - self.feed_forward2 = FeedforwardModule(d_model, - feedforward_dim, - dropout) + self.feed_forward2 = FeedforwardModule(d_model, feedforward_dim, dropout) - self.feed_forward3 = FeedforwardModule(d_model, - feedforward_dim, - dropout) + self.feed_forward3 = FeedforwardModule(d_model, feedforward_dim, dropout) + self.conv_module1 = ConvolutionModule(d_model, cnn_module_kernel) - self.conv_module1 = ConvolutionModule(d_model, - cnn_module_kernel) - - self.conv_module2 = ConvolutionModule(d_model, - cnn_module_kernel) + self.conv_module2 = ConvolutionModule(d_model, cnn_module_kernel) self.norm_final = BasicNorm(d_model) @@ -360,14 +384,18 @@ class ZipformerEncoderLayer(nn.Module): # try to ensure the output is close to zero-mean (or at least, zero-median). self.balancer = ActivationBalancer( - d_model, channel_dim=-1, - min_positive=0.45, max_positive=0.55, + d_model, + channel_dim=-1, + min_positive=0.45, + max_positive=0.55, max_abs=6.0, ) - self.whiten = Whiten(num_groups=1, - whitening_limit=5.0, - prob=(0.025, 0.25), - grad_scale=0.01) + self.whiten = Whiten( + num_groups=1, + whitening_limit=5.0, + prob=(0.025, 0.25), + grad_scale=0.01, + ) def get_bypass_scale(self): if torch.jit.is_scripting() or not self.training: @@ -382,8 +410,9 @@ class ZipformerEncoderLayer(nn.Module): if self.batch_count > warmup_period: clamp_min = final_clamp_min else: - clamp_min = (initial_clamp_min - - (self.batch_count / warmup_period) * (initial_clamp_min - final_clamp_min)) + clamp_min = initial_clamp_min - (self.batch_count / warmup_period) * ( + initial_clamp_min - final_clamp_min + ) return self.bypass_scale.clamp(min=clamp_min, max=1.0) def get_dynamic_dropout_rate(self): @@ -398,8 +427,9 @@ class ZipformerEncoderLayer(nn.Module): if self.batch_count > warmup_period: return final_dropout_rate else: - return (initial_dropout_rate - - (initial_dropout_rate * final_dropout_rate) * (self.batch_count / warmup_period)) + return initial_dropout_rate - ( + initial_dropout_rate * final_dropout_rate + ) * (self.batch_count / warmup_period) def forward( self, @@ -508,13 +538,14 @@ class ZipformerEncoder(nn.Module): >>> src = torch.rand(10, 32, 512) >>> out = zipformer_encoder(src) """ + def __init__( - self, - encoder_layer: nn.Module, - num_layers: int, - dropout: float, - warmup_begin: float, - warmup_end: float + self, + encoder_layer: nn.Module, + num_layers: int, + dropout: float, + warmup_begin: float, + warmup_end: float, ) -> None: super().__init__() # will be written to, see set_batch_count() Note: in inference time this @@ -528,8 +559,7 @@ class ZipformerEncoder(nn.Module): # so that we can keep this consistent across worker tasks (for efficiency). self.module_seed = torch.randint(0, 1000, ()).item() - self.encoder_pos = RelPositionalEncoding(encoder_layer.d_model, - dropout) + self.encoder_pos = RelPositionalEncoding(encoder_layer.d_model, dropout) self.layers = nn.ModuleList( [copy.deepcopy(encoder_layer) for i in range(num_layers)] @@ -538,15 +568,13 @@ class ZipformerEncoder(nn.Module): assert 0 <= warmup_begin <= warmup_end, (warmup_begin, warmup_end) - - delta = (1. / num_layers) * (warmup_end - warmup_begin) + delta = (1.0 / num_layers) * (warmup_end - warmup_begin) cur_begin = warmup_begin for i in range(num_layers): self.layers[i].warmup_begin = cur_begin cur_begin += delta self.layers[i].warmup_end = cur_begin - def get_layers_to_drop(self, rnd_seed: int): ans = set() if not self.training: @@ -579,12 +607,14 @@ class ZipformerEncoder(nn.Module): # linearly interpolate t = (batch_count - layer_warmup_begin) / layer_warmup_end assert 0.0 <= t < 1.001, t - return initial_layerdrop_prob + t * (final_layerdrop_prob - initial_layerdrop_prob) + return initial_layerdrop_prob + t * ( + final_layerdrop_prob - initial_layerdrop_prob + ) shared_rng = random.Random(batch_count + self.module_seed) independent_rng = random.Random(rnd_seed) - layerdrop_probs = [ get_layerdrop_prob(i) for i in range(num_layers) ] + layerdrop_probs = [get_layerdrop_prob(i) for i in range(num_layers)] tot = sum(layerdrop_probs) # Instead of drawing the samples independently, we first randomly decide # how many layers to drop out, using the same random number generator between @@ -604,11 +634,13 @@ class ZipformerEncoder(nn.Module): if len(ans) == num_to_drop: break if shared_rng.random() < 0.005 or __name__ == "__main__": - logging.info(f"warmup_begin={self.warmup_begin:.1f}, warmup_end={self.warmup_end:.1f}, " - f"batch_count={batch_count:.1f}, num_to_drop={num_to_drop}, layers_to_drop={ans}") + logging.info( + f"warmup_begin={self.warmup_begin:.1f}," + f" warmup_end={self.warmup_end:.1f}, batch_count={batch_count:.1f}," + f" num_to_drop={num_to_drop}, layers_to_drop={ans}" + ) return ans - def forward( self, src: Tensor, @@ -639,7 +671,6 @@ class ZipformerEncoder(nn.Module): pos_emb = self.encoder_pos(src) output = src - if torch.jit.is_scripting(): layers_to_drop = [] else: @@ -670,28 +701,31 @@ class DownsampledZipformerEncoder(nn.Module): after convolutional downsampling, and then upsampled again at the output, and combined with the origin input, so that the output has the same shape as the input. """ - def __init__(self, - encoder: nn.Module, - input_dim: int, - output_dim: int, - downsample: int): + + def __init__( + self, + encoder: nn.Module, + input_dim: int, + output_dim: int, + downsample: int, + ): super(DownsampledZipformerEncoder, self).__init__() self.downsample_factor = downsample self.downsample = AttentionDownsample(input_dim, output_dim, downsample) self.encoder = encoder self.upsample = SimpleUpsample(output_dim, downsample) - self.out_combiner = SimpleCombiner(input_dim, - output_dim, - min_weight=(0.0, 0.25)) + self.out_combiner = SimpleCombiner( + input_dim, output_dim, min_weight=(0.0, 0.25) + ) - - def forward(self, - src: Tensor, - # Note: the type of feature_mask should be Unino[float, Tensor], - # but to make torch.jit.script() happ, we use float here - feature_mask: float = 1.0, - mask: Optional[Tensor] = None, - src_key_padding_mask: Optional[Tensor] = None, + def forward( + self, + src: Tensor, + # Note: the type of feature_mask should be Unino[float, Tensor], + # but to make torch.jit.script() happ, we use float here + feature_mask: float = 1.0, + mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, ) -> Tensor: r"""Downsample, go through encoder, upsample. @@ -718,42 +752,43 @@ class DownsampledZipformerEncoder(nn.Module): src = self.downsample(src) ds = self.downsample_factor if mask is not None: - mask = mask[::ds,::ds] + mask = mask[::ds, ::ds] src = self.encoder( - src, feature_mask=feature_mask, mask=mask, src_key_padding_mask=mask, + src, + feature_mask=feature_mask, + mask=mask, + src_key_padding_mask=mask, ) src = self.upsample(src) # remove any extra frames that are not a multiple of downsample_factor - src = src[:src_orig.shape[0]] + src = src[: src_orig.shape[0]] return self.out_combiner(src_orig, src) + class AttentionDownsample(torch.nn.Module): """ Does downsampling with attention, by weighted sum, and a projection.. """ - def __init__(self, - in_channels: int, - out_channels: int, - downsample: int): + + def __init__(self, in_channels: int, out_channels: int, downsample: int): """ Require out_channels > in_channels. """ super(AttentionDownsample, self).__init__() - self.query = nn.Parameter(torch.randn(in_channels) * (in_channels ** -0.5)) + self.query = nn.Parameter(torch.randn(in_channels) * (in_channels**-0.5)) # fill in the extra dimensions with a projection of the input if out_channels > in_channels: - self.extra_proj = nn.Linear(in_channels * downsample, - out_channels - in_channels, - bias=False) + self.extra_proj = nn.Linear( + in_channels * downsample, out_channels - in_channels, bias=False + ) else: self.extra_proj = None self.downsample = downsample - def forward(self, - src: Tensor) -> Tensor: + def forward(self, src: Tensor) -> Tensor: """ x: (seq_len, batch_size, in_channels) Returns a tensor of shape @@ -767,16 +802,14 @@ class AttentionDownsample(torch.nn.Module): if seq_len != d_seq_len * ds: # right-pad src, repeating the last element. pad = d_seq_len * ds - seq_len - src_extra = src[src.shape[0]-1:].expand(pad, src.shape[1], src.shape[2]) + src_extra = src[src.shape[0] - 1 :].expand(pad, src.shape[1], src.shape[2]) src = torch.cat((src, src_extra), dim=0) assert src.shape[0] == d_seq_len * ds, (src.shape[0], d_seq_len, ds) src = src.reshape(d_seq_len, ds, batch_size, in_channels) scores = (src * self.query).sum(dim=-1, keepdim=True) - scores = penalize_abs_values_gt(scores, - limit=10.0, - penalty=1.0e-04) + scores = penalize_abs_values_gt(scores, limit=10.0, penalty=1.0e-04) weights = scores.softmax(dim=1) @@ -795,14 +828,12 @@ class SimpleUpsample(torch.nn.Module): A very simple form of upsampling that mostly just repeats the input, but also adds a position-specific bias. """ - def __init__(self, - num_channels: int, - upsample: int): + + def __init__(self, num_channels: int, upsample: int): super(SimpleUpsample, self).__init__() self.bias = nn.Parameter(torch.randn(upsample, num_channels) * 0.01) - def forward(self, - src: Tensor) -> Tensor: + def forward(self, src: Tensor) -> Tensor: """ x: (seq_len, batch_size, num_channels) Returns a tensor of shape @@ -815,6 +846,7 @@ class SimpleUpsample(torch.nn.Module): src = src.reshape(seq_len * upsample, batch_size, num_channels) return src + class SimpleCombinerIdentity(nn.Module): def __init__(self, *args, **kwargs): super().__init__() @@ -822,6 +854,7 @@ class SimpleCombinerIdentity(nn.Module): def forward(self, src1: Tensor, src2: Tensor) -> Tensor: return src1 + class SimpleCombiner(torch.nn.Module): """ A very simple way of combining 2 vectors of 2 different dims, via a @@ -831,18 +864,14 @@ class SimpleCombiner(torch.nn.Module): dim2: the dimension of the second input, e.g. 384. The output will have the same dimension as dim2. """ - def __init__(self, - dim1: int, - dim2: int, - min_weight: Tuple[float] = (0., 0.)): + + def __init__(self, dim1: int, dim2: int, min_weight: Tuple[float] = (0.0, 0.0)): super(SimpleCombiner, self).__init__() assert dim2 >= dim1, (dim2, dim1) self.weight1 = nn.Parameter(torch.zeros(())) self.min_weight = min_weight - def forward(self, - src1: Tensor, - src2: Tensor) -> Tensor: + def forward(self, src1: Tensor, src2: Tensor) -> Tensor: """ src1: (*, dim1) src2: (*, dim2) @@ -853,10 +882,14 @@ class SimpleCombiner(torch.nn.Module): weight1 = self.weight1 if not torch.jit.is_scripting(): - if self.training and random.random() < 0.25 and self.min_weight != (0., 0.): - weight1 = weight1.clamp(min=self.min_weight[0], - max=1.0-self.min_weight[1]) - + if ( + self.training + and random.random() < 0.25 + and self.min_weight != (0.0, 0.0) + ): + weight1 = weight1.clamp( + min=self.min_weight[0], max=1.0 - self.min_weight[1] + ) src1 = src1 * weight1 src2 = src2 * (1.0 - weight1) @@ -869,12 +902,9 @@ class SimpleCombiner(torch.nn.Module): else: src1 = src1[:src2_dim] - return src1 + src2 - - class RelPositionalEncoding(torch.nn.Module): """Relative positional encoding module. @@ -888,9 +918,7 @@ class RelPositionalEncoding(torch.nn.Module): """ - def __init__( - self, d_model: int, dropout_rate: float, max_len: int = 5000 - ) -> None: + def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: """Construct a PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() self.d_model = d_model @@ -905,9 +933,7 @@ class RelPositionalEncoding(torch.nn.Module): # the length of self.pe is 2 * input_len - 1 if self.pe.size(1) >= x.size(0) * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str( - x.device - ): + if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return # Suppose `i` means to the position of query vecotr and `j` means the @@ -955,7 +981,6 @@ class RelPositionalEncoding(torch.nn.Module): return self.dropout(pos_emb) - class RelPositionMultiheadAttention(nn.Module): r"""Multi-Head Attention layer with relative position encoding @@ -992,34 +1017,46 @@ class RelPositionMultiheadAttention(nn.Module): self.head_dim = attention_dim // num_heads self.pos_dim = pos_dim assert self.head_dim % 2 == 0, self.head_dim - assert ( - self.head_dim * num_heads == attention_dim - ), (self.head_dim, num_heads, attention_dim) + assert self.head_dim * num_heads == attention_dim, ( + self.head_dim, + num_heads, + attention_dim, + ) # the initial_scale is supposed to take over the "scaling" factor of # head_dim ** -0.5, dividing it between the query and key. - in_proj_dim = (2 * attention_dim + # query, key - attention_dim // 2 + # value - pos_dim * num_heads) # positional encoding query + in_proj_dim = ( + 2 * attention_dim + + attention_dim // 2 # query, key + + pos_dim * num_heads # value + ) # positional encoding query - self.in_proj = ScaledLinear(embed_dim, in_proj_dim, bias=True, - initial_scale=self.head_dim**-0.25) + self.in_proj = ScaledLinear( + embed_dim, + in_proj_dim, + bias=True, + initial_scale=self.head_dim**-0.25, + ) # self.whiten_values is applied on the values in forward(); # it just copies the keys but prevents low-rank distribution by modifying grads. - self.whiten_values = Whiten(num_groups=num_heads, - whitening_limit=2.0, - prob=(0.025, 0.25), - grad_scale=0.025) - self.whiten_keys = Whiten(num_groups=num_heads, - whitening_limit=2.0, - prob=(0.025, 0.25), - grad_scale=0.025) - + self.whiten_values = Whiten( + num_groups=num_heads, + whitening_limit=2.0, + prob=(0.025, 0.25), + grad_scale=0.025, + ) + self.whiten_keys = Whiten( + num_groups=num_heads, + whitening_limit=2.0, + prob=(0.025, 0.25), + grad_scale=0.025, + ) # linear transformation for positional encoding. - self.linear_pos = ScaledLinear(embed_dim, num_heads * pos_dim, bias=False, - initial_scale=0.05) + self.linear_pos = ScaledLinear( + embed_dim, num_heads * pos_dim, bias=False, initial_scale=0.05 + ) # the following are for diagnosics only, see --print-diagnostics option. # they only copy their inputs. @@ -1031,14 +1068,16 @@ class RelPositionMultiheadAttention(nn.Module): ) self.in_proj2 = nn.Linear(embed_dim, attention_dim // 2, bias=False) - self.out_proj2 = ScaledLinear(attention_dim // 2, embed_dim, bias=True, - initial_scale=0.05) + self.out_proj2 = ScaledLinear( + attention_dim // 2, embed_dim, bias=True, initial_scale=0.05 + ) # self.whiten_values2 is applied on the values in forward2() - self.whiten_values2 = Whiten(num_groups=num_heads, - whitening_limit=2.0, - prob=(0.025, 0.25), - grad_scale=0.025) - + self.whiten_values2 = Whiten( + num_groups=num_heads, + whitening_limit=2.0, + prob=(0.025, 0.25), + grad_scale=0.025, + ) def forward( self, @@ -1098,7 +1137,6 @@ class RelPositionMultiheadAttention(nn.Module): ) return x, weights - def multi_head_attention_forward( self, x_proj: Tensor, @@ -1156,26 +1194,24 @@ class RelPositionMultiheadAttention(nn.Module): head_dim = attention_dim // num_heads pos_dim = self.pos_dim # positional-encoding dim per head - assert ( - head_dim * num_heads == attention_dim - ), f"attention_dim must be divisible by num_heads: {head_dim}, {num_heads}, {attention_dim}" - + assert head_dim * num_heads == attention_dim, ( + f"attention_dim must be divisible by num_heads: {head_dim}, {num_heads}," + f" {attention_dim}" + ) # self-attention - q = x_proj[...,0:attention_dim] - k = x_proj[...,attention_dim:2*attention_dim] + q = x_proj[..., 0:attention_dim] + k = x_proj[..., attention_dim : 2 * attention_dim] value_dim = attention_dim // 2 - v = x_proj[...,2*attention_dim:2*attention_dim+value_dim] + v = x_proj[..., 2 * attention_dim : 2 * attention_dim + value_dim] # p is the position-encoding query, its dimension is num_heads*pos_dim.. - p = x_proj[...,2*attention_dim+value_dim:] - + p = x_proj[..., 2 * attention_dim + value_dim :] k = self.whiten_keys(k) # does nothing in the forward pass. v = self.whiten_values(v) # does nothing in the forward pass. q = self.copy_query(q) # for diagnostics only, does nothing. p = self.copy_pos_query(p) # for diagnostics only, does nothing. - if attn_mask is not None: assert ( attn_mask.dtype == torch.float32 @@ -1195,33 +1231,25 @@ class RelPositionMultiheadAttention(nn.Module): if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, seq_len, seq_len]: - raise RuntimeError( - "The size of the 2D attn_mask is not correct." - ) + raise RuntimeError("The size of the 2D attn_mask is not correct.") elif attn_mask.dim() == 3: if list(attn_mask.size()) != [ bsz * num_heads, seq_len, seq_len, ]: - raise RuntimeError( - "The size of the 3D attn_mask is not correct." - ) + raise RuntimeError("The size of the 3D attn_mask is not correct.") else: raise RuntimeError( - "attn_mask's dimension {} is not supported".format( - attn_mask.dim() - ) + "attn_mask's dimension {} is not supported".format(attn_mask.dim()) ) # attn_mask's dim is 3 now. # convert ByteTensor key_padding_mask to bool - if ( - key_padding_mask is not None - and key_padding_mask.dtype == torch.uint8 - ): + if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: warnings.warn( - "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." + "Byte tensor for key_padding_mask is deprecated. Use bool tensor" + " instead." ) key_padding_mask = key_padding_mask.to(torch.bool) @@ -1230,7 +1258,6 @@ class RelPositionMultiheadAttention(nn.Module): k = k.reshape(seq_len, bsz, num_heads, head_dim) v = v.reshape(seq_len, bsz * num_heads, head_dim // 2).transpose(0, 1) - if key_padding_mask is not None: assert key_padding_mask.size(0) == bsz, "{} == {}".format( key_padding_mask.size(0), bsz @@ -1239,13 +1266,10 @@ class RelPositionMultiheadAttention(nn.Module): key_padding_mask.size(1), seq_len ) - - q = q.permute(1, 2, 0, 3) # (batch, head, time1, head_dim) p = p.permute(1, 2, 0, 3) # (batch, head, time1, pos_dim) k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - seq_len2 = 2 * seq_len - 1 pos = pos.reshape(1, seq_len2, num_heads, pos_dim).permute(0, 2, 3, 1) # pos shape now: (batch, head, pos_dim, seq_len2) @@ -1256,13 +1280,16 @@ class RelPositionMultiheadAttention(nn.Module): # the following .as_strided() expression converts the last axis of pos_weights from relative # to absolute position. I don't know whether I might have got the time-offsets backwards or # not, but let this code define which way round it is supposed to be. - pos_weights = pos_weights.as_strided((bsz, num_heads, seq_len, seq_len), - (pos_weights.stride(0), - pos_weights.stride(1), - pos_weights.stride(2)-pos_weights.stride(3), - pos_weights.stride(3)), - storage_offset=pos_weights.stride(3) * (seq_len - 1)) - + pos_weights = pos_weights.as_strided( + (bsz, num_heads, seq_len, seq_len), + ( + pos_weights.stride(0), + pos_weights.stride(1), + pos_weights.stride(2) - pos_weights.stride(3), + pos_weights.stride(3), + ), + storage_offset=pos_weights.stride(3) * (seq_len - 1), + ) # caution: they are really scores at this point. attn_output_weights = torch.matmul(q, k) + pos_weights @@ -1275,10 +1302,9 @@ class RelPositionMultiheadAttention(nn.Module): # this mechanism instead of, say, a limit on entropy, because once the entropy # gets very small gradients through the softmax can become very small, and # some mechanisms like that become ineffective. - attn_output_weights = penalize_abs_values_gt(attn_output_weights, - limit=25.0, - penalty=1.0e-04) - + attn_output_weights = penalize_abs_values_gt( + attn_output_weights, limit=25.0, penalty=1.0e-04 + ) # attn_output_weights: (batch, head, time1, time2) attn_output_weights = attn_output_weights.view( @@ -1320,20 +1346,20 @@ class RelPositionMultiheadAttention(nn.Module): ) attn_output = torch.bmm(attn_output_weights, v) - assert list(attn_output.size()) == [bsz * num_heads, seq_len, - head_dim // 2] + assert list(attn_output.size()) == [ + bsz * num_heads, + seq_len, + head_dim // 2, + ] attn_output = ( attn_output.transpose(0, 1) .contiguous() .view(seq_len, bsz, attention_dim // 2) ) - attn_output = nn.functional.linear( - attn_output, out_proj_weight, out_proj_bias - ) + attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) return attn_output, attn_output_weights - def forward2( self, x: Tensor, @@ -1372,11 +1398,7 @@ class RelPositionMultiheadAttention(nn.Module): # returned value is of shape (seq_len, bsz, embed_dim), like x. return self.out_proj2(attn_output) - - def _print_attn_stats( - self, - attn_weights: Tensor, - attn_output: Tensor): + def _print_attn_stats(self, attn_weights: Tensor, attn_output: Tensor): # attn_weights: (batch_size * num_heads, seq_len, seq_len) # attn_output: (bsz * num_heads, seq_len, head_dim) (n, seq_len, head_dim) = attn_output.shape @@ -1387,39 +1409,50 @@ class RelPositionMultiheadAttention(nn.Module): with torch.cuda.amp.autocast(enabled=False): attn_weights = attn_weights.to(torch.float32) attn_output = attn_output.to(torch.float32) - attn_weights_entropy = -((attn_weights + 1.0e-20).log() * attn_weights).sum( - dim=-1).reshape(bsz, num_heads, seq_len).mean(dim=(0,2)) + attn_weights_entropy = ( + -((attn_weights + 1.0e-20).log() * attn_weights) + .sum(dim=-1) + .reshape(bsz, num_heads, seq_len) + .mean(dim=(0, 2)) + ) attn_output = attn_output.reshape(bsz, num_heads, seq_len, head_dim) - attn_output = attn_output.permute(1, 0, 2, 3).reshape(num_heads, bsz * seq_len, head_dim) + attn_output = attn_output.permute(1, 0, 2, 3).reshape( + num_heads, bsz * seq_len, head_dim + ) attn_output_mean = attn_output.mean(dim=1, keepdim=True) attn_output = attn_output - attn_output_mean - attn_covar = torch.matmul(attn_output.transpose(1, 2), attn_output) / (bsz * seq_len) + attn_covar = torch.matmul(attn_output.transpose(1, 2), attn_output) / ( + bsz * seq_len + ) # attn_covar: (num_heads, head_dim, head_dim) - #eigs, _ = torch.symeig(attn_covar) - #logging.info(f"attn_weights_entropy = {attn_weights_entropy}, output_eigs = {eigs}") + # eigs, _ = torch.symeig(attn_covar) + # logging.info(f"attn_weights_entropy = {attn_weights_entropy}, output_eigs = {eigs}") attn_covar = _diag(attn_covar).mean(dim=1) # (num_heads,) embed_dim = self.in_proj2.weight.shape[1] - in_proj_covar = (self.in_proj2.weight.reshape(num_heads, head_dim, embed_dim) ** 2).mean(dim=(1,2)) - out_proj_covar = (self.out_proj2.weight.reshape(embed_dim, num_heads, head_dim) ** 2).mean(dim=(0,2)) - logging.info(f"attn_weights_entropy = {attn_weights_entropy}, covar={attn_covar}, in_proj_covar={in_proj_covar}, out_proj_covar={out_proj_covar}") - - + in_proj_covar = ( + self.in_proj2.weight.reshape(num_heads, head_dim, embed_dim) ** 2 + ).mean(dim=(1, 2)) + out_proj_covar = ( + self.out_proj2.weight.reshape(embed_dim, num_heads, head_dim) ** 2 + ).mean(dim=(0, 2)) + logging.info( + f"attn_weights_entropy = {attn_weights_entropy}," + f" covar={attn_covar}, in_proj_covar={in_proj_covar}," + f" out_proj_covar={out_proj_covar}" + ) class PoolingModule(nn.Module): """ Averages the input over the time dimension and project with a square matrix. """ - def __init__(self, - d_model: int): - super().__init__() - self.proj = ScaledLinear(d_model, d_model, - initial_scale=0.1, bias=False) - def forward(self, - x: Tensor, - key_padding_mask: Optional[Tensor] = None): + def __init__(self, d_model: int): + super().__init__() + self.proj = ScaledLinear(d_model, d_model, initial_scale=0.1, bias=False) + + def forward(self, x: Tensor, key_padding_mask: Optional[Tensor] = None): """ Args: x: a Tensor of shape (T, N, C) @@ -1430,7 +1463,7 @@ class PoolingModule(nn.Module): """ if key_padding_mask is not None: pooling_mask = key_padding_mask.logical_not().to(x.dtype) # (N, T) - pooling_mask = (pooling_mask / pooling_mask.sum(dim=1, keepdim=True)) + pooling_mask = pooling_mask / pooling_mask.sum(dim=1, keepdim=True) pooling_mask = pooling_mask.transpose(0, 1).contiguous().unsqueeze(-1) # now pooling_mask: (T, N, 1) x = (x * pooling_mask).sum(dim=0, keepdim=True) @@ -1444,24 +1477,19 @@ class PoolingModule(nn.Module): class FeedforwardModule(nn.Module): - """Feedforward module in Zipformer model. - """ - def __init__(self, - d_model: int, - feedforward_dim: int, - dropout: float): + """Feedforward module in Zipformer model.""" + + def __init__(self, d_model: int, feedforward_dim: int, dropout: float): super(FeedforwardModule, self).__init__() self.in_proj = nn.Linear(d_model, feedforward_dim) - self.balancer = ActivationBalancer(feedforward_dim, - channel_dim=-1, max_abs=10.0, - min_prob=0.25) + self.balancer = ActivationBalancer( + feedforward_dim, channel_dim=-1, max_abs=10.0, min_prob=0.25 + ) self.activation = DoubleSwish() self.dropout = nn.Dropout(dropout) - self.out_proj = ScaledLinear(feedforward_dim, d_model, - initial_scale=0.01) + self.out_proj = ScaledLinear(feedforward_dim, d_model, initial_scale=0.01) - def forward(self, - x: Tensor): + def forward(self, x: Tensor): x = self.in_proj(x) x = self.balancer(x) x = self.activation(x) @@ -1481,9 +1509,7 @@ class ConvolutionModule(nn.Module): """ - def __init__( - self, channels: int, kernel_size: int, bias: bool = True - ) -> None: + def __init__(self, channels: int, kernel_size: int, bias: bool = True) -> None: """Construct an ConvolutionModule object.""" super(ConvolutionModule, self).__init__() # kernerl_size should be a odd number for 'SAME' padding @@ -1513,7 +1539,10 @@ class ConvolutionModule(nn.Module): # the correct range. self.deriv_balancer1 = ActivationBalancer( 2 * channels, - channel_dim=1, max_abs=10.0, min_positive=0.05, max_positive=1.0 + channel_dim=1, + max_abs=10.0, + min_positive=0.05, + max_positive=1.0, ) self.depthwise_conv = nn.Conv1d( @@ -1527,8 +1556,10 @@ class ConvolutionModule(nn.Module): ) self.deriv_balancer2 = ActivationBalancer( - channels, channel_dim=1, - min_positive=0.05, max_positive=1.0, + channels, + channel_dim=1, + min_positive=0.05, + max_positive=1.0, max_abs=20.0, ) @@ -1544,9 +1575,10 @@ class ConvolutionModule(nn.Module): initial_scale=0.05, ) - def forward(self, - x: Tensor, - src_key_padding_mask: Optional[Tensor] = None, + def forward( + self, + x: Tensor, + src_key_padding_mask: Optional[Tensor] = None, ) -> Tensor: """Compute convolution module. @@ -1626,8 +1658,7 @@ class Conv2dSubsampling(nn.Module): kernel_size=3, padding=(0, 1), # (time, freq) ), - ActivationBalancer(layer1_channels, - channel_dim=1), + ActivationBalancer(layer1_channels, channel_dim=1), DoubleSwish(), nn.Conv2d( in_channels=layer1_channels, @@ -1636,24 +1667,21 @@ class Conv2dSubsampling(nn.Module): stride=2, padding=0, ), - ActivationBalancer(layer2_channels, - channel_dim=1), + ActivationBalancer(layer2_channels, channel_dim=1), DoubleSwish(), nn.Conv2d( in_channels=layer2_channels, out_channels=layer3_channels, kernel_size=3, - stride=(1, 2), # (time, freq) + stride=(1, 2), # (time, freq) ), - ActivationBalancer(layer3_channels, - channel_dim=1), + ActivationBalancer(layer3_channels, channel_dim=1), DoubleSwish(), ) out_height = (((in_channels - 1) // 2) - 1) // 2 self.out = ScaledLinear(out_height * layer3_channels, out_channels) self.dropout = nn.Dropout(dropout) - def forward(self, x: torch.Tensor) -> torch.Tensor: """Subsample x. @@ -1674,6 +1702,7 @@ class Conv2dSubsampling(nn.Module): x = self.dropout(x) return x + class AttentionCombine(nn.Module): """ This module combines a list of Tensors, all with the same shape, to @@ -1717,15 +1746,12 @@ class AttentionCombine(nn.Module): self.random_prob = random_prob self.single_prob = single_prob - self.weight = torch.nn.Parameter(torch.zeros(num_channels, - num_inputs)) + self.weight = torch.nn.Parameter(torch.zeros(num_channels, num_inputs)) self.bias = torch.nn.Parameter(torch.zeros(num_inputs)) assert 0 <= random_prob <= 1, random_prob assert 0 <= single_prob <= 1, single_prob - - def forward(self, inputs: List[Tensor]) -> Tensor: """Forward function. Args: @@ -1756,28 +1782,35 @@ class AttentionCombine(nn.Module): if self.training: # random masking.. - mask_start = torch.randint(low=1, high=int(num_inputs / self.random_prob), - size=(num_frames,), device=scores.device).unsqueeze(1) + mask_start = torch.randint( + low=1, + high=int(num_inputs / self.random_prob), + size=(num_frames,), + device=scores.device, + ).unsqueeze(1) # mask will have rows like: [ False, False, False, True, True, .. ] - arange = torch.arange(num_inputs, device=scores.device).unsqueeze(0).expand( - num_frames, num_inputs) + arange = ( + torch.arange(num_inputs, device=scores.device) + .unsqueeze(0) + .expand(num_frames, num_inputs) + ) mask = arange >= mask_start - apply_single_prob = torch.logical_and(torch.rand(size=(num_frames, 1), - device=scores.device) < self.single_prob, - mask_start < num_inputs) - single_prob_mask = torch.logical_and(apply_single_prob, - arange < mask_start - 1) + apply_single_prob = torch.logical_and( + torch.rand(size=(num_frames, 1), device=scores.device) + < self.single_prob, + mask_start < num_inputs, + ) + single_prob_mask = torch.logical_and( + apply_single_prob, arange < mask_start - 1 + ) - mask = torch.logical_or(mask, - single_prob_mask) + mask = torch.logical_or(mask, single_prob_mask) - scores = scores.masked_fill(mask, float('-inf')) + scores = scores.masked_fill(mask, float("-inf")) if self.training and random.random() < 0.1: - scores = penalize_abs_values_gt(scores, - limit=10.0, - penalty=1.0e-04) + scores = penalize_abs_values_gt(scores, limit=10.0, penalty=1.0e-04) weights = scores.softmax(dim=1) @@ -1792,7 +1825,6 @@ class AttentionCombine(nn.Module): return ans - def _test_random_combine(): print("_test_random_combine()") num_inputs = 3 @@ -1801,8 +1833,8 @@ def _test_random_combine(): num_channels=num_channels, num_inputs=num_inputs, random_prob=0.5, - single_prob=0.0) - + single_prob=0.0, + ) x = [torch.ones(3, 4, num_channels) for _ in range(num_inputs)] @@ -1819,7 +1851,10 @@ def _test_zipformer_main(): # Just make sure the forward pass runs. c = Zipformer( - num_features=feature_dim, encoder_dims=(64,96), encoder_unmasked_dims=(48,64), nhead=(4,4) + num_features=feature_dim, + encoder_dims=(64, 96), + encoder_unmasked_dims=(48, 64), + nhead=(4, 4), ) batch_size = 5 seq_len = 20 @@ -1837,19 +1872,18 @@ def _test_zipformer_main(): ) f # to remove flake8 warnings + def _test_conv2d_subsampling(): num_features = 80 encoder_dims = 384 dropout = 0.1 - encoder_embed = Conv2dSubsampling(num_features, encoder_dims, - dropout=dropout) + encoder_embed = Conv2dSubsampling(num_features, encoder_dims, dropout=dropout) for i in range(20, 40): x = torch.rand(2, i, num_features) y = encoder_embed(x) assert (x.shape[1] - 7) // 2 == y.shape[1], (x.shape[1], y.shape[1]) - if __name__ == "__main__": logging.getLogger().setLevel(logging.INFO) torch.set_num_threads(1) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless8/decode.py index 9d7335e77..822f8e44b 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless8/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/decode.py @@ -165,20 +165,24 @@ def get_parser(): "--avg", type=int, default=9, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'" + ), ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", + help=( + "Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. " + ), ) parser.add_argument( @@ -273,8 +277,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -394,9 +397,7 @@ def decode_one_batch( simulate_streaming=True, ) else: - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) hyps = [] @@ -455,10 +456,7 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -589,9 +587,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -624,8 +620,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -680,9 +675,7 @@ def main(): if "LG" in params.decoding_method: params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -719,13 +712,12 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -753,13 +745,12 @@ def main(): ) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -788,7 +779,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - f"Calculating the averaged model over epoch range from " + "Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -816,9 +807,7 @@ def main(): decoding_graph.scores *= params.ngram_lm_scale else: word_table = None - decoding_graph = k2.trivial_graph( - params.vocab_size - 1, device=device - ) + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) else: decoding_graph = None word_table = None diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/export.py b/egs/librispeech/ASR/pruned_transducer_stateless8/export.py index 49f469e29..43eb0c1bc 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless8/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/export.py @@ -129,20 +129,24 @@ def get_parser(): "--avg", type=int, default=9, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'" + ), ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", + help=( + "Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. " + ), ) parser.add_argument( @@ -176,8 +180,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) add_model_arguments(parser) @@ -217,13 +220,12 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -252,13 +254,12 @@ def main(): ) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -287,7 +288,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - f"Calculating the averaged model over epoch range from " + "Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -326,9 +327,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/jit_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless8/jit_pretrained.py index e79a3a3aa..ed920dc03 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless8/jit_pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/jit_pretrained.py @@ -69,10 +69,12 @@ def get_parser(): "sound_files", type=str, nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", + help=( + "The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz." + ), ) return parser @@ -93,10 +95,9 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" - ) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" # We use only the first channel ans.append(wave[0]) return ans @@ -267,9 +268,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/model.py b/egs/librispeech/ASR/pruned_transducer_stateless8/model.py index 497b89136..39a360796 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless8/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/model.py @@ -160,9 +160,7 @@ class Transducer(nn.Module): y_padded = y.pad(mode="constant", padding_value=0) y_padded = y_padded.to(torch.int64) - boundary = torch.zeros( - (x.size(0), 4), dtype=torch.int64, device=x.device - ) + boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) boundary[:, 2] = y_lens boundary[:, 3] = x_lens diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless8/pretrained.py index 373a48fc1..716136812 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless8/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/pretrained.py @@ -100,9 +100,11 @@ def get_parser(): "--checkpoint", type=str, required=True, - help="Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint().", + help=( + "Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint()." + ), ) parser.add_argument( @@ -127,10 +129,12 @@ def get_parser(): "sound_files", type=str, nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", + help=( + "The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz." + ), ) parser.add_argument( @@ -177,8 +181,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -209,10 +212,9 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" - ) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" # We use only the first channel ans.append(wave[0]) return ans @@ -275,15 +277,11 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) feature_lengths = torch.tensor(feature_lengths, device=device) - encoder_out, encoder_out_lens = model.encoder( - x=features, x_lens=feature_lengths - ) + encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths) num_waves = encoder_out.size(0) hyps = [] @@ -355,9 +353,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/train.py b/egs/librispeech/ASR/pruned_transducer_stateless8/train.py index 2603bb854..381a86a67 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless8/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/train.py @@ -92,9 +92,7 @@ from icefall.env import get_env_info from icefall.hooks import register_inf_check_hooks from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: @@ -132,7 +130,10 @@ def add_model_arguments(parser: argparse.ArgumentParser): "--encoder-dims", type=str, default="384,384,384,384,384", - help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated", + help=( + "Embedding dimension in the 2 blocks of zipformer encoder layers, comma" + " separated" + ), ) parser.add_argument( @@ -147,9 +148,11 @@ def add_model_arguments(parser: argparse.ArgumentParser): "--encoder-unmasked-dims", type=str, default="256,256,256,256,256", - help="Unmasked dimensions in the encoders, relates to augmentation during training. " - "Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance " - " worse.", + help=( + "Unmasked dimensions in the encoders, relates to augmentation during" + " training. Must be <= each of encoder_dims. Empirically, less than 256" + " seems to make performance worse." + ), ) parser.add_argument( @@ -214,8 +217,7 @@ def get_parser(): "--full-libri", type=str2bool, default=True, - help="When enabled, use 960h LibriSpeech. " - "Otherwise, use 100h subset.", + help="When enabled, use 960h LibriSpeech. Otherwise, use 100h subset.", ) parser.add_argument( @@ -285,42 +287,45 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", + help=( + "The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss" + ), ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", + help=( + "The scale to smooth the loss with lm (output of prediction network) part." + ), ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network)part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", + help=( + "To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss." + ), ) parser.add_argument( @@ -691,11 +696,7 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = ( - model.device - if isinstance(model, DDP) - else next(model.parameters()).device - ) + device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -744,9 +745,7 @@ def compute_loss( info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() @@ -952,9 +951,7 @@ def train_one_epoch( # of the grad scaler is configurable, but we can't configure it to have different # behavior depending on the current grad scale. cur_grad_scale = scaler._scale.item() - if cur_grad_scale < 1.0 or ( - cur_grad_scale < 8.0 and batch_idx % 400 == 0 - ): + if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0): scaler.update(cur_grad_scale * 2.0) if cur_grad_scale < 0.01: logging.warning(f"Grad scale is small: {cur_grad_scale}") @@ -975,11 +972,7 @@ def train_one_epoch( f"giga_tot_loss[{giga_tot_loss}], " f"batch size: {batch_size}, " f"lr: {cur_lr:.2e}, " - + ( - f"grad_scale: {scaler._scale.item()}" - if params.use_fp16 - else "" - ) + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") ) if tb_writer is not None: @@ -992,12 +985,8 @@ def train_one_epoch( f"train/current_{prefix}_", params.batch_idx_train, ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) libri_tot_loss.write_summary( tb_writer, "train/libri_tot_", params.batch_idx_train ) @@ -1011,10 +1000,7 @@ def train_one_epoch( params.batch_idx_train, ) - if ( - batch_idx % params.valid_interval == 0 - and not params.print_diagnostics - ): + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: logging.info("Computing validation loss") valid_info = compute_validation_loss( params=params, @@ -1026,7 +1012,8 @@ def train_one_epoch( model.train() logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + "Maximum memory allocated so far is" + f" {torch.cuda.max_memory_allocated()//1000000}MB" ) if tb_writer is not None: valid_info.write_summary( @@ -1054,8 +1041,7 @@ def filter_short_and_long_utterances( # the threshold if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. " - f"Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" ) return False @@ -1152,9 +1138,7 @@ def run(rank, world_size, args): logging.info("Using DDP") model = DDP(model, device_ids=[rank], find_unused_parameters=True) - optimizer = ScaledAdam( - model.parameters(), lr=params.base_lr, clipping_scale=2.0 - ) + optimizer = ScaledAdam(model.parameters(), lr=params.base_lr, clipping_scale=2.0) scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) @@ -1172,7 +1156,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2 ** 22 + 2**22 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) @@ -1207,9 +1191,7 @@ def run(rank, world_size, args): train_giga_cuts = train_giga_cuts.repeat(times=None) if args.enable_musan: - cuts_musan = load_manifest( - Path(args.manifest_dir) / "musan_cuts.jsonl.gz" - ) + cuts_musan = load_manifest(Path(args.manifest_dir) / "musan_cuts.jsonl.gz") else: cuts_musan = None @@ -1364,7 +1346,8 @@ def scan_pessimistic_batches_for_oom( display_and_save_batch(batch, params=params, sp=sp) raise logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + "Maximum memory allocated so far is" + f" {torch.cuda.max_memory_allocated()//1000000}MB" ) diff --git a/egs/librispeech/ASR/streaming_conformer_ctc/README.md b/egs/librispeech/ASR/streaming_conformer_ctc/README.md index 01be7090b..53f383c99 100644 --- a/egs/librispeech/ASR/streaming_conformer_ctc/README.md +++ b/egs/librispeech/ASR/streaming_conformer_ctc/README.md @@ -1,20 +1,20 @@ ## Train and Decode -Commands of data preparation/train/decode steps are almost the same with +Commands of data preparation/train/decode steps are almost the same with ../conformer_ctc experiment except some options. Please read the code and understand following new added options before running this experiment: For data preparation: - + Nothing new. For streaming_conformer_ctc/train.py: - + --dynamic-chunk-training --short-chunk-proportion For streaming_conformer_ctc/streaming_decode.py: - + --chunk-size --tailing-num-frames --simulate-streaming @@ -57,10 +57,10 @@ And check md5sum values again. Finally, following files will be downloaded:

-streaming_models/  
-|-- lang_bpe  
-|   |-- L.pt  
-|   |-- Linv.pt  
+streaming_models/
+|-- lang_bpe
+|   |-- L.pt
+|   |-- Linv.pt
 |   |-- bpe.model
 |   |-- tokens.txt
 |   `-- words.txt
diff --git a/egs/librispeech/ASR/streaming_conformer_ctc/conformer.py b/egs/librispeech/ASR/streaming_conformer_ctc/conformer.py
index ff4c91446..4f7427c1f 100644
--- a/egs/librispeech/ASR/streaming_conformer_ctc/conformer.py
+++ b/egs/librispeech/ASR/streaming_conformer_ctc/conformer.py
@@ -309,36 +309,26 @@ class Conformer(Transformer):
 
                 # start chunk_by_chunk decoding
                 offset = 0
-                for cur in range(
-                    0, num_frames - embed_left_context + 1, stride
-                ):
+                for cur in range(0, num_frames - embed_left_context + 1, stride):
                     end = min(cur + decoding_window, num_frames)
                     cur_feature = feature[:, cur:end, :]
                     cur_feature = self.encoder_embed(cur_feature)
-                    cur_embed, cur_pos_emb = self.encoder_pos(
-                        cur_feature, offset
-                    )
-                    cur_embed = cur_embed.permute(
-                        1, 0, 2
-                    )  # (B, T, F) -> (T, B, F)
+                    cur_embed, cur_pos_emb = self.encoder_pos(cur_feature, offset)
+                    cur_embed = cur_embed.permute(1, 0, 2)  # (B, T, F) -> (T, B, F)
 
                     cur_T = cur_feature.size(1)
                     if cur == 0:
                         # for first chunk extract the central pos embedding
-                        pos_emb_central = cur_pos_emb[
-                            0, (chunk_size - 1), :
-                        ].view(1, 1, -1)
+                        pos_emb_central = cur_pos_emb[0, (chunk_size - 1), :].view(
+                            1, 1, -1
+                        )
                         cur_T -= 1
                     pos_emb_positive.append(cur_pos_emb[0, :cur_T].flip(0))
                     pos_emb_negative.append(cur_pos_emb[0, -cur_T:])
                     assert pos_emb_positive[-1].size(0) == cur_T
 
-                    pos_emb_pos = torch.cat(pos_emb_positive, dim=0).unsqueeze(
-                        0
-                    )
-                    pos_emb_neg = torch.cat(pos_emb_negative, dim=0).unsqueeze(
-                        0
-                    )
+                    pos_emb_pos = torch.cat(pos_emb_positive, dim=0).unsqueeze(0)
+                    pos_emb_neg = torch.cat(pos_emb_negative, dim=0).unsqueeze(0)
                     cur_pos_emb = torch.cat(
                         [pos_emb_pos.flip(1), pos_emb_central, pos_emb_neg],
                         dim=1,
@@ -413,9 +403,7 @@ class ConformerEncoderLayer(nn.Module):
         causal: bool = False,
     ) -> None:
         super(ConformerEncoderLayer, self).__init__()
-        self.self_attn = RelPositionMultiheadAttention(
-            d_model, nhead, dropout=0.0
-        )
+        self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0)
 
         self.feed_forward = nn.Sequential(
             nn.Linear(d_model, dim_feedforward),
@@ -431,22 +419,16 @@ class ConformerEncoderLayer(nn.Module):
             nn.Linear(dim_feedforward, d_model),
         )
 
-        self.conv_module = ConvolutionModule(
-            d_model, cnn_module_kernel, causal=causal
-        )
+        self.conv_module = ConvolutionModule(d_model, cnn_module_kernel, causal=causal)
 
-        self.norm_ff_macaron = nn.LayerNorm(
-            d_model
-        )  # for the macaron style FNN module
+        self.norm_ff_macaron = nn.LayerNorm(d_model)  # for the macaron style FNN module
         self.norm_ff = nn.LayerNorm(d_model)  # for the FNN module
         self.norm_mha = nn.LayerNorm(d_model)  # for the MHA module
 
         self.ff_scale = 0.5
 
         self.norm_conv = nn.LayerNorm(d_model)  # for the CNN module
-        self.norm_final = nn.LayerNorm(
-            d_model
-        )  # for the final output of the block
+        self.norm_final = nn.LayerNorm(d_model)  # for the final output of the block
 
         self.dropout = nn.Dropout(dropout)
 
@@ -480,9 +462,7 @@ class ConformerEncoderLayer(nn.Module):
         residual = src
         if self.normalize_before:
             src = self.norm_ff_macaron(src)
-        src = residual + self.ff_scale * self.dropout(
-            self.feed_forward_macaron(src)
-        )
+        src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src))
         if not self.normalize_before:
             src = self.norm_ff_macaron(src)
 
@@ -554,9 +534,7 @@ class ConformerEncoderLayer(nn.Module):
         residual = src
         if self.normalize_before:
             src = self.norm_ff_macaron(src)
-        src = residual + self.ff_scale * self.dropout(
-            self.feed_forward_macaron(src)
-        )
+        src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src))
         if not self.normalize_before:
             src = self.norm_ff_macaron(src)
 
@@ -736,9 +714,7 @@ class RelPositionalEncoding(torch.nn.Module):
 
     """
 
-    def __init__(
-        self, d_model: int, dropout_rate: float, max_len: int = 5000
-    ) -> None:
+    def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None:
         """Construct an PositionalEncoding object."""
         super(RelPositionalEncoding, self).__init__()
         self.d_model = d_model
@@ -755,9 +731,7 @@ class RelPositionalEncoding(torch.nn.Module):
             # the length of self.pe is 2 * input_len - 1
             if self.pe.size(1) >= x_size_1 * 2 - 1:
                 # Note: TorchScript doesn't implement operator== for torch.Device
-                if self.pe.dtype != x.dtype or str(self.pe.device) != str(
-                    x.device
-                ):
+                if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device):
                     self.pe = self.pe.to(dtype=x.dtype, device=x.device)
                 return
         # Suppose `i` means to the position of query vector and `j` means the
@@ -783,9 +757,7 @@ class RelPositionalEncoding(torch.nn.Module):
         pe = torch.cat([pe_positive, pe_negative], dim=1)
         self.pe = pe.to(device=x.device, dtype=x.dtype)
 
-    def forward(
-        self, x: torch.Tensor, offset: int = 0
-    ) -> Tuple[Tensor, Tensor]:
+    def forward(self, x: torch.Tensor, offset: int = 0) -> Tuple[Tensor, Tensor]:
         """Add positional encoding.
 
         Args:
@@ -813,9 +785,7 @@ class RelPositionalEncoding(torch.nn.Module):
             pos_emb = torch.cat(
                 [
                     pos_emb[:, : (x_T - 1)],
-                    self.pe[0, self.pe.size(1) // 2].view(
-                        1, 1, self.pe.size(-1)
-                    ),
+                    self.pe[0, self.pe.size(1) // 2].view(1, 1, self.pe.size(-1)),
                     pos_emb[:, -(x_T - 1) :],  # noqa: E203
                 ],
                 dim=1,
@@ -1050,9 +1020,9 @@ class RelPositionMultiheadAttention(nn.Module):
 
         if torch.equal(query, key) and torch.equal(key, value):
             # self-attention
-            q, k, v = nn.functional.linear(
-                query, in_proj_weight, in_proj_bias
-            ).chunk(3, dim=-1)
+            q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk(
+                3, dim=-1
+            )
 
         elif torch.equal(key, value):
             # encoder-decoder attention
@@ -1120,33 +1090,25 @@ class RelPositionMultiheadAttention(nn.Module):
             if attn_mask.dim() == 2:
                 attn_mask = attn_mask.unsqueeze(0)
                 if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
-                    raise RuntimeError(
-                        "The size of the 2D attn_mask is not correct."
-                    )
+                    raise RuntimeError("The size of the 2D attn_mask is not correct.")
             elif attn_mask.dim() == 3:
                 if list(attn_mask.size()) != [
                     bsz * num_heads,
                     query.size(0),
                     key.size(0),
                 ]:
-                    raise RuntimeError(
-                        "The size of the 3D attn_mask is not correct."
-                    )
+                    raise RuntimeError("The size of the 3D attn_mask is not correct.")
             else:
                 raise RuntimeError(
-                    "attn_mask's dimension {} is not supported".format(
-                        attn_mask.dim()
-                    )
+                    "attn_mask's dimension {} is not supported".format(attn_mask.dim())
                 )
             # attn_mask's dim is 3 now.
 
         # convert ByteTensor key_padding_mask to bool
-        if (
-            key_padding_mask is not None
-            and key_padding_mask.dtype == torch.uint8
-        ):
+        if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
             warnings.warn(
-                "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead."
+                "Byte tensor for key_padding_mask is deprecated. Use bool tensor"
+                " instead."
             )
             key_padding_mask = key_padding_mask.to(torch.bool)
 
@@ -1185,24 +1147,16 @@ class RelPositionMultiheadAttention(nn.Module):
         # first compute matrix a and matrix c
         # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
         k = k.permute(1, 2, 3, 0)  # (batch, head, d_k, time2)
-        matrix_ac = torch.matmul(
-            q_with_bias_u, k
-        )  # (batch, head, time1, time2)
+        matrix_ac = torch.matmul(q_with_bias_u, k)  # (batch, head, time1, time2)
 
         # compute matrix b and matrix d
-        matrix_bd = torch.matmul(
-            q_with_bias_v, p
-        )  # (batch, head, time1, 2*time1-1)
-        matrix_bd = self.rel_shift(
-            matrix_bd, offset=offset
-        )  # [B, head, time1, time2]
+        matrix_bd = torch.matmul(q_with_bias_v, p)  # (batch, head, time1, 2*time1-1)
+        matrix_bd = self.rel_shift(matrix_bd, offset=offset)  # [B, head, time1, time2]
         attn_output_weights = (
             matrix_ac + matrix_bd
         ) * scaling  # (batch, head, time1, time2)
 
-        attn_output_weights = attn_output_weights.view(
-            bsz * num_heads, tgt_len, -1
-        )
+        attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1)
 
         assert list(attn_output_weights.size()) == [
             bsz * num_heads,
@@ -1236,13 +1190,9 @@ class RelPositionMultiheadAttention(nn.Module):
         attn_output = torch.bmm(attn_output_weights, v)
         assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
         attn_output = (
-            attn_output.transpose(0, 1)
-            .contiguous()
-            .view(tgt_len, bsz, embed_dim)
-        )
-        attn_output = nn.functional.linear(
-            attn_output, out_proj_weight, out_proj_bias
+            attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
         )
+        attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias)
 
         if need_weights:
             # average attention weights over heads
diff --git a/egs/librispeech/ASR/streaming_conformer_ctc/streaming_decode.py b/egs/librispeech/ASR/streaming_conformer_ctc/streaming_decode.py
index a74c51836..5a8149aad 100755
--- a/egs/librispeech/ASR/streaming_conformer_ctc/streaming_decode.py
+++ b/egs/librispeech/ASR/streaming_conformer_ctc/streaming_decode.py
@@ -28,6 +28,7 @@ import torch
 import torch.nn as nn
 from asr_datamodule import LibriSpeechAsrDataModule
 from conformer import Conformer
+
 from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
 from icefall.checkpoint import average_checkpoints, load_checkpoint
 from icefall.lexicon import Lexicon
@@ -62,32 +63,36 @@ def get_parser():
         "--epoch",
         type=int,
         default=34,
-        help="It specifies the checkpoint to use for decoding."
-        "Note: Epoch counts from 0.",
+        help=(
+            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
+        ),
     )
     parser.add_argument(
         "--avg",
         type=int,
         default=20,
-        help="Number of checkpoints to average. Automatically select "
-        "consecutive checkpoints before the checkpoint specified by "
-        "'--epoch'. ",
+        help=(
+            "Number of checkpoints to average. Automatically select "
+            "consecutive checkpoints before the checkpoint specified by "
+            "'--epoch'. "
+        ),
     )
 
     parser.add_argument(
         "--chunk-size",
         type=int,
         default=8,
-        help="Frames of right context"
-        "-1 for whole right context, i.e. non-streaming decoding",
+        help=(
+            "Frames of right context"
+            "-1 for whole right context, i.e. non-streaming decoding"
+        ),
     )
 
     parser.add_argument(
         "--tailing-num-frames",
         type=int,
         default=20,
-        help="tailing dummy frames padded to the right,"
-        "only used during decoding",
+        help="tailing dummy frames padded to the right,only used during decoding",
     )
 
     parser.add_argument(
@@ -139,8 +144,7 @@ def get_parser():
         "--avg-models",
         type=str,
         default=None,
-        help="Manually select models to average, seperated by comma;"
-        "e.g. 60,62,63,72",
+        help="Manually select models to average, seperated by comma;e.g. 60,62,63,72",
     )
 
     return parser
@@ -248,13 +252,9 @@ def decode_one_batch(
     maxlen = nnet_output.size(1)
     topk_prob, topk_index = nnet_output.topk(1, dim=2)  # (B, maxlen, 1)
     topk_index = topk_index.view(batch_size, maxlen)  # (B, maxlen)
-    topk_index = topk_index.masked_fill_(
-        memory_key_padding_mask, 0
-    )  # (B, maxlen)
+    topk_index = topk_index.masked_fill_(memory_key_padding_mask, 0)  # (B, maxlen)
     token_ids = [token_id.tolist() for token_id in topk_index]
-    token_ids = [
-        remove_duplicates_and_blank(token_id) for token_id in token_ids
-    ]
+    token_ids = [remove_duplicates_and_blank(token_id) for token_id in token_ids]
     hyps = bpe_model.decode(token_ids)
     hyps = [s.split() for s in hyps]
     return {key: hyps}
@@ -337,9 +337,7 @@ def decode_dataset(
         if batch_idx % 100 == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
 
     return results
 
@@ -357,15 +355,18 @@ def save_results(
     test_set_wers = dict()
     if params.avg_models is not None:
         avg_models = params.avg_models.replace(",", "_")
-        result_file_prefix = f"epoch-avg-{avg_models}-chunksize \
-        -{params.chunk_size}-tailing-num-frames-{params.tailing_num_frames}-"
+        result_file_prefix = (
+            f"epoch-avg-{avg_models}-chunksize        "
+            f" -{params.chunk_size}-tailing-num-frames-{params.tailing_num_frames}-"
+        )
     else:
-        result_file_prefix = f"epoch-{params.epoch}-avg-{params.avg}-chunksize \
-        -{params.chunk_size}-tailing-num-frames-{params.tailing_num_frames}-"
+        result_file_prefix = (
+            f"epoch-{params.epoch}-avg-{params.avg}-chunksize        "
+            f" -{params.chunk_size}-tailing-num-frames-{params.tailing_num_frames}-"
+        )
     for key, results in results_dict.items():
         recog_path = (
-            params.exp_dir
-            / f"{result_file_prefix}recogs-{test_set_name}-{key}.txt"
+            params.exp_dir / f"{result_file_prefix}recogs-{test_set_name}-{key}.txt"
         )
         store_transcripts(filename=recog_path, texts=results)
         if enable_log:
@@ -374,8 +375,7 @@ def save_results(
         # The following prints out WERs, per-word error statistics and aligned
         # ref/hyp pairs.
         errs_filename = (
-            params.exp_dir
-            / f"{result_file_prefix}-errs-{test_set_name}-{key}.txt"
+            params.exp_dir / f"{result_file_prefix}-errs-{test_set_name}-{key}.txt"
         )
         with open(errs_filename, "w") as f:
             wer = write_error_stats(
@@ -384,9 +384,7 @@ def save_results(
             test_set_wers[key] = wer
 
         if enable_log:
-            logging.info(
-                "Wrote detailed error stats to {}".format(errs_filename)
-            )
+            logging.info("Wrote detailed error stats to {}".format(errs_filename))
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt"
@@ -474,9 +472,7 @@ def main():
 
     if params.export:
         logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
-        torch.save(
-            {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
-        )
+        torch.save({"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt")
         return
 
     model.to(device)
@@ -507,9 +503,7 @@ def main():
             simulate_streaming=params.simulate_streaming,
         )
 
-        save_results(
-            params=params, test_set_name=test_set, results_dict=results_dict
-        )
+        save_results(params=params, test_set_name=test_set, results_dict=results_dict)
 
     logging.info("Done!")
 
diff --git a/egs/librispeech/ASR/streaming_conformer_ctc/train.py b/egs/librispeech/ASR/streaming_conformer_ctc/train.py
index e41b7ea78..553b7d092 100755
--- a/egs/librispeech/ASR/streaming_conformer_ctc/train.py
+++ b/egs/librispeech/ASR/streaming_conformer_ctc/train.py
@@ -405,9 +405,7 @@ def compute_loss(
             #
             # See https://github.com/k2-fsa/icefall/issues/97
             # for more details
-            unsorted_token_ids = graph_compiler.texts_to_ids(
-                supervisions["text"]
-            )
+            unsorted_token_ids = graph_compiler.texts_to_ids(supervisions["text"])
             att_loss = mmodel.decoder_forward(
                 encoder_memory,
                 memory_mask,
@@ -436,9 +434,7 @@ def compute_loss(
     info["utt_duration"] = supervisions["num_frames"].sum().item()
     # averaged padding proportion over utterances
     info["utt_pad_proportion"] = (
-        ((feature.size(1) - supervisions["num_frames"]) / feature.size(1))
-        .sum()
-        .item()
+        ((feature.size(1) - supervisions["num_frames"]) / feature.size(1)).sum().item()
     )
 
     return loss, info
@@ -551,9 +547,7 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(
-                    tb_writer, "train/tot_", params.batch_idx_train
-                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
@@ -668,9 +662,7 @@ def run(rank, world_size, args):
 
         cur_lr = optimizer._rate
         if tb_writer is not None:
-            tb_writer.add_scalar(
-                "train/learning_rate", cur_lr, params.batch_idx_train
-            )
+            tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
             tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
 
         if rank == 0:
diff --git a/egs/librispeech/ASR/streaming_conformer_ctc/transformer.py b/egs/librispeech/ASR/streaming_conformer_ctc/transformer.py
index bc78e4a41..0c87fdf1b 100644
--- a/egs/librispeech/ASR/streaming_conformer_ctc/transformer.py
+++ b/egs/librispeech/ASR/streaming_conformer_ctc/transformer.py
@@ -149,9 +149,7 @@ class Transformer(nn.Module):
                 norm=decoder_norm,
             )
 
-            self.decoder_output_layer = torch.nn.Linear(
-                d_model, self.decoder_num_class
-            )
+            self.decoder_output_layer = torch.nn.Linear(d_model, self.decoder_num_class)
 
             self.decoder_criterion = LabelSmoothingLoss()
         else:
@@ -286,23 +284,17 @@ class Transformer(nn.Module):
         """
         ys_in = add_sos(token_ids, sos_id=sos_id)
         ys_in = [torch.tensor(y) for y in ys_in]
-        ys_in_pad = pad_sequence(
-            ys_in, batch_first=True, padding_value=float(eos_id)
-        )
+        ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id))
 
         ys_out = add_eos(token_ids, eos_id=eos_id)
         ys_out = [torch.tensor(y) for y in ys_out]
-        ys_out_pad = pad_sequence(
-            ys_out, batch_first=True, padding_value=float(-1)
-        )
+        ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1))
 
         device = memory.device
         ys_in_pad = ys_in_pad.to(device)
         ys_out_pad = ys_out_pad.to(device)
 
-        tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(
-            device
-        )
+        tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device)
 
         tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id)
         # TODO: Use length information to create the decoder padding mask
@@ -363,23 +355,17 @@ class Transformer(nn.Module):
 
         ys_in = add_sos(token_ids, sos_id=sos_id)
         ys_in = [torch.tensor(y) for y in ys_in]
-        ys_in_pad = pad_sequence(
-            ys_in, batch_first=True, padding_value=float(eos_id)
-        )
+        ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id))
 
         ys_out = add_eos(token_ids, eos_id=eos_id)
         ys_out = [torch.tensor(y) for y in ys_out]
-        ys_out_pad = pad_sequence(
-            ys_out, batch_first=True, padding_value=float(-1)
-        )
+        ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1))
 
         device = memory.device
         ys_in_pad = ys_in_pad.to(device, dtype=torch.int64)
         ys_out_pad = ys_out_pad.to(device, dtype=torch.int64)
 
-        tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(
-            device
-        )
+        tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device)
 
         tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id)
         # TODO: Use length information to create the decoder padding mask
@@ -652,9 +638,7 @@ def _get_activation_fn(activation: str):
     elif activation == "gelu":
         return nn.functional.gelu
 
-    raise RuntimeError(
-        "activation should be relu/gelu, not {}".format(activation)
-    )
+    raise RuntimeError("activation should be relu/gelu, not {}".format(activation))
 
 
 class PositionalEncoding(nn.Module):
@@ -856,9 +840,7 @@ def encoder_padding_mask(
         1,
     ).to(torch.int32)
 
-    lengths = [
-        0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)
-    ]
+    lengths = [0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)]
     for idx in range(supervision_segments.size(0)):
         # Note: TorchScript doesn't allow to unpack tensors as tuples
         sequence_idx = supervision_segments[idx, 0].item()
@@ -879,9 +861,7 @@ def encoder_padding_mask(
     return mask
 
 
-def decoder_padding_mask(
-    ys_pad: torch.Tensor, ignore_id: int = -1
-) -> torch.Tensor:
+def decoder_padding_mask(ys_pad: torch.Tensor, ignore_id: int = -1) -> torch.Tensor:
     """Generate a length mask for input.
 
     The masked position are filled with True,
diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py
index 355ccc99a..63afd6be2 100644
--- a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py
+++ b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py
@@ -77,17 +77,18 @@ class LibriSpeechAsrDataModule:
     def add_arguments(cls, parser: argparse.ArgumentParser):
         group = parser.add_argument_group(
             title="ASR data related options",
-            description="These options are used for the preparation of "
-            "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
-            "effective batch sizes, sampling strategies, applied data "
-            "augmentations, etc.",
+            description=(
+                "These options are used for the preparation of "
+                "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
+                "effective batch sizes, sampling strategies, applied data "
+                "augmentations, etc."
+            ),
         )
         group.add_argument(
             "--full-libri",
             type=str2bool,
             default=True,
-            help="When enabled, use 960h LibriSpeech. "
-            "Otherwise, use 100h subset.",
+            help="When enabled, use 960h LibriSpeech. Otherwise, use 100h subset.",
         )
         group.add_argument(
             "--manifest-dir",
@@ -99,59 +100,74 @@ class LibriSpeechAsrDataModule:
             "--max-duration",
             type=int,
             default=200.0,
-            help="Maximum pooled recordings duration (seconds) in a "
-            "single batch. You can reduce it if it causes CUDA OOM.",
+            help=(
+                "Maximum pooled recordings duration (seconds) in a "
+                "single batch. You can reduce it if it causes CUDA OOM."
+            ),
         )
         group.add_argument(
             "--bucketing-sampler",
             type=str2bool,
             default=True,
-            help="When enabled, the batches will come from buckets of "
-            "similar duration (saves padding frames).",
+            help=(
+                "When enabled, the batches will come from buckets of "
+                "similar duration (saves padding frames)."
+            ),
         )
         group.add_argument(
             "--num-buckets",
             type=int,
             default=30,
-            help="The number of buckets for the DynamicBucketingSampler"
-            "(you might want to increase it for larger datasets).",
+            help=(
+                "The number of buckets for the DynamicBucketingSampler"
+                "(you might want to increase it for larger datasets)."
+            ),
         )
         group.add_argument(
             "--concatenate-cuts",
             type=str2bool,
             default=False,
-            help="When enabled, utterances (cuts) will be concatenated "
-            "to minimize the amount of padding.",
+            help=(
+                "When enabled, utterances (cuts) will be concatenated "
+                "to minimize the amount of padding."
+            ),
         )
         group.add_argument(
             "--duration-factor",
             type=float,
             default=1.0,
-            help="Determines the maximum duration of a concatenated cut "
-            "relative to the duration of the longest cut in a batch.",
+            help=(
+                "Determines the maximum duration of a concatenated cut "
+                "relative to the duration of the longest cut in a batch."
+            ),
         )
         group.add_argument(
             "--gap",
             type=float,
             default=1.0,
-            help="The amount of padding (in seconds) inserted between "
-            "concatenated cuts. This padding is filled with noise when "
-            "noise augmentation is used.",
+            help=(
+                "The amount of padding (in seconds) inserted between "
+                "concatenated cuts. This padding is filled with noise when "
+                "noise augmentation is used."
+            ),
         )
         group.add_argument(
             "--on-the-fly-feats",
             type=str2bool,
             default=False,
-            help="When enabled, use on-the-fly cut mixing and feature "
-            "extraction. Will drop existing precomputed feature manifests "
-            "if available.",
+            help=(
+                "When enabled, use on-the-fly cut mixing and feature "
+                "extraction. Will drop existing precomputed feature manifests "
+                "if available."
+            ),
         )
         group.add_argument(
             "--shuffle",
             type=str2bool,
             default=True,
-            help="When enabled (=default), the examples will be "
-            "shuffled for each epoch.",
+            help=(
+                "When enabled (=default), the examples will be shuffled for each epoch."
+            ),
         )
         group.add_argument(
             "--drop-last",
@@ -163,17 +179,18 @@ class LibriSpeechAsrDataModule:
             "--return-cuts",
             type=str2bool,
             default=True,
-            help="When enabled, each batch will have the "
-            "field: batch['supervisions']['cut'] with the cuts that "
-            "were used to construct it.",
+            help=(
+                "When enabled, each batch will have the "
+                "field: batch['supervisions']['cut'] with the cuts that "
+                "were used to construct it."
+            ),
         )
 
         group.add_argument(
             "--num-workers",
             type=int,
             default=2,
-            help="The number of training dataloader workers that "
-            "collect the batches.",
+            help="The number of training dataloader workers that collect the batches.",
         )
 
         group.add_argument(
@@ -187,18 +204,22 @@ class LibriSpeechAsrDataModule:
             "--spec-aug-time-warp-factor",
             type=int,
             default=80,
-            help="Used only when --enable-spec-aug is True. "
-            "It specifies the factor for time warping in SpecAugment. "
-            "Larger values mean more warping. "
-            "A value less than 1 means to disable time warp.",
+            help=(
+                "Used only when --enable-spec-aug is True. "
+                "It specifies the factor for time warping in SpecAugment. "
+                "Larger values mean more warping. "
+                "A value less than 1 means to disable time warp."
+            ),
         )
 
         group.add_argument(
             "--enable-musan",
             type=str2bool,
             default=True,
-            help="When enabled, select noise from MUSAN and mix it"
-            "with training dataset. ",
+            help=(
+                "When enabled, select noise from MUSAN and mix it"
+                "with training dataset. "
+            ),
         )
 
         group.add_argument(
@@ -224,20 +245,16 @@ class LibriSpeechAsrDataModule:
         if self.args.enable_musan:
             logging.info("Enable MUSAN")
             logging.info("About to get Musan cuts")
-            cuts_musan = load_manifest(
-                self.args.manifest_dir / "musan_cuts.jsonl.gz"
-            )
+            cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
             transforms.append(
-                CutMix(
-                    cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
-                )
+                CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
             )
         else:
             logging.info("Disable MUSAN")
 
         if self.args.concatenate_cuts:
             logging.info(
-                f"Using cut concatenation with duration factor "
+                "Using cut concatenation with duration factor "
                 f"{self.args.duration_factor} and gap {self.args.gap}."
             )
             # Cut concatenation should be the first transform in the list,
@@ -252,9 +269,7 @@ class LibriSpeechAsrDataModule:
         input_transforms = []
         if self.args.enable_spec_aug:
             logging.info("Enable SpecAugment")
-            logging.info(
-                f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
-            )
+            logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
             # Set the value of num_frame_masks according to Lhotse's version.
             # In different Lhotse's versions, the default of num_frame_masks is
             # different.
@@ -298,9 +313,7 @@ class LibriSpeechAsrDataModule:
             # Drop feats to be on the safe side.
             train = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(
-                    Fbank(FbankConfig(num_mel_bins=80))
-                ),
+                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
                 input_transforms=input_transforms,
                 return_cuts=self.args.return_cuts,
             )
@@ -356,9 +369,7 @@ class LibriSpeechAsrDataModule:
         if self.args.on_the_fly_feats:
             validate = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(
-                    Fbank(FbankConfig(num_mel_bins=80))
-                ),
+                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
                 return_cuts=self.args.return_cuts,
             )
         else:
diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py b/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py
index 7d0cd0bf3..94ba0a4dc 100755
--- a/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py
+++ b/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py
@@ -57,16 +57,19 @@ def get_parser():
         "--epoch",
         type=int,
         default=19,
-        help="It specifies the checkpoint to use for decoding."
-        "Note: Epoch counts from 0.",
+        help=(
+            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
+        ),
     )
     parser.add_argument(
         "--avg",
         type=int,
         default=5,
-        help="Number of checkpoints to average. Automatically select "
-        "consecutive checkpoints before the checkpoint specified by "
-        "'--epoch'. ",
+        help=(
+            "Number of checkpoints to average. Automatically select "
+            "consecutive checkpoints before the checkpoint specified by "
+            "'--epoch'. "
+        ),
     )
     parser.add_argument(
         "--method",
@@ -336,9 +339,7 @@ def decode_dataset(
         if batch_idx % 100 == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
     return results
 
 
@@ -400,9 +401,7 @@ def main():
 
     logging.info(f"device: {device}")
 
-    HLG = k2.Fsa.from_dict(
-        torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")
-    )
+    HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu"))
     HLG = HLG.to(device)
     assert HLG.requires_grad is False
 
@@ -467,9 +466,7 @@ def main():
 
     if params.export:
         logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
-        torch.save(
-            {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
-        )
+        torch.save({"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt")
         return
 
     model.to(device)
@@ -498,9 +495,7 @@ def main():
             G=G,
         )
 
-        save_results(
-            params=params, test_set_name=test_set, results_dict=results_dict
-        )
+        save_results(params=params, test_set_name=test_set, results_dict=results_dict)
 
     logging.info("Done!")
 
diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/model.py b/egs/librispeech/ASR/tdnn_lstm_ctc/model.py
index 5e04c11b4..1731e1ebe 100644
--- a/egs/librispeech/ASR/tdnn_lstm_ctc/model.py
+++ b/egs/librispeech/ASR/tdnn_lstm_ctc/model.py
@@ -66,10 +66,7 @@ class TdnnLstm(nn.Module):
             nn.BatchNorm1d(num_features=500, affine=False),
         )
         self.lstms = nn.ModuleList(
-            [
-                nn.LSTM(input_size=500, hidden_size=500, num_layers=1)
-                for _ in range(5)
-            ]
+            [nn.LSTM(input_size=500, hidden_size=500, num_layers=1) for _ in range(5)]
         )
         self.lstm_bnorms = nn.ModuleList(
             [nn.BatchNorm1d(num_features=500, affine=False) for _ in range(5)]
diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py b/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py
index 2baeb6bba..722e8f003 100755
--- a/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py
+++ b/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py
@@ -29,11 +29,7 @@ import torchaudio
 from model import TdnnLstm
 from torch.nn.utils.rnn import pad_sequence
 
-from icefall.decode import (
-    get_lattice,
-    one_best_decoding,
-    rescore_with_whole_lattice,
-)
+from icefall.decode import get_lattice, one_best_decoding, rescore_with_whole_lattice
 from icefall.utils import AttributeDict, get_texts
 
 
@@ -46,9 +42,11 @@ def get_parser():
         "--checkpoint",
         type=str,
         required=True,
-        help="Path to the checkpoint. "
-        "The checkpoint is assumed to be saved by "
-        "icefall.checkpoint.save_checkpoint().",
+        help=(
+            "Path to the checkpoint. "
+            "The checkpoint is assumed to be saved by "
+            "icefall.checkpoint.save_checkpoint()."
+        ),
     )
 
     parser.add_argument(
@@ -58,9 +56,7 @@ def get_parser():
         help="Path to words.txt",
     )
 
-    parser.add_argument(
-        "--HLG", type=str, required=True, help="Path to HLG.pt."
-    )
+    parser.add_argument("--HLG", type=str, required=True, help="Path to HLG.pt.")
 
     parser.add_argument(
         "--method",
@@ -103,10 +99,12 @@ def get_parser():
         "sound_files",
         type=str,
         nargs="+",
-        help="The input sound file(s) to transcribe. "
-        "Supported formats are those supported by torchaudio.load(). "
-        "For example, wav and flac are supported. "
-        "The sample rate has to be 16kHz.",
+        help=(
+            "The input sound file(s) to transcribe. "
+            "Supported formats are those supported by torchaudio.load(). "
+            "For example, wav and flac are supported. "
+            "The sample rate has to be 16kHz."
+        ),
     )
 
     return parser
@@ -144,10 +142,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. "
-            f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
@@ -215,9 +212,7 @@ def main():
     logging.info("Decoding started")
     features = fbank(waves)
 
-    features = pad_sequence(
-        features, batch_first=True, padding_value=math.log(1e-10)
-    )
+    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
     features = features.permute(0, 2, 1)  # now features is (N, C, T)
 
     with torch.no_grad():
@@ -269,9 +264,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/train.py b/egs/librispeech/ASR/tdnn_lstm_ctc/train.py
index 6b37d5c23..071ac792b 100755
--- a/egs/librispeech/ASR/tdnn_lstm_ctc/train.py
+++ b/egs/librispeech/ASR/tdnn_lstm_ctc/train.py
@@ -355,9 +355,7 @@ def compute_loss(
     info["utt_duration"] = supervisions["num_frames"].sum().item()
     # averaged padding proportion over utterances
     info["utt_pad_proportion"] = (
-        ((feature.size(2) - supervisions["num_frames"]) / feature.size(2))
-        .sum()
-        .item()
+        ((feature.size(2) - supervisions["num_frames"]) / feature.size(2)).sum().item()
     )
 
     return loss, info
@@ -470,9 +468,7 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(
-                    tb_writer, "train/tot_", params.batch_idx_train
-                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             valid_info = compute_validation_loss(
diff --git a/egs/librispeech/ASR/transducer/beam_search.py b/egs/librispeech/ASR/transducer/beam_search.py
index 11032f31a..b45b6a9d8 100644
--- a/egs/librispeech/ASR/transducer/beam_search.py
+++ b/egs/librispeech/ASR/transducer/beam_search.py
@@ -38,9 +38,7 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]:
     blank_id = model.decoder.blank_id
     device = model.device
 
-    sos = torch.tensor([blank_id], device=device, dtype=torch.int64).reshape(
-        1, 1
-    )
+    sos = torch.tensor([blank_id], device=device, dtype=torch.int64).reshape(1, 1)
     decoder_out, (h, c) = model.decoder(sos)
     T = encoder_out.size(1)
     t = 0
@@ -123,9 +121,7 @@ def beam_search(
     max_u = 20000  # terminate after this number of steps
     u = 0
 
-    cache: Dict[
-        str, Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
-    ] = {}
+    cache: Dict[str, Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = {}
 
     while t < T and u < max_u:
         # fmt: off
@@ -157,9 +153,9 @@ def beam_search(
             cached_key = "_".join(map(str, y_star.ys))
 
             if cached_key not in cache:
-                decoder_input = torch.tensor(
-                    [y_star.ys[-1]], device=device
-                ).reshape(1, 1)
+                decoder_input = torch.tensor([y_star.ys[-1]], device=device).reshape(
+                    1, 1
+                )
 
                 decoder_out, decoder_state = model.decoder(
                     decoder_input,
diff --git a/egs/librispeech/ASR/transducer/decode.py b/egs/librispeech/ASR/transducer/decode.py
index 5f233df87..f30332cea 100755
--- a/egs/librispeech/ASR/transducer/decode.py
+++ b/egs/librispeech/ASR/transducer/decode.py
@@ -71,16 +71,19 @@ def get_parser():
         "--epoch",
         type=int,
         default=34,
-        help="It specifies the checkpoint to use for decoding."
-        "Note: Epoch counts from 0.",
+        help=(
+            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
+        ),
     )
     parser.add_argument(
         "--avg",
         type=int,
         default=11,
-        help="Number of checkpoints to average. Automatically select "
-        "consecutive checkpoints before the checkpoint specified by "
-        "'--epoch'. ",
+        help=(
+            "Number of checkpoints to average. Automatically select "
+            "consecutive checkpoints before the checkpoint specified by "
+            "'--epoch'. "
+        ),
     )
 
     parser.add_argument(
@@ -228,9 +231,7 @@ def decode_one_batch(
     supervisions = batch["supervisions"]
     feature_lens = supervisions["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(
-        x=feature, x_lens=feature_lens
-    )
+    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
     hyps = []
     batch_size = encoder_out.size(0)
 
@@ -245,9 +246,7 @@ def decode_one_batch(
                 model=model, encoder_out=encoder_out_i, beam=params.beam_size
             )
         else:
-            raise ValueError(
-                f"Unsupported decoding method: {params.decoding_method}"
-            )
+            raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
         hyps.append(sp.decode(hyp).split())
 
     if params.decoding_method == "greedy_search":
@@ -318,9 +317,7 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
     return results
 
 
@@ -353,8 +350,7 @@ def save_results(
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = (
-        params.res_dir
-        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tWER", file=f)
diff --git a/egs/librispeech/ASR/transducer/export.py b/egs/librispeech/ASR/transducer/export.py
index 5a5db30c4..4d9f937f5 100755
--- a/egs/librispeech/ASR/transducer/export.py
+++ b/egs/librispeech/ASR/transducer/export.py
@@ -67,17 +67,20 @@ def get_parser():
         "--epoch",
         type=int,
         default=34,
-        help="It specifies the checkpoint to use for decoding."
-        "Note: Epoch counts from 0.",
+        help=(
+            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
+        ),
     )
 
     parser.add_argument(
         "--avg",
         type=int,
         default=11,
-        help="Number of checkpoints to average. Automatically select "
-        "consecutive checkpoints before the checkpoint specified by "
-        "'--epoch'. ",
+        help=(
+            "Number of checkpoints to average. Automatically select "
+            "consecutive checkpoints before the checkpoint specified by "
+            "'--epoch'. "
+        ),
     )
 
     parser.add_argument(
@@ -238,9 +241,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/librispeech/ASR/transducer/pretrained.py b/egs/librispeech/ASR/transducer/pretrained.py
index 1db2df648..7aadfbcd1 100755
--- a/egs/librispeech/ASR/transducer/pretrained.py
+++ b/egs/librispeech/ASR/transducer/pretrained.py
@@ -60,9 +60,11 @@ def get_parser():
         "--checkpoint",
         type=str,
         required=True,
-        help="Path to the checkpoint. "
-        "The checkpoint is assumed to be saved by "
-        "icefall.checkpoint.save_checkpoint().",
+        help=(
+            "Path to the checkpoint. "
+            "The checkpoint is assumed to be saved by "
+            "icefall.checkpoint.save_checkpoint()."
+        ),
     )
 
     parser.add_argument(
@@ -87,10 +89,12 @@ def get_parser():
         "sound_files",
         type=str,
         nargs="+",
-        help="The input sound file(s) to transcribe. "
-        "Supported formats are those supported by torchaudio.load(). "
-        "For example, wav and flac are supported. "
-        "The sample rate has to be 16kHz.",
+        help=(
+            "The input sound file(s) to transcribe. "
+            "Supported formats are those supported by torchaudio.load(). "
+            "For example, wav and flac are supported. "
+            "The sample rate has to be 16kHz."
+        ),
     )
 
     parser.add_argument(
@@ -188,10 +192,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. "
-            f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
@@ -249,9 +252,7 @@ def main():
     features = fbank(waves)
     feature_lengths = [f.size(0) for f in features]
 
-    features = pad_sequence(
-        features, batch_first=True, padding_value=math.log(1e-10)
-    )
+    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
 
     feature_lengths = torch.tensor(feature_lengths, device=device)
 
@@ -287,9 +288,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/librispeech/ASR/transducer/rnn.py b/egs/librispeech/ASR/transducer/rnn.py
index 2a165b0c1..fe8732301 100644
--- a/egs/librispeech/ASR/transducer/rnn.py
+++ b/egs/librispeech/ASR/transducer/rnn.py
@@ -117,12 +117,8 @@ class LayerNormLSTMCell(nn.Module):
         )
 
         if bias:
-            self.bias_ih = nn.Parameter(
-                torch.empty(4 * hidden_size, **factory_kwargs)
-            )
-            self.bias_hh = nn.Parameter(
-                torch.empty(4 * hidden_size, **factory_kwargs)
-            )
+            self.bias_ih = nn.Parameter(torch.empty(4 * hidden_size, **factory_kwargs))
+            self.bias_hh = nn.Parameter(torch.empty(4 * hidden_size, **factory_kwargs))
         else:
             self.register_parameter("bias_ih", None)
             self.register_parameter("bias_hh", None)
@@ -348,9 +344,7 @@ class LayerNormLSTM(nn.Module):
             device=device,
             dtype=dtype,
         )
-        first_layer = LayerNormLSTMLayer(
-            input_size=input_size, **factory_kwargs
-        )
+        first_layer = LayerNormLSTMLayer(input_size=input_size, **factory_kwargs)
         layers = [first_layer]
         for i in range(1, num_layers):
             layers.append(
@@ -385,9 +379,7 @@ class LayerNormLSTM(nn.Module):
             - List[(next_h, next_c)] containing the hidden states for all layers
 
         """
-        output_states = torch.jit.annotate(
-            List[Tuple[torch.Tensor, torch.Tensor]], []
-        )
+        output_states = torch.jit.annotate(List[Tuple[torch.Tensor, torch.Tensor]], [])
         output = input
         for i, rnn_layer in enumerate(self.layers):
             state = states[i]
@@ -456,12 +448,8 @@ class LayerNormGRUCell(nn.Module):
         )
 
         if bias:
-            self.bias_ih = nn.Parameter(
-                torch.empty(3 * hidden_size, **factory_kwargs)
-            )
-            self.bias_hh = nn.Parameter(
-                torch.empty(3 * hidden_size, **factory_kwargs)
-            )
+            self.bias_ih = nn.Parameter(torch.empty(3 * hidden_size, **factory_kwargs))
+            self.bias_hh = nn.Parameter(torch.empty(3 * hidden_size, **factory_kwargs))
         else:
             self.register_parameter("bias_ih", None)
             self.register_parameter("bias_hh", None)
diff --git a/egs/librispeech/ASR/transducer/test_rnn.py b/egs/librispeech/ASR/transducer/test_rnn.py
index 8591e2d8a..74c94cc70 100755
--- a/egs/librispeech/ASR/transducer/test_rnn.py
+++ b/egs/librispeech/ASR/transducer/test_rnn.py
@@ -254,9 +254,7 @@ def test_layernorm_lstm_layer_with_projection_forward(device="cpu"):
         for name, self_param in self_layer.cell.named_parameters():
             getattr(torch_layer, f"{name}_l0").copy_(self_param)
 
-    torch_y, (torch_h, torch_c) = torch_layer(
-        x_clone, (h.unsqueeze(0), c.unsqueeze(0))
-    )
+    torch_y, (torch_h, torch_c) = torch_layer(x_clone, (h.unsqueeze(0), c.unsqueeze(0)))
     assert_allclose(self_y, torch_y)
     assert_allclose(self_h, torch_h)
     assert_allclose(self_c, torch_c)
@@ -303,9 +301,7 @@ def test_layernorm_lstm_layer_forward(device="cpu"):
         for name, self_param in self_layer.cell.named_parameters():
             getattr(torch_layer, f"{name}_l0").copy_(self_param)
 
-    torch_y, (torch_h, torch_c) = torch_layer(
-        x_clone, (h.unsqueeze(0), c.unsqueeze(0))
-    )
+    torch_y, (torch_h, torch_c) = torch_layer(x_clone, (h.unsqueeze(0), c.unsqueeze(0)))
     assert_allclose(self_y, torch_y)
     assert_allclose(self_h, torch_h)
     assert_allclose(self_c, torch_c)
@@ -594,9 +590,7 @@ def test_layernorm_gru_cell_forward(device="cpu"):
 
     assert_allclose(self_h, torch_h, atol=1e-5)
 
-    (
-        self_h.reshape(-1) * torch.arange(self_h.numel(), device=device)
-    ).sum().backward()
+    (self_h.reshape(-1) * torch.arange(self_h.numel(), device=device)).sum().backward()
     (
         torch_h.reshape(-1) * torch.arange(torch_h.numel(), device=device)
     ).sum().backward()
@@ -718,9 +712,7 @@ def test_layernorm_gru_forward(device="cpu"):
     T = torch.randint(low=2, high=100, size=(1,))
 
     x = torch.rand(N, T, input_size, device=device).requires_grad_()
-    states = [
-        torch.rand(N, hidden_size, device=device) for _ in range(num_layers)
-    ]
+    states = [torch.rand(N, hidden_size, device=device) for _ in range(num_layers)]
 
     x_clone = x.detach().clone().requires_grad_()
 
diff --git a/egs/librispeech/ASR/transducer/train.py b/egs/librispeech/ASR/transducer/train.py
index 1dd65eddb..674ea10a6 100755
--- a/egs/librispeech/ASR/transducer/train.py
+++ b/egs/librispeech/ASR/transducer/train.py
@@ -396,9 +396,7 @@ def compute_loss(
     info = MetricsTracker()
     with warnings.catch_warnings():
         warnings.simplefilter("ignore")
-        info["frames"] = (
-            (feature_lens // params.subsampling_factor).sum().item()
-        )
+        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
 
     # Note: We use reduction=sum while computing the loss.
     info["loss"] = loss.detach().cpu().item()
@@ -520,9 +518,7 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(
-                    tb_writer, "train/tot_", params.batch_idx_train
-                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
@@ -659,9 +655,7 @@ def run(rank, world_size, args):
 
         cur_lr = optimizer._rate
         if tb_writer is not None:
-            tb_writer.add_scalar(
-                "train/learning_rate", cur_lr, params.batch_idx_train
-            )
+            tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
             tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
 
         if rank == 0:
diff --git a/egs/librispeech/ASR/transducer_lstm/beam_search.py b/egs/librispeech/ASR/transducer_lstm/beam_search.py
index 3531a9633..5342c3e8c 100644
--- a/egs/librispeech/ASR/transducer_lstm/beam_search.py
+++ b/egs/librispeech/ASR/transducer_lstm/beam_search.py
@@ -38,9 +38,7 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]:
     blank_id = model.decoder.blank_id
     device = model.device
 
-    sos = torch.tensor([blank_id], device=device, dtype=torch.int64).reshape(
-        1, 1
-    )
+    sos = torch.tensor([blank_id], device=device, dtype=torch.int64).reshape(1, 1)
     decoder_out, (h, c) = model.decoder(sos)
     T = encoder_out.size(1)
     t = 0
@@ -124,9 +122,7 @@ def beam_search(
     max_u = 20000  # terminate after this number of steps
     u = 0
 
-    cache: Dict[
-        str, Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
-    ] = {}
+    cache: Dict[str, Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = {}
 
     while t < T and u < max_u:
         # fmt: off
@@ -158,9 +154,9 @@ def beam_search(
             cached_key = "_".join(map(str, y_star.ys))
 
             if cached_key not in cache:
-                decoder_input = torch.tensor(
-                    [y_star.ys[-1]], device=device
-                ).reshape(1, 1)
+                decoder_input = torch.tensor([y_star.ys[-1]], device=device).reshape(
+                    1, 1
+                )
 
                 decoder_out, decoder_state = model.decoder(
                     decoder_input,
diff --git a/egs/librispeech/ASR/transducer_lstm/decode.py b/egs/librispeech/ASR/transducer_lstm/decode.py
index 604235e2a..61b9de504 100755
--- a/egs/librispeech/ASR/transducer_lstm/decode.py
+++ b/egs/librispeech/ASR/transducer_lstm/decode.py
@@ -71,16 +71,19 @@ def get_parser():
         "--epoch",
         type=int,
         default=77,
-        help="It specifies the checkpoint to use for decoding."
-        "Note: Epoch counts from 0.",
+        help=(
+            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
+        ),
     )
     parser.add_argument(
         "--avg",
         type=int,
         default=55,
-        help="Number of checkpoints to average. Automatically select "
-        "consecutive checkpoints before the checkpoint specified by "
-        "'--epoch'. ",
+        help=(
+            "Number of checkpoints to average. Automatically select "
+            "consecutive checkpoints before the checkpoint specified by "
+            "'--epoch'. "
+        ),
     )
 
     parser.add_argument(
@@ -225,9 +228,7 @@ def decode_one_batch(
     supervisions = batch["supervisions"]
     feature_lens = supervisions["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(
-        x=feature, x_lens=feature_lens
-    )
+    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
     hyps = []
     batch_size = encoder_out.size(0)
 
@@ -242,9 +243,7 @@ def decode_one_batch(
                 model=model, encoder_out=encoder_out_i, beam=params.beam_size
             )
         else:
-            raise ValueError(
-                f"Unsupported decoding method: {params.decoding_method}"
-            )
+            raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
         hyps.append(sp.decode(hyp).split())
 
     if params.decoding_method == "greedy_search":
@@ -315,9 +314,7 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
     return results
 
 
@@ -350,8 +347,7 @@ def save_results(
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = (
-        params.res_dir
-        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tWER", file=f)
diff --git a/egs/librispeech/ASR/transducer_lstm/encoder.py b/egs/librispeech/ASR/transducer_lstm/encoder.py
index 3dc992dd2..038d80077 100644
--- a/egs/librispeech/ASR/transducer_lstm/encoder.py
+++ b/egs/librispeech/ASR/transducer_lstm/encoder.py
@@ -48,9 +48,7 @@ class LstmEncoder(EncoderInterface):
         if vgg_frontend:
             self.encoder_embed = VggSubsampling(num_features, real_hidden_size)
         else:
-            self.encoder_embed = Conv2dSubsampling(
-                num_features, real_hidden_size
-            )
+            self.encoder_embed = Conv2dSubsampling(num_features, real_hidden_size)
 
         self.rnn = nn.LSTM(
             input_size=hidden_size,
diff --git a/egs/librispeech/ASR/transducer_lstm/train.py b/egs/librispeech/ASR/transducer_lstm/train.py
index cdb801e79..57bda63fd 100755
--- a/egs/librispeech/ASR/transducer_lstm/train.py
+++ b/egs/librispeech/ASR/transducer_lstm/train.py
@@ -400,9 +400,7 @@ def compute_loss(
     info = MetricsTracker()
     with warnings.catch_warnings():
         warnings.simplefilter("ignore")
-        info["frames"] = (
-            (feature_lens // params.subsampling_factor).sum().item()
-        )
+        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
 
     # Note: We use reduction=sum while computing the loss.
     info["loss"] = loss.detach().cpu().item()
@@ -524,9 +522,7 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(
-                    tb_writer, "train/tot_", params.batch_idx_train
-                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
@@ -665,9 +661,7 @@ def run(rank, world_size, args):
 
         cur_lr = optimizer._rate
         if tb_writer is not None:
-            tb_writer.add_scalar(
-                "train/learning_rate", cur_lr, params.batch_idx_train
-            )
+            tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
             tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
 
         if rank == 0:
diff --git a/egs/librispeech/ASR/transducer_stateless/alignment.py b/egs/librispeech/ASR/transducer_stateless/alignment.py
index f143611ea..65f2c58d8 100644
--- a/egs/librispeech/ASR/transducer_stateless/alignment.py
+++ b/egs/librispeech/ASR/transducer_stateless/alignment.py
@@ -193,9 +193,7 @@ def force_alignment(
         decoder_out = model.decoder(decoder_input, need_pad=False)
         # decoder_output is of shape (num_active_items, 1, decoder_output_dim)
 
-        current_encoder_out = current_encoder_out.expand(
-            decoder_out.size(0), 1, -1
-        )
+        current_encoder_out = current_encoder_out.expand(decoder_out.size(0), 1, -1)
 
         logits = model.joiner(
             current_encoder_out,
diff --git a/egs/librispeech/ASR/transducer_stateless/beam_search.py b/egs/librispeech/ASR/transducer_stateless/beam_search.py
index ea985f30d..1d79eef9d 100644
--- a/egs/librispeech/ASR/transducer_stateless/beam_search.py
+++ b/egs/librispeech/ASR/transducer_stateless/beam_search.py
@@ -316,9 +316,9 @@ def greedy_search(
         y = logits.argmax().item()
         if y != blank_id:
             hyp.append(y)
-            decoder_input = torch.tensor(
-                [hyp[-context_size:]], device=device
-            ).reshape(1, context_size)
+            decoder_input = torch.tensor([hyp[-context_size:]], device=device).reshape(
+                1, context_size
+            )
 
             decoder_out = model.decoder(decoder_input, need_pad=False)
 
@@ -478,9 +478,7 @@ class HypothesisList(object):
         key = hyp.key
         if key in self:
             old_hyp = self._data[key]  # shallow copy
-            torch.logaddexp(
-                old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob
-            )
+            torch.logaddexp(old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob)
         else:
             self._data[key] = hyp
 
@@ -496,9 +494,7 @@ class HypothesisList(object):
           Return the hypothesis that has the largest `log_prob`.
         """
         if length_norm:
-            return max(
-                self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys)
-            )
+            return max(self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys))
         else:
             return max(self._data.values(), key=lambda hyp: hyp.log_prob)
 
@@ -786,9 +782,7 @@ def modified_beam_search(
         log_probs_shape = k2.ragged.create_ragged_shape2(
             row_splits=row_splits, cached_tot_size=log_probs.numel()
         )
-        ragged_log_probs = k2.RaggedTensor(
-            shape=log_probs_shape, value=log_probs
-        )
+        ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs)
 
         for i in range(batch_size):
             topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam)
@@ -887,9 +881,7 @@ def _deprecated_modified_beam_search(
         decoder_out = model.decoder(decoder_input, need_pad=False)
         # decoder_output is of shape (num_hyps, 1, decoder_output_dim)
 
-        current_encoder_out = current_encoder_out.expand(
-            decoder_out.size(0), 1, -1
-        )
+        current_encoder_out = current_encoder_out.expand(decoder_out.size(0), 1, -1)
 
         logits = model.joiner(
             current_encoder_out,
@@ -959,9 +951,9 @@ def beam_search(
 
     device = model.device
 
-    decoder_input = torch.tensor(
-        [blank_id] * context_size, device=device
-    ).reshape(1, context_size)
+    decoder_input = torch.tensor([blank_id] * context_size, device=device).reshape(
+        1, context_size
+    )
 
     decoder_out = model.decoder(decoder_input, need_pad=False)
 
diff --git a/egs/librispeech/ASR/transducer_stateless/compute_ali.py b/egs/librispeech/ASR/transducer_stateless/compute_ali.py
index 48769e9d1..89992856d 100755
--- a/egs/librispeech/ASR/transducer_stateless/compute_ali.py
+++ b/egs/librispeech/ASR/transducer_stateless/compute_ali.py
@@ -54,16 +54,19 @@ def get_parser():
         "--epoch",
         type=int,
         default=34,
-        help="It specifies the checkpoint to use for decoding."
-        "Note: Epoch counts from 0.",
+        help=(
+            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
+        ),
     )
     parser.add_argument(
         "--avg",
         type=int,
         default=20,
-        help="Number of checkpoints to average. Automatically select "
-        "consecutive checkpoints before the checkpoint specified by "
-        "'--epoch'. ",
+        help=(
+            "Number of checkpoints to average. Automatically select "
+            "consecutive checkpoints before the checkpoint specified by "
+            "'--epoch'. "
+        ),
     )
 
     parser.add_argument(
@@ -124,8 +127,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     return parser
@@ -162,9 +164,7 @@ def compute_alignments(
 
         feature_lens = supervisions["num_frames"].to(device)
 
-        encoder_out, encoder_out_lens = model.encoder(
-            x=feature, x_lens=feature_lens
-        )
+        encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
 
         batch_size = encoder_out.size(0)
 
@@ -204,9 +204,7 @@ def compute_alignments(
         if batch_idx % 2 == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
 
     return CutSet.from_cuts(cuts)
 
diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py
index cde52c9fc..d279eae85 100644
--- a/egs/librispeech/ASR/transducer_stateless/conformer.py
+++ b/egs/librispeech/ASR/transducer_stateless/conformer.py
@@ -209,10 +209,7 @@ class Conformer(Transformer):
 
           NOTE: the returned tensors are on the given device.
         """
-        if (
-            len(self._init_state) == 2
-            and self._init_state[0].size(1) == left_context
-        ):
+        if len(self._init_state) == 2 and self._init_state[0].size(1) == left_context:
             # Note: It is OK to share the init state as it is
             # not going to be modified by the model
             return self._init_state
@@ -421,9 +418,7 @@ class ConformerEncoderLayer(nn.Module):
         causal: bool = False,
     ) -> None:
         super(ConformerEncoderLayer, self).__init__()
-        self.self_attn = RelPositionMultiheadAttention(
-            d_model, nhead, dropout=0.0
-        )
+        self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0)
 
         self.feed_forward = nn.Sequential(
             nn.Linear(d_model, dim_feedforward),
@@ -439,22 +434,16 @@ class ConformerEncoderLayer(nn.Module):
             nn.Linear(dim_feedforward, d_model),
         )
 
-        self.conv_module = ConvolutionModule(
-            d_model, cnn_module_kernel, causal=causal
-        )
+        self.conv_module = ConvolutionModule(d_model, cnn_module_kernel, causal=causal)
 
-        self.norm_ff_macaron = nn.LayerNorm(
-            d_model
-        )  # for the macaron style FNN module
+        self.norm_ff_macaron = nn.LayerNorm(d_model)  # for the macaron style FNN module
         self.norm_ff = nn.LayerNorm(d_model)  # for the FNN module
         self.norm_mha = nn.LayerNorm(d_model)  # for the MHA module
 
         self.ff_scale = 0.5
 
         self.norm_conv = nn.LayerNorm(d_model)  # for the CNN module
-        self.norm_final = nn.LayerNorm(
-            d_model
-        )  # for the final output of the block
+        self.norm_final = nn.LayerNorm(d_model)  # for the final output of the block
 
         self.dropout = nn.Dropout(dropout)
 
@@ -486,9 +475,7 @@ class ConformerEncoderLayer(nn.Module):
         residual = src
         if self.normalize_before:
             src = self.norm_ff_macaron(src)
-        src = residual + self.ff_scale * self.dropout(
-            self.feed_forward_macaron(src)
-        )
+        src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src))
         if not self.normalize_before:
             src = self.norm_ff_macaron(src)
 
@@ -514,9 +501,7 @@ class ConformerEncoderLayer(nn.Module):
         if self.normalize_before:
             src = self.norm_conv(src)
 
-        src, _ = self.conv_module(
-            src, src_key_padding_mask=src_key_padding_mask
-        )
+        src, _ = self.conv_module(src, src_key_padding_mask=src_key_padding_mask)
         src = residual + self.dropout(src)
 
         if not self.normalize_before:
@@ -581,9 +566,7 @@ class ConformerEncoderLayer(nn.Module):
         residual = src
         if self.normalize_before:
             src = self.norm_ff_macaron(src)
-        src = residual + self.ff_scale * self.dropout(
-            self.feed_forward_macaron(src)
-        )
+        src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src))
         if not self.normalize_before:
             src = self.norm_ff_macaron(src)
 
@@ -625,9 +608,7 @@ class ConformerEncoderLayer(nn.Module):
         if self.normalize_before:
             src = self.norm_conv(src)
 
-        src, conv_cache = self.conv_module(
-            src, states[1], right_context=right_context
-        )
+        src, conv_cache = self.conv_module(src, states[1], right_context=right_context)
         states[1] = conv_cache
         src = residual + self.dropout(src)
 
@@ -779,9 +760,7 @@ class RelPositionalEncoding(torch.nn.Module):
 
     """
 
-    def __init__(
-        self, d_model: int, dropout_rate: float, max_len: int = 5000
-    ) -> None:
+    def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None:
         """Construct an PositionalEncoding object."""
         super(RelPositionalEncoding, self).__init__()
         self.d_model = d_model
@@ -798,9 +777,7 @@ class RelPositionalEncoding(torch.nn.Module):
             # the length of self.pe is 2 * input_len - 1
             if self.pe.size(1) >= x_size_1 * 2 - 1:
                 # Note: TorchScript doesn't implement operator== for torch.Device
-                if self.pe.dtype != x.dtype or str(self.pe.device) != str(
-                    x.device
-                ):
+                if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device):
                     self.pe = self.pe.to(dtype=x.dtype, device=x.device)
                 return
         # Suppose `i` means to the position of query vector and `j` means the
@@ -826,9 +803,7 @@ class RelPositionalEncoding(torch.nn.Module):
         pe = torch.cat([pe_positive, pe_negative], dim=1)
         self.pe = pe.to(device=x.device, dtype=x.dtype)
 
-    def forward(
-        self, x: torch.Tensor, left_context: int = 0
-    ) -> Tuple[Tensor, Tensor]:
+    def forward(self, x: torch.Tensor, left_context: int = 0) -> Tuple[Tensor, Tensor]:
         """Add positional encoding.
 
         Args:
@@ -1092,9 +1067,9 @@ class RelPositionMultiheadAttention(nn.Module):
 
         if torch.equal(query, key) and torch.equal(key, value):
             # self-attention
-            q, k, v = nn.functional.linear(
-                query, in_proj_weight, in_proj_bias
-            ).chunk(3, dim=-1)
+            q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk(
+                3, dim=-1
+            )
 
         elif torch.equal(key, value):
             # encoder-decoder attention
@@ -1163,33 +1138,25 @@ class RelPositionMultiheadAttention(nn.Module):
             if attn_mask.dim() == 2:
                 attn_mask = attn_mask.unsqueeze(0)
                 if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
-                    raise RuntimeError(
-                        "The size of the 2D attn_mask is not correct."
-                    )
+                    raise RuntimeError("The size of the 2D attn_mask is not correct.")
             elif attn_mask.dim() == 3:
                 if list(attn_mask.size()) != [
                     bsz * num_heads,
                     query.size(0),
                     key.size(0),
                 ]:
-                    raise RuntimeError(
-                        "The size of the 3D attn_mask is not correct."
-                    )
+                    raise RuntimeError("The size of the 3D attn_mask is not correct.")
             else:
                 raise RuntimeError(
-                    "attn_mask's dimension {} is not supported".format(
-                        attn_mask.dim()
-                    )
+                    "attn_mask's dimension {} is not supported".format(attn_mask.dim())
                 )
             # attn_mask's dim is 3 now.
 
         # convert ByteTensor key_padding_mask to bool
-        if (
-            key_padding_mask is not None
-            and key_padding_mask.dtype == torch.uint8
-        ):
+        if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
             warnings.warn(
-                "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead."
+                "Byte tensor for key_padding_mask is deprecated. Use bool tensor"
+                " instead."
             )
             key_padding_mask = key_padding_mask.to(torch.bool)
 
@@ -1228,14 +1195,10 @@ class RelPositionMultiheadAttention(nn.Module):
         # first compute matrix a and matrix c
         # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
         k = k.permute(1, 2, 3, 0)  # (batch, head, d_k, time2)
-        matrix_ac = torch.matmul(
-            q_with_bias_u, k
-        )  # (batch, head, time1, time2)
+        matrix_ac = torch.matmul(q_with_bias_u, k)  # (batch, head, time1, time2)
 
         # compute matrix b and matrix d
-        matrix_bd = torch.matmul(
-            q_with_bias_v, p
-        )  # (batch, head, time1, 2*time1-1)
+        matrix_bd = torch.matmul(q_with_bias_v, p)  # (batch, head, time1, 2*time1-1)
 
         matrix_bd = self.rel_shift(matrix_bd, left_context=left_context)
 
@@ -1243,9 +1206,7 @@ class RelPositionMultiheadAttention(nn.Module):
             matrix_ac + matrix_bd
         ) * scaling  # (batch, head, time1, time2)
 
-        attn_output_weights = attn_output_weights.view(
-            bsz * num_heads, tgt_len, -1
-        )
+        attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1)
 
         assert list(attn_output_weights.size()) == [
             bsz * num_heads,
@@ -1290,9 +1251,7 @@ class RelPositionMultiheadAttention(nn.Module):
             attn_output_weights = attn_output_weights.view(
                 bsz, num_heads, tgt_len, src_len
             )
-            attn_output_weights = attn_output_weights.masked_fill(
-                combined_mask, 0.0
-            )
+            attn_output_weights = attn_output_weights.masked_fill(combined_mask, 0.0)
             attn_output_weights = attn_output_weights.view(
                 bsz * num_heads, tgt_len, src_len
             )
@@ -1304,13 +1263,9 @@ class RelPositionMultiheadAttention(nn.Module):
         attn_output = torch.bmm(attn_output_weights, v)
         assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
         attn_output = (
-            attn_output.transpose(0, 1)
-            .contiguous()
-            .view(tgt_len, bsz, embed_dim)
-        )
-        attn_output = nn.functional.linear(
-            attn_output, out_proj_weight, out_proj_bias
+            attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
         )
+        attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias)
 
         if need_weights:
             # average attention weights over heads
@@ -1418,16 +1373,12 @@ class ConvolutionModule(nn.Module):
                 # manualy padding self.lorder zeros to the left
                 x = nn.functional.pad(x, (self.lorder, 0), "constant", 0.0)
             else:
-                assert (
-                    not self.training
-                ), "Cache should be None in training time"
+                assert not self.training, "Cache should be None in training time"
                 assert cache.size(0) == self.lorder
                 x = torch.cat([cache.permute(1, 2, 0), x], dim=2)
                 if right_context > 0:
                     cache = x.permute(2, 0, 1)[
-                        -(self.lorder + right_context) : (  # noqa
-                            -right_context
-                        ),
+                        -(self.lorder + right_context) : (-right_context),  # noqa
                         ...,
                     ]
                 else:
diff --git a/egs/librispeech/ASR/transducer_stateless/decode.py b/egs/librispeech/ASR/transducer_stateless/decode.py
index 74bba9cad..314f49154 100755
--- a/egs/librispeech/ASR/transducer_stateless/decode.py
+++ b/egs/librispeech/ASR/transducer_stateless/decode.py
@@ -94,16 +94,19 @@ def get_parser():
         "--epoch",
         type=int,
         default=29,
-        help="It specifies the checkpoint to use for decoding."
-        "Note: Epoch counts from 0.",
+        help=(
+            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
+        ),
     )
     parser.add_argument(
         "--avg",
         type=int,
         default=13,
-        help="Number of checkpoints to average. Automatically select "
-        "consecutive checkpoints before the checkpoint specified by "
-        "'--epoch'. ",
+        help=(
+            "Number of checkpoints to average. Automatically select "
+            "consecutive checkpoints before the checkpoint specified by "
+            "'--epoch'. "
+        ),
     )
 
     parser.add_argument(
@@ -171,8 +174,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -230,9 +232,7 @@ def decode_one_batch(
     supervisions = batch["supervisions"]
     feature_lens = supervisions["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(
-        x=feature, x_lens=feature_lens
-    )
+    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
 
     hyps = []
 
@@ -248,10 +248,7 @@ def decode_one_batch(
         )
         for hyp in sp.decode(hyp_tokens):
             hyps.append(hyp.split())
-    elif (
-        params.decoding_method == "greedy_search"
-        and params.max_sym_per_frame == 1
-    ):
+    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -297,11 +294,7 @@ def decode_one_batch(
         return {"greedy_search": hyps}
     elif params.decoding_method == "fast_beam_search":
         return {
-            (
-                f"beam_{params.beam}_"
-                f"max_contexts_{params.max_contexts}_"
-                f"max_states_{params.max_states}"
-            ): hyps
+            f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps
         }
     else:
         return {f"beam_size_{params.beam_size}": hyps}
@@ -374,9 +367,7 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
     return results
 
 
@@ -409,8 +400,7 @@ def save_results(
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = (
-        params.res_dir
-        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tWER", file=f)
@@ -450,9 +440,7 @@ def main():
         params.suffix += f"-max-contexts-{params.max_contexts}"
         params.suffix += f"-max-states-{params.max_states}"
     elif "beam_search" in params.decoding_method:
-        params.suffix += (
-            f"-{params.decoding_method}-beam-size-{params.beam_size}"
-        )
+        params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
     else:
         params.suffix += f"-context-{params.context_size}"
         params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
diff --git a/egs/librispeech/ASR/transducer_stateless/decoder.py b/egs/librispeech/ASR/transducer_stateless/decoder.py
index fbc2373a9..a182d91e2 100644
--- a/egs/librispeech/ASR/transducer_stateless/decoder.py
+++ b/egs/librispeech/ASR/transducer_stateless/decoder.py
@@ -87,9 +87,7 @@ class Decoder(nn.Module):
         if self.context_size > 1:
             embedding_out = embedding_out.permute(0, 2, 1)
             if need_pad is True:
-                embedding_out = F.pad(
-                    embedding_out, pad=(self.context_size - 1, 0)
-                )
+                embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0))
             else:
                 # During inference time, there is no need to do extra padding
                 # as we only need one output
diff --git a/egs/librispeech/ASR/transducer_stateless/export.py b/egs/librispeech/ASR/transducer_stateless/export.py
index 8bd0bdea1..7c10b4348 100755
--- a/egs/librispeech/ASR/transducer_stateless/export.py
+++ b/egs/librispeech/ASR/transducer_stateless/export.py
@@ -68,17 +68,20 @@ def get_parser():
         "--epoch",
         type=int,
         default=20,
-        help="It specifies the checkpoint to use for decoding."
-        "Note: Epoch counts from 0.",
+        help=(
+            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
+        ),
     )
 
     parser.add_argument(
         "--avg",
         type=int,
         default=10,
-        help="Number of checkpoints to average. Automatically select "
-        "consecutive checkpoints before the checkpoint specified by "
-        "'--epoch'. ",
+        help=(
+            "Number of checkpoints to average. Automatically select "
+            "consecutive checkpoints before the checkpoint specified by "
+            "'--epoch'. "
+        ),
     )
 
     parser.add_argument(
@@ -109,8 +112,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     return parser
@@ -244,9 +246,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/librispeech/ASR/transducer_stateless/joiner.py b/egs/librispeech/ASR/transducer_stateless/joiner.py
index 93cccbd8c..e1625992d 100644
--- a/egs/librispeech/ASR/transducer_stateless/joiner.py
+++ b/egs/librispeech/ASR/transducer_stateless/joiner.py
@@ -60,13 +60,9 @@ class Joiner(nn.Module):
         encoder_out_len: List[int] = encoder_out_len.tolist()
         decoder_out_len: List[int] = decoder_out_len.tolist()
 
-        encoder_out_list = [
-            encoder_out[i, : encoder_out_len[i], :] for i in range(N)
-        ]
+        encoder_out_list = [encoder_out[i, : encoder_out_len[i], :] for i in range(N)]
 
-        decoder_out_list = [
-            decoder_out[i, : decoder_out_len[i], :] for i in range(N)
-        ]
+        decoder_out_list = [decoder_out[i, : decoder_out_len[i], :] for i in range(N)]
 
         x = [
             e.unsqueeze(1) + d.unsqueeze(0)
diff --git a/egs/librispeech/ASR/transducer_stateless/pretrained.py b/egs/librispeech/ASR/transducer_stateless/pretrained.py
index b64521801..bd7eeff28 100755
--- a/egs/librispeech/ASR/transducer_stateless/pretrained.py
+++ b/egs/librispeech/ASR/transducer_stateless/pretrained.py
@@ -90,9 +90,11 @@ def get_parser():
         "--checkpoint",
         type=str,
         required=True,
-        help="Path to the checkpoint. "
-        "The checkpoint is assumed to be saved by "
-        "icefall.checkpoint.save_checkpoint().",
+        help=(
+            "Path to the checkpoint. "
+            "The checkpoint is assumed to be saved by "
+            "icefall.checkpoint.save_checkpoint()."
+        ),
     )
 
     parser.add_argument(
@@ -117,10 +119,12 @@ def get_parser():
         "sound_files",
         type=str,
         nargs="+",
-        help="The input sound file(s) to transcribe. "
-        "Supported formats are those supported by torchaudio.load(). "
-        "For example, wav and flac are supported. "
-        "The sample rate has to be 16kHz.",
+        help=(
+            "The input sound file(s) to transcribe. "
+            "Supported formats are those supported by torchaudio.load(). "
+            "For example, wav and flac are supported. "
+            "The sample rate has to be 16kHz."
+        ),
     )
 
     parser.add_argument(
@@ -167,8 +171,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -197,10 +200,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. "
-            f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
@@ -259,9 +261,7 @@ def main():
     features = fbank(waves)
     feature_lengths = [f.size(0) for f in features]
 
-    features = pad_sequence(
-        features, batch_first=True, padding_value=math.log(1e-10)
-    )
+    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
 
     feature_lengths = torch.tensor(feature_lengths, device=device)
 
@@ -334,9 +334,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/librispeech/ASR/transducer_stateless/test_compute_ali.py b/egs/librispeech/ASR/transducer_stateless/test_compute_ali.py
index b00fc34f1..9af46846a 100755
--- a/egs/librispeech/ASR/transducer_stateless/test_compute_ali.py
+++ b/egs/librispeech/ASR/transducer_stateless/test_compute_ali.py
@@ -140,16 +140,13 @@ def main():
                 token_alignment[i, : token_alignment_length[i]].tolist(), sp=sp
             )
             word_starting_time = [
-                "{:.2f}".format(i * frame_shift_in_second)
-                for i in word_starting_frames
+                "{:.2f}".format(i * frame_shift_in_second) for i in word_starting_frames
             ]
 
             words = supervisions["text"][i].split()
 
             assert len(word_starting_frames) == len(words)
-            word_starting_time_dict[cuts[i].id] = list(
-                zip(words, word_starting_time)
-            )
+            word_starting_time_dict[cuts[i].id] = list(zip(words, word_starting_time))
 
         # This is a demo script and we exit here after processing
         # one batch.
@@ -160,9 +157,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/librispeech/ASR/transducer_stateless/test_conformer.py b/egs/librispeech/ASR/transducer_stateless/test_conformer.py
index d1350c8ab..65b08d425 100755
--- a/egs/librispeech/ASR/transducer_stateless/test_conformer.py
+++ b/egs/librispeech/ASR/transducer_stateless/test_conformer.py
@@ -29,9 +29,7 @@ from conformer import Conformer
 
 def test_conformer():
     feature_dim = 50
-    c = Conformer(
-        num_features=feature_dim, output_dim=256, d_model=128, nhead=4
-    )
+    c = Conformer(num_features=feature_dim, output_dim=256, d_model=128, nhead=4)
     batch_size = 5
     seq_len = 20
     # Just make sure the forward pass runs.
diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py
index ae93f3348..bcb883fa5 100755
--- a/egs/librispeech/ASR/transducer_stateless/train.py
+++ b/egs/librispeech/ASR/transducer_stateless/train.py
@@ -136,8 +136,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
@@ -422,9 +421,7 @@ def compute_loss(
     info = MetricsTracker()
     with warnings.catch_warnings():
         warnings.simplefilter("ignore")
-        info["frames"] = (
-            (feature_lens // params.subsampling_factor).sum().item()
-        )
+        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
 
     # Note: We use reduction=sum while computing the loss.
     info["loss"] = loss.detach().cpu().item()
@@ -545,9 +542,7 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(
-                    tb_writer, "train/tot_", params.batch_idx_train
-                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
@@ -664,13 +659,9 @@ def run(rank, world_size, args):
         num_removed = num_in_total - num_left
         removed_percent = num_removed / num_in_total * 100
 
-        logging.info(
-            f"Before removing short and long utterances: {num_in_total}"
-        )
+        logging.info(f"Before removing short and long utterances: {num_in_total}")
         logging.info(f"After removing short and long utterances: {num_left}")
-        logging.info(
-            f"Removed {num_removed} utterances ({removed_percent:.5f}%)"
-        )
+        logging.info(f"Removed {num_removed} utterances ({removed_percent:.5f}%)")
     except TypeError as e:
         # You can ignore this error as previous versions of Lhotse work fine
         # for the above code. In recent versions of Lhotse, it uses
@@ -698,9 +689,7 @@ def run(rank, world_size, args):
 
         cur_lr = optimizer._rate
         if tb_writer is not None:
-            tb_writer.add_scalar(
-                "train/learning_rate", cur_lr, params.batch_idx_train
-            )
+            tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
             tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
 
         if rank == 0:
diff --git a/egs/librispeech/ASR/transducer_stateless/transformer.py b/egs/librispeech/ASR/transducer_stateless/transformer.py
index e851dcc32..b3ff153c1 100644
--- a/egs/librispeech/ASR/transducer_stateless/transformer.py
+++ b/egs/librispeech/ASR/transducer_stateless/transformer.py
@@ -250,9 +250,7 @@ def _get_activation_fn(activation: str):
     elif activation == "gelu":
         return nn.functional.gelu
 
-    raise RuntimeError(
-        "activation should be relu/gelu, not {}".format(activation)
-    )
+    raise RuntimeError("activation should be relu/gelu, not {}".format(activation))
 
 
 class PositionalEncoding(nn.Module):
diff --git a/egs/librispeech/ASR/transducer_stateless2/decode.py b/egs/librispeech/ASR/transducer_stateless2/decode.py
index ac2807241..86ef9e5b6 100755
--- a/egs/librispeech/ASR/transducer_stateless2/decode.py
+++ b/egs/librispeech/ASR/transducer_stateless2/decode.py
@@ -94,16 +94,19 @@ def get_parser():
         "--epoch",
         type=int,
         default=29,
-        help="It specifies the checkpoint to use for decoding."
-        "Note: Epoch counts from 0.",
+        help=(
+            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
+        ),
     )
     parser.add_argument(
         "--avg",
         type=int,
         default=13,
-        help="Number of checkpoints to average. Automatically select "
-        "consecutive checkpoints before the checkpoint specified by "
-        "'--epoch'. ",
+        help=(
+            "Number of checkpoints to average. Automatically select "
+            "consecutive checkpoints before the checkpoint specified by "
+            "'--epoch'. "
+        ),
     )
 
     parser.add_argument(
@@ -171,8 +174,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -230,9 +232,7 @@ def decode_one_batch(
     supervisions = batch["supervisions"]
     feature_lens = supervisions["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(
-        x=feature, x_lens=feature_lens
-    )
+    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
 
     hyps = []
 
@@ -248,10 +248,7 @@ def decode_one_batch(
         )
         for hyp in sp.decode(hyp_tokens):
             hyps.append(hyp.split())
-    elif (
-        params.decoding_method == "greedy_search"
-        and params.max_sym_per_frame == 1
-    ):
+    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -297,11 +294,7 @@ def decode_one_batch(
         return {"greedy_search": hyps}
     elif params.decoding_method == "fast_beam_search":
         return {
-            (
-                f"beam_{params.beam}_"
-                f"max_contexts_{params.max_contexts}_"
-                f"max_states_{params.max_states}"
-            ): hyps
+            f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps
         }
     else:
         return {f"beam_size_{params.beam_size}": hyps}
@@ -374,9 +367,7 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
     return results
 
 
@@ -409,8 +400,7 @@ def save_results(
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = (
-        params.res_dir
-        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tWER", file=f)
@@ -450,9 +440,7 @@ def main():
         params.suffix += f"-max-contexts-{params.max_contexts}"
         params.suffix += f"-max-states-{params.max_states}"
     elif "beam_search" in params.decoding_method:
-        params.suffix += (
-            f"-{params.decoding_method}-beam-size-{params.beam_size}"
-        )
+        params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
     else:
         params.suffix += f"-context-{params.context_size}"
         params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
diff --git a/egs/librispeech/ASR/transducer_stateless2/export.py b/egs/librispeech/ASR/transducer_stateless2/export.py
index 57c1a6094..d95eeb1f4 100755
--- a/egs/librispeech/ASR/transducer_stateless2/export.py
+++ b/egs/librispeech/ASR/transducer_stateless2/export.py
@@ -63,17 +63,20 @@ def get_parser():
         "--epoch",
         type=int,
         default=20,
-        help="It specifies the checkpoint to use for decoding."
-        "Note: Epoch counts from 0.",
+        help=(
+            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
+        ),
     )
 
     parser.add_argument(
         "--avg",
         type=int,
         default=10,
-        help="Number of checkpoints to average. Automatically select "
-        "consecutive checkpoints before the checkpoint specified by "
-        "'--epoch'. ",
+        help=(
+            "Number of checkpoints to average. Automatically select "
+            "consecutive checkpoints before the checkpoint specified by "
+            "'--epoch'. "
+        ),
     )
 
     parser.add_argument(
@@ -104,8 +107,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     return parser
@@ -176,9 +178,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/librispeech/ASR/transducer_stateless2/pretrained.py b/egs/librispeech/ASR/transducer_stateless2/pretrained.py
index 292f77f03..793931e3b 100755
--- a/egs/librispeech/ASR/transducer_stateless2/pretrained.py
+++ b/egs/librispeech/ASR/transducer_stateless2/pretrained.py
@@ -90,9 +90,11 @@ def get_parser():
         "--checkpoint",
         type=str,
         required=True,
-        help="Path to the checkpoint. "
-        "The checkpoint is assumed to be saved by "
-        "icefall.checkpoint.save_checkpoint().",
+        help=(
+            "Path to the checkpoint. "
+            "The checkpoint is assumed to be saved by "
+            "icefall.checkpoint.save_checkpoint()."
+        ),
     )
 
     parser.add_argument(
@@ -117,10 +119,12 @@ def get_parser():
         "sound_files",
         type=str,
         nargs="+",
-        help="The input sound file(s) to transcribe. "
-        "Supported formats are those supported by torchaudio.load(). "
-        "For example, wav and flac are supported. "
-        "The sample rate has to be 16kHz.",
+        help=(
+            "The input sound file(s) to transcribe. "
+            "Supported formats are those supported by torchaudio.load(). "
+            "For example, wav and flac are supported. "
+            "The sample rate has to be 16kHz."
+        ),
     )
 
     parser.add_argument(
@@ -167,8 +171,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -197,10 +200,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. "
-            f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
@@ -259,9 +261,7 @@ def main():
     features = fbank(waves)
     feature_lengths = [f.size(0) for f in features]
 
-    features = pad_sequence(
-        features, batch_first=True, padding_value=math.log(1e-10)
-    )
+    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
 
     feature_lengths = torch.tensor(feature_lengths, device=device)
 
@@ -334,9 +334,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/librispeech/ASR/transducer_stateless2/train.py b/egs/librispeech/ASR/transducer_stateless2/train.py
index ea15c9040..68e247f23 100755
--- a/egs/librispeech/ASR/transducer_stateless2/train.py
+++ b/egs/librispeech/ASR/transducer_stateless2/train.py
@@ -136,8 +136,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
@@ -410,9 +409,7 @@ def compute_loss(
     info = MetricsTracker()
     with warnings.catch_warnings():
         warnings.simplefilter("ignore")
-        info["frames"] = (
-            (feature_lens // params.subsampling_factor).sum().item()
-        )
+        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
 
     # Note: We use reduction=sum while computing the loss.
     info["loss"] = loss.detach().cpu().item()
@@ -533,9 +530,7 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(
-                    tb_writer, "train/tot_", params.batch_idx_train
-                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
@@ -652,13 +647,9 @@ def run(rank, world_size, args):
         num_removed = num_in_total - num_left
         removed_percent = num_removed / num_in_total * 100
 
-        logging.info(
-            f"Before removing short and long utterances: {num_in_total}"
-        )
+        logging.info(f"Before removing short and long utterances: {num_in_total}")
         logging.info(f"After removing short and long utterances: {num_left}")
-        logging.info(
-            f"Removed {num_removed} utterances ({removed_percent:.5f}%)"
-        )
+        logging.info(f"Removed {num_removed} utterances ({removed_percent:.5f}%)")
     except TypeError as e:
         # You can ignore this error as previous versions of Lhotse work fine
         # for the above code. In recent versions of Lhotse, it uses
@@ -686,9 +677,7 @@ def run(rank, world_size, args):
 
         cur_lr = optimizer._rate
         if tb_writer is not None:
-            tb_writer.add_scalar(
-                "train/learning_rate", cur_lr, params.batch_idx_train
-            )
+            tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
             tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
 
         if rank == 0:
diff --git a/egs/librispeech/ASR/transducer_stateless_multi_datasets/decode.py b/egs/librispeech/ASR/transducer_stateless_multi_datasets/decode.py
index d596e05cb..22b6ab911 100755
--- a/egs/librispeech/ASR/transducer_stateless_multi_datasets/decode.py
+++ b/egs/librispeech/ASR/transducer_stateless_multi_datasets/decode.py
@@ -95,16 +95,19 @@ def get_parser():
         "--epoch",
         type=int,
         default=29,
-        help="It specifies the checkpoint to use for decoding."
-        "Note: Epoch counts from 0.",
+        help=(
+            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
+        ),
     )
     parser.add_argument(
         "--avg",
         type=int,
         default=13,
-        help="Number of checkpoints to average. Automatically select "
-        "consecutive checkpoints before the checkpoint specified by "
-        "'--epoch'. ",
+        help=(
+            "Number of checkpoints to average. Automatically select "
+            "consecutive checkpoints before the checkpoint specified by "
+            "'--epoch'. "
+        ),
     )
 
     parser.add_argument(
@@ -172,8 +175,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -231,9 +233,7 @@ def decode_one_batch(
     supervisions = batch["supervisions"]
     feature_lens = supervisions["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(
-        x=feature, x_lens=feature_lens
-    )
+    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
 
     hyps = []
 
@@ -249,10 +249,7 @@ def decode_one_batch(
         )
         for hyp in sp.decode(hyp_tokens):
             hyps.append(hyp.split())
-    elif (
-        params.decoding_method == "greedy_search"
-        and params.max_sym_per_frame == 1
-    ):
+    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -298,11 +295,7 @@ def decode_one_batch(
         return {"greedy_search": hyps}
     elif params.decoding_method == "fast_beam_search":
         return {
-            (
-                f"beam_{params.beam}_"
-                f"max_contexts_{params.max_contexts}_"
-                f"max_states_{params.max_states}"
-            ): hyps
+            f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps
         }
     else:
         return {f"beam_size_{params.beam_size}": hyps}
@@ -375,9 +368,7 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
     return results
 
 
@@ -410,8 +401,7 @@ def save_results(
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = (
-        params.res_dir
-        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tWER", file=f)
@@ -451,9 +441,7 @@ def main():
         params.suffix += f"-max-contexts-{params.max_contexts}"
         params.suffix += f"-max-states-{params.max_states}"
     elif "beam_search" in params.decoding_method:
-        params.suffix += (
-            f"-{params.decoding_method}-beam-size-{params.beam_size}"
-        )
+        params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
     else:
         params.suffix += f"-context-{params.context_size}"
         params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
diff --git a/egs/librispeech/ASR/transducer_stateless_multi_datasets/export.py b/egs/librispeech/ASR/transducer_stateless_multi_datasets/export.py
index b6b69d932..fad9a6977 100755
--- a/egs/librispeech/ASR/transducer_stateless_multi_datasets/export.py
+++ b/egs/librispeech/ASR/transducer_stateless_multi_datasets/export.py
@@ -69,17 +69,20 @@ def get_parser():
         "--epoch",
         type=int,
         default=20,
-        help="It specifies the checkpoint to use for decoding."
-        "Note: Epoch counts from 0.",
+        help=(
+            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
+        ),
     )
 
     parser.add_argument(
         "--avg",
         type=int,
         default=10,
-        help="Number of checkpoints to average. Automatically select "
-        "consecutive checkpoints before the checkpoint specified by "
-        "'--epoch'. ",
+        help=(
+            "Number of checkpoints to average. Automatically select "
+            "consecutive checkpoints before the checkpoint specified by "
+            "'--epoch'. "
+        ),
     )
 
     parser.add_argument(
@@ -110,8 +113,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     return parser
@@ -247,9 +249,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/librispeech/ASR/transducer_stateless_multi_datasets/pretrained.py b/egs/librispeech/ASR/transducer_stateless_multi_datasets/pretrained.py
index f297fa2b2..efd257b5d 100755
--- a/egs/librispeech/ASR/transducer_stateless_multi_datasets/pretrained.py
+++ b/egs/librispeech/ASR/transducer_stateless_multi_datasets/pretrained.py
@@ -90,9 +90,11 @@ def get_parser():
         "--checkpoint",
         type=str,
         required=True,
-        help="Path to the checkpoint. "
-        "The checkpoint is assumed to be saved by "
-        "icefall.checkpoint.save_checkpoint().",
+        help=(
+            "Path to the checkpoint. "
+            "The checkpoint is assumed to be saved by "
+            "icefall.checkpoint.save_checkpoint()."
+        ),
     )
 
     parser.add_argument(
@@ -117,10 +119,12 @@ def get_parser():
         "sound_files",
         type=str,
         nargs="+",
-        help="The input sound file(s) to transcribe. "
-        "Supported formats are those supported by torchaudio.load(). "
-        "For example, wav and flac are supported. "
-        "The sample rate has to be 16kHz.",
+        help=(
+            "The input sound file(s) to transcribe. "
+            "Supported formats are those supported by torchaudio.load(). "
+            "For example, wav and flac are supported. "
+            "The sample rate has to be 16kHz."
+        ),
     )
 
     parser.add_argument(
@@ -167,8 +171,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -197,10 +200,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. "
-            f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
@@ -259,9 +261,7 @@ def main():
     features = fbank(waves)
     feature_lengths = [f.size(0) for f in features]
 
-    features = pad_sequence(
-        features, batch_first=True, padding_value=math.log(1e-10)
-    )
+    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
 
     feature_lengths = torch.tensor(feature_lengths, device=device)
 
@@ -334,9 +334,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/librispeech/ASR/transducer_stateless_multi_datasets/test_asr_datamodule.py b/egs/librispeech/ASR/transducer_stateless_multi_datasets/test_asr_datamodule.py
index ef51a7811..1e1188ca6 100755
--- a/egs/librispeech/ASR/transducer_stateless_multi_datasets/test_asr_datamodule.py
+++ b/egs/librispeech/ASR/transducer_stateless_multi_datasets/test_asr_datamodule.py
@@ -41,9 +41,7 @@ def test_dataset():
     print(args)
 
     if args.enable_musan:
-        cuts_musan = load_manifest(
-            Path(args.manifest_dir) / "musan_cuts.jsonl.gz"
-        )
+        cuts_musan = load_manifest(Path(args.manifest_dir) / "musan_cuts.jsonl.gz")
     else:
         cuts_musan = None
 
diff --git a/egs/librispeech/ASR/transducer_stateless_multi_datasets/train.py b/egs/librispeech/ASR/transducer_stateless_multi_datasets/train.py
index 27912738c..88987d91c 100755
--- a/egs/librispeech/ASR/transducer_stateless_multi_datasets/train.py
+++ b/egs/librispeech/ASR/transducer_stateless_multi_datasets/train.py
@@ -114,8 +114,7 @@ def get_parser():
         "--full-libri",
         type=str2bool,
         default=True,
-        help="When enabled, use 960h LibriSpeech. "
-        "Otherwise, use 100h subset.",
+        help="When enabled, use 960h LibriSpeech. Otherwise, use 100h subset.",
     )
 
     parser.add_argument(
@@ -170,8 +169,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
@@ -469,9 +467,7 @@ def compute_loss(
     info = MetricsTracker()
     with warnings.catch_warnings():
         warnings.simplefilter("ignore")
-        info["frames"] = (
-            (feature_lens // params.subsampling_factor).sum().item()
-        )
+        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
 
     # Note: We use reduction=sum while computing the loss.
     info["loss"] = loss.detach().cpu().item()
@@ -635,9 +631,7 @@ def train_one_epoch(
                     f"train/current_{prefix}_",
                     params.batch_idx_train,
                 )
-                tot_loss.write_summary(
-                    tb_writer, "train/tot_", params.batch_idx_train
-                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
                 libri_tot_loss.write_summary(
                     tb_writer, "train/libri_tot_", params.batch_idx_train
                 )
@@ -784,9 +778,7 @@ def run(rank, world_size, args):
     train_giga_cuts = train_giga_cuts.repeat(times=None)
 
     if args.enable_musan:
-        cuts_musan = load_manifest(
-            Path(args.manifest_dir) / "musan_cuts.jsonl.gz"
-        )
+        cuts_musan = load_manifest(Path(args.manifest_dir) / "musan_cuts.jsonl.gz")
     else:
         cuts_musan = None
 
@@ -825,9 +817,7 @@ def run(rank, world_size, args):
 
         cur_lr = optimizer._rate
         if tb_writer is not None:
-            tb_writer.add_scalar(
-                "train/learning_rate", cur_lr, params.batch_idx_train
-            )
+            tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
             tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
 
         if rank == 0:
diff --git a/egs/ptb/LM/local/sort_lm_training_data.py b/egs/ptb/LM/local/sort_lm_training_data.py
index af54dbd07..bed3856e4 100755
--- a/egs/ptb/LM/local/sort_lm_training_data.py
+++ b/egs/ptb/LM/local/sort_lm_training_data.py
@@ -135,9 +135,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/ptb/LM/local/test_prepare_lm_training_data.py b/egs/ptb/LM/local/test_prepare_lm_training_data.py
index 877720e7b..3790045fa 100755
--- a/egs/ptb/LM/local/test_prepare_lm_training_data.py
+++ b/egs/ptb/LM/local/test_prepare_lm_training_data.py
@@ -54,9 +54,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/spgispeech/ASR/local/compute_fbank_musan.py b/egs/spgispeech/ASR/local/compute_fbank_musan.py
index 6cb8b65ae..9bea28a41 100755
--- a/egs/spgispeech/ASR/local/compute_fbank_musan.py
+++ b/egs/spgispeech/ASR/local/compute_fbank_musan.py
@@ -87,9 +87,7 @@ def compute_fbank_musan():
     # create chunks of Musan with duration 5 - 10 seconds
     musan_cuts = (
         CutSet.from_manifests(
-            recordings=combine(
-                part["recordings"] for part in manifests.values()
-            )
+            recordings=combine(part["recordings"] for part in manifests.values())
         )
         .cut_into_windows(10.0)
         .filter(lambda c: c.duration > 5)
@@ -108,8 +106,6 @@ def compute_fbank_musan():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
     logging.basicConfig(format=formatter, level=logging.INFO)
     compute_fbank_musan()
diff --git a/egs/spgispeech/ASR/local/compute_fbank_spgispeech.py b/egs/spgispeech/ASR/local/compute_fbank_spgispeech.py
index 8116e7605..20ff6d7ab 100755
--- a/egs/spgispeech/ASR/local/compute_fbank_spgispeech.py
+++ b/egs/spgispeech/ASR/local/compute_fbank_spgispeech.py
@@ -103,11 +103,7 @@ def compute_fbank_spgispeech(args):
             chunk_size=chunk_size,
         )
         start = args.start
-        stop = (
-            min(args.stop, args.num_splits)
-            if args.stop > 0
-            else args.num_splits
-        )
+        stop = min(args.stop, args.num_splits) if args.stop > 0 else args.num_splits
         num_digits = len(str(args.num_splits))
         for i in range(start, stop):
             idx = f"{i + 1}".zfill(num_digits)
@@ -129,9 +125,7 @@ def compute_fbank_spgispeech(args):
                 logging.info(f"{partition} already exists - skipping.")
                 continue
             logging.info(f"Processing {partition}")
-            cut_set = load_manifest_lazy(
-                src_dir / f"cuts_{partition}_raw.jsonl.gz"
-            )
+            cut_set = load_manifest_lazy(src_dir / f"cuts_{partition}_raw.jsonl.gz")
             cut_set = cut_set.compute_and_store_features_batch(
                 extractor=extractor,
                 storage_path=output_dir / f"feats_{partition}",
@@ -144,9 +138,7 @@ def compute_fbank_spgispeech(args):
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
     logging.basicConfig(format=formatter, level=logging.INFO)
 
     args = get_args()
diff --git a/egs/spgispeech/ASR/local/prepare_splits.py b/egs/spgispeech/ASR/local/prepare_splits.py
index 8c8f1c133..508d4acd8 100755
--- a/egs/spgispeech/ASR/local/prepare_splits.py
+++ b/egs/spgispeech/ASR/local/prepare_splits.py
@@ -55,9 +55,7 @@ def split_spgispeech_train():
 
     # Add speed perturbation
     train_cuts = (
-        train_cuts
-        + train_cuts.perturb_speed(0.9)
-        + train_cuts.perturb_speed(1.1)
+        train_cuts + train_cuts.perturb_speed(0.9) + train_cuts.perturb_speed(1.1)
     )
 
     # Write the manifests to disk.
@@ -73,9 +71,7 @@ def split_spgispeech_train():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
     logging.basicConfig(format=formatter, level=logging.INFO)
 
     split_spgispeech_train()
diff --git a/egs/spgispeech/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/spgispeech/ASR/pruned_transducer_stateless2/asr_datamodule.py
index f165f6e60..83f95d123 100644
--- a/egs/spgispeech/ASR/pruned_transducer_stateless2/asr_datamodule.py
+++ b/egs/spgispeech/ASR/pruned_transducer_stateless2/asr_datamodule.py
@@ -70,10 +70,12 @@ class SPGISpeechAsrDataModule:
     def add_arguments(cls, parser: argparse.ArgumentParser):
         group = parser.add_argument_group(
             title="ASR data related options",
-            description="These options are used for the preparation of "
-            "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
-            "effective batch sizes, sampling strategies, applied data "
-            "augmentations, etc.",
+            description=(
+                "These options are used for the preparation of "
+                "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
+                "effective batch sizes, sampling strategies, applied data "
+                "augmentations, etc."
+            ),
         )
         group.add_argument(
             "--manifest-dir",
@@ -85,67 +87,81 @@ class SPGISpeechAsrDataModule:
             "--enable-musan",
             type=str2bool,
             default=True,
-            help="When enabled, select noise from MUSAN and mix it "
-            "with training dataset. ",
+            help=(
+                "When enabled, select noise from MUSAN and mix it "
+                "with training dataset. "
+            ),
         )
         group.add_argument(
             "--concatenate-cuts",
             type=str2bool,
             default=False,
-            help="When enabled, utterances (cuts) will be concatenated "
-            "to minimize the amount of padding.",
+            help=(
+                "When enabled, utterances (cuts) will be concatenated "
+                "to minimize the amount of padding."
+            ),
         )
         group.add_argument(
             "--duration-factor",
             type=float,
             default=1.0,
-            help="Determines the maximum duration of a concatenated cut "
-            "relative to the duration of the longest cut in a batch.",
+            help=(
+                "Determines the maximum duration of a concatenated cut "
+                "relative to the duration of the longest cut in a batch."
+            ),
         )
         group.add_argument(
             "--gap",
             type=float,
             default=1.0,
-            help="The amount of padding (in seconds) inserted between "
-            "concatenated cuts. This padding is filled with noise when "
-            "noise augmentation is used.",
+            help=(
+                "The amount of padding (in seconds) inserted between "
+                "concatenated cuts. This padding is filled with noise when "
+                "noise augmentation is used."
+            ),
         )
         group.add_argument(
             "--max-duration",
             type=int,
             default=100.0,
-            help="Maximum pooled recordings duration (seconds) in a "
-            "single batch. You can reduce it if it causes CUDA OOM.",
+            help=(
+                "Maximum pooled recordings duration (seconds) in a "
+                "single batch. You can reduce it if it causes CUDA OOM."
+            ),
         )
         group.add_argument(
             "--num-buckets",
             type=int,
             default=30,
-            help="The number of buckets for the BucketingSampler"
-            "(you might want to increase it for larger datasets).",
+            help=(
+                "The number of buckets for the BucketingSampler"
+                "(you might want to increase it for larger datasets)."
+            ),
         )
         group.add_argument(
             "--on-the-fly-feats",
             type=str2bool,
             default=False,
-            help="When enabled, use on-the-fly cut mixing and feature "
-            "extraction. Will drop existing precomputed feature manifests "
-            "if available.",
+            help=(
+                "When enabled, use on-the-fly cut mixing and feature "
+                "extraction. Will drop existing precomputed feature manifests "
+                "if available."
+            ),
         )
         group.add_argument(
             "--shuffle",
             type=str2bool,
             default=True,
-            help="When enabled (=default), the examples will be "
-            "shuffled for each epoch.",
+            help=(
+                "When enabled (=default), the examples will be shuffled for each epoch."
+            ),
         )
 
         group.add_argument(
             "--num-workers",
             type=int,
             default=8,
-            help="The number of training dataloader workers that "
-            "collect the batches.",
+            help="The number of training dataloader workers that collect the batches.",
         )
         group.add_argument(
             "--enable-spec-aug",
@@ -157,10 +173,12 @@ class SPGISpeechAsrDataModule:
             "--spec-aug-time-warp-factor",
             type=int,
             default=80,
-            help="Used only when --enable-spec-aug is True. "
-            "It specifies the factor for time warping in SpecAugment. "
-            "Larger values mean more warping. "
-            "A value less than 1 means to disable time warp.",
+            help=(
+                "Used only when --enable-spec-aug is True. "
+                "It specifies the factor for time warping in SpecAugment. "
+                "Larger values mean more warping. "
+                "A value less than 1 means to disable time warp."
+            ),
         )
 
     def train_dataloaders(
@@ -176,24 +194,20 @@ class SPGISpeechAsrDataModule:
             The state dict for the training sampler.
         """
         logging.info("About to get Musan cuts")
-        cuts_musan = load_manifest(
-            self.args.manifest_dir / "cuts_musan.jsonl.gz"
-        )
+        cuts_musan = load_manifest(self.args.manifest_dir / "cuts_musan.jsonl.gz")
 
         transforms = []
         if self.args.enable_musan:
             logging.info("Enable MUSAN")
             transforms.append(
-                CutMix(
-                    cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
-                )
+                CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
             )
         else:
             logging.info("Disable MUSAN")
 
         if self.args.concatenate_cuts:
             logging.info(
-                f"Using cut concatenation with duration factor "
+                "Using cut concatenation with duration factor "
                 f"{self.args.duration_factor} and gap {self.args.gap}."
             )
             # Cut concatenation should be the first transform in the list,
@@ -208,9 +222,7 @@ class SPGISpeechAsrDataModule:
         input_transforms = []
         if self.args.enable_spec_aug:
             logging.info("Enable SpecAugment")
-            logging.info(
-                f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
-            )
+            logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
             input_transforms.append(
                 SpecAugment(
                     time_warp_factor=self.args.spec_aug_time_warp_factor,
@@ -227,9 +239,7 @@ class SPGISpeechAsrDataModule:
         if self.args.on_the_fly_feats:
             train = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(
-                    Fbank(FbankConfig(num_mel_bins=80))
-                ),
+                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
                 input_transforms=input_transforms,
             )
         else:
@@ -282,9 +292,7 @@ class SPGISpeechAsrDataModule:
         if self.args.on_the_fly_feats:
             validate = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(
-                    Fbank(FbankConfig(num_mel_bins=80))
-                ),
+                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
             )
         else:
             validate = K2SpeechRecognitionDataset(
@@ -328,9 +336,7 @@ class SPGISpeechAsrDataModule:
     @lru_cache()
     def train_cuts(self) -> CutSet:
         logging.info("About to get SPGISpeech train cuts")
-        return load_manifest_lazy(
-            self.args.manifest_dir / "cuts_train_shuf.jsonl.gz"
-        )
+        return load_manifest_lazy(self.args.manifest_dir / "cuts_train_shuf.jsonl.gz")
 
     @lru_cache()
     def dev_cuts(self) -> CutSet:
diff --git a/egs/spgispeech/ASR/pruned_transducer_stateless2/decode.py b/egs/spgispeech/ASR/pruned_transducer_stateless2/decode.py
index c39bd0530..72a7cd1c1 100755
--- a/egs/spgispeech/ASR/pruned_transducer_stateless2/decode.py
+++ b/egs/spgispeech/ASR/pruned_transducer_stateless2/decode.py
@@ -76,11 +76,7 @@ from beam_search import (
 )
 from train import get_params, get_transducer_model
 
-from icefall.checkpoint import (
-    average_checkpoints,
-    find_checkpoints,
-    load_checkpoint,
-)
+from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
 from icefall.utils import (
     AttributeDict,
     setup_logger,
@@ -117,9 +113,11 @@ def get_parser():
         "--avg",
         type=int,
         default=10,
-        help="Number of checkpoints to average. Automatically select "
-        "consecutive checkpoints before the checkpoint specified by "
-        "'--epoch' and '--iter'",
+        help=(
+            "Number of checkpoints to average. Automatically select "
+            "consecutive checkpoints before the checkpoint specified by "
+            "'--epoch' and '--iter'"
+        ),
     )
 
     parser.add_argument(
@@ -187,8 +185,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -246,9 +243,7 @@ def decode_one_batch(
     supervisions = batch["supervisions"]
     feature_lens = supervisions["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(
-        x=feature, x_lens=feature_lens
-    )
+    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
     hyps = []
 
     if params.decoding_method == "fast_beam_search":
@@ -263,10 +258,7 @@ def decode_one_batch(
         )
         for hyp in sp.decode(hyp_tokens):
             hyps.append(hyp.split())
-    elif (
-        params.decoding_method == "greedy_search"
-        and params.max_sym_per_frame == 1
-    ):
+    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -312,11 +304,7 @@ def decode_one_batch(
         return {"greedy_search": hyps}
     elif params.decoding_method == "fast_beam_search":
         return {
-            (
-                f"beam_{params.beam}_"
-                f"max_contexts_{params.max_contexts}_"
-                f"max_states_{params.max_states}"
-            ): hyps
+            f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps
         }
     else:
         return {f"beam_size_{params.beam_size}": hyps}
@@ -389,9 +377,7 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
     return results
 
 
@@ -424,9 +410,7 @@ def save_results(
         # we also compute CER for spgispeech dataset.
         results_char = []
         for res in results:
-            results_char.append(
-                (res[0], list("".join(res[1])), list("".join(res[2])))
-            )
+            results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
         cers_filename = (
             params.res_dir / f"cers-{test_set_name}-{key}-{params.suffix}.txt"
         )
@@ -438,32 +422,23 @@ def save_results(
 
         logging.info("Wrote detailed error stats to {}".format(wers_filename))
 
-    test_set_wers = {
-        k: v for k, v in sorted(test_set_wers.items(), key=lambda x: x[1])
-    }
-    test_set_cers = {
-        k: v for k, v in sorted(test_set_cers.items(), key=lambda x: x[1])
-    }
+    test_set_wers = {k: v for k, v in sorted(test_set_wers.items(), key=lambda x: x[1])}
+    test_set_cers = {k: v for k, v in sorted(test_set_cers.items(), key=lambda x: x[1])}
     errs_info = (
-        params.res_dir
-        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tWER\tCER", file=f)
         for key in test_set_wers:
             print(
-                "{}\t{}\t{}".format(
-                    key, test_set_wers[key], test_set_cers[key]
-                ),
+                "{}\t{}\t{}".format(key, test_set_wers[key], test_set_cers[key]),
                 file=f,
             )
 
     s = "\nFor {}, WER/CER of different settings are:\n".format(test_set_name)
     note = "\tbest for {}".format(test_set_name)
     for key in test_set_wers:
-        s += "{}\t{}\t{}{}\n".format(
-            key, test_set_wers[key], test_set_cers[key], note
-        )
+        s += "{}\t{}\t{}{}\n".format(key, test_set_wers[key], test_set_cers[key], note)
         note = ""
     logging.info(s)
 
@@ -496,9 +471,7 @@ def main():
         params.suffix += f"-max-contexts-{params.max_contexts}"
         params.suffix += f"-max-states-{params.max_states}"
     elif "beam_search" in params.decoding_method:
-        params.suffix += (
-            f"-{params.decoding_method}-beam-size-{params.beam_size}"
-        )
+        params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
     else:
         params.suffix += f"-context-{params.context_size}"
         params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
@@ -530,8 +503,7 @@ def main():
         ]
         if len(filenames) == 0:
             raise ValueError(
-                f"No checkpoints found for"
-                f" --iter {params.iter}, --avg {params.avg}"
+                f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
             )
         elif len(filenames) < params.avg:
             raise ValueError(
diff --git a/egs/spgispeech/ASR/pruned_transducer_stateless2/export.py b/egs/spgispeech/ASR/pruned_transducer_stateless2/export.py
index 77faa3c0e..1f18ae2f3 100755
--- a/egs/spgispeech/ASR/pruned_transducer_stateless2/export.py
+++ b/egs/spgispeech/ASR/pruned_transducer_stateless2/export.py
@@ -50,11 +50,7 @@ import sentencepiece as spm
 import torch
 from train import get_params, get_transducer_model
 
-from icefall.checkpoint import (
-    average_checkpoints,
-    find_checkpoints,
-    load_checkpoint,
-)
+from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
 from icefall.utils import str2bool
 
 
@@ -67,17 +63,20 @@ def get_parser():
         "--epoch",
         type=int,
         default=28,
-        help="It specifies the checkpoint to use for decoding."
-        "Note: Epoch counts from 0.",
+        help=(
+            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
+        ),
     )
 
     parser.add_argument(
         "--avg",
         type=int,
         default=15,
-        help="Number of checkpoints to average. Automatically select "
-        "consecutive checkpoints before the checkpoint specified by "
-        "'--epoch'. ",
+        help=(
+            "Number of checkpoints to average. Automatically select "
+            "consecutive checkpoints before the checkpoint specified by "
+            "'--epoch'. "
+        ),
     )
 
     parser.add_argument(
@@ -119,8 +118,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     return parser
@@ -196,9 +194,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/spgispeech/ASR/pruned_transducer_stateless2/train.py b/egs/spgispeech/ASR/pruned_transducer_stateless2/train.py
index dda29b3e5..cd835a7b4 100755
--- a/egs/spgispeech/ASR/pruned_transducer_stateless2/train.py
+++ b/egs/spgispeech/ASR/pruned_transducer_stateless2/train.py
@@ -77,9 +77,7 @@ from icefall.dist import cleanup_dist, setup_dist
 from icefall.env import get_env_info
 from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
 
-LRSchedulerType = Union[
-    torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
-]
+LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
 
 
 def get_parser():
@@ -155,8 +153,7 @@ def get_parser():
         "--initial-lr",
         type=float,
         default=0.003,
-        help="The initial learning rate.  This value should not need to be "
-        "changed.",
+        help="The initial learning rate.  This value should not need to be changed.",
     )
 
     parser.add_argument(
@@ -179,42 +176,45 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
         "--prune-range",
         type=int,
         default=5,
-        help="The prune range for rnnt loss, it means how many symbols(context)"
-        "we are using to compute the loss",
+        help=(
+            "The prune range for rnnt loss, it means how many symbols(context)"
+            "we are using to compute the loss"
+        ),
     )
 
     parser.add_argument(
         "--lm-scale",
         type=float,
         default=0.25,
-        help="The scale to smooth the loss with lm "
-        "(output of prediction network) part.",
+        help=(
+            "The scale to smooth the loss with lm (output of prediction network) part."
+        ),
     )
 
     parser.add_argument(
         "--am-scale",
         type=float,
         default=0.0,
-        help="The scale to smooth the loss with am (output of encoder network)"
-        "part.",
+        help="The scale to smooth the loss with am (output of encoder network)part.",
     )
 
     parser.add_argument(
         "--simple-loss-scale",
         type=float,
         default=0.5,
-        help="To get pruning ranges, we will calculate a simple version"
-        "loss(joiner is just addition), this simple loss also uses for"
-        "training (as a regularization item). We will scale the simple loss"
-        "with this parameter before adding to the final loss.",
+        help=(
+            "To get pruning ranges, we will calculate a simple version"
+            "loss(joiner is just addition), this simple loss also uses for"
+            "training (as a regularization item). We will scale the simple loss"
+            "with this parameter before adding to the final loss."
+        ),
     )
 
     parser.add_argument(
@@ -554,23 +554,16 @@ def compute_loss(
         # overwhelming the simple_loss and causing it to diverge,
         # in case it had not fully learned the alignment yet.
         pruned_loss_scale = (
-            0.0
-            if warmup < 1.0
-            else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
-        )
-        loss = (
-            params.simple_loss_scale * simple_loss
-            + pruned_loss_scale * pruned_loss
+            0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
         )
+        loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
 
     assert loss.requires_grad == is_training
 
     info = MetricsTracker()
     with warnings.catch_warnings():
         warnings.simplefilter("ignore")
-        info["frames"] = (
-            (feature_lens // params.subsampling_factor).sum().item()
-        )
+        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
 
     # Note: We use reduction=sum while computing the loss.
     info["loss"] = loss.detach().cpu().item()
@@ -733,9 +726,7 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(
-                    tb_writer, "train/tot_", params.batch_idx_train
-                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
diff --git a/egs/tal_csasr/ASR/local/compute_fbank_tal_csasr.py b/egs/tal_csasr/ASR/local/compute_fbank_tal_csasr.py
index 4582609ac..602e50d29 100755
--- a/egs/tal_csasr/ASR/local/compute_fbank_tal_csasr.py
+++ b/egs/tal_csasr/ASR/local/compute_fbank_tal_csasr.py
@@ -84,9 +84,7 @@ def compute_fbank_tal_csasr(num_mel_bins: int = 80):
             )
             if "train" in partition:
                 cut_set = (
-                    cut_set
-                    + cut_set.perturb_speed(0.9)
-                    + cut_set.perturb_speed(1.1)
+                    cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
                 )
             cut_set = cut_set.compute_and_store_features(
                 extractor=extractor,
@@ -112,9 +110,7 @@ def get_args():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
 
diff --git a/egs/tal_csasr/ASR/local/prepare_char.py b/egs/tal_csasr/ASR/local/prepare_char.py
index 2c5b8b8b3..1262baf63 100755
--- a/egs/tal_csasr/ASR/local/prepare_char.py
+++ b/egs/tal_csasr/ASR/local/prepare_char.py
@@ -87,9 +87,7 @@ def lexicon_to_fst_no_sil(
         cur_state = loop_state
 
         word = word2id[word]
-        pieces = [
-            token2id[i] if i in token2id else token2id[""] for i in pieces
-        ]
+        pieces = [token2id[i] if i in token2id else token2id[""] for i in pieces]
 
         for i in range(len(pieces) - 1):
             w = word if i == 0 else eps
diff --git a/egs/tal_csasr/ASR/local/prepare_lang.py b/egs/tal_csasr/ASR/local/prepare_lang.py
index e5ae89ec4..c8cf9b881 100755
--- a/egs/tal_csasr/ASR/local/prepare_lang.py
+++ b/egs/tal_csasr/ASR/local/prepare_lang.py
@@ -317,9 +317,7 @@ def lexicon_to_fst(
 
 def get_args():
     parser = argparse.ArgumentParser()
-    parser.add_argument(
-        "--lang-dir", type=str, help="The lang dir, data/lang_phone"
-    )
+    parser.add_argument("--lang-dir", type=str, help="The lang dir, data/lang_phone")
     return parser.parse_args()
 
 
diff --git a/egs/tal_csasr/ASR/local/test_prepare_lang.py b/egs/tal_csasr/ASR/local/test_prepare_lang.py
index d4cf62bba..74e025ad7 100755
--- a/egs/tal_csasr/ASR/local/test_prepare_lang.py
+++ b/egs/tal_csasr/ASR/local/test_prepare_lang.py
@@ -88,9 +88,7 @@ def test_read_lexicon(filename: str):
     fsa.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
     fsa.draw("L.pdf", title="L")
 
-    fsa_disambig = lexicon_to_fst(
-        lexicon_disambig, phone2id=phone2id, word2id=word2id
-    )
+    fsa_disambig = lexicon_to_fst(lexicon_disambig, phone2id=phone2id, word2id=word2id)
     fsa_disambig.labels_sym = k2.SymbolTable.from_file("phones.txt")
     fsa_disambig.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
     fsa_disambig.draw("L_disambig.pdf", title="L_disambig")
diff --git a/egs/tal_csasr/ASR/local/text2token.py b/egs/tal_csasr/ASR/local/text2token.py
index 71be2a613..2be639b7a 100755
--- a/egs/tal_csasr/ASR/local/text2token.py
+++ b/egs/tal_csasr/ASR/local/text2token.py
@@ -50,15 +50,15 @@ def get_parser():
         "-n",
         default=1,
         type=int,
-        help="number of characters to split, i.e., \
-                        aabb -> a a b b with -n 1 and aa bb with -n 2",
+        help=(
+            "number of characters to split, i.e.,                         aabb -> a a b"
+            " b with -n 1 and aa bb with -n 2"
+        ),
     )
     parser.add_argument(
         "--skip-ncols", "-s", default=0, type=int, help="skip first n columns"
     )
-    parser.add_argument(
-        "--space", default="", type=str, help="space symbol"
-    )
+    parser.add_argument("--space", default="", type=str, help="space symbol")
     parser.add_argument(
         "--non-lang-syms",
         "-l",
@@ -66,9 +66,7 @@ def get_parser():
         type=str,
         help="list of non-linguistic symobles, e.g.,  etc.",
     )
-    parser.add_argument(
-        "text", type=str, default=False, nargs="?", help="input text"
-    )
+    parser.add_argument("text", type=str, default=False, nargs="?", help="input text")
     parser.add_argument(
         "--trans_type",
         "-t",
@@ -108,8 +106,7 @@ def token2id(
             if token_type == "lazy_pinyin":
                 text = lazy_pinyin(chars_list)
                 sub_ids = [
-                    token_table[txt] if txt in token_table else oov_id
-                    for txt in text
+                    token_table[txt] if txt in token_table else oov_id for txt in text
                 ]
                 ids.append(sub_ids)
             else:  # token_type = "pinyin"
@@ -135,9 +132,7 @@ def main():
     if args.text:
         f = codecs.open(args.text, encoding="utf-8")
     else:
-        f = codecs.getreader("utf-8")(
-            sys.stdin if is_python2 else sys.stdin.buffer
-        )
+        f = codecs.getreader("utf-8")(sys.stdin if is_python2 else sys.stdin.buffer)
 
     sys.stdout = codecs.getwriter("utf-8")(
         sys.stdout if is_python2 else sys.stdout.buffer
diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/asr_datamodule.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/asr_datamodule.py
index 49bfb148b..02bd6e2cc 100644
--- a/egs/tal_csasr/ASR/pruned_transducer_stateless5/asr_datamodule.py
+++ b/egs/tal_csasr/ASR/pruned_transducer_stateless5/asr_datamodule.py
@@ -74,10 +74,12 @@ class TAL_CSASRAsrDataModule:
     def add_arguments(cls, parser: argparse.ArgumentParser):
         group = parser.add_argument_group(
             title="ASR data related options",
-            description="These options are used for the preparation of "
-            "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
-            "effective batch sizes, sampling strategies, applied data "
-            "augmentations, etc.",
+            description=(
+                "These options are used for the preparation of "
+                "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
+                "effective batch sizes, sampling strategies, applied data "
+                "augmentations, etc."
+            ),
         )
 
         group.add_argument(
@@ -91,66 +93,81 @@ class TAL_CSASRAsrDataModule:
             "--max-duration",
             type=int,
             default=200.0,
-            help="Maximum pooled recordings duration (seconds) in a "
-            "single batch. You can reduce it if it causes CUDA OOM.",
+            help=(
+                "Maximum pooled recordings duration (seconds) in a "
+                "single batch. You can reduce it if it causes CUDA OOM."
+            ),
         )
 
         group.add_argument(
             "--bucketing-sampler",
             type=str2bool,
             default=True,
-            help="When enabled, the batches will come from buckets of "
-            "similar duration (saves padding frames).",
+            help=(
+                "When enabled, the batches will come from buckets of "
+                "similar duration (saves padding frames)."
+            ),
         )
 
         group.add_argument(
             "--num-buckets",
             type=int,
             default=300,
-            help="The number of buckets for the DynamicBucketingSampler"
-            "(you might want to increase it for larger datasets).",
+            help=(
+                "The number of buckets for the DynamicBucketingSampler"
+                "(you might want to increase it for larger datasets)."
+            ),
         )
 
         group.add_argument(
             "--concatenate-cuts",
             type=str2bool,
             default=False,
-            help="When enabled, utterances (cuts) will be concatenated "
-            "to minimize the amount of padding.",
+            help=(
+                "When enabled, utterances (cuts) will be concatenated "
+                "to minimize the amount of padding."
+            ),
         )
 
         group.add_argument(
             "--duration-factor",
             type=float,
             default=1.0,
-            help="Determines the maximum duration of a concatenated cut "
-            "relative to the duration of the longest cut in a batch.",
+            help=(
+                "Determines the maximum duration of a concatenated cut "
+                "relative to the duration of the longest cut in a batch."
+            ),
         )
 
         group.add_argument(
             "--gap",
             type=float,
             default=1.0,
-            help="The amount of padding (in seconds) inserted between "
-            "concatenated cuts. This padding is filled with noise when "
-            "noise augmentation is used.",
+            help=(
+                "The amount of padding (in seconds) inserted between "
+                "concatenated cuts. This padding is filled with noise when "
+                "noise augmentation is used."
+            ),
         )
 
         group.add_argument(
             "--on-the-fly-feats",
             type=str2bool,
             default=False,
-            help="When enabled, use on-the-fly cut mixing and feature "
-            "extraction. Will drop existing precomputed feature manifests "
-            "if available.",
+            help=(
+                "When enabled, use on-the-fly cut mixing and feature "
+                "extraction. Will drop existing precomputed feature manifests "
+                "if available."
+            ),
         )
 
         group.add_argument(
             "--shuffle",
             type=str2bool,
             default=True,
-            help="When enabled (=default), the examples will be "
-            "shuffled for each epoch.",
+            help=(
+                "When enabled (=default), the examples will be shuffled for each epoch."
+            ),
         )
 
         group.add_argument(
@@ -164,17 +181,18 @@ class TAL_CSASRAsrDataModule:
             "--return-cuts",
             type=str2bool,
             default=True,
-            help="When enabled, each batch will have the "
-            "field: batch['supervisions']['cut'] with the cuts that "
-            "were used to construct it.",
+            help=(
+                "When enabled, each batch will have the "
+                "field: batch['supervisions']['cut'] with the cuts that "
+                "were used to construct it."
+            ),
         )
 
         group.add_argument(
             "--num-workers",
             type=int,
             default=2,
-            help="The number of training dataloader workers that "
-            "collect the batches.",
+            help="The number of training dataloader workers that collect the batches.",
         )
 
         group.add_argument(
@@ -188,18 +206,22 @@ class TAL_CSASRAsrDataModule:
             "--spec-aug-time-warp-factor",
             type=int,
             default=80,
-            help="Used only when --enable-spec-aug is True. "
-            "It specifies the factor for time warping in SpecAugment. "
-            "Larger values mean more warping. "
-            "A value less than 1 means to disable time warp.",
+            help=(
+                "Used only when --enable-spec-aug is True. "
+                "It specifies the factor for time warping in SpecAugment. "
+                "Larger values mean more warping. "
+                "A value less than 1 means to disable time warp."
+            ),
         )
 
         group.add_argument(
             "--enable-musan",
             type=str2bool,
             default=True,
-            help="When enabled, select noise from MUSAN and mix it"
-            "with training dataset. ",
+            help=(
+                "When enabled, select noise from MUSAN and mix it"
+                "with training dataset. "
+            ),
         )
 
         group.add_argument(
@@ -222,24 +244,20 @@ class TAL_CSASRAsrDataModule:
             The state dict for the training sampler.
         """
         logging.info("About to get Musan cuts")
-        cuts_musan = load_manifest(
-            self.args.manifest_dir / "musan_cuts.jsonl.gz"
-        )
+        cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
 
         transforms = []
         if self.args.enable_musan:
             logging.info("Enable MUSAN")
             transforms.append(
-                CutMix(
-                    cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
-                )
+                CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
             )
         else:
             logging.info("Disable MUSAN")
 
         if self.args.concatenate_cuts:
             logging.info(
-                f"Using cut concatenation with duration factor "
+                "Using cut concatenation with duration factor "
                 f"{self.args.duration_factor} and gap {self.args.gap}."
             )
             # Cut concatenation should be the first transform in the list,
@@ -254,9 +272,7 @@ class TAL_CSASRAsrDataModule:
         input_transforms = []
         if self.args.enable_spec_aug:
             logging.info("Enable SpecAugment")
-            logging.info(
-                f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
-            )
+            logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
             # Set the value of num_frame_masks according to Lhotse's version.
             # In different Lhotse's versions, the default of num_frame_masks is
             # different.
@@ -300,9 +316,7 @@ class TAL_CSASRAsrDataModule:
             # Drop feats to be on the safe side.
             train = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(
-                    Fbank(FbankConfig(num_mel_bins=80))
-                ),
+                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
                 input_transforms=input_transforms,
                 return_cuts=self.args.return_cuts,
             )
@@ -360,9 +374,7 @@ class TAL_CSASRAsrDataModule:
         if self.args.on_the_fly_feats:
             validate = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(
-                    Fbank(FbankConfig(num_mel_bins=80))
-                ),
+                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
                 return_cuts=self.args.return_cuts,
             )
         else:
diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/decode.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/decode.py
index b624913f5..b2aef7e86 100755
--- a/egs/tal_csasr/ASR/pruned_transducer_stateless5/decode.py
+++ b/egs/tal_csasr/ASR/pruned_transducer_stateless5/decode.py
@@ -124,20 +124,24 @@ def get_parser():
         "--avg",
         type=int,
         default=15,
-        help="Number of checkpoints to average. Automatically select "
-        "consecutive checkpoints before the checkpoint specified by "
-        "'--epoch' and '--iter'",
+        help=(
+            "Number of checkpoints to average. Automatically select "
+            "consecutive checkpoints before the checkpoint specified by "
+            "'--epoch' and '--iter'"
+        ),
     )
 
     parser.add_argument(
         "--use-averaged-model",
         type=str2bool,
         default=False,
-        help="Whether to load averaged model. Currently it only supports "
-        "using --epoch. If True, it would decode with the averaged model "
-        "over the epoch range from `epoch-avg` (excluded) to `epoch`."
-        "Actually only the models with epoch number of `epoch-avg` and "
-        "`epoch` are loaded for averaging. ",
+        help=(
+            "Whether to load averaged model. Currently it only supports "
+            "using --epoch. If True, it would decode with the averaged model "
+            "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+            "Actually only the models with epoch number of `epoch-avg` and "
+            "`epoch` are loaded for averaging. "
+        ),
     )
 
     parser.add_argument(
@@ -208,8 +212,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -268,9 +271,7 @@ def decode_one_batch(
     supervisions = batch["supervisions"]
     feature_lens = supervisions["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(
-        x=feature, x_lens=feature_lens
-    )
+    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
     hyps = []
     zh_hyps = []
     en_hyps = []
@@ -303,10 +304,7 @@ def decode_one_batch(
             hyps.append(chars_new)
             zh_hyps.append(zh_text)
             en_hyps.append(en_text)
-    elif (
-        params.decoding_method == "greedy_search"
-        and params.max_sym_per_frame == 1
-    ):
+    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -375,9 +373,7 @@ def decode_one_batch(
                     f"Unsupported decoding method: {params.decoding_method}"
                 )
             for i in range(encoder_out.size(0)):
-                hyp = sp.decode(
-                    [lexicon.token_table[idx] for idx in hyp_tokens[i]]
-                )
+                hyp = sp.decode([lexicon.token_table[idx] for idx in hyp_tokens[i]])
                 chars = pattern.split(hyp.upper())
                 chars_new = []
                 zh_text = []
@@ -396,11 +392,11 @@ def decode_one_batch(
         return {"greedy_search": (hyps, zh_hyps, en_hyps)}
     elif params.decoding_method == "fast_beam_search":
         return {
-            (
-                f"beam_{params.beam}_"
-                f"max_contexts_{params.max_contexts}_"
-                f"max_states_{params.max_states}"
-            ): (hyps, zh_hyps, en_hyps)
+            f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": (
+                hyps,
+                zh_hyps,
+                en_hyps,
+            )
         }
     else:
         return {f"beam_size_{params.beam_size}": (hyps, zh_hyps, en_hyps)}
@@ -506,9 +502,7 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
     return results, zh_results, en_results
 
 
@@ -541,8 +535,7 @@ def save_results(
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = (
-        params.res_dir
-        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tWER", file=f)
@@ -585,9 +578,7 @@ def main():
         params.suffix += f"-max-contexts-{params.max_contexts}"
         params.suffix += f"-max-states-{params.max_states}"
     elif "beam_search" in params.decoding_method:
-        params.suffix += (
-            f"-{params.decoding_method}-beam-size-{params.beam_size}"
-        )
+        params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
     else:
         params.suffix += f"-context-{params.context_size}"
         params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
@@ -619,13 +610,12 @@ def main():
 
     if not params.use_averaged_model:
         if params.iter > 0:
-            filenames = find_checkpoints(
-                params.exp_dir, iteration=-params.iter
-            )[: params.avg]
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg
+            ]
             if len(filenames) == 0:
                 raise ValueError(
-                    f"No checkpoints found for"
-                    f" --iter {params.iter}, --avg {params.avg}"
+                    f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
                 )
             elif len(filenames) < params.avg:
                 raise ValueError(
@@ -648,13 +638,12 @@ def main():
             model.load_state_dict(average_checkpoints(filenames, device=device))
     else:
         if params.iter > 0:
-            filenames = find_checkpoints(
-                params.exp_dir, iteration=-params.iter
-            )[: params.avg + 1]
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg + 1
+            ]
             if len(filenames) == 0:
                 raise ValueError(
-                    f"No checkpoints found for"
-                    f" --iter {params.iter}, --avg {params.avg}"
+                    f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
                 )
             elif len(filenames) < params.avg + 1:
                 raise ValueError(
@@ -682,7 +671,7 @@ def main():
             filename_start = f"{params.exp_dir}/epoch-{start}.pt"
             filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
             logging.info(
-                f"Calculating the averaged model over epoch range from "
+                "Calculating the averaged model over epoch range from "
                 f"{start} (excluded) to {params.epoch}"
             )
             model.to(device)
diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/export.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/export.py
index 8f900208a..94a4c7a2e 100755
--- a/egs/tal_csasr/ASR/pruned_transducer_stateless5/export.py
+++ b/egs/tal_csasr/ASR/pruned_transducer_stateless5/export.py
@@ -92,20 +92,24 @@ def get_parser():
         "--avg",
         type=int,
         default=15,
-        help="Number of checkpoints to average. Automatically select "
-        "consecutive checkpoints before the checkpoint specified by "
-        "'--epoch' and '--iter'",
+        help=(
+            "Number of checkpoints to average. Automatically select "
+            "consecutive checkpoints before the checkpoint specified by "
+            "'--epoch' and '--iter'"
+        ),
     )
 
     parser.add_argument(
         "--use-averaged-model",
         type=str2bool,
         default=False,
-        help="Whether to load averaged model. Currently it only supports "
-        "using --epoch. If True, it would decode with the averaged model "
-        "over the epoch range from `epoch-avg` (excluded) to `epoch`."
-        "Actually only the models with epoch number of `epoch-avg` and "
-        "`epoch` are loaded for averaging. ",
+        help=(
+            "Whether to load averaged model. Currently it only supports "
+            "using --epoch. If True, it would decode with the averaged model "
+            "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+            "Actually only the models with epoch number of `epoch-avg` and "
+            "`epoch` are loaded for averaging. "
+        ),
     )
 
     parser.add_argument(
@@ -139,8 +143,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     add_model_arguments(parser)
@@ -176,13 +179,12 @@ def main():
 
     if not params.use_averaged_model:
         if params.iter > 0:
-            filenames = find_checkpoints(
-                params.exp_dir, iteration=-params.iter
-            )[: params.avg]
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg
+            ]
             if len(filenames) == 0:
                 raise ValueError(
-                    f"No checkpoints found for"
-                    f" --iter {params.iter}, --avg {params.avg}"
+                    f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
                 )
             elif len(filenames) < params.avg:
                 raise ValueError(
@@ -205,13 +207,12 @@ def main():
             model.load_state_dict(average_checkpoints(filenames, device=device))
     else:
         if params.iter > 0:
-            filenames = find_checkpoints(
-                params.exp_dir, iteration=-params.iter
-            )[: params.avg + 1]
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg + 1
+            ]
             if len(filenames) == 0:
                 raise ValueError(
-                    f"No checkpoints found for"
-                    f" --iter {params.iter}, --avg {params.avg}"
+                    f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
                 )
             elif len(filenames) < params.avg + 1:
                 raise ValueError(
@@ -239,7 +240,7 @@ def main():
             filename_start = f"{params.exp_dir}/epoch-{start}.pt"
             filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
             logging.info(
-                f"Calculating the averaged model over epoch range from "
+                "Calculating the averaged model over epoch range from "
                 f"{start} (excluded) to {params.epoch}"
             )
             model.to(device)
@@ -277,9 +278,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/pretrained.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/pretrained.py
index dbe213b24..198242129 100755
--- a/egs/tal_csasr/ASR/pruned_transducer_stateless5/pretrained.py
+++ b/egs/tal_csasr/ASR/pruned_transducer_stateless5/pretrained.py
@@ -84,9 +84,11 @@ def get_parser():
         "--checkpoint",
         type=str,
         required=True,
-        help="Path to the checkpoint. "
-        "The checkpoint is assumed to be saved by "
-        "icefall.checkpoint.save_checkpoint().",
+        help=(
+            "Path to the checkpoint. "
+            "The checkpoint is assumed to be saved by "
+            "icefall.checkpoint.save_checkpoint()."
+        ),
     )
 
     parser.add_argument(
@@ -115,10 +117,12 @@ def get_parser():
         "sound_files",
         type=str,
         nargs="+",
-        help="The input sound file(s) to transcribe. "
-        "Supported formats are those supported by torchaudio.load(). "
-        "For example, wav and flac are supported. "
-        "The sample rate has to be 16kHz.",
+        help=(
+            "The input sound file(s) to transcribe. "
+            "Supported formats are those supported by torchaudio.load(). "
+            "For example, wav and flac are supported. "
+            "The sample rate has to be 16kHz."
+        ),
     )
 
     parser.add_argument(
@@ -165,8 +169,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -197,10 +200,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. "
-            f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
@@ -263,15 +265,11 @@ def main():
     features = fbank(waves)
     feature_lengths = [f.size(0) for f in features]
 
-    features = pad_sequence(
-        features, batch_first=True, padding_value=math.log(1e-10)
-    )
+    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
 
     feature_lengths = torch.tensor(feature_lengths, device=device)
 
-    encoder_out, encoder_out_lens = model.encoder(
-        x=features, x_lens=feature_lengths
-    )
+    encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths)
 
     num_waves = encoder_out.size(0)
     hyps = []
@@ -367,9 +365,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/train.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/train.py
index ca35eba45..676e8c904 100755
--- a/egs/tal_csasr/ASR/pruned_transducer_stateless5/train.py
+++ b/egs/tal_csasr/ASR/pruned_transducer_stateless5/train.py
@@ -86,9 +86,7 @@ from icefall.env import get_env_info
 from icefall.lexicon import Lexicon
 from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
 
-LRSchedulerType = Union[
-    torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
-]
+LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
 
 
 def add_model_arguments(parser: argparse.ArgumentParser):
@@ -214,8 +212,7 @@ def get_parser():
         "--initial-lr",
         type=float,
         default=0.003,
-        help="The initial learning rate.  This value should not need "
-        "to be changed.",
+        help="The initial learning rate.  This value should not need to be changed.",
     )
 
     parser.add_argument(
@@ -238,42 +235,45 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
         "--prune-range",
         type=int,
         default=5,
-        help="The prune range for rnnt loss, it means how many symbols(context)"
-        "we are using to compute the loss",
+        help=(
+            "The prune range for rnnt loss, it means how many symbols(context)"
+            "we are using to compute the loss"
+        ),
     )
 
     parser.add_argument(
         "--lm-scale",
         type=float,
         default=0.25,
-        help="The scale to smooth the loss with lm "
-        "(output of prediction network) part.",
+        help=(
+            "The scale to smooth the loss with lm (output of prediction network) part."
+        ),
     )
 
     parser.add_argument(
         "--am-scale",
         type=float,
         default=0.0,
-        help="The scale to smooth the loss with am (output of encoder network)"
-        "part.",
+        help="The scale to smooth the loss with am (output of encoder network)part.",
     )
 
     parser.add_argument(
         "--simple-loss-scale",
         type=float,
         default=0.5,
-        help="To get pruning ranges, we will calculate a simple version"
-        "loss(joiner is just addition), this simple loss also uses for"
-        "training (as a regularization item). We will scale the simple loss"
-        "with this parameter before adding to the final loss.",
+        help=(
+            "To get pruning ranges, we will calculate a simple version"
+            "loss(joiner is just addition), this simple loss also uses for"
+            "training (as a regularization item). We will scale the simple loss"
+            "with this parameter before adding to the final loss."
+        ),
     )
 
     parser.add_argument(
@@ -600,11 +600,7 @@ def compute_loss(
      warmup: a floating point value which increases throughout training;
         values >= 1.0 are fully warmed up and have all modules present.
     """
-    device = (
-        model.device
-        if isinstance(model, DDP)
-        else next(model.parameters()).device
-    )
+    device = model.device if isinstance(model, DDP) else next(model.parameters()).device
     feature = batch["inputs"]
     # at entry, feature is (N, T, C)
     assert feature.ndim == 3
@@ -634,22 +630,15 @@ def compute_loss(
         # overwhelming the simple_loss and causing it to diverge,
         # in case it had not fully learned the alignment yet.
         pruned_loss_scale = (
-            0.0
-            if warmup < 1.0
-            else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
-        )
-        loss = (
-            params.simple_loss_scale * simple_loss
-            + pruned_loss_scale * pruned_loss
+            0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
         )
+        loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
     assert loss.requires_grad == is_training
 
     info = MetricsTracker()
     with warnings.catch_warnings():
         warnings.simplefilter("ignore")
-        info["frames"] = (
-            (feature_lens // params.subsampling_factor).sum().item()
-        )
+        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
 
     # Note: We use reduction=sum while computing the loss.
     info["loss"] = loss.detach().cpu().item()
@@ -828,9 +817,7 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(
-                    tb_writer, "train/tot_", params.batch_idx_train
-                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
@@ -944,7 +931,7 @@ def run(rank, world_size, args):
 
     if params.print_diagnostics:
         opts = diagnostics.TensorDiagnosticOptions(
-            2 ** 22
+            2**22
         )  # allow 4 megabytes per sub-module
         diagnostic = diagnostics.attach_diagnostics(model, opts)
 
diff --git a/egs/tedlium3/ASR/local/compute_fbank_tedlium.py b/egs/tedlium3/ASR/local/compute_fbank_tedlium.py
index 327962a79..733ebf235 100755
--- a/egs/tedlium3/ASR/local/compute_fbank_tedlium.py
+++ b/egs/tedlium3/ASR/local/compute_fbank_tedlium.py
@@ -83,9 +83,7 @@ def compute_fbank_tedlium():
             )
             if "train" in partition:
                 cut_set = (
-                    cut_set
-                    + cut_set.perturb_speed(0.9)
-                    + cut_set.perturb_speed(1.1)
+                    cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
                 )
             cur_num_jobs = num_jobs if ex is None else 80
             cur_num_jobs = min(cur_num_jobs, len(cut_set))
@@ -104,9 +102,7 @@ def compute_fbank_tedlium():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
 
diff --git a/egs/tedlium3/ASR/local/convert_transcript_words_to_bpe_ids.py b/egs/tedlium3/ASR/local/convert_transcript_words_to_bpe_ids.py
index 49544ccb3..9dbcc9d9e 100644
--- a/egs/tedlium3/ASR/local/convert_transcript_words_to_bpe_ids.py
+++ b/egs/tedlium3/ASR/local/convert_transcript_words_to_bpe_ids.py
@@ -25,9 +25,7 @@ import sentencepiece as spm
 
 def get_args():
     parser = argparse.ArgumentParser()
-    parser.add_argument(
-        "--texts", type=List[str], help="The input transcripts list."
-    )
+    parser.add_argument("--texts", type=List[str], help="The input transcripts list.")
     parser.add_argument(
         "--bpe-model",
         type=str,
diff --git a/egs/tedlium3/ASR/local/prepare_lexicon.py b/egs/tedlium3/ASR/local/prepare_lexicon.py
index 35dd332e8..b9160b6d4 100755
--- a/egs/tedlium3/ASR/local/prepare_lexicon.py
+++ b/egs/tedlium3/ASR/local/prepare_lexicon.py
@@ -23,11 +23,12 @@ consisting of supervisions_train.json and does the following:
 1. Generate lexicon_words.txt.
 
 """
-import lhotse
 import argparse
 import logging
 from pathlib import Path
 
+import lhotse
+
 
 def get_args():
     parser = argparse.ArgumentParser()
@@ -61,9 +62,7 @@ def prepare_lexicon(manifests_dir: str, lang_dir: str):
     words = set()
 
     lexicon = Path(lang_dir) / "lexicon_words.txt"
-    sups = lhotse.load_manifest(
-        f"{manifests_dir}/tedlium_supervisions_train.jsonl.gz"
-    )
+    sups = lhotse.load_manifest(f"{manifests_dir}/tedlium_supervisions_train.jsonl.gz")
     for s in sups:
         # list the words units and filter the empty item
         words_list = list(filter(None, s.text.split()))
@@ -88,9 +87,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
 
diff --git a/egs/tedlium3/ASR/local/prepare_transcripts.py b/egs/tedlium3/ASR/local/prepare_transcripts.py
index 1039ac5bb..7ea4e89a4 100755
--- a/egs/tedlium3/ASR/local/prepare_transcripts.py
+++ b/egs/tedlium3/ASR/local/prepare_transcripts.py
@@ -23,11 +23,12 @@ consisting of supervisions_train.json and does the following:
 1. Generate train.text.
 
 """
-import lhotse
 import argparse
 import logging
 from pathlib import Path
 
+import lhotse
+
 
 def get_args():
     parser = argparse.ArgumentParser()
@@ -61,9 +62,7 @@ def prepare_transcripts(manifests_dir: str, lang_dir: str):
     texts = []
 
     train_text = Path(lang_dir) / "train.text"
-    sups = lhotse.load_manifest(
-        f"{manifests_dir}/tedlium_supervisions_train.jsonl.gz"
-    )
+    sups = lhotse.load_manifest(f"{manifests_dir}/tedlium_supervisions_train.jsonl.gz")
     for s in sups:
         texts.append(s.text)
 
@@ -83,9 +82,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
 
diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless/decode.py b/egs/tedlium3/ASR/pruned_transducer_stateless/decode.py
index 2b294e601..6bae33e65 100755
--- a/egs/tedlium3/ASR/pruned_transducer_stateless/decode.py
+++ b/egs/tedlium3/ASR/pruned_transducer_stateless/decode.py
@@ -94,17 +94,20 @@ def get_parser():
         "--epoch",
         type=int,
         default=29,
-        help="It specifies the checkpoint to use for decoding."
-        "Note: Epoch counts from 0.",
+        help=(
+            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
+        ),
     )
 
     parser.add_argument(
         "--avg",
         type=int,
         default=13,
-        help="Number of checkpoints to average. Automatically select "
-        "consecutive checkpoints before the checkpoint specified by "
-        "'--epoch'. ",
+        help=(
+            "Number of checkpoints to average. Automatically select "
+            "consecutive checkpoints before the checkpoint specified by "
+            "'--epoch'. "
+        ),
     )
 
     parser.add_argument(
@@ -172,8 +175,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -231,9 +233,7 @@ def decode_one_batch(
     supervisions = batch["supervisions"]
     feature_lens = supervisions["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(
-        x=feature, x_lens=feature_lens
-    )
+    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
     hyps = []
 
     if params.decoding_method == "fast_beam_search":
@@ -248,10 +248,7 @@ def decode_one_batch(
         )
         for hyp in sp.decode(hyp_tokens):
             hyps.append(hyp.split())
-    elif (
-        params.decoding_method == "greedy_search"
-        and params.max_sym_per_frame == 1
-    ):
+    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -297,11 +294,7 @@ def decode_one_batch(
         return {"greedy_search": hyps}
     elif params.decoding_method == "fast_beam_search":
         return {
-            (
-                f"beam_{params.beam}_"
-                f"max_contexts_{params.max_contexts}_"
-                f"max_states_{params.max_states}"
-            ): hyps
+            f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps
         }
     else:
         return {f"beam_size_{params.beam_size}": hyps}
@@ -374,9 +367,7 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
     return results
 
 
@@ -409,8 +400,7 @@ def save_results(
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = (
-        params.res_dir
-        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tWER", file=f)
diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless/export.py b/egs/tedlium3/ASR/pruned_transducer_stateless/export.py
index a1c3bcea3..244740932 100644
--- a/egs/tedlium3/ASR/pruned_transducer_stateless/export.py
+++ b/egs/tedlium3/ASR/pruned_transducer_stateless/export.py
@@ -65,17 +65,20 @@ def get_parser():
         "--epoch",
         type=int,
         default=30,
-        help="It specifies the checkpoint to use for decoding."
-        "Note: Epoch counts from 0.",
+        help=(
+            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
+        ),
     )
 
     parser.add_argument(
         "--avg",
         type=int,
         default=13,
-        help="Number of checkpoints to average. Automatically select "
-        "consecutive checkpoints before the checkpoint specified by "
-        "'--epoch'. ",
+        help=(
+            "Number of checkpoints to average. Automatically select "
+            "consecutive checkpoints before the checkpoint specified by "
+            "'--epoch'. "
+        ),
     )
 
     parser.add_argument(
@@ -106,8 +109,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     return parser
@@ -179,9 +181,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless/pretrained.py b/egs/tedlium3/ASR/pruned_transducer_stateless/pretrained.py
index 8480ac029..00545f107 100644
--- a/egs/tedlium3/ASR/pruned_transducer_stateless/pretrained.py
+++ b/egs/tedlium3/ASR/pruned_transducer_stateless/pretrained.py
@@ -93,9 +93,11 @@ def get_parser():
         "--checkpoint",
         type=str,
         required=True,
-        help="Path to the checkpoint. "
-        "The checkpoint is assumed to be saved by "
-        "icefall.checkpoint.save_checkpoint().",
+        help=(
+            "Path to the checkpoint. "
+            "The checkpoint is assumed to be saved by "
+            "icefall.checkpoint.save_checkpoint()."
+        ),
     )
 
     parser.add_argument(
@@ -122,10 +124,12 @@ def get_parser():
         "sound_files",
         type=str,
         nargs="+",
-        help="The input sound file(s) to transcribe. "
-        "Supported formats are those supported by torchaudio.load(). "
-        "For example, wav and flac are supported. "
-        "The sample rate has to be 16kHz.",
+        help=(
+            "The input sound file(s) to transcribe. "
+            "Supported formats are those supported by torchaudio.load(). "
+            "For example, wav and flac are supported. "
+            "The sample rate has to be 16kHz."
+        ),
     )
 
     parser.add_argument(
@@ -165,8 +169,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
@@ -203,10 +206,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. "
-            f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
@@ -271,9 +273,7 @@ def main():
     features = fbank(waves)
     feature_lengths = [f.size(0) for f in features]
 
-    features = pad_sequence(
-        features, batch_first=True, padding_value=math.log(1e-10)
-    )
+    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
 
     feature_lengths = torch.tensor(feature_lengths, device=device)
 
@@ -298,10 +298,7 @@ def main():
         )
         for hyp in sp.decode(hyp_tokens):
             hyps.append(hyp.split())
-    elif (
-        params.decoding_method == "greedy_search"
-        and params.max_sym_per_frame == 1
-    ):
+    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -353,9 +350,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless/train.py b/egs/tedlium3/ASR/pruned_transducer_stateless/train.py
index 8d5cdf683..70c5e290f 100755
--- a/egs/tedlium3/ASR/pruned_transducer_stateless/train.py
+++ b/egs/tedlium3/ASR/pruned_transducer_stateless/train.py
@@ -133,42 +133,45 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
         "--prune-range",
         type=int,
         default=5,
-        help="The prune range for rnnt loss, it means how many symbols(context)"
-        "we are using to compute the loss",
+        help=(
+            "The prune range for rnnt loss, it means how many symbols(context)"
+            "we are using to compute the loss"
+        ),
     )
 
     parser.add_argument(
         "--lm-scale",
         type=float,
         default=0.25,
-        help="The scale to smooth the loss with lm "
-        "(output of prediction network) part.",
+        help=(
+            "The scale to smooth the loss with lm (output of prediction network) part."
+        ),
     )
 
     parser.add_argument(
         "--am-scale",
         type=float,
         default=0.0,
-        help="The scale to smooth the loss with am (output of encoder network)"
-        "part.",
+        help="The scale to smooth the loss with am (output of encoder network)part.",
     )
 
     parser.add_argument(
         "--simple-loss-scale",
         type=float,
         default=0.5,
-        help="To get pruning ranges, we will calculate a simple version"
-        "loss(joiner is just addition), this simple loss also uses for"
-        "training (as a regularization item). We will scale the simple loss"
-        "with this parameter before adding to the final loss.",
+        help=(
+            "To get pruning ranges, we will calculate a simple version"
+            "loss(joiner is just addition), this simple loss also uses for"
+            "training (as a regularization item). We will scale the simple loss"
+            "with this parameter before adding to the final loss."
+        ),
     )
 
     parser.add_argument(
@@ -556,9 +559,7 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(
-                    tb_writer, "train/tot_", params.batch_idx_train
-                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
@@ -678,9 +679,7 @@ def run(rank, world_size, args):
 
         cur_lr = optimizer._rate
         if tb_writer is not None:
-            tb_writer.add_scalar(
-                "train/learning_rate", cur_lr, params.batch_idx_train
-            )
+            tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
             tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
 
         if rank == 0:
diff --git a/egs/tedlium3/ASR/transducer_stateless/asr_datamodule.py b/egs/tedlium3/ASR/transducer_stateless/asr_datamodule.py
index 51de46ae8..86ac2fea3 100644
--- a/egs/tedlium3/ASR/transducer_stateless/asr_datamodule.py
+++ b/egs/tedlium3/ASR/transducer_stateless/asr_datamodule.py
@@ -63,10 +63,12 @@ class TedLiumAsrDataModule:
     def add_arguments(cls, parser: argparse.ArgumentParser):
         group = parser.add_argument_group(
             title="ASR data related options",
-            description="These options are used for the preparation of "
-            "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
-            "effective batch sizes, sampling strategies, applied data "
-            "augmentations, etc.",
+            description=(
+                "These options are used for the preparation of "
+                "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
+                "effective batch sizes, sampling strategies, applied data "
+                "augmentations, etc."
+            ),
         )
         group.add_argument(
             "--manifest-dir",
@@ -78,75 +80,91 @@ class TedLiumAsrDataModule:
             "--max-duration",
             type=int,
             default=200.0,
-            help="Maximum pooled recordings duration (seconds) in a "
-            "single batch. You can reduce it if it causes CUDA OOM.",
+            help=(
+                "Maximum pooled recordings duration (seconds) in a "
+                "single batch. You can reduce it if it causes CUDA OOM."
+            ),
         )
         group.add_argument(
             "--bucketing-sampler",
             type=str2bool,
             default=True,
-            help="When enabled, the batches will come from buckets of "
-            "similar duration (saves padding frames).",
+            help=(
+                "When enabled, the batches will come from buckets of "
+                "similar duration (saves padding frames)."
+            ),
         )
         group.add_argument(
             "--num-buckets",
             type=int,
             default=30,
-            help="The number of buckets for the DynamicBucketingSampler"
-            "(you might want to increase it for larger datasets).",
+            help=(
+                "The number of buckets for the DynamicBucketingSampler"
+                "(you might want to increase it for larger datasets)."
+            ),
         )
         group.add_argument(
             "--concatenate-cuts",
             type=str2bool,
             default=False,
-            help="When enabled, utterances (cuts) will be concatenated "
-            "to minimize the amount of padding.",
+            help=(
+                "When enabled, utterances (cuts) will be concatenated "
+                "to minimize the amount of padding."
+            ),
         )
         group.add_argument(
             "--duration-factor",
             type=float,
             default=1.0,
-            help="Determines the maximum duration of a concatenated cut "
-            "relative to the duration of the longest cut in a batch.",
+            help=(
+                "Determines the maximum duration of a concatenated cut "
+                "relative to the duration of the longest cut in a batch."
+            ),
         )
         group.add_argument(
             "--gap",
             type=float,
             default=1.0,
-            help="The amount of padding (in seconds) inserted between "
-            "concatenated cuts. This padding is filled with noise when "
-            "noise augmentation is used.",
+            help=(
+                "The amount of padding (in seconds) inserted between "
+                "concatenated cuts. This padding is filled with noise when "
+                "noise augmentation is used."
+            ),
         )
         group.add_argument(
             "--on-the-fly-feats",
             type=str2bool,
             default=False,
-            help="When enabled, use on-the-fly cut mixing and feature "
-            "extraction. Will drop existing precomputed feature manifests "
-            "if available.",
+            help=(
+                "When enabled, use on-the-fly cut mixing and feature "
+                "extraction. Will drop existing precomputed feature manifests "
+                "if available."
+            ),
         )
         group.add_argument(
             "--shuffle",
             type=str2bool,
             default=True,
-            help="When enabled (=default), the examples will be "
-            "shuffled for each epoch.",
+            help=(
+                "When enabled (=default), the examples will be shuffled for each epoch."
+            ),
         )
         group.add_argument(
             "--return-cuts",
             type=str2bool,
             default=True,
-            help="When enabled, each batch will have the "
-            "field: batch['supervisions']['cut'] with the cuts that "
-            "were used to construct it.",
+            help=(
+                "When enabled, each batch will have the "
+                "field: batch['supervisions']['cut'] with the cuts that "
+                "were used to construct it."
+            ),
         )
 
         group.add_argument(
             "--num-workers",
             type=int,
             default=2,
-            help="The number of training dataloader workers that "
-            "collect the batches.",
+            help="The number of training dataloader workers that collect the batches.",
         )
 
         group.add_argument(
@@ -160,18 +178,22 @@ class TedLiumAsrDataModule:
             "--spec-aug-time-warp-factor",
             type=int,
             default=80,
-            help="Used only when --enable-spec-aug is True. "
-            "It specifies the factor for time warping in SpecAugment. "
-            "Larger values mean more warping. "
-            "A value less than 1 means to disable time warp.",
+            help=(
+                "Used only when --enable-spec-aug is True. "
+                "It specifies the factor for time warping in SpecAugment. "
+                "Larger values mean more warping. "
+                "A value less than 1 means to disable time warp."
+            ),
         )
 
         group.add_argument(
             "--enable-musan",
             type=str2bool,
             default=True,
-            help="When enabled, select noise from MUSAN and mix it"
-            "with training dataset. ",
+            help=(
+                "When enabled, select noise from MUSAN and mix it"
+                "with training dataset. "
+            ),
         )
 
     def train_dataloaders(self, cuts_train: CutSet) -> DataLoader:
@@ -179,20 +201,16 @@ class TedLiumAsrDataModule:
         transforms = []
         if self.args.enable_musan:
             logging.info("Enable MUSAN")
-            cuts_musan = load_manifest(
-                self.args.manifest_dir / "musan_cuts.jsonl.gz"
-            )
+            cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
             transforms.append(
-                CutMix(
-                    cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
-                )
+                CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
             )
         else:
             logging.info("Disable MUSAN")
 
         if self.args.concatenate_cuts:
             logging.info(
-                f"Using cut concatenation with duration factor "
+                "Using cut concatenation with duration factor "
                 f"{self.args.duration_factor} and gap {self.args.gap}."
             )
             # Cut concatenation should be the first transform in the list,
@@ -207,9 +225,7 @@ class TedLiumAsrDataModule:
         input_transforms = []
         if self.args.enable_spec_aug:
             logging.info("Enable SpecAugment")
-            logging.info(
-                f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
-            )
+            logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
             # Set the value of num_frame_masks according to Lhotse's version.
             # In different Lhotse's versions, the default of num_frame_masks is
             # different.
@@ -253,9 +269,7 @@ class TedLiumAsrDataModule:
             # Drop feats to be on the safe side.
             train = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(
-                    Fbank(FbankConfig(num_mel_bins=80))
-                ),
+                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
                 input_transforms=input_transforms,
                 return_cuts=self.args.return_cuts,
             )
@@ -300,9 +314,7 @@ class TedLiumAsrDataModule:
         if self.args.on_the_fly_feats:
             validate = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(
-                    Fbank(FbankConfig(num_mel_bins=80))
-                ),
+                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
                 return_cuts=self.args.return_cuts,
             )
         else:
@@ -358,13 +370,9 @@ class TedLiumAsrDataModule:
     @lru_cache()
     def dev_cuts(self) -> CutSet:
         logging.info("About to get dev cuts")
-        return load_manifest_lazy(
-            self.args.manifest_dir / "tedlium_cuts_dev.jsonl.gz"
-        )
+        return load_manifest_lazy(self.args.manifest_dir / "tedlium_cuts_dev.jsonl.gz")
 
     @lru_cache()
     def test_cuts(self) -> CutSet:
         logging.info("About to get test cuts")
-        return load_manifest_lazy(
-            self.args.manifest_dir / "tedlium_cuts_test.jsonl.gz"
-        )
+        return load_manifest_lazy(self.args.manifest_dir / "tedlium_cuts_test.jsonl.gz")
diff --git a/egs/tedlium3/ASR/transducer_stateless/beam_search.py b/egs/tedlium3/ASR/transducer_stateless/beam_search.py
index 77caf6460..1f99edaf3 100644
--- a/egs/tedlium3/ASR/transducer_stateless/beam_search.py
+++ b/egs/tedlium3/ASR/transducer_stateless/beam_search.py
@@ -87,9 +87,9 @@ def greedy_search(
         y = logits.argmax().item()
         if y != blank_id and y != unk_id:
             hyp.append(y)
-            decoder_input = torch.tensor(
-                [hyp[-context_size:]], device=device
-            ).reshape(1, context_size)
+            decoder_input = torch.tensor([hyp[-context_size:]], device=device).reshape(
+                1, context_size
+            )
 
             decoder_out = model.decoder(decoder_input, need_pad=False)
 
@@ -148,9 +148,7 @@ class HypothesisList(object):
         key = hyp.key
         if key in self:
             old_hyp = self._data[key]  # shallow copy
-            torch.logaddexp(
-                old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob
-            )
+            torch.logaddexp(old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob)
         else:
             self._data[key] = hyp
 
@@ -166,9 +164,7 @@ class HypothesisList(object):
           Return the hypothesis that has the largest `log_prob`.
         """
         if length_norm:
-            return max(
-                self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys)
-            )
+            return max(self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys))
         else:
             return max(self._data.values(), key=lambda hyp: hyp.log_prob)
 
@@ -344,9 +340,9 @@ def modified_beam_search(
 
     device = model.device
 
-    decoder_input = torch.tensor(
-        [blank_id] * context_size, device=device
-    ).reshape(1, context_size)
+    decoder_input = torch.tensor([blank_id] * context_size, device=device).reshape(
+        1, context_size
+    )
 
     decoder_out = model.decoder(decoder_input, need_pad=False)
 
@@ -383,9 +379,7 @@ def modified_beam_search(
         decoder_out = model.decoder(decoder_input, need_pad=False)
         # decoder_output is of shape (num_hyps, 1, decoder_output_dim)
 
-        current_encoder_out = current_encoder_out.expand(
-            decoder_out.size(0), 1, -1
-        )
+        current_encoder_out = current_encoder_out.expand(decoder_out.size(0), 1, -1)
 
         logits = model.joiner(
             current_encoder_out,
@@ -454,9 +448,9 @@ def beam_search(
 
     device = model.device
 
-    decoder_input = torch.tensor(
-        [blank_id] * context_size, device=device
-    ).reshape(1, context_size)
+    decoder_input = torch.tensor([blank_id] * context_size, device=device).reshape(
+        1, context_size
+    )
 
     decoder_out = model.decoder(decoder_input, need_pad=False)
 
diff --git a/egs/tedlium3/ASR/transducer_stateless/decode.py b/egs/tedlium3/ASR/transducer_stateless/decode.py
index d3e9e55e7..12d0e2652 100755
--- a/egs/tedlium3/ASR/transducer_stateless/decode.py
+++ b/egs/tedlium3/ASR/transducer_stateless/decode.py
@@ -81,16 +81,19 @@ def get_parser():
         "--epoch",
         type=int,
         default=29,
-        help="It specifies the checkpoint to use for decoding."
-        "Note: Epoch counts from 0.",
+        help=(
+            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
+        ),
     )
     parser.add_argument(
         "--avg",
         type=int,
         default=13,
-        help="Number of checkpoints to average. Automatically select "
-        "consecutive checkpoints before the checkpoint specified by "
-        "'--epoch'. ",
+        help=(
+            "Number of checkpoints to average. Automatically select "
+            "consecutive checkpoints before the checkpoint specified by "
+            "'--epoch'. "
+        ),
     )
 
     parser.add_argument(
@@ -130,8 +133,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -250,9 +252,7 @@ def decode_one_batch(
     supervisions = batch["supervisions"]
     feature_lens = supervisions["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(
-        x=feature, x_lens=feature_lens
-    )
+    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
     hyps = []
     batch_size = encoder_out.size(0)
 
@@ -275,9 +275,7 @@ def decode_one_batch(
                 model=model, encoder_out=encoder_out_i, beam=params.beam_size
             )
         else:
-            raise ValueError(
-                f"Unsupported decoding method: {params.decoding_method}"
-            )
+            raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
         hyps.append(sp.decode(hyp).split())
 
     if params.decoding_method == "greedy_search":
@@ -348,9 +346,7 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
     return results
 
 
@@ -383,8 +379,7 @@ def save_results(
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = (
-        params.res_dir
-        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tWER", file=f)
diff --git a/egs/tedlium3/ASR/transducer_stateless/decoder.py b/egs/tedlium3/ASR/transducer_stateless/decoder.py
index f0c6f32b6..f9a3814c6 100644
--- a/egs/tedlium3/ASR/transducer_stateless/decoder.py
+++ b/egs/tedlium3/ASR/transducer_stateless/decoder.py
@@ -90,9 +90,7 @@ class Decoder(nn.Module):
         if self.context_size > 1:
             embedding_out = embedding_out.permute(0, 2, 1)
             if need_pad is True:
-                embedding_out = F.pad(
-                    embedding_out, pad=(self.context_size - 1, 0)
-                )
+                embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0))
             else:
                 # During inference time, there is no need to do extra padding
                 # as we only need one output
diff --git a/egs/tedlium3/ASR/transducer_stateless/export.py b/egs/tedlium3/ASR/transducer_stateless/export.py
index c32b1d002..0b2ae970b 100644
--- a/egs/tedlium3/ASR/transducer_stateless/export.py
+++ b/egs/tedlium3/ASR/transducer_stateless/export.py
@@ -69,17 +69,20 @@ def get_parser():
         "--epoch",
         type=int,
         default=20,
-        help="It specifies the checkpoint to use for decoding."
-        "Note: Epoch counts from 0.",
+        help=(
+            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
+        ),
     )
 
     parser.add_argument(
         "--avg",
         type=int,
         default=10,
-        help="Number of checkpoints to average. Automatically select "
-        "consecutive checkpoints before the checkpoint specified by "
-        "'--epoch'. ",
+        help=(
+            "Number of checkpoints to average. Automatically select "
+            "consecutive checkpoints before the checkpoint specified by "
+            "'--epoch'. "
+        ),
     )
 
     parser.add_argument(
@@ -110,8 +113,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     return parser
@@ -247,9 +249,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/tedlium3/ASR/transducer_stateless/pretrained.py b/egs/tedlium3/ASR/transducer_stateless/pretrained.py
index c0e3bb844..912d65497 100644
--- a/egs/tedlium3/ASR/transducer_stateless/pretrained.py
+++ b/egs/tedlium3/ASR/transducer_stateless/pretrained.py
@@ -82,9 +82,11 @@ def get_parser():
         "--checkpoint",
         type=str,
         required=True,
-        help="Path to the checkpoint. "
-        "The checkpoint is assumed to be saved by "
-        "icefall.checkpoint.save_checkpoint().",
+        help=(
+            "Path to the checkpoint. "
+            "The checkpoint is assumed to be saved by "
+            "icefall.checkpoint.save_checkpoint()."
+        ),
     )
 
     parser.add_argument(
@@ -110,10 +112,12 @@ def get_parser():
         "sound_files",
         type=str,
         nargs="+",
-        help="The input sound file(s) to transcribe. "
-        "Supported formats are those supported by torchaudio.load(). "
-        "For example, wav and flac are supported. "
-        "The sample rate has to be 16kHz.",
+        help=(
+            "The input sound file(s) to transcribe. "
+            "Supported formats are those supported by torchaudio.load(). "
+            "For example, wav and flac are supported. "
+            "The sample rate has to be 16kHz."
+        ),
     )
 
     parser.add_argument(
@@ -127,8 +131,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -222,10 +225,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. "
-            f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
@@ -285,9 +287,7 @@ def main():
     features = fbank(waves)
     feature_lengths = [f.size(0) for f in features]
 
-    features = pad_sequence(
-        features, batch_first=True, padding_value=math.log(1e-10)
-    )
+    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
 
     feature_lengths = torch.tensor(feature_lengths, device=device)
 
@@ -335,9 +335,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/tedlium3/ASR/transducer_stateless/train.py b/egs/tedlium3/ASR/transducer_stateless/train.py
index 09cbf4a00..6fed32e81 100755
--- a/egs/tedlium3/ASR/transducer_stateless/train.py
+++ b/egs/tedlium3/ASR/transducer_stateless/train.py
@@ -133,8 +133,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
@@ -525,9 +524,7 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(
-                    tb_writer, "train/tot_", params.batch_idx_train
-                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
@@ -647,9 +644,7 @@ def run(rank, world_size, args):
 
         cur_lr = optimizer._rate
         if tb_writer is not None:
-            tb_writer.add_scalar(
-                "train/learning_rate", cur_lr, params.batch_idx_train
-            )
+            tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
             tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
 
         if rank == 0:
diff --git a/egs/timit/ASR/RESULTS.md b/egs/timit/ASR/RESULTS.md
index b78c16b88..d8ceb82b6 100644
--- a/egs/timit/ASR/RESULTS.md
+++ b/egs/timit/ASR/RESULTS.md
@@ -71,4 +71,4 @@ python tdnn_ligru_ctc/decode.py --epoch 25 \
                                --avg 17 \
                                --max-duration 20 \
                                --lang-dir data/lang_phone
-```
\ No newline at end of file
+```
diff --git a/egs/timit/ASR/local/compile_hlg.py b/egs/timit/ASR/local/compile_hlg.py
index 58cab4cf2..32c248d7e 100644
--- a/egs/timit/ASR/local/compile_hlg.py
+++ b/egs/timit/ASR/local/compile_hlg.py
@@ -146,9 +146,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
 
diff --git a/egs/timit/ASR/local/compute_fbank_timit.py b/egs/timit/ASR/local/compute_fbank_timit.py
index f25786a0c..ecdf10ba9 100644
--- a/egs/timit/ASR/local/compute_fbank_timit.py
+++ b/egs/timit/ASR/local/compute_fbank_timit.py
@@ -85,9 +85,7 @@ def compute_fbank_timit():
             )
             if partition == "TRAIN":
                 cut_set = (
-                    cut_set
-                    + cut_set.perturb_speed(0.9)
-                    + cut_set.perturb_speed(1.1)
+                    cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
                 )
             cut_set = cut_set.compute_and_store_features(
                 extractor=extractor,
@@ -101,9 +99,7 @@ def compute_fbank_timit():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
 
diff --git a/egs/timit/ASR/local/prepare_lexicon.py b/egs/timit/ASR/local/prepare_lexicon.py
index 04023a9ab..0cf0f0deb 100644
--- a/egs/timit/ASR/local/prepare_lexicon.py
+++ b/egs/timit/ASR/local/prepare_lexicon.py
@@ -62,9 +62,7 @@ def prepare_lexicon(manifests_dir: str, lang_dir: str):
 
     phones = set()
 
-    supervisions_train = (
-        Path(manifests_dir) / "timit_supervisions_TRAIN.jsonl.gz"
-    )
+    supervisions_train = Path(manifests_dir) / "timit_supervisions_TRAIN.jsonl.gz"
     lexicon = Path(lang_dir) / "lexicon.txt"
 
     logging.info(f"Loading {supervisions_train}!")
@@ -97,9 +95,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
 
diff --git a/egs/timit/ASR/prepare.sh b/egs/timit/ASR/prepare.sh
index ae1b96a68..d11cd3a05 100644
--- a/egs/timit/ASR/prepare.sh
+++ b/egs/timit/ASR/prepare.sh
@@ -20,9 +20,9 @@ stop_stage=100
 #  - $dl_dir/lm
 #      This directory contains the language model(LM) downloaded from
 #      https://huggingface.co/luomingshuang/timit_lm, and the LM is based
-#	     on 39 phones. About how to get these LM files, you can know it 
+#	     on 39 phones. About how to get these LM files, you can know it
 #      from https://github.com/luomingshuang/Train_LM_with_kaldilm.
-#	
+#
 #	    - lm_3_gram.arpa
 #     - lm_4_gram.arpa
 #
diff --git a/egs/timit/ASR/tdnn_ligru_ctc/decode.py b/egs/timit/ASR/tdnn_ligru_ctc/decode.py
index 4f2aa2340..5a59a13ce 100644
--- a/egs/timit/ASR/tdnn_ligru_ctc/decode.py
+++ b/egs/timit/ASR/tdnn_ligru_ctc/decode.py
@@ -57,16 +57,19 @@ def get_parser():
         "--epoch",
         type=int,
         default=19,
-        help="It specifies the checkpoint to use for decoding."
-        "Note: Epoch counts from 0.",
+        help=(
+            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
+        ),
     )
     parser.add_argument(
         "--avg",
         type=int,
         default=5,
-        help="Number of checkpoints to average. Automatically select "
-        "consecutive checkpoints before the checkpoint specified by "
-        "'--epoch'. ",
+        help=(
+            "Number of checkpoints to average. Automatically select "
+            "consecutive checkpoints before the checkpoint specified by "
+            "'--epoch'. "
+        ),
     )
     parser.add_argument(
         "--method",
@@ -336,9 +339,7 @@ def decode_dataset(
         if batch_idx % 100 == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
     return results
 
 
@@ -400,9 +401,7 @@ def main():
 
     logging.info(f"device: {device}")
 
-    HLG = k2.Fsa.from_dict(
-        torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")
-    )
+    HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu"))
     HLG = HLG.to(device)
     assert HLG.requires_grad is False
 
@@ -462,9 +461,7 @@ def main():
 
     if params.export:
         logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
-        torch.save(
-            {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
-        )
+        torch.save({"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt")
         return
 
     model.to(device)
@@ -485,9 +482,7 @@ def main():
         G=G,
     )
 
-    save_results(
-        params=params, test_set_name=test_set, results_dict=results_dict
-    )
+    save_results(params=params, test_set_name=test_set, results_dict=results_dict)
 
     logging.info("Done!")
 
diff --git a/egs/timit/ASR/tdnn_ligru_ctc/model.py b/egs/timit/ASR/tdnn_ligru_ctc/model.py
index 4d2199ace..9a594a969 100644
--- a/egs/timit/ASR/tdnn_ligru_ctc/model.py
+++ b/egs/timit/ASR/tdnn_ligru_ctc/model.py
@@ -16,11 +16,11 @@
 # limitations under the License.
 
 
+from typing import Optional
+
 import torch
 import torch.nn as nn
-
 from torch import Tensor
-from typing import Optional
 
 
 class TdnnLiGRU(nn.Module):
@@ -261,9 +261,7 @@ class LiGRU(torch.nn.Module):
         h = []
         if hx is not None:
             if self.bidirectional:
-                hx = hx.reshape(
-                    self.num_layers, self.batch_size * 2, self.hidden_size
-                )
+                hx = hx.reshape(self.num_layers, self.batch_size * 2, self.hidden_size)
         # Processing the different layers
         for i, ligru_lay in enumerate(self.rnn):
             if hx is not None:
@@ -445,9 +443,7 @@ class LiGRU_Layer(torch.nn.Module):
             if self.drop_mask_cnt + self.batch_size > self.N_drop_masks:
                 self.drop_mask_cnt = 0
                 self.drop_masks = self.drop(
-                    torch.ones(
-                        self.N_drop_masks, self.hidden_size, device=w.device
-                    )
+                    torch.ones(self.N_drop_masks, self.hidden_size, device=w.device)
                 ).data
 
             # Sampling the mask
diff --git a/egs/timit/ASR/tdnn_ligru_ctc/pretrained.py b/egs/timit/ASR/tdnn_ligru_ctc/pretrained.py
index 7da285944..da669bc39 100644
--- a/egs/timit/ASR/tdnn_ligru_ctc/pretrained.py
+++ b/egs/timit/ASR/tdnn_ligru_ctc/pretrained.py
@@ -29,11 +29,7 @@ import torchaudio
 from model import TdnnLiGRU
 from torch.nn.utils.rnn import pad_sequence
 
-from icefall.decode import (
-    get_lattice,
-    one_best_decoding,
-    rescore_with_whole_lattice,
-)
+from icefall.decode import get_lattice, one_best_decoding, rescore_with_whole_lattice
 from icefall.utils import AttributeDict, get_texts
 
 
@@ -46,9 +42,11 @@ def get_parser():
         "--checkpoint",
         type=str,
         required=True,
-        help="Path to the checkpoint. "
-        "The checkpoint is assumed to be saved by "
-        "icefall.checkpoint.save_checkpoint().",
+        help=(
+            "Path to the checkpoint. "
+            "The checkpoint is assumed to be saved by "
+            "icefall.checkpoint.save_checkpoint()."
+        ),
     )
 
     parser.add_argument(
@@ -58,9 +56,7 @@ def get_parser():
         help="Path to words.txt",
     )
 
-    parser.add_argument(
-        "--HLG", type=str, required=True, help="Path to HLG.pt."
-    )
+    parser.add_argument("--HLG", type=str, required=True, help="Path to HLG.pt.")
 
     parser.add_argument(
         "--method",
@@ -103,10 +99,12 @@ def get_parser():
         "sound_files",
         type=str,
         nargs="+",
-        help="The input sound file(s) to transcribe. "
-        "Supported formats are those supported by torchaudio.load(). "
-        "For example, wav and flac are supported. "
-        "The sample rate has to be 16kHz.",
+        help=(
+            "The input sound file(s) to transcribe. "
+            "Supported formats are those supported by torchaudio.load(). "
+            "For example, wav and flac are supported. "
+            "The sample rate has to be 16kHz."
+        ),
     )
 
     return parser
@@ -144,10 +142,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. "
-            f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
@@ -215,9 +212,7 @@ def main():
     logging.info("Decoding started")
     features = fbank(waves)
 
-    features = pad_sequence(
-        features, batch_first=True, padding_value=math.log(1e-10)
-    )
+    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
     features = features.permute(0, 2, 1)  # now features is (N, C, T)
 
     with torch.no_grad():
@@ -269,9 +264,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/timit/ASR/tdnn_ligru_ctc/train.py b/egs/timit/ASR/tdnn_ligru_ctc/train.py
index 452c2a7cb..48b7feda0 100644
--- a/egs/timit/ASR/tdnn_ligru_ctc/train.py
+++ b/egs/timit/ASR/tdnn_ligru_ctc/train.py
@@ -449,9 +449,7 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(
-                    tb_writer, "train/tot_", params.batch_idx_train
-                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             valid_info = compute_validation_loss(
diff --git a/egs/timit/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/timit/ASR/tdnn_lstm_ctc/asr_datamodule.py
index 1554e987f..d957c22e1 100644
--- a/egs/timit/ASR/tdnn_lstm_ctc/asr_datamodule.py
+++ b/egs/timit/ASR/tdnn_lstm_ctc/asr_datamodule.py
@@ -63,10 +63,12 @@ class TimitAsrDataModule(DataModule):
         super().add_arguments(parser)
         group = parser.add_argument_group(
             title="ASR data related options",
-            description="These options are used for the preparation of "
-            "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
-            "effective batch sizes, sampling strategies, applied data "
-            "augmentations, etc.",
+            description=(
+                "These options are used for the preparation of "
+                "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
+                "effective batch sizes, sampling strategies, applied data "
+                "augmentations, etc."
+            ),
         )
         group.add_argument(
             "--feature-dir",
@@ -78,75 +80,91 @@ class TimitAsrDataModule(DataModule):
             "--max-duration",
             type=int,
             default=200.0,
-            help="Maximum pooled recordings duration (seconds) in a "
-            "single batch. You can reduce it if it causes CUDA OOM.",
+            help=(
+                "Maximum pooled recordings duration (seconds) in a "
+                "single batch. You can reduce it if it causes CUDA OOM."
+            ),
         )
         group.add_argument(
             "--bucketing-sampler",
             type=str2bool,
             default=True,
-            help="When enabled, the batches will come from buckets of "
-            "similar duration (saves padding frames).",
+            help=(
+                "When enabled, the batches will come from buckets of "
+                "similar duration (saves padding frames)."
+            ),
         )
         group.add_argument(
             "--num-buckets",
             type=int,
             default=30,
-            help="The number of buckets for the DynamicBucketingSampler"
-            "(you might want to increase it for larger datasets).",
+            help=(
+                "The number of buckets for the DynamicBucketingSampler"
+                "(you might want to increase it for larger datasets)."
+            ),
         )
         group.add_argument(
             "--concatenate-cuts",
             type=str2bool,
             default=False,
-            help="When enabled, utterances (cuts) will be concatenated "
-            "to minimize the amount of padding.",
+            help=(
+                "When enabled, utterances (cuts) will be concatenated "
+                "to minimize the amount of padding."
+            ),
         )
         group.add_argument(
             "--duration-factor",
             type=float,
             default=1.0,
-            help="Determines the maximum duration of a concatenated cut "
-            "relative to the duration of the longest cut in a batch.",
+            help=(
+                "Determines the maximum duration of a concatenated cut "
+                "relative to the duration of the longest cut in a batch."
+            ),
         )
         group.add_argument(
             "--gap",
             type=float,
             default=1.0,
-            help="The amount of padding (in seconds) inserted between "
-            "concatenated cuts. This padding is filled with noise when "
-            "noise augmentation is used.",
+            help=(
+                "The amount of padding (in seconds) inserted between "
+                "concatenated cuts. This padding is filled with noise when "
+                "noise augmentation is used."
+            ),
         )
         group.add_argument(
             "--on-the-fly-feats",
             type=str2bool,
             default=False,
-            help="When enabled, use on-the-fly cut mixing and feature "
-            "extraction. Will drop existing precomputed feature manifests "
-            "if available.",
+            help=(
+                "When enabled, use on-the-fly cut mixing and feature "
+                "extraction. Will drop existing precomputed feature manifests "
+                "if available."
+            ),
         )
         group.add_argument(
             "--shuffle",
             type=str2bool,
             default=True,
-            help="When enabled (=default), the examples will be "
-            "shuffled for each epoch.",
+            help=(
+                "When enabled (=default), the examples will be shuffled for each epoch."
+            ),
         )
         group.add_argument(
             "--return-cuts",
             type=str2bool,
             default=True,
-            help="When enabled, each batch will have the "
-            "field: batch['supervisions']['cut'] with the cuts that "
-            "were used to construct it.",
+            help=(
+                "When enabled, each batch will have the "
+                "field: batch['supervisions']['cut'] with the cuts that "
+                "were used to construct it."
+            ),
         )
 
         group.add_argument(
             "--num-workers",
             type=int,
             default=2,
-            help="The number of training dataloader workers that "
-            "collect the batches.",
+            help="The number of training dataloader workers that collect the batches.",
         )
 
     def train_dataloaders(self) -> DataLoader:
@@ -154,15 +172,13 @@ class TimitAsrDataModule(DataModule):
         cuts_train = self.train_cuts()
 
         logging.info("About to get Musan cuts")
-        cuts_musan = load_manifest(
-            self.args.feature_dir / "musan_cuts.jsonl.gz"
-        )
+        cuts_musan = load_manifest(self.args.feature_dir / "musan_cuts.jsonl.gz")
 
         logging.info("About to create train dataset")
         transforms = [CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20))]
         if self.args.concatenate_cuts:
             logging.info(
-                f"Using cut concatenation with duration factor "
+                "Using cut concatenation with duration factor "
                 f"{self.args.duration_factor} and gap {self.args.gap}."
             )
             # Cut concatenation should be the first transform in the list,
@@ -178,9 +194,9 @@ class TimitAsrDataModule(DataModule):
         # In different Lhotse's versions, the default of num_frame_masks is
         # different.
         num_frame_masks = 10
-        num_frame_masks_parameter = inspect.signature(
-            SpecAugment.__init__
-        ).parameters["num_frame_masks"]
+        num_frame_masks_parameter = inspect.signature(SpecAugment.__init__).parameters[
+            "num_frame_masks"
+        ]
         if num_frame_masks_parameter.default == 1:
             num_frame_masks = 2
         logging.info(f"Num frame mask: {num_frame_masks}")
@@ -212,9 +228,7 @@ class TimitAsrDataModule(DataModule):
             # Drop feats to be on the safe side.
             train = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(
-                    Fbank(FbankConfig(num_mel_bins=80))
-                ),
+                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
                 input_transforms=input_transforms,
                 return_cuts=self.args.return_cuts,
             )
@@ -263,9 +277,7 @@ class TimitAsrDataModule(DataModule):
         if self.args.on_the_fly_feats:
             validate = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(
-                    Fbank(FbankConfig(num_mel_bins=80))
-                ),
+                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
                 return_cuts=self.args.return_cuts,
             )
         else:
@@ -299,20 +311,14 @@ class TimitAsrDataModule(DataModule):
         for cuts_test in cuts:
             logging.debug("About to create test dataset")
             test = K2SpeechRecognitionDataset(
-                input_strategy=OnTheFlyFeatures(
-                    Fbank(FbankConfig(num_mel_bins=80))
-                )
+                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
                 if self.args.on_the_fly_feats
                 else PrecomputedFeatures(),
                 return_cuts=self.args.return_cuts,
             )
-            sampler = SingleCutSampler(
-                cuts_test, max_duration=self.args.max_duration
-            )
+            sampler = SingleCutSampler(cuts_test, max_duration=self.args.max_duration)
             logging.debug("About to create test dataloader")
-            test_dl = DataLoader(
-                test, batch_size=None, sampler=sampler, num_workers=1
-            )
+            test_dl = DataLoader(test, batch_size=None, sampler=sampler, num_workers=1)
             test_loaders.append(test_dl)
 
         if is_list:
diff --git a/egs/timit/ASR/tdnn_lstm_ctc/decode.py b/egs/timit/ASR/tdnn_lstm_ctc/decode.py
index 5e7300cf2..319ee5515 100644
--- a/egs/timit/ASR/tdnn_lstm_ctc/decode.py
+++ b/egs/timit/ASR/tdnn_lstm_ctc/decode.py
@@ -56,16 +56,19 @@ def get_parser():
         "--epoch",
         type=int,
         default=25,
-        help="It specifies the checkpoint to use for decoding."
-        "Note: Epoch counts from 0.",
+        help=(
+            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
+        ),
     )
     parser.add_argument(
         "--avg",
         type=int,
         default=5,
-        help="Number of checkpoints to average. Automatically select "
-        "consecutive checkpoints before the checkpoint specified by "
-        "'--epoch'. ",
+        help=(
+            "Number of checkpoints to average. Automatically select "
+            "consecutive checkpoints before the checkpoint specified by "
+            "'--epoch'. "
+        ),
     )
     parser.add_argument(
         "--method",
@@ -335,9 +338,7 @@ def decode_dataset(
         if batch_idx % 100 == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
     return results
 
 
@@ -399,9 +400,7 @@ def main():
 
     logging.info(f"device: {device}")
 
-    HLG = k2.Fsa.from_dict(
-        torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")
-    )
+    HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu"))
     HLG = HLG.to(device)
     assert HLG.requires_grad is False
 
@@ -461,9 +460,7 @@ def main():
 
     if params.export:
         logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
-        torch.save(
-            {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
-        )
+        torch.save({"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt")
         return
 
     model.to(device)
@@ -483,9 +480,7 @@ def main():
         G=G,
     )
 
-    save_results(
-        params=params, test_set_name=test_set, results_dict=results_dict
-    )
+    save_results(params=params, test_set_name=test_set, results_dict=results_dict)
 
     logging.info("Done!")
 
diff --git a/egs/timit/ASR/tdnn_lstm_ctc/model.py b/egs/timit/ASR/tdnn_lstm_ctc/model.py
index 51edb97e2..e211ad80d 100644
--- a/egs/timit/ASR/tdnn_lstm_ctc/model.py
+++ b/egs/timit/ASR/tdnn_lstm_ctc/model.py
@@ -74,10 +74,7 @@ class TdnnLstm(nn.Module):
             nn.BatchNorm1d(num_features=512, affine=False),
         )
         self.lstms = nn.ModuleList(
-            [
-                nn.LSTM(input_size=512, hidden_size=512, num_layers=1)
-                for _ in range(4)
-            ]
+            [nn.LSTM(input_size=512, hidden_size=512, num_layers=1) for _ in range(4)]
         )
         self.lstm_bnorms = nn.ModuleList(
             [nn.BatchNorm1d(num_features=512, affine=False) for _ in range(5)]
diff --git a/egs/timit/ASR/tdnn_lstm_ctc/pretrained.py b/egs/timit/ASR/tdnn_lstm_ctc/pretrained.py
index 5f478da1c..0c72c973b 100644
--- a/egs/timit/ASR/tdnn_lstm_ctc/pretrained.py
+++ b/egs/timit/ASR/tdnn_lstm_ctc/pretrained.py
@@ -29,11 +29,7 @@ import torchaudio
 from model import TdnnLstm
 from torch.nn.utils.rnn import pad_sequence
 
-from icefall.decode import (
-    get_lattice,
-    one_best_decoding,
-    rescore_with_whole_lattice,
-)
+from icefall.decode import get_lattice, one_best_decoding, rescore_with_whole_lattice
 from icefall.utils import AttributeDict, get_texts
 
 
@@ -46,9 +42,11 @@ def get_parser():
         "--checkpoint",
         type=str,
         required=True,
-        help="Path to the checkpoint. "
-        "The checkpoint is assumed to be saved by "
-        "icefall.checkpoint.save_checkpoint().",
+        help=(
+            "Path to the checkpoint. "
+            "The checkpoint is assumed to be saved by "
+            "icefall.checkpoint.save_checkpoint()."
+        ),
     )
 
     parser.add_argument(
@@ -58,9 +56,7 @@ def get_parser():
         help="Path to words.txt",
     )
 
-    parser.add_argument(
-        "--HLG", type=str, required=True, help="Path to HLG.pt."
-    )
+    parser.add_argument("--HLG", type=str, required=True, help="Path to HLG.pt.")
 
     parser.add_argument(
         "--method",
@@ -103,10 +99,12 @@ def get_parser():
         "sound_files",
         type=str,
         nargs="+",
-        help="The input sound file(s) to transcribe. "
-        "Supported formats are those supported by torchaudio.load(). "
-        "For example, wav and flac are supported. "
-        "The sample rate has to be 16kHz.",
+        help=(
+            "The input sound file(s) to transcribe. "
+            "Supported formats are those supported by torchaudio.load(). "
+            "For example, wav and flac are supported. "
+            "The sample rate has to be 16kHz."
+        ),
     )
 
     return parser
@@ -144,10 +142,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. "
-            f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
@@ -215,9 +212,7 @@ def main():
     logging.info("Decoding started")
     features = fbank(waves)
 
-    features = pad_sequence(
-        features, batch_first=True, padding_value=math.log(1e-10)
-    )
+    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
     features = features.permute(0, 2, 1)  # now features is (N, C, T)
 
     with torch.no_grad():
@@ -269,9 +264,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/timit/ASR/tdnn_lstm_ctc/train.py b/egs/timit/ASR/tdnn_lstm_ctc/train.py
index 849256b98..be1ecffaa 100644
--- a/egs/timit/ASR/tdnn_lstm_ctc/train.py
+++ b/egs/timit/ASR/tdnn_lstm_ctc/train.py
@@ -449,9 +449,7 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(
-                    tb_writer, "train/tot_", params.batch_idx_train
-                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             valid_info = compute_validation_loss(
diff --git a/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_dev_test.py b/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_dev_test.py
index 8a9f6ed30..bd73e520e 100755
--- a/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_dev_test.py
+++ b/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_dev_test.py
@@ -20,12 +20,7 @@ import logging
 from pathlib import Path
 
 import torch
-from lhotse import (
-    CutSet,
-    KaldifeatFbank,
-    KaldifeatFbankConfig,
-    LilcomHdf5Writer,
-)
+from lhotse import CutSet, KaldifeatFbank, KaldifeatFbankConfig, LilcomHdf5Writer
 
 # Torch's multithreaded behavior needs to be disabled or
 # it wastes a lot of CPU and slow things down.
@@ -83,9 +78,7 @@ def compute_fbank_wenetspeech_dev_test():
 
 
 def main():
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
     logging.basicConfig(format=formatter, level=logging.INFO)
 
     compute_fbank_wenetspeech_dev_test()
diff --git a/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_splits.py b/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_splits.py
index a882b6113..c228597b8 100755
--- a/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_splits.py
+++ b/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_splits.py
@@ -62,8 +62,10 @@ def get_parser():
         "--batch-duration",
         type=float,
         default=600.0,
-        help="The maximum number of audio seconds in a batch."
-        "Determines batch size dynamically.",
+        help=(
+            "The maximum number of audio seconds in a batch."
+            "Determines batch size dynamically."
+        ),
     )
 
     parser.add_argument(
@@ -152,9 +154,7 @@ def main():
     date_time = now.strftime("%Y-%m-%d-%H-%M-%S")
 
     log_filename = "log-compute_fbank_wenetspeech_splits"
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
     log_filename = f"{log_filename}-{date_time}"
 
     logging.basicConfig(
diff --git a/egs/wenetspeech/ASR/local/prepare_char.py b/egs/wenetspeech/ASR/local/prepare_char.py
index 8bc073c75..d8622842f 100755
--- a/egs/wenetspeech/ASR/local/prepare_char.py
+++ b/egs/wenetspeech/ASR/local/prepare_char.py
@@ -83,9 +83,7 @@ def lexicon_to_fst_no_sil(
         cur_state = loop_state
 
         word = word2id[word]
-        pieces = [
-            token2id[i] if i in token2id else token2id[""] for i in pieces
-        ]
+        pieces = [token2id[i] if i in token2id else token2id[""] for i in pieces]
 
         for i in range(len(pieces) - 1):
             w = word if i == 0 else eps
@@ -138,9 +136,7 @@ def contain_oov(token_sym_table: Dict[str, int], tokens: List[str]) -> bool:
     return False
 
 
-def generate_lexicon(
-    token_sym_table: Dict[str, int], words: List[str]
-) -> Lexicon:
+def generate_lexicon(token_sym_table: Dict[str, int], words: List[str]) -> Lexicon:
     """Generate a lexicon from a word list and token_sym_table.
     Args:
       token_sym_table:
diff --git a/egs/wenetspeech/ASR/local/preprocess_wenetspeech.py b/egs/wenetspeech/ASR/local/preprocess_wenetspeech.py
index 817969c47..93ce750f8 100755
--- a/egs/wenetspeech/ASR/local/preprocess_wenetspeech.py
+++ b/egs/wenetspeech/ASR/local/preprocess_wenetspeech.py
@@ -115,11 +115,7 @@ def preprocess_wenet_speech():
                 f"Speed perturb for {partition} with factors 0.9 and 1.1 "
                 "(Perturbing may take 8 minutes and saving may take 20 minutes)"
             )
-            cut_set = (
-                cut_set
-                + cut_set.perturb_speed(0.9)
-                + cut_set.perturb_speed(1.1)
-            )
+            cut_set = cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
         logging.info(f"Saving to {raw_cuts_path}")
         cut_set.to_file(raw_cuts_path)
 
diff --git a/egs/wenetspeech/ASR/local/text2token.py b/egs/wenetspeech/ASR/local/text2token.py
index 1c463cf1c..e121d842c 100755
--- a/egs/wenetspeech/ASR/local/text2token.py
+++ b/egs/wenetspeech/ASR/local/text2token.py
@@ -50,15 +50,15 @@ def get_parser():
         "-n",
         default=1,
         type=int,
-        help="number of characters to split, i.e., \
-                        aabb -> a a b b with -n 1 and aa bb with -n 2",
+        help=(
+            "number of characters to split, i.e.,                         aabb -> a a b"
+            " b with -n 1 and aa bb with -n 2"
+        ),
     )
     parser.add_argument(
         "--skip-ncols", "-s", default=0, type=int, help="skip first n columns"
     )
-    parser.add_argument(
-        "--space", default="", type=str, help="space symbol"
-    )
+    parser.add_argument("--space", default="", type=str, help="space symbol")
     parser.add_argument(
         "--non-lang-syms",
         "-l",
@@ -66,9 +66,7 @@ def get_parser():
         type=str,
         help="list of non-linguistic symobles, e.g.,  etc.",
     )
-    parser.add_argument(
-        "text", type=str, default=False, nargs="?", help="input text"
-    )
+    parser.add_argument("text", type=str, default=False, nargs="?", help="input text")
     parser.add_argument(
         "--trans_type",
         "-t",
@@ -108,8 +106,7 @@ def token2id(
             if token_type == "lazy_pinyin":
                 text = lazy_pinyin(chars_list)
                 sub_ids = [
-                    token_table[txt] if txt in token_table else oov_id
-                    for txt in text
+                    token_table[txt] if txt in token_table else oov_id for txt in text
                 ]
                 ids.append(sub_ids)
             else:  # token_type = "pinyin"
@@ -135,9 +132,7 @@ def main():
     if args.text:
         f = codecs.open(args.text, encoding="utf-8")
     else:
-        f = codecs.getreader("utf-8")(
-            sys.stdin if is_python2 else sys.stdin.buffer
-        )
+        f = codecs.getreader("utf-8")(sys.stdin if is_python2 else sys.stdin.buffer)
 
     sys.stdout = codecs.getwriter("utf-8")(
         sys.stdout if is_python2 else sys.stdout.buffer
diff --git a/egs/wenetspeech/ASR/prepare.sh b/egs/wenetspeech/ASR/prepare.sh
index 755fbb2d7..da7d7e061 100755
--- a/egs/wenetspeech/ASR/prepare.sh
+++ b/egs/wenetspeech/ASR/prepare.sh
@@ -190,7 +190,7 @@ if [ $stage -le 15 ] && [ $stop_stage -ge 15 ]; then
   mkdir -p $lang_char_dir
 
   if ! which jq; then
-      echo "This script is intended to be used with jq but you have not installed jq 
+      echo "This script is intended to be used with jq but you have not installed jq
       Note: in Linux, you can install jq with the following command:
       1. wget -O jq https://github.com/stedolan/jq/releases/download/jq-1.6/jq-linux64
       2. chmod +x ./jq
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py
index 10c953e3b..bd92ac115 100644
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py
@@ -81,10 +81,12 @@ class WenetSpeechAsrDataModule:
     def add_arguments(cls, parser: argparse.ArgumentParser):
         group = parser.add_argument_group(
             title="ASR data related options",
-            description="These options are used for the preparation of "
-            "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
-            "effective batch sizes, sampling strategies, applied data "
-            "augmentations, etc.",
+            description=(
+                "These options are used for the preparation of "
+                "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
+                "effective batch sizes, sampling strategies, applied data "
+                "augmentations, etc."
+            ),
         )
         group.add_argument(
             "--manifest-dir",
@@ -96,75 +98,91 @@ class WenetSpeechAsrDataModule:
             "--max-duration",
             type=int,
             default=200.0,
-            help="Maximum pooled recordings duration (seconds) in a "
-            "single batch. You can reduce it if it causes CUDA OOM.",
+            help=(
+                "Maximum pooled recordings duration (seconds) in a "
+                "single batch. You can reduce it if it causes CUDA OOM."
+            ),
         )
         group.add_argument(
             "--bucketing-sampler",
             type=str2bool,
             default=True,
-            help="When enabled, the batches will come from buckets of "
-            "similar duration (saves padding frames).",
+            help=(
+                "When enabled, the batches will come from buckets of "
+                "similar duration (saves padding frames)."
+            ),
         )
         group.add_argument(
             "--num-buckets",
             type=int,
             default=300,
-            help="The number of buckets for the DynamicBucketingSampler"
-            "(you might want to increase it for larger datasets).",
+            help=(
+                "The number of buckets for the DynamicBucketingSampler"
+                "(you might want to increase it for larger datasets)."
+            ),
         )
         group.add_argument(
             "--concatenate-cuts",
             type=str2bool,
             default=False,
-            help="When enabled, utterances (cuts) will be concatenated "
-            "to minimize the amount of padding.",
+            help=(
+                "When enabled, utterances (cuts) will be concatenated "
+                "to minimize the amount of padding."
+            ),
         )
         group.add_argument(
             "--duration-factor",
             type=float,
             default=1.0,
-            help="Determines the maximum duration of a concatenated cut "
-            "relative to the duration of the longest cut in a batch.",
+            help=(
+                "Determines the maximum duration of a concatenated cut "
+                "relative to the duration of the longest cut in a batch."
+            ),
         )
         group.add_argument(
             "--gap",
             type=float,
             default=1.0,
-            help="The amount of padding (in seconds) inserted between "
-            "concatenated cuts. This padding is filled with noise when "
-            "noise augmentation is used.",
+            help=(
+                "The amount of padding (in seconds) inserted between "
+                "concatenated cuts. This padding is filled with noise when "
+                "noise augmentation is used."
+            ),
         )
         group.add_argument(
             "--on-the-fly-feats",
             type=str2bool,
             default=False,
-            help="When enabled, use on-the-fly cut mixing and feature "
-            "extraction. Will drop existing precomputed feature manifests "
-            "if available.",
+            help=(
+                "When enabled, use on-the-fly cut mixing and feature "
+                "extraction. Will drop existing precomputed feature manifests "
+                "if available."
+            ),
         )
         group.add_argument(
             "--shuffle",
             type=str2bool,
             default=True,
-            help="When enabled (=default), the examples will be "
-            "shuffled for each epoch.",
+            help=(
+                "When enabled (=default), the examples will be shuffled for each epoch."
+            ),
         )
         group.add_argument(
             "--return-cuts",
             type=str2bool,
             default=True,
-            help="When enabled, each batch will have the "
-            "field: batch['supervisions']['cut'] with the cuts that "
-            "were used to construct it.",
+            help=(
+                "When enabled, each batch will have the "
+                "field: batch['supervisions']['cut'] with the cuts that "
+                "were used to construct it."
+            ),
         )
 
         group.add_argument(
             "--num-workers",
             type=int,
             default=2,
-            help="The number of training dataloader workers that "
-            "collect the batches.",
+            help="The number of training dataloader workers that collect the batches.",
         )
 
         group.add_argument(
@@ -178,18 +196,22 @@ class WenetSpeechAsrDataModule:
             "--spec-aug-time-warp-factor",
             type=int,
             default=80,
-            help="Used only when --enable-spec-aug is True. "
-            "It specifies the factor for time warping in SpecAugment. "
-            "Larger values mean more warping. "
-            "A value less than 1 means to disable time warp.",
+            help=(
+                "Used only when --enable-spec-aug is True. "
+                "It specifies the factor for time warping in SpecAugment. "
+                "Larger values mean more warping. "
+                "A value less than 1 means to disable time warp."
+            ),
         )
 
         group.add_argument(
             "--enable-musan",
             type=str2bool,
             default=True,
-            help="When enabled, select noise from MUSAN and mix it"
-            "with training dataset. ",
+            help=(
+                "When enabled, select noise from MUSAN and mix it"
+                "with training dataset. "
+            ),
         )
 
         group.add_argument(
@@ -212,24 +234,20 @@ class WenetSpeechAsrDataModule:
             The state dict for the training sampler.
         """
         logging.info("About to get Musan cuts")
-        cuts_musan = load_manifest(
-            self.args.manifest_dir / "musan_cuts.jsonl.gz"
-        )
+        cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
 
         transforms = []
         if self.args.enable_musan:
             logging.info("Enable MUSAN")
             transforms.append(
-                CutMix(
-                    cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
-                )
+                CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
             )
         else:
             logging.info("Disable MUSAN")
 
         if self.args.concatenate_cuts:
             logging.info(
-                f"Using cut concatenation with duration factor "
+                "Using cut concatenation with duration factor "
                 f"{self.args.duration_factor} and gap {self.args.gap}."
             )
             # Cut concatenation should be the first transform in the list,
@@ -244,9 +262,7 @@ class WenetSpeechAsrDataModule:
         input_transforms = []
         if self.args.enable_spec_aug:
             logging.info("Enable SpecAugment")
-            logging.info(
-                f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
-            )
+            logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
             # Set the value of num_frame_masks according to Lhotse's version.
             # In different Lhotse's versions, the default of num_frame_masks is
             # different.
@@ -289,9 +305,7 @@ class WenetSpeechAsrDataModule:
             # Drop feats to be on the safe side.
             train = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(
-                    Fbank(FbankConfig(num_mel_bins=80))
-                ),
+                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
                 input_transforms=input_transforms,
                 return_cuts=self.args.return_cuts,
             )
@@ -348,9 +362,7 @@ class WenetSpeechAsrDataModule:
         if self.args.on_the_fly_feats:
             validate = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(
-                    Fbank(FbankConfig(num_mel_bins=80))
-                ),
+                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
                 return_cuts=self.args.return_cuts,
             )
         else:
@@ -414,8 +426,7 @@ class WenetSpeechAsrDataModule:
     def train_cuts(self) -> CutSet:
         logging.info("About to get train cuts")
         cuts_train = load_manifest_lazy(
-            self.args.manifest_dir
-            / f"cuts_{self.args.training_subset}.jsonl.gz"
+            self.args.manifest_dir / f"cuts_{self.args.training_subset}.jsonl.gz"
         )
         return cuts_train
 
@@ -427,13 +438,9 @@ class WenetSpeechAsrDataModule:
     @lru_cache()
     def test_net_cuts(self) -> List[CutSet]:
         logging.info("About to get TEST_NET cuts")
-        return load_manifest_lazy(
-            self.args.manifest_dir / "cuts_TEST_NET.jsonl.gz"
-        )
+        return load_manifest_lazy(self.args.manifest_dir / "cuts_TEST_NET.jsonl.gz")
 
     @lru_cache()
     def test_meeting_cuts(self) -> List[CutSet]:
         logging.info("About to get TEST_MEETING cuts")
-        return load_manifest_lazy(
-            self.args.manifest_dir / "cuts_TEST_MEETING.jsonl.gz"
-        )
+        return load_manifest_lazy(self.args.manifest_dir / "cuts_TEST_MEETING.jsonl.gz")
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py
index f0c9bebec..6e856248c 100755
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py
@@ -114,11 +114,7 @@ from beam_search import (
 from train import get_params, get_transducer_model
 
 from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
-from icefall.checkpoint import (
-    average_checkpoints,
-    find_checkpoints,
-    load_checkpoint,
-)
+from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
 from icefall.lexicon import Lexicon
 from icefall.utils import (
     AttributeDict,
@@ -137,25 +133,30 @@ def get_parser():
         "--epoch",
         type=int,
         default=28,
-        help="It specifies the checkpoint to use for decoding."
-        "Note: Epoch counts from 0.",
+        help=(
+            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
+        ),
     )
 
     parser.add_argument(
         "--batch",
         type=int,
         default=None,
-        help="It specifies the batch checkpoint to use for decoding."
-        "Note: Epoch counts from 0.",
+        help=(
+            "It specifies the batch checkpoint to use for decoding."
+            "Note: Epoch counts from 0."
+        ),
     )
 
     parser.add_argument(
         "--avg",
         type=int,
         default=15,
-        help="Number of checkpoints to average. Automatically select "
-        "consecutive checkpoints before the checkpoint specified by "
-        "'--epoch'. ",
+        help=(
+            "Number of checkpoints to average. Automatically select "
+            "consecutive checkpoints before the checkpoint specified by "
+            "'--epoch'. "
+        ),
     )
 
     parser.add_argument(
@@ -252,8 +253,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -328,9 +328,7 @@ def decode_one_batch(
     supervisions = batch["supervisions"]
     feature_lens = supervisions["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(
-        x=feature, x_lens=feature_lens
-    )
+    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
     hyps = []
 
     if params.decoding_method == "fast_beam_search":
@@ -389,10 +387,7 @@ def decode_one_batch(
         )
         for i in range(encoder_out.size(0)):
             hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
-    elif (
-        params.decoding_method == "greedy_search"
-        and params.max_sym_per_frame == 1
-    ):
+    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -438,11 +433,7 @@ def decode_one_batch(
         return {"greedy_search": hyps}
     elif params.decoding_method == "fast_beam_search":
         return {
-            (
-                f"beam_{params.beam}_"
-                f"max_contexts_{params.max_contexts}_"
-                f"max_states_{params.max_states}"
-            ): hyps
+            f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps
         }
     else:
         return {f"beam_size_{params.beam_size}": hyps}
@@ -515,9 +506,7 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
     return results
 
 
@@ -550,8 +539,7 @@ def save_results(
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = (
-        params.res_dir
-        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tWER", file=f)
@@ -663,9 +651,7 @@ def main():
             )
             decoding_graph.scores *= params.ngram_lm_scale
         else:
-            decoding_graph = k2.trivial_graph(
-                params.vocab_size - 1, device=device
-            )
+            decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
     else:
         decoding_graph = None
 
@@ -716,8 +702,7 @@ def main():
         )
 
     dev_shards = [
-        str(path)
-        for path in sorted(glob.glob(os.path.join(dev, "shared-*.tar")))
+        str(path) for path in sorted(glob.glob(os.path.join(dev, "shared-*.tar")))
     ]
     cuts_dev_webdataset = CutSet.from_webdataset(
         dev_shards,
@@ -727,8 +712,7 @@ def main():
     )
 
     test_net_shards = [
-        str(path)
-        for path in sorted(glob.glob(os.path.join(test_net, "shared-*.tar")))
+        str(path) for path in sorted(glob.glob(os.path.join(test_net, "shared-*.tar")))
     ]
     cuts_test_net_webdataset = CutSet.from_webdataset(
         test_net_shards,
@@ -739,9 +723,7 @@ def main():
 
     test_meeting_shards = [
         str(path)
-        for path in sorted(
-            glob.glob(os.path.join(test_meeting, "shared-*.tar"))
-        )
+        for path in sorted(glob.glob(os.path.join(test_meeting, "shared-*.tar")))
     ]
     cuts_test_meeting_webdataset = CutSet.from_webdataset(
         test_meeting_shards,
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/export.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/export.py
index 933642a0f..c742593df 100755
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/export.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/export.py
@@ -126,17 +126,20 @@ def get_parser():
         "--epoch",
         type=int,
         default=28,
-        help="It specifies the checkpoint to use for decoding."
-        "Note: Epoch counts from 0.",
+        help=(
+            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
+        ),
     )
 
     parser.add_argument(
         "--avg",
         type=int,
         default=15,
-        help="Number of checkpoints to average. Automatically select "
-        "consecutive checkpoints before the checkpoint specified by "
-        "'--epoch'. ",
+        help=(
+            "Number of checkpoints to average. Automatically select "
+            "consecutive checkpoints before the checkpoint specified by "
+            "'--epoch'. "
+        ),
     )
 
     parser.add_argument(
@@ -205,8 +208,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     return parser
@@ -468,13 +470,9 @@ def export_joiner_model_onnx(
 
         - projected_decoder_out: a tensor of shape (N, joiner_dim)
     """
-    encoder_proj_filename = str(joiner_filename).replace(
-        ".onnx", "_encoder_proj.onnx"
-    )
+    encoder_proj_filename = str(joiner_filename).replace(".onnx", "_encoder_proj.onnx")
 
-    decoder_proj_filename = str(joiner_filename).replace(
-        ".onnx", "_decoder_proj.onnx"
-    )
+    decoder_proj_filename = str(joiner_filename).replace(".onnx", "_decoder_proj.onnx")
 
     encoder_out_dim = joiner_model.encoder_proj.weight.shape[1]
     decoder_out_dim = joiner_model.decoder_proj.weight.shape[1]
@@ -645,9 +643,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/jit_pretrained.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/jit_pretrained.py
index e5cc47bfe..ed9020c67 100755
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/jit_pretrained.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/jit_pretrained.py
@@ -107,10 +107,12 @@ def get_parser():
         "sound_files",
         type=str,
         nargs="+",
-        help="The input sound file(s) to transcribe. "
-        "Supported formats are those supported by torchaudio.load(). "
-        "For example, wav and flac are supported. "
-        "The sample rate has to be 16kHz.",
+        help=(
+            "The input sound file(s) to transcribe. "
+            "Supported formats are those supported by torchaudio.load(). "
+            "For example, wav and flac are supported. "
+            "The sample rate has to be 16kHz."
+        ),
     )
 
     parser.add_argument(
@@ -145,10 +147,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. "
-            f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
@@ -331,9 +332,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_check.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_check.py
index c396c50ef..a46ff5a07 100755
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_check.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_check.py
@@ -219,9 +219,7 @@ def test_joiner(
         )
 
         # Now test encoder_proj
-        joiner_encoder_proj_inputs = {
-            encoder_proj_input_name: encoder_out.numpy()
-        }
+        joiner_encoder_proj_inputs = {encoder_proj_input_name: encoder_out.numpy()}
         joiner_encoder_proj_out = joiner_encoder_proj_session.run(
             [encoder_proj_output_name], joiner_encoder_proj_inputs
         )[0]
@@ -230,16 +228,10 @@ def test_joiner(
         torch_joiner_encoder_proj_out = model.joiner.encoder_proj(encoder_out)
         assert torch.allclose(
             joiner_encoder_proj_out, torch_joiner_encoder_proj_out, atol=1e-5
-        ), (
-            (joiner_encoder_proj_out - torch_joiner_encoder_proj_out)
-            .abs()
-            .max()
-        )
+        ), ((joiner_encoder_proj_out - torch_joiner_encoder_proj_out).abs().max())
 
         # Now test decoder_proj
-        joiner_decoder_proj_inputs = {
-            decoder_proj_input_name: decoder_out.numpy()
-        }
+        joiner_decoder_proj_inputs = {decoder_proj_input_name: decoder_out.numpy()}
         joiner_decoder_proj_out = joiner_decoder_proj_session.run(
             [decoder_proj_output_name], joiner_decoder_proj_inputs
         )[0]
@@ -248,11 +240,7 @@ def test_joiner(
         torch_joiner_decoder_proj_out = model.joiner.decoder_proj(decoder_out)
         assert torch.allclose(
             joiner_decoder_proj_out, torch_joiner_decoder_proj_out, atol=1e-5
-        ), (
-            (joiner_decoder_proj_out - torch_joiner_decoder_proj_out)
-            .abs()
-            .max()
-        )
+        ), ((joiner_decoder_proj_out - torch_joiner_decoder_proj_out).abs().max())
 
 
 @torch.no_grad()
@@ -304,9 +292,7 @@ def main():
 
 if __name__ == "__main__":
     torch.manual_seed(20220727)
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_pretrained.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_pretrained.py
index 3770fbbb4..f7d962008 100755
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_pretrained.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_pretrained.py
@@ -111,10 +111,12 @@ def get_parser():
         "sound_files",
         type=str,
         nargs="+",
-        help="The input sound file(s) to transcribe. "
-        "Supported formats are those supported by torchaudio.load(). "
-        "For example, wav and flac are supported. "
-        "The sample rate has to be 16kHz.",
+        help=(
+            "The input sound file(s) to transcribe. "
+            "Supported formats are those supported by torchaudio.load(). "
+            "For example, wav and flac are supported. "
+            "The sample rate has to be 16kHz."
+        ),
     )
 
     parser.add_argument(
@@ -149,10 +151,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. "
-            f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
@@ -200,11 +201,7 @@ def greedy_search(
 
     projected_encoder_out = joiner_encoder_proj.run(
         [joiner_encoder_proj.get_outputs()[0].name],
-        {
-            joiner_encoder_proj.get_inputs()[
-                0
-            ].name: packed_encoder_out.data.numpy()
-        },
+        {joiner_encoder_proj.get_inputs()[0].name: packed_encoder_out.data.numpy()},
     )[0]
 
     blank_id = 0  # hard-code to 0
@@ -389,9 +386,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/pretrained.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/pretrained.py
index 9a549efd9..26c9c2b8c 100755
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/pretrained.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/pretrained.py
@@ -80,9 +80,11 @@ def get_parser():
         "--checkpoint",
         type=str,
         required=True,
-        help="Path to the checkpoint. "
-        "The checkpoint is assumed to be saved by "
-        "icefall.checkpoint.save_checkpoint().",
+        help=(
+            "Path to the checkpoint. "
+            "The checkpoint is assumed to be saved by "
+            "icefall.checkpoint.save_checkpoint()."
+        ),
     )
 
     parser.add_argument(
@@ -107,10 +109,12 @@ def get_parser():
         "sound_files",
         type=str,
         nargs="+",
-        help="The input sound file(s) to transcribe. "
-        "Supported formats are those supported by torchaudio.load(). "
-        "For example, wav and flac are supported. "
-        "The sample rate has to be 16kHz.",
+        help=(
+            "The input sound file(s) to transcribe. "
+            "Supported formats are those supported by torchaudio.load(). "
+            "For example, wav and flac are supported. "
+            "The sample rate has to be 16kHz."
+        ),
     )
 
     parser.add_argument(
@@ -158,8 +162,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
@@ -189,10 +192,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. "
-            f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
@@ -253,9 +255,7 @@ def main():
     features = fbank(waves)
     feature_lengths = [f.size(0) for f in features]
 
-    features = pad_sequence(
-        features, batch_first=True, padding_value=math.log(1e-10)
-    )
+    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
 
     feature_lengths = torch.tensor(feature_lengths, device=device)
 
@@ -280,10 +280,7 @@ def main():
         )
         for i in range(encoder_out.size(0)):
             hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
-    elif (
-        params.decoding_method == "greedy_search"
-        and params.max_sym_per_frame == 1
-    ):
+    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -335,9 +332,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py
index d3cc7c9c9..e020c4c05 100644
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py
@@ -115,9 +115,7 @@ from icefall.env import get_env_info
 from icefall.lexicon import Lexicon
 from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
 
-LRSchedulerType = Union[
-    torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
-]
+LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
 
 
 def get_parser():
@@ -219,42 +217,45 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
         "--prune-range",
         type=int,
         default=5,
-        help="The prune range for rnnt loss, it means how many symbols(context)"
-        "we are using to compute the loss",
+        help=(
+            "The prune range for rnnt loss, it means how many symbols(context)"
+            "we are using to compute the loss"
+        ),
     )
 
     parser.add_argument(
         "--lm-scale",
         type=float,
         default=0.25,
-        help="The scale to smooth the loss with lm "
-        "(output of prediction network) part.",
+        help=(
+            "The scale to smooth the loss with lm (output of prediction network) part."
+        ),
     )
 
     parser.add_argument(
         "--am-scale",
         type=float,
         default=0.0,
-        help="The scale to smooth the loss with am (output of encoder network)"
-        "part.",
+        help="The scale to smooth the loss with am (output of encoder network)part.",
     )
 
     parser.add_argument(
         "--simple-loss-scale",
         type=float,
         default=0.5,
-        help="To get pruning ranges, we will calculate a simple version"
-        "loss(joiner is just addition), this simple loss also uses for"
-        "training (as a regularization item). We will scale the simple loss"
-        "with this parameter before adding to the final loss.",
+        help=(
+            "To get pruning ranges, we will calculate a simple version"
+            "loss(joiner is just addition), this simple loss also uses for"
+            "training (as a regularization item). We will scale the simple loss"
+            "with this parameter before adding to the final loss."
+        ),
     )
 
     parser.add_argument(
@@ -590,22 +591,15 @@ def compute_loss(
         # overwhelming the simple_loss and causing it to diverge,
         # in case it had not fully learned the alignment yet.
         pruned_loss_scale = (
-            0.0
-            if warmup < 1.0
-            else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
-        )
-        loss = (
-            params.simple_loss_scale * simple_loss
-            + pruned_loss_scale * pruned_loss
+            0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
         )
+        loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
     assert loss.requires_grad == is_training
 
     info = MetricsTracker()
     with warnings.catch_warnings():
         warnings.simplefilter("ignore")
-        info["frames"] = (
-            (feature_lens // params.subsampling_factor).sum().item()
-        )
+        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
 
     # Note: We use reduction=sum while computing the loss.
     info["loss"] = loss.detach().cpu().item()
@@ -762,9 +756,7 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(
-                    tb_writer, "train/tot_", params.batch_idx_train
-                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
@@ -864,7 +856,7 @@ def run(rank, world_size, args):
 
     if params.print_diagnostics:
         opts = diagnostics.TensorDiagnosticOptions(
-            2 ** 22
+            2**22
         )  # allow 4 megabytes per sub-module
         diagnostic = diagnostics.attach_diagnostics(model, opts)
 
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/conformer.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/conformer.py
index dd27c17f0..1023c931a 100644
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/conformer.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/conformer.py
@@ -210,10 +210,7 @@ class Conformer(EncoderInterface):
           (num_encoder_layers, cnn_module_kernel - 1, encoder_dim).
           NOTE: the returned tensors are on the given device.
         """
-        if (
-            len(self._init_state) == 2
-            and self._init_state[0].size(1) == left_context
-        ):
+        if len(self._init_state) == 2 and self._init_state[0].size(1) == left_context:
             # Note: It is OK to share the init state as it is
             # not going to be modified by the model
             return self._init_state
@@ -433,9 +430,7 @@ class ConformerEncoderLayer(nn.Module):
 
         self.d_model = d_model
 
-        self.self_attn = RelPositionMultiheadAttention(
-            d_model, nhead, dropout=0.0
-        )
+        self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0)
 
         self.feed_forward = nn.Sequential(
             ScaledLinear(d_model, dim_feedforward),
@@ -453,9 +448,7 @@ class ConformerEncoderLayer(nn.Module):
             ScaledLinear(dim_feedforward, d_model, initial_scale=0.25),
         )
 
-        self.conv_module = ConvolutionModule(
-            d_model, cnn_module_kernel, causal=causal
-        )
+        self.conv_module = ConvolutionModule(d_model, cnn_module_kernel, causal=causal)
 
         self.norm_final = BasicNorm(d_model)
 
@@ -520,9 +513,7 @@ class ConformerEncoderLayer(nn.Module):
         src = src + self.dropout(src_att)
 
         # convolution module
-        conv, _ = self.conv_module(
-            src, src_key_padding_mask=src_key_padding_mask
-        )
+        conv, _ = self.conv_module(src, src_key_padding_mask=src_key_padding_mask)
         src = src + self.dropout(conv)
 
         # feed forward module
@@ -766,9 +757,7 @@ class RelPositionalEncoding(torch.nn.Module):
         max_len: Maximum input length.
     """
 
-    def __init__(
-        self, d_model: int, dropout_rate: float, max_len: int = 5000
-    ) -> None:
+    def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None:
         """Construct an PositionalEncoding object."""
         super(RelPositionalEncoding, self).__init__()
         self.d_model = d_model
@@ -784,9 +773,7 @@ class RelPositionalEncoding(torch.nn.Module):
             # the length of self.pe is 2 * input_len - 1
             if self.pe.size(1) >= x_size_1 * 2 - 1:
                 # Note: TorchScript doesn't implement operator== for torch.Device
-                if self.pe.dtype != x.dtype or str(self.pe.device) != str(
-                    x.device
-                ):
+                if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device):
                     self.pe = self.pe.to(dtype=x.dtype, device=x.device)
                 return
         # Suppose `i` means to the position of query vector and `j` means the
@@ -1073,9 +1060,9 @@ class RelPositionMultiheadAttention(nn.Module):
 
         if torch.equal(query, key) and torch.equal(key, value):
             # self-attention
-            q, k, v = nn.functional.linear(
-                query, in_proj_weight, in_proj_bias
-            ).chunk(3, dim=-1)
+            q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk(
+                3, dim=-1
+            )
 
         elif torch.equal(key, value):
             # encoder-decoder attention
@@ -1144,33 +1131,25 @@ class RelPositionMultiheadAttention(nn.Module):
             if attn_mask.dim() == 2:
                 attn_mask = attn_mask.unsqueeze(0)
                 if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
-                    raise RuntimeError(
-                        "The size of the 2D attn_mask is not correct."
-                    )
+                    raise RuntimeError("The size of the 2D attn_mask is not correct.")
             elif attn_mask.dim() == 3:
                 if list(attn_mask.size()) != [
                     bsz * num_heads,
                     query.size(0),
                     key.size(0),
                 ]:
-                    raise RuntimeError(
-                        "The size of the 3D attn_mask is not correct."
-                    )
+                    raise RuntimeError("The size of the 3D attn_mask is not correct.")
             else:
                 raise RuntimeError(
-                    "attn_mask's dimension {} is not supported".format(
-                        attn_mask.dim()
-                    )
+                    "attn_mask's dimension {} is not supported".format(attn_mask.dim())
                 )
             # attn_mask's dim is 3 now.
 
         # convert ByteTensor key_padding_mask to bool
-        if (
-            key_padding_mask is not None
-            and key_padding_mask.dtype == torch.uint8
-        ):
+        if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
             warnings.warn(
-                "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead."
+                "Byte tensor for key_padding_mask is deprecated. Use bool tensor"
+                " instead."
             )
             key_padding_mask = key_padding_mask.to(torch.bool)
 
@@ -1208,23 +1187,15 @@ class RelPositionMultiheadAttention(nn.Module):
         # first compute matrix a and matrix c
         # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
         k = k.permute(1, 2, 3, 0)  # (batch, head, d_k, time2)
-        matrix_ac = torch.matmul(
-            q_with_bias_u, k
-        )  # (batch, head, time1, time2)
+        matrix_ac = torch.matmul(q_with_bias_u, k)  # (batch, head, time1, time2)
 
         # compute matrix b and matrix d
-        matrix_bd = torch.matmul(
-            q_with_bias_v, p
-        )  # (batch, head, time1, 2*time1-1)
+        matrix_bd = torch.matmul(q_with_bias_v, p)  # (batch, head, time1, 2*time1-1)
         matrix_bd = self.rel_shift(matrix_bd, left_context)
 
-        attn_output_weights = (
-            matrix_ac + matrix_bd
-        )  # (batch, head, time1, time2)
+        attn_output_weights = matrix_ac + matrix_bd  # (batch, head, time1, time2)
 
-        attn_output_weights = attn_output_weights.view(
-            bsz * num_heads, tgt_len, -1
-        )
+        attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1)
 
         assert list(attn_output_weights.size()) == [
             bsz * num_heads,
@@ -1265,21 +1236,17 @@ class RelPositionMultiheadAttention(nn.Module):
         ):
             if attn_mask.size(0) != 1:
                 attn_mask = attn_mask.view(bsz, num_heads, tgt_len, src_len)
-                combined_mask = attn_mask | key_padding_mask.unsqueeze(
-                    1
-                ).unsqueeze(2)
+                combined_mask = attn_mask | key_padding_mask.unsqueeze(1).unsqueeze(2)
             else:
                 # attn_mask.shape == (1, tgt_len, src_len)
-                combined_mask = attn_mask.unsqueeze(
-                    0
-                ) | key_padding_mask.unsqueeze(1).unsqueeze(2)
+                combined_mask = attn_mask.unsqueeze(0) | key_padding_mask.unsqueeze(
+                    1
+                ).unsqueeze(2)
 
             attn_output_weights = attn_output_weights.view(
                 bsz, num_heads, tgt_len, src_len
             )
-            attn_output_weights = attn_output_weights.masked_fill(
-                combined_mask, 0.0
-            )
+            attn_output_weights = attn_output_weights.masked_fill(combined_mask, 0.0)
             attn_output_weights = attn_output_weights.view(
                 bsz * num_heads, tgt_len, src_len
             )
@@ -1291,13 +1258,9 @@ class RelPositionMultiheadAttention(nn.Module):
         attn_output = torch.bmm(attn_output_weights, v)
         assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
         attn_output = (
-            attn_output.transpose(0, 1)
-            .contiguous()
-            .view(tgt_len, bsz, embed_dim)
-        )
-        attn_output = nn.functional.linear(
-            attn_output, out_proj_weight, out_proj_bias
+            attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
         )
+        attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias)
 
         if need_weights:
             # average attention weights over heads
@@ -1430,16 +1393,12 @@ class ConvolutionModule(nn.Module):
                 # manualy padding self.lorder zeros to the left
                 x = nn.functional.pad(x, (self.lorder, 0), "constant", 0.0)
             else:
-                assert (
-                    not self.training
-                ), "Cache should be None in training time"
+                assert not self.training, "Cache should be None in training time"
                 assert cache.size(0) == self.lorder
                 x = torch.cat([cache.permute(1, 2, 0), x], dim=2)
                 if right_context > 0:
                     cache = x.permute(2, 0, 1)[
-                        -(self.lorder + right_context) : (  # noqa
-                            -right_context
-                        ),
+                        -(self.lorder + right_context) : (-right_context),  # noqa
                         ...,
                     ]
                 else:
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py
index 344e31283..3d66f9dc9 100755
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py
@@ -160,20 +160,24 @@ def get_parser():
         "--avg",
         type=int,
         default=15,
-        help="Number of checkpoints to average. Automatically select "
-        "consecutive checkpoints before the checkpoint specified by "
-        "'--epoch' and '--iter'",
+        help=(
+            "Number of checkpoints to average. Automatically select "
+            "consecutive checkpoints before the checkpoint specified by "
+            "'--epoch' and '--iter'"
+        ),
     )
 
     parser.add_argument(
         "--use-averaged-model",
         type=str2bool,
         default=True,
-        help="Whether to load averaged model. Currently it only supports "
-        "using --epoch. If True, it would decode with the averaged model "
-        "over the epoch range from `epoch-avg` (excluded) to `epoch`."
-        "Actually only the models with epoch number of `epoch-avg` and "
-        "`epoch` are loaded for averaging. ",
+        help=(
+            "Whether to load averaged model. Currently it only supports "
+            "using --epoch. If True, it would decode with the averaged model "
+            "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+            "Actually only the models with epoch number of `epoch-avg` and "
+            "`epoch` are loaded for averaging. "
+        ),
     )
 
     parser.add_argument(
@@ -244,8 +248,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -342,9 +345,7 @@ def decode_one_batch(
             simulate_streaming=True,
         )
     else:
-        encoder_out, encoder_out_lens = model.encoder(
-            x=feature, x_lens=feature_lens
-        )
+        encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
 
     hyps = []
 
@@ -360,10 +361,7 @@ def decode_one_batch(
         )
         for i in range(encoder_out.size(0)):
             hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
-    elif (
-        params.decoding_method == "greedy_search"
-        and params.max_sym_per_frame == 1
-    ):
+    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -409,11 +407,7 @@ def decode_one_batch(
         return {"greedy_search": hyps}
     elif params.decoding_method == "fast_beam_search":
         return {
-            (
-                f"beam_{params.beam}_"
-                f"max_contexts_{params.max_contexts}_"
-                f"max_states_{params.max_states}"
-            ): hyps
+            f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps
         }
     else:
         return {f"beam_size_{params.beam_size}": hyps}
@@ -484,9 +478,7 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
     return results
 
 
@@ -519,8 +511,7 @@ def save_results(
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = (
-        params.res_dir
-        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tWER", file=f)
@@ -589,13 +580,12 @@ def main():
 
     if not params.use_averaged_model:
         if params.iter > 0:
-            filenames = find_checkpoints(
-                params.exp_dir, iteration=-params.iter
-            )[: params.avg]
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg
+            ]
             if len(filenames) == 0:
                 raise ValueError(
-                    f"No checkpoints found for"
-                    f" --iter {params.iter}, --avg {params.avg}"
+                    f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
                 )
             elif len(filenames) < params.avg:
                 raise ValueError(
@@ -618,13 +608,12 @@ def main():
             model.load_state_dict(average_checkpoints(filenames, device=device))
     else:
         if params.iter > 0:
-            filenames = find_checkpoints(
-                params.exp_dir, iteration=-params.iter
-            )[: params.avg + 1]
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg + 1
+            ]
             if len(filenames) == 0:
                 raise ValueError(
-                    f"No checkpoints found for"
-                    f" --iter {params.iter}, --avg {params.avg}"
+                    f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
                 )
             elif len(filenames) < params.avg + 1:
                 raise ValueError(
@@ -652,7 +641,7 @@ def main():
             filename_start = f"{params.exp_dir}/epoch-{start}.pt"
             filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
             logging.info(
-                f"Calculating the averaged model over epoch range from "
+                "Calculating the averaged model over epoch range from "
                 f"{start} (excluded) to {params.epoch}"
             )
             model.to(device)
@@ -720,8 +709,7 @@ def main():
         )
 
     dev_shards = [
-        str(path)
-        for path in sorted(glob.glob(os.path.join(dev, "shared-*.tar")))
+        str(path) for path in sorted(glob.glob(os.path.join(dev, "shared-*.tar")))
     ]
     cuts_dev_webdataset = CutSet.from_webdataset(
         dev_shards,
@@ -731,8 +719,7 @@ def main():
     )
 
     test_net_shards = [
-        str(path)
-        for path in sorted(glob.glob(os.path.join(test_net, "shared-*.tar")))
+        str(path) for path in sorted(glob.glob(os.path.join(test_net, "shared-*.tar")))
     ]
     cuts_test_net_webdataset = CutSet.from_webdataset(
         test_net_shards,
@@ -743,9 +730,7 @@ def main():
 
     test_meeting_shards = [
         str(path)
-        for path in sorted(
-            glob.glob(os.path.join(test_meeting, "shared-*.tar"))
-        )
+        for path in sorted(glob.glob(os.path.join(test_meeting, "shared-*.tar")))
     ]
     cuts_test_meeting_webdataset = CutSet.from_webdataset(
         test_meeting_shards,
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode_stream.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode_stream.py
index 386248554..e522943c0 100644
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode_stream.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode_stream.py
@@ -75,9 +75,7 @@ class DecodeStream(object):
         # encoder.streaming_forward
         self.done_frames: int = 0
 
-        self.pad_length = (
-            params.right_context + 2
-        ) * params.subsampling_factor + 3
+        self.pad_length = (params.right_context + 2) * params.subsampling_factor + 3
 
         if params.decoding_method == "greedy_search":
             self.hyp = [params.blank_id] * params.context_size
@@ -91,13 +89,11 @@ class DecodeStream(object):
             )
         elif params.decoding_method == "fast_beam_search":
             # The rnnt_decoding_stream for fast_beam_search.
-            self.rnnt_decoding_stream: k2.RnntDecodingStream = (
-                k2.RnntDecodingStream(decoding_graph)
+            self.rnnt_decoding_stream: k2.RnntDecodingStream = k2.RnntDecodingStream(
+                decoding_graph
             )
         else:
-            raise ValueError(
-                f"Unsupported decoding method: {params.decoding_method}"
-            )
+            raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
 
     @property
     def done(self) -> bool:
@@ -126,13 +122,10 @@ class DecodeStream(object):
         """Consume chunk_size frames of features"""
         chunk_length = chunk_size + self.pad_length
 
-        ret_length = min(
-            self.num_frames - self.num_processed_frames, chunk_length
-        )
+        ret_length = min(self.num_frames - self.num_processed_frames, chunk_length)
 
         ret_features = self.features[
-            self.num_processed_frames : self.num_processed_frames  # noqa
-            + ret_length
+            self.num_processed_frames : self.num_processed_frames + ret_length  # noqa
         ]
 
         self.num_processed_frames += chunk_size
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/export.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/export.py
index d0a7fd69f..fb53f70ab 100644
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/export.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/export.py
@@ -90,17 +90,20 @@ def get_parser():
         "--epoch",
         type=int,
         default=28,
-        help="It specifies the checkpoint to use for decoding."
-        "Note: Epoch counts from 0.",
+        help=(
+            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
+        ),
     )
 
     parser.add_argument(
         "--avg",
         type=int,
         default=15,
-        help="Number of checkpoints to average. Automatically select "
-        "consecutive checkpoints before the checkpoint specified by "
-        "'--epoch'. ",
+        help=(
+            "Number of checkpoints to average. Automatically select "
+            "consecutive checkpoints before the checkpoint specified by "
+            "'--epoch'. "
+        ),
     )
 
     parser.add_argument(
@@ -131,8 +134,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     add_model_arguments(parser)
 
@@ -201,9 +203,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/pretrained.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/pretrained.py
index 1b064c874..9834189d8 100644
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/pretrained.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/pretrained.py
@@ -80,9 +80,11 @@ def get_parser():
         "--checkpoint",
         type=str,
         required=True,
-        help="Path to the checkpoint. "
-        "The checkpoint is assumed to be saved by "
-        "icefall.checkpoint.save_checkpoint().",
+        help=(
+            "Path to the checkpoint. "
+            "The checkpoint is assumed to be saved by "
+            "icefall.checkpoint.save_checkpoint()."
+        ),
     )
 
     parser.add_argument(
@@ -107,10 +109,12 @@ def get_parser():
         "sound_files",
         type=str,
         nargs="+",
-        help="The input sound file(s) to transcribe. "
-        "Supported formats are those supported by torchaudio.load(). "
-        "For example, wav and flac are supported. "
-        "The sample rate has to be 16kHz.",
+        help=(
+            "The input sound file(s) to transcribe. "
+            "Supported formats are those supported by torchaudio.load(). "
+            "For example, wav and flac are supported. "
+            "The sample rate has to be 16kHz."
+        ),
     )
 
     parser.add_argument(
@@ -157,8 +161,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
@@ -189,10 +192,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. "
-            f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
@@ -253,9 +255,7 @@ def main():
     features = fbank(waves)
     feature_lengths = [f.size(0) for f in features]
 
-    features = pad_sequence(
-        features, batch_first=True, padding_value=math.log(1e-10)
-    )
+    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
 
     feature_lengths = torch.tensor(feature_lengths, device=device)
 
@@ -280,10 +280,7 @@ def main():
         )
         for i in range(encoder_out.size(0)):
             hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
-    elif (
-        params.decoding_method == "greedy_search"
-        and params.max_sym_per_frame == 1
-    ):
+    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -335,9 +332,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_beam_search.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_beam_search.py
index 651aff6c9..810d94135 100644
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_beam_search.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_beam_search.py
@@ -173,14 +173,10 @@ def modified_beam_search(
         log_probs_shape = k2.ragged.create_ragged_shape2(
             row_splits=row_splits, cached_tot_size=log_probs.numel()
         )
-        ragged_log_probs = k2.RaggedTensor(
-            shape=log_probs_shape, value=log_probs
-        )
+        ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs)
 
         for i in range(batch_size):
-            topk_log_probs, topk_indexes = ragged_log_probs[i].topk(
-                num_active_paths
-            )
+            topk_log_probs, topk_indexes = ragged_log_probs[i].topk(num_active_paths)
 
             with warnings.catch_warnings():
                 warnings.simplefilter("ignore")
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_decode.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_decode.py
index ff96c6487..31a7fe605 100644
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_decode.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_decode.py
@@ -119,20 +119,24 @@ def get_parser():
         "--avg",
         type=int,
         default=15,
-        help="Number of checkpoints to average. Automatically select "
-        "consecutive checkpoints before the checkpoint specified by "
-        "'--epoch' and '--iter'",
+        help=(
+            "Number of checkpoints to average. Automatically select "
+            "consecutive checkpoints before the checkpoint specified by "
+            "'--epoch' and '--iter'"
+        ),
     )
 
     parser.add_argument(
         "--use-averaged-model",
         type=str2bool,
         default=True,
-        help="Whether to load averaged model. Currently it only supports "
-        "using --epoch. If True, it would decode with the averaged model "
-        "over the epoch range from `epoch-avg` (excluded) to `epoch`."
-        "Actually only the models with epoch number of `epoch-avg` and "
-        "`epoch` are loaded for averaging. ",
+        help=(
+            "Whether to load averaged model. Currently it only supports "
+            "using --epoch. If True, it would decode with the averaged model "
+            "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+            "Actually only the models with epoch number of `epoch-avg` and "
+            "`epoch` are loaded for averaging. "
+        ),
     )
 
     parser.add_argument(
@@ -201,8 +205,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
@@ -311,9 +314,7 @@ def decode_one_chunk(
     encoder_out = model.joiner.encoder_proj(encoder_out)
 
     if params.decoding_method == "greedy_search":
-        greedy_search(
-            model=model, encoder_out=encoder_out, streams=decode_streams
-        )
+        greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams)
     elif params.decoding_method == "fast_beam_search":
         processed_lens = processed_lens + encoder_out_lens
         fast_beam_search_one_best(
@@ -333,9 +334,7 @@ def decode_one_chunk(
             num_active_paths=params.num_active_paths,
         )
     else:
-        raise ValueError(
-            f"Unsupported decoding method: {params.decoding_method}"
-        )
+        raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
 
     states = [torch.unbind(states[0], dim=2), torch.unbind(states[1], dim=2)]
 
@@ -389,9 +388,7 @@ def decode_dataset(
     decode_results = []
     # Contain decode streams currently running.
     decode_streams = []
-    initial_states = model.encoder.get_init_state(
-        params.left_context, device=device
-    )
+    initial_states = model.encoder.get_init_state(params.left_context, device=device)
     for num, cut in enumerate(cuts):
         # each utterance has a DecodeStream.
         decode_stream = DecodeStream(
@@ -461,9 +458,7 @@ def decode_dataset(
     elif params.decoding_method == "modified_beam_search":
         key = f"num_active_paths_{params.num_active_paths}"
     else:
-        raise ValueError(
-            f"Unsupported decoding method: {params.decoding_method}"
-        )
+        raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
 
     return {key: decode_results}
 
@@ -499,8 +494,7 @@ def save_results(
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = (
-        params.res_dir
-        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tWER", file=f)
@@ -565,13 +559,12 @@ def main():
 
     if not params.use_averaged_model:
         if params.iter > 0:
-            filenames = find_checkpoints(
-                params.exp_dir, iteration=-params.iter
-            )[: params.avg]
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg
+            ]
             if len(filenames) == 0:
                 raise ValueError(
-                    f"No checkpoints found for"
-                    f" --iter {params.iter}, --avg {params.avg}"
+                    f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
                 )
             elif len(filenames) < params.avg:
                 raise ValueError(
@@ -594,13 +587,12 @@ def main():
             model.load_state_dict(average_checkpoints(filenames, device=device))
     else:
         if params.iter > 0:
-            filenames = find_checkpoints(
-                params.exp_dir, iteration=-params.iter
-            )[: params.avg + 1]
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg + 1
+            ]
             if len(filenames) == 0:
                 raise ValueError(
-                    f"No checkpoints found for"
-                    f" --iter {params.iter}, --avg {params.avg}"
+                    f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
                 )
             elif len(filenames) < params.avg + 1:
                 raise ValueError(
@@ -628,7 +620,7 @@ def main():
             filename_start = f"{params.exp_dir}/epoch-{start}.pt"
             filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
             logging.info(
-                f"Calculating the averaged model over epoch range from "
+                "Calculating the averaged model over epoch range from "
                 f"{start} (excluded) to {params.epoch}"
             )
             model.to(device)
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/train.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/train.py
index 2052e9da7..40c9665f7 100755
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/train.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/train.py
@@ -98,9 +98,7 @@ from icefall.env import get_env_info
 from icefall.lexicon import Lexicon
 from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
 
-LRSchedulerType = Union[
-    torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
-]
+LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
 
 
 def add_model_arguments(parser: argparse.ArgumentParser):
@@ -260,8 +258,7 @@ def get_parser():
         "--initial-lr",
         type=float,
         default=0.003,
-        help="The initial learning rate.  This value should not need "
-        "to be changed.",
+        help="The initial learning rate.  This value should not need to be changed.",
     )
 
     parser.add_argument(
@@ -284,42 +281,45 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
         "--prune-range",
         type=int,
         default=5,
-        help="The prune range for rnnt loss, it means how many symbols(context)"
-        "we are using to compute the loss",
+        help=(
+            "The prune range for rnnt loss, it means how many symbols(context)"
+            "we are using to compute the loss"
+        ),
     )
 
     parser.add_argument(
         "--lm-scale",
         type=float,
         default=0.25,
-        help="The scale to smooth the loss with lm "
-        "(output of prediction network) part.",
+        help=(
+            "The scale to smooth the loss with lm (output of prediction network) part."
+        ),
     )
 
     parser.add_argument(
         "--am-scale",
         type=float,
         default=0.0,
-        help="The scale to smooth the loss with am (output of encoder network)"
-        "part.",
+        help="The scale to smooth the loss with am (output of encoder network)part.",
     )
 
     parser.add_argument(
         "--simple-loss-scale",
         type=float,
         default=0.5,
-        help="To get pruning ranges, we will calculate a simple version"
-        "loss(joiner is just addition), this simple loss also uses for"
-        "training (as a regularization item). We will scale the simple loss"
-        "with this parameter before adding to the final loss.",
+        help=(
+            "To get pruning ranges, we will calculate a simple version"
+            "loss(joiner is just addition), this simple loss also uses for"
+            "training (as a regularization item). We will scale the simple loss"
+            "with this parameter before adding to the final loss."
+        ),
     )
 
     parser.add_argument(
@@ -665,11 +665,7 @@ def compute_loss(
      warmup: a floating point value which increases throughout training;
         values >= 1.0 are fully warmed up and have all modules present.
     """
-    device = (
-        model.device
-        if isinstance(model, DDP)
-        else next(model.parameters()).device
-    )
+    device = model.device if isinstance(model, DDP) else next(model.parameters()).device
     feature = batch["inputs"]
     # at entry, feature is (N, T, C)
     assert feature.ndim == 3
@@ -701,23 +697,16 @@ def compute_loss(
         # overwhelming the simple_loss and causing it to diverge,
         # in case it had not fully learned the alignment yet.
         pruned_loss_scale = (
-            0.0
-            if warmup < 1.0
-            else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
-        )
-        loss = (
-            params.simple_loss_scale * simple_loss
-            + pruned_loss_scale * pruned_loss
+            0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
         )
+        loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
 
     assert loss.requires_grad == is_training
 
     info = MetricsTracker()
     with warnings.catch_warnings():
         warnings.simplefilter("ignore")
-        info["frames"] = (
-            (feature_lens // params.subsampling_factor).sum().item()
-        )
+        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
 
     # Note: We use reduction=sum while computing the loss.
     info["loss"] = loss.detach().cpu().item()
@@ -841,9 +830,7 @@ def train_one_epoch(
             scaler.update()
             optimizer.zero_grad()
         except:  # noqa
-            display_and_save_batch(
-                batch, params=params, graph_compiler=graph_compiler
-            )
+            display_and_save_batch(batch, params=params, graph_compiler=graph_compiler)
             raise
 
         if params.print_diagnostics and batch_idx == 5:
@@ -901,9 +888,7 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(
-                    tb_writer, "train/tot_", params.batch_idx_train
-                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
@@ -1016,7 +1001,7 @@ def run(rank, world_size, args):
 
     if params.print_diagnostics:
         opts = diagnostics.TensorDiagnosticOptions(
-            2 ** 22
+            2**22
         )  # allow 4 megabytes per sub-module
         diagnostic = diagnostics.attach_diagnostics(model, opts)
 
@@ -1184,9 +1169,7 @@ def scan_pessimistic_batches_for_oom(
                     f"Failing criterion: {criterion} "
                     f"(={crit_values[criterion]}) ..."
                 )
-            display_and_save_batch(
-                batch, params=params, graph_compiler=graph_compiler
-            )
+            display_and_save_batch(batch, params=params, graph_compiler=graph_compiler)
             raise
 
 
diff --git a/egs/yesno/ASR/local/compile_hlg.py b/egs/yesno/ASR/local/compile_hlg.py
index f83be05cf..7234ca929 100755
--- a/egs/yesno/ASR/local/compile_hlg.py
+++ b/egs/yesno/ASR/local/compile_hlg.py
@@ -128,9 +128,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
 
diff --git a/egs/yesno/ASR/local/compute_fbank_yesno.py b/egs/yesno/ASR/local/compute_fbank_yesno.py
index 9a4e8a36f..75d95df68 100755
--- a/egs/yesno/ASR/local/compute_fbank_yesno.py
+++ b/egs/yesno/ASR/local/compute_fbank_yesno.py
@@ -54,9 +54,7 @@ def compute_fbank_yesno():
         dataset_parts,
     )
 
-    extractor = Fbank(
-        FbankConfig(sampling_rate=8000, num_mel_bins=num_mel_bins)
-    )
+    extractor = Fbank(FbankConfig(sampling_rate=8000, num_mel_bins=num_mel_bins))
 
     with get_executor() as ex:  # Initialize the executor only once.
         for partition, m in manifests.items():
@@ -71,9 +69,7 @@ def compute_fbank_yesno():
             )
             if "train" in partition:
                 cut_set = (
-                    cut_set
-                    + cut_set.perturb_speed(0.9)
-                    + cut_set.perturb_speed(1.1)
+                    cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
                 )
             cut_set = cut_set.compute_and_store_features(
                 extractor=extractor,
@@ -87,9 +83,7 @@ def compute_fbank_yesno():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
 
diff --git a/egs/yesno/ASR/tdnn/asr_datamodule.py b/egs/yesno/ASR/tdnn/asr_datamodule.py
index 85e5f1358..21860d2f5 100644
--- a/egs/yesno/ASR/tdnn/asr_datamodule.py
+++ b/egs/yesno/ASR/tdnn/asr_datamodule.py
@@ -56,10 +56,12 @@ class YesNoAsrDataModule(DataModule):
         super().add_arguments(parser)
         group = parser.add_argument_group(
             title="ASR data related options",
-            description="These options are used for the preparation of "
-            "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
-            "effective batch sizes, sampling strategies, applied data "
-            "augmentations, etc.",
+            description=(
+                "These options are used for the preparation of "
+                "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
+                "effective batch sizes, sampling strategies, applied data "
+                "augmentations, etc."
+            ),
         )
         group.add_argument(
             "--feature-dir",
@@ -71,75 +73,91 @@ class YesNoAsrDataModule(DataModule):
             "--max-duration",
             type=int,
             default=30.0,
-            help="Maximum pooled recordings duration (seconds) in a "
-            "single batch. You can reduce it if it causes CUDA OOM.",
+            help=(
+                "Maximum pooled recordings duration (seconds) in a "
+                "single batch. You can reduce it if it causes CUDA OOM."
+            ),
         )
         group.add_argument(
             "--bucketing-sampler",
             type=str2bool,
             default=False,
-            help="When enabled, the batches will come from buckets of "
-            "similar duration (saves padding frames).",
+            help=(
+                "When enabled, the batches will come from buckets of "
+                "similar duration (saves padding frames)."
+            ),
         )
         group.add_argument(
             "--num-buckets",
             type=int,
             default=10,
-            help="The number of buckets for the DynamicBucketingSampler"
-            "(you might want to increase it for larger datasets).",
+            help=(
+                "The number of buckets for the DynamicBucketingSampler"
+                "(you might want to increase it for larger datasets)."
+            ),
         )
         group.add_argument(
             "--concatenate-cuts",
             type=str2bool,
             default=False,
-            help="When enabled, utterances (cuts) will be concatenated "
-            "to minimize the amount of padding.",
+            help=(
+                "When enabled, utterances (cuts) will be concatenated "
+                "to minimize the amount of padding."
+            ),
         )
         group.add_argument(
             "--duration-factor",
             type=float,
             default=1.0,
-            help="Determines the maximum duration of a concatenated cut "
-            "relative to the duration of the longest cut in a batch.",
+            help=(
+                "Determines the maximum duration of a concatenated cut "
+                "relative to the duration of the longest cut in a batch."
+            ),
         )
         group.add_argument(
             "--gap",
             type=float,
             default=1.0,
-            help="The amount of padding (in seconds) inserted between "
-            "concatenated cuts. This padding is filled with noise when "
-            "noise augmentation is used.",
+            help=(
+                "The amount of padding (in seconds) inserted between "
+                "concatenated cuts. This padding is filled with noise when "
+                "noise augmentation is used."
+            ),
         )
         group.add_argument(
             "--on-the-fly-feats",
             type=str2bool,
             default=False,
-            help="When enabled, use on-the-fly cut mixing and feature "
-            "extraction. Will drop existing precomputed feature manifests "
-            "if available.",
+            help=(
+                "When enabled, use on-the-fly cut mixing and feature "
+                "extraction. Will drop existing precomputed feature manifests "
+                "if available."
+            ),
         )
         group.add_argument(
             "--shuffle",
             type=str2bool,
             default=True,
-            help="When enabled (=default), the examples will be "
-            "shuffled for each epoch.",
+            help=(
+                "When enabled (=default), the examples will be shuffled for each epoch."
+            ),
         )
         group.add_argument(
             "--return-cuts",
             type=str2bool,
             default=True,
-            help="When enabled, each batch will have the "
-            "field: batch['supervisions']['cut'] with the cuts that "
-            "were used to construct it.",
+            help=(
+                "When enabled, each batch will have the "
+                "field: batch['supervisions']['cut'] with the cuts that "
+                "were used to construct it."
+            ),
         )
 
         group.add_argument(
             "--num-workers",
             type=int,
             default=2,
-            help="The number of training dataloader workers that "
-            "collect the batches.",
+            help="The number of training dataloader workers that collect the batches.",
         )
 
     def train_dataloaders(self) -> DataLoader:
@@ -150,7 +168,7 @@ class YesNoAsrDataModule(DataModule):
         transforms = []
         if self.args.concatenate_cuts:
             logging.info(
-                f"Using cut concatenation with duration factor "
+                "Using cut concatenation with duration factor "
                 f"{self.args.duration_factor} and gap {self.args.gap}."
             )
             # Cut concatenation should be the first transform in the list,
diff --git a/egs/yesno/ASR/tdnn/decode.py b/egs/yesno/ASR/tdnn/decode.py
index 9d4ab4b61..41afe0404 100755
--- a/egs/yesno/ASR/tdnn/decode.py
+++ b/egs/yesno/ASR/tdnn/decode.py
@@ -35,16 +35,19 @@ def get_parser():
         "--epoch",
         type=int,
         default=14,
-        help="It specifies the checkpoint to use for decoding."
-        "Note: Epoch counts from 0.",
+        help=(
+            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
+        ),
     )
     parser.add_argument(
         "--avg",
         type=int,
         default=2,
-        help="Number of checkpoints to average. Automatically select "
-        "consecutive checkpoints before the checkpoint specified by "
-        "'--epoch'. ",
+        help=(
+            "Number of checkpoints to average. Automatically select "
+            "consecutive checkpoints before the checkpoint specified by "
+            "'--epoch'. "
+        ),
     )
 
     parser.add_argument(
@@ -201,9 +204,7 @@ def decode_dataset(
         if batch_idx % 100 == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
     return results
 
 
@@ -274,9 +275,7 @@ def main():
 
     logging.info(f"device: {device}")
 
-    HLG = k2.Fsa.from_dict(
-        torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")
-    )
+    HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu"))
     HLG = HLG.to(device)
     assert HLG.requires_grad is False
 
@@ -297,9 +296,7 @@ def main():
 
     if params.export:
         logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
-        torch.save(
-            {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
-        )
+        torch.save({"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt")
         return
 
     model.to(device)
@@ -317,9 +314,7 @@ def main():
         word_table=lexicon.word_table,
     )
 
-    save_results(
-        exp_dir=params.exp_dir, test_set_name="test_set", results=results
-    )
+    save_results(exp_dir=params.exp_dir, test_set_name="test_set", results=results)
 
     logging.info("Done!")
 
diff --git a/egs/yesno/ASR/tdnn/pretrained.py b/egs/yesno/ASR/tdnn/pretrained.py
index 14220be19..09a8672ae 100755
--- a/egs/yesno/ASR/tdnn/pretrained.py
+++ b/egs/yesno/ASR/tdnn/pretrained.py
@@ -41,9 +41,11 @@ def get_parser():
         "--checkpoint",
         type=str,
         required=True,
-        help="Path to the checkpoint. "
-        "The checkpoint is assumed to be saved by "
-        "icefall.checkpoint.save_checkpoint().",
+        help=(
+            "Path to the checkpoint. "
+            "The checkpoint is assumed to be saved by "
+            "icefall.checkpoint.save_checkpoint()."
+        ),
     )
 
     parser.add_argument(
@@ -53,18 +55,18 @@ def get_parser():
         help="Path to words.txt",
     )
 
-    parser.add_argument(
-        "--HLG", type=str, required=True, help="Path to HLG.pt."
-    )
+    parser.add_argument("--HLG", type=str, required=True, help="Path to HLG.pt.")
 
     parser.add_argument(
         "sound_files",
         type=str,
         nargs="+",
-        help="The input sound file(s) to transcribe. "
-        "Supported formats are those supported by torchaudio.load(). "
-        "For example, wav and flac are supported. "
-        "The sample rate has to be 16kHz.",
+        help=(
+            "The input sound file(s) to transcribe. "
+            "Supported formats are those supported by torchaudio.load(). "
+            "For example, wav and flac are supported. "
+            "The sample rate has to be 16kHz."
+        ),
     )
 
     return parser
@@ -101,10 +103,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. "
-            f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
@@ -159,9 +160,7 @@ def main():
     logging.info("Decoding started")
     features = fbank(waves)
 
-    features = pad_sequence(
-        features, batch_first=True, padding_value=math.log(1e-10)
-    )
+    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
 
     # Note: We don't use key padding mask for attention during decoding
     with torch.no_grad():
@@ -201,9 +200,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/yesno/ASR/tdnn/train.py b/egs/yesno/ASR/tdnn/train.py
index f32a27f35..335493491 100755
--- a/egs/yesno/ASR/tdnn/train.py
+++ b/egs/yesno/ASR/tdnn/train.py
@@ -430,9 +430,7 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(
-                    tb_writer, "train/tot_", params.batch_idx_train
-                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             valid_info = compute_validation_loss(
diff --git a/egs/yesno/ASR/transducer/decode.py b/egs/yesno/ASR/transducer/decode.py
index 6714180db..de478334e 100755
--- a/egs/yesno/ASR/transducer/decode.py
+++ b/egs/yesno/ASR/transducer/decode.py
@@ -48,16 +48,19 @@ def get_parser():
         "--epoch",
         type=int,
         default=125,
-        help="It specifies the checkpoint to use for decoding."
-        "Note: Epoch counts from 0.",
+        help=(
+            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
+        ),
     )
     parser.add_argument(
         "--avg",
         type=int,
         default=20,
-        help="Number of checkpoints to average. Automatically select "
-        "consecutive checkpoints before the checkpoint specified by "
-        "'--epoch'. ",
+        help=(
+            "Number of checkpoints to average. Automatically select "
+            "consecutive checkpoints before the checkpoint specified by "
+            "'--epoch'. "
+        ),
     )
     parser.add_argument(
         "--exp-dir",
@@ -116,9 +119,7 @@ def decode_one_batch(
     # at entry, feature is (N, T, C)
     feature_lens = batch["supervisions"]["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(
-        x=feature, x_lens=feature_lens
-    )
+    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
 
     hyps = []
     batch_size = encoder_out.size(0)
@@ -186,9 +187,7 @@ def decode_dataset(
         if batch_idx % 100 == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
     return results
 
 
@@ -303,9 +302,7 @@ def main():
         model=model,
     )
 
-    save_results(
-        exp_dir=params.exp_dir, test_set_name="test_set", results=results
-    )
+    save_results(exp_dir=params.exp_dir, test_set_name="test_set", results=results)
 
     logging.info("Done!")
 
diff --git a/egs/yesno/ASR/transducer/train.py b/egs/yesno/ASR/transducer/train.py
index deb92107d..88866ae81 100755
--- a/egs/yesno/ASR/transducer/train.py
+++ b/egs/yesno/ASR/transducer/train.py
@@ -430,9 +430,7 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(
-                    tb_writer, "train/tot_", params.batch_idx_train
-                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             valid_info = compute_validation_loss(
diff --git a/icefall/char_graph_compiler.py b/icefall/char_graph_compiler.py
index 235160e14..c31db6e4c 100644
--- a/icefall/char_graph_compiler.py
+++ b/icefall/char_graph_compiler.py
@@ -71,9 +71,7 @@ class CharCtcTrainingGraphCompiler(object):
         for text in texts:
             text = re.sub(whitespace, "", text)
             sub_ids = [
-                self.token_table[txt]
-                if txt in self.token_table
-                else self.oov_id
+                self.token_table[txt] if txt in self.token_table else self.oov_id
                 for txt in text
             ]
             ids.append(sub_ids)
@@ -96,9 +94,7 @@ class CharCtcTrainingGraphCompiler(object):
         for text in texts:
             text = text.split("/")
             sub_ids = [
-                self.token_table[txt]
-                if txt in self.token_table
-                else self.oov_id
+                self.token_table[txt] if txt in self.token_table else self.oov_id
                 for txt in text
             ]
             ids.append(sub_ids)
diff --git a/icefall/checkpoint.py b/icefall/checkpoint.py
index 5069b78e8..8aa0a8eeb 100644
--- a/icefall/checkpoint.py
+++ b/icefall/checkpoint.py
@@ -292,15 +292,11 @@ def find_checkpoints(out_dir: Path, iteration: int = 0) -> List[str]:
     """
     checkpoints = list(glob.glob(f"{out_dir}/checkpoint-[0-9]*.pt"))
     pattern = re.compile(r"checkpoint-([0-9]+).pt")
-    iter_checkpoints = [
-        (int(pattern.search(c).group(1)), c) for c in checkpoints
-    ]
+    iter_checkpoints = [(int(pattern.search(c).group(1)), c) for c in checkpoints]
     # iter_checkpoints is a list of tuples. Each tuple contains
     # two elements: (iteration_number, checkpoint-iteration_number.pt)
 
-    iter_checkpoints = sorted(
-        iter_checkpoints, reverse=True, key=lambda x: x[0]
-    )
+    iter_checkpoints = sorted(iter_checkpoints, reverse=True, key=lambda x: x[0])
     if iteration >= 0:
         ans = [ic[1] for ic in iter_checkpoints if ic[0] >= iteration]
     else:
@@ -469,7 +465,5 @@ def average_state_dict(
         v = state_dict_1[k]
         if torch.is_floating_point(v):
             v *= weight_1
-            v += (
-                state_dict_2[k].to(device=state_dict_1[k].device) * weight_2
-            )
+            v += state_dict_2[k].to(device=state_dict_1[k].device) * weight_2
             v *= scaling_factor
diff --git a/icefall/decode.py b/icefall/decode.py
index f04ee368c..6cd87bdc0 100644
--- a/icefall/decode.py
+++ b/icefall/decode.py
@@ -334,13 +334,9 @@ class Nbest(object):
         if hasattr(lattice, "aux_labels"):
             # delete token IDs as it is not needed
             del word_fsa.aux_labels
-            word_fsa_with_epsilon_loops = k2.linear_fsa_with_self_loops(
-                word_fsa
-            )
+            word_fsa_with_epsilon_loops = k2.linear_fsa_with_self_loops(word_fsa)
         else:
-            word_fsa_with_epsilon_loops = k2.linear_fst_with_self_loops(
-                word_fsa
-            )
+            word_fsa_with_epsilon_loops = k2.linear_fst_with_self_loops(word_fsa)
 
         path_to_utt_map = self.shape.row_ids(1)
 
@@ -370,9 +366,7 @@ class Nbest(object):
         # path_lattice has word IDs as labels and token IDs as aux_labels
         path_lattice = k2.top_sort(k2.connect(path_lattice))
 
-        one_best = k2.shortest_path(
-            path_lattice, use_double_scores=use_double_scores
-        )
+        one_best = k2.shortest_path(path_lattice, use_double_scores=use_double_scores)
 
         one_best = k2.invert(one_best)
         # Now one_best has token IDs as labels and word IDs as aux_labels
@@ -442,9 +436,7 @@ class Nbest(object):
         scores_shape = self.fsa.arcs.shape().remove_axis(1)
         # scores_shape has axes [path][arc]
 
-        ragged_scores = k2.RaggedTensor(
-            scores_shape, self.fsa.scores.contiguous()
-        )
+        ragged_scores = k2.RaggedTensor(scores_shape, self.fsa.scores.contiguous())
 
         tot_scores = ragged_scores.sum()
 
@@ -678,9 +670,7 @@ def rescore_with_n_best_list(
             logging.info(f"num_paths before decreasing: {num_paths}")
             num_paths = int(num_paths / 2)
             if loop_count >= max_loop_count or num_paths <= 0:
-                logging.info(
-                    "Return None as the resulting lattice is too large."
-                )
+                logging.info("Return None as the resulting lattice is too large.")
                 return None
             logging.info(
                 "This OOM is not an error. You can ignore it. "
@@ -787,13 +777,9 @@ def rescore_with_whole_lattice(
         except RuntimeError as e:
             logging.info(f"Caught exception:\n{e}\n")
             if loop_count >= max_loop_count:
-                logging.info(
-                    "Return None as the resulting lattice is too large."
-                )
+                logging.info("Return None as the resulting lattice is too large.")
                 return None
-            logging.info(
-                f"num_arcs before pruning: {inv_lattice.arcs.num_elements()}"
-            )
+            logging.info(f"num_arcs before pruning: {inv_lattice.arcs.num_elements()}")
             logging.info(
                 "This OOM is not an error. You can ignore it. "
                 "If your model does not converge well, or --max-duration "
@@ -805,9 +791,7 @@ def rescore_with_whole_lattice(
                 prune_th_list[loop_count],
                 True,
             )
-            logging.info(
-                f"num_arcs after pruning: {inv_lattice.arcs.num_elements()}"
-            )
+            logging.info(f"num_arcs after pruning: {inv_lattice.arcs.num_elements()}")
         loop_count += 1
 
     # lat has token IDs as labels
@@ -894,9 +878,7 @@ def rescore_with_attention_decoder(
             logging.info(f"num_paths before decreasing: {num_paths}")
             num_paths = int(num_paths / 2)
             if loop_count >= max_loop_count or num_paths <= 0:
-                logging.info(
-                    "Return None as the resulting lattice is too large."
-                )
+                logging.info("Return None as the resulting lattice is too large.")
                 return None
             logging.info(
                 "This OOM is not an error. You can ignore it. "
diff --git a/icefall/diagnostics.py b/icefall/diagnostics.py
index b075aceac..7b58ffbd4 100644
--- a/icefall/diagnostics.py
+++ b/icefall/diagnostics.py
@@ -19,7 +19,7 @@
 
 import random
 from dataclasses import dataclass
-from typing import Optional, Tuple, List
+from typing import List, Optional, Tuple
 
 import torch
 from torch import Tensor, nn
@@ -78,11 +78,11 @@ def get_tensor_stats(
     elif stats_type == "abs":
         x = x.abs()
     elif stats_type == "rms":
-        x = x ** 2
+        x = x**2
     elif stats_type == "positive":
         x = (x > 0).to(dtype=torch.float)
     else:
-        assert stats_type in [ "value", "max", "min" ]
+        assert stats_type in ["value", "max", "min"]
 
     sum_dims = [d for d in range(x.ndim) if d != dim]
     if len(sum_dims) > 0:
@@ -121,7 +121,9 @@ class TensorDiagnostic(object):
         self.name = name
         self.class_name = None  # will assign in accumulate()
 
-        self.stats = None  # we'll later assign a list to this data member.  It's a list of dict.
+        self.stats = (
+            None  # we'll later assign a list to this data member.  It's a list of dict.
+        )
 
         # the keys into self.stats[dim] are strings, whose values can be
         # "abs", "max", "min" ,"value", "positive", "rms", "value".
@@ -133,7 +135,6 @@ class TensorDiagnostic(object):
         # only adding a new element to the list if there was a different dim.
         # if the string in the key is "eigs", if we detect a length mismatch we put None as the value.
 
-
     def accumulate(self, x, class_name: Optional[str] = None):
         """
         Accumulate tensors.
@@ -185,17 +186,12 @@ class TensorDiagnostic(object):
                         done = True
                         break
                 if not done:
-                    if (
-                        this_dim_stats[stats_type] != []
-                        and stats_type == "eigs"
-                    ):
+                    if this_dim_stats[stats_type] != [] and stats_type == "eigs":
                         # >1 size encountered on this dim, e.g. it's a batch or time dimension,
                         # don't accumulat "eigs" stats type, it uses too much memory
                         this_dim_stats[stats_type] = None
                     else:
-                        this_dim_stats[stats_type].append(
-                            TensorAndCount(stats, count)
-                        )
+                        this_dim_stats[stats_type].append(TensorAndCount(stats, count))
 
     def print_diagnostics(self):
         """Print diagnostics for each dimension of the tensor."""
@@ -211,7 +207,6 @@ class TensorDiagnostic(object):
                     assert stats_type == "eigs"
                     continue
 
-
                 def get_count(count):
                     return 1 if stats_type in ["max", "min"] else count
 
@@ -221,7 +216,8 @@ class TensorDiagnostic(object):
                     # a dimension that has variable size in different nnet
                     # forwards, e.g. a time dimension in an ASR model.
                     stats = torch.cat(
-                        [x.tensor / get_count(x.count) for x in stats_list], dim=0
+                        [x.tensor / get_count(x.count) for x in stats_list],
+                        dim=0,
                     )
 
                 if stats_type == "eigs":
@@ -229,9 +225,7 @@ class TensorDiagnostic(object):
                         eigs, _ = torch.symeig(stats)
                         stats = eigs.abs().sqrt()
                     except:  # noqa
-                        print(
-                            "Error getting eigenvalues, trying another method."
-                        )
+                        print("Error getting eigenvalues, trying another method.")
                         eigs, _ = torch.eig(stats)
                         stats = eigs.abs().sqrt()
                         # sqrt so it reflects data magnitude, like stddev- not variance
@@ -242,9 +236,9 @@ class TensorDiagnostic(object):
 
                 # if `summarize` we print percentiles of the stats; else,
                 # we print out individual elements.
-                summarize = (
-                    len(stats_list) > 1
-                ) or self.opts.dim_is_summarized(stats.numel())
+                summarize = (len(stats_list) > 1) or self.opts.dim_is_summarized(
+                    stats.numel()
+                )
                 if summarize:  # usually `summarize` will be true
                     # print out percentiles.
                     stats = stats.sort()[0]
@@ -261,15 +255,15 @@ class TensorDiagnostic(object):
                     ans = stats.tolist()
                     ans = ["%.2g" % x for x in ans]
                     ans = "[" + " ".join(ans) + "]"
-                if stats_type in [ "value", "rms", "eigs" ]:
+                if stats_type in ["value", "rms", "eigs"]:
                     # This norm is useful because it is strictly less than the largest
                     # sqrt(eigenvalue) of the variance, which we print out, and shows,
                     # speaking in an approximate way, how much of that largest eigenvalue
                     # can be attributed to the mean of the distribution.
-                    norm = (stats ** 2).sum().sqrt().item()
+                    norm = (stats**2).sum().sqrt().item()
                     ans += f", norm={norm:.2g}"
                 mean = stats.mean().item()
-                rms = (stats ** 2).mean().sqrt().item()
+                rms = (stats**2).mean().sqrt().item()
                 ans += f", mean={mean:.3g}, rms={rms:.3g}"
 
                 # OK, "ans" contains the actual stats, e.g.
@@ -277,17 +271,17 @@ class TensorDiagnostic(object):
 
                 sizes = [x.tensor.shape[0] for x in stats_list]
                 size_str = (
-                    f"{sizes[0]}"
-                    if len(sizes) == 1
-                    else f"{min(sizes)}..{max(sizes)}"
+                    f"{sizes[0]}" if len(sizes) == 1 else f"{min(sizes)}..{max(sizes)}"
+                )
+                maybe_class_name = (
+                    f" type={self.class_name}," if self.class_name is not None else ""
                 )
-                maybe_class_name = f" type={self.class_name}," if self.class_name is not None else ""
                 print(
-                    f"module={self.name},{maybe_class_name} dim={dim}, size={size_str}, {stats_type} {ans}"
+                    f"module={self.name},{maybe_class_name} dim={dim}, size={size_str},"
+                    f" {stats_type} {ans}"
                 )
 
 
-
 class ModelDiagnostic(object):
     """This class stores diagnostics for all tensors in the torch.nn.Module.
 
@@ -345,32 +339,32 @@ def attach_diagnostics(
         # (matters for name, since the variable gets overwritten).
         # These closures don't really capture by value, only by
         # "the final value the variable got in the function" :-(
-        def forward_hook(
-            _module, _input, _output, _model_diagnostic=ans, _name=name
-        ):
+        def forward_hook(_module, _input, _output, _model_diagnostic=ans, _name=name):
             if isinstance(_output, tuple) and len(_output) == 1:
                 _output = _output[0]
 
             if isinstance(_output, Tensor):
-                _model_diagnostic[f"{_name}.output"].accumulate(_output,
-                                                                class_name=type(_module).__name__)
+                _model_diagnostic[f"{_name}.output"].accumulate(
+                    _output, class_name=type(_module).__name__
+                )
             elif isinstance(_output, tuple):
                 for i, o in enumerate(_output):
-                    _model_diagnostic[f"{_name}.output[{i}]"].accumulate(o,
-                                                                         class_name=type(_module).__name__)
+                    _model_diagnostic[f"{_name}.output[{i}]"].accumulate(
+                        o, class_name=type(_module).__name__
+                    )
 
-        def backward_hook(
-            _module, _input, _output, _model_diagnostic=ans, _name=name
-        ):
+        def backward_hook(_module, _input, _output, _model_diagnostic=ans, _name=name):
             if isinstance(_output, tuple) and len(_output) == 1:
                 _output = _output[0]
             if isinstance(_output, Tensor):
-                _model_diagnostic[f"{_name}.grad"].accumulate(_output,
-                                                              class_name=type(_module).__name__)
+                _model_diagnostic[f"{_name}.grad"].accumulate(
+                    _output, class_name=type(_module).__name__
+                )
             elif isinstance(_output, tuple):
                 for i, o in enumerate(_output):
-                    _model_diagnostic[f"{_name}.grad[{i}]"].accumulate(o,
-                                                                       class_name=type(_module).__name__)
+                    _model_diagnostic[f"{_name}.grad[{i}]"].accumulate(
+                        o, class_name=type(_module).__name__
+                    )
 
         module.register_forward_hook(forward_hook)
         module.register_backward_hook(backward_hook)
diff --git a/icefall/dist.py b/icefall/dist.py
index 7016beafb..9df1c5bd1 100644
--- a/icefall/dist.py
+++ b/icefall/dist.py
@@ -29,9 +29,7 @@ def setup_dist(rank, world_size, master_port=None, use_ddp_launch=False):
         os.environ["MASTER_ADDR"] = "localhost"
 
     if "MASTER_PORT" not in os.environ:
-        os.environ["MASTER_PORT"] = (
-            "12354" if master_port is None else str(master_port)
-        )
+        os.environ["MASTER_PORT"] = "12354" if master_port is None else str(master_port)
 
     if use_ddp_launch is False:
         dist.init_process_group("nccl", rank=rank, world_size=world_size)
diff --git a/icefall/env.py b/icefall/env.py
index 8aeda6be2..373e9a9ff 100644
--- a/icefall/env.py
+++ b/icefall/env.py
@@ -53,9 +53,7 @@ def get_git_sha1():
             )
             > 0
         )
-        git_commit = (
-            git_commit + "-dirty" if dirty_commit else git_commit + "-clean"
-        )
+        git_commit = git_commit + "-dirty" if dirty_commit else git_commit + "-clean"
     except:  # noqa
         return None
 
diff --git a/icefall/graph_compiler.py b/icefall/graph_compiler.py
index 570ed7d7a..e2ff03f61 100644
--- a/icefall/graph_compiler.py
+++ b/icefall/graph_compiler.py
@@ -75,9 +75,7 @@ class CtcTrainingGraphCompiler(object):
 
         # NOTE: k2.compose runs on CUDA only when treat_epsilons_specially
         # is False, so we add epsilon self-loops here
-        fsa_with_self_loops = k2.remove_epsilon_and_add_self_loops(
-            transcript_fsa
-        )
+        fsa_with_self_loops = k2.remove_epsilon_and_add_self_loops(transcript_fsa)
 
         fsa_with_self_loops = k2.arc_sort(fsa_with_self_loops)
 
diff --git a/icefall/hooks.py b/icefall/hooks.py
index fbcf5e148..398a5f689 100644
--- a/icefall/hooks.py
+++ b/icefall/hooks.py
@@ -14,10 +14,11 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import logging
 import random
+
 import torch
 from torch import Tensor, nn
-import logging
 
 
 def register_inf_check_hooks(model: nn.Module) -> None:
@@ -56,7 +57,7 @@ def register_inf_check_hooks(model: nn.Module) -> None:
             if isinstance(_output, Tensor):
                 if not torch.isfinite(_output.to(torch.float32).sum()):
                     logging.warning(
-                        f"The sum of {_name}.grad is not finite" # ": {_output}"
+                        f"The sum of {_name}.grad is not finite"  # ": {_output}"
                     )
             elif isinstance(_output, tuple):
                 for i, o in enumerate(_output):
@@ -65,28 +66,20 @@ def register_inf_check_hooks(model: nn.Module) -> None:
                     if not isinstance(o, Tensor):
                         continue
                     if not torch.isfinite(o.to(torch.float32).sum()):
-                        logging.warning(
-                            f"The sum of {_name}.grad[{i}] is not finite"
-                        )
+                        logging.warning(f"The sum of {_name}.grad[{i}] is not finite")
 
         module.register_forward_hook(forward_hook)
         module.register_backward_hook(backward_hook)
 
-
     for name, parameter in model.named_parameters():
 
-        def param_backward_hook(
-                grad, _name=name
-        ):
+        def param_backward_hook(grad, _name=name):
             if not torch.isfinite(grad.to(torch.float32).sum()):
-                logging.warning(
-                    f"The sum of {_name}.param_grad is not finite"
-                )
+                logging.warning(f"The sum of {_name}.param_grad is not finite")
 
         parameter.register_hook(param_backward_hook)
 
 
-
 def _test_inf_check_hooks():
     model = nn.Sequential(nn.Linear(100, 50), nn.Linear(50, 80))
 
diff --git a/icefall/lexicon.py b/icefall/lexicon.py
index 80bd7c1ee..22e1b78bb 100644
--- a/icefall/lexicon.py
+++ b/icefall/lexicon.py
@@ -49,18 +49,12 @@ def read_lexicon(filename: str) -> List[Tuple[str, List[str]]]:
                 continue
 
             if len(a) < 2:
-                logging.info(
-                    f"Found bad line {line} in lexicon file {filename}"
-                )
-                logging.info(
-                    "Every line is expected to contain at least 2 fields"
-                )
+                logging.info(f"Found bad line {line} in lexicon file {filename}")
+                logging.info("Every line is expected to contain at least 2 fields")
                 sys.exit(1)
             word = a[0]
             if word == "":
-                logging.info(
-                    f"Found bad line {line} in lexicon file {filename}"
-                )
+                logging.info(f"Found bad line {line} in lexicon file {filename}")
                 logging.info(" should not be a valid word")
                 sys.exit(1)
 
@@ -119,9 +113,7 @@ def convert_lexicon_to_ragged(
     lexicon_tmp = read_lexicon(filename)
     lexicon = dict(lexicon_tmp)
     if len(lexicon_tmp) != len(lexicon):
-        raise RuntimeError(
-            "It's assumed that each word has a unique pronunciation"
-        )
+        raise RuntimeError("It's assumed that each word has a unique pronunciation")
 
     for i in range(disambig_id):
         w = word_table[i]
diff --git a/icefall/mmi.py b/icefall/mmi.py
index 2c479fc2c..16ed6e032 100644
--- a/icefall/mmi.py
+++ b/icefall/mmi.py
@@ -63,10 +63,7 @@ def _compute_mmi_loss_exact_optimized(
 
     # [0, num_fsas, 1, num_fsas, 2, num_fsas, ... ]
     num_den_graphs_indexes = (
-        torch.stack([num_graphs_indexes, den_graphs_indexes])
-        .t()
-        .reshape(-1)
-        .to(device)
+        torch.stack([num_graphs_indexes, den_graphs_indexes]).t().reshape(-1).to(device)
     )
 
     num_den_reordered_graphs = k2.index(num_den_graphs, num_den_graphs_indexes)
@@ -115,20 +112,12 @@ def _compute_mmi_loss_exact_non_optimized(
     num_graphs, den_graphs = graph_compiler.compile(texts, replicate_den=True)
 
     # TODO: pass output_beam as function argument
-    num_lats = k2.intersect_dense(
-        num_graphs, dense_fsa_vec, output_beam=beam_size
-    )
-    den_lats = k2.intersect_dense(
-        den_graphs, dense_fsa_vec, output_beam=beam_size
-    )
+    num_lats = k2.intersect_dense(num_graphs, dense_fsa_vec, output_beam=beam_size)
+    den_lats = k2.intersect_dense(den_graphs, dense_fsa_vec, output_beam=beam_size)
 
-    num_tot_scores = num_lats.get_tot_scores(
-        log_semiring=True, use_double_scores=True
-    )
+    num_tot_scores = num_lats.get_tot_scores(log_semiring=True, use_double_scores=True)
 
-    den_tot_scores = den_lats.get_tot_scores(
-        log_semiring=True, use_double_scores=True
-    )
+    den_tot_scores = den_lats.get_tot_scores(log_semiring=True, use_double_scores=True)
 
     tot_scores = num_tot_scores - den_scale * den_tot_scores
 
@@ -168,13 +157,9 @@ def _compute_mmi_loss_pruned(
         max_active_states=10000,
     )
 
-    num_tot_scores = num_lats.get_tot_scores(
-        log_semiring=True, use_double_scores=True
-    )
+    num_tot_scores = num_lats.get_tot_scores(log_semiring=True, use_double_scores=True)
 
-    den_tot_scores = den_lats.get_tot_scores(
-        log_semiring=True, use_double_scores=True
-    )
+    den_tot_scores = den_lats.get_tot_scores(log_semiring=True, use_double_scores=True)
 
     tot_scores = num_tot_scores - den_scale * den_tot_scores
 
diff --git a/icefall/mmi_graph_compiler.py b/icefall/mmi_graph_compiler.py
index 0d901227d..9f680f83d 100644
--- a/icefall/mmi_graph_compiler.py
+++ b/icefall/mmi_graph_compiler.py
@@ -137,9 +137,7 @@ class MmiTrainingGraphCompiler(object):
             transcript_fsa
         )
 
-        transcript_fsa_with_self_loops = k2.arc_sort(
-            transcript_fsa_with_self_loops
-        )
+        transcript_fsa_with_self_loops = k2.arc_sort(transcript_fsa_with_self_loops)
 
         num = k2.compose(
             self.ctc_topo_P,
@@ -155,9 +153,7 @@ class MmiTrainingGraphCompiler(object):
 
         ctc_topo_P_vec = k2.create_fsa_vec([self.ctc_topo_P])
         if replicate_den:
-            indexes = torch.zeros(
-                len(texts), dtype=torch.int32, device=self.device
-            )
+            indexes = torch.zeros(len(texts), dtype=torch.int32, device=self.device)
             den = k2.index_fsa(ctc_topo_P_vec, indexes)
         else:
             den = ctc_topo_P_vec
diff --git a/icefall/rnn_lm/compute_perplexity.py b/icefall/rnn_lm/compute_perplexity.py
index 550801a8f..9a275bf28 100755
--- a/icefall/rnn_lm/compute_perplexity.py
+++ b/icefall/rnn_lm/compute_perplexity.py
@@ -46,16 +46,19 @@ def get_parser():
         "--epoch",
         type=int,
         default=49,
-        help="It specifies the checkpoint to use for decoding."
-        "Note: Epoch counts from 0.",
+        help=(
+            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
+        ),
     )
     parser.add_argument(
         "--avg",
         type=int,
         default=20,
-        help="Number of checkpoints to average. Automatically select "
-        "consecutive checkpoints before the checkpoint specified by "
-        "'--epoch'. ",
+        help=(
+            "Number of checkpoints to average. Automatically select "
+            "consecutive checkpoints before the checkpoint specified by "
+            "'--epoch'. "
+        ),
     )
 
     parser.add_argument(
@@ -194,7 +197,7 @@ def main():
 
     logging.info(f"Number of model parameters: {num_param}")
     logging.info(
-        f"Number of model parameters (requires_grad): "
+        "Number of model parameters (requires_grad): "
         f"{num_param_requires_grad} "
         f"({num_param_requires_grad/num_param_requires_grad*100}%)"
     )
diff --git a/icefall/rnn_lm/dataset.py b/icefall/rnn_lm/dataset.py
index 598e329c4..4bf982503 100644
--- a/icefall/rnn_lm/dataset.py
+++ b/icefall/rnn_lm/dataset.py
@@ -155,12 +155,8 @@ class LmDatasetCollate:
         sentence_tokens_with_sos = add_sos(sentence_tokens, self.sos_id)
         sentence_tokens_with_eos = add_eos(sentence_tokens, self.eos_id)
 
-        x = sentence_tokens_with_sos.pad(
-            mode="constant", padding_value=self.blank_id
-        )
-        y = sentence_tokens_with_eos.pad(
-            mode="constant", padding_value=self.blank_id
-        )
+        x = sentence_tokens_with_sos.pad(mode="constant", padding_value=self.blank_id)
+        y = sentence_tokens_with_eos.pad(mode="constant", padding_value=self.blank_id)
         sentence_token_lengths += 1  # plus 1 since we added a SOS
 
         return x.to(torch.int64), y.to(torch.int64), sentence_token_lengths
diff --git a/icefall/rnn_lm/export.py b/icefall/rnn_lm/export.py
index 094035fce..2e878f5c8 100644
--- a/icefall/rnn_lm/export.py
+++ b/icefall/rnn_lm/export.py
@@ -38,17 +38,20 @@ def get_parser():
         "--epoch",
         type=int,
         default=29,
-        help="It specifies the checkpoint to use for decoding."
-        "Note: Epoch counts from 0.",
+        help=(
+            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
+        ),
     )
 
     parser.add_argument(
         "--avg",
         type=int,
         default=5,
-        help="Number of checkpoints to average. Automatically select "
-        "consecutive checkpoints before the checkpoint specified by "
-        "'--epoch'. ",
+        help=(
+            "Number of checkpoints to average. Automatically select "
+            "consecutive checkpoints before the checkpoint specified by "
+            "'--epoch'. "
+        ),
     )
 
     parser.add_argument(
@@ -159,9 +162,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/icefall/rnn_lm/model.py b/icefall/rnn_lm/model.py
index a6144727a..9eef88840 100644
--- a/icefall/rnn_lm/model.py
+++ b/icefall/rnn_lm/model.py
@@ -129,9 +129,7 @@ class RnnLmModel(torch.nn.Module):
         tokens_eos = add_eos(tokens, eos_id)
         sos_tokens_row_splits = sos_tokens.shape.row_splits(1)
 
-        sentence_lengths = (
-            sos_tokens_row_splits[1:] - sos_tokens_row_splits[:-1]
-        )
+        sentence_lengths = sos_tokens_row_splits[1:] - sos_tokens_row_splits[:-1]
 
         x_tokens = sos_tokens.pad(mode="constant", padding_value=blank_id)
         y_tokens = tokens_eos.pad(mode="constant", padding_value=blank_id)
@@ -161,12 +159,12 @@ class RnnLmModel(torch.nn.Module):
         if state:
             h, c = state
         else:
-            h = torch.zeros(
-                self.rnn.num_layers, batch_size, self.rnn.input_size
-            ).to(device)
-            c = torch.zeros(
-                self.rnn.num_layers, batch_size, self.rnn.input_size
-            ).to(device)
+            h = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.input_size).to(
+                device
+            )
+            c = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.input_size).to(
+                device
+            )
 
         embedding = self.input_embedding(tokens)
         rnn_out, states = self.rnn(embedding, (h, c))
@@ -181,12 +179,8 @@ class RnnLmModel(torch.nn.Module):
         if state:
             h, c = state
         else:
-            h = torch.zeros(
-                self.rnn.num_layers, batch_size, self.rnn.input_size
-            )
-            c = torch.zeros(
-                self.rnn.num_layers, batch_size, self.rnn.input_size
-            )
+            h = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.input_size)
+            c = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.input_size)
 
         device = next(self.parameters()).device
 
@@ -194,9 +188,7 @@ class RnnLmModel(torch.nn.Module):
         tokens_eos = add_eos(tokens, eos_id)
         sos_tokens_row_splits = sos_tokens.shape.row_splits(1)
 
-        sentence_lengths = (
-            sos_tokens_row_splits[1:] - sos_tokens_row_splits[:-1]
-        )
+        sentence_lengths = sos_tokens_row_splits[1:] - sos_tokens_row_splits[:-1]
 
         x_tokens = sos_tokens.pad(mode="constant", padding_value=blank_id)
         y_tokens = tokens_eos.pad(mode="constant", padding_value=blank_id)
diff --git a/icefall/rnn_lm/train.py b/icefall/rnn_lm/train.py
index bb5f03fb9..e17b50332 100755
--- a/icefall/rnn_lm/train.py
+++ b/icefall/rnn_lm/train.py
@@ -446,17 +446,13 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(
-                    tb_writer, "train/tot_", params.batch_idx_train
-                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
 
                 tb_writer.add_scalar(
                     "train/current_ppl", this_batch_ppl, params.batch_idx_train
                 )
 
-                tb_writer.add_scalar(
-                    "train/tot_ppl", tot_ppl, params.batch_idx_train
-                )
+                tb_writer.add_scalar("train/tot_ppl", tot_ppl, params.batch_idx_train)
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
@@ -471,8 +467,7 @@ def train_one_epoch(
 
             valid_ppl = math.exp(valid_info["loss"] / valid_info["frames"])
             logging.info(
-                f"Epoch {params.cur_epoch}, validation: {valid_info}, "
-                f"ppl: {valid_ppl}"
+                f"Epoch {params.cur_epoch}, validation: {valid_info}, ppl: {valid_ppl}"
             )
 
             if tb_writer is not None:
diff --git a/icefall/shared/make_kn_lm.py b/icefall/shared/make_kn_lm.py
index c2edd823e..a3bf1ef4c 100755
--- a/icefall/shared/make_kn_lm.py
+++ b/icefall/shared/make_kn_lm.py
@@ -15,30 +15,50 @@
 # The data structure is based on: kaldi/egs/wsj/s5/utils/lang/make_phone_lm.py
 # The smoothing algorithm is based on: http://www.speech.sri.com/projects/srilm/manpages/ngram-discount.7.html
 
-import sys
-import os
-import re
+import argparse
 import io
 import math
-import argparse
+import os
+import re
+import sys
 from collections import Counter, defaultdict
 
-
-parser = argparse.ArgumentParser(description="""
+parser = argparse.ArgumentParser(
+    description="""
     Generate kneser-ney language model as arpa format. By default,
     it will read the corpus from standard input, and output to standard output.
-    """)
-parser.add_argument("-ngram-order", type=int, default=4, choices=[2, 3, 4, 5, 6, 7], help="Order of n-gram")
+    """
+)
+parser.add_argument(
+    "-ngram-order",
+    type=int,
+    default=4,
+    choices=[2, 3, 4, 5, 6, 7],
+    help="Order of n-gram",
+)
 parser.add_argument("-text", type=str, default=None, help="Path to the corpus file")
-parser.add_argument("-lm", type=str, default=None, help="Path to output arpa file for language models")
-parser.add_argument("-verbose", type=int, default=0, choices=[0, 1, 2, 3, 4, 5], help="Verbose level")
+parser.add_argument(
+    "-lm",
+    type=str,
+    default=None,
+    help="Path to output arpa file for language models",
+)
+parser.add_argument(
+    "-verbose",
+    type=int,
+    default=0,
+    choices=[0, 1, 2, 3, 4, 5],
+    help="Verbose level",
+)
 args = parser.parse_args()
 
-default_encoding = "latin-1"  # For encoding-agnostic scripts, we assume byte stream as input.
-                              # Need to be very careful about the use of strip() and split()
-                              # in this case, because there is a latin-1 whitespace character
-                              # (nbsp) which is part of the unicode encoding range.
-                              # Ref: kaldi/egs/wsj/s5/utils/lang/bpe/prepend_words.py @ 69cd717
+default_encoding = (
+    "latin-1"  # For encoding-agnostic scripts, we assume byte stream as input.
+)
+# Need to be very careful about the use of strip() and split()
+# in this case, because there is a latin-1 whitespace character
+# (nbsp) which is part of the unicode encoding range.
+# Ref: kaldi/egs/wsj/s5/utils/lang/bpe/prepend_words.py @ 69cd717
 strip_chars = " \t\r\n"
 whitespace = re.compile("[ \t]+")
 
@@ -52,7 +72,9 @@ class CountsForHistory:
         # The 'lambda: defaultdict(float)' is an anonymous function taking no
         # arguments that returns a new defaultdict(float).
         self.word_to_count = defaultdict(int)
-        self.word_to_context = defaultdict(set)  # using a set to count the number of unique contexts
+        self.word_to_context = defaultdict(
+            set
+        )  # using a set to count the number of unique contexts
         self.word_to_f = dict()  # discounted probability
         self.word_to_bow = dict()  # back-off weight
         self.total_count = 0
@@ -62,10 +84,15 @@ class CountsForHistory:
 
     def __str__(self):
         # e.g. returns ' total=12: 3->4, 4->6, -1->2'
-        return ' total={0}: {1}'.format(
+        return " total={0}: {1}".format(
             str(self.total_count),
-            ', '.join(['{0} -> {1}'.format(word, count)
-                      for word, count in self.word_to_count.items()]))
+            ", ".join(
+                [
+                    "{0} -> {1}".format(word, count)
+                    for word, count in self.word_to_count.items()
+                ]
+            ),
+        )
 
     def add_count(self, predicted_word, context_word, count):
         assert count >= 0
@@ -85,7 +112,7 @@ class NgramCounts:
     # accumulating the 4-gram count for the '8' in the sequence '5 6 7 8', we'd
     # do as follows: self.counts[3][[5,6,7]][8] += 1.0 where the [3] indexes an
     # array, the [[5,6,7]] indexes a dict, and the [8] indexes a dict.
-    def __init__(self, ngram_order, bos_symbol='', eos_symbol=''):
+    def __init__(self, ngram_order, bos_symbol="", eos_symbol=""):
         assert ngram_order >= 2
 
         self.ngram_order = ngram_order
@@ -103,39 +130,48 @@ class NgramCounts:
     # would be (6,7,8) and 'predicted_word' would be 9; 'count' would be
     # 1.
     def add_count(self, history, predicted_word, context_word, count):
-        self.counts[len(history)][history].add_count(predicted_word, context_word, count)
+        self.counts[len(history)][history].add_count(
+            predicted_word, context_word, count
+        )
 
     # 'line' is a string containing a sequence of integer word-ids.
     # This function adds the un-smoothed counts from this line of text.
     def add_raw_counts_from_line(self, line):
-        if line == '':
+        if line == "":
             words = [self.bos_symbol, self.eos_symbol]
         else:
             words = [self.bos_symbol] + whitespace.split(line) + [self.eos_symbol]
 
         for i in range(len(words)):
-            for n in range(1, self.ngram_order+1):
+            for n in range(1, self.ngram_order + 1):
                 if i + n > len(words):
                     break
-                ngram = words[i: i + n]
+                ngram = words[i : i + n]
                 predicted_word = ngram[-1]
-                history = tuple(ngram[: -1])
+                history = tuple(ngram[:-1])
                 if i == 0 or n == self.ngram_order:
                     context_word = None
                 else:
-                    context_word = words[i-1]
+                    context_word = words[i - 1]
 
                 self.add_count(history, predicted_word, context_word, 1)
 
     def add_raw_counts_from_standard_input(self):
         lines_processed = 0
-        infile = io.TextIOWrapper(sys.stdin.buffer, encoding=default_encoding)  # byte stream as input
+        infile = io.TextIOWrapper(
+            sys.stdin.buffer, encoding=default_encoding
+        )  # byte stream as input
         for line in infile:
             line = line.strip(strip_chars)
             self.add_raw_counts_from_line(line)
             lines_processed += 1
         if lines_processed == 0 or args.verbose > 0:
-            print("make_phone_lm.py: processed {0} lines of input".format(lines_processed), file=sys.stderr)
+            print(
+                "make_phone_lm.py: processed {0} lines of input".format(
+                    lines_processed
+                ),
+                file=sys.stderr,
+            )
 
     def add_raw_counts_from_file(self, filename):
         lines_processed = 0
@@ -145,7 +181,12 @@ class NgramCounts:
                 self.add_raw_counts_from_line(line)
                 lines_processed += 1
         if lines_processed == 0 or args.verbose > 0:
-            print("make_phone_lm.py: processed {0} lines of input".format(lines_processed), file=sys.stderr)
+            print(
+                "make_phone_lm.py: processed {0} lines of input".format(
+                    lines_processed
+                ),
+                file=sys.stderr,
+            )
 
     def cal_discounting_constants(self):
         # For each order N of N-grams, we calculate discounting constant D_N = n1_N / (n1_N + 2 * n2_N),
@@ -153,9 +194,11 @@ class NgramCounts:
         # This constant is used similarly to absolute discounting.
         # Return value: d is a list of floats, where d[N+1] = D_N
 
-        self.d = [0]  # for the lowest order, i.e., 1-gram, we do not need to discount, thus the constant is 0
-                      # This is a special case: as we currently assumed having seen all vocabularies in the dictionary,
-                      # but perhaps this is not the case for some other scenarios.
+        self.d = [
+            0
+        ]  # for the lowest order, i.e., 1-gram, we do not need to discount, thus the constant is 0
+        # This is a special case: as we currently assumed having seen all vocabularies in the dictionary,
+        # but perhaps this is not the case for some other scenarios.
         for n in range(1, self.ngram_order):
             this_order_counts = self.counts[n]
             n1 = 0
@@ -165,9 +208,11 @@ class NgramCounts:
                 n1 += stat[1]
                 n2 += stat[2]
             assert n1 + 2 * n2 > 0
-            self.d.append(max(0.1, n1 * 1.0) / (n1 + 2 * n2))   # We are doing this max(0.001, xxx) to avoid zero discounting constant D due to n1=0, 
-                                                                # which could happen if the number of symbols is small.
-                                                                # Otherwise, zero discounting constant can cause division by zero in computing BOW.
+            self.d.append(
+                max(0.1, n1 * 1.0) / (n1 + 2 * n2)
+            )  # We are doing this max(0.001, xxx) to avoid zero discounting constant D due to n1=0,
+            # which could happen if the number of symbols is small.
+            # Otherwise, zero discounting constant can cause division by zero in computing BOW.
 
     def cal_f(self):
         # f(a_z) is a probability distribution of word sequence a_z.
@@ -182,7 +227,9 @@ class NgramCounts:
         this_order_counts = self.counts[n]
         for hist, counts_for_hist in this_order_counts.items():
             for w, c in counts_for_hist.word_to_count.items():
-                counts_for_hist.word_to_f[w] = max((c - self.d[n]), 0) * 1.0 / counts_for_hist.total_count
+                counts_for_hist.word_to_f[w] = (
+                    max((c - self.d[n]), 0) * 1.0 / counts_for_hist.total_count
+                )
 
         # lower order N-grams
         for n in range(0, self.ngram_order - 1):
@@ -196,11 +243,17 @@ class NgramCounts:
                 if n_star_star != 0:
                     for w in counts_for_hist.word_to_count.keys():
                         n_star_z = len(counts_for_hist.word_to_context[w])
-                        counts_for_hist.word_to_f[w] = max((n_star_z - self.d[n]), 0) * 1.0 / n_star_star
+                        counts_for_hist.word_to_f[w] = (
+                            max((n_star_z - self.d[n]), 0) * 1.0 / n_star_star
+                        )
                 else:  # patterns begin with , they do not have "modified count", so use raw count instead
                     for w in counts_for_hist.word_to_count.keys():
                         n_star_z = counts_for_hist.word_to_count[w]
-                        counts_for_hist.word_to_f[w] = max((n_star_z - self.d[n]), 0) * 1.0 / counts_for_hist.total_count
+                        counts_for_hist.word_to_f[w] = (
+                            max((n_star_z - self.d[n]), 0)
+                            * 1.0
+                            / counts_for_hist.total_count
+                        )
 
     def cal_bow(self):
         # Backoff weights are only necessary for ngrams which form a prefix of a longer ngram.
@@ -240,12 +293,18 @@ class NgramCounts:
                         sum_z1_f_z = 0
                         _ = a_[1:]
                         _counts_for_hist = self.counts[len(_)][_]
-                        for u in a_counts_for_hist.word_to_count.keys():  # Should be careful here: what is Z1
+                        for (
+                            u
+                        ) in (
+                            a_counts_for_hist.word_to_count.keys()
+                        ):  # Should be careful here: what is Z1
                             sum_z1_f_z += _counts_for_hist.word_to_f[u]
 
                         if sum_z1_f_z < 1:
                             # assert sum_z1_f_a_z < 1
-                            counts_for_hist.word_to_bow[w] = (1.0 - sum_z1_f_a_z) / (1.0 - sum_z1_f_z)
+                            counts_for_hist.word_to_bow[w] = (1.0 - sum_z1_f_a_z) / (
+                                1.0 - sum_z1_f_z
+                            )
                         else:
                             counts_for_hist.word_to_bow[w] = None
 
@@ -259,7 +318,9 @@ class NgramCounts:
                     ngram = " ".join(hist) + " " + w
                     ngram = ngram.strip(strip_chars)
 
-                    res.append("{0}\t{1}".format(ngram, counts_for_hist.word_to_count[w]))
+                    res.append(
+                        "{0}\t{1}".format(ngram, counts_for_hist.word_to_count[w])
+                    )
         res.sort(reverse=True)
         for r in res:
             print(r)
@@ -322,27 +383,40 @@ class NgramCounts:
                     if bow is None:
                         res.append("{1}\t{0}".format(ngram, math.log(f, 10)))
                     else:
-                        res.append("{1}\t{0}\t{2}".format(ngram, math.log(f, 10), math.log(bow, 10)))
+                        res.append(
+                            "{1}\t{0}\t{2}".format(
+                                ngram, math.log(f, 10), math.log(bow, 10)
+                            )
+                        )
         res.sort(reverse=True)
         for r in res:
             print(r)
 
-    def print_as_arpa(self, fout=io.TextIOWrapper(sys.stdout.buffer, encoding='latin-1')):
+    def print_as_arpa(
+        self, fout=io.TextIOWrapper(sys.stdout.buffer, encoding="latin-1")
+    ):
         # print as ARPA format.
 
-        print('\\data\\', file=fout)
+        print("\\data\\", file=fout)
         for hist_len in range(self.ngram_order):
             # print the number of n-grams.
-            print('ngram {0}={1}'.format(
-                hist_len + 1,
-                sum([len(counts_for_hist.word_to_f) for counts_for_hist in self.counts[hist_len].values()])),
-                file=fout
+            print(
+                "ngram {0}={1}".format(
+                    hist_len + 1,
+                    sum(
+                        [
+                            len(counts_for_hist.word_to_f)
+                            for counts_for_hist in self.counts[hist_len].values()
+                        ]
+                    ),
+                ),
+                file=fout,
             )
 
-        print('', file=fout)
+        print("", file=fout)
 
         for hist_len in range(self.ngram_order):
-            print('\\{0}-grams:'.format(hist_len + 1), file=fout)
+            print("\\{0}-grams:".format(hist_len + 1), file=fout)
 
             this_order_counts = self.counts[hist_len]
             for hist, counts_for_hist in this_order_counts.items():
@@ -354,12 +428,12 @@ class NgramCounts:
                     if prob == 0:  # f() is always 0
                         prob = 1e-99
 
-                    line = '{0}\t{1}'.format('%.7f' % math.log10(prob), ' '.join(ngram))
+                    line = "{0}\t{1}".format("%.7f" % math.log10(prob), " ".join(ngram))
                     if bow is not None:
-                        line += '\t{0}'.format('%.7f' % math.log10(bow))
+                        line += "\t{0}".format("%.7f" % math.log10(bow))
                     print(line, file=fout)
-            print('', file=fout)
-        print('\\end\\', file=fout)
+            print("", file=fout)
+        print("\\end\\", file=fout)
 
 
 if __name__ == "__main__":
@@ -379,5 +453,5 @@ if __name__ == "__main__":
     if args.lm is None:
         ngram_counts.print_as_arpa()
     else:
-        with open(args.lm, 'w', encoding=default_encoding) as f:
+        with open(args.lm, "w", encoding=default_encoding) as f:
             ngram_counts.print_as_arpa(fout=f)
diff --git a/icefall/utils.py b/icefall/utils.py
index c502cb4d8..0beb94b2e 100644
--- a/icefall/utils.py
+++ b/icefall/utils.py
@@ -130,9 +130,7 @@ def setup_logger(
         formatter = f"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] ({rank}/{world_size}) %(message)s"  # noqa
         log_filename = f"{log_filename}-{date_time}-{rank}"
     else:
-        formatter = (
-            "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-        )
+        formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
         log_filename = f"{log_filename}-{date_time}"
 
     os.makedirs(os.path.dirname(log_filename), exist_ok=True)
@@ -280,13 +278,9 @@ def get_texts_with_timestamp(
     """
     if isinstance(best_paths.aux_labels, k2.RaggedTensor):
         all_aux_shape = (
-            best_paths.arcs.shape()
-            .remove_axis(1)
-            .compose(best_paths.aux_labels.shape)
-        )
-        all_aux_labels = k2.RaggedTensor(
-            all_aux_shape, best_paths.aux_labels.values
+            best_paths.arcs.shape().remove_axis(1).compose(best_paths.aux_labels.shape)
         )
+        all_aux_labels = k2.RaggedTensor(all_aux_shape, best_paths.aux_labels.values)
         # remove 0's and -1's.
         aux_labels = best_paths.aux_labels.remove_values_leq(0)
         # TODO: change arcs.shape() to arcs.shape
@@ -355,9 +349,7 @@ def get_alignments(best_paths: k2.Fsa, kind: str) -> List[List[int]]:
     # arc.shape() has axes [fsa][state][arc], we remove "state"-axis here
     token_shape = best_paths.arcs.shape().remove_axis(1)
     # token_shape has axes [fsa][arc]
-    tokens = k2.RaggedTensor(
-        token_shape, getattr(best_paths, kind).contiguous()
-    )
+    tokens = k2.RaggedTensor(token_shape, getattr(best_paths, kind).contiguous())
     tokens = tokens.remove_values_eq(-1)
     return tokens.tolist()
 
@@ -578,9 +570,7 @@ def write_error_stats(
             f"{cut_id}:\t"
             + " ".join(
                 (
-                    ref_word
-                    if ref_word == hyp_word
-                    else f"({ref_word}->{hyp_word})"
+                    ref_word if ref_word == hyp_word else f"({ref_word}->{hyp_word})"
                     for ref_word, hyp_word in ali
                 )
             ),
@@ -590,9 +580,7 @@ def write_error_stats(
     print("", file=f)
     print("SUBSTITUTIONS: count ref -> hyp", file=f)
 
-    for count, (ref, hyp) in sorted(
-        [(v, k) for k, v in subs.items()], reverse=True
-    ):
+    for count, (ref, hyp) in sorted([(v, k) for k, v in subs.items()], reverse=True):
         print(f"{count}   {ref} -> {hyp}", file=f)
 
     print("", file=f)
@@ -606,9 +594,7 @@ def write_error_stats(
         print(f"{count}   {hyp}", file=f)
 
     print("", file=f)
-    print(
-        "PER-WORD STATS: word  corr tot_errs count_in_ref count_in_hyp", file=f
-    )
+    print("PER-WORD STATS: word  corr tot_errs count_in_ref count_in_hyp", file=f)
     for _, word, counts in sorted(
         [(sum(v[1:]), k, v) for k, v in words.items()], reverse=True
     ):
@@ -783,9 +769,7 @@ def write_error_stats_with_timestamps(
             f"{cut_id}:\t"
             + " ".join(
                 (
-                    ref_word
-                    if ref_word == hyp_word
-                    else f"({ref_word}->{hyp_word})"
+                    ref_word if ref_word == hyp_word else f"({ref_word}->{hyp_word})"
                     for ref_word, hyp_word in ali
                 )
             ),
@@ -795,9 +779,7 @@ def write_error_stats_with_timestamps(
     print("", file=f)
     print("SUBSTITUTIONS: count ref -> hyp", file=f)
 
-    for count, (ref, hyp) in sorted(
-        [(v, k) for k, v in subs.items()], reverse=True
-    ):
+    for count, (ref, hyp) in sorted([(v, k) for k, v in subs.items()], reverse=True):
         print(f"{count}   {ref} -> {hyp}", file=f)
 
     print("", file=f)
@@ -811,9 +793,7 @@ def write_error_stats_with_timestamps(
         print(f"{count}   {hyp}", file=f)
 
     print("", file=f)
-    print(
-        "PER-WORD STATS: word  corr tot_errs count_in_ref count_in_hyp", file=f
-    )
+    print("PER-WORD STATS: word  corr tot_errs count_in_ref count_in_hyp", file=f)
     for _, word, counts in sorted(
         [(sum(v[1:]), k, v) for k, v in words.items()], reverse=True
     ):
@@ -883,9 +863,7 @@ class MetricsTracker(collections.defaultdict):
             if k == "frames" or k == "utterances":
                 continue
             norm_value = (
-                float(v) / num_frames
-                if "utt_" not in k
-                else float(v) / num_utterances
+                float(v) / num_frames if "utt_" not in k else float(v) / num_utterances
             )
             ans.append((k, norm_value))
         return ans
@@ -919,9 +897,7 @@ class MetricsTracker(collections.defaultdict):
             tb_writer.add_scalar(prefix + k, v, batch_idx)
 
 
-def concat(
-    ragged: k2.RaggedTensor, value: int, direction: str
-) -> k2.RaggedTensor:
+def concat(ragged: k2.RaggedTensor, value: int, direction: str) -> k2.RaggedTensor:
     """Prepend a value to the beginning of each sublist or append a value.
     to the end of each sublist.
 
@@ -967,8 +943,8 @@ def concat(
         ans = k2.ragged.cat([ragged, pad], axis=1)
     else:
         raise ValueError(
-            f'Unsupported direction: {direction}. " \
-            "Expect either "left" or "right"'
+            f'Unsupported direction: {direction}. "             "Expect either "left"'
+            ' or "right"'
         )
     return ans
 
@@ -1093,9 +1069,7 @@ def linf_norm(x):
     return torch.max(torch.abs(x))
 
 
-def measure_weight_norms(
-    model: nn.Module, norm: str = "l2"
-) -> Dict[str, float]:
+def measure_weight_norms(model: nn.Module, norm: str = "l2") -> Dict[str, float]:
     """
     Compute the norms of the model's parameters.
 
@@ -1118,9 +1092,7 @@ def measure_weight_norms(
         return norms
 
 
-def measure_gradient_norms(
-    model: nn.Module, norm: str = "l1"
-) -> Dict[str, float]:
+def measure_gradient_norms(model: nn.Module, norm: str = "l1") -> Dict[str, float]:
     """
     Compute the norms of the gradients for each of model's parameters.
 
@@ -1405,9 +1377,7 @@ def parse_hyp_and_timestamp(
         use_word_table = True
 
     for i in range(N):
-        time = convert_timestamp(
-            res.timestamps[i], subsampling_factor, frame_shift_ms
-        )
+        time = convert_timestamp(res.timestamps[i], subsampling_factor, frame_shift_ms)
         if use_word_table:
             words = [word_table[i] for i in res.hyps[i]]
         else:
diff --git a/pyproject.toml b/pyproject.toml
index b4f8c3377..3183055d4 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -3,7 +3,7 @@ profile = "black"
 skip = ["icefall/__init__.py"]
 
 [tool.black]
-line-length = 80
+line-length = 88
 exclude = '''
 /(
     \.git
diff --git a/setup.py b/setup.py
index 6c720e121..ccd2503ff 100644
--- a/setup.py
+++ b/setup.py
@@ -1,8 +1,9 @@
 #!/usr/bin/env python3
 
-from setuptools import find_packages, setup
 from pathlib import Path
 
+from setuptools import find_packages, setup
+
 icefall_dir = Path(__file__).parent
 install_requires = (icefall_dir / "requirements.txt").read_text().splitlines()
 
diff --git a/test/test_checkpoint.py b/test/test_checkpoint.py
index 511a11c23..34e829642 100644
--- a/test/test_checkpoint.py
+++ b/test/test_checkpoint.py
@@ -20,11 +20,7 @@ import pytest
 import torch
 import torch.nn as nn
 
-from icefall.checkpoint import (
-    average_checkpoints,
-    load_checkpoint,
-    save_checkpoint,
-)
+from icefall.checkpoint import average_checkpoints, load_checkpoint, save_checkpoint
 
 
 @pytest.fixture
diff --git a/test/test_decode.py b/test/test_decode.py
index 97964ac67..4c2e192a7 100644
--- a/test/test_decode.py
+++ b/test/test_decode.py
@@ -23,6 +23,7 @@ You can run this file in one of the two ways:
 """
 
 import k2
+
 from icefall.decode import Nbest
 
 
diff --git a/test/test_graph_compiler.py b/test/test_graph_compiler.py
index ccfb57d49..10443cf22 100644
--- a/test/test_graph_compiler.py
+++ b/test/test_graph_compiler.py
@@ -154,9 +154,7 @@ class TestCtcTrainingGraphCompiler(object):
         fsas = k2.Fsa.from_fsas([fsa1, fsa2])
 
         decoding_graph = k2.arc_sort(decoding_graph)
-        lattice = k2.intersect(
-            decoding_graph, fsas, treat_epsilons_specially=False
-        )
+        lattice = k2.intersect(decoding_graph, fsas, treat_epsilons_specially=False)
         lattice = k2.connect(lattice)
 
         aux_labels0 = lattice[0].aux_labels[:-1]
diff --git a/test/test_utils.py b/test/test_utils.py
index 6a9ce7853..31f06bd51 100644
--- a/test/test_utils.py
+++ b/test/test_utils.py
@@ -50,9 +50,7 @@ def test_encode_supervisions(sup):
     assert torch.all(
         torch.eq(
             supervision_segments,
-            torch.tensor(
-                [[1, 0, 30 // 4], [0, 0, 20 // 4], [2, 9 // 4, 10 // 4]]
-            ),
+            torch.tensor([[1, 0, 30 // 4], [0, 0, 20 // 4], [2, 9 // 4, 10 // 4]]),
         )
     )
     assert texts == ["two", "one", "three"]

From d89766d85dbb023a8fcb47221545d00b1a015c69 Mon Sep 17 00:00:00 2001
From: Desh Raj 
Date: Wed, 16 Nov 2022 13:10:55 -0500
Subject: [PATCH 007/174] add git blame ignore revs file

---
 .git-blame-ignore-revs | 2 ++
 1 file changed, 2 insertions(+)
 create mode 100644 .git-blame-ignore-revs

diff --git a/.git-blame-ignore-revs b/.git-blame-ignore-revs
new file mode 100644
index 000000000..c5908fc89
--- /dev/null
+++ b/.git-blame-ignore-revs
@@ -0,0 +1,2 @@
+# Migrate to 88 characters per line (see: https://github.com/lhotse-speech/lhotse/issues/890)
+d110b04ad389134c82fa314e3aafc7b40043efb0

From 7a8e8e735d21bd9d992e291e6c39c922154168b5 Mon Sep 17 00:00:00 2001
From: Desh Raj 
Date: Wed, 16 Nov 2022 14:43:21 -0500
Subject: [PATCH 008/174] change click version in pre-commit

---
 .pre-commit-config.yaml | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index e2055801b..5cb213327 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -4,7 +4,7 @@ repos:
     hooks:
       - id: black
         args: ["--line-length=88"]
-        additional_dependencies: ['click==8.0.1']
+        additional_dependencies: ['click==8.1.0']
         exclude: icefall\/__init__\.py
 
   - repo: https://github.com/PyCQA/flake8

From fca796cc2c9f4e86cad6ea0f5fe4305c071a2293 Mon Sep 17 00:00:00 2001
From: Daniil 
Date: Wed, 16 Nov 2022 17:55:53 -0500
Subject: [PATCH 009/174] Small code refactoring (#687)

---
 egs/librispeech/ASR/conformer_ctc2/train.py   |  15 ---
 .../ASR/conformer_ctc2/transformer.py         |  27 +---
 egs/librispeech/ASR/local/compile_hlg.py      |  25 ++--
 .../transducer_stateless/asr_datamodule.py    | 123 ++++++++++--------
 icefall/decode.py                             |  24 +++-
 icefall/utils.py                              |  12 +-
 6 files changed, 120 insertions(+), 106 deletions(-)

diff --git a/egs/librispeech/ASR/conformer_ctc2/train.py b/egs/librispeech/ASR/conformer_ctc2/train.py
index 9d9c2af1f..18fa3e69f 100755
--- a/egs/librispeech/ASR/conformer_ctc2/train.py
+++ b/egs/librispeech/ASR/conformer_ctc2/train.py
@@ -166,13 +166,6 @@ def get_parser():
         """,
     )
 
-    parser.add_argument(
-        "--bpe-model",
-        type=str,
-        default="data/lang_bpe_500/bpe.model",
-        help="Path to the BPE model",
-    )
-
     parser.add_argument(
         "--initial-lr",
         type=float,
@@ -522,14 +515,6 @@ def compute_loss(
         nnet_output, encoder_memory, memory_mask = model(
             feature, supervisions, warmup=warmup
         )
-        # logging.info('feature shape: {}'.format(feature.shape))
-        # logging.info('nnet_output shape: {}'.format(nnet_output.shape))
-        # logging.info('encoder_memory shape: {}'.format(encoder_memory.shape))
-        # logging.info('memory_mask shape: {}'.format(memory_mask.shape))
-        # after the main warmup step, we keep pruned_loss_scale small
-        # for the same amount of time (model_warm_step), to avoid
-        # overwhelming the simple_loss and causing it to diverge,
-        # in case it had not fully learned the alignment yet.
 
     # NOTE: We need `encode_supervisions` to sort sequences with
     # different duration in decreasing order, required by
diff --git a/egs/librispeech/ASR/conformer_ctc2/transformer.py b/egs/librispeech/ASR/conformer_ctc2/transformer.py
index fa179acc0..3ef7edc23 100644
--- a/egs/librispeech/ASR/conformer_ctc2/transformer.py
+++ b/egs/librispeech/ASR/conformer_ctc2/transformer.py
@@ -417,7 +417,6 @@ class TransformerEncoderLayer(nn.Module):
         dim_feedforward: int = 2048,
         dropout: float = 0.1,
         layer_dropout: float = 0.075,
-        activation: str = "relu",
     ) -> None:
         super(TransformerEncoderLayer, self).__init__()
 
@@ -443,11 +442,6 @@ class TransformerEncoderLayer(nn.Module):
 
         self.dropout = nn.Dropout(dropout)
 
-    # def __setstate__(self, state):
-    #     if "activation" not in state:
-    #         state["activation"] = nn.functional.relu
-    #     super(TransformerEncoderLayer, self).__setstate__(state)
-
     def forward(
         self,
         src: torch.Tensor,
@@ -539,7 +533,6 @@ class TransformerDecoderLayer(nn.Module):
         dim_feedforward: int = 2048,
         dropout: float = 0.1,
         layer_dropout: float = 0.075,
-        # activation: str = "relu",
         normalize_before: bool = True,
     ) -> None:
         super(TransformerDecoderLayer, self).__init__()
@@ -564,11 +557,6 @@ class TransformerDecoderLayer(nn.Module):
 
         self.dropout = nn.Dropout(dropout)
 
-    # def __setstate__(self, state):
-    #     if "activation" not in state:
-    #         state["activation"] = nn.functional.relu
-    #     super(TransformerDecoderLayer, self).__setstate__(state)
-
     def forward(
         self,
         tgt: torch.Tensor,
@@ -653,17 +641,6 @@ class TransformerDecoderLayer(nn.Module):
         return tgt
 
 
-def _get_activation_fn(activation: str):
-    if activation == "relu":
-        return nn.functional.relu
-    elif activation == "gelu":
-        return nn.functional.gelu
-
-    raise RuntimeError(
-        "activation should be relu/gelu, not {}".format(activation)
-    )
-
-
 class TransformerEncoder(nn.Module):
     r"""TransformerEncoder is a stack of N encoder layers
 
@@ -708,7 +685,7 @@ class TransformerEncoder(nn.Module):
         """
         output = src
 
-        for i, mod in enumerate(self.layers):
+        for mod in self.layers:
             output = mod(
                 output,
                 src_mask=mask,
@@ -769,7 +746,7 @@ class TransformerDecoder(nn.Module):
         """
         output = tgt
 
-        for i, mod in enumerate(self.layers):
+        for mod in self.layers:
             output = mod(
                 output,
                 memory,
diff --git a/egs/librispeech/ASR/local/compile_hlg.py b/egs/librispeech/ASR/local/compile_hlg.py
index 9a35750e0..c628dfd53 100755
--- a/egs/librispeech/ASR/local/compile_hlg.py
+++ b/egs/librispeech/ASR/local/compile_hlg.py
@@ -40,6 +40,13 @@ from icefall.lexicon import Lexicon
 
 def get_args():
     parser = argparse.ArgumentParser()
+    parser.add_argument(
+        "--lm",
+        type=str,
+        default="G_3_gram",
+        help="""Stem name for LM used in HLG compiling.
+        """,
+    )
     parser.add_argument(
         "--lang-dir",
         type=str,
@@ -50,11 +57,13 @@ def get_args():
     return parser.parse_args()
 
 
-def compile_HLG(lang_dir: str) -> k2.Fsa:
+def compile_HLG(lang_dir: str, lm: str="G_3_gram") -> k2.Fsa:
     """
     Args:
       lang_dir:
         The language directory, e.g., data/lang_phone or data/lang_bpe_5000.
+      lm:
+        The language stem base name.
 
     Return:
       An FSA representing HLG.
@@ -65,15 +74,15 @@ def compile_HLG(lang_dir: str) -> k2.Fsa:
     H = k2.ctc_topo(max_token_id)
     L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt"))
 
-    if Path("data/lm/G_3_gram.pt").is_file():
-        logging.info("Loading pre-compiled G_3_gram")
-        d = torch.load("data/lm/G_3_gram.pt")
+    if Path(f"data/lm/{lm}.pt").is_file():
+        logging.info(f"Loading pre-compiled {lm}")
+        d = torch.load(f"data/lm/{lm}.pt")
         G = k2.Fsa.from_dict(d)
     else:
-        logging.info("Loading G_3_gram.fst.txt")
-        with open("data/lm/G_3_gram.fst.txt") as f:
+        logging.info(f"Loading {lm}.fst.txt")
+        with open(f"data/lm/{lm}.fst.txt") as f:
             G = k2.Fsa.from_openfst(f.read(), acceptor=False)
-            torch.save(G.as_dict(), "data/lm/G_3_gram.pt")
+            torch.save(G.as_dict(), f"data/lm/{lm}.pt")
 
     first_token_disambig_id = lexicon.token_table["#0"]
     first_word_disambig_id = lexicon.word_table["#0"]
@@ -144,7 +153,7 @@ def main():
 
     logging.info(f"Processing {lang_dir}")
 
-    HLG = compile_HLG(lang_dir)
+    HLG = compile_HLG(lang_dir, args.lm)
     logging.info(f"Saving HLG.pt to {lang_dir}")
     torch.save(HLG.as_dict(), f"{lang_dir}/HLG.pt")
 
diff --git a/egs/tedlium3/ASR/transducer_stateless/asr_datamodule.py b/egs/tedlium3/ASR/transducer_stateless/asr_datamodule.py
index 51de46ae8..94784c4c4 100644
--- a/egs/tedlium3/ASR/transducer_stateless/asr_datamodule.py
+++ b/egs/tedlium3/ASR/transducer_stateless/asr_datamodule.py
@@ -17,10 +17,11 @@
 
 
 import argparse
-import inspect
 import logging
+
 from functools import lru_cache
 from pathlib import Path
+from typing import Any, Dict, Optional
 
 from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
 from lhotse.dataset import (
@@ -28,7 +29,6 @@ from lhotse.dataset import (
     CutMix,
     DynamicBucketingSampler,
     K2SpeechRecognitionDataset,
-    PrecomputedFeatures,
     SingleCutSampler,
     SpecAugment,
 )
@@ -140,7 +140,6 @@ class TedLiumAsrDataModule:
             "field: batch['supervisions']['cut'] with the cuts that "
             "were used to construct it.",
         )
-
         group.add_argument(
             "--num-workers",
             type=int,
@@ -148,14 +147,12 @@ class TedLiumAsrDataModule:
             help="The number of training dataloader workers that "
             "collect the batches.",
         )
-
         group.add_argument(
             "--enable-spec-aug",
             type=str2bool,
             default=True,
             help="When enabled, use SpecAugment for training dataset.",
         )
-
         group.add_argument(
             "--spec-aug-time-warp-factor",
             type=int,
@@ -165,16 +162,48 @@ class TedLiumAsrDataModule:
             "Larger values mean more warping. "
             "A value less than 1 means to disable time warp.",
         )
-
         group.add_argument(
             "--enable-musan",
             type=str2bool,
             default=True,
             help="When enabled, select noise from MUSAN and mix it"
-            "with training dataset. ",
+            "with training dataset.",
         )
 
-    def train_dataloaders(self, cuts_train: CutSet) -> DataLoader:
+    def train_dataloaders(
+        self,
+        cuts_train: CutSet,
+        sampler_state_dict: Optional[Dict[str, Any]] = None
+    ) -> DataLoader:
+        """
+        Args:
+          cuts_train:
+            CutSet for training.
+          sampler_state_dict:
+            The state dict for the training sampler.
+        """
+
+        input_transforms = []
+        if self.args.enable_spec_aug:
+            logging.info("Enable SpecAugment")
+            logging.info(
+                f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
+            )
+
+            input_transforms.append(
+                SpecAugment(
+                    time_warp_factor=self.args.spec_aug_time_warp_factor,
+                    num_frame_masks=10,
+                    features_mask_size=27,
+                    num_feature_masks=2,
+                    frames_mask_size=100,
+                    max_frames_mask_fraction=0.15,
+                    p=0.9,
+                )
+            )
+        else:
+            logging.info("Disable SpecAugment")
+
         logging.info("About to get Musan cuts")
         transforms = []
         if self.args.enable_musan:
@@ -204,42 +233,7 @@ class TedLiumAsrDataModule:
                 )
             ] + transforms
 
-        input_transforms = []
-        if self.args.enable_spec_aug:
-            logging.info("Enable SpecAugment")
-            logging.info(
-                f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
-            )
-            # Set the value of num_frame_masks according to Lhotse's version.
-            # In different Lhotse's versions, the default of num_frame_masks is
-            # different.
-            num_frame_masks = 10
-            num_frame_masks_parameter = inspect.signature(
-                SpecAugment.__init__
-            ).parameters["num_frame_masks"]
-            if num_frame_masks_parameter.default == 1:
-                num_frame_masks = 2
-            logging.info(f"Num frame mask: {num_frame_masks}")
-            input_transforms.append(
-                SpecAugment(
-                    time_warp_factor=self.args.spec_aug_time_warp_factor,
-                    num_frame_masks=num_frame_masks,
-                    features_mask_size=27,
-                    num_feature_masks=2,
-                    frames_mask_size=100,
-                    max_frames_mask_fraction=0.15,
-                    p=0.9,
-                )
-            )
-        else:
-            logging.info("Disable SpecAugment")
-
         logging.info("About to create train dataset")
-        train = K2SpeechRecognitionDataset(
-            cut_transforms=transforms,
-            input_transforms=input_transforms,
-            return_cuts=self.args.return_cuts,
-        )
         if self.args.on_the_fly_feats:
             # NOTE: the PerturbSpeed transform should be added only if we
             # remove it from data prep stage.
@@ -259,6 +253,12 @@ class TedLiumAsrDataModule:
                 input_transforms=input_transforms,
                 return_cuts=self.args.return_cuts,
             )
+        else:
+            train = K2SpeechRecognitionDataset(
+                cut_transforms=transforms,
+                input_transforms=input_transforms,
+                return_cuts=self.args.return_cuts,
+            )
 
         if self.args.bucketing_sampler:
             logging.info("Using DynamicBucketingSampler.")
@@ -276,6 +276,11 @@ class TedLiumAsrDataModule:
                 max_duration=self.args.max_duration,
                 shuffle=self.args.shuffle,
             )
+
+        if sampler_state_dict is not None:
+            logging.info("Loading sampler state dict")
+            train_sampler.load_state_dict(sampler_state_dict)
+
         logging.info("About to create train dataloader")
         train_dl = DataLoader(
             train,
@@ -288,6 +293,7 @@ class TedLiumAsrDataModule:
         return train_dl
 
     def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
+
         transforms = []
         if self.args.concatenate_cuts:
             transforms = [
@@ -310,11 +316,13 @@ class TedLiumAsrDataModule:
                 cut_transforms=transforms,
                 return_cuts=self.args.return_cuts,
             )
+
         valid_sampler = DynamicBucketingSampler(
             cuts_valid,
             max_duration=self.args.max_duration,
             shuffle=False,
         )
+
         logging.info("About to create dev dataloader")
         valid_dl = DataLoader(
             validate,
@@ -326,25 +334,34 @@ class TedLiumAsrDataModule:
 
         return valid_dl
 
-    def test_dataloaders(self, cuts: CutSet) -> DataLoader:
+    def test_dataloaders(self, cuts_test: CutSet) -> DataLoader:
+
         logging.debug("About to create test dataset")
-        test = K2SpeechRecognitionDataset(
-            input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
-            if self.args.on_the_fly_feats
-            else PrecomputedFeatures(),
-            return_cuts=self.args.return_cuts,
-        )
-        sampler = DynamicBucketingSampler(
-            cuts,
+        if self.args.on_the_fly_feats:
+            test = K2SpeechRecognitionDataset(
+                input_strategy=OnTheFlyFeatures(
+                    Fbank(FbankConfig(num_mel_bins=80))
+                ),
+                return_cuts=self.args.return_cuts,
+            )
+        else:
+            test = K2SpeechRecognitionDataset(
+                return_cuts=self.args.return_cuts,
+            )
+
+        test_sampler = DynamicBucketingSampler(
+            cuts_test,
             max_duration=self.args.max_duration,
             shuffle=False,
         )
+
         logging.debug("About to create test dataloader")
         test_dl = DataLoader(
             test,
             batch_size=None,
-            sampler=sampler,
+            sampler=test_sampler,
             num_workers=self.args.num_workers,
+            persistent_workers=False,
         )
         return test_dl
 
diff --git a/icefall/decode.py b/icefall/decode.py
index f04ee368c..099e2d171 100644
--- a/icefall/decode.py
+++ b/icefall/decode.py
@@ -459,7 +459,8 @@ class Nbest(object):
 def one_best_decoding(
     lattice: k2.Fsa,
     use_double_scores: bool = True,
-) -> k2.Fsa:
+    lm_scale_list: Optional[List[float]] = None,
+) -> Union[k2.Fsa, Dict[str, k2.Fsa]]:
     """Get the best path from a lattice.
 
     Args:
@@ -468,11 +469,28 @@ def one_best_decoding(
       use_double_scores:
         True to use double precision floating point in the computation.
         False to use single precision.
+      lm_scale_list:
+        A list of floats representing LM score scales.
     Return:
       An FsaVec containing linear paths.
     """
-    best_path = k2.shortest_path(lattice, use_double_scores=use_double_scores)
-    return best_path
+
+    if lm_scale_list is not None:
+
+        ans = dict()
+        saved_am_scores = lattice.scores - lattice.lm_scores
+        for lm_scale in lm_scale_list:
+            am_scores = saved_am_scores / lm_scale
+            lattice.scores = am_scores + lattice.lm_scores
+
+            best_path = k2.shortest_path(
+                lattice, use_double_scores=use_double_scores
+            )
+            key = f"lm_scale_{lm_scale}"
+            ans[key] = best_path
+        return ans
+
+    return k2.shortest_path(lattice, use_double_scores=use_double_scores)
 
 
 def nbest_decoding(
diff --git a/icefall/utils.py b/icefall/utils.py
index c502cb4d8..143c79497 100644
--- a/icefall/utils.py
+++ b/icefall/utils.py
@@ -194,8 +194,16 @@ def encode_supervisions(
     supervision_segments = torch.stack(
         (
             supervisions["sequence_idx"],
-            supervisions["start_frame"] // subsampling_factor,
-            supervisions["num_frames"] // subsampling_factor,
+            torch.div(
+                supervisions["start_frame"],
+                subsampling_factor,
+                rounding_mode="floor",
+            ),
+            torch.div(
+                supervisions["num_frames"],
+                subsampling_factor,
+                rounding_mode="floor",
+            )
         ),
         1,
     ).to(torch.int32)

From 60317120caef28a43ed904a26263c57536ca95ab Mon Sep 17 00:00:00 2001
From: Fangjun Kuang 
Date: Thu, 17 Nov 2022 20:19:32 +0800
Subject: [PATCH 010/174] Revert "Apply new Black style changes"

---
 .git-blame-ignore-revs                        |    2 -
 .github/workflows/style_check.yml             |   11 +-
 .pre-commit-config.yaml                       |   28 +-
 docker/README.md                              |   24 +-
 .../Dockerfile                                |   14 +-
 .../Dockerfile                                |   17 +-
 .../images/k2-gt-v1.9-blueviolet.svg          |    2 +-
 .../images/python-gt-v3.6-blue.svg            |    2 +-
 .../images/torch-gt-v1.6.0-green.svg          |    2 +-
 docs/source/recipes/aishell/index.rst         |    1 +
 docs/source/recipes/timit/index.rst           |    1 +
 docs/source/recipes/timit/tdnn_ligru_ctc.rst  |   28 +-
 docs/source/recipes/timit/tdnn_lstm_ctc.rst   |   24 +-
 .../local/compute_fbank_aidatatang_200zh.py   |    8 +-
 .../ASR/local/prepare_char.py                 |    8 +-
 .../ASR/local/prepare_lang.py                 |    4 +-
 .../ASR/local/test_prepare_lang.py            |    4 +-
 egs/aidatatang_200zh/ASR/local/text2token.py  |   21 +-
 egs/aidatatang_200zh/ASR/prepare.sh           |    3 +-
 .../asr_datamodule.py                         |  110 +-
 .../pruned_transducer_stateless2/decode.py    |   50 +-
 .../pruned_transducer_stateless2/export.py    |   20 +-
 .../pretrained.py                             |   41 +-
 .../ASR/pruned_transducer_stateless2/train.py |   50 +-
 egs/aishell/ASR/conformer_ctc/conformer.py    |   70 +-
 egs/aishell/ASR/conformer_ctc/decode.py       |   29 +-
 egs/aishell/ASR/conformer_ctc/export.py       |   17 +-
 egs/aishell/ASR/conformer_ctc/pretrained.py   |   39 +-
 egs/aishell/ASR/conformer_ctc/subsampling.py  |   16 +-
 .../ASR/conformer_ctc/test_subsampling.py     |    3 +-
 egs/aishell/ASR/conformer_ctc/train.py        |   12 +-
 egs/aishell/ASR/conformer_ctc/transformer.py  |   44 +-
 egs/aishell/ASR/conformer_mmi/conformer.py    |   70 +-
 egs/aishell/ASR/conformer_mmi/decode.py       |   33 +-
 egs/aishell/ASR/conformer_mmi/subsampling.py  |   16 +-
 egs/aishell/ASR/conformer_mmi/train.py        |    8 +-
 egs/aishell/ASR/conformer_mmi/transformer.py  |   44 +-
 .../local/compute_fbank_aidatatang_200zh.py   |    8 +-
 .../ASR/local/compute_fbank_aishell.py        |    8 +-
 egs/aishell/ASR/local/prepare_char.py         |    8 +-
 egs/aishell/ASR/local/prepare_lang.py         |    4 +-
 egs/aishell/ASR/local/test_prepare_lang.py    |    4 +-
 .../pruned_transducer_stateless2/decode.py    |   50 +-
 .../pruned_transducer_stateless2/export.py    |   31 +-
 .../pretrained.py                             |   50 +-
 .../ASR/pruned_transducer_stateless2/train.py |   64 +-
 .../pruned_transducer_stateless3/decode.py    |   73 +-
 .../pruned_transducer_stateless3/export.py    |   54 +-
 .../ASR/pruned_transducer_stateless3/model.py |    8 +-
 .../pretrained.py                             |   50 +-
 .../ASR/pruned_transducer_stateless3/train.py |   79 +-
 .../ASR/tdnn_lstm_ctc/asr_datamodule.py       |  118 +-
 egs/aishell/ASR/tdnn_lstm_ctc/decode.py       |   33 +-
 egs/aishell/ASR/tdnn_lstm_ctc/model.py        |    5 +-
 egs/aishell/ASR/tdnn_lstm_ctc/pretrained.py   |   37 +-
 egs/aishell/ASR/tdnn_lstm_ctc/train.py        |    7 +-
 .../ASR/transducer_stateless/beam_search.py   |   22 +-
 .../ASR/transducer_stateless/conformer.py     |   70 +-
 .../ASR/transducer_stateless/decode.py        |   39 +-
 .../ASR/transducer_stateless/decoder.py       |    4 +-
 .../ASR/transducer_stateless/export.py        |   20 +-
 egs/aishell/ASR/transducer_stateless/model.py |    4 +-
 .../ASR/transducer_stateless/pretrained.py    |   36 +-
 egs/aishell/ASR/transducer_stateless/train.py |   15 +-
 .../ASR/transducer_stateless/transformer.py   |    4 +-
 .../asr_datamodule.py                         |   85 +-
 .../transducer_stateless_modified-2/decode.py |   46 +-
 .../transducer_stateless_modified-2/export.py |   20 +-
 .../pretrained.py                             |   50 +-
 .../transducer_stateless_modified-2/train.py  |   22 +-
 .../transducer_stateless_modified/decode.py   |   46 +-
 .../transducer_stateless_modified/export.py   |   20 +-
 .../pretrained.py                             |   50 +-
 .../transducer_stateless_modified/train.py    |   15 +-
 egs/aishell2/ASR/local/__init__.py            |    0
 .../ASR/local/compute_fbank_aishell2.py       |    8 +-
 .../pruned_transducer_stateless5/__init__.py  |    0
 .../asr_datamodule.py                         |  114 +-
 .../pruned_transducer_stateless5/decode.py    |   67 +-
 .../pruned_transducer_stateless5/export.py    |   47 +-
 .../pretrained.py                             |   40 +-
 .../ASR/pruned_transducer_stateless5/train.py |   67 +-
 .../ASR/local/compute_fbank_aishell4.py       |    8 +-
 egs/aishell4/ASR/local/prepare_char.py        |    8 +-
 egs/aishell4/ASR/local/prepare_lang.py        |    4 +-
 egs/aishell4/ASR/local/test_prepare_lang.py   |    4 +-
 egs/aishell4/ASR/local/text2token.py          |   21 +-
 .../asr_datamodule.py                         |  110 +-
 .../pruned_transducer_stateless5/decode.py    |   69 +-
 .../pruned_transducer_stateless5/export.py    |   47 +-
 .../pretrained.py                             |   45 +-
 .../ASR/pruned_transducer_stateless5/train.py |   59 +-
 .../ASR/local/compute_fbank_alimeeting.py     |    8 +-
 egs/alimeeting/ASR/local/prepare_char.py      |    8 +-
 egs/alimeeting/ASR/local/prepare_lang.py      |    4 +-
 egs/alimeeting/ASR/local/test_prepare_lang.py |    4 +-
 egs/alimeeting/ASR/local/text2segments.py     |    2 +-
 egs/alimeeting/ASR/local/text2token.py        |   21 +-
 .../asr_datamodule.py                         |  110 +-
 .../pruned_transducer_stateless2/decode.py    |   60 +-
 .../pruned_transducer_stateless2/export.py    |   20 +-
 .../pretrained.py                             |   41 +-
 .../ASR/pruned_transducer_stateless2/train.py |   50 +-
 egs/csj/ASR/.gitignore                        |    2 +-
 egs/csj/ASR/local/compute_fbank_csj.py        |   38 +-
 egs/csj/ASR/local/compute_fbank_musan.py      |   17 +-
 egs/csj/ASR/local/conf/disfluent.ini          |   55 +-
 egs/csj/ASR/local/conf/fluent.ini             |   55 +-
 egs/csj/ASR/local/conf/number.ini             |   55 +-
 egs/csj/ASR/local/conf/symbol.ini             |   55 +-
 .../ASR/local/display_manifest_statistics.py  |    4 +-
 egs/csj/ASR/local/prepare_lang_char.py        |   17 +-
 egs/csj/ASR/local/validate_manifest.py        |    7 +-
 .../ASR/conformer_ctc/asr_datamodule.py       |  117 +-
 egs/gigaspeech/ASR/conformer_ctc/conformer.py |   66 +-
 egs/gigaspeech/ASR/conformer_ctc/decode.py    |   29 +-
 .../ASR/conformer_ctc/gigaspeech_scoring.py   |    3 +-
 .../ASR/conformer_ctc/label_smoothing.py      |    7 +-
 .../ASR/conformer_ctc/subsampling.py          |   16 +-
 egs/gigaspeech/ASR/conformer_ctc/train.py     |   12 +-
 .../ASR/conformer_ctc/transformer.py          |   49 +-
 .../compute_fbank_gigaspeech_dev_test.py      |    4 +-
 .../local/compute_fbank_gigaspeech_splits.py  |   10 +-
 .../ASR/local/preprocess_gigaspeech.py        |   10 +-
 .../asr_datamodule.py                         |  117 +-
 .../pruned_transducer_stateless2/decode.py    |   42 +-
 .../pruned_transducer_stateless2/export.py    |   24 +-
 .../ASR/pruned_transducer_stateless2/train.py |   48 +-
 egs/librispeech/ASR/conformer_ctc/ali.py      |   25 +-
 .../ASR/conformer_ctc/conformer.py            |   66 +-
 egs/librispeech/ASR/conformer_ctc/decode.py   |   29 +-
 egs/librispeech/ASR/conformer_ctc/export.py   |   17 +-
 .../ASR/conformer_ctc/label_smoothing.py      |    7 +-
 .../ASR/conformer_ctc/pretrained.py           |   33 +-
 .../ASR/conformer_ctc/subsampling.py          |   16 +-
 egs/librispeech/ASR/conformer_ctc/train.py    |   22 +-
 .../ASR/conformer_ctc/transformer.py          |   49 +-
 .../ASR/conformer_ctc2/attention.py           |   19 +-
 .../ASR/conformer_ctc2/conformer.py           |   65 +-
 egs/librispeech/ASR/conformer_ctc2/decode.py  |   56 +-
 egs/librispeech/ASR/conformer_ctc2/export.py  |   49 +-
 egs/librispeech/ASR/conformer_ctc2/train.py   |   39 +-
 .../ASR/conformer_ctc2/transformer.py         |   46 +-
 .../ASR/conformer_mmi/conformer.py            |   70 +-
 egs/librispeech/ASR/conformer_mmi/decode.py   |   29 +-
 .../ASR/conformer_mmi/subsampling.py          |   16 +-
 .../ASR/conformer_mmi/test_subsampling.py     |    3 +-
 .../ASR/conformer_mmi/test_transformer.py     |    9 +-
 .../ASR/conformer_mmi/train-with-attention.py |   27 +-
 egs/librispeech/ASR/conformer_mmi/train.py    |   27 +-
 .../ASR/conformer_mmi/transformer.py          |   28 +-
 .../decode.py                                 |   69 +-
 .../emformer.py                               |  119 +-
 .../export.py                                 |   47 +-
 .../stream.py                                 |    8 +-
 .../streaming_decode.py                       |   75 +-
 .../train.py                                  |   56 +-
 .../decode.py                                 |   69 +-
 .../emformer.py                               |  108 +-
 .../export.py                                 |   47 +-
 .../streaming_decode.py                       |   75 +-
 .../train.py                                  |   56 +-
 .../ASR/local/add_alignment_librispeech.py    |   12 +-
 egs/librispeech/ASR/local/compile_hlg.py      |    6 +-
 egs/librispeech/ASR/local/compile_lg.py       |    4 +-
 .../compute_fbank_gigaspeech_dev_test.py      |    4 +-
 .../local/compute_fbank_gigaspeech_splits.py  |   10 +-
 .../ASR/local/compute_fbank_librispeech.py    |    8 +-
 .../ASR/local/compute_fbank_musan.py          |    8 +-
 .../convert_transcript_words_to_tokens.py     |   16 +-
 egs/librispeech/ASR/local/download_lm.py      |    4 +-
 egs/librispeech/ASR/local/filter_cuts.py      |   10 +-
 .../ASR/local/generate_unique_lexicon.py      |    4 +-
 egs/librispeech/ASR/local/prepare_lang_bpe.py |    4 +-
 .../ASR/local/prepare_lm_training_data.py     |   11 +-
 .../ASR/local/preprocess_gigaspeech.py        |    4 +-
 .../ASR/local/test_prepare_lang.py            |    4 +-
 .../ASR/local/validate_manifest.py            |    7 +-
 .../ASR/lstm_transducer_stateless/decode.py   |  818 ++++++++++++
 .../ASR/lstm_transducer_stateless/export.py   |  388 ++++++
 .../jit_pretrained.py                         |  322 +++++
 .../ASR/lstm_transducer_stateless/lstm.py     |  871 +++++++++++++
 .../ASR/lstm_transducer_stateless/model.py    |  210 +++
 .../lstm_transducer_stateless/pretrained.py   |  352 +++++
 .../ASR/lstm_transducer_stateless/stream.py   |  148 +++
 .../streaming_decode.py                       |  968 ++++++++++++++
 .../ASR/lstm_transducer_stateless/train.py    | 1157 +++++++++++++++++
 .../ASR/lstm_transducer_stateless2/decode.py  |   67 +-
 .../ASR/lstm_transducer_stateless2/export.py  |   59 +-
 .../jit_pretrained.py                         |   21 +-
 .../ASR/lstm_transducer_stateless2/model.py   |    8 +-
 .../lstm_transducer_stateless2/ncnn-decode.py |   15 +-
 .../lstm_transducer_stateless2/pretrained.py  |   40 +-
 .../streaming-ncnn-decode.py                  |   27 +-
 .../streaming-onnx-decode.py                  |   45 +-
 .../ASR/lstm_transducer_stateless2/train.py   |   68 +-
 .../ASR/lstm_transducer_stateless3/decode.py  |   79 +-
 .../ASR/lstm_transducer_stateless3/export.py  |   47 +-
 .../jit_pretrained.py                         |   21 +-
 .../ASR/lstm_transducer_stateless3/lstm.py    |   14 +-
 .../lstm_transducer_stateless3/pretrained.py  |   40 +-
 .../streaming_decode.py                       |   74 +-
 .../ASR/lstm_transducer_stateless3/train.py   |   66 +-
 .../ASR/pruned2_knowledge/asr_datamodule.py   |  125 +-
 .../ASR/pruned2_knowledge/beam_search.py      |   18 +-
 .../ASR/pruned2_knowledge/conformer.py        |   90 +-
 .../ASR/pruned2_knowledge/decode.py           |   44 +-
 .../ASR/pruned2_knowledge/decoder.py          |    4 +-
 .../ASR/pruned2_knowledge/decoder2.py         |   84 +-
 .../ASR/pruned2_knowledge/export.py           |   20 +-
 .../ASR/pruned2_knowledge/joiner.py           |    4 +-
 .../ASR/pruned2_knowledge/model.py            |    8 +-
 .../ASR/pruned2_knowledge/optim.py            |   35 +-
 .../ASR/pruned2_knowledge/sampling.py         |  180 ++-
 .../ASR/pruned2_knowledge/scaling.py          |   51 +-
 .../ASR/pruned2_knowledge/scaling_tmp.py      |  355 ++---
 .../ASR/pruned2_knowledge/train.py            |   50 +-
 .../pruned_stateless_emformer_rnnt2/decode.py |   69 +-
 .../emformer.py                               |    8 +-
 .../pruned_stateless_emformer_rnnt2/export.py |   47 +-
 .../pruned_stateless_emformer_rnnt2/model.py  |    4 +-
 .../pruned_stateless_emformer_rnnt2/train.py  |   44 +-
 .../beam_search.py                            |   26 +-
 .../ASR/pruned_transducer_stateless/decode.py |   44 +-
 .../decode_stream.py                          |   19 +-
 .../pruned_transducer_stateless/decoder.py    |    4 +-
 .../ASR/pruned_transducer_stateless/export.py |   20 +-
 .../ASR/pruned_transducer_stateless/model.py  |    4 +-
 .../pruned_transducer_stateless/pretrained.py |   36 +-
 .../streaming_beam_search.py                  |    8 +-
 .../streaming_decode.py                       |   39 +-
 .../ASR/pruned_transducer_stateless/train.py  |   46 +-
 .../beam_search.py                            |   51 +-
 .../pruned_transducer_stateless2/conformer.py |   97 +-
 .../pruned_transducer_stateless2/decode.py    |   50 +-
 .../pruned_transducer_stateless2/decoder.py   |    8 +-
 .../pruned_transducer_stateless2/export.py    |   24 +-
 .../pruned_transducer_stateless2/joiner.py    |    4 +-
 .../ASR/pruned_transducer_stateless2/model.py |    8 +-
 .../ASR/pruned_transducer_stateless2/optim.py |   35 +-
 .../pretrained.py                             |   36 +-
 .../pruned_transducer_stateless2/scaling.py   |   56 +-
 .../streaming_beam_search.py                  |   12 +-
 .../streaming_decode.py                       |   39 +-
 .../ASR/pruned_transducer_stateless2/train.py |   58 +-
 .../asr_datamodule.py                         |   85 +-
 .../decode-giga.py                            |   54 +-
 .../pruned_transducer_stateless3/decode.py    |   74 +-
 .../pruned_transducer_stateless3/export.py    |   32 +-
 .../gigaspeech.py                             |    8 +-
 .../jit_pretrained.py                         |   21 +-
 .../ASR/pruned_transducer_stateless3/model.py |    8 +-
 .../onnx_check.py                             |   24 +-
 .../onnx_pretrained.py                        |   27 +-
 .../pretrained.py                             |   36 +-
 .../scaling_converter.py                      |   10 +-
 .../streaming_decode.py                       |   39 +-
 .../pruned_transducer_stateless3/test_onnx.py |   24 +-
 .../ASR/pruned_transducer_stateless3/train.py |   65 +-
 .../pruned_transducer_stateless4/decode.py    |   79 +-
 .../pruned_transducer_stateless4/export.py    |   47 +-
 .../streaming_decode.py                       |   62 +-
 .../ASR/pruned_transducer_stateless4/train.py |   61 +-
 .../pruned_transducer_stateless5/conformer.py |  118 +-
 .../pruned_transducer_stateless5/decode.py    |   67 +-
 .../pruned_transducer_stateless5/export.py    |   47 +-
 .../pretrained.py                             |   40 +-
 .../streaming_decode.py                       |   62 +-
 .../ASR/pruned_transducer_stateless5/train.py |   66 +-
 .../pruned_transducer_stateless6/conformer.py |   67 +-
 .../pruned_transducer_stateless6/decode.py    |   69 +-
 .../pruned_transducer_stateless6/export.py    |   24 +-
 .../extract_codebook_index.py                 |    3 +-
 .../hubert_decode.py                          |   17 +-
 .../hubert_xlarge.py                          |   22 +-
 .../ASR/pruned_transducer_stateless6/model.py |   12 +-
 .../ASR/pruned_transducer_stateless6/train.py |   65 +-
 .../pruned_transducer_stateless6/vq_utils.py  |   31 +-
 .../pruned_transducer_stateless7/decode.py    |   67 +-
 .../pruned_transducer_stateless7/decoder.py   |    6 +-
 .../pruned_transducer_stateless7/export.py    |   47 +-
 .../jit_pretrained.py                         |   21 +-
 .../pruned_transducer_stateless7/joiner.py    |    4 +-
 .../ASR/pruned_transducer_stateless7/model.py |   16 +-
 .../ASR/pruned_transducer_stateless7/optim.py |  435 +++----
 .../pretrained.py                             |   40 +-
 .../pruned_transducer_stateless7/scaling.py   |  487 ++++---
 .../scaling_converter.py                      |   12 +-
 .../ASR/pruned_transducer_stateless7/train.py |   88 +-
 .../pruned_transducer_stateless7/zipformer.py |  654 +++++-----
 .../pruned_transducer_stateless8/decode.py    |   67 +-
 .../pruned_transducer_stateless8/export.py    |   47 +-
 .../jit_pretrained.py                         |   21 +-
 .../ASR/pruned_transducer_stateless8/model.py |    4 +-
 .../pretrained.py                             |   40 +-
 .../ASR/pruned_transducer_stateless8/train.py |   99 +-
 .../ASR/streaming_conformer_ctc/README.md     |   16 +-
 .../ASR/streaming_conformer_ctc/conformer.py  |  116 +-
 .../streaming_decode.py                       |   68 +-
 .../ASR/streaming_conformer_ctc/train.py      |   16 +-
 .../streaming_conformer_ctc/transformer.py    |   40 +-
 .../ASR/tdnn_lstm_ctc/asr_datamodule.py       |  113 +-
 egs/librispeech/ASR/tdnn_lstm_ctc/decode.py   |   29 +-
 egs/librispeech/ASR/tdnn_lstm_ctc/model.py    |    5 +-
 .../ASR/tdnn_lstm_ctc/pretrained.py           |   43 +-
 egs/librispeech/ASR/tdnn_lstm_ctc/train.py    |    8 +-
 egs/librispeech/ASR/transducer/beam_search.py |   14 +-
 egs/librispeech/ASR/transducer/decode.py      |   28 +-
 egs/librispeech/ASR/transducer/export.py      |   17 +-
 egs/librispeech/ASR/transducer/pretrained.py  |   33 +-
 egs/librispeech/ASR/transducer/rnn.py         |   24 +-
 egs/librispeech/ASR/transducer/test_rnn.py    |   16 +-
 egs/librispeech/ASR/transducer/train.py       |   12 +-
 .../ASR/transducer_lstm/beam_search.py        |   14 +-
 egs/librispeech/ASR/transducer_lstm/decode.py |   28 +-
 .../ASR/transducer_lstm/encoder.py            |    4 +-
 egs/librispeech/ASR/transducer_lstm/train.py  |   12 +-
 .../ASR/transducer_stateless/alignment.py     |    4 +-
 .../ASR/transducer_stateless/beam_search.py   |   28 +-
 .../ASR/transducer_stateless/compute_ali.py   |   24 +-
 .../ASR/transducer_stateless/conformer.py     |  107 +-
 .../ASR/transducer_stateless/decode.py        |   42 +-
 .../ASR/transducer_stateless/decoder.py       |    4 +-
 .../ASR/transducer_stateless/export.py        |   20 +-
 .../ASR/transducer_stateless/joiner.py        |    8 +-
 .../ASR/transducer_stateless/pretrained.py    |   36 +-
 .../transducer_stateless/test_compute_ali.py  |   11 +-
 .../transducer_stateless/test_conformer.py    |    4 +-
 .../ASR/transducer_stateless/train.py         |   23 +-
 .../ASR/transducer_stateless/transformer.py   |    4 +-
 .../ASR/transducer_stateless2/decode.py       |   42 +-
 .../ASR/transducer_stateless2/export.py       |   20 +-
 .../ASR/transducer_stateless2/pretrained.py   |   36 +-
 .../ASR/transducer_stateless2/train.py        |   23 +-
 .../decode.py                                 |   42 +-
 .../export.py                                 |   20 +-
 .../pretrained.py                             |   36 +-
 .../test_asr_datamodule.py                    |    4 +-
 .../train.py                                  |   22 +-
 egs/ptb/LM/local/sort_lm_training_data.py     |    4 +-
 .../LM/local/test_prepare_lm_training_data.py |    4 +-
 .../ASR/local/compute_fbank_musan.py          |    8 +-
 .../ASR/local/compute_fbank_spgispeech.py     |   14 +-
 egs/spgispeech/ASR/local/prepare_splits.py    |    8 +-
 .../asr_datamodule.py                         |  100 +-
 .../pruned_transducer_stateless2/decode.py    |   66 +-
 .../pruned_transducer_stateless2/export.py    |   26 +-
 .../ASR/pruned_transducer_stateless2/train.py |   51 +-
 .../ASR/local/compute_fbank_tal_csasr.py      |    8 +-
 egs/tal_csasr/ASR/local/prepare_char.py       |    4 +-
 egs/tal_csasr/ASR/local/prepare_lang.py       |    4 +-
 egs/tal_csasr/ASR/local/test_prepare_lang.py  |    4 +-
 egs/tal_csasr/ASR/local/text2token.py         |   21 +-
 .../asr_datamodule.py                         |  110 +-
 .../pruned_transducer_stateless5/decode.py    |   77 +-
 .../pruned_transducer_stateless5/export.py    |   47 +-
 .../pretrained.py                             |   40 +-
 .../ASR/pruned_transducer_stateless5/train.py |   59 +-
 .../ASR/local/compute_fbank_tedlium.py        |    8 +-
 .../convert_transcript_words_to_bpe_ids.py    |    4 +-
 egs/tedlium3/ASR/local/prepare_lexicon.py     |   11 +-
 egs/tedlium3/ASR/local/prepare_transcripts.py |   11 +-
 .../ASR/pruned_transducer_stateless/decode.py |   38 +-
 .../ASR/pruned_transducer_stateless/export.py |   20 +-
 .../pruned_transducer_stateless/pretrained.py |   41 +-
 .../ASR/pruned_transducer_stateless/train.py  |   35 +-
 .../transducer_stateless/asr_datamodule.py    |  127 +-
 .../ASR/transducer_stateless/beam_search.py   |   30 +-
 .../ASR/transducer_stateless/decode.py        |   31 +-
 .../ASR/transducer_stateless/decoder.py       |    4 +-
 .../ASR/transducer_stateless/export.py        |   20 +-
 .../ASR/transducer_stateless/pretrained.py    |   36 +-
 .../ASR/transducer_stateless/train.py         |   11 +-
 egs/timit/ASR/RESULTS.md                      |    2 +-
 egs/timit/ASR/local/compile_hlg.py            |    4 +-
 egs/timit/ASR/local/compute_fbank_timit.py    |    8 +-
 egs/timit/ASR/local/prepare_lexicon.py        |    8 +-
 egs/timit/ASR/prepare.sh                      |    4 +-
 egs/timit/ASR/tdnn_ligru_ctc/decode.py        |   29 +-
 egs/timit/ASR/tdnn_ligru_ctc/model.py         |   12 +-
 egs/timit/ASR/tdnn_ligru_ctc/pretrained.py    |   43 +-
 egs/timit/ASR/tdnn_ligru_ctc/train.py         |    4 +-
 egs/timit/ASR/tdnn_lstm_ctc/asr_datamodule.py |  104 +-
 egs/timit/ASR/tdnn_lstm_ctc/decode.py         |   29 +-
 egs/timit/ASR/tdnn_lstm_ctc/model.py          |    5 +-
 egs/timit/ASR/tdnn_lstm_ctc/pretrained.py     |   43 +-
 egs/timit/ASR/tdnn_lstm_ctc/train.py          |    4 +-
 .../compute_fbank_wenetspeech_dev_test.py     |   11 +-
 .../local/compute_fbank_wenetspeech_splits.py |   10 +-
 egs/wenetspeech/ASR/local/prepare_char.py     |    8 +-
 .../ASR/local/preprocess_wenetspeech.py       |    6 +-
 egs/wenetspeech/ASR/local/text2token.py       |   21 +-
 egs/wenetspeech/ASR/prepare.sh                |    2 +-
 .../asr_datamodule.py                         |  121 +-
 .../pruned_transducer_stateless2/decode.py    |   64 +-
 .../pruned_transducer_stateless2/export.py    |   28 +-
 .../jit_pretrained.py                         |   21 +-
 .../onnx_check.py                             |   24 +-
 .../onnx_pretrained.py                        |   27 +-
 .../pretrained.py                             |   41 +-
 .../ASR/pruned_transducer_stateless2/train.py |   50 +-
 .../pruned_transducer_stateless5/conformer.py |   97 +-
 .../pruned_transducer_stateless5/decode.py    |   75 +-
 .../decode_stream.py                          |   19 +-
 .../pruned_transducer_stateless5/export.py    |   20 +-
 .../pretrained.py                             |   41 +-
 .../streaming_beam_search.py                  |    8 +-
 .../streaming_decode.py                       |   62 +-
 .../ASR/pruned_transducer_stateless5/train.py |   67 +-
 egs/yesno/ASR/local/compile_hlg.py            |    4 +-
 egs/yesno/ASR/local/compute_fbank_yesno.py    |   12 +-
 egs/yesno/ASR/tdnn/asr_datamodule.py          |   74 +-
 egs/yesno/ASR/tdnn/decode.py                  |   29 +-
 egs/yesno/ASR/tdnn/pretrained.py              |   37 +-
 egs/yesno/ASR/tdnn/train.py                   |    4 +-
 egs/yesno/ASR/transducer/decode.py            |   25 +-
 egs/yesno/ASR/transducer/train.py             |    4 +-
 icefall/char_graph_compiler.py                |    8 +-
 icefall/checkpoint.py                         |   12 +-
 icefall/decode.py                             |   40 +-
 icefall/diagnostics.py                        |   80 +-
 icefall/dist.py                               |    4 +-
 icefall/env.py                                |    4 +-
 icefall/graph_compiler.py                     |    4 +-
 icefall/hooks.py                              |   19 +-
 icefall/lexicon.py                            |   16 +-
 icefall/mmi.py                                |   29 +-
 icefall/mmi_graph_compiler.py                 |    8 +-
 icefall/rnn_lm/compute_perplexity.py          |   15 +-
 icefall/rnn_lm/dataset.py                     |    8 +-
 icefall/rnn_lm/export.py                      |   17 +-
 icefall/rnn_lm/model.py                       |   28 +-
 icefall/rnn_lm/train.py                       |   11 +-
 icefall/shared/make_kn_lm.py                  |  184 +--
 icefall/utils.py                              |   66 +-
 pyproject.toml                                |    2 +-
 setup.py                                      |    3 +-
 test/test_checkpoint.py                       |    6 +-
 test/test_decode.py                           |    1 -
 test/test_graph_compiler.py                   |    4 +-
 test/test_utils.py                            |    4 +-
 441 files changed, 14535 insertions(+), 6789 deletions(-)
 delete mode 100644 .git-blame-ignore-revs
 mode change 100644 => 100755 egs/aishell2/ASR/local/__init__.py
 mode change 100644 => 100755 egs/aishell2/ASR/pruned_transducer_stateless5/__init__.py
 mode change 100644 => 100755 egs/aishell2/ASR/pruned_transducer_stateless5/asr_datamodule.py
 mode change 100644 => 100755 egs/librispeech/ASR/lstm_transducer_stateless/decode.py
 mode change 100644 => 100755 egs/librispeech/ASR/lstm_transducer_stateless/export.py
 mode change 100644 => 100755 egs/librispeech/ASR/lstm_transducer_stateless/jit_pretrained.py
 mode change 100644 => 100755 egs/librispeech/ASR/lstm_transducer_stateless/pretrained.py
 mode change 100644 => 100755 egs/librispeech/ASR/lstm_transducer_stateless/streaming_decode.py
 mode change 100644 => 100755 egs/librispeech/ASR/lstm_transducer_stateless/train.py

diff --git a/.git-blame-ignore-revs b/.git-blame-ignore-revs
deleted file mode 100644
index c5908fc89..000000000
--- a/.git-blame-ignore-revs
+++ /dev/null
@@ -1,2 +0,0 @@
-# Migrate to 88 characters per line (see: https://github.com/lhotse-speech/lhotse/issues/890)
-d110b04ad389134c82fa314e3aafc7b40043efb0
diff --git a/.github/workflows/style_check.yml b/.github/workflows/style_check.yml
index 45d261ccc..90459bc1c 100644
--- a/.github/workflows/style_check.yml
+++ b/.github/workflows/style_check.yml
@@ -45,18 +45,17 @@ jobs:
 
       - name: Install Python dependencies
         run: |
-          python3 -m pip install --upgrade pip black==22.3.0 flake8==5.0.4 click==8.1.0
-          # Click issue fixed in https://github.com/psf/black/pull/2966
+          python3 -m pip install --upgrade pip black==21.6b0 flake8==3.9.2 click==8.0.4
+          # See https://github.com/psf/black/issues/2964
+          # The version of click should be selected from 8.0.0, 8.0.1, 8.0.2, 8.0.3, and 8.0.4
 
       - name: Run flake8
         shell: bash
         working-directory: ${{github.workspace}}
         run: |
           # stop the build if there are Python syntax errors or undefined names
-          flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
-          # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
-          flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 \
-            --statistics --extend-ignore=E203,E266,E501,F401,E402,F403,F841,W503
+          flake8 . --count --show-source --statistics
+          flake8 .
 
       - name: Run black
         shell: bash
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 5cb213327..446ba0fe7 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -1,38 +1,26 @@
 repos:
   - repo: https://github.com/psf/black
-    rev: 22.3.0
+    rev: 21.6b0
     hooks:
       - id: black
-        args: ["--line-length=88"]
-        additional_dependencies: ['click==8.1.0']
+        args: [--line-length=80]
+        additional_dependencies: ['click==8.0.1']
         exclude: icefall\/__init__\.py
 
   - repo: https://github.com/PyCQA/flake8
-    rev: 5.0.4
+    rev: 3.9.2
     hooks:
       - id: flake8
-        args: ["--max-line-length=88", "--extend-ignore=E203,E266,E501,F401,E402,F403,F841,W503"]
-
-      # What are we ignoring here?
-      # E203: whitespace before ':'
-      # E266: too many leading '#' for block comment
-      # E501: line too long
-      # F401: module imported but unused
-      # E402: module level import not at top of file
-      # F403: 'from module import *' used; unable to detect undefined names
-      # F841: local variable is assigned to but never used
-      # W503: line break before binary operator
-      # In addition, the default ignore list is:
-      # E121,E123,E126,E226,E24,E704,W503,W504
+        args: [--max-line-length=80]
 
   - repo: https://github.com/pycqa/isort
-    rev: 5.10.1
+    rev: 5.9.2
     hooks:
       - id: isort
-        args: ["--profile=black"]
+        args: [--profile=black, --line-length=80]
 
   - repo: https://github.com/pre-commit/pre-commit-hooks
-    rev: v4.2.0
+    rev: v4.0.1
     hooks:
       - id: check-executables-have-shebangs
       - id: end-of-file-fixer
diff --git a/docker/README.md b/docker/README.md
index c14b9bf75..6f2314e96 100644
--- a/docker/README.md
+++ b/docker/README.md
@@ -2,7 +2,7 @@
 
 2 sets of configuration are provided - (a) Ubuntu18.04-pytorch1.12.1-cuda11.3-cudnn8, and (b) Ubuntu18.04-pytorch1.7.1-cuda11.0-cudnn8.
 
-If your NVIDIA driver supports CUDA Version: 11.3, please go for case (a) Ubuntu18.04-pytorch1.12.1-cuda11.3-cudnn8.
+If your NVIDIA driver supports CUDA Version: 11.3, please go for case (a) Ubuntu18.04-pytorch1.12.1-cuda11.3-cudnn8. 
 
 Otherwise, since the older PyTorch images are not updated with the [apt-key rotation by NVIDIA](https://developer.nvidia.com/blog/updating-the-cuda-linux-gpg-repository-key), you have to go for case (b) Ubuntu18.04-pytorch1.7.1-cuda11.0-cudnn8. Ensure that your NVDIA driver supports at least CUDA 11.0.
 
@@ -10,7 +10,7 @@ You can check the highest CUDA version within your NVIDIA driver's support with
 
 ```bash
 $ nvidia-smi
-Tue Sep 20 00:26:13 2022
+Tue Sep 20 00:26:13 2022       
 +-----------------------------------------------------------------------------+
 | NVIDIA-SMI 450.119.03   Driver Version: 450.119.03   CUDA Version: 11.0     |
 |-------------------------------+----------------------+----------------------+
@@ -26,7 +26,7 @@ Tue Sep 20 00:26:13 2022
 | 41%   30C    P8    11W / 280W |      6MiB / 24220MiB |      0%      Default |
 |                               |                      |                  N/A |
 +-------------------------------+----------------------+----------------------+
-
+                                                                               
 +-----------------------------------------------------------------------------+
 | Processes:                                                                  |
 |  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
@@ -40,15 +40,15 @@ Tue Sep 20 00:26:13 2022
 ```
 
 ## Building images locally
-If your environment requires a proxy to access the Internet, remember to add those information into the Dockerfile directly.
-For most cases, you can uncomment these lines in the Dockerfile and add in your proxy details.
+If your environment requires a proxy to access the Internet, remember to add those information into the Dockerfile directly. 
+For most cases, you can uncomment these lines in the Dockerfile and add in your proxy details. 
 
 ```dockerfile
 ENV http_proxy=http://aaa.bb.cc.net:8080 \
     https_proxy=http://aaa.bb.cc.net:8080
 ```
 
-Then, proceed with these commands.
+Then, proceed with these commands. 
 
 ### If you are case (a), i.e. your NVIDIA driver supports CUDA version >= 11.3:
 
@@ -72,11 +72,11 @@ docker run -it --runtime=nvidia --shm-size=2gb --name=icefall --gpus all icefall
 ```
 
 ### Tips:
-1. Since your data and models most probably won't be in the docker, you must use the -v flag to access the host machine. Do this by specifying `-v {/path/in/host/machine}:{/path/in/docker}`.
+1. Since your data and models most probably won't be in the docker, you must use the -v flag to access the host machine. Do this by specifying `-v {/path/in/host/machine}:{/path/in/docker}`. 
 
 2. Also, if your environment requires a proxy, this would be a good time to add it in too: `-e http_proxy=http://aaa.bb.cc.net:8080 -e https_proxy=http://aaa.bb.cc.net:8080`.
 
-Overall, your docker run command should look like this.
+Overall, your docker run command should look like this. 
 
 ```bash
 docker run -it --runtime=nvidia --shm-size=2gb --name=icefall --gpus all -v {/path/in/host/machine}:{/path/in/docker} -e http_proxy=http://aaa.bb.cc.net:8080 -e https_proxy=http://aaa.bb.cc.net:8080 icefall/pytorch1.12.1
@@ -86,9 +86,9 @@ You can explore more docker run options [here](https://docs.docker.com/engine/re
 
 ### Linking to icefall in your host machine
 
-If you already have icefall downloaded onto your host machine, you can use that repository instead so that changes in your code are visible inside and outside of the container.
+If you already have icefall downloaded onto your host machine, you can use that repository instead so that changes in your code are visible inside and outside of the container. 
 
-Note: Remember to set the -v flag above during the first run of the container, as that is the only way for your container to access your host machine.
+Note: Remember to set the -v flag above during the first run of the container, as that is the only way for your container to access your host machine. 
 Warning: Check that the icefall in your host machine is visible from within your container before proceeding to the commands below.
 
 Use these commands once you are inside the container.
@@ -103,7 +103,7 @@ ln -s {/path/in/docker/to/icefall} /workspace/icefall
 docker exec -it icefall /bin/bash
 ```
 
-## Restarting a killed container that has been run before.
+## Restarting a killed container that has been run before. 
 ```bash
 docker start -ai icefall
 ```
@@ -111,4 +111,4 @@ docker start -ai icefall
 ## Sample usage of the CPU based images:
 ```bash
 docker run -it icefall /bin/bash
-```
+``` 
diff --git a/docker/Ubuntu18.04-pytorch1.12.1-cuda11.3-cudnn8/Dockerfile b/docker/Ubuntu18.04-pytorch1.12.1-cuda11.3-cudnn8/Dockerfile
index ff9e40604..3637d2f11 100644
--- a/docker/Ubuntu18.04-pytorch1.12.1-cuda11.3-cudnn8/Dockerfile
+++ b/docker/Ubuntu18.04-pytorch1.12.1-cuda11.3-cudnn8/Dockerfile
@@ -1,7 +1,7 @@
 FROM pytorch/pytorch:1.12.1-cuda11.3-cudnn8-devel
 
 # ENV http_proxy=http://aaa.bbb.cc.net:8080 \
-#	https_proxy=http://aaa.bbb.cc.net:8080
+#	https_proxy=http://aaa.bbb.cc.net:8080 
 
 # install normal source
 RUN apt-get update && \
@@ -38,10 +38,10 @@ RUN wget -P /opt https://cmake.org/files/v3.18/cmake-3.18.0.tar.gz && \
     rm -rf cmake-3.18.0.tar.gz && \
     find /opt/cmake-3.18.0 -type f \( -name "*.o" -o -name "*.la" -o -name "*.a" \) -exec rm {} \; && \
     cd -
-
-# flac
+	
+# flac 
 RUN wget -P /opt https://downloads.xiph.org/releases/flac/flac-1.3.2.tar.xz  && \
-    cd /opt && \
+    cd /opt && \ 
     xz -d flac-1.3.2.tar.xz && \
     tar -xvf flac-1.3.2.tar && \
     cd flac-1.3.2 && \
@@ -49,11 +49,11 @@ RUN wget -P /opt https://downloads.xiph.org/releases/flac/flac-1.3.2.tar.xz  &&
     make && make install && \
     rm -rf flac-1.3.2.tar && \
     find /opt/flac-1.3.2  -type f \( -name "*.o" -o -name "*.la" -o -name "*.a" \) -exec rm {} \; && \
-    cd -
+    cd - 
 
 RUN conda install -y -c pytorch torchaudio=0.12 && \
     pip install graphviz
-
+	
 
 #install k2 from source
 RUN git clone https://github.com/k2-fsa/k2.git /opt/k2 && \
@@ -68,7 +68,7 @@ RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \
 	cd /workspace/icefall && \
 	pip install -r requirements.txt
 
-RUN pip install kaldifeat
+RUN pip install kaldifeat 
 ENV PYTHONPATH /workspace/icefall:$PYTHONPATH
 
 WORKDIR /workspace/icefall
diff --git a/docker/Ubuntu18.04-pytorch1.7.1-cuda11.0-cudnn8/Dockerfile b/docker/Ubuntu18.04-pytorch1.7.1-cuda11.0-cudnn8/Dockerfile
index 5c7423fa5..17a8215f9 100644
--- a/docker/Ubuntu18.04-pytorch1.7.1-cuda11.0-cudnn8/Dockerfile
+++ b/docker/Ubuntu18.04-pytorch1.7.1-cuda11.0-cudnn8/Dockerfile
@@ -1,12 +1,12 @@
 FROM pytorch/pytorch:1.7.1-cuda11.0-cudnn8-devel
 
 # ENV http_proxy=http://aaa.bbb.cc.net:8080 \
-#	https_proxy=http://aaa.bbb.cc.net:8080
+#	https_proxy=http://aaa.bbb.cc.net:8080 
 
 RUN rm /etc/apt/sources.list.d/cuda.list && \
 	rm /etc/apt/sources.list.d/nvidia-ml.list && \
 	apt-key del 7fa2af80
-
+	
 # install normal source
 RUN apt-get update && \
     apt-get install -y --no-install-recommends \
@@ -36,7 +36,7 @@ RUN curl -fsSL https://developer.download.nvidia.com/compute/cuda/repos/ubuntu18
 	curl -fsSL https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64/7fa2af80.pub | apt-key add - && \
 	echo "deb https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64 /" > /etc/apt/sources.list.d/cuda.list && \
 	echo "deb https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64 /" > /etc/apt/sources.list.d/nvidia-ml.list && \
-	rm -rf /var/lib/apt/lists/* && \
+	rm -rf /var/lib/apt/lists/* && \ 
 	mv /opt/conda/lib/libcufft.so.10 /opt/libcufft.so.10.bak && \
     mv /opt/conda/lib/libcurand.so.10 /opt/libcurand.so.10.bak && \
     mv /opt/conda/lib/libcublas.so.11 /opt/libcublas.so.11.bak && \
@@ -56,10 +56,10 @@ RUN wget -P /opt https://cmake.org/files/v3.18/cmake-3.18.0.tar.gz && \
     rm -rf cmake-3.18.0.tar.gz && \
     find /opt/cmake-3.18.0 -type f \( -name "*.o" -o -name "*.la" -o -name "*.a" \) -exec rm {} \; && \
     cd -
-
-# flac
+	
+# flac 
 RUN wget -P /opt https://downloads.xiph.org/releases/flac/flac-1.3.2.tar.xz  && \
-    cd /opt && \
+    cd /opt && \ 
     xz -d flac-1.3.2.tar.xz && \
     tar -xvf flac-1.3.2.tar && \
     cd flac-1.3.2 && \
@@ -67,7 +67,7 @@ RUN wget -P /opt https://downloads.xiph.org/releases/flac/flac-1.3.2.tar.xz  &&
     make && make install && \
     rm -rf flac-1.3.2.tar && \
     find /opt/flac-1.3.2  -type f \( -name "*.o" -o -name "*.la" -o -name "*.a" \) -exec rm {} \; && \
-    cd -
+    cd - 
 
 RUN conda install -y -c pytorch torchaudio=0.7.1 && \
     pip install graphviz
@@ -79,7 +79,7 @@ RUN git clone https://github.com/k2-fsa/k2.git /opt/k2 && \
     cd -
 
 # install  lhotse
-RUN pip install git+https://github.com/lhotse-speech/lhotse
+RUN pip install git+https://github.com/lhotse-speech/lhotse 
 
 RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \
 	cd /workspace/icefall && \
@@ -88,3 +88,4 @@ RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \
 ENV PYTHONPATH /workspace/icefall:$PYTHONPATH
 
 WORKDIR /workspace/icefall
+
diff --git a/docs/source/installation/images/k2-gt-v1.9-blueviolet.svg b/docs/source/installation/images/k2-gt-v1.9-blueviolet.svg
index 3019ff03d..534b2e534 100644
--- a/docs/source/installation/images/k2-gt-v1.9-blueviolet.svg
+++ b/docs/source/installation/images/k2-gt-v1.9-blueviolet.svg
@@ -1 +1 @@
-k2: >= v1.9k2>= v1.9
+k2: >= v1.9k2>= v1.9
\ No newline at end of file
diff --git a/docs/source/installation/images/python-gt-v3.6-blue.svg b/docs/source/installation/images/python-gt-v3.6-blue.svg
index df677ad09..4254dc58a 100644
--- a/docs/source/installation/images/python-gt-v3.6-blue.svg
+++ b/docs/source/installation/images/python-gt-v3.6-blue.svg
@@ -1 +1 @@
-python: >= 3.6python>= 3.6
+python: >= 3.6python>= 3.6
\ No newline at end of file
diff --git a/docs/source/installation/images/torch-gt-v1.6.0-green.svg b/docs/source/installation/images/torch-gt-v1.6.0-green.svg
index d7007d742..d3ece9a17 100644
--- a/docs/source/installation/images/torch-gt-v1.6.0-green.svg
+++ b/docs/source/installation/images/torch-gt-v1.6.0-green.svg
@@ -1 +1 @@
-torch: >= 1.6.0torch>= 1.6.0
+torch: >= 1.6.0torch>= 1.6.0
\ No newline at end of file
diff --git a/docs/source/recipes/aishell/index.rst b/docs/source/recipes/aishell/index.rst
index b77d59bca..d072d6e9c 100644
--- a/docs/source/recipes/aishell/index.rst
+++ b/docs/source/recipes/aishell/index.rst
@@ -19,3 +19,4 @@ It can be downloaded from ``_
    tdnn_lstm_ctc
    conformer_ctc
    stateless_transducer
+
diff --git a/docs/source/recipes/timit/index.rst b/docs/source/recipes/timit/index.rst
index 5ee147be7..17f40cdb7 100644
--- a/docs/source/recipes/timit/index.rst
+++ b/docs/source/recipes/timit/index.rst
@@ -6,3 +6,4 @@ TIMIT
 
    tdnn_ligru_ctc
    tdnn_lstm_ctc
+
diff --git a/docs/source/recipes/timit/tdnn_ligru_ctc.rst b/docs/source/recipes/timit/tdnn_ligru_ctc.rst
index 3d7aefe02..186420ee7 100644
--- a/docs/source/recipes/timit/tdnn_ligru_ctc.rst
+++ b/docs/source/recipes/timit/tdnn_ligru_ctc.rst
@@ -148,10 +148,10 @@ Some commonly used options are:
 
         $ ./tdnn_ligru_ctc/decode.py --epoch 25 --avg 17
 
-    uses the average of ``epoch-9.pt``, ``epoch-10.pt``, ``epoch-11.pt``,
-    ``epoch-12.pt``, ``epoch-13.pt``, ``epoch-14.pt``, ``epoch-15.pt``,
-    ``epoch-16.pt``, ``epoch-17.pt``, ``epoch-18.pt``, ``epoch-19.pt``,
-    ``epoch-20.pt``, ``epoch-21.pt``, ``epoch-22.pt``, ``epoch-23.pt``,
+    uses the average of ``epoch-9.pt``, ``epoch-10.pt``, ``epoch-11.pt``, 
+    ``epoch-12.pt``, ``epoch-13.pt``, ``epoch-14.pt``, ``epoch-15.pt``, 
+    ``epoch-16.pt``, ``epoch-17.pt``, ``epoch-18.pt``, ``epoch-19.pt``, 
+    ``epoch-20.pt``, ``epoch-21.pt``, ``epoch-22.pt``, ``epoch-23.pt``, 
     ``epoch-24.pt`` and ``epoch-25.pt``
     for decoding.
 
@@ -317,13 +317,13 @@ To decode with ``1best`` method, we can use:
 
 .. code-block:: bash
 
-  ./tdnn_ligru_ctc/pretrained.py
+  ./tdnn_ligru_ctc/pretrained.py 
     --method 1best
-    --checkpoint ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/exp/pretrained_average_9_25.pt
-    --words-file ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/data/lang_phone/words.txt
-    --HLG ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/data/lang_phone/HLG.pt
-    ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FDHC0_SI1559.WAV
-    ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FELC0_SI756.WAV
+    --checkpoint ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/exp/pretrained_average_9_25.pt 
+    --words-file ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/data/lang_phone/words.txt 
+    --HLG ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/data/lang_phone/HLG.pt 
+    ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FDHC0_SI1559.WAV 
+    ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FELC0_SI756.WAV 
     ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FMGD0_SI1564.WAV
 
 The output is:
@@ -337,7 +337,7 @@ The output is:
   2021-11-08 20:41:38,697 INFO [pretrained.py:210] Reading sound files: ['./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FDHC0_SI1559.WAV', './tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FELC0_SI756.WAV', './tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FMGD0_SI1564.WAV']
   2021-11-08 20:41:38,704 INFO [pretrained.py:216] Decoding started
   2021-11-08 20:41:39,819 INFO [pretrained.py:246] Use HLG decoding
-  2021-11-08 20:41:39,829 INFO [pretrained.py:267]
+  2021-11-08 20:41:39,829 INFO [pretrained.py:267] 
   ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FDHC0_SI1559.WAV:
   sil dh ih sh uw ah l iy v iy z ih sil p r aa sil k s ih m ey dx ih sil d w uh dx ih w ih s f iy l ih ng w ih th ih n ih m s eh l f sil jh
 
@@ -362,8 +362,8 @@ To decode with ``whole-lattice-rescoring`` methond, you can use
     --HLG ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/data/lang_phone/HLG.pt \
     --G ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/data/lm/G_4_gram.pt \
     --ngram-lm-scale 0.1 \
-    ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FDHC0_SI1559.WAV
-    ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FELC0_SI756.WAV
+    ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FDHC0_SI1559.WAV 
+    ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FELC0_SI756.WAV 
     ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FMGD0_SI1564.WAV
 
 The decoding output is:
@@ -378,7 +378,7 @@ The decoding output is:
   2021-11-08 20:37:54,715 INFO [pretrained.py:210] Reading sound files: ['./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FDHC0_SI1559.WAV', './tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FELC0_SI756.WAV', './tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FMGD0_SI1564.WAV']
   2021-11-08 20:37:54,720 INFO [pretrained.py:216] Decoding started
   2021-11-08 20:37:55,808 INFO [pretrained.py:251] Use HLG decoding + LM rescoring
-  2021-11-08 20:37:56,348 INFO [pretrained.py:267]
+  2021-11-08 20:37:56,348 INFO [pretrained.py:267] 
   ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FDHC0_SI1559.WAV:
   sil dh ih sh uw ah l iy v iy z ah sil p r aa sil k s ih m ey dx ih sil d w uh dx iy w ih s f iy l iy ng w ih th ih n ih m s eh l f sil jh
 
diff --git a/docs/source/recipes/timit/tdnn_lstm_ctc.rst b/docs/source/recipes/timit/tdnn_lstm_ctc.rst
index ee67a6edc..6f760a9ce 100644
--- a/docs/source/recipes/timit/tdnn_lstm_ctc.rst
+++ b/docs/source/recipes/timit/tdnn_lstm_ctc.rst
@@ -148,8 +148,8 @@ Some commonly used options are:
 
         $ ./tdnn_lstm_ctc/decode.py --epoch 25 --avg 10
 
-    uses the average of ``epoch-16.pt``, ``epoch-17.pt``, ``epoch-18.pt``,
-    ``epoch-19.pt``, ``epoch-20.pt``, ``epoch-21.pt``, ``epoch-22.pt``,
+    uses the average of ``epoch-16.pt``, ``epoch-17.pt``, ``epoch-18.pt``, 
+    ``epoch-19.pt``, ``epoch-20.pt``, ``epoch-21.pt``, ``epoch-22.pt``, 
     ``epoch-23.pt``, ``epoch-24.pt`` and ``epoch-25.pt``
     for decoding.
 
@@ -315,13 +315,13 @@ To decode with ``1best`` method, we can use:
 
 .. code-block:: bash
 
-  ./tdnn_lstm_ctc/pretrained.py
+  ./tdnn_lstm_ctc/pretrained.py 
     --method 1best
-    --checkpoint ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/exp/pretrained_average_16_25.pt
-    --words-file ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/data/lang_phone/words.txt
-    --HLG ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/data/lang_phone/HLG.pt
-    ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FDHC0_SI1559.WAV
-    ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FELC0_SI756.WAV
+    --checkpoint ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/exp/pretrained_average_16_25.pt 
+    --words-file ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/data/lang_phone/words.txt 
+    --HLG ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/data/lang_phone/HLG.pt 
+    ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FDHC0_SI1559.WAV 
+    ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FELC0_SI756.WAV 
     ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FMGD0_SI1564.WAV
 
 The output is:
@@ -335,7 +335,7 @@ The output is:
   2021-11-08 21:02:53,827 INFO [pretrained.py:210] Reading sound files: ['./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FDHC0_SI1559.WAV', './tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FELC0_SI756.WAV', './tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FMGD0_SI1564.WAV']
   2021-11-08 21:02:53,831 INFO [pretrained.py:216] Decoding started
   2021-11-08 21:02:54,380 INFO [pretrained.py:246] Use HLG decoding
-  2021-11-08 21:02:54,387 INFO [pretrained.py:267]
+  2021-11-08 21:02:54,387 INFO [pretrained.py:267] 
   ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FDHC0_SI1559.WAV:
   sil dh ih sh uw ah l iy v iy z ih sil p r aa sil k s ih m ey dx ih sil d w uh dx iy w ih s f iy l iy w ih th ih n ih m s eh l f sil jh
 
@@ -360,8 +360,8 @@ To decode with ``whole-lattice-rescoring`` methond, you can use
     --HLG ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/data/lang_phone/HLG.pt \
     --G ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/data/lm/G_4_gram.pt \
     --ngram-lm-scale 0.08 \
-    ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FDHC0_SI1559.WAV
-    ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FELC0_SI756.WAV
+    ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FDHC0_SI1559.WAV 
+    ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FELC0_SI756.WAV 
     ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FMGD0_SI1564.WAV
 
 The decoding output is:
@@ -376,7 +376,7 @@ The decoding output is:
   2021-11-08 20:05:26,978 INFO [pretrained.py:210] Reading sound files: ['./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FDHC0_SI1559.WAV', './tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FELC0_SI756.WAV', './tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FMGD0_SI1564.WAV']
   2021-11-08 20:05:26,981 INFO [pretrained.py:216] Decoding started
   2021-11-08 20:05:27,519 INFO [pretrained.py:251] Use HLG decoding + LM rescoring
-  2021-11-08 20:05:27,878 INFO [pretrained.py:267]
+  2021-11-08 20:05:27,878 INFO [pretrained.py:267] 
   ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FDHC0_SI1559.WAV:
   sil dh ih sh uw l iy v iy z ih sil p r aa sil k s ah m ey dx ih sil w uh dx iy w ih s f iy l ih ng w ih th ih n ih m s eh l f sil jh
 
diff --git a/egs/aidatatang_200zh/ASR/local/compute_fbank_aidatatang_200zh.py b/egs/aidatatang_200zh/ASR/local/compute_fbank_aidatatang_200zh.py
index 387c14acf..fb2751c0f 100755
--- a/egs/aidatatang_200zh/ASR/local/compute_fbank_aidatatang_200zh.py
+++ b/egs/aidatatang_200zh/ASR/local/compute_fbank_aidatatang_200zh.py
@@ -87,7 +87,9 @@ def compute_fbank_aidatatang_200zh(num_mel_bins: int = 80):
             )
             if "train" in partition:
                 cut_set = (
-                    cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
+                    cut_set
+                    + cut_set.perturb_speed(0.9)
+                    + cut_set.perturb_speed(1.1)
                 )
             cut_set = cut_set.compute_and_store_features(
                 extractor=extractor,
@@ -114,7 +116,9 @@ def get_args():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
 
diff --git a/egs/aidatatang_200zh/ASR/local/prepare_char.py b/egs/aidatatang_200zh/ASR/local/prepare_char.py
index 6b440dfb3..d9e47d17a 100755
--- a/egs/aidatatang_200zh/ASR/local/prepare_char.py
+++ b/egs/aidatatang_200zh/ASR/local/prepare_char.py
@@ -86,7 +86,9 @@ def lexicon_to_fst_no_sil(
         cur_state = loop_state
 
         word = word2id[word]
-        pieces = [token2id[i] if i in token2id else token2id[""] for i in pieces]
+        pieces = [
+            token2id[i] if i in token2id else token2id[""] for i in pieces
+        ]
 
         for i in range(len(pieces) - 1):
             w = word if i == 0 else eps
@@ -140,7 +142,9 @@ def contain_oov(token_sym_table: Dict[str, int], tokens: List[str]) -> bool:
     return False
 
 
-def generate_lexicon(token_sym_table: Dict[str, int], words: List[str]) -> Lexicon:
+def generate_lexicon(
+    token_sym_table: Dict[str, int], words: List[str]
+) -> Lexicon:
     """Generate a lexicon from a word list and token_sym_table.
 
     Args:
diff --git a/egs/aidatatang_200zh/ASR/local/prepare_lang.py b/egs/aidatatang_200zh/ASR/local/prepare_lang.py
index c8cf9b881..e5ae89ec4 100755
--- a/egs/aidatatang_200zh/ASR/local/prepare_lang.py
+++ b/egs/aidatatang_200zh/ASR/local/prepare_lang.py
@@ -317,7 +317,9 @@ def lexicon_to_fst(
 
 def get_args():
     parser = argparse.ArgumentParser()
-    parser.add_argument("--lang-dir", type=str, help="The lang dir, data/lang_phone")
+    parser.add_argument(
+        "--lang-dir", type=str, help="The lang dir, data/lang_phone"
+    )
     return parser.parse_args()
 
 
diff --git a/egs/aidatatang_200zh/ASR/local/test_prepare_lang.py b/egs/aidatatang_200zh/ASR/local/test_prepare_lang.py
index 74e025ad7..d4cf62bba 100755
--- a/egs/aidatatang_200zh/ASR/local/test_prepare_lang.py
+++ b/egs/aidatatang_200zh/ASR/local/test_prepare_lang.py
@@ -88,7 +88,9 @@ def test_read_lexicon(filename: str):
     fsa.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
     fsa.draw("L.pdf", title="L")
 
-    fsa_disambig = lexicon_to_fst(lexicon_disambig, phone2id=phone2id, word2id=word2id)
+    fsa_disambig = lexicon_to_fst(
+        lexicon_disambig, phone2id=phone2id, word2id=word2id
+    )
     fsa_disambig.labels_sym = k2.SymbolTable.from_file("phones.txt")
     fsa_disambig.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
     fsa_disambig.draw("L_disambig.pdf", title="L_disambig")
diff --git a/egs/aidatatang_200zh/ASR/local/text2token.py b/egs/aidatatang_200zh/ASR/local/text2token.py
index 2be639b7a..71be2a613 100755
--- a/egs/aidatatang_200zh/ASR/local/text2token.py
+++ b/egs/aidatatang_200zh/ASR/local/text2token.py
@@ -50,15 +50,15 @@ def get_parser():
         "-n",
         default=1,
         type=int,
-        help=(
-            "number of characters to split, i.e.,                         aabb -> a a b"
-            " b with -n 1 and aa bb with -n 2"
-        ),
+        help="number of characters to split, i.e., \
+                        aabb -> a a b b with -n 1 and aa bb with -n 2",
     )
     parser.add_argument(
         "--skip-ncols", "-s", default=0, type=int, help="skip first n columns"
     )
-    parser.add_argument("--space", default="", type=str, help="space symbol")
+    parser.add_argument(
+        "--space", default="", type=str, help="space symbol"
+    )
     parser.add_argument(
         "--non-lang-syms",
         "-l",
@@ -66,7 +66,9 @@ def get_parser():
         type=str,
         help="list of non-linguistic symobles, e.g.,  etc.",
     )
-    parser.add_argument("text", type=str, default=False, nargs="?", help="input text")
+    parser.add_argument(
+        "text", type=str, default=False, nargs="?", help="input text"
+    )
     parser.add_argument(
         "--trans_type",
         "-t",
@@ -106,7 +108,8 @@ def token2id(
             if token_type == "lazy_pinyin":
                 text = lazy_pinyin(chars_list)
                 sub_ids = [
-                    token_table[txt] if txt in token_table else oov_id for txt in text
+                    token_table[txt] if txt in token_table else oov_id
+                    for txt in text
                 ]
                 ids.append(sub_ids)
             else:  # token_type = "pinyin"
@@ -132,7 +135,9 @@ def main():
     if args.text:
         f = codecs.open(args.text, encoding="utf-8")
     else:
-        f = codecs.getreader("utf-8")(sys.stdin if is_python2 else sys.stdin.buffer)
+        f = codecs.getreader("utf-8")(
+            sys.stdin if is_python2 else sys.stdin.buffer
+        )
 
     sys.stdout = codecs.getwriter("utf-8")(
         sys.stdout if is_python2 else sys.stdout.buffer
diff --git a/egs/aidatatang_200zh/ASR/prepare.sh b/egs/aidatatang_200zh/ASR/prepare.sh
index 4749e1b7f..039951354 100755
--- a/egs/aidatatang_200zh/ASR/prepare.sh
+++ b/egs/aidatatang_200zh/ASR/prepare.sh
@@ -106,10 +106,11 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
   if [ ! -f $lang_char_dir/words.txt ]; then
     ./local/prepare_words.py \
       --input-file $lang_char_dir/words_no_ids.txt \
-      --output-file $lang_char_dir/words.txt
+      --output-file $lang_char_dir/words.txt 
   fi
 
   if [ ! -f $lang_char_dir/L_disambig.pt ]; then
     ./local/prepare_char.py
   fi
 fi
+
diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/asr_datamodule.py
index 8c94f5bea..6a5b57e24 100644
--- a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/asr_datamodule.py
+++ b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/asr_datamodule.py
@@ -81,12 +81,10 @@ class Aidatatang_200zhAsrDataModule:
     def add_arguments(cls, parser: argparse.ArgumentParser):
         group = parser.add_argument_group(
             title="ASR data related options",
-            description=(
-                "These options are used for the preparation of "
-                "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
-                "effective batch sizes, sampling strategies, applied data "
-                "augmentations, etc."
-            ),
+            description="These options are used for the preparation of "
+            "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
+            "effective batch sizes, sampling strategies, applied data "
+            "augmentations, etc.",
         )
         group.add_argument(
             "--manifest-dir",
@@ -98,91 +96,75 @@ class Aidatatang_200zhAsrDataModule:
             "--max-duration",
             type=int,
             default=200.0,
-            help=(
-                "Maximum pooled recordings duration (seconds) in a "
-                "single batch. You can reduce it if it causes CUDA OOM."
-            ),
+            help="Maximum pooled recordings duration (seconds) in a "
+            "single batch. You can reduce it if it causes CUDA OOM.",
         )
         group.add_argument(
             "--bucketing-sampler",
             type=str2bool,
             default=True,
-            help=(
-                "When enabled, the batches will come from buckets of "
-                "similar duration (saves padding frames)."
-            ),
+            help="When enabled, the batches will come from buckets of "
+            "similar duration (saves padding frames).",
         )
         group.add_argument(
             "--num-buckets",
             type=int,
             default=300,
-            help=(
-                "The number of buckets for the DynamicBucketingSampler"
-                "(you might want to increase it for larger datasets)."
-            ),
+            help="The number of buckets for the DynamicBucketingSampler"
+            "(you might want to increase it for larger datasets).",
         )
         group.add_argument(
             "--concatenate-cuts",
             type=str2bool,
             default=False,
-            help=(
-                "When enabled, utterances (cuts) will be concatenated "
-                "to minimize the amount of padding."
-            ),
+            help="When enabled, utterances (cuts) will be concatenated "
+            "to minimize the amount of padding.",
         )
         group.add_argument(
             "--duration-factor",
             type=float,
             default=1.0,
-            help=(
-                "Determines the maximum duration of a concatenated cut "
-                "relative to the duration of the longest cut in a batch."
-            ),
+            help="Determines the maximum duration of a concatenated cut "
+            "relative to the duration of the longest cut in a batch.",
         )
         group.add_argument(
             "--gap",
             type=float,
             default=1.0,
-            help=(
-                "The amount of padding (in seconds) inserted between "
-                "concatenated cuts. This padding is filled with noise when "
-                "noise augmentation is used."
-            ),
+            help="The amount of padding (in seconds) inserted between "
+            "concatenated cuts. This padding is filled with noise when "
+            "noise augmentation is used.",
         )
         group.add_argument(
             "--on-the-fly-feats",
             type=str2bool,
             default=False,
-            help=(
-                "When enabled, use on-the-fly cut mixing and feature "
-                "extraction. Will drop existing precomputed feature manifests "
-                "if available."
-            ),
+            help="When enabled, use on-the-fly cut mixing and feature "
+            "extraction. Will drop existing precomputed feature manifests "
+            "if available.",
         )
         group.add_argument(
             "--shuffle",
             type=str2bool,
             default=True,
-            help=(
-                "When enabled (=default), the examples will be shuffled for each epoch."
-            ),
+            help="When enabled (=default), the examples will be "
+            "shuffled for each epoch.",
         )
         group.add_argument(
             "--return-cuts",
             type=str2bool,
             default=True,
-            help=(
-                "When enabled, each batch will have the "
-                "field: batch['supervisions']['cut'] with the cuts that "
-                "were used to construct it."
-            ),
+            help="When enabled, each batch will have the "
+            "field: batch['supervisions']['cut'] with the cuts that "
+            "were used to construct it.",
         )
 
         group.add_argument(
             "--num-workers",
             type=int,
             default=2,
-            help="The number of training dataloader workers that collect the batches.",
+            help="The number of training dataloader workers that "
+            "collect the batches.",
         )
 
         group.add_argument(
@@ -196,22 +178,18 @@ class Aidatatang_200zhAsrDataModule:
             "--spec-aug-time-warp-factor",
             type=int,
             default=80,
-            help=(
-                "Used only when --enable-spec-aug is True. "
-                "It specifies the factor for time warping in SpecAugment. "
-                "Larger values mean more warping. "
-                "A value less than 1 means to disable time warp."
-            ),
+            help="Used only when --enable-spec-aug is True. "
+            "It specifies the factor for time warping in SpecAugment. "
+            "Larger values mean more warping. "
+            "A value less than 1 means to disable time warp.",
         )
 
         group.add_argument(
             "--enable-musan",
             type=str2bool,
             default=True,
-            help=(
-                "When enabled, select noise from MUSAN and mix it"
-                "with training dataset. "
-            ),
+            help="When enabled, select noise from MUSAN and mix it"
+            "with training dataset. ",
         )
 
     def train_dataloaders(
@@ -227,20 +205,24 @@ class Aidatatang_200zhAsrDataModule:
             The state dict for the training sampler.
         """
         logging.info("About to get Musan cuts")
-        cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
+        cuts_musan = load_manifest(
+            self.args.manifest_dir / "musan_cuts.jsonl.gz"
+        )
 
         transforms = []
         if self.args.enable_musan:
             logging.info("Enable MUSAN")
             transforms.append(
-                CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
+                CutMix(
+                    cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
+                )
             )
         else:
             logging.info("Disable MUSAN")
 
         if self.args.concatenate_cuts:
             logging.info(
-                "Using cut concatenation with duration factor "
+                f"Using cut concatenation with duration factor "
                 f"{self.args.duration_factor} and gap {self.args.gap}."
             )
             # Cut concatenation should be the first transform in the list,
@@ -255,7 +237,9 @@ class Aidatatang_200zhAsrDataModule:
         input_transforms = []
         if self.args.enable_spec_aug:
             logging.info("Enable SpecAugment")
-            logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
+            logging.info(
+                f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
+            )
             # Set the value of num_frame_masks according to Lhotse's version.
             # In different Lhotse's versions, the default of num_frame_masks is
             # different.
@@ -298,7 +282,9 @@ class Aidatatang_200zhAsrDataModule:
             # Drop feats to be on the safe side.
             train = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
+                input_strategy=OnTheFlyFeatures(
+                    Fbank(FbankConfig(num_mel_bins=80))
+                ),
                 input_transforms=input_transforms,
                 return_cuts=self.args.return_cuts,
             )
@@ -354,7 +340,9 @@ class Aidatatang_200zhAsrDataModule:
         if self.args.on_the_fly_feats:
             validate = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
+                input_strategy=OnTheFlyFeatures(
+                    Fbank(FbankConfig(num_mel_bins=80))
+                ),
                 return_cuts=self.args.return_cuts,
             )
         else:
diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/decode.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/decode.py
index 3f582ef04..f0407f429 100755
--- a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/decode.py
+++ b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/decode.py
@@ -69,7 +69,11 @@ from beam_search import (
 )
 from train import get_params, get_transducer_model
 
-from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
+from icefall.checkpoint import (
+    average_checkpoints,
+    find_checkpoints,
+    load_checkpoint,
+)
 from icefall.lexicon import Lexicon
 from icefall.utils import (
     AttributeDict,
@@ -88,30 +92,25 @@ def get_parser():
         "--epoch",
         type=int,
         default=28,
-        help=(
-            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
-        ),
+        help="It specifies the checkpoint to use for decoding."
+        "Note: Epoch counts from 0.",
     )
 
     parser.add_argument(
         "--batch",
         type=int,
         default=None,
-        help=(
-            "It specifies the batch checkpoint to use for decoding."
-            "Note: Epoch counts from 0."
-        ),
+        help="It specifies the batch checkpoint to use for decoding."
+        "Note: Epoch counts from 0.",
     )
 
     parser.add_argument(
         "--avg",
         type=int,
         default=15,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch'. "
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch'. ",
     )
 
     parser.add_argument(
@@ -193,7 +192,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -249,7 +249,9 @@ def decode_one_batch(
     supervisions = batch["supervisions"]
     feature_lens = supervisions["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
+    encoder_out, encoder_out_lens = model.encoder(
+        x=feature, x_lens=feature_lens
+    )
     hyps = []
 
     if params.decoding_method == "fast_beam_search":
@@ -264,7 +266,10 @@ def decode_one_batch(
         )
         for i in range(encoder_out.size(0)):
             hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
-    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
+    elif (
+        params.decoding_method == "greedy_search"
+        and params.max_sym_per_frame == 1
+    ):
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -310,7 +315,11 @@ def decode_one_batch(
         return {"greedy_search": hyps}
     elif params.decoding_method == "fast_beam_search":
         return {
-            f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps
+            (
+                f"beam_{params.beam}_"
+                f"max_contexts_{params.max_contexts}_"
+                f"max_states_{params.max_states}"
+            ): hyps
         }
     else:
         return {f"beam_size_{params.beam_size}": hyps}
@@ -381,7 +390,9 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+            logging.info(
+                f"batch {batch_str}, cuts processed until now is {num_cuts}"
+            )
     return results
 
 
@@ -414,7 +425,8 @@ def save_results(
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = (
-        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir
+        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tWER", file=f)
diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/export.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/export.py
index 34f4d3ddf..00b54c39f 100644
--- a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/export.py
+++ b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/export.py
@@ -62,20 +62,17 @@ def get_parser():
         "--epoch",
         type=int,
         default=28,
-        help=(
-            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
-        ),
+        help="It specifies the checkpoint to use for decoding."
+        "Note: Epoch counts from 0.",
     )
 
     parser.add_argument(
         "--avg",
         type=int,
         default=15,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch'. "
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch'. ",
     )
 
     parser.add_argument(
@@ -106,7 +103,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
 
     return parser
@@ -175,7 +173,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/pretrained.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/pretrained.py
index 3c96ed07b..eb5e6b0d4 100644
--- a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/pretrained.py
+++ b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/pretrained.py
@@ -85,11 +85,9 @@ def get_parser():
         "--checkpoint",
         type=str,
         required=True,
-        help=(
-            "Path to the checkpoint. "
-            "The checkpoint is assumed to be saved by "
-            "icefall.checkpoint.save_checkpoint()."
-        ),
+        help="Path to the checkpoint. "
+        "The checkpoint is assumed to be saved by "
+        "icefall.checkpoint.save_checkpoint().",
     )
 
     parser.add_argument(
@@ -114,12 +112,10 @@ def get_parser():
         "sound_files",
         type=str,
         nargs="+",
-        help=(
-            "The input sound file(s) to transcribe. "
-            "Supported formats are those supported by torchaudio.load(). "
-            "For example, wav and flac are supported. "
-            "The sample rate has to be 16kHz."
-        ),
+        help="The input sound file(s) to transcribe. "
+        "Supported formats are those supported by torchaudio.load(). "
+        "For example, wav and flac are supported. "
+        "The sample rate has to be 16kHz.",
     )
 
     parser.add_argument(
@@ -166,7 +162,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
 
     parser.add_argument(
@@ -196,9 +193,10 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert (
-            sample_rate == expected_sample_rate
-        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
+        assert sample_rate == expected_sample_rate, (
+            f"expected sample rate: {expected_sample_rate}. "
+            f"Given: {sample_rate}"
+        )
         # We use only the first channel
         ans.append(wave[0])
     return ans
@@ -259,7 +257,9 @@ def main():
     features = fbank(waves)
     feature_lengths = [f.size(0) for f in features]
 
-    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
+    features = pad_sequence(
+        features, batch_first=True, padding_value=math.log(1e-10)
+    )
 
     feature_lengths = torch.tensor(feature_lengths, device=device)
 
@@ -284,7 +284,10 @@ def main():
         )
         for i in range(encoder_out.size(0)):
             hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
-    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
+    elif (
+        params.decoding_method == "greedy_search"
+        and params.max_sym_per_frame == 1
+    ):
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -336,7 +339,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/train.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/train.py
index c7b1a4266..d46838b68 100644
--- a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/train.py
+++ b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/train.py
@@ -81,7 +81,9 @@ from icefall.env import get_env_info
 from icefall.lexicon import Lexicon
 from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
 
-LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
+LRSchedulerType = Union[
+    torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
+]
 
 os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
 
@@ -185,45 +187,42 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
 
     parser.add_argument(
         "--prune-range",
         type=int,
         default=5,
-        help=(
-            "The prune range for rnnt loss, it means how many symbols(context)"
-            "we are using to compute the loss"
-        ),
+        help="The prune range for rnnt loss, it means how many symbols(context)"
+        "we are using to compute the loss",
     )
 
     parser.add_argument(
         "--lm-scale",
         type=float,
         default=0.25,
-        help=(
-            "The scale to smooth the loss with lm (output of prediction network) part."
-        ),
+        help="The scale to smooth the loss with lm "
+        "(output of prediction network) part.",
     )
 
     parser.add_argument(
         "--am-scale",
         type=float,
         default=0.0,
-        help="The scale to smooth the loss with am (output of encoder network)part.",
+        help="The scale to smooth the loss with am (output of encoder network)"
+        "part.",
     )
 
     parser.add_argument(
         "--simple-loss-scale",
         type=float,
         default=0.5,
-        help=(
-            "To get pruning ranges, we will calculate a simple version"
-            "loss(joiner is just addition), this simple loss also uses for"
-            "training (as a regularization item). We will scale the simple loss"
-            "with this parameter before adding to the final loss."
-        ),
+        help="To get pruning ranges, we will calculate a simple version"
+        "loss(joiner is just addition), this simple loss also uses for"
+        "training (as a regularization item). We will scale the simple loss"
+        "with this parameter before adding to the final loss.",
     )
 
     parser.add_argument(
@@ -543,15 +542,22 @@ def compute_loss(
         # overwhelming the simple_loss and causing it to diverge,
         # in case it had not fully learned the alignment yet.
         pruned_loss_scale = (
-            0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
+            0.0
+            if warmup < 1.0
+            else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
+        )
+        loss = (
+            params.simple_loss_scale * simple_loss
+            + pruned_loss_scale * pruned_loss
         )
-        loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
     assert loss.requires_grad == is_training
 
     info = MetricsTracker()
     with warnings.catch_warnings():
         warnings.simplefilter("ignore")
-        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
+        info["frames"] = (
+            (feature_lens // params.subsampling_factor).sum().item()
+        )
 
     # Note: We use reduction=sum while computing the loss.
     info["loss"] = loss.detach().cpu().item()
@@ -705,7 +711,9 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+                tot_loss.write_summary(
+                    tb_writer, "train/tot_", params.batch_idx_train
+                )
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
@@ -805,7 +813,7 @@ def run(rank, world_size, args):
 
     if params.print_diagnostics:
         opts = diagnostics.TensorDiagnosticOptions(
-            2**22
+            2 ** 22
         )  # allow 4 megabytes per sub-module
         diagnostic = diagnostics.attach_diagnostics(model, opts)
 
diff --git a/egs/aishell/ASR/conformer_ctc/conformer.py b/egs/aishell/ASR/conformer_ctc/conformer.py
index f5b5873b4..cb7205e51 100644
--- a/egs/aishell/ASR/conformer_ctc/conformer.py
+++ b/egs/aishell/ASR/conformer_ctc/conformer.py
@@ -157,7 +157,9 @@ class ConformerEncoderLayer(nn.Module):
         normalize_before: bool = True,
     ) -> None:
         super(ConformerEncoderLayer, self).__init__()
-        self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0)
+        self.self_attn = RelPositionMultiheadAttention(
+            d_model, nhead, dropout=0.0
+        )
 
         self.feed_forward = nn.Sequential(
             nn.Linear(d_model, dim_feedforward),
@@ -175,14 +177,18 @@ class ConformerEncoderLayer(nn.Module):
 
         self.conv_module = ConvolutionModule(d_model, cnn_module_kernel)
 
-        self.norm_ff_macaron = nn.LayerNorm(d_model)  # for the macaron style FNN module
+        self.norm_ff_macaron = nn.LayerNorm(
+            d_model
+        )  # for the macaron style FNN module
         self.norm_ff = nn.LayerNorm(d_model)  # for the FNN module
         self.norm_mha = nn.LayerNorm(d_model)  # for the MHA module
 
         self.ff_scale = 0.5
 
         self.norm_conv = nn.LayerNorm(d_model)  # for the CNN module
-        self.norm_final = nn.LayerNorm(d_model)  # for the final output of the block
+        self.norm_final = nn.LayerNorm(
+            d_model
+        )  # for the final output of the block
 
         self.dropout = nn.Dropout(dropout)
 
@@ -216,7 +222,9 @@ class ConformerEncoderLayer(nn.Module):
         residual = src
         if self.normalize_before:
             src = self.norm_ff_macaron(src)
-        src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src))
+        src = residual + self.ff_scale * self.dropout(
+            self.feed_forward_macaron(src)
+        )
         if not self.normalize_before:
             src = self.norm_ff_macaron(src)
 
@@ -335,7 +343,9 @@ class RelPositionalEncoding(torch.nn.Module):
 
     """
 
-    def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None:
+    def __init__(
+        self, d_model: int, dropout_rate: float, max_len: int = 5000
+    ) -> None:
         """Construct an PositionalEncoding object."""
         super(RelPositionalEncoding, self).__init__()
         self.d_model = d_model
@@ -351,7 +361,9 @@ class RelPositionalEncoding(torch.nn.Module):
             # the length of self.pe is 2 * input_len - 1
             if self.pe.size(1) >= x.size(1) * 2 - 1:
                 # Note: TorchScript doesn't implement operator== for torch.Device
-                if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device):
+                if self.pe.dtype != x.dtype or str(self.pe.device) != str(
+                    x.device
+                ):
                     self.pe = self.pe.to(dtype=x.dtype, device=x.device)
                 return
         # Suppose `i` means to the position of query vector and `j` means the
@@ -621,9 +633,9 @@ class RelPositionMultiheadAttention(nn.Module):
 
         if torch.equal(query, key) and torch.equal(key, value):
             # self-attention
-            q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk(
-                3, dim=-1
-            )
+            q, k, v = nn.functional.linear(
+                query, in_proj_weight, in_proj_bias
+            ).chunk(3, dim=-1)
 
         elif torch.equal(key, value):
             # encoder-decoder attention
@@ -691,25 +703,33 @@ class RelPositionMultiheadAttention(nn.Module):
             if attn_mask.dim() == 2:
                 attn_mask = attn_mask.unsqueeze(0)
                 if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
-                    raise RuntimeError("The size of the 2D attn_mask is not correct.")
+                    raise RuntimeError(
+                        "The size of the 2D attn_mask is not correct."
+                    )
             elif attn_mask.dim() == 3:
                 if list(attn_mask.size()) != [
                     bsz * num_heads,
                     query.size(0),
                     key.size(0),
                 ]:
-                    raise RuntimeError("The size of the 3D attn_mask is not correct.")
+                    raise RuntimeError(
+                        "The size of the 3D attn_mask is not correct."
+                    )
             else:
                 raise RuntimeError(
-                    "attn_mask's dimension {} is not supported".format(attn_mask.dim())
+                    "attn_mask's dimension {} is not supported".format(
+                        attn_mask.dim()
+                    )
                 )
             # attn_mask's dim is 3 now.
 
         # convert ByteTensor key_padding_mask to bool
-        if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
+        if (
+            key_padding_mask is not None
+            and key_padding_mask.dtype == torch.uint8
+        ):
             warnings.warn(
-                "Byte tensor for key_padding_mask is deprecated. Use bool tensor"
-                " instead."
+                "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead."
             )
             key_padding_mask = key_padding_mask.to(torch.bool)
 
@@ -746,7 +766,9 @@ class RelPositionMultiheadAttention(nn.Module):
         # first compute matrix a and matrix c
         # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
         k = k.permute(1, 2, 3, 0)  # (batch, head, d_k, time2)
-        matrix_ac = torch.matmul(q_with_bias_u, k)  # (batch, head, time1, time2)
+        matrix_ac = torch.matmul(
+            q_with_bias_u, k
+        )  # (batch, head, time1, time2)
 
         # compute matrix b and matrix d
         matrix_bd = torch.matmul(
@@ -758,7 +780,9 @@ class RelPositionMultiheadAttention(nn.Module):
             matrix_ac + matrix_bd
         ) * scaling  # (batch, head, time1, time2)
 
-        attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1)
+        attn_output_weights = attn_output_weights.view(
+            bsz * num_heads, tgt_len, -1
+        )
 
         assert list(attn_output_weights.size()) == [
             bsz * num_heads,
@@ -792,9 +816,13 @@ class RelPositionMultiheadAttention(nn.Module):
         attn_output = torch.bmm(attn_output_weights, v)
         assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
         attn_output = (
-            attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
+            attn_output.transpose(0, 1)
+            .contiguous()
+            .view(tgt_len, bsz, embed_dim)
+        )
+        attn_output = nn.functional.linear(
+            attn_output, out_proj_weight, out_proj_bias
         )
-        attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias)
 
         if need_weights:
             # average attention weights over heads
@@ -817,7 +845,9 @@ class ConvolutionModule(nn.Module):
 
     """
 
-    def __init__(self, channels: int, kernel_size: int, bias: bool = True) -> None:
+    def __init__(
+        self, channels: int, kernel_size: int, bias: bool = True
+    ) -> None:
         """Construct an ConvolutionModule object."""
         super(ConvolutionModule, self).__init__()
         # kernerl_size should be a odd number for 'SAME' padding
diff --git a/egs/aishell/ASR/conformer_ctc/decode.py b/egs/aishell/ASR/conformer_ctc/decode.py
index a30fa52df..751b7d5b5 100755
--- a/egs/aishell/ASR/conformer_ctc/decode.py
+++ b/egs/aishell/ASR/conformer_ctc/decode.py
@@ -58,19 +58,16 @@ def get_parser():
         "--epoch",
         type=int,
         default=49,
-        help=(
-            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
-        ),
+        help="It specifies the checkpoint to use for decoding."
+        "Note: Epoch counts from 0.",
     )
     parser.add_argument(
         "--avg",
         type=int,
         default=20,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch'. "
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch'. ",
     )
 
     parser.add_argument(
@@ -404,7 +401,9 @@ def decode_dataset(
         if batch_idx % 100 == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+            logging.info(
+                f"batch {batch_str}, cuts processed until now is {num_cuts}"
+            )
     return results
 
 
@@ -432,7 +431,9 @@ def save_results(
         # we compute CER for aishell dataset.
         results_char = []
         for res in results:
-            results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
+            results_char.append(
+                (res[0], list("".join(res[1])), list("".join(res[2])))
+            )
         with open(errs_filename, "w") as f:
             wer = write_error_stats(
                 f, f"{test_set_name}-{key}", results_char, enable_log=enable_log
@@ -440,7 +441,9 @@ def save_results(
             test_set_wers[key] = wer
 
         if enable_log:
-            logging.info("Wrote detailed error stats to {}".format(errs_filename))
+            logging.info(
+                "Wrote detailed error stats to {}".format(errs_filename)
+            )
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = params.exp_dir / f"cer-summary-{test_set_name}.txt"
@@ -559,7 +562,9 @@ def main():
             eos_id=eos_id,
         )
 
-        save_results(params=params, test_set_name=test_set, results_dict=results_dict)
+        save_results(
+            params=params, test_set_name=test_set, results_dict=results_dict
+        )
 
     logging.info("Done!")
 
diff --git a/egs/aishell/ASR/conformer_ctc/export.py b/egs/aishell/ASR/conformer_ctc/export.py
index 9ee405e8b..42b8c29e7 100644
--- a/egs/aishell/ASR/conformer_ctc/export.py
+++ b/egs/aishell/ASR/conformer_ctc/export.py
@@ -40,20 +40,17 @@ def get_parser():
         "--epoch",
         type=int,
         default=84,
-        help=(
-            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
-        ),
+        help="It specifies the checkpoint to use for decoding."
+        "Note: Epoch counts from 0.",
     )
 
     parser.add_argument(
         "--avg",
         type=int,
         default=25,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch'. "
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch'. ",
     )
 
     parser.add_argument(
@@ -160,7 +157,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/aishell/ASR/conformer_ctc/pretrained.py b/egs/aishell/ASR/conformer_ctc/pretrained.py
index e3d5a20e3..27776bc24 100755
--- a/egs/aishell/ASR/conformer_ctc/pretrained.py
+++ b/egs/aishell/ASR/conformer_ctc/pretrained.py
@@ -46,29 +46,27 @@ def get_parser():
         "--checkpoint",
         type=str,
         required=True,
-        help=(
-            "Path to the checkpoint. "
-            "The checkpoint is assumed to be saved by "
-            "icefall.checkpoint.save_checkpoint()."
-        ),
+        help="Path to the checkpoint. "
+        "The checkpoint is assumed to be saved by "
+        "icefall.checkpoint.save_checkpoint().",
     )
 
     parser.add_argument(
         "--tokens-file",
         type=str,
-        help="Path to tokens.txtUsed only when method is ctc-decoding",
+        help="Path to tokens.txt" "Used only when method is ctc-decoding",
     )
 
     parser.add_argument(
         "--words-file",
         type=str,
-        help="Path to words.txtUsed when method is NOT ctc-decoding",
+        help="Path to words.txt" "Used when method is NOT ctc-decoding",
     )
 
     parser.add_argument(
         "--HLG",
         type=str,
-        help="Path to HLG.pt.Used when method is NOT ctc-decoding",
+        help="Path to HLG.pt." "Used when method is NOT ctc-decoding",
     )
 
     parser.add_argument(
@@ -165,12 +163,10 @@ def get_parser():
         "sound_files",
         type=str,
         nargs="+",
-        help=(
-            "The input sound file(s) to transcribe. "
-            "Supported formats are those supported by torchaudio.load(). "
-            "For example, wav and flac are supported. "
-            "The sample rate has to be 16kHz."
-        ),
+        help="The input sound file(s) to transcribe. "
+        "Supported formats are those supported by torchaudio.load(). "
+        "For example, wav and flac are supported. "
+        "The sample rate has to be 16kHz.",
     )
 
     return parser
@@ -214,9 +210,10 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert (
-            sample_rate == expected_sample_rate
-        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
+        assert sample_rate == expected_sample_rate, (
+            f"expected sample rate: {expected_sample_rate}. "
+            f"Given: {sample_rate}"
+        )
         # We use only the first channel
         ans.append(wave[0])
     return ans
@@ -277,7 +274,9 @@ def main():
     logging.info("Decoding started")
     features = fbank(waves)
 
-    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
+    features = pad_sequence(
+        features, batch_first=True, padding_value=math.log(1e-10)
+    )
 
     # Note: We don't use key padding mask for attention during decoding
     with torch.no_grad():
@@ -372,7 +371,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/aishell/ASR/conformer_ctc/subsampling.py b/egs/aishell/ASR/conformer_ctc/subsampling.py
index 8e0f73d05..542fb0364 100644
--- a/egs/aishell/ASR/conformer_ctc/subsampling.py
+++ b/egs/aishell/ASR/conformer_ctc/subsampling.py
@@ -42,9 +42,13 @@ class Conv2dSubsampling(nn.Module):
         assert idim >= 7
         super().__init__()
         self.conv = nn.Sequential(
-            nn.Conv2d(in_channels=1, out_channels=odim, kernel_size=3, stride=2),
+            nn.Conv2d(
+                in_channels=1, out_channels=odim, kernel_size=3, stride=2
+            ),
             nn.ReLU(),
-            nn.Conv2d(in_channels=odim, out_channels=odim, kernel_size=3, stride=2),
+            nn.Conv2d(
+                in_channels=odim, out_channels=odim, kernel_size=3, stride=2
+            ),
             nn.ReLU(),
         )
         self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim)
@@ -128,13 +132,17 @@ class VggSubsampling(nn.Module):
                 )
             )
             layers.append(
-                torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=True)
+                torch.nn.MaxPool2d(
+                    kernel_size=2, stride=2, padding=0, ceil_mode=True
+                )
             )
             cur_channels = block_dim
 
         self.layers = nn.Sequential(*layers)
 
-        self.out = nn.Linear(block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim)
+        self.out = nn.Linear(
+            block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim
+        )
 
     def forward(self, x: torch.Tensor) -> torch.Tensor:
         """Subsample x.
diff --git a/egs/aishell/ASR/conformer_ctc/test_subsampling.py b/egs/aishell/ASR/conformer_ctc/test_subsampling.py
index 81fa234dd..e3361d0c9 100755
--- a/egs/aishell/ASR/conformer_ctc/test_subsampling.py
+++ b/egs/aishell/ASR/conformer_ctc/test_subsampling.py
@@ -16,8 +16,9 @@
 # limitations under the License.
 
 
+from subsampling import Conv2dSubsampling
+from subsampling import VggSubsampling
 import torch
-from subsampling import Conv2dSubsampling, VggSubsampling
 
 
 def test_conv2d_subsampling():
diff --git a/egs/aishell/ASR/conformer_ctc/train.py b/egs/aishell/ASR/conformer_ctc/train.py
index c2cbe6e3b..a228cc1fe 100755
--- a/egs/aishell/ASR/conformer_ctc/train.py
+++ b/egs/aishell/ASR/conformer_ctc/train.py
@@ -382,7 +382,9 @@ def compute_loss(
             #
             # See https://github.com/k2-fsa/icefall/issues/97
             # for more details
-            unsorted_token_ids = graph_compiler.texts_to_ids(supervisions["text"])
+            unsorted_token_ids = graph_compiler.texts_to_ids(
+                supervisions["text"]
+            )
             att_loss = mmodel.decoder_forward(
                 encoder_memory,
                 memory_mask,
@@ -518,7 +520,9 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+                tot_loss.write_summary(
+                    tb_writer, "train/tot_", params.batch_idx_train
+                )
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
@@ -626,7 +630,9 @@ def run(rank, world_size, args):
 
         cur_lr = optimizer._rate
         if tb_writer is not None:
-            tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
+            tb_writer.add_scalar(
+                "train/learning_rate", cur_lr, params.batch_idx_train
+            )
             tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
 
         if rank == 0:
diff --git a/egs/aishell/ASR/conformer_ctc/transformer.py b/egs/aishell/ASR/conformer_ctc/transformer.py
index a3e50e385..f93914aaa 100644
--- a/egs/aishell/ASR/conformer_ctc/transformer.py
+++ b/egs/aishell/ASR/conformer_ctc/transformer.py
@@ -149,7 +149,9 @@ class Transformer(nn.Module):
                 norm=decoder_norm,
             )
 
-            self.decoder_output_layer = torch.nn.Linear(d_model, self.decoder_num_class)
+            self.decoder_output_layer = torch.nn.Linear(
+                d_model, self.decoder_num_class
+            )
 
             self.decoder_criterion = LabelSmoothingLoss()
         else:
@@ -181,7 +183,9 @@ class Transformer(nn.Module):
             x = x.permute(0, 2, 1)  # (N, T, C) -> (N, C, T)
             x = self.feat_batchnorm(x)
             x = x.permute(0, 2, 1)  # (N, C, T) -> (N, T, C)
-        encoder_memory, memory_key_padding_mask = self.run_encoder(x, supervision)
+        encoder_memory, memory_key_padding_mask = self.run_encoder(
+            x, supervision
+        )
         x = self.ctc_output(encoder_memory)
         return x, encoder_memory, memory_key_padding_mask
 
@@ -262,17 +266,23 @@ class Transformer(nn.Module):
         """
         ys_in = add_sos(token_ids, sos_id=sos_id)
         ys_in = [torch.tensor(y) for y in ys_in]
-        ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id))
+        ys_in_pad = pad_sequence(
+            ys_in, batch_first=True, padding_value=float(eos_id)
+        )
 
         ys_out = add_eos(token_ids, eos_id=eos_id)
         ys_out = [torch.tensor(y) for y in ys_out]
-        ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1))
+        ys_out_pad = pad_sequence(
+            ys_out, batch_first=True, padding_value=float(-1)
+        )
 
         device = memory.device
         ys_in_pad = ys_in_pad.to(device)
         ys_out_pad = ys_out_pad.to(device)
 
-        tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device)
+        tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(
+            device
+        )
 
         tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id)
         # TODO: Use length information to create the decoder padding mask
@@ -333,17 +343,23 @@ class Transformer(nn.Module):
 
         ys_in = add_sos(token_ids, sos_id=sos_id)
         ys_in = [torch.tensor(y) for y in ys_in]
-        ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id))
+        ys_in_pad = pad_sequence(
+            ys_in, batch_first=True, padding_value=float(eos_id)
+        )
 
         ys_out = add_eos(token_ids, eos_id=eos_id)
         ys_out = [torch.tensor(y) for y in ys_out]
-        ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1))
+        ys_out_pad = pad_sequence(
+            ys_out, batch_first=True, padding_value=float(-1)
+        )
 
         device = memory.device
         ys_in_pad = ys_in_pad.to(device, dtype=torch.int64)
         ys_out_pad = ys_out_pad.to(device, dtype=torch.int64)
 
-        tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device)
+        tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(
+            device
+        )
 
         tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id)
         # TODO: Use length information to create the decoder padding mask
@@ -616,7 +632,9 @@ def _get_activation_fn(activation: str):
     elif activation == "gelu":
         return nn.functional.gelu
 
-    raise RuntimeError("activation should be relu/gelu, not {}".format(activation))
+    raise RuntimeError(
+        "activation should be relu/gelu, not {}".format(activation)
+    )
 
 
 class PositionalEncoding(nn.Module):
@@ -818,7 +836,9 @@ def encoder_padding_mask(
         1,
     ).to(torch.int32)
 
-    lengths = [0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)]
+    lengths = [
+        0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)
+    ]
     for idx in range(supervision_segments.size(0)):
         # Note: TorchScript doesn't allow to unpack tensors as tuples
         sequence_idx = supervision_segments[idx, 0].item()
@@ -839,7 +859,9 @@ def encoder_padding_mask(
     return mask
 
 
-def decoder_padding_mask(ys_pad: torch.Tensor, ignore_id: int = -1) -> torch.Tensor:
+def decoder_padding_mask(
+    ys_pad: torch.Tensor, ignore_id: int = -1
+) -> torch.Tensor:
     """Generate a length mask for input.
 
     The masked position are filled with True,
diff --git a/egs/aishell/ASR/conformer_mmi/conformer.py b/egs/aishell/ASR/conformer_mmi/conformer.py
index f5b5873b4..cb7205e51 100644
--- a/egs/aishell/ASR/conformer_mmi/conformer.py
+++ b/egs/aishell/ASR/conformer_mmi/conformer.py
@@ -157,7 +157,9 @@ class ConformerEncoderLayer(nn.Module):
         normalize_before: bool = True,
     ) -> None:
         super(ConformerEncoderLayer, self).__init__()
-        self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0)
+        self.self_attn = RelPositionMultiheadAttention(
+            d_model, nhead, dropout=0.0
+        )
 
         self.feed_forward = nn.Sequential(
             nn.Linear(d_model, dim_feedforward),
@@ -175,14 +177,18 @@ class ConformerEncoderLayer(nn.Module):
 
         self.conv_module = ConvolutionModule(d_model, cnn_module_kernel)
 
-        self.norm_ff_macaron = nn.LayerNorm(d_model)  # for the macaron style FNN module
+        self.norm_ff_macaron = nn.LayerNorm(
+            d_model
+        )  # for the macaron style FNN module
         self.norm_ff = nn.LayerNorm(d_model)  # for the FNN module
         self.norm_mha = nn.LayerNorm(d_model)  # for the MHA module
 
         self.ff_scale = 0.5
 
         self.norm_conv = nn.LayerNorm(d_model)  # for the CNN module
-        self.norm_final = nn.LayerNorm(d_model)  # for the final output of the block
+        self.norm_final = nn.LayerNorm(
+            d_model
+        )  # for the final output of the block
 
         self.dropout = nn.Dropout(dropout)
 
@@ -216,7 +222,9 @@ class ConformerEncoderLayer(nn.Module):
         residual = src
         if self.normalize_before:
             src = self.norm_ff_macaron(src)
-        src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src))
+        src = residual + self.ff_scale * self.dropout(
+            self.feed_forward_macaron(src)
+        )
         if not self.normalize_before:
             src = self.norm_ff_macaron(src)
 
@@ -335,7 +343,9 @@ class RelPositionalEncoding(torch.nn.Module):
 
     """
 
-    def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None:
+    def __init__(
+        self, d_model: int, dropout_rate: float, max_len: int = 5000
+    ) -> None:
         """Construct an PositionalEncoding object."""
         super(RelPositionalEncoding, self).__init__()
         self.d_model = d_model
@@ -351,7 +361,9 @@ class RelPositionalEncoding(torch.nn.Module):
             # the length of self.pe is 2 * input_len - 1
             if self.pe.size(1) >= x.size(1) * 2 - 1:
                 # Note: TorchScript doesn't implement operator== for torch.Device
-                if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device):
+                if self.pe.dtype != x.dtype or str(self.pe.device) != str(
+                    x.device
+                ):
                     self.pe = self.pe.to(dtype=x.dtype, device=x.device)
                 return
         # Suppose `i` means to the position of query vector and `j` means the
@@ -621,9 +633,9 @@ class RelPositionMultiheadAttention(nn.Module):
 
         if torch.equal(query, key) and torch.equal(key, value):
             # self-attention
-            q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk(
-                3, dim=-1
-            )
+            q, k, v = nn.functional.linear(
+                query, in_proj_weight, in_proj_bias
+            ).chunk(3, dim=-1)
 
         elif torch.equal(key, value):
             # encoder-decoder attention
@@ -691,25 +703,33 @@ class RelPositionMultiheadAttention(nn.Module):
             if attn_mask.dim() == 2:
                 attn_mask = attn_mask.unsqueeze(0)
                 if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
-                    raise RuntimeError("The size of the 2D attn_mask is not correct.")
+                    raise RuntimeError(
+                        "The size of the 2D attn_mask is not correct."
+                    )
             elif attn_mask.dim() == 3:
                 if list(attn_mask.size()) != [
                     bsz * num_heads,
                     query.size(0),
                     key.size(0),
                 ]:
-                    raise RuntimeError("The size of the 3D attn_mask is not correct.")
+                    raise RuntimeError(
+                        "The size of the 3D attn_mask is not correct."
+                    )
             else:
                 raise RuntimeError(
-                    "attn_mask's dimension {} is not supported".format(attn_mask.dim())
+                    "attn_mask's dimension {} is not supported".format(
+                        attn_mask.dim()
+                    )
                 )
             # attn_mask's dim is 3 now.
 
         # convert ByteTensor key_padding_mask to bool
-        if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
+        if (
+            key_padding_mask is not None
+            and key_padding_mask.dtype == torch.uint8
+        ):
             warnings.warn(
-                "Byte tensor for key_padding_mask is deprecated. Use bool tensor"
-                " instead."
+                "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead."
             )
             key_padding_mask = key_padding_mask.to(torch.bool)
 
@@ -746,7 +766,9 @@ class RelPositionMultiheadAttention(nn.Module):
         # first compute matrix a and matrix c
         # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
         k = k.permute(1, 2, 3, 0)  # (batch, head, d_k, time2)
-        matrix_ac = torch.matmul(q_with_bias_u, k)  # (batch, head, time1, time2)
+        matrix_ac = torch.matmul(
+            q_with_bias_u, k
+        )  # (batch, head, time1, time2)
 
         # compute matrix b and matrix d
         matrix_bd = torch.matmul(
@@ -758,7 +780,9 @@ class RelPositionMultiheadAttention(nn.Module):
             matrix_ac + matrix_bd
         ) * scaling  # (batch, head, time1, time2)
 
-        attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1)
+        attn_output_weights = attn_output_weights.view(
+            bsz * num_heads, tgt_len, -1
+        )
 
         assert list(attn_output_weights.size()) == [
             bsz * num_heads,
@@ -792,9 +816,13 @@ class RelPositionMultiheadAttention(nn.Module):
         attn_output = torch.bmm(attn_output_weights, v)
         assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
         attn_output = (
-            attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
+            attn_output.transpose(0, 1)
+            .contiguous()
+            .view(tgt_len, bsz, embed_dim)
+        )
+        attn_output = nn.functional.linear(
+            attn_output, out_proj_weight, out_proj_bias
         )
-        attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias)
 
         if need_weights:
             # average attention weights over heads
@@ -817,7 +845,9 @@ class ConvolutionModule(nn.Module):
 
     """
 
-    def __init__(self, channels: int, kernel_size: int, bias: bool = True) -> None:
+    def __init__(
+        self, channels: int, kernel_size: int, bias: bool = True
+    ) -> None:
         """Construct an ConvolutionModule object."""
         super(ConvolutionModule, self).__init__()
         # kernerl_size should be a odd number for 'SAME' padding
diff --git a/egs/aishell/ASR/conformer_mmi/decode.py b/egs/aishell/ASR/conformer_mmi/decode.py
index a43183063..4db367e36 100755
--- a/egs/aishell/ASR/conformer_mmi/decode.py
+++ b/egs/aishell/ASR/conformer_mmi/decode.py
@@ -59,19 +59,16 @@ def get_parser():
         "--epoch",
         type=int,
         default=49,
-        help=(
-            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
-        ),
+        help="It specifies the checkpoint to use for decoding."
+        "Note: Epoch counts from 0.",
     )
     parser.add_argument(
         "--avg",
         type=int,
         default=20,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch'. "
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch'. ",
     )
 
     parser.add_argument(
@@ -416,7 +413,9 @@ def decode_dataset(
         if batch_idx % 100 == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+            logging.info(
+                f"batch {batch_str}, cuts processed until now is {num_cuts}"
+            )
     return results
 
 
@@ -444,7 +443,9 @@ def save_results(
         # we compute CER for aishell dataset.
         results_char = []
         for res in results:
-            results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
+            results_char.append(
+                (res[0], list("".join(res[1])), list("".join(res[2])))
+            )
         with open(errs_filename, "w") as f:
             wer = write_error_stats(
                 f, f"{test_set_name}-{key}", results_char, enable_log=enable_log
@@ -452,7 +453,9 @@ def save_results(
             test_set_wers[key] = wer
 
         if enable_log:
-            logging.info("Wrote detailed error stats to {}".format(errs_filename))
+            logging.info(
+                "Wrote detailed error stats to {}".format(errs_filename)
+            )
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = params.exp_dir / f"cer-summary-{test_set_name}.txt"
@@ -547,7 +550,9 @@ def main():
 
     if params.export:
         logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
-        torch.save({"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt")
+        torch.save(
+            {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
+        )
         return
 
     model.to(device)
@@ -576,7 +581,9 @@ def main():
             eos_id=eos_id,
         )
 
-        save_results(params=params, test_set_name=test_set, results_dict=results_dict)
+        save_results(
+            params=params, test_set_name=test_set, results_dict=results_dict
+        )
 
     logging.info("Done!")
 
diff --git a/egs/aishell/ASR/conformer_mmi/subsampling.py b/egs/aishell/ASR/conformer_mmi/subsampling.py
index 398837a46..720ed6c22 100644
--- a/egs/aishell/ASR/conformer_mmi/subsampling.py
+++ b/egs/aishell/ASR/conformer_mmi/subsampling.py
@@ -42,9 +42,13 @@ class Conv2dSubsampling(nn.Module):
         assert idim >= 7
         super().__init__()
         self.conv = nn.Sequential(
-            nn.Conv2d(in_channels=1, out_channels=odim, kernel_size=3, stride=2),
+            nn.Conv2d(
+                in_channels=1, out_channels=odim, kernel_size=3, stride=2
+            ),
             nn.ReLU(),
-            nn.Conv2d(in_channels=odim, out_channels=odim, kernel_size=3, stride=2),
+            nn.Conv2d(
+                in_channels=odim, out_channels=odim, kernel_size=3, stride=2
+            ),
             nn.ReLU(),
         )
         self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim)
@@ -128,13 +132,17 @@ class VggSubsampling(nn.Module):
                 )
             )
             layers.append(
-                torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=True)
+                torch.nn.MaxPool2d(
+                    kernel_size=2, stride=2, padding=0, ceil_mode=True
+                )
             )
             cur_channels = block_dim
 
         self.layers = nn.Sequential(*layers)
 
-        self.out = nn.Linear(block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim)
+        self.out = nn.Linear(
+            block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim
+        )
 
     def forward(self, x: torch.Tensor) -> torch.Tensor:
         """Subsample x.
diff --git a/egs/aishell/ASR/conformer_mmi/train.py b/egs/aishell/ASR/conformer_mmi/train.py
index 09cd6e60c..685831d09 100755
--- a/egs/aishell/ASR/conformer_mmi/train.py
+++ b/egs/aishell/ASR/conformer_mmi/train.py
@@ -511,7 +511,9 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+                tot_loss.write_summary(
+                    tb_writer, "train/tot_", params.batch_idx_train
+                )
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
@@ -623,7 +625,9 @@ def run(rank, world_size, args):
 
         cur_lr = optimizer._rate
         if tb_writer is not None:
-            tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
+            tb_writer.add_scalar(
+                "train/learning_rate", cur_lr, params.batch_idx_train
+            )
             tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
 
         if rank == 0:
diff --git a/egs/aishell/ASR/conformer_mmi/transformer.py b/egs/aishell/ASR/conformer_mmi/transformer.py
index a3e50e385..f93914aaa 100644
--- a/egs/aishell/ASR/conformer_mmi/transformer.py
+++ b/egs/aishell/ASR/conformer_mmi/transformer.py
@@ -149,7 +149,9 @@ class Transformer(nn.Module):
                 norm=decoder_norm,
             )
 
-            self.decoder_output_layer = torch.nn.Linear(d_model, self.decoder_num_class)
+            self.decoder_output_layer = torch.nn.Linear(
+                d_model, self.decoder_num_class
+            )
 
             self.decoder_criterion = LabelSmoothingLoss()
         else:
@@ -181,7 +183,9 @@ class Transformer(nn.Module):
             x = x.permute(0, 2, 1)  # (N, T, C) -> (N, C, T)
             x = self.feat_batchnorm(x)
             x = x.permute(0, 2, 1)  # (N, C, T) -> (N, T, C)
-        encoder_memory, memory_key_padding_mask = self.run_encoder(x, supervision)
+        encoder_memory, memory_key_padding_mask = self.run_encoder(
+            x, supervision
+        )
         x = self.ctc_output(encoder_memory)
         return x, encoder_memory, memory_key_padding_mask
 
@@ -262,17 +266,23 @@ class Transformer(nn.Module):
         """
         ys_in = add_sos(token_ids, sos_id=sos_id)
         ys_in = [torch.tensor(y) for y in ys_in]
-        ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id))
+        ys_in_pad = pad_sequence(
+            ys_in, batch_first=True, padding_value=float(eos_id)
+        )
 
         ys_out = add_eos(token_ids, eos_id=eos_id)
         ys_out = [torch.tensor(y) for y in ys_out]
-        ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1))
+        ys_out_pad = pad_sequence(
+            ys_out, batch_first=True, padding_value=float(-1)
+        )
 
         device = memory.device
         ys_in_pad = ys_in_pad.to(device)
         ys_out_pad = ys_out_pad.to(device)
 
-        tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device)
+        tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(
+            device
+        )
 
         tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id)
         # TODO: Use length information to create the decoder padding mask
@@ -333,17 +343,23 @@ class Transformer(nn.Module):
 
         ys_in = add_sos(token_ids, sos_id=sos_id)
         ys_in = [torch.tensor(y) for y in ys_in]
-        ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id))
+        ys_in_pad = pad_sequence(
+            ys_in, batch_first=True, padding_value=float(eos_id)
+        )
 
         ys_out = add_eos(token_ids, eos_id=eos_id)
         ys_out = [torch.tensor(y) for y in ys_out]
-        ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1))
+        ys_out_pad = pad_sequence(
+            ys_out, batch_first=True, padding_value=float(-1)
+        )
 
         device = memory.device
         ys_in_pad = ys_in_pad.to(device, dtype=torch.int64)
         ys_out_pad = ys_out_pad.to(device, dtype=torch.int64)
 
-        tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device)
+        tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(
+            device
+        )
 
         tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id)
         # TODO: Use length information to create the decoder padding mask
@@ -616,7 +632,9 @@ def _get_activation_fn(activation: str):
     elif activation == "gelu":
         return nn.functional.gelu
 
-    raise RuntimeError("activation should be relu/gelu, not {}".format(activation))
+    raise RuntimeError(
+        "activation should be relu/gelu, not {}".format(activation)
+    )
 
 
 class PositionalEncoding(nn.Module):
@@ -818,7 +836,9 @@ def encoder_padding_mask(
         1,
     ).to(torch.int32)
 
-    lengths = [0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)]
+    lengths = [
+        0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)
+    ]
     for idx in range(supervision_segments.size(0)):
         # Note: TorchScript doesn't allow to unpack tensors as tuples
         sequence_idx = supervision_segments[idx, 0].item()
@@ -839,7 +859,9 @@ def encoder_padding_mask(
     return mask
 
 
-def decoder_padding_mask(ys_pad: torch.Tensor, ignore_id: int = -1) -> torch.Tensor:
+def decoder_padding_mask(
+    ys_pad: torch.Tensor, ignore_id: int = -1
+) -> torch.Tensor:
     """Generate a length mask for input.
 
     The masked position are filled with True,
diff --git a/egs/aishell/ASR/local/compute_fbank_aidatatang_200zh.py b/egs/aishell/ASR/local/compute_fbank_aidatatang_200zh.py
index 037971927..42700a972 100755
--- a/egs/aishell/ASR/local/compute_fbank_aidatatang_200zh.py
+++ b/egs/aishell/ASR/local/compute_fbank_aidatatang_200zh.py
@@ -87,7 +87,9 @@ def compute_fbank_aidatatang_200zh(num_mel_bins: int = 80):
             )
             if "train" in partition:
                 cut_set = (
-                    cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
+                    cut_set
+                    + cut_set.perturb_speed(0.9)
+                    + cut_set.perturb_speed(1.1)
                 )
             cut_set = cut_set.compute_and_store_features(
                 extractor=extractor,
@@ -114,7 +116,9 @@ def get_args():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
 
diff --git a/egs/aishell/ASR/local/compute_fbank_aishell.py b/egs/aishell/ASR/local/compute_fbank_aishell.py
index 115ca1031..deab6c809 100755
--- a/egs/aishell/ASR/local/compute_fbank_aishell.py
+++ b/egs/aishell/ASR/local/compute_fbank_aishell.py
@@ -83,7 +83,9 @@ def compute_fbank_aishell(num_mel_bins: int = 80):
             )
             if "train" in partition:
                 cut_set = (
-                    cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
+                    cut_set
+                    + cut_set.perturb_speed(0.9)
+                    + cut_set.perturb_speed(1.1)
                 )
             cut_set = cut_set.compute_and_store_features(
                 extractor=extractor,
@@ -109,7 +111,9 @@ def get_args():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
 
diff --git a/egs/aishell/ASR/local/prepare_char.py b/egs/aishell/ASR/local/prepare_char.py
index 6b440dfb3..d9e47d17a 100755
--- a/egs/aishell/ASR/local/prepare_char.py
+++ b/egs/aishell/ASR/local/prepare_char.py
@@ -86,7 +86,9 @@ def lexicon_to_fst_no_sil(
         cur_state = loop_state
 
         word = word2id[word]
-        pieces = [token2id[i] if i in token2id else token2id[""] for i in pieces]
+        pieces = [
+            token2id[i] if i in token2id else token2id[""] for i in pieces
+        ]
 
         for i in range(len(pieces) - 1):
             w = word if i == 0 else eps
@@ -140,7 +142,9 @@ def contain_oov(token_sym_table: Dict[str, int], tokens: List[str]) -> bool:
     return False
 
 
-def generate_lexicon(token_sym_table: Dict[str, int], words: List[str]) -> Lexicon:
+def generate_lexicon(
+    token_sym_table: Dict[str, int], words: List[str]
+) -> Lexicon:
     """Generate a lexicon from a word list and token_sym_table.
 
     Args:
diff --git a/egs/aishell/ASR/local/prepare_lang.py b/egs/aishell/ASR/local/prepare_lang.py
index c8cf9b881..e5ae89ec4 100755
--- a/egs/aishell/ASR/local/prepare_lang.py
+++ b/egs/aishell/ASR/local/prepare_lang.py
@@ -317,7 +317,9 @@ def lexicon_to_fst(
 
 def get_args():
     parser = argparse.ArgumentParser()
-    parser.add_argument("--lang-dir", type=str, help="The lang dir, data/lang_phone")
+    parser.add_argument(
+        "--lang-dir", type=str, help="The lang dir, data/lang_phone"
+    )
     return parser.parse_args()
 
 
diff --git a/egs/aishell/ASR/local/test_prepare_lang.py b/egs/aishell/ASR/local/test_prepare_lang.py
index 74e025ad7..d4cf62bba 100755
--- a/egs/aishell/ASR/local/test_prepare_lang.py
+++ b/egs/aishell/ASR/local/test_prepare_lang.py
@@ -88,7 +88,9 @@ def test_read_lexicon(filename: str):
     fsa.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
     fsa.draw("L.pdf", title="L")
 
-    fsa_disambig = lexicon_to_fst(lexicon_disambig, phone2id=phone2id, word2id=word2id)
+    fsa_disambig = lexicon_to_fst(
+        lexicon_disambig, phone2id=phone2id, word2id=word2id
+    )
     fsa_disambig.labels_sym = k2.SymbolTable.from_file("phones.txt")
     fsa_disambig.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
     fsa_disambig.draw("L_disambig.pdf", title="L_disambig")
diff --git a/egs/aishell/ASR/pruned_transducer_stateless2/decode.py b/egs/aishell/ASR/pruned_transducer_stateless2/decode.py
index ae926ec66..a12934d55 100755
--- a/egs/aishell/ASR/pruned_transducer_stateless2/decode.py
+++ b/egs/aishell/ASR/pruned_transducer_stateless2/decode.py
@@ -76,7 +76,11 @@ from beam_search import (
 )
 from train import add_model_arguments, get_params, get_transducer_model
 
-from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
+from icefall.checkpoint import (
+    average_checkpoints,
+    find_checkpoints,
+    load_checkpoint,
+)
 from icefall.lexicon import Lexicon
 from icefall.utils import (
     AttributeDict,
@@ -114,11 +118,9 @@ def get_parser():
         "--avg",
         type=int,
         default=15,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch' and '--iter'"
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch' and '--iter'",
     )
 
     parser.add_argument(
@@ -186,7 +188,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=1,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -246,7 +249,9 @@ def decode_one_batch(
     supervisions = batch["supervisions"]
     feature_lens = supervisions["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
+    encoder_out, encoder_out_lens = model.encoder(
+        x=feature, x_lens=feature_lens
+    )
 
     if params.decoding_method == "fast_beam_search":
         hyp_tokens = fast_beam_search_one_best(
@@ -258,7 +263,10 @@ def decode_one_batch(
             max_contexts=params.max_contexts,
             max_states=params.max_states,
         )
-    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
+    elif (
+        params.decoding_method == "greedy_search"
+        and params.max_sym_per_frame == 1
+    ):
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -302,7 +310,11 @@ def decode_one_batch(
         return {"greedy_search": hyps}
     elif params.decoding_method == "fast_beam_search":
         return {
-            f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps
+            (
+                f"beam_{params.beam}_"
+                f"max_contexts_{params.max_contexts}_"
+                f"max_states_{params.max_states}"
+            ): hyps
         }
     else:
         return {f"beam_size_{params.beam_size}": hyps}
@@ -375,7 +387,9 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+            logging.info(
+                f"batch {batch_str}, cuts processed until now is {num_cuts}"
+            )
     return results
 
 
@@ -401,7 +415,9 @@ def save_results(
         # we compute CER for aishell dataset.
         results_char = []
         for res in results:
-            results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
+            results_char.append(
+                (res[0], list("".join(res[1])), list("".join(res[2])))
+            )
         with open(errs_filename, "w") as f:
             wer = write_error_stats(
                 f, f"{test_set_name}-{key}", results_char, enable_log=True
@@ -412,7 +428,8 @@ def save_results(
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = (
-        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir
+        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tWER", file=f)
@@ -456,7 +473,9 @@ def main():
         params.suffix += f"-max-contexts-{params.max_contexts}"
         params.suffix += f"-max-states-{params.max_states}"
     elif "beam_search" in params.decoding_method:
-        params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
+        params.suffix += (
+            f"-{params.decoding_method}-beam-size-{params.beam_size}"
+        )
     else:
         params.suffix += f"-context-{params.context_size}"
         params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
@@ -485,7 +504,8 @@ def main():
         ]
         if len(filenames) == 0:
             raise ValueError(
-                f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
+                f"No checkpoints found for"
+                f" --iter {params.iter}, --avg {params.avg}"
             )
         elif len(filenames) < params.avg:
             raise ValueError(
diff --git a/egs/aishell/ASR/pruned_transducer_stateless2/export.py b/egs/aishell/ASR/pruned_transducer_stateless2/export.py
index 5f6888db4..feababdd2 100755
--- a/egs/aishell/ASR/pruned_transducer_stateless2/export.py
+++ b/egs/aishell/ASR/pruned_transducer_stateless2/export.py
@@ -50,7 +50,11 @@ from pathlib import Path
 import torch
 from train import add_model_arguments, get_params, get_transducer_model
 
-from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
+from icefall.checkpoint import (
+    average_checkpoints,
+    find_checkpoints,
+    load_checkpoint,
+)
 from icefall.lexicon import Lexicon
 from icefall.utils import str2bool
 
@@ -83,11 +87,9 @@ def get_parser():
         "--avg",
         type=int,
         default=15,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch' and '--iter'"
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch' and '--iter'",
     )
 
     parser.add_argument(
@@ -118,7 +120,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=1,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
 
     add_model_arguments(parser)
@@ -154,7 +157,8 @@ def main():
         ]
         if len(filenames) == 0:
             raise ValueError(
-                f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
+                f"No checkpoints found for"
+                f" --iter {params.iter}, --avg {params.avg}"
             )
         elif len(filenames) < params.avg:
             raise ValueError(
@@ -187,7 +191,9 @@ def main():
         model.__class__.forward = torch.jit.ignore(model.__class__.forward)
         logging.info("Using torch.jit.script")
         model = torch.jit.script(model)
-        filename = params.exp_dir / f"cpu_jit-epoch-{params.epoch}-avg-{params.avg}.pt"
+        filename = (
+            params.exp_dir / f"cpu_jit-epoch-{params.epoch}-avg-{params.avg}.pt"
+        )
         model.save(str(filename))
         logging.info(f"Saved to {filename}")
     else:
@@ -195,14 +201,17 @@ def main():
         # Save it using a format so that it can be loaded
         # by :func:`load_checkpoint`
         filename = (
-            params.exp_dir / f"pretrained-epoch-{params.epoch}-avg-{params.avg}.pt"
+            params.exp_dir
+            / f"pretrained-epoch-{params.epoch}-avg-{params.avg}.pt"
         )
         torch.save({"model": model.state_dict()}, str(filename))
         logging.info(f"Saved to {filename}")
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/aishell/ASR/pruned_transducer_stateless2/pretrained.py b/egs/aishell/ASR/pruned_transducer_stateless2/pretrained.py
index f754a7b9e..3c38e5db7 100755
--- a/egs/aishell/ASR/pruned_transducer_stateless2/pretrained.py
+++ b/egs/aishell/ASR/pruned_transducer_stateless2/pretrained.py
@@ -87,11 +87,9 @@ def get_parser():
         "--checkpoint",
         type=str,
         required=True,
-        help=(
-            "Path to the checkpoint. "
-            "The checkpoint is assumed to be saved by "
-            "icefall.checkpoint.save_checkpoint()."
-        ),
+        help="Path to the checkpoint. "
+        "The checkpoint is assumed to be saved by "
+        "icefall.checkpoint.save_checkpoint().",
     )
 
     parser.add_argument(
@@ -117,12 +115,10 @@ def get_parser():
         "sound_files",
         type=str,
         nargs="+",
-        help=(
-            "The input sound file(s) to transcribe. "
-            "Supported formats are those supported by torchaudio.load(). "
-            "For example, wav and flac are supported. "
-            "The sample rate has to be 16kHz."
-        ),
+        help="The input sound file(s) to transcribe. "
+        "Supported formats are those supported by torchaudio.load(). "
+        "For example, wav and flac are supported. "
+        "The sample rate has to be 16kHz.",
     )
 
     parser.add_argument(
@@ -169,16 +165,15 @@ def get_parser():
         "--context-size",
         type=int,
         default=1,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
         type=int,
         default=1,
-        help=(
-            "Maximum number of symbols per frame. "
-            "Use only when --method is greedy_search"
-        ),
+        help="Maximum number of symbols per frame. "
+        "Use only when --method is greedy_search",
     )
 
     add_model_arguments(parser)
@@ -201,9 +196,10 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert (
-            sample_rate == expected_sample_rate
-        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
+        assert sample_rate == expected_sample_rate, (
+            f"expected sample rate: {expected_sample_rate}. "
+            f"Given: {sample_rate}"
+        )
         # We use only the first channel
         ans.append(wave[0])
     return ans
@@ -260,9 +256,13 @@ def main():
     feature_lens = [f.size(0) for f in features]
     feature_lens = torch.tensor(feature_lens, device=device)
 
-    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
+    features = pad_sequence(
+        features, batch_first=True, padding_value=math.log(1e-10)
+    )
 
-    encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lens)
+    encoder_out, encoder_out_lens = model.encoder(
+        x=features, x_lens=feature_lens
+    )
 
     num_waves = encoder_out.size(0)
     hyp_list = []
@@ -310,7 +310,9 @@ def main():
                     beam=params.beam_size,
                 )
             else:
-                raise ValueError(f"Unsupported decoding method: {params.method}")
+                raise ValueError(
+                    f"Unsupported decoding method: {params.method}"
+                )
             hyp_list.append(hyp)
 
     hyps = []
@@ -327,7 +329,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/aishell/ASR/pruned_transducer_stateless2/train.py b/egs/aishell/ASR/pruned_transducer_stateless2/train.py
index 66ca23035..97d892754 100755
--- a/egs/aishell/ASR/pruned_transducer_stateless2/train.py
+++ b/egs/aishell/ASR/pruned_transducer_stateless2/train.py
@@ -49,6 +49,7 @@ import optim
 import torch
 import torch.multiprocessing as mp
 import torch.nn as nn
+
 from asr_datamodule import AishellAsrDataModule
 from conformer import Conformer
 from decoder import Decoder
@@ -74,7 +75,9 @@ from icefall.env import get_env_info
 from icefall.lexicon import Lexicon
 from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
 
-LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
+LRSchedulerType = Union[
+    torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
+]
 
 
 def add_model_arguments(parser: argparse.ArgumentParser):
@@ -200,7 +203,8 @@ def get_parser():
         "--initial-lr",
         type=float,
         default=0.003,
-        help="The initial learning rate.  This value should not need to be changed.",
+        help="The initial learning rate.  This value should not need "
+        "to be changed.",
     )
 
     parser.add_argument(
@@ -223,45 +227,42 @@ def get_parser():
         "--context-size",
         type=int,
         default=1,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
 
     parser.add_argument(
         "--prune-range",
         type=int,
         default=5,
-        help=(
-            "The prune range for rnnt loss, it means how many symbols(context)"
-            "we are using to compute the loss"
-        ),
+        help="The prune range for rnnt loss, it means how many symbols(context)"
+        "we are using to compute the loss",
     )
 
     parser.add_argument(
         "--lm-scale",
         type=float,
         default=0.25,
-        help=(
-            "The scale to smooth the loss with lm (output of prediction network) part."
-        ),
+        help="The scale to smooth the loss with lm "
+        "(output of prediction network) part.",
     )
 
     parser.add_argument(
         "--am-scale",
         type=float,
         default=0.0,
-        help="The scale to smooth the loss with am (output of encoder network)part.",
+        help="The scale to smooth the loss with am (output of encoder network)"
+        "part.",
     )
 
     parser.add_argument(
         "--simple-loss-scale",
         type=float,
         default=0.5,
-        help=(
-            "To get pruning ranges, we will calculate a simple version"
-            "loss(joiner is just addition), this simple loss also uses for"
-            "training (as a regularization item). We will scale the simple loss"
-            "with this parameter before adding to the final loss."
-        ),
+        help="To get pruning ranges, we will calculate a simple version"
+        "loss(joiner is just addition), this simple loss also uses for"
+        "training (as a regularization item). We will scale the simple loss"
+        "with this parameter before adding to the final loss.",
     )
 
     parser.add_argument(
@@ -560,7 +561,11 @@ def compute_loss(
      warmup: a floating point value which increases throughout training;
         values >= 1.0 are fully warmed up and have all modules present.
     """
-    device = model.device if isinstance(model, DDP) else next(model.parameters()).device
+    device = (
+        model.device
+        if isinstance(model, DDP)
+        else next(model.parameters()).device
+    )
     feature = batch["inputs"]
     # at entry, feature is (N, T, C)
     assert feature.ndim == 3
@@ -588,16 +593,23 @@ def compute_loss(
         # overwhelming the simple_loss and causing it to diverge,
         # in case it had not fully learned the alignment yet.
         pruned_loss_scale = (
-            0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
+            0.0
+            if warmup < 1.0
+            else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
+        )
+        loss = (
+            params.simple_loss_scale * simple_loss
+            + pruned_loss_scale * pruned_loss
         )
-        loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
 
     assert loss.requires_grad == is_training
 
     info = MetricsTracker()
     with warnings.catch_warnings():
         warnings.simplefilter("ignore")
-        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
+        info["frames"] = (
+            (feature_lens // params.subsampling_factor).sum().item()
+        )
 
     # Note: We use reduction=sum while computing the loss.
     info["loss"] = loss.detach().cpu().item()
@@ -713,7 +725,9 @@ def train_one_epoch(
             scaler.update()
             optimizer.zero_grad()
         except:  # noqa
-            display_and_save_batch(batch, params=params, graph_compiler=graph_compiler)
+            display_and_save_batch(
+                batch, params=params, graph_compiler=graph_compiler
+            )
             raise
 
         if params.print_diagnostics and batch_idx == 5:
@@ -877,7 +891,7 @@ def run(rank, world_size, args):
 
     if params.print_diagnostics:
         opts = diagnostics.TensorDiagnosticOptions(
-            2**22
+            2 ** 22
         )  # allow 4 megabytes per sub-module
         diagnostic = diagnostics.attach_diagnostics(model, opts)
 
@@ -1015,7 +1029,9 @@ def scan_pessimistic_batches_for_oom(
                     f"Failing criterion: {criterion} "
                     f"(={crit_values[criterion]}) ..."
                 )
-            display_and_save_batch(batch, params=params, graph_compiler=graph_compiler)
+            display_and_save_batch(
+                batch, params=params, graph_compiler=graph_compiler
+            )
             raise
 
 
diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/decode.py b/egs/aishell/ASR/pruned_transducer_stateless3/decode.py
index 6c505940d..d159e420b 100755
--- a/egs/aishell/ASR/pruned_transducer_stateless3/decode.py
+++ b/egs/aishell/ASR/pruned_transducer_stateless3/decode.py
@@ -121,24 +121,20 @@ def get_parser():
         "--avg",
         type=int,
         default=15,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch' and '--iter'"
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch' and '--iter'",
     )
 
     parser.add_argument(
         "--use-averaged-model",
         type=str2bool,
         default=False,
-        help=(
-            "Whether to load averaged model. Currently it only supports "
-            "using --epoch. If True, it would decode with the averaged model "
-            "over the epoch range from `epoch-avg` (excluded) to `epoch`."
-            "Actually only the models with epoch number of `epoch-avg` and "
-            "`epoch` are loaded for averaging. "
-        ),
+        help="Whether to load averaged model. Currently it only supports "
+        "using --epoch. If True, it would decode with the averaged model "
+        "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+        "Actually only the models with epoch number of `epoch-avg` and "
+        "`epoch` are loaded for averaging. ",
     )
 
     parser.add_argument(
@@ -206,7 +202,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=1,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -266,7 +263,9 @@ def decode_one_batch(
     supervisions = batch["supervisions"]
     feature_lens = supervisions["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
+    encoder_out, encoder_out_lens = model.encoder(
+        x=feature, x_lens=feature_lens
+    )
 
     if params.decoding_method == "fast_beam_search":
         hyp_tokens = fast_beam_search_one_best(
@@ -278,7 +277,10 @@ def decode_one_batch(
             max_contexts=params.max_contexts,
             max_states=params.max_states,
         )
-    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
+    elif (
+        params.decoding_method == "greedy_search"
+        and params.max_sym_per_frame == 1
+    ):
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -322,7 +324,11 @@ def decode_one_batch(
         return {"greedy_search": hyps}
     elif params.decoding_method == "fast_beam_search":
         return {
-            f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps
+            (
+                f"beam_{params.beam}_"
+                f"max_contexts_{params.max_contexts}_"
+                f"max_states_{params.max_states}"
+            ): hyps
         }
     else:
         return {f"beam_size_{params.beam_size}": hyps}
@@ -395,7 +401,9 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+            logging.info(
+                f"batch {batch_str}, cuts processed until now is {num_cuts}"
+            )
     return results
 
 
@@ -421,7 +429,9 @@ def save_results(
         # we compute CER for aishell dataset.
         results_char = []
         for res in results:
-            results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
+            results_char.append(
+                (res[0], list("".join(res[1])), list("".join(res[2])))
+            )
         with open(errs_filename, "w") as f:
             wer = write_error_stats(
                 f, f"{test_set_name}-{key}", results_char, enable_log=True
@@ -432,7 +442,8 @@ def save_results(
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = (
-        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir
+        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tCER", file=f)
@@ -477,7 +488,9 @@ def main():
         params.suffix += f"-max-contexts-{params.max_contexts}"
         params.suffix += f"-max-states-{params.max_states}"
     elif "beam_search" in params.decoding_method:
-        params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
+        params.suffix += (
+            f"-{params.decoding_method}-beam-size-{params.beam_size}"
+        )
     else:
         params.suffix += f"-context-{params.context_size}"
         params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
@@ -505,12 +518,13 @@ def main():
 
     if not params.use_averaged_model:
         if params.iter > 0:
-            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
-                : params.avg
-            ]
+            filenames = find_checkpoints(
+                params.exp_dir, iteration=-params.iter
+            )[: params.avg]
             if len(filenames) == 0:
                 raise ValueError(
-                    f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
                 )
             elif len(filenames) < params.avg:
                 raise ValueError(
@@ -537,12 +551,13 @@ def main():
             )
     else:
         if params.iter > 0:
-            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
-                : params.avg + 1
-            ]
+            filenames = find_checkpoints(
+                params.exp_dir, iteration=-params.iter
+            )[: params.avg + 1]
             if len(filenames) == 0:
                 raise ValueError(
-                    f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
                 )
             elif len(filenames) < params.avg + 1:
                 raise ValueError(
@@ -571,7 +586,7 @@ def main():
             filename_start = f"{params.exp_dir}/epoch-{start}.pt"
             filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
             logging.info(
-                "Calculating the averaged model over epoch range from "
+                f"Calculating the averaged model over epoch range from "
                 f"{start} (excluded) to {params.epoch}"
             )
             model.to(device)
diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/export.py b/egs/aishell/ASR/pruned_transducer_stateless3/export.py
index e5a5d7c77..566902a85 100755
--- a/egs/aishell/ASR/pruned_transducer_stateless3/export.py
+++ b/egs/aishell/ASR/pruned_transducer_stateless3/export.py
@@ -88,24 +88,20 @@ def get_parser():
         "--avg",
         type=int,
         default=15,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch' and '--iter'"
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch' and '--iter'",
     )
 
     parser.add_argument(
         "--use-averaged-model",
         type=str2bool,
         default=True,
-        help=(
-            "Whether to load averaged model. Currently it only supports "
-            "using --epoch. If True, it would decode with the averaged model "
-            "over the epoch range from `epoch-avg` (excluded) to `epoch`."
-            "Actually only the models with epoch number of `epoch-avg` and "
-            "`epoch` are loaded for averaging. "
-        ),
+        help="Whether to load averaged model. Currently it only supports "
+        "using --epoch. If True, it would decode with the averaged model "
+        "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+        "Actually only the models with epoch number of `epoch-avg` and "
+        "`epoch` are loaded for averaging. ",
     )
 
     parser.add_argument(
@@ -136,7 +132,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=1,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
 
     add_model_arguments(parser)
@@ -169,12 +166,13 @@ def main():
 
     if not params.use_averaged_model:
         if params.iter > 0:
-            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
-                : params.avg
-            ]
+            filenames = find_checkpoints(
+                params.exp_dir, iteration=-params.iter
+            )[: params.avg]
             if len(filenames) == 0:
                 raise ValueError(
-                    f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
                 )
             elif len(filenames) < params.avg:
                 raise ValueError(
@@ -197,12 +195,13 @@ def main():
             model.load_state_dict(average_checkpoints(filenames, device=device))
     else:
         if params.iter > 0:
-            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
-                : params.avg + 1
-            ]
+            filenames = find_checkpoints(
+                params.exp_dir, iteration=-params.iter
+            )[: params.avg + 1]
             if len(filenames) == 0:
                 raise ValueError(
-                    f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
                 )
             elif len(filenames) < params.avg + 1:
                 raise ValueError(
@@ -230,7 +229,7 @@ def main():
             filename_start = f"{params.exp_dir}/epoch-{start}.pt"
             filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
             logging.info(
-                "Calculating the averaged model over epoch range from "
+                f"Calculating the averaged model over epoch range from "
                 f"{start} (excluded) to {params.epoch}"
             )
             model.to(device)
@@ -253,7 +252,9 @@ def main():
         model.__class__.forward = torch.jit.ignore(model.__class__.forward)
         logging.info("Using torch.jit.script")
         model = torch.jit.script(model)
-        filename = params.exp_dir / f"cpu_jit-epoch-{params.epoch}-avg-{params.avg}.pt"
+        filename = (
+            params.exp_dir / f"cpu_jit-epoch-{params.epoch}-avg-{params.avg}.pt"
+        )
         model.save(str(filename))
         logging.info(f"Saved to {filename}")
     else:
@@ -261,14 +262,17 @@ def main():
         # Save it using a format so that it can be loaded
         # by :func:`load_checkpoint`
         filename = (
-            params.exp_dir / f"pretrained-epoch-{params.epoch}-avg-{params.avg}.pt"
+            params.exp_dir
+            / f"pretrained-epoch-{params.epoch}-avg-{params.avg}.pt"
         )
         torch.save({"model": model.state_dict()}, str(filename))
         logging.info(f"Saved to {filename}")
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/model.py b/egs/aishell/ASR/pruned_transducer_stateless3/model.py
index a4dda0d6d..e150e8230 100644
--- a/egs/aishell/ASR/pruned_transducer_stateless3/model.py
+++ b/egs/aishell/ASR/pruned_transducer_stateless3/model.py
@@ -84,7 +84,9 @@ class Transducer(nn.Module):
         self.decoder_datatang = decoder_datatang
         self.joiner_datatang = joiner_datatang
 
-        self.simple_am_proj = ScaledLinear(encoder_dim, vocab_size, initial_speed=0.5)
+        self.simple_am_proj = ScaledLinear(
+            encoder_dim, vocab_size, initial_speed=0.5
+        )
         self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size)
 
         if decoder_datatang is not None:
@@ -177,7 +179,9 @@ class Transducer(nn.Module):
         y_padded = y.pad(mode="constant", padding_value=0)
 
         y_padded = y_padded.to(torch.int64)
-        boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device)
+        boundary = torch.zeros(
+            (x.size(0), 4), dtype=torch.int64, device=x.device
+        )
         boundary[:, 2] = y_lens
         boundary[:, 3] = encoder_out_lens
 
diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/pretrained.py b/egs/aishell/ASR/pruned_transducer_stateless3/pretrained.py
index 109879952..04a0a882a 100755
--- a/egs/aishell/ASR/pruned_transducer_stateless3/pretrained.py
+++ b/egs/aishell/ASR/pruned_transducer_stateless3/pretrained.py
@@ -87,11 +87,9 @@ def get_parser():
         "--checkpoint",
         type=str,
         required=True,
-        help=(
-            "Path to the checkpoint. "
-            "The checkpoint is assumed to be saved by "
-            "icefall.checkpoint.save_checkpoint()."
-        ),
+        help="Path to the checkpoint. "
+        "The checkpoint is assumed to be saved by "
+        "icefall.checkpoint.save_checkpoint().",
     )
 
     parser.add_argument(
@@ -117,12 +115,10 @@ def get_parser():
         "sound_files",
         type=str,
         nargs="+",
-        help=(
-            "The input sound file(s) to transcribe. "
-            "Supported formats are those supported by torchaudio.load(). "
-            "For example, wav and flac are supported. "
-            "The sample rate has to be 16kHz."
-        ),
+        help="The input sound file(s) to transcribe. "
+        "Supported formats are those supported by torchaudio.load(). "
+        "For example, wav and flac are supported. "
+        "The sample rate has to be 16kHz.",
     )
 
     parser.add_argument(
@@ -169,16 +165,15 @@ def get_parser():
         "--context-size",
         type=int,
         default=1,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
         type=int,
         default=1,
-        help=(
-            "Maximum number of symbols per frame. "
-            "Use only when --method is greedy_search"
-        ),
+        help="Maximum number of symbols per frame. "
+        "Use only when --method is greedy_search",
     )
 
     add_model_arguments(parser)
@@ -201,9 +196,10 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert (
-            sample_rate == expected_sample_rate
-        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
+        assert sample_rate == expected_sample_rate, (
+            f"expected sample rate: {expected_sample_rate}. "
+            f"Given: {sample_rate}"
+        )
         # We use only the first channel
         ans.append(wave[0])
     return ans
@@ -261,9 +257,13 @@ def main():
     feature_lens = [f.size(0) for f in features]
     feature_lens = torch.tensor(feature_lens, device=device)
 
-    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
+    features = pad_sequence(
+        features, batch_first=True, padding_value=math.log(1e-10)
+    )
 
-    encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lens)
+    encoder_out, encoder_out_lens = model.encoder(
+        x=features, x_lens=feature_lens
+    )
 
     num_waves = encoder_out.size(0)
     hyp_list = []
@@ -311,7 +311,9 @@ def main():
                     beam=params.beam_size,
                 )
             else:
-                raise ValueError(f"Unsupported decoding method: {params.method}")
+                raise ValueError(
+                    f"Unsupported decoding method: {params.method}"
+                )
             hyp_list.append(hyp)
 
     hyps = []
@@ -328,7 +330,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/train.py b/egs/aishell/ASR/pruned_transducer_stateless3/train.py
index b24f533ff..feaef5cf6 100755
--- a/egs/aishell/ASR/pruned_transducer_stateless3/train.py
+++ b/egs/aishell/ASR/pruned_transducer_stateless3/train.py
@@ -96,7 +96,9 @@ from icefall.env import get_env_info
 from icefall.lexicon import Lexicon
 from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
 
-LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
+LRSchedulerType = Union[
+    torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
+]
 
 
 def add_model_arguments(parser: argparse.ArgumentParser):
@@ -222,7 +224,8 @@ def get_parser():
         "--initial-lr",
         type=float,
         default=0.003,
-        help="The initial learning rate.  This value should not need to be changed.",
+        help="The initial learning rate.  This value should not need "
+        "to be changed.",
     )
 
     parser.add_argument(
@@ -245,45 +248,42 @@ def get_parser():
         "--context-size",
         type=int,
         default=1,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
 
     parser.add_argument(
         "--prune-range",
         type=int,
         default=5,
-        help=(
-            "The prune range for rnnt loss, it means how many symbols(context)"
-            "we are using to compute the loss"
-        ),
+        help="The prune range for rnnt loss, it means how many symbols(context)"
+        "we are using to compute the loss",
     )
 
     parser.add_argument(
         "--lm-scale",
         type=float,
         default=0.25,
-        help=(
-            "The scale to smooth the loss with lm (output of prediction network) part."
-        ),
+        help="The scale to smooth the loss with lm "
+        "(output of prediction network) part.",
     )
 
     parser.add_argument(
         "--am-scale",
         type=float,
         default=0.0,
-        help="The scale to smooth the loss with am (output of encoder network)part.",
+        help="The scale to smooth the loss with am (output of encoder network)"
+        "part.",
     )
 
     parser.add_argument(
         "--simple-loss-scale",
         type=float,
         default=0.5,
-        help=(
-            "To get pruning ranges, we will calculate a simple version"
-            "loss(joiner is just addition), this simple loss also uses for"
-            "training (as a regularization item). We will scale the simple loss"
-            "with this parameter before adding to the final loss."
-        ),
+        help="To get pruning ranges, we will calculate a simple version"
+        "loss(joiner is just addition), this simple loss also uses for"
+        "training (as a regularization item). We will scale the simple loss"
+        "with this parameter before adding to the final loss.",
     )
 
     parser.add_argument(
@@ -635,7 +635,11 @@ def compute_loss(
      warmup: a floating point value which increases throughout training;
         values >= 1.0 are fully warmed up and have all modules present.
     """
-    device = model.device if isinstance(model, DDP) else next(model.parameters()).device
+    device = (
+        model.device
+        if isinstance(model, DDP)
+        else next(model.parameters()).device
+    )
     feature = batch["inputs"]
     # at entry, feature is (N, T, C)
     assert feature.ndim == 3
@@ -666,16 +670,23 @@ def compute_loss(
         # overwhelming the simple_loss and causing it to diverge,
         # in case it had not fully learned the alignment yet.
         pruned_loss_scale = (
-            0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
+            0.0
+            if warmup < 1.0
+            else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
+        )
+        loss = (
+            params.simple_loss_scale * simple_loss
+            + pruned_loss_scale * pruned_loss
         )
-        loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
 
     assert loss.requires_grad == is_training
 
     info = MetricsTracker()
     with warnings.catch_warnings():
         warnings.simplefilter("ignore")
-        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
+        info["frames"] = (
+            (feature_lens // params.subsampling_factor).sum().item()
+        )
 
     # Note: We use reduction=sum while computing the loss.
     info["loss"] = loss.detach().cpu().item()
@@ -813,7 +824,9 @@ def train_one_epoch(
                 )
             # summary stats
             if datatang_train_dl is not None:
-                tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
+                tot_loss = (
+                    tot_loss * (1 - 1 / params.reset_interval)
+                ) + loss_info
 
             if aishell:
                 aishell_tot_loss = (
@@ -834,7 +847,9 @@ def train_one_epoch(
             scaler.update()
             optimizer.zero_grad()
         except:  # noqa
-            display_and_save_batch(batch, params=params, graph_compiler=graph_compiler)
+            display_and_save_batch(
+                batch, params=params, graph_compiler=graph_compiler
+            )
             raise
 
         if params.print_diagnostics and batch_idx == 5:
@@ -877,7 +892,9 @@ def train_one_epoch(
             cur_lr = scheduler.get_last_lr()[0]
             if datatang_train_dl is not None:
                 datatang_str = f"datatang_tot_loss[{datatang_tot_loss}], "
-                tot_loss_str = f"tot_loss[{tot_loss}], batch size: {batch_size}, "
+                tot_loss_str = (
+                    f"tot_loss[{tot_loss}], batch size: {batch_size}, "
+                )
             else:
                 tot_loss_str = ""
                 datatang_str = ""
@@ -1050,7 +1067,7 @@ def run(rank, world_size, args):
 
     if params.print_diagnostics:
         opts = diagnostics.TensorDiagnosticOptions(
-            2**22
+            2 ** 22
         )  # allow 4 megabytes per sub-module
         diagnostic = diagnostics.attach_diagnostics(model, opts)
 
@@ -1059,7 +1076,9 @@ def run(rank, world_size, args):
     train_cuts = filter_short_and_long_utterances(train_cuts)
 
     if args.enable_musan:
-        cuts_musan = load_manifest(Path(args.manifest_dir) / "musan_cuts.jsonl.gz")
+        cuts_musan = load_manifest(
+            Path(args.manifest_dir) / "musan_cuts.jsonl.gz"
+        )
     else:
         cuts_musan = None
 
@@ -1074,7 +1093,9 @@ def run(rank, world_size, args):
     if params.datatang_prob > 0:
         datatang = AIDatatang200zh(manifest_dir=args.manifest_dir)
         train_datatang_cuts = datatang.train_cuts()
-        train_datatang_cuts = filter_short_and_long_utterances(train_datatang_cuts)
+        train_datatang_cuts = filter_short_and_long_utterances(
+            train_datatang_cuts
+        )
         train_datatang_cuts = train_datatang_cuts.repeat(times=None)
         datatang_train_dl = asr_datamodule.train_dataloaders(
             train_datatang_cuts,
@@ -1228,7 +1249,9 @@ def scan_pessimistic_batches_for_oom(
                     f"Failing criterion: {criterion} "
                     f"(={crit_values[criterion]}) ..."
                 )
-            display_and_save_batch(batch, params=params, graph_compiler=graph_compiler)
+            display_and_save_batch(
+                batch, params=params, graph_compiler=graph_compiler
+            )
             raise
 
 
diff --git a/egs/aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py
index 12ae6e7d4..d24ba6bb7 100644
--- a/egs/aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py
+++ b/egs/aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py
@@ -64,12 +64,10 @@ class AishellAsrDataModule:
     def add_arguments(cls, parser: argparse.ArgumentParser):
         group = parser.add_argument_group(
             title="ASR data related options",
-            description=(
-                "These options are used for the preparation of "
-                "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
-                "effective batch sizes, sampling strategies, applied data "
-                "augmentations, etc."
-            ),
+            description="These options are used for the preparation of "
+            "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
+            "effective batch sizes, sampling strategies, applied data "
+            "augmentations, etc.",
         )
         group.add_argument(
             "--manifest-dir",
@@ -81,74 +79,59 @@ class AishellAsrDataModule:
             "--max-duration",
             type=int,
             default=200.0,
-            help=(
-                "Maximum pooled recordings duration (seconds) in a "
-                "single batch. You can reduce it if it causes CUDA OOM."
-            ),
+            help="Maximum pooled recordings duration (seconds) in a "
+            "single batch. You can reduce it if it causes CUDA OOM.",
         )
         group.add_argument(
             "--bucketing-sampler",
             type=str2bool,
             default=True,
-            help=(
-                "When enabled, the batches will come from buckets of "
-                "similar duration (saves padding frames)."
-            ),
+            help="When enabled, the batches will come from buckets of "
+            "similar duration (saves padding frames).",
         )
         group.add_argument(
             "--num-buckets",
             type=int,
             default=30,
-            help=(
-                "The number of buckets for the DynamicBucketingSampler"
-                "(you might want to increase it for larger datasets)."
-            ),
+            help="The number of buckets for the DynamicBucketingSampler"
+            "(you might want to increase it for larger datasets).",
         )
         group.add_argument(
             "--concatenate-cuts",
             type=str2bool,
             default=False,
-            help=(
-                "When enabled, utterances (cuts) will be concatenated "
-                "to minimize the amount of padding."
-            ),
+            help="When enabled, utterances (cuts) will be concatenated "
+            "to minimize the amount of padding.",
         )
         group.add_argument(
             "--duration-factor",
             type=float,
             default=1.0,
-            help=(
-                "Determines the maximum duration of a concatenated cut "
-                "relative to the duration of the longest cut in a batch."
-            ),
+            help="Determines the maximum duration of a concatenated cut "
+            "relative to the duration of the longest cut in a batch.",
         )
         group.add_argument(
             "--gap",
             type=float,
             default=1.0,
-            help=(
-                "The amount of padding (in seconds) inserted between "
-                "concatenated cuts. This padding is filled with noise when "
-                "noise augmentation is used."
-            ),
+            help="The amount of padding (in seconds) inserted between "
+            "concatenated cuts. This padding is filled with noise when "
+            "noise augmentation is used.",
         )
         group.add_argument(
             "--on-the-fly-feats",
             type=str2bool,
             default=False,
-            help=(
-                "When enabled, use on-the-fly cut mixing and feature "
-                "extraction. Will drop existing precomputed feature manifests "
-                "if available."
-            ),
+            help="When enabled, use on-the-fly cut mixing and feature "
+            "extraction. Will drop existing precomputed feature manifests "
+            "if available.",
         )
         group.add_argument(
             "--shuffle",
             type=str2bool,
             default=True,
-            help=(
-                "When enabled (=default), the examples will be shuffled for each epoch."
-            ),
+            help="When enabled (=default), the examples will be "
+            "shuffled for each epoch.",
         )
         group.add_argument(
             "--drop-last",
@@ -160,18 +143,17 @@ class AishellAsrDataModule:
             "--return-cuts",
             type=str2bool,
             default=True,
-            help=(
-                "When enabled, each batch will have the "
-                "field: batch['supervisions']['cut'] with the cuts that "
-                "were used to construct it."
-            ),
+            help="When enabled, each batch will have the "
+            "field: batch['supervisions']['cut'] with the cuts that "
+            "were used to construct it.",
         )
 
         group.add_argument(
             "--num-workers",
             type=int,
             default=2,
-            help="The number of training dataloader workers that collect the batches.",
+            help="The number of training dataloader workers that "
+            "collect the batches.",
         )
 
         group.add_argument(
@@ -185,40 +167,40 @@ class AishellAsrDataModule:
             "--spec-aug-time-warp-factor",
             type=int,
             default=80,
-            help=(
-                "Used only when --enable-spec-aug is True. "
-                "It specifies the factor for time warping in SpecAugment. "
-                "Larger values mean more warping. "
-                "A value less than 1 means to disable time warp."
-            ),
+            help="Used only when --enable-spec-aug is True. "
+            "It specifies the factor for time warping in SpecAugment. "
+            "Larger values mean more warping. "
+            "A value less than 1 means to disable time warp.",
         )
 
         group.add_argument(
             "--enable-musan",
             type=str2bool,
             default=True,
-            help=(
-                "When enabled, select noise from MUSAN and mix it"
-                "with training dataset. "
-            ),
+            help="When enabled, select noise from MUSAN and mix it"
+            "with training dataset. ",
         )
 
     def train_dataloaders(self, cuts_train: CutSet) -> DataLoader:
         logging.info("About to get Musan cuts")
-        cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
+        cuts_musan = load_manifest(
+            self.args.manifest_dir / "musan_cuts.jsonl.gz"
+        )
 
         transforms = []
         if self.args.enable_musan:
             logging.info("Enable MUSAN")
             transforms.append(
-                CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
+                CutMix(
+                    cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
+                )
             )
         else:
             logging.info("Disable MUSAN")
 
         if self.args.concatenate_cuts:
             logging.info(
-                "Using cut concatenation with duration factor "
+                f"Using cut concatenation with duration factor "
                 f"{self.args.duration_factor} and gap {self.args.gap}."
             )
             # Cut concatenation should be the first transform in the list,
@@ -233,7 +215,9 @@ class AishellAsrDataModule:
         input_transforms = []
         if self.args.enable_spec_aug:
             logging.info("Enable SpecAugment")
-            logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
+            logging.info(
+                f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
+            )
             # Set the value of num_frame_masks according to Lhotse's version.
             # In different Lhotse's versions, the default of num_frame_masks is
             # different.
@@ -276,7 +260,9 @@ class AishellAsrDataModule:
             # Drop feats to be on the safe side.
             train = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
+                input_strategy=OnTheFlyFeatures(
+                    Fbank(FbankConfig(num_mel_bins=80))
+                ),
                 input_transforms=input_transforms,
                 return_cuts=self.args.return_cuts,
             )
@@ -322,7 +308,9 @@ class AishellAsrDataModule:
         if self.args.on_the_fly_feats:
             validate = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
+                input_strategy=OnTheFlyFeatures(
+                    Fbank(FbankConfig(num_mel_bins=80))
+                ),
                 return_cuts=self.args.return_cuts,
             )
         else:
@@ -378,9 +366,13 @@ class AishellAsrDataModule:
     @lru_cache()
     def valid_cuts(self) -> CutSet:
         logging.info("About to get dev cuts")
-        return load_manifest_lazy(self.args.manifest_dir / "aishell_cuts_dev.jsonl.gz")
+        return load_manifest_lazy(
+            self.args.manifest_dir / "aishell_cuts_dev.jsonl.gz"
+        )
 
     @lru_cache()
     def test_cuts(self) -> List[CutSet]:
         logging.info("About to get test cuts")
-        return load_manifest_lazy(self.args.manifest_dir / "aishell_cuts_test.jsonl.gz")
+        return load_manifest_lazy(
+            self.args.manifest_dir / "aishell_cuts_test.jsonl.gz"
+        )
diff --git a/egs/aishell/ASR/tdnn_lstm_ctc/decode.py b/egs/aishell/ASR/tdnn_lstm_ctc/decode.py
index 8ef247438..66b734fc4 100755
--- a/egs/aishell/ASR/tdnn_lstm_ctc/decode.py
+++ b/egs/aishell/ASR/tdnn_lstm_ctc/decode.py
@@ -49,19 +49,16 @@ def get_parser():
         "--epoch",
         type=int,
         default=19,
-        help=(
-            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
-        ),
+        help="It specifies the checkpoint to use for decoding."
+        "Note: Epoch counts from 0.",
     )
     parser.add_argument(
         "--avg",
         type=int,
         default=5,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch'. "
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch'. ",
     )
     parser.add_argument(
         "--method",
@@ -268,7 +265,9 @@ def decode_dataset(
         if batch_idx % 100 == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+            logging.info(
+                f"batch {batch_str}, cuts processed until now is {num_cuts}"
+            )
     return results
 
 
@@ -290,7 +289,9 @@ def save_results(
         # We compute CER for aishell dataset.
         results_char = []
         for res in results:
-            results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
+            results_char.append(
+                (res[0], list("".join(res[1])), list("".join(res[2])))
+            )
         with open(errs_filename, "w") as f:
             wer = write_error_stats(f, f"{test_set_name}-{key}", results_char)
             test_set_wers[key] = wer
@@ -334,7 +335,9 @@ def main():
 
     logging.info(f"device: {device}")
 
-    HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu"))
+    HLG = k2.Fsa.from_dict(
+        torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")
+    )
     HLG = HLG.to(device)
     assert HLG.requires_grad is False
 
@@ -359,7 +362,9 @@ def main():
 
     if params.export:
         logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
-        torch.save({"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt")
+        torch.save(
+            {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
+        )
 
     model.to(device)
     model.eval()
@@ -387,7 +392,9 @@ def main():
             lexicon=lexicon,
         )
 
-        save_results(params=params, test_set_name=test_set, results_dict=results_dict)
+        save_results(
+            params=params, test_set_name=test_set, results_dict=results_dict
+        )
 
     logging.info("Done!")
 
diff --git a/egs/aishell/ASR/tdnn_lstm_ctc/model.py b/egs/aishell/ASR/tdnn_lstm_ctc/model.py
index 1731e1ebe..5e04c11b4 100644
--- a/egs/aishell/ASR/tdnn_lstm_ctc/model.py
+++ b/egs/aishell/ASR/tdnn_lstm_ctc/model.py
@@ -66,7 +66,10 @@ class TdnnLstm(nn.Module):
             nn.BatchNorm1d(num_features=500, affine=False),
         )
         self.lstms = nn.ModuleList(
-            [nn.LSTM(input_size=500, hidden_size=500, num_layers=1) for _ in range(5)]
+            [
+                nn.LSTM(input_size=500, hidden_size=500, num_layers=1)
+                for _ in range(5)
+            ]
         )
         self.lstm_bnorms = nn.ModuleList(
             [nn.BatchNorm1d(num_features=500, affine=False) for _ in range(5)]
diff --git a/egs/aishell/ASR/tdnn_lstm_ctc/pretrained.py b/egs/aishell/ASR/tdnn_lstm_ctc/pretrained.py
index 52f9410cf..9bd810809 100644
--- a/egs/aishell/ASR/tdnn_lstm_ctc/pretrained.py
+++ b/egs/aishell/ASR/tdnn_lstm_ctc/pretrained.py
@@ -41,11 +41,9 @@ def get_parser():
         "--checkpoint",
         type=str,
         required=True,
-        help=(
-            "Path to the checkpoint. "
-            "The checkpoint is assumed to be saved by "
-            "icefall.checkpoint.save_checkpoint()."
-        ),
+        help="Path to the checkpoint. "
+        "The checkpoint is assumed to be saved by "
+        "icefall.checkpoint.save_checkpoint().",
     )
 
     parser.add_argument(
@@ -55,7 +53,9 @@ def get_parser():
         help="Path to words.txt",
     )
 
-    parser.add_argument("--HLG", type=str, required=True, help="Path to HLG.pt.")
+    parser.add_argument(
+        "--HLG", type=str, required=True, help="Path to HLG.pt."
+    )
 
     parser.add_argument(
         "--method",
@@ -71,12 +71,10 @@ def get_parser():
         "sound_files",
         type=str,
         nargs="+",
-        help=(
-            "The input sound file(s) to transcribe. "
-            "Supported formats are those supported by torchaudio.load(). "
-            "For example, wav and flac are supported. "
-            "The sample rate has to be 16kHz."
-        ),
+        help="The input sound file(s) to transcribe. "
+        "Supported formats are those supported by torchaudio.load(). "
+        "For example, wav and flac are supported. "
+        "The sample rate has to be 16kHz.",
     )
 
     return parser
@@ -114,9 +112,10 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert (
-            sample_rate == expected_sample_rate
-        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
+        assert sample_rate == expected_sample_rate, (
+            f"expected sample rate: {expected_sample_rate}. "
+            f"Given: {sample_rate}"
+        )
         # We use only the first channel
         ans.append(wave[0])
     return ans
@@ -174,7 +173,9 @@ def main():
     logging.info("Decoding started")
     features = fbank(waves)
 
-    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
+    features = pad_sequence(
+        features, batch_first=True, padding_value=math.log(1e-10)
+    )
     features = features.permute(0, 2, 1)  # now features is [N, C, T]
 
     with torch.no_grad():
@@ -218,7 +219,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/aishell/ASR/tdnn_lstm_ctc/train.py b/egs/aishell/ASR/tdnn_lstm_ctc/train.py
index e574cf89b..7619b0551 100755
--- a/egs/aishell/ASR/tdnn_lstm_ctc/train.py
+++ b/egs/aishell/ASR/tdnn_lstm_ctc/train.py
@@ -49,7 +49,12 @@ from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
 from icefall.dist import cleanup_dist, setup_dist
 from icefall.graph_compiler import CtcTrainingGraphCompiler
 from icefall.lexicon import Lexicon
-from icefall.utils import AttributeDict, encode_supervisions, setup_logger, str2bool
+from icefall.utils import (
+    AttributeDict,
+    encode_supervisions,
+    setup_logger,
+    str2bool,
+)
 
 
 def get_parser():
diff --git a/egs/aishell/ASR/transducer_stateless/beam_search.py b/egs/aishell/ASR/transducer_stateless/beam_search.py
index de0a8d0f5..9ed9b2ad1 100644
--- a/egs/aishell/ASR/transducer_stateless/beam_search.py
+++ b/egs/aishell/ASR/transducer_stateless/beam_search.py
@@ -47,9 +47,9 @@ def greedy_search(
 
     device = model.device
 
-    decoder_input = torch.tensor([blank_id] * context_size, device=device).reshape(
-        1, context_size
-    )
+    decoder_input = torch.tensor(
+        [blank_id] * context_size, device=device
+    ).reshape(1, context_size)
 
     decoder_out = model.decoder(decoder_input, need_pad=False)
 
@@ -81,9 +81,9 @@ def greedy_search(
         y = logits.argmax().item()
         if y != blank_id:
             hyp.append(y)
-            decoder_input = torch.tensor([hyp[-context_size:]], device=device).reshape(
-                1, context_size
-            )
+            decoder_input = torch.tensor(
+                [hyp[-context_size:]], device=device
+            ).reshape(1, context_size)
 
             decoder_out = model.decoder(decoder_input, need_pad=False)
 
@@ -157,7 +157,9 @@ class HypothesisList(object):
 
         """
         if length_norm:
-            return max(self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys))
+            return max(
+                self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys)
+            )
         else:
             return max(self._data.values(), key=lambda hyp: hyp.log_prob)
 
@@ -244,9 +246,9 @@ def beam_search(
 
     device = model.device
 
-    decoder_input = torch.tensor([blank_id] * context_size, device=device).reshape(
-        1, context_size
-    )
+    decoder_input = torch.tensor(
+        [blank_id] * context_size, device=device
+    ).reshape(1, context_size)
 
     decoder_out = model.decoder(decoder_input, need_pad=False)
 
diff --git a/egs/aishell/ASR/transducer_stateless/conformer.py b/egs/aishell/ASR/transducer_stateless/conformer.py
index e26c6c385..64114253d 100644
--- a/egs/aishell/ASR/transducer_stateless/conformer.py
+++ b/egs/aishell/ASR/transducer_stateless/conformer.py
@@ -155,7 +155,9 @@ class ConformerEncoderLayer(nn.Module):
         normalize_before: bool = True,
     ) -> None:
         super(ConformerEncoderLayer, self).__init__()
-        self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0)
+        self.self_attn = RelPositionMultiheadAttention(
+            d_model, nhead, dropout=0.0
+        )
 
         self.feed_forward = nn.Sequential(
             nn.Linear(d_model, dim_feedforward),
@@ -173,14 +175,18 @@ class ConformerEncoderLayer(nn.Module):
 
         self.conv_module = ConvolutionModule(d_model, cnn_module_kernel)
 
-        self.norm_ff_macaron = nn.LayerNorm(d_model)  # for the macaron style FNN module
+        self.norm_ff_macaron = nn.LayerNorm(
+            d_model
+        )  # for the macaron style FNN module
         self.norm_ff = nn.LayerNorm(d_model)  # for the FNN module
         self.norm_mha = nn.LayerNorm(d_model)  # for the MHA module
 
         self.ff_scale = 0.5
 
         self.norm_conv = nn.LayerNorm(d_model)  # for the CNN module
-        self.norm_final = nn.LayerNorm(d_model)  # for the final output of the block
+        self.norm_final = nn.LayerNorm(
+            d_model
+        )  # for the final output of the block
 
         self.dropout = nn.Dropout(dropout)
 
@@ -214,7 +220,9 @@ class ConformerEncoderLayer(nn.Module):
         residual = src
         if self.normalize_before:
             src = self.norm_ff_macaron(src)
-        src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src))
+        src = residual + self.ff_scale * self.dropout(
+            self.feed_forward_macaron(src)
+        )
         if not self.normalize_before:
             src = self.norm_ff_macaron(src)
 
@@ -333,7 +341,9 @@ class RelPositionalEncoding(torch.nn.Module):
 
     """
 
-    def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None:
+    def __init__(
+        self, d_model: int, dropout_rate: float, max_len: int = 5000
+    ) -> None:
         """Construct an PositionalEncoding object."""
         super(RelPositionalEncoding, self).__init__()
         self.d_model = d_model
@@ -349,7 +359,9 @@ class RelPositionalEncoding(torch.nn.Module):
             # the length of self.pe is 2 * input_len - 1
             if self.pe.size(1) >= x.size(1) * 2 - 1:
                 # Note: TorchScript doesn't implement operator== for torch.Device
-                if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device):
+                if self.pe.dtype != x.dtype or str(self.pe.device) != str(
+                    x.device
+                ):
                     self.pe = self.pe.to(dtype=x.dtype, device=x.device)
                 return
         # Suppose `i` means to the position of query vector and `j` means the
@@ -619,9 +631,9 @@ class RelPositionMultiheadAttention(nn.Module):
 
         if torch.equal(query, key) and torch.equal(key, value):
             # self-attention
-            q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk(
-                3, dim=-1
-            )
+            q, k, v = nn.functional.linear(
+                query, in_proj_weight, in_proj_bias
+            ).chunk(3, dim=-1)
 
         elif torch.equal(key, value):
             # encoder-decoder attention
@@ -689,25 +701,33 @@ class RelPositionMultiheadAttention(nn.Module):
             if attn_mask.dim() == 2:
                 attn_mask = attn_mask.unsqueeze(0)
                 if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
-                    raise RuntimeError("The size of the 2D attn_mask is not correct.")
+                    raise RuntimeError(
+                        "The size of the 2D attn_mask is not correct."
+                    )
             elif attn_mask.dim() == 3:
                 if list(attn_mask.size()) != [
                     bsz * num_heads,
                     query.size(0),
                     key.size(0),
                 ]:
-                    raise RuntimeError("The size of the 3D attn_mask is not correct.")
+                    raise RuntimeError(
+                        "The size of the 3D attn_mask is not correct."
+                    )
             else:
                 raise RuntimeError(
-                    "attn_mask's dimension {} is not supported".format(attn_mask.dim())
+                    "attn_mask's dimension {} is not supported".format(
+                        attn_mask.dim()
+                    )
                 )
             # attn_mask's dim is 3 now.
 
         # convert ByteTensor key_padding_mask to bool
-        if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
+        if (
+            key_padding_mask is not None
+            and key_padding_mask.dtype == torch.uint8
+        ):
             warnings.warn(
-                "Byte tensor for key_padding_mask is deprecated. Use bool tensor"
-                " instead."
+                "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead."
             )
             key_padding_mask = key_padding_mask.to(torch.bool)
 
@@ -744,7 +764,9 @@ class RelPositionMultiheadAttention(nn.Module):
         # first compute matrix a and matrix c
         # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
         k = k.permute(1, 2, 3, 0)  # (batch, head, d_k, time2)
-        matrix_ac = torch.matmul(q_with_bias_u, k)  # (batch, head, time1, time2)
+        matrix_ac = torch.matmul(
+            q_with_bias_u, k
+        )  # (batch, head, time1, time2)
 
         # compute matrix b and matrix d
         matrix_bd = torch.matmul(
@@ -756,7 +778,9 @@ class RelPositionMultiheadAttention(nn.Module):
             matrix_ac + matrix_bd
         ) * scaling  # (batch, head, time1, time2)
 
-        attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1)
+        attn_output_weights = attn_output_weights.view(
+            bsz * num_heads, tgt_len, -1
+        )
 
         assert list(attn_output_weights.size()) == [
             bsz * num_heads,
@@ -790,9 +814,13 @@ class RelPositionMultiheadAttention(nn.Module):
         attn_output = torch.bmm(attn_output_weights, v)
         assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
         attn_output = (
-            attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
+            attn_output.transpose(0, 1)
+            .contiguous()
+            .view(tgt_len, bsz, embed_dim)
+        )
+        attn_output = nn.functional.linear(
+            attn_output, out_proj_weight, out_proj_bias
         )
-        attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias)
 
         if need_weights:
             # average attention weights over heads
@@ -815,7 +843,9 @@ class ConvolutionModule(nn.Module):
 
     """
 
-    def __init__(self, channels: int, kernel_size: int, bias: bool = True) -> None:
+    def __init__(
+        self, channels: int, kernel_size: int, bias: bool = True
+    ) -> None:
         """Construct an ConvolutionModule object."""
         super(ConvolutionModule, self).__init__()
         # kernerl_size should be a odd number for 'SAME' padding
diff --git a/egs/aishell/ASR/transducer_stateless/decode.py b/egs/aishell/ASR/transducer_stateless/decode.py
index 1f7bb14e1..780b0c4bb 100755
--- a/egs/aishell/ASR/transducer_stateless/decode.py
+++ b/egs/aishell/ASR/transducer_stateless/decode.py
@@ -52,19 +52,16 @@ def get_parser():
         "--epoch",
         type=int,
         default=30,
-        help=(
-            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
-        ),
+        help="It specifies the checkpoint to use for decoding."
+        "Note: Epoch counts from 0.",
     )
     parser.add_argument(
         "--avg",
         type=int,
         default=10,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch'. "
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch'. ",
     )
 
     parser.add_argument(
@@ -102,7 +99,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -229,7 +227,9 @@ def decode_one_batch(
     supervisions = batch["supervisions"]
     feature_lens = supervisions["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
+    encoder_out, encoder_out_lens = model.encoder(
+        x=feature, x_lens=feature_lens
+    )
     hyps = []
     batch_size = encoder_out.size(0)
 
@@ -248,7 +248,9 @@ def decode_one_batch(
                 model=model, encoder_out=encoder_out_i, beam=params.beam_size
             )
         else:
-            raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
+            raise ValueError(
+                f"Unsupported decoding method: {params.decoding_method}"
+            )
         hyps.append([lexicon.token_table[i] for i in hyp])
 
     if params.decoding_method == "greedy_search":
@@ -317,7 +319,9 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+            logging.info(
+                f"batch {batch_str}, cuts processed until now is {num_cuts}"
+            )
     return results
 
 
@@ -342,7 +346,9 @@ def save_results(
         # we compute CER for aishell dataset.
         results_char = []
         for res in results:
-            results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
+            results_char.append(
+                (res[0], list("".join(res[1])), list("".join(res[2])))
+            )
         with open(errs_filename, "w") as f:
             wer = write_error_stats(
                 f, f"{test_set_name}-{key}", results_char, enable_log=True
@@ -353,7 +359,8 @@ def save_results(
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = (
-        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir
+        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tCER", file=f)
@@ -423,7 +430,9 @@ def main():
 
     if params.export:
         logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
-        torch.save({"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt")
+        torch.save(
+            {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
+        )
         return
 
     model.to(device)
diff --git a/egs/aishell/ASR/transducer_stateless/decoder.py b/egs/aishell/ASR/transducer_stateless/decoder.py
index 70e9e6c96..c2c6552a9 100644
--- a/egs/aishell/ASR/transducer_stateless/decoder.py
+++ b/egs/aishell/ASR/transducer_stateless/decoder.py
@@ -86,7 +86,9 @@ class Decoder(nn.Module):
         if self.context_size > 1:
             embedding_out = embedding_out.permute(0, 2, 1)
             if need_pad is True:
-                embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0))
+                embedding_out = F.pad(
+                    embedding_out, pad=(self.context_size - 1, 0)
+                )
             else:
                 # During inference time, there is no need to do extra padding
                 # as we only need one output
diff --git a/egs/aishell/ASR/transducer_stateless/export.py b/egs/aishell/ASR/transducer_stateless/export.py
index e35b26fe0..4c6519b96 100755
--- a/egs/aishell/ASR/transducer_stateless/export.py
+++ b/egs/aishell/ASR/transducer_stateless/export.py
@@ -69,20 +69,17 @@ def get_parser():
         "--epoch",
         type=int,
         default=20,
-        help=(
-            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
-        ),
+        help="It specifies the checkpoint to use for decoding."
+        "Note: Epoch counts from 0.",
     )
 
     parser.add_argument(
         "--avg",
         type=int,
         default=10,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch'. "
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch'. ",
     )
 
     parser.add_argument(
@@ -113,7 +110,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
 
     return parser
@@ -245,7 +243,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/aishell/ASR/transducer_stateless/model.py b/egs/aishell/ASR/transducer_stateless/model.py
index 591bbe44f..994305fc1 100644
--- a/egs/aishell/ASR/transducer_stateless/model.py
+++ b/egs/aishell/ASR/transducer_stateless/model.py
@@ -103,7 +103,9 @@ class Transducer(nn.Module):
         y_padded = y.pad(mode="constant", padding_value=0)
 
         y_padded = y_padded.to(torch.int64)
-        boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device)
+        boundary = torch.zeros(
+            (x.size(0), 4), dtype=torch.int64, device=x.device
+        )
         boundary[:, 2] = y_lens
         boundary[:, 3] = x_lens
 
diff --git a/egs/aishell/ASR/transducer_stateless/pretrained.py b/egs/aishell/ASR/transducer_stateless/pretrained.py
index 8effc9815..db89c4d67 100755
--- a/egs/aishell/ASR/transducer_stateless/pretrained.py
+++ b/egs/aishell/ASR/transducer_stateless/pretrained.py
@@ -73,11 +73,9 @@ def get_parser():
         "--checkpoint",
         type=str,
         required=True,
-        help=(
-            "Path to the checkpoint. "
-            "The checkpoint is assumed to be saved by "
-            "icefall.checkpoint.save_checkpoint()."
-        ),
+        help="Path to the checkpoint. "
+        "The checkpoint is assumed to be saved by "
+        "icefall.checkpoint.save_checkpoint().",
     )
 
     parser.add_argument(
@@ -102,12 +100,10 @@ def get_parser():
         "sound_files",
         type=str,
         nargs="+",
-        help=(
-            "The input sound file(s) to transcribe. "
-            "Supported formats are those supported by torchaudio.load(). "
-            "For example, wav and flac are supported. "
-            "The sample rate has to be 16kHz."
-        ),
+        help="The input sound file(s) to transcribe. "
+        "Supported formats are those supported by torchaudio.load(). "
+        "For example, wav and flac are supported. "
+        "The sample rate has to be 16kHz.",
     )
 
     parser.add_argument(
@@ -121,7 +117,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -214,9 +211,10 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert (
-            sample_rate == expected_sample_rate
-        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
+        assert sample_rate == expected_sample_rate, (
+            f"expected sample rate: {expected_sample_rate}. "
+            f"Given: {sample_rate}"
+        )
         # We use only the first channel
         ans.append(wave[0])
     return ans
@@ -275,7 +273,9 @@ def main():
     features = fbank(waves)
     feature_lengths = [f.size(0) for f in features]
 
-    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
+    features = pad_sequence(
+        features, batch_first=True, padding_value=math.log(1e-10)
+    )
 
     feature_lengths = torch.tensor(feature_lengths, device=device)
 
@@ -319,7 +319,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/aishell/ASR/transducer_stateless/train.py b/egs/aishell/ASR/transducer_stateless/train.py
index 62ffff473..d54157709 100755
--- a/egs/aishell/ASR/transducer_stateless/train.py
+++ b/egs/aishell/ASR/transducer_stateless/train.py
@@ -126,7 +126,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
 
     parser.add_argument(
@@ -388,7 +389,9 @@ def compute_loss(
     info = MetricsTracker()
     with warnings.catch_warnings():
         warnings.simplefilter("ignore")
-        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
+        info["frames"] = (
+            (feature_lens // params.subsampling_factor).sum().item()
+        )
 
     # Note: We use reduction=sum while computing the loss.
     info["loss"] = loss.detach().cpu().item()
@@ -501,7 +504,9 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+                tot_loss.write_summary(
+                    tb_writer, "train/tot_", params.batch_idx_train
+                )
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
@@ -620,7 +625,9 @@ def run(rank, world_size, args):
 
         cur_lr = optimizer._rate
         if tb_writer is not None:
-            tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
+            tb_writer.add_scalar(
+                "train/learning_rate", cur_lr, params.batch_idx_train
+            )
             tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
 
         if rank == 0:
diff --git a/egs/aishell/ASR/transducer_stateless/transformer.py b/egs/aishell/ASR/transducer_stateless/transformer.py
index b3ff153c1..e851dcc32 100644
--- a/egs/aishell/ASR/transducer_stateless/transformer.py
+++ b/egs/aishell/ASR/transducer_stateless/transformer.py
@@ -250,7 +250,9 @@ def _get_activation_fn(activation: str):
     elif activation == "gelu":
         return nn.functional.gelu
 
-    raise RuntimeError("activation should be relu/gelu, not {}".format(activation))
+    raise RuntimeError(
+        "activation should be relu/gelu, not {}".format(activation)
+    )
 
 
 class PositionalEncoding(nn.Module):
diff --git a/egs/aishell/ASR/transducer_stateless_modified-2/asr_datamodule.py b/egs/aishell/ASR/transducer_stateless_modified-2/asr_datamodule.py
index 76e209f06..838e53658 100644
--- a/egs/aishell/ASR/transducer_stateless_modified-2/asr_datamodule.py
+++ b/egs/aishell/ASR/transducer_stateless_modified-2/asr_datamodule.py
@@ -29,7 +29,10 @@ from lhotse.dataset import (
     K2SpeechRecognitionDataset,
     SpecAugment,
 )
-from lhotse.dataset.input_strategies import OnTheFlyFeatures, PrecomputedFeatures
+from lhotse.dataset.input_strategies import (
+    OnTheFlyFeatures,
+    PrecomputedFeatures,
+)
 from torch.utils.data import DataLoader
 
 from icefall.utils import str2bool
@@ -43,69 +46,59 @@ class AsrDataModule:
     def add_arguments(cls, parser: argparse.ArgumentParser):
         group = parser.add_argument_group(
             title="ASR data related options",
-            description=(
-                "These options are used for the preparation of "
-                "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
-                "effective batch sizes, sampling strategies, applied data "
-                "augmentations, etc."
-            ),
+            description="These options are used for the preparation of "
+            "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
+            "effective batch sizes, sampling strategies, applied data "
+            "augmentations, etc.",
         )
 
         group.add_argument(
             "--max-duration",
             type=int,
             default=200.0,
-            help=(
-                "Maximum pooled recordings duration (seconds) in a "
-                "single batch. You can reduce it if it causes CUDA OOM."
-            ),
+            help="Maximum pooled recordings duration (seconds) in a "
+            "single batch. You can reduce it if it causes CUDA OOM.",
         )
 
         group.add_argument(
             "--bucketing-sampler",
             type=str2bool,
             default=True,
-            help=(
-                "When enabled, the batches will come from buckets of "
-                "similar duration (saves padding frames)."
-            ),
+            help="When enabled, the batches will come from buckets of "
+            "similar duration (saves padding frames).",
         )
 
         group.add_argument(
             "--num-buckets",
             type=int,
             default=30,
-            help=(
-                "The number of buckets for the DynamicBucketingSampler "
-                "(you might want to increase it for larger datasets)."
-            ),
+            help="The number of buckets for the DynamicBucketingSampler "
+            "(you might want to increase it for larger datasets).",
         )
 
         group.add_argument(
             "--shuffle",
             type=str2bool,
             default=True,
-            help=(
-                "When enabled (=default), the examples will be shuffled for each epoch."
-            ),
+            help="When enabled (=default), the examples will be "
+            "shuffled for each epoch.",
         )
 
         group.add_argument(
             "--return-cuts",
             type=str2bool,
             default=True,
-            help=(
-                "When enabled, each batch will have the "
-                "field: batch['supervisions']['cut'] with the cuts that "
-                "were used to construct it."
-            ),
+            help="When enabled, each batch will have the "
+            "field: batch['supervisions']['cut'] with the cuts that "
+            "were used to construct it.",
         )
 
         group.add_argument(
             "--num-workers",
             type=int,
             default=2,
-            help="The number of training dataloader workers that collect the batches.",
+            help="The number of training dataloader workers that "
+            "collect the batches.",
         )
 
         group.add_argument(
@@ -119,22 +112,18 @@ class AsrDataModule:
             "--spec-aug-time-warp-factor",
             type=int,
             default=80,
-            help=(
-                "Used only when --enable-spec-aug is True. "
-                "It specifies the factor for time warping in SpecAugment. "
-                "Larger values mean more warping. "
-                "A value less than 1 means to disable time warp."
-            ),
+            help="Used only when --enable-spec-aug is True. "
+            "It specifies the factor for time warping in SpecAugment. "
+            "Larger values mean more warping. "
+            "A value less than 1 means to disable time warp.",
         )
 
         group.add_argument(
             "--enable-musan",
             type=str2bool,
             default=True,
-            help=(
-                "When enabled, select noise from MUSAN and mix it"
-                "with training dataset. "
-            ),
+            help="When enabled, select noise from MUSAN and mix it"
+            "with training dataset. ",
         )
 
         group.add_argument(
@@ -148,11 +137,9 @@ class AsrDataModule:
             "--on-the-fly-feats",
             type=str2bool,
             default=False,
-            help=(
-                "When enabled, use on-the-fly cut mixing and feature "
-                "extraction. Will drop existing precomputed feature manifests "
-                "if available. Used only in dev/test CutSet"
-            ),
+            help="When enabled, use on-the-fly cut mixing and feature "
+            "extraction. Will drop existing precomputed feature manifests "
+            "if available. Used only in dev/test CutSet",
         )
 
     def train_dataloaders(
@@ -175,7 +162,9 @@ class AsrDataModule:
         if cuts_musan is not None:
             logging.info("Enable MUSAN")
             transforms.append(
-                CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
+                CutMix(
+                    cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
+                )
             )
         else:
             logging.info("Disable MUSAN")
@@ -184,7 +173,9 @@ class AsrDataModule:
 
         if self.args.enable_spec_aug:
             logging.info("Enable SpecAugment")
-            logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
+            logging.info(
+                f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
+            )
             # Set the value of num_frame_masks according to Lhotse's version.
             # In different Lhotse's versions, the default of num_frame_masks is
             # different.
@@ -261,7 +252,9 @@ class AsrDataModule:
         if self.args.on_the_fly_feats:
             validate = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
+                input_strategy=OnTheFlyFeatures(
+                    Fbank(FbankConfig(num_mel_bins=80))
+                ),
                 return_cuts=self.args.return_cuts,
             )
         else:
diff --git a/egs/aishell/ASR/transducer_stateless_modified-2/decode.py b/egs/aishell/ASR/transducer_stateless_modified-2/decode.py
index fd4cb8385..ea3f94fd8 100755
--- a/egs/aishell/ASR/transducer_stateless_modified-2/decode.py
+++ b/egs/aishell/ASR/transducer_stateless_modified-2/decode.py
@@ -93,19 +93,16 @@ def get_parser():
         "--epoch",
         type=int,
         default=30,
-        help=(
-            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
-        ),
+        help="It specifies the checkpoint to use for decoding."
+        "Note: Epoch counts from 0.",
     )
     parser.add_argument(
         "--avg",
         type=int,
         default=10,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch'. "
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch'. ",
     )
 
     parser.add_argument(
@@ -173,7 +170,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
 
     parser.add_argument(
@@ -229,7 +227,9 @@ def decode_one_batch(
     supervisions = batch["supervisions"]
     feature_lens = supervisions["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
+    encoder_out, encoder_out_lens = model.encoder(
+        x=feature, x_lens=feature_lens
+    )
 
     if params.decoding_method == "fast_beam_search":
         hyp_tokens = fast_beam_search_one_best(
@@ -241,7 +241,10 @@ def decode_one_batch(
             max_contexts=params.max_contexts,
             max_states=params.max_states,
         )
-    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
+    elif (
+        params.decoding_method == "greedy_search"
+        and params.max_sym_per_frame == 1
+    ):
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -285,7 +288,11 @@ def decode_one_batch(
         return {"greedy_search": hyps}
     elif params.decoding_method == "fast_beam_search":
         return {
-            f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps
+            (
+                f"beam_{params.beam}_"
+                f"max_contexts_{params.max_contexts}_"
+                f"max_states_{params.max_states}"
+            ): hyps
         }
     else:
         return {f"beam_size_{params.beam_size}": hyps}
@@ -358,7 +365,9 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+            logging.info(
+                f"batch {batch_str}, cuts processed until now is {num_cuts}"
+            )
     return results
 
 
@@ -384,7 +393,9 @@ def save_results(
         # we compute CER for aishell dataset.
         results_char = []
         for res in results:
-            results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
+            results_char.append(
+                (res[0], list("".join(res[1])), list("".join(res[2])))
+            )
         with open(errs_filename, "w") as f:
             wer = write_error_stats(
                 f, f"{test_set_name}-{key}", results_char, enable_log=True
@@ -395,7 +406,8 @@ def save_results(
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = (
-        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir
+        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tCER", file=f)
@@ -436,7 +448,9 @@ def main():
         params.suffix += f"-max-contexts-{params.max_contexts}"
         params.suffix += f"-max-states-{params.max_states}"
     elif "beam_search" in params.decoding_method:
-        params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
+        params.suffix += (
+            f"-{params.decoding_method}-beam-size-{params.beam_size}"
+        )
     else:
         params.suffix += f"-context-{params.context_size}"
         params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
diff --git a/egs/aishell/ASR/transducer_stateless_modified-2/export.py b/egs/aishell/ASR/transducer_stateless_modified-2/export.py
index 32481829c..3bd2ceb11 100755
--- a/egs/aishell/ASR/transducer_stateless_modified-2/export.py
+++ b/egs/aishell/ASR/transducer_stateless_modified-2/export.py
@@ -68,20 +68,17 @@ def get_parser():
         "--epoch",
         type=int,
         default=20,
-        help=(
-            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
-        ),
+        help="It specifies the checkpoint to use for decoding."
+        "Note: Epoch counts from 0.",
     )
 
     parser.add_argument(
         "--avg",
         type=int,
         default=10,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch'. "
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch'. ",
     )
 
     parser.add_argument(
@@ -112,7 +109,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
 
     return parser
@@ -243,7 +241,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/aishell/ASR/transducer_stateless_modified-2/pretrained.py b/egs/aishell/ASR/transducer_stateless_modified-2/pretrained.py
index 55701a007..a95a4bc52 100755
--- a/egs/aishell/ASR/transducer_stateless_modified-2/pretrained.py
+++ b/egs/aishell/ASR/transducer_stateless_modified-2/pretrained.py
@@ -87,11 +87,9 @@ def get_parser():
         "--checkpoint",
         type=str,
         required=True,
-        help=(
-            "Path to the checkpoint. "
-            "The checkpoint is assumed to be saved by "
-            "icefall.checkpoint.save_checkpoint()."
-        ),
+        help="Path to the checkpoint. "
+        "The checkpoint is assumed to be saved by "
+        "icefall.checkpoint.save_checkpoint().",
     )
 
     parser.add_argument(
@@ -117,12 +115,10 @@ def get_parser():
         "sound_files",
         type=str,
         nargs="+",
-        help=(
-            "The input sound file(s) to transcribe. "
-            "Supported formats are those supported by torchaudio.load(). "
-            "For example, wav and flac are supported. "
-            "The sample rate has to be 16kHz."
-        ),
+        help="The input sound file(s) to transcribe. "
+        "Supported formats are those supported by torchaudio.load(). "
+        "For example, wav and flac are supported. "
+        "The sample rate has to be 16kHz.",
     )
 
     parser.add_argument(
@@ -169,16 +165,15 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
         type=int,
         default=1,
-        help=(
-            "Maximum number of symbols per frame. "
-            "Use only when --method is greedy_search"
-        ),
+        help="Maximum number of symbols per frame. "
+        "Use only when --method is greedy_search",
     )
 
     return parser
@@ -199,9 +194,10 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert (
-            sample_rate == expected_sample_rate
-        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
+        assert sample_rate == expected_sample_rate, (
+            f"expected sample rate: {expected_sample_rate}. "
+            f"Given: {sample_rate}"
+        )
         # We use only the first channel
         ans.append(wave[0])
     return ans
@@ -258,9 +254,13 @@ def main():
     feature_lens = [f.size(0) for f in features]
     feature_lens = torch.tensor(feature_lens, device=device)
 
-    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
+    features = pad_sequence(
+        features, batch_first=True, padding_value=math.log(1e-10)
+    )
 
-    encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lens)
+    encoder_out, encoder_out_lens = model.encoder(
+        x=features, x_lens=feature_lens
+    )
 
     num_waves = encoder_out.size(0)
     hyp_list = []
@@ -308,7 +308,9 @@ def main():
                     beam=params.beam_size,
                 )
             else:
-                raise ValueError(f"Unsupported decoding method: {params.method}")
+                raise ValueError(
+                    f"Unsupported decoding method: {params.method}"
+                )
             hyp_list.append(hyp)
 
     hyps = []
@@ -325,7 +327,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/aishell/ASR/transducer_stateless_modified-2/train.py b/egs/aishell/ASR/transducer_stateless_modified-2/train.py
index 8fb7d1e49..225d0d709 100755
--- a/egs/aishell/ASR/transducer_stateless_modified-2/train.py
+++ b/egs/aishell/ASR/transducer_stateless_modified-2/train.py
@@ -149,7 +149,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
 
     parser.add_argument(
@@ -167,7 +168,8 @@ def get_parser():
         "--datatang-prob",
         type=float,
         default=0.2,
-        help="The probability to select a batch from the aidatatang_200zh dataset",
+        help="The probability to select a batch from the "
+        "aidatatang_200zh dataset",
     )
 
     return parser
@@ -447,7 +449,9 @@ def compute_loss(
     info = MetricsTracker()
     with warnings.catch_warnings():
         warnings.simplefilter("ignore")
-        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
+        info["frames"] = (
+            (feature_lens // params.subsampling_factor).sum().item()
+        )
 
     # Note: We use reduction=sum while computing the loss.
     info["loss"] = loss.detach().cpu().item()
@@ -601,7 +605,9 @@ def train_one_epoch(
                     f"train/current_{prefix}_",
                     params.batch_idx_train,
                 )
-                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+                tot_loss.write_summary(
+                    tb_writer, "train/tot_", params.batch_idx_train
+                )
                 aishell_tot_loss.write_summary(
                     tb_writer, "train/aishell_tot_", params.batch_idx_train
                 )
@@ -729,7 +735,9 @@ def run(rank, world_size, args):
     train_datatang_cuts = train_datatang_cuts.repeat(times=None)
 
     if args.enable_musan:
-        cuts_musan = load_manifest(Path(args.manifest_dir) / "musan_cuts.jsonl.gz")
+        cuts_musan = load_manifest(
+            Path(args.manifest_dir) / "musan_cuts.jsonl.gz"
+        )
     else:
         cuts_musan = None
 
@@ -768,7 +776,9 @@ def run(rank, world_size, args):
 
         cur_lr = optimizer._rate
         if tb_writer is not None:
-            tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
+            tb_writer.add_scalar(
+                "train/learning_rate", cur_lr, params.batch_idx_train
+            )
             tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
 
         if rank == 0:
diff --git a/egs/aishell/ASR/transducer_stateless_modified/decode.py b/egs/aishell/ASR/transducer_stateless_modified/decode.py
index 1e41942da..65fcda873 100755
--- a/egs/aishell/ASR/transducer_stateless_modified/decode.py
+++ b/egs/aishell/ASR/transducer_stateless_modified/decode.py
@@ -94,19 +94,16 @@ def get_parser():
         "--epoch",
         type=int,
         default=30,
-        help=(
-            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
-        ),
+        help="It specifies the checkpoint to use for decoding."
+        "Note: Epoch counts from 0.",
     )
     parser.add_argument(
         "--avg",
         type=int,
         default=10,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch'. "
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch'. ",
     )
 
     parser.add_argument(
@@ -174,7 +171,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
 
     parser.add_argument(
@@ -233,7 +231,9 @@ def decode_one_batch(
     supervisions = batch["supervisions"]
     feature_lens = supervisions["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
+    encoder_out, encoder_out_lens = model.encoder(
+        x=feature, x_lens=feature_lens
+    )
 
     if params.decoding_method == "fast_beam_search":
         hyp_tokens = fast_beam_search_one_best(
@@ -245,7 +245,10 @@ def decode_one_batch(
             max_contexts=params.max_contexts,
             max_states=params.max_states,
         )
-    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
+    elif (
+        params.decoding_method == "greedy_search"
+        and params.max_sym_per_frame == 1
+    ):
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -289,7 +292,11 @@ def decode_one_batch(
         return {"greedy_search": hyps}
     elif params.decoding_method == "fast_beam_search":
         return {
-            f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps
+            (
+                f"beam_{params.beam}_"
+                f"max_contexts_{params.max_contexts}_"
+                f"max_states_{params.max_states}"
+            ): hyps
         }
     else:
         return {f"beam_size_{params.beam_size}": hyps}
@@ -362,7 +369,9 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+            logging.info(
+                f"batch {batch_str}, cuts processed until now is {num_cuts}"
+            )
     return results
 
 
@@ -388,7 +397,9 @@ def save_results(
         # we compute CER for aishell dataset.
         results_char = []
         for res in results:
-            results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
+            results_char.append(
+                (res[0], list("".join(res[1])), list("".join(res[2])))
+            )
         with open(errs_filename, "w") as f:
             wer = write_error_stats(
                 f, f"{test_set_name}-{key}", results_char, enable_log=True
@@ -399,7 +410,8 @@ def save_results(
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = (
-        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir
+        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tCER", file=f)
@@ -440,7 +452,9 @@ def main():
         params.suffix += f"-max-contexts-{params.max_contexts}"
         params.suffix += f"-max-states-{params.max_states}"
     elif "beam_search" in params.decoding_method:
-        params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
+        params.suffix += (
+            f"-{params.decoding_method}-beam-size-{params.beam_size}"
+        )
     else:
         params.suffix += f"-context-{params.context_size}"
         params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
diff --git a/egs/aishell/ASR/transducer_stateless_modified/export.py b/egs/aishell/ASR/transducer_stateless_modified/export.py
index ca1d4bd4a..11335a834 100755
--- a/egs/aishell/ASR/transducer_stateless_modified/export.py
+++ b/egs/aishell/ASR/transducer_stateless_modified/export.py
@@ -68,20 +68,17 @@ def get_parser():
         "--epoch",
         type=int,
         default=20,
-        help=(
-            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
-        ),
+        help="It specifies the checkpoint to use for decoding."
+        "Note: Epoch counts from 0.",
     )
 
     parser.add_argument(
         "--avg",
         type=int,
         default=10,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch'. "
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch'. ",
     )
 
     parser.add_argument(
@@ -112,7 +109,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
 
     return parser
@@ -243,7 +241,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/aishell/ASR/transducer_stateless_modified/pretrained.py b/egs/aishell/ASR/transducer_stateless_modified/pretrained.py
index 038090461..262e822c2 100755
--- a/egs/aishell/ASR/transducer_stateless_modified/pretrained.py
+++ b/egs/aishell/ASR/transducer_stateless_modified/pretrained.py
@@ -87,11 +87,9 @@ def get_parser():
         "--checkpoint",
         type=str,
         required=True,
-        help=(
-            "Path to the checkpoint. "
-            "The checkpoint is assumed to be saved by "
-            "icefall.checkpoint.save_checkpoint()."
-        ),
+        help="Path to the checkpoint. "
+        "The checkpoint is assumed to be saved by "
+        "icefall.checkpoint.save_checkpoint().",
     )
 
     parser.add_argument(
@@ -117,12 +115,10 @@ def get_parser():
         "sound_files",
         type=str,
         nargs="+",
-        help=(
-            "The input sound file(s) to transcribe. "
-            "Supported formats are those supported by torchaudio.load(). "
-            "For example, wav and flac are supported. "
-            "The sample rate has to be 16kHz."
-        ),
+        help="The input sound file(s) to transcribe. "
+        "Supported formats are those supported by torchaudio.load(). "
+        "For example, wav and flac are supported. "
+        "The sample rate has to be 16kHz.",
     )
 
     parser.add_argument(
@@ -169,16 +165,15 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
         type=int,
         default=1,
-        help=(
-            "Maximum number of symbols per frame. "
-            "Use only when --method is greedy_search"
-        ),
+        help="Maximum number of symbols per frame. "
+        "Use only when --method is greedy_search",
     )
 
     return parser
@@ -199,9 +194,10 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert (
-            sample_rate == expected_sample_rate
-        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
+        assert sample_rate == expected_sample_rate, (
+            f"expected sample rate: {expected_sample_rate}. "
+            f"Given: {sample_rate}"
+        )
         # We use only the first channel
         ans.append(wave[0])
     return ans
@@ -258,9 +254,13 @@ def main():
     feature_lens = [f.size(0) for f in features]
     feature_lens = torch.tensor(feature_lens, device=device)
 
-    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
+    features = pad_sequence(
+        features, batch_first=True, padding_value=math.log(1e-10)
+    )
 
-    encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lens)
+    encoder_out, encoder_out_lens = model.encoder(
+        x=features, x_lens=feature_lens
+    )
 
     num_waves = encoder_out.size(0)
     hyp_list = []
@@ -308,7 +308,9 @@ def main():
                     beam=params.beam_size,
                 )
             else:
-                raise ValueError(f"Unsupported decoding method: {params.method}")
+                raise ValueError(
+                    f"Unsupported decoding method: {params.method}"
+                )
             hyp_list.append(hyp)
 
     hyps = []
@@ -325,7 +327,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/aishell/ASR/transducer_stateless_modified/train.py b/egs/aishell/ASR/transducer_stateless_modified/train.py
index 5f116f2bd..d3ffccafa 100755
--- a/egs/aishell/ASR/transducer_stateless_modified/train.py
+++ b/egs/aishell/ASR/transducer_stateless_modified/train.py
@@ -142,7 +142,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
 
     parser.add_argument(
@@ -413,7 +414,9 @@ def compute_loss(
     info = MetricsTracker()
     with warnings.catch_warnings():
         warnings.simplefilter("ignore")
-        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
+        info["frames"] = (
+            (feature_lens // params.subsampling_factor).sum().item()
+        )
 
     # Note: We use reduction=sum while computing the loss.
     info["loss"] = loss.detach().cpu().item()
@@ -526,7 +529,9 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+                tot_loss.write_summary(
+                    tb_writer, "train/tot_", params.batch_idx_train
+                )
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
@@ -652,7 +657,9 @@ def run(rank, world_size, args):
 
         cur_lr = optimizer._rate
         if tb_writer is not None:
-            tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
+            tb_writer.add_scalar(
+                "train/learning_rate", cur_lr, params.batch_idx_train
+            )
             tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
 
         if rank == 0:
diff --git a/egs/aishell2/ASR/local/__init__.py b/egs/aishell2/ASR/local/__init__.py
old mode 100644
new mode 100755
diff --git a/egs/aishell2/ASR/local/compute_fbank_aishell2.py b/egs/aishell2/ASR/local/compute_fbank_aishell2.py
index ec0c584ca..d8d3622bd 100755
--- a/egs/aishell2/ASR/local/compute_fbank_aishell2.py
+++ b/egs/aishell2/ASR/local/compute_fbank_aishell2.py
@@ -83,7 +83,9 @@ def compute_fbank_aishell2(num_mel_bins: int = 80):
             )
             if "train" in partition:
                 cut_set = (
-                    cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
+                    cut_set
+                    + cut_set.perturb_speed(0.9)
+                    + cut_set.perturb_speed(1.1)
                 )
             cut_set = cut_set.compute_and_store_features(
                 extractor=extractor,
@@ -109,7 +111,9 @@ def get_args():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
 
diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/__init__.py b/egs/aishell2/ASR/pruned_transducer_stateless5/__init__.py
old mode 100644
new mode 100755
diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/asr_datamodule.py b/egs/aishell2/ASR/pruned_transducer_stateless5/asr_datamodule.py
old mode 100644
new mode 100755
index e8966b554..b7a21f579
--- a/egs/aishell2/ASR/pruned_transducer_stateless5/asr_datamodule.py
+++ b/egs/aishell2/ASR/pruned_transducer_stateless5/asr_datamodule.py
@@ -76,12 +76,10 @@ class AiShell2AsrDataModule:
     def add_arguments(cls, parser: argparse.ArgumentParser):
         group = parser.add_argument_group(
             title="ASR data related options",
-            description=(
-                "These options are used for the preparation of "
-                "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
-                "effective batch sizes, sampling strategies, applied data "
-                "augmentations, etc."
-            ),
+            description="These options are used for the preparation of "
+            "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
+            "effective batch sizes, sampling strategies, applied data "
+            "augmentations, etc.",
         )
         group.add_argument(
             "--manifest-dir",
@@ -93,74 +91,59 @@ class AiShell2AsrDataModule:
             "--max-duration",
             type=int,
             default=200.0,
-            help=(
-                "Maximum pooled recordings duration (seconds) in a "
-                "single batch. You can reduce it if it causes CUDA OOM."
-            ),
+            help="Maximum pooled recordings duration (seconds) in a "
+            "single batch. You can reduce it if it causes CUDA OOM.",
         )
         group.add_argument(
             "--bucketing-sampler",
             type=str2bool,
             default=True,
-            help=(
-                "When enabled, the batches will come from buckets of "
-                "similar duration (saves padding frames)."
-            ),
+            help="When enabled, the batches will come from buckets of "
+            "similar duration (saves padding frames).",
         )
         group.add_argument(
             "--num-buckets",
             type=int,
             default=30,
-            help=(
-                "The number of buckets for the DynamicBucketingSampler"
-                "(you might want to increase it for larger datasets)."
-            ),
+            help="The number of buckets for the DynamicBucketingSampler"
+            "(you might want to increase it for larger datasets).",
         )
         group.add_argument(
             "--concatenate-cuts",
             type=str2bool,
             default=False,
-            help=(
-                "When enabled, utterances (cuts) will be concatenated "
-                "to minimize the amount of padding."
-            ),
+            help="When enabled, utterances (cuts) will be concatenated "
+            "to minimize the amount of padding.",
         )
         group.add_argument(
             "--duration-factor",
             type=float,
             default=1.0,
-            help=(
-                "Determines the maximum duration of a concatenated cut "
-                "relative to the duration of the longest cut in a batch."
-            ),
+            help="Determines the maximum duration of a concatenated cut "
+            "relative to the duration of the longest cut in a batch.",
         )
         group.add_argument(
             "--gap",
             type=float,
             default=1.0,
-            help=(
-                "The amount of padding (in seconds) inserted between "
-                "concatenated cuts. This padding is filled with noise when "
-                "noise augmentation is used."
-            ),
+            help="The amount of padding (in seconds) inserted between "
+            "concatenated cuts. This padding is filled with noise when "
+            "noise augmentation is used.",
         )
         group.add_argument(
             "--on-the-fly-feats",
             type=str2bool,
             default=False,
-            help=(
-                "When enabled, use on-the-fly cut mixing and feature "
-                "extraction. Will drop existing precomputed feature manifests "
-                "if available."
-            ),
+            help="When enabled, use on-the-fly cut mixing and feature "
+            "extraction. Will drop existing precomputed feature manifests "
+            "if available.",
         )
         group.add_argument(
             "--shuffle",
             type=str2bool,
             default=True,
-            help=(
-                "When enabled (=default), the examples will be shuffled for each epoch."
-            ),
+            help="When enabled (=default), the examples will be "
+            "shuffled for each epoch.",
         )
         group.add_argument(
             "--drop-last",
@@ -172,18 +155,17 @@ class AiShell2AsrDataModule:
             "--return-cuts",
             type=str2bool,
             default=True,
-            help=(
-                "When enabled, each batch will have the "
-                "field: batch['supervisions']['cut'] with the cuts that "
-                "were used to construct it."
-            ),
+            help="When enabled, each batch will have the "
+            "field: batch['supervisions']['cut'] with the cuts that "
+            "were used to construct it.",
         )
 
         group.add_argument(
             "--num-workers",
             type=int,
             default=2,
-            help="The number of training dataloader workers that collect the batches.",
+            help="The number of training dataloader workers that "
+            "collect the batches.",
         )
 
         group.add_argument(
@@ -197,22 +179,18 @@ class AiShell2AsrDataModule:
             "--spec-aug-time-warp-factor",
             type=int,
             default=80,
-            help=(
-                "Used only when --enable-spec-aug is True. "
-                "It specifies the factor for time warping in SpecAugment. "
-                "Larger values mean more warping. "
-                "A value less than 1 means to disable time warp."
-            ),
+            help="Used only when --enable-spec-aug is True. "
+            "It specifies the factor for time warping in SpecAugment. "
+            "Larger values mean more warping. "
+            "A value less than 1 means to disable time warp.",
         )
 
         group.add_argument(
             "--enable-musan",
             type=str2bool,
             default=True,
-            help=(
-                "When enabled, select noise from MUSAN and mix it"
-                "with training dataset. "
-            ),
+            help="When enabled, select noise from MUSAN and mix it"
+            "with training dataset. ",
         )
 
         group.add_argument(
@@ -238,16 +216,20 @@ class AiShell2AsrDataModule:
         if self.args.enable_musan:
             logging.info("Enable MUSAN")
             logging.info("About to get Musan cuts")
-            cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
+            cuts_musan = load_manifest(
+                self.args.manifest_dir / "musan_cuts.jsonl.gz"
+            )
             transforms.append(
-                CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
+                CutMix(
+                    cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
+                )
             )
         else:
             logging.info("Disable MUSAN")
 
         if self.args.concatenate_cuts:
             logging.info(
-                "Using cut concatenation with duration factor "
+                f"Using cut concatenation with duration factor "
                 f"{self.args.duration_factor} and gap {self.args.gap}."
             )
             # Cut concatenation should be the first transform in the list,
@@ -262,7 +244,9 @@ class AiShell2AsrDataModule:
         input_transforms = []
         if self.args.enable_spec_aug:
             logging.info("Enable SpecAugment")
-            logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
+            logging.info(
+                f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
+            )
             # Set the value of num_frame_masks according to Lhotse's version.
             # In different Lhotse's versions, the default of num_frame_masks is
             # different.
@@ -306,7 +290,9 @@ class AiShell2AsrDataModule:
             # Drop feats to be on the safe side.
             train = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
+                input_strategy=OnTheFlyFeatures(
+                    Fbank(FbankConfig(num_mel_bins=80))
+                ),
                 input_transforms=input_transforms,
                 return_cuts=self.args.return_cuts,
             )
@@ -362,7 +348,9 @@ class AiShell2AsrDataModule:
         if self.args.on_the_fly_feats:
             validate = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
+                input_strategy=OnTheFlyFeatures(
+                    Fbank(FbankConfig(num_mel_bins=80))
+                ),
                 return_cuts=self.args.return_cuts,
             )
         else:
@@ -418,7 +406,9 @@ class AiShell2AsrDataModule:
     @lru_cache()
     def valid_cuts(self) -> CutSet:
         logging.info("About to gen cuts from aishell2_cuts_dev.jsonl.gz")
-        return load_manifest_lazy(self.args.manifest_dir / "aishell2_cuts_dev.jsonl.gz")
+        return load_manifest_lazy(
+            self.args.manifest_dir / "aishell2_cuts_dev.jsonl.gz"
+        )
 
     @lru_cache()
     def test_cuts(self) -> CutSet:
diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/decode.py b/egs/aishell2/ASR/pruned_transducer_stateless5/decode.py
index 64b64d1b1..915737f4a 100755
--- a/egs/aishell2/ASR/pruned_transducer_stateless5/decode.py
+++ b/egs/aishell2/ASR/pruned_transducer_stateless5/decode.py
@@ -168,24 +168,20 @@ def get_parser():
         "--avg",
         type=int,
         default=15,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch' and '--iter'"
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch' and '--iter'",
     )
 
     parser.add_argument(
         "--use-averaged-model",
         type=str2bool,
         default=True,
-        help=(
-            "Whether to load averaged model. Currently it only supports "
-            "using --epoch. If True, it would decode with the averaged model "
-            "over the epoch range from `epoch-avg` (excluded) to `epoch`."
-            "Actually only the models with epoch number of `epoch-avg` and "
-            "`epoch` are loaded for averaging. "
-        ),
+        help="Whether to load averaged model. Currently it only supports "
+        "using --epoch. If True, it would decode with the averaged model "
+        "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+        "Actually only the models with epoch number of `epoch-avg` and "
+        "`epoch` are loaded for averaging. ",
     )
 
     parser.add_argument(
@@ -273,7 +269,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -351,7 +348,9 @@ def decode_one_batch(
     supervisions = batch["supervisions"]
     feature_lens = supervisions["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
+    encoder_out, encoder_out_lens = model.encoder(
+        x=feature, x_lens=feature_lens
+    )
     hyps = []
 
     if params.decoding_method == "fast_beam_search":
@@ -410,7 +409,10 @@ def decode_one_batch(
         )
         for i in range(encoder_out.size(0)):
             hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
-    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
+    elif (
+        params.decoding_method == "greedy_search"
+        and params.max_sym_per_frame == 1
+    ):
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -536,7 +538,9 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+            logging.info(
+                f"batch {batch_str}, cuts processed until now is {num_cuts}"
+            )
     return results
 
 
@@ -569,7 +573,8 @@ def save_results(
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = (
-        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir
+        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tWER", file=f)
@@ -620,7 +625,9 @@ def main():
             if "LG" in params.decoding_method:
                 params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
     elif "beam_search" in params.decoding_method:
-        params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
+        params.suffix += (
+            f"-{params.decoding_method}-beam-size-{params.beam_size}"
+        )
     else:
         params.suffix += f"-context-{params.context_size}"
         params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
@@ -654,12 +661,13 @@ def main():
 
     if not params.use_averaged_model:
         if params.iter > 0:
-            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
-                : params.avg
-            ]
+            filenames = find_checkpoints(
+                params.exp_dir, iteration=-params.iter
+            )[: params.avg]
             if len(filenames) == 0:
                 raise ValueError(
-                    f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
                 )
             elif len(filenames) < params.avg:
                 raise ValueError(
@@ -682,12 +690,13 @@ def main():
             model.load_state_dict(average_checkpoints(filenames, device=device))
     else:
         if params.iter > 0:
-            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
-                : params.avg + 1
-            ]
+            filenames = find_checkpoints(
+                params.exp_dir, iteration=-params.iter
+            )[: params.avg + 1]
             if len(filenames) == 0:
                 raise ValueError(
-                    f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
                 )
             elif len(filenames) < params.avg + 1:
                 raise ValueError(
@@ -715,7 +724,7 @@ def main():
             filename_start = f"{params.exp_dir}/epoch-{start}.pt"
             filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
             logging.info(
-                "Calculating the averaged model over epoch range from "
+                f"Calculating the averaged model over epoch range from "
                 f"{start} (excluded) to {params.epoch}"
             )
             model.to(device)
@@ -740,7 +749,9 @@ def main():
             )
             decoding_graph.scores *= params.ngram_lm_scale
         else:
-            decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
+            decoding_graph = k2.trivial_graph(
+                params.vocab_size - 1, device=device
+            )
     else:
         decoding_graph = None
 
diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/export.py b/egs/aishell2/ASR/pruned_transducer_stateless5/export.py
index 547ce2069..bc7bd71cb 100755
--- a/egs/aishell2/ASR/pruned_transducer_stateless5/export.py
+++ b/egs/aishell2/ASR/pruned_transducer_stateless5/export.py
@@ -89,24 +89,20 @@ def get_parser():
         "--avg",
         type=int,
         default=15,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch' and '--iter'"
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch' and '--iter'",
     )
 
     parser.add_argument(
         "--use-averaged-model",
         type=str2bool,
         default=False,
-        help=(
-            "Whether to load averaged model. Currently it only supports "
-            "using --epoch. If True, it would decode with the averaged model "
-            "over the epoch range from `epoch-avg` (excluded) to `epoch`."
-            "Actually only the models with epoch number of `epoch-avg` and "
-            "`epoch` are loaded for averaging. "
-        ),
+        help="Whether to load averaged model. Currently it only supports "
+        "using --epoch. If True, it would decode with the averaged model "
+        "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+        "Actually only the models with epoch number of `epoch-avg` and "
+        "`epoch` are loaded for averaging. ",
     )
 
     parser.add_argument(
@@ -137,7 +133,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
 
     add_model_arguments(parser)
@@ -170,12 +167,13 @@ def main():
 
     if not params.use_averaged_model:
         if params.iter > 0:
-            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
-                : params.avg
-            ]
+            filenames = find_checkpoints(
+                params.exp_dir, iteration=-params.iter
+            )[: params.avg]
             if len(filenames) == 0:
                 raise ValueError(
-                    f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
                 )
             elif len(filenames) < params.avg:
                 raise ValueError(
@@ -198,12 +196,13 @@ def main():
             model.load_state_dict(average_checkpoints(filenames, device=device))
     else:
         if params.iter > 0:
-            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
-                : params.avg + 1
-            ]
+            filenames = find_checkpoints(
+                params.exp_dir, iteration=-params.iter
+            )[: params.avg + 1]
             if len(filenames) == 0:
                 raise ValueError(
-                    f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
                 )
             elif len(filenames) < params.avg + 1:
                 raise ValueError(
@@ -231,7 +230,7 @@ def main():
             filename_start = f"{params.exp_dir}/epoch-{start}.pt"
             filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
             logging.info(
-                "Calculating the averaged model over epoch range from "
+                f"Calculating the averaged model over epoch range from "
                 f"{start} (excluded) to {params.epoch}"
             )
             model.to(device)
@@ -267,7 +266,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/pretrained.py b/egs/aishell2/ASR/pruned_transducer_stateless5/pretrained.py
index 4b16511e8..09de1bece 100755
--- a/egs/aishell2/ASR/pruned_transducer_stateless5/pretrained.py
+++ b/egs/aishell2/ASR/pruned_transducer_stateless5/pretrained.py
@@ -81,11 +81,9 @@ def get_parser():
         "--checkpoint",
         type=str,
         required=True,
-        help=(
-            "Path to the checkpoint. "
-            "The checkpoint is assumed to be saved by "
-            "icefall.checkpoint.save_checkpoint()."
-        ),
+        help="Path to the checkpoint. "
+        "The checkpoint is assumed to be saved by "
+        "icefall.checkpoint.save_checkpoint().",
     )
 
     parser.add_argument(
@@ -111,12 +109,10 @@ def get_parser():
         "sound_files",
         type=str,
         nargs="+",
-        help=(
-            "The input sound file(s) to transcribe. "
-            "Supported formats are those supported by torchaudio.load(). "
-            "For example, wav and flac are supported. "
-            "The sample rate has to be 16kHz."
-        ),
+        help="The input sound file(s) to transcribe. "
+        "Supported formats are those supported by torchaudio.load(). "
+        "For example, wav and flac are supported. "
+        "The sample rate has to be 16kHz.",
     )
 
     parser.add_argument(
@@ -163,7 +159,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -194,9 +191,10 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert (
-            sample_rate == expected_sample_rate
-        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
+        assert sample_rate == expected_sample_rate, (
+            f"expected sample rate: {expected_sample_rate}. "
+            f"Given: {sample_rate}"
+        )
         # We use only the first channel
         ans.append(wave[0])
     return ans
@@ -256,11 +254,15 @@ def main():
     features = fbank(waves)
     feature_lengths = [f.size(0) for f in features]
 
-    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
+    features = pad_sequence(
+        features, batch_first=True, padding_value=math.log(1e-10)
+    )
 
     feature_lengths = torch.tensor(feature_lengths, device=device)
 
-    encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths)
+    encoder_out, encoder_out_lens = model.encoder(
+        x=features, x_lens=feature_lengths
+    )
 
     num_waves = encoder_out.size(0)
     hyps = []
@@ -332,7 +334,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/train.py b/egs/aishell2/ASR/pruned_transducer_stateless5/train.py
index d37e7bdca..838a0497f 100755
--- a/egs/aishell2/ASR/pruned_transducer_stateless5/train.py
+++ b/egs/aishell2/ASR/pruned_transducer_stateless5/train.py
@@ -92,7 +92,9 @@ from icefall.env import get_env_info
 from icefall.lexicon import Lexicon
 from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
 
-LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
+LRSchedulerType = Union[
+    torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
+]
 
 
 def add_model_arguments(parser: argparse.ArgumentParser):
@@ -218,7 +220,8 @@ def get_parser():
         "--initial-lr",
         type=float,
         default=0.003,
-        help="The initial learning rate.  This value should not need to be changed.",
+        help="The initial learning rate.  This value should not need "
+        "to be changed.",
     )
 
     parser.add_argument(
@@ -241,45 +244,42 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
 
     parser.add_argument(
         "--prune-range",
         type=int,
         default=5,
-        help=(
-            "The prune range for rnnt loss, it means how many symbols(context)"
-            "we are using to compute the loss"
-        ),
+        help="The prune range for rnnt loss, it means how many symbols(context)"
+        "we are using to compute the loss",
     )
 
     parser.add_argument(
         "--lm-scale",
         type=float,
         default=0.25,
-        help=(
-            "The scale to smooth the loss with lm (output of prediction network) part."
-        ),
+        help="The scale to smooth the loss with lm "
+        "(output of prediction network) part.",
     )
 
     parser.add_argument(
         "--am-scale",
         type=float,
         default=0.0,
-        help="The scale to smooth the loss with am (output of encoder network)part.",
+        help="The scale to smooth the loss with am (output of encoder network)"
+        "part.",
     )
 
     parser.add_argument(
         "--simple-loss-scale",
         type=float,
         default=0.5,
-        help=(
-            "To get pruning ranges, we will calculate a simple version"
-            "loss(joiner is just addition), this simple loss also uses for"
-            "training (as a regularization item). We will scale the simple loss"
-            "with this parameter before adding to the final loss."
-        ),
+        help="To get pruning ranges, we will calculate a simple version"
+        "loss(joiner is just addition), this simple loss also uses for"
+        "training (as a regularization item). We will scale the simple loss"
+        "with this parameter before adding to the final loss.",
     )
 
     parser.add_argument(
@@ -603,7 +603,11 @@ def compute_loss(
      warmup: a floating point value which increases throughout training;
         values >= 1.0 are fully warmed up and have all modules present.
     """
-    device = model.device if isinstance(model, DDP) else next(model.parameters()).device
+    device = (
+        model.device
+        if isinstance(model, DDP)
+        else next(model.parameters()).device
+    )
     feature = batch["inputs"]
     # at entry, feature is (N, T, C)
     assert feature.ndim == 3
@@ -632,16 +636,23 @@ def compute_loss(
         # overwhelming the simple_loss and causing it to diverge,
         # in case it had not fully learned the alignment yet.
         pruned_loss_scale = (
-            0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
+            0.0
+            if warmup < 1.0
+            else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
+        )
+        loss = (
+            params.simple_loss_scale * simple_loss
+            + pruned_loss_scale * pruned_loss
         )
-        loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
 
     assert loss.requires_grad == is_training
 
     info = MetricsTracker()
     with warnings.catch_warnings():
         warnings.simplefilter("ignore")
-        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
+        info["frames"] = (
+            (feature_lens // params.subsampling_factor).sum().item()
+        )
 
     # Note: We use reduction=sum while computing the loss.
     info["loss"] = loss.detach().cpu().item()
@@ -760,7 +771,9 @@ def train_one_epoch(
             scaler.update()
             optimizer.zero_grad()
         except:  # noqa
-            display_and_save_batch(batch, params=params, graph_compiler=graph_compiler)
+            display_and_save_batch(
+                batch, params=params, graph_compiler=graph_compiler
+            )
             raise
 
         if params.print_diagnostics and batch_idx == 5:
@@ -816,7 +829,9 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+                tot_loss.write_summary(
+                    tb_writer, "train/tot_", params.batch_idx_train
+                )
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
@@ -924,7 +939,7 @@ def run(rank, world_size, args):
 
     if params.print_diagnostics:
         opts = diagnostics.TensorDiagnosticOptions(
-            2**22
+            2 ** 22
         )  # allow 4 megabytes per sub-module
         diagnostic = diagnostics.attach_diagnostics(model, opts)
 
@@ -1089,7 +1104,9 @@ def scan_pessimistic_batches_for_oom(
                     f"Failing criterion: {criterion} "
                     f"(={crit_values[criterion]}) ..."
                 )
-            display_and_save_batch(batch, params=params, graph_compiler=graph_compiler)
+            display_and_save_batch(
+                batch, params=params, graph_compiler=graph_compiler
+            )
             raise
 
 
diff --git a/egs/aishell4/ASR/local/compute_fbank_aishell4.py b/egs/aishell4/ASR/local/compute_fbank_aishell4.py
index 400c406f0..3f50d9e3e 100755
--- a/egs/aishell4/ASR/local/compute_fbank_aishell4.py
+++ b/egs/aishell4/ASR/local/compute_fbank_aishell4.py
@@ -85,7 +85,9 @@ def compute_fbank_aishell4(num_mel_bins: int = 80):
             )
             if "train" in partition:
                 cut_set = (
-                    cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
+                    cut_set
+                    + cut_set.perturb_speed(0.9)
+                    + cut_set.perturb_speed(1.1)
                 )
             cut_set = cut_set.compute_and_store_features(
                 extractor=extractor,
@@ -118,7 +120,9 @@ def get_args():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
 
diff --git a/egs/aishell4/ASR/local/prepare_char.py b/egs/aishell4/ASR/local/prepare_char.py
index 6b440dfb3..d9e47d17a 100755
--- a/egs/aishell4/ASR/local/prepare_char.py
+++ b/egs/aishell4/ASR/local/prepare_char.py
@@ -86,7 +86,9 @@ def lexicon_to_fst_no_sil(
         cur_state = loop_state
 
         word = word2id[word]
-        pieces = [token2id[i] if i in token2id else token2id[""] for i in pieces]
+        pieces = [
+            token2id[i] if i in token2id else token2id[""] for i in pieces
+        ]
 
         for i in range(len(pieces) - 1):
             w = word if i == 0 else eps
@@ -140,7 +142,9 @@ def contain_oov(token_sym_table: Dict[str, int], tokens: List[str]) -> bool:
     return False
 
 
-def generate_lexicon(token_sym_table: Dict[str, int], words: List[str]) -> Lexicon:
+def generate_lexicon(
+    token_sym_table: Dict[str, int], words: List[str]
+) -> Lexicon:
     """Generate a lexicon from a word list and token_sym_table.
 
     Args:
diff --git a/egs/aishell4/ASR/local/prepare_lang.py b/egs/aishell4/ASR/local/prepare_lang.py
index c8cf9b881..e5ae89ec4 100755
--- a/egs/aishell4/ASR/local/prepare_lang.py
+++ b/egs/aishell4/ASR/local/prepare_lang.py
@@ -317,7 +317,9 @@ def lexicon_to_fst(
 
 def get_args():
     parser = argparse.ArgumentParser()
-    parser.add_argument("--lang-dir", type=str, help="The lang dir, data/lang_phone")
+    parser.add_argument(
+        "--lang-dir", type=str, help="The lang dir, data/lang_phone"
+    )
     return parser.parse_args()
 
 
diff --git a/egs/aishell4/ASR/local/test_prepare_lang.py b/egs/aishell4/ASR/local/test_prepare_lang.py
index 74e025ad7..d4cf62bba 100755
--- a/egs/aishell4/ASR/local/test_prepare_lang.py
+++ b/egs/aishell4/ASR/local/test_prepare_lang.py
@@ -88,7 +88,9 @@ def test_read_lexicon(filename: str):
     fsa.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
     fsa.draw("L.pdf", title="L")
 
-    fsa_disambig = lexicon_to_fst(lexicon_disambig, phone2id=phone2id, word2id=word2id)
+    fsa_disambig = lexicon_to_fst(
+        lexicon_disambig, phone2id=phone2id, word2id=word2id
+    )
     fsa_disambig.labels_sym = k2.SymbolTable.from_file("phones.txt")
     fsa_disambig.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
     fsa_disambig.draw("L_disambig.pdf", title="L_disambig")
diff --git a/egs/aishell4/ASR/local/text2token.py b/egs/aishell4/ASR/local/text2token.py
index 2be639b7a..71be2a613 100755
--- a/egs/aishell4/ASR/local/text2token.py
+++ b/egs/aishell4/ASR/local/text2token.py
@@ -50,15 +50,15 @@ def get_parser():
         "-n",
         default=1,
         type=int,
-        help=(
-            "number of characters to split, i.e.,                         aabb -> a a b"
-            " b with -n 1 and aa bb with -n 2"
-        ),
+        help="number of characters to split, i.e., \
+                        aabb -> a a b b with -n 1 and aa bb with -n 2",
     )
     parser.add_argument(
         "--skip-ncols", "-s", default=0, type=int, help="skip first n columns"
     )
-    parser.add_argument("--space", default="", type=str, help="space symbol")
+    parser.add_argument(
+        "--space", default="", type=str, help="space symbol"
+    )
     parser.add_argument(
         "--non-lang-syms",
         "-l",
@@ -66,7 +66,9 @@ def get_parser():
         type=str,
         help="list of non-linguistic symobles, e.g.,  etc.",
     )
-    parser.add_argument("text", type=str, default=False, nargs="?", help="input text")
+    parser.add_argument(
+        "text", type=str, default=False, nargs="?", help="input text"
+    )
     parser.add_argument(
         "--trans_type",
         "-t",
@@ -106,7 +108,8 @@ def token2id(
             if token_type == "lazy_pinyin":
                 text = lazy_pinyin(chars_list)
                 sub_ids = [
-                    token_table[txt] if txt in token_table else oov_id for txt in text
+                    token_table[txt] if txt in token_table else oov_id
+                    for txt in text
                 ]
                 ids.append(sub_ids)
             else:  # token_type = "pinyin"
@@ -132,7 +135,9 @@ def main():
     if args.text:
         f = codecs.open(args.text, encoding="utf-8")
     else:
-        f = codecs.getreader("utf-8")(sys.stdin if is_python2 else sys.stdin.buffer)
+        f = codecs.getreader("utf-8")(
+            sys.stdin if is_python2 else sys.stdin.buffer
+        )
 
     sys.stdout = codecs.getwriter("utf-8")(
         sys.stdout if is_python2 else sys.stdout.buffer
diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/asr_datamodule.py b/egs/aishell4/ASR/pruned_transducer_stateless5/asr_datamodule.py
index 84c7f0443..7aa53ddda 100644
--- a/egs/aishell4/ASR/pruned_transducer_stateless5/asr_datamodule.py
+++ b/egs/aishell4/ASR/pruned_transducer_stateless5/asr_datamodule.py
@@ -74,12 +74,10 @@ class Aishell4AsrDataModule:
     def add_arguments(cls, parser: argparse.ArgumentParser):
         group = parser.add_argument_group(
             title="ASR data related options",
-            description=(
-                "These options are used for the preparation of "
-                "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
-                "effective batch sizes, sampling strategies, applied data "
-                "augmentations, etc."
-            ),
+            description="These options are used for the preparation of "
+            "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
+            "effective batch sizes, sampling strategies, applied data "
+            "augmentations, etc.",
         )
 
         group.add_argument(
@@ -93,81 +91,66 @@ class Aishell4AsrDataModule:
             "--max-duration",
             type=int,
             default=200.0,
-            help=(
-                "Maximum pooled recordings duration (seconds) in a "
-                "single batch. You can reduce it if it causes CUDA OOM."
-            ),
+            help="Maximum pooled recordings duration (seconds) in a "
+            "single batch. You can reduce it if it causes CUDA OOM.",
         )
 
         group.add_argument(
             "--bucketing-sampler",
             type=str2bool,
             default=True,
-            help=(
-                "When enabled, the batches will come from buckets of "
-                "similar duration (saves padding frames)."
-            ),
+            help="When enabled, the batches will come from buckets of "
+            "similar duration (saves padding frames).",
         )
 
         group.add_argument(
             "--num-buckets",
             type=int,
             default=300,
-            help=(
-                "The number of buckets for the DynamicBucketingSampler"
-                "(you might want to increase it for larger datasets)."
-            ),
+            help="The number of buckets for the DynamicBucketingSampler"
+            "(you might want to increase it for larger datasets).",
         )
 
         group.add_argument(
             "--concatenate-cuts",
             type=str2bool,
             default=False,
-            help=(
-                "When enabled, utterances (cuts) will be concatenated "
-                "to minimize the amount of padding."
-            ),
+            help="When enabled, utterances (cuts) will be concatenated "
+            "to minimize the amount of padding.",
         )
 
         group.add_argument(
             "--duration-factor",
             type=float,
             default=1.0,
-            help=(
-                "Determines the maximum duration of a concatenated cut "
-                "relative to the duration of the longest cut in a batch."
-            ),
+            help="Determines the maximum duration of a concatenated cut "
+            "relative to the duration of the longest cut in a batch.",
         )
 
         group.add_argument(
             "--gap",
             type=float,
             default=1.0,
-            help=(
-                "The amount of padding (in seconds) inserted between "
-                "concatenated cuts. This padding is filled with noise when "
-                "noise augmentation is used."
-            ),
+            help="The amount of padding (in seconds) inserted between "
+            "concatenated cuts. This padding is filled with noise when "
+            "noise augmentation is used.",
         )
 
         group.add_argument(
             "--on-the-fly-feats",
             type=str2bool,
             default=False,
-            help=(
-                "When enabled, use on-the-fly cut mixing and feature "
-                "extraction. Will drop existing precomputed feature manifests "
-                "if available."
-            ),
+            help="When enabled, use on-the-fly cut mixing and feature "
+            "extraction. Will drop existing precomputed feature manifests "
+            "if available.",
         )
 
         group.add_argument(
             "--shuffle",
             type=str2bool,
             default=True,
-            help=(
-                "When enabled (=default), the examples will be shuffled for each epoch."
-            ),
+            help="When enabled (=default), the examples will be "
+            "shuffled for each epoch.",
         )
 
         group.add_argument(
@@ -181,18 +164,17 @@ class Aishell4AsrDataModule:
             "--return-cuts",
             type=str2bool,
             default=True,
-            help=(
-                "When enabled, each batch will have the "
-                "field: batch['supervisions']['cut'] with the cuts that "
-                "were used to construct it."
-            ),
+            help="When enabled, each batch will have the "
+            "field: batch['supervisions']['cut'] with the cuts that "
+            "were used to construct it.",
         )
 
         group.add_argument(
             "--num-workers",
             type=int,
             default=2,
-            help="The number of training dataloader workers that collect the batches.",
+            help="The number of training dataloader workers that "
+            "collect the batches.",
         )
 
         group.add_argument(
@@ -206,22 +188,18 @@ class Aishell4AsrDataModule:
             "--spec-aug-time-warp-factor",
             type=int,
             default=80,
-            help=(
-                "Used only when --enable-spec-aug is True. "
-                "It specifies the factor for time warping in SpecAugment. "
-                "Larger values mean more warping. "
-                "A value less than 1 means to disable time warp."
-            ),
+            help="Used only when --enable-spec-aug is True. "
+            "It specifies the factor for time warping in SpecAugment. "
+            "Larger values mean more warping. "
+            "A value less than 1 means to disable time warp.",
         )
 
         group.add_argument(
             "--enable-musan",
             type=str2bool,
             default=True,
-            help=(
-                "When enabled, select noise from MUSAN and mix it"
-                "with training dataset. "
-            ),
+            help="When enabled, select noise from MUSAN and mix it"
+            "with training dataset. ",
         )
 
         group.add_argument(
@@ -244,20 +222,24 @@ class Aishell4AsrDataModule:
             The state dict for the training sampler.
         """
         logging.info("About to get Musan cuts")
-        cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
+        cuts_musan = load_manifest(
+            self.args.manifest_dir / "musan_cuts.jsonl.gz"
+        )
 
         transforms = []
         if self.args.enable_musan:
             logging.info("Enable MUSAN")
             transforms.append(
-                CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
+                CutMix(
+                    cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
+                )
             )
         else:
             logging.info("Disable MUSAN")
 
         if self.args.concatenate_cuts:
             logging.info(
-                "Using cut concatenation with duration factor "
+                f"Using cut concatenation with duration factor "
                 f"{self.args.duration_factor} and gap {self.args.gap}."
             )
             # Cut concatenation should be the first transform in the list,
@@ -272,7 +254,9 @@ class Aishell4AsrDataModule:
         input_transforms = []
         if self.args.enable_spec_aug:
             logging.info("Enable SpecAugment")
-            logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
+            logging.info(
+                f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
+            )
             # Set the value of num_frame_masks according to Lhotse's version.
             # In different Lhotse's versions, the default of num_frame_masks is
             # different.
@@ -316,7 +300,9 @@ class Aishell4AsrDataModule:
             # Drop feats to be on the safe side.
             train = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
+                input_strategy=OnTheFlyFeatures(
+                    Fbank(FbankConfig(num_mel_bins=80))
+                ),
                 input_transforms=input_transforms,
                 return_cuts=self.args.return_cuts,
             )
@@ -373,7 +359,9 @@ class Aishell4AsrDataModule:
         if self.args.on_the_fly_feats:
             validate = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
+                input_strategy=OnTheFlyFeatures(
+                    Fbank(FbankConfig(num_mel_bins=80))
+                ),
                 return_cuts=self.args.return_cuts,
             )
         else:
diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/decode.py b/egs/aishell4/ASR/pruned_transducer_stateless5/decode.py
index 616a88937..14e44c7d9 100755
--- a/egs/aishell4/ASR/pruned_transducer_stateless5/decode.py
+++ b/egs/aishell4/ASR/pruned_transducer_stateless5/decode.py
@@ -117,24 +117,20 @@ def get_parser():
         "--avg",
         type=int,
         default=15,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch' and '--iter'"
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch' and '--iter'",
     )
 
     parser.add_argument(
         "--use-averaged-model",
         type=str2bool,
         default=False,
-        help=(
-            "Whether to load averaged model. Currently it only supports "
-            "using --epoch. If True, it would decode with the averaged model "
-            "over the epoch range from `epoch-avg` (excluded) to `epoch`."
-            "Actually only the models with epoch number of `epoch-avg` and "
-            "`epoch` are loaded for averaging. "
-        ),
+        help="Whether to load averaged model. Currently it only supports "
+        "using --epoch. If True, it would decode with the averaged model "
+        "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+        "Actually only the models with epoch number of `epoch-avg` and "
+        "`epoch` are loaded for averaging. ",
     )
 
     parser.add_argument(
@@ -205,7 +201,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -263,7 +260,9 @@ def decode_one_batch(
     supervisions = batch["supervisions"]
     feature_lens = supervisions["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
+    encoder_out, encoder_out_lens = model.encoder(
+        x=feature, x_lens=feature_lens
+    )
     hyps = []
 
     if params.decoding_method == "fast_beam_search":
@@ -278,7 +277,10 @@ def decode_one_batch(
         )
         for i in range(encoder_out.size(0)):
             hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
-    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
+    elif (
+        params.decoding_method == "greedy_search"
+        and params.max_sym_per_frame == 1
+    ):
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -324,7 +326,11 @@ def decode_one_batch(
         return {"greedy_search": hyps}
     elif params.decoding_method == "fast_beam_search":
         return {
-            f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps
+            (
+                f"beam_{params.beam}_"
+                f"max_contexts_{params.max_contexts}_"
+                f"max_states_{params.max_states}"
+            ): hyps
         }
     else:
         return {f"beam_size_{params.beam_size}": hyps}
@@ -395,7 +401,9 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+            logging.info(
+                f"batch {batch_str}, cuts processed until now is {num_cuts}"
+            )
     return results
 
 
@@ -428,7 +436,8 @@ def save_results(
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = (
-        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir
+        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tWER", file=f)
@@ -471,7 +480,9 @@ def main():
         params.suffix += f"-max-contexts-{params.max_contexts}"
         params.suffix += f"-max-states-{params.max_states}"
     elif "beam_search" in params.decoding_method:
-        params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
+        params.suffix += (
+            f"-{params.decoding_method}-beam-size-{params.beam_size}"
+        )
     else:
         params.suffix += f"-context-{params.context_size}"
         params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
@@ -499,12 +510,13 @@ def main():
 
     if not params.use_averaged_model:
         if params.iter > 0:
-            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
-                : params.avg
-            ]
+            filenames = find_checkpoints(
+                params.exp_dir, iteration=-params.iter
+            )[: params.avg]
             if len(filenames) == 0:
                 raise ValueError(
-                    f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
                 )
             elif len(filenames) < params.avg:
                 raise ValueError(
@@ -531,12 +543,13 @@ def main():
             )
     else:
         if params.iter > 0:
-            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
-                : params.avg + 1
-            ]
+            filenames = find_checkpoints(
+                params.exp_dir, iteration=-params.iter
+            )[: params.avg + 1]
             if len(filenames) == 0:
                 raise ValueError(
-                    f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
                 )
             elif len(filenames) < params.avg + 1:
                 raise ValueError(
@@ -565,7 +578,7 @@ def main():
             filename_start = f"{params.exp_dir}/epoch-{start}.pt"
             filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
             logging.info(
-                "Calculating the averaged model over epoch range from "
+                f"Calculating the averaged model over epoch range from "
                 f"{start} (excluded) to {params.epoch}"
             )
             model.to(device)
diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/export.py b/egs/aishell4/ASR/pruned_transducer_stateless5/export.py
index 3c580ff7b..993341131 100755
--- a/egs/aishell4/ASR/pruned_transducer_stateless5/export.py
+++ b/egs/aishell4/ASR/pruned_transducer_stateless5/export.py
@@ -89,24 +89,20 @@ def get_parser():
         "--avg",
         type=int,
         default=15,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch' and '--iter'"
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch' and '--iter'",
     )
 
     parser.add_argument(
         "--use-averaged-model",
         type=str2bool,
         default=False,
-        help=(
-            "Whether to load averaged model. Currently it only supports "
-            "using --epoch. If True, it would decode with the averaged model "
-            "over the epoch range from `epoch-avg` (excluded) to `epoch`."
-            "Actually only the models with epoch number of `epoch-avg` and "
-            "`epoch` are loaded for averaging. "
-        ),
+        help="Whether to load averaged model. Currently it only supports "
+        "using --epoch. If True, it would decode with the averaged model "
+        "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+        "Actually only the models with epoch number of `epoch-avg` and "
+        "`epoch` are loaded for averaging. ",
     )
 
     parser.add_argument(
@@ -140,7 +136,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
 
     add_model_arguments(parser)
@@ -172,12 +169,13 @@ def main():
 
     if not params.use_averaged_model:
         if params.iter > 0:
-            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
-                : params.avg
-            ]
+            filenames = find_checkpoints(
+                params.exp_dir, iteration=-params.iter
+            )[: params.avg]
             if len(filenames) == 0:
                 raise ValueError(
-                    f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
                 )
             elif len(filenames) < params.avg:
                 raise ValueError(
@@ -204,12 +202,13 @@ def main():
             )
     else:
         if params.iter > 0:
-            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
-                : params.avg + 1
-            ]
+            filenames = find_checkpoints(
+                params.exp_dir, iteration=-params.iter
+            )[: params.avg + 1]
             if len(filenames) == 0:
                 raise ValueError(
-                    f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
                 )
             elif len(filenames) < params.avg + 1:
                 raise ValueError(
@@ -238,7 +237,7 @@ def main():
             filename_start = f"{params.exp_dir}/epoch-{start}.pt"
             filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
             logging.info(
-                "Calculating the averaged model over epoch range from "
+                f"Calculating the averaged model over epoch range from "
                 f"{start} (excluded) to {params.epoch}"
             )
             model.to(device)
@@ -277,7 +276,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/pretrained.py b/egs/aishell4/ASR/pruned_transducer_stateless5/pretrained.py
index 8151442af..1fa893637 100755
--- a/egs/aishell4/ASR/pruned_transducer_stateless5/pretrained.py
+++ b/egs/aishell4/ASR/pruned_transducer_stateless5/pretrained.py
@@ -94,11 +94,9 @@ def get_parser():
         "--checkpoint",
         type=str,
         required=True,
-        help=(
-            "Path to the checkpoint. "
-            "The checkpoint is assumed to be saved by "
-            "icefall.checkpoint.save_checkpoint()."
-        ),
+        help="Path to the checkpoint. "
+        "The checkpoint is assumed to be saved by "
+        "icefall.checkpoint.save_checkpoint().",
     )
 
     parser.add_argument(
@@ -124,12 +122,10 @@ def get_parser():
         "sound_files",
         type=str,
         nargs="+",
-        help=(
-            "The input sound file(s) to transcribe. "
-            "Supported formats are those supported by torchaudio.load(). "
-            "For example, wav and flac are supported. "
-            "The sample rate has to be 16kHz."
-        ),
+        help="The input sound file(s) to transcribe. "
+        "Supported formats are those supported by torchaudio.load(). "
+        "For example, wav and flac are supported. "
+        "The sample rate has to be 16kHz.",
     )
 
     parser.add_argument(
@@ -176,7 +172,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -207,9 +204,10 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert (
-            sample_rate == expected_sample_rate
-        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
+        assert sample_rate == expected_sample_rate, (
+            f"expected sample rate: {expected_sample_rate}. "
+            f"Given: {sample_rate}"
+        )
         # We use only the first channel
         ans.append(wave[0])
     return ans
@@ -268,11 +266,15 @@ def main():
     features = fbank(waves)
     feature_lengths = [f.size(0) for f in features]
 
-    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
+    features = pad_sequence(
+        features, batch_first=True, padding_value=math.log(1e-10)
+    )
 
     feature_lengths = torch.tensor(feature_lengths, device=device)
 
-    encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths)
+    encoder_out, encoder_out_lens = model.encoder(
+        x=features, x_lens=feature_lengths
+    )
 
     num_waves = encoder_out.size(0)
     hyps = []
@@ -304,7 +306,10 @@ def main():
 
         for i in range(encoder_out.size(0)):
             hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
-    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
+    elif (
+        params.decoding_method == "greedy_search"
+        and params.max_sym_per_frame == 1
+    ):
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -345,7 +350,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/train.py b/egs/aishell4/ASR/pruned_transducer_stateless5/train.py
index aacd23ecd..0a48b9059 100755
--- a/egs/aishell4/ASR/pruned_transducer_stateless5/train.py
+++ b/egs/aishell4/ASR/pruned_transducer_stateless5/train.py
@@ -85,7 +85,9 @@ from icefall.env import get_env_info
 from icefall.lexicon import Lexicon
 from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
 
-LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
+LRSchedulerType = Union[
+    torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
+]
 
 
 def add_model_arguments(parser: argparse.ArgumentParser):
@@ -211,7 +213,8 @@ def get_parser():
         "--initial-lr",
         type=float,
         default=0.003,
-        help="The initial learning rate.  This value should not need to be changed.",
+        help="The initial learning rate.  This value should not need "
+        "to be changed.",
     )
 
     parser.add_argument(
@@ -234,45 +237,42 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
 
     parser.add_argument(
         "--prune-range",
         type=int,
         default=5,
-        help=(
-            "The prune range for rnnt loss, it means how many symbols(context)"
-            "we are using to compute the loss"
-        ),
+        help="The prune range for rnnt loss, it means how many symbols(context)"
+        "we are using to compute the loss",
     )
 
     parser.add_argument(
         "--lm-scale",
         type=float,
         default=0.25,
-        help=(
-            "The scale to smooth the loss with lm (output of prediction network) part."
-        ),
+        help="The scale to smooth the loss with lm "
+        "(output of prediction network) part.",
     )
 
     parser.add_argument(
         "--am-scale",
         type=float,
         default=0.0,
-        help="The scale to smooth the loss with am (output of encoder network)part.",
+        help="The scale to smooth the loss with am (output of encoder network)"
+        "part.",
     )
 
     parser.add_argument(
         "--simple-loss-scale",
         type=float,
         default=0.5,
-        help=(
-            "To get pruning ranges, we will calculate a simple version"
-            "loss(joiner is just addition), this simple loss also uses for"
-            "training (as a regularization item). We will scale the simple loss"
-            "with this parameter before adding to the final loss."
-        ),
+        help="To get pruning ranges, we will calculate a simple version"
+        "loss(joiner is just addition), this simple loss also uses for"
+        "training (as a regularization item). We will scale the simple loss"
+        "with this parameter before adding to the final loss.",
     )
 
     parser.add_argument(
@@ -599,7 +599,11 @@ def compute_loss(
      warmup: a floating point value which increases throughout training;
         values >= 1.0 are fully warmed up and have all modules present.
     """
-    device = model.device if isinstance(model, DDP) else next(model.parameters()).device
+    device = (
+        model.device
+        if isinstance(model, DDP)
+        else next(model.parameters()).device
+    )
     feature = batch["inputs"]
     # at entry, feature is (N, T, C)
     assert feature.ndim == 3
@@ -629,15 +633,22 @@ def compute_loss(
         # overwhelming the simple_loss and causing it to diverge,
         # in case it had not fully learned the alignment yet.
         pruned_loss_scale = (
-            0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
+            0.0
+            if warmup < 1.0
+            else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
+        )
+        loss = (
+            params.simple_loss_scale * simple_loss
+            + pruned_loss_scale * pruned_loss
         )
-        loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
     assert loss.requires_grad == is_training
 
     info = MetricsTracker()
     with warnings.catch_warnings():
         warnings.simplefilter("ignore")
-        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
+        info["frames"] = (
+            (feature_lens // params.subsampling_factor).sum().item()
+        )
 
     # Note: We use reduction=sum while computing the loss.
     info["loss"] = loss.detach().cpu().item()
@@ -816,7 +827,9 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+                tot_loss.write_summary(
+                    tb_writer, "train/tot_", params.batch_idx_train
+                )
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
@@ -924,7 +937,7 @@ def run(rank, world_size, args):
 
     if params.print_diagnostics:
         opts = diagnostics.TensorDiagnosticOptions(
-            2**22
+            2 ** 22
         )  # allow 4 megabytes per sub-module
         diagnostic = diagnostics.attach_diagnostics(model, opts)
 
diff --git a/egs/alimeeting/ASR/local/compute_fbank_alimeeting.py b/egs/alimeeting/ASR/local/compute_fbank_alimeeting.py
index 96115a230..af926aa53 100755
--- a/egs/alimeeting/ASR/local/compute_fbank_alimeeting.py
+++ b/egs/alimeeting/ASR/local/compute_fbank_alimeeting.py
@@ -84,7 +84,9 @@ def compute_fbank_alimeeting(num_mel_bins: int = 80):
             )
             if "train" in partition:
                 cut_set = (
-                    cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
+                    cut_set
+                    + cut_set.perturb_speed(0.9)
+                    + cut_set.perturb_speed(1.1)
                 )
             cur_num_jobs = num_jobs if ex is None else 80
             cur_num_jobs = min(cur_num_jobs, len(cut_set))
@@ -119,7 +121,9 @@ def get_args():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
 
diff --git a/egs/alimeeting/ASR/local/prepare_char.py b/egs/alimeeting/ASR/local/prepare_char.py
index 6b440dfb3..d9e47d17a 100755
--- a/egs/alimeeting/ASR/local/prepare_char.py
+++ b/egs/alimeeting/ASR/local/prepare_char.py
@@ -86,7 +86,9 @@ def lexicon_to_fst_no_sil(
         cur_state = loop_state
 
         word = word2id[word]
-        pieces = [token2id[i] if i in token2id else token2id[""] for i in pieces]
+        pieces = [
+            token2id[i] if i in token2id else token2id[""] for i in pieces
+        ]
 
         for i in range(len(pieces) - 1):
             w = word if i == 0 else eps
@@ -140,7 +142,9 @@ def contain_oov(token_sym_table: Dict[str, int], tokens: List[str]) -> bool:
     return False
 
 
-def generate_lexicon(token_sym_table: Dict[str, int], words: List[str]) -> Lexicon:
+def generate_lexicon(
+    token_sym_table: Dict[str, int], words: List[str]
+) -> Lexicon:
     """Generate a lexicon from a word list and token_sym_table.
 
     Args:
diff --git a/egs/alimeeting/ASR/local/prepare_lang.py b/egs/alimeeting/ASR/local/prepare_lang.py
index c8cf9b881..e5ae89ec4 100755
--- a/egs/alimeeting/ASR/local/prepare_lang.py
+++ b/egs/alimeeting/ASR/local/prepare_lang.py
@@ -317,7 +317,9 @@ def lexicon_to_fst(
 
 def get_args():
     parser = argparse.ArgumentParser()
-    parser.add_argument("--lang-dir", type=str, help="The lang dir, data/lang_phone")
+    parser.add_argument(
+        "--lang-dir", type=str, help="The lang dir, data/lang_phone"
+    )
     return parser.parse_args()
 
 
diff --git a/egs/alimeeting/ASR/local/test_prepare_lang.py b/egs/alimeeting/ASR/local/test_prepare_lang.py
index 74e025ad7..d4cf62bba 100755
--- a/egs/alimeeting/ASR/local/test_prepare_lang.py
+++ b/egs/alimeeting/ASR/local/test_prepare_lang.py
@@ -88,7 +88,9 @@ def test_read_lexicon(filename: str):
     fsa.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
     fsa.draw("L.pdf", title="L")
 
-    fsa_disambig = lexicon_to_fst(lexicon_disambig, phone2id=phone2id, word2id=word2id)
+    fsa_disambig = lexicon_to_fst(
+        lexicon_disambig, phone2id=phone2id, word2id=word2id
+    )
     fsa_disambig.labels_sym = k2.SymbolTable.from_file("phones.txt")
     fsa_disambig.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
     fsa_disambig.draw("L_disambig.pdf", title="L_disambig")
diff --git a/egs/alimeeting/ASR/local/text2segments.py b/egs/alimeeting/ASR/local/text2segments.py
index 27b904fc8..7c1019aa8 100644
--- a/egs/alimeeting/ASR/local/text2segments.py
+++ b/egs/alimeeting/ASR/local/text2segments.py
@@ -30,8 +30,8 @@ with word segmenting:
 
 import argparse
 
-import jieba
 import paddle
+import jieba
 from tqdm import tqdm
 
 paddle.enable_static()
diff --git a/egs/alimeeting/ASR/local/text2token.py b/egs/alimeeting/ASR/local/text2token.py
index 2be639b7a..71be2a613 100755
--- a/egs/alimeeting/ASR/local/text2token.py
+++ b/egs/alimeeting/ASR/local/text2token.py
@@ -50,15 +50,15 @@ def get_parser():
         "-n",
         default=1,
         type=int,
-        help=(
-            "number of characters to split, i.e.,                         aabb -> a a b"
-            " b with -n 1 and aa bb with -n 2"
-        ),
+        help="number of characters to split, i.e., \
+                        aabb -> a a b b with -n 1 and aa bb with -n 2",
     )
     parser.add_argument(
         "--skip-ncols", "-s", default=0, type=int, help="skip first n columns"
     )
-    parser.add_argument("--space", default="", type=str, help="space symbol")
+    parser.add_argument(
+        "--space", default="", type=str, help="space symbol"
+    )
     parser.add_argument(
         "--non-lang-syms",
         "-l",
@@ -66,7 +66,9 @@ def get_parser():
         type=str,
         help="list of non-linguistic symobles, e.g.,  etc.",
     )
-    parser.add_argument("text", type=str, default=False, nargs="?", help="input text")
+    parser.add_argument(
+        "text", type=str, default=False, nargs="?", help="input text"
+    )
     parser.add_argument(
         "--trans_type",
         "-t",
@@ -106,7 +108,8 @@ def token2id(
             if token_type == "lazy_pinyin":
                 text = lazy_pinyin(chars_list)
                 sub_ids = [
-                    token_table[txt] if txt in token_table else oov_id for txt in text
+                    token_table[txt] if txt in token_table else oov_id
+                    for txt in text
                 ]
                 ids.append(sub_ids)
             else:  # token_type = "pinyin"
@@ -132,7 +135,9 @@ def main():
     if args.text:
         f = codecs.open(args.text, encoding="utf-8")
     else:
-        f = codecs.getreader("utf-8")(sys.stdin if is_python2 else sys.stdin.buffer)
+        f = codecs.getreader("utf-8")(
+            sys.stdin if is_python2 else sys.stdin.buffer
+        )
 
     sys.stdout = codecs.getwriter("utf-8")(
         sys.stdout if is_python2 else sys.stdout.buffer
diff --git a/egs/alimeeting/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/alimeeting/ASR/pruned_transducer_stateless2/asr_datamodule.py
index d0467a29e..bf6faad7a 100644
--- a/egs/alimeeting/ASR/pruned_transducer_stateless2/asr_datamodule.py
+++ b/egs/alimeeting/ASR/pruned_transducer_stateless2/asr_datamodule.py
@@ -81,12 +81,10 @@ class AlimeetingAsrDataModule:
     def add_arguments(cls, parser: argparse.ArgumentParser):
         group = parser.add_argument_group(
             title="ASR data related options",
-            description=(
-                "These options are used for the preparation of "
-                "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
-                "effective batch sizes, sampling strategies, applied data "
-                "augmentations, etc."
-            ),
+            description="These options are used for the preparation of "
+            "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
+            "effective batch sizes, sampling strategies, applied data "
+            "augmentations, etc.",
         )
         group.add_argument(
             "--manifest-dir",
@@ -98,91 +96,75 @@ class AlimeetingAsrDataModule:
             "--max-duration",
             type=int,
             default=200.0,
-            help=(
-                "Maximum pooled recordings duration (seconds) in a "
-                "single batch. You can reduce it if it causes CUDA OOM."
-            ),
+            help="Maximum pooled recordings duration (seconds) in a "
+            "single batch. You can reduce it if it causes CUDA OOM.",
         )
         group.add_argument(
             "--bucketing-sampler",
             type=str2bool,
             default=True,
-            help=(
-                "When enabled, the batches will come from buckets of "
-                "similar duration (saves padding frames)."
-            ),
+            help="When enabled, the batches will come from buckets of "
+            "similar duration (saves padding frames).",
         )
         group.add_argument(
             "--num-buckets",
             type=int,
             default=300,
-            help=(
-                "The number of buckets for the DynamicBucketingSampler"
-                "(you might want to increase it for larger datasets)."
-            ),
+            help="The number of buckets for the DynamicBucketingSampler"
+            "(you might want to increase it for larger datasets).",
         )
         group.add_argument(
             "--concatenate-cuts",
             type=str2bool,
             default=False,
-            help=(
-                "When enabled, utterances (cuts) will be concatenated "
-                "to minimize the amount of padding."
-            ),
+            help="When enabled, utterances (cuts) will be concatenated "
+            "to minimize the amount of padding.",
         )
         group.add_argument(
             "--duration-factor",
             type=float,
             default=1.0,
-            help=(
-                "Determines the maximum duration of a concatenated cut "
-                "relative to the duration of the longest cut in a batch."
-            ),
+            help="Determines the maximum duration of a concatenated cut "
+            "relative to the duration of the longest cut in a batch.",
         )
         group.add_argument(
             "--gap",
             type=float,
             default=1.0,
-            help=(
-                "The amount of padding (in seconds) inserted between "
-                "concatenated cuts. This padding is filled with noise when "
-                "noise augmentation is used."
-            ),
+            help="The amount of padding (in seconds) inserted between "
+            "concatenated cuts. This padding is filled with noise when "
+            "noise augmentation is used.",
         )
         group.add_argument(
             "--on-the-fly-feats",
             type=str2bool,
             default=False,
-            help=(
-                "When enabled, use on-the-fly cut mixing and feature "
-                "extraction. Will drop existing precomputed feature manifests "
-                "if available."
-            ),
+            help="When enabled, use on-the-fly cut mixing and feature "
+            "extraction. Will drop existing precomputed feature manifests "
+            "if available.",
         )
         group.add_argument(
             "--shuffle",
             type=str2bool,
             default=True,
-            help=(
-                "When enabled (=default), the examples will be shuffled for each epoch."
-            ),
+            help="When enabled (=default), the examples will be "
+            "shuffled for each epoch.",
         )
         group.add_argument(
             "--return-cuts",
             type=str2bool,
             default=True,
-            help=(
-                "When enabled, each batch will have the "
-                "field: batch['supervisions']['cut'] with the cuts that "
-                "were used to construct it."
-            ),
+            help="When enabled, each batch will have the "
+            "field: batch['supervisions']['cut'] with the cuts that "
+            "were used to construct it.",
         )
 
         group.add_argument(
             "--num-workers",
             type=int,
             default=2,
-            help="The number of training dataloader workers that collect the batches.",
+            help="The number of training dataloader workers that "
+            "collect the batches.",
         )
 
         group.add_argument(
@@ -196,22 +178,18 @@ class AlimeetingAsrDataModule:
             "--spec-aug-time-warp-factor",
             type=int,
             default=80,
-            help=(
-                "Used only when --enable-spec-aug is True. "
-                "It specifies the factor for time warping in SpecAugment. "
-                "Larger values mean more warping. "
-                "A value less than 1 means to disable time warp."
-            ),
+            help="Used only when --enable-spec-aug is True. "
+            "It specifies the factor for time warping in SpecAugment. "
+            "Larger values mean more warping. "
+            "A value less than 1 means to disable time warp.",
         )
 
         group.add_argument(
             "--enable-musan",
             type=str2bool,
             default=True,
-            help=(
-                "When enabled, select noise from MUSAN and mix it"
-                "with training dataset. "
-            ),
+            help="When enabled, select noise from MUSAN and mix it"
+            "with training dataset. ",
         )
 
     def train_dataloaders(
@@ -227,20 +205,24 @@ class AlimeetingAsrDataModule:
             The state dict for the training sampler.
         """
         logging.info("About to get Musan cuts")
-        cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
+        cuts_musan = load_manifest(
+            self.args.manifest_dir / "musan_cuts.jsonl.gz"
+        )
 
         transforms = []
         if self.args.enable_musan:
             logging.info("Enable MUSAN")
             transforms.append(
-                CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
+                CutMix(
+                    cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
+                )
             )
         else:
             logging.info("Disable MUSAN")
 
         if self.args.concatenate_cuts:
             logging.info(
-                "Using cut concatenation with duration factor "
+                f"Using cut concatenation with duration factor "
                 f"{self.args.duration_factor} and gap {self.args.gap}."
             )
             # Cut concatenation should be the first transform in the list,
@@ -255,7 +237,9 @@ class AlimeetingAsrDataModule:
         input_transforms = []
         if self.args.enable_spec_aug:
             logging.info("Enable SpecAugment")
-            logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
+            logging.info(
+                f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
+            )
             # Set the value of num_frame_masks according to Lhotse's version.
             # In different Lhotse's versions, the default of num_frame_masks is
             # different.
@@ -298,7 +282,9 @@ class AlimeetingAsrDataModule:
             # Drop feats to be on the safe side.
             train = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
+                input_strategy=OnTheFlyFeatures(
+                    Fbank(FbankConfig(num_mel_bins=80))
+                ),
                 input_transforms=input_transforms,
                 return_cuts=self.args.return_cuts,
             )
@@ -355,7 +341,9 @@ class AlimeetingAsrDataModule:
         if self.args.on_the_fly_feats:
             validate = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
+                input_strategy=OnTheFlyFeatures(
+                    Fbank(FbankConfig(num_mel_bins=80))
+                ),
                 return_cuts=self.args.return_cuts,
             )
         else:
diff --git a/egs/alimeeting/ASR/pruned_transducer_stateless2/decode.py b/egs/alimeeting/ASR/pruned_transducer_stateless2/decode.py
index ffaca1021..6358fe970 100755
--- a/egs/alimeeting/ASR/pruned_transducer_stateless2/decode.py
+++ b/egs/alimeeting/ASR/pruned_transducer_stateless2/decode.py
@@ -70,7 +70,11 @@ from beam_search import (
 from lhotse.cut import Cut
 from train import get_params, get_transducer_model
 
-from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
+from icefall.checkpoint import (
+    average_checkpoints,
+    find_checkpoints,
+    load_checkpoint,
+)
 from icefall.lexicon import Lexicon
 from icefall.utils import (
     AttributeDict,
@@ -89,30 +93,25 @@ def get_parser():
         "--epoch",
         type=int,
         default=28,
-        help=(
-            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
-        ),
+        help="It specifies the checkpoint to use for decoding."
+        "Note: Epoch counts from 0.",
     )
 
     parser.add_argument(
         "--batch",
         type=int,
         default=None,
-        help=(
-            "It specifies the batch checkpoint to use for decoding."
-            "Note: Epoch counts from 0."
-        ),
+        help="It specifies the batch checkpoint to use for decoding."
+        "Note: Epoch counts from 0.",
     )
 
     parser.add_argument(
         "--avg",
         type=int,
         default=15,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch'. "
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch'. ",
     )
 
     parser.add_argument(
@@ -194,7 +193,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -249,7 +249,9 @@ def decode_one_batch(
 
     supervisions = batch["supervisions"]
     feature_lens = supervisions["num_frames"].to(device)
-    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
+    encoder_out, encoder_out_lens = model.encoder(
+        x=feature, x_lens=feature_lens
+    )
     hyps = []
 
     if params.decoding_method == "fast_beam_search":
@@ -264,7 +266,10 @@ def decode_one_batch(
         )
         for i in range(encoder_out.size(0)):
             hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
-    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
+    elif (
+        params.decoding_method == "greedy_search"
+        and params.max_sym_per_frame == 1
+    ):
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -310,7 +315,11 @@ def decode_one_batch(
         return {"greedy_search": hyps}
     elif params.decoding_method == "fast_beam_search":
         return {
-            f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps
+            (
+                f"beam_{params.beam}_"
+                f"max_contexts_{params.max_contexts}_"
+                f"max_states_{params.max_states}"
+            ): hyps
         }
     else:
         return {f"beam_size_{params.beam_size}": hyps}
@@ -381,7 +390,9 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+            logging.info(
+                f"batch {batch_str}, cuts processed until now is {num_cuts}"
+            )
     return results
 
 
@@ -414,7 +425,8 @@ def save_results(
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = (
-        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir
+        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tWER", file=f)
@@ -551,7 +563,8 @@ def main():
         )
 
     dev_shards = [
-        str(path) for path in sorted(glob.glob(os.path.join(dev, "shared-*.tar")))
+        str(path)
+        for path in sorted(glob.glob(os.path.join(dev, "shared-*.tar")))
     ]
     cuts_dev_webdataset = CutSet.from_webdataset(
         dev_shards,
@@ -561,7 +574,8 @@ def main():
     )
 
     test_shards = [
-        str(path) for path in sorted(glob.glob(os.path.join(test, "shared-*.tar")))
+        str(path)
+        for path in sorted(glob.glob(os.path.join(test, "shared-*.tar")))
     ]
     cuts_test_webdataset = CutSet.from_webdataset(
         test_shards,
@@ -574,7 +588,9 @@ def main():
         return 1.0 <= c.duration
 
     cuts_dev_webdataset = cuts_dev_webdataset.filter(remove_short_and_long_utt)
-    cuts_test_webdataset = cuts_test_webdataset.filter(remove_short_and_long_utt)
+    cuts_test_webdataset = cuts_test_webdataset.filter(
+        remove_short_and_long_utt
+    )
 
     dev_dl = alimeeting.valid_dataloaders(cuts_dev_webdataset)
     test_dl = alimeeting.test_dataloaders(cuts_test_webdataset)
diff --git a/egs/alimeeting/ASR/pruned_transducer_stateless2/export.py b/egs/alimeeting/ASR/pruned_transducer_stateless2/export.py
index 482e52d83..8beec1b8a 100644
--- a/egs/alimeeting/ASR/pruned_transducer_stateless2/export.py
+++ b/egs/alimeeting/ASR/pruned_transducer_stateless2/export.py
@@ -62,20 +62,17 @@ def get_parser():
         "--epoch",
         type=int,
         default=28,
-        help=(
-            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
-        ),
+        help="It specifies the checkpoint to use for decoding."
+        "Note: Epoch counts from 0.",
     )
 
     parser.add_argument(
         "--avg",
         type=int,
         default=15,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch'. "
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch'. ",
     )
 
     parser.add_argument(
@@ -106,7 +103,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
 
     return parser
@@ -175,7 +173,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/alimeeting/ASR/pruned_transducer_stateless2/pretrained.py b/egs/alimeeting/ASR/pruned_transducer_stateless2/pretrained.py
index afbf0960a..93b1e1f57 100644
--- a/egs/alimeeting/ASR/pruned_transducer_stateless2/pretrained.py
+++ b/egs/alimeeting/ASR/pruned_transducer_stateless2/pretrained.py
@@ -85,11 +85,9 @@ def get_parser():
         "--checkpoint",
         type=str,
         required=True,
-        help=(
-            "Path to the checkpoint. "
-            "The checkpoint is assumed to be saved by "
-            "icefall.checkpoint.save_checkpoint()."
-        ),
+        help="Path to the checkpoint. "
+        "The checkpoint is assumed to be saved by "
+        "icefall.checkpoint.save_checkpoint().",
     )
 
     parser.add_argument(
@@ -114,12 +112,10 @@ def get_parser():
         "sound_files",
         type=str,
         nargs="+",
-        help=(
-            "The input sound file(s) to transcribe. "
-            "Supported formats are those supported by torchaudio.load(). "
-            "For example, wav and flac are supported. "
-            "The sample rate has to be 16kHz."
-        ),
+        help="The input sound file(s) to transcribe. "
+        "Supported formats are those supported by torchaudio.load(). "
+        "For example, wav and flac are supported. "
+        "The sample rate has to be 16kHz.",
     )
 
     parser.add_argument(
@@ -166,7 +162,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
 
     parser.add_argument(
@@ -196,9 +193,10 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert (
-            sample_rate == expected_sample_rate
-        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
+        assert sample_rate == expected_sample_rate, (
+            f"expected sample rate: {expected_sample_rate}. "
+            f"Given: {sample_rate}"
+        )
         # We use only the first channel
         ans.append(wave[0])
     return ans
@@ -259,7 +257,9 @@ def main():
     features = fbank(waves)
     feature_lengths = [f.size(0) for f in features]
 
-    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
+    features = pad_sequence(
+        features, batch_first=True, padding_value=math.log(1e-10)
+    )
 
     feature_lengths = torch.tensor(feature_lengths, device=device)
 
@@ -284,7 +284,10 @@ def main():
         )
         for i in range(encoder_out.size(0)):
             hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
-    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
+    elif (
+        params.decoding_method == "greedy_search"
+        and params.max_sym_per_frame == 1
+    ):
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -336,7 +339,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/alimeeting/ASR/pruned_transducer_stateless2/train.py b/egs/alimeeting/ASR/pruned_transducer_stateless2/train.py
index 158ea9c1b..81a0ede7f 100644
--- a/egs/alimeeting/ASR/pruned_transducer_stateless2/train.py
+++ b/egs/alimeeting/ASR/pruned_transducer_stateless2/train.py
@@ -81,7 +81,9 @@ from icefall.env import get_env_info
 from icefall.lexicon import Lexicon
 from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
 
-LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
+LRSchedulerType = Union[
+    torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
+]
 
 os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
 
@@ -185,45 +187,42 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
 
     parser.add_argument(
         "--prune-range",
         type=int,
         default=5,
-        help=(
-            "The prune range for rnnt loss, it means how many symbols(context)"
-            "we are using to compute the loss"
-        ),
+        help="The prune range for rnnt loss, it means how many symbols(context)"
+        "we are using to compute the loss",
     )
 
     parser.add_argument(
         "--lm-scale",
         type=float,
         default=0.25,
-        help=(
-            "The scale to smooth the loss with lm (output of prediction network) part."
-        ),
+        help="The scale to smooth the loss with lm "
+        "(output of prediction network) part.",
     )
 
     parser.add_argument(
         "--am-scale",
         type=float,
         default=0.0,
-        help="The scale to smooth the loss with am (output of encoder network)part.",
+        help="The scale to smooth the loss with am (output of encoder network)"
+        "part.",
     )
 
     parser.add_argument(
         "--simple-loss-scale",
         type=float,
         default=0.5,
-        help=(
-            "To get pruning ranges, we will calculate a simple version"
-            "loss(joiner is just addition), this simple loss also uses for"
-            "training (as a regularization item). We will scale the simple loss"
-            "with this parameter before adding to the final loss."
-        ),
+        help="To get pruning ranges, we will calculate a simple version"
+        "loss(joiner is just addition), this simple loss also uses for"
+        "training (as a regularization item). We will scale the simple loss"
+        "with this parameter before adding to the final loss.",
     )
 
     parser.add_argument(
@@ -543,15 +542,22 @@ def compute_loss(
         # overwhelming the simple_loss and causing it to diverge,
         # in case it had not fully learned the alignment yet.
         pruned_loss_scale = (
-            0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
+            0.0
+            if warmup < 1.0
+            else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
+        )
+        loss = (
+            params.simple_loss_scale * simple_loss
+            + pruned_loss_scale * pruned_loss
         )
-        loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
     assert loss.requires_grad == is_training
 
     info = MetricsTracker()
     with warnings.catch_warnings():
         warnings.simplefilter("ignore")
-        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
+        info["frames"] = (
+            (feature_lens // params.subsampling_factor).sum().item()
+        )
 
     # Note: We use reduction=sum while computing the loss.
     info["loss"] = loss.detach().cpu().item()
@@ -705,7 +711,9 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+                tot_loss.write_summary(
+                    tb_writer, "train/tot_", params.batch_idx_train
+                )
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
@@ -805,7 +813,7 @@ def run(rank, world_size, args):
 
     if params.print_diagnostics:
         opts = diagnostics.TensorDiagnosticOptions(
-            2**22
+            2 ** 22
         )  # allow 4 megabytes per sub-module
         diagnostic = diagnostics.attach_diagnostics(model, opts)
 
diff --git a/egs/csj/ASR/.gitignore b/egs/csj/ASR/.gitignore
index cd0e20c4c..5d965832e 100644
--- a/egs/csj/ASR/.gitignore
+++ b/egs/csj/ASR/.gitignore
@@ -5,4 +5,4 @@ notify_tg.py
 finetune_*
 misc.ini
 .vscode/*
-offline/*
+offline/*
\ No newline at end of file
diff --git a/egs/csj/ASR/local/compute_fbank_csj.py b/egs/csj/ASR/local/compute_fbank_csj.py
index 036ce925f..994dedbdd 100644
--- a/egs/csj/ASR/local/compute_fbank_csj.py
+++ b/egs/csj/ASR/local/compute_fbank_csj.py
@@ -25,10 +25,15 @@ from random import Random
 from typing import List, Tuple
 
 import torch
-from lhotse import (  # fmt: off; See the following for why LilcomChunkyWriter is preferred; https://github.com/k2-fsa/icefall/pull/404; https://github.com/lhotse-speech/lhotse/pull/527; fmt: on
+from lhotse import (
     CutSet,
     Fbank,
     FbankConfig,
+    # fmt: off
+    # See the following for why LilcomChunkyWriter is preferred
+    # https://github.com/k2-fsa/icefall/pull/404
+    # https://github.com/lhotse-speech/lhotse/pull/527
+    # fmt: on
     LilcomChunkyWriter,
     RecordingSet,
     SupervisionSet,
@@ -76,13 +81,17 @@ def make_cutset_blueprints(
         cut_sets.append((f"eval{i}", cut_set))
 
     # Create train and valid cuts
-    logging.info("Loading, trimming, and shuffling the remaining core+noncore cuts.")
+    logging.info(
+        "Loading, trimming, and shuffling the remaining core+noncore cuts."
+    )
     recording_set = RecordingSet.from_file(
         manifest_dir / "csj_recordings_core.jsonl.gz"
     ) + RecordingSet.from_file(manifest_dir / "csj_recordings_noncore.jsonl.gz")
     supervision_set = SupervisionSet.from_file(
         manifest_dir / "csj_supervisions_core.jsonl.gz"
-    ) + SupervisionSet.from_file(manifest_dir / "csj_supervisions_noncore.jsonl.gz")
+    ) + SupervisionSet.from_file(
+        manifest_dir / "csj_supervisions_noncore.jsonl.gz"
+    )
 
     cut_set = CutSet.from_manifests(
         recordings=recording_set,
@@ -92,12 +101,15 @@ def make_cutset_blueprints(
     cut_set = cut_set.shuffle(Random(RNG_SEED))
 
     logging.info(
-        f"Creating valid and train cuts from core and noncore,split at {split}."
+        "Creating valid and train cuts from core and noncore,"
+        f"split at {split}."
     )
     valid_set = CutSet.from_cuts(islice(cut_set, 0, split))
 
     train_set = CutSet.from_cuts(islice(cut_set, split, None))
-    train_set = train_set + train_set.perturb_speed(0.9) + train_set.perturb_speed(1.1)
+    train_set = (
+        train_set + train_set.perturb_speed(0.9) + train_set.perturb_speed(1.1)
+    )
 
     cut_sets.extend([("valid", valid_set), ("train", train_set)])
 
@@ -110,9 +122,15 @@ def get_args():
         formatter_class=argparse.ArgumentDefaultsHelpFormatter,
     )
 
-    parser.add_argument("--manifest-dir", type=Path, help="Path to save manifests")
-    parser.add_argument("--fbank-dir", type=Path, help="Path to save fbank features")
-    parser.add_argument("--split", type=int, default=4000, help="Split at this index")
+    parser.add_argument(
+        "--manifest-dir", type=Path, help="Path to save manifests"
+    )
+    parser.add_argument(
+        "--fbank-dir", type=Path, help="Path to save fbank features"
+    )
+    parser.add_argument(
+        "--split", type=int, default=4000, help="Split at this index"
+    )
 
     return parser.parse_args()
 
@@ -123,7 +141,9 @@ def main():
     extractor = Fbank(FbankConfig(num_mel_bins=80))
     num_jobs = min(16, os.cpu_count())
 
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
 
diff --git a/egs/csj/ASR/local/compute_fbank_musan.py b/egs/csj/ASR/local/compute_fbank_musan.py
index f60e62c85..44a33c4eb 100644
--- a/egs/csj/ASR/local/compute_fbank_musan.py
+++ b/egs/csj/ASR/local/compute_fbank_musan.py
@@ -26,6 +26,7 @@ from lhotse.recipes.utils import read_manifests_if_cached
 
 from icefall.utils import get_executor
 
+
 ARGPARSE_DESCRIPTION = """
 This file computes fbank features of the musan dataset.
 It looks for manifests in the directory data/manifests.
@@ -83,7 +84,9 @@ def compute_fbank_musan(manifest_dir: Path, fbank_dir: Path):
         # create chunks of Musan with duration 5 - 10 seconds
         musan_cuts = (
             CutSet.from_manifests(
-                recordings=combine(part["recordings"] for part in manifests.values())
+                recordings=combine(
+                    part["recordings"] for part in manifests.values()
+                )
             )
             .cut_into_windows(10.0)
             .filter(lambda c: c.duration > 5)
@@ -104,15 +107,21 @@ def get_args():
         formatter_class=argparse.ArgumentDefaultsHelpFormatter,
     )
 
-    parser.add_argument("--manifest-dir", type=Path, help="Path to save manifests")
-    parser.add_argument("--fbank-dir", type=Path, help="Path to save fbank features")
+    parser.add_argument(
+        "--manifest-dir", type=Path, help="Path to save manifests"
+    )
+    parser.add_argument(
+        "--fbank-dir", type=Path, help="Path to save fbank features"
+    )
 
     return parser.parse_args()
 
 
 if __name__ == "__main__":
     args = get_args()
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     compute_fbank_musan(args.manifest_dir, args.fbank_dir)
diff --git a/egs/csj/ASR/local/conf/disfluent.ini b/egs/csj/ASR/local/conf/disfluent.ini
index c987e72c5..eb70673de 100644
--- a/egs/csj/ASR/local/conf/disfluent.ini
+++ b/egs/csj/ASR/local/conf/disfluent.ini
@@ -1,17 +1,17 @@
 ; # This section is ignored if this file is not supplied as the first config file to
-; # lhotse prepare csj
+; # lhotse prepare csj  
 [SEGMENTS]
 ; # Allowed period of nonverbal noise. If exceeded, a new segment is created.
 gap = 0.5
 ; # Maximum length of segment (s).
 maxlen = 10
-; # Minimum length of segment (s). Segments shorter than `minlen` will be dropped silently.
+; # Minimum length of segment (s). Segments shorter than `minlen` will be dropped silently.  
 minlen = 0.02
-; # Use this symbol to represent a period of allowed nonverbal noise, i.e. `gap`.
-; # Pass an empty string to avoid adding any symbol. It was "" in kaldi.
-; # If you intend to use a multicharacter string for gap_sym, remember to register the
-; # multicharacter string as part of userdef-string in prepare_lang_char.py.
-gap_sym =
+; # Use this symbol to represent a period of allowed nonverbal noise, i.e. `gap`. 
+; # Pass an empty string to avoid adding any symbol. It was "" in kaldi. 
+; # If you intend to use a multicharacter string for gap_sym, remember to register the 
+; # multicharacter string as part of userdef-string in prepare_lang_char.py. 
+gap_sym = 
 
 [CONSTANTS]
 ; # Name of this mode
@@ -115,59 +115,59 @@ B^ = 0
 ; # 0 to remain, 1 to delete
 ; # Example: '(笑 ナニガ)', '(笑 (F エー)+ソー+イッ+タ+ヨー+ナ)'
 笑 = 0
-; # Example: 'コク(笑 サイ+(D オン))',
+; # Example: 'コク(笑 サイ+(D オン))', 
 笑^ = 0
 ; # 泣きながら発話
 ; # 0 to remain, 1 to delete
-; # Example: '(泣 ドンナニ)'
+; # Example: '(泣 ドンナニ)' 
 泣 = 0
 泣^ = 0
 ; # 咳をしながら発話
 ; # 0 to remain, 1 to delete
-; # Example: 'シャ(咳 リン) ノ'
+; # Example: 'シャ(咳 リン) ノ' 
 咳 = 0
 ; # Example: 'イッ(咳 パン)', 'ワズ(咳 カ)'
 咳^ = 0
 ; # ささやき声や独り言などの小さな声
 ; # 0 to remain, 1 to delete
-; # Example: '(L アレコレナンダッケ)', '(L (W コデ;(? コレ,ココデ))+(? セツメー+シ+タ+ホー+ガ+イー+カ+ナ))'
+; # Example: '(L アレコレナンダッケ)', '(L (W コデ;(? コレ,ココデ))+(? セツメー+シ+タ+ホー+ガ+イー+カ+ナ))' 
 L = 0
 ; # Example: 'デ(L ス)', 'ッ(L テ+コ)ト'
 L^ = 0
 
 [REPLACEMENTS]
 ; # ボーカルフライなどで母音が同定できない場合
- =
+ = 
 ; # 「うん/うーん/ふーん」の音の特定が困難な場合
- =
+ = 
 ; # 非語彙的な母音の引き延ばし
- =
+ = 
 ; # 非語彙的な子音の引き延ばし
- =
+ = 
 ; # 言語音と独立に講演者の笑いが生じている場合
-<笑> =
+<笑> = 
 ; # 言語音と独立に講演者の咳が生じている場合
-<咳> =
+<咳> = 
 ; # 言語音と独立に講演者の息が生じている場合
-<息> =
+<息> = 
 ; # 講演者の泣き声
-<泣> =
+<泣> = 
 ; # 聴衆(司会者なども含む)の発話
-<フロア発話> =
+<フロア発話> = 
 ; # 聴衆の笑い
-<フロア笑> =
+<フロア笑> = 
 ; # 聴衆の拍手
-<拍手> =
+<拍手> = 
 ; # 講演者が発表中に用いたデモンストレーションの音声
-<デモ> =
+<デモ> = 
 ; # 学会講演に発表時間を知らせるためにならすベルの音
-<ベル> =
+<ベル> = 
 ; # 転記単位全体が再度読み直された場合
-<朗読間違い> =
+<朗読間違い> = 
 ; # 上記以外の音で特に目立った音
-<雑音> =
+<雑音> = 
 ; # 0.2秒以上のポーズ
-

= +

= ; # Redacted information, for R ; # It is \x00D7 multiplication sign, not your normal 'x' × = × @@ -318,3 +318,4 @@ spk_id = 2 ャ = ǐa ュ = ǐu ョ = ǐo + diff --git a/egs/csj/ASR/local/conf/fluent.ini b/egs/csj/ASR/local/conf/fluent.ini index f7f27f5bc..5d22f9eb8 100644 --- a/egs/csj/ASR/local/conf/fluent.ini +++ b/egs/csj/ASR/local/conf/fluent.ini @@ -1,17 +1,17 @@ ; # This section is ignored if this file is not supplied as the first config file to -; # lhotse prepare csj +; # lhotse prepare csj [SEGMENTS] ; # Allowed period of nonverbal noise. If exceeded, a new segment is created. gap = 0.5 ; # Maximum length of segment (s). maxlen = 10 -; # Minimum length of segment (s). Segments shorter than `minlen` will be dropped silently. +; # Minimum length of segment (s). Segments shorter than `minlen` will be dropped silently. minlen = 0.02 -; # Use this symbol to represent a period of allowed nonverbal noise, i.e. `gap`. -; # Pass an empty string to avoid adding any symbol. It was "" in kaldi. -; # If you intend to use a multicharacter string for gap_sym, remember to register the -; # multicharacter string as part of userdef-string in prepare_lang_char.py. -gap_sym = +; # Use this symbol to represent a period of allowed nonverbal noise, i.e. `gap`. +; # Pass an empty string to avoid adding any symbol. It was "" in kaldi. +; # If you intend to use a multicharacter string for gap_sym, remember to register the +; # multicharacter string as part of userdef-string in prepare_lang_char.py. +gap_sym = [CONSTANTS] ; # Name of this mode @@ -115,59 +115,59 @@ B^ = 0 ; # 0 to remain, 1 to delete ; # Example: '(笑 ナニガ)', '(笑 (F エー)+ソー+イッ+タ+ヨー+ナ)' 笑 = 0 -; # Example: 'コク(笑 サイ+(D オン))', +; # Example: 'コク(笑 サイ+(D オン))', 笑^ = 0 ; # 泣きながら発話 ; # 0 to remain, 1 to delete -; # Example: '(泣 ドンナニ)' +; # Example: '(泣 ドンナニ)' 泣 = 0 泣^ = 0 ; # 咳をしながら発話 ; # 0 to remain, 1 to delete -; # Example: 'シャ(咳 リン) ノ' +; # Example: 'シャ(咳 リン) ノ' 咳 = 0 ; # Example: 'イッ(咳 パン)', 'ワズ(咳 カ)' 咳^ = 0 ; # ささやき声や独り言などの小さな声 ; # 0 to remain, 1 to delete -; # Example: '(L アレコレナンダッケ)', '(L (W コデ;(? コレ,ココデ))+(? セツメー+シ+タ+ホー+ガ+イー+カ+ナ))' +; # Example: '(L アレコレナンダッケ)', '(L (W コデ;(? コレ,ココデ))+(? セツメー+シ+タ+ホー+ガ+イー+カ+ナ))' L = 0 ; # Example: 'デ(L ス)', 'ッ(L テ+コ)ト' L^ = 0 [REPLACEMENTS] ; # ボーカルフライなどで母音が同定できない場合 - = + = ; # 「うん/うーん/ふーん」の音の特定が困難な場合 - = + = ; # 非語彙的な母音の引き延ばし - = + = ; # 非語彙的な子音の引き延ばし - = + = ; # 言語音と独立に講演者の笑いが生じている場合 -<笑> = +<笑> = ; # 言語音と独立に講演者の咳が生じている場合 -<咳> = +<咳> = ; # 言語音と独立に講演者の息が生じている場合 -<息> = +<息> = ; # 講演者の泣き声 -<泣> = +<泣> = ; # 聴衆(司会者なども含む)の発話 -<フロア発話> = +<フロア発話> = ; # 聴衆の笑い -<フロア笑> = +<フロア笑> = ; # 聴衆の拍手 -<拍手> = +<拍手> = ; # 講演者が発表中に用いたデモンストレーションの音声 -<デモ> = +<デモ> = ; # 学会講演に発表時間を知らせるためにならすベルの音 -<ベル> = +<ベル> = ; # 転記単位全体が再度読み直された場合 -<朗読間違い> = +<朗読間違い> = ; # 上記以外の音で特に目立った音 -<雑音> = +<雑音> = ; # 0.2秒以上のポーズ -

= +

= ; # Redacted information, for R ; # It is \x00D7 multiplication sign, not your normal 'x' × = × @@ -318,3 +318,4 @@ spk_id = 2 ャ = ǐa ュ = ǐu ョ = ǐo + diff --git a/egs/csj/ASR/local/conf/number.ini b/egs/csj/ASR/local/conf/number.ini index cf9038f62..2613c3409 100644 --- a/egs/csj/ASR/local/conf/number.ini +++ b/egs/csj/ASR/local/conf/number.ini @@ -1,17 +1,17 @@ ; # This section is ignored if this file is not supplied as the first config file to -; # lhotse prepare csj +; # lhotse prepare csj [SEGMENTS] ; # Allowed period of nonverbal noise. If exceeded, a new segment is created. gap = 0.5 ; # Maximum length of segment (s). maxlen = 10 -; # Minimum length of segment (s). Segments shorter than `minlen` will be dropped silently. +; # Minimum length of segment (s). Segments shorter than `minlen` will be dropped silently. minlen = 0.02 -; # Use this symbol to represent a period of allowed nonverbal noise, i.e. `gap`. -; # Pass an empty string to avoid adding any symbol. It was "" in kaldi. -; # If you intend to use a multicharacter string for gap_sym, remember to register the -; # multicharacter string as part of userdef-string in prepare_lang_char.py. -gap_sym = +; # Use this symbol to represent a period of allowed nonverbal noise, i.e. `gap`. +; # Pass an empty string to avoid adding any symbol. It was "" in kaldi. +; # If you intend to use a multicharacter string for gap_sym, remember to register the +; # multicharacter string as part of userdef-string in prepare_lang_char.py. +gap_sym = [CONSTANTS] ; # Name of this mode @@ -115,59 +115,59 @@ B^ = 0 ; # 0 to remain, 1 to delete ; # Example: '(笑 ナニガ)', '(笑 (F エー)+ソー+イッ+タ+ヨー+ナ)' 笑 = 0 -; # Example: 'コク(笑 サイ+(D オン))', +; # Example: 'コク(笑 サイ+(D オン))', 笑^ = 0 ; # 泣きながら発話 ; # 0 to remain, 1 to delete -; # Example: '(泣 ドンナニ)' +; # Example: '(泣 ドンナニ)' 泣 = 0 泣^ = 0 ; # 咳をしながら発話 ; # 0 to remain, 1 to delete -; # Example: 'シャ(咳 リン) ノ' +; # Example: 'シャ(咳 リン) ノ' 咳 = 0 ; # Example: 'イッ(咳 パン)', 'ワズ(咳 カ)' 咳^ = 0 ; # ささやき声や独り言などの小さな声 ; # 0 to remain, 1 to delete -; # Example: '(L アレコレナンダッケ)', '(L (W コデ;(? コレ,ココデ))+(? セツメー+シ+タ+ホー+ガ+イー+カ+ナ))' +; # Example: '(L アレコレナンダッケ)', '(L (W コデ;(? コレ,ココデ))+(? セツメー+シ+タ+ホー+ガ+イー+カ+ナ))' L = 0 ; # Example: 'デ(L ス)', 'ッ(L テ+コ)ト' L^ = 0 [REPLACEMENTS] ; # ボーカルフライなどで母音が同定できない場合 - = + = ; # 「うん/うーん/ふーん」の音の特定が困難な場合 - = + = ; # 非語彙的な母音の引き延ばし - = + = ; # 非語彙的な子音の引き延ばし - = + = ; # 言語音と独立に講演者の笑いが生じている場合 -<笑> = +<笑> = ; # 言語音と独立に講演者の咳が生じている場合 -<咳> = +<咳> = ; # 言語音と独立に講演者の息が生じている場合 -<息> = +<息> = ; # 講演者の泣き声 -<泣> = +<泣> = ; # 聴衆(司会者なども含む)の発話 -<フロア発話> = +<フロア発話> = ; # 聴衆の笑い -<フロア笑> = +<フロア笑> = ; # 聴衆の拍手 -<拍手> = +<拍手> = ; # 講演者が発表中に用いたデモンストレーションの音声 -<デモ> = +<デモ> = ; # 学会講演に発表時間を知らせるためにならすベルの音 -<ベル> = +<ベル> = ; # 転記単位全体が再度読み直された場合 -<朗読間違い> = +<朗読間違い> = ; # 上記以外の音で特に目立った音 -<雑音> = +<雑音> = ; # 0.2秒以上のポーズ -

= +

= ; # Redacted information, for R ; # It is \x00D7 multiplication sign, not your normal 'x' × = × @@ -318,3 +318,4 @@ spk_id = 2 ャ = ǐa ュ = ǐu ョ = ǐo + diff --git a/egs/csj/ASR/local/conf/symbol.ini b/egs/csj/ASR/local/conf/symbol.ini index f9801284b..8ba451dd5 100644 --- a/egs/csj/ASR/local/conf/symbol.ini +++ b/egs/csj/ASR/local/conf/symbol.ini @@ -1,17 +1,17 @@ ; # This section is ignored if this file is not supplied as the first config file to -; # lhotse prepare csj +; # lhotse prepare csj [SEGMENTS] ; # Allowed period of nonverbal noise. If exceeded, a new segment is created. gap = 0.5 ; # Maximum length of segment (s). maxlen = 10 -; # Minimum length of segment (s). Segments shorter than `minlen` will be dropped silently. +; # Minimum length of segment (s). Segments shorter than `minlen` will be dropped silently. minlen = 0.02 -; # Use this symbol to represent a period of allowed nonverbal noise, i.e. `gap`. -; # Pass an empty string to avoid adding any symbol. It was "" in kaldi. -; # If you intend to use a multicharacter string for gap_sym, remember to register the -; # multicharacter string as part of userdef-string in prepare_lang_char.py. -gap_sym = +; # Use this symbol to represent a period of allowed nonverbal noise, i.e. `gap`. +; # Pass an empty string to avoid adding any symbol. It was "" in kaldi. +; # If you intend to use a multicharacter string for gap_sym, remember to register the +; # multicharacter string as part of userdef-string in prepare_lang_char.py. +gap_sym = [CONSTANTS] ; # Name of this mode @@ -116,59 +116,59 @@ B^ = 0 ; # 0 to remain, 1 to delete ; # Example: '(笑 ナニガ)', '(笑 (F エー)+ソー+イッ+タ+ヨー+ナ)' 笑 = 0 -; # Example: 'コク(笑 サイ+(D オン))', +; # Example: 'コク(笑 サイ+(D オン))', 笑^ = 0 ; # 泣きながら発話 ; # 0 to remain, 1 to delete -; # Example: '(泣 ドンナニ)' +; # Example: '(泣 ドンナニ)' 泣 = 0 泣^ = 0 ; # 咳をしながら発話 ; # 0 to remain, 1 to delete -; # Example: 'シャ(咳 リン) ノ' +; # Example: 'シャ(咳 リン) ノ' 咳 = 0 ; # Example: 'イッ(咳 パン)', 'ワズ(咳 カ)' 咳^ = 0 ; # ささやき声や独り言などの小さな声 ; # 0 to remain, 1 to delete -; # Example: '(L アレコレナンダッケ)', '(L (W コデ;(? コレ,ココデ))+(? セツメー+シ+タ+ホー+ガ+イー+カ+ナ))' +; # Example: '(L アレコレナンダッケ)', '(L (W コデ;(? コレ,ココデ))+(? セツメー+シ+タ+ホー+ガ+イー+カ+ナ))' L = 0 ; # Example: 'デ(L ス)', 'ッ(L テ+コ)ト' L^ = 0 [REPLACEMENTS] ; # ボーカルフライなどで母音が同定できない場合 - = + = ; # 「うん/うーん/ふーん」の音の特定が困難な場合 - = + = ; # 非語彙的な母音の引き延ばし - = + = ; # 非語彙的な子音の引き延ばし - = + = ; # 言語音と独立に講演者の笑いが生じている場合 -<笑> = +<笑> = ; # 言語音と独立に講演者の咳が生じている場合 -<咳> = +<咳> = ; # 言語音と独立に講演者の息が生じている場合 -<息> = +<息> = ; # 講演者の泣き声 -<泣> = +<泣> = ; # 聴衆(司会者なども含む)の発話 -<フロア発話> = +<フロア発話> = ; # 聴衆の笑い -<フロア笑> = +<フロア笑> = ; # 聴衆の拍手 -<拍手> = +<拍手> = ; # 講演者が発表中に用いたデモンストレーションの音声 -<デモ> = +<デモ> = ; # 学会講演に発表時間を知らせるためにならすベルの音 -<ベル> = +<ベル> = ; # 転記単位全体が再度読み直された場合 -<朗読間違い> = +<朗読間違い> = ; # 上記以外の音で特に目立った音 -<雑音> = +<雑音> = ; # 0.2秒以上のポーズ -

= +

= ; # Redacted information, for R ; # It is \x00D7 multiplication sign, not your normal 'x' × = × @@ -319,3 +319,4 @@ spk_id = 2 ャ = ǐa ュ = ǐu ョ = ǐo + diff --git a/egs/csj/ASR/local/display_manifest_statistics.py b/egs/csj/ASR/local/display_manifest_statistics.py index c043cf853..c9de21073 100644 --- a/egs/csj/ASR/local/display_manifest_statistics.py +++ b/egs/csj/ASR/local/display_manifest_statistics.py @@ -37,7 +37,9 @@ def get_parser(): formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) - parser.add_argument("--manifest-dir", type=Path, help="Path to cutset manifests") + parser.add_argument( + "--manifest-dir", type=Path, help="Path to cutset manifests" + ) return parser.parse_args() diff --git a/egs/csj/ASR/local/prepare_lang_char.py b/egs/csj/ASR/local/prepare_lang_char.py index f0078421b..e4d996871 100644 --- a/egs/csj/ASR/local/prepare_lang_char.py +++ b/egs/csj/ASR/local/prepare_lang_char.py @@ -68,7 +68,8 @@ def get_args(): type=Path, default=None, help=( - "Name of lang dir. If not set, this will default to lang_char_{trans-mode}" + "Name of lang dir. " + "If not set, this will default to lang_char_{trans-mode}" ), ) @@ -86,7 +87,9 @@ def main(): args = get_args() logging.basicConfig( - format="%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s", + format=( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] " "%(message)s" + ), level=logging.INFO, ) @@ -108,7 +111,8 @@ def main(): words = set() logging.info( - f"Creating vocabulary from {args.train_cut.name} at {args.trans_mode} mode." + f"Creating vocabulary from {args.train_cut.name}" + f" at {args.trans_mode} mode." ) for cut in train_set: try: @@ -119,7 +123,8 @@ def main(): ) except KeyError: raise KeyError( - f"Could not find {args.trans_mode} in {cut.supervisions[0].custom}" + f"Could not find {args.trans_mode} in " + f"{cut.supervisions[0].custom}" ) for t in text.split(): if t in args.userdef_string: @@ -138,7 +143,9 @@ def main(): (args.lang_dir / "words_len").write_text(f"{len(words)}") - (args.lang_dir / "userdef_string").write_text("\n".join(args.userdef_string)) + (args.lang_dir / "userdef_string").write_text( + "\n".join(args.userdef_string) + ) (args.lang_dir / "trans_mode").write_text(args.trans_mode) logging.info("Done.") diff --git a/egs/csj/ASR/local/validate_manifest.py b/egs/csj/ASR/local/validate_manifest.py index 89448a49c..0c4c6c1ea 100644 --- a/egs/csj/ASR/local/validate_manifest.py +++ b/egs/csj/ASR/local/validate_manifest.py @@ -68,7 +68,8 @@ def validate_supervision_and_cut_time_bounds(c: Cut): if s.end > c.end: raise ValueError( - f"{c.id}: Supervision end time {s.end} is larger than cut end time {c.end}" + f"{c.id}: Supervision end time {s.end} is larger " + f"than cut end time {c.end}" ) @@ -88,7 +89,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/gigaspeech/ASR/conformer_ctc/asr_datamodule.py b/egs/gigaspeech/ASR/conformer_ctc/asr_datamodule.py index c3e3e84bf..d78e26240 100644 --- a/egs/gigaspeech/ASR/conformer_ctc/asr_datamodule.py +++ b/egs/gigaspeech/ASR/conformer_ctc/asr_datamodule.py @@ -61,12 +61,10 @@ class GigaSpeechAsrDataModule: def add_arguments(cls, parser: argparse.ArgumentParser): group = parser.add_argument_group( title="ASR data related options", - description=( - "These options are used for the preparation of " - "PyTorch DataLoaders from Lhotse CutSet's -- they control the " - "effective batch sizes, sampling strategies, applied data " - "augmentations, etc." - ), + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", ) group.add_argument( "--manifest-dir", @@ -78,91 +76,75 @@ class GigaSpeechAsrDataModule: "--max-duration", type=int, default=200.0, - help=( - "Maximum pooled recordings duration (seconds) in a " - "single batch. You can reduce it if it causes CUDA OOM." - ), + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", ) group.add_argument( "--bucketing-sampler", type=str2bool, default=True, - help=( - "When enabled, the batches will come from buckets of " - "similar duration (saves padding frames)." - ), + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", ) group.add_argument( "--num-buckets", type=int, default=30, - help=( - "The number of buckets for the DynamicBucketingSampler" - "(you might want to increase it for larger datasets)." - ), + help="The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets).", ) group.add_argument( "--concatenate-cuts", type=str2bool, default=False, - help=( - "When enabled, utterances (cuts) will be concatenated " - "to minimize the amount of padding." - ), + help="When enabled, utterances (cuts) will be concatenated " + "to minimize the amount of padding.", ) group.add_argument( "--duration-factor", type=float, default=1.0, - help=( - "Determines the maximum duration of a concatenated cut " - "relative to the duration of the longest cut in a batch." - ), + help="Determines the maximum duration of a concatenated cut " + "relative to the duration of the longest cut in a batch.", ) group.add_argument( "--gap", type=float, default=1.0, - help=( - "The amount of padding (in seconds) inserted between " - "concatenated cuts. This padding is filled with noise when " - "noise augmentation is used." - ), + help="The amount of padding (in seconds) inserted between " + "concatenated cuts. This padding is filled with noise when " + "noise augmentation is used.", ) group.add_argument( "--on-the-fly-feats", type=str2bool, default=False, - help=( - "When enabled, use on-the-fly cut mixing and feature " - "extraction. Will drop existing precomputed feature manifests " - "if available." - ), + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", ) group.add_argument( "--shuffle", type=str2bool, default=True, - help=( - "When enabled (=default), the examples will be shuffled for each epoch." - ), + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", ) group.add_argument( "--return-cuts", type=str2bool, default=True, - help=( - "When enabled, each batch will have the " - "field: batch['supervisions']['cut'] with the cuts that " - "were used to construct it." - ), + help="When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it.", ) group.add_argument( "--num-workers", type=int, default=2, - help="The number of training dataloader workers that collect the batches.", + help="The number of training dataloader workers that " + "collect the batches.", ) group.add_argument( @@ -176,22 +158,18 @@ class GigaSpeechAsrDataModule: "--spec-aug-time-warp-factor", type=int, default=80, - help=( - "Used only when --enable-spec-aug is True. " - "It specifies the factor for time warping in SpecAugment. " - "Larger values mean more warping. " - "A value less than 1 means to disable time warp." - ), + help="Used only when --enable-spec-aug is True. " + "It specifies the factor for time warping in SpecAugment. " + "Larger values mean more warping. " + "A value less than 1 means to disable time warp.", ) group.add_argument( "--enable-musan", type=str2bool, default=True, - help=( - "When enabled, select noise from MUSAN and mix it " - "with training dataset. " - ), + help="When enabled, select noise from MUSAN and mix it " + "with training dataset. ", ) # GigaSpeech specific arguments @@ -205,25 +183,30 @@ class GigaSpeechAsrDataModule: "--small-dev", type=str2bool, default=False, - help="Should we use only 1000 utterances for dev (speeds up training)", + help="Should we use only 1000 utterances for dev " + "(speeds up training)", ) def train_dataloaders(self, cuts_train: CutSet) -> DataLoader: logging.info("About to get Musan cuts") - cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") + cuts_musan = load_manifest( + self.args.manifest_dir / "musan_cuts.jsonl.gz" + ) transforms = [] if self.args.enable_musan: logging.info("Enable MUSAN") transforms.append( - CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + CutMix( + cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True + ) ) else: logging.info("Disable MUSAN") if self.args.concatenate_cuts: logging.info( - "Using cut concatenation with duration factor " + f"Using cut concatenation with duration factor " f"{self.args.duration_factor} and gap {self.args.gap}." ) # Cut concatenation should be the first transform in the list, @@ -238,7 +221,9 @@ class GigaSpeechAsrDataModule: input_transforms = [] if self.args.enable_spec_aug: logging.info("Enable SpecAugment") - logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") + logging.info( + f"Time warp factor: {self.args.spec_aug_time_warp_factor}" + ) input_transforms.append( SpecAugment( time_warp_factor=self.args.spec_aug_time_warp_factor, @@ -271,7 +256,9 @@ class GigaSpeechAsrDataModule: # Drop feats to be on the safe side. train = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + input_strategy=OnTheFlyFeatures( + Fbank(FbankConfig(num_mel_bins=80)) + ), input_transforms=input_transforms, return_cuts=self.args.return_cuts, ) @@ -317,7 +304,9 @@ class GigaSpeechAsrDataModule: if self.args.on_the_fly_feats: validate = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + input_strategy=OnTheFlyFeatures( + Fbank(FbankConfig(num_mel_bins=80)) + ), return_cuts=self.args.return_cuts, ) else: @@ -373,7 +362,9 @@ class GigaSpeechAsrDataModule: @lru_cache() def dev_cuts(self) -> CutSet: logging.info("About to get dev cuts") - cuts_valid = load_manifest_lazy(self.args.manifest_dir / "cuts_DEV.jsonl.gz") + cuts_valid = load_manifest_lazy( + self.args.manifest_dir / "cuts_DEV.jsonl.gz" + ) if self.args.small_dev: return cuts_valid.subset(first=1000) else: diff --git a/egs/gigaspeech/ASR/conformer_ctc/conformer.py b/egs/gigaspeech/ASR/conformer_ctc/conformer.py index 1153a814c..6fac07f93 100644 --- a/egs/gigaspeech/ASR/conformer_ctc/conformer.py +++ b/egs/gigaspeech/ASR/conformer_ctc/conformer.py @@ -160,7 +160,9 @@ class ConformerEncoderLayer(nn.Module): use_conv_batchnorm: bool = False, ) -> None: super(ConformerEncoderLayer, self).__init__() - self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) + self.self_attn = RelPositionMultiheadAttention( + d_model, nhead, dropout=0.0 + ) self.feed_forward = nn.Sequential( nn.Linear(d_model, dim_feedforward), @@ -180,14 +182,18 @@ class ConformerEncoderLayer(nn.Module): d_model, cnn_module_kernel, use_batchnorm=use_conv_batchnorm ) - self.norm_ff_macaron = nn.LayerNorm(d_model) # for the macaron style FNN module + self.norm_ff_macaron = nn.LayerNorm( + d_model + ) # for the macaron style FNN module self.norm_ff = nn.LayerNorm(d_model) # for the FNN module self.norm_mha = nn.LayerNorm(d_model) # for the MHA module self.ff_scale = 0.5 self.norm_conv = nn.LayerNorm(d_model) # for the CNN module - self.norm_final = nn.LayerNorm(d_model) # for the final output of the block + self.norm_final = nn.LayerNorm( + d_model + ) # for the final output of the block self.dropout = nn.Dropout(dropout) @@ -221,7 +227,9 @@ class ConformerEncoderLayer(nn.Module): residual = src if self.normalize_before: src = self.norm_ff_macaron(src) - src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src)) + src = residual + self.ff_scale * self.dropout( + self.feed_forward_macaron(src) + ) if not self.normalize_before: src = self.norm_ff_macaron(src) @@ -340,7 +348,9 @@ class RelPositionalEncoding(torch.nn.Module): """ - def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: + def __init__( + self, d_model: int, dropout_rate: float, max_len: int = 5000 + ) -> None: """Construct an PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() self.d_model = d_model @@ -356,7 +366,9 @@ class RelPositionalEncoding(torch.nn.Module): # the length of self.pe is 2 * input_len - 1 if self.pe.size(1) >= x.size(1) * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): + if self.pe.dtype != x.dtype or str(self.pe.device) != str( + x.device + ): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return # Suppose `i` means to the position of query vector and `j` means the @@ -626,9 +638,9 @@ class RelPositionMultiheadAttention(nn.Module): if torch.equal(query, key) and torch.equal(key, value): # self-attention - q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk( - 3, dim=-1 - ) + q, k, v = nn.functional.linear( + query, in_proj_weight, in_proj_bias + ).chunk(3, dim=-1) elif torch.equal(key, value): # encoder-decoder attention @@ -696,25 +708,33 @@ class RelPositionMultiheadAttention(nn.Module): if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: - raise RuntimeError("The size of the 2D attn_mask is not correct.") + raise RuntimeError( + "The size of the 2D attn_mask is not correct." + ) elif attn_mask.dim() == 3: if list(attn_mask.size()) != [ bsz * num_heads, query.size(0), key.size(0), ]: - raise RuntimeError("The size of the 3D attn_mask is not correct.") + raise RuntimeError( + "The size of the 3D attn_mask is not correct." + ) else: raise RuntimeError( - "attn_mask's dimension {} is not supported".format(attn_mask.dim()) + "attn_mask's dimension {} is not supported".format( + attn_mask.dim() + ) ) # attn_mask's dim is 3 now. # convert ByteTensor key_padding_mask to bool - if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: + if ( + key_padding_mask is not None + and key_padding_mask.dtype == torch.uint8 + ): warnings.warn( - "Byte tensor for key_padding_mask is deprecated. Use bool tensor" - " instead." + "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." ) key_padding_mask = key_padding_mask.to(torch.bool) @@ -751,7 +771,9 @@ class RelPositionMultiheadAttention(nn.Module): # first compute matrix a and matrix c # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2) + matrix_ac = torch.matmul( + q_with_bias_u, k + ) # (batch, head, time1, time2) # compute matrix b and matrix d matrix_bd = torch.matmul( @@ -763,7 +785,9 @@ class RelPositionMultiheadAttention(nn.Module): matrix_ac + matrix_bd ) * scaling # (batch, head, time1, time2) - attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) + attn_output_weights = attn_output_weights.view( + bsz * num_heads, tgt_len, -1 + ) assert list(attn_output_weights.size()) == [ bsz * num_heads, @@ -797,9 +821,13 @@ class RelPositionMultiheadAttention(nn.Module): attn_output = torch.bmm(attn_output_weights, v) assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] attn_output = ( - attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) + attn_output.transpose(0, 1) + .contiguous() + .view(tgt_len, bsz, embed_dim) + ) + attn_output = nn.functional.linear( + attn_output, out_proj_weight, out_proj_bias ) - attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) if need_weights: # average attention weights over heads diff --git a/egs/gigaspeech/ASR/conformer_ctc/decode.py b/egs/gigaspeech/ASR/conformer_ctc/decode.py index b38ae9c8c..9c1418baa 100755 --- a/egs/gigaspeech/ASR/conformer_ctc/decode.py +++ b/egs/gigaspeech/ASR/conformer_ctc/decode.py @@ -62,19 +62,16 @@ def get_parser(): "--epoch", type=int, default=0, - help=( - "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." - ), + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", ) parser.add_argument( "--avg", type=int, default=1, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. " - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", ) parser.add_argument( @@ -479,7 +476,9 @@ def decode_dataset( results[lm_scale].extend(this_batch) else: - assert len(results) > 0, "It should not decode to empty in the first batch!" + assert ( + len(results) > 0 + ), "It should not decode to empty in the first batch!" this_batch = [] hyp_words = [] for cut_id, ref_text in zip(cut_ids, texts): @@ -494,7 +493,9 @@ def decode_dataset( if batch_idx % 100 == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) return results @@ -527,7 +528,9 @@ def save_results( test_set_wers[key] = wer if enable_log: - logging.info("Wrote detailed error stats to {}".format(errs_filename)) + logging.info( + "Wrote detailed error stats to {}".format(errs_filename) + ) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt" @@ -702,7 +705,9 @@ def main(): eos_id=eos_id, ) - save_results(params=params, test_set_name=test_set, results_dict=results_dict) + save_results( + params=params, test_set_name=test_set, results_dict=results_dict + ) logging.info("Done!") diff --git a/egs/gigaspeech/ASR/conformer_ctc/gigaspeech_scoring.py b/egs/gigaspeech/ASR/conformer_ctc/gigaspeech_scoring.py index 880aa76e2..ef53b77f8 100755 --- a/egs/gigaspeech/ASR/conformer_ctc/gigaspeech_scoring.py +++ b/egs/gigaspeech/ASR/conformer_ctc/gigaspeech_scoring.py @@ -73,7 +73,8 @@ def asr_text_post_processing(text: str) -> str: if __name__ == "__main__": parser = argparse.ArgumentParser( - description="This script evaluates GigaSpeech ASR result viaSCTK's tool sclite" + description="This script evaluates GigaSpeech ASR result via" + "SCTK's tool sclite" ) parser.add_argument( "ref", diff --git a/egs/gigaspeech/ASR/conformer_ctc/label_smoothing.py b/egs/gigaspeech/ASR/conformer_ctc/label_smoothing.py index 3b94f0c4b..cdc85ce9a 100644 --- a/egs/gigaspeech/ASR/conformer_ctc/label_smoothing.py +++ b/egs/gigaspeech/ASR/conformer_ctc/label_smoothing.py @@ -78,10 +78,13 @@ class LabelSmoothingLoss(torch.nn.Module): ignored = target == self.ignore_index target[ignored] = 0 - true_dist = torch.nn.functional.one_hot(target, num_classes=num_classes).to(x) + true_dist = torch.nn.functional.one_hot( + target, num_classes=num_classes + ).to(x) true_dist = ( - true_dist * (1 - self.label_smoothing) + self.label_smoothing / num_classes + true_dist * (1 - self.label_smoothing) + + self.label_smoothing / num_classes ) # Set the value of ignored indexes to 0 true_dist[ignored] = 0 diff --git a/egs/gigaspeech/ASR/conformer_ctc/subsampling.py b/egs/gigaspeech/ASR/conformer_ctc/subsampling.py index 8e0f73d05..542fb0364 100644 --- a/egs/gigaspeech/ASR/conformer_ctc/subsampling.py +++ b/egs/gigaspeech/ASR/conformer_ctc/subsampling.py @@ -42,9 +42,13 @@ class Conv2dSubsampling(nn.Module): assert idim >= 7 super().__init__() self.conv = nn.Sequential( - nn.Conv2d(in_channels=1, out_channels=odim, kernel_size=3, stride=2), + nn.Conv2d( + in_channels=1, out_channels=odim, kernel_size=3, stride=2 + ), nn.ReLU(), - nn.Conv2d(in_channels=odim, out_channels=odim, kernel_size=3, stride=2), + nn.Conv2d( + in_channels=odim, out_channels=odim, kernel_size=3, stride=2 + ), nn.ReLU(), ) self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim) @@ -128,13 +132,17 @@ class VggSubsampling(nn.Module): ) ) layers.append( - torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=True) + torch.nn.MaxPool2d( + kernel_size=2, stride=2, padding=0, ceil_mode=True + ) ) cur_channels = block_dim self.layers = nn.Sequential(*layers) - self.out = nn.Linear(block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim) + self.out = nn.Linear( + block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim + ) def forward(self, x: torch.Tensor) -> torch.Tensor: """Subsample x. diff --git a/egs/gigaspeech/ASR/conformer_ctc/train.py b/egs/gigaspeech/ASR/conformer_ctc/train.py index 4883d04d8..2965cde18 100755 --- a/egs/gigaspeech/ASR/conformer_ctc/train.py +++ b/egs/gigaspeech/ASR/conformer_ctc/train.py @@ -386,7 +386,9 @@ def compute_loss( # # See https://github.com/k2-fsa/icefall/issues/97 # for more details - unsorted_token_ids = graph_compiler.texts_to_ids(supervisions["text"]) + unsorted_token_ids = graph_compiler.texts_to_ids( + supervisions["text"] + ) att_loss = mmodel.decoder_forward( encoder_memory, memory_mask, @@ -519,7 +521,9 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -637,7 +641,9 @@ def run(rank, world_size, args): cur_lr = optimizer._rate if tb_writer is not None: - tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) if rank == 0: diff --git a/egs/gigaspeech/ASR/conformer_ctc/transformer.py b/egs/gigaspeech/ASR/conformer_ctc/transformer.py index 0566cfc81..00ca027a7 100644 --- a/egs/gigaspeech/ASR/conformer_ctc/transformer.py +++ b/egs/gigaspeech/ASR/conformer_ctc/transformer.py @@ -151,7 +151,9 @@ class Transformer(nn.Module): norm=decoder_norm, ) - self.decoder_output_layer = torch.nn.Linear(d_model, self.decoder_num_class) + self.decoder_output_layer = torch.nn.Linear( + d_model, self.decoder_num_class + ) self.decoder_criterion = LabelSmoothingLoss() else: @@ -179,13 +181,18 @@ class Transformer(nn.Module): memory_key_padding_mask for the decoder. Its shape is (N, T). It is None if `supervision` is None. """ - if isinstance(self.use_feat_batchnorm, bool) and self.use_feat_batchnorm: + if ( + isinstance(self.use_feat_batchnorm, bool) + and self.use_feat_batchnorm + ): x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T) x = self.feat_batchnorm(x) x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C) if isinstance(self.use_feat_batchnorm, float): x *= self.use_feat_batchnorm - encoder_memory, memory_key_padding_mask = self.run_encoder(x, supervision) + encoder_memory, memory_key_padding_mask = self.run_encoder( + x, supervision + ) x = self.ctc_output(encoder_memory) return x, encoder_memory, memory_key_padding_mask @@ -266,17 +273,23 @@ class Transformer(nn.Module): """ ys_in = add_sos(token_ids, sos_id=sos_id) ys_in = [torch.tensor(y) for y in ys_in] - ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id)) + ys_in_pad = pad_sequence( + ys_in, batch_first=True, padding_value=float(eos_id) + ) ys_out = add_eos(token_ids, eos_id=eos_id) ys_out = [torch.tensor(y) for y in ys_out] - ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1)) + ys_out_pad = pad_sequence( + ys_out, batch_first=True, padding_value=float(-1) + ) device = memory.device ys_in_pad = ys_in_pad.to(device) ys_out_pad = ys_out_pad.to(device) - tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( + device + ) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) # TODO: Use length information to create the decoder padding mask @@ -337,17 +350,23 @@ class Transformer(nn.Module): ys_in = add_sos(token_ids, sos_id=sos_id) ys_in = [torch.tensor(y) for y in ys_in] - ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id)) + ys_in_pad = pad_sequence( + ys_in, batch_first=True, padding_value=float(eos_id) + ) ys_out = add_eos(token_ids, eos_id=eos_id) ys_out = [torch.tensor(y) for y in ys_out] - ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1)) + ys_out_pad = pad_sequence( + ys_out, batch_first=True, padding_value=float(-1) + ) device = memory.device ys_in_pad = ys_in_pad.to(device, dtype=torch.int64) ys_out_pad = ys_out_pad.to(device, dtype=torch.int64) - tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( + device + ) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) # TODO: Use length information to create the decoder padding mask @@ -620,7 +639,9 @@ def _get_activation_fn(activation: str): elif activation == "gelu": return nn.functional.gelu - raise RuntimeError("activation should be relu/gelu, not {}".format(activation)) + raise RuntimeError( + "activation should be relu/gelu, not {}".format(activation) + ) class PositionalEncoding(nn.Module): @@ -822,7 +843,9 @@ def encoder_padding_mask( 1, ).to(torch.int32) - lengths = [0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)] + lengths = [ + 0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1) + ] for idx in range(supervision_segments.size(0)): # Note: TorchScript doesn't allow to unpack tensors as tuples sequence_idx = supervision_segments[idx, 0].item() @@ -843,7 +866,9 @@ def encoder_padding_mask( return mask -def decoder_padding_mask(ys_pad: torch.Tensor, ignore_id: int = -1) -> torch.Tensor: +def decoder_padding_mask( + ys_pad: torch.Tensor, ignore_id: int = -1 +) -> torch.Tensor: """Generate a length mask for input. The masked position are filled with True, diff --git a/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_dev_test.py b/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_dev_test.py index 07beeb1f0..8209ee3ec 100755 --- a/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_dev_test.py +++ b/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_dev_test.py @@ -77,7 +77,9 @@ def compute_fbank_gigaspeech_dev_test(): def main(): - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) compute_fbank_gigaspeech_dev_test() diff --git a/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_splits.py b/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_splits.py index 0ee845ec8..6410249db 100755 --- a/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_splits.py +++ b/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_splits.py @@ -47,10 +47,8 @@ def get_parser(): "--batch-duration", type=float, default=600.0, - help=( - "The maximum number of audio seconds in a batch." - "Determines batch size dynamically." - ), + help="The maximum number of audio seconds in a batch." + "Determines batch size dynamically.", ) parser.add_argument( @@ -136,7 +134,9 @@ def main(): date_time = now.strftime("%Y-%m-%d-%H-%M-%S") log_filename = "log-compute_fbank_gigaspeech_splits" - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) log_filename = f"{log_filename}-{date_time}" logging.basicConfig( diff --git a/egs/gigaspeech/ASR/local/preprocess_gigaspeech.py b/egs/gigaspeech/ASR/local/preprocess_gigaspeech.py index 31abe7fff..48d10a157 100755 --- a/egs/gigaspeech/ASR/local/preprocess_gigaspeech.py +++ b/egs/gigaspeech/ASR/local/preprocess_gigaspeech.py @@ -98,13 +98,19 @@ def preprocess_giga_speech(): f"Speed perturb for {partition} with factors 0.9 and 1.1 " "(Perturbing may take 8 minutes and saving may take 20 minutes)" ) - cut_set = cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) + cut_set = ( + cut_set + + cut_set.perturb_speed(0.9) + + cut_set.perturb_speed(1.1) + ) logging.info(f"Saving to {raw_cuts_path}") cut_set.to_file(raw_cuts_path) def main(): - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) preprocess_giga_speech() diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py index 9ae3f071e..c87686e1e 100644 --- a/egs/gigaspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py +++ b/egs/gigaspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py @@ -73,12 +73,10 @@ class GigaSpeechAsrDataModule: def add_arguments(cls, parser: argparse.ArgumentParser): group = parser.add_argument_group( title="ASR data related options", - description=( - "These options are used for the preparation of " - "PyTorch DataLoaders from Lhotse CutSet's -- they control the " - "effective batch sizes, sampling strategies, applied data " - "augmentations, etc." - ), + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", ) group.add_argument( "--manifest-dir", @@ -90,91 +88,75 @@ class GigaSpeechAsrDataModule: "--max-duration", type=int, default=200.0, - help=( - "Maximum pooled recordings duration (seconds) in a " - "single batch. You can reduce it if it causes CUDA OOM." - ), + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", ) group.add_argument( "--bucketing-sampler", type=str2bool, default=True, - help=( - "When enabled, the batches will come from buckets of " - "similar duration (saves padding frames)." - ), + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", ) group.add_argument( "--num-buckets", type=int, default=30, - help=( - "The number of buckets for the DynamicBucketingSampler" - "(you might want to increase it for larger datasets)." - ), + help="The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets).", ) group.add_argument( "--concatenate-cuts", type=str2bool, default=False, - help=( - "When enabled, utterances (cuts) will be concatenated " - "to minimize the amount of padding." - ), + help="When enabled, utterances (cuts) will be concatenated " + "to minimize the amount of padding.", ) group.add_argument( "--duration-factor", type=float, default=1.0, - help=( - "Determines the maximum duration of a concatenated cut " - "relative to the duration of the longest cut in a batch." - ), + help="Determines the maximum duration of a concatenated cut " + "relative to the duration of the longest cut in a batch.", ) group.add_argument( "--gap", type=float, default=1.0, - help=( - "The amount of padding (in seconds) inserted between " - "concatenated cuts. This padding is filled with noise when " - "noise augmentation is used." - ), + help="The amount of padding (in seconds) inserted between " + "concatenated cuts. This padding is filled with noise when " + "noise augmentation is used.", ) group.add_argument( "--on-the-fly-feats", type=str2bool, default=False, - help=( - "When enabled, use on-the-fly cut mixing and feature " - "extraction. Will drop existing precomputed feature manifests " - "if available." - ), + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", ) group.add_argument( "--shuffle", type=str2bool, default=True, - help=( - "When enabled (=default), the examples will be shuffled for each epoch." - ), + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", ) group.add_argument( "--return-cuts", type=str2bool, default=True, - help=( - "When enabled, each batch will have the " - "field: batch['supervisions']['cut'] with the cuts that " - "were used to construct it." - ), + help="When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it.", ) group.add_argument( "--num-workers", type=int, default=2, - help="The number of training dataloader workers that collect the batches.", + help="The number of training dataloader workers that " + "collect the batches.", ) group.add_argument( @@ -188,22 +170,18 @@ class GigaSpeechAsrDataModule: "--spec-aug-time-warp-factor", type=int, default=80, - help=( - "Used only when --enable-spec-aug is True. " - "It specifies the factor for time warping in SpecAugment. " - "Larger values mean more warping. " - "A value less than 1 means to disable time warp." - ), + help="Used only when --enable-spec-aug is True. " + "It specifies the factor for time warping in SpecAugment. " + "Larger values mean more warping. " + "A value less than 1 means to disable time warp.", ) group.add_argument( "--enable-musan", type=str2bool, default=True, - help=( - "When enabled, select noise from MUSAN and mix it " - "with training dataset. " - ), + help="When enabled, select noise from MUSAN and mix it " + "with training dataset. ", ) # GigaSpeech specific arguments @@ -217,7 +195,8 @@ class GigaSpeechAsrDataModule: "--small-dev", type=str2bool, default=False, - help="Should we use only 1000 utterances for dev (speeds up training)", + help="Should we use only 1000 utterances for dev " + "(speeds up training)", ) def train_dataloaders( @@ -237,16 +216,20 @@ class GigaSpeechAsrDataModule: if self.args.enable_musan: logging.info("Enable MUSAN") logging.info("About to get Musan cuts") - cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") + cuts_musan = load_manifest( + self.args.manifest_dir / "musan_cuts.jsonl.gz" + ) transforms.append( - CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + CutMix( + cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True + ) ) else: logging.info("Disable MUSAN") if self.args.concatenate_cuts: logging.info( - "Using cut concatenation with duration factor " + f"Using cut concatenation with duration factor " f"{self.args.duration_factor} and gap {self.args.gap}." ) # Cut concatenation should be the first transform in the list, @@ -261,7 +244,9 @@ class GigaSpeechAsrDataModule: input_transforms = [] if self.args.enable_spec_aug: logging.info("Enable SpecAugment") - logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") + logging.info( + f"Time warp factor: {self.args.spec_aug_time_warp_factor}" + ) # Set the value of num_frame_masks according to Lhotse's version. # In different Lhotse's versions, the default of num_frame_masks is # different. @@ -304,7 +289,9 @@ class GigaSpeechAsrDataModule: # Drop feats to be on the safe side. train = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + input_strategy=OnTheFlyFeatures( + Fbank(FbankConfig(num_mel_bins=80)) + ), input_transforms=input_transforms, return_cuts=self.args.return_cuts, ) @@ -360,7 +347,9 @@ class GigaSpeechAsrDataModule: if self.args.on_the_fly_feats: validate = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + input_strategy=OnTheFlyFeatures( + Fbank(FbankConfig(num_mel_bins=80)) + ), return_cuts=self.args.return_cuts, ) else: @@ -416,7 +405,9 @@ class GigaSpeechAsrDataModule: @lru_cache() def dev_cuts(self) -> CutSet: logging.info("About to get dev cuts") - cuts_valid = load_manifest_lazy(self.args.manifest_dir / "cuts_DEV.jsonl.gz") + cuts_valid = load_manifest_lazy( + self.args.manifest_dir / "cuts_DEV.jsonl.gz" + ) if self.args.small_dev: return cuts_valid.subset(first=1000) else: diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/decode.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/decode.py index 9f5d4711b..5849a3471 100755 --- a/egs/gigaspeech/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/gigaspeech/ASR/pruned_transducer_stateless2/decode.py @@ -77,7 +77,11 @@ from beam_search import ( from gigaspeech_scoring import asr_text_post_processing from train import get_params, get_transducer_model -from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint +from icefall.checkpoint import ( + average_checkpoints, + find_checkpoints, + load_checkpoint, +) from icefall.utils import ( AttributeDict, setup_logger, @@ -114,11 +118,9 @@ def get_parser(): "--avg", type=int, default=8, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( @@ -186,7 +188,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -255,7 +258,9 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + encoder_out, encoder_out_lens = model.encoder( + x=feature, x_lens=feature_lens + ) hyps = [] if params.decoding_method == "fast_beam_search": @@ -270,7 +275,10 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + elif ( + params.decoding_method == "greedy_search" + and params.max_sym_per_frame == 1 + ): hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -316,7 +324,11 @@ def decode_one_batch( return {"greedy_search": hyps} elif params.decoding_method == "fast_beam_search": return { - f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps + ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}" + ): hyps } else: return {f"beam_size_{params.beam_size}": hyps} @@ -386,7 +398,9 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) return results @@ -420,7 +434,8 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -496,7 +511,8 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/export.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/export.py index 17f8614dc..cff9c7377 100755 --- a/egs/gigaspeech/ASR/pruned_transducer_stateless2/export.py +++ b/egs/gigaspeech/ASR/pruned_transducer_stateless2/export.py @@ -51,7 +51,11 @@ import sentencepiece as spm import torch from train import get_params, get_transducer_model -from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint +from icefall.checkpoint import ( + average_checkpoints, + find_checkpoints, + load_checkpoint, +) from icefall.utils import str2bool @@ -83,11 +87,9 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( @@ -118,7 +120,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) return parser @@ -157,7 +160,8 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -205,7 +209,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py index 4d1a2356d..83ae25561 100755 --- a/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py @@ -77,7 +77,9 @@ from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] +LRSchedulerType = Union[ + torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler +] def get_parser(): @@ -176,45 +178,42 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help=( - "The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss" - ), + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help=( - "The scale to smooth the loss with lm (output of prediction network) part." - ), + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)part.", + help="The scale to smooth the loss with am (output of encoder network)" + "part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help=( - "To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss." - ), + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", ) parser.add_argument( @@ -554,16 +553,23 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) + 0.0 + if warmup < 1.0 + else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) + ) + loss = ( + params.simple_loss_scale * simple_loss + + pruned_loss_scale * pruned_loss ) - loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + info["frames"] = ( + (feature_lens // params.subsampling_factor).sum().item() + ) # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() @@ -726,7 +732,9 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") diff --git a/egs/librispeech/ASR/conformer_ctc/ali.py b/egs/librispeech/ASR/conformer_ctc/ali.py index 0169d0f82..2828e309e 100755 --- a/egs/librispeech/ASR/conformer_ctc/ali.py +++ b/egs/librispeech/ASR/conformer_ctc/ali.py @@ -61,19 +61,16 @@ def get_parser(): "--epoch", type=int, default=34, - help=( - "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." - ), + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", ) parser.add_argument( "--avg", type=int, default=20, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. " - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", ) parser.add_argument( @@ -234,7 +231,9 @@ def compute_alignments( labels_ali = get_alignments(best_path, kind="labels") aux_labels_ali = get_alignments(best_path, kind="aux_labels") assert len(labels_ali) == len(aux_labels_ali) == len(cut_list) - for cut, labels, aux_labels in zip(cut_list, labels_ali, aux_labels_ali): + for cut, labels, aux_labels in zip( + cut_list, labels_ali, aux_labels_ali + ): cut.labels_alignment = labels_writer.store_array( key=cut.id, value=np.asarray(labels, dtype=np.int32), @@ -259,7 +258,9 @@ def compute_alignments( if batch_idx % 100 == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) return CutSet.from_cuts(cuts) @@ -288,7 +289,9 @@ def main(): out_labels_ali_filename = out_dir / f"labels_{params.dataset}.h5" out_aux_labels_ali_filename = out_dir / f"aux_labels_{params.dataset}.h5" - out_manifest_filename = out_dir / f"librispeech_cuts_{params.dataset}.jsonl.gz" + out_manifest_filename = ( + out_dir / f"librispeech_cuts_{params.dataset}.jsonl.gz" + ) for f in ( out_labels_ali_filename, diff --git a/egs/librispeech/ASR/conformer_ctc/conformer.py b/egs/librispeech/ASR/conformer_ctc/conformer.py index 1153a814c..6fac07f93 100644 --- a/egs/librispeech/ASR/conformer_ctc/conformer.py +++ b/egs/librispeech/ASR/conformer_ctc/conformer.py @@ -160,7 +160,9 @@ class ConformerEncoderLayer(nn.Module): use_conv_batchnorm: bool = False, ) -> None: super(ConformerEncoderLayer, self).__init__() - self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) + self.self_attn = RelPositionMultiheadAttention( + d_model, nhead, dropout=0.0 + ) self.feed_forward = nn.Sequential( nn.Linear(d_model, dim_feedforward), @@ -180,14 +182,18 @@ class ConformerEncoderLayer(nn.Module): d_model, cnn_module_kernel, use_batchnorm=use_conv_batchnorm ) - self.norm_ff_macaron = nn.LayerNorm(d_model) # for the macaron style FNN module + self.norm_ff_macaron = nn.LayerNorm( + d_model + ) # for the macaron style FNN module self.norm_ff = nn.LayerNorm(d_model) # for the FNN module self.norm_mha = nn.LayerNorm(d_model) # for the MHA module self.ff_scale = 0.5 self.norm_conv = nn.LayerNorm(d_model) # for the CNN module - self.norm_final = nn.LayerNorm(d_model) # for the final output of the block + self.norm_final = nn.LayerNorm( + d_model + ) # for the final output of the block self.dropout = nn.Dropout(dropout) @@ -221,7 +227,9 @@ class ConformerEncoderLayer(nn.Module): residual = src if self.normalize_before: src = self.norm_ff_macaron(src) - src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src)) + src = residual + self.ff_scale * self.dropout( + self.feed_forward_macaron(src) + ) if not self.normalize_before: src = self.norm_ff_macaron(src) @@ -340,7 +348,9 @@ class RelPositionalEncoding(torch.nn.Module): """ - def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: + def __init__( + self, d_model: int, dropout_rate: float, max_len: int = 5000 + ) -> None: """Construct an PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() self.d_model = d_model @@ -356,7 +366,9 @@ class RelPositionalEncoding(torch.nn.Module): # the length of self.pe is 2 * input_len - 1 if self.pe.size(1) >= x.size(1) * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): + if self.pe.dtype != x.dtype or str(self.pe.device) != str( + x.device + ): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return # Suppose `i` means to the position of query vector and `j` means the @@ -626,9 +638,9 @@ class RelPositionMultiheadAttention(nn.Module): if torch.equal(query, key) and torch.equal(key, value): # self-attention - q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk( - 3, dim=-1 - ) + q, k, v = nn.functional.linear( + query, in_proj_weight, in_proj_bias + ).chunk(3, dim=-1) elif torch.equal(key, value): # encoder-decoder attention @@ -696,25 +708,33 @@ class RelPositionMultiheadAttention(nn.Module): if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: - raise RuntimeError("The size of the 2D attn_mask is not correct.") + raise RuntimeError( + "The size of the 2D attn_mask is not correct." + ) elif attn_mask.dim() == 3: if list(attn_mask.size()) != [ bsz * num_heads, query.size(0), key.size(0), ]: - raise RuntimeError("The size of the 3D attn_mask is not correct.") + raise RuntimeError( + "The size of the 3D attn_mask is not correct." + ) else: raise RuntimeError( - "attn_mask's dimension {} is not supported".format(attn_mask.dim()) + "attn_mask's dimension {} is not supported".format( + attn_mask.dim() + ) ) # attn_mask's dim is 3 now. # convert ByteTensor key_padding_mask to bool - if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: + if ( + key_padding_mask is not None + and key_padding_mask.dtype == torch.uint8 + ): warnings.warn( - "Byte tensor for key_padding_mask is deprecated. Use bool tensor" - " instead." + "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." ) key_padding_mask = key_padding_mask.to(torch.bool) @@ -751,7 +771,9 @@ class RelPositionMultiheadAttention(nn.Module): # first compute matrix a and matrix c # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2) + matrix_ac = torch.matmul( + q_with_bias_u, k + ) # (batch, head, time1, time2) # compute matrix b and matrix d matrix_bd = torch.matmul( @@ -763,7 +785,9 @@ class RelPositionMultiheadAttention(nn.Module): matrix_ac + matrix_bd ) * scaling # (batch, head, time1, time2) - attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) + attn_output_weights = attn_output_weights.view( + bsz * num_heads, tgt_len, -1 + ) assert list(attn_output_weights.size()) == [ bsz * num_heads, @@ -797,9 +821,13 @@ class RelPositionMultiheadAttention(nn.Module): attn_output = torch.bmm(attn_output_weights, v) assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] attn_output = ( - attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) + attn_output.transpose(0, 1) + .contiguous() + .view(tgt_len, bsz, embed_dim) + ) + attn_output = nn.functional.linear( + attn_output, out_proj_weight, out_proj_bias ) - attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) if need_weights: # average attention weights over heads diff --git a/egs/librispeech/ASR/conformer_ctc/decode.py b/egs/librispeech/ASR/conformer_ctc/decode.py index 66fdf82d9..3f3b1acda 100755 --- a/egs/librispeech/ASR/conformer_ctc/decode.py +++ b/egs/librispeech/ASR/conformer_ctc/decode.py @@ -64,19 +64,16 @@ def get_parser(): "--epoch", type=int, default=77, - help=( - "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." - ), + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", ) parser.add_argument( "--avg", type=int, default=55, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. " - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", ) parser.add_argument( @@ -554,7 +551,9 @@ def decode_dataset( results[lm_scale].extend(this_batch) else: - assert len(results) > 0, "It should not decode to empty in the first batch!" + assert ( + len(results) > 0 + ), "It should not decode to empty in the first batch!" this_batch = [] hyp_words = [] for ref_text in texts: @@ -569,7 +568,9 @@ def decode_dataset( if batch_idx % 100 == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) return results @@ -601,7 +602,9 @@ def save_results( test_set_wers[key] = wer if enable_log: - logging.info("Wrote detailed error stats to {}".format(errs_filename)) + logging.info( + "Wrote detailed error stats to {}".format(errs_filename) + ) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt" @@ -806,7 +809,9 @@ def main(): eos_id=eos_id, ) - save_results(params=params, test_set_name=test_set, results_dict=results_dict) + save_results( + params=params, test_set_name=test_set, results_dict=results_dict + ) logging.info("Done!") diff --git a/egs/librispeech/ASR/conformer_ctc/export.py b/egs/librispeech/ASR/conformer_ctc/export.py index bdb8a85e5..28c28df01 100755 --- a/egs/librispeech/ASR/conformer_ctc/export.py +++ b/egs/librispeech/ASR/conformer_ctc/export.py @@ -40,20 +40,17 @@ def get_parser(): "--epoch", type=int, default=34, - help=( - "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." - ), + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", ) parser.add_argument( "--avg", type=int, default=20, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. " - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", ) parser.add_argument( @@ -160,7 +157,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/conformer_ctc/label_smoothing.py b/egs/librispeech/ASR/conformer_ctc/label_smoothing.py index cb0d6e04d..1f2f3b137 100644 --- a/egs/librispeech/ASR/conformer_ctc/label_smoothing.py +++ b/egs/librispeech/ASR/conformer_ctc/label_smoothing.py @@ -82,10 +82,13 @@ class LabelSmoothingLoss(torch.nn.Module): # for why we don't use target[ignored] = 0 here target = torch.where(ignored, torch.zeros_like(target), target) - true_dist = torch.nn.functional.one_hot(target, num_classes=num_classes).to(x) + true_dist = torch.nn.functional.one_hot( + target, num_classes=num_classes + ).to(x) true_dist = ( - true_dist * (1 - self.label_smoothing) + self.label_smoothing / num_classes + true_dist * (1 - self.label_smoothing) + + self.label_smoothing / num_classes ) # Set the value of ignored indexes to 0 diff --git a/egs/librispeech/ASR/conformer_ctc/pretrained.py b/egs/librispeech/ASR/conformer_ctc/pretrained.py index 8cabf1a53..a2c0a5486 100755 --- a/egs/librispeech/ASR/conformer_ctc/pretrained.py +++ b/egs/librispeech/ASR/conformer_ctc/pretrained.py @@ -48,11 +48,9 @@ def get_parser(): "--checkpoint", type=str, required=True, - help=( - "Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint()." - ), + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", ) parser.add_argument( @@ -191,12 +189,10 @@ def get_parser(): "sound_files", type=str, nargs="+", - help=( - "The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz." - ), + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", ) return parser @@ -240,9 +236,10 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) # We use only the first channel ans.append(wave[0]) return ans @@ -303,7 +300,9 @@ def main(): logging.info("Decoding started") features = fbank(waves) - features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + features = pad_sequence( + features, batch_first=True, padding_value=math.log(1e-10) + ) # Note: We don't use key padding mask for attention during decoding with torch.no_grad(): @@ -428,7 +427,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py index 8e0f73d05..542fb0364 100644 --- a/egs/librispeech/ASR/conformer_ctc/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -42,9 +42,13 @@ class Conv2dSubsampling(nn.Module): assert idim >= 7 super().__init__() self.conv = nn.Sequential( - nn.Conv2d(in_channels=1, out_channels=odim, kernel_size=3, stride=2), + nn.Conv2d( + in_channels=1, out_channels=odim, kernel_size=3, stride=2 + ), nn.ReLU(), - nn.Conv2d(in_channels=odim, out_channels=odim, kernel_size=3, stride=2), + nn.Conv2d( + in_channels=odim, out_channels=odim, kernel_size=3, stride=2 + ), nn.ReLU(), ) self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim) @@ -128,13 +132,17 @@ class VggSubsampling(nn.Module): ) ) layers.append( - torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=True) + torch.nn.MaxPool2d( + kernel_size=2, stride=2, padding=0, ceil_mode=True + ) ) cur_channels = block_dim self.layers = nn.Sequential(*layers) - self.out = nn.Linear(block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim) + self.out = nn.Linear( + block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim + ) def forward(self, x: torch.Tensor) -> torch.Tensor: """Subsample x. diff --git a/egs/librispeech/ASR/conformer_ctc/train.py b/egs/librispeech/ASR/conformer_ctc/train.py index 1a1c2f4c5..6419f6816 100755 --- a/egs/librispeech/ASR/conformer_ctc/train.py +++ b/egs/librispeech/ASR/conformer_ctc/train.py @@ -393,7 +393,9 @@ def compute_loss( # Works with a phone lexicon decoding_graph = graph_compiler.compile(texts) else: - raise ValueError(f"Unsupported type of graph compiler: {type(graph_compiler)}") + raise ValueError( + f"Unsupported type of graph compiler: {type(graph_compiler)}" + ) dense_fsa_vec = k2.DenseFsaVec( nnet_output, @@ -420,7 +422,9 @@ def compute_loss( # # See https://github.com/k2-fsa/icefall/issues/97 # for more details - unsorted_token_ids = graph_compiler.texts_to_ids(supervisions["text"]) + unsorted_token_ids = graph_compiler.texts_to_ids( + supervisions["text"] + ) att_loss = mmodel.decoder_forward( encoder_memory, memory_mask, @@ -449,7 +453,9 @@ def compute_loss( info["utt_duration"] = supervisions["num_frames"].sum().item() # averaged padding proportion over utterances info["utt_pad_proportion"] = ( - ((feature.size(1) - supervisions["num_frames"]) / feature.size(1)).sum().item() + ((feature.size(1) - supervisions["num_frames"]) / feature.size(1)) + .sum() + .item() ) return loss, info @@ -562,7 +568,9 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -652,7 +660,7 @@ def run(rank, world_size, args): graph_compiler.eos_id = 1 else: raise ValueError( - "Unsupported type of lang dir (we expected it to have " + f"Unsupported type of lang dir (we expected it to have " f"'lang_bpe' or 'lang_phone' in its name): {params.lang_dir}" ) @@ -725,7 +733,9 @@ def run(rank, world_size, args): cur_lr = optimizer._rate if tb_writer is not None: - tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) if rank == 0: diff --git a/egs/librispeech/ASR/conformer_ctc/transformer.py b/egs/librispeech/ASR/conformer_ctc/transformer.py index 0566cfc81..00ca027a7 100644 --- a/egs/librispeech/ASR/conformer_ctc/transformer.py +++ b/egs/librispeech/ASR/conformer_ctc/transformer.py @@ -151,7 +151,9 @@ class Transformer(nn.Module): norm=decoder_norm, ) - self.decoder_output_layer = torch.nn.Linear(d_model, self.decoder_num_class) + self.decoder_output_layer = torch.nn.Linear( + d_model, self.decoder_num_class + ) self.decoder_criterion = LabelSmoothingLoss() else: @@ -179,13 +181,18 @@ class Transformer(nn.Module): memory_key_padding_mask for the decoder. Its shape is (N, T). It is None if `supervision` is None. """ - if isinstance(self.use_feat_batchnorm, bool) and self.use_feat_batchnorm: + if ( + isinstance(self.use_feat_batchnorm, bool) + and self.use_feat_batchnorm + ): x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T) x = self.feat_batchnorm(x) x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C) if isinstance(self.use_feat_batchnorm, float): x *= self.use_feat_batchnorm - encoder_memory, memory_key_padding_mask = self.run_encoder(x, supervision) + encoder_memory, memory_key_padding_mask = self.run_encoder( + x, supervision + ) x = self.ctc_output(encoder_memory) return x, encoder_memory, memory_key_padding_mask @@ -266,17 +273,23 @@ class Transformer(nn.Module): """ ys_in = add_sos(token_ids, sos_id=sos_id) ys_in = [torch.tensor(y) for y in ys_in] - ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id)) + ys_in_pad = pad_sequence( + ys_in, batch_first=True, padding_value=float(eos_id) + ) ys_out = add_eos(token_ids, eos_id=eos_id) ys_out = [torch.tensor(y) for y in ys_out] - ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1)) + ys_out_pad = pad_sequence( + ys_out, batch_first=True, padding_value=float(-1) + ) device = memory.device ys_in_pad = ys_in_pad.to(device) ys_out_pad = ys_out_pad.to(device) - tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( + device + ) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) # TODO: Use length information to create the decoder padding mask @@ -337,17 +350,23 @@ class Transformer(nn.Module): ys_in = add_sos(token_ids, sos_id=sos_id) ys_in = [torch.tensor(y) for y in ys_in] - ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id)) + ys_in_pad = pad_sequence( + ys_in, batch_first=True, padding_value=float(eos_id) + ) ys_out = add_eos(token_ids, eos_id=eos_id) ys_out = [torch.tensor(y) for y in ys_out] - ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1)) + ys_out_pad = pad_sequence( + ys_out, batch_first=True, padding_value=float(-1) + ) device = memory.device ys_in_pad = ys_in_pad.to(device, dtype=torch.int64) ys_out_pad = ys_out_pad.to(device, dtype=torch.int64) - tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( + device + ) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) # TODO: Use length information to create the decoder padding mask @@ -620,7 +639,9 @@ def _get_activation_fn(activation: str): elif activation == "gelu": return nn.functional.gelu - raise RuntimeError("activation should be relu/gelu, not {}".format(activation)) + raise RuntimeError( + "activation should be relu/gelu, not {}".format(activation) + ) class PositionalEncoding(nn.Module): @@ -822,7 +843,9 @@ def encoder_padding_mask( 1, ).to(torch.int32) - lengths = [0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)] + lengths = [ + 0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1) + ] for idx in range(supervision_segments.size(0)): # Note: TorchScript doesn't allow to unpack tensors as tuples sequence_idx = supervision_segments[idx, 0].item() @@ -843,7 +866,9 @@ def encoder_padding_mask( return mask -def decoder_padding_mask(ys_pad: torch.Tensor, ignore_id: int = -1) -> torch.Tensor: +def decoder_padding_mask( + ys_pad: torch.Tensor, ignore_id: int = -1 +) -> torch.Tensor: """Generate a length mask for input. The masked position are filled with True, diff --git a/egs/librispeech/ASR/conformer_ctc2/attention.py b/egs/librispeech/ASR/conformer_ctc2/attention.py index 356d3f21b..1375d7245 100644 --- a/egs/librispeech/ASR/conformer_ctc2/attention.py +++ b/egs/librispeech/ASR/conformer_ctc2/attention.py @@ -18,10 +18,11 @@ from typing import Optional, Tuple import torch import torch.nn as nn -from scaling import ScaledLinear from torch import Tensor from torch.nn.init import xavier_normal_ +from scaling import ScaledLinear + class MultiheadAttention(nn.Module): r"""Allows the model to jointly attend to information @@ -75,7 +76,9 @@ class MultiheadAttention(nn.Module): self.embed_dim = embed_dim self.kdim = kdim if kdim is not None else embed_dim self.vdim = vdim if vdim is not None else embed_dim - self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim + self._qkv_same_embed_dim = ( + self.kdim == embed_dim and self.vdim == embed_dim + ) self.num_heads = num_heads self.dropout = dropout @@ -91,7 +94,9 @@ class MultiheadAttention(nn.Module): self.v_proj_weight = ScaledLinear(self.vdim, embed_dim, bias=bias) self.register_parameter("in_proj_weight", None) else: - self.in_proj_weight = ScaledLinear(embed_dim, 3 * embed_dim, bias=bias) + self.in_proj_weight = ScaledLinear( + embed_dim, 3 * embed_dim, bias=bias + ) self.register_parameter("q_proj_weight", None) self.register_parameter("k_proj_weight", None) self.register_parameter("v_proj_weight", None) @@ -102,8 +107,12 @@ class MultiheadAttention(nn.Module): self.out_proj = ScaledLinear(embed_dim, embed_dim, bias=bias) if add_bias_kv: - self.bias_k = nn.Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs)) - self.bias_v = nn.Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs)) + self.bias_k = nn.Parameter( + torch.empty((1, 1, embed_dim), **factory_kwargs) + ) + self.bias_v = nn.Parameter( + torch.empty((1, 1, embed_dim), **factory_kwargs) + ) else: self.bias_k = self.bias_v = None diff --git a/egs/librispeech/ASR/conformer_ctc2/conformer.py b/egs/librispeech/ASR/conformer_ctc2/conformer.py index a6f1679ef..b906d2650 100644 --- a/egs/librispeech/ASR/conformer_ctc2/conformer.py +++ b/egs/librispeech/ASR/conformer_ctc2/conformer.py @@ -29,8 +29,9 @@ from scaling import ( ScaledConv1d, ScaledLinear, ) -from subsampling import Conv2dSubsampling from torch import Tensor, nn +from subsampling import Conv2dSubsampling + from transformer import Supervisions, Transformer, encoder_padding_mask @@ -181,7 +182,9 @@ class ConformerEncoderLayer(nn.Module): self.d_model = d_model - self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) + self.self_attn = RelPositionMultiheadAttention( + d_model, nhead, dropout=0.0 + ) self.feed_forward = nn.Sequential( ScaledLinear(d_model, dim_feedforward), @@ -353,7 +356,9 @@ class RelPositionalEncoding(torch.nn.Module): """ - def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: + def __init__( + self, d_model: int, dropout_rate: float, max_len: int = 5000 + ) -> None: """Construct an PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() self.d_model = d_model @@ -368,7 +373,9 @@ class RelPositionalEncoding(torch.nn.Module): # the length of self.pe is 2 * input_len - 1 if self.pe.size(1) >= x.size(1) * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): + if self.pe.dtype != x.dtype or str(self.pe.device) != str( + x.device + ): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return # Suppose `i` means to the position of query vecotr and `j` means the @@ -643,9 +650,9 @@ class RelPositionMultiheadAttention(nn.Module): if torch.equal(query, key) and torch.equal(key, value): # self-attention - q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk( - 3, dim=-1 - ) + q, k, v = nn.functional.linear( + query, in_proj_weight, in_proj_bias + ).chunk(3, dim=-1) elif torch.equal(key, value): # encoder-decoder attention @@ -714,25 +721,33 @@ class RelPositionMultiheadAttention(nn.Module): if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: - raise RuntimeError("The size of the 2D attn_mask is not correct.") + raise RuntimeError( + "The size of the 2D attn_mask is not correct." + ) elif attn_mask.dim() == 3: if list(attn_mask.size()) != [ bsz * num_heads, query.size(0), key.size(0), ]: - raise RuntimeError("The size of the 3D attn_mask is not correct.") + raise RuntimeError( + "The size of the 3D attn_mask is not correct." + ) else: raise RuntimeError( - "attn_mask's dimension {} is not supported".format(attn_mask.dim()) + "attn_mask's dimension {} is not supported".format( + attn_mask.dim() + ) ) # attn_mask's dim is 3 now. # convert ByteTensor key_padding_mask to bool - if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: + if ( + key_padding_mask is not None + and key_padding_mask.dtype == torch.uint8 + ): warnings.warn( - "Byte tensor for key_padding_mask is deprecated. Use bool tensor" - " instead." + "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." ) key_padding_mask = key_padding_mask.to(torch.bool) @@ -769,7 +784,9 @@ class RelPositionMultiheadAttention(nn.Module): # first compute matrix a and matrix c # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2) + matrix_ac = torch.matmul( + q_with_bias_u, k + ) # (batch, head, time1, time2) # compute matrix b and matrix d matrix_bd = torch.matmul( @@ -777,9 +794,13 @@ class RelPositionMultiheadAttention(nn.Module): ) # (batch, head, time1, 2*time1-1) matrix_bd = self.rel_shift(matrix_bd) - attn_output_weights = matrix_ac + matrix_bd # (batch, head, time1, time2) + attn_output_weights = ( + matrix_ac + matrix_bd + ) # (batch, head, time1, time2) - attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) + attn_output_weights = attn_output_weights.view( + bsz * num_heads, tgt_len, -1 + ) assert list(attn_output_weights.size()) == [ bsz * num_heads, @@ -813,9 +834,13 @@ class RelPositionMultiheadAttention(nn.Module): attn_output = torch.bmm(attn_output_weights, v) assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] attn_output = ( - attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) + attn_output.transpose(0, 1) + .contiguous() + .view(tgt_len, bsz, embed_dim) + ) + attn_output = nn.functional.linear( + attn_output, out_proj_weight, out_proj_bias ) - attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) if need_weights: # average attention weights over heads @@ -838,7 +863,9 @@ class ConvolutionModule(nn.Module): """ - def __init__(self, channels: int, kernel_size: int, bias: bool = True) -> None: + def __init__( + self, channels: int, kernel_size: int, bias: bool = True + ) -> None: """Construct an ConvolutionModule object.""" super(ConvolutionModule, self).__init__() # kernerl_size should be a odd number for 'SAME' padding diff --git a/egs/librispeech/ASR/conformer_ctc2/decode.py b/egs/librispeech/ASR/conformer_ctc2/decode.py index 934177b1f..97f2f2d39 100755 --- a/egs/librispeech/ASR/conformer_ctc2/decode.py +++ b/egs/librispeech/ASR/conformer_ctc2/decode.py @@ -90,11 +90,9 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( @@ -132,13 +130,11 @@ def get_parser(): "--use-averaged-model", type=str2bool, default=True, - help=( - "Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. " - ), + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", ) parser.add_argument( @@ -662,7 +658,9 @@ def decode_dataset( results[lm_scale].extend(this_batch) else: - assert len(results) > 0, "It should not decode to empty in the first batch!" + assert ( + len(results) > 0 + ), "It should not decode to empty in the first batch!" this_batch = [] hyp_words = [] for ref_text in texts: @@ -677,7 +675,9 @@ def decode_dataset( if batch_idx % 100 == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) return results @@ -709,7 +709,9 @@ def save_results( test_set_wers[key] = wer if enable_log: - logging.info("Wrote detailed error stats to {}".format(errs_filename)) + logging.info( + "Wrote detailed error stats to {}".format(errs_filename) + ) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt" @@ -850,12 +852,13 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -878,12 +881,13 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -911,7 +915,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - "Calculating the averaged model over epoch range from " + f"Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -981,7 +985,9 @@ def main(): eos_id=eos_id, ) - save_results(params=params, test_set_name=test_set, results_dict=results_dict) + save_results( + params=params, test_set_name=test_set, results_dict=results_dict + ) logging.info("Done!") diff --git a/egs/librispeech/ASR/conformer_ctc2/export.py b/egs/librispeech/ASR/conformer_ctc2/export.py index 0e1841d8d..584b3c3fc 100755 --- a/egs/librispeech/ASR/conformer_ctc2/export.py +++ b/egs/librispeech/ASR/conformer_ctc2/export.py @@ -47,7 +47,6 @@ import logging from pathlib import Path import torch -from conformer import Conformer from decode import get_params from icefall.checkpoint import ( @@ -56,8 +55,10 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.lexicon import Lexicon +from conformer import Conformer + from icefall.utils import str2bool +from icefall.lexicon import Lexicon def get_parser(): @@ -88,24 +89,20 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help=( - "Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. " - ), + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", ) parser.add_argument( @@ -180,12 +177,13 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -208,12 +206,13 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -241,7 +240,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - "Calculating the averaged model over epoch range from " + f"Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -274,7 +273,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/conformer_ctc2/train.py b/egs/librispeech/ASR/conformer_ctc2/train.py index 4d7137ad7..18fa3e69f 100755 --- a/egs/librispeech/ASR/conformer_ctc2/train.py +++ b/egs/librispeech/ASR/conformer_ctc2/train.py @@ -69,8 +69,8 @@ from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter -from icefall import diagnostics from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler +from icefall import diagnostics from icefall.checkpoint import load_checkpoint, remove_checkpoints from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.checkpoint import ( @@ -89,7 +89,9 @@ from icefall.utils import ( str2bool, ) -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] +LRSchedulerType = Union[ + torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler +] def get_parser(): @@ -496,7 +498,11 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = model.device if isinstance(model, DDP) else next(model.parameters()).device + device = ( + model.device + if isinstance(model, DDP) + else next(model.parameters()).device + ) feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -525,7 +531,9 @@ def compute_loss( # Works with a phone lexicon decoding_graph = graph_compiler.compile(texts) else: - raise ValueError(f"Unsupported type of graph compiler: {type(graph_compiler)}") + raise ValueError( + f"Unsupported type of graph compiler: {type(graph_compiler)}" + ) dense_fsa_vec = k2.DenseFsaVec( nnet_output, @@ -552,7 +560,9 @@ def compute_loss( # # See https://github.com/k2-fsa/icefall/issues/97 # for more details - unsorted_token_ids = graph_compiler.texts_to_ids(supervisions["text"]) + unsorted_token_ids = graph_compiler.texts_to_ids( + supervisions["text"] + ) att_loss = mmodel.decoder_forward( encoder_memory, memory_mask, @@ -570,7 +580,9 @@ def compute_loss( info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + info["frames"] = ( + (feature_lens // params.subsampling_factor).sum().item() + ) info["ctc_loss"] = ctc_loss.detach().cpu().item() if params.att_rate != 0.0: info["att_loss"] = att_loss.detach().cpu().item() @@ -708,7 +720,8 @@ def train_one_epoch( except RuntimeError as e: if "CUDA out of memory" in str(e): logging.error( - f"failing batch size:{batch_size} failing batch names {batch_name}" + f"failing batch size:{batch_size} " + f"failing batch names {batch_name}" ) raise @@ -763,9 +776,9 @@ def train_one_epoch( f"tot_loss[{tot_loss}], batch size: {batch_size}, " f"lr: {cur_lr:.2e}" ) - if loss_info["ctc_loss"] == float("inf") or loss_info["att_loss"] == float( - "inf" - ): + if loss_info["ctc_loss"] == float("inf") or loss_info[ + "att_loss" + ] == float("inf"): logging.error( "Your loss contains inf, something goes wrong" f"failing batch names {batch_name}" @@ -778,7 +791,9 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -870,7 +885,7 @@ def run(rank, world_size, args): graph_compiler.eos_id = 1 else: raise ValueError( - "Unsupported type of lang dir (we expected it to have " + f"Unsupported type of lang dir (we expected it to have " f"'lang_bpe' or 'lang_phone' in its name): {params.lang_dir}" ) diff --git a/egs/librispeech/ASR/conformer_ctc2/transformer.py b/egs/librispeech/ASR/conformer_ctc2/transformer.py index d3443dc94..3ef7edc23 100644 --- a/egs/librispeech/ASR/conformer_ctc2/transformer.py +++ b/egs/librispeech/ASR/conformer_ctc2/transformer.py @@ -21,17 +21,19 @@ from typing import Dict, List, Optional, Tuple import torch import torch.nn as nn -from attention import MultiheadAttention from label_smoothing import LabelSmoothingLoss +from subsampling import Conv2dSubsampling +from attention import MultiheadAttention +from torch.nn.utils.rnn import pad_sequence + from scaling import ( ActivationBalancer, BasicNorm, DoubleSwish, - ScaledEmbedding, ScaledLinear, + ScaledEmbedding, ) -from subsampling import Conv2dSubsampling -from torch.nn.utils.rnn import pad_sequence + # Note: TorchScript requires Dict/List/etc. to be fully typed. Supervisions = Dict[str, torch.Tensor] @@ -208,7 +210,9 @@ class Transformer(nn.Module): x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) mask = encoder_padding_mask(x.size(0), supervisions) mask = mask.to(x.device) if mask is not None else None - x = self.encoder(x, src_key_padding_mask=mask, warmup=warmup) # (T, N, C) + x = self.encoder( + x, src_key_padding_mask=mask, warmup=warmup + ) # (T, N, C) return x, mask @@ -257,17 +261,23 @@ class Transformer(nn.Module): """ ys_in = add_sos(token_ids, sos_id=sos_id) ys_in = [torch.tensor(y) for y in ys_in] - ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id)) + ys_in_pad = pad_sequence( + ys_in, batch_first=True, padding_value=float(eos_id) + ) ys_out = add_eos(token_ids, eos_id=eos_id) ys_out = [torch.tensor(y) for y in ys_out] - ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1)) + ys_out_pad = pad_sequence( + ys_out, batch_first=True, padding_value=float(-1) + ) device = memory.device ys_in_pad = ys_in_pad.to(device) ys_out_pad = ys_out_pad.to(device) - tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( + device + ) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) # TODO: Use length information to create the decoder padding mask @@ -328,17 +338,23 @@ class Transformer(nn.Module): ys_in = add_sos(token_ids, sos_id=sos_id) ys_in = [torch.tensor(y) for y in ys_in] - ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id)) + ys_in_pad = pad_sequence( + ys_in, batch_first=True, padding_value=float(eos_id) + ) ys_out = add_eos(token_ids, eos_id=eos_id) ys_out = [torch.tensor(y) for y in ys_out] - ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1)) + ys_out_pad = pad_sequence( + ys_out, batch_first=True, padding_value=float(-1) + ) device = memory.device ys_in_pad = ys_in_pad.to(device, dtype=torch.int64) ys_out_pad = ys_out_pad.to(device, dtype=torch.int64) - tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( + device + ) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) # TODO: Use length information to create the decoder padding mask @@ -943,7 +959,9 @@ def encoder_padding_mask( 1, ).to(torch.int32) - lengths = [0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)] + lengths = [ + 0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1) + ] for idx in range(supervision_segments.size(0)): # Note: TorchScript doesn't allow to unpack tensors as tuples sequence_idx = supervision_segments[idx, 0].item() @@ -964,7 +982,9 @@ def encoder_padding_mask( return mask -def decoder_padding_mask(ys_pad: torch.Tensor, ignore_id: int = -1) -> torch.Tensor: +def decoder_padding_mask( + ys_pad: torch.Tensor, ignore_id: int = -1 +) -> torch.Tensor: """Generate a length mask for input. The masked position are filled with True, diff --git a/egs/librispeech/ASR/conformer_mmi/conformer.py b/egs/librispeech/ASR/conformer_mmi/conformer.py index 4d9ddaea9..97c8d83a2 100644 --- a/egs/librispeech/ASR/conformer_mmi/conformer.py +++ b/egs/librispeech/ASR/conformer_mmi/conformer.py @@ -156,7 +156,9 @@ class ConformerEncoderLayer(nn.Module): normalize_before: bool = True, ) -> None: super(ConformerEncoderLayer, self).__init__() - self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) + self.self_attn = RelPositionMultiheadAttention( + d_model, nhead, dropout=0.0 + ) self.feed_forward = nn.Sequential( nn.Linear(d_model, dim_feedforward), @@ -174,14 +176,18 @@ class ConformerEncoderLayer(nn.Module): self.conv_module = ConvolutionModule(d_model, cnn_module_kernel) - self.norm_ff_macaron = nn.LayerNorm(d_model) # for the macaron style FNN module + self.norm_ff_macaron = nn.LayerNorm( + d_model + ) # for the macaron style FNN module self.norm_ff = nn.LayerNorm(d_model) # for the FNN module self.norm_mha = nn.LayerNorm(d_model) # for the MHA module self.ff_scale = 0.5 self.norm_conv = nn.LayerNorm(d_model) # for the CNN module - self.norm_final = nn.LayerNorm(d_model) # for the final output of the block + self.norm_final = nn.LayerNorm( + d_model + ) # for the final output of the block self.dropout = nn.Dropout(dropout) @@ -215,7 +221,9 @@ class ConformerEncoderLayer(nn.Module): residual = src if self.normalize_before: src = self.norm_ff_macaron(src) - src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src)) + src = residual + self.ff_scale * self.dropout( + self.feed_forward_macaron(src) + ) if not self.normalize_before: src = self.norm_ff_macaron(src) @@ -334,7 +342,9 @@ class RelPositionalEncoding(torch.nn.Module): """ - def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: + def __init__( + self, d_model: int, dropout_rate: float, max_len: int = 5000 + ) -> None: """Construct an PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() self.d_model = d_model @@ -350,7 +360,9 @@ class RelPositionalEncoding(torch.nn.Module): # the length of self.pe is 2 * input_len - 1 if self.pe.size(1) >= x.size(1) * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): + if self.pe.dtype != x.dtype or str(self.pe.device) != str( + x.device + ): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return # Suppose `i` means to the position of query vector and `j` means the @@ -620,9 +632,9 @@ class RelPositionMultiheadAttention(nn.Module): if torch.equal(query, key) and torch.equal(key, value): # self-attention - q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk( - 3, dim=-1 - ) + q, k, v = nn.functional.linear( + query, in_proj_weight, in_proj_bias + ).chunk(3, dim=-1) elif torch.equal(key, value): # encoder-decoder attention @@ -690,25 +702,33 @@ class RelPositionMultiheadAttention(nn.Module): if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: - raise RuntimeError("The size of the 2D attn_mask is not correct.") + raise RuntimeError( + "The size of the 2D attn_mask is not correct." + ) elif attn_mask.dim() == 3: if list(attn_mask.size()) != [ bsz * num_heads, query.size(0), key.size(0), ]: - raise RuntimeError("The size of the 3D attn_mask is not correct.") + raise RuntimeError( + "The size of the 3D attn_mask is not correct." + ) else: raise RuntimeError( - "attn_mask's dimension {} is not supported".format(attn_mask.dim()) + "attn_mask's dimension {} is not supported".format( + attn_mask.dim() + ) ) # attn_mask's dim is 3 now. # convert ByteTensor key_padding_mask to bool - if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: + if ( + key_padding_mask is not None + and key_padding_mask.dtype == torch.uint8 + ): warnings.warn( - "Byte tensor for key_padding_mask is deprecated. Use bool tensor" - " instead." + "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." ) key_padding_mask = key_padding_mask.to(torch.bool) @@ -745,7 +765,9 @@ class RelPositionMultiheadAttention(nn.Module): # first compute matrix a and matrix c # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2) + matrix_ac = torch.matmul( + q_with_bias_u, k + ) # (batch, head, time1, time2) # compute matrix b and matrix d matrix_bd = torch.matmul( @@ -757,7 +779,9 @@ class RelPositionMultiheadAttention(nn.Module): matrix_ac + matrix_bd ) * scaling # (batch, head, time1, time2) - attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) + attn_output_weights = attn_output_weights.view( + bsz * num_heads, tgt_len, -1 + ) assert list(attn_output_weights.size()) == [ bsz * num_heads, @@ -791,9 +815,13 @@ class RelPositionMultiheadAttention(nn.Module): attn_output = torch.bmm(attn_output_weights, v) assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] attn_output = ( - attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) + attn_output.transpose(0, 1) + .contiguous() + .view(tgt_len, bsz, embed_dim) + ) + attn_output = nn.functional.linear( + attn_output, out_proj_weight, out_proj_bias ) - attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) if need_weights: # average attention weights over heads @@ -816,7 +844,9 @@ class ConvolutionModule(nn.Module): """ - def __init__(self, channels: int, kernel_size: int, bias: bool = True) -> None: + def __init__( + self, channels: int, kernel_size: int, bias: bool = True + ) -> None: """Construct an ConvolutionModule object.""" super(ConvolutionModule, self).__init__() # kernerl_size should be a odd number for 'SAME' padding diff --git a/egs/librispeech/ASR/conformer_mmi/decode.py b/egs/librispeech/ASR/conformer_mmi/decode.py index e8390ded9..fc9861489 100755 --- a/egs/librispeech/ASR/conformer_mmi/decode.py +++ b/egs/librispeech/ASR/conformer_mmi/decode.py @@ -60,19 +60,16 @@ def get_parser(): "--epoch", type=int, default=34, - help=( - "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." - ), + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", ) parser.add_argument( "--avg", type=int, default=20, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. " - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", ) parser.add_argument( @@ -481,7 +478,9 @@ def decode_dataset( if batch_idx % 100 == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) return results @@ -513,7 +512,9 @@ def save_results( test_set_wers[key] = wer if enable_log: - logging.info("Wrote detailed error stats to {}".format(errs_filename)) + logging.info( + "Wrote detailed error stats to {}".format(errs_filename) + ) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt" @@ -652,7 +653,9 @@ def main(): if params.export: logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt") - torch.save({"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt") + torch.save( + {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt" + ) return model.to(device) @@ -684,7 +687,9 @@ def main(): eos_id=eos_id, ) - save_results(params=params, test_set_name=test_set, results_dict=results_dict) + save_results( + params=params, test_set_name=test_set, results_dict=results_dict + ) logging.info("Done!") diff --git a/egs/librispeech/ASR/conformer_mmi/subsampling.py b/egs/librispeech/ASR/conformer_mmi/subsampling.py index ad9415987..5c3e1222e 100644 --- a/egs/librispeech/ASR/conformer_mmi/subsampling.py +++ b/egs/librispeech/ASR/conformer_mmi/subsampling.py @@ -25,9 +25,13 @@ class Conv2dSubsampling(nn.Module): assert idim >= 7 super().__init__() self.conv = nn.Sequential( - nn.Conv2d(in_channels=1, out_channels=odim, kernel_size=3, stride=2), + nn.Conv2d( + in_channels=1, out_channels=odim, kernel_size=3, stride=2 + ), nn.ReLU(), - nn.Conv2d(in_channels=odim, out_channels=odim, kernel_size=3, stride=2), + nn.Conv2d( + in_channels=odim, out_channels=odim, kernel_size=3, stride=2 + ), nn.ReLU(), ) self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim) @@ -111,13 +115,17 @@ class VggSubsampling(nn.Module): ) ) layers.append( - torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=True) + torch.nn.MaxPool2d( + kernel_size=2, stride=2, padding=0, ceil_mode=True + ) ) cur_channels = block_dim self.layers = nn.Sequential(*layers) - self.out = nn.Linear(block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim) + self.out = nn.Linear( + block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim + ) def forward(self, x: torch.Tensor) -> torch.Tensor: """Subsample x. diff --git a/egs/librispeech/ASR/conformer_mmi/test_subsampling.py b/egs/librispeech/ASR/conformer_mmi/test_subsampling.py index d0bb017dd..937845d77 100755 --- a/egs/librispeech/ASR/conformer_mmi/test_subsampling.py +++ b/egs/librispeech/ASR/conformer_mmi/test_subsampling.py @@ -1,7 +1,8 @@ #!/usr/bin/env python3 +from subsampling import Conv2dSubsampling +from subsampling import VggSubsampling import torch -from subsampling import Conv2dSubsampling, VggSubsampling def test_conv2d_subsampling(): diff --git a/egs/librispeech/ASR/conformer_mmi/test_transformer.py b/egs/librispeech/ASR/conformer_mmi/test_transformer.py index 25d18076d..08e680607 100644 --- a/egs/librispeech/ASR/conformer_mmi/test_transformer.py +++ b/egs/librispeech/ASR/conformer_mmi/test_transformer.py @@ -1,16 +1,17 @@ #!/usr/bin/env python3 import torch -from torch.nn.utils.rnn import pad_sequence from transformer import ( Transformer, - add_eos, - add_sos, - decoder_padding_mask, encoder_padding_mask, generate_square_subsequent_mask, + decoder_padding_mask, + add_sos, + add_eos, ) +from torch.nn.utils.rnn import pad_sequence + def test_encoder_padding_mask(): supervisions = { diff --git a/egs/librispeech/ASR/conformer_mmi/train-with-attention.py b/egs/librispeech/ASR/conformer_mmi/train-with-attention.py index f8c94cff9..011dadd73 100755 --- a/egs/librispeech/ASR/conformer_mmi/train-with-attention.py +++ b/egs/librispeech/ASR/conformer_mmi/train-with-attention.py @@ -36,14 +36,23 @@ from torch.nn.utils import clip_grad_norm_ from torch.utils.tensorboard import SummaryWriter from transformer import Noam -from icefall.ali import convert_alignments_to_tensor, load_alignments, lookup_alignments +from icefall.ali import ( + convert_alignments_to_tensor, + load_alignments, + lookup_alignments, +) from icefall.checkpoint import load_checkpoint from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.dist import cleanup_dist, setup_dist from icefall.lexicon import Lexicon from icefall.mmi import LFMMILoss from icefall.mmi_graph_compiler import MmiTrainingGraphCompiler -from icefall.utils import AttributeDict, encode_supervisions, setup_logger, str2bool +from icefall.utils import ( + AttributeDict, + encode_supervisions, + setup_logger, + str2bool, +) def get_parser(): @@ -361,7 +370,10 @@ def compute_loss( nnet_output = nnet_output.clone() nnet_output[:, :min_len, :] += ali_scale * mask[:, :min_len, :] - if params.batch_idx_train > params.use_ali_until and params.beam_size < 8: + if ( + params.batch_idx_train > params.use_ali_until + and params.beam_size < 8 + ): # logging.info("Change beam size to 8") params.beam_size = 8 else: @@ -750,14 +762,19 @@ def run(rank, world_size, args): for epoch in range(params.start_epoch, params.num_epochs): train_dl.sampler.set_epoch(epoch) - if params.batch_idx_train >= params.use_ali_until and train_ali is not None: + if ( + params.batch_idx_train >= params.use_ali_until + and train_ali is not None + ): # Delete the alignments to save memory train_ali = None valid_ali = None cur_lr = optimizer._rate if tb_writer is not None: - tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) if rank == 0: diff --git a/egs/librispeech/ASR/conformer_mmi/train.py b/egs/librispeech/ASR/conformer_mmi/train.py index 5cfb2bfc7..9a5bdcce2 100755 --- a/egs/librispeech/ASR/conformer_mmi/train.py +++ b/egs/librispeech/ASR/conformer_mmi/train.py @@ -36,14 +36,23 @@ from torch.nn.utils import clip_grad_norm_ from torch.utils.tensorboard import SummaryWriter from transformer import Noam -from icefall.ali import convert_alignments_to_tensor, load_alignments, lookup_alignments +from icefall.ali import ( + convert_alignments_to_tensor, + load_alignments, + lookup_alignments, +) from icefall.checkpoint import load_checkpoint from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.dist import cleanup_dist, setup_dist from icefall.lexicon import Lexicon from icefall.mmi import LFMMILoss from icefall.mmi_graph_compiler import MmiTrainingGraphCompiler -from icefall.utils import AttributeDict, encode_supervisions, setup_logger, str2bool +from icefall.utils import ( + AttributeDict, + encode_supervisions, + setup_logger, + str2bool, +) def get_parser(): @@ -368,7 +377,10 @@ def compute_loss( nnet_output = nnet_output.clone() nnet_output[:, :min_len, :] += ali_scale * mask[:, :min_len, :] - if params.batch_idx_train > params.use_ali_until and params.beam_size < 8: + if ( + params.batch_idx_train > params.use_ali_until + and params.beam_size < 8 + ): logging.info("Change beam size to 8") params.beam_size = 8 else: @@ -758,14 +770,19 @@ def run(rank, world_size, args): for epoch in range(params.start_epoch, params.num_epochs): fix_random_seed(params.seed + epoch) train_dl.sampler.set_epoch(epoch) - if params.batch_idx_train >= params.use_ali_until and train_ali is not None: + if ( + params.batch_idx_train >= params.use_ali_until + and train_ali is not None + ): # Delete the alignments to save memory train_ali = None valid_ali = None cur_lr = optimizer._rate if tb_writer is not None: - tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) if rank == 0: diff --git a/egs/librispeech/ASR/conformer_mmi/transformer.py b/egs/librispeech/ASR/conformer_mmi/transformer.py index 2542d9abe..68a4ff65c 100644 --- a/egs/librispeech/ASR/conformer_mmi/transformer.py +++ b/egs/librispeech/ASR/conformer_mmi/transformer.py @@ -148,7 +148,9 @@ class Transformer(nn.Module): norm=decoder_norm, ) - self.decoder_output_layer = torch.nn.Linear(d_model, self.decoder_num_class) + self.decoder_output_layer = torch.nn.Linear( + d_model, self.decoder_num_class + ) self.decoder_criterion = LabelSmoothingLoss(self.decoder_num_class) else: @@ -180,7 +182,9 @@ class Transformer(nn.Module): x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T) x = self.feat_batchnorm(x) x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C) - encoder_memory, memory_key_padding_mask = self.run_encoder(x, supervision) + encoder_memory, memory_key_padding_mask = self.run_encoder( + x, supervision + ) x = self.ctc_output(encoder_memory) return x, encoder_memory, memory_key_padding_mask @@ -270,7 +274,9 @@ class Transformer(nn.Module): ys_in_pad = ys_in_pad.to(device) ys_out_pad = ys_out_pad.to(device) - tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( + device + ) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) # TODO: Use length information to create the decoder padding mask @@ -335,7 +341,9 @@ class Transformer(nn.Module): ys_in_pad = ys_in_pad.to(device, dtype=torch.int64) ys_out_pad = ys_out_pad.to(device, dtype=torch.int64) - tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( + device + ) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) # TODO: Use length information to create the decoder padding mask @@ -608,7 +616,9 @@ def _get_activation_fn(activation: str): elif activation == "gelu": return nn.functional.gelu - raise RuntimeError("activation should be relu/gelu, not {}".format(activation)) + raise RuntimeError( + "activation should be relu/gelu, not {}".format(activation) + ) class PositionalEncoding(nn.Module): @@ -877,7 +887,9 @@ def encoder_padding_mask( 1, ).to(torch.int32) - lengths = [0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)] + lengths = [ + 0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1) + ] for idx in range(supervision_segments.size(0)): # Note: TorchScript doesn't allow to unpack tensors as tuples sequence_idx = supervision_segments[idx, 0].item() @@ -898,7 +910,9 @@ def encoder_padding_mask( return mask -def decoder_padding_mask(ys_pad: torch.Tensor, ignore_id: int = -1) -> torch.Tensor: +def decoder_padding_mask( + ys_pad: torch.Tensor, ignore_id: int = -1 +) -> torch.Tensor: """Generate a length mask for input. The masked position are filled with True, diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/decode.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/decode.py index a1c43f7f5..620d69a19 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/decode.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/decode.py @@ -135,24 +135,20 @@ def get_parser(): "--avg", type=int, default=10, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help=( - "Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. " - ), + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", ) parser.add_argument( @@ -219,7 +215,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( @@ -287,7 +284,9 @@ def decode_one_batch( value=LOG_EPS, ) - encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + encoder_out, encoder_out_lens = model.encoder( + x=feature, x_lens=feature_lens + ) hyps = [] if params.decoding_method == "fast_beam_search": @@ -302,7 +301,10 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + elif ( + params.decoding_method == "greedy_search" + and params.max_sym_per_frame == 1 + ): hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -348,7 +350,11 @@ def decode_one_batch( return {"greedy_search": hyps} elif params.decoding_method == "fast_beam_search": return { - f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps + ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}" + ): hyps } else: return {f"beam_size_{params.beam_size}": hyps} @@ -421,7 +427,9 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) return results @@ -454,7 +462,8 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -497,7 +506,9 @@ def main(): params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + params.suffix += ( + f"-{params.decoding_method}-beam-size-{params.beam_size}" + ) else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -529,12 +540,13 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -557,12 +569,13 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -590,7 +603,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - "Calculating the averaged model over epoch range from " + f"Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py index 0639ba746..8ca7d5568 100644 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py @@ -35,6 +35,7 @@ from scaling import ( from icefall.utils import make_pad_mask + LOG_EPSILON = math.log(1e-10) @@ -126,7 +127,9 @@ def stack_states( for si, s in enumerate(layer): attn_caches[li][si].append(s) if b == batch_size - 1: - attn_caches[li][si] = torch.stack(attn_caches[li][si], dim=1) + attn_caches[li][si] = torch.stack( + attn_caches[li][si], dim=1 + ) conv_caches = [] for layer in state_list[0][1]: @@ -265,7 +268,9 @@ class ConvolutionModule(nn.Module): intervals = torch.arange( 0, self.chunk_length * (num_chunks - 1), self.chunk_length ) - first = torch.arange(self.chunk_length, self.chunk_length + self.cache_size) + first = torch.arange( + self.chunk_length, self.chunk_length + self.cache_size + ) indexes = intervals.unsqueeze(1) + first.unsqueeze(0) indexes = torch.cat( [indexes, torch.arange(U_ - self.cache_size, U_).unsqueeze(0)] @@ -279,7 +284,9 @@ class ConvolutionModule(nn.Module): # (num_chunks * B, cache_size + right_context_length, D) return pad_right_context.permute(0, 2, 1) - def _merge_right_context(self, right_context: torch.Tensor, B: int) -> torch.Tensor: + def _merge_right_context( + self, right_context: torch.Tensor, B: int + ) -> torch.Tensor: """ Args: right_context: @@ -330,8 +337,12 @@ class ConvolutionModule(nn.Module): right_context = x[:, :, :R] # (B, D, R) # make causal convolution - cache = torch.zeros(B, D, self.cache_size, device=x.device, dtype=x.dtype) - pad_utterance = torch.cat([cache, utterance], dim=2) # (B, D, cache + U) + cache = torch.zeros( + B, D, self.cache_size, device=x.device, dtype=x.dtype + ) + pad_utterance = torch.cat( + [cache, utterance], dim=2 + ) # (B, D, cache + U) # depth-wise conv on utterance utterance = self.depthwise_conv(pad_utterance) # (B, D, U) @@ -344,7 +355,9 @@ class ConvolutionModule(nn.Module): right_context = self.depthwise_conv( pad_right_context ) # (num_segs * B, D, right_context_length) - right_context = self._merge_right_context(right_context, B) # (B, D, R) + right_context = self._merge_right_context( + right_context, B + ) # (B, D, R) x = torch.cat([right_context, utterance], dim=2) # (B, D, R + U) x = self.deriv_balancer2(x) @@ -445,7 +458,8 @@ class EmformerAttention(nn.Module): if embed_dim % nhead != 0: raise ValueError( - f"embed_dim ({embed_dim}) is not a multiple ofnhead ({nhead})." + f"embed_dim ({embed_dim}) is not a multiple of" + f"nhead ({nhead})." ) self.embed_dim = embed_dim @@ -455,7 +469,9 @@ class EmformerAttention(nn.Module): self.head_dim = embed_dim // nhead self.dropout = dropout - self.emb_to_key_value = ScaledLinear(embed_dim, 2 * embed_dim, bias=True) + self.emb_to_key_value = ScaledLinear( + embed_dim, 2 * embed_dim, bias=True + ) self.emb_to_query = ScaledLinear(embed_dim, embed_dim, bias=True) self.out_proj = ScaledLinear( embed_dim, embed_dim, bias=True, initial_scale=0.25 @@ -497,7 +513,9 @@ class EmformerAttention(nn.Module): if padding_mask is not None: Q = attention_weights.size(1) B = attention_weights.size(0) // self.nhead - attention_weights_float = attention_weights_float.view(B, self.nhead, Q, -1) + attention_weights_float = attention_weights_float.view( + B, self.nhead, Q, -1 + ) attention_weights_float = attention_weights_float.masked_fill( padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), self.negative_inf, @@ -533,7 +551,9 @@ class EmformerAttention(nn.Module): scaling = float(self.head_dim) ** -0.5 # compute query with [right_context, utterance, summary]. - query = self.emb_to_query(torch.cat([right_context, utterance, summary])) + query = self.emb_to_query( + torch.cat([right_context, utterance, summary]) + ) # compute key and value with [memory, right_context, utterance]. key, value = self.emb_to_key_value( torch.cat([memory, right_context, utterance]) @@ -544,12 +564,16 @@ class EmformerAttention(nn.Module): # [memory, right context, left context, uttrance] # this is used in inference mode key = torch.cat([key[: M + R], left_context_key, key[M + R :]]) - value = torch.cat([value[: M + R], left_context_val, value[M + R :]]) + value = torch.cat( + [value[: M + R], left_context_val, value[M + R :]] + ) Q = query.size(0) # KV = key.size(0) reshaped_query, reshaped_key, reshaped_value = [ - tensor.contiguous().view(-1, B * self.nhead, self.head_dim).transpose(0, 1) + tensor.contiguous() + .view(-1, B * self.nhead, self.head_dim) + .transpose(0, 1) for tensor in [query, key, value] ] # (B * nhead, Q or KV, head_dim) attention_weights = torch.bmm( @@ -564,7 +588,9 @@ class EmformerAttention(nn.Module): # compute attention outputs attention = torch.bmm(attention_probs, reshaped_value) assert attention.shape == (B * self.nhead, Q, self.head_dim) - attention = attention.transpose(0, 1).contiguous().view(Q, B, self.embed_dim) + attention = ( + attention.transpose(0, 1).contiguous().view(Q, B, self.embed_dim) + ) # apply output projection outputs = self.out_proj(attention) @@ -646,7 +672,12 @@ class EmformerAttention(nn.Module): - output of right context and utterance, with shape (R + U, B, D). - memory output, with shape (M, B, D), where M = S - 1 or M = 0. """ - (output_right_context_utterance, output_memory, _, _,) = self._forward_impl( + ( + output_right_context_utterance, + output_memory, + _, + _, + ) = self._forward_impl( utterance, right_context, summary, @@ -916,9 +947,13 @@ class EmformerEncoderLayer(nn.Module): right_context = right_context_utterance[:R] if self.use_memory: - summary = self.summary_op(utterance.permute(1, 2, 0)).permute(2, 0, 1) + summary = self.summary_op(utterance.permute(1, 2, 0)).permute( + 2, 0, 1 + ) else: - summary = torch.empty(0).to(dtype=utterance.dtype, device=utterance.device) + summary = torch.empty(0).to( + dtype=utterance.dtype, device=utterance.device + ) output_right_context_utterance, output_memory = self.attention( utterance=utterance, right_context=right_context, @@ -957,10 +992,14 @@ class EmformerEncoderLayer(nn.Module): left_context_val = attn_cache[2] if self.use_memory: - summary = self.summary_op(utterance.permute(1, 2, 0)).permute(2, 0, 1) + summary = self.summary_op(utterance.permute(1, 2, 0)).permute( + 2, 0, 1 + ) summary = summary[:1] else: - summary = torch.empty(0).to(dtype=utterance.dtype, device=utterance.device) + summary = torch.empty(0).to( + dtype=utterance.dtype, device=utterance.device + ) ( output_right_context_utterance, output_memory, @@ -975,7 +1014,9 @@ class EmformerEncoderLayer(nn.Module): left_context_val=left_context_val, padding_mask=padding_mask, ) - attn_cache = self._update_attn_cache(next_key, next_val, memory, attn_cache) + attn_cache = self._update_attn_cache( + next_key, next_val, memory, attn_cache + ) return output_right_context_utterance, output_memory, attn_cache def forward( @@ -1110,7 +1151,11 @@ class EmformerEncoderLayer(nn.Module): src = src + self.dropout(self.feed_forward_macaron(src)) # emformer attention module - (src_att, output_memory, attn_cache,) = self._apply_attention_module_infer( + ( + src_att, + output_memory, + attn_cache, + ) = self._apply_attention_module_infer( src, R, memory, attn_cache, padding_mask=padding_mask ) src = src + self.dropout(src_att) @@ -1250,7 +1295,9 @@ class EmformerEncoder(nn.Module): def _gen_right_context(self, x: torch.Tensor) -> torch.Tensor: """Hard copy each chunk's right context and concat them.""" T = x.shape[0] - num_chunks = math.ceil((T - self.right_context_length) / self.chunk_length) + num_chunks = math.ceil( + (T - self.right_context_length) / self.chunk_length + ) # first (num_chunks - 1) right context block intervals = torch.arange( 0, self.chunk_length * (num_chunks - 1), self.chunk_length @@ -1269,7 +1316,9 @@ class EmformerEncoder(nn.Module): right_context_blocks = x[indexes.reshape(-1)] return right_context_blocks - def _gen_attention_mask_col_widths(self, chunk_idx: int, U: int) -> List[int]: + def _gen_attention_mask_col_widths( + self, chunk_idx: int, U: int + ) -> List[int]: """Calculate column widths (key, value) in attention mask for the chunk_idx chunk.""" num_chunks = math.ceil(U / self.chunk_length) @@ -1430,7 +1479,9 @@ class EmformerEncoder(nn.Module): output_lengths = torch.clamp(lengths - self.right_context_length, min=0) attention_mask = self._gen_attention_mask(utterance) memory = ( - self.init_memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)[:-1] + self.init_memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)[ + :-1 + ] if self.use_memory else torch.empty(0).to(dtype=x.dtype, device=x.device) ) @@ -1592,8 +1643,12 @@ class EmformerEncoder(nn.Module): attn_caches = [ [ torch.zeros(self.memory_size, self.d_model, device=device), - torch.zeros(self.left_context_length, self.d_model, device=device), - torch.zeros(self.left_context_length, self.d_model, device=device), + torch.zeros( + self.left_context_length, self.d_model, device=device + ), + torch.zeros( + self.left_context_length, self.d_model, device=device + ), ] for _ in range(self.num_encoder_layers) ] @@ -1638,11 +1693,17 @@ class Emformer(EncoderInterface): raise NotImplementedError( "chunk_length must be a mutiple of subsampling_factor." ) - if left_context_length != 0 and left_context_length % subsampling_factor != 0: + if ( + left_context_length != 0 + and left_context_length % subsampling_factor != 0 + ): raise NotImplementedError( "left_context_length must be 0 or a mutiple of subsampling_factor." # noqa ) - if right_context_length != 0 and right_context_length % subsampling_factor != 0: + if ( + right_context_length != 0 + and right_context_length % subsampling_factor != 0 + ): raise NotImplementedError( "right_context_length must be 0 or a mutiple of subsampling_factor." # noqa ) @@ -1705,7 +1766,9 @@ class Emformer(EncoderInterface): x_lens = (((x_lens - 1) >> 1) - 1) >> 1 assert x.size(0) == x_lens.max().item() - output, output_lengths = self.encoder(x, x_lens, warmup=warmup) # (T, N, C) + output, output_lengths = self.encoder( + x, x_lens, warmup=warmup + ) # (T, N, C) output = output.permute(1, 0, 2) # (T, N, C) -> (N, T, C) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/export.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/export.py index 59105e286..4930881ea 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/export.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/export.py @@ -103,11 +103,9 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( @@ -138,20 +136,19 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help=( - "Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. " - ), + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", ) add_model_arguments(parser) @@ -184,12 +181,13 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -212,12 +210,13 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -245,7 +244,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - "Calculating the averaged model over epoch range from " + f"Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -280,7 +279,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/stream.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/stream.py index c211b215e..9494e1fc1 100644 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/stream.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/stream.py @@ -68,12 +68,14 @@ class Stream(object): elif params.decoding_method == "fast_beam_search": # feature_len is needed to get partial results. # The rnnt_decoding_stream for fast_beam_search. - self.rnnt_decoding_stream: k2.RnntDecodingStream = k2.RnntDecodingStream( - decoding_graph + self.rnnt_decoding_stream: k2.RnntDecodingStream = ( + k2.RnntDecodingStream(decoding_graph) ) self.hyp: Optional[List[int]] = None else: - raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) self.ground_truth: str = "" diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py index abe83732a..61dbe8658 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py @@ -113,9 +113,8 @@ def get_parser(): "--epoch", type=int, default=28, - help=( - "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." - ), + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", ) parser.add_argument( @@ -132,24 +131,20 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. " - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", ) parser.add_argument( "--use-averaged-model", type=str2bool, default=False, - help=( - "Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. " - ), + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", ) parser.add_argument( @@ -216,7 +211,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -375,7 +371,9 @@ def modified_beam_search( index=hyps_shape.row_ids(1).to(torch.int64), ) # (num_hyps, encoder_out_dim) - logits = model.joiner(current_encoder_out, decoder_out, project_input=False) + logits = model.joiner( + current_encoder_out, decoder_out, project_input=False + ) # logits is of shape (num_hyps, 1, 1, vocab_size) logits = logits.squeeze(1).squeeze(1) @@ -392,7 +390,9 @@ def modified_beam_search( log_probs_shape = k2.ragged.create_ragged_shape2( row_splits=row_splits, cached_tot_size=log_probs.numel() ) - ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) + ragged_log_probs = k2.RaggedTensor( + shape=log_probs_shape, value=log_probs + ) for i in range(batch_size): topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) @@ -551,10 +551,14 @@ def decode_one_chunk( feature_list, batch_first=True, padding_value=LOG_EPSILON ).to(device) feature_lens = torch.tensor(feature_len_list, device=device) - num_processed_frames = torch.tensor(num_processed_frames_list, device=device) + num_processed_frames = torch.tensor( + num_processed_frames_list, device=device + ) # Make sure it has at least 1 frame after subsampling, first-and-last-frame cutting, and right context cutting # noqa - tail_length = 3 * params.subsampling_factor + params.right_context_length + 3 + tail_length = ( + 3 * params.subsampling_factor + params.right_context_length + 3 + ) if features.size(1) < tail_length: pad_length = tail_length - features.size(1) feature_lens += pad_length @@ -601,7 +605,9 @@ def decode_one_chunk( max_states=params.max_states, ) else: - raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) # Update cached states of each stream state_list = unstack_states(states) @@ -776,7 +782,8 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -824,7 +831,9 @@ def main(): params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + params.suffix += ( + f"-{params.decoding_method}-beam-size-{params.beam_size}" + ) else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -858,12 +867,13 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -886,12 +896,13 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -919,7 +930,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - "Calculating the averaged model over epoch range from " + f"Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py index a76417e5f..c07d8f76b 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py @@ -95,7 +95,9 @@ from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] +LRSchedulerType = Union[ + torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler +] def add_model_arguments(parser: argparse.ArgumentParser): @@ -263,45 +265,42 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help=( - "The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss" - ), + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help=( - "The scale to smooth the loss with lm (output of prediction network) part." - ), + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)part.", + help="The scale to smooth the loss with am (output of encoder network)" + "part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help=( - "To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss." - ), + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", ) parser.add_argument( @@ -637,7 +636,11 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = model.device if isinstance(model, DDP) else next(model.parameters()).device + device = ( + model.device + if isinstance(model, DDP) + else next(model.parameters()).device + ) feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -665,16 +668,23 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) + 0.0 + if warmup < 1.0 + else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) + ) + loss = ( + params.simple_loss_scale * simple_loss + + pruned_loss_scale * pruned_loss ) - loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + info["frames"] = ( + (feature_lens // params.subsampling_factor).sum().item() + ) # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa info["utterances"] = feature.size(0) @@ -861,7 +871,9 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -969,7 +981,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 2 ** 22 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/decode.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/decode.py index 9cb4a5afc..98b8290b5 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/decode.py @@ -135,24 +135,20 @@ def get_parser(): "--avg", type=int, default=10, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help=( - "Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. " - ), + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", ) parser.add_argument( @@ -219,7 +215,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( @@ -287,7 +284,9 @@ def decode_one_batch( value=LOG_EPS, ) - encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + encoder_out, encoder_out_lens = model.encoder( + x=feature, x_lens=feature_lens + ) hyps = [] if params.decoding_method == "fast_beam_search": @@ -302,7 +301,10 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + elif ( + params.decoding_method == "greedy_search" + and params.max_sym_per_frame == 1 + ): hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -348,7 +350,11 @@ def decode_one_batch( return {"greedy_search": hyps} elif params.decoding_method == "fast_beam_search": return { - f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps + ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}" + ): hyps } else: return {f"beam_size_{params.beam_size}": hyps} @@ -421,7 +427,9 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) return results @@ -454,7 +462,8 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -497,7 +506,9 @@ def main(): params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + params.suffix += ( + f"-{params.decoding_method}-beam-size-{params.beam_size}" + ) else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -529,12 +540,13 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -557,12 +569,13 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -590,7 +603,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - "Calculating the averaged model over epoch range from " + f"Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer.py index 09200f2e1..f16f5acc7 100644 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer.py @@ -35,6 +35,7 @@ from scaling import ( from icefall.utils import make_pad_mask + LOG_EPSILON = math.log(1e-10) @@ -126,7 +127,9 @@ def stack_states( for si, s in enumerate(layer): attn_caches[li][si].append(s) if b == batch_size - 1: - attn_caches[li][si] = torch.stack(attn_caches[li][si], dim=1) + attn_caches[li][si] = torch.stack( + attn_caches[li][si], dim=1 + ) conv_caches = [] for layer in state_list[0][1]: @@ -265,7 +268,9 @@ class ConvolutionModule(nn.Module): intervals = torch.arange( 0, self.chunk_length * (num_chunks - 1), self.chunk_length ) - first = torch.arange(self.chunk_length, self.chunk_length + self.cache_size) + first = torch.arange( + self.chunk_length, self.chunk_length + self.cache_size + ) indexes = intervals.unsqueeze(1) + first.unsqueeze(0) indexes = torch.cat( [indexes, torch.arange(U_ - self.cache_size, U_).unsqueeze(0)] @@ -279,7 +284,9 @@ class ConvolutionModule(nn.Module): # (num_chunks * B, cache_size + right_context_length, D) return pad_right_context.permute(0, 2, 1) - def _merge_right_context(self, right_context: torch.Tensor, B: int) -> torch.Tensor: + def _merge_right_context( + self, right_context: torch.Tensor, B: int + ) -> torch.Tensor: """ Args: right_context: @@ -330,8 +337,12 @@ class ConvolutionModule(nn.Module): right_context = x[:, :, :R] # (B, D, R) # make causal convolution - cache = torch.zeros(B, D, self.cache_size, device=x.device, dtype=x.dtype) - pad_utterance = torch.cat([cache, utterance], dim=2) # (B, D, cache + U) + cache = torch.zeros( + B, D, self.cache_size, device=x.device, dtype=x.dtype + ) + pad_utterance = torch.cat( + [cache, utterance], dim=2 + ) # (B, D, cache + U) # depth-wise conv on utterance utterance = self.depthwise_conv(pad_utterance) # (B, D, U) @@ -344,7 +355,9 @@ class ConvolutionModule(nn.Module): right_context = self.depthwise_conv( pad_right_context ) # (num_segs * B, D, right_context_length) - right_context = self._merge_right_context(right_context, B) # (B, D, R) + right_context = self._merge_right_context( + right_context, B + ) # (B, D, R) x = torch.cat([right_context, utterance], dim=2) # (B, D, R + U) x = self.deriv_balancer2(x) @@ -445,7 +458,8 @@ class EmformerAttention(nn.Module): if embed_dim % nhead != 0: raise ValueError( - f"embed_dim ({embed_dim}) is not a multiple ofnhead ({nhead})." + f"embed_dim ({embed_dim}) is not a multiple of" + f"nhead ({nhead})." ) self.embed_dim = embed_dim @@ -455,7 +469,9 @@ class EmformerAttention(nn.Module): self.head_dim = embed_dim // nhead self.dropout = dropout - self.emb_to_key_value = ScaledLinear(embed_dim, 2 * embed_dim, bias=True) + self.emb_to_key_value = ScaledLinear( + embed_dim, 2 * embed_dim, bias=True + ) self.emb_to_query = ScaledLinear(embed_dim, embed_dim, bias=True) self.out_proj = ScaledLinear( embed_dim, embed_dim, bias=True, initial_scale=0.25 @@ -497,7 +513,9 @@ class EmformerAttention(nn.Module): if padding_mask is not None: Q = attention_weights.size(1) B = attention_weights.size(0) // self.nhead - attention_weights_float = attention_weights_float.view(B, self.nhead, Q, -1) + attention_weights_float = attention_weights_float.view( + B, self.nhead, Q, -1 + ) attention_weights_float = attention_weights_float.masked_fill( padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), self.negative_inf, @@ -543,12 +561,16 @@ class EmformerAttention(nn.Module): # [memory, right context, left context, uttrance] # this is used in inference mode key = torch.cat([key[: M + R], left_context_key, key[M + R :]]) - value = torch.cat([value[: M + R], left_context_val, value[M + R :]]) + value = torch.cat( + [value[: M + R], left_context_val, value[M + R :]] + ) Q = query.size(0) # KV = key.size(0) reshaped_query, reshaped_key, reshaped_value = [ - tensor.contiguous().view(-1, B * self.nhead, self.head_dim).transpose(0, 1) + tensor.contiguous() + .view(-1, B * self.nhead, self.head_dim) + .transpose(0, 1) for tensor in [query, key, value] ] # (B * nhead, Q or KV, head_dim) attention_weights = torch.bmm( @@ -563,7 +585,9 @@ class EmformerAttention(nn.Module): # compute attention outputs attention = torch.bmm(attention_probs, reshaped_value) assert attention.shape == (B * self.nhead, Q, self.head_dim) - attention = attention.transpose(0, 1).contiguous().view(Q, B, self.embed_dim) + attention = ( + attention.transpose(0, 1).contiguous().view(Q, B, self.embed_dim) + ) # apply output projection output_right_context_utterance = self.out_proj(attention) @@ -881,11 +905,13 @@ class EmformerEncoderLayer(nn.Module): right_context = right_context_utterance[:R] if self.use_memory: - memory = self.summary_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)[ - :-1, :, : - ] + memory = self.summary_op(utterance.permute(1, 2, 0)).permute( + 2, 0, 1 + )[:-1, :, :] else: - memory = torch.empty(0).to(dtype=utterance.dtype, device=utterance.device) + memory = torch.empty(0).to( + dtype=utterance.dtype, device=utterance.device + ) output_right_context_utterance = self.attention( utterance=utterance, right_context=right_context, @@ -922,12 +948,18 @@ class EmformerEncoderLayer(nn.Module): left_context_val = attn_cache[2] if self.use_memory: - memory = self.summary_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)[ - :1, :, : - ] + memory = self.summary_op(utterance.permute(1, 2, 0)).permute( + 2, 0, 1 + )[:1, :, :] else: - memory = torch.empty(0).to(dtype=utterance.dtype, device=utterance.device) - (output_right_context_utterance, next_key, next_val,) = self.attention.infer( + memory = torch.empty(0).to( + dtype=utterance.dtype, device=utterance.device + ) + ( + output_right_context_utterance, + next_key, + next_val, + ) = self.attention.infer( utterance=utterance, right_context=right_context, memory=pre_memory, @@ -935,7 +967,9 @@ class EmformerEncoderLayer(nn.Module): left_context_val=left_context_val, padding_mask=padding_mask, ) - attn_cache = self._update_attn_cache(next_key, next_val, memory, attn_cache) + attn_cache = self._update_attn_cache( + next_key, next_val, memory, attn_cache + ) return output_right_context_utterance, attn_cache def forward( @@ -1192,7 +1226,9 @@ class EmformerEncoder(nn.Module): def _gen_right_context(self, x: torch.Tensor) -> torch.Tensor: """Hard copy each chunk's right context and concat them.""" T = x.shape[0] - num_chunks = math.ceil((T - self.right_context_length) / self.chunk_length) + num_chunks = math.ceil( + (T - self.right_context_length) / self.chunk_length + ) # first (num_chunks - 1) right context block intervals = torch.arange( 0, self.chunk_length * (num_chunks - 1), self.chunk_length @@ -1211,7 +1247,9 @@ class EmformerEncoder(nn.Module): right_context_blocks = x[indexes.reshape(-1)] return right_context_blocks - def _gen_attention_mask_col_widths(self, chunk_idx: int, U: int) -> List[int]: + def _gen_attention_mask_col_widths( + self, chunk_idx: int, U: int + ) -> List[int]: """Calculate column widths (key, value) in attention mask for the chunk_idx chunk.""" num_chunks = math.ceil(U / self.chunk_length) @@ -1511,8 +1549,12 @@ class EmformerEncoder(nn.Module): attn_caches = [ [ torch.zeros(self.memory_size, self.d_model, device=device), - torch.zeros(self.left_context_length, self.d_model, device=device), - torch.zeros(self.left_context_length, self.d_model, device=device), + torch.zeros( + self.left_context_length, self.d_model, device=device + ), + torch.zeros( + self.left_context_length, self.d_model, device=device + ), ] for _ in range(self.num_encoder_layers) ] @@ -1557,11 +1599,17 @@ class Emformer(EncoderInterface): raise NotImplementedError( "chunk_length must be a mutiple of subsampling_factor." ) - if left_context_length != 0 and left_context_length % subsampling_factor != 0: + if ( + left_context_length != 0 + and left_context_length % subsampling_factor != 0 + ): raise NotImplementedError( "left_context_length must be 0 or a mutiple of subsampling_factor." # noqa ) - if right_context_length != 0 and right_context_length % subsampling_factor != 0: + if ( + right_context_length != 0 + and right_context_length % subsampling_factor != 0 + ): raise NotImplementedError( "right_context_length must be 0 or a mutiple of subsampling_factor." # noqa ) @@ -1624,7 +1672,9 @@ class Emformer(EncoderInterface): x_lens = (((x_lens - 1) >> 1) - 1) >> 1 assert x.size(0) == x_lens.max().item() - output, output_lengths = self.encoder(x, x_lens, warmup=warmup) # (T, N, C) + output, output_lengths = self.encoder( + x, x_lens, warmup=warmup + ) # (T, N, C) output = output.permute(1, 0, 2) # (T, N, C) -> (N, T, C) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export.py index 4d05b367c..ab15e0241 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export.py @@ -103,11 +103,9 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( @@ -138,20 +136,19 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help=( - "Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. " - ), + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", ) add_model_arguments(parser) @@ -184,12 +181,13 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -212,12 +210,13 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -245,7 +244,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - "Calculating the averaged model over epoch range from " + f"Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -280,7 +279,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming_decode.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming_decode.py index 0486ac2eb..71150392d 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming_decode.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming_decode.py @@ -113,9 +113,8 @@ def get_parser(): "--epoch", type=int, default=28, - help=( - "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." - ), + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", ) parser.add_argument( @@ -132,24 +131,20 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. " - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", ) parser.add_argument( "--use-averaged-model", type=str2bool, default=False, - help=( - "Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. " - ), + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", ) parser.add_argument( @@ -216,7 +211,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -375,7 +371,9 @@ def modified_beam_search( index=hyps_shape.row_ids(1).to(torch.int64), ) # (num_hyps, encoder_out_dim) - logits = model.joiner(current_encoder_out, decoder_out, project_input=False) + logits = model.joiner( + current_encoder_out, decoder_out, project_input=False + ) # logits is of shape (num_hyps, 1, 1, vocab_size) logits = logits.squeeze(1).squeeze(1) @@ -392,7 +390,9 @@ def modified_beam_search( log_probs_shape = k2.ragged.create_ragged_shape2( row_splits=row_splits, cached_tot_size=log_probs.numel() ) - ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) + ragged_log_probs = k2.RaggedTensor( + shape=log_probs_shape, value=log_probs + ) for i in range(batch_size): topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) @@ -551,10 +551,14 @@ def decode_one_chunk( feature_list, batch_first=True, padding_value=LOG_EPSILON ).to(device) feature_lens = torch.tensor(feature_len_list, device=device) - num_processed_frames = torch.tensor(num_processed_frames_list, device=device) + num_processed_frames = torch.tensor( + num_processed_frames_list, device=device + ) # Make sure it has at least 1 frame after subsampling, first-and-last-frame cutting, and right context cutting # noqa - tail_length = 3 * params.subsampling_factor + params.right_context_length + 3 + tail_length = ( + 3 * params.subsampling_factor + params.right_context_length + 3 + ) if features.size(1) < tail_length: pad_length = tail_length - features.size(1) feature_lens += pad_length @@ -601,7 +605,9 @@ def decode_one_chunk( max_states=params.max_states, ) else: - raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) # Update cached states of each stream state_list = unstack_states(states) @@ -776,7 +782,8 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -824,7 +831,9 @@ def main(): params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + params.suffix += ( + f"-{params.decoding_method}-beam-size-{params.beam_size}" + ) else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -858,12 +867,13 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -886,12 +896,13 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -919,7 +930,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - "Calculating the averaged model over epoch range from " + f"Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py index 2c2593b56..2bbc45d78 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py @@ -95,7 +95,9 @@ from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] +LRSchedulerType = Union[ + torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler +] def add_model_arguments(parser: argparse.ArgumentParser): @@ -263,45 +265,42 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help=( - "The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss" - ), + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help=( - "The scale to smooth the loss with lm (output of prediction network) part." - ), + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)part.", + help="The scale to smooth the loss with am (output of encoder network)" + "part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help=( - "To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss." - ), + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", ) parser.add_argument( @@ -637,7 +636,11 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = model.device if isinstance(model, DDP) else next(model.parameters()).device + device = ( + model.device + if isinstance(model, DDP) + else next(model.parameters()).device + ) feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -665,16 +668,23 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) + 0.0 + if warmup < 1.0 + else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) + ) + loss = ( + params.simple_loss_scale * simple_loss + + pruned_loss_scale * pruned_loss ) - loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + info["frames"] = ( + (feature_lens // params.subsampling_factor).sum().item() + ) # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa info["utterances"] = feature.size(0) @@ -861,7 +871,9 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -969,7 +981,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 2 ** 22 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/librispeech/ASR/local/add_alignment_librispeech.py b/egs/librispeech/ASR/local/add_alignment_librispeech.py index cc34a72d8..fe6a26c51 100755 --- a/egs/librispeech/ASR/local/add_alignment_librispeech.py +++ b/egs/librispeech/ASR/local/add_alignment_librispeech.py @@ -157,7 +157,9 @@ def add_alignment( for ali_path in part_ali_dir.rglob("*.alignment.txt"): ali = parse_alignments(ali_path) alignments.update(ali) - logging.info(f"{part} has {len(alignments.keys())} cuts with alignments.") + logging.info( + f"{part} has {len(alignments.keys())} cuts with alignments." + ) # add alignment attribute and write out cuts_in = load_manifest_lazy(cuts_in_path) @@ -168,14 +170,18 @@ def add_alignment( if origin_id in alignments: ali = alignments[origin_id] else: - logging.info(f"Warning: {origin_id} does not have alignment.") + logging.info( + f"Warning: {origin_id} does not have alignment." + ) ali = [] subcut.alignment = {"word": ali} writer.write(cut, flush=True) def main(): - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) parser = get_parser() diff --git a/egs/librispeech/ASR/local/compile_hlg.py b/egs/librispeech/ASR/local/compile_hlg.py index df6c609bb..c628dfd53 100755 --- a/egs/librispeech/ASR/local/compile_hlg.py +++ b/egs/librispeech/ASR/local/compile_hlg.py @@ -57,7 +57,7 @@ def get_args(): return parser.parse_args() -def compile_HLG(lang_dir: str, lm: str = "G_3_gram") -> k2.Fsa: +def compile_HLG(lang_dir: str, lm: str="G_3_gram") -> k2.Fsa: """ Args: lang_dir: @@ -159,7 +159,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/librispeech/ASR/local/compile_lg.py b/egs/librispeech/ASR/local/compile_lg.py index 19bf3bff4..45c4b7f5f 100755 --- a/egs/librispeech/ASR/local/compile_lg.py +++ b/egs/librispeech/ASR/local/compile_lg.py @@ -132,7 +132,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/librispeech/ASR/local/compute_fbank_gigaspeech_dev_test.py b/egs/librispeech/ASR/local/compute_fbank_gigaspeech_dev_test.py index 97750f3ea..c0c7ef8c5 100644 --- a/egs/librispeech/ASR/local/compute_fbank_gigaspeech_dev_test.py +++ b/egs/librispeech/ASR/local/compute_fbank_gigaspeech_dev_test.py @@ -80,7 +80,9 @@ def compute_fbank_gigaspeech_dev_test(): def main(): - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) compute_fbank_gigaspeech_dev_test() diff --git a/egs/librispeech/ASR/local/compute_fbank_gigaspeech_splits.py b/egs/librispeech/ASR/local/compute_fbank_gigaspeech_splits.py index 37fce11f4..5587106e5 100644 --- a/egs/librispeech/ASR/local/compute_fbank_gigaspeech_splits.py +++ b/egs/librispeech/ASR/local/compute_fbank_gigaspeech_splits.py @@ -48,10 +48,8 @@ def get_parser(): "--batch-duration", type=float, default=600.0, - help=( - "The maximum number of audio seconds in a batch." - "Determines batch size dynamically." - ), + help="The maximum number of audio seconds in a batch." + "Determines batch size dynamically.", ) parser.add_argument( @@ -146,7 +144,9 @@ def main(): date_time = now.strftime("%Y-%m-%d-%H-%M-%S") log_filename = "log-compute_fbank_gigaspeech_splits" - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) log_filename = f"{log_filename}-{date_time}" logging.basicConfig( diff --git a/egs/librispeech/ASR/local/compute_fbank_librispeech.py b/egs/librispeech/ASR/local/compute_fbank_librispeech.py index 9f8503814..ce7d087f0 100755 --- a/egs/librispeech/ASR/local/compute_fbank_librispeech.py +++ b/egs/librispeech/ASR/local/compute_fbank_librispeech.py @@ -112,7 +112,9 @@ def compute_fbank_librispeech(bpe_model: Optional[str] = None): if "train" in partition: cut_set = ( - cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) + cut_set + + cut_set.perturb_speed(0.9) + + cut_set.perturb_speed(1.1) ) cut_set = cut_set.compute_and_store_features( extractor=extractor, @@ -126,7 +128,9 @@ def compute_fbank_librispeech(bpe_model: Optional[str] = None): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) args = get_args() diff --git a/egs/librispeech/ASR/local/compute_fbank_musan.py b/egs/librispeech/ASR/local/compute_fbank_musan.py index 4a4093ae4..056da29e5 100755 --- a/egs/librispeech/ASR/local/compute_fbank_musan.py +++ b/egs/librispeech/ASR/local/compute_fbank_musan.py @@ -83,7 +83,9 @@ def compute_fbank_musan(): # create chunks of Musan with duration 5 - 10 seconds musan_cuts = ( CutSet.from_manifests( - recordings=combine(part["recordings"] for part in manifests.values()) + recordings=combine( + part["recordings"] for part in manifests.values() + ) ) .cut_into_windows(10.0) .filter(lambda c: c.duration > 5) @@ -99,7 +101,9 @@ def compute_fbank_musan(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) compute_fbank_musan() diff --git a/egs/librispeech/ASR/local/convert_transcript_words_to_tokens.py b/egs/librispeech/ASR/local/convert_transcript_words_to_tokens.py index f149b7871..133499c8b 100755 --- a/egs/librispeech/ASR/local/convert_transcript_words_to_tokens.py +++ b/egs/librispeech/ASR/local/convert_transcript_words_to_tokens.py @@ -46,19 +46,21 @@ def get_args(): parser.add_argument( "--transcript", type=str, - help=( - "The input transcript file." - "We assume that the transcript file consists of " - "lines. Each line consists of space separated words." - ), + help="The input transcript file." + "We assume that the transcript file consists of " + "lines. Each line consists of space separated words.", ) parser.add_argument("--lexicon", type=str, help="The input lexicon file.") - parser.add_argument("--oov", type=str, default="", help="The OOV word.") + parser.add_argument( + "--oov", type=str, default="", help="The OOV word." + ) return parser.parse_args() -def process_line(lexicon: Dict[str, List[str]], line: str, oov_token: str) -> None: +def process_line( + lexicon: Dict[str, List[str]], line: str, oov_token: str +) -> None: """ Args: lexicon: diff --git a/egs/librispeech/ASR/local/download_lm.py b/egs/librispeech/ASR/local/download_lm.py index 3518db524..030122aa7 100755 --- a/egs/librispeech/ASR/local/download_lm.py +++ b/egs/librispeech/ASR/local/download_lm.py @@ -87,7 +87,9 @@ def main(out_dir: str): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/librispeech/ASR/local/filter_cuts.py b/egs/librispeech/ASR/local/filter_cuts.py index fbcc9e24a..dff98a954 100644 --- a/egs/librispeech/ASR/local/filter_cuts.py +++ b/egs/librispeech/ASR/local/filter_cuts.py @@ -79,7 +79,8 @@ def filter_cuts(cut_set: CutSet, sp: spm.SentencePieceProcessor): total += 1 if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. " + f"Duration: {c.duration}" ) removed += 1 return False @@ -124,7 +125,8 @@ def filter_cuts(cut_set: CutSet, sp: spm.SentencePieceProcessor): ans = cut_set.filter(remove_short_and_long_utterances).to_eager() ratio = removed / total * 100 logging.info( - f"Removed {removed} cuts from {total} cuts. {ratio:.3f}% data is removed." + f"Removed {removed} cuts from {total} cuts. " + f"{ratio:.3f}% data is removed." ) return ans @@ -153,7 +155,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/librispeech/ASR/local/generate_unique_lexicon.py b/egs/librispeech/ASR/local/generate_unique_lexicon.py index 3459c2f5a..566c0743d 100755 --- a/egs/librispeech/ASR/local/generate_unique_lexicon.py +++ b/egs/librispeech/ASR/local/generate_unique_lexicon.py @@ -91,7 +91,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/librispeech/ASR/local/prepare_lang_bpe.py b/egs/librispeech/ASR/local/prepare_lang_bpe.py index e121aefa9..dec8a7442 100755 --- a/egs/librispeech/ASR/local/prepare_lang_bpe.py +++ b/egs/librispeech/ASR/local/prepare_lang_bpe.py @@ -150,7 +150,9 @@ def generate_lexicon( words_pieces_ids: List[List[int]] = sp.encode(words, out_type=int) # Now convert word piece IDs back to word piece strings. - words_pieces: List[List[str]] = [sp.id_to_piece(ids) for ids in words_pieces_ids] + words_pieces: List[List[str]] = [ + sp.id_to_piece(ids) for ids in words_pieces_ids + ] lexicon = [] for word, pieces in zip(words, words_pieces): diff --git a/egs/librispeech/ASR/local/prepare_lm_training_data.py b/egs/librispeech/ASR/local/prepare_lm_training_data.py index 70343fef7..5070341f1 100755 --- a/egs/librispeech/ASR/local/prepare_lm_training_data.py +++ b/egs/librispeech/ASR/local/prepare_lm_training_data.py @@ -137,7 +137,8 @@ def main(): for i in range(num_sentences): if step and i % step == 0: logging.info( - f"Processed number of lines: {i} ({i/num_sentences*100: .3f}%)" + f"Processed number of lines: {i} " + f"({i/num_sentences*100: .3f}%)" ) word_ids = sentences[i] @@ -153,14 +154,18 @@ def main(): sentence_lengths[i] = token_ids.numel() - output["sentence_lengths"] = torch.tensor(sentence_lengths, dtype=torch.int32) + output["sentence_lengths"] = torch.tensor( + sentence_lengths, dtype=torch.int32 + ) torch.save(output, args.lm_archive) logging.info(f"Saved to {args.lm_archive}") if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/librispeech/ASR/local/preprocess_gigaspeech.py b/egs/librispeech/ASR/local/preprocess_gigaspeech.py index 8aa5e461d..077f23039 100644 --- a/egs/librispeech/ASR/local/preprocess_gigaspeech.py +++ b/egs/librispeech/ASR/local/preprocess_gigaspeech.py @@ -119,7 +119,9 @@ def preprocess_giga_speech(): def main(): - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) preprocess_giga_speech() diff --git a/egs/librispeech/ASR/local/test_prepare_lang.py b/egs/librispeech/ASR/local/test_prepare_lang.py index 74e025ad7..d4cf62bba 100755 --- a/egs/librispeech/ASR/local/test_prepare_lang.py +++ b/egs/librispeech/ASR/local/test_prepare_lang.py @@ -88,7 +88,9 @@ def test_read_lexicon(filename: str): fsa.aux_labels_sym = k2.SymbolTable.from_file("words.txt") fsa.draw("L.pdf", title="L") - fsa_disambig = lexicon_to_fst(lexicon_disambig, phone2id=phone2id, word2id=word2id) + fsa_disambig = lexicon_to_fst( + lexicon_disambig, phone2id=phone2id, word2id=word2id + ) fsa_disambig.labels_sym = k2.SymbolTable.from_file("phones.txt") fsa_disambig.aux_labels_sym = k2.SymbolTable.from_file("words.txt") fsa_disambig.draw("L_disambig.pdf", title="L_disambig") diff --git a/egs/librispeech/ASR/local/validate_manifest.py b/egs/librispeech/ASR/local/validate_manifest.py index 807aaf891..7c57d629a 100755 --- a/egs/librispeech/ASR/local/validate_manifest.py +++ b/egs/librispeech/ASR/local/validate_manifest.py @@ -64,7 +64,8 @@ def validate_supervision_and_cut_time_bounds(c: Cut): if s.end > c.end: raise ValueError( - f"{c.id}: Supervision end time {s.end} is larger than cut end time {c.end}" + f"{c.id}: Supervision end time {s.end} is larger " + f"than cut end time {c.end}" ) @@ -84,7 +85,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/decode.py b/egs/librispeech/ASR/lstm_transducer_stateless/decode.py old mode 100644 new mode 100755 index e69de29bb..27414d717 --- a/egs/librispeech/ASR/lstm_transducer_stateless/decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/decode.py @@ -0,0 +1,818 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) greedy search +./lstm_transducer_stateless/decode.py \ + --epoch 35 \ + --avg 15 \ + --exp-dir ./lstm_transducer_stateless/exp \ + --max-duration 600 \ + --decoding-method greedy_search + +(2) beam search (not recommended) +./lstm_transducer_stateless/decode.py \ + --epoch 35 \ + --avg 15 \ + --exp-dir ./lstm_transducer_stateless/exp \ + --max-duration 600 \ + --decoding-method beam_search \ + --beam-size 4 + +(3) modified beam search +./lstm_transducer_stateless/decode.py \ + --epoch 35 \ + --avg 15 \ + --exp-dir ./lstm_transducer_stateless/exp \ + --max-duration 600 \ + --decoding-method modified_beam_search \ + --beam-size 4 + +(4) fast beam search (one best) +./lstm_transducer_stateless/decode.py \ + --epoch 35 \ + --avg 15 \ + --exp-dir ./lstm_transducer_stateless/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 + +(5) fast beam search (nbest) +./lstm_transducer_stateless/decode.py \ + --epoch 30 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless3/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(6) fast beam search (nbest oracle WER) +./lstm_transducer_stateless/decode.py \ + --epoch 35 \ + --avg 15 \ + --exp-dir ./lstm_transducer_stateless/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_oracle \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(7) fast beam search (with LG) +./lstm_transducer_stateless/decode.py \ + --epoch 35 \ + --avg 15 \ + --exp-dir ./lstm_transducer_stateless/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_LG \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 +""" + + +import argparse +import logging +import math +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from beam_search import ( + beam_search, + fast_beam_search_nbest, + fast_beam_search_nbest_LG, + fast_beam_search_nbest_oracle, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, +) +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="lstm_transducer_stateless/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_bpe_500", + help="The lang dir containing word table and LG graph", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - fast_beam_search + - fast_beam_search_nbest + - fast_beam_search_nbest_oracle + - fast_beam_search_nbest_LG + If you use fast_beam_search_nbest_LG, you have to specify + `--lang-dir`, which should contain `LG.pt`. + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=20.0, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search, + fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle + """, + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.01, + help=""" + Used only when --decoding_method is fast_beam_search_nbest_LG. + It specifies the scale for n-gram LM scores. + """, + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=8, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=64, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", + ) + + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. + Used only when --decoding_method is greedy_search""", + ) + + parser.add_argument( + "--num-paths", + type=int, + default=200, + help="""Number of paths for nbest decoding. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""Scale applied to lattice scores when computing nbest paths. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + add_model_arguments(parser) + + return parser + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + batch: dict, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if greedy_search is used, it would be "greedy_search" + If beam search with a beam size of 7 is used, it would be + "beam_7" + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or LG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return the decoding result. See above description for the format of + the returned dict. + """ + device = next(model.parameters()).device + feature = batch["inputs"] + assert feature.ndim == 3 + + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + # tail padding here to alleviate the tail deletion problem + num_tail_padded_frames = 35 + feature = torch.nn.functional.pad( + feature, + (0, 0, 0, num_tail_padded_frames), + mode="constant", + value=LOG_EPS, + ) + feature_lens += num_tail_padded_frames + + encoder_out, encoder_out_lens, _ = model.encoder( + x=feature, x_lens=feature_lens + ) + + hyps = [] + + if params.decoding_method == "fast_beam_search": + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_LG": + hyp_tokens = fast_beam_search_nbest_LG( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in hyp_tokens: + hyps.append([word_table[i] for i in hyp]) + elif params.decoding_method == "fast_beam_search_nbest": + hyp_tokens = fast_beam_search_nbest( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_oracle": + hyp_tokens = fast_beam_search_nbest_oracle( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + ref_texts=sp.encode(supervisions["text"]), + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif ( + params.decoding_method == "greedy_search" + and params.max_sym_per_frame == 1 + ): + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + else: + batch_size = encoder_out.size(0) + + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i + 1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.decoding_method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyps.append(sp.decode(hyp).split()) + + if params.decoding_method == "greedy_search": + return {"greedy_search": hyps} + elif "fast_beam_search" in params.decoding_method: + key = f"beam_{params.beam}_" + key += f"max_contexts_{params.max_contexts}_" + key += f"max_states_{params.max_states}" + if "nbest" in params.decoding_method: + key += f"_num_paths_{params.num_paths}_" + key += f"nbest_scale_{params.nbest_scale}" + if "LG" in params.decoding_method: + key += f"_ngram_lm_scale_{params.ngram_lm_scale}" + + return {key: hyps} + else: + return {f"beam_size_{params.beam_size}": hyps} + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or LG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + if params.decoding_method == "greedy_search": + log_interval = 50 + else: + log_interval = 20 + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + word_table=word_table, + batch=batch, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = ( + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + assert params.decoding_method in ( + "greedy_search", + "beam_search", + "fast_beam_search", + "fast_beam_search_nbest", + "fast_beam_search_nbest_LG", + "fast_beam_search_nbest_oracle", + "modified_beam_search", + ) + params.res_dir = params.exp_dir / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if "fast_beam_search" in params.decoding_method: + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + if "nbest" in params.decoding_method: + params.suffix += f"-nbest-scale-{params.nbest_scale}" + params.suffix += f"-num-paths-{params.num_paths}" + if "LG" in params.decoding_method: + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" + elif "beam_search" in params.decoding_method: + params.suffix += ( + f"-{params.decoding_method}-beam-size-{params.beam_size}" + ) + else: + params.suffix += f"-context-{params.context_size}" + params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and are defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + + if "fast_beam_search" in params.decoding_method: + if params.decoding_method == "fast_beam_search_nbest_LG": + lexicon = Lexicon(params.lang_dir) + word_table = lexicon.word_table + lg_filename = params.lang_dir / "LG.pt" + logging.info(f"Loading {lg_filename}") + decoding_graph = k2.Fsa.from_dict( + torch.load(lg_filename, map_location=device) + ) + decoding_graph.scores *= params.ngram_lm_scale + else: + word_table = None + decoding_graph = k2.trivial_graph( + params.vocab_size - 1, device=device + ) + else: + decoding_graph = None + word_table = None + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + librispeech = LibriSpeechAsrDataModule(args) + + test_clean_cuts = librispeech.test_clean_cuts() + test_other_cuts = librispeech.test_other_cuts() + + test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) + test_other_dl = librispeech.test_dataloaders(test_other_cuts) + + test_sets = ["test-clean", "test-other"] + test_dl = [test_clean_dl, test_other_dl] + + for test_set, test_dl in zip(test_sets, test_dl): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + sp=sp, + word_table=word_table, + decoding_graph=decoding_graph, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/export.py b/egs/librispeech/ASR/lstm_transducer_stateless/export.py old mode 100644 new mode 100755 index e69de29bb..13dac6009 --- a/egs/librispeech/ASR/lstm_transducer_stateless/export.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/export.py @@ -0,0 +1,388 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This script converts several saved checkpoints +# to a single one using model averaging. +""" + +Usage: + +(1) Export to torchscript model using torch.jit.trace() + +./lstm_transducer_stateless/export.py \ + --exp-dir ./lstm_transducer_stateless/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 35 \ + --avg 10 \ + --jit-trace 1 + +It will generate 3 files: `encoder_jit_trace.pt`, +`decoder_jit_trace.pt`, and `joiner_jit_trace.pt`. + +(2) Export `model.state_dict()` + +./lstm_transducer_stateless/export.py \ + --exp-dir ./lstm_transducer_stateless/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 35 \ + --avg 10 + +It will generate a file `pretrained.pt` in the given `exp_dir`. You can later +load it by `icefall.checkpoint.load_checkpoint()`. + +To use the generated file with `lstm_transducer_stateless/decode.py`, +you can do: + + cd /path/to/exp_dir + ln -s pretrained.pt epoch-9999.pt + + cd /path/to/egs/librispeech/ASR + ./lstm_transducer_stateless/decode.py \ + --exp-dir ./lstm_transducer_stateless/exp \ + --epoch 9999 \ + --avg 1 \ + --max-duration 600 \ + --decoding-method greedy_search \ + --bpe-model data/lang_bpe_500/bpe.model + +Check ./pretrained.py for its usage. + +Note: If you don't want to train a model from scratch, we have +provided one for you. You can get it at + +https://huggingface.co/Zengwei/icefall-asr-librispeech-lstm-transducer-stateless-2022-08-18 + +with the following commands: + + sudo apt-get install git-lfs + git lfs install + git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-lstm-transducer-stateless-2022-08-18 + # You will find the pre-trained model in icefall-asr-librispeech-lstm-transducer-stateless-2022-08-18/exp +""" + +import argparse +import logging +from pathlib import Path + +import sentencepiece as spm +import torch +import torch.nn as nn +from scaling_converter import convert_scaled_to_non_scaled +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="""It specifies the checkpoint to use for averaging. + Note: Epoch counts from 0. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless3/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--jit-trace", + type=str2bool, + default=False, + help="""True to save a model after applying torch.jit.trace. + It will generate 3 files: + - encoder_jit_trace.pt + - decoder_jit_trace.pt + - joiner_jit_trace.pt + + Check ./jit_pretrained.py for how to use them. + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", + ) + + add_model_arguments(parser) + + return parser + + +def export_encoder_model_jit_trace( + encoder_model: nn.Module, + encoder_filename: str, +) -> None: + """Export the given encoder model with torch.jit.trace() + + Note: The warmup argument is fixed to 1. + + Args: + encoder_model: + The input encoder model + encoder_filename: + The filename to save the exported model. + """ + x = torch.zeros(1, 100, 80, dtype=torch.float32) + x_lens = torch.tensor([100], dtype=torch.int64) + states = encoder_model.get_init_states() + + traced_model = torch.jit.trace(encoder_model, (x, x_lens, states)) + traced_model.save(encoder_filename) + logging.info(f"Saved to {encoder_filename}") + + +def export_decoder_model_jit_trace( + decoder_model: nn.Module, + decoder_filename: str, +) -> None: + """Export the given decoder model with torch.jit.trace() + + Note: The argument need_pad is fixed to False. + + Args: + decoder_model: + The input decoder model + decoder_filename: + The filename to save the exported model. + """ + y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64) + need_pad = torch.tensor([False]) + + traced_model = torch.jit.trace(decoder_model, (y, need_pad)) + traced_model.save(decoder_filename) + logging.info(f"Saved to {decoder_filename}") + + +def export_joiner_model_jit_trace( + joiner_model: nn.Module, + joiner_filename: str, +) -> None: + """Export the given joiner model with torch.jit.trace() + + Note: The argument project_input is fixed to True. A user should not + project the encoder_out/decoder_out by himself/herself. The exported joiner + will do that for the user. + + Args: + joiner_model: + The input joiner model + joiner_filename: + The filename to save the exported model. + + """ + encoder_out_dim = joiner_model.encoder_proj.weight.shape[1] + decoder_out_dim = joiner_model.decoder_proj.weight.shape[1] + encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32) + decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32) + + traced_model = torch.jit.trace(joiner_model, (encoder_out, decoder_out)) + traced_model.save(joiner_filename) + logging.info(f"Saved to {joiner_filename}") + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to("cpu") + model.eval() + + if params.jit_trace is True: + convert_scaled_to_non_scaled(model, inplace=True) + logging.info("Using torch.jit.trace()") + encoder_filename = params.exp_dir / "encoder_jit_trace.pt" + export_encoder_model_jit_trace(model.encoder, encoder_filename) + + decoder_filename = params.exp_dir / "decoder_jit_trace.pt" + export_decoder_model_jit_trace(model.decoder, decoder_filename) + + joiner_filename = params.exp_dir / "joiner_jit_trace.pt" + export_joiner_model_jit_trace(model.joiner, joiner_filename) + else: + logging.info("Not using torchscript") + # Save it using a format so that it can be loaded + # by :func:`load_checkpoint` + filename = params.exp_dir / "pretrained.pt" + torch.save({"model": model.state_dict()}, str(filename)) + logging.info(f"Saved to {filename}") + + +if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/jit_pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless/jit_pretrained.py old mode 100644 new mode 100755 index e69de29bb..594c33e4f --- a/egs/librispeech/ASR/lstm_transducer_stateless/jit_pretrained.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/jit_pretrained.py @@ -0,0 +1,322 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads torchscript models, either exported by `torch.jit.trace()` +or by `torch.jit.script()`, and uses them to decode waves. +You can use the following command to get the exported models: + +./lstm_transducer_stateless/export.py \ + --exp-dir ./lstm_transducer_stateless/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 20 \ + --avg 10 \ + --jit-trace 1 + +Usage of this script: + +./lstm_transducer_stateless/jit_pretrained.py \ + --encoder-model-filename ./lstm_transducer_stateless/exp/encoder_jit_trace.pt \ + --decoder-model-filename ./lstm_transducer_stateless/exp/decoder_jit_trace.pt \ + --joiner-model-filename ./lstm_transducer_stateless/exp/joiner_jit_trace.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + /path/to/foo.wav \ + /path/to/bar.wav +""" + +import argparse +import logging +import math +from typing import List + +import kaldifeat +import sentencepiece as spm +import torch +import torchaudio +from torch.nn.utils.rnn import pad_sequence + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--encoder-model-filename", + type=str, + required=True, + help="Path to the encoder torchscript model. ", + ) + + parser.add_argument( + "--decoder-model-filename", + type=str, + required=True, + help="Path to the decoder torchscript model. ", + ) + + parser.add_argument( + "--joiner-model-filename", + type=str, + required=True, + help="Path to the joiner torchscript model. ", + ) + + parser.add_argument( + "--bpe-model", + type=str, + help="""Path to bpe.model.""", + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="Context size of the decoder model", + ) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) + # We use only the first channel + ans.append(wave[0]) + return ans + + +def greedy_search( + decoder: torch.jit.ScriptModule, + joiner: torch.jit.ScriptModule, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + context_size: int, +) -> List[List[int]]: + """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. + Args: + decoder: + The decoder model. + joiner: + The joiner model. + encoder_out: + A 3-D tensor of shape (N, T, C) + encoder_out_lens: + A 1-D tensor of shape (N,). + context_size: + The context size of the decoder model. + Returns: + Return the decoded results for each utterance. + """ + assert encoder_out.ndim == 3 + assert encoder_out.size(0) >= 1, encoder_out.size(0) + + packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( + input=encoder_out, + lengths=encoder_out_lens.cpu(), + batch_first=True, + enforce_sorted=False, + ) + + device = encoder_out.device + blank_id = 0 # hard-code to 0 + + batch_size_list = packed_encoder_out.batch_sizes.tolist() + N = encoder_out.size(0) + + assert torch.all(encoder_out_lens > 0), encoder_out_lens + assert N == batch_size_list[0], (N, batch_size_list) + + hyps = [[blank_id] * context_size for _ in range(N)] + + decoder_input = torch.tensor( + hyps, + device=device, + dtype=torch.int64, + ) # (N, context_size) + + decoder_out = decoder( + decoder_input, + need_pad=torch.tensor([False]), + ).squeeze(1) + + offset = 0 + for batch_size in batch_size_list: + start = offset + end = offset + batch_size + current_encoder_out = packed_encoder_out.data[start:end] + current_encoder_out = current_encoder_out + # current_encoder_out's shape: (batch_size, encoder_out_dim) + offset = end + + decoder_out = decoder_out[:batch_size] + + logits = joiner( + current_encoder_out, + decoder_out, + ) + # logits'shape (batch_size, vocab_size) + + assert logits.ndim == 2, logits.shape + y = logits.argmax(dim=1).tolist() + emitted = False + for i, v in enumerate(y): + if v != blank_id: + hyps[i].append(v) + emitted = True + if emitted: + # update decoder output + decoder_input = [h[-context_size:] for h in hyps[:batch_size]] + decoder_input = torch.tensor( + decoder_input, + device=device, + dtype=torch.int64, + ) + decoder_out = decoder( + decoder_input, + need_pad=torch.tensor([False]), + ) + decoder_out = decoder_out.squeeze(1) + + sorted_ans = [h[context_size:] for h in hyps] + ans = [] + unsorted_indices = packed_encoder_out.unsorted_indices.tolist() + for i in range(N): + ans.append(sorted_ans[unsorted_indices[i]]) + + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + logging.info(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + encoder = torch.jit.load(args.encoder_model_filename) + decoder = torch.jit.load(args.decoder_model_filename) + joiner = torch.jit.load(args.joiner_model_filename) + + encoder.eval() + decoder.eval() + joiner.eval() + + encoder.to(device) + decoder.to(device) + joiner.to(device) + + sp = spm.SentencePieceProcessor() + sp.load(args.bpe_model) + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = args.sample_rate + opts.mel_opts.num_bins = 80 + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {args.sound_files}") + waves = read_sound_files( + filenames=args.sound_files, + expected_sample_rate=args.sample_rate, + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + + features = pad_sequence( + features, + batch_first=True, + padding_value=math.log(1e-10), + ) + + feature_lengths = torch.tensor(feature_lengths, device=device) + + states = encoder.get_init_states(batch_size=features.size(0), device=device) + + encoder_out, encoder_out_lens, _ = encoder( + x=features, + x_lens=feature_lengths, + states=states, + ) + + hyps = greedy_search( + decoder=decoder, + joiner=joiner, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + context_size=args.context_size, + ) + s = "\n" + for filename, hyp in zip(args.sound_files, hyps): + words = sp.decode(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/lstm.py b/egs/librispeech/ASR/lstm_transducer_stateless/lstm.py index e69de29bb..c54a4c478 100644 --- a/egs/librispeech/ASR/lstm_transducer_stateless/lstm.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/lstm.py @@ -0,0 +1,871 @@ +# Copyright 2022 Xiaomi Corp. (authors: Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import math +from typing import List, Optional, Tuple + +import torch +from encoder_interface import EncoderInterface +from scaling import ( + ActivationBalancer, + BasicNorm, + DoubleSwish, + ScaledConv2d, + ScaledLinear, + ScaledLSTM, +) +from torch import nn + +LOG_EPSILON = math.log(1e-10) + + +def unstack_states( + states: Tuple[torch.Tensor, torch.Tensor] +) -> List[Tuple[torch.Tensor, torch.Tensor]]: + """ + Unstack the lstm states corresponding to a batch of utterances into a list + of states, where the i-th entry is the state from the i-th utterance. + + Args: + states: + A tuple of 2 elements. + ``states[0]`` is the lstm hidden states, of a batch of utterance. + ``states[1]`` is the lstm cell states, of a batch of utterances. + + Returns: + A list of states. + ``states[i]`` is a tuple of 2 elememts of i-th utterance. + ``states[i][0]`` is the lstm hidden states of i-th utterance. + ``states[i][1]`` is the lstm cell states of i-th utterance. + """ + hidden_states, cell_states = states + + list_hidden_states = hidden_states.unbind(dim=1) + list_cell_states = cell_states.unbind(dim=1) + + ans = [ + (h.unsqueeze(1), c.unsqueeze(1)) + for (h, c) in zip(list_hidden_states, list_cell_states) + ] + return ans + + +def stack_states( + states_list: List[Tuple[torch.Tensor, torch.Tensor]] +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Stack list of lstm states corresponding to separate utterances into a single + lstm state so that it can be used as an input for lstm when those utterances + are formed into a batch. + + Args: + state_list: + Each element in state_list corresponds to the lstm state for a single + utterance. + ``states[i]`` is a tuple of 2 elememts of i-th utterance. + ``states[i][0]`` is the lstm hidden states of i-th utterance. + ``states[i][1]`` is the lstm cell states of i-th utterance. + + + Returns: + A new state corresponding to a batch of utterances. + It is a tuple of 2 elements. + ``states[0]`` is the lstm hidden states, of a batch of utterance. + ``states[1]`` is the lstm cell states, of a batch of utterances. + """ + hidden_states = torch.cat([s[0] for s in states_list], dim=1) + cell_states = torch.cat([s[1] for s in states_list], dim=1) + ans = (hidden_states, cell_states) + return ans + + +class RNN(EncoderInterface): + """ + Args: + num_features (int): + Number of input features. + subsampling_factor (int): + Subsampling factor of encoder (convolution layers before lstm layers) (default=4). # noqa + d_model (int): + Output dimension (default=512). + dim_feedforward (int): + Feedforward dimension (default=2048). + rnn_hidden_size (int): + Hidden dimension for lstm layers (default=1024). + num_encoder_layers (int): + Number of encoder layers (default=12). + dropout (float): + Dropout rate (default=0.1). + layer_dropout (float): + Dropout value for model-level warmup (default=0.075). + aux_layer_period (int): + Period of auxiliary layers used for random combiner during training. + If set to 0, will not use the random combiner (Default). + You can set a positive integer to use the random combiner, e.g., 3. + is_pnnx: + True to make this class exportable via PNNX. + """ + + def __init__( + self, + num_features: int, + subsampling_factor: int = 4, + d_model: int = 512, + dim_feedforward: int = 2048, + rnn_hidden_size: int = 1024, + num_encoder_layers: int = 12, + dropout: float = 0.1, + layer_dropout: float = 0.075, + aux_layer_period: int = 0, + is_pnnx: bool = False, + ) -> None: + super(RNN, self).__init__() + + self.num_features = num_features + self.subsampling_factor = subsampling_factor + if subsampling_factor != 4: + raise NotImplementedError("Support only 'subsampling_factor=4'.") + + # self.encoder_embed converts the input of shape (N, T, num_features) + # to the shape (N, T//subsampling_factor, d_model). + # That is, it does two things simultaneously: + # (1) subsampling: T -> T//subsampling_factor + # (2) embedding: num_features -> d_model + self.encoder_embed = Conv2dSubsampling( + num_features, + d_model, + is_pnnx=is_pnnx, + ) + + self.is_pnnx = is_pnnx + + self.num_encoder_layers = num_encoder_layers + self.d_model = d_model + self.rnn_hidden_size = rnn_hidden_size + + encoder_layer = RNNEncoderLayer( + d_model=d_model, + dim_feedforward=dim_feedforward, + rnn_hidden_size=rnn_hidden_size, + dropout=dropout, + layer_dropout=layer_dropout, + ) + self.encoder = RNNEncoder( + encoder_layer, + num_encoder_layers, + aux_layers=list( + range( + num_encoder_layers // 3, + num_encoder_layers - 1, + aux_layer_period, + ) + ) + if aux_layer_period > 0 + else None, + ) + + def forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + warmup: float = 1.0, + ) -> Tuple[torch.Tensor, torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """ + Args: + x: + The input tensor. Its shape is (N, T, C), where N is the batch size, + T is the sequence length, C is the feature dimension. + x_lens: + A tensor of shape (N,), containing the number of frames in `x` + before padding. + states: + A tuple of 2 tensors (optional). It is for streaming inference. + states[0] is the hidden states of all layers, + with shape of (num_layers, N, d_model); + states[1] is the cell states of all layers, + with shape of (num_layers, N, rnn_hidden_size). + warmup: + A floating point value that gradually increases from 0 throughout + training; when it is >= 1.0 we are "fully warmed up". It is used + to turn modules on sequentially. + + Returns: + A tuple of 3 tensors: + - embeddings: its shape is (N, T', d_model), where T' is the output + sequence lengths. + - lengths: a tensor of shape (batch_size,) containing the number of + frames in `embeddings` before padding. + - updated states, whose shape is the same as the input states. + """ + x = self.encoder_embed(x) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + + # lengths = ((x_lens - 3) // 2 - 1) // 2 # issue an warning + # + # Note: rounding_mode in torch.div() is available only in torch >= 1.8.0 + if not self.is_pnnx: + lengths = (((x_lens - 3) >> 1) - 1) >> 1 + else: + lengths1 = torch.floor((x_lens - 3) / 2) + lengths = torch.floor((lengths1 - 1) / 2) + lengths = lengths.to(x_lens) + + if not torch.jit.is_tracing(): + assert x.size(0) == lengths.max().item() + + if states is None: + x = self.encoder(x, warmup=warmup)[0] + # torch.jit.trace requires returned types to be the same as annotated # noqa + new_states = (torch.empty(0), torch.empty(0)) + else: + assert not self.training + assert len(states) == 2 + if not torch.jit.is_tracing(): + # for hidden state + assert states[0].shape == ( + self.num_encoder_layers, + x.size(1), + self.d_model, + ) + # for cell state + assert states[1].shape == ( + self.num_encoder_layers, + x.size(1), + self.rnn_hidden_size, + ) + x, new_states = self.encoder(x, states) + + x = x.permute(1, 0, 2) # (T, N, C) -> (N, T, C) + return x, lengths, new_states + + @torch.jit.export + def get_init_states( + self, batch_size: int = 1, device: torch.device = torch.device("cpu") + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Get model initial states.""" + # for rnn hidden states + hidden_states = torch.zeros( + (self.num_encoder_layers, batch_size, self.d_model), device=device + ) + cell_states = torch.zeros( + (self.num_encoder_layers, batch_size, self.rnn_hidden_size), + device=device, + ) + return (hidden_states, cell_states) + + +class RNNEncoderLayer(nn.Module): + """ + RNNEncoderLayer is made up of lstm and feedforward networks. + + Args: + d_model: + The number of expected features in the input (required). + dim_feedforward: + The dimension of feedforward network model (default=2048). + rnn_hidden_size: + The hidden dimension of rnn layer. + dropout: + The dropout value (default=0.1). + layer_dropout: + The dropout value for model-level warmup (default=0.075). + """ + + def __init__( + self, + d_model: int, + dim_feedforward: int, + rnn_hidden_size: int, + dropout: float = 0.1, + layer_dropout: float = 0.075, + ) -> None: + super(RNNEncoderLayer, self).__init__() + self.layer_dropout = layer_dropout + self.d_model = d_model + self.rnn_hidden_size = rnn_hidden_size + + assert rnn_hidden_size >= d_model, (rnn_hidden_size, d_model) + self.lstm = ScaledLSTM( + input_size=d_model, + hidden_size=rnn_hidden_size, + proj_size=d_model if rnn_hidden_size > d_model else 0, + num_layers=1, + dropout=0.0, + ) + self.feed_forward = nn.Sequential( + ScaledLinear(d_model, dim_feedforward), + ActivationBalancer(channel_dim=-1), + DoubleSwish(), + nn.Dropout(dropout), + ScaledLinear(dim_feedforward, d_model, initial_scale=0.25), + ) + self.norm_final = BasicNorm(d_model) + + # try to ensure the output is close to zero-mean (or at least, zero-median). # noqa + self.balancer = ActivationBalancer( + channel_dim=-1, min_positive=0.45, max_positive=0.55, max_abs=6.0 + ) + self.dropout = nn.Dropout(dropout) + + def forward( + self, + src: torch.Tensor, + states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + warmup: float = 1.0, + ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """ + Pass the input through the encoder layer. + + Args: + src: + The sequence to the encoder layer (required). + Its shape is (S, N, E), where S is the sequence length, + N is the batch size, and E is the feature number. + states: + A tuple of 2 tensors (optional). It is for streaming inference. + states[0] is the hidden states of all layers, + with shape of (1, N, d_model); + states[1] is the cell states of all layers, + with shape of (1, N, rnn_hidden_size). + warmup: + It controls selective bypass of of layers; if < 1.0, we will + bypass layers more frequently. + """ + src_orig = src + + warmup_scale = min(0.1 + warmup, 1.0) + # alpha = 1.0 means fully use this encoder layer, 0.0 would mean + # completely bypass it. + if self.training: + alpha = ( + warmup_scale + if torch.rand(()).item() <= (1.0 - self.layer_dropout) + else 0.1 + ) + else: + alpha = 1.0 + + # lstm module + if states is None: + src_lstm = self.lstm(src)[0] + # torch.jit.trace requires returned types be the same as annotated + new_states = (torch.empty(0), torch.empty(0)) + else: + assert not self.training + assert len(states) == 2 + if not torch.jit.is_tracing(): + # for hidden state + assert states[0].shape == (1, src.size(1), self.d_model) + # for cell state + assert states[1].shape == (1, src.size(1), self.rnn_hidden_size) + src_lstm, new_states = self.lstm(src, states) + src = self.dropout(src_lstm) + src + + # feed forward module + src = src + self.dropout(self.feed_forward(src)) + + src = self.norm_final(self.balancer(src)) + + if alpha != 1.0: + src = alpha * src + (1 - alpha) * src_orig + + return src, new_states + + +class RNNEncoder(nn.Module): + """ + RNNEncoder is a stack of N encoder layers. + + Args: + encoder_layer: + An instance of the RNNEncoderLayer() class (required). + num_layers: + The number of sub-encoder-layers in the encoder (required). + """ + + def __init__( + self, + encoder_layer: nn.Module, + num_layers: int, + aux_layers: Optional[List[int]] = None, + ) -> None: + super(RNNEncoder, self).__init__() + self.layers = nn.ModuleList( + [copy.deepcopy(encoder_layer) for i in range(num_layers)] + ) + self.num_layers = num_layers + self.d_model = encoder_layer.d_model + self.rnn_hidden_size = encoder_layer.rnn_hidden_size + + self.aux_layers: List[int] = [] + self.combiner: Optional[nn.Module] = None + if aux_layers is not None: + assert len(set(aux_layers)) == len(aux_layers) + assert num_layers - 1 not in aux_layers + self.aux_layers = aux_layers + [num_layers - 1] + self.combiner = RandomCombine( + num_inputs=len(self.aux_layers), + final_weight=0.5, + pure_prob=0.333, + stddev=2.0, + ) + + def forward( + self, + src: torch.Tensor, + states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + warmup: float = 1.0, + ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """ + Pass the input through the encoder layer in turn. + + Args: + src: + The sequence to the encoder layer (required). + Its shape is (S, N, E), where S is the sequence length, + N is the batch size, and E is the feature number. + states: + A tuple of 2 tensors (optional). It is for streaming inference. + states[0] is the hidden states of all layers, + with shape of (num_layers, N, d_model); + states[1] is the cell states of all layers, + with shape of (num_layers, N, rnn_hidden_size). + warmup: + It controls selective bypass of of layers; if < 1.0, we will + bypass layers more frequently. + """ + if states is not None: + assert not self.training + assert len(states) == 2 + if not torch.jit.is_tracing(): + # for hidden state + assert states[0].shape == ( + self.num_layers, + src.size(1), + self.d_model, + ) + # for cell state + assert states[1].shape == ( + self.num_layers, + src.size(1), + self.rnn_hidden_size, + ) + + output = src + + outputs = [] + + new_hidden_states = [] + new_cell_states = [] + + for i, mod in enumerate(self.layers): + if states is None: + output = mod(output, warmup=warmup)[0] + else: + layer_state = ( + states[0][i : i + 1, :, :], # h: (1, N, d_model) + states[1][i : i + 1, :, :], # c: (1, N, rnn_hidden_size) + ) + output, (h, c) = mod(output, layer_state) + new_hidden_states.append(h) + new_cell_states.append(c) + + if self.combiner is not None and i in self.aux_layers: + outputs.append(output) + + if self.combiner is not None: + output = self.combiner(outputs) + + if states is None: + new_states = (torch.empty(0), torch.empty(0)) + else: + new_states = ( + torch.cat(new_hidden_states, dim=0), + torch.cat(new_cell_states, dim=0), + ) + + return output, new_states + + +class Conv2dSubsampling(nn.Module): + """Convolutional 2D subsampling (to 1/4 length). + + Convert an input of shape (N, T, idim) to an output + with shape (N, T', odim), where + T' = ((T-3)//2-1)//2, which approximates T' == T//4 + + It is based on + https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + layer1_channels: int = 8, + layer2_channels: int = 32, + layer3_channels: int = 128, + is_pnnx: bool = False, + ) -> None: + """ + Args: + in_channels: + Number of channels in. The input shape is (N, T, in_channels). + Caution: It requires: T >= 9, in_channels >= 9. + out_channels + Output dim. The output shape is (N, ((T-3)//2-1)//2, out_channels) + layer1_channels: + Number of channels in layer1 + layer1_channels: + Number of channels in layer2 + is_pnnx: + True if we are converting the model to PNNX format. + False otherwise. + """ + assert in_channels >= 9 + super().__init__() + + self.conv = nn.Sequential( + ScaledConv2d( + in_channels=1, + out_channels=layer1_channels, + kernel_size=3, + padding=0, + ), + ActivationBalancer(channel_dim=1), + DoubleSwish(), + ScaledConv2d( + in_channels=layer1_channels, + out_channels=layer2_channels, + kernel_size=3, + stride=2, + ), + ActivationBalancer(channel_dim=1), + DoubleSwish(), + ScaledConv2d( + in_channels=layer2_channels, + out_channels=layer3_channels, + kernel_size=3, + stride=2, + ), + ActivationBalancer(channel_dim=1), + DoubleSwish(), + ) + self.out = ScaledLinear( + layer3_channels * (((in_channels - 3) // 2 - 1) // 2), out_channels + ) + # set learn_eps=False because out_norm is preceded by `out`, and `out` + # itself has learned scale, so the extra degree of freedom is not + # needed. + self.out_norm = BasicNorm(out_channels, learn_eps=False) + # constrain median of output to be close to zero. + self.out_balancer = ActivationBalancer( + channel_dim=-1, min_positive=0.45, max_positive=0.55 + ) + + # ncnn supports only batch size == 1 + self.is_pnnx = is_pnnx + self.conv_out_dim = self.out.weight.shape[1] + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Subsample x. + + Args: + x: + Its shape is (N, T, idim). + + Returns: + Return a tensor of shape (N, ((T-3)//2-1)//2, odim) + """ + # On entry, x is (N, T, idim) + x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W) + x = self.conv(x) + + if torch.jit.is_tracing() and self.is_pnnx: + x = x.permute(0, 2, 1, 3).reshape(1, -1, self.conv_out_dim) + x = self.out(x) + else: + # Now x is of shape (N, odim, ((T-3)//2-1)//2, ((idim-3)//2-1)//2) + b, c, t, f = x.size() + x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) + + # Now x is of shape (N, ((T-3)//2-1))//2, odim) + x = self.out_norm(x) + x = self.out_balancer(x) + return x + + +class RandomCombine(nn.Module): + """ + This module combines a list of Tensors, all with the same shape, to + produce a single output of that same shape which, in training time, + is a random combination of all the inputs; but which in test time + will be just the last input. + + The idea is that the list of Tensors will be a list of outputs of multiple + conformer layers. This has a similar effect as iterated loss. (See: + DEJA-VU: DOUBLE FEATURE PRESENTATION AND ITERATED LOSS IN DEEP TRANSFORMER + NETWORKS). + """ + + def __init__( + self, + num_inputs: int, + final_weight: float = 0.5, + pure_prob: float = 0.5, + stddev: float = 2.0, + ) -> None: + """ + Args: + num_inputs: + The number of tensor inputs, which equals the number of layers' + outputs that are fed into this module. E.g. in an 18-layer neural + net if we output layers 16, 12, 18, num_inputs would be 3. + final_weight: + The amount of weight or probability we assign to the + final layer when randomly choosing layers or when choosing + continuous layer weights. + pure_prob: + The probability, on each frame, with which we choose + only a single layer to output (rather than an interpolation) + stddev: + A standard deviation that we add to log-probs for computing + randomized weights. + + The method of choosing which layers, or combinations of layers, to use, + is conceptually as follows:: + + With probability `pure_prob`:: + With probability `final_weight`: choose final layer, + Else: choose random non-final layer. + Else:: + Choose initial log-weights that correspond to assigning + weight `final_weight` to the final layer and equal + weights to other layers; then add Gaussian noise + with variance `stddev` to these log-weights, and normalize + to weights (note: the average weight assigned to the + final layer here will not be `final_weight` if stddev>0). + """ + super().__init__() + assert 0 <= pure_prob <= 1, pure_prob + assert 0 < final_weight < 1, final_weight + assert num_inputs >= 1 + + self.num_inputs = num_inputs + self.final_weight = final_weight + self.pure_prob = pure_prob + self.stddev = stddev + + self.final_log_weight = ( + torch.tensor( + (final_weight / (1 - final_weight)) * (self.num_inputs - 1) + ) + .log() + .item() + ) + + def forward(self, inputs: List[torch.Tensor]) -> torch.Tensor: + """Forward function. + Args: + inputs: + A list of Tensor, e.g. from various layers of a transformer. + All must be the same shape, of (*, num_channels) + Returns: + A Tensor of shape (*, num_channels). In test mode + this is just the final input. + """ + num_inputs = self.num_inputs + assert len(inputs) == num_inputs + if not self.training or torch.jit.is_scripting(): + return inputs[-1] + + # Shape of weights: (*, num_inputs) + num_channels = inputs[0].shape[-1] + num_frames = inputs[0].numel() // num_channels + + ndim = inputs[0].ndim + # stacked_inputs: (num_frames, num_channels, num_inputs) + stacked_inputs = torch.stack(inputs, dim=ndim).reshape( + (num_frames, num_channels, num_inputs) + ) + + # weights: (num_frames, num_inputs) + weights = self._get_random_weights( + inputs[0].dtype, inputs[0].device, num_frames + ) + + weights = weights.reshape(num_frames, num_inputs, 1) + # ans: (num_frames, num_channels, 1) + ans = torch.matmul(stacked_inputs, weights) + # ans: (*, num_channels) + + ans = ans.reshape(inputs[0].shape[:-1] + (num_channels,)) + + # The following if causes errors for torch script in torch 1.6.0 + # if __name__ == "__main__": + # # for testing only... + # print("Weights = ", weights.reshape(num_frames, num_inputs)) + return ans + + def _get_random_weights( + self, dtype: torch.dtype, device: torch.device, num_frames: int + ) -> torch.Tensor: + """Return a tensor of random weights, of shape + `(num_frames, self.num_inputs)`, + Args: + dtype: + The data-type desired for the answer, e.g. float, double. + device: + The device needed for the answer. + num_frames: + The number of sets of weights desired + Returns: + A tensor of shape (num_frames, self.num_inputs), such that + `ans.sum(dim=1)` is all ones. + """ + pure_prob = self.pure_prob + if pure_prob == 0.0: + return self._get_random_mixed_weights(dtype, device, num_frames) + elif pure_prob == 1.0: + return self._get_random_pure_weights(dtype, device, num_frames) + else: + p = self._get_random_pure_weights(dtype, device, num_frames) + m = self._get_random_mixed_weights(dtype, device, num_frames) + return torch.where( + torch.rand(num_frames, 1, device=device) < self.pure_prob, p, m + ) + + def _get_random_pure_weights( + self, dtype: torch.dtype, device: torch.device, num_frames: int + ): + """Return a tensor of random one-hot weights, of shape + `(num_frames, self.num_inputs)`, + Args: + dtype: + The data-type desired for the answer, e.g. float, double. + device: + The device needed for the answer. + num_frames: + The number of sets of weights desired. + Returns: + A one-hot tensor of shape `(num_frames, self.num_inputs)`, with + exactly one weight equal to 1.0 on each frame. + """ + final_prob = self.final_weight + + # final contains self.num_inputs - 1 in all elements + final = torch.full((num_frames,), self.num_inputs - 1, device=device) + # nonfinal contains random integers in [0..num_inputs - 2], these are for non-final weights. # noqa + nonfinal = torch.randint( + self.num_inputs - 1, (num_frames,), device=device + ) + + indexes = torch.where( + torch.rand(num_frames, device=device) < final_prob, final, nonfinal + ) + ans = torch.nn.functional.one_hot( + indexes, num_classes=self.num_inputs + ).to(dtype=dtype) + return ans + + def _get_random_mixed_weights( + self, dtype: torch.dtype, device: torch.device, num_frames: int + ): + """Return a tensor of random one-hot weights, of shape + `(num_frames, self.num_inputs)`, + Args: + dtype: + The data-type desired for the answer, e.g. float, double. + device: + The device needed for the answer. + num_frames: + The number of sets of weights desired. + Returns: + A tensor of shape (num_frames, self.num_inputs), which elements + in [0..1] that sum to one over the second axis, i.e. + `ans.sum(dim=1)` is all ones. + """ + logprobs = ( + torch.randn(num_frames, self.num_inputs, dtype=dtype, device=device) + * self.stddev # noqa + ) + logprobs[:, -1] += self.final_log_weight + return logprobs.softmax(dim=1) + + +def _test_random_combine(final_weight: float, pure_prob: float, stddev: float): + print( + f"_test_random_combine: final_weight={final_weight}, pure_prob={pure_prob}, stddev={stddev}" # noqa + ) + num_inputs = 3 + num_channels = 50 + m = RandomCombine( + num_inputs=num_inputs, + final_weight=final_weight, + pure_prob=pure_prob, + stddev=stddev, + ) + + x = [torch.ones(3, 4, num_channels) for _ in range(num_inputs)] + + y = m(x) + assert y.shape == x[0].shape + assert torch.allclose(y, x[0]) # .. since actually all ones. + + +def _test_random_combine_main(): + _test_random_combine(0.999, 0, 0.0) + _test_random_combine(0.5, 0, 0.0) + _test_random_combine(0.999, 0, 0.0) + _test_random_combine(0.5, 0, 0.3) + _test_random_combine(0.5, 1, 0.3) + _test_random_combine(0.5, 0.5, 0.3) + + feature_dim = 50 + c = RNN(num_features=feature_dim, d_model=128) + batch_size = 5 + seq_len = 20 + # Just make sure the forward pass runs. + f = c( + torch.randn(batch_size, seq_len, feature_dim), + torch.full((batch_size,), seq_len, dtype=torch.int64), + ) + f # to remove flake8 warnings + + +if __name__ == "__main__": + feature_dim = 80 + m = RNN( + num_features=feature_dim, + d_model=512, + rnn_hidden_size=1024, + dim_feedforward=2048, + num_encoder_layers=12, + ) + batch_size = 5 + seq_len = 20 + # Just make sure the forward pass runs. + f = m( + torch.randn(batch_size, seq_len, feature_dim), + torch.full((batch_size,), seq_len, dtype=torch.int64), + warmup=0.5, + ) + num_param = sum([p.numel() for p in m.parameters()]) + print(f"Number of model parameters: {num_param}") + + _test_random_combine_main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/model.py b/egs/librispeech/ASR/lstm_transducer_stateless/model.py index e69de29bb..d71132b4a 100644 --- a/egs/librispeech/ASR/lstm_transducer_stateless/model.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/model.py @@ -0,0 +1,210 @@ +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, Wei Kang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Tuple + +import k2 +import torch +import torch.nn as nn +from encoder_interface import EncoderInterface +from scaling import ScaledLinear + +from icefall.utils import add_sos + + +class Transducer(nn.Module): + """It implements https://arxiv.org/pdf/1211.3711.pdf + "Sequence Transduction with Recurrent Neural Networks" + """ + + def __init__( + self, + encoder: EncoderInterface, + decoder: nn.Module, + joiner: nn.Module, + encoder_dim: int, + decoder_dim: int, + joiner_dim: int, + vocab_size: int, + ): + """ + Args: + encoder: + It is the transcription network in the paper. Its accepts + two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,). + It returns two tensors: `logits` of shape (N, T, encoder_dm) and + `logit_lens` of shape (N,). + decoder: + It is the prediction network in the paper. Its input shape + is (N, U) and its output shape is (N, U, decoder_dim). + It should contain one attribute: `blank_id`. + joiner: + It has two inputs with shapes: (N, T, encoder_dim) and + (N, U, decoder_dim). + Its output shape is (N, T, U, vocab_size). Note that its output + contains unnormalized probs, i.e., not processed by log-softmax. + """ + super().__init__() + assert isinstance(encoder, EncoderInterface), type(encoder) + assert hasattr(decoder, "blank_id") + + self.encoder = encoder + self.decoder = decoder + self.joiner = joiner + + self.simple_am_proj = ScaledLinear( + encoder_dim, vocab_size, initial_speed=0.5 + ) + self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size) + + def forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + y: k2.RaggedTensor, + prune_range: int = 5, + am_scale: float = 0.0, + lm_scale: float = 0.0, + warmup: float = 1.0, + reduction: str = "sum", + delay_penalty: float = 0.0, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + x: + A 3-D tensor of shape (N, T, C). + x_lens: + A 1-D tensor of shape (N,). It contains the number of frames in `x` + before padding. + y: + A ragged tensor with 2 axes [utt][label]. It contains labels of each + utterance. + prune_range: + The prune range for rnnt loss, it means how many symbols(context) + we are considering for each frame to compute the loss. + am_scale: + The scale to smooth the loss with am (output of encoder network) + part + lm_scale: + The scale to smooth the loss with lm (output of predictor network) + part + warmup: + A value warmup >= 0 that determines which modules are active, values + warmup > 1 "are fully warmed up" and all modules will be active. + reduction: + "sum" to sum the losses over all utterances in the batch. + "none" to return the loss in a 1-D tensor for each utterance + in the batch. + delay_penalty: + A constant value used to penalize symbol delay, to encourage + streaming models to emit symbols earlier. + See https://github.com/k2-fsa/k2/issues/955 and + https://arxiv.org/pdf/2211.00490.pdf for more details. + Returns: + Return the transducer loss. + + Note: + Regarding am_scale & lm_scale, it will make the loss-function one of + the form: + lm_scale * lm_probs + am_scale * am_probs + + (1-lm_scale-am_scale) * combined_probs + """ + assert reduction in ("sum", "none"), reduction + assert x.ndim == 3, x.shape + assert x_lens.ndim == 1, x_lens.shape + assert y.num_axes == 2, y.num_axes + + assert x.size(0) == x_lens.size(0) == y.dim0 + + encoder_out, x_lens, _ = self.encoder(x, x_lens, warmup=warmup) + assert torch.all(x_lens > 0) + + # Now for the decoder, i.e., the prediction network + row_splits = y.shape.row_splits(1) + y_lens = row_splits[1:] - row_splits[:-1] + + blank_id = self.decoder.blank_id + sos_y = add_sos(y, sos_id=blank_id) + + # sos_y_padded: [B, S + 1], start with SOS. + sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id) + + # decoder_out: [B, S + 1, decoder_dim] + decoder_out = self.decoder(sos_y_padded) + + # Note: y does not start with SOS + # y_padded : [B, S] + y_padded = y.pad(mode="constant", padding_value=0) + + y_padded = y_padded.to(torch.int64) + boundary = torch.zeros( + (x.size(0), 4), dtype=torch.int64, device=x.device + ) + boundary[:, 2] = y_lens + boundary[:, 3] = x_lens + + lm = self.simple_lm_proj(decoder_out) + am = self.simple_am_proj(encoder_out) + + with torch.cuda.amp.autocast(enabled=False): + simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( + lm=lm.float(), + am=am.float(), + symbols=y_padded, + termination_symbol=blank_id, + lm_only_scale=lm_scale, + am_only_scale=am_scale, + boundary=boundary, + reduction=reduction, + delay_penalty=delay_penalty, + return_grad=True, + ) + + # ranges : [B, T, prune_range] + ranges = k2.get_rnnt_prune_ranges( + px_grad=px_grad, + py_grad=py_grad, + boundary=boundary, + s_range=prune_range, + ) + + # am_pruned : [B, T, prune_range, encoder_dim] + # lm_pruned : [B, T, prune_range, decoder_dim] + am_pruned, lm_pruned = k2.do_rnnt_pruning( + am=self.joiner.encoder_proj(encoder_out), + lm=self.joiner.decoder_proj(decoder_out), + ranges=ranges, + ) + + # logits : [B, T, prune_range, vocab_size] + + # project_input=False since we applied the decoder's input projections + # prior to do_rnnt_pruning (this is an optimization for speed). + logits = self.joiner(am_pruned, lm_pruned, project_input=False) + + with torch.cuda.amp.autocast(enabled=False): + pruned_loss = k2.rnnt_loss_pruned( + logits=logits.float(), + symbols=y_padded, + ranges=ranges, + termination_symbol=blank_id, + boundary=boundary, + delay_penalty=delay_penalty, + reduction=reduction, + ) + + return (simple_loss, pruned_loss) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless/pretrained.py old mode 100644 new mode 100755 index e69de29bb..2a6e2adc6 --- a/egs/librispeech/ASR/lstm_transducer_stateless/pretrained.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/pretrained.py @@ -0,0 +1,352 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +(1) greedy search +./lstm_transducer_stateless/pretrained.py \ + --checkpoint ./lstm_transducer_stateless/exp/pretrained.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method greedy_search \ + /path/to/foo.wav \ + /path/to/bar.wav + +(2) beam search +./lstm_transducer_stateless/pretrained.py \ + --checkpoint ./lstm_transducer_stateless/exp/pretrained.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method beam_search \ + --beam-size 4 \ + /path/to/foo.wav \ + /path/to/bar.wav + +(3) modified beam search +./lstm_transducer_stateless/pretrained.py \ + --checkpoint ./lstm_transducer_stateless/exp/pretrained.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method modified_beam_search \ + --beam-size 4 \ + /path/to/foo.wav \ + /path/to/bar.wav + +(4) fast beam search +./lstm_transducer_stateless/pretrained.py \ + --checkpoint ./lstm_transducer_stateless/exp/pretrained.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method fast_beam_search \ + --beam-size 4 \ + /path/to/foo.wav \ + /path/to/bar.wav + +You can also use `./lstm_transducer_stateless/exp/epoch-xx.pt`. + +Note: ./lstm_transducer_stateless/exp/pretrained.pt is generated by +./lstm_transducer_stateless/export.py +""" + + +import argparse +import logging +import math +from typing import List + +import k2 +import kaldifeat +import sentencepiece as spm +import torch +import torchaudio +from beam_search import ( + beam_search, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, +) +from torch.nn.utils.rnn import pad_sequence +from train import add_model_arguments, get_params, get_transducer_model + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--checkpoint", + type=str, + required=True, + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", + ) + + parser.add_argument( + "--bpe-model", + type=str, + help="""Path to bpe.model.""", + ) + + parser.add_argument( + "--method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - fast_beam_search + """, + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=4, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=4, + help="""Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=8, + help="""Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. Used only when + --method is greedy_search. + """, + ) + + add_model_arguments(parser) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) + # We use only the first channel + ans.append(wave[0]) + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + + params = get_params() + + params.update(vars(args)) + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(f"{params}") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + logging.info("Creating model") + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + checkpoint = torch.load(args.checkpoint, map_location="cpu") + model.load_state_dict(checkpoint["model"], strict=False) + model.to(device) + model.eval() + model.device = device + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = params.sample_rate + opts.mel_opts.num_bins = params.feature_dim + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {params.sound_files}") + waves = read_sound_files( + filenames=params.sound_files, expected_sample_rate=params.sample_rate + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + + features = pad_sequence( + features, batch_first=True, padding_value=math.log(1e-10) + ) + + feature_lengths = torch.tensor(feature_lengths, device=device) + + encoder_out, encoder_out_lens, _ = model.encoder( + x=features, x_lens=feature_lengths + ) + + num_waves = encoder_out.size(0) + hyps = [] + msg = f"Using {params.method}" + if params.method == "beam_search": + msg += f" with beam size {params.beam_size}" + logging.info(msg) + + if params.method == "fast_beam_search": + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + ) + + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + else: + for i in range(num_waves): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError(f"Unsupported method: {params.method}") + + hyps.append(sp.decode(hyp).split()) + + s = "\n" + for filename, hyp in zip(params.sound_files, hyps): + words = " ".join(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/stream.py b/egs/librispeech/ASR/lstm_transducer_stateless/stream.py index e69de29bb..97d890c82 100644 --- a/egs/librispeech/ASR/lstm_transducer_stateless/stream.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/stream.py @@ -0,0 +1,148 @@ +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import List, Optional, Tuple + +import k2 +import torch +from beam_search import Hypothesis, HypothesisList + +from icefall.utils import AttributeDict + + +class Stream(object): + def __init__( + self, + params: AttributeDict, + cut_id: str, + decoding_graph: Optional[k2.Fsa] = None, + device: torch.device = torch.device("cpu"), + LOG_EPS: float = math.log(1e-10), + ) -> None: + """ + Args: + params: + It's the return value of :func:`get_params`. + cut_id: + The cut id of the current stream. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search. + device: + The device to run this stream. + LOG_EPS: + A float value used for padding. + """ + self.LOG_EPS = LOG_EPS + self.cut_id = cut_id + + # Containing attention caches and convolution caches + self.states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None + + # It uses different attributes for different decoding methods. + self.context_size = params.context_size + self.decoding_method = params.decoding_method + if params.decoding_method == "greedy_search": + self.hyp = [params.blank_id] * params.context_size + elif params.decoding_method == "modified_beam_search": + self.hyps = HypothesisList() + self.hyps.add( + Hypothesis( + ys=[params.blank_id] * params.context_size, + log_prob=torch.zeros(1, dtype=torch.float32, device=device), + ) + ) + elif params.decoding_method == "fast_beam_search": + # feature_len is needed to get partial results. + # The rnnt_decoding_stream for fast_beam_search. + self.rnnt_decoding_stream: k2.RnntDecodingStream = ( + k2.RnntDecodingStream(decoding_graph) + ) + self.hyp: Optional[List[int]] = None + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + + self.ground_truth: str = "" + + self.feature: Optional[torch.Tensor] = None + # Make sure all feature frames can be used. + # We aim to obtain 1 frame after subsampling. + self.chunk_length = params.subsampling_factor + self.pad_length = 5 + self.num_frames = 0 + self.num_processed_frames = 0 + + # After all feature frames are processed, we set this flag to True + self._done = False + + def set_feature(self, feature: torch.Tensor) -> None: + assert feature.dim() == 2, feature.dim() + # tail padding here to alleviate the tail deletion problem + num_tail_padded_frames = 35 + self.num_frames = feature.size(0) + num_tail_padded_frames + self.feature = torch.nn.functional.pad( + feature, + (0, 0, 0, self.pad_length + num_tail_padded_frames), + mode="constant", + value=self.LOG_EPS, + ) + + def get_feature_chunk(self) -> torch.Tensor: + """Get a chunk of feature frames. + + Returns: + A tensor of shape (ret_length, feature_dim). + """ + update_length = min( + self.num_frames - self.num_processed_frames, self.chunk_length + ) + ret_length = update_length + self.pad_length + + ret_feature = self.feature[ + self.num_processed_frames : self.num_processed_frames + ret_length + ] + # Cut off used frames. + # self.feature = self.feature[update_length:] + + self.num_processed_frames += update_length + if self.num_processed_frames >= self.num_frames: + self._done = True + + return ret_feature + + @property + def id(self) -> str: + return self.cut_id + + @property + def done(self) -> bool: + """Return True if all feature frames are processed.""" + return self._done + + def decoding_result(self) -> List[int]: + """Obtain current decoding result.""" + if self.decoding_method == "greedy_search": + return self.hyp[self.context_size :] + elif self.decoding_method == "modified_beam_search": + best_hyp = self.hyps.get_most_probable(length_norm=True) + return best_hyp.ys[self.context_size :] + else: + assert self.decoding_method == "fast_beam_search" + return self.hyp diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/streaming_decode.py b/egs/librispeech/ASR/lstm_transducer_stateless/streaming_decode.py old mode 100644 new mode 100755 index e69de29bb..d6376bdc0 --- a/egs/librispeech/ASR/lstm_transducer_stateless/streaming_decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/streaming_decode.py @@ -0,0 +1,968 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) greedy search +./lstm_transducer_stateless/streaming_decode.py \ + --epoch 35 \ + --avg 10 \ + --exp-dir lstm_transducer_stateless/exp \ + --num-decode-streams 2000 \ + --num-encoder-layers 12 \ + --rnn-hidden-size 1024 \ + --decoding-method greedy_search \ + --use-averaged-model True + +(2) modified beam search +./lstm_transducer_stateless/streaming_decode.py \ + --epoch 35 \ + --avg 10 \ + --exp-dir lstm_transducer_stateless/exp \ + --num-decode-streams 2000 \ + --num-encoder-layers 12 \ + --rnn-hidden-size 1024 \ + --decoding-method modified_beam_search \ + --use-averaged-model True \ + --beam-size 4 + +(3) fast beam search +./lstm_transducer_stateless/streaming_decode.py \ + --epoch 35 \ + --avg 10 \ + --exp-dir lstm_transducer_stateless/exp \ + --num-decode-streams 2000 \ + --num-encoder-layers 12 \ + --rnn-hidden-size 1024 \ + --decoding-method fast_beam_search \ + --use-averaged-model True \ + --beam 4 \ + --max-contexts 4 \ + --max-states 8 +""" +import argparse +import logging +import warnings +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import numpy as np +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from beam_search import Hypothesis, HypothesisList, get_hyps_shape +from kaldifeat import Fbank, FbankOptions +from lhotse import CutSet +from lstm import LOG_EPSILON, stack_states, unstack_states +from stream import Stream +from torch.nn.utils.rnn import pad_sequence +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.decode import one_best_decoding +from icefall.utils import ( + AttributeDict, + get_texts, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=False, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="transducer_emformer/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - modified_beam_search + - fast_beam_search + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An interger indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=20.0, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search""", + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=8, + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=64, + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. + Used only when --decoding_method is greedy_search""", + ) + + parser.add_argument( + "--sampling-rate", + type=float, + default=16000, + help="Sample rate of the audio", + ) + + parser.add_argument( + "--num-decode-streams", + type=int, + default=2000, + help="The number of streams that can be decoded in parallel", + ) + + add_model_arguments(parser) + + return parser + + +def greedy_search( + model: nn.Module, + encoder_out: torch.Tensor, + streams: List[Stream], +) -> None: + """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. + + Args: + model: + The transducer model. + encoder_out: + Output from the encoder. Its shape is (N, T, C), where N >= 1. + streams: + A list of Stream objects. + """ + assert len(streams) == encoder_out.size(0) + assert encoder_out.ndim == 3 + + blank_id = model.decoder.blank_id + context_size = model.decoder.context_size + device = next(model.parameters()).device + T = encoder_out.size(1) + + encoder_out = model.joiner.encoder_proj(encoder_out) + + decoder_input = torch.tensor( + [stream.hyp[-context_size:] for stream in streams], + device=device, + dtype=torch.int64, + ) + # decoder_out is of shape (batch_size, 1, decoder_out_dim) + decoder_out = model.decoder(decoder_input, need_pad=False) + decoder_out = model.joiner.decoder_proj(decoder_out) + + for t in range(T): + # current_encoder_out's shape: (batch_size, 1, encoder_out_dim) + current_encoder_out = encoder_out[:, t : t + 1, :] # noqa + + logits = model.joiner( + current_encoder_out.unsqueeze(2), + decoder_out.unsqueeze(1), + project_input=False, + ) + # logits'shape (batch_size, vocab_size) + logits = logits.squeeze(1).squeeze(1) + + assert logits.ndim == 2, logits.shape + y = logits.argmax(dim=1).tolist() + emitted = False + for i, v in enumerate(y): + if v != blank_id: + streams[i].hyp.append(v) + emitted = True + if emitted: + # update decoder output + decoder_input = torch.tensor( + [stream.hyp[-context_size:] for stream in streams], + device=device, + dtype=torch.int64, + ) + decoder_out = model.decoder( + decoder_input, + need_pad=False, + ) + decoder_out = model.joiner.decoder_proj(decoder_out) + + +def modified_beam_search( + model: nn.Module, + encoder_out: torch.Tensor, + streams: List[Stream], + beam: int = 4, +): + """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. + + Args: + model: + The RNN-T model. + encoder_out: + A 3-D tensor of shape (N, T, encoder_out_dim) containing the output of + the encoder model. + streams: + A list of stream objects. + beam: + Number of active paths during the beam search. + """ + assert encoder_out.ndim == 3, encoder_out.shape + assert len(streams) == encoder_out.size(0) + + blank_id = model.decoder.blank_id + context_size = model.decoder.context_size + device = next(model.parameters()).device + batch_size = len(streams) + T = encoder_out.size(1) + + B = [stream.hyps for stream in streams] + + encoder_out = model.joiner.encoder_proj(encoder_out) + + for t in range(T): + current_encoder_out = encoder_out[:, t].unsqueeze(1).unsqueeze(1) + # current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim) + + hyps_shape = get_hyps_shape(B).to(device) + + A = [list(b) for b in B] + B = [HypothesisList() for _ in range(batch_size)] + + ys_log_probs = torch.stack( + [hyp.log_prob.reshape(1) for hyps in A for hyp in hyps], dim=0 + ) # (num_hyps, 1) + + decoder_input = torch.tensor( + [hyp.ys[-context_size:] for hyps in A for hyp in hyps], + device=device, + dtype=torch.int64, + ) # (num_hyps, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) + decoder_out = model.joiner.decoder_proj(decoder_out) + # decoder_out is of shape (num_hyps, 1, 1, decoder_output_dim) + + # Note: For torch 1.7.1 and below, it requires a torch.int64 tensor + # as index, so we use `to(torch.int64)` below. + current_encoder_out = torch.index_select( + current_encoder_out, + dim=0, + index=hyps_shape.row_ids(1).to(torch.int64), + ) # (num_hyps, encoder_out_dim) + + logits = model.joiner( + current_encoder_out, decoder_out, project_input=False + ) + # logits is of shape (num_hyps, 1, 1, vocab_size) + + logits = logits.squeeze(1).squeeze(1) + + log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size) + + log_probs.add_(ys_log_probs) + + vocab_size = log_probs.size(-1) + + log_probs = log_probs.reshape(-1) + + row_splits = hyps_shape.row_splits(1) * vocab_size + log_probs_shape = k2.ragged.create_ragged_shape2( + row_splits=row_splits, cached_tot_size=log_probs.numel() + ) + ragged_log_probs = k2.RaggedTensor( + shape=log_probs_shape, value=log_probs + ) + + for i in range(batch_size): + topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + topk_hyp_indexes = (topk_indexes // vocab_size).tolist() + topk_token_indexes = (topk_indexes % vocab_size).tolist() + + for k in range(len(topk_hyp_indexes)): + hyp_idx = topk_hyp_indexes[k] + hyp = A[i][hyp_idx] + + new_ys = hyp.ys[:] + new_token = topk_token_indexes[k] + if new_token != blank_id: + new_ys.append(new_token) + + new_log_prob = topk_log_probs[k] + new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob) + B[i].add(new_hyp) + + for i in range(batch_size): + streams[i].hyps = B[i] + + +def fast_beam_search_one_best( + model: nn.Module, + streams: List[Stream], + encoder_out: torch.Tensor, + processed_lens: torch.Tensor, + beam: float, + max_states: int, + max_contexts: int, +) -> None: + """It limits the maximum number of symbols per frame to 1. + + A lattice is first obtained using modified beam search, and then + the shortest path within the lattice is used as the final output. + + Args: + model: + An instance of `Transducer`. + streams: + A list of stream objects. + encoder_out: + A tensor of shape (N, T, C) from the encoder. + processed_lens: + A tensor of shape (N,) containing the number of processed frames + in `encoder_out` before padding. + beam: + Beam value, similar to the beam used in Kaldi.. + max_states: + Max states per stream per frame. + max_contexts: + Max contexts pre stream per frame. + """ + assert encoder_out.ndim == 3 + + context_size = model.decoder.context_size + vocab_size = model.decoder.vocab_size + + B, T, C = encoder_out.shape + assert B == len(streams) + + config = k2.RnntDecodingConfig( + vocab_size=vocab_size, + decoder_history_len=context_size, + beam=beam, + max_contexts=max_contexts, + max_states=max_states, + ) + individual_streams = [] + for i in range(B): + individual_streams.append(streams[i].rnnt_decoding_stream) + decoding_streams = k2.RnntDecodingStreams(individual_streams, config) + + encoder_out = model.joiner.encoder_proj(encoder_out) + + for t in range(T): + # shape is a RaggedShape of shape (B, context) + # contexts is a Tensor of shape (shape.NumElements(), context_size) + shape, contexts = decoding_streams.get_contexts() + # `nn.Embedding()` in torch below v1.7.1 supports only torch.int64 + contexts = contexts.to(torch.int64) + # decoder_out is of shape (shape.NumElements(), 1, decoder_out_dim) + decoder_out = model.decoder(contexts, need_pad=False) + decoder_out = model.joiner.decoder_proj(decoder_out) + # current_encoder_out is of shape + # (shape.NumElements(), 1, joiner_dim) + # fmt: off + current_encoder_out = torch.index_select( + encoder_out[:, t:t + 1, :], 0, shape.row_ids(1).to(torch.int64) + ) + # fmt: on + logits = model.joiner( + current_encoder_out.unsqueeze(2), + decoder_out.unsqueeze(1), + project_input=False, + ) + logits = logits.squeeze(1).squeeze(1) + log_probs = logits.log_softmax(dim=-1) + decoding_streams.advance(log_probs) + + decoding_streams.terminate_and_flush_to_streams() + + lattice = decoding_streams.format_output(processed_lens.tolist()) + + best_path = one_best_decoding(lattice) + hyps = get_texts(best_path) + + for i in range(B): + streams[i].hyp = hyps[i] + + +def decode_one_chunk( + model: nn.Module, + streams: List[Stream], + params: AttributeDict, + decoding_graph: Optional[k2.Fsa] = None, +) -> List[int]: + """ + Args: + model: + The Transducer model. + streams: + A list of Stream objects. + params: + It is returned by :func:`get_params`. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or LG, Used + only when --decoding_method is fast_beam_search. + + Returns: + A list of indexes indicating the finished streams. + """ + device = next(model.parameters()).device + + feature_list = [] + feature_len_list = [] + state_list = [] + num_processed_frames_list = [] + + for stream in streams: + # We should first get `stream.num_processed_frames` + # before calling `stream.get_feature_chunk()` + # since `stream.num_processed_frames` would be updated + num_processed_frames_list.append(stream.num_processed_frames) + feature = stream.get_feature_chunk() + feature_len = feature.size(0) + feature_list.append(feature) + feature_len_list.append(feature_len) + state_list.append(stream.states) + + features = pad_sequence( + feature_list, batch_first=True, padding_value=LOG_EPSILON + ).to(device) + feature_lens = torch.tensor(feature_len_list, device=device) + num_processed_frames = torch.tensor( + num_processed_frames_list, device=device + ) + + # Make sure it has at least 1 frame after subsampling + tail_length = params.subsampling_factor + 5 + if features.size(1) < tail_length: + pad_length = tail_length - features.size(1) + feature_lens += pad_length + features = torch.nn.functional.pad( + features, + (0, 0, 0, pad_length), + mode="constant", + value=LOG_EPSILON, + ) + + # Stack states of all streams + states = stack_states(state_list) + + encoder_out, encoder_out_lens, states = model.encoder( + x=features, + x_lens=feature_lens, + states=states, + ) + + if params.decoding_method == "greedy_search": + greedy_search( + model=model, + streams=streams, + encoder_out=encoder_out, + ) + elif params.decoding_method == "modified_beam_search": + modified_beam_search( + model=model, + streams=streams, + encoder_out=encoder_out, + beam=params.beam_size, + ) + elif params.decoding_method == "fast_beam_search": + # feature_len is needed to get partial results. + # The rnnt_decoding_stream for fast_beam_search. + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + processed_lens = ( + num_processed_frames // params.subsampling_factor + + encoder_out_lens + ) + fast_beam_search_one_best( + model=model, + streams=streams, + encoder_out=encoder_out, + processed_lens=processed_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + + # Update cached states of each stream + state_list = unstack_states(states) + for i, s in enumerate(state_list): + streams[i].states = s + + finished_streams = [i for i, stream in enumerate(streams) if stream.done] + return finished_streams + + +def create_streaming_feature_extractor() -> Fbank: + """Create a CPU streaming feature extractor. + + At present, we assume it returns a fbank feature extractor with + fixed options. In the future, we will support passing in the options + from outside. + + Returns: + Return a CPU streaming feature extractor. + """ + opts = FbankOptions() + opts.device = "cpu" + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = 16000 + opts.mel_opts.num_bins = 80 + return Fbank(opts) + + +def decode_dataset( + cuts: CutSet, + model: nn.Module, + params: AttributeDict, + sp: spm.SentencePieceProcessor, + decoding_graph: Optional[k2.Fsa] = None, +): + """Decode dataset. + + Args: + cuts: + Lhotse Cutset containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The Transducer model. + sp: + The BPE model. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or LG, Used + only when --decoding_method is fast_beam_search. + + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + device = next(model.parameters()).device + + log_interval = 300 + + fbank = create_streaming_feature_extractor() + + decode_results = [] + streams = [] + for num, cut in enumerate(cuts): + # Each utterance has a Stream. + stream = Stream( + params=params, + cut_id=cut.id, + decoding_graph=decoding_graph, + device=device, + LOG_EPS=LOG_EPSILON, + ) + + stream.states = model.encoder.get_init_states(device=device) + + audio: np.ndarray = cut.load_audio() + # audio.shape: (1, num_samples) + assert len(audio.shape) == 2 + assert audio.shape[0] == 1, "Should be single channel" + assert audio.dtype == np.float32, audio.dtype + # The trained model is using normalized samples + assert audio.max() <= 1, "Should be normalized to [-1, 1])" + + samples = torch.from_numpy(audio).squeeze(0) + feature = fbank(samples) + stream.set_feature(feature) + stream.ground_truth = cut.supervisions[0].text + + streams.append(stream) + + while len(streams) >= params.num_decode_streams: + finished_streams = decode_one_chunk( + model=model, + streams=streams, + params=params, + decoding_graph=decoding_graph, + ) + + for i in sorted(finished_streams, reverse=True): + decode_results.append( + ( + streams[i].id, + streams[i].ground_truth.split(), + sp.decode(streams[i].decoding_result()).split(), + ) + ) + del streams[i] + + if num % log_interval == 0: + logging.info(f"Cuts processed until now is {num}.") + + while len(streams) > 0: + finished_streams = decode_one_chunk( + model=model, + streams=streams, + params=params, + decoding_graph=decoding_graph, + ) + + for i in sorted(finished_streams, reverse=True): + decode_results.append( + ( + streams[i].id, + streams[i].ground_truth.split(), + sp.decode(streams[i].decoding_result()).split(), + ) + ) + del streams[i] + + if params.decoding_method == "greedy_search": + key = "greedy_search" + elif params.decoding_method == "fast_beam_search": + key = ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}" + ) + else: + key = f"beam_size_{params.beam_size}" + + return {key: decode_results} + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[List[str], List[str]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + store_transcripts(filename=recog_path, texts=sorted(results)) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = ( + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + assert params.decoding_method in ( + "greedy_search", + "fast_beam_search", + "modified_beam_search", + ) + params.res_dir = params.exp_dir / "streaming" / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if "fast_beam_search" in params.decoding_method: + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + elif "beam_search" in params.decoding_method: + params.suffix += ( + f"-{params.decoding_method}-beam-size-{params.beam_size}" + ) + else: + params.suffix += f"-context-{params.context_size}" + params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-streaming-decode") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and are defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + params.device = device + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.eval() + + if params.decoding_method == "fast_beam_search": + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + else: + decoding_graph = None + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + librispeech = LibriSpeechAsrDataModule(args) + + test_clean_cuts = librispeech.test_clean_cuts() + test_other_cuts = librispeech.test_other_cuts() + + test_sets = ["test-clean", "test-other"] + test_cuts = [test_clean_cuts, test_other_cuts] + + for test_set, test_cut in zip(test_sets, test_cuts): + results_dict = decode_dataset( + cuts=test_cut, + model=model, + params=params, + sp=sp, + decoding_graph=decoding_graph, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + torch.manual_seed(20220810) + main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/train.py b/egs/librispeech/ASR/lstm_transducer_stateless/train.py old mode 100644 new mode 100755 index e69de29bb..d30fc260a --- a/egs/librispeech/ASR/lstm_transducer_stateless/train.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/train.py @@ -0,0 +1,1157 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo,) +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +./lstm_transducer_stateless/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --exp-dir lstm_transducer_stateless/exp \ + --full-libri 1 \ + --max-duration 300 + +# For mix precision training: + +./lstm_transducer_stateless/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir lstm_transducer_stateless/exp \ + --full-libri 1 \ + --max-duration 550 +""" + +import argparse +import copy +import logging +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import k2 +import optim +import sentencepiece as spm +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from decoder import Decoder +from joiner import Joiner +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from lstm import RNN +from model import Transducer +from optim import Eden, Eve +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.utils import ( + AttributeDict, + MetricsTracker, + display_and_save_batch, + setup_logger, + str2bool, +) + +LRSchedulerType = Union[ + torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler +] + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=int, + default=12, + help="Number of RNN encoder layers..", + ) + + parser.add_argument( + "--encoder-dim", + type=int, + default=512, + help="Encoder output dimesion.", + ) + + parser.add_argument( + "--rnn-hidden-size", + type=int, + default=1024, + help="Hidden dim for LSTM layers.", + ) + + parser.add_argument( + "--aux-layer-period", + type=int, + default=0, + help="""Peroid of auxiliary layers used for randomly combined during training. + If set to 0, will not use the random combiner (Default). + You can set a positive integer to use the random combiner, e.g., 3. + """, + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=35, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="lstm_transducer_stateless/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--initial-lr", + type=float, + default=0.003, + help="""The initial learning rate. This value should not need to be + changed.""", + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=5000, + help="""Number of steps that affects how rapidly the learning rate decreases. + We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=10, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network)" + "part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=4000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 0. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=20, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=100, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + parser.add_argument( + "--delay-penalty", + type=float, + default=0.0, + help="""A constant value used to penalize symbol delay, + to encourage streaming models to emit symbols earlier. + See https://github.com/k2-fsa/k2/issues/955 and + https://arxiv.org/pdf/2211.00490.pdf for more details.""", + ) + + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warm_step for Noam optimizer. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, # For the 100h subset, use 800 + # parameters for conformer + "feature_dim": 80, + "subsampling_factor": 4, + "dim_feedforward": 2048, + # parameters for decoder + "decoder_dim": 512, + # parameters for joiner + "joiner_dim": 512, + # parameters for Noam + "model_warm_step": 3000, # arg given to model, not for lrate + "env_info": get_env_info(), + } + ) + + return params + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + encoder = RNN( + num_features=params.feature_dim, + subsampling_factor=params.subsampling_factor, + d_model=params.encoder_dim, + rnn_hidden_size=params.rnn_hidden_size, + dim_feedforward=params.dim_feedforward, + num_encoder_layers=params.num_encoder_layers, + aux_layer_period=params.aux_layer_period, + ) + return encoder + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + decoder_dim=params.decoder_dim, + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + joiner = Joiner( + encoder_dim=params.encoder_dim, + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return joiner + + +def get_transducer_model(params: AttributeDict) -> nn.Module: + encoder = get_encoder_model(params) + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + + model = Transducer( + encoder=encoder, + decoder=decoder, + joiner=joiner, + encoder_dim=params.encoder_dim, + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + batch: dict, + is_training: bool, + warmup: float = 1.0, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute RNN-T loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Conformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = ( + model.device + if isinstance(model, DDP) + else next(model.parameters()).device + ) + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + texts = batch["supervisions"]["text"] + y = sp.encode(texts, out_type=int) + y = k2.RaggedTensor(y).to(device) + + with torch.set_grad_enabled(is_training): + simple_loss, pruned_loss = model( + x=feature, + x_lens=feature_lens, + y=y, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + warmup=warmup, + reduction="none", + delay_penalty=params.delay_penalty if warmup >= 2.0 else 0, + ) + simple_loss_is_finite = torch.isfinite(simple_loss) + pruned_loss_is_finite = torch.isfinite(pruned_loss) + is_finite = simple_loss_is_finite & pruned_loss_is_finite + if not torch.all(is_finite): + logging.info( + "Not all losses are finite!\n" + f"simple_loss: {simple_loss}\n" + f"pruned_loss: {pruned_loss}" + ) + display_and_save_batch(batch, params=params, sp=sp) + simple_loss = simple_loss[simple_loss_is_finite] + pruned_loss = pruned_loss[pruned_loss_is_finite] + + # If either all simple_loss or pruned_loss is inf or nan, + # we stop the training process by raising an exception + if torch.all(~simple_loss_is_finite) or torch.all( + ~pruned_loss_is_finite + ): + raise ValueError( + "There are too many utterances in this batch " + "leading to inf or nan losses." + ) + + simple_loss = simple_loss.sum() + pruned_loss = pruned_loss.sum() + # after the main warmup step, we keep pruned_loss_scale small + # for the same amount of time (model_warm_step), to avoid + # overwhelming the simple_loss and causing it to diverge, + # in case it had not fully learned the alignment yet. + pruned_loss_scale = ( + 0.0 + if warmup < 1.0 + else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) + ) + loss = ( + params.simple_loss_scale * simple_loss + + pruned_loss_scale * pruned_loss + ) + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + # info["frames"] is an approximate number for two reasons: + # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2 + # (2) If some utterances in the batch lead to inf/nan loss, they + # are filtered out. + info["frames"] = ( + (feature_lens // params.subsampling_factor).sum().item() + ) + + # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa + info["utterances"] = feature.size(0) + # averaged input duration in frames over utterances + info["utt_duration"] = feature_lens.sum().item() + # averaged padding proportion over utterances + info["utt_pad_proportion"] = ( + ((feature.size(1) - feature_lens) / feature.size(1)).sum().item() + ) + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + sp: spm.SentencePieceProcessor, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(train_dl): + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + warmup=(params.batch_idx_train / params.model_warm_step), + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + scheduler.step_batch(params.batch_idx_train) + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except: # noqa + display_and_save_batch(batch, params=params, sp=sp) + raise + + if params.print_diagnostics and batch_idx == 30: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % params.log_interval == 0: + cur_lr = scheduler.get_last_lr()[0] + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}" + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) + + if batch_idx > 0 and batch_idx % params.valid_interval == 0: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + if params.full_libri is False: + params.valid_interval = 800 + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank]) + + optimizer = Eve(model.parameters(), lr=params.initial_lr) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + # # overwrite it + # scheduler.base_lrs = [params.initial_lr for _ in scheduler.base_lrs] + # print(scheduler.base_lrs) + + if params.print_diagnostics: + diagnostic = diagnostics.attach_diagnostics(model) + + librispeech = LibriSpeechAsrDataModule(args) + + train_cuts = librispeech.train_clean_100_cuts() + if params.full_libri: + train_cuts += librispeech.train_clean_360_cuts() + train_cuts += librispeech.train_other_500_cuts() + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + if c.duration < 1.0 or c.duration > 20.0: + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Duration: {c.duration}" + ) + return False + + # In pruned RNN-T, we require that T >= S + # where T is the number of feature frames after subsampling + # and S is the number of tokens in the utterance + + # In ./lstm.py, the conv module uses the following expression + # for subsampling + T = ((c.num_frames - 3) // 2 - 1) // 2 + tokens = sp.encode(c.supervisions[0].text, out_type=str) + + if T < len(tokens): + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Number of frames (before subsampling): {c.num_frames}. " + f"Number of frames (after subsampling): {T}. " + f"Text: {c.supervisions[0].text}. " + f"Tokens: {tokens}. " + f"Number of tokens: {len(tokens)}" + ) + return False + + return True + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = librispeech.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) + + valid_cuts = librispeech.dev_clean_cuts() + valid_cuts += librispeech.dev_other_cuts() + valid_dl = librispeech.valid_dataloaders(valid_cuts) + + if not params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + sp=sp, + params=params, + warmup=0.0 if params.start_epoch == 1 else 1.0, + ) + + scaler = GradScaler(enabled=params.use_fp16) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sp=sp, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + sp: spm.SentencePieceProcessor, + params: AttributeDict, + warmup: float, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + warmup=warmup, + ) + loss.backward() + optimizer.step() + optimizer.zero_grad() + except RuntimeError as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + raise + + +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py b/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py index f7e1b5a54..bad4e243e 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py @@ -185,24 +185,20 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help=( - "Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. " - ), + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", ) parser.add_argument( @@ -299,7 +295,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( @@ -477,7 +474,9 @@ def decode_one_batch( ) feature_lens += num_tail_padded_frames - encoder_out, encoder_out_lens, _ = model.encoder(x=feature, x_lens=feature_lens) + encoder_out, encoder_out_lens, _ = model.encoder( + x=feature, x_lens=feature_lens + ) hyps = [] @@ -536,7 +535,10 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + elif ( + params.decoding_method == "greedy_search" + and params.max_sym_per_frame == 1 + ): hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -698,7 +700,9 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) return results @@ -731,7 +735,8 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -784,7 +789,9 @@ def main(): if "LG" in params.decoding_method: params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + params.suffix += ( + f"-{params.decoding_method}-beam-size-{params.beam_size}" + ) else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -819,12 +826,13 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -852,12 +860,13 @@ def main(): ) else: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -886,7 +895,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - "Calculating the averaged model over epoch range from " + f"Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -952,7 +961,9 @@ def main(): decoding_graph.scores *= params.ngram_lm_scale else: word_table = None - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + decoding_graph = k2.trivial_graph( + params.vocab_size - 1, device=device + ) else: decoding_graph = None word_table = None diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/export.py b/egs/librispeech/ASR/lstm_transducer_stateless2/export.py index 0ad00cda3..190673638 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/export.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/export.py @@ -146,24 +146,20 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help=( - "Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. " - ), + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", ) parser.add_argument( @@ -229,7 +225,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) add_model_arguments(parser) @@ -345,7 +342,9 @@ def export_encoder_model_onnx( x = torch.zeros(N, 9, 80, dtype=torch.float32) x_lens = torch.tensor([9], dtype=torch.int64) h = torch.rand(encoder_model.num_encoder_layers, N, encoder_model.d_model) - c = torch.rand(encoder_model.num_encoder_layers, N, encoder_model.rnn_hidden_size) + c = torch.rand( + encoder_model.num_encoder_layers, N, encoder_model.rnn_hidden_size + ) warmup = 1.0 torch.onnx.export( @@ -446,9 +445,13 @@ def export_joiner_model_onnx( - projected_decoder_out: a tensor of shape (N, joiner_dim) """ - encoder_proj_filename = str(joiner_filename).replace(".onnx", "_encoder_proj.onnx") + encoder_proj_filename = str(joiner_filename).replace( + ".onnx", "_encoder_proj.onnx" + ) - decoder_proj_filename = str(joiner_filename).replace(".onnx", "_decoder_proj.onnx") + decoder_proj_filename = str(joiner_filename).replace( + ".onnx", "_decoder_proj.onnx" + ) encoder_out_dim = joiner_model.encoder_proj.weight.shape[1] decoder_out_dim = joiner_model.decoder_proj.weight.shape[1] @@ -547,12 +550,13 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -581,12 +585,13 @@ def main(): ) else: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -615,7 +620,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - "Calculating the averaged model over epoch range from " + f"Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -689,7 +694,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/jit_pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless2/jit_pretrained.py index 5a8efd718..da184b76f 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/jit_pretrained.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/jit_pretrained.py @@ -86,12 +86,10 @@ def get_parser(): "sound_files", type=str, nargs="+", - help=( - "The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz." - ), + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", ) parser.add_argument( @@ -126,9 +124,10 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) # We use only the first channel ans.append(wave[0]) return ans @@ -316,7 +315,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/model.py b/egs/librispeech/ASR/lstm_transducer_stateless2/model.py index 4957d14b1..fadeb4ac2 100644 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/model.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/model.py @@ -84,7 +84,9 @@ class Transducer(nn.Module): self.decoder_giga = decoder_giga self.joiner_giga = joiner_giga - self.simple_am_proj = ScaledLinear(encoder_dim, vocab_size, initial_speed=0.5) + self.simple_am_proj = ScaledLinear( + encoder_dim, vocab_size, initial_speed=0.5 + ) self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size) if decoder_giga is not None: @@ -188,7 +190,9 @@ class Transducer(nn.Module): y_padded = y.pad(mode="constant", padding_value=0) y_padded = y_padded.to(torch.int64) - boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) + boundary = torch.zeros( + (x.size(0), 4), dtype=torch.int64, device=x.device + ) boundary[:, 2] = y_lens boundary[:, 3] = x_lens diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/ncnn-decode.py b/egs/librispeech/ASR/lstm_transducer_stateless2/ncnn-decode.py index 3b471fa85..410de8d3d 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/ncnn-decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/ncnn-decode.py @@ -156,7 +156,9 @@ class Model: assert ret == 0, ret encoder_out = torch.from_numpy(ncnn_out0.numpy()).clone() - encoder_out_lens = torch.from_numpy(ncnn_out1.numpy()).to(torch.int32) + encoder_out_lens = torch.from_numpy(ncnn_out1.numpy()).to( + torch.int32 + ) hx = torch.from_numpy(ncnn_out2.numpy()).clone() cx = torch.from_numpy(ncnn_out3.numpy()).clone() return encoder_out, encoder_out_lens, hx, cx @@ -198,9 +200,10 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) # We use only the first channel ans.append(wave[0]) return ans @@ -283,7 +286,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless2/pretrained.py index 7d931a286..bef0ad760 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/pretrained.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/pretrained.py @@ -92,11 +92,9 @@ def get_parser(): "--checkpoint", type=str, required=True, - help=( - "Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint()." - ), + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", ) parser.add_argument( @@ -121,12 +119,10 @@ def get_parser(): "sound_files", type=str, nargs="+", - help=( - "The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz." - ), + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", ) parser.add_argument( @@ -173,7 +169,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -204,9 +201,10 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) # We use only the first channel ans.append(wave[0]) return ans @@ -269,11 +267,15 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + features = pad_sequence( + features, batch_first=True, padding_value=math.log(1e-10) + ) feature_lengths = torch.tensor(feature_lengths, device=device) - encoder_out, encoder_out_lens, _ = model.encoder(x=features, x_lens=feature_lengths) + encoder_out, encoder_out_lens, _ = model.encoder( + x=features, x_lens=feature_lengths + ) num_waves = encoder_out.size(0) hyps = [] @@ -345,7 +347,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-ncnn-decode.py b/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-ncnn-decode.py index baff15ea6..e47a05a9e 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-ncnn-decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-ncnn-decode.py @@ -144,7 +144,9 @@ class Model: assert ret == 0, ret encoder_out = torch.from_numpy(ncnn_out0.numpy()).clone() - encoder_out_lens = torch.from_numpy(ncnn_out1.numpy()).to(torch.int32) + encoder_out_lens = torch.from_numpy(ncnn_out1.numpy()).to( + torch.int32 + ) hx = torch.from_numpy(ncnn_out2.numpy()).clone() cx = torch.from_numpy(ncnn_out3.numpy()).clone() return encoder_out, encoder_out_lens, hx, cx @@ -186,9 +188,10 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) # We use only the first channel ans.append(wave[0]) return ans @@ -226,7 +229,9 @@ def greedy_search( if decoder_out is None: assert hyp is None, hyp hyp = [blank_id] * context_size - decoder_input = torch.tensor(hyp, dtype=torch.int32) # (1, context_size) + decoder_input = torch.tensor( + hyp, dtype=torch.int32 + ) # (1, context_size) decoder_out = model.run_decoder(decoder_input).squeeze(0) else: assert decoder_out.ndim == 1 @@ -305,7 +310,9 @@ def main(): frames.append(online_fbank.get_frame(num_processed_frames + i)) num_processed_frames += offset frames = torch.cat(frames, dim=0) - encoder_out, encoder_out_lens, hx, cx = model.run_encoder(frames, states) + encoder_out, encoder_out_lens, hx, cx = model.run_encoder( + frames, states + ) states = (hx, cx) hyp, decoder_out = greedy_search( model, encoder_out.squeeze(0), decoder_out, hyp @@ -321,7 +328,9 @@ def main(): frames.append(online_fbank.get_frame(num_processed_frames + i)) num_processed_frames += offset frames = torch.cat(frames, dim=0) - encoder_out, encoder_out_lens, hx, cx = model.run_encoder(frames, states) + encoder_out, encoder_out_lens, hx, cx = model.run_encoder( + frames, states + ) states = (hx, cx) hyp, decoder_out = greedy_search( model, encoder_out.squeeze(0), decoder_out, hyp @@ -334,7 +343,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-onnx-decode.py b/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-onnx-decode.py index b31fefa0a..232d3dd18 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-onnx-decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-onnx-decode.py @@ -109,12 +109,10 @@ def get_args(): parser.add_argument( "sound_filename", type=str, - help=( - "The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz." - ), + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", ) parser.add_argument( @@ -149,9 +147,10 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) # We use only the first channel ans.append(wave[0]) return ans @@ -200,7 +199,9 @@ class Model: sess_options=self.session_opts, ) - def run_encoder(self, x, h0, c0) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + def run_encoder( + self, x, h0, c0 + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Args: x: @@ -257,7 +258,9 @@ class Model: }, )[0] - return self.run_joiner_decoder_proj(torch.from_numpy(decoder_out).squeeze(1)) + return self.run_joiner_decoder_proj( + torch.from_numpy(decoder_out).squeeze(1) + ) def run_joiner( self, @@ -300,7 +303,11 @@ class Model: projected_encoder_out = self.joiner_encoder_proj.run( [self.joiner_encoder_proj.get_outputs()[0].name], - {self.joiner_encoder_proj.get_inputs()[0].name: encoder_out.numpy()}, + { + self.joiner_encoder_proj.get_inputs()[ + 0 + ].name: encoder_out.numpy() + }, )[0] return torch.from_numpy(projected_encoder_out) @@ -319,7 +326,11 @@ class Model: projected_decoder_out = self.joiner_decoder_proj.run( [self.joiner_decoder_proj.get_outputs()[0].name], - {self.joiner_decoder_proj.get_inputs()[0].name: decoder_out.numpy()}, + { + self.joiner_decoder_proj.get_inputs()[ + 0 + ].name: decoder_out.numpy() + }, )[0] return torch.from_numpy(projected_decoder_out) @@ -358,7 +369,9 @@ def greedy_search( if decoder_out is None: assert hyp is None, hyp hyp = [blank_id] * context_size - decoder_input = torch.tensor([hyp], dtype=torch.int64) # (1, context_size) + decoder_input = torch.tensor( + [hyp], dtype=torch.int64 + ) # (1, context_size) decoder_out = model.run_decoder(decoder_input) else: assert decoder_out.shape[0] == 1 @@ -461,7 +474,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/train.py b/egs/librispeech/ASR/lstm_transducer_stateless2/train.py index 08a895a75..5eaaf321f 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/train.py @@ -95,7 +95,9 @@ from icefall.utils import ( str2bool, ) -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] +LRSchedulerType = Union[ + torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler +] def add_model_arguments(parser: argparse.ArgumentParser): @@ -161,7 +163,8 @@ def get_parser(): "--full-libri", type=str2bool, default=True, - help="When enabled, use 960h LibriSpeech. Otherwise, use 100h subset.", + help="When enabled, use 960h LibriSpeech. " + "Otherwise, use 100h subset.", ) parser.add_argument( @@ -235,45 +238,42 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help=( - "The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss" - ), + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help=( - "The scale to smooth the loss with lm (output of prediction network) part." - ), + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)part.", + help="The scale to smooth the loss with am (output of encoder network)" + "part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help=( - "To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss." - ), + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", ) parser.add_argument( @@ -645,7 +645,11 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = model.device if isinstance(model, DDP) else next(model.parameters()).device + device = ( + model.device + if isinstance(model, DDP) + else next(model.parameters()).device + ) feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -688,7 +692,9 @@ def compute_loss( # If either all simple_loss or pruned_loss is inf or nan, # we stop the training process by raising an exception - if torch.all(~simple_loss_is_finite) or torch.all(~pruned_loss_is_finite): + if torch.all(~simple_loss_is_finite) or torch.all( + ~pruned_loss_is_finite + ): raise ValueError( "There are too many utterances in this batch " "leading to inf or nan losses." @@ -701,9 +707,14 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) + 0.0 + if warmup < 1.0 + else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) + ) + loss = ( + params.simple_loss_scale * simple_loss + + pruned_loss_scale * pruned_loss ) - loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training @@ -714,7 +725,9 @@ def compute_loss( # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2 # (2) If some utterances in the batch lead to inf/nan loss, they # are filtered out. - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + info["frames"] = ( + (feature_lens // params.subsampling_factor).sum().item() + ) # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa info["utterances"] = feature.size(0) @@ -945,7 +958,9 @@ def train_one_epoch( f"train/current_{prefix}_", params.batch_idx_train, ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) libri_tot_loss.write_summary( tb_writer, "train/libri_tot_", params.batch_idx_train ) @@ -991,7 +1006,8 @@ def filter_short_and_long_utterances( # the threshold if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. " + f"Duration: {c.duration}" ) return False @@ -1139,7 +1155,9 @@ def run(rank, world_size, args): train_giga_cuts = train_giga_cuts.repeat(times=None) if args.enable_musan: - cuts_musan = load_manifest(Path(args.manifest_dir) / "musan_cuts.jsonl.gz") + cuts_musan = load_manifest( + Path(args.manifest_dir) / "musan_cuts.jsonl.gz" + ) else: cuts_musan = None diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/decode.py b/egs/librispeech/ASR/lstm_transducer_stateless3/decode.py index a8d5605fb..9eee19379 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/decode.py @@ -182,24 +182,20 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help=( - "Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. " - ), + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", ) parser.add_argument( @@ -294,7 +290,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( @@ -389,7 +386,9 @@ def decode_one_batch( ) feature_lens += num_tail_padded_frames - encoder_out, encoder_out_lens, _ = model.encoder(x=feature, x_lens=feature_lens) + encoder_out, encoder_out_lens, _ = model.encoder( + x=feature, x_lens=feature_lens + ) if params.decoding_method == "fast_beam_search": res = fast_beam_search_one_best( @@ -442,7 +441,10 @@ def decode_one_batch( nbest_scale=params.nbest_scale, return_timestamps=True, ) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + elif ( + params.decoding_method == "greedy_search" + and params.max_sym_per_frame == 1 + ): res = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -520,7 +522,9 @@ def decode_dataset( sp: spm.SentencePieceProcessor, word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[str, List[str], List[str], List[float], List[float]]]]: +) -> Dict[ + str, List[Tuple[str, List[str], List[str], List[float], List[float]]] +]: """Decode dataset. Args: @@ -595,7 +599,9 @@ def decode_dataset( cut_ids, hyps, texts, timestamps_hyp, timestamps_ref ): ref_words = ref_text.split() - this_batch.append((cut_id, ref_words, hyp_words, time_ref, time_hyp)) + this_batch.append( + (cut_id, ref_words, hyp_words, time_ref, time_hyp) + ) results[name].extend(this_batch) @@ -604,7 +610,9 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) return results @@ -642,7 +650,8 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -669,7 +678,9 @@ def save_results( note = "" logging.info(s) - s = "\nFor {}, symbol-delay of different settings are:\n".format(test_set_name) + s = "\nFor {}, symbol-delay of different settings are:\n".format( + test_set_name + ) note = "\tbest for {}".format(test_set_name) for key, val in test_set_delays: s += "{}\tmean: {}s, variance: {}{}\n".format(key, val[0], val[1], note) @@ -713,7 +724,9 @@ def main(): if "LG" in params.decoding_method: params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + params.suffix += ( + f"-{params.decoding_method}-beam-size-{params.beam_size}" + ) else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -745,12 +758,13 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -773,12 +787,13 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -806,7 +821,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - "Calculating the averaged model over epoch range from " + f"Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -833,7 +848,9 @@ def main(): decoding_graph.scores *= params.ngram_lm_scale else: word_table = None - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + decoding_graph = k2.trivial_graph( + params.vocab_size - 1, device=device + ) else: decoding_graph = None word_table = None diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/export.py b/egs/librispeech/ASR/lstm_transducer_stateless3/export.py index 51238f768..212c7bad6 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/export.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/export.py @@ -122,24 +122,20 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help=( - "Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. " - ), + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", ) parser.add_argument( @@ -176,7 +172,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) add_model_arguments(parser) @@ -284,12 +281,13 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -312,12 +310,13 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -345,7 +344,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - "Calculating the averaged model over epoch range from " + f"Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -381,7 +380,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/jit_pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless3/jit_pretrained.py index 180ba8c72..a3443cf0a 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/jit_pretrained.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/jit_pretrained.py @@ -85,12 +85,10 @@ def get_parser(): "sound_files", type=str, nargs="+", - help=( - "The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz." - ), + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", ) parser.add_argument( @@ -125,9 +123,10 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) # We use only the first channel ans.append(wave[0]) return ans @@ -315,7 +314,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/lstm.py b/egs/librispeech/ASR/lstm_transducer_stateless3/lstm.py index 6e51b85e4..90bc351f4 100644 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/lstm.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/lstm.py @@ -661,7 +661,9 @@ class RandomCombine(nn.Module): self.stddev = stddev self.final_log_weight = ( - torch.tensor((final_weight / (1 - final_weight)) * (self.num_inputs - 1)) + torch.tensor( + (final_weight / (1 - final_weight)) * (self.num_inputs - 1) + ) .log() .item() ) @@ -758,14 +760,16 @@ class RandomCombine(nn.Module): # final contains self.num_inputs - 1 in all elements final = torch.full((num_frames,), self.num_inputs - 1, device=device) # nonfinal contains random integers in [0..num_inputs - 2], these are for non-final weights. # noqa - nonfinal = torch.randint(self.num_inputs - 1, (num_frames,), device=device) + nonfinal = torch.randint( + self.num_inputs - 1, (num_frames,), device=device + ) indexes = torch.where( torch.rand(num_frames, device=device) < final_prob, final, nonfinal ) - ans = torch.nn.functional.one_hot(indexes, num_classes=self.num_inputs).to( - dtype=dtype - ) + ans = torch.nn.functional.one_hot( + indexes, num_classes=self.num_inputs + ).to(dtype=dtype) return ans def _get_random_mixed_weights( diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless3/pretrained.py index 4f8049245..0e48fef04 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/pretrained.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/pretrained.py @@ -89,11 +89,9 @@ def get_parser(): "--checkpoint", type=str, required=True, - help=( - "Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint()." - ), + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", ) parser.add_argument( @@ -118,12 +116,10 @@ def get_parser(): "sound_files", type=str, nargs="+", - help=( - "The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz." - ), + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", ) parser.add_argument( @@ -170,7 +166,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -201,9 +198,10 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) # We use only the first channel ans.append(wave[0]) return ans @@ -266,11 +264,15 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + features = pad_sequence( + features, batch_first=True, padding_value=math.log(1e-10) + ) feature_lengths = torch.tensor(feature_lengths, device=device) - encoder_out, encoder_out_lens, _ = model.encoder(x=features, x_lens=feature_lengths) + encoder_out, encoder_out_lens, _ = model.encoder( + x=features, x_lens=feature_lengths + ) num_waves = encoder_out.size(0) hyps = [] @@ -342,7 +344,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/streaming_decode.py b/egs/librispeech/ASR/lstm_transducer_stateless3/streaming_decode.py index 4e9063a40..cfa918ed5 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/streaming_decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/streaming_decode.py @@ -101,9 +101,8 @@ def get_parser(): "--epoch", type=int, default=40, - help=( - "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." - ), + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", ) parser.add_argument( @@ -120,24 +119,20 @@ def get_parser(): "--avg", type=int, default=20, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. " - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", ) parser.add_argument( "--use-averaged-model", type=str2bool, default=False, - help=( - "Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. " - ), + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", ) parser.add_argument( @@ -204,7 +199,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -363,7 +359,9 @@ def modified_beam_search( index=hyps_shape.row_ids(1).to(torch.int64), ) # (num_hyps, encoder_out_dim) - logits = model.joiner(current_encoder_out, decoder_out, project_input=False) + logits = model.joiner( + current_encoder_out, decoder_out, project_input=False + ) # logits is of shape (num_hyps, 1, 1, vocab_size) logits = logits.squeeze(1).squeeze(1) @@ -380,7 +378,9 @@ def modified_beam_search( log_probs_shape = k2.ragged.create_ragged_shape2( row_splits=row_splits, cached_tot_size=log_probs.numel() ) - ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) + ragged_log_probs = k2.RaggedTensor( + shape=log_probs_shape, value=log_probs + ) for i in range(batch_size): topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) @@ -539,7 +539,9 @@ def decode_one_chunk( feature_list, batch_first=True, padding_value=LOG_EPSILON ).to(device) feature_lens = torch.tensor(feature_len_list, device=device) - num_processed_frames = torch.tensor(num_processed_frames_list, device=device) + num_processed_frames = torch.tensor( + num_processed_frames_list, device=device + ) # Make sure it has at least 1 frame after subsampling tail_length = params.subsampling_factor + 5 @@ -581,7 +583,8 @@ def decode_one_chunk( with warnings.catch_warnings(): warnings.simplefilter("ignore") processed_lens = ( - num_processed_frames // params.subsampling_factor + encoder_out_lens + num_processed_frames // params.subsampling_factor + + encoder_out_lens ) fast_beam_search_one_best( model=model, @@ -593,7 +596,9 @@ def decode_one_chunk( max_states=params.max_states, ) else: - raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) # Update cached states of each stream state_list = unstack_states(states) @@ -768,7 +773,8 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -810,7 +816,9 @@ def main(): params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + params.suffix += ( + f"-{params.decoding_method}-beam-size-{params.beam_size}" + ) else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -844,12 +852,13 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -872,12 +881,13 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -905,7 +915,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - "Calculating the averaged model over epoch range from " + f"Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/train.py b/egs/librispeech/ASR/lstm_transducer_stateless3/train.py index a1d19fb73..60a5a2be7 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/train.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/train.py @@ -87,7 +87,9 @@ from icefall.utils import ( str2bool, ) -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] +LRSchedulerType = Union[ + torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler +] def add_model_arguments(parser: argparse.ArgumentParser): @@ -230,45 +232,42 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help=( - "The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss" - ), + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help=( - "The scale to smooth the loss with lm (output of prediction network) part." - ), + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)part.", + help="The scale to smooth the loss with am (output of encoder network)" + "part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help=( - "To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss." - ), + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", ) parser.add_argument( @@ -607,7 +606,11 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = model.device if isinstance(model, DDP) else next(model.parameters()).device + device = ( + model.device + if isinstance(model, DDP) + else next(model.parameters()).device + ) feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -647,7 +650,9 @@ def compute_loss( # If either all simple_loss or pruned_loss is inf or nan, # we stop the training process by raising an exception - if torch.all(~simple_loss_is_finite) or torch.all(~pruned_loss_is_finite): + if torch.all(~simple_loss_is_finite) or torch.all( + ~pruned_loss_is_finite + ): raise ValueError( "There are too many utterances in this batch " "leading to inf or nan losses." @@ -660,9 +665,14 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) + 0.0 + if warmup < 1.0 + else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) + ) + loss = ( + params.simple_loss_scale * simple_loss + + pruned_loss_scale * pruned_loss ) - loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training @@ -673,7 +683,9 @@ def compute_loss( # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2 # (2) If some utterances in the batch lead to inf/nan loss, they # are filtered out. - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + info["frames"] = ( + (feature_lens // params.subsampling_factor).sum().item() + ) # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa info["utterances"] = feature.size(0) @@ -840,7 +852,10 @@ def train_one_epoch( rank=rank, ) - if batch_idx % params.log_interval == 0 and not params.print_diagnostics: + if ( + batch_idx % params.log_interval == 0 + and not params.print_diagnostics + ): cur_lr = scheduler.get_last_lr()[0] logging.info( f"Epoch {params.cur_epoch}, " @@ -857,7 +872,9 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) if ( batch_idx > 0 @@ -992,7 +1009,8 @@ def run(rank, world_size, args): # the threshold if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. " + f"Duration: {c.duration}" ) return False diff --git a/egs/librispeech/ASR/pruned2_knowledge/asr_datamodule.py b/egs/librispeech/ASR/pruned2_knowledge/asr_datamodule.py index fd2a5354a..8dd1459ca 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/asr_datamodule.py +++ b/egs/librispeech/ASR/pruned2_knowledge/asr_datamodule.py @@ -74,18 +74,17 @@ class LibriSpeechAsrDataModule: def add_arguments(cls, parser: argparse.ArgumentParser): group = parser.add_argument_group( title="ASR data related options", - description=( - "These options are used for the preparation of " - "PyTorch DataLoaders from Lhotse CutSet's -- they control the " - "effective batch sizes, sampling strategies, applied data " - "augmentations, etc." - ), + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", ) group.add_argument( "--full-libri", type=str2bool, default=True, - help="When enabled, use 960h LibriSpeech. Otherwise, use 100h subset.", + help="When enabled, use 960h LibriSpeech. " + "Otherwise, use 100h subset.", ) group.add_argument( "--manifest-dir", @@ -97,91 +96,75 @@ class LibriSpeechAsrDataModule: "--max-duration", type=int, default=200.0, - help=( - "Maximum pooled recordings duration (seconds) in a " - "single batch. You can reduce it if it causes CUDA OOM." - ), + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", ) group.add_argument( "--bucketing-sampler", type=str2bool, default=True, - help=( - "When enabled, the batches will come from buckets of " - "similar duration (saves padding frames)." - ), + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", ) group.add_argument( "--num-buckets", type=int, default=30, - help=( - "The number of buckets for the BucketingSampler" - "(you might want to increase it for larger datasets)." - ), + help="The number of buckets for the BucketingSampler" + "(you might want to increase it for larger datasets).", ) group.add_argument( "--concatenate-cuts", type=str2bool, default=False, - help=( - "When enabled, utterances (cuts) will be concatenated " - "to minimize the amount of padding." - ), + help="When enabled, utterances (cuts) will be concatenated " + "to minimize the amount of padding.", ) group.add_argument( "--duration-factor", type=float, default=1.0, - help=( - "Determines the maximum duration of a concatenated cut " - "relative to the duration of the longest cut in a batch." - ), + help="Determines the maximum duration of a concatenated cut " + "relative to the duration of the longest cut in a batch.", ) group.add_argument( "--gap", type=float, default=1.0, - help=( - "The amount of padding (in seconds) inserted between " - "concatenated cuts. This padding is filled with noise when " - "noise augmentation is used." - ), + help="The amount of padding (in seconds) inserted between " + "concatenated cuts. This padding is filled with noise when " + "noise augmentation is used.", ) group.add_argument( "--on-the-fly-feats", type=str2bool, default=False, - help=( - "When enabled, use on-the-fly cut mixing and feature " - "extraction. Will drop existing precomputed feature manifests " - "if available." - ), + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", ) group.add_argument( "--shuffle", type=str2bool, default=True, - help=( - "When enabled (=default), the examples will be shuffled for each epoch." - ), + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", ) group.add_argument( "--return-cuts", type=str2bool, default=True, - help=( - "When enabled, each batch will have the " - "field: batch['supervisions']['cut'] with the cuts that " - "were used to construct it." - ), + help="When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it.", ) group.add_argument( "--num-workers", type=int, default=2, - help="The number of training dataloader workers that collect the batches.", + help="The number of training dataloader workers that " + "collect the batches.", ) group.add_argument( @@ -195,22 +178,18 @@ class LibriSpeechAsrDataModule: "--spec-aug-time-warp-factor", type=int, default=80, - help=( - "Used only when --enable-spec-aug is True. " - "It specifies the factor for time warping in SpecAugment. " - "Larger values mean more warping. " - "A value less than 1 means to disable time warp." - ), + help="Used only when --enable-spec-aug is True. " + "It specifies the factor for time warping in SpecAugment. " + "Larger values mean more warping. " + "A value less than 1 means to disable time warp.", ) group.add_argument( "--enable-musan", type=str2bool, default=True, - help=( - "When enabled, select noise from MUSAN and mix it" - "with training dataset. " - ), + help="When enabled, select noise from MUSAN and mix it" + "with training dataset. ", ) def train_dataloaders( @@ -229,16 +208,20 @@ class LibriSpeechAsrDataModule: if self.args.enable_musan: logging.info("Enable MUSAN") logging.info("About to get Musan cuts") - cuts_musan = load_manifest(self.args.manifest_dir / "cuts_musan.json.gz") + cuts_musan = load_manifest( + self.args.manifest_dir / "cuts_musan.json.gz" + ) transforms.append( - CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + CutMix( + cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True + ) ) else: logging.info("Disable MUSAN") if self.args.concatenate_cuts: logging.info( - "Using cut concatenation with duration factor " + f"Using cut concatenation with duration factor " f"{self.args.duration_factor} and gap {self.args.gap}." ) # Cut concatenation should be the first transform in the list, @@ -253,7 +236,9 @@ class LibriSpeechAsrDataModule: input_transforms = [] if self.args.enable_spec_aug: logging.info("Enable SpecAugment") - logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") + logging.info( + f"Time warp factor: {self.args.spec_aug_time_warp_factor}" + ) # Set the value of num_frame_masks according to Lhotse's version. # In different Lhotse's versions, the default of num_frame_masks is # different. @@ -296,7 +281,9 @@ class LibriSpeechAsrDataModule: # Drop feats to be on the safe side. train = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + input_strategy=OnTheFlyFeatures( + Fbank(FbankConfig(num_mel_bins=80)) + ), input_transforms=input_transforms, return_cuts=self.args.return_cuts, ) @@ -353,7 +340,9 @@ class LibriSpeechAsrDataModule: if self.args.on_the_fly_feats: validate = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + input_strategy=OnTheFlyFeatures( + Fbank(FbankConfig(num_mel_bins=80)) + ), return_cuts=self.args.return_cuts, ) else: @@ -400,17 +389,23 @@ class LibriSpeechAsrDataModule: @lru_cache() def train_clean_100_cuts(self) -> CutSet: logging.info("About to get train-clean-100 cuts") - return load_manifest(self.args.manifest_dir / "cuts_train-clean-100.json.gz") + return load_manifest( + self.args.manifest_dir / "cuts_train-clean-100.json.gz" + ) @lru_cache() def train_clean_360_cuts(self) -> CutSet: logging.info("About to get train-clean-360 cuts") - return load_manifest(self.args.manifest_dir / "cuts_train-clean-360.json.gz") + return load_manifest( + self.args.manifest_dir / "cuts_train-clean-360.json.gz" + ) @lru_cache() def train_other_500_cuts(self) -> CutSet: logging.info("About to get train-other-500 cuts") - return load_manifest(self.args.manifest_dir / "cuts_train-other-500.json.gz") + return load_manifest( + self.args.manifest_dir / "cuts_train-other-500.json.gz" + ) @lru_cache() def dev_clean_cuts(self) -> CutSet: diff --git a/egs/librispeech/ASR/pruned2_knowledge/beam_search.py b/egs/librispeech/ASR/pruned2_knowledge/beam_search.py index 785a8f097..2e9bf3e0b 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/beam_search.py +++ b/egs/librispeech/ASR/pruned2_knowledge/beam_search.py @@ -172,9 +172,9 @@ def greedy_search( y = logits.argmax().item() if y != blank_id: hyp.append(y) - decoder_input = torch.tensor([hyp[-context_size:]], device=device).reshape( - 1, context_size - ) + decoder_input = torch.tensor( + [hyp[-context_size:]], device=device + ).reshape(1, context_size) decoder_out = model.decoder(decoder_input, need_pad=False) decoder_out = model.joiner.decoder_proj(decoder_out) @@ -302,7 +302,9 @@ class HypothesisList(object): key = hyp.key if key in self: old_hyp = self._data[key] # shallow copy - torch.logaddexp(old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob) + torch.logaddexp( + old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob + ) else: self._data[key] = hyp @@ -318,7 +320,9 @@ class HypothesisList(object): Return the hypothesis that has the largest `log_prob`. """ if length_norm: - return max(self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys)) + return max( + self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys) + ) else: return max(self._data.values(), key=lambda hyp: hyp.log_prob) @@ -492,7 +496,9 @@ def modified_beam_search( log_probs_shape = k2.ragged.create_ragged_shape2( row_splits=row_splits, cached_tot_size=log_probs.numel() ) - ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) + ragged_log_probs = k2.RaggedTensor( + shape=log_probs_shape, value=log_probs + ) for i in range(batch_size): topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) diff --git a/egs/librispeech/ASR/pruned2_knowledge/conformer.py b/egs/librispeech/ASR/pruned2_knowledge/conformer.py index 3b6d0549d..295a35204 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/conformer.py +++ b/egs/librispeech/ASR/pruned2_knowledge/conformer.py @@ -18,10 +18,10 @@ import math import warnings from typing import Optional, Tuple +from sampling import create_knowledge_base, KnowledgeBaseLookup import torch from encoder_interface import EncoderInterface -from sampling import KnowledgeBaseLookup, create_knowledge_base from scaling import ( ActivationBalancer, BasicNorm, @@ -73,9 +73,9 @@ class Conformer(EncoderInterface): if subsampling_factor != 4: raise NotImplementedError("Support only 'subsampling_factor=4'.") - self.knowledge_base = create_knowledge_base( - knowledge_M, knowledge_N, knowledge_D - ) + + self.knowledge_base = create_knowledge_base(knowledge_M, knowledge_N, + knowledge_D) # self.encoder_embed converts the input of shape (N, T, num_features) # to the shape (N, T//subsampling_factor, d_model). @@ -89,7 +89,7 @@ class Conformer(EncoderInterface): # Pass in a lambda that creates a new ConformerEncoderLayer with these # args. Don't use deepcopy because we need the knowledge_base # to be shared. - encoder_layer_fn = lambda: ConformerEncoderLayer( # noqa: E731 + encoder_layer_fn = lambda: ConformerEncoderLayer( self.knowledge_base, d_model, nhead, @@ -100,7 +100,7 @@ class Conformer(EncoderInterface): knowledge_M, knowledge_N, knowledge_D, - knowledge_K, + knowledge_K ) self.encoder = ConformerEncoder(encoder_layer_fn, num_encoder_layers) @@ -187,7 +187,9 @@ class ConformerEncoderLayer(nn.Module): self.d_model = d_model - self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) + self.self_attn = RelPositionMultiheadAttention( + d_model, nhead, dropout=0.0 + ) self.feed_forward = nn.Sequential( ScaledLinear(d_model, dim_feedforward), @@ -207,14 +209,10 @@ class ConformerEncoderLayer(nn.Module): self.conv_module = ConvolutionModule(d_model, cnn_module_kernel) - self.lookup = KnowledgeBaseLookup( - knowledge_M, - knowledge_N, - knowledge_D, - knowledge_K, - d_model, - knowledge_base, - ) + self.lookup = KnowledgeBaseLookup(knowledge_M, knowledge_N, + knowledge_D, knowledge_K, + d_model, + knowledge_base) self.norm_final = BasicNorm(d_model) @@ -313,7 +311,9 @@ class ConformerEncoder(nn.Module): def __init__(self, encoder_layer_fn, num_layers: int) -> None: super().__init__() - self.layers = nn.ModuleList([encoder_layer_fn() for i in range(num_layers)]) + self.layers = nn.ModuleList( + [encoder_layer_fn() for i in range(num_layers)] + ) self.num_layers = num_layers def forward( @@ -367,7 +367,9 @@ class RelPositionalEncoding(torch.nn.Module): """ - def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: + def __init__( + self, d_model: int, dropout_rate: float, max_len: int = 5000 + ) -> None: """Construct an PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() self.d_model = d_model @@ -382,7 +384,9 @@ class RelPositionalEncoding(torch.nn.Module): # the length of self.pe is 2 * input_len - 1 if self.pe.size(1) >= x.size(1) * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): + if self.pe.dtype != x.dtype or str(self.pe.device) != str( + x.device + ): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return # Suppose `i` means to the position of query vecotr and `j` means the @@ -657,9 +661,9 @@ class RelPositionMultiheadAttention(nn.Module): if torch.equal(query, key) and torch.equal(key, value): # self-attention - q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk( - 3, dim=-1 - ) + q, k, v = nn.functional.linear( + query, in_proj_weight, in_proj_bias + ).chunk(3, dim=-1) elif torch.equal(key, value): # encoder-decoder attention @@ -728,25 +732,33 @@ class RelPositionMultiheadAttention(nn.Module): if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: - raise RuntimeError("The size of the 2D attn_mask is not correct.") + raise RuntimeError( + "The size of the 2D attn_mask is not correct." + ) elif attn_mask.dim() == 3: if list(attn_mask.size()) != [ bsz * num_heads, query.size(0), key.size(0), ]: - raise RuntimeError("The size of the 3D attn_mask is not correct.") + raise RuntimeError( + "The size of the 3D attn_mask is not correct." + ) else: raise RuntimeError( - "attn_mask's dimension {} is not supported".format(attn_mask.dim()) + "attn_mask's dimension {} is not supported".format( + attn_mask.dim() + ) ) # attn_mask's dim is 3 now. # convert ByteTensor key_padding_mask to bool - if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: + if ( + key_padding_mask is not None + and key_padding_mask.dtype == torch.uint8 + ): warnings.warn( - "Byte tensor for key_padding_mask is deprecated. Use bool tensor" - " instead." + "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." ) key_padding_mask = key_padding_mask.to(torch.bool) @@ -783,7 +795,9 @@ class RelPositionMultiheadAttention(nn.Module): # first compute matrix a and matrix c # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2) + matrix_ac = torch.matmul( + q_with_bias_u, k + ) # (batch, head, time1, time2) # compute matrix b and matrix d matrix_bd = torch.matmul( @@ -791,9 +805,13 @@ class RelPositionMultiheadAttention(nn.Module): ) # (batch, head, time1, 2*time1-1) matrix_bd = self.rel_shift(matrix_bd) - attn_output_weights = matrix_ac + matrix_bd # (batch, head, time1, time2) + attn_output_weights = ( + matrix_ac + matrix_bd + ) # (batch, head, time1, time2) - attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) + attn_output_weights = attn_output_weights.view( + bsz * num_heads, tgt_len, -1 + ) assert list(attn_output_weights.size()) == [ bsz * num_heads, @@ -827,9 +845,13 @@ class RelPositionMultiheadAttention(nn.Module): attn_output = torch.bmm(attn_output_weights, v) assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] attn_output = ( - attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) + attn_output.transpose(0, 1) + .contiguous() + .view(tgt_len, bsz, embed_dim) + ) + attn_output = nn.functional.linear( + attn_output, out_proj_weight, out_proj_bias ) - attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) if need_weights: # average attention weights over heads @@ -852,7 +874,9 @@ class ConvolutionModule(nn.Module): """ - def __init__(self, channels: int, kernel_size: int, bias: bool = True) -> None: + def __init__( + self, channels: int, kernel_size: int, bias: bool = True + ) -> None: """Construct an ConvolutionModule object.""" super(ConvolutionModule, self).__init__() # kernerl_size should be a odd number for 'SAME' padding diff --git a/egs/librispeech/ASR/pruned2_knowledge/decode.py b/egs/librispeech/ASR/pruned2_knowledge/decode.py index 65da19f27..b4a9af55a 100755 --- a/egs/librispeech/ASR/pruned2_knowledge/decode.py +++ b/egs/librispeech/ASR/pruned2_knowledge/decode.py @@ -76,7 +76,11 @@ from beam_search import ( ) from train import get_params, get_transducer_model -from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint +from icefall.checkpoint import ( + average_checkpoints, + find_checkpoints, + load_checkpoint, +) from icefall.utils import ( AttributeDict, setup_logger, @@ -94,19 +98,16 @@ def get_parser(): "--epoch", type=int, default=28, - help=( - "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." - ), + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", ) parser.add_argument( "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. " - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", ) parser.add_argument( @@ -185,7 +186,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -243,7 +245,9 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + encoder_out, encoder_out_lens = model.encoder( + x=feature, x_lens=feature_lens + ) hyps = [] if params.decoding_method == "fast_beam_search": @@ -258,7 +262,10 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + elif ( + params.decoding_method == "greedy_search" + and params.max_sym_per_frame == 1 + ): hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -302,7 +309,11 @@ def decode_one_batch( return {"greedy_search": hyps} elif params.decoding_method == "fast_beam_search": return { - f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps + ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}" + ): hyps } else: return {f"beam_size_{params.beam_size}": hyps} @@ -374,7 +385,9 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) return results @@ -406,7 +419,8 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) diff --git a/egs/librispeech/ASR/pruned2_knowledge/decoder.py b/egs/librispeech/ASR/pruned2_knowledge/decoder.py index 0b9c886c7..b6d94aaf1 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/decoder.py +++ b/egs/librispeech/ASR/pruned2_knowledge/decoder.py @@ -90,7 +90,9 @@ class Decoder(nn.Module): if self.context_size > 1: embedding_out = embedding_out.permute(0, 2, 1) if need_pad is True: - embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0)) + embedding_out = F.pad( + embedding_out, pad=(self.context_size - 1, 0) + ) else: # During inference time, there is no need to do extra padding # as we only need one output diff --git a/egs/librispeech/ASR/pruned2_knowledge/decoder2.py b/egs/librispeech/ASR/pruned2_knowledge/decoder2.py index 2ca76a30c..db51fb1cd 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/decoder2.py +++ b/egs/librispeech/ASR/pruned2_knowledge/decoder2.py @@ -14,13 +14,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional - import torch import torch.nn as nn import torch.nn.functional as F -from subsampling import ScaledConv1d from torch import Tensor +from typing import Optional +from subsampling import ScaledConv1d class Decoder(nn.Module): @@ -91,7 +90,9 @@ class Decoder(nn.Module): if self.context_size > 1: embedding_out = embedding_out.permute(0, 2, 1) if need_pad is True: - embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0)) + embedding_out = F.pad( + embedding_out, pad=(self.context_size - 1, 0) + ) else: # During inference time, there is no need to do extra padding # as we only need one output @@ -101,6 +102,7 @@ class Decoder(nn.Module): return embedding_out + class ScaledEmbedding(nn.Module): r"""A simple lookup table that stores embeddings of a fixed dictionary and size. @@ -169,13 +171,8 @@ class ScaledEmbedding(nn.Module): [ 0.0000, 0.0000, 0.0000], [-0.1655, 0.9897, 0.0635]]]) """ - __constants__ = [ - "num_embeddings", - "embedding_dim", - "padding_idx", - "scale_grad_by_freq", - "sparse", - ] + __constants__ = ['num_embeddings', 'embedding_dim', 'padding_idx', + 'scale_grad_by_freq', 'sparse'] num_embeddings: int embedding_dim: int @@ -184,41 +181,34 @@ class ScaledEmbedding(nn.Module): weight: Tensor sparse: bool - def __init__( - self, - num_embeddings: int, - embedding_dim: int, - padding_idx: Optional[int] = None, - scale_grad_by_freq: bool = False, - sparse: bool = False, - scale_speed: float = 5.0, - ) -> None: + def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None, + scale_grad_by_freq: bool = False, + sparse: bool = False, + scale_speed: float = 5.0) -> None: super(ScaledEmbedding, self).__init__() self.num_embeddings = num_embeddings self.embedding_dim = embedding_dim if padding_idx is not None: if padding_idx > 0: - assert ( - padding_idx < self.num_embeddings - ), "Padding_idx must be within num_embeddings" + assert padding_idx < self.num_embeddings, 'Padding_idx must be within num_embeddings' elif padding_idx < 0: - assert ( - padding_idx >= -self.num_embeddings - ), "Padding_idx must be within num_embeddings" + assert padding_idx >= -self.num_embeddings, 'Padding_idx must be within num_embeddings' padding_idx = self.num_embeddings + padding_idx self.padding_idx = padding_idx self.scale_grad_by_freq = scale_grad_by_freq self.scale_speed = scale_speed - self.scale = nn.Parameter(torch.zeros(())) # see reset_parameters() + self.scale = nn.Parameter(torch.zeros(())) # see reset_parameters() self.sparse = sparse self.weight = nn.Parameter(torch.Tensor(num_embeddings, embedding_dim)) self.reset_parameters() + + def reset_parameters(self) -> None: nn.init.normal_(self.weight, std=0.05) - nn.init.constant_(self.scale, torch.tensor(1.0 / 0.05).log() / self.scale_speed) + nn.init.constant_(self.scale, torch.tensor(1.0/0.05).log() / self.scale_speed) if self.padding_idx is not None: with torch.no_grad(): @@ -227,38 +217,22 @@ class ScaledEmbedding(nn.Module): def forward(self, input: Tensor) -> Tensor: scale = (self.scale * self.scale_speed).exp() if input.numel() < self.num_embeddings: - return ( - F.embedding( - input, - self.weight, - self.padding_idx, - None, - 2.0, # None, 2.0 relate to normalization - self.scale_grad_by_freq, - self.sparse, - ) - * scale - ) + return F.embedding( + input, self.weight, self.padding_idx, + None, 2.0, # None, 2.0 relate to normalization + self.scale_grad_by_freq, self.sparse) * scale else: return F.embedding( - input, - self.weight * scale, - self.padding_idx, - None, - 2.0, # None, 2.0 relates to normalization - self.scale_grad_by_freq, - self.sparse, - ) + input, self.weight * scale, self.padding_idx, + None, 2.0, # None, 2.0 relates to normalization + self.scale_grad_by_freq, self.sparse) def extra_repr(self) -> str: - s = ( - "{num_embeddings}, {embedding_dim}, scale_speed={scale_speed}," - " scale={scale}" - ) + s = '{num_embeddings}, {embedding_dim}, scale_speed={scale_speed}, scale={scale}' if self.padding_idx is not None: - s += ", padding_idx={padding_idx}" + s += ', padding_idx={padding_idx}' if self.scale_grad_by_freq is not False: - s += ", scale_grad_by_freq={scale_grad_by_freq}" + s += ', scale_grad_by_freq={scale_grad_by_freq}' if self.sparse is not False: - s += ", sparse=True" + s += ', sparse=True' return s.format(**self.__dict__) diff --git a/egs/librispeech/ASR/pruned2_knowledge/export.py b/egs/librispeech/ASR/pruned2_knowledge/export.py index 1af05d9c8..96d1a30fb 100755 --- a/egs/librispeech/ASR/pruned2_knowledge/export.py +++ b/egs/librispeech/ASR/pruned2_knowledge/export.py @@ -64,20 +64,17 @@ def get_parser(): "--epoch", type=int, default=28, - help=( - "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." - ), + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", ) parser.add_argument( "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. " - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", ) parser.add_argument( @@ -108,7 +105,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) return parser @@ -176,7 +174,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned2_knowledge/joiner.py b/egs/librispeech/ASR/pruned2_knowledge/joiner.py index 68c663b66..35f75ed2a 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/joiner.py +++ b/egs/librispeech/ASR/pruned2_knowledge/joiner.py @@ -56,7 +56,9 @@ class Joiner(nn.Module): assert encoder_out.shape[:-1] == decoder_out.shape[:-1] if project_input: - logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out) + logit = self.encoder_proj(encoder_out) + self.decoder_proj( + decoder_out + ) else: logit = encoder_out + decoder_out diff --git a/egs/librispeech/ASR/pruned2_knowledge/model.py b/egs/librispeech/ASR/pruned2_knowledge/model.py index ca8c28af1..599bf2506 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/model.py +++ b/egs/librispeech/ASR/pruned2_knowledge/model.py @@ -63,7 +63,9 @@ class Transducer(nn.Module): self.decoder = decoder self.joiner = joiner - self.simple_am_proj = ScaledLinear(encoder_dim, vocab_size, initial_speed=0.5) + self.simple_am_proj = ScaledLinear( + encoder_dim, vocab_size, initial_speed=0.5 + ) self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size) def forward( @@ -134,7 +136,9 @@ class Transducer(nn.Module): y_padded = y.pad(mode="constant", padding_value=0) y_padded = y_padded.to(torch.int64) - boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) + boundary = torch.zeros( + (x.size(0), 4), dtype=torch.int64, device=x.device + ) boundary[:, 2] = y_lens boundary[:, 3] = x_lens diff --git a/egs/librispeech/ASR/pruned2_knowledge/optim.py b/egs/librispeech/ASR/pruned2_knowledge/optim.py index 76cd4e11e..432bf8220 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/optim.py +++ b/egs/librispeech/ASR/pruned2_knowledge/optim.py @@ -72,11 +72,17 @@ class Eve(Optimizer): if not 0.0 <= eps: raise ValueError("Invalid epsilon value: {}".format(eps)) if not 0.0 <= betas[0] < 1.0: - raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + raise ValueError( + "Invalid beta parameter at index 0: {}".format(betas[0]) + ) if not 0.0 <= betas[1] < 1.0: - raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + raise ValueError( + "Invalid beta parameter at index 1: {}".format(betas[1]) + ) if not 0 <= weight_decay <= 0.1: - raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + raise ValueError( + "Invalid weight_decay value: {}".format(weight_decay) + ) if not 0 < target_rms <= 10.0: raise ValueError("Invalid target_rms value: {}".format(target_rms)) defaults = dict( @@ -112,7 +118,9 @@ class Eve(Optimizer): # Perform optimization step grad = p.grad if grad.is_sparse: - raise RuntimeError("AdamW does not support sparse gradients") + raise RuntimeError( + "AdamW does not support sparse gradients" + ) state = self.state[p] @@ -139,7 +147,7 @@ class Eve(Optimizer): # Decay the first and second moment running average coefficient exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) - denom = (exp_avg_sq.sqrt() * (bias_correction2**-0.5)).add_( + denom = (exp_avg_sq.sqrt() * (bias_correction2 ** -0.5)).add_( group["eps"] ) @@ -150,7 +158,9 @@ class Eve(Optimizer): if p.numel() > 1: # avoid applying this weight-decay on "scaling factors" # (which are scalar). - is_above_target_rms = p.norm() > (target_rms * (p.numel() ** 0.5)) + is_above_target_rms = p.norm() > ( + target_rms * (p.numel() ** 0.5) + ) p.mul_(1 - (weight_decay * is_above_target_rms)) p.addcdiv_(exp_avg, denom, value=-step_size) @@ -166,14 +176,18 @@ class LRScheduler(object): def __init__(self, optimizer: Optimizer, verbose: bool = False): # Attach optimizer if not isinstance(optimizer, Optimizer): - raise TypeError("{} is not an Optimizer".format(type(optimizer).__name__)) + raise TypeError( + "{} is not an Optimizer".format(type(optimizer).__name__) + ) self.optimizer = optimizer self.verbose = verbose for group in optimizer.param_groups: group.setdefault("initial_lr", group["lr"]) - self.base_lrs = [group["initial_lr"] for group in optimizer.param_groups] + self.base_lrs = [ + group["initial_lr"] for group in optimizer.param_groups + ] self.epoch = 0 self.batch = 0 @@ -281,9 +295,10 @@ class Eden(LRScheduler): def get_lr(self): factor = ( - (self.batch**2 + self.lr_batches**2) / self.lr_batches**2 + (self.batch ** 2 + self.lr_batches ** 2) / self.lr_batches ** 2 ) ** -0.25 * ( - ((self.epoch**2 + self.lr_epochs**2) / self.lr_epochs**2) ** -0.25 + ((self.epoch ** 2 + self.lr_epochs ** 2) / self.lr_epochs ** 2) + ** -0.25 ) return [x * factor for x in self.base_lrs] diff --git a/egs/librispeech/ASR/pruned2_knowledge/sampling.py b/egs/librispeech/ASR/pruned2_knowledge/sampling.py index 8cc930927..7b05e2f00 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/sampling.py +++ b/egs/librispeech/ASR/pruned2_knowledge/sampling.py @@ -3,29 +3,32 @@ # This was copied from /ceph-dan/torch-sampling/torch_sampling/sampling_ref.py, # its git history is there. -import random import timeit -from typing import Optional, Tuple - import torch +from torch import Tensor +from torch import nn +from torch.cuda.amp import GradScaler, custom_fwd, custom_bwd +from typing import Tuple, Optional from scaling import ScaledLinear -from torch import Tensor, nn -from torch.cuda.amp import GradScaler, custom_bwd, custom_fwd +import random from torch_scheduled_sampling import sample_combined # The main exports of this file are the module KnowledgeBaseLookup and the # function create_knowledge_base. + + + + def create_knowledge_base(M: int, N: int, D: int) -> nn.Parameter: std = 0.1 - a = (3**0.5) * std # this sqrt(3) thing is intended to get variance of - # 0.1 from uniform distribution - ans = nn.Parameter(torch.ones(M**N, D)) + a = (3 ** 0.5) * std # this sqrt(3) thing is intended to get variance of + # 0.1 from uniform distribution + ans = nn.Parameter(torch.ones(M ** N, D)) nn.init.uniform_(ans, -a, a) return ans - def join_indexes(indexes: Tensor, M: int) -> Tensor: """ Combines N-tuples of indexes into single indexes that can be used for @@ -44,9 +47,9 @@ def join_indexes(indexes: Tensor, M: int) -> Tensor: # Note, we don't use this, we -def weighted_matrix_lookup( - weights: Tensor, indexes: Tensor, knowledge_base: Tensor -) -> Tensor: +def weighted_matrix_lookup(weights: Tensor, + indexes: Tensor, + knowledge_base: Tensor) -> Tensor: """ Weighted combination of specified rows of a matrix. weights: Tensor of shape (*, K), can contain any value but probably in [0..1]. @@ -62,9 +65,9 @@ def weighted_matrix_lookup( # simpler but less memory-efficient implementation lookup = torch.index_select(knowledge_base, dim=0, index=indexes.flatten()) D = knowledge_base.shape[-1] - weights = weights.unsqueeze(-2) # (*, 1, K) - lookup = lookup.reshape(*indexes.shape, D) # (*, K, D) - ans = torch.matmul(weights, lookup) # ans: (*, 1, D) + weights = weights.unsqueeze(-2) # (*, 1, K) + lookup = lookup.reshape(*indexes.shape, D) # (*, K, D) + ans = torch.matmul(weights, lookup) # ans: (*, 1, D) ans = ans.squeeze(-2) assert list(ans.shape) == list(weights.shape[:-2]) + [D] return ans @@ -73,9 +76,7 @@ def weighted_matrix_lookup( class WeightedMatrixLookupFunction(torch.autograd.Function): @staticmethod @custom_fwd - def forward( - ctx, weights: Tensor, indexes: Tensor, knowledge_base: Tensor - ) -> Tensor: + def forward(ctx, weights: Tensor, indexes: Tensor, knowledge_base: Tensor) -> Tensor: """ Weighted combination of specified rows of a matrix. weights: Tensor of shape (*, K), can contain any value but probably in [0..1]. @@ -87,16 +88,15 @@ class WeightedMatrixLookupFunction(torch.autograd.Function): """ if random.random() < 0.001: print("dtype[1] = ", weights.dtype) - ctx.save_for_backward( - weights.detach(), indexes.detach(), knowledge_base.detach() - ) + ctx.save_for_backward(weights.detach(), indexes.detach(), + knowledge_base.detach()) with torch.no_grad(): lookup = torch.index_select(knowledge_base, dim=0, index=indexes.flatten()) D = knowledge_base.shape[-1] - weights = weights.unsqueeze(-2) # (*, 1, K) - lookup = lookup.reshape(*indexes.shape, D) # (*, K, D) - ans = torch.matmul(weights, lookup) # ans: (*, 1, D) - ans = ans.squeeze(-2) # (*, D) + weights = weights.unsqueeze(-2) # (*, 1, K) + lookup = lookup.reshape(*indexes.shape, D) # (*, K, D) + ans = torch.matmul(weights, lookup) # ans: (*, 1, D) + ans = ans.squeeze(-2) #(*, D) return ans @staticmethod @@ -107,7 +107,7 @@ class WeightedMatrixLookupFunction(torch.autograd.Function): knowledge_base.requires_grad = True dtype = ans_grad.dtype ans_grad = ans_grad.to(weights.dtype) - assert weights.requires_grad is False + assert weights.requires_grad == False D = knowledge_base.shape[-1] with torch.enable_grad(): # we'll use torch's autograd to differentiate this operation, which @@ -115,19 +115,16 @@ class WeightedMatrixLookupFunction(torch.autograd.Function): # We don't save `lookup` because it's large, that is the reason # we override Torch autograd. lookup = torch.index_select(knowledge_base, dim=0, index=indexes.flatten()) - lookup = lookup.reshape(*indexes.shape, D) # (*, K, D) - weights = weights.unsqueeze(-1) # (*, K, 1) + lookup = lookup.reshape(*indexes.shape, D) # (*, K, D) + weights = weights.unsqueeze(-1) # (*, K, 1) # forward pass: was: ## ans = torch.matmul(weights, lookup) ## ans: (*, 1, D) ## ans = ans.squeeze(-2) # ans, ans_grad: (*, D) - weights_grad = torch.matmul( - lookup, ans_grad.unsqueeze(-1) # (*, K, D) - ) # (*, D, 1) - weights_grad = weights_grad.squeeze(-1) # (*, K, 1) -> (*, K) - lookup_grad = weights * ans_grad.unsqueeze( - -2 - ) # (*, K, 1) * (*, 1, D) = (*, K, D) + weights_grad = torch.matmul(lookup, # (*, K, D) + ans_grad.unsqueeze(-1)) # (*, D, 1) + weights_grad = weights_grad.squeeze(-1) # (*, K, 1) -> (*, K) + lookup_grad = weights * ans_grad.unsqueeze(-2) # (*, K, 1) * (*, 1, D) = (*, K, D) lookup.backward(gradient=lookup_grad) return weights_grad.to(dtype), None, knowledge_base.grad.to(dtype) @@ -149,7 +146,6 @@ class PenalizeNegentropyFunction(torch.autograd.Function): Returns: logprobs """ - @staticmethod def forward(ctx, logprobs: Tensor, alpha: float): ctx.save_for_backward(logprobs.detach()) @@ -158,23 +154,18 @@ class PenalizeNegentropyFunction(torch.autograd.Function): @staticmethod def backward(ctx, logprobs_grad: Tensor) -> Tuple[Tensor, None]: - (logprobs,) = ctx.saved_tensors + logprobs, = ctx.saved_tensors with torch.enable_grad(): logprobs.requires_grad = True # `negentropy` is the negative entropy of the average distribution. # distributions. It will be <= 0. - l = logprobs.reshape(-1, logprobs.shape[-1]) # noqa: E741 + l = logprobs.reshape(-1, logprobs.shape[-1]) scale = ctx.alpha * l.shape[0] avg_dist = l.exp().mean(dim=0) negentropy = (avg_dist * (avg_dist + 1.0e-20).log()).sum() if random.random() < 0.0005: negentropy_individual = (l * l.exp()).sum(dim=-1).mean() - print( - "Negentropy[individual,combined] = ", - negentropy_individual.item(), - ", ", - negentropy.item(), - ) + print("Negentropy[individual,combined] = ", negentropy_individual.item(), ", ", negentropy.item()) loss = negentropy * scale loss.backward() return logprobs_grad + logprobs.grad, None @@ -192,23 +183,18 @@ class KnowledgeBaseLookup(nn.Module): embedding_dim: the dimension to project from and to, e.g. the d_model of the conformer. """ - - def __init__( - self, - M: int, - N: int, - D: int, - K: int, - embedding_dim: int, - knowledge_base: nn.Parameter, - negentropy_penalty: float = 0.001, - ): + def __init__(self, M: int, N: int, D: int, + K: int, embedding_dim: int, + knowledge_base: nn.Parameter, + negentropy_penalty: float = 0.001): super(KnowledgeBaseLookup, self).__init__() self.knowledge_base = knowledge_base # shared! - self.in_proj = ScaledLinear(embedding_dim, M * N, initial_scale=1.0) + self.in_proj = ScaledLinear(embedding_dim, M * N, + initial_scale=1.0) # initial_scale = 4.0 because the knowlege_base activations are # quite small -- if we use our optimizer they'll have stddev <= 0.1. - self.out_proj = ScaledLinear(D, embedding_dim, initial_scale=4.0) + self.out_proj = ScaledLinear(D, embedding_dim, + initial_scale = 4.0) self.M = M self.N = N self.K = K @@ -224,14 +210,14 @@ class KnowledgeBaseLookup(nn.Module): # TODO: later we can try multiplying by a projection of x or something like that. """ - x = self.in_proj(x) # now (*, M*N) - x = x.reshape(*x.shape[:-1], self.N, self.M) # now (*, N, M) - x = x.log_softmax(dim=-1) # now normalized logprobs, dim= (*, N, M) + x = self.in_proj(x) # now (*, M*N) + x = x.reshape(*x.shape[:-1], self.N, self.M) # now (*, N, M) + x = x.log_softmax(dim=-1) # now normalized logprobs, dim= (*, N, M) x = PenalizeNegentropyFunction.apply(x, self.negentropy_penalty) _, indexes, weights = sample_combined(x, self.K, input_is_log=True) - x = weighted_matrix_lookup(weights, indexes, self.knowledge_base) # now (*, D) - x = self.out_proj(x) # now (*, self.embedding_dim) + x = weighted_matrix_lookup(weights, indexes, self.knowledge_base) # now (*, D) + x = self.out_proj(x) # now (*, self.embedding_dim) return x @@ -251,44 +237,38 @@ def _test_knowledge_base_lookup(): x.requires_grad = True y = m(x) assert y.shape == x.shape - y.sum().backward() # make sure backward doesn't crash.. + y.sum().backward() # make sure backward doesn't crash.. print("y = ", y) print("x.grad = ", x.grad) print("knowlege_base.grad norm = ", knowledge_base.grad.norm()) dtype = torch.float32 - device = torch.device("cuda") - train_pairs = [ - ( - torch.randn(B, T, E, device=device, dtype=dtype), - torch.randn(B, T, E, device=device, dtype=dtype), - ) - for _ in range(10) - ] + device = torch.device('cuda') + train_pairs = [ (torch.randn(B, T, E, device=device, dtype=dtype), torch.randn(B, T, E, device=device, dtype=dtype)) for _ in range(10) ] from optim import Eve - optimizer = Eve(m.parameters(), lr=0.005, eps=1.0e-04) m = m.to(device).to(dtype) + start = timeit.default_timer() - # Epoch 0, batch 0, loss 1.0109944343566895 - # Epoch 10, batch 0, loss 1.0146660804748535 - # Epoch 20, batch 0, loss 1.0119813680648804 - # Epoch 30, batch 0, loss 1.0105408430099487 - # Epoch 40, batch 0, loss 1.0077732801437378 - # Epoch 50, batch 0, loss 1.0050103664398193 - # Epoch 60, batch 0, loss 1.0033129453659058 - # Epoch 70, batch 0, loss 1.0014232397079468 - # Epoch 80, batch 0, loss 0.9977912306785583 - # Epoch 90, batch 0, loss 0.8274348974227905 - # Epoch 100, batch 0, loss 0.3368612825870514 - # Epoch 110, batch 0, loss 0.11323091387748718 - # Time taken: 17.591704960912466 +# Epoch 0, batch 0, loss 1.0109944343566895 +# Epoch 10, batch 0, loss 1.0146660804748535 +# Epoch 20, batch 0, loss 1.0119813680648804 +# Epoch 30, batch 0, loss 1.0105408430099487 +# Epoch 40, batch 0, loss 1.0077732801437378 +# Epoch 50, batch 0, loss 1.0050103664398193 +# Epoch 60, batch 0, loss 1.0033129453659058 +# Epoch 70, batch 0, loss 1.0014232397079468 +# Epoch 80, batch 0, loss 0.9977912306785583 +# Epoch 90, batch 0, loss 0.8274348974227905 +# Epoch 100, batch 0, loss 0.3368612825870514 +# Epoch 110, batch 0, loss 0.11323091387748718 +# Time taken: 17.591704960912466 for epoch in range(150): - for n, (x, y) in enumerate(train_pairs): + for n, (x,y) in enumerate(train_pairs): y_out = m(x) - loss = ((y_out - y) ** 2).mean() * 100.0 + loss = ((y_out - y)**2).mean() * 100.0 if n % 10 == 0 and epoch % 10 == 0: print(f"Epoch {epoch}, batch {n}, loss {loss.item()}") loss.backward() @@ -296,8 +276,7 @@ def _test_knowledge_base_lookup(): optimizer.zero_grad() stop = timeit.default_timer() - print("Time taken: ", stop - start) - + print('Time taken: ', stop - start) def _test_knowledge_base_lookup_autocast(): K = 16 @@ -315,21 +294,14 @@ def _test_knowledge_base_lookup_autocast(): x.requires_grad = True y = m(x) assert y.shape == x.shape - y.sum().backward() # make sure backward doesn't crash.. + y.sum().backward() # make sure backward doesn't crash.. print("y = ", y) print("x.grad = ", x.grad) print("knowlege_base.grad norm = ", knowledge_base.grad.norm()) - device = torch.device("cuda") - train_pairs = [ - ( - torch.randn(B, T, E, device=device), - torch.randn(B, T, E, device=device), - ) - for _ in range(10) - ] + device = torch.device('cuda') + train_pairs = [ (torch.randn(B, T, E, device=device), torch.randn(B, T, E, device=device)) for _ in range(10) ] from optim import Eve - optimizer = Eve(m.parameters(), lr=0.005, eps=1.0e-04) m = m.to(device) @@ -337,11 +309,12 @@ def _test_knowledge_base_lookup_autocast(): start = timeit.default_timer() + for epoch in range(150): - for n, (x, y) in enumerate(train_pairs): + for n, (x,y) in enumerate(train_pairs): y_out = m(x) with torch.cuda.amp.autocast(enabled=True): - loss = ((y_out - y) ** 2).mean() * 100.0 + loss = ((y_out - y)**2).mean() * 100.0 if n % 10 == 0 and epoch % 10 == 0: print(f"Epoch {epoch}, batch {n}, loss {loss.item()}") scaler.scale(loss).backward() @@ -350,9 +323,10 @@ def _test_knowledge_base_lookup_autocast(): optimizer.zero_grad() stop = timeit.default_timer() - print("Time taken: ", stop - start) + print('Time taken: ', stop - start) -if __name__ == "__main__": + +if __name__ == '__main__': _test_knowledge_base_lookup() _test_knowledge_base_lookup_autocast() diff --git a/egs/librispeech/ASR/pruned2_knowledge/scaling.py b/egs/librispeech/ASR/pruned2_knowledge/scaling.py index 527c735eb..f726c2583 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/scaling.py +++ b/egs/librispeech/ASR/pruned2_knowledge/scaling.py @@ -18,11 +18,11 @@ import collections from itertools import repeat from typing import Optional, Tuple +from torch.cuda.amp import custom_fwd, custom_bwd import torch import torch.nn as nn from torch import Tensor -from torch.cuda.amp import custom_bwd, custom_fwd def _ntuple(n): @@ -79,7 +79,9 @@ class ActivationBalancerFunction(torch.autograd.Function): below_threshold = mean_abs < min_abs above_threshold = mean_abs > max_abs - ctx.save_for_backward(factor, xgt0, below_threshold, above_threshold) + ctx.save_for_backward( + factor, xgt0, below_threshold, above_threshold + ) ctx.max_factor = max_factor ctx.sum_dims = sum_dims return x @@ -147,7 +149,8 @@ class BasicNorm(torch.nn.Module): def forward(self, x: Tensor) -> Tensor: assert x.shape[self.channel_dim] == self.num_channels scales = ( - torch.mean(x**2, dim=self.channel_dim, keepdim=True) + self.eps.exp() + torch.mean(x ** 2, dim=self.channel_dim, keepdim=True) + + self.eps.exp() ) ** -0.5 return x * scales @@ -179,7 +182,11 @@ class ScaledLinear(nn.Linear): """ def __init__( - self, *args, initial_scale: float = 1.0, initial_speed: float = 1.0, **kwargs + self, + *args, + initial_scale: float = 1.0, + initial_speed: float = 1.0, + **kwargs ): super(ScaledLinear, self).__init__(*args, **kwargs) initial_scale = torch.tensor(initial_scale).log() @@ -195,12 +202,12 @@ class ScaledLinear(nn.Linear): def _reset_parameters(self, initial_speed: float): std = 0.1 / initial_speed - a = (3**0.5) * std + a = (3 ** 0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: nn.init.constant_(self.bias, 0.0) fan_in = self.weight.shape[1] * self.weight[0][0].numel() - scale = fan_in**-0.5 # 1/sqrt(fan_in) + scale = fan_in ** -0.5 # 1/sqrt(fan_in) with torch.no_grad(): self.weight_scale += torch.tensor(scale / std).log() @@ -211,13 +218,19 @@ class ScaledLinear(nn.Linear): return None if self.bias is None else self.bias * self.bias_scale.exp() def forward(self, input: Tensor) -> Tensor: - return torch.nn.functional.linear(input, self.get_weight(), self.get_bias()) + return torch.nn.functional.linear( + input, self.get_weight(), self.get_bias() + ) class ScaledConv1d(nn.Conv1d): # See docs for ScaledLinear def __init__( - self, *args, initial_scale: float = 1.0, initial_speed: float = 1.0, **kwargs + self, + *args, + initial_scale: float = 1.0, + initial_speed: float = 1.0, + **kwargs ): super(ScaledConv1d, self).__init__(*args, **kwargs) initial_scale = torch.tensor(initial_scale).log() @@ -232,12 +245,12 @@ class ScaledConv1d(nn.Conv1d): def _reset_parameters(self, initial_speed: float): std = 0.1 / initial_speed - a = (3**0.5) * std + a = (3 ** 0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: nn.init.constant_(self.bias, 0.0) fan_in = self.weight.shape[1] * self.weight[0][0].numel() - scale = fan_in**-0.5 # 1/sqrt(fan_in) + scale = fan_in ** -0.5 # 1/sqrt(fan_in) with torch.no_grad(): self.weight_scale += torch.tensor(scale / std).log() @@ -277,7 +290,11 @@ class ScaledConv1d(nn.Conv1d): class ScaledConv2d(nn.Conv2d): # See docs for ScaledLinear def __init__( - self, *args, initial_scale: float = 1.0, initial_speed: float = 1.0, **kwargs + self, + *args, + initial_scale: float = 1.0, + initial_speed: float = 1.0, + **kwargs ): super(ScaledConv2d, self).__init__(*args, **kwargs) initial_scale = torch.tensor(initial_scale).log() @@ -292,12 +309,12 @@ class ScaledConv2d(nn.Conv2d): def _reset_parameters(self, initial_speed: float): std = 0.1 / initial_speed - a = (3**0.5) * std + a = (3 ** 0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: nn.init.constant_(self.bias, 0.0) fan_in = self.weight.shape[1] * self.weight[0][0].numel() - scale = fan_in**-0.5 # 1/sqrt(fan_in) + scale = fan_in ** -0.5 # 1/sqrt(fan_in) with torch.no_grad(): self.weight_scale += torch.tensor(scale / std).log() @@ -636,7 +653,9 @@ def _test_activation_balancer_sign(): def _test_activation_balancer_magnitude(): magnitudes = torch.arange(0, 1, 0.01) N = 1000 - x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1) + x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze( + -1 + ) x = x.detach() x.requires_grad = True m = ActivationBalancer( @@ -666,8 +685,8 @@ def _test_basic_norm(): y = m(x) assert y.shape == x.shape - x_rms = (x**2).mean().sqrt() - y_rms = (y**2).mean().sqrt() + x_rms = (x ** 2).mean().sqrt() + y_rms = (y ** 2).mean().sqrt() print("x rms = ", x_rms) print("y rms = ", y_rms) assert y_rms < x_rms diff --git a/egs/librispeech/ASR/pruned2_knowledge/scaling_tmp.py b/egs/librispeech/ASR/pruned2_knowledge/scaling_tmp.py index 3f21133a0..6293e081a 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/scaling_tmp.py +++ b/egs/librispeech/ASR/pruned2_knowledge/scaling_tmp.py @@ -15,23 +15,21 @@ # limitations under the License. -from typing import Optional, Tuple - import torch import torch.nn as nn from torch import Tensor +from typing import Tuple, Optional -def _activation_balancer_loss( - mean_pos: Tensor, - mean_neg: Tensor, - min_positive: float, # e.g. 0.05 - max_positive: float, # e.g. 0.95 - max_factor: float, # e.g. 0.01 - min_abs: float, # e.g. 0.2 - max_abs: float, # e.g. 100.0 - eps: float = 1.0e-10, -): + +def _activation_balancer_loss(mean_pos: Tensor, + mean_neg: Tensor, + min_positive: float, # e.g. 0.05 + max_positive: float, # e.g. 0.95 + max_factor: float, # e.g. 0.01 + min_abs: float, # e.g. 0.2 + max_abs: float, # e.g. 100.0 + eps: float = 1.0e-10): """ Returns a loss-function for the ActivationBalancer module. This loss function is not exposed to the user but is used internally, and eventually @@ -52,32 +50,28 @@ def _activation_balancer_loss( """ loss_parts = [] - x_mean = mean_pos - mean_neg - x_mean_abs = (mean_pos + mean_neg + eps).detach() - x_rel_mean = x_mean / x_mean_abs + x_mean = mean_positive - mean_negative + x_mean_abs = (mean_positive + mean_negative + eps).detach() + x_rel_mean= x_mean / x_mean_abs if min_positive != 0.0: # e.g. x_mean_floor = -0.95 + 0.05 = -0.9 - x_rel_mean_floor = -(1 - min_positive) + min_positive - min_positive_loss = (x_rel_mean_floor - x_rel_mean).relu().sum() * ( - 1.0 / (2 * min_positive) - ) + x_rel_mean_floor = (-(1-min_positive) + min_positive) + min_positive_loss = (x_rel_mean_floor - x_rel_mean).relu().sum() * (1.0 / (2*min_positive)) # this part of the loss would be 1.0 * num_channels if all these constraints were # 100% violated. loss_parts.append(min_positive_loss) if max_positive != 1.0: # e.g. x_mean_floor = -0.05 + 0.95 = 0.8 - x_rel_mean_ceil = -(1.0 - max_positive) + max_positive - max_positive_loss = (x_rel_mean - x_rel_mean_ceil).relu().sum() * ( - 1.0 / (1 - x_rel_mean_ceil) - ) + x_rel_mean_ceil = - (1.0-max_positive) + max_positive + max_positive_loss = (x_rel_mean - x_rel_mean_ceil).relu().sum() * (1.0 / (1 - x_rel_mean_ceil)) # this part of the loss would be 1.0 * num_channels if all these constraints were # 100% violated. loss_parts.append(max_positive_loss) if min_abs != 0.0: - min_abs_loss = (min_abs - x_mean_abs).relu().sum() / min_abs + min_abs_loss = min_abs - x_mean_abs).relu().sum() / min_abs # this part of the loss would be 1.0 * num_channels if all these constraints were # 100% violated. loss_parts.append(min_abs_loss) @@ -88,53 +82,43 @@ def _activation_balancer_loss( # 100% violated. loss_parts.append(max_abs_loss) + # the min_positive and 1 - max_positive are "ballast" added to the denom = mean_pos + mean_neg + (min_positive + (1 - max_positive)) - # num + num if min_positive != 0.0: - pass + + class ActivationBalancerFunction(torch.autograd.Function): @staticmethod - def forward( - ctx, - x: Tensor, - channel_dim: int, - min_positive: float, # e.g. 0.05 - max_positive: float, # e.g. 0.95 - max_factor: float, # e.g. 0.01 - min_abs: float, # e.g. 0.2 - max_abs: float, # e.g. 100.0 + def forward(ctx, x: Tensor, + channel_dim: int, + min_positive: float, # e.g. 0.05 + max_positive: float, # e.g. 0.95 + max_factor: float, # e.g. 0.01 + min_abs: float, # e.g. 0.2 + max_abs: float, # e.g. 100.0 ) -> Tensor: if x.requires_grad: if channel_dim < 0: channel_dim += x.ndim sum_dims = [d for d in range(x.ndim) if d != channel_dim] xgt0 = x > 0 - proportion_positive = torch.mean( - xgt0.to(x.dtype), dim=sum_dims, keepdim=True - ) - factor1 = ( - (min_positive - proportion_positive).relu() - * (max_factor / min_positive) - if min_positive != 0.0 - else 0.0 - ) - factor2 = ( - (proportion_positive - max_positive).relu() - * (max_factor / (max_positive - 1.0)) - if max_positive != 1.0 - else 0.0 - ) + proportion_positive = torch.mean(xgt0.to(x.dtype), dim=sum_dims, keepdim=True) + factor1 = ((min_positive - proportion_positive).relu() * (max_factor / min_positive) + if min_positive != 0.0 else 0.0) + factor2 = ((proportion_positive - max_positive).relu() * (max_factor / (max_positive - 1.0)) + if max_positive != 1.0 else 0.0) factor = factor1 + factor2 if isinstance(factor, float): factor = torch.zeros_like(proportion_positive) mean_abs = torch.mean(x.abs(), dim=sum_dims, keepdim=True) - below_threshold = mean_abs < min_abs - above_threshold = mean_abs > max_abs + below_threshold = (mean_abs < min_abs) + above_threshold = (mean_abs > max_abs) ctx.save_for_backward(factor, xgt0, below_threshold, above_threshold) ctx.max_factor = max_factor @@ -142,16 +126,11 @@ class ActivationBalancerFunction(torch.autograd.Function): return x @staticmethod - def backward( - ctx, x_grad: Tensor - ) -> Tuple[Tensor, None, None, None, None, None, None]: + def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None, None, None, None]: factor, xgt0, below_threshold, above_threshold = ctx.saved_tensors dtype = x_grad.dtype - scale_factor = ( - (below_threshold.to(dtype) - above_threshold.to(dtype)) - * (xgt0.to(dtype) - 0.5) - * (ctx.max_factor * 2.0) - ) + scale_factor = ((below_threshold.to(dtype) - above_threshold.to(dtype)) * + (xgt0.to(dtype) - 0.5) * (ctx.max_factor * 2.0)) neg_delta_grad = x_grad.abs() * (factor + scale_factor) return x_grad - neg_delta_grad, None, None, None, None, None, None @@ -184,30 +163,29 @@ class BasicNorm(torch.nn.Module): learn_eps: if true, we learn epsilon; if false, we keep it at the initial value. """ - - def __init__( - self, - num_channels: int, - channel_dim: int = -1, # CAUTION: see documentation. - eps: float = 0.25, - learn_eps: bool = True, - ) -> None: + def __init__(self, + num_channels: int, + channel_dim: int = -1, # CAUTION: see documentation. + eps: float = 0.25, + learn_eps: bool = True) -> None: super(BasicNorm, self).__init__() self.num_channels = num_channels self.channel_dim = channel_dim if learn_eps: self.eps = nn.Parameter(torch.tensor(eps).log().detach()) else: - self.register_buffer("eps", torch.tensor(eps).log().detach()) + self.register_buffer('eps', torch.tensor(eps).log().detach()) + def forward(self, x: Tensor) -> Tensor: assert x.shape[self.channel_dim] == self.num_channels - scales = ( - torch.mean(x**2, dim=self.channel_dim, keepdim=True) + self.eps.exp() - ) ** -0.5 + scales = (torch.mean(x**2, dim=self.channel_dim, keepdim=True) + + self.eps.exp()) ** -0.5 return x * scales + + class ScaledLinear(nn.Linear): """ A modified version of nn.Linear where the parameters are scaled before @@ -229,26 +207,27 @@ class ScaledLinear(nn.Linear): inherited from nn.Linear. For modules with small fan-in, this may be larger than optimal. """ - - def __init__(self, *args, initial_scale: float = 1.0, **kwargs): + def __init__(self, *args, + initial_scale: float = 1.0, + **kwargs): super(ScaledLinear, self).__init__(*args, **kwargs) initial_scale = torch.tensor(initial_scale).log() self.weight_scale = nn.Parameter(initial_scale.clone().detach()) if self.bias is not None: self.bias_scale = nn.Parameter(initial_scale.clone().detach()) else: - self.register_parameter("bias_scale", None) + self.register_parameter('bias_scale', None) self._reset_parameters() # Overrides the reset_parameters in nn.Linear def _reset_parameters(self): std = 0.01 - a = (3**0.5) * std + a = (3 ** 0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: nn.init.constant_(self.bias, 0.0) fan_in = self.weight.shape[1] * self.weight[0][0].numel() - scale = fan_in**-0.5 # 1/sqrt(fan_in) + scale = fan_in ** -0.5 # 1/sqrt(fan_in) with torch.no_grad(): self.weight_scale += torch.tensor(scale / std).log() if self.bias is not None: @@ -258,67 +237,56 @@ class ScaledLinear(nn.Linear): return self.weight * self.weight_scale.exp() def get_bias(self): - return None if self.bias is None else self.bias * self.bias_scale.exp() + return (None if self.bias is None else + self.bias * self.bias_scale.exp()) def forward(self, input: Tensor) -> Tensor: - return torch.nn.functional.linear(input, self.get_weight(), self.get_bias()) + return torch.nn.functional.linear(input, self.get_weight(), + self.get_bias()) class ScaledConv1d(nn.Conv1d): - def __init__(self, *args, initial_scale=1.0, **kwargs): + def __init__(self, *args, + initial_scale=1.0, **kwargs): super(ScaledConv1d, self).__init__(*args, **kwargs) initial_scale = torch.tensor(initial_scale).log() self.weight_scale = nn.Parameter(initial_scale.clone().detach()) if self.bias is not None: self.bias_scale = nn.Parameter(initial_scale.clone().detach()) else: - self.register_parameter("bias_scale", None) + self.register_parameter('bias_scale', None) self._reset_parameters() # Overrides the reset_parameters in base class def _reset_parameters(self): std = 0.01 - a = (3**0.5) * std + a = (3 ** 0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: nn.init.constant_(self.bias, 0.0) fan_in = self.weight.shape[1] * self.weight[0][0].numel() - scale = fan_in**-0.5 # 1/sqrt(fan_in) + scale = fan_in ** -0.5 # 1/sqrt(fan_in) with torch.no_grad(): self.weight_scale += torch.tensor(scale / std).log() if self.bias is not None: self.bias_scale += torch.tensor(scale / std).log() + def get_weight(self): return self.weight * self.weight_scale.exp() def get_bias(self): - return None if self.bias is None else self.bias * self.bias_scale.exp() + return (None if self.bias is None else + self.bias * self.bias_scale.exp()) def forward(self, input: Tensor) -> Tensor: F = torch.nn.functional - if self.padding_mode != "zeros": - return F.conv1d( - F.pad( - input, - self._reversed_padding_repeated_twice, - mode=self.padding_mode, - ), - self.get_weight(), - self.get_bias(), - self.stride, - _single(0), # noqa: F821 - self.dilation, - self.groups, - ) - return F.conv1d( - input, - self.get_weight(), - self.get_bias(), - self.stride, - self.padding, - self.dilation, - self.groups, - ) + if self.padding_mode != 'zeros': + return F.conv1d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode), + self.get_weight(), self.get_bias(), self.stride, + _single(0), self.dilation, self.groups) + return F.conv1d(input, self.get_weight(), self.get_bias(), self.stride, + self.padding, self.dilation, self.groups) + class ScaledConv2d(nn.Conv2d): @@ -329,58 +297,45 @@ class ScaledConv2d(nn.Conv2d): if self.bias is not None: self.bias_scale = nn.Parameter(initial_scale.clone().detach()) else: - self.register_parameter("bias_scale", None) + self.register_parameter('bias_scale', None) self._reset_parameters() # Overrides the reset_parameters in base class def _reset_parameters(self): std = 0.01 - a = (3**0.5) * std + a = (3 ** 0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: nn.init.constant_(self.bias, 0.0) fan_in = self.weight.shape[1] * self.weight[0][0].numel() - scale = fan_in**-0.5 # 1/sqrt(fan_in) + scale = fan_in ** -0.5 # 1/sqrt(fan_in) with torch.no_grad(): self.weight_scale += torch.tensor(scale / std).log() if self.bias is not None: self.bias_scale += torch.tensor(scale / std).log() + def get_weight(self): return self.weight * self.weight_scale.exp() def get_bias(self): - return None if self.bias is None else self.bias * self.bias_scale.exp() + return (None if self.bias is None else + self.bias * self.bias_scale.exp()) def _conv_forward(self, input, weight): F = torch.nn.functional - if self.padding_mode != "zeros": - return F.conv2d( - F.pad( - input, - self._reversed_padding_repeated_twice, - mode=self.padding_mode, - ), - weight, - self.get_bias(), - self.stride, - _pair(0), # noqa: F821 - self.dilation, - self.groups, - ) - return F.conv2d( - input, - weight, - self.get_bias(), - self.stride, - self.padding, - self.dilation, - self.groups, - ) + if self.padding_mode != 'zeros': + return F.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode), + weight, self.get_bias(), self.stride, + _pair(0), self.dilation, self.groups) + return F.conv2d(input, weight, self.get_bias(), self.stride, + self.padding, self.dilation, self.groups) def forward(self, input: Tensor) -> Tensor: return self._conv_forward(input, self.get_weight()) + + class ActivationBalancer(torch.nn.Module): """ Modifies the backpropped derivatives of a function to try to encourage, for @@ -409,16 +364,12 @@ class ActivationBalancer(torch.nn.Module): we allow, before we start to modify the derivatives to prevent this. """ - - def __init__( - self, - channel_dim: int, - min_positive: float = 0.05, - max_positive: float = 0.95, - max_factor: float = 0.01, - min_abs: float = 0.2, - max_abs: float = 100.0, - ): + def __init__(self, channel_dim: int, + min_positive: float = 0.05, + max_positive: float = 0.95, + max_factor: float = 0.01, + min_abs: float = 0.2, + max_abs: float = 100.0): super(ActivationBalancer, self).__init__() self.channel_dim = channel_dim self.min_positive = min_positive @@ -428,15 +379,10 @@ class ActivationBalancer(torch.nn.Module): self.max_abs = max_abs def forward(self, x: Tensor) -> Tensor: - return ActivationBalancerFunction.apply( - x, - self.channel_dim, - self.min_positive, - self.max_positive, - self.max_factor, - self.min_abs, - self.max_abs, - ) + return ActivationBalancerFunction.apply(x, self.channel_dim, + self.min_positive, self.max_positive, + self.max_factor, self.min_abs, + self.max_abs) class DoubleSwishFunction(torch.autograd.Function): @@ -454,7 +400,6 @@ class DoubleSwishFunction(torch.autograd.Function): = double_swish(x) * (1-s(x)) + s(x) ... so we just need to remember s(x) but not x itself. """ - @staticmethod def forward(ctx, x: Tensor) -> Tensor: x = x.detach() @@ -466,17 +411,18 @@ class DoubleSwishFunction(torch.autograd.Function): @staticmethod def backward(ctx, y_grad: Tensor) -> Tensor: s, y = ctx.saved_tensors - return (y * (1 - s) + s) * y_grad - + return (y * (1-s) + s) * y_grad class DoubleSwish(torch.nn.Module): def forward(self, x: Tensor) -> Tensor: """Return double-swish activation function which is an approximation to Swish(Swish(x)), - that we approximate closely with x * sigmoid(x-1). + that we approximate closely with x * sigmoid(x-1). """ return DoubleSwishFunction.apply(x) + + class ScaledEmbedding(nn.Module): r"""A simple lookup table that stores embeddings of a fixed dictionary and size. @@ -545,13 +491,8 @@ class ScaledEmbedding(nn.Module): [ 0.0000, 0.0000, 0.0000], [-0.1655, 0.9897, 0.0635]]]) """ - __constants__ = [ - "num_embeddings", - "embedding_dim", - "padding_idx", - "scale_grad_by_freq", - "sparse", - ] + __constants__ = ['num_embeddings', 'embedding_dim', 'padding_idx', + 'scale_grad_by_freq', 'sparse'] num_embeddings: int embedding_dim: int @@ -560,40 +501,33 @@ class ScaledEmbedding(nn.Module): weight: Tensor sparse: bool - def __init__( - self, - num_embeddings: int, - embedding_dim: int, - padding_idx: Optional[int] = None, - scale_grad_by_freq: bool = False, - sparse: bool = False, - ) -> None: + def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None, + scale_grad_by_freq: bool = False, + sparse: bool = False) -> None: super(ScaledEmbedding, self).__init__() self.num_embeddings = num_embeddings self.embedding_dim = embedding_dim if padding_idx is not None: if padding_idx > 0: - assert ( - padding_idx < self.num_embeddings - ), "Padding_idx must be within num_embeddings" + assert padding_idx < self.num_embeddings, 'Padding_idx must be within num_embeddings' elif padding_idx < 0: - assert ( - padding_idx >= -self.num_embeddings - ), "Padding_idx must be within num_embeddings" + assert padding_idx >= -self.num_embeddings, 'Padding_idx must be within num_embeddings' padding_idx = self.num_embeddings + padding_idx self.padding_idx = padding_idx self.scale_grad_by_freq = scale_grad_by_freq - self.scale = nn.Parameter(torch.zeros(())) # see reset_parameters() + self.scale = nn.Parameter(torch.zeros(())) # see reset_parameters() self.sparse = sparse self.weight = nn.Parameter(torch.Tensor(num_embeddings, embedding_dim)) self.reset_parameters() + + def reset_parameters(self) -> None: std = 0.01 nn.init.normal_(self.weight, std=std) - nn.init.constant_(self.scale, torch.tensor(1.0 / std).log()) + nn.init.constant_(self.scale, torch.tensor(1.0/std).log()) if self.padding_idx is not None: with torch.no_grad(): @@ -603,37 +537,24 @@ class ScaledEmbedding(nn.Module): F = torch.nn.functional scale = self.scale.exp() if input.numel() < self.num_embeddings: - return ( - F.embedding( - input, - self.weight, - self.padding_idx, - None, - 2.0, # None, 2.0 relate to normalization - self.scale_grad_by_freq, - self.sparse, - ) - * scale - ) + return F.embedding( + input, self.weight, self.padding_idx, + None, 2.0, # None, 2.0 relate to normalization + self.scale_grad_by_freq, self.sparse) * scale else: return F.embedding( - input, - self.weight * scale, - self.padding_idx, - None, - 2.0, # None, 2.0 relates to normalization - self.scale_grad_by_freq, - self.sparse, - ) + input, self.weight * scale, self.padding_idx, + None, 2.0, # None, 2.0 relates to normalization + self.scale_grad_by_freq, self.sparse) def extra_repr(self) -> str: - s = "{num_embeddings}, {embedding_dim}, scale={scale}" + s = '{num_embeddings}, {embedding_dim}, scale={scale}' if self.padding_idx is not None: - s += ", padding_idx={padding_idx}" + s += ', padding_idx={padding_idx}' if self.scale_grad_by_freq is not False: - s += ", scale_grad_by_freq={scale_grad_by_freq}" + s += ', scale_grad_by_freq={scale_grad_by_freq}' if self.sparse is not False: - s += ", sparse=True" + s += ', sparse=True' return s.format(**self.__dict__) @@ -644,13 +565,8 @@ def _test_activation_balancer_sign(): x = 1.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1)) x = x.detach() x.requires_grad = True - m = ActivationBalancer( - channel_dim=0, - min_positive=0.05, - max_positive=0.95, - max_factor=0.2, - min_abs=0.0, - ) + m = ActivationBalancer(channel_dim=0, min_positive=0.05, max_positive=0.95, + max_factor=0.2, min_abs=0.0) y_grad = torch.sign(torch.randn(probs.numel(), N)) @@ -660,22 +576,17 @@ def _test_activation_balancer_sign(): print("_test_activation_balancer_sign: y grad = ", y_grad) print("_test_activation_balancer_sign: x grad = ", x.grad) - def _test_activation_balancer_magnitude(): channel_dim = 0 magnitudes = torch.arange(0, 1, 0.01) N = 1000 - x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1) + x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1) x = x.detach() x.requires_grad = True - m = ActivationBalancer( - channel_dim=0, - min_positive=0.0, - max_positive=1.0, - max_factor=0.2, - min_abs=0.2, - max_abs=0.8, - ) + m = ActivationBalancer(channel_dim=0, + min_positive=0.0, max_positive=1.0, + max_factor=0.2, + min_abs=0.2, max_abs=0.8) y_grad = torch.sign(torch.randn(magnitudes.numel(), N)) @@ -710,7 +621,7 @@ def _test_double_swish_deriv(): torch.autograd.gradcheck(m, x) -if __name__ == "__main__": +if __name__ == '__main__': _test_activation_balancer_sign() _test_activation_balancer_magnitude() _test_basic_norm() diff --git a/egs/librispeech/ASR/pruned2_knowledge/train.py b/egs/librispeech/ASR/pruned2_knowledge/train.py index a60d15c3b..2f6840166 100755 --- a/egs/librispeech/ASR/pruned2_knowledge/train.py +++ b/egs/librispeech/ASR/pruned2_knowledge/train.py @@ -78,7 +78,9 @@ from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] +LRSchedulerType = Union[ + torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler +] def get_parser(): @@ -177,45 +179,42 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help=( - "The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss" - ), + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help=( - "The scale to smooth the loss with lm (output of prediction network) part." - ), + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)part.", + help="The scale to smooth the loss with am (output of encoder network)" + "part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help=( - "To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss." - ), + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", ) parser.add_argument( @@ -555,16 +554,23 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) + 0.0 + if warmup < 1.0 + else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) + ) + loss = ( + params.simple_loss_scale * simple_loss + + pruned_loss_scale * pruned_loss ) - loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + info["frames"] = ( + (feature_lens // params.subsampling_factor).sum().item() + ) # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() @@ -727,7 +733,9 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -827,7 +835,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 2 ** 22 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/decode.py b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/decode.py index 1df1650f3..2d5724d30 100755 --- a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/decode.py +++ b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/decode.py @@ -123,24 +123,20 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( "--use-averaged-model", type=str2bool, default=False, - help=( - "Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. " - ), + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", ) parser.add_argument( @@ -208,7 +204,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -275,7 +272,9 @@ def decode_one_batch( value=LOG_EPS, ) - encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + encoder_out, encoder_out_lens = model.encoder( + x=feature, x_lens=feature_lens + ) hyps = [] if params.decoding_method == "fast_beam_search": @@ -290,7 +289,10 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + elif ( + params.decoding_method == "greedy_search" + and params.max_sym_per_frame == 1 + ): hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -336,7 +338,11 @@ def decode_one_batch( return {"greedy_search": hyps} elif params.decoding_method == "fast_beam_search": return { - f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps + ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}" + ): hyps } else: return {f"beam_size_{params.beam_size}": hyps} @@ -409,7 +415,9 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) return results @@ -442,7 +450,8 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -485,7 +494,9 @@ def main(): params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + params.suffix += ( + f"-{params.decoding_method}-beam-size-{params.beam_size}" + ) else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -517,12 +528,13 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -545,12 +557,13 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -578,7 +591,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - "Calculating the averaged model over epoch range from " + f"Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) diff --git a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/emformer.py b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/emformer.py index 008f40fb1..318cd5094 100644 --- a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/emformer.py +++ b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/emformer.py @@ -272,9 +272,13 @@ class Emformer(EncoderInterface): # Caution: We assume the subsampling factor is 4! x_lens = (((x_lens - 1) >> 1) - 1) >> 1 - emformer_out, emformer_out_lens, states = self.model.infer(x, x_lens, states) + emformer_out, emformer_out_lens, states = self.model.infer( + x, x_lens, states + ) - if x.size(1) != (self.model.segment_length + self.model.right_context_length): + if x.size(1) != ( + self.model.segment_length + self.model.right_context_length + ): raise ValueError( "Incorrect input shape." f"{x.size(1)} vs {self.model.segment_length} + " diff --git a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/export.py b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/export.py index 81afb523d..2375f5001 100755 --- a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/export.py +++ b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/export.py @@ -89,24 +89,20 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( "--use-averaged-model", type=str2bool, default=False, - help=( - "Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. " - ), + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", ) parser.add_argument( @@ -137,7 +133,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) add_model_arguments(parser) @@ -173,12 +170,13 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -201,12 +199,13 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -234,7 +233,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - "Calculating the averaged model over epoch range from " + f"Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -274,7 +273,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/model.py b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/model.py index ed6848879..2f019bcdb 100644 --- a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/model.py +++ b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/model.py @@ -122,7 +122,9 @@ class Transducer(nn.Module): y_padded = y.pad(mode="constant", padding_value=0) y_padded = y_padded.to(torch.int64) - boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) + boundary = torch.zeros( + (x.size(0), 4), dtype=torch.int64, device=x.device + ) boundary[:, 2] = y_lens boundary[:, 3] = x_lens diff --git a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/train.py b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/train.py index 6b30d3be8..fed814f19 100755 --- a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/train.py +++ b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/train.py @@ -209,45 +209,42 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help=( - "The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss" - ), + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help=( - "The scale to smooth the loss with lm (output of prediction network) part." - ), + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)part.", + help="The scale to smooth the loss with am (output of encoder network)" + "part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help=( - "To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss." - ), + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", ) parser.add_argument( @@ -569,7 +566,11 @@ def compute_loss( function enables autograd during computation; when it is False, it disables autograd. """ - device = model.device if isinstance(model, DDP) else next(model.parameters()).device + device = ( + model.device + if isinstance(model, DDP) + else next(model.parameters()).device + ) feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -598,7 +599,9 @@ def compute_loss( info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + info["frames"] = ( + (feature_lens // params.subsampling_factor).sum().item() + ) # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa info["utterances"] = feature.size(0) @@ -779,7 +782,9 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -903,7 +908,8 @@ def run(rank, world_size, args): # the threshold if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. " + f"Duration: {c.duration}" ) return False diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py index 830b37cfb..7af9cc3d7 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py @@ -509,9 +509,9 @@ def greedy_search( y = logits.argmax().item() if y not in (blank_id, unk_id): hyp.append(y) - decoder_input = torch.tensor([hyp[-context_size:]], device=device).reshape( - 1, context_size - ) + decoder_input = torch.tensor( + [hyp[-context_size:]], device=device + ).reshape(1, context_size) decoder_out = model.decoder(decoder_input, need_pad=False) @@ -670,7 +670,9 @@ class HypothesisList(object): if use_max: old_hyp.log_prob = max(old_hyp.log_prob, hyp.log_prob) else: - torch.logaddexp(old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob) + torch.logaddexp( + old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob + ) else: self._data[key] = hyp @@ -686,7 +688,9 @@ class HypothesisList(object): Return the hypothesis that has the largest `log_prob`. """ if length_norm: - return max(self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys)) + return max( + self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys) + ) else: return max(self._data.values(), key=lambda hyp: hyp.log_prob) @@ -888,7 +892,9 @@ def modified_beam_search( log_probs_shape = k2.ragged.create_ragged_shape2( row_splits=row_splits, cached_tot_size=log_probs.numel() ) - ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) + ragged_log_probs = k2.RaggedTensor( + shape=log_probs_shape, value=log_probs + ) for i in range(batch_size): topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) @@ -1082,7 +1088,9 @@ def beam_search( t = 0 B = HypothesisList() - B.add(Hypothesis(ys=[blank_id] * context_size, log_prob=0.0), use_max=use_max) + B.add( + Hypothesis(ys=[blank_id] * context_size, log_prob=0.0), use_max=use_max + ) max_sym_per_utt = 20000 @@ -1122,7 +1130,9 @@ def beam_search( cached_key += f"-t-{t}" if cached_key not in joint_cache: - logits = model.joiner(current_encoder_out, decoder_out.unsqueeze(1)) + logits = model.joiner( + current_encoder_out, decoder_out.unsqueeze(1) + ) # TODO(fangjun): Scale the blank posterior diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py index 03ad45f49..7b6338948 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py @@ -128,7 +128,11 @@ from beam_search import ( ) from train import add_model_arguments, get_params, get_transducer_model -from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint +from icefall.checkpoint import ( + average_checkpoints, + find_checkpoints, + load_checkpoint, +) from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, @@ -167,11 +171,9 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( @@ -267,7 +269,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -380,7 +383,9 @@ def decode_one_batch( simulate_streaming=True, ) else: - encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + encoder_out, encoder_out_lens = model.encoder( + x=feature, x_lens=feature_lens + ) hyps = [] if ( @@ -445,7 +450,10 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + elif ( + params.decoding_method == "greedy_search" + and params.max_sym_per_frame == 1 + ): hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -576,7 +584,9 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) return results @@ -609,7 +619,8 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -667,7 +678,9 @@ def main(): if "LG" in params.decoding_method: params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + params.suffix += ( + f"-{params.decoding_method}-beam-size-{params.beam_size}" + ) else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -705,7 +718,8 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -743,7 +757,9 @@ def main(): decoding_graph.scores *= params.ngram_lm_scale else: word_table = None - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + decoding_graph = k2.trivial_graph( + params.vocab_size - 1, device=device + ) else: decoding_graph = None word_table = None diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/decode_stream.py b/egs/librispeech/ASR/pruned_transducer_stateless/decode_stream.py index e522943c0..386248554 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless/decode_stream.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/decode_stream.py @@ -75,7 +75,9 @@ class DecodeStream(object): # encoder.streaming_forward self.done_frames: int = 0 - self.pad_length = (params.right_context + 2) * params.subsampling_factor + 3 + self.pad_length = ( + params.right_context + 2 + ) * params.subsampling_factor + 3 if params.decoding_method == "greedy_search": self.hyp = [params.blank_id] * params.context_size @@ -89,11 +91,13 @@ class DecodeStream(object): ) elif params.decoding_method == "fast_beam_search": # The rnnt_decoding_stream for fast_beam_search. - self.rnnt_decoding_stream: k2.RnntDecodingStream = k2.RnntDecodingStream( - decoding_graph + self.rnnt_decoding_stream: k2.RnntDecodingStream = ( + k2.RnntDecodingStream(decoding_graph) ) else: - raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) @property def done(self) -> bool: @@ -122,10 +126,13 @@ class DecodeStream(object): """Consume chunk_size frames of features""" chunk_length = chunk_size + self.pad_length - ret_length = min(self.num_frames - self.num_processed_frames, chunk_length) + ret_length = min( + self.num_frames - self.num_processed_frames, chunk_length + ) ret_features = self.features[ - self.num_processed_frames : self.num_processed_frames + ret_length # noqa + self.num_processed_frames : self.num_processed_frames # noqa + + ret_length ] self.num_processed_frames += chunk_size diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/decoder.py b/egs/librispeech/ASR/pruned_transducer_stateless/decoder.py index 72593173c..f4355e8a0 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless/decoder.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/decoder.py @@ -92,7 +92,9 @@ class Decoder(nn.Module): if self.context_size > 1: embedding_out = embedding_out.permute(0, 2, 1) if need_pad is True: - embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0)) + embedding_out = F.pad( + embedding_out, pad=(self.context_size - 1, 0) + ) else: # During inference time, there is no need to do extra padding # as we only need one output diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/export.py b/egs/librispeech/ASR/pruned_transducer_stateless/export.py index 64708e524..b5a151878 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/export.py @@ -64,20 +64,17 @@ def get_parser(): "--epoch", type=int, default=28, - help=( - "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." - ), + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", ) parser.add_argument( "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. " - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", ) parser.add_argument( @@ -108,7 +105,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( @@ -194,7 +192,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/model.py b/egs/librispeech/ASR/pruned_transducer_stateless/model.py index 2cca7fa27..73b651b3f 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/model.py @@ -130,7 +130,9 @@ class Transducer(nn.Module): y_padded = y.pad(mode="constant", padding_value=0) y_padded = y_padded.to(torch.int64) - boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) + boundary = torch.zeros( + (x.size(0), 4), dtype=torch.int64, device=x.device + ) boundary[:, 2] = y_lens boundary[:, 3] = x_lens diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless/pretrained.py index a42b63b9c..eb95827af 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/pretrained.py @@ -91,11 +91,9 @@ def get_parser(): "--checkpoint", type=str, required=True, - help=( - "Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint()." - ), + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", ) parser.add_argument( @@ -120,12 +118,10 @@ def get_parser(): "sound_files", type=str, nargs="+", - help=( - "The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz." - ), + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", ) parser.add_argument( @@ -172,7 +168,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -224,9 +221,10 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) # We use only the first channel ans.append(wave[0]) return ans @@ -294,7 +292,9 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + features = pad_sequence( + features, batch_first=True, padding_value=math.log(1e-10) + ) feature_lengths = torch.tensor(feature_lengths, device=device) @@ -381,7 +381,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/streaming_beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless/streaming_beam_search.py index 9e09200a1..dcf6dc42f 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless/streaming_beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/streaming_beam_search.py @@ -166,10 +166,14 @@ def modified_beam_search( log_probs_shape = k2.ragged.create_ragged_shape2( row_splits=row_splits, cached_tot_size=log_probs.numel() ) - ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) + ragged_log_probs = k2.RaggedTensor( + shape=log_probs_shape, value=log_probs + ) for i in range(batch_size): - topk_log_probs, topk_indexes = ragged_log_probs[i].topk(num_active_paths) + topk_log_probs, topk_indexes = ragged_log_probs[i].topk( + num_active_paths + ) with warnings.catch_warnings(): warnings.simplefilter("ignore") diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py index a50b4d4f0..d2cae4f9f 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py @@ -51,7 +51,11 @@ from streaming_beam_search import ( from torch.nn.utils.rnn import pad_sequence from train import add_model_arguments, get_params, get_transducer_model -from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint +from icefall.checkpoint import ( + average_checkpoints, + find_checkpoints, + load_checkpoint, +) from icefall.utils import ( AttributeDict, setup_logger, @@ -90,11 +94,9 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( @@ -160,7 +162,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( @@ -266,7 +269,9 @@ def decode_one_chunk( ) if params.decoding_method == "greedy_search": - greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams) + greedy_search( + model=model, encoder_out=encoder_out, streams=decode_streams + ) elif params.decoding_method == "fast_beam_search": processed_lens = processed_lens + encoder_out_lens fast_beam_search_one_best( @@ -286,7 +291,9 @@ def decode_one_chunk( num_active_paths=params.num_active_paths, ) else: - raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) states = [torch.unbind(states[0], dim=2), torch.unbind(states[1], dim=2)] @@ -342,7 +349,9 @@ def decode_dataset( decode_results = [] # Contain decode streams currently running. decode_streams = [] - initial_states = model.encoder.get_init_state(params.left_context, device=device) + initial_states = model.encoder.get_init_state( + params.left_context, device=device + ) for num, cut in enumerate(cuts): # each utterance has a DecodeStream. decode_stream = DecodeStream( @@ -413,7 +422,9 @@ def decode_dataset( elif params.decoding_method == "modified_beam_search": key = f"num_active_paths_{params.num_active_paths}" else: - raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) return {key: decode_results} @@ -449,7 +460,8 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -521,7 +533,8 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/train.py b/egs/librispeech/ASR/pruned_transducer_stateless/train.py index dd0331a60..399b11a29 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/train.py @@ -203,45 +203,42 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help=( - "The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss" - ), + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help=( - "The scale to smooth the loss with lm (output of prediction network) part." - ), + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)part.", + help="The scale to smooth the loss with am (output of encoder network)" + "part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help=( - "To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss." - ), + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", ) parser.add_argument( @@ -565,7 +562,9 @@ def compute_loss( # If either all simple_loss or pruned_loss is inf or nan, # we stop the training process by raising an exception - if torch.all(~simple_loss_is_finite) or torch.all(~pruned_loss_is_finite): + if torch.all(~simple_loss_is_finite) or torch.all( + ~pruned_loss_is_finite + ): raise ValueError( "There are too many utterances in this batch " "leading to inf or nan losses." @@ -585,7 +584,9 @@ def compute_loss( # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2 # (2) If some utterances in the batch lead to inf/nan loss, they # are filtered out. - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + info["frames"] = ( + (feature_lens // params.subsampling_factor).sum().item() + ) # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa info["utterances"] = feature.size(0) @@ -776,7 +777,9 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -894,7 +897,8 @@ def run(rank, world_size, args): # the threshold if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. " + f"Duration: {c.duration}" ) return False @@ -952,7 +956,9 @@ def run(rank, world_size, args): cur_lr = optimizer._rate if tb_writer is not None: - tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) if rank == 0: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py index 5e9428b60..b7c2010f7 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -580,9 +580,9 @@ def greedy_search( if y not in (blank_id, unk_id): hyp.append(y) timestamp.append(t) - decoder_input = torch.tensor([hyp[-context_size:]], device=device).reshape( - 1, context_size - ) + decoder_input = torch.tensor( + [hyp[-context_size:]], device=device + ).reshape(1, context_size) decoder_out = model.decoder(decoder_input, need_pad=False) decoder_out = model.joiner.decoder_proj(decoder_out) @@ -775,7 +775,9 @@ class HypothesisList(object): key = hyp.key if key in self: old_hyp = self._data[key] # shallow copy - torch.logaddexp(old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob) + torch.logaddexp( + old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob + ) else: self._data[key] = hyp @@ -791,7 +793,9 @@ class HypothesisList(object): Return the hypothesis that has the largest `log_prob`. """ if length_norm: - return max(self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys)) + return max( + self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys) + ) else: return max(self._data.values(), key=lambda hyp: hyp.log_prob) @@ -986,7 +990,9 @@ def modified_beam_search( logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) - log_probs = (logits / temperature).log_softmax(dim=-1) # (num_hyps, vocab_size) + log_probs = (logits / temperature).log_softmax( + dim=-1 + ) # (num_hyps, vocab_size) log_probs.add_(ys_log_probs) @@ -998,7 +1004,9 @@ def modified_beam_search( log_probs_shape = k2.ragged.create_ragged_shape2( row_splits=row_splits, cached_tot_size=log_probs.numel() ) - ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) + ragged_log_probs = k2.RaggedTensor( + shape=log_probs_shape, value=log_probs + ) for i in range(batch_size): topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) @@ -1668,7 +1676,9 @@ def fast_beam_search_with_nbest_rnn_rescoring( for rnn_scale in rnn_lm_scale_list: key = f"ngram_lm_scale_{n_scale}_rnn_lm_scale_{rnn_scale}" tot_scores = ( - am_scores.values + n_scale * ngram_lm_scores + rnn_scale * rnn_lm_scores + am_scores.values + + n_scale * ngram_lm_scores + + rnn_scale * rnn_lm_scores ) ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores) max_indexes = ragged_tot_scores.argmax() @@ -1794,7 +1804,9 @@ def modified_beam_search_ngram_rescoring( logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) - log_probs = (logits / temperature).log_softmax(dim=-1) # (num_hyps, vocab_size) + log_probs = (logits / temperature).log_softmax( + dim=-1 + ) # (num_hyps, vocab_size) log_probs.add_(ys_log_probs) vocab_size = log_probs.size(-1) @@ -1804,7 +1816,9 @@ def modified_beam_search_ngram_rescoring( log_probs_shape = k2.ragged.create_ragged_shape2( row_splits=row_splits, cached_tot_size=log_probs.numel() ) - ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) + ragged_log_probs = k2.RaggedTensor( + shape=log_probs_shape, value=log_probs + ) for i in range(batch_size): topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) @@ -1827,7 +1841,9 @@ def modified_beam_search_ngram_rescoring( state_cost = hyp.state_cost # We only keep AM scores in new_hyp.log_prob - new_log_prob = topk_log_probs[k] - hyp.state_cost.lm_score * lm_scale + new_log_prob = ( + topk_log_probs[k] - hyp.state_cost.lm_score * lm_scale + ) new_hyp = Hypothesis( ys=new_ys, log_prob=new_log_prob, state_cost=state_cost @@ -1979,7 +1995,9 @@ def modified_beam_search_rnnlm_shallow_fusion( log_probs_shape = k2.ragged.create_ragged_shape2( row_splits=row_splits, cached_tot_size=log_probs.numel() ) - ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) + ragged_log_probs = k2.RaggedTensor( + shape=log_probs_shape, value=log_probs + ) """ for all hyps with a non-blank new token, score this token. It is a little confusing here because this for-loop @@ -2014,7 +2032,10 @@ def modified_beam_search_rnnlm_shallow_fusion( # forward RNNLM to get new states and scores if len(token_list) != 0: tokens_to_score = ( - torch.tensor(token_list).to(torch.int64).to(device).reshape(-1, 1) + torch.tensor(token_list) + .to(torch.int64) + .to(device) + .reshape(-1, 1) ) hs = torch.cat(hs, dim=1).to(device) @@ -2046,7 +2067,9 @@ def modified_beam_search_rnnlm_shallow_fusion( ys.append(new_token) new_timestamp.append(t) - hyp_log_prob += lm_score[new_token] * lm_scale # add the lm score + hyp_log_prob += ( + lm_score[new_token] * lm_scale + ) # add the lm score lm_score = scores[count] state = ( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py index 34ff0d7e2..bc273d33b 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py @@ -214,7 +214,10 @@ class Conformer(EncoderInterface): NOTE: the returned tensors are on the given device. """ - if len(self._init_state) == 2 and self._init_state[0].size(1) == left_context: + if ( + len(self._init_state) == 2 + and self._init_state[0].size(1) == left_context + ): # Note: It is OK to share the init state as it is # not going to be modified by the model return self._init_state @@ -436,7 +439,9 @@ class ConformerEncoderLayer(nn.Module): self.d_model = d_model - self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) + self.self_attn = RelPositionMultiheadAttention( + d_model, nhead, dropout=0.0 + ) self.feed_forward = nn.Sequential( ScaledLinear(d_model, dim_feedforward), @@ -454,7 +459,9 @@ class ConformerEncoderLayer(nn.Module): ScaledLinear(dim_feedforward, d_model, initial_scale=0.25), ) - self.conv_module = ConvolutionModule(d_model, cnn_module_kernel, causal=causal) + self.conv_module = ConvolutionModule( + d_model, cnn_module_kernel, causal=causal + ) self.norm_final = BasicNorm(d_model) @@ -520,7 +527,9 @@ class ConformerEncoderLayer(nn.Module): src = src + self.dropout(src_att) # convolution module - conv, _ = self.conv_module(src, src_key_padding_mask=src_key_padding_mask) + conv, _ = self.conv_module( + src, src_key_padding_mask=src_key_padding_mask + ) src = src + self.dropout(conv) # feed forward module @@ -776,7 +785,9 @@ class RelPositionalEncoding(torch.nn.Module): """ - def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: + def __init__( + self, d_model: int, dropout_rate: float, max_len: int = 5000 + ) -> None: """Construct an PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() if is_jit_tracing(): @@ -800,7 +811,9 @@ class RelPositionalEncoding(torch.nn.Module): # the length of self.pe is 2 * input_len - 1 if self.pe.size(1) >= x_size_1 * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): + if self.pe.dtype != x.dtype or str(self.pe.device) != str( + x.device + ): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return # Suppose `i` means to the position of query vector and `j` means the @@ -1114,9 +1127,9 @@ class RelPositionMultiheadAttention(nn.Module): if torch.equal(query, key) and torch.equal(key, value): # self-attention - q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk( - 3, dim=-1 - ) + q, k, v = nn.functional.linear( + query, in_proj_weight, in_proj_bias + ).chunk(3, dim=-1) elif torch.equal(key, value): # encoder-decoder attention @@ -1185,25 +1198,33 @@ class RelPositionMultiheadAttention(nn.Module): if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: - raise RuntimeError("The size of the 2D attn_mask is not correct.") + raise RuntimeError( + "The size of the 2D attn_mask is not correct." + ) elif attn_mask.dim() == 3: if list(attn_mask.size()) != [ bsz * num_heads, query.size(0), key.size(0), ]: - raise RuntimeError("The size of the 3D attn_mask is not correct.") + raise RuntimeError( + "The size of the 3D attn_mask is not correct." + ) else: raise RuntimeError( - "attn_mask's dimension {} is not supported".format(attn_mask.dim()) + "attn_mask's dimension {} is not supported".format( + attn_mask.dim() + ) ) # attn_mask's dim is 3 now. # convert ByteTensor key_padding_mask to bool - if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: + if ( + key_padding_mask is not None + and key_padding_mask.dtype == torch.uint8 + ): warnings.warn( - "Byte tensor for key_padding_mask is deprecated. Use bool tensor" - " instead." + "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." ) key_padding_mask = key_padding_mask.to(torch.bool) @@ -1243,15 +1264,23 @@ class RelPositionMultiheadAttention(nn.Module): # first compute matrix a and matrix c # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2) + matrix_ac = torch.matmul( + q_with_bias_u, k + ) # (batch, head, time1, time2) # compute matrix b and matrix d - matrix_bd = torch.matmul(q_with_bias_v, p) # (batch, head, time1, 2*time1-1) + matrix_bd = torch.matmul( + q_with_bias_v, p + ) # (batch, head, time1, 2*time1-1) matrix_bd = self.rel_shift(matrix_bd, left_context) - attn_output_weights = matrix_ac + matrix_bd # (batch, head, time1, time2) + attn_output_weights = ( + matrix_ac + matrix_bd + ) # (batch, head, time1, time2) - attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) + attn_output_weights = attn_output_weights.view( + bsz * num_heads, tgt_len, -1 + ) if not is_jit_tracing(): assert list(attn_output_weights.size()) == [ @@ -1293,17 +1322,21 @@ class RelPositionMultiheadAttention(nn.Module): ): if attn_mask.size(0) != 1: attn_mask = attn_mask.view(bsz, num_heads, tgt_len, src_len) - combined_mask = attn_mask | key_padding_mask.unsqueeze(1).unsqueeze(2) - else: - # attn_mask.shape == (1, tgt_len, src_len) - combined_mask = attn_mask.unsqueeze(0) | key_padding_mask.unsqueeze( + combined_mask = attn_mask | key_padding_mask.unsqueeze( 1 ).unsqueeze(2) + else: + # attn_mask.shape == (1, tgt_len, src_len) + combined_mask = attn_mask.unsqueeze( + 0 + ) | key_padding_mask.unsqueeze(1).unsqueeze(2) attn_output_weights = attn_output_weights.view( bsz, num_heads, tgt_len, src_len ) - attn_output_weights = attn_output_weights.masked_fill(combined_mask, 0.0) + attn_output_weights = attn_output_weights.masked_fill( + combined_mask, 0.0 + ) attn_output_weights = attn_output_weights.view( bsz * num_heads, tgt_len, src_len ) @@ -1322,9 +1355,13 @@ class RelPositionMultiheadAttention(nn.Module): ] attn_output = ( - attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) + attn_output.transpose(0, 1) + .contiguous() + .view(tgt_len, bsz, embed_dim) + ) + attn_output = nn.functional.linear( + attn_output, out_proj_weight, out_proj_bias ) - attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) if need_weights: # average attention weights over heads @@ -1461,12 +1498,16 @@ class ConvolutionModule(nn.Module): # manualy padding self.lorder zeros to the left x = nn.functional.pad(x, (self.lorder, 0), "constant", 0.0) else: - assert not self.training, "Cache should be None in training time" + assert ( + not self.training + ), "Cache should be None in training time" assert cache.size(0) == self.lorder x = torch.cat([cache.permute(1, 2, 0), x], dim=2) if right_context > 0: cache = x.permute(2, 0, 1)[ - -(self.lorder + right_context) : (-right_context), # noqa + -(self.lorder + right_context) : ( # noqa + -right_context + ), ..., ] else: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py index 32cd53be3..979a0e02e 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py @@ -132,7 +132,11 @@ from beam_search import ( ) from train import add_model_arguments, get_params, get_transducer_model -from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint +from icefall.checkpoint import ( + average_checkpoints, + find_checkpoints, + load_checkpoint, +) from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, @@ -173,11 +177,9 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( @@ -273,7 +275,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( @@ -394,7 +397,9 @@ def decode_one_batch( simulate_streaming=True, ) else: - encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + encoder_out, encoder_out_lens = model.encoder( + x=feature, x_lens=feature_lens + ) hyps = [] @@ -460,7 +465,10 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + elif ( + params.decoding_method == "greedy_search" + and params.max_sym_per_frame == 1 + ): hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -506,7 +514,11 @@ def decode_one_batch( return {"greedy_search": hyps} elif params.decoding_method == "fast_beam_search": return { - f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps + ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}" + ): hyps } elif "fast_beam_search" in params.decoding_method: key = f"beam_{params.beam}_" @@ -596,7 +608,9 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) return results @@ -629,7 +643,8 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -685,7 +700,9 @@ def main(): if "LG" in params.decoding_method: params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + params.suffix += ( + f"-{params.decoding_method}-beam-size-{params.beam_size}" + ) else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -723,7 +740,8 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -761,7 +779,9 @@ def main(): decoding_graph.scores *= params.ngram_lm_scale else: word_table = None - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + decoding_graph = k2.trivial_graph( + params.vocab_size - 1, device=device + ) else: decoding_graph = None word_table = None diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py b/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py index b59928103..ba91302ce 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py @@ -107,11 +107,15 @@ class Decoder(nn.Module): # This is for exporting to PNNX via ONNX embedding_out = self.embedding(y) else: - embedding_out = self.embedding(y.clamp(min=0)) * (y >= 0).unsqueeze(-1) + embedding_out = self.embedding(y.clamp(min=0)) * (y >= 0).unsqueeze( + -1 + ) if self.context_size > 1: embedding_out = embedding_out.permute(0, 2, 1) if need_pad: - embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0)) + embedding_out = F.pad( + embedding_out, pad=(self.context_size - 1, 0) + ) else: # During inference time, there is no need to do extra padding # as we only need one output diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/export.py b/egs/librispeech/ASR/pruned_transducer_stateless2/export.py index 90367bd03..f1a8ea589 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/export.py @@ -51,7 +51,11 @@ import sentencepiece as spm import torch from train import add_model_arguments, get_params, get_transducer_model -from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint +from icefall.checkpoint import ( + average_checkpoints, + find_checkpoints, + load_checkpoint, +) from icefall.utils import str2bool @@ -83,11 +87,9 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( @@ -118,7 +120,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( @@ -170,7 +173,8 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -218,7 +222,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py b/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py index 1954f4724..6a9d08033 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py @@ -60,7 +60,9 @@ class Joiner(nn.Module): assert encoder_out.shape == decoder_out.shape if project_input: - logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out) + logit = self.encoder_proj(encoder_out) + self.decoder_proj( + decoder_out + ) else: logit = encoder_out + decoder_out diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/model.py b/egs/librispeech/ASR/pruned_transducer_stateless2/model.py index 272d06c37..417c391d9 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/model.py @@ -66,7 +66,9 @@ class Transducer(nn.Module): self.decoder = decoder self.joiner = joiner - self.simple_am_proj = ScaledLinear(encoder_dim, vocab_size, initial_speed=0.5) + self.simple_am_proj = ScaledLinear( + encoder_dim, vocab_size, initial_speed=0.5 + ) self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size) def forward( @@ -150,7 +152,9 @@ class Transducer(nn.Module): y_padded = y.pad(mode="constant", padding_value=0) y_padded = y_padded.to(torch.int64) - boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) + boundary = torch.zeros( + (x.size(0), 4), dtype=torch.int64, device=x.device + ) boundary[:, 2] = y_lens boundary[:, 3] = x_lens diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py index 2d7f557ad..041a81f45 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py @@ -72,11 +72,17 @@ class Eve(Optimizer): if not 0.0 <= eps: raise ValueError("Invalid epsilon value: {}".format(eps)) if not 0.0 <= betas[0] < 1.0: - raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + raise ValueError( + "Invalid beta parameter at index 0: {}".format(betas[0]) + ) if not 0.0 <= betas[1] < 1.0: - raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + raise ValueError( + "Invalid beta parameter at index 1: {}".format(betas[1]) + ) if not 0 <= weight_decay <= 0.1: - raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + raise ValueError( + "Invalid weight_decay value: {}".format(weight_decay) + ) if not 0 < target_rms <= 10.0: raise ValueError("Invalid target_rms value: {}".format(target_rms)) defaults = dict( @@ -112,7 +118,9 @@ class Eve(Optimizer): # Perform optimization step grad = p.grad if grad.is_sparse: - raise RuntimeError("AdamW does not support sparse gradients") + raise RuntimeError( + "AdamW does not support sparse gradients" + ) state = self.state[p] @@ -139,7 +147,7 @@ class Eve(Optimizer): # Decay the first and second moment running average coefficient exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) - denom = (exp_avg_sq.sqrt() * (bias_correction2**-0.5)).add_( + denom = (exp_avg_sq.sqrt() * (bias_correction2 ** -0.5)).add_( group["eps"] ) @@ -150,7 +158,9 @@ class Eve(Optimizer): if p.numel() > 1: # avoid applying this weight-decay on "scaling factors" # (which are scalar). - is_above_target_rms = p.norm() > (target_rms * (p.numel() ** 0.5)) + is_above_target_rms = p.norm() > ( + target_rms * (p.numel() ** 0.5) + ) p.mul_(1 - (weight_decay * is_above_target_rms)) p.addcdiv_(exp_avg, denom, value=-step_size) @@ -170,14 +180,18 @@ class LRScheduler(object): def __init__(self, optimizer: Optimizer, verbose: bool = False): # Attach optimizer if not isinstance(optimizer, Optimizer): - raise TypeError("{} is not an Optimizer".format(type(optimizer).__name__)) + raise TypeError( + "{} is not an Optimizer".format(type(optimizer).__name__) + ) self.optimizer = optimizer self.verbose = verbose for group in optimizer.param_groups: group.setdefault("initial_lr", group["lr"]) - self.base_lrs = [group["initial_lr"] for group in optimizer.param_groups] + self.base_lrs = [ + group["initial_lr"] for group in optimizer.param_groups + ] self.epoch = 0 self.batch = 0 @@ -285,9 +299,10 @@ class Eden(LRScheduler): def get_lr(self): factor = ( - (self.batch**2 + self.lr_batches**2) / self.lr_batches**2 + (self.batch ** 2 + self.lr_batches ** 2) / self.lr_batches ** 2 ) ** -0.25 * ( - ((self.epoch**2 + self.lr_epochs**2) / self.lr_epochs**2) ** -0.25 + ((self.epoch ** 2 + self.lr_epochs ** 2) / self.lr_epochs ** 2) + ** -0.25 ) return [x * factor for x in self.base_lrs] diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless2/pretrained.py index 58de6875f..f52cb22ab 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/pretrained.py @@ -91,11 +91,9 @@ def get_parser(): "--checkpoint", type=str, required=True, - help=( - "Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint()." - ), + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", ) parser.add_argument( @@ -120,12 +118,10 @@ def get_parser(): "sound_files", type=str, nargs="+", - help=( - "The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz." - ), + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", ) parser.add_argument( @@ -172,7 +168,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -225,9 +222,10 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) # We use only the first channel ans.append(wave[0]) return ans @@ -295,7 +293,9 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + features = pad_sequence( + features, batch_first=True, padding_value=math.log(1e-10) + ) feature_lengths = torch.tensor(feature_lengths, device=device) @@ -382,7 +382,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py index f671e97b1..8c572a9ef 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py @@ -89,7 +89,9 @@ class ActivationBalancerFunction(torch.autograd.Function): below_threshold = mean_abs < min_abs above_threshold = mean_abs > max_abs - ctx.save_for_backward(factor, xgt0, below_threshold, above_threshold) + ctx.save_for_backward( + factor, xgt0, below_threshold, above_threshold + ) ctx.max_factor = max_factor ctx.sum_dims = sum_dims return x @@ -135,7 +137,7 @@ class GradientFilterFunction(torch.autograd.Function): eps = 1.0e-20 dim = ctx.batch_dim norm_dims = [d for d in range(x_grad.ndim) if d != dim] - norm_of_batch = (x_grad**2).mean(dim=norm_dims, keepdim=True).sqrt() + norm_of_batch = (x_grad ** 2).mean(dim=norm_dims, keepdim=True).sqrt() median_norm = norm_of_batch.median() cutoff = median_norm * ctx.threshold @@ -227,7 +229,8 @@ class BasicNorm(torch.nn.Module): if not is_jit_tracing(): assert x.shape[self.channel_dim] == self.num_channels scales = ( - torch.mean(x**2, dim=self.channel_dim, keepdim=True) + self.eps.exp() + torch.mean(x ** 2, dim=self.channel_dim, keepdim=True) + + self.eps.exp() ) ** -0.5 return x * scales @@ -279,12 +282,12 @@ class ScaledLinear(nn.Linear): def _reset_parameters(self, initial_speed: float): std = 0.1 / initial_speed - a = (3**0.5) * std + a = (3 ** 0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: nn.init.constant_(self.bias, 0.0) fan_in = self.weight.shape[1] * self.weight[0][0].numel() - scale = fan_in**-0.5 # 1/sqrt(fan_in) + scale = fan_in ** -0.5 # 1/sqrt(fan_in) with torch.no_grad(): self.weight_scale += torch.tensor(scale / std).log() @@ -298,7 +301,9 @@ class ScaledLinear(nn.Linear): return self.bias * self.bias_scale.exp() def forward(self, input: Tensor) -> Tensor: - return torch.nn.functional.linear(input, self.get_weight(), self.get_bias()) + return torch.nn.functional.linear( + input, self.get_weight(), self.get_bias() + ) class ScaledConv1d(nn.Conv1d): @@ -326,12 +331,12 @@ class ScaledConv1d(nn.Conv1d): def _reset_parameters(self, initial_speed: float): std = 0.1 / initial_speed - a = (3**0.5) * std + a = (3 ** 0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: nn.init.constant_(self.bias, 0.0) fan_in = self.weight.shape[1] * self.weight[0][0].numel() - scale = fan_in**-0.5 # 1/sqrt(fan_in) + scale = fan_in ** -0.5 # 1/sqrt(fan_in) with torch.no_grad(): self.weight_scale += torch.tensor(scale / std).log() @@ -395,12 +400,12 @@ class ScaledConv2d(nn.Conv2d): def _reset_parameters(self, initial_speed: float): std = 0.1 / initial_speed - a = (3**0.5) * std + a = (3 ** 0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: nn.init.constant_(self.bias, 0.0) fan_in = self.weight.shape[1] * self.weight[0][0].numel() - scale = fan_in**-0.5 # 1/sqrt(fan_in) + scale = fan_in ** -0.5 # 1/sqrt(fan_in) with torch.no_grad(): self.weight_scale += torch.tensor(scale / std).log() @@ -471,7 +476,9 @@ class ScaledLSTM(nn.LSTM): setattr(self, scale_name, param) self._scales.append(param) - self.grad_filter = GradientFilter(batch_dim=1, threshold=grad_norm_threshold) + self.grad_filter = GradientFilter( + batch_dim=1, threshold=grad_norm_threshold + ) self._reset_parameters( initial_speed @@ -479,8 +486,8 @@ class ScaledLSTM(nn.LSTM): def _reset_parameters(self, initial_speed: float): std = 0.1 / initial_speed - a = (3**0.5) * std - scale = self.hidden_size**-0.5 + a = (3 ** 0.5) * std + scale = self.hidden_size ** -0.5 v = scale / std for idx, name in enumerate(self._flat_weights_names): if "weight" in name: @@ -552,11 +559,15 @@ class ScaledLSTM(nn.LSTM): """Get scaled weights, and resets their data pointer.""" flat_weights = [] for idx in range(len(self._flat_weights_names)): - flat_weights.append(self._flat_weights[idx] * self._scales[idx].exp()) + flat_weights.append( + self._flat_weights[idx] * self._scales[idx].exp() + ) self._flatten_parameters(flat_weights) return flat_weights - def forward(self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None): + def forward( + self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None + ): # This function is modified from https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/rnn.py # noqa # The change for calling `_VF.lstm()` is: # self._flat_weights -> self._get_flat_weights() @@ -904,7 +915,9 @@ def _test_activation_balancer_sign(): def _test_activation_balancer_magnitude(): magnitudes = torch.arange(0, 1, 0.01) N = 1000 - x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1) + x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze( + -1 + ) x = x.detach() x.requires_grad = True m = ActivationBalancer( @@ -934,8 +947,8 @@ def _test_basic_norm(): y = m(x) assert y.shape == x.shape - x_rms = (x**2).mean().sqrt() - y_rms = (y**2).mean().sqrt() + x_rms = (x ** 2).mean().sqrt() + y_rms = (y ** 2).mean().sqrt() print("x rms = ", x_rms) print("y rms = ", y_rms) assert y_rms < x_rms @@ -988,18 +1001,17 @@ def _test_grad_filter(): ) print( - "_test_grad_filter: for gradient norms, the first element > median *" - " threshold ", # noqa + "_test_grad_filter: for gradient norms, the first element > median * threshold ", # noqa i % 2 == 1, ) print( "_test_grad_filter: x_out_grad norm = ", - (x_out_grad**2).mean(dim=(0, 2)).sqrt(), + (x_out_grad ** 2).mean(dim=(0, 2)).sqrt(), ) print( "_test_grad_filter: x.grad norm = ", - (x.grad**2).mean(dim=(0, 2)).sqrt(), + (x.grad ** 2).mean(dim=(0, 2)).sqrt(), ) print("_test_grad_filter: w_out_grad = ", w_out_grad) print("_test_grad_filter: w.grad = ", w.grad) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_beam_search.py index e6e0fb1c8..9bcd2f9f9 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_beam_search.py @@ -153,7 +153,9 @@ def modified_beam_search( index=hyps_shape.row_ids(1).to(torch.int64), ) # (num_hyps, encoder_out_dim) - logits = model.joiner(current_encoder_out, decoder_out, project_input=False) + logits = model.joiner( + current_encoder_out, decoder_out, project_input=False + ) # logits is of shape (num_hyps, 1, 1, vocab_size) logits = logits.squeeze(1).squeeze(1) @@ -170,10 +172,14 @@ def modified_beam_search( log_probs_shape = k2.ragged.create_ragged_shape2( row_splits=row_splits, cached_tot_size=log_probs.numel() ) - ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) + ragged_log_probs = k2.RaggedTensor( + shape=log_probs_shape, value=log_probs + ) for i in range(batch_size): - topk_log_probs, topk_indexes = ragged_log_probs[i].topk(num_active_paths) + topk_log_probs, topk_indexes = ragged_log_probs[i].topk( + num_active_paths + ) with warnings.catch_warnings(): warnings.simplefilter("ignore") diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py index 0139863a1..d76a03946 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py @@ -51,7 +51,11 @@ from streaming_beam_search import ( from torch.nn.utils.rnn import pad_sequence from train import add_model_arguments, get_params, get_transducer_model -from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint +from icefall.checkpoint import ( + average_checkpoints, + find_checkpoints, + load_checkpoint, +) from icefall.utils import ( AttributeDict, setup_logger, @@ -90,11 +94,9 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( @@ -160,7 +162,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( @@ -268,7 +271,9 @@ def decode_one_chunk( encoder_out = model.joiner.encoder_proj(encoder_out) if params.decoding_method == "greedy_search": - greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams) + greedy_search( + model=model, encoder_out=encoder_out, streams=decode_streams + ) elif params.decoding_method == "fast_beam_search": processed_lens = processed_lens + encoder_out_lens fast_beam_search_one_best( @@ -288,7 +293,9 @@ def decode_one_chunk( num_active_paths=params.num_active_paths, ) else: - raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) states = [torch.unbind(states[0], dim=2), torch.unbind(states[1], dim=2)] @@ -344,7 +351,9 @@ def decode_dataset( decode_results = [] # Contain decode streams currently running. decode_streams = [] - initial_states = model.encoder.get_init_state(params.left_context, device=device) + initial_states = model.encoder.get_init_state( + params.left_context, device=device + ) for num, cut in enumerate(cuts): # each utterance has a DecodeStream. decode_stream = DecodeStream( @@ -416,7 +425,9 @@ def decode_dataset( elif params.decoding_method == "modified_beam_search": key = f"num_active_paths_{params.num_active_paths}" else: - raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) return {key: decode_results} @@ -451,7 +462,8 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -524,7 +536,8 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index 623bdd51a..1947834bf 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -96,7 +96,9 @@ from icefall.utils import ( str2bool, ) -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] +LRSchedulerType = Union[ + torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler +] def add_model_arguments(parser: argparse.ArgumentParser): @@ -208,7 +210,8 @@ def get_parser(): "--initial-lr", type=float, default=0.003, - help="The initial learning rate. This value should not need to be changed.", + help="The initial learning rate. This value should not need to " + "be changed.", ) parser.add_argument( @@ -231,45 +234,42 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help=( - "The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss" - ), + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help=( - "The scale to smooth the loss with lm (output of prediction network) part." - ), + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)part.", + help="The scale to smooth the loss with am (output of encoder network)" + "part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help=( - "To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss." - ), + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", ) parser.add_argument( @@ -634,7 +634,9 @@ def compute_loss( # If either all simple_loss or pruned_loss is inf or nan, # we stop the training process by raising an exception - if torch.all(~simple_loss_is_finite) or torch.all(~pruned_loss_is_finite): + if torch.all(~simple_loss_is_finite) or torch.all( + ~pruned_loss_is_finite + ): raise ValueError( "There are too many utterances in this batch " "leading to inf or nan losses." @@ -647,9 +649,14 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) + 0.0 + if warmup < 1.0 + else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) + ) + loss = ( + params.simple_loss_scale * simple_loss + + pruned_loss_scale * pruned_loss ) - loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training @@ -660,7 +667,9 @@ def compute_loss( # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2 # (2) If some utterances in the batch lead to inf/nan loss, they # are filtered out. - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + info["frames"] = ( + (feature_lens // params.subsampling_factor).sum().item() + ) # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa info["utterances"] = feature.size(0) @@ -828,7 +837,9 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -952,7 +963,8 @@ def run(rank, world_size, args): # the threshold if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. " + f"Duration: {c.duration}" ) return False diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/asr_datamodule.py b/egs/librispeech/ASR/pruned_transducer_stateless3/asr_datamodule.py index 5e81aef07..1df7f9ee5 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/asr_datamodule.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/asr_datamodule.py @@ -27,7 +27,10 @@ from lhotse.dataset import ( K2SpeechRecognitionDataset, SpecAugment, ) -from lhotse.dataset.input_strategies import OnTheFlyFeatures, PrecomputedFeatures +from lhotse.dataset.input_strategies import ( + OnTheFlyFeatures, + PrecomputedFeatures, +) from torch.utils.data import DataLoader from icefall.utils import str2bool @@ -41,69 +44,59 @@ class AsrDataModule: def add_arguments(cls, parser: argparse.ArgumentParser): group = parser.add_argument_group( title="ASR data related options", - description=( - "These options are used for the preparation of " - "PyTorch DataLoaders from Lhotse CutSet's -- they control the " - "effective batch sizes, sampling strategies, applied data " - "augmentations, etc." - ), + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", ) group.add_argument( "--max-duration", type=int, default=200.0, - help=( - "Maximum pooled recordings duration (seconds) in a " - "single batch. You can reduce it if it causes CUDA OOM." - ), + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", ) group.add_argument( "--bucketing-sampler", type=str2bool, default=True, - help=( - "When enabled, the batches will come from buckets of " - "similar duration (saves padding frames)." - ), + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", ) group.add_argument( "--num-buckets", type=int, default=30, - help=( - "The number of buckets for the DynamicBucketingSampler. " - "(you might want to increase it for larger datasets)." - ), + help="The number of buckets for the DynamicBucketingSampler. " + "(you might want to increase it for larger datasets).", ) group.add_argument( "--shuffle", type=str2bool, default=True, - help=( - "When enabled (=default), the examples will be shuffled for each epoch." - ), + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", ) group.add_argument( "--return-cuts", type=str2bool, default=True, - help=( - "When enabled, each batch will have the " - "field: batch['supervisions']['cut'] with the cuts that " - "were used to construct it." - ), + help="When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it.", ) group.add_argument( "--num-workers", type=int, default=2, - help="The number of training dataloader workers that collect the batches.", + help="The number of training dataloader workers that " + "collect the batches.", ) group.add_argument( @@ -124,22 +117,18 @@ class AsrDataModule: "--spec-aug-time-warp-factor", type=int, default=80, - help=( - "Used only when --enable-spec-aug is True. " - "It specifies the factor for time warping in SpecAugment. " - "Larger values mean more warping. " - "A value less than 1 means to disable time warp." - ), + help="Used only when --enable-spec-aug is True. " + "It specifies the factor for time warping in SpecAugment. " + "Larger values mean more warping. " + "A value less than 1 means to disable time warp.", ) group.add_argument( "--enable-musan", type=str2bool, default=True, - help=( - "When enabled, select noise from MUSAN and mix it" - "with training dataset. " - ), + help="When enabled, select noise from MUSAN and mix it" + "with training dataset. ", ) group.add_argument( @@ -153,11 +142,9 @@ class AsrDataModule: "--on-the-fly-feats", type=str2bool, default=False, - help=( - "When enabled, use on-the-fly cut mixing and feature " - "extraction. Will drop existing precomputed feature manifests " - "if available. Used only in dev/test CutSet" - ), + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available. Used only in dev/test CutSet", ) def train_dataloaders( @@ -180,7 +167,9 @@ class AsrDataModule: if cuts_musan is not None: logging.info("Enable MUSAN") transforms.append( - CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + CutMix( + cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True + ) ) else: logging.info("Disable MUSAN") @@ -189,7 +178,9 @@ class AsrDataModule: if self.args.enable_spec_aug: logging.info("Enable SpecAugment") - logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") + logging.info( + f"Time warp factor: {self.args.spec_aug_time_warp_factor}" + ) input_transforms.append( SpecAugment( time_warp_factor=self.args.spec_aug_time_warp_factor, @@ -259,7 +250,9 @@ class AsrDataModule: if self.args.on_the_fly_feats: validate = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + input_strategy=OnTheFlyFeatures( + Fbank(FbankConfig(num_mel_bins=80)) + ), return_cuts=self.args.return_cuts, ) else: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/decode-giga.py b/egs/librispeech/ASR/pruned_transducer_stateless3/decode-giga.py index 66c8e30ba..5784a78ba 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/decode-giga.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/decode-giga.py @@ -79,7 +79,11 @@ from gigaspeech import GigaSpeech from gigaspeech_scoring import asr_text_post_processing from train import get_params, get_transducer_model -from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint +from icefall.checkpoint import ( + average_checkpoints, + find_checkpoints, + load_checkpoint, +) from icefall.utils import ( AttributeDict, setup_logger, @@ -116,11 +120,9 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( @@ -190,7 +192,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -277,7 +280,9 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + encoder_out, encoder_out_lens = model.encoder( + x=feature, x_lens=feature_lens + ) hyps = [] if params.decoding_method == "fast_beam_search": @@ -307,7 +312,10 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + elif ( + params.decoding_method == "greedy_search" + and params.max_sym_per_frame == 1 + ): hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -351,11 +359,21 @@ def decode_one_batch( return {"greedy_search": hyps} elif params.decoding_method == "fast_beam_search": return { - f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps + ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}" + ): hyps } elif params.decoding_method == "fast_beam_search_nbest_oracle": return { - f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}_num_paths_{params.num_paths}_nbest_scale_{params.nbest_scale}": hyps + ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}_" + f"num_paths_{params.num_paths}_" + f"nbest_scale_{params.nbest_scale}" + ): hyps } else: return {f"beam_size_{params.beam_size}": hyps} @@ -428,7 +446,9 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) return results @@ -461,7 +481,8 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -511,7 +532,9 @@ def main(): params.suffix += f"-num-paths-{params.num_paths}" params.suffix += f"-nbest-scale-{params.nbest_scale}" elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + params.suffix += ( + f"-{params.decoding_method}-beam-size-{params.beam_size}" + ) else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -544,7 +567,8 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py index d90497e26..8025d6be1 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py @@ -120,7 +120,11 @@ from beam_search import ( from librispeech import LibriSpeech from train import add_model_arguments, get_params, get_transducer_model -from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint +from icefall.checkpoint import ( + average_checkpoints, + find_checkpoints, + load_checkpoint, +) from icefall.lexicon import Lexicon from icefall.rnn_lm.model import RnnLmModel from icefall.utils import ( @@ -163,11 +167,9 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( @@ -263,7 +265,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -475,7 +478,9 @@ def decode_one_batch( simulate_streaming=True, ) else: - encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + encoder_out, encoder_out_lens = model.encoder( + x=feature, x_lens=feature_lens + ) hyps = [] @@ -545,7 +550,10 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + elif ( + params.decoding_method == "greedy_search" + and params.max_sym_per_frame == 1 + ): hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -638,11 +646,21 @@ def decode_one_batch( return {"greedy_search": hyps} elif params.decoding_method == "fast_beam_search": return { - f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}temperature_{params.temperature}": hyps + ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}" + f"temperature_{params.temperature}" + ): hyps } elif params.decoding_method == "fast_beam_search": return { - f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}temperature_{params.temperature}": hyps + ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}" + f"temperature_{params.temperature}" + ): hyps } elif params.decoding_method in [ "fast_beam_search_with_nbest_rescoring", @@ -672,7 +690,12 @@ def decode_one_batch( key += f"_ngram_lm_scale_{params.ngram_lm_scale}" return {key: hyps} else: - return {f"beam_size_{params.beam_size}_temperature_{params.temperature}": hyps} + return { + ( + f"beam_size_{params.beam_size}_" + f"temperature_{params.temperature}" + ): hyps + } def decode_dataset( @@ -756,7 +779,9 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) return results @@ -789,7 +814,8 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -913,7 +939,9 @@ def main(): if "LG" in params.decoding_method: params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + params.suffix += ( + f"-{params.decoding_method}-beam-size-{params.beam_size}" + ) params.suffix += f"-temperature-{params.temperature}" else: params.suffix += f"-context-{params.context_size}" @@ -953,7 +981,8 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -1003,10 +1032,15 @@ def main(): word_table=word_table, device=device, ) - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + decoding_graph = k2.trivial_graph( + params.vocab_size - 1, device=device + ) logging.info(f"G properties_str: {G.properties_str}") rnn_lm_model = None - if params.decoding_method == "fast_beam_search_with_nbest_rnn_rescoring": + if ( + params.decoding_method + == "fast_beam_search_with_nbest_rnn_rescoring" + ): rnn_lm_model = RnnLmModel( vocab_size=params.vocab_size, embedding_dim=params.rnn_lm_embedding_dim, @@ -1031,7 +1065,9 @@ def main(): rnn_lm_model.eval() else: word_table = None - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + decoding_graph = k2.trivial_graph( + params.vocab_size - 1, device=device + ) rnn_lm_model = None else: decoding_graph = None diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/export.py b/egs/librispeech/ASR/pruned_transducer_stateless3/export.py index dcf65e937..47217ba05 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/export.py @@ -128,7 +128,11 @@ import torch.nn as nn from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_params, get_transducer_model -from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint +from icefall.checkpoint import ( + average_checkpoints, + find_checkpoints, + load_checkpoint, +) from icefall.utils import str2bool @@ -160,11 +164,9 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( @@ -233,7 +235,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( @@ -506,9 +509,13 @@ def export_joiner_model_onnx( - projected_decoder_out: a tensor of shape (N, joiner_dim) """ - encoder_proj_filename = str(joiner_filename).replace(".onnx", "_encoder_proj.onnx") + encoder_proj_filename = str(joiner_filename).replace( + ".onnx", "_encoder_proj.onnx" + ) - decoder_proj_filename = str(joiner_filename).replace(".onnx", "_decoder_proj.onnx") + decoder_proj_filename = str(joiner_filename).replace( + ".onnx", "_decoder_proj.onnx" + ) encoder_out_dim = joiner_model.encoder_proj.weight.shape[1] decoder_out_dim = joiner_model.decoder_proj.weight.shape[1] @@ -609,7 +616,8 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -707,7 +715,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/gigaspeech.py b/egs/librispeech/ASR/pruned_transducer_stateless3/gigaspeech.py index 598434f54..36f32c6b3 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/gigaspeech.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/gigaspeech.py @@ -52,14 +52,18 @@ class GigaSpeech: ) pattern = re.compile(r"gigaspeech_cuts_XL.([0-9]+).jsonl.gz") - idx_filenames = [(int(pattern.search(f).group(1)), f) for f in filenames] + idx_filenames = [ + (int(pattern.search(f).group(1)), f) for f in filenames + ] idx_filenames = sorted(idx_filenames, key=lambda x: x[0]) sorted_filenames = [f[1] for f in idx_filenames] logging.info(f"Loading {len(sorted_filenames)} splits") - return lhotse.combine(lhotse.load_manifest_lazy(p) for p in sorted_filenames) + return lhotse.combine( + lhotse.load_manifest_lazy(p) for p in sorted_filenames + ) def train_L_cuts(self) -> CutSet: f = self.manifest_dir / "gigaspeech_cuts_L_raw.jsonl.gz" diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/jit_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless3/jit_pretrained.py index 108915389..162f8c7db 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/jit_pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/jit_pretrained.py @@ -104,12 +104,10 @@ def get_parser(): "sound_files", type=str, nargs="+", - help=( - "The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz." - ), + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", ) parser.add_argument( @@ -144,9 +142,10 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) # We use only the first channel ans.append(wave[0]) return ans @@ -331,7 +330,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/model.py b/egs/librispeech/ASR/pruned_transducer_stateless3/model.py index d45f6dadc..7852f84e9 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/model.py @@ -84,7 +84,9 @@ class Transducer(nn.Module): self.decoder_giga = decoder_giga self.joiner_giga = joiner_giga - self.simple_am_proj = ScaledLinear(encoder_dim, vocab_size, initial_speed=0.5) + self.simple_am_proj = ScaledLinear( + encoder_dim, vocab_size, initial_speed=0.5 + ) self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size) if decoder_giga is not None: @@ -188,7 +190,9 @@ class Transducer(nn.Module): y_padded = y.pad(mode="constant", padding_value=0) y_padded = y_padded.to(torch.int64) - boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) + boundary = torch.zeros( + (x.size(0), 4), dtype=torch.int64, device=x.device + ) boundary[:, 2] = y_lens boundary[:, 3] = encoder_out_lens diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check.py b/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check.py index 163d737e3..d03d1d7ef 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check.py @@ -203,7 +203,9 @@ def test_joiner( ) # Now test encoder_proj - joiner_encoder_proj_inputs = {encoder_proj_input_name: encoder_out.numpy()} + joiner_encoder_proj_inputs = { + encoder_proj_input_name: encoder_out.numpy() + } joiner_encoder_proj_out = joiner_encoder_proj_session.run( [encoder_proj_output_name], joiner_encoder_proj_inputs )[0] @@ -212,10 +214,16 @@ def test_joiner( torch_joiner_encoder_proj_out = model.joiner.encoder_proj(encoder_out) assert torch.allclose( joiner_encoder_proj_out, torch_joiner_encoder_proj_out, atol=1e-5 - ), ((joiner_encoder_proj_out - torch_joiner_encoder_proj_out).abs().max()) + ), ( + (joiner_encoder_proj_out - torch_joiner_encoder_proj_out) + .abs() + .max() + ) # Now test decoder_proj - joiner_decoder_proj_inputs = {decoder_proj_input_name: decoder_out.numpy()} + joiner_decoder_proj_inputs = { + decoder_proj_input_name: decoder_out.numpy() + } joiner_decoder_proj_out = joiner_decoder_proj_session.run( [decoder_proj_output_name], joiner_decoder_proj_inputs )[0] @@ -224,7 +232,11 @@ def test_joiner( torch_joiner_decoder_proj_out = model.joiner.decoder_proj(decoder_out) assert torch.allclose( joiner_decoder_proj_out, torch_joiner_decoder_proj_out, atol=1e-5 - ), ((joiner_decoder_proj_out - torch_joiner_decoder_proj_out).abs().max()) + ), ( + (joiner_decoder_proj_out - torch_joiner_decoder_proj_out) + .abs() + .max() + ) @torch.no_grad() @@ -276,7 +288,9 @@ def main(): if __name__ == "__main__": torch.manual_seed(20220727) - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_pretrained.py index 11597aa49..ea5d4e674 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_pretrained.py @@ -102,12 +102,10 @@ def get_parser(): "sound_files", type=str, nargs="+", - help=( - "The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz." - ), + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", ) parser.add_argument( @@ -142,9 +140,10 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) # We use only the first channel ans.append(wave[0]) return ans @@ -192,7 +191,11 @@ def greedy_search( projected_encoder_out = joiner_encoder_proj.run( [joiner_encoder_proj.get_outputs()[0].name], - {joiner_encoder_proj.get_inputs()[0].name: packed_encoder_out.data.numpy()}, + { + joiner_encoder_proj.get_inputs()[ + 0 + ].name: packed_encoder_out.data.numpy() + }, )[0] blank_id = 0 # hard-code to 0 @@ -379,7 +382,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py index 849d6cf4e..19b636a23 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py @@ -100,11 +100,9 @@ def get_parser(): "--checkpoint", type=str, required=True, - help=( - "Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint()." - ), + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", ) parser.add_argument( @@ -129,12 +127,10 @@ def get_parser(): "sound_files", type=str, nargs="+", - help=( - "The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz." - ), + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", ) parser.add_argument( @@ -181,7 +177,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -234,9 +231,10 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) # We use only the first channel ans.append(wave[0]) return ans @@ -304,7 +302,9 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + features = pad_sequence( + features, batch_first=True, padding_value=math.log(1e-10) + ) feature_lengths = torch.tensor(feature_lengths, device=device) @@ -391,7 +391,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py b/egs/librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py index 85d87f8f2..1e6022b57 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py @@ -234,7 +234,9 @@ def scaled_lstm_to_lstm(scaled_lstm: ScaledLSTM) -> nn.LSTM: assert lstm._flat_weights_names == scaled_lstm._flat_weights_names for idx in range(len(scaled_lstm._flat_weights_names)): - scaled_weight = scaled_lstm._flat_weights[idx] * scaled_lstm._scales[idx].exp() + scaled_weight = ( + scaled_lstm._flat_weights[idx] * scaled_lstm._scales[idx].exp() + ) lstm._flat_weights[idx].data.copy_(scaled_weight) return lstm @@ -249,10 +251,12 @@ def get_submodule(model, target): mod: torch.nn.Module = model for item in atoms: if not hasattr(mod, item): - raise AttributeError(mod._get_name() + " has no attribute `" + item + "`") + raise AttributeError( + mod._get_name() + " has no " "attribute `" + item + "`" + ) mod = getattr(mod, item) if not isinstance(mod, torch.nn.Module): - raise AttributeError("`" + item + "` is not an nn.Module") + raise AttributeError("`" + item + "` is not " "an nn.Module") return mod diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_decode.py index 41a712498..10bb44e00 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_decode.py @@ -52,7 +52,11 @@ from streaming_beam_search import ( from torch.nn.utils.rnn import pad_sequence from train import add_model_arguments, get_params, get_transducer_model -from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint +from icefall.checkpoint import ( + average_checkpoints, + find_checkpoints, + load_checkpoint, +) from icefall.utils import ( AttributeDict, setup_logger, @@ -91,11 +95,9 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( @@ -161,7 +163,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( @@ -269,7 +272,9 @@ def decode_one_chunk( encoder_out = model.joiner.encoder_proj(encoder_out) if params.decoding_method == "greedy_search": - greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams) + greedy_search( + model=model, encoder_out=encoder_out, streams=decode_streams + ) elif params.decoding_method == "fast_beam_search": processed_lens = processed_lens + encoder_out_lens fast_beam_search_one_best( @@ -289,7 +294,9 @@ def decode_one_chunk( num_active_paths=params.num_active_paths, ) else: - raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) states = [torch.unbind(states[0], dim=2), torch.unbind(states[1], dim=2)] @@ -345,7 +352,9 @@ def decode_dataset( decode_results = [] # Contain decode streams currently running. decode_streams = [] - initial_states = model.encoder.get_init_state(params.left_context, device=device) + initial_states = model.encoder.get_init_state( + params.left_context, device=device + ) for num, cut in enumerate(cuts): # each utterance has a DecodeStream. decode_stream = DecodeStream( @@ -417,7 +426,9 @@ def decode_dataset( elif params.decoding_method == "modified_beam_search": key = f"num_active_paths_{params.num_active_paths}" else: - raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) return {key: decode_results} @@ -450,7 +461,8 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -523,7 +535,8 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/test_onnx.py b/egs/librispeech/ASR/pruned_transducer_stateless3/test_onnx.py index 598fcf344..66ffbd3ec 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/test_onnx.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/test_onnx.py @@ -90,7 +90,9 @@ def test_conv2d_subsampling(): onnx_y = torch.from_numpy(onnx_y) torch_y = jit_model(x) - assert torch.allclose(onnx_y, torch_y, atol=1e-05), (onnx_y - torch_y).abs().max() + assert torch.allclose(onnx_y, torch_y, atol=1e-05), ( + (onnx_y - torch_y).abs().max() + ) os.remove(filename) @@ -145,7 +147,9 @@ def test_rel_pos(): onnx_pos_emb = torch.from_numpy(onnx_pos_emb) torch_y, torch_pos_emb = jit_model(x) - assert torch.allclose(onnx_y, torch_y, atol=1e-05), (onnx_y - torch_y).abs().max() + assert torch.allclose(onnx_y, torch_y, atol=1e-05), ( + (onnx_y - torch_y).abs().max() + ) assert torch.allclose(onnx_pos_emb, torch_pos_emb, atol=1e-05), ( (onnx_pos_emb - torch_pos_emb).abs().max() @@ -193,7 +197,9 @@ def test_conformer_encoder_layer(): encoder_layer.eval() encoder_layer = convert_scaled_to_non_scaled(encoder_layer, inplace=True) - jit_model = torch.jit.trace(encoder_layer, (x, pos_emb, src_key_padding_mask)) + jit_model = torch.jit.trace( + encoder_layer, (x, pos_emb, src_key_padding_mask) + ) torch.onnx.export( encoder_layer, @@ -230,7 +236,9 @@ def test_conformer_encoder_layer(): onnx_y = torch.from_numpy(onnx_y) torch_y = jit_model(x, pos_emb, src_key_padding_mask) - assert torch.allclose(onnx_y, torch_y, atol=1e-05), (onnx_y - torch_y).abs().max() + assert torch.allclose(onnx_y, torch_y, atol=1e-05), ( + (onnx_y - torch_y).abs().max() + ) print(onnx_y.abs().sum(), torch_y.abs().sum(), onnx_y.shape, torch_y.shape) @@ -314,7 +322,9 @@ def test_conformer_encoder(): onnx_y = torch.from_numpy(onnx_y) torch_y = jit_model(x, pos_emb, src_key_padding_mask) - assert torch.allclose(onnx_y, torch_y, atol=1e-05), (onnx_y - torch_y).abs().max() + assert torch.allclose(onnx_y, torch_y, atol=1e-05), ( + (onnx_y - torch_y).abs().max() + ) print(onnx_y.abs().sum(), torch_y.abs().sum(), onnx_y.shape, torch_y.shape) @@ -369,7 +379,9 @@ def test_conformer(): onnx_y_lens = torch.from_numpy(onnx_y_lens) torch_y, torch_y_lens = jit_model(x, x_lens) - assert torch.allclose(onnx_y, torch_y, atol=1e-05), (onnx_y - torch_y).abs().max() + assert torch.allclose(onnx_y, torch_y, atol=1e-05), ( + (onnx_y - torch_y).abs().max() + ) assert torch.allclose(onnx_y_lens, torch_y_lens, atol=1e-05), ( (onnx_y_lens - torch_y_lens).abs().max() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/train.py b/egs/librispeech/ASR/pruned_transducer_stateless3/train.py index 6724343dd..44e96644a 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/train.py @@ -92,7 +92,9 @@ from icefall.utils import ( str2bool, ) -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] +LRSchedulerType = Union[ + torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler +] def add_model_arguments(parser: argparse.ArgumentParser): @@ -161,7 +163,8 @@ def get_parser(): "--full-libri", type=str2bool, default=True, - help="When enabled, use 960h LibriSpeech. Otherwise, use 100h subset.", + help="When enabled, use 960h LibriSpeech. " + "Otherwise, use 100h subset.", ) parser.add_argument( @@ -211,7 +214,8 @@ def get_parser(): "--initial-lr", type=float, default=0.003, - help="The initial learning rate. This value should not need to be changed.", + help="The initial learning rate. This value should not need " + "to be changed.", ) parser.add_argument( @@ -234,45 +238,42 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help=( - "The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss" - ), + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help=( - "The scale to smooth the loss with lm (output of prediction network) part." - ), + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)part.", + help="The scale to smooth the loss with am (output of encoder network)" + "part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help=( - "To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss." - ), + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", ) parser.add_argument( @@ -671,7 +672,9 @@ def compute_loss( # If either all simple_loss or pruned_loss is inf or nan, # we stop the training process by raising an exception - if torch.all(~simple_loss_is_finite) or torch.all(~pruned_loss_is_finite): + if torch.all(~simple_loss_is_finite) or torch.all( + ~pruned_loss_is_finite + ): raise ValueError( "There are too many utterances in this batch " "leading to inf or nan losses." @@ -684,9 +687,14 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) + 0.0 + if warmup < 1.0 + else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) + ) + loss = ( + params.simple_loss_scale * simple_loss + + pruned_loss_scale * pruned_loss ) - loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training @@ -697,7 +705,9 @@ def compute_loss( # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2 # (2) If some utterances in the batch lead to inf/nan loss, they # are filtered out. - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + info["frames"] = ( + (feature_lens // params.subsampling_factor).sum().item() + ) # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa info["utterances"] = feature.size(0) @@ -909,7 +919,9 @@ def train_one_epoch( f"train/current_{prefix}_", params.batch_idx_train, ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) libri_tot_loss.write_summary( tb_writer, "train/libri_tot_", params.batch_idx_train ) @@ -955,7 +967,8 @@ def filter_short_and_long_utterances( # the threshold if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. " + f"Duration: {c.duration}" ) return False @@ -1096,7 +1109,9 @@ def run(rank, world_size, args): train_giga_cuts = train_giga_cuts.repeat(times=None) if args.enable_musan: - cuts_musan = load_manifest(Path(args.manifest_dir) / "musan_cuts.jsonl.gz") + cuts_musan = load_manifest( + Path(args.manifest_dir) / "musan_cuts.jsonl.gz" + ) else: cuts_musan = None diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py index 69cfcd298..4f043e5a6 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py @@ -197,24 +197,20 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help=( - "Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. " - ), + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", ) parser.add_argument( @@ -310,7 +306,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -430,7 +427,9 @@ def decode_one_batch( simulate_streaming=True, ) else: - encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + encoder_out, encoder_out_lens = model.encoder( + x=feature, x_lens=feature_lens + ) if ( params.decoding_method == "fast_beam_search" @@ -486,7 +485,10 @@ def decode_one_batch( nbest_scale=params.nbest_scale, return_timestamps=True, ) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + elif ( + params.decoding_method == "greedy_search" + and params.max_sym_per_frame == 1 + ): res = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -564,7 +566,9 @@ def decode_dataset( sp: spm.SentencePieceProcessor, word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[str, List[str], List[str], List[float], List[float]]]]: +) -> Dict[ + str, List[Tuple[str, List[str], List[str], List[float], List[float]]] +]: """Decode dataset. Args: @@ -639,7 +643,9 @@ def decode_dataset( cut_ids, hyps, texts, timestamps_hyp, timestamps_ref ): ref_words = ref_text.split() - this_batch.append((cut_id, ref_words, hyp_words, time_ref, time_hyp)) + this_batch.append( + (cut_id, ref_words, hyp_words, time_ref, time_hyp) + ) results[name].extend(this_batch) @@ -648,7 +654,9 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) return results @@ -686,7 +694,8 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -713,7 +722,9 @@ def save_results( note = "" logging.info(s) - s = "\nFor {}, symbol-delay of different settings are:\n".format(test_set_name) + s = "\nFor {}, symbol-delay of different settings are:\n".format( + test_set_name + ) note = "\tbest for {}".format(test_set_name) for key, val in test_set_delays: s += "{}\tmean: {}s, variance: {}{}\n".format(key, val[0], val[1], note) @@ -762,7 +773,9 @@ def main(): if "LG" in params.decoding_method: params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + params.suffix += ( + f"-{params.decoding_method}-beam-size-{params.beam_size}" + ) else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -799,12 +812,13 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -827,12 +841,13 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -860,7 +875,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - "Calculating the averaged model over epoch range from " + f"Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -887,7 +902,9 @@ def main(): decoding_graph.scores *= params.ngram_lm_scale else: word_table = None - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + decoding_graph = k2.trivial_graph( + params.vocab_size - 1, device=device + ) else: decoding_graph = None word_table = None diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/export.py b/egs/librispeech/ASR/pruned_transducer_stateless4/export.py index bd5801a78..ce7518ceb 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/export.py @@ -89,24 +89,20 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help=( - "Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. " - ), + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", ) parser.add_argument( @@ -137,7 +133,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( @@ -186,12 +183,13 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -214,12 +212,13 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -247,7 +246,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - "Calculating the averaged model over epoch range from " + f"Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -283,7 +282,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_decode.py index a28e52c78..7af9ea9b8 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_decode.py @@ -96,24 +96,20 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help=( - "Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. " - ), + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", ) parser.add_argument( @@ -179,7 +175,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( @@ -287,7 +284,9 @@ def decode_one_chunk( encoder_out = model.joiner.encoder_proj(encoder_out) if params.decoding_method == "greedy_search": - greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams) + greedy_search( + model=model, encoder_out=encoder_out, streams=decode_streams + ) elif params.decoding_method == "fast_beam_search": processed_lens = processed_lens + encoder_out_lens fast_beam_search_one_best( @@ -307,7 +306,9 @@ def decode_one_chunk( num_active_paths=params.num_active_paths, ) else: - raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) states = [torch.unbind(states[0], dim=2), torch.unbind(states[1], dim=2)] @@ -363,7 +364,9 @@ def decode_dataset( decode_results = [] # Contain decode streams currently running. decode_streams = [] - initial_states = model.encoder.get_init_state(params.left_context, device=device) + initial_states = model.encoder.get_init_state( + params.left_context, device=device + ) for num, cut in enumerate(cuts): # each utterance has a DecodeStream. decode_stream = DecodeStream( @@ -435,7 +438,9 @@ def decode_dataset( elif params.decoding_method == "modified_beam_search": key = f"num_active_paths_{params.num_active_paths}" else: - raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) return {key: decode_results} @@ -468,7 +473,8 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -541,12 +547,13 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -569,12 +576,13 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -602,7 +610,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - "Calculating the averaged model over epoch range from " + f"Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py index 76785a845..cf32e565b 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py @@ -101,7 +101,9 @@ from icefall.utils import ( str2bool, ) -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] +LRSchedulerType = Union[ + torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler +] def add_model_arguments(parser: argparse.ArgumentParser): @@ -237,45 +239,42 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help=( - "The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss" - ), + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help=( - "The scale to smooth the loss with lm (output of prediction network) part." - ), + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)part.", + help="The scale to smooth the loss with am (output of encoder network)" + "part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help=( - "To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss." - ), + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", ) parser.add_argument( @@ -622,7 +621,11 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = model.device if isinstance(model, DDP) else next(model.parameters()).device + device = ( + model.device + if isinstance(model, DDP) + else next(model.parameters()).device + ) feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -662,7 +665,9 @@ def compute_loss( # If either all simple_loss or pruned_loss is inf or nan, # we stop the training process by raising an exception - if torch.all(~simple_loss_is_finite) or torch.all(~pruned_loss_is_finite): + if torch.all(~simple_loss_is_finite) or torch.all( + ~pruned_loss_is_finite + ): raise ValueError( "There are too many utterances in this batch " "leading to inf or nan losses." @@ -675,9 +680,14 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) + 0.0 + if warmup < 1.0 + else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) + ) + loss = ( + params.simple_loss_scale * simple_loss + + pruned_loss_scale * pruned_loss ) - loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training @@ -688,7 +698,9 @@ def compute_loss( # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2 # (2) If some utterances in the batch lead to inf/nan loss, they # are filtered out. - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + info["frames"] = ( + (feature_lens // params.subsampling_factor).sum().item() + ) # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa info["utterances"] = feature.size(0) @@ -867,7 +879,9 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -999,7 +1013,8 @@ def run(rank, world_size, args): # the threshold if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. " + f"Duration: {c.duration}" ) return False diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py index 8499651d7..427b06294 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py @@ -214,7 +214,10 @@ class Conformer(EncoderInterface): (num_encoder_layers, cnn_module_kernel - 1, encoder_dim). NOTE: the returned tensors are on the given device. """ - if len(self._init_state) == 2 and self._init_state[0].size(1) == left_context: + if ( + len(self._init_state) == 2 + and self._init_state[0].size(1) == left_context + ): # Note: It is OK to share the init state as it is # not going to be modified by the model return self._init_state @@ -436,7 +439,9 @@ class ConformerEncoderLayer(nn.Module): self.d_model = d_model - self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) + self.self_attn = RelPositionMultiheadAttention( + d_model, nhead, dropout=0.0 + ) self.feed_forward = nn.Sequential( ScaledLinear(d_model, dim_feedforward), @@ -454,7 +459,9 @@ class ConformerEncoderLayer(nn.Module): ScaledLinear(dim_feedforward, d_model, initial_scale=0.25), ) - self.conv_module = ConvolutionModule(d_model, cnn_module_kernel, causal=causal) + self.conv_module = ConvolutionModule( + d_model, cnn_module_kernel, causal=causal + ) self.norm_final = BasicNorm(d_model) @@ -520,7 +527,9 @@ class ConformerEncoderLayer(nn.Module): src = src + self.dropout(src_att) # convolution module - conv, _ = self.conv_module(src, src_key_padding_mask=src_key_padding_mask) + conv, _ = self.conv_module( + src, src_key_padding_mask=src_key_padding_mask + ) src = src + self.dropout(conv) # feed forward module @@ -793,7 +802,9 @@ class RelPositionalEncoding(torch.nn.Module): """ - def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: + def __init__( + self, d_model: int, dropout_rate: float, max_len: int = 5000 + ) -> None: """Construct an PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() self.d_model = d_model @@ -809,7 +820,9 @@ class RelPositionalEncoding(torch.nn.Module): # the length of self.pe is 2 * input_len - 1 if self.pe.size(1) >= x_size_1 * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): + if self.pe.dtype != x.dtype or str(self.pe.device) != str( + x.device + ): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return # Suppose `i` means to the position of query vector and `j` means the @@ -835,7 +848,9 @@ class RelPositionalEncoding(torch.nn.Module): pe = torch.cat([pe_positive, pe_negative], dim=1) self.pe = pe.to(device=x.device, dtype=x.dtype) - def forward(self, x: torch.Tensor, left_context: int = 0) -> Tuple[Tensor, Tensor]: + def forward( + self, x: torch.Tensor, left_context: int = 0 + ) -> Tuple[Tensor, Tensor]: """Add positional encoding. Args: @@ -1103,9 +1118,9 @@ class RelPositionMultiheadAttention(nn.Module): if torch.equal(query, key) and torch.equal(key, value): # self-attention - q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk( - 3, dim=-1 - ) + q, k, v = nn.functional.linear( + query, in_proj_weight, in_proj_bias + ).chunk(3, dim=-1) elif torch.equal(key, value): # encoder-decoder attention @@ -1174,25 +1189,33 @@ class RelPositionMultiheadAttention(nn.Module): if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: - raise RuntimeError("The size of the 2D attn_mask is not correct.") + raise RuntimeError( + "The size of the 2D attn_mask is not correct." + ) elif attn_mask.dim() == 3: if list(attn_mask.size()) != [ bsz * num_heads, query.size(0), key.size(0), ]: - raise RuntimeError("The size of the 3D attn_mask is not correct.") + raise RuntimeError( + "The size of the 3D attn_mask is not correct." + ) else: raise RuntimeError( - "attn_mask's dimension {} is not supported".format(attn_mask.dim()) + "attn_mask's dimension {} is not supported".format( + attn_mask.dim() + ) ) # attn_mask's dim is 3 now. # convert ByteTensor key_padding_mask to bool - if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: + if ( + key_padding_mask is not None + and key_padding_mask.dtype == torch.uint8 + ): warnings.warn( - "Byte tensor for key_padding_mask is deprecated. Use bool tensor" - " instead." + "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." ) key_padding_mask = key_padding_mask.to(torch.bool) @@ -1230,15 +1253,23 @@ class RelPositionMultiheadAttention(nn.Module): # first compute matrix a and matrix c # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2) + matrix_ac = torch.matmul( + q_with_bias_u, k + ) # (batch, head, time1, time2) # compute matrix b and matrix d - matrix_bd = torch.matmul(q_with_bias_v, p) # (batch, head, time1, 2*time1-1) + matrix_bd = torch.matmul( + q_with_bias_v, p + ) # (batch, head, time1, 2*time1-1) matrix_bd = self.rel_shift(matrix_bd, left_context) - attn_output_weights = matrix_ac + matrix_bd # (batch, head, time1, time2) + attn_output_weights = ( + matrix_ac + matrix_bd + ) # (batch, head, time1, time2) - attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) + attn_output_weights = attn_output_weights.view( + bsz * num_heads, tgt_len, -1 + ) assert list(attn_output_weights.size()) == [ bsz * num_heads, @@ -1279,17 +1310,21 @@ class RelPositionMultiheadAttention(nn.Module): ): if attn_mask.size(0) != 1: attn_mask = attn_mask.view(bsz, num_heads, tgt_len, src_len) - combined_mask = attn_mask | key_padding_mask.unsqueeze(1).unsqueeze(2) - else: - # attn_mask.shape == (1, tgt_len, src_len) - combined_mask = attn_mask.unsqueeze(0) | key_padding_mask.unsqueeze( + combined_mask = attn_mask | key_padding_mask.unsqueeze( 1 ).unsqueeze(2) + else: + # attn_mask.shape == (1, tgt_len, src_len) + combined_mask = attn_mask.unsqueeze( + 0 + ) | key_padding_mask.unsqueeze(1).unsqueeze(2) attn_output_weights = attn_output_weights.view( bsz, num_heads, tgt_len, src_len ) - attn_output_weights = attn_output_weights.masked_fill(combined_mask, 0.0) + attn_output_weights = attn_output_weights.masked_fill( + combined_mask, 0.0 + ) attn_output_weights = attn_output_weights.view( bsz * num_heads, tgt_len, src_len ) @@ -1301,9 +1336,13 @@ class RelPositionMultiheadAttention(nn.Module): attn_output = torch.bmm(attn_output_weights, v) assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] attn_output = ( - attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) + attn_output.transpose(0, 1) + .contiguous() + .view(tgt_len, bsz, embed_dim) + ) + attn_output = nn.functional.linear( + attn_output, out_proj_weight, out_proj_bias ) - attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) if need_weights: # average attention weights over heads @@ -1442,12 +1481,16 @@ class ConvolutionModule(nn.Module): # manualy padding self.lorder zeros to the left x = nn.functional.pad(x, (self.lorder, 0), "constant", 0.0) else: - assert not self.training, "Cache should be None in training time" + assert ( + not self.training + ), "Cache should be None in training time" assert cache.size(0) == self.lorder x = torch.cat([cache.permute(1, 2, 0), x], dim=2) if right_context > 0: cache = x.permute(2, 0, 1)[ - -(self.lorder + right_context) : (-right_context), # noqa + -(self.lorder + right_context) : ( # noqa + -right_context + ), ..., ] else: @@ -1623,7 +1666,9 @@ class RandomCombine(nn.Module): self.stddev = stddev self.final_log_weight = ( - torch.tensor((final_weight / (1 - final_weight)) * (self.num_inputs - 1)) + torch.tensor( + (final_weight / (1 - final_weight)) * (self.num_inputs - 1) + ) .log() .item() ) @@ -1720,14 +1765,16 @@ class RandomCombine(nn.Module): # final contains self.num_inputs - 1 in all elements final = torch.full((num_frames,), self.num_inputs - 1, device=device) # nonfinal contains random integers in [0..num_inputs - 2], these are for non-final weights. - nonfinal = torch.randint(self.num_inputs - 1, (num_frames,), device=device) + nonfinal = torch.randint( + self.num_inputs - 1, (num_frames,), device=device + ) indexes = torch.where( torch.rand(num_frames, device=device) < final_prob, final, nonfinal ) - ans = torch.nn.functional.one_hot(indexes, num_classes=self.num_inputs).to( - dtype=dtype - ) + ans = torch.nn.functional.one_hot( + indexes, num_classes=self.num_inputs + ).to(dtype=dtype) return ans def _get_random_mixed_weights( @@ -1757,8 +1804,7 @@ class RandomCombine(nn.Module): def _test_random_combine(final_weight: float, pure_prob: float, stddev: float): print( - f"_test_random_combine: final_weight={final_weight}, pure_prob={pure_prob}," - f" stddev={stddev}" + f"_test_random_combine: final_weight={final_weight}, pure_prob={pure_prob}, stddev={stddev}" ) num_inputs = 3 num_channels = 50 diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py index f462cc42f..22bcdd88e 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py @@ -179,24 +179,20 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help=( - "Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. " - ), + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", ) parser.add_argument( @@ -307,7 +303,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( @@ -480,7 +477,9 @@ def decode_one_batch( simulate_streaming=True, ) else: - encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + encoder_out, encoder_out_lens = model.encoder( + x=feature, x_lens=feature_lens + ) hyps = [] @@ -546,7 +545,10 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + elif ( + params.decoding_method == "greedy_search" + and params.max_sym_per_frame == 1 + ): hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -694,7 +696,9 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) return results @@ -727,7 +731,8 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -782,7 +787,9 @@ def main(): if "LG" in params.decoding_method: params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + params.suffix += ( + f"-{params.decoding_method}-beam-size-{params.beam_size}" + ) else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -821,12 +828,13 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -849,12 +857,13 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -882,7 +891,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - "Calculating the averaged model over epoch range from " + f"Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -928,7 +937,9 @@ def main(): decoding_graph.scores *= params.ngram_lm_scale else: word_table = None - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + decoding_graph = k2.trivial_graph( + params.vocab_size - 1, device=device + ) else: decoding_graph = None word_table = None diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/export.py b/egs/librispeech/ASR/pruned_transducer_stateless5/export.py index a739c17bc..b2e5b430e 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/export.py @@ -89,24 +89,20 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help=( - "Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. " - ), + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", ) parser.add_argument( @@ -137,7 +133,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( @@ -184,12 +181,13 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -212,12 +210,13 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -245,7 +244,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - "Calculating the averaged model over epoch range from " + f"Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -281,7 +280,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless5/pretrained.py index e2da0da4c..1e100fcbd 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/pretrained.py @@ -89,11 +89,9 @@ def get_parser(): "--checkpoint", type=str, required=True, - help=( - "Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint()." - ), + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", ) parser.add_argument( @@ -118,12 +116,10 @@ def get_parser(): "sound_files", type=str, nargs="+", - help=( - "The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz." - ), + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", ) parser.add_argument( @@ -170,7 +166,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -201,9 +198,10 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) # We use only the first channel ans.append(wave[0]) return ans @@ -266,11 +264,15 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + features = pad_sequence( + features, batch_first=True, padding_value=math.log(1e-10) + ) feature_lengths = torch.tensor(feature_lengths, device=device) - encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths) + encoder_out, encoder_out_lens = model.encoder( + x=features, x_lens=feature_lengths + ) num_waves = encoder_out.size(0) hyps = [] @@ -342,7 +344,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless5/streaming_decode.py index 59a0e8fa2..6fee9483e 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/streaming_decode.py @@ -96,24 +96,20 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help=( - "Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. " - ), + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", ) parser.add_argument( @@ -179,7 +175,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( @@ -287,7 +284,9 @@ def decode_one_chunk( encoder_out = model.joiner.encoder_proj(encoder_out) if params.decoding_method == "greedy_search": - greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams) + greedy_search( + model=model, encoder_out=encoder_out, streams=decode_streams + ) elif params.decoding_method == "fast_beam_search": processed_lens = processed_lens + encoder_out_lens fast_beam_search_one_best( @@ -307,7 +306,9 @@ def decode_one_chunk( num_active_paths=params.num_active_paths, ) else: - raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) states = [torch.unbind(states[0], dim=2), torch.unbind(states[1], dim=2)] @@ -363,7 +364,9 @@ def decode_dataset( decode_results = [] # Contain decode streams currently running. decode_streams = [] - initial_states = model.encoder.get_init_state(params.left_context, device=device) + initial_states = model.encoder.get_init_state( + params.left_context, device=device + ) for num, cut in enumerate(cuts): # each utterance has a DecodeStream. decode_stream = DecodeStream( @@ -435,7 +438,9 @@ def decode_dataset( elif params.decoding_method == "modified_beam_search": key = f"num_active_paths_{params.num_active_paths}" else: - raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) return {key: decode_results} @@ -468,7 +473,8 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -541,12 +547,13 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -569,12 +576,13 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -602,7 +610,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - "Calculating the averaged model over epoch range from " + f"Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/train.py b/egs/librispeech/ASR/pruned_transducer_stateless5/train.py index 75696d61b..179d9372e 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/train.py @@ -89,7 +89,9 @@ from icefall.utils import ( str2bool, ) -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] +LRSchedulerType = Union[ + torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler +] def add_model_arguments(parser: argparse.ArgumentParser): @@ -246,7 +248,8 @@ def get_parser(): "--initial-lr", type=float, default=0.003, - help="The initial learning rate. This value should not need to be changed.", + help="The initial learning rate. This value should not need " + "to be changed.", ) parser.add_argument( @@ -269,45 +272,42 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help=( - "The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss" - ), + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help=( - "The scale to smooth the loss with lm (output of prediction network) part." - ), + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)part.", + help="The scale to smooth the loss with am (output of encoder network)" + "part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help=( - "To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss." - ), + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", ) parser.add_argument( @@ -645,7 +645,11 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = model.device if isinstance(model, DDP) else next(model.parameters()).device + device = ( + model.device + if isinstance(model, DDP) + else next(model.parameters()).device + ) feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -686,7 +690,9 @@ def compute_loss( # If the batch contains more than 10 utterances AND # if either all simple_loss or pruned_loss is inf or nan, # we stop the training process by raising an exception - if torch.all(~simple_loss_is_finite) or torch.all(~pruned_loss_is_finite): + if torch.all(~simple_loss_is_finite) or torch.all( + ~pruned_loss_is_finite + ): raise ValueError( "There are too many utterances in this batch " "leading to inf or nan losses." @@ -699,9 +705,14 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) + 0.0 + if warmup < 1.0 + else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) + ) + loss = ( + params.simple_loss_scale * simple_loss + + pruned_loss_scale * pruned_loss ) - loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training @@ -712,7 +723,9 @@ def compute_loss( # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2 # (2) If some utterances in the batch lead to inf/nan loss, they # are filtered out. - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + info["frames"] = ( + (feature_lens // params.subsampling_factor).sum().item() + ) # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa info["utterances"] = feature.size(0) @@ -895,7 +908,9 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -1008,7 +1023,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 2 ** 22 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) @@ -1030,7 +1045,8 @@ def run(rank, world_size, args): # the threshold if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. " + f"Duration: {c.duration}" ) return False diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless6/conformer.py index 40ad61fd4..53788b3f7 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/conformer.py @@ -90,7 +90,10 @@ class Conformer(EncoderInterface): output_layers = [] if middle_output_layer is not None: - assert middle_output_layer >= 0 and middle_output_layer < num_encoder_layers + assert ( + middle_output_layer >= 0 + and middle_output_layer < num_encoder_layers + ) output_layers.append(middle_output_layer) # The last layer is always needed. @@ -175,7 +178,9 @@ class ConformerEncoderLayer(nn.Module): self.d_model = d_model - self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) + self.self_attn = RelPositionMultiheadAttention( + d_model, nhead, dropout=0.0 + ) self.feed_forward = nn.Sequential( ScaledLinear(d_model, dim_feedforward), @@ -357,7 +362,9 @@ class RelPositionalEncoding(torch.nn.Module): """ - def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: + def __init__( + self, d_model: int, dropout_rate: float, max_len: int = 5000 + ) -> None: """Construct an PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() self.d_model = d_model @@ -372,7 +379,9 @@ class RelPositionalEncoding(torch.nn.Module): # the length of self.pe is 2 * input_len - 1 if self.pe.size(1) >= x.size(1) * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): + if self.pe.dtype != x.dtype or str(self.pe.device) != str( + x.device + ): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return # Suppose `i` means to the position of query vector and `j` means the @@ -647,9 +656,9 @@ class RelPositionMultiheadAttention(nn.Module): if torch.equal(query, key) and torch.equal(key, value): # self-attention - q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk( - 3, dim=-1 - ) + q, k, v = nn.functional.linear( + query, in_proj_weight, in_proj_bias + ).chunk(3, dim=-1) elif torch.equal(key, value): # encoder-decoder attention @@ -718,25 +727,33 @@ class RelPositionMultiheadAttention(nn.Module): if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: - raise RuntimeError("The size of the 2D attn_mask is not correct.") + raise RuntimeError( + "The size of the 2D attn_mask is not correct." + ) elif attn_mask.dim() == 3: if list(attn_mask.size()) != [ bsz * num_heads, query.size(0), key.size(0), ]: - raise RuntimeError("The size of the 3D attn_mask is not correct.") + raise RuntimeError( + "The size of the 3D attn_mask is not correct." + ) else: raise RuntimeError( - "attn_mask's dimension {} is not supported".format(attn_mask.dim()) + "attn_mask's dimension {} is not supported".format( + attn_mask.dim() + ) ) # attn_mask's dim is 3 now. # convert ByteTensor key_padding_mask to bool - if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: + if ( + key_padding_mask is not None + and key_padding_mask.dtype == torch.uint8 + ): warnings.warn( - "Byte tensor for key_padding_mask is deprecated. Use bool tensor" - " instead." + "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." ) key_padding_mask = key_padding_mask.to(torch.bool) @@ -773,7 +790,9 @@ class RelPositionMultiheadAttention(nn.Module): # first compute matrix a and matrix c # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2) + matrix_ac = torch.matmul( + q_with_bias_u, k + ) # (batch, head, time1, time2) # compute matrix b and matrix d matrix_bd = torch.matmul( @@ -781,9 +800,13 @@ class RelPositionMultiheadAttention(nn.Module): ) # (batch, head, time1, 2*time1-1) matrix_bd = self.rel_shift(matrix_bd) - attn_output_weights = matrix_ac + matrix_bd # (batch, head, time1, time2) + attn_output_weights = ( + matrix_ac + matrix_bd + ) # (batch, head, time1, time2) - attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) + attn_output_weights = attn_output_weights.view( + bsz * num_heads, tgt_len, -1 + ) assert list(attn_output_weights.size()) == [ bsz * num_heads, @@ -817,9 +840,13 @@ class RelPositionMultiheadAttention(nn.Module): attn_output = torch.bmm(attn_output_weights, v) assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] attn_output = ( - attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) + attn_output.transpose(0, 1) + .contiguous() + .view(tgt_len, bsz, embed_dim) + ) + attn_output = nn.functional.linear( + attn_output, out_proj_weight, out_proj_bias ) - attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) if need_weights: # average attention weights over heads @@ -842,7 +869,9 @@ class ConvolutionModule(nn.Module): """ - def __init__(self, channels: int, kernel_size: int, bias: bool = True) -> None: + def __init__( + self, channels: int, kernel_size: int, bias: bool = True + ) -> None: """Construct an ConvolutionModule object.""" super(ConvolutionModule, self).__init__() # kernerl_size should be a odd number for 'SAME' padding diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless6/decode.py index 600aa9b39..74df04006 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/decode.py @@ -120,24 +120,20 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help=( - "Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. " - ), + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", ) parser.add_argument( @@ -212,7 +208,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -270,7 +267,9 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - layer_results, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + layer_results, encoder_out_lens = model.encoder( + x=feature, x_lens=feature_lens + ) encoder_out = layer_results[-1] hyps = [] @@ -286,7 +285,10 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + elif ( + params.decoding_method == "greedy_search" + and params.max_sym_per_frame == 1 + ): hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -332,7 +334,11 @@ def decode_one_batch( return {"greedy_search": hyps} elif params.decoding_method == "fast_beam_search": return { - f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps + ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}" + ): hyps } else: return {f"beam_size_{params.beam_size}": hyps} @@ -405,7 +411,9 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) return results @@ -438,7 +446,8 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -481,7 +490,9 @@ def main(): params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + params.suffix += ( + f"-{params.decoding_method}-beam-size-{params.beam_size}" + ) else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -513,12 +524,13 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -541,12 +553,13 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -574,7 +587,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - "Calculating the averaged model over epoch range from " + f"Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/export.py b/egs/librispeech/ASR/pruned_transducer_stateless6/export.py index 17f8614dc..cff9c7377 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/export.py @@ -51,7 +51,11 @@ import sentencepiece as spm import torch from train import get_params, get_transducer_model -from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint +from icefall.checkpoint import ( + average_checkpoints, + find_checkpoints, + load_checkpoint, +) from icefall.utils import str2bool @@ -83,11 +87,9 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( @@ -118,7 +120,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) return parser @@ -157,7 +160,8 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -205,7 +209,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/extract_codebook_index.py b/egs/librispeech/ASR/pruned_transducer_stateless6/extract_codebook_index.py index 86cf34877..21409287c 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/extract_codebook_index.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/extract_codebook_index.py @@ -21,10 +21,9 @@ import os from pathlib import Path import torch +from vq_utils import CodebookIndexExtractor from asr_datamodule import LibriSpeechAsrDataModule from hubert_xlarge import HubertXlargeFineTuned -from vq_utils import CodebookIndexExtractor - from icefall.utils import AttributeDict, str2bool diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/hubert_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless6/hubert_decode.py index b8440f90a..49b557814 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/hubert_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/hubert_decode.py @@ -23,6 +23,7 @@ from pathlib import Path from typing import Dict, List, Tuple import torch + from asr_datamodule import LibriSpeechAsrDataModule from hubert_xlarge import HubertXlargeFineTuned @@ -98,7 +99,9 @@ def decode_dataset( if batch_idx % 20 == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) return results @@ -121,7 +124,9 @@ def save_results( ) test_set_wers[key] = wer - logging.info("Wrote detailed error stats to {}".format(errs_filename)) + logging.info( + "Wrote detailed error stats to {}".format(errs_filename) + ) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = params.res_dir / f"wer-summary-{test_set_name}.txt" @@ -150,7 +155,9 @@ def main(): # reset some parameters needed by hubert. params.update(HubertXlargeFineTuned.get_params()) - params.res_dir = params.exp_dir / f"ctc_greedy_search-{params.teacher_model_id}" + params.res_dir = ( + params.exp_dir / f"ctc_greedy_search-{params.teacher_model_id}" + ) setup_logger(f"{params.res_dir}/log/log-ctc_greedy_search") logging.info("Decoding started") @@ -183,7 +190,9 @@ def main(): params=params, ) - save_results(params=params, test_set_name=test_set, results_dict=results_dict) + save_results( + params=params, test_set_name=test_set, results_dict=results_dict + ) logging.info("Done!") diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/hubert_xlarge.py b/egs/librispeech/ASR/pruned_transducer_stateless6/hubert_xlarge.py index 4f9417c9f..55ce7b00d 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/hubert_xlarge.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/hubert_xlarge.py @@ -22,7 +22,11 @@ from pathlib import Path from typing import Dict, List, Tuple import torch -from fairseq import checkpoint_utils, tasks, utils +from fairseq import ( + checkpoint_utils, + tasks, + utils, +) from fairseq.data.data_utils import post_process from omegaconf import OmegaConf @@ -47,7 +51,9 @@ def _load_hubert_model(params: AttributeDict): "data": str(params.hubert_model_dir), } ) - model_path = Path(params.hubert_model_dir) / (params.teacher_model_id + ".pt") + model_path = Path(params.hubert_model_dir) / ( + params.teacher_model_id + ".pt" + ) task = tasks.setup_task(cfg_task) processor = task.target_dictionary models, saved_cfg = checkpoint_utils.load_model_ensemble( @@ -145,7 +151,9 @@ class HubertXlargeFineTuned: supervisions = batch["supervisions"] num_samples = supervisions["num_samples"] B, T = features.shape - padding_mask = torch.arange(0, T).expand(B, T) > num_samples.reshape([-1, 1]) + padding_mask = torch.arange(0, T).expand(B, T) > num_samples.reshape( + [-1, 1] + ) padding_mask = padding_mask.to(self.params.device) features = features.to(self.params.device) @@ -155,7 +163,9 @@ class HubertXlargeFineTuned: features = features.transpose(1, 2) features = self.w2v_model.layer_norm(features) - padding_mask = self.w2v_model.forward_padding_mask(features, padding_mask) + padding_mask = self.w2v_model.forward_padding_mask( + features, padding_mask + ) if self.w2v_model.post_extract_proj is not None: features = self.w2v_model.post_extract_proj(features) @@ -202,7 +212,9 @@ class HubertXlargeFineTuned: toks = encoder_out.argmax(dim=-1) blank = 0 toks = [tok.unique_consecutive() for tok in toks] - hyps = [self.processor.string(tok[tok != blank].int().cpu()) for tok in toks] + hyps = [ + self.processor.string(tok[tok != blank].int().cpu()) for tok in toks + ] hyps = [post_process(hyp, "letter") for hyp in hyps] return hyps diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/model.py b/egs/librispeech/ASR/pruned_transducer_stateless6/model.py index daadb70c9..7716d19cf 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/model.py @@ -69,7 +69,9 @@ class Transducer(nn.Module): self.decoder = decoder self.joiner = joiner - self.simple_am_proj = ScaledLinear(encoder_dim, vocab_size, initial_speed=0.5) + self.simple_am_proj = ScaledLinear( + encoder_dim, vocab_size, initial_speed=0.5 + ) self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size) from icefall import is_module_available @@ -178,7 +180,9 @@ class Transducer(nn.Module): y_padded = y.pad(mode="constant", padding_value=0) y_padded = y_padded.to(torch.int64) - boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) + boundary = torch.zeros( + (x.size(0), 4), dtype=torch.int64, device=x.device + ) boundary[:, 2] = y_lens boundary[:, 3] = x_lens @@ -233,7 +237,9 @@ class Transducer(nn.Module): return (simple_loss, pruned_loss, codebook_loss) @staticmethod - def concat_successive_codebook_indexes(middle_layer_output, codebook_indexes): + def concat_successive_codebook_indexes( + middle_layer_output, codebook_indexes + ): # Output rate of hubert is 50 frames per second, # while that of current encoder is 25. # Following code handling two issues: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/train.py b/egs/librispeech/ASR/pruned_transducer_stateless6/train.py index be54ff0ce..f717d85fb 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/train.py @@ -101,7 +101,9 @@ from icefall.utils import ( str2bool, ) -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] +LRSchedulerType = Union[ + torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler +] def get_parser(): @@ -201,45 +203,42 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help=( - "The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss" - ), + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help=( - "The scale to smooth the loss with lm (output of prediction network) part." - ), + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)part.", + help="The scale to smooth the loss with am (output of encoder network)" + "part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help=( - "To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss." - ), + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", ) parser.add_argument( @@ -570,7 +569,9 @@ def save_checkpoint( def extract_codebook_indexes(batch): cuts = batch["supervisions"]["cut"] # -100 is identical to ignore_value in CE loss computation. - cuts_pre_mixed = [c if isinstance(c, MonoCut) else c.tracks[0].cut for c in cuts] + cuts_pre_mixed = [ + c if isinstance(c, MonoCut) else c.tracks[0].cut for c in cuts + ] codebook_indexes, codebook_indexes_lens = collate_custom_field( cuts_pre_mixed, "codebook_indexes", pad_value=-100 ) @@ -603,7 +604,11 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = model.device if isinstance(model, DDP) else next(model.parameters()).device + device = ( + model.device + if isinstance(model, DDP) + else next(model.parameters()).device + ) feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -650,7 +655,9 @@ def compute_loss( # If the batch contains more than 10 utterances AND # if either all simple_loss or pruned_loss is inf or nan, # we stop the training process by raising an exception - if torch.all(~simple_loss_is_finite) or torch.all(~pruned_loss_is_finite): + if torch.all(~simple_loss_is_finite) or torch.all( + ~pruned_loss_is_finite + ): raise ValueError( "There are too many utterances in this batch " "leading to inf or nan losses." @@ -663,9 +670,14 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) + 0.0 + if warmup < 1.0 + else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) + ) + loss = ( + params.simple_loss_scale * simple_loss + + pruned_loss_scale * pruned_loss ) - loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss if is_training and params.enable_distillation: assert codebook_loss is not None loss += params.codebook_loss_scale * codebook_loss @@ -678,7 +690,9 @@ def compute_loss( # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2 # (2) If some utterances in the batch lead to inf/nan loss, they # are filtered out. - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + info["frames"] = ( + (feature_lens // params.subsampling_factor).sum().item() + ) # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa info["utterances"] = feature.size(0) @@ -859,7 +873,9 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -991,7 +1007,8 @@ def run(rank, world_size, args): # the threshold if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. " + f"Duration: {c.duration}" ) return False diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/vq_utils.py b/egs/librispeech/ASR/pruned_transducer_stateless6/vq_utils.py index 40f97f662..47cf2b14b 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/vq_utils.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/vq_utils.py @@ -68,7 +68,9 @@ class CodebookIndexExtractor: def init_dirs(self): # vq_dir is the root dir for quantization, containing: # training data, trained quantizer, and extracted codebook indexes - self.vq_dir = self.params.exp_dir / f"vq/{self.params.teacher_model_id}/" + self.vq_dir = ( + self.params.exp_dir / f"vq/{self.params.teacher_model_id}/" + ) self.vq_dir.mkdir(parents=True, exist_ok=True) # manifest_dir contains: @@ -206,7 +208,9 @@ class CodebookIndexExtractor: start = cur_offset % (data.shape[0] + 1 - B) end = start + B cur_offset += B - yield data[start:end, :].to(self.params.device).to(dtype=torch.float) + yield data[start:end, :].to(self.params.device).to( + dtype=torch.float + ) for x in minibatch_generator(train, repeat=True): trainer.step(x) @@ -223,11 +227,10 @@ class CodebookIndexExtractor: """ for subset in self.params.subsets: logging.info(f"About to split {subset}.") - ori_manifest = f"./data/fbank/librispeech_cuts_train-{subset}.jsonl.gz" - split_cmd = ( - "lhotse split" - f" {self.params.world_size} {ori_manifest} {self.manifest_dir}" + ori_manifest = ( + f"./data/fbank/librispeech_cuts_train-{subset}.jsonl.gz" ) + split_cmd = f"lhotse split {self.params.world_size} {ori_manifest} {self.manifest_dir}" os.system(f"{split_cmd}") def join_manifests(self): @@ -237,13 +240,16 @@ class CodebookIndexExtractor: logging.info("Start to join manifest files.") for subset in self.params.subsets: vq_manifest_path = ( - self.dst_manifest_dir / f"librispeech_cuts_train-{subset}-vq.jsonl.gz" + self.dst_manifest_dir + / f"librispeech_cuts_train-{subset}-vq.jsonl.gz" ) ori_manifest_path = ( - self.ori_manifest_dir / f"librispeech_cuts_train-{subset}.jsonl.gz" + self.ori_manifest_dir + / f"librispeech_cuts_train-{subset}.jsonl.gz" ) dst_vq_manifest_path = ( - self.dst_manifest_dir / f"librispeech_cuts_train-{subset}.jsonl.gz" + self.dst_manifest_dir + / f"librispeech_cuts_train-{subset}.jsonl.gz" ) cuts_vq = load_manifest(vq_manifest_path) cuts_ori = load_manifest(ori_manifest_path) @@ -263,7 +269,8 @@ class CodebookIndexExtractor: for subset in self.params.subsets: vq_manifests = f"{self.manifest_dir}/with_codebook_indexes-librispeech-cuts_train-{subset}*.jsonl.gz" dst_vq_manifest = ( - self.dst_manifest_dir / f"librispeech_cuts_train-{subset}-vq.jsonl.gz" + self.dst_manifest_dir + / f"librispeech_cuts_train-{subset}-vq.jsonl.gz" ) if 1 == self.params.world_size: merge_cmd = f"cp {vq_manifests} {dst_vq_manifest}" @@ -323,7 +330,9 @@ class CodebookIndexExtractor: def load_ori_dl(self, subset): if self.params.world_size == 1: - ori_manifest_path = f"./data/fbank/librispeech_cuts_train-{subset}.jsonl.gz" + ori_manifest_path = ( + f"./data/fbank/librispeech_cuts_train-{subset}.jsonl.gz" + ) else: ori_manifest_path = ( self.manifest_dir diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py index fa8144935..06c5863f1 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py @@ -164,24 +164,20 @@ def get_parser(): "--avg", type=int, default=9, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help=( - "Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. " - ), + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", ) parser.add_argument( @@ -276,7 +272,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -396,7 +393,9 @@ def decode_one_batch( simulate_streaming=True, ) else: - encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + encoder_out, encoder_out_lens = model.encoder( + x=feature, x_lens=feature_lens + ) hyps = [] @@ -455,7 +454,10 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + elif ( + params.decoding_method == "greedy_search" + and params.max_sym_per_frame == 1 + ): hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -586,7 +588,9 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) return results @@ -619,7 +623,8 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -674,7 +679,9 @@ def main(): if "LG" in params.decoding_method: params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + params.suffix += ( + f"-{params.decoding_method}-beam-size-{params.beam_size}" + ) else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -711,12 +718,13 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -739,12 +747,13 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -772,7 +781,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - "Calculating the averaged model over epoch range from " + f"Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -799,7 +808,9 @@ def main(): decoding_graph.scores *= params.ngram_lm_scale else: word_table = None - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + decoding_graph = k2.trivial_graph( + params.vocab_size - 1, device=device + ) else: decoding_graph = None word_table = None diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/decoder.py b/egs/librispeech/ASR/pruned_transducer_stateless7/decoder.py index 5f90e6375..712dc8ce1 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/decoder.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/decoder.py @@ -69,7 +69,7 @@ class Decoder(nn.Module): out_channels=decoder_dim, kernel_size=context_size, padding=0, - groups=decoder_dim // 4, # group size == 4 + groups=decoder_dim//4, # group size == 4 bias=False, ) @@ -91,7 +91,9 @@ class Decoder(nn.Module): if self.context_size > 1: embedding_out = embedding_out.permute(0, 2, 1) if need_pad is True: - embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0)) + embedding_out = F.pad( + embedding_out, pad=(self.context_size - 1, 0) + ) else: # During inference time, there is no need to do extra padding # as we only need one output diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/export.py b/egs/librispeech/ASR/pruned_transducer_stateless7/export.py index 43ac658e5..5744ea3ea 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/export.py @@ -129,24 +129,20 @@ def get_parser(): "--avg", type=int, default=9, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help=( - "Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. " - ), + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", ) parser.add_argument( @@ -180,7 +176,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) add_model_arguments(parser) @@ -218,12 +215,13 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -246,12 +244,13 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -279,7 +278,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - "Calculating the averaged model over epoch range from " + f"Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -317,7 +316,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/jit_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7/jit_pretrained.py index c94a34d58..e2405d5ef 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/jit_pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/jit_pretrained.py @@ -69,12 +69,10 @@ def get_parser(): "sound_files", type=str, nargs="+", - help=( - "The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz." - ), + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", ) return parser @@ -95,9 +93,10 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) # We use only the first channel ans.append(wave[0]) return ans @@ -268,7 +267,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/joiner.py b/egs/librispeech/ASR/pruned_transducer_stateless7/joiner.py index 3ddac2cf2..7d8de5afe 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/joiner.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/joiner.py @@ -56,7 +56,9 @@ class Joiner(nn.Module): assert encoder_out.shape[:-1] == decoder_out.shape[:-1] if project_input: - logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out) + logit = self.encoder_proj(encoder_out) + self.decoder_proj( + decoder_out + ) else: logit = encoder_out + decoder_out diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/model.py b/egs/librispeech/ASR/pruned_transducer_stateless7/model.py index 0e59b0f2f..53cde6c6f 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/model.py @@ -15,15 +15,14 @@ # limitations under the License. -import random - import k2 import torch import torch.nn as nn +import random from encoder_interface import EncoderInterface -from scaling import penalize_abs_values_gt from icefall.utils import add_sos +from scaling import penalize_abs_values_gt class Transducer(nn.Module): @@ -66,8 +65,7 @@ class Transducer(nn.Module): self.joiner = joiner self.simple_am_proj = nn.Linear( - encoder_dim, - vocab_size, + encoder_dim, vocab_size, ) self.simple_lm_proj = nn.Linear(decoder_dim, vocab_size) @@ -135,16 +133,18 @@ class Transducer(nn.Module): y_padded = y.pad(mode="constant", padding_value=0) y_padded = y_padded.to(torch.int64) - boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) + boundary = torch.zeros( + (x.size(0), 4), dtype=torch.int64, device=x.device + ) boundary[:, 2] = y_lens boundary[:, 3] = x_lens lm = self.simple_lm_proj(decoder_out) am = self.simple_am_proj(encoder_out) - # if self.training and random.random() < 0.25: + #if self.training and random.random() < 0.25: # lm = penalize_abs_values_gt(lm, 100.0, 1.0e-04) - # if self.training and random.random() < 0.25: + #if self.training and random.random() < 0.25: # am = penalize_abs_values_gt(am, 30.0, 1.0e-04) with torch.cuda.amp.autocast(enabled=False): diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py index 460ac2c3e..bb8b0a0e3 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py @@ -14,17 +14,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -import contextlib -import logging -import random from collections import defaultdict -from typing import List, Optional, Tuple, Union - -import torch +from typing import List, Optional, Union, Tuple, List from lhotse.utils import fix_random_seed +import torch from scaling import ActivationBalancer +import random from torch import Tensor from torch.optim import Optimizer +import logging +import contextlib + class BatchedOptimizer(Optimizer): @@ -37,10 +37,11 @@ class BatchedOptimizer(Optimizer): Args: params: """ - def __init__(self, params, defaults): super(BatchedOptimizer, self).__init__(params, defaults) + + @contextlib.contextmanager def batched_params(self, param_group): """ @@ -72,9 +73,7 @@ class BatchedOptimizer(Optimizer): group: a parameter group, which is a list of parameters; should be one of self.groups. """ - batches = defaultdict( - list - ) # `batches` maps from tuple (dtype_as_str,*shape) to list of nn.Parameter + batches = defaultdict(list) # `batches` maps from tuple (dtype_as_str,*shape) to list of nn.Parameter for p in param_group: key = (str(p.dtype), *p.shape) @@ -83,7 +82,7 @@ class BatchedOptimizer(Optimizer): stacked_params_dict = dict() # turn batches into a list, in deterministic order. - batches = [batches[key] for key in sorted(batches.keys())] + batches = [ batches[key] for key in sorted(batches.keys()) ] # pairs will contain pairs of (stacked_param, state), one for each batch # in `batches`. pairs = [] @@ -95,78 +94,77 @@ class BatchedOptimizer(Optimizer): # group. class Optimizer will take care of saving/loading state. state = self.state[p] p_stacked = torch.stack(batch) - grad = torch.stack( - [torch.zeros_like(p) if p.grad is None else p.grad for p in batch] - ) + grad = torch.stack([torch.zeros_like(p) if p.grad is None else p.grad for p in batch ]) p_stacked.grad = grad stacked_params_dict[key] = p_stacked pairs.append((p_stacked, state)) - yield pairs # <-- calling code will do the actual optimization here! + yield pairs # <-- calling code will do the actual optimization here! for ((stacked_params, _state), batch) in zip(pairs, batches): for i, p in enumerate(batch): # batch is list of Parameter p.copy_(stacked_params[i]) + class ScaledAdam(BatchedOptimizer): """ - Implements 'Scaled Adam', a variant of Adam where we scale each parameter's update - proportional to the norm of that parameter; and also learn the scale of the parameter, - in log space, subject to upper and lower limits (as if we had factored each parameter as - param = underlying_param * log_scale.exp()) + Implements 'Scaled Adam', a variant of Adam where we scale each parameter's update + proportional to the norm of that parameter; and also learn the scale of the parameter, + in log space, subject to upper and lower limits (as if we had factored each parameter as + param = underlying_param * log_scale.exp()) - Args: - params: The parameters or param_groups to optimize (like other Optimizer subclasses) - lr: The learning rate. We will typically use a learning rate schedule that starts - at 0.03 and decreases over time, i.e. much higher than other common - optimizers. - clipping_scale: (e.g. 2.0) - A scale for gradient-clipping: if specified, the normalized gradients - over the whole model will be clipped to have 2-norm equal to - `clipping_scale` times the median 2-norm over the most recent period - of `clipping_update_period` minibatches. By "normalized gradients", - we mean after multiplying by the rms parameter value for this tensor - [for non-scalars]; this is appropriate because our update is scaled - by this quantity. - betas: beta1,beta2 are momentum constants for regular momentum, and moving sum-sq grad. - Must satisfy 0 < beta <= beta2 < 1. - scalar_lr_scale: A scaling factor on the learning rate, that we use to update the - scale of each parameter tensor and scalar parameters of the mode.. - If each parameter were decomposed - as p * p_scale.exp(), where (p**2).mean().sqrt() == 1.0, scalar_lr_scale - would be a the scaling factor on the learning rate of p_scale. - eps: A general-purpose epsilon to prevent division by zero - param_min_rms: Minimum root-mean-square value of parameter tensor, for purposes of - learning the scale on the parameters (we'll constrain the rms of each non-scalar - parameter tensor to be >= this value) - param_max_rms: Maximum root-mean-square value of parameter tensor, for purposes of - learning the scale on the parameters (we'll constrain the rms of each non-scalar - parameter tensor to be <= this value) - scalar_max: Maximum absolute value for scalar parameters (applicable if your - model has any parameters with numel() == 1). - size_update_period: The periodicity, in steps, with which we update the size (scale) - of the parameter tensor. This is provided to save a little time - in the update. - clipping_update_period: if clipping_scale is specified, this is the period + Args: + params: The parameters or param_groups to optimize (like other Optimizer subclasses) + lr: The learning rate. We will typically use a learning rate schedule that starts + at 0.03 and decreases over time, i.e. much higher than other common + optimizers. + clipping_scale: (e.g. 2.0) + A scale for gradient-clipping: if specified, the normalized gradients + over the whole model will be clipped to have 2-norm equal to + `clipping_scale` times the median 2-norm over the most recent period + of `clipping_update_period` minibatches. By "normalized gradients", + we mean after multiplying by the rms parameter value for this tensor + [for non-scalars]; this is appropriate because our update is scaled + by this quantity. + betas: beta1,beta2 are momentum constants for regular momentum, and moving sum-sq grad. + Must satisfy 0 < beta <= beta2 < 1. + scalar_lr_scale: A scaling factor on the learning rate, that we use to update the + scale of each parameter tensor and scalar parameters of the mode.. + If each parameter were decomposed + as p * p_scale.exp(), where (p**2).mean().sqrt() == 1.0, scalar_lr_scale + would be a the scaling factor on the learning rate of p_scale. + eps: A general-purpose epsilon to prevent division by zero + param_min_rms: Minimum root-mean-square value of parameter tensor, for purposes of + learning the scale on the parameters (we'll constrain the rms of each non-scalar + parameter tensor to be >= this value) + param_max_rms: Maximum root-mean-square value of parameter tensor, for purposes of + learning the scale on the parameters (we'll constrain the rms of each non-scalar + parameter tensor to be <= this value) + scalar_max: Maximum absolute value for scalar parameters (applicable if your + model has any parameters with numel() == 1). + size_update_period: The periodicity, in steps, with which we update the size (scale) + of the parameter tensor. This is provided to save a little time + in the update. + clipping_update_period: if clipping_scale is specified, this is the period """ - def __init__( - self, - params, - lr=3e-02, - clipping_scale=None, - betas=(0.9, 0.98), - scalar_lr_scale=0.1, - eps=1.0e-08, - param_min_rms=1.0e-05, - param_max_rms=3.0, - scalar_max=10.0, - size_update_period=4, - clipping_update_period=100, + self, + params, + lr=3e-02, + clipping_scale=None, + betas=(0.9, 0.98), + scalar_lr_scale=0.1, + eps=1.0e-08, + param_min_rms=1.0e-05, + param_max_rms=3.0, + scalar_max=10.0, + size_update_period=4, + clipping_update_period=100, ): + defaults = dict( lr=lr, clipping_scale=clipping_scale, @@ -185,6 +183,7 @@ class ScaledAdam(BatchedOptimizer): def __setstate__(self, state): super(ScaledAdam, self).__setstate__(state) + @torch.no_grad() def step(self, closure=None): """Performs a single optimization step. @@ -207,9 +206,7 @@ class ScaledAdam(BatchedOptimizer): # a regular parameter, and will have a .grad, but the 1st dim corresponds to # a stacking dim, it is not a real dim. - if ( - len(batches[0][1]) == 0 - ): # if len(first state) == 0: not yet initialized + if len(batches[0][1]) == 0: # if len(first state) == 0: not yet initialized clipping_scale = 1 else: clipping_scale = self._get_clipping_scale(group, batches) @@ -228,9 +225,13 @@ class ScaledAdam(BatchedOptimizer): self._step_one_batch(group, p, state, clipping_scale) + return loss - def _init_state(self, group: dict, p: Tensor, state: dict): + def _init_state(self, + group: dict, + p: Tensor, + state: dict): """ Initializes state dict for parameter 'p'. Assumes that dim 0 of tensor p is actually the batch dimension, corresponding to batched-together @@ -246,7 +247,7 @@ class ScaledAdam(BatchedOptimizer): state["step"] = 0 - kwargs = {"device": p.device, "dtype": p.dtype} + kwargs = {'device':p.device, 'dtype':p.dtype} # 'delta' implements conventional momentum. There are # several different kinds of update going on, so rather than @@ -254,30 +255,36 @@ class ScaledAdam(BatchedOptimizer): # parameter-change "delta", which combines all forms of # update. this is equivalent to how it's done in Adam, # except for the first few steps. - state["delta"] = torch.zeros_like(p, memory_format=torch.preserve_format) + state["delta"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) batch_size = p.shape[0] numel = p.numel() // batch_size numel = p.numel() + if numel > 1: # "param_rms" just periodically records the scalar root-mean-square value of # the parameter tensor. # it has a shape like (batch_size, 1, 1, 1, 1) - param_rms = (p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt() + param_rms = (p**2).mean(dim=list(range(1, p.ndim)), + keepdim=True).sqrt() state["param_rms"] = param_rms state["scale_exp_avg_sq"] = torch.zeros_like(param_rms) - state["scale_grads"] = torch.zeros( - size_update_period, *param_rms.shape, **kwargs - ) + state["scale_grads"] = torch.zeros(size_update_period, *param_rms.shape, + **kwargs) + # exp_avg_sq is the weighted sum of scaled gradients. as in Adam. - state["exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format) + state["exp_avg_sq"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) - def _get_clipping_scale( - self, group: dict, pairs: List[Tuple[Tensor, dict]] - ) -> float: + def _get_clipping_scale(self, + group: dict, + pairs: List[Tuple[Tensor, dict]]) -> float: """ Returns a scalar factor <= 1.0 that dictates gradient clipping, i.e. we will scale the gradients by this amount before applying the rest of the update. @@ -307,67 +314,57 @@ class ScaledAdam(BatchedOptimizer): if p.numel() == p.shape[0]: # a batch of scalars tot_sumsq += (grad**2).sum() # sum() to change shape [1] to [] else: - tot_sumsq += ((grad * state["param_rms"]) ** 2).sum() + tot_sumsq += ((grad * state["param_rms"])**2).sum() tot_norm = tot_sumsq.sqrt() - if "model_norms" not in first_state: - first_state["model_norms"] = torch.zeros( - clipping_update_period, device=p.device - ) + if not "model_norms" in first_state: + first_state["model_norms"] = torch.zeros(clipping_update_period, + device=p.device) first_state["model_norms"][step % clipping_update_period] = tot_norm if step % clipping_update_period == 0: # Print some stats. # We don't reach here if step == 0 because we would have returned # above. - sorted_norms = first_state["model_norms"].sort()[0].to("cpu") + sorted_norms = first_state["model_norms"].sort()[0].to('cpu') quartiles = [] for n in range(0, 5): - index = min( - clipping_update_period - 1, - (clipping_update_period // 4) * n, - ) + index = min(clipping_update_period - 1, + (clipping_update_period // 4) * n) quartiles.append(sorted_norms[index].item()) median = quartiles[2] threshold = clipping_scale * median first_state["model_norm_threshold"] = threshold - percent_clipped = ( - first_state["num_clipped"] * 100.0 / clipping_update_period - if "num_clipped" in first_state - else 0.0 - ) + percent_clipped = (first_state["num_clipped"] * 100.0 / clipping_update_period + if "num_clipped" in first_state else 0.0) first_state["num_clipped"] = 0 - quartiles = " ".join(["%.3e" % x for x in quartiles]) - logging.info( - f"Clipping_scale={clipping_scale}, grad-norm quartiles {quartiles}, " - f"threshold={threshold:.3e}, percent-clipped={percent_clipped:.1f}" - ) + quartiles = ' '.join([ '%.3e' % x for x in quartiles ]) + logging.info(f"Clipping_scale={clipping_scale}, grad-norm quartiles {quartiles}, " + f"threshold={threshold:.3e}, percent-clipped={percent_clipped:.1f}") if step < clipping_update_period: return 1.0 # We have not yet estimated a norm to clip to. else: try: model_norm_threshold = first_state["model_norm_threshold"] - except KeyError: - logging.info( - "Warning: model_norm_threshold not in state: possibly " - "you changed config when restarting, adding clipping_scale option?" - ) + except: + logging.info("Warning: model_norm_threshold not in state: possibly " + "you changed config when restarting, adding clipping_scale option?") return 1.0 - ans = min(1.0, (model_norm_threshold / (tot_norm + 1.0e-20)).item()) + ans = min(1.0,(model_norm_threshold / (tot_norm + 1.0e-20)).item()) if ans < 1.0: first_state["num_clipped"] += 1 if ans < 0.1: - logging.warn( - f"Scaling gradients by {ans}," - f" model_norm_threshold={model_norm_threshold}" - ) + logging.warn(f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}") return ans - def _step_one_batch( - self, group: dict, p: Tensor, state: dict, clipping_scale: float - ): + + def _step_one_batch(self, + group: dict, + p: Tensor, + state: dict, + clipping_scale: float): """ Do the step for one parameter, which is actually going to be a batch of `real` parameters, with dim 0 as the batch dim. @@ -394,18 +391,17 @@ class ScaledAdam(BatchedOptimizer): # Update the size/scale of p, and set param_rms scale_grads = state["scale_grads"] scale_grads[step % size_update_period] = (p * grad).sum( - dim=list(range(1, p.ndim)), keepdim=True - ) + dim=list(range(1, p.ndim)), keepdim=True) if step % size_update_period == size_update_period - 1: param_rms = state["param_rms"] # shape: (batch_size, 1, 1, ..) - param_rms.copy_( - (p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt() - ) + param_rms.copy_((p**2).mean(dim=list(range(1, p.ndim)), + keepdim=True).sqrt()) if step > 0: # self._size_update() learns the overall scale on the # parameter, by shrinking or expanding it. self._size_update(group, scale_grads, p, state) + if numel == 1: # For parameters with 1 element we just use regular Adam. # Updates delta. @@ -415,21 +411,24 @@ class ScaledAdam(BatchedOptimizer): state["step"] = step + 1 - def _size_update( - self, group: dict, scale_grads: Tensor, p: Tensor, state: dict - ) -> None: - """ - Called only where p.numel() > 1, this updates the scale of the parameter. - If we imagine: p = underlying_param * scale.exp(), and we are doing - gradient descent on underlying param and on scale, this function does the update - on `scale`. - Args: - group: dict to look up configuration values - scale_grads: a tensor of shape (size_update_period, batch_size, 1, 1,...) containing - grads w.r.t. the scales. - p: The parameter to update - state: The state-dict of p + def _size_update(self, + group: dict, + scale_grads: Tensor, + p: Tensor, + state: dict) -> None: + """ + Called only where p.numel() > 1, this updates the scale of the parameter. + If we imagine: p = underlying_param * scale.exp(), and we are doing + gradient descent on underlying param and on scale, this function does the update + on `scale`. + + Args: + group: dict to look up configuration values + scale_grads: a tensor of shape (size_update_period, batch_size, 1, 1,...) containing + grads w.r.t. the scales. + p: The parameter to update + state: The state-dict of p """ param_rms = state["param_rms"] @@ -444,28 +443,25 @@ class ScaledAdam(BatchedOptimizer): size_update_period = scale_grads.shape[0] # correct beta2 for the size update period: we will have # faster decay at this level. - beta2_corr = beta2**size_update_period + beta2_corr = beta2 ** size_update_period scale_exp_avg_sq = state["scale_exp_avg_sq"] # shape: (batch_size, 1, 1, ..) scale_exp_avg_sq.mul_(beta2_corr).add_( - (scale_grads**2).mean(dim=0), # mean over dim `size_update_period` - alpha=1 - beta2_corr, - ) # shape is (batch_size, 1, 1, ...) + (scale_grads ** 2).mean(dim=0), # mean over dim `size_update_period` + alpha=1-beta2_corr) # shape is (batch_size, 1, 1, ...) # The 1st time we reach here is when size_step == 1. size_step = (step + 1) // size_update_period - bias_correction2 = 1 - beta2_corr**size_step + bias_correction2 = 1 - beta2_corr ** size_step # we don't bother with bias_correction1; this will help prevent divergence # at the start of training. denom = scale_exp_avg_sq.sqrt() + eps - scale_step = ( - -size_lr * (bias_correction2**0.5) * scale_grads.sum(dim=0) / denom - ) + scale_step = -size_lr * (bias_correction2 ** 0.5) * scale_grads.sum(dim=0) / denom - is_too_small = param_rms < param_min_rms - is_too_large = param_rms > param_max_rms + is_too_small = (param_rms < param_min_rms) + is_too_large = (param_rms > param_max_rms) # when the param gets too small, just don't shrink it any further. scale_step.masked_fill_(is_too_small, 0.0) @@ -473,9 +469,13 @@ class ScaledAdam(BatchedOptimizer): scale_step.masked_fill_(is_too_large, -size_lr * size_update_period) delta = state["delta"] # the factor of (1-beta1) relates to momentum. - delta.add_(p * scale_step, alpha=(1 - beta1)) + delta.add_(p * scale_step, alpha=(1-beta1)) - def _step(self, group: dict, p: Tensor, state: dict): + + def _step(self, + group: dict, + p: Tensor, + state: dict): """ This function does the core update of self.step(), in the case where the members of the batch have more than 1 element. @@ -496,7 +496,8 @@ class ScaledAdam(BatchedOptimizer): step = state["step"] exp_avg_sq = state["exp_avg_sq"] - exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=(1 - beta2)) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, + value=(1-beta2)) this_step = state["step"] - (state["zero_step"] if "zero_step" in state else 0) bias_correction2 = 1 - beta2 ** (this_step + 1) @@ -508,13 +509,17 @@ class ScaledAdam(BatchedOptimizer): denom += eps grad = grad / denom - alpha = -lr * (1 - beta1) * state["param_rms"].clamp(min=param_min_rms) + alpha = -lr * (1-beta1) * state["param_rms"].clamp(min=param_min_rms) delta = state["delta"] delta.add_(grad * alpha) p.add_(delta) - def _step_scalar(self, group: dict, p: Tensor, state: dict): + + def _step_scalar(self, + group: dict, + p: Tensor, + state: dict): """ A simplified form of the core update for scalar tensors, where we cannot get a good estimate of the parameter rms. @@ -526,7 +531,8 @@ class ScaledAdam(BatchedOptimizer): grad = p.grad exp_avg_sq = state["exp_avg_sq"] # shape: (batch_size,) - exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, + value=1-beta2) # bias_correction2 is like in Adam. Don't bother with bias_correction1; # slower update at the start will help stability anyway. @@ -534,11 +540,12 @@ class ScaledAdam(BatchedOptimizer): denom = (exp_avg_sq / bias_correction2).sqrt() + eps delta = state["delta"] - delta.add_(grad / denom, alpha=-lr * (1 - beta1)) + delta.add_(grad / denom, alpha=-lr*(1-beta1)) p.clamp_(min=-scalar_max, max=scalar_max) p.add_(delta) + class LRScheduler(object): """ Base-class for learning rate schedulers where the learning-rate depends on both the @@ -548,14 +555,18 @@ class LRScheduler(object): def __init__(self, optimizer: Optimizer, verbose: bool = False): # Attach optimizer if not isinstance(optimizer, Optimizer): - raise TypeError("{} is not an Optimizer".format(type(optimizer).__name__)) + raise TypeError( + "{} is not an Optimizer".format(type(optimizer).__name__) + ) self.optimizer = optimizer self.verbose = verbose for group in optimizer.param_groups: group.setdefault("base_lr", group["lr"]) - self.base_lrs = [group["base_lr"] for group in optimizer.param_groups] + self.base_lrs = [ + group["base_lr"] for group in optimizer.param_groups + ] self.epoch = 0 self.batch = 0 @@ -669,15 +680,13 @@ class Eden(LRScheduler): def get_lr(self): factor = ( - (self.batch**2 + self.lr_batches**2) / self.lr_batches**2 + (self.batch ** 2 + self.lr_batches ** 2) / self.lr_batches ** 2 ) ** -0.25 * ( - ((self.epoch**2 + self.lr_epochs**2) / self.lr_epochs**2) ** -0.25 - ) - warmup_factor = ( - 1.0 - if self.batch >= self.warmup_batches - else 0.5 + 0.5 * (self.batch / self.warmup_batches) + ((self.epoch ** 2 + self.lr_epochs ** 2) / self.lr_epochs ** 2) + ** -0.25 ) + warmup_factor = (1.0 if self.batch >= self.warmup_batches + else 0.5 + 0.5 * (self.batch / self.warmup_batches)) return [x * factor * warmup_factor for x in self.base_lrs] @@ -736,14 +745,13 @@ class Eve(Optimizer): parameters, if they fall below this we will stop applying weight decay. - .. _Adam: A Method for Stochastic Optimization: + .. _Adam\: A Method for Stochastic Optimization: https://arxiv.org/abs/1412.6980 .. _Decoupled Weight Decay Regularization: https://arxiv.org/abs/1711.05101 .. _On the Convergence of Adam and Beyond: https://openreview.net/forum?id=ryQu7f-RZ """ - def __init__( self, params, @@ -758,11 +766,17 @@ class Eve(Optimizer): if not 0.0 <= eps: raise ValueError("Invalid epsilon value: {}".format(eps)) if not 0.0 <= betas[0] < 1.0: - raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + raise ValueError( + "Invalid beta parameter at index 0: {}".format(betas[0]) + ) if not 0.0 <= betas[1] < 1.0: - raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + raise ValueError( + "Invalid beta parameter at index 1: {}".format(betas[1]) + ) if not 0 <= weight_decay <= 0.1: - raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + raise ValueError( + "Invalid weight_decay value: {}".format(weight_decay) + ) if not 0 < target_rms <= 10.0: raise ValueError("Invalid target_rms value: {}".format(target_rms)) defaults = dict( @@ -798,7 +812,9 @@ class Eve(Optimizer): # Perform optimization step grad = p.grad if grad.is_sparse: - raise RuntimeError("AdamW does not support sparse gradients") + raise RuntimeError( + "AdamW does not support sparse gradients" + ) state = self.state[p] @@ -825,7 +841,7 @@ class Eve(Optimizer): # Decay the first and second moment running average coefficient exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) - denom = (exp_avg_sq.sqrt() * (bias_correction2**-0.5)).add_( + denom = (exp_avg_sq.sqrt() * (bias_correction2 ** -0.5)).add_( group["eps"] ) @@ -836,31 +852,30 @@ class Eve(Optimizer): if p.numel() > 1: # avoid applying this weight-decay on "scaling factors" # (which are scalar). - is_above_target_rms = p.norm() > (target_rms * (p.numel() ** 0.5)) + is_above_target_rms = p.norm() > ( + target_rms * (p.numel() ** 0.5) + ) p.mul_(1 - (weight_decay * is_above_target_rms)) p.addcdiv_(exp_avg, denom, value=-step_size) if random.random() < 0.0005: - step = (exp_avg / denom) * step_size - logging.info( - f"Delta rms = {(step**2).mean().item()}, shape = {step.shape}" - ) + step = (exp_avg/denom) * step_size + logging.info(f"Delta rms = {(step**2).mean().item()}, shape = {step.shape}") + return loss def _test_scaled_adam(hidden_dim: int): import timeit - from scaling import ScaledLinear - E = 100 B = 4 T = 2 logging.info("in test_eve_cain") - # device = torch.device('cuda') - device = torch.device("cpu") + #device = torch.device('cuda') + device = torch.device('cpu') dtype = torch.float32 fix_random_seed(42) @@ -874,93 +889,79 @@ def _test_scaled_adam(hidden_dim: int): fix_random_seed(42) Linear = torch.nn.Linear if iter == 0 else ScaledLinear - m = torch.nn.Sequential( - Linear(E, hidden_dim), - torch.nn.PReLU(), - Linear(hidden_dim, hidden_dim), - torch.nn.PReLU(), - Linear(hidden_dim, E), - ).to(device) + m = torch.nn.Sequential(Linear(E, hidden_dim), + torch.nn.PReLU(), + Linear(hidden_dim, hidden_dim), + torch.nn.PReLU(), + Linear(hidden_dim, E), + ).to(device) - train_pairs = [ - ( - 100.0 - * torch.randn(B, T, E, device=device, dtype=dtype) - * input_magnitudes, - torch.randn(B, T, E, device=device, dtype=dtype) * output_magnitudes, - ) - for _ in range(20) - ] + train_pairs = [ (100.0 * torch.randn(B, T, E, device=device, dtype=dtype) * input_magnitudes, + torch.randn(B, T, E, device=device, dtype=dtype) * output_magnitudes) for _ in range(20) ] - if iter == 0: - optim = Eve(m.parameters(), lr=0.003) - elif iter == 1: - optim = ScaledAdam(m.parameters(), lr=0.03, clipping_scale=2.0) + if iter == 0: optim = Eve(m.parameters(), lr=0.003) + elif iter == 1: optim = ScaledAdam(m.parameters(), lr=0.03, clipping_scale=2.0) scheduler = Eden(optim, lr_batches=200, lr_epochs=5, verbose=False) + start = timeit.default_timer() avg_loss = 0.0 for epoch in range(180): scheduler.step_epoch() - # if epoch == 100 and iter in [2,3]: + #if epoch == 100 and iter in [2,3]: # optim.reset_speedup() # check it doesn't crash. - # if epoch == 130: + #if epoch == 130: # opts = diagnostics.TensorDiagnosticOptions( # 2 ** 22 # ) # allow 4 megabytes per sub-module # diagnostic = diagnostics.attach_diagnostics(m, opts) - for n, (x, y) in enumerate(train_pairs): + + for n, (x,y) in enumerate(train_pairs): y_out = m(x) - loss = ((y_out - y) ** 2).mean() * 100.0 + loss = ((y_out - y)**2).mean() * 100.0 if epoch == 0 and n == 0: avg_loss = loss.item() else: avg_loss = 0.98 * avg_loss + 0.02 * loss.item() if n == 0 and epoch % 5 == 0: - # norm1 = '%.2e' % (m[0].weight**2).mean().sqrt().item() - # norm1b = '%.2e' % (m[0].bias**2).mean().sqrt().item() - # norm2 = '%.2e' % (m[2].weight**2).mean().sqrt().item() - # norm2b = '%.2e' % (m[2].bias**2).mean().sqrt().item() - # scale1 = '%.2e' % (m[0].weight_scale.exp().item()) - # scale1b = '%.2e' % (m[0].bias_scale.exp().item()) - # scale2 = '%.2e' % (m[2].weight_scale.exp().item()) - # scale2b = '%.2e' % (m[2].bias_scale.exp().item()) + #norm1 = '%.2e' % (m[0].weight**2).mean().sqrt().item() + #norm1b = '%.2e' % (m[0].bias**2).mean().sqrt().item() + #norm2 = '%.2e' % (m[2].weight**2).mean().sqrt().item() + #norm2b = '%.2e' % (m[2].bias**2).mean().sqrt().item() + #scale1 = '%.2e' % (m[0].weight_scale.exp().item()) + #scale1b = '%.2e' % (m[0].bias_scale.exp().item()) + #scale2 = '%.2e' % (m[2].weight_scale.exp().item()) + #scale2b = '%.2e' % (m[2].bias_scale.exp().item()) lr = scheduler.get_last_lr()[0] - logging.info( - f"Iter {iter}, epoch {epoch}, batch {n}, avg_loss" - f" {avg_loss:.4g}, lr={lr:.4e}" - ) # , norms={norm1,norm1b,norm2,norm2b}") # scales={scale1,scale1b,scale2,scale2b} + logging.info(f"Iter {iter}, epoch {epoch}, batch {n}, avg_loss {avg_loss:.4g}, lr={lr:.4e}") #, norms={norm1,norm1b,norm2,norm2b}") # scales={scale1,scale1b,scale2,scale2b} loss.log().backward() optim.step() optim.zero_grad() scheduler.step_batch() - # diagnostic.print_diagnostics() + #diagnostic.print_diagnostics() stop = timeit.default_timer() logging.info(f"Iter={iter}, Time taken: {stop - start}") logging.info(f"last lr = {scheduler.get_last_lr()}") - # logging.info("state dict = ", scheduler.state_dict()) - # logging.info("optim state_dict = ", optim.state_dict()) + #logging.info("state dict = ", scheduler.state_dict()) + #logging.info("optim state_dict = ", optim.state_dict()) logging.info(f"input_magnitudes = {input_magnitudes}") logging.info(f"output_magnitudes = {output_magnitudes}") + if __name__ == "__main__": torch.set_num_threads(1) torch.set_num_interop_threads(1) logging.getLogger().setLevel(logging.INFO) import subprocess - - s = subprocess.check_output( - "git status -uno .; git log -1; git diff HEAD .", shell=True - ) + s = subprocess.check_output("git status -uno .; git log -1; git diff HEAD .", shell=True) logging.info(s) import sys - if len(sys.argv) > 1: hidden_dim = int(sys.argv[1]) else: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7/pretrained.py index 8b4d88871..7fe1e681a 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/pretrained.py @@ -100,11 +100,9 @@ def get_parser(): "--checkpoint", type=str, required=True, - help=( - "Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint()." - ), + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", ) parser.add_argument( @@ -129,12 +127,10 @@ def get_parser(): "sound_files", type=str, nargs="+", - help=( - "The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz." - ), + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", ) parser.add_argument( @@ -181,7 +177,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -212,9 +209,10 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) # We use only the first channel ans.append(wave[0]) return ans @@ -277,11 +275,15 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + features = pad_sequence( + features, batch_first=True, padding_value=math.log(1e-10) + ) feature_lengths = torch.tensor(feature_lengths, device=device) - encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths) + encoder_out, encoder_out_lens = model.encoder( + x=features, x_lens=feature_lengths + ) num_waves = encoder_out.size(0) hyps = [] @@ -353,7 +355,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 4040065e1..50cedba56 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -16,12 +16,12 @@ import collections -import logging -import random -from functools import reduce from itertools import repeat from typing import Optional, Tuple, Union +from functools import reduce +import logging +import random import torch import torch.nn as nn import torch.nn.functional as F @@ -32,24 +32,27 @@ from torch.nn import Embedding as ScaledEmbedding class ActivationBalancerFunction(torch.autograd.Function): @staticmethod def forward( - ctx, - x: Tensor, - scale_factor: Tensor, - sign_factor: Optional[Tensor], - channel_dim: int, + ctx, + x: Tensor, + scale_factor: Tensor, + sign_factor: Optional[Tensor], + channel_dim: int, ) -> Tensor: if channel_dim < 0: channel_dim += x.ndim ctx.channel_dim = channel_dim - xgt0 = x > 0 + xgt0 = (x > 0) if sign_factor is None: ctx.save_for_backward(xgt0, scale_factor) else: ctx.save_for_backward(xgt0, scale_factor, sign_factor) return x + @staticmethod - def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None]: + def backward( + ctx, x_grad: Tensor + ) -> Tuple[Tensor, None, None, None]: if len(ctx.saved_tensors) == 3: xgt0, scale_factor, sign_factor = ctx.saved_tensors for _ in range(ctx.channel_dim, x_grad.ndim - 1): @@ -62,22 +65,14 @@ class ActivationBalancerFunction(torch.autograd.Function): scale_factor = scale_factor.unsqueeze(-1) factor = scale_factor * (xgt0.to(x_grad.dtype) - 0.5) neg_delta_grad = x_grad.abs() * factor - return ( - x_grad - neg_delta_grad, - None, - None, - None, - ) + return x_grad - neg_delta_grad, None, None, None, - -def _compute_scale_factor( - x: Tensor, - channel_dim: int, - min_abs: float, - max_abs: float, - gain_factor: float, - max_factor: float, -) -> Tensor: +def _compute_scale_factor(x: Tensor, + channel_dim: int, + min_abs: float, + max_abs: float, + gain_factor: float, + max_factor: float) -> Tensor: if channel_dim < 0: channel_dim += x.ndim sum_dims = [d for d in range(x.ndim) if d != channel_dim] @@ -88,76 +83,71 @@ def _compute_scale_factor( else: # below_threshold is 0 if x_abs_mean > min_abs, can be at most max_factor if # x_abs)_mean , min_abs. - below_threshold = ((min_abs - x_abs_mean) * (gain_factor / min_abs)).clamp( - min=0, max=max_factor - ) + below_threshold = ((min_abs - x_abs_mean) * (gain_factor / min_abs)).clamp(min=0, max=max_factor) - above_threshold = ((x_abs_mean - max_abs) * (gain_factor / max_abs)).clamp( - min=0, max=max_factor - ) + above_threshold = ((x_abs_mean - max_abs) * (gain_factor / max_abs)).clamp(min=0, max=max_factor) return below_threshold - above_threshold - -def _compute_sign_factor( - x: Tensor, - channel_dim: int, - min_positive: float, - max_positive: float, - gain_factor: float, - max_factor: float, -) -> Tensor: +def _compute_sign_factor(x: Tensor, + channel_dim: int, + min_positive: float, + max_positive: float, + gain_factor: float, + max_factor: float) -> Tensor: if channel_dim < 0: channel_dim += x.ndim sum_dims = [d for d in range(x.ndim) if d != channel_dim] - proportion_positive = torch.mean((x > 0).to(torch.float32), dim=sum_dims) + proportion_positive = torch.mean((x > 0).to(torch.float32), + dim=sum_dims) if min_positive == 0.0: factor1 = 0.0 else: # 0 if proportion_positive >= min_positive, else can be # as large as max_factor. - factor1 = ( - (min_positive - proportion_positive) * (gain_factor / min_positive) - ).clamp_(min=0, max=max_factor) + factor1 = ((min_positive - proportion_positive) * + (gain_factor / min_positive)).clamp_(min=0, max=max_factor) if max_positive == 1.0: factor2 = 0.0 else: # 0 if self.proportion_positive <= max_positive, else can be # as large as -max_factor. - factor2 = ( - (proportion_positive - max_positive) * (gain_factor / (1.0 - max_positive)) - ).clamp_(min=0, max=max_factor) + factor2 = ((proportion_positive - max_positive) * + (gain_factor / (1.0 - max_positive))).clamp_(min=0, max=max_factor) sign_factor = factor1 - factor2 # require min_positive != 0 or max_positive != 1: assert not isinstance(sign_factor, float) return sign_factor + class ActivationScaleBalancerFunction(torch.autograd.Function): """ This object is used in class ActivationBalancer when the user specified min_positive=0, max_positive=1, so there are no constraints on the signs of the activations and only the absolute value has a constraint. """ - @staticmethod def forward( - ctx, - x: Tensor, - sign_factor: Tensor, - scale_factor: Tensor, - channel_dim: int, + ctx, + x: Tensor, + sign_factor: Tensor, + scale_factor: Tensor, + channel_dim: int, ) -> Tensor: if channel_dim < 0: channel_dim += x.ndim ctx.channel_dim = channel_dim - xgt0 = x > 0 + xgt0 = (x > 0) ctx.save_for_backward(xgt0, sign_factor, scale_factor) return x + @staticmethod - def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None]: + def backward( + ctx, x_grad: Tensor + ) -> Tuple[Tensor, None, None, None]: xgt0, sign_factor, scale_factor = ctx.saved_tensors for _ in range(ctx.channel_dim, x_grad.ndim - 1): sign_factor = sign_factor.unsqueeze(-1) @@ -165,24 +155,18 @@ class ActivationScaleBalancerFunction(torch.autograd.Function): factor = sign_factor + scale_factor * (xgt0.to(x_grad.dtype) - 0.5) neg_delta_grad = x_grad.abs() * factor - return ( - x_grad - neg_delta_grad, - None, - None, - None, - ) + return x_grad - neg_delta_grad, None, None, None, class RandomClampFunction(torch.autograd.Function): @staticmethod def forward( - ctx, - x: Tensor, - min: Optional[float], - max: Optional[float], - prob: float, - reflect: float, - ) -> Tensor: + ctx, + x: Tensor, + min: Optional[float], + max: Optional[float], + prob: float, + reflect: float) -> Tensor: x_clamped = torch.clamp(x, min=min, max=max) mask = torch.rand_like(x) < prob ans = torch.where(mask, x_clamped, x) @@ -195,32 +179,30 @@ class RandomClampFunction(torch.autograd.Function): @staticmethod def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None, None, None, None]: - (is_same,) = ctx.saved_tensors + is_same, = ctx.saved_tensors x_grad = ans_grad * is_same.to(ans_grad.dtype) reflect = ctx.reflect - if reflect != 0.0: + if reflect != 0.0: x_grad = x_grad * (1.0 + reflect) - (ans_grad * reflect) return x_grad, None, None, None, None - -def random_clamp( - x: Tensor, - min: Optional[float] = None, - max: Optional[float] = None, - prob: float = 0.5, - reflect: float = 0.0, -): +def random_clamp(x: Tensor, + min: Optional[float] = None, + max: Optional[float] = None, + prob: float = 0.5, + reflect: float = 0.0): return RandomClampFunction.apply(x, min, max, prob, reflect) -def random_cast_to_half(x: Tensor, min_abs: float = 5.0e-06) -> Tensor: +def random_cast_to_half(x: Tensor, + min_abs: float = 5.0e-06) -> Tensor: """ A randomized way of casting a floating point value to half precision. """ if x.dtype == torch.float16: return x x_abs = x.abs() - is_too_small = x_abs < min_abs + is_too_small = (x_abs < min_abs) # for elements where is_too_small is true, random_val will contain +-min_abs with # probability (x.abs() / min_abs), and 0.0 otherwise. [so this preserves expectations, # for those elements]. @@ -233,7 +215,6 @@ class RandomGradFunction(torch.autograd.Function): Does nothing in forward pass; in backward pass, gets rid of very small grads using randomized approach that preserves expectations (intended to reduce roundoff). """ - @staticmethod def forward(ctx, x: Tensor, min_abs: float) -> Tensor: ctx.min_abs = min_abs @@ -242,37 +223,35 @@ class RandomGradFunction(torch.autograd.Function): @staticmethod def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None]: if ans_grad.dtype == torch.float16: - return ( - random_cast_to_half(ans_grad.to(torch.float32), min_abs=ctx.min_abs), - None, - ) + return random_cast_to_half(ans_grad.to(torch.float32), + min_abs=ctx.min_abs), None else: return ans_grad, None - class RandomGrad(torch.nn.Module): """ Gets rid of very small gradients using an expectation-preserving method, intended to increase accuracy of training when using amp (automatic mixed precision) """ - - def __init__(self, min_abs: float = 5.0e-06): + def __init__(self, + min_abs: float = 5.0e-06): super(RandomGrad, self).__init__() self.min_abs = min_abs - def forward(self, x: Tensor): + def forward(self, + x: Tensor): if torch.jit.is_scripting() or not self.training: return x else: return RandomGradFunction.apply(x, self.min_abs) + class SoftmaxFunction(torch.autograd.Function): """ Tries to handle half-precision derivatives in a randomized way that should be more accurate for training than the default behavior. """ - @staticmethod def forward(ctx, x: Tensor, dim: int): ans = x.softmax(dim=dim) @@ -288,7 +267,7 @@ class SoftmaxFunction(torch.autograd.Function): @staticmethod def backward(ctx, ans_grad: Tensor): - (ans,) = ctx.saved_tensors + ans, = ctx.saved_tensors with torch.cuda.amp.autocast(enabled=False): ans_grad = ans_grad.to(torch.float32) ans = ans.to(torch.float32) @@ -297,7 +276,9 @@ class SoftmaxFunction(torch.autograd.Function): return x_grad, None -def softmax(x: Tensor, dim: int): + +def softmax(x: Tensor, + dim: int): if torch.jit.is_scripting(): return x.softmax(dim) @@ -307,18 +288,20 @@ def softmax(x: Tensor, dim: int): class MaxEigLimiterFunction(torch.autograd.Function): @staticmethod def forward( - ctx, - x: Tensor, - coeffs: Tensor, - direction: Tensor, - channel_dim: int, - grad_scale: float, - ) -> Tensor: + ctx, + x: Tensor, + coeffs: Tensor, + direction: Tensor, + channel_dim: int, + grad_scale: float) -> Tensor: ctx.channel_dim = channel_dim ctx.grad_scale = grad_scale - ctx.save_for_backward(x.detach(), coeffs.detach(), direction.detach()) + ctx.save_for_backward(x.detach(), + coeffs.detach(), + direction.detach()) return x + @staticmethod def backward(ctx, x_grad, *args): with torch.enable_grad(): @@ -328,20 +311,15 @@ class MaxEigLimiterFunction(torch.autograd.Function): x = x_orig.transpose(ctx.channel_dim, -1).reshape(-1, num_channels) new_direction.requires_grad = False x = x - x.mean(dim=0) - x_var = (x**2).mean() + x_var = (x ** 2).mean() x_residual = x - coeffs * new_direction - x_residual_var = (x_residual**2).mean() + x_residual_var = (x_residual ** 2).mean() # `variance_proportion` is the proportion of the variance accounted for # by the top eigen-direction. This is to be minimized. variance_proportion = (x_var - x_residual_var) / (x_var + 1.0e-20) variance_proportion.backward() x_orig_grad = x_orig.grad - x_extra_grad = ( - x_orig.grad - * ctx.grad_scale - * x_grad.norm() - / (x_orig_grad.norm() + 1.0e-20) - ) + x_extra_grad = x_orig.grad * ctx.grad_scale * x_grad.norm() / (x_orig_grad.norm() + 1.0e-20) return x_grad + x_extra_grad.detach(), None, None, None, None @@ -407,12 +385,15 @@ class BasicNorm(torch.nn.Module): # region if it happens to exit it. eps = eps.clamp(min=self.eps_min, max=self.eps_max) scales = ( - torch.mean(x**2, dim=self.channel_dim, keepdim=True) + eps.exp() + torch.mean(x ** 2, dim=self.channel_dim, keepdim=True) + eps.exp() ) ** -0.5 return x * scales -def ScaledLinear(*args, initial_scale: float = 1.0, **kwargs) -> nn.Linear: + +def ScaledLinear(*args, + initial_scale: float = 1.0, + **kwargs ) -> nn.Linear: """ Behaves like a constructor of a modified version of nn.Linear that gives an easy way to set the default initial parameter scale. @@ -431,11 +412,16 @@ def ScaledLinear(*args, initial_scale: float = 1.0, **kwargs) -> nn.Linear: with torch.no_grad(): ans.weight[:] *= initial_scale if ans.bias is not None: - torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale) + torch.nn.init.uniform_(ans.bias, + -0.1 * initial_scale, + 0.1 * initial_scale) return ans -def ScaledConv1d(*args, initial_scale: float = 1.0, **kwargs) -> nn.Conv1d: + +def ScaledConv1d(*args, + initial_scale: float = 1.0, + **kwargs ) -> nn.Conv1d: """ Behaves like a constructor of a modified version of nn.Conv1d that gives an easy way to set the default initial parameter scale. @@ -454,10 +440,13 @@ def ScaledConv1d(*args, initial_scale: float = 1.0, **kwargs) -> nn.Conv1d: with torch.no_grad(): ans.weight[:] *= initial_scale if ans.bias is not None: - torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale) + torch.nn.init.uniform_(ans.bias, + -0.1 * initial_scale, + 0.1 * initial_scale) return ans + class ActivationBalancer(torch.nn.Module): """ Modifies the backpropped derivatives of a function to try to encourage, for @@ -497,19 +486,18 @@ class ActivationBalancer(torch.nn.Module): from doing it at the same time. Early in training we may use higher probabilities than this; it will decay to this value. """ - def __init__( - self, - num_channels: int, - channel_dim: int, - min_positive: float = 0.05, - max_positive: float = 0.95, - max_factor: float = 0.04, - sign_gain_factor: float = 0.01, - scale_gain_factor: float = 0.02, - min_abs: float = 0.2, - max_abs: float = 100.0, - min_prob: float = 0.1, + self, + num_channels: int, + channel_dim: int, + min_positive: float = 0.05, + max_positive: float = 0.95, + max_factor: float = 0.04, + sign_gain_factor: float = 0.01, + scale_gain_factor: float = 0.02, + min_abs: float = 0.2, + max_abs: float = 100.0, + min_prob: float = 0.1, ): super(ActivationBalancer, self).__init__() self.num_channels = num_channels @@ -527,7 +515,9 @@ class ActivationBalancer(torch.nn.Module): # We occasionally sync this to a tensor called `count`, that exists to # make sure it is synced to disk when we load and save the model. self.cpu_count = 0 - self.register_buffer("count", torch.tensor(0, dtype=torch.int64)) + self.register_buffer('count', torch.tensor(0, dtype=torch.int64)) + + def forward(self, x: Tensor) -> Tensor: if torch.jit.is_scripting() or not x.requires_grad: @@ -545,35 +535,26 @@ class ActivationBalancer(torch.nn.Module): # the prob of doing some work exponentially decreases from 0.5 till it hits # a floor at min_prob (==0.1, by default) - prob = max(self.min_prob, 0.5 ** (1 + (count / 4000.0))) + prob = max(self.min_prob, 0.5 ** (1 + (count/4000.0))) if random.random() < prob: sign_gain_factor = 0.5 if self.min_positive != 0.0 or self.max_positive != 1.0: - sign_factor = _compute_sign_factor( - x, - self.channel_dim, - self.min_positive, - self.max_positive, - gain_factor=self.sign_gain_factor / prob, - max_factor=self.max_factor, - ) + sign_factor = _compute_sign_factor(x, self.channel_dim, + self.min_positive, self.max_positive, + gain_factor=self.sign_gain_factor / prob, + max_factor=self.max_factor) else: sign_factor = None - scale_factor = _compute_scale_factor( - x, - self.channel_dim, - min_abs=self.min_abs, - max_abs=self.max_abs, - gain_factor=self.scale_gain_factor / prob, - max_factor=self.max_factor, - ) + + scale_factor = _compute_scale_factor(x, self.channel_dim, + min_abs=self.min_abs, + max_abs=self.max_abs, + gain_factor=self.scale_gain_factor / prob, + max_factor=self.max_factor) return ActivationBalancerFunction.apply( - x, - scale_factor, - sign_factor, - self.channel_dim, + x, scale_factor, sign_factor, self.channel_dim, ) else: return _no_op(x) @@ -613,12 +594,13 @@ def _diag(x: Tensor): # like .diag(), but works for tensors with 3 dims. else: (batch, dim, dim) = x.shape x = x.reshape(batch, dim * dim) - x = x[:, :: dim + 1] + x = x[:, ::dim+1] assert x.shape == (batch, dim) return x -def _whitening_metric(x: Tensor, num_groups: int): +def _whitening_metric(x: Tensor, + num_groups: int): """ Computes the "whitening metric", a value which will be 1.0 if all the eigenvalues of of the centered feature covariance are the same within each group's covariance matrix @@ -648,21 +630,19 @@ def _whitening_metric(x: Tensor, num_groups: int): # the following expression is what we'd get if we took the matrix product # of each covariance and measured the mean of its trace, i.e. # the same as _diag(torch.matmul(x_covar, x_covar)).mean(). - x_covarsq_mean_diag = (x_covar**2).sum() / (num_groups * channels_per_group) + x_covarsq_mean_diag = (x_covar ** 2).sum() / (num_groups * channels_per_group) # this metric will be >= 1.0; the larger it is, the less 'white' the data was. - metric = x_covarsq_mean_diag / (x_covar_mean_diag**2 + 1.0e-20) + metric = x_covarsq_mean_diag / (x_covar_mean_diag ** 2 + 1.0e-20) return metric class WhiteningPenaltyFunction(torch.autograd.Function): @staticmethod - def forward( - ctx, - x: Tensor, - num_groups: int, - whitening_limit: float, - grad_scale: float, - ) -> Tensor: + def forward(ctx, + x: Tensor, + num_groups: int, + whitening_limit: float, + grad_scale: float) -> Tensor: ctx.save_for_backward(x) ctx.num_groups = num_groups ctx.whitening_limit = whitening_limit @@ -670,8 +650,9 @@ class WhiteningPenaltyFunction(torch.autograd.Function): return x @staticmethod - def backward(ctx, x_grad: Tensor): - (x_orig,) = ctx.saved_tensors + def backward(ctx, + x_grad: Tensor): + x_orig, = ctx.saved_tensors with torch.enable_grad(): with torch.cuda.amp.autocast(enabled=False): x_detached = x_orig.to(torch.float32).detach() @@ -680,29 +661,25 @@ class WhiteningPenaltyFunction(torch.autograd.Function): metric = _whitening_metric(x_detached, ctx.num_groups) if random.random() < 0.005 or __name__ == "__main__": - logging.info( - f"Whitening: num_groups={ctx.num_groups}," - f" num_channels={x_orig.shape[-1]}," - f" metric={metric.item():.2f} vs. limit={ctx.whitening_limit}" - ) + logging.info(f"Whitening: num_groups={ctx.num_groups}, num_channels={x_orig.shape[-1]}, " + f"metric={metric.item():.2f} vs. limit={ctx.whitening_limit}") (metric - ctx.whitening_limit).relu().backward() penalty_grad = x_detached.grad - scale = ctx.grad_scale * ( - x_grad.to(torch.float32).norm() / (penalty_grad.norm() + 1.0e-20) - ) + scale = ctx.grad_scale * (x_grad.to(torch.float32).norm() / + (penalty_grad.norm() + 1.0e-20)) penalty_grad = penalty_grad * scale return x_grad + penalty_grad.to(x_grad.dtype), None, None, None + class Whiten(nn.Module): def __init__( - self, - num_groups: int, - whitening_limit: float, - prob: Union[float, Tuple[float, float]], - grad_scale: float, - ): + self, + num_groups: int, + whitening_limit: float, + prob: Union[float, Tuple[float,float]], + grad_scale: float): """ Args: num_groups: the number of groups to divide the channel dim into before @@ -737,7 +714,8 @@ class Whiten(nn.Module): self.grad_scale = grad_scale - def forward(self, x: Tensor) -> Tensor: + def forward(self, + x: Tensor) -> Tensor: """ In the forward pass, this function just returns the input unmodified. In the backward pass, it will modify the gradients to ensure that the @@ -757,21 +735,19 @@ class Whiten(nn.Module): if not x.requires_grad or random.random() > self.prob or self.grad_scale == 0: return _no_op(x) else: - if hasattr(self, "min_prob") and random.random() < 0.25: + if hasattr(self, 'min_prob') and random.random() < 0.25: # occasionally switch between min_prob and max_prob, based on whether # we are above or below the threshold. - if ( - _whitening_metric(x.to(torch.float32), self.num_groups) - > self.whitening_limit - ): + if _whitening_metric(x.to(torch.float32), self.num_groups) > self.whitening_limit: # there would be a change to the grad. self.prob = self.max_prob else: self.prob = self.min_prob - return WhiteningPenaltyFunction.apply( - x, self.num_groups, self.whitening_limit, self.grad_scale - ) + return WhiteningPenaltyFunction.apply(x, + self.num_groups, + self.whitening_limit, + self.grad_scale) class WithLoss(torch.autograd.Function): @@ -779,14 +755,11 @@ class WithLoss(torch.autograd.Function): def forward(ctx, x: Tensor, y: Tensor): ctx.y_shape = y.shape return x - @staticmethod def backward(ctx, ans_grad: Tensor): - return ans_grad, torch.ones( - ctx.y_shape, dtype=ans_grad.dtype, device=ans_grad.device - ) - - + return ans_grad, torch.ones(ctx.y_shape, + dtype=ans_grad.dtype, + device=ans_grad.device) def with_loss(x, y): if torch.jit.is_scripting(): return x @@ -795,7 +768,7 @@ def with_loss(x, y): def _no_op(x: Tensor) -> Tensor: - if torch.jit.is_scripting(): + if (torch.jit.is_scripting()): return x else: # a no-op function that will have a node in the autograd graph, @@ -810,7 +783,6 @@ class Identity(torch.nn.Module): def forward(self, x): return _no_op(x) - class MaxEig(torch.nn.Module): """ Modifies the backpropped derivatives of a function to try to discourage @@ -831,14 +803,13 @@ class MaxEig(torch.nn.Module): scale: determines the scale with which we modify the gradients, relative to the existing / unmodified gradients """ - def __init__( - self, - num_channels: int, - channel_dim: int, - max_var_per_eig: float = 0.2, - min_prob: float = 0.01, - scale: float = 0.01, + self, + num_channels: int, + channel_dim: int, + max_var_per_eig: float = 0.2, + min_prob: float = 0.01, + scale: float = 0.01, ): super(MaxEig, self).__init__() self.num_channels = num_channels @@ -854,7 +825,7 @@ class MaxEig(torch.nn.Module): # random parameters unchanged for comparison direction = torch.arange(num_channels).to(torch.float) direction = direction / direction.norm() - self.register_buffer("max_eig_direction", direction) + self.register_buffer('max_eig_direction', direction) self.min_prob = min_prob # cur_prob is the current probability we'll use to apply the ActivationBalancer. @@ -862,12 +833,12 @@ class MaxEig(torch.nn.Module): # active. self.cur_prob = 1.0 + + def forward(self, x: Tensor) -> Tensor: - if ( - torch.jit.is_scripting() - or self.max_var_per_eig <= 0 - or random.random() > self.cur_prob - ): + if (torch.jit.is_scripting() or + self.max_var_per_eig <= 0 or + random.random() > self.cur_prob): return _no_op(x) with torch.cuda.amp.autocast(enabled=False): @@ -877,9 +848,7 @@ class MaxEig(torch.nn.Module): with torch.no_grad(): x = x.transpose(self.channel_dim, -1).reshape(-1, self.num_channels) x = x - x.mean(dim=0) - new_direction, coeffs = self._find_direction_coeffs( - x, self.max_eig_direction - ) + new_direction, coeffs = self._find_direction_coeffs(x, self.max_eig_direction) x_var = (x**2).mean() x_residual = x - coeffs * new_direction x_residual_var = (x_residual**2).mean() @@ -892,10 +861,7 @@ class MaxEig(torch.nn.Module): self._set_direction(0.1 * self.max_eig_direction + new_direction) if random.random() < 0.01 or __name__ == "__main__": - logging.info( - f"variance_proportion = {variance_proportion.item()}," - f" shape={tuple(orig_x.shape)}, cur_prob={self.cur_prob}" - ) + logging.info(f"variance_proportion = {variance_proportion.item()}, shape={tuple(orig_x.shape)}, cur_prob={self.cur_prob}") if variance_proportion >= self.max_var_per_eig: # The constraint is active. Note, we should quite rarely @@ -903,16 +869,17 @@ class MaxEig(torch.nn.Module): # starting to diverge, should this constraint be active. cur_prob = self.cur_prob self.cur_prob = 1.0 # next time, do the update with probability 1.0. - return MaxEigLimiterFunction.apply( - orig_x, coeffs, new_direction, self.channel_dim, self.scale - ) + return MaxEigLimiterFunction.apply(orig_x, coeffs, new_direction, + self.channel_dim, self.scale) else: # let self.cur_prob exponentially approach self.min_prob, as # long as the constraint is inactive. self.cur_prob = 0.75 * self.cur_prob + 0.25 * self.min_prob return orig_x - def _set_direction(self, direction: Tensor): + + def _set_direction(self, + direction: Tensor): """ Sets self.max_eig_direction to a normalized version of `direction` """ @@ -922,39 +889,40 @@ class MaxEig(torch.nn.Module): if direction_sum - direction_sum == 0: # no inf/nan self.max_eig_direction[:] = direction else: - logging.info( - f"Warning: sum of direction in MaxEig is {direction_sum}, " - "num_channels={self.num_channels}, channel_dim={self.channel_dim}" - ) + logging.info(f"Warning: sum of direction in MaxEig is {direction_sum}, " + "num_channels={self.num_channels}, channel_dim={self.channel_dim}") - def _find_direction_coeffs( - self, x: Tensor, prev_direction: Tensor - ) -> Tuple[Tensor, Tensor, Tensor]: - """ - Figure out (an approximation to) the proportion of the variance of a set of - feature vectors that can be attributed to the top eigen-direction. - Args: - x: a Tensor of shape (num_frames, num_channels), with num_frames > 1. - prev_direction: a Tensor of shape (num_channels,), that is our previous estimate - of the top eigen-direction, or a random direction if this is the first - iteration. Does not have to be normalized, but should be nonzero. - Returns: (cur_direction, coeffs), where: - cur_direction: a Tensor of shape (num_channels,) that is the current - estimate of the top eigen-direction. - coeffs: a Tensor of shape (num_frames, 1) that minimizes, or - approximately minimizes, (x - coeffs * cur_direction).norm() + def _find_direction_coeffs(self, + x: Tensor, + prev_direction: Tensor) -> Tuple[Tensor, Tensor, Tensor]: """ + Figure out (an approximation to) the proportion of the variance of a set of + feature vectors that can be attributed to the top eigen-direction. + Args: + x: a Tensor of shape (num_frames, num_channels), with num_frames > 1. + prev_direction: a Tensor of shape (num_channels,), that is our previous estimate + of the top eigen-direction, or a random direction if this is the first + iteration. Does not have to be normalized, but should be nonzero. + + Returns: (cur_direction, coeffs), where: + cur_direction: a Tensor of shape (num_channels,) that is the current + estimate of the top eigen-direction. + coeffs: a Tensor of shape (num_frames, 1) that minimizes, or + approximately minimizes, (x - coeffs * cur_direction).norm() + """ (num_frames, num_channels) = x.shape assert num_channels > 1 and num_frames > 1 assert prev_direction.shape == (num_channels,) # `coeffs` are the coefficients of `prev_direction` in x. # actually represent the coeffs up to a constant positive factor. coeffs = (x * prev_direction).sum(dim=1, keepdim=True) + 1.0e-10 - cur_direction = (x * coeffs).sum(dim=0) / ((coeffs**2).sum() + 1.0e-20) + cur_direction = (x * coeffs).sum(dim=0) / ((coeffs ** 2).sum() + 1.0e-20) return cur_direction, coeffs + + class DoubleSwishFunction(torch.autograd.Function): """ double_swish(x) = x * torch.sigmoid(x-1) @@ -982,7 +950,7 @@ class DoubleSwishFunction(torch.autograd.Function): y = x * s if requires_grad: - deriv = y * (1 - s) + s + deriv = (y * (1 - s) + s) # notes on derivative of x * sigmoid(x - 1): # https://www.wolframalpha.com/input?i=d%2Fdx+%28x+*+sigmoid%28x-1%29%29 # min \simeq -0.043638. Take floor as -0.043637 so it's a lower bund @@ -991,9 +959,7 @@ class DoubleSwishFunction(torch.autograd.Function): # floors), should be expectation-preserving. floor = -0.043637 ceil = 1.2 - d_scaled = (deriv - floor) * (255.0 / (ceil - floor)) + torch.rand_like( - deriv - ) + d_scaled = ((deriv - floor) * (255.0 / (ceil - floor)) + torch.rand_like(deriv)) if __name__ == "__main__": # for self-testing only. assert d_scaled.min() >= 0.0 @@ -1006,12 +972,12 @@ class DoubleSwishFunction(torch.autograd.Function): @staticmethod def backward(ctx, y_grad: Tensor) -> Tensor: - (d,) = ctx.saved_tensors + d, = ctx.saved_tensors # the same constants as used in forward pass. floor = -0.043637 ceil = 1.2 - d = d * ((ceil - floor) / 255.0) + floor - return y_grad * d + d = (d * ((ceil - floor) / 255.0) + floor) + return (y_grad * d) class DoubleSwish(torch.nn.Module): @@ -1024,6 +990,7 @@ class DoubleSwish(torch.nn.Module): return DoubleSwishFunction.apply(x) + def _test_max_eig(): for proportion in [0.1, 0.5, 10.0]: logging.info(f"proportion = {proportion}") @@ -1035,9 +1002,11 @@ def _test_max_eig(): x.requires_grad = True num_channels = 128 - m = MaxEig( - num_channels, 1, 0.5, scale=0.1 # channel_dim # max_var_per_eig - ) # grad_scale + m = MaxEig(num_channels, + 1, # channel_dim + 0.5, # max_var_per_eig + scale=0.1) # grad_scale + for _ in range(4): y = m(x) @@ -1062,9 +1031,11 @@ def _test_whiten(): x.requires_grad = True num_channels = 128 - m = Whiten( - 1, 5.0, prob=1.0, grad_scale=0.1 # num_groups # whitening_limit, - ) # grad_scale + m = Whiten(1, # num_groups + 5.0, # whitening_limit, + prob=1.0, + grad_scale=0.1) # grad_scale + for _ in range(4): y = m(x) @@ -1078,6 +1049,7 @@ def _test_whiten(): assert not torch.allclose(x.grad, y_grad) + def _test_activation_balancer_sign(): probs = torch.arange(0, 1, 0.01) N = 1000 @@ -1105,7 +1077,9 @@ def _test_activation_balancer_sign(): def _test_activation_balancer_magnitude(): magnitudes = torch.arange(0, 1, 0.01) N = 1000 - x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1) + x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze( + -1 + ) x = x.detach() x.requires_grad = True m = ActivationBalancer( @@ -1137,8 +1111,8 @@ def _test_basic_norm(): y = m(x) assert y.shape == x.shape - x_rms = (x**2).mean().sqrt() - y_rms = (y**2).mean().sqrt() + x_rms = (x ** 2).mean().sqrt() + y_rms = (y ** 2).mean().sqrt() print("x rms = ", x_rms) print("y rms = ", y_rms) assert y_rms < x_rms @@ -1150,27 +1124,30 @@ def _test_double_swish_deriv(): x.requires_grad = True m = DoubleSwish() - tol = (1.2 - (-0.043637)) / 255.0 + tol = ((1.2-(-0.043637))/255.0) torch.autograd.gradcheck(m, x, atol=tol) + # for self-test. x = torch.randn(1000, 1000, dtype=torch.double) * 3.0 x.requires_grad = True y = m(x) + def _test_softmax(): a = torch.randn(2, 10, dtype=torch.float64) b = a.clone() a.requires_grad = True b.requires_grad = True - a.softmax(dim=1)[:, 0].sum().backward() + a.softmax(dim=1)[:,0].sum().backward() print("a grad = ", a.grad) - softmax(b, dim=1)[:, 0].sum().backward() + softmax(b, dim=1)[:,0].sum().backward() print("b grad = ", b.grad) assert torch.allclose(a.grad, b.grad) + if __name__ == "__main__": logging.getLogger().setLevel(logging.INFO) torch.set_num_threads(1) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling_converter.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling_converter.py index 46e775285..8d357b15f 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling_converter.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling_converter.py @@ -26,7 +26,11 @@ from typing import List import torch import torch.nn as nn -from scaling import ActivationBalancer, BasicNorm, Whiten +from scaling import ( + ActivationBalancer, + BasicNorm, + Whiten, +) class NonScaledNorm(nn.Module): @@ -71,10 +75,12 @@ def get_submodule(model, target): mod: torch.nn.Module = model for item in atoms: if not hasattr(mod, item): - raise AttributeError(mod._get_name() + " has no attribute `" + item + "`") + raise AttributeError( + mod._get_name() + " has no " "attribute `" + item + "`" + ) mod = getattr(mod, item) if not isinstance(mod, torch.nn.Module): - raise AttributeError("`" + item + "` is not an nn.Module") + raise AttributeError("`" + item + "` is not " "an nn.Module") return mod diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py index 7f9526104..3f27736b3 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py @@ -84,7 +84,9 @@ from icefall.env import get_env_info from icefall.hooks import register_inf_check_hooks from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] +LRSchedulerType = Union[ + torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler +] def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: @@ -122,10 +124,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): "--encoder-dims", type=str, default="384,384,384,384,384", - help=( - "Embedding dimension in the 2 blocks of zipformer encoder layers, comma" - " separated" - ), + help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated", ) parser.add_argument( @@ -140,11 +139,9 @@ def add_model_arguments(parser: argparse.ArgumentParser): "--encoder-unmasked-dims", type=str, default="256,256,256,256,256", - help=( - "Unmasked dimensions in the encoders, relates to augmentation during" - " training. Must be <= each of encoder_dims. Empirically, less than 256" - " seems to make performance worse." - ), + help="Unmasked dimensions in the encoders, relates to augmentation during training. " + "Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance " + " worse.", ) parser.add_argument( @@ -272,45 +269,42 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help=( - "The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss" - ), + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help=( - "The scale to smooth the loss with lm (output of prediction network) part." - ), + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)part.", + help="The scale to smooth the loss with am (output of encoder network)" + "part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help=( - "To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss." - ), + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", ) parser.add_argument( @@ -652,7 +646,11 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = model.device if isinstance(model, DDP) else next(model.parameters()).device + device = ( + model.device + if isinstance(model, DDP) + else next(model.parameters()).device + ) feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -699,7 +697,9 @@ def compute_loss( info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + info["frames"] = ( + (feature_lens // params.subsampling_factor).sum().item() + ) # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() @@ -870,7 +870,9 @@ def train_one_epoch( # of the grad scaler is configurable, but we can't configure it to have different # behavior depending on the current grad scale. cur_grad_scale = scaler._scale.item() - if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0): + if cur_grad_scale < 1.0 or ( + cur_grad_scale < 8.0 and batch_idx % 400 == 0 + ): scaler.update(cur_grad_scale * 2.0) if cur_grad_scale < 0.01: logging.warning(f"Grad scale is small: {cur_grad_scale}") @@ -888,7 +890,11 @@ def train_one_epoch( f"batch {batch_idx}, loss[{loss_info}], " f"tot_loss[{tot_loss}], batch size: {batch_size}, " f"lr: {cur_lr:.2e}, " - + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + + ( + f"grad_scale: {scaler._scale.item()}" + if params.use_fp16 + else "" + ) ) if tb_writer is not None: @@ -899,7 +905,9 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) if params.use_fp16: tb_writer.add_scalar( "train/grad_scale", @@ -907,7 +915,10 @@ def train_one_epoch( params.batch_idx_train, ) - if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + if ( + batch_idx % params.valid_interval == 0 + and not params.print_diagnostics + ): logging.info("Computing validation loss") valid_info = compute_validation_loss( params=params, @@ -919,8 +930,7 @@ def train_one_epoch( model.train() logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") logging.info( - "Maximum memory allocated so far is" - f" {torch.cuda.max_memory_allocated()//1000000}MB" + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" ) if tb_writer is not None: valid_info.write_summary( @@ -999,7 +1009,9 @@ def run(rank, world_size, args): logging.info("Using DDP") model = DDP(model, device_ids=[rank], find_unused_parameters=True) - optimizer = ScaledAdam(model.parameters(), lr=params.base_lr, clipping_scale=2.0) + optimizer = ScaledAdam( + model.parameters(), lr=params.base_lr, clipping_scale=2.0 + ) scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) @@ -1017,7 +1029,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 2 ** 22 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) @@ -1042,7 +1054,8 @@ def run(rank, world_size, args): # the threshold if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. " + f"Duration: {c.duration}" ) return False @@ -1216,8 +1229,7 @@ def scan_pessimistic_batches_for_oom( display_and_save_batch(batch, params=params, sp=sp) raise logging.info( - "Maximum memory allocated so far is" - f" {torch.cuda.max_memory_allocated()//1000000}MB" + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" ) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index fcd9858cd..023dec97d 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -16,35 +16,32 @@ # limitations under the License. import copy -import itertools -import logging import math -import random import warnings +import itertools from typing import List, Optional, Tuple, Union - +import logging import torch +import random from encoder_interface import EncoderInterface -from scaling import ( - ScaledLinear, # not as in other dirs.. just scales down initial parameter values. -) from scaling import ( ActivationBalancer, BasicNorm, - DoubleSwish, - Identity, MaxEig, + DoubleSwish, ScaledConv1d, + ScaledLinear, # not as in other dirs.. just scales down initial parameter values. Whiten, + Identity, _diag, - penalize_abs_values_gt, random_clamp, + penalize_abs_values_gt, softmax, ) from torch import Tensor, nn -from icefall.dist import get_rank from icefall.utils import make_pad_mask +from icefall.dist import get_rank class Zipformer(EncoderInterface): @@ -92,7 +89,7 @@ class Zipformer(EncoderInterface): self.batch_count = 0 self.warmup_end = warmup_batches - for u, d in zip(encoder_unmasked_dims, encoder_dims): + for u,d in zip(encoder_unmasked_dims, encoder_dims): assert u <= d, (u, d) # self.encoder_embed converts the input of shape (N, T, num_features) @@ -100,9 +97,9 @@ class Zipformer(EncoderInterface): # That is, it does two things simultaneously: # (1) subsampling: T -> (T - 7)//2 # (2) embedding: num_features -> encoder_dims - self.encoder_embed = Conv2dSubsampling( - num_features, encoder_dims[0], dropout=dropout - ) + self.encoder_embed = Conv2dSubsampling(num_features, encoder_dims[0], + dropout=dropout) + # each one will be ZipformerEncoder or DownsampledZipformerEncoder encoders = [] @@ -126,13 +123,13 @@ class Zipformer(EncoderInterface): num_encoder_layers[i], dropout, warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1), - warmup_end=warmup_batches * (i + 2) / (num_encoders + 1), + warmup_end=warmup_batches * (i + 2) / (num_encoders + 1) ) if zipformer_downsampling_factors[i] != 1: encoder = DownsampledZipformerEncoder( encoder, - input_dim=encoder_dims[i - 1] if i > 0 else encoder_dims[0], + input_dim=encoder_dims[i-1] if i > 0 else encoder_dims[0], output_dim=encoder_dims[i], downsample=zipformer_downsampling_factors[i], ) @@ -142,11 +139,10 @@ class Zipformer(EncoderInterface): # initializes self.skip_layers and self.skip_modules self._init_skip_modules() - self.downsample_output = AttentionDownsample( - encoder_dims[-1], - encoder_dims[-1], - downsample=output_downsampling_factor, - ) + self.downsample_output = AttentionDownsample(encoder_dims[-1], + encoder_dims[-1], + downsample=output_downsampling_factor) + def _get_layer_skip_dropout_prob(self): if not self.training: @@ -170,33 +166,27 @@ class Zipformer(EncoderInterface): skip_modules = [] z = self.zipformer_downsampling_factors for i in range(len(z)): - if i <= 1 or z[i - 1] <= z[i]: + if i <= 1 or z[i-1] <= z[i]: skip_layers.append(None) skip_modules.append(SimpleCombinerIdentity()) else: # TEMP - for j in range(i - 2, -1, -1): + for j in range(i-2, -1, -1): if z[j] <= z[i] or j == 0: # TEMP logging statement. - logging.info( - f"At encoder stack {i}, which has" - f" downsampling_factor={z[i]}, we will combine the outputs" - f" of layers {j} and {i-1}, with" - f" downsampling_factors={z[j]} and {z[i-1]}." - ) + logging.info(f"At encoder stack {i}, which has downsampling_factor={z[i]}, we will " + f"combine the outputs of layers {j} and {i-1}, with downsampling_factors={z[j]} and {z[i-1]}.") skip_layers.append(j) - skip_modules.append( - SimpleCombiner( - self.encoder_dims[j], - self.encoder_dims[i - 1], - min_weight=(0.0, 0.25), - ) - ) + skip_modules.append(SimpleCombiner(self.encoder_dims[j], + self.encoder_dims[i-1], + min_weight=(0.0,0.25))) break self.skip_layers = skip_layers self.skip_modules = nn.ModuleList(skip_modules) - def get_feature_masks(self, x: torch.Tensor) -> List[float]: + def get_feature_masks( + self, + x: torch.Tensor) -> List[float]: # Note: The actual return type is Union[List[float], List[Tensor]], # but to make torch.jit.script() work, we use List[float] """ @@ -216,56 +206,46 @@ class Zipformer(EncoderInterface): """ num_encoders = len(self.encoder_dims) if torch.jit.is_scripting() or not self.training: - return [1.0] * num_encoders + return [ 1.0 ] * num_encoders (num_frames0, batch_size, _encoder_dims0) = x.shape - assert self.encoder_dims[0] == _encoder_dims0, ( - self.encoder_dims, - _encoder_dims0, - ) + + assert self.encoder_dims[0] == _encoder_dims0, (self.encoder_dims, _encoder_dims0) max_downsampling_factor = max(self.zipformer_downsampling_factors) - num_frames_max = num_frames0 + max_downsampling_factor - 1 + num_frames_max = (num_frames0 + max_downsampling_factor - 1) + feature_mask_dropout_prob = 0.15 # frame_mask_max shape: (num_frames_max, batch_size, 1) - frame_mask_max = ( - torch.rand(num_frames_max, batch_size, 1, device=x.device) - > feature_mask_dropout_prob - ).to(x.dtype) + frame_mask_max = (torch.rand(num_frames_max, batch_size, 1, + device=x.device) > + feature_mask_dropout_prob).to(x.dtype) feature_masks = [] for i in range(num_encoders): ds = self.zipformer_downsampling_factors[i] - upsample_factor = max_downsampling_factor // ds + upsample_factor = (max_downsampling_factor // ds) - frame_mask = ( - frame_mask_max.unsqueeze(1) - .expand(num_frames_max, upsample_factor, batch_size, 1) - .reshape(num_frames_max * upsample_factor, batch_size, 1) - ) + frame_mask = (frame_mask_max.unsqueeze(1).expand(num_frames_max, upsample_factor, + batch_size, 1) + .reshape(num_frames_max * upsample_factor, batch_size, 1)) num_frames = (num_frames0 + ds - 1) // ds frame_mask = frame_mask[:num_frames] - feature_mask = torch.ones( - num_frames, - batch_size, - self.encoder_dims[i], - dtype=x.dtype, - device=x.device, - ) + feature_mask = torch.ones(num_frames, batch_size, self.encoder_dims[i], + dtype=x.dtype, device=x.device) u = self.encoder_unmasked_dims[i] feature_mask[:, :, u:] *= frame_mask feature_masks.append(feature_mask) return feature_masks + def forward( - self, - x: torch.Tensor, - x_lens: torch.Tensor, + self, x: torch.Tensor, x_lens: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: @@ -285,19 +265,13 @@ class Zipformer(EncoderInterface): x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) lengths = (x_lens - 7) >> 1 - assert x.size(0) == lengths.max().item(), ( - x.shape, - lengths, - lengths.max(), - ) + assert x.size(0) == lengths.max().item(), (x.shape, lengths, lengths.max()) mask = make_pad_mask(lengths) outputs = [] feature_masks = self.get_feature_masks(x) - for i, (module, skip_module) in enumerate( - zip(self.encoders, self.skip_modules) - ): + for i, (module, skip_module) in enumerate(zip(self.encoders, self.skip_modules)): ds = self.zipformer_downsampling_factors[i] k = self.skip_layers[i] if isinstance(k, int): @@ -306,11 +280,9 @@ class Zipformer(EncoderInterface): x = skip_module(outputs[k], x) elif (not self.training) or random.random() > layer_skip_dropout_prob: x = skip_module(outputs[k], x) - x = module( - x, - feature_mask=feature_masks[i], - src_key_padding_mask=None if mask is None else mask[..., ::ds], - ) + x = module(x, + feature_mask=feature_masks[i], + src_key_padding_mask=None if mask is None else mask[...,::ds]) outputs.append(x) x = self.downsample_output(x) @@ -340,16 +312,15 @@ class ZipformerEncoderLayer(nn.Module): >>> pos_emb = torch.rand(32, 19, 512) >>> out = encoder_layer(src, pos_emb) """ - def __init__( - self, - d_model: int, - attention_dim: int, - nhead: int, - feedforward_dim: int = 2048, - dropout: float = 0.1, - cnn_module_kernel: int = 31, - pos_dim: int = 4, + self, + d_model: int, + attention_dim: int, + nhead: int, + feedforward_dim: int = 2048, + dropout: float = 0.1, + cnn_module_kernel: int = 31, + pos_dim: int = 4, ) -> None: super(ZipformerEncoderLayer, self).__init__() @@ -359,24 +330,29 @@ class ZipformerEncoderLayer(nn.Module): self.batch_count = 0 self.self_attn = RelPositionMultiheadAttention( - d_model, - attention_dim, - nhead, - pos_dim, - dropout=0.0, + d_model, attention_dim, nhead, pos_dim, dropout=0.0, ) self.pooling = PoolingModule(d_model) - self.feed_forward1 = FeedforwardModule(d_model, feedforward_dim, dropout) + self.feed_forward1 = FeedforwardModule(d_model, + feedforward_dim, + dropout) - self.feed_forward2 = FeedforwardModule(d_model, feedforward_dim, dropout) + self.feed_forward2 = FeedforwardModule(d_model, + feedforward_dim, + dropout) - self.feed_forward3 = FeedforwardModule(d_model, feedforward_dim, dropout) + self.feed_forward3 = FeedforwardModule(d_model, + feedforward_dim, + dropout) - self.conv_module1 = ConvolutionModule(d_model, cnn_module_kernel) - self.conv_module2 = ConvolutionModule(d_model, cnn_module_kernel) + self.conv_module1 = ConvolutionModule(d_model, + cnn_module_kernel) + + self.conv_module2 = ConvolutionModule(d_model, + cnn_module_kernel) self.norm_final = BasicNorm(d_model) @@ -384,18 +360,14 @@ class ZipformerEncoderLayer(nn.Module): # try to ensure the output is close to zero-mean (or at least, zero-median). self.balancer = ActivationBalancer( - d_model, - channel_dim=-1, - min_positive=0.45, - max_positive=0.55, + d_model, channel_dim=-1, + min_positive=0.45, max_positive=0.55, max_abs=6.0, ) - self.whiten = Whiten( - num_groups=1, - whitening_limit=5.0, - prob=(0.025, 0.25), - grad_scale=0.01, - ) + self.whiten = Whiten(num_groups=1, + whitening_limit=5.0, + prob=(0.025, 0.25), + grad_scale=0.01) def get_bypass_scale(self): if torch.jit.is_scripting() or not self.training: @@ -410,9 +382,8 @@ class ZipformerEncoderLayer(nn.Module): if self.batch_count > warmup_period: clamp_min = final_clamp_min else: - clamp_min = initial_clamp_min - (self.batch_count / warmup_period) * ( - initial_clamp_min - final_clamp_min - ) + clamp_min = (initial_clamp_min - + (self.batch_count / warmup_period) * (initial_clamp_min - final_clamp_min)) return self.bypass_scale.clamp(min=clamp_min, max=1.0) def get_dynamic_dropout_rate(self): @@ -427,9 +398,8 @@ class ZipformerEncoderLayer(nn.Module): if self.batch_count > warmup_period: return final_dropout_rate else: - return initial_dropout_rate - ( - initial_dropout_rate * final_dropout_rate - ) * (self.batch_count / warmup_period) + return (initial_dropout_rate - + (initial_dropout_rate * final_dropout_rate) * (self.batch_count / warmup_period)) def forward( self, @@ -538,14 +508,13 @@ class ZipformerEncoder(nn.Module): >>> src = torch.rand(10, 32, 512) >>> out = zipformer_encoder(src) """ - def __init__( - self, - encoder_layer: nn.Module, - num_layers: int, - dropout: float, - warmup_begin: float, - warmup_end: float, + self, + encoder_layer: nn.Module, + num_layers: int, + dropout: float, + warmup_begin: float, + warmup_end: float ) -> None: super().__init__() # will be written to, see set_batch_count() Note: in inference time this @@ -559,7 +528,8 @@ class ZipformerEncoder(nn.Module): # so that we can keep this consistent across worker tasks (for efficiency). self.module_seed = torch.randint(0, 1000, ()).item() - self.encoder_pos = RelPositionalEncoding(encoder_layer.d_model, dropout) + self.encoder_pos = RelPositionalEncoding(encoder_layer.d_model, + dropout) self.layers = nn.ModuleList( [copy.deepcopy(encoder_layer) for i in range(num_layers)] @@ -568,13 +538,15 @@ class ZipformerEncoder(nn.Module): assert 0 <= warmup_begin <= warmup_end, (warmup_begin, warmup_end) - delta = (1.0 / num_layers) * (warmup_end - warmup_begin) + + delta = (1. / num_layers) * (warmup_end - warmup_begin) cur_begin = warmup_begin for i in range(num_layers): self.layers[i].warmup_begin = cur_begin cur_begin += delta self.layers[i].warmup_end = cur_begin + def get_layers_to_drop(self, rnd_seed: int): ans = set() if not self.training: @@ -607,14 +579,12 @@ class ZipformerEncoder(nn.Module): # linearly interpolate t = (batch_count - layer_warmup_begin) / layer_warmup_end assert 0.0 <= t < 1.001, t - return initial_layerdrop_prob + t * ( - final_layerdrop_prob - initial_layerdrop_prob - ) + return initial_layerdrop_prob + t * (final_layerdrop_prob - initial_layerdrop_prob) shared_rng = random.Random(batch_count + self.module_seed) independent_rng = random.Random(rnd_seed) - layerdrop_probs = [get_layerdrop_prob(i) for i in range(num_layers)] + layerdrop_probs = [ get_layerdrop_prob(i) for i in range(num_layers) ] tot = sum(layerdrop_probs) # Instead of drawing the samples independently, we first randomly decide # how many layers to drop out, using the same random number generator between @@ -634,13 +604,11 @@ class ZipformerEncoder(nn.Module): if len(ans) == num_to_drop: break if shared_rng.random() < 0.005 or __name__ == "__main__": - logging.info( - f"warmup_begin={self.warmup_begin:.1f}," - f" warmup_end={self.warmup_end:.1f}, batch_count={batch_count:.1f}," - f" num_to_drop={num_to_drop}, layers_to_drop={ans}" - ) + logging.info(f"warmup_begin={self.warmup_begin:.1f}, warmup_end={self.warmup_end:.1f}, " + f"batch_count={batch_count:.1f}, num_to_drop={num_to_drop}, layers_to_drop={ans}") return ans + def forward( self, src: Tensor, @@ -671,6 +639,7 @@ class ZipformerEncoder(nn.Module): pos_emb = self.encoder_pos(src) output = src + if torch.jit.is_scripting(): layers_to_drop = [] else: @@ -701,31 +670,28 @@ class DownsampledZipformerEncoder(nn.Module): after convolutional downsampling, and then upsampled again at the output, and combined with the origin input, so that the output has the same shape as the input. """ - - def __init__( - self, - encoder: nn.Module, - input_dim: int, - output_dim: int, - downsample: int, - ): + def __init__(self, + encoder: nn.Module, + input_dim: int, + output_dim: int, + downsample: int): super(DownsampledZipformerEncoder, self).__init__() self.downsample_factor = downsample self.downsample = AttentionDownsample(input_dim, output_dim, downsample) self.encoder = encoder self.upsample = SimpleUpsample(output_dim, downsample) - self.out_combiner = SimpleCombiner( - input_dim, output_dim, min_weight=(0.0, 0.25) - ) + self.out_combiner = SimpleCombiner(input_dim, + output_dim, + min_weight=(0.0, 0.25)) - def forward( - self, - src: Tensor, - # Note: the type of feature_mask should be Unino[float, Tensor], - # but to make torch.jit.script() happ, we use float here - feature_mask: float = 1.0, - mask: Optional[Tensor] = None, - src_key_padding_mask: Optional[Tensor] = None, + + def forward(self, + src: Tensor, + # Note: the type of feature_mask should be Unino[float, Tensor], + # but to make torch.jit.script() happ, we use float here + feature_mask: float = 1.0, + mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, ) -> Tensor: r"""Downsample, go through encoder, upsample. @@ -752,43 +718,42 @@ class DownsampledZipformerEncoder(nn.Module): src = self.downsample(src) ds = self.downsample_factor if mask is not None: - mask = mask[::ds, ::ds] + mask = mask[::ds,::ds] src = self.encoder( - src, - feature_mask=feature_mask, - mask=mask, - src_key_padding_mask=mask, + src, feature_mask=feature_mask, mask=mask, src_key_padding_mask=mask, ) src = self.upsample(src) # remove any extra frames that are not a multiple of downsample_factor - src = src[: src_orig.shape[0]] + src = src[:src_orig.shape[0]] return self.out_combiner(src_orig, src) - class AttentionDownsample(torch.nn.Module): """ Does downsampling with attention, by weighted sum, and a projection.. """ - - def __init__(self, in_channels: int, out_channels: int, downsample: int): + def __init__(self, + in_channels: int, + out_channels: int, + downsample: int): """ Require out_channels > in_channels. """ super(AttentionDownsample, self).__init__() - self.query = nn.Parameter(torch.randn(in_channels) * (in_channels**-0.5)) + self.query = nn.Parameter(torch.randn(in_channels) * (in_channels ** -0.5)) # fill in the extra dimensions with a projection of the input if out_channels > in_channels: - self.extra_proj = nn.Linear( - in_channels * downsample, out_channels - in_channels, bias=False - ) + self.extra_proj = nn.Linear(in_channels * downsample, + out_channels - in_channels, + bias=False) else: self.extra_proj = None self.downsample = downsample - def forward(self, src: Tensor) -> Tensor: + def forward(self, + src: Tensor) -> Tensor: """ x: (seq_len, batch_size, in_channels) Returns a tensor of shape @@ -802,14 +767,16 @@ class AttentionDownsample(torch.nn.Module): if seq_len != d_seq_len * ds: # right-pad src, repeating the last element. pad = d_seq_len * ds - seq_len - src_extra = src[src.shape[0] - 1 :].expand(pad, src.shape[1], src.shape[2]) + src_extra = src[src.shape[0]-1:].expand(pad, src.shape[1], src.shape[2]) src = torch.cat((src, src_extra), dim=0) assert src.shape[0] == d_seq_len * ds, (src.shape[0], d_seq_len, ds) src = src.reshape(d_seq_len, ds, batch_size, in_channels) scores = (src * self.query).sum(dim=-1, keepdim=True) - scores = penalize_abs_values_gt(scores, limit=10.0, penalty=1.0e-04) + scores = penalize_abs_values_gt(scores, + limit=10.0, + penalty=1.0e-04) weights = scores.softmax(dim=1) @@ -828,12 +795,14 @@ class SimpleUpsample(torch.nn.Module): A very simple form of upsampling that mostly just repeats the input, but also adds a position-specific bias. """ - - def __init__(self, num_channels: int, upsample: int): + def __init__(self, + num_channels: int, + upsample: int): super(SimpleUpsample, self).__init__() self.bias = nn.Parameter(torch.randn(upsample, num_channels) * 0.01) - def forward(self, src: Tensor) -> Tensor: + def forward(self, + src: Tensor) -> Tensor: """ x: (seq_len, batch_size, num_channels) Returns a tensor of shape @@ -846,7 +815,6 @@ class SimpleUpsample(torch.nn.Module): src = src.reshape(seq_len * upsample, batch_size, num_channels) return src - class SimpleCombinerIdentity(nn.Module): def __init__(self, *args, **kwargs): super().__init__() @@ -854,7 +822,6 @@ class SimpleCombinerIdentity(nn.Module): def forward(self, src1: Tensor, src2: Tensor) -> Tensor: return src1 - class SimpleCombiner(torch.nn.Module): """ A very simple way of combining 2 vectors of 2 different dims, via a @@ -864,14 +831,18 @@ class SimpleCombiner(torch.nn.Module): dim2: the dimension of the second input, e.g. 384. The output will have the same dimension as dim2. """ - - def __init__(self, dim1: int, dim2: int, min_weight: Tuple[float] = (0.0, 0.0)): + def __init__(self, + dim1: int, + dim2: int, + min_weight: Tuple[float] = (0., 0.)): super(SimpleCombiner, self).__init__() assert dim2 >= dim1, (dim2, dim1) self.weight1 = nn.Parameter(torch.zeros(())) self.min_weight = min_weight - def forward(self, src1: Tensor, src2: Tensor) -> Tensor: + def forward(self, + src1: Tensor, + src2: Tensor) -> Tensor: """ src1: (*, dim1) src2: (*, dim2) @@ -882,14 +853,10 @@ class SimpleCombiner(torch.nn.Module): weight1 = self.weight1 if not torch.jit.is_scripting(): - if ( - self.training - and random.random() < 0.25 - and self.min_weight != (0.0, 0.0) - ): - weight1 = weight1.clamp( - min=self.min_weight[0], max=1.0 - self.min_weight[1] - ) + if self.training and random.random() < 0.25 and self.min_weight != (0., 0.): + weight1 = weight1.clamp(min=self.min_weight[0], + max=1.0-self.min_weight[1]) + src1 = src1 * weight1 src2 = src2 * (1.0 - weight1) @@ -902,9 +869,12 @@ class SimpleCombiner(torch.nn.Module): else: src1 = src1[:src2_dim] + return src1 + src2 + + class RelPositionalEncoding(torch.nn.Module): """Relative positional encoding module. @@ -918,7 +888,9 @@ class RelPositionalEncoding(torch.nn.Module): """ - def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: + def __init__( + self, d_model: int, dropout_rate: float, max_len: int = 5000 + ) -> None: """Construct a PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() self.d_model = d_model @@ -933,7 +905,9 @@ class RelPositionalEncoding(torch.nn.Module): # the length of self.pe is 2 * input_len - 1 if self.pe.size(1) >= x.size(0) * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): + if self.pe.dtype != x.dtype or str(self.pe.device) != str( + x.device + ): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return # Suppose `i` means to the position of query vecotr and `j` means the @@ -981,6 +955,7 @@ class RelPositionalEncoding(torch.nn.Module): return self.dropout(pos_emb) + class RelPositionMultiheadAttention(nn.Module): r"""Multi-Head Attention layer with relative position encoding @@ -1017,46 +992,34 @@ class RelPositionMultiheadAttention(nn.Module): self.head_dim = attention_dim // num_heads self.pos_dim = pos_dim assert self.head_dim % 2 == 0, self.head_dim - assert self.head_dim * num_heads == attention_dim, ( - self.head_dim, - num_heads, - attention_dim, - ) + assert ( + self.head_dim * num_heads == attention_dim + ), (self.head_dim, num_heads, attention_dim) # the initial_scale is supposed to take over the "scaling" factor of # head_dim ** -0.5, dividing it between the query and key. - in_proj_dim = ( - 2 * attention_dim - + attention_dim // 2 # query, key - + pos_dim * num_heads # value - ) # positional encoding query + in_proj_dim = (2 * attention_dim + # query, key + attention_dim // 2 + # value + pos_dim * num_heads) # positional encoding query - self.in_proj = ScaledLinear( - embed_dim, - in_proj_dim, - bias=True, - initial_scale=self.head_dim**-0.25, - ) + self.in_proj = ScaledLinear(embed_dim, in_proj_dim, bias=True, + initial_scale=self.head_dim**-0.25) # self.whiten_values is applied on the values in forward(); # it just copies the keys but prevents low-rank distribution by modifying grads. - self.whiten_values = Whiten( - num_groups=num_heads, - whitening_limit=2.0, - prob=(0.025, 0.25), - grad_scale=0.025, - ) - self.whiten_keys = Whiten( - num_groups=num_heads, - whitening_limit=2.0, - prob=(0.025, 0.25), - grad_scale=0.025, - ) + self.whiten_values = Whiten(num_groups=num_heads, + whitening_limit=2.0, + prob=(0.025, 0.25), + grad_scale=0.025) + self.whiten_keys = Whiten(num_groups=num_heads, + whitening_limit=2.0, + prob=(0.025, 0.25), + grad_scale=0.025) + # linear transformation for positional encoding. - self.linear_pos = ScaledLinear( - embed_dim, num_heads * pos_dim, bias=False, initial_scale=0.05 - ) + self.linear_pos = ScaledLinear(embed_dim, num_heads * pos_dim, bias=False, + initial_scale=0.05) # the following are for diagnosics only, see --print-diagnostics option. # they only copy their inputs. @@ -1068,16 +1031,14 @@ class RelPositionMultiheadAttention(nn.Module): ) self.in_proj2 = nn.Linear(embed_dim, attention_dim // 2, bias=False) - self.out_proj2 = ScaledLinear( - attention_dim // 2, embed_dim, bias=True, initial_scale=0.05 - ) + self.out_proj2 = ScaledLinear(attention_dim // 2, embed_dim, bias=True, + initial_scale=0.05) # self.whiten_values2 is applied on the values in forward2() - self.whiten_values2 = Whiten( - num_groups=num_heads, - whitening_limit=2.0, - prob=(0.025, 0.25), - grad_scale=0.025, - ) + self.whiten_values2 = Whiten(num_groups=num_heads, + whitening_limit=2.0, + prob=(0.025, 0.25), + grad_scale=0.025) + def forward( self, @@ -1137,6 +1098,7 @@ class RelPositionMultiheadAttention(nn.Module): ) return x, weights + def multi_head_attention_forward( self, x_proj: Tensor, @@ -1194,24 +1156,26 @@ class RelPositionMultiheadAttention(nn.Module): head_dim = attention_dim // num_heads pos_dim = self.pos_dim # positional-encoding dim per head - assert head_dim * num_heads == attention_dim, ( - f"attention_dim must be divisible by num_heads: {head_dim}, {num_heads}," - f" {attention_dim}" - ) + assert ( + head_dim * num_heads == attention_dim + ), f"attention_dim must be divisible by num_heads: {head_dim}, {num_heads}, {attention_dim}" + # self-attention - q = x_proj[..., 0:attention_dim] - k = x_proj[..., attention_dim : 2 * attention_dim] + q = x_proj[...,0:attention_dim] + k = x_proj[...,attention_dim:2*attention_dim] value_dim = attention_dim // 2 - v = x_proj[..., 2 * attention_dim : 2 * attention_dim + value_dim] + v = x_proj[...,2*attention_dim:2*attention_dim+value_dim] # p is the position-encoding query, its dimension is num_heads*pos_dim.. - p = x_proj[..., 2 * attention_dim + value_dim :] + p = x_proj[...,2*attention_dim+value_dim:] + k = self.whiten_keys(k) # does nothing in the forward pass. v = self.whiten_values(v) # does nothing in the forward pass. q = self.copy_query(q) # for diagnostics only, does nothing. p = self.copy_pos_query(p) # for diagnostics only, does nothing. + if attn_mask is not None: assert ( attn_mask.dtype == torch.float32 @@ -1231,25 +1195,33 @@ class RelPositionMultiheadAttention(nn.Module): if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, seq_len, seq_len]: - raise RuntimeError("The size of the 2D attn_mask is not correct.") + raise RuntimeError( + "The size of the 2D attn_mask is not correct." + ) elif attn_mask.dim() == 3: if list(attn_mask.size()) != [ bsz * num_heads, seq_len, seq_len, ]: - raise RuntimeError("The size of the 3D attn_mask is not correct.") + raise RuntimeError( + "The size of the 3D attn_mask is not correct." + ) else: raise RuntimeError( - "attn_mask's dimension {} is not supported".format(attn_mask.dim()) + "attn_mask's dimension {} is not supported".format( + attn_mask.dim() + ) ) # attn_mask's dim is 3 now. # convert ByteTensor key_padding_mask to bool - if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: + if ( + key_padding_mask is not None + and key_padding_mask.dtype == torch.uint8 + ): warnings.warn( - "Byte tensor for key_padding_mask is deprecated. Use bool tensor" - " instead." + "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." ) key_padding_mask = key_padding_mask.to(torch.bool) @@ -1258,6 +1230,7 @@ class RelPositionMultiheadAttention(nn.Module): k = k.reshape(seq_len, bsz, num_heads, head_dim) v = v.reshape(seq_len, bsz * num_heads, head_dim // 2).transpose(0, 1) + if key_padding_mask is not None: assert key_padding_mask.size(0) == bsz, "{} == {}".format( key_padding_mask.size(0), bsz @@ -1266,10 +1239,13 @@ class RelPositionMultiheadAttention(nn.Module): key_padding_mask.size(1), seq_len ) + + q = q.permute(1, 2, 0, 3) # (batch, head, time1, head_dim) p = p.permute(1, 2, 0, 3) # (batch, head, time1, pos_dim) k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) + seq_len2 = 2 * seq_len - 1 pos = pos.reshape(1, seq_len2, num_heads, pos_dim).permute(0, 2, 3, 1) # pos shape now: (batch, head, pos_dim, seq_len2) @@ -1280,16 +1256,13 @@ class RelPositionMultiheadAttention(nn.Module): # the following .as_strided() expression converts the last axis of pos_weights from relative # to absolute position. I don't know whether I might have got the time-offsets backwards or # not, but let this code define which way round it is supposed to be. - pos_weights = pos_weights.as_strided( - (bsz, num_heads, seq_len, seq_len), - ( - pos_weights.stride(0), - pos_weights.stride(1), - pos_weights.stride(2) - pos_weights.stride(3), - pos_weights.stride(3), - ), - storage_offset=pos_weights.stride(3) * (seq_len - 1), - ) + pos_weights = pos_weights.as_strided((bsz, num_heads, seq_len, seq_len), + (pos_weights.stride(0), + pos_weights.stride(1), + pos_weights.stride(2)-pos_weights.stride(3), + pos_weights.stride(3)), + storage_offset=pos_weights.stride(3) * (seq_len - 1)) + # caution: they are really scores at this point. attn_output_weights = torch.matmul(q, k) + pos_weights @@ -1302,9 +1275,10 @@ class RelPositionMultiheadAttention(nn.Module): # this mechanism instead of, say, a limit on entropy, because once the entropy # gets very small gradients through the softmax can become very small, and # some mechanisms like that become ineffective. - attn_output_weights = penalize_abs_values_gt( - attn_output_weights, limit=25.0, penalty=1.0e-04 - ) + attn_output_weights = penalize_abs_values_gt(attn_output_weights, + limit=25.0, + penalty=1.0e-04) + # attn_output_weights: (batch, head, time1, time2) attn_output_weights = attn_output_weights.view( @@ -1346,20 +1320,20 @@ class RelPositionMultiheadAttention(nn.Module): ) attn_output = torch.bmm(attn_output_weights, v) - assert list(attn_output.size()) == [ - bsz * num_heads, - seq_len, - head_dim // 2, - ] + assert list(attn_output.size()) == [bsz * num_heads, seq_len, + head_dim // 2] attn_output = ( attn_output.transpose(0, 1) .contiguous() .view(seq_len, bsz, attention_dim // 2) ) - attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) + attn_output = nn.functional.linear( + attn_output, out_proj_weight, out_proj_bias + ) return attn_output, attn_output_weights + def forward2( self, x: Tensor, @@ -1398,7 +1372,11 @@ class RelPositionMultiheadAttention(nn.Module): # returned value is of shape (seq_len, bsz, embed_dim), like x. return self.out_proj2(attn_output) - def _print_attn_stats(self, attn_weights: Tensor, attn_output: Tensor): + + def _print_attn_stats( + self, + attn_weights: Tensor, + attn_output: Tensor): # attn_weights: (batch_size * num_heads, seq_len, seq_len) # attn_output: (bsz * num_heads, seq_len, head_dim) (n, seq_len, head_dim) = attn_output.shape @@ -1409,50 +1387,39 @@ class RelPositionMultiheadAttention(nn.Module): with torch.cuda.amp.autocast(enabled=False): attn_weights = attn_weights.to(torch.float32) attn_output = attn_output.to(torch.float32) - attn_weights_entropy = ( - -((attn_weights + 1.0e-20).log() * attn_weights) - .sum(dim=-1) - .reshape(bsz, num_heads, seq_len) - .mean(dim=(0, 2)) - ) + attn_weights_entropy = -((attn_weights + 1.0e-20).log() * attn_weights).sum( + dim=-1).reshape(bsz, num_heads, seq_len).mean(dim=(0,2)) attn_output = attn_output.reshape(bsz, num_heads, seq_len, head_dim) - attn_output = attn_output.permute(1, 0, 2, 3).reshape( - num_heads, bsz * seq_len, head_dim - ) + attn_output = attn_output.permute(1, 0, 2, 3).reshape(num_heads, bsz * seq_len, head_dim) attn_output_mean = attn_output.mean(dim=1, keepdim=True) attn_output = attn_output - attn_output_mean - attn_covar = torch.matmul(attn_output.transpose(1, 2), attn_output) / ( - bsz * seq_len - ) + attn_covar = torch.matmul(attn_output.transpose(1, 2), attn_output) / (bsz * seq_len) # attn_covar: (num_heads, head_dim, head_dim) - # eigs, _ = torch.symeig(attn_covar) - # logging.info(f"attn_weights_entropy = {attn_weights_entropy}, output_eigs = {eigs}") + #eigs, _ = torch.symeig(attn_covar) + #logging.info(f"attn_weights_entropy = {attn_weights_entropy}, output_eigs = {eigs}") attn_covar = _diag(attn_covar).mean(dim=1) # (num_heads,) embed_dim = self.in_proj2.weight.shape[1] - in_proj_covar = ( - self.in_proj2.weight.reshape(num_heads, head_dim, embed_dim) ** 2 - ).mean(dim=(1, 2)) - out_proj_covar = ( - self.out_proj2.weight.reshape(embed_dim, num_heads, head_dim) ** 2 - ).mean(dim=(0, 2)) - logging.info( - f"attn_weights_entropy = {attn_weights_entropy}," - f" covar={attn_covar}, in_proj_covar={in_proj_covar}," - f" out_proj_covar={out_proj_covar}" - ) + in_proj_covar = (self.in_proj2.weight.reshape(num_heads, head_dim, embed_dim) ** 2).mean(dim=(1,2)) + out_proj_covar = (self.out_proj2.weight.reshape(embed_dim, num_heads, head_dim) ** 2).mean(dim=(0,2)) + logging.info(f"attn_weights_entropy = {attn_weights_entropy}, covar={attn_covar}, in_proj_covar={in_proj_covar}, out_proj_covar={out_proj_covar}") + + class PoolingModule(nn.Module): """ Averages the input over the time dimension and project with a square matrix. """ - - def __init__(self, d_model: int): + def __init__(self, + d_model: int): super().__init__() - self.proj = ScaledLinear(d_model, d_model, initial_scale=0.1, bias=False) + self.proj = ScaledLinear(d_model, d_model, + initial_scale=0.1, bias=False) - def forward(self, x: Tensor, key_padding_mask: Optional[Tensor] = None): + def forward(self, + x: Tensor, + key_padding_mask: Optional[Tensor] = None): """ Args: x: a Tensor of shape (T, N, C) @@ -1463,7 +1430,7 @@ class PoolingModule(nn.Module): """ if key_padding_mask is not None: pooling_mask = key_padding_mask.logical_not().to(x.dtype) # (N, T) - pooling_mask = pooling_mask / pooling_mask.sum(dim=1, keepdim=True) + pooling_mask = (pooling_mask / pooling_mask.sum(dim=1, keepdim=True)) pooling_mask = pooling_mask.transpose(0, 1).contiguous().unsqueeze(-1) # now pooling_mask: (T, N, 1) x = (x * pooling_mask).sum(dim=0, keepdim=True) @@ -1477,19 +1444,24 @@ class PoolingModule(nn.Module): class FeedforwardModule(nn.Module): - """Feedforward module in Zipformer model.""" - - def __init__(self, d_model: int, feedforward_dim: int, dropout: float): + """Feedforward module in Zipformer model. + """ + def __init__(self, + d_model: int, + feedforward_dim: int, + dropout: float): super(FeedforwardModule, self).__init__() self.in_proj = nn.Linear(d_model, feedforward_dim) - self.balancer = ActivationBalancer( - feedforward_dim, channel_dim=-1, max_abs=10.0, min_prob=0.25 - ) + self.balancer = ActivationBalancer(feedforward_dim, + channel_dim=-1, max_abs=10.0, + min_prob=0.25) self.activation = DoubleSwish() self.dropout = nn.Dropout(dropout) - self.out_proj = ScaledLinear(feedforward_dim, d_model, initial_scale=0.01) + self.out_proj = ScaledLinear(feedforward_dim, d_model, + initial_scale=0.01) - def forward(self, x: Tensor): + def forward(self, + x: Tensor): x = self.in_proj(x) x = self.balancer(x) x = self.activation(x) @@ -1509,7 +1481,9 @@ class ConvolutionModule(nn.Module): """ - def __init__(self, channels: int, kernel_size: int, bias: bool = True) -> None: + def __init__( + self, channels: int, kernel_size: int, bias: bool = True + ) -> None: """Construct an ConvolutionModule object.""" super(ConvolutionModule, self).__init__() # kernerl_size should be a odd number for 'SAME' padding @@ -1539,10 +1513,7 @@ class ConvolutionModule(nn.Module): # the correct range. self.deriv_balancer1 = ActivationBalancer( 2 * channels, - channel_dim=1, - max_abs=10.0, - min_positive=0.05, - max_positive=1.0, + channel_dim=1, max_abs=10.0, min_positive=0.05, max_positive=1.0 ) self.depthwise_conv = nn.Conv1d( @@ -1556,10 +1527,8 @@ class ConvolutionModule(nn.Module): ) self.deriv_balancer2 = ActivationBalancer( - channels, - channel_dim=1, - min_positive=0.05, - max_positive=1.0, + channels, channel_dim=1, + min_positive=0.05, max_positive=1.0, max_abs=20.0, ) @@ -1575,10 +1544,9 @@ class ConvolutionModule(nn.Module): initial_scale=0.05, ) - def forward( - self, - x: Tensor, - src_key_padding_mask: Optional[Tensor] = None, + def forward(self, + x: Tensor, + src_key_padding_mask: Optional[Tensor] = None, ) -> Tensor: """Compute convolution module. @@ -1658,7 +1626,8 @@ class Conv2dSubsampling(nn.Module): kernel_size=3, padding=(0, 1), # (time, freq) ), - ActivationBalancer(layer1_channels, channel_dim=1), + ActivationBalancer(layer1_channels, + channel_dim=1), DoubleSwish(), nn.Conv2d( in_channels=layer1_channels, @@ -1667,21 +1636,24 @@ class Conv2dSubsampling(nn.Module): stride=2, padding=0, ), - ActivationBalancer(layer2_channels, channel_dim=1), + ActivationBalancer(layer2_channels, + channel_dim=1), DoubleSwish(), nn.Conv2d( in_channels=layer2_channels, out_channels=layer3_channels, kernel_size=3, - stride=(1, 2), # (time, freq) + stride=(1, 2), # (time, freq) ), - ActivationBalancer(layer3_channels, channel_dim=1), + ActivationBalancer(layer3_channels, + channel_dim=1), DoubleSwish(), ) out_height = (((in_channels - 1) // 2) - 1) // 2 self.out = ScaledLinear(out_height * layer3_channels, out_channels) self.dropout = nn.Dropout(dropout) + def forward(self, x: torch.Tensor) -> torch.Tensor: """Subsample x. @@ -1702,7 +1674,6 @@ class Conv2dSubsampling(nn.Module): x = self.dropout(x) return x - class AttentionCombine(nn.Module): """ This module combines a list of Tensors, all with the same shape, to @@ -1746,12 +1717,15 @@ class AttentionCombine(nn.Module): self.random_prob = random_prob self.single_prob = single_prob - self.weight = torch.nn.Parameter(torch.zeros(num_channels, num_inputs)) + self.weight = torch.nn.Parameter(torch.zeros(num_channels, + num_inputs)) self.bias = torch.nn.Parameter(torch.zeros(num_inputs)) assert 0 <= random_prob <= 1, random_prob assert 0 <= single_prob <= 1, single_prob + + def forward(self, inputs: List[Tensor]) -> Tensor: """Forward function. Args: @@ -1782,35 +1756,28 @@ class AttentionCombine(nn.Module): if self.training: # random masking.. - mask_start = torch.randint( - low=1, - high=int(num_inputs / self.random_prob), - size=(num_frames,), - device=scores.device, - ).unsqueeze(1) + mask_start = torch.randint(low=1, high=int(num_inputs / self.random_prob), + size=(num_frames,), device=scores.device).unsqueeze(1) # mask will have rows like: [ False, False, False, True, True, .. ] - arange = ( - torch.arange(num_inputs, device=scores.device) - .unsqueeze(0) - .expand(num_frames, num_inputs) - ) + arange = torch.arange(num_inputs, device=scores.device).unsqueeze(0).expand( + num_frames, num_inputs) mask = arange >= mask_start - apply_single_prob = torch.logical_and( - torch.rand(size=(num_frames, 1), device=scores.device) - < self.single_prob, - mask_start < num_inputs, - ) - single_prob_mask = torch.logical_and( - apply_single_prob, arange < mask_start - 1 - ) + apply_single_prob = torch.logical_and(torch.rand(size=(num_frames, 1), + device=scores.device) < self.single_prob, + mask_start < num_inputs) + single_prob_mask = torch.logical_and(apply_single_prob, + arange < mask_start - 1) - mask = torch.logical_or(mask, single_prob_mask) + mask = torch.logical_or(mask, + single_prob_mask) - scores = scores.masked_fill(mask, float("-inf")) + scores = scores.masked_fill(mask, float('-inf')) if self.training and random.random() < 0.1: - scores = penalize_abs_values_gt(scores, limit=10.0, penalty=1.0e-04) + scores = penalize_abs_values_gt(scores, + limit=10.0, + penalty=1.0e-04) weights = scores.softmax(dim=1) @@ -1825,6 +1792,7 @@ class AttentionCombine(nn.Module): return ans + def _test_random_combine(): print("_test_random_combine()") num_inputs = 3 @@ -1833,8 +1801,8 @@ def _test_random_combine(): num_channels=num_channels, num_inputs=num_inputs, random_prob=0.5, - single_prob=0.0, - ) + single_prob=0.0) + x = [torch.ones(3, 4, num_channels) for _ in range(num_inputs)] @@ -1851,10 +1819,7 @@ def _test_zipformer_main(): # Just make sure the forward pass runs. c = Zipformer( - num_features=feature_dim, - encoder_dims=(64, 96), - encoder_unmasked_dims=(48, 64), - nhead=(4, 4), + num_features=feature_dim, encoder_dims=(64,96), encoder_unmasked_dims=(48,64), nhead=(4,4) ) batch_size = 5 seq_len = 20 @@ -1872,18 +1837,19 @@ def _test_zipformer_main(): ) f # to remove flake8 warnings - def _test_conv2d_subsampling(): num_features = 80 encoder_dims = 384 dropout = 0.1 - encoder_embed = Conv2dSubsampling(num_features, encoder_dims, dropout=dropout) + encoder_embed = Conv2dSubsampling(num_features, encoder_dims, + dropout=dropout) for i in range(20, 40): x = torch.rand(2, i, num_features) y = encoder_embed(x) assert (x.shape[1] - 7) // 2 == y.shape[1], (x.shape[1], y.shape[1]) + if __name__ == "__main__": logging.getLogger().setLevel(logging.INFO) torch.set_num_threads(1) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless8/decode.py index 822f8e44b..9d7335e77 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless8/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/decode.py @@ -165,24 +165,20 @@ def get_parser(): "--avg", type=int, default=9, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help=( - "Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. " - ), + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", ) parser.add_argument( @@ -277,7 +273,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -397,7 +394,9 @@ def decode_one_batch( simulate_streaming=True, ) else: - encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + encoder_out, encoder_out_lens = model.encoder( + x=feature, x_lens=feature_lens + ) hyps = [] @@ -456,7 +455,10 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + elif ( + params.decoding_method == "greedy_search" + and params.max_sym_per_frame == 1 + ): hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -587,7 +589,9 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) return results @@ -620,7 +624,8 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -675,7 +680,9 @@ def main(): if "LG" in params.decoding_method: params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + params.suffix += ( + f"-{params.decoding_method}-beam-size-{params.beam_size}" + ) else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -712,12 +719,13 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -745,12 +753,13 @@ def main(): ) else: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -779,7 +788,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - "Calculating the averaged model over epoch range from " + f"Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -807,7 +816,9 @@ def main(): decoding_graph.scores *= params.ngram_lm_scale else: word_table = None - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + decoding_graph = k2.trivial_graph( + params.vocab_size - 1, device=device + ) else: decoding_graph = None word_table = None diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/export.py b/egs/librispeech/ASR/pruned_transducer_stateless8/export.py index 43eb0c1bc..49f469e29 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless8/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/export.py @@ -129,24 +129,20 @@ def get_parser(): "--avg", type=int, default=9, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help=( - "Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. " - ), + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", ) parser.add_argument( @@ -180,7 +176,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) add_model_arguments(parser) @@ -220,12 +217,13 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -254,12 +252,13 @@ def main(): ) else: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -288,7 +287,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - "Calculating the averaged model over epoch range from " + f"Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -327,7 +326,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/jit_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless8/jit_pretrained.py index ed920dc03..e79a3a3aa 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless8/jit_pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/jit_pretrained.py @@ -69,12 +69,10 @@ def get_parser(): "sound_files", type=str, nargs="+", - help=( - "The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz." - ), + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", ) return parser @@ -95,9 +93,10 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) # We use only the first channel ans.append(wave[0]) return ans @@ -268,7 +267,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/model.py b/egs/librispeech/ASR/pruned_transducer_stateless8/model.py index 39a360796..497b89136 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless8/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/model.py @@ -160,7 +160,9 @@ class Transducer(nn.Module): y_padded = y.pad(mode="constant", padding_value=0) y_padded = y_padded.to(torch.int64) - boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) + boundary = torch.zeros( + (x.size(0), 4), dtype=torch.int64, device=x.device + ) boundary[:, 2] = y_lens boundary[:, 3] = x_lens diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless8/pretrained.py index 716136812..373a48fc1 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless8/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/pretrained.py @@ -100,11 +100,9 @@ def get_parser(): "--checkpoint", type=str, required=True, - help=( - "Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint()." - ), + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", ) parser.add_argument( @@ -129,12 +127,10 @@ def get_parser(): "sound_files", type=str, nargs="+", - help=( - "The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz." - ), + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", ) parser.add_argument( @@ -181,7 +177,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -212,9 +209,10 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) # We use only the first channel ans.append(wave[0]) return ans @@ -277,11 +275,15 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + features = pad_sequence( + features, batch_first=True, padding_value=math.log(1e-10) + ) feature_lengths = torch.tensor(feature_lengths, device=device) - encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths) + encoder_out, encoder_out_lens = model.encoder( + x=features, x_lens=feature_lengths + ) num_waves = encoder_out.size(0) hyps = [] @@ -353,7 +355,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/train.py b/egs/librispeech/ASR/pruned_transducer_stateless8/train.py index 381a86a67..2603bb854 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless8/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/train.py @@ -92,7 +92,9 @@ from icefall.env import get_env_info from icefall.hooks import register_inf_check_hooks from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] +LRSchedulerType = Union[ + torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler +] def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: @@ -130,10 +132,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): "--encoder-dims", type=str, default="384,384,384,384,384", - help=( - "Embedding dimension in the 2 blocks of zipformer encoder layers, comma" - " separated" - ), + help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated", ) parser.add_argument( @@ -148,11 +147,9 @@ def add_model_arguments(parser: argparse.ArgumentParser): "--encoder-unmasked-dims", type=str, default="256,256,256,256,256", - help=( - "Unmasked dimensions in the encoders, relates to augmentation during" - " training. Must be <= each of encoder_dims. Empirically, less than 256" - " seems to make performance worse." - ), + help="Unmasked dimensions in the encoders, relates to augmentation during training. " + "Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance " + " worse.", ) parser.add_argument( @@ -217,7 +214,8 @@ def get_parser(): "--full-libri", type=str2bool, default=True, - help="When enabled, use 960h LibriSpeech. Otherwise, use 100h subset.", + help="When enabled, use 960h LibriSpeech. " + "Otherwise, use 100h subset.", ) parser.add_argument( @@ -287,45 +285,42 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help=( - "The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss" - ), + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help=( - "The scale to smooth the loss with lm (output of prediction network) part." - ), + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)part.", + help="The scale to smooth the loss with am (output of encoder network)" + "part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help=( - "To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss." - ), + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", ) parser.add_argument( @@ -696,7 +691,11 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = model.device if isinstance(model, DDP) else next(model.parameters()).device + device = ( + model.device + if isinstance(model, DDP) + else next(model.parameters()).device + ) feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -745,7 +744,9 @@ def compute_loss( info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + info["frames"] = ( + (feature_lens // params.subsampling_factor).sum().item() + ) # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() @@ -951,7 +952,9 @@ def train_one_epoch( # of the grad scaler is configurable, but we can't configure it to have different # behavior depending on the current grad scale. cur_grad_scale = scaler._scale.item() - if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0): + if cur_grad_scale < 1.0 or ( + cur_grad_scale < 8.0 and batch_idx % 400 == 0 + ): scaler.update(cur_grad_scale * 2.0) if cur_grad_scale < 0.01: logging.warning(f"Grad scale is small: {cur_grad_scale}") @@ -972,7 +975,11 @@ def train_one_epoch( f"giga_tot_loss[{giga_tot_loss}], " f"batch size: {batch_size}, " f"lr: {cur_lr:.2e}, " - + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + + ( + f"grad_scale: {scaler._scale.item()}" + if params.use_fp16 + else "" + ) ) if tb_writer is not None: @@ -985,8 +992,12 @@ def train_one_epoch( f"train/current_{prefix}_", params.batch_idx_train, ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) libri_tot_loss.write_summary( tb_writer, "train/libri_tot_", params.batch_idx_train ) @@ -1000,7 +1011,10 @@ def train_one_epoch( params.batch_idx_train, ) - if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + if ( + batch_idx % params.valid_interval == 0 + and not params.print_diagnostics + ): logging.info("Computing validation loss") valid_info = compute_validation_loss( params=params, @@ -1012,8 +1026,7 @@ def train_one_epoch( model.train() logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") logging.info( - "Maximum memory allocated so far is" - f" {torch.cuda.max_memory_allocated()//1000000}MB" + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" ) if tb_writer is not None: valid_info.write_summary( @@ -1041,7 +1054,8 @@ def filter_short_and_long_utterances( # the threshold if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. " + f"Duration: {c.duration}" ) return False @@ -1138,7 +1152,9 @@ def run(rank, world_size, args): logging.info("Using DDP") model = DDP(model, device_ids=[rank], find_unused_parameters=True) - optimizer = ScaledAdam(model.parameters(), lr=params.base_lr, clipping_scale=2.0) + optimizer = ScaledAdam( + model.parameters(), lr=params.base_lr, clipping_scale=2.0 + ) scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) @@ -1156,7 +1172,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 2 ** 22 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) @@ -1191,7 +1207,9 @@ def run(rank, world_size, args): train_giga_cuts = train_giga_cuts.repeat(times=None) if args.enable_musan: - cuts_musan = load_manifest(Path(args.manifest_dir) / "musan_cuts.jsonl.gz") + cuts_musan = load_manifest( + Path(args.manifest_dir) / "musan_cuts.jsonl.gz" + ) else: cuts_musan = None @@ -1346,8 +1364,7 @@ def scan_pessimistic_batches_for_oom( display_and_save_batch(batch, params=params, sp=sp) raise logging.info( - "Maximum memory allocated so far is" - f" {torch.cuda.max_memory_allocated()//1000000}MB" + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" ) diff --git a/egs/librispeech/ASR/streaming_conformer_ctc/README.md b/egs/librispeech/ASR/streaming_conformer_ctc/README.md index 53f383c99..01be7090b 100644 --- a/egs/librispeech/ASR/streaming_conformer_ctc/README.md +++ b/egs/librispeech/ASR/streaming_conformer_ctc/README.md @@ -1,20 +1,20 @@ ## Train and Decode -Commands of data preparation/train/decode steps are almost the same with +Commands of data preparation/train/decode steps are almost the same with ../conformer_ctc experiment except some options. Please read the code and understand following new added options before running this experiment: For data preparation: - + Nothing new. For streaming_conformer_ctc/train.py: - + --dynamic-chunk-training --short-chunk-proportion For streaming_conformer_ctc/streaming_decode.py: - + --chunk-size --tailing-num-frames --simulate-streaming @@ -57,10 +57,10 @@ And check md5sum values again. Finally, following files will be downloaded:

-streaming_models/
-|-- lang_bpe
-|   |-- L.pt
-|   |-- Linv.pt
+streaming_models/  
+|-- lang_bpe  
+|   |-- L.pt  
+|   |-- Linv.pt  
 |   |-- bpe.model
 |   |-- tokens.txt
 |   `-- words.txt
diff --git a/egs/librispeech/ASR/streaming_conformer_ctc/conformer.py b/egs/librispeech/ASR/streaming_conformer_ctc/conformer.py
index 4f7427c1f..ff4c91446 100644
--- a/egs/librispeech/ASR/streaming_conformer_ctc/conformer.py
+++ b/egs/librispeech/ASR/streaming_conformer_ctc/conformer.py
@@ -309,26 +309,36 @@ class Conformer(Transformer):
 
                 # start chunk_by_chunk decoding
                 offset = 0
-                for cur in range(0, num_frames - embed_left_context + 1, stride):
+                for cur in range(
+                    0, num_frames - embed_left_context + 1, stride
+                ):
                     end = min(cur + decoding_window, num_frames)
                     cur_feature = feature[:, cur:end, :]
                     cur_feature = self.encoder_embed(cur_feature)
-                    cur_embed, cur_pos_emb = self.encoder_pos(cur_feature, offset)
-                    cur_embed = cur_embed.permute(1, 0, 2)  # (B, T, F) -> (T, B, F)
+                    cur_embed, cur_pos_emb = self.encoder_pos(
+                        cur_feature, offset
+                    )
+                    cur_embed = cur_embed.permute(
+                        1, 0, 2
+                    )  # (B, T, F) -> (T, B, F)
 
                     cur_T = cur_feature.size(1)
                     if cur == 0:
                         # for first chunk extract the central pos embedding
-                        pos_emb_central = cur_pos_emb[0, (chunk_size - 1), :].view(
-                            1, 1, -1
-                        )
+                        pos_emb_central = cur_pos_emb[
+                            0, (chunk_size - 1), :
+                        ].view(1, 1, -1)
                         cur_T -= 1
                     pos_emb_positive.append(cur_pos_emb[0, :cur_T].flip(0))
                     pos_emb_negative.append(cur_pos_emb[0, -cur_T:])
                     assert pos_emb_positive[-1].size(0) == cur_T
 
-                    pos_emb_pos = torch.cat(pos_emb_positive, dim=0).unsqueeze(0)
-                    pos_emb_neg = torch.cat(pos_emb_negative, dim=0).unsqueeze(0)
+                    pos_emb_pos = torch.cat(pos_emb_positive, dim=0).unsqueeze(
+                        0
+                    )
+                    pos_emb_neg = torch.cat(pos_emb_negative, dim=0).unsqueeze(
+                        0
+                    )
                     cur_pos_emb = torch.cat(
                         [pos_emb_pos.flip(1), pos_emb_central, pos_emb_neg],
                         dim=1,
@@ -403,7 +413,9 @@ class ConformerEncoderLayer(nn.Module):
         causal: bool = False,
     ) -> None:
         super(ConformerEncoderLayer, self).__init__()
-        self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0)
+        self.self_attn = RelPositionMultiheadAttention(
+            d_model, nhead, dropout=0.0
+        )
 
         self.feed_forward = nn.Sequential(
             nn.Linear(d_model, dim_feedforward),
@@ -419,16 +431,22 @@ class ConformerEncoderLayer(nn.Module):
             nn.Linear(dim_feedforward, d_model),
         )
 
-        self.conv_module = ConvolutionModule(d_model, cnn_module_kernel, causal=causal)
+        self.conv_module = ConvolutionModule(
+            d_model, cnn_module_kernel, causal=causal
+        )
 
-        self.norm_ff_macaron = nn.LayerNorm(d_model)  # for the macaron style FNN module
+        self.norm_ff_macaron = nn.LayerNorm(
+            d_model
+        )  # for the macaron style FNN module
         self.norm_ff = nn.LayerNorm(d_model)  # for the FNN module
         self.norm_mha = nn.LayerNorm(d_model)  # for the MHA module
 
         self.ff_scale = 0.5
 
         self.norm_conv = nn.LayerNorm(d_model)  # for the CNN module
-        self.norm_final = nn.LayerNorm(d_model)  # for the final output of the block
+        self.norm_final = nn.LayerNorm(
+            d_model
+        )  # for the final output of the block
 
         self.dropout = nn.Dropout(dropout)
 
@@ -462,7 +480,9 @@ class ConformerEncoderLayer(nn.Module):
         residual = src
         if self.normalize_before:
             src = self.norm_ff_macaron(src)
-        src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src))
+        src = residual + self.ff_scale * self.dropout(
+            self.feed_forward_macaron(src)
+        )
         if not self.normalize_before:
             src = self.norm_ff_macaron(src)
 
@@ -534,7 +554,9 @@ class ConformerEncoderLayer(nn.Module):
         residual = src
         if self.normalize_before:
             src = self.norm_ff_macaron(src)
-        src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src))
+        src = residual + self.ff_scale * self.dropout(
+            self.feed_forward_macaron(src)
+        )
         if not self.normalize_before:
             src = self.norm_ff_macaron(src)
 
@@ -714,7 +736,9 @@ class RelPositionalEncoding(torch.nn.Module):
 
     """
 
-    def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None:
+    def __init__(
+        self, d_model: int, dropout_rate: float, max_len: int = 5000
+    ) -> None:
         """Construct an PositionalEncoding object."""
         super(RelPositionalEncoding, self).__init__()
         self.d_model = d_model
@@ -731,7 +755,9 @@ class RelPositionalEncoding(torch.nn.Module):
             # the length of self.pe is 2 * input_len - 1
             if self.pe.size(1) >= x_size_1 * 2 - 1:
                 # Note: TorchScript doesn't implement operator== for torch.Device
-                if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device):
+                if self.pe.dtype != x.dtype or str(self.pe.device) != str(
+                    x.device
+                ):
                     self.pe = self.pe.to(dtype=x.dtype, device=x.device)
                 return
         # Suppose `i` means to the position of query vector and `j` means the
@@ -757,7 +783,9 @@ class RelPositionalEncoding(torch.nn.Module):
         pe = torch.cat([pe_positive, pe_negative], dim=1)
         self.pe = pe.to(device=x.device, dtype=x.dtype)
 
-    def forward(self, x: torch.Tensor, offset: int = 0) -> Tuple[Tensor, Tensor]:
+    def forward(
+        self, x: torch.Tensor, offset: int = 0
+    ) -> Tuple[Tensor, Tensor]:
         """Add positional encoding.
 
         Args:
@@ -785,7 +813,9 @@ class RelPositionalEncoding(torch.nn.Module):
             pos_emb = torch.cat(
                 [
                     pos_emb[:, : (x_T - 1)],
-                    self.pe[0, self.pe.size(1) // 2].view(1, 1, self.pe.size(-1)),
+                    self.pe[0, self.pe.size(1) // 2].view(
+                        1, 1, self.pe.size(-1)
+                    ),
                     pos_emb[:, -(x_T - 1) :],  # noqa: E203
                 ],
                 dim=1,
@@ -1020,9 +1050,9 @@ class RelPositionMultiheadAttention(nn.Module):
 
         if torch.equal(query, key) and torch.equal(key, value):
             # self-attention
-            q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk(
-                3, dim=-1
-            )
+            q, k, v = nn.functional.linear(
+                query, in_proj_weight, in_proj_bias
+            ).chunk(3, dim=-1)
 
         elif torch.equal(key, value):
             # encoder-decoder attention
@@ -1090,25 +1120,33 @@ class RelPositionMultiheadAttention(nn.Module):
             if attn_mask.dim() == 2:
                 attn_mask = attn_mask.unsqueeze(0)
                 if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
-                    raise RuntimeError("The size of the 2D attn_mask is not correct.")
+                    raise RuntimeError(
+                        "The size of the 2D attn_mask is not correct."
+                    )
             elif attn_mask.dim() == 3:
                 if list(attn_mask.size()) != [
                     bsz * num_heads,
                     query.size(0),
                     key.size(0),
                 ]:
-                    raise RuntimeError("The size of the 3D attn_mask is not correct.")
+                    raise RuntimeError(
+                        "The size of the 3D attn_mask is not correct."
+                    )
             else:
                 raise RuntimeError(
-                    "attn_mask's dimension {} is not supported".format(attn_mask.dim())
+                    "attn_mask's dimension {} is not supported".format(
+                        attn_mask.dim()
+                    )
                 )
             # attn_mask's dim is 3 now.
 
         # convert ByteTensor key_padding_mask to bool
-        if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
+        if (
+            key_padding_mask is not None
+            and key_padding_mask.dtype == torch.uint8
+        ):
             warnings.warn(
-                "Byte tensor for key_padding_mask is deprecated. Use bool tensor"
-                " instead."
+                "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead."
             )
             key_padding_mask = key_padding_mask.to(torch.bool)
 
@@ -1147,16 +1185,24 @@ class RelPositionMultiheadAttention(nn.Module):
         # first compute matrix a and matrix c
         # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
         k = k.permute(1, 2, 3, 0)  # (batch, head, d_k, time2)
-        matrix_ac = torch.matmul(q_with_bias_u, k)  # (batch, head, time1, time2)
+        matrix_ac = torch.matmul(
+            q_with_bias_u, k
+        )  # (batch, head, time1, time2)
 
         # compute matrix b and matrix d
-        matrix_bd = torch.matmul(q_with_bias_v, p)  # (batch, head, time1, 2*time1-1)
-        matrix_bd = self.rel_shift(matrix_bd, offset=offset)  # [B, head, time1, time2]
+        matrix_bd = torch.matmul(
+            q_with_bias_v, p
+        )  # (batch, head, time1, 2*time1-1)
+        matrix_bd = self.rel_shift(
+            matrix_bd, offset=offset
+        )  # [B, head, time1, time2]
         attn_output_weights = (
             matrix_ac + matrix_bd
         ) * scaling  # (batch, head, time1, time2)
 
-        attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1)
+        attn_output_weights = attn_output_weights.view(
+            bsz * num_heads, tgt_len, -1
+        )
 
         assert list(attn_output_weights.size()) == [
             bsz * num_heads,
@@ -1190,9 +1236,13 @@ class RelPositionMultiheadAttention(nn.Module):
         attn_output = torch.bmm(attn_output_weights, v)
         assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
         attn_output = (
-            attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
+            attn_output.transpose(0, 1)
+            .contiguous()
+            .view(tgt_len, bsz, embed_dim)
+        )
+        attn_output = nn.functional.linear(
+            attn_output, out_proj_weight, out_proj_bias
         )
-        attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias)
 
         if need_weights:
             # average attention weights over heads
diff --git a/egs/librispeech/ASR/streaming_conformer_ctc/streaming_decode.py b/egs/librispeech/ASR/streaming_conformer_ctc/streaming_decode.py
index 5a8149aad..a74c51836 100755
--- a/egs/librispeech/ASR/streaming_conformer_ctc/streaming_decode.py
+++ b/egs/librispeech/ASR/streaming_conformer_ctc/streaming_decode.py
@@ -28,7 +28,6 @@ import torch
 import torch.nn as nn
 from asr_datamodule import LibriSpeechAsrDataModule
 from conformer import Conformer
-
 from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
 from icefall.checkpoint import average_checkpoints, load_checkpoint
 from icefall.lexicon import Lexicon
@@ -63,36 +62,32 @@ def get_parser():
         "--epoch",
         type=int,
         default=34,
-        help=(
-            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
-        ),
+        help="It specifies the checkpoint to use for decoding."
+        "Note: Epoch counts from 0.",
     )
     parser.add_argument(
         "--avg",
         type=int,
         default=20,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch'. "
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch'. ",
     )
 
     parser.add_argument(
         "--chunk-size",
         type=int,
         default=8,
-        help=(
-            "Frames of right context"
-            "-1 for whole right context, i.e. non-streaming decoding"
-        ),
+        help="Frames of right context"
+        "-1 for whole right context, i.e. non-streaming decoding",
     )
 
     parser.add_argument(
         "--tailing-num-frames",
         type=int,
         default=20,
-        help="tailing dummy frames padded to the right,only used during decoding",
+        help="tailing dummy frames padded to the right,"
+        "only used during decoding",
     )
 
     parser.add_argument(
@@ -144,7 +139,8 @@ def get_parser():
         "--avg-models",
         type=str,
         default=None,
-        help="Manually select models to average, seperated by comma;e.g. 60,62,63,72",
+        help="Manually select models to average, seperated by comma;"
+        "e.g. 60,62,63,72",
     )
 
     return parser
@@ -252,9 +248,13 @@ def decode_one_batch(
     maxlen = nnet_output.size(1)
     topk_prob, topk_index = nnet_output.topk(1, dim=2)  # (B, maxlen, 1)
     topk_index = topk_index.view(batch_size, maxlen)  # (B, maxlen)
-    topk_index = topk_index.masked_fill_(memory_key_padding_mask, 0)  # (B, maxlen)
+    topk_index = topk_index.masked_fill_(
+        memory_key_padding_mask, 0
+    )  # (B, maxlen)
     token_ids = [token_id.tolist() for token_id in topk_index]
-    token_ids = [remove_duplicates_and_blank(token_id) for token_id in token_ids]
+    token_ids = [
+        remove_duplicates_and_blank(token_id) for token_id in token_ids
+    ]
     hyps = bpe_model.decode(token_ids)
     hyps = [s.split() for s in hyps]
     return {key: hyps}
@@ -337,7 +337,9 @@ def decode_dataset(
         if batch_idx % 100 == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+            logging.info(
+                f"batch {batch_str}, cuts processed until now is {num_cuts}"
+            )
 
     return results
 
@@ -355,18 +357,15 @@ def save_results(
     test_set_wers = dict()
     if params.avg_models is not None:
         avg_models = params.avg_models.replace(",", "_")
-        result_file_prefix = (
-            f"epoch-avg-{avg_models}-chunksize        "
-            f" -{params.chunk_size}-tailing-num-frames-{params.tailing_num_frames}-"
-        )
+        result_file_prefix = f"epoch-avg-{avg_models}-chunksize \
+        -{params.chunk_size}-tailing-num-frames-{params.tailing_num_frames}-"
     else:
-        result_file_prefix = (
-            f"epoch-{params.epoch}-avg-{params.avg}-chunksize        "
-            f" -{params.chunk_size}-tailing-num-frames-{params.tailing_num_frames}-"
-        )
+        result_file_prefix = f"epoch-{params.epoch}-avg-{params.avg}-chunksize \
+        -{params.chunk_size}-tailing-num-frames-{params.tailing_num_frames}-"
     for key, results in results_dict.items():
         recog_path = (
-            params.exp_dir / f"{result_file_prefix}recogs-{test_set_name}-{key}.txt"
+            params.exp_dir
+            / f"{result_file_prefix}recogs-{test_set_name}-{key}.txt"
         )
         store_transcripts(filename=recog_path, texts=results)
         if enable_log:
@@ -375,7 +374,8 @@ def save_results(
         # The following prints out WERs, per-word error statistics and aligned
         # ref/hyp pairs.
         errs_filename = (
-            params.exp_dir / f"{result_file_prefix}-errs-{test_set_name}-{key}.txt"
+            params.exp_dir
+            / f"{result_file_prefix}-errs-{test_set_name}-{key}.txt"
         )
         with open(errs_filename, "w") as f:
             wer = write_error_stats(
@@ -384,7 +384,9 @@ def save_results(
             test_set_wers[key] = wer
 
         if enable_log:
-            logging.info("Wrote detailed error stats to {}".format(errs_filename))
+            logging.info(
+                "Wrote detailed error stats to {}".format(errs_filename)
+            )
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt"
@@ -472,7 +474,9 @@ def main():
 
     if params.export:
         logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
-        torch.save({"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt")
+        torch.save(
+            {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
+        )
         return
 
     model.to(device)
@@ -503,7 +507,9 @@ def main():
             simulate_streaming=params.simulate_streaming,
         )
 
-        save_results(params=params, test_set_name=test_set, results_dict=results_dict)
+        save_results(
+            params=params, test_set_name=test_set, results_dict=results_dict
+        )
 
     logging.info("Done!")
 
diff --git a/egs/librispeech/ASR/streaming_conformer_ctc/train.py b/egs/librispeech/ASR/streaming_conformer_ctc/train.py
index 553b7d092..e41b7ea78 100755
--- a/egs/librispeech/ASR/streaming_conformer_ctc/train.py
+++ b/egs/librispeech/ASR/streaming_conformer_ctc/train.py
@@ -405,7 +405,9 @@ def compute_loss(
             #
             # See https://github.com/k2-fsa/icefall/issues/97
             # for more details
-            unsorted_token_ids = graph_compiler.texts_to_ids(supervisions["text"])
+            unsorted_token_ids = graph_compiler.texts_to_ids(
+                supervisions["text"]
+            )
             att_loss = mmodel.decoder_forward(
                 encoder_memory,
                 memory_mask,
@@ -434,7 +436,9 @@ def compute_loss(
     info["utt_duration"] = supervisions["num_frames"].sum().item()
     # averaged padding proportion over utterances
     info["utt_pad_proportion"] = (
-        ((feature.size(1) - supervisions["num_frames"]) / feature.size(1)).sum().item()
+        ((feature.size(1) - supervisions["num_frames"]) / feature.size(1))
+        .sum()
+        .item()
     )
 
     return loss, info
@@ -547,7 +551,9 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+                tot_loss.write_summary(
+                    tb_writer, "train/tot_", params.batch_idx_train
+                )
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
@@ -662,7 +668,9 @@ def run(rank, world_size, args):
 
         cur_lr = optimizer._rate
         if tb_writer is not None:
-            tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
+            tb_writer.add_scalar(
+                "train/learning_rate", cur_lr, params.batch_idx_train
+            )
             tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
 
         if rank == 0:
diff --git a/egs/librispeech/ASR/streaming_conformer_ctc/transformer.py b/egs/librispeech/ASR/streaming_conformer_ctc/transformer.py
index 0c87fdf1b..bc78e4a41 100644
--- a/egs/librispeech/ASR/streaming_conformer_ctc/transformer.py
+++ b/egs/librispeech/ASR/streaming_conformer_ctc/transformer.py
@@ -149,7 +149,9 @@ class Transformer(nn.Module):
                 norm=decoder_norm,
             )
 
-            self.decoder_output_layer = torch.nn.Linear(d_model, self.decoder_num_class)
+            self.decoder_output_layer = torch.nn.Linear(
+                d_model, self.decoder_num_class
+            )
 
             self.decoder_criterion = LabelSmoothingLoss()
         else:
@@ -284,17 +286,23 @@ class Transformer(nn.Module):
         """
         ys_in = add_sos(token_ids, sos_id=sos_id)
         ys_in = [torch.tensor(y) for y in ys_in]
-        ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id))
+        ys_in_pad = pad_sequence(
+            ys_in, batch_first=True, padding_value=float(eos_id)
+        )
 
         ys_out = add_eos(token_ids, eos_id=eos_id)
         ys_out = [torch.tensor(y) for y in ys_out]
-        ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1))
+        ys_out_pad = pad_sequence(
+            ys_out, batch_first=True, padding_value=float(-1)
+        )
 
         device = memory.device
         ys_in_pad = ys_in_pad.to(device)
         ys_out_pad = ys_out_pad.to(device)
 
-        tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device)
+        tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(
+            device
+        )
 
         tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id)
         # TODO: Use length information to create the decoder padding mask
@@ -355,17 +363,23 @@ class Transformer(nn.Module):
 
         ys_in = add_sos(token_ids, sos_id=sos_id)
         ys_in = [torch.tensor(y) for y in ys_in]
-        ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id))
+        ys_in_pad = pad_sequence(
+            ys_in, batch_first=True, padding_value=float(eos_id)
+        )
 
         ys_out = add_eos(token_ids, eos_id=eos_id)
         ys_out = [torch.tensor(y) for y in ys_out]
-        ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1))
+        ys_out_pad = pad_sequence(
+            ys_out, batch_first=True, padding_value=float(-1)
+        )
 
         device = memory.device
         ys_in_pad = ys_in_pad.to(device, dtype=torch.int64)
         ys_out_pad = ys_out_pad.to(device, dtype=torch.int64)
 
-        tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device)
+        tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(
+            device
+        )
 
         tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id)
         # TODO: Use length information to create the decoder padding mask
@@ -638,7 +652,9 @@ def _get_activation_fn(activation: str):
     elif activation == "gelu":
         return nn.functional.gelu
 
-    raise RuntimeError("activation should be relu/gelu, not {}".format(activation))
+    raise RuntimeError(
+        "activation should be relu/gelu, not {}".format(activation)
+    )
 
 
 class PositionalEncoding(nn.Module):
@@ -840,7 +856,9 @@ def encoder_padding_mask(
         1,
     ).to(torch.int32)
 
-    lengths = [0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)]
+    lengths = [
+        0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)
+    ]
     for idx in range(supervision_segments.size(0)):
         # Note: TorchScript doesn't allow to unpack tensors as tuples
         sequence_idx = supervision_segments[idx, 0].item()
@@ -861,7 +879,9 @@ def encoder_padding_mask(
     return mask
 
 
-def decoder_padding_mask(ys_pad: torch.Tensor, ignore_id: int = -1) -> torch.Tensor:
+def decoder_padding_mask(
+    ys_pad: torch.Tensor, ignore_id: int = -1
+) -> torch.Tensor:
     """Generate a length mask for input.
 
     The masked position are filled with True,
diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py
index 63afd6be2..355ccc99a 100644
--- a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py
+++ b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py
@@ -77,18 +77,17 @@ class LibriSpeechAsrDataModule:
     def add_arguments(cls, parser: argparse.ArgumentParser):
         group = parser.add_argument_group(
             title="ASR data related options",
-            description=(
-                "These options are used for the preparation of "
-                "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
-                "effective batch sizes, sampling strategies, applied data "
-                "augmentations, etc."
-            ),
+            description="These options are used for the preparation of "
+            "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
+            "effective batch sizes, sampling strategies, applied data "
+            "augmentations, etc.",
         )
         group.add_argument(
             "--full-libri",
             type=str2bool,
             default=True,
-            help="When enabled, use 960h LibriSpeech. Otherwise, use 100h subset.",
+            help="When enabled, use 960h LibriSpeech. "
+            "Otherwise, use 100h subset.",
         )
         group.add_argument(
             "--manifest-dir",
@@ -100,74 +99,59 @@ class LibriSpeechAsrDataModule:
             "--max-duration",
             type=int,
             default=200.0,
-            help=(
-                "Maximum pooled recordings duration (seconds) in a "
-                "single batch. You can reduce it if it causes CUDA OOM."
-            ),
+            help="Maximum pooled recordings duration (seconds) in a "
+            "single batch. You can reduce it if it causes CUDA OOM.",
         )
         group.add_argument(
             "--bucketing-sampler",
             type=str2bool,
             default=True,
-            help=(
-                "When enabled, the batches will come from buckets of "
-                "similar duration (saves padding frames)."
-            ),
+            help="When enabled, the batches will come from buckets of "
+            "similar duration (saves padding frames).",
         )
         group.add_argument(
             "--num-buckets",
             type=int,
             default=30,
-            help=(
-                "The number of buckets for the DynamicBucketingSampler"
-                "(you might want to increase it for larger datasets)."
-            ),
+            help="The number of buckets for the DynamicBucketingSampler"
+            "(you might want to increase it for larger datasets).",
         )
         group.add_argument(
             "--concatenate-cuts",
             type=str2bool,
             default=False,
-            help=(
-                "When enabled, utterances (cuts) will be concatenated "
-                "to minimize the amount of padding."
-            ),
+            help="When enabled, utterances (cuts) will be concatenated "
+            "to minimize the amount of padding.",
         )
         group.add_argument(
             "--duration-factor",
             type=float,
             default=1.0,
-            help=(
-                "Determines the maximum duration of a concatenated cut "
-                "relative to the duration of the longest cut in a batch."
-            ),
+            help="Determines the maximum duration of a concatenated cut "
+            "relative to the duration of the longest cut in a batch.",
         )
         group.add_argument(
             "--gap",
             type=float,
             default=1.0,
-            help=(
-                "The amount of padding (in seconds) inserted between "
-                "concatenated cuts. This padding is filled with noise when "
-                "noise augmentation is used."
-            ),
+            help="The amount of padding (in seconds) inserted between "
+            "concatenated cuts. This padding is filled with noise when "
+            "noise augmentation is used.",
         )
         group.add_argument(
             "--on-the-fly-feats",
             type=str2bool,
             default=False,
-            help=(
-                "When enabled, use on-the-fly cut mixing and feature "
-                "extraction. Will drop existing precomputed feature manifests "
-                "if available."
-            ),
+            help="When enabled, use on-the-fly cut mixing and feature "
+            "extraction. Will drop existing precomputed feature manifests "
+            "if available.",
         )
         group.add_argument(
             "--shuffle",
             type=str2bool,
             default=True,
-            help=(
-                "When enabled (=default), the examples will be shuffled for each epoch."
-            ),
+            help="When enabled (=default), the examples will be "
+            "shuffled for each epoch.",
         )
         group.add_argument(
             "--drop-last",
@@ -179,18 +163,17 @@ class LibriSpeechAsrDataModule:
             "--return-cuts",
             type=str2bool,
             default=True,
-            help=(
-                "When enabled, each batch will have the "
-                "field: batch['supervisions']['cut'] with the cuts that "
-                "were used to construct it."
-            ),
+            help="When enabled, each batch will have the "
+            "field: batch['supervisions']['cut'] with the cuts that "
+            "were used to construct it.",
         )
 
         group.add_argument(
             "--num-workers",
             type=int,
             default=2,
-            help="The number of training dataloader workers that collect the batches.",
+            help="The number of training dataloader workers that "
+            "collect the batches.",
         )
 
         group.add_argument(
@@ -204,22 +187,18 @@ class LibriSpeechAsrDataModule:
             "--spec-aug-time-warp-factor",
             type=int,
             default=80,
-            help=(
-                "Used only when --enable-spec-aug is True. "
-                "It specifies the factor for time warping in SpecAugment. "
-                "Larger values mean more warping. "
-                "A value less than 1 means to disable time warp."
-            ),
+            help="Used only when --enable-spec-aug is True. "
+            "It specifies the factor for time warping in SpecAugment. "
+            "Larger values mean more warping. "
+            "A value less than 1 means to disable time warp.",
         )
 
         group.add_argument(
             "--enable-musan",
             type=str2bool,
             default=True,
-            help=(
-                "When enabled, select noise from MUSAN and mix it"
-                "with training dataset. "
-            ),
+            help="When enabled, select noise from MUSAN and mix it"
+            "with training dataset. ",
         )
 
         group.add_argument(
@@ -245,16 +224,20 @@ class LibriSpeechAsrDataModule:
         if self.args.enable_musan:
             logging.info("Enable MUSAN")
             logging.info("About to get Musan cuts")
-            cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
+            cuts_musan = load_manifest(
+                self.args.manifest_dir / "musan_cuts.jsonl.gz"
+            )
             transforms.append(
-                CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
+                CutMix(
+                    cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
+                )
             )
         else:
             logging.info("Disable MUSAN")
 
         if self.args.concatenate_cuts:
             logging.info(
-                "Using cut concatenation with duration factor "
+                f"Using cut concatenation with duration factor "
                 f"{self.args.duration_factor} and gap {self.args.gap}."
             )
             # Cut concatenation should be the first transform in the list,
@@ -269,7 +252,9 @@ class LibriSpeechAsrDataModule:
         input_transforms = []
         if self.args.enable_spec_aug:
             logging.info("Enable SpecAugment")
-            logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
+            logging.info(
+                f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
+            )
             # Set the value of num_frame_masks according to Lhotse's version.
             # In different Lhotse's versions, the default of num_frame_masks is
             # different.
@@ -313,7 +298,9 @@ class LibriSpeechAsrDataModule:
             # Drop feats to be on the safe side.
             train = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
+                input_strategy=OnTheFlyFeatures(
+                    Fbank(FbankConfig(num_mel_bins=80))
+                ),
                 input_transforms=input_transforms,
                 return_cuts=self.args.return_cuts,
             )
@@ -369,7 +356,9 @@ class LibriSpeechAsrDataModule:
         if self.args.on_the_fly_feats:
             validate = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
+                input_strategy=OnTheFlyFeatures(
+                    Fbank(FbankConfig(num_mel_bins=80))
+                ),
                 return_cuts=self.args.return_cuts,
             )
         else:
diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py b/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py
index 94ba0a4dc..7d0cd0bf3 100755
--- a/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py
+++ b/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py
@@ -57,19 +57,16 @@ def get_parser():
         "--epoch",
         type=int,
         default=19,
-        help=(
-            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
-        ),
+        help="It specifies the checkpoint to use for decoding."
+        "Note: Epoch counts from 0.",
     )
     parser.add_argument(
         "--avg",
         type=int,
         default=5,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch'. "
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch'. ",
     )
     parser.add_argument(
         "--method",
@@ -339,7 +336,9 @@ def decode_dataset(
         if batch_idx % 100 == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+            logging.info(
+                f"batch {batch_str}, cuts processed until now is {num_cuts}"
+            )
     return results
 
 
@@ -401,7 +400,9 @@ def main():
 
     logging.info(f"device: {device}")
 
-    HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu"))
+    HLG = k2.Fsa.from_dict(
+        torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")
+    )
     HLG = HLG.to(device)
     assert HLG.requires_grad is False
 
@@ -466,7 +467,9 @@ def main():
 
     if params.export:
         logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
-        torch.save({"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt")
+        torch.save(
+            {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
+        )
         return
 
     model.to(device)
@@ -495,7 +498,9 @@ def main():
             G=G,
         )
 
-        save_results(params=params, test_set_name=test_set, results_dict=results_dict)
+        save_results(
+            params=params, test_set_name=test_set, results_dict=results_dict
+        )
 
     logging.info("Done!")
 
diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/model.py b/egs/librispeech/ASR/tdnn_lstm_ctc/model.py
index 1731e1ebe..5e04c11b4 100644
--- a/egs/librispeech/ASR/tdnn_lstm_ctc/model.py
+++ b/egs/librispeech/ASR/tdnn_lstm_ctc/model.py
@@ -66,7 +66,10 @@ class TdnnLstm(nn.Module):
             nn.BatchNorm1d(num_features=500, affine=False),
         )
         self.lstms = nn.ModuleList(
-            [nn.LSTM(input_size=500, hidden_size=500, num_layers=1) for _ in range(5)]
+            [
+                nn.LSTM(input_size=500, hidden_size=500, num_layers=1)
+                for _ in range(5)
+            ]
         )
         self.lstm_bnorms = nn.ModuleList(
             [nn.BatchNorm1d(num_features=500, affine=False) for _ in range(5)]
diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py b/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py
index 722e8f003..2baeb6bba 100755
--- a/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py
+++ b/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py
@@ -29,7 +29,11 @@ import torchaudio
 from model import TdnnLstm
 from torch.nn.utils.rnn import pad_sequence
 
-from icefall.decode import get_lattice, one_best_decoding, rescore_with_whole_lattice
+from icefall.decode import (
+    get_lattice,
+    one_best_decoding,
+    rescore_with_whole_lattice,
+)
 from icefall.utils import AttributeDict, get_texts
 
 
@@ -42,11 +46,9 @@ def get_parser():
         "--checkpoint",
         type=str,
         required=True,
-        help=(
-            "Path to the checkpoint. "
-            "The checkpoint is assumed to be saved by "
-            "icefall.checkpoint.save_checkpoint()."
-        ),
+        help="Path to the checkpoint. "
+        "The checkpoint is assumed to be saved by "
+        "icefall.checkpoint.save_checkpoint().",
     )
 
     parser.add_argument(
@@ -56,7 +58,9 @@ def get_parser():
         help="Path to words.txt",
     )
 
-    parser.add_argument("--HLG", type=str, required=True, help="Path to HLG.pt.")
+    parser.add_argument(
+        "--HLG", type=str, required=True, help="Path to HLG.pt."
+    )
 
     parser.add_argument(
         "--method",
@@ -99,12 +103,10 @@ def get_parser():
         "sound_files",
         type=str,
         nargs="+",
-        help=(
-            "The input sound file(s) to transcribe. "
-            "Supported formats are those supported by torchaudio.load(). "
-            "For example, wav and flac are supported. "
-            "The sample rate has to be 16kHz."
-        ),
+        help="The input sound file(s) to transcribe. "
+        "Supported formats are those supported by torchaudio.load(). "
+        "For example, wav and flac are supported. "
+        "The sample rate has to be 16kHz.",
     )
 
     return parser
@@ -142,9 +144,10 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert (
-            sample_rate == expected_sample_rate
-        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
+        assert sample_rate == expected_sample_rate, (
+            f"expected sample rate: {expected_sample_rate}. "
+            f"Given: {sample_rate}"
+        )
         # We use only the first channel
         ans.append(wave[0])
     return ans
@@ -212,7 +215,9 @@ def main():
     logging.info("Decoding started")
     features = fbank(waves)
 
-    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
+    features = pad_sequence(
+        features, batch_first=True, padding_value=math.log(1e-10)
+    )
     features = features.permute(0, 2, 1)  # now features is (N, C, T)
 
     with torch.no_grad():
@@ -264,7 +269,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/train.py b/egs/librispeech/ASR/tdnn_lstm_ctc/train.py
index 071ac792b..6b37d5c23 100755
--- a/egs/librispeech/ASR/tdnn_lstm_ctc/train.py
+++ b/egs/librispeech/ASR/tdnn_lstm_ctc/train.py
@@ -355,7 +355,9 @@ def compute_loss(
     info["utt_duration"] = supervisions["num_frames"].sum().item()
     # averaged padding proportion over utterances
     info["utt_pad_proportion"] = (
-        ((feature.size(2) - supervisions["num_frames"]) / feature.size(2)).sum().item()
+        ((feature.size(2) - supervisions["num_frames"]) / feature.size(2))
+        .sum()
+        .item()
     )
 
     return loss, info
@@ -468,7 +470,9 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+                tot_loss.write_summary(
+                    tb_writer, "train/tot_", params.batch_idx_train
+                )
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             valid_info = compute_validation_loss(
diff --git a/egs/librispeech/ASR/transducer/beam_search.py b/egs/librispeech/ASR/transducer/beam_search.py
index b45b6a9d8..11032f31a 100644
--- a/egs/librispeech/ASR/transducer/beam_search.py
+++ b/egs/librispeech/ASR/transducer/beam_search.py
@@ -38,7 +38,9 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]:
     blank_id = model.decoder.blank_id
     device = model.device
 
-    sos = torch.tensor([blank_id], device=device, dtype=torch.int64).reshape(1, 1)
+    sos = torch.tensor([blank_id], device=device, dtype=torch.int64).reshape(
+        1, 1
+    )
     decoder_out, (h, c) = model.decoder(sos)
     T = encoder_out.size(1)
     t = 0
@@ -121,7 +123,9 @@ def beam_search(
     max_u = 20000  # terminate after this number of steps
     u = 0
 
-    cache: Dict[str, Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = {}
+    cache: Dict[
+        str, Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
+    ] = {}
 
     while t < T and u < max_u:
         # fmt: off
@@ -153,9 +157,9 @@ def beam_search(
             cached_key = "_".join(map(str, y_star.ys))
 
             if cached_key not in cache:
-                decoder_input = torch.tensor([y_star.ys[-1]], device=device).reshape(
-                    1, 1
-                )
+                decoder_input = torch.tensor(
+                    [y_star.ys[-1]], device=device
+                ).reshape(1, 1)
 
                 decoder_out, decoder_state = model.decoder(
                     decoder_input,
diff --git a/egs/librispeech/ASR/transducer/decode.py b/egs/librispeech/ASR/transducer/decode.py
index f30332cea..5f233df87 100755
--- a/egs/librispeech/ASR/transducer/decode.py
+++ b/egs/librispeech/ASR/transducer/decode.py
@@ -71,19 +71,16 @@ def get_parser():
         "--epoch",
         type=int,
         default=34,
-        help=(
-            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
-        ),
+        help="It specifies the checkpoint to use for decoding."
+        "Note: Epoch counts from 0.",
     )
     parser.add_argument(
         "--avg",
         type=int,
         default=11,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch'. "
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch'. ",
     )
 
     parser.add_argument(
@@ -231,7 +228,9 @@ def decode_one_batch(
     supervisions = batch["supervisions"]
     feature_lens = supervisions["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
+    encoder_out, encoder_out_lens = model.encoder(
+        x=feature, x_lens=feature_lens
+    )
     hyps = []
     batch_size = encoder_out.size(0)
 
@@ -246,7 +245,9 @@ def decode_one_batch(
                 model=model, encoder_out=encoder_out_i, beam=params.beam_size
             )
         else:
-            raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
+            raise ValueError(
+                f"Unsupported decoding method: {params.decoding_method}"
+            )
         hyps.append(sp.decode(hyp).split())
 
     if params.decoding_method == "greedy_search":
@@ -317,7 +318,9 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+            logging.info(
+                f"batch {batch_str}, cuts processed until now is {num_cuts}"
+            )
     return results
 
 
@@ -350,7 +353,8 @@ def save_results(
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = (
-        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir
+        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tWER", file=f)
diff --git a/egs/librispeech/ASR/transducer/export.py b/egs/librispeech/ASR/transducer/export.py
index 4d9f937f5..5a5db30c4 100755
--- a/egs/librispeech/ASR/transducer/export.py
+++ b/egs/librispeech/ASR/transducer/export.py
@@ -67,20 +67,17 @@ def get_parser():
         "--epoch",
         type=int,
         default=34,
-        help=(
-            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
-        ),
+        help="It specifies the checkpoint to use for decoding."
+        "Note: Epoch counts from 0.",
     )
 
     parser.add_argument(
         "--avg",
         type=int,
         default=11,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch'. "
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch'. ",
     )
 
     parser.add_argument(
@@ -241,7 +238,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/librispeech/ASR/transducer/pretrained.py b/egs/librispeech/ASR/transducer/pretrained.py
index 7aadfbcd1..1db2df648 100755
--- a/egs/librispeech/ASR/transducer/pretrained.py
+++ b/egs/librispeech/ASR/transducer/pretrained.py
@@ -60,11 +60,9 @@ def get_parser():
         "--checkpoint",
         type=str,
         required=True,
-        help=(
-            "Path to the checkpoint. "
-            "The checkpoint is assumed to be saved by "
-            "icefall.checkpoint.save_checkpoint()."
-        ),
+        help="Path to the checkpoint. "
+        "The checkpoint is assumed to be saved by "
+        "icefall.checkpoint.save_checkpoint().",
     )
 
     parser.add_argument(
@@ -89,12 +87,10 @@ def get_parser():
         "sound_files",
         type=str,
         nargs="+",
-        help=(
-            "The input sound file(s) to transcribe. "
-            "Supported formats are those supported by torchaudio.load(). "
-            "For example, wav and flac are supported. "
-            "The sample rate has to be 16kHz."
-        ),
+        help="The input sound file(s) to transcribe. "
+        "Supported formats are those supported by torchaudio.load(). "
+        "For example, wav and flac are supported. "
+        "The sample rate has to be 16kHz.",
     )
 
     parser.add_argument(
@@ -192,9 +188,10 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert (
-            sample_rate == expected_sample_rate
-        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
+        assert sample_rate == expected_sample_rate, (
+            f"expected sample rate: {expected_sample_rate}. "
+            f"Given: {sample_rate}"
+        )
         # We use only the first channel
         ans.append(wave[0])
     return ans
@@ -252,7 +249,9 @@ def main():
     features = fbank(waves)
     feature_lengths = [f.size(0) for f in features]
 
-    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
+    features = pad_sequence(
+        features, batch_first=True, padding_value=math.log(1e-10)
+    )
 
     feature_lengths = torch.tensor(feature_lengths, device=device)
 
@@ -288,7 +287,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/librispeech/ASR/transducer/rnn.py b/egs/librispeech/ASR/transducer/rnn.py
index fe8732301..2a165b0c1 100644
--- a/egs/librispeech/ASR/transducer/rnn.py
+++ b/egs/librispeech/ASR/transducer/rnn.py
@@ -117,8 +117,12 @@ class LayerNormLSTMCell(nn.Module):
         )
 
         if bias:
-            self.bias_ih = nn.Parameter(torch.empty(4 * hidden_size, **factory_kwargs))
-            self.bias_hh = nn.Parameter(torch.empty(4 * hidden_size, **factory_kwargs))
+            self.bias_ih = nn.Parameter(
+                torch.empty(4 * hidden_size, **factory_kwargs)
+            )
+            self.bias_hh = nn.Parameter(
+                torch.empty(4 * hidden_size, **factory_kwargs)
+            )
         else:
             self.register_parameter("bias_ih", None)
             self.register_parameter("bias_hh", None)
@@ -344,7 +348,9 @@ class LayerNormLSTM(nn.Module):
             device=device,
             dtype=dtype,
         )
-        first_layer = LayerNormLSTMLayer(input_size=input_size, **factory_kwargs)
+        first_layer = LayerNormLSTMLayer(
+            input_size=input_size, **factory_kwargs
+        )
         layers = [first_layer]
         for i in range(1, num_layers):
             layers.append(
@@ -379,7 +385,9 @@ class LayerNormLSTM(nn.Module):
             - List[(next_h, next_c)] containing the hidden states for all layers
 
         """
-        output_states = torch.jit.annotate(List[Tuple[torch.Tensor, torch.Tensor]], [])
+        output_states = torch.jit.annotate(
+            List[Tuple[torch.Tensor, torch.Tensor]], []
+        )
         output = input
         for i, rnn_layer in enumerate(self.layers):
             state = states[i]
@@ -448,8 +456,12 @@ class LayerNormGRUCell(nn.Module):
         )
 
         if bias:
-            self.bias_ih = nn.Parameter(torch.empty(3 * hidden_size, **factory_kwargs))
-            self.bias_hh = nn.Parameter(torch.empty(3 * hidden_size, **factory_kwargs))
+            self.bias_ih = nn.Parameter(
+                torch.empty(3 * hidden_size, **factory_kwargs)
+            )
+            self.bias_hh = nn.Parameter(
+                torch.empty(3 * hidden_size, **factory_kwargs)
+            )
         else:
             self.register_parameter("bias_ih", None)
             self.register_parameter("bias_hh", None)
diff --git a/egs/librispeech/ASR/transducer/test_rnn.py b/egs/librispeech/ASR/transducer/test_rnn.py
index 74c94cc70..8591e2d8a 100755
--- a/egs/librispeech/ASR/transducer/test_rnn.py
+++ b/egs/librispeech/ASR/transducer/test_rnn.py
@@ -254,7 +254,9 @@ def test_layernorm_lstm_layer_with_projection_forward(device="cpu"):
         for name, self_param in self_layer.cell.named_parameters():
             getattr(torch_layer, f"{name}_l0").copy_(self_param)
 
-    torch_y, (torch_h, torch_c) = torch_layer(x_clone, (h.unsqueeze(0), c.unsqueeze(0)))
+    torch_y, (torch_h, torch_c) = torch_layer(
+        x_clone, (h.unsqueeze(0), c.unsqueeze(0))
+    )
     assert_allclose(self_y, torch_y)
     assert_allclose(self_h, torch_h)
     assert_allclose(self_c, torch_c)
@@ -301,7 +303,9 @@ def test_layernorm_lstm_layer_forward(device="cpu"):
         for name, self_param in self_layer.cell.named_parameters():
             getattr(torch_layer, f"{name}_l0").copy_(self_param)
 
-    torch_y, (torch_h, torch_c) = torch_layer(x_clone, (h.unsqueeze(0), c.unsqueeze(0)))
+    torch_y, (torch_h, torch_c) = torch_layer(
+        x_clone, (h.unsqueeze(0), c.unsqueeze(0))
+    )
     assert_allclose(self_y, torch_y)
     assert_allclose(self_h, torch_h)
     assert_allclose(self_c, torch_c)
@@ -590,7 +594,9 @@ def test_layernorm_gru_cell_forward(device="cpu"):
 
     assert_allclose(self_h, torch_h, atol=1e-5)
 
-    (self_h.reshape(-1) * torch.arange(self_h.numel(), device=device)).sum().backward()
+    (
+        self_h.reshape(-1) * torch.arange(self_h.numel(), device=device)
+    ).sum().backward()
     (
         torch_h.reshape(-1) * torch.arange(torch_h.numel(), device=device)
     ).sum().backward()
@@ -712,7 +718,9 @@ def test_layernorm_gru_forward(device="cpu"):
     T = torch.randint(low=2, high=100, size=(1,))
 
     x = torch.rand(N, T, input_size, device=device).requires_grad_()
-    states = [torch.rand(N, hidden_size, device=device) for _ in range(num_layers)]
+    states = [
+        torch.rand(N, hidden_size, device=device) for _ in range(num_layers)
+    ]
 
     x_clone = x.detach().clone().requires_grad_()
 
diff --git a/egs/librispeech/ASR/transducer/train.py b/egs/librispeech/ASR/transducer/train.py
index 674ea10a6..1dd65eddb 100755
--- a/egs/librispeech/ASR/transducer/train.py
+++ b/egs/librispeech/ASR/transducer/train.py
@@ -396,7 +396,9 @@ def compute_loss(
     info = MetricsTracker()
     with warnings.catch_warnings():
         warnings.simplefilter("ignore")
-        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
+        info["frames"] = (
+            (feature_lens // params.subsampling_factor).sum().item()
+        )
 
     # Note: We use reduction=sum while computing the loss.
     info["loss"] = loss.detach().cpu().item()
@@ -518,7 +520,9 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+                tot_loss.write_summary(
+                    tb_writer, "train/tot_", params.batch_idx_train
+                )
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
@@ -655,7 +659,9 @@ def run(rank, world_size, args):
 
         cur_lr = optimizer._rate
         if tb_writer is not None:
-            tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
+            tb_writer.add_scalar(
+                "train/learning_rate", cur_lr, params.batch_idx_train
+            )
             tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
 
         if rank == 0:
diff --git a/egs/librispeech/ASR/transducer_lstm/beam_search.py b/egs/librispeech/ASR/transducer_lstm/beam_search.py
index 5342c3e8c..3531a9633 100644
--- a/egs/librispeech/ASR/transducer_lstm/beam_search.py
+++ b/egs/librispeech/ASR/transducer_lstm/beam_search.py
@@ -38,7 +38,9 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]:
     blank_id = model.decoder.blank_id
     device = model.device
 
-    sos = torch.tensor([blank_id], device=device, dtype=torch.int64).reshape(1, 1)
+    sos = torch.tensor([blank_id], device=device, dtype=torch.int64).reshape(
+        1, 1
+    )
     decoder_out, (h, c) = model.decoder(sos)
     T = encoder_out.size(1)
     t = 0
@@ -122,7 +124,9 @@ def beam_search(
     max_u = 20000  # terminate after this number of steps
     u = 0
 
-    cache: Dict[str, Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = {}
+    cache: Dict[
+        str, Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
+    ] = {}
 
     while t < T and u < max_u:
         # fmt: off
@@ -154,9 +158,9 @@ def beam_search(
             cached_key = "_".join(map(str, y_star.ys))
 
             if cached_key not in cache:
-                decoder_input = torch.tensor([y_star.ys[-1]], device=device).reshape(
-                    1, 1
-                )
+                decoder_input = torch.tensor(
+                    [y_star.ys[-1]], device=device
+                ).reshape(1, 1)
 
                 decoder_out, decoder_state = model.decoder(
                     decoder_input,
diff --git a/egs/librispeech/ASR/transducer_lstm/decode.py b/egs/librispeech/ASR/transducer_lstm/decode.py
index 61b9de504..604235e2a 100755
--- a/egs/librispeech/ASR/transducer_lstm/decode.py
+++ b/egs/librispeech/ASR/transducer_lstm/decode.py
@@ -71,19 +71,16 @@ def get_parser():
         "--epoch",
         type=int,
         default=77,
-        help=(
-            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
-        ),
+        help="It specifies the checkpoint to use for decoding."
+        "Note: Epoch counts from 0.",
     )
     parser.add_argument(
         "--avg",
         type=int,
         default=55,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch'. "
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch'. ",
     )
 
     parser.add_argument(
@@ -228,7 +225,9 @@ def decode_one_batch(
     supervisions = batch["supervisions"]
     feature_lens = supervisions["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
+    encoder_out, encoder_out_lens = model.encoder(
+        x=feature, x_lens=feature_lens
+    )
     hyps = []
     batch_size = encoder_out.size(0)
 
@@ -243,7 +242,9 @@ def decode_one_batch(
                 model=model, encoder_out=encoder_out_i, beam=params.beam_size
             )
         else:
-            raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
+            raise ValueError(
+                f"Unsupported decoding method: {params.decoding_method}"
+            )
         hyps.append(sp.decode(hyp).split())
 
     if params.decoding_method == "greedy_search":
@@ -314,7 +315,9 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+            logging.info(
+                f"batch {batch_str}, cuts processed until now is {num_cuts}"
+            )
     return results
 
 
@@ -347,7 +350,8 @@ def save_results(
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = (
-        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir
+        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tWER", file=f)
diff --git a/egs/librispeech/ASR/transducer_lstm/encoder.py b/egs/librispeech/ASR/transducer_lstm/encoder.py
index 038d80077..3dc992dd2 100644
--- a/egs/librispeech/ASR/transducer_lstm/encoder.py
+++ b/egs/librispeech/ASR/transducer_lstm/encoder.py
@@ -48,7 +48,9 @@ class LstmEncoder(EncoderInterface):
         if vgg_frontend:
             self.encoder_embed = VggSubsampling(num_features, real_hidden_size)
         else:
-            self.encoder_embed = Conv2dSubsampling(num_features, real_hidden_size)
+            self.encoder_embed = Conv2dSubsampling(
+                num_features, real_hidden_size
+            )
 
         self.rnn = nn.LSTM(
             input_size=hidden_size,
diff --git a/egs/librispeech/ASR/transducer_lstm/train.py b/egs/librispeech/ASR/transducer_lstm/train.py
index 57bda63fd..cdb801e79 100755
--- a/egs/librispeech/ASR/transducer_lstm/train.py
+++ b/egs/librispeech/ASR/transducer_lstm/train.py
@@ -400,7 +400,9 @@ def compute_loss(
     info = MetricsTracker()
     with warnings.catch_warnings():
         warnings.simplefilter("ignore")
-        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
+        info["frames"] = (
+            (feature_lens // params.subsampling_factor).sum().item()
+        )
 
     # Note: We use reduction=sum while computing the loss.
     info["loss"] = loss.detach().cpu().item()
@@ -522,7 +524,9 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+                tot_loss.write_summary(
+                    tb_writer, "train/tot_", params.batch_idx_train
+                )
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
@@ -661,7 +665,9 @@ def run(rank, world_size, args):
 
         cur_lr = optimizer._rate
         if tb_writer is not None:
-            tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
+            tb_writer.add_scalar(
+                "train/learning_rate", cur_lr, params.batch_idx_train
+            )
             tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
 
         if rank == 0:
diff --git a/egs/librispeech/ASR/transducer_stateless/alignment.py b/egs/librispeech/ASR/transducer_stateless/alignment.py
index 65f2c58d8..f143611ea 100644
--- a/egs/librispeech/ASR/transducer_stateless/alignment.py
+++ b/egs/librispeech/ASR/transducer_stateless/alignment.py
@@ -193,7 +193,9 @@ def force_alignment(
         decoder_out = model.decoder(decoder_input, need_pad=False)
         # decoder_output is of shape (num_active_items, 1, decoder_output_dim)
 
-        current_encoder_out = current_encoder_out.expand(decoder_out.size(0), 1, -1)
+        current_encoder_out = current_encoder_out.expand(
+            decoder_out.size(0), 1, -1
+        )
 
         logits = model.joiner(
             current_encoder_out,
diff --git a/egs/librispeech/ASR/transducer_stateless/beam_search.py b/egs/librispeech/ASR/transducer_stateless/beam_search.py
index 1d79eef9d..ea985f30d 100644
--- a/egs/librispeech/ASR/transducer_stateless/beam_search.py
+++ b/egs/librispeech/ASR/transducer_stateless/beam_search.py
@@ -316,9 +316,9 @@ def greedy_search(
         y = logits.argmax().item()
         if y != blank_id:
             hyp.append(y)
-            decoder_input = torch.tensor([hyp[-context_size:]], device=device).reshape(
-                1, context_size
-            )
+            decoder_input = torch.tensor(
+                [hyp[-context_size:]], device=device
+            ).reshape(1, context_size)
 
             decoder_out = model.decoder(decoder_input, need_pad=False)
 
@@ -478,7 +478,9 @@ class HypothesisList(object):
         key = hyp.key
         if key in self:
             old_hyp = self._data[key]  # shallow copy
-            torch.logaddexp(old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob)
+            torch.logaddexp(
+                old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob
+            )
         else:
             self._data[key] = hyp
 
@@ -494,7 +496,9 @@ class HypothesisList(object):
           Return the hypothesis that has the largest `log_prob`.
         """
         if length_norm:
-            return max(self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys))
+            return max(
+                self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys)
+            )
         else:
             return max(self._data.values(), key=lambda hyp: hyp.log_prob)
 
@@ -782,7 +786,9 @@ def modified_beam_search(
         log_probs_shape = k2.ragged.create_ragged_shape2(
             row_splits=row_splits, cached_tot_size=log_probs.numel()
         )
-        ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs)
+        ragged_log_probs = k2.RaggedTensor(
+            shape=log_probs_shape, value=log_probs
+        )
 
         for i in range(batch_size):
             topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam)
@@ -881,7 +887,9 @@ def _deprecated_modified_beam_search(
         decoder_out = model.decoder(decoder_input, need_pad=False)
         # decoder_output is of shape (num_hyps, 1, decoder_output_dim)
 
-        current_encoder_out = current_encoder_out.expand(decoder_out.size(0), 1, -1)
+        current_encoder_out = current_encoder_out.expand(
+            decoder_out.size(0), 1, -1
+        )
 
         logits = model.joiner(
             current_encoder_out,
@@ -951,9 +959,9 @@ def beam_search(
 
     device = model.device
 
-    decoder_input = torch.tensor([blank_id] * context_size, device=device).reshape(
-        1, context_size
-    )
+    decoder_input = torch.tensor(
+        [blank_id] * context_size, device=device
+    ).reshape(1, context_size)
 
     decoder_out = model.decoder(decoder_input, need_pad=False)
 
diff --git a/egs/librispeech/ASR/transducer_stateless/compute_ali.py b/egs/librispeech/ASR/transducer_stateless/compute_ali.py
index 89992856d..48769e9d1 100755
--- a/egs/librispeech/ASR/transducer_stateless/compute_ali.py
+++ b/egs/librispeech/ASR/transducer_stateless/compute_ali.py
@@ -54,19 +54,16 @@ def get_parser():
         "--epoch",
         type=int,
         default=34,
-        help=(
-            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
-        ),
+        help="It specifies the checkpoint to use for decoding."
+        "Note: Epoch counts from 0.",
     )
     parser.add_argument(
         "--avg",
         type=int,
         default=20,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch'. "
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch'. ",
     )
 
     parser.add_argument(
@@ -127,7 +124,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
 
     return parser
@@ -164,7 +162,9 @@ def compute_alignments(
 
         feature_lens = supervisions["num_frames"].to(device)
 
-        encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
+        encoder_out, encoder_out_lens = model.encoder(
+            x=feature, x_lens=feature_lens
+        )
 
         batch_size = encoder_out.size(0)
 
@@ -204,7 +204,9 @@ def compute_alignments(
         if batch_idx % 2 == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+            logging.info(
+                f"batch {batch_str}, cuts processed until now is {num_cuts}"
+            )
 
     return CutSet.from_cuts(cuts)
 
diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py
index d279eae85..cde52c9fc 100644
--- a/egs/librispeech/ASR/transducer_stateless/conformer.py
+++ b/egs/librispeech/ASR/transducer_stateless/conformer.py
@@ -209,7 +209,10 @@ class Conformer(Transformer):
 
           NOTE: the returned tensors are on the given device.
         """
-        if len(self._init_state) == 2 and self._init_state[0].size(1) == left_context:
+        if (
+            len(self._init_state) == 2
+            and self._init_state[0].size(1) == left_context
+        ):
             # Note: It is OK to share the init state as it is
             # not going to be modified by the model
             return self._init_state
@@ -418,7 +421,9 @@ class ConformerEncoderLayer(nn.Module):
         causal: bool = False,
     ) -> None:
         super(ConformerEncoderLayer, self).__init__()
-        self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0)
+        self.self_attn = RelPositionMultiheadAttention(
+            d_model, nhead, dropout=0.0
+        )
 
         self.feed_forward = nn.Sequential(
             nn.Linear(d_model, dim_feedforward),
@@ -434,16 +439,22 @@ class ConformerEncoderLayer(nn.Module):
             nn.Linear(dim_feedforward, d_model),
         )
 
-        self.conv_module = ConvolutionModule(d_model, cnn_module_kernel, causal=causal)
+        self.conv_module = ConvolutionModule(
+            d_model, cnn_module_kernel, causal=causal
+        )
 
-        self.norm_ff_macaron = nn.LayerNorm(d_model)  # for the macaron style FNN module
+        self.norm_ff_macaron = nn.LayerNorm(
+            d_model
+        )  # for the macaron style FNN module
         self.norm_ff = nn.LayerNorm(d_model)  # for the FNN module
         self.norm_mha = nn.LayerNorm(d_model)  # for the MHA module
 
         self.ff_scale = 0.5
 
         self.norm_conv = nn.LayerNorm(d_model)  # for the CNN module
-        self.norm_final = nn.LayerNorm(d_model)  # for the final output of the block
+        self.norm_final = nn.LayerNorm(
+            d_model
+        )  # for the final output of the block
 
         self.dropout = nn.Dropout(dropout)
 
@@ -475,7 +486,9 @@ class ConformerEncoderLayer(nn.Module):
         residual = src
         if self.normalize_before:
             src = self.norm_ff_macaron(src)
-        src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src))
+        src = residual + self.ff_scale * self.dropout(
+            self.feed_forward_macaron(src)
+        )
         if not self.normalize_before:
             src = self.norm_ff_macaron(src)
 
@@ -501,7 +514,9 @@ class ConformerEncoderLayer(nn.Module):
         if self.normalize_before:
             src = self.norm_conv(src)
 
-        src, _ = self.conv_module(src, src_key_padding_mask=src_key_padding_mask)
+        src, _ = self.conv_module(
+            src, src_key_padding_mask=src_key_padding_mask
+        )
         src = residual + self.dropout(src)
 
         if not self.normalize_before:
@@ -566,7 +581,9 @@ class ConformerEncoderLayer(nn.Module):
         residual = src
         if self.normalize_before:
             src = self.norm_ff_macaron(src)
-        src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src))
+        src = residual + self.ff_scale * self.dropout(
+            self.feed_forward_macaron(src)
+        )
         if not self.normalize_before:
             src = self.norm_ff_macaron(src)
 
@@ -608,7 +625,9 @@ class ConformerEncoderLayer(nn.Module):
         if self.normalize_before:
             src = self.norm_conv(src)
 
-        src, conv_cache = self.conv_module(src, states[1], right_context=right_context)
+        src, conv_cache = self.conv_module(
+            src, states[1], right_context=right_context
+        )
         states[1] = conv_cache
         src = residual + self.dropout(src)
 
@@ -760,7 +779,9 @@ class RelPositionalEncoding(torch.nn.Module):
 
     """
 
-    def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None:
+    def __init__(
+        self, d_model: int, dropout_rate: float, max_len: int = 5000
+    ) -> None:
         """Construct an PositionalEncoding object."""
         super(RelPositionalEncoding, self).__init__()
         self.d_model = d_model
@@ -777,7 +798,9 @@ class RelPositionalEncoding(torch.nn.Module):
             # the length of self.pe is 2 * input_len - 1
             if self.pe.size(1) >= x_size_1 * 2 - 1:
                 # Note: TorchScript doesn't implement operator== for torch.Device
-                if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device):
+                if self.pe.dtype != x.dtype or str(self.pe.device) != str(
+                    x.device
+                ):
                     self.pe = self.pe.to(dtype=x.dtype, device=x.device)
                 return
         # Suppose `i` means to the position of query vector and `j` means the
@@ -803,7 +826,9 @@ class RelPositionalEncoding(torch.nn.Module):
         pe = torch.cat([pe_positive, pe_negative], dim=1)
         self.pe = pe.to(device=x.device, dtype=x.dtype)
 
-    def forward(self, x: torch.Tensor, left_context: int = 0) -> Tuple[Tensor, Tensor]:
+    def forward(
+        self, x: torch.Tensor, left_context: int = 0
+    ) -> Tuple[Tensor, Tensor]:
         """Add positional encoding.
 
         Args:
@@ -1067,9 +1092,9 @@ class RelPositionMultiheadAttention(nn.Module):
 
         if torch.equal(query, key) and torch.equal(key, value):
             # self-attention
-            q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk(
-                3, dim=-1
-            )
+            q, k, v = nn.functional.linear(
+                query, in_proj_weight, in_proj_bias
+            ).chunk(3, dim=-1)
 
         elif torch.equal(key, value):
             # encoder-decoder attention
@@ -1138,25 +1163,33 @@ class RelPositionMultiheadAttention(nn.Module):
             if attn_mask.dim() == 2:
                 attn_mask = attn_mask.unsqueeze(0)
                 if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
-                    raise RuntimeError("The size of the 2D attn_mask is not correct.")
+                    raise RuntimeError(
+                        "The size of the 2D attn_mask is not correct."
+                    )
             elif attn_mask.dim() == 3:
                 if list(attn_mask.size()) != [
                     bsz * num_heads,
                     query.size(0),
                     key.size(0),
                 ]:
-                    raise RuntimeError("The size of the 3D attn_mask is not correct.")
+                    raise RuntimeError(
+                        "The size of the 3D attn_mask is not correct."
+                    )
             else:
                 raise RuntimeError(
-                    "attn_mask's dimension {} is not supported".format(attn_mask.dim())
+                    "attn_mask's dimension {} is not supported".format(
+                        attn_mask.dim()
+                    )
                 )
             # attn_mask's dim is 3 now.
 
         # convert ByteTensor key_padding_mask to bool
-        if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
+        if (
+            key_padding_mask is not None
+            and key_padding_mask.dtype == torch.uint8
+        ):
             warnings.warn(
-                "Byte tensor for key_padding_mask is deprecated. Use bool tensor"
-                " instead."
+                "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead."
             )
             key_padding_mask = key_padding_mask.to(torch.bool)
 
@@ -1195,10 +1228,14 @@ class RelPositionMultiheadAttention(nn.Module):
         # first compute matrix a and matrix c
         # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
         k = k.permute(1, 2, 3, 0)  # (batch, head, d_k, time2)
-        matrix_ac = torch.matmul(q_with_bias_u, k)  # (batch, head, time1, time2)
+        matrix_ac = torch.matmul(
+            q_with_bias_u, k
+        )  # (batch, head, time1, time2)
 
         # compute matrix b and matrix d
-        matrix_bd = torch.matmul(q_with_bias_v, p)  # (batch, head, time1, 2*time1-1)
+        matrix_bd = torch.matmul(
+            q_with_bias_v, p
+        )  # (batch, head, time1, 2*time1-1)
 
         matrix_bd = self.rel_shift(matrix_bd, left_context=left_context)
 
@@ -1206,7 +1243,9 @@ class RelPositionMultiheadAttention(nn.Module):
             matrix_ac + matrix_bd
         ) * scaling  # (batch, head, time1, time2)
 
-        attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1)
+        attn_output_weights = attn_output_weights.view(
+            bsz * num_heads, tgt_len, -1
+        )
 
         assert list(attn_output_weights.size()) == [
             bsz * num_heads,
@@ -1251,7 +1290,9 @@ class RelPositionMultiheadAttention(nn.Module):
             attn_output_weights = attn_output_weights.view(
                 bsz, num_heads, tgt_len, src_len
             )
-            attn_output_weights = attn_output_weights.masked_fill(combined_mask, 0.0)
+            attn_output_weights = attn_output_weights.masked_fill(
+                combined_mask, 0.0
+            )
             attn_output_weights = attn_output_weights.view(
                 bsz * num_heads, tgt_len, src_len
             )
@@ -1263,9 +1304,13 @@ class RelPositionMultiheadAttention(nn.Module):
         attn_output = torch.bmm(attn_output_weights, v)
         assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
         attn_output = (
-            attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
+            attn_output.transpose(0, 1)
+            .contiguous()
+            .view(tgt_len, bsz, embed_dim)
+        )
+        attn_output = nn.functional.linear(
+            attn_output, out_proj_weight, out_proj_bias
         )
-        attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias)
 
         if need_weights:
             # average attention weights over heads
@@ -1373,12 +1418,16 @@ class ConvolutionModule(nn.Module):
                 # manualy padding self.lorder zeros to the left
                 x = nn.functional.pad(x, (self.lorder, 0), "constant", 0.0)
             else:
-                assert not self.training, "Cache should be None in training time"
+                assert (
+                    not self.training
+                ), "Cache should be None in training time"
                 assert cache.size(0) == self.lorder
                 x = torch.cat([cache.permute(1, 2, 0), x], dim=2)
                 if right_context > 0:
                     cache = x.permute(2, 0, 1)[
-                        -(self.lorder + right_context) : (-right_context),  # noqa
+                        -(self.lorder + right_context) : (  # noqa
+                            -right_context
+                        ),
                         ...,
                     ]
                 else:
diff --git a/egs/librispeech/ASR/transducer_stateless/decode.py b/egs/librispeech/ASR/transducer_stateless/decode.py
index 314f49154..74bba9cad 100755
--- a/egs/librispeech/ASR/transducer_stateless/decode.py
+++ b/egs/librispeech/ASR/transducer_stateless/decode.py
@@ -94,19 +94,16 @@ def get_parser():
         "--epoch",
         type=int,
         default=29,
-        help=(
-            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
-        ),
+        help="It specifies the checkpoint to use for decoding."
+        "Note: Epoch counts from 0.",
     )
     parser.add_argument(
         "--avg",
         type=int,
         default=13,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch'. "
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch'. ",
     )
 
     parser.add_argument(
@@ -174,7 +171,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -232,7 +230,9 @@ def decode_one_batch(
     supervisions = batch["supervisions"]
     feature_lens = supervisions["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
+    encoder_out, encoder_out_lens = model.encoder(
+        x=feature, x_lens=feature_lens
+    )
 
     hyps = []
 
@@ -248,7 +248,10 @@ def decode_one_batch(
         )
         for hyp in sp.decode(hyp_tokens):
             hyps.append(hyp.split())
-    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
+    elif (
+        params.decoding_method == "greedy_search"
+        and params.max_sym_per_frame == 1
+    ):
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -294,7 +297,11 @@ def decode_one_batch(
         return {"greedy_search": hyps}
     elif params.decoding_method == "fast_beam_search":
         return {
-            f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps
+            (
+                f"beam_{params.beam}_"
+                f"max_contexts_{params.max_contexts}_"
+                f"max_states_{params.max_states}"
+            ): hyps
         }
     else:
         return {f"beam_size_{params.beam_size}": hyps}
@@ -367,7 +374,9 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+            logging.info(
+                f"batch {batch_str}, cuts processed until now is {num_cuts}"
+            )
     return results
 
 
@@ -400,7 +409,8 @@ def save_results(
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = (
-        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir
+        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tWER", file=f)
@@ -440,7 +450,9 @@ def main():
         params.suffix += f"-max-contexts-{params.max_contexts}"
         params.suffix += f"-max-states-{params.max_states}"
     elif "beam_search" in params.decoding_method:
-        params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
+        params.suffix += (
+            f"-{params.decoding_method}-beam-size-{params.beam_size}"
+        )
     else:
         params.suffix += f"-context-{params.context_size}"
         params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
diff --git a/egs/librispeech/ASR/transducer_stateless/decoder.py b/egs/librispeech/ASR/transducer_stateless/decoder.py
index a182d91e2..fbc2373a9 100644
--- a/egs/librispeech/ASR/transducer_stateless/decoder.py
+++ b/egs/librispeech/ASR/transducer_stateless/decoder.py
@@ -87,7 +87,9 @@ class Decoder(nn.Module):
         if self.context_size > 1:
             embedding_out = embedding_out.permute(0, 2, 1)
             if need_pad is True:
-                embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0))
+                embedding_out = F.pad(
+                    embedding_out, pad=(self.context_size - 1, 0)
+                )
             else:
                 # During inference time, there is no need to do extra padding
                 # as we only need one output
diff --git a/egs/librispeech/ASR/transducer_stateless/export.py b/egs/librispeech/ASR/transducer_stateless/export.py
index 7c10b4348..8bd0bdea1 100755
--- a/egs/librispeech/ASR/transducer_stateless/export.py
+++ b/egs/librispeech/ASR/transducer_stateless/export.py
@@ -68,20 +68,17 @@ def get_parser():
         "--epoch",
         type=int,
         default=20,
-        help=(
-            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
-        ),
+        help="It specifies the checkpoint to use for decoding."
+        "Note: Epoch counts from 0.",
     )
 
     parser.add_argument(
         "--avg",
         type=int,
         default=10,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch'. "
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch'. ",
     )
 
     parser.add_argument(
@@ -112,7 +109,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
 
     return parser
@@ -246,7 +244,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/librispeech/ASR/transducer_stateless/joiner.py b/egs/librispeech/ASR/transducer_stateless/joiner.py
index e1625992d..93cccbd8c 100644
--- a/egs/librispeech/ASR/transducer_stateless/joiner.py
+++ b/egs/librispeech/ASR/transducer_stateless/joiner.py
@@ -60,9 +60,13 @@ class Joiner(nn.Module):
         encoder_out_len: List[int] = encoder_out_len.tolist()
         decoder_out_len: List[int] = decoder_out_len.tolist()
 
-        encoder_out_list = [encoder_out[i, : encoder_out_len[i], :] for i in range(N)]
+        encoder_out_list = [
+            encoder_out[i, : encoder_out_len[i], :] for i in range(N)
+        ]
 
-        decoder_out_list = [decoder_out[i, : decoder_out_len[i], :] for i in range(N)]
+        decoder_out_list = [
+            decoder_out[i, : decoder_out_len[i], :] for i in range(N)
+        ]
 
         x = [
             e.unsqueeze(1) + d.unsqueeze(0)
diff --git a/egs/librispeech/ASR/transducer_stateless/pretrained.py b/egs/librispeech/ASR/transducer_stateless/pretrained.py
index bd7eeff28..b64521801 100755
--- a/egs/librispeech/ASR/transducer_stateless/pretrained.py
+++ b/egs/librispeech/ASR/transducer_stateless/pretrained.py
@@ -90,11 +90,9 @@ def get_parser():
         "--checkpoint",
         type=str,
         required=True,
-        help=(
-            "Path to the checkpoint. "
-            "The checkpoint is assumed to be saved by "
-            "icefall.checkpoint.save_checkpoint()."
-        ),
+        help="Path to the checkpoint. "
+        "The checkpoint is assumed to be saved by "
+        "icefall.checkpoint.save_checkpoint().",
     )
 
     parser.add_argument(
@@ -119,12 +117,10 @@ def get_parser():
         "sound_files",
         type=str,
         nargs="+",
-        help=(
-            "The input sound file(s) to transcribe. "
-            "Supported formats are those supported by torchaudio.load(). "
-            "For example, wav and flac are supported. "
-            "The sample rate has to be 16kHz."
-        ),
+        help="The input sound file(s) to transcribe. "
+        "Supported formats are those supported by torchaudio.load(). "
+        "For example, wav and flac are supported. "
+        "The sample rate has to be 16kHz.",
     )
 
     parser.add_argument(
@@ -171,7 +167,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -200,9 +197,10 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert (
-            sample_rate == expected_sample_rate
-        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
+        assert sample_rate == expected_sample_rate, (
+            f"expected sample rate: {expected_sample_rate}. "
+            f"Given: {sample_rate}"
+        )
         # We use only the first channel
         ans.append(wave[0])
     return ans
@@ -261,7 +259,9 @@ def main():
     features = fbank(waves)
     feature_lengths = [f.size(0) for f in features]
 
-    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
+    features = pad_sequence(
+        features, batch_first=True, padding_value=math.log(1e-10)
+    )
 
     feature_lengths = torch.tensor(feature_lengths, device=device)
 
@@ -334,7 +334,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/librispeech/ASR/transducer_stateless/test_compute_ali.py b/egs/librispeech/ASR/transducer_stateless/test_compute_ali.py
index 9af46846a..b00fc34f1 100755
--- a/egs/librispeech/ASR/transducer_stateless/test_compute_ali.py
+++ b/egs/librispeech/ASR/transducer_stateless/test_compute_ali.py
@@ -140,13 +140,16 @@ def main():
                 token_alignment[i, : token_alignment_length[i]].tolist(), sp=sp
             )
             word_starting_time = [
-                "{:.2f}".format(i * frame_shift_in_second) for i in word_starting_frames
+                "{:.2f}".format(i * frame_shift_in_second)
+                for i in word_starting_frames
             ]
 
             words = supervisions["text"][i].split()
 
             assert len(word_starting_frames) == len(words)
-            word_starting_time_dict[cuts[i].id] = list(zip(words, word_starting_time))
+            word_starting_time_dict[cuts[i].id] = list(
+                zip(words, word_starting_time)
+            )
 
         # This is a demo script and we exit here after processing
         # one batch.
@@ -157,7 +160,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/librispeech/ASR/transducer_stateless/test_conformer.py b/egs/librispeech/ASR/transducer_stateless/test_conformer.py
index 65b08d425..d1350c8ab 100755
--- a/egs/librispeech/ASR/transducer_stateless/test_conformer.py
+++ b/egs/librispeech/ASR/transducer_stateless/test_conformer.py
@@ -29,7 +29,9 @@ from conformer import Conformer
 
 def test_conformer():
     feature_dim = 50
-    c = Conformer(num_features=feature_dim, output_dim=256, d_model=128, nhead=4)
+    c = Conformer(
+        num_features=feature_dim, output_dim=256, d_model=128, nhead=4
+    )
     batch_size = 5
     seq_len = 20
     # Just make sure the forward pass runs.
diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py
index bcb883fa5..ae93f3348 100755
--- a/egs/librispeech/ASR/transducer_stateless/train.py
+++ b/egs/librispeech/ASR/transducer_stateless/train.py
@@ -136,7 +136,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
 
     parser.add_argument(
@@ -421,7 +422,9 @@ def compute_loss(
     info = MetricsTracker()
     with warnings.catch_warnings():
         warnings.simplefilter("ignore")
-        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
+        info["frames"] = (
+            (feature_lens // params.subsampling_factor).sum().item()
+        )
 
     # Note: We use reduction=sum while computing the loss.
     info["loss"] = loss.detach().cpu().item()
@@ -542,7 +545,9 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+                tot_loss.write_summary(
+                    tb_writer, "train/tot_", params.batch_idx_train
+                )
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
@@ -659,9 +664,13 @@ def run(rank, world_size, args):
         num_removed = num_in_total - num_left
         removed_percent = num_removed / num_in_total * 100
 
-        logging.info(f"Before removing short and long utterances: {num_in_total}")
+        logging.info(
+            f"Before removing short and long utterances: {num_in_total}"
+        )
         logging.info(f"After removing short and long utterances: {num_left}")
-        logging.info(f"Removed {num_removed} utterances ({removed_percent:.5f}%)")
+        logging.info(
+            f"Removed {num_removed} utterances ({removed_percent:.5f}%)"
+        )
     except TypeError as e:
         # You can ignore this error as previous versions of Lhotse work fine
         # for the above code. In recent versions of Lhotse, it uses
@@ -689,7 +698,9 @@ def run(rank, world_size, args):
 
         cur_lr = optimizer._rate
         if tb_writer is not None:
-            tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
+            tb_writer.add_scalar(
+                "train/learning_rate", cur_lr, params.batch_idx_train
+            )
             tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
 
         if rank == 0:
diff --git a/egs/librispeech/ASR/transducer_stateless/transformer.py b/egs/librispeech/ASR/transducer_stateless/transformer.py
index b3ff153c1..e851dcc32 100644
--- a/egs/librispeech/ASR/transducer_stateless/transformer.py
+++ b/egs/librispeech/ASR/transducer_stateless/transformer.py
@@ -250,7 +250,9 @@ def _get_activation_fn(activation: str):
     elif activation == "gelu":
         return nn.functional.gelu
 
-    raise RuntimeError("activation should be relu/gelu, not {}".format(activation))
+    raise RuntimeError(
+        "activation should be relu/gelu, not {}".format(activation)
+    )
 
 
 class PositionalEncoding(nn.Module):
diff --git a/egs/librispeech/ASR/transducer_stateless2/decode.py b/egs/librispeech/ASR/transducer_stateless2/decode.py
index 86ef9e5b6..ac2807241 100755
--- a/egs/librispeech/ASR/transducer_stateless2/decode.py
+++ b/egs/librispeech/ASR/transducer_stateless2/decode.py
@@ -94,19 +94,16 @@ def get_parser():
         "--epoch",
         type=int,
         default=29,
-        help=(
-            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
-        ),
+        help="It specifies the checkpoint to use for decoding."
+        "Note: Epoch counts from 0.",
     )
     parser.add_argument(
         "--avg",
         type=int,
         default=13,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch'. "
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch'. ",
     )
 
     parser.add_argument(
@@ -174,7 +171,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -232,7 +230,9 @@ def decode_one_batch(
     supervisions = batch["supervisions"]
     feature_lens = supervisions["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
+    encoder_out, encoder_out_lens = model.encoder(
+        x=feature, x_lens=feature_lens
+    )
 
     hyps = []
 
@@ -248,7 +248,10 @@ def decode_one_batch(
         )
         for hyp in sp.decode(hyp_tokens):
             hyps.append(hyp.split())
-    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
+    elif (
+        params.decoding_method == "greedy_search"
+        and params.max_sym_per_frame == 1
+    ):
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -294,7 +297,11 @@ def decode_one_batch(
         return {"greedy_search": hyps}
     elif params.decoding_method == "fast_beam_search":
         return {
-            f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps
+            (
+                f"beam_{params.beam}_"
+                f"max_contexts_{params.max_contexts}_"
+                f"max_states_{params.max_states}"
+            ): hyps
         }
     else:
         return {f"beam_size_{params.beam_size}": hyps}
@@ -367,7 +374,9 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+            logging.info(
+                f"batch {batch_str}, cuts processed until now is {num_cuts}"
+            )
     return results
 
 
@@ -400,7 +409,8 @@ def save_results(
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = (
-        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir
+        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tWER", file=f)
@@ -440,7 +450,9 @@ def main():
         params.suffix += f"-max-contexts-{params.max_contexts}"
         params.suffix += f"-max-states-{params.max_states}"
     elif "beam_search" in params.decoding_method:
-        params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
+        params.suffix += (
+            f"-{params.decoding_method}-beam-size-{params.beam_size}"
+        )
     else:
         params.suffix += f"-context-{params.context_size}"
         params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
diff --git a/egs/librispeech/ASR/transducer_stateless2/export.py b/egs/librispeech/ASR/transducer_stateless2/export.py
index d95eeb1f4..57c1a6094 100755
--- a/egs/librispeech/ASR/transducer_stateless2/export.py
+++ b/egs/librispeech/ASR/transducer_stateless2/export.py
@@ -63,20 +63,17 @@ def get_parser():
         "--epoch",
         type=int,
         default=20,
-        help=(
-            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
-        ),
+        help="It specifies the checkpoint to use for decoding."
+        "Note: Epoch counts from 0.",
     )
 
     parser.add_argument(
         "--avg",
         type=int,
         default=10,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch'. "
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch'. ",
     )
 
     parser.add_argument(
@@ -107,7 +104,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
 
     return parser
@@ -178,7 +176,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/librispeech/ASR/transducer_stateless2/pretrained.py b/egs/librispeech/ASR/transducer_stateless2/pretrained.py
index 793931e3b..292f77f03 100755
--- a/egs/librispeech/ASR/transducer_stateless2/pretrained.py
+++ b/egs/librispeech/ASR/transducer_stateless2/pretrained.py
@@ -90,11 +90,9 @@ def get_parser():
         "--checkpoint",
         type=str,
         required=True,
-        help=(
-            "Path to the checkpoint. "
-            "The checkpoint is assumed to be saved by "
-            "icefall.checkpoint.save_checkpoint()."
-        ),
+        help="Path to the checkpoint. "
+        "The checkpoint is assumed to be saved by "
+        "icefall.checkpoint.save_checkpoint().",
     )
 
     parser.add_argument(
@@ -119,12 +117,10 @@ def get_parser():
         "sound_files",
         type=str,
         nargs="+",
-        help=(
-            "The input sound file(s) to transcribe. "
-            "Supported formats are those supported by torchaudio.load(). "
-            "For example, wav and flac are supported. "
-            "The sample rate has to be 16kHz."
-        ),
+        help="The input sound file(s) to transcribe. "
+        "Supported formats are those supported by torchaudio.load(). "
+        "For example, wav and flac are supported. "
+        "The sample rate has to be 16kHz.",
     )
 
     parser.add_argument(
@@ -171,7 +167,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -200,9 +197,10 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert (
-            sample_rate == expected_sample_rate
-        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
+        assert sample_rate == expected_sample_rate, (
+            f"expected sample rate: {expected_sample_rate}. "
+            f"Given: {sample_rate}"
+        )
         # We use only the first channel
         ans.append(wave[0])
     return ans
@@ -261,7 +259,9 @@ def main():
     features = fbank(waves)
     feature_lengths = [f.size(0) for f in features]
 
-    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
+    features = pad_sequence(
+        features, batch_first=True, padding_value=math.log(1e-10)
+    )
 
     feature_lengths = torch.tensor(feature_lengths, device=device)
 
@@ -334,7 +334,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/librispeech/ASR/transducer_stateless2/train.py b/egs/librispeech/ASR/transducer_stateless2/train.py
index 68e247f23..ea15c9040 100755
--- a/egs/librispeech/ASR/transducer_stateless2/train.py
+++ b/egs/librispeech/ASR/transducer_stateless2/train.py
@@ -136,7 +136,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
 
     parser.add_argument(
@@ -409,7 +410,9 @@ def compute_loss(
     info = MetricsTracker()
     with warnings.catch_warnings():
         warnings.simplefilter("ignore")
-        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
+        info["frames"] = (
+            (feature_lens // params.subsampling_factor).sum().item()
+        )
 
     # Note: We use reduction=sum while computing the loss.
     info["loss"] = loss.detach().cpu().item()
@@ -530,7 +533,9 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+                tot_loss.write_summary(
+                    tb_writer, "train/tot_", params.batch_idx_train
+                )
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
@@ -647,9 +652,13 @@ def run(rank, world_size, args):
         num_removed = num_in_total - num_left
         removed_percent = num_removed / num_in_total * 100
 
-        logging.info(f"Before removing short and long utterances: {num_in_total}")
+        logging.info(
+            f"Before removing short and long utterances: {num_in_total}"
+        )
         logging.info(f"After removing short and long utterances: {num_left}")
-        logging.info(f"Removed {num_removed} utterances ({removed_percent:.5f}%)")
+        logging.info(
+            f"Removed {num_removed} utterances ({removed_percent:.5f}%)"
+        )
     except TypeError as e:
         # You can ignore this error as previous versions of Lhotse work fine
         # for the above code. In recent versions of Lhotse, it uses
@@ -677,7 +686,9 @@ def run(rank, world_size, args):
 
         cur_lr = optimizer._rate
         if tb_writer is not None:
-            tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
+            tb_writer.add_scalar(
+                "train/learning_rate", cur_lr, params.batch_idx_train
+            )
             tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
 
         if rank == 0:
diff --git a/egs/librispeech/ASR/transducer_stateless_multi_datasets/decode.py b/egs/librispeech/ASR/transducer_stateless_multi_datasets/decode.py
index 22b6ab911..d596e05cb 100755
--- a/egs/librispeech/ASR/transducer_stateless_multi_datasets/decode.py
+++ b/egs/librispeech/ASR/transducer_stateless_multi_datasets/decode.py
@@ -95,19 +95,16 @@ def get_parser():
         "--epoch",
         type=int,
         default=29,
-        help=(
-            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
-        ),
+        help="It specifies the checkpoint to use for decoding."
+        "Note: Epoch counts from 0.",
     )
     parser.add_argument(
         "--avg",
         type=int,
         default=13,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch'. "
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch'. ",
     )
 
     parser.add_argument(
@@ -175,7 +172,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -233,7 +231,9 @@ def decode_one_batch(
     supervisions = batch["supervisions"]
     feature_lens = supervisions["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
+    encoder_out, encoder_out_lens = model.encoder(
+        x=feature, x_lens=feature_lens
+    )
 
     hyps = []
 
@@ -249,7 +249,10 @@ def decode_one_batch(
         )
         for hyp in sp.decode(hyp_tokens):
             hyps.append(hyp.split())
-    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
+    elif (
+        params.decoding_method == "greedy_search"
+        and params.max_sym_per_frame == 1
+    ):
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -295,7 +298,11 @@ def decode_one_batch(
         return {"greedy_search": hyps}
     elif params.decoding_method == "fast_beam_search":
         return {
-            f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps
+            (
+                f"beam_{params.beam}_"
+                f"max_contexts_{params.max_contexts}_"
+                f"max_states_{params.max_states}"
+            ): hyps
         }
     else:
         return {f"beam_size_{params.beam_size}": hyps}
@@ -368,7 +375,9 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+            logging.info(
+                f"batch {batch_str}, cuts processed until now is {num_cuts}"
+            )
     return results
 
 
@@ -401,7 +410,8 @@ def save_results(
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = (
-        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir
+        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tWER", file=f)
@@ -441,7 +451,9 @@ def main():
         params.suffix += f"-max-contexts-{params.max_contexts}"
         params.suffix += f"-max-states-{params.max_states}"
     elif "beam_search" in params.decoding_method:
-        params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
+        params.suffix += (
+            f"-{params.decoding_method}-beam-size-{params.beam_size}"
+        )
     else:
         params.suffix += f"-context-{params.context_size}"
         params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
diff --git a/egs/librispeech/ASR/transducer_stateless_multi_datasets/export.py b/egs/librispeech/ASR/transducer_stateless_multi_datasets/export.py
index fad9a6977..b6b69d932 100755
--- a/egs/librispeech/ASR/transducer_stateless_multi_datasets/export.py
+++ b/egs/librispeech/ASR/transducer_stateless_multi_datasets/export.py
@@ -69,20 +69,17 @@ def get_parser():
         "--epoch",
         type=int,
         default=20,
-        help=(
-            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
-        ),
+        help="It specifies the checkpoint to use for decoding."
+        "Note: Epoch counts from 0.",
     )
 
     parser.add_argument(
         "--avg",
         type=int,
         default=10,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch'. "
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch'. ",
     )
 
     parser.add_argument(
@@ -113,7 +110,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
 
     return parser
@@ -249,7 +247,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/librispeech/ASR/transducer_stateless_multi_datasets/pretrained.py b/egs/librispeech/ASR/transducer_stateless_multi_datasets/pretrained.py
index efd257b5d..f297fa2b2 100755
--- a/egs/librispeech/ASR/transducer_stateless_multi_datasets/pretrained.py
+++ b/egs/librispeech/ASR/transducer_stateless_multi_datasets/pretrained.py
@@ -90,11 +90,9 @@ def get_parser():
         "--checkpoint",
         type=str,
         required=True,
-        help=(
-            "Path to the checkpoint. "
-            "The checkpoint is assumed to be saved by "
-            "icefall.checkpoint.save_checkpoint()."
-        ),
+        help="Path to the checkpoint. "
+        "The checkpoint is assumed to be saved by "
+        "icefall.checkpoint.save_checkpoint().",
     )
 
     parser.add_argument(
@@ -119,12 +117,10 @@ def get_parser():
         "sound_files",
         type=str,
         nargs="+",
-        help=(
-            "The input sound file(s) to transcribe. "
-            "Supported formats are those supported by torchaudio.load(). "
-            "For example, wav and flac are supported. "
-            "The sample rate has to be 16kHz."
-        ),
+        help="The input sound file(s) to transcribe. "
+        "Supported formats are those supported by torchaudio.load(). "
+        "For example, wav and flac are supported. "
+        "The sample rate has to be 16kHz.",
     )
 
     parser.add_argument(
@@ -171,7 +167,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -200,9 +197,10 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert (
-            sample_rate == expected_sample_rate
-        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
+        assert sample_rate == expected_sample_rate, (
+            f"expected sample rate: {expected_sample_rate}. "
+            f"Given: {sample_rate}"
+        )
         # We use only the first channel
         ans.append(wave[0])
     return ans
@@ -261,7 +259,9 @@ def main():
     features = fbank(waves)
     feature_lengths = [f.size(0) for f in features]
 
-    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
+    features = pad_sequence(
+        features, batch_first=True, padding_value=math.log(1e-10)
+    )
 
     feature_lengths = torch.tensor(feature_lengths, device=device)
 
@@ -334,7 +334,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/librispeech/ASR/transducer_stateless_multi_datasets/test_asr_datamodule.py b/egs/librispeech/ASR/transducer_stateless_multi_datasets/test_asr_datamodule.py
index 1e1188ca6..ef51a7811 100755
--- a/egs/librispeech/ASR/transducer_stateless_multi_datasets/test_asr_datamodule.py
+++ b/egs/librispeech/ASR/transducer_stateless_multi_datasets/test_asr_datamodule.py
@@ -41,7 +41,9 @@ def test_dataset():
     print(args)
 
     if args.enable_musan:
-        cuts_musan = load_manifest(Path(args.manifest_dir) / "musan_cuts.jsonl.gz")
+        cuts_musan = load_manifest(
+            Path(args.manifest_dir) / "musan_cuts.jsonl.gz"
+        )
     else:
         cuts_musan = None
 
diff --git a/egs/librispeech/ASR/transducer_stateless_multi_datasets/train.py b/egs/librispeech/ASR/transducer_stateless_multi_datasets/train.py
index 88987d91c..27912738c 100755
--- a/egs/librispeech/ASR/transducer_stateless_multi_datasets/train.py
+++ b/egs/librispeech/ASR/transducer_stateless_multi_datasets/train.py
@@ -114,7 +114,8 @@ def get_parser():
         "--full-libri",
         type=str2bool,
         default=True,
-        help="When enabled, use 960h LibriSpeech. Otherwise, use 100h subset.",
+        help="When enabled, use 960h LibriSpeech. "
+        "Otherwise, use 100h subset.",
     )
 
     parser.add_argument(
@@ -169,7 +170,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
 
     parser.add_argument(
@@ -467,7 +469,9 @@ def compute_loss(
     info = MetricsTracker()
     with warnings.catch_warnings():
         warnings.simplefilter("ignore")
-        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
+        info["frames"] = (
+            (feature_lens // params.subsampling_factor).sum().item()
+        )
 
     # Note: We use reduction=sum while computing the loss.
     info["loss"] = loss.detach().cpu().item()
@@ -631,7 +635,9 @@ def train_one_epoch(
                     f"train/current_{prefix}_",
                     params.batch_idx_train,
                 )
-                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+                tot_loss.write_summary(
+                    tb_writer, "train/tot_", params.batch_idx_train
+                )
                 libri_tot_loss.write_summary(
                     tb_writer, "train/libri_tot_", params.batch_idx_train
                 )
@@ -778,7 +784,9 @@ def run(rank, world_size, args):
     train_giga_cuts = train_giga_cuts.repeat(times=None)
 
     if args.enable_musan:
-        cuts_musan = load_manifest(Path(args.manifest_dir) / "musan_cuts.jsonl.gz")
+        cuts_musan = load_manifest(
+            Path(args.manifest_dir) / "musan_cuts.jsonl.gz"
+        )
     else:
         cuts_musan = None
 
@@ -817,7 +825,9 @@ def run(rank, world_size, args):
 
         cur_lr = optimizer._rate
         if tb_writer is not None:
-            tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
+            tb_writer.add_scalar(
+                "train/learning_rate", cur_lr, params.batch_idx_train
+            )
             tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
 
         if rank == 0:
diff --git a/egs/ptb/LM/local/sort_lm_training_data.py b/egs/ptb/LM/local/sort_lm_training_data.py
index bed3856e4..af54dbd07 100755
--- a/egs/ptb/LM/local/sort_lm_training_data.py
+++ b/egs/ptb/LM/local/sort_lm_training_data.py
@@ -135,7 +135,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/ptb/LM/local/test_prepare_lm_training_data.py b/egs/ptb/LM/local/test_prepare_lm_training_data.py
index 3790045fa..877720e7b 100755
--- a/egs/ptb/LM/local/test_prepare_lm_training_data.py
+++ b/egs/ptb/LM/local/test_prepare_lm_training_data.py
@@ -54,7 +54,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/spgispeech/ASR/local/compute_fbank_musan.py b/egs/spgispeech/ASR/local/compute_fbank_musan.py
index 9bea28a41..6cb8b65ae 100755
--- a/egs/spgispeech/ASR/local/compute_fbank_musan.py
+++ b/egs/spgispeech/ASR/local/compute_fbank_musan.py
@@ -87,7 +87,9 @@ def compute_fbank_musan():
     # create chunks of Musan with duration 5 - 10 seconds
     musan_cuts = (
         CutSet.from_manifests(
-            recordings=combine(part["recordings"] for part in manifests.values())
+            recordings=combine(
+                part["recordings"] for part in manifests.values()
+            )
         )
         .cut_into_windows(10.0)
         .filter(lambda c: c.duration > 5)
@@ -106,6 +108,8 @@ def compute_fbank_musan():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
     logging.basicConfig(format=formatter, level=logging.INFO)
     compute_fbank_musan()
diff --git a/egs/spgispeech/ASR/local/compute_fbank_spgispeech.py b/egs/spgispeech/ASR/local/compute_fbank_spgispeech.py
index 20ff6d7ab..8116e7605 100755
--- a/egs/spgispeech/ASR/local/compute_fbank_spgispeech.py
+++ b/egs/spgispeech/ASR/local/compute_fbank_spgispeech.py
@@ -103,7 +103,11 @@ def compute_fbank_spgispeech(args):
             chunk_size=chunk_size,
         )
         start = args.start
-        stop = min(args.stop, args.num_splits) if args.stop > 0 else args.num_splits
+        stop = (
+            min(args.stop, args.num_splits)
+            if args.stop > 0
+            else args.num_splits
+        )
         num_digits = len(str(args.num_splits))
         for i in range(start, stop):
             idx = f"{i + 1}".zfill(num_digits)
@@ -125,7 +129,9 @@ def compute_fbank_spgispeech(args):
                 logging.info(f"{partition} already exists - skipping.")
                 continue
             logging.info(f"Processing {partition}")
-            cut_set = load_manifest_lazy(src_dir / f"cuts_{partition}_raw.jsonl.gz")
+            cut_set = load_manifest_lazy(
+                src_dir / f"cuts_{partition}_raw.jsonl.gz"
+            )
             cut_set = cut_set.compute_and_store_features_batch(
                 extractor=extractor,
                 storage_path=output_dir / f"feats_{partition}",
@@ -138,7 +144,9 @@ def compute_fbank_spgispeech(args):
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
     logging.basicConfig(format=formatter, level=logging.INFO)
 
     args = get_args()
diff --git a/egs/spgispeech/ASR/local/prepare_splits.py b/egs/spgispeech/ASR/local/prepare_splits.py
index 508d4acd8..8c8f1c133 100755
--- a/egs/spgispeech/ASR/local/prepare_splits.py
+++ b/egs/spgispeech/ASR/local/prepare_splits.py
@@ -55,7 +55,9 @@ def split_spgispeech_train():
 
     # Add speed perturbation
     train_cuts = (
-        train_cuts + train_cuts.perturb_speed(0.9) + train_cuts.perturb_speed(1.1)
+        train_cuts
+        + train_cuts.perturb_speed(0.9)
+        + train_cuts.perturb_speed(1.1)
     )
 
     # Write the manifests to disk.
@@ -71,7 +73,9 @@ def split_spgispeech_train():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
     logging.basicConfig(format=formatter, level=logging.INFO)
 
     split_spgispeech_train()
diff --git a/egs/spgispeech/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/spgispeech/ASR/pruned_transducer_stateless2/asr_datamodule.py
index 83f95d123..f165f6e60 100644
--- a/egs/spgispeech/ASR/pruned_transducer_stateless2/asr_datamodule.py
+++ b/egs/spgispeech/ASR/pruned_transducer_stateless2/asr_datamodule.py
@@ -70,12 +70,10 @@ class SPGISpeechAsrDataModule:
     def add_arguments(cls, parser: argparse.ArgumentParser):
         group = parser.add_argument_group(
             title="ASR data related options",
-            description=(
-                "These options are used for the preparation of "
-                "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
-                "effective batch sizes, sampling strategies, applied data "
-                "augmentations, etc."
-            ),
+            description="These options are used for the preparation of "
+            "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
+            "effective batch sizes, sampling strategies, applied data "
+            "augmentations, etc.",
         )
         group.add_argument(
             "--manifest-dir",
@@ -87,81 +85,67 @@ class SPGISpeechAsrDataModule:
             "--enable-musan",
             type=str2bool,
             default=True,
-            help=(
-                "When enabled, select noise from MUSAN and mix it "
-                "with training dataset. "
-            ),
+            help="When enabled, select noise from MUSAN and mix it "
+            "with training dataset. ",
         )
         group.add_argument(
             "--concatenate-cuts",
             type=str2bool,
             default=False,
-            help=(
-                "When enabled, utterances (cuts) will be concatenated "
-                "to minimize the amount of padding."
-            ),
+            help="When enabled, utterances (cuts) will be concatenated "
+            "to minimize the amount of padding.",
         )
         group.add_argument(
             "--duration-factor",
             type=float,
             default=1.0,
-            help=(
-                "Determines the maximum duration of a concatenated cut "
-                "relative to the duration of the longest cut in a batch."
-            ),
+            help="Determines the maximum duration of a concatenated cut "
+            "relative to the duration of the longest cut in a batch.",
         )
         group.add_argument(
             "--gap",
             type=float,
             default=1.0,
-            help=(
-                "The amount of padding (in seconds) inserted between "
-                "concatenated cuts. This padding is filled with noise when "
-                "noise augmentation is used."
-            ),
+            help="The amount of padding (in seconds) inserted between "
+            "concatenated cuts. This padding is filled with noise when "
+            "noise augmentation is used.",
         )
         group.add_argument(
             "--max-duration",
             type=int,
             default=100.0,
-            help=(
-                "Maximum pooled recordings duration (seconds) in a "
-                "single batch. You can reduce it if it causes CUDA OOM."
-            ),
+            help="Maximum pooled recordings duration (seconds) in a "
+            "single batch. You can reduce it if it causes CUDA OOM.",
         )
         group.add_argument(
             "--num-buckets",
             type=int,
             default=30,
-            help=(
-                "The number of buckets for the BucketingSampler"
-                "(you might want to increase it for larger datasets)."
-            ),
+            help="The number of buckets for the BucketingSampler"
+            "(you might want to increase it for larger datasets).",
         )
         group.add_argument(
             "--on-the-fly-feats",
             type=str2bool,
             default=False,
-            help=(
-                "When enabled, use on-the-fly cut mixing and feature "
-                "extraction. Will drop existing precomputed feature manifests "
-                "if available."
-            ),
+            help="When enabled, use on-the-fly cut mixing and feature "
+            "extraction. Will drop existing precomputed feature manifests "
+            "if available.",
         )
         group.add_argument(
             "--shuffle",
             type=str2bool,
             default=True,
-            help=(
-                "When enabled (=default), the examples will be shuffled for each epoch."
-            ),
+            help="When enabled (=default), the examples will be "
+            "shuffled for each epoch.",
         )
 
         group.add_argument(
             "--num-workers",
             type=int,
             default=8,
-            help="The number of training dataloader workers that collect the batches.",
+            help="The number of training dataloader workers that "
+            "collect the batches.",
         )
         group.add_argument(
             "--enable-spec-aug",
@@ -173,12 +157,10 @@ class SPGISpeechAsrDataModule:
             "--spec-aug-time-warp-factor",
             type=int,
             default=80,
-            help=(
-                "Used only when --enable-spec-aug is True. "
-                "It specifies the factor for time warping in SpecAugment. "
-                "Larger values mean more warping. "
-                "A value less than 1 means to disable time warp."
-            ),
+            help="Used only when --enable-spec-aug is True. "
+            "It specifies the factor for time warping in SpecAugment. "
+            "Larger values mean more warping. "
+            "A value less than 1 means to disable time warp.",
         )
 
     def train_dataloaders(
@@ -194,20 +176,24 @@ class SPGISpeechAsrDataModule:
             The state dict for the training sampler.
         """
         logging.info("About to get Musan cuts")
-        cuts_musan = load_manifest(self.args.manifest_dir / "cuts_musan.jsonl.gz")
+        cuts_musan = load_manifest(
+            self.args.manifest_dir / "cuts_musan.jsonl.gz"
+        )
 
         transforms = []
         if self.args.enable_musan:
             logging.info("Enable MUSAN")
             transforms.append(
-                CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
+                CutMix(
+                    cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
+                )
             )
         else:
             logging.info("Disable MUSAN")
 
         if self.args.concatenate_cuts:
             logging.info(
-                "Using cut concatenation with duration factor "
+                f"Using cut concatenation with duration factor "
                 f"{self.args.duration_factor} and gap {self.args.gap}."
             )
             # Cut concatenation should be the first transform in the list,
@@ -222,7 +208,9 @@ class SPGISpeechAsrDataModule:
         input_transforms = []
         if self.args.enable_spec_aug:
             logging.info("Enable SpecAugment")
-            logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
+            logging.info(
+                f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
+            )
             input_transforms.append(
                 SpecAugment(
                     time_warp_factor=self.args.spec_aug_time_warp_factor,
@@ -239,7 +227,9 @@ class SPGISpeechAsrDataModule:
         if self.args.on_the_fly_feats:
             train = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
+                input_strategy=OnTheFlyFeatures(
+                    Fbank(FbankConfig(num_mel_bins=80))
+                ),
                 input_transforms=input_transforms,
             )
         else:
@@ -292,7 +282,9 @@ class SPGISpeechAsrDataModule:
         if self.args.on_the_fly_feats:
             validate = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
+                input_strategy=OnTheFlyFeatures(
+                    Fbank(FbankConfig(num_mel_bins=80))
+                ),
             )
         else:
             validate = K2SpeechRecognitionDataset(
@@ -336,7 +328,9 @@ class SPGISpeechAsrDataModule:
     @lru_cache()
     def train_cuts(self) -> CutSet:
         logging.info("About to get SPGISpeech train cuts")
-        return load_manifest_lazy(self.args.manifest_dir / "cuts_train_shuf.jsonl.gz")
+        return load_manifest_lazy(
+            self.args.manifest_dir / "cuts_train_shuf.jsonl.gz"
+        )
 
     @lru_cache()
     def dev_cuts(self) -> CutSet:
diff --git a/egs/spgispeech/ASR/pruned_transducer_stateless2/decode.py b/egs/spgispeech/ASR/pruned_transducer_stateless2/decode.py
index 72a7cd1c1..c39bd0530 100755
--- a/egs/spgispeech/ASR/pruned_transducer_stateless2/decode.py
+++ b/egs/spgispeech/ASR/pruned_transducer_stateless2/decode.py
@@ -76,7 +76,11 @@ from beam_search import (
 )
 from train import get_params, get_transducer_model
 
-from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
+from icefall.checkpoint import (
+    average_checkpoints,
+    find_checkpoints,
+    load_checkpoint,
+)
 from icefall.utils import (
     AttributeDict,
     setup_logger,
@@ -113,11 +117,9 @@ def get_parser():
         "--avg",
         type=int,
         default=10,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch' and '--iter'"
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch' and '--iter'",
     )
 
     parser.add_argument(
@@ -185,7 +187,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -243,7 +246,9 @@ def decode_one_batch(
     supervisions = batch["supervisions"]
     feature_lens = supervisions["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
+    encoder_out, encoder_out_lens = model.encoder(
+        x=feature, x_lens=feature_lens
+    )
     hyps = []
 
     if params.decoding_method == "fast_beam_search":
@@ -258,7 +263,10 @@ def decode_one_batch(
         )
         for hyp in sp.decode(hyp_tokens):
             hyps.append(hyp.split())
-    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
+    elif (
+        params.decoding_method == "greedy_search"
+        and params.max_sym_per_frame == 1
+    ):
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -304,7 +312,11 @@ def decode_one_batch(
         return {"greedy_search": hyps}
     elif params.decoding_method == "fast_beam_search":
         return {
-            f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps
+            (
+                f"beam_{params.beam}_"
+                f"max_contexts_{params.max_contexts}_"
+                f"max_states_{params.max_states}"
+            ): hyps
         }
     else:
         return {f"beam_size_{params.beam_size}": hyps}
@@ -377,7 +389,9 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+            logging.info(
+                f"batch {batch_str}, cuts processed until now is {num_cuts}"
+            )
     return results
 
 
@@ -410,7 +424,9 @@ def save_results(
         # we also compute CER for spgispeech dataset.
         results_char = []
         for res in results:
-            results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
+            results_char.append(
+                (res[0], list("".join(res[1])), list("".join(res[2])))
+            )
         cers_filename = (
             params.res_dir / f"cers-{test_set_name}-{key}-{params.suffix}.txt"
         )
@@ -422,23 +438,32 @@ def save_results(
 
         logging.info("Wrote detailed error stats to {}".format(wers_filename))
 
-    test_set_wers = {k: v for k, v in sorted(test_set_wers.items(), key=lambda x: x[1])}
-    test_set_cers = {k: v for k, v in sorted(test_set_cers.items(), key=lambda x: x[1])}
+    test_set_wers = {
+        k: v for k, v in sorted(test_set_wers.items(), key=lambda x: x[1])
+    }
+    test_set_cers = {
+        k: v for k, v in sorted(test_set_cers.items(), key=lambda x: x[1])
+    }
     errs_info = (
-        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir
+        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tWER\tCER", file=f)
         for key in test_set_wers:
             print(
-                "{}\t{}\t{}".format(key, test_set_wers[key], test_set_cers[key]),
+                "{}\t{}\t{}".format(
+                    key, test_set_wers[key], test_set_cers[key]
+                ),
                 file=f,
             )
 
     s = "\nFor {}, WER/CER of different settings are:\n".format(test_set_name)
     note = "\tbest for {}".format(test_set_name)
     for key in test_set_wers:
-        s += "{}\t{}\t{}{}\n".format(key, test_set_wers[key], test_set_cers[key], note)
+        s += "{}\t{}\t{}{}\n".format(
+            key, test_set_wers[key], test_set_cers[key], note
+        )
         note = ""
     logging.info(s)
 
@@ -471,7 +496,9 @@ def main():
         params.suffix += f"-max-contexts-{params.max_contexts}"
         params.suffix += f"-max-states-{params.max_states}"
     elif "beam_search" in params.decoding_method:
-        params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
+        params.suffix += (
+            f"-{params.decoding_method}-beam-size-{params.beam_size}"
+        )
     else:
         params.suffix += f"-context-{params.context_size}"
         params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
@@ -503,7 +530,8 @@ def main():
         ]
         if len(filenames) == 0:
             raise ValueError(
-                f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
+                f"No checkpoints found for"
+                f" --iter {params.iter}, --avg {params.avg}"
             )
         elif len(filenames) < params.avg:
             raise ValueError(
diff --git a/egs/spgispeech/ASR/pruned_transducer_stateless2/export.py b/egs/spgispeech/ASR/pruned_transducer_stateless2/export.py
index 1f18ae2f3..77faa3c0e 100755
--- a/egs/spgispeech/ASR/pruned_transducer_stateless2/export.py
+++ b/egs/spgispeech/ASR/pruned_transducer_stateless2/export.py
@@ -50,7 +50,11 @@ import sentencepiece as spm
 import torch
 from train import get_params, get_transducer_model
 
-from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
+from icefall.checkpoint import (
+    average_checkpoints,
+    find_checkpoints,
+    load_checkpoint,
+)
 from icefall.utils import str2bool
 
 
@@ -63,20 +67,17 @@ def get_parser():
         "--epoch",
         type=int,
         default=28,
-        help=(
-            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
-        ),
+        help="It specifies the checkpoint to use for decoding."
+        "Note: Epoch counts from 0.",
     )
 
     parser.add_argument(
         "--avg",
         type=int,
         default=15,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch'. "
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch'. ",
     )
 
     parser.add_argument(
@@ -118,7 +119,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
 
     return parser
@@ -194,7 +196,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/spgispeech/ASR/pruned_transducer_stateless2/train.py b/egs/spgispeech/ASR/pruned_transducer_stateless2/train.py
index cd835a7b4..dda29b3e5 100755
--- a/egs/spgispeech/ASR/pruned_transducer_stateless2/train.py
+++ b/egs/spgispeech/ASR/pruned_transducer_stateless2/train.py
@@ -77,7 +77,9 @@ from icefall.dist import cleanup_dist, setup_dist
 from icefall.env import get_env_info
 from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
 
-LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
+LRSchedulerType = Union[
+    torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
+]
 
 
 def get_parser():
@@ -153,7 +155,8 @@ def get_parser():
         "--initial-lr",
         type=float,
         default=0.003,
-        help="The initial learning rate.  This value should not need to be changed.",
+        help="The initial learning rate.  This value should not need to be "
+        "changed.",
     )
 
     parser.add_argument(
@@ -176,45 +179,42 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
 
     parser.add_argument(
         "--prune-range",
         type=int,
         default=5,
-        help=(
-            "The prune range for rnnt loss, it means how many symbols(context)"
-            "we are using to compute the loss"
-        ),
+        help="The prune range for rnnt loss, it means how many symbols(context)"
+        "we are using to compute the loss",
     )
 
     parser.add_argument(
         "--lm-scale",
         type=float,
         default=0.25,
-        help=(
-            "The scale to smooth the loss with lm (output of prediction network) part."
-        ),
+        help="The scale to smooth the loss with lm "
+        "(output of prediction network) part.",
     )
 
     parser.add_argument(
         "--am-scale",
         type=float,
         default=0.0,
-        help="The scale to smooth the loss with am (output of encoder network)part.",
+        help="The scale to smooth the loss with am (output of encoder network)"
+        "part.",
     )
 
     parser.add_argument(
         "--simple-loss-scale",
         type=float,
         default=0.5,
-        help=(
-            "To get pruning ranges, we will calculate a simple version"
-            "loss(joiner is just addition), this simple loss also uses for"
-            "training (as a regularization item). We will scale the simple loss"
-            "with this parameter before adding to the final loss."
-        ),
+        help="To get pruning ranges, we will calculate a simple version"
+        "loss(joiner is just addition), this simple loss also uses for"
+        "training (as a regularization item). We will scale the simple loss"
+        "with this parameter before adding to the final loss.",
     )
 
     parser.add_argument(
@@ -554,16 +554,23 @@ def compute_loss(
         # overwhelming the simple_loss and causing it to diverge,
         # in case it had not fully learned the alignment yet.
         pruned_loss_scale = (
-            0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
+            0.0
+            if warmup < 1.0
+            else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
+        )
+        loss = (
+            params.simple_loss_scale * simple_loss
+            + pruned_loss_scale * pruned_loss
         )
-        loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
 
     assert loss.requires_grad == is_training
 
     info = MetricsTracker()
     with warnings.catch_warnings():
         warnings.simplefilter("ignore")
-        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
+        info["frames"] = (
+            (feature_lens // params.subsampling_factor).sum().item()
+        )
 
     # Note: We use reduction=sum while computing the loss.
     info["loss"] = loss.detach().cpu().item()
@@ -726,7 +733,9 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+                tot_loss.write_summary(
+                    tb_writer, "train/tot_", params.batch_idx_train
+                )
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
diff --git a/egs/tal_csasr/ASR/local/compute_fbank_tal_csasr.py b/egs/tal_csasr/ASR/local/compute_fbank_tal_csasr.py
index 602e50d29..4582609ac 100755
--- a/egs/tal_csasr/ASR/local/compute_fbank_tal_csasr.py
+++ b/egs/tal_csasr/ASR/local/compute_fbank_tal_csasr.py
@@ -84,7 +84,9 @@ def compute_fbank_tal_csasr(num_mel_bins: int = 80):
             )
             if "train" in partition:
                 cut_set = (
-                    cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
+                    cut_set
+                    + cut_set.perturb_speed(0.9)
+                    + cut_set.perturb_speed(1.1)
                 )
             cut_set = cut_set.compute_and_store_features(
                 extractor=extractor,
@@ -110,7 +112,9 @@ def get_args():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
 
diff --git a/egs/tal_csasr/ASR/local/prepare_char.py b/egs/tal_csasr/ASR/local/prepare_char.py
index 1262baf63..2c5b8b8b3 100755
--- a/egs/tal_csasr/ASR/local/prepare_char.py
+++ b/egs/tal_csasr/ASR/local/prepare_char.py
@@ -87,7 +87,9 @@ def lexicon_to_fst_no_sil(
         cur_state = loop_state
 
         word = word2id[word]
-        pieces = [token2id[i] if i in token2id else token2id[""] for i in pieces]
+        pieces = [
+            token2id[i] if i in token2id else token2id[""] for i in pieces
+        ]
 
         for i in range(len(pieces) - 1):
             w = word if i == 0 else eps
diff --git a/egs/tal_csasr/ASR/local/prepare_lang.py b/egs/tal_csasr/ASR/local/prepare_lang.py
index c8cf9b881..e5ae89ec4 100755
--- a/egs/tal_csasr/ASR/local/prepare_lang.py
+++ b/egs/tal_csasr/ASR/local/prepare_lang.py
@@ -317,7 +317,9 @@ def lexicon_to_fst(
 
 def get_args():
     parser = argparse.ArgumentParser()
-    parser.add_argument("--lang-dir", type=str, help="The lang dir, data/lang_phone")
+    parser.add_argument(
+        "--lang-dir", type=str, help="The lang dir, data/lang_phone"
+    )
     return parser.parse_args()
 
 
diff --git a/egs/tal_csasr/ASR/local/test_prepare_lang.py b/egs/tal_csasr/ASR/local/test_prepare_lang.py
index 74e025ad7..d4cf62bba 100755
--- a/egs/tal_csasr/ASR/local/test_prepare_lang.py
+++ b/egs/tal_csasr/ASR/local/test_prepare_lang.py
@@ -88,7 +88,9 @@ def test_read_lexicon(filename: str):
     fsa.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
     fsa.draw("L.pdf", title="L")
 
-    fsa_disambig = lexicon_to_fst(lexicon_disambig, phone2id=phone2id, word2id=word2id)
+    fsa_disambig = lexicon_to_fst(
+        lexicon_disambig, phone2id=phone2id, word2id=word2id
+    )
     fsa_disambig.labels_sym = k2.SymbolTable.from_file("phones.txt")
     fsa_disambig.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
     fsa_disambig.draw("L_disambig.pdf", title="L_disambig")
diff --git a/egs/tal_csasr/ASR/local/text2token.py b/egs/tal_csasr/ASR/local/text2token.py
index 2be639b7a..71be2a613 100755
--- a/egs/tal_csasr/ASR/local/text2token.py
+++ b/egs/tal_csasr/ASR/local/text2token.py
@@ -50,15 +50,15 @@ def get_parser():
         "-n",
         default=1,
         type=int,
-        help=(
-            "number of characters to split, i.e.,                         aabb -> a a b"
-            " b with -n 1 and aa bb with -n 2"
-        ),
+        help="number of characters to split, i.e., \
+                        aabb -> a a b b with -n 1 and aa bb with -n 2",
     )
     parser.add_argument(
         "--skip-ncols", "-s", default=0, type=int, help="skip first n columns"
     )
-    parser.add_argument("--space", default="", type=str, help="space symbol")
+    parser.add_argument(
+        "--space", default="", type=str, help="space symbol"
+    )
     parser.add_argument(
         "--non-lang-syms",
         "-l",
@@ -66,7 +66,9 @@ def get_parser():
         type=str,
         help="list of non-linguistic symobles, e.g.,  etc.",
     )
-    parser.add_argument("text", type=str, default=False, nargs="?", help="input text")
+    parser.add_argument(
+        "text", type=str, default=False, nargs="?", help="input text"
+    )
     parser.add_argument(
         "--trans_type",
         "-t",
@@ -106,7 +108,8 @@ def token2id(
             if token_type == "lazy_pinyin":
                 text = lazy_pinyin(chars_list)
                 sub_ids = [
-                    token_table[txt] if txt in token_table else oov_id for txt in text
+                    token_table[txt] if txt in token_table else oov_id
+                    for txt in text
                 ]
                 ids.append(sub_ids)
             else:  # token_type = "pinyin"
@@ -132,7 +135,9 @@ def main():
     if args.text:
         f = codecs.open(args.text, encoding="utf-8")
     else:
-        f = codecs.getreader("utf-8")(sys.stdin if is_python2 else sys.stdin.buffer)
+        f = codecs.getreader("utf-8")(
+            sys.stdin if is_python2 else sys.stdin.buffer
+        )
 
     sys.stdout = codecs.getwriter("utf-8")(
         sys.stdout if is_python2 else sys.stdout.buffer
diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/asr_datamodule.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/asr_datamodule.py
index 02bd6e2cc..49bfb148b 100644
--- a/egs/tal_csasr/ASR/pruned_transducer_stateless5/asr_datamodule.py
+++ b/egs/tal_csasr/ASR/pruned_transducer_stateless5/asr_datamodule.py
@@ -74,12 +74,10 @@ class TAL_CSASRAsrDataModule:
     def add_arguments(cls, parser: argparse.ArgumentParser):
         group = parser.add_argument_group(
             title="ASR data related options",
-            description=(
-                "These options are used for the preparation of "
-                "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
-                "effective batch sizes, sampling strategies, applied data "
-                "augmentations, etc."
-            ),
+            description="These options are used for the preparation of "
+            "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
+            "effective batch sizes, sampling strategies, applied data "
+            "augmentations, etc.",
         )
 
         group.add_argument(
@@ -93,81 +91,66 @@ class TAL_CSASRAsrDataModule:
             "--max-duration",
             type=int,
             default=200.0,
-            help=(
-                "Maximum pooled recordings duration (seconds) in a "
-                "single batch. You can reduce it if it causes CUDA OOM."
-            ),
+            help="Maximum pooled recordings duration (seconds) in a "
+            "single batch. You can reduce it if it causes CUDA OOM.",
         )
 
         group.add_argument(
             "--bucketing-sampler",
             type=str2bool,
             default=True,
-            help=(
-                "When enabled, the batches will come from buckets of "
-                "similar duration (saves padding frames)."
-            ),
+            help="When enabled, the batches will come from buckets of "
+            "similar duration (saves padding frames).",
         )
 
         group.add_argument(
             "--num-buckets",
             type=int,
             default=300,
-            help=(
-                "The number of buckets for the DynamicBucketingSampler"
-                "(you might want to increase it for larger datasets)."
-            ),
+            help="The number of buckets for the DynamicBucketingSampler"
+            "(you might want to increase it for larger datasets).",
         )
 
         group.add_argument(
             "--concatenate-cuts",
             type=str2bool,
             default=False,
-            help=(
-                "When enabled, utterances (cuts) will be concatenated "
-                "to minimize the amount of padding."
-            ),
+            help="When enabled, utterances (cuts) will be concatenated "
+            "to minimize the amount of padding.",
         )
 
         group.add_argument(
             "--duration-factor",
             type=float,
             default=1.0,
-            help=(
-                "Determines the maximum duration of a concatenated cut "
-                "relative to the duration of the longest cut in a batch."
-            ),
+            help="Determines the maximum duration of a concatenated cut "
+            "relative to the duration of the longest cut in a batch.",
         )
 
         group.add_argument(
             "--gap",
             type=float,
             default=1.0,
-            help=(
-                "The amount of padding (in seconds) inserted between "
-                "concatenated cuts. This padding is filled with noise when "
-                "noise augmentation is used."
-            ),
+            help="The amount of padding (in seconds) inserted between "
+            "concatenated cuts. This padding is filled with noise when "
+            "noise augmentation is used.",
         )
 
         group.add_argument(
             "--on-the-fly-feats",
             type=str2bool,
             default=False,
-            help=(
-                "When enabled, use on-the-fly cut mixing and feature "
-                "extraction. Will drop existing precomputed feature manifests "
-                "if available."
-            ),
+            help="When enabled, use on-the-fly cut mixing and feature "
+            "extraction. Will drop existing precomputed feature manifests "
+            "if available.",
         )
 
         group.add_argument(
             "--shuffle",
             type=str2bool,
             default=True,
-            help=(
-                "When enabled (=default), the examples will be shuffled for each epoch."
-            ),
+            help="When enabled (=default), the examples will be "
+            "shuffled for each epoch.",
         )
 
         group.add_argument(
@@ -181,18 +164,17 @@ class TAL_CSASRAsrDataModule:
             "--return-cuts",
             type=str2bool,
             default=True,
-            help=(
-                "When enabled, each batch will have the "
-                "field: batch['supervisions']['cut'] with the cuts that "
-                "were used to construct it."
-            ),
+            help="When enabled, each batch will have the "
+            "field: batch['supervisions']['cut'] with the cuts that "
+            "were used to construct it.",
         )
 
         group.add_argument(
             "--num-workers",
             type=int,
             default=2,
-            help="The number of training dataloader workers that collect the batches.",
+            help="The number of training dataloader workers that "
+            "collect the batches.",
         )
 
         group.add_argument(
@@ -206,22 +188,18 @@ class TAL_CSASRAsrDataModule:
             "--spec-aug-time-warp-factor",
             type=int,
             default=80,
-            help=(
-                "Used only when --enable-spec-aug is True. "
-                "It specifies the factor for time warping in SpecAugment. "
-                "Larger values mean more warping. "
-                "A value less than 1 means to disable time warp."
-            ),
+            help="Used only when --enable-spec-aug is True. "
+            "It specifies the factor for time warping in SpecAugment. "
+            "Larger values mean more warping. "
+            "A value less than 1 means to disable time warp.",
         )
 
         group.add_argument(
             "--enable-musan",
             type=str2bool,
             default=True,
-            help=(
-                "When enabled, select noise from MUSAN and mix it"
-                "with training dataset. "
-            ),
+            help="When enabled, select noise from MUSAN and mix it"
+            "with training dataset. ",
         )
 
         group.add_argument(
@@ -244,20 +222,24 @@ class TAL_CSASRAsrDataModule:
             The state dict for the training sampler.
         """
         logging.info("About to get Musan cuts")
-        cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
+        cuts_musan = load_manifest(
+            self.args.manifest_dir / "musan_cuts.jsonl.gz"
+        )
 
         transforms = []
         if self.args.enable_musan:
             logging.info("Enable MUSAN")
             transforms.append(
-                CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
+                CutMix(
+                    cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
+                )
             )
         else:
             logging.info("Disable MUSAN")
 
         if self.args.concatenate_cuts:
             logging.info(
-                "Using cut concatenation with duration factor "
+                f"Using cut concatenation with duration factor "
                 f"{self.args.duration_factor} and gap {self.args.gap}."
             )
             # Cut concatenation should be the first transform in the list,
@@ -272,7 +254,9 @@ class TAL_CSASRAsrDataModule:
         input_transforms = []
         if self.args.enable_spec_aug:
             logging.info("Enable SpecAugment")
-            logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
+            logging.info(
+                f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
+            )
             # Set the value of num_frame_masks according to Lhotse's version.
             # In different Lhotse's versions, the default of num_frame_masks is
             # different.
@@ -316,7 +300,9 @@ class TAL_CSASRAsrDataModule:
             # Drop feats to be on the safe side.
             train = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
+                input_strategy=OnTheFlyFeatures(
+                    Fbank(FbankConfig(num_mel_bins=80))
+                ),
                 input_transforms=input_transforms,
                 return_cuts=self.args.return_cuts,
             )
@@ -374,7 +360,9 @@ class TAL_CSASRAsrDataModule:
         if self.args.on_the_fly_feats:
             validate = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
+                input_strategy=OnTheFlyFeatures(
+                    Fbank(FbankConfig(num_mel_bins=80))
+                ),
                 return_cuts=self.args.return_cuts,
             )
         else:
diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/decode.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/decode.py
index b2aef7e86..b624913f5 100755
--- a/egs/tal_csasr/ASR/pruned_transducer_stateless5/decode.py
+++ b/egs/tal_csasr/ASR/pruned_transducer_stateless5/decode.py
@@ -124,24 +124,20 @@ def get_parser():
         "--avg",
         type=int,
         default=15,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch' and '--iter'"
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch' and '--iter'",
     )
 
     parser.add_argument(
         "--use-averaged-model",
         type=str2bool,
         default=False,
-        help=(
-            "Whether to load averaged model. Currently it only supports "
-            "using --epoch. If True, it would decode with the averaged model "
-            "over the epoch range from `epoch-avg` (excluded) to `epoch`."
-            "Actually only the models with epoch number of `epoch-avg` and "
-            "`epoch` are loaded for averaging. "
-        ),
+        help="Whether to load averaged model. Currently it only supports "
+        "using --epoch. If True, it would decode with the averaged model "
+        "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+        "Actually only the models with epoch number of `epoch-avg` and "
+        "`epoch` are loaded for averaging. ",
     )
 
     parser.add_argument(
@@ -212,7 +208,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -271,7 +268,9 @@ def decode_one_batch(
     supervisions = batch["supervisions"]
     feature_lens = supervisions["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
+    encoder_out, encoder_out_lens = model.encoder(
+        x=feature, x_lens=feature_lens
+    )
     hyps = []
     zh_hyps = []
     en_hyps = []
@@ -304,7 +303,10 @@ def decode_one_batch(
             hyps.append(chars_new)
             zh_hyps.append(zh_text)
             en_hyps.append(en_text)
-    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
+    elif (
+        params.decoding_method == "greedy_search"
+        and params.max_sym_per_frame == 1
+    ):
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -373,7 +375,9 @@ def decode_one_batch(
                     f"Unsupported decoding method: {params.decoding_method}"
                 )
             for i in range(encoder_out.size(0)):
-                hyp = sp.decode([lexicon.token_table[idx] for idx in hyp_tokens[i]])
+                hyp = sp.decode(
+                    [lexicon.token_table[idx] for idx in hyp_tokens[i]]
+                )
                 chars = pattern.split(hyp.upper())
                 chars_new = []
                 zh_text = []
@@ -392,11 +396,11 @@ def decode_one_batch(
         return {"greedy_search": (hyps, zh_hyps, en_hyps)}
     elif params.decoding_method == "fast_beam_search":
         return {
-            f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": (
-                hyps,
-                zh_hyps,
-                en_hyps,
-            )
+            (
+                f"beam_{params.beam}_"
+                f"max_contexts_{params.max_contexts}_"
+                f"max_states_{params.max_states}"
+            ): (hyps, zh_hyps, en_hyps)
         }
     else:
         return {f"beam_size_{params.beam_size}": (hyps, zh_hyps, en_hyps)}
@@ -502,7 +506,9 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+            logging.info(
+                f"batch {batch_str}, cuts processed until now is {num_cuts}"
+            )
     return results, zh_results, en_results
 
 
@@ -535,7 +541,8 @@ def save_results(
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = (
-        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir
+        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tWER", file=f)
@@ -578,7 +585,9 @@ def main():
         params.suffix += f"-max-contexts-{params.max_contexts}"
         params.suffix += f"-max-states-{params.max_states}"
     elif "beam_search" in params.decoding_method:
-        params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
+        params.suffix += (
+            f"-{params.decoding_method}-beam-size-{params.beam_size}"
+        )
     else:
         params.suffix += f"-context-{params.context_size}"
         params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
@@ -610,12 +619,13 @@ def main():
 
     if not params.use_averaged_model:
         if params.iter > 0:
-            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
-                : params.avg
-            ]
+            filenames = find_checkpoints(
+                params.exp_dir, iteration=-params.iter
+            )[: params.avg]
             if len(filenames) == 0:
                 raise ValueError(
-                    f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
                 )
             elif len(filenames) < params.avg:
                 raise ValueError(
@@ -638,12 +648,13 @@ def main():
             model.load_state_dict(average_checkpoints(filenames, device=device))
     else:
         if params.iter > 0:
-            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
-                : params.avg + 1
-            ]
+            filenames = find_checkpoints(
+                params.exp_dir, iteration=-params.iter
+            )[: params.avg + 1]
             if len(filenames) == 0:
                 raise ValueError(
-                    f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
                 )
             elif len(filenames) < params.avg + 1:
                 raise ValueError(
@@ -671,7 +682,7 @@ def main():
             filename_start = f"{params.exp_dir}/epoch-{start}.pt"
             filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
             logging.info(
-                "Calculating the averaged model over epoch range from "
+                f"Calculating the averaged model over epoch range from "
                 f"{start} (excluded) to {params.epoch}"
             )
             model.to(device)
diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/export.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/export.py
index 94a4c7a2e..8f900208a 100755
--- a/egs/tal_csasr/ASR/pruned_transducer_stateless5/export.py
+++ b/egs/tal_csasr/ASR/pruned_transducer_stateless5/export.py
@@ -92,24 +92,20 @@ def get_parser():
         "--avg",
         type=int,
         default=15,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch' and '--iter'"
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch' and '--iter'",
     )
 
     parser.add_argument(
         "--use-averaged-model",
         type=str2bool,
         default=False,
-        help=(
-            "Whether to load averaged model. Currently it only supports "
-            "using --epoch. If True, it would decode with the averaged model "
-            "over the epoch range from `epoch-avg` (excluded) to `epoch`."
-            "Actually only the models with epoch number of `epoch-avg` and "
-            "`epoch` are loaded for averaging. "
-        ),
+        help="Whether to load averaged model. Currently it only supports "
+        "using --epoch. If True, it would decode with the averaged model "
+        "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+        "Actually only the models with epoch number of `epoch-avg` and "
+        "`epoch` are loaded for averaging. ",
     )
 
     parser.add_argument(
@@ -143,7 +139,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
 
     add_model_arguments(parser)
@@ -179,12 +176,13 @@ def main():
 
     if not params.use_averaged_model:
         if params.iter > 0:
-            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
-                : params.avg
-            ]
+            filenames = find_checkpoints(
+                params.exp_dir, iteration=-params.iter
+            )[: params.avg]
             if len(filenames) == 0:
                 raise ValueError(
-                    f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
                 )
             elif len(filenames) < params.avg:
                 raise ValueError(
@@ -207,12 +205,13 @@ def main():
             model.load_state_dict(average_checkpoints(filenames, device=device))
     else:
         if params.iter > 0:
-            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
-                : params.avg + 1
-            ]
+            filenames = find_checkpoints(
+                params.exp_dir, iteration=-params.iter
+            )[: params.avg + 1]
             if len(filenames) == 0:
                 raise ValueError(
-                    f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
                 )
             elif len(filenames) < params.avg + 1:
                 raise ValueError(
@@ -240,7 +239,7 @@ def main():
             filename_start = f"{params.exp_dir}/epoch-{start}.pt"
             filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
             logging.info(
-                "Calculating the averaged model over epoch range from "
+                f"Calculating the averaged model over epoch range from "
                 f"{start} (excluded) to {params.epoch}"
             )
             model.to(device)
@@ -278,7 +277,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/pretrained.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/pretrained.py
index 198242129..dbe213b24 100755
--- a/egs/tal_csasr/ASR/pruned_transducer_stateless5/pretrained.py
+++ b/egs/tal_csasr/ASR/pruned_transducer_stateless5/pretrained.py
@@ -84,11 +84,9 @@ def get_parser():
         "--checkpoint",
         type=str,
         required=True,
-        help=(
-            "Path to the checkpoint. "
-            "The checkpoint is assumed to be saved by "
-            "icefall.checkpoint.save_checkpoint()."
-        ),
+        help="Path to the checkpoint. "
+        "The checkpoint is assumed to be saved by "
+        "icefall.checkpoint.save_checkpoint().",
     )
 
     parser.add_argument(
@@ -117,12 +115,10 @@ def get_parser():
         "sound_files",
         type=str,
         nargs="+",
-        help=(
-            "The input sound file(s) to transcribe. "
-            "Supported formats are those supported by torchaudio.load(). "
-            "For example, wav and flac are supported. "
-            "The sample rate has to be 16kHz."
-        ),
+        help="The input sound file(s) to transcribe. "
+        "Supported formats are those supported by torchaudio.load(). "
+        "For example, wav and flac are supported. "
+        "The sample rate has to be 16kHz.",
     )
 
     parser.add_argument(
@@ -169,7 +165,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -200,9 +197,10 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert (
-            sample_rate == expected_sample_rate
-        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
+        assert sample_rate == expected_sample_rate, (
+            f"expected sample rate: {expected_sample_rate}. "
+            f"Given: {sample_rate}"
+        )
         # We use only the first channel
         ans.append(wave[0])
     return ans
@@ -265,11 +263,15 @@ def main():
     features = fbank(waves)
     feature_lengths = [f.size(0) for f in features]
 
-    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
+    features = pad_sequence(
+        features, batch_first=True, padding_value=math.log(1e-10)
+    )
 
     feature_lengths = torch.tensor(feature_lengths, device=device)
 
-    encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths)
+    encoder_out, encoder_out_lens = model.encoder(
+        x=features, x_lens=feature_lengths
+    )
 
     num_waves = encoder_out.size(0)
     hyps = []
@@ -365,7 +367,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/train.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/train.py
index 676e8c904..ca35eba45 100755
--- a/egs/tal_csasr/ASR/pruned_transducer_stateless5/train.py
+++ b/egs/tal_csasr/ASR/pruned_transducer_stateless5/train.py
@@ -86,7 +86,9 @@ from icefall.env import get_env_info
 from icefall.lexicon import Lexicon
 from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
 
-LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
+LRSchedulerType = Union[
+    torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
+]
 
 
 def add_model_arguments(parser: argparse.ArgumentParser):
@@ -212,7 +214,8 @@ def get_parser():
         "--initial-lr",
         type=float,
         default=0.003,
-        help="The initial learning rate.  This value should not need to be changed.",
+        help="The initial learning rate.  This value should not need "
+        "to be changed.",
     )
 
     parser.add_argument(
@@ -235,45 +238,42 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
 
     parser.add_argument(
         "--prune-range",
         type=int,
         default=5,
-        help=(
-            "The prune range for rnnt loss, it means how many symbols(context)"
-            "we are using to compute the loss"
-        ),
+        help="The prune range for rnnt loss, it means how many symbols(context)"
+        "we are using to compute the loss",
     )
 
     parser.add_argument(
         "--lm-scale",
         type=float,
         default=0.25,
-        help=(
-            "The scale to smooth the loss with lm (output of prediction network) part."
-        ),
+        help="The scale to smooth the loss with lm "
+        "(output of prediction network) part.",
     )
 
     parser.add_argument(
         "--am-scale",
         type=float,
         default=0.0,
-        help="The scale to smooth the loss with am (output of encoder network)part.",
+        help="The scale to smooth the loss with am (output of encoder network)"
+        "part.",
     )
 
     parser.add_argument(
         "--simple-loss-scale",
         type=float,
         default=0.5,
-        help=(
-            "To get pruning ranges, we will calculate a simple version"
-            "loss(joiner is just addition), this simple loss also uses for"
-            "training (as a regularization item). We will scale the simple loss"
-            "with this parameter before adding to the final loss."
-        ),
+        help="To get pruning ranges, we will calculate a simple version"
+        "loss(joiner is just addition), this simple loss also uses for"
+        "training (as a regularization item). We will scale the simple loss"
+        "with this parameter before adding to the final loss.",
     )
 
     parser.add_argument(
@@ -600,7 +600,11 @@ def compute_loss(
      warmup: a floating point value which increases throughout training;
         values >= 1.0 are fully warmed up and have all modules present.
     """
-    device = model.device if isinstance(model, DDP) else next(model.parameters()).device
+    device = (
+        model.device
+        if isinstance(model, DDP)
+        else next(model.parameters()).device
+    )
     feature = batch["inputs"]
     # at entry, feature is (N, T, C)
     assert feature.ndim == 3
@@ -630,15 +634,22 @@ def compute_loss(
         # overwhelming the simple_loss and causing it to diverge,
         # in case it had not fully learned the alignment yet.
         pruned_loss_scale = (
-            0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
+            0.0
+            if warmup < 1.0
+            else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
+        )
+        loss = (
+            params.simple_loss_scale * simple_loss
+            + pruned_loss_scale * pruned_loss
         )
-        loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
     assert loss.requires_grad == is_training
 
     info = MetricsTracker()
     with warnings.catch_warnings():
         warnings.simplefilter("ignore")
-        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
+        info["frames"] = (
+            (feature_lens // params.subsampling_factor).sum().item()
+        )
 
     # Note: We use reduction=sum while computing the loss.
     info["loss"] = loss.detach().cpu().item()
@@ -817,7 +828,9 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+                tot_loss.write_summary(
+                    tb_writer, "train/tot_", params.batch_idx_train
+                )
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
@@ -931,7 +944,7 @@ def run(rank, world_size, args):
 
     if params.print_diagnostics:
         opts = diagnostics.TensorDiagnosticOptions(
-            2**22
+            2 ** 22
         )  # allow 4 megabytes per sub-module
         diagnostic = diagnostics.attach_diagnostics(model, opts)
 
diff --git a/egs/tedlium3/ASR/local/compute_fbank_tedlium.py b/egs/tedlium3/ASR/local/compute_fbank_tedlium.py
index 733ebf235..327962a79 100755
--- a/egs/tedlium3/ASR/local/compute_fbank_tedlium.py
+++ b/egs/tedlium3/ASR/local/compute_fbank_tedlium.py
@@ -83,7 +83,9 @@ def compute_fbank_tedlium():
             )
             if "train" in partition:
                 cut_set = (
-                    cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
+                    cut_set
+                    + cut_set.perturb_speed(0.9)
+                    + cut_set.perturb_speed(1.1)
                 )
             cur_num_jobs = num_jobs if ex is None else 80
             cur_num_jobs = min(cur_num_jobs, len(cut_set))
@@ -102,7 +104,9 @@ def compute_fbank_tedlium():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
 
diff --git a/egs/tedlium3/ASR/local/convert_transcript_words_to_bpe_ids.py b/egs/tedlium3/ASR/local/convert_transcript_words_to_bpe_ids.py
index 9dbcc9d9e..49544ccb3 100644
--- a/egs/tedlium3/ASR/local/convert_transcript_words_to_bpe_ids.py
+++ b/egs/tedlium3/ASR/local/convert_transcript_words_to_bpe_ids.py
@@ -25,7 +25,9 @@ import sentencepiece as spm
 
 def get_args():
     parser = argparse.ArgumentParser()
-    parser.add_argument("--texts", type=List[str], help="The input transcripts list.")
+    parser.add_argument(
+        "--texts", type=List[str], help="The input transcripts list."
+    )
     parser.add_argument(
         "--bpe-model",
         type=str,
diff --git a/egs/tedlium3/ASR/local/prepare_lexicon.py b/egs/tedlium3/ASR/local/prepare_lexicon.py
index b9160b6d4..35dd332e8 100755
--- a/egs/tedlium3/ASR/local/prepare_lexicon.py
+++ b/egs/tedlium3/ASR/local/prepare_lexicon.py
@@ -23,12 +23,11 @@ consisting of supervisions_train.json and does the following:
 1. Generate lexicon_words.txt.
 
 """
+import lhotse
 import argparse
 import logging
 from pathlib import Path
 
-import lhotse
-
 
 def get_args():
     parser = argparse.ArgumentParser()
@@ -62,7 +61,9 @@ def prepare_lexicon(manifests_dir: str, lang_dir: str):
     words = set()
 
     lexicon = Path(lang_dir) / "lexicon_words.txt"
-    sups = lhotse.load_manifest(f"{manifests_dir}/tedlium_supervisions_train.jsonl.gz")
+    sups = lhotse.load_manifest(
+        f"{manifests_dir}/tedlium_supervisions_train.jsonl.gz"
+    )
     for s in sups:
         # list the words units and filter the empty item
         words_list = list(filter(None, s.text.split()))
@@ -87,7 +88,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
 
diff --git a/egs/tedlium3/ASR/local/prepare_transcripts.py b/egs/tedlium3/ASR/local/prepare_transcripts.py
index 7ea4e89a4..1039ac5bb 100755
--- a/egs/tedlium3/ASR/local/prepare_transcripts.py
+++ b/egs/tedlium3/ASR/local/prepare_transcripts.py
@@ -23,12 +23,11 @@ consisting of supervisions_train.json and does the following:
 1. Generate train.text.
 
 """
+import lhotse
 import argparse
 import logging
 from pathlib import Path
 
-import lhotse
-
 
 def get_args():
     parser = argparse.ArgumentParser()
@@ -62,7 +61,9 @@ def prepare_transcripts(manifests_dir: str, lang_dir: str):
     texts = []
 
     train_text = Path(lang_dir) / "train.text"
-    sups = lhotse.load_manifest(f"{manifests_dir}/tedlium_supervisions_train.jsonl.gz")
+    sups = lhotse.load_manifest(
+        f"{manifests_dir}/tedlium_supervisions_train.jsonl.gz"
+    )
     for s in sups:
         texts.append(s.text)
 
@@ -82,7 +83,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
 
diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless/decode.py b/egs/tedlium3/ASR/pruned_transducer_stateless/decode.py
index 6bae33e65..2b294e601 100755
--- a/egs/tedlium3/ASR/pruned_transducer_stateless/decode.py
+++ b/egs/tedlium3/ASR/pruned_transducer_stateless/decode.py
@@ -94,20 +94,17 @@ def get_parser():
         "--epoch",
         type=int,
         default=29,
-        help=(
-            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
-        ),
+        help="It specifies the checkpoint to use for decoding."
+        "Note: Epoch counts from 0.",
     )
 
     parser.add_argument(
         "--avg",
         type=int,
         default=13,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch'. "
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch'. ",
     )
 
     parser.add_argument(
@@ -175,7 +172,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -233,7 +231,9 @@ def decode_one_batch(
     supervisions = batch["supervisions"]
     feature_lens = supervisions["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
+    encoder_out, encoder_out_lens = model.encoder(
+        x=feature, x_lens=feature_lens
+    )
     hyps = []
 
     if params.decoding_method == "fast_beam_search":
@@ -248,7 +248,10 @@ def decode_one_batch(
         )
         for hyp in sp.decode(hyp_tokens):
             hyps.append(hyp.split())
-    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
+    elif (
+        params.decoding_method == "greedy_search"
+        and params.max_sym_per_frame == 1
+    ):
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -294,7 +297,11 @@ def decode_one_batch(
         return {"greedy_search": hyps}
     elif params.decoding_method == "fast_beam_search":
         return {
-            f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps
+            (
+                f"beam_{params.beam}_"
+                f"max_contexts_{params.max_contexts}_"
+                f"max_states_{params.max_states}"
+            ): hyps
         }
     else:
         return {f"beam_size_{params.beam_size}": hyps}
@@ -367,7 +374,9 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+            logging.info(
+                f"batch {batch_str}, cuts processed until now is {num_cuts}"
+            )
     return results
 
 
@@ -400,7 +409,8 @@ def save_results(
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = (
-        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir
+        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tWER", file=f)
diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless/export.py b/egs/tedlium3/ASR/pruned_transducer_stateless/export.py
index 244740932..a1c3bcea3 100644
--- a/egs/tedlium3/ASR/pruned_transducer_stateless/export.py
+++ b/egs/tedlium3/ASR/pruned_transducer_stateless/export.py
@@ -65,20 +65,17 @@ def get_parser():
         "--epoch",
         type=int,
         default=30,
-        help=(
-            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
-        ),
+        help="It specifies the checkpoint to use for decoding."
+        "Note: Epoch counts from 0.",
     )
 
     parser.add_argument(
         "--avg",
         type=int,
         default=13,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch'. "
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch'. ",
     )
 
     parser.add_argument(
@@ -109,7 +106,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
 
     return parser
@@ -181,7 +179,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless/pretrained.py b/egs/tedlium3/ASR/pruned_transducer_stateless/pretrained.py
index 00545f107..8480ac029 100644
--- a/egs/tedlium3/ASR/pruned_transducer_stateless/pretrained.py
+++ b/egs/tedlium3/ASR/pruned_transducer_stateless/pretrained.py
@@ -93,11 +93,9 @@ def get_parser():
         "--checkpoint",
         type=str,
         required=True,
-        help=(
-            "Path to the checkpoint. "
-            "The checkpoint is assumed to be saved by "
-            "icefall.checkpoint.save_checkpoint()."
-        ),
+        help="Path to the checkpoint. "
+        "The checkpoint is assumed to be saved by "
+        "icefall.checkpoint.save_checkpoint().",
     )
 
     parser.add_argument(
@@ -124,12 +122,10 @@ def get_parser():
         "sound_files",
         type=str,
         nargs="+",
-        help=(
-            "The input sound file(s) to transcribe. "
-            "Supported formats are those supported by torchaudio.load(). "
-            "For example, wav and flac are supported. "
-            "The sample rate has to be 16kHz."
-        ),
+        help="The input sound file(s) to transcribe. "
+        "Supported formats are those supported by torchaudio.load(). "
+        "For example, wav and flac are supported. "
+        "The sample rate has to be 16kHz.",
     )
 
     parser.add_argument(
@@ -169,7 +165,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
 
     parser.add_argument(
@@ -206,9 +203,10 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert (
-            sample_rate == expected_sample_rate
-        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
+        assert sample_rate == expected_sample_rate, (
+            f"expected sample rate: {expected_sample_rate}. "
+            f"Given: {sample_rate}"
+        )
         # We use only the first channel
         ans.append(wave[0])
     return ans
@@ -273,7 +271,9 @@ def main():
     features = fbank(waves)
     feature_lengths = [f.size(0) for f in features]
 
-    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
+    features = pad_sequence(
+        features, batch_first=True, padding_value=math.log(1e-10)
+    )
 
     feature_lengths = torch.tensor(feature_lengths, device=device)
 
@@ -298,7 +298,10 @@ def main():
         )
         for hyp in sp.decode(hyp_tokens):
             hyps.append(hyp.split())
-    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
+    elif (
+        params.decoding_method == "greedy_search"
+        and params.max_sym_per_frame == 1
+    ):
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -350,7 +353,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless/train.py b/egs/tedlium3/ASR/pruned_transducer_stateless/train.py
index 70c5e290f..8d5cdf683 100755
--- a/egs/tedlium3/ASR/pruned_transducer_stateless/train.py
+++ b/egs/tedlium3/ASR/pruned_transducer_stateless/train.py
@@ -133,45 +133,42 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
 
     parser.add_argument(
         "--prune-range",
         type=int,
         default=5,
-        help=(
-            "The prune range for rnnt loss, it means how many symbols(context)"
-            "we are using to compute the loss"
-        ),
+        help="The prune range for rnnt loss, it means how many symbols(context)"
+        "we are using to compute the loss",
     )
 
     parser.add_argument(
         "--lm-scale",
         type=float,
         default=0.25,
-        help=(
-            "The scale to smooth the loss with lm (output of prediction network) part."
-        ),
+        help="The scale to smooth the loss with lm "
+        "(output of prediction network) part.",
     )
 
     parser.add_argument(
         "--am-scale",
         type=float,
         default=0.0,
-        help="The scale to smooth the loss with am (output of encoder network)part.",
+        help="The scale to smooth the loss with am (output of encoder network)"
+        "part.",
     )
 
     parser.add_argument(
         "--simple-loss-scale",
         type=float,
         default=0.5,
-        help=(
-            "To get pruning ranges, we will calculate a simple version"
-            "loss(joiner is just addition), this simple loss also uses for"
-            "training (as a regularization item). We will scale the simple loss"
-            "with this parameter before adding to the final loss."
-        ),
+        help="To get pruning ranges, we will calculate a simple version"
+        "loss(joiner is just addition), this simple loss also uses for"
+        "training (as a regularization item). We will scale the simple loss"
+        "with this parameter before adding to the final loss.",
     )
 
     parser.add_argument(
@@ -559,7 +556,9 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+                tot_loss.write_summary(
+                    tb_writer, "train/tot_", params.batch_idx_train
+                )
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
@@ -679,7 +678,9 @@ def run(rank, world_size, args):
 
         cur_lr = optimizer._rate
         if tb_writer is not None:
-            tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
+            tb_writer.add_scalar(
+                "train/learning_rate", cur_lr, params.batch_idx_train
+            )
             tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
 
         if rank == 0:
diff --git a/egs/tedlium3/ASR/transducer_stateless/asr_datamodule.py b/egs/tedlium3/ASR/transducer_stateless/asr_datamodule.py
index f90f79d8c..94784c4c4 100644
--- a/egs/tedlium3/ASR/transducer_stateless/asr_datamodule.py
+++ b/egs/tedlium3/ASR/transducer_stateless/asr_datamodule.py
@@ -18,6 +18,7 @@
 
 import argparse
 import logging
+
 from functools import lru_cache
 from pathlib import Path
 from typing import Any, Dict, Optional
@@ -62,12 +63,10 @@ class TedLiumAsrDataModule:
     def add_arguments(cls, parser: argparse.ArgumentParser):
         group = parser.add_argument_group(
             title="ASR data related options",
-            description=(
-                "These options are used for the preparation of "
-                "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
-                "effective batch sizes, sampling strategies, applied data "
-                "augmentations, etc."
-            ),
+            description="These options are used for the preparation of "
+            "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
+            "effective batch sizes, sampling strategies, applied data "
+            "augmentations, etc.",
         )
         group.add_argument(
             "--manifest-dir",
@@ -79,90 +78,74 @@ class TedLiumAsrDataModule:
             "--max-duration",
             type=int,
             default=200.0,
-            help=(
-                "Maximum pooled recordings duration (seconds) in a "
-                "single batch. You can reduce it if it causes CUDA OOM."
-            ),
+            help="Maximum pooled recordings duration (seconds) in a "
+            "single batch. You can reduce it if it causes CUDA OOM.",
         )
         group.add_argument(
             "--bucketing-sampler",
             type=str2bool,
             default=True,
-            help=(
-                "When enabled, the batches will come from buckets of "
-                "similar duration (saves padding frames)."
-            ),
+            help="When enabled, the batches will come from buckets of "
+            "similar duration (saves padding frames).",
         )
         group.add_argument(
             "--num-buckets",
             type=int,
             default=30,
-            help=(
-                "The number of buckets for the DynamicBucketingSampler"
-                "(you might want to increase it for larger datasets)."
-            ),
+            help="The number of buckets for the DynamicBucketingSampler"
+            "(you might want to increase it for larger datasets).",
         )
         group.add_argument(
             "--concatenate-cuts",
             type=str2bool,
             default=False,
-            help=(
-                "When enabled, utterances (cuts) will be concatenated "
-                "to minimize the amount of padding."
-            ),
+            help="When enabled, utterances (cuts) will be concatenated "
+            "to minimize the amount of padding.",
         )
         group.add_argument(
             "--duration-factor",
             type=float,
             default=1.0,
-            help=(
-                "Determines the maximum duration of a concatenated cut "
-                "relative to the duration of the longest cut in a batch."
-            ),
+            help="Determines the maximum duration of a concatenated cut "
+            "relative to the duration of the longest cut in a batch.",
         )
         group.add_argument(
             "--gap",
             type=float,
             default=1.0,
-            help=(
-                "The amount of padding (in seconds) inserted between "
-                "concatenated cuts. This padding is filled with noise when "
-                "noise augmentation is used."
-            ),
+            help="The amount of padding (in seconds) inserted between "
+            "concatenated cuts. This padding is filled with noise when "
+            "noise augmentation is used.",
         )
         group.add_argument(
             "--on-the-fly-feats",
             type=str2bool,
             default=False,
-            help=(
-                "When enabled, use on-the-fly cut mixing and feature "
-                "extraction. Will drop existing precomputed feature manifests "
-                "if available."
-            ),
+            help="When enabled, use on-the-fly cut mixing and feature "
+            "extraction. Will drop existing precomputed feature manifests "
+            "if available.",
         )
         group.add_argument(
             "--shuffle",
             type=str2bool,
             default=True,
-            help=(
-                "When enabled (=default), the examples will be shuffled for each epoch."
-            ),
+            help="When enabled (=default), the examples will be "
+            "shuffled for each epoch.",
         )
         group.add_argument(
             "--return-cuts",
             type=str2bool,
             default=True,
-            help=(
-                "When enabled, each batch will have the "
-                "field: batch['supervisions']['cut'] with the cuts that "
-                "were used to construct it."
-            ),
+            help="When enabled, each batch will have the "
+            "field: batch['supervisions']['cut'] with the cuts that "
+            "were used to construct it.",
         )
         group.add_argument(
             "--num-workers",
             type=int,
             default=2,
-            help="The number of training dataloader workers that collect the batches.",
+            help="The number of training dataloader workers that "
+            "collect the batches.",
         )
         group.add_argument(
             "--enable-spec-aug",
@@ -174,25 +157,23 @@ class TedLiumAsrDataModule:
             "--spec-aug-time-warp-factor",
             type=int,
             default=80,
-            help=(
-                "Used only when --enable-spec-aug is True. "
-                "It specifies the factor for time warping in SpecAugment. "
-                "Larger values mean more warping. "
-                "A value less than 1 means to disable time warp."
-            ),
+            help="Used only when --enable-spec-aug is True. "
+            "It specifies the factor for time warping in SpecAugment. "
+            "Larger values mean more warping. "
+            "A value less than 1 means to disable time warp.",
         )
         group.add_argument(
             "--enable-musan",
             type=str2bool,
             default=True,
-            help=(
-                "When enabled, select noise from MUSAN and mix it"
-                "with training dataset. "
-            ),
+            help="When enabled, select noise from MUSAN and mix it"
+            "with training dataset.",
         )
 
     def train_dataloaders(
-        self, cuts_train: CutSet, sampler_state_dict: Optional[Dict[str, Any]] = None
+        self,
+        cuts_train: CutSet,
+        sampler_state_dict: Optional[Dict[str, Any]] = None
     ) -> DataLoader:
         """
         Args:
@@ -205,7 +186,9 @@ class TedLiumAsrDataModule:
         input_transforms = []
         if self.args.enable_spec_aug:
             logging.info("Enable SpecAugment")
-            logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
+            logging.info(
+                f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
+            )
 
             input_transforms.append(
                 SpecAugment(
@@ -225,16 +208,20 @@ class TedLiumAsrDataModule:
         transforms = []
         if self.args.enable_musan:
             logging.info("Enable MUSAN")
-            cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
+            cuts_musan = load_manifest(
+                self.args.manifest_dir / "musan_cuts.jsonl.gz"
+            )
             transforms.append(
-                CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
+                CutMix(
+                    cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
+                )
             )
         else:
             logging.info("Disable MUSAN")
 
         if self.args.concatenate_cuts:
             logging.info(
-                "Using cut concatenation with duration factor "
+                f"Using cut concatenation with duration factor "
                 f"{self.args.duration_factor} and gap {self.args.gap}."
             )
             # Cut concatenation should be the first transform in the list,
@@ -260,7 +247,9 @@ class TedLiumAsrDataModule:
             # Drop feats to be on the safe side.
             train = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
+                input_strategy=OnTheFlyFeatures(
+                    Fbank(FbankConfig(num_mel_bins=80))
+                ),
                 input_transforms=input_transforms,
                 return_cuts=self.args.return_cuts,
             )
@@ -317,7 +306,9 @@ class TedLiumAsrDataModule:
         if self.args.on_the_fly_feats:
             validate = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
+                input_strategy=OnTheFlyFeatures(
+                    Fbank(FbankConfig(num_mel_bins=80))
+                ),
                 return_cuts=self.args.return_cuts,
             )
         else:
@@ -348,7 +339,9 @@ class TedLiumAsrDataModule:
         logging.debug("About to create test dataset")
         if self.args.on_the_fly_feats:
             test = K2SpeechRecognitionDataset(
-                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
+                input_strategy=OnTheFlyFeatures(
+                    Fbank(FbankConfig(num_mel_bins=80))
+                ),
                 return_cuts=self.args.return_cuts,
             )
         else:
@@ -382,9 +375,13 @@ class TedLiumAsrDataModule:
     @lru_cache()
     def dev_cuts(self) -> CutSet:
         logging.info("About to get dev cuts")
-        return load_manifest_lazy(self.args.manifest_dir / "tedlium_cuts_dev.jsonl.gz")
+        return load_manifest_lazy(
+            self.args.manifest_dir / "tedlium_cuts_dev.jsonl.gz"
+        )
 
     @lru_cache()
     def test_cuts(self) -> CutSet:
         logging.info("About to get test cuts")
-        return load_manifest_lazy(self.args.manifest_dir / "tedlium_cuts_test.jsonl.gz")
+        return load_manifest_lazy(
+            self.args.manifest_dir / "tedlium_cuts_test.jsonl.gz"
+        )
diff --git a/egs/tedlium3/ASR/transducer_stateless/beam_search.py b/egs/tedlium3/ASR/transducer_stateless/beam_search.py
index 1f99edaf3..77caf6460 100644
--- a/egs/tedlium3/ASR/transducer_stateless/beam_search.py
+++ b/egs/tedlium3/ASR/transducer_stateless/beam_search.py
@@ -87,9 +87,9 @@ def greedy_search(
         y = logits.argmax().item()
         if y != blank_id and y != unk_id:
             hyp.append(y)
-            decoder_input = torch.tensor([hyp[-context_size:]], device=device).reshape(
-                1, context_size
-            )
+            decoder_input = torch.tensor(
+                [hyp[-context_size:]], device=device
+            ).reshape(1, context_size)
 
             decoder_out = model.decoder(decoder_input, need_pad=False)
 
@@ -148,7 +148,9 @@ class HypothesisList(object):
         key = hyp.key
         if key in self:
             old_hyp = self._data[key]  # shallow copy
-            torch.logaddexp(old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob)
+            torch.logaddexp(
+                old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob
+            )
         else:
             self._data[key] = hyp
 
@@ -164,7 +166,9 @@ class HypothesisList(object):
           Return the hypothesis that has the largest `log_prob`.
         """
         if length_norm:
-            return max(self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys))
+            return max(
+                self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys)
+            )
         else:
             return max(self._data.values(), key=lambda hyp: hyp.log_prob)
 
@@ -340,9 +344,9 @@ def modified_beam_search(
 
     device = model.device
 
-    decoder_input = torch.tensor([blank_id] * context_size, device=device).reshape(
-        1, context_size
-    )
+    decoder_input = torch.tensor(
+        [blank_id] * context_size, device=device
+    ).reshape(1, context_size)
 
     decoder_out = model.decoder(decoder_input, need_pad=False)
 
@@ -379,7 +383,9 @@ def modified_beam_search(
         decoder_out = model.decoder(decoder_input, need_pad=False)
         # decoder_output is of shape (num_hyps, 1, decoder_output_dim)
 
-        current_encoder_out = current_encoder_out.expand(decoder_out.size(0), 1, -1)
+        current_encoder_out = current_encoder_out.expand(
+            decoder_out.size(0), 1, -1
+        )
 
         logits = model.joiner(
             current_encoder_out,
@@ -448,9 +454,9 @@ def beam_search(
 
     device = model.device
 
-    decoder_input = torch.tensor([blank_id] * context_size, device=device).reshape(
-        1, context_size
-    )
+    decoder_input = torch.tensor(
+        [blank_id] * context_size, device=device
+    ).reshape(1, context_size)
 
     decoder_out = model.decoder(decoder_input, need_pad=False)
 
diff --git a/egs/tedlium3/ASR/transducer_stateless/decode.py b/egs/tedlium3/ASR/transducer_stateless/decode.py
index 12d0e2652..d3e9e55e7 100755
--- a/egs/tedlium3/ASR/transducer_stateless/decode.py
+++ b/egs/tedlium3/ASR/transducer_stateless/decode.py
@@ -81,19 +81,16 @@ def get_parser():
         "--epoch",
         type=int,
         default=29,
-        help=(
-            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
-        ),
+        help="It specifies the checkpoint to use for decoding."
+        "Note: Epoch counts from 0.",
     )
     parser.add_argument(
         "--avg",
         type=int,
         default=13,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch'. "
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch'. ",
     )
 
     parser.add_argument(
@@ -133,7 +130,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -252,7 +250,9 @@ def decode_one_batch(
     supervisions = batch["supervisions"]
     feature_lens = supervisions["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
+    encoder_out, encoder_out_lens = model.encoder(
+        x=feature, x_lens=feature_lens
+    )
     hyps = []
     batch_size = encoder_out.size(0)
 
@@ -275,7 +275,9 @@ def decode_one_batch(
                 model=model, encoder_out=encoder_out_i, beam=params.beam_size
             )
         else:
-            raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
+            raise ValueError(
+                f"Unsupported decoding method: {params.decoding_method}"
+            )
         hyps.append(sp.decode(hyp).split())
 
     if params.decoding_method == "greedy_search":
@@ -346,7 +348,9 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+            logging.info(
+                f"batch {batch_str}, cuts processed until now is {num_cuts}"
+            )
     return results
 
 
@@ -379,7 +383,8 @@ def save_results(
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = (
-        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir
+        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tWER", file=f)
diff --git a/egs/tedlium3/ASR/transducer_stateless/decoder.py b/egs/tedlium3/ASR/transducer_stateless/decoder.py
index f9a3814c6..f0c6f32b6 100644
--- a/egs/tedlium3/ASR/transducer_stateless/decoder.py
+++ b/egs/tedlium3/ASR/transducer_stateless/decoder.py
@@ -90,7 +90,9 @@ class Decoder(nn.Module):
         if self.context_size > 1:
             embedding_out = embedding_out.permute(0, 2, 1)
             if need_pad is True:
-                embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0))
+                embedding_out = F.pad(
+                    embedding_out, pad=(self.context_size - 1, 0)
+                )
             else:
                 # During inference time, there is no need to do extra padding
                 # as we only need one output
diff --git a/egs/tedlium3/ASR/transducer_stateless/export.py b/egs/tedlium3/ASR/transducer_stateless/export.py
index 0b2ae970b..c32b1d002 100644
--- a/egs/tedlium3/ASR/transducer_stateless/export.py
+++ b/egs/tedlium3/ASR/transducer_stateless/export.py
@@ -69,20 +69,17 @@ def get_parser():
         "--epoch",
         type=int,
         default=20,
-        help=(
-            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
-        ),
+        help="It specifies the checkpoint to use for decoding."
+        "Note: Epoch counts from 0.",
     )
 
     parser.add_argument(
         "--avg",
         type=int,
         default=10,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch'. "
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch'. ",
     )
 
     parser.add_argument(
@@ -113,7 +110,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
 
     return parser
@@ -249,7 +247,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/tedlium3/ASR/transducer_stateless/pretrained.py b/egs/tedlium3/ASR/transducer_stateless/pretrained.py
index 912d65497..c0e3bb844 100644
--- a/egs/tedlium3/ASR/transducer_stateless/pretrained.py
+++ b/egs/tedlium3/ASR/transducer_stateless/pretrained.py
@@ -82,11 +82,9 @@ def get_parser():
         "--checkpoint",
         type=str,
         required=True,
-        help=(
-            "Path to the checkpoint. "
-            "The checkpoint is assumed to be saved by "
-            "icefall.checkpoint.save_checkpoint()."
-        ),
+        help="Path to the checkpoint. "
+        "The checkpoint is assumed to be saved by "
+        "icefall.checkpoint.save_checkpoint().",
     )
 
     parser.add_argument(
@@ -112,12 +110,10 @@ def get_parser():
         "sound_files",
         type=str,
         nargs="+",
-        help=(
-            "The input sound file(s) to transcribe. "
-            "Supported formats are those supported by torchaudio.load(). "
-            "For example, wav and flac are supported. "
-            "The sample rate has to be 16kHz."
-        ),
+        help="The input sound file(s) to transcribe. "
+        "Supported formats are those supported by torchaudio.load(). "
+        "For example, wav and flac are supported. "
+        "The sample rate has to be 16kHz.",
     )
 
     parser.add_argument(
@@ -131,7 +127,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -225,9 +222,10 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert (
-            sample_rate == expected_sample_rate
-        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
+        assert sample_rate == expected_sample_rate, (
+            f"expected sample rate: {expected_sample_rate}. "
+            f"Given: {sample_rate}"
+        )
         # We use only the first channel
         ans.append(wave[0])
     return ans
@@ -287,7 +285,9 @@ def main():
     features = fbank(waves)
     feature_lengths = [f.size(0) for f in features]
 
-    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
+    features = pad_sequence(
+        features, batch_first=True, padding_value=math.log(1e-10)
+    )
 
     feature_lengths = torch.tensor(feature_lengths, device=device)
 
@@ -335,7 +335,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/tedlium3/ASR/transducer_stateless/train.py b/egs/tedlium3/ASR/transducer_stateless/train.py
index 6fed32e81..09cbf4a00 100755
--- a/egs/tedlium3/ASR/transducer_stateless/train.py
+++ b/egs/tedlium3/ASR/transducer_stateless/train.py
@@ -133,7 +133,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
 
     parser.add_argument(
@@ -524,7 +525,9 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+                tot_loss.write_summary(
+                    tb_writer, "train/tot_", params.batch_idx_train
+                )
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
@@ -644,7 +647,9 @@ def run(rank, world_size, args):
 
         cur_lr = optimizer._rate
         if tb_writer is not None:
-            tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
+            tb_writer.add_scalar(
+                "train/learning_rate", cur_lr, params.batch_idx_train
+            )
             tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
 
         if rank == 0:
diff --git a/egs/timit/ASR/RESULTS.md b/egs/timit/ASR/RESULTS.md
index d8ceb82b6..b78c16b88 100644
--- a/egs/timit/ASR/RESULTS.md
+++ b/egs/timit/ASR/RESULTS.md
@@ -71,4 +71,4 @@ python tdnn_ligru_ctc/decode.py --epoch 25 \
                                --avg 17 \
                                --max-duration 20 \
                                --lang-dir data/lang_phone
-```
+```
\ No newline at end of file
diff --git a/egs/timit/ASR/local/compile_hlg.py b/egs/timit/ASR/local/compile_hlg.py
index 32c248d7e..58cab4cf2 100644
--- a/egs/timit/ASR/local/compile_hlg.py
+++ b/egs/timit/ASR/local/compile_hlg.py
@@ -146,7 +146,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
 
diff --git a/egs/timit/ASR/local/compute_fbank_timit.py b/egs/timit/ASR/local/compute_fbank_timit.py
index ecdf10ba9..f25786a0c 100644
--- a/egs/timit/ASR/local/compute_fbank_timit.py
+++ b/egs/timit/ASR/local/compute_fbank_timit.py
@@ -85,7 +85,9 @@ def compute_fbank_timit():
             )
             if partition == "TRAIN":
                 cut_set = (
-                    cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
+                    cut_set
+                    + cut_set.perturb_speed(0.9)
+                    + cut_set.perturb_speed(1.1)
                 )
             cut_set = cut_set.compute_and_store_features(
                 extractor=extractor,
@@ -99,7 +101,9 @@ def compute_fbank_timit():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
 
diff --git a/egs/timit/ASR/local/prepare_lexicon.py b/egs/timit/ASR/local/prepare_lexicon.py
index 0cf0f0deb..04023a9ab 100644
--- a/egs/timit/ASR/local/prepare_lexicon.py
+++ b/egs/timit/ASR/local/prepare_lexicon.py
@@ -62,7 +62,9 @@ def prepare_lexicon(manifests_dir: str, lang_dir: str):
 
     phones = set()
 
-    supervisions_train = Path(manifests_dir) / "timit_supervisions_TRAIN.jsonl.gz"
+    supervisions_train = (
+        Path(manifests_dir) / "timit_supervisions_TRAIN.jsonl.gz"
+    )
     lexicon = Path(lang_dir) / "lexicon.txt"
 
     logging.info(f"Loading {supervisions_train}!")
@@ -95,7 +97,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
 
diff --git a/egs/timit/ASR/prepare.sh b/egs/timit/ASR/prepare.sh
index d11cd3a05..ae1b96a68 100644
--- a/egs/timit/ASR/prepare.sh
+++ b/egs/timit/ASR/prepare.sh
@@ -20,9 +20,9 @@ stop_stage=100
 #  - $dl_dir/lm
 #      This directory contains the language model(LM) downloaded from
 #      https://huggingface.co/luomingshuang/timit_lm, and the LM is based
-#	     on 39 phones. About how to get these LM files, you can know it
+#	     on 39 phones. About how to get these LM files, you can know it 
 #      from https://github.com/luomingshuang/Train_LM_with_kaldilm.
-#
+#	
 #	    - lm_3_gram.arpa
 #     - lm_4_gram.arpa
 #
diff --git a/egs/timit/ASR/tdnn_ligru_ctc/decode.py b/egs/timit/ASR/tdnn_ligru_ctc/decode.py
index 5a59a13ce..4f2aa2340 100644
--- a/egs/timit/ASR/tdnn_ligru_ctc/decode.py
+++ b/egs/timit/ASR/tdnn_ligru_ctc/decode.py
@@ -57,19 +57,16 @@ def get_parser():
         "--epoch",
         type=int,
         default=19,
-        help=(
-            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
-        ),
+        help="It specifies the checkpoint to use for decoding."
+        "Note: Epoch counts from 0.",
     )
     parser.add_argument(
         "--avg",
         type=int,
         default=5,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch'. "
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch'. ",
     )
     parser.add_argument(
         "--method",
@@ -339,7 +336,9 @@ def decode_dataset(
         if batch_idx % 100 == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+            logging.info(
+                f"batch {batch_str}, cuts processed until now is {num_cuts}"
+            )
     return results
 
 
@@ -401,7 +400,9 @@ def main():
 
     logging.info(f"device: {device}")
 
-    HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu"))
+    HLG = k2.Fsa.from_dict(
+        torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")
+    )
     HLG = HLG.to(device)
     assert HLG.requires_grad is False
 
@@ -461,7 +462,9 @@ def main():
 
     if params.export:
         logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
-        torch.save({"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt")
+        torch.save(
+            {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
+        )
         return
 
     model.to(device)
@@ -482,7 +485,9 @@ def main():
         G=G,
     )
 
-    save_results(params=params, test_set_name=test_set, results_dict=results_dict)
+    save_results(
+        params=params, test_set_name=test_set, results_dict=results_dict
+    )
 
     logging.info("Done!")
 
diff --git a/egs/timit/ASR/tdnn_ligru_ctc/model.py b/egs/timit/ASR/tdnn_ligru_ctc/model.py
index 9a594a969..4d2199ace 100644
--- a/egs/timit/ASR/tdnn_ligru_ctc/model.py
+++ b/egs/timit/ASR/tdnn_ligru_ctc/model.py
@@ -16,11 +16,11 @@
 # limitations under the License.
 
 
-from typing import Optional
-
 import torch
 import torch.nn as nn
+
 from torch import Tensor
+from typing import Optional
 
 
 class TdnnLiGRU(nn.Module):
@@ -261,7 +261,9 @@ class LiGRU(torch.nn.Module):
         h = []
         if hx is not None:
             if self.bidirectional:
-                hx = hx.reshape(self.num_layers, self.batch_size * 2, self.hidden_size)
+                hx = hx.reshape(
+                    self.num_layers, self.batch_size * 2, self.hidden_size
+                )
         # Processing the different layers
         for i, ligru_lay in enumerate(self.rnn):
             if hx is not None:
@@ -443,7 +445,9 @@ class LiGRU_Layer(torch.nn.Module):
             if self.drop_mask_cnt + self.batch_size > self.N_drop_masks:
                 self.drop_mask_cnt = 0
                 self.drop_masks = self.drop(
-                    torch.ones(self.N_drop_masks, self.hidden_size, device=w.device)
+                    torch.ones(
+                        self.N_drop_masks, self.hidden_size, device=w.device
+                    )
                 ).data
 
             # Sampling the mask
diff --git a/egs/timit/ASR/tdnn_ligru_ctc/pretrained.py b/egs/timit/ASR/tdnn_ligru_ctc/pretrained.py
index da669bc39..7da285944 100644
--- a/egs/timit/ASR/tdnn_ligru_ctc/pretrained.py
+++ b/egs/timit/ASR/tdnn_ligru_ctc/pretrained.py
@@ -29,7 +29,11 @@ import torchaudio
 from model import TdnnLiGRU
 from torch.nn.utils.rnn import pad_sequence
 
-from icefall.decode import get_lattice, one_best_decoding, rescore_with_whole_lattice
+from icefall.decode import (
+    get_lattice,
+    one_best_decoding,
+    rescore_with_whole_lattice,
+)
 from icefall.utils import AttributeDict, get_texts
 
 
@@ -42,11 +46,9 @@ def get_parser():
         "--checkpoint",
         type=str,
         required=True,
-        help=(
-            "Path to the checkpoint. "
-            "The checkpoint is assumed to be saved by "
-            "icefall.checkpoint.save_checkpoint()."
-        ),
+        help="Path to the checkpoint. "
+        "The checkpoint is assumed to be saved by "
+        "icefall.checkpoint.save_checkpoint().",
     )
 
     parser.add_argument(
@@ -56,7 +58,9 @@ def get_parser():
         help="Path to words.txt",
     )
 
-    parser.add_argument("--HLG", type=str, required=True, help="Path to HLG.pt.")
+    parser.add_argument(
+        "--HLG", type=str, required=True, help="Path to HLG.pt."
+    )
 
     parser.add_argument(
         "--method",
@@ -99,12 +103,10 @@ def get_parser():
         "sound_files",
         type=str,
         nargs="+",
-        help=(
-            "The input sound file(s) to transcribe. "
-            "Supported formats are those supported by torchaudio.load(). "
-            "For example, wav and flac are supported. "
-            "The sample rate has to be 16kHz."
-        ),
+        help="The input sound file(s) to transcribe. "
+        "Supported formats are those supported by torchaudio.load(). "
+        "For example, wav and flac are supported. "
+        "The sample rate has to be 16kHz.",
     )
 
     return parser
@@ -142,9 +144,10 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert (
-            sample_rate == expected_sample_rate
-        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
+        assert sample_rate == expected_sample_rate, (
+            f"expected sample rate: {expected_sample_rate}. "
+            f"Given: {sample_rate}"
+        )
         # We use only the first channel
         ans.append(wave[0])
     return ans
@@ -212,7 +215,9 @@ def main():
     logging.info("Decoding started")
     features = fbank(waves)
 
-    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
+    features = pad_sequence(
+        features, batch_first=True, padding_value=math.log(1e-10)
+    )
     features = features.permute(0, 2, 1)  # now features is (N, C, T)
 
     with torch.no_grad():
@@ -264,7 +269,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/timit/ASR/tdnn_ligru_ctc/train.py b/egs/timit/ASR/tdnn_ligru_ctc/train.py
index 48b7feda0..452c2a7cb 100644
--- a/egs/timit/ASR/tdnn_ligru_ctc/train.py
+++ b/egs/timit/ASR/tdnn_ligru_ctc/train.py
@@ -449,7 +449,9 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+                tot_loss.write_summary(
+                    tb_writer, "train/tot_", params.batch_idx_train
+                )
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             valid_info = compute_validation_loss(
diff --git a/egs/timit/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/timit/ASR/tdnn_lstm_ctc/asr_datamodule.py
index d957c22e1..1554e987f 100644
--- a/egs/timit/ASR/tdnn_lstm_ctc/asr_datamodule.py
+++ b/egs/timit/ASR/tdnn_lstm_ctc/asr_datamodule.py
@@ -63,12 +63,10 @@ class TimitAsrDataModule(DataModule):
         super().add_arguments(parser)
         group = parser.add_argument_group(
             title="ASR data related options",
-            description=(
-                "These options are used for the preparation of "
-                "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
-                "effective batch sizes, sampling strategies, applied data "
-                "augmentations, etc."
-            ),
+            description="These options are used for the preparation of "
+            "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
+            "effective batch sizes, sampling strategies, applied data "
+            "augmentations, etc.",
         )
         group.add_argument(
             "--feature-dir",
@@ -80,91 +78,75 @@ class TimitAsrDataModule(DataModule):
             "--max-duration",
             type=int,
             default=200.0,
-            help=(
-                "Maximum pooled recordings duration (seconds) in a "
-                "single batch. You can reduce it if it causes CUDA OOM."
-            ),
+            help="Maximum pooled recordings duration (seconds) in a "
+            "single batch. You can reduce it if it causes CUDA OOM.",
         )
         group.add_argument(
             "--bucketing-sampler",
             type=str2bool,
             default=True,
-            help=(
-                "When enabled, the batches will come from buckets of "
-                "similar duration (saves padding frames)."
-            ),
+            help="When enabled, the batches will come from buckets of "
+            "similar duration (saves padding frames).",
         )
         group.add_argument(
             "--num-buckets",
             type=int,
             default=30,
-            help=(
-                "The number of buckets for the DynamicBucketingSampler"
-                "(you might want to increase it for larger datasets)."
-            ),
+            help="The number of buckets for the DynamicBucketingSampler"
+            "(you might want to increase it for larger datasets).",
         )
         group.add_argument(
             "--concatenate-cuts",
             type=str2bool,
             default=False,
-            help=(
-                "When enabled, utterances (cuts) will be concatenated "
-                "to minimize the amount of padding."
-            ),
+            help="When enabled, utterances (cuts) will be concatenated "
+            "to minimize the amount of padding.",
         )
         group.add_argument(
             "--duration-factor",
             type=float,
             default=1.0,
-            help=(
-                "Determines the maximum duration of a concatenated cut "
-                "relative to the duration of the longest cut in a batch."
-            ),
+            help="Determines the maximum duration of a concatenated cut "
+            "relative to the duration of the longest cut in a batch.",
         )
         group.add_argument(
             "--gap",
             type=float,
             default=1.0,
-            help=(
-                "The amount of padding (in seconds) inserted between "
-                "concatenated cuts. This padding is filled with noise when "
-                "noise augmentation is used."
-            ),
+            help="The amount of padding (in seconds) inserted between "
+            "concatenated cuts. This padding is filled with noise when "
+            "noise augmentation is used.",
         )
         group.add_argument(
             "--on-the-fly-feats",
             type=str2bool,
             default=False,
-            help=(
-                "When enabled, use on-the-fly cut mixing and feature "
-                "extraction. Will drop existing precomputed feature manifests "
-                "if available."
-            ),
+            help="When enabled, use on-the-fly cut mixing and feature "
+            "extraction. Will drop existing precomputed feature manifests "
+            "if available.",
         )
         group.add_argument(
             "--shuffle",
             type=str2bool,
             default=True,
-            help=(
-                "When enabled (=default), the examples will be shuffled for each epoch."
-            ),
+            help="When enabled (=default), the examples will be "
+            "shuffled for each epoch.",
         )
         group.add_argument(
             "--return-cuts",
             type=str2bool,
             default=True,
-            help=(
-                "When enabled, each batch will have the "
-                "field: batch['supervisions']['cut'] with the cuts that "
-                "were used to construct it."
-            ),
+            help="When enabled, each batch will have the "
+            "field: batch['supervisions']['cut'] with the cuts that "
+            "were used to construct it.",
         )
 
         group.add_argument(
             "--num-workers",
             type=int,
             default=2,
-            help="The number of training dataloader workers that collect the batches.",
+            help="The number of training dataloader workers that "
+            "collect the batches.",
         )
 
     def train_dataloaders(self) -> DataLoader:
@@ -172,13 +154,15 @@ class TimitAsrDataModule(DataModule):
         cuts_train = self.train_cuts()
 
         logging.info("About to get Musan cuts")
-        cuts_musan = load_manifest(self.args.feature_dir / "musan_cuts.jsonl.gz")
+        cuts_musan = load_manifest(
+            self.args.feature_dir / "musan_cuts.jsonl.gz"
+        )
 
         logging.info("About to create train dataset")
         transforms = [CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20))]
         if self.args.concatenate_cuts:
             logging.info(
-                "Using cut concatenation with duration factor "
+                f"Using cut concatenation with duration factor "
                 f"{self.args.duration_factor} and gap {self.args.gap}."
             )
             # Cut concatenation should be the first transform in the list,
@@ -194,9 +178,9 @@ class TimitAsrDataModule(DataModule):
         # In different Lhotse's versions, the default of num_frame_masks is
         # different.
         num_frame_masks = 10
-        num_frame_masks_parameter = inspect.signature(SpecAugment.__init__).parameters[
-            "num_frame_masks"
-        ]
+        num_frame_masks_parameter = inspect.signature(
+            SpecAugment.__init__
+        ).parameters["num_frame_masks"]
         if num_frame_masks_parameter.default == 1:
             num_frame_masks = 2
         logging.info(f"Num frame mask: {num_frame_masks}")
@@ -228,7 +212,9 @@ class TimitAsrDataModule(DataModule):
             # Drop feats to be on the safe side.
             train = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
+                input_strategy=OnTheFlyFeatures(
+                    Fbank(FbankConfig(num_mel_bins=80))
+                ),
                 input_transforms=input_transforms,
                 return_cuts=self.args.return_cuts,
             )
@@ -277,7 +263,9 @@ class TimitAsrDataModule(DataModule):
         if self.args.on_the_fly_feats:
             validate = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
+                input_strategy=OnTheFlyFeatures(
+                    Fbank(FbankConfig(num_mel_bins=80))
+                ),
                 return_cuts=self.args.return_cuts,
             )
         else:
@@ -311,14 +299,20 @@ class TimitAsrDataModule(DataModule):
         for cuts_test in cuts:
             logging.debug("About to create test dataset")
             test = K2SpeechRecognitionDataset(
-                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
+                input_strategy=OnTheFlyFeatures(
+                    Fbank(FbankConfig(num_mel_bins=80))
+                )
                 if self.args.on_the_fly_feats
                 else PrecomputedFeatures(),
                 return_cuts=self.args.return_cuts,
             )
-            sampler = SingleCutSampler(cuts_test, max_duration=self.args.max_duration)
+            sampler = SingleCutSampler(
+                cuts_test, max_duration=self.args.max_duration
+            )
             logging.debug("About to create test dataloader")
-            test_dl = DataLoader(test, batch_size=None, sampler=sampler, num_workers=1)
+            test_dl = DataLoader(
+                test, batch_size=None, sampler=sampler, num_workers=1
+            )
             test_loaders.append(test_dl)
 
         if is_list:
diff --git a/egs/timit/ASR/tdnn_lstm_ctc/decode.py b/egs/timit/ASR/tdnn_lstm_ctc/decode.py
index 319ee5515..5e7300cf2 100644
--- a/egs/timit/ASR/tdnn_lstm_ctc/decode.py
+++ b/egs/timit/ASR/tdnn_lstm_ctc/decode.py
@@ -56,19 +56,16 @@ def get_parser():
         "--epoch",
         type=int,
         default=25,
-        help=(
-            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
-        ),
+        help="It specifies the checkpoint to use for decoding."
+        "Note: Epoch counts from 0.",
     )
     parser.add_argument(
         "--avg",
         type=int,
         default=5,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch'. "
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch'. ",
     )
     parser.add_argument(
         "--method",
@@ -338,7 +335,9 @@ def decode_dataset(
         if batch_idx % 100 == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+            logging.info(
+                f"batch {batch_str}, cuts processed until now is {num_cuts}"
+            )
     return results
 
 
@@ -400,7 +399,9 @@ def main():
 
     logging.info(f"device: {device}")
 
-    HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu"))
+    HLG = k2.Fsa.from_dict(
+        torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")
+    )
     HLG = HLG.to(device)
     assert HLG.requires_grad is False
 
@@ -460,7 +461,9 @@ def main():
 
     if params.export:
         logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
-        torch.save({"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt")
+        torch.save(
+            {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
+        )
         return
 
     model.to(device)
@@ -480,7 +483,9 @@ def main():
         G=G,
     )
 
-    save_results(params=params, test_set_name=test_set, results_dict=results_dict)
+    save_results(
+        params=params, test_set_name=test_set, results_dict=results_dict
+    )
 
     logging.info("Done!")
 
diff --git a/egs/timit/ASR/tdnn_lstm_ctc/model.py b/egs/timit/ASR/tdnn_lstm_ctc/model.py
index e211ad80d..51edb97e2 100644
--- a/egs/timit/ASR/tdnn_lstm_ctc/model.py
+++ b/egs/timit/ASR/tdnn_lstm_ctc/model.py
@@ -74,7 +74,10 @@ class TdnnLstm(nn.Module):
             nn.BatchNorm1d(num_features=512, affine=False),
         )
         self.lstms = nn.ModuleList(
-            [nn.LSTM(input_size=512, hidden_size=512, num_layers=1) for _ in range(4)]
+            [
+                nn.LSTM(input_size=512, hidden_size=512, num_layers=1)
+                for _ in range(4)
+            ]
         )
         self.lstm_bnorms = nn.ModuleList(
             [nn.BatchNorm1d(num_features=512, affine=False) for _ in range(5)]
diff --git a/egs/timit/ASR/tdnn_lstm_ctc/pretrained.py b/egs/timit/ASR/tdnn_lstm_ctc/pretrained.py
index 0c72c973b..5f478da1c 100644
--- a/egs/timit/ASR/tdnn_lstm_ctc/pretrained.py
+++ b/egs/timit/ASR/tdnn_lstm_ctc/pretrained.py
@@ -29,7 +29,11 @@ import torchaudio
 from model import TdnnLstm
 from torch.nn.utils.rnn import pad_sequence
 
-from icefall.decode import get_lattice, one_best_decoding, rescore_with_whole_lattice
+from icefall.decode import (
+    get_lattice,
+    one_best_decoding,
+    rescore_with_whole_lattice,
+)
 from icefall.utils import AttributeDict, get_texts
 
 
@@ -42,11 +46,9 @@ def get_parser():
         "--checkpoint",
         type=str,
         required=True,
-        help=(
-            "Path to the checkpoint. "
-            "The checkpoint is assumed to be saved by "
-            "icefall.checkpoint.save_checkpoint()."
-        ),
+        help="Path to the checkpoint. "
+        "The checkpoint is assumed to be saved by "
+        "icefall.checkpoint.save_checkpoint().",
     )
 
     parser.add_argument(
@@ -56,7 +58,9 @@ def get_parser():
         help="Path to words.txt",
     )
 
-    parser.add_argument("--HLG", type=str, required=True, help="Path to HLG.pt.")
+    parser.add_argument(
+        "--HLG", type=str, required=True, help="Path to HLG.pt."
+    )
 
     parser.add_argument(
         "--method",
@@ -99,12 +103,10 @@ def get_parser():
         "sound_files",
         type=str,
         nargs="+",
-        help=(
-            "The input sound file(s) to transcribe. "
-            "Supported formats are those supported by torchaudio.load(). "
-            "For example, wav and flac are supported. "
-            "The sample rate has to be 16kHz."
-        ),
+        help="The input sound file(s) to transcribe. "
+        "Supported formats are those supported by torchaudio.load(). "
+        "For example, wav and flac are supported. "
+        "The sample rate has to be 16kHz.",
     )
 
     return parser
@@ -142,9 +144,10 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert (
-            sample_rate == expected_sample_rate
-        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
+        assert sample_rate == expected_sample_rate, (
+            f"expected sample rate: {expected_sample_rate}. "
+            f"Given: {sample_rate}"
+        )
         # We use only the first channel
         ans.append(wave[0])
     return ans
@@ -212,7 +215,9 @@ def main():
     logging.info("Decoding started")
     features = fbank(waves)
 
-    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
+    features = pad_sequence(
+        features, batch_first=True, padding_value=math.log(1e-10)
+    )
     features = features.permute(0, 2, 1)  # now features is (N, C, T)
 
     with torch.no_grad():
@@ -264,7 +269,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/timit/ASR/tdnn_lstm_ctc/train.py b/egs/timit/ASR/tdnn_lstm_ctc/train.py
index be1ecffaa..849256b98 100644
--- a/egs/timit/ASR/tdnn_lstm_ctc/train.py
+++ b/egs/timit/ASR/tdnn_lstm_ctc/train.py
@@ -449,7 +449,9 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+                tot_loss.write_summary(
+                    tb_writer, "train/tot_", params.batch_idx_train
+                )
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             valid_info = compute_validation_loss(
diff --git a/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_dev_test.py b/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_dev_test.py
index bd73e520e..8a9f6ed30 100755
--- a/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_dev_test.py
+++ b/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_dev_test.py
@@ -20,7 +20,12 @@ import logging
 from pathlib import Path
 
 import torch
-from lhotse import CutSet, KaldifeatFbank, KaldifeatFbankConfig, LilcomHdf5Writer
+from lhotse import (
+    CutSet,
+    KaldifeatFbank,
+    KaldifeatFbankConfig,
+    LilcomHdf5Writer,
+)
 
 # Torch's multithreaded behavior needs to be disabled or
 # it wastes a lot of CPU and slow things down.
@@ -78,7 +83,9 @@ def compute_fbank_wenetspeech_dev_test():
 
 
 def main():
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
     logging.basicConfig(format=formatter, level=logging.INFO)
 
     compute_fbank_wenetspeech_dev_test()
diff --git a/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_splits.py b/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_splits.py
index c228597b8..a882b6113 100755
--- a/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_splits.py
+++ b/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_splits.py
@@ -62,10 +62,8 @@ def get_parser():
         "--batch-duration",
         type=float,
         default=600.0,
-        help=(
-            "The maximum number of audio seconds in a batch."
-            "Determines batch size dynamically."
-        ),
+        help="The maximum number of audio seconds in a batch."
+        "Determines batch size dynamically.",
     )
 
     parser.add_argument(
@@ -154,7 +152,9 @@ def main():
     date_time = now.strftime("%Y-%m-%d-%H-%M-%S")
 
     log_filename = "log-compute_fbank_wenetspeech_splits"
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
     log_filename = f"{log_filename}-{date_time}"
 
     logging.basicConfig(
diff --git a/egs/wenetspeech/ASR/local/prepare_char.py b/egs/wenetspeech/ASR/local/prepare_char.py
index d8622842f..8bc073c75 100755
--- a/egs/wenetspeech/ASR/local/prepare_char.py
+++ b/egs/wenetspeech/ASR/local/prepare_char.py
@@ -83,7 +83,9 @@ def lexicon_to_fst_no_sil(
         cur_state = loop_state
 
         word = word2id[word]
-        pieces = [token2id[i] if i in token2id else token2id[""] for i in pieces]
+        pieces = [
+            token2id[i] if i in token2id else token2id[""] for i in pieces
+        ]
 
         for i in range(len(pieces) - 1):
             w = word if i == 0 else eps
@@ -136,7 +138,9 @@ def contain_oov(token_sym_table: Dict[str, int], tokens: List[str]) -> bool:
     return False
 
 
-def generate_lexicon(token_sym_table: Dict[str, int], words: List[str]) -> Lexicon:
+def generate_lexicon(
+    token_sym_table: Dict[str, int], words: List[str]
+) -> Lexicon:
     """Generate a lexicon from a word list and token_sym_table.
     Args:
       token_sym_table:
diff --git a/egs/wenetspeech/ASR/local/preprocess_wenetspeech.py b/egs/wenetspeech/ASR/local/preprocess_wenetspeech.py
index 93ce750f8..817969c47 100755
--- a/egs/wenetspeech/ASR/local/preprocess_wenetspeech.py
+++ b/egs/wenetspeech/ASR/local/preprocess_wenetspeech.py
@@ -115,7 +115,11 @@ def preprocess_wenet_speech():
                 f"Speed perturb for {partition} with factors 0.9 and 1.1 "
                 "(Perturbing may take 8 minutes and saving may take 20 minutes)"
             )
-            cut_set = cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
+            cut_set = (
+                cut_set
+                + cut_set.perturb_speed(0.9)
+                + cut_set.perturb_speed(1.1)
+            )
         logging.info(f"Saving to {raw_cuts_path}")
         cut_set.to_file(raw_cuts_path)
 
diff --git a/egs/wenetspeech/ASR/local/text2token.py b/egs/wenetspeech/ASR/local/text2token.py
index e121d842c..1c463cf1c 100755
--- a/egs/wenetspeech/ASR/local/text2token.py
+++ b/egs/wenetspeech/ASR/local/text2token.py
@@ -50,15 +50,15 @@ def get_parser():
         "-n",
         default=1,
         type=int,
-        help=(
-            "number of characters to split, i.e.,                         aabb -> a a b"
-            " b with -n 1 and aa bb with -n 2"
-        ),
+        help="number of characters to split, i.e., \
+                        aabb -> a a b b with -n 1 and aa bb with -n 2",
     )
     parser.add_argument(
         "--skip-ncols", "-s", default=0, type=int, help="skip first n columns"
     )
-    parser.add_argument("--space", default="", type=str, help="space symbol")
+    parser.add_argument(
+        "--space", default="", type=str, help="space symbol"
+    )
     parser.add_argument(
         "--non-lang-syms",
         "-l",
@@ -66,7 +66,9 @@ def get_parser():
         type=str,
         help="list of non-linguistic symobles, e.g.,  etc.",
     )
-    parser.add_argument("text", type=str, default=False, nargs="?", help="input text")
+    parser.add_argument(
+        "text", type=str, default=False, nargs="?", help="input text"
+    )
     parser.add_argument(
         "--trans_type",
         "-t",
@@ -106,7 +108,8 @@ def token2id(
             if token_type == "lazy_pinyin":
                 text = lazy_pinyin(chars_list)
                 sub_ids = [
-                    token_table[txt] if txt in token_table else oov_id for txt in text
+                    token_table[txt] if txt in token_table else oov_id
+                    for txt in text
                 ]
                 ids.append(sub_ids)
             else:  # token_type = "pinyin"
@@ -132,7 +135,9 @@ def main():
     if args.text:
         f = codecs.open(args.text, encoding="utf-8")
     else:
-        f = codecs.getreader("utf-8")(sys.stdin if is_python2 else sys.stdin.buffer)
+        f = codecs.getreader("utf-8")(
+            sys.stdin if is_python2 else sys.stdin.buffer
+        )
 
     sys.stdout = codecs.getwriter("utf-8")(
         sys.stdout if is_python2 else sys.stdout.buffer
diff --git a/egs/wenetspeech/ASR/prepare.sh b/egs/wenetspeech/ASR/prepare.sh
index da7d7e061..755fbb2d7 100755
--- a/egs/wenetspeech/ASR/prepare.sh
+++ b/egs/wenetspeech/ASR/prepare.sh
@@ -190,7 +190,7 @@ if [ $stage -le 15 ] && [ $stop_stage -ge 15 ]; then
   mkdir -p $lang_char_dir
 
   if ! which jq; then
-      echo "This script is intended to be used with jq but you have not installed jq
+      echo "This script is intended to be used with jq but you have not installed jq 
       Note: in Linux, you can install jq with the following command:
       1. wget -O jq https://github.com/stedolan/jq/releases/download/jq-1.6/jq-linux64
       2. chmod +x ./jq
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py
index bd92ac115..10c953e3b 100644
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py
@@ -81,12 +81,10 @@ class WenetSpeechAsrDataModule:
     def add_arguments(cls, parser: argparse.ArgumentParser):
         group = parser.add_argument_group(
             title="ASR data related options",
-            description=(
-                "These options are used for the preparation of "
-                "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
-                "effective batch sizes, sampling strategies, applied data "
-                "augmentations, etc."
-            ),
+            description="These options are used for the preparation of "
+            "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
+            "effective batch sizes, sampling strategies, applied data "
+            "augmentations, etc.",
         )
         group.add_argument(
             "--manifest-dir",
@@ -98,91 +96,75 @@ class WenetSpeechAsrDataModule:
             "--max-duration",
             type=int,
             default=200.0,
-            help=(
-                "Maximum pooled recordings duration (seconds) in a "
-                "single batch. You can reduce it if it causes CUDA OOM."
-            ),
+            help="Maximum pooled recordings duration (seconds) in a "
+            "single batch. You can reduce it if it causes CUDA OOM.",
         )
         group.add_argument(
             "--bucketing-sampler",
             type=str2bool,
             default=True,
-            help=(
-                "When enabled, the batches will come from buckets of "
-                "similar duration (saves padding frames)."
-            ),
+            help="When enabled, the batches will come from buckets of "
+            "similar duration (saves padding frames).",
         )
         group.add_argument(
             "--num-buckets",
             type=int,
             default=300,
-            help=(
-                "The number of buckets for the DynamicBucketingSampler"
-                "(you might want to increase it for larger datasets)."
-            ),
+            help="The number of buckets for the DynamicBucketingSampler"
+            "(you might want to increase it for larger datasets).",
         )
         group.add_argument(
             "--concatenate-cuts",
             type=str2bool,
             default=False,
-            help=(
-                "When enabled, utterances (cuts) will be concatenated "
-                "to minimize the amount of padding."
-            ),
+            help="When enabled, utterances (cuts) will be concatenated "
+            "to minimize the amount of padding.",
         )
         group.add_argument(
             "--duration-factor",
             type=float,
             default=1.0,
-            help=(
-                "Determines the maximum duration of a concatenated cut "
-                "relative to the duration of the longest cut in a batch."
-            ),
+            help="Determines the maximum duration of a concatenated cut "
+            "relative to the duration of the longest cut in a batch.",
         )
         group.add_argument(
             "--gap",
             type=float,
             default=1.0,
-            help=(
-                "The amount of padding (in seconds) inserted between "
-                "concatenated cuts. This padding is filled with noise when "
-                "noise augmentation is used."
-            ),
+            help="The amount of padding (in seconds) inserted between "
+            "concatenated cuts. This padding is filled with noise when "
+            "noise augmentation is used.",
         )
         group.add_argument(
             "--on-the-fly-feats",
             type=str2bool,
             default=False,
-            help=(
-                "When enabled, use on-the-fly cut mixing and feature "
-                "extraction. Will drop existing precomputed feature manifests "
-                "if available."
-            ),
+            help="When enabled, use on-the-fly cut mixing and feature "
+            "extraction. Will drop existing precomputed feature manifests "
+            "if available.",
         )
         group.add_argument(
             "--shuffle",
             type=str2bool,
             default=True,
-            help=(
-                "When enabled (=default), the examples will be shuffled for each epoch."
-            ),
+            help="When enabled (=default), the examples will be "
+            "shuffled for each epoch.",
         )
         group.add_argument(
             "--return-cuts",
             type=str2bool,
             default=True,
-            help=(
-                "When enabled, each batch will have the "
-                "field: batch['supervisions']['cut'] with the cuts that "
-                "were used to construct it."
-            ),
+            help="When enabled, each batch will have the "
+            "field: batch['supervisions']['cut'] with the cuts that "
+            "were used to construct it.",
         )
 
         group.add_argument(
             "--num-workers",
             type=int,
             default=2,
-            help="The number of training dataloader workers that collect the batches.",
+            help="The number of training dataloader workers that "
+            "collect the batches.",
         )
 
         group.add_argument(
@@ -196,22 +178,18 @@ class WenetSpeechAsrDataModule:
             "--spec-aug-time-warp-factor",
             type=int,
             default=80,
-            help=(
-                "Used only when --enable-spec-aug is True. "
-                "It specifies the factor for time warping in SpecAugment. "
-                "Larger values mean more warping. "
-                "A value less than 1 means to disable time warp."
-            ),
+            help="Used only when --enable-spec-aug is True. "
+            "It specifies the factor for time warping in SpecAugment. "
+            "Larger values mean more warping. "
+            "A value less than 1 means to disable time warp.",
         )
 
         group.add_argument(
             "--enable-musan",
             type=str2bool,
             default=True,
-            help=(
-                "When enabled, select noise from MUSAN and mix it"
-                "with training dataset. "
-            ),
+            help="When enabled, select noise from MUSAN and mix it"
+            "with training dataset. ",
         )
 
         group.add_argument(
@@ -234,20 +212,24 @@ class WenetSpeechAsrDataModule:
             The state dict for the training sampler.
         """
         logging.info("About to get Musan cuts")
-        cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
+        cuts_musan = load_manifest(
+            self.args.manifest_dir / "musan_cuts.jsonl.gz"
+        )
 
         transforms = []
         if self.args.enable_musan:
             logging.info("Enable MUSAN")
             transforms.append(
-                CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
+                CutMix(
+                    cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
+                )
             )
         else:
             logging.info("Disable MUSAN")
 
         if self.args.concatenate_cuts:
             logging.info(
-                "Using cut concatenation with duration factor "
+                f"Using cut concatenation with duration factor "
                 f"{self.args.duration_factor} and gap {self.args.gap}."
             )
             # Cut concatenation should be the first transform in the list,
@@ -262,7 +244,9 @@ class WenetSpeechAsrDataModule:
         input_transforms = []
         if self.args.enable_spec_aug:
             logging.info("Enable SpecAugment")
-            logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
+            logging.info(
+                f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
+            )
             # Set the value of num_frame_masks according to Lhotse's version.
             # In different Lhotse's versions, the default of num_frame_masks is
             # different.
@@ -305,7 +289,9 @@ class WenetSpeechAsrDataModule:
             # Drop feats to be on the safe side.
             train = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
+                input_strategy=OnTheFlyFeatures(
+                    Fbank(FbankConfig(num_mel_bins=80))
+                ),
                 input_transforms=input_transforms,
                 return_cuts=self.args.return_cuts,
             )
@@ -362,7 +348,9 @@ class WenetSpeechAsrDataModule:
         if self.args.on_the_fly_feats:
             validate = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
+                input_strategy=OnTheFlyFeatures(
+                    Fbank(FbankConfig(num_mel_bins=80))
+                ),
                 return_cuts=self.args.return_cuts,
             )
         else:
@@ -426,7 +414,8 @@ class WenetSpeechAsrDataModule:
     def train_cuts(self) -> CutSet:
         logging.info("About to get train cuts")
         cuts_train = load_manifest_lazy(
-            self.args.manifest_dir / f"cuts_{self.args.training_subset}.jsonl.gz"
+            self.args.manifest_dir
+            / f"cuts_{self.args.training_subset}.jsonl.gz"
         )
         return cuts_train
 
@@ -438,9 +427,13 @@ class WenetSpeechAsrDataModule:
     @lru_cache()
     def test_net_cuts(self) -> List[CutSet]:
         logging.info("About to get TEST_NET cuts")
-        return load_manifest_lazy(self.args.manifest_dir / "cuts_TEST_NET.jsonl.gz")
+        return load_manifest_lazy(
+            self.args.manifest_dir / "cuts_TEST_NET.jsonl.gz"
+        )
 
     @lru_cache()
     def test_meeting_cuts(self) -> List[CutSet]:
         logging.info("About to get TEST_MEETING cuts")
-        return load_manifest_lazy(self.args.manifest_dir / "cuts_TEST_MEETING.jsonl.gz")
+        return load_manifest_lazy(
+            self.args.manifest_dir / "cuts_TEST_MEETING.jsonl.gz"
+        )
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py
index 6e856248c..f0c9bebec 100755
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py
@@ -114,7 +114,11 @@ from beam_search import (
 from train import get_params, get_transducer_model
 
 from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
-from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
+from icefall.checkpoint import (
+    average_checkpoints,
+    find_checkpoints,
+    load_checkpoint,
+)
 from icefall.lexicon import Lexicon
 from icefall.utils import (
     AttributeDict,
@@ -133,30 +137,25 @@ def get_parser():
         "--epoch",
         type=int,
         default=28,
-        help=(
-            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
-        ),
+        help="It specifies the checkpoint to use for decoding."
+        "Note: Epoch counts from 0.",
     )
 
     parser.add_argument(
         "--batch",
         type=int,
         default=None,
-        help=(
-            "It specifies the batch checkpoint to use for decoding."
-            "Note: Epoch counts from 0."
-        ),
+        help="It specifies the batch checkpoint to use for decoding."
+        "Note: Epoch counts from 0.",
     )
 
     parser.add_argument(
         "--avg",
         type=int,
         default=15,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch'. "
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch'. ",
     )
 
     parser.add_argument(
@@ -253,7 +252,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -328,7 +328,9 @@ def decode_one_batch(
     supervisions = batch["supervisions"]
     feature_lens = supervisions["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
+    encoder_out, encoder_out_lens = model.encoder(
+        x=feature, x_lens=feature_lens
+    )
     hyps = []
 
     if params.decoding_method == "fast_beam_search":
@@ -387,7 +389,10 @@ def decode_one_batch(
         )
         for i in range(encoder_out.size(0)):
             hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
-    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
+    elif (
+        params.decoding_method == "greedy_search"
+        and params.max_sym_per_frame == 1
+    ):
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -433,7 +438,11 @@ def decode_one_batch(
         return {"greedy_search": hyps}
     elif params.decoding_method == "fast_beam_search":
         return {
-            f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps
+            (
+                f"beam_{params.beam}_"
+                f"max_contexts_{params.max_contexts}_"
+                f"max_states_{params.max_states}"
+            ): hyps
         }
     else:
         return {f"beam_size_{params.beam_size}": hyps}
@@ -506,7 +515,9 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+            logging.info(
+                f"batch {batch_str}, cuts processed until now is {num_cuts}"
+            )
     return results
 
 
@@ -539,7 +550,8 @@ def save_results(
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = (
-        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir
+        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tWER", file=f)
@@ -651,7 +663,9 @@ def main():
             )
             decoding_graph.scores *= params.ngram_lm_scale
         else:
-            decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
+            decoding_graph = k2.trivial_graph(
+                params.vocab_size - 1, device=device
+            )
     else:
         decoding_graph = None
 
@@ -702,7 +716,8 @@ def main():
         )
 
     dev_shards = [
-        str(path) for path in sorted(glob.glob(os.path.join(dev, "shared-*.tar")))
+        str(path)
+        for path in sorted(glob.glob(os.path.join(dev, "shared-*.tar")))
     ]
     cuts_dev_webdataset = CutSet.from_webdataset(
         dev_shards,
@@ -712,7 +727,8 @@ def main():
     )
 
     test_net_shards = [
-        str(path) for path in sorted(glob.glob(os.path.join(test_net, "shared-*.tar")))
+        str(path)
+        for path in sorted(glob.glob(os.path.join(test_net, "shared-*.tar")))
     ]
     cuts_test_net_webdataset = CutSet.from_webdataset(
         test_net_shards,
@@ -723,7 +739,9 @@ def main():
 
     test_meeting_shards = [
         str(path)
-        for path in sorted(glob.glob(os.path.join(test_meeting, "shared-*.tar")))
+        for path in sorted(
+            glob.glob(os.path.join(test_meeting, "shared-*.tar"))
+        )
     ]
     cuts_test_meeting_webdataset = CutSet.from_webdataset(
         test_meeting_shards,
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/export.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/export.py
index c742593df..933642a0f 100755
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/export.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/export.py
@@ -126,20 +126,17 @@ def get_parser():
         "--epoch",
         type=int,
         default=28,
-        help=(
-            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
-        ),
+        help="It specifies the checkpoint to use for decoding."
+        "Note: Epoch counts from 0.",
     )
 
     parser.add_argument(
         "--avg",
         type=int,
         default=15,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch'. "
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch'. ",
     )
 
     parser.add_argument(
@@ -208,7 +205,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
 
     return parser
@@ -470,9 +468,13 @@ def export_joiner_model_onnx(
 
         - projected_decoder_out: a tensor of shape (N, joiner_dim)
     """
-    encoder_proj_filename = str(joiner_filename).replace(".onnx", "_encoder_proj.onnx")
+    encoder_proj_filename = str(joiner_filename).replace(
+        ".onnx", "_encoder_proj.onnx"
+    )
 
-    decoder_proj_filename = str(joiner_filename).replace(".onnx", "_decoder_proj.onnx")
+    decoder_proj_filename = str(joiner_filename).replace(
+        ".onnx", "_decoder_proj.onnx"
+    )
 
     encoder_out_dim = joiner_model.encoder_proj.weight.shape[1]
     decoder_out_dim = joiner_model.decoder_proj.weight.shape[1]
@@ -643,7 +645,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/jit_pretrained.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/jit_pretrained.py
index ed9020c67..e5cc47bfe 100755
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/jit_pretrained.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/jit_pretrained.py
@@ -107,12 +107,10 @@ def get_parser():
         "sound_files",
         type=str,
         nargs="+",
-        help=(
-            "The input sound file(s) to transcribe. "
-            "Supported formats are those supported by torchaudio.load(). "
-            "For example, wav and flac are supported. "
-            "The sample rate has to be 16kHz."
-        ),
+        help="The input sound file(s) to transcribe. "
+        "Supported formats are those supported by torchaudio.load(). "
+        "For example, wav and flac are supported. "
+        "The sample rate has to be 16kHz.",
     )
 
     parser.add_argument(
@@ -147,9 +145,10 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert (
-            sample_rate == expected_sample_rate
-        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
+        assert sample_rate == expected_sample_rate, (
+            f"expected sample rate: {expected_sample_rate}. "
+            f"Given: {sample_rate}"
+        )
         # We use only the first channel
         ans.append(wave[0])
     return ans
@@ -332,7 +331,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_check.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_check.py
index a46ff5a07..c396c50ef 100755
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_check.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_check.py
@@ -219,7 +219,9 @@ def test_joiner(
         )
 
         # Now test encoder_proj
-        joiner_encoder_proj_inputs = {encoder_proj_input_name: encoder_out.numpy()}
+        joiner_encoder_proj_inputs = {
+            encoder_proj_input_name: encoder_out.numpy()
+        }
         joiner_encoder_proj_out = joiner_encoder_proj_session.run(
             [encoder_proj_output_name], joiner_encoder_proj_inputs
         )[0]
@@ -228,10 +230,16 @@ def test_joiner(
         torch_joiner_encoder_proj_out = model.joiner.encoder_proj(encoder_out)
         assert torch.allclose(
             joiner_encoder_proj_out, torch_joiner_encoder_proj_out, atol=1e-5
-        ), ((joiner_encoder_proj_out - torch_joiner_encoder_proj_out).abs().max())
+        ), (
+            (joiner_encoder_proj_out - torch_joiner_encoder_proj_out)
+            .abs()
+            .max()
+        )
 
         # Now test decoder_proj
-        joiner_decoder_proj_inputs = {decoder_proj_input_name: decoder_out.numpy()}
+        joiner_decoder_proj_inputs = {
+            decoder_proj_input_name: decoder_out.numpy()
+        }
         joiner_decoder_proj_out = joiner_decoder_proj_session.run(
             [decoder_proj_output_name], joiner_decoder_proj_inputs
         )[0]
@@ -240,7 +248,11 @@ def test_joiner(
         torch_joiner_decoder_proj_out = model.joiner.decoder_proj(decoder_out)
         assert torch.allclose(
             joiner_decoder_proj_out, torch_joiner_decoder_proj_out, atol=1e-5
-        ), ((joiner_decoder_proj_out - torch_joiner_decoder_proj_out).abs().max())
+        ), (
+            (joiner_decoder_proj_out - torch_joiner_decoder_proj_out)
+            .abs()
+            .max()
+        )
 
 
 @torch.no_grad()
@@ -292,7 +304,9 @@ def main():
 
 if __name__ == "__main__":
     torch.manual_seed(20220727)
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_pretrained.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_pretrained.py
index f7d962008..3770fbbb4 100755
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_pretrained.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_pretrained.py
@@ -111,12 +111,10 @@ def get_parser():
         "sound_files",
         type=str,
         nargs="+",
-        help=(
-            "The input sound file(s) to transcribe. "
-            "Supported formats are those supported by torchaudio.load(). "
-            "For example, wav and flac are supported. "
-            "The sample rate has to be 16kHz."
-        ),
+        help="The input sound file(s) to transcribe. "
+        "Supported formats are those supported by torchaudio.load(). "
+        "For example, wav and flac are supported. "
+        "The sample rate has to be 16kHz.",
     )
 
     parser.add_argument(
@@ -151,9 +149,10 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert (
-            sample_rate == expected_sample_rate
-        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
+        assert sample_rate == expected_sample_rate, (
+            f"expected sample rate: {expected_sample_rate}. "
+            f"Given: {sample_rate}"
+        )
         # We use only the first channel
         ans.append(wave[0])
     return ans
@@ -201,7 +200,11 @@ def greedy_search(
 
     projected_encoder_out = joiner_encoder_proj.run(
         [joiner_encoder_proj.get_outputs()[0].name],
-        {joiner_encoder_proj.get_inputs()[0].name: packed_encoder_out.data.numpy()},
+        {
+            joiner_encoder_proj.get_inputs()[
+                0
+            ].name: packed_encoder_out.data.numpy()
+        },
     )[0]
 
     blank_id = 0  # hard-code to 0
@@ -386,7 +389,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/pretrained.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/pretrained.py
index 26c9c2b8c..9a549efd9 100755
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/pretrained.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/pretrained.py
@@ -80,11 +80,9 @@ def get_parser():
         "--checkpoint",
         type=str,
         required=True,
-        help=(
-            "Path to the checkpoint. "
-            "The checkpoint is assumed to be saved by "
-            "icefall.checkpoint.save_checkpoint()."
-        ),
+        help="Path to the checkpoint. "
+        "The checkpoint is assumed to be saved by "
+        "icefall.checkpoint.save_checkpoint().",
     )
 
     parser.add_argument(
@@ -109,12 +107,10 @@ def get_parser():
         "sound_files",
         type=str,
         nargs="+",
-        help=(
-            "The input sound file(s) to transcribe. "
-            "Supported formats are those supported by torchaudio.load(). "
-            "For example, wav and flac are supported. "
-            "The sample rate has to be 16kHz."
-        ),
+        help="The input sound file(s) to transcribe. "
+        "Supported formats are those supported by torchaudio.load(). "
+        "For example, wav and flac are supported. "
+        "The sample rate has to be 16kHz.",
     )
 
     parser.add_argument(
@@ -162,7 +158,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
 
     parser.add_argument(
@@ -192,9 +189,10 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert (
-            sample_rate == expected_sample_rate
-        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
+        assert sample_rate == expected_sample_rate, (
+            f"expected sample rate: {expected_sample_rate}. "
+            f"Given: {sample_rate}"
+        )
         # We use only the first channel
         ans.append(wave[0])
     return ans
@@ -255,7 +253,9 @@ def main():
     features = fbank(waves)
     feature_lengths = [f.size(0) for f in features]
 
-    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
+    features = pad_sequence(
+        features, batch_first=True, padding_value=math.log(1e-10)
+    )
 
     feature_lengths = torch.tensor(feature_lengths, device=device)
 
@@ -280,7 +280,10 @@ def main():
         )
         for i in range(encoder_out.size(0)):
             hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
-    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
+    elif (
+        params.decoding_method == "greedy_search"
+        and params.max_sym_per_frame == 1
+    ):
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -332,7 +335,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py
index e020c4c05..d3cc7c9c9 100644
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py
@@ -115,7 +115,9 @@ from icefall.env import get_env_info
 from icefall.lexicon import Lexicon
 from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
 
-LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
+LRSchedulerType = Union[
+    torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
+]
 
 
 def get_parser():
@@ -217,45 +219,42 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
 
     parser.add_argument(
         "--prune-range",
         type=int,
         default=5,
-        help=(
-            "The prune range for rnnt loss, it means how many symbols(context)"
-            "we are using to compute the loss"
-        ),
+        help="The prune range for rnnt loss, it means how many symbols(context)"
+        "we are using to compute the loss",
     )
 
     parser.add_argument(
         "--lm-scale",
         type=float,
         default=0.25,
-        help=(
-            "The scale to smooth the loss with lm (output of prediction network) part."
-        ),
+        help="The scale to smooth the loss with lm "
+        "(output of prediction network) part.",
     )
 
     parser.add_argument(
         "--am-scale",
         type=float,
         default=0.0,
-        help="The scale to smooth the loss with am (output of encoder network)part.",
+        help="The scale to smooth the loss with am (output of encoder network)"
+        "part.",
     )
 
     parser.add_argument(
         "--simple-loss-scale",
         type=float,
         default=0.5,
-        help=(
-            "To get pruning ranges, we will calculate a simple version"
-            "loss(joiner is just addition), this simple loss also uses for"
-            "training (as a regularization item). We will scale the simple loss"
-            "with this parameter before adding to the final loss."
-        ),
+        help="To get pruning ranges, we will calculate a simple version"
+        "loss(joiner is just addition), this simple loss also uses for"
+        "training (as a regularization item). We will scale the simple loss"
+        "with this parameter before adding to the final loss.",
     )
 
     parser.add_argument(
@@ -591,15 +590,22 @@ def compute_loss(
         # overwhelming the simple_loss and causing it to diverge,
         # in case it had not fully learned the alignment yet.
         pruned_loss_scale = (
-            0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
+            0.0
+            if warmup < 1.0
+            else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
+        )
+        loss = (
+            params.simple_loss_scale * simple_loss
+            + pruned_loss_scale * pruned_loss
         )
-        loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
     assert loss.requires_grad == is_training
 
     info = MetricsTracker()
     with warnings.catch_warnings():
         warnings.simplefilter("ignore")
-        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
+        info["frames"] = (
+            (feature_lens // params.subsampling_factor).sum().item()
+        )
 
     # Note: We use reduction=sum while computing the loss.
     info["loss"] = loss.detach().cpu().item()
@@ -756,7 +762,9 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+                tot_loss.write_summary(
+                    tb_writer, "train/tot_", params.batch_idx_train
+                )
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
@@ -856,7 +864,7 @@ def run(rank, world_size, args):
 
     if params.print_diagnostics:
         opts = diagnostics.TensorDiagnosticOptions(
-            2**22
+            2 ** 22
         )  # allow 4 megabytes per sub-module
         diagnostic = diagnostics.attach_diagnostics(model, opts)
 
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/conformer.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/conformer.py
index 1023c931a..dd27c17f0 100644
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/conformer.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/conformer.py
@@ -210,7 +210,10 @@ class Conformer(EncoderInterface):
           (num_encoder_layers, cnn_module_kernel - 1, encoder_dim).
           NOTE: the returned tensors are on the given device.
         """
-        if len(self._init_state) == 2 and self._init_state[0].size(1) == left_context:
+        if (
+            len(self._init_state) == 2
+            and self._init_state[0].size(1) == left_context
+        ):
             # Note: It is OK to share the init state as it is
             # not going to be modified by the model
             return self._init_state
@@ -430,7 +433,9 @@ class ConformerEncoderLayer(nn.Module):
 
         self.d_model = d_model
 
-        self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0)
+        self.self_attn = RelPositionMultiheadAttention(
+            d_model, nhead, dropout=0.0
+        )
 
         self.feed_forward = nn.Sequential(
             ScaledLinear(d_model, dim_feedforward),
@@ -448,7 +453,9 @@ class ConformerEncoderLayer(nn.Module):
             ScaledLinear(dim_feedforward, d_model, initial_scale=0.25),
         )
 
-        self.conv_module = ConvolutionModule(d_model, cnn_module_kernel, causal=causal)
+        self.conv_module = ConvolutionModule(
+            d_model, cnn_module_kernel, causal=causal
+        )
 
         self.norm_final = BasicNorm(d_model)
 
@@ -513,7 +520,9 @@ class ConformerEncoderLayer(nn.Module):
         src = src + self.dropout(src_att)
 
         # convolution module
-        conv, _ = self.conv_module(src, src_key_padding_mask=src_key_padding_mask)
+        conv, _ = self.conv_module(
+            src, src_key_padding_mask=src_key_padding_mask
+        )
         src = src + self.dropout(conv)
 
         # feed forward module
@@ -757,7 +766,9 @@ class RelPositionalEncoding(torch.nn.Module):
         max_len: Maximum input length.
     """
 
-    def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None:
+    def __init__(
+        self, d_model: int, dropout_rate: float, max_len: int = 5000
+    ) -> None:
         """Construct an PositionalEncoding object."""
         super(RelPositionalEncoding, self).__init__()
         self.d_model = d_model
@@ -773,7 +784,9 @@ class RelPositionalEncoding(torch.nn.Module):
             # the length of self.pe is 2 * input_len - 1
             if self.pe.size(1) >= x_size_1 * 2 - 1:
                 # Note: TorchScript doesn't implement operator== for torch.Device
-                if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device):
+                if self.pe.dtype != x.dtype or str(self.pe.device) != str(
+                    x.device
+                ):
                     self.pe = self.pe.to(dtype=x.dtype, device=x.device)
                 return
         # Suppose `i` means to the position of query vector and `j` means the
@@ -1060,9 +1073,9 @@ class RelPositionMultiheadAttention(nn.Module):
 
         if torch.equal(query, key) and torch.equal(key, value):
             # self-attention
-            q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk(
-                3, dim=-1
-            )
+            q, k, v = nn.functional.linear(
+                query, in_proj_weight, in_proj_bias
+            ).chunk(3, dim=-1)
 
         elif torch.equal(key, value):
             # encoder-decoder attention
@@ -1131,25 +1144,33 @@ class RelPositionMultiheadAttention(nn.Module):
             if attn_mask.dim() == 2:
                 attn_mask = attn_mask.unsqueeze(0)
                 if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
-                    raise RuntimeError("The size of the 2D attn_mask is not correct.")
+                    raise RuntimeError(
+                        "The size of the 2D attn_mask is not correct."
+                    )
             elif attn_mask.dim() == 3:
                 if list(attn_mask.size()) != [
                     bsz * num_heads,
                     query.size(0),
                     key.size(0),
                 ]:
-                    raise RuntimeError("The size of the 3D attn_mask is not correct.")
+                    raise RuntimeError(
+                        "The size of the 3D attn_mask is not correct."
+                    )
             else:
                 raise RuntimeError(
-                    "attn_mask's dimension {} is not supported".format(attn_mask.dim())
+                    "attn_mask's dimension {} is not supported".format(
+                        attn_mask.dim()
+                    )
                 )
             # attn_mask's dim is 3 now.
 
         # convert ByteTensor key_padding_mask to bool
-        if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
+        if (
+            key_padding_mask is not None
+            and key_padding_mask.dtype == torch.uint8
+        ):
             warnings.warn(
-                "Byte tensor for key_padding_mask is deprecated. Use bool tensor"
-                " instead."
+                "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead."
             )
             key_padding_mask = key_padding_mask.to(torch.bool)
 
@@ -1187,15 +1208,23 @@ class RelPositionMultiheadAttention(nn.Module):
         # first compute matrix a and matrix c
         # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
         k = k.permute(1, 2, 3, 0)  # (batch, head, d_k, time2)
-        matrix_ac = torch.matmul(q_with_bias_u, k)  # (batch, head, time1, time2)
+        matrix_ac = torch.matmul(
+            q_with_bias_u, k
+        )  # (batch, head, time1, time2)
 
         # compute matrix b and matrix d
-        matrix_bd = torch.matmul(q_with_bias_v, p)  # (batch, head, time1, 2*time1-1)
+        matrix_bd = torch.matmul(
+            q_with_bias_v, p
+        )  # (batch, head, time1, 2*time1-1)
         matrix_bd = self.rel_shift(matrix_bd, left_context)
 
-        attn_output_weights = matrix_ac + matrix_bd  # (batch, head, time1, time2)
+        attn_output_weights = (
+            matrix_ac + matrix_bd
+        )  # (batch, head, time1, time2)
 
-        attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1)
+        attn_output_weights = attn_output_weights.view(
+            bsz * num_heads, tgt_len, -1
+        )
 
         assert list(attn_output_weights.size()) == [
             bsz * num_heads,
@@ -1236,17 +1265,21 @@ class RelPositionMultiheadAttention(nn.Module):
         ):
             if attn_mask.size(0) != 1:
                 attn_mask = attn_mask.view(bsz, num_heads, tgt_len, src_len)
-                combined_mask = attn_mask | key_padding_mask.unsqueeze(1).unsqueeze(2)
-            else:
-                # attn_mask.shape == (1, tgt_len, src_len)
-                combined_mask = attn_mask.unsqueeze(0) | key_padding_mask.unsqueeze(
+                combined_mask = attn_mask | key_padding_mask.unsqueeze(
                     1
                 ).unsqueeze(2)
+            else:
+                # attn_mask.shape == (1, tgt_len, src_len)
+                combined_mask = attn_mask.unsqueeze(
+                    0
+                ) | key_padding_mask.unsqueeze(1).unsqueeze(2)
 
             attn_output_weights = attn_output_weights.view(
                 bsz, num_heads, tgt_len, src_len
             )
-            attn_output_weights = attn_output_weights.masked_fill(combined_mask, 0.0)
+            attn_output_weights = attn_output_weights.masked_fill(
+                combined_mask, 0.0
+            )
             attn_output_weights = attn_output_weights.view(
                 bsz * num_heads, tgt_len, src_len
             )
@@ -1258,9 +1291,13 @@ class RelPositionMultiheadAttention(nn.Module):
         attn_output = torch.bmm(attn_output_weights, v)
         assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
         attn_output = (
-            attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
+            attn_output.transpose(0, 1)
+            .contiguous()
+            .view(tgt_len, bsz, embed_dim)
+        )
+        attn_output = nn.functional.linear(
+            attn_output, out_proj_weight, out_proj_bias
         )
-        attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias)
 
         if need_weights:
             # average attention weights over heads
@@ -1393,12 +1430,16 @@ class ConvolutionModule(nn.Module):
                 # manualy padding self.lorder zeros to the left
                 x = nn.functional.pad(x, (self.lorder, 0), "constant", 0.0)
             else:
-                assert not self.training, "Cache should be None in training time"
+                assert (
+                    not self.training
+                ), "Cache should be None in training time"
                 assert cache.size(0) == self.lorder
                 x = torch.cat([cache.permute(1, 2, 0), x], dim=2)
                 if right_context > 0:
                     cache = x.permute(2, 0, 1)[
-                        -(self.lorder + right_context) : (-right_context),  # noqa
+                        -(self.lorder + right_context) : (  # noqa
+                            -right_context
+                        ),
                         ...,
                     ]
                 else:
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py
index 3d66f9dc9..344e31283 100755
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py
@@ -160,24 +160,20 @@ def get_parser():
         "--avg",
         type=int,
         default=15,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch' and '--iter'"
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch' and '--iter'",
     )
 
     parser.add_argument(
         "--use-averaged-model",
         type=str2bool,
         default=True,
-        help=(
-            "Whether to load averaged model. Currently it only supports "
-            "using --epoch. If True, it would decode with the averaged model "
-            "over the epoch range from `epoch-avg` (excluded) to `epoch`."
-            "Actually only the models with epoch number of `epoch-avg` and "
-            "`epoch` are loaded for averaging. "
-        ),
+        help="Whether to load averaged model. Currently it only supports "
+        "using --epoch. If True, it would decode with the averaged model "
+        "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+        "Actually only the models with epoch number of `epoch-avg` and "
+        "`epoch` are loaded for averaging. ",
     )
 
     parser.add_argument(
@@ -248,7 +244,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -345,7 +342,9 @@ def decode_one_batch(
             simulate_streaming=True,
         )
     else:
-        encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
+        encoder_out, encoder_out_lens = model.encoder(
+            x=feature, x_lens=feature_lens
+        )
 
     hyps = []
 
@@ -361,7 +360,10 @@ def decode_one_batch(
         )
         for i in range(encoder_out.size(0)):
             hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
-    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
+    elif (
+        params.decoding_method == "greedy_search"
+        and params.max_sym_per_frame == 1
+    ):
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -407,7 +409,11 @@ def decode_one_batch(
         return {"greedy_search": hyps}
     elif params.decoding_method == "fast_beam_search":
         return {
-            f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps
+            (
+                f"beam_{params.beam}_"
+                f"max_contexts_{params.max_contexts}_"
+                f"max_states_{params.max_states}"
+            ): hyps
         }
     else:
         return {f"beam_size_{params.beam_size}": hyps}
@@ -478,7 +484,9 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+            logging.info(
+                f"batch {batch_str}, cuts processed until now is {num_cuts}"
+            )
     return results
 
 
@@ -511,7 +519,8 @@ def save_results(
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = (
-        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir
+        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tWER", file=f)
@@ -580,12 +589,13 @@ def main():
 
     if not params.use_averaged_model:
         if params.iter > 0:
-            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
-                : params.avg
-            ]
+            filenames = find_checkpoints(
+                params.exp_dir, iteration=-params.iter
+            )[: params.avg]
             if len(filenames) == 0:
                 raise ValueError(
-                    f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
                 )
             elif len(filenames) < params.avg:
                 raise ValueError(
@@ -608,12 +618,13 @@ def main():
             model.load_state_dict(average_checkpoints(filenames, device=device))
     else:
         if params.iter > 0:
-            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
-                : params.avg + 1
-            ]
+            filenames = find_checkpoints(
+                params.exp_dir, iteration=-params.iter
+            )[: params.avg + 1]
             if len(filenames) == 0:
                 raise ValueError(
-                    f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
                 )
             elif len(filenames) < params.avg + 1:
                 raise ValueError(
@@ -641,7 +652,7 @@ def main():
             filename_start = f"{params.exp_dir}/epoch-{start}.pt"
             filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
             logging.info(
-                "Calculating the averaged model over epoch range from "
+                f"Calculating the averaged model over epoch range from "
                 f"{start} (excluded) to {params.epoch}"
             )
             model.to(device)
@@ -709,7 +720,8 @@ def main():
         )
 
     dev_shards = [
-        str(path) for path in sorted(glob.glob(os.path.join(dev, "shared-*.tar")))
+        str(path)
+        for path in sorted(glob.glob(os.path.join(dev, "shared-*.tar")))
     ]
     cuts_dev_webdataset = CutSet.from_webdataset(
         dev_shards,
@@ -719,7 +731,8 @@ def main():
     )
 
     test_net_shards = [
-        str(path) for path in sorted(glob.glob(os.path.join(test_net, "shared-*.tar")))
+        str(path)
+        for path in sorted(glob.glob(os.path.join(test_net, "shared-*.tar")))
     ]
     cuts_test_net_webdataset = CutSet.from_webdataset(
         test_net_shards,
@@ -730,7 +743,9 @@ def main():
 
     test_meeting_shards = [
         str(path)
-        for path in sorted(glob.glob(os.path.join(test_meeting, "shared-*.tar")))
+        for path in sorted(
+            glob.glob(os.path.join(test_meeting, "shared-*.tar"))
+        )
     ]
     cuts_test_meeting_webdataset = CutSet.from_webdataset(
         test_meeting_shards,
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode_stream.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode_stream.py
index e522943c0..386248554 100644
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode_stream.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode_stream.py
@@ -75,7 +75,9 @@ class DecodeStream(object):
         # encoder.streaming_forward
         self.done_frames: int = 0
 
-        self.pad_length = (params.right_context + 2) * params.subsampling_factor + 3
+        self.pad_length = (
+            params.right_context + 2
+        ) * params.subsampling_factor + 3
 
         if params.decoding_method == "greedy_search":
             self.hyp = [params.blank_id] * params.context_size
@@ -89,11 +91,13 @@ class DecodeStream(object):
             )
         elif params.decoding_method == "fast_beam_search":
             # The rnnt_decoding_stream for fast_beam_search.
-            self.rnnt_decoding_stream: k2.RnntDecodingStream = k2.RnntDecodingStream(
-                decoding_graph
+            self.rnnt_decoding_stream: k2.RnntDecodingStream = (
+                k2.RnntDecodingStream(decoding_graph)
             )
         else:
-            raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
+            raise ValueError(
+                f"Unsupported decoding method: {params.decoding_method}"
+            )
 
     @property
     def done(self) -> bool:
@@ -122,10 +126,13 @@ class DecodeStream(object):
         """Consume chunk_size frames of features"""
         chunk_length = chunk_size + self.pad_length
 
-        ret_length = min(self.num_frames - self.num_processed_frames, chunk_length)
+        ret_length = min(
+            self.num_frames - self.num_processed_frames, chunk_length
+        )
 
         ret_features = self.features[
-            self.num_processed_frames : self.num_processed_frames + ret_length  # noqa
+            self.num_processed_frames : self.num_processed_frames  # noqa
+            + ret_length
         ]
 
         self.num_processed_frames += chunk_size
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/export.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/export.py
index fb53f70ab..d0a7fd69f 100644
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/export.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/export.py
@@ -90,20 +90,17 @@ def get_parser():
         "--epoch",
         type=int,
         default=28,
-        help=(
-            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
-        ),
+        help="It specifies the checkpoint to use for decoding."
+        "Note: Epoch counts from 0.",
     )
 
     parser.add_argument(
         "--avg",
         type=int,
         default=15,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch'. "
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch'. ",
     )
 
     parser.add_argument(
@@ -134,7 +131,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
     add_model_arguments(parser)
 
@@ -203,7 +201,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/pretrained.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/pretrained.py
index 9834189d8..1b064c874 100644
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/pretrained.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/pretrained.py
@@ -80,11 +80,9 @@ def get_parser():
         "--checkpoint",
         type=str,
         required=True,
-        help=(
-            "Path to the checkpoint. "
-            "The checkpoint is assumed to be saved by "
-            "icefall.checkpoint.save_checkpoint()."
-        ),
+        help="Path to the checkpoint. "
+        "The checkpoint is assumed to be saved by "
+        "icefall.checkpoint.save_checkpoint().",
     )
 
     parser.add_argument(
@@ -109,12 +107,10 @@ def get_parser():
         "sound_files",
         type=str,
         nargs="+",
-        help=(
-            "The input sound file(s) to transcribe. "
-            "Supported formats are those supported by torchaudio.load(). "
-            "For example, wav and flac are supported. "
-            "The sample rate has to be 16kHz."
-        ),
+        help="The input sound file(s) to transcribe. "
+        "Supported formats are those supported by torchaudio.load(). "
+        "For example, wav and flac are supported. "
+        "The sample rate has to be 16kHz.",
     )
 
     parser.add_argument(
@@ -161,7 +157,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
 
     parser.add_argument(
@@ -192,9 +189,10 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert (
-            sample_rate == expected_sample_rate
-        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
+        assert sample_rate == expected_sample_rate, (
+            f"expected sample rate: {expected_sample_rate}. "
+            f"Given: {sample_rate}"
+        )
         # We use only the first channel
         ans.append(wave[0])
     return ans
@@ -255,7 +253,9 @@ def main():
     features = fbank(waves)
     feature_lengths = [f.size(0) for f in features]
 
-    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
+    features = pad_sequence(
+        features, batch_first=True, padding_value=math.log(1e-10)
+    )
 
     feature_lengths = torch.tensor(feature_lengths, device=device)
 
@@ -280,7 +280,10 @@ def main():
         )
         for i in range(encoder_out.size(0)):
             hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
-    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
+    elif (
+        params.decoding_method == "greedy_search"
+        and params.max_sym_per_frame == 1
+    ):
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -332,7 +335,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_beam_search.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_beam_search.py
index 810d94135..651aff6c9 100644
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_beam_search.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_beam_search.py
@@ -173,10 +173,14 @@ def modified_beam_search(
         log_probs_shape = k2.ragged.create_ragged_shape2(
             row_splits=row_splits, cached_tot_size=log_probs.numel()
         )
-        ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs)
+        ragged_log_probs = k2.RaggedTensor(
+            shape=log_probs_shape, value=log_probs
+        )
 
         for i in range(batch_size):
-            topk_log_probs, topk_indexes = ragged_log_probs[i].topk(num_active_paths)
+            topk_log_probs, topk_indexes = ragged_log_probs[i].topk(
+                num_active_paths
+            )
 
             with warnings.catch_warnings():
                 warnings.simplefilter("ignore")
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_decode.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_decode.py
index 31a7fe605..ff96c6487 100644
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_decode.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_decode.py
@@ -119,24 +119,20 @@ def get_parser():
         "--avg",
         type=int,
         default=15,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch' and '--iter'"
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch' and '--iter'",
     )
 
     parser.add_argument(
         "--use-averaged-model",
         type=str2bool,
         default=True,
-        help=(
-            "Whether to load averaged model. Currently it only supports "
-            "using --epoch. If True, it would decode with the averaged model "
-            "over the epoch range from `epoch-avg` (excluded) to `epoch`."
-            "Actually only the models with epoch number of `epoch-avg` and "
-            "`epoch` are loaded for averaging. "
-        ),
+        help="Whether to load averaged model. Currently it only supports "
+        "using --epoch. If True, it would decode with the averaged model "
+        "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+        "Actually only the models with epoch number of `epoch-avg` and "
+        "`epoch` are loaded for averaging. ",
     )
 
     parser.add_argument(
@@ -205,7 +201,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
 
     parser.add_argument(
@@ -314,7 +311,9 @@ def decode_one_chunk(
     encoder_out = model.joiner.encoder_proj(encoder_out)
 
     if params.decoding_method == "greedy_search":
-        greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams)
+        greedy_search(
+            model=model, encoder_out=encoder_out, streams=decode_streams
+        )
     elif params.decoding_method == "fast_beam_search":
         processed_lens = processed_lens + encoder_out_lens
         fast_beam_search_one_best(
@@ -334,7 +333,9 @@ def decode_one_chunk(
             num_active_paths=params.num_active_paths,
         )
     else:
-        raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
+        raise ValueError(
+            f"Unsupported decoding method: {params.decoding_method}"
+        )
 
     states = [torch.unbind(states[0], dim=2), torch.unbind(states[1], dim=2)]
 
@@ -388,7 +389,9 @@ def decode_dataset(
     decode_results = []
     # Contain decode streams currently running.
     decode_streams = []
-    initial_states = model.encoder.get_init_state(params.left_context, device=device)
+    initial_states = model.encoder.get_init_state(
+        params.left_context, device=device
+    )
     for num, cut in enumerate(cuts):
         # each utterance has a DecodeStream.
         decode_stream = DecodeStream(
@@ -458,7 +461,9 @@ def decode_dataset(
     elif params.decoding_method == "modified_beam_search":
         key = f"num_active_paths_{params.num_active_paths}"
     else:
-        raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
+        raise ValueError(
+            f"Unsupported decoding method: {params.decoding_method}"
+        )
 
     return {key: decode_results}
 
@@ -494,7 +499,8 @@ def save_results(
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = (
-        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir
+        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tWER", file=f)
@@ -559,12 +565,13 @@ def main():
 
     if not params.use_averaged_model:
         if params.iter > 0:
-            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
-                : params.avg
-            ]
+            filenames = find_checkpoints(
+                params.exp_dir, iteration=-params.iter
+            )[: params.avg]
             if len(filenames) == 0:
                 raise ValueError(
-                    f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
                 )
             elif len(filenames) < params.avg:
                 raise ValueError(
@@ -587,12 +594,13 @@ def main():
             model.load_state_dict(average_checkpoints(filenames, device=device))
     else:
         if params.iter > 0:
-            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
-                : params.avg + 1
-            ]
+            filenames = find_checkpoints(
+                params.exp_dir, iteration=-params.iter
+            )[: params.avg + 1]
             if len(filenames) == 0:
                 raise ValueError(
-                    f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
                 )
             elif len(filenames) < params.avg + 1:
                 raise ValueError(
@@ -620,7 +628,7 @@ def main():
             filename_start = f"{params.exp_dir}/epoch-{start}.pt"
             filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
             logging.info(
-                "Calculating the averaged model over epoch range from "
+                f"Calculating the averaged model over epoch range from "
                 f"{start} (excluded) to {params.epoch}"
             )
             model.to(device)
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/train.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/train.py
index 40c9665f7..2052e9da7 100755
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/train.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/train.py
@@ -98,7 +98,9 @@ from icefall.env import get_env_info
 from icefall.lexicon import Lexicon
 from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
 
-LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
+LRSchedulerType = Union[
+    torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
+]
 
 
 def add_model_arguments(parser: argparse.ArgumentParser):
@@ -258,7 +260,8 @@ def get_parser():
         "--initial-lr",
         type=float,
         default=0.003,
-        help="The initial learning rate.  This value should not need to be changed.",
+        help="The initial learning rate.  This value should not need "
+        "to be changed.",
     )
 
     parser.add_argument(
@@ -281,45 +284,42 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
 
     parser.add_argument(
         "--prune-range",
         type=int,
         default=5,
-        help=(
-            "The prune range for rnnt loss, it means how many symbols(context)"
-            "we are using to compute the loss"
-        ),
+        help="The prune range for rnnt loss, it means how many symbols(context)"
+        "we are using to compute the loss",
     )
 
     parser.add_argument(
         "--lm-scale",
         type=float,
         default=0.25,
-        help=(
-            "The scale to smooth the loss with lm (output of prediction network) part."
-        ),
+        help="The scale to smooth the loss with lm "
+        "(output of prediction network) part.",
     )
 
     parser.add_argument(
         "--am-scale",
         type=float,
         default=0.0,
-        help="The scale to smooth the loss with am (output of encoder network)part.",
+        help="The scale to smooth the loss with am (output of encoder network)"
+        "part.",
     )
 
     parser.add_argument(
         "--simple-loss-scale",
         type=float,
         default=0.5,
-        help=(
-            "To get pruning ranges, we will calculate a simple version"
-            "loss(joiner is just addition), this simple loss also uses for"
-            "training (as a regularization item). We will scale the simple loss"
-            "with this parameter before adding to the final loss."
-        ),
+        help="To get pruning ranges, we will calculate a simple version"
+        "loss(joiner is just addition), this simple loss also uses for"
+        "training (as a regularization item). We will scale the simple loss"
+        "with this parameter before adding to the final loss.",
     )
 
     parser.add_argument(
@@ -665,7 +665,11 @@ def compute_loss(
      warmup: a floating point value which increases throughout training;
         values >= 1.0 are fully warmed up and have all modules present.
     """
-    device = model.device if isinstance(model, DDP) else next(model.parameters()).device
+    device = (
+        model.device
+        if isinstance(model, DDP)
+        else next(model.parameters()).device
+    )
     feature = batch["inputs"]
     # at entry, feature is (N, T, C)
     assert feature.ndim == 3
@@ -697,16 +701,23 @@ def compute_loss(
         # overwhelming the simple_loss and causing it to diverge,
         # in case it had not fully learned the alignment yet.
         pruned_loss_scale = (
-            0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
+            0.0
+            if warmup < 1.0
+            else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
+        )
+        loss = (
+            params.simple_loss_scale * simple_loss
+            + pruned_loss_scale * pruned_loss
         )
-        loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
 
     assert loss.requires_grad == is_training
 
     info = MetricsTracker()
     with warnings.catch_warnings():
         warnings.simplefilter("ignore")
-        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
+        info["frames"] = (
+            (feature_lens // params.subsampling_factor).sum().item()
+        )
 
     # Note: We use reduction=sum while computing the loss.
     info["loss"] = loss.detach().cpu().item()
@@ -830,7 +841,9 @@ def train_one_epoch(
             scaler.update()
             optimizer.zero_grad()
         except:  # noqa
-            display_and_save_batch(batch, params=params, graph_compiler=graph_compiler)
+            display_and_save_batch(
+                batch, params=params, graph_compiler=graph_compiler
+            )
             raise
 
         if params.print_diagnostics and batch_idx == 5:
@@ -888,7 +901,9 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+                tot_loss.write_summary(
+                    tb_writer, "train/tot_", params.batch_idx_train
+                )
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
@@ -1001,7 +1016,7 @@ def run(rank, world_size, args):
 
     if params.print_diagnostics:
         opts = diagnostics.TensorDiagnosticOptions(
-            2**22
+            2 ** 22
         )  # allow 4 megabytes per sub-module
         diagnostic = diagnostics.attach_diagnostics(model, opts)
 
@@ -1169,7 +1184,9 @@ def scan_pessimistic_batches_for_oom(
                     f"Failing criterion: {criterion} "
                     f"(={crit_values[criterion]}) ..."
                 )
-            display_and_save_batch(batch, params=params, graph_compiler=graph_compiler)
+            display_and_save_batch(
+                batch, params=params, graph_compiler=graph_compiler
+            )
             raise
 
 
diff --git a/egs/yesno/ASR/local/compile_hlg.py b/egs/yesno/ASR/local/compile_hlg.py
index 7234ca929..f83be05cf 100755
--- a/egs/yesno/ASR/local/compile_hlg.py
+++ b/egs/yesno/ASR/local/compile_hlg.py
@@ -128,7 +128,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
 
diff --git a/egs/yesno/ASR/local/compute_fbank_yesno.py b/egs/yesno/ASR/local/compute_fbank_yesno.py
index 75d95df68..9a4e8a36f 100755
--- a/egs/yesno/ASR/local/compute_fbank_yesno.py
+++ b/egs/yesno/ASR/local/compute_fbank_yesno.py
@@ -54,7 +54,9 @@ def compute_fbank_yesno():
         dataset_parts,
     )
 
-    extractor = Fbank(FbankConfig(sampling_rate=8000, num_mel_bins=num_mel_bins))
+    extractor = Fbank(
+        FbankConfig(sampling_rate=8000, num_mel_bins=num_mel_bins)
+    )
 
     with get_executor() as ex:  # Initialize the executor only once.
         for partition, m in manifests.items():
@@ -69,7 +71,9 @@ def compute_fbank_yesno():
             )
             if "train" in partition:
                 cut_set = (
-                    cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
+                    cut_set
+                    + cut_set.perturb_speed(0.9)
+                    + cut_set.perturb_speed(1.1)
                 )
             cut_set = cut_set.compute_and_store_features(
                 extractor=extractor,
@@ -83,7 +87,9 @@ def compute_fbank_yesno():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
 
diff --git a/egs/yesno/ASR/tdnn/asr_datamodule.py b/egs/yesno/ASR/tdnn/asr_datamodule.py
index 21860d2f5..85e5f1358 100644
--- a/egs/yesno/ASR/tdnn/asr_datamodule.py
+++ b/egs/yesno/ASR/tdnn/asr_datamodule.py
@@ -56,12 +56,10 @@ class YesNoAsrDataModule(DataModule):
         super().add_arguments(parser)
         group = parser.add_argument_group(
             title="ASR data related options",
-            description=(
-                "These options are used for the preparation of "
-                "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
-                "effective batch sizes, sampling strategies, applied data "
-                "augmentations, etc."
-            ),
+            description="These options are used for the preparation of "
+            "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
+            "effective batch sizes, sampling strategies, applied data "
+            "augmentations, etc.",
         )
         group.add_argument(
             "--feature-dir",
@@ -73,91 +71,75 @@ class YesNoAsrDataModule(DataModule):
             "--max-duration",
             type=int,
             default=30.0,
-            help=(
-                "Maximum pooled recordings duration (seconds) in a "
-                "single batch. You can reduce it if it causes CUDA OOM."
-            ),
+            help="Maximum pooled recordings duration (seconds) in a "
+            "single batch. You can reduce it if it causes CUDA OOM.",
         )
         group.add_argument(
             "--bucketing-sampler",
             type=str2bool,
             default=False,
-            help=(
-                "When enabled, the batches will come from buckets of "
-                "similar duration (saves padding frames)."
-            ),
+            help="When enabled, the batches will come from buckets of "
+            "similar duration (saves padding frames).",
         )
         group.add_argument(
             "--num-buckets",
             type=int,
             default=10,
-            help=(
-                "The number of buckets for the DynamicBucketingSampler"
-                "(you might want to increase it for larger datasets)."
-            ),
+            help="The number of buckets for the DynamicBucketingSampler"
+            "(you might want to increase it for larger datasets).",
         )
         group.add_argument(
             "--concatenate-cuts",
             type=str2bool,
             default=False,
-            help=(
-                "When enabled, utterances (cuts) will be concatenated "
-                "to minimize the amount of padding."
-            ),
+            help="When enabled, utterances (cuts) will be concatenated "
+            "to minimize the amount of padding.",
         )
         group.add_argument(
             "--duration-factor",
             type=float,
             default=1.0,
-            help=(
-                "Determines the maximum duration of a concatenated cut "
-                "relative to the duration of the longest cut in a batch."
-            ),
+            help="Determines the maximum duration of a concatenated cut "
+            "relative to the duration of the longest cut in a batch.",
         )
         group.add_argument(
             "--gap",
             type=float,
             default=1.0,
-            help=(
-                "The amount of padding (in seconds) inserted between "
-                "concatenated cuts. This padding is filled with noise when "
-                "noise augmentation is used."
-            ),
+            help="The amount of padding (in seconds) inserted between "
+            "concatenated cuts. This padding is filled with noise when "
+            "noise augmentation is used.",
         )
         group.add_argument(
             "--on-the-fly-feats",
             type=str2bool,
             default=False,
-            help=(
-                "When enabled, use on-the-fly cut mixing and feature "
-                "extraction. Will drop existing precomputed feature manifests "
-                "if available."
-            ),
+            help="When enabled, use on-the-fly cut mixing and feature "
+            "extraction. Will drop existing precomputed feature manifests "
+            "if available.",
         )
         group.add_argument(
             "--shuffle",
             type=str2bool,
             default=True,
-            help=(
-                "When enabled (=default), the examples will be shuffled for each epoch."
-            ),
+            help="When enabled (=default), the examples will be "
+            "shuffled for each epoch.",
         )
         group.add_argument(
             "--return-cuts",
             type=str2bool,
             default=True,
-            help=(
-                "When enabled, each batch will have the "
-                "field: batch['supervisions']['cut'] with the cuts that "
-                "were used to construct it."
-            ),
+            help="When enabled, each batch will have the "
+            "field: batch['supervisions']['cut'] with the cuts that "
+            "were used to construct it.",
         )
 
         group.add_argument(
             "--num-workers",
             type=int,
             default=2,
-            help="The number of training dataloader workers that collect the batches.",
+            help="The number of training dataloader workers that "
+            "collect the batches.",
         )
 
     def train_dataloaders(self) -> DataLoader:
@@ -168,7 +150,7 @@ class YesNoAsrDataModule(DataModule):
         transforms = []
         if self.args.concatenate_cuts:
             logging.info(
-                "Using cut concatenation with duration factor "
+                f"Using cut concatenation with duration factor "
                 f"{self.args.duration_factor} and gap {self.args.gap}."
             )
             # Cut concatenation should be the first transform in the list,
diff --git a/egs/yesno/ASR/tdnn/decode.py b/egs/yesno/ASR/tdnn/decode.py
index 41afe0404..9d4ab4b61 100755
--- a/egs/yesno/ASR/tdnn/decode.py
+++ b/egs/yesno/ASR/tdnn/decode.py
@@ -35,19 +35,16 @@ def get_parser():
         "--epoch",
         type=int,
         default=14,
-        help=(
-            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
-        ),
+        help="It specifies the checkpoint to use for decoding."
+        "Note: Epoch counts from 0.",
     )
     parser.add_argument(
         "--avg",
         type=int,
         default=2,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch'. "
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch'. ",
     )
 
     parser.add_argument(
@@ -204,7 +201,9 @@ def decode_dataset(
         if batch_idx % 100 == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+            logging.info(
+                f"batch {batch_str}, cuts processed until now is {num_cuts}"
+            )
     return results
 
 
@@ -275,7 +274,9 @@ def main():
 
     logging.info(f"device: {device}")
 
-    HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu"))
+    HLG = k2.Fsa.from_dict(
+        torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")
+    )
     HLG = HLG.to(device)
     assert HLG.requires_grad is False
 
@@ -296,7 +297,9 @@ def main():
 
     if params.export:
         logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
-        torch.save({"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt")
+        torch.save(
+            {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
+        )
         return
 
     model.to(device)
@@ -314,7 +317,9 @@ def main():
         word_table=lexicon.word_table,
     )
 
-    save_results(exp_dir=params.exp_dir, test_set_name="test_set", results=results)
+    save_results(
+        exp_dir=params.exp_dir, test_set_name="test_set", results=results
+    )
 
     logging.info("Done!")
 
diff --git a/egs/yesno/ASR/tdnn/pretrained.py b/egs/yesno/ASR/tdnn/pretrained.py
index 09a8672ae..14220be19 100755
--- a/egs/yesno/ASR/tdnn/pretrained.py
+++ b/egs/yesno/ASR/tdnn/pretrained.py
@@ -41,11 +41,9 @@ def get_parser():
         "--checkpoint",
         type=str,
         required=True,
-        help=(
-            "Path to the checkpoint. "
-            "The checkpoint is assumed to be saved by "
-            "icefall.checkpoint.save_checkpoint()."
-        ),
+        help="Path to the checkpoint. "
+        "The checkpoint is assumed to be saved by "
+        "icefall.checkpoint.save_checkpoint().",
     )
 
     parser.add_argument(
@@ -55,18 +53,18 @@ def get_parser():
         help="Path to words.txt",
     )
 
-    parser.add_argument("--HLG", type=str, required=True, help="Path to HLG.pt.")
+    parser.add_argument(
+        "--HLG", type=str, required=True, help="Path to HLG.pt."
+    )
 
     parser.add_argument(
         "sound_files",
         type=str,
         nargs="+",
-        help=(
-            "The input sound file(s) to transcribe. "
-            "Supported formats are those supported by torchaudio.load(). "
-            "For example, wav and flac are supported. "
-            "The sample rate has to be 16kHz."
-        ),
+        help="The input sound file(s) to transcribe. "
+        "Supported formats are those supported by torchaudio.load(). "
+        "For example, wav and flac are supported. "
+        "The sample rate has to be 16kHz.",
     )
 
     return parser
@@ -103,9 +101,10 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert (
-            sample_rate == expected_sample_rate
-        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
+        assert sample_rate == expected_sample_rate, (
+            f"expected sample rate: {expected_sample_rate}. "
+            f"Given: {sample_rate}"
+        )
         # We use only the first channel
         ans.append(wave[0])
     return ans
@@ -160,7 +159,9 @@ def main():
     logging.info("Decoding started")
     features = fbank(waves)
 
-    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
+    features = pad_sequence(
+        features, batch_first=True, padding_value=math.log(1e-10)
+    )
 
     # Note: We don't use key padding mask for attention during decoding
     with torch.no_grad():
@@ -200,7 +201,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/yesno/ASR/tdnn/train.py b/egs/yesno/ASR/tdnn/train.py
index 335493491..f32a27f35 100755
--- a/egs/yesno/ASR/tdnn/train.py
+++ b/egs/yesno/ASR/tdnn/train.py
@@ -430,7 +430,9 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+                tot_loss.write_summary(
+                    tb_writer, "train/tot_", params.batch_idx_train
+                )
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             valid_info = compute_validation_loss(
diff --git a/egs/yesno/ASR/transducer/decode.py b/egs/yesno/ASR/transducer/decode.py
index de478334e..6714180db 100755
--- a/egs/yesno/ASR/transducer/decode.py
+++ b/egs/yesno/ASR/transducer/decode.py
@@ -48,19 +48,16 @@ def get_parser():
         "--epoch",
         type=int,
         default=125,
-        help=(
-            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
-        ),
+        help="It specifies the checkpoint to use for decoding."
+        "Note: Epoch counts from 0.",
     )
     parser.add_argument(
         "--avg",
         type=int,
         default=20,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch'. "
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch'. ",
     )
     parser.add_argument(
         "--exp-dir",
@@ -119,7 +116,9 @@ def decode_one_batch(
     # at entry, feature is (N, T, C)
     feature_lens = batch["supervisions"]["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
+    encoder_out, encoder_out_lens = model.encoder(
+        x=feature, x_lens=feature_lens
+    )
 
     hyps = []
     batch_size = encoder_out.size(0)
@@ -187,7 +186,9 @@ def decode_dataset(
         if batch_idx % 100 == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+            logging.info(
+                f"batch {batch_str}, cuts processed until now is {num_cuts}"
+            )
     return results
 
 
@@ -302,7 +303,9 @@ def main():
         model=model,
     )
 
-    save_results(exp_dir=params.exp_dir, test_set_name="test_set", results=results)
+    save_results(
+        exp_dir=params.exp_dir, test_set_name="test_set", results=results
+    )
 
     logging.info("Done!")
 
diff --git a/egs/yesno/ASR/transducer/train.py b/egs/yesno/ASR/transducer/train.py
index 88866ae81..deb92107d 100755
--- a/egs/yesno/ASR/transducer/train.py
+++ b/egs/yesno/ASR/transducer/train.py
@@ -430,7 +430,9 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+                tot_loss.write_summary(
+                    tb_writer, "train/tot_", params.batch_idx_train
+                )
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             valid_info = compute_validation_loss(
diff --git a/icefall/char_graph_compiler.py b/icefall/char_graph_compiler.py
index c31db6e4c..235160e14 100644
--- a/icefall/char_graph_compiler.py
+++ b/icefall/char_graph_compiler.py
@@ -71,7 +71,9 @@ class CharCtcTrainingGraphCompiler(object):
         for text in texts:
             text = re.sub(whitespace, "", text)
             sub_ids = [
-                self.token_table[txt] if txt in self.token_table else self.oov_id
+                self.token_table[txt]
+                if txt in self.token_table
+                else self.oov_id
                 for txt in text
             ]
             ids.append(sub_ids)
@@ -94,7 +96,9 @@ class CharCtcTrainingGraphCompiler(object):
         for text in texts:
             text = text.split("/")
             sub_ids = [
-                self.token_table[txt] if txt in self.token_table else self.oov_id
+                self.token_table[txt]
+                if txt in self.token_table
+                else self.oov_id
                 for txt in text
             ]
             ids.append(sub_ids)
diff --git a/icefall/checkpoint.py b/icefall/checkpoint.py
index 8aa0a8eeb..5069b78e8 100644
--- a/icefall/checkpoint.py
+++ b/icefall/checkpoint.py
@@ -292,11 +292,15 @@ def find_checkpoints(out_dir: Path, iteration: int = 0) -> List[str]:
     """
     checkpoints = list(glob.glob(f"{out_dir}/checkpoint-[0-9]*.pt"))
     pattern = re.compile(r"checkpoint-([0-9]+).pt")
-    iter_checkpoints = [(int(pattern.search(c).group(1)), c) for c in checkpoints]
+    iter_checkpoints = [
+        (int(pattern.search(c).group(1)), c) for c in checkpoints
+    ]
     # iter_checkpoints is a list of tuples. Each tuple contains
     # two elements: (iteration_number, checkpoint-iteration_number.pt)
 
-    iter_checkpoints = sorted(iter_checkpoints, reverse=True, key=lambda x: x[0])
+    iter_checkpoints = sorted(
+        iter_checkpoints, reverse=True, key=lambda x: x[0]
+    )
     if iteration >= 0:
         ans = [ic[1] for ic in iter_checkpoints if ic[0] >= iteration]
     else:
@@ -465,5 +469,7 @@ def average_state_dict(
         v = state_dict_1[k]
         if torch.is_floating_point(v):
             v *= weight_1
-            v += state_dict_2[k].to(device=state_dict_1[k].device) * weight_2
+            v += (
+                state_dict_2[k].to(device=state_dict_1[k].device) * weight_2
+            )
             v *= scaling_factor
diff --git a/icefall/decode.py b/icefall/decode.py
index e4c614c4e..099e2d171 100644
--- a/icefall/decode.py
+++ b/icefall/decode.py
@@ -334,9 +334,13 @@ class Nbest(object):
         if hasattr(lattice, "aux_labels"):
             # delete token IDs as it is not needed
             del word_fsa.aux_labels
-            word_fsa_with_epsilon_loops = k2.linear_fsa_with_self_loops(word_fsa)
+            word_fsa_with_epsilon_loops = k2.linear_fsa_with_self_loops(
+                word_fsa
+            )
         else:
-            word_fsa_with_epsilon_loops = k2.linear_fst_with_self_loops(word_fsa)
+            word_fsa_with_epsilon_loops = k2.linear_fst_with_self_loops(
+                word_fsa
+            )
 
         path_to_utt_map = self.shape.row_ids(1)
 
@@ -366,7 +370,9 @@ class Nbest(object):
         # path_lattice has word IDs as labels and token IDs as aux_labels
         path_lattice = k2.top_sort(k2.connect(path_lattice))
 
-        one_best = k2.shortest_path(path_lattice, use_double_scores=use_double_scores)
+        one_best = k2.shortest_path(
+            path_lattice, use_double_scores=use_double_scores
+        )
 
         one_best = k2.invert(one_best)
         # Now one_best has token IDs as labels and word IDs as aux_labels
@@ -436,7 +442,9 @@ class Nbest(object):
         scores_shape = self.fsa.arcs.shape().remove_axis(1)
         # scores_shape has axes [path][arc]
 
-        ragged_scores = k2.RaggedTensor(scores_shape, self.fsa.scores.contiguous())
+        ragged_scores = k2.RaggedTensor(
+            scores_shape, self.fsa.scores.contiguous()
+        )
 
         tot_scores = ragged_scores.sum()
 
@@ -475,7 +483,9 @@ def one_best_decoding(
             am_scores = saved_am_scores / lm_scale
             lattice.scores = am_scores + lattice.lm_scores
 
-            best_path = k2.shortest_path(lattice, use_double_scores=use_double_scores)
+            best_path = k2.shortest_path(
+                lattice, use_double_scores=use_double_scores
+            )
             key = f"lm_scale_{lm_scale}"
             ans[key] = best_path
         return ans
@@ -686,7 +696,9 @@ def rescore_with_n_best_list(
             logging.info(f"num_paths before decreasing: {num_paths}")
             num_paths = int(num_paths / 2)
             if loop_count >= max_loop_count or num_paths <= 0:
-                logging.info("Return None as the resulting lattice is too large.")
+                logging.info(
+                    "Return None as the resulting lattice is too large."
+                )
                 return None
             logging.info(
                 "This OOM is not an error. You can ignore it. "
@@ -793,9 +805,13 @@ def rescore_with_whole_lattice(
         except RuntimeError as e:
             logging.info(f"Caught exception:\n{e}\n")
             if loop_count >= max_loop_count:
-                logging.info("Return None as the resulting lattice is too large.")
+                logging.info(
+                    "Return None as the resulting lattice is too large."
+                )
                 return None
-            logging.info(f"num_arcs before pruning: {inv_lattice.arcs.num_elements()}")
+            logging.info(
+                f"num_arcs before pruning: {inv_lattice.arcs.num_elements()}"
+            )
             logging.info(
                 "This OOM is not an error. You can ignore it. "
                 "If your model does not converge well, or --max-duration "
@@ -807,7 +823,9 @@ def rescore_with_whole_lattice(
                 prune_th_list[loop_count],
                 True,
             )
-            logging.info(f"num_arcs after pruning: {inv_lattice.arcs.num_elements()}")
+            logging.info(
+                f"num_arcs after pruning: {inv_lattice.arcs.num_elements()}"
+            )
         loop_count += 1
 
     # lat has token IDs as labels
@@ -894,7 +912,9 @@ def rescore_with_attention_decoder(
             logging.info(f"num_paths before decreasing: {num_paths}")
             num_paths = int(num_paths / 2)
             if loop_count >= max_loop_count or num_paths <= 0:
-                logging.info("Return None as the resulting lattice is too large.")
+                logging.info(
+                    "Return None as the resulting lattice is too large."
+                )
                 return None
             logging.info(
                 "This OOM is not an error. You can ignore it. "
diff --git a/icefall/diagnostics.py b/icefall/diagnostics.py
index 7b58ffbd4..b075aceac 100644
--- a/icefall/diagnostics.py
+++ b/icefall/diagnostics.py
@@ -19,7 +19,7 @@
 
 import random
 from dataclasses import dataclass
-from typing import List, Optional, Tuple
+from typing import Optional, Tuple, List
 
 import torch
 from torch import Tensor, nn
@@ -78,11 +78,11 @@ def get_tensor_stats(
     elif stats_type == "abs":
         x = x.abs()
     elif stats_type == "rms":
-        x = x**2
+        x = x ** 2
     elif stats_type == "positive":
         x = (x > 0).to(dtype=torch.float)
     else:
-        assert stats_type in ["value", "max", "min"]
+        assert stats_type in [ "value", "max", "min" ]
 
     sum_dims = [d for d in range(x.ndim) if d != dim]
     if len(sum_dims) > 0:
@@ -121,9 +121,7 @@ class TensorDiagnostic(object):
         self.name = name
         self.class_name = None  # will assign in accumulate()
 
-        self.stats = (
-            None  # we'll later assign a list to this data member.  It's a list of dict.
-        )
+        self.stats = None  # we'll later assign a list to this data member.  It's a list of dict.
 
         # the keys into self.stats[dim] are strings, whose values can be
         # "abs", "max", "min" ,"value", "positive", "rms", "value".
@@ -135,6 +133,7 @@ class TensorDiagnostic(object):
         # only adding a new element to the list if there was a different dim.
         # if the string in the key is "eigs", if we detect a length mismatch we put None as the value.
 
+
     def accumulate(self, x, class_name: Optional[str] = None):
         """
         Accumulate tensors.
@@ -186,12 +185,17 @@ class TensorDiagnostic(object):
                         done = True
                         break
                 if not done:
-                    if this_dim_stats[stats_type] != [] and stats_type == "eigs":
+                    if (
+                        this_dim_stats[stats_type] != []
+                        and stats_type == "eigs"
+                    ):
                         # >1 size encountered on this dim, e.g. it's a batch or time dimension,
                         # don't accumulat "eigs" stats type, it uses too much memory
                         this_dim_stats[stats_type] = None
                     else:
-                        this_dim_stats[stats_type].append(TensorAndCount(stats, count))
+                        this_dim_stats[stats_type].append(
+                            TensorAndCount(stats, count)
+                        )
 
     def print_diagnostics(self):
         """Print diagnostics for each dimension of the tensor."""
@@ -207,6 +211,7 @@ class TensorDiagnostic(object):
                     assert stats_type == "eigs"
                     continue
 
+
                 def get_count(count):
                     return 1 if stats_type in ["max", "min"] else count
 
@@ -216,8 +221,7 @@ class TensorDiagnostic(object):
                     # a dimension that has variable size in different nnet
                     # forwards, e.g. a time dimension in an ASR model.
                     stats = torch.cat(
-                        [x.tensor / get_count(x.count) for x in stats_list],
-                        dim=0,
+                        [x.tensor / get_count(x.count) for x in stats_list], dim=0
                     )
 
                 if stats_type == "eigs":
@@ -225,7 +229,9 @@ class TensorDiagnostic(object):
                         eigs, _ = torch.symeig(stats)
                         stats = eigs.abs().sqrt()
                     except:  # noqa
-                        print("Error getting eigenvalues, trying another method.")
+                        print(
+                            "Error getting eigenvalues, trying another method."
+                        )
                         eigs, _ = torch.eig(stats)
                         stats = eigs.abs().sqrt()
                         # sqrt so it reflects data magnitude, like stddev- not variance
@@ -236,9 +242,9 @@ class TensorDiagnostic(object):
 
                 # if `summarize` we print percentiles of the stats; else,
                 # we print out individual elements.
-                summarize = (len(stats_list) > 1) or self.opts.dim_is_summarized(
-                    stats.numel()
-                )
+                summarize = (
+                    len(stats_list) > 1
+                ) or self.opts.dim_is_summarized(stats.numel())
                 if summarize:  # usually `summarize` will be true
                     # print out percentiles.
                     stats = stats.sort()[0]
@@ -255,15 +261,15 @@ class TensorDiagnostic(object):
                     ans = stats.tolist()
                     ans = ["%.2g" % x for x in ans]
                     ans = "[" + " ".join(ans) + "]"
-                if stats_type in ["value", "rms", "eigs"]:
+                if stats_type in [ "value", "rms", "eigs" ]:
                     # This norm is useful because it is strictly less than the largest
                     # sqrt(eigenvalue) of the variance, which we print out, and shows,
                     # speaking in an approximate way, how much of that largest eigenvalue
                     # can be attributed to the mean of the distribution.
-                    norm = (stats**2).sum().sqrt().item()
+                    norm = (stats ** 2).sum().sqrt().item()
                     ans += f", norm={norm:.2g}"
                 mean = stats.mean().item()
-                rms = (stats**2).mean().sqrt().item()
+                rms = (stats ** 2).mean().sqrt().item()
                 ans += f", mean={mean:.3g}, rms={rms:.3g}"
 
                 # OK, "ans" contains the actual stats, e.g.
@@ -271,17 +277,17 @@ class TensorDiagnostic(object):
 
                 sizes = [x.tensor.shape[0] for x in stats_list]
                 size_str = (
-                    f"{sizes[0]}" if len(sizes) == 1 else f"{min(sizes)}..{max(sizes)}"
-                )
-                maybe_class_name = (
-                    f" type={self.class_name}," if self.class_name is not None else ""
+                    f"{sizes[0]}"
+                    if len(sizes) == 1
+                    else f"{min(sizes)}..{max(sizes)}"
                 )
+                maybe_class_name = f" type={self.class_name}," if self.class_name is not None else ""
                 print(
-                    f"module={self.name},{maybe_class_name} dim={dim}, size={size_str},"
-                    f" {stats_type} {ans}"
+                    f"module={self.name},{maybe_class_name} dim={dim}, size={size_str}, {stats_type} {ans}"
                 )
 
 
+
 class ModelDiagnostic(object):
     """This class stores diagnostics for all tensors in the torch.nn.Module.
 
@@ -339,32 +345,32 @@ def attach_diagnostics(
         # (matters for name, since the variable gets overwritten).
         # These closures don't really capture by value, only by
         # "the final value the variable got in the function" :-(
-        def forward_hook(_module, _input, _output, _model_diagnostic=ans, _name=name):
+        def forward_hook(
+            _module, _input, _output, _model_diagnostic=ans, _name=name
+        ):
             if isinstance(_output, tuple) and len(_output) == 1:
                 _output = _output[0]
 
             if isinstance(_output, Tensor):
-                _model_diagnostic[f"{_name}.output"].accumulate(
-                    _output, class_name=type(_module).__name__
-                )
+                _model_diagnostic[f"{_name}.output"].accumulate(_output,
+                                                                class_name=type(_module).__name__)
             elif isinstance(_output, tuple):
                 for i, o in enumerate(_output):
-                    _model_diagnostic[f"{_name}.output[{i}]"].accumulate(
-                        o, class_name=type(_module).__name__
-                    )
+                    _model_diagnostic[f"{_name}.output[{i}]"].accumulate(o,
+                                                                         class_name=type(_module).__name__)
 
-        def backward_hook(_module, _input, _output, _model_diagnostic=ans, _name=name):
+        def backward_hook(
+            _module, _input, _output, _model_diagnostic=ans, _name=name
+        ):
             if isinstance(_output, tuple) and len(_output) == 1:
                 _output = _output[0]
             if isinstance(_output, Tensor):
-                _model_diagnostic[f"{_name}.grad"].accumulate(
-                    _output, class_name=type(_module).__name__
-                )
+                _model_diagnostic[f"{_name}.grad"].accumulate(_output,
+                                                              class_name=type(_module).__name__)
             elif isinstance(_output, tuple):
                 for i, o in enumerate(_output):
-                    _model_diagnostic[f"{_name}.grad[{i}]"].accumulate(
-                        o, class_name=type(_module).__name__
-                    )
+                    _model_diagnostic[f"{_name}.grad[{i}]"].accumulate(o,
+                                                                       class_name=type(_module).__name__)
 
         module.register_forward_hook(forward_hook)
         module.register_backward_hook(backward_hook)
diff --git a/icefall/dist.py b/icefall/dist.py
index 9df1c5bd1..7016beafb 100644
--- a/icefall/dist.py
+++ b/icefall/dist.py
@@ -29,7 +29,9 @@ def setup_dist(rank, world_size, master_port=None, use_ddp_launch=False):
         os.environ["MASTER_ADDR"] = "localhost"
 
     if "MASTER_PORT" not in os.environ:
-        os.environ["MASTER_PORT"] = "12354" if master_port is None else str(master_port)
+        os.environ["MASTER_PORT"] = (
+            "12354" if master_port is None else str(master_port)
+        )
 
     if use_ddp_launch is False:
         dist.init_process_group("nccl", rank=rank, world_size=world_size)
diff --git a/icefall/env.py b/icefall/env.py
index 373e9a9ff..8aeda6be2 100644
--- a/icefall/env.py
+++ b/icefall/env.py
@@ -53,7 +53,9 @@ def get_git_sha1():
             )
             > 0
         )
-        git_commit = git_commit + "-dirty" if dirty_commit else git_commit + "-clean"
+        git_commit = (
+            git_commit + "-dirty" if dirty_commit else git_commit + "-clean"
+        )
     except:  # noqa
         return None
 
diff --git a/icefall/graph_compiler.py b/icefall/graph_compiler.py
index e2ff03f61..570ed7d7a 100644
--- a/icefall/graph_compiler.py
+++ b/icefall/graph_compiler.py
@@ -75,7 +75,9 @@ class CtcTrainingGraphCompiler(object):
 
         # NOTE: k2.compose runs on CUDA only when treat_epsilons_specially
         # is False, so we add epsilon self-loops here
-        fsa_with_self_loops = k2.remove_epsilon_and_add_self_loops(transcript_fsa)
+        fsa_with_self_loops = k2.remove_epsilon_and_add_self_loops(
+            transcript_fsa
+        )
 
         fsa_with_self_loops = k2.arc_sort(fsa_with_self_loops)
 
diff --git a/icefall/hooks.py b/icefall/hooks.py
index 398a5f689..fbcf5e148 100644
--- a/icefall/hooks.py
+++ b/icefall/hooks.py
@@ -14,11 +14,10 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import logging
 import random
-
 import torch
 from torch import Tensor, nn
+import logging
 
 
 def register_inf_check_hooks(model: nn.Module) -> None:
@@ -57,7 +56,7 @@ def register_inf_check_hooks(model: nn.Module) -> None:
             if isinstance(_output, Tensor):
                 if not torch.isfinite(_output.to(torch.float32).sum()):
                     logging.warning(
-                        f"The sum of {_name}.grad is not finite"  # ": {_output}"
+                        f"The sum of {_name}.grad is not finite" # ": {_output}"
                     )
             elif isinstance(_output, tuple):
                 for i, o in enumerate(_output):
@@ -66,20 +65,28 @@ def register_inf_check_hooks(model: nn.Module) -> None:
                     if not isinstance(o, Tensor):
                         continue
                     if not torch.isfinite(o.to(torch.float32).sum()):
-                        logging.warning(f"The sum of {_name}.grad[{i}] is not finite")
+                        logging.warning(
+                            f"The sum of {_name}.grad[{i}] is not finite"
+                        )
 
         module.register_forward_hook(forward_hook)
         module.register_backward_hook(backward_hook)
 
+
     for name, parameter in model.named_parameters():
 
-        def param_backward_hook(grad, _name=name):
+        def param_backward_hook(
+                grad, _name=name
+        ):
             if not torch.isfinite(grad.to(torch.float32).sum()):
-                logging.warning(f"The sum of {_name}.param_grad is not finite")
+                logging.warning(
+                    f"The sum of {_name}.param_grad is not finite"
+                )
 
         parameter.register_hook(param_backward_hook)
 
 
+
 def _test_inf_check_hooks():
     model = nn.Sequential(nn.Linear(100, 50), nn.Linear(50, 80))
 
diff --git a/icefall/lexicon.py b/icefall/lexicon.py
index 22e1b78bb..80bd7c1ee 100644
--- a/icefall/lexicon.py
+++ b/icefall/lexicon.py
@@ -49,12 +49,18 @@ def read_lexicon(filename: str) -> List[Tuple[str, List[str]]]:
                 continue
 
             if len(a) < 2:
-                logging.info(f"Found bad line {line} in lexicon file {filename}")
-                logging.info("Every line is expected to contain at least 2 fields")
+                logging.info(
+                    f"Found bad line {line} in lexicon file {filename}"
+                )
+                logging.info(
+                    "Every line is expected to contain at least 2 fields"
+                )
                 sys.exit(1)
             word = a[0]
             if word == "":
-                logging.info(f"Found bad line {line} in lexicon file {filename}")
+                logging.info(
+                    f"Found bad line {line} in lexicon file {filename}"
+                )
                 logging.info(" should not be a valid word")
                 sys.exit(1)
 
@@ -113,7 +119,9 @@ def convert_lexicon_to_ragged(
     lexicon_tmp = read_lexicon(filename)
     lexicon = dict(lexicon_tmp)
     if len(lexicon_tmp) != len(lexicon):
-        raise RuntimeError("It's assumed that each word has a unique pronunciation")
+        raise RuntimeError(
+            "It's assumed that each word has a unique pronunciation"
+        )
 
     for i in range(disambig_id):
         w = word_table[i]
diff --git a/icefall/mmi.py b/icefall/mmi.py
index 16ed6e032..2c479fc2c 100644
--- a/icefall/mmi.py
+++ b/icefall/mmi.py
@@ -63,7 +63,10 @@ def _compute_mmi_loss_exact_optimized(
 
     # [0, num_fsas, 1, num_fsas, 2, num_fsas, ... ]
     num_den_graphs_indexes = (
-        torch.stack([num_graphs_indexes, den_graphs_indexes]).t().reshape(-1).to(device)
+        torch.stack([num_graphs_indexes, den_graphs_indexes])
+        .t()
+        .reshape(-1)
+        .to(device)
     )
 
     num_den_reordered_graphs = k2.index(num_den_graphs, num_den_graphs_indexes)
@@ -112,12 +115,20 @@ def _compute_mmi_loss_exact_non_optimized(
     num_graphs, den_graphs = graph_compiler.compile(texts, replicate_den=True)
 
     # TODO: pass output_beam as function argument
-    num_lats = k2.intersect_dense(num_graphs, dense_fsa_vec, output_beam=beam_size)
-    den_lats = k2.intersect_dense(den_graphs, dense_fsa_vec, output_beam=beam_size)
+    num_lats = k2.intersect_dense(
+        num_graphs, dense_fsa_vec, output_beam=beam_size
+    )
+    den_lats = k2.intersect_dense(
+        den_graphs, dense_fsa_vec, output_beam=beam_size
+    )
 
-    num_tot_scores = num_lats.get_tot_scores(log_semiring=True, use_double_scores=True)
+    num_tot_scores = num_lats.get_tot_scores(
+        log_semiring=True, use_double_scores=True
+    )
 
-    den_tot_scores = den_lats.get_tot_scores(log_semiring=True, use_double_scores=True)
+    den_tot_scores = den_lats.get_tot_scores(
+        log_semiring=True, use_double_scores=True
+    )
 
     tot_scores = num_tot_scores - den_scale * den_tot_scores
 
@@ -157,9 +168,13 @@ def _compute_mmi_loss_pruned(
         max_active_states=10000,
     )
 
-    num_tot_scores = num_lats.get_tot_scores(log_semiring=True, use_double_scores=True)
+    num_tot_scores = num_lats.get_tot_scores(
+        log_semiring=True, use_double_scores=True
+    )
 
-    den_tot_scores = den_lats.get_tot_scores(log_semiring=True, use_double_scores=True)
+    den_tot_scores = den_lats.get_tot_scores(
+        log_semiring=True, use_double_scores=True
+    )
 
     tot_scores = num_tot_scores - den_scale * den_tot_scores
 
diff --git a/icefall/mmi_graph_compiler.py b/icefall/mmi_graph_compiler.py
index 9f680f83d..0d901227d 100644
--- a/icefall/mmi_graph_compiler.py
+++ b/icefall/mmi_graph_compiler.py
@@ -137,7 +137,9 @@ class MmiTrainingGraphCompiler(object):
             transcript_fsa
         )
 
-        transcript_fsa_with_self_loops = k2.arc_sort(transcript_fsa_with_self_loops)
+        transcript_fsa_with_self_loops = k2.arc_sort(
+            transcript_fsa_with_self_loops
+        )
 
         num = k2.compose(
             self.ctc_topo_P,
@@ -153,7 +155,9 @@ class MmiTrainingGraphCompiler(object):
 
         ctc_topo_P_vec = k2.create_fsa_vec([self.ctc_topo_P])
         if replicate_den:
-            indexes = torch.zeros(len(texts), dtype=torch.int32, device=self.device)
+            indexes = torch.zeros(
+                len(texts), dtype=torch.int32, device=self.device
+            )
             den = k2.index_fsa(ctc_topo_P_vec, indexes)
         else:
             den = ctc_topo_P_vec
diff --git a/icefall/rnn_lm/compute_perplexity.py b/icefall/rnn_lm/compute_perplexity.py
index 9a275bf28..550801a8f 100755
--- a/icefall/rnn_lm/compute_perplexity.py
+++ b/icefall/rnn_lm/compute_perplexity.py
@@ -46,19 +46,16 @@ def get_parser():
         "--epoch",
         type=int,
         default=49,
-        help=(
-            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
-        ),
+        help="It specifies the checkpoint to use for decoding."
+        "Note: Epoch counts from 0.",
     )
     parser.add_argument(
         "--avg",
         type=int,
         default=20,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch'. "
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch'. ",
     )
 
     parser.add_argument(
@@ -197,7 +194,7 @@ def main():
 
     logging.info(f"Number of model parameters: {num_param}")
     logging.info(
-        "Number of model parameters (requires_grad): "
+        f"Number of model parameters (requires_grad): "
         f"{num_param_requires_grad} "
         f"({num_param_requires_grad/num_param_requires_grad*100}%)"
     )
diff --git a/icefall/rnn_lm/dataset.py b/icefall/rnn_lm/dataset.py
index 4bf982503..598e329c4 100644
--- a/icefall/rnn_lm/dataset.py
+++ b/icefall/rnn_lm/dataset.py
@@ -155,8 +155,12 @@ class LmDatasetCollate:
         sentence_tokens_with_sos = add_sos(sentence_tokens, self.sos_id)
         sentence_tokens_with_eos = add_eos(sentence_tokens, self.eos_id)
 
-        x = sentence_tokens_with_sos.pad(mode="constant", padding_value=self.blank_id)
-        y = sentence_tokens_with_eos.pad(mode="constant", padding_value=self.blank_id)
+        x = sentence_tokens_with_sos.pad(
+            mode="constant", padding_value=self.blank_id
+        )
+        y = sentence_tokens_with_eos.pad(
+            mode="constant", padding_value=self.blank_id
+        )
         sentence_token_lengths += 1  # plus 1 since we added a SOS
 
         return x.to(torch.int64), y.to(torch.int64), sentence_token_lengths
diff --git a/icefall/rnn_lm/export.py b/icefall/rnn_lm/export.py
index 2e878f5c8..094035fce 100644
--- a/icefall/rnn_lm/export.py
+++ b/icefall/rnn_lm/export.py
@@ -38,20 +38,17 @@ def get_parser():
         "--epoch",
         type=int,
         default=29,
-        help=(
-            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
-        ),
+        help="It specifies the checkpoint to use for decoding."
+        "Note: Epoch counts from 0.",
     )
 
     parser.add_argument(
         "--avg",
         type=int,
         default=5,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch'. "
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch'. ",
     )
 
     parser.add_argument(
@@ -162,7 +159,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/icefall/rnn_lm/model.py b/icefall/rnn_lm/model.py
index 9eef88840..a6144727a 100644
--- a/icefall/rnn_lm/model.py
+++ b/icefall/rnn_lm/model.py
@@ -129,7 +129,9 @@ class RnnLmModel(torch.nn.Module):
         tokens_eos = add_eos(tokens, eos_id)
         sos_tokens_row_splits = sos_tokens.shape.row_splits(1)
 
-        sentence_lengths = sos_tokens_row_splits[1:] - sos_tokens_row_splits[:-1]
+        sentence_lengths = (
+            sos_tokens_row_splits[1:] - sos_tokens_row_splits[:-1]
+        )
 
         x_tokens = sos_tokens.pad(mode="constant", padding_value=blank_id)
         y_tokens = tokens_eos.pad(mode="constant", padding_value=blank_id)
@@ -159,12 +161,12 @@ class RnnLmModel(torch.nn.Module):
         if state:
             h, c = state
         else:
-            h = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.input_size).to(
-                device
-            )
-            c = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.input_size).to(
-                device
-            )
+            h = torch.zeros(
+                self.rnn.num_layers, batch_size, self.rnn.input_size
+            ).to(device)
+            c = torch.zeros(
+                self.rnn.num_layers, batch_size, self.rnn.input_size
+            ).to(device)
 
         embedding = self.input_embedding(tokens)
         rnn_out, states = self.rnn(embedding, (h, c))
@@ -179,8 +181,12 @@ class RnnLmModel(torch.nn.Module):
         if state:
             h, c = state
         else:
-            h = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.input_size)
-            c = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.input_size)
+            h = torch.zeros(
+                self.rnn.num_layers, batch_size, self.rnn.input_size
+            )
+            c = torch.zeros(
+                self.rnn.num_layers, batch_size, self.rnn.input_size
+            )
 
         device = next(self.parameters()).device
 
@@ -188,7 +194,9 @@ class RnnLmModel(torch.nn.Module):
         tokens_eos = add_eos(tokens, eos_id)
         sos_tokens_row_splits = sos_tokens.shape.row_splits(1)
 
-        sentence_lengths = sos_tokens_row_splits[1:] - sos_tokens_row_splits[:-1]
+        sentence_lengths = (
+            sos_tokens_row_splits[1:] - sos_tokens_row_splits[:-1]
+        )
 
         x_tokens = sos_tokens.pad(mode="constant", padding_value=blank_id)
         y_tokens = tokens_eos.pad(mode="constant", padding_value=blank_id)
diff --git a/icefall/rnn_lm/train.py b/icefall/rnn_lm/train.py
index e17b50332..bb5f03fb9 100755
--- a/icefall/rnn_lm/train.py
+++ b/icefall/rnn_lm/train.py
@@ -446,13 +446,17 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+                tot_loss.write_summary(
+                    tb_writer, "train/tot_", params.batch_idx_train
+                )
 
                 tb_writer.add_scalar(
                     "train/current_ppl", this_batch_ppl, params.batch_idx_train
                 )
 
-                tb_writer.add_scalar("train/tot_ppl", tot_ppl, params.batch_idx_train)
+                tb_writer.add_scalar(
+                    "train/tot_ppl", tot_ppl, params.batch_idx_train
+                )
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
@@ -467,7 +471,8 @@ def train_one_epoch(
 
             valid_ppl = math.exp(valid_info["loss"] / valid_info["frames"])
             logging.info(
-                f"Epoch {params.cur_epoch}, validation: {valid_info}, ppl: {valid_ppl}"
+                f"Epoch {params.cur_epoch}, validation: {valid_info}, "
+                f"ppl: {valid_ppl}"
             )
 
             if tb_writer is not None:
diff --git a/icefall/shared/make_kn_lm.py b/icefall/shared/make_kn_lm.py
index a3bf1ef4c..c2edd823e 100755
--- a/icefall/shared/make_kn_lm.py
+++ b/icefall/shared/make_kn_lm.py
@@ -15,50 +15,30 @@
 # The data structure is based on: kaldi/egs/wsj/s5/utils/lang/make_phone_lm.py
 # The smoothing algorithm is based on: http://www.speech.sri.com/projects/srilm/manpages/ngram-discount.7.html
 
-import argparse
-import io
-import math
+import sys
 import os
 import re
-import sys
+import io
+import math
+import argparse
 from collections import Counter, defaultdict
 
-parser = argparse.ArgumentParser(
-    description="""
+
+parser = argparse.ArgumentParser(description="""
     Generate kneser-ney language model as arpa format. By default,
     it will read the corpus from standard input, and output to standard output.
-    """
-)
-parser.add_argument(
-    "-ngram-order",
-    type=int,
-    default=4,
-    choices=[2, 3, 4, 5, 6, 7],
-    help="Order of n-gram",
-)
+    """)
+parser.add_argument("-ngram-order", type=int, default=4, choices=[2, 3, 4, 5, 6, 7], help="Order of n-gram")
 parser.add_argument("-text", type=str, default=None, help="Path to the corpus file")
-parser.add_argument(
-    "-lm",
-    type=str,
-    default=None,
-    help="Path to output arpa file for language models",
-)
-parser.add_argument(
-    "-verbose",
-    type=int,
-    default=0,
-    choices=[0, 1, 2, 3, 4, 5],
-    help="Verbose level",
-)
+parser.add_argument("-lm", type=str, default=None, help="Path to output arpa file for language models")
+parser.add_argument("-verbose", type=int, default=0, choices=[0, 1, 2, 3, 4, 5], help="Verbose level")
 args = parser.parse_args()
 
-default_encoding = (
-    "latin-1"  # For encoding-agnostic scripts, we assume byte stream as input.
-)
-# Need to be very careful about the use of strip() and split()
-# in this case, because there is a latin-1 whitespace character
-# (nbsp) which is part of the unicode encoding range.
-# Ref: kaldi/egs/wsj/s5/utils/lang/bpe/prepend_words.py @ 69cd717
+default_encoding = "latin-1"  # For encoding-agnostic scripts, we assume byte stream as input.
+                              # Need to be very careful about the use of strip() and split()
+                              # in this case, because there is a latin-1 whitespace character
+                              # (nbsp) which is part of the unicode encoding range.
+                              # Ref: kaldi/egs/wsj/s5/utils/lang/bpe/prepend_words.py @ 69cd717
 strip_chars = " \t\r\n"
 whitespace = re.compile("[ \t]+")
 
@@ -72,9 +52,7 @@ class CountsForHistory:
         # The 'lambda: defaultdict(float)' is an anonymous function taking no
         # arguments that returns a new defaultdict(float).
         self.word_to_count = defaultdict(int)
-        self.word_to_context = defaultdict(
-            set
-        )  # using a set to count the number of unique contexts
+        self.word_to_context = defaultdict(set)  # using a set to count the number of unique contexts
         self.word_to_f = dict()  # discounted probability
         self.word_to_bow = dict()  # back-off weight
         self.total_count = 0
@@ -84,15 +62,10 @@ class CountsForHistory:
 
     def __str__(self):
         # e.g. returns ' total=12: 3->4, 4->6, -1->2'
-        return " total={0}: {1}".format(
+        return ' total={0}: {1}'.format(
             str(self.total_count),
-            ", ".join(
-                [
-                    "{0} -> {1}".format(word, count)
-                    for word, count in self.word_to_count.items()
-                ]
-            ),
-        )
+            ', '.join(['{0} -> {1}'.format(word, count)
+                      for word, count in self.word_to_count.items()]))
 
     def add_count(self, predicted_word, context_word, count):
         assert count >= 0
@@ -112,7 +85,7 @@ class NgramCounts:
     # accumulating the 4-gram count for the '8' in the sequence '5 6 7 8', we'd
     # do as follows: self.counts[3][[5,6,7]][8] += 1.0 where the [3] indexes an
     # array, the [[5,6,7]] indexes a dict, and the [8] indexes a dict.
-    def __init__(self, ngram_order, bos_symbol="", eos_symbol=""):
+    def __init__(self, ngram_order, bos_symbol='', eos_symbol=''):
         assert ngram_order >= 2
 
         self.ngram_order = ngram_order
@@ -130,48 +103,39 @@ class NgramCounts:
     # would be (6,7,8) and 'predicted_word' would be 9; 'count' would be
     # 1.
     def add_count(self, history, predicted_word, context_word, count):
-        self.counts[len(history)][history].add_count(
-            predicted_word, context_word, count
-        )
+        self.counts[len(history)][history].add_count(predicted_word, context_word, count)
 
     # 'line' is a string containing a sequence of integer word-ids.
     # This function adds the un-smoothed counts from this line of text.
     def add_raw_counts_from_line(self, line):
-        if line == "":
+        if line == '':
             words = [self.bos_symbol, self.eos_symbol]
         else:
             words = [self.bos_symbol] + whitespace.split(line) + [self.eos_symbol]
 
         for i in range(len(words)):
-            for n in range(1, self.ngram_order + 1):
+            for n in range(1, self.ngram_order+1):
                 if i + n > len(words):
                     break
-                ngram = words[i : i + n]
+                ngram = words[i: i + n]
                 predicted_word = ngram[-1]
-                history = tuple(ngram[:-1])
+                history = tuple(ngram[: -1])
                 if i == 0 or n == self.ngram_order:
                     context_word = None
                 else:
-                    context_word = words[i - 1]
+                    context_word = words[i-1]
 
                 self.add_count(history, predicted_word, context_word, 1)
 
     def add_raw_counts_from_standard_input(self):
         lines_processed = 0
-        infile = io.TextIOWrapper(
-            sys.stdin.buffer, encoding=default_encoding
-        )  # byte stream as input
+        infile = io.TextIOWrapper(sys.stdin.buffer, encoding=default_encoding)  # byte stream as input
         for line in infile:
             line = line.strip(strip_chars)
             self.add_raw_counts_from_line(line)
             lines_processed += 1
         if lines_processed == 0 or args.verbose > 0:
-            print(
-                "make_phone_lm.py: processed {0} lines of input".format(
-                    lines_processed
-                ),
-                file=sys.stderr,
-            )
+            print("make_phone_lm.py: processed {0} lines of input".format(lines_processed), file=sys.stderr)
 
     def add_raw_counts_from_file(self, filename):
         lines_processed = 0
@@ -181,12 +145,7 @@ class NgramCounts:
                 self.add_raw_counts_from_line(line)
                 lines_processed += 1
         if lines_processed == 0 or args.verbose > 0:
-            print(
-                "make_phone_lm.py: processed {0} lines of input".format(
-                    lines_processed
-                ),
-                file=sys.stderr,
-            )
+            print("make_phone_lm.py: processed {0} lines of input".format(lines_processed), file=sys.stderr)
 
     def cal_discounting_constants(self):
         # For each order N of N-grams, we calculate discounting constant D_N = n1_N / (n1_N + 2 * n2_N),
@@ -194,11 +153,9 @@ class NgramCounts:
         # This constant is used similarly to absolute discounting.
         # Return value: d is a list of floats, where d[N+1] = D_N
 
-        self.d = [
-            0
-        ]  # for the lowest order, i.e., 1-gram, we do not need to discount, thus the constant is 0
-        # This is a special case: as we currently assumed having seen all vocabularies in the dictionary,
-        # but perhaps this is not the case for some other scenarios.
+        self.d = [0]  # for the lowest order, i.e., 1-gram, we do not need to discount, thus the constant is 0
+                      # This is a special case: as we currently assumed having seen all vocabularies in the dictionary,
+                      # but perhaps this is not the case for some other scenarios.
         for n in range(1, self.ngram_order):
             this_order_counts = self.counts[n]
             n1 = 0
@@ -208,11 +165,9 @@ class NgramCounts:
                 n1 += stat[1]
                 n2 += stat[2]
             assert n1 + 2 * n2 > 0
-            self.d.append(
-                max(0.1, n1 * 1.0) / (n1 + 2 * n2)
-            )  # We are doing this max(0.001, xxx) to avoid zero discounting constant D due to n1=0,
-            # which could happen if the number of symbols is small.
-            # Otherwise, zero discounting constant can cause division by zero in computing BOW.
+            self.d.append(max(0.1, n1 * 1.0) / (n1 + 2 * n2))   # We are doing this max(0.001, xxx) to avoid zero discounting constant D due to n1=0, 
+                                                                # which could happen if the number of symbols is small.
+                                                                # Otherwise, zero discounting constant can cause division by zero in computing BOW.
 
     def cal_f(self):
         # f(a_z) is a probability distribution of word sequence a_z.
@@ -227,9 +182,7 @@ class NgramCounts:
         this_order_counts = self.counts[n]
         for hist, counts_for_hist in this_order_counts.items():
             for w, c in counts_for_hist.word_to_count.items():
-                counts_for_hist.word_to_f[w] = (
-                    max((c - self.d[n]), 0) * 1.0 / counts_for_hist.total_count
-                )
+                counts_for_hist.word_to_f[w] = max((c - self.d[n]), 0) * 1.0 / counts_for_hist.total_count
 
         # lower order N-grams
         for n in range(0, self.ngram_order - 1):
@@ -243,17 +196,11 @@ class NgramCounts:
                 if n_star_star != 0:
                     for w in counts_for_hist.word_to_count.keys():
                         n_star_z = len(counts_for_hist.word_to_context[w])
-                        counts_for_hist.word_to_f[w] = (
-                            max((n_star_z - self.d[n]), 0) * 1.0 / n_star_star
-                        )
+                        counts_for_hist.word_to_f[w] = max((n_star_z - self.d[n]), 0) * 1.0 / n_star_star
                 else:  # patterns begin with , they do not have "modified count", so use raw count instead
                     for w in counts_for_hist.word_to_count.keys():
                         n_star_z = counts_for_hist.word_to_count[w]
-                        counts_for_hist.word_to_f[w] = (
-                            max((n_star_z - self.d[n]), 0)
-                            * 1.0
-                            / counts_for_hist.total_count
-                        )
+                        counts_for_hist.word_to_f[w] = max((n_star_z - self.d[n]), 0) * 1.0 / counts_for_hist.total_count
 
     def cal_bow(self):
         # Backoff weights are only necessary for ngrams which form a prefix of a longer ngram.
@@ -293,18 +240,12 @@ class NgramCounts:
                         sum_z1_f_z = 0
                         _ = a_[1:]
                         _counts_for_hist = self.counts[len(_)][_]
-                        for (
-                            u
-                        ) in (
-                            a_counts_for_hist.word_to_count.keys()
-                        ):  # Should be careful here: what is Z1
+                        for u in a_counts_for_hist.word_to_count.keys():  # Should be careful here: what is Z1
                             sum_z1_f_z += _counts_for_hist.word_to_f[u]
 
                         if sum_z1_f_z < 1:
                             # assert sum_z1_f_a_z < 1
-                            counts_for_hist.word_to_bow[w] = (1.0 - sum_z1_f_a_z) / (
-                                1.0 - sum_z1_f_z
-                            )
+                            counts_for_hist.word_to_bow[w] = (1.0 - sum_z1_f_a_z) / (1.0 - sum_z1_f_z)
                         else:
                             counts_for_hist.word_to_bow[w] = None
 
@@ -318,9 +259,7 @@ class NgramCounts:
                     ngram = " ".join(hist) + " " + w
                     ngram = ngram.strip(strip_chars)
 
-                    res.append(
-                        "{0}\t{1}".format(ngram, counts_for_hist.word_to_count[w])
-                    )
+                    res.append("{0}\t{1}".format(ngram, counts_for_hist.word_to_count[w]))
         res.sort(reverse=True)
         for r in res:
             print(r)
@@ -383,40 +322,27 @@ class NgramCounts:
                     if bow is None:
                         res.append("{1}\t{0}".format(ngram, math.log(f, 10)))
                     else:
-                        res.append(
-                            "{1}\t{0}\t{2}".format(
-                                ngram, math.log(f, 10), math.log(bow, 10)
-                            )
-                        )
+                        res.append("{1}\t{0}\t{2}".format(ngram, math.log(f, 10), math.log(bow, 10)))
         res.sort(reverse=True)
         for r in res:
             print(r)
 
-    def print_as_arpa(
-        self, fout=io.TextIOWrapper(sys.stdout.buffer, encoding="latin-1")
-    ):
+    def print_as_arpa(self, fout=io.TextIOWrapper(sys.stdout.buffer, encoding='latin-1')):
         # print as ARPA format.
 
-        print("\\data\\", file=fout)
+        print('\\data\\', file=fout)
         for hist_len in range(self.ngram_order):
             # print the number of n-grams.
-            print(
-                "ngram {0}={1}".format(
-                    hist_len + 1,
-                    sum(
-                        [
-                            len(counts_for_hist.word_to_f)
-                            for counts_for_hist in self.counts[hist_len].values()
-                        ]
-                    ),
-                ),
-                file=fout,
+            print('ngram {0}={1}'.format(
+                hist_len + 1,
+                sum([len(counts_for_hist.word_to_f) for counts_for_hist in self.counts[hist_len].values()])),
+                file=fout
             )
 
-        print("", file=fout)
+        print('', file=fout)
 
         for hist_len in range(self.ngram_order):
-            print("\\{0}-grams:".format(hist_len + 1), file=fout)
+            print('\\{0}-grams:'.format(hist_len + 1), file=fout)
 
             this_order_counts = self.counts[hist_len]
             for hist, counts_for_hist in this_order_counts.items():
@@ -428,12 +354,12 @@ class NgramCounts:
                     if prob == 0:  # f() is always 0
                         prob = 1e-99
 
-                    line = "{0}\t{1}".format("%.7f" % math.log10(prob), " ".join(ngram))
+                    line = '{0}\t{1}'.format('%.7f' % math.log10(prob), ' '.join(ngram))
                     if bow is not None:
-                        line += "\t{0}".format("%.7f" % math.log10(bow))
+                        line += '\t{0}'.format('%.7f' % math.log10(bow))
                     print(line, file=fout)
-            print("", file=fout)
-        print("\\end\\", file=fout)
+            print('', file=fout)
+        print('\\end\\', file=fout)
 
 
 if __name__ == "__main__":
@@ -453,5 +379,5 @@ if __name__ == "__main__":
     if args.lm is None:
         ngram_counts.print_as_arpa()
     else:
-        with open(args.lm, "w", encoding=default_encoding) as f:
+        with open(args.lm, 'w', encoding=default_encoding) as f:
             ngram_counts.print_as_arpa(fout=f)
diff --git a/icefall/utils.py b/icefall/utils.py
index 785bd80f9..143c79497 100644
--- a/icefall/utils.py
+++ b/icefall/utils.py
@@ -130,7 +130,9 @@ def setup_logger(
         formatter = f"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] ({rank}/{world_size}) %(message)s"  # noqa
         log_filename = f"{log_filename}-{date_time}-{rank}"
     else:
-        formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+        formatter = (
+            "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+        )
         log_filename = f"{log_filename}-{date_time}"
 
     os.makedirs(os.path.dirname(log_filename), exist_ok=True)
@@ -201,7 +203,7 @@ def encode_supervisions(
                 supervisions["num_frames"],
                 subsampling_factor,
                 rounding_mode="floor",
-            ),
+            )
         ),
         1,
     ).to(torch.int32)
@@ -286,9 +288,13 @@ def get_texts_with_timestamp(
     """
     if isinstance(best_paths.aux_labels, k2.RaggedTensor):
         all_aux_shape = (
-            best_paths.arcs.shape().remove_axis(1).compose(best_paths.aux_labels.shape)
+            best_paths.arcs.shape()
+            .remove_axis(1)
+            .compose(best_paths.aux_labels.shape)
+        )
+        all_aux_labels = k2.RaggedTensor(
+            all_aux_shape, best_paths.aux_labels.values
         )
-        all_aux_labels = k2.RaggedTensor(all_aux_shape, best_paths.aux_labels.values)
         # remove 0's and -1's.
         aux_labels = best_paths.aux_labels.remove_values_leq(0)
         # TODO: change arcs.shape() to arcs.shape
@@ -357,7 +363,9 @@ def get_alignments(best_paths: k2.Fsa, kind: str) -> List[List[int]]:
     # arc.shape() has axes [fsa][state][arc], we remove "state"-axis here
     token_shape = best_paths.arcs.shape().remove_axis(1)
     # token_shape has axes [fsa][arc]
-    tokens = k2.RaggedTensor(token_shape, getattr(best_paths, kind).contiguous())
+    tokens = k2.RaggedTensor(
+        token_shape, getattr(best_paths, kind).contiguous()
+    )
     tokens = tokens.remove_values_eq(-1)
     return tokens.tolist()
 
@@ -578,7 +586,9 @@ def write_error_stats(
             f"{cut_id}:\t"
             + " ".join(
                 (
-                    ref_word if ref_word == hyp_word else f"({ref_word}->{hyp_word})"
+                    ref_word
+                    if ref_word == hyp_word
+                    else f"({ref_word}->{hyp_word})"
                     for ref_word, hyp_word in ali
                 )
             ),
@@ -588,7 +598,9 @@ def write_error_stats(
     print("", file=f)
     print("SUBSTITUTIONS: count ref -> hyp", file=f)
 
-    for count, (ref, hyp) in sorted([(v, k) for k, v in subs.items()], reverse=True):
+    for count, (ref, hyp) in sorted(
+        [(v, k) for k, v in subs.items()], reverse=True
+    ):
         print(f"{count}   {ref} -> {hyp}", file=f)
 
     print("", file=f)
@@ -602,7 +614,9 @@ def write_error_stats(
         print(f"{count}   {hyp}", file=f)
 
     print("", file=f)
-    print("PER-WORD STATS: word  corr tot_errs count_in_ref count_in_hyp", file=f)
+    print(
+        "PER-WORD STATS: word  corr tot_errs count_in_ref count_in_hyp", file=f
+    )
     for _, word, counts in sorted(
         [(sum(v[1:]), k, v) for k, v in words.items()], reverse=True
     ):
@@ -777,7 +791,9 @@ def write_error_stats_with_timestamps(
             f"{cut_id}:\t"
             + " ".join(
                 (
-                    ref_word if ref_word == hyp_word else f"({ref_word}->{hyp_word})"
+                    ref_word
+                    if ref_word == hyp_word
+                    else f"({ref_word}->{hyp_word})"
                     for ref_word, hyp_word in ali
                 )
             ),
@@ -787,7 +803,9 @@ def write_error_stats_with_timestamps(
     print("", file=f)
     print("SUBSTITUTIONS: count ref -> hyp", file=f)
 
-    for count, (ref, hyp) in sorted([(v, k) for k, v in subs.items()], reverse=True):
+    for count, (ref, hyp) in sorted(
+        [(v, k) for k, v in subs.items()], reverse=True
+    ):
         print(f"{count}   {ref} -> {hyp}", file=f)
 
     print("", file=f)
@@ -801,7 +819,9 @@ def write_error_stats_with_timestamps(
         print(f"{count}   {hyp}", file=f)
 
     print("", file=f)
-    print("PER-WORD STATS: word  corr tot_errs count_in_ref count_in_hyp", file=f)
+    print(
+        "PER-WORD STATS: word  corr tot_errs count_in_ref count_in_hyp", file=f
+    )
     for _, word, counts in sorted(
         [(sum(v[1:]), k, v) for k, v in words.items()], reverse=True
     ):
@@ -871,7 +891,9 @@ class MetricsTracker(collections.defaultdict):
             if k == "frames" or k == "utterances":
                 continue
             norm_value = (
-                float(v) / num_frames if "utt_" not in k else float(v) / num_utterances
+                float(v) / num_frames
+                if "utt_" not in k
+                else float(v) / num_utterances
             )
             ans.append((k, norm_value))
         return ans
@@ -905,7 +927,9 @@ class MetricsTracker(collections.defaultdict):
             tb_writer.add_scalar(prefix + k, v, batch_idx)
 
 
-def concat(ragged: k2.RaggedTensor, value: int, direction: str) -> k2.RaggedTensor:
+def concat(
+    ragged: k2.RaggedTensor, value: int, direction: str
+) -> k2.RaggedTensor:
     """Prepend a value to the beginning of each sublist or append a value.
     to the end of each sublist.
 
@@ -951,8 +975,8 @@ def concat(ragged: k2.RaggedTensor, value: int, direction: str) -> k2.RaggedTens
         ans = k2.ragged.cat([ragged, pad], axis=1)
     else:
         raise ValueError(
-            f'Unsupported direction: {direction}. "             "Expect either "left"'
-            ' or "right"'
+            f'Unsupported direction: {direction}. " \
+            "Expect either "left" or "right"'
         )
     return ans
 
@@ -1077,7 +1101,9 @@ def linf_norm(x):
     return torch.max(torch.abs(x))
 
 
-def measure_weight_norms(model: nn.Module, norm: str = "l2") -> Dict[str, float]:
+def measure_weight_norms(
+    model: nn.Module, norm: str = "l2"
+) -> Dict[str, float]:
     """
     Compute the norms of the model's parameters.
 
@@ -1100,7 +1126,9 @@ def measure_weight_norms(model: nn.Module, norm: str = "l2") -> Dict[str, float]
         return norms
 
 
-def measure_gradient_norms(model: nn.Module, norm: str = "l1") -> Dict[str, float]:
+def measure_gradient_norms(
+    model: nn.Module, norm: str = "l1"
+) -> Dict[str, float]:
     """
     Compute the norms of the gradients for each of model's parameters.
 
@@ -1385,7 +1413,9 @@ def parse_hyp_and_timestamp(
         use_word_table = True
 
     for i in range(N):
-        time = convert_timestamp(res.timestamps[i], subsampling_factor, frame_shift_ms)
+        time = convert_timestamp(
+            res.timestamps[i], subsampling_factor, frame_shift_ms
+        )
         if use_word_table:
             words = [word_table[i] for i in res.hyps[i]]
         else:
diff --git a/pyproject.toml b/pyproject.toml
index 3183055d4..b4f8c3377 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -3,7 +3,7 @@ profile = "black"
 skip = ["icefall/__init__.py"]
 
 [tool.black]
-line-length = 88
+line-length = 80
 exclude = '''
 /(
     \.git
diff --git a/setup.py b/setup.py
index ccd2503ff..6c720e121 100644
--- a/setup.py
+++ b/setup.py
@@ -1,8 +1,7 @@
 #!/usr/bin/env python3
 
-from pathlib import Path
-
 from setuptools import find_packages, setup
+from pathlib import Path
 
 icefall_dir = Path(__file__).parent
 install_requires = (icefall_dir / "requirements.txt").read_text().splitlines()
diff --git a/test/test_checkpoint.py b/test/test_checkpoint.py
index 34e829642..511a11c23 100644
--- a/test/test_checkpoint.py
+++ b/test/test_checkpoint.py
@@ -20,7 +20,11 @@ import pytest
 import torch
 import torch.nn as nn
 
-from icefall.checkpoint import average_checkpoints, load_checkpoint, save_checkpoint
+from icefall.checkpoint import (
+    average_checkpoints,
+    load_checkpoint,
+    save_checkpoint,
+)
 
 
 @pytest.fixture
diff --git a/test/test_decode.py b/test/test_decode.py
index 4c2e192a7..97964ac67 100644
--- a/test/test_decode.py
+++ b/test/test_decode.py
@@ -23,7 +23,6 @@ You can run this file in one of the two ways:
 """
 
 import k2
-
 from icefall.decode import Nbest
 
 
diff --git a/test/test_graph_compiler.py b/test/test_graph_compiler.py
index 10443cf22..ccfb57d49 100644
--- a/test/test_graph_compiler.py
+++ b/test/test_graph_compiler.py
@@ -154,7 +154,9 @@ class TestCtcTrainingGraphCompiler(object):
         fsas = k2.Fsa.from_fsas([fsa1, fsa2])
 
         decoding_graph = k2.arc_sort(decoding_graph)
-        lattice = k2.intersect(decoding_graph, fsas, treat_epsilons_specially=False)
+        lattice = k2.intersect(
+            decoding_graph, fsas, treat_epsilons_specially=False
+        )
         lattice = k2.connect(lattice)
 
         aux_labels0 = lattice[0].aux_labels[:-1]
diff --git a/test/test_utils.py b/test/test_utils.py
index 31f06bd51..6a9ce7853 100644
--- a/test/test_utils.py
+++ b/test/test_utils.py
@@ -50,7 +50,9 @@ def test_encode_supervisions(sup):
     assert torch.all(
         torch.eq(
             supervision_segments,
-            torch.tensor([[1, 0, 30 // 4], [0, 0, 20 // 4], [2, 9 // 4, 10 // 4]]),
+            torch.tensor(
+                [[1, 0, 30 // 4], [0, 0, 20 // 4], [2, 9 // 4, 10 // 4]]
+            ),
         )
     )
     assert texts == ["two", "one", "three"]

From 107df3b115a58f1b68a6458c3f94a130004be34c Mon Sep 17 00:00:00 2001
From: Desh Raj 
Date: Thu, 17 Nov 2022 09:42:17 -0500
Subject: [PATCH 011/174] apply black on all files

---
 .github/workflows/style_check.yml             |  11 +-
 .pre-commit-config.yaml                       |  28 +-
 docker/README.md                              |  24 +-
 .../Dockerfile                                |  14 +-
 .../Dockerfile                                |  17 +-
 .../images/k2-gt-v1.9-blueviolet.svg          |   2 +-
 .../images/python-gt-v3.6-blue.svg            |   2 +-
 .../images/torch-gt-v1.6.0-green.svg          |   2 +-
 docs/source/recipes/aishell/index.rst         |   1 -
 docs/source/recipes/timit/index.rst           |   1 -
 docs/source/recipes/timit/tdnn_ligru_ctc.rst  |  28 +-
 docs/source/recipes/timit/tdnn_lstm_ctc.rst   |  24 +-
 .../local/compute_fbank_aidatatang_200zh.py   |   8 +-
 .../ASR/local/prepare_char.py                 |   8 +-
 .../ASR/local/prepare_lang.py                 |   4 +-
 .../ASR/local/test_prepare_lang.py            |   4 +-
 egs/aidatatang_200zh/ASR/local/text2token.py  |  15 +-
 egs/aidatatang_200zh/ASR/prepare.sh           |   3 +-
 .../asr_datamodule.py                         |  20 +-
 .../pruned_transducer_stateless2/decode.py    |  25 +-
 .../pruned_transducer_stateless2/export.py    |   7 +-
 .../pretrained.py                             |  19 +-
 .../ASR/pruned_transducer_stateless2/train.py |  29 +-
 egs/aishell/ASR/conformer_ctc/conformer.py    |  67 +-
 egs/aishell/ASR/conformer_ctc/decode.py       |  16 +-
 egs/aishell/ASR/conformer_ctc/export.py       |   4 +-
 egs/aishell/ASR/conformer_ctc/pretrained.py   |  11 +-
 egs/aishell/ASR/conformer_ctc/subsampling.py  |  16 +-
 .../ASR/conformer_ctc/test_subsampling.py     |   3 +-
 egs/aishell/ASR/conformer_ctc/train.py        |  12 +-
 egs/aishell/ASR/conformer_ctc/transformer.py  |  44 +-
 egs/aishell/ASR/conformer_mmi/conformer.py    |  67 +-
 egs/aishell/ASR/conformer_mmi/decode.py       |  20 +-
 egs/aishell/ASR/conformer_mmi/subsampling.py  |  16 +-
 egs/aishell/ASR/conformer_mmi/train.py        |   8 +-
 egs/aishell/ASR/conformer_mmi/transformer.py  |  44 +-
 .../local/compute_fbank_aidatatang_200zh.py   |   8 +-
 .../ASR/local/compute_fbank_aishell.py        |   8 +-
 egs/aishell/ASR/local/prepare_char.py         |   8 +-
 egs/aishell/ASR/local/prepare_lang.py         |   4 +-
 egs/aishell/ASR/local/test_prepare_lang.py    |   4 +-
 .../pruned_transducer_stateless2/decode.py    |  36 +-
 .../pruned_transducer_stateless2/export.py    |  23 +-
 .../pretrained.py                             |  22 +-
 .../ASR/pruned_transducer_stateless2/train.py |  43 +-
 .../pruned_transducer_stateless3/decode.py    |  39 +-
 .../pruned_transducer_stateless3/export.py    |  26 +-
 .../ASR/pruned_transducer_stateless3/model.py |   8 +-
 .../pretrained.py                             |  22 +-
 .../ASR/pruned_transducer_stateless3/train.py |  58 +-
 .../ASR/tdnn_lstm_ctc/asr_datamodule.py       |  28 +-
 egs/aishell/ASR/tdnn_lstm_ctc/decode.py       |  20 +-
 egs/aishell/ASR/tdnn_lstm_ctc/model.py        |   5 +-
 egs/aishell/ASR/tdnn_lstm_ctc/pretrained.py   |  15 +-
 egs/aishell/ASR/tdnn_lstm_ctc/train.py        |   7 +-
 .../ASR/transducer_stateless/beam_search.py   |  22 +-
 .../ASR/transducer_stateless/conformer.py     |  67 +-
 .../ASR/transducer_stateless/decode.py        |  26 +-
 .../ASR/transducer_stateless/decoder.py       |   4 +-
 .../ASR/transducer_stateless/export.py        |   7 +-
 egs/aishell/ASR/transducer_stateless/model.py |   4 +-
 .../ASR/transducer_stateless/pretrained.py    |  14 +-
 egs/aishell/ASR/transducer_stateless/train.py |  15 +-
 .../ASR/transducer_stateless/transformer.py   |   4 +-
 .../asr_datamodule.py                         |  17 +-
 .../transducer_stateless_modified-2/decode.py |  27 +-
 .../transducer_stateless_modified-2/export.py |   7 +-
 .../pretrained.py                             |  22 +-
 .../transducer_stateless_modified-2/train.py  |  22 +-
 .../transducer_stateless_modified/decode.py   |  27 +-
 .../transducer_stateless_modified/export.py   |   7 +-
 .../pretrained.py                             |  22 +-
 .../transducer_stateless_modified/train.py    |  15 +-
 egs/aishell2/ASR/local/__init__.py            |   0
 .../ASR/local/compute_fbank_aishell2.py       |   8 +-
 .../pruned_transducer_stateless5/__init__.py  |   0
 .../asr_datamodule.py                         |  24 +-
 .../pruned_transducer_stateless5/decode.py    |  39 +-
 .../pruned_transducer_stateless5/export.py    |  19 +-
 .../pretrained.py                             |  18 +-
 .../ASR/pruned_transducer_stateless5/train.py |  46 +-
 .../ASR/local/compute_fbank_aishell4.py       |   8 +-
 egs/aishell4/ASR/local/prepare_char.py        |   8 +-
 egs/aishell4/ASR/local/prepare_lang.py        |   4 +-
 egs/aishell4/ASR/local/test_prepare_lang.py   |   4 +-
 egs/aishell4/ASR/local/text2token.py          |  15 +-
 .../asr_datamodule.py                         |  20 +-
 .../pruned_transducer_stateless5/decode.py    |  35 +-
 .../pruned_transducer_stateless5/export.py    |  19 +-
 .../pretrained.py                             |  23 +-
 .../ASR/pruned_transducer_stateless5/train.py |  38 +-
 .../ASR/local/compute_fbank_alimeeting.py     |   8 +-
 egs/alimeeting/ASR/local/prepare_char.py      |   8 +-
 egs/alimeeting/ASR/local/prepare_lang.py      |   4 +-
 egs/alimeeting/ASR/local/test_prepare_lang.py |   4 +-
 egs/alimeeting/ASR/local/text2segments.py     |   2 +-
 egs/alimeeting/ASR/local/text2token.py        |  15 +-
 .../asr_datamodule.py                         |  20 +-
 .../pruned_transducer_stateless2/decode.py    |  35 +-
 .../pruned_transducer_stateless2/export.py    |   7 +-
 .../pretrained.py                             |  19 +-
 .../ASR/pruned_transducer_stateless2/train.py |  29 +-
 egs/csj/ASR/.gitignore                        |   2 +-
 egs/csj/ASR/local/compute_fbank_csj.py        |  38 +-
 egs/csj/ASR/local/compute_fbank_musan.py      |  17 +-
 egs/csj/ASR/local/conf/disfluent.ini          |  55 +-
 egs/csj/ASR/local/conf/fluent.ini             |  55 +-
 egs/csj/ASR/local/conf/number.ini             |  55 +-
 egs/csj/ASR/local/conf/symbol.ini             |  55 +-
 .../ASR/local/display_manifest_statistics.py  |   4 +-
 egs/csj/ASR/local/prepare_lang_char.py        |  14 +-
 egs/csj/ASR/local/validate_manifest.py        |   4 +-
 .../ASR/conformer_ctc/asr_datamodule.py       |  27 +-
 egs/gigaspeech/ASR/conformer_ctc/conformer.py |  63 +-
 egs/gigaspeech/ASR/conformer_ctc/decode.py    |  16 +-
 .../ASR/conformer_ctc/label_smoothing.py      |   7 +-
 .../ASR/conformer_ctc/subsampling.py          |  16 +-
 egs/gigaspeech/ASR/conformer_ctc/train.py     |  12 +-
 .../ASR/conformer_ctc/transformer.py          |  49 +-
 .../compute_fbank_gigaspeech_dev_test.py      |   4 +-
 .../local/compute_fbank_gigaspeech_splits.py  |   4 +-
 .../ASR/local/preprocess_gigaspeech.py        |  10 +-
 .../asr_datamodule.py                         |  27 +-
 .../pruned_transducer_stateless2/decode.py    |  28 +-
 .../pruned_transducer_stateless2/export.py    |  16 +-
 .../ASR/pruned_transducer_stateless2/train.py |  27 +-
 egs/librispeech/ASR/conformer_ctc/ali.py      |  12 +-
 .../ASR/conformer_ctc/conformer.py            |  63 +-
 egs/librispeech/ASR/conformer_ctc/decode.py   |  16 +-
 egs/librispeech/ASR/conformer_ctc/export.py   |   4 +-
 .../ASR/conformer_ctc/label_smoothing.py      |   7 +-
 .../ASR/conformer_ctc/pretrained.py           |  11 +-
 .../ASR/conformer_ctc/subsampling.py          |  16 +-
 egs/librispeech/ASR/conformer_ctc/train.py    |  20 +-
 .../ASR/conformer_ctc/transformer.py          |  49 +-
 .../ASR/conformer_ctc2/attention.py           |  19 +-
 .../ASR/conformer_ctc2/conformer.py           |  62 +-
 egs/librispeech/ASR/conformer_ctc2/decode.py  |  28 +-
 egs/librispeech/ASR/conformer_ctc2/export.py  |  21 +-
 egs/librispeech/ASR/conformer_ctc2/train.py   |  34 +-
 .../ASR/conformer_ctc2/transformer.py         |  46 +-
 .../ASR/conformer_mmi/conformer.py            |  67 +-
 egs/librispeech/ASR/conformer_mmi/decode.py   |  16 +-
 .../ASR/conformer_mmi/subsampling.py          |  16 +-
 .../ASR/conformer_mmi/test_subsampling.py     |   3 +-
 .../ASR/conformer_mmi/test_transformer.py     |   9 +-
 .../ASR/conformer_mmi/train-with-attention.py |  27 +-
 egs/librispeech/ASR/conformer_mmi/train.py    |  27 +-
 .../ASR/conformer_mmi/transformer.py          |  28 +-
 .../decode.py                                 |  35 +-
 .../emformer.py                               | 119 +---
 .../export.py                                 |  19 +-
 .../stream.py                                 |   8 +-
 .../streaming_decode.py                       |  42 +-
 .../train.py                                  |  35 +-
 .../decode.py                                 |  35 +-
 .../emformer.py                               | 108 +--
 .../export.py                                 |  19 +-
 .../streaming_decode.py                       |  42 +-
 .../train.py                                  |  35 +-
 .../ASR/local/add_alignment_librispeech.py    |  12 +-
 egs/librispeech/ASR/local/compile_hlg.py      |   6 +-
 egs/librispeech/ASR/local/compile_lg.py       |   4 +-
 .../compute_fbank_gigaspeech_dev_test.py      |   4 +-
 .../local/compute_fbank_gigaspeech_splits.py  |   4 +-
 .../ASR/local/compute_fbank_librispeech.py    |   8 +-
 .../ASR/local/compute_fbank_musan.py          |   8 +-
 .../convert_transcript_words_to_tokens.py     |   8 +-
 egs/librispeech/ASR/local/download_lm.py      |   4 +-
 egs/librispeech/ASR/local/filter_cuts.py      |  10 +-
 .../ASR/local/generate_unique_lexicon.py      |   4 +-
 egs/librispeech/ASR/local/prepare_lang_bpe.py |   4 +-
 .../ASR/local/prepare_lm_training_data.py     |  11 +-
 .../ASR/local/preprocess_gigaspeech.py        |   4 +-
 .../ASR/local/test_prepare_lang.py            |   4 +-
 .../ASR/local/validate_manifest.py            |   4 +-
 .../ASR/lstm_transducer_stateless/decode.py   |  39 +-
 .../ASR/lstm_transducer_stateless/export.py   |  19 +-
 .../jit_pretrained.py                         |   7 +-
 .../ASR/lstm_transducer_stateless/lstm.py     |  14 +-
 .../ASR/lstm_transducer_stateless/model.py    |   8 +-
 .../lstm_transducer_stateless/pretrained.py   |  18 +-
 .../ASR/lstm_transducer_stateless/stream.py   |   8 +-
 .../streaming_decode.py                       |  41 +-
 .../ASR/lstm_transducer_stateless/train.py    |  40 +-
 .../ASR/lstm_transducer_stateless2/decode.py  |  39 +-
 .../ASR/lstm_transducer_stateless2/export.py  |  31 +-
 .../jit_pretrained.py                         |   7 +-
 .../ASR/lstm_transducer_stateless2/model.py   |   8 +-
 .../lstm_transducer_stateless2/ncnn-decode.py |  11 +-
 .../lstm_transducer_stateless2/pretrained.py  |  18 +-
 .../streaming-ncnn-decode.py                  |  23 +-
 .../streaming-onnx-decode.py                  |  31 +-
 .../ASR/lstm_transducer_stateless2/train.py   |  47 +-
 .../ASR/lstm_transducer_stateless3/decode.py  |  51 +-
 .../ASR/lstm_transducer_stateless3/export.py  |  19 +-
 .../jit_pretrained.py                         |   7 +-
 .../ASR/lstm_transducer_stateless3/lstm.py    |  14 +-
 .../lstm_transducer_stateless3/pretrained.py  |  18 +-
 .../streaming_decode.py                       |  41 +-
 .../ASR/lstm_transducer_stateless3/train.py   |  45 +-
 .../ASR/pruned2_knowledge/asr_datamodule.py   |  35 +-
 .../ASR/pruned2_knowledge/beam_search.py      |  18 +-
 .../ASR/pruned2_knowledge/conformer.py        |  82 +--
 .../ASR/pruned2_knowledge/decode.py           |  25 +-
 .../ASR/pruned2_knowledge/decoder.py          |   4 +-
 .../ASR/pruned2_knowledge/decoder2.py         |  81 ++-
 .../ASR/pruned2_knowledge/export.py           |   7 +-
 .../ASR/pruned2_knowledge/joiner.py           |   4 +-
 .../ASR/pruned2_knowledge/model.py            |   8 +-
 .../ASR/pruned2_knowledge/optim.py            |  35 +-
 .../ASR/pruned2_knowledge/sampling.py         | 181 ++---
 .../ASR/pruned2_knowledge/scaling.py          |  51 +-
 .../ASR/pruned2_knowledge/scaling_tmp.py      | 355 ++++++----
 .../ASR/pruned2_knowledge/train.py            |  29 +-
 .../pruned_stateless_emformer_rnnt2/decode.py |  35 +-
 .../emformer.py                               |   8 +-
 .../pruned_stateless_emformer_rnnt2/export.py |  19 +-
 .../pruned_stateless_emformer_rnnt2/model.py  |   4 +-
 .../pruned_stateless_emformer_rnnt2/train.py  |  23 +-
 .../beam_search.py                            |  26 +-
 .../ASR/pruned_transducer_stateless/decode.py |  36 +-
 .../decode_stream.py                          |  19 +-
 .../pruned_transducer_stateless/decoder.py    |   4 +-
 .../ASR/pruned_transducer_stateless/export.py |   7 +-
 .../ASR/pruned_transducer_stateless/model.py  |   4 +-
 .../pruned_transducer_stateless/pretrained.py |  14 +-
 .../streaming_beam_search.py                  |   8 +-
 .../streaming_decode.py                       |  31 +-
 .../ASR/pruned_transducer_stateless/train.py  |  25 +-
 .../beam_search.py                            |  51 +-
 .../pruned_transducer_stateless2/conformer.py |  94 +--
 .../pruned_transducer_stateless2/decode.py    |  36 +-
 .../pruned_transducer_stateless2/decoder.py   |   8 +-
 .../pruned_transducer_stateless2/export.py    |  16 +-
 .../pruned_transducer_stateless2/joiner.py    |   4 +-
 .../ASR/pruned_transducer_stateless2/model.py |   8 +-
 .../ASR/pruned_transducer_stateless2/optim.py |  35 +-
 .../pretrained.py                             |  14 +-
 .../pruned_transducer_stateless2/scaling.py   |  53 +-
 .../streaming_beam_search.py                  |  12 +-
 .../streaming_decode.py                       |  31 +-
 .../ASR/pruned_transducer_stateless2/train.py |  37 +-
 .../asr_datamodule.py                         |  17 +-
 .../decode-giga.py                            |  32 +-
 .../pruned_transducer_stateless3/decode.py    |  50 +-
 .../pruned_transducer_stateless3/export.py    |  24 +-
 .../gigaspeech.py                             |   8 +-
 .../jit_pretrained.py                         |   7 +-
 .../ASR/pruned_transducer_stateless3/model.py |   8 +-
 .../onnx_check.py                             |  24 +-
 .../onnx_pretrained.py                        |  13 +-
 .../pretrained.py                             |  14 +-
 .../scaling_converter.py                      |   4 +-
 .../streaming_decode.py                       |  31 +-
 .../pruned_transducer_stateless3/test_onnx.py |  24 +-
 .../ASR/pruned_transducer_stateless3/train.py |  44 +-
 .../pruned_transducer_stateless4/decode.py    |  51 +-
 .../pruned_transducer_stateless4/export.py    |  19 +-
 .../streaming_decode.py                       |  34 +-
 .../ASR/pruned_transducer_stateless4/train.py |  40 +-
 .../pruned_transducer_stateless5/conformer.py | 112 +---
 .../pruned_transducer_stateless5/decode.py    |  39 +-
 .../pruned_transducer_stateless5/export.py    |  19 +-
 .../pretrained.py                             |  18 +-
 .../streaming_decode.py                       |  34 +-
 .../ASR/pruned_transducer_stateless5/train.py |  45 +-
 .../pruned_transducer_stateless6/conformer.py |  64 +-
 .../pruned_transducer_stateless6/decode.py    |  35 +-
 .../pruned_transducer_stateless6/export.py    |  16 +-
 .../extract_codebook_index.py                 |   3 +-
 .../hubert_decode.py                          |  17 +-
 .../hubert_xlarge.py                          |  22 +-
 .../ASR/pruned_transducer_stateless6/model.py |  12 +-
 .../ASR/pruned_transducer_stateless6/train.py |  44 +-
 .../pruned_transducer_stateless6/vq_utils.py  |  28 +-
 .../pruned_transducer_stateless7/decode.py    |  39 +-
 .../pruned_transducer_stateless7/decoder.py   |   6 +-
 .../pruned_transducer_stateless7/export.py    |  19 +-
 .../jit_pretrained.py                         |   7 +-
 .../pruned_transducer_stateless7/joiner.py    |   4 +-
 .../ASR/pruned_transducer_stateless7/model.py |  16 +-
 .../ASR/pruned_transducer_stateless7/optim.py | 436 ++++++------
 .../pretrained.py                             |  18 +-
 .../pruned_transducer_stateless7/scaling.py   | 481 +++++++-------
 .../scaling_converter.py                      |   6 +-
 .../ASR/pruned_transducer_stateless7/train.py |  48 +-
 .../pruned_transducer_stateless7/zipformer.py | 625 +++++++++---------
 .../pruned_transducer_stateless8/decode.py    |  39 +-
 .../pruned_transducer_stateless8/export.py    |  19 +-
 .../jit_pretrained.py                         |   7 +-
 .../ASR/pruned_transducer_stateless8/model.py |   4 +-
 .../pretrained.py                             |  18 +-
 .../ASR/pruned_transducer_stateless8/train.py |  59 +-
 .../ASR/streaming_conformer_ctc/README.md     |  16 +-
 .../ASR/streaming_conformer_ctc/conformer.py  | 113 +---
 .../streaming_decode.py                       |  34 +-
 .../ASR/streaming_conformer_ctc/train.py      |  16 +-
 .../streaming_conformer_ctc/transformer.py    |  40 +-
 .../ASR/tdnn_lstm_ctc/asr_datamodule.py       |  23 +-
 egs/librispeech/ASR/tdnn_lstm_ctc/decode.py   |  16 +-
 egs/librispeech/ASR/tdnn_lstm_ctc/model.py    |   5 +-
 .../ASR/tdnn_lstm_ctc/pretrained.py           |  21 +-
 egs/librispeech/ASR/tdnn_lstm_ctc/train.py    |   8 +-
 egs/librispeech/ASR/transducer/beam_search.py |  14 +-
 egs/librispeech/ASR/transducer/decode.py      |  15 +-
 egs/librispeech/ASR/transducer/export.py      |   4 +-
 egs/librispeech/ASR/transducer/pretrained.py  |  11 +-
 egs/librispeech/ASR/transducer/rnn.py         |  24 +-
 egs/librispeech/ASR/transducer/test_rnn.py    |  16 +-
 egs/librispeech/ASR/transducer/train.py       |  12 +-
 .../ASR/transducer_lstm/beam_search.py        |  14 +-
 egs/librispeech/ASR/transducer_lstm/decode.py |  15 +-
 .../ASR/transducer_lstm/encoder.py            |   4 +-
 egs/librispeech/ASR/transducer_lstm/train.py  |  12 +-
 .../ASR/transducer_stateless/alignment.py     |   4 +-
 .../ASR/transducer_stateless/beam_search.py   |  28 +-
 .../ASR/transducer_stateless/compute_ali.py   |  11 +-
 .../ASR/transducer_stateless/conformer.py     | 104 +--
 .../ASR/transducer_stateless/decode.py        |  23 +-
 .../ASR/transducer_stateless/decoder.py       |   4 +-
 .../ASR/transducer_stateless/export.py        |   7 +-
 .../ASR/transducer_stateless/joiner.py        |   8 +-
 .../ASR/transducer_stateless/pretrained.py    |  14 +-
 .../transducer_stateless/test_compute_ali.py  |  11 +-
 .../transducer_stateless/test_conformer.py    |   4 +-
 .../ASR/transducer_stateless/train.py         |  23 +-
 .../ASR/transducer_stateless/transformer.py   |   4 +-
 .../ASR/transducer_stateless2/decode.py       |  23 +-
 .../ASR/transducer_stateless2/export.py       |   7 +-
 .../ASR/transducer_stateless2/pretrained.py   |  14 +-
 .../ASR/transducer_stateless2/train.py        |  23 +-
 .../decode.py                                 |  23 +-
 .../export.py                                 |   7 +-
 .../pretrained.py                             |  14 +-
 .../test_asr_datamodule.py                    |   4 +-
 .../train.py                                  |  22 +-
 egs/ptb/LM/local/sort_lm_training_data.py     |   4 +-
 .../LM/local/test_prepare_lm_training_data.py |   4 +-
 .../ASR/local/compute_fbank_musan.py          |   8 +-
 .../ASR/local/compute_fbank_spgispeech.py     |  14 +-
 egs/spgispeech/ASR/local/prepare_splits.py    |   8 +-
 .../asr_datamodule.py                         |  24 +-
 .../pruned_transducer_stateless2/decode.py    |  52 +-
 .../pruned_transducer_stateless2/export.py    |  13 +-
 .../ASR/pruned_transducer_stateless2/train.py |  30 +-
 .../ASR/local/compute_fbank_tal_csasr.py      |   8 +-
 egs/tal_csasr/ASR/local/prepare_char.py       |   4 +-
 egs/tal_csasr/ASR/local/prepare_lang.py       |   4 +-
 egs/tal_csasr/ASR/local/test_prepare_lang.py  |   4 +-
 egs/tal_csasr/ASR/local/text2token.py         |  15 +-
 .../asr_datamodule.py                         |  20 +-
 .../pruned_transducer_stateless5/decode.py    |  39 +-
 .../pruned_transducer_stateless5/export.py    |  19 +-
 .../pretrained.py                             |  18 +-
 .../ASR/pruned_transducer_stateless5/train.py |  38 +-
 .../ASR/local/compute_fbank_tedlium.py        |   8 +-
 .../convert_transcript_words_to_bpe_ids.py    |   4 +-
 egs/tedlium3/ASR/local/prepare_lexicon.py     |  11 +-
 egs/tedlium3/ASR/local/prepare_transcripts.py |  11 +-
 .../ASR/pruned_transducer_stateless/decode.py |  19 +-
 .../ASR/pruned_transducer_stateless/export.py |   7 +-
 .../pruned_transducer_stateless/pretrained.py |  19 +-
 .../ASR/pruned_transducer_stateless/train.py  |  14 +-
 .../transducer_stateless/asr_datamodule.py    |  37 +-
 .../ASR/transducer_stateless/beam_search.py   |  30 +-
 .../ASR/transducer_stateless/decode.py        |  18 +-
 .../ASR/transducer_stateless/decoder.py       |   4 +-
 .../ASR/transducer_stateless/export.py        |   7 +-
 .../ASR/transducer_stateless/pretrained.py    |  14 +-
 .../ASR/transducer_stateless/train.py         |  11 +-
 egs/timit/ASR/RESULTS.md                      |   2 +-
 egs/timit/ASR/local/compile_hlg.py            |   4 +-
 egs/timit/ASR/local/compute_fbank_timit.py    |   8 +-
 egs/timit/ASR/local/prepare_lexicon.py        |   8 +-
 egs/timit/ASR/prepare.sh                      |   4 +-
 egs/timit/ASR/tdnn_ligru_ctc/decode.py        |  16 +-
 egs/timit/ASR/tdnn_ligru_ctc/model.py         |  12 +-
 egs/timit/ASR/tdnn_ligru_ctc/pretrained.py    |  21 +-
 egs/timit/ASR/tdnn_ligru_ctc/train.py         |   4 +-
 egs/timit/ASR/tdnn_lstm_ctc/asr_datamodule.py |  30 +-
 egs/timit/ASR/tdnn_lstm_ctc/decode.py         |  16 +-
 egs/timit/ASR/tdnn_lstm_ctc/model.py          |   5 +-
 egs/timit/ASR/tdnn_lstm_ctc/pretrained.py     |  21 +-
 egs/timit/ASR/tdnn_lstm_ctc/train.py          |   4 +-
 .../compute_fbank_wenetspeech_dev_test.py     |  11 +-
 .../local/compute_fbank_wenetspeech_splits.py |   4 +-
 egs/wenetspeech/ASR/local/prepare_char.py     |   8 +-
 .../ASR/local/preprocess_wenetspeech.py       |   6 +-
 egs/wenetspeech/ASR/local/text2token.py       |  15 +-
 egs/wenetspeech/ASR/prepare.sh                |   2 +-
 .../asr_datamodule.py                         |  31 +-
 .../pruned_transducer_stateless2/decode.py    |  39 +-
 .../pruned_transducer_stateless2/export.py    |  15 +-
 .../jit_pretrained.py                         |   7 +-
 .../onnx_check.py                             |  24 +-
 .../onnx_pretrained.py                        |  13 +-
 .../pretrained.py                             |  19 +-
 .../ASR/pruned_transducer_stateless2/train.py |  29 +-
 .../pruned_transducer_stateless5/conformer.py |  94 +--
 .../pruned_transducer_stateless5/decode.py    |  41 +-
 .../decode_stream.py                          |  19 +-
 .../pruned_transducer_stateless5/export.py    |   7 +-
 .../pretrained.py                             |  19 +-
 .../streaming_beam_search.py                  |   8 +-
 .../streaming_decode.py                       |  34 +-
 .../ASR/pruned_transducer_stateless5/train.py |  46 +-
 egs/yesno/ASR/local/compile_hlg.py            |   4 +-
 egs/yesno/ASR/local/compute_fbank_yesno.py    |  12 +-
 egs/yesno/ASR/tdnn/decode.py                  |  16 +-
 egs/yesno/ASR/tdnn/pretrained.py              |  15 +-
 egs/yesno/ASR/tdnn/train.py                   |   4 +-
 egs/yesno/ASR/transducer/decode.py            |  12 +-
 egs/yesno/ASR/transducer/train.py             |   4 +-
 icefall/char_graph_compiler.py                |   8 +-
 icefall/checkpoint.py                         |  12 +-
 icefall/decode.py                             |  40 +-
 icefall/diagnostics.py                        |  74 +--
 icefall/dist.py                               |   4 +-
 icefall/env.py                                |   4 +-
 icefall/graph_compiler.py                     |   4 +-
 icefall/hooks.py                              |  19 +-
 icefall/lexicon.py                            |  16 +-
 icefall/mmi.py                                |  29 +-
 icefall/mmi_graph_compiler.py                 |   8 +-
 icefall/rnn_lm/dataset.py                     |   8 +-
 icefall/rnn_lm/export.py                      |   4 +-
 icefall/rnn_lm/model.py                       |  28 +-
 icefall/rnn_lm/train.py                       |   8 +-
 icefall/shared/make_kn_lm.py                  | 177 +++--
 icefall/utils.py                              |  62 +-
 pyproject.toml                                |   2 +-
 setup.py                                      |   3 +-
 test/test_checkpoint.py                       |   6 +-
 test/test_decode.py                           |   1 +
 test/test_graph_compiler.py                   |   4 +-
 test/test_utils.py                            |   4 +-
 437 files changed, 3861 insertions(+), 7334 deletions(-)
 mode change 100755 => 100644 egs/aishell2/ASR/local/__init__.py
 mode change 100755 => 100644 egs/aishell2/ASR/pruned_transducer_stateless5/__init__.py
 mode change 100755 => 100644 egs/aishell2/ASR/pruned_transducer_stateless5/asr_datamodule.py

diff --git a/.github/workflows/style_check.yml b/.github/workflows/style_check.yml
index 90459bc1c..45d261ccc 100644
--- a/.github/workflows/style_check.yml
+++ b/.github/workflows/style_check.yml
@@ -45,17 +45,18 @@ jobs:
 
       - name: Install Python dependencies
         run: |
-          python3 -m pip install --upgrade pip black==21.6b0 flake8==3.9.2 click==8.0.4
-          # See https://github.com/psf/black/issues/2964
-          # The version of click should be selected from 8.0.0, 8.0.1, 8.0.2, 8.0.3, and 8.0.4
+          python3 -m pip install --upgrade pip black==22.3.0 flake8==5.0.4 click==8.1.0
+          # Click issue fixed in https://github.com/psf/black/pull/2966
 
       - name: Run flake8
         shell: bash
         working-directory: ${{github.workspace}}
         run: |
           # stop the build if there are Python syntax errors or undefined names
-          flake8 . --count --show-source --statistics
-          flake8 .
+          flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
+          # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
+          flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 \
+            --statistics --extend-ignore=E203,E266,E501,F401,E402,F403,F841,W503
 
       - name: Run black
         shell: bash
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 446ba0fe7..5cb213327 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -1,26 +1,38 @@
 repos:
   - repo: https://github.com/psf/black
-    rev: 21.6b0
+    rev: 22.3.0
     hooks:
       - id: black
-        args: [--line-length=80]
-        additional_dependencies: ['click==8.0.1']
+        args: ["--line-length=88"]
+        additional_dependencies: ['click==8.1.0']
         exclude: icefall\/__init__\.py
 
   - repo: https://github.com/PyCQA/flake8
-    rev: 3.9.2
+    rev: 5.0.4
     hooks:
       - id: flake8
-        args: [--max-line-length=80]
+        args: ["--max-line-length=88", "--extend-ignore=E203,E266,E501,F401,E402,F403,F841,W503"]
+
+      # What are we ignoring here?
+      # E203: whitespace before ':'
+      # E266: too many leading '#' for block comment
+      # E501: line too long
+      # F401: module imported but unused
+      # E402: module level import not at top of file
+      # F403: 'from module import *' used; unable to detect undefined names
+      # F841: local variable is assigned to but never used
+      # W503: line break before binary operator
+      # In addition, the default ignore list is:
+      # E121,E123,E126,E226,E24,E704,W503,W504
 
   - repo: https://github.com/pycqa/isort
-    rev: 5.9.2
+    rev: 5.10.1
     hooks:
       - id: isort
-        args: [--profile=black, --line-length=80]
+        args: ["--profile=black"]
 
   - repo: https://github.com/pre-commit/pre-commit-hooks
-    rev: v4.0.1
+    rev: v4.2.0
     hooks:
       - id: check-executables-have-shebangs
       - id: end-of-file-fixer
diff --git a/docker/README.md b/docker/README.md
index 6f2314e96..c14b9bf75 100644
--- a/docker/README.md
+++ b/docker/README.md
@@ -2,7 +2,7 @@
 
 2 sets of configuration are provided - (a) Ubuntu18.04-pytorch1.12.1-cuda11.3-cudnn8, and (b) Ubuntu18.04-pytorch1.7.1-cuda11.0-cudnn8.
 
-If your NVIDIA driver supports CUDA Version: 11.3, please go for case (a) Ubuntu18.04-pytorch1.12.1-cuda11.3-cudnn8. 
+If your NVIDIA driver supports CUDA Version: 11.3, please go for case (a) Ubuntu18.04-pytorch1.12.1-cuda11.3-cudnn8.
 
 Otherwise, since the older PyTorch images are not updated with the [apt-key rotation by NVIDIA](https://developer.nvidia.com/blog/updating-the-cuda-linux-gpg-repository-key), you have to go for case (b) Ubuntu18.04-pytorch1.7.1-cuda11.0-cudnn8. Ensure that your NVDIA driver supports at least CUDA 11.0.
 
@@ -10,7 +10,7 @@ You can check the highest CUDA version within your NVIDIA driver's support with
 
 ```bash
 $ nvidia-smi
-Tue Sep 20 00:26:13 2022       
+Tue Sep 20 00:26:13 2022
 +-----------------------------------------------------------------------------+
 | NVIDIA-SMI 450.119.03   Driver Version: 450.119.03   CUDA Version: 11.0     |
 |-------------------------------+----------------------+----------------------+
@@ -26,7 +26,7 @@ Tue Sep 20 00:26:13 2022
 | 41%   30C    P8    11W / 280W |      6MiB / 24220MiB |      0%      Default |
 |                               |                      |                  N/A |
 +-------------------------------+----------------------+----------------------+
-                                                                               
+
 +-----------------------------------------------------------------------------+
 | Processes:                                                                  |
 |  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
@@ -40,15 +40,15 @@ Tue Sep 20 00:26:13 2022
 ```
 
 ## Building images locally
-If your environment requires a proxy to access the Internet, remember to add those information into the Dockerfile directly. 
-For most cases, you can uncomment these lines in the Dockerfile and add in your proxy details. 
+If your environment requires a proxy to access the Internet, remember to add those information into the Dockerfile directly.
+For most cases, you can uncomment these lines in the Dockerfile and add in your proxy details.
 
 ```dockerfile
 ENV http_proxy=http://aaa.bb.cc.net:8080 \
     https_proxy=http://aaa.bb.cc.net:8080
 ```
 
-Then, proceed with these commands. 
+Then, proceed with these commands.
 
 ### If you are case (a), i.e. your NVIDIA driver supports CUDA version >= 11.3:
 
@@ -72,11 +72,11 @@ docker run -it --runtime=nvidia --shm-size=2gb --name=icefall --gpus all icefall
 ```
 
 ### Tips:
-1. Since your data and models most probably won't be in the docker, you must use the -v flag to access the host machine. Do this by specifying `-v {/path/in/host/machine}:{/path/in/docker}`. 
+1. Since your data and models most probably won't be in the docker, you must use the -v flag to access the host machine. Do this by specifying `-v {/path/in/host/machine}:{/path/in/docker}`.
 
 2. Also, if your environment requires a proxy, this would be a good time to add it in too: `-e http_proxy=http://aaa.bb.cc.net:8080 -e https_proxy=http://aaa.bb.cc.net:8080`.
 
-Overall, your docker run command should look like this. 
+Overall, your docker run command should look like this.
 
 ```bash
 docker run -it --runtime=nvidia --shm-size=2gb --name=icefall --gpus all -v {/path/in/host/machine}:{/path/in/docker} -e http_proxy=http://aaa.bb.cc.net:8080 -e https_proxy=http://aaa.bb.cc.net:8080 icefall/pytorch1.12.1
@@ -86,9 +86,9 @@ You can explore more docker run options [here](https://docs.docker.com/engine/re
 
 ### Linking to icefall in your host machine
 
-If you already have icefall downloaded onto your host machine, you can use that repository instead so that changes in your code are visible inside and outside of the container. 
+If you already have icefall downloaded onto your host machine, you can use that repository instead so that changes in your code are visible inside and outside of the container.
 
-Note: Remember to set the -v flag above during the first run of the container, as that is the only way for your container to access your host machine. 
+Note: Remember to set the -v flag above during the first run of the container, as that is the only way for your container to access your host machine.
 Warning: Check that the icefall in your host machine is visible from within your container before proceeding to the commands below.
 
 Use these commands once you are inside the container.
@@ -103,7 +103,7 @@ ln -s {/path/in/docker/to/icefall} /workspace/icefall
 docker exec -it icefall /bin/bash
 ```
 
-## Restarting a killed container that has been run before. 
+## Restarting a killed container that has been run before.
 ```bash
 docker start -ai icefall
 ```
@@ -111,4 +111,4 @@ docker start -ai icefall
 ## Sample usage of the CPU based images:
 ```bash
 docker run -it icefall /bin/bash
-``` 
+```
diff --git a/docker/Ubuntu18.04-pytorch1.12.1-cuda11.3-cudnn8/Dockerfile b/docker/Ubuntu18.04-pytorch1.12.1-cuda11.3-cudnn8/Dockerfile
index 3637d2f11..ff9e40604 100644
--- a/docker/Ubuntu18.04-pytorch1.12.1-cuda11.3-cudnn8/Dockerfile
+++ b/docker/Ubuntu18.04-pytorch1.12.1-cuda11.3-cudnn8/Dockerfile
@@ -1,7 +1,7 @@
 FROM pytorch/pytorch:1.12.1-cuda11.3-cudnn8-devel
 
 # ENV http_proxy=http://aaa.bbb.cc.net:8080 \
-#	https_proxy=http://aaa.bbb.cc.net:8080 
+#	https_proxy=http://aaa.bbb.cc.net:8080
 
 # install normal source
 RUN apt-get update && \
@@ -38,10 +38,10 @@ RUN wget -P /opt https://cmake.org/files/v3.18/cmake-3.18.0.tar.gz && \
     rm -rf cmake-3.18.0.tar.gz && \
     find /opt/cmake-3.18.0 -type f \( -name "*.o" -o -name "*.la" -o -name "*.a" \) -exec rm {} \; && \
     cd -
-	
-# flac 
+
+# flac
 RUN wget -P /opt https://downloads.xiph.org/releases/flac/flac-1.3.2.tar.xz  && \
-    cd /opt && \ 
+    cd /opt && \
     xz -d flac-1.3.2.tar.xz && \
     tar -xvf flac-1.3.2.tar && \
     cd flac-1.3.2 && \
@@ -49,11 +49,11 @@ RUN wget -P /opt https://downloads.xiph.org/releases/flac/flac-1.3.2.tar.xz  &&
     make && make install && \
     rm -rf flac-1.3.2.tar && \
     find /opt/flac-1.3.2  -type f \( -name "*.o" -o -name "*.la" -o -name "*.a" \) -exec rm {} \; && \
-    cd - 
+    cd -
 
 RUN conda install -y -c pytorch torchaudio=0.12 && \
     pip install graphviz
-	
+
 
 #install k2 from source
 RUN git clone https://github.com/k2-fsa/k2.git /opt/k2 && \
@@ -68,7 +68,7 @@ RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \
 	cd /workspace/icefall && \
 	pip install -r requirements.txt
 
-RUN pip install kaldifeat 
+RUN pip install kaldifeat
 ENV PYTHONPATH /workspace/icefall:$PYTHONPATH
 
 WORKDIR /workspace/icefall
diff --git a/docker/Ubuntu18.04-pytorch1.7.1-cuda11.0-cudnn8/Dockerfile b/docker/Ubuntu18.04-pytorch1.7.1-cuda11.0-cudnn8/Dockerfile
index 17a8215f9..5c7423fa5 100644
--- a/docker/Ubuntu18.04-pytorch1.7.1-cuda11.0-cudnn8/Dockerfile
+++ b/docker/Ubuntu18.04-pytorch1.7.1-cuda11.0-cudnn8/Dockerfile
@@ -1,12 +1,12 @@
 FROM pytorch/pytorch:1.7.1-cuda11.0-cudnn8-devel
 
 # ENV http_proxy=http://aaa.bbb.cc.net:8080 \
-#	https_proxy=http://aaa.bbb.cc.net:8080 
+#	https_proxy=http://aaa.bbb.cc.net:8080
 
 RUN rm /etc/apt/sources.list.d/cuda.list && \
 	rm /etc/apt/sources.list.d/nvidia-ml.list && \
 	apt-key del 7fa2af80
-	
+
 # install normal source
 RUN apt-get update && \
     apt-get install -y --no-install-recommends \
@@ -36,7 +36,7 @@ RUN curl -fsSL https://developer.download.nvidia.com/compute/cuda/repos/ubuntu18
 	curl -fsSL https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64/7fa2af80.pub | apt-key add - && \
 	echo "deb https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64 /" > /etc/apt/sources.list.d/cuda.list && \
 	echo "deb https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64 /" > /etc/apt/sources.list.d/nvidia-ml.list && \
-	rm -rf /var/lib/apt/lists/* && \ 
+	rm -rf /var/lib/apt/lists/* && \
 	mv /opt/conda/lib/libcufft.so.10 /opt/libcufft.so.10.bak && \
     mv /opt/conda/lib/libcurand.so.10 /opt/libcurand.so.10.bak && \
     mv /opt/conda/lib/libcublas.so.11 /opt/libcublas.so.11.bak && \
@@ -56,10 +56,10 @@ RUN wget -P /opt https://cmake.org/files/v3.18/cmake-3.18.0.tar.gz && \
     rm -rf cmake-3.18.0.tar.gz && \
     find /opt/cmake-3.18.0 -type f \( -name "*.o" -o -name "*.la" -o -name "*.a" \) -exec rm {} \; && \
     cd -
-	
-# flac 
+
+# flac
 RUN wget -P /opt https://downloads.xiph.org/releases/flac/flac-1.3.2.tar.xz  && \
-    cd /opt && \ 
+    cd /opt && \
     xz -d flac-1.3.2.tar.xz && \
     tar -xvf flac-1.3.2.tar && \
     cd flac-1.3.2 && \
@@ -67,7 +67,7 @@ RUN wget -P /opt https://downloads.xiph.org/releases/flac/flac-1.3.2.tar.xz  &&
     make && make install && \
     rm -rf flac-1.3.2.tar && \
     find /opt/flac-1.3.2  -type f \( -name "*.o" -o -name "*.la" -o -name "*.a" \) -exec rm {} \; && \
-    cd - 
+    cd -
 
 RUN conda install -y -c pytorch torchaudio=0.7.1 && \
     pip install graphviz
@@ -79,7 +79,7 @@ RUN git clone https://github.com/k2-fsa/k2.git /opt/k2 && \
     cd -
 
 # install  lhotse
-RUN pip install git+https://github.com/lhotse-speech/lhotse 
+RUN pip install git+https://github.com/lhotse-speech/lhotse
 
 RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \
 	cd /workspace/icefall && \
@@ -88,4 +88,3 @@ RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \
 ENV PYTHONPATH /workspace/icefall:$PYTHONPATH
 
 WORKDIR /workspace/icefall
-
diff --git a/docs/source/installation/images/k2-gt-v1.9-blueviolet.svg b/docs/source/installation/images/k2-gt-v1.9-blueviolet.svg
index 534b2e534..3019ff03d 100644
--- a/docs/source/installation/images/k2-gt-v1.9-blueviolet.svg
+++ b/docs/source/installation/images/k2-gt-v1.9-blueviolet.svg
@@ -1 +1 @@
-k2: >= v1.9k2>= v1.9
\ No newline at end of file
+k2: >= v1.9k2>= v1.9
diff --git a/docs/source/installation/images/python-gt-v3.6-blue.svg b/docs/source/installation/images/python-gt-v3.6-blue.svg
index 4254dc58a..df677ad09 100644
--- a/docs/source/installation/images/python-gt-v3.6-blue.svg
+++ b/docs/source/installation/images/python-gt-v3.6-blue.svg
@@ -1 +1 @@
-python: >= 3.6python>= 3.6
\ No newline at end of file
+python: >= 3.6python>= 3.6
diff --git a/docs/source/installation/images/torch-gt-v1.6.0-green.svg b/docs/source/installation/images/torch-gt-v1.6.0-green.svg
index d3ece9a17..d7007d742 100644
--- a/docs/source/installation/images/torch-gt-v1.6.0-green.svg
+++ b/docs/source/installation/images/torch-gt-v1.6.0-green.svg
@@ -1 +1 @@
-torch: >= 1.6.0torch>= 1.6.0
\ No newline at end of file
+torch: >= 1.6.0torch>= 1.6.0
diff --git a/docs/source/recipes/aishell/index.rst b/docs/source/recipes/aishell/index.rst
index d072d6e9c..b77d59bca 100644
--- a/docs/source/recipes/aishell/index.rst
+++ b/docs/source/recipes/aishell/index.rst
@@ -19,4 +19,3 @@ It can be downloaded from ``_
    tdnn_lstm_ctc
    conformer_ctc
    stateless_transducer
-
diff --git a/docs/source/recipes/timit/index.rst b/docs/source/recipes/timit/index.rst
index 17f40cdb7..5ee147be7 100644
--- a/docs/source/recipes/timit/index.rst
+++ b/docs/source/recipes/timit/index.rst
@@ -6,4 +6,3 @@ TIMIT
 
    tdnn_ligru_ctc
    tdnn_lstm_ctc
-
diff --git a/docs/source/recipes/timit/tdnn_ligru_ctc.rst b/docs/source/recipes/timit/tdnn_ligru_ctc.rst
index 186420ee7..3d7aefe02 100644
--- a/docs/source/recipes/timit/tdnn_ligru_ctc.rst
+++ b/docs/source/recipes/timit/tdnn_ligru_ctc.rst
@@ -148,10 +148,10 @@ Some commonly used options are:
 
         $ ./tdnn_ligru_ctc/decode.py --epoch 25 --avg 17
 
-    uses the average of ``epoch-9.pt``, ``epoch-10.pt``, ``epoch-11.pt``, 
-    ``epoch-12.pt``, ``epoch-13.pt``, ``epoch-14.pt``, ``epoch-15.pt``, 
-    ``epoch-16.pt``, ``epoch-17.pt``, ``epoch-18.pt``, ``epoch-19.pt``, 
-    ``epoch-20.pt``, ``epoch-21.pt``, ``epoch-22.pt``, ``epoch-23.pt``, 
+    uses the average of ``epoch-9.pt``, ``epoch-10.pt``, ``epoch-11.pt``,
+    ``epoch-12.pt``, ``epoch-13.pt``, ``epoch-14.pt``, ``epoch-15.pt``,
+    ``epoch-16.pt``, ``epoch-17.pt``, ``epoch-18.pt``, ``epoch-19.pt``,
+    ``epoch-20.pt``, ``epoch-21.pt``, ``epoch-22.pt``, ``epoch-23.pt``,
     ``epoch-24.pt`` and ``epoch-25.pt``
     for decoding.
 
@@ -317,13 +317,13 @@ To decode with ``1best`` method, we can use:
 
 .. code-block:: bash
 
-  ./tdnn_ligru_ctc/pretrained.py 
+  ./tdnn_ligru_ctc/pretrained.py
     --method 1best
-    --checkpoint ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/exp/pretrained_average_9_25.pt 
-    --words-file ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/data/lang_phone/words.txt 
-    --HLG ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/data/lang_phone/HLG.pt 
-    ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FDHC0_SI1559.WAV 
-    ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FELC0_SI756.WAV 
+    --checkpoint ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/exp/pretrained_average_9_25.pt
+    --words-file ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/data/lang_phone/words.txt
+    --HLG ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/data/lang_phone/HLG.pt
+    ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FDHC0_SI1559.WAV
+    ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FELC0_SI756.WAV
     ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FMGD0_SI1564.WAV
 
 The output is:
@@ -337,7 +337,7 @@ The output is:
   2021-11-08 20:41:38,697 INFO [pretrained.py:210] Reading sound files: ['./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FDHC0_SI1559.WAV', './tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FELC0_SI756.WAV', './tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FMGD0_SI1564.WAV']
   2021-11-08 20:41:38,704 INFO [pretrained.py:216] Decoding started
   2021-11-08 20:41:39,819 INFO [pretrained.py:246] Use HLG decoding
-  2021-11-08 20:41:39,829 INFO [pretrained.py:267] 
+  2021-11-08 20:41:39,829 INFO [pretrained.py:267]
   ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FDHC0_SI1559.WAV:
   sil dh ih sh uw ah l iy v iy z ih sil p r aa sil k s ih m ey dx ih sil d w uh dx ih w ih s f iy l ih ng w ih th ih n ih m s eh l f sil jh
 
@@ -362,8 +362,8 @@ To decode with ``whole-lattice-rescoring`` methond, you can use
     --HLG ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/data/lang_phone/HLG.pt \
     --G ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/data/lm/G_4_gram.pt \
     --ngram-lm-scale 0.1 \
-    ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FDHC0_SI1559.WAV 
-    ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FELC0_SI756.WAV 
+    ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FDHC0_SI1559.WAV
+    ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FELC0_SI756.WAV
     ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FMGD0_SI1564.WAV
 
 The decoding output is:
@@ -378,7 +378,7 @@ The decoding output is:
   2021-11-08 20:37:54,715 INFO [pretrained.py:210] Reading sound files: ['./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FDHC0_SI1559.WAV', './tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FELC0_SI756.WAV', './tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FMGD0_SI1564.WAV']
   2021-11-08 20:37:54,720 INFO [pretrained.py:216] Decoding started
   2021-11-08 20:37:55,808 INFO [pretrained.py:251] Use HLG decoding + LM rescoring
-  2021-11-08 20:37:56,348 INFO [pretrained.py:267] 
+  2021-11-08 20:37:56,348 INFO [pretrained.py:267]
   ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FDHC0_SI1559.WAV:
   sil dh ih sh uw ah l iy v iy z ah sil p r aa sil k s ih m ey dx ih sil d w uh dx iy w ih s f iy l iy ng w ih th ih n ih m s eh l f sil jh
 
diff --git a/docs/source/recipes/timit/tdnn_lstm_ctc.rst b/docs/source/recipes/timit/tdnn_lstm_ctc.rst
index 6f760a9ce..ee67a6edc 100644
--- a/docs/source/recipes/timit/tdnn_lstm_ctc.rst
+++ b/docs/source/recipes/timit/tdnn_lstm_ctc.rst
@@ -148,8 +148,8 @@ Some commonly used options are:
 
         $ ./tdnn_lstm_ctc/decode.py --epoch 25 --avg 10
 
-    uses the average of ``epoch-16.pt``, ``epoch-17.pt``, ``epoch-18.pt``, 
-    ``epoch-19.pt``, ``epoch-20.pt``, ``epoch-21.pt``, ``epoch-22.pt``, 
+    uses the average of ``epoch-16.pt``, ``epoch-17.pt``, ``epoch-18.pt``,
+    ``epoch-19.pt``, ``epoch-20.pt``, ``epoch-21.pt``, ``epoch-22.pt``,
     ``epoch-23.pt``, ``epoch-24.pt`` and ``epoch-25.pt``
     for decoding.
 
@@ -315,13 +315,13 @@ To decode with ``1best`` method, we can use:
 
 .. code-block:: bash
 
-  ./tdnn_lstm_ctc/pretrained.py 
+  ./tdnn_lstm_ctc/pretrained.py
     --method 1best
-    --checkpoint ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/exp/pretrained_average_16_25.pt 
-    --words-file ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/data/lang_phone/words.txt 
-    --HLG ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/data/lang_phone/HLG.pt 
-    ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FDHC0_SI1559.WAV 
-    ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FELC0_SI756.WAV 
+    --checkpoint ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/exp/pretrained_average_16_25.pt
+    --words-file ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/data/lang_phone/words.txt
+    --HLG ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/data/lang_phone/HLG.pt
+    ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FDHC0_SI1559.WAV
+    ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FELC0_SI756.WAV
     ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FMGD0_SI1564.WAV
 
 The output is:
@@ -335,7 +335,7 @@ The output is:
   2021-11-08 21:02:53,827 INFO [pretrained.py:210] Reading sound files: ['./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FDHC0_SI1559.WAV', './tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FELC0_SI756.WAV', './tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FMGD0_SI1564.WAV']
   2021-11-08 21:02:53,831 INFO [pretrained.py:216] Decoding started
   2021-11-08 21:02:54,380 INFO [pretrained.py:246] Use HLG decoding
-  2021-11-08 21:02:54,387 INFO [pretrained.py:267] 
+  2021-11-08 21:02:54,387 INFO [pretrained.py:267]
   ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FDHC0_SI1559.WAV:
   sil dh ih sh uw ah l iy v iy z ih sil p r aa sil k s ih m ey dx ih sil d w uh dx iy w ih s f iy l iy w ih th ih n ih m s eh l f sil jh
 
@@ -360,8 +360,8 @@ To decode with ``whole-lattice-rescoring`` methond, you can use
     --HLG ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/data/lang_phone/HLG.pt \
     --G ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/data/lm/G_4_gram.pt \
     --ngram-lm-scale 0.08 \
-    ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FDHC0_SI1559.WAV 
-    ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FELC0_SI756.WAV 
+    ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FDHC0_SI1559.WAV
+    ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FELC0_SI756.WAV
     ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FMGD0_SI1564.WAV
 
 The decoding output is:
@@ -376,7 +376,7 @@ The decoding output is:
   2021-11-08 20:05:26,978 INFO [pretrained.py:210] Reading sound files: ['./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FDHC0_SI1559.WAV', './tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FELC0_SI756.WAV', './tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FMGD0_SI1564.WAV']
   2021-11-08 20:05:26,981 INFO [pretrained.py:216] Decoding started
   2021-11-08 20:05:27,519 INFO [pretrained.py:251] Use HLG decoding + LM rescoring
-  2021-11-08 20:05:27,878 INFO [pretrained.py:267] 
+  2021-11-08 20:05:27,878 INFO [pretrained.py:267]
   ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FDHC0_SI1559.WAV:
   sil dh ih sh uw l iy v iy z ih sil p r aa sil k s ah m ey dx ih sil w uh dx iy w ih s f iy l ih ng w ih th ih n ih m s eh l f sil jh
 
diff --git a/egs/aidatatang_200zh/ASR/local/compute_fbank_aidatatang_200zh.py b/egs/aidatatang_200zh/ASR/local/compute_fbank_aidatatang_200zh.py
index fb2751c0f..387c14acf 100755
--- a/egs/aidatatang_200zh/ASR/local/compute_fbank_aidatatang_200zh.py
+++ b/egs/aidatatang_200zh/ASR/local/compute_fbank_aidatatang_200zh.py
@@ -87,9 +87,7 @@ def compute_fbank_aidatatang_200zh(num_mel_bins: int = 80):
             )
             if "train" in partition:
                 cut_set = (
-                    cut_set
-                    + cut_set.perturb_speed(0.9)
-                    + cut_set.perturb_speed(1.1)
+                    cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
                 )
             cut_set = cut_set.compute_and_store_features(
                 extractor=extractor,
@@ -116,9 +114,7 @@ def get_args():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
 
diff --git a/egs/aidatatang_200zh/ASR/local/prepare_char.py b/egs/aidatatang_200zh/ASR/local/prepare_char.py
index d9e47d17a..6b440dfb3 100755
--- a/egs/aidatatang_200zh/ASR/local/prepare_char.py
+++ b/egs/aidatatang_200zh/ASR/local/prepare_char.py
@@ -86,9 +86,7 @@ def lexicon_to_fst_no_sil(
         cur_state = loop_state
 
         word = word2id[word]
-        pieces = [
-            token2id[i] if i in token2id else token2id[""] for i in pieces
-        ]
+        pieces = [token2id[i] if i in token2id else token2id[""] for i in pieces]
 
         for i in range(len(pieces) - 1):
             w = word if i == 0 else eps
@@ -142,9 +140,7 @@ def contain_oov(token_sym_table: Dict[str, int], tokens: List[str]) -> bool:
     return False
 
 
-def generate_lexicon(
-    token_sym_table: Dict[str, int], words: List[str]
-) -> Lexicon:
+def generate_lexicon(token_sym_table: Dict[str, int], words: List[str]) -> Lexicon:
     """Generate a lexicon from a word list and token_sym_table.
 
     Args:
diff --git a/egs/aidatatang_200zh/ASR/local/prepare_lang.py b/egs/aidatatang_200zh/ASR/local/prepare_lang.py
index e5ae89ec4..c8cf9b881 100755
--- a/egs/aidatatang_200zh/ASR/local/prepare_lang.py
+++ b/egs/aidatatang_200zh/ASR/local/prepare_lang.py
@@ -317,9 +317,7 @@ def lexicon_to_fst(
 
 def get_args():
     parser = argparse.ArgumentParser()
-    parser.add_argument(
-        "--lang-dir", type=str, help="The lang dir, data/lang_phone"
-    )
+    parser.add_argument("--lang-dir", type=str, help="The lang dir, data/lang_phone")
     return parser.parse_args()
 
 
diff --git a/egs/aidatatang_200zh/ASR/local/test_prepare_lang.py b/egs/aidatatang_200zh/ASR/local/test_prepare_lang.py
index d4cf62bba..74e025ad7 100755
--- a/egs/aidatatang_200zh/ASR/local/test_prepare_lang.py
+++ b/egs/aidatatang_200zh/ASR/local/test_prepare_lang.py
@@ -88,9 +88,7 @@ def test_read_lexicon(filename: str):
     fsa.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
     fsa.draw("L.pdf", title="L")
 
-    fsa_disambig = lexicon_to_fst(
-        lexicon_disambig, phone2id=phone2id, word2id=word2id
-    )
+    fsa_disambig = lexicon_to_fst(lexicon_disambig, phone2id=phone2id, word2id=word2id)
     fsa_disambig.labels_sym = k2.SymbolTable.from_file("phones.txt")
     fsa_disambig.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
     fsa_disambig.draw("L_disambig.pdf", title="L_disambig")
diff --git a/egs/aidatatang_200zh/ASR/local/text2token.py b/egs/aidatatang_200zh/ASR/local/text2token.py
index 71be2a613..85047c367 100755
--- a/egs/aidatatang_200zh/ASR/local/text2token.py
+++ b/egs/aidatatang_200zh/ASR/local/text2token.py
@@ -56,9 +56,7 @@ def get_parser():
     parser.add_argument(
         "--skip-ncols", "-s", default=0, type=int, help="skip first n columns"
     )
-    parser.add_argument(
-        "--space", default="", type=str, help="space symbol"
-    )
+    parser.add_argument("--space", default="", type=str, help="space symbol")
     parser.add_argument(
         "--non-lang-syms",
         "-l",
@@ -66,9 +64,7 @@ def get_parser():
         type=str,
         help="list of non-linguistic symobles, e.g.,  etc.",
     )
-    parser.add_argument(
-        "text", type=str, default=False, nargs="?", help="input text"
-    )
+    parser.add_argument("text", type=str, default=False, nargs="?", help="input text")
     parser.add_argument(
         "--trans_type",
         "-t",
@@ -108,8 +104,7 @@ def token2id(
             if token_type == "lazy_pinyin":
                 text = lazy_pinyin(chars_list)
                 sub_ids = [
-                    token_table[txt] if txt in token_table else oov_id
-                    for txt in text
+                    token_table[txt] if txt in token_table else oov_id for txt in text
                 ]
                 ids.append(sub_ids)
             else:  # token_type = "pinyin"
@@ -135,9 +130,7 @@ def main():
     if args.text:
         f = codecs.open(args.text, encoding="utf-8")
     else:
-        f = codecs.getreader("utf-8")(
-            sys.stdin if is_python2 else sys.stdin.buffer
-        )
+        f = codecs.getreader("utf-8")(sys.stdin if is_python2 else sys.stdin.buffer)
 
     sys.stdout = codecs.getwriter("utf-8")(
         sys.stdout if is_python2 else sys.stdout.buffer
diff --git a/egs/aidatatang_200zh/ASR/prepare.sh b/egs/aidatatang_200zh/ASR/prepare.sh
index 039951354..4749e1b7f 100755
--- a/egs/aidatatang_200zh/ASR/prepare.sh
+++ b/egs/aidatatang_200zh/ASR/prepare.sh
@@ -106,11 +106,10 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
   if [ ! -f $lang_char_dir/words.txt ]; then
     ./local/prepare_words.py \
       --input-file $lang_char_dir/words_no_ids.txt \
-      --output-file $lang_char_dir/words.txt 
+      --output-file $lang_char_dir/words.txt
   fi
 
   if [ ! -f $lang_char_dir/L_disambig.pt ]; then
     ./local/prepare_char.py
   fi
 fi
-
diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/asr_datamodule.py
index 6a5b57e24..167d5e15e 100644
--- a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/asr_datamodule.py
+++ b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/asr_datamodule.py
@@ -205,17 +205,13 @@ class Aidatatang_200zhAsrDataModule:
             The state dict for the training sampler.
         """
         logging.info("About to get Musan cuts")
-        cuts_musan = load_manifest(
-            self.args.manifest_dir / "musan_cuts.jsonl.gz"
-        )
+        cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
 
         transforms = []
         if self.args.enable_musan:
             logging.info("Enable MUSAN")
             transforms.append(
-                CutMix(
-                    cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
-                )
+                CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
             )
         else:
             logging.info("Disable MUSAN")
@@ -237,9 +233,7 @@ class Aidatatang_200zhAsrDataModule:
         input_transforms = []
         if self.args.enable_spec_aug:
             logging.info("Enable SpecAugment")
-            logging.info(
-                f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
-            )
+            logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
             # Set the value of num_frame_masks according to Lhotse's version.
             # In different Lhotse's versions, the default of num_frame_masks is
             # different.
@@ -282,9 +276,7 @@ class Aidatatang_200zhAsrDataModule:
             # Drop feats to be on the safe side.
             train = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(
-                    Fbank(FbankConfig(num_mel_bins=80))
-                ),
+                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
                 input_transforms=input_transforms,
                 return_cuts=self.args.return_cuts,
             )
@@ -340,9 +332,7 @@ class Aidatatang_200zhAsrDataModule:
         if self.args.on_the_fly_feats:
             validate = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(
-                    Fbank(FbankConfig(num_mel_bins=80))
-                ),
+                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
                 return_cuts=self.args.return_cuts,
             )
         else:
diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/decode.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/decode.py
index f0407f429..b1c7c2839 100755
--- a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/decode.py
+++ b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/decode.py
@@ -69,11 +69,7 @@ from beam_search import (
 )
 from train import get_params, get_transducer_model
 
-from icefall.checkpoint import (
-    average_checkpoints,
-    find_checkpoints,
-    load_checkpoint,
-)
+from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
 from icefall.lexicon import Lexicon
 from icefall.utils import (
     AttributeDict,
@@ -192,8 +188,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -249,9 +244,7 @@ def decode_one_batch(
     supervisions = batch["supervisions"]
     feature_lens = supervisions["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(
-        x=feature, x_lens=feature_lens
-    )
+    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
     hyps = []
 
     if params.decoding_method == "fast_beam_search":
@@ -266,10 +259,7 @@ def decode_one_batch(
         )
         for i in range(encoder_out.size(0)):
             hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
-    elif (
-        params.decoding_method == "greedy_search"
-        and params.max_sym_per_frame == 1
-    ):
+    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -390,9 +380,7 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
     return results
 
 
@@ -425,8 +413,7 @@ def save_results(
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = (
-        params.res_dir
-        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tWER", file=f)
diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/export.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/export.py
index 00b54c39f..de37ec7e4 100644
--- a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/export.py
+++ b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/export.py
@@ -103,8 +103,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
 
     return parser
@@ -173,9 +172,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/pretrained.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/pretrained.py
index eb5e6b0d4..548b7263c 100644
--- a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/pretrained.py
+++ b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/pretrained.py
@@ -162,8 +162,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
 
     parser.add_argument(
@@ -194,8 +193,7 @@ def read_sound_files(
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
         assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. "
-            f"Given: {sample_rate}"
+            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
         )
         # We use only the first channel
         ans.append(wave[0])
@@ -257,9 +255,7 @@ def main():
     features = fbank(waves)
     feature_lengths = [f.size(0) for f in features]
 
-    features = pad_sequence(
-        features, batch_first=True, padding_value=math.log(1e-10)
-    )
+    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
 
     feature_lengths = torch.tensor(feature_lengths, device=device)
 
@@ -284,10 +280,7 @@ def main():
         )
         for i in range(encoder_out.size(0)):
             hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
-    elif (
-        params.decoding_method == "greedy_search"
-        and params.max_sym_per_frame == 1
-    ):
+    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -339,9 +332,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/train.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/train.py
index d46838b68..322fa6b00 100644
--- a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/train.py
+++ b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/train.py
@@ -81,9 +81,7 @@ from icefall.env import get_env_info
 from icefall.lexicon import Lexicon
 from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
 
-LRSchedulerType = Union[
-    torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
-]
+LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
 
 os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
 
@@ -187,8 +185,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
 
     parser.add_argument(
@@ -211,8 +208,7 @@ def get_parser():
         "--am-scale",
         type=float,
         default=0.0,
-        help="The scale to smooth the loss with am (output of encoder network)"
-        "part.",
+        help="The scale to smooth the loss with am (output of encoder network)" "part.",
     )
 
     parser.add_argument(
@@ -542,22 +538,15 @@ def compute_loss(
         # overwhelming the simple_loss and causing it to diverge,
         # in case it had not fully learned the alignment yet.
         pruned_loss_scale = (
-            0.0
-            if warmup < 1.0
-            else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
-        )
-        loss = (
-            params.simple_loss_scale * simple_loss
-            + pruned_loss_scale * pruned_loss
+            0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
         )
+        loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
     assert loss.requires_grad == is_training
 
     info = MetricsTracker()
     with warnings.catch_warnings():
         warnings.simplefilter("ignore")
-        info["frames"] = (
-            (feature_lens // params.subsampling_factor).sum().item()
-        )
+        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
 
     # Note: We use reduction=sum while computing the loss.
     info["loss"] = loss.detach().cpu().item()
@@ -711,9 +700,7 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(
-                    tb_writer, "train/tot_", params.batch_idx_train
-                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
@@ -813,7 +800,7 @@ def run(rank, world_size, args):
 
     if params.print_diagnostics:
         opts = diagnostics.TensorDiagnosticOptions(
-            2 ** 22
+            2**22
         )  # allow 4 megabytes per sub-module
         diagnostic = diagnostics.attach_diagnostics(model, opts)
 
diff --git a/egs/aishell/ASR/conformer_ctc/conformer.py b/egs/aishell/ASR/conformer_ctc/conformer.py
index cb7205e51..ab1cbbae4 100644
--- a/egs/aishell/ASR/conformer_ctc/conformer.py
+++ b/egs/aishell/ASR/conformer_ctc/conformer.py
@@ -157,9 +157,7 @@ class ConformerEncoderLayer(nn.Module):
         normalize_before: bool = True,
     ) -> None:
         super(ConformerEncoderLayer, self).__init__()
-        self.self_attn = RelPositionMultiheadAttention(
-            d_model, nhead, dropout=0.0
-        )
+        self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0)
 
         self.feed_forward = nn.Sequential(
             nn.Linear(d_model, dim_feedforward),
@@ -177,18 +175,14 @@ class ConformerEncoderLayer(nn.Module):
 
         self.conv_module = ConvolutionModule(d_model, cnn_module_kernel)
 
-        self.norm_ff_macaron = nn.LayerNorm(
-            d_model
-        )  # for the macaron style FNN module
+        self.norm_ff_macaron = nn.LayerNorm(d_model)  # for the macaron style FNN module
         self.norm_ff = nn.LayerNorm(d_model)  # for the FNN module
         self.norm_mha = nn.LayerNorm(d_model)  # for the MHA module
 
         self.ff_scale = 0.5
 
         self.norm_conv = nn.LayerNorm(d_model)  # for the CNN module
-        self.norm_final = nn.LayerNorm(
-            d_model
-        )  # for the final output of the block
+        self.norm_final = nn.LayerNorm(d_model)  # for the final output of the block
 
         self.dropout = nn.Dropout(dropout)
 
@@ -222,9 +216,7 @@ class ConformerEncoderLayer(nn.Module):
         residual = src
         if self.normalize_before:
             src = self.norm_ff_macaron(src)
-        src = residual + self.ff_scale * self.dropout(
-            self.feed_forward_macaron(src)
-        )
+        src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src))
         if not self.normalize_before:
             src = self.norm_ff_macaron(src)
 
@@ -343,9 +335,7 @@ class RelPositionalEncoding(torch.nn.Module):
 
     """
 
-    def __init__(
-        self, d_model: int, dropout_rate: float, max_len: int = 5000
-    ) -> None:
+    def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None:
         """Construct an PositionalEncoding object."""
         super(RelPositionalEncoding, self).__init__()
         self.d_model = d_model
@@ -361,9 +351,7 @@ class RelPositionalEncoding(torch.nn.Module):
             # the length of self.pe is 2 * input_len - 1
             if self.pe.size(1) >= x.size(1) * 2 - 1:
                 # Note: TorchScript doesn't implement operator== for torch.Device
-                if self.pe.dtype != x.dtype or str(self.pe.device) != str(
-                    x.device
-                ):
+                if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device):
                     self.pe = self.pe.to(dtype=x.dtype, device=x.device)
                 return
         # Suppose `i` means to the position of query vector and `j` means the
@@ -633,9 +621,9 @@ class RelPositionMultiheadAttention(nn.Module):
 
         if torch.equal(query, key) and torch.equal(key, value):
             # self-attention
-            q, k, v = nn.functional.linear(
-                query, in_proj_weight, in_proj_bias
-            ).chunk(3, dim=-1)
+            q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk(
+                3, dim=-1
+            )
 
         elif torch.equal(key, value):
             # encoder-decoder attention
@@ -703,31 +691,22 @@ class RelPositionMultiheadAttention(nn.Module):
             if attn_mask.dim() == 2:
                 attn_mask = attn_mask.unsqueeze(0)
                 if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
-                    raise RuntimeError(
-                        "The size of the 2D attn_mask is not correct."
-                    )
+                    raise RuntimeError("The size of the 2D attn_mask is not correct.")
             elif attn_mask.dim() == 3:
                 if list(attn_mask.size()) != [
                     bsz * num_heads,
                     query.size(0),
                     key.size(0),
                 ]:
-                    raise RuntimeError(
-                        "The size of the 3D attn_mask is not correct."
-                    )
+                    raise RuntimeError("The size of the 3D attn_mask is not correct.")
             else:
                 raise RuntimeError(
-                    "attn_mask's dimension {} is not supported".format(
-                        attn_mask.dim()
-                    )
+                    "attn_mask's dimension {} is not supported".format(attn_mask.dim())
                 )
             # attn_mask's dim is 3 now.
 
         # convert ByteTensor key_padding_mask to bool
-        if (
-            key_padding_mask is not None
-            and key_padding_mask.dtype == torch.uint8
-        ):
+        if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
             warnings.warn(
                 "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead."
             )
@@ -766,9 +745,7 @@ class RelPositionMultiheadAttention(nn.Module):
         # first compute matrix a and matrix c
         # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
         k = k.permute(1, 2, 3, 0)  # (batch, head, d_k, time2)
-        matrix_ac = torch.matmul(
-            q_with_bias_u, k
-        )  # (batch, head, time1, time2)
+        matrix_ac = torch.matmul(q_with_bias_u, k)  # (batch, head, time1, time2)
 
         # compute matrix b and matrix d
         matrix_bd = torch.matmul(
@@ -780,9 +757,7 @@ class RelPositionMultiheadAttention(nn.Module):
             matrix_ac + matrix_bd
         ) * scaling  # (batch, head, time1, time2)
 
-        attn_output_weights = attn_output_weights.view(
-            bsz * num_heads, tgt_len, -1
-        )
+        attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1)
 
         assert list(attn_output_weights.size()) == [
             bsz * num_heads,
@@ -816,13 +791,9 @@ class RelPositionMultiheadAttention(nn.Module):
         attn_output = torch.bmm(attn_output_weights, v)
         assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
         attn_output = (
-            attn_output.transpose(0, 1)
-            .contiguous()
-            .view(tgt_len, bsz, embed_dim)
-        )
-        attn_output = nn.functional.linear(
-            attn_output, out_proj_weight, out_proj_bias
+            attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
         )
+        attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias)
 
         if need_weights:
             # average attention weights over heads
@@ -845,9 +816,7 @@ class ConvolutionModule(nn.Module):
 
     """
 
-    def __init__(
-        self, channels: int, kernel_size: int, bias: bool = True
-    ) -> None:
+    def __init__(self, channels: int, kernel_size: int, bias: bool = True) -> None:
         """Construct an ConvolutionModule object."""
         super(ConvolutionModule, self).__init__()
         # kernerl_size should be a odd number for 'SAME' padding
diff --git a/egs/aishell/ASR/conformer_ctc/decode.py b/egs/aishell/ASR/conformer_ctc/decode.py
index 751b7d5b5..74a7b5933 100755
--- a/egs/aishell/ASR/conformer_ctc/decode.py
+++ b/egs/aishell/ASR/conformer_ctc/decode.py
@@ -401,9 +401,7 @@ def decode_dataset(
         if batch_idx % 100 == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
     return results
 
 
@@ -431,9 +429,7 @@ def save_results(
         # we compute CER for aishell dataset.
         results_char = []
         for res in results:
-            results_char.append(
-                (res[0], list("".join(res[1])), list("".join(res[2])))
-            )
+            results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
         with open(errs_filename, "w") as f:
             wer = write_error_stats(
                 f, f"{test_set_name}-{key}", results_char, enable_log=enable_log
@@ -441,9 +437,7 @@ def save_results(
             test_set_wers[key] = wer
 
         if enable_log:
-            logging.info(
-                "Wrote detailed error stats to {}".format(errs_filename)
-            )
+            logging.info("Wrote detailed error stats to {}".format(errs_filename))
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = params.exp_dir / f"cer-summary-{test_set_name}.txt"
@@ -562,9 +556,7 @@ def main():
             eos_id=eos_id,
         )
 
-        save_results(
-            params=params, test_set_name=test_set, results_dict=results_dict
-        )
+        save_results(params=params, test_set_name=test_set, results_dict=results_dict)
 
     logging.info("Done!")
 
diff --git a/egs/aishell/ASR/conformer_ctc/export.py b/egs/aishell/ASR/conformer_ctc/export.py
index 42b8c29e7..1df3cfdc2 100644
--- a/egs/aishell/ASR/conformer_ctc/export.py
+++ b/egs/aishell/ASR/conformer_ctc/export.py
@@ -157,9 +157,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/aishell/ASR/conformer_ctc/pretrained.py b/egs/aishell/ASR/conformer_ctc/pretrained.py
index 27776bc24..e0dcb8ad4 100755
--- a/egs/aishell/ASR/conformer_ctc/pretrained.py
+++ b/egs/aishell/ASR/conformer_ctc/pretrained.py
@@ -211,8 +211,7 @@ def read_sound_files(
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
         assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. "
-            f"Given: {sample_rate}"
+            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
         )
         # We use only the first channel
         ans.append(wave[0])
@@ -274,9 +273,7 @@ def main():
     logging.info("Decoding started")
     features = fbank(waves)
 
-    features = pad_sequence(
-        features, batch_first=True, padding_value=math.log(1e-10)
-    )
+    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
 
     # Note: We don't use key padding mask for attention during decoding
     with torch.no_grad():
@@ -371,9 +368,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/aishell/ASR/conformer_ctc/subsampling.py b/egs/aishell/ASR/conformer_ctc/subsampling.py
index 542fb0364..8e0f73d05 100644
--- a/egs/aishell/ASR/conformer_ctc/subsampling.py
+++ b/egs/aishell/ASR/conformer_ctc/subsampling.py
@@ -42,13 +42,9 @@ class Conv2dSubsampling(nn.Module):
         assert idim >= 7
         super().__init__()
         self.conv = nn.Sequential(
-            nn.Conv2d(
-                in_channels=1, out_channels=odim, kernel_size=3, stride=2
-            ),
+            nn.Conv2d(in_channels=1, out_channels=odim, kernel_size=3, stride=2),
             nn.ReLU(),
-            nn.Conv2d(
-                in_channels=odim, out_channels=odim, kernel_size=3, stride=2
-            ),
+            nn.Conv2d(in_channels=odim, out_channels=odim, kernel_size=3, stride=2),
             nn.ReLU(),
         )
         self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim)
@@ -132,17 +128,13 @@ class VggSubsampling(nn.Module):
                 )
             )
             layers.append(
-                torch.nn.MaxPool2d(
-                    kernel_size=2, stride=2, padding=0, ceil_mode=True
-                )
+                torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=True)
             )
             cur_channels = block_dim
 
         self.layers = nn.Sequential(*layers)
 
-        self.out = nn.Linear(
-            block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim
-        )
+        self.out = nn.Linear(block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim)
 
     def forward(self, x: torch.Tensor) -> torch.Tensor:
         """Subsample x.
diff --git a/egs/aishell/ASR/conformer_ctc/test_subsampling.py b/egs/aishell/ASR/conformer_ctc/test_subsampling.py
index e3361d0c9..81fa234dd 100755
--- a/egs/aishell/ASR/conformer_ctc/test_subsampling.py
+++ b/egs/aishell/ASR/conformer_ctc/test_subsampling.py
@@ -16,9 +16,8 @@
 # limitations under the License.
 
 
-from subsampling import Conv2dSubsampling
-from subsampling import VggSubsampling
 import torch
+from subsampling import Conv2dSubsampling, VggSubsampling
 
 
 def test_conv2d_subsampling():
diff --git a/egs/aishell/ASR/conformer_ctc/train.py b/egs/aishell/ASR/conformer_ctc/train.py
index a228cc1fe..c2cbe6e3b 100755
--- a/egs/aishell/ASR/conformer_ctc/train.py
+++ b/egs/aishell/ASR/conformer_ctc/train.py
@@ -382,9 +382,7 @@ def compute_loss(
             #
             # See https://github.com/k2-fsa/icefall/issues/97
             # for more details
-            unsorted_token_ids = graph_compiler.texts_to_ids(
-                supervisions["text"]
-            )
+            unsorted_token_ids = graph_compiler.texts_to_ids(supervisions["text"])
             att_loss = mmodel.decoder_forward(
                 encoder_memory,
                 memory_mask,
@@ -520,9 +518,7 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(
-                    tb_writer, "train/tot_", params.batch_idx_train
-                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
@@ -630,9 +626,7 @@ def run(rank, world_size, args):
 
         cur_lr = optimizer._rate
         if tb_writer is not None:
-            tb_writer.add_scalar(
-                "train/learning_rate", cur_lr, params.batch_idx_train
-            )
+            tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
             tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
 
         if rank == 0:
diff --git a/egs/aishell/ASR/conformer_ctc/transformer.py b/egs/aishell/ASR/conformer_ctc/transformer.py
index f93914aaa..a3e50e385 100644
--- a/egs/aishell/ASR/conformer_ctc/transformer.py
+++ b/egs/aishell/ASR/conformer_ctc/transformer.py
@@ -149,9 +149,7 @@ class Transformer(nn.Module):
                 norm=decoder_norm,
             )
 
-            self.decoder_output_layer = torch.nn.Linear(
-                d_model, self.decoder_num_class
-            )
+            self.decoder_output_layer = torch.nn.Linear(d_model, self.decoder_num_class)
 
             self.decoder_criterion = LabelSmoothingLoss()
         else:
@@ -183,9 +181,7 @@ class Transformer(nn.Module):
             x = x.permute(0, 2, 1)  # (N, T, C) -> (N, C, T)
             x = self.feat_batchnorm(x)
             x = x.permute(0, 2, 1)  # (N, C, T) -> (N, T, C)
-        encoder_memory, memory_key_padding_mask = self.run_encoder(
-            x, supervision
-        )
+        encoder_memory, memory_key_padding_mask = self.run_encoder(x, supervision)
         x = self.ctc_output(encoder_memory)
         return x, encoder_memory, memory_key_padding_mask
 
@@ -266,23 +262,17 @@ class Transformer(nn.Module):
         """
         ys_in = add_sos(token_ids, sos_id=sos_id)
         ys_in = [torch.tensor(y) for y in ys_in]
-        ys_in_pad = pad_sequence(
-            ys_in, batch_first=True, padding_value=float(eos_id)
-        )
+        ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id))
 
         ys_out = add_eos(token_ids, eos_id=eos_id)
         ys_out = [torch.tensor(y) for y in ys_out]
-        ys_out_pad = pad_sequence(
-            ys_out, batch_first=True, padding_value=float(-1)
-        )
+        ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1))
 
         device = memory.device
         ys_in_pad = ys_in_pad.to(device)
         ys_out_pad = ys_out_pad.to(device)
 
-        tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(
-            device
-        )
+        tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device)
 
         tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id)
         # TODO: Use length information to create the decoder padding mask
@@ -343,23 +333,17 @@ class Transformer(nn.Module):
 
         ys_in = add_sos(token_ids, sos_id=sos_id)
         ys_in = [torch.tensor(y) for y in ys_in]
-        ys_in_pad = pad_sequence(
-            ys_in, batch_first=True, padding_value=float(eos_id)
-        )
+        ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id))
 
         ys_out = add_eos(token_ids, eos_id=eos_id)
         ys_out = [torch.tensor(y) for y in ys_out]
-        ys_out_pad = pad_sequence(
-            ys_out, batch_first=True, padding_value=float(-1)
-        )
+        ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1))
 
         device = memory.device
         ys_in_pad = ys_in_pad.to(device, dtype=torch.int64)
         ys_out_pad = ys_out_pad.to(device, dtype=torch.int64)
 
-        tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(
-            device
-        )
+        tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device)
 
         tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id)
         # TODO: Use length information to create the decoder padding mask
@@ -632,9 +616,7 @@ def _get_activation_fn(activation: str):
     elif activation == "gelu":
         return nn.functional.gelu
 
-    raise RuntimeError(
-        "activation should be relu/gelu, not {}".format(activation)
-    )
+    raise RuntimeError("activation should be relu/gelu, not {}".format(activation))
 
 
 class PositionalEncoding(nn.Module):
@@ -836,9 +818,7 @@ def encoder_padding_mask(
         1,
     ).to(torch.int32)
 
-    lengths = [
-        0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)
-    ]
+    lengths = [0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)]
     for idx in range(supervision_segments.size(0)):
         # Note: TorchScript doesn't allow to unpack tensors as tuples
         sequence_idx = supervision_segments[idx, 0].item()
@@ -859,9 +839,7 @@ def encoder_padding_mask(
     return mask
 
 
-def decoder_padding_mask(
-    ys_pad: torch.Tensor, ignore_id: int = -1
-) -> torch.Tensor:
+def decoder_padding_mask(ys_pad: torch.Tensor, ignore_id: int = -1) -> torch.Tensor:
     """Generate a length mask for input.
 
     The masked position are filled with True,
diff --git a/egs/aishell/ASR/conformer_mmi/conformer.py b/egs/aishell/ASR/conformer_mmi/conformer.py
index cb7205e51..ab1cbbae4 100644
--- a/egs/aishell/ASR/conformer_mmi/conformer.py
+++ b/egs/aishell/ASR/conformer_mmi/conformer.py
@@ -157,9 +157,7 @@ class ConformerEncoderLayer(nn.Module):
         normalize_before: bool = True,
     ) -> None:
         super(ConformerEncoderLayer, self).__init__()
-        self.self_attn = RelPositionMultiheadAttention(
-            d_model, nhead, dropout=0.0
-        )
+        self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0)
 
         self.feed_forward = nn.Sequential(
             nn.Linear(d_model, dim_feedforward),
@@ -177,18 +175,14 @@ class ConformerEncoderLayer(nn.Module):
 
         self.conv_module = ConvolutionModule(d_model, cnn_module_kernel)
 
-        self.norm_ff_macaron = nn.LayerNorm(
-            d_model
-        )  # for the macaron style FNN module
+        self.norm_ff_macaron = nn.LayerNorm(d_model)  # for the macaron style FNN module
         self.norm_ff = nn.LayerNorm(d_model)  # for the FNN module
         self.norm_mha = nn.LayerNorm(d_model)  # for the MHA module
 
         self.ff_scale = 0.5
 
         self.norm_conv = nn.LayerNorm(d_model)  # for the CNN module
-        self.norm_final = nn.LayerNorm(
-            d_model
-        )  # for the final output of the block
+        self.norm_final = nn.LayerNorm(d_model)  # for the final output of the block
 
         self.dropout = nn.Dropout(dropout)
 
@@ -222,9 +216,7 @@ class ConformerEncoderLayer(nn.Module):
         residual = src
         if self.normalize_before:
             src = self.norm_ff_macaron(src)
-        src = residual + self.ff_scale * self.dropout(
-            self.feed_forward_macaron(src)
-        )
+        src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src))
         if not self.normalize_before:
             src = self.norm_ff_macaron(src)
 
@@ -343,9 +335,7 @@ class RelPositionalEncoding(torch.nn.Module):
 
     """
 
-    def __init__(
-        self, d_model: int, dropout_rate: float, max_len: int = 5000
-    ) -> None:
+    def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None:
         """Construct an PositionalEncoding object."""
         super(RelPositionalEncoding, self).__init__()
         self.d_model = d_model
@@ -361,9 +351,7 @@ class RelPositionalEncoding(torch.nn.Module):
             # the length of self.pe is 2 * input_len - 1
             if self.pe.size(1) >= x.size(1) * 2 - 1:
                 # Note: TorchScript doesn't implement operator== for torch.Device
-                if self.pe.dtype != x.dtype or str(self.pe.device) != str(
-                    x.device
-                ):
+                if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device):
                     self.pe = self.pe.to(dtype=x.dtype, device=x.device)
                 return
         # Suppose `i` means to the position of query vector and `j` means the
@@ -633,9 +621,9 @@ class RelPositionMultiheadAttention(nn.Module):
 
         if torch.equal(query, key) and torch.equal(key, value):
             # self-attention
-            q, k, v = nn.functional.linear(
-                query, in_proj_weight, in_proj_bias
-            ).chunk(3, dim=-1)
+            q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk(
+                3, dim=-1
+            )
 
         elif torch.equal(key, value):
             # encoder-decoder attention
@@ -703,31 +691,22 @@ class RelPositionMultiheadAttention(nn.Module):
             if attn_mask.dim() == 2:
                 attn_mask = attn_mask.unsqueeze(0)
                 if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
-                    raise RuntimeError(
-                        "The size of the 2D attn_mask is not correct."
-                    )
+                    raise RuntimeError("The size of the 2D attn_mask is not correct.")
             elif attn_mask.dim() == 3:
                 if list(attn_mask.size()) != [
                     bsz * num_heads,
                     query.size(0),
                     key.size(0),
                 ]:
-                    raise RuntimeError(
-                        "The size of the 3D attn_mask is not correct."
-                    )
+                    raise RuntimeError("The size of the 3D attn_mask is not correct.")
             else:
                 raise RuntimeError(
-                    "attn_mask's dimension {} is not supported".format(
-                        attn_mask.dim()
-                    )
+                    "attn_mask's dimension {} is not supported".format(attn_mask.dim())
                 )
             # attn_mask's dim is 3 now.
 
         # convert ByteTensor key_padding_mask to bool
-        if (
-            key_padding_mask is not None
-            and key_padding_mask.dtype == torch.uint8
-        ):
+        if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
             warnings.warn(
                 "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead."
             )
@@ -766,9 +745,7 @@ class RelPositionMultiheadAttention(nn.Module):
         # first compute matrix a and matrix c
         # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
         k = k.permute(1, 2, 3, 0)  # (batch, head, d_k, time2)
-        matrix_ac = torch.matmul(
-            q_with_bias_u, k
-        )  # (batch, head, time1, time2)
+        matrix_ac = torch.matmul(q_with_bias_u, k)  # (batch, head, time1, time2)
 
         # compute matrix b and matrix d
         matrix_bd = torch.matmul(
@@ -780,9 +757,7 @@ class RelPositionMultiheadAttention(nn.Module):
             matrix_ac + matrix_bd
         ) * scaling  # (batch, head, time1, time2)
 
-        attn_output_weights = attn_output_weights.view(
-            bsz * num_heads, tgt_len, -1
-        )
+        attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1)
 
         assert list(attn_output_weights.size()) == [
             bsz * num_heads,
@@ -816,13 +791,9 @@ class RelPositionMultiheadAttention(nn.Module):
         attn_output = torch.bmm(attn_output_weights, v)
         assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
         attn_output = (
-            attn_output.transpose(0, 1)
-            .contiguous()
-            .view(tgt_len, bsz, embed_dim)
-        )
-        attn_output = nn.functional.linear(
-            attn_output, out_proj_weight, out_proj_bias
+            attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
         )
+        attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias)
 
         if need_weights:
             # average attention weights over heads
@@ -845,9 +816,7 @@ class ConvolutionModule(nn.Module):
 
     """
 
-    def __init__(
-        self, channels: int, kernel_size: int, bias: bool = True
-    ) -> None:
+    def __init__(self, channels: int, kernel_size: int, bias: bool = True) -> None:
         """Construct an ConvolutionModule object."""
         super(ConvolutionModule, self).__init__()
         # kernerl_size should be a odd number for 'SAME' padding
diff --git a/egs/aishell/ASR/conformer_mmi/decode.py b/egs/aishell/ASR/conformer_mmi/decode.py
index 4db367e36..20a855e7f 100755
--- a/egs/aishell/ASR/conformer_mmi/decode.py
+++ b/egs/aishell/ASR/conformer_mmi/decode.py
@@ -413,9 +413,7 @@ def decode_dataset(
         if batch_idx % 100 == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
     return results
 
 
@@ -443,9 +441,7 @@ def save_results(
         # we compute CER for aishell dataset.
         results_char = []
         for res in results:
-            results_char.append(
-                (res[0], list("".join(res[1])), list("".join(res[2])))
-            )
+            results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
         with open(errs_filename, "w") as f:
             wer = write_error_stats(
                 f, f"{test_set_name}-{key}", results_char, enable_log=enable_log
@@ -453,9 +449,7 @@ def save_results(
             test_set_wers[key] = wer
 
         if enable_log:
-            logging.info(
-                "Wrote detailed error stats to {}".format(errs_filename)
-            )
+            logging.info("Wrote detailed error stats to {}".format(errs_filename))
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = params.exp_dir / f"cer-summary-{test_set_name}.txt"
@@ -550,9 +544,7 @@ def main():
 
     if params.export:
         logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
-        torch.save(
-            {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
-        )
+        torch.save({"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt")
         return
 
     model.to(device)
@@ -581,9 +573,7 @@ def main():
             eos_id=eos_id,
         )
 
-        save_results(
-            params=params, test_set_name=test_set, results_dict=results_dict
-        )
+        save_results(params=params, test_set_name=test_set, results_dict=results_dict)
 
     logging.info("Done!")
 
diff --git a/egs/aishell/ASR/conformer_mmi/subsampling.py b/egs/aishell/ASR/conformer_mmi/subsampling.py
index 720ed6c22..398837a46 100644
--- a/egs/aishell/ASR/conformer_mmi/subsampling.py
+++ b/egs/aishell/ASR/conformer_mmi/subsampling.py
@@ -42,13 +42,9 @@ class Conv2dSubsampling(nn.Module):
         assert idim >= 7
         super().__init__()
         self.conv = nn.Sequential(
-            nn.Conv2d(
-                in_channels=1, out_channels=odim, kernel_size=3, stride=2
-            ),
+            nn.Conv2d(in_channels=1, out_channels=odim, kernel_size=3, stride=2),
             nn.ReLU(),
-            nn.Conv2d(
-                in_channels=odim, out_channels=odim, kernel_size=3, stride=2
-            ),
+            nn.Conv2d(in_channels=odim, out_channels=odim, kernel_size=3, stride=2),
             nn.ReLU(),
         )
         self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim)
@@ -132,17 +128,13 @@ class VggSubsampling(nn.Module):
                 )
             )
             layers.append(
-                torch.nn.MaxPool2d(
-                    kernel_size=2, stride=2, padding=0, ceil_mode=True
-                )
+                torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=True)
             )
             cur_channels = block_dim
 
         self.layers = nn.Sequential(*layers)
 
-        self.out = nn.Linear(
-            block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim
-        )
+        self.out = nn.Linear(block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim)
 
     def forward(self, x: torch.Tensor) -> torch.Tensor:
         """Subsample x.
diff --git a/egs/aishell/ASR/conformer_mmi/train.py b/egs/aishell/ASR/conformer_mmi/train.py
index 685831d09..09cd6e60c 100755
--- a/egs/aishell/ASR/conformer_mmi/train.py
+++ b/egs/aishell/ASR/conformer_mmi/train.py
@@ -511,9 +511,7 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(
-                    tb_writer, "train/tot_", params.batch_idx_train
-                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
@@ -625,9 +623,7 @@ def run(rank, world_size, args):
 
         cur_lr = optimizer._rate
         if tb_writer is not None:
-            tb_writer.add_scalar(
-                "train/learning_rate", cur_lr, params.batch_idx_train
-            )
+            tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
             tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
 
         if rank == 0:
diff --git a/egs/aishell/ASR/conformer_mmi/transformer.py b/egs/aishell/ASR/conformer_mmi/transformer.py
index f93914aaa..a3e50e385 100644
--- a/egs/aishell/ASR/conformer_mmi/transformer.py
+++ b/egs/aishell/ASR/conformer_mmi/transformer.py
@@ -149,9 +149,7 @@ class Transformer(nn.Module):
                 norm=decoder_norm,
             )
 
-            self.decoder_output_layer = torch.nn.Linear(
-                d_model, self.decoder_num_class
-            )
+            self.decoder_output_layer = torch.nn.Linear(d_model, self.decoder_num_class)
 
             self.decoder_criterion = LabelSmoothingLoss()
         else:
@@ -183,9 +181,7 @@ class Transformer(nn.Module):
             x = x.permute(0, 2, 1)  # (N, T, C) -> (N, C, T)
             x = self.feat_batchnorm(x)
             x = x.permute(0, 2, 1)  # (N, C, T) -> (N, T, C)
-        encoder_memory, memory_key_padding_mask = self.run_encoder(
-            x, supervision
-        )
+        encoder_memory, memory_key_padding_mask = self.run_encoder(x, supervision)
         x = self.ctc_output(encoder_memory)
         return x, encoder_memory, memory_key_padding_mask
 
@@ -266,23 +262,17 @@ class Transformer(nn.Module):
         """
         ys_in = add_sos(token_ids, sos_id=sos_id)
         ys_in = [torch.tensor(y) for y in ys_in]
-        ys_in_pad = pad_sequence(
-            ys_in, batch_first=True, padding_value=float(eos_id)
-        )
+        ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id))
 
         ys_out = add_eos(token_ids, eos_id=eos_id)
         ys_out = [torch.tensor(y) for y in ys_out]
-        ys_out_pad = pad_sequence(
-            ys_out, batch_first=True, padding_value=float(-1)
-        )
+        ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1))
 
         device = memory.device
         ys_in_pad = ys_in_pad.to(device)
         ys_out_pad = ys_out_pad.to(device)
 
-        tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(
-            device
-        )
+        tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device)
 
         tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id)
         # TODO: Use length information to create the decoder padding mask
@@ -343,23 +333,17 @@ class Transformer(nn.Module):
 
         ys_in = add_sos(token_ids, sos_id=sos_id)
         ys_in = [torch.tensor(y) for y in ys_in]
-        ys_in_pad = pad_sequence(
-            ys_in, batch_first=True, padding_value=float(eos_id)
-        )
+        ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id))
 
         ys_out = add_eos(token_ids, eos_id=eos_id)
         ys_out = [torch.tensor(y) for y in ys_out]
-        ys_out_pad = pad_sequence(
-            ys_out, batch_first=True, padding_value=float(-1)
-        )
+        ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1))
 
         device = memory.device
         ys_in_pad = ys_in_pad.to(device, dtype=torch.int64)
         ys_out_pad = ys_out_pad.to(device, dtype=torch.int64)
 
-        tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(
-            device
-        )
+        tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device)
 
         tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id)
         # TODO: Use length information to create the decoder padding mask
@@ -632,9 +616,7 @@ def _get_activation_fn(activation: str):
     elif activation == "gelu":
         return nn.functional.gelu
 
-    raise RuntimeError(
-        "activation should be relu/gelu, not {}".format(activation)
-    )
+    raise RuntimeError("activation should be relu/gelu, not {}".format(activation))
 
 
 class PositionalEncoding(nn.Module):
@@ -836,9 +818,7 @@ def encoder_padding_mask(
         1,
     ).to(torch.int32)
 
-    lengths = [
-        0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)
-    ]
+    lengths = [0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)]
     for idx in range(supervision_segments.size(0)):
         # Note: TorchScript doesn't allow to unpack tensors as tuples
         sequence_idx = supervision_segments[idx, 0].item()
@@ -859,9 +839,7 @@ def encoder_padding_mask(
     return mask
 
 
-def decoder_padding_mask(
-    ys_pad: torch.Tensor, ignore_id: int = -1
-) -> torch.Tensor:
+def decoder_padding_mask(ys_pad: torch.Tensor, ignore_id: int = -1) -> torch.Tensor:
     """Generate a length mask for input.
 
     The masked position are filled with True,
diff --git a/egs/aishell/ASR/local/compute_fbank_aidatatang_200zh.py b/egs/aishell/ASR/local/compute_fbank_aidatatang_200zh.py
index 42700a972..037971927 100755
--- a/egs/aishell/ASR/local/compute_fbank_aidatatang_200zh.py
+++ b/egs/aishell/ASR/local/compute_fbank_aidatatang_200zh.py
@@ -87,9 +87,7 @@ def compute_fbank_aidatatang_200zh(num_mel_bins: int = 80):
             )
             if "train" in partition:
                 cut_set = (
-                    cut_set
-                    + cut_set.perturb_speed(0.9)
-                    + cut_set.perturb_speed(1.1)
+                    cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
                 )
             cut_set = cut_set.compute_and_store_features(
                 extractor=extractor,
@@ -116,9 +114,7 @@ def get_args():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
 
diff --git a/egs/aishell/ASR/local/compute_fbank_aishell.py b/egs/aishell/ASR/local/compute_fbank_aishell.py
index deab6c809..115ca1031 100755
--- a/egs/aishell/ASR/local/compute_fbank_aishell.py
+++ b/egs/aishell/ASR/local/compute_fbank_aishell.py
@@ -83,9 +83,7 @@ def compute_fbank_aishell(num_mel_bins: int = 80):
             )
             if "train" in partition:
                 cut_set = (
-                    cut_set
-                    + cut_set.perturb_speed(0.9)
-                    + cut_set.perturb_speed(1.1)
+                    cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
                 )
             cut_set = cut_set.compute_and_store_features(
                 extractor=extractor,
@@ -111,9 +109,7 @@ def get_args():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
 
diff --git a/egs/aishell/ASR/local/prepare_char.py b/egs/aishell/ASR/local/prepare_char.py
index d9e47d17a..6b440dfb3 100755
--- a/egs/aishell/ASR/local/prepare_char.py
+++ b/egs/aishell/ASR/local/prepare_char.py
@@ -86,9 +86,7 @@ def lexicon_to_fst_no_sil(
         cur_state = loop_state
 
         word = word2id[word]
-        pieces = [
-            token2id[i] if i in token2id else token2id[""] for i in pieces
-        ]
+        pieces = [token2id[i] if i in token2id else token2id[""] for i in pieces]
 
         for i in range(len(pieces) - 1):
             w = word if i == 0 else eps
@@ -142,9 +140,7 @@ def contain_oov(token_sym_table: Dict[str, int], tokens: List[str]) -> bool:
     return False
 
 
-def generate_lexicon(
-    token_sym_table: Dict[str, int], words: List[str]
-) -> Lexicon:
+def generate_lexicon(token_sym_table: Dict[str, int], words: List[str]) -> Lexicon:
     """Generate a lexicon from a word list and token_sym_table.
 
     Args:
diff --git a/egs/aishell/ASR/local/prepare_lang.py b/egs/aishell/ASR/local/prepare_lang.py
index e5ae89ec4..c8cf9b881 100755
--- a/egs/aishell/ASR/local/prepare_lang.py
+++ b/egs/aishell/ASR/local/prepare_lang.py
@@ -317,9 +317,7 @@ def lexicon_to_fst(
 
 def get_args():
     parser = argparse.ArgumentParser()
-    parser.add_argument(
-        "--lang-dir", type=str, help="The lang dir, data/lang_phone"
-    )
+    parser.add_argument("--lang-dir", type=str, help="The lang dir, data/lang_phone")
     return parser.parse_args()
 
 
diff --git a/egs/aishell/ASR/local/test_prepare_lang.py b/egs/aishell/ASR/local/test_prepare_lang.py
index d4cf62bba..74e025ad7 100755
--- a/egs/aishell/ASR/local/test_prepare_lang.py
+++ b/egs/aishell/ASR/local/test_prepare_lang.py
@@ -88,9 +88,7 @@ def test_read_lexicon(filename: str):
     fsa.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
     fsa.draw("L.pdf", title="L")
 
-    fsa_disambig = lexicon_to_fst(
-        lexicon_disambig, phone2id=phone2id, word2id=word2id
-    )
+    fsa_disambig = lexicon_to_fst(lexicon_disambig, phone2id=phone2id, word2id=word2id)
     fsa_disambig.labels_sym = k2.SymbolTable.from_file("phones.txt")
     fsa_disambig.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
     fsa_disambig.draw("L_disambig.pdf", title="L_disambig")
diff --git a/egs/aishell/ASR/pruned_transducer_stateless2/decode.py b/egs/aishell/ASR/pruned_transducer_stateless2/decode.py
index a12934d55..199acf6c3 100755
--- a/egs/aishell/ASR/pruned_transducer_stateless2/decode.py
+++ b/egs/aishell/ASR/pruned_transducer_stateless2/decode.py
@@ -76,11 +76,7 @@ from beam_search import (
 )
 from train import add_model_arguments, get_params, get_transducer_model
 
-from icefall.checkpoint import (
-    average_checkpoints,
-    find_checkpoints,
-    load_checkpoint,
-)
+from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
 from icefall.lexicon import Lexicon
 from icefall.utils import (
     AttributeDict,
@@ -188,8 +184,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=1,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -249,9 +244,7 @@ def decode_one_batch(
     supervisions = batch["supervisions"]
     feature_lens = supervisions["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(
-        x=feature, x_lens=feature_lens
-    )
+    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
 
     if params.decoding_method == "fast_beam_search":
         hyp_tokens = fast_beam_search_one_best(
@@ -263,10 +256,7 @@ def decode_one_batch(
             max_contexts=params.max_contexts,
             max_states=params.max_states,
         )
-    elif (
-        params.decoding_method == "greedy_search"
-        and params.max_sym_per_frame == 1
-    ):
+    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -387,9 +377,7 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
     return results
 
 
@@ -415,9 +403,7 @@ def save_results(
         # we compute CER for aishell dataset.
         results_char = []
         for res in results:
-            results_char.append(
-                (res[0], list("".join(res[1])), list("".join(res[2])))
-            )
+            results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
         with open(errs_filename, "w") as f:
             wer = write_error_stats(
                 f, f"{test_set_name}-{key}", results_char, enable_log=True
@@ -428,8 +414,7 @@ def save_results(
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = (
-        params.res_dir
-        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tWER", file=f)
@@ -473,9 +458,7 @@ def main():
         params.suffix += f"-max-contexts-{params.max_contexts}"
         params.suffix += f"-max-states-{params.max_states}"
     elif "beam_search" in params.decoding_method:
-        params.suffix += (
-            f"-{params.decoding_method}-beam-size-{params.beam_size}"
-        )
+        params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
     else:
         params.suffix += f"-context-{params.context_size}"
         params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
@@ -504,8 +487,7 @@ def main():
         ]
         if len(filenames) == 0:
             raise ValueError(
-                f"No checkpoints found for"
-                f" --iter {params.iter}, --avg {params.avg}"
+                f"No checkpoints found for" f" --iter {params.iter}, --avg {params.avg}"
             )
         elif len(filenames) < params.avg:
             raise ValueError(
diff --git a/egs/aishell/ASR/pruned_transducer_stateless2/export.py b/egs/aishell/ASR/pruned_transducer_stateless2/export.py
index feababdd2..4d41e425c 100755
--- a/egs/aishell/ASR/pruned_transducer_stateless2/export.py
+++ b/egs/aishell/ASR/pruned_transducer_stateless2/export.py
@@ -50,11 +50,7 @@ from pathlib import Path
 import torch
 from train import add_model_arguments, get_params, get_transducer_model
 
-from icefall.checkpoint import (
-    average_checkpoints,
-    find_checkpoints,
-    load_checkpoint,
-)
+from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
 from icefall.lexicon import Lexicon
 from icefall.utils import str2bool
 
@@ -120,8 +116,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=1,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
 
     add_model_arguments(parser)
@@ -157,8 +152,7 @@ def main():
         ]
         if len(filenames) == 0:
             raise ValueError(
-                f"No checkpoints found for"
-                f" --iter {params.iter}, --avg {params.avg}"
+                f"No checkpoints found for" f" --iter {params.iter}, --avg {params.avg}"
             )
         elif len(filenames) < params.avg:
             raise ValueError(
@@ -191,9 +185,7 @@ def main():
         model.__class__.forward = torch.jit.ignore(model.__class__.forward)
         logging.info("Using torch.jit.script")
         model = torch.jit.script(model)
-        filename = (
-            params.exp_dir / f"cpu_jit-epoch-{params.epoch}-avg-{params.avg}.pt"
-        )
+        filename = params.exp_dir / f"cpu_jit-epoch-{params.epoch}-avg-{params.avg}.pt"
         model.save(str(filename))
         logging.info(f"Saved to {filename}")
     else:
@@ -201,17 +193,14 @@ def main():
         # Save it using a format so that it can be loaded
         # by :func:`load_checkpoint`
         filename = (
-            params.exp_dir
-            / f"pretrained-epoch-{params.epoch}-avg-{params.avg}.pt"
+            params.exp_dir / f"pretrained-epoch-{params.epoch}-avg-{params.avg}.pt"
         )
         torch.save({"model": model.state_dict()}, str(filename))
         logging.info(f"Saved to {filename}")
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/aishell/ASR/pruned_transducer_stateless2/pretrained.py b/egs/aishell/ASR/pruned_transducer_stateless2/pretrained.py
index 3c38e5db7..8aa0fbdd7 100755
--- a/egs/aishell/ASR/pruned_transducer_stateless2/pretrained.py
+++ b/egs/aishell/ASR/pruned_transducer_stateless2/pretrained.py
@@ -165,8 +165,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=1,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -197,8 +196,7 @@ def read_sound_files(
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
         assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. "
-            f"Given: {sample_rate}"
+            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
         )
         # We use only the first channel
         ans.append(wave[0])
@@ -256,13 +254,9 @@ def main():
     feature_lens = [f.size(0) for f in features]
     feature_lens = torch.tensor(feature_lens, device=device)
 
-    features = pad_sequence(
-        features, batch_first=True, padding_value=math.log(1e-10)
-    )
+    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
 
-    encoder_out, encoder_out_lens = model.encoder(
-        x=features, x_lens=feature_lens
-    )
+    encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lens)
 
     num_waves = encoder_out.size(0)
     hyp_list = []
@@ -310,9 +304,7 @@ def main():
                     beam=params.beam_size,
                 )
             else:
-                raise ValueError(
-                    f"Unsupported decoding method: {params.method}"
-                )
+                raise ValueError(f"Unsupported decoding method: {params.method}")
             hyp_list.append(hyp)
 
     hyps = []
@@ -329,9 +321,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/aishell/ASR/pruned_transducer_stateless2/train.py b/egs/aishell/ASR/pruned_transducer_stateless2/train.py
index 97d892754..f81ab2568 100755
--- a/egs/aishell/ASR/pruned_transducer_stateless2/train.py
+++ b/egs/aishell/ASR/pruned_transducer_stateless2/train.py
@@ -49,7 +49,6 @@ import optim
 import torch
 import torch.multiprocessing as mp
 import torch.nn as nn
-
 from asr_datamodule import AishellAsrDataModule
 from conformer import Conformer
 from decoder import Decoder
@@ -75,9 +74,7 @@ from icefall.env import get_env_info
 from icefall.lexicon import Lexicon
 from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
 
-LRSchedulerType = Union[
-    torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
-]
+LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
 
 
 def add_model_arguments(parser: argparse.ArgumentParser):
@@ -203,8 +200,7 @@ def get_parser():
         "--initial-lr",
         type=float,
         default=0.003,
-        help="The initial learning rate.  This value should not need "
-        "to be changed.",
+        help="The initial learning rate.  This value should not need " "to be changed.",
     )
 
     parser.add_argument(
@@ -227,8 +223,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=1,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
 
     parser.add_argument(
@@ -251,8 +246,7 @@ def get_parser():
         "--am-scale",
         type=float,
         default=0.0,
-        help="The scale to smooth the loss with am (output of encoder network)"
-        "part.",
+        help="The scale to smooth the loss with am (output of encoder network)" "part.",
     )
 
     parser.add_argument(
@@ -561,11 +555,7 @@ def compute_loss(
      warmup: a floating point value which increases throughout training;
         values >= 1.0 are fully warmed up and have all modules present.
     """
-    device = (
-        model.device
-        if isinstance(model, DDP)
-        else next(model.parameters()).device
-    )
+    device = model.device if isinstance(model, DDP) else next(model.parameters()).device
     feature = batch["inputs"]
     # at entry, feature is (N, T, C)
     assert feature.ndim == 3
@@ -593,23 +583,16 @@ def compute_loss(
         # overwhelming the simple_loss and causing it to diverge,
         # in case it had not fully learned the alignment yet.
         pruned_loss_scale = (
-            0.0
-            if warmup < 1.0
-            else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
-        )
-        loss = (
-            params.simple_loss_scale * simple_loss
-            + pruned_loss_scale * pruned_loss
+            0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
         )
+        loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
 
     assert loss.requires_grad == is_training
 
     info = MetricsTracker()
     with warnings.catch_warnings():
         warnings.simplefilter("ignore")
-        info["frames"] = (
-            (feature_lens // params.subsampling_factor).sum().item()
-        )
+        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
 
     # Note: We use reduction=sum while computing the loss.
     info["loss"] = loss.detach().cpu().item()
@@ -725,9 +708,7 @@ def train_one_epoch(
             scaler.update()
             optimizer.zero_grad()
         except:  # noqa
-            display_and_save_batch(
-                batch, params=params, graph_compiler=graph_compiler
-            )
+            display_and_save_batch(batch, params=params, graph_compiler=graph_compiler)
             raise
 
         if params.print_diagnostics and batch_idx == 5:
@@ -891,7 +872,7 @@ def run(rank, world_size, args):
 
     if params.print_diagnostics:
         opts = diagnostics.TensorDiagnosticOptions(
-            2 ** 22
+            2**22
         )  # allow 4 megabytes per sub-module
         diagnostic = diagnostics.attach_diagnostics(model, opts)
 
@@ -1029,9 +1010,7 @@ def scan_pessimistic_batches_for_oom(
                     f"Failing criterion: {criterion} "
                     f"(={crit_values[criterion]}) ..."
                 )
-            display_and_save_batch(
-                batch, params=params, graph_compiler=graph_compiler
-            )
+            display_and_save_batch(batch, params=params, graph_compiler=graph_compiler)
             raise
 
 
diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/decode.py b/egs/aishell/ASR/pruned_transducer_stateless3/decode.py
index d159e420b..f6c919e9d 100755
--- a/egs/aishell/ASR/pruned_transducer_stateless3/decode.py
+++ b/egs/aishell/ASR/pruned_transducer_stateless3/decode.py
@@ -202,8 +202,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=1,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -263,9 +262,7 @@ def decode_one_batch(
     supervisions = batch["supervisions"]
     feature_lens = supervisions["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(
-        x=feature, x_lens=feature_lens
-    )
+    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
 
     if params.decoding_method == "fast_beam_search":
         hyp_tokens = fast_beam_search_one_best(
@@ -277,10 +274,7 @@ def decode_one_batch(
             max_contexts=params.max_contexts,
             max_states=params.max_states,
         )
-    elif (
-        params.decoding_method == "greedy_search"
-        and params.max_sym_per_frame == 1
-    ):
+    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -401,9 +395,7 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
     return results
 
 
@@ -429,9 +421,7 @@ def save_results(
         # we compute CER for aishell dataset.
         results_char = []
         for res in results:
-            results_char.append(
-                (res[0], list("".join(res[1])), list("".join(res[2])))
-            )
+            results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
         with open(errs_filename, "w") as f:
             wer = write_error_stats(
                 f, f"{test_set_name}-{key}", results_char, enable_log=True
@@ -442,8 +432,7 @@ def save_results(
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = (
-        params.res_dir
-        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tCER", file=f)
@@ -488,9 +477,7 @@ def main():
         params.suffix += f"-max-contexts-{params.max_contexts}"
         params.suffix += f"-max-states-{params.max_states}"
     elif "beam_search" in params.decoding_method:
-        params.suffix += (
-            f"-{params.decoding_method}-beam-size-{params.beam_size}"
-        )
+        params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
     else:
         params.suffix += f"-context-{params.context_size}"
         params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
@@ -518,9 +505,9 @@ def main():
 
     if not params.use_averaged_model:
         if params.iter > 0:
-            filenames = find_checkpoints(
-                params.exp_dir, iteration=-params.iter
-            )[: params.avg]
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg
+            ]
             if len(filenames) == 0:
                 raise ValueError(
                     f"No checkpoints found for"
@@ -551,9 +538,9 @@ def main():
             )
     else:
         if params.iter > 0:
-            filenames = find_checkpoints(
-                params.exp_dir, iteration=-params.iter
-            )[: params.avg + 1]
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg + 1
+            ]
             if len(filenames) == 0:
                 raise ValueError(
                     f"No checkpoints found for"
diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/export.py b/egs/aishell/ASR/pruned_transducer_stateless3/export.py
index 566902a85..5e701c121 100755
--- a/egs/aishell/ASR/pruned_transducer_stateless3/export.py
+++ b/egs/aishell/ASR/pruned_transducer_stateless3/export.py
@@ -132,8 +132,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=1,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
 
     add_model_arguments(parser)
@@ -166,9 +165,9 @@ def main():
 
     if not params.use_averaged_model:
         if params.iter > 0:
-            filenames = find_checkpoints(
-                params.exp_dir, iteration=-params.iter
-            )[: params.avg]
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg
+            ]
             if len(filenames) == 0:
                 raise ValueError(
                     f"No checkpoints found for"
@@ -195,9 +194,9 @@ def main():
             model.load_state_dict(average_checkpoints(filenames, device=device))
     else:
         if params.iter > 0:
-            filenames = find_checkpoints(
-                params.exp_dir, iteration=-params.iter
-            )[: params.avg + 1]
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg + 1
+            ]
             if len(filenames) == 0:
                 raise ValueError(
                     f"No checkpoints found for"
@@ -252,9 +251,7 @@ def main():
         model.__class__.forward = torch.jit.ignore(model.__class__.forward)
         logging.info("Using torch.jit.script")
         model = torch.jit.script(model)
-        filename = (
-            params.exp_dir / f"cpu_jit-epoch-{params.epoch}-avg-{params.avg}.pt"
-        )
+        filename = params.exp_dir / f"cpu_jit-epoch-{params.epoch}-avg-{params.avg}.pt"
         model.save(str(filename))
         logging.info(f"Saved to {filename}")
     else:
@@ -262,17 +259,14 @@ def main():
         # Save it using a format so that it can be loaded
         # by :func:`load_checkpoint`
         filename = (
-            params.exp_dir
-            / f"pretrained-epoch-{params.epoch}-avg-{params.avg}.pt"
+            params.exp_dir / f"pretrained-epoch-{params.epoch}-avg-{params.avg}.pt"
         )
         torch.save({"model": model.state_dict()}, str(filename))
         logging.info(f"Saved to {filename}")
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/model.py b/egs/aishell/ASR/pruned_transducer_stateless3/model.py
index e150e8230..a4dda0d6d 100644
--- a/egs/aishell/ASR/pruned_transducer_stateless3/model.py
+++ b/egs/aishell/ASR/pruned_transducer_stateless3/model.py
@@ -84,9 +84,7 @@ class Transducer(nn.Module):
         self.decoder_datatang = decoder_datatang
         self.joiner_datatang = joiner_datatang
 
-        self.simple_am_proj = ScaledLinear(
-            encoder_dim, vocab_size, initial_speed=0.5
-        )
+        self.simple_am_proj = ScaledLinear(encoder_dim, vocab_size, initial_speed=0.5)
         self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size)
 
         if decoder_datatang is not None:
@@ -179,9 +177,7 @@ class Transducer(nn.Module):
         y_padded = y.pad(mode="constant", padding_value=0)
 
         y_padded = y_padded.to(torch.int64)
-        boundary = torch.zeros(
-            (x.size(0), 4), dtype=torch.int64, device=x.device
-        )
+        boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device)
         boundary[:, 2] = y_lens
         boundary[:, 3] = encoder_out_lens
 
diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/pretrained.py b/egs/aishell/ASR/pruned_transducer_stateless3/pretrained.py
index 04a0a882a..40926173c 100755
--- a/egs/aishell/ASR/pruned_transducer_stateless3/pretrained.py
+++ b/egs/aishell/ASR/pruned_transducer_stateless3/pretrained.py
@@ -165,8 +165,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=1,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -197,8 +196,7 @@ def read_sound_files(
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
         assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. "
-            f"Given: {sample_rate}"
+            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
         )
         # We use only the first channel
         ans.append(wave[0])
@@ -257,13 +255,9 @@ def main():
     feature_lens = [f.size(0) for f in features]
     feature_lens = torch.tensor(feature_lens, device=device)
 
-    features = pad_sequence(
-        features, batch_first=True, padding_value=math.log(1e-10)
-    )
+    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
 
-    encoder_out, encoder_out_lens = model.encoder(
-        x=features, x_lens=feature_lens
-    )
+    encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lens)
 
     num_waves = encoder_out.size(0)
     hyp_list = []
@@ -311,9 +305,7 @@ def main():
                     beam=params.beam_size,
                 )
             else:
-                raise ValueError(
-                    f"Unsupported decoding method: {params.method}"
-                )
+                raise ValueError(f"Unsupported decoding method: {params.method}")
             hyp_list.append(hyp)
 
     hyps = []
@@ -330,9 +322,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/train.py b/egs/aishell/ASR/pruned_transducer_stateless3/train.py
index feaef5cf6..680986ee9 100755
--- a/egs/aishell/ASR/pruned_transducer_stateless3/train.py
+++ b/egs/aishell/ASR/pruned_transducer_stateless3/train.py
@@ -96,9 +96,7 @@ from icefall.env import get_env_info
 from icefall.lexicon import Lexicon
 from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
 
-LRSchedulerType = Union[
-    torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
-]
+LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
 
 
 def add_model_arguments(parser: argparse.ArgumentParser):
@@ -224,8 +222,7 @@ def get_parser():
         "--initial-lr",
         type=float,
         default=0.003,
-        help="The initial learning rate.  This value should not need "
-        "to be changed.",
+        help="The initial learning rate.  This value should not need " "to be changed.",
     )
 
     parser.add_argument(
@@ -248,8 +245,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=1,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
 
     parser.add_argument(
@@ -272,8 +268,7 @@ def get_parser():
         "--am-scale",
         type=float,
         default=0.0,
-        help="The scale to smooth the loss with am (output of encoder network)"
-        "part.",
+        help="The scale to smooth the loss with am (output of encoder network)" "part.",
     )
 
     parser.add_argument(
@@ -635,11 +630,7 @@ def compute_loss(
      warmup: a floating point value which increases throughout training;
         values >= 1.0 are fully warmed up and have all modules present.
     """
-    device = (
-        model.device
-        if isinstance(model, DDP)
-        else next(model.parameters()).device
-    )
+    device = model.device if isinstance(model, DDP) else next(model.parameters()).device
     feature = batch["inputs"]
     # at entry, feature is (N, T, C)
     assert feature.ndim == 3
@@ -670,23 +661,16 @@ def compute_loss(
         # overwhelming the simple_loss and causing it to diverge,
         # in case it had not fully learned the alignment yet.
         pruned_loss_scale = (
-            0.0
-            if warmup < 1.0
-            else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
-        )
-        loss = (
-            params.simple_loss_scale * simple_loss
-            + pruned_loss_scale * pruned_loss
+            0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
         )
+        loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
 
     assert loss.requires_grad == is_training
 
     info = MetricsTracker()
     with warnings.catch_warnings():
         warnings.simplefilter("ignore")
-        info["frames"] = (
-            (feature_lens // params.subsampling_factor).sum().item()
-        )
+        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
 
     # Note: We use reduction=sum while computing the loss.
     info["loss"] = loss.detach().cpu().item()
@@ -824,9 +808,7 @@ def train_one_epoch(
                 )
             # summary stats
             if datatang_train_dl is not None:
-                tot_loss = (
-                    tot_loss * (1 - 1 / params.reset_interval)
-                ) + loss_info
+                tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
 
             if aishell:
                 aishell_tot_loss = (
@@ -847,9 +829,7 @@ def train_one_epoch(
             scaler.update()
             optimizer.zero_grad()
         except:  # noqa
-            display_and_save_batch(
-                batch, params=params, graph_compiler=graph_compiler
-            )
+            display_and_save_batch(batch, params=params, graph_compiler=graph_compiler)
             raise
 
         if params.print_diagnostics and batch_idx == 5:
@@ -892,9 +872,7 @@ def train_one_epoch(
             cur_lr = scheduler.get_last_lr()[0]
             if datatang_train_dl is not None:
                 datatang_str = f"datatang_tot_loss[{datatang_tot_loss}], "
-                tot_loss_str = (
-                    f"tot_loss[{tot_loss}], batch size: {batch_size}, "
-                )
+                tot_loss_str = f"tot_loss[{tot_loss}], batch size: {batch_size}, "
             else:
                 tot_loss_str = ""
                 datatang_str = ""
@@ -1067,7 +1045,7 @@ def run(rank, world_size, args):
 
     if params.print_diagnostics:
         opts = diagnostics.TensorDiagnosticOptions(
-            2 ** 22
+            2**22
         )  # allow 4 megabytes per sub-module
         diagnostic = diagnostics.attach_diagnostics(model, opts)
 
@@ -1076,9 +1054,7 @@ def run(rank, world_size, args):
     train_cuts = filter_short_and_long_utterances(train_cuts)
 
     if args.enable_musan:
-        cuts_musan = load_manifest(
-            Path(args.manifest_dir) / "musan_cuts.jsonl.gz"
-        )
+        cuts_musan = load_manifest(Path(args.manifest_dir) / "musan_cuts.jsonl.gz")
     else:
         cuts_musan = None
 
@@ -1093,9 +1069,7 @@ def run(rank, world_size, args):
     if params.datatang_prob > 0:
         datatang = AIDatatang200zh(manifest_dir=args.manifest_dir)
         train_datatang_cuts = datatang.train_cuts()
-        train_datatang_cuts = filter_short_and_long_utterances(
-            train_datatang_cuts
-        )
+        train_datatang_cuts = filter_short_and_long_utterances(train_datatang_cuts)
         train_datatang_cuts = train_datatang_cuts.repeat(times=None)
         datatang_train_dl = asr_datamodule.train_dataloaders(
             train_datatang_cuts,
@@ -1249,9 +1223,7 @@ def scan_pessimistic_batches_for_oom(
                     f"Failing criterion: {criterion} "
                     f"(={crit_values[criterion]}) ..."
                 )
-            display_and_save_batch(
-                batch, params=params, graph_compiler=graph_compiler
-            )
+            display_and_save_batch(batch, params=params, graph_compiler=graph_compiler)
             raise
 
 
diff --git a/egs/aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py
index d24ba6bb7..fc28e8dbc 100644
--- a/egs/aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py
+++ b/egs/aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py
@@ -183,17 +183,13 @@ class AishellAsrDataModule:
 
     def train_dataloaders(self, cuts_train: CutSet) -> DataLoader:
         logging.info("About to get Musan cuts")
-        cuts_musan = load_manifest(
-            self.args.manifest_dir / "musan_cuts.jsonl.gz"
-        )
+        cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
 
         transforms = []
         if self.args.enable_musan:
             logging.info("Enable MUSAN")
             transforms.append(
-                CutMix(
-                    cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
-                )
+                CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
             )
         else:
             logging.info("Disable MUSAN")
@@ -215,9 +211,7 @@ class AishellAsrDataModule:
         input_transforms = []
         if self.args.enable_spec_aug:
             logging.info("Enable SpecAugment")
-            logging.info(
-                f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
-            )
+            logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
             # Set the value of num_frame_masks according to Lhotse's version.
             # In different Lhotse's versions, the default of num_frame_masks is
             # different.
@@ -260,9 +254,7 @@ class AishellAsrDataModule:
             # Drop feats to be on the safe side.
             train = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(
-                    Fbank(FbankConfig(num_mel_bins=80))
-                ),
+                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
                 input_transforms=input_transforms,
                 return_cuts=self.args.return_cuts,
             )
@@ -308,9 +300,7 @@ class AishellAsrDataModule:
         if self.args.on_the_fly_feats:
             validate = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(
-                    Fbank(FbankConfig(num_mel_bins=80))
-                ),
+                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
                 return_cuts=self.args.return_cuts,
             )
         else:
@@ -366,13 +356,9 @@ class AishellAsrDataModule:
     @lru_cache()
     def valid_cuts(self) -> CutSet:
         logging.info("About to get dev cuts")
-        return load_manifest_lazy(
-            self.args.manifest_dir / "aishell_cuts_dev.jsonl.gz"
-        )
+        return load_manifest_lazy(self.args.manifest_dir / "aishell_cuts_dev.jsonl.gz")
 
     @lru_cache()
     def test_cuts(self) -> List[CutSet]:
         logging.info("About to get test cuts")
-        return load_manifest_lazy(
-            self.args.manifest_dir / "aishell_cuts_test.jsonl.gz"
-        )
+        return load_manifest_lazy(self.args.manifest_dir / "aishell_cuts_test.jsonl.gz")
diff --git a/egs/aishell/ASR/tdnn_lstm_ctc/decode.py b/egs/aishell/ASR/tdnn_lstm_ctc/decode.py
index 66b734fc4..824ca2a92 100755
--- a/egs/aishell/ASR/tdnn_lstm_ctc/decode.py
+++ b/egs/aishell/ASR/tdnn_lstm_ctc/decode.py
@@ -265,9 +265,7 @@ def decode_dataset(
         if batch_idx % 100 == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
     return results
 
 
@@ -289,9 +287,7 @@ def save_results(
         # We compute CER for aishell dataset.
         results_char = []
         for res in results:
-            results_char.append(
-                (res[0], list("".join(res[1])), list("".join(res[2])))
-            )
+            results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
         with open(errs_filename, "w") as f:
             wer = write_error_stats(f, f"{test_set_name}-{key}", results_char)
             test_set_wers[key] = wer
@@ -335,9 +331,7 @@ def main():
 
     logging.info(f"device: {device}")
 
-    HLG = k2.Fsa.from_dict(
-        torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")
-    )
+    HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu"))
     HLG = HLG.to(device)
     assert HLG.requires_grad is False
 
@@ -362,9 +356,7 @@ def main():
 
     if params.export:
         logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
-        torch.save(
-            {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
-        )
+        torch.save({"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt")
 
     model.to(device)
     model.eval()
@@ -392,9 +384,7 @@ def main():
             lexicon=lexicon,
         )
 
-        save_results(
-            params=params, test_set_name=test_set, results_dict=results_dict
-        )
+        save_results(params=params, test_set_name=test_set, results_dict=results_dict)
 
     logging.info("Done!")
 
diff --git a/egs/aishell/ASR/tdnn_lstm_ctc/model.py b/egs/aishell/ASR/tdnn_lstm_ctc/model.py
index 5e04c11b4..1731e1ebe 100644
--- a/egs/aishell/ASR/tdnn_lstm_ctc/model.py
+++ b/egs/aishell/ASR/tdnn_lstm_ctc/model.py
@@ -66,10 +66,7 @@ class TdnnLstm(nn.Module):
             nn.BatchNorm1d(num_features=500, affine=False),
         )
         self.lstms = nn.ModuleList(
-            [
-                nn.LSTM(input_size=500, hidden_size=500, num_layers=1)
-                for _ in range(5)
-            ]
+            [nn.LSTM(input_size=500, hidden_size=500, num_layers=1) for _ in range(5)]
         )
         self.lstm_bnorms = nn.ModuleList(
             [nn.BatchNorm1d(num_features=500, affine=False) for _ in range(5)]
diff --git a/egs/aishell/ASR/tdnn_lstm_ctc/pretrained.py b/egs/aishell/ASR/tdnn_lstm_ctc/pretrained.py
index 9bd810809..fe197a9f9 100644
--- a/egs/aishell/ASR/tdnn_lstm_ctc/pretrained.py
+++ b/egs/aishell/ASR/tdnn_lstm_ctc/pretrained.py
@@ -53,9 +53,7 @@ def get_parser():
         help="Path to words.txt",
     )
 
-    parser.add_argument(
-        "--HLG", type=str, required=True, help="Path to HLG.pt."
-    )
+    parser.add_argument("--HLG", type=str, required=True, help="Path to HLG.pt.")
 
     parser.add_argument(
         "--method",
@@ -113,8 +111,7 @@ def read_sound_files(
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
         assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. "
-            f"Given: {sample_rate}"
+            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
         )
         # We use only the first channel
         ans.append(wave[0])
@@ -173,9 +170,7 @@ def main():
     logging.info("Decoding started")
     features = fbank(waves)
 
-    features = pad_sequence(
-        features, batch_first=True, padding_value=math.log(1e-10)
-    )
+    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
     features = features.permute(0, 2, 1)  # now features is [N, C, T]
 
     with torch.no_grad():
@@ -219,9 +214,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/aishell/ASR/tdnn_lstm_ctc/train.py b/egs/aishell/ASR/tdnn_lstm_ctc/train.py
index 7619b0551..e574cf89b 100755
--- a/egs/aishell/ASR/tdnn_lstm_ctc/train.py
+++ b/egs/aishell/ASR/tdnn_lstm_ctc/train.py
@@ -49,12 +49,7 @@ from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
 from icefall.dist import cleanup_dist, setup_dist
 from icefall.graph_compiler import CtcTrainingGraphCompiler
 from icefall.lexicon import Lexicon
-from icefall.utils import (
-    AttributeDict,
-    encode_supervisions,
-    setup_logger,
-    str2bool,
-)
+from icefall.utils import AttributeDict, encode_supervisions, setup_logger, str2bool
 
 
 def get_parser():
diff --git a/egs/aishell/ASR/transducer_stateless/beam_search.py b/egs/aishell/ASR/transducer_stateless/beam_search.py
index 9ed9b2ad1..de0a8d0f5 100644
--- a/egs/aishell/ASR/transducer_stateless/beam_search.py
+++ b/egs/aishell/ASR/transducer_stateless/beam_search.py
@@ -47,9 +47,9 @@ def greedy_search(
 
     device = model.device
 
-    decoder_input = torch.tensor(
-        [blank_id] * context_size, device=device
-    ).reshape(1, context_size)
+    decoder_input = torch.tensor([blank_id] * context_size, device=device).reshape(
+        1, context_size
+    )
 
     decoder_out = model.decoder(decoder_input, need_pad=False)
 
@@ -81,9 +81,9 @@ def greedy_search(
         y = logits.argmax().item()
         if y != blank_id:
             hyp.append(y)
-            decoder_input = torch.tensor(
-                [hyp[-context_size:]], device=device
-            ).reshape(1, context_size)
+            decoder_input = torch.tensor([hyp[-context_size:]], device=device).reshape(
+                1, context_size
+            )
 
             decoder_out = model.decoder(decoder_input, need_pad=False)
 
@@ -157,9 +157,7 @@ class HypothesisList(object):
 
         """
         if length_norm:
-            return max(
-                self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys)
-            )
+            return max(self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys))
         else:
             return max(self._data.values(), key=lambda hyp: hyp.log_prob)
 
@@ -246,9 +244,9 @@ def beam_search(
 
     device = model.device
 
-    decoder_input = torch.tensor(
-        [blank_id] * context_size, device=device
-    ).reshape(1, context_size)
+    decoder_input = torch.tensor([blank_id] * context_size, device=device).reshape(
+        1, context_size
+    )
 
     decoder_out = model.decoder(decoder_input, need_pad=False)
 
diff --git a/egs/aishell/ASR/transducer_stateless/conformer.py b/egs/aishell/ASR/transducer_stateless/conformer.py
index 64114253d..78424aea2 100644
--- a/egs/aishell/ASR/transducer_stateless/conformer.py
+++ b/egs/aishell/ASR/transducer_stateless/conformer.py
@@ -155,9 +155,7 @@ class ConformerEncoderLayer(nn.Module):
         normalize_before: bool = True,
     ) -> None:
         super(ConformerEncoderLayer, self).__init__()
-        self.self_attn = RelPositionMultiheadAttention(
-            d_model, nhead, dropout=0.0
-        )
+        self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0)
 
         self.feed_forward = nn.Sequential(
             nn.Linear(d_model, dim_feedforward),
@@ -175,18 +173,14 @@ class ConformerEncoderLayer(nn.Module):
 
         self.conv_module = ConvolutionModule(d_model, cnn_module_kernel)
 
-        self.norm_ff_macaron = nn.LayerNorm(
-            d_model
-        )  # for the macaron style FNN module
+        self.norm_ff_macaron = nn.LayerNorm(d_model)  # for the macaron style FNN module
         self.norm_ff = nn.LayerNorm(d_model)  # for the FNN module
         self.norm_mha = nn.LayerNorm(d_model)  # for the MHA module
 
         self.ff_scale = 0.5
 
         self.norm_conv = nn.LayerNorm(d_model)  # for the CNN module
-        self.norm_final = nn.LayerNorm(
-            d_model
-        )  # for the final output of the block
+        self.norm_final = nn.LayerNorm(d_model)  # for the final output of the block
 
         self.dropout = nn.Dropout(dropout)
 
@@ -220,9 +214,7 @@ class ConformerEncoderLayer(nn.Module):
         residual = src
         if self.normalize_before:
             src = self.norm_ff_macaron(src)
-        src = residual + self.ff_scale * self.dropout(
-            self.feed_forward_macaron(src)
-        )
+        src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src))
         if not self.normalize_before:
             src = self.norm_ff_macaron(src)
 
@@ -341,9 +333,7 @@ class RelPositionalEncoding(torch.nn.Module):
 
     """
 
-    def __init__(
-        self, d_model: int, dropout_rate: float, max_len: int = 5000
-    ) -> None:
+    def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None:
         """Construct an PositionalEncoding object."""
         super(RelPositionalEncoding, self).__init__()
         self.d_model = d_model
@@ -359,9 +349,7 @@ class RelPositionalEncoding(torch.nn.Module):
             # the length of self.pe is 2 * input_len - 1
             if self.pe.size(1) >= x.size(1) * 2 - 1:
                 # Note: TorchScript doesn't implement operator== for torch.Device
-                if self.pe.dtype != x.dtype or str(self.pe.device) != str(
-                    x.device
-                ):
+                if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device):
                     self.pe = self.pe.to(dtype=x.dtype, device=x.device)
                 return
         # Suppose `i` means to the position of query vector and `j` means the
@@ -631,9 +619,9 @@ class RelPositionMultiheadAttention(nn.Module):
 
         if torch.equal(query, key) and torch.equal(key, value):
             # self-attention
-            q, k, v = nn.functional.linear(
-                query, in_proj_weight, in_proj_bias
-            ).chunk(3, dim=-1)
+            q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk(
+                3, dim=-1
+            )
 
         elif torch.equal(key, value):
             # encoder-decoder attention
@@ -701,31 +689,22 @@ class RelPositionMultiheadAttention(nn.Module):
             if attn_mask.dim() == 2:
                 attn_mask = attn_mask.unsqueeze(0)
                 if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
-                    raise RuntimeError(
-                        "The size of the 2D attn_mask is not correct."
-                    )
+                    raise RuntimeError("The size of the 2D attn_mask is not correct.")
             elif attn_mask.dim() == 3:
                 if list(attn_mask.size()) != [
                     bsz * num_heads,
                     query.size(0),
                     key.size(0),
                 ]:
-                    raise RuntimeError(
-                        "The size of the 3D attn_mask is not correct."
-                    )
+                    raise RuntimeError("The size of the 3D attn_mask is not correct.")
             else:
                 raise RuntimeError(
-                    "attn_mask's dimension {} is not supported".format(
-                        attn_mask.dim()
-                    )
+                    "attn_mask's dimension {} is not supported".format(attn_mask.dim())
                 )
             # attn_mask's dim is 3 now.
 
         # convert ByteTensor key_padding_mask to bool
-        if (
-            key_padding_mask is not None
-            and key_padding_mask.dtype == torch.uint8
-        ):
+        if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
             warnings.warn(
                 "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead."
             )
@@ -764,9 +743,7 @@ class RelPositionMultiheadAttention(nn.Module):
         # first compute matrix a and matrix c
         # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
         k = k.permute(1, 2, 3, 0)  # (batch, head, d_k, time2)
-        matrix_ac = torch.matmul(
-            q_with_bias_u, k
-        )  # (batch, head, time1, time2)
+        matrix_ac = torch.matmul(q_with_bias_u, k)  # (batch, head, time1, time2)
 
         # compute matrix b and matrix d
         matrix_bd = torch.matmul(
@@ -778,9 +755,7 @@ class RelPositionMultiheadAttention(nn.Module):
             matrix_ac + matrix_bd
         ) * scaling  # (batch, head, time1, time2)
 
-        attn_output_weights = attn_output_weights.view(
-            bsz * num_heads, tgt_len, -1
-        )
+        attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1)
 
         assert list(attn_output_weights.size()) == [
             bsz * num_heads,
@@ -814,13 +789,9 @@ class RelPositionMultiheadAttention(nn.Module):
         attn_output = torch.bmm(attn_output_weights, v)
         assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
         attn_output = (
-            attn_output.transpose(0, 1)
-            .contiguous()
-            .view(tgt_len, bsz, embed_dim)
-        )
-        attn_output = nn.functional.linear(
-            attn_output, out_proj_weight, out_proj_bias
+            attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
         )
+        attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias)
 
         if need_weights:
             # average attention weights over heads
@@ -843,9 +814,7 @@ class ConvolutionModule(nn.Module):
 
     """
 
-    def __init__(
-        self, channels: int, kernel_size: int, bias: bool = True
-    ) -> None:
+    def __init__(self, channels: int, kernel_size: int, bias: bool = True) -> None:
         """Construct an ConvolutionModule object."""
         super(ConvolutionModule, self).__init__()
         # kernerl_size should be a odd number for 'SAME' padding
diff --git a/egs/aishell/ASR/transducer_stateless/decode.py b/egs/aishell/ASR/transducer_stateless/decode.py
index 780b0c4bb..fbc54f68b 100755
--- a/egs/aishell/ASR/transducer_stateless/decode.py
+++ b/egs/aishell/ASR/transducer_stateless/decode.py
@@ -99,8 +99,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -227,9 +226,7 @@ def decode_one_batch(
     supervisions = batch["supervisions"]
     feature_lens = supervisions["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(
-        x=feature, x_lens=feature_lens
-    )
+    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
     hyps = []
     batch_size = encoder_out.size(0)
 
@@ -248,9 +245,7 @@ def decode_one_batch(
                 model=model, encoder_out=encoder_out_i, beam=params.beam_size
             )
         else:
-            raise ValueError(
-                f"Unsupported decoding method: {params.decoding_method}"
-            )
+            raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
         hyps.append([lexicon.token_table[i] for i in hyp])
 
     if params.decoding_method == "greedy_search":
@@ -319,9 +314,7 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
     return results
 
 
@@ -346,9 +339,7 @@ def save_results(
         # we compute CER for aishell dataset.
         results_char = []
         for res in results:
-            results_char.append(
-                (res[0], list("".join(res[1])), list("".join(res[2])))
-            )
+            results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
         with open(errs_filename, "w") as f:
             wer = write_error_stats(
                 f, f"{test_set_name}-{key}", results_char, enable_log=True
@@ -359,8 +350,7 @@ def save_results(
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = (
-        params.res_dir
-        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tCER", file=f)
@@ -430,9 +420,7 @@ def main():
 
     if params.export:
         logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
-        torch.save(
-            {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
-        )
+        torch.save({"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt")
         return
 
     model.to(device)
diff --git a/egs/aishell/ASR/transducer_stateless/decoder.py b/egs/aishell/ASR/transducer_stateless/decoder.py
index c2c6552a9..70e9e6c96 100644
--- a/egs/aishell/ASR/transducer_stateless/decoder.py
+++ b/egs/aishell/ASR/transducer_stateless/decoder.py
@@ -86,9 +86,7 @@ class Decoder(nn.Module):
         if self.context_size > 1:
             embedding_out = embedding_out.permute(0, 2, 1)
             if need_pad is True:
-                embedding_out = F.pad(
-                    embedding_out, pad=(self.context_size - 1, 0)
-                )
+                embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0))
             else:
                 # During inference time, there is no need to do extra padding
                 # as we only need one output
diff --git a/egs/aishell/ASR/transducer_stateless/export.py b/egs/aishell/ASR/transducer_stateless/export.py
index 4c6519b96..eea9b6883 100755
--- a/egs/aishell/ASR/transducer_stateless/export.py
+++ b/egs/aishell/ASR/transducer_stateless/export.py
@@ -110,8 +110,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
 
     return parser
@@ -243,9 +242,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/aishell/ASR/transducer_stateless/model.py b/egs/aishell/ASR/transducer_stateless/model.py
index 994305fc1..591bbe44f 100644
--- a/egs/aishell/ASR/transducer_stateless/model.py
+++ b/egs/aishell/ASR/transducer_stateless/model.py
@@ -103,9 +103,7 @@ class Transducer(nn.Module):
         y_padded = y.pad(mode="constant", padding_value=0)
 
         y_padded = y_padded.to(torch.int64)
-        boundary = torch.zeros(
-            (x.size(0), 4), dtype=torch.int64, device=x.device
-        )
+        boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device)
         boundary[:, 2] = y_lens
         boundary[:, 3] = x_lens
 
diff --git a/egs/aishell/ASR/transducer_stateless/pretrained.py b/egs/aishell/ASR/transducer_stateless/pretrained.py
index db89c4d67..b03a2643a 100755
--- a/egs/aishell/ASR/transducer_stateless/pretrained.py
+++ b/egs/aishell/ASR/transducer_stateless/pretrained.py
@@ -117,8 +117,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -212,8 +211,7 @@ def read_sound_files(
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
         assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. "
-            f"Given: {sample_rate}"
+            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
         )
         # We use only the first channel
         ans.append(wave[0])
@@ -273,9 +271,7 @@ def main():
     features = fbank(waves)
     feature_lengths = [f.size(0) for f in features]
 
-    features = pad_sequence(
-        features, batch_first=True, padding_value=math.log(1e-10)
-    )
+    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
 
     feature_lengths = torch.tensor(feature_lengths, device=device)
 
@@ -319,9 +315,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/aishell/ASR/transducer_stateless/train.py b/egs/aishell/ASR/transducer_stateless/train.py
index d54157709..4ea902507 100755
--- a/egs/aishell/ASR/transducer_stateless/train.py
+++ b/egs/aishell/ASR/transducer_stateless/train.py
@@ -126,8 +126,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
 
     parser.add_argument(
@@ -389,9 +388,7 @@ def compute_loss(
     info = MetricsTracker()
     with warnings.catch_warnings():
         warnings.simplefilter("ignore")
-        info["frames"] = (
-            (feature_lens // params.subsampling_factor).sum().item()
-        )
+        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
 
     # Note: We use reduction=sum while computing the loss.
     info["loss"] = loss.detach().cpu().item()
@@ -504,9 +501,7 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(
-                    tb_writer, "train/tot_", params.batch_idx_train
-                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
@@ -625,9 +620,7 @@ def run(rank, world_size, args):
 
         cur_lr = optimizer._rate
         if tb_writer is not None:
-            tb_writer.add_scalar(
-                "train/learning_rate", cur_lr, params.batch_idx_train
-            )
+            tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
             tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
 
         if rank == 0:
diff --git a/egs/aishell/ASR/transducer_stateless/transformer.py b/egs/aishell/ASR/transducer_stateless/transformer.py
index e851dcc32..b3ff153c1 100644
--- a/egs/aishell/ASR/transducer_stateless/transformer.py
+++ b/egs/aishell/ASR/transducer_stateless/transformer.py
@@ -250,9 +250,7 @@ def _get_activation_fn(activation: str):
     elif activation == "gelu":
         return nn.functional.gelu
 
-    raise RuntimeError(
-        "activation should be relu/gelu, not {}".format(activation)
-    )
+    raise RuntimeError("activation should be relu/gelu, not {}".format(activation))
 
 
 class PositionalEncoding(nn.Module):
diff --git a/egs/aishell/ASR/transducer_stateless_modified-2/asr_datamodule.py b/egs/aishell/ASR/transducer_stateless_modified-2/asr_datamodule.py
index 838e53658..5d49d7338 100644
--- a/egs/aishell/ASR/transducer_stateless_modified-2/asr_datamodule.py
+++ b/egs/aishell/ASR/transducer_stateless_modified-2/asr_datamodule.py
@@ -29,10 +29,7 @@ from lhotse.dataset import (
     K2SpeechRecognitionDataset,
     SpecAugment,
 )
-from lhotse.dataset.input_strategies import (
-    OnTheFlyFeatures,
-    PrecomputedFeatures,
-)
+from lhotse.dataset.input_strategies import OnTheFlyFeatures, PrecomputedFeatures
 from torch.utils.data import DataLoader
 
 from icefall.utils import str2bool
@@ -162,9 +159,7 @@ class AsrDataModule:
         if cuts_musan is not None:
             logging.info("Enable MUSAN")
             transforms.append(
-                CutMix(
-                    cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
-                )
+                CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
             )
         else:
             logging.info("Disable MUSAN")
@@ -173,9 +168,7 @@ class AsrDataModule:
 
         if self.args.enable_spec_aug:
             logging.info("Enable SpecAugment")
-            logging.info(
-                f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
-            )
+            logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
             # Set the value of num_frame_masks according to Lhotse's version.
             # In different Lhotse's versions, the default of num_frame_masks is
             # different.
@@ -252,9 +245,7 @@ class AsrDataModule:
         if self.args.on_the_fly_feats:
             validate = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(
-                    Fbank(FbankConfig(num_mel_bins=80))
-                ),
+                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
                 return_cuts=self.args.return_cuts,
             )
         else:
diff --git a/egs/aishell/ASR/transducer_stateless_modified-2/decode.py b/egs/aishell/ASR/transducer_stateless_modified-2/decode.py
index ea3f94fd8..cb206af6d 100755
--- a/egs/aishell/ASR/transducer_stateless_modified-2/decode.py
+++ b/egs/aishell/ASR/transducer_stateless_modified-2/decode.py
@@ -170,8 +170,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
 
     parser.add_argument(
@@ -227,9 +226,7 @@ def decode_one_batch(
     supervisions = batch["supervisions"]
     feature_lens = supervisions["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(
-        x=feature, x_lens=feature_lens
-    )
+    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
 
     if params.decoding_method == "fast_beam_search":
         hyp_tokens = fast_beam_search_one_best(
@@ -241,10 +238,7 @@ def decode_one_batch(
             max_contexts=params.max_contexts,
             max_states=params.max_states,
         )
-    elif (
-        params.decoding_method == "greedy_search"
-        and params.max_sym_per_frame == 1
-    ):
+    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -365,9 +359,7 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
     return results
 
 
@@ -393,9 +385,7 @@ def save_results(
         # we compute CER for aishell dataset.
         results_char = []
         for res in results:
-            results_char.append(
-                (res[0], list("".join(res[1])), list("".join(res[2])))
-            )
+            results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
         with open(errs_filename, "w") as f:
             wer = write_error_stats(
                 f, f"{test_set_name}-{key}", results_char, enable_log=True
@@ -406,8 +396,7 @@ def save_results(
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = (
-        params.res_dir
-        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tCER", file=f)
@@ -448,9 +437,7 @@ def main():
         params.suffix += f"-max-contexts-{params.max_contexts}"
         params.suffix += f"-max-states-{params.max_states}"
     elif "beam_search" in params.decoding_method:
-        params.suffix += (
-            f"-{params.decoding_method}-beam-size-{params.beam_size}"
-        )
+        params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
     else:
         params.suffix += f"-context-{params.context_size}"
         params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
diff --git a/egs/aishell/ASR/transducer_stateless_modified-2/export.py b/egs/aishell/ASR/transducer_stateless_modified-2/export.py
index 3bd2ceb11..3c56d4a01 100755
--- a/egs/aishell/ASR/transducer_stateless_modified-2/export.py
+++ b/egs/aishell/ASR/transducer_stateless_modified-2/export.py
@@ -109,8 +109,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
 
     return parser
@@ -241,9 +240,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/aishell/ASR/transducer_stateless_modified-2/pretrained.py b/egs/aishell/ASR/transducer_stateless_modified-2/pretrained.py
index a95a4bc52..d8c0c5fcd 100755
--- a/egs/aishell/ASR/transducer_stateless_modified-2/pretrained.py
+++ b/egs/aishell/ASR/transducer_stateless_modified-2/pretrained.py
@@ -165,8 +165,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -195,8 +194,7 @@ def read_sound_files(
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
         assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. "
-            f"Given: {sample_rate}"
+            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
         )
         # We use only the first channel
         ans.append(wave[0])
@@ -254,13 +252,9 @@ def main():
     feature_lens = [f.size(0) for f in features]
     feature_lens = torch.tensor(feature_lens, device=device)
 
-    features = pad_sequence(
-        features, batch_first=True, padding_value=math.log(1e-10)
-    )
+    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
 
-    encoder_out, encoder_out_lens = model.encoder(
-        x=features, x_lens=feature_lens
-    )
+    encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lens)
 
     num_waves = encoder_out.size(0)
     hyp_list = []
@@ -308,9 +302,7 @@ def main():
                     beam=params.beam_size,
                 )
             else:
-                raise ValueError(
-                    f"Unsupported decoding method: {params.method}"
-                )
+                raise ValueError(f"Unsupported decoding method: {params.method}")
             hyp_list.append(hyp)
 
     hyps = []
@@ -327,9 +319,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/aishell/ASR/transducer_stateless_modified-2/train.py b/egs/aishell/ASR/transducer_stateless_modified-2/train.py
index 225d0d709..a9a30d7f7 100755
--- a/egs/aishell/ASR/transducer_stateless_modified-2/train.py
+++ b/egs/aishell/ASR/transducer_stateless_modified-2/train.py
@@ -149,8 +149,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
 
     parser.add_argument(
@@ -168,8 +167,7 @@ def get_parser():
         "--datatang-prob",
         type=float,
         default=0.2,
-        help="The probability to select a batch from the "
-        "aidatatang_200zh dataset",
+        help="The probability to select a batch from the " "aidatatang_200zh dataset",
     )
 
     return parser
@@ -449,9 +447,7 @@ def compute_loss(
     info = MetricsTracker()
     with warnings.catch_warnings():
         warnings.simplefilter("ignore")
-        info["frames"] = (
-            (feature_lens // params.subsampling_factor).sum().item()
-        )
+        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
 
     # Note: We use reduction=sum while computing the loss.
     info["loss"] = loss.detach().cpu().item()
@@ -605,9 +601,7 @@ def train_one_epoch(
                     f"train/current_{prefix}_",
                     params.batch_idx_train,
                 )
-                tot_loss.write_summary(
-                    tb_writer, "train/tot_", params.batch_idx_train
-                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
                 aishell_tot_loss.write_summary(
                     tb_writer, "train/aishell_tot_", params.batch_idx_train
                 )
@@ -735,9 +729,7 @@ def run(rank, world_size, args):
     train_datatang_cuts = train_datatang_cuts.repeat(times=None)
 
     if args.enable_musan:
-        cuts_musan = load_manifest(
-            Path(args.manifest_dir) / "musan_cuts.jsonl.gz"
-        )
+        cuts_musan = load_manifest(Path(args.manifest_dir) / "musan_cuts.jsonl.gz")
     else:
         cuts_musan = None
 
@@ -776,9 +768,7 @@ def run(rank, world_size, args):
 
         cur_lr = optimizer._rate
         if tb_writer is not None:
-            tb_writer.add_scalar(
-                "train/learning_rate", cur_lr, params.batch_idx_train
-            )
+            tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
             tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
 
         if rank == 0:
diff --git a/egs/aishell/ASR/transducer_stateless_modified/decode.py b/egs/aishell/ASR/transducer_stateless_modified/decode.py
index 65fcda873..ba3cb3218 100755
--- a/egs/aishell/ASR/transducer_stateless_modified/decode.py
+++ b/egs/aishell/ASR/transducer_stateless_modified/decode.py
@@ -171,8 +171,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
 
     parser.add_argument(
@@ -231,9 +230,7 @@ def decode_one_batch(
     supervisions = batch["supervisions"]
     feature_lens = supervisions["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(
-        x=feature, x_lens=feature_lens
-    )
+    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
 
     if params.decoding_method == "fast_beam_search":
         hyp_tokens = fast_beam_search_one_best(
@@ -245,10 +242,7 @@ def decode_one_batch(
             max_contexts=params.max_contexts,
             max_states=params.max_states,
         )
-    elif (
-        params.decoding_method == "greedy_search"
-        and params.max_sym_per_frame == 1
-    ):
+    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -369,9 +363,7 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
     return results
 
 
@@ -397,9 +389,7 @@ def save_results(
         # we compute CER for aishell dataset.
         results_char = []
         for res in results:
-            results_char.append(
-                (res[0], list("".join(res[1])), list("".join(res[2])))
-            )
+            results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
         with open(errs_filename, "w") as f:
             wer = write_error_stats(
                 f, f"{test_set_name}-{key}", results_char, enable_log=True
@@ -410,8 +400,7 @@ def save_results(
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = (
-        params.res_dir
-        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tCER", file=f)
@@ -452,9 +441,7 @@ def main():
         params.suffix += f"-max-contexts-{params.max_contexts}"
         params.suffix += f"-max-states-{params.max_states}"
     elif "beam_search" in params.decoding_method:
-        params.suffix += (
-            f"-{params.decoding_method}-beam-size-{params.beam_size}"
-        )
+        params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
     else:
         params.suffix += f"-context-{params.context_size}"
         params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
diff --git a/egs/aishell/ASR/transducer_stateless_modified/export.py b/egs/aishell/ASR/transducer_stateless_modified/export.py
index 11335a834..cbdbdbeb6 100755
--- a/egs/aishell/ASR/transducer_stateless_modified/export.py
+++ b/egs/aishell/ASR/transducer_stateless_modified/export.py
@@ -109,8 +109,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
 
     return parser
@@ -241,9 +240,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/aishell/ASR/transducer_stateless_modified/pretrained.py b/egs/aishell/ASR/transducer_stateless_modified/pretrained.py
index 262e822c2..7dfa92a3c 100755
--- a/egs/aishell/ASR/transducer_stateless_modified/pretrained.py
+++ b/egs/aishell/ASR/transducer_stateless_modified/pretrained.py
@@ -165,8 +165,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -195,8 +194,7 @@ def read_sound_files(
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
         assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. "
-            f"Given: {sample_rate}"
+            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
         )
         # We use only the first channel
         ans.append(wave[0])
@@ -254,13 +252,9 @@ def main():
     feature_lens = [f.size(0) for f in features]
     feature_lens = torch.tensor(feature_lens, device=device)
 
-    features = pad_sequence(
-        features, batch_first=True, padding_value=math.log(1e-10)
-    )
+    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
 
-    encoder_out, encoder_out_lens = model.encoder(
-        x=features, x_lens=feature_lens
-    )
+    encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lens)
 
     num_waves = encoder_out.size(0)
     hyp_list = []
@@ -308,9 +302,7 @@ def main():
                     beam=params.beam_size,
                 )
             else:
-                raise ValueError(
-                    f"Unsupported decoding method: {params.method}"
-                )
+                raise ValueError(f"Unsupported decoding method: {params.method}")
             hyp_list.append(hyp)
 
     hyps = []
@@ -327,9 +319,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/aishell/ASR/transducer_stateless_modified/train.py b/egs/aishell/ASR/transducer_stateless_modified/train.py
index d3ffccafa..c4bf4dd56 100755
--- a/egs/aishell/ASR/transducer_stateless_modified/train.py
+++ b/egs/aishell/ASR/transducer_stateless_modified/train.py
@@ -142,8 +142,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
 
     parser.add_argument(
@@ -414,9 +413,7 @@ def compute_loss(
     info = MetricsTracker()
     with warnings.catch_warnings():
         warnings.simplefilter("ignore")
-        info["frames"] = (
-            (feature_lens // params.subsampling_factor).sum().item()
-        )
+        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
 
     # Note: We use reduction=sum while computing the loss.
     info["loss"] = loss.detach().cpu().item()
@@ -529,9 +526,7 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(
-                    tb_writer, "train/tot_", params.batch_idx_train
-                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
@@ -657,9 +652,7 @@ def run(rank, world_size, args):
 
         cur_lr = optimizer._rate
         if tb_writer is not None:
-            tb_writer.add_scalar(
-                "train/learning_rate", cur_lr, params.batch_idx_train
-            )
+            tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
             tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
 
         if rank == 0:
diff --git a/egs/aishell2/ASR/local/__init__.py b/egs/aishell2/ASR/local/__init__.py
old mode 100755
new mode 100644
diff --git a/egs/aishell2/ASR/local/compute_fbank_aishell2.py b/egs/aishell2/ASR/local/compute_fbank_aishell2.py
index d8d3622bd..ec0c584ca 100755
--- a/egs/aishell2/ASR/local/compute_fbank_aishell2.py
+++ b/egs/aishell2/ASR/local/compute_fbank_aishell2.py
@@ -83,9 +83,7 @@ def compute_fbank_aishell2(num_mel_bins: int = 80):
             )
             if "train" in partition:
                 cut_set = (
-                    cut_set
-                    + cut_set.perturb_speed(0.9)
-                    + cut_set.perturb_speed(1.1)
+                    cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
                 )
             cut_set = cut_set.compute_and_store_features(
                 extractor=extractor,
@@ -111,9 +109,7 @@ def get_args():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
 
diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/__init__.py b/egs/aishell2/ASR/pruned_transducer_stateless5/__init__.py
old mode 100755
new mode 100644
diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/asr_datamodule.py b/egs/aishell2/ASR/pruned_transducer_stateless5/asr_datamodule.py
old mode 100755
new mode 100644
index b7a21f579..0f383a244
--- a/egs/aishell2/ASR/pruned_transducer_stateless5/asr_datamodule.py
+++ b/egs/aishell2/ASR/pruned_transducer_stateless5/asr_datamodule.py
@@ -216,13 +216,9 @@ class AiShell2AsrDataModule:
         if self.args.enable_musan:
             logging.info("Enable MUSAN")
             logging.info("About to get Musan cuts")
-            cuts_musan = load_manifest(
-                self.args.manifest_dir / "musan_cuts.jsonl.gz"
-            )
+            cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
             transforms.append(
-                CutMix(
-                    cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
-                )
+                CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
             )
         else:
             logging.info("Disable MUSAN")
@@ -244,9 +240,7 @@ class AiShell2AsrDataModule:
         input_transforms = []
         if self.args.enable_spec_aug:
             logging.info("Enable SpecAugment")
-            logging.info(
-                f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
-            )
+            logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
             # Set the value of num_frame_masks according to Lhotse's version.
             # In different Lhotse's versions, the default of num_frame_masks is
             # different.
@@ -290,9 +284,7 @@ class AiShell2AsrDataModule:
             # Drop feats to be on the safe side.
             train = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(
-                    Fbank(FbankConfig(num_mel_bins=80))
-                ),
+                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
                 input_transforms=input_transforms,
                 return_cuts=self.args.return_cuts,
             )
@@ -348,9 +340,7 @@ class AiShell2AsrDataModule:
         if self.args.on_the_fly_feats:
             validate = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(
-                    Fbank(FbankConfig(num_mel_bins=80))
-                ),
+                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
                 return_cuts=self.args.return_cuts,
             )
         else:
@@ -406,9 +396,7 @@ class AiShell2AsrDataModule:
     @lru_cache()
     def valid_cuts(self) -> CutSet:
         logging.info("About to gen cuts from aishell2_cuts_dev.jsonl.gz")
-        return load_manifest_lazy(
-            self.args.manifest_dir / "aishell2_cuts_dev.jsonl.gz"
-        )
+        return load_manifest_lazy(self.args.manifest_dir / "aishell2_cuts_dev.jsonl.gz")
 
     @lru_cache()
     def test_cuts(self) -> CutSet:
diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/decode.py b/egs/aishell2/ASR/pruned_transducer_stateless5/decode.py
index 915737f4a..7900c5883 100755
--- a/egs/aishell2/ASR/pruned_transducer_stateless5/decode.py
+++ b/egs/aishell2/ASR/pruned_transducer_stateless5/decode.py
@@ -269,8 +269,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -348,9 +347,7 @@ def decode_one_batch(
     supervisions = batch["supervisions"]
     feature_lens = supervisions["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(
-        x=feature, x_lens=feature_lens
-    )
+    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
     hyps = []
 
     if params.decoding_method == "fast_beam_search":
@@ -409,10 +406,7 @@ def decode_one_batch(
         )
         for i in range(encoder_out.size(0)):
             hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
-    elif (
-        params.decoding_method == "greedy_search"
-        and params.max_sym_per_frame == 1
-    ):
+    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -538,9 +532,7 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
     return results
 
 
@@ -573,8 +565,7 @@ def save_results(
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = (
-        params.res_dir
-        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tWER", file=f)
@@ -625,9 +616,7 @@ def main():
             if "LG" in params.decoding_method:
                 params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
     elif "beam_search" in params.decoding_method:
-        params.suffix += (
-            f"-{params.decoding_method}-beam-size-{params.beam_size}"
-        )
+        params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
     else:
         params.suffix += f"-context-{params.context_size}"
         params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
@@ -661,9 +650,9 @@ def main():
 
     if not params.use_averaged_model:
         if params.iter > 0:
-            filenames = find_checkpoints(
-                params.exp_dir, iteration=-params.iter
-            )[: params.avg]
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg
+            ]
             if len(filenames) == 0:
                 raise ValueError(
                     f"No checkpoints found for"
@@ -690,9 +679,9 @@ def main():
             model.load_state_dict(average_checkpoints(filenames, device=device))
     else:
         if params.iter > 0:
-            filenames = find_checkpoints(
-                params.exp_dir, iteration=-params.iter
-            )[: params.avg + 1]
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg + 1
+            ]
             if len(filenames) == 0:
                 raise ValueError(
                     f"No checkpoints found for"
@@ -749,9 +738,7 @@ def main():
             )
             decoding_graph.scores *= params.ngram_lm_scale
         else:
-            decoding_graph = k2.trivial_graph(
-                params.vocab_size - 1, device=device
-            )
+            decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
     else:
         decoding_graph = None
 
diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/export.py b/egs/aishell2/ASR/pruned_transducer_stateless5/export.py
index bc7bd71cb..ea4a8d4f9 100755
--- a/egs/aishell2/ASR/pruned_transducer_stateless5/export.py
+++ b/egs/aishell2/ASR/pruned_transducer_stateless5/export.py
@@ -133,8 +133,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
 
     add_model_arguments(parser)
@@ -167,9 +166,9 @@ def main():
 
     if not params.use_averaged_model:
         if params.iter > 0:
-            filenames = find_checkpoints(
-                params.exp_dir, iteration=-params.iter
-            )[: params.avg]
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg
+            ]
             if len(filenames) == 0:
                 raise ValueError(
                     f"No checkpoints found for"
@@ -196,9 +195,9 @@ def main():
             model.load_state_dict(average_checkpoints(filenames, device=device))
     else:
         if params.iter > 0:
-            filenames = find_checkpoints(
-                params.exp_dir, iteration=-params.iter
-            )[: params.avg + 1]
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg + 1
+            ]
             if len(filenames) == 0:
                 raise ValueError(
                     f"No checkpoints found for"
@@ -266,9 +265,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/pretrained.py b/egs/aishell2/ASR/pruned_transducer_stateless5/pretrained.py
index 09de1bece..94536fa6f 100755
--- a/egs/aishell2/ASR/pruned_transducer_stateless5/pretrained.py
+++ b/egs/aishell2/ASR/pruned_transducer_stateless5/pretrained.py
@@ -159,8 +159,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -192,8 +191,7 @@ def read_sound_files(
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
         assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. "
-            f"Given: {sample_rate}"
+            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
         )
         # We use only the first channel
         ans.append(wave[0])
@@ -254,15 +252,11 @@ def main():
     features = fbank(waves)
     feature_lengths = [f.size(0) for f in features]
 
-    features = pad_sequence(
-        features, batch_first=True, padding_value=math.log(1e-10)
-    )
+    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
 
     feature_lengths = torch.tensor(feature_lengths, device=device)
 
-    encoder_out, encoder_out_lens = model.encoder(
-        x=features, x_lens=feature_lengths
-    )
+    encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths)
 
     num_waves = encoder_out.size(0)
     hyps = []
@@ -334,9 +328,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/train.py b/egs/aishell2/ASR/pruned_transducer_stateless5/train.py
index 838a0497f..4a228113d 100755
--- a/egs/aishell2/ASR/pruned_transducer_stateless5/train.py
+++ b/egs/aishell2/ASR/pruned_transducer_stateless5/train.py
@@ -92,9 +92,7 @@ from icefall.env import get_env_info
 from icefall.lexicon import Lexicon
 from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
 
-LRSchedulerType = Union[
-    torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
-]
+LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
 
 
 def add_model_arguments(parser: argparse.ArgumentParser):
@@ -220,8 +218,7 @@ def get_parser():
         "--initial-lr",
         type=float,
         default=0.003,
-        help="The initial learning rate.  This value should not need "
-        "to be changed.",
+        help="The initial learning rate.  This value should not need " "to be changed.",
     )
 
     parser.add_argument(
@@ -244,8 +241,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
 
     parser.add_argument(
@@ -268,8 +264,7 @@ def get_parser():
         "--am-scale",
         type=float,
         default=0.0,
-        help="The scale to smooth the loss with am (output of encoder network)"
-        "part.",
+        help="The scale to smooth the loss with am (output of encoder network)" "part.",
     )
 
     parser.add_argument(
@@ -603,11 +598,7 @@ def compute_loss(
      warmup: a floating point value which increases throughout training;
         values >= 1.0 are fully warmed up and have all modules present.
     """
-    device = (
-        model.device
-        if isinstance(model, DDP)
-        else next(model.parameters()).device
-    )
+    device = model.device if isinstance(model, DDP) else next(model.parameters()).device
     feature = batch["inputs"]
     # at entry, feature is (N, T, C)
     assert feature.ndim == 3
@@ -636,23 +627,16 @@ def compute_loss(
         # overwhelming the simple_loss and causing it to diverge,
         # in case it had not fully learned the alignment yet.
         pruned_loss_scale = (
-            0.0
-            if warmup < 1.0
-            else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
-        )
-        loss = (
-            params.simple_loss_scale * simple_loss
-            + pruned_loss_scale * pruned_loss
+            0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
         )
+        loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
 
     assert loss.requires_grad == is_training
 
     info = MetricsTracker()
     with warnings.catch_warnings():
         warnings.simplefilter("ignore")
-        info["frames"] = (
-            (feature_lens // params.subsampling_factor).sum().item()
-        )
+        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
 
     # Note: We use reduction=sum while computing the loss.
     info["loss"] = loss.detach().cpu().item()
@@ -771,9 +755,7 @@ def train_one_epoch(
             scaler.update()
             optimizer.zero_grad()
         except:  # noqa
-            display_and_save_batch(
-                batch, params=params, graph_compiler=graph_compiler
-            )
+            display_and_save_batch(batch, params=params, graph_compiler=graph_compiler)
             raise
 
         if params.print_diagnostics and batch_idx == 5:
@@ -829,9 +811,7 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(
-                    tb_writer, "train/tot_", params.batch_idx_train
-                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
@@ -939,7 +919,7 @@ def run(rank, world_size, args):
 
     if params.print_diagnostics:
         opts = diagnostics.TensorDiagnosticOptions(
-            2 ** 22
+            2**22
         )  # allow 4 megabytes per sub-module
         diagnostic = diagnostics.attach_diagnostics(model, opts)
 
@@ -1104,9 +1084,7 @@ def scan_pessimistic_batches_for_oom(
                     f"Failing criterion: {criterion} "
                     f"(={crit_values[criterion]}) ..."
                 )
-            display_and_save_batch(
-                batch, params=params, graph_compiler=graph_compiler
-            )
+            display_and_save_batch(batch, params=params, graph_compiler=graph_compiler)
             raise
 
 
diff --git a/egs/aishell4/ASR/local/compute_fbank_aishell4.py b/egs/aishell4/ASR/local/compute_fbank_aishell4.py
index 3f50d9e3e..400c406f0 100755
--- a/egs/aishell4/ASR/local/compute_fbank_aishell4.py
+++ b/egs/aishell4/ASR/local/compute_fbank_aishell4.py
@@ -85,9 +85,7 @@ def compute_fbank_aishell4(num_mel_bins: int = 80):
             )
             if "train" in partition:
                 cut_set = (
-                    cut_set
-                    + cut_set.perturb_speed(0.9)
-                    + cut_set.perturb_speed(1.1)
+                    cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
                 )
             cut_set = cut_set.compute_and_store_features(
                 extractor=extractor,
@@ -120,9 +118,7 @@ def get_args():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
 
diff --git a/egs/aishell4/ASR/local/prepare_char.py b/egs/aishell4/ASR/local/prepare_char.py
index d9e47d17a..6b440dfb3 100755
--- a/egs/aishell4/ASR/local/prepare_char.py
+++ b/egs/aishell4/ASR/local/prepare_char.py
@@ -86,9 +86,7 @@ def lexicon_to_fst_no_sil(
         cur_state = loop_state
 
         word = word2id[word]
-        pieces = [
-            token2id[i] if i in token2id else token2id[""] for i in pieces
-        ]
+        pieces = [token2id[i] if i in token2id else token2id[""] for i in pieces]
 
         for i in range(len(pieces) - 1):
             w = word if i == 0 else eps
@@ -142,9 +140,7 @@ def contain_oov(token_sym_table: Dict[str, int], tokens: List[str]) -> bool:
     return False
 
 
-def generate_lexicon(
-    token_sym_table: Dict[str, int], words: List[str]
-) -> Lexicon:
+def generate_lexicon(token_sym_table: Dict[str, int], words: List[str]) -> Lexicon:
     """Generate a lexicon from a word list and token_sym_table.
 
     Args:
diff --git a/egs/aishell4/ASR/local/prepare_lang.py b/egs/aishell4/ASR/local/prepare_lang.py
index e5ae89ec4..c8cf9b881 100755
--- a/egs/aishell4/ASR/local/prepare_lang.py
+++ b/egs/aishell4/ASR/local/prepare_lang.py
@@ -317,9 +317,7 @@ def lexicon_to_fst(
 
 def get_args():
     parser = argparse.ArgumentParser()
-    parser.add_argument(
-        "--lang-dir", type=str, help="The lang dir, data/lang_phone"
-    )
+    parser.add_argument("--lang-dir", type=str, help="The lang dir, data/lang_phone")
     return parser.parse_args()
 
 
diff --git a/egs/aishell4/ASR/local/test_prepare_lang.py b/egs/aishell4/ASR/local/test_prepare_lang.py
index d4cf62bba..74e025ad7 100755
--- a/egs/aishell4/ASR/local/test_prepare_lang.py
+++ b/egs/aishell4/ASR/local/test_prepare_lang.py
@@ -88,9 +88,7 @@ def test_read_lexicon(filename: str):
     fsa.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
     fsa.draw("L.pdf", title="L")
 
-    fsa_disambig = lexicon_to_fst(
-        lexicon_disambig, phone2id=phone2id, word2id=word2id
-    )
+    fsa_disambig = lexicon_to_fst(lexicon_disambig, phone2id=phone2id, word2id=word2id)
     fsa_disambig.labels_sym = k2.SymbolTable.from_file("phones.txt")
     fsa_disambig.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
     fsa_disambig.draw("L_disambig.pdf", title="L_disambig")
diff --git a/egs/aishell4/ASR/local/text2token.py b/egs/aishell4/ASR/local/text2token.py
index 71be2a613..85047c367 100755
--- a/egs/aishell4/ASR/local/text2token.py
+++ b/egs/aishell4/ASR/local/text2token.py
@@ -56,9 +56,7 @@ def get_parser():
     parser.add_argument(
         "--skip-ncols", "-s", default=0, type=int, help="skip first n columns"
     )
-    parser.add_argument(
-        "--space", default="", type=str, help="space symbol"
-    )
+    parser.add_argument("--space", default="", type=str, help="space symbol")
     parser.add_argument(
         "--non-lang-syms",
         "-l",
@@ -66,9 +64,7 @@ def get_parser():
         type=str,
         help="list of non-linguistic symobles, e.g.,  etc.",
     )
-    parser.add_argument(
-        "text", type=str, default=False, nargs="?", help="input text"
-    )
+    parser.add_argument("text", type=str, default=False, nargs="?", help="input text")
     parser.add_argument(
         "--trans_type",
         "-t",
@@ -108,8 +104,7 @@ def token2id(
             if token_type == "lazy_pinyin":
                 text = lazy_pinyin(chars_list)
                 sub_ids = [
-                    token_table[txt] if txt in token_table else oov_id
-                    for txt in text
+                    token_table[txt] if txt in token_table else oov_id for txt in text
                 ]
                 ids.append(sub_ids)
             else:  # token_type = "pinyin"
@@ -135,9 +130,7 @@ def main():
     if args.text:
         f = codecs.open(args.text, encoding="utf-8")
     else:
-        f = codecs.getreader("utf-8")(
-            sys.stdin if is_python2 else sys.stdin.buffer
-        )
+        f = codecs.getreader("utf-8")(sys.stdin if is_python2 else sys.stdin.buffer)
 
     sys.stdout = codecs.getwriter("utf-8")(
         sys.stdout if is_python2 else sys.stdout.buffer
diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/asr_datamodule.py b/egs/aishell4/ASR/pruned_transducer_stateless5/asr_datamodule.py
index 7aa53ddda..d980a857f 100644
--- a/egs/aishell4/ASR/pruned_transducer_stateless5/asr_datamodule.py
+++ b/egs/aishell4/ASR/pruned_transducer_stateless5/asr_datamodule.py
@@ -222,17 +222,13 @@ class Aishell4AsrDataModule:
             The state dict for the training sampler.
         """
         logging.info("About to get Musan cuts")
-        cuts_musan = load_manifest(
-            self.args.manifest_dir / "musan_cuts.jsonl.gz"
-        )
+        cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
 
         transforms = []
         if self.args.enable_musan:
             logging.info("Enable MUSAN")
             transforms.append(
-                CutMix(
-                    cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
-                )
+                CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
             )
         else:
             logging.info("Disable MUSAN")
@@ -254,9 +250,7 @@ class Aishell4AsrDataModule:
         input_transforms = []
         if self.args.enable_spec_aug:
             logging.info("Enable SpecAugment")
-            logging.info(
-                f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
-            )
+            logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
             # Set the value of num_frame_masks according to Lhotse's version.
             # In different Lhotse's versions, the default of num_frame_masks is
             # different.
@@ -300,9 +294,7 @@ class Aishell4AsrDataModule:
             # Drop feats to be on the safe side.
             train = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(
-                    Fbank(FbankConfig(num_mel_bins=80))
-                ),
+                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
                 input_transforms=input_transforms,
                 return_cuts=self.args.return_cuts,
             )
@@ -359,9 +351,7 @@ class Aishell4AsrDataModule:
         if self.args.on_the_fly_feats:
             validate = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(
-                    Fbank(FbankConfig(num_mel_bins=80))
-                ),
+                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
                 return_cuts=self.args.return_cuts,
             )
         else:
diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/decode.py b/egs/aishell4/ASR/pruned_transducer_stateless5/decode.py
index 14e44c7d9..cb533df35 100755
--- a/egs/aishell4/ASR/pruned_transducer_stateless5/decode.py
+++ b/egs/aishell4/ASR/pruned_transducer_stateless5/decode.py
@@ -201,8 +201,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -260,9 +259,7 @@ def decode_one_batch(
     supervisions = batch["supervisions"]
     feature_lens = supervisions["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(
-        x=feature, x_lens=feature_lens
-    )
+    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
     hyps = []
 
     if params.decoding_method == "fast_beam_search":
@@ -277,10 +274,7 @@ def decode_one_batch(
         )
         for i in range(encoder_out.size(0)):
             hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
-    elif (
-        params.decoding_method == "greedy_search"
-        and params.max_sym_per_frame == 1
-    ):
+    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -401,9 +395,7 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
     return results
 
 
@@ -436,8 +428,7 @@ def save_results(
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = (
-        params.res_dir
-        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tWER", file=f)
@@ -480,9 +471,7 @@ def main():
         params.suffix += f"-max-contexts-{params.max_contexts}"
         params.suffix += f"-max-states-{params.max_states}"
     elif "beam_search" in params.decoding_method:
-        params.suffix += (
-            f"-{params.decoding_method}-beam-size-{params.beam_size}"
-        )
+        params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
     else:
         params.suffix += f"-context-{params.context_size}"
         params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
@@ -510,9 +499,9 @@ def main():
 
     if not params.use_averaged_model:
         if params.iter > 0:
-            filenames = find_checkpoints(
-                params.exp_dir, iteration=-params.iter
-            )[: params.avg]
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg
+            ]
             if len(filenames) == 0:
                 raise ValueError(
                     f"No checkpoints found for"
@@ -543,9 +532,9 @@ def main():
             )
     else:
         if params.iter > 0:
-            filenames = find_checkpoints(
-                params.exp_dir, iteration=-params.iter
-            )[: params.avg + 1]
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg + 1
+            ]
             if len(filenames) == 0:
                 raise ValueError(
                     f"No checkpoints found for"
diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/export.py b/egs/aishell4/ASR/pruned_transducer_stateless5/export.py
index 993341131..cc9b7b444 100755
--- a/egs/aishell4/ASR/pruned_transducer_stateless5/export.py
+++ b/egs/aishell4/ASR/pruned_transducer_stateless5/export.py
@@ -136,8 +136,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
 
     add_model_arguments(parser)
@@ -169,9 +168,9 @@ def main():
 
     if not params.use_averaged_model:
         if params.iter > 0:
-            filenames = find_checkpoints(
-                params.exp_dir, iteration=-params.iter
-            )[: params.avg]
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg
+            ]
             if len(filenames) == 0:
                 raise ValueError(
                     f"No checkpoints found for"
@@ -202,9 +201,9 @@ def main():
             )
     else:
         if params.iter > 0:
-            filenames = find_checkpoints(
-                params.exp_dir, iteration=-params.iter
-            )[: params.avg + 1]
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg + 1
+            ]
             if len(filenames) == 0:
                 raise ValueError(
                     f"No checkpoints found for"
@@ -276,9 +275,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/pretrained.py b/egs/aishell4/ASR/pruned_transducer_stateless5/pretrained.py
index 1fa893637..a234f9d65 100755
--- a/egs/aishell4/ASR/pruned_transducer_stateless5/pretrained.py
+++ b/egs/aishell4/ASR/pruned_transducer_stateless5/pretrained.py
@@ -172,8 +172,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -205,8 +204,7 @@ def read_sound_files(
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
         assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. "
-            f"Given: {sample_rate}"
+            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
         )
         # We use only the first channel
         ans.append(wave[0])
@@ -266,15 +264,11 @@ def main():
     features = fbank(waves)
     feature_lengths = [f.size(0) for f in features]
 
-    features = pad_sequence(
-        features, batch_first=True, padding_value=math.log(1e-10)
-    )
+    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
 
     feature_lengths = torch.tensor(feature_lengths, device=device)
 
-    encoder_out, encoder_out_lens = model.encoder(
-        x=features, x_lens=feature_lengths
-    )
+    encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths)
 
     num_waves = encoder_out.size(0)
     hyps = []
@@ -306,10 +300,7 @@ def main():
 
         for i in range(encoder_out.size(0)):
             hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
-    elif (
-        params.decoding_method == "greedy_search"
-        and params.max_sym_per_frame == 1
-    ):
+    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -350,9 +341,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/train.py b/egs/aishell4/ASR/pruned_transducer_stateless5/train.py
index 0a48b9059..73ee34284 100755
--- a/egs/aishell4/ASR/pruned_transducer_stateless5/train.py
+++ b/egs/aishell4/ASR/pruned_transducer_stateless5/train.py
@@ -85,9 +85,7 @@ from icefall.env import get_env_info
 from icefall.lexicon import Lexicon
 from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
 
-LRSchedulerType = Union[
-    torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
-]
+LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
 
 
 def add_model_arguments(parser: argparse.ArgumentParser):
@@ -213,8 +211,7 @@ def get_parser():
         "--initial-lr",
         type=float,
         default=0.003,
-        help="The initial learning rate.  This value should not need "
-        "to be changed.",
+        help="The initial learning rate.  This value should not need " "to be changed.",
     )
 
     parser.add_argument(
@@ -237,8 +234,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
 
     parser.add_argument(
@@ -261,8 +257,7 @@ def get_parser():
         "--am-scale",
         type=float,
         default=0.0,
-        help="The scale to smooth the loss with am (output of encoder network)"
-        "part.",
+        help="The scale to smooth the loss with am (output of encoder network)" "part.",
     )
 
     parser.add_argument(
@@ -599,11 +594,7 @@ def compute_loss(
      warmup: a floating point value which increases throughout training;
         values >= 1.0 are fully warmed up and have all modules present.
     """
-    device = (
-        model.device
-        if isinstance(model, DDP)
-        else next(model.parameters()).device
-    )
+    device = model.device if isinstance(model, DDP) else next(model.parameters()).device
     feature = batch["inputs"]
     # at entry, feature is (N, T, C)
     assert feature.ndim == 3
@@ -633,22 +624,15 @@ def compute_loss(
         # overwhelming the simple_loss and causing it to diverge,
         # in case it had not fully learned the alignment yet.
         pruned_loss_scale = (
-            0.0
-            if warmup < 1.0
-            else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
-        )
-        loss = (
-            params.simple_loss_scale * simple_loss
-            + pruned_loss_scale * pruned_loss
+            0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
         )
+        loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
     assert loss.requires_grad == is_training
 
     info = MetricsTracker()
     with warnings.catch_warnings():
         warnings.simplefilter("ignore")
-        info["frames"] = (
-            (feature_lens // params.subsampling_factor).sum().item()
-        )
+        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
 
     # Note: We use reduction=sum while computing the loss.
     info["loss"] = loss.detach().cpu().item()
@@ -827,9 +811,7 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(
-                    tb_writer, "train/tot_", params.batch_idx_train
-                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
@@ -937,7 +919,7 @@ def run(rank, world_size, args):
 
     if params.print_diagnostics:
         opts = diagnostics.TensorDiagnosticOptions(
-            2 ** 22
+            2**22
         )  # allow 4 megabytes per sub-module
         diagnostic = diagnostics.attach_diagnostics(model, opts)
 
diff --git a/egs/alimeeting/ASR/local/compute_fbank_alimeeting.py b/egs/alimeeting/ASR/local/compute_fbank_alimeeting.py
index af926aa53..96115a230 100755
--- a/egs/alimeeting/ASR/local/compute_fbank_alimeeting.py
+++ b/egs/alimeeting/ASR/local/compute_fbank_alimeeting.py
@@ -84,9 +84,7 @@ def compute_fbank_alimeeting(num_mel_bins: int = 80):
             )
             if "train" in partition:
                 cut_set = (
-                    cut_set
-                    + cut_set.perturb_speed(0.9)
-                    + cut_set.perturb_speed(1.1)
+                    cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
                 )
             cur_num_jobs = num_jobs if ex is None else 80
             cur_num_jobs = min(cur_num_jobs, len(cut_set))
@@ -121,9 +119,7 @@ def get_args():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
 
diff --git a/egs/alimeeting/ASR/local/prepare_char.py b/egs/alimeeting/ASR/local/prepare_char.py
index d9e47d17a..6b440dfb3 100755
--- a/egs/alimeeting/ASR/local/prepare_char.py
+++ b/egs/alimeeting/ASR/local/prepare_char.py
@@ -86,9 +86,7 @@ def lexicon_to_fst_no_sil(
         cur_state = loop_state
 
         word = word2id[word]
-        pieces = [
-            token2id[i] if i in token2id else token2id[""] for i in pieces
-        ]
+        pieces = [token2id[i] if i in token2id else token2id[""] for i in pieces]
 
         for i in range(len(pieces) - 1):
             w = word if i == 0 else eps
@@ -142,9 +140,7 @@ def contain_oov(token_sym_table: Dict[str, int], tokens: List[str]) -> bool:
     return False
 
 
-def generate_lexicon(
-    token_sym_table: Dict[str, int], words: List[str]
-) -> Lexicon:
+def generate_lexicon(token_sym_table: Dict[str, int], words: List[str]) -> Lexicon:
     """Generate a lexicon from a word list and token_sym_table.
 
     Args:
diff --git a/egs/alimeeting/ASR/local/prepare_lang.py b/egs/alimeeting/ASR/local/prepare_lang.py
index e5ae89ec4..c8cf9b881 100755
--- a/egs/alimeeting/ASR/local/prepare_lang.py
+++ b/egs/alimeeting/ASR/local/prepare_lang.py
@@ -317,9 +317,7 @@ def lexicon_to_fst(
 
 def get_args():
     parser = argparse.ArgumentParser()
-    parser.add_argument(
-        "--lang-dir", type=str, help="The lang dir, data/lang_phone"
-    )
+    parser.add_argument("--lang-dir", type=str, help="The lang dir, data/lang_phone")
     return parser.parse_args()
 
 
diff --git a/egs/alimeeting/ASR/local/test_prepare_lang.py b/egs/alimeeting/ASR/local/test_prepare_lang.py
index d4cf62bba..74e025ad7 100755
--- a/egs/alimeeting/ASR/local/test_prepare_lang.py
+++ b/egs/alimeeting/ASR/local/test_prepare_lang.py
@@ -88,9 +88,7 @@ def test_read_lexicon(filename: str):
     fsa.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
     fsa.draw("L.pdf", title="L")
 
-    fsa_disambig = lexicon_to_fst(
-        lexicon_disambig, phone2id=phone2id, word2id=word2id
-    )
+    fsa_disambig = lexicon_to_fst(lexicon_disambig, phone2id=phone2id, word2id=word2id)
     fsa_disambig.labels_sym = k2.SymbolTable.from_file("phones.txt")
     fsa_disambig.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
     fsa_disambig.draw("L_disambig.pdf", title="L_disambig")
diff --git a/egs/alimeeting/ASR/local/text2segments.py b/egs/alimeeting/ASR/local/text2segments.py
index 7c1019aa8..27b904fc8 100644
--- a/egs/alimeeting/ASR/local/text2segments.py
+++ b/egs/alimeeting/ASR/local/text2segments.py
@@ -30,8 +30,8 @@ with word segmenting:
 
 import argparse
 
-import paddle
 import jieba
+import paddle
 from tqdm import tqdm
 
 paddle.enable_static()
diff --git a/egs/alimeeting/ASR/local/text2token.py b/egs/alimeeting/ASR/local/text2token.py
index 71be2a613..85047c367 100755
--- a/egs/alimeeting/ASR/local/text2token.py
+++ b/egs/alimeeting/ASR/local/text2token.py
@@ -56,9 +56,7 @@ def get_parser():
     parser.add_argument(
         "--skip-ncols", "-s", default=0, type=int, help="skip first n columns"
     )
-    parser.add_argument(
-        "--space", default="", type=str, help="space symbol"
-    )
+    parser.add_argument("--space", default="", type=str, help="space symbol")
     parser.add_argument(
         "--non-lang-syms",
         "-l",
@@ -66,9 +64,7 @@ def get_parser():
         type=str,
         help="list of non-linguistic symobles, e.g.,  etc.",
     )
-    parser.add_argument(
-        "text", type=str, default=False, nargs="?", help="input text"
-    )
+    parser.add_argument("text", type=str, default=False, nargs="?", help="input text")
     parser.add_argument(
         "--trans_type",
         "-t",
@@ -108,8 +104,7 @@ def token2id(
             if token_type == "lazy_pinyin":
                 text = lazy_pinyin(chars_list)
                 sub_ids = [
-                    token_table[txt] if txt in token_table else oov_id
-                    for txt in text
+                    token_table[txt] if txt in token_table else oov_id for txt in text
                 ]
                 ids.append(sub_ids)
             else:  # token_type = "pinyin"
@@ -135,9 +130,7 @@ def main():
     if args.text:
         f = codecs.open(args.text, encoding="utf-8")
     else:
-        f = codecs.getreader("utf-8")(
-            sys.stdin if is_python2 else sys.stdin.buffer
-        )
+        f = codecs.getreader("utf-8")(sys.stdin if is_python2 else sys.stdin.buffer)
 
     sys.stdout = codecs.getwriter("utf-8")(
         sys.stdout if is_python2 else sys.stdout.buffer
diff --git a/egs/alimeeting/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/alimeeting/ASR/pruned_transducer_stateless2/asr_datamodule.py
index bf6faad7a..a9a4675a9 100644
--- a/egs/alimeeting/ASR/pruned_transducer_stateless2/asr_datamodule.py
+++ b/egs/alimeeting/ASR/pruned_transducer_stateless2/asr_datamodule.py
@@ -205,17 +205,13 @@ class AlimeetingAsrDataModule:
             The state dict for the training sampler.
         """
         logging.info("About to get Musan cuts")
-        cuts_musan = load_manifest(
-            self.args.manifest_dir / "musan_cuts.jsonl.gz"
-        )
+        cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
 
         transforms = []
         if self.args.enable_musan:
             logging.info("Enable MUSAN")
             transforms.append(
-                CutMix(
-                    cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
-                )
+                CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
             )
         else:
             logging.info("Disable MUSAN")
@@ -237,9 +233,7 @@ class AlimeetingAsrDataModule:
         input_transforms = []
         if self.args.enable_spec_aug:
             logging.info("Enable SpecAugment")
-            logging.info(
-                f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
-            )
+            logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
             # Set the value of num_frame_masks according to Lhotse's version.
             # In different Lhotse's versions, the default of num_frame_masks is
             # different.
@@ -282,9 +276,7 @@ class AlimeetingAsrDataModule:
             # Drop feats to be on the safe side.
             train = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(
-                    Fbank(FbankConfig(num_mel_bins=80))
-                ),
+                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
                 input_transforms=input_transforms,
                 return_cuts=self.args.return_cuts,
             )
@@ -341,9 +333,7 @@ class AlimeetingAsrDataModule:
         if self.args.on_the_fly_feats:
             validate = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(
-                    Fbank(FbankConfig(num_mel_bins=80))
-                ),
+                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
                 return_cuts=self.args.return_cuts,
             )
         else:
diff --git a/egs/alimeeting/ASR/pruned_transducer_stateless2/decode.py b/egs/alimeeting/ASR/pruned_transducer_stateless2/decode.py
index 6358fe970..f3b63b222 100755
--- a/egs/alimeeting/ASR/pruned_transducer_stateless2/decode.py
+++ b/egs/alimeeting/ASR/pruned_transducer_stateless2/decode.py
@@ -70,11 +70,7 @@ from beam_search import (
 from lhotse.cut import Cut
 from train import get_params, get_transducer_model
 
-from icefall.checkpoint import (
-    average_checkpoints,
-    find_checkpoints,
-    load_checkpoint,
-)
+from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
 from icefall.lexicon import Lexicon
 from icefall.utils import (
     AttributeDict,
@@ -193,8 +189,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -249,9 +244,7 @@ def decode_one_batch(
 
     supervisions = batch["supervisions"]
     feature_lens = supervisions["num_frames"].to(device)
-    encoder_out, encoder_out_lens = model.encoder(
-        x=feature, x_lens=feature_lens
-    )
+    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
     hyps = []
 
     if params.decoding_method == "fast_beam_search":
@@ -266,10 +259,7 @@ def decode_one_batch(
         )
         for i in range(encoder_out.size(0)):
             hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
-    elif (
-        params.decoding_method == "greedy_search"
-        and params.max_sym_per_frame == 1
-    ):
+    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -390,9 +380,7 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
     return results
 
 
@@ -425,8 +413,7 @@ def save_results(
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = (
-        params.res_dir
-        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tWER", file=f)
@@ -563,8 +550,7 @@ def main():
         )
 
     dev_shards = [
-        str(path)
-        for path in sorted(glob.glob(os.path.join(dev, "shared-*.tar")))
+        str(path) for path in sorted(glob.glob(os.path.join(dev, "shared-*.tar")))
     ]
     cuts_dev_webdataset = CutSet.from_webdataset(
         dev_shards,
@@ -574,8 +560,7 @@ def main():
     )
 
     test_shards = [
-        str(path)
-        for path in sorted(glob.glob(os.path.join(test, "shared-*.tar")))
+        str(path) for path in sorted(glob.glob(os.path.join(test, "shared-*.tar")))
     ]
     cuts_test_webdataset = CutSet.from_webdataset(
         test_shards,
@@ -588,9 +573,7 @@ def main():
         return 1.0 <= c.duration
 
     cuts_dev_webdataset = cuts_dev_webdataset.filter(remove_short_and_long_utt)
-    cuts_test_webdataset = cuts_test_webdataset.filter(
-        remove_short_and_long_utt
-    )
+    cuts_test_webdataset = cuts_test_webdataset.filter(remove_short_and_long_utt)
 
     dev_dl = alimeeting.valid_dataloaders(cuts_dev_webdataset)
     test_dl = alimeeting.test_dataloaders(cuts_test_webdataset)
diff --git a/egs/alimeeting/ASR/pruned_transducer_stateless2/export.py b/egs/alimeeting/ASR/pruned_transducer_stateless2/export.py
index 8beec1b8a..538853f67 100644
--- a/egs/alimeeting/ASR/pruned_transducer_stateless2/export.py
+++ b/egs/alimeeting/ASR/pruned_transducer_stateless2/export.py
@@ -103,8 +103,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
 
     return parser
@@ -173,9 +172,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/alimeeting/ASR/pruned_transducer_stateless2/pretrained.py b/egs/alimeeting/ASR/pruned_transducer_stateless2/pretrained.py
index 93b1e1f57..4da8d8e14 100644
--- a/egs/alimeeting/ASR/pruned_transducer_stateless2/pretrained.py
+++ b/egs/alimeeting/ASR/pruned_transducer_stateless2/pretrained.py
@@ -162,8 +162,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
 
     parser.add_argument(
@@ -194,8 +193,7 @@ def read_sound_files(
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
         assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. "
-            f"Given: {sample_rate}"
+            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
         )
         # We use only the first channel
         ans.append(wave[0])
@@ -257,9 +255,7 @@ def main():
     features = fbank(waves)
     feature_lengths = [f.size(0) for f in features]
 
-    features = pad_sequence(
-        features, batch_first=True, padding_value=math.log(1e-10)
-    )
+    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
 
     feature_lengths = torch.tensor(feature_lengths, device=device)
 
@@ -284,10 +280,7 @@ def main():
         )
         for i in range(encoder_out.size(0)):
             hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
-    elif (
-        params.decoding_method == "greedy_search"
-        and params.max_sym_per_frame == 1
-    ):
+    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -339,9 +332,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/alimeeting/ASR/pruned_transducer_stateless2/train.py b/egs/alimeeting/ASR/pruned_transducer_stateless2/train.py
index 81a0ede7f..c9d2f3cb9 100644
--- a/egs/alimeeting/ASR/pruned_transducer_stateless2/train.py
+++ b/egs/alimeeting/ASR/pruned_transducer_stateless2/train.py
@@ -81,9 +81,7 @@ from icefall.env import get_env_info
 from icefall.lexicon import Lexicon
 from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
 
-LRSchedulerType = Union[
-    torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
-]
+LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
 
 os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
 
@@ -187,8 +185,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
 
     parser.add_argument(
@@ -211,8 +208,7 @@ def get_parser():
         "--am-scale",
         type=float,
         default=0.0,
-        help="The scale to smooth the loss with am (output of encoder network)"
-        "part.",
+        help="The scale to smooth the loss with am (output of encoder network)" "part.",
     )
 
     parser.add_argument(
@@ -542,22 +538,15 @@ def compute_loss(
         # overwhelming the simple_loss and causing it to diverge,
         # in case it had not fully learned the alignment yet.
         pruned_loss_scale = (
-            0.0
-            if warmup < 1.0
-            else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
-        )
-        loss = (
-            params.simple_loss_scale * simple_loss
-            + pruned_loss_scale * pruned_loss
+            0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
         )
+        loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
     assert loss.requires_grad == is_training
 
     info = MetricsTracker()
     with warnings.catch_warnings():
         warnings.simplefilter("ignore")
-        info["frames"] = (
-            (feature_lens // params.subsampling_factor).sum().item()
-        )
+        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
 
     # Note: We use reduction=sum while computing the loss.
     info["loss"] = loss.detach().cpu().item()
@@ -711,9 +700,7 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(
-                    tb_writer, "train/tot_", params.batch_idx_train
-                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
@@ -813,7 +800,7 @@ def run(rank, world_size, args):
 
     if params.print_diagnostics:
         opts = diagnostics.TensorDiagnosticOptions(
-            2 ** 22
+            2**22
         )  # allow 4 megabytes per sub-module
         diagnostic = diagnostics.attach_diagnostics(model, opts)
 
diff --git a/egs/csj/ASR/.gitignore b/egs/csj/ASR/.gitignore
index 5d965832e..cd0e20c4c 100644
--- a/egs/csj/ASR/.gitignore
+++ b/egs/csj/ASR/.gitignore
@@ -5,4 +5,4 @@ notify_tg.py
 finetune_*
 misc.ini
 .vscode/*
-offline/*
\ No newline at end of file
+offline/*
diff --git a/egs/csj/ASR/local/compute_fbank_csj.py b/egs/csj/ASR/local/compute_fbank_csj.py
index 994dedbdd..c248aa668 100644
--- a/egs/csj/ASR/local/compute_fbank_csj.py
+++ b/egs/csj/ASR/local/compute_fbank_csj.py
@@ -25,15 +25,10 @@ from random import Random
 from typing import List, Tuple
 
 import torch
-from lhotse import (
+from lhotse import (  # fmt: off; See the following for why LilcomChunkyWriter is preferred; https://github.com/k2-fsa/icefall/pull/404; https://github.com/lhotse-speech/lhotse/pull/527; fmt: on
     CutSet,
     Fbank,
     FbankConfig,
-    # fmt: off
-    # See the following for why LilcomChunkyWriter is preferred
-    # https://github.com/k2-fsa/icefall/pull/404
-    # https://github.com/lhotse-speech/lhotse/pull/527
-    # fmt: on
     LilcomChunkyWriter,
     RecordingSet,
     SupervisionSet,
@@ -81,17 +76,13 @@ def make_cutset_blueprints(
         cut_sets.append((f"eval{i}", cut_set))
 
     # Create train and valid cuts
-    logging.info(
-        "Loading, trimming, and shuffling the remaining core+noncore cuts."
-    )
+    logging.info("Loading, trimming, and shuffling the remaining core+noncore cuts.")
     recording_set = RecordingSet.from_file(
         manifest_dir / "csj_recordings_core.jsonl.gz"
     ) + RecordingSet.from_file(manifest_dir / "csj_recordings_noncore.jsonl.gz")
     supervision_set = SupervisionSet.from_file(
         manifest_dir / "csj_supervisions_core.jsonl.gz"
-    ) + SupervisionSet.from_file(
-        manifest_dir / "csj_supervisions_noncore.jsonl.gz"
-    )
+    ) + SupervisionSet.from_file(manifest_dir / "csj_supervisions_noncore.jsonl.gz")
 
     cut_set = CutSet.from_manifests(
         recordings=recording_set,
@@ -101,15 +92,12 @@ def make_cutset_blueprints(
     cut_set = cut_set.shuffle(Random(RNG_SEED))
 
     logging.info(
-        "Creating valid and train cuts from core and noncore,"
-        f"split at {split}."
+        "Creating valid and train cuts from core and noncore," f"split at {split}."
     )
     valid_set = CutSet.from_cuts(islice(cut_set, 0, split))
 
     train_set = CutSet.from_cuts(islice(cut_set, split, None))
-    train_set = (
-        train_set + train_set.perturb_speed(0.9) + train_set.perturb_speed(1.1)
-    )
+    train_set = train_set + train_set.perturb_speed(0.9) + train_set.perturb_speed(1.1)
 
     cut_sets.extend([("valid", valid_set), ("train", train_set)])
 
@@ -122,15 +110,9 @@ def get_args():
         formatter_class=argparse.ArgumentDefaultsHelpFormatter,
     )
 
-    parser.add_argument(
-        "--manifest-dir", type=Path, help="Path to save manifests"
-    )
-    parser.add_argument(
-        "--fbank-dir", type=Path, help="Path to save fbank features"
-    )
-    parser.add_argument(
-        "--split", type=int, default=4000, help="Split at this index"
-    )
+    parser.add_argument("--manifest-dir", type=Path, help="Path to save manifests")
+    parser.add_argument("--fbank-dir", type=Path, help="Path to save fbank features")
+    parser.add_argument("--split", type=int, default=4000, help="Split at this index")
 
     return parser.parse_args()
 
@@ -141,9 +123,7 @@ def main():
     extractor = Fbank(FbankConfig(num_mel_bins=80))
     num_jobs = min(16, os.cpu_count())
 
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
 
diff --git a/egs/csj/ASR/local/compute_fbank_musan.py b/egs/csj/ASR/local/compute_fbank_musan.py
index 44a33c4eb..f60e62c85 100644
--- a/egs/csj/ASR/local/compute_fbank_musan.py
+++ b/egs/csj/ASR/local/compute_fbank_musan.py
@@ -26,7 +26,6 @@ from lhotse.recipes.utils import read_manifests_if_cached
 
 from icefall.utils import get_executor
 
-
 ARGPARSE_DESCRIPTION = """
 This file computes fbank features of the musan dataset.
 It looks for manifests in the directory data/manifests.
@@ -84,9 +83,7 @@ def compute_fbank_musan(manifest_dir: Path, fbank_dir: Path):
         # create chunks of Musan with duration 5 - 10 seconds
         musan_cuts = (
             CutSet.from_manifests(
-                recordings=combine(
-                    part["recordings"] for part in manifests.values()
-                )
+                recordings=combine(part["recordings"] for part in manifests.values())
             )
             .cut_into_windows(10.0)
             .filter(lambda c: c.duration > 5)
@@ -107,21 +104,15 @@ def get_args():
         formatter_class=argparse.ArgumentDefaultsHelpFormatter,
     )
 
-    parser.add_argument(
-        "--manifest-dir", type=Path, help="Path to save manifests"
-    )
-    parser.add_argument(
-        "--fbank-dir", type=Path, help="Path to save fbank features"
-    )
+    parser.add_argument("--manifest-dir", type=Path, help="Path to save manifests")
+    parser.add_argument("--fbank-dir", type=Path, help="Path to save fbank features")
 
     return parser.parse_args()
 
 
 if __name__ == "__main__":
     args = get_args()
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     compute_fbank_musan(args.manifest_dir, args.fbank_dir)
diff --git a/egs/csj/ASR/local/conf/disfluent.ini b/egs/csj/ASR/local/conf/disfluent.ini
index eb70673de..c987e72c5 100644
--- a/egs/csj/ASR/local/conf/disfluent.ini
+++ b/egs/csj/ASR/local/conf/disfluent.ini
@@ -1,17 +1,17 @@
 ; # This section is ignored if this file is not supplied as the first config file to
-; # lhotse prepare csj  
+; # lhotse prepare csj
 [SEGMENTS]
 ; # Allowed period of nonverbal noise. If exceeded, a new segment is created.
 gap = 0.5
 ; # Maximum length of segment (s).
 maxlen = 10
-; # Minimum length of segment (s). Segments shorter than `minlen` will be dropped silently.  
+; # Minimum length of segment (s). Segments shorter than `minlen` will be dropped silently.
 minlen = 0.02
-; # Use this symbol to represent a period of allowed nonverbal noise, i.e. `gap`. 
-; # Pass an empty string to avoid adding any symbol. It was "" in kaldi. 
-; # If you intend to use a multicharacter string for gap_sym, remember to register the 
-; # multicharacter string as part of userdef-string in prepare_lang_char.py. 
-gap_sym = 
+; # Use this symbol to represent a period of allowed nonverbal noise, i.e. `gap`.
+; # Pass an empty string to avoid adding any symbol. It was "" in kaldi.
+; # If you intend to use a multicharacter string for gap_sym, remember to register the
+; # multicharacter string as part of userdef-string in prepare_lang_char.py.
+gap_sym =
 
 [CONSTANTS]
 ; # Name of this mode
@@ -115,59 +115,59 @@ B^ = 0
 ; # 0 to remain, 1 to delete
 ; # Example: '(笑 ナニガ)', '(笑 (F エー)+ソー+イッ+タ+ヨー+ナ)'
 笑 = 0
-; # Example: 'コク(笑 サイ+(D オン))', 
+; # Example: 'コク(笑 サイ+(D オン))',
 笑^ = 0
 ; # 泣きながら発話
 ; # 0 to remain, 1 to delete
-; # Example: '(泣 ドンナニ)' 
+; # Example: '(泣 ドンナニ)'
 泣 = 0
 泣^ = 0
 ; # 咳をしながら発話
 ; # 0 to remain, 1 to delete
-; # Example: 'シャ(咳 リン) ノ' 
+; # Example: 'シャ(咳 リン) ノ'
 咳 = 0
 ; # Example: 'イッ(咳 パン)', 'ワズ(咳 カ)'
 咳^ = 0
 ; # ささやき声や独り言などの小さな声
 ; # 0 to remain, 1 to delete
-; # Example: '(L アレコレナンダッケ)', '(L (W コデ;(? コレ,ココデ))+(? セツメー+シ+タ+ホー+ガ+イー+カ+ナ))' 
+; # Example: '(L アレコレナンダッケ)', '(L (W コデ;(? コレ,ココデ))+(? セツメー+シ+タ+ホー+ガ+イー+カ+ナ))'
 L = 0
 ; # Example: 'デ(L ス)', 'ッ(L テ+コ)ト'
 L^ = 0
 
 [REPLACEMENTS]
 ; # ボーカルフライなどで母音が同定できない場合
- = 
+ =
 ; # 「うん/うーん/ふーん」の音の特定が困難な場合
- = 
+ =
 ; # 非語彙的な母音の引き延ばし
- = 
+ =
 ; # 非語彙的な子音の引き延ばし
- = 
+ =
 ; # 言語音と独立に講演者の笑いが生じている場合
-<笑> = 
+<笑> =
 ; # 言語音と独立に講演者の咳が生じている場合
-<咳> = 
+<咳> =
 ; # 言語音と独立に講演者の息が生じている場合
-<息> = 
+<息> =
 ; # 講演者の泣き声
-<泣> = 
+<泣> =
 ; # 聴衆(司会者なども含む)の発話
-<フロア発話> = 
+<フロア発話> =
 ; # 聴衆の笑い
-<フロア笑> = 
+<フロア笑> =
 ; # 聴衆の拍手
-<拍手> = 
+<拍手> =
 ; # 講演者が発表中に用いたデモンストレーションの音声
-<デモ> = 
+<デモ> =
 ; # 学会講演に発表時間を知らせるためにならすベルの音
-<ベル> = 
+<ベル> =
 ; # 転記単位全体が再度読み直された場合
-<朗読間違い> = 
+<朗読間違い> =
 ; # 上記以外の音で特に目立った音
-<雑音> = 
+<雑音> =
 ; # 0.2秒以上のポーズ
-

= +

= ; # Redacted information, for R ; # It is \x00D7 multiplication sign, not your normal 'x' × = × @@ -318,4 +318,3 @@ spk_id = 2 ャ = ǐa ュ = ǐu ョ = ǐo - diff --git a/egs/csj/ASR/local/conf/fluent.ini b/egs/csj/ASR/local/conf/fluent.ini index 5d22f9eb8..f7f27f5bc 100644 --- a/egs/csj/ASR/local/conf/fluent.ini +++ b/egs/csj/ASR/local/conf/fluent.ini @@ -1,17 +1,17 @@ ; # This section is ignored if this file is not supplied as the first config file to -; # lhotse prepare csj +; # lhotse prepare csj [SEGMENTS] ; # Allowed period of nonverbal noise. If exceeded, a new segment is created. gap = 0.5 ; # Maximum length of segment (s). maxlen = 10 -; # Minimum length of segment (s). Segments shorter than `minlen` will be dropped silently. +; # Minimum length of segment (s). Segments shorter than `minlen` will be dropped silently. minlen = 0.02 -; # Use this symbol to represent a period of allowed nonverbal noise, i.e. `gap`. -; # Pass an empty string to avoid adding any symbol. It was "" in kaldi. -; # If you intend to use a multicharacter string for gap_sym, remember to register the -; # multicharacter string as part of userdef-string in prepare_lang_char.py. -gap_sym = +; # Use this symbol to represent a period of allowed nonverbal noise, i.e. `gap`. +; # Pass an empty string to avoid adding any symbol. It was "" in kaldi. +; # If you intend to use a multicharacter string for gap_sym, remember to register the +; # multicharacter string as part of userdef-string in prepare_lang_char.py. +gap_sym = [CONSTANTS] ; # Name of this mode @@ -115,59 +115,59 @@ B^ = 0 ; # 0 to remain, 1 to delete ; # Example: '(笑 ナニガ)', '(笑 (F エー)+ソー+イッ+タ+ヨー+ナ)' 笑 = 0 -; # Example: 'コク(笑 サイ+(D オン))', +; # Example: 'コク(笑 サイ+(D オン))', 笑^ = 0 ; # 泣きながら発話 ; # 0 to remain, 1 to delete -; # Example: '(泣 ドンナニ)' +; # Example: '(泣 ドンナニ)' 泣 = 0 泣^ = 0 ; # 咳をしながら発話 ; # 0 to remain, 1 to delete -; # Example: 'シャ(咳 リン) ノ' +; # Example: 'シャ(咳 リン) ノ' 咳 = 0 ; # Example: 'イッ(咳 パン)', 'ワズ(咳 カ)' 咳^ = 0 ; # ささやき声や独り言などの小さな声 ; # 0 to remain, 1 to delete -; # Example: '(L アレコレナンダッケ)', '(L (W コデ;(? コレ,ココデ))+(? セツメー+シ+タ+ホー+ガ+イー+カ+ナ))' +; # Example: '(L アレコレナンダッケ)', '(L (W コデ;(? コレ,ココデ))+(? セツメー+シ+タ+ホー+ガ+イー+カ+ナ))' L = 0 ; # Example: 'デ(L ス)', 'ッ(L テ+コ)ト' L^ = 0 [REPLACEMENTS] ; # ボーカルフライなどで母音が同定できない場合 - = + = ; # 「うん/うーん/ふーん」の音の特定が困難な場合 - = + = ; # 非語彙的な母音の引き延ばし - = + = ; # 非語彙的な子音の引き延ばし - = + = ; # 言語音と独立に講演者の笑いが生じている場合 -<笑> = +<笑> = ; # 言語音と独立に講演者の咳が生じている場合 -<咳> = +<咳> = ; # 言語音と独立に講演者の息が生じている場合 -<息> = +<息> = ; # 講演者の泣き声 -<泣> = +<泣> = ; # 聴衆(司会者なども含む)の発話 -<フロア発話> = +<フロア発話> = ; # 聴衆の笑い -<フロア笑> = +<フロア笑> = ; # 聴衆の拍手 -<拍手> = +<拍手> = ; # 講演者が発表中に用いたデモンストレーションの音声 -<デモ> = +<デモ> = ; # 学会講演に発表時間を知らせるためにならすベルの音 -<ベル> = +<ベル> = ; # 転記単位全体が再度読み直された場合 -<朗読間違い> = +<朗読間違い> = ; # 上記以外の音で特に目立った音 -<雑音> = +<雑音> = ; # 0.2秒以上のポーズ -

= +

= ; # Redacted information, for R ; # It is \x00D7 multiplication sign, not your normal 'x' × = × @@ -318,4 +318,3 @@ spk_id = 2 ャ = ǐa ュ = ǐu ョ = ǐo - diff --git a/egs/csj/ASR/local/conf/number.ini b/egs/csj/ASR/local/conf/number.ini index 2613c3409..cf9038f62 100644 --- a/egs/csj/ASR/local/conf/number.ini +++ b/egs/csj/ASR/local/conf/number.ini @@ -1,17 +1,17 @@ ; # This section is ignored if this file is not supplied as the first config file to -; # lhotse prepare csj +; # lhotse prepare csj [SEGMENTS] ; # Allowed period of nonverbal noise. If exceeded, a new segment is created. gap = 0.5 ; # Maximum length of segment (s). maxlen = 10 -; # Minimum length of segment (s). Segments shorter than `minlen` will be dropped silently. +; # Minimum length of segment (s). Segments shorter than `minlen` will be dropped silently. minlen = 0.02 -; # Use this symbol to represent a period of allowed nonverbal noise, i.e. `gap`. -; # Pass an empty string to avoid adding any symbol. It was "" in kaldi. -; # If you intend to use a multicharacter string for gap_sym, remember to register the -; # multicharacter string as part of userdef-string in prepare_lang_char.py. -gap_sym = +; # Use this symbol to represent a period of allowed nonverbal noise, i.e. `gap`. +; # Pass an empty string to avoid adding any symbol. It was "" in kaldi. +; # If you intend to use a multicharacter string for gap_sym, remember to register the +; # multicharacter string as part of userdef-string in prepare_lang_char.py. +gap_sym = [CONSTANTS] ; # Name of this mode @@ -115,59 +115,59 @@ B^ = 0 ; # 0 to remain, 1 to delete ; # Example: '(笑 ナニガ)', '(笑 (F エー)+ソー+イッ+タ+ヨー+ナ)' 笑 = 0 -; # Example: 'コク(笑 サイ+(D オン))', +; # Example: 'コク(笑 サイ+(D オン))', 笑^ = 0 ; # 泣きながら発話 ; # 0 to remain, 1 to delete -; # Example: '(泣 ドンナニ)' +; # Example: '(泣 ドンナニ)' 泣 = 0 泣^ = 0 ; # 咳をしながら発話 ; # 0 to remain, 1 to delete -; # Example: 'シャ(咳 リン) ノ' +; # Example: 'シャ(咳 リン) ノ' 咳 = 0 ; # Example: 'イッ(咳 パン)', 'ワズ(咳 カ)' 咳^ = 0 ; # ささやき声や独り言などの小さな声 ; # 0 to remain, 1 to delete -; # Example: '(L アレコレナンダッケ)', '(L (W コデ;(? コレ,ココデ))+(? セツメー+シ+タ+ホー+ガ+イー+カ+ナ))' +; # Example: '(L アレコレナンダッケ)', '(L (W コデ;(? コレ,ココデ))+(? セツメー+シ+タ+ホー+ガ+イー+カ+ナ))' L = 0 ; # Example: 'デ(L ス)', 'ッ(L テ+コ)ト' L^ = 0 [REPLACEMENTS] ; # ボーカルフライなどで母音が同定できない場合 - = + = ; # 「うん/うーん/ふーん」の音の特定が困難な場合 - = + = ; # 非語彙的な母音の引き延ばし - = + = ; # 非語彙的な子音の引き延ばし - = + = ; # 言語音と独立に講演者の笑いが生じている場合 -<笑> = +<笑> = ; # 言語音と独立に講演者の咳が生じている場合 -<咳> = +<咳> = ; # 言語音と独立に講演者の息が生じている場合 -<息> = +<息> = ; # 講演者の泣き声 -<泣> = +<泣> = ; # 聴衆(司会者なども含む)の発話 -<フロア発話> = +<フロア発話> = ; # 聴衆の笑い -<フロア笑> = +<フロア笑> = ; # 聴衆の拍手 -<拍手> = +<拍手> = ; # 講演者が発表中に用いたデモンストレーションの音声 -<デモ> = +<デモ> = ; # 学会講演に発表時間を知らせるためにならすベルの音 -<ベル> = +<ベル> = ; # 転記単位全体が再度読み直された場合 -<朗読間違い> = +<朗読間違い> = ; # 上記以外の音で特に目立った音 -<雑音> = +<雑音> = ; # 0.2秒以上のポーズ -

= +

= ; # Redacted information, for R ; # It is \x00D7 multiplication sign, not your normal 'x' × = × @@ -318,4 +318,3 @@ spk_id = 2 ャ = ǐa ュ = ǐu ョ = ǐo - diff --git a/egs/csj/ASR/local/conf/symbol.ini b/egs/csj/ASR/local/conf/symbol.ini index 8ba451dd5..f9801284b 100644 --- a/egs/csj/ASR/local/conf/symbol.ini +++ b/egs/csj/ASR/local/conf/symbol.ini @@ -1,17 +1,17 @@ ; # This section is ignored if this file is not supplied as the first config file to -; # lhotse prepare csj +; # lhotse prepare csj [SEGMENTS] ; # Allowed period of nonverbal noise. If exceeded, a new segment is created. gap = 0.5 ; # Maximum length of segment (s). maxlen = 10 -; # Minimum length of segment (s). Segments shorter than `minlen` will be dropped silently. +; # Minimum length of segment (s). Segments shorter than `minlen` will be dropped silently. minlen = 0.02 -; # Use this symbol to represent a period of allowed nonverbal noise, i.e. `gap`. -; # Pass an empty string to avoid adding any symbol. It was "" in kaldi. -; # If you intend to use a multicharacter string for gap_sym, remember to register the -; # multicharacter string as part of userdef-string in prepare_lang_char.py. -gap_sym = +; # Use this symbol to represent a period of allowed nonverbal noise, i.e. `gap`. +; # Pass an empty string to avoid adding any symbol. It was "" in kaldi. +; # If you intend to use a multicharacter string for gap_sym, remember to register the +; # multicharacter string as part of userdef-string in prepare_lang_char.py. +gap_sym = [CONSTANTS] ; # Name of this mode @@ -116,59 +116,59 @@ B^ = 0 ; # 0 to remain, 1 to delete ; # Example: '(笑 ナニガ)', '(笑 (F エー)+ソー+イッ+タ+ヨー+ナ)' 笑 = 0 -; # Example: 'コク(笑 サイ+(D オン))', +; # Example: 'コク(笑 サイ+(D オン))', 笑^ = 0 ; # 泣きながら発話 ; # 0 to remain, 1 to delete -; # Example: '(泣 ドンナニ)' +; # Example: '(泣 ドンナニ)' 泣 = 0 泣^ = 0 ; # 咳をしながら発話 ; # 0 to remain, 1 to delete -; # Example: 'シャ(咳 リン) ノ' +; # Example: 'シャ(咳 リン) ノ' 咳 = 0 ; # Example: 'イッ(咳 パン)', 'ワズ(咳 カ)' 咳^ = 0 ; # ささやき声や独り言などの小さな声 ; # 0 to remain, 1 to delete -; # Example: '(L アレコレナンダッケ)', '(L (W コデ;(? コレ,ココデ))+(? セツメー+シ+タ+ホー+ガ+イー+カ+ナ))' +; # Example: '(L アレコレナンダッケ)', '(L (W コデ;(? コレ,ココデ))+(? セツメー+シ+タ+ホー+ガ+イー+カ+ナ))' L = 0 ; # Example: 'デ(L ス)', 'ッ(L テ+コ)ト' L^ = 0 [REPLACEMENTS] ; # ボーカルフライなどで母音が同定できない場合 - = + = ; # 「うん/うーん/ふーん」の音の特定が困難な場合 - = + = ; # 非語彙的な母音の引き延ばし - = + = ; # 非語彙的な子音の引き延ばし - = + = ; # 言語音と独立に講演者の笑いが生じている場合 -<笑> = +<笑> = ; # 言語音と独立に講演者の咳が生じている場合 -<咳> = +<咳> = ; # 言語音と独立に講演者の息が生じている場合 -<息> = +<息> = ; # 講演者の泣き声 -<泣> = +<泣> = ; # 聴衆(司会者なども含む)の発話 -<フロア発話> = +<フロア発話> = ; # 聴衆の笑い -<フロア笑> = +<フロア笑> = ; # 聴衆の拍手 -<拍手> = +<拍手> = ; # 講演者が発表中に用いたデモンストレーションの音声 -<デモ> = +<デモ> = ; # 学会講演に発表時間を知らせるためにならすベルの音 -<ベル> = +<ベル> = ; # 転記単位全体が再度読み直された場合 -<朗読間違い> = +<朗読間違い> = ; # 上記以外の音で特に目立った音 -<雑音> = +<雑音> = ; # 0.2秒以上のポーズ -

= +

= ; # Redacted information, for R ; # It is \x00D7 multiplication sign, not your normal 'x' × = × @@ -319,4 +319,3 @@ spk_id = 2 ャ = ǐa ュ = ǐu ョ = ǐo - diff --git a/egs/csj/ASR/local/display_manifest_statistics.py b/egs/csj/ASR/local/display_manifest_statistics.py index c9de21073..c043cf853 100644 --- a/egs/csj/ASR/local/display_manifest_statistics.py +++ b/egs/csj/ASR/local/display_manifest_statistics.py @@ -37,9 +37,7 @@ def get_parser(): formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) - parser.add_argument( - "--manifest-dir", type=Path, help="Path to cutset manifests" - ) + parser.add_argument("--manifest-dir", type=Path, help="Path to cutset manifests") return parser.parse_args() diff --git a/egs/csj/ASR/local/prepare_lang_char.py b/egs/csj/ASR/local/prepare_lang_char.py index e4d996871..ef91f6e43 100644 --- a/egs/csj/ASR/local/prepare_lang_char.py +++ b/egs/csj/ASR/local/prepare_lang_char.py @@ -87,9 +87,7 @@ def main(): args = get_args() logging.basicConfig( - format=( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] " "%(message)s" - ), + format=("%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] " "%(message)s"), level=logging.INFO, ) @@ -111,8 +109,7 @@ def main(): words = set() logging.info( - f"Creating vocabulary from {args.train_cut.name}" - f" at {args.trans_mode} mode." + f"Creating vocabulary from {args.train_cut.name}" f" at {args.trans_mode} mode." ) for cut in train_set: try: @@ -123,8 +120,7 @@ def main(): ) except KeyError: raise KeyError( - f"Could not find {args.trans_mode} in " - f"{cut.supervisions[0].custom}" + f"Could not find {args.trans_mode} in " f"{cut.supervisions[0].custom}" ) for t in text.split(): if t in args.userdef_string: @@ -143,9 +139,7 @@ def main(): (args.lang_dir / "words_len").write_text(f"{len(words)}") - (args.lang_dir / "userdef_string").write_text( - "\n".join(args.userdef_string) - ) + (args.lang_dir / "userdef_string").write_text("\n".join(args.userdef_string)) (args.lang_dir / "trans_mode").write_text(args.trans_mode) logging.info("Done.") diff --git a/egs/csj/ASR/local/validate_manifest.py b/egs/csj/ASR/local/validate_manifest.py index 0c4c6c1ea..7f67c64b6 100644 --- a/egs/csj/ASR/local/validate_manifest.py +++ b/egs/csj/ASR/local/validate_manifest.py @@ -89,9 +89,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/gigaspeech/ASR/conformer_ctc/asr_datamodule.py b/egs/gigaspeech/ASR/conformer_ctc/asr_datamodule.py index d78e26240..72dcd772a 100644 --- a/egs/gigaspeech/ASR/conformer_ctc/asr_datamodule.py +++ b/egs/gigaspeech/ASR/conformer_ctc/asr_datamodule.py @@ -183,23 +183,18 @@ class GigaSpeechAsrDataModule: "--small-dev", type=str2bool, default=False, - help="Should we use only 1000 utterances for dev " - "(speeds up training)", + help="Should we use only 1000 utterances for dev " "(speeds up training)", ) def train_dataloaders(self, cuts_train: CutSet) -> DataLoader: logging.info("About to get Musan cuts") - cuts_musan = load_manifest( - self.args.manifest_dir / "musan_cuts.jsonl.gz" - ) + cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") transforms = [] if self.args.enable_musan: logging.info("Enable MUSAN") transforms.append( - CutMix( - cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True - ) + CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) ) else: logging.info("Disable MUSAN") @@ -221,9 +216,7 @@ class GigaSpeechAsrDataModule: input_transforms = [] if self.args.enable_spec_aug: logging.info("Enable SpecAugment") - logging.info( - f"Time warp factor: {self.args.spec_aug_time_warp_factor}" - ) + logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") input_transforms.append( SpecAugment( time_warp_factor=self.args.spec_aug_time_warp_factor, @@ -256,9 +249,7 @@ class GigaSpeechAsrDataModule: # Drop feats to be on the safe side. train = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) - ), + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), input_transforms=input_transforms, return_cuts=self.args.return_cuts, ) @@ -304,9 +295,7 @@ class GigaSpeechAsrDataModule: if self.args.on_the_fly_feats: validate = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) - ), + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), return_cuts=self.args.return_cuts, ) else: @@ -362,9 +351,7 @@ class GigaSpeechAsrDataModule: @lru_cache() def dev_cuts(self) -> CutSet: logging.info("About to get dev cuts") - cuts_valid = load_manifest_lazy( - self.args.manifest_dir / "cuts_DEV.jsonl.gz" - ) + cuts_valid = load_manifest_lazy(self.args.manifest_dir / "cuts_DEV.jsonl.gz") if self.args.small_dev: return cuts_valid.subset(first=1000) else: diff --git a/egs/gigaspeech/ASR/conformer_ctc/conformer.py b/egs/gigaspeech/ASR/conformer_ctc/conformer.py index 6fac07f93..a1cfe6e75 100644 --- a/egs/gigaspeech/ASR/conformer_ctc/conformer.py +++ b/egs/gigaspeech/ASR/conformer_ctc/conformer.py @@ -160,9 +160,7 @@ class ConformerEncoderLayer(nn.Module): use_conv_batchnorm: bool = False, ) -> None: super(ConformerEncoderLayer, self).__init__() - self.self_attn = RelPositionMultiheadAttention( - d_model, nhead, dropout=0.0 - ) + self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) self.feed_forward = nn.Sequential( nn.Linear(d_model, dim_feedforward), @@ -182,18 +180,14 @@ class ConformerEncoderLayer(nn.Module): d_model, cnn_module_kernel, use_batchnorm=use_conv_batchnorm ) - self.norm_ff_macaron = nn.LayerNorm( - d_model - ) # for the macaron style FNN module + self.norm_ff_macaron = nn.LayerNorm(d_model) # for the macaron style FNN module self.norm_ff = nn.LayerNorm(d_model) # for the FNN module self.norm_mha = nn.LayerNorm(d_model) # for the MHA module self.ff_scale = 0.5 self.norm_conv = nn.LayerNorm(d_model) # for the CNN module - self.norm_final = nn.LayerNorm( - d_model - ) # for the final output of the block + self.norm_final = nn.LayerNorm(d_model) # for the final output of the block self.dropout = nn.Dropout(dropout) @@ -227,9 +221,7 @@ class ConformerEncoderLayer(nn.Module): residual = src if self.normalize_before: src = self.norm_ff_macaron(src) - src = residual + self.ff_scale * self.dropout( - self.feed_forward_macaron(src) - ) + src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src)) if not self.normalize_before: src = self.norm_ff_macaron(src) @@ -348,9 +340,7 @@ class RelPositionalEncoding(torch.nn.Module): """ - def __init__( - self, d_model: int, dropout_rate: float, max_len: int = 5000 - ) -> None: + def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: """Construct an PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() self.d_model = d_model @@ -366,9 +356,7 @@ class RelPositionalEncoding(torch.nn.Module): # the length of self.pe is 2 * input_len - 1 if self.pe.size(1) >= x.size(1) * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str( - x.device - ): + if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return # Suppose `i` means to the position of query vector and `j` means the @@ -638,9 +626,9 @@ class RelPositionMultiheadAttention(nn.Module): if torch.equal(query, key) and torch.equal(key, value): # self-attention - q, k, v = nn.functional.linear( - query, in_proj_weight, in_proj_bias - ).chunk(3, dim=-1) + q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk( + 3, dim=-1 + ) elif torch.equal(key, value): # encoder-decoder attention @@ -708,31 +696,22 @@ class RelPositionMultiheadAttention(nn.Module): if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: - raise RuntimeError( - "The size of the 2D attn_mask is not correct." - ) + raise RuntimeError("The size of the 2D attn_mask is not correct.") elif attn_mask.dim() == 3: if list(attn_mask.size()) != [ bsz * num_heads, query.size(0), key.size(0), ]: - raise RuntimeError( - "The size of the 3D attn_mask is not correct." - ) + raise RuntimeError("The size of the 3D attn_mask is not correct.") else: raise RuntimeError( - "attn_mask's dimension {} is not supported".format( - attn_mask.dim() - ) + "attn_mask's dimension {} is not supported".format(attn_mask.dim()) ) # attn_mask's dim is 3 now. # convert ByteTensor key_padding_mask to bool - if ( - key_padding_mask is not None - and key_padding_mask.dtype == torch.uint8 - ): + if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: warnings.warn( "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." ) @@ -771,9 +750,7 @@ class RelPositionMultiheadAttention(nn.Module): # first compute matrix a and matrix c # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - matrix_ac = torch.matmul( - q_with_bias_u, k - ) # (batch, head, time1, time2) + matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2) # compute matrix b and matrix d matrix_bd = torch.matmul( @@ -785,9 +762,7 @@ class RelPositionMultiheadAttention(nn.Module): matrix_ac + matrix_bd ) * scaling # (batch, head, time1, time2) - attn_output_weights = attn_output_weights.view( - bsz * num_heads, tgt_len, -1 - ) + attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) assert list(attn_output_weights.size()) == [ bsz * num_heads, @@ -821,13 +796,9 @@ class RelPositionMultiheadAttention(nn.Module): attn_output = torch.bmm(attn_output_weights, v) assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] attn_output = ( - attn_output.transpose(0, 1) - .contiguous() - .view(tgt_len, bsz, embed_dim) - ) - attn_output = nn.functional.linear( - attn_output, out_proj_weight, out_proj_bias + attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) ) + attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) if need_weights: # average attention weights over heads diff --git a/egs/gigaspeech/ASR/conformer_ctc/decode.py b/egs/gigaspeech/ASR/conformer_ctc/decode.py index 9c1418baa..d7035a1f8 100755 --- a/egs/gigaspeech/ASR/conformer_ctc/decode.py +++ b/egs/gigaspeech/ASR/conformer_ctc/decode.py @@ -476,9 +476,7 @@ def decode_dataset( results[lm_scale].extend(this_batch) else: - assert ( - len(results) > 0 - ), "It should not decode to empty in the first batch!" + assert len(results) > 0, "It should not decode to empty in the first batch!" this_batch = [] hyp_words = [] for cut_id, ref_text in zip(cut_ids, texts): @@ -493,9 +491,7 @@ def decode_dataset( if batch_idx % 100 == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -528,9 +524,7 @@ def save_results( test_set_wers[key] = wer if enable_log: - logging.info( - "Wrote detailed error stats to {}".format(errs_filename) - ) + logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt" @@ -705,9 +699,7 @@ def main(): eos_id=eos_id, ) - save_results( - params=params, test_set_name=test_set, results_dict=results_dict - ) + save_results(params=params, test_set_name=test_set, results_dict=results_dict) logging.info("Done!") diff --git a/egs/gigaspeech/ASR/conformer_ctc/label_smoothing.py b/egs/gigaspeech/ASR/conformer_ctc/label_smoothing.py index cdc85ce9a..3b94f0c4b 100644 --- a/egs/gigaspeech/ASR/conformer_ctc/label_smoothing.py +++ b/egs/gigaspeech/ASR/conformer_ctc/label_smoothing.py @@ -78,13 +78,10 @@ class LabelSmoothingLoss(torch.nn.Module): ignored = target == self.ignore_index target[ignored] = 0 - true_dist = torch.nn.functional.one_hot( - target, num_classes=num_classes - ).to(x) + true_dist = torch.nn.functional.one_hot(target, num_classes=num_classes).to(x) true_dist = ( - true_dist * (1 - self.label_smoothing) - + self.label_smoothing / num_classes + true_dist * (1 - self.label_smoothing) + self.label_smoothing / num_classes ) # Set the value of ignored indexes to 0 true_dist[ignored] = 0 diff --git a/egs/gigaspeech/ASR/conformer_ctc/subsampling.py b/egs/gigaspeech/ASR/conformer_ctc/subsampling.py index 542fb0364..8e0f73d05 100644 --- a/egs/gigaspeech/ASR/conformer_ctc/subsampling.py +++ b/egs/gigaspeech/ASR/conformer_ctc/subsampling.py @@ -42,13 +42,9 @@ class Conv2dSubsampling(nn.Module): assert idim >= 7 super().__init__() self.conv = nn.Sequential( - nn.Conv2d( - in_channels=1, out_channels=odim, kernel_size=3, stride=2 - ), + nn.Conv2d(in_channels=1, out_channels=odim, kernel_size=3, stride=2), nn.ReLU(), - nn.Conv2d( - in_channels=odim, out_channels=odim, kernel_size=3, stride=2 - ), + nn.Conv2d(in_channels=odim, out_channels=odim, kernel_size=3, stride=2), nn.ReLU(), ) self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim) @@ -132,17 +128,13 @@ class VggSubsampling(nn.Module): ) ) layers.append( - torch.nn.MaxPool2d( - kernel_size=2, stride=2, padding=0, ceil_mode=True - ) + torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=True) ) cur_channels = block_dim self.layers = nn.Sequential(*layers) - self.out = nn.Linear( - block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim - ) + self.out = nn.Linear(block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim) def forward(self, x: torch.Tensor) -> torch.Tensor: """Subsample x. diff --git a/egs/gigaspeech/ASR/conformer_ctc/train.py b/egs/gigaspeech/ASR/conformer_ctc/train.py index 2965cde18..4883d04d8 100755 --- a/egs/gigaspeech/ASR/conformer_ctc/train.py +++ b/egs/gigaspeech/ASR/conformer_ctc/train.py @@ -386,9 +386,7 @@ def compute_loss( # # See https://github.com/k2-fsa/icefall/issues/97 # for more details - unsorted_token_ids = graph_compiler.texts_to_ids( - supervisions["text"] - ) + unsorted_token_ids = graph_compiler.texts_to_ids(supervisions["text"]) att_loss = mmodel.decoder_forward( encoder_memory, memory_mask, @@ -521,9 +519,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -641,9 +637,7 @@ def run(rank, world_size, args): cur_lr = optimizer._rate if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) + tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) if rank == 0: diff --git a/egs/gigaspeech/ASR/conformer_ctc/transformer.py b/egs/gigaspeech/ASR/conformer_ctc/transformer.py index 00ca027a7..0566cfc81 100644 --- a/egs/gigaspeech/ASR/conformer_ctc/transformer.py +++ b/egs/gigaspeech/ASR/conformer_ctc/transformer.py @@ -151,9 +151,7 @@ class Transformer(nn.Module): norm=decoder_norm, ) - self.decoder_output_layer = torch.nn.Linear( - d_model, self.decoder_num_class - ) + self.decoder_output_layer = torch.nn.Linear(d_model, self.decoder_num_class) self.decoder_criterion = LabelSmoothingLoss() else: @@ -181,18 +179,13 @@ class Transformer(nn.Module): memory_key_padding_mask for the decoder. Its shape is (N, T). It is None if `supervision` is None. """ - if ( - isinstance(self.use_feat_batchnorm, bool) - and self.use_feat_batchnorm - ): + if isinstance(self.use_feat_batchnorm, bool) and self.use_feat_batchnorm: x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T) x = self.feat_batchnorm(x) x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C) if isinstance(self.use_feat_batchnorm, float): x *= self.use_feat_batchnorm - encoder_memory, memory_key_padding_mask = self.run_encoder( - x, supervision - ) + encoder_memory, memory_key_padding_mask = self.run_encoder(x, supervision) x = self.ctc_output(encoder_memory) return x, encoder_memory, memory_key_padding_mask @@ -273,23 +266,17 @@ class Transformer(nn.Module): """ ys_in = add_sos(token_ids, sos_id=sos_id) ys_in = [torch.tensor(y) for y in ys_in] - ys_in_pad = pad_sequence( - ys_in, batch_first=True, padding_value=float(eos_id) - ) + ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id)) ys_out = add_eos(token_ids, eos_id=eos_id) ys_out = [torch.tensor(y) for y in ys_out] - ys_out_pad = pad_sequence( - ys_out, batch_first=True, padding_value=float(-1) - ) + ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1)) device = memory.device ys_in_pad = ys_in_pad.to(device) ys_out_pad = ys_out_pad.to(device) - tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( - device - ) + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) # TODO: Use length information to create the decoder padding mask @@ -350,23 +337,17 @@ class Transformer(nn.Module): ys_in = add_sos(token_ids, sos_id=sos_id) ys_in = [torch.tensor(y) for y in ys_in] - ys_in_pad = pad_sequence( - ys_in, batch_first=True, padding_value=float(eos_id) - ) + ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id)) ys_out = add_eos(token_ids, eos_id=eos_id) ys_out = [torch.tensor(y) for y in ys_out] - ys_out_pad = pad_sequence( - ys_out, batch_first=True, padding_value=float(-1) - ) + ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1)) device = memory.device ys_in_pad = ys_in_pad.to(device, dtype=torch.int64) ys_out_pad = ys_out_pad.to(device, dtype=torch.int64) - tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( - device - ) + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) # TODO: Use length information to create the decoder padding mask @@ -639,9 +620,7 @@ def _get_activation_fn(activation: str): elif activation == "gelu": return nn.functional.gelu - raise RuntimeError( - "activation should be relu/gelu, not {}".format(activation) - ) + raise RuntimeError("activation should be relu/gelu, not {}".format(activation)) class PositionalEncoding(nn.Module): @@ -843,9 +822,7 @@ def encoder_padding_mask( 1, ).to(torch.int32) - lengths = [ - 0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1) - ] + lengths = [0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)] for idx in range(supervision_segments.size(0)): # Note: TorchScript doesn't allow to unpack tensors as tuples sequence_idx = supervision_segments[idx, 0].item() @@ -866,9 +843,7 @@ def encoder_padding_mask( return mask -def decoder_padding_mask( - ys_pad: torch.Tensor, ignore_id: int = -1 -) -> torch.Tensor: +def decoder_padding_mask(ys_pad: torch.Tensor, ignore_id: int = -1) -> torch.Tensor: """Generate a length mask for input. The masked position are filled with True, diff --git a/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_dev_test.py b/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_dev_test.py index 8209ee3ec..07beeb1f0 100755 --- a/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_dev_test.py +++ b/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_dev_test.py @@ -77,9 +77,7 @@ def compute_fbank_gigaspeech_dev_test(): def main(): - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) compute_fbank_gigaspeech_dev_test() diff --git a/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_splits.py b/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_splits.py index 6410249db..1c71be0f9 100755 --- a/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_splits.py +++ b/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_splits.py @@ -134,9 +134,7 @@ def main(): date_time = now.strftime("%Y-%m-%d-%H-%M-%S") log_filename = "log-compute_fbank_gigaspeech_splits" - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" log_filename = f"{log_filename}-{date_time}" logging.basicConfig( diff --git a/egs/gigaspeech/ASR/local/preprocess_gigaspeech.py b/egs/gigaspeech/ASR/local/preprocess_gigaspeech.py index 48d10a157..31abe7fff 100755 --- a/egs/gigaspeech/ASR/local/preprocess_gigaspeech.py +++ b/egs/gigaspeech/ASR/local/preprocess_gigaspeech.py @@ -98,19 +98,13 @@ def preprocess_giga_speech(): f"Speed perturb for {partition} with factors 0.9 and 1.1 " "(Perturbing may take 8 minutes and saving may take 20 minutes)" ) - cut_set = ( - cut_set - + cut_set.perturb_speed(0.9) - + cut_set.perturb_speed(1.1) - ) + cut_set = cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) logging.info(f"Saving to {raw_cuts_path}") cut_set.to_file(raw_cuts_path) def main(): - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) preprocess_giga_speech() diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py index c87686e1e..7f114fba6 100644 --- a/egs/gigaspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py +++ b/egs/gigaspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py @@ -195,8 +195,7 @@ class GigaSpeechAsrDataModule: "--small-dev", type=str2bool, default=False, - help="Should we use only 1000 utterances for dev " - "(speeds up training)", + help="Should we use only 1000 utterances for dev " "(speeds up training)", ) def train_dataloaders( @@ -216,13 +215,9 @@ class GigaSpeechAsrDataModule: if self.args.enable_musan: logging.info("Enable MUSAN") logging.info("About to get Musan cuts") - cuts_musan = load_manifest( - self.args.manifest_dir / "musan_cuts.jsonl.gz" - ) + cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") transforms.append( - CutMix( - cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True - ) + CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) ) else: logging.info("Disable MUSAN") @@ -244,9 +239,7 @@ class GigaSpeechAsrDataModule: input_transforms = [] if self.args.enable_spec_aug: logging.info("Enable SpecAugment") - logging.info( - f"Time warp factor: {self.args.spec_aug_time_warp_factor}" - ) + logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") # Set the value of num_frame_masks according to Lhotse's version. # In different Lhotse's versions, the default of num_frame_masks is # different. @@ -289,9 +282,7 @@ class GigaSpeechAsrDataModule: # Drop feats to be on the safe side. train = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) - ), + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), input_transforms=input_transforms, return_cuts=self.args.return_cuts, ) @@ -347,9 +338,7 @@ class GigaSpeechAsrDataModule: if self.args.on_the_fly_feats: validate = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) - ), + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), return_cuts=self.args.return_cuts, ) else: @@ -405,9 +394,7 @@ class GigaSpeechAsrDataModule: @lru_cache() def dev_cuts(self) -> CutSet: logging.info("About to get dev cuts") - cuts_valid = load_manifest_lazy( - self.args.manifest_dir / "cuts_DEV.jsonl.gz" - ) + cuts_valid = load_manifest_lazy(self.args.manifest_dir / "cuts_DEV.jsonl.gz") if self.args.small_dev: return cuts_valid.subset(first=1000) else: diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/decode.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/decode.py index 5849a3471..c0b17750e 100755 --- a/egs/gigaspeech/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/gigaspeech/ASR/pruned_transducer_stateless2/decode.py @@ -77,11 +77,7 @@ from beam_search import ( from gigaspeech_scoring import asr_text_post_processing from train import get_params, get_transducer_model -from icefall.checkpoint import ( - average_checkpoints, - find_checkpoints, - load_checkpoint, -) +from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.utils import ( AttributeDict, setup_logger, @@ -188,8 +184,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -258,9 +253,7 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) hyps = [] if params.decoding_method == "fast_beam_search": @@ -275,10 +268,7 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -398,9 +388,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -434,8 +422,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -511,8 +498,7 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/export.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/export.py index cff9c7377..3d1e7bc18 100755 --- a/egs/gigaspeech/ASR/pruned_transducer_stateless2/export.py +++ b/egs/gigaspeech/ASR/pruned_transducer_stateless2/export.py @@ -51,11 +51,7 @@ import sentencepiece as spm import torch from train import get_params, get_transducer_model -from icefall.checkpoint import ( - average_checkpoints, - find_checkpoints, - load_checkpoint, -) +from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.utils import str2bool @@ -120,8 +116,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) return parser @@ -160,8 +155,7 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -209,9 +203,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py index 83ae25561..f51584120 100755 --- a/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py @@ -77,9 +77,7 @@ from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def get_parser(): @@ -178,8 +176,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( @@ -202,8 +199,7 @@ def get_parser(): "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network)" "part.", ) parser.add_argument( @@ -553,23 +549,16 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 - if warmup < 1.0 - else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = ( - params.simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss + 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) ) + loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() @@ -732,9 +721,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") diff --git a/egs/librispeech/ASR/conformer_ctc/ali.py b/egs/librispeech/ASR/conformer_ctc/ali.py index 2828e309e..42e14abac 100755 --- a/egs/librispeech/ASR/conformer_ctc/ali.py +++ b/egs/librispeech/ASR/conformer_ctc/ali.py @@ -231,9 +231,7 @@ def compute_alignments( labels_ali = get_alignments(best_path, kind="labels") aux_labels_ali = get_alignments(best_path, kind="aux_labels") assert len(labels_ali) == len(aux_labels_ali) == len(cut_list) - for cut, labels, aux_labels in zip( - cut_list, labels_ali, aux_labels_ali - ): + for cut, labels, aux_labels in zip(cut_list, labels_ali, aux_labels_ali): cut.labels_alignment = labels_writer.store_array( key=cut.id, value=np.asarray(labels, dtype=np.int32), @@ -258,9 +256,7 @@ def compute_alignments( if batch_idx % 100 == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return CutSet.from_cuts(cuts) @@ -289,9 +285,7 @@ def main(): out_labels_ali_filename = out_dir / f"labels_{params.dataset}.h5" out_aux_labels_ali_filename = out_dir / f"aux_labels_{params.dataset}.h5" - out_manifest_filename = ( - out_dir / f"librispeech_cuts_{params.dataset}.jsonl.gz" - ) + out_manifest_filename = out_dir / f"librispeech_cuts_{params.dataset}.jsonl.gz" for f in ( out_labels_ali_filename, diff --git a/egs/librispeech/ASR/conformer_ctc/conformer.py b/egs/librispeech/ASR/conformer_ctc/conformer.py index 6fac07f93..a1cfe6e75 100644 --- a/egs/librispeech/ASR/conformer_ctc/conformer.py +++ b/egs/librispeech/ASR/conformer_ctc/conformer.py @@ -160,9 +160,7 @@ class ConformerEncoderLayer(nn.Module): use_conv_batchnorm: bool = False, ) -> None: super(ConformerEncoderLayer, self).__init__() - self.self_attn = RelPositionMultiheadAttention( - d_model, nhead, dropout=0.0 - ) + self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) self.feed_forward = nn.Sequential( nn.Linear(d_model, dim_feedforward), @@ -182,18 +180,14 @@ class ConformerEncoderLayer(nn.Module): d_model, cnn_module_kernel, use_batchnorm=use_conv_batchnorm ) - self.norm_ff_macaron = nn.LayerNorm( - d_model - ) # for the macaron style FNN module + self.norm_ff_macaron = nn.LayerNorm(d_model) # for the macaron style FNN module self.norm_ff = nn.LayerNorm(d_model) # for the FNN module self.norm_mha = nn.LayerNorm(d_model) # for the MHA module self.ff_scale = 0.5 self.norm_conv = nn.LayerNorm(d_model) # for the CNN module - self.norm_final = nn.LayerNorm( - d_model - ) # for the final output of the block + self.norm_final = nn.LayerNorm(d_model) # for the final output of the block self.dropout = nn.Dropout(dropout) @@ -227,9 +221,7 @@ class ConformerEncoderLayer(nn.Module): residual = src if self.normalize_before: src = self.norm_ff_macaron(src) - src = residual + self.ff_scale * self.dropout( - self.feed_forward_macaron(src) - ) + src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src)) if not self.normalize_before: src = self.norm_ff_macaron(src) @@ -348,9 +340,7 @@ class RelPositionalEncoding(torch.nn.Module): """ - def __init__( - self, d_model: int, dropout_rate: float, max_len: int = 5000 - ) -> None: + def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: """Construct an PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() self.d_model = d_model @@ -366,9 +356,7 @@ class RelPositionalEncoding(torch.nn.Module): # the length of self.pe is 2 * input_len - 1 if self.pe.size(1) >= x.size(1) * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str( - x.device - ): + if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return # Suppose `i` means to the position of query vector and `j` means the @@ -638,9 +626,9 @@ class RelPositionMultiheadAttention(nn.Module): if torch.equal(query, key) and torch.equal(key, value): # self-attention - q, k, v = nn.functional.linear( - query, in_proj_weight, in_proj_bias - ).chunk(3, dim=-1) + q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk( + 3, dim=-1 + ) elif torch.equal(key, value): # encoder-decoder attention @@ -708,31 +696,22 @@ class RelPositionMultiheadAttention(nn.Module): if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: - raise RuntimeError( - "The size of the 2D attn_mask is not correct." - ) + raise RuntimeError("The size of the 2D attn_mask is not correct.") elif attn_mask.dim() == 3: if list(attn_mask.size()) != [ bsz * num_heads, query.size(0), key.size(0), ]: - raise RuntimeError( - "The size of the 3D attn_mask is not correct." - ) + raise RuntimeError("The size of the 3D attn_mask is not correct.") else: raise RuntimeError( - "attn_mask's dimension {} is not supported".format( - attn_mask.dim() - ) + "attn_mask's dimension {} is not supported".format(attn_mask.dim()) ) # attn_mask's dim is 3 now. # convert ByteTensor key_padding_mask to bool - if ( - key_padding_mask is not None - and key_padding_mask.dtype == torch.uint8 - ): + if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: warnings.warn( "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." ) @@ -771,9 +750,7 @@ class RelPositionMultiheadAttention(nn.Module): # first compute matrix a and matrix c # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - matrix_ac = torch.matmul( - q_with_bias_u, k - ) # (batch, head, time1, time2) + matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2) # compute matrix b and matrix d matrix_bd = torch.matmul( @@ -785,9 +762,7 @@ class RelPositionMultiheadAttention(nn.Module): matrix_ac + matrix_bd ) * scaling # (batch, head, time1, time2) - attn_output_weights = attn_output_weights.view( - bsz * num_heads, tgt_len, -1 - ) + attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) assert list(attn_output_weights.size()) == [ bsz * num_heads, @@ -821,13 +796,9 @@ class RelPositionMultiheadAttention(nn.Module): attn_output = torch.bmm(attn_output_weights, v) assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] attn_output = ( - attn_output.transpose(0, 1) - .contiguous() - .view(tgt_len, bsz, embed_dim) - ) - attn_output = nn.functional.linear( - attn_output, out_proj_weight, out_proj_bias + attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) ) + attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) if need_weights: # average attention weights over heads diff --git a/egs/librispeech/ASR/conformer_ctc/decode.py b/egs/librispeech/ASR/conformer_ctc/decode.py index 3f3b1acda..7e0bf5b7b 100755 --- a/egs/librispeech/ASR/conformer_ctc/decode.py +++ b/egs/librispeech/ASR/conformer_ctc/decode.py @@ -551,9 +551,7 @@ def decode_dataset( results[lm_scale].extend(this_batch) else: - assert ( - len(results) > 0 - ), "It should not decode to empty in the first batch!" + assert len(results) > 0, "It should not decode to empty in the first batch!" this_batch = [] hyp_words = [] for ref_text in texts: @@ -568,9 +566,7 @@ def decode_dataset( if batch_idx % 100 == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -602,9 +598,7 @@ def save_results( test_set_wers[key] = wer if enable_log: - logging.info( - "Wrote detailed error stats to {}".format(errs_filename) - ) + logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt" @@ -809,9 +803,7 @@ def main(): eos_id=eos_id, ) - save_results( - params=params, test_set_name=test_set, results_dict=results_dict - ) + save_results(params=params, test_set_name=test_set, results_dict=results_dict) logging.info("Done!") diff --git a/egs/librispeech/ASR/conformer_ctc/export.py b/egs/librispeech/ASR/conformer_ctc/export.py index 28c28df01..fbcbd7b29 100755 --- a/egs/librispeech/ASR/conformer_ctc/export.py +++ b/egs/librispeech/ASR/conformer_ctc/export.py @@ -157,9 +157,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/conformer_ctc/label_smoothing.py b/egs/librispeech/ASR/conformer_ctc/label_smoothing.py index 1f2f3b137..cb0d6e04d 100644 --- a/egs/librispeech/ASR/conformer_ctc/label_smoothing.py +++ b/egs/librispeech/ASR/conformer_ctc/label_smoothing.py @@ -82,13 +82,10 @@ class LabelSmoothingLoss(torch.nn.Module): # for why we don't use target[ignored] = 0 here target = torch.where(ignored, torch.zeros_like(target), target) - true_dist = torch.nn.functional.one_hot( - target, num_classes=num_classes - ).to(x) + true_dist = torch.nn.functional.one_hot(target, num_classes=num_classes).to(x) true_dist = ( - true_dist * (1 - self.label_smoothing) - + self.label_smoothing / num_classes + true_dist * (1 - self.label_smoothing) + self.label_smoothing / num_classes ) # Set the value of ignored indexes to 0 diff --git a/egs/librispeech/ASR/conformer_ctc/pretrained.py b/egs/librispeech/ASR/conformer_ctc/pretrained.py index a2c0a5486..8200af866 100755 --- a/egs/librispeech/ASR/conformer_ctc/pretrained.py +++ b/egs/librispeech/ASR/conformer_ctc/pretrained.py @@ -237,8 +237,7 @@ def read_sound_files( for f in filenames: wave, sample_rate = torchaudio.load(f) assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" + f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}" ) # We use only the first channel ans.append(wave[0]) @@ -300,9 +299,7 @@ def main(): logging.info("Decoding started") features = fbank(waves) - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) # Note: We don't use key padding mask for attention during decoding with torch.no_grad(): @@ -427,9 +424,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py index 542fb0364..8e0f73d05 100644 --- a/egs/librispeech/ASR/conformer_ctc/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -42,13 +42,9 @@ class Conv2dSubsampling(nn.Module): assert idim >= 7 super().__init__() self.conv = nn.Sequential( - nn.Conv2d( - in_channels=1, out_channels=odim, kernel_size=3, stride=2 - ), + nn.Conv2d(in_channels=1, out_channels=odim, kernel_size=3, stride=2), nn.ReLU(), - nn.Conv2d( - in_channels=odim, out_channels=odim, kernel_size=3, stride=2 - ), + nn.Conv2d(in_channels=odim, out_channels=odim, kernel_size=3, stride=2), nn.ReLU(), ) self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim) @@ -132,17 +128,13 @@ class VggSubsampling(nn.Module): ) ) layers.append( - torch.nn.MaxPool2d( - kernel_size=2, stride=2, padding=0, ceil_mode=True - ) + torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=True) ) cur_channels = block_dim self.layers = nn.Sequential(*layers) - self.out = nn.Linear( - block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim - ) + self.out = nn.Linear(block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim) def forward(self, x: torch.Tensor) -> torch.Tensor: """Subsample x. diff --git a/egs/librispeech/ASR/conformer_ctc/train.py b/egs/librispeech/ASR/conformer_ctc/train.py index 6419f6816..1449bc310 100755 --- a/egs/librispeech/ASR/conformer_ctc/train.py +++ b/egs/librispeech/ASR/conformer_ctc/train.py @@ -393,9 +393,7 @@ def compute_loss( # Works with a phone lexicon decoding_graph = graph_compiler.compile(texts) else: - raise ValueError( - f"Unsupported type of graph compiler: {type(graph_compiler)}" - ) + raise ValueError(f"Unsupported type of graph compiler: {type(graph_compiler)}") dense_fsa_vec = k2.DenseFsaVec( nnet_output, @@ -422,9 +420,7 @@ def compute_loss( # # See https://github.com/k2-fsa/icefall/issues/97 # for more details - unsorted_token_ids = graph_compiler.texts_to_ids( - supervisions["text"] - ) + unsorted_token_ids = graph_compiler.texts_to_ids(supervisions["text"]) att_loss = mmodel.decoder_forward( encoder_memory, memory_mask, @@ -453,9 +449,7 @@ def compute_loss( info["utt_duration"] = supervisions["num_frames"].sum().item() # averaged padding proportion over utterances info["utt_pad_proportion"] = ( - ((feature.size(1) - supervisions["num_frames"]) / feature.size(1)) - .sum() - .item() + ((feature.size(1) - supervisions["num_frames"]) / feature.size(1)).sum().item() ) return loss, info @@ -568,9 +562,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -733,9 +725,7 @@ def run(rank, world_size, args): cur_lr = optimizer._rate if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) + tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) if rank == 0: diff --git a/egs/librispeech/ASR/conformer_ctc/transformer.py b/egs/librispeech/ASR/conformer_ctc/transformer.py index 00ca027a7..0566cfc81 100644 --- a/egs/librispeech/ASR/conformer_ctc/transformer.py +++ b/egs/librispeech/ASR/conformer_ctc/transformer.py @@ -151,9 +151,7 @@ class Transformer(nn.Module): norm=decoder_norm, ) - self.decoder_output_layer = torch.nn.Linear( - d_model, self.decoder_num_class - ) + self.decoder_output_layer = torch.nn.Linear(d_model, self.decoder_num_class) self.decoder_criterion = LabelSmoothingLoss() else: @@ -181,18 +179,13 @@ class Transformer(nn.Module): memory_key_padding_mask for the decoder. Its shape is (N, T). It is None if `supervision` is None. """ - if ( - isinstance(self.use_feat_batchnorm, bool) - and self.use_feat_batchnorm - ): + if isinstance(self.use_feat_batchnorm, bool) and self.use_feat_batchnorm: x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T) x = self.feat_batchnorm(x) x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C) if isinstance(self.use_feat_batchnorm, float): x *= self.use_feat_batchnorm - encoder_memory, memory_key_padding_mask = self.run_encoder( - x, supervision - ) + encoder_memory, memory_key_padding_mask = self.run_encoder(x, supervision) x = self.ctc_output(encoder_memory) return x, encoder_memory, memory_key_padding_mask @@ -273,23 +266,17 @@ class Transformer(nn.Module): """ ys_in = add_sos(token_ids, sos_id=sos_id) ys_in = [torch.tensor(y) for y in ys_in] - ys_in_pad = pad_sequence( - ys_in, batch_first=True, padding_value=float(eos_id) - ) + ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id)) ys_out = add_eos(token_ids, eos_id=eos_id) ys_out = [torch.tensor(y) for y in ys_out] - ys_out_pad = pad_sequence( - ys_out, batch_first=True, padding_value=float(-1) - ) + ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1)) device = memory.device ys_in_pad = ys_in_pad.to(device) ys_out_pad = ys_out_pad.to(device) - tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( - device - ) + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) # TODO: Use length information to create the decoder padding mask @@ -350,23 +337,17 @@ class Transformer(nn.Module): ys_in = add_sos(token_ids, sos_id=sos_id) ys_in = [torch.tensor(y) for y in ys_in] - ys_in_pad = pad_sequence( - ys_in, batch_first=True, padding_value=float(eos_id) - ) + ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id)) ys_out = add_eos(token_ids, eos_id=eos_id) ys_out = [torch.tensor(y) for y in ys_out] - ys_out_pad = pad_sequence( - ys_out, batch_first=True, padding_value=float(-1) - ) + ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1)) device = memory.device ys_in_pad = ys_in_pad.to(device, dtype=torch.int64) ys_out_pad = ys_out_pad.to(device, dtype=torch.int64) - tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( - device - ) + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) # TODO: Use length information to create the decoder padding mask @@ -639,9 +620,7 @@ def _get_activation_fn(activation: str): elif activation == "gelu": return nn.functional.gelu - raise RuntimeError( - "activation should be relu/gelu, not {}".format(activation) - ) + raise RuntimeError("activation should be relu/gelu, not {}".format(activation)) class PositionalEncoding(nn.Module): @@ -843,9 +822,7 @@ def encoder_padding_mask( 1, ).to(torch.int32) - lengths = [ - 0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1) - ] + lengths = [0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)] for idx in range(supervision_segments.size(0)): # Note: TorchScript doesn't allow to unpack tensors as tuples sequence_idx = supervision_segments[idx, 0].item() @@ -866,9 +843,7 @@ def encoder_padding_mask( return mask -def decoder_padding_mask( - ys_pad: torch.Tensor, ignore_id: int = -1 -) -> torch.Tensor: +def decoder_padding_mask(ys_pad: torch.Tensor, ignore_id: int = -1) -> torch.Tensor: """Generate a length mask for input. The masked position are filled with True, diff --git a/egs/librispeech/ASR/conformer_ctc2/attention.py b/egs/librispeech/ASR/conformer_ctc2/attention.py index 1375d7245..356d3f21b 100644 --- a/egs/librispeech/ASR/conformer_ctc2/attention.py +++ b/egs/librispeech/ASR/conformer_ctc2/attention.py @@ -18,11 +18,10 @@ from typing import Optional, Tuple import torch import torch.nn as nn +from scaling import ScaledLinear from torch import Tensor from torch.nn.init import xavier_normal_ -from scaling import ScaledLinear - class MultiheadAttention(nn.Module): r"""Allows the model to jointly attend to information @@ -76,9 +75,7 @@ class MultiheadAttention(nn.Module): self.embed_dim = embed_dim self.kdim = kdim if kdim is not None else embed_dim self.vdim = vdim if vdim is not None else embed_dim - self._qkv_same_embed_dim = ( - self.kdim == embed_dim and self.vdim == embed_dim - ) + self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim self.num_heads = num_heads self.dropout = dropout @@ -94,9 +91,7 @@ class MultiheadAttention(nn.Module): self.v_proj_weight = ScaledLinear(self.vdim, embed_dim, bias=bias) self.register_parameter("in_proj_weight", None) else: - self.in_proj_weight = ScaledLinear( - embed_dim, 3 * embed_dim, bias=bias - ) + self.in_proj_weight = ScaledLinear(embed_dim, 3 * embed_dim, bias=bias) self.register_parameter("q_proj_weight", None) self.register_parameter("k_proj_weight", None) self.register_parameter("v_proj_weight", None) @@ -107,12 +102,8 @@ class MultiheadAttention(nn.Module): self.out_proj = ScaledLinear(embed_dim, embed_dim, bias=bias) if add_bias_kv: - self.bias_k = nn.Parameter( - torch.empty((1, 1, embed_dim), **factory_kwargs) - ) - self.bias_v = nn.Parameter( - torch.empty((1, 1, embed_dim), **factory_kwargs) - ) + self.bias_k = nn.Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs)) + self.bias_v = nn.Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs)) else: self.bias_k = self.bias_v = None diff --git a/egs/librispeech/ASR/conformer_ctc2/conformer.py b/egs/librispeech/ASR/conformer_ctc2/conformer.py index b906d2650..09f1eb000 100644 --- a/egs/librispeech/ASR/conformer_ctc2/conformer.py +++ b/egs/librispeech/ASR/conformer_ctc2/conformer.py @@ -29,9 +29,8 @@ from scaling import ( ScaledConv1d, ScaledLinear, ) -from torch import Tensor, nn from subsampling import Conv2dSubsampling - +from torch import Tensor, nn from transformer import Supervisions, Transformer, encoder_padding_mask @@ -182,9 +181,7 @@ class ConformerEncoderLayer(nn.Module): self.d_model = d_model - self.self_attn = RelPositionMultiheadAttention( - d_model, nhead, dropout=0.0 - ) + self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) self.feed_forward = nn.Sequential( ScaledLinear(d_model, dim_feedforward), @@ -356,9 +353,7 @@ class RelPositionalEncoding(torch.nn.Module): """ - def __init__( - self, d_model: int, dropout_rate: float, max_len: int = 5000 - ) -> None: + def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: """Construct an PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() self.d_model = d_model @@ -373,9 +368,7 @@ class RelPositionalEncoding(torch.nn.Module): # the length of self.pe is 2 * input_len - 1 if self.pe.size(1) >= x.size(1) * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str( - x.device - ): + if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return # Suppose `i` means to the position of query vecotr and `j` means the @@ -650,9 +643,9 @@ class RelPositionMultiheadAttention(nn.Module): if torch.equal(query, key) and torch.equal(key, value): # self-attention - q, k, v = nn.functional.linear( - query, in_proj_weight, in_proj_bias - ).chunk(3, dim=-1) + q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk( + 3, dim=-1 + ) elif torch.equal(key, value): # encoder-decoder attention @@ -721,31 +714,22 @@ class RelPositionMultiheadAttention(nn.Module): if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: - raise RuntimeError( - "The size of the 2D attn_mask is not correct." - ) + raise RuntimeError("The size of the 2D attn_mask is not correct.") elif attn_mask.dim() == 3: if list(attn_mask.size()) != [ bsz * num_heads, query.size(0), key.size(0), ]: - raise RuntimeError( - "The size of the 3D attn_mask is not correct." - ) + raise RuntimeError("The size of the 3D attn_mask is not correct.") else: raise RuntimeError( - "attn_mask's dimension {} is not supported".format( - attn_mask.dim() - ) + "attn_mask's dimension {} is not supported".format(attn_mask.dim()) ) # attn_mask's dim is 3 now. # convert ByteTensor key_padding_mask to bool - if ( - key_padding_mask is not None - and key_padding_mask.dtype == torch.uint8 - ): + if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: warnings.warn( "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." ) @@ -784,9 +768,7 @@ class RelPositionMultiheadAttention(nn.Module): # first compute matrix a and matrix c # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - matrix_ac = torch.matmul( - q_with_bias_u, k - ) # (batch, head, time1, time2) + matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2) # compute matrix b and matrix d matrix_bd = torch.matmul( @@ -794,13 +776,9 @@ class RelPositionMultiheadAttention(nn.Module): ) # (batch, head, time1, 2*time1-1) matrix_bd = self.rel_shift(matrix_bd) - attn_output_weights = ( - matrix_ac + matrix_bd - ) # (batch, head, time1, time2) + attn_output_weights = matrix_ac + matrix_bd # (batch, head, time1, time2) - attn_output_weights = attn_output_weights.view( - bsz * num_heads, tgt_len, -1 - ) + attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) assert list(attn_output_weights.size()) == [ bsz * num_heads, @@ -834,13 +812,9 @@ class RelPositionMultiheadAttention(nn.Module): attn_output = torch.bmm(attn_output_weights, v) assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] attn_output = ( - attn_output.transpose(0, 1) - .contiguous() - .view(tgt_len, bsz, embed_dim) - ) - attn_output = nn.functional.linear( - attn_output, out_proj_weight, out_proj_bias + attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) ) + attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) if need_weights: # average attention weights over heads @@ -863,9 +837,7 @@ class ConvolutionModule(nn.Module): """ - def __init__( - self, channels: int, kernel_size: int, bias: bool = True - ) -> None: + def __init__(self, channels: int, kernel_size: int, bias: bool = True) -> None: """Construct an ConvolutionModule object.""" super(ConvolutionModule, self).__init__() # kernerl_size should be a odd number for 'SAME' padding diff --git a/egs/librispeech/ASR/conformer_ctc2/decode.py b/egs/librispeech/ASR/conformer_ctc2/decode.py index 97f2f2d39..0b271a51c 100755 --- a/egs/librispeech/ASR/conformer_ctc2/decode.py +++ b/egs/librispeech/ASR/conformer_ctc2/decode.py @@ -658,9 +658,7 @@ def decode_dataset( results[lm_scale].extend(this_batch) else: - assert ( - len(results) > 0 - ), "It should not decode to empty in the first batch!" + assert len(results) > 0, "It should not decode to empty in the first batch!" this_batch = [] hyp_words = [] for ref_text in texts: @@ -675,9 +673,7 @@ def decode_dataset( if batch_idx % 100 == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -709,9 +705,7 @@ def save_results( test_set_wers[key] = wer if enable_log: - logging.info( - "Wrote detailed error stats to {}".format(errs_filename) - ) + logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt" @@ -852,9 +846,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -881,9 +875,9 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -985,9 +979,7 @@ def main(): eos_id=eos_id, ) - save_results( - params=params, test_set_name=test_set, results_dict=results_dict - ) + save_results(params=params, test_set_name=test_set, results_dict=results_dict) logging.info("Done!") diff --git a/egs/librispeech/ASR/conformer_ctc2/export.py b/egs/librispeech/ASR/conformer_ctc2/export.py index 584b3c3fc..7892b03c6 100755 --- a/egs/librispeech/ASR/conformer_ctc2/export.py +++ b/egs/librispeech/ASR/conformer_ctc2/export.py @@ -47,6 +47,7 @@ import logging from pathlib import Path import torch +from conformer import Conformer from decode import get_params from icefall.checkpoint import ( @@ -55,10 +56,8 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from conformer import Conformer - -from icefall.utils import str2bool from icefall.lexicon import Lexicon +from icefall.utils import str2bool def get_parser(): @@ -177,9 +176,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -206,9 +205,9 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -273,9 +272,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/conformer_ctc2/train.py b/egs/librispeech/ASR/conformer_ctc2/train.py index 18fa3e69f..ceea0c22c 100755 --- a/egs/librispeech/ASR/conformer_ctc2/train.py +++ b/egs/librispeech/ASR/conformer_ctc2/train.py @@ -69,8 +69,8 @@ from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter -from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler from icefall import diagnostics +from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler from icefall.checkpoint import load_checkpoint, remove_checkpoints from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.checkpoint import ( @@ -89,9 +89,7 @@ from icefall.utils import ( str2bool, ) -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def get_parser(): @@ -498,11 +496,7 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = ( - model.device - if isinstance(model, DDP) - else next(model.parameters()).device - ) + device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -531,9 +525,7 @@ def compute_loss( # Works with a phone lexicon decoding_graph = graph_compiler.compile(texts) else: - raise ValueError( - f"Unsupported type of graph compiler: {type(graph_compiler)}" - ) + raise ValueError(f"Unsupported type of graph compiler: {type(graph_compiler)}") dense_fsa_vec = k2.DenseFsaVec( nnet_output, @@ -560,9 +552,7 @@ def compute_loss( # # See https://github.com/k2-fsa/icefall/issues/97 # for more details - unsorted_token_ids = graph_compiler.texts_to_ids( - supervisions["text"] - ) + unsorted_token_ids = graph_compiler.texts_to_ids(supervisions["text"]) att_loss = mmodel.decoder_forward( encoder_memory, memory_mask, @@ -580,9 +570,7 @@ def compute_loss( info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() info["ctc_loss"] = ctc_loss.detach().cpu().item() if params.att_rate != 0.0: info["att_loss"] = att_loss.detach().cpu().item() @@ -776,9 +764,9 @@ def train_one_epoch( f"tot_loss[{tot_loss}], batch size: {batch_size}, " f"lr: {cur_lr:.2e}" ) - if loss_info["ctc_loss"] == float("inf") or loss_info[ - "att_loss" - ] == float("inf"): + if loss_info["ctc_loss"] == float("inf") or loss_info["att_loss"] == float( + "inf" + ): logging.error( "Your loss contains inf, something goes wrong" f"failing batch names {batch_name}" @@ -791,9 +779,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") diff --git a/egs/librispeech/ASR/conformer_ctc2/transformer.py b/egs/librispeech/ASR/conformer_ctc2/transformer.py index 3ef7edc23..d3443dc94 100644 --- a/egs/librispeech/ASR/conformer_ctc2/transformer.py +++ b/egs/librispeech/ASR/conformer_ctc2/transformer.py @@ -21,19 +21,17 @@ from typing import Dict, List, Optional, Tuple import torch import torch.nn as nn -from label_smoothing import LabelSmoothingLoss -from subsampling import Conv2dSubsampling from attention import MultiheadAttention -from torch.nn.utils.rnn import pad_sequence - +from label_smoothing import LabelSmoothingLoss from scaling import ( ActivationBalancer, BasicNorm, DoubleSwish, - ScaledLinear, ScaledEmbedding, + ScaledLinear, ) - +from subsampling import Conv2dSubsampling +from torch.nn.utils.rnn import pad_sequence # Note: TorchScript requires Dict/List/etc. to be fully typed. Supervisions = Dict[str, torch.Tensor] @@ -210,9 +208,7 @@ class Transformer(nn.Module): x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) mask = encoder_padding_mask(x.size(0), supervisions) mask = mask.to(x.device) if mask is not None else None - x = self.encoder( - x, src_key_padding_mask=mask, warmup=warmup - ) # (T, N, C) + x = self.encoder(x, src_key_padding_mask=mask, warmup=warmup) # (T, N, C) return x, mask @@ -261,23 +257,17 @@ class Transformer(nn.Module): """ ys_in = add_sos(token_ids, sos_id=sos_id) ys_in = [torch.tensor(y) for y in ys_in] - ys_in_pad = pad_sequence( - ys_in, batch_first=True, padding_value=float(eos_id) - ) + ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id)) ys_out = add_eos(token_ids, eos_id=eos_id) ys_out = [torch.tensor(y) for y in ys_out] - ys_out_pad = pad_sequence( - ys_out, batch_first=True, padding_value=float(-1) - ) + ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1)) device = memory.device ys_in_pad = ys_in_pad.to(device) ys_out_pad = ys_out_pad.to(device) - tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( - device - ) + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) # TODO: Use length information to create the decoder padding mask @@ -338,23 +328,17 @@ class Transformer(nn.Module): ys_in = add_sos(token_ids, sos_id=sos_id) ys_in = [torch.tensor(y) for y in ys_in] - ys_in_pad = pad_sequence( - ys_in, batch_first=True, padding_value=float(eos_id) - ) + ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id)) ys_out = add_eos(token_ids, eos_id=eos_id) ys_out = [torch.tensor(y) for y in ys_out] - ys_out_pad = pad_sequence( - ys_out, batch_first=True, padding_value=float(-1) - ) + ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1)) device = memory.device ys_in_pad = ys_in_pad.to(device, dtype=torch.int64) ys_out_pad = ys_out_pad.to(device, dtype=torch.int64) - tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( - device - ) + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) # TODO: Use length information to create the decoder padding mask @@ -959,9 +943,7 @@ def encoder_padding_mask( 1, ).to(torch.int32) - lengths = [ - 0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1) - ] + lengths = [0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)] for idx in range(supervision_segments.size(0)): # Note: TorchScript doesn't allow to unpack tensors as tuples sequence_idx = supervision_segments[idx, 0].item() @@ -982,9 +964,7 @@ def encoder_padding_mask( return mask -def decoder_padding_mask( - ys_pad: torch.Tensor, ignore_id: int = -1 -) -> torch.Tensor: +def decoder_padding_mask(ys_pad: torch.Tensor, ignore_id: int = -1) -> torch.Tensor: """Generate a length mask for input. The masked position are filled with True, diff --git a/egs/librispeech/ASR/conformer_mmi/conformer.py b/egs/librispeech/ASR/conformer_mmi/conformer.py index 97c8d83a2..53e48eb13 100644 --- a/egs/librispeech/ASR/conformer_mmi/conformer.py +++ b/egs/librispeech/ASR/conformer_mmi/conformer.py @@ -156,9 +156,7 @@ class ConformerEncoderLayer(nn.Module): normalize_before: bool = True, ) -> None: super(ConformerEncoderLayer, self).__init__() - self.self_attn = RelPositionMultiheadAttention( - d_model, nhead, dropout=0.0 - ) + self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) self.feed_forward = nn.Sequential( nn.Linear(d_model, dim_feedforward), @@ -176,18 +174,14 @@ class ConformerEncoderLayer(nn.Module): self.conv_module = ConvolutionModule(d_model, cnn_module_kernel) - self.norm_ff_macaron = nn.LayerNorm( - d_model - ) # for the macaron style FNN module + self.norm_ff_macaron = nn.LayerNorm(d_model) # for the macaron style FNN module self.norm_ff = nn.LayerNorm(d_model) # for the FNN module self.norm_mha = nn.LayerNorm(d_model) # for the MHA module self.ff_scale = 0.5 self.norm_conv = nn.LayerNorm(d_model) # for the CNN module - self.norm_final = nn.LayerNorm( - d_model - ) # for the final output of the block + self.norm_final = nn.LayerNorm(d_model) # for the final output of the block self.dropout = nn.Dropout(dropout) @@ -221,9 +215,7 @@ class ConformerEncoderLayer(nn.Module): residual = src if self.normalize_before: src = self.norm_ff_macaron(src) - src = residual + self.ff_scale * self.dropout( - self.feed_forward_macaron(src) - ) + src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src)) if not self.normalize_before: src = self.norm_ff_macaron(src) @@ -342,9 +334,7 @@ class RelPositionalEncoding(torch.nn.Module): """ - def __init__( - self, d_model: int, dropout_rate: float, max_len: int = 5000 - ) -> None: + def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: """Construct an PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() self.d_model = d_model @@ -360,9 +350,7 @@ class RelPositionalEncoding(torch.nn.Module): # the length of self.pe is 2 * input_len - 1 if self.pe.size(1) >= x.size(1) * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str( - x.device - ): + if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return # Suppose `i` means to the position of query vector and `j` means the @@ -632,9 +620,9 @@ class RelPositionMultiheadAttention(nn.Module): if torch.equal(query, key) and torch.equal(key, value): # self-attention - q, k, v = nn.functional.linear( - query, in_proj_weight, in_proj_bias - ).chunk(3, dim=-1) + q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk( + 3, dim=-1 + ) elif torch.equal(key, value): # encoder-decoder attention @@ -702,31 +690,22 @@ class RelPositionMultiheadAttention(nn.Module): if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: - raise RuntimeError( - "The size of the 2D attn_mask is not correct." - ) + raise RuntimeError("The size of the 2D attn_mask is not correct.") elif attn_mask.dim() == 3: if list(attn_mask.size()) != [ bsz * num_heads, query.size(0), key.size(0), ]: - raise RuntimeError( - "The size of the 3D attn_mask is not correct." - ) + raise RuntimeError("The size of the 3D attn_mask is not correct.") else: raise RuntimeError( - "attn_mask's dimension {} is not supported".format( - attn_mask.dim() - ) + "attn_mask's dimension {} is not supported".format(attn_mask.dim()) ) # attn_mask's dim is 3 now. # convert ByteTensor key_padding_mask to bool - if ( - key_padding_mask is not None - and key_padding_mask.dtype == torch.uint8 - ): + if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: warnings.warn( "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." ) @@ -765,9 +744,7 @@ class RelPositionMultiheadAttention(nn.Module): # first compute matrix a and matrix c # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - matrix_ac = torch.matmul( - q_with_bias_u, k - ) # (batch, head, time1, time2) + matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2) # compute matrix b and matrix d matrix_bd = torch.matmul( @@ -779,9 +756,7 @@ class RelPositionMultiheadAttention(nn.Module): matrix_ac + matrix_bd ) * scaling # (batch, head, time1, time2) - attn_output_weights = attn_output_weights.view( - bsz * num_heads, tgt_len, -1 - ) + attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) assert list(attn_output_weights.size()) == [ bsz * num_heads, @@ -815,13 +790,9 @@ class RelPositionMultiheadAttention(nn.Module): attn_output = torch.bmm(attn_output_weights, v) assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] attn_output = ( - attn_output.transpose(0, 1) - .contiguous() - .view(tgt_len, bsz, embed_dim) - ) - attn_output = nn.functional.linear( - attn_output, out_proj_weight, out_proj_bias + attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) ) + attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) if need_weights: # average attention weights over heads @@ -844,9 +815,7 @@ class ConvolutionModule(nn.Module): """ - def __init__( - self, channels: int, kernel_size: int, bias: bool = True - ) -> None: + def __init__(self, channels: int, kernel_size: int, bias: bool = True) -> None: """Construct an ConvolutionModule object.""" super(ConvolutionModule, self).__init__() # kernerl_size should be a odd number for 'SAME' padding diff --git a/egs/librispeech/ASR/conformer_mmi/decode.py b/egs/librispeech/ASR/conformer_mmi/decode.py index fc9861489..e3c7b685f 100755 --- a/egs/librispeech/ASR/conformer_mmi/decode.py +++ b/egs/librispeech/ASR/conformer_mmi/decode.py @@ -478,9 +478,7 @@ def decode_dataset( if batch_idx % 100 == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -512,9 +510,7 @@ def save_results( test_set_wers[key] = wer if enable_log: - logging.info( - "Wrote detailed error stats to {}".format(errs_filename) - ) + logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt" @@ -653,9 +649,7 @@ def main(): if params.export: logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt") - torch.save( - {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt" - ) + torch.save({"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt") return model.to(device) @@ -687,9 +681,7 @@ def main(): eos_id=eos_id, ) - save_results( - params=params, test_set_name=test_set, results_dict=results_dict - ) + save_results(params=params, test_set_name=test_set, results_dict=results_dict) logging.info("Done!") diff --git a/egs/librispeech/ASR/conformer_mmi/subsampling.py b/egs/librispeech/ASR/conformer_mmi/subsampling.py index 5c3e1222e..ad9415987 100644 --- a/egs/librispeech/ASR/conformer_mmi/subsampling.py +++ b/egs/librispeech/ASR/conformer_mmi/subsampling.py @@ -25,13 +25,9 @@ class Conv2dSubsampling(nn.Module): assert idim >= 7 super().__init__() self.conv = nn.Sequential( - nn.Conv2d( - in_channels=1, out_channels=odim, kernel_size=3, stride=2 - ), + nn.Conv2d(in_channels=1, out_channels=odim, kernel_size=3, stride=2), nn.ReLU(), - nn.Conv2d( - in_channels=odim, out_channels=odim, kernel_size=3, stride=2 - ), + nn.Conv2d(in_channels=odim, out_channels=odim, kernel_size=3, stride=2), nn.ReLU(), ) self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim) @@ -115,17 +111,13 @@ class VggSubsampling(nn.Module): ) ) layers.append( - torch.nn.MaxPool2d( - kernel_size=2, stride=2, padding=0, ceil_mode=True - ) + torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=True) ) cur_channels = block_dim self.layers = nn.Sequential(*layers) - self.out = nn.Linear( - block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim - ) + self.out = nn.Linear(block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim) def forward(self, x: torch.Tensor) -> torch.Tensor: """Subsample x. diff --git a/egs/librispeech/ASR/conformer_mmi/test_subsampling.py b/egs/librispeech/ASR/conformer_mmi/test_subsampling.py index 937845d77..d0bb017dd 100755 --- a/egs/librispeech/ASR/conformer_mmi/test_subsampling.py +++ b/egs/librispeech/ASR/conformer_mmi/test_subsampling.py @@ -1,8 +1,7 @@ #!/usr/bin/env python3 -from subsampling import Conv2dSubsampling -from subsampling import VggSubsampling import torch +from subsampling import Conv2dSubsampling, VggSubsampling def test_conv2d_subsampling(): diff --git a/egs/librispeech/ASR/conformer_mmi/test_transformer.py b/egs/librispeech/ASR/conformer_mmi/test_transformer.py index 08e680607..25d18076d 100644 --- a/egs/librispeech/ASR/conformer_mmi/test_transformer.py +++ b/egs/librispeech/ASR/conformer_mmi/test_transformer.py @@ -1,17 +1,16 @@ #!/usr/bin/env python3 import torch +from torch.nn.utils.rnn import pad_sequence from transformer import ( Transformer, + add_eos, + add_sos, + decoder_padding_mask, encoder_padding_mask, generate_square_subsequent_mask, - decoder_padding_mask, - add_sos, - add_eos, ) -from torch.nn.utils.rnn import pad_sequence - def test_encoder_padding_mask(): supervisions = { diff --git a/egs/librispeech/ASR/conformer_mmi/train-with-attention.py b/egs/librispeech/ASR/conformer_mmi/train-with-attention.py index 011dadd73..f8c94cff9 100755 --- a/egs/librispeech/ASR/conformer_mmi/train-with-attention.py +++ b/egs/librispeech/ASR/conformer_mmi/train-with-attention.py @@ -36,23 +36,14 @@ from torch.nn.utils import clip_grad_norm_ from torch.utils.tensorboard import SummaryWriter from transformer import Noam -from icefall.ali import ( - convert_alignments_to_tensor, - load_alignments, - lookup_alignments, -) +from icefall.ali import convert_alignments_to_tensor, load_alignments, lookup_alignments from icefall.checkpoint import load_checkpoint from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.dist import cleanup_dist, setup_dist from icefall.lexicon import Lexicon from icefall.mmi import LFMMILoss from icefall.mmi_graph_compiler import MmiTrainingGraphCompiler -from icefall.utils import ( - AttributeDict, - encode_supervisions, - setup_logger, - str2bool, -) +from icefall.utils import AttributeDict, encode_supervisions, setup_logger, str2bool def get_parser(): @@ -370,10 +361,7 @@ def compute_loss( nnet_output = nnet_output.clone() nnet_output[:, :min_len, :] += ali_scale * mask[:, :min_len, :] - if ( - params.batch_idx_train > params.use_ali_until - and params.beam_size < 8 - ): + if params.batch_idx_train > params.use_ali_until and params.beam_size < 8: # logging.info("Change beam size to 8") params.beam_size = 8 else: @@ -762,19 +750,14 @@ def run(rank, world_size, args): for epoch in range(params.start_epoch, params.num_epochs): train_dl.sampler.set_epoch(epoch) - if ( - params.batch_idx_train >= params.use_ali_until - and train_ali is not None - ): + if params.batch_idx_train >= params.use_ali_until and train_ali is not None: # Delete the alignments to save memory train_ali = None valid_ali = None cur_lr = optimizer._rate if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) + tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) if rank == 0: diff --git a/egs/librispeech/ASR/conformer_mmi/train.py b/egs/librispeech/ASR/conformer_mmi/train.py index 9a5bdcce2..5cfb2bfc7 100755 --- a/egs/librispeech/ASR/conformer_mmi/train.py +++ b/egs/librispeech/ASR/conformer_mmi/train.py @@ -36,23 +36,14 @@ from torch.nn.utils import clip_grad_norm_ from torch.utils.tensorboard import SummaryWriter from transformer import Noam -from icefall.ali import ( - convert_alignments_to_tensor, - load_alignments, - lookup_alignments, -) +from icefall.ali import convert_alignments_to_tensor, load_alignments, lookup_alignments from icefall.checkpoint import load_checkpoint from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.dist import cleanup_dist, setup_dist from icefall.lexicon import Lexicon from icefall.mmi import LFMMILoss from icefall.mmi_graph_compiler import MmiTrainingGraphCompiler -from icefall.utils import ( - AttributeDict, - encode_supervisions, - setup_logger, - str2bool, -) +from icefall.utils import AttributeDict, encode_supervisions, setup_logger, str2bool def get_parser(): @@ -377,10 +368,7 @@ def compute_loss( nnet_output = nnet_output.clone() nnet_output[:, :min_len, :] += ali_scale * mask[:, :min_len, :] - if ( - params.batch_idx_train > params.use_ali_until - and params.beam_size < 8 - ): + if params.batch_idx_train > params.use_ali_until and params.beam_size < 8: logging.info("Change beam size to 8") params.beam_size = 8 else: @@ -770,19 +758,14 @@ def run(rank, world_size, args): for epoch in range(params.start_epoch, params.num_epochs): fix_random_seed(params.seed + epoch) train_dl.sampler.set_epoch(epoch) - if ( - params.batch_idx_train >= params.use_ali_until - and train_ali is not None - ): + if params.batch_idx_train >= params.use_ali_until and train_ali is not None: # Delete the alignments to save memory train_ali = None valid_ali = None cur_lr = optimizer._rate if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) + tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) if rank == 0: diff --git a/egs/librispeech/ASR/conformer_mmi/transformer.py b/egs/librispeech/ASR/conformer_mmi/transformer.py index 68a4ff65c..2542d9abe 100644 --- a/egs/librispeech/ASR/conformer_mmi/transformer.py +++ b/egs/librispeech/ASR/conformer_mmi/transformer.py @@ -148,9 +148,7 @@ class Transformer(nn.Module): norm=decoder_norm, ) - self.decoder_output_layer = torch.nn.Linear( - d_model, self.decoder_num_class - ) + self.decoder_output_layer = torch.nn.Linear(d_model, self.decoder_num_class) self.decoder_criterion = LabelSmoothingLoss(self.decoder_num_class) else: @@ -182,9 +180,7 @@ class Transformer(nn.Module): x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T) x = self.feat_batchnorm(x) x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C) - encoder_memory, memory_key_padding_mask = self.run_encoder( - x, supervision - ) + encoder_memory, memory_key_padding_mask = self.run_encoder(x, supervision) x = self.ctc_output(encoder_memory) return x, encoder_memory, memory_key_padding_mask @@ -274,9 +270,7 @@ class Transformer(nn.Module): ys_in_pad = ys_in_pad.to(device) ys_out_pad = ys_out_pad.to(device) - tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( - device - ) + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) # TODO: Use length information to create the decoder padding mask @@ -341,9 +335,7 @@ class Transformer(nn.Module): ys_in_pad = ys_in_pad.to(device, dtype=torch.int64) ys_out_pad = ys_out_pad.to(device, dtype=torch.int64) - tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( - device - ) + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) # TODO: Use length information to create the decoder padding mask @@ -616,9 +608,7 @@ def _get_activation_fn(activation: str): elif activation == "gelu": return nn.functional.gelu - raise RuntimeError( - "activation should be relu/gelu, not {}".format(activation) - ) + raise RuntimeError("activation should be relu/gelu, not {}".format(activation)) class PositionalEncoding(nn.Module): @@ -887,9 +877,7 @@ def encoder_padding_mask( 1, ).to(torch.int32) - lengths = [ - 0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1) - ] + lengths = [0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)] for idx in range(supervision_segments.size(0)): # Note: TorchScript doesn't allow to unpack tensors as tuples sequence_idx = supervision_segments[idx, 0].item() @@ -910,9 +898,7 @@ def encoder_padding_mask( return mask -def decoder_padding_mask( - ys_pad: torch.Tensor, ignore_id: int = -1 -) -> torch.Tensor: +def decoder_padding_mask(ys_pad: torch.Tensor, ignore_id: int = -1) -> torch.Tensor: """Generate a length mask for input. The masked position are filled with True, diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/decode.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/decode.py index 620d69a19..6854c82d8 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/decode.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/decode.py @@ -215,8 +215,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( @@ -284,9 +283,7 @@ def decode_one_batch( value=LOG_EPS, ) - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) hyps = [] if params.decoding_method == "fast_beam_search": @@ -301,10 +298,7 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -427,9 +421,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -462,8 +454,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -506,9 +497,7 @@ def main(): params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -540,9 +529,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -569,9 +558,9 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py index 8ca7d5568..1aaa3b9cb 100644 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py @@ -35,7 +35,6 @@ from scaling import ( from icefall.utils import make_pad_mask - LOG_EPSILON = math.log(1e-10) @@ -127,9 +126,7 @@ def stack_states( for si, s in enumerate(layer): attn_caches[li][si].append(s) if b == batch_size - 1: - attn_caches[li][si] = torch.stack( - attn_caches[li][si], dim=1 - ) + attn_caches[li][si] = torch.stack(attn_caches[li][si], dim=1) conv_caches = [] for layer in state_list[0][1]: @@ -268,9 +265,7 @@ class ConvolutionModule(nn.Module): intervals = torch.arange( 0, self.chunk_length * (num_chunks - 1), self.chunk_length ) - first = torch.arange( - self.chunk_length, self.chunk_length + self.cache_size - ) + first = torch.arange(self.chunk_length, self.chunk_length + self.cache_size) indexes = intervals.unsqueeze(1) + first.unsqueeze(0) indexes = torch.cat( [indexes, torch.arange(U_ - self.cache_size, U_).unsqueeze(0)] @@ -284,9 +279,7 @@ class ConvolutionModule(nn.Module): # (num_chunks * B, cache_size + right_context_length, D) return pad_right_context.permute(0, 2, 1) - def _merge_right_context( - self, right_context: torch.Tensor, B: int - ) -> torch.Tensor: + def _merge_right_context(self, right_context: torch.Tensor, B: int) -> torch.Tensor: """ Args: right_context: @@ -337,12 +330,8 @@ class ConvolutionModule(nn.Module): right_context = x[:, :, :R] # (B, D, R) # make causal convolution - cache = torch.zeros( - B, D, self.cache_size, device=x.device, dtype=x.dtype - ) - pad_utterance = torch.cat( - [cache, utterance], dim=2 - ) # (B, D, cache + U) + cache = torch.zeros(B, D, self.cache_size, device=x.device, dtype=x.dtype) + pad_utterance = torch.cat([cache, utterance], dim=2) # (B, D, cache + U) # depth-wise conv on utterance utterance = self.depthwise_conv(pad_utterance) # (B, D, U) @@ -355,9 +344,7 @@ class ConvolutionModule(nn.Module): right_context = self.depthwise_conv( pad_right_context ) # (num_segs * B, D, right_context_length) - right_context = self._merge_right_context( - right_context, B - ) # (B, D, R) + right_context = self._merge_right_context(right_context, B) # (B, D, R) x = torch.cat([right_context, utterance], dim=2) # (B, D, R + U) x = self.deriv_balancer2(x) @@ -458,8 +445,7 @@ class EmformerAttention(nn.Module): if embed_dim % nhead != 0: raise ValueError( - f"embed_dim ({embed_dim}) is not a multiple of" - f"nhead ({nhead})." + f"embed_dim ({embed_dim}) is not a multiple of" f"nhead ({nhead})." ) self.embed_dim = embed_dim @@ -469,9 +455,7 @@ class EmformerAttention(nn.Module): self.head_dim = embed_dim // nhead self.dropout = dropout - self.emb_to_key_value = ScaledLinear( - embed_dim, 2 * embed_dim, bias=True - ) + self.emb_to_key_value = ScaledLinear(embed_dim, 2 * embed_dim, bias=True) self.emb_to_query = ScaledLinear(embed_dim, embed_dim, bias=True) self.out_proj = ScaledLinear( embed_dim, embed_dim, bias=True, initial_scale=0.25 @@ -513,9 +497,7 @@ class EmformerAttention(nn.Module): if padding_mask is not None: Q = attention_weights.size(1) B = attention_weights.size(0) // self.nhead - attention_weights_float = attention_weights_float.view( - B, self.nhead, Q, -1 - ) + attention_weights_float = attention_weights_float.view(B, self.nhead, Q, -1) attention_weights_float = attention_weights_float.masked_fill( padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), self.negative_inf, @@ -551,9 +533,7 @@ class EmformerAttention(nn.Module): scaling = float(self.head_dim) ** -0.5 # compute query with [right_context, utterance, summary]. - query = self.emb_to_query( - torch.cat([right_context, utterance, summary]) - ) + query = self.emb_to_query(torch.cat([right_context, utterance, summary])) # compute key and value with [memory, right_context, utterance]. key, value = self.emb_to_key_value( torch.cat([memory, right_context, utterance]) @@ -564,16 +544,12 @@ class EmformerAttention(nn.Module): # [memory, right context, left context, uttrance] # this is used in inference mode key = torch.cat([key[: M + R], left_context_key, key[M + R :]]) - value = torch.cat( - [value[: M + R], left_context_val, value[M + R :]] - ) + value = torch.cat([value[: M + R], left_context_val, value[M + R :]]) Q = query.size(0) # KV = key.size(0) reshaped_query, reshaped_key, reshaped_value = [ - tensor.contiguous() - .view(-1, B * self.nhead, self.head_dim) - .transpose(0, 1) + tensor.contiguous().view(-1, B * self.nhead, self.head_dim).transpose(0, 1) for tensor in [query, key, value] ] # (B * nhead, Q or KV, head_dim) attention_weights = torch.bmm( @@ -588,9 +564,7 @@ class EmformerAttention(nn.Module): # compute attention outputs attention = torch.bmm(attention_probs, reshaped_value) assert attention.shape == (B * self.nhead, Q, self.head_dim) - attention = ( - attention.transpose(0, 1).contiguous().view(Q, B, self.embed_dim) - ) + attention = attention.transpose(0, 1).contiguous().view(Q, B, self.embed_dim) # apply output projection outputs = self.out_proj(attention) @@ -672,12 +646,7 @@ class EmformerAttention(nn.Module): - output of right context and utterance, with shape (R + U, B, D). - memory output, with shape (M, B, D), where M = S - 1 or M = 0. """ - ( - output_right_context_utterance, - output_memory, - _, - _, - ) = self._forward_impl( + (output_right_context_utterance, output_memory, _, _,) = self._forward_impl( utterance, right_context, summary, @@ -947,13 +916,9 @@ class EmformerEncoderLayer(nn.Module): right_context = right_context_utterance[:R] if self.use_memory: - summary = self.summary_op(utterance.permute(1, 2, 0)).permute( - 2, 0, 1 - ) + summary = self.summary_op(utterance.permute(1, 2, 0)).permute(2, 0, 1) else: - summary = torch.empty(0).to( - dtype=utterance.dtype, device=utterance.device - ) + summary = torch.empty(0).to(dtype=utterance.dtype, device=utterance.device) output_right_context_utterance, output_memory = self.attention( utterance=utterance, right_context=right_context, @@ -992,14 +957,10 @@ class EmformerEncoderLayer(nn.Module): left_context_val = attn_cache[2] if self.use_memory: - summary = self.summary_op(utterance.permute(1, 2, 0)).permute( - 2, 0, 1 - ) + summary = self.summary_op(utterance.permute(1, 2, 0)).permute(2, 0, 1) summary = summary[:1] else: - summary = torch.empty(0).to( - dtype=utterance.dtype, device=utterance.device - ) + summary = torch.empty(0).to(dtype=utterance.dtype, device=utterance.device) ( output_right_context_utterance, output_memory, @@ -1014,9 +975,7 @@ class EmformerEncoderLayer(nn.Module): left_context_val=left_context_val, padding_mask=padding_mask, ) - attn_cache = self._update_attn_cache( - next_key, next_val, memory, attn_cache - ) + attn_cache = self._update_attn_cache(next_key, next_val, memory, attn_cache) return output_right_context_utterance, output_memory, attn_cache def forward( @@ -1151,11 +1110,7 @@ class EmformerEncoderLayer(nn.Module): src = src + self.dropout(self.feed_forward_macaron(src)) # emformer attention module - ( - src_att, - output_memory, - attn_cache, - ) = self._apply_attention_module_infer( + (src_att, output_memory, attn_cache,) = self._apply_attention_module_infer( src, R, memory, attn_cache, padding_mask=padding_mask ) src = src + self.dropout(src_att) @@ -1295,9 +1250,7 @@ class EmformerEncoder(nn.Module): def _gen_right_context(self, x: torch.Tensor) -> torch.Tensor: """Hard copy each chunk's right context and concat them.""" T = x.shape[0] - num_chunks = math.ceil( - (T - self.right_context_length) / self.chunk_length - ) + num_chunks = math.ceil((T - self.right_context_length) / self.chunk_length) # first (num_chunks - 1) right context block intervals = torch.arange( 0, self.chunk_length * (num_chunks - 1), self.chunk_length @@ -1316,9 +1269,7 @@ class EmformerEncoder(nn.Module): right_context_blocks = x[indexes.reshape(-1)] return right_context_blocks - def _gen_attention_mask_col_widths( - self, chunk_idx: int, U: int - ) -> List[int]: + def _gen_attention_mask_col_widths(self, chunk_idx: int, U: int) -> List[int]: """Calculate column widths (key, value) in attention mask for the chunk_idx chunk.""" num_chunks = math.ceil(U / self.chunk_length) @@ -1479,9 +1430,7 @@ class EmformerEncoder(nn.Module): output_lengths = torch.clamp(lengths - self.right_context_length, min=0) attention_mask = self._gen_attention_mask(utterance) memory = ( - self.init_memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)[ - :-1 - ] + self.init_memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)[:-1] if self.use_memory else torch.empty(0).to(dtype=x.dtype, device=x.device) ) @@ -1643,12 +1592,8 @@ class EmformerEncoder(nn.Module): attn_caches = [ [ torch.zeros(self.memory_size, self.d_model, device=device), - torch.zeros( - self.left_context_length, self.d_model, device=device - ), - torch.zeros( - self.left_context_length, self.d_model, device=device - ), + torch.zeros(self.left_context_length, self.d_model, device=device), + torch.zeros(self.left_context_length, self.d_model, device=device), ] for _ in range(self.num_encoder_layers) ] @@ -1693,17 +1638,11 @@ class Emformer(EncoderInterface): raise NotImplementedError( "chunk_length must be a mutiple of subsampling_factor." ) - if ( - left_context_length != 0 - and left_context_length % subsampling_factor != 0 - ): + if left_context_length != 0 and left_context_length % subsampling_factor != 0: raise NotImplementedError( "left_context_length must be 0 or a mutiple of subsampling_factor." # noqa ) - if ( - right_context_length != 0 - and right_context_length % subsampling_factor != 0 - ): + if right_context_length != 0 and right_context_length % subsampling_factor != 0: raise NotImplementedError( "right_context_length must be 0 or a mutiple of subsampling_factor." # noqa ) @@ -1766,9 +1705,7 @@ class Emformer(EncoderInterface): x_lens = (((x_lens - 1) >> 1) - 1) >> 1 assert x.size(0) == x_lens.max().item() - output, output_lengths = self.encoder( - x, x_lens, warmup=warmup - ) # (T, N, C) + output, output_lengths = self.encoder(x, x_lens, warmup=warmup) # (T, N, C) output = output.permute(1, 0, 2) # (T, N, C) -> (N, T, C) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/export.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/export.py index 4930881ea..334682ad6 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/export.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/export.py @@ -136,8 +136,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( @@ -181,9 +180,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -210,9 +209,9 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -279,9 +278,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/stream.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/stream.py index 9494e1fc1..c211b215e 100644 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/stream.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/stream.py @@ -68,14 +68,12 @@ class Stream(object): elif params.decoding_method == "fast_beam_search": # feature_len is needed to get partial results. # The rnnt_decoding_stream for fast_beam_search. - self.rnnt_decoding_stream: k2.RnntDecodingStream = ( - k2.RnntDecodingStream(decoding_graph) + self.rnnt_decoding_stream: k2.RnntDecodingStream = k2.RnntDecodingStream( + decoding_graph ) self.hyp: Optional[List[int]] = None else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") self.ground_truth: str = "" diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py index 61dbe8658..621eeb952 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py @@ -211,8 +211,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -371,9 +370,7 @@ def modified_beam_search( index=hyps_shape.row_ids(1).to(torch.int64), ) # (num_hyps, encoder_out_dim) - logits = model.joiner( - current_encoder_out, decoder_out, project_input=False - ) + logits = model.joiner(current_encoder_out, decoder_out, project_input=False) # logits is of shape (num_hyps, 1, 1, vocab_size) logits = logits.squeeze(1).squeeze(1) @@ -390,9 +387,7 @@ def modified_beam_search( log_probs_shape = k2.ragged.create_ragged_shape2( row_splits=row_splits, cached_tot_size=log_probs.numel() ) - ragged_log_probs = k2.RaggedTensor( - shape=log_probs_shape, value=log_probs - ) + ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) for i in range(batch_size): topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) @@ -551,14 +546,10 @@ def decode_one_chunk( feature_list, batch_first=True, padding_value=LOG_EPSILON ).to(device) feature_lens = torch.tensor(feature_len_list, device=device) - num_processed_frames = torch.tensor( - num_processed_frames_list, device=device - ) + num_processed_frames = torch.tensor(num_processed_frames_list, device=device) # Make sure it has at least 1 frame after subsampling, first-and-last-frame cutting, and right context cutting # noqa - tail_length = ( - 3 * params.subsampling_factor + params.right_context_length + 3 - ) + tail_length = 3 * params.subsampling_factor + params.right_context_length + 3 if features.size(1) < tail_length: pad_length = tail_length - features.size(1) feature_lens += pad_length @@ -605,9 +596,7 @@ def decode_one_chunk( max_states=params.max_states, ) else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") # Update cached states of each stream state_list = unstack_states(states) @@ -782,8 +771,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -831,9 +819,7 @@ def main(): params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -867,9 +853,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -896,9 +882,9 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py index c07d8f76b..3d8d4a18a 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py @@ -95,9 +95,7 @@ from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def add_model_arguments(parser: argparse.ArgumentParser): @@ -265,8 +263,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( @@ -289,8 +286,7 @@ def get_parser(): "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network)" "part.", ) parser.add_argument( @@ -636,11 +632,7 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = ( - model.device - if isinstance(model, DDP) - else next(model.parameters()).device - ) + device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -668,23 +660,16 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 - if warmup < 1.0 - else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = ( - params.simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss + 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) ) + loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa info["utterances"] = feature.size(0) @@ -871,9 +856,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -981,7 +964,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2 ** 22 + 2**22 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/decode.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/decode.py index 98b8290b5..d3c001942 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/decode.py @@ -215,8 +215,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( @@ -284,9 +283,7 @@ def decode_one_batch( value=LOG_EPS, ) - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) hyps = [] if params.decoding_method == "fast_beam_search": @@ -301,10 +298,7 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -427,9 +421,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -462,8 +454,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -506,9 +497,7 @@ def main(): params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -540,9 +529,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -569,9 +558,9 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer.py index f16f5acc7..c3739566f 100644 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer.py @@ -35,7 +35,6 @@ from scaling import ( from icefall.utils import make_pad_mask - LOG_EPSILON = math.log(1e-10) @@ -127,9 +126,7 @@ def stack_states( for si, s in enumerate(layer): attn_caches[li][si].append(s) if b == batch_size - 1: - attn_caches[li][si] = torch.stack( - attn_caches[li][si], dim=1 - ) + attn_caches[li][si] = torch.stack(attn_caches[li][si], dim=1) conv_caches = [] for layer in state_list[0][1]: @@ -268,9 +265,7 @@ class ConvolutionModule(nn.Module): intervals = torch.arange( 0, self.chunk_length * (num_chunks - 1), self.chunk_length ) - first = torch.arange( - self.chunk_length, self.chunk_length + self.cache_size - ) + first = torch.arange(self.chunk_length, self.chunk_length + self.cache_size) indexes = intervals.unsqueeze(1) + first.unsqueeze(0) indexes = torch.cat( [indexes, torch.arange(U_ - self.cache_size, U_).unsqueeze(0)] @@ -284,9 +279,7 @@ class ConvolutionModule(nn.Module): # (num_chunks * B, cache_size + right_context_length, D) return pad_right_context.permute(0, 2, 1) - def _merge_right_context( - self, right_context: torch.Tensor, B: int - ) -> torch.Tensor: + def _merge_right_context(self, right_context: torch.Tensor, B: int) -> torch.Tensor: """ Args: right_context: @@ -337,12 +330,8 @@ class ConvolutionModule(nn.Module): right_context = x[:, :, :R] # (B, D, R) # make causal convolution - cache = torch.zeros( - B, D, self.cache_size, device=x.device, dtype=x.dtype - ) - pad_utterance = torch.cat( - [cache, utterance], dim=2 - ) # (B, D, cache + U) + cache = torch.zeros(B, D, self.cache_size, device=x.device, dtype=x.dtype) + pad_utterance = torch.cat([cache, utterance], dim=2) # (B, D, cache + U) # depth-wise conv on utterance utterance = self.depthwise_conv(pad_utterance) # (B, D, U) @@ -355,9 +344,7 @@ class ConvolutionModule(nn.Module): right_context = self.depthwise_conv( pad_right_context ) # (num_segs * B, D, right_context_length) - right_context = self._merge_right_context( - right_context, B - ) # (B, D, R) + right_context = self._merge_right_context(right_context, B) # (B, D, R) x = torch.cat([right_context, utterance], dim=2) # (B, D, R + U) x = self.deriv_balancer2(x) @@ -458,8 +445,7 @@ class EmformerAttention(nn.Module): if embed_dim % nhead != 0: raise ValueError( - f"embed_dim ({embed_dim}) is not a multiple of" - f"nhead ({nhead})." + f"embed_dim ({embed_dim}) is not a multiple of" f"nhead ({nhead})." ) self.embed_dim = embed_dim @@ -469,9 +455,7 @@ class EmformerAttention(nn.Module): self.head_dim = embed_dim // nhead self.dropout = dropout - self.emb_to_key_value = ScaledLinear( - embed_dim, 2 * embed_dim, bias=True - ) + self.emb_to_key_value = ScaledLinear(embed_dim, 2 * embed_dim, bias=True) self.emb_to_query = ScaledLinear(embed_dim, embed_dim, bias=True) self.out_proj = ScaledLinear( embed_dim, embed_dim, bias=True, initial_scale=0.25 @@ -513,9 +497,7 @@ class EmformerAttention(nn.Module): if padding_mask is not None: Q = attention_weights.size(1) B = attention_weights.size(0) // self.nhead - attention_weights_float = attention_weights_float.view( - B, self.nhead, Q, -1 - ) + attention_weights_float = attention_weights_float.view(B, self.nhead, Q, -1) attention_weights_float = attention_weights_float.masked_fill( padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), self.negative_inf, @@ -561,16 +543,12 @@ class EmformerAttention(nn.Module): # [memory, right context, left context, uttrance] # this is used in inference mode key = torch.cat([key[: M + R], left_context_key, key[M + R :]]) - value = torch.cat( - [value[: M + R], left_context_val, value[M + R :]] - ) + value = torch.cat([value[: M + R], left_context_val, value[M + R :]]) Q = query.size(0) # KV = key.size(0) reshaped_query, reshaped_key, reshaped_value = [ - tensor.contiguous() - .view(-1, B * self.nhead, self.head_dim) - .transpose(0, 1) + tensor.contiguous().view(-1, B * self.nhead, self.head_dim).transpose(0, 1) for tensor in [query, key, value] ] # (B * nhead, Q or KV, head_dim) attention_weights = torch.bmm( @@ -585,9 +563,7 @@ class EmformerAttention(nn.Module): # compute attention outputs attention = torch.bmm(attention_probs, reshaped_value) assert attention.shape == (B * self.nhead, Q, self.head_dim) - attention = ( - attention.transpose(0, 1).contiguous().view(Q, B, self.embed_dim) - ) + attention = attention.transpose(0, 1).contiguous().view(Q, B, self.embed_dim) # apply output projection output_right_context_utterance = self.out_proj(attention) @@ -905,13 +881,11 @@ class EmformerEncoderLayer(nn.Module): right_context = right_context_utterance[:R] if self.use_memory: - memory = self.summary_op(utterance.permute(1, 2, 0)).permute( - 2, 0, 1 - )[:-1, :, :] + memory = self.summary_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)[ + :-1, :, : + ] else: - memory = torch.empty(0).to( - dtype=utterance.dtype, device=utterance.device - ) + memory = torch.empty(0).to(dtype=utterance.dtype, device=utterance.device) output_right_context_utterance = self.attention( utterance=utterance, right_context=right_context, @@ -948,18 +922,12 @@ class EmformerEncoderLayer(nn.Module): left_context_val = attn_cache[2] if self.use_memory: - memory = self.summary_op(utterance.permute(1, 2, 0)).permute( - 2, 0, 1 - )[:1, :, :] + memory = self.summary_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)[ + :1, :, : + ] else: - memory = torch.empty(0).to( - dtype=utterance.dtype, device=utterance.device - ) - ( - output_right_context_utterance, - next_key, - next_val, - ) = self.attention.infer( + memory = torch.empty(0).to(dtype=utterance.dtype, device=utterance.device) + (output_right_context_utterance, next_key, next_val,) = self.attention.infer( utterance=utterance, right_context=right_context, memory=pre_memory, @@ -967,9 +935,7 @@ class EmformerEncoderLayer(nn.Module): left_context_val=left_context_val, padding_mask=padding_mask, ) - attn_cache = self._update_attn_cache( - next_key, next_val, memory, attn_cache - ) + attn_cache = self._update_attn_cache(next_key, next_val, memory, attn_cache) return output_right_context_utterance, attn_cache def forward( @@ -1226,9 +1192,7 @@ class EmformerEncoder(nn.Module): def _gen_right_context(self, x: torch.Tensor) -> torch.Tensor: """Hard copy each chunk's right context and concat them.""" T = x.shape[0] - num_chunks = math.ceil( - (T - self.right_context_length) / self.chunk_length - ) + num_chunks = math.ceil((T - self.right_context_length) / self.chunk_length) # first (num_chunks - 1) right context block intervals = torch.arange( 0, self.chunk_length * (num_chunks - 1), self.chunk_length @@ -1247,9 +1211,7 @@ class EmformerEncoder(nn.Module): right_context_blocks = x[indexes.reshape(-1)] return right_context_blocks - def _gen_attention_mask_col_widths( - self, chunk_idx: int, U: int - ) -> List[int]: + def _gen_attention_mask_col_widths(self, chunk_idx: int, U: int) -> List[int]: """Calculate column widths (key, value) in attention mask for the chunk_idx chunk.""" num_chunks = math.ceil(U / self.chunk_length) @@ -1549,12 +1511,8 @@ class EmformerEncoder(nn.Module): attn_caches = [ [ torch.zeros(self.memory_size, self.d_model, device=device), - torch.zeros( - self.left_context_length, self.d_model, device=device - ), - torch.zeros( - self.left_context_length, self.d_model, device=device - ), + torch.zeros(self.left_context_length, self.d_model, device=device), + torch.zeros(self.left_context_length, self.d_model, device=device), ] for _ in range(self.num_encoder_layers) ] @@ -1599,17 +1557,11 @@ class Emformer(EncoderInterface): raise NotImplementedError( "chunk_length must be a mutiple of subsampling_factor." ) - if ( - left_context_length != 0 - and left_context_length % subsampling_factor != 0 - ): + if left_context_length != 0 and left_context_length % subsampling_factor != 0: raise NotImplementedError( "left_context_length must be 0 or a mutiple of subsampling_factor." # noqa ) - if ( - right_context_length != 0 - and right_context_length % subsampling_factor != 0 - ): + if right_context_length != 0 and right_context_length % subsampling_factor != 0: raise NotImplementedError( "right_context_length must be 0 or a mutiple of subsampling_factor." # noqa ) @@ -1672,9 +1624,7 @@ class Emformer(EncoderInterface): x_lens = (((x_lens - 1) >> 1) - 1) >> 1 assert x.size(0) == x_lens.max().item() - output, output_lengths = self.encoder( - x, x_lens, warmup=warmup - ) # (T, N, C) + output, output_lengths = self.encoder(x, x_lens, warmup=warmup) # (T, N, C) output = output.permute(1, 0, 2) # (T, N, C) -> (N, T, C) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export.py index ab15e0241..998fb6e81 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export.py @@ -136,8 +136,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( @@ -181,9 +180,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -210,9 +209,9 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -279,9 +278,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming_decode.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming_decode.py index 71150392d..618d8bb63 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming_decode.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming_decode.py @@ -211,8 +211,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -371,9 +370,7 @@ def modified_beam_search( index=hyps_shape.row_ids(1).to(torch.int64), ) # (num_hyps, encoder_out_dim) - logits = model.joiner( - current_encoder_out, decoder_out, project_input=False - ) + logits = model.joiner(current_encoder_out, decoder_out, project_input=False) # logits is of shape (num_hyps, 1, 1, vocab_size) logits = logits.squeeze(1).squeeze(1) @@ -390,9 +387,7 @@ def modified_beam_search( log_probs_shape = k2.ragged.create_ragged_shape2( row_splits=row_splits, cached_tot_size=log_probs.numel() ) - ragged_log_probs = k2.RaggedTensor( - shape=log_probs_shape, value=log_probs - ) + ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) for i in range(batch_size): topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) @@ -551,14 +546,10 @@ def decode_one_chunk( feature_list, batch_first=True, padding_value=LOG_EPSILON ).to(device) feature_lens = torch.tensor(feature_len_list, device=device) - num_processed_frames = torch.tensor( - num_processed_frames_list, device=device - ) + num_processed_frames = torch.tensor(num_processed_frames_list, device=device) # Make sure it has at least 1 frame after subsampling, first-and-last-frame cutting, and right context cutting # noqa - tail_length = ( - 3 * params.subsampling_factor + params.right_context_length + 3 - ) + tail_length = 3 * params.subsampling_factor + params.right_context_length + 3 if features.size(1) < tail_length: pad_length = tail_length - features.size(1) feature_lens += pad_length @@ -605,9 +596,7 @@ def decode_one_chunk( max_states=params.max_states, ) else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") # Update cached states of each stream state_list = unstack_states(states) @@ -782,8 +771,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -831,9 +819,7 @@ def main(): params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -867,9 +853,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -896,9 +882,9 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py index 2bbc45d78..542f524a9 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py @@ -95,9 +95,7 @@ from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def add_model_arguments(parser: argparse.ArgumentParser): @@ -265,8 +263,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( @@ -289,8 +286,7 @@ def get_parser(): "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network)" "part.", ) parser.add_argument( @@ -636,11 +632,7 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = ( - model.device - if isinstance(model, DDP) - else next(model.parameters()).device - ) + device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -668,23 +660,16 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 - if warmup < 1.0 - else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = ( - params.simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss + 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) ) + loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa info["utterances"] = feature.size(0) @@ -871,9 +856,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -981,7 +964,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2 ** 22 + 2**22 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/librispeech/ASR/local/add_alignment_librispeech.py b/egs/librispeech/ASR/local/add_alignment_librispeech.py index fe6a26c51..cc34a72d8 100755 --- a/egs/librispeech/ASR/local/add_alignment_librispeech.py +++ b/egs/librispeech/ASR/local/add_alignment_librispeech.py @@ -157,9 +157,7 @@ def add_alignment( for ali_path in part_ali_dir.rglob("*.alignment.txt"): ali = parse_alignments(ali_path) alignments.update(ali) - logging.info( - f"{part} has {len(alignments.keys())} cuts with alignments." - ) + logging.info(f"{part} has {len(alignments.keys())} cuts with alignments.") # add alignment attribute and write out cuts_in = load_manifest_lazy(cuts_in_path) @@ -170,18 +168,14 @@ def add_alignment( if origin_id in alignments: ali = alignments[origin_id] else: - logging.info( - f"Warning: {origin_id} does not have alignment." - ) + logging.info(f"Warning: {origin_id} does not have alignment.") ali = [] subcut.alignment = {"word": ali} writer.write(cut, flush=True) def main(): - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) parser = get_parser() diff --git a/egs/librispeech/ASR/local/compile_hlg.py b/egs/librispeech/ASR/local/compile_hlg.py index c628dfd53..df6c609bb 100755 --- a/egs/librispeech/ASR/local/compile_hlg.py +++ b/egs/librispeech/ASR/local/compile_hlg.py @@ -57,7 +57,7 @@ def get_args(): return parser.parse_args() -def compile_HLG(lang_dir: str, lm: str="G_3_gram") -> k2.Fsa: +def compile_HLG(lang_dir: str, lm: str = "G_3_gram") -> k2.Fsa: """ Args: lang_dir: @@ -159,9 +159,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/librispeech/ASR/local/compile_lg.py b/egs/librispeech/ASR/local/compile_lg.py index 45c4b7f5f..19bf3bff4 100755 --- a/egs/librispeech/ASR/local/compile_lg.py +++ b/egs/librispeech/ASR/local/compile_lg.py @@ -132,9 +132,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/librispeech/ASR/local/compute_fbank_gigaspeech_dev_test.py b/egs/librispeech/ASR/local/compute_fbank_gigaspeech_dev_test.py index c0c7ef8c5..97750f3ea 100644 --- a/egs/librispeech/ASR/local/compute_fbank_gigaspeech_dev_test.py +++ b/egs/librispeech/ASR/local/compute_fbank_gigaspeech_dev_test.py @@ -80,9 +80,7 @@ def compute_fbank_gigaspeech_dev_test(): def main(): - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) compute_fbank_gigaspeech_dev_test() diff --git a/egs/librispeech/ASR/local/compute_fbank_gigaspeech_splits.py b/egs/librispeech/ASR/local/compute_fbank_gigaspeech_splits.py index 5587106e5..ce0ef24e7 100644 --- a/egs/librispeech/ASR/local/compute_fbank_gigaspeech_splits.py +++ b/egs/librispeech/ASR/local/compute_fbank_gigaspeech_splits.py @@ -144,9 +144,7 @@ def main(): date_time = now.strftime("%Y-%m-%d-%H-%M-%S") log_filename = "log-compute_fbank_gigaspeech_splits" - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" log_filename = f"{log_filename}-{date_time}" logging.basicConfig( diff --git a/egs/librispeech/ASR/local/compute_fbank_librispeech.py b/egs/librispeech/ASR/local/compute_fbank_librispeech.py index ce7d087f0..9f8503814 100755 --- a/egs/librispeech/ASR/local/compute_fbank_librispeech.py +++ b/egs/librispeech/ASR/local/compute_fbank_librispeech.py @@ -112,9 +112,7 @@ def compute_fbank_librispeech(bpe_model: Optional[str] = None): if "train" in partition: cut_set = ( - cut_set - + cut_set.perturb_speed(0.9) - + cut_set.perturb_speed(1.1) + cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) ) cut_set = cut_set.compute_and_store_features( extractor=extractor, @@ -128,9 +126,7 @@ def compute_fbank_librispeech(bpe_model: Optional[str] = None): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) args = get_args() diff --git a/egs/librispeech/ASR/local/compute_fbank_musan.py b/egs/librispeech/ASR/local/compute_fbank_musan.py index 056da29e5..4a4093ae4 100755 --- a/egs/librispeech/ASR/local/compute_fbank_musan.py +++ b/egs/librispeech/ASR/local/compute_fbank_musan.py @@ -83,9 +83,7 @@ def compute_fbank_musan(): # create chunks of Musan with duration 5 - 10 seconds musan_cuts = ( CutSet.from_manifests( - recordings=combine( - part["recordings"] for part in manifests.values() - ) + recordings=combine(part["recordings"] for part in manifests.values()) ) .cut_into_windows(10.0) .filter(lambda c: c.duration > 5) @@ -101,9 +99,7 @@ def compute_fbank_musan(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) compute_fbank_musan() diff --git a/egs/librispeech/ASR/local/convert_transcript_words_to_tokens.py b/egs/librispeech/ASR/local/convert_transcript_words_to_tokens.py index 133499c8b..a8d5117c9 100755 --- a/egs/librispeech/ASR/local/convert_transcript_words_to_tokens.py +++ b/egs/librispeech/ASR/local/convert_transcript_words_to_tokens.py @@ -51,16 +51,12 @@ def get_args(): "lines. Each line consists of space separated words.", ) parser.add_argument("--lexicon", type=str, help="The input lexicon file.") - parser.add_argument( - "--oov", type=str, default="", help="The OOV word." - ) + parser.add_argument("--oov", type=str, default="", help="The OOV word.") return parser.parse_args() -def process_line( - lexicon: Dict[str, List[str]], line: str, oov_token: str -) -> None: +def process_line(lexicon: Dict[str, List[str]], line: str, oov_token: str) -> None: """ Args: lexicon: diff --git a/egs/librispeech/ASR/local/download_lm.py b/egs/librispeech/ASR/local/download_lm.py index 030122aa7..3518db524 100755 --- a/egs/librispeech/ASR/local/download_lm.py +++ b/egs/librispeech/ASR/local/download_lm.py @@ -87,9 +87,7 @@ def main(out_dir: str): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/librispeech/ASR/local/filter_cuts.py b/egs/librispeech/ASR/local/filter_cuts.py index dff98a954..b3f0956c3 100644 --- a/egs/librispeech/ASR/local/filter_cuts.py +++ b/egs/librispeech/ASR/local/filter_cuts.py @@ -79,8 +79,7 @@ def filter_cuts(cut_set: CutSet, sp: spm.SentencePieceProcessor): total += 1 if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. " - f"Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. " f"Duration: {c.duration}" ) removed += 1 return False @@ -125,8 +124,7 @@ def filter_cuts(cut_set: CutSet, sp: spm.SentencePieceProcessor): ans = cut_set.filter(remove_short_and_long_utterances).to_eager() ratio = removed / total * 100 logging.info( - f"Removed {removed} cuts from {total} cuts. " - f"{ratio:.3f}% data is removed." + f"Removed {removed} cuts from {total} cuts. " f"{ratio:.3f}% data is removed." ) return ans @@ -155,9 +153,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/librispeech/ASR/local/generate_unique_lexicon.py b/egs/librispeech/ASR/local/generate_unique_lexicon.py index 566c0743d..3459c2f5a 100755 --- a/egs/librispeech/ASR/local/generate_unique_lexicon.py +++ b/egs/librispeech/ASR/local/generate_unique_lexicon.py @@ -91,9 +91,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/librispeech/ASR/local/prepare_lang_bpe.py b/egs/librispeech/ASR/local/prepare_lang_bpe.py index dec8a7442..e121aefa9 100755 --- a/egs/librispeech/ASR/local/prepare_lang_bpe.py +++ b/egs/librispeech/ASR/local/prepare_lang_bpe.py @@ -150,9 +150,7 @@ def generate_lexicon( words_pieces_ids: List[List[int]] = sp.encode(words, out_type=int) # Now convert word piece IDs back to word piece strings. - words_pieces: List[List[str]] = [ - sp.id_to_piece(ids) for ids in words_pieces_ids - ] + words_pieces: List[List[str]] = [sp.id_to_piece(ids) for ids in words_pieces_ids] lexicon = [] for word, pieces in zip(words, words_pieces): diff --git a/egs/librispeech/ASR/local/prepare_lm_training_data.py b/egs/librispeech/ASR/local/prepare_lm_training_data.py index 5070341f1..32ae8c580 100755 --- a/egs/librispeech/ASR/local/prepare_lm_training_data.py +++ b/egs/librispeech/ASR/local/prepare_lm_training_data.py @@ -137,8 +137,7 @@ def main(): for i in range(num_sentences): if step and i % step == 0: logging.info( - f"Processed number of lines: {i} " - f"({i/num_sentences*100: .3f}%)" + f"Processed number of lines: {i} " f"({i/num_sentences*100: .3f}%)" ) word_ids = sentences[i] @@ -154,18 +153,14 @@ def main(): sentence_lengths[i] = token_ids.numel() - output["sentence_lengths"] = torch.tensor( - sentence_lengths, dtype=torch.int32 - ) + output["sentence_lengths"] = torch.tensor(sentence_lengths, dtype=torch.int32) torch.save(output, args.lm_archive) logging.info(f"Saved to {args.lm_archive}") if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/librispeech/ASR/local/preprocess_gigaspeech.py b/egs/librispeech/ASR/local/preprocess_gigaspeech.py index 077f23039..8aa5e461d 100644 --- a/egs/librispeech/ASR/local/preprocess_gigaspeech.py +++ b/egs/librispeech/ASR/local/preprocess_gigaspeech.py @@ -119,9 +119,7 @@ def preprocess_giga_speech(): def main(): - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) preprocess_giga_speech() diff --git a/egs/librispeech/ASR/local/test_prepare_lang.py b/egs/librispeech/ASR/local/test_prepare_lang.py index d4cf62bba..74e025ad7 100755 --- a/egs/librispeech/ASR/local/test_prepare_lang.py +++ b/egs/librispeech/ASR/local/test_prepare_lang.py @@ -88,9 +88,7 @@ def test_read_lexicon(filename: str): fsa.aux_labels_sym = k2.SymbolTable.from_file("words.txt") fsa.draw("L.pdf", title="L") - fsa_disambig = lexicon_to_fst( - lexicon_disambig, phone2id=phone2id, word2id=word2id - ) + fsa_disambig = lexicon_to_fst(lexicon_disambig, phone2id=phone2id, word2id=word2id) fsa_disambig.labels_sym = k2.SymbolTable.from_file("phones.txt") fsa_disambig.aux_labels_sym = k2.SymbolTable.from_file("words.txt") fsa_disambig.draw("L_disambig.pdf", title="L_disambig") diff --git a/egs/librispeech/ASR/local/validate_manifest.py b/egs/librispeech/ASR/local/validate_manifest.py index 7c57d629a..f620b91ea 100755 --- a/egs/librispeech/ASR/local/validate_manifest.py +++ b/egs/librispeech/ASR/local/validate_manifest.py @@ -85,9 +85,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/decode.py b/egs/librispeech/ASR/lstm_transducer_stateless/decode.py index 27414d717..79b21fab1 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless/decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/decode.py @@ -272,8 +272,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( @@ -366,9 +365,7 @@ def decode_one_batch( ) feature_lens += num_tail_padded_frames - encoder_out, encoder_out_lens, _ = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens, _ = model.encoder(x=feature, x_lens=feature_lens) hyps = [] @@ -427,10 +424,7 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -561,9 +555,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -596,8 +588,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -648,9 +639,7 @@ def main(): if "LG" in params.decoding_method: params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -682,9 +671,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -711,9 +700,9 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -772,9 +761,7 @@ def main(): decoding_graph.scores *= params.ngram_lm_scale else: word_table = None - decoding_graph = k2.trivial_graph( - params.vocab_size - 1, device=device - ) + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) else: decoding_graph = None word_table = None diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/export.py b/egs/librispeech/ASR/lstm_transducer_stateless/export.py index 13dac6009..45fa6d662 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless/export.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/export.py @@ -172,8 +172,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) add_model_arguments(parser) @@ -281,9 +280,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -310,9 +309,9 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -380,9 +379,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/jit_pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless/jit_pretrained.py index 594c33e4f..51f4a2e8a 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless/jit_pretrained.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/jit_pretrained.py @@ -124,8 +124,7 @@ def read_sound_files( for f in filenames: wave, sample_rate = torchaudio.load(f) assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" + f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}" ) # We use only the first channel ans.append(wave[0]) @@ -314,9 +313,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/lstm.py b/egs/librispeech/ASR/lstm_transducer_stateless/lstm.py index c54a4c478..bbab16af7 100644 --- a/egs/librispeech/ASR/lstm_transducer_stateless/lstm.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/lstm.py @@ -672,9 +672,7 @@ class RandomCombine(nn.Module): self.stddev = stddev self.final_log_weight = ( - torch.tensor( - (final_weight / (1 - final_weight)) * (self.num_inputs - 1) - ) + torch.tensor((final_weight / (1 - final_weight)) * (self.num_inputs - 1)) .log() .item() ) @@ -771,16 +769,14 @@ class RandomCombine(nn.Module): # final contains self.num_inputs - 1 in all elements final = torch.full((num_frames,), self.num_inputs - 1, device=device) # nonfinal contains random integers in [0..num_inputs - 2], these are for non-final weights. # noqa - nonfinal = torch.randint( - self.num_inputs - 1, (num_frames,), device=device - ) + nonfinal = torch.randint(self.num_inputs - 1, (num_frames,), device=device) indexes = torch.where( torch.rand(num_frames, device=device) < final_prob, final, nonfinal ) - ans = torch.nn.functional.one_hot( - indexes, num_classes=self.num_inputs - ).to(dtype=dtype) + ans = torch.nn.functional.one_hot(indexes, num_classes=self.num_inputs).to( + dtype=dtype + ) return ans def _get_random_mixed_weights( diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/model.py b/egs/librispeech/ASR/lstm_transducer_stateless/model.py index d71132b4a..e7bad7ed8 100644 --- a/egs/librispeech/ASR/lstm_transducer_stateless/model.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/model.py @@ -66,9 +66,7 @@ class Transducer(nn.Module): self.decoder = decoder self.joiner = joiner - self.simple_am_proj = ScaledLinear( - encoder_dim, vocab_size, initial_speed=0.5 - ) + self.simple_am_proj = ScaledLinear(encoder_dim, vocab_size, initial_speed=0.5) self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size) def forward( @@ -151,9 +149,7 @@ class Transducer(nn.Module): y_padded = y.pad(mode="constant", padding_value=0) y_padded = y_padded.to(torch.int64) - boundary = torch.zeros( - (x.size(0), 4), dtype=torch.int64, device=x.device - ) + boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) boundary[:, 2] = y_lens boundary[:, 3] = x_lens diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless/pretrained.py index 2a6e2adc6..9263b41b2 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless/pretrained.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/pretrained.py @@ -166,8 +166,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -199,8 +198,7 @@ def read_sound_files( for f in filenames: wave, sample_rate = torchaudio.load(f) assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" + f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}" ) # We use only the first channel ans.append(wave[0]) @@ -264,15 +262,11 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) feature_lengths = torch.tensor(feature_lengths, device=device) - encoder_out, encoder_out_lens, _ = model.encoder( - x=features, x_lens=feature_lengths - ) + encoder_out, encoder_out_lens, _ = model.encoder(x=features, x_lens=feature_lengths) num_waves = encoder_out.size(0) hyps = [] @@ -344,9 +338,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/stream.py b/egs/librispeech/ASR/lstm_transducer_stateless/stream.py index 97d890c82..d8f7fd960 100644 --- a/egs/librispeech/ASR/lstm_transducer_stateless/stream.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/stream.py @@ -70,14 +70,12 @@ class Stream(object): elif params.decoding_method == "fast_beam_search": # feature_len is needed to get partial results. # The rnnt_decoding_stream for fast_beam_search. - self.rnnt_decoding_stream: k2.RnntDecodingStream = ( - k2.RnntDecodingStream(decoding_graph) + self.rnnt_decoding_stream: k2.RnntDecodingStream = k2.RnntDecodingStream( + decoding_graph ) self.hyp: Optional[List[int]] = None else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") self.ground_truth: str = "" diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/streaming_decode.py b/egs/librispeech/ASR/lstm_transducer_stateless/streaming_decode.py index d6376bdc0..4cc2aabb2 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless/streaming_decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/streaming_decode.py @@ -199,8 +199,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -359,9 +358,7 @@ def modified_beam_search( index=hyps_shape.row_ids(1).to(torch.int64), ) # (num_hyps, encoder_out_dim) - logits = model.joiner( - current_encoder_out, decoder_out, project_input=False - ) + logits = model.joiner(current_encoder_out, decoder_out, project_input=False) # logits is of shape (num_hyps, 1, 1, vocab_size) logits = logits.squeeze(1).squeeze(1) @@ -378,9 +375,7 @@ def modified_beam_search( log_probs_shape = k2.ragged.create_ragged_shape2( row_splits=row_splits, cached_tot_size=log_probs.numel() ) - ragged_log_probs = k2.RaggedTensor( - shape=log_probs_shape, value=log_probs - ) + ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) for i in range(batch_size): topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) @@ -539,9 +534,7 @@ def decode_one_chunk( feature_list, batch_first=True, padding_value=LOG_EPSILON ).to(device) feature_lens = torch.tensor(feature_len_list, device=device) - num_processed_frames = torch.tensor( - num_processed_frames_list, device=device - ) + num_processed_frames = torch.tensor(num_processed_frames_list, device=device) # Make sure it has at least 1 frame after subsampling tail_length = params.subsampling_factor + 5 @@ -583,8 +576,7 @@ def decode_one_chunk( with warnings.catch_warnings(): warnings.simplefilter("ignore") processed_lens = ( - num_processed_frames // params.subsampling_factor - + encoder_out_lens + num_processed_frames // params.subsampling_factor + encoder_out_lens ) fast_beam_search_one_best( model=model, @@ -596,9 +588,7 @@ def decode_one_chunk( max_states=params.max_states, ) else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") # Update cached states of each stream state_list = unstack_states(states) @@ -773,8 +763,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -816,9 +805,7 @@ def main(): params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -852,9 +839,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -881,9 +868,9 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/train.py b/egs/librispeech/ASR/lstm_transducer_stateless/train.py index d30fc260a..b9a68753e 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless/train.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/train.py @@ -87,9 +87,7 @@ from icefall.utils import ( str2bool, ) -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def add_model_arguments(parser: argparse.ArgumentParser): @@ -222,8 +220,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( @@ -246,8 +243,7 @@ def get_parser(): "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network)" "part.", ) parser.add_argument( @@ -594,11 +590,7 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = ( - model.device - if isinstance(model, DDP) - else next(model.parameters()).device - ) + device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -638,9 +630,7 @@ def compute_loss( # If either all simple_loss or pruned_loss is inf or nan, # we stop the training process by raising an exception - if torch.all(~simple_loss_is_finite) or torch.all( - ~pruned_loss_is_finite - ): + if torch.all(~simple_loss_is_finite) or torch.all(~pruned_loss_is_finite): raise ValueError( "There are too many utterances in this batch " "leading to inf or nan losses." @@ -653,14 +643,9 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 - if warmup < 1.0 - else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = ( - params.simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss + 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) ) + loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training @@ -671,9 +656,7 @@ def compute_loss( # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2 # (2) If some utterances in the batch lead to inf/nan loss, they # are filtered out. - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa info["utterances"] = feature.size(0) @@ -856,9 +839,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -989,8 +970,7 @@ def run(rank, world_size, args): # the threshold if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. " - f"Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. " f"Duration: {c.duration}" ) return False diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py b/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py index bad4e243e..41602d207 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py @@ -295,8 +295,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( @@ -474,9 +473,7 @@ def decode_one_batch( ) feature_lens += num_tail_padded_frames - encoder_out, encoder_out_lens, _ = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens, _ = model.encoder(x=feature, x_lens=feature_lens) hyps = [] @@ -535,10 +532,7 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -700,9 +694,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -735,8 +727,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -789,9 +780,7 @@ def main(): if "LG" in params.decoding_method: params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -826,9 +815,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -860,9 +849,9 @@ def main(): ) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -961,9 +950,7 @@ def main(): decoding_graph.scores *= params.ngram_lm_scale else: word_table = None - decoding_graph = k2.trivial_graph( - params.vocab_size - 1, device=device - ) + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) else: decoding_graph = None word_table = None diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/export.py b/egs/librispeech/ASR/lstm_transducer_stateless2/export.py index 190673638..2a25cb46a 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/export.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/export.py @@ -225,8 +225,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) add_model_arguments(parser) @@ -342,9 +341,7 @@ def export_encoder_model_onnx( x = torch.zeros(N, 9, 80, dtype=torch.float32) x_lens = torch.tensor([9], dtype=torch.int64) h = torch.rand(encoder_model.num_encoder_layers, N, encoder_model.d_model) - c = torch.rand( - encoder_model.num_encoder_layers, N, encoder_model.rnn_hidden_size - ) + c = torch.rand(encoder_model.num_encoder_layers, N, encoder_model.rnn_hidden_size) warmup = 1.0 torch.onnx.export( @@ -445,13 +442,9 @@ def export_joiner_model_onnx( - projected_decoder_out: a tensor of shape (N, joiner_dim) """ - encoder_proj_filename = str(joiner_filename).replace( - ".onnx", "_encoder_proj.onnx" - ) + encoder_proj_filename = str(joiner_filename).replace(".onnx", "_encoder_proj.onnx") - decoder_proj_filename = str(joiner_filename).replace( - ".onnx", "_decoder_proj.onnx" - ) + decoder_proj_filename = str(joiner_filename).replace(".onnx", "_decoder_proj.onnx") encoder_out_dim = joiner_model.encoder_proj.weight.shape[1] decoder_out_dim = joiner_model.decoder_proj.weight.shape[1] @@ -550,9 +543,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -585,9 +578,9 @@ def main(): ) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -694,9 +687,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/jit_pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless2/jit_pretrained.py index da184b76f..40f11018f 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/jit_pretrained.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/jit_pretrained.py @@ -125,8 +125,7 @@ def read_sound_files( for f in filenames: wave, sample_rate = torchaudio.load(f) assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" + f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}" ) # We use only the first channel ans.append(wave[0]) @@ -315,9 +314,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/model.py b/egs/librispeech/ASR/lstm_transducer_stateless2/model.py index fadeb4ac2..4957d14b1 100644 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/model.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/model.py @@ -84,9 +84,7 @@ class Transducer(nn.Module): self.decoder_giga = decoder_giga self.joiner_giga = joiner_giga - self.simple_am_proj = ScaledLinear( - encoder_dim, vocab_size, initial_speed=0.5 - ) + self.simple_am_proj = ScaledLinear(encoder_dim, vocab_size, initial_speed=0.5) self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size) if decoder_giga is not None: @@ -190,9 +188,7 @@ class Transducer(nn.Module): y_padded = y.pad(mode="constant", padding_value=0) y_padded = y_padded.to(torch.int64) - boundary = torch.zeros( - (x.size(0), 4), dtype=torch.int64, device=x.device - ) + boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) boundary[:, 2] = y_lens boundary[:, 3] = x_lens diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/ncnn-decode.py b/egs/librispeech/ASR/lstm_transducer_stateless2/ncnn-decode.py index 410de8d3d..ab2f17480 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/ncnn-decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/ncnn-decode.py @@ -156,9 +156,7 @@ class Model: assert ret == 0, ret encoder_out = torch.from_numpy(ncnn_out0.numpy()).clone() - encoder_out_lens = torch.from_numpy(ncnn_out1.numpy()).to( - torch.int32 - ) + encoder_out_lens = torch.from_numpy(ncnn_out1.numpy()).to(torch.int32) hx = torch.from_numpy(ncnn_out2.numpy()).clone() cx = torch.from_numpy(ncnn_out3.numpy()).clone() return encoder_out, encoder_out_lens, hx, cx @@ -201,8 +199,7 @@ def read_sound_files( for f in filenames: wave, sample_rate = torchaudio.load(f) assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" + f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}" ) # We use only the first channel ans.append(wave[0]) @@ -286,9 +283,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless2/pretrained.py index bef0ad760..2983328bf 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/pretrained.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/pretrained.py @@ -169,8 +169,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -202,8 +201,7 @@ def read_sound_files( for f in filenames: wave, sample_rate = torchaudio.load(f) assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" + f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}" ) # We use only the first channel ans.append(wave[0]) @@ -267,15 +265,11 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) feature_lengths = torch.tensor(feature_lengths, device=device) - encoder_out, encoder_out_lens, _ = model.encoder( - x=features, x_lens=feature_lengths - ) + encoder_out, encoder_out_lens, _ = model.encoder(x=features, x_lens=feature_lengths) num_waves = encoder_out.size(0) hyps = [] @@ -347,9 +341,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-ncnn-decode.py b/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-ncnn-decode.py index e47a05a9e..a787a00e6 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-ncnn-decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-ncnn-decode.py @@ -144,9 +144,7 @@ class Model: assert ret == 0, ret encoder_out = torch.from_numpy(ncnn_out0.numpy()).clone() - encoder_out_lens = torch.from_numpy(ncnn_out1.numpy()).to( - torch.int32 - ) + encoder_out_lens = torch.from_numpy(ncnn_out1.numpy()).to(torch.int32) hx = torch.from_numpy(ncnn_out2.numpy()).clone() cx = torch.from_numpy(ncnn_out3.numpy()).clone() return encoder_out, encoder_out_lens, hx, cx @@ -189,8 +187,7 @@ def read_sound_files( for f in filenames: wave, sample_rate = torchaudio.load(f) assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" + f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}" ) # We use only the first channel ans.append(wave[0]) @@ -229,9 +226,7 @@ def greedy_search( if decoder_out is None: assert hyp is None, hyp hyp = [blank_id] * context_size - decoder_input = torch.tensor( - hyp, dtype=torch.int32 - ) # (1, context_size) + decoder_input = torch.tensor(hyp, dtype=torch.int32) # (1, context_size) decoder_out = model.run_decoder(decoder_input).squeeze(0) else: assert decoder_out.ndim == 1 @@ -310,9 +305,7 @@ def main(): frames.append(online_fbank.get_frame(num_processed_frames + i)) num_processed_frames += offset frames = torch.cat(frames, dim=0) - encoder_out, encoder_out_lens, hx, cx = model.run_encoder( - frames, states - ) + encoder_out, encoder_out_lens, hx, cx = model.run_encoder(frames, states) states = (hx, cx) hyp, decoder_out = greedy_search( model, encoder_out.squeeze(0), decoder_out, hyp @@ -328,9 +321,7 @@ def main(): frames.append(online_fbank.get_frame(num_processed_frames + i)) num_processed_frames += offset frames = torch.cat(frames, dim=0) - encoder_out, encoder_out_lens, hx, cx = model.run_encoder( - frames, states - ) + encoder_out, encoder_out_lens, hx, cx = model.run_encoder(frames, states) states = (hx, cx) hyp, decoder_out = greedy_search( model, encoder_out.squeeze(0), decoder_out, hyp @@ -343,9 +334,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-onnx-decode.py b/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-onnx-decode.py index 232d3dd18..e896fd510 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-onnx-decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-onnx-decode.py @@ -148,8 +148,7 @@ def read_sound_files( for f in filenames: wave, sample_rate = torchaudio.load(f) assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" + f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}" ) # We use only the first channel ans.append(wave[0]) @@ -199,9 +198,7 @@ class Model: sess_options=self.session_opts, ) - def run_encoder( - self, x, h0, c0 - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + def run_encoder(self, x, h0, c0) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Args: x: @@ -258,9 +255,7 @@ class Model: }, )[0] - return self.run_joiner_decoder_proj( - torch.from_numpy(decoder_out).squeeze(1) - ) + return self.run_joiner_decoder_proj(torch.from_numpy(decoder_out).squeeze(1)) def run_joiner( self, @@ -303,11 +298,7 @@ class Model: projected_encoder_out = self.joiner_encoder_proj.run( [self.joiner_encoder_proj.get_outputs()[0].name], - { - self.joiner_encoder_proj.get_inputs()[ - 0 - ].name: encoder_out.numpy() - }, + {self.joiner_encoder_proj.get_inputs()[0].name: encoder_out.numpy()}, )[0] return torch.from_numpy(projected_encoder_out) @@ -326,11 +317,7 @@ class Model: projected_decoder_out = self.joiner_decoder_proj.run( [self.joiner_decoder_proj.get_outputs()[0].name], - { - self.joiner_decoder_proj.get_inputs()[ - 0 - ].name: decoder_out.numpy() - }, + {self.joiner_decoder_proj.get_inputs()[0].name: decoder_out.numpy()}, )[0] return torch.from_numpy(projected_decoder_out) @@ -369,9 +356,7 @@ def greedy_search( if decoder_out is None: assert hyp is None, hyp hyp = [blank_id] * context_size - decoder_input = torch.tensor( - [hyp], dtype=torch.int64 - ) # (1, context_size) + decoder_input = torch.tensor([hyp], dtype=torch.int64) # (1, context_size) decoder_out = model.run_decoder(decoder_input) else: assert decoder_out.shape[0] == 1 @@ -474,9 +459,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/train.py b/egs/librispeech/ASR/lstm_transducer_stateless2/train.py index 5eaaf321f..056285c64 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/train.py @@ -95,9 +95,7 @@ from icefall.utils import ( str2bool, ) -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def add_model_arguments(parser: argparse.ArgumentParser): @@ -163,8 +161,7 @@ def get_parser(): "--full-libri", type=str2bool, default=True, - help="When enabled, use 960h LibriSpeech. " - "Otherwise, use 100h subset.", + help="When enabled, use 960h LibriSpeech. " "Otherwise, use 100h subset.", ) parser.add_argument( @@ -238,8 +235,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( @@ -262,8 +258,7 @@ def get_parser(): "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network)" "part.", ) parser.add_argument( @@ -645,11 +640,7 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = ( - model.device - if isinstance(model, DDP) - else next(model.parameters()).device - ) + device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -692,9 +683,7 @@ def compute_loss( # If either all simple_loss or pruned_loss is inf or nan, # we stop the training process by raising an exception - if torch.all(~simple_loss_is_finite) or torch.all( - ~pruned_loss_is_finite - ): + if torch.all(~simple_loss_is_finite) or torch.all(~pruned_loss_is_finite): raise ValueError( "There are too many utterances in this batch " "leading to inf or nan losses." @@ -707,14 +696,9 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 - if warmup < 1.0 - else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = ( - params.simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss + 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) ) + loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training @@ -725,9 +709,7 @@ def compute_loss( # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2 # (2) If some utterances in the batch lead to inf/nan loss, they # are filtered out. - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa info["utterances"] = feature.size(0) @@ -958,9 +940,7 @@ def train_one_epoch( f"train/current_{prefix}_", params.batch_idx_train, ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) libri_tot_loss.write_summary( tb_writer, "train/libri_tot_", params.batch_idx_train ) @@ -1006,8 +986,7 @@ def filter_short_and_long_utterances( # the threshold if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. " - f"Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. " f"Duration: {c.duration}" ) return False @@ -1155,9 +1134,7 @@ def run(rank, world_size, args): train_giga_cuts = train_giga_cuts.repeat(times=None) if args.enable_musan: - cuts_musan = load_manifest( - Path(args.manifest_dir) / "musan_cuts.jsonl.gz" - ) + cuts_musan = load_manifest(Path(args.manifest_dir) / "musan_cuts.jsonl.gz") else: cuts_musan = None diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/decode.py b/egs/librispeech/ASR/lstm_transducer_stateless3/decode.py index 9eee19379..cba1ac689 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/decode.py @@ -290,8 +290,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( @@ -386,9 +385,7 @@ def decode_one_batch( ) feature_lens += num_tail_padded_frames - encoder_out, encoder_out_lens, _ = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens, _ = model.encoder(x=feature, x_lens=feature_lens) if params.decoding_method == "fast_beam_search": res = fast_beam_search_one_best( @@ -441,10 +438,7 @@ def decode_one_batch( nbest_scale=params.nbest_scale, return_timestamps=True, ) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: res = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -522,9 +516,7 @@ def decode_dataset( sp: spm.SentencePieceProcessor, word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[ - str, List[Tuple[str, List[str], List[str], List[float], List[float]]] -]: +) -> Dict[str, List[Tuple[str, List[str], List[str], List[float], List[float]]]]: """Decode dataset. Args: @@ -599,9 +591,7 @@ def decode_dataset( cut_ids, hyps, texts, timestamps_hyp, timestamps_ref ): ref_words = ref_text.split() - this_batch.append( - (cut_id, ref_words, hyp_words, time_ref, time_hyp) - ) + this_batch.append((cut_id, ref_words, hyp_words, time_ref, time_hyp)) results[name].extend(this_batch) @@ -610,9 +600,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -650,8 +638,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -678,9 +665,7 @@ def save_results( note = "" logging.info(s) - s = "\nFor {}, symbol-delay of different settings are:\n".format( - test_set_name - ) + s = "\nFor {}, symbol-delay of different settings are:\n".format(test_set_name) note = "\tbest for {}".format(test_set_name) for key, val in test_set_delays: s += "{}\tmean: {}s, variance: {}{}\n".format(key, val[0], val[1], note) @@ -724,9 +709,7 @@ def main(): if "LG" in params.decoding_method: params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -758,9 +741,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -787,9 +770,9 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -848,9 +831,7 @@ def main(): decoding_graph.scores *= params.ngram_lm_scale else: word_table = None - decoding_graph = k2.trivial_graph( - params.vocab_size - 1, device=device - ) + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) else: decoding_graph = None word_table = None diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/export.py b/egs/librispeech/ASR/lstm_transducer_stateless3/export.py index 212c7bad6..457bd472f 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/export.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/export.py @@ -172,8 +172,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) add_model_arguments(parser) @@ -281,9 +280,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -310,9 +309,9 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -380,9 +379,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/jit_pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless3/jit_pretrained.py index a3443cf0a..71b37ac55 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/jit_pretrained.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/jit_pretrained.py @@ -124,8 +124,7 @@ def read_sound_files( for f in filenames: wave, sample_rate = torchaudio.load(f) assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" + f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}" ) # We use only the first channel ans.append(wave[0]) @@ -314,9 +313,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/lstm.py b/egs/librispeech/ASR/lstm_transducer_stateless3/lstm.py index 90bc351f4..6e51b85e4 100644 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/lstm.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/lstm.py @@ -661,9 +661,7 @@ class RandomCombine(nn.Module): self.stddev = stddev self.final_log_weight = ( - torch.tensor( - (final_weight / (1 - final_weight)) * (self.num_inputs - 1) - ) + torch.tensor((final_weight / (1 - final_weight)) * (self.num_inputs - 1)) .log() .item() ) @@ -760,16 +758,14 @@ class RandomCombine(nn.Module): # final contains self.num_inputs - 1 in all elements final = torch.full((num_frames,), self.num_inputs - 1, device=device) # nonfinal contains random integers in [0..num_inputs - 2], these are for non-final weights. # noqa - nonfinal = torch.randint( - self.num_inputs - 1, (num_frames,), device=device - ) + nonfinal = torch.randint(self.num_inputs - 1, (num_frames,), device=device) indexes = torch.where( torch.rand(num_frames, device=device) < final_prob, final, nonfinal ) - ans = torch.nn.functional.one_hot( - indexes, num_classes=self.num_inputs - ).to(dtype=dtype) + ans = torch.nn.functional.one_hot(indexes, num_classes=self.num_inputs).to( + dtype=dtype + ) return ans def _get_random_mixed_weights( diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless3/pretrained.py index 0e48fef04..e72f4ee42 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/pretrained.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/pretrained.py @@ -166,8 +166,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -199,8 +198,7 @@ def read_sound_files( for f in filenames: wave, sample_rate = torchaudio.load(f) assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" + f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}" ) # We use only the first channel ans.append(wave[0]) @@ -264,15 +262,11 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) feature_lengths = torch.tensor(feature_lengths, device=device) - encoder_out, encoder_out_lens, _ = model.encoder( - x=features, x_lens=feature_lengths - ) + encoder_out, encoder_out_lens, _ = model.encoder(x=features, x_lens=feature_lengths) num_waves = encoder_out.size(0) hyps = [] @@ -344,9 +338,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/streaming_decode.py b/egs/librispeech/ASR/lstm_transducer_stateless3/streaming_decode.py index cfa918ed5..dad6b905f 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/streaming_decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/streaming_decode.py @@ -199,8 +199,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -359,9 +358,7 @@ def modified_beam_search( index=hyps_shape.row_ids(1).to(torch.int64), ) # (num_hyps, encoder_out_dim) - logits = model.joiner( - current_encoder_out, decoder_out, project_input=False - ) + logits = model.joiner(current_encoder_out, decoder_out, project_input=False) # logits is of shape (num_hyps, 1, 1, vocab_size) logits = logits.squeeze(1).squeeze(1) @@ -378,9 +375,7 @@ def modified_beam_search( log_probs_shape = k2.ragged.create_ragged_shape2( row_splits=row_splits, cached_tot_size=log_probs.numel() ) - ragged_log_probs = k2.RaggedTensor( - shape=log_probs_shape, value=log_probs - ) + ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) for i in range(batch_size): topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) @@ -539,9 +534,7 @@ def decode_one_chunk( feature_list, batch_first=True, padding_value=LOG_EPSILON ).to(device) feature_lens = torch.tensor(feature_len_list, device=device) - num_processed_frames = torch.tensor( - num_processed_frames_list, device=device - ) + num_processed_frames = torch.tensor(num_processed_frames_list, device=device) # Make sure it has at least 1 frame after subsampling tail_length = params.subsampling_factor + 5 @@ -583,8 +576,7 @@ def decode_one_chunk( with warnings.catch_warnings(): warnings.simplefilter("ignore") processed_lens = ( - num_processed_frames // params.subsampling_factor - + encoder_out_lens + num_processed_frames // params.subsampling_factor + encoder_out_lens ) fast_beam_search_one_best( model=model, @@ -596,9 +588,7 @@ def decode_one_chunk( max_states=params.max_states, ) else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") # Update cached states of each stream state_list = unstack_states(states) @@ -773,8 +763,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -816,9 +805,7 @@ def main(): params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -852,9 +839,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -881,9 +868,9 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/train.py b/egs/librispeech/ASR/lstm_transducer_stateless3/train.py index 60a5a2be7..97ca4b94c 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/train.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/train.py @@ -87,9 +87,7 @@ from icefall.utils import ( str2bool, ) -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def add_model_arguments(parser: argparse.ArgumentParser): @@ -232,8 +230,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( @@ -256,8 +253,7 @@ def get_parser(): "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network)" "part.", ) parser.add_argument( @@ -606,11 +602,7 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = ( - model.device - if isinstance(model, DDP) - else next(model.parameters()).device - ) + device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -650,9 +642,7 @@ def compute_loss( # If either all simple_loss or pruned_loss is inf or nan, # we stop the training process by raising an exception - if torch.all(~simple_loss_is_finite) or torch.all( - ~pruned_loss_is_finite - ): + if torch.all(~simple_loss_is_finite) or torch.all(~pruned_loss_is_finite): raise ValueError( "There are too many utterances in this batch " "leading to inf or nan losses." @@ -665,14 +655,9 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 - if warmup < 1.0 - else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = ( - params.simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss + 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) ) + loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training @@ -683,9 +668,7 @@ def compute_loss( # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2 # (2) If some utterances in the batch lead to inf/nan loss, they # are filtered out. - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa info["utterances"] = feature.size(0) @@ -852,10 +835,7 @@ def train_one_epoch( rank=rank, ) - if ( - batch_idx % params.log_interval == 0 - and not params.print_diagnostics - ): + if batch_idx % params.log_interval == 0 and not params.print_diagnostics: cur_lr = scheduler.get_last_lr()[0] logging.info( f"Epoch {params.cur_epoch}, " @@ -872,9 +852,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if ( batch_idx > 0 @@ -1009,8 +987,7 @@ def run(rank, world_size, args): # the threshold if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. " - f"Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. " f"Duration: {c.duration}" ) return False diff --git a/egs/librispeech/ASR/pruned2_knowledge/asr_datamodule.py b/egs/librispeech/ASR/pruned2_knowledge/asr_datamodule.py index 8dd1459ca..3dc9164f8 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/asr_datamodule.py +++ b/egs/librispeech/ASR/pruned2_knowledge/asr_datamodule.py @@ -83,8 +83,7 @@ class LibriSpeechAsrDataModule: "--full-libri", type=str2bool, default=True, - help="When enabled, use 960h LibriSpeech. " - "Otherwise, use 100h subset.", + help="When enabled, use 960h LibriSpeech. " "Otherwise, use 100h subset.", ) group.add_argument( "--manifest-dir", @@ -208,13 +207,9 @@ class LibriSpeechAsrDataModule: if self.args.enable_musan: logging.info("Enable MUSAN") logging.info("About to get Musan cuts") - cuts_musan = load_manifest( - self.args.manifest_dir / "cuts_musan.json.gz" - ) + cuts_musan = load_manifest(self.args.manifest_dir / "cuts_musan.json.gz") transforms.append( - CutMix( - cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True - ) + CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) ) else: logging.info("Disable MUSAN") @@ -236,9 +231,7 @@ class LibriSpeechAsrDataModule: input_transforms = [] if self.args.enable_spec_aug: logging.info("Enable SpecAugment") - logging.info( - f"Time warp factor: {self.args.spec_aug_time_warp_factor}" - ) + logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") # Set the value of num_frame_masks according to Lhotse's version. # In different Lhotse's versions, the default of num_frame_masks is # different. @@ -281,9 +274,7 @@ class LibriSpeechAsrDataModule: # Drop feats to be on the safe side. train = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) - ), + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), input_transforms=input_transforms, return_cuts=self.args.return_cuts, ) @@ -340,9 +331,7 @@ class LibriSpeechAsrDataModule: if self.args.on_the_fly_feats: validate = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) - ), + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), return_cuts=self.args.return_cuts, ) else: @@ -389,23 +378,17 @@ class LibriSpeechAsrDataModule: @lru_cache() def train_clean_100_cuts(self) -> CutSet: logging.info("About to get train-clean-100 cuts") - return load_manifest( - self.args.manifest_dir / "cuts_train-clean-100.json.gz" - ) + return load_manifest(self.args.manifest_dir / "cuts_train-clean-100.json.gz") @lru_cache() def train_clean_360_cuts(self) -> CutSet: logging.info("About to get train-clean-360 cuts") - return load_manifest( - self.args.manifest_dir / "cuts_train-clean-360.json.gz" - ) + return load_manifest(self.args.manifest_dir / "cuts_train-clean-360.json.gz") @lru_cache() def train_other_500_cuts(self) -> CutSet: logging.info("About to get train-other-500 cuts") - return load_manifest( - self.args.manifest_dir / "cuts_train-other-500.json.gz" - ) + return load_manifest(self.args.manifest_dir / "cuts_train-other-500.json.gz") @lru_cache() def dev_clean_cuts(self) -> CutSet: diff --git a/egs/librispeech/ASR/pruned2_knowledge/beam_search.py b/egs/librispeech/ASR/pruned2_knowledge/beam_search.py index 2e9bf3e0b..785a8f097 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/beam_search.py +++ b/egs/librispeech/ASR/pruned2_knowledge/beam_search.py @@ -172,9 +172,9 @@ def greedy_search( y = logits.argmax().item() if y != blank_id: hyp.append(y) - decoder_input = torch.tensor( - [hyp[-context_size:]], device=device - ).reshape(1, context_size) + decoder_input = torch.tensor([hyp[-context_size:]], device=device).reshape( + 1, context_size + ) decoder_out = model.decoder(decoder_input, need_pad=False) decoder_out = model.joiner.decoder_proj(decoder_out) @@ -302,9 +302,7 @@ class HypothesisList(object): key = hyp.key if key in self: old_hyp = self._data[key] # shallow copy - torch.logaddexp( - old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob - ) + torch.logaddexp(old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob) else: self._data[key] = hyp @@ -320,9 +318,7 @@ class HypothesisList(object): Return the hypothesis that has the largest `log_prob`. """ if length_norm: - return max( - self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys) - ) + return max(self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys)) else: return max(self._data.values(), key=lambda hyp: hyp.log_prob) @@ -496,9 +492,7 @@ def modified_beam_search( log_probs_shape = k2.ragged.create_ragged_shape2( row_splits=row_splits, cached_tot_size=log_probs.numel() ) - ragged_log_probs = k2.RaggedTensor( - shape=log_probs_shape, value=log_probs - ) + ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) for i in range(batch_size): topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) diff --git a/egs/librispeech/ASR/pruned2_knowledge/conformer.py b/egs/librispeech/ASR/pruned2_knowledge/conformer.py index 295a35204..de367c234 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/conformer.py +++ b/egs/librispeech/ASR/pruned2_knowledge/conformer.py @@ -18,10 +18,10 @@ import math import warnings from typing import Optional, Tuple -from sampling import create_knowledge_base, KnowledgeBaseLookup import torch from encoder_interface import EncoderInterface +from sampling import KnowledgeBaseLookup, create_knowledge_base from scaling import ( ActivationBalancer, BasicNorm, @@ -73,9 +73,9 @@ class Conformer(EncoderInterface): if subsampling_factor != 4: raise NotImplementedError("Support only 'subsampling_factor=4'.") - - self.knowledge_base = create_knowledge_base(knowledge_M, knowledge_N, - knowledge_D) + self.knowledge_base = create_knowledge_base( + knowledge_M, knowledge_N, knowledge_D + ) # self.encoder_embed converts the input of shape (N, T, num_features) # to the shape (N, T//subsampling_factor, d_model). @@ -89,7 +89,7 @@ class Conformer(EncoderInterface): # Pass in a lambda that creates a new ConformerEncoderLayer with these # args. Don't use deepcopy because we need the knowledge_base # to be shared. - encoder_layer_fn = lambda: ConformerEncoderLayer( + encoder_layer_fn = lambda: ConformerEncoderLayer( # noqa: E731 self.knowledge_base, d_model, nhead, @@ -100,7 +100,7 @@ class Conformer(EncoderInterface): knowledge_M, knowledge_N, knowledge_D, - knowledge_K + knowledge_K, ) self.encoder = ConformerEncoder(encoder_layer_fn, num_encoder_layers) @@ -187,9 +187,7 @@ class ConformerEncoderLayer(nn.Module): self.d_model = d_model - self.self_attn = RelPositionMultiheadAttention( - d_model, nhead, dropout=0.0 - ) + self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) self.feed_forward = nn.Sequential( ScaledLinear(d_model, dim_feedforward), @@ -209,10 +207,9 @@ class ConformerEncoderLayer(nn.Module): self.conv_module = ConvolutionModule(d_model, cnn_module_kernel) - self.lookup = KnowledgeBaseLookup(knowledge_M, knowledge_N, - knowledge_D, knowledge_K, - d_model, - knowledge_base) + self.lookup = KnowledgeBaseLookup( + knowledge_M, knowledge_N, knowledge_D, knowledge_K, d_model, knowledge_base + ) self.norm_final = BasicNorm(d_model) @@ -311,9 +308,7 @@ class ConformerEncoder(nn.Module): def __init__(self, encoder_layer_fn, num_layers: int) -> None: super().__init__() - self.layers = nn.ModuleList( - [encoder_layer_fn() for i in range(num_layers)] - ) + self.layers = nn.ModuleList([encoder_layer_fn() for i in range(num_layers)]) self.num_layers = num_layers def forward( @@ -367,9 +362,7 @@ class RelPositionalEncoding(torch.nn.Module): """ - def __init__( - self, d_model: int, dropout_rate: float, max_len: int = 5000 - ) -> None: + def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: """Construct an PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() self.d_model = d_model @@ -384,9 +377,7 @@ class RelPositionalEncoding(torch.nn.Module): # the length of self.pe is 2 * input_len - 1 if self.pe.size(1) >= x.size(1) * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str( - x.device - ): + if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return # Suppose `i` means to the position of query vecotr and `j` means the @@ -661,9 +652,9 @@ class RelPositionMultiheadAttention(nn.Module): if torch.equal(query, key) and torch.equal(key, value): # self-attention - q, k, v = nn.functional.linear( - query, in_proj_weight, in_proj_bias - ).chunk(3, dim=-1) + q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk( + 3, dim=-1 + ) elif torch.equal(key, value): # encoder-decoder attention @@ -732,31 +723,22 @@ class RelPositionMultiheadAttention(nn.Module): if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: - raise RuntimeError( - "The size of the 2D attn_mask is not correct." - ) + raise RuntimeError("The size of the 2D attn_mask is not correct.") elif attn_mask.dim() == 3: if list(attn_mask.size()) != [ bsz * num_heads, query.size(0), key.size(0), ]: - raise RuntimeError( - "The size of the 3D attn_mask is not correct." - ) + raise RuntimeError("The size of the 3D attn_mask is not correct.") else: raise RuntimeError( - "attn_mask's dimension {} is not supported".format( - attn_mask.dim() - ) + "attn_mask's dimension {} is not supported".format(attn_mask.dim()) ) # attn_mask's dim is 3 now. # convert ByteTensor key_padding_mask to bool - if ( - key_padding_mask is not None - and key_padding_mask.dtype == torch.uint8 - ): + if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: warnings.warn( "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." ) @@ -795,9 +777,7 @@ class RelPositionMultiheadAttention(nn.Module): # first compute matrix a and matrix c # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - matrix_ac = torch.matmul( - q_with_bias_u, k - ) # (batch, head, time1, time2) + matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2) # compute matrix b and matrix d matrix_bd = torch.matmul( @@ -805,13 +785,9 @@ class RelPositionMultiheadAttention(nn.Module): ) # (batch, head, time1, 2*time1-1) matrix_bd = self.rel_shift(matrix_bd) - attn_output_weights = ( - matrix_ac + matrix_bd - ) # (batch, head, time1, time2) + attn_output_weights = matrix_ac + matrix_bd # (batch, head, time1, time2) - attn_output_weights = attn_output_weights.view( - bsz * num_heads, tgt_len, -1 - ) + attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) assert list(attn_output_weights.size()) == [ bsz * num_heads, @@ -845,13 +821,9 @@ class RelPositionMultiheadAttention(nn.Module): attn_output = torch.bmm(attn_output_weights, v) assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] attn_output = ( - attn_output.transpose(0, 1) - .contiguous() - .view(tgt_len, bsz, embed_dim) - ) - attn_output = nn.functional.linear( - attn_output, out_proj_weight, out_proj_bias + attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) ) + attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) if need_weights: # average attention weights over heads @@ -874,9 +846,7 @@ class ConvolutionModule(nn.Module): """ - def __init__( - self, channels: int, kernel_size: int, bias: bool = True - ) -> None: + def __init__(self, channels: int, kernel_size: int, bias: bool = True) -> None: """Construct an ConvolutionModule object.""" super(ConvolutionModule, self).__init__() # kernerl_size should be a odd number for 'SAME' padding diff --git a/egs/librispeech/ASR/pruned2_knowledge/decode.py b/egs/librispeech/ASR/pruned2_knowledge/decode.py index b4a9af55a..c3e7b01ab 100755 --- a/egs/librispeech/ASR/pruned2_knowledge/decode.py +++ b/egs/librispeech/ASR/pruned2_knowledge/decode.py @@ -76,11 +76,7 @@ from beam_search import ( ) from train import get_params, get_transducer_model -from icefall.checkpoint import ( - average_checkpoints, - find_checkpoints, - load_checkpoint, -) +from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.utils import ( AttributeDict, setup_logger, @@ -186,8 +182,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -245,9 +240,7 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) hyps = [] if params.decoding_method == "fast_beam_search": @@ -262,10 +255,7 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -385,9 +375,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -419,8 +407,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) diff --git a/egs/librispeech/ASR/pruned2_knowledge/decoder.py b/egs/librispeech/ASR/pruned2_knowledge/decoder.py index b6d94aaf1..0b9c886c7 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/decoder.py +++ b/egs/librispeech/ASR/pruned2_knowledge/decoder.py @@ -90,9 +90,7 @@ class Decoder(nn.Module): if self.context_size > 1: embedding_out = embedding_out.permute(0, 2, 1) if need_pad is True: - embedding_out = F.pad( - embedding_out, pad=(self.context_size - 1, 0) - ) + embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0)) else: # During inference time, there is no need to do extra padding # as we only need one output diff --git a/egs/librispeech/ASR/pruned2_knowledge/decoder2.py b/egs/librispeech/ASR/pruned2_knowledge/decoder2.py index db51fb1cd..0c9cee431 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/decoder2.py +++ b/egs/librispeech/ASR/pruned2_knowledge/decoder2.py @@ -14,12 +14,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional + import torch import torch.nn as nn import torch.nn.functional as F -from torch import Tensor -from typing import Optional from subsampling import ScaledConv1d +from torch import Tensor class Decoder(nn.Module): @@ -90,9 +91,7 @@ class Decoder(nn.Module): if self.context_size > 1: embedding_out = embedding_out.permute(0, 2, 1) if need_pad is True: - embedding_out = F.pad( - embedding_out, pad=(self.context_size - 1, 0) - ) + embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0)) else: # During inference time, there is no need to do extra padding # as we only need one output @@ -102,7 +101,6 @@ class Decoder(nn.Module): return embedding_out - class ScaledEmbedding(nn.Module): r"""A simple lookup table that stores embeddings of a fixed dictionary and size. @@ -171,8 +169,13 @@ class ScaledEmbedding(nn.Module): [ 0.0000, 0.0000, 0.0000], [-0.1655, 0.9897, 0.0635]]]) """ - __constants__ = ['num_embeddings', 'embedding_dim', 'padding_idx', - 'scale_grad_by_freq', 'sparse'] + __constants__ = [ + "num_embeddings", + "embedding_dim", + "padding_idx", + "scale_grad_by_freq", + "sparse", + ] num_embeddings: int embedding_dim: int @@ -181,34 +184,41 @@ class ScaledEmbedding(nn.Module): weight: Tensor sparse: bool - def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None, - scale_grad_by_freq: bool = False, - sparse: bool = False, - scale_speed: float = 5.0) -> None: + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + padding_idx: Optional[int] = None, + scale_grad_by_freq: bool = False, + sparse: bool = False, + scale_speed: float = 5.0, + ) -> None: super(ScaledEmbedding, self).__init__() self.num_embeddings = num_embeddings self.embedding_dim = embedding_dim if padding_idx is not None: if padding_idx > 0: - assert padding_idx < self.num_embeddings, 'Padding_idx must be within num_embeddings' + assert ( + padding_idx < self.num_embeddings + ), "Padding_idx must be within num_embeddings" elif padding_idx < 0: - assert padding_idx >= -self.num_embeddings, 'Padding_idx must be within num_embeddings' + assert ( + padding_idx >= -self.num_embeddings + ), "Padding_idx must be within num_embeddings" padding_idx = self.num_embeddings + padding_idx self.padding_idx = padding_idx self.scale_grad_by_freq = scale_grad_by_freq self.scale_speed = scale_speed - self.scale = nn.Parameter(torch.zeros(())) # see reset_parameters() + self.scale = nn.Parameter(torch.zeros(())) # see reset_parameters() self.sparse = sparse self.weight = nn.Parameter(torch.Tensor(num_embeddings, embedding_dim)) self.reset_parameters() - - def reset_parameters(self) -> None: nn.init.normal_(self.weight, std=0.05) - nn.init.constant_(self.scale, torch.tensor(1.0/0.05).log() / self.scale_speed) + nn.init.constant_(self.scale, torch.tensor(1.0 / 0.05).log() / self.scale_speed) if self.padding_idx is not None: with torch.no_grad(): @@ -217,22 +227,35 @@ class ScaledEmbedding(nn.Module): def forward(self, input: Tensor) -> Tensor: scale = (self.scale * self.scale_speed).exp() if input.numel() < self.num_embeddings: - return F.embedding( - input, self.weight, self.padding_idx, - None, 2.0, # None, 2.0 relate to normalization - self.scale_grad_by_freq, self.sparse) * scale + return ( + F.embedding( + input, + self.weight, + self.padding_idx, + None, + 2.0, # None, 2.0 relate to normalization + self.scale_grad_by_freq, + self.sparse, + ) + * scale + ) else: return F.embedding( - input, self.weight * scale, self.padding_idx, - None, 2.0, # None, 2.0 relates to normalization - self.scale_grad_by_freq, self.sparse) + input, + self.weight * scale, + self.padding_idx, + None, + 2.0, # None, 2.0 relates to normalization + self.scale_grad_by_freq, + self.sparse, + ) def extra_repr(self) -> str: - s = '{num_embeddings}, {embedding_dim}, scale_speed={scale_speed}, scale={scale}' + s = "{num_embeddings}, {embedding_dim}, scale_speed={scale_speed}, scale={scale}" if self.padding_idx is not None: - s += ', padding_idx={padding_idx}' + s += ", padding_idx={padding_idx}" if self.scale_grad_by_freq is not False: - s += ', scale_grad_by_freq={scale_grad_by_freq}' + s += ", scale_grad_by_freq={scale_grad_by_freq}" if self.sparse is not False: - s += ', sparse=True' + s += ", sparse=True" return s.format(**self.__dict__) diff --git a/egs/librispeech/ASR/pruned2_knowledge/export.py b/egs/librispeech/ASR/pruned2_knowledge/export.py index 96d1a30fb..ce5f162bf 100755 --- a/egs/librispeech/ASR/pruned2_knowledge/export.py +++ b/egs/librispeech/ASR/pruned2_knowledge/export.py @@ -105,8 +105,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) return parser @@ -174,9 +173,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned2_knowledge/joiner.py b/egs/librispeech/ASR/pruned2_knowledge/joiner.py index 35f75ed2a..68c663b66 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/joiner.py +++ b/egs/librispeech/ASR/pruned2_knowledge/joiner.py @@ -56,9 +56,7 @@ class Joiner(nn.Module): assert encoder_out.shape[:-1] == decoder_out.shape[:-1] if project_input: - logit = self.encoder_proj(encoder_out) + self.decoder_proj( - decoder_out - ) + logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out) else: logit = encoder_out + decoder_out diff --git a/egs/librispeech/ASR/pruned2_knowledge/model.py b/egs/librispeech/ASR/pruned2_knowledge/model.py index 599bf2506..ca8c28af1 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/model.py +++ b/egs/librispeech/ASR/pruned2_knowledge/model.py @@ -63,9 +63,7 @@ class Transducer(nn.Module): self.decoder = decoder self.joiner = joiner - self.simple_am_proj = ScaledLinear( - encoder_dim, vocab_size, initial_speed=0.5 - ) + self.simple_am_proj = ScaledLinear(encoder_dim, vocab_size, initial_speed=0.5) self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size) def forward( @@ -136,9 +134,7 @@ class Transducer(nn.Module): y_padded = y.pad(mode="constant", padding_value=0) y_padded = y_padded.to(torch.int64) - boundary = torch.zeros( - (x.size(0), 4), dtype=torch.int64, device=x.device - ) + boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) boundary[:, 2] = y_lens boundary[:, 3] = x_lens diff --git a/egs/librispeech/ASR/pruned2_knowledge/optim.py b/egs/librispeech/ASR/pruned2_knowledge/optim.py index 432bf8220..76cd4e11e 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/optim.py +++ b/egs/librispeech/ASR/pruned2_knowledge/optim.py @@ -72,17 +72,11 @@ class Eve(Optimizer): if not 0.0 <= eps: raise ValueError("Invalid epsilon value: {}".format(eps)) if not 0.0 <= betas[0] < 1.0: - raise ValueError( - "Invalid beta parameter at index 0: {}".format(betas[0]) - ) + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) if not 0.0 <= betas[1] < 1.0: - raise ValueError( - "Invalid beta parameter at index 1: {}".format(betas[1]) - ) + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) if not 0 <= weight_decay <= 0.1: - raise ValueError( - "Invalid weight_decay value: {}".format(weight_decay) - ) + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) if not 0 < target_rms <= 10.0: raise ValueError("Invalid target_rms value: {}".format(target_rms)) defaults = dict( @@ -118,9 +112,7 @@ class Eve(Optimizer): # Perform optimization step grad = p.grad if grad.is_sparse: - raise RuntimeError( - "AdamW does not support sparse gradients" - ) + raise RuntimeError("AdamW does not support sparse gradients") state = self.state[p] @@ -147,7 +139,7 @@ class Eve(Optimizer): # Decay the first and second moment running average coefficient exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) - denom = (exp_avg_sq.sqrt() * (bias_correction2 ** -0.5)).add_( + denom = (exp_avg_sq.sqrt() * (bias_correction2**-0.5)).add_( group["eps"] ) @@ -158,9 +150,7 @@ class Eve(Optimizer): if p.numel() > 1: # avoid applying this weight-decay on "scaling factors" # (which are scalar). - is_above_target_rms = p.norm() > ( - target_rms * (p.numel() ** 0.5) - ) + is_above_target_rms = p.norm() > (target_rms * (p.numel() ** 0.5)) p.mul_(1 - (weight_decay * is_above_target_rms)) p.addcdiv_(exp_avg, denom, value=-step_size) @@ -176,18 +166,14 @@ class LRScheduler(object): def __init__(self, optimizer: Optimizer, verbose: bool = False): # Attach optimizer if not isinstance(optimizer, Optimizer): - raise TypeError( - "{} is not an Optimizer".format(type(optimizer).__name__) - ) + raise TypeError("{} is not an Optimizer".format(type(optimizer).__name__)) self.optimizer = optimizer self.verbose = verbose for group in optimizer.param_groups: group.setdefault("initial_lr", group["lr"]) - self.base_lrs = [ - group["initial_lr"] for group in optimizer.param_groups - ] + self.base_lrs = [group["initial_lr"] for group in optimizer.param_groups] self.epoch = 0 self.batch = 0 @@ -295,10 +281,9 @@ class Eden(LRScheduler): def get_lr(self): factor = ( - (self.batch ** 2 + self.lr_batches ** 2) / self.lr_batches ** 2 + (self.batch**2 + self.lr_batches**2) / self.lr_batches**2 ) ** -0.25 * ( - ((self.epoch ** 2 + self.lr_epochs ** 2) / self.lr_epochs ** 2) - ** -0.25 + ((self.epoch**2 + self.lr_epochs**2) / self.lr_epochs**2) ** -0.25 ) return [x * factor for x in self.base_lrs] diff --git a/egs/librispeech/ASR/pruned2_knowledge/sampling.py b/egs/librispeech/ASR/pruned2_knowledge/sampling.py index 7b05e2f00..5b595c76c 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/sampling.py +++ b/egs/librispeech/ASR/pruned2_knowledge/sampling.py @@ -3,32 +3,29 @@ # This was copied from /ceph-dan/torch-sampling/torch_sampling/sampling_ref.py, # its git history is there. -import timeit -import torch -from torch import Tensor -from torch import nn -from torch.cuda.amp import GradScaler, custom_fwd, custom_bwd -from typing import Tuple, Optional -from scaling import ScaledLinear import random +import timeit +from typing import Optional, Tuple + +import torch +from scaling import ScaledLinear +from torch import Tensor, nn +from torch.cuda.amp import GradScaler, custom_bwd, custom_fwd from torch_scheduled_sampling import sample_combined # The main exports of this file are the module KnowledgeBaseLookup and the # function create_knowledge_base. - - - - def create_knowledge_base(M: int, N: int, D: int) -> nn.Parameter: std = 0.1 - a = (3 ** 0.5) * std # this sqrt(3) thing is intended to get variance of - # 0.1 from uniform distribution - ans = nn.Parameter(torch.ones(M ** N, D)) + a = (3**0.5) * std # this sqrt(3) thing is intended to get variance of + # 0.1 from uniform distribution + ans = nn.Parameter(torch.ones(M**N, D)) nn.init.uniform_(ans, -a, a) return ans + def join_indexes(indexes: Tensor, M: int) -> Tensor: """ Combines N-tuples of indexes into single indexes that can be used for @@ -47,9 +44,9 @@ def join_indexes(indexes: Tensor, M: int) -> Tensor: # Note, we don't use this, we -def weighted_matrix_lookup(weights: Tensor, - indexes: Tensor, - knowledge_base: Tensor) -> Tensor: +def weighted_matrix_lookup( + weights: Tensor, indexes: Tensor, knowledge_base: Tensor +) -> Tensor: """ Weighted combination of specified rows of a matrix. weights: Tensor of shape (*, K), can contain any value but probably in [0..1]. @@ -65,9 +62,9 @@ def weighted_matrix_lookup(weights: Tensor, # simpler but less memory-efficient implementation lookup = torch.index_select(knowledge_base, dim=0, index=indexes.flatten()) D = knowledge_base.shape[-1] - weights = weights.unsqueeze(-2) # (*, 1, K) - lookup = lookup.reshape(*indexes.shape, D) # (*, K, D) - ans = torch.matmul(weights, lookup) # ans: (*, 1, D) + weights = weights.unsqueeze(-2) # (*, 1, K) + lookup = lookup.reshape(*indexes.shape, D) # (*, K, D) + ans = torch.matmul(weights, lookup) # ans: (*, 1, D) ans = ans.squeeze(-2) assert list(ans.shape) == list(weights.shape[:-2]) + [D] return ans @@ -76,7 +73,9 @@ def weighted_matrix_lookup(weights: Tensor, class WeightedMatrixLookupFunction(torch.autograd.Function): @staticmethod @custom_fwd - def forward(ctx, weights: Tensor, indexes: Tensor, knowledge_base: Tensor) -> Tensor: + def forward( + ctx, weights: Tensor, indexes: Tensor, knowledge_base: Tensor + ) -> Tensor: """ Weighted combination of specified rows of a matrix. weights: Tensor of shape (*, K), can contain any value but probably in [0..1]. @@ -88,15 +87,16 @@ class WeightedMatrixLookupFunction(torch.autograd.Function): """ if random.random() < 0.001: print("dtype[1] = ", weights.dtype) - ctx.save_for_backward(weights.detach(), indexes.detach(), - knowledge_base.detach()) + ctx.save_for_backward( + weights.detach(), indexes.detach(), knowledge_base.detach() + ) with torch.no_grad(): lookup = torch.index_select(knowledge_base, dim=0, index=indexes.flatten()) D = knowledge_base.shape[-1] - weights = weights.unsqueeze(-2) # (*, 1, K) - lookup = lookup.reshape(*indexes.shape, D) # (*, K, D) - ans = torch.matmul(weights, lookup) # ans: (*, 1, D) - ans = ans.squeeze(-2) #(*, D) + weights = weights.unsqueeze(-2) # (*, 1, K) + lookup = lookup.reshape(*indexes.shape, D) # (*, K, D) + ans = torch.matmul(weights, lookup) # ans: (*, 1, D) + ans = ans.squeeze(-2) # (*, D) return ans @staticmethod @@ -107,7 +107,7 @@ class WeightedMatrixLookupFunction(torch.autograd.Function): knowledge_base.requires_grad = True dtype = ans_grad.dtype ans_grad = ans_grad.to(weights.dtype) - assert weights.requires_grad == False + assert weights.requires_grad is False D = knowledge_base.shape[-1] with torch.enable_grad(): # we'll use torch's autograd to differentiate this operation, which @@ -115,16 +115,19 @@ class WeightedMatrixLookupFunction(torch.autograd.Function): # We don't save `lookup` because it's large, that is the reason # we override Torch autograd. lookup = torch.index_select(knowledge_base, dim=0, index=indexes.flatten()) - lookup = lookup.reshape(*indexes.shape, D) # (*, K, D) - weights = weights.unsqueeze(-1) # (*, K, 1) + lookup = lookup.reshape(*indexes.shape, D) # (*, K, D) + weights = weights.unsqueeze(-1) # (*, K, 1) # forward pass: was: ## ans = torch.matmul(weights, lookup) ## ans: (*, 1, D) ## ans = ans.squeeze(-2) # ans, ans_grad: (*, D) - weights_grad = torch.matmul(lookup, # (*, K, D) - ans_grad.unsqueeze(-1)) # (*, D, 1) - weights_grad = weights_grad.squeeze(-1) # (*, K, 1) -> (*, K) - lookup_grad = weights * ans_grad.unsqueeze(-2) # (*, K, 1) * (*, 1, D) = (*, K, D) + weights_grad = torch.matmul( + lookup, ans_grad.unsqueeze(-1) # (*, K, D) + ) # (*, D, 1) + weights_grad = weights_grad.squeeze(-1) # (*, K, 1) -> (*, K) + lookup_grad = weights * ans_grad.unsqueeze( + -2 + ) # (*, K, 1) * (*, 1, D) = (*, K, D) lookup.backward(gradient=lookup_grad) return weights_grad.to(dtype), None, knowledge_base.grad.to(dtype) @@ -146,6 +149,7 @@ class PenalizeNegentropyFunction(torch.autograd.Function): Returns: logprobs """ + @staticmethod def forward(ctx, logprobs: Tensor, alpha: float): ctx.save_for_backward(logprobs.detach()) @@ -154,18 +158,23 @@ class PenalizeNegentropyFunction(torch.autograd.Function): @staticmethod def backward(ctx, logprobs_grad: Tensor) -> Tuple[Tensor, None]: - logprobs, = ctx.saved_tensors + (logprobs,) = ctx.saved_tensors with torch.enable_grad(): logprobs.requires_grad = True # `negentropy` is the negative entropy of the average distribution. # distributions. It will be <= 0. - l = logprobs.reshape(-1, logprobs.shape[-1]) + l = logprobs.reshape(-1, logprobs.shape[-1]) # noqa: E741 scale = ctx.alpha * l.shape[0] avg_dist = l.exp().mean(dim=0) negentropy = (avg_dist * (avg_dist + 1.0e-20).log()).sum() if random.random() < 0.0005: negentropy_individual = (l * l.exp()).sum(dim=-1).mean() - print("Negentropy[individual,combined] = ", negentropy_individual.item(), ", ", negentropy.item()) + print( + "Negentropy[individual,combined] = ", + negentropy_individual.item(), + ", ", + negentropy.item(), + ) loss = negentropy * scale loss.backward() return logprobs_grad + logprobs.grad, None @@ -183,18 +192,23 @@ class KnowledgeBaseLookup(nn.Module): embedding_dim: the dimension to project from and to, e.g. the d_model of the conformer. """ - def __init__(self, M: int, N: int, D: int, - K: int, embedding_dim: int, - knowledge_base: nn.Parameter, - negentropy_penalty: float = 0.001): + + def __init__( + self, + M: int, + N: int, + D: int, + K: int, + embedding_dim: int, + knowledge_base: nn.Parameter, + negentropy_penalty: float = 0.001, + ): super(KnowledgeBaseLookup, self).__init__() self.knowledge_base = knowledge_base # shared! - self.in_proj = ScaledLinear(embedding_dim, M * N, - initial_scale=1.0) + self.in_proj = ScaledLinear(embedding_dim, M * N, initial_scale=1.0) # initial_scale = 4.0 because the knowlege_base activations are # quite small -- if we use our optimizer they'll have stddev <= 0.1. - self.out_proj = ScaledLinear(D, embedding_dim, - initial_scale = 4.0) + self.out_proj = ScaledLinear(D, embedding_dim, initial_scale=4.0) self.M = M self.N = N self.K = K @@ -210,14 +224,14 @@ class KnowledgeBaseLookup(nn.Module): # TODO: later we can try multiplying by a projection of x or something like that. """ - x = self.in_proj(x) # now (*, M*N) - x = x.reshape(*x.shape[:-1], self.N, self.M) # now (*, N, M) - x = x.log_softmax(dim=-1) # now normalized logprobs, dim= (*, N, M) + x = self.in_proj(x) # now (*, M*N) + x = x.reshape(*x.shape[:-1], self.N, self.M) # now (*, N, M) + x = x.log_softmax(dim=-1) # now normalized logprobs, dim= (*, N, M) x = PenalizeNegentropyFunction.apply(x, self.negentropy_penalty) _, indexes, weights = sample_combined(x, self.K, input_is_log=True) - x = weighted_matrix_lookup(weights, indexes, self.knowledge_base) # now (*, D) - x = self.out_proj(x) # now (*, self.embedding_dim) + x = weighted_matrix_lookup(weights, indexes, self.knowledge_base) # now (*, D) + x = self.out_proj(x) # now (*, self.embedding_dim) return x @@ -237,38 +251,44 @@ def _test_knowledge_base_lookup(): x.requires_grad = True y = m(x) assert y.shape == x.shape - y.sum().backward() # make sure backward doesn't crash.. + y.sum().backward() # make sure backward doesn't crash.. print("y = ", y) print("x.grad = ", x.grad) print("knowlege_base.grad norm = ", knowledge_base.grad.norm()) dtype = torch.float32 - device = torch.device('cuda') - train_pairs = [ (torch.randn(B, T, E, device=device, dtype=dtype), torch.randn(B, T, E, device=device, dtype=dtype)) for _ in range(10) ] + device = torch.device("cuda") + train_pairs = [ + ( + torch.randn(B, T, E, device=device, dtype=dtype), + torch.randn(B, T, E, device=device, dtype=dtype), + ) + for _ in range(10) + ] from optim import Eve + optimizer = Eve(m.parameters(), lr=0.005, eps=1.0e-04) m = m.to(device).to(dtype) - start = timeit.default_timer() -# Epoch 0, batch 0, loss 1.0109944343566895 -# Epoch 10, batch 0, loss 1.0146660804748535 -# Epoch 20, batch 0, loss 1.0119813680648804 -# Epoch 30, batch 0, loss 1.0105408430099487 -# Epoch 40, batch 0, loss 1.0077732801437378 -# Epoch 50, batch 0, loss 1.0050103664398193 -# Epoch 60, batch 0, loss 1.0033129453659058 -# Epoch 70, batch 0, loss 1.0014232397079468 -# Epoch 80, batch 0, loss 0.9977912306785583 -# Epoch 90, batch 0, loss 0.8274348974227905 -# Epoch 100, batch 0, loss 0.3368612825870514 -# Epoch 110, batch 0, loss 0.11323091387748718 -# Time taken: 17.591704960912466 + # Epoch 0, batch 0, loss 1.0109944343566895 + # Epoch 10, batch 0, loss 1.0146660804748535 + # Epoch 20, batch 0, loss 1.0119813680648804 + # Epoch 30, batch 0, loss 1.0105408430099487 + # Epoch 40, batch 0, loss 1.0077732801437378 + # Epoch 50, batch 0, loss 1.0050103664398193 + # Epoch 60, batch 0, loss 1.0033129453659058 + # Epoch 70, batch 0, loss 1.0014232397079468 + # Epoch 80, batch 0, loss 0.9977912306785583 + # Epoch 90, batch 0, loss 0.8274348974227905 + # Epoch 100, batch 0, loss 0.3368612825870514 + # Epoch 110, batch 0, loss 0.11323091387748718 + # Time taken: 17.591704960912466 for epoch in range(150): - for n, (x,y) in enumerate(train_pairs): + for n, (x, y) in enumerate(train_pairs): y_out = m(x) - loss = ((y_out - y)**2).mean() * 100.0 + loss = ((y_out - y) ** 2).mean() * 100.0 if n % 10 == 0 and epoch % 10 == 0: print(f"Epoch {epoch}, batch {n}, loss {loss.item()}") loss.backward() @@ -276,7 +296,8 @@ def _test_knowledge_base_lookup(): optimizer.zero_grad() stop = timeit.default_timer() - print('Time taken: ', stop - start) + print("Time taken: ", stop - start) + def _test_knowledge_base_lookup_autocast(): K = 16 @@ -294,14 +315,18 @@ def _test_knowledge_base_lookup_autocast(): x.requires_grad = True y = m(x) assert y.shape == x.shape - y.sum().backward() # make sure backward doesn't crash.. + y.sum().backward() # make sure backward doesn't crash.. print("y = ", y) print("x.grad = ", x.grad) print("knowlege_base.grad norm = ", knowledge_base.grad.norm()) - device = torch.device('cuda') - train_pairs = [ (torch.randn(B, T, E, device=device), torch.randn(B, T, E, device=device)) for _ in range(10) ] + device = torch.device("cuda") + train_pairs = [ + (torch.randn(B, T, E, device=device), torch.randn(B, T, E, device=device)) + for _ in range(10) + ] from optim import Eve + optimizer = Eve(m.parameters(), lr=0.005, eps=1.0e-04) m = m.to(device) @@ -309,12 +334,11 @@ def _test_knowledge_base_lookup_autocast(): start = timeit.default_timer() - for epoch in range(150): - for n, (x,y) in enumerate(train_pairs): + for n, (x, y) in enumerate(train_pairs): y_out = m(x) with torch.cuda.amp.autocast(enabled=True): - loss = ((y_out - y)**2).mean() * 100.0 + loss = ((y_out - y) ** 2).mean() * 100.0 if n % 10 == 0 and epoch % 10 == 0: print(f"Epoch {epoch}, batch {n}, loss {loss.item()}") scaler.scale(loss).backward() @@ -323,10 +347,9 @@ def _test_knowledge_base_lookup_autocast(): optimizer.zero_grad() stop = timeit.default_timer() - print('Time taken: ', stop - start) + print("Time taken: ", stop - start) - -if __name__ == '__main__': +if __name__ == "__main__": _test_knowledge_base_lookup() _test_knowledge_base_lookup_autocast() diff --git a/egs/librispeech/ASR/pruned2_knowledge/scaling.py b/egs/librispeech/ASR/pruned2_knowledge/scaling.py index f726c2583..527c735eb 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/scaling.py +++ b/egs/librispeech/ASR/pruned2_knowledge/scaling.py @@ -18,11 +18,11 @@ import collections from itertools import repeat from typing import Optional, Tuple -from torch.cuda.amp import custom_fwd, custom_bwd import torch import torch.nn as nn from torch import Tensor +from torch.cuda.amp import custom_bwd, custom_fwd def _ntuple(n): @@ -79,9 +79,7 @@ class ActivationBalancerFunction(torch.autograd.Function): below_threshold = mean_abs < min_abs above_threshold = mean_abs > max_abs - ctx.save_for_backward( - factor, xgt0, below_threshold, above_threshold - ) + ctx.save_for_backward(factor, xgt0, below_threshold, above_threshold) ctx.max_factor = max_factor ctx.sum_dims = sum_dims return x @@ -149,8 +147,7 @@ class BasicNorm(torch.nn.Module): def forward(self, x: Tensor) -> Tensor: assert x.shape[self.channel_dim] == self.num_channels scales = ( - torch.mean(x ** 2, dim=self.channel_dim, keepdim=True) - + self.eps.exp() + torch.mean(x**2, dim=self.channel_dim, keepdim=True) + self.eps.exp() ) ** -0.5 return x * scales @@ -182,11 +179,7 @@ class ScaledLinear(nn.Linear): """ def __init__( - self, - *args, - initial_scale: float = 1.0, - initial_speed: float = 1.0, - **kwargs + self, *args, initial_scale: float = 1.0, initial_speed: float = 1.0, **kwargs ): super(ScaledLinear, self).__init__(*args, **kwargs) initial_scale = torch.tensor(initial_scale).log() @@ -202,12 +195,12 @@ class ScaledLinear(nn.Linear): def _reset_parameters(self, initial_speed: float): std = 0.1 / initial_speed - a = (3 ** 0.5) * std + a = (3**0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: nn.init.constant_(self.bias, 0.0) fan_in = self.weight.shape[1] * self.weight[0][0].numel() - scale = fan_in ** -0.5 # 1/sqrt(fan_in) + scale = fan_in**-0.5 # 1/sqrt(fan_in) with torch.no_grad(): self.weight_scale += torch.tensor(scale / std).log() @@ -218,19 +211,13 @@ class ScaledLinear(nn.Linear): return None if self.bias is None else self.bias * self.bias_scale.exp() def forward(self, input: Tensor) -> Tensor: - return torch.nn.functional.linear( - input, self.get_weight(), self.get_bias() - ) + return torch.nn.functional.linear(input, self.get_weight(), self.get_bias()) class ScaledConv1d(nn.Conv1d): # See docs for ScaledLinear def __init__( - self, - *args, - initial_scale: float = 1.0, - initial_speed: float = 1.0, - **kwargs + self, *args, initial_scale: float = 1.0, initial_speed: float = 1.0, **kwargs ): super(ScaledConv1d, self).__init__(*args, **kwargs) initial_scale = torch.tensor(initial_scale).log() @@ -245,12 +232,12 @@ class ScaledConv1d(nn.Conv1d): def _reset_parameters(self, initial_speed: float): std = 0.1 / initial_speed - a = (3 ** 0.5) * std + a = (3**0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: nn.init.constant_(self.bias, 0.0) fan_in = self.weight.shape[1] * self.weight[0][0].numel() - scale = fan_in ** -0.5 # 1/sqrt(fan_in) + scale = fan_in**-0.5 # 1/sqrt(fan_in) with torch.no_grad(): self.weight_scale += torch.tensor(scale / std).log() @@ -290,11 +277,7 @@ class ScaledConv1d(nn.Conv1d): class ScaledConv2d(nn.Conv2d): # See docs for ScaledLinear def __init__( - self, - *args, - initial_scale: float = 1.0, - initial_speed: float = 1.0, - **kwargs + self, *args, initial_scale: float = 1.0, initial_speed: float = 1.0, **kwargs ): super(ScaledConv2d, self).__init__(*args, **kwargs) initial_scale = torch.tensor(initial_scale).log() @@ -309,12 +292,12 @@ class ScaledConv2d(nn.Conv2d): def _reset_parameters(self, initial_speed: float): std = 0.1 / initial_speed - a = (3 ** 0.5) * std + a = (3**0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: nn.init.constant_(self.bias, 0.0) fan_in = self.weight.shape[1] * self.weight[0][0].numel() - scale = fan_in ** -0.5 # 1/sqrt(fan_in) + scale = fan_in**-0.5 # 1/sqrt(fan_in) with torch.no_grad(): self.weight_scale += torch.tensor(scale / std).log() @@ -653,9 +636,7 @@ def _test_activation_balancer_sign(): def _test_activation_balancer_magnitude(): magnitudes = torch.arange(0, 1, 0.01) N = 1000 - x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze( - -1 - ) + x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1) x = x.detach() x.requires_grad = True m = ActivationBalancer( @@ -685,8 +666,8 @@ def _test_basic_norm(): y = m(x) assert y.shape == x.shape - x_rms = (x ** 2).mean().sqrt() - y_rms = (y ** 2).mean().sqrt() + x_rms = (x**2).mean().sqrt() + y_rms = (y**2).mean().sqrt() print("x rms = ", x_rms) print("y rms = ", y_rms) assert y_rms < x_rms diff --git a/egs/librispeech/ASR/pruned2_knowledge/scaling_tmp.py b/egs/librispeech/ASR/pruned2_knowledge/scaling_tmp.py index 6293e081a..3f21133a0 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/scaling_tmp.py +++ b/egs/librispeech/ASR/pruned2_knowledge/scaling_tmp.py @@ -15,21 +15,23 @@ # limitations under the License. +from typing import Optional, Tuple + import torch import torch.nn as nn from torch import Tensor -from typing import Tuple, Optional - -def _activation_balancer_loss(mean_pos: Tensor, - mean_neg: Tensor, - min_positive: float, # e.g. 0.05 - max_positive: float, # e.g. 0.95 - max_factor: float, # e.g. 0.01 - min_abs: float, # e.g. 0.2 - max_abs: float, # e.g. 100.0 - eps: float = 1.0e-10): +def _activation_balancer_loss( + mean_pos: Tensor, + mean_neg: Tensor, + min_positive: float, # e.g. 0.05 + max_positive: float, # e.g. 0.95 + max_factor: float, # e.g. 0.01 + min_abs: float, # e.g. 0.2 + max_abs: float, # e.g. 100.0 + eps: float = 1.0e-10, +): """ Returns a loss-function for the ActivationBalancer module. This loss function is not exposed to the user but is used internally, and eventually @@ -50,28 +52,32 @@ def _activation_balancer_loss(mean_pos: Tensor, """ loss_parts = [] - x_mean = mean_positive - mean_negative - x_mean_abs = (mean_positive + mean_negative + eps).detach() - x_rel_mean= x_mean / x_mean_abs + x_mean = mean_pos - mean_neg + x_mean_abs = (mean_pos + mean_neg + eps).detach() + x_rel_mean = x_mean / x_mean_abs if min_positive != 0.0: # e.g. x_mean_floor = -0.95 + 0.05 = -0.9 - x_rel_mean_floor = (-(1-min_positive) + min_positive) - min_positive_loss = (x_rel_mean_floor - x_rel_mean).relu().sum() * (1.0 / (2*min_positive)) + x_rel_mean_floor = -(1 - min_positive) + min_positive + min_positive_loss = (x_rel_mean_floor - x_rel_mean).relu().sum() * ( + 1.0 / (2 * min_positive) + ) # this part of the loss would be 1.0 * num_channels if all these constraints were # 100% violated. loss_parts.append(min_positive_loss) if max_positive != 1.0: # e.g. x_mean_floor = -0.05 + 0.95 = 0.8 - x_rel_mean_ceil = - (1.0-max_positive) + max_positive - max_positive_loss = (x_rel_mean - x_rel_mean_ceil).relu().sum() * (1.0 / (1 - x_rel_mean_ceil)) + x_rel_mean_ceil = -(1.0 - max_positive) + max_positive + max_positive_loss = (x_rel_mean - x_rel_mean_ceil).relu().sum() * ( + 1.0 / (1 - x_rel_mean_ceil) + ) # this part of the loss would be 1.0 * num_channels if all these constraints were # 100% violated. loss_parts.append(max_positive_loss) if min_abs != 0.0: - min_abs_loss = min_abs - x_mean_abs).relu().sum() / min_abs + min_abs_loss = (min_abs - x_mean_abs).relu().sum() / min_abs # this part of the loss would be 1.0 * num_channels if all these constraints were # 100% violated. loss_parts.append(min_abs_loss) @@ -82,43 +88,53 @@ def _activation_balancer_loss(mean_pos: Tensor, # 100% violated. loss_parts.append(max_abs_loss) - # the min_positive and 1 - max_positive are "ballast" added to the denom = mean_pos + mean_neg + (min_positive + (1 - max_positive)) - num + # num if min_positive != 0.0: - - + pass class ActivationBalancerFunction(torch.autograd.Function): @staticmethod - def forward(ctx, x: Tensor, - channel_dim: int, - min_positive: float, # e.g. 0.05 - max_positive: float, # e.g. 0.95 - max_factor: float, # e.g. 0.01 - min_abs: float, # e.g. 0.2 - max_abs: float, # e.g. 100.0 + def forward( + ctx, + x: Tensor, + channel_dim: int, + min_positive: float, # e.g. 0.05 + max_positive: float, # e.g. 0.95 + max_factor: float, # e.g. 0.01 + min_abs: float, # e.g. 0.2 + max_abs: float, # e.g. 100.0 ) -> Tensor: if x.requires_grad: if channel_dim < 0: channel_dim += x.ndim sum_dims = [d for d in range(x.ndim) if d != channel_dim] xgt0 = x > 0 - proportion_positive = torch.mean(xgt0.to(x.dtype), dim=sum_dims, keepdim=True) - factor1 = ((min_positive - proportion_positive).relu() * (max_factor / min_positive) - if min_positive != 0.0 else 0.0) - factor2 = ((proportion_positive - max_positive).relu() * (max_factor / (max_positive - 1.0)) - if max_positive != 1.0 else 0.0) + proportion_positive = torch.mean( + xgt0.to(x.dtype), dim=sum_dims, keepdim=True + ) + factor1 = ( + (min_positive - proportion_positive).relu() + * (max_factor / min_positive) + if min_positive != 0.0 + else 0.0 + ) + factor2 = ( + (proportion_positive - max_positive).relu() + * (max_factor / (max_positive - 1.0)) + if max_positive != 1.0 + else 0.0 + ) factor = factor1 + factor2 if isinstance(factor, float): factor = torch.zeros_like(proportion_positive) mean_abs = torch.mean(x.abs(), dim=sum_dims, keepdim=True) - below_threshold = (mean_abs < min_abs) - above_threshold = (mean_abs > max_abs) + below_threshold = mean_abs < min_abs + above_threshold = mean_abs > max_abs ctx.save_for_backward(factor, xgt0, below_threshold, above_threshold) ctx.max_factor = max_factor @@ -126,11 +142,16 @@ class ActivationBalancerFunction(torch.autograd.Function): return x @staticmethod - def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None, None, None, None]: + def backward( + ctx, x_grad: Tensor + ) -> Tuple[Tensor, None, None, None, None, None, None]: factor, xgt0, below_threshold, above_threshold = ctx.saved_tensors dtype = x_grad.dtype - scale_factor = ((below_threshold.to(dtype) - above_threshold.to(dtype)) * - (xgt0.to(dtype) - 0.5) * (ctx.max_factor * 2.0)) + scale_factor = ( + (below_threshold.to(dtype) - above_threshold.to(dtype)) + * (xgt0.to(dtype) - 0.5) + * (ctx.max_factor * 2.0) + ) neg_delta_grad = x_grad.abs() * (factor + scale_factor) return x_grad - neg_delta_grad, None, None, None, None, None, None @@ -163,29 +184,30 @@ class BasicNorm(torch.nn.Module): learn_eps: if true, we learn epsilon; if false, we keep it at the initial value. """ - def __init__(self, - num_channels: int, - channel_dim: int = -1, # CAUTION: see documentation. - eps: float = 0.25, - learn_eps: bool = True) -> None: + + def __init__( + self, + num_channels: int, + channel_dim: int = -1, # CAUTION: see documentation. + eps: float = 0.25, + learn_eps: bool = True, + ) -> None: super(BasicNorm, self).__init__() self.num_channels = num_channels self.channel_dim = channel_dim if learn_eps: self.eps = nn.Parameter(torch.tensor(eps).log().detach()) else: - self.register_buffer('eps', torch.tensor(eps).log().detach()) - + self.register_buffer("eps", torch.tensor(eps).log().detach()) def forward(self, x: Tensor) -> Tensor: assert x.shape[self.channel_dim] == self.num_channels - scales = (torch.mean(x**2, dim=self.channel_dim, keepdim=True) + - self.eps.exp()) ** -0.5 + scales = ( + torch.mean(x**2, dim=self.channel_dim, keepdim=True) + self.eps.exp() + ) ** -0.5 return x * scales - - class ScaledLinear(nn.Linear): """ A modified version of nn.Linear where the parameters are scaled before @@ -207,27 +229,26 @@ class ScaledLinear(nn.Linear): inherited from nn.Linear. For modules with small fan-in, this may be larger than optimal. """ - def __init__(self, *args, - initial_scale: float = 1.0, - **kwargs): + + def __init__(self, *args, initial_scale: float = 1.0, **kwargs): super(ScaledLinear, self).__init__(*args, **kwargs) initial_scale = torch.tensor(initial_scale).log() self.weight_scale = nn.Parameter(initial_scale.clone().detach()) if self.bias is not None: self.bias_scale = nn.Parameter(initial_scale.clone().detach()) else: - self.register_parameter('bias_scale', None) + self.register_parameter("bias_scale", None) self._reset_parameters() # Overrides the reset_parameters in nn.Linear def _reset_parameters(self): std = 0.01 - a = (3 ** 0.5) * std + a = (3**0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: nn.init.constant_(self.bias, 0.0) fan_in = self.weight.shape[1] * self.weight[0][0].numel() - scale = fan_in ** -0.5 # 1/sqrt(fan_in) + scale = fan_in**-0.5 # 1/sqrt(fan_in) with torch.no_grad(): self.weight_scale += torch.tensor(scale / std).log() if self.bias is not None: @@ -237,56 +258,67 @@ class ScaledLinear(nn.Linear): return self.weight * self.weight_scale.exp() def get_bias(self): - return (None if self.bias is None else - self.bias * self.bias_scale.exp()) + return None if self.bias is None else self.bias * self.bias_scale.exp() def forward(self, input: Tensor) -> Tensor: - return torch.nn.functional.linear(input, self.get_weight(), - self.get_bias()) + return torch.nn.functional.linear(input, self.get_weight(), self.get_bias()) class ScaledConv1d(nn.Conv1d): - def __init__(self, *args, - initial_scale=1.0, **kwargs): + def __init__(self, *args, initial_scale=1.0, **kwargs): super(ScaledConv1d, self).__init__(*args, **kwargs) initial_scale = torch.tensor(initial_scale).log() self.weight_scale = nn.Parameter(initial_scale.clone().detach()) if self.bias is not None: self.bias_scale = nn.Parameter(initial_scale.clone().detach()) else: - self.register_parameter('bias_scale', None) + self.register_parameter("bias_scale", None) self._reset_parameters() # Overrides the reset_parameters in base class def _reset_parameters(self): std = 0.01 - a = (3 ** 0.5) * std + a = (3**0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: nn.init.constant_(self.bias, 0.0) fan_in = self.weight.shape[1] * self.weight[0][0].numel() - scale = fan_in ** -0.5 # 1/sqrt(fan_in) + scale = fan_in**-0.5 # 1/sqrt(fan_in) with torch.no_grad(): self.weight_scale += torch.tensor(scale / std).log() if self.bias is not None: self.bias_scale += torch.tensor(scale / std).log() - def get_weight(self): return self.weight * self.weight_scale.exp() def get_bias(self): - return (None if self.bias is None else - self.bias * self.bias_scale.exp()) + return None if self.bias is None else self.bias * self.bias_scale.exp() def forward(self, input: Tensor) -> Tensor: F = torch.nn.functional - if self.padding_mode != 'zeros': - return F.conv1d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode), - self.get_weight(), self.get_bias(), self.stride, - _single(0), self.dilation, self.groups) - return F.conv1d(input, self.get_weight(), self.get_bias(), self.stride, - self.padding, self.dilation, self.groups) - + if self.padding_mode != "zeros": + return F.conv1d( + F.pad( + input, + self._reversed_padding_repeated_twice, + mode=self.padding_mode, + ), + self.get_weight(), + self.get_bias(), + self.stride, + _single(0), # noqa: F821 + self.dilation, + self.groups, + ) + return F.conv1d( + input, + self.get_weight(), + self.get_bias(), + self.stride, + self.padding, + self.dilation, + self.groups, + ) class ScaledConv2d(nn.Conv2d): @@ -297,45 +329,58 @@ class ScaledConv2d(nn.Conv2d): if self.bias is not None: self.bias_scale = nn.Parameter(initial_scale.clone().detach()) else: - self.register_parameter('bias_scale', None) + self.register_parameter("bias_scale", None) self._reset_parameters() # Overrides the reset_parameters in base class def _reset_parameters(self): std = 0.01 - a = (3 ** 0.5) * std + a = (3**0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: nn.init.constant_(self.bias, 0.0) fan_in = self.weight.shape[1] * self.weight[0][0].numel() - scale = fan_in ** -0.5 # 1/sqrt(fan_in) + scale = fan_in**-0.5 # 1/sqrt(fan_in) with torch.no_grad(): self.weight_scale += torch.tensor(scale / std).log() if self.bias is not None: self.bias_scale += torch.tensor(scale / std).log() - def get_weight(self): return self.weight * self.weight_scale.exp() def get_bias(self): - return (None if self.bias is None else - self.bias * self.bias_scale.exp()) + return None if self.bias is None else self.bias * self.bias_scale.exp() def _conv_forward(self, input, weight): F = torch.nn.functional - if self.padding_mode != 'zeros': - return F.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode), - weight, self.get_bias(), self.stride, - _pair(0), self.dilation, self.groups) - return F.conv2d(input, weight, self.get_bias(), self.stride, - self.padding, self.dilation, self.groups) + if self.padding_mode != "zeros": + return F.conv2d( + F.pad( + input, + self._reversed_padding_repeated_twice, + mode=self.padding_mode, + ), + weight, + self.get_bias(), + self.stride, + _pair(0), # noqa: F821 + self.dilation, + self.groups, + ) + return F.conv2d( + input, + weight, + self.get_bias(), + self.stride, + self.padding, + self.dilation, + self.groups, + ) def forward(self, input: Tensor) -> Tensor: return self._conv_forward(input, self.get_weight()) - - class ActivationBalancer(torch.nn.Module): """ Modifies the backpropped derivatives of a function to try to encourage, for @@ -364,12 +409,16 @@ class ActivationBalancer(torch.nn.Module): we allow, before we start to modify the derivatives to prevent this. """ - def __init__(self, channel_dim: int, - min_positive: float = 0.05, - max_positive: float = 0.95, - max_factor: float = 0.01, - min_abs: float = 0.2, - max_abs: float = 100.0): + + def __init__( + self, + channel_dim: int, + min_positive: float = 0.05, + max_positive: float = 0.95, + max_factor: float = 0.01, + min_abs: float = 0.2, + max_abs: float = 100.0, + ): super(ActivationBalancer, self).__init__() self.channel_dim = channel_dim self.min_positive = min_positive @@ -379,10 +428,15 @@ class ActivationBalancer(torch.nn.Module): self.max_abs = max_abs def forward(self, x: Tensor) -> Tensor: - return ActivationBalancerFunction.apply(x, self.channel_dim, - self.min_positive, self.max_positive, - self.max_factor, self.min_abs, - self.max_abs) + return ActivationBalancerFunction.apply( + x, + self.channel_dim, + self.min_positive, + self.max_positive, + self.max_factor, + self.min_abs, + self.max_abs, + ) class DoubleSwishFunction(torch.autograd.Function): @@ -400,6 +454,7 @@ class DoubleSwishFunction(torch.autograd.Function): = double_swish(x) * (1-s(x)) + s(x) ... so we just need to remember s(x) but not x itself. """ + @staticmethod def forward(ctx, x: Tensor) -> Tensor: x = x.detach() @@ -411,18 +466,17 @@ class DoubleSwishFunction(torch.autograd.Function): @staticmethod def backward(ctx, y_grad: Tensor) -> Tensor: s, y = ctx.saved_tensors - return (y * (1-s) + s) * y_grad + return (y * (1 - s) + s) * y_grad + class DoubleSwish(torch.nn.Module): def forward(self, x: Tensor) -> Tensor: """Return double-swish activation function which is an approximation to Swish(Swish(x)), - that we approximate closely with x * sigmoid(x-1). + that we approximate closely with x * sigmoid(x-1). """ return DoubleSwishFunction.apply(x) - - class ScaledEmbedding(nn.Module): r"""A simple lookup table that stores embeddings of a fixed dictionary and size. @@ -491,8 +545,13 @@ class ScaledEmbedding(nn.Module): [ 0.0000, 0.0000, 0.0000], [-0.1655, 0.9897, 0.0635]]]) """ - __constants__ = ['num_embeddings', 'embedding_dim', 'padding_idx', - 'scale_grad_by_freq', 'sparse'] + __constants__ = [ + "num_embeddings", + "embedding_dim", + "padding_idx", + "scale_grad_by_freq", + "sparse", + ] num_embeddings: int embedding_dim: int @@ -501,33 +560,40 @@ class ScaledEmbedding(nn.Module): weight: Tensor sparse: bool - def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None, - scale_grad_by_freq: bool = False, - sparse: bool = False) -> None: + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + padding_idx: Optional[int] = None, + scale_grad_by_freq: bool = False, + sparse: bool = False, + ) -> None: super(ScaledEmbedding, self).__init__() self.num_embeddings = num_embeddings self.embedding_dim = embedding_dim if padding_idx is not None: if padding_idx > 0: - assert padding_idx < self.num_embeddings, 'Padding_idx must be within num_embeddings' + assert ( + padding_idx < self.num_embeddings + ), "Padding_idx must be within num_embeddings" elif padding_idx < 0: - assert padding_idx >= -self.num_embeddings, 'Padding_idx must be within num_embeddings' + assert ( + padding_idx >= -self.num_embeddings + ), "Padding_idx must be within num_embeddings" padding_idx = self.num_embeddings + padding_idx self.padding_idx = padding_idx self.scale_grad_by_freq = scale_grad_by_freq - self.scale = nn.Parameter(torch.zeros(())) # see reset_parameters() + self.scale = nn.Parameter(torch.zeros(())) # see reset_parameters() self.sparse = sparse self.weight = nn.Parameter(torch.Tensor(num_embeddings, embedding_dim)) self.reset_parameters() - - def reset_parameters(self) -> None: std = 0.01 nn.init.normal_(self.weight, std=std) - nn.init.constant_(self.scale, torch.tensor(1.0/std).log()) + nn.init.constant_(self.scale, torch.tensor(1.0 / std).log()) if self.padding_idx is not None: with torch.no_grad(): @@ -537,24 +603,37 @@ class ScaledEmbedding(nn.Module): F = torch.nn.functional scale = self.scale.exp() if input.numel() < self.num_embeddings: - return F.embedding( - input, self.weight, self.padding_idx, - None, 2.0, # None, 2.0 relate to normalization - self.scale_grad_by_freq, self.sparse) * scale + return ( + F.embedding( + input, + self.weight, + self.padding_idx, + None, + 2.0, # None, 2.0 relate to normalization + self.scale_grad_by_freq, + self.sparse, + ) + * scale + ) else: return F.embedding( - input, self.weight * scale, self.padding_idx, - None, 2.0, # None, 2.0 relates to normalization - self.scale_grad_by_freq, self.sparse) + input, + self.weight * scale, + self.padding_idx, + None, + 2.0, # None, 2.0 relates to normalization + self.scale_grad_by_freq, + self.sparse, + ) def extra_repr(self) -> str: - s = '{num_embeddings}, {embedding_dim}, scale={scale}' + s = "{num_embeddings}, {embedding_dim}, scale={scale}" if self.padding_idx is not None: - s += ', padding_idx={padding_idx}' + s += ", padding_idx={padding_idx}" if self.scale_grad_by_freq is not False: - s += ', scale_grad_by_freq={scale_grad_by_freq}' + s += ", scale_grad_by_freq={scale_grad_by_freq}" if self.sparse is not False: - s += ', sparse=True' + s += ", sparse=True" return s.format(**self.__dict__) @@ -565,8 +644,13 @@ def _test_activation_balancer_sign(): x = 1.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1)) x = x.detach() x.requires_grad = True - m = ActivationBalancer(channel_dim=0, min_positive=0.05, max_positive=0.95, - max_factor=0.2, min_abs=0.0) + m = ActivationBalancer( + channel_dim=0, + min_positive=0.05, + max_positive=0.95, + max_factor=0.2, + min_abs=0.0, + ) y_grad = torch.sign(torch.randn(probs.numel(), N)) @@ -576,17 +660,22 @@ def _test_activation_balancer_sign(): print("_test_activation_balancer_sign: y grad = ", y_grad) print("_test_activation_balancer_sign: x grad = ", x.grad) + def _test_activation_balancer_magnitude(): channel_dim = 0 magnitudes = torch.arange(0, 1, 0.01) N = 1000 - x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1) + x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1) x = x.detach() x.requires_grad = True - m = ActivationBalancer(channel_dim=0, - min_positive=0.0, max_positive=1.0, - max_factor=0.2, - min_abs=0.2, max_abs=0.8) + m = ActivationBalancer( + channel_dim=0, + min_positive=0.0, + max_positive=1.0, + max_factor=0.2, + min_abs=0.2, + max_abs=0.8, + ) y_grad = torch.sign(torch.randn(magnitudes.numel(), N)) @@ -621,7 +710,7 @@ def _test_double_swish_deriv(): torch.autograd.gradcheck(m, x) -if __name__ == '__main__': +if __name__ == "__main__": _test_activation_balancer_sign() _test_activation_balancer_magnitude() _test_basic_norm() diff --git a/egs/librispeech/ASR/pruned2_knowledge/train.py b/egs/librispeech/ASR/pruned2_knowledge/train.py index 2f6840166..c322abaf8 100755 --- a/egs/librispeech/ASR/pruned2_knowledge/train.py +++ b/egs/librispeech/ASR/pruned2_knowledge/train.py @@ -78,9 +78,7 @@ from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def get_parser(): @@ -179,8 +177,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( @@ -203,8 +200,7 @@ def get_parser(): "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network)" "part.", ) parser.add_argument( @@ -554,23 +550,16 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 - if warmup < 1.0 - else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = ( - params.simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss + 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) ) + loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() @@ -733,9 +722,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -835,7 +822,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2 ** 22 + 2**22 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/decode.py b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/decode.py index 2d5724d30..891719f3d 100755 --- a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/decode.py +++ b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/decode.py @@ -204,8 +204,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -272,9 +271,7 @@ def decode_one_batch( value=LOG_EPS, ) - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) hyps = [] if params.decoding_method == "fast_beam_search": @@ -289,10 +286,7 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -415,9 +409,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -450,8 +442,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -494,9 +485,7 @@ def main(): params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -528,9 +517,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -557,9 +546,9 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" diff --git a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/emformer.py b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/emformer.py index 318cd5094..008f40fb1 100644 --- a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/emformer.py +++ b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/emformer.py @@ -272,13 +272,9 @@ class Emformer(EncoderInterface): # Caution: We assume the subsampling factor is 4! x_lens = (((x_lens - 1) >> 1) - 1) >> 1 - emformer_out, emformer_out_lens, states = self.model.infer( - x, x_lens, states - ) + emformer_out, emformer_out_lens, states = self.model.infer(x, x_lens, states) - if x.size(1) != ( - self.model.segment_length + self.model.right_context_length - ): + if x.size(1) != (self.model.segment_length + self.model.right_context_length): raise ValueError( "Incorrect input shape." f"{x.size(1)} vs {self.model.segment_length} + " diff --git a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/export.py b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/export.py index 2375f5001..047a1d476 100755 --- a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/export.py +++ b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/export.py @@ -133,8 +133,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) add_model_arguments(parser) @@ -170,9 +169,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -199,9 +198,9 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -273,9 +272,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/model.py b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/model.py index 2f019bcdb..ed6848879 100644 --- a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/model.py +++ b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/model.py @@ -122,9 +122,7 @@ class Transducer(nn.Module): y_padded = y.pad(mode="constant", padding_value=0) y_padded = y_padded.to(torch.int64) - boundary = torch.zeros( - (x.size(0), 4), dtype=torch.int64, device=x.device - ) + boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) boundary[:, 2] = y_lens boundary[:, 3] = x_lens diff --git a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/train.py b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/train.py index fed814f19..69e74cc57 100755 --- a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/train.py +++ b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/train.py @@ -209,8 +209,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( @@ -233,8 +232,7 @@ def get_parser(): "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network)" "part.", ) parser.add_argument( @@ -566,11 +564,7 @@ def compute_loss( function enables autograd during computation; when it is False, it disables autograd. """ - device = ( - model.device - if isinstance(model, DDP) - else next(model.parameters()).device - ) + device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -599,9 +593,7 @@ def compute_loss( info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa info["utterances"] = feature.size(0) @@ -782,9 +774,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -908,8 +898,7 @@ def run(rank, world_size, args): # the threshold if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. " - f"Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. " f"Duration: {c.duration}" ) return False diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py index 7af9cc3d7..830b37cfb 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py @@ -509,9 +509,9 @@ def greedy_search( y = logits.argmax().item() if y not in (blank_id, unk_id): hyp.append(y) - decoder_input = torch.tensor( - [hyp[-context_size:]], device=device - ).reshape(1, context_size) + decoder_input = torch.tensor([hyp[-context_size:]], device=device).reshape( + 1, context_size + ) decoder_out = model.decoder(decoder_input, need_pad=False) @@ -670,9 +670,7 @@ class HypothesisList(object): if use_max: old_hyp.log_prob = max(old_hyp.log_prob, hyp.log_prob) else: - torch.logaddexp( - old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob - ) + torch.logaddexp(old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob) else: self._data[key] = hyp @@ -688,9 +686,7 @@ class HypothesisList(object): Return the hypothesis that has the largest `log_prob`. """ if length_norm: - return max( - self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys) - ) + return max(self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys)) else: return max(self._data.values(), key=lambda hyp: hyp.log_prob) @@ -892,9 +888,7 @@ def modified_beam_search( log_probs_shape = k2.ragged.create_ragged_shape2( row_splits=row_splits, cached_tot_size=log_probs.numel() ) - ragged_log_probs = k2.RaggedTensor( - shape=log_probs_shape, value=log_probs - ) + ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) for i in range(batch_size): topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) @@ -1088,9 +1082,7 @@ def beam_search( t = 0 B = HypothesisList() - B.add( - Hypothesis(ys=[blank_id] * context_size, log_prob=0.0), use_max=use_max - ) + B.add(Hypothesis(ys=[blank_id] * context_size, log_prob=0.0), use_max=use_max) max_sym_per_utt = 20000 @@ -1130,9 +1122,7 @@ def beam_search( cached_key += f"-t-{t}" if cached_key not in joint_cache: - logits = model.joiner( - current_encoder_out, decoder_out.unsqueeze(1) - ) + logits = model.joiner(current_encoder_out, decoder_out.unsqueeze(1)) # TODO(fangjun): Scale the blank posterior diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py index 7b6338948..12bd7f9bb 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py @@ -128,11 +128,7 @@ from beam_search import ( ) from train import add_model_arguments, get_params, get_transducer_model -from icefall.checkpoint import ( - average_checkpoints, - find_checkpoints, - load_checkpoint, -) +from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, @@ -269,8 +265,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -383,9 +378,7 @@ def decode_one_batch( simulate_streaming=True, ) else: - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) hyps = [] if ( @@ -450,10 +443,7 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -584,9 +574,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -619,8 +607,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -678,9 +665,7 @@ def main(): if "LG" in params.decoding_method: params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -718,8 +703,7 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -757,9 +741,7 @@ def main(): decoding_graph.scores *= params.ngram_lm_scale else: word_table = None - decoding_graph = k2.trivial_graph( - params.vocab_size - 1, device=device - ) + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) else: decoding_graph = None word_table = None diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/decode_stream.py b/egs/librispeech/ASR/pruned_transducer_stateless/decode_stream.py index 386248554..e522943c0 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless/decode_stream.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/decode_stream.py @@ -75,9 +75,7 @@ class DecodeStream(object): # encoder.streaming_forward self.done_frames: int = 0 - self.pad_length = ( - params.right_context + 2 - ) * params.subsampling_factor + 3 + self.pad_length = (params.right_context + 2) * params.subsampling_factor + 3 if params.decoding_method == "greedy_search": self.hyp = [params.blank_id] * params.context_size @@ -91,13 +89,11 @@ class DecodeStream(object): ) elif params.decoding_method == "fast_beam_search": # The rnnt_decoding_stream for fast_beam_search. - self.rnnt_decoding_stream: k2.RnntDecodingStream = ( - k2.RnntDecodingStream(decoding_graph) + self.rnnt_decoding_stream: k2.RnntDecodingStream = k2.RnntDecodingStream( + decoding_graph ) else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") @property def done(self) -> bool: @@ -126,13 +122,10 @@ class DecodeStream(object): """Consume chunk_size frames of features""" chunk_length = chunk_size + self.pad_length - ret_length = min( - self.num_frames - self.num_processed_frames, chunk_length - ) + ret_length = min(self.num_frames - self.num_processed_frames, chunk_length) ret_features = self.features[ - self.num_processed_frames : self.num_processed_frames # noqa - + ret_length + self.num_processed_frames : self.num_processed_frames + ret_length # noqa ] self.num_processed_frames += chunk_size diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/decoder.py b/egs/librispeech/ASR/pruned_transducer_stateless/decoder.py index f4355e8a0..72593173c 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless/decoder.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/decoder.py @@ -92,9 +92,7 @@ class Decoder(nn.Module): if self.context_size > 1: embedding_out = embedding_out.permute(0, 2, 1) if need_pad is True: - embedding_out = F.pad( - embedding_out, pad=(self.context_size - 1, 0) - ) + embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0)) else: # During inference time, there is no need to do extra padding # as we only need one output diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/export.py b/egs/librispeech/ASR/pruned_transducer_stateless/export.py index b5a151878..be45536d8 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/export.py @@ -105,8 +105,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( @@ -192,9 +191,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/model.py b/egs/librispeech/ASR/pruned_transducer_stateless/model.py index 73b651b3f..2cca7fa27 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/model.py @@ -130,9 +130,7 @@ class Transducer(nn.Module): y_padded = y.pad(mode="constant", padding_value=0) y_padded = y_padded.to(torch.int64) - boundary = torch.zeros( - (x.size(0), 4), dtype=torch.int64, device=x.device - ) + boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) boundary[:, 2] = y_lens boundary[:, 3] = x_lens diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless/pretrained.py index eb95827af..6e91e0501 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/pretrained.py @@ -168,8 +168,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -222,8 +221,7 @@ def read_sound_files( for f in filenames: wave, sample_rate = torchaudio.load(f) assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" + f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}" ) # We use only the first channel ans.append(wave[0]) @@ -292,9 +290,7 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) feature_lengths = torch.tensor(feature_lengths, device=device) @@ -381,9 +377,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/streaming_beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless/streaming_beam_search.py index dcf6dc42f..9e09200a1 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless/streaming_beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/streaming_beam_search.py @@ -166,14 +166,10 @@ def modified_beam_search( log_probs_shape = k2.ragged.create_ragged_shape2( row_splits=row_splits, cached_tot_size=log_probs.numel() ) - ragged_log_probs = k2.RaggedTensor( - shape=log_probs_shape, value=log_probs - ) + ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) for i in range(batch_size): - topk_log_probs, topk_indexes = ragged_log_probs[i].topk( - num_active_paths - ) + topk_log_probs, topk_indexes = ragged_log_probs[i].topk(num_active_paths) with warnings.catch_warnings(): warnings.simplefilter("ignore") diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py index d2cae4f9f..ce8e2f348 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py @@ -51,11 +51,7 @@ from streaming_beam_search import ( from torch.nn.utils.rnn import pad_sequence from train import add_model_arguments, get_params, get_transducer_model -from icefall.checkpoint import ( - average_checkpoints, - find_checkpoints, - load_checkpoint, -) +from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.utils import ( AttributeDict, setup_logger, @@ -162,8 +158,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( @@ -269,9 +264,7 @@ def decode_one_chunk( ) if params.decoding_method == "greedy_search": - greedy_search( - model=model, encoder_out=encoder_out, streams=decode_streams - ) + greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams) elif params.decoding_method == "fast_beam_search": processed_lens = processed_lens + encoder_out_lens fast_beam_search_one_best( @@ -291,9 +284,7 @@ def decode_one_chunk( num_active_paths=params.num_active_paths, ) else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") states = [torch.unbind(states[0], dim=2), torch.unbind(states[1], dim=2)] @@ -349,9 +340,7 @@ def decode_dataset( decode_results = [] # Contain decode streams currently running. decode_streams = [] - initial_states = model.encoder.get_init_state( - params.left_context, device=device - ) + initial_states = model.encoder.get_init_state(params.left_context, device=device) for num, cut in enumerate(cuts): # each utterance has a DecodeStream. decode_stream = DecodeStream( @@ -422,9 +411,7 @@ def decode_dataset( elif params.decoding_method == "modified_beam_search": key = f"num_active_paths_{params.num_active_paths}" else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") return {key: decode_results} @@ -460,8 +447,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -533,8 +519,7 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/train.py b/egs/librispeech/ASR/pruned_transducer_stateless/train.py index 399b11a29..7861df874 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/train.py @@ -203,8 +203,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( @@ -227,8 +226,7 @@ def get_parser(): "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network)" "part.", ) parser.add_argument( @@ -562,9 +560,7 @@ def compute_loss( # If either all simple_loss or pruned_loss is inf or nan, # we stop the training process by raising an exception - if torch.all(~simple_loss_is_finite) or torch.all( - ~pruned_loss_is_finite - ): + if torch.all(~simple_loss_is_finite) or torch.all(~pruned_loss_is_finite): raise ValueError( "There are too many utterances in this batch " "leading to inf or nan losses." @@ -584,9 +580,7 @@ def compute_loss( # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2 # (2) If some utterances in the batch lead to inf/nan loss, they # are filtered out. - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa info["utterances"] = feature.size(0) @@ -777,9 +771,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -897,8 +889,7 @@ def run(rank, world_size, args): # the threshold if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. " - f"Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. " f"Duration: {c.duration}" ) return False @@ -956,9 +947,7 @@ def run(rank, world_size, args): cur_lr = optimizer._rate if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) + tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) if rank == 0: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py index b7c2010f7..5e9428b60 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -580,9 +580,9 @@ def greedy_search( if y not in (blank_id, unk_id): hyp.append(y) timestamp.append(t) - decoder_input = torch.tensor( - [hyp[-context_size:]], device=device - ).reshape(1, context_size) + decoder_input = torch.tensor([hyp[-context_size:]], device=device).reshape( + 1, context_size + ) decoder_out = model.decoder(decoder_input, need_pad=False) decoder_out = model.joiner.decoder_proj(decoder_out) @@ -775,9 +775,7 @@ class HypothesisList(object): key = hyp.key if key in self: old_hyp = self._data[key] # shallow copy - torch.logaddexp( - old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob - ) + torch.logaddexp(old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob) else: self._data[key] = hyp @@ -793,9 +791,7 @@ class HypothesisList(object): Return the hypothesis that has the largest `log_prob`. """ if length_norm: - return max( - self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys) - ) + return max(self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys)) else: return max(self._data.values(), key=lambda hyp: hyp.log_prob) @@ -990,9 +986,7 @@ def modified_beam_search( logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) - log_probs = (logits / temperature).log_softmax( - dim=-1 - ) # (num_hyps, vocab_size) + log_probs = (logits / temperature).log_softmax(dim=-1) # (num_hyps, vocab_size) log_probs.add_(ys_log_probs) @@ -1004,9 +998,7 @@ def modified_beam_search( log_probs_shape = k2.ragged.create_ragged_shape2( row_splits=row_splits, cached_tot_size=log_probs.numel() ) - ragged_log_probs = k2.RaggedTensor( - shape=log_probs_shape, value=log_probs - ) + ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) for i in range(batch_size): topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) @@ -1676,9 +1668,7 @@ def fast_beam_search_with_nbest_rnn_rescoring( for rnn_scale in rnn_lm_scale_list: key = f"ngram_lm_scale_{n_scale}_rnn_lm_scale_{rnn_scale}" tot_scores = ( - am_scores.values - + n_scale * ngram_lm_scores - + rnn_scale * rnn_lm_scores + am_scores.values + n_scale * ngram_lm_scores + rnn_scale * rnn_lm_scores ) ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores) max_indexes = ragged_tot_scores.argmax() @@ -1804,9 +1794,7 @@ def modified_beam_search_ngram_rescoring( logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) - log_probs = (logits / temperature).log_softmax( - dim=-1 - ) # (num_hyps, vocab_size) + log_probs = (logits / temperature).log_softmax(dim=-1) # (num_hyps, vocab_size) log_probs.add_(ys_log_probs) vocab_size = log_probs.size(-1) @@ -1816,9 +1804,7 @@ def modified_beam_search_ngram_rescoring( log_probs_shape = k2.ragged.create_ragged_shape2( row_splits=row_splits, cached_tot_size=log_probs.numel() ) - ragged_log_probs = k2.RaggedTensor( - shape=log_probs_shape, value=log_probs - ) + ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) for i in range(batch_size): topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) @@ -1841,9 +1827,7 @@ def modified_beam_search_ngram_rescoring( state_cost = hyp.state_cost # We only keep AM scores in new_hyp.log_prob - new_log_prob = ( - topk_log_probs[k] - hyp.state_cost.lm_score * lm_scale - ) + new_log_prob = topk_log_probs[k] - hyp.state_cost.lm_score * lm_scale new_hyp = Hypothesis( ys=new_ys, log_prob=new_log_prob, state_cost=state_cost @@ -1995,9 +1979,7 @@ def modified_beam_search_rnnlm_shallow_fusion( log_probs_shape = k2.ragged.create_ragged_shape2( row_splits=row_splits, cached_tot_size=log_probs.numel() ) - ragged_log_probs = k2.RaggedTensor( - shape=log_probs_shape, value=log_probs - ) + ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) """ for all hyps with a non-blank new token, score this token. It is a little confusing here because this for-loop @@ -2032,10 +2014,7 @@ def modified_beam_search_rnnlm_shallow_fusion( # forward RNNLM to get new states and scores if len(token_list) != 0: tokens_to_score = ( - torch.tensor(token_list) - .to(torch.int64) - .to(device) - .reshape(-1, 1) + torch.tensor(token_list).to(torch.int64).to(device).reshape(-1, 1) ) hs = torch.cat(hs, dim=1).to(device) @@ -2067,9 +2046,7 @@ def modified_beam_search_rnnlm_shallow_fusion( ys.append(new_token) new_timestamp.append(t) - hyp_log_prob += ( - lm_score[new_token] * lm_scale - ) # add the lm score + hyp_log_prob += lm_score[new_token] * lm_scale # add the lm score lm_score = scores[count] state = ( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py index bc273d33b..f94ffef59 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py @@ -214,10 +214,7 @@ class Conformer(EncoderInterface): NOTE: the returned tensors are on the given device. """ - if ( - len(self._init_state) == 2 - and self._init_state[0].size(1) == left_context - ): + if len(self._init_state) == 2 and self._init_state[0].size(1) == left_context: # Note: It is OK to share the init state as it is # not going to be modified by the model return self._init_state @@ -439,9 +436,7 @@ class ConformerEncoderLayer(nn.Module): self.d_model = d_model - self.self_attn = RelPositionMultiheadAttention( - d_model, nhead, dropout=0.0 - ) + self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) self.feed_forward = nn.Sequential( ScaledLinear(d_model, dim_feedforward), @@ -459,9 +454,7 @@ class ConformerEncoderLayer(nn.Module): ScaledLinear(dim_feedforward, d_model, initial_scale=0.25), ) - self.conv_module = ConvolutionModule( - d_model, cnn_module_kernel, causal=causal - ) + self.conv_module = ConvolutionModule(d_model, cnn_module_kernel, causal=causal) self.norm_final = BasicNorm(d_model) @@ -527,9 +520,7 @@ class ConformerEncoderLayer(nn.Module): src = src + self.dropout(src_att) # convolution module - conv, _ = self.conv_module( - src, src_key_padding_mask=src_key_padding_mask - ) + conv, _ = self.conv_module(src, src_key_padding_mask=src_key_padding_mask) src = src + self.dropout(conv) # feed forward module @@ -785,9 +776,7 @@ class RelPositionalEncoding(torch.nn.Module): """ - def __init__( - self, d_model: int, dropout_rate: float, max_len: int = 5000 - ) -> None: + def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: """Construct an PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() if is_jit_tracing(): @@ -811,9 +800,7 @@ class RelPositionalEncoding(torch.nn.Module): # the length of self.pe is 2 * input_len - 1 if self.pe.size(1) >= x_size_1 * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str( - x.device - ): + if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return # Suppose `i` means to the position of query vector and `j` means the @@ -1127,9 +1114,9 @@ class RelPositionMultiheadAttention(nn.Module): if torch.equal(query, key) and torch.equal(key, value): # self-attention - q, k, v = nn.functional.linear( - query, in_proj_weight, in_proj_bias - ).chunk(3, dim=-1) + q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk( + 3, dim=-1 + ) elif torch.equal(key, value): # encoder-decoder attention @@ -1198,31 +1185,22 @@ class RelPositionMultiheadAttention(nn.Module): if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: - raise RuntimeError( - "The size of the 2D attn_mask is not correct." - ) + raise RuntimeError("The size of the 2D attn_mask is not correct.") elif attn_mask.dim() == 3: if list(attn_mask.size()) != [ bsz * num_heads, query.size(0), key.size(0), ]: - raise RuntimeError( - "The size of the 3D attn_mask is not correct." - ) + raise RuntimeError("The size of the 3D attn_mask is not correct.") else: raise RuntimeError( - "attn_mask's dimension {} is not supported".format( - attn_mask.dim() - ) + "attn_mask's dimension {} is not supported".format(attn_mask.dim()) ) # attn_mask's dim is 3 now. # convert ByteTensor key_padding_mask to bool - if ( - key_padding_mask is not None - and key_padding_mask.dtype == torch.uint8 - ): + if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: warnings.warn( "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." ) @@ -1264,23 +1242,15 @@ class RelPositionMultiheadAttention(nn.Module): # first compute matrix a and matrix c # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - matrix_ac = torch.matmul( - q_with_bias_u, k - ) # (batch, head, time1, time2) + matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2) # compute matrix b and matrix d - matrix_bd = torch.matmul( - q_with_bias_v, p - ) # (batch, head, time1, 2*time1-1) + matrix_bd = torch.matmul(q_with_bias_v, p) # (batch, head, time1, 2*time1-1) matrix_bd = self.rel_shift(matrix_bd, left_context) - attn_output_weights = ( - matrix_ac + matrix_bd - ) # (batch, head, time1, time2) + attn_output_weights = matrix_ac + matrix_bd # (batch, head, time1, time2) - attn_output_weights = attn_output_weights.view( - bsz * num_heads, tgt_len, -1 - ) + attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) if not is_jit_tracing(): assert list(attn_output_weights.size()) == [ @@ -1322,21 +1292,17 @@ class RelPositionMultiheadAttention(nn.Module): ): if attn_mask.size(0) != 1: attn_mask = attn_mask.view(bsz, num_heads, tgt_len, src_len) - combined_mask = attn_mask | key_padding_mask.unsqueeze( - 1 - ).unsqueeze(2) + combined_mask = attn_mask | key_padding_mask.unsqueeze(1).unsqueeze(2) else: # attn_mask.shape == (1, tgt_len, src_len) - combined_mask = attn_mask.unsqueeze( - 0 - ) | key_padding_mask.unsqueeze(1).unsqueeze(2) + combined_mask = attn_mask.unsqueeze(0) | key_padding_mask.unsqueeze( + 1 + ).unsqueeze(2) attn_output_weights = attn_output_weights.view( bsz, num_heads, tgt_len, src_len ) - attn_output_weights = attn_output_weights.masked_fill( - combined_mask, 0.0 - ) + attn_output_weights = attn_output_weights.masked_fill(combined_mask, 0.0) attn_output_weights = attn_output_weights.view( bsz * num_heads, tgt_len, src_len ) @@ -1355,13 +1321,9 @@ class RelPositionMultiheadAttention(nn.Module): ] attn_output = ( - attn_output.transpose(0, 1) - .contiguous() - .view(tgt_len, bsz, embed_dim) - ) - attn_output = nn.functional.linear( - attn_output, out_proj_weight, out_proj_bias + attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) ) + attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) if need_weights: # average attention weights over heads @@ -1498,16 +1460,12 @@ class ConvolutionModule(nn.Module): # manualy padding self.lorder zeros to the left x = nn.functional.pad(x, (self.lorder, 0), "constant", 0.0) else: - assert ( - not self.training - ), "Cache should be None in training time" + assert not self.training, "Cache should be None in training time" assert cache.size(0) == self.lorder x = torch.cat([cache.permute(1, 2, 0), x], dim=2) if right_context > 0: cache = x.permute(2, 0, 1)[ - -(self.lorder + right_context) : ( # noqa - -right_context - ), + -(self.lorder + right_context) : (-right_context), # noqa ..., ] else: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py index 979a0e02e..92138a5ea 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py @@ -132,11 +132,7 @@ from beam_search import ( ) from train import add_model_arguments, get_params, get_transducer_model -from icefall.checkpoint import ( - average_checkpoints, - find_checkpoints, - load_checkpoint, -) +from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, @@ -275,8 +271,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( @@ -397,9 +392,7 @@ def decode_one_batch( simulate_streaming=True, ) else: - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) hyps = [] @@ -465,10 +458,7 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -608,9 +598,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -643,8 +631,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -700,9 +687,7 @@ def main(): if "LG" in params.decoding_method: params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -740,8 +725,7 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -779,9 +763,7 @@ def main(): decoding_graph.scores *= params.ngram_lm_scale else: word_table = None - decoding_graph = k2.trivial_graph( - params.vocab_size - 1, device=device - ) + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) else: decoding_graph = None word_table = None diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py b/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py index ba91302ce..b59928103 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py @@ -107,15 +107,11 @@ class Decoder(nn.Module): # This is for exporting to PNNX via ONNX embedding_out = self.embedding(y) else: - embedding_out = self.embedding(y.clamp(min=0)) * (y >= 0).unsqueeze( - -1 - ) + embedding_out = self.embedding(y.clamp(min=0)) * (y >= 0).unsqueeze(-1) if self.context_size > 1: embedding_out = embedding_out.permute(0, 2, 1) if need_pad: - embedding_out = F.pad( - embedding_out, pad=(self.context_size - 1, 0) - ) + embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0)) else: # During inference time, there is no need to do extra padding # as we only need one output diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/export.py b/egs/librispeech/ASR/pruned_transducer_stateless2/export.py index f1a8ea589..4f1170bbc 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/export.py @@ -51,11 +51,7 @@ import sentencepiece as spm import torch from train import add_model_arguments, get_params, get_transducer_model -from icefall.checkpoint import ( - average_checkpoints, - find_checkpoints, - load_checkpoint, -) +from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.utils import str2bool @@ -120,8 +116,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( @@ -173,8 +168,7 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -222,9 +216,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py b/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py index 6a9d08033..1954f4724 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py @@ -60,9 +60,7 @@ class Joiner(nn.Module): assert encoder_out.shape == decoder_out.shape if project_input: - logit = self.encoder_proj(encoder_out) + self.decoder_proj( - decoder_out - ) + logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out) else: logit = encoder_out + decoder_out diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/model.py b/egs/librispeech/ASR/pruned_transducer_stateless2/model.py index 417c391d9..272d06c37 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/model.py @@ -66,9 +66,7 @@ class Transducer(nn.Module): self.decoder = decoder self.joiner = joiner - self.simple_am_proj = ScaledLinear( - encoder_dim, vocab_size, initial_speed=0.5 - ) + self.simple_am_proj = ScaledLinear(encoder_dim, vocab_size, initial_speed=0.5) self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size) def forward( @@ -152,9 +150,7 @@ class Transducer(nn.Module): y_padded = y.pad(mode="constant", padding_value=0) y_padded = y_padded.to(torch.int64) - boundary = torch.zeros( - (x.size(0), 4), dtype=torch.int64, device=x.device - ) + boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) boundary[:, 2] = y_lens boundary[:, 3] = x_lens diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py index 041a81f45..2d7f557ad 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py @@ -72,17 +72,11 @@ class Eve(Optimizer): if not 0.0 <= eps: raise ValueError("Invalid epsilon value: {}".format(eps)) if not 0.0 <= betas[0] < 1.0: - raise ValueError( - "Invalid beta parameter at index 0: {}".format(betas[0]) - ) + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) if not 0.0 <= betas[1] < 1.0: - raise ValueError( - "Invalid beta parameter at index 1: {}".format(betas[1]) - ) + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) if not 0 <= weight_decay <= 0.1: - raise ValueError( - "Invalid weight_decay value: {}".format(weight_decay) - ) + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) if not 0 < target_rms <= 10.0: raise ValueError("Invalid target_rms value: {}".format(target_rms)) defaults = dict( @@ -118,9 +112,7 @@ class Eve(Optimizer): # Perform optimization step grad = p.grad if grad.is_sparse: - raise RuntimeError( - "AdamW does not support sparse gradients" - ) + raise RuntimeError("AdamW does not support sparse gradients") state = self.state[p] @@ -147,7 +139,7 @@ class Eve(Optimizer): # Decay the first and second moment running average coefficient exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) - denom = (exp_avg_sq.sqrt() * (bias_correction2 ** -0.5)).add_( + denom = (exp_avg_sq.sqrt() * (bias_correction2**-0.5)).add_( group["eps"] ) @@ -158,9 +150,7 @@ class Eve(Optimizer): if p.numel() > 1: # avoid applying this weight-decay on "scaling factors" # (which are scalar). - is_above_target_rms = p.norm() > ( - target_rms * (p.numel() ** 0.5) - ) + is_above_target_rms = p.norm() > (target_rms * (p.numel() ** 0.5)) p.mul_(1 - (weight_decay * is_above_target_rms)) p.addcdiv_(exp_avg, denom, value=-step_size) @@ -180,18 +170,14 @@ class LRScheduler(object): def __init__(self, optimizer: Optimizer, verbose: bool = False): # Attach optimizer if not isinstance(optimizer, Optimizer): - raise TypeError( - "{} is not an Optimizer".format(type(optimizer).__name__) - ) + raise TypeError("{} is not an Optimizer".format(type(optimizer).__name__)) self.optimizer = optimizer self.verbose = verbose for group in optimizer.param_groups: group.setdefault("initial_lr", group["lr"]) - self.base_lrs = [ - group["initial_lr"] for group in optimizer.param_groups - ] + self.base_lrs = [group["initial_lr"] for group in optimizer.param_groups] self.epoch = 0 self.batch = 0 @@ -299,10 +285,9 @@ class Eden(LRScheduler): def get_lr(self): factor = ( - (self.batch ** 2 + self.lr_batches ** 2) / self.lr_batches ** 2 + (self.batch**2 + self.lr_batches**2) / self.lr_batches**2 ) ** -0.25 * ( - ((self.epoch ** 2 + self.lr_epochs ** 2) / self.lr_epochs ** 2) - ** -0.25 + ((self.epoch**2 + self.lr_epochs**2) / self.lr_epochs**2) ** -0.25 ) return [x * factor for x in self.base_lrs] diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless2/pretrained.py index f52cb22ab..e5b5aeba5 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/pretrained.py @@ -168,8 +168,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -223,8 +222,7 @@ def read_sound_files( for f in filenames: wave, sample_rate = torchaudio.load(f) assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" + f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}" ) # We use only the first channel ans.append(wave[0]) @@ -293,9 +291,7 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) feature_lengths = torch.tensor(feature_lengths, device=device) @@ -382,9 +378,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py index 8c572a9ef..c802ecf89 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py @@ -89,9 +89,7 @@ class ActivationBalancerFunction(torch.autograd.Function): below_threshold = mean_abs < min_abs above_threshold = mean_abs > max_abs - ctx.save_for_backward( - factor, xgt0, below_threshold, above_threshold - ) + ctx.save_for_backward(factor, xgt0, below_threshold, above_threshold) ctx.max_factor = max_factor ctx.sum_dims = sum_dims return x @@ -137,7 +135,7 @@ class GradientFilterFunction(torch.autograd.Function): eps = 1.0e-20 dim = ctx.batch_dim norm_dims = [d for d in range(x_grad.ndim) if d != dim] - norm_of_batch = (x_grad ** 2).mean(dim=norm_dims, keepdim=True).sqrt() + norm_of_batch = (x_grad**2).mean(dim=norm_dims, keepdim=True).sqrt() median_norm = norm_of_batch.median() cutoff = median_norm * ctx.threshold @@ -229,8 +227,7 @@ class BasicNorm(torch.nn.Module): if not is_jit_tracing(): assert x.shape[self.channel_dim] == self.num_channels scales = ( - torch.mean(x ** 2, dim=self.channel_dim, keepdim=True) - + self.eps.exp() + torch.mean(x**2, dim=self.channel_dim, keepdim=True) + self.eps.exp() ) ** -0.5 return x * scales @@ -282,12 +279,12 @@ class ScaledLinear(nn.Linear): def _reset_parameters(self, initial_speed: float): std = 0.1 / initial_speed - a = (3 ** 0.5) * std + a = (3**0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: nn.init.constant_(self.bias, 0.0) fan_in = self.weight.shape[1] * self.weight[0][0].numel() - scale = fan_in ** -0.5 # 1/sqrt(fan_in) + scale = fan_in**-0.5 # 1/sqrt(fan_in) with torch.no_grad(): self.weight_scale += torch.tensor(scale / std).log() @@ -301,9 +298,7 @@ class ScaledLinear(nn.Linear): return self.bias * self.bias_scale.exp() def forward(self, input: Tensor) -> Tensor: - return torch.nn.functional.linear( - input, self.get_weight(), self.get_bias() - ) + return torch.nn.functional.linear(input, self.get_weight(), self.get_bias()) class ScaledConv1d(nn.Conv1d): @@ -331,12 +326,12 @@ class ScaledConv1d(nn.Conv1d): def _reset_parameters(self, initial_speed: float): std = 0.1 / initial_speed - a = (3 ** 0.5) * std + a = (3**0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: nn.init.constant_(self.bias, 0.0) fan_in = self.weight.shape[1] * self.weight[0][0].numel() - scale = fan_in ** -0.5 # 1/sqrt(fan_in) + scale = fan_in**-0.5 # 1/sqrt(fan_in) with torch.no_grad(): self.weight_scale += torch.tensor(scale / std).log() @@ -400,12 +395,12 @@ class ScaledConv2d(nn.Conv2d): def _reset_parameters(self, initial_speed: float): std = 0.1 / initial_speed - a = (3 ** 0.5) * std + a = (3**0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: nn.init.constant_(self.bias, 0.0) fan_in = self.weight.shape[1] * self.weight[0][0].numel() - scale = fan_in ** -0.5 # 1/sqrt(fan_in) + scale = fan_in**-0.5 # 1/sqrt(fan_in) with torch.no_grad(): self.weight_scale += torch.tensor(scale / std).log() @@ -476,9 +471,7 @@ class ScaledLSTM(nn.LSTM): setattr(self, scale_name, param) self._scales.append(param) - self.grad_filter = GradientFilter( - batch_dim=1, threshold=grad_norm_threshold - ) + self.grad_filter = GradientFilter(batch_dim=1, threshold=grad_norm_threshold) self._reset_parameters( initial_speed @@ -486,8 +479,8 @@ class ScaledLSTM(nn.LSTM): def _reset_parameters(self, initial_speed: float): std = 0.1 / initial_speed - a = (3 ** 0.5) * std - scale = self.hidden_size ** -0.5 + a = (3**0.5) * std + scale = self.hidden_size**-0.5 v = scale / std for idx, name in enumerate(self._flat_weights_names): if "weight" in name: @@ -559,15 +552,11 @@ class ScaledLSTM(nn.LSTM): """Get scaled weights, and resets their data pointer.""" flat_weights = [] for idx in range(len(self._flat_weights_names)): - flat_weights.append( - self._flat_weights[idx] * self._scales[idx].exp() - ) + flat_weights.append(self._flat_weights[idx] * self._scales[idx].exp()) self._flatten_parameters(flat_weights) return flat_weights - def forward( - self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None - ): + def forward(self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None): # This function is modified from https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/rnn.py # noqa # The change for calling `_VF.lstm()` is: # self._flat_weights -> self._get_flat_weights() @@ -915,9 +904,7 @@ def _test_activation_balancer_sign(): def _test_activation_balancer_magnitude(): magnitudes = torch.arange(0, 1, 0.01) N = 1000 - x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze( - -1 - ) + x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1) x = x.detach() x.requires_grad = True m = ActivationBalancer( @@ -947,8 +934,8 @@ def _test_basic_norm(): y = m(x) assert y.shape == x.shape - x_rms = (x ** 2).mean().sqrt() - y_rms = (y ** 2).mean().sqrt() + x_rms = (x**2).mean().sqrt() + y_rms = (y**2).mean().sqrt() print("x rms = ", x_rms) print("y rms = ", y_rms) assert y_rms < x_rms @@ -1007,11 +994,11 @@ def _test_grad_filter(): print( "_test_grad_filter: x_out_grad norm = ", - (x_out_grad ** 2).mean(dim=(0, 2)).sqrt(), + (x_out_grad**2).mean(dim=(0, 2)).sqrt(), ) print( "_test_grad_filter: x.grad norm = ", - (x.grad ** 2).mean(dim=(0, 2)).sqrt(), + (x.grad**2).mean(dim=(0, 2)).sqrt(), ) print("_test_grad_filter: w_out_grad = ", w_out_grad) print("_test_grad_filter: w.grad = ", w.grad) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_beam_search.py index 9bcd2f9f9..e6e0fb1c8 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_beam_search.py @@ -153,9 +153,7 @@ def modified_beam_search( index=hyps_shape.row_ids(1).to(torch.int64), ) # (num_hyps, encoder_out_dim) - logits = model.joiner( - current_encoder_out, decoder_out, project_input=False - ) + logits = model.joiner(current_encoder_out, decoder_out, project_input=False) # logits is of shape (num_hyps, 1, 1, vocab_size) logits = logits.squeeze(1).squeeze(1) @@ -172,14 +170,10 @@ def modified_beam_search( log_probs_shape = k2.ragged.create_ragged_shape2( row_splits=row_splits, cached_tot_size=log_probs.numel() ) - ragged_log_probs = k2.RaggedTensor( - shape=log_probs_shape, value=log_probs - ) + ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) for i in range(batch_size): - topk_log_probs, topk_indexes = ragged_log_probs[i].topk( - num_active_paths - ) + topk_log_probs, topk_indexes = ragged_log_probs[i].topk(num_active_paths) with warnings.catch_warnings(): warnings.simplefilter("ignore") diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py index d76a03946..0eea3a782 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py @@ -51,11 +51,7 @@ from streaming_beam_search import ( from torch.nn.utils.rnn import pad_sequence from train import add_model_arguments, get_params, get_transducer_model -from icefall.checkpoint import ( - average_checkpoints, - find_checkpoints, - load_checkpoint, -) +from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.utils import ( AttributeDict, setup_logger, @@ -162,8 +158,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( @@ -271,9 +266,7 @@ def decode_one_chunk( encoder_out = model.joiner.encoder_proj(encoder_out) if params.decoding_method == "greedy_search": - greedy_search( - model=model, encoder_out=encoder_out, streams=decode_streams - ) + greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams) elif params.decoding_method == "fast_beam_search": processed_lens = processed_lens + encoder_out_lens fast_beam_search_one_best( @@ -293,9 +286,7 @@ def decode_one_chunk( num_active_paths=params.num_active_paths, ) else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") states = [torch.unbind(states[0], dim=2), torch.unbind(states[1], dim=2)] @@ -351,9 +342,7 @@ def decode_dataset( decode_results = [] # Contain decode streams currently running. decode_streams = [] - initial_states = model.encoder.get_init_state( - params.left_context, device=device - ) + initial_states = model.encoder.get_init_state(params.left_context, device=device) for num, cut in enumerate(cuts): # each utterance has a DecodeStream. decode_stream = DecodeStream( @@ -425,9 +414,7 @@ def decode_dataset( elif params.decoding_method == "modified_beam_search": key = f"num_active_paths_{params.num_active_paths}" else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") return {key: decode_results} @@ -462,8 +449,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -536,8 +522,7 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index 1947834bf..f6702ef16 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -96,9 +96,7 @@ from icefall.utils import ( str2bool, ) -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def add_model_arguments(parser: argparse.ArgumentParser): @@ -210,8 +208,7 @@ def get_parser(): "--initial-lr", type=float, default=0.003, - help="The initial learning rate. This value should not need to " - "be changed.", + help="The initial learning rate. This value should not need to " "be changed.", ) parser.add_argument( @@ -234,8 +231,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( @@ -258,8 +254,7 @@ def get_parser(): "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network)" "part.", ) parser.add_argument( @@ -634,9 +629,7 @@ def compute_loss( # If either all simple_loss or pruned_loss is inf or nan, # we stop the training process by raising an exception - if torch.all(~simple_loss_is_finite) or torch.all( - ~pruned_loss_is_finite - ): + if torch.all(~simple_loss_is_finite) or torch.all(~pruned_loss_is_finite): raise ValueError( "There are too many utterances in this batch " "leading to inf or nan losses." @@ -649,14 +642,9 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 - if warmup < 1.0 - else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = ( - params.simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss + 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) ) + loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training @@ -667,9 +655,7 @@ def compute_loss( # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2 # (2) If some utterances in the batch lead to inf/nan loss, they # are filtered out. - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa info["utterances"] = feature.size(0) @@ -837,9 +823,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -963,8 +947,7 @@ def run(rank, world_size, args): # the threshold if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. " - f"Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. " f"Duration: {c.duration}" ) return False diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/asr_datamodule.py b/egs/librispeech/ASR/pruned_transducer_stateless3/asr_datamodule.py index 1df7f9ee5..b7735be85 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/asr_datamodule.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/asr_datamodule.py @@ -27,10 +27,7 @@ from lhotse.dataset import ( K2SpeechRecognitionDataset, SpecAugment, ) -from lhotse.dataset.input_strategies import ( - OnTheFlyFeatures, - PrecomputedFeatures, -) +from lhotse.dataset.input_strategies import OnTheFlyFeatures, PrecomputedFeatures from torch.utils.data import DataLoader from icefall.utils import str2bool @@ -167,9 +164,7 @@ class AsrDataModule: if cuts_musan is not None: logging.info("Enable MUSAN") transforms.append( - CutMix( - cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True - ) + CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) ) else: logging.info("Disable MUSAN") @@ -178,9 +173,7 @@ class AsrDataModule: if self.args.enable_spec_aug: logging.info("Enable SpecAugment") - logging.info( - f"Time warp factor: {self.args.spec_aug_time_warp_factor}" - ) + logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") input_transforms.append( SpecAugment( time_warp_factor=self.args.spec_aug_time_warp_factor, @@ -250,9 +243,7 @@ class AsrDataModule: if self.args.on_the_fly_feats: validate = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) - ), + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), return_cuts=self.args.return_cuts, ) else: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/decode-giga.py b/egs/librispeech/ASR/pruned_transducer_stateless3/decode-giga.py index 5784a78ba..df24d9585 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/decode-giga.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/decode-giga.py @@ -79,11 +79,7 @@ from gigaspeech import GigaSpeech from gigaspeech_scoring import asr_text_post_processing from train import get_params, get_transducer_model -from icefall.checkpoint import ( - average_checkpoints, - find_checkpoints, - load_checkpoint, -) +from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.utils import ( AttributeDict, setup_logger, @@ -192,8 +188,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -280,9 +275,7 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) hyps = [] if params.decoding_method == "fast_beam_search": @@ -312,10 +305,7 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -446,9 +436,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -481,8 +469,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -532,9 +519,7 @@ def main(): params.suffix += f"-num-paths-{params.num_paths}" params.suffix += f"-nbest-scale-{params.nbest_scale}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -567,8 +552,7 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py index 8025d6be1..55585e08c 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py @@ -120,11 +120,7 @@ from beam_search import ( from librispeech import LibriSpeech from train import add_model_arguments, get_params, get_transducer_model -from icefall.checkpoint import ( - average_checkpoints, - find_checkpoints, - load_checkpoint, -) +from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.lexicon import Lexicon from icefall.rnn_lm.model import RnnLmModel from icefall.utils import ( @@ -265,8 +261,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -478,9 +473,7 @@ def decode_one_batch( simulate_streaming=True, ) else: - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) hyps = [] @@ -550,10 +543,7 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -691,10 +681,7 @@ def decode_one_batch( return {key: hyps} else: return { - ( - f"beam_size_{params.beam_size}_" - f"temperature_{params.temperature}" - ): hyps + (f"beam_size_{params.beam_size}_" f"temperature_{params.temperature}"): hyps } @@ -779,9 +766,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -814,8 +799,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -939,9 +923,7 @@ def main(): if "LG" in params.decoding_method: params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" params.suffix += f"-temperature-{params.temperature}" else: params.suffix += f"-context-{params.context_size}" @@ -981,8 +963,7 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -1032,15 +1013,10 @@ def main(): word_table=word_table, device=device, ) - decoding_graph = k2.trivial_graph( - params.vocab_size - 1, device=device - ) + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) logging.info(f"G properties_str: {G.properties_str}") rnn_lm_model = None - if ( - params.decoding_method - == "fast_beam_search_with_nbest_rnn_rescoring" - ): + if params.decoding_method == "fast_beam_search_with_nbest_rnn_rescoring": rnn_lm_model = RnnLmModel( vocab_size=params.vocab_size, embedding_dim=params.rnn_lm_embedding_dim, @@ -1065,9 +1041,7 @@ def main(): rnn_lm_model.eval() else: word_table = None - decoding_graph = k2.trivial_graph( - params.vocab_size - 1, device=device - ) + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) rnn_lm_model = None else: decoding_graph = None diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/export.py b/egs/librispeech/ASR/pruned_transducer_stateless3/export.py index 47217ba05..2e444353c 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/export.py @@ -128,11 +128,7 @@ import torch.nn as nn from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_params, get_transducer_model -from icefall.checkpoint import ( - average_checkpoints, - find_checkpoints, - load_checkpoint, -) +from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.utils import str2bool @@ -235,8 +231,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( @@ -509,13 +504,9 @@ def export_joiner_model_onnx( - projected_decoder_out: a tensor of shape (N, joiner_dim) """ - encoder_proj_filename = str(joiner_filename).replace( - ".onnx", "_encoder_proj.onnx" - ) + encoder_proj_filename = str(joiner_filename).replace(".onnx", "_encoder_proj.onnx") - decoder_proj_filename = str(joiner_filename).replace( - ".onnx", "_decoder_proj.onnx" - ) + decoder_proj_filename = str(joiner_filename).replace(".onnx", "_decoder_proj.onnx") encoder_out_dim = joiner_model.encoder_proj.weight.shape[1] decoder_out_dim = joiner_model.decoder_proj.weight.shape[1] @@ -616,8 +607,7 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -715,9 +705,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/gigaspeech.py b/egs/librispeech/ASR/pruned_transducer_stateless3/gigaspeech.py index 36f32c6b3..598434f54 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/gigaspeech.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/gigaspeech.py @@ -52,18 +52,14 @@ class GigaSpeech: ) pattern = re.compile(r"gigaspeech_cuts_XL.([0-9]+).jsonl.gz") - idx_filenames = [ - (int(pattern.search(f).group(1)), f) for f in filenames - ] + idx_filenames = [(int(pattern.search(f).group(1)), f) for f in filenames] idx_filenames = sorted(idx_filenames, key=lambda x: x[0]) sorted_filenames = [f[1] for f in idx_filenames] logging.info(f"Loading {len(sorted_filenames)} splits") - return lhotse.combine( - lhotse.load_manifest_lazy(p) for p in sorted_filenames - ) + return lhotse.combine(lhotse.load_manifest_lazy(p) for p in sorted_filenames) def train_L_cuts(self) -> CutSet: f = self.manifest_dir / "gigaspeech_cuts_L_raw.jsonl.gz" diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/jit_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless3/jit_pretrained.py index 162f8c7db..86cb45c09 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/jit_pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/jit_pretrained.py @@ -143,8 +143,7 @@ def read_sound_files( for f in filenames: wave, sample_rate = torchaudio.load(f) assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" + f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}" ) # We use only the first channel ans.append(wave[0]) @@ -330,9 +329,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/model.py b/egs/librispeech/ASR/pruned_transducer_stateless3/model.py index 7852f84e9..d45f6dadc 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/model.py @@ -84,9 +84,7 @@ class Transducer(nn.Module): self.decoder_giga = decoder_giga self.joiner_giga = joiner_giga - self.simple_am_proj = ScaledLinear( - encoder_dim, vocab_size, initial_speed=0.5 - ) + self.simple_am_proj = ScaledLinear(encoder_dim, vocab_size, initial_speed=0.5) self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size) if decoder_giga is not None: @@ -190,9 +188,7 @@ class Transducer(nn.Module): y_padded = y.pad(mode="constant", padding_value=0) y_padded = y_padded.to(torch.int64) - boundary = torch.zeros( - (x.size(0), 4), dtype=torch.int64, device=x.device - ) + boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) boundary[:, 2] = y_lens boundary[:, 3] = encoder_out_lens diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check.py b/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check.py index d03d1d7ef..163d737e3 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check.py @@ -203,9 +203,7 @@ def test_joiner( ) # Now test encoder_proj - joiner_encoder_proj_inputs = { - encoder_proj_input_name: encoder_out.numpy() - } + joiner_encoder_proj_inputs = {encoder_proj_input_name: encoder_out.numpy()} joiner_encoder_proj_out = joiner_encoder_proj_session.run( [encoder_proj_output_name], joiner_encoder_proj_inputs )[0] @@ -214,16 +212,10 @@ def test_joiner( torch_joiner_encoder_proj_out = model.joiner.encoder_proj(encoder_out) assert torch.allclose( joiner_encoder_proj_out, torch_joiner_encoder_proj_out, atol=1e-5 - ), ( - (joiner_encoder_proj_out - torch_joiner_encoder_proj_out) - .abs() - .max() - ) + ), ((joiner_encoder_proj_out - torch_joiner_encoder_proj_out).abs().max()) # Now test decoder_proj - joiner_decoder_proj_inputs = { - decoder_proj_input_name: decoder_out.numpy() - } + joiner_decoder_proj_inputs = {decoder_proj_input_name: decoder_out.numpy()} joiner_decoder_proj_out = joiner_decoder_proj_session.run( [decoder_proj_output_name], joiner_decoder_proj_inputs )[0] @@ -232,11 +224,7 @@ def test_joiner( torch_joiner_decoder_proj_out = model.joiner.decoder_proj(decoder_out) assert torch.allclose( joiner_decoder_proj_out, torch_joiner_decoder_proj_out, atol=1e-5 - ), ( - (joiner_decoder_proj_out - torch_joiner_decoder_proj_out) - .abs() - .max() - ) + ), ((joiner_decoder_proj_out - torch_joiner_decoder_proj_out).abs().max()) @torch.no_grad() @@ -288,9 +276,7 @@ def main(): if __name__ == "__main__": torch.manual_seed(20220727) - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_pretrained.py index ea5d4e674..825c6510b 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_pretrained.py @@ -141,8 +141,7 @@ def read_sound_files( for f in filenames: wave, sample_rate = torchaudio.load(f) assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" + f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}" ) # We use only the first channel ans.append(wave[0]) @@ -191,11 +190,7 @@ def greedy_search( projected_encoder_out = joiner_encoder_proj.run( [joiner_encoder_proj.get_outputs()[0].name], - { - joiner_encoder_proj.get_inputs()[ - 0 - ].name: packed_encoder_out.data.numpy() - }, + {joiner_encoder_proj.get_inputs()[0].name: packed_encoder_out.data.numpy()}, )[0] blank_id = 0 # hard-code to 0 @@ -382,9 +377,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py index 19b636a23..77bd6d13d 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py @@ -177,8 +177,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -232,8 +231,7 @@ def read_sound_files( for f in filenames: wave, sample_rate = torchaudio.load(f) assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" + f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}" ) # We use only the first channel ans.append(wave[0]) @@ -302,9 +300,7 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) feature_lengths = torch.tensor(feature_lengths, device=device) @@ -391,9 +387,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py b/egs/librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py index 1e6022b57..b712eeda0 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py @@ -234,9 +234,7 @@ def scaled_lstm_to_lstm(scaled_lstm: ScaledLSTM) -> nn.LSTM: assert lstm._flat_weights_names == scaled_lstm._flat_weights_names for idx in range(len(scaled_lstm._flat_weights_names)): - scaled_weight = ( - scaled_lstm._flat_weights[idx] * scaled_lstm._scales[idx].exp() - ) + scaled_weight = scaled_lstm._flat_weights[idx] * scaled_lstm._scales[idx].exp() lstm._flat_weights[idx].data.copy_(scaled_weight) return lstm diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_decode.py index 10bb44e00..e85d2060a 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_decode.py @@ -52,11 +52,7 @@ from streaming_beam_search import ( from torch.nn.utils.rnn import pad_sequence from train import add_model_arguments, get_params, get_transducer_model -from icefall.checkpoint import ( - average_checkpoints, - find_checkpoints, - load_checkpoint, -) +from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.utils import ( AttributeDict, setup_logger, @@ -163,8 +159,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( @@ -272,9 +267,7 @@ def decode_one_chunk( encoder_out = model.joiner.encoder_proj(encoder_out) if params.decoding_method == "greedy_search": - greedy_search( - model=model, encoder_out=encoder_out, streams=decode_streams - ) + greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams) elif params.decoding_method == "fast_beam_search": processed_lens = processed_lens + encoder_out_lens fast_beam_search_one_best( @@ -294,9 +287,7 @@ def decode_one_chunk( num_active_paths=params.num_active_paths, ) else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") states = [torch.unbind(states[0], dim=2), torch.unbind(states[1], dim=2)] @@ -352,9 +343,7 @@ def decode_dataset( decode_results = [] # Contain decode streams currently running. decode_streams = [] - initial_states = model.encoder.get_init_state( - params.left_context, device=device - ) + initial_states = model.encoder.get_init_state(params.left_context, device=device) for num, cut in enumerate(cuts): # each utterance has a DecodeStream. decode_stream = DecodeStream( @@ -426,9 +415,7 @@ def decode_dataset( elif params.decoding_method == "modified_beam_search": key = f"num_active_paths_{params.num_active_paths}" else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") return {key: decode_results} @@ -461,8 +448,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -535,8 +521,7 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/test_onnx.py b/egs/librispeech/ASR/pruned_transducer_stateless3/test_onnx.py index 66ffbd3ec..598fcf344 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/test_onnx.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/test_onnx.py @@ -90,9 +90,7 @@ def test_conv2d_subsampling(): onnx_y = torch.from_numpy(onnx_y) torch_y = jit_model(x) - assert torch.allclose(onnx_y, torch_y, atol=1e-05), ( - (onnx_y - torch_y).abs().max() - ) + assert torch.allclose(onnx_y, torch_y, atol=1e-05), (onnx_y - torch_y).abs().max() os.remove(filename) @@ -147,9 +145,7 @@ def test_rel_pos(): onnx_pos_emb = torch.from_numpy(onnx_pos_emb) torch_y, torch_pos_emb = jit_model(x) - assert torch.allclose(onnx_y, torch_y, atol=1e-05), ( - (onnx_y - torch_y).abs().max() - ) + assert torch.allclose(onnx_y, torch_y, atol=1e-05), (onnx_y - torch_y).abs().max() assert torch.allclose(onnx_pos_emb, torch_pos_emb, atol=1e-05), ( (onnx_pos_emb - torch_pos_emb).abs().max() @@ -197,9 +193,7 @@ def test_conformer_encoder_layer(): encoder_layer.eval() encoder_layer = convert_scaled_to_non_scaled(encoder_layer, inplace=True) - jit_model = torch.jit.trace( - encoder_layer, (x, pos_emb, src_key_padding_mask) - ) + jit_model = torch.jit.trace(encoder_layer, (x, pos_emb, src_key_padding_mask)) torch.onnx.export( encoder_layer, @@ -236,9 +230,7 @@ def test_conformer_encoder_layer(): onnx_y = torch.from_numpy(onnx_y) torch_y = jit_model(x, pos_emb, src_key_padding_mask) - assert torch.allclose(onnx_y, torch_y, atol=1e-05), ( - (onnx_y - torch_y).abs().max() - ) + assert torch.allclose(onnx_y, torch_y, atol=1e-05), (onnx_y - torch_y).abs().max() print(onnx_y.abs().sum(), torch_y.abs().sum(), onnx_y.shape, torch_y.shape) @@ -322,9 +314,7 @@ def test_conformer_encoder(): onnx_y = torch.from_numpy(onnx_y) torch_y = jit_model(x, pos_emb, src_key_padding_mask) - assert torch.allclose(onnx_y, torch_y, atol=1e-05), ( - (onnx_y - torch_y).abs().max() - ) + assert torch.allclose(onnx_y, torch_y, atol=1e-05), (onnx_y - torch_y).abs().max() print(onnx_y.abs().sum(), torch_y.abs().sum(), onnx_y.shape, torch_y.shape) @@ -379,9 +369,7 @@ def test_conformer(): onnx_y_lens = torch.from_numpy(onnx_y_lens) torch_y, torch_y_lens = jit_model(x, x_lens) - assert torch.allclose(onnx_y, torch_y, atol=1e-05), ( - (onnx_y - torch_y).abs().max() - ) + assert torch.allclose(onnx_y, torch_y, atol=1e-05), (onnx_y - torch_y).abs().max() assert torch.allclose(onnx_y_lens, torch_y_lens, atol=1e-05), ( (onnx_y_lens - torch_y_lens).abs().max() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/train.py b/egs/librispeech/ASR/pruned_transducer_stateless3/train.py index 44e96644a..e9ceb60de 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/train.py @@ -92,9 +92,7 @@ from icefall.utils import ( str2bool, ) -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def add_model_arguments(parser: argparse.ArgumentParser): @@ -163,8 +161,7 @@ def get_parser(): "--full-libri", type=str2bool, default=True, - help="When enabled, use 960h LibriSpeech. " - "Otherwise, use 100h subset.", + help="When enabled, use 960h LibriSpeech. " "Otherwise, use 100h subset.", ) parser.add_argument( @@ -214,8 +211,7 @@ def get_parser(): "--initial-lr", type=float, default=0.003, - help="The initial learning rate. This value should not need " - "to be changed.", + help="The initial learning rate. This value should not need " "to be changed.", ) parser.add_argument( @@ -238,8 +234,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( @@ -262,8 +257,7 @@ def get_parser(): "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network)" "part.", ) parser.add_argument( @@ -672,9 +666,7 @@ def compute_loss( # If either all simple_loss or pruned_loss is inf or nan, # we stop the training process by raising an exception - if torch.all(~simple_loss_is_finite) or torch.all( - ~pruned_loss_is_finite - ): + if torch.all(~simple_loss_is_finite) or torch.all(~pruned_loss_is_finite): raise ValueError( "There are too many utterances in this batch " "leading to inf or nan losses." @@ -687,14 +679,9 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 - if warmup < 1.0 - else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = ( - params.simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss + 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) ) + loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training @@ -705,9 +692,7 @@ def compute_loss( # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2 # (2) If some utterances in the batch lead to inf/nan loss, they # are filtered out. - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa info["utterances"] = feature.size(0) @@ -919,9 +904,7 @@ def train_one_epoch( f"train/current_{prefix}_", params.batch_idx_train, ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) libri_tot_loss.write_summary( tb_writer, "train/libri_tot_", params.batch_idx_train ) @@ -967,8 +950,7 @@ def filter_short_and_long_utterances( # the threshold if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. " - f"Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. " f"Duration: {c.duration}" ) return False @@ -1109,9 +1091,7 @@ def run(rank, world_size, args): train_giga_cuts = train_giga_cuts.repeat(times=None) if args.enable_musan: - cuts_musan = load_manifest( - Path(args.manifest_dir) / "musan_cuts.jsonl.gz" - ) + cuts_musan = load_manifest(Path(args.manifest_dir) / "musan_cuts.jsonl.gz") else: cuts_musan = None diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py index 4f043e5a6..2f9a60f13 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py @@ -306,8 +306,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -427,9 +426,7 @@ def decode_one_batch( simulate_streaming=True, ) else: - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) if ( params.decoding_method == "fast_beam_search" @@ -485,10 +482,7 @@ def decode_one_batch( nbest_scale=params.nbest_scale, return_timestamps=True, ) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: res = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -566,9 +560,7 @@ def decode_dataset( sp: spm.SentencePieceProcessor, word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[ - str, List[Tuple[str, List[str], List[str], List[float], List[float]]] -]: +) -> Dict[str, List[Tuple[str, List[str], List[str], List[float], List[float]]]]: """Decode dataset. Args: @@ -643,9 +635,7 @@ def decode_dataset( cut_ids, hyps, texts, timestamps_hyp, timestamps_ref ): ref_words = ref_text.split() - this_batch.append( - (cut_id, ref_words, hyp_words, time_ref, time_hyp) - ) + this_batch.append((cut_id, ref_words, hyp_words, time_ref, time_hyp)) results[name].extend(this_batch) @@ -654,9 +644,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -694,8 +682,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -722,9 +709,7 @@ def save_results( note = "" logging.info(s) - s = "\nFor {}, symbol-delay of different settings are:\n".format( - test_set_name - ) + s = "\nFor {}, symbol-delay of different settings are:\n".format(test_set_name) note = "\tbest for {}".format(test_set_name) for key, val in test_set_delays: s += "{}\tmean: {}s, variance: {}{}\n".format(key, val[0], val[1], note) @@ -773,9 +758,7 @@ def main(): if "LG" in params.decoding_method: params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -812,9 +795,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -841,9 +824,9 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -902,9 +885,7 @@ def main(): decoding_graph.scores *= params.ngram_lm_scale else: word_table = None - decoding_graph = k2.trivial_graph( - params.vocab_size - 1, device=device - ) + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) else: decoding_graph = None word_table = None diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/export.py b/egs/librispeech/ASR/pruned_transducer_stateless4/export.py index ce7518ceb..64ef89733 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/export.py @@ -133,8 +133,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( @@ -183,9 +182,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -212,9 +211,9 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -282,9 +281,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_decode.py index 7af9ea9b8..d74d1c89d 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_decode.py @@ -175,8 +175,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( @@ -284,9 +283,7 @@ def decode_one_chunk( encoder_out = model.joiner.encoder_proj(encoder_out) if params.decoding_method == "greedy_search": - greedy_search( - model=model, encoder_out=encoder_out, streams=decode_streams - ) + greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams) elif params.decoding_method == "fast_beam_search": processed_lens = processed_lens + encoder_out_lens fast_beam_search_one_best( @@ -306,9 +303,7 @@ def decode_one_chunk( num_active_paths=params.num_active_paths, ) else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") states = [torch.unbind(states[0], dim=2), torch.unbind(states[1], dim=2)] @@ -364,9 +359,7 @@ def decode_dataset( decode_results = [] # Contain decode streams currently running. decode_streams = [] - initial_states = model.encoder.get_init_state( - params.left_context, device=device - ) + initial_states = model.encoder.get_init_state(params.left_context, device=device) for num, cut in enumerate(cuts): # each utterance has a DecodeStream. decode_stream = DecodeStream( @@ -438,9 +431,7 @@ def decode_dataset( elif params.decoding_method == "modified_beam_search": key = f"num_active_paths_{params.num_active_paths}" else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") return {key: decode_results} @@ -473,8 +464,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -547,9 +537,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -576,9 +566,9 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py index cf32e565b..97f3e56a9 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py @@ -101,9 +101,7 @@ from icefall.utils import ( str2bool, ) -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def add_model_arguments(parser: argparse.ArgumentParser): @@ -239,8 +237,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( @@ -263,8 +260,7 @@ def get_parser(): "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network)" "part.", ) parser.add_argument( @@ -621,11 +617,7 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = ( - model.device - if isinstance(model, DDP) - else next(model.parameters()).device - ) + device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -665,9 +657,7 @@ def compute_loss( # If either all simple_loss or pruned_loss is inf or nan, # we stop the training process by raising an exception - if torch.all(~simple_loss_is_finite) or torch.all( - ~pruned_loss_is_finite - ): + if torch.all(~simple_loss_is_finite) or torch.all(~pruned_loss_is_finite): raise ValueError( "There are too many utterances in this batch " "leading to inf or nan losses." @@ -680,14 +670,9 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 - if warmup < 1.0 - else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = ( - params.simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss + 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) ) + loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training @@ -698,9 +683,7 @@ def compute_loss( # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2 # (2) If some utterances in the batch lead to inf/nan loss, they # are filtered out. - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa info["utterances"] = feature.size(0) @@ -879,9 +862,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -1013,8 +994,7 @@ def run(rank, world_size, args): # the threshold if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. " - f"Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. " f"Duration: {c.duration}" ) return False diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py index 427b06294..b3a7d71bc 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py @@ -214,10 +214,7 @@ class Conformer(EncoderInterface): (num_encoder_layers, cnn_module_kernel - 1, encoder_dim). NOTE: the returned tensors are on the given device. """ - if ( - len(self._init_state) == 2 - and self._init_state[0].size(1) == left_context - ): + if len(self._init_state) == 2 and self._init_state[0].size(1) == left_context: # Note: It is OK to share the init state as it is # not going to be modified by the model return self._init_state @@ -439,9 +436,7 @@ class ConformerEncoderLayer(nn.Module): self.d_model = d_model - self.self_attn = RelPositionMultiheadAttention( - d_model, nhead, dropout=0.0 - ) + self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) self.feed_forward = nn.Sequential( ScaledLinear(d_model, dim_feedforward), @@ -459,9 +454,7 @@ class ConformerEncoderLayer(nn.Module): ScaledLinear(dim_feedforward, d_model, initial_scale=0.25), ) - self.conv_module = ConvolutionModule( - d_model, cnn_module_kernel, causal=causal - ) + self.conv_module = ConvolutionModule(d_model, cnn_module_kernel, causal=causal) self.norm_final = BasicNorm(d_model) @@ -527,9 +520,7 @@ class ConformerEncoderLayer(nn.Module): src = src + self.dropout(src_att) # convolution module - conv, _ = self.conv_module( - src, src_key_padding_mask=src_key_padding_mask - ) + conv, _ = self.conv_module(src, src_key_padding_mask=src_key_padding_mask) src = src + self.dropout(conv) # feed forward module @@ -802,9 +793,7 @@ class RelPositionalEncoding(torch.nn.Module): """ - def __init__( - self, d_model: int, dropout_rate: float, max_len: int = 5000 - ) -> None: + def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: """Construct an PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() self.d_model = d_model @@ -820,9 +809,7 @@ class RelPositionalEncoding(torch.nn.Module): # the length of self.pe is 2 * input_len - 1 if self.pe.size(1) >= x_size_1 * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str( - x.device - ): + if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return # Suppose `i` means to the position of query vector and `j` means the @@ -848,9 +835,7 @@ class RelPositionalEncoding(torch.nn.Module): pe = torch.cat([pe_positive, pe_negative], dim=1) self.pe = pe.to(device=x.device, dtype=x.dtype) - def forward( - self, x: torch.Tensor, left_context: int = 0 - ) -> Tuple[Tensor, Tensor]: + def forward(self, x: torch.Tensor, left_context: int = 0) -> Tuple[Tensor, Tensor]: """Add positional encoding. Args: @@ -1118,9 +1103,9 @@ class RelPositionMultiheadAttention(nn.Module): if torch.equal(query, key) and torch.equal(key, value): # self-attention - q, k, v = nn.functional.linear( - query, in_proj_weight, in_proj_bias - ).chunk(3, dim=-1) + q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk( + 3, dim=-1 + ) elif torch.equal(key, value): # encoder-decoder attention @@ -1189,31 +1174,22 @@ class RelPositionMultiheadAttention(nn.Module): if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: - raise RuntimeError( - "The size of the 2D attn_mask is not correct." - ) + raise RuntimeError("The size of the 2D attn_mask is not correct.") elif attn_mask.dim() == 3: if list(attn_mask.size()) != [ bsz * num_heads, query.size(0), key.size(0), ]: - raise RuntimeError( - "The size of the 3D attn_mask is not correct." - ) + raise RuntimeError("The size of the 3D attn_mask is not correct.") else: raise RuntimeError( - "attn_mask's dimension {} is not supported".format( - attn_mask.dim() - ) + "attn_mask's dimension {} is not supported".format(attn_mask.dim()) ) # attn_mask's dim is 3 now. # convert ByteTensor key_padding_mask to bool - if ( - key_padding_mask is not None - and key_padding_mask.dtype == torch.uint8 - ): + if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: warnings.warn( "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." ) @@ -1253,23 +1229,15 @@ class RelPositionMultiheadAttention(nn.Module): # first compute matrix a and matrix c # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - matrix_ac = torch.matmul( - q_with_bias_u, k - ) # (batch, head, time1, time2) + matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2) # compute matrix b and matrix d - matrix_bd = torch.matmul( - q_with_bias_v, p - ) # (batch, head, time1, 2*time1-1) + matrix_bd = torch.matmul(q_with_bias_v, p) # (batch, head, time1, 2*time1-1) matrix_bd = self.rel_shift(matrix_bd, left_context) - attn_output_weights = ( - matrix_ac + matrix_bd - ) # (batch, head, time1, time2) + attn_output_weights = matrix_ac + matrix_bd # (batch, head, time1, time2) - attn_output_weights = attn_output_weights.view( - bsz * num_heads, tgt_len, -1 - ) + attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) assert list(attn_output_weights.size()) == [ bsz * num_heads, @@ -1310,21 +1278,17 @@ class RelPositionMultiheadAttention(nn.Module): ): if attn_mask.size(0) != 1: attn_mask = attn_mask.view(bsz, num_heads, tgt_len, src_len) - combined_mask = attn_mask | key_padding_mask.unsqueeze( - 1 - ).unsqueeze(2) + combined_mask = attn_mask | key_padding_mask.unsqueeze(1).unsqueeze(2) else: # attn_mask.shape == (1, tgt_len, src_len) - combined_mask = attn_mask.unsqueeze( - 0 - ) | key_padding_mask.unsqueeze(1).unsqueeze(2) + combined_mask = attn_mask.unsqueeze(0) | key_padding_mask.unsqueeze( + 1 + ).unsqueeze(2) attn_output_weights = attn_output_weights.view( bsz, num_heads, tgt_len, src_len ) - attn_output_weights = attn_output_weights.masked_fill( - combined_mask, 0.0 - ) + attn_output_weights = attn_output_weights.masked_fill(combined_mask, 0.0) attn_output_weights = attn_output_weights.view( bsz * num_heads, tgt_len, src_len ) @@ -1336,13 +1300,9 @@ class RelPositionMultiheadAttention(nn.Module): attn_output = torch.bmm(attn_output_weights, v) assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] attn_output = ( - attn_output.transpose(0, 1) - .contiguous() - .view(tgt_len, bsz, embed_dim) - ) - attn_output = nn.functional.linear( - attn_output, out_proj_weight, out_proj_bias + attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) ) + attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) if need_weights: # average attention weights over heads @@ -1481,16 +1441,12 @@ class ConvolutionModule(nn.Module): # manualy padding self.lorder zeros to the left x = nn.functional.pad(x, (self.lorder, 0), "constant", 0.0) else: - assert ( - not self.training - ), "Cache should be None in training time" + assert not self.training, "Cache should be None in training time" assert cache.size(0) == self.lorder x = torch.cat([cache.permute(1, 2, 0), x], dim=2) if right_context > 0: cache = x.permute(2, 0, 1)[ - -(self.lorder + right_context) : ( # noqa - -right_context - ), + -(self.lorder + right_context) : (-right_context), # noqa ..., ] else: @@ -1666,9 +1622,7 @@ class RandomCombine(nn.Module): self.stddev = stddev self.final_log_weight = ( - torch.tensor( - (final_weight / (1 - final_weight)) * (self.num_inputs - 1) - ) + torch.tensor((final_weight / (1 - final_weight)) * (self.num_inputs - 1)) .log() .item() ) @@ -1765,16 +1719,14 @@ class RandomCombine(nn.Module): # final contains self.num_inputs - 1 in all elements final = torch.full((num_frames,), self.num_inputs - 1, device=device) # nonfinal contains random integers in [0..num_inputs - 2], these are for non-final weights. - nonfinal = torch.randint( - self.num_inputs - 1, (num_frames,), device=device - ) + nonfinal = torch.randint(self.num_inputs - 1, (num_frames,), device=device) indexes = torch.where( torch.rand(num_frames, device=device) < final_prob, final, nonfinal ) - ans = torch.nn.functional.one_hot( - indexes, num_classes=self.num_inputs - ).to(dtype=dtype) + ans = torch.nn.functional.one_hot(indexes, num_classes=self.num_inputs).to( + dtype=dtype + ) return ans def _get_random_mixed_weights( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py index 22bcdd88e..5c76afde6 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py @@ -303,8 +303,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( @@ -477,9 +476,7 @@ def decode_one_batch( simulate_streaming=True, ) else: - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) hyps = [] @@ -545,10 +542,7 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -696,9 +690,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -731,8 +723,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -787,9 +778,7 @@ def main(): if "LG" in params.decoding_method: params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -828,9 +817,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -857,9 +846,9 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -937,9 +926,7 @@ def main(): decoding_graph.scores *= params.ngram_lm_scale else: word_table = None - decoding_graph = k2.trivial_graph( - params.vocab_size - 1, device=device - ) + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) else: decoding_graph = None word_table = None diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/export.py b/egs/librispeech/ASR/pruned_transducer_stateless5/export.py index b2e5b430e..f0bfd3b4c 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/export.py @@ -133,8 +133,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( @@ -181,9 +180,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -210,9 +209,9 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -280,9 +279,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless5/pretrained.py index 1e100fcbd..77ba0873b 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/pretrained.py @@ -166,8 +166,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -199,8 +198,7 @@ def read_sound_files( for f in filenames: wave, sample_rate = torchaudio.load(f) assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" + f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}" ) # We use only the first channel ans.append(wave[0]) @@ -264,15 +262,11 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) feature_lengths = torch.tensor(feature_lengths, device=device) - encoder_out, encoder_out_lens = model.encoder( - x=features, x_lens=feature_lengths - ) + encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths) num_waves = encoder_out.size(0) hyps = [] @@ -344,9 +338,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless5/streaming_decode.py index 6fee9483e..e750f5554 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/streaming_decode.py @@ -175,8 +175,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( @@ -284,9 +283,7 @@ def decode_one_chunk( encoder_out = model.joiner.encoder_proj(encoder_out) if params.decoding_method == "greedy_search": - greedy_search( - model=model, encoder_out=encoder_out, streams=decode_streams - ) + greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams) elif params.decoding_method == "fast_beam_search": processed_lens = processed_lens + encoder_out_lens fast_beam_search_one_best( @@ -306,9 +303,7 @@ def decode_one_chunk( num_active_paths=params.num_active_paths, ) else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") states = [torch.unbind(states[0], dim=2), torch.unbind(states[1], dim=2)] @@ -364,9 +359,7 @@ def decode_dataset( decode_results = [] # Contain decode streams currently running. decode_streams = [] - initial_states = model.encoder.get_init_state( - params.left_context, device=device - ) + initial_states = model.encoder.get_init_state(params.left_context, device=device) for num, cut in enumerate(cuts): # each utterance has a DecodeStream. decode_stream = DecodeStream( @@ -438,9 +431,7 @@ def decode_dataset( elif params.decoding_method == "modified_beam_search": key = f"num_active_paths_{params.num_active_paths}" else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") return {key: decode_results} @@ -473,8 +464,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -547,9 +537,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -576,9 +566,9 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/train.py b/egs/librispeech/ASR/pruned_transducer_stateless5/train.py index 179d9372e..a1a810d3e 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/train.py @@ -89,9 +89,7 @@ from icefall.utils import ( str2bool, ) -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def add_model_arguments(parser: argparse.ArgumentParser): @@ -248,8 +246,7 @@ def get_parser(): "--initial-lr", type=float, default=0.003, - help="The initial learning rate. This value should not need " - "to be changed.", + help="The initial learning rate. This value should not need " "to be changed.", ) parser.add_argument( @@ -272,8 +269,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( @@ -296,8 +292,7 @@ def get_parser(): "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network)" "part.", ) parser.add_argument( @@ -645,11 +640,7 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = ( - model.device - if isinstance(model, DDP) - else next(model.parameters()).device - ) + device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -690,9 +681,7 @@ def compute_loss( # If the batch contains more than 10 utterances AND # if either all simple_loss or pruned_loss is inf or nan, # we stop the training process by raising an exception - if torch.all(~simple_loss_is_finite) or torch.all( - ~pruned_loss_is_finite - ): + if torch.all(~simple_loss_is_finite) or torch.all(~pruned_loss_is_finite): raise ValueError( "There are too many utterances in this batch " "leading to inf or nan losses." @@ -705,14 +694,9 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 - if warmup < 1.0 - else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = ( - params.simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss + 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) ) + loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training @@ -723,9 +707,7 @@ def compute_loss( # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2 # (2) If some utterances in the batch lead to inf/nan loss, they # are filtered out. - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa info["utterances"] = feature.size(0) @@ -908,9 +890,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -1023,7 +1003,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2 ** 22 + 2**22 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) @@ -1045,8 +1025,7 @@ def run(rank, world_size, args): # the threshold if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. " - f"Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. " f"Duration: {c.duration}" ) return False diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless6/conformer.py index 53788b3f7..0667e7f61 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/conformer.py @@ -90,10 +90,7 @@ class Conformer(EncoderInterface): output_layers = [] if middle_output_layer is not None: - assert ( - middle_output_layer >= 0 - and middle_output_layer < num_encoder_layers - ) + assert middle_output_layer >= 0 and middle_output_layer < num_encoder_layers output_layers.append(middle_output_layer) # The last layer is always needed. @@ -178,9 +175,7 @@ class ConformerEncoderLayer(nn.Module): self.d_model = d_model - self.self_attn = RelPositionMultiheadAttention( - d_model, nhead, dropout=0.0 - ) + self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) self.feed_forward = nn.Sequential( ScaledLinear(d_model, dim_feedforward), @@ -362,9 +357,7 @@ class RelPositionalEncoding(torch.nn.Module): """ - def __init__( - self, d_model: int, dropout_rate: float, max_len: int = 5000 - ) -> None: + def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: """Construct an PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() self.d_model = d_model @@ -379,9 +372,7 @@ class RelPositionalEncoding(torch.nn.Module): # the length of self.pe is 2 * input_len - 1 if self.pe.size(1) >= x.size(1) * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str( - x.device - ): + if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return # Suppose `i` means to the position of query vector and `j` means the @@ -656,9 +647,9 @@ class RelPositionMultiheadAttention(nn.Module): if torch.equal(query, key) and torch.equal(key, value): # self-attention - q, k, v = nn.functional.linear( - query, in_proj_weight, in_proj_bias - ).chunk(3, dim=-1) + q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk( + 3, dim=-1 + ) elif torch.equal(key, value): # encoder-decoder attention @@ -727,31 +718,22 @@ class RelPositionMultiheadAttention(nn.Module): if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: - raise RuntimeError( - "The size of the 2D attn_mask is not correct." - ) + raise RuntimeError("The size of the 2D attn_mask is not correct.") elif attn_mask.dim() == 3: if list(attn_mask.size()) != [ bsz * num_heads, query.size(0), key.size(0), ]: - raise RuntimeError( - "The size of the 3D attn_mask is not correct." - ) + raise RuntimeError("The size of the 3D attn_mask is not correct.") else: raise RuntimeError( - "attn_mask's dimension {} is not supported".format( - attn_mask.dim() - ) + "attn_mask's dimension {} is not supported".format(attn_mask.dim()) ) # attn_mask's dim is 3 now. # convert ByteTensor key_padding_mask to bool - if ( - key_padding_mask is not None - and key_padding_mask.dtype == torch.uint8 - ): + if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: warnings.warn( "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." ) @@ -790,9 +772,7 @@ class RelPositionMultiheadAttention(nn.Module): # first compute matrix a and matrix c # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - matrix_ac = torch.matmul( - q_with_bias_u, k - ) # (batch, head, time1, time2) + matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2) # compute matrix b and matrix d matrix_bd = torch.matmul( @@ -800,13 +780,9 @@ class RelPositionMultiheadAttention(nn.Module): ) # (batch, head, time1, 2*time1-1) matrix_bd = self.rel_shift(matrix_bd) - attn_output_weights = ( - matrix_ac + matrix_bd - ) # (batch, head, time1, time2) + attn_output_weights = matrix_ac + matrix_bd # (batch, head, time1, time2) - attn_output_weights = attn_output_weights.view( - bsz * num_heads, tgt_len, -1 - ) + attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) assert list(attn_output_weights.size()) == [ bsz * num_heads, @@ -840,13 +816,9 @@ class RelPositionMultiheadAttention(nn.Module): attn_output = torch.bmm(attn_output_weights, v) assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] attn_output = ( - attn_output.transpose(0, 1) - .contiguous() - .view(tgt_len, bsz, embed_dim) - ) - attn_output = nn.functional.linear( - attn_output, out_proj_weight, out_proj_bias + attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) ) + attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) if need_weights: # average attention weights over heads @@ -869,9 +841,7 @@ class ConvolutionModule(nn.Module): """ - def __init__( - self, channels: int, kernel_size: int, bias: bool = True - ) -> None: + def __init__(self, channels: int, kernel_size: int, bias: bool = True) -> None: """Construct an ConvolutionModule object.""" super(ConvolutionModule, self).__init__() # kernerl_size should be a odd number for 'SAME' padding diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless6/decode.py index 74df04006..3734564fe 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/decode.py @@ -208,8 +208,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -267,9 +266,7 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - layer_results, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + layer_results, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) encoder_out = layer_results[-1] hyps = [] @@ -285,10 +282,7 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -411,9 +405,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -446,8 +438,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -490,9 +481,7 @@ def main(): params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -524,9 +513,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -553,9 +542,9 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/export.py b/egs/librispeech/ASR/pruned_transducer_stateless6/export.py index cff9c7377..3d1e7bc18 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/export.py @@ -51,11 +51,7 @@ import sentencepiece as spm import torch from train import get_params, get_transducer_model -from icefall.checkpoint import ( - average_checkpoints, - find_checkpoints, - load_checkpoint, -) +from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.utils import str2bool @@ -120,8 +116,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) return parser @@ -160,8 +155,7 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -209,9 +203,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/extract_codebook_index.py b/egs/librispeech/ASR/pruned_transducer_stateless6/extract_codebook_index.py index 21409287c..86cf34877 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/extract_codebook_index.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/extract_codebook_index.py @@ -21,9 +21,10 @@ import os from pathlib import Path import torch -from vq_utils import CodebookIndexExtractor from asr_datamodule import LibriSpeechAsrDataModule from hubert_xlarge import HubertXlargeFineTuned +from vq_utils import CodebookIndexExtractor + from icefall.utils import AttributeDict, str2bool diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/hubert_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless6/hubert_decode.py index 49b557814..b8440f90a 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/hubert_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/hubert_decode.py @@ -23,7 +23,6 @@ from pathlib import Path from typing import Dict, List, Tuple import torch - from asr_datamodule import LibriSpeechAsrDataModule from hubert_xlarge import HubertXlargeFineTuned @@ -99,9 +98,7 @@ def decode_dataset( if batch_idx % 20 == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -124,9 +121,7 @@ def save_results( ) test_set_wers[key] = wer - logging.info( - "Wrote detailed error stats to {}".format(errs_filename) - ) + logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = params.res_dir / f"wer-summary-{test_set_name}.txt" @@ -155,9 +150,7 @@ def main(): # reset some parameters needed by hubert. params.update(HubertXlargeFineTuned.get_params()) - params.res_dir = ( - params.exp_dir / f"ctc_greedy_search-{params.teacher_model_id}" - ) + params.res_dir = params.exp_dir / f"ctc_greedy_search-{params.teacher_model_id}" setup_logger(f"{params.res_dir}/log/log-ctc_greedy_search") logging.info("Decoding started") @@ -190,9 +183,7 @@ def main(): params=params, ) - save_results( - params=params, test_set_name=test_set, results_dict=results_dict - ) + save_results(params=params, test_set_name=test_set, results_dict=results_dict) logging.info("Done!") diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/hubert_xlarge.py b/egs/librispeech/ASR/pruned_transducer_stateless6/hubert_xlarge.py index 55ce7b00d..4f9417c9f 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/hubert_xlarge.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/hubert_xlarge.py @@ -22,11 +22,7 @@ from pathlib import Path from typing import Dict, List, Tuple import torch -from fairseq import ( - checkpoint_utils, - tasks, - utils, -) +from fairseq import checkpoint_utils, tasks, utils from fairseq.data.data_utils import post_process from omegaconf import OmegaConf @@ -51,9 +47,7 @@ def _load_hubert_model(params: AttributeDict): "data": str(params.hubert_model_dir), } ) - model_path = Path(params.hubert_model_dir) / ( - params.teacher_model_id + ".pt" - ) + model_path = Path(params.hubert_model_dir) / (params.teacher_model_id + ".pt") task = tasks.setup_task(cfg_task) processor = task.target_dictionary models, saved_cfg = checkpoint_utils.load_model_ensemble( @@ -151,9 +145,7 @@ class HubertXlargeFineTuned: supervisions = batch["supervisions"] num_samples = supervisions["num_samples"] B, T = features.shape - padding_mask = torch.arange(0, T).expand(B, T) > num_samples.reshape( - [-1, 1] - ) + padding_mask = torch.arange(0, T).expand(B, T) > num_samples.reshape([-1, 1]) padding_mask = padding_mask.to(self.params.device) features = features.to(self.params.device) @@ -163,9 +155,7 @@ class HubertXlargeFineTuned: features = features.transpose(1, 2) features = self.w2v_model.layer_norm(features) - padding_mask = self.w2v_model.forward_padding_mask( - features, padding_mask - ) + padding_mask = self.w2v_model.forward_padding_mask(features, padding_mask) if self.w2v_model.post_extract_proj is not None: features = self.w2v_model.post_extract_proj(features) @@ -212,9 +202,7 @@ class HubertXlargeFineTuned: toks = encoder_out.argmax(dim=-1) blank = 0 toks = [tok.unique_consecutive() for tok in toks] - hyps = [ - self.processor.string(tok[tok != blank].int().cpu()) for tok in toks - ] + hyps = [self.processor.string(tok[tok != blank].int().cpu()) for tok in toks] hyps = [post_process(hyp, "letter") for hyp in hyps] return hyps diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/model.py b/egs/librispeech/ASR/pruned_transducer_stateless6/model.py index 7716d19cf..daadb70c9 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/model.py @@ -69,9 +69,7 @@ class Transducer(nn.Module): self.decoder = decoder self.joiner = joiner - self.simple_am_proj = ScaledLinear( - encoder_dim, vocab_size, initial_speed=0.5 - ) + self.simple_am_proj = ScaledLinear(encoder_dim, vocab_size, initial_speed=0.5) self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size) from icefall import is_module_available @@ -180,9 +178,7 @@ class Transducer(nn.Module): y_padded = y.pad(mode="constant", padding_value=0) y_padded = y_padded.to(torch.int64) - boundary = torch.zeros( - (x.size(0), 4), dtype=torch.int64, device=x.device - ) + boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) boundary[:, 2] = y_lens boundary[:, 3] = x_lens @@ -237,9 +233,7 @@ class Transducer(nn.Module): return (simple_loss, pruned_loss, codebook_loss) @staticmethod - def concat_successive_codebook_indexes( - middle_layer_output, codebook_indexes - ): + def concat_successive_codebook_indexes(middle_layer_output, codebook_indexes): # Output rate of hubert is 50 frames per second, # while that of current encoder is 25. # Following code handling two issues: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/train.py b/egs/librispeech/ASR/pruned_transducer_stateless6/train.py index f717d85fb..a24becb14 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/train.py @@ -101,9 +101,7 @@ from icefall.utils import ( str2bool, ) -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def get_parser(): @@ -203,8 +201,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( @@ -227,8 +224,7 @@ def get_parser(): "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network)" "part.", ) parser.add_argument( @@ -569,9 +565,7 @@ def save_checkpoint( def extract_codebook_indexes(batch): cuts = batch["supervisions"]["cut"] # -100 is identical to ignore_value in CE loss computation. - cuts_pre_mixed = [ - c if isinstance(c, MonoCut) else c.tracks[0].cut for c in cuts - ] + cuts_pre_mixed = [c if isinstance(c, MonoCut) else c.tracks[0].cut for c in cuts] codebook_indexes, codebook_indexes_lens = collate_custom_field( cuts_pre_mixed, "codebook_indexes", pad_value=-100 ) @@ -604,11 +598,7 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = ( - model.device - if isinstance(model, DDP) - else next(model.parameters()).device - ) + device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -655,9 +645,7 @@ def compute_loss( # If the batch contains more than 10 utterances AND # if either all simple_loss or pruned_loss is inf or nan, # we stop the training process by raising an exception - if torch.all(~simple_loss_is_finite) or torch.all( - ~pruned_loss_is_finite - ): + if torch.all(~simple_loss_is_finite) or torch.all(~pruned_loss_is_finite): raise ValueError( "There are too many utterances in this batch " "leading to inf or nan losses." @@ -670,14 +658,9 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 - if warmup < 1.0 - else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = ( - params.simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss + 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) ) + loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss if is_training and params.enable_distillation: assert codebook_loss is not None loss += params.codebook_loss_scale * codebook_loss @@ -690,9 +673,7 @@ def compute_loss( # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2 # (2) If some utterances in the batch lead to inf/nan loss, they # are filtered out. - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa info["utterances"] = feature.size(0) @@ -873,9 +854,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -1007,8 +986,7 @@ def run(rank, world_size, args): # the threshold if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. " - f"Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. " f"Duration: {c.duration}" ) return False diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/vq_utils.py b/egs/librispeech/ASR/pruned_transducer_stateless6/vq_utils.py index 47cf2b14b..97a83b974 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/vq_utils.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/vq_utils.py @@ -68,9 +68,7 @@ class CodebookIndexExtractor: def init_dirs(self): # vq_dir is the root dir for quantization, containing: # training data, trained quantizer, and extracted codebook indexes - self.vq_dir = ( - self.params.exp_dir / f"vq/{self.params.teacher_model_id}/" - ) + self.vq_dir = self.params.exp_dir / f"vq/{self.params.teacher_model_id}/" self.vq_dir.mkdir(parents=True, exist_ok=True) # manifest_dir contains: @@ -208,9 +206,7 @@ class CodebookIndexExtractor: start = cur_offset % (data.shape[0] + 1 - B) end = start + B cur_offset += B - yield data[start:end, :].to(self.params.device).to( - dtype=torch.float - ) + yield data[start:end, :].to(self.params.device).to(dtype=torch.float) for x in minibatch_generator(train, repeat=True): trainer.step(x) @@ -227,9 +223,7 @@ class CodebookIndexExtractor: """ for subset in self.params.subsets: logging.info(f"About to split {subset}.") - ori_manifest = ( - f"./data/fbank/librispeech_cuts_train-{subset}.jsonl.gz" - ) + ori_manifest = f"./data/fbank/librispeech_cuts_train-{subset}.jsonl.gz" split_cmd = f"lhotse split {self.params.world_size} {ori_manifest} {self.manifest_dir}" os.system(f"{split_cmd}") @@ -240,16 +234,13 @@ class CodebookIndexExtractor: logging.info("Start to join manifest files.") for subset in self.params.subsets: vq_manifest_path = ( - self.dst_manifest_dir - / f"librispeech_cuts_train-{subset}-vq.jsonl.gz" + self.dst_manifest_dir / f"librispeech_cuts_train-{subset}-vq.jsonl.gz" ) ori_manifest_path = ( - self.ori_manifest_dir - / f"librispeech_cuts_train-{subset}.jsonl.gz" + self.ori_manifest_dir / f"librispeech_cuts_train-{subset}.jsonl.gz" ) dst_vq_manifest_path = ( - self.dst_manifest_dir - / f"librispeech_cuts_train-{subset}.jsonl.gz" + self.dst_manifest_dir / f"librispeech_cuts_train-{subset}.jsonl.gz" ) cuts_vq = load_manifest(vq_manifest_path) cuts_ori = load_manifest(ori_manifest_path) @@ -269,8 +260,7 @@ class CodebookIndexExtractor: for subset in self.params.subsets: vq_manifests = f"{self.manifest_dir}/with_codebook_indexes-librispeech-cuts_train-{subset}*.jsonl.gz" dst_vq_manifest = ( - self.dst_manifest_dir - / f"librispeech_cuts_train-{subset}-vq.jsonl.gz" + self.dst_manifest_dir / f"librispeech_cuts_train-{subset}-vq.jsonl.gz" ) if 1 == self.params.world_size: merge_cmd = f"cp {vq_manifests} {dst_vq_manifest}" @@ -330,9 +320,7 @@ class CodebookIndexExtractor: def load_ori_dl(self, subset): if self.params.world_size == 1: - ori_manifest_path = ( - f"./data/fbank/librispeech_cuts_train-{subset}.jsonl.gz" - ) + ori_manifest_path = f"./data/fbank/librispeech_cuts_train-{subset}.jsonl.gz" else: ori_manifest_path = ( self.manifest_dir diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py index 06c5863f1..162966df8 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py @@ -272,8 +272,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -393,9 +392,7 @@ def decode_one_batch( simulate_streaming=True, ) else: - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) hyps = [] @@ -454,10 +451,7 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -588,9 +582,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -623,8 +615,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -679,9 +670,7 @@ def main(): if "LG" in params.decoding_method: params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -718,9 +707,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -747,9 +736,9 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -808,9 +797,7 @@ def main(): decoding_graph.scores *= params.ngram_lm_scale else: word_table = None - decoding_graph = k2.trivial_graph( - params.vocab_size - 1, device=device - ) + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) else: decoding_graph = None word_table = None diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/decoder.py b/egs/librispeech/ASR/pruned_transducer_stateless7/decoder.py index 712dc8ce1..5f90e6375 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/decoder.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/decoder.py @@ -69,7 +69,7 @@ class Decoder(nn.Module): out_channels=decoder_dim, kernel_size=context_size, padding=0, - groups=decoder_dim//4, # group size == 4 + groups=decoder_dim // 4, # group size == 4 bias=False, ) @@ -91,9 +91,7 @@ class Decoder(nn.Module): if self.context_size > 1: embedding_out = embedding_out.permute(0, 2, 1) if need_pad is True: - embedding_out = F.pad( - embedding_out, pad=(self.context_size - 1, 0) - ) + embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0)) else: # During inference time, there is no need to do extra padding # as we only need one output diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/export.py b/egs/librispeech/ASR/pruned_transducer_stateless7/export.py index 5744ea3ea..57af52fb1 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/export.py @@ -176,8 +176,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) add_model_arguments(parser) @@ -215,9 +214,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -244,9 +243,9 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -316,9 +315,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/jit_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7/jit_pretrained.py index e2405d5ef..f469442ed 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/jit_pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/jit_pretrained.py @@ -94,8 +94,7 @@ def read_sound_files( for f in filenames: wave, sample_rate = torchaudio.load(f) assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" + f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}" ) # We use only the first channel ans.append(wave[0]) @@ -267,9 +266,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/joiner.py b/egs/librispeech/ASR/pruned_transducer_stateless7/joiner.py index 7d8de5afe..3ddac2cf2 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/joiner.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/joiner.py @@ -56,9 +56,7 @@ class Joiner(nn.Module): assert encoder_out.shape[:-1] == decoder_out.shape[:-1] if project_input: - logit = self.encoder_proj(encoder_out) + self.decoder_proj( - decoder_out - ) + logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out) else: logit = encoder_out + decoder_out diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/model.py b/egs/librispeech/ASR/pruned_transducer_stateless7/model.py index 53cde6c6f..0e59b0f2f 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/model.py @@ -15,14 +15,15 @@ # limitations under the License. +import random + import k2 import torch import torch.nn as nn -import random from encoder_interface import EncoderInterface +from scaling import penalize_abs_values_gt from icefall.utils import add_sos -from scaling import penalize_abs_values_gt class Transducer(nn.Module): @@ -65,7 +66,8 @@ class Transducer(nn.Module): self.joiner = joiner self.simple_am_proj = nn.Linear( - encoder_dim, vocab_size, + encoder_dim, + vocab_size, ) self.simple_lm_proj = nn.Linear(decoder_dim, vocab_size) @@ -133,18 +135,16 @@ class Transducer(nn.Module): y_padded = y.pad(mode="constant", padding_value=0) y_padded = y_padded.to(torch.int64) - boundary = torch.zeros( - (x.size(0), 4), dtype=torch.int64, device=x.device - ) + boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) boundary[:, 2] = y_lens boundary[:, 3] = x_lens lm = self.simple_lm_proj(decoder_out) am = self.simple_am_proj(encoder_out) - #if self.training and random.random() < 0.25: + # if self.training and random.random() < 0.25: # lm = penalize_abs_values_gt(lm, 100.0, 1.0e-04) - #if self.training and random.random() < 0.25: + # if self.training and random.random() < 0.25: # am = penalize_abs_values_gt(am, 30.0, 1.0e-04) with torch.cuda.amp.autocast(enabled=False): diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py index bb8b0a0e3..8b90c9a0d 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py @@ -14,17 +14,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -from collections import defaultdict -from typing import List, Optional, Union, Tuple, List -from lhotse.utils import fix_random_seed -import torch -from scaling import ActivationBalancer +import contextlib +import logging import random +from collections import defaultdict +from typing import List, Optional, Tuple, Union + +import torch +from lhotse.utils import fix_random_seed +from scaling import ActivationBalancer from torch import Tensor from torch.optim import Optimizer -import logging -import contextlib - class BatchedOptimizer(Optimizer): @@ -37,11 +37,10 @@ class BatchedOptimizer(Optimizer): Args: params: """ + def __init__(self, params, defaults): super(BatchedOptimizer, self).__init__(params, defaults) - - @contextlib.contextmanager def batched_params(self, param_group): """ @@ -73,7 +72,9 @@ class BatchedOptimizer(Optimizer): group: a parameter group, which is a list of parameters; should be one of self.groups. """ - batches = defaultdict(list) # `batches` maps from tuple (dtype_as_str,*shape) to list of nn.Parameter + batches = defaultdict( + list + ) # `batches` maps from tuple (dtype_as_str,*shape) to list of nn.Parameter for p in param_group: key = (str(p.dtype), *p.shape) @@ -82,7 +83,7 @@ class BatchedOptimizer(Optimizer): stacked_params_dict = dict() # turn batches into a list, in deterministic order. - batches = [ batches[key] for key in sorted(batches.keys()) ] + batches = [batches[key] for key in sorted(batches.keys())] # pairs will contain pairs of (stacked_param, state), one for each batch # in `batches`. pairs = [] @@ -94,76 +95,77 @@ class BatchedOptimizer(Optimizer): # group. class Optimizer will take care of saving/loading state. state = self.state[p] p_stacked = torch.stack(batch) - grad = torch.stack([torch.zeros_like(p) if p.grad is None else p.grad for p in batch ]) + grad = torch.stack( + [torch.zeros_like(p) if p.grad is None else p.grad for p in batch] + ) p_stacked.grad = grad stacked_params_dict[key] = p_stacked pairs.append((p_stacked, state)) - yield pairs # <-- calling code will do the actual optimization here! + yield pairs # <-- calling code will do the actual optimization here! for ((stacked_params, _state), batch) in zip(pairs, batches): for i, p in enumerate(batch): # batch is list of Parameter p.copy_(stacked_params[i]) - class ScaledAdam(BatchedOptimizer): """ - Implements 'Scaled Adam', a variant of Adam where we scale each parameter's update - proportional to the norm of that parameter; and also learn the scale of the parameter, - in log space, subject to upper and lower limits (as if we had factored each parameter as - param = underlying_param * log_scale.exp()) + Implements 'Scaled Adam', a variant of Adam where we scale each parameter's update + proportional to the norm of that parameter; and also learn the scale of the parameter, + in log space, subject to upper and lower limits (as if we had factored each parameter as + param = underlying_param * log_scale.exp()) - Args: - params: The parameters or param_groups to optimize (like other Optimizer subclasses) - lr: The learning rate. We will typically use a learning rate schedule that starts - at 0.03 and decreases over time, i.e. much higher than other common - optimizers. - clipping_scale: (e.g. 2.0) - A scale for gradient-clipping: if specified, the normalized gradients - over the whole model will be clipped to have 2-norm equal to - `clipping_scale` times the median 2-norm over the most recent period - of `clipping_update_period` minibatches. By "normalized gradients", - we mean after multiplying by the rms parameter value for this tensor - [for non-scalars]; this is appropriate because our update is scaled - by this quantity. - betas: beta1,beta2 are momentum constants for regular momentum, and moving sum-sq grad. - Must satisfy 0 < beta <= beta2 < 1. - scalar_lr_scale: A scaling factor on the learning rate, that we use to update the - scale of each parameter tensor and scalar parameters of the mode.. - If each parameter were decomposed - as p * p_scale.exp(), where (p**2).mean().sqrt() == 1.0, scalar_lr_scale - would be a the scaling factor on the learning rate of p_scale. - eps: A general-purpose epsilon to prevent division by zero - param_min_rms: Minimum root-mean-square value of parameter tensor, for purposes of - learning the scale on the parameters (we'll constrain the rms of each non-scalar - parameter tensor to be >= this value) - param_max_rms: Maximum root-mean-square value of parameter tensor, for purposes of - learning the scale on the parameters (we'll constrain the rms of each non-scalar - parameter tensor to be <= this value) - scalar_max: Maximum absolute value for scalar parameters (applicable if your - model has any parameters with numel() == 1). - size_update_period: The periodicity, in steps, with which we update the size (scale) - of the parameter tensor. This is provided to save a little time - in the update. - clipping_update_period: if clipping_scale is specified, this is the period + Args: + params: The parameters or param_groups to optimize (like other Optimizer subclasses) + lr: The learning rate. We will typically use a learning rate schedule that starts + at 0.03 and decreases over time, i.e. much higher than other common + optimizers. + clipping_scale: (e.g. 2.0) + A scale for gradient-clipping: if specified, the normalized gradients + over the whole model will be clipped to have 2-norm equal to + `clipping_scale` times the median 2-norm over the most recent period + of `clipping_update_period` minibatches. By "normalized gradients", + we mean after multiplying by the rms parameter value for this tensor + [for non-scalars]; this is appropriate because our update is scaled + by this quantity. + betas: beta1,beta2 are momentum constants for regular momentum, and moving sum-sq grad. + Must satisfy 0 < beta <= beta2 < 1. + scalar_lr_scale: A scaling factor on the learning rate, that we use to update the + scale of each parameter tensor and scalar parameters of the mode.. + If each parameter were decomposed + as p * p_scale.exp(), where (p**2).mean().sqrt() == 1.0, scalar_lr_scale + would be a the scaling factor on the learning rate of p_scale. + eps: A general-purpose epsilon to prevent division by zero + param_min_rms: Minimum root-mean-square value of parameter tensor, for purposes of + learning the scale on the parameters (we'll constrain the rms of each non-scalar + parameter tensor to be >= this value) + param_max_rms: Maximum root-mean-square value of parameter tensor, for purposes of + learning the scale on the parameters (we'll constrain the rms of each non-scalar + parameter tensor to be <= this value) + scalar_max: Maximum absolute value for scalar parameters (applicable if your + model has any parameters with numel() == 1). + size_update_period: The periodicity, in steps, with which we update the size (scale) + of the parameter tensor. This is provided to save a little time + in the update. + clipping_update_period: if clipping_scale is specified, this is the period """ - def __init__( - self, - params, - lr=3e-02, - clipping_scale=None, - betas=(0.9, 0.98), - scalar_lr_scale=0.1, - eps=1.0e-08, - param_min_rms=1.0e-05, - param_max_rms=3.0, - scalar_max=10.0, - size_update_period=4, - clipping_update_period=100, - ): + def __init__( + self, + params, + lr=3e-02, + clipping_scale=None, + betas=(0.9, 0.98), + scalar_lr_scale=0.1, + eps=1.0e-08, + param_min_rms=1.0e-05, + param_max_rms=3.0, + scalar_max=10.0, + size_update_period=4, + clipping_update_period=100, + ): defaults = dict( lr=lr, @@ -183,7 +185,6 @@ class ScaledAdam(BatchedOptimizer): def __setstate__(self, state): super(ScaledAdam, self).__setstate__(state) - @torch.no_grad() def step(self, closure=None): """Performs a single optimization step. @@ -206,7 +207,9 @@ class ScaledAdam(BatchedOptimizer): # a regular parameter, and will have a .grad, but the 1st dim corresponds to # a stacking dim, it is not a real dim. - if len(batches[0][1]) == 0: # if len(first state) == 0: not yet initialized + if ( + len(batches[0][1]) == 0 + ): # if len(first state) == 0: not yet initialized clipping_scale = 1 else: clipping_scale = self._get_clipping_scale(group, batches) @@ -225,13 +228,9 @@ class ScaledAdam(BatchedOptimizer): self._step_one_batch(group, p, state, clipping_scale) - return loss - def _init_state(self, - group: dict, - p: Tensor, - state: dict): + def _init_state(self, group: dict, p: Tensor, state: dict): """ Initializes state dict for parameter 'p'. Assumes that dim 0 of tensor p is actually the batch dimension, corresponding to batched-together @@ -247,7 +246,7 @@ class ScaledAdam(BatchedOptimizer): state["step"] = 0 - kwargs = {'device':p.device, 'dtype':p.dtype} + kwargs = {"device": p.device, "dtype": p.dtype} # 'delta' implements conventional momentum. There are # several different kinds of update going on, so rather than @@ -255,36 +254,30 @@ class ScaledAdam(BatchedOptimizer): # parameter-change "delta", which combines all forms of # update. this is equivalent to how it's done in Adam, # except for the first few steps. - state["delta"] = torch.zeros_like( - p, memory_format=torch.preserve_format - ) + state["delta"] = torch.zeros_like(p, memory_format=torch.preserve_format) batch_size = p.shape[0] numel = p.numel() // batch_size numel = p.numel() - if numel > 1: # "param_rms" just periodically records the scalar root-mean-square value of # the parameter tensor. # it has a shape like (batch_size, 1, 1, 1, 1) - param_rms = (p**2).mean(dim=list(range(1, p.ndim)), - keepdim=True).sqrt() + param_rms = (p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt() state["param_rms"] = param_rms state["scale_exp_avg_sq"] = torch.zeros_like(param_rms) - state["scale_grads"] = torch.zeros(size_update_period, *param_rms.shape, - **kwargs) - + state["scale_grads"] = torch.zeros( + size_update_period, *param_rms.shape, **kwargs + ) # exp_avg_sq is the weighted sum of scaled gradients. as in Adam. - state["exp_avg_sq"] = torch.zeros_like( - p, memory_format=torch.preserve_format - ) + state["exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format) - def _get_clipping_scale(self, - group: dict, - pairs: List[Tuple[Tensor, dict]]) -> float: + def _get_clipping_scale( + self, group: dict, pairs: List[Tuple[Tensor, dict]] + ) -> float: """ Returns a scalar factor <= 1.0 that dictates gradient clipping, i.e. we will scale the gradients by this amount before applying the rest of the update. @@ -314,57 +307,65 @@ class ScaledAdam(BatchedOptimizer): if p.numel() == p.shape[0]: # a batch of scalars tot_sumsq += (grad**2).sum() # sum() to change shape [1] to [] else: - tot_sumsq += ((grad * state["param_rms"])**2).sum() + tot_sumsq += ((grad * state["param_rms"]) ** 2).sum() tot_norm = tot_sumsq.sqrt() - if not "model_norms" in first_state: - first_state["model_norms"] = torch.zeros(clipping_update_period, - device=p.device) + if "model_norms" not in first_state: + first_state["model_norms"] = torch.zeros( + clipping_update_period, device=p.device + ) first_state["model_norms"][step % clipping_update_period] = tot_norm if step % clipping_update_period == 0: # Print some stats. # We don't reach here if step == 0 because we would have returned # above. - sorted_norms = first_state["model_norms"].sort()[0].to('cpu') + sorted_norms = first_state["model_norms"].sort()[0].to("cpu") quartiles = [] for n in range(0, 5): - index = min(clipping_update_period - 1, - (clipping_update_period // 4) * n) + index = min( + clipping_update_period - 1, (clipping_update_period // 4) * n + ) quartiles.append(sorted_norms[index].item()) median = quartiles[2] threshold = clipping_scale * median first_state["model_norm_threshold"] = threshold - percent_clipped = (first_state["num_clipped"] * 100.0 / clipping_update_period - if "num_clipped" in first_state else 0.0) + percent_clipped = ( + first_state["num_clipped"] * 100.0 / clipping_update_period + if "num_clipped" in first_state + else 0.0 + ) first_state["num_clipped"] = 0 - quartiles = ' '.join([ '%.3e' % x for x in quartiles ]) - logging.info(f"Clipping_scale={clipping_scale}, grad-norm quartiles {quartiles}, " - f"threshold={threshold:.3e}, percent-clipped={percent_clipped:.1f}") + quartiles = " ".join(["%.3e" % x for x in quartiles]) + logging.info( + f"Clipping_scale={clipping_scale}, grad-norm quartiles {quartiles}, " + f"threshold={threshold:.3e}, percent-clipped={percent_clipped:.1f}" + ) if step < clipping_update_period: return 1.0 # We have not yet estimated a norm to clip to. else: try: model_norm_threshold = first_state["model_norm_threshold"] - except: - logging.info("Warning: model_norm_threshold not in state: possibly " - "you changed config when restarting, adding clipping_scale option?") + except KeyError: + logging.info( + "Warning: model_norm_threshold not in state: possibly " + "you changed config when restarting, adding clipping_scale option?" + ) return 1.0 - ans = min(1.0,(model_norm_threshold / (tot_norm + 1.0e-20)).item()) + ans = min(1.0, (model_norm_threshold / (tot_norm + 1.0e-20)).item()) if ans < 1.0: first_state["num_clipped"] += 1 if ans < 0.1: - logging.warn(f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}") + logging.warn( + f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}" + ) return ans - - def _step_one_batch(self, - group: dict, - p: Tensor, - state: dict, - clipping_scale: float): + def _step_one_batch( + self, group: dict, p: Tensor, state: dict, clipping_scale: float + ): """ Do the step for one parameter, which is actually going to be a batch of `real` parameters, with dim 0 as the batch dim. @@ -391,17 +392,18 @@ class ScaledAdam(BatchedOptimizer): # Update the size/scale of p, and set param_rms scale_grads = state["scale_grads"] scale_grads[step % size_update_period] = (p * grad).sum( - dim=list(range(1, p.ndim)), keepdim=True) + dim=list(range(1, p.ndim)), keepdim=True + ) if step % size_update_period == size_update_period - 1: param_rms = state["param_rms"] # shape: (batch_size, 1, 1, ..) - param_rms.copy_((p**2).mean(dim=list(range(1, p.ndim)), - keepdim=True).sqrt()) + param_rms.copy_( + (p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt() + ) if step > 0: # self._size_update() learns the overall scale on the # parameter, by shrinking or expanding it. self._size_update(group, scale_grads, p, state) - if numel == 1: # For parameters with 1 element we just use regular Adam. # Updates delta. @@ -411,24 +413,21 @@ class ScaledAdam(BatchedOptimizer): state["step"] = step + 1 - - def _size_update(self, - group: dict, - scale_grads: Tensor, - p: Tensor, - state: dict) -> None: + def _size_update( + self, group: dict, scale_grads: Tensor, p: Tensor, state: dict + ) -> None: """ - Called only where p.numel() > 1, this updates the scale of the parameter. - If we imagine: p = underlying_param * scale.exp(), and we are doing - gradient descent on underlying param and on scale, this function does the update - on `scale`. + Called only where p.numel() > 1, this updates the scale of the parameter. + If we imagine: p = underlying_param * scale.exp(), and we are doing + gradient descent on underlying param and on scale, this function does the update + on `scale`. - Args: - group: dict to look up configuration values - scale_grads: a tensor of shape (size_update_period, batch_size, 1, 1,...) containing - grads w.r.t. the scales. - p: The parameter to update - state: The state-dict of p + Args: + group: dict to look up configuration values + scale_grads: a tensor of shape (size_update_period, batch_size, 1, 1,...) containing + grads w.r.t. the scales. + p: The parameter to update + state: The state-dict of p """ param_rms = state["param_rms"] @@ -443,25 +442,28 @@ class ScaledAdam(BatchedOptimizer): size_update_period = scale_grads.shape[0] # correct beta2 for the size update period: we will have # faster decay at this level. - beta2_corr = beta2 ** size_update_period + beta2_corr = beta2**size_update_period scale_exp_avg_sq = state["scale_exp_avg_sq"] # shape: (batch_size, 1, 1, ..) scale_exp_avg_sq.mul_(beta2_corr).add_( - (scale_grads ** 2).mean(dim=0), # mean over dim `size_update_period` - alpha=1-beta2_corr) # shape is (batch_size, 1, 1, ...) + (scale_grads**2).mean(dim=0), # mean over dim `size_update_period` + alpha=1 - beta2_corr, + ) # shape is (batch_size, 1, 1, ...) # The 1st time we reach here is when size_step == 1. size_step = (step + 1) // size_update_period - bias_correction2 = 1 - beta2_corr ** size_step + bias_correction2 = 1 - beta2_corr**size_step # we don't bother with bias_correction1; this will help prevent divergence # at the start of training. denom = scale_exp_avg_sq.sqrt() + eps - scale_step = -size_lr * (bias_correction2 ** 0.5) * scale_grads.sum(dim=0) / denom + scale_step = ( + -size_lr * (bias_correction2**0.5) * scale_grads.sum(dim=0) / denom + ) - is_too_small = (param_rms < param_min_rms) - is_too_large = (param_rms > param_max_rms) + is_too_small = param_rms < param_min_rms + is_too_large = param_rms > param_max_rms # when the param gets too small, just don't shrink it any further. scale_step.masked_fill_(is_too_small, 0.0) @@ -469,13 +471,9 @@ class ScaledAdam(BatchedOptimizer): scale_step.masked_fill_(is_too_large, -size_lr * size_update_period) delta = state["delta"] # the factor of (1-beta1) relates to momentum. - delta.add_(p * scale_step, alpha=(1-beta1)) + delta.add_(p * scale_step, alpha=(1 - beta1)) - - def _step(self, - group: dict, - p: Tensor, - state: dict): + def _step(self, group: dict, p: Tensor, state: dict): """ This function does the core update of self.step(), in the case where the members of the batch have more than 1 element. @@ -496,8 +494,7 @@ class ScaledAdam(BatchedOptimizer): step = state["step"] exp_avg_sq = state["exp_avg_sq"] - exp_avg_sq.mul_(beta2).addcmul_(grad, grad, - value=(1-beta2)) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=(1 - beta2)) this_step = state["step"] - (state["zero_step"] if "zero_step" in state else 0) bias_correction2 = 1 - beta2 ** (this_step + 1) @@ -509,17 +506,13 @@ class ScaledAdam(BatchedOptimizer): denom += eps grad = grad / denom - alpha = -lr * (1-beta1) * state["param_rms"].clamp(min=param_min_rms) + alpha = -lr * (1 - beta1) * state["param_rms"].clamp(min=param_min_rms) delta = state["delta"] delta.add_(grad * alpha) p.add_(delta) - - def _step_scalar(self, - group: dict, - p: Tensor, - state: dict): + def _step_scalar(self, group: dict, p: Tensor, state: dict): """ A simplified form of the core update for scalar tensors, where we cannot get a good estimate of the parameter rms. @@ -531,8 +524,7 @@ class ScaledAdam(BatchedOptimizer): grad = p.grad exp_avg_sq = state["exp_avg_sq"] # shape: (batch_size,) - exp_avg_sq.mul_(beta2).addcmul_(grad, grad, - value=1-beta2) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) # bias_correction2 is like in Adam. Don't bother with bias_correction1; # slower update at the start will help stability anyway. @@ -540,12 +532,11 @@ class ScaledAdam(BatchedOptimizer): denom = (exp_avg_sq / bias_correction2).sqrt() + eps delta = state["delta"] - delta.add_(grad / denom, alpha=-lr*(1-beta1)) + delta.add_(grad / denom, alpha=-lr * (1 - beta1)) p.clamp_(min=-scalar_max, max=scalar_max) p.add_(delta) - class LRScheduler(object): """ Base-class for learning rate schedulers where the learning-rate depends on both the @@ -555,18 +546,14 @@ class LRScheduler(object): def __init__(self, optimizer: Optimizer, verbose: bool = False): # Attach optimizer if not isinstance(optimizer, Optimizer): - raise TypeError( - "{} is not an Optimizer".format(type(optimizer).__name__) - ) + raise TypeError("{} is not an Optimizer".format(type(optimizer).__name__)) self.optimizer = optimizer self.verbose = verbose for group in optimizer.param_groups: group.setdefault("base_lr", group["lr"]) - self.base_lrs = [ - group["base_lr"] for group in optimizer.param_groups - ] + self.base_lrs = [group["base_lr"] for group in optimizer.param_groups] self.epoch = 0 self.batch = 0 @@ -680,13 +667,15 @@ class Eden(LRScheduler): def get_lr(self): factor = ( - (self.batch ** 2 + self.lr_batches ** 2) / self.lr_batches ** 2 + (self.batch**2 + self.lr_batches**2) / self.lr_batches**2 ) ** -0.25 * ( - ((self.epoch ** 2 + self.lr_epochs ** 2) / self.lr_epochs ** 2) - ** -0.25 + ((self.epoch**2 + self.lr_epochs**2) / self.lr_epochs**2) ** -0.25 + ) + warmup_factor = ( + 1.0 + if self.batch >= self.warmup_batches + else 0.5 + 0.5 * (self.batch / self.warmup_batches) ) - warmup_factor = (1.0 if self.batch >= self.warmup_batches - else 0.5 + 0.5 * (self.batch / self.warmup_batches)) return [x * factor * warmup_factor for x in self.base_lrs] @@ -745,13 +734,14 @@ class Eve(Optimizer): parameters, if they fall below this we will stop applying weight decay. - .. _Adam\: A Method for Stochastic Optimization: + .. _Adam: A Method for Stochastic Optimization: https://arxiv.org/abs/1412.6980 .. _Decoupled Weight Decay Regularization: https://arxiv.org/abs/1711.05101 .. _On the Convergence of Adam and Beyond: https://openreview.net/forum?id=ryQu7f-RZ """ + def __init__( self, params, @@ -766,17 +756,11 @@ class Eve(Optimizer): if not 0.0 <= eps: raise ValueError("Invalid epsilon value: {}".format(eps)) if not 0.0 <= betas[0] < 1.0: - raise ValueError( - "Invalid beta parameter at index 0: {}".format(betas[0]) - ) + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) if not 0.0 <= betas[1] < 1.0: - raise ValueError( - "Invalid beta parameter at index 1: {}".format(betas[1]) - ) + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) if not 0 <= weight_decay <= 0.1: - raise ValueError( - "Invalid weight_decay value: {}".format(weight_decay) - ) + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) if not 0 < target_rms <= 10.0: raise ValueError("Invalid target_rms value: {}".format(target_rms)) defaults = dict( @@ -812,9 +796,7 @@ class Eve(Optimizer): # Perform optimization step grad = p.grad if grad.is_sparse: - raise RuntimeError( - "AdamW does not support sparse gradients" - ) + raise RuntimeError("AdamW does not support sparse gradients") state = self.state[p] @@ -841,7 +823,7 @@ class Eve(Optimizer): # Decay the first and second moment running average coefficient exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) - denom = (exp_avg_sq.sqrt() * (bias_correction2 ** -0.5)).add_( + denom = (exp_avg_sq.sqrt() * (bias_correction2**-0.5)).add_( group["eps"] ) @@ -852,30 +834,31 @@ class Eve(Optimizer): if p.numel() > 1: # avoid applying this weight-decay on "scaling factors" # (which are scalar). - is_above_target_rms = p.norm() > ( - target_rms * (p.numel() ** 0.5) - ) + is_above_target_rms = p.norm() > (target_rms * (p.numel() ** 0.5)) p.mul_(1 - (weight_decay * is_above_target_rms)) p.addcdiv_(exp_avg, denom, value=-step_size) if random.random() < 0.0005: - step = (exp_avg/denom) * step_size - logging.info(f"Delta rms = {(step**2).mean().item()}, shape = {step.shape}") - + step = (exp_avg / denom) * step_size + logging.info( + f"Delta rms = {(step**2).mean().item()}, shape = {step.shape}" + ) return loss def _test_scaled_adam(hidden_dim: int): import timeit + from scaling import ScaledLinear + E = 100 B = 4 T = 2 logging.info("in test_eve_cain") - #device = torch.device('cuda') - device = torch.device('cpu') + # device = torch.device('cuda') + device = torch.device("cpu") dtype = torch.float32 fix_random_seed(42) @@ -889,79 +872,92 @@ def _test_scaled_adam(hidden_dim: int): fix_random_seed(42) Linear = torch.nn.Linear if iter == 0 else ScaledLinear - m = torch.nn.Sequential(Linear(E, hidden_dim), - torch.nn.PReLU(), - Linear(hidden_dim, hidden_dim), - torch.nn.PReLU(), - Linear(hidden_dim, E), - ).to(device) + m = torch.nn.Sequential( + Linear(E, hidden_dim), + torch.nn.PReLU(), + Linear(hidden_dim, hidden_dim), + torch.nn.PReLU(), + Linear(hidden_dim, E), + ).to(device) - train_pairs = [ (100.0 * torch.randn(B, T, E, device=device, dtype=dtype) * input_magnitudes, - torch.randn(B, T, E, device=device, dtype=dtype) * output_magnitudes) for _ in range(20) ] + train_pairs = [ + ( + 100.0 + * torch.randn(B, T, E, device=device, dtype=dtype) + * input_magnitudes, + torch.randn(B, T, E, device=device, dtype=dtype) * output_magnitudes, + ) + for _ in range(20) + ] - if iter == 0: optim = Eve(m.parameters(), lr=0.003) - elif iter == 1: optim = ScaledAdam(m.parameters(), lr=0.03, clipping_scale=2.0) + if iter == 0: + optim = Eve(m.parameters(), lr=0.003) + elif iter == 1: + optim = ScaledAdam(m.parameters(), lr=0.03, clipping_scale=2.0) scheduler = Eden(optim, lr_batches=200, lr_epochs=5, verbose=False) - start = timeit.default_timer() avg_loss = 0.0 for epoch in range(180): scheduler.step_epoch() - #if epoch == 100 and iter in [2,3]: + # if epoch == 100 and iter in [2,3]: # optim.reset_speedup() # check it doesn't crash. - #if epoch == 130: + # if epoch == 130: # opts = diagnostics.TensorDiagnosticOptions( # 2 ** 22 # ) # allow 4 megabytes per sub-module # diagnostic = diagnostics.attach_diagnostics(m, opts) - - for n, (x,y) in enumerate(train_pairs): + for n, (x, y) in enumerate(train_pairs): y_out = m(x) - loss = ((y_out - y)**2).mean() * 100.0 + loss = ((y_out - y) ** 2).mean() * 100.0 if epoch == 0 and n == 0: avg_loss = loss.item() else: avg_loss = 0.98 * avg_loss + 0.02 * loss.item() if n == 0 and epoch % 5 == 0: - #norm1 = '%.2e' % (m[0].weight**2).mean().sqrt().item() - #norm1b = '%.2e' % (m[0].bias**2).mean().sqrt().item() - #norm2 = '%.2e' % (m[2].weight**2).mean().sqrt().item() - #norm2b = '%.2e' % (m[2].bias**2).mean().sqrt().item() - #scale1 = '%.2e' % (m[0].weight_scale.exp().item()) - #scale1b = '%.2e' % (m[0].bias_scale.exp().item()) - #scale2 = '%.2e' % (m[2].weight_scale.exp().item()) - #scale2b = '%.2e' % (m[2].bias_scale.exp().item()) + # norm1 = '%.2e' % (m[0].weight**2).mean().sqrt().item() + # norm1b = '%.2e' % (m[0].bias**2).mean().sqrt().item() + # norm2 = '%.2e' % (m[2].weight**2).mean().sqrt().item() + # norm2b = '%.2e' % (m[2].bias**2).mean().sqrt().item() + # scale1 = '%.2e' % (m[0].weight_scale.exp().item()) + # scale1b = '%.2e' % (m[0].bias_scale.exp().item()) + # scale2 = '%.2e' % (m[2].weight_scale.exp().item()) + # scale2b = '%.2e' % (m[2].bias_scale.exp().item()) lr = scheduler.get_last_lr()[0] - logging.info(f"Iter {iter}, epoch {epoch}, batch {n}, avg_loss {avg_loss:.4g}, lr={lr:.4e}") #, norms={norm1,norm1b,norm2,norm2b}") # scales={scale1,scale1b,scale2,scale2b} + logging.info( + f"Iter {iter}, epoch {epoch}, batch {n}, avg_loss {avg_loss:.4g}, lr={lr:.4e}" + ) # , norms={norm1,norm1b,norm2,norm2b}") # scales={scale1,scale1b,scale2,scale2b} loss.log().backward() optim.step() optim.zero_grad() scheduler.step_batch() - #diagnostic.print_diagnostics() + # diagnostic.print_diagnostics() stop = timeit.default_timer() logging.info(f"Iter={iter}, Time taken: {stop - start}") logging.info(f"last lr = {scheduler.get_last_lr()}") - #logging.info("state dict = ", scheduler.state_dict()) - #logging.info("optim state_dict = ", optim.state_dict()) + # logging.info("state dict = ", scheduler.state_dict()) + # logging.info("optim state_dict = ", optim.state_dict()) logging.info(f"input_magnitudes = {input_magnitudes}") logging.info(f"output_magnitudes = {output_magnitudes}") - if __name__ == "__main__": torch.set_num_threads(1) torch.set_num_interop_threads(1) logging.getLogger().setLevel(logging.INFO) import subprocess - s = subprocess.check_output("git status -uno .; git log -1; git diff HEAD .", shell=True) + + s = subprocess.check_output( + "git status -uno .; git log -1; git diff HEAD .", shell=True + ) logging.info(s) import sys + if len(sys.argv) > 1: hidden_dim = int(sys.argv[1]) else: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7/pretrained.py index 7fe1e681a..758e0c036 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/pretrained.py @@ -177,8 +177,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -210,8 +209,7 @@ def read_sound_files( for f in filenames: wave, sample_rate = torchaudio.load(f) assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" + f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}" ) # We use only the first channel ans.append(wave[0]) @@ -275,15 +273,11 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) feature_lengths = torch.tensor(feature_lengths, device=device) - encoder_out, encoder_out_lens = model.encoder( - x=features, x_lens=feature_lengths - ) + encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths) num_waves = encoder_out.size(0) hyps = [] @@ -355,9 +349,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 50cedba56..6f63e0629 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -16,12 +16,12 @@ import collections +import logging +import random +from functools import reduce from itertools import repeat from typing import Optional, Tuple, Union -from functools import reduce -import logging -import random import torch import torch.nn as nn import torch.nn.functional as F @@ -32,27 +32,24 @@ from torch.nn import Embedding as ScaledEmbedding class ActivationBalancerFunction(torch.autograd.Function): @staticmethod def forward( - ctx, - x: Tensor, - scale_factor: Tensor, - sign_factor: Optional[Tensor], - channel_dim: int, + ctx, + x: Tensor, + scale_factor: Tensor, + sign_factor: Optional[Tensor], + channel_dim: int, ) -> Tensor: if channel_dim < 0: channel_dim += x.ndim ctx.channel_dim = channel_dim - xgt0 = (x > 0) + xgt0 = x > 0 if sign_factor is None: ctx.save_for_backward(xgt0, scale_factor) else: ctx.save_for_backward(xgt0, scale_factor, sign_factor) return x - @staticmethod - def backward( - ctx, x_grad: Tensor - ) -> Tuple[Tensor, None, None, None]: + def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None]: if len(ctx.saved_tensors) == 3: xgt0, scale_factor, sign_factor = ctx.saved_tensors for _ in range(ctx.channel_dim, x_grad.ndim - 1): @@ -65,14 +62,22 @@ class ActivationBalancerFunction(torch.autograd.Function): scale_factor = scale_factor.unsqueeze(-1) factor = scale_factor * (xgt0.to(x_grad.dtype) - 0.5) neg_delta_grad = x_grad.abs() * factor - return x_grad - neg_delta_grad, None, None, None, + return ( + x_grad - neg_delta_grad, + None, + None, + None, + ) -def _compute_scale_factor(x: Tensor, - channel_dim: int, - min_abs: float, - max_abs: float, - gain_factor: float, - max_factor: float) -> Tensor: + +def _compute_scale_factor( + x: Tensor, + channel_dim: int, + min_abs: float, + max_abs: float, + gain_factor: float, + max_factor: float, +) -> Tensor: if channel_dim < 0: channel_dim += x.ndim sum_dims = [d for d in range(x.ndim) if d != channel_dim] @@ -83,71 +88,76 @@ def _compute_scale_factor(x: Tensor, else: # below_threshold is 0 if x_abs_mean > min_abs, can be at most max_factor if # x_abs)_mean , min_abs. - below_threshold = ((min_abs - x_abs_mean) * (gain_factor / min_abs)).clamp(min=0, max=max_factor) + below_threshold = ((min_abs - x_abs_mean) * (gain_factor / min_abs)).clamp( + min=0, max=max_factor + ) - above_threshold = ((x_abs_mean - max_abs) * (gain_factor / max_abs)).clamp(min=0, max=max_factor) + above_threshold = ((x_abs_mean - max_abs) * (gain_factor / max_abs)).clamp( + min=0, max=max_factor + ) return below_threshold - above_threshold -def _compute_sign_factor(x: Tensor, - channel_dim: int, - min_positive: float, - max_positive: float, - gain_factor: float, - max_factor: float) -> Tensor: + +def _compute_sign_factor( + x: Tensor, + channel_dim: int, + min_positive: float, + max_positive: float, + gain_factor: float, + max_factor: float, +) -> Tensor: if channel_dim < 0: channel_dim += x.ndim sum_dims = [d for d in range(x.ndim) if d != channel_dim] - proportion_positive = torch.mean((x > 0).to(torch.float32), - dim=sum_dims) + proportion_positive = torch.mean((x > 0).to(torch.float32), dim=sum_dims) if min_positive == 0.0: factor1 = 0.0 else: # 0 if proportion_positive >= min_positive, else can be # as large as max_factor. - factor1 = ((min_positive - proportion_positive) * - (gain_factor / min_positive)).clamp_(min=0, max=max_factor) + factor1 = ( + (min_positive - proportion_positive) * (gain_factor / min_positive) + ).clamp_(min=0, max=max_factor) if max_positive == 1.0: factor2 = 0.0 else: # 0 if self.proportion_positive <= max_positive, else can be # as large as -max_factor. - factor2 = ((proportion_positive - max_positive) * - (gain_factor / (1.0 - max_positive))).clamp_(min=0, max=max_factor) + factor2 = ( + (proportion_positive - max_positive) * (gain_factor / (1.0 - max_positive)) + ).clamp_(min=0, max=max_factor) sign_factor = factor1 - factor2 # require min_positive != 0 or max_positive != 1: assert not isinstance(sign_factor, float) return sign_factor - class ActivationScaleBalancerFunction(torch.autograd.Function): """ This object is used in class ActivationBalancer when the user specified min_positive=0, max_positive=1, so there are no constraints on the signs of the activations and only the absolute value has a constraint. """ + @staticmethod def forward( - ctx, - x: Tensor, - sign_factor: Tensor, - scale_factor: Tensor, - channel_dim: int, + ctx, + x: Tensor, + sign_factor: Tensor, + scale_factor: Tensor, + channel_dim: int, ) -> Tensor: if channel_dim < 0: channel_dim += x.ndim ctx.channel_dim = channel_dim - xgt0 = (x > 0) + xgt0 = x > 0 ctx.save_for_backward(xgt0, sign_factor, scale_factor) return x - @staticmethod - def backward( - ctx, x_grad: Tensor - ) -> Tuple[Tensor, None, None, None]: + def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None]: xgt0, sign_factor, scale_factor = ctx.saved_tensors for _ in range(ctx.channel_dim, x_grad.ndim - 1): sign_factor = sign_factor.unsqueeze(-1) @@ -155,18 +165,24 @@ class ActivationScaleBalancerFunction(torch.autograd.Function): factor = sign_factor + scale_factor * (xgt0.to(x_grad.dtype) - 0.5) neg_delta_grad = x_grad.abs() * factor - return x_grad - neg_delta_grad, None, None, None, + return ( + x_grad - neg_delta_grad, + None, + None, + None, + ) class RandomClampFunction(torch.autograd.Function): @staticmethod def forward( - ctx, - x: Tensor, - min: Optional[float], - max: Optional[float], - prob: float, - reflect: float) -> Tensor: + ctx, + x: Tensor, + min: Optional[float], + max: Optional[float], + prob: float, + reflect: float, + ) -> Tensor: x_clamped = torch.clamp(x, min=min, max=max) mask = torch.rand_like(x) < prob ans = torch.where(mask, x_clamped, x) @@ -179,30 +195,32 @@ class RandomClampFunction(torch.autograd.Function): @staticmethod def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None, None, None, None]: - is_same, = ctx.saved_tensors + (is_same,) = ctx.saved_tensors x_grad = ans_grad * is_same.to(ans_grad.dtype) reflect = ctx.reflect - if reflect != 0.0: + if reflect != 0.0: x_grad = x_grad * (1.0 + reflect) - (ans_grad * reflect) return x_grad, None, None, None, None -def random_clamp(x: Tensor, - min: Optional[float] = None, - max: Optional[float] = None, - prob: float = 0.5, - reflect: float = 0.0): + +def random_clamp( + x: Tensor, + min: Optional[float] = None, + max: Optional[float] = None, + prob: float = 0.5, + reflect: float = 0.0, +): return RandomClampFunction.apply(x, min, max, prob, reflect) -def random_cast_to_half(x: Tensor, - min_abs: float = 5.0e-06) -> Tensor: +def random_cast_to_half(x: Tensor, min_abs: float = 5.0e-06) -> Tensor: """ A randomized way of casting a floating point value to half precision. """ if x.dtype == torch.float16: return x x_abs = x.abs() - is_too_small = (x_abs < min_abs) + is_too_small = x_abs < min_abs # for elements where is_too_small is true, random_val will contain +-min_abs with # probability (x.abs() / min_abs), and 0.0 otherwise. [so this preserves expectations, # for those elements]. @@ -215,6 +233,7 @@ class RandomGradFunction(torch.autograd.Function): Does nothing in forward pass; in backward pass, gets rid of very small grads using randomized approach that preserves expectations (intended to reduce roundoff). """ + @staticmethod def forward(ctx, x: Tensor, min_abs: float) -> Tensor: ctx.min_abs = min_abs @@ -223,35 +242,37 @@ class RandomGradFunction(torch.autograd.Function): @staticmethod def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None]: if ans_grad.dtype == torch.float16: - return random_cast_to_half(ans_grad.to(torch.float32), - min_abs=ctx.min_abs), None + return ( + random_cast_to_half(ans_grad.to(torch.float32), min_abs=ctx.min_abs), + None, + ) else: return ans_grad, None + class RandomGrad(torch.nn.Module): """ Gets rid of very small gradients using an expectation-preserving method, intended to increase accuracy of training when using amp (automatic mixed precision) """ - def __init__(self, - min_abs: float = 5.0e-06): + + def __init__(self, min_abs: float = 5.0e-06): super(RandomGrad, self).__init__() self.min_abs = min_abs - def forward(self, - x: Tensor): + def forward(self, x: Tensor): if torch.jit.is_scripting() or not self.training: return x else: return RandomGradFunction.apply(x, self.min_abs) - class SoftmaxFunction(torch.autograd.Function): """ Tries to handle half-precision derivatives in a randomized way that should be more accurate for training than the default behavior. """ + @staticmethod def forward(ctx, x: Tensor, dim: int): ans = x.softmax(dim=dim) @@ -267,7 +288,7 @@ class SoftmaxFunction(torch.autograd.Function): @staticmethod def backward(ctx, ans_grad: Tensor): - ans, = ctx.saved_tensors + (ans,) = ctx.saved_tensors with torch.cuda.amp.autocast(enabled=False): ans_grad = ans_grad.to(torch.float32) ans = ans.to(torch.float32) @@ -276,9 +297,7 @@ class SoftmaxFunction(torch.autograd.Function): return x_grad, None - -def softmax(x: Tensor, - dim: int): +def softmax(x: Tensor, dim: int): if torch.jit.is_scripting(): return x.softmax(dim) @@ -288,20 +307,18 @@ def softmax(x: Tensor, class MaxEigLimiterFunction(torch.autograd.Function): @staticmethod def forward( - ctx, - x: Tensor, - coeffs: Tensor, - direction: Tensor, - channel_dim: int, - grad_scale: float) -> Tensor: + ctx, + x: Tensor, + coeffs: Tensor, + direction: Tensor, + channel_dim: int, + grad_scale: float, + ) -> Tensor: ctx.channel_dim = channel_dim ctx.grad_scale = grad_scale - ctx.save_for_backward(x.detach(), - coeffs.detach(), - direction.detach()) + ctx.save_for_backward(x.detach(), coeffs.detach(), direction.detach()) return x - @staticmethod def backward(ctx, x_grad, *args): with torch.enable_grad(): @@ -311,15 +328,20 @@ class MaxEigLimiterFunction(torch.autograd.Function): x = x_orig.transpose(ctx.channel_dim, -1).reshape(-1, num_channels) new_direction.requires_grad = False x = x - x.mean(dim=0) - x_var = (x ** 2).mean() + x_var = (x**2).mean() x_residual = x - coeffs * new_direction - x_residual_var = (x_residual ** 2).mean() + x_residual_var = (x_residual**2).mean() # `variance_proportion` is the proportion of the variance accounted for # by the top eigen-direction. This is to be minimized. variance_proportion = (x_var - x_residual_var) / (x_var + 1.0e-20) variance_proportion.backward() x_orig_grad = x_orig.grad - x_extra_grad = x_orig.grad * ctx.grad_scale * x_grad.norm() / (x_orig_grad.norm() + 1.0e-20) + x_extra_grad = ( + x_orig.grad + * ctx.grad_scale + * x_grad.norm() + / (x_orig_grad.norm() + 1.0e-20) + ) return x_grad + x_extra_grad.detach(), None, None, None, None @@ -385,15 +407,12 @@ class BasicNorm(torch.nn.Module): # region if it happens to exit it. eps = eps.clamp(min=self.eps_min, max=self.eps_max) scales = ( - torch.mean(x ** 2, dim=self.channel_dim, keepdim=True) + eps.exp() + torch.mean(x**2, dim=self.channel_dim, keepdim=True) + eps.exp() ) ** -0.5 return x * scales - -def ScaledLinear(*args, - initial_scale: float = 1.0, - **kwargs ) -> nn.Linear: +def ScaledLinear(*args, initial_scale: float = 1.0, **kwargs) -> nn.Linear: """ Behaves like a constructor of a modified version of nn.Linear that gives an easy way to set the default initial parameter scale. @@ -412,16 +431,11 @@ def ScaledLinear(*args, with torch.no_grad(): ans.weight[:] *= initial_scale if ans.bias is not None: - torch.nn.init.uniform_(ans.bias, - -0.1 * initial_scale, - 0.1 * initial_scale) + torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale) return ans - -def ScaledConv1d(*args, - initial_scale: float = 1.0, - **kwargs ) -> nn.Conv1d: +def ScaledConv1d(*args, initial_scale: float = 1.0, **kwargs) -> nn.Conv1d: """ Behaves like a constructor of a modified version of nn.Conv1d that gives an easy way to set the default initial parameter scale. @@ -440,13 +454,10 @@ def ScaledConv1d(*args, with torch.no_grad(): ans.weight[:] *= initial_scale if ans.bias is not None: - torch.nn.init.uniform_(ans.bias, - -0.1 * initial_scale, - 0.1 * initial_scale) + torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale) return ans - class ActivationBalancer(torch.nn.Module): """ Modifies the backpropped derivatives of a function to try to encourage, for @@ -486,18 +497,19 @@ class ActivationBalancer(torch.nn.Module): from doing it at the same time. Early in training we may use higher probabilities than this; it will decay to this value. """ + def __init__( - self, - num_channels: int, - channel_dim: int, - min_positive: float = 0.05, - max_positive: float = 0.95, - max_factor: float = 0.04, - sign_gain_factor: float = 0.01, - scale_gain_factor: float = 0.02, - min_abs: float = 0.2, - max_abs: float = 100.0, - min_prob: float = 0.1, + self, + num_channels: int, + channel_dim: int, + min_positive: float = 0.05, + max_positive: float = 0.95, + max_factor: float = 0.04, + sign_gain_factor: float = 0.01, + scale_gain_factor: float = 0.02, + min_abs: float = 0.2, + max_abs: float = 100.0, + min_prob: float = 0.1, ): super(ActivationBalancer, self).__init__() self.num_channels = num_channels @@ -515,9 +527,7 @@ class ActivationBalancer(torch.nn.Module): # We occasionally sync this to a tensor called `count`, that exists to # make sure it is synced to disk when we load and save the model. self.cpu_count = 0 - self.register_buffer('count', torch.tensor(0, dtype=torch.int64)) - - + self.register_buffer("count", torch.tensor(0, dtype=torch.int64)) def forward(self, x: Tensor) -> Tensor: if torch.jit.is_scripting() or not x.requires_grad: @@ -535,26 +545,35 @@ class ActivationBalancer(torch.nn.Module): # the prob of doing some work exponentially decreases from 0.5 till it hits # a floor at min_prob (==0.1, by default) - prob = max(self.min_prob, 0.5 ** (1 + (count/4000.0))) + prob = max(self.min_prob, 0.5 ** (1 + (count / 4000.0))) if random.random() < prob: sign_gain_factor = 0.5 if self.min_positive != 0.0 or self.max_positive != 1.0: - sign_factor = _compute_sign_factor(x, self.channel_dim, - self.min_positive, self.max_positive, - gain_factor=self.sign_gain_factor / prob, - max_factor=self.max_factor) + sign_factor = _compute_sign_factor( + x, + self.channel_dim, + self.min_positive, + self.max_positive, + gain_factor=self.sign_gain_factor / prob, + max_factor=self.max_factor, + ) else: sign_factor = None - - scale_factor = _compute_scale_factor(x, self.channel_dim, - min_abs=self.min_abs, - max_abs=self.max_abs, - gain_factor=self.scale_gain_factor / prob, - max_factor=self.max_factor) + scale_factor = _compute_scale_factor( + x, + self.channel_dim, + min_abs=self.min_abs, + max_abs=self.max_abs, + gain_factor=self.scale_gain_factor / prob, + max_factor=self.max_factor, + ) return ActivationBalancerFunction.apply( - x, scale_factor, sign_factor, self.channel_dim, + x, + scale_factor, + sign_factor, + self.channel_dim, ) else: return _no_op(x) @@ -594,13 +613,12 @@ def _diag(x: Tensor): # like .diag(), but works for tensors with 3 dims. else: (batch, dim, dim) = x.shape x = x.reshape(batch, dim * dim) - x = x[:, ::dim+1] + x = x[:, :: dim + 1] assert x.shape == (batch, dim) return x -def _whitening_metric(x: Tensor, - num_groups: int): +def _whitening_metric(x: Tensor, num_groups: int): """ Computes the "whitening metric", a value which will be 1.0 if all the eigenvalues of of the centered feature covariance are the same within each group's covariance matrix @@ -630,19 +648,17 @@ def _whitening_metric(x: Tensor, # the following expression is what we'd get if we took the matrix product # of each covariance and measured the mean of its trace, i.e. # the same as _diag(torch.matmul(x_covar, x_covar)).mean(). - x_covarsq_mean_diag = (x_covar ** 2).sum() / (num_groups * channels_per_group) + x_covarsq_mean_diag = (x_covar**2).sum() / (num_groups * channels_per_group) # this metric will be >= 1.0; the larger it is, the less 'white' the data was. - metric = x_covarsq_mean_diag / (x_covar_mean_diag ** 2 + 1.0e-20) + metric = x_covarsq_mean_diag / (x_covar_mean_diag**2 + 1.0e-20) return metric class WhiteningPenaltyFunction(torch.autograd.Function): @staticmethod - def forward(ctx, - x: Tensor, - num_groups: int, - whitening_limit: float, - grad_scale: float) -> Tensor: + def forward( + ctx, x: Tensor, num_groups: int, whitening_limit: float, grad_scale: float + ) -> Tensor: ctx.save_for_backward(x) ctx.num_groups = num_groups ctx.whitening_limit = whitening_limit @@ -650,9 +666,8 @@ class WhiteningPenaltyFunction(torch.autograd.Function): return x @staticmethod - def backward(ctx, - x_grad: Tensor): - x_orig, = ctx.saved_tensors + def backward(ctx, x_grad: Tensor): + (x_orig,) = ctx.saved_tensors with torch.enable_grad(): with torch.cuda.amp.autocast(enabled=False): x_detached = x_orig.to(torch.float32).detach() @@ -661,25 +676,28 @@ class WhiteningPenaltyFunction(torch.autograd.Function): metric = _whitening_metric(x_detached, ctx.num_groups) if random.random() < 0.005 or __name__ == "__main__": - logging.info(f"Whitening: num_groups={ctx.num_groups}, num_channels={x_orig.shape[-1]}, " - f"metric={metric.item():.2f} vs. limit={ctx.whitening_limit}") + logging.info( + f"Whitening: num_groups={ctx.num_groups}, num_channels={x_orig.shape[-1]}, " + f"metric={metric.item():.2f} vs. limit={ctx.whitening_limit}" + ) (metric - ctx.whitening_limit).relu().backward() penalty_grad = x_detached.grad - scale = ctx.grad_scale * (x_grad.to(torch.float32).norm() / - (penalty_grad.norm() + 1.0e-20)) + scale = ctx.grad_scale * ( + x_grad.to(torch.float32).norm() / (penalty_grad.norm() + 1.0e-20) + ) penalty_grad = penalty_grad * scale return x_grad + penalty_grad.to(x_grad.dtype), None, None, None - class Whiten(nn.Module): def __init__( - self, - num_groups: int, - whitening_limit: float, - prob: Union[float, Tuple[float,float]], - grad_scale: float): + self, + num_groups: int, + whitening_limit: float, + prob: Union[float, Tuple[float, float]], + grad_scale: float, + ): """ Args: num_groups: the number of groups to divide the channel dim into before @@ -714,8 +732,7 @@ class Whiten(nn.Module): self.grad_scale = grad_scale - def forward(self, - x: Tensor) -> Tensor: + def forward(self, x: Tensor) -> Tensor: """ In the forward pass, this function just returns the input unmodified. In the backward pass, it will modify the gradients to ensure that the @@ -735,19 +752,21 @@ class Whiten(nn.Module): if not x.requires_grad or random.random() > self.prob or self.grad_scale == 0: return _no_op(x) else: - if hasattr(self, 'min_prob') and random.random() < 0.25: + if hasattr(self, "min_prob") and random.random() < 0.25: # occasionally switch between min_prob and max_prob, based on whether # we are above or below the threshold. - if _whitening_metric(x.to(torch.float32), self.num_groups) > self.whitening_limit: + if ( + _whitening_metric(x.to(torch.float32), self.num_groups) + > self.whitening_limit + ): # there would be a change to the grad. self.prob = self.max_prob else: self.prob = self.min_prob - return WhiteningPenaltyFunction.apply(x, - self.num_groups, - self.whitening_limit, - self.grad_scale) + return WhiteningPenaltyFunction.apply( + x, self.num_groups, self.whitening_limit, self.grad_scale + ) class WithLoss(torch.autograd.Function): @@ -755,11 +774,14 @@ class WithLoss(torch.autograd.Function): def forward(ctx, x: Tensor, y: Tensor): ctx.y_shape = y.shape return x + @staticmethod def backward(ctx, ans_grad: Tensor): - return ans_grad, torch.ones(ctx.y_shape, - dtype=ans_grad.dtype, - device=ans_grad.device) + return ans_grad, torch.ones( + ctx.y_shape, dtype=ans_grad.dtype, device=ans_grad.device + ) + + def with_loss(x, y): if torch.jit.is_scripting(): return x @@ -768,7 +790,7 @@ def with_loss(x, y): def _no_op(x: Tensor) -> Tensor: - if (torch.jit.is_scripting()): + if torch.jit.is_scripting(): return x else: # a no-op function that will have a node in the autograd graph, @@ -783,6 +805,7 @@ class Identity(torch.nn.Module): def forward(self, x): return _no_op(x) + class MaxEig(torch.nn.Module): """ Modifies the backpropped derivatives of a function to try to discourage @@ -803,13 +826,14 @@ class MaxEig(torch.nn.Module): scale: determines the scale with which we modify the gradients, relative to the existing / unmodified gradients """ + def __init__( - self, - num_channels: int, - channel_dim: int, - max_var_per_eig: float = 0.2, - min_prob: float = 0.01, - scale: float = 0.01, + self, + num_channels: int, + channel_dim: int, + max_var_per_eig: float = 0.2, + min_prob: float = 0.01, + scale: float = 0.01, ): super(MaxEig, self).__init__() self.num_channels = num_channels @@ -825,7 +849,7 @@ class MaxEig(torch.nn.Module): # random parameters unchanged for comparison direction = torch.arange(num_channels).to(torch.float) direction = direction / direction.norm() - self.register_buffer('max_eig_direction', direction) + self.register_buffer("max_eig_direction", direction) self.min_prob = min_prob # cur_prob is the current probability we'll use to apply the ActivationBalancer. @@ -833,12 +857,12 @@ class MaxEig(torch.nn.Module): # active. self.cur_prob = 1.0 - - def forward(self, x: Tensor) -> Tensor: - if (torch.jit.is_scripting() or - self.max_var_per_eig <= 0 or - random.random() > self.cur_prob): + if ( + torch.jit.is_scripting() + or self.max_var_per_eig <= 0 + or random.random() > self.cur_prob + ): return _no_op(x) with torch.cuda.amp.autocast(enabled=False): @@ -848,7 +872,9 @@ class MaxEig(torch.nn.Module): with torch.no_grad(): x = x.transpose(self.channel_dim, -1).reshape(-1, self.num_channels) x = x - x.mean(dim=0) - new_direction, coeffs = self._find_direction_coeffs(x, self.max_eig_direction) + new_direction, coeffs = self._find_direction_coeffs( + x, self.max_eig_direction + ) x_var = (x**2).mean() x_residual = x - coeffs * new_direction x_residual_var = (x_residual**2).mean() @@ -861,7 +887,9 @@ class MaxEig(torch.nn.Module): self._set_direction(0.1 * self.max_eig_direction + new_direction) if random.random() < 0.01 or __name__ == "__main__": - logging.info(f"variance_proportion = {variance_proportion.item()}, shape={tuple(orig_x.shape)}, cur_prob={self.cur_prob}") + logging.info( + f"variance_proportion = {variance_proportion.item()}, shape={tuple(orig_x.shape)}, cur_prob={self.cur_prob}" + ) if variance_proportion >= self.max_var_per_eig: # The constraint is active. Note, we should quite rarely @@ -869,17 +897,16 @@ class MaxEig(torch.nn.Module): # starting to diverge, should this constraint be active. cur_prob = self.cur_prob self.cur_prob = 1.0 # next time, do the update with probability 1.0. - return MaxEigLimiterFunction.apply(orig_x, coeffs, new_direction, - self.channel_dim, self.scale) + return MaxEigLimiterFunction.apply( + orig_x, coeffs, new_direction, self.channel_dim, self.scale + ) else: # let self.cur_prob exponentially approach self.min_prob, as # long as the constraint is inactive. self.cur_prob = 0.75 * self.cur_prob + 0.25 * self.min_prob return orig_x - - def _set_direction(self, - direction: Tensor): + def _set_direction(self, direction: Tensor): """ Sets self.max_eig_direction to a normalized version of `direction` """ @@ -889,40 +916,39 @@ class MaxEig(torch.nn.Module): if direction_sum - direction_sum == 0: # no inf/nan self.max_eig_direction[:] = direction else: - logging.info(f"Warning: sum of direction in MaxEig is {direction_sum}, " - "num_channels={self.num_channels}, channel_dim={self.channel_dim}") + logging.info( + f"Warning: sum of direction in MaxEig is {direction_sum}, " + "num_channels={self.num_channels}, channel_dim={self.channel_dim}" + ) - - def _find_direction_coeffs(self, - x: Tensor, - prev_direction: Tensor) -> Tuple[Tensor, Tensor, Tensor]: + def _find_direction_coeffs( + self, x: Tensor, prev_direction: Tensor + ) -> Tuple[Tensor, Tensor, Tensor]: """ - Figure out (an approximation to) the proportion of the variance of a set of - feature vectors that can be attributed to the top eigen-direction. - Args: - x: a Tensor of shape (num_frames, num_channels), with num_frames > 1. - prev_direction: a Tensor of shape (num_channels,), that is our previous estimate - of the top eigen-direction, or a random direction if this is the first - iteration. Does not have to be normalized, but should be nonzero. + Figure out (an approximation to) the proportion of the variance of a set of + feature vectors that can be attributed to the top eigen-direction. + Args: + x: a Tensor of shape (num_frames, num_channels), with num_frames > 1. + prev_direction: a Tensor of shape (num_channels,), that is our previous estimate + of the top eigen-direction, or a random direction if this is the first + iteration. Does not have to be normalized, but should be nonzero. - Returns: (cur_direction, coeffs), where: - cur_direction: a Tensor of shape (num_channels,) that is the current - estimate of the top eigen-direction. - coeffs: a Tensor of shape (num_frames, 1) that minimizes, or - approximately minimizes, (x - coeffs * cur_direction).norm() - """ + Returns: (cur_direction, coeffs), where: + cur_direction: a Tensor of shape (num_channels,) that is the current + estimate of the top eigen-direction. + coeffs: a Tensor of shape (num_frames, 1) that minimizes, or + approximately minimizes, (x - coeffs * cur_direction).norm() + """ (num_frames, num_channels) = x.shape assert num_channels > 1 and num_frames > 1 assert prev_direction.shape == (num_channels,) # `coeffs` are the coefficients of `prev_direction` in x. # actually represent the coeffs up to a constant positive factor. coeffs = (x * prev_direction).sum(dim=1, keepdim=True) + 1.0e-10 - cur_direction = (x * coeffs).sum(dim=0) / ((coeffs ** 2).sum() + 1.0e-20) + cur_direction = (x * coeffs).sum(dim=0) / ((coeffs**2).sum() + 1.0e-20) return cur_direction, coeffs - - class DoubleSwishFunction(torch.autograd.Function): """ double_swish(x) = x * torch.sigmoid(x-1) @@ -950,7 +976,7 @@ class DoubleSwishFunction(torch.autograd.Function): y = x * s if requires_grad: - deriv = (y * (1 - s) + s) + deriv = y * (1 - s) + s # notes on derivative of x * sigmoid(x - 1): # https://www.wolframalpha.com/input?i=d%2Fdx+%28x+*+sigmoid%28x-1%29%29 # min \simeq -0.043638. Take floor as -0.043637 so it's a lower bund @@ -959,7 +985,9 @@ class DoubleSwishFunction(torch.autograd.Function): # floors), should be expectation-preserving. floor = -0.043637 ceil = 1.2 - d_scaled = ((deriv - floor) * (255.0 / (ceil - floor)) + torch.rand_like(deriv)) + d_scaled = (deriv - floor) * (255.0 / (ceil - floor)) + torch.rand_like( + deriv + ) if __name__ == "__main__": # for self-testing only. assert d_scaled.min() >= 0.0 @@ -972,12 +1000,12 @@ class DoubleSwishFunction(torch.autograd.Function): @staticmethod def backward(ctx, y_grad: Tensor) -> Tensor: - d, = ctx.saved_tensors + (d,) = ctx.saved_tensors # the same constants as used in forward pass. floor = -0.043637 ceil = 1.2 - d = (d * ((ceil - floor) / 255.0) + floor) - return (y_grad * d) + d = d * ((ceil - floor) / 255.0) + floor + return y_grad * d class DoubleSwish(torch.nn.Module): @@ -990,7 +1018,6 @@ class DoubleSwish(torch.nn.Module): return DoubleSwishFunction.apply(x) - def _test_max_eig(): for proportion in [0.1, 0.5, 10.0]: logging.info(f"proportion = {proportion}") @@ -1002,11 +1029,9 @@ def _test_max_eig(): x.requires_grad = True num_channels = 128 - m = MaxEig(num_channels, - 1, # channel_dim - 0.5, # max_var_per_eig - scale=0.1) # grad_scale - + m = MaxEig( + num_channels, 1, 0.5, scale=0.1 # channel_dim # max_var_per_eig + ) # grad_scale for _ in range(4): y = m(x) @@ -1031,11 +1056,9 @@ def _test_whiten(): x.requires_grad = True num_channels = 128 - m = Whiten(1, # num_groups - 5.0, # whitening_limit, - prob=1.0, - grad_scale=0.1) # grad_scale - + m = Whiten( + 1, 5.0, prob=1.0, grad_scale=0.1 # num_groups # whitening_limit, + ) # grad_scale for _ in range(4): y = m(x) @@ -1049,7 +1072,6 @@ def _test_whiten(): assert not torch.allclose(x.grad, y_grad) - def _test_activation_balancer_sign(): probs = torch.arange(0, 1, 0.01) N = 1000 @@ -1077,9 +1099,7 @@ def _test_activation_balancer_sign(): def _test_activation_balancer_magnitude(): magnitudes = torch.arange(0, 1, 0.01) N = 1000 - x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze( - -1 - ) + x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1) x = x.detach() x.requires_grad = True m = ActivationBalancer( @@ -1111,8 +1131,8 @@ def _test_basic_norm(): y = m(x) assert y.shape == x.shape - x_rms = (x ** 2).mean().sqrt() - y_rms = (y ** 2).mean().sqrt() + x_rms = (x**2).mean().sqrt() + y_rms = (y**2).mean().sqrt() print("x rms = ", x_rms) print("y rms = ", y_rms) assert y_rms < x_rms @@ -1124,30 +1144,27 @@ def _test_double_swish_deriv(): x.requires_grad = True m = DoubleSwish() - tol = ((1.2-(-0.043637))/255.0) + tol = (1.2 - (-0.043637)) / 255.0 torch.autograd.gradcheck(m, x, atol=tol) - # for self-test. x = torch.randn(1000, 1000, dtype=torch.double) * 3.0 x.requires_grad = True y = m(x) - def _test_softmax(): a = torch.randn(2, 10, dtype=torch.float64) b = a.clone() a.requires_grad = True b.requires_grad = True - a.softmax(dim=1)[:,0].sum().backward() + a.softmax(dim=1)[:, 0].sum().backward() print("a grad = ", a.grad) - softmax(b, dim=1)[:,0].sum().backward() + softmax(b, dim=1)[:, 0].sum().backward() print("b grad = ", b.grad) assert torch.allclose(a.grad, b.grad) - if __name__ == "__main__": logging.getLogger().setLevel(logging.INFO) torch.set_num_threads(1) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling_converter.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling_converter.py index 8d357b15f..56165d1f9 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling_converter.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling_converter.py @@ -26,11 +26,7 @@ from typing import List import torch import torch.nn as nn -from scaling import ( - ActivationBalancer, - BasicNorm, - Whiten, -) +from scaling import ActivationBalancer, BasicNorm, Whiten class NonScaledNorm(nn.Module): diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py index 3f27736b3..7160fc54a 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py @@ -84,9 +84,7 @@ from icefall.env import get_env_info from icefall.hooks import register_inf_check_hooks from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: @@ -269,8 +267,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( @@ -293,8 +290,7 @@ def get_parser(): "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network)" "part.", ) parser.add_argument( @@ -646,11 +642,7 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = ( - model.device - if isinstance(model, DDP) - else next(model.parameters()).device - ) + device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -697,9 +689,7 @@ def compute_loss( info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() @@ -870,9 +860,7 @@ def train_one_epoch( # of the grad scaler is configurable, but we can't configure it to have different # behavior depending on the current grad scale. cur_grad_scale = scaler._scale.item() - if cur_grad_scale < 1.0 or ( - cur_grad_scale < 8.0 and batch_idx % 400 == 0 - ): + if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0): scaler.update(cur_grad_scale * 2.0) if cur_grad_scale < 0.01: logging.warning(f"Grad scale is small: {cur_grad_scale}") @@ -890,11 +878,7 @@ def train_one_epoch( f"batch {batch_idx}, loss[{loss_info}], " f"tot_loss[{tot_loss}], batch size: {batch_size}, " f"lr: {cur_lr:.2e}, " - + ( - f"grad_scale: {scaler._scale.item()}" - if params.use_fp16 - else "" - ) + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") ) if tb_writer is not None: @@ -905,9 +889,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if params.use_fp16: tb_writer.add_scalar( "train/grad_scale", @@ -915,10 +897,7 @@ def train_one_epoch( params.batch_idx_train, ) - if ( - batch_idx % params.valid_interval == 0 - and not params.print_diagnostics - ): + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: logging.info("Computing validation loss") valid_info = compute_validation_loss( params=params, @@ -1009,9 +988,7 @@ def run(rank, world_size, args): logging.info("Using DDP") model = DDP(model, device_ids=[rank], find_unused_parameters=True) - optimizer = ScaledAdam( - model.parameters(), lr=params.base_lr, clipping_scale=2.0 - ) + optimizer = ScaledAdam(model.parameters(), lr=params.base_lr, clipping_scale=2.0) scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) @@ -1029,7 +1006,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2 ** 22 + 2**22 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) @@ -1054,8 +1031,7 @@ def run(rank, world_size, args): # the threshold if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. " - f"Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. " f"Duration: {c.duration}" ) return False diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 023dec97d..b007a7308 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -16,32 +16,35 @@ # limitations under the License. import copy -import math -import warnings import itertools -from typing import List, Optional, Tuple, Union import logging -import torch +import math import random +import warnings +from typing import List, Optional, Tuple, Union + +import torch from encoder_interface import EncoderInterface +from scaling import ( + ScaledLinear, # not as in other dirs.. just scales down initial parameter values. +) from scaling import ( ActivationBalancer, BasicNorm, - MaxEig, DoubleSwish, - ScaledConv1d, - ScaledLinear, # not as in other dirs.. just scales down initial parameter values. - Whiten, Identity, + MaxEig, + ScaledConv1d, + Whiten, _diag, - random_clamp, penalize_abs_values_gt, + random_clamp, softmax, ) from torch import Tensor, nn -from icefall.utils import make_pad_mask from icefall.dist import get_rank +from icefall.utils import make_pad_mask class Zipformer(EncoderInterface): @@ -89,7 +92,7 @@ class Zipformer(EncoderInterface): self.batch_count = 0 self.warmup_end = warmup_batches - for u,d in zip(encoder_unmasked_dims, encoder_dims): + for u, d in zip(encoder_unmasked_dims, encoder_dims): assert u <= d, (u, d) # self.encoder_embed converts the input of shape (N, T, num_features) @@ -97,9 +100,9 @@ class Zipformer(EncoderInterface): # That is, it does two things simultaneously: # (1) subsampling: T -> (T - 7)//2 # (2) embedding: num_features -> encoder_dims - self.encoder_embed = Conv2dSubsampling(num_features, encoder_dims[0], - dropout=dropout) - + self.encoder_embed = Conv2dSubsampling( + num_features, encoder_dims[0], dropout=dropout + ) # each one will be ZipformerEncoder or DownsampledZipformerEncoder encoders = [] @@ -123,13 +126,13 @@ class Zipformer(EncoderInterface): num_encoder_layers[i], dropout, warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1), - warmup_end=warmup_batches * (i + 2) / (num_encoders + 1) + warmup_end=warmup_batches * (i + 2) / (num_encoders + 1), ) if zipformer_downsampling_factors[i] != 1: encoder = DownsampledZipformerEncoder( encoder, - input_dim=encoder_dims[i-1] if i > 0 else encoder_dims[0], + input_dim=encoder_dims[i - 1] if i > 0 else encoder_dims[0], output_dim=encoder_dims[i], downsample=zipformer_downsampling_factors[i], ) @@ -139,10 +142,9 @@ class Zipformer(EncoderInterface): # initializes self.skip_layers and self.skip_modules self._init_skip_modules() - self.downsample_output = AttentionDownsample(encoder_dims[-1], - encoder_dims[-1], - downsample=output_downsampling_factor) - + self.downsample_output = AttentionDownsample( + encoder_dims[-1], encoder_dims[-1], downsample=output_downsampling_factor + ) def _get_layer_skip_dropout_prob(self): if not self.training: @@ -166,27 +168,31 @@ class Zipformer(EncoderInterface): skip_modules = [] z = self.zipformer_downsampling_factors for i in range(len(z)): - if i <= 1 or z[i-1] <= z[i]: + if i <= 1 or z[i - 1] <= z[i]: skip_layers.append(None) skip_modules.append(SimpleCombinerIdentity()) else: # TEMP - for j in range(i-2, -1, -1): + for j in range(i - 2, -1, -1): if z[j] <= z[i] or j == 0: # TEMP logging statement. - logging.info(f"At encoder stack {i}, which has downsampling_factor={z[i]}, we will " - f"combine the outputs of layers {j} and {i-1}, with downsampling_factors={z[j]} and {z[i-1]}.") + logging.info( + f"At encoder stack {i}, which has downsampling_factor={z[i]}, we will " + f"combine the outputs of layers {j} and {i-1}, with downsampling_factors={z[j]} and {z[i-1]}." + ) skip_layers.append(j) - skip_modules.append(SimpleCombiner(self.encoder_dims[j], - self.encoder_dims[i-1], - min_weight=(0.0,0.25))) + skip_modules.append( + SimpleCombiner( + self.encoder_dims[j], + self.encoder_dims[i - 1], + min_weight=(0.0, 0.25), + ) + ) break self.skip_layers = skip_layers self.skip_modules = nn.ModuleList(skip_modules) - def get_feature_masks( - self, - x: torch.Tensor) -> List[float]: + def get_feature_masks(self, x: torch.Tensor) -> List[float]: # Note: The actual return type is Union[List[float], List[Tensor]], # but to make torch.jit.script() work, we use List[float] """ @@ -206,46 +212,56 @@ class Zipformer(EncoderInterface): """ num_encoders = len(self.encoder_dims) if torch.jit.is_scripting() or not self.training: - return [ 1.0 ] * num_encoders + return [1.0] * num_encoders (num_frames0, batch_size, _encoder_dims0) = x.shape - - assert self.encoder_dims[0] == _encoder_dims0, (self.encoder_dims, _encoder_dims0) + assert self.encoder_dims[0] == _encoder_dims0, ( + self.encoder_dims, + _encoder_dims0, + ) max_downsampling_factor = max(self.zipformer_downsampling_factors) - num_frames_max = (num_frames0 + max_downsampling_factor - 1) - + num_frames_max = num_frames0 + max_downsampling_factor - 1 feature_mask_dropout_prob = 0.15 # frame_mask_max shape: (num_frames_max, batch_size, 1) - frame_mask_max = (torch.rand(num_frames_max, batch_size, 1, - device=x.device) > - feature_mask_dropout_prob).to(x.dtype) + frame_mask_max = ( + torch.rand(num_frames_max, batch_size, 1, device=x.device) + > feature_mask_dropout_prob + ).to(x.dtype) feature_masks = [] for i in range(num_encoders): ds = self.zipformer_downsampling_factors[i] - upsample_factor = (max_downsampling_factor // ds) + upsample_factor = max_downsampling_factor // ds - frame_mask = (frame_mask_max.unsqueeze(1).expand(num_frames_max, upsample_factor, - batch_size, 1) - .reshape(num_frames_max * upsample_factor, batch_size, 1)) + frame_mask = ( + frame_mask_max.unsqueeze(1) + .expand(num_frames_max, upsample_factor, batch_size, 1) + .reshape(num_frames_max * upsample_factor, batch_size, 1) + ) num_frames = (num_frames0 + ds - 1) // ds frame_mask = frame_mask[:num_frames] - feature_mask = torch.ones(num_frames, batch_size, self.encoder_dims[i], - dtype=x.dtype, device=x.device) + feature_mask = torch.ones( + num_frames, + batch_size, + self.encoder_dims[i], + dtype=x.dtype, + device=x.device, + ) u = self.encoder_unmasked_dims[i] feature_mask[:, :, u:] *= frame_mask feature_masks.append(feature_mask) return feature_masks - def forward( - self, x: torch.Tensor, x_lens: torch.Tensor, + self, + x: torch.Tensor, + x_lens: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: @@ -271,7 +287,9 @@ class Zipformer(EncoderInterface): outputs = [] feature_masks = self.get_feature_masks(x) - for i, (module, skip_module) in enumerate(zip(self.encoders, self.skip_modules)): + for i, (module, skip_module) in enumerate( + zip(self.encoders, self.skip_modules) + ): ds = self.zipformer_downsampling_factors[i] k = self.skip_layers[i] if isinstance(k, int): @@ -280,9 +298,11 @@ class Zipformer(EncoderInterface): x = skip_module(outputs[k], x) elif (not self.training) or random.random() > layer_skip_dropout_prob: x = skip_module(outputs[k], x) - x = module(x, - feature_mask=feature_masks[i], - src_key_padding_mask=None if mask is None else mask[...,::ds]) + x = module( + x, + feature_mask=feature_masks[i], + src_key_padding_mask=None if mask is None else mask[..., ::ds], + ) outputs.append(x) x = self.downsample_output(x) @@ -312,15 +332,16 @@ class ZipformerEncoderLayer(nn.Module): >>> pos_emb = torch.rand(32, 19, 512) >>> out = encoder_layer(src, pos_emb) """ + def __init__( - self, - d_model: int, - attention_dim: int, - nhead: int, - feedforward_dim: int = 2048, - dropout: float = 0.1, - cnn_module_kernel: int = 31, - pos_dim: int = 4, + self, + d_model: int, + attention_dim: int, + nhead: int, + feedforward_dim: int = 2048, + dropout: float = 0.1, + cnn_module_kernel: int = 31, + pos_dim: int = 4, ) -> None: super(ZipformerEncoderLayer, self).__init__() @@ -330,29 +351,24 @@ class ZipformerEncoderLayer(nn.Module): self.batch_count = 0 self.self_attn = RelPositionMultiheadAttention( - d_model, attention_dim, nhead, pos_dim, dropout=0.0, + d_model, + attention_dim, + nhead, + pos_dim, + dropout=0.0, ) self.pooling = PoolingModule(d_model) - self.feed_forward1 = FeedforwardModule(d_model, - feedforward_dim, - dropout) + self.feed_forward1 = FeedforwardModule(d_model, feedforward_dim, dropout) - self.feed_forward2 = FeedforwardModule(d_model, - feedforward_dim, - dropout) + self.feed_forward2 = FeedforwardModule(d_model, feedforward_dim, dropout) - self.feed_forward3 = FeedforwardModule(d_model, - feedforward_dim, - dropout) + self.feed_forward3 = FeedforwardModule(d_model, feedforward_dim, dropout) + self.conv_module1 = ConvolutionModule(d_model, cnn_module_kernel) - self.conv_module1 = ConvolutionModule(d_model, - cnn_module_kernel) - - self.conv_module2 = ConvolutionModule(d_model, - cnn_module_kernel) + self.conv_module2 = ConvolutionModule(d_model, cnn_module_kernel) self.norm_final = BasicNorm(d_model) @@ -360,14 +376,15 @@ class ZipformerEncoderLayer(nn.Module): # try to ensure the output is close to zero-mean (or at least, zero-median). self.balancer = ActivationBalancer( - d_model, channel_dim=-1, - min_positive=0.45, max_positive=0.55, + d_model, + channel_dim=-1, + min_positive=0.45, + max_positive=0.55, max_abs=6.0, ) - self.whiten = Whiten(num_groups=1, - whitening_limit=5.0, - prob=(0.025, 0.25), - grad_scale=0.01) + self.whiten = Whiten( + num_groups=1, whitening_limit=5.0, prob=(0.025, 0.25), grad_scale=0.01 + ) def get_bypass_scale(self): if torch.jit.is_scripting() or not self.training: @@ -382,8 +399,9 @@ class ZipformerEncoderLayer(nn.Module): if self.batch_count > warmup_period: clamp_min = final_clamp_min else: - clamp_min = (initial_clamp_min - - (self.batch_count / warmup_period) * (initial_clamp_min - final_clamp_min)) + clamp_min = initial_clamp_min - (self.batch_count / warmup_period) * ( + initial_clamp_min - final_clamp_min + ) return self.bypass_scale.clamp(min=clamp_min, max=1.0) def get_dynamic_dropout_rate(self): @@ -398,8 +416,9 @@ class ZipformerEncoderLayer(nn.Module): if self.batch_count > warmup_period: return final_dropout_rate else: - return (initial_dropout_rate - - (initial_dropout_rate * final_dropout_rate) * (self.batch_count / warmup_period)) + return initial_dropout_rate - ( + initial_dropout_rate * final_dropout_rate + ) * (self.batch_count / warmup_period) def forward( self, @@ -508,13 +527,14 @@ class ZipformerEncoder(nn.Module): >>> src = torch.rand(10, 32, 512) >>> out = zipformer_encoder(src) """ + def __init__( - self, - encoder_layer: nn.Module, - num_layers: int, - dropout: float, - warmup_begin: float, - warmup_end: float + self, + encoder_layer: nn.Module, + num_layers: int, + dropout: float, + warmup_begin: float, + warmup_end: float, ) -> None: super().__init__() # will be written to, see set_batch_count() Note: in inference time this @@ -528,8 +548,7 @@ class ZipformerEncoder(nn.Module): # so that we can keep this consistent across worker tasks (for efficiency). self.module_seed = torch.randint(0, 1000, ()).item() - self.encoder_pos = RelPositionalEncoding(encoder_layer.d_model, - dropout) + self.encoder_pos = RelPositionalEncoding(encoder_layer.d_model, dropout) self.layers = nn.ModuleList( [copy.deepcopy(encoder_layer) for i in range(num_layers)] @@ -538,15 +557,13 @@ class ZipformerEncoder(nn.Module): assert 0 <= warmup_begin <= warmup_end, (warmup_begin, warmup_end) - - delta = (1. / num_layers) * (warmup_end - warmup_begin) + delta = (1.0 / num_layers) * (warmup_end - warmup_begin) cur_begin = warmup_begin for i in range(num_layers): self.layers[i].warmup_begin = cur_begin cur_begin += delta self.layers[i].warmup_end = cur_begin - def get_layers_to_drop(self, rnd_seed: int): ans = set() if not self.training: @@ -579,12 +596,14 @@ class ZipformerEncoder(nn.Module): # linearly interpolate t = (batch_count - layer_warmup_begin) / layer_warmup_end assert 0.0 <= t < 1.001, t - return initial_layerdrop_prob + t * (final_layerdrop_prob - initial_layerdrop_prob) + return initial_layerdrop_prob + t * ( + final_layerdrop_prob - initial_layerdrop_prob + ) shared_rng = random.Random(batch_count + self.module_seed) independent_rng = random.Random(rnd_seed) - layerdrop_probs = [ get_layerdrop_prob(i) for i in range(num_layers) ] + layerdrop_probs = [get_layerdrop_prob(i) for i in range(num_layers)] tot = sum(layerdrop_probs) # Instead of drawing the samples independently, we first randomly decide # how many layers to drop out, using the same random number generator between @@ -604,11 +623,12 @@ class ZipformerEncoder(nn.Module): if len(ans) == num_to_drop: break if shared_rng.random() < 0.005 or __name__ == "__main__": - logging.info(f"warmup_begin={self.warmup_begin:.1f}, warmup_end={self.warmup_end:.1f}, " - f"batch_count={batch_count:.1f}, num_to_drop={num_to_drop}, layers_to_drop={ans}") + logging.info( + f"warmup_begin={self.warmup_begin:.1f}, warmup_end={self.warmup_end:.1f}, " + f"batch_count={batch_count:.1f}, num_to_drop={num_to_drop}, layers_to_drop={ans}" + ) return ans - def forward( self, src: Tensor, @@ -639,7 +659,6 @@ class ZipformerEncoder(nn.Module): pos_emb = self.encoder_pos(src) output = src - if torch.jit.is_scripting(): layers_to_drop = [] else: @@ -670,28 +689,27 @@ class DownsampledZipformerEncoder(nn.Module): after convolutional downsampling, and then upsampled again at the output, and combined with the origin input, so that the output has the same shape as the input. """ - def __init__(self, - encoder: nn.Module, - input_dim: int, - output_dim: int, - downsample: int): + + def __init__( + self, encoder: nn.Module, input_dim: int, output_dim: int, downsample: int + ): super(DownsampledZipformerEncoder, self).__init__() self.downsample_factor = downsample self.downsample = AttentionDownsample(input_dim, output_dim, downsample) self.encoder = encoder self.upsample = SimpleUpsample(output_dim, downsample) - self.out_combiner = SimpleCombiner(input_dim, - output_dim, - min_weight=(0.0, 0.25)) + self.out_combiner = SimpleCombiner( + input_dim, output_dim, min_weight=(0.0, 0.25) + ) - - def forward(self, - src: Tensor, - # Note: the type of feature_mask should be Unino[float, Tensor], - # but to make torch.jit.script() happ, we use float here - feature_mask: float = 1.0, - mask: Optional[Tensor] = None, - src_key_padding_mask: Optional[Tensor] = None, + def forward( + self, + src: Tensor, + # Note: the type of feature_mask should be Unino[float, Tensor], + # but to make torch.jit.script() happ, we use float here + feature_mask: float = 1.0, + mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, ) -> Tensor: r"""Downsample, go through encoder, upsample. @@ -718,42 +736,43 @@ class DownsampledZipformerEncoder(nn.Module): src = self.downsample(src) ds = self.downsample_factor if mask is not None: - mask = mask[::ds,::ds] + mask = mask[::ds, ::ds] src = self.encoder( - src, feature_mask=feature_mask, mask=mask, src_key_padding_mask=mask, + src, + feature_mask=feature_mask, + mask=mask, + src_key_padding_mask=mask, ) src = self.upsample(src) # remove any extra frames that are not a multiple of downsample_factor - src = src[:src_orig.shape[0]] + src = src[: src_orig.shape[0]] return self.out_combiner(src_orig, src) + class AttentionDownsample(torch.nn.Module): """ Does downsampling with attention, by weighted sum, and a projection.. """ - def __init__(self, - in_channels: int, - out_channels: int, - downsample: int): + + def __init__(self, in_channels: int, out_channels: int, downsample: int): """ Require out_channels > in_channels. """ super(AttentionDownsample, self).__init__() - self.query = nn.Parameter(torch.randn(in_channels) * (in_channels ** -0.5)) + self.query = nn.Parameter(torch.randn(in_channels) * (in_channels**-0.5)) # fill in the extra dimensions with a projection of the input if out_channels > in_channels: - self.extra_proj = nn.Linear(in_channels * downsample, - out_channels - in_channels, - bias=False) + self.extra_proj = nn.Linear( + in_channels * downsample, out_channels - in_channels, bias=False + ) else: self.extra_proj = None self.downsample = downsample - def forward(self, - src: Tensor) -> Tensor: + def forward(self, src: Tensor) -> Tensor: """ x: (seq_len, batch_size, in_channels) Returns a tensor of shape @@ -767,16 +786,14 @@ class AttentionDownsample(torch.nn.Module): if seq_len != d_seq_len * ds: # right-pad src, repeating the last element. pad = d_seq_len * ds - seq_len - src_extra = src[src.shape[0]-1:].expand(pad, src.shape[1], src.shape[2]) + src_extra = src[src.shape[0] - 1 :].expand(pad, src.shape[1], src.shape[2]) src = torch.cat((src, src_extra), dim=0) assert src.shape[0] == d_seq_len * ds, (src.shape[0], d_seq_len, ds) src = src.reshape(d_seq_len, ds, batch_size, in_channels) scores = (src * self.query).sum(dim=-1, keepdim=True) - scores = penalize_abs_values_gt(scores, - limit=10.0, - penalty=1.0e-04) + scores = penalize_abs_values_gt(scores, limit=10.0, penalty=1.0e-04) weights = scores.softmax(dim=1) @@ -795,14 +812,12 @@ class SimpleUpsample(torch.nn.Module): A very simple form of upsampling that mostly just repeats the input, but also adds a position-specific bias. """ - def __init__(self, - num_channels: int, - upsample: int): + + def __init__(self, num_channels: int, upsample: int): super(SimpleUpsample, self).__init__() self.bias = nn.Parameter(torch.randn(upsample, num_channels) * 0.01) - def forward(self, - src: Tensor) -> Tensor: + def forward(self, src: Tensor) -> Tensor: """ x: (seq_len, batch_size, num_channels) Returns a tensor of shape @@ -815,6 +830,7 @@ class SimpleUpsample(torch.nn.Module): src = src.reshape(seq_len * upsample, batch_size, num_channels) return src + class SimpleCombinerIdentity(nn.Module): def __init__(self, *args, **kwargs): super().__init__() @@ -822,6 +838,7 @@ class SimpleCombinerIdentity(nn.Module): def forward(self, src1: Tensor, src2: Tensor) -> Tensor: return src1 + class SimpleCombiner(torch.nn.Module): """ A very simple way of combining 2 vectors of 2 different dims, via a @@ -831,18 +848,14 @@ class SimpleCombiner(torch.nn.Module): dim2: the dimension of the second input, e.g. 384. The output will have the same dimension as dim2. """ - def __init__(self, - dim1: int, - dim2: int, - min_weight: Tuple[float] = (0., 0.)): + + def __init__(self, dim1: int, dim2: int, min_weight: Tuple[float] = (0.0, 0.0)): super(SimpleCombiner, self).__init__() assert dim2 >= dim1, (dim2, dim1) self.weight1 = nn.Parameter(torch.zeros(())) self.min_weight = min_weight - def forward(self, - src1: Tensor, - src2: Tensor) -> Tensor: + def forward(self, src1: Tensor, src2: Tensor) -> Tensor: """ src1: (*, dim1) src2: (*, dim2) @@ -853,10 +866,14 @@ class SimpleCombiner(torch.nn.Module): weight1 = self.weight1 if not torch.jit.is_scripting(): - if self.training and random.random() < 0.25 and self.min_weight != (0., 0.): - weight1 = weight1.clamp(min=self.min_weight[0], - max=1.0-self.min_weight[1]) - + if ( + self.training + and random.random() < 0.25 + and self.min_weight != (0.0, 0.0) + ): + weight1 = weight1.clamp( + min=self.min_weight[0], max=1.0 - self.min_weight[1] + ) src1 = src1 * weight1 src2 = src2 * (1.0 - weight1) @@ -869,12 +886,9 @@ class SimpleCombiner(torch.nn.Module): else: src1 = src1[:src2_dim] - return src1 + src2 - - class RelPositionalEncoding(torch.nn.Module): """Relative positional encoding module. @@ -888,9 +902,7 @@ class RelPositionalEncoding(torch.nn.Module): """ - def __init__( - self, d_model: int, dropout_rate: float, max_len: int = 5000 - ) -> None: + def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: """Construct a PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() self.d_model = d_model @@ -905,9 +917,7 @@ class RelPositionalEncoding(torch.nn.Module): # the length of self.pe is 2 * input_len - 1 if self.pe.size(1) >= x.size(0) * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str( - x.device - ): + if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return # Suppose `i` means to the position of query vecotr and `j` means the @@ -955,7 +965,6 @@ class RelPositionalEncoding(torch.nn.Module): return self.dropout(pos_emb) - class RelPositionMultiheadAttention(nn.Module): r"""Multi-Head Attention layer with relative position encoding @@ -992,34 +1001,43 @@ class RelPositionMultiheadAttention(nn.Module): self.head_dim = attention_dim // num_heads self.pos_dim = pos_dim assert self.head_dim % 2 == 0, self.head_dim - assert ( - self.head_dim * num_heads == attention_dim - ), (self.head_dim, num_heads, attention_dim) + assert self.head_dim * num_heads == attention_dim, ( + self.head_dim, + num_heads, + attention_dim, + ) # the initial_scale is supposed to take over the "scaling" factor of # head_dim ** -0.5, dividing it between the query and key. - in_proj_dim = (2 * attention_dim + # query, key - attention_dim // 2 + # value - pos_dim * num_heads) # positional encoding query + in_proj_dim = ( + 2 * attention_dim + + attention_dim // 2 # query, key + + pos_dim * num_heads # value + ) # positional encoding query - self.in_proj = ScaledLinear(embed_dim, in_proj_dim, bias=True, - initial_scale=self.head_dim**-0.25) + self.in_proj = ScaledLinear( + embed_dim, in_proj_dim, bias=True, initial_scale=self.head_dim**-0.25 + ) # self.whiten_values is applied on the values in forward(); # it just copies the keys but prevents low-rank distribution by modifying grads. - self.whiten_values = Whiten(num_groups=num_heads, - whitening_limit=2.0, - prob=(0.025, 0.25), - grad_scale=0.025) - self.whiten_keys = Whiten(num_groups=num_heads, - whitening_limit=2.0, - prob=(0.025, 0.25), - grad_scale=0.025) - + self.whiten_values = Whiten( + num_groups=num_heads, + whitening_limit=2.0, + prob=(0.025, 0.25), + grad_scale=0.025, + ) + self.whiten_keys = Whiten( + num_groups=num_heads, + whitening_limit=2.0, + prob=(0.025, 0.25), + grad_scale=0.025, + ) # linear transformation for positional encoding. - self.linear_pos = ScaledLinear(embed_dim, num_heads * pos_dim, bias=False, - initial_scale=0.05) + self.linear_pos = ScaledLinear( + embed_dim, num_heads * pos_dim, bias=False, initial_scale=0.05 + ) # the following are for diagnosics only, see --print-diagnostics option. # they only copy their inputs. @@ -1031,14 +1049,16 @@ class RelPositionMultiheadAttention(nn.Module): ) self.in_proj2 = nn.Linear(embed_dim, attention_dim // 2, bias=False) - self.out_proj2 = ScaledLinear(attention_dim // 2, embed_dim, bias=True, - initial_scale=0.05) + self.out_proj2 = ScaledLinear( + attention_dim // 2, embed_dim, bias=True, initial_scale=0.05 + ) # self.whiten_values2 is applied on the values in forward2() - self.whiten_values2 = Whiten(num_groups=num_heads, - whitening_limit=2.0, - prob=(0.025, 0.25), - grad_scale=0.025) - + self.whiten_values2 = Whiten( + num_groups=num_heads, + whitening_limit=2.0, + prob=(0.025, 0.25), + grad_scale=0.025, + ) def forward( self, @@ -1098,7 +1118,6 @@ class RelPositionMultiheadAttention(nn.Module): ) return x, weights - def multi_head_attention_forward( self, x_proj: Tensor, @@ -1158,24 +1177,21 @@ class RelPositionMultiheadAttention(nn.Module): pos_dim = self.pos_dim # positional-encoding dim per head assert ( head_dim * num_heads == attention_dim - ), f"attention_dim must be divisible by num_heads: {head_dim}, {num_heads}, {attention_dim}" - + ), f"attention_dim must be divisible by num_heads: {head_dim}, {num_heads}, {attention_dim}" # self-attention - q = x_proj[...,0:attention_dim] - k = x_proj[...,attention_dim:2*attention_dim] + q = x_proj[..., 0:attention_dim] + k = x_proj[..., attention_dim : 2 * attention_dim] value_dim = attention_dim // 2 - v = x_proj[...,2*attention_dim:2*attention_dim+value_dim] + v = x_proj[..., 2 * attention_dim : 2 * attention_dim + value_dim] # p is the position-encoding query, its dimension is num_heads*pos_dim.. - p = x_proj[...,2*attention_dim+value_dim:] - + p = x_proj[..., 2 * attention_dim + value_dim :] k = self.whiten_keys(k) # does nothing in the forward pass. v = self.whiten_values(v) # does nothing in the forward pass. q = self.copy_query(q) # for diagnostics only, does nothing. p = self.copy_pos_query(p) # for diagnostics only, does nothing. - if attn_mask is not None: assert ( attn_mask.dtype == torch.float32 @@ -1195,31 +1211,22 @@ class RelPositionMultiheadAttention(nn.Module): if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, seq_len, seq_len]: - raise RuntimeError( - "The size of the 2D attn_mask is not correct." - ) + raise RuntimeError("The size of the 2D attn_mask is not correct.") elif attn_mask.dim() == 3: if list(attn_mask.size()) != [ bsz * num_heads, seq_len, seq_len, ]: - raise RuntimeError( - "The size of the 3D attn_mask is not correct." - ) + raise RuntimeError("The size of the 3D attn_mask is not correct.") else: raise RuntimeError( - "attn_mask's dimension {} is not supported".format( - attn_mask.dim() - ) + "attn_mask's dimension {} is not supported".format(attn_mask.dim()) ) # attn_mask's dim is 3 now. # convert ByteTensor key_padding_mask to bool - if ( - key_padding_mask is not None - and key_padding_mask.dtype == torch.uint8 - ): + if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: warnings.warn( "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." ) @@ -1230,7 +1237,6 @@ class RelPositionMultiheadAttention(nn.Module): k = k.reshape(seq_len, bsz, num_heads, head_dim) v = v.reshape(seq_len, bsz * num_heads, head_dim // 2).transpose(0, 1) - if key_padding_mask is not None: assert key_padding_mask.size(0) == bsz, "{} == {}".format( key_padding_mask.size(0), bsz @@ -1239,13 +1245,10 @@ class RelPositionMultiheadAttention(nn.Module): key_padding_mask.size(1), seq_len ) - - q = q.permute(1, 2, 0, 3) # (batch, head, time1, head_dim) p = p.permute(1, 2, 0, 3) # (batch, head, time1, pos_dim) k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - seq_len2 = 2 * seq_len - 1 pos = pos.reshape(1, seq_len2, num_heads, pos_dim).permute(0, 2, 3, 1) # pos shape now: (batch, head, pos_dim, seq_len2) @@ -1256,13 +1259,16 @@ class RelPositionMultiheadAttention(nn.Module): # the following .as_strided() expression converts the last axis of pos_weights from relative # to absolute position. I don't know whether I might have got the time-offsets backwards or # not, but let this code define which way round it is supposed to be. - pos_weights = pos_weights.as_strided((bsz, num_heads, seq_len, seq_len), - (pos_weights.stride(0), - pos_weights.stride(1), - pos_weights.stride(2)-pos_weights.stride(3), - pos_weights.stride(3)), - storage_offset=pos_weights.stride(3) * (seq_len - 1)) - + pos_weights = pos_weights.as_strided( + (bsz, num_heads, seq_len, seq_len), + ( + pos_weights.stride(0), + pos_weights.stride(1), + pos_weights.stride(2) - pos_weights.stride(3), + pos_weights.stride(3), + ), + storage_offset=pos_weights.stride(3) * (seq_len - 1), + ) # caution: they are really scores at this point. attn_output_weights = torch.matmul(q, k) + pos_weights @@ -1275,10 +1281,9 @@ class RelPositionMultiheadAttention(nn.Module): # this mechanism instead of, say, a limit on entropy, because once the entropy # gets very small gradients through the softmax can become very small, and # some mechanisms like that become ineffective. - attn_output_weights = penalize_abs_values_gt(attn_output_weights, - limit=25.0, - penalty=1.0e-04) - + attn_output_weights = penalize_abs_values_gt( + attn_output_weights, limit=25.0, penalty=1.0e-04 + ) # attn_output_weights: (batch, head, time1, time2) attn_output_weights = attn_output_weights.view( @@ -1320,20 +1325,16 @@ class RelPositionMultiheadAttention(nn.Module): ) attn_output = torch.bmm(attn_output_weights, v) - assert list(attn_output.size()) == [bsz * num_heads, seq_len, - head_dim // 2] + assert list(attn_output.size()) == [bsz * num_heads, seq_len, head_dim // 2] attn_output = ( attn_output.transpose(0, 1) .contiguous() .view(seq_len, bsz, attention_dim // 2) ) - attn_output = nn.functional.linear( - attn_output, out_proj_weight, out_proj_bias - ) + attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) return attn_output, attn_output_weights - def forward2( self, x: Tensor, @@ -1372,11 +1373,7 @@ class RelPositionMultiheadAttention(nn.Module): # returned value is of shape (seq_len, bsz, embed_dim), like x. return self.out_proj2(attn_output) - - def _print_attn_stats( - self, - attn_weights: Tensor, - attn_output: Tensor): + def _print_attn_stats(self, attn_weights: Tensor, attn_output: Tensor): # attn_weights: (batch_size * num_heads, seq_len, seq_len) # attn_output: (bsz * num_heads, seq_len, head_dim) (n, seq_len, head_dim) = attn_output.shape @@ -1387,39 +1384,48 @@ class RelPositionMultiheadAttention(nn.Module): with torch.cuda.amp.autocast(enabled=False): attn_weights = attn_weights.to(torch.float32) attn_output = attn_output.to(torch.float32) - attn_weights_entropy = -((attn_weights + 1.0e-20).log() * attn_weights).sum( - dim=-1).reshape(bsz, num_heads, seq_len).mean(dim=(0,2)) + attn_weights_entropy = ( + -((attn_weights + 1.0e-20).log() * attn_weights) + .sum(dim=-1) + .reshape(bsz, num_heads, seq_len) + .mean(dim=(0, 2)) + ) attn_output = attn_output.reshape(bsz, num_heads, seq_len, head_dim) - attn_output = attn_output.permute(1, 0, 2, 3).reshape(num_heads, bsz * seq_len, head_dim) + attn_output = attn_output.permute(1, 0, 2, 3).reshape( + num_heads, bsz * seq_len, head_dim + ) attn_output_mean = attn_output.mean(dim=1, keepdim=True) attn_output = attn_output - attn_output_mean - attn_covar = torch.matmul(attn_output.transpose(1, 2), attn_output) / (bsz * seq_len) + attn_covar = torch.matmul(attn_output.transpose(1, 2), attn_output) / ( + bsz * seq_len + ) # attn_covar: (num_heads, head_dim, head_dim) - #eigs, _ = torch.symeig(attn_covar) - #logging.info(f"attn_weights_entropy = {attn_weights_entropy}, output_eigs = {eigs}") + # eigs, _ = torch.symeig(attn_covar) + # logging.info(f"attn_weights_entropy = {attn_weights_entropy}, output_eigs = {eigs}") attn_covar = _diag(attn_covar).mean(dim=1) # (num_heads,) embed_dim = self.in_proj2.weight.shape[1] - in_proj_covar = (self.in_proj2.weight.reshape(num_heads, head_dim, embed_dim) ** 2).mean(dim=(1,2)) - out_proj_covar = (self.out_proj2.weight.reshape(embed_dim, num_heads, head_dim) ** 2).mean(dim=(0,2)) - logging.info(f"attn_weights_entropy = {attn_weights_entropy}, covar={attn_covar}, in_proj_covar={in_proj_covar}, out_proj_covar={out_proj_covar}") - - + in_proj_covar = ( + self.in_proj2.weight.reshape(num_heads, head_dim, embed_dim) ** 2 + ).mean(dim=(1, 2)) + out_proj_covar = ( + self.out_proj2.weight.reshape(embed_dim, num_heads, head_dim) ** 2 + ).mean(dim=(0, 2)) + logging.info( + f"attn_weights_entropy = {attn_weights_entropy}, covar={attn_covar}, in_proj_covar={in_proj_covar}, out_proj_covar={out_proj_covar}" + ) class PoolingModule(nn.Module): """ Averages the input over the time dimension and project with a square matrix. """ - def __init__(self, - d_model: int): - super().__init__() - self.proj = ScaledLinear(d_model, d_model, - initial_scale=0.1, bias=False) - def forward(self, - x: Tensor, - key_padding_mask: Optional[Tensor] = None): + def __init__(self, d_model: int): + super().__init__() + self.proj = ScaledLinear(d_model, d_model, initial_scale=0.1, bias=False) + + def forward(self, x: Tensor, key_padding_mask: Optional[Tensor] = None): """ Args: x: a Tensor of shape (T, N, C) @@ -1430,7 +1436,7 @@ class PoolingModule(nn.Module): """ if key_padding_mask is not None: pooling_mask = key_padding_mask.logical_not().to(x.dtype) # (N, T) - pooling_mask = (pooling_mask / pooling_mask.sum(dim=1, keepdim=True)) + pooling_mask = pooling_mask / pooling_mask.sum(dim=1, keepdim=True) pooling_mask = pooling_mask.transpose(0, 1).contiguous().unsqueeze(-1) # now pooling_mask: (T, N, 1) x = (x * pooling_mask).sum(dim=0, keepdim=True) @@ -1444,24 +1450,19 @@ class PoolingModule(nn.Module): class FeedforwardModule(nn.Module): - """Feedforward module in Zipformer model. - """ - def __init__(self, - d_model: int, - feedforward_dim: int, - dropout: float): + """Feedforward module in Zipformer model.""" + + def __init__(self, d_model: int, feedforward_dim: int, dropout: float): super(FeedforwardModule, self).__init__() self.in_proj = nn.Linear(d_model, feedforward_dim) - self.balancer = ActivationBalancer(feedforward_dim, - channel_dim=-1, max_abs=10.0, - min_prob=0.25) + self.balancer = ActivationBalancer( + feedforward_dim, channel_dim=-1, max_abs=10.0, min_prob=0.25 + ) self.activation = DoubleSwish() self.dropout = nn.Dropout(dropout) - self.out_proj = ScaledLinear(feedforward_dim, d_model, - initial_scale=0.01) + self.out_proj = ScaledLinear(feedforward_dim, d_model, initial_scale=0.01) - def forward(self, - x: Tensor): + def forward(self, x: Tensor): x = self.in_proj(x) x = self.balancer(x) x = self.activation(x) @@ -1481,9 +1482,7 @@ class ConvolutionModule(nn.Module): """ - def __init__( - self, channels: int, kernel_size: int, bias: bool = True - ) -> None: + def __init__(self, channels: int, kernel_size: int, bias: bool = True) -> None: """Construct an ConvolutionModule object.""" super(ConvolutionModule, self).__init__() # kernerl_size should be a odd number for 'SAME' padding @@ -1513,7 +1512,10 @@ class ConvolutionModule(nn.Module): # the correct range. self.deriv_balancer1 = ActivationBalancer( 2 * channels, - channel_dim=1, max_abs=10.0, min_positive=0.05, max_positive=1.0 + channel_dim=1, + max_abs=10.0, + min_positive=0.05, + max_positive=1.0, ) self.depthwise_conv = nn.Conv1d( @@ -1527,8 +1529,10 @@ class ConvolutionModule(nn.Module): ) self.deriv_balancer2 = ActivationBalancer( - channels, channel_dim=1, - min_positive=0.05, max_positive=1.0, + channels, + channel_dim=1, + min_positive=0.05, + max_positive=1.0, max_abs=20.0, ) @@ -1544,9 +1548,10 @@ class ConvolutionModule(nn.Module): initial_scale=0.05, ) - def forward(self, - x: Tensor, - src_key_padding_mask: Optional[Tensor] = None, + def forward( + self, + x: Tensor, + src_key_padding_mask: Optional[Tensor] = None, ) -> Tensor: """Compute convolution module. @@ -1626,8 +1631,7 @@ class Conv2dSubsampling(nn.Module): kernel_size=3, padding=(0, 1), # (time, freq) ), - ActivationBalancer(layer1_channels, - channel_dim=1), + ActivationBalancer(layer1_channels, channel_dim=1), DoubleSwish(), nn.Conv2d( in_channels=layer1_channels, @@ -1636,24 +1640,21 @@ class Conv2dSubsampling(nn.Module): stride=2, padding=0, ), - ActivationBalancer(layer2_channels, - channel_dim=1), + ActivationBalancer(layer2_channels, channel_dim=1), DoubleSwish(), nn.Conv2d( in_channels=layer2_channels, out_channels=layer3_channels, kernel_size=3, - stride=(1, 2), # (time, freq) + stride=(1, 2), # (time, freq) ), - ActivationBalancer(layer3_channels, - channel_dim=1), + ActivationBalancer(layer3_channels, channel_dim=1), DoubleSwish(), ) out_height = (((in_channels - 1) // 2) - 1) // 2 self.out = ScaledLinear(out_height * layer3_channels, out_channels) self.dropout = nn.Dropout(dropout) - def forward(self, x: torch.Tensor) -> torch.Tensor: """Subsample x. @@ -1674,6 +1675,7 @@ class Conv2dSubsampling(nn.Module): x = self.dropout(x) return x + class AttentionCombine(nn.Module): """ This module combines a list of Tensors, all with the same shape, to @@ -1717,15 +1719,12 @@ class AttentionCombine(nn.Module): self.random_prob = random_prob self.single_prob = single_prob - self.weight = torch.nn.Parameter(torch.zeros(num_channels, - num_inputs)) + self.weight = torch.nn.Parameter(torch.zeros(num_channels, num_inputs)) self.bias = torch.nn.Parameter(torch.zeros(num_inputs)) assert 0 <= random_prob <= 1, random_prob assert 0 <= single_prob <= 1, single_prob - - def forward(self, inputs: List[Tensor]) -> Tensor: """Forward function. Args: @@ -1756,28 +1755,35 @@ class AttentionCombine(nn.Module): if self.training: # random masking.. - mask_start = torch.randint(low=1, high=int(num_inputs / self.random_prob), - size=(num_frames,), device=scores.device).unsqueeze(1) + mask_start = torch.randint( + low=1, + high=int(num_inputs / self.random_prob), + size=(num_frames,), + device=scores.device, + ).unsqueeze(1) # mask will have rows like: [ False, False, False, True, True, .. ] - arange = torch.arange(num_inputs, device=scores.device).unsqueeze(0).expand( - num_frames, num_inputs) + arange = ( + torch.arange(num_inputs, device=scores.device) + .unsqueeze(0) + .expand(num_frames, num_inputs) + ) mask = arange >= mask_start - apply_single_prob = torch.logical_and(torch.rand(size=(num_frames, 1), - device=scores.device) < self.single_prob, - mask_start < num_inputs) - single_prob_mask = torch.logical_and(apply_single_prob, - arange < mask_start - 1) + apply_single_prob = torch.logical_and( + torch.rand(size=(num_frames, 1), device=scores.device) + < self.single_prob, + mask_start < num_inputs, + ) + single_prob_mask = torch.logical_and( + apply_single_prob, arange < mask_start - 1 + ) - mask = torch.logical_or(mask, - single_prob_mask) + mask = torch.logical_or(mask, single_prob_mask) - scores = scores.masked_fill(mask, float('-inf')) + scores = scores.masked_fill(mask, float("-inf")) if self.training and random.random() < 0.1: - scores = penalize_abs_values_gt(scores, - limit=10.0, - penalty=1.0e-04) + scores = penalize_abs_values_gt(scores, limit=10.0, penalty=1.0e-04) weights = scores.softmax(dim=1) @@ -1792,7 +1798,6 @@ class AttentionCombine(nn.Module): return ans - def _test_random_combine(): print("_test_random_combine()") num_inputs = 3 @@ -1801,8 +1806,8 @@ def _test_random_combine(): num_channels=num_channels, num_inputs=num_inputs, random_prob=0.5, - single_prob=0.0) - + single_prob=0.0, + ) x = [torch.ones(3, 4, num_channels) for _ in range(num_inputs)] @@ -1819,7 +1824,10 @@ def _test_zipformer_main(): # Just make sure the forward pass runs. c = Zipformer( - num_features=feature_dim, encoder_dims=(64,96), encoder_unmasked_dims=(48,64), nhead=(4,4) + num_features=feature_dim, + encoder_dims=(64, 96), + encoder_unmasked_dims=(48, 64), + nhead=(4, 4), ) batch_size = 5 seq_len = 20 @@ -1837,19 +1845,18 @@ def _test_zipformer_main(): ) f # to remove flake8 warnings + def _test_conv2d_subsampling(): num_features = 80 encoder_dims = 384 dropout = 0.1 - encoder_embed = Conv2dSubsampling(num_features, encoder_dims, - dropout=dropout) + encoder_embed = Conv2dSubsampling(num_features, encoder_dims, dropout=dropout) for i in range(20, 40): x = torch.rand(2, i, num_features) y = encoder_embed(x) assert (x.shape[1] - 7) // 2 == y.shape[1], (x.shape[1], y.shape[1]) - if __name__ == "__main__": logging.getLogger().setLevel(logging.INFO) torch.set_num_threads(1) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless8/decode.py index 9d7335e77..3d89ae00a 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless8/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/decode.py @@ -273,8 +273,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -394,9 +393,7 @@ def decode_one_batch( simulate_streaming=True, ) else: - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) hyps = [] @@ -455,10 +452,7 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -589,9 +583,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -624,8 +616,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -680,9 +671,7 @@ def main(): if "LG" in params.decoding_method: params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -719,9 +708,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -753,9 +742,9 @@ def main(): ) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -816,9 +805,7 @@ def main(): decoding_graph.scores *= params.ngram_lm_scale else: word_table = None - decoding_graph = k2.trivial_graph( - params.vocab_size - 1, device=device - ) + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) else: decoding_graph = None word_table = None diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/export.py b/egs/librispeech/ASR/pruned_transducer_stateless8/export.py index 49f469e29..0a962149d 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless8/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/export.py @@ -176,8 +176,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) add_model_arguments(parser) @@ -217,9 +216,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -252,9 +251,9 @@ def main(): ) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -326,9 +325,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/jit_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless8/jit_pretrained.py index e79a3a3aa..c458ee5a9 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless8/jit_pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/jit_pretrained.py @@ -94,8 +94,7 @@ def read_sound_files( for f in filenames: wave, sample_rate = torchaudio.load(f) assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" + f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}" ) # We use only the first channel ans.append(wave[0]) @@ -267,9 +266,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/model.py b/egs/librispeech/ASR/pruned_transducer_stateless8/model.py index 497b89136..39a360796 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless8/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/model.py @@ -160,9 +160,7 @@ class Transducer(nn.Module): y_padded = y.pad(mode="constant", padding_value=0) y_padded = y_padded.to(torch.int64) - boundary = torch.zeros( - (x.size(0), 4), dtype=torch.int64, device=x.device - ) + boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) boundary[:, 2] = y_lens boundary[:, 3] = x_lens diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless8/pretrained.py index 373a48fc1..f1f0771ef 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless8/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/pretrained.py @@ -177,8 +177,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -210,8 +209,7 @@ def read_sound_files( for f in filenames: wave, sample_rate = torchaudio.load(f) assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" + f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}" ) # We use only the first channel ans.append(wave[0]) @@ -275,15 +273,11 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) feature_lengths = torch.tensor(feature_lengths, device=device) - encoder_out, encoder_out_lens = model.encoder( - x=features, x_lens=feature_lengths - ) + encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths) num_waves = encoder_out.size(0) hyps = [] @@ -355,9 +349,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/train.py b/egs/librispeech/ASR/pruned_transducer_stateless8/train.py index 2603bb854..ba8ed3ea8 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless8/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/train.py @@ -92,9 +92,7 @@ from icefall.env import get_env_info from icefall.hooks import register_inf_check_hooks from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: @@ -214,8 +212,7 @@ def get_parser(): "--full-libri", type=str2bool, default=True, - help="When enabled, use 960h LibriSpeech. " - "Otherwise, use 100h subset.", + help="When enabled, use 960h LibriSpeech. " "Otherwise, use 100h subset.", ) parser.add_argument( @@ -285,8 +282,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( @@ -309,8 +305,7 @@ def get_parser(): "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network)" "part.", ) parser.add_argument( @@ -691,11 +686,7 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = ( - model.device - if isinstance(model, DDP) - else next(model.parameters()).device - ) + device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -744,9 +735,7 @@ def compute_loss( info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() @@ -952,9 +941,7 @@ def train_one_epoch( # of the grad scaler is configurable, but we can't configure it to have different # behavior depending on the current grad scale. cur_grad_scale = scaler._scale.item() - if cur_grad_scale < 1.0 or ( - cur_grad_scale < 8.0 and batch_idx % 400 == 0 - ): + if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0): scaler.update(cur_grad_scale * 2.0) if cur_grad_scale < 0.01: logging.warning(f"Grad scale is small: {cur_grad_scale}") @@ -975,11 +962,7 @@ def train_one_epoch( f"giga_tot_loss[{giga_tot_loss}], " f"batch size: {batch_size}, " f"lr: {cur_lr:.2e}, " - + ( - f"grad_scale: {scaler._scale.item()}" - if params.use_fp16 - else "" - ) + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") ) if tb_writer is not None: @@ -992,12 +975,8 @@ def train_one_epoch( f"train/current_{prefix}_", params.batch_idx_train, ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) libri_tot_loss.write_summary( tb_writer, "train/libri_tot_", params.batch_idx_train ) @@ -1011,10 +990,7 @@ def train_one_epoch( params.batch_idx_train, ) - if ( - batch_idx % params.valid_interval == 0 - and not params.print_diagnostics - ): + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: logging.info("Computing validation loss") valid_info = compute_validation_loss( params=params, @@ -1054,8 +1030,7 @@ def filter_short_and_long_utterances( # the threshold if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. " - f"Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. " f"Duration: {c.duration}" ) return False @@ -1152,9 +1127,7 @@ def run(rank, world_size, args): logging.info("Using DDP") model = DDP(model, device_ids=[rank], find_unused_parameters=True) - optimizer = ScaledAdam( - model.parameters(), lr=params.base_lr, clipping_scale=2.0 - ) + optimizer = ScaledAdam(model.parameters(), lr=params.base_lr, clipping_scale=2.0) scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) @@ -1172,7 +1145,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2 ** 22 + 2**22 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) @@ -1207,9 +1180,7 @@ def run(rank, world_size, args): train_giga_cuts = train_giga_cuts.repeat(times=None) if args.enable_musan: - cuts_musan = load_manifest( - Path(args.manifest_dir) / "musan_cuts.jsonl.gz" - ) + cuts_musan = load_manifest(Path(args.manifest_dir) / "musan_cuts.jsonl.gz") else: cuts_musan = None diff --git a/egs/librispeech/ASR/streaming_conformer_ctc/README.md b/egs/librispeech/ASR/streaming_conformer_ctc/README.md index 01be7090b..53f383c99 100644 --- a/egs/librispeech/ASR/streaming_conformer_ctc/README.md +++ b/egs/librispeech/ASR/streaming_conformer_ctc/README.md @@ -1,20 +1,20 @@ ## Train and Decode -Commands of data preparation/train/decode steps are almost the same with +Commands of data preparation/train/decode steps are almost the same with ../conformer_ctc experiment except some options. Please read the code and understand following new added options before running this experiment: For data preparation: - + Nothing new. For streaming_conformer_ctc/train.py: - + --dynamic-chunk-training --short-chunk-proportion For streaming_conformer_ctc/streaming_decode.py: - + --chunk-size --tailing-num-frames --simulate-streaming @@ -57,10 +57,10 @@ And check md5sum values again. Finally, following files will be downloaded:

-streaming_models/  
-|-- lang_bpe  
-|   |-- L.pt  
-|   |-- Linv.pt  
+streaming_models/
+|-- lang_bpe
+|   |-- L.pt
+|   |-- Linv.pt
 |   |-- bpe.model
 |   |-- tokens.txt
 |   `-- words.txt
diff --git a/egs/librispeech/ASR/streaming_conformer_ctc/conformer.py b/egs/librispeech/ASR/streaming_conformer_ctc/conformer.py
index ff4c91446..5fe92172e 100644
--- a/egs/librispeech/ASR/streaming_conformer_ctc/conformer.py
+++ b/egs/librispeech/ASR/streaming_conformer_ctc/conformer.py
@@ -309,36 +309,26 @@ class Conformer(Transformer):
 
                 # start chunk_by_chunk decoding
                 offset = 0
-                for cur in range(
-                    0, num_frames - embed_left_context + 1, stride
-                ):
+                for cur in range(0, num_frames - embed_left_context + 1, stride):
                     end = min(cur + decoding_window, num_frames)
                     cur_feature = feature[:, cur:end, :]
                     cur_feature = self.encoder_embed(cur_feature)
-                    cur_embed, cur_pos_emb = self.encoder_pos(
-                        cur_feature, offset
-                    )
-                    cur_embed = cur_embed.permute(
-                        1, 0, 2
-                    )  # (B, T, F) -> (T, B, F)
+                    cur_embed, cur_pos_emb = self.encoder_pos(cur_feature, offset)
+                    cur_embed = cur_embed.permute(1, 0, 2)  # (B, T, F) -> (T, B, F)
 
                     cur_T = cur_feature.size(1)
                     if cur == 0:
                         # for first chunk extract the central pos embedding
-                        pos_emb_central = cur_pos_emb[
-                            0, (chunk_size - 1), :
-                        ].view(1, 1, -1)
+                        pos_emb_central = cur_pos_emb[0, (chunk_size - 1), :].view(
+                            1, 1, -1
+                        )
                         cur_T -= 1
                     pos_emb_positive.append(cur_pos_emb[0, :cur_T].flip(0))
                     pos_emb_negative.append(cur_pos_emb[0, -cur_T:])
                     assert pos_emb_positive[-1].size(0) == cur_T
 
-                    pos_emb_pos = torch.cat(pos_emb_positive, dim=0).unsqueeze(
-                        0
-                    )
-                    pos_emb_neg = torch.cat(pos_emb_negative, dim=0).unsqueeze(
-                        0
-                    )
+                    pos_emb_pos = torch.cat(pos_emb_positive, dim=0).unsqueeze(0)
+                    pos_emb_neg = torch.cat(pos_emb_negative, dim=0).unsqueeze(0)
                     cur_pos_emb = torch.cat(
                         [pos_emb_pos.flip(1), pos_emb_central, pos_emb_neg],
                         dim=1,
@@ -413,9 +403,7 @@ class ConformerEncoderLayer(nn.Module):
         causal: bool = False,
     ) -> None:
         super(ConformerEncoderLayer, self).__init__()
-        self.self_attn = RelPositionMultiheadAttention(
-            d_model, nhead, dropout=0.0
-        )
+        self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0)
 
         self.feed_forward = nn.Sequential(
             nn.Linear(d_model, dim_feedforward),
@@ -431,22 +419,16 @@ class ConformerEncoderLayer(nn.Module):
             nn.Linear(dim_feedforward, d_model),
         )
 
-        self.conv_module = ConvolutionModule(
-            d_model, cnn_module_kernel, causal=causal
-        )
+        self.conv_module = ConvolutionModule(d_model, cnn_module_kernel, causal=causal)
 
-        self.norm_ff_macaron = nn.LayerNorm(
-            d_model
-        )  # for the macaron style FNN module
+        self.norm_ff_macaron = nn.LayerNorm(d_model)  # for the macaron style FNN module
         self.norm_ff = nn.LayerNorm(d_model)  # for the FNN module
         self.norm_mha = nn.LayerNorm(d_model)  # for the MHA module
 
         self.ff_scale = 0.5
 
         self.norm_conv = nn.LayerNorm(d_model)  # for the CNN module
-        self.norm_final = nn.LayerNorm(
-            d_model
-        )  # for the final output of the block
+        self.norm_final = nn.LayerNorm(d_model)  # for the final output of the block
 
         self.dropout = nn.Dropout(dropout)
 
@@ -480,9 +462,7 @@ class ConformerEncoderLayer(nn.Module):
         residual = src
         if self.normalize_before:
             src = self.norm_ff_macaron(src)
-        src = residual + self.ff_scale * self.dropout(
-            self.feed_forward_macaron(src)
-        )
+        src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src))
         if not self.normalize_before:
             src = self.norm_ff_macaron(src)
 
@@ -554,9 +534,7 @@ class ConformerEncoderLayer(nn.Module):
         residual = src
         if self.normalize_before:
             src = self.norm_ff_macaron(src)
-        src = residual + self.ff_scale * self.dropout(
-            self.feed_forward_macaron(src)
-        )
+        src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src))
         if not self.normalize_before:
             src = self.norm_ff_macaron(src)
 
@@ -736,9 +714,7 @@ class RelPositionalEncoding(torch.nn.Module):
 
     """
 
-    def __init__(
-        self, d_model: int, dropout_rate: float, max_len: int = 5000
-    ) -> None:
+    def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None:
         """Construct an PositionalEncoding object."""
         super(RelPositionalEncoding, self).__init__()
         self.d_model = d_model
@@ -755,9 +731,7 @@ class RelPositionalEncoding(torch.nn.Module):
             # the length of self.pe is 2 * input_len - 1
             if self.pe.size(1) >= x_size_1 * 2 - 1:
                 # Note: TorchScript doesn't implement operator== for torch.Device
-                if self.pe.dtype != x.dtype or str(self.pe.device) != str(
-                    x.device
-                ):
+                if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device):
                     self.pe = self.pe.to(dtype=x.dtype, device=x.device)
                 return
         # Suppose `i` means to the position of query vector and `j` means the
@@ -783,9 +757,7 @@ class RelPositionalEncoding(torch.nn.Module):
         pe = torch.cat([pe_positive, pe_negative], dim=1)
         self.pe = pe.to(device=x.device, dtype=x.dtype)
 
-    def forward(
-        self, x: torch.Tensor, offset: int = 0
-    ) -> Tuple[Tensor, Tensor]:
+    def forward(self, x: torch.Tensor, offset: int = 0) -> Tuple[Tensor, Tensor]:
         """Add positional encoding.
 
         Args:
@@ -813,9 +785,7 @@ class RelPositionalEncoding(torch.nn.Module):
             pos_emb = torch.cat(
                 [
                     pos_emb[:, : (x_T - 1)],
-                    self.pe[0, self.pe.size(1) // 2].view(
-                        1, 1, self.pe.size(-1)
-                    ),
+                    self.pe[0, self.pe.size(1) // 2].view(1, 1, self.pe.size(-1)),
                     pos_emb[:, -(x_T - 1) :],  # noqa: E203
                 ],
                 dim=1,
@@ -1050,9 +1020,9 @@ class RelPositionMultiheadAttention(nn.Module):
 
         if torch.equal(query, key) and torch.equal(key, value):
             # self-attention
-            q, k, v = nn.functional.linear(
-                query, in_proj_weight, in_proj_bias
-            ).chunk(3, dim=-1)
+            q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk(
+                3, dim=-1
+            )
 
         elif torch.equal(key, value):
             # encoder-decoder attention
@@ -1120,31 +1090,22 @@ class RelPositionMultiheadAttention(nn.Module):
             if attn_mask.dim() == 2:
                 attn_mask = attn_mask.unsqueeze(0)
                 if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
-                    raise RuntimeError(
-                        "The size of the 2D attn_mask is not correct."
-                    )
+                    raise RuntimeError("The size of the 2D attn_mask is not correct.")
             elif attn_mask.dim() == 3:
                 if list(attn_mask.size()) != [
                     bsz * num_heads,
                     query.size(0),
                     key.size(0),
                 ]:
-                    raise RuntimeError(
-                        "The size of the 3D attn_mask is not correct."
-                    )
+                    raise RuntimeError("The size of the 3D attn_mask is not correct.")
             else:
                 raise RuntimeError(
-                    "attn_mask's dimension {} is not supported".format(
-                        attn_mask.dim()
-                    )
+                    "attn_mask's dimension {} is not supported".format(attn_mask.dim())
                 )
             # attn_mask's dim is 3 now.
 
         # convert ByteTensor key_padding_mask to bool
-        if (
-            key_padding_mask is not None
-            and key_padding_mask.dtype == torch.uint8
-        ):
+        if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
             warnings.warn(
                 "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead."
             )
@@ -1185,24 +1146,16 @@ class RelPositionMultiheadAttention(nn.Module):
         # first compute matrix a and matrix c
         # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
         k = k.permute(1, 2, 3, 0)  # (batch, head, d_k, time2)
-        matrix_ac = torch.matmul(
-            q_with_bias_u, k
-        )  # (batch, head, time1, time2)
+        matrix_ac = torch.matmul(q_with_bias_u, k)  # (batch, head, time1, time2)
 
         # compute matrix b and matrix d
-        matrix_bd = torch.matmul(
-            q_with_bias_v, p
-        )  # (batch, head, time1, 2*time1-1)
-        matrix_bd = self.rel_shift(
-            matrix_bd, offset=offset
-        )  # [B, head, time1, time2]
+        matrix_bd = torch.matmul(q_with_bias_v, p)  # (batch, head, time1, 2*time1-1)
+        matrix_bd = self.rel_shift(matrix_bd, offset=offset)  # [B, head, time1, time2]
         attn_output_weights = (
             matrix_ac + matrix_bd
         ) * scaling  # (batch, head, time1, time2)
 
-        attn_output_weights = attn_output_weights.view(
-            bsz * num_heads, tgt_len, -1
-        )
+        attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1)
 
         assert list(attn_output_weights.size()) == [
             bsz * num_heads,
@@ -1236,13 +1189,9 @@ class RelPositionMultiheadAttention(nn.Module):
         attn_output = torch.bmm(attn_output_weights, v)
         assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
         attn_output = (
-            attn_output.transpose(0, 1)
-            .contiguous()
-            .view(tgt_len, bsz, embed_dim)
-        )
-        attn_output = nn.functional.linear(
-            attn_output, out_proj_weight, out_proj_bias
+            attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
         )
+        attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias)
 
         if need_weights:
             # average attention weights over heads
diff --git a/egs/librispeech/ASR/streaming_conformer_ctc/streaming_decode.py b/egs/librispeech/ASR/streaming_conformer_ctc/streaming_decode.py
index a74c51836..3965bd5c3 100755
--- a/egs/librispeech/ASR/streaming_conformer_ctc/streaming_decode.py
+++ b/egs/librispeech/ASR/streaming_conformer_ctc/streaming_decode.py
@@ -28,6 +28,7 @@ import torch
 import torch.nn as nn
 from asr_datamodule import LibriSpeechAsrDataModule
 from conformer import Conformer
+
 from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
 from icefall.checkpoint import average_checkpoints, load_checkpoint
 from icefall.lexicon import Lexicon
@@ -86,8 +87,7 @@ def get_parser():
         "--tailing-num-frames",
         type=int,
         default=20,
-        help="tailing dummy frames padded to the right,"
-        "only used during decoding",
+        help="tailing dummy frames padded to the right," "only used during decoding",
     )
 
     parser.add_argument(
@@ -248,13 +248,9 @@ def decode_one_batch(
     maxlen = nnet_output.size(1)
     topk_prob, topk_index = nnet_output.topk(1, dim=2)  # (B, maxlen, 1)
     topk_index = topk_index.view(batch_size, maxlen)  # (B, maxlen)
-    topk_index = topk_index.masked_fill_(
-        memory_key_padding_mask, 0
-    )  # (B, maxlen)
+    topk_index = topk_index.masked_fill_(memory_key_padding_mask, 0)  # (B, maxlen)
     token_ids = [token_id.tolist() for token_id in topk_index]
-    token_ids = [
-        remove_duplicates_and_blank(token_id) for token_id in token_ids
-    ]
+    token_ids = [remove_duplicates_and_blank(token_id) for token_id in token_ids]
     hyps = bpe_model.decode(token_ids)
     hyps = [s.split() for s in hyps]
     return {key: hyps}
@@ -337,9 +333,7 @@ def decode_dataset(
         if batch_idx % 100 == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
 
     return results
 
@@ -364,8 +358,7 @@ def save_results(
         -{params.chunk_size}-tailing-num-frames-{params.tailing_num_frames}-"
     for key, results in results_dict.items():
         recog_path = (
-            params.exp_dir
-            / f"{result_file_prefix}recogs-{test_set_name}-{key}.txt"
+            params.exp_dir / f"{result_file_prefix}recogs-{test_set_name}-{key}.txt"
         )
         store_transcripts(filename=recog_path, texts=results)
         if enable_log:
@@ -374,8 +367,7 @@ def save_results(
         # The following prints out WERs, per-word error statistics and aligned
         # ref/hyp pairs.
         errs_filename = (
-            params.exp_dir
-            / f"{result_file_prefix}-errs-{test_set_name}-{key}.txt"
+            params.exp_dir / f"{result_file_prefix}-errs-{test_set_name}-{key}.txt"
         )
         with open(errs_filename, "w") as f:
             wer = write_error_stats(
@@ -384,9 +376,7 @@ def save_results(
             test_set_wers[key] = wer
 
         if enable_log:
-            logging.info(
-                "Wrote detailed error stats to {}".format(errs_filename)
-            )
+            logging.info("Wrote detailed error stats to {}".format(errs_filename))
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt"
@@ -474,9 +464,7 @@ def main():
 
     if params.export:
         logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
-        torch.save(
-            {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
-        )
+        torch.save({"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt")
         return
 
     model.to(device)
@@ -507,9 +495,7 @@ def main():
             simulate_streaming=params.simulate_streaming,
         )
 
-        save_results(
-            params=params, test_set_name=test_set, results_dict=results_dict
-        )
+        save_results(params=params, test_set_name=test_set, results_dict=results_dict)
 
     logging.info("Done!")
 
diff --git a/egs/librispeech/ASR/streaming_conformer_ctc/train.py b/egs/librispeech/ASR/streaming_conformer_ctc/train.py
index e41b7ea78..553b7d092 100755
--- a/egs/librispeech/ASR/streaming_conformer_ctc/train.py
+++ b/egs/librispeech/ASR/streaming_conformer_ctc/train.py
@@ -405,9 +405,7 @@ def compute_loss(
             #
             # See https://github.com/k2-fsa/icefall/issues/97
             # for more details
-            unsorted_token_ids = graph_compiler.texts_to_ids(
-                supervisions["text"]
-            )
+            unsorted_token_ids = graph_compiler.texts_to_ids(supervisions["text"])
             att_loss = mmodel.decoder_forward(
                 encoder_memory,
                 memory_mask,
@@ -436,9 +434,7 @@ def compute_loss(
     info["utt_duration"] = supervisions["num_frames"].sum().item()
     # averaged padding proportion over utterances
     info["utt_pad_proportion"] = (
-        ((feature.size(1) - supervisions["num_frames"]) / feature.size(1))
-        .sum()
-        .item()
+        ((feature.size(1) - supervisions["num_frames"]) / feature.size(1)).sum().item()
     )
 
     return loss, info
@@ -551,9 +547,7 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(
-                    tb_writer, "train/tot_", params.batch_idx_train
-                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
@@ -668,9 +662,7 @@ def run(rank, world_size, args):
 
         cur_lr = optimizer._rate
         if tb_writer is not None:
-            tb_writer.add_scalar(
-                "train/learning_rate", cur_lr, params.batch_idx_train
-            )
+            tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
             tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
 
         if rank == 0:
diff --git a/egs/librispeech/ASR/streaming_conformer_ctc/transformer.py b/egs/librispeech/ASR/streaming_conformer_ctc/transformer.py
index bc78e4a41..0c87fdf1b 100644
--- a/egs/librispeech/ASR/streaming_conformer_ctc/transformer.py
+++ b/egs/librispeech/ASR/streaming_conformer_ctc/transformer.py
@@ -149,9 +149,7 @@ class Transformer(nn.Module):
                 norm=decoder_norm,
             )
 
-            self.decoder_output_layer = torch.nn.Linear(
-                d_model, self.decoder_num_class
-            )
+            self.decoder_output_layer = torch.nn.Linear(d_model, self.decoder_num_class)
 
             self.decoder_criterion = LabelSmoothingLoss()
         else:
@@ -286,23 +284,17 @@ class Transformer(nn.Module):
         """
         ys_in = add_sos(token_ids, sos_id=sos_id)
         ys_in = [torch.tensor(y) for y in ys_in]
-        ys_in_pad = pad_sequence(
-            ys_in, batch_first=True, padding_value=float(eos_id)
-        )
+        ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id))
 
         ys_out = add_eos(token_ids, eos_id=eos_id)
         ys_out = [torch.tensor(y) for y in ys_out]
-        ys_out_pad = pad_sequence(
-            ys_out, batch_first=True, padding_value=float(-1)
-        )
+        ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1))
 
         device = memory.device
         ys_in_pad = ys_in_pad.to(device)
         ys_out_pad = ys_out_pad.to(device)
 
-        tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(
-            device
-        )
+        tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device)
 
         tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id)
         # TODO: Use length information to create the decoder padding mask
@@ -363,23 +355,17 @@ class Transformer(nn.Module):
 
         ys_in = add_sos(token_ids, sos_id=sos_id)
         ys_in = [torch.tensor(y) for y in ys_in]
-        ys_in_pad = pad_sequence(
-            ys_in, batch_first=True, padding_value=float(eos_id)
-        )
+        ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id))
 
         ys_out = add_eos(token_ids, eos_id=eos_id)
         ys_out = [torch.tensor(y) for y in ys_out]
-        ys_out_pad = pad_sequence(
-            ys_out, batch_first=True, padding_value=float(-1)
-        )
+        ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1))
 
         device = memory.device
         ys_in_pad = ys_in_pad.to(device, dtype=torch.int64)
         ys_out_pad = ys_out_pad.to(device, dtype=torch.int64)
 
-        tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(
-            device
-        )
+        tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device)
 
         tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id)
         # TODO: Use length information to create the decoder padding mask
@@ -652,9 +638,7 @@ def _get_activation_fn(activation: str):
     elif activation == "gelu":
         return nn.functional.gelu
 
-    raise RuntimeError(
-        "activation should be relu/gelu, not {}".format(activation)
-    )
+    raise RuntimeError("activation should be relu/gelu, not {}".format(activation))
 
 
 class PositionalEncoding(nn.Module):
@@ -856,9 +840,7 @@ def encoder_padding_mask(
         1,
     ).to(torch.int32)
 
-    lengths = [
-        0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)
-    ]
+    lengths = [0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)]
     for idx in range(supervision_segments.size(0)):
         # Note: TorchScript doesn't allow to unpack tensors as tuples
         sequence_idx = supervision_segments[idx, 0].item()
@@ -879,9 +861,7 @@ def encoder_padding_mask(
     return mask
 
 
-def decoder_padding_mask(
-    ys_pad: torch.Tensor, ignore_id: int = -1
-) -> torch.Tensor:
+def decoder_padding_mask(ys_pad: torch.Tensor, ignore_id: int = -1) -> torch.Tensor:
     """Generate a length mask for input.
 
     The masked position are filled with True,
diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py
index 355ccc99a..993a7cab5 100644
--- a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py
+++ b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py
@@ -86,8 +86,7 @@ class LibriSpeechAsrDataModule:
             "--full-libri",
             type=str2bool,
             default=True,
-            help="When enabled, use 960h LibriSpeech. "
-            "Otherwise, use 100h subset.",
+            help="When enabled, use 960h LibriSpeech. " "Otherwise, use 100h subset.",
         )
         group.add_argument(
             "--manifest-dir",
@@ -224,13 +223,9 @@ class LibriSpeechAsrDataModule:
         if self.args.enable_musan:
             logging.info("Enable MUSAN")
             logging.info("About to get Musan cuts")
-            cuts_musan = load_manifest(
-                self.args.manifest_dir / "musan_cuts.jsonl.gz"
-            )
+            cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
             transforms.append(
-                CutMix(
-                    cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
-                )
+                CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
             )
         else:
             logging.info("Disable MUSAN")
@@ -252,9 +247,7 @@ class LibriSpeechAsrDataModule:
         input_transforms = []
         if self.args.enable_spec_aug:
             logging.info("Enable SpecAugment")
-            logging.info(
-                f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
-            )
+            logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
             # Set the value of num_frame_masks according to Lhotse's version.
             # In different Lhotse's versions, the default of num_frame_masks is
             # different.
@@ -298,9 +291,7 @@ class LibriSpeechAsrDataModule:
             # Drop feats to be on the safe side.
             train = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(
-                    Fbank(FbankConfig(num_mel_bins=80))
-                ),
+                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
                 input_transforms=input_transforms,
                 return_cuts=self.args.return_cuts,
             )
@@ -356,9 +347,7 @@ class LibriSpeechAsrDataModule:
         if self.args.on_the_fly_feats:
             validate = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(
-                    Fbank(FbankConfig(num_mel_bins=80))
-                ),
+                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
                 return_cuts=self.args.return_cuts,
             )
         else:
diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py b/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py
index 7d0cd0bf3..92529e06c 100755
--- a/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py
+++ b/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py
@@ -336,9 +336,7 @@ def decode_dataset(
         if batch_idx % 100 == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
     return results
 
 
@@ -400,9 +398,7 @@ def main():
 
     logging.info(f"device: {device}")
 
-    HLG = k2.Fsa.from_dict(
-        torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")
-    )
+    HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu"))
     HLG = HLG.to(device)
     assert HLG.requires_grad is False
 
@@ -467,9 +463,7 @@ def main():
 
     if params.export:
         logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
-        torch.save(
-            {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
-        )
+        torch.save({"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt")
         return
 
     model.to(device)
@@ -498,9 +492,7 @@ def main():
             G=G,
         )
 
-        save_results(
-            params=params, test_set_name=test_set, results_dict=results_dict
-        )
+        save_results(params=params, test_set_name=test_set, results_dict=results_dict)
 
     logging.info("Done!")
 
diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/model.py b/egs/librispeech/ASR/tdnn_lstm_ctc/model.py
index 5e04c11b4..1731e1ebe 100644
--- a/egs/librispeech/ASR/tdnn_lstm_ctc/model.py
+++ b/egs/librispeech/ASR/tdnn_lstm_ctc/model.py
@@ -66,10 +66,7 @@ class TdnnLstm(nn.Module):
             nn.BatchNorm1d(num_features=500, affine=False),
         )
         self.lstms = nn.ModuleList(
-            [
-                nn.LSTM(input_size=500, hidden_size=500, num_layers=1)
-                for _ in range(5)
-            ]
+            [nn.LSTM(input_size=500, hidden_size=500, num_layers=1) for _ in range(5)]
         )
         self.lstm_bnorms = nn.ModuleList(
             [nn.BatchNorm1d(num_features=500, affine=False) for _ in range(5)]
diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py b/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py
index 2baeb6bba..addadbe4e 100755
--- a/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py
+++ b/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py
@@ -29,11 +29,7 @@ import torchaudio
 from model import TdnnLstm
 from torch.nn.utils.rnn import pad_sequence
 
-from icefall.decode import (
-    get_lattice,
-    one_best_decoding,
-    rescore_with_whole_lattice,
-)
+from icefall.decode import get_lattice, one_best_decoding, rescore_with_whole_lattice
 from icefall.utils import AttributeDict, get_texts
 
 
@@ -58,9 +54,7 @@ def get_parser():
         help="Path to words.txt",
     )
 
-    parser.add_argument(
-        "--HLG", type=str, required=True, help="Path to HLG.pt."
-    )
+    parser.add_argument("--HLG", type=str, required=True, help="Path to HLG.pt.")
 
     parser.add_argument(
         "--method",
@@ -145,8 +139,7 @@ def read_sound_files(
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
         assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. "
-            f"Given: {sample_rate}"
+            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
         )
         # We use only the first channel
         ans.append(wave[0])
@@ -215,9 +208,7 @@ def main():
     logging.info("Decoding started")
     features = fbank(waves)
 
-    features = pad_sequence(
-        features, batch_first=True, padding_value=math.log(1e-10)
-    )
+    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
     features = features.permute(0, 2, 1)  # now features is (N, C, T)
 
     with torch.no_grad():
@@ -269,9 +260,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/train.py b/egs/librispeech/ASR/tdnn_lstm_ctc/train.py
index 6b37d5c23..071ac792b 100755
--- a/egs/librispeech/ASR/tdnn_lstm_ctc/train.py
+++ b/egs/librispeech/ASR/tdnn_lstm_ctc/train.py
@@ -355,9 +355,7 @@ def compute_loss(
     info["utt_duration"] = supervisions["num_frames"].sum().item()
     # averaged padding proportion over utterances
     info["utt_pad_proportion"] = (
-        ((feature.size(2) - supervisions["num_frames"]) / feature.size(2))
-        .sum()
-        .item()
+        ((feature.size(2) - supervisions["num_frames"]) / feature.size(2)).sum().item()
     )
 
     return loss, info
@@ -470,9 +468,7 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(
-                    tb_writer, "train/tot_", params.batch_idx_train
-                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             valid_info = compute_validation_loss(
diff --git a/egs/librispeech/ASR/transducer/beam_search.py b/egs/librispeech/ASR/transducer/beam_search.py
index 11032f31a..b45b6a9d8 100644
--- a/egs/librispeech/ASR/transducer/beam_search.py
+++ b/egs/librispeech/ASR/transducer/beam_search.py
@@ -38,9 +38,7 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]:
     blank_id = model.decoder.blank_id
     device = model.device
 
-    sos = torch.tensor([blank_id], device=device, dtype=torch.int64).reshape(
-        1, 1
-    )
+    sos = torch.tensor([blank_id], device=device, dtype=torch.int64).reshape(1, 1)
     decoder_out, (h, c) = model.decoder(sos)
     T = encoder_out.size(1)
     t = 0
@@ -123,9 +121,7 @@ def beam_search(
     max_u = 20000  # terminate after this number of steps
     u = 0
 
-    cache: Dict[
-        str, Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
-    ] = {}
+    cache: Dict[str, Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = {}
 
     while t < T and u < max_u:
         # fmt: off
@@ -157,9 +153,9 @@ def beam_search(
             cached_key = "_".join(map(str, y_star.ys))
 
             if cached_key not in cache:
-                decoder_input = torch.tensor(
-                    [y_star.ys[-1]], device=device
-                ).reshape(1, 1)
+                decoder_input = torch.tensor([y_star.ys[-1]], device=device).reshape(
+                    1, 1
+                )
 
                 decoder_out, decoder_state = model.decoder(
                     decoder_input,
diff --git a/egs/librispeech/ASR/transducer/decode.py b/egs/librispeech/ASR/transducer/decode.py
index 5f233df87..804713a20 100755
--- a/egs/librispeech/ASR/transducer/decode.py
+++ b/egs/librispeech/ASR/transducer/decode.py
@@ -228,9 +228,7 @@ def decode_one_batch(
     supervisions = batch["supervisions"]
     feature_lens = supervisions["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(
-        x=feature, x_lens=feature_lens
-    )
+    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
     hyps = []
     batch_size = encoder_out.size(0)
 
@@ -245,9 +243,7 @@ def decode_one_batch(
                 model=model, encoder_out=encoder_out_i, beam=params.beam_size
             )
         else:
-            raise ValueError(
-                f"Unsupported decoding method: {params.decoding_method}"
-            )
+            raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
         hyps.append(sp.decode(hyp).split())
 
     if params.decoding_method == "greedy_search":
@@ -318,9 +314,7 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
     return results
 
 
@@ -353,8 +347,7 @@ def save_results(
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = (
-        params.res_dir
-        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tWER", file=f)
diff --git a/egs/librispeech/ASR/transducer/export.py b/egs/librispeech/ASR/transducer/export.py
index 5a5db30c4..6db0272f0 100755
--- a/egs/librispeech/ASR/transducer/export.py
+++ b/egs/librispeech/ASR/transducer/export.py
@@ -238,9 +238,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/librispeech/ASR/transducer/pretrained.py b/egs/librispeech/ASR/transducer/pretrained.py
index 1db2df648..b1ff7b2b1 100755
--- a/egs/librispeech/ASR/transducer/pretrained.py
+++ b/egs/librispeech/ASR/transducer/pretrained.py
@@ -189,8 +189,7 @@ def read_sound_files(
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
         assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. "
-            f"Given: {sample_rate}"
+            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
         )
         # We use only the first channel
         ans.append(wave[0])
@@ -249,9 +248,7 @@ def main():
     features = fbank(waves)
     feature_lengths = [f.size(0) for f in features]
 
-    features = pad_sequence(
-        features, batch_first=True, padding_value=math.log(1e-10)
-    )
+    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
 
     feature_lengths = torch.tensor(feature_lengths, device=device)
 
@@ -287,9 +284,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/librispeech/ASR/transducer/rnn.py b/egs/librispeech/ASR/transducer/rnn.py
index 2a165b0c1..fe8732301 100644
--- a/egs/librispeech/ASR/transducer/rnn.py
+++ b/egs/librispeech/ASR/transducer/rnn.py
@@ -117,12 +117,8 @@ class LayerNormLSTMCell(nn.Module):
         )
 
         if bias:
-            self.bias_ih = nn.Parameter(
-                torch.empty(4 * hidden_size, **factory_kwargs)
-            )
-            self.bias_hh = nn.Parameter(
-                torch.empty(4 * hidden_size, **factory_kwargs)
-            )
+            self.bias_ih = nn.Parameter(torch.empty(4 * hidden_size, **factory_kwargs))
+            self.bias_hh = nn.Parameter(torch.empty(4 * hidden_size, **factory_kwargs))
         else:
             self.register_parameter("bias_ih", None)
             self.register_parameter("bias_hh", None)
@@ -348,9 +344,7 @@ class LayerNormLSTM(nn.Module):
             device=device,
             dtype=dtype,
         )
-        first_layer = LayerNormLSTMLayer(
-            input_size=input_size, **factory_kwargs
-        )
+        first_layer = LayerNormLSTMLayer(input_size=input_size, **factory_kwargs)
         layers = [first_layer]
         for i in range(1, num_layers):
             layers.append(
@@ -385,9 +379,7 @@ class LayerNormLSTM(nn.Module):
             - List[(next_h, next_c)] containing the hidden states for all layers
 
         """
-        output_states = torch.jit.annotate(
-            List[Tuple[torch.Tensor, torch.Tensor]], []
-        )
+        output_states = torch.jit.annotate(List[Tuple[torch.Tensor, torch.Tensor]], [])
         output = input
         for i, rnn_layer in enumerate(self.layers):
             state = states[i]
@@ -456,12 +448,8 @@ class LayerNormGRUCell(nn.Module):
         )
 
         if bias:
-            self.bias_ih = nn.Parameter(
-                torch.empty(3 * hidden_size, **factory_kwargs)
-            )
-            self.bias_hh = nn.Parameter(
-                torch.empty(3 * hidden_size, **factory_kwargs)
-            )
+            self.bias_ih = nn.Parameter(torch.empty(3 * hidden_size, **factory_kwargs))
+            self.bias_hh = nn.Parameter(torch.empty(3 * hidden_size, **factory_kwargs))
         else:
             self.register_parameter("bias_ih", None)
             self.register_parameter("bias_hh", None)
diff --git a/egs/librispeech/ASR/transducer/test_rnn.py b/egs/librispeech/ASR/transducer/test_rnn.py
index 8591e2d8a..74c94cc70 100755
--- a/egs/librispeech/ASR/transducer/test_rnn.py
+++ b/egs/librispeech/ASR/transducer/test_rnn.py
@@ -254,9 +254,7 @@ def test_layernorm_lstm_layer_with_projection_forward(device="cpu"):
         for name, self_param in self_layer.cell.named_parameters():
             getattr(torch_layer, f"{name}_l0").copy_(self_param)
 
-    torch_y, (torch_h, torch_c) = torch_layer(
-        x_clone, (h.unsqueeze(0), c.unsqueeze(0))
-    )
+    torch_y, (torch_h, torch_c) = torch_layer(x_clone, (h.unsqueeze(0), c.unsqueeze(0)))
     assert_allclose(self_y, torch_y)
     assert_allclose(self_h, torch_h)
     assert_allclose(self_c, torch_c)
@@ -303,9 +301,7 @@ def test_layernorm_lstm_layer_forward(device="cpu"):
         for name, self_param in self_layer.cell.named_parameters():
             getattr(torch_layer, f"{name}_l0").copy_(self_param)
 
-    torch_y, (torch_h, torch_c) = torch_layer(
-        x_clone, (h.unsqueeze(0), c.unsqueeze(0))
-    )
+    torch_y, (torch_h, torch_c) = torch_layer(x_clone, (h.unsqueeze(0), c.unsqueeze(0)))
     assert_allclose(self_y, torch_y)
     assert_allclose(self_h, torch_h)
     assert_allclose(self_c, torch_c)
@@ -594,9 +590,7 @@ def test_layernorm_gru_cell_forward(device="cpu"):
 
     assert_allclose(self_h, torch_h, atol=1e-5)
 
-    (
-        self_h.reshape(-1) * torch.arange(self_h.numel(), device=device)
-    ).sum().backward()
+    (self_h.reshape(-1) * torch.arange(self_h.numel(), device=device)).sum().backward()
     (
         torch_h.reshape(-1) * torch.arange(torch_h.numel(), device=device)
     ).sum().backward()
@@ -718,9 +712,7 @@ def test_layernorm_gru_forward(device="cpu"):
     T = torch.randint(low=2, high=100, size=(1,))
 
     x = torch.rand(N, T, input_size, device=device).requires_grad_()
-    states = [
-        torch.rand(N, hidden_size, device=device) for _ in range(num_layers)
-    ]
+    states = [torch.rand(N, hidden_size, device=device) for _ in range(num_layers)]
 
     x_clone = x.detach().clone().requires_grad_()
 
diff --git a/egs/librispeech/ASR/transducer/train.py b/egs/librispeech/ASR/transducer/train.py
index 1dd65eddb..674ea10a6 100755
--- a/egs/librispeech/ASR/transducer/train.py
+++ b/egs/librispeech/ASR/transducer/train.py
@@ -396,9 +396,7 @@ def compute_loss(
     info = MetricsTracker()
     with warnings.catch_warnings():
         warnings.simplefilter("ignore")
-        info["frames"] = (
-            (feature_lens // params.subsampling_factor).sum().item()
-        )
+        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
 
     # Note: We use reduction=sum while computing the loss.
     info["loss"] = loss.detach().cpu().item()
@@ -520,9 +518,7 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(
-                    tb_writer, "train/tot_", params.batch_idx_train
-                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
@@ -659,9 +655,7 @@ def run(rank, world_size, args):
 
         cur_lr = optimizer._rate
         if tb_writer is not None:
-            tb_writer.add_scalar(
-                "train/learning_rate", cur_lr, params.batch_idx_train
-            )
+            tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
             tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
 
         if rank == 0:
diff --git a/egs/librispeech/ASR/transducer_lstm/beam_search.py b/egs/librispeech/ASR/transducer_lstm/beam_search.py
index 3531a9633..5342c3e8c 100644
--- a/egs/librispeech/ASR/transducer_lstm/beam_search.py
+++ b/egs/librispeech/ASR/transducer_lstm/beam_search.py
@@ -38,9 +38,7 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]:
     blank_id = model.decoder.blank_id
     device = model.device
 
-    sos = torch.tensor([blank_id], device=device, dtype=torch.int64).reshape(
-        1, 1
-    )
+    sos = torch.tensor([blank_id], device=device, dtype=torch.int64).reshape(1, 1)
     decoder_out, (h, c) = model.decoder(sos)
     T = encoder_out.size(1)
     t = 0
@@ -124,9 +122,7 @@ def beam_search(
     max_u = 20000  # terminate after this number of steps
     u = 0
 
-    cache: Dict[
-        str, Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
-    ] = {}
+    cache: Dict[str, Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = {}
 
     while t < T and u < max_u:
         # fmt: off
@@ -158,9 +154,9 @@ def beam_search(
             cached_key = "_".join(map(str, y_star.ys))
 
             if cached_key not in cache:
-                decoder_input = torch.tensor(
-                    [y_star.ys[-1]], device=device
-                ).reshape(1, 1)
+                decoder_input = torch.tensor([y_star.ys[-1]], device=device).reshape(
+                    1, 1
+                )
 
                 decoder_out, decoder_state = model.decoder(
                     decoder_input,
diff --git a/egs/librispeech/ASR/transducer_lstm/decode.py b/egs/librispeech/ASR/transducer_lstm/decode.py
index 604235e2a..9511ca6d7 100755
--- a/egs/librispeech/ASR/transducer_lstm/decode.py
+++ b/egs/librispeech/ASR/transducer_lstm/decode.py
@@ -225,9 +225,7 @@ def decode_one_batch(
     supervisions = batch["supervisions"]
     feature_lens = supervisions["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(
-        x=feature, x_lens=feature_lens
-    )
+    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
     hyps = []
     batch_size = encoder_out.size(0)
 
@@ -242,9 +240,7 @@ def decode_one_batch(
                 model=model, encoder_out=encoder_out_i, beam=params.beam_size
             )
         else:
-            raise ValueError(
-                f"Unsupported decoding method: {params.decoding_method}"
-            )
+            raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
         hyps.append(sp.decode(hyp).split())
 
     if params.decoding_method == "greedy_search":
@@ -315,9 +311,7 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
     return results
 
 
@@ -350,8 +344,7 @@ def save_results(
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = (
-        params.res_dir
-        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tWER", file=f)
diff --git a/egs/librispeech/ASR/transducer_lstm/encoder.py b/egs/librispeech/ASR/transducer_lstm/encoder.py
index 3dc992dd2..038d80077 100644
--- a/egs/librispeech/ASR/transducer_lstm/encoder.py
+++ b/egs/librispeech/ASR/transducer_lstm/encoder.py
@@ -48,9 +48,7 @@ class LstmEncoder(EncoderInterface):
         if vgg_frontend:
             self.encoder_embed = VggSubsampling(num_features, real_hidden_size)
         else:
-            self.encoder_embed = Conv2dSubsampling(
-                num_features, real_hidden_size
-            )
+            self.encoder_embed = Conv2dSubsampling(num_features, real_hidden_size)
 
         self.rnn = nn.LSTM(
             input_size=hidden_size,
diff --git a/egs/librispeech/ASR/transducer_lstm/train.py b/egs/librispeech/ASR/transducer_lstm/train.py
index cdb801e79..57bda63fd 100755
--- a/egs/librispeech/ASR/transducer_lstm/train.py
+++ b/egs/librispeech/ASR/transducer_lstm/train.py
@@ -400,9 +400,7 @@ def compute_loss(
     info = MetricsTracker()
     with warnings.catch_warnings():
         warnings.simplefilter("ignore")
-        info["frames"] = (
-            (feature_lens // params.subsampling_factor).sum().item()
-        )
+        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
 
     # Note: We use reduction=sum while computing the loss.
     info["loss"] = loss.detach().cpu().item()
@@ -524,9 +522,7 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(
-                    tb_writer, "train/tot_", params.batch_idx_train
-                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
@@ -665,9 +661,7 @@ def run(rank, world_size, args):
 
         cur_lr = optimizer._rate
         if tb_writer is not None:
-            tb_writer.add_scalar(
-                "train/learning_rate", cur_lr, params.batch_idx_train
-            )
+            tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
             tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
 
         if rank == 0:
diff --git a/egs/librispeech/ASR/transducer_stateless/alignment.py b/egs/librispeech/ASR/transducer_stateless/alignment.py
index f143611ea..65f2c58d8 100644
--- a/egs/librispeech/ASR/transducer_stateless/alignment.py
+++ b/egs/librispeech/ASR/transducer_stateless/alignment.py
@@ -193,9 +193,7 @@ def force_alignment(
         decoder_out = model.decoder(decoder_input, need_pad=False)
         # decoder_output is of shape (num_active_items, 1, decoder_output_dim)
 
-        current_encoder_out = current_encoder_out.expand(
-            decoder_out.size(0), 1, -1
-        )
+        current_encoder_out = current_encoder_out.expand(decoder_out.size(0), 1, -1)
 
         logits = model.joiner(
             current_encoder_out,
diff --git a/egs/librispeech/ASR/transducer_stateless/beam_search.py b/egs/librispeech/ASR/transducer_stateless/beam_search.py
index ea985f30d..1d79eef9d 100644
--- a/egs/librispeech/ASR/transducer_stateless/beam_search.py
+++ b/egs/librispeech/ASR/transducer_stateless/beam_search.py
@@ -316,9 +316,9 @@ def greedy_search(
         y = logits.argmax().item()
         if y != blank_id:
             hyp.append(y)
-            decoder_input = torch.tensor(
-                [hyp[-context_size:]], device=device
-            ).reshape(1, context_size)
+            decoder_input = torch.tensor([hyp[-context_size:]], device=device).reshape(
+                1, context_size
+            )
 
             decoder_out = model.decoder(decoder_input, need_pad=False)
 
@@ -478,9 +478,7 @@ class HypothesisList(object):
         key = hyp.key
         if key in self:
             old_hyp = self._data[key]  # shallow copy
-            torch.logaddexp(
-                old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob
-            )
+            torch.logaddexp(old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob)
         else:
             self._data[key] = hyp
 
@@ -496,9 +494,7 @@ class HypothesisList(object):
           Return the hypothesis that has the largest `log_prob`.
         """
         if length_norm:
-            return max(
-                self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys)
-            )
+            return max(self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys))
         else:
             return max(self._data.values(), key=lambda hyp: hyp.log_prob)
 
@@ -786,9 +782,7 @@ def modified_beam_search(
         log_probs_shape = k2.ragged.create_ragged_shape2(
             row_splits=row_splits, cached_tot_size=log_probs.numel()
         )
-        ragged_log_probs = k2.RaggedTensor(
-            shape=log_probs_shape, value=log_probs
-        )
+        ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs)
 
         for i in range(batch_size):
             topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam)
@@ -887,9 +881,7 @@ def _deprecated_modified_beam_search(
         decoder_out = model.decoder(decoder_input, need_pad=False)
         # decoder_output is of shape (num_hyps, 1, decoder_output_dim)
 
-        current_encoder_out = current_encoder_out.expand(
-            decoder_out.size(0), 1, -1
-        )
+        current_encoder_out = current_encoder_out.expand(decoder_out.size(0), 1, -1)
 
         logits = model.joiner(
             current_encoder_out,
@@ -959,9 +951,9 @@ def beam_search(
 
     device = model.device
 
-    decoder_input = torch.tensor(
-        [blank_id] * context_size, device=device
-    ).reshape(1, context_size)
+    decoder_input = torch.tensor([blank_id] * context_size, device=device).reshape(
+        1, context_size
+    )
 
     decoder_out = model.decoder(decoder_input, need_pad=False)
 
diff --git a/egs/librispeech/ASR/transducer_stateless/compute_ali.py b/egs/librispeech/ASR/transducer_stateless/compute_ali.py
index 48769e9d1..c91198bb9 100755
--- a/egs/librispeech/ASR/transducer_stateless/compute_ali.py
+++ b/egs/librispeech/ASR/transducer_stateless/compute_ali.py
@@ -124,8 +124,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
 
     return parser
@@ -162,9 +161,7 @@ def compute_alignments(
 
         feature_lens = supervisions["num_frames"].to(device)
 
-        encoder_out, encoder_out_lens = model.encoder(
-            x=feature, x_lens=feature_lens
-        )
+        encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
 
         batch_size = encoder_out.size(0)
 
@@ -204,9 +201,7 @@ def compute_alignments(
         if batch_idx % 2 == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
 
     return CutSet.from_cuts(cuts)
 
diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py
index cde52c9fc..01e8c5b21 100644
--- a/egs/librispeech/ASR/transducer_stateless/conformer.py
+++ b/egs/librispeech/ASR/transducer_stateless/conformer.py
@@ -209,10 +209,7 @@ class Conformer(Transformer):
 
           NOTE: the returned tensors are on the given device.
         """
-        if (
-            len(self._init_state) == 2
-            and self._init_state[0].size(1) == left_context
-        ):
+        if len(self._init_state) == 2 and self._init_state[0].size(1) == left_context:
             # Note: It is OK to share the init state as it is
             # not going to be modified by the model
             return self._init_state
@@ -421,9 +418,7 @@ class ConformerEncoderLayer(nn.Module):
         causal: bool = False,
     ) -> None:
         super(ConformerEncoderLayer, self).__init__()
-        self.self_attn = RelPositionMultiheadAttention(
-            d_model, nhead, dropout=0.0
-        )
+        self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0)
 
         self.feed_forward = nn.Sequential(
             nn.Linear(d_model, dim_feedforward),
@@ -439,22 +434,16 @@ class ConformerEncoderLayer(nn.Module):
             nn.Linear(dim_feedforward, d_model),
         )
 
-        self.conv_module = ConvolutionModule(
-            d_model, cnn_module_kernel, causal=causal
-        )
+        self.conv_module = ConvolutionModule(d_model, cnn_module_kernel, causal=causal)
 
-        self.norm_ff_macaron = nn.LayerNorm(
-            d_model
-        )  # for the macaron style FNN module
+        self.norm_ff_macaron = nn.LayerNorm(d_model)  # for the macaron style FNN module
         self.norm_ff = nn.LayerNorm(d_model)  # for the FNN module
         self.norm_mha = nn.LayerNorm(d_model)  # for the MHA module
 
         self.ff_scale = 0.5
 
         self.norm_conv = nn.LayerNorm(d_model)  # for the CNN module
-        self.norm_final = nn.LayerNorm(
-            d_model
-        )  # for the final output of the block
+        self.norm_final = nn.LayerNorm(d_model)  # for the final output of the block
 
         self.dropout = nn.Dropout(dropout)
 
@@ -486,9 +475,7 @@ class ConformerEncoderLayer(nn.Module):
         residual = src
         if self.normalize_before:
             src = self.norm_ff_macaron(src)
-        src = residual + self.ff_scale * self.dropout(
-            self.feed_forward_macaron(src)
-        )
+        src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src))
         if not self.normalize_before:
             src = self.norm_ff_macaron(src)
 
@@ -514,9 +501,7 @@ class ConformerEncoderLayer(nn.Module):
         if self.normalize_before:
             src = self.norm_conv(src)
 
-        src, _ = self.conv_module(
-            src, src_key_padding_mask=src_key_padding_mask
-        )
+        src, _ = self.conv_module(src, src_key_padding_mask=src_key_padding_mask)
         src = residual + self.dropout(src)
 
         if not self.normalize_before:
@@ -581,9 +566,7 @@ class ConformerEncoderLayer(nn.Module):
         residual = src
         if self.normalize_before:
             src = self.norm_ff_macaron(src)
-        src = residual + self.ff_scale * self.dropout(
-            self.feed_forward_macaron(src)
-        )
+        src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src))
         if not self.normalize_before:
             src = self.norm_ff_macaron(src)
 
@@ -625,9 +608,7 @@ class ConformerEncoderLayer(nn.Module):
         if self.normalize_before:
             src = self.norm_conv(src)
 
-        src, conv_cache = self.conv_module(
-            src, states[1], right_context=right_context
-        )
+        src, conv_cache = self.conv_module(src, states[1], right_context=right_context)
         states[1] = conv_cache
         src = residual + self.dropout(src)
 
@@ -779,9 +760,7 @@ class RelPositionalEncoding(torch.nn.Module):
 
     """
 
-    def __init__(
-        self, d_model: int, dropout_rate: float, max_len: int = 5000
-    ) -> None:
+    def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None:
         """Construct an PositionalEncoding object."""
         super(RelPositionalEncoding, self).__init__()
         self.d_model = d_model
@@ -798,9 +777,7 @@ class RelPositionalEncoding(torch.nn.Module):
             # the length of self.pe is 2 * input_len - 1
             if self.pe.size(1) >= x_size_1 * 2 - 1:
                 # Note: TorchScript doesn't implement operator== for torch.Device
-                if self.pe.dtype != x.dtype or str(self.pe.device) != str(
-                    x.device
-                ):
+                if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device):
                     self.pe = self.pe.to(dtype=x.dtype, device=x.device)
                 return
         # Suppose `i` means to the position of query vector and `j` means the
@@ -826,9 +803,7 @@ class RelPositionalEncoding(torch.nn.Module):
         pe = torch.cat([pe_positive, pe_negative], dim=1)
         self.pe = pe.to(device=x.device, dtype=x.dtype)
 
-    def forward(
-        self, x: torch.Tensor, left_context: int = 0
-    ) -> Tuple[Tensor, Tensor]:
+    def forward(self, x: torch.Tensor, left_context: int = 0) -> Tuple[Tensor, Tensor]:
         """Add positional encoding.
 
         Args:
@@ -1092,9 +1067,9 @@ class RelPositionMultiheadAttention(nn.Module):
 
         if torch.equal(query, key) and torch.equal(key, value):
             # self-attention
-            q, k, v = nn.functional.linear(
-                query, in_proj_weight, in_proj_bias
-            ).chunk(3, dim=-1)
+            q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk(
+                3, dim=-1
+            )
 
         elif torch.equal(key, value):
             # encoder-decoder attention
@@ -1163,31 +1138,22 @@ class RelPositionMultiheadAttention(nn.Module):
             if attn_mask.dim() == 2:
                 attn_mask = attn_mask.unsqueeze(0)
                 if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
-                    raise RuntimeError(
-                        "The size of the 2D attn_mask is not correct."
-                    )
+                    raise RuntimeError("The size of the 2D attn_mask is not correct.")
             elif attn_mask.dim() == 3:
                 if list(attn_mask.size()) != [
                     bsz * num_heads,
                     query.size(0),
                     key.size(0),
                 ]:
-                    raise RuntimeError(
-                        "The size of the 3D attn_mask is not correct."
-                    )
+                    raise RuntimeError("The size of the 3D attn_mask is not correct.")
             else:
                 raise RuntimeError(
-                    "attn_mask's dimension {} is not supported".format(
-                        attn_mask.dim()
-                    )
+                    "attn_mask's dimension {} is not supported".format(attn_mask.dim())
                 )
             # attn_mask's dim is 3 now.
 
         # convert ByteTensor key_padding_mask to bool
-        if (
-            key_padding_mask is not None
-            and key_padding_mask.dtype == torch.uint8
-        ):
+        if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
             warnings.warn(
                 "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead."
             )
@@ -1228,14 +1194,10 @@ class RelPositionMultiheadAttention(nn.Module):
         # first compute matrix a and matrix c
         # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
         k = k.permute(1, 2, 3, 0)  # (batch, head, d_k, time2)
-        matrix_ac = torch.matmul(
-            q_with_bias_u, k
-        )  # (batch, head, time1, time2)
+        matrix_ac = torch.matmul(q_with_bias_u, k)  # (batch, head, time1, time2)
 
         # compute matrix b and matrix d
-        matrix_bd = torch.matmul(
-            q_with_bias_v, p
-        )  # (batch, head, time1, 2*time1-1)
+        matrix_bd = torch.matmul(q_with_bias_v, p)  # (batch, head, time1, 2*time1-1)
 
         matrix_bd = self.rel_shift(matrix_bd, left_context=left_context)
 
@@ -1243,9 +1205,7 @@ class RelPositionMultiheadAttention(nn.Module):
             matrix_ac + matrix_bd
         ) * scaling  # (batch, head, time1, time2)
 
-        attn_output_weights = attn_output_weights.view(
-            bsz * num_heads, tgt_len, -1
-        )
+        attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1)
 
         assert list(attn_output_weights.size()) == [
             bsz * num_heads,
@@ -1290,9 +1250,7 @@ class RelPositionMultiheadAttention(nn.Module):
             attn_output_weights = attn_output_weights.view(
                 bsz, num_heads, tgt_len, src_len
             )
-            attn_output_weights = attn_output_weights.masked_fill(
-                combined_mask, 0.0
-            )
+            attn_output_weights = attn_output_weights.masked_fill(combined_mask, 0.0)
             attn_output_weights = attn_output_weights.view(
                 bsz * num_heads, tgt_len, src_len
             )
@@ -1304,13 +1262,9 @@ class RelPositionMultiheadAttention(nn.Module):
         attn_output = torch.bmm(attn_output_weights, v)
         assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
         attn_output = (
-            attn_output.transpose(0, 1)
-            .contiguous()
-            .view(tgt_len, bsz, embed_dim)
-        )
-        attn_output = nn.functional.linear(
-            attn_output, out_proj_weight, out_proj_bias
+            attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
         )
+        attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias)
 
         if need_weights:
             # average attention weights over heads
@@ -1418,16 +1372,12 @@ class ConvolutionModule(nn.Module):
                 # manualy padding self.lorder zeros to the left
                 x = nn.functional.pad(x, (self.lorder, 0), "constant", 0.0)
             else:
-                assert (
-                    not self.training
-                ), "Cache should be None in training time"
+                assert not self.training, "Cache should be None in training time"
                 assert cache.size(0) == self.lorder
                 x = torch.cat([cache.permute(1, 2, 0), x], dim=2)
                 if right_context > 0:
                     cache = x.permute(2, 0, 1)[
-                        -(self.lorder + right_context) : (  # noqa
-                            -right_context
-                        ),
+                        -(self.lorder + right_context) : (-right_context),  # noqa
                         ...,
                     ]
                 else:
diff --git a/egs/librispeech/ASR/transducer_stateless/decode.py b/egs/librispeech/ASR/transducer_stateless/decode.py
index 74bba9cad..688e214c8 100755
--- a/egs/librispeech/ASR/transducer_stateless/decode.py
+++ b/egs/librispeech/ASR/transducer_stateless/decode.py
@@ -171,8 +171,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -230,9 +229,7 @@ def decode_one_batch(
     supervisions = batch["supervisions"]
     feature_lens = supervisions["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(
-        x=feature, x_lens=feature_lens
-    )
+    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
 
     hyps = []
 
@@ -248,10 +245,7 @@ def decode_one_batch(
         )
         for hyp in sp.decode(hyp_tokens):
             hyps.append(hyp.split())
-    elif (
-        params.decoding_method == "greedy_search"
-        and params.max_sym_per_frame == 1
-    ):
+    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -374,9 +368,7 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
     return results
 
 
@@ -409,8 +401,7 @@ def save_results(
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = (
-        params.res_dir
-        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tWER", file=f)
@@ -450,9 +441,7 @@ def main():
         params.suffix += f"-max-contexts-{params.max_contexts}"
         params.suffix += f"-max-states-{params.max_states}"
     elif "beam_search" in params.decoding_method:
-        params.suffix += (
-            f"-{params.decoding_method}-beam-size-{params.beam_size}"
-        )
+        params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
     else:
         params.suffix += f"-context-{params.context_size}"
         params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
diff --git a/egs/librispeech/ASR/transducer_stateless/decoder.py b/egs/librispeech/ASR/transducer_stateless/decoder.py
index fbc2373a9..a182d91e2 100644
--- a/egs/librispeech/ASR/transducer_stateless/decoder.py
+++ b/egs/librispeech/ASR/transducer_stateless/decoder.py
@@ -87,9 +87,7 @@ class Decoder(nn.Module):
         if self.context_size > 1:
             embedding_out = embedding_out.permute(0, 2, 1)
             if need_pad is True:
-                embedding_out = F.pad(
-                    embedding_out, pad=(self.context_size - 1, 0)
-                )
+                embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0))
             else:
                 # During inference time, there is no need to do extra padding
                 # as we only need one output
diff --git a/egs/librispeech/ASR/transducer_stateless/export.py b/egs/librispeech/ASR/transducer_stateless/export.py
index 8bd0bdea1..c617e6c4c 100755
--- a/egs/librispeech/ASR/transducer_stateless/export.py
+++ b/egs/librispeech/ASR/transducer_stateless/export.py
@@ -109,8 +109,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
 
     return parser
@@ -244,9 +243,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/librispeech/ASR/transducer_stateless/joiner.py b/egs/librispeech/ASR/transducer_stateless/joiner.py
index 93cccbd8c..e1625992d 100644
--- a/egs/librispeech/ASR/transducer_stateless/joiner.py
+++ b/egs/librispeech/ASR/transducer_stateless/joiner.py
@@ -60,13 +60,9 @@ class Joiner(nn.Module):
         encoder_out_len: List[int] = encoder_out_len.tolist()
         decoder_out_len: List[int] = decoder_out_len.tolist()
 
-        encoder_out_list = [
-            encoder_out[i, : encoder_out_len[i], :] for i in range(N)
-        ]
+        encoder_out_list = [encoder_out[i, : encoder_out_len[i], :] for i in range(N)]
 
-        decoder_out_list = [
-            decoder_out[i, : decoder_out_len[i], :] for i in range(N)
-        ]
+        decoder_out_list = [decoder_out[i, : decoder_out_len[i], :] for i in range(N)]
 
         x = [
             e.unsqueeze(1) + d.unsqueeze(0)
diff --git a/egs/librispeech/ASR/transducer_stateless/pretrained.py b/egs/librispeech/ASR/transducer_stateless/pretrained.py
index b64521801..c393974e6 100755
--- a/egs/librispeech/ASR/transducer_stateless/pretrained.py
+++ b/egs/librispeech/ASR/transducer_stateless/pretrained.py
@@ -167,8 +167,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -198,8 +197,7 @@ def read_sound_files(
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
         assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. "
-            f"Given: {sample_rate}"
+            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
         )
         # We use only the first channel
         ans.append(wave[0])
@@ -259,9 +257,7 @@ def main():
     features = fbank(waves)
     feature_lengths = [f.size(0) for f in features]
 
-    features = pad_sequence(
-        features, batch_first=True, padding_value=math.log(1e-10)
-    )
+    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
 
     feature_lengths = torch.tensor(feature_lengths, device=device)
 
@@ -334,9 +330,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/librispeech/ASR/transducer_stateless/test_compute_ali.py b/egs/librispeech/ASR/transducer_stateless/test_compute_ali.py
index b00fc34f1..9af46846a 100755
--- a/egs/librispeech/ASR/transducer_stateless/test_compute_ali.py
+++ b/egs/librispeech/ASR/transducer_stateless/test_compute_ali.py
@@ -140,16 +140,13 @@ def main():
                 token_alignment[i, : token_alignment_length[i]].tolist(), sp=sp
             )
             word_starting_time = [
-                "{:.2f}".format(i * frame_shift_in_second)
-                for i in word_starting_frames
+                "{:.2f}".format(i * frame_shift_in_second) for i in word_starting_frames
             ]
 
             words = supervisions["text"][i].split()
 
             assert len(word_starting_frames) == len(words)
-            word_starting_time_dict[cuts[i].id] = list(
-                zip(words, word_starting_time)
-            )
+            word_starting_time_dict[cuts[i].id] = list(zip(words, word_starting_time))
 
         # This is a demo script and we exit here after processing
         # one batch.
@@ -160,9 +157,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/librispeech/ASR/transducer_stateless/test_conformer.py b/egs/librispeech/ASR/transducer_stateless/test_conformer.py
index d1350c8ab..65b08d425 100755
--- a/egs/librispeech/ASR/transducer_stateless/test_conformer.py
+++ b/egs/librispeech/ASR/transducer_stateless/test_conformer.py
@@ -29,9 +29,7 @@ from conformer import Conformer
 
 def test_conformer():
     feature_dim = 50
-    c = Conformer(
-        num_features=feature_dim, output_dim=256, d_model=128, nhead=4
-    )
+    c = Conformer(num_features=feature_dim, output_dim=256, d_model=128, nhead=4)
     batch_size = 5
     seq_len = 20
     # Just make sure the forward pass runs.
diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py
index ae93f3348..c86125f44 100755
--- a/egs/librispeech/ASR/transducer_stateless/train.py
+++ b/egs/librispeech/ASR/transducer_stateless/train.py
@@ -136,8 +136,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
 
     parser.add_argument(
@@ -422,9 +421,7 @@ def compute_loss(
     info = MetricsTracker()
     with warnings.catch_warnings():
         warnings.simplefilter("ignore")
-        info["frames"] = (
-            (feature_lens // params.subsampling_factor).sum().item()
-        )
+        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
 
     # Note: We use reduction=sum while computing the loss.
     info["loss"] = loss.detach().cpu().item()
@@ -545,9 +542,7 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(
-                    tb_writer, "train/tot_", params.batch_idx_train
-                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
@@ -664,13 +659,9 @@ def run(rank, world_size, args):
         num_removed = num_in_total - num_left
         removed_percent = num_removed / num_in_total * 100
 
-        logging.info(
-            f"Before removing short and long utterances: {num_in_total}"
-        )
+        logging.info(f"Before removing short and long utterances: {num_in_total}")
         logging.info(f"After removing short and long utterances: {num_left}")
-        logging.info(
-            f"Removed {num_removed} utterances ({removed_percent:.5f}%)"
-        )
+        logging.info(f"Removed {num_removed} utterances ({removed_percent:.5f}%)")
     except TypeError as e:
         # You can ignore this error as previous versions of Lhotse work fine
         # for the above code. In recent versions of Lhotse, it uses
@@ -698,9 +689,7 @@ def run(rank, world_size, args):
 
         cur_lr = optimizer._rate
         if tb_writer is not None:
-            tb_writer.add_scalar(
-                "train/learning_rate", cur_lr, params.batch_idx_train
-            )
+            tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
             tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
 
         if rank == 0:
diff --git a/egs/librispeech/ASR/transducer_stateless/transformer.py b/egs/librispeech/ASR/transducer_stateless/transformer.py
index e851dcc32..b3ff153c1 100644
--- a/egs/librispeech/ASR/transducer_stateless/transformer.py
+++ b/egs/librispeech/ASR/transducer_stateless/transformer.py
@@ -250,9 +250,7 @@ def _get_activation_fn(activation: str):
     elif activation == "gelu":
         return nn.functional.gelu
 
-    raise RuntimeError(
-        "activation should be relu/gelu, not {}".format(activation)
-    )
+    raise RuntimeError("activation should be relu/gelu, not {}".format(activation))
 
 
 class PositionalEncoding(nn.Module):
diff --git a/egs/librispeech/ASR/transducer_stateless2/decode.py b/egs/librispeech/ASR/transducer_stateless2/decode.py
index ac2807241..c642b16bd 100755
--- a/egs/librispeech/ASR/transducer_stateless2/decode.py
+++ b/egs/librispeech/ASR/transducer_stateless2/decode.py
@@ -171,8 +171,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -230,9 +229,7 @@ def decode_one_batch(
     supervisions = batch["supervisions"]
     feature_lens = supervisions["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(
-        x=feature, x_lens=feature_lens
-    )
+    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
 
     hyps = []
 
@@ -248,10 +245,7 @@ def decode_one_batch(
         )
         for hyp in sp.decode(hyp_tokens):
             hyps.append(hyp.split())
-    elif (
-        params.decoding_method == "greedy_search"
-        and params.max_sym_per_frame == 1
-    ):
+    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -374,9 +368,7 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
     return results
 
 
@@ -409,8 +401,7 @@ def save_results(
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = (
-        params.res_dir
-        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tWER", file=f)
@@ -450,9 +441,7 @@ def main():
         params.suffix += f"-max-contexts-{params.max_contexts}"
         params.suffix += f"-max-states-{params.max_states}"
     elif "beam_search" in params.decoding_method:
-        params.suffix += (
-            f"-{params.decoding_method}-beam-size-{params.beam_size}"
-        )
+        params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
     else:
         params.suffix += f"-context-{params.context_size}"
         params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
diff --git a/egs/librispeech/ASR/transducer_stateless2/export.py b/egs/librispeech/ASR/transducer_stateless2/export.py
index 57c1a6094..229c514b9 100755
--- a/egs/librispeech/ASR/transducer_stateless2/export.py
+++ b/egs/librispeech/ASR/transducer_stateless2/export.py
@@ -104,8 +104,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
 
     return parser
@@ -176,9 +175,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/librispeech/ASR/transducer_stateless2/pretrained.py b/egs/librispeech/ASR/transducer_stateless2/pretrained.py
index 292f77f03..9053bc6e0 100755
--- a/egs/librispeech/ASR/transducer_stateless2/pretrained.py
+++ b/egs/librispeech/ASR/transducer_stateless2/pretrained.py
@@ -167,8 +167,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -198,8 +197,7 @@ def read_sound_files(
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
         assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. "
-            f"Given: {sample_rate}"
+            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
         )
         # We use only the first channel
         ans.append(wave[0])
@@ -259,9 +257,7 @@ def main():
     features = fbank(waves)
     feature_lengths = [f.size(0) for f in features]
 
-    features = pad_sequence(
-        features, batch_first=True, padding_value=math.log(1e-10)
-    )
+    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
 
     feature_lengths = torch.tensor(feature_lengths, device=device)
 
@@ -334,9 +330,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/librispeech/ASR/transducer_stateless2/train.py b/egs/librispeech/ASR/transducer_stateless2/train.py
index ea15c9040..71c9c5df7 100755
--- a/egs/librispeech/ASR/transducer_stateless2/train.py
+++ b/egs/librispeech/ASR/transducer_stateless2/train.py
@@ -136,8 +136,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
 
     parser.add_argument(
@@ -410,9 +409,7 @@ def compute_loss(
     info = MetricsTracker()
     with warnings.catch_warnings():
         warnings.simplefilter("ignore")
-        info["frames"] = (
-            (feature_lens // params.subsampling_factor).sum().item()
-        )
+        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
 
     # Note: We use reduction=sum while computing the loss.
     info["loss"] = loss.detach().cpu().item()
@@ -533,9 +530,7 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(
-                    tb_writer, "train/tot_", params.batch_idx_train
-                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
@@ -652,13 +647,9 @@ def run(rank, world_size, args):
         num_removed = num_in_total - num_left
         removed_percent = num_removed / num_in_total * 100
 
-        logging.info(
-            f"Before removing short and long utterances: {num_in_total}"
-        )
+        logging.info(f"Before removing short and long utterances: {num_in_total}")
         logging.info(f"After removing short and long utterances: {num_left}")
-        logging.info(
-            f"Removed {num_removed} utterances ({removed_percent:.5f}%)"
-        )
+        logging.info(f"Removed {num_removed} utterances ({removed_percent:.5f}%)")
     except TypeError as e:
         # You can ignore this error as previous versions of Lhotse work fine
         # for the above code. In recent versions of Lhotse, it uses
@@ -686,9 +677,7 @@ def run(rank, world_size, args):
 
         cur_lr = optimizer._rate
         if tb_writer is not None:
-            tb_writer.add_scalar(
-                "train/learning_rate", cur_lr, params.batch_idx_train
-            )
+            tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
             tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
 
         if rank == 0:
diff --git a/egs/librispeech/ASR/transducer_stateless_multi_datasets/decode.py b/egs/librispeech/ASR/transducer_stateless_multi_datasets/decode.py
index d596e05cb..253821028 100755
--- a/egs/librispeech/ASR/transducer_stateless_multi_datasets/decode.py
+++ b/egs/librispeech/ASR/transducer_stateless_multi_datasets/decode.py
@@ -172,8 +172,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -231,9 +230,7 @@ def decode_one_batch(
     supervisions = batch["supervisions"]
     feature_lens = supervisions["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(
-        x=feature, x_lens=feature_lens
-    )
+    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
 
     hyps = []
 
@@ -249,10 +246,7 @@ def decode_one_batch(
         )
         for hyp in sp.decode(hyp_tokens):
             hyps.append(hyp.split())
-    elif (
-        params.decoding_method == "greedy_search"
-        and params.max_sym_per_frame == 1
-    ):
+    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -375,9 +369,7 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
     return results
 
 
@@ -410,8 +402,7 @@ def save_results(
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = (
-        params.res_dir
-        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tWER", file=f)
@@ -451,9 +442,7 @@ def main():
         params.suffix += f"-max-contexts-{params.max_contexts}"
         params.suffix += f"-max-states-{params.max_states}"
     elif "beam_search" in params.decoding_method:
-        params.suffix += (
-            f"-{params.decoding_method}-beam-size-{params.beam_size}"
-        )
+        params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
     else:
         params.suffix += f"-context-{params.context_size}"
         params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
diff --git a/egs/librispeech/ASR/transducer_stateless_multi_datasets/export.py b/egs/librispeech/ASR/transducer_stateless_multi_datasets/export.py
index b6b69d932..97b0eea4a 100755
--- a/egs/librispeech/ASR/transducer_stateless_multi_datasets/export.py
+++ b/egs/librispeech/ASR/transducer_stateless_multi_datasets/export.py
@@ -110,8 +110,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
 
     return parser
@@ -247,9 +246,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/librispeech/ASR/transducer_stateless_multi_datasets/pretrained.py b/egs/librispeech/ASR/transducer_stateless_multi_datasets/pretrained.py
index f297fa2b2..c698a35b0 100755
--- a/egs/librispeech/ASR/transducer_stateless_multi_datasets/pretrained.py
+++ b/egs/librispeech/ASR/transducer_stateless_multi_datasets/pretrained.py
@@ -167,8 +167,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -198,8 +197,7 @@ def read_sound_files(
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
         assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. "
-            f"Given: {sample_rate}"
+            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
         )
         # We use only the first channel
         ans.append(wave[0])
@@ -259,9 +257,7 @@ def main():
     features = fbank(waves)
     feature_lengths = [f.size(0) for f in features]
 
-    features = pad_sequence(
-        features, batch_first=True, padding_value=math.log(1e-10)
-    )
+    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
 
     feature_lengths = torch.tensor(feature_lengths, device=device)
 
@@ -334,9 +330,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/librispeech/ASR/transducer_stateless_multi_datasets/test_asr_datamodule.py b/egs/librispeech/ASR/transducer_stateless_multi_datasets/test_asr_datamodule.py
index ef51a7811..1e1188ca6 100755
--- a/egs/librispeech/ASR/transducer_stateless_multi_datasets/test_asr_datamodule.py
+++ b/egs/librispeech/ASR/transducer_stateless_multi_datasets/test_asr_datamodule.py
@@ -41,9 +41,7 @@ def test_dataset():
     print(args)
 
     if args.enable_musan:
-        cuts_musan = load_manifest(
-            Path(args.manifest_dir) / "musan_cuts.jsonl.gz"
-        )
+        cuts_musan = load_manifest(Path(args.manifest_dir) / "musan_cuts.jsonl.gz")
     else:
         cuts_musan = None
 
diff --git a/egs/librispeech/ASR/transducer_stateless_multi_datasets/train.py b/egs/librispeech/ASR/transducer_stateless_multi_datasets/train.py
index 27912738c..e5b7dc390 100755
--- a/egs/librispeech/ASR/transducer_stateless_multi_datasets/train.py
+++ b/egs/librispeech/ASR/transducer_stateless_multi_datasets/train.py
@@ -114,8 +114,7 @@ def get_parser():
         "--full-libri",
         type=str2bool,
         default=True,
-        help="When enabled, use 960h LibriSpeech. "
-        "Otherwise, use 100h subset.",
+        help="When enabled, use 960h LibriSpeech. " "Otherwise, use 100h subset.",
     )
 
     parser.add_argument(
@@ -170,8 +169,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
 
     parser.add_argument(
@@ -469,9 +467,7 @@ def compute_loss(
     info = MetricsTracker()
     with warnings.catch_warnings():
         warnings.simplefilter("ignore")
-        info["frames"] = (
-            (feature_lens // params.subsampling_factor).sum().item()
-        )
+        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
 
     # Note: We use reduction=sum while computing the loss.
     info["loss"] = loss.detach().cpu().item()
@@ -635,9 +631,7 @@ def train_one_epoch(
                     f"train/current_{prefix}_",
                     params.batch_idx_train,
                 )
-                tot_loss.write_summary(
-                    tb_writer, "train/tot_", params.batch_idx_train
-                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
                 libri_tot_loss.write_summary(
                     tb_writer, "train/libri_tot_", params.batch_idx_train
                 )
@@ -784,9 +778,7 @@ def run(rank, world_size, args):
     train_giga_cuts = train_giga_cuts.repeat(times=None)
 
     if args.enable_musan:
-        cuts_musan = load_manifest(
-            Path(args.manifest_dir) / "musan_cuts.jsonl.gz"
-        )
+        cuts_musan = load_manifest(Path(args.manifest_dir) / "musan_cuts.jsonl.gz")
     else:
         cuts_musan = None
 
@@ -825,9 +817,7 @@ def run(rank, world_size, args):
 
         cur_lr = optimizer._rate
         if tb_writer is not None:
-            tb_writer.add_scalar(
-                "train/learning_rate", cur_lr, params.batch_idx_train
-            )
+            tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
             tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
 
         if rank == 0:
diff --git a/egs/ptb/LM/local/sort_lm_training_data.py b/egs/ptb/LM/local/sort_lm_training_data.py
index af54dbd07..bed3856e4 100755
--- a/egs/ptb/LM/local/sort_lm_training_data.py
+++ b/egs/ptb/LM/local/sort_lm_training_data.py
@@ -135,9 +135,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/ptb/LM/local/test_prepare_lm_training_data.py b/egs/ptb/LM/local/test_prepare_lm_training_data.py
index 877720e7b..3790045fa 100755
--- a/egs/ptb/LM/local/test_prepare_lm_training_data.py
+++ b/egs/ptb/LM/local/test_prepare_lm_training_data.py
@@ -54,9 +54,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/spgispeech/ASR/local/compute_fbank_musan.py b/egs/spgispeech/ASR/local/compute_fbank_musan.py
index 6cb8b65ae..9bea28a41 100755
--- a/egs/spgispeech/ASR/local/compute_fbank_musan.py
+++ b/egs/spgispeech/ASR/local/compute_fbank_musan.py
@@ -87,9 +87,7 @@ def compute_fbank_musan():
     # create chunks of Musan with duration 5 - 10 seconds
     musan_cuts = (
         CutSet.from_manifests(
-            recordings=combine(
-                part["recordings"] for part in manifests.values()
-            )
+            recordings=combine(part["recordings"] for part in manifests.values())
         )
         .cut_into_windows(10.0)
         .filter(lambda c: c.duration > 5)
@@ -108,8 +106,6 @@ def compute_fbank_musan():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
     logging.basicConfig(format=formatter, level=logging.INFO)
     compute_fbank_musan()
diff --git a/egs/spgispeech/ASR/local/compute_fbank_spgispeech.py b/egs/spgispeech/ASR/local/compute_fbank_spgispeech.py
index 8116e7605..20ff6d7ab 100755
--- a/egs/spgispeech/ASR/local/compute_fbank_spgispeech.py
+++ b/egs/spgispeech/ASR/local/compute_fbank_spgispeech.py
@@ -103,11 +103,7 @@ def compute_fbank_spgispeech(args):
             chunk_size=chunk_size,
         )
         start = args.start
-        stop = (
-            min(args.stop, args.num_splits)
-            if args.stop > 0
-            else args.num_splits
-        )
+        stop = min(args.stop, args.num_splits) if args.stop > 0 else args.num_splits
         num_digits = len(str(args.num_splits))
         for i in range(start, stop):
             idx = f"{i + 1}".zfill(num_digits)
@@ -129,9 +125,7 @@ def compute_fbank_spgispeech(args):
                 logging.info(f"{partition} already exists - skipping.")
                 continue
             logging.info(f"Processing {partition}")
-            cut_set = load_manifest_lazy(
-                src_dir / f"cuts_{partition}_raw.jsonl.gz"
-            )
+            cut_set = load_manifest_lazy(src_dir / f"cuts_{partition}_raw.jsonl.gz")
             cut_set = cut_set.compute_and_store_features_batch(
                 extractor=extractor,
                 storage_path=output_dir / f"feats_{partition}",
@@ -144,9 +138,7 @@ def compute_fbank_spgispeech(args):
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
     logging.basicConfig(format=formatter, level=logging.INFO)
 
     args = get_args()
diff --git a/egs/spgispeech/ASR/local/prepare_splits.py b/egs/spgispeech/ASR/local/prepare_splits.py
index 8c8f1c133..508d4acd8 100755
--- a/egs/spgispeech/ASR/local/prepare_splits.py
+++ b/egs/spgispeech/ASR/local/prepare_splits.py
@@ -55,9 +55,7 @@ def split_spgispeech_train():
 
     # Add speed perturbation
     train_cuts = (
-        train_cuts
-        + train_cuts.perturb_speed(0.9)
-        + train_cuts.perturb_speed(1.1)
+        train_cuts + train_cuts.perturb_speed(0.9) + train_cuts.perturb_speed(1.1)
     )
 
     # Write the manifests to disk.
@@ -73,9 +71,7 @@ def split_spgispeech_train():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
     logging.basicConfig(format=formatter, level=logging.INFO)
 
     split_spgispeech_train()
diff --git a/egs/spgispeech/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/spgispeech/ASR/pruned_transducer_stateless2/asr_datamodule.py
index f165f6e60..d94a92503 100644
--- a/egs/spgispeech/ASR/pruned_transducer_stateless2/asr_datamodule.py
+++ b/egs/spgispeech/ASR/pruned_transducer_stateless2/asr_datamodule.py
@@ -176,17 +176,13 @@ class SPGISpeechAsrDataModule:
             The state dict for the training sampler.
         """
         logging.info("About to get Musan cuts")
-        cuts_musan = load_manifest(
-            self.args.manifest_dir / "cuts_musan.jsonl.gz"
-        )
+        cuts_musan = load_manifest(self.args.manifest_dir / "cuts_musan.jsonl.gz")
 
         transforms = []
         if self.args.enable_musan:
             logging.info("Enable MUSAN")
             transforms.append(
-                CutMix(
-                    cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
-                )
+                CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
             )
         else:
             logging.info("Disable MUSAN")
@@ -208,9 +204,7 @@ class SPGISpeechAsrDataModule:
         input_transforms = []
         if self.args.enable_spec_aug:
             logging.info("Enable SpecAugment")
-            logging.info(
-                f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
-            )
+            logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
             input_transforms.append(
                 SpecAugment(
                     time_warp_factor=self.args.spec_aug_time_warp_factor,
@@ -227,9 +221,7 @@ class SPGISpeechAsrDataModule:
         if self.args.on_the_fly_feats:
             train = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(
-                    Fbank(FbankConfig(num_mel_bins=80))
-                ),
+                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
                 input_transforms=input_transforms,
             )
         else:
@@ -282,9 +274,7 @@ class SPGISpeechAsrDataModule:
         if self.args.on_the_fly_feats:
             validate = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(
-                    Fbank(FbankConfig(num_mel_bins=80))
-                ),
+                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
             )
         else:
             validate = K2SpeechRecognitionDataset(
@@ -328,9 +318,7 @@ class SPGISpeechAsrDataModule:
     @lru_cache()
     def train_cuts(self) -> CutSet:
         logging.info("About to get SPGISpeech train cuts")
-        return load_manifest_lazy(
-            self.args.manifest_dir / "cuts_train_shuf.jsonl.gz"
-        )
+        return load_manifest_lazy(self.args.manifest_dir / "cuts_train_shuf.jsonl.gz")
 
     @lru_cache()
     def dev_cuts(self) -> CutSet:
diff --git a/egs/spgispeech/ASR/pruned_transducer_stateless2/decode.py b/egs/spgispeech/ASR/pruned_transducer_stateless2/decode.py
index c39bd0530..098da3ff0 100755
--- a/egs/spgispeech/ASR/pruned_transducer_stateless2/decode.py
+++ b/egs/spgispeech/ASR/pruned_transducer_stateless2/decode.py
@@ -76,11 +76,7 @@ from beam_search import (
 )
 from train import get_params, get_transducer_model
 
-from icefall.checkpoint import (
-    average_checkpoints,
-    find_checkpoints,
-    load_checkpoint,
-)
+from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
 from icefall.utils import (
     AttributeDict,
     setup_logger,
@@ -187,8 +183,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -246,9 +241,7 @@ def decode_one_batch(
     supervisions = batch["supervisions"]
     feature_lens = supervisions["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(
-        x=feature, x_lens=feature_lens
-    )
+    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
     hyps = []
 
     if params.decoding_method == "fast_beam_search":
@@ -263,10 +256,7 @@ def decode_one_batch(
         )
         for hyp in sp.decode(hyp_tokens):
             hyps.append(hyp.split())
-    elif (
-        params.decoding_method == "greedy_search"
-        and params.max_sym_per_frame == 1
-    ):
+    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -389,9 +379,7 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
     return results
 
 
@@ -424,9 +412,7 @@ def save_results(
         # we also compute CER for spgispeech dataset.
         results_char = []
         for res in results:
-            results_char.append(
-                (res[0], list("".join(res[1])), list("".join(res[2])))
-            )
+            results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
         cers_filename = (
             params.res_dir / f"cers-{test_set_name}-{key}-{params.suffix}.txt"
         )
@@ -438,32 +424,23 @@ def save_results(
 
         logging.info("Wrote detailed error stats to {}".format(wers_filename))
 
-    test_set_wers = {
-        k: v for k, v in sorted(test_set_wers.items(), key=lambda x: x[1])
-    }
-    test_set_cers = {
-        k: v for k, v in sorted(test_set_cers.items(), key=lambda x: x[1])
-    }
+    test_set_wers = {k: v for k, v in sorted(test_set_wers.items(), key=lambda x: x[1])}
+    test_set_cers = {k: v for k, v in sorted(test_set_cers.items(), key=lambda x: x[1])}
     errs_info = (
-        params.res_dir
-        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tWER\tCER", file=f)
         for key in test_set_wers:
             print(
-                "{}\t{}\t{}".format(
-                    key, test_set_wers[key], test_set_cers[key]
-                ),
+                "{}\t{}\t{}".format(key, test_set_wers[key], test_set_cers[key]),
                 file=f,
             )
 
     s = "\nFor {}, WER/CER of different settings are:\n".format(test_set_name)
     note = "\tbest for {}".format(test_set_name)
     for key in test_set_wers:
-        s += "{}\t{}\t{}{}\n".format(
-            key, test_set_wers[key], test_set_cers[key], note
-        )
+        s += "{}\t{}\t{}{}\n".format(key, test_set_wers[key], test_set_cers[key], note)
         note = ""
     logging.info(s)
 
@@ -496,9 +473,7 @@ def main():
         params.suffix += f"-max-contexts-{params.max_contexts}"
         params.suffix += f"-max-states-{params.max_states}"
     elif "beam_search" in params.decoding_method:
-        params.suffix += (
-            f"-{params.decoding_method}-beam-size-{params.beam_size}"
-        )
+        params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
     else:
         params.suffix += f"-context-{params.context_size}"
         params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
@@ -530,8 +505,7 @@ def main():
         ]
         if len(filenames) == 0:
             raise ValueError(
-                f"No checkpoints found for"
-                f" --iter {params.iter}, --avg {params.avg}"
+                f"No checkpoints found for" f" --iter {params.iter}, --avg {params.avg}"
             )
         elif len(filenames) < params.avg:
             raise ValueError(
diff --git a/egs/spgispeech/ASR/pruned_transducer_stateless2/export.py b/egs/spgispeech/ASR/pruned_transducer_stateless2/export.py
index 77faa3c0e..e79cb300d 100755
--- a/egs/spgispeech/ASR/pruned_transducer_stateless2/export.py
+++ b/egs/spgispeech/ASR/pruned_transducer_stateless2/export.py
@@ -50,11 +50,7 @@ import sentencepiece as spm
 import torch
 from train import get_params, get_transducer_model
 
-from icefall.checkpoint import (
-    average_checkpoints,
-    find_checkpoints,
-    load_checkpoint,
-)
+from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
 from icefall.utils import str2bool
 
 
@@ -119,8 +115,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
 
     return parser
@@ -196,9 +191,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/spgispeech/ASR/pruned_transducer_stateless2/train.py b/egs/spgispeech/ASR/pruned_transducer_stateless2/train.py
index dda29b3e5..213635894 100755
--- a/egs/spgispeech/ASR/pruned_transducer_stateless2/train.py
+++ b/egs/spgispeech/ASR/pruned_transducer_stateless2/train.py
@@ -77,9 +77,7 @@ from icefall.dist import cleanup_dist, setup_dist
 from icefall.env import get_env_info
 from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
 
-LRSchedulerType = Union[
-    torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
-]
+LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
 
 
 def get_parser():
@@ -155,8 +153,7 @@ def get_parser():
         "--initial-lr",
         type=float,
         default=0.003,
-        help="The initial learning rate.  This value should not need to be "
-        "changed.",
+        help="The initial learning rate.  This value should not need to be " "changed.",
     )
 
     parser.add_argument(
@@ -179,8 +176,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
 
     parser.add_argument(
@@ -203,8 +199,7 @@ def get_parser():
         "--am-scale",
         type=float,
         default=0.0,
-        help="The scale to smooth the loss with am (output of encoder network)"
-        "part.",
+        help="The scale to smooth the loss with am (output of encoder network)" "part.",
     )
 
     parser.add_argument(
@@ -554,23 +549,16 @@ def compute_loss(
         # overwhelming the simple_loss and causing it to diverge,
         # in case it had not fully learned the alignment yet.
         pruned_loss_scale = (
-            0.0
-            if warmup < 1.0
-            else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
-        )
-        loss = (
-            params.simple_loss_scale * simple_loss
-            + pruned_loss_scale * pruned_loss
+            0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
         )
+        loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
 
     assert loss.requires_grad == is_training
 
     info = MetricsTracker()
     with warnings.catch_warnings():
         warnings.simplefilter("ignore")
-        info["frames"] = (
-            (feature_lens // params.subsampling_factor).sum().item()
-        )
+        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
 
     # Note: We use reduction=sum while computing the loss.
     info["loss"] = loss.detach().cpu().item()
@@ -733,9 +721,7 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(
-                    tb_writer, "train/tot_", params.batch_idx_train
-                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
diff --git a/egs/tal_csasr/ASR/local/compute_fbank_tal_csasr.py b/egs/tal_csasr/ASR/local/compute_fbank_tal_csasr.py
index 4582609ac..602e50d29 100755
--- a/egs/tal_csasr/ASR/local/compute_fbank_tal_csasr.py
+++ b/egs/tal_csasr/ASR/local/compute_fbank_tal_csasr.py
@@ -84,9 +84,7 @@ def compute_fbank_tal_csasr(num_mel_bins: int = 80):
             )
             if "train" in partition:
                 cut_set = (
-                    cut_set
-                    + cut_set.perturb_speed(0.9)
-                    + cut_set.perturb_speed(1.1)
+                    cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
                 )
             cut_set = cut_set.compute_and_store_features(
                 extractor=extractor,
@@ -112,9 +110,7 @@ def get_args():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
 
diff --git a/egs/tal_csasr/ASR/local/prepare_char.py b/egs/tal_csasr/ASR/local/prepare_char.py
index 2c5b8b8b3..1262baf63 100755
--- a/egs/tal_csasr/ASR/local/prepare_char.py
+++ b/egs/tal_csasr/ASR/local/prepare_char.py
@@ -87,9 +87,7 @@ def lexicon_to_fst_no_sil(
         cur_state = loop_state
 
         word = word2id[word]
-        pieces = [
-            token2id[i] if i in token2id else token2id[""] for i in pieces
-        ]
+        pieces = [token2id[i] if i in token2id else token2id[""] for i in pieces]
 
         for i in range(len(pieces) - 1):
             w = word if i == 0 else eps
diff --git a/egs/tal_csasr/ASR/local/prepare_lang.py b/egs/tal_csasr/ASR/local/prepare_lang.py
index e5ae89ec4..c8cf9b881 100755
--- a/egs/tal_csasr/ASR/local/prepare_lang.py
+++ b/egs/tal_csasr/ASR/local/prepare_lang.py
@@ -317,9 +317,7 @@ def lexicon_to_fst(
 
 def get_args():
     parser = argparse.ArgumentParser()
-    parser.add_argument(
-        "--lang-dir", type=str, help="The lang dir, data/lang_phone"
-    )
+    parser.add_argument("--lang-dir", type=str, help="The lang dir, data/lang_phone")
     return parser.parse_args()
 
 
diff --git a/egs/tal_csasr/ASR/local/test_prepare_lang.py b/egs/tal_csasr/ASR/local/test_prepare_lang.py
index d4cf62bba..74e025ad7 100755
--- a/egs/tal_csasr/ASR/local/test_prepare_lang.py
+++ b/egs/tal_csasr/ASR/local/test_prepare_lang.py
@@ -88,9 +88,7 @@ def test_read_lexicon(filename: str):
     fsa.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
     fsa.draw("L.pdf", title="L")
 
-    fsa_disambig = lexicon_to_fst(
-        lexicon_disambig, phone2id=phone2id, word2id=word2id
-    )
+    fsa_disambig = lexicon_to_fst(lexicon_disambig, phone2id=phone2id, word2id=word2id)
     fsa_disambig.labels_sym = k2.SymbolTable.from_file("phones.txt")
     fsa_disambig.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
     fsa_disambig.draw("L_disambig.pdf", title="L_disambig")
diff --git a/egs/tal_csasr/ASR/local/text2token.py b/egs/tal_csasr/ASR/local/text2token.py
index 71be2a613..85047c367 100755
--- a/egs/tal_csasr/ASR/local/text2token.py
+++ b/egs/tal_csasr/ASR/local/text2token.py
@@ -56,9 +56,7 @@ def get_parser():
     parser.add_argument(
         "--skip-ncols", "-s", default=0, type=int, help="skip first n columns"
     )
-    parser.add_argument(
-        "--space", default="", type=str, help="space symbol"
-    )
+    parser.add_argument("--space", default="", type=str, help="space symbol")
     parser.add_argument(
         "--non-lang-syms",
         "-l",
@@ -66,9 +64,7 @@ def get_parser():
         type=str,
         help="list of non-linguistic symobles, e.g.,  etc.",
     )
-    parser.add_argument(
-        "text", type=str, default=False, nargs="?", help="input text"
-    )
+    parser.add_argument("text", type=str, default=False, nargs="?", help="input text")
     parser.add_argument(
         "--trans_type",
         "-t",
@@ -108,8 +104,7 @@ def token2id(
             if token_type == "lazy_pinyin":
                 text = lazy_pinyin(chars_list)
                 sub_ids = [
-                    token_table[txt] if txt in token_table else oov_id
-                    for txt in text
+                    token_table[txt] if txt in token_table else oov_id for txt in text
                 ]
                 ids.append(sub_ids)
             else:  # token_type = "pinyin"
@@ -135,9 +130,7 @@ def main():
     if args.text:
         f = codecs.open(args.text, encoding="utf-8")
     else:
-        f = codecs.getreader("utf-8")(
-            sys.stdin if is_python2 else sys.stdin.buffer
-        )
+        f = codecs.getreader("utf-8")(sys.stdin if is_python2 else sys.stdin.buffer)
 
     sys.stdout = codecs.getwriter("utf-8")(
         sys.stdout if is_python2 else sys.stdout.buffer
diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/asr_datamodule.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/asr_datamodule.py
index 49bfb148b..2240c1c1d 100644
--- a/egs/tal_csasr/ASR/pruned_transducer_stateless5/asr_datamodule.py
+++ b/egs/tal_csasr/ASR/pruned_transducer_stateless5/asr_datamodule.py
@@ -222,17 +222,13 @@ class TAL_CSASRAsrDataModule:
             The state dict for the training sampler.
         """
         logging.info("About to get Musan cuts")
-        cuts_musan = load_manifest(
-            self.args.manifest_dir / "musan_cuts.jsonl.gz"
-        )
+        cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
 
         transforms = []
         if self.args.enable_musan:
             logging.info("Enable MUSAN")
             transforms.append(
-                CutMix(
-                    cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
-                )
+                CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
             )
         else:
             logging.info("Disable MUSAN")
@@ -254,9 +250,7 @@ class TAL_CSASRAsrDataModule:
         input_transforms = []
         if self.args.enable_spec_aug:
             logging.info("Enable SpecAugment")
-            logging.info(
-                f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
-            )
+            logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
             # Set the value of num_frame_masks according to Lhotse's version.
             # In different Lhotse's versions, the default of num_frame_masks is
             # different.
@@ -300,9 +294,7 @@ class TAL_CSASRAsrDataModule:
             # Drop feats to be on the safe side.
             train = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(
-                    Fbank(FbankConfig(num_mel_bins=80))
-                ),
+                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
                 input_transforms=input_transforms,
                 return_cuts=self.args.return_cuts,
             )
@@ -360,9 +352,7 @@ class TAL_CSASRAsrDataModule:
         if self.args.on_the_fly_feats:
             validate = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(
-                    Fbank(FbankConfig(num_mel_bins=80))
-                ),
+                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
                 return_cuts=self.args.return_cuts,
             )
         else:
diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/decode.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/decode.py
index b624913f5..82e1a9437 100755
--- a/egs/tal_csasr/ASR/pruned_transducer_stateless5/decode.py
+++ b/egs/tal_csasr/ASR/pruned_transducer_stateless5/decode.py
@@ -208,8 +208,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -268,9 +267,7 @@ def decode_one_batch(
     supervisions = batch["supervisions"]
     feature_lens = supervisions["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(
-        x=feature, x_lens=feature_lens
-    )
+    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
     hyps = []
     zh_hyps = []
     en_hyps = []
@@ -303,10 +300,7 @@ def decode_one_batch(
             hyps.append(chars_new)
             zh_hyps.append(zh_text)
             en_hyps.append(en_text)
-    elif (
-        params.decoding_method == "greedy_search"
-        and params.max_sym_per_frame == 1
-    ):
+    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -375,9 +369,7 @@ def decode_one_batch(
                     f"Unsupported decoding method: {params.decoding_method}"
                 )
             for i in range(encoder_out.size(0)):
-                hyp = sp.decode(
-                    [lexicon.token_table[idx] for idx in hyp_tokens[i]]
-                )
+                hyp = sp.decode([lexicon.token_table[idx] for idx in hyp_tokens[i]])
                 chars = pattern.split(hyp.upper())
                 chars_new = []
                 zh_text = []
@@ -506,9 +498,7 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
     return results, zh_results, en_results
 
 
@@ -541,8 +531,7 @@ def save_results(
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = (
-        params.res_dir
-        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tWER", file=f)
@@ -585,9 +574,7 @@ def main():
         params.suffix += f"-max-contexts-{params.max_contexts}"
         params.suffix += f"-max-states-{params.max_states}"
     elif "beam_search" in params.decoding_method:
-        params.suffix += (
-            f"-{params.decoding_method}-beam-size-{params.beam_size}"
-        )
+        params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
     else:
         params.suffix += f"-context-{params.context_size}"
         params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
@@ -619,9 +606,9 @@ def main():
 
     if not params.use_averaged_model:
         if params.iter > 0:
-            filenames = find_checkpoints(
-                params.exp_dir, iteration=-params.iter
-            )[: params.avg]
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg
+            ]
             if len(filenames) == 0:
                 raise ValueError(
                     f"No checkpoints found for"
@@ -648,9 +635,9 @@ def main():
             model.load_state_dict(average_checkpoints(filenames, device=device))
     else:
         if params.iter > 0:
-            filenames = find_checkpoints(
-                params.exp_dir, iteration=-params.iter
-            )[: params.avg + 1]
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg + 1
+            ]
             if len(filenames) == 0:
                 raise ValueError(
                     f"No checkpoints found for"
diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/export.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/export.py
index 8f900208a..d0875c5f5 100755
--- a/egs/tal_csasr/ASR/pruned_transducer_stateless5/export.py
+++ b/egs/tal_csasr/ASR/pruned_transducer_stateless5/export.py
@@ -139,8 +139,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
 
     add_model_arguments(parser)
@@ -176,9 +175,9 @@ def main():
 
     if not params.use_averaged_model:
         if params.iter > 0:
-            filenames = find_checkpoints(
-                params.exp_dir, iteration=-params.iter
-            )[: params.avg]
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg
+            ]
             if len(filenames) == 0:
                 raise ValueError(
                     f"No checkpoints found for"
@@ -205,9 +204,9 @@ def main():
             model.load_state_dict(average_checkpoints(filenames, device=device))
     else:
         if params.iter > 0:
-            filenames = find_checkpoints(
-                params.exp_dir, iteration=-params.iter
-            )[: params.avg + 1]
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg + 1
+            ]
             if len(filenames) == 0:
                 raise ValueError(
                     f"No checkpoints found for"
@@ -277,9 +276,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/pretrained.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/pretrained.py
index dbe213b24..da4e3bc2f 100755
--- a/egs/tal_csasr/ASR/pruned_transducer_stateless5/pretrained.py
+++ b/egs/tal_csasr/ASR/pruned_transducer_stateless5/pretrained.py
@@ -165,8 +165,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -198,8 +197,7 @@ def read_sound_files(
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
         assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. "
-            f"Given: {sample_rate}"
+            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
         )
         # We use only the first channel
         ans.append(wave[0])
@@ -263,15 +261,11 @@ def main():
     features = fbank(waves)
     feature_lengths = [f.size(0) for f in features]
 
-    features = pad_sequence(
-        features, batch_first=True, padding_value=math.log(1e-10)
-    )
+    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
 
     feature_lengths = torch.tensor(feature_lengths, device=device)
 
-    encoder_out, encoder_out_lens = model.encoder(
-        x=features, x_lens=feature_lengths
-    )
+    encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths)
 
     num_waves = encoder_out.size(0)
     hyps = []
@@ -367,9 +361,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/train.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/train.py
index ca35eba45..97d434157 100755
--- a/egs/tal_csasr/ASR/pruned_transducer_stateless5/train.py
+++ b/egs/tal_csasr/ASR/pruned_transducer_stateless5/train.py
@@ -86,9 +86,7 @@ from icefall.env import get_env_info
 from icefall.lexicon import Lexicon
 from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
 
-LRSchedulerType = Union[
-    torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
-]
+LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
 
 
 def add_model_arguments(parser: argparse.ArgumentParser):
@@ -214,8 +212,7 @@ def get_parser():
         "--initial-lr",
         type=float,
         default=0.003,
-        help="The initial learning rate.  This value should not need "
-        "to be changed.",
+        help="The initial learning rate.  This value should not need " "to be changed.",
     )
 
     parser.add_argument(
@@ -238,8 +235,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
 
     parser.add_argument(
@@ -262,8 +258,7 @@ def get_parser():
         "--am-scale",
         type=float,
         default=0.0,
-        help="The scale to smooth the loss with am (output of encoder network)"
-        "part.",
+        help="The scale to smooth the loss with am (output of encoder network)" "part.",
     )
 
     parser.add_argument(
@@ -600,11 +595,7 @@ def compute_loss(
      warmup: a floating point value which increases throughout training;
         values >= 1.0 are fully warmed up and have all modules present.
     """
-    device = (
-        model.device
-        if isinstance(model, DDP)
-        else next(model.parameters()).device
-    )
+    device = model.device if isinstance(model, DDP) else next(model.parameters()).device
     feature = batch["inputs"]
     # at entry, feature is (N, T, C)
     assert feature.ndim == 3
@@ -634,22 +625,15 @@ def compute_loss(
         # overwhelming the simple_loss and causing it to diverge,
         # in case it had not fully learned the alignment yet.
         pruned_loss_scale = (
-            0.0
-            if warmup < 1.0
-            else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
-        )
-        loss = (
-            params.simple_loss_scale * simple_loss
-            + pruned_loss_scale * pruned_loss
+            0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
         )
+        loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
     assert loss.requires_grad == is_training
 
     info = MetricsTracker()
     with warnings.catch_warnings():
         warnings.simplefilter("ignore")
-        info["frames"] = (
-            (feature_lens // params.subsampling_factor).sum().item()
-        )
+        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
 
     # Note: We use reduction=sum while computing the loss.
     info["loss"] = loss.detach().cpu().item()
@@ -828,9 +812,7 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(
-                    tb_writer, "train/tot_", params.batch_idx_train
-                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
@@ -944,7 +926,7 @@ def run(rank, world_size, args):
 
     if params.print_diagnostics:
         opts = diagnostics.TensorDiagnosticOptions(
-            2 ** 22
+            2**22
         )  # allow 4 megabytes per sub-module
         diagnostic = diagnostics.attach_diagnostics(model, opts)
 
diff --git a/egs/tedlium3/ASR/local/compute_fbank_tedlium.py b/egs/tedlium3/ASR/local/compute_fbank_tedlium.py
index 327962a79..733ebf235 100755
--- a/egs/tedlium3/ASR/local/compute_fbank_tedlium.py
+++ b/egs/tedlium3/ASR/local/compute_fbank_tedlium.py
@@ -83,9 +83,7 @@ def compute_fbank_tedlium():
             )
             if "train" in partition:
                 cut_set = (
-                    cut_set
-                    + cut_set.perturb_speed(0.9)
-                    + cut_set.perturb_speed(1.1)
+                    cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
                 )
             cur_num_jobs = num_jobs if ex is None else 80
             cur_num_jobs = min(cur_num_jobs, len(cut_set))
@@ -104,9 +102,7 @@ def compute_fbank_tedlium():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
 
diff --git a/egs/tedlium3/ASR/local/convert_transcript_words_to_bpe_ids.py b/egs/tedlium3/ASR/local/convert_transcript_words_to_bpe_ids.py
index 49544ccb3..9dbcc9d9e 100644
--- a/egs/tedlium3/ASR/local/convert_transcript_words_to_bpe_ids.py
+++ b/egs/tedlium3/ASR/local/convert_transcript_words_to_bpe_ids.py
@@ -25,9 +25,7 @@ import sentencepiece as spm
 
 def get_args():
     parser = argparse.ArgumentParser()
-    parser.add_argument(
-        "--texts", type=List[str], help="The input transcripts list."
-    )
+    parser.add_argument("--texts", type=List[str], help="The input transcripts list.")
     parser.add_argument(
         "--bpe-model",
         type=str,
diff --git a/egs/tedlium3/ASR/local/prepare_lexicon.py b/egs/tedlium3/ASR/local/prepare_lexicon.py
index 35dd332e8..b9160b6d4 100755
--- a/egs/tedlium3/ASR/local/prepare_lexicon.py
+++ b/egs/tedlium3/ASR/local/prepare_lexicon.py
@@ -23,11 +23,12 @@ consisting of supervisions_train.json and does the following:
 1. Generate lexicon_words.txt.
 
 """
-import lhotse
 import argparse
 import logging
 from pathlib import Path
 
+import lhotse
+
 
 def get_args():
     parser = argparse.ArgumentParser()
@@ -61,9 +62,7 @@ def prepare_lexicon(manifests_dir: str, lang_dir: str):
     words = set()
 
     lexicon = Path(lang_dir) / "lexicon_words.txt"
-    sups = lhotse.load_manifest(
-        f"{manifests_dir}/tedlium_supervisions_train.jsonl.gz"
-    )
+    sups = lhotse.load_manifest(f"{manifests_dir}/tedlium_supervisions_train.jsonl.gz")
     for s in sups:
         # list the words units and filter the empty item
         words_list = list(filter(None, s.text.split()))
@@ -88,9 +87,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
 
diff --git a/egs/tedlium3/ASR/local/prepare_transcripts.py b/egs/tedlium3/ASR/local/prepare_transcripts.py
index 1039ac5bb..7ea4e89a4 100755
--- a/egs/tedlium3/ASR/local/prepare_transcripts.py
+++ b/egs/tedlium3/ASR/local/prepare_transcripts.py
@@ -23,11 +23,12 @@ consisting of supervisions_train.json and does the following:
 1. Generate train.text.
 
 """
-import lhotse
 import argparse
 import logging
 from pathlib import Path
 
+import lhotse
+
 
 def get_args():
     parser = argparse.ArgumentParser()
@@ -61,9 +62,7 @@ def prepare_transcripts(manifests_dir: str, lang_dir: str):
     texts = []
 
     train_text = Path(lang_dir) / "train.text"
-    sups = lhotse.load_manifest(
-        f"{manifests_dir}/tedlium_supervisions_train.jsonl.gz"
-    )
+    sups = lhotse.load_manifest(f"{manifests_dir}/tedlium_supervisions_train.jsonl.gz")
     for s in sups:
         texts.append(s.text)
 
@@ -83,9 +82,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
 
diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless/decode.py b/egs/tedlium3/ASR/pruned_transducer_stateless/decode.py
index 2b294e601..8ca875c24 100755
--- a/egs/tedlium3/ASR/pruned_transducer_stateless/decode.py
+++ b/egs/tedlium3/ASR/pruned_transducer_stateless/decode.py
@@ -172,8 +172,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -231,9 +230,7 @@ def decode_one_batch(
     supervisions = batch["supervisions"]
     feature_lens = supervisions["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(
-        x=feature, x_lens=feature_lens
-    )
+    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
     hyps = []
 
     if params.decoding_method == "fast_beam_search":
@@ -248,10 +245,7 @@ def decode_one_batch(
         )
         for hyp in sp.decode(hyp_tokens):
             hyps.append(hyp.split())
-    elif (
-        params.decoding_method == "greedy_search"
-        and params.max_sym_per_frame == 1
-    ):
+    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -374,9 +368,7 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
     return results
 
 
@@ -409,8 +401,7 @@ def save_results(
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = (
-        params.res_dir
-        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tWER", file=f)
diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless/export.py b/egs/tedlium3/ASR/pruned_transducer_stateless/export.py
index a1c3bcea3..71a9e2d71 100644
--- a/egs/tedlium3/ASR/pruned_transducer_stateless/export.py
+++ b/egs/tedlium3/ASR/pruned_transducer_stateless/export.py
@@ -106,8 +106,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
 
     return parser
@@ -179,9 +178,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless/pretrained.py b/egs/tedlium3/ASR/pruned_transducer_stateless/pretrained.py
index 8480ac029..e8a453c80 100644
--- a/egs/tedlium3/ASR/pruned_transducer_stateless/pretrained.py
+++ b/egs/tedlium3/ASR/pruned_transducer_stateless/pretrained.py
@@ -165,8 +165,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
 
     parser.add_argument(
@@ -204,8 +203,7 @@ def read_sound_files(
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
         assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. "
-            f"Given: {sample_rate}"
+            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
         )
         # We use only the first channel
         ans.append(wave[0])
@@ -271,9 +269,7 @@ def main():
     features = fbank(waves)
     feature_lengths = [f.size(0) for f in features]
 
-    features = pad_sequence(
-        features, batch_first=True, padding_value=math.log(1e-10)
-    )
+    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
 
     feature_lengths = torch.tensor(feature_lengths, device=device)
 
@@ -298,10 +294,7 @@ def main():
         )
         for hyp in sp.decode(hyp_tokens):
             hyps.append(hyp.split())
-    elif (
-        params.decoding_method == "greedy_search"
-        and params.max_sym_per_frame == 1
-    ):
+    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -353,9 +346,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless/train.py b/egs/tedlium3/ASR/pruned_transducer_stateless/train.py
index 8d5cdf683..59d80a0d8 100755
--- a/egs/tedlium3/ASR/pruned_transducer_stateless/train.py
+++ b/egs/tedlium3/ASR/pruned_transducer_stateless/train.py
@@ -133,8 +133,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
 
     parser.add_argument(
@@ -157,8 +156,7 @@ def get_parser():
         "--am-scale",
         type=float,
         default=0.0,
-        help="The scale to smooth the loss with am (output of encoder network)"
-        "part.",
+        help="The scale to smooth the loss with am (output of encoder network)" "part.",
     )
 
     parser.add_argument(
@@ -556,9 +554,7 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(
-                    tb_writer, "train/tot_", params.batch_idx_train
-                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
@@ -678,9 +674,7 @@ def run(rank, world_size, args):
 
         cur_lr = optimizer._rate
         if tb_writer is not None:
-            tb_writer.add_scalar(
-                "train/learning_rate", cur_lr, params.batch_idx_train
-            )
+            tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
             tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
 
         if rank == 0:
diff --git a/egs/tedlium3/ASR/transducer_stateless/asr_datamodule.py b/egs/tedlium3/ASR/transducer_stateless/asr_datamodule.py
index 94784c4c4..c647392f0 100644
--- a/egs/tedlium3/ASR/transducer_stateless/asr_datamodule.py
+++ b/egs/tedlium3/ASR/transducer_stateless/asr_datamodule.py
@@ -18,7 +18,6 @@
 
 import argparse
 import logging
-
 from functools import lru_cache
 from pathlib import Path
 from typing import Any, Dict, Optional
@@ -171,9 +170,7 @@ class TedLiumAsrDataModule:
         )
 
     def train_dataloaders(
-        self,
-        cuts_train: CutSet,
-        sampler_state_dict: Optional[Dict[str, Any]] = None
+        self, cuts_train: CutSet, sampler_state_dict: Optional[Dict[str, Any]] = None
     ) -> DataLoader:
         """
         Args:
@@ -186,9 +183,7 @@ class TedLiumAsrDataModule:
         input_transforms = []
         if self.args.enable_spec_aug:
             logging.info("Enable SpecAugment")
-            logging.info(
-                f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
-            )
+            logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
 
             input_transforms.append(
                 SpecAugment(
@@ -208,13 +203,9 @@ class TedLiumAsrDataModule:
         transforms = []
         if self.args.enable_musan:
             logging.info("Enable MUSAN")
-            cuts_musan = load_manifest(
-                self.args.manifest_dir / "musan_cuts.jsonl.gz"
-            )
+            cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
             transforms.append(
-                CutMix(
-                    cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
-                )
+                CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
             )
         else:
             logging.info("Disable MUSAN")
@@ -247,9 +238,7 @@ class TedLiumAsrDataModule:
             # Drop feats to be on the safe side.
             train = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(
-                    Fbank(FbankConfig(num_mel_bins=80))
-                ),
+                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
                 input_transforms=input_transforms,
                 return_cuts=self.args.return_cuts,
             )
@@ -306,9 +295,7 @@ class TedLiumAsrDataModule:
         if self.args.on_the_fly_feats:
             validate = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(
-                    Fbank(FbankConfig(num_mel_bins=80))
-                ),
+                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
                 return_cuts=self.args.return_cuts,
             )
         else:
@@ -339,9 +326,7 @@ class TedLiumAsrDataModule:
         logging.debug("About to create test dataset")
         if self.args.on_the_fly_feats:
             test = K2SpeechRecognitionDataset(
-                input_strategy=OnTheFlyFeatures(
-                    Fbank(FbankConfig(num_mel_bins=80))
-                ),
+                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
                 return_cuts=self.args.return_cuts,
             )
         else:
@@ -375,13 +360,9 @@ class TedLiumAsrDataModule:
     @lru_cache()
     def dev_cuts(self) -> CutSet:
         logging.info("About to get dev cuts")
-        return load_manifest_lazy(
-            self.args.manifest_dir / "tedlium_cuts_dev.jsonl.gz"
-        )
+        return load_manifest_lazy(self.args.manifest_dir / "tedlium_cuts_dev.jsonl.gz")
 
     @lru_cache()
     def test_cuts(self) -> CutSet:
         logging.info("About to get test cuts")
-        return load_manifest_lazy(
-            self.args.manifest_dir / "tedlium_cuts_test.jsonl.gz"
-        )
+        return load_manifest_lazy(self.args.manifest_dir / "tedlium_cuts_test.jsonl.gz")
diff --git a/egs/tedlium3/ASR/transducer_stateless/beam_search.py b/egs/tedlium3/ASR/transducer_stateless/beam_search.py
index 77caf6460..1f99edaf3 100644
--- a/egs/tedlium3/ASR/transducer_stateless/beam_search.py
+++ b/egs/tedlium3/ASR/transducer_stateless/beam_search.py
@@ -87,9 +87,9 @@ def greedy_search(
         y = logits.argmax().item()
         if y != blank_id and y != unk_id:
             hyp.append(y)
-            decoder_input = torch.tensor(
-                [hyp[-context_size:]], device=device
-            ).reshape(1, context_size)
+            decoder_input = torch.tensor([hyp[-context_size:]], device=device).reshape(
+                1, context_size
+            )
 
             decoder_out = model.decoder(decoder_input, need_pad=False)
 
@@ -148,9 +148,7 @@ class HypothesisList(object):
         key = hyp.key
         if key in self:
             old_hyp = self._data[key]  # shallow copy
-            torch.logaddexp(
-                old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob
-            )
+            torch.logaddexp(old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob)
         else:
             self._data[key] = hyp
 
@@ -166,9 +164,7 @@ class HypothesisList(object):
           Return the hypothesis that has the largest `log_prob`.
         """
         if length_norm:
-            return max(
-                self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys)
-            )
+            return max(self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys))
         else:
             return max(self._data.values(), key=lambda hyp: hyp.log_prob)
 
@@ -344,9 +340,9 @@ def modified_beam_search(
 
     device = model.device
 
-    decoder_input = torch.tensor(
-        [blank_id] * context_size, device=device
-    ).reshape(1, context_size)
+    decoder_input = torch.tensor([blank_id] * context_size, device=device).reshape(
+        1, context_size
+    )
 
     decoder_out = model.decoder(decoder_input, need_pad=False)
 
@@ -383,9 +379,7 @@ def modified_beam_search(
         decoder_out = model.decoder(decoder_input, need_pad=False)
         # decoder_output is of shape (num_hyps, 1, decoder_output_dim)
 
-        current_encoder_out = current_encoder_out.expand(
-            decoder_out.size(0), 1, -1
-        )
+        current_encoder_out = current_encoder_out.expand(decoder_out.size(0), 1, -1)
 
         logits = model.joiner(
             current_encoder_out,
@@ -454,9 +448,9 @@ def beam_search(
 
     device = model.device
 
-    decoder_input = torch.tensor(
-        [blank_id] * context_size, device=device
-    ).reshape(1, context_size)
+    decoder_input = torch.tensor([blank_id] * context_size, device=device).reshape(
+        1, context_size
+    )
 
     decoder_out = model.decoder(decoder_input, need_pad=False)
 
diff --git a/egs/tedlium3/ASR/transducer_stateless/decode.py b/egs/tedlium3/ASR/transducer_stateless/decode.py
index d3e9e55e7..e5ab2c107 100755
--- a/egs/tedlium3/ASR/transducer_stateless/decode.py
+++ b/egs/tedlium3/ASR/transducer_stateless/decode.py
@@ -130,8 +130,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -250,9 +249,7 @@ def decode_one_batch(
     supervisions = batch["supervisions"]
     feature_lens = supervisions["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(
-        x=feature, x_lens=feature_lens
-    )
+    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
     hyps = []
     batch_size = encoder_out.size(0)
 
@@ -275,9 +272,7 @@ def decode_one_batch(
                 model=model, encoder_out=encoder_out_i, beam=params.beam_size
             )
         else:
-            raise ValueError(
-                f"Unsupported decoding method: {params.decoding_method}"
-            )
+            raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
         hyps.append(sp.decode(hyp).split())
 
     if params.decoding_method == "greedy_search":
@@ -348,9 +343,7 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
     return results
 
 
@@ -383,8 +376,7 @@ def save_results(
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = (
-        params.res_dir
-        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tWER", file=f)
diff --git a/egs/tedlium3/ASR/transducer_stateless/decoder.py b/egs/tedlium3/ASR/transducer_stateless/decoder.py
index f0c6f32b6..f9a3814c6 100644
--- a/egs/tedlium3/ASR/transducer_stateless/decoder.py
+++ b/egs/tedlium3/ASR/transducer_stateless/decoder.py
@@ -90,9 +90,7 @@ class Decoder(nn.Module):
         if self.context_size > 1:
             embedding_out = embedding_out.permute(0, 2, 1)
             if need_pad is True:
-                embedding_out = F.pad(
-                    embedding_out, pad=(self.context_size - 1, 0)
-                )
+                embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0))
             else:
                 # During inference time, there is no need to do extra padding
                 # as we only need one output
diff --git a/egs/tedlium3/ASR/transducer_stateless/export.py b/egs/tedlium3/ASR/transducer_stateless/export.py
index c32b1d002..c2ec7a590 100644
--- a/egs/tedlium3/ASR/transducer_stateless/export.py
+++ b/egs/tedlium3/ASR/transducer_stateless/export.py
@@ -110,8 +110,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
 
     return parser
@@ -247,9 +246,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/tedlium3/ASR/transducer_stateless/pretrained.py b/egs/tedlium3/ASR/transducer_stateless/pretrained.py
index c0e3bb844..070b070a7 100644
--- a/egs/tedlium3/ASR/transducer_stateless/pretrained.py
+++ b/egs/tedlium3/ASR/transducer_stateless/pretrained.py
@@ -127,8 +127,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -223,8 +222,7 @@ def read_sound_files(
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
         assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. "
-            f"Given: {sample_rate}"
+            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
         )
         # We use only the first channel
         ans.append(wave[0])
@@ -285,9 +283,7 @@ def main():
     features = fbank(waves)
     feature_lengths = [f.size(0) for f in features]
 
-    features = pad_sequence(
-        features, batch_first=True, padding_value=math.log(1e-10)
-    )
+    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
 
     feature_lengths = torch.tensor(feature_lengths, device=device)
 
@@ -335,9 +331,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/tedlium3/ASR/transducer_stateless/train.py b/egs/tedlium3/ASR/transducer_stateless/train.py
index 09cbf4a00..4fc13b1da 100755
--- a/egs/tedlium3/ASR/transducer_stateless/train.py
+++ b/egs/tedlium3/ASR/transducer_stateless/train.py
@@ -133,8 +133,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
 
     parser.add_argument(
@@ -525,9 +524,7 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(
-                    tb_writer, "train/tot_", params.batch_idx_train
-                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
@@ -647,9 +644,7 @@ def run(rank, world_size, args):
 
         cur_lr = optimizer._rate
         if tb_writer is not None:
-            tb_writer.add_scalar(
-                "train/learning_rate", cur_lr, params.batch_idx_train
-            )
+            tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
             tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
 
         if rank == 0:
diff --git a/egs/timit/ASR/RESULTS.md b/egs/timit/ASR/RESULTS.md
index b78c16b88..d8ceb82b6 100644
--- a/egs/timit/ASR/RESULTS.md
+++ b/egs/timit/ASR/RESULTS.md
@@ -71,4 +71,4 @@ python tdnn_ligru_ctc/decode.py --epoch 25 \
                                --avg 17 \
                                --max-duration 20 \
                                --lang-dir data/lang_phone
-```
\ No newline at end of file
+```
diff --git a/egs/timit/ASR/local/compile_hlg.py b/egs/timit/ASR/local/compile_hlg.py
index 58cab4cf2..32c248d7e 100644
--- a/egs/timit/ASR/local/compile_hlg.py
+++ b/egs/timit/ASR/local/compile_hlg.py
@@ -146,9 +146,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
 
diff --git a/egs/timit/ASR/local/compute_fbank_timit.py b/egs/timit/ASR/local/compute_fbank_timit.py
index f25786a0c..ecdf10ba9 100644
--- a/egs/timit/ASR/local/compute_fbank_timit.py
+++ b/egs/timit/ASR/local/compute_fbank_timit.py
@@ -85,9 +85,7 @@ def compute_fbank_timit():
             )
             if partition == "TRAIN":
                 cut_set = (
-                    cut_set
-                    + cut_set.perturb_speed(0.9)
-                    + cut_set.perturb_speed(1.1)
+                    cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
                 )
             cut_set = cut_set.compute_and_store_features(
                 extractor=extractor,
@@ -101,9 +99,7 @@ def compute_fbank_timit():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
 
diff --git a/egs/timit/ASR/local/prepare_lexicon.py b/egs/timit/ASR/local/prepare_lexicon.py
index 04023a9ab..0cf0f0deb 100644
--- a/egs/timit/ASR/local/prepare_lexicon.py
+++ b/egs/timit/ASR/local/prepare_lexicon.py
@@ -62,9 +62,7 @@ def prepare_lexicon(manifests_dir: str, lang_dir: str):
 
     phones = set()
 
-    supervisions_train = (
-        Path(manifests_dir) / "timit_supervisions_TRAIN.jsonl.gz"
-    )
+    supervisions_train = Path(manifests_dir) / "timit_supervisions_TRAIN.jsonl.gz"
     lexicon = Path(lang_dir) / "lexicon.txt"
 
     logging.info(f"Loading {supervisions_train}!")
@@ -97,9 +95,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
 
diff --git a/egs/timit/ASR/prepare.sh b/egs/timit/ASR/prepare.sh
index ae1b96a68..d11cd3a05 100644
--- a/egs/timit/ASR/prepare.sh
+++ b/egs/timit/ASR/prepare.sh
@@ -20,9 +20,9 @@ stop_stage=100
 #  - $dl_dir/lm
 #      This directory contains the language model(LM) downloaded from
 #      https://huggingface.co/luomingshuang/timit_lm, and the LM is based
-#	     on 39 phones. About how to get these LM files, you can know it 
+#	     on 39 phones. About how to get these LM files, you can know it
 #      from https://github.com/luomingshuang/Train_LM_with_kaldilm.
-#	
+#
 #	    - lm_3_gram.arpa
 #     - lm_4_gram.arpa
 #
diff --git a/egs/timit/ASR/tdnn_ligru_ctc/decode.py b/egs/timit/ASR/tdnn_ligru_ctc/decode.py
index 4f2aa2340..4beeed18c 100644
--- a/egs/timit/ASR/tdnn_ligru_ctc/decode.py
+++ b/egs/timit/ASR/tdnn_ligru_ctc/decode.py
@@ -336,9 +336,7 @@ def decode_dataset(
         if batch_idx % 100 == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
     return results
 
 
@@ -400,9 +398,7 @@ def main():
 
     logging.info(f"device: {device}")
 
-    HLG = k2.Fsa.from_dict(
-        torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")
-    )
+    HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu"))
     HLG = HLG.to(device)
     assert HLG.requires_grad is False
 
@@ -462,9 +458,7 @@ def main():
 
     if params.export:
         logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
-        torch.save(
-            {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
-        )
+        torch.save({"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt")
         return
 
     model.to(device)
@@ -485,9 +479,7 @@ def main():
         G=G,
     )
 
-    save_results(
-        params=params, test_set_name=test_set, results_dict=results_dict
-    )
+    save_results(params=params, test_set_name=test_set, results_dict=results_dict)
 
     logging.info("Done!")
 
diff --git a/egs/timit/ASR/tdnn_ligru_ctc/model.py b/egs/timit/ASR/tdnn_ligru_ctc/model.py
index 4d2199ace..9a594a969 100644
--- a/egs/timit/ASR/tdnn_ligru_ctc/model.py
+++ b/egs/timit/ASR/tdnn_ligru_ctc/model.py
@@ -16,11 +16,11 @@
 # limitations under the License.
 
 
+from typing import Optional
+
 import torch
 import torch.nn as nn
-
 from torch import Tensor
-from typing import Optional
 
 
 class TdnnLiGRU(nn.Module):
@@ -261,9 +261,7 @@ class LiGRU(torch.nn.Module):
         h = []
         if hx is not None:
             if self.bidirectional:
-                hx = hx.reshape(
-                    self.num_layers, self.batch_size * 2, self.hidden_size
-                )
+                hx = hx.reshape(self.num_layers, self.batch_size * 2, self.hidden_size)
         # Processing the different layers
         for i, ligru_lay in enumerate(self.rnn):
             if hx is not None:
@@ -445,9 +443,7 @@ class LiGRU_Layer(torch.nn.Module):
             if self.drop_mask_cnt + self.batch_size > self.N_drop_masks:
                 self.drop_mask_cnt = 0
                 self.drop_masks = self.drop(
-                    torch.ones(
-                        self.N_drop_masks, self.hidden_size, device=w.device
-                    )
+                    torch.ones(self.N_drop_masks, self.hidden_size, device=w.device)
                 ).data
 
             # Sampling the mask
diff --git a/egs/timit/ASR/tdnn_ligru_ctc/pretrained.py b/egs/timit/ASR/tdnn_ligru_ctc/pretrained.py
index 7da285944..4ef134412 100644
--- a/egs/timit/ASR/tdnn_ligru_ctc/pretrained.py
+++ b/egs/timit/ASR/tdnn_ligru_ctc/pretrained.py
@@ -29,11 +29,7 @@ import torchaudio
 from model import TdnnLiGRU
 from torch.nn.utils.rnn import pad_sequence
 
-from icefall.decode import (
-    get_lattice,
-    one_best_decoding,
-    rescore_with_whole_lattice,
-)
+from icefall.decode import get_lattice, one_best_decoding, rescore_with_whole_lattice
 from icefall.utils import AttributeDict, get_texts
 
 
@@ -58,9 +54,7 @@ def get_parser():
         help="Path to words.txt",
     )
 
-    parser.add_argument(
-        "--HLG", type=str, required=True, help="Path to HLG.pt."
-    )
+    parser.add_argument("--HLG", type=str, required=True, help="Path to HLG.pt.")
 
     parser.add_argument(
         "--method",
@@ -145,8 +139,7 @@ def read_sound_files(
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
         assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. "
-            f"Given: {sample_rate}"
+            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
         )
         # We use only the first channel
         ans.append(wave[0])
@@ -215,9 +208,7 @@ def main():
     logging.info("Decoding started")
     features = fbank(waves)
 
-    features = pad_sequence(
-        features, batch_first=True, padding_value=math.log(1e-10)
-    )
+    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
     features = features.permute(0, 2, 1)  # now features is (N, C, T)
 
     with torch.no_grad():
@@ -269,9 +260,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/timit/ASR/tdnn_ligru_ctc/train.py b/egs/timit/ASR/tdnn_ligru_ctc/train.py
index 452c2a7cb..48b7feda0 100644
--- a/egs/timit/ASR/tdnn_ligru_ctc/train.py
+++ b/egs/timit/ASR/tdnn_ligru_ctc/train.py
@@ -449,9 +449,7 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(
-                    tb_writer, "train/tot_", params.batch_idx_train
-                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             valid_info = compute_validation_loss(
diff --git a/egs/timit/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/timit/ASR/tdnn_lstm_ctc/asr_datamodule.py
index 1554e987f..51ca4cc6e 100644
--- a/egs/timit/ASR/tdnn_lstm_ctc/asr_datamodule.py
+++ b/egs/timit/ASR/tdnn_lstm_ctc/asr_datamodule.py
@@ -154,9 +154,7 @@ class TimitAsrDataModule(DataModule):
         cuts_train = self.train_cuts()
 
         logging.info("About to get Musan cuts")
-        cuts_musan = load_manifest(
-            self.args.feature_dir / "musan_cuts.jsonl.gz"
-        )
+        cuts_musan = load_manifest(self.args.feature_dir / "musan_cuts.jsonl.gz")
 
         logging.info("About to create train dataset")
         transforms = [CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20))]
@@ -178,9 +176,9 @@ class TimitAsrDataModule(DataModule):
         # In different Lhotse's versions, the default of num_frame_masks is
         # different.
         num_frame_masks = 10
-        num_frame_masks_parameter = inspect.signature(
-            SpecAugment.__init__
-        ).parameters["num_frame_masks"]
+        num_frame_masks_parameter = inspect.signature(SpecAugment.__init__).parameters[
+            "num_frame_masks"
+        ]
         if num_frame_masks_parameter.default == 1:
             num_frame_masks = 2
         logging.info(f"Num frame mask: {num_frame_masks}")
@@ -212,9 +210,7 @@ class TimitAsrDataModule(DataModule):
             # Drop feats to be on the safe side.
             train = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(
-                    Fbank(FbankConfig(num_mel_bins=80))
-                ),
+                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
                 input_transforms=input_transforms,
                 return_cuts=self.args.return_cuts,
             )
@@ -263,9 +259,7 @@ class TimitAsrDataModule(DataModule):
         if self.args.on_the_fly_feats:
             validate = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(
-                    Fbank(FbankConfig(num_mel_bins=80))
-                ),
+                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
                 return_cuts=self.args.return_cuts,
             )
         else:
@@ -299,20 +293,14 @@ class TimitAsrDataModule(DataModule):
         for cuts_test in cuts:
             logging.debug("About to create test dataset")
             test = K2SpeechRecognitionDataset(
-                input_strategy=OnTheFlyFeatures(
-                    Fbank(FbankConfig(num_mel_bins=80))
-                )
+                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
                 if self.args.on_the_fly_feats
                 else PrecomputedFeatures(),
                 return_cuts=self.args.return_cuts,
             )
-            sampler = SingleCutSampler(
-                cuts_test, max_duration=self.args.max_duration
-            )
+            sampler = SingleCutSampler(cuts_test, max_duration=self.args.max_duration)
             logging.debug("About to create test dataloader")
-            test_dl = DataLoader(
-                test, batch_size=None, sampler=sampler, num_workers=1
-            )
+            test_dl = DataLoader(test, batch_size=None, sampler=sampler, num_workers=1)
             test_loaders.append(test_dl)
 
         if is_list:
diff --git a/egs/timit/ASR/tdnn_lstm_ctc/decode.py b/egs/timit/ASR/tdnn_lstm_ctc/decode.py
index 5e7300cf2..502a48def 100644
--- a/egs/timit/ASR/tdnn_lstm_ctc/decode.py
+++ b/egs/timit/ASR/tdnn_lstm_ctc/decode.py
@@ -335,9 +335,7 @@ def decode_dataset(
         if batch_idx % 100 == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
     return results
 
 
@@ -399,9 +397,7 @@ def main():
 
     logging.info(f"device: {device}")
 
-    HLG = k2.Fsa.from_dict(
-        torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")
-    )
+    HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu"))
     HLG = HLG.to(device)
     assert HLG.requires_grad is False
 
@@ -461,9 +457,7 @@ def main():
 
     if params.export:
         logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
-        torch.save(
-            {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
-        )
+        torch.save({"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt")
         return
 
     model.to(device)
@@ -483,9 +477,7 @@ def main():
         G=G,
     )
 
-    save_results(
-        params=params, test_set_name=test_set, results_dict=results_dict
-    )
+    save_results(params=params, test_set_name=test_set, results_dict=results_dict)
 
     logging.info("Done!")
 
diff --git a/egs/timit/ASR/tdnn_lstm_ctc/model.py b/egs/timit/ASR/tdnn_lstm_ctc/model.py
index 51edb97e2..e211ad80d 100644
--- a/egs/timit/ASR/tdnn_lstm_ctc/model.py
+++ b/egs/timit/ASR/tdnn_lstm_ctc/model.py
@@ -74,10 +74,7 @@ class TdnnLstm(nn.Module):
             nn.BatchNorm1d(num_features=512, affine=False),
         )
         self.lstms = nn.ModuleList(
-            [
-                nn.LSTM(input_size=512, hidden_size=512, num_layers=1)
-                for _ in range(4)
-            ]
+            [nn.LSTM(input_size=512, hidden_size=512, num_layers=1) for _ in range(4)]
         )
         self.lstm_bnorms = nn.ModuleList(
             [nn.BatchNorm1d(num_features=512, affine=False) for _ in range(5)]
diff --git a/egs/timit/ASR/tdnn_lstm_ctc/pretrained.py b/egs/timit/ASR/tdnn_lstm_ctc/pretrained.py
index 5f478da1c..3f143912e 100644
--- a/egs/timit/ASR/tdnn_lstm_ctc/pretrained.py
+++ b/egs/timit/ASR/tdnn_lstm_ctc/pretrained.py
@@ -29,11 +29,7 @@ import torchaudio
 from model import TdnnLstm
 from torch.nn.utils.rnn import pad_sequence
 
-from icefall.decode import (
-    get_lattice,
-    one_best_decoding,
-    rescore_with_whole_lattice,
-)
+from icefall.decode import get_lattice, one_best_decoding, rescore_with_whole_lattice
 from icefall.utils import AttributeDict, get_texts
 
 
@@ -58,9 +54,7 @@ def get_parser():
         help="Path to words.txt",
     )
 
-    parser.add_argument(
-        "--HLG", type=str, required=True, help="Path to HLG.pt."
-    )
+    parser.add_argument("--HLG", type=str, required=True, help="Path to HLG.pt.")
 
     parser.add_argument(
         "--method",
@@ -145,8 +139,7 @@ def read_sound_files(
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
         assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. "
-            f"Given: {sample_rate}"
+            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
         )
         # We use only the first channel
         ans.append(wave[0])
@@ -215,9 +208,7 @@ def main():
     logging.info("Decoding started")
     features = fbank(waves)
 
-    features = pad_sequence(
-        features, batch_first=True, padding_value=math.log(1e-10)
-    )
+    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
     features = features.permute(0, 2, 1)  # now features is (N, C, T)
 
     with torch.no_grad():
@@ -269,9 +260,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/timit/ASR/tdnn_lstm_ctc/train.py b/egs/timit/ASR/tdnn_lstm_ctc/train.py
index 849256b98..be1ecffaa 100644
--- a/egs/timit/ASR/tdnn_lstm_ctc/train.py
+++ b/egs/timit/ASR/tdnn_lstm_ctc/train.py
@@ -449,9 +449,7 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(
-                    tb_writer, "train/tot_", params.batch_idx_train
-                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             valid_info = compute_validation_loss(
diff --git a/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_dev_test.py b/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_dev_test.py
index 8a9f6ed30..bd73e520e 100755
--- a/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_dev_test.py
+++ b/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_dev_test.py
@@ -20,12 +20,7 @@ import logging
 from pathlib import Path
 
 import torch
-from lhotse import (
-    CutSet,
-    KaldifeatFbank,
-    KaldifeatFbankConfig,
-    LilcomHdf5Writer,
-)
+from lhotse import CutSet, KaldifeatFbank, KaldifeatFbankConfig, LilcomHdf5Writer
 
 # Torch's multithreaded behavior needs to be disabled or
 # it wastes a lot of CPU and slow things down.
@@ -83,9 +78,7 @@ def compute_fbank_wenetspeech_dev_test():
 
 
 def main():
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
     logging.basicConfig(format=formatter, level=logging.INFO)
 
     compute_fbank_wenetspeech_dev_test()
diff --git a/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_splits.py b/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_splits.py
index a882b6113..1b257fb70 100755
--- a/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_splits.py
+++ b/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_splits.py
@@ -152,9 +152,7 @@ def main():
     date_time = now.strftime("%Y-%m-%d-%H-%M-%S")
 
     log_filename = "log-compute_fbank_wenetspeech_splits"
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
     log_filename = f"{log_filename}-{date_time}"
 
     logging.basicConfig(
diff --git a/egs/wenetspeech/ASR/local/prepare_char.py b/egs/wenetspeech/ASR/local/prepare_char.py
index 8bc073c75..d8622842f 100755
--- a/egs/wenetspeech/ASR/local/prepare_char.py
+++ b/egs/wenetspeech/ASR/local/prepare_char.py
@@ -83,9 +83,7 @@ def lexicon_to_fst_no_sil(
         cur_state = loop_state
 
         word = word2id[word]
-        pieces = [
-            token2id[i] if i in token2id else token2id[""] for i in pieces
-        ]
+        pieces = [token2id[i] if i in token2id else token2id[""] for i in pieces]
 
         for i in range(len(pieces) - 1):
             w = word if i == 0 else eps
@@ -138,9 +136,7 @@ def contain_oov(token_sym_table: Dict[str, int], tokens: List[str]) -> bool:
     return False
 
 
-def generate_lexicon(
-    token_sym_table: Dict[str, int], words: List[str]
-) -> Lexicon:
+def generate_lexicon(token_sym_table: Dict[str, int], words: List[str]) -> Lexicon:
     """Generate a lexicon from a word list and token_sym_table.
     Args:
       token_sym_table:
diff --git a/egs/wenetspeech/ASR/local/preprocess_wenetspeech.py b/egs/wenetspeech/ASR/local/preprocess_wenetspeech.py
index 817969c47..93ce750f8 100755
--- a/egs/wenetspeech/ASR/local/preprocess_wenetspeech.py
+++ b/egs/wenetspeech/ASR/local/preprocess_wenetspeech.py
@@ -115,11 +115,7 @@ def preprocess_wenet_speech():
                 f"Speed perturb for {partition} with factors 0.9 and 1.1 "
                 "(Perturbing may take 8 minutes and saving may take 20 minutes)"
             )
-            cut_set = (
-                cut_set
-                + cut_set.perturb_speed(0.9)
-                + cut_set.perturb_speed(1.1)
-            )
+            cut_set = cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
         logging.info(f"Saving to {raw_cuts_path}")
         cut_set.to_file(raw_cuts_path)
 
diff --git a/egs/wenetspeech/ASR/local/text2token.py b/egs/wenetspeech/ASR/local/text2token.py
index 1c463cf1c..d1d237a68 100755
--- a/egs/wenetspeech/ASR/local/text2token.py
+++ b/egs/wenetspeech/ASR/local/text2token.py
@@ -56,9 +56,7 @@ def get_parser():
     parser.add_argument(
         "--skip-ncols", "-s", default=0, type=int, help="skip first n columns"
     )
-    parser.add_argument(
-        "--space", default="", type=str, help="space symbol"
-    )
+    parser.add_argument("--space", default="", type=str, help="space symbol")
     parser.add_argument(
         "--non-lang-syms",
         "-l",
@@ -66,9 +64,7 @@ def get_parser():
         type=str,
         help="list of non-linguistic symobles, e.g.,  etc.",
     )
-    parser.add_argument(
-        "text", type=str, default=False, nargs="?", help="input text"
-    )
+    parser.add_argument("text", type=str, default=False, nargs="?", help="input text")
     parser.add_argument(
         "--trans_type",
         "-t",
@@ -108,8 +104,7 @@ def token2id(
             if token_type == "lazy_pinyin":
                 text = lazy_pinyin(chars_list)
                 sub_ids = [
-                    token_table[txt] if txt in token_table else oov_id
-                    for txt in text
+                    token_table[txt] if txt in token_table else oov_id for txt in text
                 ]
                 ids.append(sub_ids)
             else:  # token_type = "pinyin"
@@ -135,9 +130,7 @@ def main():
     if args.text:
         f = codecs.open(args.text, encoding="utf-8")
     else:
-        f = codecs.getreader("utf-8")(
-            sys.stdin if is_python2 else sys.stdin.buffer
-        )
+        f = codecs.getreader("utf-8")(sys.stdin if is_python2 else sys.stdin.buffer)
 
     sys.stdout = codecs.getwriter("utf-8")(
         sys.stdout if is_python2 else sys.stdout.buffer
diff --git a/egs/wenetspeech/ASR/prepare.sh b/egs/wenetspeech/ASR/prepare.sh
index 755fbb2d7..da7d7e061 100755
--- a/egs/wenetspeech/ASR/prepare.sh
+++ b/egs/wenetspeech/ASR/prepare.sh
@@ -190,7 +190,7 @@ if [ $stage -le 15 ] && [ $stop_stage -ge 15 ]; then
   mkdir -p $lang_char_dir
 
   if ! which jq; then
-      echo "This script is intended to be used with jq but you have not installed jq 
+      echo "This script is intended to be used with jq but you have not installed jq
       Note: in Linux, you can install jq with the following command:
       1. wget -O jq https://github.com/stedolan/jq/releases/download/jq-1.6/jq-linux64
       2. chmod +x ./jq
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py
index 10c953e3b..9c07263a2 100644
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py
@@ -212,17 +212,13 @@ class WenetSpeechAsrDataModule:
             The state dict for the training sampler.
         """
         logging.info("About to get Musan cuts")
-        cuts_musan = load_manifest(
-            self.args.manifest_dir / "musan_cuts.jsonl.gz"
-        )
+        cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
 
         transforms = []
         if self.args.enable_musan:
             logging.info("Enable MUSAN")
             transforms.append(
-                CutMix(
-                    cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
-                )
+                CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
             )
         else:
             logging.info("Disable MUSAN")
@@ -244,9 +240,7 @@ class WenetSpeechAsrDataModule:
         input_transforms = []
         if self.args.enable_spec_aug:
             logging.info("Enable SpecAugment")
-            logging.info(
-                f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
-            )
+            logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
             # Set the value of num_frame_masks according to Lhotse's version.
             # In different Lhotse's versions, the default of num_frame_masks is
             # different.
@@ -289,9 +283,7 @@ class WenetSpeechAsrDataModule:
             # Drop feats to be on the safe side.
             train = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(
-                    Fbank(FbankConfig(num_mel_bins=80))
-                ),
+                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
                 input_transforms=input_transforms,
                 return_cuts=self.args.return_cuts,
             )
@@ -348,9 +340,7 @@ class WenetSpeechAsrDataModule:
         if self.args.on_the_fly_feats:
             validate = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(
-                    Fbank(FbankConfig(num_mel_bins=80))
-                ),
+                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
                 return_cuts=self.args.return_cuts,
             )
         else:
@@ -414,8 +404,7 @@ class WenetSpeechAsrDataModule:
     def train_cuts(self) -> CutSet:
         logging.info("About to get train cuts")
         cuts_train = load_manifest_lazy(
-            self.args.manifest_dir
-            / f"cuts_{self.args.training_subset}.jsonl.gz"
+            self.args.manifest_dir / f"cuts_{self.args.training_subset}.jsonl.gz"
         )
         return cuts_train
 
@@ -427,13 +416,9 @@ class WenetSpeechAsrDataModule:
     @lru_cache()
     def test_net_cuts(self) -> List[CutSet]:
         logging.info("About to get TEST_NET cuts")
-        return load_manifest_lazy(
-            self.args.manifest_dir / "cuts_TEST_NET.jsonl.gz"
-        )
+        return load_manifest_lazy(self.args.manifest_dir / "cuts_TEST_NET.jsonl.gz")
 
     @lru_cache()
     def test_meeting_cuts(self) -> List[CutSet]:
         logging.info("About to get TEST_MEETING cuts")
-        return load_manifest_lazy(
-            self.args.manifest_dir / "cuts_TEST_MEETING.jsonl.gz"
-        )
+        return load_manifest_lazy(self.args.manifest_dir / "cuts_TEST_MEETING.jsonl.gz")
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py
index f0c9bebec..cd9ed57b9 100755
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py
@@ -114,11 +114,7 @@ from beam_search import (
 from train import get_params, get_transducer_model
 
 from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
-from icefall.checkpoint import (
-    average_checkpoints,
-    find_checkpoints,
-    load_checkpoint,
-)
+from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
 from icefall.lexicon import Lexicon
 from icefall.utils import (
     AttributeDict,
@@ -252,8 +248,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -328,9 +323,7 @@ def decode_one_batch(
     supervisions = batch["supervisions"]
     feature_lens = supervisions["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(
-        x=feature, x_lens=feature_lens
-    )
+    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
     hyps = []
 
     if params.decoding_method == "fast_beam_search":
@@ -389,10 +382,7 @@ def decode_one_batch(
         )
         for i in range(encoder_out.size(0)):
             hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
-    elif (
-        params.decoding_method == "greedy_search"
-        and params.max_sym_per_frame == 1
-    ):
+    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -515,9 +505,7 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
     return results
 
 
@@ -550,8 +538,7 @@ def save_results(
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = (
-        params.res_dir
-        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tWER", file=f)
@@ -663,9 +650,7 @@ def main():
             )
             decoding_graph.scores *= params.ngram_lm_scale
         else:
-            decoding_graph = k2.trivial_graph(
-                params.vocab_size - 1, device=device
-            )
+            decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
     else:
         decoding_graph = None
 
@@ -716,8 +701,7 @@ def main():
         )
 
     dev_shards = [
-        str(path)
-        for path in sorted(glob.glob(os.path.join(dev, "shared-*.tar")))
+        str(path) for path in sorted(glob.glob(os.path.join(dev, "shared-*.tar")))
     ]
     cuts_dev_webdataset = CutSet.from_webdataset(
         dev_shards,
@@ -727,8 +711,7 @@ def main():
     )
 
     test_net_shards = [
-        str(path)
-        for path in sorted(glob.glob(os.path.join(test_net, "shared-*.tar")))
+        str(path) for path in sorted(glob.glob(os.path.join(test_net, "shared-*.tar")))
     ]
     cuts_test_net_webdataset = CutSet.from_webdataset(
         test_net_shards,
@@ -739,9 +722,7 @@ def main():
 
     test_meeting_shards = [
         str(path)
-        for path in sorted(
-            glob.glob(os.path.join(test_meeting, "shared-*.tar"))
-        )
+        for path in sorted(glob.glob(os.path.join(test_meeting, "shared-*.tar")))
     ]
     cuts_test_meeting_webdataset = CutSet.from_webdataset(
         test_meeting_shards,
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/export.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/export.py
index 933642a0f..df2fc5df5 100755
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/export.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/export.py
@@ -205,8 +205,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
 
     return parser
@@ -468,13 +467,9 @@ def export_joiner_model_onnx(
 
         - projected_decoder_out: a tensor of shape (N, joiner_dim)
     """
-    encoder_proj_filename = str(joiner_filename).replace(
-        ".onnx", "_encoder_proj.onnx"
-    )
+    encoder_proj_filename = str(joiner_filename).replace(".onnx", "_encoder_proj.onnx")
 
-    decoder_proj_filename = str(joiner_filename).replace(
-        ".onnx", "_decoder_proj.onnx"
-    )
+    decoder_proj_filename = str(joiner_filename).replace(".onnx", "_decoder_proj.onnx")
 
     encoder_out_dim = joiner_model.encoder_proj.weight.shape[1]
     decoder_out_dim = joiner_model.decoder_proj.weight.shape[1]
@@ -645,9 +640,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/jit_pretrained.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/jit_pretrained.py
index e5cc47bfe..42ffbcfb8 100755
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/jit_pretrained.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/jit_pretrained.py
@@ -146,8 +146,7 @@ def read_sound_files(
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
         assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. "
-            f"Given: {sample_rate}"
+            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
         )
         # We use only the first channel
         ans.append(wave[0])
@@ -331,9 +330,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_check.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_check.py
index c396c50ef..a46ff5a07 100755
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_check.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_check.py
@@ -219,9 +219,7 @@ def test_joiner(
         )
 
         # Now test encoder_proj
-        joiner_encoder_proj_inputs = {
-            encoder_proj_input_name: encoder_out.numpy()
-        }
+        joiner_encoder_proj_inputs = {encoder_proj_input_name: encoder_out.numpy()}
         joiner_encoder_proj_out = joiner_encoder_proj_session.run(
             [encoder_proj_output_name], joiner_encoder_proj_inputs
         )[0]
@@ -230,16 +228,10 @@ def test_joiner(
         torch_joiner_encoder_proj_out = model.joiner.encoder_proj(encoder_out)
         assert torch.allclose(
             joiner_encoder_proj_out, torch_joiner_encoder_proj_out, atol=1e-5
-        ), (
-            (joiner_encoder_proj_out - torch_joiner_encoder_proj_out)
-            .abs()
-            .max()
-        )
+        ), ((joiner_encoder_proj_out - torch_joiner_encoder_proj_out).abs().max())
 
         # Now test decoder_proj
-        joiner_decoder_proj_inputs = {
-            decoder_proj_input_name: decoder_out.numpy()
-        }
+        joiner_decoder_proj_inputs = {decoder_proj_input_name: decoder_out.numpy()}
         joiner_decoder_proj_out = joiner_decoder_proj_session.run(
             [decoder_proj_output_name], joiner_decoder_proj_inputs
         )[0]
@@ -248,11 +240,7 @@ def test_joiner(
         torch_joiner_decoder_proj_out = model.joiner.decoder_proj(decoder_out)
         assert torch.allclose(
             joiner_decoder_proj_out, torch_joiner_decoder_proj_out, atol=1e-5
-        ), (
-            (joiner_decoder_proj_out - torch_joiner_decoder_proj_out)
-            .abs()
-            .max()
-        )
+        ), ((joiner_decoder_proj_out - torch_joiner_decoder_proj_out).abs().max())
 
 
 @torch.no_grad()
@@ -304,9 +292,7 @@ def main():
 
 if __name__ == "__main__":
     torch.manual_seed(20220727)
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_pretrained.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_pretrained.py
index 3770fbbb4..ca1e408fa 100755
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_pretrained.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_pretrained.py
@@ -150,8 +150,7 @@ def read_sound_files(
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
         assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. "
-            f"Given: {sample_rate}"
+            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
         )
         # We use only the first channel
         ans.append(wave[0])
@@ -200,11 +199,7 @@ def greedy_search(
 
     projected_encoder_out = joiner_encoder_proj.run(
         [joiner_encoder_proj.get_outputs()[0].name],
-        {
-            joiner_encoder_proj.get_inputs()[
-                0
-            ].name: packed_encoder_out.data.numpy()
-        },
+        {joiner_encoder_proj.get_inputs()[0].name: packed_encoder_out.data.numpy()},
     )[0]
 
     blank_id = 0  # hard-code to 0
@@ -389,9 +384,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/pretrained.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/pretrained.py
index 9a549efd9..aaf7ac874 100755
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/pretrained.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/pretrained.py
@@ -158,8 +158,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
 
     parser.add_argument(
@@ -190,8 +189,7 @@ def read_sound_files(
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
         assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. "
-            f"Given: {sample_rate}"
+            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
         )
         # We use only the first channel
         ans.append(wave[0])
@@ -253,9 +251,7 @@ def main():
     features = fbank(waves)
     feature_lengths = [f.size(0) for f in features]
 
-    features = pad_sequence(
-        features, batch_first=True, padding_value=math.log(1e-10)
-    )
+    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
 
     feature_lengths = torch.tensor(feature_lengths, device=device)
 
@@ -280,10 +276,7 @@ def main():
         )
         for i in range(encoder_out.size(0)):
             hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
-    elif (
-        params.decoding_method == "greedy_search"
-        and params.max_sym_per_frame == 1
-    ):
+    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -335,9 +328,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py
index d3cc7c9c9..7aba0711d 100644
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py
@@ -115,9 +115,7 @@ from icefall.env import get_env_info
 from icefall.lexicon import Lexicon
 from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
 
-LRSchedulerType = Union[
-    torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
-]
+LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
 
 
 def get_parser():
@@ -219,8 +217,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
 
     parser.add_argument(
@@ -243,8 +240,7 @@ def get_parser():
         "--am-scale",
         type=float,
         default=0.0,
-        help="The scale to smooth the loss with am (output of encoder network)"
-        "part.",
+        help="The scale to smooth the loss with am (output of encoder network)" "part.",
     )
 
     parser.add_argument(
@@ -590,22 +586,15 @@ def compute_loss(
         # overwhelming the simple_loss and causing it to diverge,
         # in case it had not fully learned the alignment yet.
         pruned_loss_scale = (
-            0.0
-            if warmup < 1.0
-            else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
-        )
-        loss = (
-            params.simple_loss_scale * simple_loss
-            + pruned_loss_scale * pruned_loss
+            0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
         )
+        loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
     assert loss.requires_grad == is_training
 
     info = MetricsTracker()
     with warnings.catch_warnings():
         warnings.simplefilter("ignore")
-        info["frames"] = (
-            (feature_lens // params.subsampling_factor).sum().item()
-        )
+        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
 
     # Note: We use reduction=sum while computing the loss.
     info["loss"] = loss.detach().cpu().item()
@@ -762,9 +751,7 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(
-                    tb_writer, "train/tot_", params.batch_idx_train
-                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
@@ -864,7 +851,7 @@ def run(rank, world_size, args):
 
     if params.print_diagnostics:
         opts = diagnostics.TensorDiagnosticOptions(
-            2 ** 22
+            2**22
         )  # allow 4 megabytes per sub-module
         diagnostic = diagnostics.attach_diagnostics(model, opts)
 
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/conformer.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/conformer.py
index dd27c17f0..9bb55d07a 100644
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/conformer.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/conformer.py
@@ -210,10 +210,7 @@ class Conformer(EncoderInterface):
           (num_encoder_layers, cnn_module_kernel - 1, encoder_dim).
           NOTE: the returned tensors are on the given device.
         """
-        if (
-            len(self._init_state) == 2
-            and self._init_state[0].size(1) == left_context
-        ):
+        if len(self._init_state) == 2 and self._init_state[0].size(1) == left_context:
             # Note: It is OK to share the init state as it is
             # not going to be modified by the model
             return self._init_state
@@ -433,9 +430,7 @@ class ConformerEncoderLayer(nn.Module):
 
         self.d_model = d_model
 
-        self.self_attn = RelPositionMultiheadAttention(
-            d_model, nhead, dropout=0.0
-        )
+        self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0)
 
         self.feed_forward = nn.Sequential(
             ScaledLinear(d_model, dim_feedforward),
@@ -453,9 +448,7 @@ class ConformerEncoderLayer(nn.Module):
             ScaledLinear(dim_feedforward, d_model, initial_scale=0.25),
         )
 
-        self.conv_module = ConvolutionModule(
-            d_model, cnn_module_kernel, causal=causal
-        )
+        self.conv_module = ConvolutionModule(d_model, cnn_module_kernel, causal=causal)
 
         self.norm_final = BasicNorm(d_model)
 
@@ -520,9 +513,7 @@ class ConformerEncoderLayer(nn.Module):
         src = src + self.dropout(src_att)
 
         # convolution module
-        conv, _ = self.conv_module(
-            src, src_key_padding_mask=src_key_padding_mask
-        )
+        conv, _ = self.conv_module(src, src_key_padding_mask=src_key_padding_mask)
         src = src + self.dropout(conv)
 
         # feed forward module
@@ -766,9 +757,7 @@ class RelPositionalEncoding(torch.nn.Module):
         max_len: Maximum input length.
     """
 
-    def __init__(
-        self, d_model: int, dropout_rate: float, max_len: int = 5000
-    ) -> None:
+    def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None:
         """Construct an PositionalEncoding object."""
         super(RelPositionalEncoding, self).__init__()
         self.d_model = d_model
@@ -784,9 +773,7 @@ class RelPositionalEncoding(torch.nn.Module):
             # the length of self.pe is 2 * input_len - 1
             if self.pe.size(1) >= x_size_1 * 2 - 1:
                 # Note: TorchScript doesn't implement operator== for torch.Device
-                if self.pe.dtype != x.dtype or str(self.pe.device) != str(
-                    x.device
-                ):
+                if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device):
                     self.pe = self.pe.to(dtype=x.dtype, device=x.device)
                 return
         # Suppose `i` means to the position of query vector and `j` means the
@@ -1073,9 +1060,9 @@ class RelPositionMultiheadAttention(nn.Module):
 
         if torch.equal(query, key) and torch.equal(key, value):
             # self-attention
-            q, k, v = nn.functional.linear(
-                query, in_proj_weight, in_proj_bias
-            ).chunk(3, dim=-1)
+            q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk(
+                3, dim=-1
+            )
 
         elif torch.equal(key, value):
             # encoder-decoder attention
@@ -1144,31 +1131,22 @@ class RelPositionMultiheadAttention(nn.Module):
             if attn_mask.dim() == 2:
                 attn_mask = attn_mask.unsqueeze(0)
                 if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
-                    raise RuntimeError(
-                        "The size of the 2D attn_mask is not correct."
-                    )
+                    raise RuntimeError("The size of the 2D attn_mask is not correct.")
             elif attn_mask.dim() == 3:
                 if list(attn_mask.size()) != [
                     bsz * num_heads,
                     query.size(0),
                     key.size(0),
                 ]:
-                    raise RuntimeError(
-                        "The size of the 3D attn_mask is not correct."
-                    )
+                    raise RuntimeError("The size of the 3D attn_mask is not correct.")
             else:
                 raise RuntimeError(
-                    "attn_mask's dimension {} is not supported".format(
-                        attn_mask.dim()
-                    )
+                    "attn_mask's dimension {} is not supported".format(attn_mask.dim())
                 )
             # attn_mask's dim is 3 now.
 
         # convert ByteTensor key_padding_mask to bool
-        if (
-            key_padding_mask is not None
-            and key_padding_mask.dtype == torch.uint8
-        ):
+        if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
             warnings.warn(
                 "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead."
             )
@@ -1208,23 +1186,15 @@ class RelPositionMultiheadAttention(nn.Module):
         # first compute matrix a and matrix c
         # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
         k = k.permute(1, 2, 3, 0)  # (batch, head, d_k, time2)
-        matrix_ac = torch.matmul(
-            q_with_bias_u, k
-        )  # (batch, head, time1, time2)
+        matrix_ac = torch.matmul(q_with_bias_u, k)  # (batch, head, time1, time2)
 
         # compute matrix b and matrix d
-        matrix_bd = torch.matmul(
-            q_with_bias_v, p
-        )  # (batch, head, time1, 2*time1-1)
+        matrix_bd = torch.matmul(q_with_bias_v, p)  # (batch, head, time1, 2*time1-1)
         matrix_bd = self.rel_shift(matrix_bd, left_context)
 
-        attn_output_weights = (
-            matrix_ac + matrix_bd
-        )  # (batch, head, time1, time2)
+        attn_output_weights = matrix_ac + matrix_bd  # (batch, head, time1, time2)
 
-        attn_output_weights = attn_output_weights.view(
-            bsz * num_heads, tgt_len, -1
-        )
+        attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1)
 
         assert list(attn_output_weights.size()) == [
             bsz * num_heads,
@@ -1265,21 +1235,17 @@ class RelPositionMultiheadAttention(nn.Module):
         ):
             if attn_mask.size(0) != 1:
                 attn_mask = attn_mask.view(bsz, num_heads, tgt_len, src_len)
-                combined_mask = attn_mask | key_padding_mask.unsqueeze(
-                    1
-                ).unsqueeze(2)
+                combined_mask = attn_mask | key_padding_mask.unsqueeze(1).unsqueeze(2)
             else:
                 # attn_mask.shape == (1, tgt_len, src_len)
-                combined_mask = attn_mask.unsqueeze(
-                    0
-                ) | key_padding_mask.unsqueeze(1).unsqueeze(2)
+                combined_mask = attn_mask.unsqueeze(0) | key_padding_mask.unsqueeze(
+                    1
+                ).unsqueeze(2)
 
             attn_output_weights = attn_output_weights.view(
                 bsz, num_heads, tgt_len, src_len
             )
-            attn_output_weights = attn_output_weights.masked_fill(
-                combined_mask, 0.0
-            )
+            attn_output_weights = attn_output_weights.masked_fill(combined_mask, 0.0)
             attn_output_weights = attn_output_weights.view(
                 bsz * num_heads, tgt_len, src_len
             )
@@ -1291,13 +1257,9 @@ class RelPositionMultiheadAttention(nn.Module):
         attn_output = torch.bmm(attn_output_weights, v)
         assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
         attn_output = (
-            attn_output.transpose(0, 1)
-            .contiguous()
-            .view(tgt_len, bsz, embed_dim)
-        )
-        attn_output = nn.functional.linear(
-            attn_output, out_proj_weight, out_proj_bias
+            attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
         )
+        attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias)
 
         if need_weights:
             # average attention weights over heads
@@ -1430,16 +1392,12 @@ class ConvolutionModule(nn.Module):
                 # manualy padding self.lorder zeros to the left
                 x = nn.functional.pad(x, (self.lorder, 0), "constant", 0.0)
             else:
-                assert (
-                    not self.training
-                ), "Cache should be None in training time"
+                assert not self.training, "Cache should be None in training time"
                 assert cache.size(0) == self.lorder
                 x = torch.cat([cache.permute(1, 2, 0), x], dim=2)
                 if right_context > 0:
                     cache = x.permute(2, 0, 1)[
-                        -(self.lorder + right_context) : (  # noqa
-                            -right_context
-                        ),
+                        -(self.lorder + right_context) : (-right_context),  # noqa
                         ...,
                     ]
                 else:
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py
index 344e31283..166497c31 100755
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py
@@ -244,8 +244,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -342,9 +341,7 @@ def decode_one_batch(
             simulate_streaming=True,
         )
     else:
-        encoder_out, encoder_out_lens = model.encoder(
-            x=feature, x_lens=feature_lens
-        )
+        encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
 
     hyps = []
 
@@ -360,10 +357,7 @@ def decode_one_batch(
         )
         for i in range(encoder_out.size(0)):
             hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
-    elif (
-        params.decoding_method == "greedy_search"
-        and params.max_sym_per_frame == 1
-    ):
+    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -484,9 +478,7 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
     return results
 
 
@@ -519,8 +511,7 @@ def save_results(
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = (
-        params.res_dir
-        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tWER", file=f)
@@ -589,9 +580,9 @@ def main():
 
     if not params.use_averaged_model:
         if params.iter > 0:
-            filenames = find_checkpoints(
-                params.exp_dir, iteration=-params.iter
-            )[: params.avg]
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg
+            ]
             if len(filenames) == 0:
                 raise ValueError(
                     f"No checkpoints found for"
@@ -618,9 +609,9 @@ def main():
             model.load_state_dict(average_checkpoints(filenames, device=device))
     else:
         if params.iter > 0:
-            filenames = find_checkpoints(
-                params.exp_dir, iteration=-params.iter
-            )[: params.avg + 1]
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg + 1
+            ]
             if len(filenames) == 0:
                 raise ValueError(
                     f"No checkpoints found for"
@@ -720,8 +711,7 @@ def main():
         )
 
     dev_shards = [
-        str(path)
-        for path in sorted(glob.glob(os.path.join(dev, "shared-*.tar")))
+        str(path) for path in sorted(glob.glob(os.path.join(dev, "shared-*.tar")))
     ]
     cuts_dev_webdataset = CutSet.from_webdataset(
         dev_shards,
@@ -731,8 +721,7 @@ def main():
     )
 
     test_net_shards = [
-        str(path)
-        for path in sorted(glob.glob(os.path.join(test_net, "shared-*.tar")))
+        str(path) for path in sorted(glob.glob(os.path.join(test_net, "shared-*.tar")))
     ]
     cuts_test_net_webdataset = CutSet.from_webdataset(
         test_net_shards,
@@ -743,9 +732,7 @@ def main():
 
     test_meeting_shards = [
         str(path)
-        for path in sorted(
-            glob.glob(os.path.join(test_meeting, "shared-*.tar"))
-        )
+        for path in sorted(glob.glob(os.path.join(test_meeting, "shared-*.tar")))
     ]
     cuts_test_meeting_webdataset = CutSet.from_webdataset(
         test_meeting_shards,
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode_stream.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode_stream.py
index 386248554..e522943c0 100644
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode_stream.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode_stream.py
@@ -75,9 +75,7 @@ class DecodeStream(object):
         # encoder.streaming_forward
         self.done_frames: int = 0
 
-        self.pad_length = (
-            params.right_context + 2
-        ) * params.subsampling_factor + 3
+        self.pad_length = (params.right_context + 2) * params.subsampling_factor + 3
 
         if params.decoding_method == "greedy_search":
             self.hyp = [params.blank_id] * params.context_size
@@ -91,13 +89,11 @@ class DecodeStream(object):
             )
         elif params.decoding_method == "fast_beam_search":
             # The rnnt_decoding_stream for fast_beam_search.
-            self.rnnt_decoding_stream: k2.RnntDecodingStream = (
-                k2.RnntDecodingStream(decoding_graph)
+            self.rnnt_decoding_stream: k2.RnntDecodingStream = k2.RnntDecodingStream(
+                decoding_graph
             )
         else:
-            raise ValueError(
-                f"Unsupported decoding method: {params.decoding_method}"
-            )
+            raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
 
     @property
     def done(self) -> bool:
@@ -126,13 +122,10 @@ class DecodeStream(object):
         """Consume chunk_size frames of features"""
         chunk_length = chunk_size + self.pad_length
 
-        ret_length = min(
-            self.num_frames - self.num_processed_frames, chunk_length
-        )
+        ret_length = min(self.num_frames - self.num_processed_frames, chunk_length)
 
         ret_features = self.features[
-            self.num_processed_frames : self.num_processed_frames  # noqa
-            + ret_length
+            self.num_processed_frames : self.num_processed_frames + ret_length  # noqa
         ]
 
         self.num_processed_frames += chunk_size
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/export.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/export.py
index d0a7fd69f..ff2c4db38 100644
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/export.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/export.py
@@ -131,8 +131,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
     add_model_arguments(parser)
 
@@ -201,9 +200,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/pretrained.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/pretrained.py
index 1b064c874..7e4829a60 100644
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/pretrained.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/pretrained.py
@@ -157,8 +157,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
 
     parser.add_argument(
@@ -190,8 +189,7 @@ def read_sound_files(
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
         assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. "
-            f"Given: {sample_rate}"
+            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
         )
         # We use only the first channel
         ans.append(wave[0])
@@ -253,9 +251,7 @@ def main():
     features = fbank(waves)
     feature_lengths = [f.size(0) for f in features]
 
-    features = pad_sequence(
-        features, batch_first=True, padding_value=math.log(1e-10)
-    )
+    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
 
     feature_lengths = torch.tensor(feature_lengths, device=device)
 
@@ -280,10 +276,7 @@ def main():
         )
         for i in range(encoder_out.size(0)):
             hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
-    elif (
-        params.decoding_method == "greedy_search"
-        and params.max_sym_per_frame == 1
-    ):
+    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -335,9 +328,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_beam_search.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_beam_search.py
index 651aff6c9..810d94135 100644
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_beam_search.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_beam_search.py
@@ -173,14 +173,10 @@ def modified_beam_search(
         log_probs_shape = k2.ragged.create_ragged_shape2(
             row_splits=row_splits, cached_tot_size=log_probs.numel()
         )
-        ragged_log_probs = k2.RaggedTensor(
-            shape=log_probs_shape, value=log_probs
-        )
+        ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs)
 
         for i in range(batch_size):
-            topk_log_probs, topk_indexes = ragged_log_probs[i].topk(
-                num_active_paths
-            )
+            topk_log_probs, topk_indexes = ragged_log_probs[i].topk(num_active_paths)
 
             with warnings.catch_warnings():
                 warnings.simplefilter("ignore")
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_decode.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_decode.py
index ff96c6487..6909f40be 100644
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_decode.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_decode.py
@@ -201,8 +201,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
 
     parser.add_argument(
@@ -311,9 +310,7 @@ def decode_one_chunk(
     encoder_out = model.joiner.encoder_proj(encoder_out)
 
     if params.decoding_method == "greedy_search":
-        greedy_search(
-            model=model, encoder_out=encoder_out, streams=decode_streams
-        )
+        greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams)
     elif params.decoding_method == "fast_beam_search":
         processed_lens = processed_lens + encoder_out_lens
         fast_beam_search_one_best(
@@ -333,9 +330,7 @@ def decode_one_chunk(
             num_active_paths=params.num_active_paths,
         )
     else:
-        raise ValueError(
-            f"Unsupported decoding method: {params.decoding_method}"
-        )
+        raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
 
     states = [torch.unbind(states[0], dim=2), torch.unbind(states[1], dim=2)]
 
@@ -389,9 +384,7 @@ def decode_dataset(
     decode_results = []
     # Contain decode streams currently running.
     decode_streams = []
-    initial_states = model.encoder.get_init_state(
-        params.left_context, device=device
-    )
+    initial_states = model.encoder.get_init_state(params.left_context, device=device)
     for num, cut in enumerate(cuts):
         # each utterance has a DecodeStream.
         decode_stream = DecodeStream(
@@ -461,9 +454,7 @@ def decode_dataset(
     elif params.decoding_method == "modified_beam_search":
         key = f"num_active_paths_{params.num_active_paths}"
     else:
-        raise ValueError(
-            f"Unsupported decoding method: {params.decoding_method}"
-        )
+        raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
 
     return {key: decode_results}
 
@@ -499,8 +490,7 @@ def save_results(
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = (
-        params.res_dir
-        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tWER", file=f)
@@ -565,9 +555,9 @@ def main():
 
     if not params.use_averaged_model:
         if params.iter > 0:
-            filenames = find_checkpoints(
-                params.exp_dir, iteration=-params.iter
-            )[: params.avg]
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg
+            ]
             if len(filenames) == 0:
                 raise ValueError(
                     f"No checkpoints found for"
@@ -594,9 +584,9 @@ def main():
             model.load_state_dict(average_checkpoints(filenames, device=device))
     else:
         if params.iter > 0:
-            filenames = find_checkpoints(
-                params.exp_dir, iteration=-params.iter
-            )[: params.avg + 1]
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg + 1
+            ]
             if len(filenames) == 0:
                 raise ValueError(
                     f"No checkpoints found for"
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/train.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/train.py
index 2052e9da7..5f614e77c 100755
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/train.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/train.py
@@ -98,9 +98,7 @@ from icefall.env import get_env_info
 from icefall.lexicon import Lexicon
 from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
 
-LRSchedulerType = Union[
-    torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
-]
+LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
 
 
 def add_model_arguments(parser: argparse.ArgumentParser):
@@ -260,8 +258,7 @@ def get_parser():
         "--initial-lr",
         type=float,
         default=0.003,
-        help="The initial learning rate.  This value should not need "
-        "to be changed.",
+        help="The initial learning rate.  This value should not need " "to be changed.",
     )
 
     parser.add_argument(
@@ -284,8 +281,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
     )
 
     parser.add_argument(
@@ -308,8 +304,7 @@ def get_parser():
         "--am-scale",
         type=float,
         default=0.0,
-        help="The scale to smooth the loss with am (output of encoder network)"
-        "part.",
+        help="The scale to smooth the loss with am (output of encoder network)" "part.",
     )
 
     parser.add_argument(
@@ -665,11 +660,7 @@ def compute_loss(
      warmup: a floating point value which increases throughout training;
         values >= 1.0 are fully warmed up and have all modules present.
     """
-    device = (
-        model.device
-        if isinstance(model, DDP)
-        else next(model.parameters()).device
-    )
+    device = model.device if isinstance(model, DDP) else next(model.parameters()).device
     feature = batch["inputs"]
     # at entry, feature is (N, T, C)
     assert feature.ndim == 3
@@ -701,23 +692,16 @@ def compute_loss(
         # overwhelming the simple_loss and causing it to diverge,
         # in case it had not fully learned the alignment yet.
         pruned_loss_scale = (
-            0.0
-            if warmup < 1.0
-            else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
-        )
-        loss = (
-            params.simple_loss_scale * simple_loss
-            + pruned_loss_scale * pruned_loss
+            0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
         )
+        loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
 
     assert loss.requires_grad == is_training
 
     info = MetricsTracker()
     with warnings.catch_warnings():
         warnings.simplefilter("ignore")
-        info["frames"] = (
-            (feature_lens // params.subsampling_factor).sum().item()
-        )
+        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
 
     # Note: We use reduction=sum while computing the loss.
     info["loss"] = loss.detach().cpu().item()
@@ -841,9 +825,7 @@ def train_one_epoch(
             scaler.update()
             optimizer.zero_grad()
         except:  # noqa
-            display_and_save_batch(
-                batch, params=params, graph_compiler=graph_compiler
-            )
+            display_and_save_batch(batch, params=params, graph_compiler=graph_compiler)
             raise
 
         if params.print_diagnostics and batch_idx == 5:
@@ -901,9 +883,7 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(
-                    tb_writer, "train/tot_", params.batch_idx_train
-                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
@@ -1016,7 +996,7 @@ def run(rank, world_size, args):
 
     if params.print_diagnostics:
         opts = diagnostics.TensorDiagnosticOptions(
-            2 ** 22
+            2**22
         )  # allow 4 megabytes per sub-module
         diagnostic = diagnostics.attach_diagnostics(model, opts)
 
@@ -1184,9 +1164,7 @@ def scan_pessimistic_batches_for_oom(
                     f"Failing criterion: {criterion} "
                     f"(={crit_values[criterion]}) ..."
                 )
-            display_and_save_batch(
-                batch, params=params, graph_compiler=graph_compiler
-            )
+            display_and_save_batch(batch, params=params, graph_compiler=graph_compiler)
             raise
 
 
diff --git a/egs/yesno/ASR/local/compile_hlg.py b/egs/yesno/ASR/local/compile_hlg.py
index f83be05cf..7234ca929 100755
--- a/egs/yesno/ASR/local/compile_hlg.py
+++ b/egs/yesno/ASR/local/compile_hlg.py
@@ -128,9 +128,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
 
diff --git a/egs/yesno/ASR/local/compute_fbank_yesno.py b/egs/yesno/ASR/local/compute_fbank_yesno.py
index 9a4e8a36f..75d95df68 100755
--- a/egs/yesno/ASR/local/compute_fbank_yesno.py
+++ b/egs/yesno/ASR/local/compute_fbank_yesno.py
@@ -54,9 +54,7 @@ def compute_fbank_yesno():
         dataset_parts,
     )
 
-    extractor = Fbank(
-        FbankConfig(sampling_rate=8000, num_mel_bins=num_mel_bins)
-    )
+    extractor = Fbank(FbankConfig(sampling_rate=8000, num_mel_bins=num_mel_bins))
 
     with get_executor() as ex:  # Initialize the executor only once.
         for partition, m in manifests.items():
@@ -71,9 +69,7 @@ def compute_fbank_yesno():
             )
             if "train" in partition:
                 cut_set = (
-                    cut_set
-                    + cut_set.perturb_speed(0.9)
-                    + cut_set.perturb_speed(1.1)
+                    cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
                 )
             cut_set = cut_set.compute_and_store_features(
                 extractor=extractor,
@@ -87,9 +83,7 @@ def compute_fbank_yesno():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
 
diff --git a/egs/yesno/ASR/tdnn/decode.py b/egs/yesno/ASR/tdnn/decode.py
index 9d4ab4b61..d5efb41df 100755
--- a/egs/yesno/ASR/tdnn/decode.py
+++ b/egs/yesno/ASR/tdnn/decode.py
@@ -201,9 +201,7 @@ def decode_dataset(
         if batch_idx % 100 == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
     return results
 
 
@@ -274,9 +272,7 @@ def main():
 
     logging.info(f"device: {device}")
 
-    HLG = k2.Fsa.from_dict(
-        torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")
-    )
+    HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu"))
     HLG = HLG.to(device)
     assert HLG.requires_grad is False
 
@@ -297,9 +293,7 @@ def main():
 
     if params.export:
         logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
-        torch.save(
-            {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
-        )
+        torch.save({"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt")
         return
 
     model.to(device)
@@ -317,9 +311,7 @@ def main():
         word_table=lexicon.word_table,
     )
 
-    save_results(
-        exp_dir=params.exp_dir, test_set_name="test_set", results=results
-    )
+    save_results(exp_dir=params.exp_dir, test_set_name="test_set", results=results)
 
     logging.info("Done!")
 
diff --git a/egs/yesno/ASR/tdnn/pretrained.py b/egs/yesno/ASR/tdnn/pretrained.py
index 14220be19..88d5eca5d 100755
--- a/egs/yesno/ASR/tdnn/pretrained.py
+++ b/egs/yesno/ASR/tdnn/pretrained.py
@@ -53,9 +53,7 @@ def get_parser():
         help="Path to words.txt",
     )
 
-    parser.add_argument(
-        "--HLG", type=str, required=True, help="Path to HLG.pt."
-    )
+    parser.add_argument("--HLG", type=str, required=True, help="Path to HLG.pt.")
 
     parser.add_argument(
         "sound_files",
@@ -102,8 +100,7 @@ def read_sound_files(
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
         assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. "
-            f"Given: {sample_rate}"
+            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
         )
         # We use only the first channel
         ans.append(wave[0])
@@ -159,9 +156,7 @@ def main():
     logging.info("Decoding started")
     features = fbank(waves)
 
-    features = pad_sequence(
-        features, batch_first=True, padding_value=math.log(1e-10)
-    )
+    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
 
     # Note: We don't use key padding mask for attention during decoding
     with torch.no_grad():
@@ -201,9 +196,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/yesno/ASR/tdnn/train.py b/egs/yesno/ASR/tdnn/train.py
index f32a27f35..335493491 100755
--- a/egs/yesno/ASR/tdnn/train.py
+++ b/egs/yesno/ASR/tdnn/train.py
@@ -430,9 +430,7 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(
-                    tb_writer, "train/tot_", params.batch_idx_train
-                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             valid_info = compute_validation_loss(
diff --git a/egs/yesno/ASR/transducer/decode.py b/egs/yesno/ASR/transducer/decode.py
index 6714180db..7f13e417a 100755
--- a/egs/yesno/ASR/transducer/decode.py
+++ b/egs/yesno/ASR/transducer/decode.py
@@ -116,9 +116,7 @@ def decode_one_batch(
     # at entry, feature is (N, T, C)
     feature_lens = batch["supervisions"]["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(
-        x=feature, x_lens=feature_lens
-    )
+    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
 
     hyps = []
     batch_size = encoder_out.size(0)
@@ -186,9 +184,7 @@ def decode_dataset(
         if batch_idx % 100 == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
     return results
 
 
@@ -303,9 +299,7 @@ def main():
         model=model,
     )
 
-    save_results(
-        exp_dir=params.exp_dir, test_set_name="test_set", results=results
-    )
+    save_results(exp_dir=params.exp_dir, test_set_name="test_set", results=results)
 
     logging.info("Done!")
 
diff --git a/egs/yesno/ASR/transducer/train.py b/egs/yesno/ASR/transducer/train.py
index deb92107d..88866ae81 100755
--- a/egs/yesno/ASR/transducer/train.py
+++ b/egs/yesno/ASR/transducer/train.py
@@ -430,9 +430,7 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(
-                    tb_writer, "train/tot_", params.batch_idx_train
-                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             valid_info = compute_validation_loss(
diff --git a/icefall/char_graph_compiler.py b/icefall/char_graph_compiler.py
index 235160e14..c31db6e4c 100644
--- a/icefall/char_graph_compiler.py
+++ b/icefall/char_graph_compiler.py
@@ -71,9 +71,7 @@ class CharCtcTrainingGraphCompiler(object):
         for text in texts:
             text = re.sub(whitespace, "", text)
             sub_ids = [
-                self.token_table[txt]
-                if txt in self.token_table
-                else self.oov_id
+                self.token_table[txt] if txt in self.token_table else self.oov_id
                 for txt in text
             ]
             ids.append(sub_ids)
@@ -96,9 +94,7 @@ class CharCtcTrainingGraphCompiler(object):
         for text in texts:
             text = text.split("/")
             sub_ids = [
-                self.token_table[txt]
-                if txt in self.token_table
-                else self.oov_id
+                self.token_table[txt] if txt in self.token_table else self.oov_id
                 for txt in text
             ]
             ids.append(sub_ids)
diff --git a/icefall/checkpoint.py b/icefall/checkpoint.py
index 5069b78e8..8aa0a8eeb 100644
--- a/icefall/checkpoint.py
+++ b/icefall/checkpoint.py
@@ -292,15 +292,11 @@ def find_checkpoints(out_dir: Path, iteration: int = 0) -> List[str]:
     """
     checkpoints = list(glob.glob(f"{out_dir}/checkpoint-[0-9]*.pt"))
     pattern = re.compile(r"checkpoint-([0-9]+).pt")
-    iter_checkpoints = [
-        (int(pattern.search(c).group(1)), c) for c in checkpoints
-    ]
+    iter_checkpoints = [(int(pattern.search(c).group(1)), c) for c in checkpoints]
     # iter_checkpoints is a list of tuples. Each tuple contains
     # two elements: (iteration_number, checkpoint-iteration_number.pt)
 
-    iter_checkpoints = sorted(
-        iter_checkpoints, reverse=True, key=lambda x: x[0]
-    )
+    iter_checkpoints = sorted(iter_checkpoints, reverse=True, key=lambda x: x[0])
     if iteration >= 0:
         ans = [ic[1] for ic in iter_checkpoints if ic[0] >= iteration]
     else:
@@ -469,7 +465,5 @@ def average_state_dict(
         v = state_dict_1[k]
         if torch.is_floating_point(v):
             v *= weight_1
-            v += (
-                state_dict_2[k].to(device=state_dict_1[k].device) * weight_2
-            )
+            v += state_dict_2[k].to(device=state_dict_1[k].device) * weight_2
             v *= scaling_factor
diff --git a/icefall/decode.py b/icefall/decode.py
index 099e2d171..e4c614c4e 100644
--- a/icefall/decode.py
+++ b/icefall/decode.py
@@ -334,13 +334,9 @@ class Nbest(object):
         if hasattr(lattice, "aux_labels"):
             # delete token IDs as it is not needed
             del word_fsa.aux_labels
-            word_fsa_with_epsilon_loops = k2.linear_fsa_with_self_loops(
-                word_fsa
-            )
+            word_fsa_with_epsilon_loops = k2.linear_fsa_with_self_loops(word_fsa)
         else:
-            word_fsa_with_epsilon_loops = k2.linear_fst_with_self_loops(
-                word_fsa
-            )
+            word_fsa_with_epsilon_loops = k2.linear_fst_with_self_loops(word_fsa)
 
         path_to_utt_map = self.shape.row_ids(1)
 
@@ -370,9 +366,7 @@ class Nbest(object):
         # path_lattice has word IDs as labels and token IDs as aux_labels
         path_lattice = k2.top_sort(k2.connect(path_lattice))
 
-        one_best = k2.shortest_path(
-            path_lattice, use_double_scores=use_double_scores
-        )
+        one_best = k2.shortest_path(path_lattice, use_double_scores=use_double_scores)
 
         one_best = k2.invert(one_best)
         # Now one_best has token IDs as labels and word IDs as aux_labels
@@ -442,9 +436,7 @@ class Nbest(object):
         scores_shape = self.fsa.arcs.shape().remove_axis(1)
         # scores_shape has axes [path][arc]
 
-        ragged_scores = k2.RaggedTensor(
-            scores_shape, self.fsa.scores.contiguous()
-        )
+        ragged_scores = k2.RaggedTensor(scores_shape, self.fsa.scores.contiguous())
 
         tot_scores = ragged_scores.sum()
 
@@ -483,9 +475,7 @@ def one_best_decoding(
             am_scores = saved_am_scores / lm_scale
             lattice.scores = am_scores + lattice.lm_scores
 
-            best_path = k2.shortest_path(
-                lattice, use_double_scores=use_double_scores
-            )
+            best_path = k2.shortest_path(lattice, use_double_scores=use_double_scores)
             key = f"lm_scale_{lm_scale}"
             ans[key] = best_path
         return ans
@@ -696,9 +686,7 @@ def rescore_with_n_best_list(
             logging.info(f"num_paths before decreasing: {num_paths}")
             num_paths = int(num_paths / 2)
             if loop_count >= max_loop_count or num_paths <= 0:
-                logging.info(
-                    "Return None as the resulting lattice is too large."
-                )
+                logging.info("Return None as the resulting lattice is too large.")
                 return None
             logging.info(
                 "This OOM is not an error. You can ignore it. "
@@ -805,13 +793,9 @@ def rescore_with_whole_lattice(
         except RuntimeError as e:
             logging.info(f"Caught exception:\n{e}\n")
             if loop_count >= max_loop_count:
-                logging.info(
-                    "Return None as the resulting lattice is too large."
-                )
+                logging.info("Return None as the resulting lattice is too large.")
                 return None
-            logging.info(
-                f"num_arcs before pruning: {inv_lattice.arcs.num_elements()}"
-            )
+            logging.info(f"num_arcs before pruning: {inv_lattice.arcs.num_elements()}")
             logging.info(
                 "This OOM is not an error. You can ignore it. "
                 "If your model does not converge well, or --max-duration "
@@ -823,9 +807,7 @@ def rescore_with_whole_lattice(
                 prune_th_list[loop_count],
                 True,
             )
-            logging.info(
-                f"num_arcs after pruning: {inv_lattice.arcs.num_elements()}"
-            )
+            logging.info(f"num_arcs after pruning: {inv_lattice.arcs.num_elements()}")
         loop_count += 1
 
     # lat has token IDs as labels
@@ -912,9 +894,7 @@ def rescore_with_attention_decoder(
             logging.info(f"num_paths before decreasing: {num_paths}")
             num_paths = int(num_paths / 2)
             if loop_count >= max_loop_count or num_paths <= 0:
-                logging.info(
-                    "Return None as the resulting lattice is too large."
-                )
+                logging.info("Return None as the resulting lattice is too large.")
                 return None
             logging.info(
                 "This OOM is not an error. You can ignore it. "
diff --git a/icefall/diagnostics.py b/icefall/diagnostics.py
index b075aceac..207c12bf1 100644
--- a/icefall/diagnostics.py
+++ b/icefall/diagnostics.py
@@ -19,7 +19,7 @@
 
 import random
 from dataclasses import dataclass
-from typing import Optional, Tuple, List
+from typing import List, Optional, Tuple
 
 import torch
 from torch import Tensor, nn
@@ -78,11 +78,11 @@ def get_tensor_stats(
     elif stats_type == "abs":
         x = x.abs()
     elif stats_type == "rms":
-        x = x ** 2
+        x = x**2
     elif stats_type == "positive":
         x = (x > 0).to(dtype=torch.float)
     else:
-        assert stats_type in [ "value", "max", "min" ]
+        assert stats_type in ["value", "max", "min"]
 
     sum_dims = [d for d in range(x.ndim) if d != dim]
     if len(sum_dims) > 0:
@@ -121,7 +121,9 @@ class TensorDiagnostic(object):
         self.name = name
         self.class_name = None  # will assign in accumulate()
 
-        self.stats = None  # we'll later assign a list to this data member.  It's a list of dict.
+        self.stats = (
+            None  # we'll later assign a list to this data member.  It's a list of dict.
+        )
 
         # the keys into self.stats[dim] are strings, whose values can be
         # "abs", "max", "min" ,"value", "positive", "rms", "value".
@@ -133,7 +135,6 @@ class TensorDiagnostic(object):
         # only adding a new element to the list if there was a different dim.
         # if the string in the key is "eigs", if we detect a length mismatch we put None as the value.
 
-
     def accumulate(self, x, class_name: Optional[str] = None):
         """
         Accumulate tensors.
@@ -185,17 +186,12 @@ class TensorDiagnostic(object):
                         done = True
                         break
                 if not done:
-                    if (
-                        this_dim_stats[stats_type] != []
-                        and stats_type == "eigs"
-                    ):
+                    if this_dim_stats[stats_type] != [] and stats_type == "eigs":
                         # >1 size encountered on this dim, e.g. it's a batch or time dimension,
                         # don't accumulat "eigs" stats type, it uses too much memory
                         this_dim_stats[stats_type] = None
                     else:
-                        this_dim_stats[stats_type].append(
-                            TensorAndCount(stats, count)
-                        )
+                        this_dim_stats[stats_type].append(TensorAndCount(stats, count))
 
     def print_diagnostics(self):
         """Print diagnostics for each dimension of the tensor."""
@@ -211,7 +207,6 @@ class TensorDiagnostic(object):
                     assert stats_type == "eigs"
                     continue
 
-
                 def get_count(count):
                     return 1 if stats_type in ["max", "min"] else count
 
@@ -229,9 +224,7 @@ class TensorDiagnostic(object):
                         eigs, _ = torch.symeig(stats)
                         stats = eigs.abs().sqrt()
                     except:  # noqa
-                        print(
-                            "Error getting eigenvalues, trying another method."
-                        )
+                        print("Error getting eigenvalues, trying another method.")
                         eigs, _ = torch.eig(stats)
                         stats = eigs.abs().sqrt()
                         # sqrt so it reflects data magnitude, like stddev- not variance
@@ -242,9 +235,9 @@ class TensorDiagnostic(object):
 
                 # if `summarize` we print percentiles of the stats; else,
                 # we print out individual elements.
-                summarize = (
-                    len(stats_list) > 1
-                ) or self.opts.dim_is_summarized(stats.numel())
+                summarize = (len(stats_list) > 1) or self.opts.dim_is_summarized(
+                    stats.numel()
+                )
                 if summarize:  # usually `summarize` will be true
                     # print out percentiles.
                     stats = stats.sort()[0]
@@ -261,15 +254,15 @@ class TensorDiagnostic(object):
                     ans = stats.tolist()
                     ans = ["%.2g" % x for x in ans]
                     ans = "[" + " ".join(ans) + "]"
-                if stats_type in [ "value", "rms", "eigs" ]:
+                if stats_type in ["value", "rms", "eigs"]:
                     # This norm is useful because it is strictly less than the largest
                     # sqrt(eigenvalue) of the variance, which we print out, and shows,
                     # speaking in an approximate way, how much of that largest eigenvalue
                     # can be attributed to the mean of the distribution.
-                    norm = (stats ** 2).sum().sqrt().item()
+                    norm = (stats**2).sum().sqrt().item()
                     ans += f", norm={norm:.2g}"
                 mean = stats.mean().item()
-                rms = (stats ** 2).mean().sqrt().item()
+                rms = (stats**2).mean().sqrt().item()
                 ans += f", mean={mean:.3g}, rms={rms:.3g}"
 
                 # OK, "ans" contains the actual stats, e.g.
@@ -277,17 +270,16 @@ class TensorDiagnostic(object):
 
                 sizes = [x.tensor.shape[0] for x in stats_list]
                 size_str = (
-                    f"{sizes[0]}"
-                    if len(sizes) == 1
-                    else f"{min(sizes)}..{max(sizes)}"
+                    f"{sizes[0]}" if len(sizes) == 1 else f"{min(sizes)}..{max(sizes)}"
+                )
+                maybe_class_name = (
+                    f" type={self.class_name}," if self.class_name is not None else ""
                 )
-                maybe_class_name = f" type={self.class_name}," if self.class_name is not None else ""
                 print(
                     f"module={self.name},{maybe_class_name} dim={dim}, size={size_str}, {stats_type} {ans}"
                 )
 
 
-
 class ModelDiagnostic(object):
     """This class stores diagnostics for all tensors in the torch.nn.Module.
 
@@ -345,32 +337,32 @@ def attach_diagnostics(
         # (matters for name, since the variable gets overwritten).
         # These closures don't really capture by value, only by
         # "the final value the variable got in the function" :-(
-        def forward_hook(
-            _module, _input, _output, _model_diagnostic=ans, _name=name
-        ):
+        def forward_hook(_module, _input, _output, _model_diagnostic=ans, _name=name):
             if isinstance(_output, tuple) and len(_output) == 1:
                 _output = _output[0]
 
             if isinstance(_output, Tensor):
-                _model_diagnostic[f"{_name}.output"].accumulate(_output,
-                                                                class_name=type(_module).__name__)
+                _model_diagnostic[f"{_name}.output"].accumulate(
+                    _output, class_name=type(_module).__name__
+                )
             elif isinstance(_output, tuple):
                 for i, o in enumerate(_output):
-                    _model_diagnostic[f"{_name}.output[{i}]"].accumulate(o,
-                                                                         class_name=type(_module).__name__)
+                    _model_diagnostic[f"{_name}.output[{i}]"].accumulate(
+                        o, class_name=type(_module).__name__
+                    )
 
-        def backward_hook(
-            _module, _input, _output, _model_diagnostic=ans, _name=name
-        ):
+        def backward_hook(_module, _input, _output, _model_diagnostic=ans, _name=name):
             if isinstance(_output, tuple) and len(_output) == 1:
                 _output = _output[0]
             if isinstance(_output, Tensor):
-                _model_diagnostic[f"{_name}.grad"].accumulate(_output,
-                                                              class_name=type(_module).__name__)
+                _model_diagnostic[f"{_name}.grad"].accumulate(
+                    _output, class_name=type(_module).__name__
+                )
             elif isinstance(_output, tuple):
                 for i, o in enumerate(_output):
-                    _model_diagnostic[f"{_name}.grad[{i}]"].accumulate(o,
-                                                                       class_name=type(_module).__name__)
+                    _model_diagnostic[f"{_name}.grad[{i}]"].accumulate(
+                        o, class_name=type(_module).__name__
+                    )
 
         module.register_forward_hook(forward_hook)
         module.register_backward_hook(backward_hook)
diff --git a/icefall/dist.py b/icefall/dist.py
index 7016beafb..9df1c5bd1 100644
--- a/icefall/dist.py
+++ b/icefall/dist.py
@@ -29,9 +29,7 @@ def setup_dist(rank, world_size, master_port=None, use_ddp_launch=False):
         os.environ["MASTER_ADDR"] = "localhost"
 
     if "MASTER_PORT" not in os.environ:
-        os.environ["MASTER_PORT"] = (
-            "12354" if master_port is None else str(master_port)
-        )
+        os.environ["MASTER_PORT"] = "12354" if master_port is None else str(master_port)
 
     if use_ddp_launch is False:
         dist.init_process_group("nccl", rank=rank, world_size=world_size)
diff --git a/icefall/env.py b/icefall/env.py
index 8aeda6be2..373e9a9ff 100644
--- a/icefall/env.py
+++ b/icefall/env.py
@@ -53,9 +53,7 @@ def get_git_sha1():
             )
             > 0
         )
-        git_commit = (
-            git_commit + "-dirty" if dirty_commit else git_commit + "-clean"
-        )
+        git_commit = git_commit + "-dirty" if dirty_commit else git_commit + "-clean"
     except:  # noqa
         return None
 
diff --git a/icefall/graph_compiler.py b/icefall/graph_compiler.py
index 570ed7d7a..e2ff03f61 100644
--- a/icefall/graph_compiler.py
+++ b/icefall/graph_compiler.py
@@ -75,9 +75,7 @@ class CtcTrainingGraphCompiler(object):
 
         # NOTE: k2.compose runs on CUDA only when treat_epsilons_specially
         # is False, so we add epsilon self-loops here
-        fsa_with_self_loops = k2.remove_epsilon_and_add_self_loops(
-            transcript_fsa
-        )
+        fsa_with_self_loops = k2.remove_epsilon_and_add_self_loops(transcript_fsa)
 
         fsa_with_self_loops = k2.arc_sort(fsa_with_self_loops)
 
diff --git a/icefall/hooks.py b/icefall/hooks.py
index fbcf5e148..398a5f689 100644
--- a/icefall/hooks.py
+++ b/icefall/hooks.py
@@ -14,10 +14,11 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import logging
 import random
+
 import torch
 from torch import Tensor, nn
-import logging
 
 
 def register_inf_check_hooks(model: nn.Module) -> None:
@@ -56,7 +57,7 @@ def register_inf_check_hooks(model: nn.Module) -> None:
             if isinstance(_output, Tensor):
                 if not torch.isfinite(_output.to(torch.float32).sum()):
                     logging.warning(
-                        f"The sum of {_name}.grad is not finite" # ": {_output}"
+                        f"The sum of {_name}.grad is not finite"  # ": {_output}"
                     )
             elif isinstance(_output, tuple):
                 for i, o in enumerate(_output):
@@ -65,28 +66,20 @@ def register_inf_check_hooks(model: nn.Module) -> None:
                     if not isinstance(o, Tensor):
                         continue
                     if not torch.isfinite(o.to(torch.float32).sum()):
-                        logging.warning(
-                            f"The sum of {_name}.grad[{i}] is not finite"
-                        )
+                        logging.warning(f"The sum of {_name}.grad[{i}] is not finite")
 
         module.register_forward_hook(forward_hook)
         module.register_backward_hook(backward_hook)
 
-
     for name, parameter in model.named_parameters():
 
-        def param_backward_hook(
-                grad, _name=name
-        ):
+        def param_backward_hook(grad, _name=name):
             if not torch.isfinite(grad.to(torch.float32).sum()):
-                logging.warning(
-                    f"The sum of {_name}.param_grad is not finite"
-                )
+                logging.warning(f"The sum of {_name}.param_grad is not finite")
 
         parameter.register_hook(param_backward_hook)
 
 
-
 def _test_inf_check_hooks():
     model = nn.Sequential(nn.Linear(100, 50), nn.Linear(50, 80))
 
diff --git a/icefall/lexicon.py b/icefall/lexicon.py
index 80bd7c1ee..22e1b78bb 100644
--- a/icefall/lexicon.py
+++ b/icefall/lexicon.py
@@ -49,18 +49,12 @@ def read_lexicon(filename: str) -> List[Tuple[str, List[str]]]:
                 continue
 
             if len(a) < 2:
-                logging.info(
-                    f"Found bad line {line} in lexicon file {filename}"
-                )
-                logging.info(
-                    "Every line is expected to contain at least 2 fields"
-                )
+                logging.info(f"Found bad line {line} in lexicon file {filename}")
+                logging.info("Every line is expected to contain at least 2 fields")
                 sys.exit(1)
             word = a[0]
             if word == "":
-                logging.info(
-                    f"Found bad line {line} in lexicon file {filename}"
-                )
+                logging.info(f"Found bad line {line} in lexicon file {filename}")
                 logging.info(" should not be a valid word")
                 sys.exit(1)
 
@@ -119,9 +113,7 @@ def convert_lexicon_to_ragged(
     lexicon_tmp = read_lexicon(filename)
     lexicon = dict(lexicon_tmp)
     if len(lexicon_tmp) != len(lexicon):
-        raise RuntimeError(
-            "It's assumed that each word has a unique pronunciation"
-        )
+        raise RuntimeError("It's assumed that each word has a unique pronunciation")
 
     for i in range(disambig_id):
         w = word_table[i]
diff --git a/icefall/mmi.py b/icefall/mmi.py
index 2c479fc2c..16ed6e032 100644
--- a/icefall/mmi.py
+++ b/icefall/mmi.py
@@ -63,10 +63,7 @@ def _compute_mmi_loss_exact_optimized(
 
     # [0, num_fsas, 1, num_fsas, 2, num_fsas, ... ]
     num_den_graphs_indexes = (
-        torch.stack([num_graphs_indexes, den_graphs_indexes])
-        .t()
-        .reshape(-1)
-        .to(device)
+        torch.stack([num_graphs_indexes, den_graphs_indexes]).t().reshape(-1).to(device)
     )
 
     num_den_reordered_graphs = k2.index(num_den_graphs, num_den_graphs_indexes)
@@ -115,20 +112,12 @@ def _compute_mmi_loss_exact_non_optimized(
     num_graphs, den_graphs = graph_compiler.compile(texts, replicate_den=True)
 
     # TODO: pass output_beam as function argument
-    num_lats = k2.intersect_dense(
-        num_graphs, dense_fsa_vec, output_beam=beam_size
-    )
-    den_lats = k2.intersect_dense(
-        den_graphs, dense_fsa_vec, output_beam=beam_size
-    )
+    num_lats = k2.intersect_dense(num_graphs, dense_fsa_vec, output_beam=beam_size)
+    den_lats = k2.intersect_dense(den_graphs, dense_fsa_vec, output_beam=beam_size)
 
-    num_tot_scores = num_lats.get_tot_scores(
-        log_semiring=True, use_double_scores=True
-    )
+    num_tot_scores = num_lats.get_tot_scores(log_semiring=True, use_double_scores=True)
 
-    den_tot_scores = den_lats.get_tot_scores(
-        log_semiring=True, use_double_scores=True
-    )
+    den_tot_scores = den_lats.get_tot_scores(log_semiring=True, use_double_scores=True)
 
     tot_scores = num_tot_scores - den_scale * den_tot_scores
 
@@ -168,13 +157,9 @@ def _compute_mmi_loss_pruned(
         max_active_states=10000,
     )
 
-    num_tot_scores = num_lats.get_tot_scores(
-        log_semiring=True, use_double_scores=True
-    )
+    num_tot_scores = num_lats.get_tot_scores(log_semiring=True, use_double_scores=True)
 
-    den_tot_scores = den_lats.get_tot_scores(
-        log_semiring=True, use_double_scores=True
-    )
+    den_tot_scores = den_lats.get_tot_scores(log_semiring=True, use_double_scores=True)
 
     tot_scores = num_tot_scores - den_scale * den_tot_scores
 
diff --git a/icefall/mmi_graph_compiler.py b/icefall/mmi_graph_compiler.py
index 0d901227d..9f680f83d 100644
--- a/icefall/mmi_graph_compiler.py
+++ b/icefall/mmi_graph_compiler.py
@@ -137,9 +137,7 @@ class MmiTrainingGraphCompiler(object):
             transcript_fsa
         )
 
-        transcript_fsa_with_self_loops = k2.arc_sort(
-            transcript_fsa_with_self_loops
-        )
+        transcript_fsa_with_self_loops = k2.arc_sort(transcript_fsa_with_self_loops)
 
         num = k2.compose(
             self.ctc_topo_P,
@@ -155,9 +153,7 @@ class MmiTrainingGraphCompiler(object):
 
         ctc_topo_P_vec = k2.create_fsa_vec([self.ctc_topo_P])
         if replicate_den:
-            indexes = torch.zeros(
-                len(texts), dtype=torch.int32, device=self.device
-            )
+            indexes = torch.zeros(len(texts), dtype=torch.int32, device=self.device)
             den = k2.index_fsa(ctc_topo_P_vec, indexes)
         else:
             den = ctc_topo_P_vec
diff --git a/icefall/rnn_lm/dataset.py b/icefall/rnn_lm/dataset.py
index 598e329c4..4bf982503 100644
--- a/icefall/rnn_lm/dataset.py
+++ b/icefall/rnn_lm/dataset.py
@@ -155,12 +155,8 @@ class LmDatasetCollate:
         sentence_tokens_with_sos = add_sos(sentence_tokens, self.sos_id)
         sentence_tokens_with_eos = add_eos(sentence_tokens, self.eos_id)
 
-        x = sentence_tokens_with_sos.pad(
-            mode="constant", padding_value=self.blank_id
-        )
-        y = sentence_tokens_with_eos.pad(
-            mode="constant", padding_value=self.blank_id
-        )
+        x = sentence_tokens_with_sos.pad(mode="constant", padding_value=self.blank_id)
+        y = sentence_tokens_with_eos.pad(mode="constant", padding_value=self.blank_id)
         sentence_token_lengths += 1  # plus 1 since we added a SOS
 
         return x.to(torch.int64), y.to(torch.int64), sentence_token_lengths
diff --git a/icefall/rnn_lm/export.py b/icefall/rnn_lm/export.py
index 094035fce..2411cb1f0 100644
--- a/icefall/rnn_lm/export.py
+++ b/icefall/rnn_lm/export.py
@@ -159,9 +159,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/icefall/rnn_lm/model.py b/icefall/rnn_lm/model.py
index a6144727a..9eef88840 100644
--- a/icefall/rnn_lm/model.py
+++ b/icefall/rnn_lm/model.py
@@ -129,9 +129,7 @@ class RnnLmModel(torch.nn.Module):
         tokens_eos = add_eos(tokens, eos_id)
         sos_tokens_row_splits = sos_tokens.shape.row_splits(1)
 
-        sentence_lengths = (
-            sos_tokens_row_splits[1:] - sos_tokens_row_splits[:-1]
-        )
+        sentence_lengths = sos_tokens_row_splits[1:] - sos_tokens_row_splits[:-1]
 
         x_tokens = sos_tokens.pad(mode="constant", padding_value=blank_id)
         y_tokens = tokens_eos.pad(mode="constant", padding_value=blank_id)
@@ -161,12 +159,12 @@ class RnnLmModel(torch.nn.Module):
         if state:
             h, c = state
         else:
-            h = torch.zeros(
-                self.rnn.num_layers, batch_size, self.rnn.input_size
-            ).to(device)
-            c = torch.zeros(
-                self.rnn.num_layers, batch_size, self.rnn.input_size
-            ).to(device)
+            h = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.input_size).to(
+                device
+            )
+            c = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.input_size).to(
+                device
+            )
 
         embedding = self.input_embedding(tokens)
         rnn_out, states = self.rnn(embedding, (h, c))
@@ -181,12 +179,8 @@ class RnnLmModel(torch.nn.Module):
         if state:
             h, c = state
         else:
-            h = torch.zeros(
-                self.rnn.num_layers, batch_size, self.rnn.input_size
-            )
-            c = torch.zeros(
-                self.rnn.num_layers, batch_size, self.rnn.input_size
-            )
+            h = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.input_size)
+            c = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.input_size)
 
         device = next(self.parameters()).device
 
@@ -194,9 +188,7 @@ class RnnLmModel(torch.nn.Module):
         tokens_eos = add_eos(tokens, eos_id)
         sos_tokens_row_splits = sos_tokens.shape.row_splits(1)
 
-        sentence_lengths = (
-            sos_tokens_row_splits[1:] - sos_tokens_row_splits[:-1]
-        )
+        sentence_lengths = sos_tokens_row_splits[1:] - sos_tokens_row_splits[:-1]
 
         x_tokens = sos_tokens.pad(mode="constant", padding_value=blank_id)
         y_tokens = tokens_eos.pad(mode="constant", padding_value=blank_id)
diff --git a/icefall/rnn_lm/train.py b/icefall/rnn_lm/train.py
index bb5f03fb9..3ba5bfbee 100755
--- a/icefall/rnn_lm/train.py
+++ b/icefall/rnn_lm/train.py
@@ -446,17 +446,13 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(
-                    tb_writer, "train/tot_", params.batch_idx_train
-                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
 
                 tb_writer.add_scalar(
                     "train/current_ppl", this_batch_ppl, params.batch_idx_train
                 )
 
-                tb_writer.add_scalar(
-                    "train/tot_ppl", tot_ppl, params.batch_idx_train
-                )
+                tb_writer.add_scalar("train/tot_ppl", tot_ppl, params.batch_idx_train)
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
diff --git a/icefall/shared/make_kn_lm.py b/icefall/shared/make_kn_lm.py
index c2edd823e..b1220d55e 100755
--- a/icefall/shared/make_kn_lm.py
+++ b/icefall/shared/make_kn_lm.py
@@ -15,30 +15,43 @@
 # The data structure is based on: kaldi/egs/wsj/s5/utils/lang/make_phone_lm.py
 # The smoothing algorithm is based on: http://www.speech.sri.com/projects/srilm/manpages/ngram-discount.7.html
 
-import sys
-import os
-import re
+import argparse
 import io
 import math
-import argparse
+import os
+import re
+import sys
 from collections import Counter, defaultdict
 
-
-parser = argparse.ArgumentParser(description="""
+parser = argparse.ArgumentParser(
+    description="""
     Generate kneser-ney language model as arpa format. By default,
     it will read the corpus from standard input, and output to standard output.
-    """)
-parser.add_argument("-ngram-order", type=int, default=4, choices=[2, 3, 4, 5, 6, 7], help="Order of n-gram")
+    """
+)
+parser.add_argument(
+    "-ngram-order",
+    type=int,
+    default=4,
+    choices=[2, 3, 4, 5, 6, 7],
+    help="Order of n-gram",
+)
 parser.add_argument("-text", type=str, default=None, help="Path to the corpus file")
-parser.add_argument("-lm", type=str, default=None, help="Path to output arpa file for language models")
-parser.add_argument("-verbose", type=int, default=0, choices=[0, 1, 2, 3, 4, 5], help="Verbose level")
+parser.add_argument(
+    "-lm", type=str, default=None, help="Path to output arpa file for language models"
+)
+parser.add_argument(
+    "-verbose", type=int, default=0, choices=[0, 1, 2, 3, 4, 5], help="Verbose level"
+)
 args = parser.parse_args()
 
-default_encoding = "latin-1"  # For encoding-agnostic scripts, we assume byte stream as input.
-                              # Need to be very careful about the use of strip() and split()
-                              # in this case, because there is a latin-1 whitespace character
-                              # (nbsp) which is part of the unicode encoding range.
-                              # Ref: kaldi/egs/wsj/s5/utils/lang/bpe/prepend_words.py @ 69cd717
+default_encoding = (
+    "latin-1"  # For encoding-agnostic scripts, we assume byte stream as input.
+)
+# Need to be very careful about the use of strip() and split()
+# in this case, because there is a latin-1 whitespace character
+# (nbsp) which is part of the unicode encoding range.
+# Ref: kaldi/egs/wsj/s5/utils/lang/bpe/prepend_words.py @ 69cd717
 strip_chars = " \t\r\n"
 whitespace = re.compile("[ \t]+")
 
@@ -52,7 +65,9 @@ class CountsForHistory:
         # The 'lambda: defaultdict(float)' is an anonymous function taking no
         # arguments that returns a new defaultdict(float).
         self.word_to_count = defaultdict(int)
-        self.word_to_context = defaultdict(set)  # using a set to count the number of unique contexts
+        self.word_to_context = defaultdict(
+            set
+        )  # using a set to count the number of unique contexts
         self.word_to_f = dict()  # discounted probability
         self.word_to_bow = dict()  # back-off weight
         self.total_count = 0
@@ -62,10 +77,15 @@ class CountsForHistory:
 
     def __str__(self):
         # e.g. returns ' total=12: 3->4, 4->6, -1->2'
-        return ' total={0}: {1}'.format(
+        return " total={0}: {1}".format(
             str(self.total_count),
-            ', '.join(['{0} -> {1}'.format(word, count)
-                      for word, count in self.word_to_count.items()]))
+            ", ".join(
+                [
+                    "{0} -> {1}".format(word, count)
+                    for word, count in self.word_to_count.items()
+                ]
+            ),
+        )
 
     def add_count(self, predicted_word, context_word, count):
         assert count >= 0
@@ -85,7 +105,7 @@ class NgramCounts:
     # accumulating the 4-gram count for the '8' in the sequence '5 6 7 8', we'd
     # do as follows: self.counts[3][[5,6,7]][8] += 1.0 where the [3] indexes an
     # array, the [[5,6,7]] indexes a dict, and the [8] indexes a dict.
-    def __init__(self, ngram_order, bos_symbol='', eos_symbol=''):
+    def __init__(self, ngram_order, bos_symbol="", eos_symbol=""):
         assert ngram_order >= 2
 
         self.ngram_order = ngram_order
@@ -103,39 +123,48 @@ class NgramCounts:
     # would be (6,7,8) and 'predicted_word' would be 9; 'count' would be
     # 1.
     def add_count(self, history, predicted_word, context_word, count):
-        self.counts[len(history)][history].add_count(predicted_word, context_word, count)
+        self.counts[len(history)][history].add_count(
+            predicted_word, context_word, count
+        )
 
     # 'line' is a string containing a sequence of integer word-ids.
     # This function adds the un-smoothed counts from this line of text.
     def add_raw_counts_from_line(self, line):
-        if line == '':
+        if line == "":
             words = [self.bos_symbol, self.eos_symbol]
         else:
             words = [self.bos_symbol] + whitespace.split(line) + [self.eos_symbol]
 
         for i in range(len(words)):
-            for n in range(1, self.ngram_order+1):
+            for n in range(1, self.ngram_order + 1):
                 if i + n > len(words):
                     break
-                ngram = words[i: i + n]
+                ngram = words[i : i + n]
                 predicted_word = ngram[-1]
-                history = tuple(ngram[: -1])
+                history = tuple(ngram[:-1])
                 if i == 0 or n == self.ngram_order:
                     context_word = None
                 else:
-                    context_word = words[i-1]
+                    context_word = words[i - 1]
 
                 self.add_count(history, predicted_word, context_word, 1)
 
     def add_raw_counts_from_standard_input(self):
         lines_processed = 0
-        infile = io.TextIOWrapper(sys.stdin.buffer, encoding=default_encoding)  # byte stream as input
+        infile = io.TextIOWrapper(
+            sys.stdin.buffer, encoding=default_encoding
+        )  # byte stream as input
         for line in infile:
             line = line.strip(strip_chars)
             self.add_raw_counts_from_line(line)
             lines_processed += 1
         if lines_processed == 0 or args.verbose > 0:
-            print("make_phone_lm.py: processed {0} lines of input".format(lines_processed), file=sys.stderr)
+            print(
+                "make_phone_lm.py: processed {0} lines of input".format(
+                    lines_processed
+                ),
+                file=sys.stderr,
+            )
 
     def add_raw_counts_from_file(self, filename):
         lines_processed = 0
@@ -145,7 +174,12 @@ class NgramCounts:
                 self.add_raw_counts_from_line(line)
                 lines_processed += 1
         if lines_processed == 0 or args.verbose > 0:
-            print("make_phone_lm.py: processed {0} lines of input".format(lines_processed), file=sys.stderr)
+            print(
+                "make_phone_lm.py: processed {0} lines of input".format(
+                    lines_processed
+                ),
+                file=sys.stderr,
+            )
 
     def cal_discounting_constants(self):
         # For each order N of N-grams, we calculate discounting constant D_N = n1_N / (n1_N + 2 * n2_N),
@@ -153,9 +187,11 @@ class NgramCounts:
         # This constant is used similarly to absolute discounting.
         # Return value: d is a list of floats, where d[N+1] = D_N
 
-        self.d = [0]  # for the lowest order, i.e., 1-gram, we do not need to discount, thus the constant is 0
-                      # This is a special case: as we currently assumed having seen all vocabularies in the dictionary,
-                      # but perhaps this is not the case for some other scenarios.
+        self.d = [
+            0
+        ]  # for the lowest order, i.e., 1-gram, we do not need to discount, thus the constant is 0
+        # This is a special case: as we currently assumed having seen all vocabularies in the dictionary,
+        # but perhaps this is not the case for some other scenarios.
         for n in range(1, self.ngram_order):
             this_order_counts = self.counts[n]
             n1 = 0
@@ -165,9 +201,11 @@ class NgramCounts:
                 n1 += stat[1]
                 n2 += stat[2]
             assert n1 + 2 * n2 > 0
-            self.d.append(max(0.1, n1 * 1.0) / (n1 + 2 * n2))   # We are doing this max(0.001, xxx) to avoid zero discounting constant D due to n1=0, 
-                                                                # which could happen if the number of symbols is small.
-                                                                # Otherwise, zero discounting constant can cause division by zero in computing BOW.
+            self.d.append(
+                max(0.1, n1 * 1.0) / (n1 + 2 * n2)
+            )  # We are doing this max(0.001, xxx) to avoid zero discounting constant D due to n1=0,
+            # which could happen if the number of symbols is small.
+            # Otherwise, zero discounting constant can cause division by zero in computing BOW.
 
     def cal_f(self):
         # f(a_z) is a probability distribution of word sequence a_z.
@@ -182,7 +220,9 @@ class NgramCounts:
         this_order_counts = self.counts[n]
         for hist, counts_for_hist in this_order_counts.items():
             for w, c in counts_for_hist.word_to_count.items():
-                counts_for_hist.word_to_f[w] = max((c - self.d[n]), 0) * 1.0 / counts_for_hist.total_count
+                counts_for_hist.word_to_f[w] = (
+                    max((c - self.d[n]), 0) * 1.0 / counts_for_hist.total_count
+                )
 
         # lower order N-grams
         for n in range(0, self.ngram_order - 1):
@@ -196,11 +236,17 @@ class NgramCounts:
                 if n_star_star != 0:
                     for w in counts_for_hist.word_to_count.keys():
                         n_star_z = len(counts_for_hist.word_to_context[w])
-                        counts_for_hist.word_to_f[w] = max((n_star_z - self.d[n]), 0) * 1.0 / n_star_star
+                        counts_for_hist.word_to_f[w] = (
+                            max((n_star_z - self.d[n]), 0) * 1.0 / n_star_star
+                        )
                 else:  # patterns begin with , they do not have "modified count", so use raw count instead
                     for w in counts_for_hist.word_to_count.keys():
                         n_star_z = counts_for_hist.word_to_count[w]
-                        counts_for_hist.word_to_f[w] = max((n_star_z - self.d[n]), 0) * 1.0 / counts_for_hist.total_count
+                        counts_for_hist.word_to_f[w] = (
+                            max((n_star_z - self.d[n]), 0)
+                            * 1.0
+                            / counts_for_hist.total_count
+                        )
 
     def cal_bow(self):
         # Backoff weights are only necessary for ngrams which form a prefix of a longer ngram.
@@ -240,12 +286,18 @@ class NgramCounts:
                         sum_z1_f_z = 0
                         _ = a_[1:]
                         _counts_for_hist = self.counts[len(_)][_]
-                        for u in a_counts_for_hist.word_to_count.keys():  # Should be careful here: what is Z1
+                        for (
+                            u
+                        ) in (
+                            a_counts_for_hist.word_to_count.keys()
+                        ):  # Should be careful here: what is Z1
                             sum_z1_f_z += _counts_for_hist.word_to_f[u]
 
                         if sum_z1_f_z < 1:
                             # assert sum_z1_f_a_z < 1
-                            counts_for_hist.word_to_bow[w] = (1.0 - sum_z1_f_a_z) / (1.0 - sum_z1_f_z)
+                            counts_for_hist.word_to_bow[w] = (1.0 - sum_z1_f_a_z) / (
+                                1.0 - sum_z1_f_z
+                            )
                         else:
                             counts_for_hist.word_to_bow[w] = None
 
@@ -259,7 +311,9 @@ class NgramCounts:
                     ngram = " ".join(hist) + " " + w
                     ngram = ngram.strip(strip_chars)
 
-                    res.append("{0}\t{1}".format(ngram, counts_for_hist.word_to_count[w]))
+                    res.append(
+                        "{0}\t{1}".format(ngram, counts_for_hist.word_to_count[w])
+                    )
         res.sort(reverse=True)
         for r in res:
             print(r)
@@ -322,27 +376,40 @@ class NgramCounts:
                     if bow is None:
                         res.append("{1}\t{0}".format(ngram, math.log(f, 10)))
                     else:
-                        res.append("{1}\t{0}\t{2}".format(ngram, math.log(f, 10), math.log(bow, 10)))
+                        res.append(
+                            "{1}\t{0}\t{2}".format(
+                                ngram, math.log(f, 10), math.log(bow, 10)
+                            )
+                        )
         res.sort(reverse=True)
         for r in res:
             print(r)
 
-    def print_as_arpa(self, fout=io.TextIOWrapper(sys.stdout.buffer, encoding='latin-1')):
+    def print_as_arpa(
+        self, fout=io.TextIOWrapper(sys.stdout.buffer, encoding="latin-1")
+    ):
         # print as ARPA format.
 
-        print('\\data\\', file=fout)
+        print("\\data\\", file=fout)
         for hist_len in range(self.ngram_order):
             # print the number of n-grams.
-            print('ngram {0}={1}'.format(
-                hist_len + 1,
-                sum([len(counts_for_hist.word_to_f) for counts_for_hist in self.counts[hist_len].values()])),
-                file=fout
+            print(
+                "ngram {0}={1}".format(
+                    hist_len + 1,
+                    sum(
+                        [
+                            len(counts_for_hist.word_to_f)
+                            for counts_for_hist in self.counts[hist_len].values()
+                        ]
+                    ),
+                ),
+                file=fout,
             )
 
-        print('', file=fout)
+        print("", file=fout)
 
         for hist_len in range(self.ngram_order):
-            print('\\{0}-grams:'.format(hist_len + 1), file=fout)
+            print("\\{0}-grams:".format(hist_len + 1), file=fout)
 
             this_order_counts = self.counts[hist_len]
             for hist, counts_for_hist in this_order_counts.items():
@@ -354,12 +421,12 @@ class NgramCounts:
                     if prob == 0:  # f() is always 0
                         prob = 1e-99
 
-                    line = '{0}\t{1}'.format('%.7f' % math.log10(prob), ' '.join(ngram))
+                    line = "{0}\t{1}".format("%.7f" % math.log10(prob), " ".join(ngram))
                     if bow is not None:
-                        line += '\t{0}'.format('%.7f' % math.log10(bow))
+                        line += "\t{0}".format("%.7f" % math.log10(bow))
                     print(line, file=fout)
-            print('', file=fout)
-        print('\\end\\', file=fout)
+            print("", file=fout)
+        print("\\end\\", file=fout)
 
 
 if __name__ == "__main__":
@@ -379,5 +446,5 @@ if __name__ == "__main__":
     if args.lm is None:
         ngram_counts.print_as_arpa()
     else:
-        with open(args.lm, 'w', encoding=default_encoding) as f:
+        with open(args.lm, "w", encoding=default_encoding) as f:
             ngram_counts.print_as_arpa(fout=f)
diff --git a/icefall/utils.py b/icefall/utils.py
index 143c79497..b4d8e9a51 100644
--- a/icefall/utils.py
+++ b/icefall/utils.py
@@ -130,9 +130,7 @@ def setup_logger(
         formatter = f"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] ({rank}/{world_size}) %(message)s"  # noqa
         log_filename = f"{log_filename}-{date_time}-{rank}"
     else:
-        formatter = (
-            "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-        )
+        formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
         log_filename = f"{log_filename}-{date_time}"
 
     os.makedirs(os.path.dirname(log_filename), exist_ok=True)
@@ -203,7 +201,7 @@ def encode_supervisions(
                 supervisions["num_frames"],
                 subsampling_factor,
                 rounding_mode="floor",
-            )
+            ),
         ),
         1,
     ).to(torch.int32)
@@ -288,13 +286,9 @@ def get_texts_with_timestamp(
     """
     if isinstance(best_paths.aux_labels, k2.RaggedTensor):
         all_aux_shape = (
-            best_paths.arcs.shape()
-            .remove_axis(1)
-            .compose(best_paths.aux_labels.shape)
-        )
-        all_aux_labels = k2.RaggedTensor(
-            all_aux_shape, best_paths.aux_labels.values
+            best_paths.arcs.shape().remove_axis(1).compose(best_paths.aux_labels.shape)
         )
+        all_aux_labels = k2.RaggedTensor(all_aux_shape, best_paths.aux_labels.values)
         # remove 0's and -1's.
         aux_labels = best_paths.aux_labels.remove_values_leq(0)
         # TODO: change arcs.shape() to arcs.shape
@@ -363,9 +357,7 @@ def get_alignments(best_paths: k2.Fsa, kind: str) -> List[List[int]]:
     # arc.shape() has axes [fsa][state][arc], we remove "state"-axis here
     token_shape = best_paths.arcs.shape().remove_axis(1)
     # token_shape has axes [fsa][arc]
-    tokens = k2.RaggedTensor(
-        token_shape, getattr(best_paths, kind).contiguous()
-    )
+    tokens = k2.RaggedTensor(token_shape, getattr(best_paths, kind).contiguous())
     tokens = tokens.remove_values_eq(-1)
     return tokens.tolist()
 
@@ -586,9 +578,7 @@ def write_error_stats(
             f"{cut_id}:\t"
             + " ".join(
                 (
-                    ref_word
-                    if ref_word == hyp_word
-                    else f"({ref_word}->{hyp_word})"
+                    ref_word if ref_word == hyp_word else f"({ref_word}->{hyp_word})"
                     for ref_word, hyp_word in ali
                 )
             ),
@@ -598,9 +588,7 @@ def write_error_stats(
     print("", file=f)
     print("SUBSTITUTIONS: count ref -> hyp", file=f)
 
-    for count, (ref, hyp) in sorted(
-        [(v, k) for k, v in subs.items()], reverse=True
-    ):
+    for count, (ref, hyp) in sorted([(v, k) for k, v in subs.items()], reverse=True):
         print(f"{count}   {ref} -> {hyp}", file=f)
 
     print("", file=f)
@@ -614,9 +602,7 @@ def write_error_stats(
         print(f"{count}   {hyp}", file=f)
 
     print("", file=f)
-    print(
-        "PER-WORD STATS: word  corr tot_errs count_in_ref count_in_hyp", file=f
-    )
+    print("PER-WORD STATS: word  corr tot_errs count_in_ref count_in_hyp", file=f)
     for _, word, counts in sorted(
         [(sum(v[1:]), k, v) for k, v in words.items()], reverse=True
     ):
@@ -791,9 +777,7 @@ def write_error_stats_with_timestamps(
             f"{cut_id}:\t"
             + " ".join(
                 (
-                    ref_word
-                    if ref_word == hyp_word
-                    else f"({ref_word}->{hyp_word})"
+                    ref_word if ref_word == hyp_word else f"({ref_word}->{hyp_word})"
                     for ref_word, hyp_word in ali
                 )
             ),
@@ -803,9 +787,7 @@ def write_error_stats_with_timestamps(
     print("", file=f)
     print("SUBSTITUTIONS: count ref -> hyp", file=f)
 
-    for count, (ref, hyp) in sorted(
-        [(v, k) for k, v in subs.items()], reverse=True
-    ):
+    for count, (ref, hyp) in sorted([(v, k) for k, v in subs.items()], reverse=True):
         print(f"{count}   {ref} -> {hyp}", file=f)
 
     print("", file=f)
@@ -819,9 +801,7 @@ def write_error_stats_with_timestamps(
         print(f"{count}   {hyp}", file=f)
 
     print("", file=f)
-    print(
-        "PER-WORD STATS: word  corr tot_errs count_in_ref count_in_hyp", file=f
-    )
+    print("PER-WORD STATS: word  corr tot_errs count_in_ref count_in_hyp", file=f)
     for _, word, counts in sorted(
         [(sum(v[1:]), k, v) for k, v in words.items()], reverse=True
     ):
@@ -891,9 +871,7 @@ class MetricsTracker(collections.defaultdict):
             if k == "frames" or k == "utterances":
                 continue
             norm_value = (
-                float(v) / num_frames
-                if "utt_" not in k
-                else float(v) / num_utterances
+                float(v) / num_frames if "utt_" not in k else float(v) / num_utterances
             )
             ans.append((k, norm_value))
         return ans
@@ -927,9 +905,7 @@ class MetricsTracker(collections.defaultdict):
             tb_writer.add_scalar(prefix + k, v, batch_idx)
 
 
-def concat(
-    ragged: k2.RaggedTensor, value: int, direction: str
-) -> k2.RaggedTensor:
+def concat(ragged: k2.RaggedTensor, value: int, direction: str) -> k2.RaggedTensor:
     """Prepend a value to the beginning of each sublist or append a value.
     to the end of each sublist.
 
@@ -1101,9 +1077,7 @@ def linf_norm(x):
     return torch.max(torch.abs(x))
 
 
-def measure_weight_norms(
-    model: nn.Module, norm: str = "l2"
-) -> Dict[str, float]:
+def measure_weight_norms(model: nn.Module, norm: str = "l2") -> Dict[str, float]:
     """
     Compute the norms of the model's parameters.
 
@@ -1126,9 +1100,7 @@ def measure_weight_norms(
         return norms
 
 
-def measure_gradient_norms(
-    model: nn.Module, norm: str = "l1"
-) -> Dict[str, float]:
+def measure_gradient_norms(model: nn.Module, norm: str = "l1") -> Dict[str, float]:
     """
     Compute the norms of the gradients for each of model's parameters.
 
@@ -1413,9 +1385,7 @@ def parse_hyp_and_timestamp(
         use_word_table = True
 
     for i in range(N):
-        time = convert_timestamp(
-            res.timestamps[i], subsampling_factor, frame_shift_ms
-        )
+        time = convert_timestamp(res.timestamps[i], subsampling_factor, frame_shift_ms)
         if use_word_table:
             words = [word_table[i] for i in res.hyps[i]]
         else:
diff --git a/pyproject.toml b/pyproject.toml
index b4f8c3377..3183055d4 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -3,7 +3,7 @@ profile = "black"
 skip = ["icefall/__init__.py"]
 
 [tool.black]
-line-length = 80
+line-length = 88
 exclude = '''
 /(
     \.git
diff --git a/setup.py b/setup.py
index 6c720e121..ccd2503ff 100644
--- a/setup.py
+++ b/setup.py
@@ -1,8 +1,9 @@
 #!/usr/bin/env python3
 
-from setuptools import find_packages, setup
 from pathlib import Path
 
+from setuptools import find_packages, setup
+
 icefall_dir = Path(__file__).parent
 install_requires = (icefall_dir / "requirements.txt").read_text().splitlines()
 
diff --git a/test/test_checkpoint.py b/test/test_checkpoint.py
index 511a11c23..34e829642 100644
--- a/test/test_checkpoint.py
+++ b/test/test_checkpoint.py
@@ -20,11 +20,7 @@ import pytest
 import torch
 import torch.nn as nn
 
-from icefall.checkpoint import (
-    average_checkpoints,
-    load_checkpoint,
-    save_checkpoint,
-)
+from icefall.checkpoint import average_checkpoints, load_checkpoint, save_checkpoint
 
 
 @pytest.fixture
diff --git a/test/test_decode.py b/test/test_decode.py
index 97964ac67..4c2e192a7 100644
--- a/test/test_decode.py
+++ b/test/test_decode.py
@@ -23,6 +23,7 @@ You can run this file in one of the two ways:
 """
 
 import k2
+
 from icefall.decode import Nbest
 
 
diff --git a/test/test_graph_compiler.py b/test/test_graph_compiler.py
index ccfb57d49..10443cf22 100644
--- a/test/test_graph_compiler.py
+++ b/test/test_graph_compiler.py
@@ -154,9 +154,7 @@ class TestCtcTrainingGraphCompiler(object):
         fsas = k2.Fsa.from_fsas([fsa1, fsa2])
 
         decoding_graph = k2.arc_sort(decoding_graph)
-        lattice = k2.intersect(
-            decoding_graph, fsas, treat_epsilons_specially=False
-        )
+        lattice = k2.intersect(decoding_graph, fsas, treat_epsilons_specially=False)
         lattice = k2.connect(lattice)
 
         aux_labels0 = lattice[0].aux_labels[:-1]
diff --git a/test/test_utils.py b/test/test_utils.py
index 6a9ce7853..31f06bd51 100644
--- a/test/test_utils.py
+++ b/test/test_utils.py
@@ -50,9 +50,7 @@ def test_encode_supervisions(sup):
     assert torch.all(
         torch.eq(
             supervision_segments,
-            torch.tensor(
-                [[1, 0, 30 // 4], [0, 0, 20 // 4], [2, 9 // 4, 10 // 4]]
-            ),
+            torch.tensor([[1, 0, 30 // 4], [0, 0, 20 // 4], [2, 9 // 4, 10 // 4]]),
         )
     )
     assert texts == ["two", "one", "three"]

From 18e3a7a9d59ed4079a6ec53039ef60f2aeeb89f4 Mon Sep 17 00:00:00 2001
From: Desh Raj 
Date: Thu, 17 Nov 2022 09:43:48 -0500
Subject: [PATCH 012/174] add git blame ignore file

---
 .git-blame-ignore-revs | 2 ++
 1 file changed, 2 insertions(+)
 create mode 100644 .git-blame-ignore-revs

diff --git a/.git-blame-ignore-revs b/.git-blame-ignore-revs
new file mode 100644
index 000000000..be5901517
--- /dev/null
+++ b/.git-blame-ignore-revs
@@ -0,0 +1,2 @@
+# Migrate to 88 characters per line (see: https://github.com/lhotse-speech/lhotse/issues/890)
+107df3b115a58f1b68a6458c3f94a130004be34c

From d31db010371a4128856480382876acdc0d1739ed Mon Sep 17 00:00:00 2001
From: Desh Raj 
Date: Thu, 17 Nov 2022 14:18:05 -0500
Subject: [PATCH 013/174] manual correction of black formatting

---
 .../pruned_transducer_stateless2/decode.py    |  2 +-
 .../pruned_transducer_stateless2/export.py    |  2 +-
 .../pretrained.py                             |  8 ++---
 .../ASR/pruned_transducer_stateless2/train.py |  4 +--
 egs/aishell/ASR/conformer_ctc/pretrained.py   |  6 ++--
 .../pruned_transducer_stateless2/decode.py    |  4 +--
 .../pruned_transducer_stateless2/export.py    |  4 +--
 .../pretrained.py                             |  8 ++---
 .../ASR/pruned_transducer_stateless2/train.py |  6 ++--
 .../pruned_transducer_stateless3/decode.py    |  2 +-
 .../pruned_transducer_stateless3/export.py    |  2 +-
 .../pretrained.py                             |  8 ++---
 .../ASR/pruned_transducer_stateless3/train.py |  6 ++--
 egs/aishell/ASR/tdnn_lstm_ctc/pretrained.py   |  6 ++--
 .../ASR/transducer_stateless/decode.py        |  2 +-
 .../ASR/transducer_stateless/export.py        |  2 +-
 .../ASR/transducer_stateless/pretrained.py    |  8 ++---
 egs/aishell/ASR/transducer_stateless/train.py |  2 +-
 .../transducer_stateless_modified-2/decode.py |  2 +-
 .../transducer_stateless_modified-2/export.py |  2 +-
 .../pretrained.py                             |  8 ++---
 .../transducer_stateless_modified-2/train.py  |  4 +--
 .../transducer_stateless_modified/decode.py   |  2 +-
 .../transducer_stateless_modified/export.py   |  2 +-
 .../pretrained.py                             |  8 ++---
 .../transducer_stateless_modified/train.py    |  2 +-
 .../pruned_transducer_stateless5/decode.py    |  2 +-
 .../pruned_transducer_stateless5/export.py    |  2 +-
 .../pretrained.py                             |  8 ++---
 .../ASR/pruned_transducer_stateless5/train.py |  6 ++--
 .../pruned_transducer_stateless5/decode.py    |  2 +-
 .../pruned_transducer_stateless5/export.py    |  2 +-
 .../pretrained.py                             |  8 ++---
 .../ASR/pruned_transducer_stateless5/train.py |  6 ++--
 .../pruned_transducer_stateless2/decode.py    |  2 +-
 .../pruned_transducer_stateless2/export.py    |  2 +-
 .../pretrained.py                             |  8 ++---
 .../ASR/pruned_transducer_stateless2/train.py |  4 +--
 egs/csj/ASR/local/compute_fbank_csj.py        |  8 +++--
 egs/csj/ASR/local/prepare_lang_char.py        |  6 ++--
 .../ASR/conformer_ctc/asr_datamodule.py       |  2 +-
 .../asr_datamodule.py                         |  2 +-
 .../pruned_transducer_stateless2/decode.py    |  4 +--
 .../pruned_transducer_stateless2/export.py    |  4 +--
 .../ASR/pruned_transducer_stateless2/train.py |  4 +--
 .../ASR/conformer_ctc/pretrained.py           |  6 ++--
 .../decode.py                                 |  2 +-
 .../emformer.py                               |  2 +-
 .../export.py                                 |  2 +-
 .../streaming_decode.py                       |  2 +-
 .../train.py                                  |  4 +--
 .../decode.py                                 |  2 +-
 .../emformer.py                               |  2 +-
 .../export.py                                 |  2 +-
 .../streaming_decode.py                       |  2 +-
 .../train.py                                  |  4 +--
 egs/librispeech/ASR/local/filter_cuts.py      |  4 +--
 .../ASR/local/prepare_lm_training_data.py     |  2 +-
 .../ASR/lstm_transducer_stateless/decode.py   |  2 +-
 .../ASR/lstm_transducer_stateless/export.py   |  2 +-
 .../jit_pretrained.py                         |  6 ++--
 .../lstm_transducer_stateless/pretrained.py   |  8 ++---
 .../streaming_decode.py                       |  2 +-
 .../ASR/lstm_transducer_stateless/train.py    |  6 ++--
 .../ASR/lstm_transducer_stateless2/decode.py  |  2 +-
 .../ASR/lstm_transducer_stateless2/export.py  |  2 +-
 .../jit_pretrained.py                         |  6 ++--
 .../lstm_transducer_stateless2/ncnn-decode.py |  6 ++--
 .../lstm_transducer_stateless2/pretrained.py  |  8 ++---
 .../streaming-ncnn-decode.py                  |  6 ++--
 .../streaming-onnx-decode.py                  |  6 ++--
 .../ASR/lstm_transducer_stateless2/train.py   |  8 ++---
 .../ASR/lstm_transducer_stateless3/decode.py  |  2 +-
 .../ASR/lstm_transducer_stateless3/export.py  |  2 +-
 .../jit_pretrained.py                         |  6 ++--
 .../lstm_transducer_stateless3/pretrained.py  |  8 ++---
 .../streaming_decode.py                       |  2 +-
 .../ASR/lstm_transducer_stateless3/train.py   |  6 ++--
 .../ASR/pruned2_knowledge/asr_datamodule.py   |  2 +-
 .../ASR/pruned2_knowledge/decode.py           |  2 +-
 .../ASR/pruned2_knowledge/export.py           |  2 +-
 .../ASR/pruned2_knowledge/train.py            |  4 +--
 .../pruned_stateless_emformer_rnnt2/decode.py |  2 +-
 .../pruned_stateless_emformer_rnnt2/export.py |  2 +-
 .../pruned_stateless_emformer_rnnt2/train.py  |  6 ++--
 .../ASR/pruned_transducer_stateless/decode.py |  4 +--
 .../ASR/pruned_transducer_stateless/export.py |  2 +-
 .../pruned_transducer_stateless/pretrained.py |  8 ++---
 .../streaming_decode.py                       |  4 +--
 .../ASR/pruned_transducer_stateless/train.py  |  6 ++--
 .../pruned_transducer_stateless2/decode.py    |  4 +--
 .../pruned_transducer_stateless2/export.py    |  4 +--
 .../pretrained.py                             |  8 ++---
 .../streaming_decode.py                       |  4 +--
 .../ASR/pruned_transducer_stateless2/train.py |  8 ++---
 .../decode-giga.py                            |  4 +--
 .../pruned_transducer_stateless3/decode.py    |  6 ++--
 .../pruned_transducer_stateless3/export.py    |  4 +--
 .../jit_pretrained.py                         |  6 ++--
 .../onnx_pretrained.py                        |  6 ++--
 .../pretrained.py                             |  8 ++---
 .../streaming_decode.py                       |  4 +--
 .../ASR/pruned_transducer_stateless3/train.py | 10 +++---
 .../pruned_transducer_stateless4/decode.py    |  2 +-
 .../pruned_transducer_stateless4/export.py    |  2 +-
 .../streaming_decode.py                       |  2 +-
 .../ASR/pruned_transducer_stateless4/train.py |  6 ++--
 .../pruned_transducer_stateless5/decode.py    |  2 +-
 .../pruned_transducer_stateless5/export.py    |  2 +-
 .../pretrained.py                             |  8 ++---
 .../streaming_decode.py                       |  2 +-
 .../ASR/pruned_transducer_stateless5/train.py |  8 ++---
 .../pruned_transducer_stateless6/decode.py    |  2 +-
 .../pruned_transducer_stateless6/export.py    |  4 +--
 .../ASR/pruned_transducer_stateless6/train.py |  6 ++--
 .../pruned_transducer_stateless7/decode.py    |  2 +-
 .../pruned_transducer_stateless7/export.py    |  2 +-
 .../jit_pretrained.py                         |  6 ++--
 .../pretrained.py                             |  8 ++---
 .../ASR/pruned_transducer_stateless7/train.py |  6 ++--
 .../pruned_transducer_stateless8/decode.py    |  2 +-
 .../pruned_transducer_stateless8/export.py    |  2 +-
 .../jit_pretrained.py                         |  6 ++--
 .../pretrained.py                             |  8 ++---
 .../ASR/pruned_transducer_stateless8/train.py |  8 ++---
 .../streaming_decode.py                       |  2 +-
 .../ASR/tdnn_lstm_ctc/asr_datamodule.py       |  2 +-
 .../ASR/tdnn_lstm_ctc/pretrained.py           |  6 ++--
 egs/librispeech/ASR/transducer/pretrained.py  |  6 ++--
 .../ASR/transducer_stateless/compute_ali.py   |  2 +-
 .../ASR/transducer_stateless/decode.py        |  2 +-
 .../ASR/transducer_stateless/export.py        |  2 +-
 .../ASR/transducer_stateless/pretrained.py    |  8 ++---
 .../ASR/transducer_stateless/train.py         |  2 +-
 .../ASR/transducer_stateless2/decode.py       |  2 +-
 .../ASR/transducer_stateless2/export.py       |  2 +-
 .../ASR/transducer_stateless2/pretrained.py   |  8 ++---
 .../ASR/transducer_stateless2/train.py        |  2 +-
 .../decode.py                                 |  2 +-
 .../export.py                                 |  2 +-
 .../pretrained.py                             |  8 ++---
 .../train.py                                  |  4 +--
 .../pruned_transducer_stateless2/decode.py    |  4 +--
 .../pruned_transducer_stateless2/export.py    |  2 +-
 .../ASR/pruned_transducer_stateless2/train.py |  6 ++--
 .../pruned_transducer_stateless5/decode.py    |  2 +-
 .../pruned_transducer_stateless5/export.py    |  2 +-
 .../pretrained.py                             |  8 ++---
 .../ASR/pruned_transducer_stateless5/train.py |  6 ++--
 .../ASR/pruned_transducer_stateless/decode.py |  2 +-
 .../ASR/pruned_transducer_stateless/export.py |  2 +-
 .../pruned_transducer_stateless/pretrained.py |  8 ++---
 .../ASR/pruned_transducer_stateless/train.py  |  4 +--
 .../ASR/transducer_stateless/decode.py        |  2 +-
 .../ASR/transducer_stateless/export.py        |  2 +-
 .../ASR/transducer_stateless/pretrained.py    |  8 ++---
 .../ASR/transducer_stateless/train.py         |  2 +-
 egs/timit/ASR/tdnn_ligru_ctc/pretrained.py    |  6 ++--
 egs/timit/ASR/tdnn_lstm_ctc/pretrained.py     |  6 ++--
 .../pruned_transducer_stateless2/decode.py    |  2 +-
 .../pruned_transducer_stateless2/export.py    |  2 +-
 .../jit_pretrained.py                         |  6 ++--
 .../onnx_pretrained.py                        |  6 ++--
 .../pretrained.py                             |  8 ++---
 .../ASR/pruned_transducer_stateless2/train.py |  4 +--
 .../pruned_transducer_stateless5/decode.py    |  2 +-
 .../pruned_transducer_stateless5/export.py    |  2 +-
 .../pretrained.py                             |  8 ++---
 .../streaming_decode.py                       |  2 +-
 .../ASR/pruned_transducer_stateless5/train.py |  6 ++--
 egs/yesno/ASR/tdnn/pretrained.py              |  6 ++--
 icefall/shared/make_kn_lm.py                  | 34 ++++++++-----------
 172 files changed, 381 insertions(+), 383 deletions(-)

diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/decode.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/decode.py
index b1c7c2839..d0f118959 100755
--- a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/decode.py
+++ b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/decode.py
@@ -188,7 +188,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/export.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/export.py
index de37ec7e4..e348f7b2b 100644
--- a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/export.py
+++ b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/export.py
@@ -103,7 +103,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     return parser
diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/pretrained.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/pretrained.py
index 548b7263c..75c316eaf 100644
--- a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/pretrained.py
+++ b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/pretrained.py
@@ -162,7 +162,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
@@ -192,9 +192,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/train.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/train.py
index 322fa6b00..c9d9c4aa8 100644
--- a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/train.py
+++ b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/train.py
@@ -185,7 +185,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
@@ -208,7 +208,7 @@ def get_parser():
         "--am-scale",
         type=float,
         default=0.0,
-        help="The scale to smooth the loss with am (output of encoder network)" "part.",
+        help="The scale to smooth the loss with am (output of encoder network) part.",
     )
 
     parser.add_argument(
diff --git a/egs/aishell/ASR/conformer_ctc/pretrained.py b/egs/aishell/ASR/conformer_ctc/pretrained.py
index e0dcb8ad4..66d583396 100755
--- a/egs/aishell/ASR/conformer_ctc/pretrained.py
+++ b/egs/aishell/ASR/conformer_ctc/pretrained.py
@@ -210,9 +210,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
diff --git a/egs/aishell/ASR/pruned_transducer_stateless2/decode.py b/egs/aishell/ASR/pruned_transducer_stateless2/decode.py
index 199acf6c3..20a4f21c7 100755
--- a/egs/aishell/ASR/pruned_transducer_stateless2/decode.py
+++ b/egs/aishell/ASR/pruned_transducer_stateless2/decode.py
@@ -184,7 +184,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=1,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -487,7 +487,7 @@ def main():
         ]
         if len(filenames) == 0:
             raise ValueError(
-                f"No checkpoints found for" f" --iter {params.iter}, --avg {params.avg}"
+                f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
             )
         elif len(filenames) < params.avg:
             raise ValueError(
diff --git a/egs/aishell/ASR/pruned_transducer_stateless2/export.py b/egs/aishell/ASR/pruned_transducer_stateless2/export.py
index 4d41e425c..2ce5cfe69 100755
--- a/egs/aishell/ASR/pruned_transducer_stateless2/export.py
+++ b/egs/aishell/ASR/pruned_transducer_stateless2/export.py
@@ -116,7 +116,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=1,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     add_model_arguments(parser)
@@ -152,7 +152,7 @@ def main():
         ]
         if len(filenames) == 0:
             raise ValueError(
-                f"No checkpoints found for" f" --iter {params.iter}, --avg {params.avg}"
+                f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
             )
         elif len(filenames) < params.avg:
             raise ValueError(
diff --git a/egs/aishell/ASR/pruned_transducer_stateless2/pretrained.py b/egs/aishell/ASR/pruned_transducer_stateless2/pretrained.py
index 8aa0fbdd7..82c10f129 100755
--- a/egs/aishell/ASR/pruned_transducer_stateless2/pretrained.py
+++ b/egs/aishell/ASR/pruned_transducer_stateless2/pretrained.py
@@ -165,7 +165,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=1,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -195,9 +195,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
diff --git a/egs/aishell/ASR/pruned_transducer_stateless2/train.py b/egs/aishell/ASR/pruned_transducer_stateless2/train.py
index f81ab2568..d08908238 100755
--- a/egs/aishell/ASR/pruned_transducer_stateless2/train.py
+++ b/egs/aishell/ASR/pruned_transducer_stateless2/train.py
@@ -200,7 +200,7 @@ def get_parser():
         "--initial-lr",
         type=float,
         default=0.003,
-        help="The initial learning rate.  This value should not need " "to be changed.",
+        help="The initial learning rate.  This value should not need to be changed.",
     )
 
     parser.add_argument(
@@ -223,7 +223,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=1,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
@@ -246,7 +246,7 @@ def get_parser():
         "--am-scale",
         type=float,
         default=0.0,
-        help="The scale to smooth the loss with am (output of encoder network)" "part.",
+        help="The scale to smooth the loss with am (output of encoder network) part.",
     )
 
     parser.add_argument(
diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/decode.py b/egs/aishell/ASR/pruned_transducer_stateless3/decode.py
index f6c919e9d..bac829ae1 100755
--- a/egs/aishell/ASR/pruned_transducer_stateless3/decode.py
+++ b/egs/aishell/ASR/pruned_transducer_stateless3/decode.py
@@ -202,7 +202,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=1,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/export.py b/egs/aishell/ASR/pruned_transducer_stateless3/export.py
index 5e701c121..7f10eb36e 100755
--- a/egs/aishell/ASR/pruned_transducer_stateless3/export.py
+++ b/egs/aishell/ASR/pruned_transducer_stateless3/export.py
@@ -132,7 +132,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=1,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     add_model_arguments(parser)
diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/pretrained.py b/egs/aishell/ASR/pruned_transducer_stateless3/pretrained.py
index 40926173c..ead393e6e 100755
--- a/egs/aishell/ASR/pruned_transducer_stateless3/pretrained.py
+++ b/egs/aishell/ASR/pruned_transducer_stateless3/pretrained.py
@@ -165,7 +165,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=1,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -195,9 +195,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/train.py b/egs/aishell/ASR/pruned_transducer_stateless3/train.py
index 680986ee9..62e67530d 100755
--- a/egs/aishell/ASR/pruned_transducer_stateless3/train.py
+++ b/egs/aishell/ASR/pruned_transducer_stateless3/train.py
@@ -222,7 +222,7 @@ def get_parser():
         "--initial-lr",
         type=float,
         default=0.003,
-        help="The initial learning rate.  This value should not need " "to be changed.",
+        help="The initial learning rate.  This value should not need to be changed.",
     )
 
     parser.add_argument(
@@ -245,7 +245,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=1,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
@@ -268,7 +268,7 @@ def get_parser():
         "--am-scale",
         type=float,
         default=0.0,
-        help="The scale to smooth the loss with am (output of encoder network)" "part.",
+        help="The scale to smooth the loss with am (output of encoder network) part.",
     )
 
     parser.add_argument(
diff --git a/egs/aishell/ASR/tdnn_lstm_ctc/pretrained.py b/egs/aishell/ASR/tdnn_lstm_ctc/pretrained.py
index fe197a9f9..7e7213501 100644
--- a/egs/aishell/ASR/tdnn_lstm_ctc/pretrained.py
+++ b/egs/aishell/ASR/tdnn_lstm_ctc/pretrained.py
@@ -110,9 +110,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
diff --git a/egs/aishell/ASR/transducer_stateless/decode.py b/egs/aishell/ASR/transducer_stateless/decode.py
index fbc54f68b..e019d2329 100755
--- a/egs/aishell/ASR/transducer_stateless/decode.py
+++ b/egs/aishell/ASR/transducer_stateless/decode.py
@@ -99,7 +99,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
diff --git a/egs/aishell/ASR/transducer_stateless/export.py b/egs/aishell/ASR/transducer_stateless/export.py
index eea9b6883..01de5d772 100755
--- a/egs/aishell/ASR/transducer_stateless/export.py
+++ b/egs/aishell/ASR/transducer_stateless/export.py
@@ -110,7 +110,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     return parser
diff --git a/egs/aishell/ASR/transducer_stateless/pretrained.py b/egs/aishell/ASR/transducer_stateless/pretrained.py
index b03a2643a..40f430e13 100755
--- a/egs/aishell/ASR/transducer_stateless/pretrained.py
+++ b/egs/aishell/ASR/transducer_stateless/pretrained.py
@@ -117,7 +117,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -210,9 +210,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
diff --git a/egs/aishell/ASR/transducer_stateless/train.py b/egs/aishell/ASR/transducer_stateless/train.py
index 4ea902507..62ffff473 100755
--- a/egs/aishell/ASR/transducer_stateless/train.py
+++ b/egs/aishell/ASR/transducer_stateless/train.py
@@ -126,7 +126,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
diff --git a/egs/aishell/ASR/transducer_stateless_modified-2/decode.py b/egs/aishell/ASR/transducer_stateless_modified-2/decode.py
index cb206af6d..41cc1c01c 100755
--- a/egs/aishell/ASR/transducer_stateless_modified-2/decode.py
+++ b/egs/aishell/ASR/transducer_stateless_modified-2/decode.py
@@ -170,7 +170,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
diff --git a/egs/aishell/ASR/transducer_stateless_modified-2/export.py b/egs/aishell/ASR/transducer_stateless_modified-2/export.py
index 3c56d4a01..c1081c32b 100755
--- a/egs/aishell/ASR/transducer_stateless_modified-2/export.py
+++ b/egs/aishell/ASR/transducer_stateless_modified-2/export.py
@@ -109,7 +109,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     return parser
diff --git a/egs/aishell/ASR/transducer_stateless_modified-2/pretrained.py b/egs/aishell/ASR/transducer_stateless_modified-2/pretrained.py
index d8c0c5fcd..5d8ca2e11 100755
--- a/egs/aishell/ASR/transducer_stateless_modified-2/pretrained.py
+++ b/egs/aishell/ASR/transducer_stateless_modified-2/pretrained.py
@@ -165,7 +165,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -193,9 +193,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
diff --git a/egs/aishell/ASR/transducer_stateless_modified-2/train.py b/egs/aishell/ASR/transducer_stateless_modified-2/train.py
index a9a30d7f7..8fb7d1e49 100755
--- a/egs/aishell/ASR/transducer_stateless_modified-2/train.py
+++ b/egs/aishell/ASR/transducer_stateless_modified-2/train.py
@@ -149,7 +149,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
@@ -167,7 +167,7 @@ def get_parser():
         "--datatang-prob",
         type=float,
         default=0.2,
-        help="The probability to select a batch from the " "aidatatang_200zh dataset",
+        help="The probability to select a batch from the aidatatang_200zh dataset",
     )
 
     return parser
diff --git a/egs/aishell/ASR/transducer_stateless_modified/decode.py b/egs/aishell/ASR/transducer_stateless_modified/decode.py
index ba3cb3218..7c06e6e51 100755
--- a/egs/aishell/ASR/transducer_stateless_modified/decode.py
+++ b/egs/aishell/ASR/transducer_stateless_modified/decode.py
@@ -171,7 +171,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
diff --git a/egs/aishell/ASR/transducer_stateless_modified/export.py b/egs/aishell/ASR/transducer_stateless_modified/export.py
index cbdbdbeb6..3e14ad69c 100755
--- a/egs/aishell/ASR/transducer_stateless_modified/export.py
+++ b/egs/aishell/ASR/transducer_stateless_modified/export.py
@@ -109,7 +109,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     return parser
diff --git a/egs/aishell/ASR/transducer_stateless_modified/pretrained.py b/egs/aishell/ASR/transducer_stateless_modified/pretrained.py
index 7dfa92a3c..9e4459247 100755
--- a/egs/aishell/ASR/transducer_stateless_modified/pretrained.py
+++ b/egs/aishell/ASR/transducer_stateless_modified/pretrained.py
@@ -165,7 +165,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -193,9 +193,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
diff --git a/egs/aishell/ASR/transducer_stateless_modified/train.py b/egs/aishell/ASR/transducer_stateless_modified/train.py
index c4bf4dd56..5f116f2bd 100755
--- a/egs/aishell/ASR/transducer_stateless_modified/train.py
+++ b/egs/aishell/ASR/transducer_stateless_modified/train.py
@@ -142,7 +142,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/decode.py b/egs/aishell2/ASR/pruned_transducer_stateless5/decode.py
index 7900c5883..b5da0959b 100755
--- a/egs/aishell2/ASR/pruned_transducer_stateless5/decode.py
+++ b/egs/aishell2/ASR/pruned_transducer_stateless5/decode.py
@@ -269,7 +269,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/export.py b/egs/aishell2/ASR/pruned_transducer_stateless5/export.py
index ea4a8d4f9..8a5be94d0 100755
--- a/egs/aishell2/ASR/pruned_transducer_stateless5/export.py
+++ b/egs/aishell2/ASR/pruned_transducer_stateless5/export.py
@@ -133,7 +133,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     add_model_arguments(parser)
diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/pretrained.py b/egs/aishell2/ASR/pruned_transducer_stateless5/pretrained.py
index 94536fa6f..bc3ae7abf 100755
--- a/egs/aishell2/ASR/pruned_transducer_stateless5/pretrained.py
+++ b/egs/aishell2/ASR/pruned_transducer_stateless5/pretrained.py
@@ -159,7 +159,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -190,9 +190,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/train.py b/egs/aishell2/ASR/pruned_transducer_stateless5/train.py
index 4a228113d..74bf68ccb 100755
--- a/egs/aishell2/ASR/pruned_transducer_stateless5/train.py
+++ b/egs/aishell2/ASR/pruned_transducer_stateless5/train.py
@@ -218,7 +218,7 @@ def get_parser():
         "--initial-lr",
         type=float,
         default=0.003,
-        help="The initial learning rate.  This value should not need " "to be changed.",
+        help="The initial learning rate.  This value should not need to be changed.",
     )
 
     parser.add_argument(
@@ -241,7 +241,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
@@ -264,7 +264,7 @@ def get_parser():
         "--am-scale",
         type=float,
         default=0.0,
-        help="The scale to smooth the loss with am (output of encoder network)" "part.",
+        help="The scale to smooth the loss with am (output of encoder network) part.",
     )
 
     parser.add_argument(
diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/decode.py b/egs/aishell4/ASR/pruned_transducer_stateless5/decode.py
index cb533df35..37d766ec8 100755
--- a/egs/aishell4/ASR/pruned_transducer_stateless5/decode.py
+++ b/egs/aishell4/ASR/pruned_transducer_stateless5/decode.py
@@ -201,7 +201,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/export.py b/egs/aishell4/ASR/pruned_transducer_stateless5/export.py
index cc9b7b444..bf9856c60 100755
--- a/egs/aishell4/ASR/pruned_transducer_stateless5/export.py
+++ b/egs/aishell4/ASR/pruned_transducer_stateless5/export.py
@@ -136,7 +136,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     add_model_arguments(parser)
diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/pretrained.py b/egs/aishell4/ASR/pruned_transducer_stateless5/pretrained.py
index a234f9d65..ee898c303 100755
--- a/egs/aishell4/ASR/pruned_transducer_stateless5/pretrained.py
+++ b/egs/aishell4/ASR/pruned_transducer_stateless5/pretrained.py
@@ -172,7 +172,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -203,9 +203,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/train.py b/egs/aishell4/ASR/pruned_transducer_stateless5/train.py
index 73ee34284..d7c69f226 100755
--- a/egs/aishell4/ASR/pruned_transducer_stateless5/train.py
+++ b/egs/aishell4/ASR/pruned_transducer_stateless5/train.py
@@ -211,7 +211,7 @@ def get_parser():
         "--initial-lr",
         type=float,
         default=0.003,
-        help="The initial learning rate.  This value should not need " "to be changed.",
+        help="The initial learning rate.  This value should not need to be changed.",
     )
 
     parser.add_argument(
@@ -234,7 +234,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
@@ -257,7 +257,7 @@ def get_parser():
         "--am-scale",
         type=float,
         default=0.0,
-        help="The scale to smooth the loss with am (output of encoder network)" "part.",
+        help="The scale to smooth the loss with am (output of encoder network) part.",
     )
 
     parser.add_argument(
diff --git a/egs/alimeeting/ASR/pruned_transducer_stateless2/decode.py b/egs/alimeeting/ASR/pruned_transducer_stateless2/decode.py
index f3b63b222..e4a90ef71 100755
--- a/egs/alimeeting/ASR/pruned_transducer_stateless2/decode.py
+++ b/egs/alimeeting/ASR/pruned_transducer_stateless2/decode.py
@@ -189,7 +189,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
diff --git a/egs/alimeeting/ASR/pruned_transducer_stateless2/export.py b/egs/alimeeting/ASR/pruned_transducer_stateless2/export.py
index 538853f67..8e5cc6075 100644
--- a/egs/alimeeting/ASR/pruned_transducer_stateless2/export.py
+++ b/egs/alimeeting/ASR/pruned_transducer_stateless2/export.py
@@ -103,7 +103,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     return parser
diff --git a/egs/alimeeting/ASR/pruned_transducer_stateless2/pretrained.py b/egs/alimeeting/ASR/pruned_transducer_stateless2/pretrained.py
index 4da8d8e14..f5a0dd8c8 100644
--- a/egs/alimeeting/ASR/pruned_transducer_stateless2/pretrained.py
+++ b/egs/alimeeting/ASR/pruned_transducer_stateless2/pretrained.py
@@ -162,7 +162,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
@@ -192,9 +192,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
diff --git a/egs/alimeeting/ASR/pruned_transducer_stateless2/train.py b/egs/alimeeting/ASR/pruned_transducer_stateless2/train.py
index c9d2f3cb9..e57b5c859 100644
--- a/egs/alimeeting/ASR/pruned_transducer_stateless2/train.py
+++ b/egs/alimeeting/ASR/pruned_transducer_stateless2/train.py
@@ -185,7 +185,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
@@ -208,7 +208,7 @@ def get_parser():
         "--am-scale",
         type=float,
         default=0.0,
-        help="The scale to smooth the loss with am (output of encoder network)" "part.",
+        help="The scale to smooth the loss with am (output of encoder network) part.",
     )
 
     parser.add_argument(
diff --git a/egs/csj/ASR/local/compute_fbank_csj.py b/egs/csj/ASR/local/compute_fbank_csj.py
index c248aa668..667ad427e 100644
--- a/egs/csj/ASR/local/compute_fbank_csj.py
+++ b/egs/csj/ASR/local/compute_fbank_csj.py
@@ -25,7 +25,9 @@ from random import Random
 from typing import List, Tuple
 
 import torch
-from lhotse import (  # fmt: off; See the following for why LilcomChunkyWriter is preferred; https://github.com/k2-fsa/icefall/pull/404; https://github.com/lhotse-speech/lhotse/pull/527; fmt: on
+
+# fmt: off
+from lhotse import (  # See the following for why LilcomChunkyWriter is preferred; https://github.com/k2-fsa/icefall/pull/404; https://github.com/lhotse-speech/lhotse/pull/527
     CutSet,
     Fbank,
     FbankConfig,
@@ -34,6 +36,8 @@ from lhotse import (  # fmt: off; See the following for why LilcomChunkyWriter i
     SupervisionSet,
 )
 
+# fmt: on
+
 ARGPARSE_DESCRIPTION = """
 This script follows the espnet method of splitting the remaining core+noncore
 utterances into valid and train cutsets at an index which is by default 4000.
@@ -92,7 +96,7 @@ def make_cutset_blueprints(
     cut_set = cut_set.shuffle(Random(RNG_SEED))
 
     logging.info(
-        "Creating valid and train cuts from core and noncore," f"split at {split}."
+        "Creating valid and train cuts from core and noncore, split at {split}."
     )
     valid_set = CutSet.from_cuts(islice(cut_set, 0, split))
 
diff --git a/egs/csj/ASR/local/prepare_lang_char.py b/egs/csj/ASR/local/prepare_lang_char.py
index ef91f6e43..16107f543 100644
--- a/egs/csj/ASR/local/prepare_lang_char.py
+++ b/egs/csj/ASR/local/prepare_lang_char.py
@@ -87,7 +87,7 @@ def main():
     args = get_args()
 
     logging.basicConfig(
-        format=("%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] " "%(message)s"),
+        format=("%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"),
         level=logging.INFO,
     )
 
@@ -109,7 +109,7 @@ def main():
 
     words = set()
     logging.info(
-        f"Creating vocabulary from {args.train_cut.name}" f" at {args.trans_mode} mode."
+        f"Creating vocabulary from {args.train_cut.name} at {args.trans_mode} mode."
     )
     for cut in train_set:
         try:
@@ -120,7 +120,7 @@ def main():
             )
         except KeyError:
             raise KeyError(
-                f"Could not find {args.trans_mode} in " f"{cut.supervisions[0].custom}"
+                f"Could not find {args.trans_mode} in {cut.supervisions[0].custom}"
             )
         for t in text.split():
             if t in args.userdef_string:
diff --git a/egs/gigaspeech/ASR/conformer_ctc/asr_datamodule.py b/egs/gigaspeech/ASR/conformer_ctc/asr_datamodule.py
index 72dcd772a..9437c935c 100644
--- a/egs/gigaspeech/ASR/conformer_ctc/asr_datamodule.py
+++ b/egs/gigaspeech/ASR/conformer_ctc/asr_datamodule.py
@@ -183,7 +183,7 @@ class GigaSpeechAsrDataModule:
             "--small-dev",
             type=str2bool,
             default=False,
-            help="Should we use only 1000 utterances for dev " "(speeds up training)",
+            help="Should we use only 1000 utterances for dev (speeds up training)",
         )
 
     def train_dataloaders(self, cuts_train: CutSet) -> DataLoader:
diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py
index 7f114fba6..5c01d7190 100644
--- a/egs/gigaspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py
+++ b/egs/gigaspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py
@@ -195,7 +195,7 @@ class GigaSpeechAsrDataModule:
             "--small-dev",
             type=str2bool,
             default=False,
-            help="Should we use only 1000 utterances for dev " "(speeds up training)",
+            help="Should we use only 1000 utterances for dev (speeds up training)",
         )
 
     def train_dataloaders(
diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/decode.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/decode.py
index c0b17750e..8595c27bd 100755
--- a/egs/gigaspeech/ASR/pruned_transducer_stateless2/decode.py
+++ b/egs/gigaspeech/ASR/pruned_transducer_stateless2/decode.py
@@ -184,7 +184,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -498,7 +498,7 @@ def main():
         ]
         if len(filenames) == 0:
             raise ValueError(
-                f"No checkpoints found for" f" --iter {params.iter}, --avg {params.avg}"
+                f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
             )
         elif len(filenames) < params.avg:
             raise ValueError(
diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/export.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/export.py
index 3d1e7bc18..b6190e8a6 100755
--- a/egs/gigaspeech/ASR/pruned_transducer_stateless2/export.py
+++ b/egs/gigaspeech/ASR/pruned_transducer_stateless2/export.py
@@ -116,7 +116,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     return parser
@@ -155,7 +155,7 @@ def main():
         ]
         if len(filenames) == 0:
             raise ValueError(
-                f"No checkpoints found for" f" --iter {params.iter}, --avg {params.avg}"
+                f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
             )
         elif len(filenames) < params.avg:
             raise ValueError(
diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py
index f51584120..9edc42b61 100755
--- a/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py
+++ b/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py
@@ -176,7 +176,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
@@ -199,7 +199,7 @@ def get_parser():
         "--am-scale",
         type=float,
         default=0.0,
-        help="The scale to smooth the loss with am (output of encoder network)" "part.",
+        help="The scale to smooth the loss with am (output of encoder network) part.",
     )
 
     parser.add_argument(
diff --git a/egs/librispeech/ASR/conformer_ctc/pretrained.py b/egs/librispeech/ASR/conformer_ctc/pretrained.py
index 8200af866..30def9c40 100755
--- a/egs/librispeech/ASR/conformer_ctc/pretrained.py
+++ b/egs/librispeech/ASR/conformer_ctc/pretrained.py
@@ -236,9 +236,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/decode.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/decode.py
index 6854c82d8..365e8b8a7 100755
--- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/decode.py
+++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/decode.py
@@ -215,7 +215,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py
index 1aaa3b9cb..91f50cf67 100644
--- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py
+++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py
@@ -445,7 +445,7 @@ class EmformerAttention(nn.Module):
 
         if embed_dim % nhead != 0:
             raise ValueError(
-                f"embed_dim ({embed_dim}) is not a multiple of" f"nhead ({nhead})."
+                f"embed_dim ({embed_dim}) is not a multiple of nhead ({nhead})."
             )
 
         self.embed_dim = embed_dim
diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/export.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/export.py
index 334682ad6..09a3e96b0 100755
--- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/export.py
+++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/export.py
@@ -136,7 +136,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py
index 621eeb952..c93125c80 100755
--- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py
+++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py
@@ -211,7 +211,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py
index 3d8d4a18a..213115854 100755
--- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py
+++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py
@@ -263,7 +263,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
@@ -286,7 +286,7 @@ def get_parser():
         "--am-scale",
         type=float,
         default=0.0,
-        help="The scale to smooth the loss with am (output of encoder network)" "part.",
+        help="The scale to smooth the loss with am (output of encoder network) part.",
     )
 
     parser.add_argument(
diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/decode.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/decode.py
index d3c001942..78e1f4096 100755
--- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/decode.py
+++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/decode.py
@@ -215,7 +215,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer.py
index c3739566f..3cedf99b6 100644
--- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer.py
+++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer.py
@@ -445,7 +445,7 @@ class EmformerAttention(nn.Module):
 
         if embed_dim % nhead != 0:
             raise ValueError(
-                f"embed_dim ({embed_dim}) is not a multiple of" f"nhead ({nhead})."
+                f"embed_dim ({embed_dim}) is not a multiple of nhead ({nhead})."
             )
 
         self.embed_dim = embed_dim
diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export.py
index 998fb6e81..949214aec 100755
--- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export.py
+++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export.py
@@ -136,7 +136,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming_decode.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming_decode.py
index 618d8bb63..b2cb2c96b 100755
--- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming_decode.py
+++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming_decode.py
@@ -211,7 +211,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py
index 542f524a9..6a019fd63 100755
--- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py
+++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py
@@ -263,7 +263,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
@@ -286,7 +286,7 @@ def get_parser():
         "--am-scale",
         type=float,
         default=0.0,
-        help="The scale to smooth the loss with am (output of encoder network)" "part.",
+        help="The scale to smooth the loss with am (output of encoder network) part.",
     )
 
     parser.add_argument(
diff --git a/egs/librispeech/ASR/local/filter_cuts.py b/egs/librispeech/ASR/local/filter_cuts.py
index b3f0956c3..fbcc9e24a 100644
--- a/egs/librispeech/ASR/local/filter_cuts.py
+++ b/egs/librispeech/ASR/local/filter_cuts.py
@@ -79,7 +79,7 @@ def filter_cuts(cut_set: CutSet, sp: spm.SentencePieceProcessor):
         total += 1
         if c.duration < 1.0 or c.duration > 20.0:
             logging.warning(
-                f"Exclude cut with ID {c.id} from training. " f"Duration: {c.duration}"
+                f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
             )
             removed += 1
             return False
@@ -124,7 +124,7 @@ def filter_cuts(cut_set: CutSet, sp: spm.SentencePieceProcessor):
     ans = cut_set.filter(remove_short_and_long_utterances).to_eager()
     ratio = removed / total * 100
     logging.info(
-        f"Removed {removed} cuts from {total} cuts. " f"{ratio:.3f}% data is removed."
+        f"Removed {removed} cuts from {total} cuts. {ratio:.3f}% data is removed."
     )
     return ans
 
diff --git a/egs/librispeech/ASR/local/prepare_lm_training_data.py b/egs/librispeech/ASR/local/prepare_lm_training_data.py
index 32ae8c580..70343fef7 100755
--- a/egs/librispeech/ASR/local/prepare_lm_training_data.py
+++ b/egs/librispeech/ASR/local/prepare_lm_training_data.py
@@ -137,7 +137,7 @@ def main():
     for i in range(num_sentences):
         if step and i % step == 0:
             logging.info(
-                f"Processed number of lines: {i} " f"({i/num_sentences*100: .3f}%)"
+                f"Processed number of lines: {i} ({i/num_sentences*100: .3f}%)"
             )
 
         word_ids = sentences[i]
diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/decode.py b/egs/librispeech/ASR/lstm_transducer_stateless/decode.py
index 79b21fab1..3ad08f56a 100755
--- a/egs/librispeech/ASR/lstm_transducer_stateless/decode.py
+++ b/egs/librispeech/ASR/lstm_transducer_stateless/decode.py
@@ -272,7 +272,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/export.py b/egs/librispeech/ASR/lstm_transducer_stateless/export.py
index 45fa6d662..e338342cc 100755
--- a/egs/librispeech/ASR/lstm_transducer_stateless/export.py
+++ b/egs/librispeech/ASR/lstm_transducer_stateless/export.py
@@ -172,7 +172,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     add_model_arguments(parser)
diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/jit_pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless/jit_pretrained.py
index 51f4a2e8a..c07956243 100755
--- a/egs/librispeech/ASR/lstm_transducer_stateless/jit_pretrained.py
+++ b/egs/librispeech/ASR/lstm_transducer_stateless/jit_pretrained.py
@@ -123,9 +123,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless/pretrained.py
index 9263b41b2..b3a34a9e3 100755
--- a/egs/librispeech/ASR/lstm_transducer_stateless/pretrained.py
+++ b/egs/librispeech/ASR/lstm_transducer_stateless/pretrained.py
@@ -166,7 +166,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -197,9 +197,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/streaming_decode.py b/egs/librispeech/ASR/lstm_transducer_stateless/streaming_decode.py
index 4cc2aabb2..961d8ddfb 100755
--- a/egs/librispeech/ASR/lstm_transducer_stateless/streaming_decode.py
+++ b/egs/librispeech/ASR/lstm_transducer_stateless/streaming_decode.py
@@ -199,7 +199,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/train.py b/egs/librispeech/ASR/lstm_transducer_stateless/train.py
index b9a68753e..a54108f6d 100755
--- a/egs/librispeech/ASR/lstm_transducer_stateless/train.py
+++ b/egs/librispeech/ASR/lstm_transducer_stateless/train.py
@@ -220,7 +220,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
@@ -243,7 +243,7 @@ def get_parser():
         "--am-scale",
         type=float,
         default=0.0,
-        help="The scale to smooth the loss with am (output of encoder network)" "part.",
+        help="The scale to smooth the loss with am (output of encoder network) part.",
     )
 
     parser.add_argument(
@@ -970,7 +970,7 @@ def run(rank, world_size, args):
         # the threshold
         if c.duration < 1.0 or c.duration > 20.0:
             logging.warning(
-                f"Exclude cut with ID {c.id} from training. " f"Duration: {c.duration}"
+                f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
             )
             return False
 
diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py b/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py
index 41602d207..69f695fef 100755
--- a/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py
+++ b/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py
@@ -295,7 +295,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/export.py b/egs/librispeech/ASR/lstm_transducer_stateless2/export.py
index 2a25cb46a..5977cb36d 100755
--- a/egs/librispeech/ASR/lstm_transducer_stateless2/export.py
+++ b/egs/librispeech/ASR/lstm_transducer_stateless2/export.py
@@ -225,7 +225,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     add_model_arguments(parser)
diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/jit_pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless2/jit_pretrained.py
index 40f11018f..728b09104 100755
--- a/egs/librispeech/ASR/lstm_transducer_stateless2/jit_pretrained.py
+++ b/egs/librispeech/ASR/lstm_transducer_stateless2/jit_pretrained.py
@@ -124,9 +124,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/ncnn-decode.py b/egs/librispeech/ASR/lstm_transducer_stateless2/ncnn-decode.py
index ab2f17480..3b471fa85 100755
--- a/egs/librispeech/ASR/lstm_transducer_stateless2/ncnn-decode.py
+++ b/egs/librispeech/ASR/lstm_transducer_stateless2/ncnn-decode.py
@@ -198,9 +198,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless2/pretrained.py
index 2983328bf..f3f272b9f 100755
--- a/egs/librispeech/ASR/lstm_transducer_stateless2/pretrained.py
+++ b/egs/librispeech/ASR/lstm_transducer_stateless2/pretrained.py
@@ -169,7 +169,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -200,9 +200,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-ncnn-decode.py b/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-ncnn-decode.py
index a787a00e6..baff15ea6 100755
--- a/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-ncnn-decode.py
+++ b/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-ncnn-decode.py
@@ -186,9 +186,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-onnx-decode.py b/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-onnx-decode.py
index e896fd510..34d2e5630 100755
--- a/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-onnx-decode.py
+++ b/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-onnx-decode.py
@@ -147,9 +147,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/train.py b/egs/librispeech/ASR/lstm_transducer_stateless2/train.py
index 056285c64..8736384b4 100755
--- a/egs/librispeech/ASR/lstm_transducer_stateless2/train.py
+++ b/egs/librispeech/ASR/lstm_transducer_stateless2/train.py
@@ -161,7 +161,7 @@ def get_parser():
         "--full-libri",
         type=str2bool,
         default=True,
-        help="When enabled, use 960h LibriSpeech. " "Otherwise, use 100h subset.",
+        help="When enabled, use 960h LibriSpeech. Otherwise, use 100h subset.",
     )
 
     parser.add_argument(
@@ -235,7 +235,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
@@ -258,7 +258,7 @@ def get_parser():
         "--am-scale",
         type=float,
         default=0.0,
-        help="The scale to smooth the loss with am (output of encoder network)" "part.",
+        help="The scale to smooth the loss with am (output of encoder network) part.",
     )
 
     parser.add_argument(
@@ -986,7 +986,7 @@ def filter_short_and_long_utterances(
         # the threshold
         if c.duration < 1.0 or c.duration > 20.0:
             logging.warning(
-                f"Exclude cut with ID {c.id} from training. " f"Duration: {c.duration}"
+                f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
             )
             return False
 
diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/decode.py b/egs/librispeech/ASR/lstm_transducer_stateless3/decode.py
index cba1ac689..b7953e5e3 100755
--- a/egs/librispeech/ASR/lstm_transducer_stateless3/decode.py
+++ b/egs/librispeech/ASR/lstm_transducer_stateless3/decode.py
@@ -290,7 +290,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/export.py b/egs/librispeech/ASR/lstm_transducer_stateless3/export.py
index 457bd472f..a82cad043 100755
--- a/egs/librispeech/ASR/lstm_transducer_stateless3/export.py
+++ b/egs/librispeech/ASR/lstm_transducer_stateless3/export.py
@@ -172,7 +172,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     add_model_arguments(parser)
diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/jit_pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless3/jit_pretrained.py
index 71b37ac55..237591a36 100755
--- a/egs/librispeech/ASR/lstm_transducer_stateless3/jit_pretrained.py
+++ b/egs/librispeech/ASR/lstm_transducer_stateless3/jit_pretrained.py
@@ -123,9 +123,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless3/pretrained.py
index e72f4ee42..f49e9c518 100755
--- a/egs/librispeech/ASR/lstm_transducer_stateless3/pretrained.py
+++ b/egs/librispeech/ASR/lstm_transducer_stateless3/pretrained.py
@@ -166,7 +166,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -197,9 +197,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/streaming_decode.py b/egs/librispeech/ASR/lstm_transducer_stateless3/streaming_decode.py
index dad6b905f..109746ed5 100755
--- a/egs/librispeech/ASR/lstm_transducer_stateless3/streaming_decode.py
+++ b/egs/librispeech/ASR/lstm_transducer_stateless3/streaming_decode.py
@@ -199,7 +199,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/train.py b/egs/librispeech/ASR/lstm_transducer_stateless3/train.py
index 97ca4b94c..f56b4fd83 100755
--- a/egs/librispeech/ASR/lstm_transducer_stateless3/train.py
+++ b/egs/librispeech/ASR/lstm_transducer_stateless3/train.py
@@ -230,7 +230,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
@@ -253,7 +253,7 @@ def get_parser():
         "--am-scale",
         type=float,
         default=0.0,
-        help="The scale to smooth the loss with am (output of encoder network)" "part.",
+        help="The scale to smooth the loss with am (output of encoder network) part.",
     )
 
     parser.add_argument(
@@ -987,7 +987,7 @@ def run(rank, world_size, args):
         # the threshold
         if c.duration < 1.0 or c.duration > 20.0:
             logging.warning(
-                f"Exclude cut with ID {c.id} from training. " f"Duration: {c.duration}"
+                f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
             )
             return False
 
diff --git a/egs/librispeech/ASR/pruned2_knowledge/asr_datamodule.py b/egs/librispeech/ASR/pruned2_knowledge/asr_datamodule.py
index 3dc9164f8..b839a4a4c 100644
--- a/egs/librispeech/ASR/pruned2_knowledge/asr_datamodule.py
+++ b/egs/librispeech/ASR/pruned2_knowledge/asr_datamodule.py
@@ -83,7 +83,7 @@ class LibriSpeechAsrDataModule:
             "--full-libri",
             type=str2bool,
             default=True,
-            help="When enabled, use 960h LibriSpeech. " "Otherwise, use 100h subset.",
+            help="When enabled, use 960h LibriSpeech. Otherwise, use 100h subset.",
         )
         group.add_argument(
             "--manifest-dir",
diff --git a/egs/librispeech/ASR/pruned2_knowledge/decode.py b/egs/librispeech/ASR/pruned2_knowledge/decode.py
index c3e7b01ab..40d14bb5a 100755
--- a/egs/librispeech/ASR/pruned2_knowledge/decode.py
+++ b/egs/librispeech/ASR/pruned2_knowledge/decode.py
@@ -182,7 +182,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
diff --git a/egs/librispeech/ASR/pruned2_knowledge/export.py b/egs/librispeech/ASR/pruned2_knowledge/export.py
index ce5f162bf..51020aa30 100755
--- a/egs/librispeech/ASR/pruned2_knowledge/export.py
+++ b/egs/librispeech/ASR/pruned2_knowledge/export.py
@@ -105,7 +105,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     return parser
diff --git a/egs/librispeech/ASR/pruned2_knowledge/train.py b/egs/librispeech/ASR/pruned2_knowledge/train.py
index c322abaf8..123d448bb 100755
--- a/egs/librispeech/ASR/pruned2_knowledge/train.py
+++ b/egs/librispeech/ASR/pruned2_knowledge/train.py
@@ -177,7 +177,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
@@ -200,7 +200,7 @@ def get_parser():
         "--am-scale",
         type=float,
         default=0.0,
-        help="The scale to smooth the loss with am (output of encoder network)" "part.",
+        help="The scale to smooth the loss with am (output of encoder network) part.",
     )
 
     parser.add_argument(
diff --git a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/decode.py b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/decode.py
index 891719f3d..0e3b7ff74 100755
--- a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/decode.py
+++ b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/decode.py
@@ -204,7 +204,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
diff --git a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/export.py b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/export.py
index 047a1d476..3612a2bfd 100755
--- a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/export.py
+++ b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/export.py
@@ -133,7 +133,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     add_model_arguments(parser)
diff --git a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/train.py b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/train.py
index 69e74cc57..ed3fa1521 100755
--- a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/train.py
+++ b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/train.py
@@ -209,7 +209,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
@@ -232,7 +232,7 @@ def get_parser():
         "--am-scale",
         type=float,
         default=0.0,
-        help="The scale to smooth the loss with am (output of encoder network)" "part.",
+        help="The scale to smooth the loss with am (output of encoder network) part.",
     )
 
     parser.add_argument(
@@ -898,7 +898,7 @@ def run(rank, world_size, args):
         # the threshold
         if c.duration < 1.0 or c.duration > 20.0:
             logging.warning(
-                f"Exclude cut with ID {c.id} from training. " f"Duration: {c.duration}"
+                f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
             )
             return False
 
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py
index 12bd7f9bb..0444afe40 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py
@@ -265,7 +265,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -703,7 +703,7 @@ def main():
         ]
         if len(filenames) == 0:
             raise ValueError(
-                f"No checkpoints found for" f" --iter {params.iter}, --avg {params.avg}"
+                f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
             )
         elif len(filenames) < params.avg:
             raise ValueError(
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/export.py b/egs/librispeech/ASR/pruned_transducer_stateless/export.py
index be45536d8..a19f9ab9a 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless/export.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless/export.py
@@ -105,7 +105,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless/pretrained.py
index 6e91e0501..2ed1725b4 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless/pretrained.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless/pretrained.py
@@ -168,7 +168,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -220,9 +220,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py
index ce8e2f348..fbc39fb65 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py
@@ -158,7 +158,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
@@ -519,7 +519,7 @@ def main():
         ]
         if len(filenames) == 0:
             raise ValueError(
-                f"No checkpoints found for" f" --iter {params.iter}, --avg {params.avg}"
+                f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
             )
         elif len(filenames) < params.avg:
             raise ValueError(
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/train.py b/egs/librispeech/ASR/pruned_transducer_stateless/train.py
index 7861df874..4dabbccc1 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless/train.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless/train.py
@@ -203,7 +203,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
@@ -226,7 +226,7 @@ def get_parser():
         "--am-scale",
         type=float,
         default=0.0,
-        help="The scale to smooth the loss with am (output of encoder network)" "part.",
+        help="The scale to smooth the loss with am (output of encoder network) part.",
     )
 
     parser.add_argument(
@@ -889,7 +889,7 @@ def run(rank, world_size, args):
         # the threshold
         if c.duration < 1.0 or c.duration > 20.0:
             logging.warning(
-                f"Exclude cut with ID {c.id} from training. " f"Duration: {c.duration}"
+                f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
             )
             return False
 
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py
index 92138a5ea..5f135f219 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py
@@ -271,7 +271,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
@@ -725,7 +725,7 @@ def main():
         ]
         if len(filenames) == 0:
             raise ValueError(
-                f"No checkpoints found for" f" --iter {params.iter}, --avg {params.avg}"
+                f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
             )
         elif len(filenames) < params.avg:
             raise ValueError(
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/export.py b/egs/librispeech/ASR/pruned_transducer_stateless2/export.py
index 4f1170bbc..984caf5f2 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless2/export.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless2/export.py
@@ -116,7 +116,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
@@ -168,7 +168,7 @@ def main():
         ]
         if len(filenames) == 0:
             raise ValueError(
-                f"No checkpoints found for" f" --iter {params.iter}, --avg {params.avg}"
+                f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
             )
         elif len(filenames) < params.avg:
             raise ValueError(
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless2/pretrained.py
index e5b5aeba5..013964720 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless2/pretrained.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless2/pretrained.py
@@ -168,7 +168,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -221,9 +221,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py
index 0eea3a782..bb08246d9 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py
@@ -158,7 +158,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
@@ -522,7 +522,7 @@ def main():
         ]
         if len(filenames) == 0:
             raise ValueError(
-                f"No checkpoints found for" f" --iter {params.iter}, --avg {params.avg}"
+                f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
             )
         elif len(filenames) < params.avg:
             raise ValueError(
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py
index f6702ef16..86333fc97 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py
@@ -208,7 +208,7 @@ def get_parser():
         "--initial-lr",
         type=float,
         default=0.003,
-        help="The initial learning rate.  This value should not need to " "be changed.",
+        help="The initial learning rate.  This value should not need to be changed.",
     )
 
     parser.add_argument(
@@ -231,7 +231,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
@@ -254,7 +254,7 @@ def get_parser():
         "--am-scale",
         type=float,
         default=0.0,
-        help="The scale to smooth the loss with am (output of encoder network)" "part.",
+        help="The scale to smooth the loss with am (output of encoder network) part.",
     )
 
     parser.add_argument(
@@ -947,7 +947,7 @@ def run(rank, world_size, args):
         # the threshold
         if c.duration < 1.0 or c.duration > 20.0:
             logging.warning(
-                f"Exclude cut with ID {c.id} from training. " f"Duration: {c.duration}"
+                f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
             )
             return False
 
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/decode-giga.py b/egs/librispeech/ASR/pruned_transducer_stateless3/decode-giga.py
index df24d9585..b4804ecde 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless3/decode-giga.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless3/decode-giga.py
@@ -188,7 +188,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -552,7 +552,7 @@ def main():
         ]
         if len(filenames) == 0:
             raise ValueError(
-                f"No checkpoints found for" f" --iter {params.iter}, --avg {params.avg}"
+                f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
             )
         elif len(filenames) < params.avg:
             raise ValueError(
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py
index 55585e08c..03137501f 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py
@@ -261,7 +261,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -681,7 +681,7 @@ def decode_one_batch(
         return {key: hyps}
     else:
         return {
-            (f"beam_size_{params.beam_size}_" f"temperature_{params.temperature}"): hyps
+            (f"beam_size_{params.beam_size}_temperature_{params.temperature}"): hyps
         }
 
 
@@ -963,7 +963,7 @@ def main():
         ]
         if len(filenames) == 0:
             raise ValueError(
-                f"No checkpoints found for" f" --iter {params.iter}, --avg {params.avg}"
+                f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
             )
         elif len(filenames) < params.avg:
             raise ValueError(
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/export.py b/egs/librispeech/ASR/pruned_transducer_stateless3/export.py
index 2e444353c..239bdc12f 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless3/export.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless3/export.py
@@ -231,7 +231,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
@@ -607,7 +607,7 @@ def main():
         ]
         if len(filenames) == 0:
             raise ValueError(
-                f"No checkpoints found for" f" --iter {params.iter}, --avg {params.avg}"
+                f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
             )
         elif len(filenames) < params.avg:
             raise ValueError(
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/jit_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless3/jit_pretrained.py
index 86cb45c09..0669284b3 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless3/jit_pretrained.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless3/jit_pretrained.py
@@ -142,9 +142,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_pretrained.py
index 825c6510b..550cf6aad 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_pretrained.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_pretrained.py
@@ -140,9 +140,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py
index 77bd6d13d..7c3dfc660 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py
@@ -177,7 +177,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -230,9 +230,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_decode.py
index e85d2060a..0e5111f33 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_decode.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_decode.py
@@ -159,7 +159,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
@@ -521,7 +521,7 @@ def main():
         ]
         if len(filenames) == 0:
             raise ValueError(
-                f"No checkpoints found for" f" --iter {params.iter}, --avg {params.avg}"
+                f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
             )
         elif len(filenames) < params.avg:
             raise ValueError(
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/train.py b/egs/librispeech/ASR/pruned_transducer_stateless3/train.py
index e9ceb60de..281ba4650 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless3/train.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless3/train.py
@@ -161,7 +161,7 @@ def get_parser():
         "--full-libri",
         type=str2bool,
         default=True,
-        help="When enabled, use 960h LibriSpeech. " "Otherwise, use 100h subset.",
+        help="When enabled, use 960h LibriSpeech. Otherwise, use 100h subset.",
     )
 
     parser.add_argument(
@@ -211,7 +211,7 @@ def get_parser():
         "--initial-lr",
         type=float,
         default=0.003,
-        help="The initial learning rate.  This value should not need " "to be changed.",
+        help="The initial learning rate.  This value should not need to be changed.",
     )
 
     parser.add_argument(
@@ -234,7 +234,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
@@ -257,7 +257,7 @@ def get_parser():
         "--am-scale",
         type=float,
         default=0.0,
-        help="The scale to smooth the loss with am (output of encoder network)" "part.",
+        help="The scale to smooth the loss with am (output of encoder network) part.",
     )
 
     parser.add_argument(
@@ -950,7 +950,7 @@ def filter_short_and_long_utterances(
         # the threshold
         if c.duration < 1.0 or c.duration > 20.0:
             logging.warning(
-                f"Exclude cut with ID {c.id} from training. " f"Duration: {c.duration}"
+                f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
             )
             return False
 
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py
index 2f9a60f13..f5cbc21f7 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py
@@ -306,7 +306,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/export.py b/egs/librispeech/ASR/pruned_transducer_stateless4/export.py
index 64ef89733..401b3ef3a 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless4/export.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless4/export.py
@@ -133,7 +133,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_decode.py
index d74d1c89d..c4e3cef16 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_decode.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_decode.py
@@ -175,7 +175,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py
index 97f3e56a9..cb56c8294 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py
@@ -237,7 +237,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
@@ -260,7 +260,7 @@ def get_parser():
         "--am-scale",
         type=float,
         default=0.0,
-        help="The scale to smooth the loss with am (output of encoder network)" "part.",
+        help="The scale to smooth the loss with am (output of encoder network) part.",
     )
 
     parser.add_argument(
@@ -994,7 +994,7 @@ def run(rank, world_size, args):
         # the threshold
         if c.duration < 1.0 or c.duration > 20.0:
             logging.warning(
-                f"Exclude cut with ID {c.id} from training. " f"Duration: {c.duration}"
+                f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
             )
             return False
 
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py
index 5c76afde6..8b993f638 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py
@@ -303,7 +303,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/export.py b/egs/librispeech/ASR/pruned_transducer_stateless5/export.py
index f0bfd3b4c..a4fad1e59 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless5/export.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless5/export.py
@@ -133,7 +133,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless5/pretrained.py
index 77ba0873b..74a2210c3 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless5/pretrained.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless5/pretrained.py
@@ -166,7 +166,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -197,9 +197,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless5/streaming_decode.py
index e750f5554..064811f1c 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless5/streaming_decode.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless5/streaming_decode.py
@@ -175,7 +175,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/train.py b/egs/librispeech/ASR/pruned_transducer_stateless5/train.py
index a1a810d3e..436620744 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless5/train.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless5/train.py
@@ -246,7 +246,7 @@ def get_parser():
         "--initial-lr",
         type=float,
         default=0.003,
-        help="The initial learning rate.  This value should not need " "to be changed.",
+        help="The initial learning rate.  This value should not need to be changed.",
     )
 
     parser.add_argument(
@@ -269,7 +269,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
@@ -292,7 +292,7 @@ def get_parser():
         "--am-scale",
         type=float,
         default=0.0,
-        help="The scale to smooth the loss with am (output of encoder network)" "part.",
+        help="The scale to smooth the loss with am (output of encoder network) part.",
     )
 
     parser.add_argument(
@@ -1025,7 +1025,7 @@ def run(rank, world_size, args):
         # the threshold
         if c.duration < 1.0 or c.duration > 20.0:
             logging.warning(
-                f"Exclude cut with ID {c.id} from training. " f"Duration: {c.duration}"
+                f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
             )
             return False
 
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless6/decode.py
index 3734564fe..fd9de052a 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless6/decode.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless6/decode.py
@@ -208,7 +208,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/export.py b/egs/librispeech/ASR/pruned_transducer_stateless6/export.py
index 3d1e7bc18..b6190e8a6 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless6/export.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless6/export.py
@@ -116,7 +116,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     return parser
@@ -155,7 +155,7 @@ def main():
         ]
         if len(filenames) == 0:
             raise ValueError(
-                f"No checkpoints found for" f" --iter {params.iter}, --avg {params.avg}"
+                f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
             )
         elif len(filenames) < params.avg:
             raise ValueError(
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/train.py b/egs/librispeech/ASR/pruned_transducer_stateless6/train.py
index a24becb14..8f4d3b879 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless6/train.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless6/train.py
@@ -201,7 +201,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
@@ -224,7 +224,7 @@ def get_parser():
         "--am-scale",
         type=float,
         default=0.0,
-        help="The scale to smooth the loss with am (output of encoder network)" "part.",
+        help="The scale to smooth the loss with am (output of encoder network) part.",
     )
 
     parser.add_argument(
@@ -986,7 +986,7 @@ def run(rank, world_size, args):
         # the threshold
         if c.duration < 1.0 or c.duration > 20.0:
             logging.warning(
-                f"Exclude cut with ID {c.id} from training. " f"Duration: {c.duration}"
+                f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
             )
             return False
 
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py
index 162966df8..bc15948fc 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py
@@ -272,7 +272,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/export.py b/egs/librispeech/ASR/pruned_transducer_stateless7/export.py
index 57af52fb1..9a6f3ed37 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless7/export.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7/export.py
@@ -176,7 +176,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     add_model_arguments(parser)
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/jit_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7/jit_pretrained.py
index f469442ed..5af6dae25 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless7/jit_pretrained.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7/jit_pretrained.py
@@ -93,9 +93,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7/pretrained.py
index 758e0c036..d05bafcfb 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless7/pretrained.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7/pretrained.py
@@ -177,7 +177,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -208,9 +208,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py
index 7160fc54a..b27c573ab 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py
@@ -267,7 +267,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
@@ -290,7 +290,7 @@ def get_parser():
         "--am-scale",
         type=float,
         default=0.0,
-        help="The scale to smooth the loss with am (output of encoder network)" "part.",
+        help="The scale to smooth the loss with am (output of encoder network) part.",
     )
 
     parser.add_argument(
@@ -1031,7 +1031,7 @@ def run(rank, world_size, args):
         # the threshold
         if c.duration < 1.0 or c.duration > 20.0:
             logging.warning(
-                f"Exclude cut with ID {c.id} from training. " f"Duration: {c.duration}"
+                f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
             )
             return False
 
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless8/decode.py
index 3d89ae00a..e61367134 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless8/decode.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless8/decode.py
@@ -273,7 +273,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/export.py b/egs/librispeech/ASR/pruned_transducer_stateless8/export.py
index 0a962149d..d4a228b47 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless8/export.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless8/export.py
@@ -176,7 +176,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     add_model_arguments(parser)
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/jit_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless8/jit_pretrained.py
index c458ee5a9..129497d5a 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless8/jit_pretrained.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless8/jit_pretrained.py
@@ -93,9 +93,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless8/pretrained.py
index f1f0771ef..486d9d74e 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless8/pretrained.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless8/pretrained.py
@@ -177,7 +177,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -208,9 +208,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/train.py b/egs/librispeech/ASR/pruned_transducer_stateless8/train.py
index ba8ed3ea8..abe249c7b 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless8/train.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless8/train.py
@@ -212,7 +212,7 @@ def get_parser():
         "--full-libri",
         type=str2bool,
         default=True,
-        help="When enabled, use 960h LibriSpeech. " "Otherwise, use 100h subset.",
+        help="When enabled, use 960h LibriSpeech. Otherwise, use 100h subset.",
     )
 
     parser.add_argument(
@@ -282,7 +282,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
@@ -305,7 +305,7 @@ def get_parser():
         "--am-scale",
         type=float,
         default=0.0,
-        help="The scale to smooth the loss with am (output of encoder network)" "part.",
+        help="The scale to smooth the loss with am (output of encoder network) part.",
     )
 
     parser.add_argument(
@@ -1030,7 +1030,7 @@ def filter_short_and_long_utterances(
         # the threshold
         if c.duration < 1.0 or c.duration > 20.0:
             logging.warning(
-                f"Exclude cut with ID {c.id} from training. " f"Duration: {c.duration}"
+                f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
             )
             return False
 
diff --git a/egs/librispeech/ASR/streaming_conformer_ctc/streaming_decode.py b/egs/librispeech/ASR/streaming_conformer_ctc/streaming_decode.py
index 3965bd5c3..a26d0b789 100755
--- a/egs/librispeech/ASR/streaming_conformer_ctc/streaming_decode.py
+++ b/egs/librispeech/ASR/streaming_conformer_ctc/streaming_decode.py
@@ -87,7 +87,7 @@ def get_parser():
         "--tailing-num-frames",
         type=int,
         default=20,
-        help="tailing dummy frames padded to the right," "only used during decoding",
+        help="tailing dummy frames padded to the right, only used during decoding",
     )
 
     parser.add_argument(
diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py
index 993a7cab5..95d1b273a 100644
--- a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py
+++ b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py
@@ -86,7 +86,7 @@ class LibriSpeechAsrDataModule:
             "--full-libri",
             type=str2bool,
             default=True,
-            help="When enabled, use 960h LibriSpeech. " "Otherwise, use 100h subset.",
+            help="When enabled, use 960h LibriSpeech. Otherwise, use 100h subset.",
         )
         group.add_argument(
             "--manifest-dir",
diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py b/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py
index addadbe4e..fde724866 100755
--- a/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py
+++ b/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py
@@ -138,9 +138,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
diff --git a/egs/librispeech/ASR/transducer/pretrained.py b/egs/librispeech/ASR/transducer/pretrained.py
index b1ff7b2b1..511610245 100755
--- a/egs/librispeech/ASR/transducer/pretrained.py
+++ b/egs/librispeech/ASR/transducer/pretrained.py
@@ -188,9 +188,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
diff --git a/egs/librispeech/ASR/transducer_stateless/compute_ali.py b/egs/librispeech/ASR/transducer_stateless/compute_ali.py
index c91198bb9..f479389df 100755
--- a/egs/librispeech/ASR/transducer_stateless/compute_ali.py
+++ b/egs/librispeech/ASR/transducer_stateless/compute_ali.py
@@ -124,7 +124,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     return parser
diff --git a/egs/librispeech/ASR/transducer_stateless/decode.py b/egs/librispeech/ASR/transducer_stateless/decode.py
index 688e214c8..643238f1b 100755
--- a/egs/librispeech/ASR/transducer_stateless/decode.py
+++ b/egs/librispeech/ASR/transducer_stateless/decode.py
@@ -171,7 +171,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
diff --git a/egs/librispeech/ASR/transducer_stateless/export.py b/egs/librispeech/ASR/transducer_stateless/export.py
index c617e6c4c..89359f1a4 100755
--- a/egs/librispeech/ASR/transducer_stateless/export.py
+++ b/egs/librispeech/ASR/transducer_stateless/export.py
@@ -109,7 +109,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     return parser
diff --git a/egs/librispeech/ASR/transducer_stateless/pretrained.py b/egs/librispeech/ASR/transducer_stateless/pretrained.py
index c393974e6..915a6069d 100755
--- a/egs/librispeech/ASR/transducer_stateless/pretrained.py
+++ b/egs/librispeech/ASR/transducer_stateless/pretrained.py
@@ -167,7 +167,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -196,9 +196,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py
index c86125f44..bcb883fa5 100755
--- a/egs/librispeech/ASR/transducer_stateless/train.py
+++ b/egs/librispeech/ASR/transducer_stateless/train.py
@@ -136,7 +136,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
diff --git a/egs/librispeech/ASR/transducer_stateless2/decode.py b/egs/librispeech/ASR/transducer_stateless2/decode.py
index c642b16bd..9a6363629 100755
--- a/egs/librispeech/ASR/transducer_stateless2/decode.py
+++ b/egs/librispeech/ASR/transducer_stateless2/decode.py
@@ -171,7 +171,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
diff --git a/egs/librispeech/ASR/transducer_stateless2/export.py b/egs/librispeech/ASR/transducer_stateless2/export.py
index 229c514b9..d33d02642 100755
--- a/egs/librispeech/ASR/transducer_stateless2/export.py
+++ b/egs/librispeech/ASR/transducer_stateless2/export.py
@@ -104,7 +104,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     return parser
diff --git a/egs/librispeech/ASR/transducer_stateless2/pretrained.py b/egs/librispeech/ASR/transducer_stateless2/pretrained.py
index 9053bc6e0..0738f30c0 100755
--- a/egs/librispeech/ASR/transducer_stateless2/pretrained.py
+++ b/egs/librispeech/ASR/transducer_stateless2/pretrained.py
@@ -167,7 +167,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -196,9 +196,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
diff --git a/egs/librispeech/ASR/transducer_stateless2/train.py b/egs/librispeech/ASR/transducer_stateless2/train.py
index 71c9c5df7..68e247f23 100755
--- a/egs/librispeech/ASR/transducer_stateless2/train.py
+++ b/egs/librispeech/ASR/transducer_stateless2/train.py
@@ -136,7 +136,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
diff --git a/egs/librispeech/ASR/transducer_stateless_multi_datasets/decode.py b/egs/librispeech/ASR/transducer_stateless_multi_datasets/decode.py
index 253821028..56ad558c6 100755
--- a/egs/librispeech/ASR/transducer_stateless_multi_datasets/decode.py
+++ b/egs/librispeech/ASR/transducer_stateless_multi_datasets/decode.py
@@ -172,7 +172,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
diff --git a/egs/librispeech/ASR/transducer_stateless_multi_datasets/export.py b/egs/librispeech/ASR/transducer_stateless_multi_datasets/export.py
index 97b0eea4a..3735ef452 100755
--- a/egs/librispeech/ASR/transducer_stateless_multi_datasets/export.py
+++ b/egs/librispeech/ASR/transducer_stateless_multi_datasets/export.py
@@ -110,7 +110,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     return parser
diff --git a/egs/librispeech/ASR/transducer_stateless_multi_datasets/pretrained.py b/egs/librispeech/ASR/transducer_stateless_multi_datasets/pretrained.py
index c698a35b0..8c7726367 100755
--- a/egs/librispeech/ASR/transducer_stateless_multi_datasets/pretrained.py
+++ b/egs/librispeech/ASR/transducer_stateless_multi_datasets/pretrained.py
@@ -167,7 +167,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -196,9 +196,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
diff --git a/egs/librispeech/ASR/transducer_stateless_multi_datasets/train.py b/egs/librispeech/ASR/transducer_stateless_multi_datasets/train.py
index e5b7dc390..88987d91c 100755
--- a/egs/librispeech/ASR/transducer_stateless_multi_datasets/train.py
+++ b/egs/librispeech/ASR/transducer_stateless_multi_datasets/train.py
@@ -114,7 +114,7 @@ def get_parser():
         "--full-libri",
         type=str2bool,
         default=True,
-        help="When enabled, use 960h LibriSpeech. " "Otherwise, use 100h subset.",
+        help="When enabled, use 960h LibriSpeech. Otherwise, use 100h subset.",
     )
 
     parser.add_argument(
@@ -169,7 +169,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
diff --git a/egs/spgispeech/ASR/pruned_transducer_stateless2/decode.py b/egs/spgispeech/ASR/pruned_transducer_stateless2/decode.py
index 098da3ff0..219c96d60 100755
--- a/egs/spgispeech/ASR/pruned_transducer_stateless2/decode.py
+++ b/egs/spgispeech/ASR/pruned_transducer_stateless2/decode.py
@@ -183,7 +183,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -505,7 +505,7 @@ def main():
         ]
         if len(filenames) == 0:
             raise ValueError(
-                f"No checkpoints found for" f" --iter {params.iter}, --avg {params.avg}"
+                f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
             )
         elif len(filenames) < params.avg:
             raise ValueError(
diff --git a/egs/spgispeech/ASR/pruned_transducer_stateless2/export.py b/egs/spgispeech/ASR/pruned_transducer_stateless2/export.py
index e79cb300d..68763808a 100755
--- a/egs/spgispeech/ASR/pruned_transducer_stateless2/export.py
+++ b/egs/spgispeech/ASR/pruned_transducer_stateless2/export.py
@@ -115,7 +115,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     return parser
diff --git a/egs/spgispeech/ASR/pruned_transducer_stateless2/train.py b/egs/spgispeech/ASR/pruned_transducer_stateless2/train.py
index 213635894..d943180b1 100755
--- a/egs/spgispeech/ASR/pruned_transducer_stateless2/train.py
+++ b/egs/spgispeech/ASR/pruned_transducer_stateless2/train.py
@@ -153,7 +153,7 @@ def get_parser():
         "--initial-lr",
         type=float,
         default=0.003,
-        help="The initial learning rate.  This value should not need to be " "changed.",
+        help="The initial learning rate.  This value should not need to be changed.",
     )
 
     parser.add_argument(
@@ -176,7 +176,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
@@ -199,7 +199,7 @@ def get_parser():
         "--am-scale",
         type=float,
         default=0.0,
-        help="The scale to smooth the loss with am (output of encoder network)" "part.",
+        help="The scale to smooth the loss with am (output of encoder network) part.",
     )
 
     parser.add_argument(
diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/decode.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/decode.py
index 82e1a9437..bf91fef7e 100755
--- a/egs/tal_csasr/ASR/pruned_transducer_stateless5/decode.py
+++ b/egs/tal_csasr/ASR/pruned_transducer_stateless5/decode.py
@@ -208,7 +208,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/export.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/export.py
index d0875c5f5..bc33dd160 100755
--- a/egs/tal_csasr/ASR/pruned_transducer_stateless5/export.py
+++ b/egs/tal_csasr/ASR/pruned_transducer_stateless5/export.py
@@ -139,7 +139,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     add_model_arguments(parser)
diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/pretrained.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/pretrained.py
index da4e3bc2f..3305f5bd3 100755
--- a/egs/tal_csasr/ASR/pruned_transducer_stateless5/pretrained.py
+++ b/egs/tal_csasr/ASR/pruned_transducer_stateless5/pretrained.py
@@ -165,7 +165,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -196,9 +196,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/train.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/train.py
index 97d434157..43f3231ba 100755
--- a/egs/tal_csasr/ASR/pruned_transducer_stateless5/train.py
+++ b/egs/tal_csasr/ASR/pruned_transducer_stateless5/train.py
@@ -212,7 +212,7 @@ def get_parser():
         "--initial-lr",
         type=float,
         default=0.003,
-        help="The initial learning rate.  This value should not need " "to be changed.",
+        help="The initial learning rate.  This value should not need to be changed.",
     )
 
     parser.add_argument(
@@ -235,7 +235,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
@@ -258,7 +258,7 @@ def get_parser():
         "--am-scale",
         type=float,
         default=0.0,
-        help="The scale to smooth the loss with am (output of encoder network)" "part.",
+        help="The scale to smooth the loss with am (output of encoder network) part.",
     )
 
     parser.add_argument(
diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless/decode.py b/egs/tedlium3/ASR/pruned_transducer_stateless/decode.py
index 8ca875c24..38f2ae83c 100755
--- a/egs/tedlium3/ASR/pruned_transducer_stateless/decode.py
+++ b/egs/tedlium3/ASR/pruned_transducer_stateless/decode.py
@@ -172,7 +172,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless/export.py b/egs/tedlium3/ASR/pruned_transducer_stateless/export.py
index 71a9e2d71..aa22f82ec 100644
--- a/egs/tedlium3/ASR/pruned_transducer_stateless/export.py
+++ b/egs/tedlium3/ASR/pruned_transducer_stateless/export.py
@@ -106,7 +106,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     return parser
diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless/pretrained.py b/egs/tedlium3/ASR/pruned_transducer_stateless/pretrained.py
index e8a453c80..8a89c3578 100644
--- a/egs/tedlium3/ASR/pruned_transducer_stateless/pretrained.py
+++ b/egs/tedlium3/ASR/pruned_transducer_stateless/pretrained.py
@@ -165,7 +165,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
@@ -202,9 +202,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless/train.py b/egs/tedlium3/ASR/pruned_transducer_stateless/train.py
index 59d80a0d8..170f37767 100755
--- a/egs/tedlium3/ASR/pruned_transducer_stateless/train.py
+++ b/egs/tedlium3/ASR/pruned_transducer_stateless/train.py
@@ -133,7 +133,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
@@ -156,7 +156,7 @@ def get_parser():
         "--am-scale",
         type=float,
         default=0.0,
-        help="The scale to smooth the loss with am (output of encoder network)" "part.",
+        help="The scale to smooth the loss with am (output of encoder network) part.",
     )
 
     parser.add_argument(
diff --git a/egs/tedlium3/ASR/transducer_stateless/decode.py b/egs/tedlium3/ASR/transducer_stateless/decode.py
index e5ab2c107..01f08ce59 100755
--- a/egs/tedlium3/ASR/transducer_stateless/decode.py
+++ b/egs/tedlium3/ASR/transducer_stateless/decode.py
@@ -130,7 +130,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
diff --git a/egs/tedlium3/ASR/transducer_stateless/export.py b/egs/tedlium3/ASR/transducer_stateless/export.py
index c2ec7a590..48dcdc736 100644
--- a/egs/tedlium3/ASR/transducer_stateless/export.py
+++ b/egs/tedlium3/ASR/transducer_stateless/export.py
@@ -110,7 +110,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     return parser
diff --git a/egs/tedlium3/ASR/transducer_stateless/pretrained.py b/egs/tedlium3/ASR/transducer_stateless/pretrained.py
index 070b070a7..81afd6a4e 100644
--- a/egs/tedlium3/ASR/transducer_stateless/pretrained.py
+++ b/egs/tedlium3/ASR/transducer_stateless/pretrained.py
@@ -127,7 +127,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -221,9 +221,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
diff --git a/egs/tedlium3/ASR/transducer_stateless/train.py b/egs/tedlium3/ASR/transducer_stateless/train.py
index 4fc13b1da..6fed32e81 100755
--- a/egs/tedlium3/ASR/transducer_stateless/train.py
+++ b/egs/tedlium3/ASR/transducer_stateless/train.py
@@ -133,7 +133,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
diff --git a/egs/timit/ASR/tdnn_ligru_ctc/pretrained.py b/egs/timit/ASR/tdnn_ligru_ctc/pretrained.py
index 4ef134412..3fdf3b855 100644
--- a/egs/timit/ASR/tdnn_ligru_ctc/pretrained.py
+++ b/egs/timit/ASR/tdnn_ligru_ctc/pretrained.py
@@ -138,9 +138,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
diff --git a/egs/timit/ASR/tdnn_lstm_ctc/pretrained.py b/egs/timit/ASR/tdnn_lstm_ctc/pretrained.py
index 3f143912e..98c746ce5 100644
--- a/egs/timit/ASR/tdnn_lstm_ctc/pretrained.py
+++ b/egs/timit/ASR/tdnn_lstm_ctc/pretrained.py
@@ -138,9 +138,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py
index cd9ed57b9..04602ea2e 100755
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py
@@ -248,7 +248,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/export.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/export.py
index df2fc5df5..8c4fbdd47 100755
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/export.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/export.py
@@ -205,7 +205,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     return parser
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/jit_pretrained.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/jit_pretrained.py
index 42ffbcfb8..f90dd2b43 100755
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/jit_pretrained.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/jit_pretrained.py
@@ -145,9 +145,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_pretrained.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_pretrained.py
index ca1e408fa..9e34b4427 100755
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_pretrained.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_pretrained.py
@@ -149,9 +149,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/pretrained.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/pretrained.py
index aaf7ac874..bc499f3dd 100755
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/pretrained.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/pretrained.py
@@ -158,7 +158,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
@@ -188,9 +188,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py
index 7aba0711d..43fa0d01b 100644
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py
@@ -217,7 +217,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
@@ -240,7 +240,7 @@ def get_parser():
         "--am-scale",
         type=float,
         default=0.0,
-        help="The scale to smooth the loss with am (output of encoder network)" "part.",
+        help="The scale to smooth the loss with am (output of encoder network) part.",
     )
 
     parser.add_argument(
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py
index 166497c31..7bd1177bd 100755
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py
@@ -244,7 +244,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/export.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/export.py
index ff2c4db38..35577c327 100644
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/export.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/export.py
@@ -131,7 +131,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     add_model_arguments(parser)
 
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/pretrained.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/pretrained.py
index 7e4829a60..1cac20435 100644
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/pretrained.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/pretrained.py
@@ -157,7 +157,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
@@ -188,9 +188,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_decode.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_decode.py
index 6909f40be..c7863415b 100644
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_decode.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_decode.py
@@ -201,7 +201,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/train.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/train.py
index 5f614e77c..440b65f32 100755
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/train.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/train.py
@@ -258,7 +258,7 @@ def get_parser():
         "--initial-lr",
         type=float,
         default=0.003,
-        help="The initial learning rate.  This value should not need " "to be changed.",
+        help="The initial learning rate.  This value should not need to be changed.",
     )
 
     parser.add_argument(
@@ -281,7 +281,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
@@ -304,7 +304,7 @@ def get_parser():
         "--am-scale",
         type=float,
         default=0.0,
-        help="The scale to smooth the loss with am (output of encoder network)" "part.",
+        help="The scale to smooth the loss with am (output of encoder network) part.",
     )
 
     parser.add_argument(
diff --git a/egs/yesno/ASR/tdnn/pretrained.py b/egs/yesno/ASR/tdnn/pretrained.py
index 88d5eca5d..65be77db1 100755
--- a/egs/yesno/ASR/tdnn/pretrained.py
+++ b/egs/yesno/ASR/tdnn/pretrained.py
@@ -99,9 +99,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
diff --git a/icefall/shared/make_kn_lm.py b/icefall/shared/make_kn_lm.py
index b1220d55e..7150297d6 100755
--- a/icefall/shared/make_kn_lm.py
+++ b/icefall/shared/make_kn_lm.py
@@ -45,13 +45,13 @@ parser.add_argument(
 )
 args = parser.parse_args()
 
-default_encoding = (
-    "latin-1"  # For encoding-agnostic scripts, we assume byte stream as input.
-)
+# For encoding-agnostic scripts, we assume byte stream as input.
 # Need to be very careful about the use of strip() and split()
 # in this case, because there is a latin-1 whitespace character
 # (nbsp) which is part of the unicode encoding range.
 # Ref: kaldi/egs/wsj/s5/utils/lang/bpe/prepend_words.py @ 69cd717
+default_encoding = "latin-1"
+
 strip_chars = " \t\r\n"
 whitespace = re.compile("[ \t]+")
 
@@ -65,9 +65,8 @@ class CountsForHistory:
         # The 'lambda: defaultdict(float)' is an anonymous function taking no
         # arguments that returns a new defaultdict(float).
         self.word_to_count = defaultdict(int)
-        self.word_to_context = defaultdict(
-            set
-        )  # using a set to count the number of unique contexts
+        # using a set to count the number of unique contexts
+        self.word_to_context = defaultdict(set)
         self.word_to_f = dict()  # discounted probability
         self.word_to_bow = dict()  # back-off weight
         self.total_count = 0
@@ -151,9 +150,8 @@ class NgramCounts:
 
     def add_raw_counts_from_standard_input(self):
         lines_processed = 0
-        infile = io.TextIOWrapper(
-            sys.stdin.buffer, encoding=default_encoding
-        )  # byte stream as input
+        # byte stream as input
+        infile = io.TextIOWrapper(sys.stdin.buffer, encoding=default_encoding)
         for line in infile:
             line = line.strip(strip_chars)
             self.add_raw_counts_from_line(line)
@@ -187,11 +185,10 @@ class NgramCounts:
         # This constant is used similarly to absolute discounting.
         # Return value: d is a list of floats, where d[N+1] = D_N
 
-        self.d = [
-            0
-        ]  # for the lowest order, i.e., 1-gram, we do not need to discount, thus the constant is 0
+        # for the lowest order, i.e., 1-gram, we do not need to discount, thus the constant is 0
         # This is a special case: as we currently assumed having seen all vocabularies in the dictionary,
         # but perhaps this is not the case for some other scenarios.
+        self.d = [0]
         for n in range(1, self.ngram_order):
             this_order_counts = self.counts[n]
             n1 = 0
@@ -201,11 +198,11 @@ class NgramCounts:
                 n1 += stat[1]
                 n2 += stat[2]
             assert n1 + 2 * n2 > 0
-            self.d.append(
-                max(0.1, n1 * 1.0) / (n1 + 2 * n2)
-            )  # We are doing this max(0.001, xxx) to avoid zero discounting constant D due to n1=0,
+
+            # We are doing this max(0.001, xxx) to avoid zero discounting constant D due to n1=0,
             # which could happen if the number of symbols is small.
             # Otherwise, zero discounting constant can cause division by zero in computing BOW.
+            self.d.append(max(0.1, n1 * 1.0) / (n1 + 2 * n2))
 
     def cal_f(self):
         # f(a_z) is a probability distribution of word sequence a_z.
@@ -286,11 +283,8 @@ class NgramCounts:
                         sum_z1_f_z = 0
                         _ = a_[1:]
                         _counts_for_hist = self.counts[len(_)][_]
-                        for (
-                            u
-                        ) in (
-                            a_counts_for_hist.word_to_count.keys()
-                        ):  # Should be careful here: what is Z1
+                        # Should be careful here: what is Z1
+                        for u in a_counts_for_hist.word_to_count.keys():
                             sum_z1_f_z += _counts_for_hist.word_to_f[u]
 
                         if sum_z1_f_z < 1:

From 349dae35037ee468340e889fd99704336c16c2a2 Mon Sep 17 00:00:00 2001
From: Desh Raj 
Date: Thu, 17 Nov 2022 14:18:50 -0500
Subject: [PATCH 014/174] add revision commit to git blame ignore

---
 .git-blame-ignore-revs | 1 +
 1 file changed, 1 insertion(+)

diff --git a/.git-blame-ignore-revs b/.git-blame-ignore-revs
index be5901517..5d65b98e9 100644
--- a/.git-blame-ignore-revs
+++ b/.git-blame-ignore-revs
@@ -1,2 +1,3 @@
 # Migrate to 88 characters per line (see: https://github.com/lhotse-speech/lhotse/issues/890)
 107df3b115a58f1b68a6458c3f94a130004be34c
+d31db010371a4128856480382876acdc0d1739ed

From fbe1e35b74e3ae593e3798e1a56e9d5b708a6767 Mon Sep 17 00:00:00 2001
From: Desh Raj 
Date: Fri, 18 Nov 2022 09:24:07 -0500
Subject: [PATCH 015/174] update code style docs

---
 docs/source/contributing/code-style.rst | 15 +++++++++++----
 1 file changed, 11 insertions(+), 4 deletions(-)

diff --git a/docs/source/contributing/code-style.rst b/docs/source/contributing/code-style.rst
index 7d61a3ba1..3baaaeec2 100644
--- a/docs/source/contributing/code-style.rst
+++ b/docs/source/contributing/code-style.rst
@@ -11,9 +11,9 @@ We use the following tools to make the code style to be as consistent as possibl
 
 The following versions of the above tools are used:
 
-  - ``black == 12.6b0``
-  - ``flake8 == 3.9.2``
-  - ``isort == 5.9.2``
+  - ``black == 22.3.0``
+  - ``flake8 == 5.0.4``
+  - ``isort == 5.10.1``
 
 After running the following commands:
 
@@ -54,10 +54,17 @@ it should succeed this time:
 If you want to check the style of your code before ``git commit``, you
 can do the following:
 
+  .. code-block:: bash
+
+    $ pre-commit install
+    $ pre-commit run
+
+Or without installing the pre-commit hooks:
+
   .. code-block:: bash
 
     $ cd icefall
-    $ pip install black==21.6b0 flake8==3.9.2 isort==5.9.2
+    $ pip install black==22.3.0 flake8==5.0.4 isort==5.10.1
     $ black --check your_changed_file.py
     $ black your_changed_file.py  # modify it in-place
     $

From 53454701cb69ec23be8a37c6ab69f1cf5104585d Mon Sep 17 00:00:00 2001
From: marcoyang 
Date: Tue, 22 Nov 2022 11:39:21 +0800
Subject: [PATCH 016/174] fix segmentation fault

---
 egs/aidatatang_200zh/ASR/prepare.sh | 3 +++
 egs/aishell/ASR/prepare.sh          | 3 +++
 egs/aishell2/ASR/prepare.sh         | 3 +++
 egs/aishell4/ASR/prepare.sh         | 3 +++
 egs/alimeeting/ASR/prepare.sh       | 3 +++
 egs/csj/ASR/prepare.sh              | 3 +++
 egs/gigaspeech/ASR/prepare.sh       | 3 +++
 egs/librispeech/ASR/prepare.sh      | 3 +++
 egs/ptb/LM/prepare.sh               | 3 +++
 egs/spgispeech/ASR/prepare.sh       | 3 +++
 egs/tal_csasr/ASR/prepare.sh        | 3 +++
 egs/tedlium3/ASR/prepare.sh         | 3 +++
 egs/timit/ASR/prepare.sh            | 3 +++
 egs/wenetspeech/ASR/prepare.sh      | 3 +++
 egs/yesno/ASR/prepare.sh            | 3 +++
 15 files changed, 45 insertions(+)

diff --git a/egs/aidatatang_200zh/ASR/prepare.sh b/egs/aidatatang_200zh/ASR/prepare.sh
index 4749e1b7f..46ecd5769 100755
--- a/egs/aidatatang_200zh/ASR/prepare.sh
+++ b/egs/aidatatang_200zh/ASR/prepare.sh
@@ -1,5 +1,8 @@
 #!/usr/bin/env bash
 
+# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
+export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
+
 set -eou pipefail
 
 stage=-1
diff --git a/egs/aishell/ASR/prepare.sh b/egs/aishell/ASR/prepare.sh
index eaeecfc4a..5917668a1 100755
--- a/egs/aishell/ASR/prepare.sh
+++ b/egs/aishell/ASR/prepare.sh
@@ -1,5 +1,8 @@
 #!/usr/bin/env bash
 
+# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
+export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
+
 set -eou pipefail
 
 nj=15
diff --git a/egs/aishell2/ASR/prepare.sh b/egs/aishell2/ASR/prepare.sh
index 06810bfdd..3e8e840ab 100755
--- a/egs/aishell2/ASR/prepare.sh
+++ b/egs/aishell2/ASR/prepare.sh
@@ -1,5 +1,8 @@
 #!/usr/bin/env bash
 
+# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
+export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
+
 set -eou pipefail
 
 nj=30
diff --git a/egs/aishell4/ASR/prepare.sh b/egs/aishell4/ASR/prepare.sh
index c351e3964..cb2b73a3e 100755
--- a/egs/aishell4/ASR/prepare.sh
+++ b/egs/aishell4/ASR/prepare.sh
@@ -1,5 +1,8 @@
 #!/usr/bin/env bash
 
+# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
+export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
+
 set -eou pipefail
 
 stage=-1
diff --git a/egs/alimeeting/ASR/prepare.sh b/egs/alimeeting/ASR/prepare.sh
index 17224bb68..604cc92c6 100755
--- a/egs/alimeeting/ASR/prepare.sh
+++ b/egs/alimeeting/ASR/prepare.sh
@@ -1,5 +1,8 @@
 #!/usr/bin/env bash
 
+# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
+export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
+
 set -eou pipefail
 
 stage=-1
diff --git a/egs/csj/ASR/prepare.sh b/egs/csj/ASR/prepare.sh
index 052748ca6..c4ce91984 100755
--- a/egs/csj/ASR/prepare.sh
+++ b/egs/csj/ASR/prepare.sh
@@ -35,6 +35,9 @@
 # can generate other transcript formats by supplying your own config files. A few examples of these
 # config files can be found in local/conf.
 
+# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
+export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
+
 set -eou pipefail
 
 nj=8
diff --git a/egs/gigaspeech/ASR/prepare.sh b/egs/gigaspeech/ASR/prepare.sh
index fd2532741..bd255dc6a 100755
--- a/egs/gigaspeech/ASR/prepare.sh
+++ b/egs/gigaspeech/ASR/prepare.sh
@@ -1,5 +1,8 @@
 #!/usr/bin/env bash
 
+# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
+export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
+
 set -eou pipefail
 
 nj=15
diff --git a/egs/librispeech/ASR/prepare.sh b/egs/librispeech/ASR/prepare.sh
index 94e003036..8668af0e4 100755
--- a/egs/librispeech/ASR/prepare.sh
+++ b/egs/librispeech/ASR/prepare.sh
@@ -1,5 +1,8 @@
 #!/usr/bin/env bash
 
+# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
+export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
+
 set -eou pipefail
 
 nj=15
diff --git a/egs/ptb/LM/prepare.sh b/egs/ptb/LM/prepare.sh
index 70586785d..91c3c667a 100755
--- a/egs/ptb/LM/prepare.sh
+++ b/egs/ptb/LM/prepare.sh
@@ -1,5 +1,8 @@
 #!/usr/bin/env bash
 
+# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
+export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
+
 set -eou pipefail
 
 nj=15
diff --git a/egs/spgispeech/ASR/prepare.sh b/egs/spgispeech/ASR/prepare.sh
index 231ebd742..4842f52d0 100755
--- a/egs/spgispeech/ASR/prepare.sh
+++ b/egs/spgispeech/ASR/prepare.sh
@@ -1,5 +1,8 @@
 #!/usr/bin/env bash
 
+# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
+export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
+
 set -eou pipefail
 
 nj=20
diff --git a/egs/tal_csasr/ASR/prepare.sh b/egs/tal_csasr/ASR/prepare.sh
index 340521ad8..d9938fa63 100755
--- a/egs/tal_csasr/ASR/prepare.sh
+++ b/egs/tal_csasr/ASR/prepare.sh
@@ -1,5 +1,8 @@
 #!/usr/bin/env bash
 
+# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
+export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
+
 set -eou pipefail
 
 stage=-1
diff --git a/egs/tedlium3/ASR/prepare.sh b/egs/tedlium3/ASR/prepare.sh
index ccb307a52..272cf7aed 100755
--- a/egs/tedlium3/ASR/prepare.sh
+++ b/egs/tedlium3/ASR/prepare.sh
@@ -1,5 +1,8 @@
 #!/usr/bin/env bash
 
+# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
+export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
+
 set -eou pipefail
 
 nj=15
diff --git a/egs/timit/ASR/prepare.sh b/egs/timit/ASR/prepare.sh
index d11cd3a05..148a9f51b 100644
--- a/egs/timit/ASR/prepare.sh
+++ b/egs/timit/ASR/prepare.sh
@@ -1,5 +1,8 @@
 #!/usr/bin/env bash
 
+# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
+export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
+
 set -eou pipefail
 
 num_phones=39
diff --git a/egs/wenetspeech/ASR/prepare.sh b/egs/wenetspeech/ASR/prepare.sh
index da7d7e061..50a00253d 100755
--- a/egs/wenetspeech/ASR/prepare.sh
+++ b/egs/wenetspeech/ASR/prepare.sh
@@ -1,5 +1,8 @@
 #!/usr/bin/env bash
 
+# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
+export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
+
 set -eou pipefail
 
 nj=15
diff --git a/egs/yesno/ASR/prepare.sh b/egs/yesno/ASR/prepare.sh
index 8fcee0290..d4ef8d601 100755
--- a/egs/yesno/ASR/prepare.sh
+++ b/egs/yesno/ASR/prepare.sh
@@ -1,5 +1,8 @@
 #!/usr/bin/env bash
 
+# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
+export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
+
 set -eou pipefail
 
 stage=-1

From 4c636c2cfffd853a4dc1f618dd8a6fede78a3bea Mon Sep 17 00:00:00 2001
From: Senyan Li <1149593720@qq.com>
Date: Fri, 25 Nov 2022 14:39:56 +0800
Subject: [PATCH 017/174] fix librispeech ASR pruned_transducer_stateless5
 export (#704)

---
 egs/librispeech/ASR/pruned_transducer_stateless5/export.py      | 2 ++
 egs/librispeech/ASR/pruned_transducer_stateless5/lstmp.py       | 1 +
 .../ASR/pruned_transducer_stateless5/scaling_converter.py       | 1 +
 3 files changed, 4 insertions(+)
 create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless5/lstmp.py
 create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless5/scaling_converter.py

diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/export.py b/egs/librispeech/ASR/pruned_transducer_stateless5/export.py
index a4fad1e59..54f656859 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless5/export.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless5/export.py
@@ -50,6 +50,7 @@ from pathlib import Path
 
 import sentencepiece as spm
 import torch
+from scaling_converter import convert_scaled_to_non_scaled
 from train import add_model_arguments, get_params, get_transducer_model
 
 from icefall.checkpoint import (
@@ -263,6 +264,7 @@ def main():
         # it here.
         # Otherwise, one of its arguments is a ragged tensor and is not
         # torch scriptabe.
+        convert_scaled_to_non_scaled(model, inplace=True)
         model.__class__.forward = torch.jit.ignore(model.__class__.forward)
         logging.info("Using torch.jit.script")
         model = torch.jit.script(model)
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/lstmp.py b/egs/librispeech/ASR/pruned_transducer_stateless5/lstmp.py
new file mode 120000
index 000000000..4f377cd01
--- /dev/null
+++ b/egs/librispeech/ASR/pruned_transducer_stateless5/lstmp.py
@@ -0,0 +1 @@
+../lstm_transducer_stateless2/lstmp.py
\ No newline at end of file
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/scaling_converter.py b/egs/librispeech/ASR/pruned_transducer_stateless5/scaling_converter.py
new file mode 120000
index 000000000..3b667058d
--- /dev/null
+++ b/egs/librispeech/ASR/pruned_transducer_stateless5/scaling_converter.py
@@ -0,0 +1 @@
+../pruned_transducer_stateless3/scaling_converter.py
\ No newline at end of file

From 89c3982a0760f135740556ae67c11d0af434303c Mon Sep 17 00:00:00 2001
From: Guo Liyong 
Date: Sat, 26 Nov 2022 00:50:21 +0800
Subject: [PATCH 018/174] show dominant parameters

---
 .../ASR/pruned_transducer_stateless7/optim.py | 79 ++++++++++++++++---
 .../ASR/pruned_transducer_stateless7/train.py | 13 ++-
 2 files changed, 79 insertions(+), 13 deletions(-)

diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py
index 8b90c9a0d..ab55381d7 100644
--- a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py
@@ -42,7 +42,7 @@ class BatchedOptimizer(Optimizer):
         super(BatchedOptimizer, self).__init__(params, defaults)
 
     @contextlib.contextmanager
-    def batched_params(self, param_group):
+    def batched_params(self, param_group, group_params_names=None):
         """
         This function returns (technically, yields) a list of
           of tuples (p, state), where
@@ -75,20 +75,28 @@ class BatchedOptimizer(Optimizer):
         batches = defaultdict(
             list
         )  # `batches` maps from tuple (dtype_as_str,*shape) to list of nn.Parameter
+        batches_names = defaultdict(
+            list
+        )  # `batches` maps from tuple (dtype_as_str,*shape) to list of str
 
-        for p in param_group:
+        for p, named_p in zip(param_group, group_params_names):
             key = (str(p.dtype), *p.shape)
             batches[key].append(p)
+            batches_names[key].append(named_p)
+
+        batches_names_keys = list(batches_names.keys())
+        sorted_idx = sorted(range(len(batches_names)), key=lambda i: batches_names_keys[i])
+        batches_names = [batches_names[batches_names_keys[idx]] for idx in sorted_idx]
+        batches = [batches[batches_names_keys[idx]] for idx in sorted_idx]
 
         stacked_params_dict = dict()
 
         # turn batches into a list, in deterministic order.
-        batches = [batches[key] for key in sorted(batches.keys())]
         # pairs will contain pairs of (stacked_param, state), one for each batch
         # in `batches`.
         pairs = []
 
-        for batch in batches:
+        for batch, batch_names in zip(batches, batches_names):
             p = batch[0]
             # we arbitrarily store the state in the
             # state corresponding to the 1st parameter in the
@@ -100,11 +108,11 @@ class BatchedOptimizer(Optimizer):
             )
             p_stacked.grad = grad
             stacked_params_dict[key] = p_stacked
-            pairs.append((p_stacked, state))
+            pairs.append((p_stacked, state, batch_names))
 
         yield pairs  # <-- calling code will do the actual optimization here!
 
-        for ((stacked_params, _state), batch) in zip(pairs, batches):
+        for ((stacked_params, _state, _names), batch) in zip(pairs, batches):
             for i, p in enumerate(batch):  # batch is list of Parameter
                 p.copy_(stacked_params[i])
 
@@ -165,6 +173,8 @@ class ScaledAdam(BatchedOptimizer):
         scalar_max=10.0,
         size_update_period=4,
         clipping_update_period=100,
+        parameters_names=None,
+        show_dominant_parameters=False,
     ):
 
         defaults = dict(
@@ -181,6 +191,8 @@ class ScaledAdam(BatchedOptimizer):
         )
 
         super(ScaledAdam, self).__init__(params, defaults)
+        self.parameters_names = parameters_names
+        self.show_dominant_parameters = show_dominant_parameters
 
     def __setstate__(self, state):
         super(ScaledAdam, self).__setstate__(state)
@@ -199,9 +211,11 @@ class ScaledAdam(BatchedOptimizer):
                 loss = closure()
 
         batch = True
-        for group in self.param_groups:
+        assert len(self.param_groups)  == len(self.parameters_names)
 
-            with self.batched_params(group["params"]) as batches:
+        for group, group_params_names in zip(self.param_groups, self.parameters_names):
+
+            with self.batched_params(group["params"], group_params_names) as batches:
 
                 # batches is list of pairs (stacked_param, state).  stacked_param is like
                 # a regular parameter, and will have a .grad, but the 1st dim corresponds to
@@ -214,7 +228,7 @@ class ScaledAdam(BatchedOptimizer):
                 else:
                     clipping_scale = self._get_clipping_scale(group, batches)
 
-                for p, state in batches:
+                for p, state, _ in batches:
                     # Perform optimization step.
                     # grad is not going to be None, we handled that when creating the batches.
                     grad = p.grad
@@ -276,7 +290,7 @@ class ScaledAdam(BatchedOptimizer):
         state["exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format)
 
     def _get_clipping_scale(
-        self, group: dict, pairs: List[Tuple[Tensor, dict]]
+        self, group: dict, pairs: List[Tuple[Tensor, dict, List[str]]]
     ) -> float:
         """
         Returns a scalar factor <= 1.0 that dictates gradient clipping, i.e. we will scale the gradients
@@ -289,7 +303,7 @@ class ScaledAdam(BatchedOptimizer):
         """
         assert len(pairs) >= 1
         clipping_scale = group["clipping_scale"]
-        (first_p, first_state) = pairs[0]
+        (first_p, first_state, _) = pairs[0]
         step = first_state["step"]
         if clipping_scale is None or step == 0:
             # no clipping.  return early on step == 0 because the other
@@ -298,7 +312,7 @@ class ScaledAdam(BatchedOptimizer):
         clipping_update_period = group["clipping_update_period"]
 
         tot_sumsq = torch.tensor(0.0, device=first_p.device)
-        for (p, state) in pairs:
+        for (p, state, param_names) in pairs:
             grad = p.grad
             if grad.is_sparse:
                 raise RuntimeError(
@@ -361,8 +375,49 @@ class ScaledAdam(BatchedOptimizer):
                 logging.warn(
                     f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}"
                 )
+                if self.show_dominant_parameters:
+                    assert p.shape[0] == len(param_names)
+                    self._show_gradient_dominating_parameter(pairs, tot_sumsq)
             return ans
 
+    def _show_gradient_dominating_parameter(self, pairs, tot_sumsq):
+        # ori means calculated with state["param_rms"]
+        # cur means calculated with "param_rms" of current param.
+        # bt is short batch
+        # all_sumsq_ori_rms
+        all_sumsq_ori = {}
+        all_sumsq_cur = {}
+        for (p, state, batch_param_names) in pairs:
+            # p is a stacked batch parameters.
+            grad = p.grad
+            if p.numel() == p.shape[0]:  # a batch of scalars
+                batch_sumsq_ori = grad**2  # sum() to change shape [1] to []
+                batch_sumsq_cur = batch_sumsq_ori  # sum() to change shape [1] to []
+                # Dummpy values used by following `zip` statement.
+                batch_rms_ori = torch.zeros(p.shape[0])
+                batch_rms_cur = batch_rms_ori
+            else:
+                batch_rms_ori = state["param_rms"]
+                batch_sumsq_ori = ((grad * batch_rms_ori) ** 2).sum(dim=list(range(1, grad.ndim)))
+
+                batch_rms_cur = (p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt()
+                batch_sumsq_cur = ((grad * batch_rms_cur) ** 2).sum(dim=list(range(1, grad.ndim)))
+
+            for name, sumsq_ori, sumsq_cur in zip(
+                    batch_param_names, batch_sumsq_ori, batch_sumsq_cur):
+
+                proportion_ori = sumsq_ori / tot_sumsq
+                proportion_cur = sumsq_cur / tot_sumsq
+
+                all_sumsq_ori[name] = (proportion_ori, sumsq_ori)
+                all_sumsq_cur[name] = (proportion_cur, sumsq_cur)
+
+        for rms_type, all_sumsq in zip(("ori", "cur"), (all_sumsq_ori, all_sumsq_cur)):
+            sorted_by_proportion = {k: v for k, v in sorted(all_sumsq.items(), key=lambda item: item[1][0], reverse=True)}
+            dominant_param_name = next(iter(sorted_by_proportion))
+            dominant_proportion, dominant_sumsq = sorted_by_proportion[dominant_param_name]
+            logging.info(f"Dominant sumsq with {rms_type}_rms: {dominant_param_name} {dominant_proportion}  {dominant_sumsq} {tot_sumsq}")
+
     def _step_one_batch(
         self, group: dict, p: Tensor, state: dict, clipping_scale: float
     ):
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py
index b27c573ab..8375b1a18 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py
@@ -368,6 +368,13 @@ def get_parser():
         help="Whether to use half precision training.",
     )
 
+    parser.add_argument(
+        "--show-dominant-parameters",
+        type=str2bool,
+        default=False,
+        help="Whether to show dominant parameters.",
+    )
+
     add_model_arguments(parser)
 
     return parser
@@ -988,7 +995,11 @@ def run(rank, world_size, args):
         logging.info("Using DDP")
         model = DDP(model, device_ids=[rank], find_unused_parameters=True)
 
-    optimizer = ScaledAdam(model.parameters(), lr=params.base_lr, clipping_scale=2.0)
+    parameters_names = []
+    parameters_names.append([name_param_pair[0] for name_param_pair in model.named_parameters()])
+    optimizer = ScaledAdam(model.parameters(), lr=params.base_lr,
+            clipping_scale=2.0, parameters_names=parameters_names,
+            show_dominant_parameters=params.show_dominant_parameters)
 
     scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
 

From db75627e92155c16fd6e74d640ece4f6563f96f2 Mon Sep 17 00:00:00 2001
From: Desh Raj 
Date: Fri, 25 Nov 2022 21:00:45 -0500
Subject: [PATCH 019/174] [recipe] AMI Zipformer transducer (#698)

* remove unnecessary changes

* add AMI prepare scripts

* add zipformer scripts for AMI

* added logs and pretrained model

* minor fix

* remove unwanted changes

* fix missing link

* make suggested changes

* update results
---
 egs/ami/ASR/README.md                         |   48 +
 egs/ami/ASR/RESULTS.md                        |   92 ++
 egs/ami/ASR/local/__init__.py                 |    0
 egs/ami/ASR/local/compute_fbank_ami.py        |  194 +++
 egs/ami/ASR/local/compute_fbank_musan.py      |  114 ++
 egs/ami/ASR/local/prepare_ami_enhanced.py     |  158 +++
 egs/ami/ASR/local/prepare_ami_gss.sh          |   98 ++
 egs/ami/ASR/local/prepare_lang_bpe.py         |    1 +
 egs/ami/ASR/local/train_bpe_model.py          |    1 +
 egs/ami/ASR/prepare.sh                        |  144 ++
 .../pruned_transducer_stateless7/__init__.py  |    0
 .../asr_datamodule.py                         |  430 ++++++
 .../beam_search.py                            |    1 +
 .../pruned_transducer_stateless7/decode.py    |  747 +++++++++++
 .../pruned_transducer_stateless7/decoder.py   |    1 +
 .../encoder_interface.py                      |    1 +
 .../pruned_transducer_stateless7/export.py    |    1 +
 .../pruned_transducer_stateless7/joiner.py    |    1 +
 .../ASR/pruned_transducer_stateless7/model.py |    1 +
 .../ASR/pruned_transducer_stateless7/optim.py |    1 +
 .../pruned_transducer_stateless7/scaling.py   |    1 +
 .../scaling_converter.py                      |    1 +
 .../ASR/pruned_transducer_stateless7/train.py | 1184 +++++++++++++++++
 .../pruned_transducer_stateless7/zipformer.py |    1 +
 egs/ami/ASR/shared                            |    1 +
 25 files changed, 3222 insertions(+)
 create mode 100644 egs/ami/ASR/README.md
 create mode 100644 egs/ami/ASR/RESULTS.md
 create mode 100644 egs/ami/ASR/local/__init__.py
 create mode 100755 egs/ami/ASR/local/compute_fbank_ami.py
 create mode 100755 egs/ami/ASR/local/compute_fbank_musan.py
 create mode 100644 egs/ami/ASR/local/prepare_ami_enhanced.py
 create mode 100755 egs/ami/ASR/local/prepare_ami_gss.sh
 create mode 120000 egs/ami/ASR/local/prepare_lang_bpe.py
 create mode 120000 egs/ami/ASR/local/train_bpe_model.py
 create mode 100755 egs/ami/ASR/prepare.sh
 create mode 100644 egs/ami/ASR/pruned_transducer_stateless7/__init__.py
 create mode 100644 egs/ami/ASR/pruned_transducer_stateless7/asr_datamodule.py
 create mode 120000 egs/ami/ASR/pruned_transducer_stateless7/beam_search.py
 create mode 100755 egs/ami/ASR/pruned_transducer_stateless7/decode.py
 create mode 120000 egs/ami/ASR/pruned_transducer_stateless7/decoder.py
 create mode 120000 egs/ami/ASR/pruned_transducer_stateless7/encoder_interface.py
 create mode 120000 egs/ami/ASR/pruned_transducer_stateless7/export.py
 create mode 120000 egs/ami/ASR/pruned_transducer_stateless7/joiner.py
 create mode 120000 egs/ami/ASR/pruned_transducer_stateless7/model.py
 create mode 120000 egs/ami/ASR/pruned_transducer_stateless7/optim.py
 create mode 120000 egs/ami/ASR/pruned_transducer_stateless7/scaling.py
 create mode 120000 egs/ami/ASR/pruned_transducer_stateless7/scaling_converter.py
 create mode 100755 egs/ami/ASR/pruned_transducer_stateless7/train.py
 create mode 120000 egs/ami/ASR/pruned_transducer_stateless7/zipformer.py
 create mode 120000 egs/ami/ASR/shared

diff --git a/egs/ami/ASR/README.md b/egs/ami/ASR/README.md
new file mode 100644
index 000000000..1c9714bd4
--- /dev/null
+++ b/egs/ami/ASR/README.md
@@ -0,0 +1,48 @@
+# AMI
+
+This is an ASR recipe for the AMI corpus. AMI provides recordings from the speaker's
+headset and lapel microphones, and also 2 array microphones containing 8 channels each.
+We pool data in the following 4 ways and train a single model on the pooled data:
+
+(i) individual headset microphone (IHM)
+(ii) IHM with simulated reverb
+(iii) Single distant microphone (SDM)
+(iv) GSS-enhanced array microphones
+
+Speed perturbation and MUSAN noise augmentation are additionally performed on the pooled
+data. Here are the statistics of the combined training data:
+
+```python
+>>> cuts_train.describe()
+Cuts count: 1222053
+Total duration (hh:mm:ss): 905:00:28
+Speech duration (hh:mm:ss): 905:00:28 (99.9%)
+Duration statistics (seconds):
+mean    2.7
+std     2.8
+min     0.0
+25%     0.6
+50%     1.6
+75%     3.8
+99%     12.3
+99.5%   13.9
+99.9%   18.4
+max     36.8
+```
+
+**Note:** This recipe additionally uses [GSS](https://github.com/desh2608/gss) for enhancement
+of far-field array microphones, but this is optional (see `prepare.sh` for details).
+
+## Performance Record
+
+### pruned_transducer_stateless7
+
+The following are decoded using `modified_beam_search`:
+
+| Evaluation set           | dev WER    | test WER |
+|--------------------------|------------|---------|
+| IHM                      |  18.92  | 17.40 |
+| SDM                      |  31.25  | 32.21 |
+| MDM (GSS-enhanced)       |  21.67  | 22.43 |
+
+See [RESULTS](/egs/ami/ASR/RESULTS.md) for details.
diff --git a/egs/ami/ASR/RESULTS.md b/egs/ami/ASR/RESULTS.md
new file mode 100644
index 000000000..163986021
--- /dev/null
+++ b/egs/ami/ASR/RESULTS.md
@@ -0,0 +1,92 @@
+## Results
+
+### AMI training results (Pruned Transducer)
+
+#### 2022-11-20
+
+#### Zipformer (pruned_transducer_stateless7)
+
+Zipformer encoder + non-current decoder. The decoder
+contains only an embedding layer, a Conv1d (with kernel size 2) and a linear
+layer (to transform tensor dim).
+
+All the results below are using a single model that is trained by combining the following
+data: IHM, IHM+reverb, SDM, and GSS-enhanced MDM. Speed perturbation and MUSAN noise
+augmentation are applied on top of the pooled data.
+
+**WERs for IHM:**
+
+|                           | dev | test | comment                                  |
+|---------------------------|------------|------------|------------------------------------------|
+| greedy search             |  19.25  |  17.83  | --epoch 14 --avg 8 --max-duration 500 |
+| modified beam search      |  18.92  |  17.40  | --epoch 14 --avg 8 --max-duration 500 --beam-size 4 |
+| fast beam search          |  19.44  |  18.04  | --epoch 14 --avg 8 --max-duration 500 --beam-size 4 --max-contexts 4 --max-states 8 |
+
+**WERs for SDM:**
+
+|                           | dev | test | comment                                  |
+|---------------------------|------------|------------|------------------------------------------|
+| greedy search             |  31.32  |  32.38  | --epoch 14 --avg 8 --max-duration 500 |
+| modified beam search      |  31.25  |  32.21  | --epoch 14 --avg 8 --max-duration 500 --beam-size 4 |
+| fast beam search          |  31.11  |  32.10  | --epoch 14 --avg 8 --max-duration 500 --beam-size 4 --max-contexts 4 --max-states 8 |
+
+**WERs for GSS-enhanced MDM:**
+
+|                           | dev | test | comment                                  |
+|---------------------------|------------|------------|------------------------------------------|
+| greedy search             |  22.05  |  22.93  | --epoch 14 --avg 8 --max-duration 500 |
+| modified beam search      |  21.67  |  22.43  | --epoch 14 --avg 8 --max-duration 500 --beam-size 4 |
+| fast beam search          |  22.21  |  22.83  | --epoch 14 --avg 8 --max-duration 500 --beam-size 4 --max-contexts 4 --max-states 8 |
+
+The training command for reproducing is given below:
+
+```
+export CUDA_VISIBLE_DEVICES="0,1,2,3"
+
+./pruned_transducer_stateless7/train.py \
+  --world-size 4 \
+  --num-epochs 15 \
+  --exp-dir pruned_transducer_stateless7/exp \
+  --max-duration 150 \
+  --max-cuts 150 \
+  --prune-range 5 \
+  --lr-factor 5 \
+  --lm-scale 0.25 \
+  --use-fp16 True
+```
+
+The decoding command is:
+```
+# greedy search
+./pruned_transducer_stateless7/decode.py \
+        --epoch 14 \
+        --avg 8 \
+        --exp-dir ./pruned_transducer_stateless7/exp \
+        --max-duration 500 \
+        --decoding-method greedy_search
+
+# modified beam search
+./pruned_transducer_stateless7/decode.py \
+        --iter 105000 \
+        --avg 10 \
+        --exp-dir ./pruned_transducer_stateless7/exp \
+        --max-duration 500 \
+        --decoding-method modified_beam_search \
+        --beam-size 4
+
+# fast beam search
+./pruned_transducer_stateless7/decode.py \
+        --iter 105000 \
+        --avg 10 \
+        --exp-dir ./pruned_transducer_stateless5/exp \
+        --max-duration 500 \
+        --decoding-method fast_beam_search \
+        --beam 4 \
+        --max-contexts 4 \
+        --max-states 8
+```
+
+Pretrained model is available at 
+
+The tensorboard training log can be found at
+
diff --git a/egs/ami/ASR/local/__init__.py b/egs/ami/ASR/local/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/egs/ami/ASR/local/compute_fbank_ami.py b/egs/ami/ASR/local/compute_fbank_ami.py
new file mode 100755
index 000000000..4892b40e3
--- /dev/null
+++ b/egs/ami/ASR/local/compute_fbank_ami.py
@@ -0,0 +1,194 @@
+#!/usr/bin/env python3
+# Copyright    2022  Johns Hopkins University        (authors: Desh Raj)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""
+This file computes fbank features of the AMI dataset.
+For the training data, we pool together IHM, reverberated IHM, and GSS-enhanced
+audios. For the test data, we separately prepare IHM, SDM, and GSS-enhanced
+parts (which are the 3 evaluation settings).
+It looks for manifests in the directory data/manifests.
+
+The generated fbank features are saved in data/fbank.
+"""
+import logging
+import math
+from pathlib import Path
+
+import torch
+import torch.multiprocessing
+from lhotse import CutSet, LilcomChunkyWriter
+from lhotse.features.kaldifeat import (
+    KaldifeatFbank,
+    KaldifeatFbankConfig,
+    KaldifeatFrameOptions,
+    KaldifeatMelOptions,
+)
+from lhotse.recipes.utils import read_manifests_if_cached
+
+# Torch's multithreaded behavior needs to be disabled or
+# it wastes a lot of CPU and slow things down.
+# Do this outside of main() in case it needs to take effect
+# even when we are not invoking the main (e.g. when spawning subprocesses).
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+torch.multiprocessing.set_sharing_strategy("file_system")
+
+
+def compute_fbank_ami():
+    src_dir = Path("data/manifests")
+    output_dir = Path("data/fbank")
+
+    sampling_rate = 16000
+    num_mel_bins = 80
+
+    extractor = KaldifeatFbank(
+        KaldifeatFbankConfig(
+            frame_opts=KaldifeatFrameOptions(sampling_rate=sampling_rate),
+            mel_opts=KaldifeatMelOptions(num_bins=num_mel_bins),
+            device="cuda",
+        )
+    )
+
+    logging.info("Reading manifests")
+    manifests_ihm = read_manifests_if_cached(
+        dataset_parts=["train", "dev", "test"],
+        output_dir=src_dir,
+        prefix="ami-ihm",
+        suffix="jsonl.gz",
+    )
+    manifests_sdm = read_manifests_if_cached(
+        dataset_parts=["train", "dev", "test"],
+        output_dir=src_dir,
+        prefix="ami-sdm",
+        suffix="jsonl.gz",
+    )
+    # For GSS we already have cuts so we read them directly.
+    manifests_gss = read_manifests_if_cached(
+        dataset_parts=["train", "dev", "test"],
+        output_dir=src_dir,
+        prefix="ami-gss",
+        suffix="jsonl.gz",
+    )
+
+    def _extract_feats(cuts: CutSet, storage_path: Path, manifest_path: Path) -> None:
+        cuts = cuts + cuts.perturb_speed(0.9) + cuts.perturb_speed(1.1)
+        _ = cuts.compute_and_store_features_batch(
+            extractor=extractor,
+            storage_path=storage_path,
+            manifest_path=manifest_path,
+            batch_duration=5000,
+            num_workers=8,
+            storage_type=LilcomChunkyWriter,
+        )
+
+    logging.info(
+        "Preparing training cuts: IHM + reverberated IHM + SDM + GSS (optional)"
+    )
+
+    logging.info("Processing train split IHM")
+    cuts_ihm = (
+        CutSet.from_manifests(**manifests_ihm["train"])
+        .trim_to_supervisions(keep_overlapping=False, keep_all_channels=False)
+        .modify_ids(lambda x: x + "-ihm")
+    )
+    _extract_feats(
+        cuts_ihm,
+        output_dir / "feats_train_ihm",
+        src_dir / "cuts_train_ihm.jsonl.gz",
+    )
+
+    logging.info("Processing train split IHM + reverberated IHM")
+    cuts_ihm_rvb = cuts_ihm.reverb_rir()
+    _extract_feats(
+        cuts_ihm_rvb,
+        output_dir / "feats_train_ihm_rvb",
+        src_dir / "cuts_train_ihm_rvb.jsonl.gz",
+    )
+
+    logging.info("Processing train split SDM")
+    cuts_sdm = (
+        CutSet.from_manifests(**manifests_sdm["train"])
+        .trim_to_supervisions(keep_overlapping=False)
+        .modify_ids(lambda x: x + "-sdm")
+    )
+    _extract_feats(
+        cuts_sdm,
+        output_dir / "feats_train_sdm",
+        src_dir / "cuts_train_sdm.jsonl.gz",
+    )
+
+    logging.info("Processing train split GSS")
+    cuts_gss = (
+        CutSet.from_manifests(**manifests_gss["train"])
+        .trim_to_supervisions(keep_overlapping=False)
+        .modify_ids(lambda x: x + "-gss")
+    )
+    _extract_feats(
+        cuts_gss,
+        output_dir / "feats_train_gss",
+        src_dir / "cuts_train_gss.jsonl.gz",
+    )
+
+    logging.info("Preparing test cuts: IHM, SDM, GSS (optional)")
+    for split in ["dev", "test"]:
+        logging.info(f"Processing {split} IHM")
+        cuts_ihm = (
+            CutSet.from_manifests(**manifests_ihm[split])
+            .trim_to_supervisions(keep_overlapping=False, keep_all_channels=False)
+            .compute_and_store_features_batch(
+                extractor=extractor,
+                storage_path=output_dir / f"feats_{split}_ihm",
+                manifest_path=src_dir / f"cuts_{split}_ihm.jsonl.gz",
+                batch_duration=5000,
+                num_workers=8,
+                storage_type=LilcomChunkyWriter,
+            )
+        )
+        logging.info(f"Processing {split} SDM")
+        cuts_sdm = (
+            CutSet.from_manifests(**manifests_sdm[split])
+            .trim_to_supervisions(keep_overlapping=False)
+            .compute_and_store_features_batch(
+                extractor=extractor,
+                storage_path=output_dir / f"feats_{split}_sdm",
+                manifest_path=src_dir / f"cuts_{split}_sdm.jsonl.gz",
+                batch_duration=500,
+                num_workers=4,
+                storage_type=LilcomChunkyWriter,
+            )
+        )
+        logging.info(f"Processing {split} GSS")
+        cuts_gss = (
+            CutSet.from_manifests(**manifests_gss[split])
+            .trim_to_supervisions(keep_overlapping=False)
+            .compute_and_store_features_batch(
+                extractor=extractor,
+                storage_path=output_dir / f"feats_{split}_gss",
+                manifest_path=src_dir / f"cuts_{split}_gss.jsonl.gz",
+                batch_duration=500,
+                num_workers=4,
+                storage_type=LilcomChunkyWriter,
+            )
+        )
+
+
+if __name__ == "__main__":
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    logging.basicConfig(format=formatter, level=logging.INFO)
+
+    compute_fbank_ami()
diff --git a/egs/ami/ASR/local/compute_fbank_musan.py b/egs/ami/ASR/local/compute_fbank_musan.py
new file mode 100755
index 000000000..1fcf951f9
--- /dev/null
+++ b/egs/ami/ASR/local/compute_fbank_musan.py
@@ -0,0 +1,114 @@
+#!/usr/bin/env python3
+# Copyright    2021  Xiaomi Corp.        (authors: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""
+This file computes fbank features of the musan dataset.
+It looks for manifests in the directory data/manifests.
+
+The generated fbank features are saved in data/fbank.
+"""
+
+import logging
+from pathlib import Path
+
+import torch
+from lhotse import CutSet, LilcomChunkyWriter, combine
+from lhotse.features.kaldifeat import (
+    KaldifeatFbank,
+    KaldifeatFbankConfig,
+    KaldifeatFrameOptions,
+    KaldifeatMelOptions,
+)
+from lhotse.recipes.utils import read_manifests_if_cached
+
+# Torch's multithreaded behavior needs to be disabled or
+# it wastes a lot of CPU and slow things down.
+# Do this outside of main() in case it needs to take effect
+# even when we are not invoking the main (e.g. when spawning subprocesses).
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+
+def compute_fbank_musan():
+    src_dir = Path("data/manifests")
+    output_dir = Path("data/fbank")
+
+    sampling_rate = 16000
+    num_mel_bins = 80
+
+    dataset_parts = (
+        "music",
+        "speech",
+        "noise",
+    )
+    prefix = "musan"
+    suffix = "jsonl.gz"
+    manifests = read_manifests_if_cached(
+        dataset_parts=dataset_parts,
+        output_dir=src_dir,
+        prefix=prefix,
+        suffix=suffix,
+    )
+    assert manifests is not None
+
+    assert len(manifests) == len(dataset_parts), (
+        len(manifests),
+        len(dataset_parts),
+        list(manifests.keys()),
+        dataset_parts,
+    )
+
+    musan_cuts_path = src_dir / "musan_cuts.jsonl.gz"
+
+    if musan_cuts_path.is_file():
+        logging.info(f"{musan_cuts_path} already exists - skipping")
+        return
+
+    logging.info("Extracting features for Musan")
+
+    extractor = KaldifeatFbank(
+        KaldifeatFbankConfig(
+            frame_opts=KaldifeatFrameOptions(sampling_rate=sampling_rate),
+            mel_opts=KaldifeatMelOptions(num_bins=num_mel_bins),
+            device="cuda",
+        )
+    )
+
+    # create chunks of Musan with duration 5 - 10 seconds
+    _ = (
+        CutSet.from_manifests(
+            recordings=combine(part["recordings"] for part in manifests.values())
+        )
+        .cut_into_windows(10.0)
+        .filter(lambda c: c.duration > 5)
+        .compute_and_store_features_batch(
+            extractor=extractor,
+            storage_path=output_dir / "musan_feats",
+            manifest_path=musan_cuts_path,
+            batch_duration=500,
+            num_workers=4,
+            storage_type=LilcomChunkyWriter,
+        )
+    )
+
+
+if __name__ == "__main__":
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+    logging.basicConfig(format=formatter, level=logging.INFO)
+    compute_fbank_musan()
diff --git a/egs/ami/ASR/local/prepare_ami_enhanced.py b/egs/ami/ASR/local/prepare_ami_enhanced.py
new file mode 100644
index 000000000..bed220eb3
--- /dev/null
+++ b/egs/ami/ASR/local/prepare_ami_enhanced.py
@@ -0,0 +1,158 @@
+#!/usr/local/bin/python
+# -*- coding: utf-8 -*-
+# Data preparation for AMI GSS-enhanced dataset.
+
+import logging
+from concurrent.futures import ThreadPoolExecutor
+from pathlib import Path
+
+from lhotse import Recording, RecordingSet, SupervisionSet
+from lhotse.qa import fix_manifests
+from lhotse.recipes.utils import read_manifests_if_cached
+from lhotse.utils import fastcopy
+from tqdm import tqdm
+
+logging.basicConfig(
+    format="%(asctime)s %(levelname)-8s %(message)s",
+    level=logging.INFO,
+    datefmt="%Y-%m-%d %H:%M:%S",
+)
+
+
+def get_args():
+    import argparse
+
+    parser = argparse.ArgumentParser(description="AMI enhanced dataset preparation.")
+    parser.add_argument(
+        "manifests_dir",
+        type=Path,
+        help="Path to directory containing AMI manifests.",
+    )
+    parser.add_argument(
+        "enhanced_dir",
+        type=Path,
+        help="Path to enhanced data directory.",
+    )
+    parser.add_argument(
+        "--num-jobs",
+        "-j",
+        type=int,
+        default=1,
+        help="Number of parallel jobs to run.",
+    )
+    parser.add_argument(
+        "--min-segment-duration",
+        "-d",
+        type=float,
+        default=0.0,
+        help="Minimum duration of a segment in seconds.",
+    )
+    return parser.parse_args()
+
+
+def find_recording_and_create_new_supervision(enhanced_dir, supervision):
+    """
+    Given a supervision (corresponding to original AMI recording), this function finds the
+    enhanced recording correspoding to the supervision, and returns this recording and
+    a new supervision whose start and end times are adjusted to match the enhanced recording.
+    """
+    file_name = Path(
+        f"{supervision.recording_id}-{supervision.speaker}-{int(100*supervision.start):06d}_{int(100*supervision.end):06d}.flac"
+    )
+    save_path = enhanced_dir / f"{supervision.recording_id}" / file_name
+    if save_path.exists():
+        recording = Recording.from_file(save_path)
+        if recording.duration == 0:
+            logging.warning(f"Skipping {save_path} which has duration 0 seconds.")
+            return None
+
+        # Old supervision is wrt to the original recording, we create new supervision
+        # wrt to the enhanced segment
+        new_supervision = fastcopy(
+            supervision,
+            recording_id=recording.id,
+            start=0,
+            duration=recording.duration,
+        )
+        return recording, new_supervision
+    else:
+        logging.warning(f"{save_path} does not exist.")
+        return None
+
+
+def main(args):
+    # Get arguments
+    manifests_dir = args.manifests_dir
+    enhanced_dir = args.enhanced_dir
+
+    # Load manifests from cache if they exist (saves time)
+    manifests = read_manifests_if_cached(
+        dataset_parts=["train", "dev", "test"],
+        output_dir=manifests_dir,
+        prefix="ami-sdm",
+        suffix="jsonl.gz",
+    )
+    if not manifests:
+        raise ValueError("AMI SDM manifests not found in {}".format(manifests_dir))
+
+    with ThreadPoolExecutor(args.num_jobs) as ex:
+        for part in ["train", "dev", "test"]:
+            logging.info(f"Processing {part}...")
+            supervisions_orig = manifests[part]["supervisions"].filter(
+                lambda s: s.duration >= args.min_segment_duration
+            )
+            # Remove TS3009d supervisions since they are not present in the enhanced data
+            supervisions_orig = supervisions_orig.filter(
+                lambda s: s.recording_id != "TS3009d"
+            )
+            futures = []
+
+            for supervision in tqdm(
+                supervisions_orig,
+                desc="Distributing tasks",
+            ):
+                futures.append(
+                    ex.submit(
+                        find_recording_and_create_new_supervision,
+                        enhanced_dir,
+                        supervision,
+                    )
+                )
+
+            recordings = []
+            supervisions = []
+            for future in tqdm(
+                futures,
+                total=len(futures),
+                desc="Processing tasks",
+            ):
+                result = future.result()
+                if result is not None:
+                    recording, new_supervision = result
+                    recordings.append(recording)
+                    supervisions.append(new_supervision)
+
+            # Remove duplicates from the recordings
+            recordings_nodup = {}
+            for recording in recordings:
+                if recording.id not in recordings_nodup:
+                    recordings_nodup[recording.id] = recording
+                else:
+                    logging.warning("Recording {} is duplicated.".format(recording.id))
+            recordings = RecordingSet.from_recordings(recordings_nodup.values())
+            supervisions = SupervisionSet.from_segments(supervisions)
+
+            recordings, supervisions = fix_manifests(
+                recordings=recordings, supervisions=supervisions
+            )
+
+            logging.info(f"Writing {part} enhanced manifests")
+            recordings.to_file(manifests_dir / f"ami-gss_recordings_{part}.jsonl.gz")
+            supervisions.to_file(
+                manifests_dir / f"ami-gss_supervisions_{part}.jsonl.gz"
+            )
+
+
+if __name__ == "__main__":
+    args = get_args()
+    main(args)
diff --git a/egs/ami/ASR/local/prepare_ami_gss.sh b/egs/ami/ASR/local/prepare_ami_gss.sh
new file mode 100755
index 000000000..d5422458b
--- /dev/null
+++ b/egs/ami/ASR/local/prepare_ami_gss.sh
@@ -0,0 +1,98 @@
+#!/bin/bash
+# This script is used to run GSS-based enhancement on AMI data.
+set -euo pipefail
+nj=4
+stage=0
+
+. shared/parse_options.sh || exit 1
+
+if [ $# != 2 ]; then
+   echo "Wrong #arguments ($#, expected 2)"
+   echo "Usage: local/prepare_ami_gss.sh [options]  "
+   echo "e.g. local/prepare_ami_gss.sh data/manifests exp/ami_gss"
+   echo "main options (for others, see top of script file)"
+   echo "  --nj                                 # number of parallel jobs"
+   echo "  --stage                           # stage to start running from"
+   exit 1;
+fi
+
+DATA_DIR=$1
+EXP_DIR=$2
+
+mkdir -p $EXP_DIR
+
+log() {
+  # This function is from espnet
+  local fname=${BASH_SOURCE[1]##*/}
+  echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
+}
+
+if [ $stage -le 1 ]; then
+  log "Stage 1: Prepare cut sets"
+  for part in train dev test; do
+    lhotse cut simple \
+      -r $DATA_DIR/ami-mdm_recordings_${part}.jsonl.gz \
+      -s $DATA_DIR/ami-mdm_supervisions_${part}.jsonl.gz \
+      $EXP_DIR/cuts_${part}.jsonl.gz
+  done
+fi
+
+if [ $stage -le 2 ]; then
+  log "Stage 2: Trim cuts to supervisions (1 cut per supervision segment)"
+  for part in train dev test; do
+    lhotse cut trim-to-supervisions --discard-overlapping \
+        $EXP_DIR/cuts_${part}.jsonl.gz $EXP_DIR/cuts_per_segment_${part}.jsonl.gz
+  done
+fi
+
+if [ $stage -le 3 ]; then
+  log "Stage 3: Split manifests for multi-GPU processing (optional)"
+  for part in train; do
+    gss utils split $nj $EXP_DIR/cuts_per_segment_${part}.jsonl.gz \
+      $EXP_DIR/cuts_per_segment_${part}_split$nj
+  done
+fi
+
+if [ $stage -le 4 ]; then
+  log "Stage 4: Enhance train segments using GSS (requires GPU)"
+  # for train, we use smaller context and larger batches to speed-up processing
+  for JOB in $(seq $nj); do
+    gss enhance cuts $EXP_DIR/cuts_train.jsonl.gz \
+      $EXP_DIR/cuts_per_segment_train_split$nj/cuts_per_segment_train.JOB.jsonl.gz $EXP_DIR/enhanced \
+      --bss-iterations 10 \
+      --context-duration 5.0 \
+      --use-garbage-class \
+      --channels 0,1,2,3,4,5,6,7 \
+      --min-segment-length 0.05 \
+      --max-segment-length 35.0 \
+      --max-batch-duration 60.0 \
+      --num-buckets 3 \
+      --num-workers 2
+  done
+fi
+
+if [ $stage -le 5 ]; then
+  log "Stage 5: Enhance dev/test segments using GSS (using GPU)"
+  # for dev/test, we use larger context and smaller batches to get better quality
+  for part in dev test; do
+    for JOB in $(seq $nj); do
+      gss enhance cuts $EXP_DIR/cuts_${part}.jsonl.gz \
+      $EXP_DIR/cuts_per_segment_${part}_split$nj/cuts_per_segment_${part}.JOB.jsonl.gz \
+      $EXP_DIR/enhanced \
+      --bss-iterations 10 \
+      --context-duration 15.0 \
+      --use-garbage-class \
+      --channels 0,1,2,3,4,5,6,7 \
+      --min-segment-length 0.05 \
+      --max-segment-length 30.0 \
+      --max-batch-duration 45.0 \
+      --num-buckets 3 \
+      --num-workers 2
+    done
+  done
+fi
+
+if [ $stage -le 6 ]; then
+  log "Stage 6: Prepare manifests for GSS-enhanced data"
+  python local/prepare_ami_enhanced.py $DATA_DIR $EXP_DIR/enhanced -j $nj --min-segment-duration 0.05
+fi
diff --git a/egs/ami/ASR/local/prepare_lang_bpe.py b/egs/ami/ASR/local/prepare_lang_bpe.py
new file mode 120000
index 000000000..36b40e7fc
--- /dev/null
+++ b/egs/ami/ASR/local/prepare_lang_bpe.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/local/prepare_lang_bpe.py
\ No newline at end of file
diff --git a/egs/ami/ASR/local/train_bpe_model.py b/egs/ami/ASR/local/train_bpe_model.py
new file mode 120000
index 000000000..6fad36421
--- /dev/null
+++ b/egs/ami/ASR/local/train_bpe_model.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/local/train_bpe_model.py
\ No newline at end of file
diff --git a/egs/ami/ASR/prepare.sh b/egs/ami/ASR/prepare.sh
new file mode 100755
index 000000000..fb21a8ec6
--- /dev/null
+++ b/egs/ami/ASR/prepare.sh
@@ -0,0 +1,144 @@
+#!/usr/bin/env bash
+
+set -eou pipefail
+
+stage=-1
+stop_stage=100
+use_gss=true  # Use GSS-based enhancement with MDM setting
+
+# We assume dl_dir (download dir) contains the following
+# directories and files. If not, they will be downloaded
+# by this script automatically.
+#
+#  - $dl_dir/amicorpus
+#      You can find audio and transcripts in this path.
+#
+#  - $dl_dir/musan
+#      This directory contains the following directories downloaded from
+#       http://www.openslr.org/17/
+#
+#     - music
+#     - noise
+#     - speech
+#
+#  - $dl_dir/{LDC2004S13,LDC2005S13,LDC2004T19,LDC2005T19}
+#      These contain the Fisher English audio and transcripts. We will
+#      only use the transcripts as extra LM training data (similar to Kaldi).
+#
+dl_dir=$PWD/download
+
+. shared/parse_options.sh || exit 1
+
+# All files generated by this script are saved in "data".
+# You can safely remove "data" and rerun this script to regenerate it.
+mkdir -p data
+vocab_size=500
+
+log() {
+  # This function is from espnet
+  local fname=${BASH_SOURCE[1]##*/}
+  echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
+}
+
+log "dl_dir: $dl_dir"
+
+if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
+  log "Stage 0: Download data"
+
+  # If you have pre-downloaded it to /path/to/amicorpus,
+  # you can create a symlink
+  #
+  #   ln -sfv /path/to/amicorpus $dl_dir/amicorpus
+  #
+  if [ ! -d $dl_dir/amicorpus ]; then
+    lhotse download ami --mic ihm $dl_dir/amicorpus
+    lhotse download ami --mic mdm $dl_dir/amicorpus
+  fi
+
+  # If you have pre-downloaded it to /path/to/musan,
+  # you can create a symlink
+  #
+  #   ln -sfv /path/to/musan $dl_dir/
+  #
+  if [ ! -d $dl_dir/musan ]; then
+    lhotse download musan $dl_dir
+  fi
+fi
+
+if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
+  log "Stage 1: Prepare AMI manifests"
+  # We assume that you have downloaded the AMI corpus
+  # to $dl_dir/amicorpus. We perform text normalization for the transcripts.
+  mkdir -p data/manifests
+  for mic in ihm sdm mdm; do
+    lhotse prepare ami --mic $mic --partition full-corpus-asr --normalize-text kaldi \
+      --max-words-per-segment 30 $dl_dir/amicorpus data/manifests/
+  done
+fi
+
+if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
+  log "Stage 2: Prepare musan manifest"
+  # We assume that you have downloaded the musan corpus
+  # to $dl_dir/musan
+  mkdir -p data/manifests
+  lhotse prepare musan $dl_dir/musan data/manifests
+fi
+
+if [ $stage -le 3 ] && [ $stop_stage -ge 3 ] && [ $use_gss = true ]; then
+  log "Stage 3: Apply GSS enhancement on MDM data (this stage requires a GPU)"
+  # We assume that you have installed the GSS package: https://github.com/desh2608/gss
+  local/prepare_ami_gss.sh data/manifests exp/ami_gss
+fi
+
+if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
+  log "Stage 4: Compute fbank features for AMI"
+  mkdir -p data/fbank
+  python local/compute_fbank_ami.py
+  log "Combine features from train splits"
+  lhotse combine data/manifests/cuts_train_{ihm,ihm_rvb,sdm,gss}.jsonl.gz - | shuf |\
+    gzip -c > data/manifests/cuts_train_all.jsonl.gz
+fi
+
+if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
+  log "Stage 5: Compute fbank features for musan"
+  mkdir -p data/fbank
+  python local/compute_fbank_musan.py
+fi
+
+if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
+  log "Stage 6: Dump transcripts for BPE model training."
+  mkdir -p data/lm
+  cat <(gunzip -c data/manifests/ami-sdm_supervisions_train.jsonl.gz | jq '.text' | sed 's:"::g')> data/lm/transcript_words.txt
+fi
+
+if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
+  log "Stage 7: Prepare BPE based lang"
+
+  lang_dir=data/lang_bpe_${vocab_size}
+  mkdir -p $lang_dir
+
+  # Add special words to words.txt
+  echo " 0" > $lang_dir/words.txt
+  echo "!SIL 1" >> $lang_dir/words.txt
+  echo " 2" >> $lang_dir/words.txt
+
+  # Add regular words to words.txt
+  cat data/lm/transcript_words.txt | grep -o -E '\w+' | sort -u | awk '{print $0,NR+2}' >> $lang_dir/words.txt
+
+  # Add remaining special word symbols expected by LM scripts.
+  num_words=$(cat $lang_dir/words.txt | wc -l)
+  echo " ${num_words}" >> $lang_dir/words.txt
+  num_words=$(cat $lang_dir/words.txt | wc -l)
+  echo " ${num_words}" >> $lang_dir/words.txt
+  num_words=$(cat $lang_dir/words.txt | wc -l)
+  echo "#0 ${num_words}" >> $lang_dir/words.txt
+
+  ./local/train_bpe_model.py \
+    --lang-dir $lang_dir \
+    --vocab-size $vocab_size \
+    --transcript data/lm/transcript_words.txt
+
+  if [ ! -f $lang_dir/L_disambig.pt ]; then
+    ./local/prepare_lang_bpe.py --lang-dir $lang_dir
+  fi
+fi
diff --git a/egs/ami/ASR/pruned_transducer_stateless7/__init__.py b/egs/ami/ASR/pruned_transducer_stateless7/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/egs/ami/ASR/pruned_transducer_stateless7/asr_datamodule.py b/egs/ami/ASR/pruned_transducer_stateless7/asr_datamodule.py
new file mode 100644
index 000000000..f7ee9c962
--- /dev/null
+++ b/egs/ami/ASR/pruned_transducer_stateless7/asr_datamodule.py
@@ -0,0 +1,430 @@
+# Copyright      2021  Piotr Żelasko
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import argparse
+import logging
+import re
+from functools import lru_cache
+from pathlib import Path
+from typing import Any, Dict, Optional
+
+import torch
+from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
+from lhotse.cut import Cut
+from lhotse.dataset import (
+    CutConcatenate,
+    CutMix,
+    DynamicBucketingSampler,
+    K2SpeechRecognitionDataset,
+    PrecomputedFeatures,
+    SpecAugment,
+)
+from lhotse.dataset.input_strategies import OnTheFlyFeatures
+from lhotse.utils import fix_random_seed
+from torch.utils.data import DataLoader
+from tqdm import tqdm
+
+from icefall.utils import str2bool
+
+
+class _SeedWorkers:
+    def __init__(self, seed: int):
+        self.seed = seed
+
+    def __call__(self, worker_id: int):
+        fix_random_seed(self.seed + worker_id)
+
+
+class AmiAsrDataModule:
+    """
+    DataModule for k2 ASR experiments.
+    It assumes there is always one train and valid dataloader,
+    but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
+    and test-other).
+    It contains all the common data pipeline modules used in ASR
+    experiments, e.g.:
+    - dynamic batch size,
+    - bucketing samplers,
+    - cut concatenation,
+    - augmentation,
+    - on-the-fly feature extraction
+    This class should be derived for specific corpora used in ASR tasks.
+    """
+
+    def __init__(self, args: argparse.Namespace):
+        self.args = args
+
+    @classmethod
+    def add_arguments(cls, parser: argparse.ArgumentParser):
+        group = parser.add_argument_group(
+            title="ASR data related options",
+            description=(
+                "These options are used for the preparation of "
+                "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
+                "effective batch sizes, sampling strategies, applied data "
+                "augmentations, etc."
+            ),
+        )
+        group.add_argument(
+            "--manifest-dir",
+            type=Path,
+            default=Path("data/manifests"),
+            help="Path to directory with train/valid/test cuts.",
+        )
+        group.add_argument(
+            "--enable-musan",
+            type=str2bool,
+            default=True,
+            help=(
+                "When enabled, select noise from MUSAN and mix it "
+                "with training dataset. "
+            ),
+        )
+        group.add_argument(
+            "--concatenate-cuts",
+            type=str2bool,
+            default=False,
+            help=(
+                "When enabled, utterances (cuts) will be concatenated "
+                "to minimize the amount of padding."
+            ),
+        )
+        group.add_argument(
+            "--duration-factor",
+            type=float,
+            default=1.0,
+            help=(
+                "Determines the maximum duration of a concatenated cut "
+                "relative to the duration of the longest cut in a batch."
+            ),
+        )
+        group.add_argument(
+            "--gap",
+            type=float,
+            default=1.0,
+            help=(
+                "The amount of padding (in seconds) inserted between "
+                "concatenated cuts. This padding is filled with noise when "
+                "noise augmentation is used."
+            ),
+        )
+        group.add_argument(
+            "--max-duration",
+            type=int,
+            default=100.0,
+            help=(
+                "Maximum pooled recordings duration (seconds) in a "
+                "single batch. You can reduce it if it causes CUDA OOM."
+            ),
+        )
+        group.add_argument(
+            "--max-cuts", type=int, default=None, help="Maximum cuts in a single batch."
+        )
+        group.add_argument(
+            "--num-buckets",
+            type=int,
+            default=50,
+            help=(
+                "The number of buckets for the BucketingSampler"
+                "(you might want to increase it for larger datasets)."
+            ),
+        )
+        group.add_argument(
+            "--on-the-fly-feats",
+            type=str2bool,
+            default=False,
+            help=(
+                "When enabled, use on-the-fly cut mixing and feature "
+                "extraction. Will drop existing precomputed feature manifests "
+                "if available."
+            ),
+        )
+        group.add_argument(
+            "--shuffle",
+            type=str2bool,
+            default=True,
+            help=(
+                "When enabled (=default), the examples will be "
+                "shuffled for each epoch."
+            ),
+        )
+
+        group.add_argument(
+            "--num-workers",
+            type=int,
+            default=8,
+            help=(
+                "The number of training dataloader workers that " "collect the batches."
+            ),
+        )
+        group.add_argument(
+            "--enable-spec-aug",
+            type=str2bool,
+            default=True,
+            help="When enabled, use SpecAugment for training dataset.",
+        )
+        group.add_argument(
+            "--spec-aug-time-warp-factor",
+            type=int,
+            default=80,
+            help=(
+                "Used only when --enable-spec-aug is True. "
+                "It specifies the factor for time warping in SpecAugment. "
+                "Larger values mean more warping. "
+                "A value less than 1 means to disable time warp."
+            ),
+        )
+        group.add_argument(
+            "--ihm-only",
+            type=str2bool,
+            default=False,
+            help="When enabled, only use IHM data for training.",
+        )
+
+    def train_dataloaders(
+        self,
+        cuts_train: CutSet,
+        sampler_state_dict: Optional[Dict[str, Any]] = None,
+    ) -> DataLoader:
+        """
+        Args:
+          cuts_train:
+            CutSet for training.
+          sampler_state_dict:
+            The state dict for the training sampler.
+        """
+        logging.info("About to get Musan cuts")
+
+        transforms = []
+        if self.args.enable_musan:
+            logging.info("Enable MUSAN")
+            cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
+            transforms.append(
+                CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
+            )
+        else:
+            logging.info("Disable MUSAN")
+
+        if self.args.concatenate_cuts:
+            logging.info(
+                "Using cut concatenation with duration factor "
+                f"{self.args.duration_factor} and gap {self.args.gap}."
+            )
+            # Cut concatenation should be the first transform in the list,
+            # so that if we e.g. mix noise in, it will fill the gaps between
+            # different utterances.
+            transforms = [
+                CutConcatenate(
+                    duration_factor=self.args.duration_factor, gap=self.args.gap
+                )
+            ] + transforms
+
+        input_transforms = []
+        if self.args.enable_spec_aug:
+            logging.info("Enable SpecAugment")
+            logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
+            input_transforms.append(
+                SpecAugment(
+                    time_warp_factor=self.args.spec_aug_time_warp_factor,
+                    num_frame_masks=2,
+                    features_mask_size=27,
+                    num_feature_masks=2,
+                    frames_mask_size=100,
+                )
+            )
+        else:
+            logging.info("Disable SpecAugment")
+
+        logging.info("About to create train dataset")
+        if self.args.on_the_fly_feats:
+            train = K2SpeechRecognitionDataset(
+                cut_transforms=transforms,
+                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
+                input_transforms=input_transforms,
+            )
+        else:
+            train = K2SpeechRecognitionDataset(
+                cut_transforms=transforms,
+                input_transforms=input_transforms,
+            )
+
+        logging.info("Using DynamicBucketingSampler.")
+        train_sampler = DynamicBucketingSampler(
+            cuts_train,
+            max_duration=self.args.max_duration,
+            max_cuts=self.args.max_cuts,
+            shuffle=False,
+            num_buckets=self.args.num_buckets,
+            drop_last=True,
+        )
+        logging.info("About to create train dataloader")
+
+        if sampler_state_dict is not None:
+            logging.info("Loading sampler state dict")
+            train_sampler.load_state_dict(sampler_state_dict)
+
+        # 'seed' is derived from the current random state, which will have
+        # previously been set in the main process.
+        seed = torch.randint(0, 100000, ()).item()
+        worker_init_fn = _SeedWorkers(seed)
+
+        train_dl = DataLoader(
+            train,
+            sampler=train_sampler,
+            batch_size=None,
+            num_workers=self.args.num_workers,
+            persistent_workers=False,
+            worker_init_fn=worker_init_fn,
+        )
+
+        return train_dl
+
+    def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
+
+        transforms = []
+        if self.args.concatenate_cuts:
+            transforms = [
+                CutConcatenate(
+                    duration_factor=self.args.duration_factor, gap=self.args.gap
+                )
+            ] + transforms
+
+        logging.info("About to create dev dataset")
+        if self.args.on_the_fly_feats:
+            validate = K2SpeechRecognitionDataset(
+                cut_transforms=transforms,
+                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
+            )
+        else:
+            validate = K2SpeechRecognitionDataset(
+                cut_transforms=transforms,
+            )
+        valid_sampler = DynamicBucketingSampler(
+            cuts_valid,
+            max_duration=self.args.max_duration,
+            shuffle=False,
+        )
+        logging.info("About to create dev dataloader")
+        valid_dl = DataLoader(
+            validate,
+            sampler=valid_sampler,
+            batch_size=None,
+            num_workers=2,
+            persistent_workers=False,
+        )
+
+        return valid_dl
+
+    def test_dataloaders(self, cuts: CutSet) -> DataLoader:
+        logging.debug("About to create test dataset")
+        test = K2SpeechRecognitionDataset(
+            input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
+            if self.args.on_the_fly_feats
+            else PrecomputedFeatures(),
+            return_cuts=True,
+        )
+        sampler = DynamicBucketingSampler(
+            cuts, max_duration=self.args.max_duration, shuffle=False
+        )
+        logging.debug("About to create test dataloader")
+        test_dl = DataLoader(
+            test,
+            batch_size=None,
+            sampler=sampler,
+            num_workers=self.args.num_workers,
+        )
+        return test_dl
+
+    def remove_short_cuts(self, cut: Cut) -> bool:
+        """
+        See: https://github.com/k2-fsa/icefall/issues/500
+        Basically, the zipformer model subsamples the input using the following formula:
+        num_out_frames = (num_in_frames - 7)//2
+        For num_out_frames to be at least 1, num_in_frames must be at least 9.
+        """
+        return cut.duration >= 0.09
+
+    @lru_cache()
+    def train_cuts(self, sp: Optional[Any] = None) -> CutSet:
+        logging.info("About to get AMI train cuts")
+
+        def _remove_short_and_long_utt(c: Cut):
+            if c.duration < 0.2 or c.duration > 25.0:
+                return False
+
+            # In pruned RNN-T, we require that T >= S
+            # where T is the number of feature frames after subsampling
+            # and S is the number of tokens in the utterance
+
+            # In ./zipformer.py, the conv module uses the following expression
+            # for subsampling
+            T = ((c.num_frames - 7) // 2 + 1) // 2
+            tokens = sp.encode(c.supervisions[0].text, out_type=str)
+            return T >= len(tokens)
+
+        if self.args.ihm_only:
+            cuts_train = load_manifest_lazy(
+                self.args.manifest_dir / "cuts_train_ihm.jsonl.gz"
+            )
+        else:
+            cuts_train = load_manifest_lazy(
+                self.args.manifest_dir / "cuts_train_all.jsonl.gz"
+            )
+
+        return cuts_train.filter(_remove_short_and_long_utt)
+
+    @lru_cache()
+    def dev_ihm_cuts(self) -> CutSet:
+        logging.info("About to get AMI IHM dev cuts")
+        cs = load_manifest_lazy(self.args.manifest_dir / "cuts_dev_ihm.jsonl.gz")
+        return cs.filter(self.remove_short_cuts)
+
+    @lru_cache()
+    def dev_sdm_cuts(self) -> CutSet:
+        logging.info("About to get AMI SDM dev cuts")
+        cs = load_manifest_lazy(self.args.manifest_dir / "cuts_dev_sdm.jsonl.gz")
+        return cs.filter(self.remove_short_cuts)
+
+    @lru_cache()
+    def dev_gss_cuts(self) -> CutSet:
+        if not (self.args.manifest_dir / "cuts_dev_gss.jsonl.gz").exists():
+            logging.info("No GSS dev cuts found")
+            return None
+        logging.info("About to get AMI GSS-enhanced dev cuts")
+        cs = load_manifest_lazy(self.args.manifest_dir / "cuts_dev_gss.jsonl.gz")
+        return cs.filter(self.remove_short_cuts)
+
+    @lru_cache()
+    def test_ihm_cuts(self) -> CutSet:
+        logging.info("About to get AMI IHM test cuts")
+        cs = load_manifest_lazy(self.args.manifest_dir / "cuts_test_ihm.jsonl.gz")
+        return cs.filter(self.remove_short_cuts)
+
+    @lru_cache()
+    def test_sdm_cuts(self) -> CutSet:
+        logging.info("About to get AMI SDM test cuts")
+        cs = load_manifest_lazy(self.args.manifest_dir / "cuts_test_sdm.jsonl.gz")
+        return cs.filter(self.remove_short_cuts)
+
+    @lru_cache()
+    def test_gss_cuts(self) -> CutSet:
+        if not (self.args.manifest_dir / "cuts_test_gss.jsonl.gz").exists():
+            logging.info("No GSS test cuts found")
+            return None
+        logging.info("About to get AMI GSS-enhanced test cuts")
+        cs = load_manifest_lazy(self.args.manifest_dir / "cuts_test_gss.jsonl.gz")
+        return cs.filter(self.remove_short_cuts)
diff --git a/egs/ami/ASR/pruned_transducer_stateless7/beam_search.py b/egs/ami/ASR/pruned_transducer_stateless7/beam_search.py
new file mode 120000
index 000000000..37516affc
--- /dev/null
+++ b/egs/ami/ASR/pruned_transducer_stateless7/beam_search.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7/beam_search.py
\ No newline at end of file
diff --git a/egs/ami/ASR/pruned_transducer_stateless7/decode.py b/egs/ami/ASR/pruned_transducer_stateless7/decode.py
new file mode 100755
index 000000000..f47228fbe
--- /dev/null
+++ b/egs/ami/ASR/pruned_transducer_stateless7/decode.py
@@ -0,0 +1,747 @@
+#!/usr/bin/env python3
+#
+# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+(1) greedy search
+./pruned_transducer_stateless7/decode.py \
+        --iter 105000 \
+        --avg 10 \
+        --exp-dir ./pruned_transducer_stateless7/exp \
+        --max-duration 100 \
+        --decoding-method greedy_search
+
+(2) beam search
+./pruned_transducer_stateless7/decode.py \
+        --iter 105000 \
+        --avg 10 \
+        --exp-dir ./pruned_transducer_stateless7/exp \
+        --max-duration 500 \
+        --decoding-method beam_search \
+        --beam-size 4
+
+(3) modified beam search
+./pruned_transducer_stateless7/decode.py \
+        --iter 105000 \
+        --avg 10 \
+        --exp-dir ./pruned_transducer_stateless7/exp \
+        --max-duration 500 \
+        --decoding-method modified_beam_search \
+        --beam-size 4
+
+(4) fast beam search
+./pruned_transducer_stateless7/decode.py \
+        --iter 105000 \
+        --avg 10 \
+        --exp-dir ./pruned_transducer_stateless5/exp \
+        --max-duration 500 \
+        --decoding-method fast_beam_search \
+        --beam 4 \
+        --max-contexts 4 \
+        --max-states 8
+"""
+
+
+import argparse
+import logging
+from collections import defaultdict
+from pathlib import Path
+from typing import Dict, List, Optional, Tuple
+
+import k2
+import sentencepiece as spm
+import torch
+import torch.nn as nn
+from asr_datamodule import AmiAsrDataModule
+from beam_search import (
+    beam_search,
+    fast_beam_search_nbest_LG,
+    fast_beam_search_one_best,
+    greedy_search,
+    greedy_search_batch,
+    modified_beam_search,
+)
+from train import add_model_arguments, get_params, get_transducer_model
+
+from icefall import NgramLm
+from icefall.checkpoint import (
+    average_checkpoints,
+    average_checkpoints_with_averaged_model,
+    find_checkpoints,
+    load_checkpoint,
+)
+from icefall.lexicon import Lexicon
+from icefall.utils import (
+    AttributeDict,
+    setup_logger,
+    store_transcripts,
+    str2bool,
+    write_error_stats,
+)
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+
+    parser.add_argument(
+        "--epoch",
+        type=int,
+        default=30,
+        help="""It specifies the checkpoint to use for decoding.
+        Note: Epoch counts from 0.
+        You can specify --avg to use more checkpoints for model averaging.""",
+    )
+
+    parser.add_argument(
+        "--iter",
+        type=int,
+        default=0,
+        help="""If positive, --epoch is ignored and it
+        will use the checkpoint exp_dir/checkpoint-iter.pt.
+        You can specify --avg to use more checkpoints for model averaging.
+        """,
+    )
+
+    parser.add_argument(
+        "--avg",
+        type=int,
+        default=10,
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch' and '--iter'",
+    )
+
+    parser.add_argument(
+        "--use-averaged-model",
+        type=str2bool,
+        default=True,
+        help="Whether to load averaged model. Currently it only supports "
+        "using --epoch. If True, it would decode with the averaged model "
+        "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+        "Actually only the models with epoch number of `epoch-avg` and "
+        "`epoch` are loaded for averaging. ",
+    )
+
+    parser.add_argument(
+        "--exp-dir",
+        type=str,
+        default="pruned_transducer_stateless2/exp",
+        help="The experiment dir",
+    )
+
+    parser.add_argument(
+        "--lang-dir",
+        type=Path,
+        default="data/lang_bpe_500",
+        help="The lang dir containing word table and LG graph",
+    )
+
+    parser.add_argument(
+        "--decoding-method",
+        type=str,
+        default="greedy_search",
+        help="""Possible values are:
+          - greedy_search
+          - beam_search
+          - modified_beam_search
+          - fast_beam_search
+          - fast_beam_search_nbest
+          - fast_beam_search_nbest_oracle
+          - fast_beam_search_nbest_LG
+        If you use fast_beam_search_nbest_LG, you have to specify
+        `--lang-dir`, which should contain `LG.pt`.
+        """,
+    )
+
+    parser.add_argument(
+        "--beam-size",
+        type=int,
+        default=4,
+        help="""An interger indicating how many candidates we will keep for each
+        frame. Used only when --decoding-method is beam_search or
+        modified_beam_search.""",
+    )
+
+    parser.add_argument(
+        "--beam",
+        type=float,
+        default=4,
+        help="""A floating point value to calculate the cutoff score during beam
+        search (i.e., `cutoff = max-score - beam`), which is the same as the
+        `beam` in Kaldi.
+        Used only when --decoding-method is fast_beam_search""",
+    )
+
+    parser.add_argument(
+        "--ngram-lm-scale",
+        type=float,
+        default=0.01,
+        help="""
+        Used only when --decoding_method is fast_beam_search_nbest_LG.
+        It specifies the scale for n-gram LM scores.
+        """,
+    )
+
+    parser.add_argument(
+        "--max-contexts",
+        type=int,
+        default=8,
+        help="""Used only when --decoding-method is
+        fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
+        and fast_beam_search_nbest_oracle""",
+    )
+
+    parser.add_argument(
+        "--max-states",
+        type=int,
+        default=64,
+        help="""Used only when --decoding-method is
+        fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
+        and fast_beam_search_nbest_oracle""",
+    )
+
+    parser.add_argument(
+        "--context-size",
+        type=int,
+        default=2,
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+    )
+    parser.add_argument(
+        "--max-sym-per-frame",
+        type=int,
+        default=1,
+        help="""Maximum number of symbols per frame.
+        Used only when --decoding_method is greedy_search""",
+    )
+
+    parser.add_argument(
+        "--num-paths",
+        type=int,
+        default=200,
+        help="""Number of paths for nbest decoding.
+        Used only when the decoding method is fast_beam_search_nbest,
+        fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
+    )
+
+    parser.add_argument(
+        "--nbest-scale",
+        type=float,
+        default=0.5,
+        help="""Scale applied to lattice scores when computing nbest paths.
+        Used only when the decoding method is fast_beam_search_nbest,
+        fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
+    )
+
+    add_model_arguments(parser)
+
+    return parser
+
+
+def decode_one_batch(
+    params: AttributeDict,
+    model: nn.Module,
+    sp: spm.SentencePieceProcessor,
+    batch: dict,
+    decoding_graph: Optional[k2.Fsa] = None,
+    word_table: Optional[k2.SymbolTable] = None,
+) -> Dict[str, List[List[str]]]:
+    """Decode one batch and return the result in a dict. The dict has the
+    following format:
+
+        - key: It indicates the setting used for decoding. For example,
+               if greedy_search is used, it would be "greedy_search"
+               If beam search with a beam size of 7 is used, it would be
+               "beam_7"
+        - value: It contains the decoding result. `len(value)` equals to
+                 batch size. `value[i]` is the decoding result for the i-th
+                 utterance in the given batch.
+    Args:
+      params:
+        It's the return value of :func:`get_params`.
+      model:
+        The neural model.
+      sp:
+        The BPE model.
+      batch:
+        It is the return value from iterating
+        `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
+        for the format of the `batch`.
+      decoding_graph:
+        The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
+        only when --decoding_method is fast_beam_search.
+      word_table:
+        The word symbol table.
+    Returns:
+      Return the decoding result. See above description for the format of
+      the returned dict.
+    """
+    device = model.device
+    feature = batch["inputs"]
+    assert feature.ndim == 3
+
+    feature = feature.to(device)
+    # at entry, feature is (N, T, C)
+
+    supervisions = batch["supervisions"]
+    feature_lens = supervisions["num_frames"].to(device)
+
+    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
+    hyps = []
+
+    if params.decoding_method == "fast_beam_search":
+        hyp_tokens = fast_beam_search_one_best(
+            model=model,
+            decoding_graph=decoding_graph,
+            encoder_out=encoder_out,
+            encoder_out_lens=encoder_out_lens,
+            beam=params.beam,
+            max_contexts=params.max_contexts,
+            max_states=params.max_states,
+        )
+        for hyp in sp.decode(hyp_tokens):
+            hyps.append(hyp.split())
+    elif params.decoding_method == "fast_beam_search_nbest_LG":
+        hyp_tokens = fast_beam_search_nbest_LG(
+            model=model,
+            decoding_graph=decoding_graph,
+            encoder_out=encoder_out,
+            encoder_out_lens=encoder_out_lens,
+            beam=params.beam,
+            max_contexts=params.max_contexts,
+            max_states=params.max_states,
+            num_paths=params.num_paths,
+            nbest_scale=params.nbest_scale,
+        )
+        for hyp in hyp_tokens:
+            hyps.append([word_table[i] for i in hyp])
+    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
+        hyp_tokens = greedy_search_batch(
+            model=model,
+            encoder_out=encoder_out,
+            encoder_out_lens=encoder_out_lens,
+        )
+        for hyp in sp.decode(hyp_tokens):
+            hyps.append(hyp.split())
+    elif params.decoding_method == "modified_beam_search":
+        hyp_tokens = modified_beam_search(
+            model=model,
+            encoder_out=encoder_out,
+            encoder_out_lens=encoder_out_lens,
+            beam=params.beam_size,
+        )
+        for hyp in sp.decode(hyp_tokens):
+            hyps.append(hyp.split())
+    else:
+        batch_size = encoder_out.size(0)
+
+        for i in range(batch_size):
+            # fmt: off
+            encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
+            # fmt: on
+            if params.decoding_method == "greedy_search":
+                hyp = greedy_search(
+                    model=model,
+                    encoder_out=encoder_out_i,
+                    max_sym_per_frame=params.max_sym_per_frame,
+                )
+            elif params.decoding_method == "beam_search":
+                hyp = beam_search(
+                    model=model,
+                    encoder_out=encoder_out_i,
+                    beam=params.beam_size,
+                )
+            else:
+                raise ValueError(
+                    f"Unsupported decoding method: {params.decoding_method}"
+                )
+            hyps.append(sp.decode(hyp).split())
+
+    if params.decoding_method == "greedy_search":
+        return {"greedy_search": hyps}
+    elif params.decoding_method == "fast_beam_search":
+        return {
+            (
+                f"beam_{params.beam}_"
+                f"max_contexts_{params.max_contexts}_"
+                f"max_states_{params.max_states}"
+            ): hyps
+        }
+    elif "fast_beam_search" in params.decoding_method:
+        key = f"beam_{params.beam}_"
+        key += f"max_contexts_{params.max_contexts}_"
+        key += f"max_states_{params.max_states}"
+        if "nbest" in params.decoding_method:
+            key += f"_num_paths_{params.num_paths}_"
+            key += f"nbest_scale_{params.nbest_scale}"
+            if "LG" in params.decoding_method:
+                key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
+
+        return {key: hyps}
+    else:
+        return {f"beam_size_{params.beam_size}": hyps}
+
+
+def decode_dataset(
+    dl: torch.utils.data.DataLoader,
+    params: AttributeDict,
+    model: nn.Module,
+    sp: spm.SentencePieceProcessor,
+    decoding_graph: Optional[k2.Fsa] = None,
+    word_table: Optional[k2.SymbolTable] = None,
+) -> Dict[str, List[Tuple[List[str], List[str]]]]:
+    """Decode dataset.
+
+    Args:
+      dl:
+        PyTorch's dataloader containing the dataset to decode.
+      params:
+        It is returned by :func:`get_params`.
+      model:
+        The neural model.
+      sp:
+        The BPE model.
+      decoding_graph:
+        The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
+        only when --decoding_method is fast_beam_search.
+    Returns:
+      Return a dict, whose key may be "greedy_search" if greedy search
+      is used, or it may be "beam_7" if beam size of 7 is used.
+      Its value is a list of tuples. Each tuple contains two elements:
+      The first is the reference transcript, and the second is the
+      predicted result.
+    """
+    num_cuts = 0
+
+    try:
+        num_batches = len(dl)
+    except TypeError:
+        num_batches = "?"
+
+    if params.decoding_method == "greedy_search":
+        log_interval = 100
+    else:
+        log_interval = 2
+
+    results = defaultdict(list)
+    for batch_idx, batch in enumerate(dl):
+        texts = batch["supervisions"]["text"]
+        cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
+
+        hyps_dict = decode_one_batch(
+            params=params,
+            model=model,
+            sp=sp,
+            decoding_graph=decoding_graph,
+            word_table=word_table,
+            batch=batch,
+        )
+
+        for name, hyps in hyps_dict.items():
+            this_batch = []
+            assert len(hyps) == len(texts)
+            for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
+                ref_words = ref_text.split()
+                this_batch.append((cut_id, ref_words, hyp_words))
+
+            results[name].extend(this_batch)
+
+        num_cuts += len(texts)
+
+        if batch_idx % log_interval == 0:
+            batch_str = f"{batch_idx}/{num_batches}"
+
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+    return results
+
+
+def save_results(
+    params: AttributeDict,
+    test_set_name: str,
+    results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
+):
+    test_set_wers = dict()
+    test_set_cers = dict()
+    for key, results in results_dict.items():
+        recog_path = (
+            params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
+        )
+        store_transcripts(filename=recog_path, texts=results)
+        logging.info(f"The transcripts are stored in {recog_path}")
+
+        # The following prints out WERs, per-word error statistics and aligned
+        # ref/hyp pairs.
+        wers_filename = (
+            params.res_dir / f"wers-{test_set_name}-{key}-{params.suffix}.txt"
+        )
+        with open(wers_filename, "w") as f:
+            wer = write_error_stats(
+                f, f"{test_set_name}-{key}", results, enable_log=True
+            )
+            test_set_wers[key] = wer
+
+        # we also compute CER for AMI dataset.
+        results_char = []
+        for res in results:
+            results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
+        cers_filename = (
+            params.res_dir / f"cers-{test_set_name}-{key}-{params.suffix}.txt"
+        )
+        with open(cers_filename, "w") as f:
+            cer = write_error_stats(
+                f, f"{test_set_name}-{key}", results_char, enable_log=True
+            )
+            test_set_cers[key] = cer
+
+        logging.info("Wrote detailed error stats to {}".format(wers_filename))
+
+    test_set_wers = {k: v for k, v in sorted(test_set_wers.items(), key=lambda x: x[1])}
+    test_set_cers = {k: v for k, v in sorted(test_set_cers.items(), key=lambda x: x[1])}
+    errs_info = (
+        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+    )
+    with open(errs_info, "w") as f:
+        print("settings\tWER\tCER", file=f)
+        for key in test_set_wers:
+            print(
+                "{}\t{}\t{}".format(key, test_set_wers[key], test_set_cers[key]),
+                file=f,
+            )
+
+    s = "\nFor {}, WER/CER of different settings are:\n".format(test_set_name)
+    note = "\tbest for {}".format(test_set_name)
+    for key in test_set_wers:
+        s += "{}\t{}\t{}{}\n".format(key, test_set_wers[key], test_set_cers[key], note)
+        note = ""
+    logging.info(s)
+
+
+@torch.no_grad()
+def main():
+    parser = get_parser()
+    AmiAsrDataModule.add_arguments(parser)
+    args = parser.parse_args()
+    args.exp_dir = Path(args.exp_dir)
+
+    params = get_params()
+    params.update(vars(args))
+
+    assert params.decoding_method in (
+        "greedy_search",
+        "beam_search",
+        "fast_beam_search",
+        "fast_beam_search_nbest_LG",
+        "modified_beam_search",
+    )
+    params.res_dir = params.exp_dir / params.decoding_method
+
+    if params.iter > 0:
+        params.suffix = f"iter-{params.iter}-avg-{params.avg}"
+    else:
+        params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
+
+    if "fast_beam_search" in params.decoding_method:
+        params.suffix += f"-beam-{params.beam}"
+        params.suffix += f"-max-contexts-{params.max_contexts}"
+        params.suffix += f"-max-states-{params.max_states}"
+        if "nbest" in params.decoding_method:
+            params.suffix += f"-nbest-scale-{params.nbest_scale}"
+            params.suffix += f"-num-paths-{params.num_paths}"
+            if "LG" in params.decoding_method:
+                params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
+    elif "beam_search" in params.decoding_method:
+        params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
+    else:
+        params.suffix += f"-context-{params.context_size}"
+        params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
+
+    setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
+    logging.info("Decoding started")
+
+    device = torch.device("cpu")
+    if torch.cuda.is_available():
+        device = torch.device("cuda", 0)
+
+    logging.info(f"Device: {device}")
+
+    sp = spm.SentencePieceProcessor()
+    sp.load(f"{params.lang_dir}/bpe.model")
+
+    #  is defined in local/train_bpe_model.py
+    params.blank_id = sp.piece_to_id("")
+    params.unk_id = sp.piece_to_id("")
+    params.vocab_size = sp.get_piece_size()
+
+    logging.info(params)
+
+    logging.info("About to create model")
+    model = get_transducer_model(params)
+
+    if not params.use_averaged_model:
+        if params.iter > 0:
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg
+            ]
+            if len(filenames) == 0:
+                raise ValueError(
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            elif len(filenames) < params.avg:
+                raise ValueError(
+                    f"Not enough checkpoints ({len(filenames)}) found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            logging.info(f"averaging {filenames}")
+            model.to(device)
+            model.load_state_dict(average_checkpoints(filenames, device=device))
+        elif params.avg == 1:
+            load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+        else:
+            start = params.epoch - params.avg + 1
+            filenames = []
+            for i in range(start, params.epoch + 1):
+                if i >= 1:
+                    filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+            logging.info(f"averaging {filenames}")
+            model.to(device)
+            model.load_state_dict(average_checkpoints(filenames, device=device))
+    else:
+        if params.iter > 0:
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg + 1
+            ]
+            if len(filenames) == 0:
+                raise ValueError(
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            elif len(filenames) < params.avg + 1:
+                raise ValueError(
+                    f"Not enough checkpoints ({len(filenames)}) found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            filename_start = filenames[-1]
+            filename_end = filenames[0]
+            logging.info(
+                "Calculating the averaged model over iteration checkpoints"
+                f" from {filename_start} (excluded) to {filename_end}"
+            )
+            model.to(device)
+            model.load_state_dict(
+                average_checkpoints_with_averaged_model(
+                    filename_start=filename_start,
+                    filename_end=filename_end,
+                    device=device,
+                )
+            )
+        else:
+            assert params.avg > 0, params.avg
+            start = params.epoch - params.avg
+            assert start >= 1, start
+            filename_start = f"{params.exp_dir}/epoch-{start}.pt"
+            filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
+            logging.info(
+                f"Calculating the averaged model over epoch range from "
+                f"{start} (excluded) to {params.epoch}"
+            )
+            model.to(device)
+            model.load_state_dict(
+                average_checkpoints_with_averaged_model(
+                    filename_start=filename_start,
+                    filename_end=filename_end,
+                    device=device,
+                )
+            )
+
+    model.to(device)
+    model.eval()
+    model.device = device
+
+    if "fast_beam_search" in params.decoding_method:
+        if params.decoding_method == "fast_beam_search_nbest_LG":
+            lexicon = Lexicon(params.lang_dir)
+            word_table = lexicon.word_table
+            lg_filename = params.lang_dir / "LG.pt"
+            logging.info(f"Loading {lg_filename}")
+            decoding_graph = k2.Fsa.from_dict(
+                torch.load(lg_filename, map_location=device)
+            )
+            decoding_graph.scores *= params.ngram_lm_scale
+        else:
+            word_table = None
+            decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
+    else:
+        decoding_graph = None
+        word_table = None
+
+    num_param = sum([p.numel() for p in model.parameters()])
+    logging.info(f"Number of model parameters: {num_param}")
+
+    ami = AmiAsrDataModule(args)
+
+    dev_ihm_cuts = ami.dev_ihm_cuts()
+    test_ihm_cuts = ami.test_ihm_cuts()
+    dev_sdm_cuts = ami.dev_sdm_cuts()
+    test_sdm_cuts = ami.test_sdm_cuts()
+    dev_gss_cuts = ami.dev_gss_cuts()
+    test_gss_cuts = ami.test_gss_cuts()
+
+    dev_ihm_dl = ami.test_dataloaders(dev_ihm_cuts)
+    test_ihm_dl = ami.test_dataloaders(test_ihm_cuts)
+    dev_sdm_dl = ami.test_dataloaders(dev_sdm_cuts)
+    test_sdm_dl = ami.test_dataloaders(test_sdm_cuts)
+    if dev_gss_cuts is not None:
+        dev_gss_dl = ami.test_dataloaders(dev_gss_cuts)
+    if test_gss_cuts is not None:
+        test_gss_dl = ami.test_dataloaders(test_gss_cuts)
+
+    test_sets = {
+        "dev_ihm": (dev_ihm_dl, dev_ihm_cuts),
+        "test_ihm": (test_ihm_dl, test_ihm_cuts),
+        "dev_sdm": (dev_sdm_dl, dev_sdm_cuts),
+        "test_sdm": (test_sdm_dl, test_sdm_cuts),
+    }
+    if dev_gss_cuts is not None:
+        test_sets["dev_gss"] = (dev_gss_dl, dev_gss_cuts)
+    if test_gss_cuts is not None:
+        test_sets["test_gss"] = (test_gss_dl, test_gss_cuts)
+
+    for test_set in test_sets:
+        logging.info(f"Decoding {test_set}")
+        dl, cuts = test_sets[test_set]
+        results_dict = decode_dataset(
+            dl=dl,
+            params=params,
+            model=model,
+            sp=sp,
+            word_table=word_table,
+            decoding_graph=decoding_graph,
+        )
+
+        save_results(
+            params=params,
+            test_set_name=test_set,
+            results_dict=results_dict,
+        )
+
+    logging.info("Done!")
+
+
+if __name__ == "__main__":
+    main()
diff --git a/egs/ami/ASR/pruned_transducer_stateless7/decoder.py b/egs/ami/ASR/pruned_transducer_stateless7/decoder.py
new file mode 120000
index 000000000..8283d8c5a
--- /dev/null
+++ b/egs/ami/ASR/pruned_transducer_stateless7/decoder.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7/decoder.py
\ No newline at end of file
diff --git a/egs/ami/ASR/pruned_transducer_stateless7/encoder_interface.py b/egs/ami/ASR/pruned_transducer_stateless7/encoder_interface.py
new file mode 120000
index 000000000..0c2673d46
--- /dev/null
+++ b/egs/ami/ASR/pruned_transducer_stateless7/encoder_interface.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7/encoder_interface.py
\ No newline at end of file
diff --git a/egs/ami/ASR/pruned_transducer_stateless7/export.py b/egs/ami/ASR/pruned_transducer_stateless7/export.py
new file mode 120000
index 000000000..2713792e6
--- /dev/null
+++ b/egs/ami/ASR/pruned_transducer_stateless7/export.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7/export.py
\ No newline at end of file
diff --git a/egs/ami/ASR/pruned_transducer_stateless7/joiner.py b/egs/ami/ASR/pruned_transducer_stateless7/joiner.py
new file mode 120000
index 000000000..0f0c3c90a
--- /dev/null
+++ b/egs/ami/ASR/pruned_transducer_stateless7/joiner.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7/joiner.py
\ No newline at end of file
diff --git a/egs/ami/ASR/pruned_transducer_stateless7/model.py b/egs/ami/ASR/pruned_transducer_stateless7/model.py
new file mode 120000
index 000000000..0d8bc665b
--- /dev/null
+++ b/egs/ami/ASR/pruned_transducer_stateless7/model.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7/model.py
\ No newline at end of file
diff --git a/egs/ami/ASR/pruned_transducer_stateless7/optim.py b/egs/ami/ASR/pruned_transducer_stateless7/optim.py
new file mode 120000
index 000000000..8a05abb5f
--- /dev/null
+++ b/egs/ami/ASR/pruned_transducer_stateless7/optim.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7/optim.py
\ No newline at end of file
diff --git a/egs/ami/ASR/pruned_transducer_stateless7/scaling.py b/egs/ami/ASR/pruned_transducer_stateless7/scaling.py
new file mode 120000
index 000000000..5f9be9fe0
--- /dev/null
+++ b/egs/ami/ASR/pruned_transducer_stateless7/scaling.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7/scaling.py
\ No newline at end of file
diff --git a/egs/ami/ASR/pruned_transducer_stateless7/scaling_converter.py b/egs/ami/ASR/pruned_transducer_stateless7/scaling_converter.py
new file mode 120000
index 000000000..f9960e5c6
--- /dev/null
+++ b/egs/ami/ASR/pruned_transducer_stateless7/scaling_converter.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7/scaling_converter.py
\ No newline at end of file
diff --git a/egs/ami/ASR/pruned_transducer_stateless7/train.py b/egs/ami/ASR/pruned_transducer_stateless7/train.py
new file mode 100755
index 000000000..b5efb3405
--- /dev/null
+++ b/egs/ami/ASR/pruned_transducer_stateless7/train.py
@@ -0,0 +1,1184 @@
+#!/usr/bin/env python3
+# Copyright    2021-2022  Xiaomi Corp.        (authors: Fangjun Kuang,
+#                                                       Wei Kang,
+#                                                       Mingshuang Luo,)
+#                                                       Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+
+export CUDA_VISIBLE_DEVICES="0,1,2,3"
+
+./pruned_transducer_stateless7/train.py \
+  --world-size 4 \
+  --num-epochs 15 \
+  --start-epoch 1 \
+  --exp-dir pruned_transducer_stateless7/exp \
+  --max-duration 150 \
+    --use-fp16 True
+
+"""
+
+
+import argparse
+import copy
+import logging
+import warnings
+from pathlib import Path
+from shutil import copyfile
+from typing import Any, Dict, Optional, Tuple, Union
+
+import k2
+import optim
+import sentencepiece as spm
+import torch
+import torch.multiprocessing as mp
+import torch.nn as nn
+from asr_datamodule import AmiAsrDataModule
+from decoder import Decoder
+from joiner import Joiner
+from lhotse.dataset.sampling.base import CutSampler
+from lhotse.utils import fix_random_seed
+from model import Transducer
+from optim import Eden, ScaledAdam
+from torch import Tensor
+from torch.cuda.amp import GradScaler
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.utils.tensorboard import SummaryWriter
+from zipformer import Zipformer
+
+from icefall import diagnostics
+from icefall.checkpoint import load_checkpoint, remove_checkpoints
+from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
+from icefall.checkpoint import (
+    save_checkpoint_with_global_batch_idx,
+    update_averaged_model,
+)
+from icefall.dist import cleanup_dist, setup_dist
+from icefall.env import get_env_info
+from icefall.hooks import register_inf_check_hooks
+from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
+
+LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
+
+
+def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
+    if isinstance(model, DDP):
+        # get underlying nn.Module
+        model = model.module
+    for module in model.modules():
+        if hasattr(module, "batch_count"):
+            module.batch_count = batch_count
+
+
+def add_model_arguments(parser: argparse.ArgumentParser):
+    parser.add_argument(
+        "--num-encoder-layers",
+        type=str,
+        default="2,4,3,2,4",
+        help="Number of zipformer encoder layers, comma separated.",
+    )
+
+    parser.add_argument(
+        "--feedforward-dims",
+        type=str,
+        default="1024,1024,2048,2048,1024",
+        help="Feedforward dimension of the zipformer encoder layers, comma separated.",
+    )
+
+    parser.add_argument(
+        "--nhead",
+        type=str,
+        default="8,8,8,8,8",
+        help="Number of attention heads in the zipformer encoder layers.",
+    )
+
+    parser.add_argument(
+        "--encoder-dims",
+        type=str,
+        default="384,384,384,384,384",
+        help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated",
+    )
+
+    parser.add_argument(
+        "--attention-dims",
+        type=str,
+        default="192,192,192,192,192",
+        help="""Attention dimension in the 2 blocks of zipformer encoder layers, comma separated;
+        not the same as embedding dimension.""",
+    )
+
+    parser.add_argument(
+        "--encoder-unmasked-dims",
+        type=str,
+        default="256,256,256,256,256",
+        help="Unmasked dimensions in the encoders, relates to augmentation during training.  "
+        "Must be <= each of encoder_dims.  Empirically, less than 256 seems to make performance "
+        " worse.",
+    )
+
+    parser.add_argument(
+        "--zipformer-downsampling-factors",
+        type=str,
+        default="1,2,4,8,2",
+        help="Downsampling factor for each stack of encoder layers.",
+    )
+
+    parser.add_argument(
+        "--cnn-module-kernels",
+        type=str,
+        default="31,31,31,31,31",
+        help="Sizes of kernels in convolution modules",
+    )
+
+    parser.add_argument(
+        "--decoder-dim",
+        type=int,
+        default=512,
+        help="Embedding dimension in the decoder model.",
+    )
+
+    parser.add_argument(
+        "--joiner-dim",
+        type=int,
+        default=512,
+        help="""Dimension used in the joiner model.
+        Outputs from the encoder and decoder model are projected
+        to this dimension before adding.
+        """,
+    )
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+
+    parser.add_argument(
+        "--world-size",
+        type=int,
+        default=1,
+        help="Number of GPUs for DDP training.",
+    )
+
+    parser.add_argument(
+        "--master-port",
+        type=int,
+        default=12354,
+        help="Master port to use for DDP training.",
+    )
+
+    parser.add_argument(
+        "--tensorboard",
+        type=str2bool,
+        default=True,
+        help="Should various information be logged in tensorboard.",
+    )
+
+    parser.add_argument(
+        "--num-epochs",
+        type=int,
+        default=11,
+        help="Number of epochs to train.",
+    )
+
+    parser.add_argument(
+        "--start-epoch",
+        type=int,
+        default=1,
+        help="""Resume training from this epoch. It should be positive.
+        If larger than 1, it will load checkpoint from
+        exp-dir/epoch-{start_epoch-1}.pt
+        """,
+    )
+
+    parser.add_argument(
+        "--start-batch",
+        type=int,
+        default=0,
+        help="""If positive, --start-epoch is ignored and
+        it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt
+        """,
+    )
+
+    parser.add_argument(
+        "--exp-dir",
+        type=str,
+        default="pruned_transducer_stateless7/exp",
+        help="""The experiment dir.
+        It specifies the directory where all training related
+        files, e.g., checkpoints, log, etc, are saved
+        """,
+    )
+
+    parser.add_argument(
+        "--bpe-model",
+        type=str,
+        default="data/lang_bpe_500/bpe.model",
+        help="Path to the BPE model",
+    )
+
+    parser.add_argument(
+        "--base-lr", type=float, default=0.05, help="The base learning rate."
+    )
+
+    parser.add_argument(
+        "--lr-batches",
+        type=float,
+        default=5000,
+        help="""Number of steps that affects how rapidly the learning rate
+        decreases. We suggest not to change this.""",
+    )
+
+    parser.add_argument(
+        "--lr-epochs",
+        type=float,
+        default=3.5,
+        help="""Number of epochs that affects how rapidly the learning rate decreases.
+        """,
+    )
+
+    parser.add_argument(
+        "--context-size",
+        type=int,
+        default=2,
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+    )
+
+    parser.add_argument(
+        "--prune-range",
+        type=int,
+        default=5,
+        help="The prune range for rnnt loss, it means how many symbols(context)"
+        "we are using to compute the loss",
+    )
+
+    parser.add_argument(
+        "--lm-scale",
+        type=float,
+        default=0.25,
+        help="The scale to smooth the loss with lm "
+        "(output of prediction network) part.",
+    )
+
+    parser.add_argument(
+        "--am-scale",
+        type=float,
+        default=0.0,
+        help="The scale to smooth the loss with am (output of encoder network)" "part.",
+    )
+
+    parser.add_argument(
+        "--simple-loss-scale",
+        type=float,
+        default=0.5,
+        help="To get pruning ranges, we will calculate a simple version"
+        "loss(joiner is just addition), this simple loss also uses for"
+        "training (as a regularization item). We will scale the simple loss"
+        "with this parameter before adding to the final loss.",
+    )
+
+    parser.add_argument(
+        "--seed",
+        type=int,
+        default=42,
+        help="The seed for random generators intended for reproducibility",
+    )
+
+    parser.add_argument(
+        "--print-diagnostics",
+        type=str2bool,
+        default=False,
+        help="Accumulate stats on activations, print them and exit.",
+    )
+
+    parser.add_argument(
+        "--inf-check",
+        type=str2bool,
+        default=False,
+        help="Add hooks to check for infinite module outputs and gradients.",
+    )
+
+    parser.add_argument(
+        "--save-every-n",
+        type=int,
+        default=5000,
+        help="""Save checkpoint after processing this number of batches"
+        periodically. We save checkpoint to exp-dir/ whenever
+        params.batch_idx_train % save_every_n == 0. The checkpoint filename
+        has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
+        Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
+        end of each epoch where `xxx` is the epoch number counting from 0.
+        """,
+    )
+
+    parser.add_argument(
+        "--keep-last-k",
+        type=int,
+        default=10,
+        help="""Only keep this number of checkpoints on disk.
+        For instance, if it is 3, there are only 3 checkpoints
+        in the exp-dir with filenames `checkpoint-xxx.pt`.
+        It does not affect checkpoints with name `epoch-xxx.pt`.
+        """,
+    )
+
+    parser.add_argument(
+        "--average-period",
+        type=int,
+        default=200,
+        help="""Update the averaged model, namely `model_avg`, after processing
+        this number of batches. `model_avg` is a separate version of model,
+        in which each floating-point parameter is the average of all the
+        parameters from the start of training. Each time we take the average,
+        we do: `model_avg = model * (average_period / batch_idx_train) +
+            model_avg * ((batch_idx_train - average_period) / batch_idx_train)`.
+        """,
+    )
+
+    parser.add_argument(
+        "--use-fp16",
+        type=str2bool,
+        default=False,
+        help="Whether to use half precision training.",
+    )
+
+    add_model_arguments(parser)
+
+    return parser
+
+
+def get_params() -> AttributeDict:
+    """Return a dict containing training parameters.
+
+    All training related parameters that are not passed from the commandline
+    are saved in the variable `params`.
+
+    Commandline options are merged into `params` after they are parsed, so
+    you can also access them via `params`.
+
+    Explanation of options saved in `params`:
+
+        - best_train_loss: Best training loss so far. It is used to select
+                           the model that has the lowest training loss. It is
+                           updated during the training.
+
+        - best_valid_loss: Best validation loss so far. It is used to select
+                           the model that has the lowest validation loss. It is
+                           updated during the training.
+
+        - best_train_epoch: It is the epoch that has the best training loss.
+
+        - best_valid_epoch: It is the epoch that has the best validation loss.
+
+        - batch_idx_train: Used to writing statistics to tensorboard. It
+                           contains number of batches trained so far across
+                           epochs.
+
+        - log_interval:  Print training loss if batch_idx % log_interval` is 0
+
+        - reset_interval: Reset statistics if batch_idx % reset_interval is 0
+
+        - valid_interval:  Run validation if batch_idx % valid_interval is 0
+
+        - feature_dim: The model input dim. It has to match the one used
+                       in computing features.
+
+        - subsampling_factor:  The subsampling factor for the model.
+
+        - encoder_dim: Hidden dim for multi-head attention model.
+
+        - num_decoder_layers: Number of decoder layer of transformer decoder.
+
+        - warm_step: The warmup period that dictates the decay of the
+              scale on "simple" (un-pruned) loss.
+    """
+    params = AttributeDict(
+        {
+            "best_train_loss": float("inf"),
+            "best_valid_loss": float("inf"),
+            "best_train_epoch": -1,
+            "best_valid_epoch": -1,
+            "batch_idx_train": 0,
+            "log_interval": 100,
+            "reset_interval": 200,
+            "valid_interval": 3000,  # For the 100h subset, use 800
+            # parameters for zipformer
+            "feature_dim": 80,
+            "subsampling_factor": 4,  # not passed in, this is fixed.
+            "warm_step": 2000,
+            "env_info": get_env_info(),
+        }
+    )
+
+    return params
+
+
+def get_encoder_model(params: AttributeDict) -> nn.Module:
+    # TODO: We can add an option to switch between Zipformer and Transformer
+    def to_int_tuple(s: str):
+        return tuple(map(int, s.split(",")))
+
+    encoder = Zipformer(
+        num_features=params.feature_dim,
+        output_downsampling_factor=2,
+        zipformer_downsampling_factors=to_int_tuple(
+            params.zipformer_downsampling_factors
+        ),
+        encoder_dims=to_int_tuple(params.encoder_dims),
+        attention_dim=to_int_tuple(params.attention_dims),
+        encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims),
+        nhead=to_int_tuple(params.nhead),
+        feedforward_dim=to_int_tuple(params.feedforward_dims),
+        cnn_module_kernels=to_int_tuple(params.cnn_module_kernels),
+        num_encoder_layers=to_int_tuple(params.num_encoder_layers),
+    )
+    return encoder
+
+
+def get_decoder_model(params: AttributeDict) -> nn.Module:
+    decoder = Decoder(
+        vocab_size=params.vocab_size,
+        decoder_dim=params.decoder_dim,
+        blank_id=params.blank_id,
+        context_size=params.context_size,
+    )
+    return decoder
+
+
+def get_joiner_model(params: AttributeDict) -> nn.Module:
+    joiner = Joiner(
+        encoder_dim=int(params.encoder_dims.split(",")[-1]),
+        decoder_dim=params.decoder_dim,
+        joiner_dim=params.joiner_dim,
+        vocab_size=params.vocab_size,
+    )
+    return joiner
+
+
+def get_transducer_model(params: AttributeDict) -> nn.Module:
+    encoder = get_encoder_model(params)
+    decoder = get_decoder_model(params)
+    joiner = get_joiner_model(params)
+
+    model = Transducer(
+        encoder=encoder,
+        decoder=decoder,
+        joiner=joiner,
+        encoder_dim=int(params.encoder_dims.split(",")[-1]),
+        decoder_dim=params.decoder_dim,
+        joiner_dim=params.joiner_dim,
+        vocab_size=params.vocab_size,
+    )
+    return model
+
+
+def load_checkpoint_if_available(
+    params: AttributeDict,
+    model: nn.Module,
+    model_avg: nn.Module = None,
+    optimizer: Optional[torch.optim.Optimizer] = None,
+    scheduler: Optional[LRSchedulerType] = None,
+) -> Optional[Dict[str, Any]]:
+    """Load checkpoint from file.
+
+    If params.start_batch is positive, it will load the checkpoint from
+    `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if
+    params.start_epoch is larger than 1, it will load the checkpoint from
+    `params.start_epoch - 1`.
+
+    Apart from loading state dict for `model` and `optimizer` it also updates
+    `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
+    and `best_valid_loss` in `params`.
+
+    Args:
+      params:
+        The return value of :func:`get_params`.
+      model:
+        The training model.
+      model_avg:
+        The stored model averaged from the start of training.
+      optimizer:
+        The optimizer that we are using.
+      scheduler:
+        The scheduler that we are using.
+    Returns:
+      Return a dict containing previously saved training info.
+    """
+    if params.start_batch > 0:
+        filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt"
+    elif params.start_epoch > 1:
+        filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
+    else:
+        return None
+
+    assert filename.is_file(), f"{filename} does not exist!"
+
+    saved_params = load_checkpoint(
+        filename,
+        model=model,
+        model_avg=model_avg,
+        optimizer=optimizer,
+        scheduler=scheduler,
+    )
+
+    keys = [
+        "best_train_epoch",
+        "best_valid_epoch",
+        "batch_idx_train",
+        "best_train_loss",
+        "best_valid_loss",
+    ]
+    for k in keys:
+        params[k] = saved_params[k]
+
+    if params.start_batch > 0:
+        if "cur_epoch" in saved_params:
+            params["start_epoch"] = saved_params["cur_epoch"]
+
+        if "cur_batch_idx" in saved_params:
+            params["cur_batch_idx"] = saved_params["cur_batch_idx"]
+
+    return saved_params
+
+
+def save_checkpoint(
+    params: AttributeDict,
+    model: Union[nn.Module, DDP],
+    model_avg: Optional[nn.Module] = None,
+    optimizer: Optional[torch.optim.Optimizer] = None,
+    scheduler: Optional[LRSchedulerType] = None,
+    sampler: Optional[CutSampler] = None,
+    scaler: Optional[GradScaler] = None,
+    rank: int = 0,
+) -> None:
+    """Save model, optimizer, scheduler and training stats to file.
+
+    Args:
+      params:
+        It is returned by :func:`get_params`.
+      model:
+        The training model.
+      model_avg:
+        The stored model averaged from the start of training.
+      optimizer:
+        The optimizer used in the training.
+      sampler:
+       The sampler for the training dataset.
+      scaler:
+        The scaler used for mix precision training.
+    """
+    if rank != 0:
+        return
+    filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
+    save_checkpoint_impl(
+        filename=filename,
+        model=model,
+        model_avg=model_avg,
+        params=params,
+        optimizer=optimizer,
+        scheduler=scheduler,
+        sampler=sampler,
+        scaler=scaler,
+        rank=rank,
+    )
+
+    if params.best_train_epoch == params.cur_epoch:
+        best_train_filename = params.exp_dir / "best-train-loss.pt"
+        copyfile(src=filename, dst=best_train_filename)
+
+    if params.best_valid_epoch == params.cur_epoch:
+        best_valid_filename = params.exp_dir / "best-valid-loss.pt"
+        copyfile(src=filename, dst=best_valid_filename)
+
+
+def compute_loss(
+    params: AttributeDict,
+    model: Union[nn.Module, DDP],
+    sp: spm.SentencePieceProcessor,
+    batch: dict,
+    is_training: bool,
+) -> Tuple[Tensor, MetricsTracker]:
+    """
+    Compute transducer loss given the model and its inputs.
+
+    Args:
+      params:
+        Parameters for training. See :func:`get_params`.
+      model:
+        The model for training. It is an instance of Zipformer in our case.
+      batch:
+        A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+        for the content in it.
+      is_training:
+        True for training. False for validation. When it is True, this
+        function enables autograd during computation; when it is False, it
+        disables autograd.
+     warmup: a floating point value which increases throughout training;
+        values >= 1.0 are fully warmed up and have all modules present.
+    """
+    device = model.device if isinstance(model, DDP) else next(model.parameters()).device
+    feature = batch["inputs"]
+    # at entry, feature is (N, T, C)
+    assert feature.ndim == 3
+    feature = feature.to(device)
+
+    supervisions = batch["supervisions"]
+    feature_lens = supervisions["num_frames"]
+
+    batch_idx_train = params.batch_idx_train
+    warm_step = params.warm_step
+
+    texts = supervisions["text"]
+    y = sp.encode(texts, out_type=int)
+    y = k2.RaggedTensor(y).to(device)
+
+    with torch.set_grad_enabled(is_training):
+        simple_loss, pruned_loss = model(
+            x=feature,
+            x_lens=feature_lens,
+            y=y,
+            prune_range=params.prune_range,
+            am_scale=params.am_scale,
+            lm_scale=params.lm_scale,
+        )
+
+        s = params.simple_loss_scale
+        # take down the scale on the simple loss from 1.0 at the start
+        # to params.simple_loss scale by warm_step.
+        simple_loss_scale = (
+            s
+            if batch_idx_train >= warm_step
+            else 1.0 - (batch_idx_train / warm_step) * (1.0 - s)
+        )
+        pruned_loss_scale = (
+            1.0
+            if batch_idx_train >= warm_step
+            else 0.1 + 0.9 * (batch_idx_train / warm_step)
+        )
+
+        loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
+
+    assert loss.requires_grad == is_training
+
+    info = MetricsTracker()
+    with warnings.catch_warnings():
+        warnings.simplefilter("ignore")
+        info["frames"] = ((feature_lens - 7) // 2).sum().item()
+
+    # Note: We use reduction=sum while computing the loss.
+    info["loss"] = loss.detach().cpu().item()
+    info["simple_loss"] = simple_loss.detach().cpu().item()
+    info["pruned_loss"] = pruned_loss.detach().cpu().item()
+
+    return loss, info
+
+
+def compute_validation_loss(
+    params: AttributeDict,
+    model: Union[nn.Module, DDP],
+    sp: spm.SentencePieceProcessor,
+    valid_dl: torch.utils.data.DataLoader,
+    world_size: int = 1,
+) -> MetricsTracker:
+    """Run the validation process."""
+    model.eval()
+
+    tot_loss = MetricsTracker()
+
+    for batch_idx, batch in enumerate(valid_dl):
+        loss, loss_info = compute_loss(
+            params=params,
+            model=model,
+            sp=sp,
+            batch=batch,
+            is_training=False,
+        )
+        assert loss.requires_grad is False
+        tot_loss = tot_loss + loss_info
+
+    if world_size > 1:
+        tot_loss.reduce(loss.device)
+
+    loss_value = tot_loss["loss"] / tot_loss["frames"]
+    if loss_value < params.best_valid_loss:
+        params.best_valid_epoch = params.cur_epoch
+        params.best_valid_loss = loss_value
+
+    return tot_loss
+
+
+def train_one_epoch(
+    params: AttributeDict,
+    model: Union[nn.Module, DDP],
+    optimizer: torch.optim.Optimizer,
+    scheduler: LRSchedulerType,
+    sp: spm.SentencePieceProcessor,
+    train_dl: torch.utils.data.DataLoader,
+    valid_dl: torch.utils.data.DataLoader,
+    scaler: GradScaler,
+    model_avg: Optional[nn.Module] = None,
+    tb_writer: Optional[SummaryWriter] = None,
+    world_size: int = 1,
+    rank: int = 0,
+) -> None:
+    """Train the model for one epoch.
+
+    The training loss from the mean of all frames is saved in
+    `params.train_loss`. It runs the validation process every
+    `params.valid_interval` batches.
+
+    Args:
+      params:
+        It is returned by :func:`get_params`.
+      model:
+        The model for training.
+      optimizer:
+        The optimizer we are using.
+      scheduler:
+        The learning rate scheduler, we call step() every step.
+      train_dl:
+        Dataloader for the training dataset.
+      valid_dl:
+        Dataloader for the validation dataset.
+      scaler:
+        The scaler used for mix precision training.
+      model_avg:
+        The stored model averaged from the start of training.
+      tb_writer:
+        Writer to write log messages to tensorboard.
+      world_size:
+        Number of nodes in DDP training. If it is 1, DDP is disabled.
+      rank:
+        The rank of the node in DDP training. If no DDP is used, it should
+        be set to 0.
+    """
+    model.train()
+
+    tot_loss = MetricsTracker()
+
+    cur_batch_idx = params.get("cur_batch_idx", 0)
+
+    for batch_idx, batch in enumerate(train_dl):
+        if batch_idx < cur_batch_idx:
+            continue
+        cur_batch_idx = batch_idx
+
+        params.batch_idx_train += 1
+        batch_size = len(batch["supervisions"]["text"])
+
+        try:
+            with torch.cuda.amp.autocast(enabled=params.use_fp16):
+                loss, loss_info = compute_loss(
+                    params=params,
+                    model=model,
+                    sp=sp,
+                    batch=batch,
+                    is_training=True,
+                )
+            # summary stats
+            tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
+
+            # NOTE: We use reduction==sum and loss is computed over utterances
+            # in the batch and there is no normalization to it so far.
+            scaler.scale(loss).backward()
+            set_batch_count(model, params.batch_idx_train)
+            scheduler.step_batch(params.batch_idx_train)
+
+            scaler.step(optimizer)
+            scaler.update()
+            optimizer.zero_grad()
+        except:  # noqa
+            display_and_save_batch(batch, params=params, sp=sp)
+            raise
+
+        if params.print_diagnostics and batch_idx == 5:
+            return
+
+        if (
+            rank == 0
+            and params.batch_idx_train > 0
+            and params.batch_idx_train % params.average_period == 0
+        ):
+            update_averaged_model(
+                params=params,
+                model_cur=model,
+                model_avg=model_avg,
+            )
+
+        if (
+            params.batch_idx_train > 0
+            and params.batch_idx_train % params.save_every_n == 0
+        ):
+            params.cur_batch_idx = batch_idx
+            save_checkpoint_with_global_batch_idx(
+                out_dir=params.exp_dir,
+                global_batch_idx=params.batch_idx_train,
+                model=model,
+                model_avg=model_avg,
+                params=params,
+                optimizer=optimizer,
+                scheduler=scheduler,
+                sampler=train_dl.sampler,
+                scaler=scaler,
+                rank=rank,
+            )
+            del params.cur_batch_idx
+            remove_checkpoints(
+                out_dir=params.exp_dir,
+                topk=params.keep_last_k,
+                rank=rank,
+            )
+
+        if batch_idx % 100 == 0 and params.use_fp16:
+            # If the grad scale was less than 1, try increasing it.    The _growth_interval
+            # of the grad scaler is configurable, but we can't configure it to have different
+            # behavior depending on the current grad scale.
+            cur_grad_scale = scaler._scale.item()
+            if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0):
+                scaler.update(cur_grad_scale * 2.0)
+            if cur_grad_scale < 0.01:
+                logging.warning(f"Grad scale is small: {cur_grad_scale}")
+            if cur_grad_scale < 1.0e-05:
+                raise RuntimeError(
+                    f"grad_scale is too small, exiting: {cur_grad_scale}"
+                )
+
+        if batch_idx % params.log_interval == 0:
+            cur_lr = scheduler.get_last_lr()[0]
+            cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
+
+            logging.info(
+                f"Epoch {params.cur_epoch}, "
+                f"batch {batch_idx}, loss[{loss_info}], "
+                f"tot_loss[{tot_loss}], batch size: {batch_size}, "
+                f"lr: {cur_lr:.2e}, "
+                + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
+            )
+
+            if tb_writer is not None:
+                tb_writer.add_scalar(
+                    "train/learning_rate", cur_lr, params.batch_idx_train
+                )
+
+                loss_info.write_summary(
+                    tb_writer, "train/current_", params.batch_idx_train
+                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+                if params.use_fp16:
+                    tb_writer.add_scalar(
+                        "train/grad_scale", cur_grad_scale, params.batch_idx_train
+                    )
+
+        if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
+            logging.info("Computing validation loss")
+            valid_info = compute_validation_loss(
+                params=params,
+                model=model,
+                sp=sp,
+                valid_dl=valid_dl,
+                world_size=world_size,
+            )
+            model.train()
+            logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
+            logging.info(
+                f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
+            )
+            if tb_writer is not None:
+                valid_info.write_summary(
+                    tb_writer, "train/valid_", params.batch_idx_train
+                )
+
+    loss_value = tot_loss["loss"] / tot_loss["frames"]
+    params.train_loss = loss_value
+    if params.train_loss < params.best_train_loss:
+        params.best_train_epoch = params.cur_epoch
+        params.best_train_loss = params.train_loss
+
+
+def run(rank, world_size, args):
+    """
+    Args:
+      rank:
+        It is a value between 0 and `world_size-1`, which is
+        passed automatically by `mp.spawn()` in :func:`main`.
+        The node with rank 0 is responsible for saving checkpoint.
+      world_size:
+        Number of GPUs for DDP training.
+      args:
+        The return value of get_parser().parse_args()
+    """
+    params = get_params()
+    params.update(vars(args))
+
+    fix_random_seed(params.seed)
+    if world_size > 1:
+        setup_dist(rank, world_size, params.master_port)
+
+    setup_logger(f"{params.exp_dir}/log/log-train")
+    logging.info("Training started")
+
+    if args.tensorboard and rank == 0:
+        tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
+    else:
+        tb_writer = None
+
+    device = torch.device("cpu")
+    if torch.cuda.is_available():
+        device = torch.device("cuda", rank)
+    logging.info(f"Device: {device}")
+
+    sp = spm.SentencePieceProcessor()
+    sp.load(params.bpe_model)
+
+    #  is defined in local/train_bpe_model.py
+    params.blank_id = sp.piece_to_id("")
+    params.vocab_size = sp.get_piece_size()
+
+    logging.info(params)
+
+    logging.info("About to create model")
+    model = get_transducer_model(params)
+
+    num_param = sum([p.numel() for p in model.parameters()])
+    logging.info(f"Number of model parameters: {num_param}")
+
+    assert params.save_every_n >= params.average_period
+    model_avg: Optional[nn.Module] = None
+    if rank == 0:
+        # model_avg is only used with rank 0
+        model_avg = copy.deepcopy(model).to(torch.float64)
+
+    assert params.start_epoch > 0, params.start_epoch
+    checkpoints = load_checkpoint_if_available(
+        params=params, model=model, model_avg=model_avg
+    )
+
+    model.to(device)
+    if world_size > 1:
+        logging.info("Using DDP")
+        model = DDP(model, device_ids=[rank], find_unused_parameters=True)
+
+    optimizer = ScaledAdam(model.parameters(), lr=params.base_lr, clipping_scale=2.0)
+
+    scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
+
+    if checkpoints and "optimizer" in checkpoints:
+        logging.info("Loading optimizer state dict")
+        optimizer.load_state_dict(checkpoints["optimizer"])
+
+    if (
+        checkpoints
+        and "scheduler" in checkpoints
+        and checkpoints["scheduler"] is not None
+    ):
+        logging.info("Loading scheduler state dict")
+        scheduler.load_state_dict(checkpoints["scheduler"])
+
+    if params.print_diagnostics:
+        opts = diagnostics.TensorDiagnosticOptions(
+            2**22
+        )  # allow 4 megabytes per sub-module
+        diagnostic = diagnostics.attach_diagnostics(model, opts)
+
+    if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
+        # We only load the sampler's state dict when it loads a checkpoint
+        # saved in the middle of an epoch
+        sampler_state_dict = checkpoints["sampler"]
+    else:
+        sampler_state_dict = None
+
+    if params.inf_check:
+        register_inf_check_hooks(model)
+
+    ami = AmiAsrDataModule(args)
+
+    # Here is the duration statistics of the training set.
+    # Cuts count: 1230033
+    # Total duration (hh:mm:ss): 904:25:34
+    # Speech duration (hh:mm:ss): 904:25:34 (100.0%)
+    # Duration statistics (seconds):
+    # mean	2.6
+    # std	2.8
+    # min	0.0
+    # 25%	0.6
+    # 50%	1.6
+    # 75%	3.8
+    # 99%	12.3
+    # 99.5%	13.9
+    # 99.9%	18.3
+    # max	36.8
+
+    train_cuts = ami.train_cuts(sp=sp)
+    train_dl = ami.train_dataloaders(train_cuts, sampler_state_dict=sampler_state_dict)
+
+    valid_cuts = ami.dev_ihm_cuts()
+    valid_dl = ami.valid_dataloaders(valid_cuts)
+
+    if not params.print_diagnostics:
+        scan_pessimistic_batches_for_oom(
+            model=model,
+            train_dl=train_dl,
+            optimizer=optimizer,
+            sp=sp,
+            params=params,
+        )
+
+    scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
+    if checkpoints and "grad_scaler" in checkpoints:
+        logging.info("Loading grad scaler state dict")
+        scaler.load_state_dict(checkpoints["grad_scaler"])
+
+    for epoch in range(params.start_epoch, params.num_epochs + 1):
+        scheduler.step_epoch(epoch - 1)
+        fix_random_seed(params.seed + epoch - 1)
+        train_dl.sampler.set_epoch(epoch - 1)
+
+        if tb_writer is not None:
+            tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
+
+        params.cur_epoch = epoch
+
+        train_one_epoch(
+            params=params,
+            model=model,
+            model_avg=model_avg,
+            optimizer=optimizer,
+            scheduler=scheduler,
+            sp=sp,
+            train_dl=train_dl,
+            valid_dl=valid_dl,
+            scaler=scaler,
+            tb_writer=tb_writer,
+            world_size=world_size,
+            rank=rank,
+        )
+
+        if params.print_diagnostics:
+            diagnostic.print_diagnostics()
+            break
+
+        save_checkpoint(
+            params=params,
+            model=model,
+            model_avg=model_avg,
+            optimizer=optimizer,
+            scheduler=scheduler,
+            sampler=train_dl.sampler,
+            scaler=scaler,
+            rank=rank,
+        )
+
+    logging.info("Done!")
+
+    if world_size > 1:
+        torch.distributed.barrier()
+        cleanup_dist()
+
+
+def display_and_save_batch(
+    batch: dict,
+    params: AttributeDict,
+    sp: spm.SentencePieceProcessor,
+) -> None:
+    """Display the batch statistics and save the batch into disk.
+
+    Args:
+      batch:
+        A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+        for the content in it.
+      params:
+        Parameters for training. See :func:`get_params`.
+      sp:
+        The BPE model.
+    """
+    from lhotse.utils import uuid4
+
+    filename = f"{params.exp_dir}/batch-{uuid4()}.pt"
+    logging.info(f"Saving batch to {filename}")
+    torch.save(batch, filename)
+
+    supervisions = batch["supervisions"]
+    features = batch["inputs"]
+
+    logging.info(f"features shape: {features.shape}")
+
+    y = sp.encode(supervisions["text"], out_type=int)
+    num_tokens = sum(len(i) for i in y)
+    logging.info(f"num tokens: {num_tokens}")
+
+
+def scan_pessimistic_batches_for_oom(
+    model: Union[nn.Module, DDP],
+    train_dl: torch.utils.data.DataLoader,
+    optimizer: torch.optim.Optimizer,
+    sp: spm.SentencePieceProcessor,
+    params: AttributeDict,
+):
+    from lhotse.dataset import find_pessimistic_batches
+
+    logging.info(
+        "Sanity check -- see if any of the batches in epoch 1 would cause OOM."
+    )
+    batches, crit_values = find_pessimistic_batches(train_dl.sampler)
+    for criterion, cuts in batches.items():
+        batch = train_dl.dataset[cuts]
+        try:
+            with torch.cuda.amp.autocast(enabled=params.use_fp16):
+                loss, _ = compute_loss(
+                    params=params,
+                    model=model,
+                    sp=sp,
+                    batch=batch,
+                    is_training=True,
+                )
+            loss.backward()
+            optimizer.zero_grad()
+        except Exception as e:
+            if "CUDA out of memory" in str(e):
+                logging.error(
+                    "Your GPU ran out of memory with the current "
+                    "max_duration setting. We recommend decreasing "
+                    "max_duration and trying again.\n"
+                    f"Failing criterion: {criterion} "
+                    f"(={crit_values[criterion]}) ..."
+                )
+            display_and_save_batch(batch, params=params, sp=sp)
+            raise
+        logging.info(
+            f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
+        )
+
+
+def main():
+    parser = get_parser()
+    AmiAsrDataModule.add_arguments(parser)
+    args = parser.parse_args()
+    args.exp_dir = Path(args.exp_dir)
+
+    world_size = args.world_size
+    assert world_size >= 1
+    if world_size > 1:
+        mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
+    else:
+        run(rank=0, world_size=1, args=args)
+
+
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+if __name__ == "__main__":
+    main()
diff --git a/egs/ami/ASR/pruned_transducer_stateless7/zipformer.py b/egs/ami/ASR/pruned_transducer_stateless7/zipformer.py
new file mode 120000
index 000000000..f2f66041e
--- /dev/null
+++ b/egs/ami/ASR/pruned_transducer_stateless7/zipformer.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7/zipformer.py
\ No newline at end of file
diff --git a/egs/ami/ASR/shared b/egs/ami/ASR/shared
new file mode 120000
index 000000000..4cbd91a7e
--- /dev/null
+++ b/egs/ami/ASR/shared
@@ -0,0 +1 @@
+../../../icefall/shared
\ No newline at end of file

From 61032e70e097aea63d191183466d0f1b16f9e16e Mon Sep 17 00:00:00 2001
From: abb128 <65567823+abb128@users.noreply.github.com>
Date: Sat, 26 Nov 2022 04:10:37 +0200
Subject: [PATCH 020/174] Fix exception in find_checkpoints (#668)

---
 icefall/checkpoint.py | 10 +++++++++-
 1 file changed, 9 insertions(+), 1 deletion(-)

diff --git a/icefall/checkpoint.py b/icefall/checkpoint.py
index 8aa0a8eeb..f0663a1df 100644
--- a/icefall/checkpoint.py
+++ b/icefall/checkpoint.py
@@ -292,7 +292,15 @@ def find_checkpoints(out_dir: Path, iteration: int = 0) -> List[str]:
     """
     checkpoints = list(glob.glob(f"{out_dir}/checkpoint-[0-9]*.pt"))
     pattern = re.compile(r"checkpoint-([0-9]+).pt")
-    iter_checkpoints = [(int(pattern.search(c).group(1)), c) for c in checkpoints]
+    iter_checkpoints = []
+    for c in checkpoints:
+        result = pattern.search(c)
+        if not result:
+            logging.warn(f"Invalid checkpoint filename {c}")
+            continue
+        
+        iter_checkpoints.append((int(result.group(1)), c))
+
     # iter_checkpoints is a list of tuples. Each tuple contains
     # two elements: (iteration_number, checkpoint-iteration_number.pt)
 

From 9cf79cac3f23757e25b499947f045efb0f4d71a6 Mon Sep 17 00:00:00 2001
From: Guo Liyong 
Date: Sat, 26 Nov 2022 21:48:17 +0800
Subject: [PATCH 021/174] message formatting

---
 .../ASR/pruned_transducer_stateless7/optim.py | 76 +++++++++++--------
 .../ASR/pruned_transducer_stateless7/train.py | 10 +--
 2 files changed, 45 insertions(+), 41 deletions(-)

diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py
index ab55381d7..790752fe1 100644
--- a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py
@@ -42,7 +42,7 @@ class BatchedOptimizer(Optimizer):
         super(BatchedOptimizer, self).__init__(params, defaults)
 
     @contextlib.contextmanager
-    def batched_params(self, param_group, group_params_names=None):
+    def batched_params(self, param_group, group_params_names):
         """
         This function returns (technically, yields) a list of
           of tuples (p, state), where
@@ -85,7 +85,9 @@ class BatchedOptimizer(Optimizer):
             batches_names[key].append(named_p)
 
         batches_names_keys = list(batches_names.keys())
-        sorted_idx = sorted(range(len(batches_names)), key=lambda i: batches_names_keys[i])
+        sorted_idx = sorted(
+            range(len(batches_names)), key=lambda i: batches_names_keys[i]
+        )
         batches_names = [batches_names[batches_names_keys[idx]] for idx in sorted_idx]
         batches = [batches[batches_names_keys[idx]] for idx in sorted_idx]
 
@@ -174,7 +176,7 @@ class ScaledAdam(BatchedOptimizer):
         size_update_period=4,
         clipping_update_period=100,
         parameters_names=None,
-        show_dominant_parameters=False,
+        show_dominant_parameters=True,
     ):
 
         defaults = dict(
@@ -211,7 +213,7 @@ class ScaledAdam(BatchedOptimizer):
                 loss = closure()
 
         batch = True
-        assert len(self.param_groups)  == len(self.parameters_names)
+        assert len(self.param_groups) == len(self.parameters_names)
 
         for group, group_params_names in zip(self.param_groups, self.parameters_names):
 
@@ -381,42 +383,52 @@ class ScaledAdam(BatchedOptimizer):
             return ans
 
     def _show_gradient_dominating_parameter(self, pairs, tot_sumsq):
-        # ori means calculated with state["param_rms"]
-        # cur means calculated with "param_rms" of current param.
-        # bt is short batch
-        # all_sumsq_ori_rms
-        all_sumsq_ori = {}
-        all_sumsq_cur = {}
+        all_sumsq_orig = {}
         for (p, state, batch_param_names) in pairs:
             # p is a stacked batch parameters.
-            grad = p.grad
+            batch_grad = p.grad
             if p.numel() == p.shape[0]:  # a batch of scalars
-                batch_sumsq_ori = grad**2  # sum() to change shape [1] to []
-                batch_sumsq_cur = batch_sumsq_ori  # sum() to change shape [1] to []
+                batch_sumsq_orig = batch_grad**2
                 # Dummpy values used by following `zip` statement.
-                batch_rms_ori = torch.zeros(p.shape[0])
-                batch_rms_cur = batch_rms_ori
+                batch_rms_orig = torch.ones(p.shape[0])
             else:
-                batch_rms_ori = state["param_rms"]
-                batch_sumsq_ori = ((grad * batch_rms_ori) ** 2).sum(dim=list(range(1, grad.ndim)))
+                batch_rms_orig = state["param_rms"]
+                batch_sumsq_orig = ((batch_grad * batch_rms_orig) ** 2).sum(
+                    dim=list(range(1, batch_grad.ndim))
+                )
 
-                batch_rms_cur = (p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt()
-                batch_sumsq_cur = ((grad * batch_rms_cur) ** 2).sum(dim=list(range(1, grad.ndim)))
+            for name, sumsq_orig, rms, grad in zip(
+                batch_param_names, batch_sumsq_orig, batch_rms_orig, batch_grad
+            ):
 
-            for name, sumsq_ori, sumsq_cur in zip(
-                    batch_param_names, batch_sumsq_ori, batch_sumsq_cur):
+                proportion_orig = sumsq_orig / tot_sumsq
+                all_sumsq_orig[name] = (proportion_orig, sumsq_orig, rms, grad)
 
-                proportion_ori = sumsq_ori / tot_sumsq
-                proportion_cur = sumsq_cur / tot_sumsq
-
-                all_sumsq_ori[name] = (proportion_ori, sumsq_ori)
-                all_sumsq_cur[name] = (proportion_cur, sumsq_cur)
-
-        for rms_type, all_sumsq in zip(("ori", "cur"), (all_sumsq_ori, all_sumsq_cur)):
-            sorted_by_proportion = {k: v for k, v in sorted(all_sumsq.items(), key=lambda item: item[1][0], reverse=True)}
-            dominant_param_name = next(iter(sorted_by_proportion))
-            dominant_proportion, dominant_sumsq = sorted_by_proportion[dominant_param_name]
-            logging.info(f"Dominant sumsq with {rms_type}_rms: {dominant_param_name} {dominant_proportion}  {dominant_sumsq} {tot_sumsq}")
+        assert torch.isclose(
+            sum([value[0] for value in all_sumsq_orig.values()]).cpu(),
+            torch.tensor(1.0),
+        )
+        sorted_by_proportion = {
+            k: v
+            for k, v in sorted(
+                all_sumsq_orig.items(), key=lambda item: item[1][0], reverse=True
+            )
+        }
+        dominant_param_name = next(iter(sorted_by_proportion))
+        (
+            dominant_proportion,
+            dominant_sumsq,
+            dominant_rms,
+            dominant_grad,
+        ) = sorted_by_proportion[dominant_param_name]
+        logging.info(
+            f"Parameter Dominanting tot_sumsq {dominant_param_name}"
+            f" with proportion {dominant_proportion:.2f},"
+            f" where dominant_sumsq=(grad_sumsq*orig_rms_sq)"
+            f"={dominant_sumsq:.3e},"
+            f" grad_sumsq = {(dominant_grad**2).sum():.3e},"
+            f" orig_rms_sq={(dominant_rms**2).item():.3e}"
+        )
 
     def _step_one_batch(
         self, group: dict, p: Tensor, state: dict, clipping_scale: float
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py
index 8375b1a18..e5a3e68df 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py
@@ -368,13 +368,6 @@ def get_parser():
         help="Whether to use half precision training.",
     )
 
-    parser.add_argument(
-        "--show-dominant-parameters",
-        type=str2bool,
-        default=False,
-        help="Whether to show dominant parameters.",
-    )
-
     add_model_arguments(parser)
 
     return parser
@@ -998,8 +991,7 @@ def run(rank, world_size, args):
     parameters_names = []
     parameters_names.append([name_param_pair[0] for name_param_pair in model.named_parameters()])
     optimizer = ScaledAdam(model.parameters(), lr=params.base_lr,
-            clipping_scale=2.0, parameters_names=parameters_names,
-            show_dominant_parameters=params.show_dominant_parameters)
+            clipping_scale=2.0, parameters_names=parameters_names)
 
     scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
 

From 6693d907d3ddd5c5eade144b55a57c8831d6d9b2 Mon Sep 17 00:00:00 2001
From: huangruizhe 
Date: Sat, 26 Nov 2022 22:26:09 -0500
Subject: [PATCH 022/174] shuffle full Librispeech data (#574)

* shuffled full/partial librispeech data

* fixed the code style issue

* Shuffled full librispeech data off-line

* Fixed style, addressed comments, and removed redandunt codes

* Used the suggested version of black

* Propagated the changes to other folders for librispeech (except
conformer_mmi and streaming_conformer_ctc)
---
 egs/librispeech/ASR/conformer_ctc/train.py             |  6 +++---
 egs/librispeech/ASR/conformer_ctc2/train.py            |  6 +++---
 .../ASR/conv_emformer_transducer_stateless/train.py    |  6 +++---
 .../ASR/conv_emformer_transducer_stateless2/train.py   |  6 +++---
 egs/librispeech/ASR/lstm_transducer_stateless/train.py |  6 +++---
 .../ASR/lstm_transducer_stateless2/train.py            |  6 +++---
 egs/librispeech/ASR/prepare.sh                         |  5 +++++
 .../ASR/pruned_stateless_emformer_rnnt2/train.py       |  6 +++---
 .../ASR/pruned_transducer_stateless/train.py           |  6 +++---
 .../ASR/pruned_transducer_stateless2/train.py          |  6 +++---
 .../ASR/pruned_transducer_stateless3/train.py          |  6 +++---
 .../ASR/pruned_transducer_stateless4/train.py          |  6 +++---
 .../ASR/pruned_transducer_stateless5/train.py          |  6 +++---
 .../ASR/pruned_transducer_stateless6/train.py          |  6 +++---
 egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py    | 10 ++++++++++
 egs/librispeech/ASR/tdnn_lstm_ctc/train.py             |  8 ++++----
 egs/librispeech/ASR/transducer/train.py                |  6 +++---
 egs/librispeech/ASR/transducer_lstm/train.py           |  6 +++---
 egs/librispeech/ASR/transducer_stateless/train.py      |  6 +++---
 egs/librispeech/ASR/transducer_stateless2/train.py     |  6 +++---
 .../ASR/transducer_stateless_multi_datasets/train.py   |  6 +++---
 21 files changed, 73 insertions(+), 58 deletions(-)

diff --git a/egs/librispeech/ASR/conformer_ctc/train.py b/egs/librispeech/ASR/conformer_ctc/train.py
index 1449bc310..99fe64793 100755
--- a/egs/librispeech/ASR/conformer_ctc/train.py
+++ b/egs/librispeech/ASR/conformer_ctc/train.py
@@ -687,10 +687,10 @@ def run(rank, world_size, args):
 
     librispeech = LibriSpeechAsrDataModule(args)
 
-    train_cuts = librispeech.train_clean_100_cuts()
     if params.full_libri:
-        train_cuts += librispeech.train_clean_360_cuts()
-        train_cuts += librispeech.train_other_500_cuts()
+        train_cuts = librispeech.train_all_shuf_cuts()
+    else:
+        train_cuts = librispeech.train_clean_100_cuts()
 
     def remove_short_and_long_utt(c: Cut):
         # Keep only utterances with duration between 1 second and 20 seconds
diff --git a/egs/librispeech/ASR/conformer_ctc2/train.py b/egs/librispeech/ASR/conformer_ctc2/train.py
index ceea0c22c..121fdb256 100755
--- a/egs/librispeech/ASR/conformer_ctc2/train.py
+++ b/egs/librispeech/ASR/conformer_ctc2/train.py
@@ -928,10 +928,10 @@ def run(rank, world_size, args):
 
     librispeech = LibriSpeechAsrDataModule(args)
 
-    train_cuts = librispeech.train_clean_100_cuts()
     if params.full_libri:
-        train_cuts += librispeech.train_clean_360_cuts()
-        train_cuts += librispeech.train_other_500_cuts()
+        train_cuts = librispeech.train_all_shuf_cuts()
+    else:
+        train_cuts = librispeech.train_clean_100_cuts()
 
     def remove_short_and_long_utt(c: Cut):
         # Keep only utterances with duration between 1 second and 20 seconds
diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py
index 213115854..6bb5505aa 100755
--- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py
+++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py
@@ -970,10 +970,10 @@ def run(rank, world_size, args):
 
     librispeech = LibriSpeechAsrDataModule(args)
 
-    train_cuts = librispeech.train_clean_100_cuts()
     if params.full_libri:
-        train_cuts += librispeech.train_clean_360_cuts()
-        train_cuts += librispeech.train_other_500_cuts()
+        train_cuts = librispeech.train_all_shuf_cuts()
+    else:
+        train_cuts = librispeech.train_clean_100_cuts()
 
     def remove_short_and_long_utt(c: Cut):
         # Keep only utterances with duration between 1 second and 20 seconds
diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py
index 6a019fd63..8462ae92a 100755
--- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py
+++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py
@@ -970,10 +970,10 @@ def run(rank, world_size, args):
 
     librispeech = LibriSpeechAsrDataModule(args)
 
-    train_cuts = librispeech.train_clean_100_cuts()
     if params.full_libri:
-        train_cuts += librispeech.train_clean_360_cuts()
-        train_cuts += librispeech.train_other_500_cuts()
+        train_cuts = librispeech.train_all_shuf_cuts()
+    else:
+        train_cuts = librispeech.train_clean_100_cuts()
 
     def remove_short_and_long_utt(c: Cut):
         # Keep only utterances with duration between 1 second and 20 seconds
diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/train.py b/egs/librispeech/ASR/lstm_transducer_stateless/train.py
index a54108f6d..feb81d500 100755
--- a/egs/librispeech/ASR/lstm_transducer_stateless/train.py
+++ b/egs/librispeech/ASR/lstm_transducer_stateless/train.py
@@ -954,10 +954,10 @@ def run(rank, world_size, args):
 
     librispeech = LibriSpeechAsrDataModule(args)
 
-    train_cuts = librispeech.train_clean_100_cuts()
     if params.full_libri:
-        train_cuts += librispeech.train_clean_360_cuts()
-        train_cuts += librispeech.train_other_500_cuts()
+        train_cuts = librispeech.train_all_shuf_cuts()
+    else:
+        train_cuts = librispeech.train_clean_100_cuts()
 
     def remove_short_and_long_utt(c: Cut):
         # Keep only utterances with duration between 1 second and 20 seconds
diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/train.py b/egs/librispeech/ASR/lstm_transducer_stateless2/train.py
index 8736384b4..4fc4fa7f8 100755
--- a/egs/librispeech/ASR/lstm_transducer_stateless2/train.py
+++ b/egs/librispeech/ASR/lstm_transducer_stateless2/train.py
@@ -1108,10 +1108,10 @@ def run(rank, world_size, args):
 
     librispeech = LibriSpeech(manifest_dir=args.manifest_dir)
 
-    train_cuts = librispeech.train_clean_100_cuts()
     if params.full_libri:
-        train_cuts += librispeech.train_clean_360_cuts()
-        train_cuts += librispeech.train_other_500_cuts()
+        train_cuts = librispeech.train_all_shuf_cuts()
+    else:
+        train_cuts = librispeech.train_clean_100_cuts()
 
     train_cuts = filter_short_and_long_utterances(train_cuts, sp)
 
diff --git a/egs/librispeech/ASR/prepare.sh b/egs/librispeech/ASR/prepare.sh
index 8668af0e4..542bbcdd8 100755
--- a/egs/librispeech/ASR/prepare.sh
+++ b/egs/librispeech/ASR/prepare.sh
@@ -123,6 +123,11 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
     touch data/fbank/.librispeech.done
   fi
 
+  cat <(gunzip -c data/fbank/librispeech_cuts_train-clean-100.jsonl.gz) \
+    <(gunzip -c data/fbank/librispeech_cuts_train-clean-360.jsonl.gz) \
+    <(gunzip -c data/fbank/librispeech_cuts_train-other-500.jsonl.gz) | \
+    shuf | gzip -c > data/fbank/librispeech_cuts_train-all-shuf.jsonl.gz
+
   if [ ! -e data/fbank/.librispeech-validated.done ]; then
     log "Validating data/fbank for LibriSpeech"
     parts=(
diff --git a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/train.py b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/train.py
index ed3fa1521..3601e1e11 100755
--- a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/train.py
+++ b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/train.py
@@ -882,10 +882,10 @@ def run(rank, world_size, args):
 
     librispeech = LibriSpeechAsrDataModule(args)
 
-    train_cuts = librispeech.train_clean_100_cuts()
     if params.full_libri:
-        train_cuts += librispeech.train_clean_360_cuts()
-        train_cuts += librispeech.train_other_500_cuts()
+        train_cuts = librispeech.train_all_shuf_cuts()
+    else:
+        train_cuts = librispeech.train_clean_100_cuts()
 
     def remove_short_and_long_utt(c: Cut):
         # Keep only utterances with duration between 1 second and 20 seconds
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/train.py b/egs/librispeech/ASR/pruned_transducer_stateless/train.py
index 4dabbccc1..cf4032027 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless/train.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless/train.py
@@ -873,10 +873,10 @@ def run(rank, world_size, args):
 
     librispeech = LibriSpeechAsrDataModule(args)
 
-    train_cuts = librispeech.train_clean_100_cuts()
     if params.full_libri:
-        train_cuts += librispeech.train_clean_360_cuts()
-        train_cuts += librispeech.train_other_500_cuts()
+        train_cuts = librispeech.train_all_shuf_cuts()
+    else:
+        train_cuts = librispeech.train_clean_100_cuts()
 
     def remove_short_and_long_utt(c: Cut):
         # Keep only utterances with duration between 1 second and 20 seconds
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py
index 86333fc97..6c19f2cb0 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py
@@ -931,10 +931,10 @@ def run(rank, world_size, args):
 
     librispeech = LibriSpeechAsrDataModule(args)
 
-    train_cuts = librispeech.train_clean_100_cuts()
     if params.full_libri:
-        train_cuts += librispeech.train_clean_360_cuts()
-        train_cuts += librispeech.train_other_500_cuts()
+        train_cuts = librispeech.train_all_shuf_cuts()
+    else:
+        train_cuts = librispeech.train_clean_100_cuts()
 
     def remove_short_and_long_utt(c: Cut):
         # Keep only utterances with duration between 1 second and 20 seconds
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/train.py b/egs/librispeech/ASR/pruned_transducer_stateless3/train.py
index 281ba4650..fdafa5a87 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless3/train.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless3/train.py
@@ -1065,10 +1065,10 @@ def run(rank, world_size, args):
 
     librispeech = LibriSpeech(manifest_dir=args.manifest_dir)
 
-    train_cuts = librispeech.train_clean_100_cuts()
     if params.full_libri:
-        train_cuts += librispeech.train_clean_360_cuts()
-        train_cuts += librispeech.train_other_500_cuts()
+        train_cuts = librispeech.train_all_shuf_cuts()
+    else:
+        train_cuts = librispeech.train_clean_100_cuts()
 
     train_cuts = filter_short_and_long_utterances(train_cuts, sp)
 
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py
index cb56c8294..9bd7df401 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py
@@ -978,10 +978,10 @@ def run(rank, world_size, args):
 
     librispeech = LibriSpeechAsrDataModule(args)
 
-    train_cuts = librispeech.train_clean_100_cuts()
     if params.full_libri:
-        train_cuts += librispeech.train_clean_360_cuts()
-        train_cuts += librispeech.train_other_500_cuts()
+        train_cuts = librispeech.train_all_shuf_cuts()
+    else:
+        train_cuts = librispeech.train_clean_100_cuts()
 
     def remove_short_and_long_utt(c: Cut):
         # Keep only utterances with duration between 1 second and 20 seconds
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/train.py b/egs/librispeech/ASR/pruned_transducer_stateless5/train.py
index 436620744..847c80ab0 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless5/train.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless5/train.py
@@ -1009,10 +1009,10 @@ def run(rank, world_size, args):
 
     librispeech = LibriSpeechAsrDataModule(args)
 
-    train_cuts = librispeech.train_clean_100_cuts()
     if params.full_libri:
-        train_cuts += librispeech.train_clean_360_cuts()
-        train_cuts += librispeech.train_other_500_cuts()
+        train_cuts = librispeech.train_all_shuf_cuts()
+    else:
+        train_cuts = librispeech.train_clean_100_cuts()
 
     def remove_short_and_long_utt(c: Cut):
         # Keep only utterances with duration between 1 second and 20 seconds
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/train.py b/egs/librispeech/ASR/pruned_transducer_stateless6/train.py
index 8f4d3b879..57753599a 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless6/train.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless6/train.py
@@ -970,10 +970,10 @@ def run(rank, world_size, args):
 
     librispeech = LibriSpeechAsrDataModule(args)
 
-    train_cuts = librispeech.train_clean_100_cuts()
     if params.full_libri:
-        train_cuts += librispeech.train_clean_360_cuts()
-        train_cuts += librispeech.train_other_500_cuts()
+        train_cuts = librispeech.train_all_shuf_cuts()
+    else:
+        train_cuts = librispeech.train_clean_100_cuts()
 
     def remove_short_and_long_utt(c: Cut):
         # Keep only utterances with duration between 1 second and 20 seconds
diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py
index 95d1b273a..c5787835d 100644
--- a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py
+++ b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py
@@ -414,6 +414,16 @@ class LibriSpeechAsrDataModule:
             self.args.manifest_dir / "librispeech_cuts_train-other-500.jsonl.gz"
         )
 
+    @lru_cache()
+    def train_all_shuf_cuts(self) -> CutSet:
+        logging.info(
+            "About to get the shuffled train-clean-100, \
+            train-clean-360 and train-other-500 cuts"
+        )
+        return load_manifest_lazy(
+            self.args.manifest_dir / "librispeech_cuts_train-all-shuf.jsonl.gz"
+        )
+
     @lru_cache()
     def dev_clean_cuts(self) -> CutSet:
         logging.info("About to get dev-clean cuts")
diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/train.py b/egs/librispeech/ASR/tdnn_lstm_ctc/train.py
index 071ac792b..0aa1587ba 100755
--- a/egs/librispeech/ASR/tdnn_lstm_ctc/train.py
+++ b/egs/librispeech/ASR/tdnn_lstm_ctc/train.py
@@ -173,7 +173,7 @@ def get_params() -> AttributeDict:
         {
             "exp_dir": Path("tdnn_lstm_ctc/exp"),
             "lang_dir": Path("data/lang_phone"),
-            "lr": 1e-3,
+            "lr": 1e-4,
             "feature_dim": 80,
             "weight_decay": 5e-4,
             "subsampling_factor": 3,
@@ -557,10 +557,10 @@ def run(rank, world_size, args):
 
     librispeech = LibriSpeechAsrDataModule(args)
 
-    train_cuts = librispeech.train_clean_100_cuts()
     if params.full_libri:
-        train_cuts += librispeech.train_clean_360_cuts()
-        train_cuts += librispeech.train_other_500_cuts()
+        train_cuts = librispeech.train_all_shuf_cuts()
+    else:
+        train_cuts = librispeech.train_clean_100_cuts()
 
     def remove_short_and_long_utt(c: Cut):
         # Keep only utterances with duration between 1 second and 20 seconds
diff --git a/egs/librispeech/ASR/transducer/train.py b/egs/librispeech/ASR/transducer/train.py
index 674ea10a6..29625754e 100755
--- a/egs/librispeech/ASR/transducer/train.py
+++ b/egs/librispeech/ASR/transducer/train.py
@@ -614,10 +614,10 @@ def run(rank, world_size, args):
 
     librispeech = LibriSpeechAsrDataModule(args)
 
-    train_cuts = librispeech.train_clean_100_cuts()
     if params.full_libri:
-        train_cuts += librispeech.train_clean_360_cuts()
-        train_cuts += librispeech.train_other_500_cuts()
+        train_cuts = librispeech.train_all_shuf_cuts()
+    else:
+        train_cuts = librispeech.train_clean_100_cuts()
 
     def remove_short_and_long_utt(c: Cut):
         # Keep only utterances with duration between 1 second and 20 seconds
diff --git a/egs/librispeech/ASR/transducer_lstm/train.py b/egs/librispeech/ASR/transducer_lstm/train.py
index 57bda63fd..792708bc0 100755
--- a/egs/librispeech/ASR/transducer_lstm/train.py
+++ b/egs/librispeech/ASR/transducer_lstm/train.py
@@ -620,10 +620,10 @@ def run(rank, world_size, args):
 
     librispeech = LibriSpeechAsrDataModule(args)
 
-    train_cuts = librispeech.train_clean_100_cuts()
     if params.full_libri:
-        train_cuts += librispeech.train_clean_360_cuts()
-        train_cuts += librispeech.train_other_500_cuts()
+        train_cuts = librispeech.train_all_shuf_cuts()
+    else:
+        train_cuts = librispeech.train_clean_100_cuts()
 
     def remove_short_and_long_utt(c: Cut):
         # Keep only utterances with duration between 1 second and 20 seconds
diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py
index bcb883fa5..8db9b59e7 100755
--- a/egs/librispeech/ASR/transducer_stateless/train.py
+++ b/egs/librispeech/ASR/transducer_stateless/train.py
@@ -641,10 +641,10 @@ def run(rank, world_size, args):
     if params.print_diagnostics:
         diagnostic = diagnostics.attach_diagnostics(model)
 
-    train_cuts = librispeech.train_clean_100_cuts()
     if params.full_libri:
-        train_cuts += librispeech.train_clean_360_cuts()
-        train_cuts += librispeech.train_other_500_cuts()
+        train_cuts = librispeech.train_all_shuf_cuts()
+    else:
+        train_cuts = librispeech.train_clean_100_cuts()
 
     def remove_short_and_long_utt(c: Cut):
         # Keep only utterances with duration between 1 second and 20 seconds
diff --git a/egs/librispeech/ASR/transducer_stateless2/train.py b/egs/librispeech/ASR/transducer_stateless2/train.py
index 68e247f23..1c3a33870 100755
--- a/egs/librispeech/ASR/transducer_stateless2/train.py
+++ b/egs/librispeech/ASR/transducer_stateless2/train.py
@@ -629,10 +629,10 @@ def run(rank, world_size, args):
     if params.print_diagnostics:
         diagnostic = diagnostics.attach_diagnostics(model)
 
-    train_cuts = librispeech.train_clean_100_cuts()
     if params.full_libri:
-        train_cuts += librispeech.train_clean_360_cuts()
-        train_cuts += librispeech.train_other_500_cuts()
+        train_cuts = librispeech.train_all_shuf_cuts()
+    else:
+        train_cuts = librispeech.train_clean_100_cuts()
 
     def remove_short_and_long_utt(c: Cut):
         # Keep only utterances with duration between 1 second and 20 seconds
diff --git a/egs/librispeech/ASR/transducer_stateless_multi_datasets/train.py b/egs/librispeech/ASR/transducer_stateless_multi_datasets/train.py
index 88987d91c..dafccd088 100755
--- a/egs/librispeech/ASR/transducer_stateless_multi_datasets/train.py
+++ b/egs/librispeech/ASR/transducer_stateless_multi_datasets/train.py
@@ -752,10 +752,10 @@ def run(rank, world_size, args):
 
     librispeech = LibriSpeech(manifest_dir=args.manifest_dir)
 
-    train_cuts = librispeech.train_clean_100_cuts()
     if params.full_libri:
-        train_cuts += librispeech.train_clean_360_cuts()
-        train_cuts += librispeech.train_other_500_cuts()
+        train_cuts = librispeech.train_all_shuf_cuts()
+    else:
+        train_cuts = librispeech.train_clean_100_cuts()
 
     train_cuts = filter_short_and_long_utterances(train_cuts)
 

From 4fee3e7f1ea6c2aefe7594e325ede1e530e54d3d Mon Sep 17 00:00:00 2001
From: Guo Liyong 
Date: Mon, 28 Nov 2022 16:55:18 +0800
Subject: [PATCH 023/174] impove comment

---
 .../ASR/pruned_transducer_stateless7/optim.py | 63 +++++++++++++------
 .../ASR/pruned_transducer_stateless7/train.py | 12 +++-
 2 files changed, 54 insertions(+), 21 deletions(-)

diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py
index 790752fe1..ff8fbb32c 100644
--- a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py
@@ -64,13 +64,15 @@ class BatchedOptimizer(Optimizer):
         you can do:
         
           with self.batched_params(group["params"]) as batches:
-             for p, state in batches:
+             for p, state, p_names in batches:
                  ...
         
 
         Args:
           group: a parameter group, which is a list of parameters; should be
-                one of self.groups.
+                one of self.param_groups.
+          group_params_names: name for each parameter in group,
+                which is List[str].
         """
         batches = defaultdict(
             list
@@ -79,6 +81,7 @@ class BatchedOptimizer(Optimizer):
             list
         )  # `batches` maps from tuple (dtype_as_str,*shape) to list of str
 
+        assert len(param_group) == len(group_params_names)
         for p, named_p in zip(param_group, group_params_names):
             key = (str(p.dtype), *p.shape)
             batches[key].append(p)
@@ -94,9 +97,9 @@ class BatchedOptimizer(Optimizer):
         stacked_params_dict = dict()
 
         # turn batches into a list, in deterministic order.
-        # pairs will contain pairs of (stacked_param, state), one for each batch
-        # in `batches`.
-        pairs = []
+        # tuples will contain tuples of (stacked_param, state, stacked_params_names),
+        # one for each batch in `batches`.
+        tuples = []
 
         for batch, batch_names in zip(batches, batches_names):
             p = batch[0]
@@ -110,11 +113,11 @@ class BatchedOptimizer(Optimizer):
             )
             p_stacked.grad = grad
             stacked_params_dict[key] = p_stacked
-            pairs.append((p_stacked, state, batch_names))
+            tuples.append((p_stacked, state, batch_names))
 
-        yield pairs  # <-- calling code will do the actual optimization here!
+        yield tuples  # <-- calling code will do the actual optimization here!
 
-        for ((stacked_params, _state, _names), batch) in zip(pairs, batches):
+        for ((stacked_params, _state, _names), batch) in zip(tuples, batches):
             for i, p in enumerate(batch):  # batch is list of Parameter
                 p.copy_(stacked_params[i])
 
@@ -179,6 +182,11 @@ class ScaledAdam(BatchedOptimizer):
         show_dominant_parameters=True,
     ):
 
+        assert parameters_names is not None, (
+            "Please prepare parameters_names,"
+            "which is a List[List[str]]. Each List[str] is for a group"
+            "and each str is for a parameter"
+        )
         defaults = dict(
             lr=lr,
             clipping_scale=clipping_scale,
@@ -193,6 +201,7 @@ class ScaledAdam(BatchedOptimizer):
         )
 
         super(ScaledAdam, self).__init__(params, defaults)
+        assert len(self.param_groups) == len(parameters_names)
         self.parameters_names = parameters_names
         self.show_dominant_parameters = show_dominant_parameters
 
@@ -213,7 +222,6 @@ class ScaledAdam(BatchedOptimizer):
                 loss = closure()
 
         batch = True
-        assert len(self.param_groups) == len(self.parameters_names)
 
         for group, group_params_names in zip(self.param_groups, self.parameters_names):
 
@@ -292,7 +300,7 @@ class ScaledAdam(BatchedOptimizer):
         state["exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format)
 
     def _get_clipping_scale(
-        self, group: dict, pairs: List[Tuple[Tensor, dict, List[str]]]
+        self, group: dict, tuples: List[Tuple[Tensor, dict, List[str]]]
     ) -> float:
         """
         Returns a scalar factor <= 1.0 that dictates gradient clipping, i.e. we will scale the gradients
@@ -300,12 +308,16 @@ class ScaledAdam(BatchedOptimizer):
 
         Args:
            group: the parameter group, an item in self.param_groups
-           pairs: a list of pairs of (param, state) where param is a batched set of parameters, with a .grad
-                (1st dim is batch dim) and state is the state-dict where optimization parameters are kept.
+           tuples: a list of tuples of (param, state, param_names)
+                where param is a batched set of parameters,
+                with a .grad (1st dim is batch dim)
+                and state is the state-dict where optimization parameters are kept.
+                param_names is a List[str] while each str is name for a parameter
+                in batched set of parameters "param".
         """
-        assert len(pairs) >= 1
+        assert len(tuples) >= 1
         clipping_scale = group["clipping_scale"]
-        (first_p, first_state, _) = pairs[0]
+        (first_p, first_state, _) = tuples[0]
         step = first_state["step"]
         if clipping_scale is None or step == 0:
             # no clipping.  return early on step == 0 because the other
@@ -314,7 +326,7 @@ class ScaledAdam(BatchedOptimizer):
         clipping_update_period = group["clipping_update_period"]
 
         tot_sumsq = torch.tensor(0.0, device=first_p.device)
-        for (p, state, param_names) in pairs:
+        for (p, state, param_names) in tuples:
             grad = p.grad
             if grad.is_sparse:
                 raise RuntimeError(
@@ -379,12 +391,27 @@ class ScaledAdam(BatchedOptimizer):
                 )
                 if self.show_dominant_parameters:
                     assert p.shape[0] == len(param_names)
-                    self._show_gradient_dominating_parameter(pairs, tot_sumsq)
+                    self._show_gradient_dominating_parameter(tuples, tot_sumsq)
             return ans
 
-    def _show_gradient_dominating_parameter(self, pairs, tot_sumsq):
+    def _show_gradient_dominating_parameter(
+        self, tuples: List[Tuple[Tensor, dict, List[str]]], tot_sumsq: Tensor
+    ):
+        """
+        Show information of parameter wihch dominanting tot_sumsq.
+
+        Args:
+           tuples: a list of tuples of (param, state, param_names)
+                where param is a batched set of parameters,
+                with a .grad (1st dim is batch dim)
+                and state is the state-dict where optimization parameters are kept.
+                param_names is a List[str] while each str is name for a parameter
+                in batched set of parameters "param".
+            tot_sumsq: sumsq of all parameters. Though it's could be calculated
+                from tuples, we still pass it to save some time.
+        """
         all_sumsq_orig = {}
-        for (p, state, batch_param_names) in pairs:
+        for (p, state, batch_param_names) in tuples:
             # p is a stacked batch parameters.
             batch_grad = p.grad
             if p.numel() == p.shape[0]:  # a batch of scalars
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py
index e5a3e68df..31a3a0505 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py
@@ -989,9 +989,15 @@ def run(rank, world_size, args):
         model = DDP(model, device_ids=[rank], find_unused_parameters=True)
 
     parameters_names = []
-    parameters_names.append([name_param_pair[0] for name_param_pair in model.named_parameters()])
-    optimizer = ScaledAdam(model.parameters(), lr=params.base_lr,
-            clipping_scale=2.0, parameters_names=parameters_names)
+    parameters_names.append(
+        [name_param_pair[0] for name_param_pair in model.named_parameters()]
+    )
+    optimizer = ScaledAdam(
+        model.parameters(),
+        lr=params.base_lr,
+        clipping_scale=2.0,
+        parameters_names=parameters_names,
+    )
 
     scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
 

From ece728d895c11545eb3232caa4f6a1c907192064 Mon Sep 17 00:00:00 2001
From: Zengwei Yao 
Date: Mon, 28 Nov 2022 22:34:02 +0800
Subject: [PATCH 024/174] Apply delay penalty on k2 ctc loss (#669)

* add init files

* fix bug, apply delay penalty

* fix decoding code and getting timestamps

* add option applying delay penalty on ctc log-prob

* fix bug of streaming decoding

* minor change for bpe-based case

* add test_model.py

* add README.md

* add CI
---
 .flake8                                       |    2 +-
 ...n-librispeech-conformer-ctc3-2022-11-28.sh |  119 ++
 ...-librispeech-conformer-ctc3-2022-11-28.yml |  151 +++
 egs/librispeech/ASR/RESULTS.md                |  102 +-
 .../ASR/conformer_ctc3/__init__.py            |    1 +
 .../ASR/conformer_ctc3/asr_datamodule.py      |    1 +
 .../ASR/conformer_ctc3/conformer.py           |    1 +
 egs/librispeech/ASR/conformer_ctc3/decode.py  | 1004 +++++++++++++++
 .../ASR/conformer_ctc3/encoder_interface.py   |    1 +
 egs/librispeech/ASR/conformer_ctc3/export.py  |  292 +++++
 .../ASR/conformer_ctc3/jit_pretrained.py      |  406 ++++++
 egs/librispeech/ASR/conformer_ctc3/lstmp.py   |    1 +
 egs/librispeech/ASR/conformer_ctc3/model.py   |  122 ++
 egs/librispeech/ASR/conformer_ctc3/optim.py   |    1 +
 .../ASR/conformer_ctc3/pretrained.py          |  458 +++++++
 egs/librispeech/ASR/conformer_ctc3/scaling.py |    1 +
 .../ASR/conformer_ctc3/scaling_converter.py   |    1 +
 .../ASR/conformer_ctc3/test_model.py          |   82 ++
 egs/librispeech/ASR/conformer_ctc3/train.py   | 1108 +++++++++++++++++
 icefall/bpe_graph_compiler.py                 |    5 +-
 icefall/char_graph_compiler.py                |    3 +-
 icefall/checkpoint.py                         |    2 +-
 icefall/graph_compiler.py                     |    4 +
 icefall/utils.py                              |   51 +-
 24 files changed, 3876 insertions(+), 43 deletions(-)
 create mode 100755 .github/scripts/run-librispeech-conformer-ctc3-2022-11-28.sh
 create mode 100644 .github/workflows/run-librispeech-conformer-ctc3-2022-11-28.yml
 create mode 120000 egs/librispeech/ASR/conformer_ctc3/__init__.py
 create mode 120000 egs/librispeech/ASR/conformer_ctc3/asr_datamodule.py
 create mode 120000 egs/librispeech/ASR/conformer_ctc3/conformer.py
 create mode 100755 egs/librispeech/ASR/conformer_ctc3/decode.py
 create mode 120000 egs/librispeech/ASR/conformer_ctc3/encoder_interface.py
 create mode 100755 egs/librispeech/ASR/conformer_ctc3/export.py
 create mode 100755 egs/librispeech/ASR/conformer_ctc3/jit_pretrained.py
 create mode 120000 egs/librispeech/ASR/conformer_ctc3/lstmp.py
 create mode 100644 egs/librispeech/ASR/conformer_ctc3/model.py
 create mode 120000 egs/librispeech/ASR/conformer_ctc3/optim.py
 create mode 100755 egs/librispeech/ASR/conformer_ctc3/pretrained.py
 create mode 120000 egs/librispeech/ASR/conformer_ctc3/scaling.py
 create mode 120000 egs/librispeech/ASR/conformer_ctc3/scaling_converter.py
 create mode 100755 egs/librispeech/ASR/conformer_ctc3/test_model.py
 create mode 100755 egs/librispeech/ASR/conformer_ctc3/train.py

diff --git a/.flake8 b/.flake8
index 609fa2c03..a0f44263c 100644
--- a/.flake8
+++ b/.flake8
@@ -11,7 +11,7 @@ per-file-ignores =
     egs/*/ASR/*/scaling.py: E501,
     egs/librispeech/ASR/lstm_transducer_stateless*/*.py: E501, E203
     egs/librispeech/ASR/conv_emformer_transducer_stateless*/*.py: E501, E203
-    egs/librispeech/ASR/conformer_ctc2/*py: E501,
+    egs/librispeech/ASR/conformer_ctc*/*py: E501,
     egs/librispeech/ASR/RESULTS.md: E999,
 
     # invalid escape sequence (cause by tex formular), W605
diff --git a/.github/scripts/run-librispeech-conformer-ctc3-2022-11-28.sh b/.github/scripts/run-librispeech-conformer-ctc3-2022-11-28.sh
new file mode 100755
index 000000000..27944807f
--- /dev/null
+++ b/.github/scripts/run-librispeech-conformer-ctc3-2022-11-28.sh
@@ -0,0 +1,119 @@
+#!/usr/bin/env bash
+
+set -e
+
+log() {
+  # This function is from espnet
+  local fname=${BASH_SOURCE[1]##*/}
+  echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
+}
+
+cd egs/librispeech/ASR
+
+repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-conformer-ctc3-2022-11-27
+
+log "Downloading pre-trained model from $repo_url"
+git lfs install
+GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
+repo=$(basename $repo_url)
+
+log "Display test files"
+tree $repo/
+soxi $repo/test_wavs/*.wav
+ls -lh $repo/test_wavs/*.wav
+
+pushd $repo/exp
+git lfs pull --include "data/*"
+git lfs pull --include "exp/jit_trace.pt"
+git lfs pull --include "exp/pretrained.pt"
+ln -s pretrained.pt epoch-99.pt
+ls -lh *.pt
+popd
+
+log "Decode with models exported by torch.jit.trace()"
+
+for m in ctc-decoding 1best; do
+  ./conformer_ctc3/jit_pretrained.py \
+    --model-filename $repo/exp/jit_trace.pt \
+    --words-file $repo/data/lang_bpe_500/words.txt  \
+    --HLG $repo/data/lang_bpe_500/HLG.pt \
+    --bpe-model $repo/data/lang_bpe_500/bpe.model \
+    --G $repo/data/lm/G_4_gram.pt \
+    --method $m \
+    --sample-rate 16000 \
+    $repo/test_wavs/1089-134686-0001.wav \
+    $repo/test_wavs/1221-135766-0001.wav \
+    $repo/test_wavs/1221-135766-0002.wav
+done
+
+log "Export to torchscript model"
+
+./conformer_ctc3/export.py \
+  --exp-dir $repo/exp \
+  --lang-dir $repo/data/lang_bpe_500 \
+  --jit-trace 1 \
+  --epoch 99 \
+  --avg 1 \
+  --use-averaged-model 0
+
+ls -lh $repo/exp/*.pt
+
+log "Decode with models exported by torch.jit.trace()"
+
+for m in ctc-decoding 1best; do
+  ./conformer_ctc3/jit_pretrained.py \
+    --model-filename $repo/exp/jit_trace.pt \
+    --words-file $repo/data/lang_bpe_500/words.txt  \
+    --HLG $repo/data/lang_bpe_500/HLG.pt \
+    --bpe-model $repo/data/lang_bpe_500/bpe.model \
+    --G $repo/data/lm/G_4_gram.pt \
+    --method $m \
+    --sample-rate 16000 \
+    $repo/test_wavs/1089-134686-0001.wav \
+    $repo/test_wavs/1221-135766-0001.wav \
+    $repo/test_wavs/1221-135766-0002.wav
+done
+
+for m in ctc-decoding 1best; do
+  ./conformer_ctc3/pretrained.py \
+    --checkpoint $repo/exp/pretrained.pt \
+    --words-file $repo/data/lang_bpe_500/words.txt  \
+    --HLG $repo/data/lang_bpe_500/HLG.pt \
+    --bpe-model $repo/data/lang_bpe_500/bpe.model \
+    --G $repo/data/lm/G_4_gram.pt \
+    --method $m \
+    --sample-rate 16000 \
+    $repo/test_wavs/1089-134686-0001.wav \
+    $repo/test_wavs/1221-135766-0001.wav \
+    $repo/test_wavs/1221-135766-0002.wav
+done
+
+echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}"
+echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}"
+if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode"  ]]; then
+  mkdir -p conformer_ctc3/exp
+  ln -s $PWD/$repo/exp/pretrained.pt conformer_ctc3/exp/epoch-999.pt
+  ln -s $PWD/$repo/data/lang_bpe_500 data/
+
+  ls -lh data
+  ls -lh conformer_ctc3/exp
+
+  log "Decoding test-clean and test-other"
+
+  # use a small value for decoding with CPU
+  max_duration=100
+
+  for method in ctc-decoding 1best; do
+    log "Decoding with $method"
+    ./conformer_ctc3/decode.py \
+      --epoch 999 \
+      --avg 1 \
+      --use-averaged-model 0 \
+      --exp-dir conformer_ctc3/exp/ \
+      --max-duration $max_duration \
+      --decoding-method $method \
+      --lm-dir data/lm
+  done
+
+  rm conformer_ctc3/exp/*.pt
+fi
diff --git a/.github/workflows/run-librispeech-conformer-ctc3-2022-11-28.yml b/.github/workflows/run-librispeech-conformer-ctc3-2022-11-28.yml
new file mode 100644
index 000000000..21f396c32
--- /dev/null
+++ b/.github/workflows/run-librispeech-conformer-ctc3-2022-11-28.yml
@@ -0,0 +1,151 @@
+# Copyright      2022  Fangjun Kuang (csukuangfj@gmail.com)
+
+# See ../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+name: run-librispeech-conformer-ctc3-2022-11-28
+# zipformer
+
+on:
+  push:
+    branches:
+      - master
+  pull_request:
+    types: [labeled]
+
+  schedule:
+    # minute (0-59)
+    # hour (0-23)
+    # day of the month (1-31)
+    # month (1-12)
+    # day of the week (0-6)
+    # nightly build at 15:50 UTC time every day
+    - cron: "50 15 * * *"
+
+jobs:
+  run_librispeech_2022_11_28_conformer_ctc3:
+    if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
+    runs-on: ${{ matrix.os }}
+    strategy:
+      matrix:
+        os: [ubuntu-latest]
+        python-version: [3.8]
+
+      fail-fast: false
+
+    steps:
+      - uses: actions/checkout@v2
+        with:
+          fetch-depth: 0
+
+      - name: Setup Python ${{ matrix.python-version }}
+        uses: actions/setup-python@v2
+        with:
+          python-version: ${{ matrix.python-version }}
+          cache: 'pip'
+          cache-dependency-path: '**/requirements-ci.txt'
+
+      - name: Install Python dependencies
+        run: |
+          grep -v '^#' ./requirements-ci.txt  | xargs -n 1 -L 1 pip install
+          pip uninstall -y protobuf
+          pip install --no-binary protobuf protobuf
+
+      - name: Cache kaldifeat
+        id: my-cache
+        uses: actions/cache@v2
+        with:
+          path: |
+            ~/tmp/kaldifeat
+          key: cache-tmp-${{ matrix.python-version }}-2022-09-25
+
+      - name: Install kaldifeat
+        if: steps.my-cache.outputs.cache-hit != 'true'
+        shell: bash
+        run: |
+          .github/scripts/install-kaldifeat.sh
+
+      - name: Cache LibriSpeech test-clean and test-other datasets
+        id: libri-test-clean-and-test-other-data
+        uses: actions/cache@v2
+        with:
+          path: |
+            ~/tmp/download
+          key: cache-libri-test-clean-and-test-other
+
+      - name: Download LibriSpeech test-clean and test-other
+        if: steps.libri-test-clean-and-test-other-data.outputs.cache-hit != 'true'
+        shell: bash
+        run: |
+          .github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh
+
+      - name: Prepare manifests for LibriSpeech test-clean and test-other
+        shell: bash
+        run: |
+          .github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh
+
+      - name: Cache LibriSpeech test-clean and test-other fbank features
+        id: libri-test-clean-and-test-other-fbank
+        uses: actions/cache@v2
+        with:
+          path: |
+            ~/tmp/fbank-libri
+          key: cache-libri-fbank-test-clean-and-test-other-v2
+
+      - name: Compute fbank for LibriSpeech test-clean and test-other
+        if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true'
+        shell: bash
+        run: |
+          .github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh
+
+      - name: Inference with pre-trained model
+        shell: bash
+        env:
+          GITHUB_EVENT_NAME: ${{ github.event_name }}
+          GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }}
+        run: |
+          mkdir -p egs/librispeech/ASR/data
+          ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank
+          ls -lh egs/librispeech/ASR/data/*
+
+          sudo apt-get -qq install git-lfs tree sox
+          export PYTHONPATH=$PWD:$PYTHONPATH
+          export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
+          export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
+
+          .github/scripts/run-librispeech-conformer-ctc3-2022-11-28.sh
+
+      - name: Display decoding results for librispeech conformer_ctc3
+        if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
+        shell: bash
+        run: |
+          cd egs/librispeech/ASR/
+          tree ./conformer_ctc3/exp
+
+          cd conformer_ctc3
+          echo "results for conformer_ctc3"
+          echo "===ctc-decoding==="
+          find exp/ctc-decoding -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
+          find exp/ctc-decoding -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
+
+          echo "===1best==="
+          find exp/1best -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
+          find exp/1best -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
+
+      - name: Upload decoding results for librispeech conformer_ctc3
+        uses: actions/upload-artifact@v2
+        if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
+        with:
+          name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-conformer_ctc3-2022-11-28
+          path: egs/librispeech/ASR/conformer_ctc3/exp/
diff --git a/egs/librispeech/ASR/RESULTS.md b/egs/librispeech/ASR/RESULTS.md
index 030e47b86..efd60ba81 100644
--- a/egs/librispeech/ASR/RESULTS.md
+++ b/egs/librispeech/ASR/RESULTS.md
@@ -1,5 +1,106 @@
 ## Results
 
+### LibriSpeech BPE training results (Conformer CTC, supporting delay penalty)
+
+#### [conformer_ctc3](./conformer_ctc3)
+
+It implements Conformer model training with CTC loss.
+For streaming mode, it supports symbol delay penalty.
+
+See  for more details.
+
+##### training on full librispeech
+
+This model contains 12 encoder layers. The number of model parameters is 77352694.
+
+The WERs are:
+
+|                                     | test-clean | test-other | comment              |
+|-------------------------------------|------------|------------|----------------------|
+| ctc-decoding                        | 3.09       | 7.62       | --epoch 25 --avg 7   |
+| 1best                               | 2.87       | 6.44       | --epoch 25 --avg 7   |
+| nbest                               | 2.88       | 6.5        | --epoch 25 --avg 7   |
+| nbest-rescoring                     | 2.71       | 6.1        | --epoch 25 --avg 7   |
+| whole-lattice-rescoring             | 2.71       | 6.04       | --epoch 25 --avg 7   |
+
+The training command is:
+
+```bash
+./conformer_ctc3/train.py \
+  --world-size 4 \
+  --num-epochs 25 \
+  --start-epoch 1 \
+  --exp-dir conformer_ctc3/full \
+  --full-libri 1 \
+  --max-duration 300 \
+  --master-port 12345
+```
+
+The tensorboard log can be found at
+
+
+The decoding command using different methods is:
+```bash
+for method in ctc-decoding 1best nbest nbest-rescoring whole-lattice-rescoring; do
+  ./conformer_ctc3/decode.py \
+    --epoch 25 \
+    --avg 7 \
+    --exp-dir conformer_ctc3/exp \
+    --max-duration 300 \
+    --decoding-method $method \
+    --manifest-dir data/fbank \
+    --lm-dir data/lm \
+done
+```
+
+Pretrained models, training logs, decoding logs, and decoding results
+are available at
+
+
+The command to train a streaming model with symbol delay penalty is:
+```bash
+./conformer_ctc3/train.py \
+  --world-size 4 \
+  --num-epochs 30 \
+  --start-epoch 1 \
+  --exp-dir conformer_ctc3/exp \
+  --full-libri 1 \
+  --dynamic-chunk-training 1 \
+  --causal-convolution 1 \
+  --short-chunk-size 25 \
+  --num-left-chunks 4 \
+  --max-duration 300 \
+  --delay-penalty 0.1
+```
+To evaluate symbol delay, you should:
+(1) Generate cuts with word-time alignments:
+```bash
+./local/add_alignment_librispeech.py \
+  --alignments-dir data/alignment \
+  --cuts-in-dir data/fbank \
+  --cuts-out-dir data/fbank_ali
+```
+(2) Set the argument "--manifest-dir data/fbank_ali" while decoding.
+For example:
+```bash
+./conformer_ctc3/decode.py \
+  --epoch 25 \
+  --avg 7 \
+  --exp-dir ./conformer_ctc3/exp \
+  --max-duration 300 \
+  --decoding-method ctc-decoding \
+  --simulate-streaming 1 \
+  --causal-convolution 1 \
+  --decode-chunk-size 16 \
+  --left-context 64 \
+  --manifest-dir data/fbank_ali
+```
+Note: It supports to calculate symbol delay with following decoding methods:
+  - ctc-greedy-search
+  - ctc-decoding
+  - 1best
+
+
 ### pruned_transducer_stateless8 (zipformer + multidataset)
 
 See  for more details.
@@ -115,7 +216,6 @@ done
 ```
 
 
-
 ### LibriSpeech BPE training results (Pruned Stateless LSTM RNN-T + gradient filter)
 
 #### [lstm_transducer_stateless3](./lstm_transducer_stateless3)
diff --git a/egs/librispeech/ASR/conformer_ctc3/__init__.py b/egs/librispeech/ASR/conformer_ctc3/__init__.py
new file mode 120000
index 000000000..b24e5e357
--- /dev/null
+++ b/egs/librispeech/ASR/conformer_ctc3/__init__.py
@@ -0,0 +1 @@
+../pruned_transducer_stateless2/__init__.py
\ No newline at end of file
diff --git a/egs/librispeech/ASR/conformer_ctc3/asr_datamodule.py b/egs/librispeech/ASR/conformer_ctc3/asr_datamodule.py
new file mode 120000
index 000000000..a074d6085
--- /dev/null
+++ b/egs/librispeech/ASR/conformer_ctc3/asr_datamodule.py
@@ -0,0 +1 @@
+../pruned_transducer_stateless2/asr_datamodule.py
\ No newline at end of file
diff --git a/egs/librispeech/ASR/conformer_ctc3/conformer.py b/egs/librispeech/ASR/conformer_ctc3/conformer.py
new file mode 120000
index 000000000..3b84b9573
--- /dev/null
+++ b/egs/librispeech/ASR/conformer_ctc3/conformer.py
@@ -0,0 +1 @@
+../pruned_transducer_stateless2/conformer.py
\ No newline at end of file
diff --git a/egs/librispeech/ASR/conformer_ctc3/decode.py b/egs/librispeech/ASR/conformer_ctc3/decode.py
new file mode 100755
index 000000000..8eca2ae02
--- /dev/null
+++ b/egs/librispeech/ASR/conformer_ctc3/decode.py
@@ -0,0 +1,1004 @@
+#!/usr/bin/env python3
+#
+# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang,
+#                                                 Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+(1) decode in non-streaming mode (take ctc-decoding as an example)
+./conformer_ctc3/decode.py \
+    --epoch 30 \
+    --avg 15 \
+    --exp-dir ./conformer_ctc3/exp \
+    --max-duration 600 \
+    --decoding-method ctc-decoding
+
+(2) decode in streaming mode (take ctc-decoding as an example)
+./conformer_ctc3/decode.py \
+    --epoch 30 \
+    --avg 15 \
+    --simulate-streaming 1 \
+    --causal-convolution 1 \
+    --decode-chunk-size 16 \
+    --left-context 64 \
+    --exp-dir ./conformer_ctc3/exp \
+    --max-duration 600 \
+    --decoding-method ctc-decoding
+
+To evaluate symbol delay, you should:
+(1) Generate cuts with word-time alignments:
+./local/add_alignment_librispeech.py \
+    --alignments-dir data/alignment \
+    --cuts-in-dir data/fbank \
+    --cuts-out-dir data/fbank_ali
+(2) Set the argument "--manifest-dir data/fbank_ali" while decoding.
+For example:
+./conformer_ctc3/decode.py \
+    --epoch 30 \
+    --avg 15 \
+    --exp-dir ./conformer_ctc3/exp \
+    --max-duration 600 \
+    --decoding-method ctc-decoding \
+    --simulate-streaming 1 \
+    --causal-convolution 1 \
+    --decode-chunk-size 16 \
+    --left-context 64 \
+    --manifest-dir data/fbank_ali
+Note: It supports calculating symbol delay with following decoding methods:
+    - ctc-greedy-search
+    - ctc-decoding
+    - 1best
+"""
+
+
+import argparse
+import logging
+import math
+from collections import defaultdict
+from pathlib import Path
+from typing import Dict, List, Optional, Tuple
+
+import k2
+import sentencepiece as spm
+import torch
+import torch.nn as nn
+from asr_datamodule import LibriSpeechAsrDataModule
+from train import add_model_arguments, get_ctc_model, get_params
+
+from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
+from icefall.checkpoint import (
+    average_checkpoints,
+    average_checkpoints_with_averaged_model,
+    find_checkpoints,
+    load_checkpoint,
+)
+from icefall.decode import (
+    get_lattice,
+    nbest_decoding,
+    nbest_oracle,
+    one_best_decoding,
+    rescore_with_n_best_list,
+    rescore_with_whole_lattice,
+)
+from icefall.lexicon import Lexicon
+from icefall.utils import (
+    AttributeDict,
+    DecodingResults,
+    get_texts,
+    get_texts_with_timestamp,
+    make_pad_mask,
+    parse_hyp_and_timestamp,
+    setup_logger,
+    store_transcripts_and_timestamps,
+    str2bool,
+    write_error_stats_with_timestamps,
+)
+
+LOG_EPS = math.log(1e-10)
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+
+    parser.add_argument(
+        "--epoch",
+        type=int,
+        default=30,
+        help="""It specifies the checkpoint to use for decoding.
+        Note: Epoch counts from 1.
+        You can specify --avg to use more checkpoints for model averaging.""",
+    )
+
+    parser.add_argument(
+        "--iter",
+        type=int,
+        default=0,
+        help="""If positive, --epoch is ignored and it
+        will use the checkpoint exp_dir/checkpoint-iter.pt.
+        You can specify --avg to use more checkpoints for model averaging.
+        """,
+    )
+
+    parser.add_argument(
+        "--avg",
+        type=int,
+        default=15,
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch' and '--iter'",
+    )
+
+    parser.add_argument(
+        "--use-averaged-model",
+        type=str2bool,
+        default=True,
+        help="Whether to load averaged model. Currently it only supports "
+        "using --epoch. If True, it would decode with the averaged model "
+        "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+        "Actually only the models with epoch number of `epoch-avg` and "
+        "`epoch` are loaded for averaging. ",
+    )
+
+    parser.add_argument(
+        "--exp-dir",
+        type=str,
+        default="pruned_transducer_stateless4/exp",
+        help="The experiment dir",
+    )
+
+    parser.add_argument(
+        "--lang-dir",
+        type=Path,
+        default="data/lang_bpe_500",
+        help="The lang dir containing word table and LG graph",
+    )
+
+    parser.add_argument(
+        "--decoding-method",
+        type=str,
+        default="ctc-decoding",
+        help="""Decoding method.
+        Supported values are:
+        - (0) ctc-decoding. Use CTC decoding. It uses a sentence piece
+          model, i.e., lang_dir/bpe.model, to convert word pieces to words.
+          It needs neither a lexicon nor an n-gram LM.
+        - (1) ctc-greedy-search. It only use CTC output and a sentence piece
+          model for decoding. It produces the same results with ctc-decoding.
+        - (2) 1best. Extract the best path from the decoding lattice as the
+          decoding result.
+        - (3) nbest. Extract n paths from the decoding lattice; the path
+          with the highest score is the decoding result.
+        - (4) nbest-rescoring. Extract n paths from the decoding lattice,
+          rescore them with an n-gram LM (e.g., a 4-gram LM), the path with
+          the highest score is the decoding result.
+        - (5) whole-lattice-rescoring. Rescore the decoding lattice with an
+          n-gram LM (e.g., a 4-gram LM), the best path of rescored lattice
+          is the decoding result.
+          you have trained an RNN LM using ./rnn_lm/train.py
+        - (6) nbest-oracle. Its WER is the lower bound of any n-best
+          rescoring method can achieve. Useful for debugging n-best
+          rescoring method.
+        """,
+    )
+
+    parser.add_argument(
+        "--num-paths",
+        type=int,
+        default=100,
+        help="""Number of paths for n-best based decoding method.
+        Used only when "method" is one of the following values:
+        nbest, nbest-rescoring, and nbest-oracle
+        """,
+    )
+
+    parser.add_argument(
+        "--nbest-scale",
+        type=float,
+        default=0.5,
+        help="""The scale to be applied to `lattice.scores`.
+        It's needed if you use any kinds of n-best based rescoring.
+        Used only when "method" is one of the following values:
+        nbest, nbest-rescoring, and nbest-oracle
+        A smaller value results in more unique paths.
+        """,
+    )
+
+    parser.add_argument(
+        "--lm-dir",
+        type=str,
+        default="data/lm",
+        help="""The n-gram LM dir.
+        It should contain either G_4_gram.pt or G_4_gram.fst.txt
+        """,
+    )
+
+    parser.add_argument(
+        "--simulate-streaming",
+        type=str2bool,
+        default=False,
+        help="""Whether to simulate streaming in decoding, this is a good way to
+        test a streaming model.
+        """,
+    )
+
+    parser.add_argument(
+        "--decode-chunk-size",
+        type=int,
+        default=16,
+        help="The chunk size for decoding (in frames after subsampling)",
+    )
+
+    parser.add_argument(
+        "--left-context",
+        type=int,
+        default=64,
+        help="left context can be seen during decoding (in frames after subsampling)",
+    )
+
+    add_model_arguments(parser)
+
+    return parser
+
+
+def get_decoding_params() -> AttributeDict:
+    """Parameters for decoding."""
+    params = AttributeDict(
+        {
+            "frame_shift_ms": 10,
+            "search_beam": 20,
+            "output_beam": 8,
+            "min_active_states": 30,
+            "max_active_states": 10000,
+            "use_double_scores": True,
+        }
+    )
+    return params
+
+
+def ctc_greedy_search(
+    ctc_probs: torch.Tensor,
+    nnet_output_lens: torch.Tensor,
+) -> List[List[int]]:
+    """Apply CTC greedy search
+    Args:
+      ctc_probs (torch.Tensor): (batch, max_len, feat_dim)
+      nnet_output_lens (torch.Tensor): (batch, )
+    Returns:
+      List[List[int]]: best path result
+    """
+    topk_prob, topk_index = ctc_probs.topk(1, dim=2)  # (B, maxlen, 1)
+    topk_index = topk_index.squeeze(2)  # (B, maxlen)
+    mask = make_pad_mask(nnet_output_lens)
+    topk_index = topk_index.masked_fill_(mask, 0)  # (B, maxlen)
+    hyps = [hyp.tolist() for hyp in topk_index]
+    scores = topk_prob.max(1)
+    ret_hyps = []
+    timestamps = []
+    for i in range(len(hyps)):
+        hyp, time = remove_duplicates_and_blank(hyps[i])
+        ret_hyps.append(hyp)
+        timestamps.append(time)
+    return ret_hyps, timestamps, scores
+
+
+def remove_duplicates_and_blank(hyp: List[int]) -> Tuple[List[int], List[int]]:
+    # modified from https://github.com/wenet-e2e/wenet/blob/main/wenet/utils/common.py
+    new_hyp: List[int] = []
+    time: List[int] = []
+    cur = 0
+    while cur < len(hyp):
+        if hyp[cur] != 0:
+            new_hyp.append(hyp[cur])
+            time.append(cur)
+        prev = cur
+        while cur < len(hyp) and hyp[cur] == hyp[prev]:
+            cur += 1
+    return new_hyp, time
+
+
+def decode_one_batch(
+    params: AttributeDict,
+    model: nn.Module,
+    HLG: Optional[k2.Fsa],
+    H: Optional[k2.Fsa],
+    bpe_model: Optional[spm.SentencePieceProcessor],
+    batch: dict,
+    word_table: k2.SymbolTable,
+    sos_id: int,
+    eos_id: int,
+    G: Optional[k2.Fsa] = None,
+) -> Dict[str, Tuple[List[List[str]], List[List[float]]]]:
+    """Decode one batch and return the result in a dict. The dict has the
+    following format:
+    - key: It indicates the setting used for decoding. For example,
+           if no rescoring is used, the key is the string `no_rescore`.
+           If LM rescoring is used, the key is the string `lm_scale_xxx`,
+           where `xxx` is the value of `lm_scale`. An example key is
+           `lm_scale_0.7`
+    - value: It contains the decoding result. `len(value)` equals to
+             batch size. `value[i]` is the decoding result for the i-th
+             utterance in the given batch.
+
+    Args:
+      params:
+        It's the return value of :func:`get_params`.
+
+        - params.decoding_method is "1best", it uses 1best decoding without LM rescoring.
+        - params.decoding_method is "nbest", it uses nbest decoding without LM rescoring.
+        - params.decoding_method is "nbest-rescoring", it uses nbest LM rescoring.
+        - params.decoding_method is "whole-lattice-rescoring", it uses whole lattice LM
+          rescoring.
+
+      model:
+        The neural model.
+      HLG:
+        The decoding graph. Used only when params.decoding_method is NOT ctc-decoding.
+      H:
+        The ctc topo. Used only when params.decoding_method is ctc-decoding.
+      bpe_model:
+        The BPE model. Used only when params.decoding_method is ctc-decoding.
+      batch:
+        It is the return value from iterating
+        `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
+        for the format of the `batch`.
+      word_table:
+        The word symbol table.
+      sos_id:
+        The token ID of the SOS.
+      eos_id:
+        The token ID of the EOS.
+      G:
+        An LM. It is not None when params.decoding_method is "nbest-rescoring"
+        or "whole-lattice-rescoring". In general, the G in HLG
+        is a 3-gram LM, while this G is a 4-gram LM.
+    Returns:
+      Return the decoding result. See above description for the format of
+      the returned dict. Note: If it decodes to nothing, then return None.
+    """
+    if HLG is not None:
+        device = HLG.device
+    else:
+        device = H.device
+    feature = batch["inputs"]
+    assert feature.ndim == 3
+    feature = feature.to(device)
+    # at entry, feature is (N, T, C)
+
+    supervisions = batch["supervisions"]
+    feature_lens = supervisions["num_frames"].to(device)
+
+    if params.simulate_streaming:
+        feature_lens += params.left_context
+        feature = torch.nn.functional.pad(
+            feature,
+            pad=(0, 0, 0, params.left_context),
+            value=LOG_EPS,
+        )
+        encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
+            x=feature,
+            x_lens=feature_lens,
+            chunk_size=params.decode_chunk_size,
+            left_context=params.left_context,
+            simulate_streaming=True,
+        )
+    else:
+        encoder_out, encoder_out_lens = model.encoder(feature, feature_lens)
+
+    nnet_output = model.get_ctc_output(encoder_out)
+    # nnet_output is (N, T, C)
+
+    if params.decoding_method == "ctc-greedy-search":
+        hyps, timestamps, _ = ctc_greedy_search(
+            nnet_output,
+            encoder_out_lens,
+        )
+        res = DecodingResults(hyps=hyps, timestamps=timestamps)
+        hyps, timestamps = parse_hyp_and_timestamp(
+            res=res,
+            sp=bpe_model,
+            subsampling_factor=params.subsampling_factor,
+            frame_shift_ms=params.frame_shift_ms,
+        )
+        key = "ctc-greedy-search"
+        return {key: (hyps, timestamps)}
+
+    supervision_segments = torch.stack(
+        (
+            supervisions["sequence_idx"],
+            supervisions["start_frame"] // params.subsampling_factor,
+            supervisions["num_frames"] // params.subsampling_factor,
+        ),
+        1,
+    ).to(torch.int32)
+
+    if H is None:
+        assert HLG is not None
+        decoding_graph = HLG
+    else:
+        assert HLG is None
+        assert bpe_model is not None
+        decoding_graph = H
+
+    if params.decoding_method in ["1best", "nbest", "nbest-oracle"]:
+        hlg_scale_list = [0.2, 0.4, 0.6, 0.8, 1.0]
+
+        ori_scores = decoding_graph.scores.clone()
+
+        ans = {}
+        for hlg_scale in hlg_scale_list:
+            decoding_graph.scores = ori_scores * hlg_scale
+            lattice = get_lattice(
+                nnet_output=nnet_output,
+                decoding_graph=decoding_graph,
+                supervision_segments=supervision_segments,
+                search_beam=params.search_beam,
+                output_beam=params.output_beam,
+                min_active_states=params.min_active_states,
+                max_active_states=params.max_active_states,
+                subsampling_factor=params.subsampling_factor,
+            )
+            key_suffix = f"-HLG-scale-{hlg_scale}"
+
+            if params.decoding_method == "nbest-oracle":
+                # Note: You can also pass rescored lattices to it.
+                # We choose the HLG decoded lattice for speed reasons
+                # as HLG decoding is faster and the oracle WER
+                # is only slightly worse than that of rescored lattices.
+                best_path = nbest_oracle(
+                    lattice=lattice,
+                    num_paths=params.num_paths,
+                    ref_texts=supervisions["text"],
+                    word_table=word_table,
+                    nbest_scale=params.nbest_scale,
+                    oov="",
+                )
+                hyps = get_texts(best_path)
+                hyps = [[word_table[i] for i in ids] for ids in hyps]
+                key = f"oracle-{params.num_paths}-nbest-scale-{params.nbest_scale}"  # noqa
+                timestamps = [[] for _ in range(len(hyps))]
+                ans[key + key_suffix] = (hyps, timestamps)
+
+            elif params.decoding_method in ["1best", "nbest"]:
+                if params.decoding_method == "1best":
+                    best_path = one_best_decoding(
+                        lattice=lattice,
+                        use_double_scores=params.use_double_scores,
+                    )
+                    key = "no-rescore"
+                    res = get_texts_with_timestamp(best_path)
+                    hyps, timestamps = parse_hyp_and_timestamp(
+                        res=res,
+                        subsampling_factor=params.subsampling_factor,
+                        frame_shift_ms=params.frame_shift_ms,
+                        word_table=word_table,
+                    )
+                else:
+                    best_path = nbest_decoding(
+                        lattice=lattice,
+                        num_paths=params.num_paths,
+                        use_double_scores=params.use_double_scores,
+                        nbest_scale=params.nbest_scale,
+                    )
+                    key = f"no_rescore-nbest-scale-{params.nbest_scale}-{params.num_paths}"  # noqa
+                    hyps = get_texts(best_path)
+                    hyps = [[word_table[i] for i in ids] for ids in hyps]
+                    timestamps = [[] for _ in range(len(hyps))]
+
+                ans[key + key_suffix] = (hyps, timestamps)
+
+        return ans
+
+    lattice = get_lattice(
+        nnet_output=nnet_output,
+        decoding_graph=decoding_graph,
+        supervision_segments=supervision_segments,
+        search_beam=params.search_beam,
+        output_beam=params.output_beam,
+        min_active_states=params.min_active_states,
+        max_active_states=params.max_active_states,
+        subsampling_factor=params.subsampling_factor,
+    )
+
+    if params.decoding_method == "ctc-decoding":
+        best_path = one_best_decoding(
+            lattice=lattice, use_double_scores=params.use_double_scores
+        )
+        # Note: `best_path.aux_labels` contains token IDs, not word IDs
+        # since we are using H, not HLG here.
+        #
+        # token_ids is a lit-of-list of IDs
+        res = get_texts_with_timestamp(best_path)
+        hyps, timestamps = parse_hyp_and_timestamp(
+            res=res,
+            sp=bpe_model,
+            subsampling_factor=params.subsampling_factor,
+            frame_shift_ms=params.frame_shift_ms,
+        )
+        key = "ctc-decoding"
+        return {key: (hyps, timestamps)}
+
+    assert params.decoding_method in [
+        "nbest-rescoring",
+        "whole-lattice-rescoring",
+    ]
+
+    lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]
+    lm_scale_list += [0.8, 0.9, 1.0, 1.1, 1.2, 1.3]
+    lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0]
+
+    if params.decoding_method == "nbest-rescoring":
+        best_path_dict = rescore_with_n_best_list(
+            lattice=lattice,
+            G=G,
+            num_paths=params.num_paths,
+            lm_scale_list=lm_scale_list,
+            nbest_scale=params.nbest_scale,
+        )
+    elif params.decoding_method == "whole-lattice-rescoring":
+        best_path_dict = rescore_with_whole_lattice(
+            lattice=lattice,
+            G_with_epsilon_loops=G,
+            lm_scale_list=lm_scale_list,
+        )
+    else:
+        assert False, f"Unsupported decoding method: {params.decoding_method}"
+
+    ans = dict()
+    if best_path_dict is not None:
+        for lm_scale_str, best_path in best_path_dict.items():
+            hyps = get_texts(best_path)
+            hyps = [[word_table[i] for i in ids] for ids in hyps]
+            timestamps = [[] for _ in range(len(hyps))]
+            ans[lm_scale_str] = (hyps, timestamps)
+    else:
+        ans = None
+    return ans
+
+
+def decode_dataset(
+    dl: torch.utils.data.DataLoader,
+    params: AttributeDict,
+    model: nn.Module,
+    HLG: Optional[k2.Fsa],
+    H: Optional[k2.Fsa],
+    bpe_model: Optional[spm.SentencePieceProcessor],
+    word_table: k2.SymbolTable,
+    sos_id: int,
+    eos_id: int,
+    G: Optional[k2.Fsa] = None,
+) -> Dict[str, List[Tuple[str, List[str], List[str], List[float], List[float]]]]:
+    """Decode dataset.
+
+    Args:
+      dl:
+        PyTorch's dataloader containing the dataset to decode.
+      params:
+        It is returned by :func:`get_params`.
+      model:
+        The neural model.
+      HLG:
+        The decoding graph. Used only when params.decoding_method is NOT ctc-decoding.
+      H:
+        The ctc topo. Used only when params.decoding_method is ctc-decoding.
+      bpe_model:
+        The BPE model. Used only when params.decoding_method is ctc-decoding.
+      word_table:
+        It is the word symbol table.
+      sos_id:
+        The token ID for SOS.
+      eos_id:
+        The token ID for EOS.
+      G:
+        An LM. It is not None when params.decoding_method is "nbest-rescoring"
+        or "whole-lattice-rescoring". In general, the G in HLG
+        is a 3-gram LM, while this G is a 4-gram LM.
+    Returns:
+      Return a dict, whose key may be "no-rescore" if no LM rescoring
+      is used, or it may be "lm_scale_0.7" if LM rescoring is used.
+      Its value is a list of tuples. Each tuple contains two elements:
+      The first is the reference transcript, and the second is the
+      predicted result.
+    """
+    num_cuts = 0
+
+    try:
+        num_batches = len(dl)
+    except TypeError:
+        num_batches = "?"
+
+    results = defaultdict(list)
+    for batch_idx, batch in enumerate(dl):
+        texts = batch["supervisions"]["text"]
+        cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
+
+        timestamps_ref = []
+        for cut in batch["supervisions"]["cut"]:
+            for s in cut.supervisions:
+                time = []
+                if s.alignment is not None and "word" in s.alignment:
+                    time = [
+                        aliword.start
+                        for aliword in s.alignment["word"]
+                        if aliword.symbol != ""
+                    ]
+                timestamps_ref.append(time)
+
+        hyps_dict = decode_one_batch(
+            params=params,
+            model=model,
+            HLG=HLG,
+            H=H,
+            bpe_model=bpe_model,
+            batch=batch,
+            word_table=word_table,
+            G=G,
+            sos_id=sos_id,
+            eos_id=eos_id,
+        )
+
+        for name, (hyps, timestamps_hyp) in hyps_dict.items():
+            this_batch = []
+            assert len(hyps) == len(texts) and len(timestamps_hyp) == len(
+                timestamps_ref
+            )
+            for cut_id, hyp_words, ref_text, time_hyp, time_ref in zip(
+                cut_ids, hyps, texts, timestamps_hyp, timestamps_ref
+            ):
+                ref_words = ref_text.split()
+                this_batch.append((cut_id, ref_words, hyp_words, time_ref, time_hyp))
+
+            results[name].extend(this_batch)
+
+        num_cuts += len(texts)
+
+        if batch_idx % 100 == 0:
+            batch_str = f"{batch_idx}/{num_batches}"
+
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+    return results
+
+
+def save_results(
+    params: AttributeDict,
+    test_set_name: str,
+    results_dict: Dict[
+        str,
+        List[Tuple[List[str], List[str], List[str], List[float], List[float]]],
+    ],
+):
+    test_set_wers = dict()
+    test_set_delays = dict()
+    for key, results in results_dict.items():
+        recog_path = (
+            params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
+        )
+        results = sorted(results)
+        store_transcripts_and_timestamps(filename=recog_path, texts=results)
+        logging.info(f"The transcripts are stored in {recog_path}")
+
+        # The following prints out WERs, per-word error statistics and aligned
+        # ref/hyp pairs.
+        errs_filename = (
+            params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
+        )
+        with open(errs_filename, "w") as f:
+            wer, mean_delay, var_delay = write_error_stats_with_timestamps(
+                f, f"{test_set_name}-{key}", results, enable_log=True
+            )
+            test_set_wers[key] = wer
+            test_set_delays[key] = (mean_delay, var_delay)
+
+        logging.info("Wrote detailed error stats to {}".format(errs_filename))
+
+    test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
+    errs_info = (
+        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+    )
+    with open(errs_info, "w") as f:
+        print("settings\tWER", file=f)
+        for key, val in test_set_wers:
+            print("{}\t{}".format(key, val), file=f)
+
+    test_set_delays = sorted(test_set_delays.items(), key=lambda x: x[1][0])
+    delays_info = (
+        params.res_dir
+        / f"symbol-delay-summary-{test_set_name}-{key}-{params.suffix}.txt"
+    )
+    with open(delays_info, "w") as f:
+        print("settings\tsymbol-delay", file=f)
+        for key, val in test_set_delays:
+            print(
+                "{}\tmean: {}s, variance: {}".format(key, val[0], val[1]),
+                file=f,
+            )
+
+    s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
+    note = "\tbest for {}".format(test_set_name)
+    for key, val in test_set_wers:
+        s += "{}\t{}{}\n".format(key, val, note)
+        note = ""
+    logging.info(s)
+
+    s = "\nFor {}, symbol-delay of different settings are:\n".format(test_set_name)
+    note = "\tbest for {}".format(test_set_name)
+    for key, val in test_set_delays:
+        s += "{}\tmean: {}s, variance: {}{}\n".format(key, val[0], val[1], note)
+        note = ""
+    logging.info(s)
+
+
+@torch.no_grad()
+def main():
+    parser = get_parser()
+    LibriSpeechAsrDataModule.add_arguments(parser)
+    args = parser.parse_args()
+    args.exp_dir = Path(args.exp_dir)
+    args.lang_dir = Path(args.lang_dir)
+    args.lm_dir = Path(args.lm_dir)
+
+    params = get_params()
+    # add decoding params
+    params.update(get_decoding_params())
+    params.update(vars(args))
+
+    assert params.decoding_method in (
+        "ctc-greedy-search",
+        "ctc-decoding",
+        "1best",
+        "nbest",
+        "nbest-rescoring",
+        "whole-lattice-rescoring",
+        "nbest-oracle",
+    )
+    params.res_dir = params.exp_dir / params.decoding_method
+
+    if params.iter > 0:
+        params.suffix = f"iter-{params.iter}-avg-{params.avg}"
+    else:
+        params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
+
+    if params.simulate_streaming:
+        params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}"
+        params.suffix += f"-left-context-{params.left_context}"
+
+    if params.simulate_streaming:
+        assert (
+            params.causal_convolution
+        ), "Decoding in streaming requires causal convolution"
+
+    if params.use_averaged_model:
+        params.suffix += "-use-averaged-model"
+
+    setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
+    logging.info("Decoding started")
+
+    device = torch.device("cpu")
+    if torch.cuda.is_available():
+        device = torch.device("cuda", 0)
+
+    logging.info(f"Device: {device}")
+    logging.info(params)
+
+    lexicon = Lexicon(params.lang_dir)
+    max_token_id = max(lexicon.tokens)
+    num_classes = max_token_id + 1  # +1 for the blank
+
+    graph_compiler = BpeCtcTrainingGraphCompiler(
+        params.lang_dir,
+        device=device,
+        sos_token="",
+        eos_token="",
+    )
+    sos_id = graph_compiler.sos_id
+    eos_id = graph_compiler.eos_id
+
+    params.vocab_size = num_classes
+    params.sos_id = sos_id
+    params.eos_id = eos_id
+
+    if params.decoding_method in ["ctc-decoding", "ctc-greedy-search"]:
+        HLG = None
+        H = k2.ctc_topo(
+            max_token=max_token_id,
+            modified=False,
+            device=device,
+        )
+        bpe_model = spm.SentencePieceProcessor()
+        bpe_model.load(str(params.lang_dir / "bpe.model"))
+    else:
+        H = None
+        bpe_model = None
+        HLG = k2.Fsa.from_dict(
+            torch.load(f"{params.lang_dir}/HLG.pt", map_location=device)
+        )
+        assert HLG.requires_grad is False
+
+        if not hasattr(HLG, "lm_scores"):
+            HLG.lm_scores = HLG.scores.clone()
+
+    if params.decoding_method in (
+        "nbest-rescoring",
+        "whole-lattice-rescoring",
+    ):
+        if not (params.lm_dir / "G_4_gram.pt").is_file():
+            logging.info("Loading G_4_gram.fst.txt")
+            logging.warning("It may take 8 minutes.")
+            with open(params.lm_dir / "G_4_gram.fst.txt") as f:
+                first_word_disambig_id = lexicon.word_table["#0"]
+
+                G = k2.Fsa.from_openfst(f.read(), acceptor=False)
+                # G.aux_labels is not needed in later computations, so
+                # remove it here.
+                del G.aux_labels
+                # CAUTION: The following line is crucial.
+                # Arcs entering the back-off state have label equal to #0.
+                # We have to change it to 0 here.
+                G.labels[G.labels >= first_word_disambig_id] = 0
+                # See https://github.com/k2-fsa/k2/issues/874
+                # for why we need to set G.properties to None
+                G.__dict__["_properties"] = None
+                G = k2.Fsa.from_fsas([G]).to(device)
+                G = k2.arc_sort(G)
+                # Save a dummy value so that it can be loaded in C++.
+                # See https://github.com/pytorch/pytorch/issues/67902
+                # for why we need to do this.
+                G.dummy = 1
+
+                torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt")
+        else:
+            logging.info("Loading pre-compiled G_4_gram.pt")
+            d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device)
+            G = k2.Fsa.from_dict(d)
+
+        if params.decoding_method == "whole-lattice-rescoring":
+            # Add epsilon self-loops to G as we will compose
+            # it with the whole lattice later
+            G = k2.add_epsilon_self_loops(G)
+            G = k2.arc_sort(G)
+            G = G.to(device)
+
+        # G.lm_scores is used to replace HLG.lm_scores during
+        # LM rescoring.
+        G.lm_scores = G.scores.clone()
+    else:
+        G = None
+
+    logging.info("About to create model")
+    model = get_ctc_model(params)
+
+    if not params.use_averaged_model:
+        if params.iter > 0:
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg
+            ]
+            if len(filenames) == 0:
+                raise ValueError(
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            elif len(filenames) < params.avg:
+                raise ValueError(
+                    f"Not enough checkpoints ({len(filenames)}) found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            logging.info(f"averaging {filenames}")
+            model.to(device)
+            model.load_state_dict(average_checkpoints(filenames, device=device))
+        elif params.avg == 1:
+            load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+        else:
+            start = params.epoch - params.avg + 1
+            filenames = []
+            for i in range(start, params.epoch + 1):
+                if i >= 1:
+                    filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+            logging.info(f"averaging {filenames}")
+            model.to(device)
+            model.load_state_dict(average_checkpoints(filenames, device=device))
+    else:
+        if params.iter > 0:
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg + 1
+            ]
+            if len(filenames) == 0:
+                raise ValueError(
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            elif len(filenames) < params.avg + 1:
+                raise ValueError(
+                    f"Not enough checkpoints ({len(filenames)}) found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            filename_start = filenames[-1]
+            filename_end = filenames[0]
+            logging.info(
+                "Calculating the averaged model over iteration checkpoints"
+                f" from {filename_start} (excluded) to {filename_end}"
+            )
+            model.to(device)
+            model.load_state_dict(
+                average_checkpoints_with_averaged_model(
+                    filename_start=filename_start,
+                    filename_end=filename_end,
+                    device=device,
+                )
+            )
+        else:
+            assert params.avg > 0, params.avg
+            start = params.epoch - params.avg
+            assert start >= 1, start
+            filename_start = f"{params.exp_dir}/epoch-{start}.pt"
+            filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
+            logging.info(
+                f"Calculating the averaged model over epoch range from "
+                f"{start} (excluded) to {params.epoch}"
+            )
+            model.to(device)
+            model.load_state_dict(
+                average_checkpoints_with_averaged_model(
+                    filename_start=filename_start,
+                    filename_end=filename_end,
+                    device=device,
+                )
+            )
+
+    model.to(device)
+    model.eval()
+
+    num_param = sum([p.numel() for p in model.parameters()])
+    logging.info(f"Number of model parameters: {num_param}")
+
+    # we need cut ids to display recognition results.
+    args.return_cuts = True
+    librispeech = LibriSpeechAsrDataModule(args)
+
+    test_clean_cuts = librispeech.test_clean_cuts()
+    test_other_cuts = librispeech.test_other_cuts()
+
+    test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
+    test_other_dl = librispeech.test_dataloaders(test_other_cuts)
+
+    test_sets = ["test-clean", "test-other"]
+    test_dl = [test_clean_dl, test_other_dl]
+
+    for test_set, test_dl in zip(test_sets, test_dl):
+        results_dict = decode_dataset(
+            dl=test_dl,
+            params=params,
+            model=model,
+            HLG=HLG,
+            H=H,
+            bpe_model=bpe_model,
+            word_table=lexicon.word_table,
+            G=G,
+            sos_id=sos_id,
+            eos_id=eos_id,
+        )
+
+        save_results(
+            params=params,
+            test_set_name=test_set,
+            results_dict=results_dict,
+        )
+
+    logging.info("Done!")
+
+
+if __name__ == "__main__":
+    main()
diff --git a/egs/librispeech/ASR/conformer_ctc3/encoder_interface.py b/egs/librispeech/ASR/conformer_ctc3/encoder_interface.py
new file mode 120000
index 000000000..b9aa0ae08
--- /dev/null
+++ b/egs/librispeech/ASR/conformer_ctc3/encoder_interface.py
@@ -0,0 +1 @@
+../pruned_transducer_stateless2/encoder_interface.py
\ No newline at end of file
diff --git a/egs/librispeech/ASR/conformer_ctc3/export.py b/egs/librispeech/ASR/conformer_ctc3/export.py
new file mode 100755
index 000000000..c5b95d981
--- /dev/null
+++ b/egs/librispeech/ASR/conformer_ctc3/export.py
@@ -0,0 +1,292 @@
+#!/usr/bin/env python3
+#
+# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# This script converts several saved checkpoints
+# to a single one using model averaging.
+"""
+Usage:
+
+(1) Export to torchscript model using torch.jit.trace()
+
+./conformer_ctc3/export.py \
+  --exp-dir ./conformer_ctc3/exp \
+  --lang-dir data/lang_bpe_500 \
+  --epoch 20 \
+  --avg 10 \
+  --jit-trace 1
+
+It will generates the file: `jit_trace.pt`.
+
+(2) Export `model.state_dict()`
+
+./conformer_ctc3/export.py \
+  --exp-dir ./conformer_ctc3/exp \
+  --lang-dir data/lang_bpe_500 \
+  --epoch 20 \
+  --avg 10
+
+It will generate a file `pretrained.pt` in the given `exp_dir`. You can later
+load it by `icefall.checkpoint.load_checkpoint()`.
+
+To use the generated file with `conformer_ctc3/decode.py`,
+you can do:
+
+    cd /path/to/exp_dir
+    ln -s pretrained.pt epoch-9999.pt
+
+    cd /path/to/egs/librispeech/ASR
+    ./conformer_ctc3/decode.py \
+        --exp-dir ./conformer_ctc3/exp \
+        --epoch 9999 \
+        --avg 1 \
+        --max-duration 100 \
+        --lang-dir data/lang_bpe_500
+"""
+
+import argparse
+import logging
+from pathlib import Path
+
+import torch
+from scaling_converter import convert_scaled_to_non_scaled
+from train import add_model_arguments, get_ctc_model, get_params
+
+from icefall.checkpoint import (
+    average_checkpoints,
+    average_checkpoints_with_averaged_model,
+    find_checkpoints,
+    load_checkpoint,
+)
+from icefall.lexicon import Lexicon
+from icefall.utils import str2bool
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+
+    parser.add_argument(
+        "--epoch",
+        type=int,
+        default=28,
+        help="""It specifies the checkpoint to use for averaging.
+        Note: Epoch counts from 0.
+        You can specify --avg to use more checkpoints for model averaging.""",
+    )
+
+    parser.add_argument(
+        "--iter",
+        type=int,
+        default=0,
+        help="""If positive, --epoch is ignored and it
+        will use the checkpoint exp_dir/checkpoint-iter.pt.
+        You can specify --avg to use more checkpoints for model averaging.
+        """,
+    )
+
+    parser.add_argument(
+        "--avg",
+        type=int,
+        default=15,
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch' and '--iter'",
+    )
+
+    parser.add_argument(
+        "--use-averaged-model",
+        type=str2bool,
+        default=True,
+        help="Whether to load averaged model. Currently it only supports "
+        "using --epoch. If True, it would decode with the averaged model "
+        "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+        "Actually only the models with epoch number of `epoch-avg` and "
+        "`epoch` are loaded for averaging. ",
+    )
+
+    parser.add_argument(
+        "--exp-dir",
+        type=str,
+        default="pruned_transducer_stateless4/exp",
+        help="""It specifies the directory where all training related
+        files, e.g., checkpoints, log, etc, are saved
+        """,
+    )
+
+    parser.add_argument(
+        "--lang-dir",
+        type=Path,
+        default="data/lang_bpe_500",
+        help="The lang dir containing word table and LG graph",
+    )
+
+    parser.add_argument(
+        "--jit-trace",
+        type=str2bool,
+        default=False,
+        help="""True to save a model after applying torch.jit.script.
+        """,
+    )
+
+    parser.add_argument(
+        "--streaming-model",
+        type=str2bool,
+        default=False,
+        help="""Whether to export a streaming model, if the models in exp-dir
+        are streaming model, this should be True.
+        """,
+    )
+
+    add_model_arguments(parser)
+
+    return parser
+
+
+def main():
+    args = get_parser().parse_args()
+    args.exp_dir = Path(args.exp_dir)
+
+    params = get_params()
+    params.update(vars(args))
+
+    device = torch.device("cpu")
+    if torch.cuda.is_available():
+        device = torch.device("cuda", 0)
+
+    logging.info(f"device: {device}")
+
+    lexicon = Lexicon(params.lang_dir)
+    max_token_id = max(lexicon.tokens)
+    num_classes = max_token_id + 1  # +1 for the blank
+    params.vocab_size = num_classes
+
+    if params.streaming_model:
+        assert params.causal_convolution
+
+    logging.info(params)
+
+    logging.info("About to create model")
+    model = get_ctc_model(params)
+
+    model.to(device)
+
+    if not params.use_averaged_model:
+        if params.iter > 0:
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg
+            ]
+            if len(filenames) == 0:
+                raise ValueError(
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            elif len(filenames) < params.avg:
+                raise ValueError(
+                    f"Not enough checkpoints ({len(filenames)}) found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            logging.info(f"averaging {filenames}")
+            model.load_state_dict(average_checkpoints(filenames, device=device))
+        elif params.avg == 1:
+            load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+        else:
+            start = params.epoch - params.avg + 1
+            filenames = []
+            for i in range(start, params.epoch + 1):
+                if i >= 1:
+                    filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+            logging.info(f"averaging {filenames}")
+            model.load_state_dict(average_checkpoints(filenames, device=device))
+    else:
+        if params.iter > 0:
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg + 1
+            ]
+            if len(filenames) == 0:
+                raise ValueError(
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            elif len(filenames) < params.avg + 1:
+                raise ValueError(
+                    f"Not enough checkpoints ({len(filenames)}) found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            filename_start = filenames[-1]
+            filename_end = filenames[0]
+            logging.info(
+                "Calculating the averaged model over iteration checkpoints"
+                f" from {filename_start} (excluded) to {filename_end}"
+            )
+            model.load_state_dict(
+                average_checkpoints_with_averaged_model(
+                    filename_start=filename_start,
+                    filename_end=filename_end,
+                    device=device,
+                )
+            )
+        else:
+            assert params.avg > 0, params.avg
+            start = params.epoch - params.avg
+            assert start >= 1, start
+            filename_start = f"{params.exp_dir}/epoch-{start}.pt"
+            filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
+            logging.info(
+                f"Calculating the averaged model over epoch range from "
+                f"{start} (excluded) to {params.epoch}"
+            )
+            model.load_state_dict(
+                average_checkpoints_with_averaged_model(
+                    filename_start=filename_start,
+                    filename_end=filename_end,
+                    device=device,
+                )
+            )
+
+    model.to("cpu")
+    model.eval()
+
+    if params.jit_trace:
+        # TODO: will support streaming mode
+        assert not params.streaming_model
+        convert_scaled_to_non_scaled(model, inplace=True)
+
+        logging.info("Using torch.jit.trace()")
+
+        x = torch.zeros(1, 100, 80, dtype=torch.float32)
+        x_lens = torch.tensor([100], dtype=torch.int64)
+        traced_model = torch.jit.trace(model, (x, x_lens))
+
+        filename = params.exp_dir / "jit_trace.pt"
+        traced_model.save(str(filename))
+        logging.info(f"Saved to {filename}")
+    else:
+        logging.info("Not using torch.jit.trace()")
+        # Save it using a format so that it can be loaded
+        # by :func:`load_checkpoint`
+        filename = params.exp_dir / "pretrained.pt"
+        torch.save({"model": model.state_dict()}, str(filename))
+        logging.info(f"Saved to {filename}")
+
+
+if __name__ == "__main__":
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+    logging.basicConfig(format=formatter, level=logging.INFO)
+    main()
diff --git a/egs/librispeech/ASR/conformer_ctc3/jit_pretrained.py b/egs/librispeech/ASR/conformer_ctc3/jit_pretrained.py
new file mode 100755
index 000000000..c96defd23
--- /dev/null
+++ b/egs/librispeech/ASR/conformer_ctc3/jit_pretrained.py
@@ -0,0 +1,406 @@
+#!/usr/bin/env python3
+# Copyright      2021  Xiaomi Corp.        (authors: Fangjun Kuang,
+#                                                    Mingshuang Luo,)
+#                                                    Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""
+Usage (for non-streaming mode):
+
+(1) ctc-decoding
+./conformer_ctc3/pretrained.py \
+  --checkpoint conformer_ctc3/exp/pretrained.pt \
+  --bpe-model data/lang_bpe_500/bpe.model \
+  --method ctc-decoding \
+  --sample-rate 16000 \
+  test_wavs/1089-134686-0001.wav
+
+(2) 1best
+./conformer_ctc3/pretrained.py \
+  --checkpoint conformer_ctc3/exp/pretrained.pt \
+  --HLG data/lang_bpe_500/HLG.pt \
+  --words-file data/lang_bpe_500/words.txt  \
+  --method 1best \
+  --sample-rate 16000 \
+  test_wavs/1089-134686-0001.wav
+
+(3) nbest-rescoring
+./conformer_ctc3/pretrained.py \
+  --checkpoint conformer_ctc3/exp/pretrained.pt \
+  --HLG data/lang_bpe_500/HLG.pt \
+  --words-file data/lang_bpe_500/words.txt  \
+  --G data/lm/G_4_gram.pt \
+  --method nbest-rescoring \
+  --sample-rate 16000 \
+  test_wavs/1089-134686-0001.wav
+
+(4) whole-lattice-rescoring
+./conformer_ctc3/pretrained.py \
+  --checkpoint conformer_ctc3/exp/pretrained.pt \
+  --HLG data/lang_bpe_500/HLG.pt \
+  --words-file data/lang_bpe_500/words.txt  \
+  --G data/lm/G_4_gram.pt \
+  --method whole-lattice-rescoring \
+  --sample-rate 16000 \
+  test_wavs/1089-134686-0001.wav
+"""
+
+
+import argparse
+import logging
+import math
+from typing import List
+
+import k2
+import kaldifeat
+import sentencepiece as spm
+import torch
+import torchaudio
+from decode import get_decoding_params
+from torch.nn.utils.rnn import pad_sequence
+from train import add_model_arguments, get_params
+
+from icefall.decode import (
+    get_lattice,
+    one_best_decoding,
+    rescore_with_n_best_list,
+    rescore_with_whole_lattice,
+)
+from icefall.utils import get_texts
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+
+    parser.add_argument(
+        "--model-filename",
+        type=str,
+        required=True,
+        help="Path to the torchscript model.",
+    )
+
+    parser.add_argument(
+        "--words-file",
+        type=str,
+        help="""Path to words.txt.
+        Used only when method is not ctc-decoding.
+        """,
+    )
+
+    parser.add_argument(
+        "--HLG",
+        type=str,
+        help="""Path to HLG.pt.
+        Used only when method is not ctc-decoding.
+        """,
+    )
+
+    parser.add_argument(
+        "--bpe-model",
+        type=str,
+        help="""Path to bpe.model.
+        Used only when method is ctc-decoding.
+        """,
+    )
+
+    parser.add_argument(
+        "--method",
+        type=str,
+        default="1best",
+        help="""Decoding method.
+        Possible values are:
+        (0) ctc-decoding - Use CTC decoding. It uses a sentence
+            piece model, i.e., lang_dir/bpe.model, to convert
+            word pieces to words. It needs neither a lexicon
+            nor an n-gram LM.
+        (1) 1best - Use the best path as decoding output. Only
+            the transformer encoder output is used for decoding.
+            We call it HLG decoding.
+        (2) nbest-rescoring. Extract n paths from the decoding lattice,
+            rescore them with an LM, the path with
+            the highest score is the decoding result.
+            We call it HLG decoding + n-gram LM rescoring.
+        (3) whole-lattice-rescoring - Use an LM to rescore the
+            decoding lattice and then use 1best to decode the
+            rescored lattice.
+            We call it HLG decoding + n-gram LM rescoring.
+        """,
+    )
+
+    parser.add_argument(
+        "--G",
+        type=str,
+        help="""An LM for rescoring.
+        Used only when method is
+        whole-lattice-rescoring or nbest-rescoring.
+        It's usually a 4-gram LM.
+        """,
+    )
+
+    parser.add_argument(
+        "--num-paths",
+        type=int,
+        default=100,
+        help="""
+        Used only when method is attention-decoder.
+        It specifies the size of n-best list.""",
+    )
+
+    parser.add_argument(
+        "--ngram-lm-scale",
+        type=float,
+        default=1.3,
+        help="""
+        Used only when method is whole-lattice-rescoring and nbest-rescoring.
+        It specifies the scale for n-gram LM scores.
+        (Note: You need to tune it on a dataset.)
+        """,
+    )
+
+    parser.add_argument(
+        "--nbest-scale",
+        type=float,
+        default=0.5,
+        help="""
+        Used only when method is nbest-rescoring.
+        It specifies the scale for lattice.scores when
+        extracting n-best lists. A smaller value results in
+        more unique number of paths with the risk of missing
+        the best path.
+        """,
+    )
+
+    parser.add_argument(
+        "--num-classes",
+        type=int,
+        default=500,
+        help="""
+        Vocab size in the BPE model.
+        """,
+    )
+
+    parser.add_argument(
+        "--sample-rate",
+        type=int,
+        default=16000,
+        help="The sample rate of the input sound file",
+    )
+
+    parser.add_argument(
+        "sound_files",
+        type=str,
+        nargs="+",
+        help="The input sound file(s) to transcribe. "
+        "Supported formats are those supported by torchaudio.load(). "
+        "For example, wav and flac are supported. "
+        "The sample rate has to be 16kHz.",
+    )
+
+    add_model_arguments(parser)
+
+    return parser
+
+
+def read_sound_files(
+    filenames: List[str], expected_sample_rate: float
+) -> List[torch.Tensor]:
+    """Read a list of sound files into a list 1-D float32 torch tensors.
+    Args:
+      filenames:
+        A list of sound filenames.
+      expected_sample_rate:
+        The expected sample rate of the sound files.
+    Returns:
+      Return a list of 1-D float32 torch tensors.
+    """
+    ans = []
+    for f in filenames:
+        wave, sample_rate = torchaudio.load(f)
+        assert sample_rate == expected_sample_rate, (
+            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
+        )
+        # We use only the first channel
+        ans.append(wave[0])
+    return ans
+
+
+def main():
+    parser = get_parser()
+    args = parser.parse_args()
+
+    params = get_params()
+    # add decoding params
+    params.update(get_decoding_params())
+    params.update(vars(args))
+    params.vocab_size = params.num_classes
+
+    logging.info(f"{params}")
+
+    device = torch.device("cpu")
+
+    logging.info(f"device: {device}")
+
+    model = torch.jit.load(args.model_filename)
+    model.to(device)
+    model.eval()
+
+    logging.info("Constructing Fbank computer")
+    opts = kaldifeat.FbankOptions()
+    opts.device = device
+    opts.frame_opts.dither = 0
+    opts.frame_opts.snip_edges = False
+    opts.frame_opts.samp_freq = params.sample_rate
+    opts.mel_opts.num_bins = params.feature_dim
+
+    fbank = kaldifeat.Fbank(opts)
+
+    logging.info(f"Reading sound files: {params.sound_files}")
+    waves = read_sound_files(
+        filenames=params.sound_files, expected_sample_rate=params.sample_rate
+    )
+    waves = [w.to(device) for w in waves]
+
+    logging.info("Decoding started")
+    features = fbank(waves)
+    feature_lengths = [f.size(0) for f in features]
+
+    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
+    feature_lengths = torch.tensor(feature_lengths, device=device)
+
+    nnet_output, _ = model(features, feature_lengths)
+
+    batch_size = nnet_output.shape[0]
+    supervision_segments = torch.tensor(
+        [[i, 0, nnet_output.shape[1]] for i in range(batch_size)],
+        dtype=torch.int32,
+    )
+
+    if params.method == "ctc-decoding":
+        logging.info("Use CTC decoding")
+        bpe_model = spm.SentencePieceProcessor()
+        bpe_model.load(params.bpe_model)
+        max_token_id = params.num_classes - 1
+
+        H = k2.ctc_topo(
+            max_token=max_token_id,
+            modified=False,
+            device=device,
+        )
+
+        lattice = get_lattice(
+            nnet_output=nnet_output,
+            decoding_graph=H,
+            supervision_segments=supervision_segments,
+            search_beam=params.search_beam,
+            output_beam=params.output_beam,
+            min_active_states=params.min_active_states,
+            max_active_states=params.max_active_states,
+            subsampling_factor=params.subsampling_factor,
+        )
+
+        best_path = one_best_decoding(
+            lattice=lattice, use_double_scores=params.use_double_scores
+        )
+        token_ids = get_texts(best_path)
+        hyps = bpe_model.decode(token_ids)
+        hyps = [s.split() for s in hyps]
+    elif params.method in [
+        "1best",
+        "nbest-rescoring",
+        "whole-lattice-rescoring",
+    ]:
+        logging.info(f"Loading HLG from {params.HLG}")
+        HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu"))
+        HLG = HLG.to(device)
+        if not hasattr(HLG, "lm_scores"):
+            # For whole-lattice-rescoring and attention-decoder
+            HLG.lm_scores = HLG.scores.clone()
+
+        if params.method in [
+            "nbest-rescoring",
+            "whole-lattice-rescoring",
+        ]:
+            logging.info(f"Loading G from {params.G}")
+            G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu"))
+            G = G.to(device)
+            if params.method == "whole-lattice-rescoring":
+                # Add epsilon self-loops to G as we will compose
+                # it with the whole lattice later
+                G = k2.add_epsilon_self_loops(G)
+                G = k2.arc_sort(G)
+
+            # G.lm_scores is used to replace HLG.lm_scores during
+            # LM rescoring.
+            G.lm_scores = G.scores.clone()
+
+        lattice = get_lattice(
+            nnet_output=nnet_output,
+            decoding_graph=HLG,
+            supervision_segments=supervision_segments,
+            search_beam=params.search_beam,
+            output_beam=params.output_beam,
+            min_active_states=params.min_active_states,
+            max_active_states=params.max_active_states,
+            subsampling_factor=params.subsampling_factor,
+        )
+
+        if params.method == "1best":
+            logging.info("Use HLG decoding")
+            best_path = one_best_decoding(
+                lattice=lattice, use_double_scores=params.use_double_scores
+            )
+        if params.method == "nbest-rescoring":
+            logging.info("Use HLG decoding + LM rescoring")
+            best_path_dict = rescore_with_n_best_list(
+                lattice=lattice,
+                G=G,
+                num_paths=params.num_paths,
+                lm_scale_list=[params.ngram_lm_scale],
+                nbest_scale=params.nbest_scale,
+            )
+            best_path = next(iter(best_path_dict.values()))
+        elif params.method == "whole-lattice-rescoring":
+            logging.info("Use HLG decoding + LM rescoring")
+            best_path_dict = rescore_with_whole_lattice(
+                lattice=lattice,
+                G_with_epsilon_loops=G,
+                lm_scale_list=[params.ngram_lm_scale],
+            )
+            best_path = next(iter(best_path_dict.values()))
+
+        hyps = get_texts(best_path)
+        word_sym_table = k2.SymbolTable.from_file(params.words_file)
+        hyps = [[word_sym_table[i] for i in ids] for ids in hyps]
+    else:
+        raise ValueError(f"Unsupported decoding method: {params.method}")
+
+    s = "\n"
+    for filename, hyp in zip(params.sound_files, hyps):
+        words = " ".join(hyp)
+        s += f"{filename}:\n{words}\n\n"
+    logging.info(s)
+
+    logging.info("Decoding Done")
+
+
+if __name__ == "__main__":
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+    logging.basicConfig(format=formatter, level=logging.INFO)
+    main()
diff --git a/egs/librispeech/ASR/conformer_ctc3/lstmp.py b/egs/librispeech/ASR/conformer_ctc3/lstmp.py
new file mode 120000
index 000000000..4f377cd01
--- /dev/null
+++ b/egs/librispeech/ASR/conformer_ctc3/lstmp.py
@@ -0,0 +1 @@
+../lstm_transducer_stateless2/lstmp.py
\ No newline at end of file
diff --git a/egs/librispeech/ASR/conformer_ctc3/model.py b/egs/librispeech/ASR/conformer_ctc3/model.py
new file mode 100644
index 000000000..f56df2006
--- /dev/null
+++ b/egs/librispeech/ASR/conformer_ctc3/model.py
@@ -0,0 +1,122 @@
+# Copyright  2021-2022  Xiaomi Corp.     (authors: Fangjun Kuang,
+#                                                  Wei Kang,
+#                                                  Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import math
+from typing import Tuple
+
+import torch
+import torch.nn as nn
+from encoder_interface import EncoderInterface
+from scaling import ScaledLinear
+
+
+class CTCModel(nn.Module):
+    """It implements https://www.cs.toronto.edu/~graves/icml_2006.pdf
+    "Connectionist Temporal Classification: Labelling Unsegmented
+    Sequence Data with Recurrent Neural Networks"
+    """
+
+    def __init__(
+        self,
+        encoder: EncoderInterface,
+        encoder_dim: int,
+        vocab_size: int,
+    ):
+        """
+        Args:
+          encoder:
+            It is the transcription network in the paper. Its accepts
+            two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,).
+            It returns two tensors: `logits` of shape (N, T, encoder_dm) and
+            `logit_lens` of shape (N,).
+          encoder_dim:
+            The feature embedding dimension.
+          vocab_size:
+            The vocabulary size.
+        """
+        super().__init__()
+        assert isinstance(encoder, EncoderInterface), type(encoder)
+
+        self.encoder = encoder
+        self.ctc_output_module = nn.Sequential(
+            nn.Dropout(p=0.1),
+            ScaledLinear(encoder_dim, vocab_size),
+        )
+
+    def get_ctc_output(
+        self,
+        encoder_out: torch.Tensor,
+        delay_penalty: float = 0.0,
+        blank_threshold: float = 0.99,
+    ):
+        """Compute ctc log-prob and optionally (delay_penalty > 0) apply delay penalty.
+        We first split utterance into sub-utterances according to the
+        blank probs, and then add sawtooth-like "blank-bonus" values to
+        the blank probs.
+        See https://github.com/k2-fsa/icefall/pull/669 for details.
+
+        Args:
+          encoder_out:
+            A tensor with shape of (N, T, C).
+          delay_penalty:
+            A constant used to scale the delay penalty score.
+          blank_threshold:
+            The threshold used to split utterance into sub-utterances.
+        """
+        output = self.ctc_output_module(encoder_out)
+        log_prob = nn.functional.log_softmax(output, dim=-1)
+
+        if self.training and delay_penalty > 0:
+            T_arange = torch.arange(encoder_out.shape[1]).to(device=encoder_out.device)
+            # split into sub-utterances using the blank-id
+            mask = log_prob[:, :, 0] >= math.log(blank_threshold)  # (B, T)
+            mask[:, 0] = True
+            cummax_out = (T_arange * mask).cummax(dim=-1)[0]  # (B, T)
+            # the sawtooth "blank-bonus" value
+            penalty = T_arange - cummax_out  # (B, T)
+            penalty_all = torch.zeros_like(log_prob)
+            penalty_all[:, :, 0] = delay_penalty * penalty
+            # apply latency penalty on probs
+            log_prob = log_prob + penalty_all
+
+        return log_prob
+
+    def forward(
+        self,
+        x: torch.Tensor,
+        x_lens: torch.Tensor,
+        warmup: float = 1.0,
+        delay_penalty: float = 0.0,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """
+        Args:
+          x:
+            A 3-D tensor of shape (N, T, C).
+          x_lens:
+            A 1-D tensor of shape (N,). It contains the number of frames in `x`
+            before padding.
+          warmup: a floating point value which increases throughout training;
+            values >= 1.0 are fully warmed up and have all modules present.
+          delay_penalty:
+            A constant used to scale the delay penalty score.
+        """
+        encoder_out, encoder_out_lens = self.encoder(x, x_lens, warmup=warmup)
+        assert torch.all(encoder_out_lens > 0)
+        nnet_output = self.get_ctc_output(encoder_out, delay_penalty=delay_penalty)
+        return nnet_output, encoder_out_lens
diff --git a/egs/librispeech/ASR/conformer_ctc3/optim.py b/egs/librispeech/ASR/conformer_ctc3/optim.py
new file mode 120000
index 000000000..e2deb4492
--- /dev/null
+++ b/egs/librispeech/ASR/conformer_ctc3/optim.py
@@ -0,0 +1 @@
+../pruned_transducer_stateless2/optim.py
\ No newline at end of file
diff --git a/egs/librispeech/ASR/conformer_ctc3/pretrained.py b/egs/librispeech/ASR/conformer_ctc3/pretrained.py
new file mode 100755
index 000000000..3628d6a5f
--- /dev/null
+++ b/egs/librispeech/ASR/conformer_ctc3/pretrained.py
@@ -0,0 +1,458 @@
+#!/usr/bin/env python3
+# Copyright      2021  Xiaomi Corp.        (authors: Fangjun Kuang,
+#                                                    Mingshuang Luo,)
+#                                                    Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""
+Usage (for non-streaming mode):
+
+(1) ctc-decoding
+./conformer_ctc3/pretrained.py \
+  --checkpoint conformer_ctc3/exp/pretrained.pt \
+  --bpe-model data/lang_bpe_500/bpe.model \
+  --method ctc-decoding \
+  --sample-rate 16000 \
+  test_wavs/1089-134686-0001.wav
+
+(2) 1best
+./conformer_ctc3/pretrained.py \
+  --checkpoint conformer_ctc3/exp/pretrained.pt \
+  --HLG data/lang_bpe_500/HLG.pt \
+  --words-file data/lang_bpe_500/words.txt  \
+  --method 1best \
+  --sample-rate 16000 \
+  test_wavs/1089-134686-0001.wav
+
+(3) nbest-rescoring
+./conformer_ctc3/pretrained.py \
+  --checkpoint conformer_ctc3/exp/pretrained.pt \
+  --HLG data/lang_bpe_500/HLG.pt \
+  --words-file data/lang_bpe_500/words.txt  \
+  --G data/lm/G_4_gram.pt \
+  --method nbest-rescoring \
+  --sample-rate 16000 \
+  test_wavs/1089-134686-0001.wav
+
+(4) whole-lattice-rescoring
+./conformer_ctc3/pretrained.py \
+  --checkpoint conformer_ctc3/exp/pretrained.pt \
+  --HLG data/lang_bpe_500/HLG.pt \
+  --words-file data/lang_bpe_500/words.txt  \
+  --G data/lm/G_4_gram.pt \
+  --method whole-lattice-rescoring \
+  --sample-rate 16000 \
+  test_wavs/1089-134686-0001.wav
+"""
+
+
+import argparse
+import logging
+import math
+from typing import List
+
+import k2
+import kaldifeat
+import sentencepiece as spm
+import torch
+import torchaudio
+from decode import get_decoding_params
+from torch.nn.utils.rnn import pad_sequence
+from train import add_model_arguments, get_ctc_model, get_params
+
+from icefall.decode import (
+    get_lattice,
+    one_best_decoding,
+    rescore_with_n_best_list,
+    rescore_with_whole_lattice,
+)
+from icefall.utils import get_texts, str2bool
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+
+    parser.add_argument(
+        "--checkpoint",
+        type=str,
+        required=True,
+        help="Path to the checkpoint. "
+        "The checkpoint is assumed to be saved by "
+        "icefall.checkpoint.save_checkpoint().",
+    )
+
+    parser.add_argument(
+        "--words-file",
+        type=str,
+        help="""Path to words.txt.
+        Used only when method is not ctc-decoding.
+        """,
+    )
+
+    parser.add_argument(
+        "--HLG",
+        type=str,
+        help="""Path to HLG.pt.
+        Used only when method is not ctc-decoding.
+        """,
+    )
+
+    parser.add_argument(
+        "--bpe-model",
+        type=str,
+        help="""Path to bpe.model.
+        Used only when method is ctc-decoding.
+        """,
+    )
+
+    parser.add_argument(
+        "--method",
+        type=str,
+        default="1best",
+        help="""Decoding method.
+        Possible values are:
+        (0) ctc-decoding - Use CTC decoding. It uses a sentence
+            piece model, i.e., lang_dir/bpe.model, to convert
+            word pieces to words. It needs neither a lexicon
+            nor an n-gram LM.
+        (1) 1best - Use the best path as decoding output. Only
+            the transformer encoder output is used for decoding.
+            We call it HLG decoding.
+        (2) nbest-rescoring. Extract n paths from the decoding lattice,
+            rescore them with an LM, the path with
+            the highest score is the decoding result.
+            We call it HLG decoding + n-gram LM rescoring.
+        (3) whole-lattice-rescoring - Use an LM to rescore the
+            decoding lattice and then use 1best to decode the
+            rescored lattice.
+            We call it HLG decoding + n-gram LM rescoring.
+        """,
+    )
+
+    parser.add_argument(
+        "--G",
+        type=str,
+        help="""An LM for rescoring.
+        Used only when method is
+        whole-lattice-rescoring or nbest-rescoring.
+        It's usually a 4-gram LM.
+        """,
+    )
+
+    parser.add_argument(
+        "--num-paths",
+        type=int,
+        default=100,
+        help="""
+        Used only when method is attention-decoder.
+        It specifies the size of n-best list.""",
+    )
+
+    parser.add_argument(
+        "--ngram-lm-scale",
+        type=float,
+        default=1.3,
+        help="""
+        Used only when method is whole-lattice-rescoring and nbest-rescoring.
+        It specifies the scale for n-gram LM scores.
+        (Note: You need to tune it on a dataset.)
+        """,
+    )
+
+    parser.add_argument(
+        "--nbest-scale",
+        type=float,
+        default=0.5,
+        help="""
+        Used only when method is nbest-rescoring.
+        It specifies the scale for lattice.scores when
+        extracting n-best lists. A smaller value results in
+        more unique number of paths with the risk of missing
+        the best path.
+        """,
+    )
+
+    parser.add_argument(
+        "--num-classes",
+        type=int,
+        default=500,
+        help="""
+        Vocab size in the BPE model.
+        """,
+    )
+
+    parser.add_argument(
+        "--simulate-streaming",
+        type=str2bool,
+        default=False,
+        help="""Whether to simulate streaming in decoding, this is a good way to
+        test a streaming model.
+        """,
+    )
+
+    parser.add_argument(
+        "--decode-chunk-size",
+        type=int,
+        default=16,
+        help="The chunk size for decoding (in frames after subsampling)",
+    )
+
+    parser.add_argument(
+        "--left-context",
+        type=int,
+        default=64,
+        help="left context can be seen during decoding (in frames after subsampling)",
+    )
+
+    parser.add_argument(
+        "--sample-rate",
+        type=int,
+        default=16000,
+        help="The sample rate of the input sound file",
+    )
+
+    parser.add_argument(
+        "sound_files",
+        type=str,
+        nargs="+",
+        help="The input sound file(s) to transcribe. "
+        "Supported formats are those supported by torchaudio.load(). "
+        "For example, wav and flac are supported. "
+        "The sample rate has to be 16kHz.",
+    )
+
+    add_model_arguments(parser)
+
+    return parser
+
+
+def read_sound_files(
+    filenames: List[str], expected_sample_rate: float
+) -> List[torch.Tensor]:
+    """Read a list of sound files into a list 1-D float32 torch tensors.
+    Args:
+      filenames:
+        A list of sound filenames.
+      expected_sample_rate:
+        The expected sample rate of the sound files.
+    Returns:
+      Return a list of 1-D float32 torch tensors.
+    """
+    ans = []
+    for f in filenames:
+        wave, sample_rate = torchaudio.load(f)
+        assert sample_rate == expected_sample_rate, (
+            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
+        )
+        # We use only the first channel
+        ans.append(wave[0])
+    return ans
+
+
+def main():
+    parser = get_parser()
+    args = parser.parse_args()
+
+    params = get_params()
+    # add decoding params
+    params.update(get_decoding_params())
+    params.update(vars(args))
+    params.vocab_size = params.num_classes
+
+    if params.simulate_streaming:
+        assert (
+            params.causal_convolution
+        ), "Decoding in streaming requires causal convolution"
+
+    logging.info(f"{params}")
+
+    device = torch.device("cpu")
+    if torch.cuda.is_available():
+        device = torch.device("cuda", 0)
+
+    logging.info(f"device: {device}")
+
+    logging.info("About to create model")
+    model = get_ctc_model(params)
+
+    num_param = sum([p.numel() for p in model.parameters()])
+    logging.info(f"Number of model parameters: {num_param}")
+
+    checkpoint = torch.load(args.checkpoint, map_location="cpu")
+    model.load_state_dict(checkpoint["model"], strict=False)
+    model.to(device)
+    model.eval()
+
+    logging.info("Constructing Fbank computer")
+    opts = kaldifeat.FbankOptions()
+    opts.device = device
+    opts.frame_opts.dither = 0
+    opts.frame_opts.snip_edges = False
+    opts.frame_opts.samp_freq = params.sample_rate
+    opts.mel_opts.num_bins = params.feature_dim
+
+    fbank = kaldifeat.Fbank(opts)
+
+    logging.info(f"Reading sound files: {params.sound_files}")
+    waves = read_sound_files(
+        filenames=params.sound_files, expected_sample_rate=params.sample_rate
+    )
+    waves = [w.to(device) for w in waves]
+
+    logging.info("Decoding started")
+    features = fbank(waves)
+    feature_lengths = [f.size(0) for f in features]
+
+    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
+    feature_lengths = torch.tensor(feature_lengths, device=device)
+
+    # model forward
+    if params.simulate_streaming:
+        encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
+            x=features,
+            x_lens=feature_lengths,
+            chunk_size=params.decode_chunk_size,
+            left_context=params.left_context,
+            simulate_streaming=True,
+        )
+    else:
+        encoder_out, encoder_out_lens = model.encoder(
+            x=features, x_lens=feature_lengths
+        )
+    nnet_output = model.get_ctc_output(encoder_out)
+
+    batch_size = nnet_output.shape[0]
+    supervision_segments = torch.tensor(
+        [[i, 0, nnet_output.shape[1]] for i in range(batch_size)],
+        dtype=torch.int32,
+    )
+
+    if params.method == "ctc-decoding":
+        logging.info("Use CTC decoding")
+        bpe_model = spm.SentencePieceProcessor()
+        bpe_model.load(params.bpe_model)
+        max_token_id = params.num_classes - 1
+
+        H = k2.ctc_topo(
+            max_token=max_token_id,
+            modified=False,
+            device=device,
+        )
+
+        lattice = get_lattice(
+            nnet_output=nnet_output,
+            decoding_graph=H,
+            supervision_segments=supervision_segments,
+            search_beam=params.search_beam,
+            output_beam=params.output_beam,
+            min_active_states=params.min_active_states,
+            max_active_states=params.max_active_states,
+            subsampling_factor=params.subsampling_factor,
+        )
+
+        best_path = one_best_decoding(
+            lattice=lattice, use_double_scores=params.use_double_scores
+        )
+        token_ids = get_texts(best_path)
+        hyps = bpe_model.decode(token_ids)
+        hyps = [s.split() for s in hyps]
+    elif params.method in [
+        "1best",
+        "nbest-rescoring",
+        "whole-lattice-rescoring",
+    ]:
+        logging.info(f"Loading HLG from {params.HLG}")
+        HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu"))
+        HLG = HLG.to(device)
+        if not hasattr(HLG, "lm_scores"):
+            # For whole-lattice-rescoring and attention-decoder
+            HLG.lm_scores = HLG.scores.clone()
+
+        if params.method in [
+            "nbest-rescoring",
+            "whole-lattice-rescoring",
+        ]:
+            logging.info(f"Loading G from {params.G}")
+            G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu"))
+            G = G.to(device)
+            if params.method == "whole-lattice-rescoring":
+                # Add epsilon self-loops to G as we will compose
+                # it with the whole lattice later
+                G = k2.add_epsilon_self_loops(G)
+                G = k2.arc_sort(G)
+
+            # G.lm_scores is used to replace HLG.lm_scores during
+            # LM rescoring.
+            G.lm_scores = G.scores.clone()
+
+        lattice = get_lattice(
+            nnet_output=nnet_output,
+            decoding_graph=HLG,
+            supervision_segments=supervision_segments,
+            search_beam=params.search_beam,
+            output_beam=params.output_beam,
+            min_active_states=params.min_active_states,
+            max_active_states=params.max_active_states,
+            subsampling_factor=params.subsampling_factor,
+        )
+
+        if params.method == "1best":
+            logging.info("Use HLG decoding")
+            best_path = one_best_decoding(
+                lattice=lattice, use_double_scores=params.use_double_scores
+            )
+        if params.method == "nbest-rescoring":
+            logging.info("Use HLG decoding + LM rescoring")
+            best_path_dict = rescore_with_n_best_list(
+                lattice=lattice,
+                G=G,
+                num_paths=params.num_paths,
+                lm_scale_list=[params.ngram_lm_scale],
+                nbest_scale=params.nbest_scale,
+            )
+            best_path = next(iter(best_path_dict.values()))
+        elif params.method == "whole-lattice-rescoring":
+            logging.info("Use HLG decoding + LM rescoring")
+            best_path_dict = rescore_with_whole_lattice(
+                lattice=lattice,
+                G_with_epsilon_loops=G,
+                lm_scale_list=[params.ngram_lm_scale],
+            )
+            best_path = next(iter(best_path_dict.values()))
+
+        hyps = get_texts(best_path)
+        word_sym_table = k2.SymbolTable.from_file(params.words_file)
+        hyps = [[word_sym_table[i] for i in ids] for ids in hyps]
+    else:
+        raise ValueError(f"Unsupported decoding method: {params.method}")
+
+    s = "\n"
+    for filename, hyp in zip(params.sound_files, hyps):
+        words = " ".join(hyp)
+        s += f"{filename}:\n{words}\n\n"
+    logging.info(s)
+
+    logging.info("Decoding Done")
+
+
+if __name__ == "__main__":
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+    logging.basicConfig(format=formatter, level=logging.INFO)
+    main()
diff --git a/egs/librispeech/ASR/conformer_ctc3/scaling.py b/egs/librispeech/ASR/conformer_ctc3/scaling.py
new file mode 120000
index 000000000..09d802cc4
--- /dev/null
+++ b/egs/librispeech/ASR/conformer_ctc3/scaling.py
@@ -0,0 +1 @@
+../pruned_transducer_stateless2/scaling.py
\ No newline at end of file
diff --git a/egs/librispeech/ASR/conformer_ctc3/scaling_converter.py b/egs/librispeech/ASR/conformer_ctc3/scaling_converter.py
new file mode 120000
index 000000000..3b667058d
--- /dev/null
+++ b/egs/librispeech/ASR/conformer_ctc3/scaling_converter.py
@@ -0,0 +1 @@
+../pruned_transducer_stateless3/scaling_converter.py
\ No newline at end of file
diff --git a/egs/librispeech/ASR/conformer_ctc3/test_model.py b/egs/librispeech/ASR/conformer_ctc3/test_model.py
new file mode 100755
index 000000000..b97b7eed8
--- /dev/null
+++ b/egs/librispeech/ASR/conformer_ctc3/test_model.py
@@ -0,0 +1,82 @@
+#!/usr/bin/env python3
+# Copyright    2022  Xiaomi Corp.        (authors: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""
+To run this file, do:
+
+    cd icefall/egs/librispeech/ASR
+    python ./conformer_ctc3/test_model.py
+"""
+
+import torch
+
+from train import get_params, get_ctc_model
+
+
+def test_model():
+    params = get_params()
+    params.vocab_size = 500
+    params.blank_id = 0
+    params.context_size = 2
+    params.unk_id = 2
+
+    params.dynamic_chunk_training = False
+    params.short_chunk_size = 25
+    params.num_left_chunks = 4
+    params.causal_convolution = False
+
+    model = get_ctc_model(params)
+
+    num_param = sum([p.numel() for p in model.parameters()])
+    print(f"Number of model parameters: {num_param}")
+
+    features = torch.randn(2, 100, 80)
+    feature_lengths = torch.full((2,), 100)
+    model(x=features, x_lens=feature_lengths)
+
+
+def test_model_streaming():
+    params = get_params()
+    params.vocab_size = 500
+    params.blank_id = 0
+    params.context_size = 2
+    params.unk_id = 2
+
+    params.dynamic_chunk_training = True
+    params.short_chunk_size = 25
+    params.num_left_chunks = 4
+    params.causal_convolution = True
+
+    model = get_ctc_model(params)
+
+    num_param = sum([p.numel() for p in model.parameters()])
+    print(f"Number of model parameters: {num_param}")
+
+    features = torch.randn(2, 100, 80)
+    feature_lengths = torch.full((2,), 100)
+    encoder_out, _ = model.encoder(x=features, x_lens=feature_lengths)
+    model.get_ctc_output(encoder_out)
+
+
+def main():
+    test_model()
+    test_model_streaming()
+
+
+if __name__ == "__main__":
+    main()
diff --git a/egs/librispeech/ASR/conformer_ctc3/train.py b/egs/librispeech/ASR/conformer_ctc3/train.py
new file mode 100755
index 000000000..fb3b740c1
--- /dev/null
+++ b/egs/librispeech/ASR/conformer_ctc3/train.py
@@ -0,0 +1,1108 @@
+#!/usr/bin/env python3
+# Copyright    2021  Xiaomi Corp.        (authors: Fangjun Kuang,
+#                                                  Wei Kang,
+#                                                  Mingshuang Luo,)
+#                                                  Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+
+export CUDA_VISIBLE_DEVICES="0,1,2,3"
+
+./conformer_ctc3/train.py \
+  --world-size 4 \
+  --num-epochs 30 \
+  --start-epoch 1 \
+  --exp-dir conformer_ctc3/exp \
+  --full-libri 1 \
+  --max-duration 300
+
+# For mix precision training:
+
+./conformer_ctc3/train.py \
+  --world-size 4 \
+  --num-epochs 30 \
+  --start-epoch 1 \
+  --use-fp16 1 \
+  --exp-dir conformer_ctc3/exp \
+  --full-libri 1 \
+  --max-duration 550
+
+# train a streaming model
+./conformer_ctc3/train.py \
+  --world-size 4 \
+  --num-epochs 30 \
+  --start-epoch 1 \
+  --exp-dir conformer_ctc3/exp \
+  --full-libri 1 \
+  --dynamic-chunk-training 1 \
+  --causal-convolution 1 \
+  --short-chunk-size 25 \
+  --num-left-chunks 4 \
+  --max-duration 300 \
+  --delay-penalty 0.0
+"""
+
+import argparse
+import copy
+import logging
+from pathlib import Path
+from shutil import copyfile
+from typing import Any, Dict, Optional, Tuple, Union
+
+import k2
+import optim
+import torch
+import torch.multiprocessing as mp
+import torch.nn as nn
+from asr_datamodule import LibriSpeechAsrDataModule
+from conformer import Conformer
+from lhotse.cut import Cut
+from lhotse.dataset.sampling.base import CutSampler
+from lhotse.utils import fix_random_seed
+from model import CTCModel
+from optim import Eden, Eve
+from torch import Tensor
+from torch.cuda.amp import GradScaler
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.utils.tensorboard import SummaryWriter
+
+from icefall import diagnostics
+from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
+from icefall.checkpoint import load_checkpoint, remove_checkpoints
+from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
+from icefall.checkpoint import (
+    save_checkpoint_with_global_batch_idx,
+    update_averaged_model,
+)
+from icefall.dist import cleanup_dist, setup_dist
+from icefall.env import get_env_info
+from icefall.graph_compiler import CtcTrainingGraphCompiler
+from icefall.lexicon import Lexicon
+from icefall.utils import (
+    AttributeDict,
+    MetricsTracker,
+    encode_supervisions,
+    setup_logger,
+    str2bool,
+)
+
+LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
+
+
+def add_model_arguments(parser: argparse.ArgumentParser):
+    parser.add_argument(
+        "--dynamic-chunk-training",
+        type=str2bool,
+        default=False,
+        help="""Whether to use dynamic_chunk_training, if you want a streaming
+        model, this requires to be True.
+        """,
+    )
+
+    parser.add_argument(
+        "--causal-convolution",
+        type=str2bool,
+        default=False,
+        help="""Whether to use causal convolution, this requires to be True when
+        using dynamic_chunk_training.
+        """,
+    )
+
+    parser.add_argument(
+        "--short-chunk-size",
+        type=int,
+        default=25,
+        help="""Chunk length of dynamic training, the chunk size would be either
+        max sequence length of current batch or uniformly sampled from (1, short_chunk_size).
+        """,
+    )
+
+    parser.add_argument(
+        "--num-left-chunks",
+        type=int,
+        default=4,
+        help="How many left context can be seen in chunks when calculating attention.",
+    )
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+
+    parser.add_argument(
+        "--world-size",
+        type=int,
+        default=1,
+        help="Number of GPUs for DDP training.",
+    )
+
+    parser.add_argument(
+        "--master-port",
+        type=int,
+        default=12354,
+        help="Master port to use for DDP training.",
+    )
+
+    parser.add_argument(
+        "--tensorboard",
+        type=str2bool,
+        default=True,
+        help="Should various information be logged in tensorboard.",
+    )
+
+    parser.add_argument(
+        "--num-epochs",
+        type=int,
+        default=30,
+        help="Number of epochs to train.",
+    )
+
+    parser.add_argument(
+        "--start-epoch",
+        type=int,
+        default=1,
+        help="""Resume training from this epoch. It should be positive.
+        If larger than 1, it will load checkpoint from
+        exp-dir/epoch-{start_epoch-1}.pt
+        """,
+    )
+
+    parser.add_argument(
+        "--start-batch",
+        type=int,
+        default=0,
+        help="""If positive, --start-epoch is ignored and
+        it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt
+        """,
+    )
+
+    parser.add_argument(
+        "--exp-dir",
+        type=str,
+        default="conformer_ctc3/exp",
+        help="""The experiment dir.
+        It specifies the directory where all training related
+        files, e.g., checkpoints, log, etc, are saved
+        """,
+    )
+
+    parser.add_argument(
+        "--lang-dir",
+        type=str,
+        default="data/lang_bpe_500",
+        help="""The lang dir
+        It contains language related input files such as
+        "lexicon.txt"
+        """,
+    )
+
+    parser.add_argument(
+        "--initial-lr",
+        type=float,
+        default=0.003,
+        help="""The initial learning rate. This value should not need to be
+        changed.""",
+    )
+
+    parser.add_argument(
+        "--lr-batches",
+        type=float,
+        default=5000,
+        help="""Number of steps that affects how rapidly the learning rate decreases.
+        We suggest not to change this.""",
+    )
+
+    parser.add_argument(
+        "--lr-epochs",
+        type=float,
+        default=6,
+        help="""Number of epochs that affects how rapidly the learning rate decreases.
+        """,
+    )
+
+    parser.add_argument(
+        "--seed",
+        type=int,
+        default=42,
+        help="The seed for random generators intended for reproducibility",
+    )
+
+    parser.add_argument(
+        "--print-diagnostics",
+        type=str2bool,
+        default=False,
+        help="Accumulate stats on activations, print them and exit.",
+    )
+
+    parser.add_argument(
+        "--save-every-n",
+        type=int,
+        default=8000,
+        help="""Save checkpoint after processing this number of batches"
+        periodically. We save checkpoint to exp-dir/ whenever
+        params.batch_idx_train % save_every_n == 0. The checkpoint filename
+        has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
+        Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
+        end of each epoch where `xxx` is the epoch number counting from 0.
+        """,
+    )
+
+    parser.add_argument(
+        "--keep-last-k",
+        type=int,
+        default=20,
+        help="""Only keep this number of checkpoints on disk.
+        For instance, if it is 3, there are only 3 checkpoints
+        in the exp-dir with filenames `checkpoint-xxx.pt`.
+        It does not affect checkpoints with name `epoch-xxx.pt`.
+        """,
+    )
+
+    parser.add_argument(
+        "--average-period",
+        type=int,
+        default=100,
+        help="""Update the averaged model, namely `model_avg`, after processing
+        this number of batches. `model_avg` is a separate version of model,
+        in which each floating-point parameter is the average of all the
+        parameters from the start of training. Each time we take the average,
+        we do: `model_avg = model * (average_period / batch_idx_train) +
+            model_avg * ((batch_idx_train - average_period) / batch_idx_train)`.
+        """,
+    )
+
+    parser.add_argument(
+        "--use-fp16",
+        type=str2bool,
+        default=False,
+        help="Whether to use half precision training.",
+    )
+
+    parser.add_argument(
+        "--delay-penalty",
+        type=float,
+        default=0.0,
+        help="""A constant used to scale the symbol delay penalty,
+        to encourage symbol emit earlier for streaming models.
+        It is almost the same as the `delay_penalty` in our `rnnt_loss`, See
+        https://github.com/k2-fsa/k2/issues/955 and
+        https://arxiv.org/pdf/2211.00490.pdf for more details.""",
+    )
+
+    parser.add_argument(
+        "--nnet-delay-penalty",
+        type=float,
+        default=0.0,
+        help="""A constant to penalize symbol delay, which is applied on
+        the nnet_output after log-softmax.
+        We recommend using --delay-penalty instead.
+        See https://github.com/k2-fsa/icefall/pull/669 for details.""",
+    )
+
+    add_model_arguments(parser)
+
+    return parser
+
+
+def get_params() -> AttributeDict:
+    """Return a dict containing training parameters.
+
+    All training related parameters that are not passed from the commandline
+    are saved in the variable `params`.
+
+    Commandline options are merged into `params` after they are parsed, so
+    you can also access them via `params`.
+
+    Explanation of options saved in `params`:
+
+        - best_train_loss: Best training loss so far. It is used to select
+                           the model that has the lowest training loss. It is
+                           updated during the training.
+
+        - best_valid_loss: Best validation loss so far. It is used to select
+                           the model that has the lowest validation loss. It is
+                           updated during the training.
+
+        - best_train_epoch: It is the epoch that has the best training loss.
+
+        - best_valid_epoch: It is the epoch that has the best validation loss.
+
+        - batch_idx_train: Used to writing statistics to tensorboard. It
+                           contains number of batches trained so far across
+                           epochs.
+
+        - log_interval:  Print training loss if batch_idx % log_interval` is 0
+
+        - reset_interval: Reset statistics if batch_idx % reset_interval is 0
+
+        - valid_interval:  Run validation if batch_idx % valid_interval is 0
+
+        - feature_dim: The model input dim. It has to match the one used
+                       in computing features.
+
+        - subsampling_factor:  The subsampling factor for the model.
+
+        - encoder_dim: Hidden dim for multi-head attention model.
+
+        - num_decoder_layers: Number of decoder layer of transformer decoder.
+
+        - warm_step: The warm_step for Noam optimizer.
+    """
+    params = AttributeDict(
+        {
+            "best_train_loss": float("inf"),
+            "best_valid_loss": float("inf"),
+            "best_train_epoch": -1,
+            "best_valid_epoch": -1,
+            "batch_idx_train": 0,
+            "log_interval": 50,
+            "reset_interval": 200,
+            "valid_interval": 3000,  # For the 100h subset, use 800
+            # parameters for conformer
+            "feature_dim": 80,
+            "subsampling_factor": 4,
+            "encoder_dim": 512,
+            "nhead": 8,
+            "dim_feedforward": 2048,
+            "num_encoder_layers": 12,
+            # parameters for loss
+            "beam_size": 10,
+            "reduction": "sum",
+            "use_double_scores": True,
+            # parameters for Noam
+            "model_warm_step": 3000,  # arg given to model, not for lrate
+            "env_info": get_env_info(),
+        }
+    )
+
+    return params
+
+
+def get_encoder_model(params: AttributeDict) -> nn.Module:
+    # TODO: We can add an option to switch between Conformer and Transformer
+    encoder = Conformer(
+        num_features=params.feature_dim,
+        subsampling_factor=params.subsampling_factor,
+        d_model=params.encoder_dim,
+        nhead=params.nhead,
+        dim_feedforward=params.dim_feedforward,
+        num_encoder_layers=params.num_encoder_layers,
+        dynamic_chunk_training=params.dynamic_chunk_training,
+        short_chunk_size=params.short_chunk_size,
+        num_left_chunks=params.num_left_chunks,
+        causal=params.causal_convolution,
+    )
+    return encoder
+
+
+def get_ctc_model(params: AttributeDict) -> nn.Module:
+    encoder = get_encoder_model(params)
+    model = CTCModel(
+        encoder=encoder,
+        encoder_dim=params.encoder_dim,
+        vocab_size=params.vocab_size,
+    )
+    return model
+
+
+def load_checkpoint_if_available(
+    params: AttributeDict,
+    model: nn.Module,
+    model_avg: nn.Module = None,
+    optimizer: Optional[torch.optim.Optimizer] = None,
+    scheduler: Optional[LRSchedulerType] = None,
+) -> Optional[Dict[str, Any]]:
+    """Load checkpoint from file.
+
+    If params.start_batch is positive, it will load the checkpoint from
+    `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if
+    params.start_epoch is larger than 1, it will load the checkpoint from
+    `params.start_epoch - 1`.
+
+    Apart from loading state dict for `model` and `optimizer` it also updates
+    `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
+    and `best_valid_loss` in `params`.
+
+    Args:
+      params:
+        The return value of :func:`get_params`.
+      model:
+        The training model.
+      model_avg:
+        The stored model averaged from the start of training.
+      optimizer:
+        The optimizer that we are using.
+      scheduler:
+        The scheduler that we are using.
+    Returns:
+      Return a dict containing previously saved training info.
+    """
+    if params.start_batch > 0:
+        filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt"
+    elif params.start_epoch > 1:
+        filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
+    else:
+        return None
+
+    assert filename.is_file(), f"{filename} does not exist!"
+
+    saved_params = load_checkpoint(
+        filename,
+        model=model,
+        model_avg=model_avg,
+        optimizer=optimizer,
+        scheduler=scheduler,
+    )
+
+    keys = [
+        "best_train_epoch",
+        "best_valid_epoch",
+        "batch_idx_train",
+        "best_train_loss",
+        "best_valid_loss",
+    ]
+    for k in keys:
+        params[k] = saved_params[k]
+
+    if params.start_batch > 0:
+        if "cur_epoch" in saved_params:
+            params["start_epoch"] = saved_params["cur_epoch"]
+
+    return saved_params
+
+
+def save_checkpoint(
+    params: AttributeDict,
+    model: Union[nn.Module, DDP],
+    model_avg: Optional[nn.Module] = None,
+    optimizer: Optional[torch.optim.Optimizer] = None,
+    scheduler: Optional[LRSchedulerType] = None,
+    sampler: Optional[CutSampler] = None,
+    scaler: Optional[GradScaler] = None,
+    rank: int = 0,
+) -> None:
+    """Save model, optimizer, scheduler and training stats to file.
+
+    Args:
+      params:
+        It is returned by :func:`get_params`.
+      model:
+        The training model.
+      model_avg:
+        The stored model averaged from the start of training.
+      optimizer:
+        The optimizer used in the training.
+      sampler:
+       The sampler for the training dataset.
+      scaler:
+        The scaler used for mix precision training.
+    """
+    if rank != 0:
+        return
+    filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
+    save_checkpoint_impl(
+        filename=filename,
+        model=model,
+        model_avg=model_avg,
+        params=params,
+        optimizer=optimizer,
+        scheduler=scheduler,
+        sampler=sampler,
+        scaler=scaler,
+        rank=rank,
+    )
+
+    if params.best_train_epoch == params.cur_epoch:
+        best_train_filename = params.exp_dir / "best-train-loss.pt"
+        copyfile(src=filename, dst=best_train_filename)
+
+    if params.best_valid_epoch == params.cur_epoch:
+        best_valid_filename = params.exp_dir / "best-valid-loss.pt"
+        copyfile(src=filename, dst=best_valid_filename)
+
+
+def compute_loss(
+    params: AttributeDict,
+    model: Union[nn.Module, DDP],
+    graph_compiler: Union[BpeCtcTrainingGraphCompiler, CtcTrainingGraphCompiler],
+    batch: dict,
+    is_training: bool,
+    warmup: float = 1.0,
+) -> Tuple[Tensor, MetricsTracker]:
+    """
+    Compute RNN-T loss given the model and its inputs.
+
+    Args:
+      params:
+        Parameters for training. See :func:`get_params`.
+      model:
+        The model for training. It is an instance of Conformer in our case.
+      graph_compiler:
+        It is used to build a decoding graph from a ctc topo and training
+        transcript. The training transcript is contained in the given `batch`,
+        while the ctc topo is built when this compiler is instantiated.
+      batch:
+        A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+        for the content in it.
+      is_training:
+        True for training. False for validation. When it is True, this
+        function enables autograd during computation; when it is False, it
+        disables autograd.
+     warmup: a floating point value which increases throughout training;
+        values >= 1.0 are fully warmed up and have all modules present.
+    """
+    device = model.device if isinstance(model, DDP) else next(model.parameters()).device
+    feature = batch["inputs"]
+    # at entry, feature is (N, T, C)
+    assert feature.ndim == 3
+    feature = feature.to(device)
+
+    supervisions = batch["supervisions"]
+    feature_lens = supervisions["num_frames"].to(device)
+
+    with torch.set_grad_enabled(is_training):
+        nnet_output, encoder_out_lens = model(
+            feature,
+            feature_lens,
+            warmup=warmup,
+            delay_penalty=params.nnet_delay_penalty if warmup >= 1.0 else 0,
+        )
+        assert torch.all(encoder_out_lens > 0)
+
+    # NOTE: We need `encode_supervisions` to sort sequences with
+    # different duration in decreasing order, required by
+    # `k2.intersect_dense` called in `k2.ctc_loss`
+    supervision_segments, texts = encode_supervisions(
+        supervisions, subsampling_factor=params.subsampling_factor
+    )
+
+    if isinstance(graph_compiler, BpeCtcTrainingGraphCompiler):
+        # Works with a BPE model
+        token_ids = graph_compiler.texts_to_ids(texts)
+        decoding_graph = graph_compiler.compile(token_ids)
+    elif isinstance(graph_compiler, CtcTrainingGraphCompiler):
+        # Works with a phone lexicon
+        decoding_graph = graph_compiler.compile(texts)
+    else:
+        raise ValueError(f"Unsupported type of graph compiler: {type(graph_compiler)}")
+
+    dense_fsa_vec = k2.DenseFsaVec(
+        nnet_output,
+        supervision_segments,
+        allow_truncate=params.subsampling_factor - 1,
+    )
+
+    ctc_loss = k2.ctc_loss(
+        decoding_graph=decoding_graph,
+        dense_fsa_vec=dense_fsa_vec,
+        output_beam=params.beam_size,
+        delay_penalty=params.delay_penalty if warmup >= 1.0 else 0.0,
+        reduction=params.reduction,
+        use_double_scores=params.use_double_scores,
+    )
+    ctc_loss_is_finite = torch.isfinite(ctc_loss)
+    if not torch.all(ctc_loss_is_finite):
+        logging.info("Not all losses are finite!\n" f"ctc_loss: {ctc_loss}")
+        ctc_loss = ctc_loss[ctc_loss_is_finite]
+
+        # If either all simple_loss or pruned_loss is inf or nan,
+        # we stop the training process by raising an exception
+        if torch.all(~ctc_loss_is_finite):
+            raise ValueError(
+                "There are too many utterances in this batch "
+                "leading to inf or nan losses."
+            )
+    loss = ctc_loss.sum()
+
+    assert loss.requires_grad == is_training
+
+    info = MetricsTracker()
+    # info["frames"] is an approximate number for two reasons:
+    # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2
+    # (2) If some utterances in the batch lead to inf/nan loss, they
+    #     are filtered out.
+    info["frames"] = supervision_segments[:, 2].sum().item()
+    # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances`  # noqa
+    info["utterances"] = feature.size(0)
+    # averaged input duration in frames over utterances
+    info["utt_duration"] = feature_lens.sum().item()
+    # averaged padding proportion over utterances
+    info["utt_pad_proportion"] = (
+        ((feature.size(1) - feature_lens) / feature.size(1)).sum().item()
+    )
+
+    # Note: We use reduction=sum while computing the loss.
+    info["loss"] = loss.detach().cpu().item()
+
+    return loss, info
+
+
+def compute_validation_loss(
+    params: AttributeDict,
+    model: Union[nn.Module, DDP],
+    graph_compiler: Union[BpeCtcTrainingGraphCompiler, CtcTrainingGraphCompiler],
+    valid_dl: torch.utils.data.DataLoader,
+    world_size: int = 1,
+) -> MetricsTracker:
+    """Run the validation process."""
+    model.eval()
+
+    tot_loss = MetricsTracker()
+
+    for batch_idx, batch in enumerate(valid_dl):
+        loss, loss_info = compute_loss(
+            params=params,
+            model=model,
+            graph_compiler=graph_compiler,
+            batch=batch,
+            is_training=False,
+        )
+        assert loss.requires_grad is False
+        tot_loss = tot_loss + loss_info
+
+    if world_size > 1:
+        tot_loss.reduce(loss.device)
+
+    loss_value = tot_loss["loss"] / tot_loss["frames"]
+    if loss_value < params.best_valid_loss:
+        params.best_valid_epoch = params.cur_epoch
+        params.best_valid_loss = loss_value
+
+    return tot_loss
+
+
+def train_one_epoch(
+    params: AttributeDict,
+    model: Union[nn.Module, DDP],
+    optimizer: torch.optim.Optimizer,
+    scheduler: LRSchedulerType,
+    graph_compiler: Union[BpeCtcTrainingGraphCompiler, CtcTrainingGraphCompiler],
+    train_dl: torch.utils.data.DataLoader,
+    valid_dl: torch.utils.data.DataLoader,
+    scaler: GradScaler,
+    model_avg: Optional[nn.Module] = None,
+    tb_writer: Optional[SummaryWriter] = None,
+    world_size: int = 1,
+    rank: int = 0,
+) -> None:
+    """Train the model for one epoch.
+
+    The training loss from the mean of all frames is saved in
+    `params.train_loss`. It runs the validation process every
+    `params.valid_interval` batches.
+
+    Args:
+      params:
+        It is returned by :func:`get_params`.
+      model:
+        The model for training.
+      optimizer:
+        The optimizer we are using.
+      scheduler:
+        The learning rate scheduler, we call step() every step.
+      graph_compiler:
+        It is used to build a decoding graph from a ctc topo and training
+        transcript. The training transcript is contained in the given `batch`,
+        while the ctc topo is built when this compiler is instantiated.
+      train_dl:
+        Dataloader for the training dataset.
+      valid_dl:
+        Dataloader for the validation dataset.
+      scaler:
+        The scaler used for mix precision training.
+      model_avg:
+        The stored model averaged from the start of training.
+      tb_writer:
+        Writer to write log messages to tensorboard.
+      world_size:
+        Number of nodes in DDP training. If it is 1, DDP is disabled.
+      rank:
+        The rank of the node in DDP training. If no DDP is used, it should
+        be set to 0.
+    """
+    model.train()
+
+    tot_loss = MetricsTracker()
+
+    for batch_idx, batch in enumerate(train_dl):
+        params.batch_idx_train += 1
+        batch_size = len(batch["supervisions"]["text"])
+
+        with torch.cuda.amp.autocast(enabled=params.use_fp16):
+            loss, loss_info = compute_loss(
+                params=params,
+                model=model,
+                graph_compiler=graph_compiler,
+                batch=batch,
+                is_training=True,
+                warmup=(params.batch_idx_train / params.model_warm_step),
+            )
+        # summary stats
+        tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
+
+        # NOTE: We use reduction==sum and loss is computed over utterances
+        # in the batch and there is no normalization to it so far.
+        scaler.scale(loss).backward()
+        scheduler.step_batch(params.batch_idx_train)
+        scaler.step(optimizer)
+        scaler.update()
+        optimizer.zero_grad()
+
+        if params.print_diagnostics and batch_idx == 30:
+            return
+
+        if (
+            rank == 0
+            and params.batch_idx_train > 0
+            and params.batch_idx_train % params.average_period == 0
+        ):
+            update_averaged_model(
+                params=params,
+                model_cur=model,
+                model_avg=model_avg,
+            )
+
+        if (
+            params.batch_idx_train > 0
+            and params.batch_idx_train % params.save_every_n == 0
+        ):
+            save_checkpoint_with_global_batch_idx(
+                out_dir=params.exp_dir,
+                global_batch_idx=params.batch_idx_train,
+                model=model,
+                model_avg=model_avg,
+                params=params,
+                optimizer=optimizer,
+                scheduler=scheduler,
+                sampler=train_dl.sampler,
+                scaler=scaler,
+                rank=rank,
+            )
+            remove_checkpoints(
+                out_dir=params.exp_dir,
+                topk=params.keep_last_k,
+                rank=rank,
+            )
+
+        if batch_idx % params.log_interval == 0:
+            cur_lr = scheduler.get_last_lr()[0]
+            logging.info(
+                f"Epoch {params.cur_epoch}, "
+                f"batch {batch_idx}, loss[{loss_info}], "
+                f"tot_loss[{tot_loss}], batch size: {batch_size}, "
+                f"lr: {cur_lr:.2e}"
+            )
+
+            if tb_writer is not None:
+                tb_writer.add_scalar(
+                    "train/learning_rate", cur_lr, params.batch_idx_train
+                )
+
+                loss_info.write_summary(
+                    tb_writer, "train/current_", params.batch_idx_train
+                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+
+        if batch_idx > 0 and batch_idx % params.valid_interval == 0:
+            logging.info("Computing validation loss")
+            valid_info = compute_validation_loss(
+                params=params,
+                model=model,
+                graph_compiler=graph_compiler,
+                valid_dl=valid_dl,
+                world_size=world_size,
+            )
+            model.train()
+            logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
+            if tb_writer is not None:
+                valid_info.write_summary(
+                    tb_writer, "train/valid_", params.batch_idx_train
+                )
+
+    loss_value = tot_loss["loss"] / tot_loss["frames"]
+    params.train_loss = loss_value
+    if params.train_loss < params.best_train_loss:
+        params.best_train_epoch = params.cur_epoch
+        params.best_train_loss = params.train_loss
+
+
+def run(rank, world_size, args):
+    """
+    Args:
+      rank:
+        It is a value between 0 and `world_size-1`, which is
+        passed automatically by `mp.spawn()` in :func:`main`.
+        The node with rank 0 is responsible for saving checkpoint.
+      world_size:
+        Number of GPUs for DDP training.
+      args:
+        The return value of get_parser().parse_args()
+    """
+    params = get_params()
+    params.update(vars(args))
+    if params.full_libri is False:
+        params.valid_interval = 1600
+
+    fix_random_seed(params.seed)
+    if world_size > 1:
+        setup_dist(rank, world_size, params.master_port)
+
+    setup_logger(f"{params.exp_dir}/log/log-train")
+    logging.info("Training started")
+
+    if args.tensorboard and rank == 0:
+        tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
+    else:
+        tb_writer = None
+
+    lexicon = Lexicon(params.lang_dir)
+    max_token_id = max(lexicon.tokens)
+    params.vocab_size = max_token_id + 1  # +1 for the blank
+
+    device = torch.device("cpu")
+    if torch.cuda.is_available():
+        device = torch.device("cuda", rank)
+    logging.info(f"Device: {device}")
+
+    if "lang_bpe" in str(params.lang_dir):
+        graph_compiler = BpeCtcTrainingGraphCompiler(
+            params.lang_dir,
+            device=device,
+            sos_token="",
+            eos_token="",
+        )
+    elif "lang_phone" in str(params.lang_dir):
+        graph_compiler = CtcTrainingGraphCompiler(
+            lexicon,
+            device=device,
+        )
+        # Manually add the sos/eos ID with their default values
+        # from the BPE recipe which we're adapting here.
+        graph_compiler.sos_id = 1
+        graph_compiler.eos_id = 1
+    else:
+        raise ValueError(
+            f"Unsupported type of lang dir (we expected it to have "
+            f"'lang_bpe' or 'lang_phone' in its name): {params.lang_dir}"
+        )
+
+    if params.dynamic_chunk_training:
+        assert (
+            params.causal_convolution
+        ), "dynamic_chunk_training requires causal convolution"
+
+    logging.info(params)
+
+    logging.info("About to create model")
+    model = get_ctc_model(params)
+
+    num_param = sum([p.numel() for p in model.parameters()])
+    logging.info(f"Number of model parameters: {num_param}")
+
+    assert params.save_every_n >= params.average_period
+    model_avg: Optional[nn.Module] = None
+    if rank == 0:
+        # model_avg is only used with rank 0
+        model_avg = copy.deepcopy(model)
+
+    assert params.start_epoch > 0, params.start_epoch
+    checkpoints = load_checkpoint_if_available(
+        params=params, model=model, model_avg=model_avg
+    )
+
+    model.to(device)
+    if world_size > 1:
+        logging.info("Using DDP")
+        model = DDP(model, device_ids=[rank])
+
+    optimizer = Eve(model.parameters(), lr=params.initial_lr)
+
+    scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
+
+    if checkpoints and "optimizer" in checkpoints:
+        logging.info("Loading optimizer state dict")
+        optimizer.load_state_dict(checkpoints["optimizer"])
+
+    if (
+        checkpoints
+        and "scheduler" in checkpoints
+        and checkpoints["scheduler"] is not None
+    ):
+        logging.info("Loading scheduler state dict")
+        scheduler.load_state_dict(checkpoints["scheduler"])
+
+    if params.print_diagnostics:
+        diagnostic = diagnostics.attach_diagnostics(model)
+
+    librispeech = LibriSpeechAsrDataModule(args)
+
+    train_cuts = librispeech.train_clean_100_cuts()
+    if params.full_libri:
+        train_cuts += librispeech.train_clean_360_cuts()
+        train_cuts += librispeech.train_other_500_cuts()
+
+    def remove_short_and_long_utt(c: Cut):
+        # Keep only utterances with duration between 1 second and 20 seconds
+        #
+        # Caution: There is a reason to select 20.0 here. Please see
+        # ../local/display_manifest_statistics.py
+        #
+        # You should use ../local/display_manifest_statistics.py to get
+        # an utterance duration distribution for your dataset to select
+        # the threshold
+        return 1.0 <= c.duration <= 20.0
+
+    train_cuts = train_cuts.filter(remove_short_and_long_utt)
+
+    if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
+        # We only load the sampler's state dict when it loads a checkpoint
+        # saved in the middle of an epoch
+        sampler_state_dict = checkpoints["sampler"]
+    else:
+        sampler_state_dict = None
+
+    train_dl = librispeech.train_dataloaders(
+        train_cuts, sampler_state_dict=sampler_state_dict
+    )
+
+    valid_cuts = librispeech.dev_clean_cuts()
+    valid_cuts += librispeech.dev_other_cuts()
+    valid_dl = librispeech.valid_dataloaders(valid_cuts)
+
+    if params.start_batch <= 0 and not params.print_diagnostics:
+        scan_pessimistic_batches_for_oom(
+            model=model,
+            train_dl=train_dl,
+            optimizer=optimizer,
+            graph_compiler=graph_compiler,
+            params=params,
+            warmup=0.0 if params.start_epoch == 1 else 1.0,
+        )
+
+    scaler = GradScaler(enabled=params.use_fp16)
+    if checkpoints and "grad_scaler" in checkpoints:
+        logging.info("Loading grad scaler state dict")
+        scaler.load_state_dict(checkpoints["grad_scaler"])
+
+    for epoch in range(params.start_epoch, params.num_epochs + 1):
+        scheduler.step_epoch(epoch - 1)
+        fix_random_seed(params.seed + epoch - 1)
+        train_dl.sampler.set_epoch(epoch - 1)
+
+        if tb_writer is not None:
+            tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
+
+        params.cur_epoch = epoch
+
+        train_one_epoch(
+            params=params,
+            model=model,
+            model_avg=model_avg,
+            optimizer=optimizer,
+            scheduler=scheduler,
+            graph_compiler=graph_compiler,
+            train_dl=train_dl,
+            valid_dl=valid_dl,
+            scaler=scaler,
+            tb_writer=tb_writer,
+            world_size=world_size,
+            rank=rank,
+        )
+
+        if params.print_diagnostics:
+            diagnostic.print_diagnostics()
+            break
+
+        save_checkpoint(
+            params=params,
+            model=model,
+            model_avg=model_avg,
+            optimizer=optimizer,
+            scheduler=scheduler,
+            sampler=train_dl.sampler,
+            scaler=scaler,
+            rank=rank,
+        )
+
+    logging.info("Done!")
+
+    if world_size > 1:
+        torch.distributed.barrier()
+        cleanup_dist()
+
+
+def scan_pessimistic_batches_for_oom(
+    model: Union[nn.Module, DDP],
+    train_dl: torch.utils.data.DataLoader,
+    optimizer: torch.optim.Optimizer,
+    graph_compiler: Union[BpeCtcTrainingGraphCompiler, CtcTrainingGraphCompiler],
+    params: AttributeDict,
+    warmup: float,
+):
+    from lhotse.dataset import find_pessimistic_batches
+
+    logging.info(
+        "Sanity check -- see if any of the batches in epoch 1 would cause OOM."
+    )
+    batches, crit_values = find_pessimistic_batches(train_dl.sampler)
+    for criterion, cuts in batches.items():
+        batch = train_dl.dataset[cuts]
+        try:
+            with torch.cuda.amp.autocast(enabled=params.use_fp16):
+                loss, _ = compute_loss(
+                    params=params,
+                    model=model,
+                    graph_compiler=graph_compiler,
+                    batch=batch,
+                    is_training=True,
+                    warmup=warmup,
+                )
+            loss.backward()
+            optimizer.step()
+            optimizer.zero_grad()
+        except RuntimeError as e:
+            if "CUDA out of memory" in str(e):
+                logging.error(
+                    "Your GPU ran out of memory with the current "
+                    "max_duration setting. We recommend decreasing "
+                    "max_duration and trying again.\n"
+                    f"Failing criterion: {criterion} "
+                    f"(={crit_values[criterion]}) ..."
+                )
+            raise
+
+
+def main():
+    parser = get_parser()
+    LibriSpeechAsrDataModule.add_arguments(parser)
+    args = parser.parse_args()
+    args.exp_dir = Path(args.exp_dir)
+
+    world_size = args.world_size
+    assert world_size >= 1
+    if world_size > 1:
+        mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
+    else:
+        run(rank=0, world_size=1, args=args)
+
+
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+if __name__ == "__main__":
+    main()
diff --git a/icefall/bpe_graph_compiler.py b/icefall/bpe_graph_compiler.py
index e76b7ea32..d9659c2dd 100644
--- a/icefall/bpe_graph_compiler.py
+++ b/icefall/bpe_graph_compiler.py
@@ -83,11 +83,12 @@ class BpeCtcTrainingGraphCompiler(object):
         Args:
           piece_ids:
             It is a list-of-list integer IDs.
-         modified:
+          modified:
            See :func:`k2.ctc_graph` for its meaning.
         Return:
           Return an FsaVec, which is the result of composing a
           CTC topology with linear FSAs constructed from the given
           piece IDs.
         """
-        return k2.ctc_graph(piece_ids, modified=modified, device=self.device)
+        graph = k2.ctc_graph(piece_ids, modified=modified, device=self.device)
+        return graph
diff --git a/icefall/char_graph_compiler.py b/icefall/char_graph_compiler.py
index c31db6e4c..5f9571d42 100644
--- a/icefall/char_graph_compiler.py
+++ b/icefall/char_graph_compiler.py
@@ -117,4 +117,5 @@ class CharCtcTrainingGraphCompiler(object):
           CTC topology with linear FSAs constructed from the given
           piece IDs.
         """
-        return k2.ctc_graph(token_ids, modified=modified, device=self.device)
+        graph = k2.ctc_graph(token_ids, modified=modified, device=self.device)
+        return graph
diff --git a/icefall/checkpoint.py b/icefall/checkpoint.py
index f0663a1df..c83c56a53 100644
--- a/icefall/checkpoint.py
+++ b/icefall/checkpoint.py
@@ -298,7 +298,7 @@ def find_checkpoints(out_dir: Path, iteration: int = 0) -> List[str]:
         if not result:
             logging.warn(f"Invalid checkpoint filename {c}")
             continue
-        
+
         iter_checkpoints.append((int(result.group(1)), c))
 
     # iter_checkpoints is a list of tuples. Each tuple contains
diff --git a/icefall/graph_compiler.py b/icefall/graph_compiler.py
index e2ff03f61..84be81254 100644
--- a/icefall/graph_compiler.py
+++ b/icefall/graph_compiler.py
@@ -79,6 +79,10 @@ class CtcTrainingGraphCompiler(object):
 
         fsa_with_self_loops = k2.arc_sort(fsa_with_self_loops)
 
+        self.ctc_topo._is_repeat_token_ = (
+            self.ctc_topo.labels != self.ctc_topo.aux_labels
+        )
+
         decoding_graph = k2.compose(
             self.ctc_topo, fsa_with_self_loops, treat_epsilons_specially=False
         )
diff --git a/icefall/utils.py b/icefall/utils.py
index b4d8e9a51..d852491c8 100644
--- a/icefall/utils.py
+++ b/icefall/utils.py
@@ -670,8 +670,8 @@ def write_error_stats_with_timestamps(
     all_delay = []
     for cut_id, ref, hyp, time_ref, time_hyp in results:
         ali = kaldialign.align(ref, hyp, ERR)
-        has_time_ref = len(time_ref) > 0
-        if has_time_ref:
+        has_time = len(time_ref) > 0 and len(time_hyp) > 0
+        if has_time:
             # pointer to timestamp_hyp
             p_hyp = 0
             # pointer to timestamp_ref
@@ -680,28 +680,28 @@ def write_error_stats_with_timestamps(
             if ref_word == ERR:
                 ins[hyp_word] += 1
                 words[hyp_word][3] += 1
-                if has_time_ref:
+                if has_time:
                     p_hyp += 1
             elif hyp_word == ERR:
                 dels[ref_word] += 1
                 words[ref_word][4] += 1
-                if has_time_ref:
+                if has_time:
                     p_ref += 1
             elif hyp_word != ref_word:
                 subs[(ref_word, hyp_word)] += 1
                 words[ref_word][1] += 1
                 words[hyp_word][2] += 1
-                if has_time_ref:
+                if has_time:
                     p_hyp += 1
                     p_ref += 1
             else:
                 words[ref_word][0] += 1
                 num_corr += 1
-                if has_time_ref:
+                if has_time:
                     all_delay.append(time_hyp[p_hyp] - time_ref[p_ref])
                     p_hyp += 1
                     p_ref += 1
-        if has_time_ref:
+        if has_time:
             assert p_hyp == len(hyp), (p_hyp, len(hyp))
             assert p_ref == len(ref), (p_ref, len(ref))
 
@@ -1327,10 +1327,9 @@ def parse_timestamp(tokens: List[str], timestamp: List[float]) -> List[float]:
 
 def parse_hyp_and_timestamp(
     res: DecodingResults,
-    decoding_method: str,
-    sp: spm.SentencePieceProcessor,
     subsampling_factor: int,
     frame_shift_ms: float = 10,
+    sp: Optional[spm.SentencePieceProcessor] = None,
     word_table: Optional[k2.SymbolTable] = None,
 ) -> Tuple[List[List[str]], List[List[float]]]:
     """Parse hypothesis and timestamp.
@@ -1338,51 +1337,29 @@ def parse_hyp_and_timestamp(
     Args:
       res:
         A DecodingResults object.
-      decoding_method:
-        Possible values are:
-          - greedy_search
-          - beam_search
-          - modified_beam_search
-          - fast_beam_search
-          - fast_beam_search_LG
-          - fast_beam_search_nbest
-          - fast_beam_search_nbest_oracle
-          - fast_beam_search_nbest_LG
-      sp:
-        The BPE model.
       subsampling_factor:
         The integer subsampling factor.
       frame_shift_ms:
         The float frame shift used for feature extraction.
+      sp:
+        The BPE model.
       word_table:
         The word symbol table.
 
     Returns:
        Return a list of hypothesis and timestamp.
     """
-    assert decoding_method in (
-        "greedy_search",
-        "beam_search",
-        "fast_beam_search",
-        "fast_beam_search_LG",
-        "fast_beam_search_nbest",
-        "fast_beam_search_nbest_LG",
-        "fast_beam_search_nbest_oracle",
-        "modified_beam_search",
-    )
-
     hyps = []
     timestamps = []
 
     N = len(res.hyps)
     assert len(res.timestamps) == N, (len(res.timestamps), N)
     use_word_table = False
-    if (
-        decoding_method == "fast_beam_search_nbest_LG"
-        and decoding_method == "fast_beam_search_LG"
-    ):
-        assert word_table is not None
+    if word_table is not None:
+        assert sp is None
         use_word_table = True
+    else:
+        assert sp is not None and word_table is None
 
     for i in range(N):
         time = convert_timestamp(res.timestamps[i], subsampling_factor, frame_shift_ms)

From 4b5bc480e8a5ac253dcd22b08dfa59083dadd6fd Mon Sep 17 00:00:00 2001
From: marcoyang1998 <45973641+marcoyang1998@users.noreply.github.com>
Date: Wed, 30 Nov 2022 17:26:05 +0800
Subject: [PATCH 025/174] Add low-order density ratio in RNNLM shallow fusion
 (#678)

* Support LODR in RNNLM shallow fusion

* fix style

* fix code style

* update workflow and CI

* update results

* propagate changes to stateless3

* add decoding results for stateless3+giga

* fix CI
---
 ...-lstm-transducer-stateless2-2022-09-03.yml |  67 ++++-
 ...-lstm-transducer-stateless2-2022-09-03.yml |  15 +-
 egs/librispeech/ASR/RESULTS.md                |  87 ++++++
 .../ASR/lstm_transducer_stateless2/decode.py  |  51 +++-
 .../beam_search.py                            | 264 ++++++++++++++++++
 .../pruned_transducer_stateless3/decode.py    | 181 +++++++++++-
 6 files changed, 646 insertions(+), 19 deletions(-)

diff --git a/.github/scripts/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml b/.github/scripts/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml
index 6ce92d022..ac5b15979 100755
--- a/.github/scripts/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml
+++ b/.github/scripts/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml
@@ -16,6 +16,7 @@ log "Downloading pre-trained model from $repo_url"
 git lfs install
 git clone $repo_url
 repo=$(basename $repo_url)
+abs_repo=$(realpath $repo)
 
 log "Display test files"
 tree $repo/
@@ -178,21 +179,27 @@ echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}"
 if [[ x"${GITHUB_EVENT_LABEL_NAME}" == x"shallow-fusion" ]]; then
   lm_repo_url=https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm
   log "Download pre-trained RNN-LM model from ${lm_repo_url}"
-  git clone $lm_repo_url
+  GIT_LFS_SKIP_SMUDGE=1 git clone $lm_repo_url
   lm_repo=$(basename $lm_repo_url)
   pushd $lm_repo
   git lfs pull --include "exp/pretrained.pt"
-  cd exp
-  ln -s pretrained.pt epoch-88.pt
+  mv exp/pretrained.pt exp/epoch-88.pt
   popd
 
+  mkdir -p lstm_transducer_stateless2/exp
+  ln -sf $PWD/$repo/exp/pretrained.pt lstm_transducer_stateless2/exp/epoch-999.pt
+  ln -s $PWD/$repo/data/lang_bpe_500 data/
+
+  ls -lh data
+  ls -lh lstm_transducer_stateless2/exp
+
+  log "Decoding test-clean and test-other"
+
   ./lstm_transducer_stateless2/decode.py \
     --use-averaged-model 0 \
-    --epoch 99 \
+    --epoch 999 \
     --avg 1 \
-    --exp-dir $repo/exp \
-    --lang-dir $repo/data/lang_bpe_500 \
-    --bpe-model $repo/data/lang_bpe_500/bpe.model \
+    --exp-dir lstm_transducer_stateless2/exp \
     --max-duration 600 \
     --decoding-method modified_beam_search_rnnlm_shallow_fusion \
     --beam 4 \
@@ -204,6 +211,52 @@ if [[ x"${GITHUB_EVENT_LABEL_NAME}" == x"shallow-fusion" ]]; then
     --rnn-lm-tie-weights 1
 fi
 
+if [[ x"${GITHUB_EVENT_LABEL_NAME}" == x"LODR" ]]; then
+  bigram_repo_url=https://huggingface.co/marcoyang/librispeech_bigram
+  log "Download bi-gram LM from ${bigram_repo_url}"
+  GIT_LFS_SKIP_SMUDGE=1 git clone $bigram_repo_url
+  bigramlm_repo=$(basename $bigram_repo_url)
+  pushd $bigramlm_repo
+  git lfs pull --include "2gram.fst.txt"
+  cp 2gram.fst.txt $abs_repo/data/lang_bpe_500/.
+  popd
+
+  lm_repo_url=https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm
+  log "Download pre-trained RNN-LM model from ${lm_repo_url}"
+  GIT_LFS_SKIP_SMUDGE=1 git clone $lm_repo_url
+  lm_repo=$(basename $lm_repo_url)
+  pushd $lm_repo
+  git lfs pull --include "exp/pretrained.pt"
+  mv exp/pretrained.pt exp/epoch-88.pt
+  popd
+
+  mkdir -p lstm_transducer_stateless2/exp
+  ln -sf $PWD/$repo/exp/pretrained.pt lstm_transducer_stateless2/exp/epoch-999.pt
+  ln -s $PWD/$repo/data/lang_bpe_500 data/
+
+  ls -lh data
+  ls -lh lstm_transducer_stateless2/exp
+
+  log "Decoding test-clean and test-other"
+
+  ./lstm_transducer_stateless2/decode.py \
+    --use-averaged-model 0 \
+    --epoch 999 \
+    --avg 1 \
+    --exp-dir lstm_transducer_stateless2/exp \
+    --max-duration 600 \
+    --decoding-method modified_beam_search_rnnlm_LODR \
+    --beam 4 \
+    --rnn-lm-scale 0.3 \
+    --rnn-lm-exp-dir $lm_repo/exp \
+    --rnn-lm-epoch 88 \
+    --rnn-lm-avg 1 \
+    --rnn-lm-num-layers 3 \
+    --rnn-lm-tie-weights 1 \
+    --tokens-ngram 2 \
+    --ngram-lm-scale -0.16
+fi
+
 if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" ]]; then
   mkdir -p lstm_transducer_stateless2/exp
   ln -s $PWD/$repo/exp/pretrained.pt lstm_transducer_stateless2/exp/epoch-999.pt
diff --git a/.github/workflows/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml b/.github/workflows/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml
index a90841fb6..5f0acf9b8 100644
--- a/.github/workflows/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml
+++ b/.github/workflows/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml
@@ -18,7 +18,7 @@ on:
 
 jobs:
   run_librispeech_lstm_transducer_stateless2_2022_09_03:
-    if: github.event.label.name == 'ready' || github.event.label.name == 'shallow-fusion' || github.event.label.name == 'ncnn' || github.event.label.name == 'onnx' || github.event_name == 'push' || github.event_name == 'schedule'
+    if: github.event.label.name == 'ready' || github.event.label.name == 'LODR' || github.event.label.name == 'shallow-fusion' || github.event.label.name == 'ncnn' || github.event.label.name == 'onnx' || github.event_name == 'push' || github.event_name == 'schedule'
     runs-on: ${{ matrix.os }}
     strategy:
       matrix:
@@ -139,9 +139,20 @@ jobs:
           find modified_beam_search_rnnlm_shallow_fusion  -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
           find modified_beam_search_rnnlm_shallow_fusion  -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
 
+      - name: Display decoding results for lstm_transducer_stateless2
+        if: github.event.label.name == 'LODR'
+        shell: bash
+        run: |
+          cd egs/librispeech/ASR
+          tree lstm_transducer_stateless2/exp
+          cd lstm_transducer_stateless2/exp
+          echo "===modified_beam_search_rnnlm_LODR==="
+          find modified_beam_search_rnnlm_LODR  -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
+          find modified_beam_search_rnnlm_LODR  -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
+
       - name: Upload decoding results for lstm_transducer_stateless2
         uses: actions/upload-artifact@v2
-        if: github.event_name == 'schedule' || github.event.label.name == 'shallow-fusion'
+        if: github.event_name == 'schedule' || github.event.label.name == 'shallow-fusion' || github.event.label.name == 'LODR'
         with:
           name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-lstm_transducer_stateless2-2022-09-03
           path: egs/librispeech/ASR/lstm_transducer_stateless2/exp/
diff --git a/egs/librispeech/ASR/RESULTS.md b/egs/librispeech/ASR/RESULTS.md
index efd60ba81..c2ea3d050 100644
--- a/egs/librispeech/ASR/RESULTS.md
+++ b/egs/librispeech/ASR/RESULTS.md
@@ -318,6 +318,7 @@ The WERs are:
 | greedy search (max sym per frame 1) | 2.78       | 7.36       | --iter 468000 --avg 16  |
 | modified_beam_search                | 2.73       | 7.15       | --iter 468000 --avg 16  |
 | modified_beam_search + RNNLM shallow fusion   | 2.42     |  6.46      | --iter 468000 --avg 16  |
+| modified_beam_search + RNNLM shallow fusion   | 2.28     |  5.94      | --iter 468000 --avg 16  |
 | fast_beam_search                    | 2.76       | 7.31       | --iter 468000 --avg 16  |
 | greedy search (max sym per frame 1) | 2.77       | 7.35       | --iter 472000 --avg 18  |
 | modified_beam_search                | 2.75       | 7.08       | --iter 472000 --avg 18  |
@@ -393,6 +394,32 @@ for iter in 472000; do
     done
 done
 
+You may also decode using LODR + RNNLM shallow fusion. This decoding method is proposed in .
+It subtracts the internal language model score during shallow fusion, which is approximated by a bi-gram model. The bi-gram can be
+generated by `generate-lm.sh`, or you may download it from .
+
+The decoding command is as follows:
+
+for iter in 472000; do
+    for avg in 8 10 12 14 16 18; do
+        ./lstm_transducer_stateless2/decode.py \
+                --iter $iter \
+                --avg $avg \
+                --exp-dir ./lstm_transducer_stateless2/exp \
+                --max-duration 600 \
+                --decoding-method modified_beam_search_rnnlm_LODR \
+                --beam 4 \
+                --rnn-lm-scale 0.4 \
+                --rnn-lm-exp-dir /path/to/RNNLM \
+                --rnn-lm-epoch 99 \
+                --rnn-lm-avg 1 \
+                --rnn-lm-num-layers 3 \
+                --rnn-lm-tie-weights 1 \
+                --token-ngram 2 \
+                --ngram-lm-scale -0.16
+    done
+done
+
 Pretrained models, training logs, decoding logs, and decoding results
 are available at
 
@@ -1912,6 +1939,8 @@ subset so that the gigaspeech dataloader never exhausts.
 |-------------------------------------|------------|------------|---------------------------------------------|
 | greedy search (max sym per frame 1) | 2.03       | 4.70       | --iter 1224000 --avg 14  --max-duration 600 |
 | modified beam search                | 2.00       | 4.63       | --iter 1224000 --avg 14  --max-duration 600 |
+| modified beam search + rnnlm shallow fusion  | 1.94     |  4.2    | --iter 1224000 --avg 14  --max-duration 600 |
+| modified beam search + LODR         | 1.83       | 4.03       | --iter 1224000 --avg 14  --max-duration 600 |
 | fast beam search                    | 2.10       | 4.68       | --iter 1224000 --avg 14 --max-duration 600 |
 
 The training commands are:
@@ -1957,6 +1986,64 @@ for iter in 1224000; do
   done
 done
 ```
+You may also decode using shallow fusion with external RNNLM. To do so you need to
+download a well-trained RNNLM from this link 
+
+```bash
+rnn_lm_scale=0.3
+
+for iter in 1224000; do
+  for avg in 14; do
+    for method in modified_beam_search_rnnlm_shallow_fusion ; do
+      ./pruned_transducer_stateless3/decode.py \
+        --iter $iter \
+        --avg $avg \
+        --exp-dir ./pruned_transducer_stateless3/exp-0.9/ \
+        --max-duration 600 \
+        --decoding-method $method \
+        --max-sym-per-frame 1 \
+        --beam 4 \
+        --max-contexts 32 \
+        --rnn-lm-scale $rnn_lm_scale \
+        --rnn-lm-exp-dir /path/to/RNNLM \
+        --rnn-lm-epoch 99 \
+        --rnn-lm-avg 1 \
+        --rnn-lm-num-layers 3 \
+        --rnn-lm-tie-weights 1
+    done
+  done
+done
+```
+
+If you want to try out with LODR decoding, use the following command. This assums you have a bi-gram LM trained on LibriSpeech text. You can also download the bi-gram LM from here  and put it under the directory `data/lang_bpe_500`.
+
+```bash
+rnn_lm_scale=0.4
+
+for iter in 1224000; do
+  for avg in 14; do
+    for method in modified_beam_search_rnnlm_LODR ; do
+      ./pruned_transducer_stateless3/decode.py \
+        --iter $iter \
+        --avg $avg \
+        --exp-dir ./pruned_transducer_stateless3/exp-0.9/ \
+        --max-duration 600 \
+        --decoding-method $method \
+        --max-sym-per-frame 1 \
+        --beam 4 \
+        --max-contexts 32 \
+        --rnn-lm-scale $rnn_lm_scale \
+        --rnn-lm-exp-dir /path/to/RNNLM \
+        --rnn-lm-epoch 99 \
+        --rnn-lm-avg 1 \
+        --rnn-lm-num-layers 3 \
+        --rnn-lm-tie-weights 1 \
+        --tokens-ngram 2 \
+        --ngram-lm-scale -0.14
+    done
+  done
+done
+```
 
 The pretrained models, training logs, decoding logs, and decoding results
 can be found at
diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py b/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py
index 69f695fef..fa5bf1825 100755
--- a/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py
+++ b/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py
@@ -107,8 +107,25 @@ Usage:
     --rnn-lm-avg 1 \
     --rnn-lm-num-layers 3 \
     --rnn-lm-tie-weights 1
-"""
 
+(9) modified beam search with RNNLM shallow fusion + LODR
+./lstm_transducer_stateless2/decode.py \
+    --epoch 35 \
+    --avg 15 \
+    --max-duration 600 \
+    --exp-dir ./lstm_transducer_stateless2/exp \
+    --decoding-method modified_beam_search_rnnlm_LODR \
+    --beam 4 \
+    --max-contexts 4 \
+    --rnn-lm-scale 0.4 \
+    --rnn-lm-exp-dir /path/to/RNNLM/exp \
+    --rnn-lm-epoch 99 \
+    --rnn-lm-avg 1 \
+    --rnn-lm-num-layers 3 \
+    --rnn-lm-tie-weights 1 \
+    --tokens-ngram 2 \
+    --ngram-lm-scale -0.16 \
+"""
 
 import argparse
 import logging
@@ -132,6 +149,7 @@ from beam_search import (
     greedy_search_batch,
     modified_beam_search,
     modified_beam_search_ngram_rescoring,
+    modified_beam_search_rnnlm_LODR,
     modified_beam_search_rnnlm_shallow_fusion,
 )
 from librispeech import LibriSpeech
@@ -235,7 +253,8 @@ def get_parser():
           - fast_beam_search_nbest_oracle
           - fast_beam_search_nbest_LG
           - modified_beam_search_ngram_rescoring
-          - modified_beam_search_rnnlm_shallow_fusion # for rnn lm shallow fusion
+          - modified_beam_search_rnnlm_shallow_fusion
+          - modified_beam_search_rnnlm_LODR
         If you use fast_beam_search_nbest_LG, you have to specify
         `--lang-dir`, which should contain `LG.pt`.
         """,
@@ -394,7 +413,8 @@ def get_parser():
         type=int,
         default=3,
         help="""Token Ngram used for rescoring.
-            Used only when the decoding method is modified_beam_search_ngram_rescoring""",
+            Used only when the decoding method is
+            modified_beam_search_ngram_rescoring""",
     )
 
     parser.add_argument(
@@ -402,7 +422,8 @@ def get_parser():
         type=int,
         default=500,
         help="""ID of the backoff symbol.
-                Used only when the decoding method is modified_beam_search_ngram_rescoring""",
+                Used only when the decoding method is
+                modified_beam_search_ngram_rescoring""",
     )
 
     add_model_arguments(parser)
@@ -572,6 +593,20 @@ def decode_one_batch(
         )
         for hyp in sp.decode(hyp_tokens):
             hyps.append(hyp.split())
+    elif params.decoding_method == "modified_beam_search_rnnlm_LODR":
+        hyp_tokens = modified_beam_search_rnnlm_LODR(
+            model=model,
+            encoder_out=encoder_out,
+            encoder_out_lens=encoder_out_lens,
+            beam=params.beam_size,
+            sp=sp,
+            LODR_lm=ngram_lm,
+            LODR_lm_scale=ngram_lm_scale,
+            rnnlm=rnnlm,
+            rnnlm_scale=rnnlm_scale,
+        )
+        for hyp in sp.decode(hyp_tokens):
+            hyps.append(hyp.split())
     else:
         batch_size = encoder_out.size(0)
 
@@ -760,6 +795,7 @@ def main():
         "fast_beam_search_nbest_LG",
         "fast_beam_search_nbest_oracle",
         "modified_beam_search",
+        "modified_beam_search_rnnlm_LODR",
         "modified_beam_search_ngram_rescoring",
         "modified_beam_search_rnnlm_shallow_fusion",
     )
@@ -788,6 +824,9 @@ def main():
     if "rnnlm" in params.decoding_method:
         params.suffix += f"-rnnlm-lm-scale-{params.rnn_lm_scale}"
 
+    if "LODR" in params.decoding_method:
+        params.suffix += "-LODR"
+
     if params.use_averaged_model:
         params.suffix += "-use-averaged-model"
 
@@ -901,7 +940,7 @@ def main():
     model.eval()
 
     # only load N-gram LM when needed
-    if "ngram" in params.decoding_method:
+    if "ngram" in params.decoding_method or "LODR" in params.decoding_method:
         lm_filename = f"{params.tokens_ngram}gram.fst.txt"
         logging.info(f"lm filename: {lm_filename}")
         ngram_lm = NgramLm(
@@ -910,6 +949,7 @@ def main():
             is_binary=False,
         )
         logging.info(f"num states: {ngram_lm.lm.num_states}")
+        ngram_lm_scale = params.ngram_lm_scale
     else:
         ngram_lm = None
         ngram_lm_scale = None
@@ -933,7 +973,6 @@ def main():
         )
         rnn_lm_model.to(device)
         rnn_lm_model.eval()
-
     else:
         rnn_lm_model = None
         rnn_lm_scale = 0.0
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py
index 5e9428b60..59c8ed5b5 100644
--- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py
@@ -2083,3 +2083,267 @@ def modified_beam_search_rnnlm_shallow_fusion(
             tokens=ans,
             timestamps=ans_timestamps,
         )
+
+
+def modified_beam_search_rnnlm_LODR(
+    model: Transducer,
+    encoder_out: torch.Tensor,
+    encoder_out_lens: torch.Tensor,
+    sp: spm.SentencePieceProcessor,
+    LODR_lm: NgramLm,
+    LODR_lm_scale: float,
+    rnnlm: RnnLmModel,
+    rnnlm_scale: float,
+    beam: int = 4,
+) -> List[List[int]]:
+    """This function implements LODR (https://arxiv.org/abs/2203.16776) with
+    `modified_beam_search`. It uses a bi-gram language model as the estimate
+    of the internal language model and subtracts its score during shallow fusion
+    with an external language model. This implementation uses a RNNLM as the
+    external language model.
+
+    Args:
+        model (Transducer):
+            The transducer model
+        encoder_out (torch.Tensor):
+            Encoder output in (N,T,C)
+        encoder_out_lens (torch.Tensor):
+            A 1-D tensor of shape (N,), containing the number of
+            valid frames in encoder_out before padding.
+        sp:
+            Sentence piece generator.
+        LODR_lm:
+            A low order n-gram LM
+        LODR_lm_scale:
+            The scale of the LODR_lm
+        rnnlm (RnnLmModel):
+            RNNLM, the external language model
+        rnnlm_scale (float):
+            scale of RNNLM in shallow fusion
+        beam (int, optional):
+            Beam size. Defaults to 4.
+
+    Returns:
+      Return a list-of-list of token IDs. ans[i] is the decoding results
+      for the i-th utterance.
+
+    """
+    assert encoder_out.ndim == 3, encoder_out.shape
+    assert encoder_out.size(0) >= 1, encoder_out.size(0)
+    assert rnnlm is not None
+    lm_scale = rnnlm_scale
+    vocab_size = rnnlm.vocab_size
+
+    packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
+        input=encoder_out,
+        lengths=encoder_out_lens.cpu(),
+        batch_first=True,
+        enforce_sorted=False,
+    )
+
+    blank_id = model.decoder.blank_id
+    sos_id = sp.piece_to_id("")
+    unk_id = getattr(model, "unk_id", blank_id)
+    context_size = model.decoder.context_size
+    device = next(model.parameters()).device
+
+    batch_size_list = packed_encoder_out.batch_sizes.tolist()
+    N = encoder_out.size(0)
+    assert torch.all(encoder_out_lens > 0), encoder_out_lens
+    assert N == batch_size_list[0], (N, batch_size_list)
+
+    # get initial lm score and lm state by scoring the "sos" token
+    sos_token = torch.tensor([[sos_id]]).to(torch.int64).to(device)
+    init_score, init_states = rnnlm.score_token(sos_token)
+
+    B = [HypothesisList() for _ in range(N)]
+    for i in range(N):
+        B[i].add(
+            Hypothesis(
+                ys=[blank_id] * context_size,
+                log_prob=torch.zeros(1, dtype=torch.float32, device=device),
+                state=init_states,  # state of the RNNLM
+                lm_score=init_score.reshape(-1),
+                state_cost=NgramLmStateCost(
+                    LODR_lm
+                ),  # state of the source domain ngram
+            )
+        )
+
+    rnnlm.clean_cache()
+    encoder_out = model.joiner.encoder_proj(packed_encoder_out.data)
+
+    offset = 0
+    finalized_B = []
+    for batch_size in batch_size_list:
+        start = offset
+        end = offset + batch_size
+        current_encoder_out = encoder_out.data[start:end]  # get batch
+        current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1)
+        # current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim)
+        offset = end
+
+        finalized_B = B[batch_size:] + finalized_B
+        B = B[:batch_size]
+
+        hyps_shape = get_hyps_shape(B).to(device)
+
+        A = [list(b) for b in B]
+        B = [HypothesisList() for _ in range(batch_size)]
+
+        ys_log_probs = torch.cat(
+            [hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps]
+        )
+
+        decoder_input = torch.tensor(
+            [hyp.ys[-context_size:] for hyps in A for hyp in hyps],
+            device=device,
+            dtype=torch.int64,
+        )  # (num_hyps, context_size)
+
+        decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1)
+        decoder_out = model.joiner.decoder_proj(decoder_out)
+
+        current_encoder_out = torch.index_select(
+            current_encoder_out,
+            dim=0,
+            index=hyps_shape.row_ids(1).to(torch.int64),
+        )  # (num_hyps, 1, 1, encoder_out_dim)
+
+        logits = model.joiner(
+            current_encoder_out,
+            decoder_out,
+            project_input=False,
+        )  # (num_hyps, 1, 1, vocab_size)
+
+        logits = logits.squeeze(1).squeeze(1)  # (num_hyps, vocab_size)
+
+        log_probs = logits.log_softmax(dim=-1)  # (num_hyps, vocab_size)
+
+        log_probs.add_(ys_log_probs)
+
+        vocab_size = log_probs.size(-1)
+
+        log_probs = log_probs.reshape(-1)
+
+        row_splits = hyps_shape.row_splits(1) * vocab_size
+        log_probs_shape = k2.ragged.create_ragged_shape2(
+            row_splits=row_splits, cached_tot_size=log_probs.numel()
+        )
+        ragged_log_probs = k2.RaggedTensor(
+            shape=log_probs_shape, value=log_probs
+        )
+        """
+        for all hyps with a non-blank new token, score this token.
+        It is a little confusing here because this for-loop
+        looks very similar to the one below. Here, we go through all
+        top-k tokens and only add the non-blanks ones to the token_list.
+        The RNNLM will score those tokens given the LM states. Note that
+        the variable `scores` is the LM score after seeing the new
+        non-blank token.
+        """
+        token_list = []
+        hs = []
+        cs = []
+        for i in range(batch_size):
+            topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam)
+
+            with warnings.catch_warnings():
+                warnings.simplefilter("ignore")
+                topk_hyp_indexes = (topk_indexes // vocab_size).tolist()
+                topk_token_indexes = (topk_indexes % vocab_size).tolist()
+            for k in range(len(topk_hyp_indexes)):
+                hyp_idx = topk_hyp_indexes[k]
+                hyp = A[i][hyp_idx]
+
+                new_token = topk_token_indexes[k]
+                if new_token not in (blank_id, unk_id):
+                    assert new_token != 0, new_token
+                    token_list.append([new_token])
+                    # store the LSTM states
+                    hs.append(hyp.state[0])
+                    cs.append(hyp.state[1])
+
+        # forward RNNLM to get new states and scores
+        if len(token_list) != 0:
+            tokens_to_score = (
+                torch.tensor(token_list)
+                .to(torch.int64)
+                .to(device)
+                .reshape(-1, 1)
+            )
+
+            hs = torch.cat(hs, dim=1).to(device)
+            cs = torch.cat(cs, dim=1).to(device)
+            scores, lm_states = rnnlm.score_token(tokens_to_score, (hs, cs))
+
+        count = 0  # index, used to locate score and lm states
+        for i in range(batch_size):
+            topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam)
+
+            with warnings.catch_warnings():
+                warnings.simplefilter("ignore")
+                topk_hyp_indexes = (topk_indexes // vocab_size).tolist()
+                topk_token_indexes = (topk_indexes % vocab_size).tolist()
+
+            for k in range(len(topk_hyp_indexes)):
+                hyp_idx = topk_hyp_indexes[k]
+                hyp = A[i][hyp_idx]
+
+                ys = hyp.ys[:]
+
+                # current score of hyp
+                lm_score = hyp.lm_score
+                state = hyp.state
+
+                hyp_log_prob = topk_log_probs[k]  # get score of current hyp
+                new_token = topk_token_indexes[k]
+                if new_token not in (blank_id, unk_id):
+
+                    ys.append(new_token)
+                    state_cost = hyp.state_cost.forward_one_step(new_token)
+
+                    # calculate the score of the latest token
+                    current_ngram_score = (
+                        state_cost.lm_score - hyp.state_cost.lm_score
+                    )
+
+                    assert current_ngram_score <= 0.0, (
+                        state_cost.lm_score,
+                        hyp.state_cost.lm_score,
+                    )
+                    # score = score + RNNLM_score - LODR_score
+                    # LODR_LM_scale is a negative number here
+                    hyp_log_prob += (
+                        lm_score[new_token] * lm_scale
+                        + LODR_lm_scale * current_ngram_score
+                    )  # add the lm score
+
+                    lm_score = scores[count]
+                    state = (
+                        lm_states[0][:, count, :].unsqueeze(1),
+                        lm_states[1][:, count, :].unsqueeze(1),
+                    )
+                    count += 1
+                else:
+                    state_cost = hyp.state_cost
+
+                new_hyp = Hypothesis(
+                    ys=ys,
+                    log_prob=hyp_log_prob,
+                    state=state,
+                    lm_score=lm_score,
+                    state_cost=state_cost,
+                )
+                B[i].add(new_hyp)
+
+    B = B + finalized_B
+    best_hyps = [b.get_most_probable(length_norm=True) for b in B]
+
+    sorted_ans = [h.ys[context_size:] for h in best_hyps]
+    ans = []
+    unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
+    for i in range(N):
+        ans.append(sorted_ans[unsorted_indices[i]])
+
+    return ans
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py
index 03137501f..e00aab34a 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py
@@ -1,6 +1,7 @@
 #!/usr/bin/env python3
 #
-# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
+# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang
+#                                            Xiaoyu Yang)
 #
 # See ../../../../LICENSE for clarification regarding multiple authors
 #
@@ -90,8 +91,40 @@ Usage:
     --beam 20.0 \
     --max-contexts 8 \
     --max-states 64
-"""
 
+(8) modified beam search (with RNNLM shallow fusion)
+./pruned_transducer_stateless3/decode.py \
+    --epoch 28 \
+    --avg 15 \
+    --exp-dir ./pruned_transducer_stateless3/exp \
+    --max-duration 600 \
+    --decoding-method modified_beam_search_rnnlm_shallow_fusion \
+    --beam 4 \
+    --rnn-lm-scale 0.3 \
+    --rnn-lm-exp-dir /path/to/RNNLM \
+    --rnn-lm-epoch 99 \
+    --rnn-lm-avg 1 \
+    --rnn-lm-num-layers 3 \
+    --rnn-lm-tie-weights 1
+
+(9) modified beam search with RNNLM shallow fusion + LODR
+./pruned_transducer_stateless3/decode.py \
+    --epoch 28 \
+    --avg 15 \
+    --max-duration 600 \
+    --exp-dir ./pruned_transducer_stateless3/exp \
+    --decoding-method modified_beam_search_rnnlm_LODR \
+    --beam 4 \
+    --max-contexts 4 \
+    --rnn-lm-scale 0.4 \
+    --rnn-lm-exp-dir /path/to/RNNLM/exp \
+    --rnn-lm-epoch 99 \
+    --rnn-lm-avg 1 \
+    --rnn-lm-num-layers 3 \
+    --rnn-lm-tie-weights 1 \
+    --tokens-ngram 2 \
+    --ngram-lm-scale -0.16 \
+"""
 
 import argparse
 import logging
@@ -116,10 +149,14 @@ from beam_search import (
     greedy_search,
     greedy_search_batch,
     modified_beam_search,
+    modified_beam_search_ngram_rescoring,
+    modified_beam_search_rnnlm_LODR,
+    modified_beam_search_rnnlm_shallow_fusion,
 )
 from librispeech import LibriSpeech
 from train import add_model_arguments, get_params, get_transducer_model
 
+from icefall import NgramLm
 from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
 from icefall.lexicon import Lexicon
 from icefall.rnn_lm.model import RnnLmModel
@@ -202,6 +239,9 @@ def get_parser():
           - fast_beam_search_nbest
           - fast_beam_search_nbest_oracle
           - fast_beam_search_nbest_LG
+          - modified_beam_search_ngram_rescoring
+          - modified_beam_search_rnnlm_shallow_fusion
+          - modified_beam_search_rnnlm_LODR
         If you use fast_beam_search_nbest_LG, you have to specify
         `--lang-dir`, which should contain `LG.pt`.
         """,
@@ -263,6 +303,7 @@ def get_parser():
         default=2,
         help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
+
     parser.add_argument(
         "--max-sym-per-frame",
         type=int,
@@ -341,6 +382,15 @@ def get_parser():
          """,
     )
 
+    parser.add_argument(
+        "--rnn-lm-scale",
+        type=float,
+        default=0.0,
+        help="""Used only when --method is modified-beam-search_rnnlm_shallow_fusion.
+        It specifies the path to RNN LM exp dir.
+        """,
+    )
+
     parser.add_argument(
         "--rnn-lm-exp-dir",
         type=str,
@@ -397,6 +447,24 @@ def get_parser():
         """,
     )
 
+    parser.add_argument(
+        "--tokens-ngram",
+        type=int,
+        default=3,
+        help="""Token Ngram used for rescoring.
+            Used only when the decoding method is
+            modified_beam_search_ngram_rescoring""",
+    )
+
+    parser.add_argument(
+        "--backoff-id",
+        type=int,
+        default=500,
+        help="""ID of the backoff symbol.
+                Used only when the decoding method is
+                modified_beam_search_ngram_rescoring""",
+    )
+
     add_model_arguments(parser)
 
     return parser
@@ -410,7 +478,10 @@ def decode_one_batch(
     word_table: Optional[k2.SymbolTable] = None,
     decoding_graph: Optional[k2.Fsa] = None,
     G: Optional[k2.Fsa] = None,
-    rnn_lm_model: torch.nn.Module = None,
+    ngram_lm: Optional[NgramLm] = None,
+    ngram_lm_scale: float = 1.0,
+    rnn_lm_model: Optional[RnnLmModel] = None,
+    rnnlm_scale: float = 1.0,
 ) -> Dict[str, List[List[str]]]:
     """Decode one batch and return the result in a dict. The dict has the
     following format:
@@ -444,6 +515,14 @@ def decode_one_batch(
         fast_beam_search_nbest, fast_beam_search_nbest_oracle,
         or fast_beam_search_with_nbest_rescoring.
         It an FsaVec containing an acceptor.
+      rnn_lm_model:
+        A rnnlm which can be used for rescoring or shallow fusion
+      rnnlm_scale:
+        The scale of the rnnlm.
+      ngram_lm:
+        A ngram lm. Used in LODR decoding.
+      ngram_lm_scale:
+        The scale of the ngram language model.
     Returns:
       Return the decoding result. See above description for the format of
       the returned dict.
@@ -607,6 +686,43 @@ def decode_one_batch(
             nbest_scale=params.nbest_scale,
             temperature=params.temperature,
         )
+    elif params.decoding_method == "modified_beam_search_ngram_rescoring":
+        hyp_tokens = modified_beam_search_ngram_rescoring(
+            model=model,
+            encoder_out=encoder_out,
+            encoder_out_lens=encoder_out_lens,
+            ngram_lm=ngram_lm,
+            ngram_lm_scale=ngram_lm_scale,
+            beam=params.beam_size,
+        )
+        for hyp in sp.decode(hyp_tokens):
+            hyps.append(hyp.split())
+    elif params.decoding_method == "modified_beam_search_rnnlm_shallow_fusion":
+        hyp_tokens = modified_beam_search_rnnlm_shallow_fusion(
+            model=model,
+            encoder_out=encoder_out,
+            encoder_out_lens=encoder_out_lens,
+            beam=params.beam_size,
+            sp=sp,
+            rnnlm=rnn_lm_model,
+            rnnlm_scale=rnnlm_scale,
+        )
+        for hyp in sp.decode(hyp_tokens):
+            hyps.append(hyp.split())
+    elif params.decoding_method == "modified_beam_search_rnnlm_LODR":
+        hyp_tokens = modified_beam_search_rnnlm_LODR(
+            model=model,
+            encoder_out=encoder_out,
+            encoder_out_lens=encoder_out_lens,
+            beam=params.beam_size,
+            sp=sp,
+            LODR_lm=ngram_lm,
+            LODR_lm_scale=ngram_lm_scale,
+            rnnlm=rnn_lm_model,
+            rnnlm_scale=rnnlm_scale,
+        )
+        for hyp in sp.decode(hyp_tokens):
+            hyps.append(hyp.split())
     else:
         batch_size = encoder_out.size(0)
 
@@ -693,7 +809,10 @@ def decode_dataset(
     word_table: Optional[k2.SymbolTable] = None,
     decoding_graph: Optional[k2.Fsa] = None,
     G: Optional[k2.Fsa] = None,
-    rnn_lm_model: torch.nn.Module = None,
+    ngram_lm: Optional[NgramLm] = None,
+    ngram_lm_scale: float = 1.0,
+    rnn_lm_model: Optional[RnnLmModel] = None,
+    rnnlm_scale: float = 1.0,
 ) -> Dict[str, List[Tuple[List[str], List[str]]]]:
     """Decode dataset.
 
@@ -749,7 +868,10 @@ def decode_dataset(
             decoding_graph=decoding_graph,
             batch=batch,
             G=G,
+            ngram_lm=ngram_lm,
+            ngram_lm_scale=ngram_lm_scale,
             rnn_lm_model=rnn_lm_model,
+            rnnlm_scale=rnnlm_scale,
         )
 
         for name, hyps in hyps_dict.items():
@@ -900,6 +1022,9 @@ def main():
         "modified_beam_search",
         "fast_beam_search_with_nbest_rescoring",
         "fast_beam_search_with_nbest_rnn_rescoring",
+        "modified_beam_search_rnnlm_LODR",
+        "modified_beam_search_ngram_rescoring",
+        "modified_beam_search_rnnlm_shallow_fusion",
     )
     params.res_dir = params.exp_dir / params.decoding_method
 
@@ -930,6 +1055,13 @@ def main():
         params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
         params.suffix += f"-temperature-{params.temperature}"
 
+    if "rnnlm" in params.decoding_method:
+        params.suffix += f"-rnnlm-lm-scale-{params.rnn_lm_scale}"
+    if "LODR" in params.decoding_method:
+        params.suffix += "-LODR"
+    if "ngram" in params.decoding_method:
+        params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
+
     setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
     logging.info("Decoding started")
 
@@ -1048,6 +1180,44 @@ def main():
         word_table = None
         rnn_lm_model = None
 
+    # only load N-gram LM when needed
+    if "ngram" in params.decoding_method or "LODR" in params.decoding_method:
+        lm_filename = f"{params.tokens_ngram}gram.fst.txt"
+        logging.info(f"lm filename: {lm_filename}")
+        ngram_lm = NgramLm(
+            str(params.lang_dir / lm_filename),
+            backoff_id=params.backoff_id,
+            is_binary=False,
+        )
+        logging.info(f"num states: {ngram_lm.lm.num_states}")
+        ngram_lm_scale = params.ngram_lm_scale
+    else:
+        ngram_lm = None
+        ngram_lm_scale = None
+
+    # only load rnnlm if used
+    if "rnnlm" in params.decoding_method:
+        rnn_lm_scale = params.rnn_lm_scale
+
+        rnn_lm_model = RnnLmModel(
+            vocab_size=params.vocab_size,
+            embedding_dim=params.rnn_lm_embedding_dim,
+            hidden_dim=params.rnn_lm_hidden_dim,
+            num_layers=params.rnn_lm_num_layers,
+            tie_weights=params.rnn_lm_tie_weights,
+        )
+        assert params.rnn_lm_avg == 1
+
+        load_checkpoint(
+            f"{params.rnn_lm_exp_dir}/epoch-{params.rnn_lm_epoch}.pt",
+            rnn_lm_model,
+        )
+        rnn_lm_model.to(device)
+        rnn_lm_model.eval()
+    else:
+        rnn_lm_model = None
+        rnn_lm_scale = 0.0
+
     num_param = sum([p.numel() for p in model.parameters()])
     logging.info(f"Number of model parameters: {num_param}")
 
@@ -1074,7 +1244,10 @@ def main():
             word_table=word_table,
             decoding_graph=decoding_graph,
             G=G,
+            ngram_lm=ngram_lm,
+            ngram_lm_scale=ngram_lm_scale,
             rnn_lm_model=rnn_lm_model,
+            rnnlm_scale=rnn_lm_scale,
         )
 
         save_results(

From 556c63fbb741bcbc1669ec6848e06b08480d001f Mon Sep 17 00:00:00 2001
From: Fangjun Kuang 
Date: Thu, 1 Dec 2022 08:58:18 +0800
Subject: [PATCH 026/174] Describe how to fix segfault in doc (#719)

---
 docs/source/installation/index.rst | 11 +++++++++++
 1 file changed, 11 insertions(+)

diff --git a/docs/source/installation/index.rst b/docs/source/installation/index.rst
index c4474c3d9..5b9fb2664 100644
--- a/docs/source/installation/index.rst
+++ b/docs/source/installation/index.rst
@@ -393,6 +393,17 @@ Now let us run the training part:
   We use ``export CUDA_VISIBLE_DEVICES=""`` so that ``icefall`` uses CPU
   even if there are GPUs available.
 
+.. hint::
+
+   In case you get a ``Segmentation fault (core dump)`` error, please use:
+
+      .. code-block:: bash
+
+        export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
+
+   See more at `` if you are
+   interested.
+
 The training log is given below:
 
 .. code-block::

From 2bca7032afb0d5b9eb60f7bcf3bc15ad1e8d8a83 Mon Sep 17 00:00:00 2001
From: Fangjun Kuang 
Date: Thu, 1 Dec 2022 15:57:43 +0800
Subject: [PATCH 027/174] Update RNNLM training scripts (#720)

* Update RNNLM training scripts

* Fix a typo

* Fix CI
---
 .github/workflows/run-ptb-rnn-lm.yml         | 67 ++++++++++++++++++++
 egs/librispeech/ASR/local/train_bpe_model.py |  4 ++
 egs/ptb/LM/prepare.sh                        | 38 ++++++-----
 egs/ptb/LM/rnn_lm                            |  1 +
 egs/ptb/LM/train-rnn-lm.sh                   | 67 ++++++++++++++++++++
 icefall/rnn_lm/compute_perplexity.py         |  2 +-
 icefall/rnn_lm/dataset.py                    |  4 +-
 icefall/rnn_lm/train.py                      | 10 +--
 8 files changed, 170 insertions(+), 23 deletions(-)
 create mode 100644 .github/workflows/run-ptb-rnn-lm.yml
 create mode 120000 egs/ptb/LM/rnn_lm
 create mode 100755 egs/ptb/LM/train-rnn-lm.sh

diff --git a/.github/workflows/run-ptb-rnn-lm.yml b/.github/workflows/run-ptb-rnn-lm.yml
new file mode 100644
index 000000000..8ebc2e79b
--- /dev/null
+++ b/.github/workflows/run-ptb-rnn-lm.yml
@@ -0,0 +1,67 @@
+name: run-ptb-rnn-lm-training
+
+on:
+  push:
+    branches:
+      - master
+  pull_request:
+    types: [labeled]
+
+  schedule:
+    # minute (0-59)
+    # hour (0-23)
+    # day of the month (1-31)
+    # month (1-12)
+    # day of the week (0-6)
+    # nightly build at 15:50 UTC time every day
+    - cron: "50 15 * * *"
+
+jobs:
+  run_ptb_rnn_lm_training:
+    if: github.event.label.name == 'ready' || github.event.label.name == 'rnnlm' || github.event_name == 'push' || github.event_name == 'schedule'
+    runs-on: ${{ matrix.os }}
+    strategy:
+      matrix:
+        os: [ubuntu-latest]
+        python-version: ["3.8"]
+
+      fail-fast: false
+
+    steps:
+      - uses: actions/checkout@v2
+        with:
+          fetch-depth: 0
+
+      - name: Setup Python ${{ matrix.python-version }}
+        uses: actions/setup-python@v2
+        with:
+          python-version: ${{ matrix.python-version }}
+          cache: 'pip'
+          cache-dependency-path: '**/requirements-ci.txt'
+
+      - name: Install Python dependencies
+        run: |
+          grep -v '^#' ./requirements-ci.txt  | grep -v kaldifst | xargs -n 1 -L 1 pip install
+          pip uninstall -y protobuf
+          pip install --no-binary protobuf protobuf
+
+      - name: Prepare data
+        shell: bash
+        run: |
+          export PYTHONPATH=$PWD:$PYTHONPATH
+          cd egs/ptb/LM
+          ./prepare.sh
+
+      - name: Run training
+        shell: bash
+        run: |
+          export PYTHONPATH=$PWD:$PYTHONPATH
+          cd egs/ptb/LM
+          ./train-rnn-lm.sh --world-size 1 --num-epochs 5 --use-epoch 4 --use-avg 2
+
+      - name: Upload pretrained models
+        uses: actions/upload-artifact@v2
+        if: github.event.label.name == 'ready' || github.event.label.name == 'rnnlm' || github.event_name == 'push' || github.event_name == 'schedule'
+        with:
+          name: python-${{ matrix.python-version }}-ubuntu-rnn-lm-ptb
+          path: egs/ptb/LM/my-rnnlm-exp/
diff --git a/egs/librispeech/ASR/local/train_bpe_model.py b/egs/librispeech/ASR/local/train_bpe_model.py
index 42aba9572..7f6f47e16 100755
--- a/egs/librispeech/ASR/local/train_bpe_model.py
+++ b/egs/librispeech/ASR/local/train_bpe_model.py
@@ -89,6 +89,10 @@ def main():
             bos_id=-1,
             eos_id=-1,
         )
+    else:
+        print(f"{model_file} exists - skipping")
+        return
+
 
     shutil.copyfile(model_file, f"{lang_dir}/bpe.model")
 
diff --git a/egs/ptb/LM/prepare.sh b/egs/ptb/LM/prepare.sh
index 91c3c667a..69fab999a 100755
--- a/egs/ptb/LM/prepare.sh
+++ b/egs/ptb/LM/prepare.sh
@@ -22,9 +22,9 @@ dl_dir=$PWD/download
 # if the array contains xxx, yyy
 vocab_sizes=(
   500
-  1000
-  2000
-  5000
+  # 1000
+  # 2000
+  # 5000
 )
 
 # All files generated by this script are saved in "data".
@@ -42,11 +42,14 @@ log "dl_dir: $dl_dir"
 
 if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
   log "Stage -1: Download data"
+
+  # Caution: The downloaded data has already been normalized for LM training.
+
   if [ ! -f $dl_dir/.complete ]; then
-    url=https://raw.githubusercontent.com/townie/PTB-dataset-from-Tomas-Mikolov-s-webpage/master/data/
-    wget --no-verbose --directory-prefix $dl_dir $url/ptb.train.txt
-    wget --no-verbose --directory-prefix $dl_dir $url/ptb.valid.txt
-    wget --no-verbose --directory-prefix $dl_dir $url/ptb.test.txt
+    url=http://raw.githubusercontent.com/townie/PTB-dataset-from-Tomas-Mikolov-s-webpage/master/data
+    wget --directory-prefix $dl_dir $url/ptb.train.txt
+    wget --directory-prefix $dl_dir $url/ptb.valid.txt
+    wget --directory-prefix $dl_dir $url/ptb.test.txt
     touch $dl_dir/.complete
   fi
 fi
@@ -54,11 +57,15 @@ fi
 if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
   log "Stage 0: Train BPE model"
 
+  # Caution: You have to use the same bpe model for training your acoustic model
+  # Caution: You have to use the same bpe model for training your acoustic model
+  # Caution: You have to use the same bpe model for training your acoustic model
+
   for vocab_size in ${vocab_sizes[@]}; do
-    out_dir=data/bpe_${vocab_size}
-    mkdir -p $out_dir
+    lang_dir=data/lang_bpe_${vocab_size}
+    mkdir -p $lang_dir
     ./local/train_bpe_model.py \
-      --out-dir $out_dir \
+      --lang-dir $lang_dir \
       --vocab-size $vocab_size \
       --transcript $dl_dir/ptb.train.txt
   done
@@ -69,20 +76,21 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
   # Note: ptb.train.txt has already been normalized
 
   for vocab_size in ${vocab_sizes[@]}; do
-    out_dir=data/bpe_${vocab_size}
+    lang_dir=data/lang_bpe_${vocab_size}
+    out_dir=data/lm_training_bpe_${vocab_size}
     mkdir -p $out_dir
     ./local/prepare_lm_training_data.py \
-      --bpe-model $out_dir/bpe.model \
+      --bpe-model $lang_dir/bpe.model \
       --lm-data $dl_dir/ptb.train.txt \
       --lm-archive $out_dir/lm_data.pt
 
     ./local/prepare_lm_training_data.py \
-      --bpe-model $out_dir/bpe.model \
+      --bpe-model $lang_dir/bpe.model \
       --lm-data $dl_dir/ptb.valid.txt \
       --lm-archive $out_dir/lm_data-valid.pt
 
     ./local/prepare_lm_training_data.py \
-      --bpe-model $out_dir/bpe.model \
+      --bpe-model $lang_dir/bpe.model \
       --lm-data $dl_dir/ptb.test.txt \
       --lm-archive $out_dir/lm_data-test.pt
   done
@@ -98,7 +106,7 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
   # in a sentence.
 
   for vocab_size in ${vocab_sizes[@]}; do
-    out_dir=data/bpe_${vocab_size}
+    out_dir=data/lm_training_bpe_${vocab_size}
     mkdir -p $out_dir
     ./local/sort_lm_training_data.py \
       --in-lm-data $out_dir/lm_data.pt \
diff --git a/egs/ptb/LM/rnn_lm b/egs/ptb/LM/rnn_lm
new file mode 120000
index 000000000..87f29771e
--- /dev/null
+++ b/egs/ptb/LM/rnn_lm
@@ -0,0 +1 @@
+../../../icefall/rnn_lm
\ No newline at end of file
diff --git a/egs/ptb/LM/train-rnn-lm.sh b/egs/ptb/LM/train-rnn-lm.sh
new file mode 100755
index 000000000..29c609ee1
--- /dev/null
+++ b/egs/ptb/LM/train-rnn-lm.sh
@@ -0,0 +1,67 @@
+#!/usr/bin/env bash
+
+# Please run ./prepare.sh first
+
+stage=-1
+stop_stage=100
+
+# Number of GPUs to use for training
+world_size=1
+
+# Number of epochs to train
+num_epochs=20
+
+# Use this epoch for computing ppl
+use_epoch=19
+
+# number of models to average for computing ppl
+use_avg=2
+
+exp_dir=./my-rnnlm-exp
+
+. shared/parse_options.sh || exit 1
+
+log() {
+  # This function is from espnet
+  local fname=${BASH_SOURCE[1]##*/}
+  echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
+}
+
+if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
+  log "Training RNN LM"
+
+  ./rnn_lm/train.py \
+    --exp-dir $exp_dir \
+    --start-epoch 0 \
+    --num-epochs $num_epochs \
+    --world-size $world_size \
+    --use-fp16 0 \
+    --vocab-size 500 \
+    \
+    --lm-data ./data/lm_training_bpe_500/sorted_lm_data.pt \
+    --lm-data-valid ./data/lm_training_bpe_500/sorted_lm_data-valid.pt \
+    \
+    --embedding-dim 800 \
+    --hidden-dim 200 \
+    --num-layers 2 \
+    --tie-weights false \
+    --batch-size 50
+fi
+
+if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
+  log "Computing perplexity"
+
+  ./rnn_lm/compute_perplexity.py \
+    --exp-dir $exp_dir \
+    --epoch $use_epoch \
+    --avg $use_avg \
+    --vocab-size 500 \
+    \
+    --lm-data ./data/lm_training_bpe_500/sorted_lm_data-test.pt \
+    \
+    --embedding-dim 800 \
+    --hidden-dim 200 \
+    --num-layers 2 \
+    --tie-weights false \
+    --batch-size 50
+fi
diff --git a/icefall/rnn_lm/compute_perplexity.py b/icefall/rnn_lm/compute_perplexity.py
index 550801a8f..f75a89590 100755
--- a/icefall/rnn_lm/compute_perplexity.py
+++ b/icefall/rnn_lm/compute_perplexity.py
@@ -20,7 +20,7 @@ Usage:
   ./rnn_lm/compute_perplexity.py \
     --epoch 4 \
     --avg 2 \
-    --lm-data ./data/bpe_500/sorted_lm_data-test.pt
+    --lm-data ./data/lm_training_bpe_500/sorted_lm_data-test.pt
 
 """
 
diff --git a/icefall/rnn_lm/dataset.py b/icefall/rnn_lm/dataset.py
index 4bf982503..53be53f64 100644
--- a/icefall/rnn_lm/dataset.py
+++ b/icefall/rnn_lm/dataset.py
@@ -1,4 +1,4 @@
-# Copyright (c)  2021  Xiaomi Corporation (authors: Fangjun Kuang)
+# Copyright (c)  2021  Xiaomi Corporation (authors: Daniel Povey, Fangjun Kuang)
 #
 # See ../../../../LICENSE for clarification regarding multiple authors
 #
@@ -194,7 +194,7 @@ def get_dataloader(
         batch_size=params.batch_size,
     )
     if is_distributed:
-        sampler = DistributedSampler(dataset, shuffle=True, drop_last=False)
+        sampler = DistributedSampler(dataset, shuffle=True, drop_last=True)
     else:
         sampler = None
 
diff --git a/icefall/rnn_lm/train.py b/icefall/rnn_lm/train.py
index 3ba5bfbee..803da99d6 100755
--- a/icefall/rnn_lm/train.py
+++ b/icefall/rnn_lm/train.py
@@ -24,7 +24,7 @@ Usage:
         --use-fp16 0 \
         --embedding-dim 800 \
         --hidden-dim 200 \
-        --num-layers 2\
+        --num-layers 2 \
         --batch-size 400
 
 """
@@ -83,7 +83,7 @@ def get_parser():
     parser.add_argument(
         "--num-epochs",
         type=int,
-        default=10,
+        default=30,
         help="Number of epochs to train.",
     )
 
@@ -110,14 +110,14 @@ def get_parser():
     parser.add_argument(
         "--use-fp16",
         type=str2bool,
-        default=False,
+        default=True,
         help="Whether to use half precision training.",
     )
 
     parser.add_argument(
         "--batch-size",
         type=int,
-        default=50,
+        default=400,
     )
 
     parser.add_argument(
@@ -165,7 +165,7 @@ def get_parser():
     parser.add_argument(
         "--tie-weights",
         type=str2bool,
-        default=False,
+        default=True,
         help="""True to share the weights between the input embedding layer and the
         last output linear layer
         """,

From 04c9fc9c9f9e481cbfae18bb34252b878ff51f6a Mon Sep 17 00:00:00 2001
From: Fangjun Kuang 
Date: Fri, 2 Dec 2022 09:18:28 +0800
Subject: [PATCH 028/174] Fix for older versions of k2 (#725)

---
 icefall/graph_compiler.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/icefall/graph_compiler.py b/icefall/graph_compiler.py
index 84be81254..0dcd777ad 100644
--- a/icefall/graph_compiler.py
+++ b/icefall/graph_compiler.py
@@ -81,7 +81,7 @@ class CtcTrainingGraphCompiler(object):
 
         self.ctc_topo._is_repeat_token_ = (
             self.ctc_topo.labels != self.ctc_topo.aux_labels
-        )
+        ).int()
 
         decoding_graph = k2.compose(
             self.ctc_topo, fsa_with_self_loops, treat_epsilons_specially=False

From 6533f359c998cee6fcb618f7b221cbfee05512e8 Mon Sep 17 00:00:00 2001
From: Fangjun Kuang 
Date: Fri, 2 Dec 2022 10:53:06 +0800
Subject: [PATCH 029/174] Fix CI (#726)

* Fix CI

* Disable shuffle for yesno.

See https://github.com/k2-fsa/icefall/issues/197
---
 .github/workflows/build-doc.yml               |  4 ++
 .github/workflows/run-aishell-2022-06-20.yml  |  4 ++
 .../workflows/run-gigaspeech-2022-05-13.yml   |  4 ++
 .../workflows/run-librispeech-2022-03-12.yml  |  4 ++
 .../workflows/run-librispeech-2022-04-29.yml  |  4 ++
 .../workflows/run-librispeech-2022-05-13.yml  |  4 ++
 .../run-librispeech-2022-11-11-stateless7.yml |  4 ++
 .../run-librispeech-2022-11-14-stateless8.yml |  4 ++
 ...-librispeech-conformer-ctc3-2022-11-28.yml |  4 ++
 ...-lstm-transducer-stateless2-2022-09-03.yml |  4 ++
 ...runed-transducer-stateless3-2022-05-13.yml |  4 ++
 ...aming-transducer-stateless2-2022-06-26.yml |  4 ++
 ...peech-transducer-stateless2-2022-04-19.yml |  4 ++
 .../run-pretrained-conformer-ctc.yml          |  4 ++
 ...-transducer-stateless-librispeech-100h.yml |  4 ++
 ...r-stateless-librispeech-multi-datasets.yml |  4 ++
 ...ransducer-stateless-modified-2-aishell.yml |  4 ++
 ...-transducer-stateless-modified-aishell.yml |  4 ++
 .../run-pretrained-transducer-stateless.yml   |  4 ++
 .../workflows/run-pretrained-transducer.yml   |  4 ++
 .github/workflows/run-ptb-rnn-lm.yml          |  4 ++
 ...netspeech-pruned-transducer-stateless2.yml |  6 +-
 .github/workflows/run-yesno-recipe.yml        | 10 +++-
 .github/workflows/style_check.yml             |  4 ++
 .github/workflows/test.yml                    | 60 ++++++++-----------
 egs/librispeech/ASR/local/train_bpe_model.py  |  1 -
 .../beam_search.py                            | 13 +---
 .../test_scaling.py                           |  8 ---
 egs/yesno/ASR/tdnn/asr_datamodule.py          |  2 +-
 29 files changed, 128 insertions(+), 60 deletions(-)

diff --git a/.github/workflows/build-doc.yml b/.github/workflows/build-doc.yml
index dd0969f51..d7fe2c964 100644
--- a/.github/workflows/build-doc.yml
+++ b/.github/workflows/build-doc.yml
@@ -26,6 +26,10 @@ on:
   pull_request:
     types: [labeled]
 
+concurrency:
+  group: build_doc-${{ github.ref }}
+  cancel-in-progress: true
+
 jobs:
   build-doc:
     if: github.event.label.name == 'doc' || github.event_name == 'push'
diff --git a/.github/workflows/run-aishell-2022-06-20.yml b/.github/workflows/run-aishell-2022-06-20.yml
index e46b01a08..1865a0da8 100644
--- a/.github/workflows/run-aishell-2022-06-20.yml
+++ b/.github/workflows/run-aishell-2022-06-20.yml
@@ -34,6 +34,10 @@ on:
     # nightly build at 15:50 UTC time every day
     - cron: "50 15 * * *"
 
+concurrency:
+  group: run_aishell_2022_06_20-${{ github.ref }}
+  cancel-in-progress: true
+
 jobs:
   run_aishell_2022_06_20:
     if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
diff --git a/.github/workflows/run-gigaspeech-2022-05-13.yml b/.github/workflows/run-gigaspeech-2022-05-13.yml
index c631927fa..e438c5dba 100644
--- a/.github/workflows/run-gigaspeech-2022-05-13.yml
+++ b/.github/workflows/run-gigaspeech-2022-05-13.yml
@@ -33,6 +33,10 @@ on:
     # nightly build at 15:50 UTC time every day
     - cron: "50 15 * * *"
 
+concurrency:
+  group: run_gigaspeech_2022_05_13-${{ github.ref }}
+  cancel-in-progress: true
+
 jobs:
   run_gigaspeech_2022_05_13:
     if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
diff --git a/.github/workflows/run-librispeech-2022-03-12.yml b/.github/workflows/run-librispeech-2022-03-12.yml
index 5df710006..3ba6850cd 100644
--- a/.github/workflows/run-librispeech-2022-03-12.yml
+++ b/.github/workflows/run-librispeech-2022-03-12.yml
@@ -33,6 +33,10 @@ on:
     # nightly build at 15:50 UTC time every day
     - cron: "50 15 * * *"
 
+concurrency:
+  group: run_librispeech_2022_03_12-${{ github.ref }}
+  cancel-in-progress: true
+
 jobs:
   run_librispeech_2022_03_12:
     if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
diff --git a/.github/workflows/run-librispeech-2022-04-29.yml b/.github/workflows/run-librispeech-2022-04-29.yml
index 24c062442..595b410b8 100644
--- a/.github/workflows/run-librispeech-2022-04-29.yml
+++ b/.github/workflows/run-librispeech-2022-04-29.yml
@@ -33,6 +33,10 @@ on:
     # nightly build at 15:50 UTC time every day
     - cron: "50 15 * * *"
 
+concurrency:
+  group: run_librispeech_2022_04_29-${{ github.ref }}
+  cancel-in-progress: true
+
 jobs:
   run_librispeech_2022_04_29:
     if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
diff --git a/.github/workflows/run-librispeech-2022-05-13.yml b/.github/workflows/run-librispeech-2022-05-13.yml
index 29215ec25..eb0b06a2d 100644
--- a/.github/workflows/run-librispeech-2022-05-13.yml
+++ b/.github/workflows/run-librispeech-2022-05-13.yml
@@ -33,6 +33,10 @@ on:
     # nightly build at 15:50 UTC time every day
     - cron: "50 15 * * *"
 
+concurrency:
+  group: run_librispeech_2022_05_13-${{ github.ref }}
+  cancel-in-progress: true
+
 jobs:
   run_librispeech_2022_05_13:
     if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
diff --git a/.github/workflows/run-librispeech-2022-11-11-stateless7.yml b/.github/workflows/run-librispeech-2022-11-11-stateless7.yml
index 3b98b500e..365e2761a 100644
--- a/.github/workflows/run-librispeech-2022-11-11-stateless7.yml
+++ b/.github/workflows/run-librispeech-2022-11-11-stateless7.yml
@@ -33,6 +33,10 @@ on:
     # nightly build at 15:50 UTC time every day
     - cron: "50 15 * * *"
 
+concurrency:
+  group: run_librispeech_2022_11_11_zipformer-${{ github.ref }}
+  cancel-in-progress: true
+
 jobs:
   run_librispeech_2022_11_11_zipformer:
     if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
diff --git a/.github/workflows/run-librispeech-2022-11-14-stateless8.yml b/.github/workflows/run-librispeech-2022-11-14-stateless8.yml
index eaab35189..acb11a8f4 100644
--- a/.github/workflows/run-librispeech-2022-11-14-stateless8.yml
+++ b/.github/workflows/run-librispeech-2022-11-14-stateless8.yml
@@ -33,6 +33,10 @@ on:
     # nightly build at 15:50 UTC time every day
     - cron: "50 15 * * *"
 
+concurrency:
+  group: run_librispeech_2022_11_14_zipformer_stateless8-${{ github.ref }}
+  cancel-in-progress: true
+
 jobs:
   run_librispeech_2022_11_14_zipformer_stateless8:
     if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
diff --git a/.github/workflows/run-librispeech-conformer-ctc3-2022-11-28.yml b/.github/workflows/run-librispeech-conformer-ctc3-2022-11-28.yml
index 21f396c32..d763fb1c5 100644
--- a/.github/workflows/run-librispeech-conformer-ctc3-2022-11-28.yml
+++ b/.github/workflows/run-librispeech-conformer-ctc3-2022-11-28.yml
@@ -33,6 +33,10 @@ on:
     # nightly build at 15:50 UTC time every day
     - cron: "50 15 * * *"
 
+concurrency:
+  group: run_librispeech_2022_11_28_conformer_ctc3-${{ github.ref }}
+  cancel-in-progress: true
+
 jobs:
   run_librispeech_2022_11_28_conformer_ctc3:
     if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
diff --git a/.github/workflows/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml b/.github/workflows/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml
index 5f0acf9b8..59f116fde 100644
--- a/.github/workflows/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml
+++ b/.github/workflows/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml
@@ -16,6 +16,10 @@ on:
     # nightly build at 15:50 UTC time every day
     - cron: "50 15 * * *"
 
+concurrency:
+  group: run_librispeech_lstm_transducer_stateless2_2022_09_03-${{ github.ref }}
+  cancel-in-progress: true
+
 jobs:
   run_librispeech_lstm_transducer_stateless2_2022_09_03:
     if: github.event.label.name == 'ready' || github.event.label.name == 'LODR' || github.event.label.name == 'shallow-fusion' || github.event.label.name == 'ncnn' || github.event.label.name == 'onnx' || github.event_name == 'push' || github.event_name == 'schedule'
diff --git a/.github/workflows/run-librispeech-pruned-transducer-stateless3-2022-05-13.yml b/.github/workflows/run-librispeech-pruned-transducer-stateless3-2022-05-13.yml
index 66a2c240b..2c2bcab0c 100644
--- a/.github/workflows/run-librispeech-pruned-transducer-stateless3-2022-05-13.yml
+++ b/.github/workflows/run-librispeech-pruned-transducer-stateless3-2022-05-13.yml
@@ -33,6 +33,10 @@ on:
     # nightly build at 15:50 UTC time every day
     - cron: "50 15 * * *"
 
+concurrency:
+  group: run_librispeech_pruned_transducer_stateless3_2022_05_13-${{ github.ref }}
+  cancel-in-progress: true
+
 jobs:
   run_librispeech_pruned_transducer_stateless3_2022_05_13:
     if: github.event.label.name == 'onnx' || github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
diff --git a/.github/workflows/run-librispeech-streaming-transducer-stateless2-2022-06-26.yml b/.github/workflows/run-librispeech-streaming-transducer-stateless2-2022-06-26.yml
index 55428861c..ac7e58b20 100644
--- a/.github/workflows/run-librispeech-streaming-transducer-stateless2-2022-06-26.yml
+++ b/.github/workflows/run-librispeech-streaming-transducer-stateless2-2022-06-26.yml
@@ -33,6 +33,10 @@ on:
     # nightly build at 15:50 UTC time every day
     - cron: "50 15 * * *"
 
+concurrency:
+  group: run_librispeech_streaming_2022_06_26-${{ github.ref }}
+  cancel-in-progress: true
+
 jobs:
   run_librispeech_streaming_2022_06_26:
     if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
diff --git a/.github/workflows/run-librispeech-transducer-stateless2-2022-04-19.yml b/.github/workflows/run-librispeech-transducer-stateless2-2022-04-19.yml
index f520405e1..575727e22 100644
--- a/.github/workflows/run-librispeech-transducer-stateless2-2022-04-19.yml
+++ b/.github/workflows/run-librispeech-transducer-stateless2-2022-04-19.yml
@@ -33,6 +33,10 @@ on:
     # nightly build at 15:50 UTC time every day
     - cron: "50 15 * * *"
 
+concurrency:
+  group: run_librispeech_2022_04_19-${{ github.ref }}
+  cancel-in-progress: true
+
 jobs:
   run_librispeech_2022_04_19:
     if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
diff --git a/.github/workflows/run-pretrained-conformer-ctc.yml b/.github/workflows/run-pretrained-conformer-ctc.yml
index 9bc6a481f..7dbfd2bd9 100644
--- a/.github/workflows/run-pretrained-conformer-ctc.yml
+++ b/.github/workflows/run-pretrained-conformer-ctc.yml
@@ -23,6 +23,10 @@ on:
   pull_request:
     types: [labeled]
 
+concurrency:
+  group: run_pre_trained_conformer_ctc-${{ github.ref }}
+  cancel-in-progress: true
+
 jobs:
   run_pre_trained_conformer_ctc:
     if: github.event.label.name == 'ready' || github.event_name == 'push'
diff --git a/.github/workflows/run-pretrained-transducer-stateless-librispeech-100h.yml b/.github/workflows/run-pretrained-transducer-stateless-librispeech-100h.yml
index 7a0f30b0f..d6b3de8d4 100644
--- a/.github/workflows/run-pretrained-transducer-stateless-librispeech-100h.yml
+++ b/.github/workflows/run-pretrained-transducer-stateless-librispeech-100h.yml
@@ -32,6 +32,10 @@ on:
     # nightly build at 15:50 UTC time every day
     - cron: "50 15 * * *"
 
+concurrency:
+  group: run_pre_trained_transducer_stateless_multi_datasets_librispeech_100h-${{ github.ref }}
+  cancel-in-progress: true
+
 jobs:
   run_pre_trained_transducer_stateless_multi_datasets_librispeech_100h:
     if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
diff --git a/.github/workflows/run-pretrained-transducer-stateless-librispeech-multi-datasets.yml b/.github/workflows/run-pretrained-transducer-stateless-librispeech-multi-datasets.yml
index 797f3fe50..749fb3fca 100644
--- a/.github/workflows/run-pretrained-transducer-stateless-librispeech-multi-datasets.yml
+++ b/.github/workflows/run-pretrained-transducer-stateless-librispeech-multi-datasets.yml
@@ -32,6 +32,10 @@ on:
     # nightly build at 15:50 UTC time every day
     - cron: "50 15 * * *"
 
+concurrency:
+  group: run_pre_trained_transducer_stateless_multi_datasets_librispeech_960h-${{ github.ref }}
+  cancel-in-progress: true
+
 jobs:
   run_pre_trained_transducer_stateless_multi_datasets_librispeech_960h:
     if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
diff --git a/.github/workflows/run-pretrained-transducer-stateless-modified-2-aishell.yml b/.github/workflows/run-pretrained-transducer-stateless-modified-2-aishell.yml
index 29e665881..92bf6feb8 100644
--- a/.github/workflows/run-pretrained-transducer-stateless-modified-2-aishell.yml
+++ b/.github/workflows/run-pretrained-transducer-stateless-modified-2-aishell.yml
@@ -23,6 +23,10 @@ on:
   pull_request:
     types: [labeled]
 
+concurrency:
+  group: run_pre_trained_transducer_stateless_modified_2_aishell-${{ github.ref }}
+  cancel-in-progress: true
+
 jobs:
   run_pre_trained_transducer_stateless_modified_2_aishell:
     if: github.event.label.name == 'ready' || github.event_name == 'push'
diff --git a/.github/workflows/run-pretrained-transducer-stateless-modified-aishell.yml b/.github/workflows/run-pretrained-transducer-stateless-modified-aishell.yml
index 6193f28e7..e51da8bd8 100644
--- a/.github/workflows/run-pretrained-transducer-stateless-modified-aishell.yml
+++ b/.github/workflows/run-pretrained-transducer-stateless-modified-aishell.yml
@@ -23,6 +23,10 @@ on:
   pull_request:
     types: [labeled]
 
+concurrency:
+  group: run_pre_trained_transducer_stateless_modified_aishell-${{ github.ref }}
+  cancel-in-progress: true
+
 jobs:
   run_pre_trained_transducer_stateless_modified_aishell:
     if: github.event.label.name == 'ready' || github.event_name == 'push'
diff --git a/.github/workflows/run-pretrained-transducer-stateless.yml b/.github/workflows/run-pretrained-transducer-stateless.yml
index 32208076c..2103d0510 100644
--- a/.github/workflows/run-pretrained-transducer-stateless.yml
+++ b/.github/workflows/run-pretrained-transducer-stateless.yml
@@ -32,6 +32,10 @@ on:
     # nightly build at 15:50 UTC time every day
     - cron: "50 15 * * *"
 
+concurrency:
+  group: run_pre_trained_transducer_stateless-${{ github.ref }}
+  cancel-in-progress: true
+
 jobs:
   run_pre_trained_transducer_stateless:
     if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
diff --git a/.github/workflows/run-pretrained-transducer.yml b/.github/workflows/run-pretrained-transducer.yml
index 965d0f655..902319b55 100644
--- a/.github/workflows/run-pretrained-transducer.yml
+++ b/.github/workflows/run-pretrained-transducer.yml
@@ -23,6 +23,10 @@ on:
   pull_request:
     types: [labeled]
 
+concurrency:
+  group: run_pre_trained_transducer-${{ github.ref }}
+  cancel-in-progress: true
+
 jobs:
   run_pre_trained_transducer:
     if: github.event.label.name == 'ready' || github.event_name == 'push'
diff --git a/.github/workflows/run-ptb-rnn-lm.yml b/.github/workflows/run-ptb-rnn-lm.yml
index 8ebc2e79b..47ed958f2 100644
--- a/.github/workflows/run-ptb-rnn-lm.yml
+++ b/.github/workflows/run-ptb-rnn-lm.yml
@@ -16,6 +16,10 @@ on:
     # nightly build at 15:50 UTC time every day
     - cron: "50 15 * * *"
 
+concurrency:
+  group: run_ptb_rnn_lm_training-${{ github.ref }}
+  cancel-in-progress: true
+
 jobs:
   run_ptb_rnn_lm_training:
     if: github.event.label.name == 'ready' || github.event.label.name == 'rnnlm' || github.event_name == 'push' || github.event_name == 'schedule'
diff --git a/.github/workflows/run-wenetspeech-pruned-transducer-stateless2.yml b/.github/workflows/run-wenetspeech-pruned-transducer-stateless2.yml
index d96a3bfe6..8a7be0b80 100644
--- a/.github/workflows/run-wenetspeech-pruned-transducer-stateless2.yml
+++ b/.github/workflows/run-wenetspeech-pruned-transducer-stateless2.yml
@@ -23,8 +23,12 @@ on:
   pull_request:
     types: [labeled]
 
+concurrency:
+  group: run_wenetspeech_pruned_transducer_stateless2-${{ github.ref }}
+  cancel-in-progress: true
+
 jobs:
-  run_librispeech_pruned_transducer_stateless3_2022_05_13:
+  run_wenetspeech_pruned_transducer_stateless2:
     if: github.event.label.name == 'onnx' || github.event.label.name == 'ready' || github.event_name == 'push' || github.event.label.name == 'wenetspeech'
     runs-on: ${{ matrix.os }}
     strategy:
diff --git a/.github/workflows/run-yesno-recipe.yml b/.github/workflows/run-yesno-recipe.yml
index ce77c47df..ed343aee5 100644
--- a/.github/workflows/run-yesno-recipe.yml
+++ b/.github/workflows/run-yesno-recipe.yml
@@ -21,11 +21,15 @@ on:
     branches:
       - master
   pull_request:
-    types: [labeled]
+    branches:
+      - master
+
+concurrency:
+  group: run-yesno-recipe-${{ github.ref }}
+  cancel-in-progress: true
 
 jobs:
   run-yesno-recipe:
-    if: github.event.label.name == 'ready' || github.event_name == 'push'
     runs-on: ${{ matrix.os }}
     strategy:
       matrix:
@@ -61,7 +65,7 @@ jobs:
 
       - name: Install Python dependencies
         run: |
-          grep -v '^#' ./requirements-ci.txt  | xargs -n 1 -L 1 pip install
+          grep -v '^#' ./requirements-ci.txt  | grep -v kaldifst | xargs -n 1 -L 1 pip install
           pip uninstall -y protobuf
           pip install --no-binary protobuf protobuf
 
diff --git a/.github/workflows/style_check.yml b/.github/workflows/style_check.yml
index 45d261ccc..fc1dcbfd4 100644
--- a/.github/workflows/style_check.yml
+++ b/.github/workflows/style_check.yml
@@ -24,6 +24,10 @@ on:
     branches:
       - master
 
+concurrency:
+  group: style_check-${{ github.ref }}
+  cancel-in-progress: true
+
 jobs:
   style_check:
     runs-on: ${{ matrix.os }}
diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml
index 04fc0265f..4dbe99827 100644
--- a/.github/workflows/test.yml
+++ b/.github/workflows/test.yml
@@ -21,26 +21,23 @@ on:
     branches:
       - master
   pull_request:
-    types: [labeled]
+    branches:
+      - master
+
+concurrency:
+  group: test-${{ github.ref }}
+  cancel-in-progress: true
 
 jobs:
   test:
-    if: github.event.label.name == 'ready' || github.event_name == 'push'
     runs-on: ${{ matrix.os }}
     strategy:
       matrix:
-        # os: [ubuntu-18.04, macos-10.15]
-        # disable macOS test for now.
-        os: [ubuntu-18.04]
-        python-version: [3.7, 3.8]
-        torch: ["1.8.0", "1.11.0"]
-        torchaudio: ["0.8.0", "0.11.0"]
-        k2-version: ["1.15.1.dev20220427"]
-        exclude:
-          - torch: "1.8.0"
-            torchaudio: "0.11.0"
-          - torch: "1.11.0"
-            torchaudio: "0.8.0"
+        os: [ubuntu-latest]
+        python-version: ["3.8"]
+        torch: ["1.10.0"]
+        torchaudio: ["0.10.0"]
+        k2-version: ["1.23.2.dev20221201"]
 
       fail-fast: false
 
@@ -67,11 +64,7 @@ jobs:
           # numpy 1.20.x does not support python 3.6
           pip install numpy==1.19
           pip install torch==${{ matrix.torch }}+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html
-          if [[ ${{ matrix.torchaudio }} == "0.11.0" ]]; then
-            pip install torchaudio==${{ matrix.torchaudio }}+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html
-          else
-            pip install torchaudio==${{ matrix.torchaudio }}
-          fi
+          pip install torchaudio==${{ matrix.torchaudio }}+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html
 
           pip install k2==${{ matrix.k2-version }}+cpu.torch${{ matrix.torch }} -f https://k2-fsa.org/nightly/
           pip install git+https://github.com/lhotse-speech/lhotse
@@ -81,7 +74,6 @@ jobs:
 
           pip install kaldifst
           pip install onnxruntime
-
           pip install -r requirements.txt
 
       - name: Install graphviz
@@ -124,16 +116,14 @@ jobs:
           cd ../transducer_stateless
           pytest -v -s
 
-          if [[ ${{ matrix.torchaudio }} == "0.10.0" ]]; then
-            cd ../transducer
-            pytest -v -s
+          cd ../transducer
+          pytest -v -s
 
-            cd ../transducer_stateless2
-            pytest -v -s
+          cd ../transducer_stateless2
+          pytest -v -s
 
-            cd ../transducer_lstm
-            pytest -v -s
-          fi
+          cd ../transducer_lstm
+          pytest -v -s
 
       - name: Run tests
         if: startsWith(matrix.os, 'macos')
@@ -164,13 +154,11 @@ jobs:
           cd ../transducer_stateless
           pytest -v -s
 
-          if [[ ${{ matrix.torchaudio }} == "0.10.0" ]]; then
-            cd ../transducer
-            pytest -v -s
+          cd ../transducer
+          pytest -v -s
 
-            cd ../transducer_stateless2
-            pytest -v -s
+          cd ../transducer_stateless2
+          pytest -v -s
 
-            cd ../transducer_lstm
-            pytest -v -s
-          fi
+          cd ../transducer_lstm
+          pytest -v -s
diff --git a/egs/librispeech/ASR/local/train_bpe_model.py b/egs/librispeech/ASR/local/train_bpe_model.py
index 7f6f47e16..43142aee4 100755
--- a/egs/librispeech/ASR/local/train_bpe_model.py
+++ b/egs/librispeech/ASR/local/train_bpe_model.py
@@ -93,7 +93,6 @@ def main():
         print(f"{model_file} exists - skipping")
         return
 
-
     shutil.copyfile(model_file, f"{lang_dir}/bpe.model")
 
 
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py
index 59c8ed5b5..b324cc9b7 100644
--- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py
@@ -2230,9 +2230,7 @@ def modified_beam_search_rnnlm_LODR(
         log_probs_shape = k2.ragged.create_ragged_shape2(
             row_splits=row_splits, cached_tot_size=log_probs.numel()
         )
-        ragged_log_probs = k2.RaggedTensor(
-            shape=log_probs_shape, value=log_probs
-        )
+        ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs)
         """
         for all hyps with a non-blank new token, score this token.
         It is a little confusing here because this for-loop
@@ -2267,10 +2265,7 @@ def modified_beam_search_rnnlm_LODR(
         # forward RNNLM to get new states and scores
         if len(token_list) != 0:
             tokens_to_score = (
-                torch.tensor(token_list)
-                .to(torch.int64)
-                .to(device)
-                .reshape(-1, 1)
+                torch.tensor(token_list).to(torch.int64).to(device).reshape(-1, 1)
             )
 
             hs = torch.cat(hs, dim=1).to(device)
@@ -2304,9 +2299,7 @@ def modified_beam_search_rnnlm_LODR(
                     state_cost = hyp.state_cost.forward_one_step(new_token)
 
                     # calculate the score of the latest token
-                    current_ngram_score = (
-                        state_cost.lm_score - hyp.state_cost.lm_score
-                    )
+                    current_ngram_score = state_cost.lm_score - hyp.state_cost.lm_score
 
                     assert current_ngram_score <= 0.0, (
                         state_cost.lm_score,
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/test_scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless3/test_scaling.py
index e9dfe6d5e..42de2410a 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless3/test_scaling.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless3/test_scaling.py
@@ -52,17 +52,9 @@ def test_scaled_conv2d():
         torch.jit.script(conv2d)
 
 
-def test_activation_balancer():
-    act = ActivationBalancer(
-        channel_dim=1, max_abs=10.0, min_positive=0.05, max_positive=1.0
-    )
-    torch.jit.script(act)
-
-
 def main():
     test_scaled_conv1d()
     test_scaled_conv2d()
-    test_activation_balancer()
 
 
 if __name__ == "__main__":
diff --git a/egs/yesno/ASR/tdnn/asr_datamodule.py b/egs/yesno/ASR/tdnn/asr_datamodule.py
index 85e5f1358..3c1682fa1 100644
--- a/egs/yesno/ASR/tdnn/asr_datamodule.py
+++ b/egs/yesno/ASR/tdnn/asr_datamodule.py
@@ -121,7 +121,7 @@ class YesNoAsrDataModule(DataModule):
         group.add_argument(
             "--shuffle",
             type=str2bool,
-            default=True,
+            default=False,
             help="When enabled (=default), the examples will be "
             "shuffled for each epoch.",
         )

From 6f719816673761ceda0bfe6bece5a44b151ead46 Mon Sep 17 00:00:00 2001
From: Amir Hussein <36240131+AmirHussein96@users.noreply.github.com>
Date: Thu, 1 Dec 2022 21:58:34 -0500
Subject: [PATCH 030/174] MGB2 (#396)

* mgb2

* mgb2

* adding pruned transducer stateless to mgb2

* update display_manifest_statistics.py

* .

* stateless transducer MGB-2

* Update README.md

* Update RESULTS.md

* Update prepare_lang_bpe.py

* Update asr_datamodule.py

* .nfs removed

* Adding symlink

* .

* resolving conflicts

* Update .gitignore

* black formatting

* Update compile_hlg.py

* Update compute_fbank_musan.py

* Update convert_transcript_words_to_tokens.py

* Update download_lm.py

* Update generate_unique_lexicon.py

* adding simlinks

* fixing symbolic links
---
 .gitignore                                    |   20 +
 egs/mgb2/ASR/README.md                        |   43 +
 egs/mgb2/ASR/RESULTS.md                       |  236 ++++
 egs/mgb2/ASR/conformer_ctc/__init__.py        |    0
 egs/mgb2/ASR/conformer_ctc/ali.py             |  395 ++++++
 egs/mgb2/ASR/conformer_ctc/asr_datamodule.py  |  372 ++++++
 egs/mgb2/ASR/conformer_ctc/compile_hlg.py     |    1 +
 .../ASR/conformer_ctc/compute_fbank_musan.py  |    1 +
 egs/mgb2/ASR/conformer_ctc/conformer.py       |    1 +
 .../convert_transcript_words_to_tokens.py     |    1 +
 egs/mgb2/ASR/conformer_ctc/decode.py          |  695 ++++++++++
 egs/mgb2/ASR/conformer_ctc/download_lm.py     |    1 +
 egs/mgb2/ASR/conformer_ctc/export.py          |    1 +
 .../conformer_ctc/generate_unique_lexicon.py  |    1 +
 egs/mgb2/ASR/conformer_ctc/label_smoothing.py |    1 +
 egs/mgb2/ASR/conformer_ctc/pretrained.py      |  430 ++++++
 egs/mgb2/ASR/conformer_ctc/subsampling.py     |    1 +
 .../ASR/conformer_ctc/test_label_smoothing.py |    1 +
 .../ASR/conformer_ctc/test_subsampling.py     |    1 +
 .../ASR/conformer_ctc/test_transformer.py     |    1 +
 egs/mgb2/ASR/conformer_ctc/train.py           |  766 +++++++++++
 egs/mgb2/ASR/conformer_ctc/transformer.py     |    1 +
 egs/mgb2/ASR/local/__init__.py                |    0
 egs/mgb2/ASR/local/compile_hlg.py             |    1 +
 egs/mgb2/ASR/local/compute_fbank_mgb2.py      |  101 ++
 egs/mgb2/ASR/local/compute_fbank_musan.py     |  108 ++
 .../convert_transcript_words_to_tokens.py     |  103 ++
 .../ASR/local/display_manifest_statistics.py  |   97 ++
 egs/mgb2/ASR/local/generate_unique_lexicon.py |    1 +
 egs/mgb2/ASR/local/prep_mgb2_lexicon.sh       |   30 +
 egs/mgb2/ASR/local/prepare_lang.py            |    1 +
 egs/mgb2/ASR/local/prepare_lang_bpe.py        |    1 +
 egs/mgb2/ASR/local/prepare_mgb2_lexicon.py    |   37 +
 egs/mgb2/ASR/local/test_prepare_lang.py       |    1 +
 egs/mgb2/ASR/prepare.sh                       |  234 ++++
 .../pruned_transducer_stateless5/__init__.py  |    0
 .../asr_datamodule.py                         |    1 +
 .../beam_search.py                            |    1 +
 .../pruned_transducer_stateless5/conformer.py |    1 +
 .../pruned_transducer_stateless5/decode.py    |  625 +++++++++
 .../pruned_transducer_stateless5/decoder.py   |    1 +
 .../encoder_interface.py                      |    1 +
 .../pruned_transducer_stateless5/export.py    |  272 ++++
 .../pruned_transducer_stateless5/joiner.py    |    1 +
 .../ASR/pruned_transducer_stateless5/model.py |    1 +
 .../ASR/pruned_transducer_stateless5/optim.py |    1 +
 .../pretrained.py                             |  344 +++++
 .../pruned_transducer_stateless5/scaling.py   |    1 +
 .../test_model.py                             |    1 +
 .../ASR/pruned_transducer_stateless5/train.py | 1176 +++++++++++++++++
 egs/mgb2/ASR/shared                           |    1 +
 icefall/diagnostics.py                        |    2 +-
 52 files changed, 6114 insertions(+), 1 deletion(-)
 create mode 100644 egs/mgb2/ASR/README.md
 create mode 100644 egs/mgb2/ASR/RESULTS.md
 create mode 100644 egs/mgb2/ASR/conformer_ctc/__init__.py
 create mode 100755 egs/mgb2/ASR/conformer_ctc/ali.py
 create mode 100644 egs/mgb2/ASR/conformer_ctc/asr_datamodule.py
 create mode 120000 egs/mgb2/ASR/conformer_ctc/compile_hlg.py
 create mode 120000 egs/mgb2/ASR/conformer_ctc/compute_fbank_musan.py
 create mode 120000 egs/mgb2/ASR/conformer_ctc/conformer.py
 create mode 120000 egs/mgb2/ASR/conformer_ctc/convert_transcript_words_to_tokens.py
 create mode 100755 egs/mgb2/ASR/conformer_ctc/decode.py
 create mode 120000 egs/mgb2/ASR/conformer_ctc/download_lm.py
 create mode 120000 egs/mgb2/ASR/conformer_ctc/export.py
 create mode 120000 egs/mgb2/ASR/conformer_ctc/generate_unique_lexicon.py
 create mode 120000 egs/mgb2/ASR/conformer_ctc/label_smoothing.py
 create mode 100755 egs/mgb2/ASR/conformer_ctc/pretrained.py
 create mode 120000 egs/mgb2/ASR/conformer_ctc/subsampling.py
 create mode 120000 egs/mgb2/ASR/conformer_ctc/test_label_smoothing.py
 create mode 120000 egs/mgb2/ASR/conformer_ctc/test_subsampling.py
 create mode 120000 egs/mgb2/ASR/conformer_ctc/test_transformer.py
 create mode 100755 egs/mgb2/ASR/conformer_ctc/train.py
 create mode 120000 egs/mgb2/ASR/conformer_ctc/transformer.py
 create mode 100644 egs/mgb2/ASR/local/__init__.py
 create mode 120000 egs/mgb2/ASR/local/compile_hlg.py
 create mode 100755 egs/mgb2/ASR/local/compute_fbank_mgb2.py
 create mode 100755 egs/mgb2/ASR/local/compute_fbank_musan.py
 create mode 100755 egs/mgb2/ASR/local/convert_transcript_words_to_tokens.py
 create mode 100755 egs/mgb2/ASR/local/display_manifest_statistics.py
 create mode 120000 egs/mgb2/ASR/local/generate_unique_lexicon.py
 create mode 100755 egs/mgb2/ASR/local/prep_mgb2_lexicon.sh
 create mode 120000 egs/mgb2/ASR/local/prepare_lang.py
 create mode 120000 egs/mgb2/ASR/local/prepare_lang_bpe.py
 create mode 100755 egs/mgb2/ASR/local/prepare_mgb2_lexicon.py
 create mode 120000 egs/mgb2/ASR/local/test_prepare_lang.py
 create mode 100755 egs/mgb2/ASR/prepare.sh
 create mode 100644 egs/mgb2/ASR/pruned_transducer_stateless5/__init__.py
 create mode 120000 egs/mgb2/ASR/pruned_transducer_stateless5/asr_datamodule.py
 create mode 120000 egs/mgb2/ASR/pruned_transducer_stateless5/beam_search.py
 create mode 120000 egs/mgb2/ASR/pruned_transducer_stateless5/conformer.py
 create mode 100755 egs/mgb2/ASR/pruned_transducer_stateless5/decode.py
 create mode 120000 egs/mgb2/ASR/pruned_transducer_stateless5/decoder.py
 create mode 120000 egs/mgb2/ASR/pruned_transducer_stateless5/encoder_interface.py
 create mode 100755 egs/mgb2/ASR/pruned_transducer_stateless5/export.py
 create mode 120000 egs/mgb2/ASR/pruned_transducer_stateless5/joiner.py
 create mode 120000 egs/mgb2/ASR/pruned_transducer_stateless5/model.py
 create mode 120000 egs/mgb2/ASR/pruned_transducer_stateless5/optim.py
 create mode 100755 egs/mgb2/ASR/pruned_transducer_stateless5/pretrained.py
 create mode 120000 egs/mgb2/ASR/pruned_transducer_stateless5/scaling.py
 create mode 120000 egs/mgb2/ASR/pruned_transducer_stateless5/test_model.py
 create mode 100755 egs/mgb2/ASR/pruned_transducer_stateless5/train.py
 create mode 120000 egs/mgb2/ASR/shared

diff --git a/.gitignore b/.gitignore
index 406deff6a..583410f45 100644
--- a/.gitignore
+++ b/.gitignore
@@ -11,5 +11,25 @@ log
 *.bak
 *-bak
 *bak.py
+
+# Ignore Mac system files
+.DS_store
+
+# Ignore node_modules folder
+node_modules
+
+# ignore .nfs
+
+.nfs*
+
+# Ignore all text files
+*.txt
+
+# Ignore files related to API keys
+.env
+
+# Ignore SASS config files
+.sass-cache
+
 *.param
 *.bin
diff --git a/egs/mgb2/ASR/README.md b/egs/mgb2/ASR/README.md
new file mode 100644
index 000000000..2bc4b000b
--- /dev/null
+++ b/egs/mgb2/ASR/README.md
@@ -0,0 +1,43 @@
+# MGB2
+
+The Multi-Dialect Broadcast News Arabic Speech Recognition (MGB-2):
+The second edition of the Multi-Genre Broadcast (MGB-2) Challenge is
+an evaluation of speech recognition and lightly supervised alignment
+using TV recordings in Arabic. The speech data is broad and multi-genre,
+spanning the whole range of TV output, and represents a challenging task for
+speech technology. In 2016, the challenge featured two new Arabic tracks based
+on TV data from Aljazeera. It was an official challenge at the 2016 IEEE
+Workshop on Spoken Language Technology. The 1,200 hours MGB-2: from Aljazeera
+TV programs have been manually captioned with no timing information.
+QCRI Arabic ASR system has been used to recognize all programs. The ASR output
+was used to align the manual captioning and produce speech segments for
+training speech recognition. More than 20 hours from 2015 programs have been
+transcribed verbatim and manually segmented. This data is split into a
+development set of 10 hours, and a similar evaluation set of 10 hours.
+Both the development and evaluation data have been released in the 2016 MGB
+challenge
+
+Official reference:
+
+Ali, Ahmed, et al. "The MGB-2 challenge: Arabic multi-dialect broadcast media recognition." 
+2016 IEEE Spoken Language Technology Workshop (SLT). IEEE, 2016.
+
+IEEE link: https://ieeexplore.ieee.org/abstract/document/7846277
+
+## Stateless Pruned Transducer Performance Record (after 30 epochs)
+
+|                                    |     dev    |    test    | comment                                  |
+|------------------------------------|------------|------------|------------------------------------------|
+|          greedy search             | 15.52      | 15.28      | --epoch 18, --avg 5, --max-duration 200  |
+| modified beam search               | 13.88      | 13.7       | --epoch 18, --avg 5, --max-duration 200  |
+| fast beam search                   | 14.62      | 14.36      | --epoch 18, --avg 5, --max-duration 200  |
+
+## Conformer-CTC Performance Record (after 40 epochs)
+
+| Decoding method           | dev WER    | test WER |
+|---------------------------|------------|---------|
+| attention-decoder         | 15.62      |  15.01  |
+| whole-lattice-rescoring   | 15.89      |  15.08  |
+
+
+See [RESULTS](/egs/mgb2/ASR/RESULTS.md) for details.
diff --git a/egs/mgb2/ASR/RESULTS.md b/egs/mgb2/ASR/RESULTS.md
new file mode 100644
index 000000000..2a7ea7664
--- /dev/null
+++ b/egs/mgb2/ASR/RESULTS.md
@@ -0,0 +1,236 @@
+# Results
+
+
+### MGB2 all data BPE training results (Stateless Pruned Transducer)
+
+#### 2022-09-07
+
+The WERs are
+
+|                                    |     dev    |    test    | comment                                  |
+|------------------------------------|------------|------------|------------------------------------------|
+|          greedy search             | 15.52      | 15.28      | --epoch 18, --avg 5, --max-duration 200 |
+| modified beam search               | 13.88      | 13.7       | --epoch 18, --avg 5, --max-duration 200 |
+| fast beam search                   | 14.62      | 14.36      | --epoch 18, --avg 5, --max-duration 200|
+
+The training command for reproducing is given below:
+
+```
+export CUDA_VISIBLE_DEVICES="0,1,2,3"
+
+
+  
+./pruned_transducer_stateless5/train.py \
+  --world-size 4 \
+  --num-epochs 30 \
+  --start-epoch 1 \
+  --exp-dir pruned_transducer_stateless5/exp \
+  --max-duration 300 \
+  --num-buckets 50
+```
+
+The tensorboard training log can be found at
+https://tensorboard.dev/experiment/YyNv45pfQ0GqWzZ898WOlw/#scalars
+
+The decoding command is:
+```
+epoch=18
+avg=5
+for method in greedy_search modified_beam_search fast_beam_search; do
+  ./pruned_transducer_stateless5/decode.py \
+    --epoch $epoch \
+	--beam-size 10 \
+    --avg $avg \
+    --exp-dir ./pruned_transducer_stateless5/exp \
+    --max-duration 200 \
+    --decoding-method $method \
+    --max-sym-per-frame 1 \
+    --num-encoder-layers 12 \
+    --dim-feedforward 2048 \
+    --nhead 8 \
+    --encoder-dim 512 \
+    --decoder-dim 512 \
+    --joiner-dim 512 \
+    --use-averaged-model True
+done
+```
+
+### MGB2 all data BPE training results (Conformer-CTC) (after 40 epochs)
+
+#### 2022-06-04
+
+You can find a pretrained model, training logs, decoding logs, and decoding results at:
+https://huggingface.co/AmirHussein/icefall-asr-mgb2-conformer_ctc-2022-27-06
+
+The best WER, as of 2022-06-04, for the MGB2 test dataset is below
+
+Using whole lattice HLG decoding + n-gram LM rescoring 
+
+|     | dev        | test       |
+|-----|------------|------------|
+| WER | 15.62      |  15.01     |
+
+Scale values used in n-gram LM rescoring and attention rescoring for the best WERs are:
+| ngram_lm_scale | attention_scale |
+|----------------|-----------------|
+| 0.1            | -            |
+
+
+Using n-best (n=0.5) attention decoder rescoring
+
+|     | dev        | test       |
+|-----|------------|------------|
+| WER |    15.89   |  15.08     |
+
+Scale values used in n-gram LM rescoring and attention rescoring for the best WERs are:
+| ngram_lm_scale | attention_scale |
+|----------------|-----------------|
+| 0.01           | 0.5             |
+
+
+To reproduce the above result, use the following commands for training:
+
+# Note: the model was trained on V-100 32GB GPU
+
+```
+cd egs/mgb2/ASR
+. ./path.sh
+./prepare.sh
+export CUDA_VISIBLE_DEVICES="0,1"
+./conformer_ctc/train.py \
+  --lang-dir data/lang_bpe_5000 \
+  --att-rate 0.8 \
+  --lr-factor 10 \
+  --max-duration  \
+  --concatenate-cuts 0 \
+  --world-size 2 \
+  --bucketing-sampler 1 \
+  --max-duration 100 \ 
+  --start-epoch 0 \
+  --num-epochs 40
+  
+```
+
+and the following command for nbest decoding
+
+```
+./conformer_ctc/decode.py \
+  --lang-dir data/lang_bpe_5000 \
+  --max-duration 30 \
+  --concatenate-cuts 0 \
+  --bucketing-sampler 1 \
+  --num-paths 1000 \
+  --epoch 40 \
+  --avg 5 \
+  --method attention-decoder \
+  --nbest-scale 0.5
+```
+
+and the following command for whole-lattice decoding
+
+```
+./conformer_ctc/decode.py \
+  --epoch 40 \
+  --avg 5 \
+  --exp-dir conformer_ctc/exp_5000_att0.8 \
+  --lang-dir data/lang_bpe_5000 \
+  --max-duration 30 \
+  --concatenate-cuts 0 \
+  --bucketing-sampler 1 \
+  --num-paths 1000 \
+  --method  whole-lattice-rescoring
+```
+
+
+The tensorboard log for training is available at
+https://tensorboard.dev/experiment/QYNzOi52RwOX8yvtpl3hMw/#scalars
+
+
+### MGB2 100h BPE training results (Conformer-CTC) (after 33 epochs)
+
+#### 2022-06-04
+
+The best WER, as of 2022-06-04, for the MGB2 test dataset is below
+
+Using whole lattice HLG decoding + n-gram LM rescoring 
+
+|     | dev        | test       |
+|-----|------------|------------|
+| WER | 25.32      |  23.53     |
+
+Scale values used in n-gram LM rescoring and attention rescoring for the best WERs are:
+| ngram_lm_scale | attention_scale |
+|----------------|-----------------|
+| 0.1            | -            |
+
+
+Using n-best (n=0.5) HLG decoding + n-gram LM rescoring + attention decoder rescoring:
+
+|     | dev        | test       |
+|-----|------------|------------|
+| WER |    27.87   |  26.12     |
+
+Scale values used in n-gram LM rescoring and attention rescoring for the best WERs are:
+| ngram_lm_scale | attention_scale |
+|----------------|-----------------|
+| 0.01           | 0.3             |
+
+
+To reproduce the above result, use the following commands for training:
+
+# Note: the model was trained on V-100 32GB GPU
+
+```
+cd egs/mgb2/ASR
+. ./path.sh
+./prepare.sh
+export CUDA_VISIBLE_DEVICES="0,1"
+./conformer_ctc/train.py \
+  --lang-dir data/lang_bpe_5000 \
+  --att-rate 0.8 \
+  --lr-factor 10 \
+  --max-duration  \
+  --concatenate-cuts 0 \
+  --world-size 2 \
+  --bucketing-sampler 1 \
+  --max-duration 100 \ 
+  --start-epoch 0 \
+  --num-epochs 40
+  
+```
+
+and the following command for nbest decoding
+
+```
+./conformer_ctc/decode.py \
+  --lang-dir data/lang_bpe_5000 \
+  --max-duration 30 \
+  --concatenate-cuts 0 \
+  --bucketing-sampler 1 \
+  --num-paths 1000 \
+  --epoch 40 \
+  --avg 5 \
+  --method attention-decoder \
+  --nbest-scale 0.5
+```
+
+and the following command for whole-lattice decoding
+
+```
+./conformer_ctc/decode.py \
+  --lang-dir data/lang_bpe_5000 \
+  --max-duration 30 \
+  --concatenate-cuts 0 \
+  --bucketing-sampler 1 \
+  --num-paths 1000 \
+  --epoch 40 \
+  --avg 5 \
+  --method  whole-lattice-rescoring
+```
+
+The tensorboard log for training is available at
+
+
+
+
+
diff --git a/egs/mgb2/ASR/conformer_ctc/__init__.py b/egs/mgb2/ASR/conformer_ctc/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/egs/mgb2/ASR/conformer_ctc/ali.py b/egs/mgb2/ASR/conformer_ctc/ali.py
new file mode 100755
index 000000000..aea962dcd
--- /dev/null
+++ b/egs/mgb2/ASR/conformer_ctc/ali.py
@@ -0,0 +1,395 @@
+#!/usr/bin/env python3
+# Copyright    2021  Xiaomi Corp.        (authors: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Usage:
+    ./conformer_ctc/ali.py \
+            --exp-dir ./conformer_ctc/exp \
+            --lang-dir ./data/lang_bpe_500 \
+            --epoch 20 \
+            --avg 10 \
+            --max-duration 300 \
+            --dataset train-clean-100 \
+            --out-dir data/ali
+"""
+
+import argparse
+import logging
+from pathlib import Path
+
+import k2
+import numpy as np
+import torch
+from asr_datamodule import LibriSpeechAsrDataModule
+from conformer import Conformer
+from lhotse import CutSet
+from lhotse.features.io import FeaturesWriter, NumpyHdf5Writer
+
+from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
+from icefall.checkpoint import average_checkpoints, load_checkpoint
+from icefall.decode import one_best_decoding
+from icefall.env import get_env_info
+from icefall.lexicon import Lexicon
+from icefall.utils import (
+    AttributeDict,
+    encode_supervisions,
+    get_alignments,
+    setup_logger,
+)
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+
+    parser.add_argument(
+        "--epoch",
+        type=int,
+        default=34,
+        help="It specifies the checkpoint to use for decoding."
+        "Note: Epoch counts from 0.",
+    )
+    parser.add_argument(
+        "--avg",
+        type=int,
+        default=20,
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch'. ",
+    )
+
+    parser.add_argument(
+        "--lang-dir",
+        type=str,
+        default="data/lang_bpe_500",
+        help="The lang dir",
+    )
+
+    parser.add_argument(
+        "--exp-dir",
+        type=str,
+        default="conformer_ctc/exp",
+        help="The experiment dir",
+    )
+
+    parser.add_argument(
+        "--out-dir",
+        type=str,
+        required=True,
+        help="""Output directory.
+        It contains 3 generated files:
+
+        - labels_xxx.h5
+        - aux_labels_xxx.h5
+        - cuts_xxx.json.gz
+
+        where xxx is the value of `--dataset`. For instance, if
+        `--dataset` is `train-clean-100`, it will contain 3 files:
+
+        - `labels_train-clean-100.h5`
+        - `aux_labels_train-clean-100.h5`
+        - `cuts_train-clean-100.json.gz`
+
+        Note: Both labels_xxx.h5 and aux_labels_xxx.h5 contain framewise
+        alignment. The difference is that labels_xxx.h5 contains repeats.
+        """,
+    )
+
+    parser.add_argument(
+        "--dataset",
+        type=str,
+        required=True,
+        help="""The name of the dataset to compute alignments for.
+        Possible values are:
+            - test-clean.
+            - test-other
+            - train-clean-100
+            - train-clean-360
+            - train-other-500
+            - dev-clean
+            - dev-other
+        """,
+    )
+    return parser
+
+
+def get_params() -> AttributeDict:
+    params = AttributeDict(
+        {
+            "lm_dir": Path("data/lm"),
+            "feature_dim": 80,
+            "nhead": 8,
+            "attention_dim": 512,
+            "subsampling_factor": 4,
+            # Set it to 0 since attention decoder
+            # is not used for computing alignments
+            "num_decoder_layers": 0,
+            "vgg_frontend": False,
+            "use_feat_batchnorm": True,
+            "output_beam": 10,
+            "use_double_scores": True,
+            "env_info": get_env_info(),
+        }
+    )
+    return params
+
+
+def compute_alignments(
+    model: torch.nn.Module,
+    dl: torch.utils.data.DataLoader,
+    labels_writer: FeaturesWriter,
+    aux_labels_writer: FeaturesWriter,
+    params: AttributeDict,
+    graph_compiler: BpeCtcTrainingGraphCompiler,
+) -> CutSet:
+    """Compute the framewise alignments of a dataset.
+
+    Args:
+      model:
+        The neural network model.
+      dl:
+        Dataloader containing the dataset.
+      params:
+        Parameters for computing alignments.
+      graph_compiler:
+        It converts token IDs to decoding graphs.
+    Returns:
+      Return a CutSet. Each cut has two custom fields: labels_alignment
+      and aux_labels_alignment, containing framewise alignments information.
+      Both are of type `lhotse.array.TemporalArray`. The difference between
+      the two alignments is that `labels_alignment` contain repeats.
+    """
+    try:
+        num_batches = len(dl)
+    except TypeError:
+        num_batches = "?"
+    num_cuts = 0
+
+    device = graph_compiler.device
+    cuts = []
+    for batch_idx, batch in enumerate(dl):
+        feature = batch["inputs"]
+
+        # at entry, feature is [N, T, C]
+        assert feature.ndim == 3
+        feature = feature.to(device)
+
+        supervisions = batch["supervisions"]
+        cut_list = supervisions["cut"]
+
+        for cut in cut_list:
+            assert len(cut.supervisions) == 1, f"{len(cut.supervisions)}"
+
+        nnet_output, encoder_memory, memory_mask = model(feature, supervisions)
+        # nnet_output is [N, T, C]
+        supervision_segments, texts = encode_supervisions(
+            supervisions, subsampling_factor=params.subsampling_factor
+        )
+        # we need also to sort cut_ids as encode_supervisions()
+        # reorders "texts".
+        # In general, new2old is an identity map since lhotse sorts the returned
+        # cuts by duration in descending order
+        new2old = supervision_segments[:, 0].tolist()
+
+        cut_list = [cut_list[i] for i in new2old]
+
+        token_ids = graph_compiler.texts_to_ids(texts)
+        decoding_graph = graph_compiler.compile(token_ids)
+
+        dense_fsa_vec = k2.DenseFsaVec(
+            nnet_output,
+            supervision_segments,
+            allow_truncate=params.subsampling_factor - 1,
+        )
+
+        lattice = k2.intersect_dense(
+            decoding_graph,
+            dense_fsa_vec,
+            params.output_beam,
+        )
+
+        best_path = one_best_decoding(
+            lattice=lattice,
+            use_double_scores=params.use_double_scores,
+        )
+
+        labels_ali = get_alignments(best_path, kind="labels")
+        aux_labels_ali = get_alignments(best_path, kind="aux_labels")
+        assert len(labels_ali) == len(aux_labels_ali) == len(cut_list)
+        for cut, labels, aux_labels in zip(cut_list, labels_ali, aux_labels_ali):
+            cut.labels_alignment = labels_writer.store_array(
+                key=cut.id,
+                value=np.asarray(labels, dtype=np.int32),
+                # frame shift is 0.01s, subsampling_factor is 4
+                frame_shift=0.04,
+                temporal_dim=0,
+                start=0,
+            )
+            cut.aux_labels_alignment = aux_labels_writer.store_array(
+                key=cut.id,
+                value=np.asarray(aux_labels, dtype=np.int32),
+                # frame shift is 0.01s, subsampling_factor is 4
+                frame_shift=0.04,
+                temporal_dim=0,
+                start=0,
+            )
+
+        cuts += cut_list
+
+        num_cuts += len(cut_list)
+
+        if batch_idx % 100 == 0:
+            batch_str = f"{batch_idx}/{num_batches}"
+
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+
+    return CutSet.from_cuts(cuts)
+
+
+@torch.no_grad()
+def main():
+    parser = get_parser()
+    LibriSpeechAsrDataModule.add_arguments(parser)
+    args = parser.parse_args()
+
+    args.enable_spec_aug = False
+    args.enable_musan = False
+    args.return_cuts = True
+    args.concatenate_cuts = False
+
+    params = get_params()
+    params.update(vars(args))
+
+    setup_logger(f"{params.exp_dir}/log-ali")
+
+    logging.info(f"Computing alignments for {params.dataset} - started")
+    logging.info(params)
+
+    out_dir = Path(params.out_dir)
+    out_dir.mkdir(exist_ok=True)
+
+    out_labels_ali_filename = out_dir / f"labels_{params.dataset}.h5"
+    out_aux_labels_ali_filename = out_dir / f"aux_labels_{params.dataset}.h5"
+    out_manifest_filename = out_dir / f"cuts_{params.dataset}.json.gz"
+
+    for f in (
+        out_labels_ali_filename,
+        out_aux_labels_ali_filename,
+        out_manifest_filename,
+    ):
+        if f.exists():
+            logging.info(f"{f} exists - skipping")
+            return
+
+    lexicon = Lexicon(params.lang_dir)
+    max_token_id = max(lexicon.tokens)
+    num_classes = max_token_id + 1  # +1 for the blank
+
+    device = torch.device("cpu")
+    if torch.cuda.is_available():
+        device = torch.device("cuda", 0)
+    logging.info(f"device: {device}")
+
+    graph_compiler = BpeCtcTrainingGraphCompiler(
+        params.lang_dir,
+        device=device,
+        sos_token="",
+        eos_token="",
+    )
+
+    logging.info("About to create model")
+    model = Conformer(
+        num_features=params.feature_dim,
+        nhead=params.nhead,
+        d_model=params.attention_dim,
+        num_classes=num_classes,
+        subsampling_factor=params.subsampling_factor,
+        num_decoder_layers=params.num_decoder_layers,
+        vgg_frontend=params.vgg_frontend,
+        use_feat_batchnorm=params.use_feat_batchnorm,
+    )
+    model.to(device)
+
+    if params.avg == 1:
+        load_checkpoint(
+            f"{params.exp_dir}/epoch-{params.epoch}.pt", model, strict=False
+        )
+    else:
+        start = params.epoch - params.avg + 1
+        filenames = []
+        for i in range(start, params.epoch + 1):
+            if start >= 0:
+                filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+        logging.info(f"averaging {filenames}")
+        model.load_state_dict(
+            average_checkpoints(filenames, device=device), strict=False
+        )
+
+    model.eval()
+
+    librispeech = LibriSpeechAsrDataModule(args)
+    if params.dataset == "test-clean":
+        test_clean_cuts = librispeech.test_clean_cuts()
+        dl = librispeech.test_dataloaders(test_clean_cuts)
+    elif params.dataset == "test-other":
+        test_other_cuts = librispeech.test_other_cuts()
+        dl = librispeech.test_dataloaders(test_other_cuts)
+    elif params.dataset == "train-clean-100":
+        train_clean_100_cuts = librispeech.train_clean_100_cuts()
+        dl = librispeech.train_dataloaders(train_clean_100_cuts)
+    elif params.dataset == "train-clean-360":
+        train_clean_360_cuts = librispeech.train_clean_360_cuts()
+        dl = librispeech.train_dataloaders(train_clean_360_cuts)
+    elif params.dataset == "train-other-500":
+        train_other_500_cuts = librispeech.train_other_500_cuts()
+        dl = librispeech.train_dataloaders(train_other_500_cuts)
+    elif params.dataset == "dev-clean":
+        dev_clean_cuts = librispeech.dev_clean_cuts()
+        dl = librispeech.valid_dataloaders(dev_clean_cuts)
+    else:
+        assert params.dataset == "dev-other", f"{params.dataset}"
+        dev_other_cuts = librispeech.dev_other_cuts()
+        dl = librispeech.valid_dataloaders(dev_other_cuts)
+
+    logging.info(f"Processing {params.dataset}")
+    with NumpyHdf5Writer(out_labels_ali_filename) as labels_writer:
+        with NumpyHdf5Writer(out_aux_labels_ali_filename) as aux_labels_writer:
+            cut_set = compute_alignments(
+                model=model,
+                dl=dl,
+                labels_writer=labels_writer,
+                aux_labels_writer=aux_labels_writer,
+                params=params,
+                graph_compiler=graph_compiler,
+            )
+
+    cut_set.to_file(out_manifest_filename)
+
+    logging.info(
+        f"For dataset {params.dataset}, its alignments with repeats are "
+        f"saved to {out_labels_ali_filename}, the alignments without repeats "
+        f"are saved to {out_aux_labels_ali_filename}, and the cut manifest "
+        f"file is {out_manifest_filename}. Number of cuts: {len(cut_set)}"
+    )
+
+
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+if __name__ == "__main__":
+    main()
diff --git a/egs/mgb2/ASR/conformer_ctc/asr_datamodule.py b/egs/mgb2/ASR/conformer_ctc/asr_datamodule.py
new file mode 100644
index 000000000..8242e986d
--- /dev/null
+++ b/egs/mgb2/ASR/conformer_ctc/asr_datamodule.py
@@ -0,0 +1,372 @@
+# Copyright 2022 Johns Hopkins University  (Amir Hussein)
+# Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)
+
+
+import argparse
+import inspect
+import logging
+from functools import lru_cache
+from pathlib import Path
+from typing import Any, Dict, Optional
+
+import torch
+from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
+from lhotse.dataset import (
+    CutConcatenate,
+    CutMix,
+    DynamicBucketingSampler,
+    K2SpeechRecognitionDataset,
+    PrecomputedFeatures,
+    SingleCutSampler,
+    SpecAugment,
+)
+from lhotse.dataset.input_strategies import OnTheFlyFeatures
+from lhotse.utils import fix_random_seed
+from torch.utils.data import DataLoader
+
+from icefall.utils import str2bool
+
+
+class _SeedWorkers:
+    def __init__(self, seed: int):
+        self.seed = seed
+
+    def __call__(self, worker_id: int):
+        fix_random_seed(self.seed + worker_id)
+
+
+class MGB2AsrDataModule:
+
+    """
+    DataModule for k2 ASR experiments.
+    It assumes there is always one train and valid dataloader,
+    but there can be multiple test dataloaders
+
+    It contains all the common data pipeline modules used in ASR
+    experiments, e.g.:
+    - dynamic batch size,
+    - bucketing samplers,
+    - cut concatenation,
+    - augmentation,
+    - on-the-fly feature extraction
+
+    This class should be derived for specific corpora used in ASR tasks.
+    """
+
+    def __init__(self, args: argparse.Namespace):
+        self.args = args
+
+    @classmethod
+    def add_arguments(cls, parser: argparse.ArgumentParser):
+        group = parser.add_argument_group(
+            title="ASR data related options",
+            description="These options are used for the preparation of "
+            "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
+            "effective batch sizes, sampling strategies, applied data "
+            "augmentations, etc.",
+        )
+        group.add_argument(
+            "--manifest-dir",
+            type=Path,
+            default=Path("data/fbank"),
+            help="Path to directory with train/valid/test cuts.",
+        )
+        group.add_argument(
+            "--max-duration",
+            type=int,
+            default=200.0,
+            help="Maximum pooled recordings duration (seconds) in a "
+            "single batch. You can reduce it if it causes CUDA OOM.",
+        )
+        group.add_argument(
+            "--bucketing-sampler",
+            type=str2bool,
+            default=True,
+            help="When enabled, the batches will come from buckets of "
+            "similar duration (saves padding frames).",
+        )
+        group.add_argument(
+            "--num-buckets",
+            type=int,
+            default=30,
+            help="The number of buckets for the DynamicBucketingSampler"
+            "(you might want to increase it for larger datasets).",
+        )
+        group.add_argument(
+            "--concatenate-cuts",
+            type=str2bool,
+            default=False,
+            help="When enabled, utterances (cuts) will be concatenated "
+            "to minimize the amount of padding.",
+        )
+        group.add_argument(
+            "--duration-factor",
+            type=float,
+            default=1.0,
+            help="Determines the maximum duration of a concatenated cut "
+            "relative to the duration of the longest cut in a batch.",
+        )
+        group.add_argument(
+            "--gap",
+            type=float,
+            default=1.0,
+            help="The amount of padding (in seconds) inserted between "
+            "concatenated cuts. This padding is filled with noise when "
+            "noise augmentation is used.",
+        )
+        group.add_argument(
+            "--on-the-fly-feats",
+            type=str2bool,
+            default=False,
+            help="When enabled, use on-the-fly cut mixing and feature "
+            "extraction. Will drop existing precomputed feature manifests "
+            "if available.",
+        )
+        group.add_argument(
+            "--shuffle",
+            type=str2bool,
+            default=True,
+            help="When enabled (=default), the examples will be "
+            "shuffled for each epoch.",
+        )
+        group.add_argument(
+            "--drop-last",
+            type=str2bool,
+            default=True,
+            help="Whether to drop last batch. Used by sampler.",
+        )
+        group.add_argument(
+            "--return-cuts",
+            type=str2bool,
+            default=True,
+            help="When enabled, each batch will have the "
+            "field: batch['supervisions']['cut'] with the cuts that "
+            "were used to construct it.",
+        )
+
+        group.add_argument(
+            "--num-workers",
+            type=int,
+            default=1,
+            help="The number of training dataloader workers that "
+            "collect the batches.",
+        )
+
+        group.add_argument(
+            "--enable-spec-aug",
+            type=str2bool,
+            default=True,
+            help="When enabled, use SpecAugment for training dataset.",
+        )
+
+        group.add_argument(
+            "--spec-aug-time-warp-factor",
+            type=int,
+            default=80,
+            help="Used only when --enable-spec-aug is True. "
+            "It specifies the factor for time warping in SpecAugment. "
+            "Larger values mean more warping. "
+            "A value less than 1 means to disable time warp.",
+        )
+
+        group.add_argument(
+            "--enable-musan",
+            type=str2bool,
+            default=True,
+            help="When enabled, select noise from MUSAN and mix it"
+            "with training dataset. ",
+        )
+
+    def train_dataloaders(
+        self,
+        cuts_train: CutSet,
+        sampler_state_dict: Optional[Dict[str, Any]] = None,
+    ) -> DataLoader:
+
+        transforms = []
+        if self.args.enable_musan:
+            logging.info("Enable MUSAN")
+            logging.info("About to get Musan cuts")
+            cuts_musan = load_manifest(self.args.manifest_dir / "cuts_musan.jsonl.gz")
+
+            transforms.append(
+                CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
+            )
+        else:
+            logging.info("Disable MUSAN")
+
+        if self.args.concatenate_cuts:
+            logging.info(
+                f"Using cut concatenation with duration factor "
+                f"{self.args.duration_factor} and gap {self.args.gap}."
+            )
+            # Cut concatenation should be the first transform in the list,
+            # so that if we e.g. mix noise in, it will fill the gaps between
+            # different utterances.
+            transforms = [
+                CutConcatenate(
+                    duration_factor=self.args.duration_factor, gap=self.args.gap
+                )
+            ] + transforms
+
+        input_transforms = []
+        if self.args.enable_spec_aug:
+            logging.info("Enable SpecAugment")
+            logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
+            # Set the value of num_frame_masks according to Lhotse's version.
+            # In different Lhotse's versions, the default of num_frame_masks is
+            # different.
+            num_frame_masks = 10
+            num_frame_masks_parameter = inspect.signature(
+                SpecAugment.__init__
+            ).parameters["num_frame_masks"]
+            if num_frame_masks_parameter.default == 1:
+                num_frame_masks = 2
+            logging.info(f"Num frame mask: {num_frame_masks}")
+            input_transforms.append(
+                SpecAugment(
+                    time_warp_factor=self.args.spec_aug_time_warp_factor,
+                    num_frame_masks=num_frame_masks,
+                    features_mask_size=27,
+                    num_feature_masks=2,
+                    frames_mask_size=100,
+                )
+            )
+        else:
+            logging.info("Disable SpecAugment")
+
+        logging.info("About to create train dataset")
+        train = K2SpeechRecognitionDataset(
+            cut_transforms=transforms,
+            input_transforms=input_transforms,
+            return_cuts=self.args.return_cuts,
+        )
+
+        if self.args.on_the_fly_feats:
+            # NOTE: the PerturbSpeed transform should be added only if we
+            # remove it from data prep stage.
+            # Add on-the-fly speed perturbation; since originally it would
+            # have increased epoch size by 3, we will apply prob 2/3 and use
+            # 3x more epochs.
+            # Speed perturbation probably should come first before
+            # concatenation, but in principle the transforms order doesn't have
+            # to be strict (e.g. could be randomized)
+            # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms   # noqa
+            # Drop feats to be on the safe side.
+            train = K2SpeechRecognitionDataset(
+                cut_transforms=transforms,
+                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
+                input_transforms=input_transforms,
+                return_cuts=self.args.return_cuts,
+            )
+
+        if self.args.bucketing_sampler:
+            logging.info("Using DynamicBucketingSampler.")
+            train_sampler = DynamicBucketingSampler(
+                cuts_train,
+                max_duration=self.args.max_duration,
+                shuffle=self.args.shuffle,
+                num_buckets=self.args.num_buckets,
+                drop_last=self.args.drop_last,
+            )
+        else:
+            logging.info("Using SingleCutSampler.")
+            train_sampler = SingleCutSampler(
+                cuts_train,
+                max_duration=self.args.max_duration,
+                shuffle=self.args.shuffle,
+            )
+        logging.info("About to create train dataloader")
+
+        if sampler_state_dict is not None:
+            logging.info("Loading sampler state dict")
+            train_sampler.load_state_dict(sampler_state_dict)
+        # 'seed' is derived from the current random state, which will have
+        # previously been set in the main process.
+        seed = torch.randint(0, 100000, ()).item()
+        worker_init_fn = _SeedWorkers(seed)
+
+        train_dl = DataLoader(
+            train,
+            sampler=train_sampler,
+            batch_size=None,
+            num_workers=self.args.num_workers,
+            persistent_workers=False,
+            worker_init_fn=worker_init_fn,
+        )
+
+        return train_dl
+
+    def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
+        transforms = []
+        if self.args.concatenate_cuts:
+            transforms = [
+                CutConcatenate(
+                    duration_factor=self.args.duration_factor, gap=self.args.gap
+                )
+            ] + transforms
+
+        logging.info("About to create dev dataset")
+        if self.args.on_the_fly_feats:
+            validate = K2SpeechRecognitionDataset(
+                cut_transforms=transforms,
+                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
+                return_cuts=self.args.return_cuts,
+            )
+        else:
+            validate = K2SpeechRecognitionDataset(
+                cut_transforms=transforms,
+                return_cuts=self.args.return_cuts,
+            )
+        valid_sampler = DynamicBucketingSampler(
+            cuts_valid,
+            max_duration=self.args.max_duration,
+            shuffle=False,
+        )
+        logging.info("About to create dev dataloader")
+        valid_dl = DataLoader(
+            validate,
+            sampler=valid_sampler,
+            batch_size=None,
+            num_workers=2,
+            persistent_workers=False,
+        )
+
+        return valid_dl
+
+    def test_dataloaders(self, cuts: CutSet) -> DataLoader:
+        logging.debug("About to create test dataset")
+        test = K2SpeechRecognitionDataset(
+            input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
+            if self.args.on_the_fly_feats
+            else PrecomputedFeatures(),
+            return_cuts=self.args.return_cuts,
+        )
+        sampler = DynamicBucketingSampler(
+            cuts, max_duration=self.args.max_duration, shuffle=False
+        )
+        logging.debug("About to create test dataloader")
+        test_dl = DataLoader(
+            test,
+            batch_size=None,
+            sampler=sampler,
+            num_workers=self.args.num_workers,
+        )
+        return test_dl
+
+    @lru_cache()
+    def train_cuts(self) -> CutSet:
+        logging.info("About to get train cuts")
+        return load_manifest_lazy(self.args.manifest_dir / "cuts_train_shuf.jsonl.gz")
+
+    @lru_cache()
+    def dev_cuts(self) -> CutSet:
+        logging.info("About to get dev cuts")
+
+        return load_manifest_lazy(self.args.manifest_dir / "cuts_dev.jsonl.gz")
+
+    @lru_cache()
+    def test_cuts(self) -> CutSet:
+        logging.info("About to get test cuts")
+
+        return load_manifest_lazy(self.args.manifest_dir / "cuts_test.jsonl.gz")
diff --git a/egs/mgb2/ASR/conformer_ctc/compile_hlg.py b/egs/mgb2/ASR/conformer_ctc/compile_hlg.py
new file mode 120000
index 000000000..471aa7fb4
--- /dev/null
+++ b/egs/mgb2/ASR/conformer_ctc/compile_hlg.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/local/compile_hlg.py
\ No newline at end of file
diff --git a/egs/mgb2/ASR/conformer_ctc/compute_fbank_musan.py b/egs/mgb2/ASR/conformer_ctc/compute_fbank_musan.py
new file mode 120000
index 000000000..5833f2484
--- /dev/null
+++ b/egs/mgb2/ASR/conformer_ctc/compute_fbank_musan.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/local/compute_fbank_musan.py
\ No newline at end of file
diff --git a/egs/mgb2/ASR/conformer_ctc/conformer.py b/egs/mgb2/ASR/conformer_ctc/conformer.py
new file mode 120000
index 000000000..d1f4209d7
--- /dev/null
+++ b/egs/mgb2/ASR/conformer_ctc/conformer.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/conformer_ctc/conformer.py
\ No newline at end of file
diff --git a/egs/mgb2/ASR/conformer_ctc/convert_transcript_words_to_tokens.py b/egs/mgb2/ASR/conformer_ctc/convert_transcript_words_to_tokens.py
new file mode 120000
index 000000000..2ce13fd69
--- /dev/null
+++ b/egs/mgb2/ASR/conformer_ctc/convert_transcript_words_to_tokens.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/local/convert_transcript_words_to_tokens.py
\ No newline at end of file
diff --git a/egs/mgb2/ASR/conformer_ctc/decode.py b/egs/mgb2/ASR/conformer_ctc/decode.py
new file mode 100755
index 000000000..f771d7f1e
--- /dev/null
+++ b/egs/mgb2/ASR/conformer_ctc/decode.py
@@ -0,0 +1,695 @@
+#!/usr/bin/env python3
+# Copyright 2021 Xiaomi Corporation (Author: Liyong Guo, Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import argparse
+import logging
+import pdb
+from collections import defaultdict
+from pathlib import Path
+from typing import Dict, List, Optional, Tuple
+
+import k2
+import sentencepiece as spm
+import torch
+import torch.nn as nn
+from asr_datamodule import MGB2AsrDataModule
+from conformer import Conformer
+
+from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
+from icefall.checkpoint import average_checkpoints, load_checkpoint
+from icefall.decode import (
+    get_lattice,
+    nbest_decoding,
+    nbest_oracle,
+    one_best_decoding,
+    rescore_with_attention_decoder,
+    rescore_with_n_best_list,
+    rescore_with_whole_lattice,
+)
+from icefall.env import get_env_info
+from icefall.lexicon import Lexicon
+from icefall.utils import (
+    AttributeDict,
+    get_texts,
+    setup_logger,
+    store_transcripts,
+    write_error_stats,
+)
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+
+    parser.add_argument(
+        "--epoch",
+        type=int,
+        default=50,
+        help="It specifies the checkpoint to use for decoding."
+        "Note: Epoch counts from 0.",
+    )
+    parser.add_argument(
+        "--avg",
+        type=int,
+        default=5,
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch'. ",
+    )
+
+    parser.add_argument(
+        "--method",
+        type=str,
+        default="attention-decoder",
+        help="""Decoding method.
+        Supported values are:
+            - (0) ctc-decoding. Use CTC decoding. It uses a sentence piece
+              model, i.e., lang_dir/bpe.model, to convert word pieces to words.
+              It needs neither a lexicon nor an n-gram LM.
+            - (1) 1best. Extract the best path from the decoding lattice as the
+              decoding result.
+            - (2) nbest. Extract n paths from the decoding lattice; the path
+              with the highest score is the decoding result.
+            - (3) nbest-rescoring. Extract n paths from the decoding lattice,
+              rescore them with an n-gram LM (e.g., a 4-gram LM), the path with
+              the highest score is the decoding result.
+            - (4) whole-lattice-rescoring. Rescore the decoding lattice with an
+              n-gram LM (e.g., a 4-gram LM), the best path of rescored lattice
+              is the decoding result.
+            - (5) attention-decoder. Extract n paths from the LM rescored
+              lattice, the path with the highest score is the decoding result.
+            - (6) nbest-oracle. Its WER is the lower bound of any n-best
+              rescoring method can achieve. Useful for debugging n-best
+              rescoring method.
+        """,
+    )
+
+    parser.add_argument(
+        "--num-paths",
+        type=int,
+        default=20,
+        help="""Number of paths for n-best based decoding method.
+        Used only when "method" is one of the following values:
+        nbest, nbest-rescoring, attention-decoder, and nbest-oracle
+        """,
+    )
+
+    parser.add_argument(
+        "--nbest-scale",
+        type=float,
+        default=0.5,
+        help="""The scale to be applied to `lattice.scores`.
+        It's needed if you use any kinds of n-best based rescoring.
+        Used only when "method" is one of the following values:
+        nbest, nbest-rescoring, attention-decoder, and nbest-oracle
+        A smaller value results in more unique paths.
+        """,
+    )
+
+    parser.add_argument(
+        "--exp-dir",
+        type=str,
+        default="conformer_ctc/exp",
+        help="The experiment dir",
+    )
+
+    parser.add_argument(
+        "--lang-dir",
+        type=str,
+        default="data/lang_bpe_500",
+        help="The lang dir",
+    )
+
+    parser.add_argument(
+        "--lm-dir",
+        type=str,
+        default="data/lm",
+        help="""The LM dir.
+        It should contain either G_4_gram.pt or G_4_gram.fst.txt
+        """,
+    )
+
+    return parser
+
+
+def get_params() -> AttributeDict:
+    params = AttributeDict(
+        {
+            # parameters for conformer
+            "subsampling_factor": 4,
+            "vgg_frontend": False,
+            "use_feat_batchnorm": True,
+            "feature_dim": 80,
+            "nhead": 8,
+            "attention_dim": 512,
+            "num_decoder_layers": 6,
+            # parameters for decoding
+            "search_beam": 20,
+            "output_beam": 8,
+            "min_active_states": 30,
+            "max_active_states": 10000,
+            "use_double_scores": True,
+            "env_info": get_env_info(),
+        }
+    )
+    return params
+
+
+def decode_one_batch(
+    params: AttributeDict,
+    model: nn.Module,
+    HLG: Optional[k2.Fsa],
+    H: Optional[k2.Fsa],
+    bpe_model: Optional[spm.SentencePieceProcessor],
+    batch: dict,
+    word_table: k2.SymbolTable,
+    sos_id: int,
+    eos_id: int,
+    G: Optional[k2.Fsa] = None,
+) -> Dict[str, List[List[str]]]:
+    """Decode one batch and return the result in a dict. The dict has the
+    following format:
+
+        - key: It indicates the setting used for decoding. For example,
+               if no rescoring is used, the key is the string `no_rescore`.
+               If LM rescoring is used, the key is the string `lm_scale_xxx`,
+               where `xxx` is the value of `lm_scale`. An example key is
+               `lm_scale_0.7`
+        - value: It contains the decoding result. `len(value)` equals to
+                 batch size. `value[i]` is the decoding result for the i-th
+                 utterance in the given batch.
+    Args:
+      params:
+        It's the return value of :func:`get_params`.
+
+        - params.method is "1best", it uses 1best decoding without LM rescoring.
+        - params.method is "nbest", it uses nbest decoding without LM rescoring.
+        - params.method is "nbest-rescoring", it uses nbest LM rescoring.
+        - params.method is "whole-lattice-rescoring", it uses whole lattice LM
+          rescoring.
+
+      model:
+        The neural model.
+      HLG:
+        The decoding graph. Used only when params.method is NOT ctc-decoding.
+      H:
+        The ctc topo. Used only when params.method is ctc-decoding.
+      bpe_model:
+        The BPE model. Used only when params.method is ctc-decoding.
+      batch:
+        It is the return value from iterating
+        `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
+        for the format of the `batch`.
+      word_table:
+        The word symbol table.
+      sos_id:
+        The token ID of the SOS.
+      eos_id:
+        The token ID of the EOS.
+      G:
+        An LM. It is not None when params.method is "nbest-rescoring"
+        or "whole-lattice-rescoring". In general, the G in HLG
+        is a 3-gram LM, while this G is a 4-gram LM.
+    Returns:
+      Return the decoding result. See above description for the format of
+      the returned dict. Note: If it decodes to nothing, then return None.
+    """
+    if HLG is not None:
+        device = HLG.device
+    else:
+        device = H.device
+    feature = batch["inputs"]
+    assert feature.ndim == 3
+    feature = feature.to(device)
+    # at entry, feature is (N, T, C)
+
+    supervisions = batch["supervisions"]
+
+    nnet_output, memory, memory_key_padding_mask = model(feature, supervisions)
+    # nnet_output is (N, T, C)
+
+    supervision_segments = torch.stack(
+        (
+            supervisions["sequence_idx"],
+            supervisions["start_frame"] // params.subsampling_factor,
+            supervisions["num_frames"] // params.subsampling_factor,
+        ),
+        1,
+    ).to(torch.int32)
+
+    if H is None:
+        assert HLG is not None
+        decoding_graph = HLG
+    else:
+        assert HLG is None
+        assert bpe_model is not None
+        decoding_graph = H
+
+    lattice = get_lattice(
+        nnet_output=nnet_output,
+        decoding_graph=decoding_graph,
+        supervision_segments=supervision_segments,
+        search_beam=params.search_beam,
+        output_beam=params.output_beam,
+        min_active_states=params.min_active_states,
+        max_active_states=params.max_active_states,
+        subsampling_factor=params.subsampling_factor,
+    )
+
+    if params.method == "ctc-decoding":
+        best_path = one_best_decoding(
+            lattice=lattice, use_double_scores=params.use_double_scores
+        )
+        # Note: `best_path.aux_labels` contains token IDs, not word IDs
+        # since we are using H, not HLG here.
+        #
+        # token_ids is a lit-of-list of IDs
+        token_ids = get_texts(best_path)
+
+        # hyps is a list of str, e.g., ['xxx yyy zzz', ...]
+        hyps = bpe_model.decode(token_ids)
+
+        # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ]
+        hyps = [s.split() for s in hyps]
+        key = "ctc-decoding"
+        return {key: hyps}
+
+    if params.method == "nbest-oracle":
+        # Note: You can also pass rescored lattices to it.
+        # We choose the HLG decoded lattice for speed reasons
+        # as HLG decoding is faster and the oracle WER
+        # is only slightly worse than that of rescored lattices.
+        best_path = nbest_oracle(
+            lattice=lattice,
+            num_paths=params.num_paths,
+            ref_texts=supervisions["text"],
+            word_table=word_table,
+            nbest_scale=params.nbest_scale,
+            oov="",
+        )
+        hyps = get_texts(best_path)
+        hyps = [[word_table[i] for i in ids] for ids in hyps]
+        key = f"oracle_{params.num_paths}_nbest_scale_{params.nbest_scale}"  # noqa
+        return {key: hyps}
+
+    if params.method in ["1best", "nbest"]:
+        if params.method == "1best":
+            best_path = one_best_decoding(
+                lattice=lattice, use_double_scores=params.use_double_scores
+            )
+            key = "no_rescore"
+        else:
+            best_path = nbest_decoding(
+                lattice=lattice,
+                num_paths=params.num_paths,
+                use_double_scores=params.use_double_scores,
+                nbest_scale=params.nbest_scale,
+            )
+            key = f"no_rescore-nbest-scale-{params.nbest_scale}-{params.num_paths}"  # noqa
+
+        hyps = get_texts(best_path)
+        hyps = [[word_table[i] for i in ids] for ids in hyps]
+        return {key: hyps}
+
+    assert params.method in [
+        "nbest-rescoring",
+        "whole-lattice-rescoring",
+        "attention-decoder",
+    ]
+
+    lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]
+    lm_scale_list += [0.8, 0.9, 1.0, 1.1, 1.2, 1.3]
+    lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0]
+
+    if params.method == "nbest-rescoring":
+        best_path_dict = rescore_with_n_best_list(
+            lattice=lattice,
+            G=G,
+            num_paths=params.num_paths,
+            lm_scale_list=lm_scale_list,
+            nbest_scale=params.nbest_scale,
+        )
+    elif params.method == "whole-lattice-rescoring":
+        best_path_dict = rescore_with_whole_lattice(
+            lattice=lattice,
+            G_with_epsilon_loops=G,
+            lm_scale_list=lm_scale_list,
+        )
+    elif params.method == "attention-decoder":
+        # lattice uses a 3-gram Lm. We rescore it with a 4-gram LM.
+        rescored_lattice = rescore_with_whole_lattice(
+            lattice=lattice,
+            G_with_epsilon_loops=G,
+            lm_scale_list=None,
+        )
+        # TODO: pass `lattice` instead of `rescored_lattice` to
+        # `rescore_with_attention_decoder`
+
+        best_path_dict = rescore_with_attention_decoder(
+            lattice=rescored_lattice,
+            num_paths=params.num_paths,
+            model=model,
+            memory=memory,
+            memory_key_padding_mask=memory_key_padding_mask,
+            sos_id=sos_id,
+            eos_id=eos_id,
+            nbest_scale=params.nbest_scale,
+        )
+    else:
+        assert False, f"Unsupported decoding method: {params.method}"
+
+    ans = dict()
+    if best_path_dict is not None:
+        for lm_scale_str, best_path in best_path_dict.items():
+            hyps = get_texts(best_path)
+            hyps = [[word_table[i] for i in ids] for ids in hyps]
+            ans[lm_scale_str] = hyps
+    else:
+        ans = None
+    return ans
+
+
+def decode_dataset(
+    dl: torch.utils.data.DataLoader,
+    params: AttributeDict,
+    model: nn.Module,
+    HLG: Optional[k2.Fsa],
+    H: Optional[k2.Fsa],
+    bpe_model: Optional[spm.SentencePieceProcessor],
+    word_table: k2.SymbolTable,
+    sos_id: int,
+    eos_id: int,
+    G: Optional[k2.Fsa] = None,
+) -> Dict[str, List[Tuple[List[str], List[str]]]]:
+    """Decode dataset.
+
+    Args:
+      dl:
+        PyTorch's dataloader containing the dataset to decode.
+      params:
+        It is returned by :func:`get_params`.
+      model:
+        The neural model.
+      HLG:
+        The decoding graph. Used only when params.method is NOT ctc-decoding.
+      H:
+        The ctc topo. Used only when params.method is ctc-decoding.
+      bpe_model:
+        The BPE model. Used only when params.method is ctc-decoding.
+      word_table:
+        It is the word symbol table.
+      sos_id:
+        The token ID for SOS.
+      eos_id:
+        The token ID for EOS.
+      G:
+        An LM. It is not None when params.method is "nbest-rescoring"
+        or "whole-lattice-rescoring". In general, the G in HLG
+        is a 3-gram LM, while this G is a 4-gram LM.
+    Returns:
+      Return a dict, whose key may be "no-rescore" if no LM rescoring
+      is used, or it may be "lm_scale_0.7" if LM rescoring is used.
+      Its value is a list of tuples. Each tuple contains two elements:
+      The first is the reference transcript, and the second is the
+      predicted result.
+    """
+    num_cuts = 0
+
+    try:
+        num_batches = len(dl)
+    except TypeError:
+        num_batches = "?"
+
+    results = defaultdict(list)
+    for batch_idx, batch in enumerate(dl):
+        # pdb.set_trace()
+        texts = batch["supervisions"]["text"]
+
+        hyps_dict = decode_one_batch(
+            params=params,
+            model=model,
+            HLG=HLG,
+            H=H,
+            bpe_model=bpe_model,
+            batch=batch,
+            word_table=word_table,
+            G=G,
+            sos_id=sos_id,
+            eos_id=eos_id,
+        )
+
+        if hyps_dict is not None:
+            for lm_scale, hyps in hyps_dict.items():
+                this_batch = []
+                assert len(hyps) == len(texts)
+                for hyp_words, ref_text in zip(hyps, texts):
+                    ref_words = ref_text.split()
+                    this_batch.append((ref_words, hyp_words))
+
+                results[lm_scale].extend(this_batch)
+        else:
+            assert len(results) > 0, "It should not decode to empty in the first batch!"
+            this_batch = []
+            hyp_words = []
+            for ref_text in texts:
+                ref_words = ref_text.split()
+                this_batch.append((ref_words, hyp_words))
+
+            for lm_scale in results.keys():
+                results[lm_scale].extend(this_batch)
+
+        num_cuts += len(texts)
+
+        if batch_idx % 100 == 0:
+            batch_str = f"{batch_idx}/{num_batches}"
+
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+    return results
+
+
+def save_results(
+    params: AttributeDict,
+    test_set_name: str,
+    results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
+):
+    if params.method == "attention-decoder":
+        # Set it to False since there are too many logs.
+        enable_log = False
+    else:
+        enable_log = True
+    test_set_wers = dict()
+    for key, results in results_dict.items():
+        recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt"
+        store_transcripts(filename=recog_path, texts=results)
+        if enable_log:
+            logging.info(f"The transcripts are stored in {recog_path}")
+
+        # The following prints out WERs, per-word error statistics and aligned
+        # ref/hyp pairs.
+        errs_filename = params.exp_dir / f"errs-{test_set_name}-{key}.txt"
+        with open(errs_filename, "w") as f:
+            wer = write_error_stats(
+                f, f"{test_set_name}-{key}", results, enable_log=enable_log
+            )
+            test_set_wers[key] = wer
+
+        if enable_log:
+            logging.info("Wrote detailed error stats to {}".format(errs_filename))
+
+    test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
+    errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt"
+    with open(errs_info, "w") as f:
+        print("settings\tWER", file=f)
+        for key, val in test_set_wers:
+            print("{}\t{}".format(key, val), file=f)
+
+    s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
+    note = "\tbest for {}".format(test_set_name)
+    for key, val in test_set_wers:
+        s += "{}\t{}{}\n".format(key, val, note)
+        note = ""
+    logging.info(s)
+
+
+@torch.no_grad()
+def main():
+    parser = get_parser()
+    MGB2AsrDataModule.add_arguments(parser)
+    args = parser.parse_args()
+    args.exp_dir = Path(args.exp_dir)
+    args.lang_dir = Path(args.lang_dir)
+    args.lm_dir = Path(args.lm_dir)
+
+    params = get_params()
+    params.update(vars(args))
+
+    setup_logger(f"{params.exp_dir}/log-{params.method}/log-decode")
+    logging.info("Decoding started")
+    logging.info(params)
+
+    lexicon = Lexicon(params.lang_dir)
+    max_token_id = max(lexicon.tokens)
+    num_classes = max_token_id + 1  # +1 for the blank
+
+    device = torch.device("cpu")
+    if torch.cuda.is_available():
+        device = torch.device("cuda", 0)
+
+    logging.info(f"device: {device}")
+
+    graph_compiler = BpeCtcTrainingGraphCompiler(
+        params.lang_dir,
+        device=device,
+        sos_token="",
+        eos_token="",
+    )
+    sos_id = graph_compiler.sos_id
+    eos_id = graph_compiler.eos_id
+
+    if params.method == "ctc-decoding":
+        HLG = None
+        H = k2.ctc_topo(
+            max_token=max_token_id,
+            modified=False,
+            device=device,
+        )
+        bpe_model = spm.SentencePieceProcessor()
+        bpe_model.load(str(params.lang_dir / "bpe.model"))
+    else:
+        H = None
+        bpe_model = None
+        HLG = k2.Fsa.from_dict(
+            torch.load(f"{params.lang_dir}/HLG.pt", map_location=device)
+        )
+        assert HLG.requires_grad is False
+
+        if not hasattr(HLG, "lm_scores"):
+            HLG.lm_scores = HLG.scores.clone()
+
+    if params.method in (
+        "nbest-rescoring",
+        "whole-lattice-rescoring",
+        "attention-decoder",
+    ):
+        if not (params.lm_dir / "G_4_gram.pt").is_file():
+            logging.info("Loading G_4_gram.fst.txt")
+            logging.warning("It may take 8 minutes.")
+            with open(params.lm_dir / "G_4_gram.fst.txt") as f:
+                first_word_disambig_id = lexicon.word_table["#0"]
+
+                G = k2.Fsa.from_openfst(f.read(), acceptor=False)
+                # G.aux_labels is not needed in later computations, so
+                # remove it here.
+                del G.aux_labels
+                # CAUTION: The following line is crucial.
+                # Arcs entering the back-off state have label equal to #0.
+                # We have to change it to 0 here.
+                G.labels[G.labels >= first_word_disambig_id] = 0
+                # See https://github.com/k2-fsa/k2/issues/874
+                # for why we need to set G.properties to None
+                G.__dict__["_properties"] = None
+                G = k2.Fsa.from_fsas([G]).to(device)
+                G = k2.arc_sort(G)
+                # Save a dummy value so that it can be loaded in C++.
+                # See https://github.com/pytorch/pytorch/issues/67902
+                # for why we need to do this.
+                G.dummy = 1
+
+                torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt")
+        else:
+            logging.info("Loading pre-compiled G_4_gram.pt")
+            d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device)
+            G = k2.Fsa.from_dict(d)
+
+        if params.method in ["whole-lattice-rescoring", "attention-decoder"]:
+            # Add epsilon self-loops to G as we will compose
+            # it with the whole lattice later
+            G = k2.add_epsilon_self_loops(G)
+            G = k2.arc_sort(G)
+            G = G.to(device)
+
+        # G.lm_scores is used to replace HLG.lm_scores during
+        # LM rescoring.
+        G.lm_scores = G.scores.clone()
+    else:
+        G = None
+
+    model = Conformer(
+        num_features=params.feature_dim,
+        nhead=params.nhead,
+        d_model=params.attention_dim,
+        num_classes=num_classes,
+        subsampling_factor=params.subsampling_factor,
+        num_decoder_layers=params.num_decoder_layers,
+        vgg_frontend=params.vgg_frontend,
+        use_feat_batchnorm=params.use_feat_batchnorm,
+    )
+
+    if params.avg == 1:
+        load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+    else:
+        start = params.epoch - params.avg + 1
+        filenames = []
+        for i in range(start, params.epoch + 1):
+            if start >= 0:
+                filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+        logging.info(f"averaging {filenames}")
+        model.to(device)
+        model.load_state_dict(average_checkpoints(filenames, device=device))
+
+    model.to(device)
+    model.eval()
+    num_param = sum([p.numel() for p in model.parameters()])
+    logging.info(f"Number of model parameters: {num_param}")
+
+    MGB2 = MGB2AsrDataModule(args)
+
+    test_cuts = MGB2.test_cuts()
+    dev_cuts = MGB2.dev_cuts()
+
+    test_dl = MGB2.test_dataloaders(test_cuts)
+    dev_dl = MGB2.test_dataloaders(dev_cuts)
+
+    test_sets = ["test", "dev"]
+    test_all_dl = [test_dl, dev_dl]
+
+    for test_set, test_dl in zip(test_sets, test_all_dl):
+        results_dict = decode_dataset(
+            dl=test_dl,
+            params=params,
+            model=model,
+            HLG=HLG,
+            H=H,
+            bpe_model=bpe_model,
+            word_table=lexicon.word_table,
+            G=G,
+            sos_id=sos_id,
+            eos_id=eos_id,
+        )
+
+        save_results(params=params, test_set_name=test_set, results_dict=results_dict)
+
+    logging.info("Done!")
+
+
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+if __name__ == "__main__":
+    main()
diff --git a/egs/mgb2/ASR/conformer_ctc/download_lm.py b/egs/mgb2/ASR/conformer_ctc/download_lm.py
new file mode 120000
index 000000000..c9668bd2d
--- /dev/null
+++ b/egs/mgb2/ASR/conformer_ctc/download_lm.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/local/download_lm.py
\ No newline at end of file
diff --git a/egs/mgb2/ASR/conformer_ctc/export.py b/egs/mgb2/ASR/conformer_ctc/export.py
new file mode 120000
index 000000000..60e314d9d
--- /dev/null
+++ b/egs/mgb2/ASR/conformer_ctc/export.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/conformer_ctc/export.py
\ No newline at end of file
diff --git a/egs/mgb2/ASR/conformer_ctc/generate_unique_lexicon.py b/egs/mgb2/ASR/conformer_ctc/generate_unique_lexicon.py
new file mode 120000
index 000000000..c0aea1403
--- /dev/null
+++ b/egs/mgb2/ASR/conformer_ctc/generate_unique_lexicon.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/local/generate_unique_lexicon.py
\ No newline at end of file
diff --git a/egs/mgb2/ASR/conformer_ctc/label_smoothing.py b/egs/mgb2/ASR/conformer_ctc/label_smoothing.py
new file mode 120000
index 000000000..e9d239fff
--- /dev/null
+++ b/egs/mgb2/ASR/conformer_ctc/label_smoothing.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/conformer_ctc/label_smoothing.py
\ No newline at end of file
diff --git a/egs/mgb2/ASR/conformer_ctc/pretrained.py b/egs/mgb2/ASR/conformer_ctc/pretrained.py
new file mode 100755
index 000000000..d30ca98d8
--- /dev/null
+++ b/egs/mgb2/ASR/conformer_ctc/pretrained.py
@@ -0,0 +1,430 @@
+#!/usr/bin/env python3
+# Copyright      2021  Xiaomi Corp.        (authors: Fangjun Kuang,
+#                                                    Mingshuang Luo)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import argparse
+import logging
+import math
+from typing import List
+
+import k2
+import kaldifeat
+import sentencepiece as spm
+import torch
+import torchaudio
+from conformer import Conformer
+from torch.nn.utils.rnn import pad_sequence
+
+from icefall.decode import (
+    get_lattice,
+    one_best_decoding,
+    rescore_with_attention_decoder,
+    rescore_with_whole_lattice,
+)
+from icefall.utils import AttributeDict, get_texts
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+
+    parser.add_argument(
+        "--checkpoint",
+        type=str,
+        required=True,
+        help="Path to the checkpoint. "
+        "The checkpoint is assumed to be saved by "
+        "icefall.checkpoint.save_checkpoint().",
+    )
+
+    parser.add_argument(
+        "--words-file",
+        type=str,
+        help="""Path to words.txt.
+        Used only when method is not ctc-decoding.
+        """,
+    )
+
+    parser.add_argument(
+        "--HLG",
+        type=str,
+        help="""Path to HLG.pt.
+        Used only when method is not ctc-decoding.
+        """,
+    )
+
+    parser.add_argument(
+        "--bpe-model",
+        type=str,
+        help="""Path to bpe.model.
+        Used only when method is ctc-decoding.
+        """,
+    )
+
+    parser.add_argument(
+        "--method",
+        type=str,
+        default="1best",
+        help="""Decoding method.
+        Possible values are:
+        (0) ctc-decoding - Use CTC decoding. It uses a sentence
+            piece model, i.e., lang_dir/bpe.model, to convert
+            word pieces to words. It needs neither a lexicon
+            nor an n-gram LM.
+        (1) 1best - Use the best path as decoding output. Only
+            the transformer encoder output is used for decoding.
+            We call it HLG decoding.
+        (2) whole-lattice-rescoring - Use an LM to rescore the
+            decoding lattice and then use 1best to decode the
+            rescored lattice.
+            We call it HLG decoding + n-gram LM rescoring.
+        (3) attention-decoder - Extract n paths from the rescored
+            lattice and use the transformer attention decoder for
+            rescoring.
+            We call it HLG decoding + n-gram LM rescoring + attention
+            decoder rescoring.
+        """,
+    )
+
+    parser.add_argument(
+        "--G",
+        type=str,
+        help="""An LM for rescoring.
+        Used only when method is
+        whole-lattice-rescoring or attention-decoder.
+        It's usually a 4-gram LM.
+        """,
+    )
+
+    parser.add_argument(
+        "--num-paths",
+        type=int,
+        default=100,
+        help="""
+        Used only when method is attention-decoder.
+        It specifies the size of n-best list.""",
+    )
+
+    parser.add_argument(
+        "--ngram-lm-scale",
+        type=float,
+        default=1.3,
+        help="""
+        Used only when method is whole-lattice-rescoring and attention-decoder.
+        It specifies the scale for n-gram LM scores.
+        (Note: You need to tune it on a dataset.)
+        """,
+    )
+
+    parser.add_argument(
+        "--attention-decoder-scale",
+        type=float,
+        default=1.2,
+        help="""
+        Used only when method is attention-decoder.
+        It specifies the scale for attention decoder scores.
+        (Note: You need to tune it on a dataset.)
+        """,
+    )
+
+    parser.add_argument(
+        "--nbest-scale",
+        type=float,
+        default=0.5,
+        help="""
+        Used only when method is attention-decoder.
+        It specifies the scale for lattice.scores when
+        extracting n-best lists. A smaller value results in
+        more unique number of paths with the risk of missing
+        the best path.
+        """,
+    )
+
+    parser.add_argument(
+        "--sos-id",
+        type=int,
+        default=1,
+        help="""
+        Used only when method is attention-decoder.
+        It specifies ID for the SOS token.
+        """,
+    )
+
+    parser.add_argument(
+        "--num-classes",
+        type=int,
+        default=500,
+        help="""
+        Vocab size in the BPE model.
+        """,
+    )
+
+    parser.add_argument(
+        "--eos-id",
+        type=int,
+        default=1,
+        help="""
+        Used only when method is attention-decoder.
+        It specifies ID for the EOS token.
+        """,
+    )
+
+    parser.add_argument(
+        "sound_files",
+        type=str,
+        nargs="+",
+        help="The input sound file(s) to transcribe. "
+        "Supported formats are those supported by torchaudio.load(). "
+        "For example, wav and flac are supported. "
+        "The sample rate has to be 16kHz.",
+    )
+
+    return parser
+
+
+def get_params() -> AttributeDict:
+    params = AttributeDict(
+        {
+            "sample_rate": 16000,
+            # parameters for conformer
+            "subsampling_factor": 4,
+            "vgg_frontend": False,
+            "use_feat_batchnorm": True,
+            "feature_dim": 80,
+            "nhead": 8,
+            "attention_dim": 512,
+            "num_decoder_layers": 6,
+            # parameters for decoding
+            "search_beam": 20,
+            "output_beam": 8,
+            "min_active_states": 30,
+            "max_active_states": 10000,
+            "use_double_scores": True,
+        }
+    )
+    return params
+
+
+def read_sound_files(
+    filenames: List[str], expected_sample_rate: float
+) -> List[torch.Tensor]:
+    """Read a list of sound files into a list 1-D float32 torch tensors.
+    Args:
+      filenames:
+        A list of sound filenames.
+      expected_sample_rate:
+        The expected sample rate of the sound files.
+    Returns:
+      Return a list of 1-D float32 torch tensors.
+    """
+    ans = []
+    for f in filenames:
+        wave, sample_rate = torchaudio.load(f)
+        assert sample_rate == expected_sample_rate, (
+            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
+        )
+        # We use only the first channel
+        ans.append(wave[0])
+    return ans
+
+
+def main():
+    parser = get_parser()
+    args = parser.parse_args()
+
+    params = get_params()
+    if args.method != "attention-decoder":
+        # to save memory as the attention decoder
+        # will not be used
+        params.num_decoder_layers = 0
+
+    params.update(vars(args))
+    logging.info(f"{params}")
+
+    device = torch.device("cpu")
+    if torch.cuda.is_available():
+        device = torch.device("cuda", 0)
+
+    logging.info(f"device: {device}")
+
+    logging.info("Creating model")
+    model = Conformer(
+        num_features=params.feature_dim,
+        nhead=params.nhead,
+        d_model=params.attention_dim,
+        num_classes=params.num_classes,
+        subsampling_factor=params.subsampling_factor,
+        num_decoder_layers=params.num_decoder_layers,
+        vgg_frontend=params.vgg_frontend,
+        use_feat_batchnorm=params.use_feat_batchnorm,
+    )
+
+    checkpoint = torch.load(args.checkpoint, map_location="cpu")
+    model.load_state_dict(checkpoint["model"], strict=False)
+    model.to(device)
+    model.eval()
+
+    logging.info("Constructing Fbank computer")
+    opts = kaldifeat.FbankOptions()
+    opts.device = device
+    opts.frame_opts.dither = 0
+    opts.frame_opts.snip_edges = False
+    opts.frame_opts.samp_freq = params.sample_rate
+    opts.mel_opts.num_bins = params.feature_dim
+
+    fbank = kaldifeat.Fbank(opts)
+
+    logging.info(f"Reading sound files: {params.sound_files}")
+    waves = read_sound_files(
+        filenames=params.sound_files, expected_sample_rate=params.sample_rate
+    )
+    waves = [w.to(device) for w in waves]
+
+    logging.info("Decoding started")
+    features = fbank(waves)
+
+    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
+
+    # Note: We don't use key padding mask for attention during decoding
+    with torch.no_grad():
+        nnet_output, memory, memory_key_padding_mask = model(features)
+
+    batch_size = nnet_output.shape[0]
+    supervision_segments = torch.tensor(
+        [[i, 0, nnet_output.shape[1]] for i in range(batch_size)],
+        dtype=torch.int32,
+    )
+
+    if params.method == "ctc-decoding":
+        logging.info("Use CTC decoding")
+        bpe_model = spm.SentencePieceProcessor()
+        bpe_model.load(params.bpe_model)
+        max_token_id = params.num_classes - 1
+
+        H = k2.ctc_topo(
+            max_token=max_token_id,
+            modified=False,
+            device=device,
+        )
+
+        lattice = get_lattice(
+            nnet_output=nnet_output,
+            decoding_graph=H,
+            supervision_segments=supervision_segments,
+            search_beam=params.search_beam,
+            output_beam=params.output_beam,
+            min_active_states=params.min_active_states,
+            max_active_states=params.max_active_states,
+            subsampling_factor=params.subsampling_factor,
+        )
+
+        best_path = one_best_decoding(
+            lattice=lattice, use_double_scores=params.use_double_scores
+        )
+        token_ids = get_texts(best_path)
+        hyps = bpe_model.decode(token_ids)
+        hyps = [s.split() for s in hyps]
+    elif params.method in [
+        "1best",
+        "whole-lattice-rescoring",
+        "attention-decoder",
+    ]:
+        logging.info(f"Loading HLG from {params.HLG}")
+        HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu"))
+        HLG = HLG.to(device)
+        if not hasattr(HLG, "lm_scores"):
+            # For whole-lattice-rescoring and attention-decoder
+            HLG.lm_scores = HLG.scores.clone()
+
+        if params.method in [
+            "whole-lattice-rescoring",
+            "attention-decoder",
+        ]:
+            logging.info(f"Loading G from {params.G}")
+            G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu"))
+            # Add epsilon self-loops to G as we will compose
+            # it with the whole lattice later
+            G = G.to(device)
+            G = k2.add_epsilon_self_loops(G)
+            G = k2.arc_sort(G)
+            G.lm_scores = G.scores.clone()
+
+        lattice = get_lattice(
+            nnet_output=nnet_output,
+            decoding_graph=HLG,
+            supervision_segments=supervision_segments,
+            search_beam=params.search_beam,
+            output_beam=params.output_beam,
+            min_active_states=params.min_active_states,
+            max_active_states=params.max_active_states,
+            subsampling_factor=params.subsampling_factor,
+        )
+
+        if params.method == "1best":
+            logging.info("Use HLG decoding")
+            best_path = one_best_decoding(
+                lattice=lattice, use_double_scores=params.use_double_scores
+            )
+        elif params.method == "whole-lattice-rescoring":
+            logging.info("Use HLG decoding + LM rescoring")
+            best_path_dict = rescore_with_whole_lattice(
+                lattice=lattice,
+                G_with_epsilon_loops=G,
+                lm_scale_list=[params.ngram_lm_scale],
+            )
+            best_path = next(iter(best_path_dict.values()))
+        elif params.method == "attention-decoder":
+            logging.info("Use HLG + LM rescoring + attention decoder rescoring")
+            rescored_lattice = rescore_with_whole_lattice(
+                lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=None
+            )
+            best_path_dict = rescore_with_attention_decoder(
+                lattice=rescored_lattice,
+                num_paths=params.num_paths,
+                model=model,
+                memory=memory,
+                memory_key_padding_mask=memory_key_padding_mask,
+                sos_id=params.sos_id,
+                eos_id=params.eos_id,
+                nbest_scale=params.nbest_scale,
+                ngram_lm_scale=params.ngram_lm_scale,
+                attention_scale=params.attention_decoder_scale,
+            )
+            best_path = next(iter(best_path_dict.values()))
+
+        hyps = get_texts(best_path)
+        word_sym_table = k2.SymbolTable.from_file(params.words_file)
+        hyps = [[word_sym_table[i] for i in ids] for ids in hyps]
+    else:
+        raise ValueError(f"Unsupported decoding method: {params.method}")
+
+    s = "\n"
+    for filename, hyp in zip(params.sound_files, hyps):
+        words = " ".join(hyp)
+        s += f"{filename}:\n{words}\n\n"
+    logging.info(s)
+
+    logging.info("Decoding Done")
+
+
+if __name__ == "__main__":
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+    logging.basicConfig(format=formatter, level=logging.INFO)
+    main()
diff --git a/egs/mgb2/ASR/conformer_ctc/subsampling.py b/egs/mgb2/ASR/conformer_ctc/subsampling.py
new file mode 120000
index 000000000..16354dc73
--- /dev/null
+++ b/egs/mgb2/ASR/conformer_ctc/subsampling.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/conformer_ctc/subsampling.py
\ No newline at end of file
diff --git a/egs/mgb2/ASR/conformer_ctc/test_label_smoothing.py b/egs/mgb2/ASR/conformer_ctc/test_label_smoothing.py
new file mode 120000
index 000000000..04b959ecf
--- /dev/null
+++ b/egs/mgb2/ASR/conformer_ctc/test_label_smoothing.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/conformer_ctc/test_label_smoothing.py
\ No newline at end of file
diff --git a/egs/mgb2/ASR/conformer_ctc/test_subsampling.py b/egs/mgb2/ASR/conformer_ctc/test_subsampling.py
new file mode 120000
index 000000000..98c3be3e6
--- /dev/null
+++ b/egs/mgb2/ASR/conformer_ctc/test_subsampling.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/conformer_ctc/test_subsampling.py
\ No newline at end of file
diff --git a/egs/mgb2/ASR/conformer_ctc/test_transformer.py b/egs/mgb2/ASR/conformer_ctc/test_transformer.py
new file mode 120000
index 000000000..8b0990ec6
--- /dev/null
+++ b/egs/mgb2/ASR/conformer_ctc/test_transformer.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/conformer_ctc/test_transformer.py
\ No newline at end of file
diff --git a/egs/mgb2/ASR/conformer_ctc/train.py b/egs/mgb2/ASR/conformer_ctc/train.py
new file mode 100755
index 000000000..08ffee210
--- /dev/null
+++ b/egs/mgb2/ASR/conformer_ctc/train.py
@@ -0,0 +1,766 @@
+#!/usr/bin/env python3
+# Copyright 2022 Johns Hopkins University  (Amir Hussein)
+# Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)
+
+
+import argparse
+import logging
+from pathlib import Path
+from shutil import copyfile
+from typing import Optional, Tuple
+
+import k2
+import torch
+import torch.multiprocessing as mp
+import torch.nn as nn
+from asr_datamodule import MGB2AsrDataModule
+from conformer import Conformer
+from lhotse.cut import Cut
+from lhotse.utils import fix_random_seed
+from torch import Tensor
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.nn.utils import clip_grad_norm_
+from torch.utils.tensorboard import SummaryWriter
+from transformer import Noam
+
+from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
+from icefall.checkpoint import load_checkpoint
+from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
+from icefall.dist import cleanup_dist, setup_dist
+from icefall.env import get_env_info
+from icefall.lexicon import Lexicon
+from icefall.utils import (
+    AttributeDict,
+    MetricsTracker,
+    encode_supervisions,
+    setup_logger,
+    str2bool,
+)
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+
+    parser.add_argument(
+        "--world-size",
+        type=int,
+        default=1,
+        help="Number of GPUs for DDP training.",
+    )
+
+    parser.add_argument(
+        "--master-port",
+        type=int,
+        default=12354,
+        help="Master port to use for DDP training.",
+    )
+
+    parser.add_argument(
+        "--tensorboard",
+        type=str2bool,
+        default=True,
+        help="Should various information be logged in tensorboard.",
+    )
+
+    parser.add_argument(
+        "--num-epochs",
+        type=int,
+        default=50,
+        help="Number of epochs to train.",
+    )
+
+    parser.add_argument(
+        "--start-epoch",
+        type=int,
+        default=0,
+        help="""Resume training from from this epoch.
+        If it is positive, it will load checkpoint from
+        conformer_ctc/exp/epoch-{start_epoch-1}.pt
+        """,
+    )
+
+    parser.add_argument(
+        "--exp-dir",
+        type=str,
+        default="conformer_ctc/exp",
+        help="""The experiment dir.
+        It specifies the directory where all training related
+        files, e.g., checkpoints, log, etc, are saved
+        """,
+    )
+
+    parser.add_argument(
+        "--lang-dir",
+        type=str,
+        default="data/lang_bpe_500",
+        help="""The lang dir
+        It contains language related input files such as
+        "lexicon.txt"
+        """,
+    )
+
+    parser.add_argument(
+        "--att-rate",
+        type=float,
+        default=0.8,
+        help="""The attention rate.
+        The total loss is (1 -  att_rate) * ctc_loss + att_rate * att_loss
+        """,
+    )
+
+    parser.add_argument(
+        "--num-decoder-layers",
+        type=int,
+        default=6,
+        help="""Number of decoder layer of transformer decoder.
+        Setting this to 0 will not create the decoder at all (pure CTC model)
+        """,
+    )
+
+    parser.add_argument(
+        "--lr-factor",
+        type=float,
+        default=5.0,
+        help="The lr_factor for Noam optimizer",
+    )
+
+    return parser
+
+
+def get_params() -> AttributeDict:
+    """Return a dict containing training parameters.
+
+    All training related parameters that are not passed from the commandline
+    are saved in the variable `params`.
+
+    Commandline options are merged into `params` after they are parsed, so
+    you can also access them via `params`.
+
+    Explanation of options saved in `params`:
+
+        - best_train_loss: Best training loss so far. It is used to select
+                           the model that has the lowest training loss. It is
+                           updated during the training.
+
+        - best_valid_loss: Best validation loss so far. It is used to select
+                           the model that has the lowest validation loss. It is
+                           updated during the training.
+
+        - best_train_epoch: It is the epoch that has the best training loss.
+
+        - best_valid_epoch: It is the epoch that has the best validation loss.
+
+        - batch_idx_train: Used to writing statistics to tensorboard. It
+                           contains number of batches trained so far across
+                           epochs.
+
+        - log_interval:  Print training loss if batch_idx % log_interval` is 0
+
+        - reset_interval: Reset statistics if batch_idx % reset_interval is 0
+
+        - valid_interval:  Run validation if batch_idx % valid_interval is 0
+
+        - feature_dim: The model input dim. It has to match the one used
+                       in computing features.
+
+        - subsampling_factor:  The subsampling factor for the model.
+
+        - use_feat_batchnorm: Normalization for the input features, can be a
+                              boolean indicating whether to do batch
+                              normalization, or a float which means just scaling
+                              the input features with this float value.
+                              If given a float value, we will remove batchnorm
+                              layer in `ConvolutionModule` as well.
+
+        - attention_dim: Hidden dim for multi-head attention model.
+
+        - head: Number of heads of multi-head attention model.
+
+        - num_decoder_layers: Number of decoder layer of transformer decoder.
+
+        - beam_size: It is used in k2.ctc_loss
+
+        - reduction: It is used in k2.ctc_loss
+
+        - use_double_scores: It is used in k2.ctc_loss
+
+        - weight_decay:  The weight_decay for the optimizer.
+
+        - warm_step: The warm_step for Noam optimizer.
+    """
+    params = AttributeDict(
+        {
+            "best_train_loss": float("inf"),
+            "best_valid_loss": float("inf"),
+            "best_train_epoch": -1,
+            "best_valid_epoch": -1,
+            "batch_idx_train": 0,
+            "log_interval": 50,
+            "reset_interval": 200,
+            "valid_interval": 3000,
+            # parameters for conformer
+            "feature_dim": 80,
+            "subsampling_factor": 4,
+            "use_feat_batchnorm": True,
+            "attention_dim": 512,
+            "nhead": 8,
+            "num_decoder_layers": 6,
+            # parameters for loss
+            "beam_size": 10,
+            "reduction": "sum",
+            "use_double_scores": True,
+            # parameters for Noam
+            "weight_decay": 1e-6,
+            "warm_step": 80000,
+            "env_info": get_env_info(),
+        }
+    )
+
+    return params
+
+
+def load_checkpoint_if_available(
+    params: AttributeDict,
+    model: nn.Module,
+    optimizer: Optional[torch.optim.Optimizer] = None,
+    scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
+) -> None:
+    """Load checkpoint from file.
+
+    If params.start_epoch is positive, it will load the checkpoint from
+    `params.start_epoch - 1`. Otherwise, this function does nothing.
+
+    Apart from loading state dict for `model`, `optimizer` and `scheduler`,
+    it also updates `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
+    and `best_valid_loss` in `params`.
+
+    Args:
+      params:
+        The return value of :func:`get_params`.
+      model:
+        The training model.
+      optimizer:
+        The optimizer that we are using.
+      scheduler:
+        The learning rate scheduler we are using.
+    Returns:
+      Return None.
+    """
+    if params.start_epoch <= 0:
+        return
+
+    filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
+    saved_params = load_checkpoint(
+        filename,
+        model=model,
+        optimizer=optimizer,
+        scheduler=scheduler,
+    )
+
+    keys = [
+        "best_train_epoch",
+        "best_valid_epoch",
+        "batch_idx_train",
+        "best_train_loss",
+        "best_valid_loss",
+    ]
+    for k in keys:
+        params[k] = saved_params[k]
+
+    return saved_params
+
+
+def save_checkpoint(
+    params: AttributeDict,
+    model: nn.Module,
+    optimizer: Optional[torch.optim.Optimizer] = None,
+    scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
+    rank: int = 0,
+) -> None:
+    """Save model, optimizer, scheduler and training stats to file.
+
+    Args:
+      params:
+        It is returned by :func:`get_params`.
+      model:
+        The training model.
+    """
+    if rank != 0:
+        return
+    filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
+    save_checkpoint_impl(
+        filename=filename,
+        model=model,
+        params=params,
+        optimizer=optimizer,
+        scheduler=scheduler,
+        rank=rank,
+    )
+
+    if params.best_train_epoch == params.cur_epoch:
+        best_train_filename = params.exp_dir / "best-train-loss.pt"
+        copyfile(src=filename, dst=best_train_filename)
+
+    if params.best_valid_epoch == params.cur_epoch:
+        best_valid_filename = params.exp_dir / "best-valid-loss.pt"
+        copyfile(src=filename, dst=best_valid_filename)
+
+
+def compute_loss(
+    params: AttributeDict,
+    model: nn.Module,
+    batch: dict,
+    graph_compiler: BpeCtcTrainingGraphCompiler,
+    is_training: bool,
+) -> Tuple[Tensor, MetricsTracker]:
+    """
+    Compute CTC loss given the model and its inputs.
+
+    Args:
+      params:
+        Parameters for training. See :func:`get_params`.
+      model:
+        The model for training. It is an instance of Conformer in our case.
+      batch:
+        A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+        for the content in it.
+      graph_compiler:
+        It is used to build a decoding graph from a ctc topo and training
+        transcript. The training transcript is contained in the given `batch`,
+        while the ctc topo is built when this compiler is instantiated.
+      is_training:
+        True for training. False for validation. When it is True, this
+        function enables autograd during computation; when it is False, it
+        disables autograd.
+    """
+    device = graph_compiler.device
+    feature = batch["inputs"]
+    # at entry, feature is (N, T, C)
+    assert feature.ndim == 3
+    feature = feature.to(device)
+
+    supervisions = batch["supervisions"]
+    with torch.set_grad_enabled(is_training):
+        nnet_output, encoder_memory, memory_mask = model(feature, supervisions)
+        # nnet_output is (N, T, C)
+
+    # NOTE: We need `encode_supervisions` to sort sequences with
+    # different duration in decreasing order, required by
+    # `k2.intersect_dense` called in `k2.ctc_loss`
+    supervision_segments, texts = encode_supervisions(
+        supervisions, subsampling_factor=params.subsampling_factor
+    )
+
+    token_ids = graph_compiler.texts_to_ids(texts)
+
+    decoding_graph = graph_compiler.compile(token_ids)
+
+    dense_fsa_vec = k2.DenseFsaVec(
+        nnet_output,
+        supervision_segments,
+        allow_truncate=params.subsampling_factor - 1,
+    )
+
+    ctc_loss = k2.ctc_loss(
+        decoding_graph=decoding_graph,
+        dense_fsa_vec=dense_fsa_vec,
+        output_beam=params.beam_size,
+        reduction="none",
+        use_double_scores=params.use_double_scores,
+    )
+    # filter inf from ctc_loss
+    ctc_loss = torch.sum(
+        torch.where(
+            ctc_loss != float("inf"),
+            ctc_loss,
+            torch.tensor(0, dtype=torch.float32).to(device),
+        )
+    )
+
+    if params.att_rate != 0.0:
+        with torch.set_grad_enabled(is_training):
+            mmodel = model.module if hasattr(model, "module") else model
+            # Note: We need to generate an unsorted version of token_ids
+            # `encode_supervisions()` called above sorts text, but
+            # encoder_memory and memory_mask are not sorted, so we
+            # use an unsorted version `supervisions["text"]` to regenerate
+            # the token_ids
+            #
+            # See https://github.com/k2-fsa/icefall/issues/97
+            # for more details
+            unsorted_token_ids = graph_compiler.texts_to_ids(supervisions["text"])
+
+            att_loss = mmodel.decoder_forward(
+                encoder_memory,
+                memory_mask,
+                token_ids=unsorted_token_ids,
+                sos_id=graph_compiler.sos_id,
+                eos_id=graph_compiler.eos_id,
+            )
+        loss = (1.0 - params.att_rate) * ctc_loss + params.att_rate * att_loss
+    else:
+        loss = ctc_loss
+        att_loss = torch.tensor([0])
+
+    assert loss.requires_grad == is_training
+
+    info = MetricsTracker()
+    info["frames"] = supervision_segments[:, 2].sum().item()
+    info["ctc_loss"] = ctc_loss.detach().cpu().item()
+    if params.att_rate != 0.0:
+        info["att_loss"] = att_loss.detach().cpu().item()
+
+    info["loss"] = loss.detach().cpu().item()
+
+    return loss, info
+
+
+def compute_validation_loss(
+    params: AttributeDict,
+    model: nn.Module,
+    graph_compiler: BpeCtcTrainingGraphCompiler,
+    valid_dl: torch.utils.data.DataLoader,
+    world_size: int = 1,
+) -> MetricsTracker:
+    """Run the validation process."""
+    model.eval()
+
+    tot_loss = MetricsTracker()
+
+    for batch_idx, batch in enumerate(valid_dl):
+        loss, loss_info = compute_loss(
+            params=params,
+            model=model,
+            batch=batch,
+            graph_compiler=graph_compiler,
+            is_training=False,
+        )
+        assert loss.requires_grad is False
+        tot_loss = tot_loss + loss_info
+
+    if world_size > 1:
+        tot_loss.reduce(loss.device)
+
+    loss_value = tot_loss["loss"] / tot_loss["frames"]
+    if loss_value < params.best_valid_loss:
+        params.best_valid_epoch = params.cur_epoch
+        params.best_valid_loss = loss_value
+
+    return tot_loss
+
+
+def train_one_epoch(
+    params: AttributeDict,
+    model: nn.Module,
+    optimizer: torch.optim.Optimizer,
+    graph_compiler: BpeCtcTrainingGraphCompiler,
+    train_dl: torch.utils.data.DataLoader,
+    valid_dl: torch.utils.data.DataLoader,
+    tb_writer: Optional[SummaryWriter] = None,
+    world_size: int = 1,
+) -> None:
+    """Train the model for one epoch.
+
+    The training loss from the mean of all frames is saved in
+    `params.train_loss`. It runs the validation process every
+    `params.valid_interval` batches.
+
+    Args:
+      params:
+        It is returned by :func:`get_params`.
+      model:
+        The model for training.
+      optimizer:
+        The optimizer we are using.
+      graph_compiler:
+        It is used to convert transcripts to FSAs.
+      train_dl:
+        Dataloader for the training dataset.
+      valid_dl:
+        Dataloader for the validation dataset.
+      tb_writer:
+        Writer to write log messages to tensorboard.
+      world_size:
+        Number of nodes in DDP training. If it is 1, DDP is disabled.
+    """
+
+    model.train()
+
+    tot_loss = MetricsTracker()
+
+    for batch_idx, batch in enumerate(train_dl):
+        if batch["inputs"].shape[0] == len(batch["supervisions"]["text"]):
+            params.batch_idx_train += 1
+            batch_size = len(batch["supervisions"]["text"])
+
+            loss, loss_info = compute_loss(
+                params=params,
+                model=model,
+                batch=batch,
+                graph_compiler=graph_compiler,
+                is_training=True,
+            )
+            # summary stats
+            tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
+            # if tot_loss is None:
+            #     logging.warning("Batch mismatch. Skipping ...")
+            #     del batch
+            #     del tot_loss
+            #     continue;
+            # elif tot_loss.isinf() or tot_loss.isnan():
+            #     logging.warning("NaN or Inf loss. Skipping ...")
+            #     del batch
+            #     del tot_loss
+            #     continue;
+            # NOTE: We use reduction==sum and loss is computed over utterances
+            # in the batch and there is no normalization to it so far.
+
+            optimizer.zero_grad()
+            loss.backward()
+            clip_grad_norm_(model.parameters(), 5.0, 2.0)
+            optimizer.step()
+
+            if batch_idx % params.log_interval == 0:
+                logging.info(
+                    f"Epoch {params.cur_epoch}, "
+                    f"batch {batch_idx}, loss[{loss_info}], "
+                    f"tot_loss[{tot_loss}], batch size: {batch_size}"
+                )
+
+            if batch_idx % params.log_interval == 0:
+
+                if tb_writer is not None:
+                    loss_info.write_summary(
+                        tb_writer, "train/current_", params.batch_idx_train
+                    )
+                    tot_loss.write_summary(
+                        tb_writer, "train/tot_", params.batch_idx_train
+                    )
+
+            if batch_idx > 0 and batch_idx % params.valid_interval == 0:
+                logging.info("Computing validation loss")
+                valid_info = compute_validation_loss(
+                    params=params,
+                    model=model,
+                    graph_compiler=graph_compiler,
+                    valid_dl=valid_dl,
+                    world_size=world_size,
+                )
+                model.train()
+                logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
+                if tb_writer is not None:
+                    valid_info.write_summary(
+                        tb_writer, "train/valid_", params.batch_idx_train
+                    )
+        else:
+            logging.warning(
+                f"Batch {batch_idx} mismatch in dimentions between the input and the output. Skipping ..."
+            )
+            continue
+    loss_value = tot_loss["loss"] / tot_loss["frames"]
+    params.train_loss = loss_value
+    if params.train_loss < params.best_train_loss:
+        params.best_train_epoch = params.cur_epoch
+        params.best_train_loss = params.train_loss
+
+
+def run(rank, world_size, args):
+    """
+    Args:
+      rank:
+        It is a value between 0 and `world_size-1`, which is
+        passed automatically by `mp.spawn()` in :func:`main`.
+        The node with rank 0 is responsible for saving checkpoint.
+      world_size:
+        Number of GPUs for DDP training.
+      args:
+        The return value of get_parser().parse_args()
+    """
+    params = get_params()
+    params.update(vars(args))
+
+    fix_random_seed(42)
+    if world_size > 1:
+        setup_dist(rank, world_size, params.master_port)
+
+    setup_logger(f"{params.exp_dir}/log/log-train")
+    logging.info("Training started")
+    logging.info(params)
+
+    if args.tensorboard and rank == 0:
+        tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
+    else:
+        tb_writer = None
+
+    lexicon = Lexicon(params.lang_dir)
+    max_token_id = max(lexicon.tokens)
+    num_classes = max_token_id + 1  # +1 for the blank
+
+    device = torch.device("cpu")
+    if torch.cuda.is_available():
+        device = torch.device("cuda", rank)
+
+    graph_compiler = BpeCtcTrainingGraphCompiler(
+        params.lang_dir,
+        device=device,
+        sos_token="",
+        eos_token="",
+    )
+
+    logging.info("About to create model")
+    model = Conformer(
+        num_features=params.feature_dim,
+        nhead=params.nhead,
+        d_model=params.attention_dim,
+        num_classes=num_classes,
+        subsampling_factor=params.subsampling_factor,
+        num_decoder_layers=params.num_decoder_layers,
+        vgg_frontend=False,
+        use_feat_batchnorm=params.use_feat_batchnorm,
+    )
+
+    checkpoints = load_checkpoint_if_available(params=params, model=model)
+
+    model.to(device)
+    if world_size > 1:
+        model = DDP(model, device_ids=[rank])
+
+    optimizer = Noam(
+        model.parameters(),
+        model_size=params.attention_dim,
+        factor=params.lr_factor,
+        warm_step=params.warm_step,
+        weight_decay=params.weight_decay,
+    )
+
+    if checkpoints:
+        optimizer.load_state_dict(checkpoints["optimizer"])
+
+    MGB2 = MGB2AsrDataModule(args)
+
+    train_cuts = MGB2.train_cuts()
+
+    def remove_short_and_long_utt(c: Cut):
+        # Keep only utterances with duration between 1 second and 20 seconds
+        #
+        # Caution: There is a reason to select 20.0 here. Please see
+        # ../local/display_manifest_statistics.py
+        #
+        # You should use ../local/display_manifest_statistics.py to get
+        # an utterance duration distribution for your dataset to select
+        # the threshold
+        return 0.5 <= c.duration <= 30.0
+
+    train_cuts = train_cuts.filter(remove_short_and_long_utt)
+    train_dl = MGB2.train_dataloaders(train_cuts)
+
+    valid_cuts = MGB2.dev_cuts()
+    valid_dl = MGB2.test_dataloaders(valid_cuts)
+
+    scan_pessimistic_batches_for_oom(
+        model=model,
+        train_dl=train_dl,
+        optimizer=optimizer,
+        graph_compiler=graph_compiler,
+        params=params,
+    )
+
+    for epoch in range(params.start_epoch, params.num_epochs):
+        train_dl.sampler.set_epoch(epoch)
+
+        cur_lr = optimizer._rate
+        if tb_writer is not None:
+            tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
+            tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
+
+        if rank == 0:
+            logging.info("epoch {}, learning rate {}".format(epoch, cur_lr))
+
+        params.cur_epoch = epoch
+
+        train_one_epoch(
+            params=params,
+            model=model,
+            optimizer=optimizer,
+            graph_compiler=graph_compiler,
+            train_dl=train_dl,
+            valid_dl=valid_dl,
+            tb_writer=tb_writer,
+            world_size=world_size,
+        )
+
+        save_checkpoint(
+            params=params,
+            model=model,
+            optimizer=optimizer,
+            rank=rank,
+        )
+
+    logging.info("Done!")
+
+    if world_size > 1:
+        torch.distributed.barrier()
+        cleanup_dist()
+
+
+def scan_pessimistic_batches_for_oom(
+    model: nn.Module,
+    train_dl: torch.utils.data.DataLoader,
+    optimizer: torch.optim.Optimizer,
+    graph_compiler: BpeCtcTrainingGraphCompiler,
+    params: AttributeDict,
+):
+    from lhotse.dataset import find_pessimistic_batches
+
+    logging.info(
+        "Sanity check -- see if any of the batches in epoch 0 would cause OOM."
+    )
+    batches, crit_values = find_pessimistic_batches(train_dl.sampler)
+    for criterion, cuts in batches.items():
+        batch = train_dl.dataset[cuts]
+        try:
+            optimizer.zero_grad()
+            loss, _ = compute_loss(
+                params=params,
+                model=model,
+                batch=batch,
+                graph_compiler=graph_compiler,
+                is_training=True,
+            )
+            loss.backward()
+            clip_grad_norm_(model.parameters(), 5.0, 2.0)
+            optimizer.step()
+        except RuntimeError as e:
+            if "CUDA out of memory" in str(e):
+                logging.error(
+                    "Your GPU ran out of memory with the current "
+                    "max_duration setting. We recommend decreasing "
+                    "max_duration and trying again.\n"
+                    f"Failing criterion: {criterion} "
+                    f"(={crit_values[criterion]}) ..."
+                )
+            raise
+
+
+def main():
+    parser = get_parser()
+    MGB2AsrDataModule.add_arguments(parser)
+    args = parser.parse_args()
+    args.exp_dir = Path(args.exp_dir)
+    args.lang_dir = Path(args.lang_dir)
+
+    world_size = args.world_size
+    assert world_size >= 1
+    if world_size > 1:
+        mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
+    else:
+        run(rank=0, world_size=1, args=args)
+
+
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+if __name__ == "__main__":
+    main()
diff --git a/egs/mgb2/ASR/conformer_ctc/transformer.py b/egs/mgb2/ASR/conformer_ctc/transformer.py
new file mode 120000
index 000000000..1c3f43fcf
--- /dev/null
+++ b/egs/mgb2/ASR/conformer_ctc/transformer.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/conformer_ctc/transformer.py
\ No newline at end of file
diff --git a/egs/mgb2/ASR/local/__init__.py b/egs/mgb2/ASR/local/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/egs/mgb2/ASR/local/compile_hlg.py b/egs/mgb2/ASR/local/compile_hlg.py
new file mode 120000
index 000000000..471aa7fb4
--- /dev/null
+++ b/egs/mgb2/ASR/local/compile_hlg.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/local/compile_hlg.py
\ No newline at end of file
diff --git a/egs/mgb2/ASR/local/compute_fbank_mgb2.py b/egs/mgb2/ASR/local/compute_fbank_mgb2.py
new file mode 100755
index 000000000..6cae69e41
--- /dev/null
+++ b/egs/mgb2/ASR/local/compute_fbank_mgb2.py
@@ -0,0 +1,101 @@
+#!/usr/bin/env python3
+# Copyright 2022 Johns Hopkins University  (Amir Hussein)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""
+This file computes fbank features of the MGB2 dataset.
+It looks for manifests in the directory data/manifests.
+
+The generated fbank features are saved in data/fbank.
+"""
+
+import logging
+import os
+from pathlib import Path
+
+import torch
+from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter
+from lhotse.recipes.utils import read_manifests_if_cached
+
+from icefall.utils import get_executor
+
+# Torch's multithreaded behavior needs to be disabled or
+# it wastes a lot of CPU and slow things down.
+# Do this outside of main() in case it needs to take effect
+# even when we are not invoking the main (e.g. when spawning subprocesses).
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+
+def compute_fbank_mgb2():
+    src_dir = Path("data/manifests")
+    output_dir = Path("data/fbank")
+    num_jobs = min(15, os.cpu_count())
+    num_mel_bins = 80
+
+    dataset_parts = (
+        "train",
+        "test",
+        "dev",
+    )
+    manifests = read_manifests_if_cached(
+        prefix="mgb2", dataset_parts=dataset_parts, output_dir=src_dir
+    )
+    assert manifests is not None
+    assert len(manifests) == len(dataset_parts), (
+        len(manifests),
+        len(dataset_parts),
+        list(manifests.keys()),
+        dataset_parts,
+    )
+    extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
+
+    with get_executor() as ex:  # Initialize the executor only once.
+        for partition, m in manifests.items():
+            if (output_dir / f"cuts_{partition}.json.gz").is_file():
+                logging.info(f"{partition} already exists - skipping.")
+                continue
+            logging.info(f"Processing {partition}")
+            cut_set = CutSet.from_manifests(
+                recordings=m["recordings"],
+                supervisions=m["supervisions"],
+            )
+            if "train" in partition:
+                cut_set = (
+                    cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
+                )
+            cut_set = cut_set.compute_and_store_features(
+                extractor=extractor,
+                storage_path=f"{output_dir}/feats_{partition}",
+                # when an executor is specified, make more partitions
+                num_jobs=num_jobs if ex is None else 80,
+                executor=ex,
+                storage_type=LilcomChunkyWriter,
+            )
+            logging.info("About to split cuts into smaller chunks.")
+            cut_set = cut_set.trim_to_supervisions(
+                keep_overlapping=False, min_duration=None
+            )
+            cut_set.to_file(output_dir / f"cuts_{partition}.jsonl.gz")
+
+
+if __name__ == "__main__":
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+    logging.basicConfig(format=formatter, level=logging.INFO)
+
+    compute_fbank_mgb2()
diff --git a/egs/mgb2/ASR/local/compute_fbank_musan.py b/egs/mgb2/ASR/local/compute_fbank_musan.py
new file mode 100755
index 000000000..5d0d69a13
--- /dev/null
+++ b/egs/mgb2/ASR/local/compute_fbank_musan.py
@@ -0,0 +1,108 @@
+#!/usr/bin/env python3
+# Copyright    2021  Xiaomi Corp.        (authors: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""
+This file computes fbank features of the musan dataset.
+It looks for manifests in the directory data/manifests.
+The generated fbank features are saved in data/fbank.
+"""
+
+import logging
+import os
+from pathlib import Path
+
+import torch
+from lhotse import (
+    ChunkedLilcomHdf5Writer,
+    CutSet,
+    Fbank,
+    FbankConfig,
+    LilcomChunkyWriter,
+    combine,
+)
+from lhotse.recipes.utils import read_manifests_if_cached
+
+from icefall.utils import get_executor
+
+# Torch's multithreaded behavior needs to be disabled or
+# it wastes a lot of CPU and slow things down.
+# Do this outside of main() in case it needs to take effect
+# even when we are not invoking the main (e.g. when spawning subprocesses).
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+
+def compute_fbank_musan():
+    src_dir = Path("data/manifests")
+    output_dir = Path("data/fbank")
+    num_jobs = min(15, os.cpu_count())
+    num_mel_bins = 80
+
+    dataset_parts = (
+        "music",
+        "speech",
+        "noise",
+    )
+    prefix = "musan"
+    suffix = "jsonl.gz"
+    manifests = read_manifests_if_cached(
+        prefix=prefix,
+        dataset_parts=dataset_parts,
+        output_dir=src_dir,
+        suffix=suffix,
+    )
+    assert manifests is not None
+    assert len(manifests) == len(dataset_parts), (
+        len(manifests),
+        len(dataset_parts),
+    )
+
+    musan_cuts_path = output_dir / "cuts_musan.jsonl.gz"
+
+    if musan_cuts_path.is_file():
+        logging.info(f"{musan_cuts_path} already exists - skipping")
+        return
+
+    logging.info("Extracting features for Musan")
+
+    extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
+
+    with get_executor() as ex:  # Initialize the executor only once.
+        # create chunks of Musan with duration 5 - 10 seconds
+        musan_cuts = (
+            CutSet.from_manifests(
+                recordings=combine(part["recordings"] for part in manifests.values())
+            )
+            .cut_into_windows(10.0)
+            .filter(lambda c: c.duration > 5)
+            .compute_and_store_features(
+                extractor=extractor,
+                storage_path=f"{output_dir}/feats_musan",
+                num_jobs=num_jobs if ex is None else 80,
+                executor=ex,
+                storage_type=LilcomChunkyWriter,
+            )
+        )
+        musan_cuts.to_file(musan_cuts_path)
+
+
+if __name__ == "__main__":
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+    logging.basicConfig(format=formatter, level=logging.INFO)
+    compute_fbank_musan()
diff --git a/egs/mgb2/ASR/local/convert_transcript_words_to_tokens.py b/egs/mgb2/ASR/local/convert_transcript_words_to_tokens.py
new file mode 100755
index 000000000..a8d5117c9
--- /dev/null
+++ b/egs/mgb2/ASR/local/convert_transcript_words_to_tokens.py
@@ -0,0 +1,103 @@
+#!/usr/bin/env python3
+
+# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
+"""
+Convert a transcript file containing words to a corpus file containing tokens
+for LM training with the help of a lexicon.
+
+If the lexicon contains phones, the resulting LM will be a phone LM; If the
+lexicon contains word pieces, the resulting LM will be a word piece LM.
+
+If a word has multiple pronunciations, the one that appears first in the lexicon
+is kept; others are removed.
+
+If the input transcript is:
+
+    hello zoo world hello
+    world zoo
+    foo zoo world hellO
+
+and if the lexicon is
+
+     SPN
+    hello h e l l o 2
+    hello h e l l o
+    world w o r l d
+    zoo z o o
+
+Then the output is
+
+    h e l l o 2 z o o w o r l d h e l l o 2
+    w o r l d z o o
+    SPN z o o w o r l d SPN
+"""
+
+import argparse
+from pathlib import Path
+from typing import Dict, List
+
+from generate_unique_lexicon import filter_multiple_pronunications
+
+from icefall.lexicon import read_lexicon
+
+
+def get_args():
+    parser = argparse.ArgumentParser()
+    parser.add_argument(
+        "--transcript",
+        type=str,
+        help="The input transcript file."
+        "We assume that the transcript file consists of "
+        "lines. Each line consists of space separated words.",
+    )
+    parser.add_argument("--lexicon", type=str, help="The input lexicon file.")
+    parser.add_argument("--oov", type=str, default="", help="The OOV word.")
+
+    return parser.parse_args()
+
+
+def process_line(lexicon: Dict[str, List[str]], line: str, oov_token: str) -> None:
+    """
+    Args:
+      lexicon:
+        A dict containing pronunciations. Its keys are words and values
+        are pronunciations (i.e., tokens).
+      line:
+        A line of transcript consisting of space(s) separated words.
+      oov_token:
+        The pronunciation of the oov word if a word in `line` is not present
+        in the lexicon.
+    Returns:
+      Return None.
+    """
+    s = ""
+    words = line.strip().split()
+    for i, w in enumerate(words):
+        tokens = lexicon.get(w, oov_token)
+        s += " ".join(tokens)
+        s += " "
+    print(s.strip())
+
+
+def main():
+    args = get_args()
+    assert Path(args.lexicon).is_file()
+    assert Path(args.transcript).is_file()
+    assert len(args.oov) > 0
+
+    # Only the first pronunciation of a word is kept
+    lexicon = filter_multiple_pronunications(read_lexicon(args.lexicon))
+
+    lexicon = dict(lexicon)
+
+    assert args.oov in lexicon
+
+    oov_token = lexicon[args.oov]
+
+    with open(args.transcript) as f:
+        for line in f:
+            process_line(lexicon=lexicon, line=line, oov_token=oov_token)
+
+
+if __name__ == "__main__":
+    main()
diff --git a/egs/mgb2/ASR/local/display_manifest_statistics.py b/egs/mgb2/ASR/local/display_manifest_statistics.py
new file mode 100755
index 000000000..d3e224905
--- /dev/null
+++ b/egs/mgb2/ASR/local/display_manifest_statistics.py
@@ -0,0 +1,97 @@
+#!/usr/bin/env python3
+# Copyright    2021  Xiaomi Corp.        (authors: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+This file displays duration statistics of utterances in a manifest.
+You can use the displayed value to choose minimum/maximum duration
+to remove short and long utterances during the training.
+
+See the function `remove_short_and_long_utt()` in transducer/train.py
+for usage.
+"""
+
+
+from lhotse import load_manifest
+
+
+def main():
+    # path = "./data/fbank/cuts_train.jsonl.gz"
+    path = "./data/fbank/cuts_dev.jsonl.gz"
+    # path = "./data/fbank/cuts_test.jsonl.gz"
+
+    cuts = load_manifest(path)
+    cuts.describe()
+
+
+if __name__ == "__main__":
+    main()
+
+"""
+# train
+
+Cuts count: 1125309
+Total duration (hours): 3403.9
+Speech duration (hours): 3403.9 (100.0%)
+***
+Duration statistics (seconds):
+mean    10.9
+std     10.1
+min     0.2
+25%     5.2
+50%     7.8
+75%     12.7
+99%     52.0
+99.5%   65.1
+99.9%   99.5
+max     228.9
+
+
+# test
+Cuts count: 5365
+Total duration (hours): 9.6
+Speech duration (hours): 9.6 (100.0%)
+***
+Duration statistics (seconds):
+mean    6.4
+std     1.5
+min     1.6
+25%     5.3
+50%     6.5
+75%     7.6
+99%     9.5
+99.5%   9.7
+99.9%   10.3
+max     12.4
+
+# dev
+Cuts count: 5002
+Total duration (hours): 8.5
+Speech duration (hours): 8.5 (100.0%)
+***
+Duration statistics (seconds):
+mean    6.1
+std     1.7
+min     1.5
+25%     4.8
+50%     6.2
+75%     7.4
+99%     9.5
+99.5%   9.7
+99.9%   10.1
+max     20.3
+
+"""
diff --git a/egs/mgb2/ASR/local/generate_unique_lexicon.py b/egs/mgb2/ASR/local/generate_unique_lexicon.py
new file mode 120000
index 000000000..c0aea1403
--- /dev/null
+++ b/egs/mgb2/ASR/local/generate_unique_lexicon.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/local/generate_unique_lexicon.py
\ No newline at end of file
diff --git a/egs/mgb2/ASR/local/prep_mgb2_lexicon.sh b/egs/mgb2/ASR/local/prep_mgb2_lexicon.sh
new file mode 100755
index 000000000..3b673db6f
--- /dev/null
+++ b/egs/mgb2/ASR/local/prep_mgb2_lexicon.sh
@@ -0,0 +1,30 @@
+#!/usr/bin/env bash
+
+# Copyright 2022 QCRI (author: Amir Hussein)
+# Apache 2.0
+# This script prepares the graphemic lexicon.
+
+dir=data/local/dict
+lexicon_url1="https://arabicspeech.org/arabicspeech-portal-resources/lexicon/ar-ar_grapheme_lexicon_20160209.bz2";
+lexicon_url2="https://arabicspeech.org/arabicspeech-portal-resources/lexicon/ar-ar_phoneme_lexicon_20140317.bz2";
+stage=0
+lang_dir=download/lm
+mkdir -p $lang_dir
+
+if [ $stage -le 0 ]; then
+  echo "$0: Downloading text for lexicon... $(date)."
+  wget --no-check-certificate -P $lang_dir $lexicon_url1
+  wget --no-check-certificate -P $lang_dir $lexicon_url2
+  bzcat $lang_dir/ar-ar_grapheme_lexicon_20160209.bz2 | sed '1,3d' | awk '{print $1}'  >  $lang_dir/grapheme_lexicon
+  bzcat $lang_dir/ar-ar_phoneme_lexicon_20140317.bz2 | sed '1,3d' | awk '{print $1}' >>  $lang_dir/phoneme_lexicon
+  cat download/lm/train/text | cut -d ' ' -f 2- | tr -s " " "\n" | sort -u >> $lang_dir/uniq_words
+fi
+
+
+if [ $stage -le 0 ]; then
+  echo "$0: processing lexicon text and creating lexicon... $(date)."
+  # remove vowels and  rare alef wasla
+  cat $lang_dir/uniq_words |  sed -e 's:[FNKaui\~o\`]::g' -e 's:{:}:g' | sed -r '/^\s*$/d' | sort -u > $lang_dir/grapheme_lexicon.txt
+fi
+
+echo "$0: Lexicon preparation succeeded"
diff --git a/egs/mgb2/ASR/local/prepare_lang.py b/egs/mgb2/ASR/local/prepare_lang.py
new file mode 120000
index 000000000..747f2ab39
--- /dev/null
+++ b/egs/mgb2/ASR/local/prepare_lang.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/local/prepare_lang.py
\ No newline at end of file
diff --git a/egs/mgb2/ASR/local/prepare_lang_bpe.py b/egs/mgb2/ASR/local/prepare_lang_bpe.py
new file mode 120000
index 000000000..36b40e7fc
--- /dev/null
+++ b/egs/mgb2/ASR/local/prepare_lang_bpe.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/local/prepare_lang_bpe.py
\ No newline at end of file
diff --git a/egs/mgb2/ASR/local/prepare_mgb2_lexicon.py b/egs/mgb2/ASR/local/prepare_mgb2_lexicon.py
new file mode 100755
index 000000000..99e1fa34d
--- /dev/null
+++ b/egs/mgb2/ASR/local/prepare_mgb2_lexicon.py
@@ -0,0 +1,37 @@
+#!/usr/bin/env python3
+
+# Copyright      2022  Amir Hussein
+# Apache 2.0
+
+# This script prepares givel a column of words lexicon.
+
+import argparse
+
+
+def get_args():
+    parser = argparse.ArgumentParser(
+        description="""Creates the list of characters and words in lexicon"""
+    )
+    parser.add_argument("input", type=str, help="""Input list of words file""")
+    parser.add_argument("output", type=str, help="""output graphemic lexicon""")
+    args = parser.parse_args()
+    return args
+
+
+def main():
+    lex = {}
+    args = get_args()
+    with open(args.input, "r", encoding="utf-8") as f:
+        for line in f:
+            line = line.strip()
+            characters = list(line)
+            characters = " ".join(["V" if char == "*" else char for char in characters])
+            lex[line] = characters
+
+    with open(args.output, "w", encoding="utf-8") as fp:
+        for key in sorted(lex):
+            fp.write(key + "  " + lex[key] + "\n")
+
+
+if __name__ == "__main__":
+    main()
diff --git a/egs/mgb2/ASR/local/test_prepare_lang.py b/egs/mgb2/ASR/local/test_prepare_lang.py
new file mode 120000
index 000000000..f0f864998
--- /dev/null
+++ b/egs/mgb2/ASR/local/test_prepare_lang.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/local/test_prepare_lang.py
\ No newline at end of file
diff --git a/egs/mgb2/ASR/prepare.sh b/egs/mgb2/ASR/prepare.sh
new file mode 100755
index 000000000..899d15d97
--- /dev/null
+++ b/egs/mgb2/ASR/prepare.sh
@@ -0,0 +1,234 @@
+#!/usr/bin/env bash
+# Copyright 2022 Johns Hopkins University  (Amir Hussein)
+# Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)
+
+set -eou pipefail
+nj=30
+stage=7
+stop_stage=1000
+
+# We assume dl_dir (download dir) contains the following
+# directories and files. 
+#
+#  - $dl_dir/mgb2
+#      
+#      You can download the data from 
+#
+#
+#  - $dl_dir/musan
+#      This directory contains the following directories downloaded from
+#       http://www.openslr.org/17/
+#
+#     - music
+#     - noise
+#     - speech
+#
+# Note: MGB2 is not available for direct 
+# download, however you can fill out the form and  
+# download it from https://arabicspeech.org/mgb2 
+
+dl_dir=$PWD/download
+
+. shared/parse_options.sh || exit 1
+
+# vocab size for sentence piece models.
+# It will generate data/lang_bpe_xxx,
+# data/lang_bpe_yyy if the array contains xxx, yyy
+vocab_sizes=(
+  5000
+)
+
+# All files generated by this script are saved in "data".
+# You can safely remove "data" and rerun this script to regenerate it.
+mkdir -p data
+
+log() {
+  # This function is from espnet
+  local fname=${BASH_SOURCE[1]##*/}
+  echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
+}
+
+log "dl_dir: $dl_dir"
+
+if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
+  log "Stage 0: Download data"
+
+  # If you have pre-downloaded it to /path/to/MGB2,
+  # you can create a symlink
+  #
+  #   ln -sfv /path/to/mgb2 $dl_dir/MGB2
+
+  # If you have pre-downloaded it to /path/to/musan,
+  # you can create a symlink
+  #
+  #   ln -sfv /path/to/musan $dl_dir/
+  #
+  if [ ! -d $dl_dir/musan ]; then
+    lhotse download musan $dl_dir
+  fi
+fi
+
+if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
+  log "Stage 1: Prepare mgb2 manifest"
+  # We assume that you have downloaded the mgb2 corpus
+  # to $dl_dir/mgb2
+  mkdir -p data/manifests
+
+  lhotse prepare mgb2 $dl_dir/mgb2 data/manifests
+  
+fi
+
+if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
+  log "Stage 2: Prepare musan manifest"
+  # We assume that you have downloaded the musan corpus
+  # to data/musan
+  mkdir -p data/manifests
+  lhotse prepare musan $dl_dir/musan data/manifests
+fi
+
+if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
+  log "Stage 3: Compute fbank for mgb2"
+  mkdir -p data/fbank
+  ./local/compute_fbank_mgb2.py
+   # shufling the data
+  gunzip -c data/fbank/cuts_train.jsonl.gz | shuf | gzip -c > data/fbank/cuts_train_shuf.jsonl.gz
+fi
+
+if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
+  log "Stage 4: Compute fbank for musan"
+  mkdir -p data/fbank
+  ./local/compute_fbank_musan.py
+fi
+
+if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
+  log "Stage 5: Prepare phone based lang"
+  if [[ ! -e download/lm/train/text ]]; then 
+  # export train text file to build grapheme lexicon 
+  lhotse kaldi export \
+    data/manifests/mgb2_recordings_train.jsonl.gz \
+    data/manifests/mgb2_supervisions_train.jsonl.gz  \
+    download/lm/train
+  fi
+
+  lang_dir=data/lang_phone
+  mkdir -p $lang_dir
+  ./local/prep_mgb2_lexicon.sh 
+  python local/prepare_mgb2_lexicon.py  $dl_dir/lm/grapheme_lexicon.txt  $dl_dir/lm/lexicon.txt
+  (echo '!SIL SIL'; echo ' SPN'; echo ' SPN'; ) |
+    cat - $dl_dir/lm/lexicon.txt |
+    sort | uniq > $lang_dir/lexicon.txt
+
+  if [ ! -f $lang_dir/L_disambig.pt ]; then
+    ./local/prepare_lang.py --lang-dir $lang_dir
+  fi
+fi
+
+
+if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
+  log "Stage 6: Prepare BPE based lang"
+
+  for vocab_size in ${vocab_sizes[@]}; do
+    lang_dir=data/lang_bpe_${vocab_size}
+    mkdir -p $lang_dir
+    # We reuse words.txt from phone based lexicon
+    # so that the two can share G.pt later.
+    cp data/lang_phone/words.txt $lang_dir
+
+    if [ ! -f $lang_dir/transcript_words.txt ]; then
+      log "Generate data for BPE training"
+      files=$(
+        find "$dl_dir/lm/train" -name "text"
+      )
+      for f in ${files[@]}; do
+        cat $f | cut -d " " -f 2- | sed -r '/^\s*$/d'
+      done > $lang_dir/transcript_words.txt
+    fi
+
+    ./local/train_bpe_model.py \
+      --lang-dir $lang_dir \
+      --vocab-size $vocab_size \
+      --transcript $lang_dir/transcript_words.txt
+
+    if [ ! -f $lang_dir/L_disambig.pt ]; then
+      ./local/prepare_lang_bpe.py --lang-dir $lang_dir
+    fi
+  done
+fi
+
+if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
+  log "Stage 7: Prepare bigram P"
+
+  for vocab_size in ${vocab_sizes[@]}; do
+    lang_dir=data/lang_bpe_${vocab_size}
+
+    if [ ! -f $lang_dir/transcript_tokens.txt ]; then
+      ./local/convert_transcript_words_to_tokens.py \
+        --lexicon $lang_dir/lexicon.txt \
+        --transcript $lang_dir/transcript_words.txt \
+        --oov "" \
+        > $lang_dir/transcript_tokens.txt
+    fi
+
+    if [ ! -f $lang_dir/P.arpa ]; then
+      ./shared/make_kn_lm.py \
+        -ngram-order 2 \
+        -text $lang_dir/transcript_tokens.txt \
+        -lm $lang_dir/P.arpa
+    fi
+
+    if [ ! -f $lang_dir/P.fst.txt ]; then
+      python3 -m kaldilm \
+        --read-symbol-table="$lang_dir/tokens.txt" \
+        --disambig-symbol='#0' \
+        --max-order=2 \
+        $lang_dir/P.arpa > $lang_dir/P.fst.txt
+    fi
+  done
+fi
+
+if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
+  log "Stage 8: Prepare G"
+  # We assume you have install kaldilm, if not, please install
+  # it using: pip install kaldilm
+  for vocab_size in ${vocab_sizes[@]}; do
+    lang_dir=data/lang_bpe_${vocab_size}
+    mkdir -p data/lm
+    if [ ! -f data/lm/G_3_gram.fst.txt ]; then
+      # It is used in building HLG
+      ./shared/make_kn_lm.py \
+          -ngram-order 3 \
+          -text $lang_dir/transcript_words.txt \
+          -lm $lang_dir/G.arpa
+
+      python3 -m kaldilm \
+        --read-symbol-table="data/lang_phone/words.txt" \
+        --disambig-symbol='#0' \
+        --max-order=3 \
+        $lang_dir/G.arpa > data/lm/G_3_gram.fst.txt
+    fi
+
+    if [ ! -f data/lm/G_4_gram.fst.txt ]; then
+      # It is used for LM rescoring
+      ./shared/make_kn_lm.py \
+          -ngram-order 4 \
+          -text $lang_dir/transcript_words.txt \
+          -lm $lang_dir/4-gram.arpa
+
+      python3 -m kaldilm \
+        --read-symbol-table="data/lang_phone/words.txt" \
+        --disambig-symbol='#0' \
+        --max-order=4 \
+        $lang_dir/4-gram.arpa > data/lm/G_4_gram.fst.txt
+    fi
+  done
+fi
+
+if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
+  log "Stage 9: Compile HLG"
+  ./local/compile_hlg.py --lang-dir data/lang_phone
+
+  for vocab_size in ${vocab_sizes[@]}; do
+    lang_dir=data/lang_bpe_${vocab_size}
+    ./local/compile_hlg.py --lang-dir $lang_dir
+  done
+fi
diff --git a/egs/mgb2/ASR/pruned_transducer_stateless5/__init__.py b/egs/mgb2/ASR/pruned_transducer_stateless5/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/egs/mgb2/ASR/pruned_transducer_stateless5/asr_datamodule.py b/egs/mgb2/ASR/pruned_transducer_stateless5/asr_datamodule.py
new file mode 120000
index 000000000..a73848de9
--- /dev/null
+++ b/egs/mgb2/ASR/pruned_transducer_stateless5/asr_datamodule.py
@@ -0,0 +1 @@
+../conformer_ctc/asr_datamodule.py
\ No newline at end of file
diff --git a/egs/mgb2/ASR/pruned_transducer_stateless5/beam_search.py b/egs/mgb2/ASR/pruned_transducer_stateless5/beam_search.py
new file mode 120000
index 000000000..02d01b343
--- /dev/null
+++ b/egs/mgb2/ASR/pruned_transducer_stateless5/beam_search.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless5/beam_search.py
\ No newline at end of file
diff --git a/egs/mgb2/ASR/pruned_transducer_stateless5/conformer.py b/egs/mgb2/ASR/pruned_transducer_stateless5/conformer.py
new file mode 120000
index 000000000..c7c1a4b6e
--- /dev/null
+++ b/egs/mgb2/ASR/pruned_transducer_stateless5/conformer.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless5/conformer.py
\ No newline at end of file
diff --git a/egs/mgb2/ASR/pruned_transducer_stateless5/decode.py b/egs/mgb2/ASR/pruned_transducer_stateless5/decode.py
new file mode 100755
index 000000000..1463f8f67
--- /dev/null
+++ b/egs/mgb2/ASR/pruned_transducer_stateless5/decode.py
@@ -0,0 +1,625 @@
+#!/usr/bin/env python3
+# Copyright    2022  Johns Hopkins        (authors: Amir Hussein)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+(1) greedy search
+./pruned_transducer_stateless5/decode.py \
+    --epoch 18 \
+    --avg 5 \
+    --exp-dir ./pruned_transducer_stateless5/exp \
+    --max-duration 200 \
+    --decoding-method greedy_search
+
+(2) beam search (not recommended)
+./pruned_transducer_stateless5/decode.py \
+    --epoch 18 \
+    --avg 5 \
+    --exp-dir ./pruned_transducer_stateless5/exp \
+    --max-duration 200 \
+    --decoding-method beam_search \
+    --beam-size 10
+
+(3) modified beam search
+./pruned_transducer_stateless5/decode.py \
+    --epoch 18 \
+    --avg 5 \
+    --exp-dir ./pruned_transducer_stateless5/exp \
+    --max-duration 600 \
+    --decoding-method modified_beam_search \
+    --beam-size 10
+
+(4) fast beam search
+./pruned_transducer_stateless5/decode.py \
+    --epoch 18 \
+    --avg 5 \
+    --exp-dir ./pruned_transducer_stateless5/exp \
+    --max-duration 200 \
+    --decoding-method fast_beam_search \
+    --beam-size 10 \
+    --max-contexts 4 \
+    --max-states 8
+"""
+
+
+import argparse
+import logging
+from collections import defaultdict
+from pathlib import Path
+from typing import Dict, List, Optional, Tuple
+
+import k2
+import sentencepiece as spm
+import torch
+import torch.nn as nn
+from asr_datamodule import MGB2AsrDataModule
+from beam_search import (
+    beam_search,
+    fast_beam_search_one_best,
+    greedy_search,
+    greedy_search_batch,
+    modified_beam_search,
+)
+from train import add_model_arguments, get_params, get_transducer_model
+
+from icefall.checkpoint import (
+    average_checkpoints,
+    average_checkpoints_with_averaged_model,
+    find_checkpoints,
+    load_checkpoint,
+)
+from icefall.utils import (
+    AttributeDict,
+    setup_logger,
+    store_transcripts,
+    str2bool,
+    write_error_stats,
+)
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+
+    parser.add_argument(
+        "--epoch",
+        type=int,
+        default=30,
+        help="""It specifies the checkpoint to use for decoding.
+        Note: Epoch counts from 1.
+        You can specify --avg to use more checkpoints for model averaging.""",
+    )
+
+    parser.add_argument(
+        "--iter",
+        type=int,
+        default=0,
+        help="""If positive, --epoch is ignored and it
+        will use the checkpoint exp_dir/checkpoint-iter.pt.
+        You can specify --avg to use more checkpoints for model averaging.
+        """,
+    )
+
+    parser.add_argument(
+        "--avg",
+        type=int,
+        default=15,
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch' and '--iter'",
+    )
+
+    parser.add_argument(
+        "--use-averaged-model",
+        type=str2bool,
+        default=False,
+        help="Whether to load averaged model. Currently it only supports "
+        "using --epoch. If True, it would decode with the averaged model "
+        "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+        "Actually only the models with epoch number of `epoch-avg` and "
+        "`epoch` are loaded for averaging. ",
+    )
+
+    parser.add_argument(
+        "--exp-dir",
+        type=str,
+        default="pruned_transducer_stateless5/exp",
+        help="The experiment dir",
+    )
+
+    parser.add_argument(
+        "--bpe-model",
+        type=str,
+        default="data/lang_bpe_2000/bpe.model",
+        help="Path to the BPE model",
+    )
+
+    parser.add_argument(
+        "--decoding-method",
+        type=str,
+        default="greedy_search",
+        help="""Possible values are:
+          - greedy_search
+          - beam_search
+          - modified_beam_search
+          - fast_beam_search
+        """,
+    )
+
+    parser.add_argument(
+        "--beam-size",
+        type=int,
+        default=4,
+        help="""An integer indicating how many candidates we will keep for each
+        frame. Used only when --decoding-method is beam_search or
+        modified_beam_search.""",
+    )
+
+    parser.add_argument(
+        "--beam",
+        type=float,
+        default=4,
+        help="""A floating point value to calculate the cutoff score during beam
+        search (i.e., `cutoff = max-score - beam`), which is the same as the
+        `beam` in Kaldi.
+        Used only when --decoding-method is fast_beam_search""",
+    )
+
+    parser.add_argument(
+        "--max-contexts",
+        type=int,
+        default=4,
+        help="""Used only when --decoding-method is
+        fast_beam_search""",
+    )
+
+    parser.add_argument(
+        "--max-states",
+        type=int,
+        default=8,
+        help="""Used only when --decoding-method is
+        fast_beam_search""",
+    )
+
+    parser.add_argument(
+        "--context-size",
+        type=int,
+        default=2,
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+    )
+    parser.add_argument(
+        "--max-sym-per-frame",
+        type=int,
+        default=1,
+        help="""Maximum number of symbols per frame.
+        Used only when --decoding_method is greedy_search""",
+    )
+
+    add_model_arguments(parser)
+
+    return parser
+
+
+def decode_one_batch(
+    params: AttributeDict,
+    model: nn.Module,
+    sp: spm.SentencePieceProcessor,
+    batch: dict,
+    decoding_graph: Optional[k2.Fsa] = None,
+) -> Dict[str, List[List[str]]]:
+    """Decode one batch and return the result in a dict. The dict has the
+    following format:
+
+        - key: It indicates the setting used for decoding. For example,
+               if greedy_search is used, it would be "greedy_search"
+               If beam search with a beam size of 7 is used, it would be
+               "beam_7"
+        - value: It contains the decoding result. `len(value)` equals to
+                 batch size. `value[i]` is the decoding result for the i-th
+                 utterance in the given batch.
+    Args:
+      params:
+        It's the return value of :func:`get_params`.
+      model:
+        The neural model.
+      sp:
+        The BPE model.
+      batch:
+        It is the return value from iterating
+        `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
+        for the format of the `batch`.
+      decoding_graph:
+        The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
+        only when --decoding_method is fast_beam_search.
+    Returns:
+      Return the decoding result. See above description for the format of
+      the returned dict.
+    """
+    device = next(model.parameters()).device
+    feature = batch["inputs"]
+    assert feature.ndim == 3
+
+    feature = feature.to(device)
+    # at entry, feature is (N, T, C)
+
+    supervisions = batch["supervisions"]
+    feature_lens = supervisions["num_frames"].to(device)
+
+    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
+    hyps = []
+
+    if params.decoding_method == "fast_beam_search":
+        hyp_tokens = fast_beam_search_one_best(
+            model=model,
+            decoding_graph=decoding_graph,
+            encoder_out=encoder_out,
+            encoder_out_lens=encoder_out_lens,
+            beam=params.beam,
+            max_contexts=params.max_contexts,
+            max_states=params.max_states,
+        )
+        for hyp in sp.decode(hyp_tokens):
+            hyps.append(hyp.split())
+    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
+
+        hyp_tokens = greedy_search_batch(
+            model=model,
+            encoder_out=encoder_out,
+            encoder_out_lens=encoder_out_lens,
+        )
+
+        for hyp in sp.decode(hyp_tokens):
+            hyps.append(hyp.split())
+    elif params.decoding_method == "modified_beam_search":
+        hyp_tokens = modified_beam_search(
+            model=model,
+            encoder_out=encoder_out,
+            encoder_out_lens=encoder_out_lens,
+            beam=params.beam_size,
+        )
+        for hyp in sp.decode(hyp_tokens):
+            hyps.append(hyp.split())
+    else:
+        batch_size = encoder_out.size(0)
+
+        for i in range(batch_size):
+            # fmt: off
+            encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
+            # fmt: on
+            if params.decoding_method == "greedy_search":
+                hyp = greedy_search(
+                    model=model,
+                    encoder_out=encoder_out_i,
+                    max_sym_per_frame=params.max_sym_per_frame,
+                )
+            elif params.decoding_method == "beam_search":
+                hyp = beam_search(
+                    model=model,
+                    encoder_out=encoder_out_i,
+                    beam=params.beam_size,
+                )
+            else:
+                raise ValueError(
+                    f"Unsupported decoding method: {params.decoding_method}"
+                )
+            hyps.append(sp.decode(hyp).split())
+
+    if params.decoding_method == "greedy_search":
+        return {"greedy_search": hyps}
+    elif params.decoding_method == "fast_beam_search":
+        return {
+            (
+                f"beam_{params.beam}_"
+                f"max_contexts_{params.max_contexts}_"
+                f"max_states_{params.max_states}"
+            ): hyps
+        }
+    else:
+        return {f"beam_size_{params.beam_size}": hyps}
+
+
+def decode_dataset(
+    dl: torch.utils.data.DataLoader,
+    params: AttributeDict,
+    model: nn.Module,
+    sp: spm.SentencePieceProcessor,
+    decoding_graph: Optional[k2.Fsa] = None,
+) -> Dict[str, List[Tuple[List[str], List[str]]]]:
+    """Decode dataset.
+
+    Args:
+      dl:
+        PyTorch's dataloader containing the dataset to decode.
+      params:
+        It is returned by :func:`get_params`.
+      model:
+        The neural model.
+      sp:
+        The BPE model.
+      decoding_graph:
+        The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
+        only when --decoding_method is fast_beam_search.
+    Returns:
+      Return a dict, whose key may be "greedy_search" if greedy search
+      is used, or it may be "beam_7" if beam size of 7 is used.
+      Its value is a list of tuples. Each tuple contains two elements:
+      The first is the reference transcript, and the second is the
+      predicted result.
+    """
+    num_cuts = 0
+
+    try:
+        num_batches = len(dl)
+    except TypeError:
+        num_batches = "?"
+
+    if params.decoding_method == "greedy_search":
+        log_interval = 50
+    else:
+        log_interval = 20
+
+    results = defaultdict(list)
+    for batch_idx, batch in enumerate(dl):
+        texts = batch["supervisions"]["text"]
+
+        hyps_dict = decode_one_batch(
+            params=params,
+            model=model,
+            sp=sp,
+            decoding_graph=decoding_graph,
+            batch=batch,
+        )
+
+        for name, hyps in hyps_dict.items():
+            this_batch = []
+            assert len(hyps) == len(texts)
+            for hyp_words, ref_text in zip(hyps, texts):
+
+                ref_words = ref_text.split()
+                this_batch.append((ref_words, hyp_words))
+
+            results[name].extend(this_batch)
+
+        num_cuts += len(texts)
+
+        if batch_idx % log_interval == 0:
+            batch_str = f"{batch_idx}/{num_batches}"
+
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+    return results
+
+
+def save_results(
+    params: AttributeDict,
+    test_set_name: str,
+    results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
+):
+    test_set_wers = dict()
+    for key, results in results_dict.items():
+        recog_path = (
+            params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
+        )
+        store_transcripts(filename=recog_path, texts=results)
+        logging.info(f"The transcripts are stored in {recog_path}")
+
+        # The following prints out WERs, per-word error statistics and aligned
+        # ref/hyp pairs.
+        errs_filename = (
+            params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
+        )
+        with open(errs_filename, "w") as f:
+            wer = write_error_stats(
+                f, f"{test_set_name}-{key}", results, enable_log=True
+            )
+            test_set_wers[key] = wer
+
+        logging.info("Wrote detailed error stats to {}".format(errs_filename))
+
+    test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
+    errs_info = (
+        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+    )
+    with open(errs_info, "w") as f:
+        print("settings\tWER", file=f)
+        for key, val in test_set_wers:
+            print("{}\t{}".format(key, val), file=f)
+
+    s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
+    note = "\tbest for {}".format(test_set_name)
+    for key, val in test_set_wers:
+        s += "{}\t{}{}\n".format(key, val, note)
+        note = ""
+    logging.info(s)
+
+
+@torch.no_grad()
+def main():
+    parser = get_parser()
+    MGB2AsrDataModule.add_arguments(parser)
+    args = parser.parse_args()
+    args.exp_dir = Path(args.exp_dir)
+
+    params = get_params()
+    params.update(vars(args))
+
+    assert params.decoding_method in (
+        "greedy_search",
+        "beam_search",
+        "fast_beam_search",
+        "modified_beam_search",
+    )
+    params.res_dir = params.exp_dir / params.decoding_method
+
+    if params.iter > 0:
+        params.suffix = f"iter-{params.iter}-avg-{params.avg}"
+    else:
+        params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
+
+    if "fast_beam_search" in params.decoding_method:
+        params.suffix += f"-beam-{params.beam}"
+        params.suffix += f"-max-contexts-{params.max_contexts}"
+        params.suffix += f"-max-states-{params.max_states}"
+    elif "beam_search" in params.decoding_method:
+        params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
+    else:
+        params.suffix += f"-context-{params.context_size}"
+        params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
+
+    if params.use_averaged_model:
+        params.suffix += "-use-averaged-model"
+
+    setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
+    logging.info("Decoding started")
+
+    device = torch.device("cpu")
+    if torch.cuda.is_available():
+        device = torch.device("cuda", 0)
+
+    logging.info(f"Device: {device}")
+
+    sp = spm.SentencePieceProcessor()
+    sp.load(params.bpe_model)
+
+    #  and  are defined in local/train_bpe_model.py
+    params.blank_id = sp.piece_to_id("")
+    params.unk_id = sp.piece_to_id("")
+    params.vocab_size = sp.get_piece_size()
+
+    logging.info(params)
+
+    logging.info("About to create model")
+    model = get_transducer_model(params)
+
+    if not params.use_averaged_model:
+        if params.iter > 0:
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg
+            ]
+            if len(filenames) == 0:
+                raise ValueError(
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            elif len(filenames) < params.avg:
+                raise ValueError(
+                    f"Not enough checkpoints ({len(filenames)}) found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            logging.info(f"averaging {filenames}")
+            model.to(device)
+            model.load_state_dict(average_checkpoints(filenames, device=device))
+        elif params.avg == 1:
+            load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+        else:
+            start = params.epoch - params.avg + 1
+            filenames = []
+            for i in range(start, params.epoch + 1):
+                if i >= 1:
+                    filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+            logging.info(f"averaging {filenames}")
+            model.to(device)
+            model.load_state_dict(average_checkpoints(filenames, device=device))
+    else:
+        if params.iter > 0:
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg + 1
+            ]
+            if len(filenames) == 0:
+                raise ValueError(
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            elif len(filenames) < params.avg + 1:
+                raise ValueError(
+                    f"Not enough checkpoints ({len(filenames)}) found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            filename_start = filenames[-1]
+            filename_end = filenames[0]
+            logging.info(
+                "Calculating the averaged model over iteration checkpoints"
+                f" from {filename_start} (excluded) to {filename_end}"
+            )
+            model.to(device)
+            model.load_state_dict(
+                average_checkpoints_with_averaged_model(
+                    filename_start=filename_start,
+                    filename_end=filename_end,
+                    device=device,
+                )
+            )
+        else:
+            assert params.avg > 0, params.avg
+            start = params.epoch - params.avg
+            assert start >= 1, start
+            filename_start = f"{params.exp_dir}/epoch-{start}.pt"
+            filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
+            logging.info(
+                f"Calculating the averaged model over epoch range from "
+                f"{start} (excluded) to {params.epoch}"
+            )
+            model.to(device)
+            model.load_state_dict(
+                average_checkpoints_with_averaged_model(
+                    filename_start=filename_start,
+                    filename_end=filename_end,
+                    device=device,
+                )
+            )
+
+    model.to(device)
+    model.eval()
+
+    if params.decoding_method == "fast_beam_search":
+        decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
+    else:
+        decoding_graph = None
+
+    num_param = sum([p.numel() for p in model.parameters()])
+    logging.info(f"Number of model parameters: {num_param}")
+
+    MGB2 = MGB2AsrDataModule(args)
+
+    test_cuts = MGB2.test_cuts()
+    dev_cuts = MGB2.dev_cuts()
+
+    test_dl = MGB2.test_dataloaders(test_cuts)
+    dev_dl = MGB2.test_dataloaders(dev_cuts)
+
+    test_sets = ["test", "dev"]
+    test_all_dl = [test_dl, dev_dl]
+
+    for test_set, test_dl in zip(test_sets, test_all_dl):
+        results_dict = decode_dataset(
+            dl=test_dl,
+            params=params,
+            model=model,
+            sp=sp,
+            decoding_graph=decoding_graph,
+        )
+
+        save_results(
+            params=params,
+            test_set_name=test_set,
+            results_dict=results_dict,
+        )
+
+    logging.info("Done!")
+
+
+if __name__ == "__main__":
+    main()
diff --git a/egs/mgb2/ASR/pruned_transducer_stateless5/decoder.py b/egs/mgb2/ASR/pruned_transducer_stateless5/decoder.py
new file mode 120000
index 000000000..6775ee67e
--- /dev/null
+++ b/egs/mgb2/ASR/pruned_transducer_stateless5/decoder.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless5/decoder.py
\ No newline at end of file
diff --git a/egs/mgb2/ASR/pruned_transducer_stateless5/encoder_interface.py b/egs/mgb2/ASR/pruned_transducer_stateless5/encoder_interface.py
new file mode 120000
index 000000000..972e44ca4
--- /dev/null
+++ b/egs/mgb2/ASR/pruned_transducer_stateless5/encoder_interface.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless5/encoder_interface.py
\ No newline at end of file
diff --git a/egs/mgb2/ASR/pruned_transducer_stateless5/export.py b/egs/mgb2/ASR/pruned_transducer_stateless5/export.py
new file mode 100755
index 000000000..7a5d7f680
--- /dev/null
+++ b/egs/mgb2/ASR/pruned_transducer_stateless5/export.py
@@ -0,0 +1,272 @@
+#!/usr/bin/env python3
+#
+# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# This script converts several saved checkpoints
+# to a single one using model averaging.
+"""
+Usage:
+./pruned_transducer_stateless5/export.py \
+  --exp-dir ./pruned_transducer_stateless5/exp \
+  --bpe-model data/lang_bpe_500/bpe.model \
+  --epoch 20 \
+  --avg 10
+
+It will generate a file exp_dir/pretrained.pt
+
+To use the generated file with `pruned_transducer_stateless5/decode.py`,
+you can do:
+
+    cd /path/to/exp_dir
+    ln -s pretrained.pt epoch-9999.pt
+
+    cd /path/to/egs/librispeech/ASR
+    ./pruned_transducer_stateless5/decode.py \
+        --exp-dir ./pruned_transducer_stateless5/exp \
+        --epoch 9999 \
+        --avg 1 \
+        --max-duration 600 \
+        --decoding-method greedy_search \
+        --bpe-model data/lang_bpe_500/bpe.model
+"""
+
+import argparse
+import logging
+from pathlib import Path
+
+import sentencepiece as spm
+import torch
+from train import add_model_arguments, get_params, get_transducer_model
+
+from icefall.checkpoint import (
+    average_checkpoints,
+    average_checkpoints_with_averaged_model,
+    find_checkpoints,
+    load_checkpoint,
+)
+from icefall.utils import str2bool
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+
+    parser.add_argument(
+        "--epoch",
+        type=int,
+        default=28,
+        help="""It specifies the checkpoint to use for averaging.
+        Note: Epoch counts from 1.
+        You can specify --avg to use more checkpoints for model averaging.""",
+    )
+
+    parser.add_argument(
+        "--iter",
+        type=int,
+        default=0,
+        help="""If positive, --epoch is ignored and it
+        will use the checkpoint exp_dir/checkpoint-iter.pt.
+        You can specify --avg to use more checkpoints for model averaging.
+        """,
+    )
+
+    parser.add_argument(
+        "--avg",
+        type=int,
+        default=15,
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch' and '--iter'",
+    )
+
+    parser.add_argument(
+        "--use-averaged-model",
+        type=str2bool,
+        default=False,
+        help="Whether to load averaged model. Currently it only supports "
+        "using --epoch. If True, it would decode with the averaged model "
+        "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+        "Actually only the models with epoch number of `epoch-avg` and "
+        "`epoch` are loaded for averaging. ",
+    )
+
+    parser.add_argument(
+        "--exp-dir",
+        type=str,
+        default="pruned_transducer_stateless5/exp",
+        help="""It specifies the directory where all training related
+        files, e.g., checkpoints, log, etc, are saved
+        """,
+    )
+
+    parser.add_argument(
+        "--bpe-model",
+        type=str,
+        default="data/lang_bpe_500/bpe.model",
+        help="Path to the BPE model",
+    )
+
+    parser.add_argument(
+        "--jit",
+        type=str2bool,
+        default=False,
+        help="""True to save a model after applying torch.jit.script.
+        """,
+    )
+
+    parser.add_argument(
+        "--context-size",
+        type=int,
+        default=2,
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+    )
+
+    add_model_arguments(parser)
+
+    return parser
+
+
+def main():
+    args = get_parser().parse_args()
+    args.exp_dir = Path(args.exp_dir)
+
+    assert args.jit is False, "Support torchscript will be added later"
+
+    params = get_params()
+    params.update(vars(args))
+
+    device = torch.device("cpu")
+    if torch.cuda.is_available():
+        device = torch.device("cuda", 0)
+
+    logging.info(f"device: {device}")
+
+    sp = spm.SentencePieceProcessor()
+    sp.load(params.bpe_model)
+
+    #  is defined in local/train_bpe_model.py
+    params.blank_id = sp.piece_to_id("")
+    params.vocab_size = sp.get_piece_size()
+
+    logging.info(params)
+
+    logging.info("About to create model")
+    model = get_transducer_model(params)
+
+    if not params.use_averaged_model:
+        if params.iter > 0:
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg
+            ]
+            if len(filenames) == 0:
+                raise ValueError(
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            elif len(filenames) < params.avg:
+                raise ValueError(
+                    f"Not enough checkpoints ({len(filenames)}) found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            logging.info(f"averaging {filenames}")
+            model.to(device)
+            model.load_state_dict(average_checkpoints(filenames, device=device))
+        elif params.avg == 1:
+            load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+        else:
+            start = params.epoch - params.avg + 1
+            filenames = []
+            for i in range(start, params.epoch + 1):
+                if i >= 1:
+                    filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+            logging.info(f"averaging {filenames}")
+            model.to(device)
+            model.load_state_dict(average_checkpoints(filenames, device=device))
+    else:
+        if params.iter > 0:
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg + 1
+            ]
+            if len(filenames) == 0:
+                raise ValueError(
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            elif len(filenames) < params.avg + 1:
+                raise ValueError(
+                    f"Not enough checkpoints ({len(filenames)}) found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            filename_start = filenames[-1]
+            filename_end = filenames[0]
+            logging.info(
+                "Calculating the averaged model over iteration checkpoints"
+                f" from {filename_start} (excluded) to {filename_end}"
+            )
+            model.to(device)
+            model.load_state_dict(
+                average_checkpoints_with_averaged_model(
+                    filename_start=filename_start,
+                    filename_end=filename_end,
+                    device=device,
+                )
+            )
+        else:
+            assert params.avg > 0, params.avg
+            start = params.epoch - params.avg
+            assert start >= 1, start
+            filename_start = f"{params.exp_dir}/epoch-{start}.pt"
+            filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
+            logging.info(
+                f"Calculating the averaged model over epoch range from "
+                f"{start} (excluded) to {params.epoch}"
+            )
+            model.to(device)
+            model.load_state_dict(
+                average_checkpoints_with_averaged_model(
+                    filename_start=filename_start,
+                    filename_end=filename_end,
+                    device=device,
+                )
+            )
+
+    model.eval()
+
+    model.to("cpu")
+    model.eval()
+
+    if params.jit:
+        logging.info("Using torch.jit.script")
+        model = torch.jit.script(model)
+        filename = params.exp_dir / "cpu_jit.pt"
+        model.save(str(filename))
+        logging.info(f"Saved to {filename}")
+    else:
+        logging.info("Not using torch.jit.script")
+        # Save it using a format so that it can be loaded
+        # by :func:`load_checkpoint`
+        filename = params.exp_dir / "pretrained.pt"
+        torch.save({"model": model.state_dict()}, str(filename))
+        logging.info(f"Saved to {filename}")
+
+
+if __name__ == "__main__":
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+    logging.basicConfig(format=formatter, level=logging.INFO)
+    main()
diff --git a/egs/mgb2/ASR/pruned_transducer_stateless5/joiner.py b/egs/mgb2/ASR/pruned_transducer_stateless5/joiner.py
new file mode 120000
index 000000000..f5279e151
--- /dev/null
+++ b/egs/mgb2/ASR/pruned_transducer_stateless5/joiner.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless5/joiner.py
\ No newline at end of file
diff --git a/egs/mgb2/ASR/pruned_transducer_stateless5/model.py b/egs/mgb2/ASR/pruned_transducer_stateless5/model.py
new file mode 120000
index 000000000..7b417fd89
--- /dev/null
+++ b/egs/mgb2/ASR/pruned_transducer_stateless5/model.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless5/model.py
\ No newline at end of file
diff --git a/egs/mgb2/ASR/pruned_transducer_stateless5/optim.py b/egs/mgb2/ASR/pruned_transducer_stateless5/optim.py
new file mode 120000
index 000000000..210374f22
--- /dev/null
+++ b/egs/mgb2/ASR/pruned_transducer_stateless5/optim.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless5/optim.py
\ No newline at end of file
diff --git a/egs/mgb2/ASR/pruned_transducer_stateless5/pretrained.py b/egs/mgb2/ASR/pruned_transducer_stateless5/pretrained.py
new file mode 100755
index 000000000..77ba0873b
--- /dev/null
+++ b/egs/mgb2/ASR/pruned_transducer_stateless5/pretrained.py
@@ -0,0 +1,344 @@
+#!/usr/bin/env python3
+# Copyright      2021  Xiaomi Corp.        (authors: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+
+(1) greedy search
+./pruned_transducer_stateless5/pretrained.py \
+    --checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \
+    --bpe-model ./data/lang_bpe_500/bpe.model \
+    --method greedy_search \
+    /path/to/foo.wav \
+    /path/to/bar.wav
+
+(2) beam search
+./pruned_transducer_stateless5/pretrained.py \
+    --checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \
+    --bpe-model ./data/lang_bpe_500/bpe.model \
+    --method beam_search \
+    --beam-size 4 \
+    /path/to/foo.wav \
+    /path/to/bar.wav
+
+(3) modified beam search
+./pruned_transducer_stateless5/pretrained.py \
+    --checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \
+    --bpe-model ./data/lang_bpe_500/bpe.model \
+    --method modified_beam_search \
+    --beam-size 4 \
+    /path/to/foo.wav \
+    /path/to/bar.wav
+
+(4) fast beam search
+./pruned_transducer_stateless5/pretrained.py \
+    --checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \
+    --bpe-model ./data/lang_bpe_500/bpe.model \
+    --method fast_beam_search \
+    --beam-size 4 \
+    /path/to/foo.wav \
+    /path/to/bar.wav
+
+You can also use `./pruned_transducer_stateless5/exp/epoch-xx.pt`.
+
+Note: ./pruned_transducer_stateless5/exp/pretrained.pt is generated by
+./pruned_transducer_stateless5/export.py
+"""
+
+
+import argparse
+import logging
+import math
+from typing import List
+
+import k2
+import kaldifeat
+import sentencepiece as spm
+import torch
+import torchaudio
+from beam_search import (
+    beam_search,
+    fast_beam_search_one_best,
+    greedy_search,
+    greedy_search_batch,
+    modified_beam_search,
+)
+from torch.nn.utils.rnn import pad_sequence
+from train import add_model_arguments, get_params, get_transducer_model
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+
+    parser.add_argument(
+        "--checkpoint",
+        type=str,
+        required=True,
+        help="Path to the checkpoint. "
+        "The checkpoint is assumed to be saved by "
+        "icefall.checkpoint.save_checkpoint().",
+    )
+
+    parser.add_argument(
+        "--bpe-model",
+        type=str,
+        help="""Path to bpe.model.""",
+    )
+
+    parser.add_argument(
+        "--method",
+        type=str,
+        default="greedy_search",
+        help="""Possible values are:
+          - greedy_search
+          - beam_search
+          - modified_beam_search
+          - fast_beam_search
+        """,
+    )
+
+    parser.add_argument(
+        "sound_files",
+        type=str,
+        nargs="+",
+        help="The input sound file(s) to transcribe. "
+        "Supported formats are those supported by torchaudio.load(). "
+        "For example, wav and flac are supported. "
+        "The sample rate has to be 16kHz.",
+    )
+
+    parser.add_argument(
+        "--sample-rate",
+        type=int,
+        default=16000,
+        help="The sample rate of the input sound file",
+    )
+
+    parser.add_argument(
+        "--beam-size",
+        type=int,
+        default=4,
+        help="""An integer indicating how many candidates we will keep for each
+        frame. Used only when --method is beam_search or
+        modified_beam_search.""",
+    )
+
+    parser.add_argument(
+        "--beam",
+        type=float,
+        default=4,
+        help="""A floating point value to calculate the cutoff score during beam
+        search (i.e., `cutoff = max-score - beam`), which is the same as the
+        `beam` in Kaldi.
+        Used only when --method is fast_beam_search""",
+    )
+
+    parser.add_argument(
+        "--max-contexts",
+        type=int,
+        default=4,
+        help="""Used only when --method is fast_beam_search""",
+    )
+
+    parser.add_argument(
+        "--max-states",
+        type=int,
+        default=8,
+        help="""Used only when --method is fast_beam_search""",
+    )
+
+    parser.add_argument(
+        "--context-size",
+        type=int,
+        default=2,
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+    )
+    parser.add_argument(
+        "--max-sym-per-frame",
+        type=int,
+        default=1,
+        help="""Maximum number of symbols per frame. Used only when
+        --method is greedy_search.
+        """,
+    )
+
+    add_model_arguments(parser)
+
+    return parser
+
+
+def read_sound_files(
+    filenames: List[str], expected_sample_rate: float
+) -> List[torch.Tensor]:
+    """Read a list of sound files into a list 1-D float32 torch tensors.
+    Args:
+      filenames:
+        A list of sound filenames.
+      expected_sample_rate:
+        The expected sample rate of the sound files.
+    Returns:
+      Return a list of 1-D float32 torch tensors.
+    """
+    ans = []
+    for f in filenames:
+        wave, sample_rate = torchaudio.load(f)
+        assert sample_rate == expected_sample_rate, (
+            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
+        )
+        # We use only the first channel
+        ans.append(wave[0])
+    return ans
+
+
+@torch.no_grad()
+def main():
+    parser = get_parser()
+    args = parser.parse_args()
+
+    params = get_params()
+
+    params.update(vars(args))
+
+    sp = spm.SentencePieceProcessor()
+    sp.load(params.bpe_model)
+
+    #  is defined in local/train_bpe_model.py
+    params.blank_id = sp.piece_to_id("")
+    params.unk_id = sp.piece_to_id("")
+    params.vocab_size = sp.get_piece_size()
+
+    logging.info(f"{params}")
+
+    device = torch.device("cpu")
+    if torch.cuda.is_available():
+        device = torch.device("cuda", 0)
+
+    logging.info(f"device: {device}")
+
+    logging.info("Creating model")
+    model = get_transducer_model(params)
+
+    num_param = sum([p.numel() for p in model.parameters()])
+    logging.info(f"Number of model parameters: {num_param}")
+
+    checkpoint = torch.load(args.checkpoint, map_location="cpu")
+    model.load_state_dict(checkpoint["model"], strict=False)
+    model.to(device)
+    model.eval()
+    model.device = device
+
+    logging.info("Constructing Fbank computer")
+    opts = kaldifeat.FbankOptions()
+    opts.device = device
+    opts.frame_opts.dither = 0
+    opts.frame_opts.snip_edges = False
+    opts.frame_opts.samp_freq = params.sample_rate
+    opts.mel_opts.num_bins = params.feature_dim
+
+    fbank = kaldifeat.Fbank(opts)
+
+    logging.info(f"Reading sound files: {params.sound_files}")
+    waves = read_sound_files(
+        filenames=params.sound_files, expected_sample_rate=params.sample_rate
+    )
+    waves = [w.to(device) for w in waves]
+
+    logging.info("Decoding started")
+    features = fbank(waves)
+    feature_lengths = [f.size(0) for f in features]
+
+    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
+
+    feature_lengths = torch.tensor(feature_lengths, device=device)
+
+    encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths)
+
+    num_waves = encoder_out.size(0)
+    hyps = []
+    msg = f"Using {params.method}"
+    if params.method == "beam_search":
+        msg += f" with beam size {params.beam_size}"
+    logging.info(msg)
+
+    if params.method == "fast_beam_search":
+        decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
+        hyp_tokens = fast_beam_search_one_best(
+            model=model,
+            decoding_graph=decoding_graph,
+            encoder_out=encoder_out,
+            encoder_out_lens=encoder_out_lens,
+            beam=params.beam,
+            max_contexts=params.max_contexts,
+            max_states=params.max_states,
+        )
+        for hyp in sp.decode(hyp_tokens):
+            hyps.append(hyp.split())
+    elif params.method == "modified_beam_search":
+        hyp_tokens = modified_beam_search(
+            model=model,
+            encoder_out=encoder_out,
+            encoder_out_lens=encoder_out_lens,
+            beam=params.beam_size,
+        )
+
+        for hyp in sp.decode(hyp_tokens):
+            hyps.append(hyp.split())
+    elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
+        hyp_tokens = greedy_search_batch(
+            model=model,
+            encoder_out=encoder_out,
+            encoder_out_lens=encoder_out_lens,
+        )
+        for hyp in sp.decode(hyp_tokens):
+            hyps.append(hyp.split())
+    else:
+        for i in range(num_waves):
+            # fmt: off
+            encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
+            # fmt: on
+            if params.method == "greedy_search":
+                hyp = greedy_search(
+                    model=model,
+                    encoder_out=encoder_out_i,
+                    max_sym_per_frame=params.max_sym_per_frame,
+                )
+            elif params.method == "beam_search":
+                hyp = beam_search(
+                    model=model,
+                    encoder_out=encoder_out_i,
+                    beam=params.beam_size,
+                )
+            else:
+                raise ValueError(f"Unsupported method: {params.method}")
+
+            hyps.append(sp.decode(hyp).split())
+
+    s = "\n"
+    for filename, hyp in zip(params.sound_files, hyps):
+        words = " ".join(hyp)
+        s += f"{filename}:\n{words}\n\n"
+    logging.info(s)
+
+    logging.info("Decoding Done")
+
+
+if __name__ == "__main__":
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+    logging.basicConfig(format=formatter, level=logging.INFO)
+    main()
diff --git a/egs/mgb2/ASR/pruned_transducer_stateless5/scaling.py b/egs/mgb2/ASR/pruned_transducer_stateless5/scaling.py
new file mode 120000
index 000000000..ff7bfeda9
--- /dev/null
+++ b/egs/mgb2/ASR/pruned_transducer_stateless5/scaling.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless5/scaling.py
\ No newline at end of file
diff --git a/egs/mgb2/ASR/pruned_transducer_stateless5/test_model.py b/egs/mgb2/ASR/pruned_transducer_stateless5/test_model.py
new file mode 120000
index 000000000..b71d7bb81
--- /dev/null
+++ b/egs/mgb2/ASR/pruned_transducer_stateless5/test_model.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless5/test_model.py
\ No newline at end of file
diff --git a/egs/mgb2/ASR/pruned_transducer_stateless5/train.py b/egs/mgb2/ASR/pruned_transducer_stateless5/train.py
new file mode 100755
index 000000000..e1b623353
--- /dev/null
+++ b/egs/mgb2/ASR/pruned_transducer_stateless5/train.py
@@ -0,0 +1,1176 @@
+#!/usr/bin/env python3
+# Copyright    2022  Johns Hopkins        (authors: Amir Hussein)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+
+export CUDA_VISIBLE_DEVICES="0,1,2,3"
+
+./pruned_transducer_stateless5/train.py \
+  --world-size 2 \
+  --num-epochs 30 \
+  --start-epoch 1 \
+  --exp-dir pruned_transducer_stateless5/exp \
+  --max-duration 200 \
+  --num-buckets 50
+
+# For mix precision training:
+
+./pruned_transducer_stateless5/train.py \
+  --world-size 2 \
+  --num-epochs 30 \
+  --start-epoch 1 \
+  --use-fp16 1 \
+  --exp-dir pruned_transducer_stateless5/exp \
+  --max-duration 200	\
+  --num-buckets 50
+
+"""
+
+# xxx
+import argparse
+import copy
+import logging
+import warnings
+from pathlib import Path
+from shutil import copyfile
+from typing import Any, Dict, Optional, Tuple, Union
+
+import k2
+import nvidia_smi
+import optim
+import sentencepiece as spm
+import torch
+import torch.multiprocessing as mp
+import torch.nn as nn
+from asr_datamodule import MGB2AsrDataModule
+from conformer import Conformer
+from decoder import Decoder
+from joiner import Joiner
+from lhotse.cut import Cut
+from lhotse.dataset.sampling.base import CutSampler
+from lhotse.utils import fix_random_seed
+from model import Transducer
+from optim import Eden, Eve
+from torch import Tensor
+from torch.cuda.amp import GradScaler
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.nn.utils import clip_grad_norm_
+from torch.utils.tensorboard import SummaryWriter
+
+from icefall import diagnostics
+from icefall.checkpoint import load_checkpoint, remove_checkpoints
+from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
+from icefall.checkpoint import (
+    save_checkpoint_with_global_batch_idx,
+    update_averaged_model,
+)
+from icefall.dist import cleanup_dist, setup_dist
+from icefall.env import get_env_info
+from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
+
+LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
+
+
+def add_model_arguments(parser: argparse.ArgumentParser):
+    parser.add_argument(
+        "--num-encoder-layers",
+        type=int,
+        default=12,
+        help="Number of conformer encoder layers..",
+    )
+
+    parser.add_argument(
+        "--dim-feedforward",
+        type=int,
+        default=2048,
+        help="Feedforward dimension of the conformer encoder layer.",
+    )
+
+    parser.add_argument(
+        "--nhead",
+        type=int,
+        default=8,
+        help="Number of attention heads in the conformer encoder layer.",
+    )
+
+    parser.add_argument(
+        "--encoder-dim",
+        type=int,
+        default=512,
+        help="Attention dimension in the conformer encoder layer.",
+    )
+
+    parser.add_argument(
+        "--decoder-dim",
+        type=int,
+        default=512,
+        help="Embedding dimension in the decoder model.",
+    )
+
+    parser.add_argument(
+        "--joiner-dim",
+        type=int,
+        default=512,
+        help="""Dimension used in the joiner model.
+        Outputs from the encoder and decoder model are projected
+        to this dimension before adding.
+        """,
+    )
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+
+    parser.add_argument(
+        "--world-size",
+        type=int,
+        default=1,
+        help="Number of GPUs for DDP training.",
+    )
+
+    parser.add_argument(
+        "--master-port",
+        type=int,
+        default=12354,
+        help="Master port to use for DDP training.",
+    )
+
+    parser.add_argument(
+        "--tensorboard",
+        type=str2bool,
+        default=True,
+        help="Should various information be logged in tensorboard.",
+    )
+
+    parser.add_argument(
+        "--num-epochs",
+        type=int,
+        default=30,
+        help="Number of epochs to train.",
+    )
+
+    parser.add_argument(
+        "--start-epoch",
+        type=int,
+        default=1,
+        help="""Resume training from this epoch. It should be positive.
+        If larger than 1, it will load checkpoint from
+        exp-dir/epoch-{start_epoch-1}.pt
+        """,
+    )
+
+    parser.add_argument(
+        "--start-batch",
+        type=int,
+        default=0,
+        help="""If positive, --start-epoch is ignored and
+        it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt
+        """,
+    )
+
+    parser.add_argument(
+        "--exp-dir",
+        type=str,
+        default="pruned_transducer_stateless5/exp",
+        help="""The experiment dir.
+        It specifies the directory where all training related
+        files, e.g., checkpoints, log, etc, are saved
+        """,
+    )
+
+    parser.add_argument(
+        "--bpe-model",
+        type=str,
+        default="data/lang_bpe_2000/bpe.model",
+        help="Path to the BPE model",
+    )
+
+    parser.add_argument(
+        "--initial-lr",
+        type=float,
+        default=0.003,
+        help="The initial learning rate.  This value should not need " "to be changed.",
+    )
+
+    parser.add_argument(
+        "--lr-batches",
+        type=float,
+        default=5000,
+        help="""Number of steps that affects how rapidly the learning rate
+        decreases. We suggest not to change this.""",
+    )
+
+    parser.add_argument(
+        "--lr-epochs",
+        type=float,
+        default=6,
+        help="""Number of epochs that affects how rapidly the learning rate decreases.
+        """,
+    )
+
+    parser.add_argument(
+        "--context-size",
+        type=int,
+        default=2,
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+    )
+
+    parser.add_argument(
+        "--prune-range",
+        type=int,
+        default=5,
+        help="The prune range for rnnt loss, it means how many symbols(context)"
+        "we are using to compute the loss",
+    )
+
+    parser.add_argument(
+        "--lm-scale",
+        type=float,
+        default=0.25,
+        help="The scale to smooth the loss with lm "
+        "(output of prediction network) part.",
+    )
+
+    parser.add_argument(
+        "--am-scale",
+        type=float,
+        default=0.0,
+        help="The scale to smooth the loss with am (output of encoder network)" "part.",
+    )
+
+    parser.add_argument(
+        "--simple-loss-scale",
+        type=float,
+        default=0.5,
+        help="To get pruning ranges, we will calculate a simple version"
+        "loss(joiner is just addition), this simple loss also uses for"
+        "training (as a regularization item). We will scale the simple loss"
+        "with this parameter before adding to the final loss.",
+    )
+
+    parser.add_argument(
+        "--seed",
+        type=int,
+        default=42,
+        help="The seed for random generators intended for reproducibility",
+    )
+
+    parser.add_argument(
+        "--print-diagnostics",
+        type=str2bool,
+        default=False,
+        help="Accumulate stats on activations, print them and exit.",
+    )
+
+    parser.add_argument(
+        "--save-every-n",
+        type=int,
+        default=8000,
+        help="""Save checkpoint after processing this number of batches"
+        periodically. We save checkpoint to exp-dir/ whenever
+        params.batch_idx_train % save_every_n == 0. The checkpoint filename
+        has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
+        Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
+        end of each epoch where `xxx` is the epoch number counting from 0.
+        """,
+    )
+
+    parser.add_argument(
+        "--keep-last-k",
+        type=int,
+        default=10,
+        help="""Only keep this number of checkpoints on disk.
+        For instance, if it is 3, there are only 3 checkpoints
+        in the exp-dir with filenames `checkpoint-xxx.pt`.
+        It does not affect checkpoints with name `epoch-xxx.pt`.
+        """,
+    )
+
+    parser.add_argument(
+        "--average-period",
+        type=int,
+        default=100,
+        help="""Update the averaged model, namely `model_avg`, after processing
+        this number of batches. `model_avg` is a separate version of model,
+        in which each floating-point parameter is the average of all the
+        parameters from the start of training. Each time we take the average,
+        we do: `model_avg = model * (average_period / batch_idx_train) +
+            model_avg * ((batch_idx_train - average_period) / batch_idx_train)`.
+        """,
+    )
+
+    parser.add_argument(
+        "--use-fp16",
+        type=str2bool,
+        default=True,
+        help="Whether to use half precision training.",
+    )
+
+    add_model_arguments(parser)
+
+    return parser
+
+
+def get_params() -> AttributeDict:
+    """Return a dict containing training parameters.
+
+    All training related parameters that are not passed from the commandline
+    are saved in the variable `params`.
+
+    Commandline options are merged into `params` after they are parsed, so
+    you can also access them via `params`.
+
+    Explanation of options saved in `params`:
+
+        - best_train_loss: Best training loss so far. It is used to select
+                           the model that has the lowest training loss. It is
+                           updated during the training.
+
+        - best_valid_loss: Best validation loss so far. It is used to select
+                           the model that has the lowest validation loss. It is
+                           updated during the training.
+
+        - best_train_epoch: It is the epoch that has the best training loss.
+
+        - best_valid_epoch: It is the epoch that has the best validation loss.
+
+        - batch_idx_train: Used to writing statistics to tensorboard. It
+                           contains number of batches trained so far across
+                           epochs.
+
+        - log_interval:  Print training loss if batch_idx % log_interval` is 0
+
+        - reset_interval: Reset statistics if batch_idx % reset_interval is 0
+
+        - valid_interval:  Run validation if batch_idx % valid_interval is 0
+
+        - feature_dim: The model input dim. It has to match the one used
+                       in computing features.
+
+        - subsampling_factor:  The subsampling factor for the model.
+
+        - encoder_dim: Hidden dim for multi-head attention model.
+
+        - num_decoder_layers: Number of decoder layer of transformer decoder.
+
+        - warm_step: The warm_step for Noam optimizer.
+    """
+    params = AttributeDict(
+        {
+            "best_train_loss": float("inf"),
+            "best_valid_loss": float("inf"),
+            "best_train_epoch": -1,
+            "best_valid_epoch": -1,
+            "batch_idx_train": 0,
+            "log_interval": 50,
+            "reset_interval": 200,
+            "valid_interval": 3000,  # For the 100h subset, use 800
+            # parameters for conformer
+            "feature_dim": 80,
+            "subsampling_factor": 4,
+            # parameters for Noam
+            "model_warm_step": 80000,  # arg given to model, not for lrate
+            "env_info": get_env_info(),
+        }
+    )
+
+    return params
+
+
+def get_encoder_model(params: AttributeDict) -> nn.Module:
+    # TODO: We can add an option to switch between Conformer and Transformer
+    encoder = Conformer(
+        num_features=params.feature_dim,
+        subsampling_factor=params.subsampling_factor,
+        d_model=params.encoder_dim,
+        nhead=params.nhead,
+        dim_feedforward=params.dim_feedforward,
+        num_encoder_layers=params.num_encoder_layers,
+    )
+    return encoder
+
+
+def get_decoder_model(params: AttributeDict) -> nn.Module:
+    decoder = Decoder(
+        vocab_size=params.vocab_size,
+        decoder_dim=params.decoder_dim,
+        blank_id=params.blank_id,
+        context_size=params.context_size,
+    )
+    return decoder
+
+
+def get_joiner_model(params: AttributeDict) -> nn.Module:
+    joiner = Joiner(
+        encoder_dim=params.encoder_dim,
+        decoder_dim=params.decoder_dim,
+        joiner_dim=params.joiner_dim,
+        vocab_size=params.vocab_size,
+    )
+    return joiner
+
+
+def get_transducer_model(params: AttributeDict) -> nn.Module:
+    encoder = get_encoder_model(params)
+    decoder = get_decoder_model(params)
+    joiner = get_joiner_model(params)
+
+    model = Transducer(
+        encoder=encoder,
+        decoder=decoder,
+        joiner=joiner,
+        encoder_dim=params.encoder_dim,
+        decoder_dim=params.decoder_dim,
+        joiner_dim=params.joiner_dim,
+        vocab_size=params.vocab_size,
+    )
+    return model
+
+
+def load_checkpoint_if_available(
+    params: AttributeDict,
+    model: nn.Module,
+    model_avg: nn.Module = None,
+    optimizer: Optional[torch.optim.Optimizer] = None,
+    scheduler: Optional[LRSchedulerType] = None,
+) -> Optional[Dict[str, Any]]:
+    """Load checkpoint from file.
+
+    If params.start_batch is positive, it will load the checkpoint from
+    `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if
+    params.start_epoch is larger than 1, it will load the checkpoint from
+    `params.start_epoch - 1`.
+
+    Apart from loading state dict for `model` and `optimizer` it also updates
+    `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
+    and `best_valid_loss` in `params`.
+
+    Args:
+      params:
+        The return value of :func:`get_params`.
+      model:
+        The training model.
+      model_avg:
+        The stored model averaged from the start of training.
+      optimizer:
+        The optimizer that we are using.
+      scheduler:
+        The scheduler that we are using.
+    Returns:
+      Return a dict containing previously saved training info.
+    """
+    if params.start_batch > 0:
+        filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt"
+    elif params.start_epoch > 1:
+        filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
+    else:
+        return None
+
+    assert filename.is_file(), f"{filename} does not exist!"
+
+    saved_params = load_checkpoint(
+        filename,
+        model=model,
+        model_avg=model_avg,
+        optimizer=optimizer,
+        scheduler=scheduler,
+    )
+
+    keys = [
+        "best_train_epoch",
+        "best_valid_epoch",
+        "batch_idx_train",
+        "best_train_loss",
+        "best_valid_loss",
+    ]
+    for k in keys:
+        params[k] = saved_params[k]
+
+    if params.start_batch > 0:
+        if "cur_epoch" in saved_params:
+            params["start_epoch"] = saved_params["cur_epoch"]
+
+        if "cur_batch_idx" in saved_params:
+            params["cur_batch_idx"] = saved_params["cur_batch_idx"]
+
+    return saved_params
+
+
+def save_checkpoint(
+    params: AttributeDict,
+    model: Union[nn.Module, DDP],
+    model_avg: Optional[nn.Module] = None,
+    optimizer: Optional[torch.optim.Optimizer] = None,
+    scheduler: Optional[LRSchedulerType] = None,
+    sampler: Optional[CutSampler] = None,
+    scaler: Optional[GradScaler] = None,
+    rank: int = 0,
+) -> None:
+    """Save model, optimizer, scheduler and training stats to file.
+
+    Args:
+      params:
+        It is returned by :func:`get_params`.
+      model:
+        The training model.
+      model_avg:
+        The stored model averaged from the start of training.
+      optimizer:
+        The optimizer used in the training.
+      sampler:
+       The sampler for the training dataset.
+      scaler:
+        The scaler used for mix precision training.
+    """
+    if rank != 0:
+        return
+    filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
+    save_checkpoint_impl(
+        filename=filename,
+        model=model,
+        model_avg=model_avg,
+        params=params,
+        optimizer=optimizer,
+        scheduler=scheduler,
+        sampler=sampler,
+        scaler=scaler,
+        rank=rank,
+    )
+
+    if params.best_train_epoch == params.cur_epoch:
+        best_train_filename = params.exp_dir / "best-train-loss.pt"
+        copyfile(src=filename, dst=best_train_filename)
+
+    if params.best_valid_epoch == params.cur_epoch:
+        best_valid_filename = params.exp_dir / "best-valid-loss.pt"
+        copyfile(src=filename, dst=best_valid_filename)
+
+
+def compute_loss(
+    params: AttributeDict,
+    model: Union[nn.Module, DDP],
+    sp: spm.SentencePieceProcessor,
+    batch: dict,
+    is_training: bool,
+    warmup: float = 1.0,
+    reduction="none",
+) -> Tuple[Tensor, MetricsTracker]:
+    """
+    Compute CTC loss given the model and its inputs.
+
+    Args:
+      params:
+        Parameters for training. See :func:`get_params`.
+      model:
+        The model for training. It is an instance of Conformer in our case.
+      batch:
+        A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+        for the content in it.
+      is_training:
+        True for training. False for validation. When it is True, this
+        function enables autograd during computation; when it is False, it
+        disables autograd.
+     warmup: a floating point value which increases throughout training;
+        values >= 1.0 are fully warmed up and have all modules present.
+    """
+    device = model.device if isinstance(model, DDP) else next(model.parameters()).device
+    feature = batch["inputs"]
+    # at entry, feature is (N, T, C)
+    assert feature.ndim == 3
+    feature = feature.to(device)
+
+    supervisions = batch["supervisions"]
+    feature_lens = supervisions["num_frames"].to(device)
+
+    texts = batch["supervisions"]["text"]
+    y = sp.encode(texts, out_type=int)
+    y = k2.RaggedTensor(y).to(device)
+
+    with torch.set_grad_enabled(is_training):
+        simple_loss, pruned_loss = model(
+            x=feature,
+            x_lens=feature_lens,
+            y=y,
+            prune_range=params.prune_range,
+            am_scale=params.am_scale,
+            lm_scale=params.lm_scale,
+            warmup=warmup,
+            reduction="none",
+        )
+        simple_loss_is_finite = torch.isfinite(simple_loss)
+        pruned_loss_is_finite = torch.isfinite(pruned_loss)
+        is_finite = simple_loss_is_finite & pruned_loss_is_finite
+        inf_flag = False
+        if not torch.all(is_finite):
+            inf_flag = True
+            logging.info(
+                "Not all losses are finite!\n"
+                f"simple_loss: {simple_loss}\n"
+                f"pruned_loss: {pruned_loss}"
+            )
+            display_and_save_batch(batch, params=params, sp=sp)
+            simple_loss = simple_loss[simple_loss_is_finite]
+            pruned_loss = pruned_loss[pruned_loss_is_finite]
+
+        simple_loss = simple_loss.sum()
+        pruned_loss = pruned_loss.sum()
+
+        # after the main warmup step, we keep pruned_loss_scale small
+        # for the same amount of time (model_warm_step), to avoid
+        # overwhelming the simple_loss and causing it to diverge,
+        # in case it had not fully learned the alignment yet.
+        pruned_loss_scale = (
+            0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
+        )
+        loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
+
+    assert loss.requires_grad == is_training
+
+    info = MetricsTracker()
+    with warnings.catch_warnings():
+        warnings.simplefilter("ignore")
+
+        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
+
+    # # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances`  # noqa
+    # info["utterances"] = feature.size(0)
+    # # averaged input duration in frames over utterances
+    # info["utt_duration"] = feature_lens.sum().item()
+    # # averaged padding proportion over utterances
+    # info["utt_pad_proportion"] = (
+    #     ((feature.size(1) - feature_lens) / feature.size(1)).sum().item()
+    # )
+
+    # Note: We use reduction=sum while computing the loss.
+    info["loss"] = loss.detach().cpu().item()
+    info["simple_loss"] = simple_loss.detach().cpu().item()
+    info["pruned_loss"] = pruned_loss.detach().cpu().item()
+
+    return loss, info, inf_flag
+
+
+def compute_validation_loss(
+    params: AttributeDict,
+    model: Union[nn.Module, DDP],
+    sp: spm.SentencePieceProcessor,
+    valid_dl: torch.utils.data.DataLoader,
+    world_size: int = 1,
+) -> MetricsTracker:
+    """Run the validation process."""
+    model.eval()
+
+    tot_loss = MetricsTracker()
+    with torch.no_grad():
+        for batch_idx, batch in enumerate(valid_dl):
+            loss, loss_info, inf_flag = compute_loss(
+                params=params,
+                model=model,
+                sp=sp,
+                batch=batch,
+                is_training=False,
+            )
+            assert loss.requires_grad is False
+            tot_loss = tot_loss + loss_info
+
+        if world_size > 1:
+            tot_loss.reduce(loss.device)
+
+        loss_value = tot_loss["loss"] / tot_loss["frames"]
+        if loss_value < params.best_valid_loss:
+            params.best_valid_epoch = params.cur_epoch
+            params.best_valid_loss = loss_value
+
+    return tot_loss
+
+
+def train_one_epoch(
+    params: AttributeDict,
+    model: Union[nn.Module, DDP],
+    optimizer: torch.optim.Optimizer,
+    scheduler: LRSchedulerType,
+    sp: spm.SentencePieceProcessor,
+    train_dl: torch.utils.data.DataLoader,
+    valid_dl: torch.utils.data.DataLoader,
+    scaler: GradScaler,
+    model_avg: Optional[nn.Module] = None,
+    tb_writer: Optional[SummaryWriter] = None,
+    world_size: int = 1,
+    rank: int = 0,
+) -> None:
+    """Train the model for one epoch.
+
+    The training loss from the mean of all frames is saved in
+    `params.train_loss`. It runs the validation process every
+    `params.valid_interval` batches.
+
+    Args:
+      params:
+        It is returned by :func:`get_params`.
+      model:
+        The model for training.
+      optimizer:
+        The optimizer we are using.
+      scheduler:
+        The learning rate scheduler, we call step() every step.
+      train_dl:
+        Dataloader for the training dataset.
+      valid_dl:
+        Dataloader for the validation dataset.
+      scaler:
+        The scaler used for mix precision training.
+      model_avg:
+        The stored model averaged from the start of training.
+      tb_writer:
+        Writer to write log messages to tensorboard.
+      world_size:
+        Number of nodes in DDP training. If it is 1, DDP is disabled.
+      rank:
+        The rank of the node in DDP training. If no DDP is used, it should
+        be set to 0.
+    """
+    model.train()
+
+    tot_loss = MetricsTracker()
+
+    cur_batch_idx = params.get("cur_batch_idx", 0)
+
+    for batch_idx, batch in enumerate(train_dl):
+
+        if batch["inputs"].shape[0] == len(batch["supervisions"]["text"]):
+            if batch_idx < cur_batch_idx:
+                continue
+            cur_batch_idx = batch_idx
+
+            params.batch_idx_train += 1
+            batch_size = len(batch["supervisions"]["text"])
+
+            try:
+                with torch.cuda.amp.autocast(enabled=params.use_fp16):
+                    loss, loss_info, inf_flag = compute_loss(
+                        params=params,
+                        model=model,
+                        sp=sp,
+                        batch=batch,
+                        is_training=True,
+                        warmup=(params.batch_idx_train / params.model_warm_step),
+                    )
+                # summary stats
+                tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
+
+                # NOTE: We use reduction==sum and loss is computed over utterances
+                # in the batch and there is no normalization to it so far.
+                if not inf_flag:
+                    scaler.scale(loss).backward()
+                    scheduler.step_batch(params.batch_idx_train)
+                    scaler.step(optimizer)
+                    scaler.update()
+                    optimizer.zero_grad()
+                else:
+                    continue
+            except:  # noqa
+                display_and_save_batch(batch, params=params, sp=sp)
+                raise
+
+            if params.print_diagnostics and batch_idx == 5:
+                return
+
+            if (
+                rank == 0
+                and params.batch_idx_train > 0
+                and params.batch_idx_train % params.average_period == 0
+            ):
+                update_averaged_model(
+                    params=params,
+                    model_cur=model,
+                    model_avg=model_avg,
+                )
+
+            if (
+                params.batch_idx_train > 0
+                and params.batch_idx_train % params.save_every_n == 0
+            ):
+                params.cur_batch_idx = batch_idx
+                save_checkpoint_with_global_batch_idx(
+                    out_dir=params.exp_dir,
+                    global_batch_idx=params.batch_idx_train,
+                    model=model,
+                    model_avg=model_avg,
+                    params=params,
+                    optimizer=optimizer,
+                    scheduler=scheduler,
+                    sampler=train_dl.sampler,
+                    scaler=scaler,
+                    rank=rank,
+                )
+                del params.cur_batch_idx
+                remove_checkpoints(
+                    out_dir=params.exp_dir,
+                    topk=params.keep_last_k,
+                    rank=rank,
+                )
+
+            if batch_idx % params.log_interval == 0:
+                cur_lr = scheduler.get_last_lr()[0]
+                # https://silpara.medium.com/check-gpu-memory-usage-from-python-ccca503322ea
+                memory_debugging()
+                logging.info(
+                    f"Epoch {params.cur_epoch}, "
+                    f"batch {batch_idx}, loss[{loss_info}], "
+                    f"tot_loss[{tot_loss}], batch size: {batch_size}, "
+                    f"lr: {cur_lr:.2e}"
+                )
+
+                if tb_writer is not None:
+                    tb_writer.add_scalar(
+                        "train/learning_rate", cur_lr, params.batch_idx_train
+                    )
+
+                    loss_info.write_summary(
+                        tb_writer, "train/current_", params.batch_idx_train
+                    )
+                    tot_loss.write_summary(
+                        tb_writer, "train/tot_", params.batch_idx_train
+                    )
+
+            if batch_idx > 0 and batch_idx % params.valid_interval == 0:
+                logging.info("Computing validation loss")
+                valid_info = compute_validation_loss(
+                    params=params,
+                    model=model,
+                    sp=sp,
+                    valid_dl=valid_dl,
+                    world_size=world_size,
+                )
+                model.train()
+                logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
+                if tb_writer is not None:
+                    valid_info.write_summary(
+                        tb_writer, "train/valid_", params.batch_idx_train
+                    )
+        else:
+            logging.warning(
+                f"Batch {batch_idx} mismatch in dimentions between the input and the output. Skipping ..."
+            )
+            continue
+
+    loss_value = tot_loss["loss"] / tot_loss["frames"]
+    params.train_loss = loss_value
+    if params.train_loss < params.best_train_loss:
+        params.best_train_epoch = params.cur_epoch
+        params.best_train_loss = params.train_loss
+
+
+def memory_debugging():
+    # memory nvidia debugging
+    nvidia_smi.nvmlInit()
+
+    deviceCount = nvidia_smi.nvmlDeviceGetCount()
+    for i in range(deviceCount):
+        handle = nvidia_smi.nvmlDeviceGetHandleByIndex(i)
+        info = nvidia_smi.nvmlDeviceGetMemoryInfo(handle)
+        logging.info(
+            "Device {}: {}, Memory : ({:.2f}% free): {}(total), {} (free), {} (used)".format(
+                i,
+                nvidia_smi.nvmlDeviceGetName(handle),
+                100 * info.free / info.total,
+                info.total,
+                info.free,
+                info.used,
+            )
+        )
+
+    nvidia_smi.nvmlShutdown()
+
+
+def run(rank, world_size, args):
+    """
+    Args:
+      rank:
+        It is a value between 0 and `world_size-1`, which is
+        passed automatically by `mp.spawn()` in :func:`main`.
+        The node with rank 0 is responsible for saving checkpoint.
+      world_size:
+        Number of GPUs for DDP training.
+      args:
+        The return value of get_parser().parse_args()
+    """
+    params = get_params()
+    params.update(vars(args))
+
+    fix_random_seed(params.seed)
+    if world_size > 1:
+        setup_dist(rank, world_size, params.master_port)
+
+    setup_logger(f"{params.exp_dir}/log/log-train")
+    logging.info("Training started")
+
+    if args.tensorboard and rank == 0:
+        tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
+    else:
+        tb_writer = None
+
+    device = torch.device("cpu")
+    if torch.cuda.is_available():
+        device = torch.device("cuda", rank)
+    logging.info(f"Device: {device}")
+
+    sp = spm.SentencePieceProcessor()
+    sp.load(params.bpe_model)
+    #  is defined in local/train_bpe_model.py
+    params.blank_id = sp.piece_to_id("")
+    params.vocab_size = sp.get_piece_size()
+
+    logging.info(params)
+
+    logging.info("About to create model")
+    model = get_transducer_model(params)
+
+    num_param = sum([p.numel() for p in model.parameters()])
+    logging.info(f"Number of model parameters: {num_param}")
+
+    assert params.save_every_n >= params.average_period
+    model_avg: Optional[nn.Module] = None
+    if rank == 0:
+        # model_avg is only used with rank 0
+        model_avg = copy.deepcopy(model)
+
+    checkpoints = load_checkpoint_if_available(
+        params=params, model=model, model_avg=model_avg
+    )
+
+    model.to(device)
+    if world_size > 1:
+        logging.info("Using DDP")
+        model = DDP(model, device_ids=[rank])
+
+    optimizer = Eve(model.parameters(), lr=params.initial_lr)
+
+    scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
+
+    if checkpoints and "optimizer" in checkpoints:
+        logging.info("Loading optimizer state dict")
+        optimizer.load_state_dict(checkpoints["optimizer"])
+
+    if (
+        checkpoints
+        and "scheduler" in checkpoints
+        and checkpoints["scheduler"] is not None
+    ):
+        logging.info("Loading scheduler state dict")
+        scheduler.load_state_dict(checkpoints["scheduler"])
+
+    if params.print_diagnostics:
+        opts = diagnostics.TensorDiagnosticOptions(
+            2**22
+        )  # allow 4 megabytes per sub-module
+        diagnostic = diagnostics.attach_diagnostics(model, opts)
+
+    MGB2 = MGB2AsrDataModule(args)
+    train_cuts = MGB2.train_cuts()
+
+    def remove_short_and_long_utt(c: Cut):
+        # Keep only utterances with duration between 1 second and 30 seconds
+        #
+        # Caution: There is a reason to select 20.0 here. Please see
+        # ../local/display_manifest_statistics.py
+        #
+        # You should use ../local/display_manifest_statistics.py to get
+        # an utterance duration distribution for your dataset to select
+        # the threshold
+        return 0.5 <= c.duration <= 30.0
+
+    def remove_short_and_long_text(c: Cut):
+        # Keep only text with charachters between 20 and 450
+
+        return 20 <= len(c.supervisions[0].text) <= 450
+
+    train_cuts = train_cuts.filter(remove_short_and_long_utt)
+    train_cuts = train_cuts.filter(remove_short_and_long_text)
+
+    if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
+        # We only load the sampler's state dict when it loads a checkpoint
+        # saved in the middle of an epoch
+        sampler_state_dict = checkpoints["sampler"]
+    else:
+        sampler_state_dict = None
+
+    train_dl = MGB2.train_dataloaders(train_cuts, sampler_state_dict=sampler_state_dict)
+
+    valid_cuts = MGB2.dev_cuts()
+    valid_dl = MGB2.test_dataloaders(valid_cuts)
+
+    if not params.print_diagnostics:
+        scan_pessimistic_batches_for_oom(
+            model=model,
+            train_dl=train_dl,
+            optimizer=optimizer,
+            sp=sp,
+            params=params,
+        )
+
+    scaler = GradScaler(enabled=params.use_fp16)
+    if checkpoints and "grad_scaler" in checkpoints:
+        logging.info("Loading grad scaler state dict")
+        scaler.load_state_dict(checkpoints["grad_scaler"])
+
+    for epoch in range(params.start_epoch, params.num_epochs + 1):
+
+        scheduler.step_epoch(epoch - 1)
+        fix_random_seed(params.seed + epoch - 1)
+        train_dl.sampler.set_epoch(epoch - 1)
+
+        if tb_writer is not None:
+            tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
+
+        params.cur_epoch = epoch
+
+        train_one_epoch(
+            params=params,
+            model=model,
+            model_avg=model_avg,
+            optimizer=optimizer,
+            scheduler=scheduler,
+            sp=sp,
+            train_dl=train_dl,
+            valid_dl=valid_dl,
+            scaler=scaler,
+            tb_writer=tb_writer,
+            world_size=world_size,
+            rank=rank,
+        )
+
+        if params.print_diagnostics:
+            diagnostic.print_diagnostics()
+            break
+
+        save_checkpoint(
+            params=params,
+            model=model,
+            model_avg=model_avg,
+            optimizer=optimizer,
+            scheduler=scheduler,
+            sampler=train_dl.sampler,
+            scaler=scaler,
+            rank=rank,
+        )
+
+    logging.info("Done!")
+
+    if world_size > 1:
+        torch.distributed.barrier()
+        cleanup_dist()
+
+
+def display_and_save_batch(
+    batch: dict,
+    params: AttributeDict,
+    sp: spm.SentencePieceProcessor,
+) -> None:
+    """Display the batch statistics and save the batch into disk.
+
+    Args:
+      batch:
+        A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+        for the content in it.
+      params:
+        Parameters for training. See :func:`get_params`.
+      sp:
+        The BPE model.
+    """
+    from lhotse.utils import uuid4
+
+    filename = f"{params.exp_dir}/batch-{uuid4()}.pt"
+    logging.info(f"Saving batch to {filename}")
+    torch.save(batch, filename)
+
+    supervisions = batch["supervisions"]
+    features = batch["inputs"]
+
+    logging.info(f"features shape: {features.shape}")
+
+    y = sp.encode(supervisions["text"], out_type=int)
+    num_tokens = sum(len(i) for i in y)
+    logging.info(f"num tokens: {num_tokens}")
+
+
+def scan_pessimistic_batches_for_oom(
+    model: Union[nn.Module, DDP],
+    train_dl: torch.utils.data.DataLoader,
+    optimizer: torch.optim.Optimizer,
+    sp: spm.SentencePieceProcessor,
+    params: AttributeDict,
+):
+    from lhotse.dataset import find_pessimistic_batches
+
+    logging.info(
+        "Sanity check -- see if any of the batches in epoch 1 would cause OOM."
+    )
+    batches, crit_values = find_pessimistic_batches(train_dl.sampler)
+    for criterion, cuts in batches.items():
+        batch = train_dl.dataset[cuts]
+        try:
+            # warmup = 0.0 is so that the derivs for the pruned loss stay zero
+            # (i.e. are not remembered by the decaying-average in adam), because
+            # we want to avoid these params being subject to shrinkage in adam.
+            with torch.cuda.amp.autocast(enabled=params.use_fp16):
+
+                loss, _, _ = compute_loss(
+                    params=params,
+                    model=model,
+                    sp=sp,
+                    batch=batch,
+                    is_training=True,
+                    warmup=0.0,
+                )
+            loss.backward()
+            # clip_grad_norm_(model.parameters(), 5.0, 2.0)
+            optimizer.step()
+            optimizer.zero_grad()
+        except Exception as e:
+            if "CUDA out of memory" in str(e):
+                logging.error(
+                    "Your GPU ran out of memory with the current "
+                    "max_duration setting. We recommend decreasing "
+                    "max_duration and trying again.\n"
+                    f"Failing criterion: {criterion} "
+                    f"(={crit_values[criterion]}) ..."
+                )
+            display_and_save_batch(batch, params=params, sp=sp)
+            raise
+
+
+def main():
+    parser = get_parser()
+    MGB2AsrDataModule.add_arguments(parser)
+    args = parser.parse_args()
+    args.exp_dir = Path(args.exp_dir)
+
+    world_size = args.world_size
+    assert world_size >= 1
+    if world_size > 1:
+        mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
+    else:
+        run(rank=0, world_size=1, args=args)
+
+
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+if __name__ == "__main__":
+    main()
diff --git a/egs/mgb2/ASR/shared b/egs/mgb2/ASR/shared
new file mode 120000
index 000000000..4c5e91438
--- /dev/null
+++ b/egs/mgb2/ASR/shared
@@ -0,0 +1 @@
+../../../icefall/shared/
\ No newline at end of file
diff --git a/icefall/diagnostics.py b/icefall/diagnostics.py
index 207c12bf1..6589579d1 100644
--- a/icefall/diagnostics.py
+++ b/icefall/diagnostics.py
@@ -263,7 +263,7 @@ class TensorDiagnostic(object):
                     ans += f", norm={norm:.2g}"
                 mean = stats.mean().item()
                 rms = (stats**2).mean().sqrt().item()
-                ans += f", mean={mean:.3g}, rms={rms:.3g}"
+                ans += f", mean={mean:.2g}, rms={rms:.2g}"
 
                 # OK, "ans" contains the actual stats, e.g.
                 # ans = "percentiles: [0.43 0.46 0.48 0.49 0.49 0.5 0.51 0.52 0.53 0.54 0.59], mean=0.5, rms=0.5"

From 7700ddcb38b5ba0d91334947e3cac44825f1cf7c Mon Sep 17 00:00:00 2001
From: Weiji Zhuang 
Date: Fri, 2 Dec 2022 17:40:42 +0800
Subject: [PATCH 031/174] update multidataset zipformer results (#728)

---
 egs/librispeech/ASR/RESULTS.md | 26 +++++++++++++++-----------
 1 file changed, 15 insertions(+), 11 deletions(-)

diff --git a/egs/librispeech/ASR/RESULTS.md b/egs/librispeech/ASR/RESULTS.md
index c2ea3d050..0885fb9b6 100644
--- a/egs/librispeech/ASR/RESULTS.md
+++ b/egs/librispeech/ASR/RESULTS.md
@@ -108,21 +108,25 @@ See  for more details.
 [pruned_transducer_stateless8](./pruned_transducer_stateless8)
 
 The tensorboard log can be found at
-
+
 
 You can find a pretrained model, training logs, decoding logs, and decoding
 results at:
-
+
 
 You can use  to deploy it.
 
 Number of model parameters: 70369391, i.e., 70.37 M
 
-|                      | test-clean | test-other  | comment                                |
-|----------------------|------------|-------------|----------------------------------------|
-| greedy search        | 1.87       | 4.38        | --epoch 16 --avg 2 --max-duration 600  |
-| modified beam search | 1.81       | 4.34        | --epoch 16 --avg 2 --max-duration 600  |
-| fast beam search     | 1.91       | 4.33        | --epoch 16 --avg 2 --max-duration 600  |
+| decoding method      | test-clean | test-other | comment            |
+|----------------------|------------|------------|--------------------|
+| greedy_search        | 1.81       | 4.18       | --epoch 20 --avg 4 |
+| fast_beam_search     | 1.82       | 4.15       | --epoch 20 --avg 4 |
+| modified_beam_search | 1.78       | **4.08**   | --epoch 20 --avg 4 |
+| greedy_search        | 1.84       | 4.3        | --epoch 19 --avg 8 |
+| fast_beam_search     |**1.77**    | 4.25       | --epoch 19 --avg 8 |
+| modified_beam_search | 1.81       | 4.16       | --epoch 19 --avg 8 |
+
 
 The training commands are:
 ```bash
@@ -142,15 +146,15 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
 
 The decoding commands are:
 ```bash
-for m in greedy_search fast_beam_search modified_beam_search ; do
-  for epoch in 16; do
-    for avg in 2; do
+for m in greedy_search fast_beam_search modified_beam_search; do
+  for epoch in $(seq 20 -1 10); do
+    for avg in $(seq 9 -1 1); do
       ./pruned_transducer_stateless8/decode.py \
           --epoch $epoch \
           --avg $avg \
           --use-averaged-model 1 \
           --exp-dir ./pruned_transducer_stateless8/exp \
-          --feedforward-dims  "1024,1024,2048,2048,1024" \
+          --feedforward-dims "1024,1024,2048,2048,1024" \
           --max-duration 600 \
           --decoding-method $m
     done

From 8eb4b9d96da0432c1c27901f2964da954583d69a Mon Sep 17 00:00:00 2001
From: Zengwei Yao 
Date: Sat, 3 Dec 2022 19:01:10 +0800
Subject: [PATCH 032/174] Combining rnnt loss and k2-ctc loss for Dan's
 Zipformer (#683)

* init files

* add ctc as auxiliary loss and ctc_decode.py

* tuning the scalar of HLG score for 1best, nbest and nbest-oracle

* rename to pruned_transducer_stateless7_ctc

* fix doc

* fix bug, recover the hlg scores

* modify ctc_decode.py, move out the hlg scale

* fix hlg_scale

* add export.py and pretrained.py, and so on

* upload files, update README.md and RESULTS.md

* add CI test
---
 ...ed-transducer-stateless7-ctc-2022-12-01.sh |  147 ++
 ...-librispeech-2022-12-01-stateless7-ctc.yml |  163 +++
 egs/librispeech/ASR/README.md                 |    1 +
 egs/librispeech/ASR/RESULTS.md                |   79 ++
 .../ASR/conformer_ctc3/jit_pretrained.py      |   20 +-
 .../__init__.py                               |    0
 .../asr_datamodule.py                         |    1 +
 .../beam_search.py                            |    1 +
 .../ctc_decode.py                             |  818 +++++++++++
 .../decode.py                                 |  841 +++++++++++
 .../decoder.py                                |    1 +
 .../encoder_interface.py                      |    1 +
 .../export.py                                 |  320 +++++
 .../jit_pretrained.py                         |  271 ++++
 .../jit_pretrained_ctc.py                     |  423 ++++++
 .../joiner.py                                 |    1 +
 .../pruned_transducer_stateless7_ctc/model.py |  198 +++
 .../pruned_transducer_stateless7_ctc/optim.py |    1 +
 .../pretrained.py                             |  353 +++++
 .../pretrained_ctc.py                         |  441 ++++++
 .../scaling.py                                |    1 +
 .../scaling_converter.py                      |    1 +
 .../test_model.py                             |   56 +
 .../pruned_transducer_stateless7_ctc/train.py | 1252 +++++++++++++++++
 .../zipformer.py                              |    1 +
 icefall/utils.py                              |   18 +-
 26 files changed, 5396 insertions(+), 14 deletions(-)
 create mode 100755 .github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-2022-12-01.sh
 create mode 100644 .github/workflows/run-librispeech-2022-12-01-stateless7-ctc.yml
 create mode 100644 egs/librispeech/ASR/pruned_transducer_stateless7_ctc/__init__.py
 create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless7_ctc/asr_datamodule.py
 create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless7_ctc/beam_search.py
 create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless7_ctc/ctc_decode.py
 create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless7_ctc/decode.py
 create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless7_ctc/decoder.py
 create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless7_ctc/encoder_interface.py
 create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless7_ctc/export.py
 create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless7_ctc/jit_pretrained.py
 create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py
 create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless7_ctc/joiner.py
 create mode 100644 egs/librispeech/ASR/pruned_transducer_stateless7_ctc/model.py
 create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless7_ctc/optim.py
 create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained.py
 create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained_ctc.py
 create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless7_ctc/scaling.py
 create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless7_ctc/scaling_converter.py
 create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless7_ctc/test_model.py
 create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless7_ctc/train.py
 create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless7_ctc/zipformer.py

diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-2022-12-01.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-2022-12-01.sh
new file mode 100755
index 000000000..6642d5f67
--- /dev/null
+++ b/.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-2022-12-01.sh
@@ -0,0 +1,147 @@
+#!/usr/bin/env bash
+
+set -e
+
+log() {
+  # This function is from espnet
+  local fname=${BASH_SOURCE[1]##*/}
+  echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
+}
+
+cd egs/librispeech/ASR
+
+repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-ctc-2022-12-01
+
+log "Downloading pre-trained model from $repo_url"
+git lfs install
+GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
+repo=$(basename $repo_url)
+
+log "Display test files"
+tree $repo/
+soxi $repo/test_wavs/*.wav
+ls -lh $repo/test_wavs/*.wav
+
+pushd $repo/exp
+git lfs pull --include "data/*"
+git lfs pull --include "exp/cpu_jit.pt"
+git lfs pull --include "exp/pretrained.pt"
+ln -s pretrained.pt epoch-99.pt
+ls -lh *.pt
+popd
+
+log "Export to torchscript model"
+./pruned_transducer_stateless7_ctc/export.py \
+  --exp-dir $repo/exp \
+  --use-averaged-model false \
+  --bpe-model $repo/data/lang_bpe_500/bpe.model \
+  --epoch 99 \
+  --avg 1 \
+  --jit 1
+
+ls -lh $repo/exp/*.pt
+
+log "Decode with models exported by torch.jit.script()"
+
+./pruned_transducer_stateless7_ctc/jit_pretrained.py \
+  --bpe-model $repo/data/lang_bpe_500/bpe.model \
+  --nn-model-filename $repo/exp/cpu_jit.pt \
+  $repo/test_wavs/1089-134686-0001.wav \
+  $repo/test_wavs/1221-135766-0001.wav \
+  $repo/test_wavs/1221-135766-0002.wav
+
+for m in ctc-decoding 1best; do
+  ./pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py \
+    --model-filename $repo/exp/cpu_jit.pt \
+    --words-file $repo/data/lang_bpe_500/words.txt  \
+    --HLG $repo/data/lang_bpe_500/HLG.pt \
+    --bpe-model $repo/data/lang_bpe_500/bpe.model \
+    --G $repo/data/lm/G_4_gram.pt \
+    --method $m \
+    --sample-rate 16000 \
+    $repo/test_wavs/1089-134686-0001.wav \
+    $repo/test_wavs/1221-135766-0001.wav \
+    $repo/test_wavs/1221-135766-0002.wav
+done
+
+for sym in 1 2 3; do
+  log "Greedy search with --max-sym-per-frame $sym"
+
+  ./pruned_transducer_stateless7_ctc/pretrained.py \
+    --method greedy_search \
+    --max-sym-per-frame $sym \
+    --checkpoint $repo/exp/pretrained.pt \
+    --bpe-model $repo/data/lang_bpe_500/bpe.model \
+    $repo/test_wavs/1089-134686-0001.wav \
+    $repo/test_wavs/1221-135766-0001.wav \
+    $repo/test_wavs/1221-135766-0002.wav
+done
+
+for method in modified_beam_search beam_search fast_beam_search; do
+  log "$method"
+
+  ./pruned_transducer_stateless7_ctc/pretrained.py \
+    --method $method \
+    --beam-size 4 \
+    --checkpoint $repo/exp/pretrained.pt \
+    --bpe-model $repo/data/lang_bpe_500/bpe.model \
+    $repo/test_wavs/1089-134686-0001.wav \
+    $repo/test_wavs/1221-135766-0001.wav \
+    $repo/test_wavs/1221-135766-0002.wav
+done
+
+for m in ctc-decoding 1best; do
+  ./pruned_transducer_stateless7_ctc/pretrained_ctc.py \
+    --checkpoint $repo/exp/pretrained.pt \
+    --words-file $repo/data/lang_bpe_500/words.txt  \
+    --HLG $repo/data/lang_bpe_500/HLG.pt \
+    --bpe-model $repo/data/lang_bpe_500/bpe.model \
+    --G $repo/data/lm/G_4_gram.pt \
+    --method $m \
+    --sample-rate 16000 \
+    $repo/test_wavs/1089-134686-0001.wav \
+    $repo/test_wavs/1221-135766-0001.wav \
+    $repo/test_wavs/1221-135766-0002.wav
+done
+
+echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}"
+echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}"
+if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode"  ]]; then
+  mkdir -p pruned_transducer_stateless7_ctc/exp
+  ln -s $PWD/$repo/exp/pretrained.pt pruned_transducer_stateless7_ctc/exp/epoch-999.pt
+  ln -s $PWD/$repo/data/lang_bpe_500 data/
+
+  ls -lh data
+  ls -lh pruned_transducer_stateless7_ctc/exp
+
+  log "Decoding test-clean and test-other"
+
+  # use a small value for decoding with CPU
+  max_duration=100
+
+  for method in greedy_search fast_beam_search modified_beam_search; do
+    log "Decoding with $method"
+
+    ./pruned_transducer_stateless7_ctc/decode.py \
+      --decoding-method $method \
+      --epoch 999 \
+      --avg 1 \
+      --use-averaged-model 0 \
+      --max-duration $max_duration \
+      --exp-dir pruned_transducer_stateless7_ctc/exp
+  done
+
+  for m in ctc-decoding 1best; do
+    ./pruned_transducer_stateless7_ctc/ctc_decode.py \
+        --epoch 999 \
+        --avg 1 \
+        --exp-dir ./pruned_transducer_stateless7_ctc/exp \
+        --max-duration $max_duration \
+        --use-averaged-model 0 \
+        --decoding-method $m \
+        --hlg-scale 0.6 \
+        --lm-dir data/lm
+  done
+
+  rm pruned_transducer_stateless7_ctc/exp/*.pt
+fi
diff --git a/.github/workflows/run-librispeech-2022-12-01-stateless7-ctc.yml b/.github/workflows/run-librispeech-2022-12-01-stateless7-ctc.yml
new file mode 100644
index 000000000..ccd8d50d0
--- /dev/null
+++ b/.github/workflows/run-librispeech-2022-12-01-stateless7-ctc.yml
@@ -0,0 +1,163 @@
+# Copyright      2022  Fangjun Kuang (csukuangfj@gmail.com)
+
+# See ../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+name: run-librispeech-2022-12-01-stateless7-ctc
+# zipformer
+
+on:
+  push:
+    branches:
+      - master
+  pull_request:
+    types: [labeled]
+
+  schedule:
+    # minute (0-59)
+    # hour (0-23)
+    # day of the month (1-31)
+    # month (1-12)
+    # day of the week (0-6)
+    # nightly build at 15:50 UTC time every day
+    - cron: "50 15 * * *"
+
+jobs:
+  run_librispeech_2022_11_11_zipformer:
+    if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
+    runs-on: ${{ matrix.os }}
+    strategy:
+      matrix:
+        os: [ubuntu-latest]
+        python-version: [3.8]
+
+      fail-fast: false
+
+    steps:
+      - uses: actions/checkout@v2
+        with:
+          fetch-depth: 0
+
+      - name: Setup Python ${{ matrix.python-version }}
+        uses: actions/setup-python@v2
+        with:
+          python-version: ${{ matrix.python-version }}
+          cache: 'pip'
+          cache-dependency-path: '**/requirements-ci.txt'
+
+      - name: Install Python dependencies
+        run: |
+          grep -v '^#' ./requirements-ci.txt  | xargs -n 1 -L 1 pip install
+          pip uninstall -y protobuf
+          pip install --no-binary protobuf protobuf
+
+      - name: Cache kaldifeat
+        id: my-cache
+        uses: actions/cache@v2
+        with:
+          path: |
+            ~/tmp/kaldifeat
+          key: cache-tmp-${{ matrix.python-version }}-2022-09-25
+
+      - name: Install kaldifeat
+        if: steps.my-cache.outputs.cache-hit != 'true'
+        shell: bash
+        run: |
+          .github/scripts/install-kaldifeat.sh
+
+      - name: Cache LibriSpeech test-clean and test-other datasets
+        id: libri-test-clean-and-test-other-data
+        uses: actions/cache@v2
+        with:
+          path: |
+            ~/tmp/download
+          key: cache-libri-test-clean-and-test-other
+
+      - name: Download LibriSpeech test-clean and test-other
+        if: steps.libri-test-clean-and-test-other-data.outputs.cache-hit != 'true'
+        shell: bash
+        run: |
+          .github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh
+
+      - name: Prepare manifests for LibriSpeech test-clean and test-other
+        shell: bash
+        run: |
+          .github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh
+
+      - name: Cache LibriSpeech test-clean and test-other fbank features
+        id: libri-test-clean-and-test-other-fbank
+        uses: actions/cache@v2
+        with:
+          path: |
+            ~/tmp/fbank-libri
+          key: cache-libri-fbank-test-clean-and-test-other-v2
+
+      - name: Compute fbank for LibriSpeech test-clean and test-other
+        if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true'
+        shell: bash
+        run: |
+          .github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh
+
+      - name: Inference with pre-trained model
+        shell: bash
+        env:
+          GITHUB_EVENT_NAME: ${{ github.event_name }}
+          GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }}
+        run: |
+          mkdir -p egs/librispeech/ASR/data
+          ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank
+          ls -lh egs/librispeech/ASR/data/*
+
+          sudo apt-get -qq install git-lfs tree sox
+          export PYTHONPATH=$PWD:$PYTHONPATH
+          export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
+          export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
+
+          .github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-2022-12-01.sh
+
+      - name: Display decoding results for librispeech pruned_transducer_stateless7_ctc
+        if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
+        shell: bash
+        run: |
+          cd egs/librispeech/ASR/
+          tree ./pruned_transducer_stateless7_ctc/exp
+
+          cd pruned_transducer_stateless7_ctc
+          echo "results for pruned_transducer_stateless7_ctc"
+          echo "===greedy search==="
+          find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
+          find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
+
+          echo "===fast_beam_search==="
+          find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
+          find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
+
+          echo "===modified beam search==="
+          find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
+          find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
+
+          echo "===ctc decoding==="
+          find exp/ctc-decoding -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
+          find exp/ctc-decoding -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
+
+          echo "===1best==="
+          find exp/1best -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
+          find exp/1best -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
+
+      - name: Upload decoding results for librispeech pruned_transducer_stateless7_ctc
+        uses: actions/upload-artifact@v2
+        if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
+        with:
+          name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-pruned_transducer_stateless7-ctc-2022-12-01
+          path: egs/librispeech/ASR/pruned_transducer_stateless7_ctc/exp/
diff --git a/egs/librispeech/ASR/README.md b/egs/librispeech/ASR/README.md
index e737d68bd..caa23a49f 100644
--- a/egs/librispeech/ASR/README.md
+++ b/egs/librispeech/ASR/README.md
@@ -23,6 +23,7 @@ The following table lists the differences among them.
 | `pruned_transducer_stateless5`        | Conformer(modified) | Embedding + Conv1d | same as pruned_transducer_stateless4 + more layers + random combiner|
 | `pruned_transducer_stateless6`        | Conformer(modified) | Embedding + Conv1d | same as pruned_transducer_stateless4 + distillation with hubert|
 | `pruned_transducer_stateless7`        | Zipformer | Embedding + Conv1d | First experiment with Zipformer from Dan|
+| `pruned_transducer_stateless7_ctc`    | Zipformer | Embedding + Conv1d | Same as pruned_transducer_stateless7, but with extra CTC head|
 | `pruned_transducer_stateless8`        | Zipformer | Embedding + Conv1d | Same as pruned_transducer_stateless7, but using extra data from GigaSpeech|
 | `pruned_stateless_emformer_rnnt2`     | Emformer(from torchaudio) | Embedding + Conv1d | Using Emformer from torchaudio for streaming ASR|
 | `conv_emformer_transducer_stateless`  | ConvEmformer | Embedding + Conv1d | Using ConvEmformer for streaming ASR + mechanisms in reworked model |
diff --git a/egs/librispeech/ASR/RESULTS.md b/egs/librispeech/ASR/RESULTS.md
index 0885fb9b6..9e5669f6d 100644
--- a/egs/librispeech/ASR/RESULTS.md
+++ b/egs/librispeech/ASR/RESULTS.md
@@ -1,5 +1,84 @@
 ## Results
 
+### pruned_transducer_stateless7_ctc (zipformer with transducer loss and ctc loss)
+
+See  for more details.
+
+[pruned_transducer_stateless7_ctc](./pruned_transducer_stateless7_ctc)
+
+The tensorboard log can be found at
+
+
+You can find a pretrained model, training logs, decoding logs, and decoding
+results at:
+
+
+Number of model parameters: 70561891, i.e., 70.56 M
+
+|                          | test-clean | test-other  | comment            |
+|--------------------------|------------|-------------|--------------------|
+| greedy search            | 2.23       | 5.19        | --epoch 30 --avg 8 |
+| modified beam search     | 2.21       | 5.12        | --epoch 30 --avg 8 |
+| fast beam search         | 2.23       | 5.18        | --epoch 30 --avg 8 |
+| ctc decoding             | 2.48       | 5.82        | --epoch 30 --avg 9 |
+| 1best                    | 2.43       | 5.22        | --epoch 30 --avg 9 |
+| nbest                    | 2.43       | 5.22        | --epoch 30 --avg 9 |
+| nbest rescoring          | 2.34       | 5.05        | --epoch 30 --avg 9 |
+| whole lattice rescoring  | 2.34       | 5.04        | --epoch 30 --avg 9 |
+
+The training commands are:
+```bash
+export CUDA_VISIBLE_DEVICES="0,1,2,3"
+
+./pruned_transducer_stateless7_ctc/train.py \
+  --world-size 4 \
+  --num-epochs 30 \
+  --full-libri 1 \
+  --use-fp16 1 \
+  --max-duration 750 \
+  --exp-dir pruned_transducer_stateless7_ctc/exp \
+  --feedforward-dims  "1024,1024,2048,2048,1024" \
+  --ctc-loss-scale 0.2 \
+  --master-port 12535
+```
+
+The decoding commands for the transducer branch are:
+```bash
+for m in greedy_search fast_beam_search modified_beam_search ; do
+  for epoch in 30; do
+    for avg in 8; do
+      ./pruned_transducer_stateless7_ctc/decode.py \
+          --epoch $epoch \
+          --avg $avg \
+          --use-averaged-model 1 \
+          --exp-dir ./pruned_transducer_stateless7_ctc/exp \
+          --feedforward-dims  "1024,1024,2048,2048,1024" \
+          --max-duration 600 \
+          --decoding-method $m
+    done
+  done
+done
+```
+
+The decoding commands for the ctc branch are:
+```bash
+for m in ctc-decoding nbest nbest-rescoring whole-lattice-rescoring; do
+  for epoch in 30; do
+    for avg in 9; do
+      ./pruned_transducer_stateless7_ctc/ctc_decode.py \
+          --epoch $epoch \
+          --avg $avg \
+          --exp-dir ./pruned_transducer_stateless7_ctc/exp \
+          --max-duration 100 \
+          --decoding-method $m \
+          --hlg-scale 0.6 \
+          --lm-dir data/lm
+    done
+  done
+done
+```
+
+
 ### LibriSpeech BPE training results (Conformer CTC, supporting delay penalty)
 
 #### [conformer_ctc3](./conformer_ctc3)
diff --git a/egs/librispeech/ASR/conformer_ctc3/jit_pretrained.py b/egs/librispeech/ASR/conformer_ctc3/jit_pretrained.py
index c96defd23..5be898e37 100755
--- a/egs/librispeech/ASR/conformer_ctc3/jit_pretrained.py
+++ b/egs/librispeech/ASR/conformer_ctc3/jit_pretrained.py
@@ -23,40 +23,44 @@ Usage (for non-streaming mode):
 
 (1) ctc-decoding
 ./conformer_ctc3/pretrained.py \
-  --checkpoint conformer_ctc3/exp/pretrained.pt \
+  --nn-model-filename ./conformer_ctc3/exp/cpu_jit.pt \
   --bpe-model data/lang_bpe_500/bpe.model \
   --method ctc-decoding \
   --sample-rate 16000 \
-  test_wavs/1089-134686-0001.wav
+  /path/to/foo.wav \
+  /path/to/bar.wav
 
 (2) 1best
 ./conformer_ctc3/pretrained.py \
-  --checkpoint conformer_ctc3/exp/pretrained.pt \
+  --nn-model-filename ./conformer_ctc3/exp/cpu_jit.pt \
   --HLG data/lang_bpe_500/HLG.pt \
   --words-file data/lang_bpe_500/words.txt  \
   --method 1best \
   --sample-rate 16000 \
-  test_wavs/1089-134686-0001.wav
+  /path/to/foo.wav \
+  /path/to/bar.wav
 
 (3) nbest-rescoring
 ./conformer_ctc3/pretrained.py \
-  --checkpoint conformer_ctc3/exp/pretrained.pt \
+  --nn-model-filename ./conformer_ctc3/exp/cpu_jit.pt \
   --HLG data/lang_bpe_500/HLG.pt \
   --words-file data/lang_bpe_500/words.txt  \
   --G data/lm/G_4_gram.pt \
   --method nbest-rescoring \
   --sample-rate 16000 \
-  test_wavs/1089-134686-0001.wav
+  /path/to/foo.wav \
+  /path/to/bar.wav
 
 (4) whole-lattice-rescoring
 ./conformer_ctc3/pretrained.py \
-  --checkpoint conformer_ctc3/exp/pretrained.pt \
+  --nn-model-filename ./conformer_ctc3/exp/cpu_jit.pt \
   --HLG data/lang_bpe_500/HLG.pt \
   --words-file data/lang_bpe_500/words.txt  \
   --G data/lm/G_4_gram.pt \
   --method whole-lattice-rescoring \
   --sample-rate 16000 \
-  test_wavs/1089-134686-0001.wav
+  /path/to/foo.wav \
+  /path/to/bar.wav
 """
 
 
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/__init__.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/asr_datamodule.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/asr_datamodule.py
new file mode 120000
index 000000000..a074d6085
--- /dev/null
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/asr_datamodule.py
@@ -0,0 +1 @@
+../pruned_transducer_stateless2/asr_datamodule.py
\ No newline at end of file
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/beam_search.py
new file mode 120000
index 000000000..8554e44cc
--- /dev/null
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/beam_search.py
@@ -0,0 +1 @@
+../pruned_transducer_stateless2/beam_search.py
\ No newline at end of file
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/ctc_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/ctc_decode.py
new file mode 100755
index 000000000..9c23e7d66
--- /dev/null
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/ctc_decode.py
@@ -0,0 +1,818 @@
+#!/usr/bin/env python3
+#
+# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang,
+#                                                 Liyong Guo,
+#                                                 Quandong Wang,
+#                                                 Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+(1) ctc-decoding
+./pruned_transducer_stateless7_ctc/ctc_decode.py \
+    --epoch 30 \
+    --avg 15 \
+    --exp-dir ./pruned_transducer_stateless7_ctc/exp \
+    --max-duration 600 \
+    --decoding-method ctc-decoding
+
+(2) 1best
+./pruned_transducer_stateless7_ctc/ctc_decode.py \
+    --epoch 30 \
+    --avg 15 \
+    --exp-dir ./pruned_transducer_stateless7_ctc/exp \
+    --max-duration 600 \
+    --hlg-scale 0.8 \
+    --decoding-method 1best
+
+(3) nbest
+./pruned_transducer_stateless7_ctc/ctc_decode.py \
+    --epoch 30 \
+    --avg 15 \
+    --exp-dir ./pruned_transducer_stateless7_ctc/exp \
+    --max-duration 600 \
+    --hlg-scale 0.8 \
+    --decoding-method 1best
+
+(4) nbest-rescoring
+./pruned_transducer_stateless7_ctc/ctc_decode.py \
+    --epoch 30 \
+    --avg 15 \
+    --exp-dir ./pruned_transducer_stateless7_ctc/exp \
+    --max-duration 600 \
+    --hlg-scale 0.8 \
+    --lm-dir data/lm \
+    --decoding-method nbest-rescoring
+
+(5) whole-lattice-rescoring
+./pruned_transducer_stateless7_ctc/ctc_decode.py \
+    --epoch 30 \
+    --avg 15 \
+    --exp-dir ./pruned_transducer_stateless7_ctc/exp \
+    --max-duration 600 \
+    --hlg-scale 0.8 \
+    --lm-dir data/lm \
+    --decoding-method whole-lattice-rescoring
+"""
+
+
+import argparse
+import logging
+import math
+from collections import defaultdict
+from pathlib import Path
+from typing import Dict, List, Optional, Tuple
+
+import k2
+import sentencepiece as spm
+import torch
+import torch.nn as nn
+from asr_datamodule import LibriSpeechAsrDataModule
+from train import add_model_arguments, get_params, get_transducer_model
+
+from icefall.checkpoint import (
+    average_checkpoints,
+    average_checkpoints_with_averaged_model,
+    find_checkpoints,
+    load_checkpoint,
+)
+from icefall.decode import (
+    get_lattice,
+    nbest_decoding,
+    nbest_oracle,
+    one_best_decoding,
+    rescore_with_n_best_list,
+    rescore_with_whole_lattice,
+)
+from icefall.lexicon import Lexicon
+from icefall.utils import (
+    AttributeDict,
+    get_texts,
+    setup_logger,
+    store_transcripts,
+    str2bool,
+    write_error_stats,
+)
+
+LOG_EPS = math.log(1e-10)
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+
+    parser.add_argument(
+        "--epoch",
+        type=int,
+        default=30,
+        help="""It specifies the checkpoint to use for decoding.
+        Note: Epoch counts from 1.
+        You can specify --avg to use more checkpoints for model averaging.""",
+    )
+
+    parser.add_argument(
+        "--iter",
+        type=int,
+        default=0,
+        help="""If positive, --epoch is ignored and it
+        will use the checkpoint exp_dir/checkpoint-iter.pt.
+        You can specify --avg to use more checkpoints for model averaging.
+        """,
+    )
+
+    parser.add_argument(
+        "--avg",
+        type=int,
+        default=15,
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch' and '--iter'",
+    )
+
+    parser.add_argument(
+        "--use-averaged-model",
+        type=str2bool,
+        default=True,
+        help="Whether to load averaged model. Currently it only supports "
+        "using --epoch. If True, it would decode with the averaged model "
+        "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+        "Actually only the models with epoch number of `epoch-avg` and "
+        "`epoch` are loaded for averaging. ",
+    )
+
+    parser.add_argument(
+        "--exp-dir",
+        type=str,
+        default="pruned_transducer_stateless7_ctc/exp",
+        help="The experiment dir",
+    )
+
+    parser.add_argument(
+        "--bpe-model",
+        type=str,
+        default="data/lang_bpe_500/bpe.model",
+        help="Path to the BPE model",
+    )
+
+    parser.add_argument(
+        "--lang-dir",
+        type=Path,
+        default="data/lang_bpe_500",
+        help="The lang dir containing word table and LG graph",
+    )
+
+    parser.add_argument(
+        "--context-size",
+        type=int,
+        default=2,
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+    )
+
+    parser.add_argument(
+        "--decoding-method",
+        type=str,
+        default="ctc-decoding",
+        help="""Decoding method.
+        Supported values are:
+        - (1) ctc-decoding. Use CTC decoding. It uses a sentence piece
+          model, i.e., lang_dir/bpe.model, to convert word pieces to words.
+          It needs neither a lexicon nor an n-gram LM.
+        - (2) 1best. Extract the best path from the decoding lattice as the
+          decoding result.
+        - (3) nbest. Extract n paths from the decoding lattice; the path
+          with the highest score is the decoding result.
+        - (4) nbest-rescoring. Extract n paths from the decoding lattice,
+          rescore them with an n-gram LM (e.g., a 4-gram LM), the path with
+          the highest score is the decoding result.
+        - (5) whole-lattice-rescoring. Rescore the decoding lattice with an
+          n-gram LM (e.g., a 4-gram LM), the best path of rescored lattice
+          is the decoding result.
+          you have trained an RNN LM using ./rnn_lm/train.py
+        - (6) nbest-oracle. Its WER is the lower bound of any n-best
+          rescoring method can achieve. Useful for debugging n-best
+          rescoring method.
+        """,
+    )
+
+    parser.add_argument(
+        "--num-paths",
+        type=int,
+        default=100,
+        help="""Number of paths for n-best based decoding method.
+        Used only when "method" is one of the following values:
+        nbest, nbest-rescoring, and nbest-oracle
+        """,
+    )
+
+    parser.add_argument(
+        "--nbest-scale",
+        type=float,
+        default=0.5,
+        help="""The scale to be applied to `lattice.scores`.
+        It's needed if you use any kinds of n-best based rescoring.
+        Used only when "method" is one of the following values:
+        nbest, nbest-rescoring, and nbest-oracle
+        A smaller value results in more unique paths.
+        """,
+    )
+
+    parser.add_argument(
+        "--hlg-scale",
+        type=float,
+        default=0.8,
+        help="""The scale to be applied to `hlg.scores`.
+        """,
+    )
+
+    parser.add_argument(
+        "--lm-dir",
+        type=str,
+        default="data/lm",
+        help="""The n-gram LM dir.
+        It should contain either G_4_gram.pt or G_4_gram.fst.txt
+        """,
+    )
+
+    add_model_arguments(parser)
+
+    return parser
+
+
+def get_decoding_params() -> AttributeDict:
+    """Parameters for decoding."""
+    params = AttributeDict(
+        {
+            "frame_shift_ms": 10,
+            "search_beam": 20,
+            "output_beam": 8,
+            "min_active_states": 30,
+            "max_active_states": 10000,
+            "use_double_scores": True,
+        }
+    )
+    return params
+
+
+def decode_one_batch(
+    params: AttributeDict,
+    model: nn.Module,
+    HLG: Optional[k2.Fsa],
+    H: Optional[k2.Fsa],
+    bpe_model: Optional[spm.SentencePieceProcessor],
+    batch: dict,
+    word_table: k2.SymbolTable,
+    G: Optional[k2.Fsa] = None,
+) -> Dict[str, List[List[str]]]:
+    """Decode one batch and return the result in a dict. The dict has the
+    following format:
+    - key: It indicates the setting used for decoding. For example,
+           if no rescoring is used, the key is the string `no_rescore`.
+           If LM rescoring is used, the key is the string `lm_scale_xxx`,
+           where `xxx` is the value of `lm_scale`. An example key is
+           `lm_scale_0.7`
+    - value: It contains the decoding result. `len(value)` equals to
+             batch size. `value[i]` is the decoding result for the i-th
+             utterance in the given batch.
+
+    Args:
+      params:
+        It's the return value of :func:`get_params`.
+
+        - params.decoding_method is "1best", it uses 1best decoding without LM rescoring.
+        - params.decoding_method is "nbest", it uses nbest decoding without LM rescoring.
+        - params.decoding_method is "nbest-rescoring", it uses nbest LM rescoring.
+        - params.decoding_method is "whole-lattice-rescoring", it uses whole lattice LM
+          rescoring.
+
+      model:
+        The neural model.
+      HLG:
+        The decoding graph. Used only when params.decoding_method is NOT ctc-decoding.
+      H:
+        The ctc topo. Used only when params.decoding_method is ctc-decoding.
+      bpe_model:
+        The BPE model. Used only when params.decoding_method is ctc-decoding.
+      batch:
+        It is the return value from iterating
+        `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
+        for the format of the `batch`.
+      word_table:
+        The word symbol table.
+      G:
+        An LM. It is not None when params.decoding_method is "nbest-rescoring"
+        or "whole-lattice-rescoring". In general, the G in HLG
+        is a 3-gram LM, while this G is a 4-gram LM.
+    Returns:
+      Return the decoding result. See above description for the format of
+      the returned dict. Note: If it decodes to nothing, then return None.
+    """
+    if HLG is not None:
+        device = HLG.device
+    else:
+        device = H.device
+    feature = batch["inputs"]
+    assert feature.ndim == 3
+    feature = feature.to(device)
+    # at entry, feature is (N, T, C)
+
+    supervisions = batch["supervisions"]
+    feature_lens = supervisions["num_frames"].to(device)
+
+    encoder_out, encoder_out_lens = model.encoder(feature, feature_lens)
+    nnet_output = model.ctc_output(encoder_out)
+    # nnet_output is (N, T, C)
+
+    supervision_segments = torch.stack(
+        (
+            supervisions["sequence_idx"],
+            supervisions["start_frame"] // params.subsampling_factor,
+            supervisions["num_frames"] // params.subsampling_factor,
+        ),
+        1,
+    ).to(torch.int32)
+
+    if H is None:
+        assert HLG is not None
+        decoding_graph = HLG
+    else:
+        assert HLG is None
+        assert bpe_model is not None
+        decoding_graph = H
+
+    lattice = get_lattice(
+        nnet_output=nnet_output,
+        decoding_graph=decoding_graph,
+        supervision_segments=supervision_segments,
+        search_beam=params.search_beam,
+        output_beam=params.output_beam,
+        min_active_states=params.min_active_states,
+        max_active_states=params.max_active_states,
+        subsampling_factor=params.subsampling_factor,
+    )
+
+    if params.decoding_method == "ctc-decoding":
+        best_path = one_best_decoding(
+            lattice=lattice, use_double_scores=params.use_double_scores
+        )
+        # Note: `best_path.aux_labels` contains token IDs, not word IDs
+        # since we are using H, not HLG here.
+        #
+        # token_ids is a lit-of-list of IDs
+        token_ids = get_texts(best_path)
+
+        # hyps is a list of str, e.g., ['xxx yyy zzz', ...]
+        hyps = bpe_model.decode(token_ids)
+
+        # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ]
+        hyps = [s.split() for s in hyps]
+        key = "ctc-decoding"
+        return {key: hyps}
+
+    if params.decoding_method == "nbest-oracle":
+        # Note: You can also pass rescored lattices to it.
+        # We choose the HLG decoded lattice for speed reasons
+        # as HLG decoding is faster and the oracle WER
+        # is only slightly worse than that of rescored lattices.
+        best_path = nbest_oracle(
+            lattice=lattice,
+            num_paths=params.num_paths,
+            ref_texts=supervisions["text"],
+            word_table=word_table,
+            nbest_scale=params.nbest_scale,
+            oov="",
+        )
+        hyps = get_texts(best_path)
+        hyps = [[word_table[i] for i in ids] for ids in hyps]
+        key = f"oracle_{params.num_paths}_nbest_scale_{params.nbest_scale}"  # noqa
+        return {key: hyps}
+
+    if params.decoding_method in ["1best", "nbest"]:
+        if params.decoding_method == "1best":
+            best_path = one_best_decoding(
+                lattice=lattice, use_double_scores=params.use_double_scores
+            )
+            key = "no_rescore"
+        else:
+            best_path = nbest_decoding(
+                lattice=lattice,
+                num_paths=params.num_paths,
+                use_double_scores=params.use_double_scores,
+                nbest_scale=params.nbest_scale,
+            )
+            key = f"no_rescore-nbest-scale-{params.nbest_scale}-{params.num_paths}"  # noqa
+
+        hyps = get_texts(best_path)
+        hyps = [[word_table[i] for i in ids] for ids in hyps]
+        return {key: hyps}
+
+    assert params.decoding_method in [
+        "nbest-rescoring",
+        "whole-lattice-rescoring",
+    ]
+
+    lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]
+    lm_scale_list += [0.8, 0.9, 1.0, 1.1, 1.2, 1.3]
+    lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0]
+
+    if params.decoding_method == "nbest-rescoring":
+        best_path_dict = rescore_with_n_best_list(
+            lattice=lattice,
+            G=G,
+            num_paths=params.num_paths,
+            lm_scale_list=lm_scale_list,
+            nbest_scale=params.nbest_scale,
+        )
+    elif params.decoding_method == "whole-lattice-rescoring":
+        best_path_dict = rescore_with_whole_lattice(
+            lattice=lattice,
+            G_with_epsilon_loops=G,
+            lm_scale_list=lm_scale_list,
+        )
+    else:
+        assert False, f"Unsupported decoding method: {params.decoding_method}"
+
+    ans = dict()
+    if best_path_dict is not None:
+        for lm_scale_str, best_path in best_path_dict.items():
+            hyps = get_texts(best_path)
+            hyps = [[word_table[i] for i in ids] for ids in hyps]
+            ans[lm_scale_str] = hyps
+    else:
+        ans = None
+    return ans
+
+
+def decode_dataset(
+    dl: torch.utils.data.DataLoader,
+    params: AttributeDict,
+    model: nn.Module,
+    HLG: Optional[k2.Fsa],
+    H: Optional[k2.Fsa],
+    bpe_model: Optional[spm.SentencePieceProcessor],
+    word_table: k2.SymbolTable,
+    G: Optional[k2.Fsa] = None,
+) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
+    """Decode dataset.
+
+    Args:
+      dl:
+        PyTorch's dataloader containing the dataset to decode.
+      params:
+        It is returned by :func:`get_params`.
+      model:
+        The neural model.
+      HLG:
+        The decoding graph. Used only when params.decoding_method is NOT ctc-decoding.
+      H:
+        The ctc topo. Used only when params.decoding_method is ctc-decoding.
+      bpe_model:
+        The BPE model. Used only when params.decoding_method is ctc-decoding.
+      word_table:
+        It is the word symbol table.
+      G:
+        An LM. It is not None when params.decoding_method is "nbest-rescoring"
+        or "whole-lattice-rescoring". In general, the G in HLG
+        is a 3-gram LM, while this G is a 4-gram LM.
+    Returns:
+      Return a dict, whose key may be "no-rescore" if no LM rescoring
+      is used, or it may be "lm_scale_0.7" if LM rescoring is used.
+      Its value is a list of tuples. Each tuple contains two elements:
+      The first is the reference transcript, and the second is the
+      predicted result.
+    """
+    num_cuts = 0
+
+    try:
+        num_batches = len(dl)
+    except TypeError:
+        num_batches = "?"
+
+    results = defaultdict(list)
+    for batch_idx, batch in enumerate(dl):
+        texts = batch["supervisions"]["text"]
+        cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
+
+        hyps_dict = decode_one_batch(
+            params=params,
+            model=model,
+            HLG=HLG,
+            H=H,
+            bpe_model=bpe_model,
+            batch=batch,
+            word_table=word_table,
+            G=G,
+        )
+
+        for name, hyps in hyps_dict.items():
+            this_batch = []
+            assert len(hyps) == len(texts)
+            for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
+                ref_words = ref_text.split()
+                this_batch.append((cut_id, ref_words, hyp_words))
+
+            results[name].extend(this_batch)
+
+        num_cuts += len(texts)
+
+        if batch_idx % 100 == 0:
+            batch_str = f"{batch_idx}/{num_batches}"
+
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+    return results
+
+
+def save_results(
+    params: AttributeDict,
+    test_set_name: str,
+    results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
+):
+    test_set_wers = dict()
+    for key, results in results_dict.items():
+        recog_path = (
+            params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
+        )
+        results = sorted(results)
+        store_transcripts(filename=recog_path, texts=results)
+        logging.info(f"The transcripts are stored in {recog_path}")
+
+        # The following prints out WERs, per-word error statistics and aligned
+        # ref/hyp pairs.
+        errs_filename = (
+            params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
+        )
+        with open(errs_filename, "w") as f:
+            wer = write_error_stats(f, f"{test_set_name}-{key}", results)
+            test_set_wers[key] = wer
+
+        logging.info("Wrote detailed error stats to {}".format(errs_filename))
+
+    test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
+    errs_info = (
+        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+    )
+    with open(errs_info, "w") as f:
+        print("settings\tWER", file=f)
+        for key, val in test_set_wers:
+            print("{}\t{}".format(key, val), file=f)
+
+    s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
+    note = "\tbest for {}".format(test_set_name)
+    for key, val in test_set_wers:
+        s += "{}\t{}{}\n".format(key, val, note)
+        note = ""
+    logging.info(s)
+
+
+@torch.no_grad()
+def main():
+    parser = get_parser()
+    LibriSpeechAsrDataModule.add_arguments(parser)
+    args = parser.parse_args()
+    args.exp_dir = Path(args.exp_dir)
+    args.lang_dir = Path(args.lang_dir)
+    args.lm_dir = Path(args.lm_dir)
+
+    params = get_params()
+    # add decoding params
+    params.update(get_decoding_params())
+    params.update(vars(args))
+
+    assert params.decoding_method in (
+        "ctc-decoding",
+        "1best",
+        "nbest",
+        "nbest-rescoring",
+        "whole-lattice-rescoring",
+        "nbest-oracle",
+    )
+    params.res_dir = params.exp_dir / params.decoding_method
+
+    if params.iter > 0:
+        params.suffix = f"iter-{params.iter}-avg-{params.avg}"
+    else:
+        params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
+
+    if params.use_averaged_model:
+        params.suffix += "-use-averaged-model"
+
+    setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
+    logging.info("Decoding started")
+
+    device = torch.device("cpu")
+    if torch.cuda.is_available():
+        device = torch.device("cuda", 0)
+
+    logging.info(f"Device: {device}")
+    logging.info(params)
+
+    lexicon = Lexicon(params.lang_dir)
+    max_token_id = max(lexicon.tokens)
+    num_classes = max_token_id + 1  # +1 for the blank
+
+    params.vocab_size = num_classes
+    #  and  are defined in local/train_bpe_model.py
+    params.blank_id = 0
+
+    if params.decoding_method == "ctc-decoding":
+        HLG = None
+        H = k2.ctc_topo(
+            max_token=max_token_id,
+            modified=False,
+            device=device,
+        )
+        bpe_model = spm.SentencePieceProcessor()
+        bpe_model.load(str(params.lang_dir / "bpe.model"))
+    else:
+        H = None
+        bpe_model = None
+        HLG = k2.Fsa.from_dict(
+            torch.load(f"{params.lang_dir}/HLG.pt", map_location=device)
+        )
+        assert HLG.requires_grad is False
+
+        HLG.scores *= params.hlg_scale
+        if not hasattr(HLG, "lm_scores"):
+            HLG.lm_scores = HLG.scores.clone()
+
+    if params.decoding_method in (
+        "nbest-rescoring",
+        "whole-lattice-rescoring",
+    ):
+        if not (params.lm_dir / "G_4_gram.pt").is_file():
+            logging.info("Loading G_4_gram.fst.txt")
+            logging.warning("It may take 8 minutes.")
+            with open(params.lm_dir / "G_4_gram.fst.txt") as f:
+                first_word_disambig_id = lexicon.word_table["#0"]
+
+                G = k2.Fsa.from_openfst(f.read(), acceptor=False)
+                # G.aux_labels is not needed in later computations, so
+                # remove it here.
+                del G.aux_labels
+                # CAUTION: The following line is crucial.
+                # Arcs entering the back-off state have label equal to #0.
+                # We have to change it to 0 here.
+                G.labels[G.labels >= first_word_disambig_id] = 0
+                # See https://github.com/k2-fsa/k2/issues/874
+                # for why we need to set G.properties to None
+                G.__dict__["_properties"] = None
+                G = k2.Fsa.from_fsas([G]).to(device)
+                G = k2.arc_sort(G)
+                # Save a dummy value so that it can be loaded in C++.
+                # See https://github.com/pytorch/pytorch/issues/67902
+                # for why we need to do this.
+                G.dummy = 1
+
+                torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt")
+        else:
+            logging.info("Loading pre-compiled G_4_gram.pt")
+            d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device)
+            G = k2.Fsa.from_dict(d)
+
+        if params.decoding_method == "whole-lattice-rescoring":
+            # Add epsilon self-loops to G as we will compose
+            # it with the whole lattice later
+            G = k2.add_epsilon_self_loops(G)
+            G = k2.arc_sort(G)
+            G = G.to(device)
+
+        # G.lm_scores is used to replace HLG.lm_scores during
+        # LM rescoring.
+        G.lm_scores = G.scores.clone()
+    else:
+        G = None
+
+    logging.info("About to create model")
+    model = get_transducer_model(params)
+
+    if not params.use_averaged_model:
+        if params.iter > 0:
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg
+            ]
+            if len(filenames) == 0:
+                raise ValueError(
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            elif len(filenames) < params.avg:
+                raise ValueError(
+                    f"Not enough checkpoints ({len(filenames)}) found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            logging.info(f"averaging {filenames}")
+            model.to(device)
+            model.load_state_dict(average_checkpoints(filenames, device=device))
+        elif params.avg == 1:
+            load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+        else:
+            start = params.epoch - params.avg + 1
+            filenames = []
+            for i in range(start, params.epoch + 1):
+                if i >= 1:
+                    filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+            logging.info(f"averaging {filenames}")
+            model.to(device)
+            model.load_state_dict(average_checkpoints(filenames, device=device))
+    else:
+        if params.iter > 0:
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg + 1
+            ]
+            if len(filenames) == 0:
+                raise ValueError(
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            elif len(filenames) < params.avg + 1:
+                raise ValueError(
+                    f"Not enough checkpoints ({len(filenames)}) found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            filename_start = filenames[-1]
+            filename_end = filenames[0]
+            logging.info(
+                "Calculating the averaged model over iteration checkpoints"
+                f" from {filename_start} (excluded) to {filename_end}"
+            )
+            model.to(device)
+            model.load_state_dict(
+                average_checkpoints_with_averaged_model(
+                    filename_start=filename_start,
+                    filename_end=filename_end,
+                    device=device,
+                )
+            )
+        else:
+            assert params.avg > 0, params.avg
+            start = params.epoch - params.avg
+            assert start >= 1, start
+            filename_start = f"{params.exp_dir}/epoch-{start}.pt"
+            filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
+            logging.info(
+                f"Calculating the averaged model over epoch range from "
+                f"{start} (excluded) to {params.epoch}"
+            )
+            model.to(device)
+            model.load_state_dict(
+                average_checkpoints_with_averaged_model(
+                    filename_start=filename_start,
+                    filename_end=filename_end,
+                    device=device,
+                )
+            )
+
+    model.to(device)
+    model.eval()
+
+    num_param = sum([p.numel() for p in model.parameters()])
+    logging.info(f"Number of model parameters: {num_param}")
+
+    # we need cut ids to display recognition results.
+    args.return_cuts = True
+    librispeech = LibriSpeechAsrDataModule(args)
+
+    test_clean_cuts = librispeech.test_clean_cuts()
+    test_other_cuts = librispeech.test_other_cuts()
+
+    test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
+    test_other_dl = librispeech.test_dataloaders(test_other_cuts)
+
+    test_sets = ["test-clean", "test-other"]
+    test_dl = [test_clean_dl, test_other_dl]
+
+    for test_set, test_dl in zip(test_sets, test_dl):
+        results_dict = decode_dataset(
+            dl=test_dl,
+            params=params,
+            model=model,
+            HLG=HLG,
+            H=H,
+            bpe_model=bpe_model,
+            word_table=lexicon.word_table,
+            G=G,
+        )
+
+        save_results(
+            params=params,
+            test_set_name=test_set,
+            results_dict=results_dict,
+        )
+
+    logging.info("Done!")
+
+
+if __name__ == "__main__":
+    main()
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/decode.py
new file mode 100755
index 000000000..32a9b6bb2
--- /dev/null
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/decode.py
@@ -0,0 +1,841 @@
+#!/usr/bin/env python3
+#
+# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang,
+#                                                 Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+(1) greedy search
+./pruned_transducer_stateless7_ctc/decode.py \
+    --epoch 28 \
+    --avg 15 \
+    --exp-dir ./pruned_transducer_stateless7_ctc/exp \
+    --max-duration 600 \
+    --decoding-method greedy_search
+
+(2) beam search (not recommended)
+./pruned_transducer_stateless7_ctc/decode.py \
+    --epoch 28 \
+    --avg 15 \
+    --exp-dir ./pruned_transducer_stateless7_ctc/exp \
+    --max-duration 600 \
+    --decoding-method beam_search \
+    --beam-size 4
+
+(3) modified beam search
+./pruned_transducer_stateless7_ctc/decode.py \
+    --epoch 28 \
+    --avg 15 \
+    --exp-dir ./pruned_transducer_stateless7_ctc/exp \
+    --max-duration 600 \
+    --decoding-method modified_beam_search \
+    --beam-size 4
+
+(4) fast beam search (one best)
+./pruned_transducer_stateless7_ctc/decode.py \
+    --epoch 28 \
+    --avg 15 \
+    --exp-dir ./pruned_transducer_stateless7_ctc/exp \
+    --max-duration 600 \
+    --decoding-method fast_beam_search \
+    --beam 20.0 \
+    --max-contexts 8 \
+    --max-states 64
+
+(5) fast beam search (nbest)
+./pruned_transducer_stateless7_ctc/decode.py \
+    --epoch 28 \
+    --avg 15 \
+    --exp-dir ./pruned_transducer_stateless7_ctc/exp \
+    --max-duration 600 \
+    --decoding-method fast_beam_search_nbest \
+    --beam 20.0 \
+    --max-contexts 8 \
+    --max-states 64 \
+    --num-paths 200 \
+    --nbest-scale 0.5
+
+(6) fast beam search (nbest oracle WER)
+./pruned_transducer_stateless7_ctc/decode.py \
+    --epoch 28 \
+    --avg 15 \
+    --exp-dir ./pruned_transducer_stateless7_ctc/exp \
+    --max-duration 600 \
+    --decoding-method fast_beam_search_nbest_oracle \
+    --beam 20.0 \
+    --max-contexts 8 \
+    --max-states 64 \
+    --num-paths 200 \
+    --nbest-scale 0.5
+
+(7) fast beam search (with LG)
+./pruned_transducer_stateless7_ctc/decode.py \
+    --epoch 28 \
+    --avg 15 \
+    --exp-dir ./pruned_transducer_stateless7_ctc/exp \
+    --max-duration 600 \
+    --decoding-method fast_beam_search_nbest_LG \
+    --beam 20.0 \
+    --max-contexts 8 \
+    --max-states 64
+"""
+
+
+import argparse
+import logging
+import math
+from collections import defaultdict
+from pathlib import Path
+from typing import Dict, List, Optional, Tuple
+
+import k2
+import sentencepiece as spm
+import torch
+import torch.nn as nn
+from asr_datamodule import LibriSpeechAsrDataModule
+from beam_search import (
+    beam_search,
+    fast_beam_search_nbest,
+    fast_beam_search_nbest_LG,
+    fast_beam_search_nbest_oracle,
+    fast_beam_search_one_best,
+    greedy_search,
+    greedy_search_batch,
+    modified_beam_search,
+)
+from train import add_model_arguments, get_params, get_transducer_model
+
+from icefall.checkpoint import (
+    average_checkpoints,
+    average_checkpoints_with_averaged_model,
+    find_checkpoints,
+    load_checkpoint,
+)
+from icefall.lexicon import Lexicon
+from icefall.utils import (
+    AttributeDict,
+    setup_logger,
+    store_transcripts,
+    str2bool,
+    write_error_stats,
+)
+
+LOG_EPS = math.log(1e-10)
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+
+    parser.add_argument(
+        "--epoch",
+        type=int,
+        default=30,
+        help="""It specifies the checkpoint to use for decoding.
+        Note: Epoch counts from 1.
+        You can specify --avg to use more checkpoints for model averaging.""",
+    )
+
+    parser.add_argument(
+        "--iter",
+        type=int,
+        default=0,
+        help="""If positive, --epoch is ignored and it
+        will use the checkpoint exp_dir/checkpoint-iter.pt.
+        You can specify --avg to use more checkpoints for model averaging.
+        """,
+    )
+
+    parser.add_argument(
+        "--avg",
+        type=int,
+        default=9,
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch' and '--iter'",
+    )
+
+    parser.add_argument(
+        "--use-averaged-model",
+        type=str2bool,
+        default=True,
+        help="Whether to load averaged model. Currently it only supports "
+        "using --epoch. If True, it would decode with the averaged model "
+        "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+        "Actually only the models with epoch number of `epoch-avg` and "
+        "`epoch` are loaded for averaging. ",
+    )
+
+    parser.add_argument(
+        "--exp-dir",
+        type=str,
+        default="pruned_transducer_stateless7_ctc/exp",
+        help="The experiment dir",
+    )
+
+    parser.add_argument(
+        "--bpe-model",
+        type=str,
+        default="data/lang_bpe_500/bpe.model",
+        help="Path to the BPE model",
+    )
+
+    parser.add_argument(
+        "--lang-dir",
+        type=Path,
+        default="data/lang_bpe_500",
+        help="The lang dir containing word table and LG graph",
+    )
+
+    parser.add_argument(
+        "--decoding-method",
+        type=str,
+        default="greedy_search",
+        help="""Possible values are:
+          - greedy_search
+          - beam_search
+          - modified_beam_search
+          - fast_beam_search
+          - fast_beam_search_nbest
+          - fast_beam_search_nbest_oracle
+          - fast_beam_search_nbest_LG
+        If you use fast_beam_search_nbest_LG, you have to specify
+        `--lang-dir`, which should contain `LG.pt`.
+        """,
+    )
+
+    parser.add_argument(
+        "--beam-size",
+        type=int,
+        default=4,
+        help="""An integer indicating how many candidates we will keep for each
+        frame. Used only when --decoding-method is beam_search or
+        modified_beam_search.""",
+    )
+
+    parser.add_argument(
+        "--beam",
+        type=float,
+        default=20.0,
+        help="""A floating point value to calculate the cutoff score during beam
+        search (i.e., `cutoff = max-score - beam`), which is the same as the
+        `beam` in Kaldi.
+        Used only when --decoding-method is fast_beam_search,
+        fast_beam_search_nbest, fast_beam_search_nbest_LG,
+        and fast_beam_search_nbest_oracle
+        """,
+    )
+
+    parser.add_argument(
+        "--ngram-lm-scale",
+        type=float,
+        default=0.01,
+        help="""
+        Used only when --decoding_method is fast_beam_search_nbest_LG.
+        It specifies the scale for n-gram LM scores.
+        """,
+    )
+
+    parser.add_argument(
+        "--max-contexts",
+        type=int,
+        default=8,
+        help="""Used only when --decoding-method is
+        fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
+        and fast_beam_search_nbest_oracle""",
+    )
+
+    parser.add_argument(
+        "--max-states",
+        type=int,
+        default=64,
+        help="""Used only when --decoding-method is
+        fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
+        and fast_beam_search_nbest_oracle""",
+    )
+
+    parser.add_argument(
+        "--context-size",
+        type=int,
+        default=2,
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+    )
+    parser.add_argument(
+        "--max-sym-per-frame",
+        type=int,
+        default=1,
+        help="""Maximum number of symbols per frame.
+        Used only when --decoding_method is greedy_search""",
+    )
+
+    parser.add_argument(
+        "--num-paths",
+        type=int,
+        default=200,
+        help="""Number of paths for nbest decoding.
+        Used only when the decoding method is fast_beam_search_nbest,
+        fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
+    )
+
+    parser.add_argument(
+        "--nbest-scale",
+        type=float,
+        default=0.5,
+        help="""Scale applied to lattice scores when computing nbest paths.
+        Used only when the decoding method is fast_beam_search_nbest,
+        fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
+    )
+
+    parser.add_argument(
+        "--simulate-streaming",
+        type=str2bool,
+        default=False,
+        help="""Whether to simulate streaming in decoding, this is a good way to
+        test a streaming model.
+        """,
+    )
+
+    parser.add_argument(
+        "--decode-chunk-size",
+        type=int,
+        default=16,
+        help="The chunk size for decoding (in frames after subsampling)",
+    )
+
+    parser.add_argument(
+        "--left-context",
+        type=int,
+        default=64,
+        help="left context can be seen during decoding (in frames after subsampling)",
+    )
+
+    add_model_arguments(parser)
+
+    return parser
+
+
+def decode_one_batch(
+    params: AttributeDict,
+    model: nn.Module,
+    sp: spm.SentencePieceProcessor,
+    batch: dict,
+    word_table: Optional[k2.SymbolTable] = None,
+    decoding_graph: Optional[k2.Fsa] = None,
+) -> Dict[str, List[List[str]]]:
+    """Decode one batch and return the result in a dict. The dict has the
+    following format:
+
+        - key: It indicates the setting used for decoding. For example,
+               if greedy_search is used, it would be "greedy_search"
+               If beam search with a beam size of 7 is used, it would be
+               "beam_7"
+        - value: It contains the decoding result. `len(value)` equals to
+                 batch size. `value[i]` is the decoding result for the i-th
+                 utterance in the given batch.
+    Args:
+      params:
+        It's the return value of :func:`get_params`.
+      model:
+        The neural model.
+      sp:
+        The BPE model.
+      batch:
+        It is the return value from iterating
+        `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
+        for the format of the `batch`.
+      word_table:
+        The word symbol table.
+      decoding_graph:
+        The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
+        only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
+        fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
+    Returns:
+      Return the decoding result. See above description for the format of
+      the returned dict.
+    """
+    device = next(model.parameters()).device
+    feature = batch["inputs"]
+    assert feature.ndim == 3
+
+    feature = feature.to(device)
+    # at entry, feature is (N, T, C)
+
+    supervisions = batch["supervisions"]
+    feature_lens = supervisions["num_frames"].to(device)
+
+    if params.simulate_streaming:
+        feature_lens += params.left_context
+        feature = torch.nn.functional.pad(
+            feature,
+            pad=(0, 0, 0, params.left_context),
+            value=LOG_EPS,
+        )
+        encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
+            x=feature,
+            x_lens=feature_lens,
+            chunk_size=params.decode_chunk_size,
+            left_context=params.left_context,
+            simulate_streaming=True,
+        )
+    else:
+        encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
+
+    hyps = []
+
+    if params.decoding_method == "fast_beam_search":
+        hyp_tokens = fast_beam_search_one_best(
+            model=model,
+            decoding_graph=decoding_graph,
+            encoder_out=encoder_out,
+            encoder_out_lens=encoder_out_lens,
+            beam=params.beam,
+            max_contexts=params.max_contexts,
+            max_states=params.max_states,
+        )
+        for hyp in sp.decode(hyp_tokens):
+            hyps.append(hyp.split())
+    elif params.decoding_method == "fast_beam_search_nbest_LG":
+        hyp_tokens = fast_beam_search_nbest_LG(
+            model=model,
+            decoding_graph=decoding_graph,
+            encoder_out=encoder_out,
+            encoder_out_lens=encoder_out_lens,
+            beam=params.beam,
+            max_contexts=params.max_contexts,
+            max_states=params.max_states,
+            num_paths=params.num_paths,
+            nbest_scale=params.nbest_scale,
+        )
+        for hyp in hyp_tokens:
+            hyps.append([word_table[i] for i in hyp])
+    elif params.decoding_method == "fast_beam_search_nbest":
+        hyp_tokens = fast_beam_search_nbest(
+            model=model,
+            decoding_graph=decoding_graph,
+            encoder_out=encoder_out,
+            encoder_out_lens=encoder_out_lens,
+            beam=params.beam,
+            max_contexts=params.max_contexts,
+            max_states=params.max_states,
+            num_paths=params.num_paths,
+            nbest_scale=params.nbest_scale,
+        )
+        for hyp in sp.decode(hyp_tokens):
+            hyps.append(hyp.split())
+    elif params.decoding_method == "fast_beam_search_nbest_oracle":
+        hyp_tokens = fast_beam_search_nbest_oracle(
+            model=model,
+            decoding_graph=decoding_graph,
+            encoder_out=encoder_out,
+            encoder_out_lens=encoder_out_lens,
+            beam=params.beam,
+            max_contexts=params.max_contexts,
+            max_states=params.max_states,
+            num_paths=params.num_paths,
+            ref_texts=sp.encode(supervisions["text"]),
+            nbest_scale=params.nbest_scale,
+        )
+        for hyp in sp.decode(hyp_tokens):
+            hyps.append(hyp.split())
+    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
+        hyp_tokens = greedy_search_batch(
+            model=model,
+            encoder_out=encoder_out,
+            encoder_out_lens=encoder_out_lens,
+        )
+        for hyp in sp.decode(hyp_tokens):
+            hyps.append(hyp.split())
+    elif params.decoding_method == "modified_beam_search":
+        hyp_tokens = modified_beam_search(
+            model=model,
+            encoder_out=encoder_out,
+            encoder_out_lens=encoder_out_lens,
+            beam=params.beam_size,
+        )
+        for hyp in sp.decode(hyp_tokens):
+            hyps.append(hyp.split())
+    else:
+        batch_size = encoder_out.size(0)
+
+        for i in range(batch_size):
+            # fmt: off
+            encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
+            # fmt: on
+            if params.decoding_method == "greedy_search":
+                hyp = greedy_search(
+                    model=model,
+                    encoder_out=encoder_out_i,
+                    max_sym_per_frame=params.max_sym_per_frame,
+                )
+            elif params.decoding_method == "beam_search":
+                hyp = beam_search(
+                    model=model,
+                    encoder_out=encoder_out_i,
+                    beam=params.beam_size,
+                )
+            else:
+                raise ValueError(
+                    f"Unsupported decoding method: {params.decoding_method}"
+                )
+            hyps.append(sp.decode(hyp).split())
+
+    if params.decoding_method == "greedy_search":
+        return {"greedy_search": hyps}
+    elif "fast_beam_search" in params.decoding_method:
+        key = f"beam_{params.beam}_"
+        key += f"max_contexts_{params.max_contexts}_"
+        key += f"max_states_{params.max_states}"
+        if "nbest" in params.decoding_method:
+            key += f"_num_paths_{params.num_paths}_"
+            key += f"nbest_scale_{params.nbest_scale}"
+            if "LG" in params.decoding_method:
+                key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
+
+        return {key: hyps}
+    else:
+        return {f"beam_size_{params.beam_size}": hyps}
+
+
+def decode_dataset(
+    dl: torch.utils.data.DataLoader,
+    params: AttributeDict,
+    model: nn.Module,
+    sp: spm.SentencePieceProcessor,
+    word_table: Optional[k2.SymbolTable] = None,
+    decoding_graph: Optional[k2.Fsa] = None,
+) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
+    """Decode dataset.
+
+    Args:
+      dl:
+        PyTorch's dataloader containing the dataset to decode.
+      params:
+        It is returned by :func:`get_params`.
+      model:
+        The neural model.
+      sp:
+        The BPE model.
+      word_table:
+        The word symbol table.
+      decoding_graph:
+        The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
+        only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
+        fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
+    Returns:
+      Return a dict, whose key may be "greedy_search" if greedy search
+      is used, or it may be "beam_7" if beam size of 7 is used.
+      Its value is a list of tuples. Each tuple contains two elements:
+      The first is the reference transcript, and the second is the
+      predicted result.
+    """
+    num_cuts = 0
+
+    try:
+        num_batches = len(dl)
+    except TypeError:
+        num_batches = "?"
+
+    if params.decoding_method == "greedy_search":
+        log_interval = 50
+    else:
+        log_interval = 20
+
+    results = defaultdict(list)
+    for batch_idx, batch in enumerate(dl):
+        texts = batch["supervisions"]["text"]
+        cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
+
+        hyps_dict = decode_one_batch(
+            params=params,
+            model=model,
+            sp=sp,
+            decoding_graph=decoding_graph,
+            word_table=word_table,
+            batch=batch,
+        )
+
+        for name, hyps in hyps_dict.items():
+            this_batch = []
+            assert len(hyps) == len(texts)
+            for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
+                ref_words = ref_text.split()
+                this_batch.append((cut_id, ref_words, hyp_words))
+
+            results[name].extend(this_batch)
+
+        num_cuts += len(texts)
+
+        if batch_idx % log_interval == 0:
+            batch_str = f"{batch_idx}/{num_batches}"
+
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+    return results
+
+
+def save_results(
+    params: AttributeDict,
+    test_set_name: str,
+    results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
+):
+    test_set_wers = dict()
+    for key, results in results_dict.items():
+        recog_path = (
+            params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
+        )
+        results = sorted(results)
+        store_transcripts(filename=recog_path, texts=results)
+        logging.info(f"The transcripts are stored in {recog_path}")
+
+        # The following prints out WERs, per-word error statistics and aligned
+        # ref/hyp pairs.
+        errs_filename = (
+            params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
+        )
+        with open(errs_filename, "w") as f:
+            wer = write_error_stats(
+                f, f"{test_set_name}-{key}", results, enable_log=True
+            )
+            test_set_wers[key] = wer
+
+        logging.info("Wrote detailed error stats to {}".format(errs_filename))
+
+    test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
+    errs_info = (
+        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+    )
+    with open(errs_info, "w") as f:
+        print("settings\tWER", file=f)
+        for key, val in test_set_wers:
+            print("{}\t{}".format(key, val), file=f)
+
+    s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
+    note = "\tbest for {}".format(test_set_name)
+    for key, val in test_set_wers:
+        s += "{}\t{}{}\n".format(key, val, note)
+        note = ""
+    logging.info(s)
+
+
+@torch.no_grad()
+def main():
+    parser = get_parser()
+    LibriSpeechAsrDataModule.add_arguments(parser)
+    args = parser.parse_args()
+    args.exp_dir = Path(args.exp_dir)
+
+    params = get_params()
+    params.update(vars(args))
+
+    assert params.decoding_method in (
+        "greedy_search",
+        "beam_search",
+        "fast_beam_search",
+        "fast_beam_search_nbest",
+        "fast_beam_search_nbest_LG",
+        "fast_beam_search_nbest_oracle",
+        "modified_beam_search",
+    )
+    params.res_dir = params.exp_dir / params.decoding_method
+
+    if params.iter > 0:
+        params.suffix = f"iter-{params.iter}-avg-{params.avg}"
+    else:
+        params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
+
+    if params.simulate_streaming:
+        params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}"
+        params.suffix += f"-left-context-{params.left_context}"
+
+    if "fast_beam_search" in params.decoding_method:
+        params.suffix += f"-beam-{params.beam}"
+        params.suffix += f"-max-contexts-{params.max_contexts}"
+        params.suffix += f"-max-states-{params.max_states}"
+        if "nbest" in params.decoding_method:
+            params.suffix += f"-nbest-scale-{params.nbest_scale}"
+            params.suffix += f"-num-paths-{params.num_paths}"
+            if "LG" in params.decoding_method:
+                params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
+    elif "beam_search" in params.decoding_method:
+        params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
+    else:
+        params.suffix += f"-context-{params.context_size}"
+        params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
+
+    if params.use_averaged_model:
+        params.suffix += "-use-averaged-model"
+
+    setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
+    logging.info("Decoding started")
+
+    device = torch.device("cpu")
+    if torch.cuda.is_available():
+        device = torch.device("cuda", 0)
+
+    logging.info(f"Device: {device}")
+
+    sp = spm.SentencePieceProcessor()
+    sp.load(params.bpe_model)
+
+    #  and  are defined in local/train_bpe_model.py
+    params.blank_id = sp.piece_to_id("")
+    params.unk_id = sp.piece_to_id("")
+    params.vocab_size = sp.get_piece_size()
+
+    if params.simulate_streaming:
+        assert (
+            params.causal_convolution
+        ), "Decoding in streaming requires causal convolution"
+
+    logging.info(params)
+
+    logging.info("About to create model")
+    model = get_transducer_model(params)
+
+    if not params.use_averaged_model:
+        if params.iter > 0:
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg
+            ]
+            if len(filenames) == 0:
+                raise ValueError(
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            elif len(filenames) < params.avg:
+                raise ValueError(
+                    f"Not enough checkpoints ({len(filenames)}) found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            logging.info(f"averaging {filenames}")
+            model.to(device)
+            model.load_state_dict(average_checkpoints(filenames, device=device))
+        elif params.avg == 1:
+            load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+        else:
+            start = params.epoch - params.avg + 1
+            filenames = []
+            for i in range(start, params.epoch + 1):
+                if i >= 1:
+                    filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+            logging.info(f"averaging {filenames}")
+            model.to(device)
+            model.load_state_dict(average_checkpoints(filenames, device=device))
+    else:
+        if params.iter > 0:
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg + 1
+            ]
+            if len(filenames) == 0:
+                raise ValueError(
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            elif len(filenames) < params.avg + 1:
+                raise ValueError(
+                    f"Not enough checkpoints ({len(filenames)}) found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            filename_start = filenames[-1]
+            filename_end = filenames[0]
+            logging.info(
+                "Calculating the averaged model over iteration checkpoints"
+                f" from {filename_start} (excluded) to {filename_end}"
+            )
+            model.to(device)
+            model.load_state_dict(
+                average_checkpoints_with_averaged_model(
+                    filename_start=filename_start,
+                    filename_end=filename_end,
+                    device=device,
+                )
+            )
+        else:
+            assert params.avg > 0, params.avg
+            start = params.epoch - params.avg
+            assert start >= 1, start
+            filename_start = f"{params.exp_dir}/epoch-{start}.pt"
+            filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
+            logging.info(
+                f"Calculating the averaged model over epoch range from "
+                f"{start} (excluded) to {params.epoch}"
+            )
+            model.to(device)
+            model.load_state_dict(
+                average_checkpoints_with_averaged_model(
+                    filename_start=filename_start,
+                    filename_end=filename_end,
+                    device=device,
+                )
+            )
+
+    model.to(device)
+    model.eval()
+
+    if "fast_beam_search" in params.decoding_method:
+        if params.decoding_method == "fast_beam_search_nbest_LG":
+            lexicon = Lexicon(params.lang_dir)
+            word_table = lexicon.word_table
+            lg_filename = params.lang_dir / "LG.pt"
+            logging.info(f"Loading {lg_filename}")
+            decoding_graph = k2.Fsa.from_dict(
+                torch.load(lg_filename, map_location=device)
+            )
+            decoding_graph.scores *= params.ngram_lm_scale
+        else:
+            word_table = None
+            decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
+    else:
+        decoding_graph = None
+        word_table = None
+
+    num_param = sum([p.numel() for p in model.parameters()])
+    logging.info(f"Number of model parameters: {num_param}")
+
+    # we need cut ids to display recognition results.
+    args.return_cuts = True
+    librispeech = LibriSpeechAsrDataModule(args)
+
+    test_clean_cuts = librispeech.test_clean_cuts()
+    test_other_cuts = librispeech.test_other_cuts()
+
+    test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
+    test_other_dl = librispeech.test_dataloaders(test_other_cuts)
+
+    test_sets = ["test-clean", "test-other"]
+    test_dl = [test_clean_dl, test_other_dl]
+
+    for test_set, test_dl in zip(test_sets, test_dl):
+        results_dict = decode_dataset(
+            dl=test_dl,
+            params=params,
+            model=model,
+            sp=sp,
+            word_table=word_table,
+            decoding_graph=decoding_graph,
+        )
+
+        save_results(
+            params=params,
+            test_set_name=test_set,
+            results_dict=results_dict,
+        )
+
+    logging.info("Done!")
+
+
+if __name__ == "__main__":
+    main()
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/decoder.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/decoder.py
new file mode 120000
index 000000000..33944d0d2
--- /dev/null
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/decoder.py
@@ -0,0 +1 @@
+../pruned_transducer_stateless7/decoder.py
\ No newline at end of file
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/encoder_interface.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/encoder_interface.py
new file mode 120000
index 000000000..b9aa0ae08
--- /dev/null
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/encoder_interface.py
@@ -0,0 +1 @@
+../pruned_transducer_stateless2/encoder_interface.py
\ No newline at end of file
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/export.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/export.py
new file mode 100755
index 000000000..59a393739
--- /dev/null
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/export.py
@@ -0,0 +1,320 @@
+#!/usr/bin/env python3
+#
+# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# This script converts several saved checkpoints
+# to a single one using model averaging.
+"""
+
+Usage:
+
+(1) Export to torchscript model using torch.jit.script()
+
+./pruned_transducer_stateless7_ctc/export.py \
+  --exp-dir ./pruned_transducer_stateless7_ctc/exp \
+  --bpe-model data/lang_bpe_500/bpe.model \
+  --epoch 30 \
+  --avg 9 \
+  --jit 1
+
+It will generate a file `cpu_jit.pt` in the given `exp_dir`. You can later
+load it by `torch.jit.load("cpu_jit.pt")`.
+
+Note `cpu` in the name `cpu_jit.pt` means the parameters when loaded into Python
+are on CPU. You can use `to("cuda")` to move them to a CUDA device.
+
+Check
+https://github.com/k2-fsa/sherpa
+for how to use the exported models outside of icefall.
+
+(2) Export `model.state_dict()`
+
+./pruned_transducer_stateless7_ctc/export.py \
+  --exp-dir ./pruned_transducer_stateless7_ctc/exp \
+  --bpe-model data/lang_bpe_500/bpe.model \
+  --epoch 20 \
+  --avg 10
+
+It will generate a file `pretrained.pt` in the given `exp_dir`. You can later
+load it by `icefall.checkpoint.load_checkpoint()`.
+
+To use the generated file with `pruned_transducer_stateless7_ctc/decode.py`,
+you can do:
+
+    cd /path/to/exp_dir
+    ln -s pretrained.pt epoch-9999.pt
+
+    cd /path/to/egs/librispeech/ASR
+    ./pruned_transducer_stateless7_ctc/decode.py \
+        --exp-dir ./pruned_transducer_stateless7_ctc/exp \
+        --epoch 9999 \
+        --avg 1 \
+        --max-duration 600 \
+        --decoding-method greedy_search \
+        --bpe-model data/lang_bpe_500/bpe.model
+
+Check ./pretrained.py for its usage.
+
+Note: If you don't want to train a model from scratch, we have
+provided one for you. You can get it at
+
+https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11
+
+with the following commands:
+
+    sudo apt-get install git-lfs
+    git lfs install
+    git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11
+    # You will find the pre-trained model in icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11/exp
+"""
+
+import argparse
+import logging
+from pathlib import Path
+
+import sentencepiece as spm
+import torch
+from scaling_converter import convert_scaled_to_non_scaled
+from train import add_model_arguments, get_params, get_transducer_model
+
+from icefall.checkpoint import (
+    average_checkpoints,
+    average_checkpoints_with_averaged_model,
+    find_checkpoints,
+    load_checkpoint,
+)
+from icefall.utils import str2bool
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+
+    parser.add_argument(
+        "--epoch",
+        type=int,
+        default=30,
+        help="""It specifies the checkpoint to use for decoding.
+        Note: Epoch counts from 1.
+        You can specify --avg to use more checkpoints for model averaging.""",
+    )
+
+    parser.add_argument(
+        "--iter",
+        type=int,
+        default=0,
+        help="""If positive, --epoch is ignored and it
+        will use the checkpoint exp_dir/checkpoint-iter.pt.
+        You can specify --avg to use more checkpoints for model averaging.
+        """,
+    )
+
+    parser.add_argument(
+        "--avg",
+        type=int,
+        default=9,
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch' and '--iter'",
+    )
+
+    parser.add_argument(
+        "--use-averaged-model",
+        type=str2bool,
+        default=True,
+        help="Whether to load averaged model. Currently it only supports "
+        "using --epoch. If True, it would decode with the averaged model "
+        "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+        "Actually only the models with epoch number of `epoch-avg` and "
+        "`epoch` are loaded for averaging. ",
+    )
+
+    parser.add_argument(
+        "--exp-dir",
+        type=str,
+        default="pruned_transducer_stateless7/exp",
+        help="""It specifies the directory where all training related
+        files, e.g., checkpoints, log, etc, are saved
+        """,
+    )
+
+    parser.add_argument(
+        "--bpe-model",
+        type=str,
+        default="data/lang_bpe_500/bpe.model",
+        help="Path to the BPE model",
+    )
+
+    parser.add_argument(
+        "--jit",
+        type=str2bool,
+        default=False,
+        help="""True to save a model after applying torch.jit.script.
+        It will generate a file named cpu_jit.pt
+
+        Check ./jit_pretrained.py for how to use it.
+        """,
+    )
+
+    parser.add_argument(
+        "--context-size",
+        type=int,
+        default=2,
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+    )
+
+    add_model_arguments(parser)
+
+    return parser
+
+
+@torch.no_grad()
+def main():
+    args = get_parser().parse_args()
+    args.exp_dir = Path(args.exp_dir)
+
+    params = get_params()
+    params.update(vars(args))
+
+    device = torch.device("cpu")
+    if torch.cuda.is_available():
+        device = torch.device("cuda", 0)
+
+    logging.info(f"device: {device}")
+
+    sp = spm.SentencePieceProcessor()
+    sp.load(params.bpe_model)
+
+    #  is defined in local/train_bpe_model.py
+    params.blank_id = sp.piece_to_id("")
+    params.vocab_size = sp.get_piece_size()
+
+    logging.info(params)
+
+    logging.info("About to create model")
+    model = get_transducer_model(params)
+
+    model.to(device)
+
+    if not params.use_averaged_model:
+        if params.iter > 0:
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg
+            ]
+            if len(filenames) == 0:
+                raise ValueError(
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            elif len(filenames) < params.avg:
+                raise ValueError(
+                    f"Not enough checkpoints ({len(filenames)}) found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            logging.info(f"averaging {filenames}")
+            model.to(device)
+            model.load_state_dict(average_checkpoints(filenames, device=device))
+        elif params.avg == 1:
+            load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+        else:
+            start = params.epoch - params.avg + 1
+            filenames = []
+            for i in range(start, params.epoch + 1):
+                if i >= 1:
+                    filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+            logging.info(f"averaging {filenames}")
+            model.to(device)
+            model.load_state_dict(average_checkpoints(filenames, device=device))
+    else:
+        if params.iter > 0:
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg + 1
+            ]
+            if len(filenames) == 0:
+                raise ValueError(
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            elif len(filenames) < params.avg + 1:
+                raise ValueError(
+                    f"Not enough checkpoints ({len(filenames)}) found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            filename_start = filenames[-1]
+            filename_end = filenames[0]
+            logging.info(
+                "Calculating the averaged model over iteration checkpoints"
+                f" from {filename_start} (excluded) to {filename_end}"
+            )
+            model.to(device)
+            model.load_state_dict(
+                average_checkpoints_with_averaged_model(
+                    filename_start=filename_start,
+                    filename_end=filename_end,
+                    device=device,
+                )
+            )
+        else:
+            assert params.avg > 0, params.avg
+            start = params.epoch - params.avg
+            assert start >= 1, start
+            filename_start = f"{params.exp_dir}/epoch-{start}.pt"
+            filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
+            logging.info(
+                f"Calculating the averaged model over epoch range from "
+                f"{start} (excluded) to {params.epoch}"
+            )
+            model.to(device)
+            model.load_state_dict(
+                average_checkpoints_with_averaged_model(
+                    filename_start=filename_start,
+                    filename_end=filename_end,
+                    device=device,
+                )
+            )
+
+    model.to("cpu")
+    model.eval()
+
+    if params.jit is True:
+        convert_scaled_to_non_scaled(model, inplace=True)
+        logging.info("Using torch.jit.script()")
+        # We won't use the forward() method of the model in C++, so just ignore
+        # it here.
+        # Otherwise, one of its arguments is a ragged tensor and is not
+        # torch scriptabe.
+        model.__class__.forward = torch.jit.ignore(model.__class__.forward)
+        logging.info("Using torch.jit.script")
+        model = torch.jit.script(model)
+        filename = params.exp_dir / "cpu_jit.pt"
+        model.save(str(filename))
+        logging.info(f"Saved to {filename}")
+    else:
+        logging.info("Not using torchscript. Export model.state_dict()")
+        # Save it using a format so that it can be loaded
+        # by :func:`load_checkpoint`
+        filename = params.exp_dir / "pretrained.pt"
+        torch.save({"model": model.state_dict()}, str(filename))
+        logging.info(f"Saved to {filename}")
+
+
+if __name__ == "__main__":
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+    logging.basicConfig(format=formatter, level=logging.INFO)
+    main()
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/jit_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/jit_pretrained.py
new file mode 100755
index 000000000..280b95984
--- /dev/null
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/jit_pretrained.py
@@ -0,0 +1,271 @@
+#!/usr/bin/env python3
+# Copyright      2022  Xiaomi Corp.        (authors: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This script loads torchscript models, exported by `torch.jit.script()`
+and uses them to decode waves.
+You can use the following command to get the exported models:
+
+./pruned_transducer_stateless7_ctc/export.py \
+  --exp-dir ./pruned_transducer_stateless7_ctc/exp \
+  --bpe-model data/lang_bpe_500/bpe.model \
+  --epoch 20 \
+  --avg 10 \
+  --jit 1
+
+Usage of this script:
+
+./pruned_transducer_stateless7_ctc/jit_pretrained.py \
+  --nn-model-filename ./pruned_transducer_stateless7_ctc/exp/cpu_jit.pt \
+  /path/to/foo.wav \
+  /path/to/bar.wav
+"""
+
+import argparse
+import logging
+import math
+from typing import List
+
+import kaldifeat
+import sentencepiece as spm
+import torch
+import torchaudio
+from torch.nn.utils.rnn import pad_sequence
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+
+    parser.add_argument(
+        "--nn-model-filename",
+        type=str,
+        required=True,
+        help="Path to the torchscript model cpu_jit.pt",
+    )
+
+    parser.add_argument(
+        "--bpe-model",
+        type=str,
+        help="""Path to bpe.model.""",
+    )
+
+    parser.add_argument(
+        "sound_files",
+        type=str,
+        nargs="+",
+        help="The input sound file(s) to transcribe. "
+        "Supported formats are those supported by torchaudio.load(). "
+        "For example, wav and flac are supported. "
+        "The sample rate has to be 16kHz.",
+    )
+
+    return parser
+
+
+def read_sound_files(
+    filenames: List[str], expected_sample_rate: float = 16000
+) -> List[torch.Tensor]:
+    """Read a list of sound files into a list 1-D float32 torch tensors.
+    Args:
+      filenames:
+        A list of sound filenames.
+      expected_sample_rate:
+        The expected sample rate of the sound files.
+    Returns:
+      Return a list of 1-D float32 torch tensors.
+    """
+    ans = []
+    for f in filenames:
+        wave, sample_rate = torchaudio.load(f)
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"Expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
+        # We use only the first channel
+        ans.append(wave[0])
+    return ans
+
+
+def greedy_search(
+    model: torch.jit.ScriptModule,
+    encoder_out: torch.Tensor,
+    encoder_out_lens: torch.Tensor,
+) -> List[List[int]]:
+    """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
+    Args:
+      model:
+        The transducer model.
+      encoder_out:
+        A 3-D tensor of shape (N, T, C)
+      encoder_out_lens:
+        A 1-D tensor of shape (N,).
+    Returns:
+      Return the decoded results for each utterance.
+    """
+    assert encoder_out.ndim == 3
+    assert encoder_out.size(0) >= 1, encoder_out.size(0)
+
+    packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
+        input=encoder_out,
+        lengths=encoder_out_lens.cpu(),
+        batch_first=True,
+        enforce_sorted=False,
+    )
+
+    device = encoder_out.device
+    blank_id = 0  # hard-code to 0
+
+    batch_size_list = packed_encoder_out.batch_sizes.tolist()
+    N = encoder_out.size(0)
+
+    assert torch.all(encoder_out_lens > 0), encoder_out_lens
+    assert N == batch_size_list[0], (N, batch_size_list)
+
+    context_size = model.decoder.context_size
+    hyps = [[blank_id] * context_size for _ in range(N)]
+
+    decoder_input = torch.tensor(
+        hyps,
+        device=device,
+        dtype=torch.int64,
+    )  # (N, context_size)
+
+    decoder_out = model.decoder(
+        decoder_input,
+        need_pad=torch.tensor([False]),
+    ).squeeze(1)
+
+    offset = 0
+    for batch_size in batch_size_list:
+        start = offset
+        end = offset + batch_size
+        current_encoder_out = packed_encoder_out.data[start:end]
+        current_encoder_out = current_encoder_out
+        # current_encoder_out's shape: (batch_size, encoder_out_dim)
+        offset = end
+
+        decoder_out = decoder_out[:batch_size]
+
+        logits = model.joiner(
+            current_encoder_out,
+            decoder_out,
+        )
+        # logits'shape (batch_size, vocab_size)
+
+        assert logits.ndim == 2, logits.shape
+        y = logits.argmax(dim=1).tolist()
+        emitted = False
+        for i, v in enumerate(y):
+            if v != blank_id:
+                hyps[i].append(v)
+                emitted = True
+        if emitted:
+            # update decoder output
+            decoder_input = [h[-context_size:] for h in hyps[:batch_size]]
+            decoder_input = torch.tensor(
+                decoder_input,
+                device=device,
+                dtype=torch.int64,
+            )
+            decoder_out = model.decoder(
+                decoder_input,
+                need_pad=torch.tensor([False]),
+            )
+            decoder_out = decoder_out.squeeze(1)
+
+    sorted_ans = [h[context_size:] for h in hyps]
+    ans = []
+    unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
+    for i in range(N):
+        ans.append(sorted_ans[unsorted_indices[i]])
+
+    return ans
+
+
+@torch.no_grad()
+def main():
+    parser = get_parser()
+    args = parser.parse_args()
+    logging.info(vars(args))
+
+    device = torch.device("cpu")
+    if torch.cuda.is_available():
+        device = torch.device("cuda", 0)
+
+    logging.info(f"device: {device}")
+
+    model = torch.jit.load(args.nn_model_filename)
+
+    model.eval()
+
+    model.to(device)
+
+    sp = spm.SentencePieceProcessor()
+    sp.load(args.bpe_model)
+
+    logging.info("Constructing Fbank computer")
+    opts = kaldifeat.FbankOptions()
+    opts.device = device
+    opts.frame_opts.dither = 0
+    opts.frame_opts.snip_edges = False
+    opts.frame_opts.samp_freq = 16000
+    opts.mel_opts.num_bins = 80
+
+    fbank = kaldifeat.Fbank(opts)
+
+    logging.info(f"Reading sound files: {args.sound_files}")
+    waves = read_sound_files(
+        filenames=args.sound_files,
+    )
+    waves = [w.to(device) for w in waves]
+
+    logging.info("Decoding started")
+    features = fbank(waves)
+    feature_lengths = [f.size(0) for f in features]
+
+    features = pad_sequence(
+        features,
+        batch_first=True,
+        padding_value=math.log(1e-10),
+    )
+
+    feature_lengths = torch.tensor(feature_lengths, device=device)
+
+    encoder_out, encoder_out_lens = model.encoder(
+        x=features,
+        x_lens=feature_lengths,
+    )
+
+    hyps = greedy_search(
+        model=model,
+        encoder_out=encoder_out,
+        encoder_out_lens=encoder_out_lens,
+    )
+    s = "\n"
+    for filename, hyp in zip(args.sound_files, hyps):
+        words = sp.decode(hyp)
+        s += f"{filename}:\n{words}\n\n"
+    logging.info(s)
+
+    logging.info("Decoding Done")
+
+
+if __name__ == "__main__":
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+    logging.basicConfig(format=formatter, level=logging.INFO)
+    main()
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py
new file mode 100755
index 000000000..d3343d34a
--- /dev/null
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py
@@ -0,0 +1,423 @@
+#!/usr/bin/env python3
+# Copyright      2022  Xiaomi Corp.        (authors: Fangjun Kuang,
+#                                                    Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This script loads torchscript models, exported by `torch.jit.script()`
+and uses them to decode waves.
+You can use the following command to get the exported models:
+
+./pruned_transducer_stateless7_ctc/export.py \
+  --exp-dir ./pruned_transducer_stateless7_ctc/exp \
+  --bpe-model data/lang_bpe_500/bpe.model \
+  --epoch 20 \
+  --avg 10 \
+  --jit 1
+
+Usage of this script:
+
+(1) ctc-decoding
+./pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py \
+  --nn-model-filename ./pruned_transducer_stateless7_ctc/exp/cpu_jit.pt \
+  --bpe-model data/lang_bpe_500/bpe.model \
+  --method ctc-decoding \
+  --sample-rate 16000 \
+  /path/to/foo.wav \
+  /path/to/bar.wav
+
+(2) 1best
+./pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py \
+  --nn-model-filename ./pruned_transducer_stateless7_ctc/exp/cpu_jit.pt \
+  --HLG data/lang_bpe_500/HLG.pt \
+  --words-file data/lang_bpe_500/words.txt  \
+  --method 1best \
+  --sample-rate 16000 \
+  /path/to/foo.wav \
+  /path/to/bar.wav
+
+
+(3) nbest-rescoring
+./pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py \
+  --nn-model-filename ./pruned_transducer_stateless7_ctc/exp/cpu_jit.pt \
+  --HLG data/lang_bpe_500/HLG.pt \
+  --words-file data/lang_bpe_500/words.txt  \
+  --G data/lm/G_4_gram.pt \
+  --method nbest-rescoring \
+  --sample-rate 16000 \
+  /path/to/foo.wav \
+  /path/to/bar.wav
+
+
+(4) whole-lattice-rescoring
+./pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py \
+  --nn-model-filename ./pruned_transducer_stateless7_ctc/exp/cpu_jit.pt \
+  --HLG data/lang_bpe_500/HLG.pt \
+  --words-file data/lang_bpe_500/words.txt  \
+  --G data/lm/G_4_gram.pt \
+  --method whole-lattice-rescoring \
+  --sample-rate 16000 \
+  /path/to/foo.wav \
+  /path/to/bar.wav
+"""
+
+import argparse
+import logging
+import math
+from typing import List
+
+import k2
+import kaldifeat
+import sentencepiece as spm
+import torch
+import torchaudio
+from ctc_decode import get_decoding_params
+from torch.nn.utils.rnn import pad_sequence
+from train import get_params
+
+from icefall.decode import (
+    get_lattice,
+    one_best_decoding,
+    rescore_with_n_best_list,
+    rescore_with_whole_lattice,
+)
+from icefall.utils import get_texts
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+
+    parser.add_argument(
+        "--model-filename",
+        type=str,
+        required=True,
+        help="Path to the torchscript model.",
+    )
+
+    parser.add_argument(
+        "--words-file",
+        type=str,
+        help="""Path to words.txt.
+        Used only when method is not ctc-decoding.
+        """,
+    )
+
+    parser.add_argument(
+        "--HLG",
+        type=str,
+        help="""Path to HLG.pt.
+        Used only when method is not ctc-decoding.
+        """,
+    )
+
+    parser.add_argument(
+        "--bpe-model",
+        type=str,
+        help="""Path to bpe.model.
+        Used only when method is ctc-decoding.
+        """,
+    )
+
+    parser.add_argument(
+        "--method",
+        type=str,
+        default="1best",
+        help="""Decoding method.
+        Possible values are:
+        (0) ctc-decoding - Use CTC decoding. It uses a sentence
+            piece model, i.e., lang_dir/bpe.model, to convert
+            word pieces to words. It needs neither a lexicon
+            nor an n-gram LM.
+        (1) 1best - Use the best path as decoding output. Only
+            the transformer encoder output is used for decoding.
+            We call it HLG decoding.
+        (2) nbest-rescoring. Extract n paths from the decoding lattice,
+            rescore them with an LM, the path with
+            the highest score is the decoding result.
+            We call it HLG decoding + n-gram LM rescoring.
+        (3) whole-lattice-rescoring - Use an LM to rescore the
+            decoding lattice and then use 1best to decode the
+            rescored lattice.
+            We call it HLG decoding + n-gram LM rescoring.
+        """,
+    )
+
+    parser.add_argument(
+        "--G",
+        type=str,
+        help="""An LM for rescoring.
+        Used only when method is
+        whole-lattice-rescoring or nbest-rescoring.
+        It's usually a 4-gram LM.
+        """,
+    )
+
+    parser.add_argument(
+        "--num-paths",
+        type=int,
+        default=100,
+        help="""
+        Used only when method is attention-decoder.
+        It specifies the size of n-best list.""",
+    )
+
+    parser.add_argument(
+        "--ngram-lm-scale",
+        type=float,
+        default=1.3,
+        help="""
+        Used only when method is whole-lattice-rescoring and nbest-rescoring.
+        It specifies the scale for n-gram LM scores.
+        (Note: You need to tune it on a dataset.)
+        """,
+    )
+
+    parser.add_argument(
+        "--nbest-scale",
+        type=float,
+        default=0.5,
+        help="""
+        Used only when method is nbest-rescoring.
+        It specifies the scale for lattice.scores when
+        extracting n-best lists. A smaller value results in
+        more unique number of paths with the risk of missing
+        the best path.
+        """,
+    )
+
+    parser.add_argument(
+        "--num-classes",
+        type=int,
+        default=500,
+        help="""
+        Vocab size in the BPE model.
+        """,
+    )
+
+    parser.add_argument(
+        "--sample-rate",
+        type=int,
+        default=16000,
+        help="The sample rate of the input sound file",
+    )
+
+    parser.add_argument(
+        "sound_files",
+        type=str,
+        nargs="+",
+        help="The input sound file(s) to transcribe. "
+        "Supported formats are those supported by torchaudio.load(). "
+        "For example, wav and flac are supported. "
+        "The sample rate has to be 16kHz.",
+    )
+
+    return parser
+
+
+def read_sound_files(
+    filenames: List[str], expected_sample_rate: float = 16000
+) -> List[torch.Tensor]:
+    """Read a list of sound files into a list 1-D float32 torch tensors.
+    Args:
+      filenames:
+        A list of sound filenames.
+      expected_sample_rate:
+        The expected sample rate of the sound files.
+    Returns:
+      Return a list of 1-D float32 torch tensors.
+    """
+    ans = []
+    for f in filenames:
+        wave, sample_rate = torchaudio.load(f)
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"Expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
+        # We use only the first channel
+        ans.append(wave[0])
+    return ans
+
+
+@torch.no_grad()
+def main():
+    parser = get_parser()
+    args = parser.parse_args()
+
+    params = get_params()
+    # add decoding params
+    params.update(get_decoding_params())
+    params.update(vars(args))
+
+    logging.info(f"{params}")
+
+    device = torch.device("cpu")
+    if torch.cuda.is_available():
+        device = torch.device("cuda", 0)
+
+    logging.info(f"device: {device}")
+
+    model = torch.jit.load(args.model_filename)
+    model.to(device)
+    model.eval()
+
+    logging.info("Constructing Fbank computer")
+    opts = kaldifeat.FbankOptions()
+    opts.device = device
+    opts.frame_opts.dither = 0
+    opts.frame_opts.snip_edges = False
+    opts.frame_opts.samp_freq = params.sample_rate
+    opts.mel_opts.num_bins = params.feature_dim
+
+    fbank = kaldifeat.Fbank(opts)
+
+    logging.info(f"Reading sound files: {params.sound_files}")
+    waves = read_sound_files(
+        filenames=params.sound_files, expected_sample_rate=params.sample_rate
+    )
+    waves = [w.to(device) for w in waves]
+
+    logging.info("Decoding started")
+    features = fbank(waves)
+    feature_lengths = [f.size(0) for f in features]
+
+    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
+    feature_lengths = torch.tensor(feature_lengths, device=device)
+
+    encoder_out, encoder_out_lens = model.encoder(
+        x=features,
+        x_lens=feature_lengths,
+    )
+    nnet_output = model.ctc_output(encoder_out)
+
+    batch_size = nnet_output.shape[0]
+    supervision_segments = torch.tensor(
+        [[i, 0, nnet_output.shape[1]] for i in range(batch_size)],
+        dtype=torch.int32,
+    )
+
+    if params.method == "ctc-decoding":
+        logging.info("Use CTC decoding")
+        bpe_model = spm.SentencePieceProcessor()
+        bpe_model.load(params.bpe_model)
+        max_token_id = params.num_classes - 1
+
+        H = k2.ctc_topo(
+            max_token=max_token_id,
+            modified=False,
+            device=device,
+        )
+
+        lattice = get_lattice(
+            nnet_output=nnet_output,
+            decoding_graph=H,
+            supervision_segments=supervision_segments,
+            search_beam=params.search_beam,
+            output_beam=params.output_beam,
+            min_active_states=params.min_active_states,
+            max_active_states=params.max_active_states,
+            subsampling_factor=params.subsampling_factor,
+        )
+
+        best_path = one_best_decoding(
+            lattice=lattice, use_double_scores=params.use_double_scores
+        )
+        token_ids = get_texts(best_path)
+        hyps = bpe_model.decode(token_ids)
+        hyps = [s.split() for s in hyps]
+    elif params.method in [
+        "1best",
+        "nbest-rescoring",
+        "whole-lattice-rescoring",
+    ]:
+        logging.info(f"Loading HLG from {params.HLG}")
+        HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu"))
+        HLG = HLG.to(device)
+        if not hasattr(HLG, "lm_scores"):
+            # For whole-lattice-rescoring and attention-decoder
+            HLG.lm_scores = HLG.scores.clone()
+
+        if params.method in [
+            "nbest-rescoring",
+            "whole-lattice-rescoring",
+        ]:
+            logging.info(f"Loading G from {params.G}")
+            G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu"))
+            G = G.to(device)
+            if params.method == "whole-lattice-rescoring":
+                # Add epsilon self-loops to G as we will compose
+                # it with the whole lattice later
+                G = k2.add_epsilon_self_loops(G)
+                G = k2.arc_sort(G)
+
+            # G.lm_scores is used to replace HLG.lm_scores during
+            # LM rescoring.
+            G.lm_scores = G.scores.clone()
+
+        lattice = get_lattice(
+            nnet_output=nnet_output,
+            decoding_graph=HLG,
+            supervision_segments=supervision_segments,
+            search_beam=params.search_beam,
+            output_beam=params.output_beam,
+            min_active_states=params.min_active_states,
+            max_active_states=params.max_active_states,
+            subsampling_factor=params.subsampling_factor,
+        )
+
+        if params.method == "1best":
+            logging.info("Use HLG decoding")
+            best_path = one_best_decoding(
+                lattice=lattice, use_double_scores=params.use_double_scores
+            )
+        if params.method == "nbest-rescoring":
+            logging.info("Use HLG decoding + LM rescoring")
+            best_path_dict = rescore_with_n_best_list(
+                lattice=lattice,
+                G=G,
+                num_paths=params.num_paths,
+                lm_scale_list=[params.ngram_lm_scale],
+                nbest_scale=params.nbest_scale,
+            )
+            best_path = next(iter(best_path_dict.values()))
+        elif params.method == "whole-lattice-rescoring":
+            logging.info("Use HLG decoding + LM rescoring")
+            best_path_dict = rescore_with_whole_lattice(
+                lattice=lattice,
+                G_with_epsilon_loops=G,
+                lm_scale_list=[params.ngram_lm_scale],
+            )
+            best_path = next(iter(best_path_dict.values()))
+
+        hyps = get_texts(best_path)
+        word_sym_table = k2.SymbolTable.from_file(params.words_file)
+        hyps = [[word_sym_table[i] for i in ids] for ids in hyps]
+    else:
+        raise ValueError(f"Unsupported decoding method: {params.method}")
+
+    s = "\n"
+    for filename, hyp in zip(params.sound_files, hyps):
+        words = " ".join(hyp)
+        s += f"{filename}:\n{words}\n\n"
+    logging.info(s)
+
+    logging.info("Decoding Done")
+
+
+if __name__ == "__main__":
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+    logging.basicConfig(format=formatter, level=logging.INFO)
+    main()
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/joiner.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/joiner.py
new file mode 120000
index 000000000..ecfb6dd8a
--- /dev/null
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/joiner.py
@@ -0,0 +1 @@
+../pruned_transducer_stateless7/joiner.py
\ No newline at end of file
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/model.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/model.py
new file mode 100644
index 000000000..a6e919e2f
--- /dev/null
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/model.py
@@ -0,0 +1,198 @@
+# Copyright    2021  Xiaomi Corp.        (authors: Fangjun Kuang, Wei Kang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from typing import Tuple
+
+import k2
+import torch
+import torch.nn as nn
+from encoder_interface import EncoderInterface
+
+from icefall.utils import add_sos
+
+
+class Transducer(nn.Module):
+    """It implements https://arxiv.org/pdf/1211.3711.pdf
+    "Sequence Transduction with Recurrent Neural Networks"
+    """
+
+    def __init__(
+        self,
+        encoder: EncoderInterface,
+        decoder: nn.Module,
+        joiner: nn.Module,
+        encoder_dim: int,
+        decoder_dim: int,
+        joiner_dim: int,
+        vocab_size: int,
+    ):
+        """
+        Args:
+          encoder:
+            It is the transcription network in the paper. Its accepts
+            two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,).
+            It returns two tensors: `logits` of shape (N, T, encoder_dm) and
+            `logit_lens` of shape (N,).
+          decoder:
+            It is the prediction network in the paper. Its input shape
+            is (N, U) and its output shape is (N, U, decoder_dim).
+            It should contain one attribute: `blank_id`.
+          joiner:
+            It has two inputs with shapes: (N, T, encoder_dim) and (N, U, decoder_dim).
+            Its output shape is (N, T, U, vocab_size). Note that its output contains
+            unnormalized probs, i.e., not processed by log-softmax.
+        """
+        super().__init__()
+        assert isinstance(encoder, EncoderInterface), type(encoder)
+        assert hasattr(decoder, "blank_id")
+
+        self.encoder = encoder
+        self.decoder = decoder
+        self.joiner = joiner
+
+        self.simple_am_proj = nn.Linear(
+            encoder_dim,
+            vocab_size,
+        )
+        self.simple_lm_proj = nn.Linear(decoder_dim, vocab_size)
+
+        self.ctc_output = nn.Sequential(
+            nn.Dropout(p=0.1),
+            nn.Linear(encoder_dim, vocab_size),
+            nn.LogSoftmax(dim=-1),
+        )
+
+    def forward(
+        self,
+        x: torch.Tensor,
+        x_lens: torch.Tensor,
+        y: k2.RaggedTensor,
+        prune_range: int = 5,
+        am_scale: float = 0.0,
+        lm_scale: float = 0.0,
+    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+        """
+        Args:
+          x:
+            A 3-D tensor of shape (N, T, C).
+          x_lens:
+            A 1-D tensor of shape (N,). It contains the number of frames in `x`
+            before padding.
+          y:
+            A ragged tensor with 2 axes [utt][label]. It contains labels of each
+            utterance.
+          prune_range:
+            The prune range for rnnt loss, it means how many symbols(context)
+            we are considering for each frame to compute the loss.
+          am_scale:
+            The scale to smooth the loss with am (output of encoder network)
+            part
+          lm_scale:
+            The scale to smooth the loss with lm (output of predictor network)
+            part
+        Returns:
+          Return a tuple containing simple loss, pruned loss, and ctc-output.
+
+        Note:
+           Regarding am_scale & lm_scale, it will make the loss-function one of
+           the form:
+              lm_scale * lm_probs + am_scale * am_probs +
+              (1-lm_scale-am_scale) * combined_probs
+        """
+        assert x.ndim == 3, x.shape
+        assert x_lens.ndim == 1, x_lens.shape
+        assert y.num_axes == 2, y.num_axes
+
+        assert x.size(0) == x_lens.size(0) == y.dim0
+
+        encoder_out, x_lens = self.encoder(x, x_lens)
+        assert torch.all(x_lens > 0)
+
+        # compute ctc log-probs
+        ctc_output = self.ctc_output(encoder_out)
+
+        # Now for the decoder, i.e., the prediction network
+        row_splits = y.shape.row_splits(1)
+        y_lens = row_splits[1:] - row_splits[:-1]
+
+        blank_id = self.decoder.blank_id
+        sos_y = add_sos(y, sos_id=blank_id)
+
+        # sos_y_padded: [B, S + 1], start with SOS.
+        sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
+
+        # decoder_out: [B, S + 1, decoder_dim]
+        decoder_out = self.decoder(sos_y_padded)
+
+        # Note: y does not start with SOS
+        # y_padded : [B, S]
+        y_padded = y.pad(mode="constant", padding_value=0)
+
+        y_padded = y_padded.to(torch.int64)
+        boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device)
+        boundary[:, 2] = y_lens
+        boundary[:, 3] = x_lens
+
+        lm = self.simple_lm_proj(decoder_out)
+        am = self.simple_am_proj(encoder_out)
+
+        with torch.cuda.amp.autocast(enabled=False):
+            simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
+                lm=lm.float(),
+                am=am.float(),
+                symbols=y_padded,
+                termination_symbol=blank_id,
+                lm_only_scale=lm_scale,
+                am_only_scale=am_scale,
+                boundary=boundary,
+                reduction="sum",
+                return_grad=True,
+            )
+
+        # ranges : [B, T, prune_range]
+        ranges = k2.get_rnnt_prune_ranges(
+            px_grad=px_grad,
+            py_grad=py_grad,
+            boundary=boundary,
+            s_range=prune_range,
+        )
+
+        # am_pruned : [B, T, prune_range, encoder_dim]
+        # lm_pruned : [B, T, prune_range, decoder_dim]
+        am_pruned, lm_pruned = k2.do_rnnt_pruning(
+            am=self.joiner.encoder_proj(encoder_out),
+            lm=self.joiner.decoder_proj(decoder_out),
+            ranges=ranges,
+        )
+
+        # logits : [B, T, prune_range, vocab_size]
+
+        # project_input=False since we applied the decoder's input projections
+        # prior to do_rnnt_pruning (this is an optimization for speed).
+        logits = self.joiner(am_pruned, lm_pruned, project_input=False)
+
+        with torch.cuda.amp.autocast(enabled=False):
+            pruned_loss = k2.rnnt_loss_pruned(
+                logits=logits.float(),
+                symbols=y_padded,
+                ranges=ranges,
+                termination_symbol=blank_id,
+                boundary=boundary,
+                reduction="sum",
+            )
+
+        return (simple_loss, pruned_loss, ctc_output)
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/optim.py
new file mode 120000
index 000000000..81ac4a89a
--- /dev/null
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/optim.py
@@ -0,0 +1 @@
+../pruned_transducer_stateless7/optim.py
\ No newline at end of file
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained.py
new file mode 100755
index 000000000..2f1b1a49f
--- /dev/null
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained.py
@@ -0,0 +1,353 @@
+#!/usr/bin/env python3
+# Copyright      2021  Xiaomi Corp.        (authors: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This script loads a checkpoint and uses it to decode waves.
+You can generate the checkpoint with the following command:
+
+./pruned_transducer_stateless7_ctc/export.py \
+  --exp-dir ./pruned_transducer_stateless7_ctc/exp \
+  --bpe-model data/lang_bpe_500/bpe.model \
+  --epoch 20 \
+  --avg 10
+
+Usage of this script:
+
+(1) greedy search
+./pruned_transducer_stateless7_ctc/pretrained.py \
+    --checkpoint ./pruned_transducer_stateless7_ctc/exp/pretrained.pt \
+    --bpe-model ./data/lang_bpe_500/bpe.model \
+    --method greedy_search \
+    /path/to/foo.wav \
+    /path/to/bar.wav
+
+(2) beam search
+./pruned_transducer_stateless7_ctc/pretrained.py \
+    --checkpoint ./pruned_transducer_stateless7_ctc/exp/pretrained.pt \
+    --bpe-model ./data/lang_bpe_500/bpe.model \
+    --method beam_search \
+    --beam-size 4 \
+    /path/to/foo.wav \
+    /path/to/bar.wav
+
+(3) modified beam search
+./pruned_transducer_stateless7_ctc/pretrained.py \
+    --checkpoint ./pruned_transducer_stateless7_ctc/exp/pretrained.pt \
+    --bpe-model ./data/lang_bpe_500/bpe.model \
+    --method modified_beam_search \
+    --beam-size 4 \
+    /path/to/foo.wav \
+    /path/to/bar.wav
+
+(4) fast beam search
+./pruned_transducer_stateless7_ctc/pretrained.py \
+    --checkpoint ./pruned_transducer_stateless7_ctc/exp/pretrained.pt \
+    --bpe-model ./data/lang_bpe_500/bpe.model \
+    --method fast_beam_search \
+    --beam-size 4 \
+    /path/to/foo.wav \
+    /path/to/bar.wav
+
+You can also use `./pruned_transducer_stateless7_ctc/exp/epoch-xx.pt`.
+
+Note: ./pruned_transducer_stateless7_ctc/exp/pretrained.pt is generated by
+./pruned_transducer_stateless7_ctc/export.py
+"""
+
+
+import argparse
+import logging
+import math
+from typing import List
+
+import k2
+import kaldifeat
+import sentencepiece as spm
+import torch
+import torchaudio
+from beam_search import (
+    beam_search,
+    fast_beam_search_one_best,
+    greedy_search,
+    greedy_search_batch,
+    modified_beam_search,
+)
+from torch.nn.utils.rnn import pad_sequence
+from train import add_model_arguments, get_params, get_transducer_model
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+
+    parser.add_argument(
+        "--checkpoint",
+        type=str,
+        required=True,
+        help="Path to the checkpoint. "
+        "The checkpoint is assumed to be saved by "
+        "icefall.checkpoint.save_checkpoint().",
+    )
+
+    parser.add_argument(
+        "--bpe-model",
+        type=str,
+        help="""Path to bpe.model.""",
+    )
+
+    parser.add_argument(
+        "--method",
+        type=str,
+        default="greedy_search",
+        help="""Possible values are:
+          - greedy_search
+          - beam_search
+          - modified_beam_search
+          - fast_beam_search
+        """,
+    )
+
+    parser.add_argument(
+        "sound_files",
+        type=str,
+        nargs="+",
+        help="The input sound file(s) to transcribe. "
+        "Supported formats are those supported by torchaudio.load(). "
+        "For example, wav and flac are supported. "
+        "The sample rate has to be 16kHz.",
+    )
+
+    parser.add_argument(
+        "--sample-rate",
+        type=int,
+        default=16000,
+        help="The sample rate of the input sound file",
+    )
+
+    parser.add_argument(
+        "--beam-size",
+        type=int,
+        default=4,
+        help="""An integer indicating how many candidates we will keep for each
+        frame. Used only when --method is beam_search or
+        modified_beam_search.""",
+    )
+
+    parser.add_argument(
+        "--beam",
+        type=float,
+        default=4,
+        help="""A floating point value to calculate the cutoff score during beam
+        search (i.e., `cutoff = max-score - beam`), which is the same as the
+        `beam` in Kaldi.
+        Used only when --method is fast_beam_search""",
+    )
+
+    parser.add_argument(
+        "--max-contexts",
+        type=int,
+        default=4,
+        help="""Used only when --method is fast_beam_search""",
+    )
+
+    parser.add_argument(
+        "--max-states",
+        type=int,
+        default=8,
+        help="""Used only when --method is fast_beam_search""",
+    )
+
+    parser.add_argument(
+        "--context-size",
+        type=int,
+        default=2,
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+    )
+    parser.add_argument(
+        "--max-sym-per-frame",
+        type=int,
+        default=1,
+        help="""Maximum number of symbols per frame. Used only when
+        --method is greedy_search.
+        """,
+    )
+
+    add_model_arguments(parser)
+
+    return parser
+
+
+def read_sound_files(
+    filenames: List[str], expected_sample_rate: float
+) -> List[torch.Tensor]:
+    """Read a list of sound files into a list 1-D float32 torch tensors.
+    Args:
+      filenames:
+        A list of sound filenames.
+      expected_sample_rate:
+        The expected sample rate of the sound files.
+    Returns:
+      Return a list of 1-D float32 torch tensors.
+    """
+    ans = []
+    for f in filenames:
+        wave, sample_rate = torchaudio.load(f)
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"Expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
+        # We use only the first channel
+        ans.append(wave[0])
+    return ans
+
+
+@torch.no_grad()
+def main():
+    parser = get_parser()
+    args = parser.parse_args()
+
+    params = get_params()
+
+    params.update(vars(args))
+
+    sp = spm.SentencePieceProcessor()
+    sp.load(params.bpe_model)
+
+    #  is defined in local/train_bpe_model.py
+    params.blank_id = sp.piece_to_id("")
+    params.unk_id = sp.piece_to_id("")
+    params.vocab_size = sp.get_piece_size()
+
+    logging.info(f"{params}")
+
+    device = torch.device("cpu")
+    if torch.cuda.is_available():
+        device = torch.device("cuda", 0)
+
+    logging.info(f"device: {device}")
+
+    logging.info("Creating model")
+    model = get_transducer_model(params)
+
+    num_param = sum([p.numel() for p in model.parameters()])
+    logging.info(f"Number of model parameters: {num_param}")
+
+    checkpoint = torch.load(args.checkpoint, map_location="cpu")
+    model.load_state_dict(checkpoint["model"], strict=False)
+    model.to(device)
+    model.eval()
+    model.device = device
+
+    logging.info("Constructing Fbank computer")
+    opts = kaldifeat.FbankOptions()
+    opts.device = device
+    opts.frame_opts.dither = 0
+    opts.frame_opts.snip_edges = False
+    opts.frame_opts.samp_freq = params.sample_rate
+    opts.mel_opts.num_bins = params.feature_dim
+
+    fbank = kaldifeat.Fbank(opts)
+
+    logging.info(f"Reading sound files: {params.sound_files}")
+    waves = read_sound_files(
+        filenames=params.sound_files, expected_sample_rate=params.sample_rate
+    )
+    waves = [w.to(device) for w in waves]
+
+    logging.info("Decoding started")
+    features = fbank(waves)
+    feature_lengths = [f.size(0) for f in features]
+
+    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
+
+    feature_lengths = torch.tensor(feature_lengths, device=device)
+
+    encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths)
+
+    num_waves = encoder_out.size(0)
+    hyps = []
+    msg = f"Using {params.method}"
+    if params.method == "beam_search":
+        msg += f" with beam size {params.beam_size}"
+    logging.info(msg)
+
+    if params.method == "fast_beam_search":
+        decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
+        hyp_tokens = fast_beam_search_one_best(
+            model=model,
+            decoding_graph=decoding_graph,
+            encoder_out=encoder_out,
+            encoder_out_lens=encoder_out_lens,
+            beam=params.beam,
+            max_contexts=params.max_contexts,
+            max_states=params.max_states,
+        )
+        for hyp in sp.decode(hyp_tokens):
+            hyps.append(hyp.split())
+    elif params.method == "modified_beam_search":
+        hyp_tokens = modified_beam_search(
+            model=model,
+            encoder_out=encoder_out,
+            encoder_out_lens=encoder_out_lens,
+            beam=params.beam_size,
+        )
+
+        for hyp in sp.decode(hyp_tokens):
+            hyps.append(hyp.split())
+    elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
+        hyp_tokens = greedy_search_batch(
+            model=model,
+            encoder_out=encoder_out,
+            encoder_out_lens=encoder_out_lens,
+        )
+        for hyp in sp.decode(hyp_tokens):
+            hyps.append(hyp.split())
+    else:
+        for i in range(num_waves):
+            # fmt: off
+            encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
+            # fmt: on
+            if params.method == "greedy_search":
+                hyp = greedy_search(
+                    model=model,
+                    encoder_out=encoder_out_i,
+                    max_sym_per_frame=params.max_sym_per_frame,
+                )
+            elif params.method == "beam_search":
+                hyp = beam_search(
+                    model=model,
+                    encoder_out=encoder_out_i,
+                    beam=params.beam_size,
+                )
+            else:
+                raise ValueError(f"Unsupported method: {params.method}")
+
+            hyps.append(sp.decode(hyp).split())
+
+    s = "\n"
+    for filename, hyp in zip(params.sound_files, hyps):
+        words = " ".join(hyp)
+        s += f"{filename}:\n{words}\n\n"
+    logging.info(s)
+
+    logging.info("Decoding Done")
+
+
+if __name__ == "__main__":
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+    logging.basicConfig(format=formatter, level=logging.INFO)
+    main()
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained_ctc.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained_ctc.py
new file mode 100755
index 000000000..74aef1bc7
--- /dev/null
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained_ctc.py
@@ -0,0 +1,441 @@
+#!/usr/bin/env python3
+# Copyright      2022  Xiaomi Corp.        (authors: Fangjun Kuang,
+#                                                    Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This script loads torchscript models, exported by `torch.jit.script()`
+and uses them to decode waves.
+You can use the following command to get the exported models:
+
+./pruned_transducer_stateless7_ctc/export.py \
+  --exp-dir ./pruned_transducer_stateless7_ctc/exp \
+  --bpe-model data/lang_bpe_500/bpe.model \
+  --epoch 20 \
+  --avg 10
+
+Usage of this script:
+
+(1) ctc-decoding
+./pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py \
+  --checkpoint ./pruned_transducer_stateless7_ctc/exp/pretrained.pt \
+  --bpe-model data/lang_bpe_500/bpe.model \
+  --method ctc-decoding \
+  --sample-rate 16000 \
+  /path/to/foo.wav \
+  /path/to/bar.wav
+
+(2) 1best
+./pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py \
+  --checkpoint ./pruned_transducer_stateless7_ctc/exp/pretrained.pt \
+  --HLG data/lang_bpe_500/HLG.pt \
+  --words-file data/lang_bpe_500/words.txt  \
+  --method 1best \
+  --sample-rate 16000 \
+  /path/to/foo.wav \
+  /path/to/bar.wav
+
+(3) nbest-rescoring
+./bruned_transducer_stateless7_ctc/jit_pretrained_ctc.py \
+  --checkpoint ./pruned_transducer_stateless7_ctc/exp/pretrained.pt \
+  --HLG data/lang_bpe_500/HLG.pt \
+  --words-file data/lang_bpe_500/words.txt  \
+  --G data/lm/G_4_gram.pt \
+  --method nbest-rescoring \
+  --sample-rate 16000 \
+  /path/to/foo.wav \
+  /path/to/bar.wav
+
+
+(4) whole-lattice-rescoring
+./pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py \
+  --checkpoint ./pruned_transducer_stateless7_ctc/exp/pretrained.pt \
+  --HLG data/lang_bpe_500/HLG.pt \
+  --words-file data/lang_bpe_500/words.txt  \
+  --G data/lm/G_4_gram.pt \
+  --method whole-lattice-rescoring \
+  --sample-rate 16000 \
+  /path/to/foo.wav \
+  /path/to/bar.wav
+"""
+
+import argparse
+import logging
+import math
+from typing import List
+
+import k2
+import kaldifeat
+import sentencepiece as spm
+import torch
+import torchaudio
+from ctc_decode import get_decoding_params
+from torch.nn.utils.rnn import pad_sequence
+from train import add_model_arguments, get_params, get_transducer_model
+
+from icefall.decode import (
+    get_lattice,
+    one_best_decoding,
+    rescore_with_n_best_list,
+    rescore_with_whole_lattice,
+)
+from icefall.utils import get_texts
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+
+    parser.add_argument(
+        "--checkpoint",
+        type=str,
+        required=True,
+        help="Path to the checkpoint. "
+        "The checkpoint is assumed to be saved by "
+        "icefall.checkpoint.save_checkpoint().",
+    )
+
+    parser.add_argument(
+        "--context-size",
+        type=int,
+        default=2,
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+    )
+
+    parser.add_argument(
+        "--words-file",
+        type=str,
+        help="""Path to words.txt.
+        Used only when method is not ctc-decoding.
+        """,
+    )
+
+    parser.add_argument(
+        "--HLG",
+        type=str,
+        help="""Path to HLG.pt.
+        Used only when method is not ctc-decoding.
+        """,
+    )
+
+    parser.add_argument(
+        "--bpe-model",
+        type=str,
+        help="""Path to bpe.model.
+        Used only when method is ctc-decoding.
+        """,
+    )
+
+    parser.add_argument(
+        "--method",
+        type=str,
+        default="1best",
+        help="""Decoding method.
+        Possible values are:
+        (0) ctc-decoding - Use CTC decoding. It uses a sentence
+            piece model, i.e., lang_dir/bpe.model, to convert
+            word pieces to words. It needs neither a lexicon
+            nor an n-gram LM.
+        (1) 1best - Use the best path as decoding output. Only
+            the transformer encoder output is used for decoding.
+            We call it HLG decoding.
+        (2) nbest-rescoring. Extract n paths from the decoding lattice,
+            rescore them with an LM, the path with
+            the highest score is the decoding result.
+            We call it HLG decoding + n-gram LM rescoring.
+        (3) whole-lattice-rescoring - Use an LM to rescore the
+            decoding lattice and then use 1best to decode the
+            rescored lattice.
+            We call it HLG decoding + n-gram LM rescoring.
+        """,
+    )
+
+    parser.add_argument(
+        "--G",
+        type=str,
+        help="""An LM for rescoring.
+        Used only when method is
+        whole-lattice-rescoring or nbest-rescoring.
+        It's usually a 4-gram LM.
+        """,
+    )
+
+    parser.add_argument(
+        "--num-paths",
+        type=int,
+        default=100,
+        help="""
+        Used only when method is attention-decoder.
+        It specifies the size of n-best list.""",
+    )
+
+    parser.add_argument(
+        "--ngram-lm-scale",
+        type=float,
+        default=1.3,
+        help="""
+        Used only when method is whole-lattice-rescoring and nbest-rescoring.
+        It specifies the scale for n-gram LM scores.
+        (Note: You need to tune it on a dataset.)
+        """,
+    )
+
+    parser.add_argument(
+        "--nbest-scale",
+        type=float,
+        default=0.5,
+        help="""
+        Used only when method is nbest-rescoring.
+        It specifies the scale for lattice.scores when
+        extracting n-best lists. A smaller value results in
+        more unique number of paths with the risk of missing
+        the best path.
+        """,
+    )
+
+    parser.add_argument(
+        "--num-classes",
+        type=int,
+        default=500,
+        help="""
+        Vocab size in the BPE model.
+        """,
+    )
+
+    parser.add_argument(
+        "--sample-rate",
+        type=int,
+        default=16000,
+        help="The sample rate of the input sound file",
+    )
+
+    parser.add_argument(
+        "sound_files",
+        type=str,
+        nargs="+",
+        help="The input sound file(s) to transcribe. "
+        "Supported formats are those supported by torchaudio.load(). "
+        "For example, wav and flac are supported. "
+        "The sample rate has to be 16kHz.",
+    )
+
+    add_model_arguments(parser)
+
+    return parser
+
+
+def read_sound_files(
+    filenames: List[str], expected_sample_rate: float = 16000
+) -> List[torch.Tensor]:
+    """Read a list of sound files into a list 1-D float32 torch tensors.
+    Args:
+      filenames:
+        A list of sound filenames.
+      expected_sample_rate:
+        The expected sample rate of the sound files.
+    Returns:
+      Return a list of 1-D float32 torch tensors.
+    """
+    ans = []
+    for f in filenames:
+        wave, sample_rate = torchaudio.load(f)
+        assert sample_rate == expected_sample_rate, (
+            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
+        )
+        # We use only the first channel
+        ans.append(wave[0])
+    return ans
+
+
+@torch.no_grad()
+def main():
+    parser = get_parser()
+    args = parser.parse_args()
+
+    params = get_params()
+    # add decoding params
+    params.update(get_decoding_params())
+    params.update(vars(args))
+    params.vocab_size = params.num_classes
+    params.blank_id = 0
+
+    logging.info(f"{params}")
+
+    device = torch.device("cpu")
+    if torch.cuda.is_available():
+        device = torch.device("cuda", 0)
+
+    logging.info(f"device: {device}")
+
+    logging.info("Creating model")
+    model = get_transducer_model(params)
+
+    num_param = sum([p.numel() for p in model.parameters()])
+    logging.info(f"Number of model parameters: {num_param}")
+
+    checkpoint = torch.load(args.checkpoint, map_location="cpu")
+    model.load_state_dict(checkpoint["model"], strict=False)
+    model.to(device)
+    model.eval()
+
+    logging.info("Constructing Fbank computer")
+    opts = kaldifeat.FbankOptions()
+    opts.device = device
+    opts.frame_opts.dither = 0
+    opts.frame_opts.snip_edges = False
+    opts.frame_opts.samp_freq = params.sample_rate
+    opts.mel_opts.num_bins = params.feature_dim
+
+    fbank = kaldifeat.Fbank(opts)
+
+    logging.info(f"Reading sound files: {params.sound_files}")
+    waves = read_sound_files(
+        filenames=params.sound_files, expected_sample_rate=params.sample_rate
+    )
+    waves = [w.to(device) for w in waves]
+
+    logging.info("Decoding started")
+    features = fbank(waves)
+    feature_lengths = [f.size(0) for f in features]
+
+    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
+    feature_lengths = torch.tensor(feature_lengths, device=device)
+
+    encoder_out, encoder_out_lens = model.encoder(
+        x=features,
+        x_lens=feature_lengths,
+    )
+    nnet_output = model.ctc_output(encoder_out)
+
+    batch_size = nnet_output.shape[0]
+    supervision_segments = torch.tensor(
+        [[i, 0, nnet_output.shape[1]] for i in range(batch_size)],
+        dtype=torch.int32,
+    )
+
+    if params.method == "ctc-decoding":
+        logging.info("Use CTC decoding")
+        bpe_model = spm.SentencePieceProcessor()
+        bpe_model.load(params.bpe_model)
+        max_token_id = params.num_classes - 1
+
+        H = k2.ctc_topo(
+            max_token=max_token_id,
+            modified=False,
+            device=device,
+        )
+
+        lattice = get_lattice(
+            nnet_output=nnet_output,
+            decoding_graph=H,
+            supervision_segments=supervision_segments,
+            search_beam=params.search_beam,
+            output_beam=params.output_beam,
+            min_active_states=params.min_active_states,
+            max_active_states=params.max_active_states,
+            subsampling_factor=params.subsampling_factor,
+        )
+
+        best_path = one_best_decoding(
+            lattice=lattice, use_double_scores=params.use_double_scores
+        )
+        token_ids = get_texts(best_path)
+        hyps = bpe_model.decode(token_ids)
+        hyps = [s.split() for s in hyps]
+    elif params.method in [
+        "1best",
+        "nbest-rescoring",
+        "whole-lattice-rescoring",
+    ]:
+        logging.info(f"Loading HLG from {params.HLG}")
+        HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu"))
+        HLG = HLG.to(device)
+        if not hasattr(HLG, "lm_scores"):
+            # For whole-lattice-rescoring and attention-decoder
+            HLG.lm_scores = HLG.scores.clone()
+
+        if params.method in [
+            "nbest-rescoring",
+            "whole-lattice-rescoring",
+        ]:
+            logging.info(f"Loading G from {params.G}")
+            G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu"))
+            G = G.to(device)
+            if params.method == "whole-lattice-rescoring":
+                # Add epsilon self-loops to G as we will compose
+                # it with the whole lattice later
+                G = k2.add_epsilon_self_loops(G)
+                G = k2.arc_sort(G)
+
+            # G.lm_scores is used to replace HLG.lm_scores during
+            # LM rescoring.
+            G.lm_scores = G.scores.clone()
+
+        lattice = get_lattice(
+            nnet_output=nnet_output,
+            decoding_graph=HLG,
+            supervision_segments=supervision_segments,
+            search_beam=params.search_beam,
+            output_beam=params.output_beam,
+            min_active_states=params.min_active_states,
+            max_active_states=params.max_active_states,
+            subsampling_factor=params.subsampling_factor,
+        )
+
+        if params.method == "1best":
+            logging.info("Use HLG decoding")
+            best_path = one_best_decoding(
+                lattice=lattice, use_double_scores=params.use_double_scores
+            )
+        if params.method == "nbest-rescoring":
+            logging.info("Use HLG decoding + LM rescoring")
+            best_path_dict = rescore_with_n_best_list(
+                lattice=lattice,
+                G=G,
+                num_paths=params.num_paths,
+                lm_scale_list=[params.ngram_lm_scale],
+                nbest_scale=params.nbest_scale,
+            )
+            best_path = next(iter(best_path_dict.values()))
+        elif params.method == "whole-lattice-rescoring":
+            logging.info("Use HLG decoding + LM rescoring")
+            best_path_dict = rescore_with_whole_lattice(
+                lattice=lattice,
+                G_with_epsilon_loops=G,
+                lm_scale_list=[params.ngram_lm_scale],
+            )
+            best_path = next(iter(best_path_dict.values()))
+
+        hyps = get_texts(best_path)
+        word_sym_table = k2.SymbolTable.from_file(params.words_file)
+        hyps = [[word_sym_table[i] for i in ids] for ids in hyps]
+    else:
+        raise ValueError(f"Unsupported decoding method: {params.method}")
+
+    s = "\n"
+    for filename, hyp in zip(params.sound_files, hyps):
+        words = " ".join(hyp)
+        s += f"{filename}:\n{words}\n\n"
+    logging.info(s)
+
+    logging.info("Decoding Done")
+
+
+if __name__ == "__main__":
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+    logging.basicConfig(format=formatter, level=logging.INFO)
+    main()
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/scaling.py
new file mode 120000
index 000000000..2428b74b9
--- /dev/null
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/scaling.py
@@ -0,0 +1 @@
+../pruned_transducer_stateless7/scaling.py
\ No newline at end of file
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/scaling_converter.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/scaling_converter.py
new file mode 120000
index 000000000..b8b8ba432
--- /dev/null
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/scaling_converter.py
@@ -0,0 +1 @@
+../pruned_transducer_stateless7/scaling_converter.py
\ No newline at end of file
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/test_model.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/test_model.py
new file mode 100755
index 000000000..e482d2040
--- /dev/null
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/test_model.py
@@ -0,0 +1,56 @@
+#!/usr/bin/env python3
+# Copyright    2022  Xiaomi Corp.        (authors: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""
+To run this file, do:
+
+    cd icefall/egs/librispeech/ASR
+    python ./pruned_transducer_stateless7_ctc/test_model.py
+"""
+
+from train import get_params, get_transducer_model
+
+
+def test_model_1():
+    params = get_params()
+    params.vocab_size = 500
+    params.blank_id = 0
+    params.context_size = 2
+    params.num_encoder_layers = "2,4,3,2,4"
+    #  params.feedforward_dims = "1024,1024,1536,1536,1024"
+    params.feedforward_dims = "1024,1024,2048,2048,1024"
+    params.nhead = "8,8,8,8,8"
+    params.encoder_dims = "384,384,384,384,384"
+    params.attention_dims = "192,192,192,192,192"
+    params.encoder_unmasked_dims = "256,256,256,256,256"
+    params.zipformer_downsampling_factors = "1,2,4,8,2"
+    params.cnn_module_kernels = "31,31,31,31,31"
+    params.decoder_dim = 512
+    params.joiner_dim = 512
+    model = get_transducer_model(params)
+
+    num_param = sum([p.numel() for p in model.parameters()])
+    print(f"Number of model parameters: {num_param}")
+
+
+def main():
+    test_model_1()
+
+
+if __name__ == "__main__":
+    main()
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/train.py
new file mode 100755
index 000000000..abfd56e5a
--- /dev/null
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/train.py
@@ -0,0 +1,1252 @@
+#!/usr/bin/env python3
+# Copyright    2021-2022  Xiaomi Corp.        (authors: Fangjun Kuang,
+#                                                       Wei Kang,
+#                                                       Mingshuang Luo,)
+#                                                       Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+
+export CUDA_VISIBLE_DEVICES="0,1,2,3"
+
+./pruned_transducer_stateless7_ctc/train.py \
+  --world-size 4 \
+  --num-epochs 30 \
+  --start-epoch 1 \
+  --exp-dir pruned_transducer_stateless7_ctc/exp \
+  --full-libri 1 \
+  --max-duration 300
+
+# For mix precision training:
+
+./pruned_transducer_stateless7_ctc/train.py \
+  --world-size 4 \
+  --num-epochs 30 \
+  --start-epoch 1 \
+  --use-fp16 1 \
+  --exp-dir pruned_transducer_stateless7_ctc/exp \
+  --full-libri 1 \
+  --max-duration 550
+
+"""
+
+
+import argparse
+import copy
+import logging
+import warnings
+from pathlib import Path
+from shutil import copyfile
+from typing import Any, Dict, Optional, Tuple, Union
+
+import k2
+import optim
+import sentencepiece as spm
+import torch
+import torch.multiprocessing as mp
+import torch.nn as nn
+from asr_datamodule import LibriSpeechAsrDataModule
+from decoder import Decoder
+from joiner import Joiner
+from lhotse.cut import Cut
+from lhotse.dataset.sampling.base import CutSampler
+from lhotse.utils import fix_random_seed
+from model import Transducer
+from optim import Eden, ScaledAdam
+from torch import Tensor
+from torch.cuda.amp import GradScaler
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.utils.tensorboard import SummaryWriter
+from zipformer import Zipformer
+
+from icefall import diagnostics
+from icefall.checkpoint import load_checkpoint, remove_checkpoints
+from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
+from icefall.checkpoint import (
+    save_checkpoint_with_global_batch_idx,
+    update_averaged_model,
+)
+from icefall.dist import cleanup_dist, setup_dist
+from icefall.env import get_env_info
+from icefall.hooks import register_inf_check_hooks
+from icefall.utils import (
+    AttributeDict,
+    MetricsTracker,
+    encode_supervisions,
+    setup_logger,
+    str2bool,
+)
+
+LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
+
+
+def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
+    if isinstance(model, DDP):
+        # get underlying nn.Module
+        model = model.module
+    for module in model.modules():
+        if hasattr(module, "batch_count"):
+            module.batch_count = batch_count
+
+
+def add_model_arguments(parser: argparse.ArgumentParser):
+    parser.add_argument(
+        "--num-encoder-layers",
+        type=str,
+        default="2,4,3,2,4",
+        help="Number of zipformer encoder layers, comma separated.",
+    )
+
+    parser.add_argument(
+        "--feedforward-dims",
+        type=str,
+        default="1024,1024,2048,2048,1024",
+        help="Feedforward dimension of the zipformer encoder layers, comma separated.",
+    )
+
+    parser.add_argument(
+        "--nhead",
+        type=str,
+        default="8,8,8,8,8",
+        help="Number of attention heads in the zipformer encoder layers.",
+    )
+
+    parser.add_argument(
+        "--encoder-dims",
+        type=str,
+        default="384,384,384,384,384",
+        help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated",
+    )
+
+    parser.add_argument(
+        "--attention-dims",
+        type=str,
+        default="192,192,192,192,192",
+        help="""Attention dimension in the 2 blocks of zipformer encoder layers, comma separated;
+        not the same as embedding dimension.""",
+    )
+
+    parser.add_argument(
+        "--encoder-unmasked-dims",
+        type=str,
+        default="256,256,256,256,256",
+        help="Unmasked dimensions in the encoders, relates to augmentation during training.  "
+        "Must be <= each of encoder_dims.  Empirically, less than 256 seems to make performance "
+        " worse.",
+    )
+
+    parser.add_argument(
+        "--zipformer-downsampling-factors",
+        type=str,
+        default="1,2,4,8,2",
+        help="Downsampling factor for each stack of encoder layers.",
+    )
+
+    parser.add_argument(
+        "--cnn-module-kernels",
+        type=str,
+        default="31,31,31,31,31",
+        help="Sizes of kernels in convolution modules",
+    )
+
+    parser.add_argument(
+        "--decoder-dim",
+        type=int,
+        default=512,
+        help="Embedding dimension in the decoder model.",
+    )
+
+    parser.add_argument(
+        "--joiner-dim",
+        type=int,
+        default=512,
+        help="""Dimension used in the joiner model.
+        Outputs from the encoder and decoder model are projected
+        to this dimension before adding.
+        """,
+    )
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+
+    parser.add_argument(
+        "--world-size",
+        type=int,
+        default=1,
+        help="Number of GPUs for DDP training.",
+    )
+
+    parser.add_argument(
+        "--master-port",
+        type=int,
+        default=12354,
+        help="Master port to use for DDP training.",
+    )
+
+    parser.add_argument(
+        "--tensorboard",
+        type=str2bool,
+        default=True,
+        help="Should various information be logged in tensorboard.",
+    )
+
+    parser.add_argument(
+        "--num-epochs",
+        type=int,
+        default=30,
+        help="Number of epochs to train.",
+    )
+
+    parser.add_argument(
+        "--start-epoch",
+        type=int,
+        default=1,
+        help="""Resume training from this epoch. It should be positive.
+        If larger than 1, it will load checkpoint from
+        exp-dir/epoch-{start_epoch-1}.pt
+        """,
+    )
+
+    parser.add_argument(
+        "--start-batch",
+        type=int,
+        default=0,
+        help="""If positive, --start-epoch is ignored and
+        it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt
+        """,
+    )
+
+    parser.add_argument(
+        "--exp-dir",
+        type=str,
+        default="pruned_transducer_stateless7_ctc/exp",
+        help="""The experiment dir.
+        It specifies the directory where all training related
+        files, e.g., checkpoints, log, etc, are saved
+        """,
+    )
+
+    parser.add_argument(
+        "--bpe-model",
+        type=str,
+        default="data/lang_bpe_500/bpe.model",
+        help="Path to the BPE model",
+    )
+
+    parser.add_argument(
+        "--base-lr", type=float, default=0.05, help="The base learning rate."
+    )
+
+    parser.add_argument(
+        "--lr-batches",
+        type=float,
+        default=5000,
+        help="""Number of steps that affects how rapidly the learning rate
+        decreases. We suggest not to change this.""",
+    )
+
+    parser.add_argument(
+        "--lr-epochs",
+        type=float,
+        default=3.5,
+        help="""Number of epochs that affects how rapidly the learning rate decreases.
+        """,
+    )
+
+    parser.add_argument(
+        "--context-size",
+        type=int,
+        default=2,
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+    )
+
+    parser.add_argument(
+        "--prune-range",
+        type=int,
+        default=5,
+        help="The prune range for rnnt loss, it means how many symbols(context)"
+        "we are using to compute the loss",
+    )
+
+    parser.add_argument(
+        "--lm-scale",
+        type=float,
+        default=0.25,
+        help="The scale to smooth the loss with lm "
+        "(output of prediction network) part.",
+    )
+
+    parser.add_argument(
+        "--am-scale",
+        type=float,
+        default=0.0,
+        help="The scale to smooth the loss with am (output of encoder network) part.",
+    )
+
+    parser.add_argument(
+        "--simple-loss-scale",
+        type=float,
+        default=0.5,
+        help="To get pruning ranges, we will calculate a simple version"
+        "loss(joiner is just addition), this simple loss also uses for"
+        "training (as a regularization item). We will scale the simple loss"
+        "with this parameter before adding to the final loss.",
+    )
+
+    parser.add_argument(
+        "--ctc-loss-scale",
+        type=float,
+        default=0.2,
+        help="Scale for CTC loss.",
+    )
+
+    parser.add_argument(
+        "--seed",
+        type=int,
+        default=42,
+        help="The seed for random generators intended for reproducibility",
+    )
+
+    parser.add_argument(
+        "--print-diagnostics",
+        type=str2bool,
+        default=False,
+        help="Accumulate stats on activations, print them and exit.",
+    )
+
+    parser.add_argument(
+        "--inf-check",
+        type=str2bool,
+        default=False,
+        help="Add hooks to check for infinite module outputs and gradients.",
+    )
+
+    parser.add_argument(
+        "--save-every-n",
+        type=int,
+        default=2000,
+        help="""Save checkpoint after processing this number of batches"
+        periodically. We save checkpoint to exp-dir/ whenever
+        params.batch_idx_train % save_every_n == 0. The checkpoint filename
+        has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
+        Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
+        end of each epoch where `xxx` is the epoch number counting from 0.
+        """,
+    )
+
+    parser.add_argument(
+        "--keep-last-k",
+        type=int,
+        default=30,
+        help="""Only keep this number of checkpoints on disk.
+        For instance, if it is 3, there are only 3 checkpoints
+        in the exp-dir with filenames `checkpoint-xxx.pt`.
+        It does not affect checkpoints with name `epoch-xxx.pt`.
+        """,
+    )
+
+    parser.add_argument(
+        "--average-period",
+        type=int,
+        default=200,
+        help="""Update the averaged model, namely `model_avg`, after processing
+        this number of batches. `model_avg` is a separate version of model,
+        in which each floating-point parameter is the average of all the
+        parameters from the start of training. Each time we take the average,
+        we do: `model_avg = model * (average_period / batch_idx_train) +
+            model_avg * ((batch_idx_train - average_period) / batch_idx_train)`.
+        """,
+    )
+
+    parser.add_argument(
+        "--use-fp16",
+        type=str2bool,
+        default=False,
+        help="Whether to use half precision training.",
+    )
+
+    add_model_arguments(parser)
+
+    return parser
+
+
+def get_params() -> AttributeDict:
+    """Return a dict containing training parameters.
+
+    All training related parameters that are not passed from the commandline
+    are saved in the variable `params`.
+
+    Commandline options are merged into `params` after they are parsed, so
+    you can also access them via `params`.
+
+    Explanation of options saved in `params`:
+
+        - best_train_loss: Best training loss so far. It is used to select
+                           the model that has the lowest training loss. It is
+                           updated during the training.
+
+        - best_valid_loss: Best validation loss so far. It is used to select
+                           the model that has the lowest validation loss. It is
+                           updated during the training.
+
+        - best_train_epoch: It is the epoch that has the best training loss.
+
+        - best_valid_epoch: It is the epoch that has the best validation loss.
+
+        - batch_idx_train: Used to writing statistics to tensorboard. It
+                           contains number of batches trained so far across
+                           epochs.
+
+        - log_interval:  Print training loss if batch_idx % log_interval` is 0
+
+        - reset_interval: Reset statistics if batch_idx % reset_interval is 0
+
+        - valid_interval:  Run validation if batch_idx % valid_interval is 0
+
+        - feature_dim: The model input dim. It has to match the one used
+                       in computing features.
+
+        - subsampling_factor:  The subsampling factor for the model.
+
+        - encoder_dim: Hidden dim for multi-head attention model.
+
+        - num_decoder_layers: Number of decoder layer of transformer decoder.
+
+        - warm_step: The warmup period that dictates the decay of the
+              scale on "simple" (un-pruned) loss.
+    """
+    params = AttributeDict(
+        {
+            "best_train_loss": float("inf"),
+            "best_valid_loss": float("inf"),
+            "best_train_epoch": -1,
+            "best_valid_epoch": -1,
+            "batch_idx_train": 0,
+            "log_interval": 50,
+            "reset_interval": 200,
+            "valid_interval": 3000,  # For the 100h subset, use 800
+            # parameters for zipformer
+            "feature_dim": 80,
+            "subsampling_factor": 4,  # not passed in, this is fixed.
+            # parameters for ctc loss
+            "beam_size": 10,
+            "use_double_scores": True,
+            "warm_step": 2000,
+            "env_info": get_env_info(),
+        }
+    )
+
+    return params
+
+
+def get_encoder_model(params: AttributeDict) -> nn.Module:
+    # TODO: We can add an option to switch between Zipformer and Transformer
+    def to_int_tuple(s: str):
+        return tuple(map(int, s.split(",")))
+
+    encoder = Zipformer(
+        num_features=params.feature_dim,
+        output_downsampling_factor=2,
+        zipformer_downsampling_factors=to_int_tuple(
+            params.zipformer_downsampling_factors
+        ),
+        encoder_dims=to_int_tuple(params.encoder_dims),
+        attention_dim=to_int_tuple(params.attention_dims),
+        encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims),
+        nhead=to_int_tuple(params.nhead),
+        feedforward_dim=to_int_tuple(params.feedforward_dims),
+        cnn_module_kernels=to_int_tuple(params.cnn_module_kernels),
+        num_encoder_layers=to_int_tuple(params.num_encoder_layers),
+    )
+    return encoder
+
+
+def get_decoder_model(params: AttributeDict) -> nn.Module:
+    decoder = Decoder(
+        vocab_size=params.vocab_size,
+        decoder_dim=params.decoder_dim,
+        blank_id=params.blank_id,
+        context_size=params.context_size,
+    )
+    return decoder
+
+
+def get_joiner_model(params: AttributeDict) -> nn.Module:
+    joiner = Joiner(
+        encoder_dim=int(params.encoder_dims.split(",")[-1]),
+        decoder_dim=params.decoder_dim,
+        joiner_dim=params.joiner_dim,
+        vocab_size=params.vocab_size,
+    )
+    return joiner
+
+
+def get_transducer_model(params: AttributeDict) -> nn.Module:
+    encoder = get_encoder_model(params)
+    decoder = get_decoder_model(params)
+    joiner = get_joiner_model(params)
+
+    model = Transducer(
+        encoder=encoder,
+        decoder=decoder,
+        joiner=joiner,
+        encoder_dim=int(params.encoder_dims.split(",")[-1]),
+        decoder_dim=params.decoder_dim,
+        joiner_dim=params.joiner_dim,
+        vocab_size=params.vocab_size,
+    )
+    return model
+
+
+def load_checkpoint_if_available(
+    params: AttributeDict,
+    model: nn.Module,
+    model_avg: nn.Module = None,
+    optimizer: Optional[torch.optim.Optimizer] = None,
+    scheduler: Optional[LRSchedulerType] = None,
+) -> Optional[Dict[str, Any]]:
+    """Load checkpoint from file.
+
+    If params.start_batch is positive, it will load the checkpoint from
+    `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if
+    params.start_epoch is larger than 1, it will load the checkpoint from
+    `params.start_epoch - 1`.
+
+    Apart from loading state dict for `model` and `optimizer` it also updates
+    `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
+    and `best_valid_loss` in `params`.
+
+    Args:
+      params:
+        The return value of :func:`get_params`.
+      model:
+        The training model.
+      model_avg:
+        The stored model averaged from the start of training.
+      optimizer:
+        The optimizer that we are using.
+      scheduler:
+        The scheduler that we are using.
+    Returns:
+      Return a dict containing previously saved training info.
+    """
+    if params.start_batch > 0:
+        filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt"
+    elif params.start_epoch > 1:
+        filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
+    else:
+        return None
+
+    assert filename.is_file(), f"{filename} does not exist!"
+
+    saved_params = load_checkpoint(
+        filename,
+        model=model,
+        model_avg=model_avg,
+        optimizer=optimizer,
+        scheduler=scheduler,
+    )
+
+    keys = [
+        "best_train_epoch",
+        "best_valid_epoch",
+        "batch_idx_train",
+        "best_train_loss",
+        "best_valid_loss",
+    ]
+    for k in keys:
+        params[k] = saved_params[k]
+
+    if params.start_batch > 0:
+        if "cur_epoch" in saved_params:
+            params["start_epoch"] = saved_params["cur_epoch"]
+
+        if "cur_batch_idx" in saved_params:
+            params["cur_batch_idx"] = saved_params["cur_batch_idx"]
+
+    return saved_params
+
+
+def save_checkpoint(
+    params: AttributeDict,
+    model: Union[nn.Module, DDP],
+    model_avg: Optional[nn.Module] = None,
+    optimizer: Optional[torch.optim.Optimizer] = None,
+    scheduler: Optional[LRSchedulerType] = None,
+    sampler: Optional[CutSampler] = None,
+    scaler: Optional[GradScaler] = None,
+    rank: int = 0,
+) -> None:
+    """Save model, optimizer, scheduler and training stats to file.
+
+    Args:
+      params:
+        It is returned by :func:`get_params`.
+      model:
+        The training model.
+      model_avg:
+        The stored model averaged from the start of training.
+      optimizer:
+        The optimizer used in the training.
+      sampler:
+       The sampler for the training dataset.
+      scaler:
+        The scaler used for mix precision training.
+    """
+    if rank != 0:
+        return
+    filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
+    save_checkpoint_impl(
+        filename=filename,
+        model=model,
+        model_avg=model_avg,
+        params=params,
+        optimizer=optimizer,
+        scheduler=scheduler,
+        sampler=sampler,
+        scaler=scaler,
+        rank=rank,
+    )
+
+    if params.best_train_epoch == params.cur_epoch:
+        best_train_filename = params.exp_dir / "best-train-loss.pt"
+        copyfile(src=filename, dst=best_train_filename)
+
+    if params.best_valid_epoch == params.cur_epoch:
+        best_valid_filename = params.exp_dir / "best-valid-loss.pt"
+        copyfile(src=filename, dst=best_valid_filename)
+
+
+def compute_loss(
+    params: AttributeDict,
+    model: Union[nn.Module, DDP],
+    sp: spm.SentencePieceProcessor,
+    batch: dict,
+    is_training: bool,
+) -> Tuple[Tensor, MetricsTracker]:
+    """
+    Compute transducer loss given the model and its inputs.
+
+    Args:
+      params:
+        Parameters for training. See :func:`get_params`.
+      model:
+        The model for training. It is an instance of Zipformer in our case.
+      batch:
+        A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+        for the content in it.
+      is_training:
+        True for training. False for validation. When it is True, this
+        function enables autograd during computation; when it is False, it
+        disables autograd.
+     warmup: a floating point value which increases throughout training;
+        values >= 1.0 are fully warmed up and have all modules present.
+    """
+    device = model.device if isinstance(model, DDP) else next(model.parameters()).device
+    feature = batch["inputs"]
+    # at entry, feature is (N, T, C)
+    assert feature.ndim == 3
+    feature = feature.to(device)
+
+    supervisions = batch["supervisions"]
+    feature_lens = supervisions["num_frames"].to(device)
+
+    batch_idx_train = params.batch_idx_train
+    warm_step = params.warm_step
+
+    texts = batch["supervisions"]["text"]
+    token_ids = sp.encode(texts, out_type=int)
+    y = k2.RaggedTensor(token_ids).to(device)
+
+    with torch.set_grad_enabled(is_training):
+        simple_loss, pruned_loss, ctc_output = model(
+            x=feature,
+            x_lens=feature_lens,
+            y=y,
+            prune_range=params.prune_range,
+            am_scale=params.am_scale,
+            lm_scale=params.lm_scale,
+        )
+
+        s = params.simple_loss_scale
+        # take down the scale on the simple loss from 1.0 at the start
+        # to params.simple_loss scale by warm_step.
+        simple_loss_scale = (
+            s
+            if batch_idx_train >= warm_step
+            else 1.0 - (batch_idx_train / warm_step) * (1.0 - s)
+        )
+        pruned_loss_scale = (
+            1.0
+            if batch_idx_train >= warm_step
+            else 0.1 + 0.9 * (batch_idx_train / warm_step)
+        )
+
+        loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
+
+    # Compute ctc loss
+
+    # NOTE: We need `encode_supervisions` to sort sequences with
+    # different duration in decreasing order, required by
+    # `k2.intersect_dense` called in `k2.ctc_loss`
+    with warnings.catch_warnings():
+        warnings.simplefilter("ignore")
+        supervision_segments, token_ids = encode_supervisions(
+            supervisions,
+            subsampling_factor=params.subsampling_factor,
+            token_ids=token_ids,
+        )
+
+    # Works with a BPE model
+    decoding_graph = k2.ctc_graph(token_ids, modified=False, device=device)
+    dense_fsa_vec = k2.DenseFsaVec(
+        ctc_output,
+        supervision_segments,
+        allow_truncate=params.subsampling_factor - 1,
+    )
+
+    ctc_loss = k2.ctc_loss(
+        decoding_graph=decoding_graph,
+        dense_fsa_vec=dense_fsa_vec,
+        output_beam=params.beam_size,
+        reduction="sum",
+        use_double_scores=params.use_double_scores,
+    )
+    assert ctc_loss.requires_grad == is_training
+    loss += params.ctc_loss_scale * ctc_loss
+
+    assert loss.requires_grad == is_training
+
+    info = MetricsTracker()
+    with warnings.catch_warnings():
+        warnings.simplefilter("ignore")
+        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
+
+    # Note: We use reduction=sum while computing the loss.
+    info["loss"] = loss.detach().cpu().item()
+    info["simple_loss"] = simple_loss.detach().cpu().item()
+    info["pruned_loss"] = pruned_loss.detach().cpu().item()
+    info["ctc_loss"] = ctc_loss.detach().cpu().item()
+
+    return loss, info
+
+
+def compute_validation_loss(
+    params: AttributeDict,
+    model: Union[nn.Module, DDP],
+    sp: spm.SentencePieceProcessor,
+    valid_dl: torch.utils.data.DataLoader,
+    world_size: int = 1,
+) -> MetricsTracker:
+    """Run the validation process."""
+    model.eval()
+
+    tot_loss = MetricsTracker()
+
+    for batch_idx, batch in enumerate(valid_dl):
+        loss, loss_info = compute_loss(
+            params=params,
+            model=model,
+            sp=sp,
+            batch=batch,
+            is_training=False,
+        )
+        assert loss.requires_grad is False
+        tot_loss = tot_loss + loss_info
+
+    if world_size > 1:
+        tot_loss.reduce(loss.device)
+
+    loss_value = tot_loss["loss"] / tot_loss["frames"]
+    if loss_value < params.best_valid_loss:
+        params.best_valid_epoch = params.cur_epoch
+        params.best_valid_loss = loss_value
+
+    return tot_loss
+
+
+def train_one_epoch(
+    params: AttributeDict,
+    model: Union[nn.Module, DDP],
+    optimizer: torch.optim.Optimizer,
+    scheduler: LRSchedulerType,
+    sp: spm.SentencePieceProcessor,
+    train_dl: torch.utils.data.DataLoader,
+    valid_dl: torch.utils.data.DataLoader,
+    scaler: GradScaler,
+    model_avg: Optional[nn.Module] = None,
+    tb_writer: Optional[SummaryWriter] = None,
+    world_size: int = 1,
+    rank: int = 0,
+) -> None:
+    """Train the model for one epoch.
+
+    The training loss from the mean of all frames is saved in
+    `params.train_loss`. It runs the validation process every
+    `params.valid_interval` batches.
+
+    Args:
+      params:
+        It is returned by :func:`get_params`.
+      model:
+        The model for training.
+      optimizer:
+        The optimizer we are using.
+      scheduler:
+        The learning rate scheduler, we call step() every step.
+      train_dl:
+        Dataloader for the training dataset.
+      valid_dl:
+        Dataloader for the validation dataset.
+      scaler:
+        The scaler used for mix precision training.
+      model_avg:
+        The stored model averaged from the start of training.
+      tb_writer:
+        Writer to write log messages to tensorboard.
+      world_size:
+        Number of nodes in DDP training. If it is 1, DDP is disabled.
+      rank:
+        The rank of the node in DDP training. If no DDP is used, it should
+        be set to 0.
+    """
+    model.train()
+
+    tot_loss = MetricsTracker()
+
+    cur_batch_idx = params.get("cur_batch_idx", 0)
+
+    for batch_idx, batch in enumerate(train_dl):
+        if batch_idx < cur_batch_idx:
+            continue
+        cur_batch_idx = batch_idx
+
+        params.batch_idx_train += 1
+        batch_size = len(batch["supervisions"]["text"])
+
+        try:
+            with torch.cuda.amp.autocast(enabled=params.use_fp16):
+                loss, loss_info = compute_loss(
+                    params=params,
+                    model=model,
+                    sp=sp,
+                    batch=batch,
+                    is_training=True,
+                )
+            # summary stats
+            tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
+
+            # NOTE: We use reduction==sum and loss is computed over utterances
+            # in the batch and there is no normalization to it so far.
+            scaler.scale(loss).backward()
+            set_batch_count(model, params.batch_idx_train)
+            scheduler.step_batch(params.batch_idx_train)
+
+            scaler.step(optimizer)
+            scaler.update()
+            optimizer.zero_grad()
+        except:  # noqa
+            display_and_save_batch(batch, params=params, sp=sp)
+            raise
+
+        if params.print_diagnostics and batch_idx == 5:
+            return
+
+        if (
+            rank == 0
+            and params.batch_idx_train > 0
+            and params.batch_idx_train % params.average_period == 0
+        ):
+            update_averaged_model(
+                params=params,
+                model_cur=model,
+                model_avg=model_avg,
+            )
+
+        if (
+            params.batch_idx_train > 0
+            and params.batch_idx_train % params.save_every_n == 0
+        ):
+            params.cur_batch_idx = batch_idx
+            save_checkpoint_with_global_batch_idx(
+                out_dir=params.exp_dir,
+                global_batch_idx=params.batch_idx_train,
+                model=model,
+                model_avg=model_avg,
+                params=params,
+                optimizer=optimizer,
+                scheduler=scheduler,
+                sampler=train_dl.sampler,
+                scaler=scaler,
+                rank=rank,
+            )
+            del params.cur_batch_idx
+            remove_checkpoints(
+                out_dir=params.exp_dir,
+                topk=params.keep_last_k,
+                rank=rank,
+            )
+
+        if batch_idx % 100 == 0 and params.use_fp16:
+            # If the grad scale was less than 1, try increasing it.    The _growth_interval
+            # of the grad scaler is configurable, but we can't configure it to have different
+            # behavior depending on the current grad scale.
+            cur_grad_scale = scaler._scale.item()
+            if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0):
+                scaler.update(cur_grad_scale * 2.0)
+            if cur_grad_scale < 0.01:
+                logging.warning(f"Grad scale is small: {cur_grad_scale}")
+            if cur_grad_scale < 1.0e-05:
+                raise RuntimeError(
+                    f"grad_scale is too small, exiting: {cur_grad_scale}"
+                )
+
+        if batch_idx % params.log_interval == 0:
+            cur_lr = scheduler.get_last_lr()[0]
+            cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
+
+            logging.info(
+                f"Epoch {params.cur_epoch}, "
+                f"batch {batch_idx}, loss[{loss_info}], "
+                f"tot_loss[{tot_loss}], batch size: {batch_size}, "
+                f"lr: {cur_lr:.2e}, "
+                + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
+            )
+
+            if tb_writer is not None:
+                tb_writer.add_scalar(
+                    "train/learning_rate", cur_lr, params.batch_idx_train
+                )
+
+                loss_info.write_summary(
+                    tb_writer, "train/current_", params.batch_idx_train
+                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+                if params.use_fp16:
+                    tb_writer.add_scalar(
+                        "train/grad_scale",
+                        cur_grad_scale,
+                        params.batch_idx_train,
+                    )
+
+        if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
+            logging.info("Computing validation loss")
+            valid_info = compute_validation_loss(
+                params=params,
+                model=model,
+                sp=sp,
+                valid_dl=valid_dl,
+                world_size=world_size,
+            )
+            model.train()
+            logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
+            logging.info(
+                f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
+            )
+            if tb_writer is not None:
+                valid_info.write_summary(
+                    tb_writer, "train/valid_", params.batch_idx_train
+                )
+
+    loss_value = tot_loss["loss"] / tot_loss["frames"]
+    params.train_loss = loss_value
+    if params.train_loss < params.best_train_loss:
+        params.best_train_epoch = params.cur_epoch
+        params.best_train_loss = params.train_loss
+
+
+def run(rank, world_size, args):
+    """
+    Args:
+      rank:
+        It is a value between 0 and `world_size-1`, which is
+        passed automatically by `mp.spawn()` in :func:`main`.
+        The node with rank 0 is responsible for saving checkpoint.
+      world_size:
+        Number of GPUs for DDP training.
+      args:
+        The return value of get_parser().parse_args()
+    """
+    params = get_params()
+    params.update(vars(args))
+    if params.full_libri is False:
+        params.valid_interval = 1600
+
+    fix_random_seed(params.seed)
+    if world_size > 1:
+        setup_dist(rank, world_size, params.master_port)
+
+    setup_logger(f"{params.exp_dir}/log/log-train")
+    logging.info("Training started")
+
+    if args.tensorboard and rank == 0:
+        tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
+    else:
+        tb_writer = None
+
+    device = torch.device("cpu")
+    if torch.cuda.is_available():
+        device = torch.device("cuda", rank)
+    logging.info(f"Device: {device}")
+
+    sp = spm.SentencePieceProcessor()
+    sp.load(params.bpe_model)
+
+    #  is defined in local/train_bpe_model.py
+    params.blank_id = sp.piece_to_id("")
+    params.vocab_size = sp.get_piece_size()
+
+    logging.info(params)
+
+    logging.info("About to create model")
+    model = get_transducer_model(params)
+
+    num_param = sum([p.numel() for p in model.parameters()])
+    logging.info(f"Number of model parameters: {num_param}")
+
+    assert params.save_every_n >= params.average_period
+    model_avg: Optional[nn.Module] = None
+    if rank == 0:
+        # model_avg is only used with rank 0
+        model_avg = copy.deepcopy(model).to(torch.float64)
+
+    assert params.start_epoch > 0, params.start_epoch
+    checkpoints = load_checkpoint_if_available(
+        params=params, model=model, model_avg=model_avg
+    )
+
+    model.to(device)
+    if world_size > 1:
+        logging.info("Using DDP")
+        model = DDP(model, device_ids=[rank], find_unused_parameters=True)
+
+    optimizer = ScaledAdam(model.parameters(), lr=params.base_lr, clipping_scale=2.0)
+
+    scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
+
+    if checkpoints and "optimizer" in checkpoints:
+        logging.info("Loading optimizer state dict")
+        optimizer.load_state_dict(checkpoints["optimizer"])
+
+    if (
+        checkpoints
+        and "scheduler" in checkpoints
+        and checkpoints["scheduler"] is not None
+    ):
+        logging.info("Loading scheduler state dict")
+        scheduler.load_state_dict(checkpoints["scheduler"])
+
+    if params.print_diagnostics:
+        opts = diagnostics.TensorDiagnosticOptions(
+            2**22
+        )  # allow 4 megabytes per sub-module
+        diagnostic = diagnostics.attach_diagnostics(model, opts)
+
+    if params.inf_check:
+        register_inf_check_hooks(model)
+
+    librispeech = LibriSpeechAsrDataModule(args)
+
+    train_cuts = librispeech.train_clean_100_cuts()
+    if params.full_libri:
+        train_cuts += librispeech.train_clean_360_cuts()
+        train_cuts += librispeech.train_other_500_cuts()
+
+    def remove_short_and_long_utt(c: Cut):
+        # Keep only utterances with duration between 1 second and 20 seconds
+        #
+        # Caution: There is a reason to select 20.0 here. Please see
+        # ../local/display_manifest_statistics.py
+        #
+        # You should use ../local/display_manifest_statistics.py to get
+        # an utterance duration distribution for your dataset to select
+        # the threshold
+        return 1.0 <= c.duration <= 20.0
+
+    train_cuts = train_cuts.filter(remove_short_and_long_utt)
+
+    if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
+        # We only load the sampler's state dict when it loads a checkpoint
+        # saved in the middle of an epoch
+        sampler_state_dict = checkpoints["sampler"]
+    else:
+        sampler_state_dict = None
+
+    train_dl = librispeech.train_dataloaders(
+        train_cuts, sampler_state_dict=sampler_state_dict
+    )
+
+    valid_cuts = librispeech.dev_clean_cuts()
+    valid_cuts += librispeech.dev_other_cuts()
+    valid_dl = librispeech.valid_dataloaders(valid_cuts)
+
+    if not params.print_diagnostics:
+        scan_pessimistic_batches_for_oom(
+            model=model,
+            train_dl=train_dl,
+            optimizer=optimizer,
+            sp=sp,
+            params=params,
+        )
+
+    scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
+    if checkpoints and "grad_scaler" in checkpoints:
+        logging.info("Loading grad scaler state dict")
+        scaler.load_state_dict(checkpoints["grad_scaler"])
+
+    for epoch in range(params.start_epoch, params.num_epochs + 1):
+        scheduler.step_epoch(epoch - 1)
+        fix_random_seed(params.seed + epoch - 1)
+        train_dl.sampler.set_epoch(epoch - 1)
+
+        if tb_writer is not None:
+            tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
+
+        params.cur_epoch = epoch
+
+        train_one_epoch(
+            params=params,
+            model=model,
+            model_avg=model_avg,
+            optimizer=optimizer,
+            scheduler=scheduler,
+            sp=sp,
+            train_dl=train_dl,
+            valid_dl=valid_dl,
+            scaler=scaler,
+            tb_writer=tb_writer,
+            world_size=world_size,
+            rank=rank,
+        )
+
+        if params.print_diagnostics:
+            diagnostic.print_diagnostics()
+            break
+
+        save_checkpoint(
+            params=params,
+            model=model,
+            model_avg=model_avg,
+            optimizer=optimizer,
+            scheduler=scheduler,
+            sampler=train_dl.sampler,
+            scaler=scaler,
+            rank=rank,
+        )
+
+    logging.info("Done!")
+
+    if world_size > 1:
+        torch.distributed.barrier()
+        cleanup_dist()
+
+
+def display_and_save_batch(
+    batch: dict,
+    params: AttributeDict,
+    sp: spm.SentencePieceProcessor,
+) -> None:
+    """Display the batch statistics and save the batch into disk.
+
+    Args:
+      batch:
+        A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+        for the content in it.
+      params:
+        Parameters for training. See :func:`get_params`.
+      sp:
+        The BPE model.
+    """
+    from lhotse.utils import uuid4
+
+    filename = f"{params.exp_dir}/batch-{uuid4()}.pt"
+    logging.info(f"Saving batch to {filename}")
+    torch.save(batch, filename)
+
+    supervisions = batch["supervisions"]
+    features = batch["inputs"]
+
+    logging.info(f"features shape: {features.shape}")
+
+    y = sp.encode(supervisions["text"], out_type=int)
+    num_tokens = sum(len(i) for i in y)
+    logging.info(f"num tokens: {num_tokens}")
+
+
+def scan_pessimistic_batches_for_oom(
+    model: Union[nn.Module, DDP],
+    train_dl: torch.utils.data.DataLoader,
+    optimizer: torch.optim.Optimizer,
+    sp: spm.SentencePieceProcessor,
+    params: AttributeDict,
+):
+    from lhotse.dataset import find_pessimistic_batches
+
+    logging.info(
+        "Sanity check -- see if any of the batches in epoch 1 would cause OOM."
+    )
+    batches, crit_values = find_pessimistic_batches(train_dl.sampler)
+    for criterion, cuts in batches.items():
+        batch = train_dl.dataset[cuts]
+        try:
+            with torch.cuda.amp.autocast(enabled=params.use_fp16):
+                loss, _ = compute_loss(
+                    params=params,
+                    model=model,
+                    sp=sp,
+                    batch=batch,
+                    is_training=True,
+                )
+            loss.backward()
+            optimizer.zero_grad()
+        except Exception as e:
+            if "CUDA out of memory" in str(e):
+                logging.error(
+                    "Your GPU ran out of memory with the current "
+                    "max_duration setting. We recommend decreasing "
+                    "max_duration and trying again.\n"
+                    f"Failing criterion: {criterion} "
+                    f"(={crit_values[criterion]}) ..."
+                )
+            display_and_save_batch(batch, params=params, sp=sp)
+            raise
+        logging.info(
+            f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
+        )
+
+
+def main():
+    parser = get_parser()
+    LibriSpeechAsrDataModule.add_arguments(parser)
+    args = parser.parse_args()
+    args.exp_dir = Path(args.exp_dir)
+
+    world_size = args.world_size
+    assert world_size >= 1
+    if world_size > 1:
+        mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
+    else:
+        run(rank=0, world_size=1, args=args)
+
+
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+if __name__ == "__main__":
+    main()
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/zipformer.py
new file mode 120000
index 000000000..79b076556
--- /dev/null
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/zipformer.py
@@ -0,0 +1 @@
+../pruned_transducer_stateless7/zipformer.py
\ No newline at end of file
diff --git a/icefall/utils.py b/icefall/utils.py
index d852491c8..99e51a2a9 100644
--- a/icefall/utils.py
+++ b/icefall/utils.py
@@ -175,11 +175,13 @@ class AttributeDict(dict):
 
 
 def encode_supervisions(
-    supervisions: dict, subsampling_factor: int
-) -> Tuple[torch.Tensor, List[str]]:
+    supervisions: dict,
+    subsampling_factor: int,
+    token_ids: Optional[List[List[int]]] = None,
+) -> Tuple[torch.Tensor, Union[List[str], List[List[int]]]]:
     """
     Encodes Lhotse's ``batch["supervisions"]`` dict into
-    a pair of torch Tensor, and a list of transcription strings.
+    a pair of torch Tensor, and a list of transcription strings or token indexes
 
     The supervision tensor has shape ``(batch_size, 3)``.
     Its second dimension contains information about sequence index [0],
@@ -208,10 +210,14 @@ def encode_supervisions(
 
     indices = torch.argsort(supervision_segments[:, 2], descending=True)
     supervision_segments = supervision_segments[indices]
-    texts = supervisions["text"]
-    texts = [texts[idx] for idx in indices]
 
-    return supervision_segments, texts
+    if token_ids is None:
+        texts = supervisions["text"]
+        res = [texts[idx] for idx in indices]
+    else:
+        res = [token_ids[idx] for idx in indices]
+
+    return supervision_segments, res
 
 
 def get_texts(

From e6a67270128f607f49c81327190aca63bb3bb4eb Mon Sep 17 00:00:00 2001
From: Senyan Li <1149593720@qq.com>
Date: Sat, 3 Dec 2022 23:50:49 +0800
Subject: [PATCH 033/174] Add Tibetan Amdo dialect xbmu_amdo31 in egs (#706)

* add egs/xbmu_amdo31

* fix xbmu_amdo31/ASR/pruned_transducer_stateless5/train.py

* fix xbmu_amdo31/ASR/pruned_transducer_stateless5/asr_datamodule.py

* fix xbmu_amdo31/ASR/prepare.sh

* add RESULTS.md and README.md

* dix pruned_transducer_stateless5 decode.py

* add transducer stateless7

* fix transducer_stateless7

* fix RESULTS.md error

* Add pruned_transducer_stateless7 validation set results
---
 egs/xbmu_amdo31/ASR/README.md                 |   16 +
 egs/xbmu_amdo31/ASR/RESULTS.md                |   92 ++
 egs/xbmu_amdo31/ASR/local/compile_hlg.py      |    1 +
 egs/xbmu_amdo31/ASR/local/compile_lg.py       |    1 +
 .../ASR/local/compute_fbank_musan.py          |    1 +
 .../ASR/local/compute_fbank_xbmu_amdo31.py    |  130 ++
 .../convert_transcript_words_to_tokens.py     |    1 +
 egs/xbmu_amdo31/ASR/local/filter_cuts.py      |    1 +
 .../ASR/local/generate_unique_lexicon.py      |    1 +
 egs/xbmu_amdo31/ASR/local/prepare_lang.py     |    1 +
 egs/xbmu_amdo31/ASR/local/prepare_lang_bpe.py |    1 +
 .../ASR/local/prepare_lm_training_data.py     |    1 +
 .../ASR/local/sort_lm_training_data.py        |    1 +
 egs/xbmu_amdo31/ASR/local/train_bpe_model.py  |    1 +
 .../ASR/local/validate_bpe_lexicon.py         |    1 +
 egs/xbmu_amdo31/ASR/prepare.sh                |  357 +++++
 .../pruned_transducer_stateless5/__init__.py  |    0
 .../asr_datamodule.py                         |  408 ++++++
 .../beam_search.py                            |    1 +
 .../pruned_transducer_stateless5/conformer.py |    1 +
 .../pruned_transducer_stateless5/decode.py    |  970 +++++++++++++
 .../decode_stream.py                          |    1 +
 .../pruned_transducer_stateless5/decoder.py   |    1 +
 .../encoder_interface.py                      |    1 +
 .../pruned_transducer_stateless5/export.py    |  287 ++++
 .../pruned_transducer_stateless5/joiner.py    |    1 +
 .../ASR/pruned_transducer_stateless5/lstmp.py |    1 +
 .../ASR/pruned_transducer_stateless5/model.py |    1 +
 .../ASR/pruned_transducer_stateless5/optim.py |    1 +
 .../pretrained.py                             |  344 +++++
 .../pruned_transducer_stateless5/scaling.py   |    1 +
 .../scaling_converter.py                      |    1 +
 .../streaming_beam_search.py                  |    1 +
 .../streaming_decode.py                       |    1 +
 .../test_model.py                             |   65 +
 .../ASR/pruned_transducer_stateless5/train.py | 1187 ++++++++++++++++
 .../pruned_transducer_stateless7/__init__.py  |    0
 .../asr_datamodule.py                         |    1 +
 .../beam_search.py                            |    1 +
 .../pruned_transducer_stateless7/decode.py    |  843 ++++++++++++
 .../pruned_transducer_stateless7/decoder.py   |    1 +
 .../encoder_interface.py                      |    1 +
 .../pruned_transducer_stateless7/export.py    |    1 +
 .../jit_pretrained.py                         |    1 +
 .../pruned_transducer_stateless7/joiner.py    |    1 +
 .../ASR/pruned_transducer_stateless7/model.py |    1 +
 .../ASR/pruned_transducer_stateless7/optim.py |    1 +
 .../pretrained.py                             |  355 +++++
 .../pruned_transducer_stateless7/scaling.py   |    1 +
 .../scaling_converter.py                      |    1 +
 .../test_model.py                             |    1 +
 .../ASR/pruned_transducer_stateless7/train.py | 1224 +++++++++++++++++
 .../pruned_transducer_stateless7/zipformer.py |    1 +
 egs/xbmu_amdo31/ASR/shared                    |    1 +
 54 files changed, 6317 insertions(+)
 create mode 100644 egs/xbmu_amdo31/ASR/README.md
 create mode 100644 egs/xbmu_amdo31/ASR/RESULTS.md
 create mode 120000 egs/xbmu_amdo31/ASR/local/compile_hlg.py
 create mode 120000 egs/xbmu_amdo31/ASR/local/compile_lg.py
 create mode 120000 egs/xbmu_amdo31/ASR/local/compute_fbank_musan.py
 create mode 100755 egs/xbmu_amdo31/ASR/local/compute_fbank_xbmu_amdo31.py
 create mode 120000 egs/xbmu_amdo31/ASR/local/convert_transcript_words_to_tokens.py
 create mode 120000 egs/xbmu_amdo31/ASR/local/filter_cuts.py
 create mode 120000 egs/xbmu_amdo31/ASR/local/generate_unique_lexicon.py
 create mode 120000 egs/xbmu_amdo31/ASR/local/prepare_lang.py
 create mode 120000 egs/xbmu_amdo31/ASR/local/prepare_lang_bpe.py
 create mode 120000 egs/xbmu_amdo31/ASR/local/prepare_lm_training_data.py
 create mode 120000 egs/xbmu_amdo31/ASR/local/sort_lm_training_data.py
 create mode 120000 egs/xbmu_amdo31/ASR/local/train_bpe_model.py
 create mode 120000 egs/xbmu_amdo31/ASR/local/validate_bpe_lexicon.py
 create mode 100755 egs/xbmu_amdo31/ASR/prepare.sh
 create mode 100644 egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/__init__.py
 create mode 100644 egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/asr_datamodule.py
 create mode 120000 egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/beam_search.py
 create mode 120000 egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/conformer.py
 create mode 100755 egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/decode.py
 create mode 120000 egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/decode_stream.py
 create mode 120000 egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/decoder.py
 create mode 120000 egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/encoder_interface.py
 create mode 100755 egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/export.py
 create mode 120000 egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/joiner.py
 create mode 120000 egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/lstmp.py
 create mode 120000 egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/model.py
 create mode 120000 egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/optim.py
 create mode 100755 egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/pretrained.py
 create mode 120000 egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/scaling.py
 create mode 120000 egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/scaling_converter.py
 create mode 120000 egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/streaming_beam_search.py
 create mode 120000 egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/streaming_decode.py
 create mode 100755 egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/test_model.py
 create mode 100755 egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/train.py
 create mode 100644 egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/__init__.py
 create mode 120000 egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/asr_datamodule.py
 create mode 120000 egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/beam_search.py
 create mode 100755 egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/decode.py
 create mode 120000 egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/decoder.py
 create mode 120000 egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/encoder_interface.py
 create mode 120000 egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/export.py
 create mode 120000 egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/jit_pretrained.py
 create mode 120000 egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/joiner.py
 create mode 120000 egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/model.py
 create mode 120000 egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/optim.py
 create mode 100755 egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/pretrained.py
 create mode 120000 egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/scaling.py
 create mode 120000 egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/scaling_converter.py
 create mode 120000 egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/test_model.py
 create mode 100755 egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/train.py
 create mode 120000 egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/zipformer.py
 create mode 120000 egs/xbmu_amdo31/ASR/shared

diff --git a/egs/xbmu_amdo31/ASR/README.md b/egs/xbmu_amdo31/ASR/README.md
new file mode 100644
index 000000000..0a441d070
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/README.md
@@ -0,0 +1,16 @@
+# Introduction
+About the XBMU-AMDO31 corpus
+XBMU-AMDO31 is an open-source Amdo Tibetan speech corpus published by Northwest Minzu University.
+publicly available on https://huggingface.co/datasets/syzym/xbmu_amdo31
+
+XBMU-AMDO31 dataset is a speech recognition corpus of Amdo Tibetan dialect. 
+The open source corpus contains 31 hours of speech data and resources related 
+to build speech recognition systems,including transcribed texts and a Tibetan 
+pronunciation lexicon.
+(The lexicon is a Tibetan lexicon of the Lhasa dialect, which has been reused 
+for the Amdo dialect because of the uniformity of the Tibetan language)
+The dataset can be used to train a model for Amdo Tibetan Automatic Speech Recognition (ASR). 
+
+This recipe includes some different ASR models trained with XBMU-AMDO31.
+
+[./RESULTS.md](./RESULTS.md) contains the latest results.
\ No newline at end of file
diff --git a/egs/xbmu_amdo31/ASR/RESULTS.md b/egs/xbmu_amdo31/ASR/RESULTS.md
new file mode 100644
index 000000000..1bd9b2e2b
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/RESULTS.md
@@ -0,0 +1,92 @@
+## Results
+
+### XBMU-AMDO31 BPE training result (Stateless Transducer)
+
+#### Pruned transducer stateless 5
+
+[./pruned_transducer_stateless5](./pruned_transducer_stateless5)
+
+It uses pruned RNN-T.
+
+A pre-trained model and decoding logs can be found at 
+
+You can use  to deploy it.
+
+Number of model parameters: 87801200, i.e., 87.8 M
+
+|                        | test | dev  | comment                               |
+|------------------------|------|------|---------------------------------------|
+| greedy search          | 11.06| 11.73| --epoch 28 --avg 23 --max-duration 600|
+| beam search            | 10.64| 11.42| --epoch 28 --avg 23 --max-duration 600|
+| modified beam search   | 10.57| 11.24| --epoch 28 --avg 23 --max-duration 600|
+
+
+Training command is:
+
+```bash
+cd egs/xbmu_amdo31/ASR
+./prepare.sh
+
+export CUDA_VISIBLE_DEVICES="0"
+
+./pruned_transducer_stateless5/train.py
+```
+
+**Caution**: It uses `--context-size=1`.
+
+
+The decoding command is:
+```bash
+for method in greedy_search beam_search modified_beam_search;
+do
+./pruned_transducer_stateless5/decode.py \
+    --epoch 28 \
+    --avg 23 \
+    --exp-dir ./pruned_transducer_stateless5/exp \
+    --max-duration 600 \
+    --decoding-method $method
+done
+```
+
+### pruned_transducer_stateless7 (zipformer)
+
+See  for more details.
+
+[pruned_transducer_stateless7](./pruned_transducer_stateless7)
+
+You can find a pretrained model, training logs, decoding logs, and decoding
+results at:
+
+
+You can use  to deploy it.
+
+Number of model parameters: 70369391, i.e., 70.37 M
+
+|                      | test | dev  | comment                                |
+|----------------------|------|------|----------------------------------------|
+| greedy search        | 10.06| 10.59| --epoch 23 --avg 11 --max-duration 600 |
+| beam search          | 9.77 | 10.11| --epoch 23 --avg 11 --max-duration 600 |
+| modified beam search | 9.7  | 10.12| --epoch 23 --avg 11 --max-duration 600 |
+
+The training commands are:
+```bash
+export CUDA_VISIBLE_DEVICES="0"
+
+./pruned_transducer_stateless7/train.py
+```
+
+The decoding commands are:
+```bash
+for m in greedy_search beam_search modified_beam_search; do
+  for epoch in 23; do
+    for avg in 11; do
+      ./pruned_transducer_stateless7/decode.py \
+          --epoch $epoch \
+          --avg $avg \
+          --exp-dir ./pruned_transducer_stateless7/exp \
+          --max-duration 600 \
+          --decoding-method $m
+    done
+  done
+done
+```
diff --git a/egs/xbmu_amdo31/ASR/local/compile_hlg.py b/egs/xbmu_amdo31/ASR/local/compile_hlg.py
new file mode 120000
index 000000000..471aa7fb4
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/local/compile_hlg.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/local/compile_hlg.py
\ No newline at end of file
diff --git a/egs/xbmu_amdo31/ASR/local/compile_lg.py b/egs/xbmu_amdo31/ASR/local/compile_lg.py
new file mode 120000
index 000000000..462d6d3fb
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/local/compile_lg.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/local/compile_lg.py
\ No newline at end of file
diff --git a/egs/xbmu_amdo31/ASR/local/compute_fbank_musan.py b/egs/xbmu_amdo31/ASR/local/compute_fbank_musan.py
new file mode 120000
index 000000000..5833f2484
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/local/compute_fbank_musan.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/local/compute_fbank_musan.py
\ No newline at end of file
diff --git a/egs/xbmu_amdo31/ASR/local/compute_fbank_xbmu_amdo31.py b/egs/xbmu_amdo31/ASR/local/compute_fbank_xbmu_amdo31.py
new file mode 100755
index 000000000..a593e7be3
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/local/compute_fbank_xbmu_amdo31.py
@@ -0,0 +1,130 @@
+#!/usr/bin/env python3
+# Copyright    2021  Xiaomi Corp.        (authors: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""
+This file computes fbank features of the XBMU-AMDO31 dataset.
+It looks for manifests in the directory data/manifests.
+
+The generated fbank features are saved in data/fbank.
+"""
+
+import argparse
+import logging
+import os
+from pathlib import Path
+from typing import Optional
+
+import sentencepiece as spm
+import torch
+from filter_cuts import filter_cuts
+from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter
+from lhotse.recipes.utils import read_manifests_if_cached
+
+from icefall.utils import get_executor
+
+# Torch's multithreaded behavior needs to be disabled or
+# it wastes a lot of CPU and slow things down.
+# Do this outside of main() in case it needs to take effect
+# even when we are not invoking the main (e.g. when spawning subprocesses).
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+
+def get_args():
+    parser = argparse.ArgumentParser()
+
+    parser.add_argument(
+        "--bpe-model",
+        type=str,
+        help="""Path to the bpe.model. If not None, we will remove short and
+        long utterances before extracting features""",
+    )
+    return parser.parse_args()
+
+
+def compute_fbank_xbmu_amdo31(bpe_model: Optional[str] = None):
+    src_dir = Path("data/manifests")
+    output_dir = Path("data/fbank")
+    num_jobs = min(15, os.cpu_count())
+    num_mel_bins = 80
+
+    if bpe_model:
+        logging.info(f"Loading {bpe_model}")
+        sp = spm.SentencePieceProcessor()
+        sp.load(bpe_model)
+
+    dataset_parts = (
+        "train",
+        "dev",
+        "test",
+    )
+    prefix = "xbmu_amdo31"
+    suffix = "jsonl.gz"
+    manifests = read_manifests_if_cached(
+        dataset_parts=dataset_parts,
+        output_dir=src_dir,
+        prefix=prefix,
+        suffix=suffix,
+    )
+    assert manifests is not None
+
+    assert len(manifests) == len(dataset_parts), (
+        len(manifests),
+        len(dataset_parts),
+        list(manifests.keys()),
+        dataset_parts,
+    )
+
+    extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
+
+    with get_executor() as ex:  # Initialize the executor only once.
+        for partition, m in manifests.items():
+            cuts_filename = f"{prefix}_cuts_{partition}.{suffix}"
+            if (output_dir / cuts_filename).is_file():
+                logging.info(f"{partition} already exists - skipping.")
+                continue
+            logging.info(f"Processing {partition}")
+            cut_set = CutSet.from_manifests(
+                recordings=m["recordings"],
+                supervisions=m["supervisions"],
+            )
+            if bpe_model:
+                cut_set = filter_cuts(cut_set, sp)
+
+            if "train" in partition:
+                cut_set = (
+                    cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
+                )
+            cut_set = cut_set.compute_and_store_features(
+                extractor=extractor,
+                storage_path=f"{output_dir}/{prefix}_feats_{partition}",
+                # when an executor is specified, make more partitions
+                num_jobs=num_jobs if ex is None else 80,
+                executor=ex,
+                storage_type=LilcomChunkyWriter,
+            )
+            cut_set.to_file(output_dir / cuts_filename)
+
+
+if __name__ == "__main__":
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+    logging.basicConfig(format=formatter, level=logging.INFO)
+    args = get_args()
+    logging.info(vars(args))
+    compute_fbank_xbmu_amdo31(bpe_model=args.bpe_model)
diff --git a/egs/xbmu_amdo31/ASR/local/convert_transcript_words_to_tokens.py b/egs/xbmu_amdo31/ASR/local/convert_transcript_words_to_tokens.py
new file mode 120000
index 000000000..2ce13fd69
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/local/convert_transcript_words_to_tokens.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/local/convert_transcript_words_to_tokens.py
\ No newline at end of file
diff --git a/egs/xbmu_amdo31/ASR/local/filter_cuts.py b/egs/xbmu_amdo31/ASR/local/filter_cuts.py
new file mode 120000
index 000000000..27aca1729
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/local/filter_cuts.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/local/filter_cuts.py
\ No newline at end of file
diff --git a/egs/xbmu_amdo31/ASR/local/generate_unique_lexicon.py b/egs/xbmu_amdo31/ASR/local/generate_unique_lexicon.py
new file mode 120000
index 000000000..c0aea1403
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/local/generate_unique_lexicon.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/local/generate_unique_lexicon.py
\ No newline at end of file
diff --git a/egs/xbmu_amdo31/ASR/local/prepare_lang.py b/egs/xbmu_amdo31/ASR/local/prepare_lang.py
new file mode 120000
index 000000000..747f2ab39
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/local/prepare_lang.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/local/prepare_lang.py
\ No newline at end of file
diff --git a/egs/xbmu_amdo31/ASR/local/prepare_lang_bpe.py b/egs/xbmu_amdo31/ASR/local/prepare_lang_bpe.py
new file mode 120000
index 000000000..36b40e7fc
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/local/prepare_lang_bpe.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/local/prepare_lang_bpe.py
\ No newline at end of file
diff --git a/egs/xbmu_amdo31/ASR/local/prepare_lm_training_data.py b/egs/xbmu_amdo31/ASR/local/prepare_lm_training_data.py
new file mode 120000
index 000000000..abc00d421
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/local/prepare_lm_training_data.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/local/prepare_lm_training_data.py
\ No newline at end of file
diff --git a/egs/xbmu_amdo31/ASR/local/sort_lm_training_data.py b/egs/xbmu_amdo31/ASR/local/sort_lm_training_data.py
new file mode 120000
index 000000000..1d6ccbe33
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/local/sort_lm_training_data.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/local/sort_lm_training_data.py
\ No newline at end of file
diff --git a/egs/xbmu_amdo31/ASR/local/train_bpe_model.py b/egs/xbmu_amdo31/ASR/local/train_bpe_model.py
new file mode 120000
index 000000000..6fad36421
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/local/train_bpe_model.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/local/train_bpe_model.py
\ No newline at end of file
diff --git a/egs/xbmu_amdo31/ASR/local/validate_bpe_lexicon.py b/egs/xbmu_amdo31/ASR/local/validate_bpe_lexicon.py
new file mode 120000
index 000000000..721bb48e7
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/local/validate_bpe_lexicon.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/local/validate_bpe_lexicon.py
\ No newline at end of file
diff --git a/egs/xbmu_amdo31/ASR/prepare.sh b/egs/xbmu_amdo31/ASR/prepare.sh
new file mode 100755
index 000000000..32ae440f7
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/prepare.sh
@@ -0,0 +1,357 @@
+#!/usr/bin/env bash
+
+set -eou pipefail
+
+nj=15
+stage=-1
+stop_stage=100
+
+# We assume dl_dir (download dir) contains the following
+# directories and files. If not, they will be downloaded
+# by this script automatically.
+#
+#  - $dl_dir/xbmu_amdo31
+#      You can find data, resource, etc, inside it.
+#      You can download them from https://huggingface.co/datasets/syzym/xbmu_amdo31
+#
+#  - $dl_dir/lm
+#      This directory contains the following files downloaded from
+#       git lfs install
+#       https://huggingface.co/syzym/xbmu_amdo31_lm
+#
+#        - tibetan.3-gram.arpa
+#        - tibetan.4-gram.arpa
+#
+#  - $dl_dir/musan
+#      This directory contains the following directories downloaded from
+#       http://www.openslr.org/17/
+#
+#     - music
+#     - noise
+#     - speech
+
+dl_dir=$PWD/download
+
+. shared/parse_options.sh || exit 1
+
+# vocab size for sentence piece models.
+# It will generate data/lang_bpe_xxx,
+# data/lang_bpe_yyy if the array contains xxx, yyy
+vocab_sizes=(
+  1000
+  500
+)
+
+# All files generated by this script are saved in "data".
+# You can safely remove "data" and rerun this script to regenerate it.
+mkdir -p data
+
+log() {
+  # This function is from espnet
+  local fname=${BASH_SOURCE[1]##*/}
+  echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
+}
+
+log "dl_dir: $dl_dir"
+
+if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
+  log "stage -1: Download LM"
+  # We assume that you have installed the git-lfs, if not, you could install it
+  # using: `sudo apt-get install git-lfs && git-lfs install`
+  git lfs 1>/dev/null 2>&1 || (echo "please install git-lfs, consider using: sudo apt-get install git-lfs && git-lfs install" && exit 1)
+
+  if [ ! -f $dl_dir/lm/3-gram.unpruned.arpa ]; then
+    git clone https://huggingface.co/syzym/xbmu_amdo31_lm $dl_dir/lm
+    pushd $dl_dir/lm
+    git lfs pull --include "tibetan.3-gram.arpa"
+    git lfs pull --include "tibetan.4-gram.arpa"
+    popd
+  fi
+fi
+
+if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
+  log "Stage 0: Download data"
+
+  # If you have pre-downloaded it to /path/to/xbmu_amdo31,
+  # you can create a symlink
+  #
+  #   ln -sfv /path/to/xbmu_amdo31 $dl_dir/xbmu_amdo31
+  #
+  
+  if [ ! -f $dl_dir/xbmu_amdo31 ]; then
+    git lfs 1>/dev/null 2>&1 || (echo "please install git-lfs, consider using: sudo apt-get install git-lfs && git-lfs install" && exit 1)
+    lhotse download xbmu-amdo31 $dl_dir
+  fi
+
+  # If you have pre-downloaded it to /path/to/musan,
+  # you can create a symlink
+  #
+  #   ln -sfv /path/to/musan $dl_dir/
+  #
+  if [ ! -d $dl_dir/musan ]; then
+    lhotse download musan $dl_dir
+  fi
+fi
+
+if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
+  log "Stage 1: Prepare xbmu_amdo31 manifest"
+  # We assume that you have downloaded the xbmu_amdo31 corpus
+  # to $dl_dir/xbmu_amdo31
+  if [ ! -f data/manifests/.xbmu_amdo31_manifests.done ]; then
+    mkdir -p data/manifests
+    lhotse prepare xbmu-amdo31 $dl_dir/xbmu_amdo31 data/manifests
+    touch data/manifests/.xbmu_amdo31_manifests.done
+  fi
+fi
+
+if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
+  log "Stage 2: Prepare musan manifest"
+  # We assume that you have downloaded the musan corpus
+  # to data/musan
+  if [ ! -f data/manifests/.musan_manifests.done ]; then
+    log "It may take 6 minutes"
+    mkdir -p data/manifests
+    lhotse prepare musan $dl_dir/musan data/manifests
+    touch data/manifests/.musan_manifests.done
+  fi
+fi
+
+if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
+  log "Stage 3: Compute fbank for xbmu_amdo31"
+  if [ ! -f data/fbank/.xbmu_amdo31.done ]; then
+    mkdir -p data/fbank
+    ./local/compute_fbank_xbmu_amdo31.py
+    touch data/fbank/.xbmu_amdo31.done
+  fi
+fi
+
+
+
+if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
+  log "Stage 4: Compute fbank for musan"
+  if [ ! -f data/fbank/.msuan.done ]; then
+    mkdir -p data/fbank
+    ./local/compute_fbank_musan.py
+    touch data/fbank/.msuan.done
+  fi
+fi
+
+
+if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
+  log "Stage 5: Prepare phone based lang"
+  lang_dir=data/lang_phone
+  mkdir -p $lang_dir
+
+  (echo '!SIL SIL'; echo ' SPN'; echo ' SPN'; ) |
+    cat - $dl_dir/xbmu_amdo31/resource/lexicon.txt |
+    sort | uniq > $lang_dir/lexicon.txt
+
+  ./local/generate_unique_lexicon.py --lang-dir $lang_dir
+
+  if [ ! -f $lang_dir/L_disambig.pt ]; then
+    ./local/prepare_lang.py --lang-dir $lang_dir
+  fi
+fi
+
+
+if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
+  log "Stage 6: Prepare BPE based lang"
+
+  for vocab_size in ${vocab_sizes[@]}; do
+    lang_dir=data/lang_bpe_${vocab_size}
+    mkdir -p $lang_dir
+    # We reuse words.txt from phone based lexicon
+    # so that the two can share G.pt later.
+    cp data/lang_phone/words.txt $lang_dir
+
+  if [ ! -f $lang_dir/transcript_words.txt ]; then
+    log "Generate data to train phone based bigram P"
+    xbmu_amdo31_text=$dl_dir/xbmu_amdo31/data/transcript/transcript_clean.txt
+    xbmu_amdo31_train_uid=$dl_dir/xbmu_amdo31/data/transcript/xbmu_amdo31_train_uid
+    find $dl_dir/xbmu_amdo31/data/wav/train -name "*.wav" | sed 's/\.wav//g' | awk -F '-' '{print $NF}' > $xbmu_amdo31_train_uid
+    awk 'NR==FNR{uid[$1]=$1} NR!=FNR{if($1 in uid) print $0}' $xbmu_amdo31_train_uid $xbmu_amdo31_text |
+	    cut -d " " -f 2- > $lang_dir/transcript_words.txt
+  fi
+
+    if [ ! -f $lang_dir/bpe.model ]; then
+      ./local/train_bpe_model.py \
+        --lang-dir $lang_dir \
+        --vocab-size $vocab_size \
+        --transcript $lang_dir/transcript_words.txt
+    fi
+
+    if [ ! -f $lang_dir/L_disambig.pt ]; then
+      ./local/prepare_lang_bpe.py --lang-dir $lang_dir
+
+      log "Validating $lang_dir/lexicon.txt"
+      ./local/validate_bpe_lexicon.py \
+        --lexicon $lang_dir/lexicon.txt \
+        --bpe-model $lang_dir/bpe.model
+    fi
+  done
+fi
+
+if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
+  log "Stage 7: Prepare bigram P"
+
+  for vocab_size in ${vocab_sizes[@]}; do
+    lang_dir=data/lang_bpe_${vocab_size}
+
+    if [ ! -f $lang_dir/transcript_tokens.txt ]; then
+      ./local/convert_transcript_words_to_tokens.py \
+        --lexicon $lang_dir/lexicon.txt \
+        --transcript $lang_dir/transcript_words.txt \
+        --oov "" \
+        > $lang_dir/transcript_tokens.txt
+    fi
+
+    if [ ! -f $lang_dir/P.arpa ]; then
+      ./shared/make_kn_lm.py \
+        -ngram-order 2 \
+        -text $lang_dir/transcript_tokens.txt \
+        -lm $lang_dir/P.arpa
+    fi
+
+    if [ ! -f $lang_dir/P.fst.txt ]; then
+      python3 -m kaldilm \
+        --read-symbol-table="$lang_dir/tokens.txt" \
+        --disambig-symbol='#0' \
+        --max-order=2 \
+        $lang_dir/P.arpa > $lang_dir/P.fst.txt
+    fi
+  done
+fi
+
+if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
+  log "Stage 8: Prepare G"
+  # We assume you have install kaldilm, if not, please install
+  # it using: pip install kaldilm
+
+  mkdir -p data/lm
+  if [ ! -f data/lm/G_3_gram.fst.txt ]; then
+    # It is used in building HLG
+    python3 -m kaldilm \
+      --read-symbol-table="data/lang_phone/words.txt" \
+      --disambig-symbol='#0' \
+      --max-order=3 \
+      $dl_dir/lm/tibetan.3-gram.arpa > data/lm/G_3_gram.fst.txt
+  fi
+
+  if [ ! -f data/lm/G_4_gram.fst.txt ]; then
+    # It is used for LM rescoring
+    python3 -m kaldilm \
+      --read-symbol-table="data/lang_phone/words.txt" \
+      --disambig-symbol='#0' \
+      --max-order=4 \
+      $dl_dir/lm/tibetan.4-gram.arpa > data/lm/G_4_gram.fst.txt
+  fi
+fi
+
+if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
+  log "Stage 9: Compile HLG"
+  ./local/compile_hlg.py --lang-dir data/lang_phone
+
+  for vocab_size in ${vocab_sizes[@]}; do
+    lang_dir=data/lang_bpe_${vocab_size}
+    ./local/compile_hlg.py --lang-dir $lang_dir
+  done
+fi
+
+# Compile LG for RNN-T fast_beam_search decoding
+if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then
+  log "Stage 10: Compile LG"
+  ./local/compile_lg.py --lang-dir data/lang_phone
+
+  for vocab_size in ${vocab_sizes[@]}; do
+    lang_dir=data/lang_bpe_${vocab_size}
+    ./local/compile_lg.py --lang-dir $lang_dir
+  done
+fi
+
+if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then
+  log "Stage 11: Generate LM training data"
+
+  for vocab_size in ${vocab_sizes[@]}; do
+    log "Processing vocab_size == ${vocab_size}"
+    lang_dir=data/lang_bpe_${vocab_size}
+    out_dir=data/lm_training_bpe_${vocab_size}
+    mkdir -p $out_dir
+
+    ./local/prepare_lm_training_data.py \
+      --bpe-model $lang_dir/bpe.model \
+      --lm-data $dl_dir/lm/lm_train.txt \
+      --lm-archive $out_dir/lm_data.pt
+  done
+fi
+
+if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then
+  log "Stage 12: Generate LM validation data"
+
+  for vocab_size in ${vocab_sizes[@]}; do
+    log "Processing vocab_size == ${vocab_size}"
+    out_dir=data/lm_training_bpe_${vocab_size}
+    mkdir -p $out_dir
+
+    if [ ! -f $out_dir/valid.txt ]; then
+      files=$dl_dir/xbmu_amdo31/data/transcript/dev_text
+      for f in ${files[@]}; do
+        cat $f | cut -d " " -f 2-
+      done > $out_dir/valid.txt
+    fi
+
+    lang_dir=data/lang_bpe_${vocab_size}
+    ./local/prepare_lm_training_data.py \
+      --bpe-model $lang_dir/bpe.model \
+      --lm-data $out_dir/valid.txt \
+      --lm-archive $out_dir/lm_data-valid.pt
+  done
+fi
+
+if [ $stage -le 13 ] && [ $stop_stage -ge 13 ]; then
+  log "Stage 13: Generate LM test data"
+
+  for vocab_size in ${vocab_sizes[@]}; do
+    log "Processing vocab_size == ${vocab_size}"
+    out_dir=data/lm_training_bpe_${vocab_size}
+    mkdir -p $out_dir
+
+    if [ ! -f $out_dir/test.txt ]; then
+        files=$dl_dir/xbmu_amdo31/data/transcript/test_text
+        cat $f | cut -d " " -f 2- > $out_dir/test.txt
+    fi
+
+    lang_dir=data/lang_bpe_${vocab_size}
+    ./local/prepare_lm_training_data.py \
+      --bpe-model $lang_dir/bpe.model \
+      --lm-data $out_dir/test.txt \
+      --lm-archive $out_dir/lm_data-test.pt
+  done
+fi
+
+if [ $stage -le 14 ] && [ $stop_stage -ge 14 ]; then
+  log "Stage 14: Sort LM training data"
+  # Sort LM training data by sentence length in descending order
+  # for ease of training.
+  #
+  # Sentence length equals to the number of BPE tokens
+  # in a sentence.
+
+  for vocab_size in ${vocab_sizes[@]}; do
+    out_dir=data/lm_training_bpe_${vocab_size}
+    mkdir -p $out_dir
+    ./local/sort_lm_training_data.py \
+      --in-lm-data $out_dir/lm_data.pt \
+      --out-lm-data $out_dir/sorted_lm_data.pt \
+      --out-statistics $out_dir/statistics.txt
+
+    ./local/sort_lm_training_data.py \
+      --in-lm-data $out_dir/lm_data-valid.pt \
+      --out-lm-data $out_dir/sorted_lm_data-valid.pt \
+      --out-statistics $out_dir/statistics-valid.txt
+
+    ./local/sort_lm_training_data.py \
+      --in-lm-data $out_dir/lm_data-test.pt \
+      --out-lm-data $out_dir/sorted_lm_data-test.pt \
+      --out-statistics $out_dir/statistics-test.txt
+  done
+fi
diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/__init__.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/asr_datamodule.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/asr_datamodule.py
new file mode 100644
index 000000000..55d5f4636
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/asr_datamodule.py
@@ -0,0 +1,408 @@
+# Copyright      2021  Piotr Żelasko
+# Copyright      2022  Xiaomi Corporation     (Author: Mingshuang Luo)
+# Copyright      2022  Northwest Minzu University     (Author: Senyan Li)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import argparse
+import inspect
+import logging
+from functools import lru_cache
+from pathlib import Path
+from typing import Any, Dict, Optional
+
+import torch
+from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
+from lhotse.dataset import CutConcatenate  # noqa F401 for PrecomputedFeatures
+from lhotse.dataset import (
+    CutMix,
+    DynamicBucketingSampler,
+    K2SpeechRecognitionDataset,
+    PrecomputedFeatures,
+    SingleCutSampler,
+    SpecAugment,
+)
+from lhotse.dataset.input_strategies import AudioSamples  # noqa F401 For AudioSamples
+from lhotse.dataset.input_strategies import OnTheFlyFeatures
+from lhotse.utils import fix_random_seed
+from torch.utils.data import DataLoader
+
+from icefall.utils import str2bool
+
+
+class _SeedWorkers:
+    def __init__(self, seed: int):
+        self.seed = seed
+
+    def __call__(self, worker_id: int):
+        fix_random_seed(self.seed + worker_id)
+
+
+class Xbmu_AmdoAsrDataModule:
+    """
+    DataModule for k2 ASR experiments.
+    It assumes there is always one train and valid dataloader,
+    but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
+    and test-other).
+
+    It contains all the common data pipeline modules used in ASR
+    experiments, e.g.:
+    - dynamic batch size,
+    - bucketing samplers,
+    - cut concatenation,
+    - augmentation,
+    - on-the-fly feature extraction
+
+    This class should be derived for specific corpora used in ASR tasks.
+    """
+
+    def __init__(self, args: argparse.Namespace):
+        self.args = args
+
+    @classmethod
+    def add_arguments(cls, parser: argparse.ArgumentParser):
+        group = parser.add_argument_group(
+            title="ASR data related options",
+            description="These options are used for the preparation of "
+            "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
+            "effective batch sizes, sampling strategies, applied data "
+            "augmentations, etc.",
+        )
+        group.add_argument(
+            "--manifest-dir",
+            type=Path,
+            default=Path("data/fbank"),
+            help="Path to directory with train/valid/test cuts.",
+        )
+        group.add_argument(
+            "--max-duration",
+            type=int,
+            default=200.0,
+            help="Maximum pooled recordings duration (seconds) in a "
+            "single batch. You can reduce it if it causes CUDA OOM.",
+        )
+        group.add_argument(
+            "--bucketing-sampler",
+            type=str2bool,
+            default=True,
+            help="When enabled, the batches will come from buckets of "
+            "similar duration (saves padding frames).",
+        )
+        group.add_argument(
+            "--num-buckets",
+            type=int,
+            default=30,
+            help="The number of buckets for the DynamicBucketingSampler"
+            "(you might want to increase it for larger datasets).",
+        )
+        group.add_argument(
+            "--concatenate-cuts",
+            type=str2bool,
+            default=False,
+            help="When enabled, utterances (cuts) will be concatenated "
+            "to minimize the amount of padding.",
+        )
+        group.add_argument(
+            "--duration-factor",
+            type=float,
+            default=1.0,
+            help="Determines the maximum duration of a concatenated cut "
+            "relative to the duration of the longest cut in a batch.",
+        )
+        group.add_argument(
+            "--gap",
+            type=float,
+            default=1.0,
+            help="The amount of padding (in seconds) inserted between "
+            "concatenated cuts. This padding is filled with noise when "
+            "noise augmentation is used.",
+        )
+        group.add_argument(
+            "--on-the-fly-feats",
+            type=str2bool,
+            default=False,
+            help="When enabled, use on-the-fly cut mixing and feature "
+            "extraction. Will drop existing precomputed feature manifests "
+            "if available.",
+        )
+        group.add_argument(
+            "--shuffle",
+            type=str2bool,
+            default=True,
+            help="When enabled (=default), the examples will be "
+            "shuffled for each epoch.",
+        )
+        group.add_argument(
+            "--drop-last",
+            type=str2bool,
+            default=True,
+            help="Whether to drop last batch. Used by sampler.",
+        )
+        group.add_argument(
+            "--return-cuts",
+            type=str2bool,
+            default=True,
+            help="When enabled, each batch will have the "
+            "field: batch['supervisions']['cut'] with the cuts that "
+            "were used to construct it.",
+        )
+
+        group.add_argument(
+            "--num-workers",
+            type=int,
+            default=2,
+            help="The number of training dataloader workers that "
+            "collect the batches.",
+        )
+
+        group.add_argument(
+            "--enable-spec-aug",
+            type=str2bool,
+            default=True,
+            help="When enabled, use SpecAugment for training dataset.",
+        )
+
+        group.add_argument(
+            "--spec-aug-time-warp-factor",
+            type=int,
+            default=80,
+            help="Used only when --enable-spec-aug is True. "
+            "It specifies the factor for time warping in SpecAugment. "
+            "Larger values mean more warping. "
+            "A value less than 1 means to disable time warp.",
+        )
+
+        group.add_argument(
+            "--enable-musan",
+            type=str2bool,
+            default=True,
+            help="When enabled, select noise from MUSAN and mix it"
+            "with training dataset. ",
+        )
+
+        group.add_argument(
+            "--input-strategy",
+            type=str,
+            default="PrecomputedFeatures",
+            help="AudioSamples or PrecomputedFeatures",
+        )
+
+    def train_dataloaders(
+        self,
+        cuts_train: CutSet,
+        sampler_state_dict: Optional[Dict[str, Any]] = None,
+    ) -> DataLoader:
+        """
+        Args:
+          cuts_train:
+            CutSet for training.
+          sampler_state_dict:
+            The state dict for the training sampler.
+        """
+        transforms = []
+        if self.args.enable_musan:
+            logging.info("Enable MUSAN")
+            logging.info("About to get Musan cuts")
+            cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
+            transforms.append(
+                CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
+            )
+        else:
+            logging.info("Disable MUSAN")
+
+        if self.args.concatenate_cuts:
+            logging.info(
+                f"Using cut concatenation with duration factor "
+                f"{self.args.duration_factor} and gap {self.args.gap}."
+            )
+            # Cut concatenation should be the first transform in the list,
+            # so that if we e.g. mix noise in, it will fill the gaps between
+            # different utterances.
+            transforms = [
+                CutConcatenate(
+                    duration_factor=self.args.duration_factor, gap=self.args.gap
+                )
+            ] + transforms
+
+        input_transforms = []
+        if self.args.enable_spec_aug:
+            logging.info("Enable SpecAugment")
+            logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
+            # Set the value of num_frame_masks according to Lhotse's version.
+            # In different Lhotse's versions, the default of num_frame_masks is
+            # different.
+            num_frame_masks = 10
+            num_frame_masks_parameter = inspect.signature(
+                SpecAugment.__init__
+            ).parameters["num_frame_masks"]
+            if num_frame_masks_parameter.default == 1:
+                num_frame_masks = 2
+            logging.info(f"Num frame mask: {num_frame_masks}")
+            input_transforms.append(
+                SpecAugment(
+                    time_warp_factor=self.args.spec_aug_time_warp_factor,
+                    num_frame_masks=num_frame_masks,
+                    features_mask_size=27,
+                    num_feature_masks=2,
+                    frames_mask_size=100,
+                )
+            )
+        else:
+            logging.info("Disable SpecAugment")
+
+        logging.info("About to create train dataset")
+        train = K2SpeechRecognitionDataset(
+            input_strategy=eval(self.args.input_strategy)(),
+            cut_transforms=transforms,
+            input_transforms=input_transforms,
+            return_cuts=self.args.return_cuts,
+        )
+
+        if self.args.on_the_fly_feats:
+            # NOTE: the PerturbSpeed transform should be added only if we
+            # remove it from data prep stage.
+            # Add on-the-fly speed perturbation; since originally it would
+            # have increased epoch size by 3, we will apply prob 2/3 and use
+            # 3x more epochs.
+            # Speed perturbation probably should come first before
+            # concatenation, but in principle the transforms order doesn't have
+            # to be strict (e.g. could be randomized)
+            # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms   # noqa
+            # Drop feats to be on the safe side.
+            train = K2SpeechRecognitionDataset(
+                cut_transforms=transforms,
+                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
+                input_transforms=input_transforms,
+                return_cuts=self.args.return_cuts,
+            )
+
+        if self.args.bucketing_sampler:
+            logging.info("Using DynamicBucketingSampler.")
+            train_sampler = DynamicBucketingSampler(
+                cuts_train,
+                max_duration=self.args.max_duration,
+                shuffle=self.args.shuffle,
+                num_buckets=self.args.num_buckets,
+                drop_last=self.args.drop_last,
+            )
+        else:
+            logging.info("Using SingleCutSampler.")
+            train_sampler = SingleCutSampler(
+                cuts_train,
+                max_duration=self.args.max_duration,
+                shuffle=self.args.shuffle,
+            )
+        logging.info("About to create train dataloader")
+
+        if sampler_state_dict is not None:
+            logging.info("Loading sampler state dict")
+            train_sampler.load_state_dict(sampler_state_dict)
+
+        # 'seed' is derived from the current random state, which will have
+        # previously been set in the main process.
+        seed = torch.randint(0, 100000, ()).item()
+        worker_init_fn = _SeedWorkers(seed)
+
+        train_dl = DataLoader(
+            train,
+            sampler=train_sampler,
+            batch_size=None,
+            num_workers=self.args.num_workers,
+            persistent_workers=False,
+            worker_init_fn=worker_init_fn,
+        )
+
+        return train_dl
+
+    def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
+        transforms = []
+        if self.args.concatenate_cuts:
+            transforms = [
+                CutConcatenate(
+                    duration_factor=self.args.duration_factor, gap=self.args.gap
+                )
+            ] + transforms
+
+        logging.info("About to create dev dataset")
+        if self.args.on_the_fly_feats:
+            validate = K2SpeechRecognitionDataset(
+                cut_transforms=transforms,
+                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
+                return_cuts=self.args.return_cuts,
+            )
+        else:
+            validate = K2SpeechRecognitionDataset(
+                cut_transforms=transforms,
+                return_cuts=self.args.return_cuts,
+            )
+        valid_sampler = DynamicBucketingSampler(
+            cuts_valid,
+            max_duration=self.args.max_duration,
+            shuffle=False,
+        )
+        logging.info("About to create dev dataloader")
+        valid_dl = DataLoader(
+            validate,
+            sampler=valid_sampler,
+            batch_size=None,
+            num_workers=2,
+            persistent_workers=False,
+        )
+
+        return valid_dl
+
+    def test_dataloaders(self, cuts: CutSet) -> DataLoader:
+        logging.debug("About to create test dataset")
+        test = K2SpeechRecognitionDataset(
+            input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
+            if self.args.on_the_fly_feats
+            else eval(self.args.input_strategy)(),
+            return_cuts=self.args.return_cuts,
+        )
+        sampler = DynamicBucketingSampler(
+            cuts,
+            max_duration=self.args.max_duration,
+            shuffle=False,
+        )
+        logging.debug("About to create test dataloader")
+        test_dl = DataLoader(
+            test,
+            batch_size=None,
+            sampler=sampler,
+            num_workers=self.args.num_workers,
+        )
+        return test_dl
+
+    @lru_cache()
+    def train_cuts(self) -> CutSet:
+        f = self.args.manifest_dir / "xbmu_amdo31_cuts_train.jsonl.gz"
+        logging.info(f"About to get train cuts from {f}")
+        cuts_train = load_manifest_lazy(f)
+        return cuts_train
+
+    @lru_cache()
+    def valid_cuts(self) -> CutSet:
+        f = self.args.manifest_dir / "xbmu_amdo31_cuts_dev.jsonl.gz"
+        logging.info(f"About to get valid cuts from {f}")
+        cuts_valid = load_manifest_lazy(f)
+        return cuts_valid
+
+    @lru_cache()
+    def test_cuts(self) -> CutSet:
+        f = self.args.manifest_dir / "xbmu_amdo31_cuts_test.jsonl.gz"
+        logging.info(f"About to get test cuts from {f}")
+        cuts_test = load_manifest_lazy(f)
+        return cuts_test
diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/beam_search.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/beam_search.py
new file mode 120000
index 000000000..e24eca39f
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/beam_search.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless2/beam_search.py
\ No newline at end of file
diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/conformer.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/conformer.py
new file mode 120000
index 000000000..c7c1a4b6e
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/conformer.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless5/conformer.py
\ No newline at end of file
diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/decode.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/decode.py
new file mode 100755
index 000000000..6a67e26f8
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/decode.py
@@ -0,0 +1,970 @@
+#!/usr/bin/env python3
+#
+# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang,
+#                                                 Zengwei Yao,
+#                                                 Xiaoyu Yang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+(1) greedy search
+./pruned_transducer_stateless5/decode.py \
+    --epoch 28 \
+    --avg 15 \
+    --exp-dir ./pruned_transducer_stateless5/exp \
+    --max-duration 600 \
+    --decoding-method greedy_search
+(2) beam search (not recommended)
+./pruned_transducer_stateless5/decode.py \
+    --epoch 28 \
+    --avg 15 \
+    --exp-dir ./pruned_transducer_stateless5/exp \
+    --max-duration 600 \
+    --decoding-method beam_search \
+    --beam-size 4
+(3) modified beam search
+./pruned_transducer_stateless5/decode.py \
+    --epoch 28 \
+    --avg 15 \
+    --exp-dir ./pruned_transducer_stateless5/exp \
+    --max-duration 600 \
+    --decoding-method modified_beam_search \
+    --beam-size 4
+(4) fast beam search (one best)
+./pruned_transducer_stateless5/decode.py \
+    --epoch 28 \
+    --avg 15 \
+    --exp-dir ./pruned_transducer_stateless5/exp \
+    --max-duration 600 \
+    --decoding-method fast_beam_search \
+    --beam 20.0 \
+    --max-contexts 8 \
+    --max-states 64
+(5) fast beam search (nbest)
+./pruned_transducer_stateless5/decode.py \
+    --epoch 28 \
+    --avg 15 \
+    --exp-dir ./pruned_transducer_stateless5/exp \
+    --max-duration 600 \
+    --decoding-method fast_beam_search_nbest \
+    --beam 20.0 \
+    --max-contexts 8 \
+    --max-states 64 \
+    --num-paths 200 \
+    --nbest-scale 0.5
+(6) fast beam search (nbest oracle WER)
+./pruned_transducer_stateless5/decode.py \
+    --epoch 28 \
+    --avg 15 \
+    --exp-dir ./pruned_transducer_stateless5/exp \
+    --max-duration 600 \
+    --decoding-method fast_beam_search_nbest_oracle \
+    --beam 20.0 \
+    --max-contexts 8 \
+    --max-states 64 \
+    --num-paths 200 \
+    --nbest-scale 0.5
+(7) fast beam search (with LG)
+./pruned_transducer_stateless5/decode.py \
+    --epoch 28 \
+    --avg 15 \
+    --exp-dir ./pruned_transducer_stateless5/exp \
+    --max-duration 600 \
+    --decoding-method fast_beam_search_nbest_LG \
+    --beam 20.0 \
+    --max-contexts 8 \
+    --max-states 64
+
+(8) modified beam search with RNNLM shallow fusion (with LG)
+./pruned_transducer_stateless5/decode.py \
+    --epoch 35 \
+    --avg 15 \
+    --exp-dir ./pruned_transducer_stateless5/exp \
+    --max-duration 600 \
+    --decoding-method fast_beam_search_nbest_LG \
+    --beam 4 \
+    --max-contexts 4 \
+    --rnn-lm-scale 0.4 \
+    --rnn-lm-exp-dir /path/to/RNNLM/exp \
+    --rnn-lm-epoch 99 \
+    --rnn-lm-avg 1 \
+    --rnn-lm-num-layers 3 \
+    --rnn-lm-tie-weights 1
+
+
+"""
+
+
+import argparse
+import logging
+import math
+from collections import defaultdict
+from pathlib import Path
+from typing import Dict, List, Optional, Tuple
+
+import k2
+import sentencepiece as spm
+import torch
+import torch.nn as nn
+from asr_datamodule import Xbmu_AmdoAsrDataModule
+from beam_search import (
+    beam_search,
+    fast_beam_search_nbest,
+    fast_beam_search_nbest_LG,
+    fast_beam_search_nbest_oracle,
+    fast_beam_search_one_best,
+    greedy_search,
+    greedy_search_batch,
+    modified_beam_search,
+    modified_beam_search_rnnlm_shallow_fusion,
+)
+from train import add_model_arguments, get_params, get_transducer_model
+
+from icefall.checkpoint import (
+    average_checkpoints,
+    average_checkpoints_with_averaged_model,
+    find_checkpoints,
+    load_checkpoint,
+)
+from icefall.lexicon import Lexicon
+from icefall.rnn_lm.model import RnnLmModel
+from icefall.utils import (
+    AttributeDict,
+    setup_logger,
+    store_transcripts,
+    str2bool,
+    write_error_stats,
+)
+
+LOG_EPS = math.log(1e-10)
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+
+    parser.add_argument(
+        "--epoch",
+        type=int,
+        default=30,
+        help="""It specifies the checkpoint to use for decoding.
+        Note: Epoch counts from 1.
+        You can specify --avg to use more checkpoints for model averaging.""",
+    )
+
+    parser.add_argument(
+        "--iter",
+        type=int,
+        default=0,
+        help="""If positive, --epoch is ignored and it
+        will use the checkpoint exp_dir/checkpoint-iter.pt.
+        You can specify --avg to use more checkpoints for model averaging.
+        """,
+    )
+
+    parser.add_argument(
+        "--avg",
+        type=int,
+        default=15,
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch' and '--iter'",
+    )
+
+    parser.add_argument(
+        "--use-averaged-model",
+        type=str2bool,
+        default=True,
+        help="Whether to load averaged model. Currently it only supports "
+        "using --epoch. If True, it would decode with the averaged model "
+        "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+        "Actually only the models with epoch number of `epoch-avg` and "
+        "`epoch` are loaded for averaging. ",
+    )
+
+    parser.add_argument(
+        "--exp-dir",
+        type=str,
+        default="pruned_transducer_stateless5/exp",
+        help="The experiment dir",
+    )
+
+    parser.add_argument(
+        "--bpe-model",
+        type=str,
+        default="data/lang_bpe_500/bpe.model",
+        help="Path to the BPE model",
+    )
+
+    parser.add_argument(
+        "--lang-dir",
+        type=Path,
+        default="data/lang_bpe_500",
+        help="The lang dir containing word table and LG graph",
+    )
+
+    parser.add_argument(
+        "--decoding-method",
+        type=str,
+        default="greedy_search",
+        help="""Possible values are:
+          - greedy_search
+          - beam_search
+          - modified_beam_search
+          - fast_beam_search
+          - fast_beam_search_LG
+          - fast_beam_search_nbest
+          - fast_beam_search_nbest_oracle
+          - fast_beam_search_nbest_LG
+          - modified_beam_search_rnnlm_shallow_fusion # for rnn lm shallow fusion
+        If you use fast_beam_search_nbest_LG, you have to specify
+        `--lang-dir`, which should contain `LG.pt`.
+        """,
+    )
+
+    parser.add_argument(
+        "--beam-size",
+        type=int,
+        default=4,
+        help="""An integer indicating how many candidates we will keep for each
+        frame. Used only when --decoding-method is beam_search or
+        modified_beam_search.""",
+    )
+
+    parser.add_argument(
+        "--beam",
+        type=float,
+        default=20.0,
+        help="""A floating point value to calculate the cutoff score during beam
+        search (i.e., `cutoff = max-score - beam`), which is the same as the
+        `beam` in Kaldi.
+        Used only when --decoding-method is fast_beam_search, fast_beam_search_LG,
+        fast_beam_search_nbest, fast_beam_search_nbest_LG,
+        and fast_beam_search_nbest_oracle
+        """,
+    )
+
+    parser.add_argument(
+        "--ngram-lm-scale",
+        type=float,
+        default=0.01,
+        help="""
+        Used only when --decoding_method is fast_beam_search_nbest_LG and fast_beam_search_LG.
+        It specifies the scale for n-gram LM scores.
+        """,
+    )
+
+    parser.add_argument(
+        "--decode-chunk-size",
+        type=int,
+        default=16,
+        help="The chunk size for decoding (in frames after subsampling)",
+    )
+
+    parser.add_argument(
+        "--left-context",
+        type=int,
+        default=64,
+        help="left context can be seen during decoding (in frames after subsampling)",
+    )
+
+    parser.add_argument(
+        "--max-contexts",
+        type=int,
+        default=8,
+        help="""Used only when --decoding-method is fast_beam_search_LG,
+        fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
+        and fast_beam_search_nbest_oracle""",
+    )
+
+    parser.add_argument(
+        "--max-states",
+        type=int,
+        default=64,
+        help="""Used only when --decoding-method is fast_beam_search_LG,
+        fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
+        and fast_beam_search_nbest_oracle""",
+    )
+
+    parser.add_argument(
+        "--context-size",
+        type=int,
+        default=2,
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+    )
+
+    parser.add_argument(
+        "--max-sym-per-frame",
+        type=int,
+        default=1,
+        help="""Maximum number of symbols per frame.
+        Used only when --decoding_method is greedy_search""",
+    )
+
+    parser.add_argument(
+        "--num-paths",
+        type=int,
+        default=200,
+        help="""Number of paths for nbest decoding.
+        Used only when the decoding method is fast_beam_search_nbest,
+        fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
+    )
+
+    parser.add_argument(
+        "--nbest-scale",
+        type=float,
+        default=0.5,
+        help="""Scale applied to lattice scores when computing nbest paths.
+        Used only when the decoding method is fast_beam_search_nbest,
+        fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
+    )
+
+    parser.add_argument(
+        "--simulate-streaming",
+        type=str2bool,
+        default=False,
+        help="""Whether to simulate streaming in decoding, this is a good way to
+        test a streaming model.
+        """,
+    )
+
+    parser.add_argument(
+        "--rnn-lm-scale",
+        type=float,
+        default=0.0,
+        help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion.
+        It specifies the path to RNN LM exp dir.
+        """,
+    )
+
+    parser.add_argument(
+        "--rnn-lm-exp-dir",
+        type=str,
+        default="rnn_lm/exp",
+        help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion.
+        It specifies the path to RNN LM exp dir.
+        """,
+    )
+
+    parser.add_argument(
+        "--rnn-lm-epoch",
+        type=int,
+        default=7,
+        help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion.
+        It specifies the checkpoint to use.
+        """,
+    )
+
+    parser.add_argument(
+        "--rnn-lm-avg",
+        type=int,
+        default=2,
+        help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion.
+        It specifies the number of checkpoints to average.
+        """,
+    )
+
+    parser.add_argument(
+        "--rnn-lm-embedding-dim",
+        type=int,
+        default=2048,
+        help="Embedding dim of the model",
+    )
+
+    parser.add_argument(
+        "--rnn-lm-hidden-dim",
+        type=int,
+        default=2048,
+        help="Hidden dim of the model",
+    )
+
+    parser.add_argument(
+        "--rnn-lm-num-layers",
+        type=int,
+        default=4,
+        help="Number of RNN layers the model",
+    )
+    parser.add_argument(
+        "--rnn-lm-tie-weights",
+        type=str2bool,
+        default=False,
+        help="""True to share the weights between the input embedding layer and the
+        last output linear layer
+        """,
+    )
+    add_model_arguments(parser)
+
+    return parser
+
+
+def decode_one_batch(
+    params: AttributeDict,
+    model: nn.Module,
+    sp: spm.SentencePieceProcessor,
+    batch: dict,
+    word_table: Optional[k2.SymbolTable] = None,
+    decoding_graph: Optional[k2.Fsa] = None,
+    rnnlm: Optional[RnnLmModel] = None,
+    rnnlm_scale: float = 1.0,
+) -> Dict[str, List[List[str]]]:
+    """Decode one batch and return the result in a dict. The dict has the
+    following format:
+
+        - key: It indicates the setting used for decoding. For example,
+               if greedy_search is used, it would be "greedy_search"
+               If beam search with a beam size of 7 is used, it would be
+               "beam_7"
+        - value: It contains the decoding result. `len(value)` equals to
+                 batch size. `value[i]` is the decoding result for the i-th
+                 utterance in the given batch.
+    Args:
+      params:
+        It's the return value of :func:`get_params`.
+      model:
+        The neural model.
+      sp:
+        The BPE model.
+      batch:
+        It is the return value from iterating
+        `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
+        for the format of the `batch`.
+      word_table:
+        The word symbol table.
+      decoding_graph:
+        The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
+        only when --decoding_method is fast_beam_search, fast_beam_search_LG, fast_beam_search_nbest,
+        fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
+    Returns:
+      Return the decoding result. See above description for the format of
+      the returned dict.
+    """
+    device = next(model.parameters()).device
+    feature = batch["inputs"]
+    assert feature.ndim == 3
+
+    feature = feature.to(device)
+    # at entry, feature is (N, T, C)
+
+    supervisions = batch["supervisions"]
+    feature_lens = supervisions["num_frames"].to(device)
+
+    if params.simulate_streaming:
+        feature_lens += params.left_context
+        feature = torch.nn.functional.pad(
+            feature,
+            pad=(0, 0, 0, params.left_context),
+            value=LOG_EPS,
+        )
+        encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
+            x=feature,
+            x_lens=feature_lens,
+            chunk_size=params.decode_chunk_size,
+            left_context=params.left_context,
+            simulate_streaming=True,
+        )
+    else:
+        encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
+
+    hyps = []
+
+    if (
+        params.decoding_method == "fast_beam_search"
+        or params.decoding_method == "fast_beam_search_LG"
+    ):
+        hyp_tokens = fast_beam_search_one_best(
+            model=model,
+            decoding_graph=decoding_graph,
+            encoder_out=encoder_out,
+            encoder_out_lens=encoder_out_lens,
+            beam=params.beam,
+            max_contexts=params.max_contexts,
+            max_states=params.max_states,
+        )
+        if params.decoding_method == "fast_beam_search":
+            for hyp in sp.decode(hyp_tokens):
+                hyps.append(hyp.split())
+        else:
+            for hyp in hyp_tokens:
+                hyps.append([word_table[i] for i in hyp])
+    elif params.decoding_method == "fast_beam_search_nbest_LG":
+        hyp_tokens = fast_beam_search_nbest_LG(
+            model=model,
+            decoding_graph=decoding_graph,
+            encoder_out=encoder_out,
+            encoder_out_lens=encoder_out_lens,
+            beam=params.beam,
+            max_contexts=params.max_contexts,
+            max_states=params.max_states,
+            num_paths=params.num_paths,
+            nbest_scale=params.nbest_scale,
+        )
+        for hyp in hyp_tokens:
+            hyps.append([word_table[i] for i in hyp])
+    elif params.decoding_method == "fast_beam_search_nbest":
+        hyp_tokens = fast_beam_search_nbest(
+            model=model,
+            decoding_graph=decoding_graph,
+            encoder_out=encoder_out,
+            encoder_out_lens=encoder_out_lens,
+            beam=params.beam,
+            max_contexts=params.max_contexts,
+            max_states=params.max_states,
+            num_paths=params.num_paths,
+            nbest_scale=params.nbest_scale,
+        )
+        for hyp in sp.decode(hyp_tokens):
+            hyps.append(hyp.split())
+    elif params.decoding_method == "fast_beam_search_nbest_oracle":
+        hyp_tokens = fast_beam_search_nbest_oracle(
+            model=model,
+            decoding_graph=decoding_graph,
+            encoder_out=encoder_out,
+            encoder_out_lens=encoder_out_lens,
+            beam=params.beam,
+            max_contexts=params.max_contexts,
+            max_states=params.max_states,
+            num_paths=params.num_paths,
+            ref_texts=sp.encode(supervisions["text"]),
+            nbest_scale=params.nbest_scale,
+        )
+        for hyp in sp.decode(hyp_tokens):
+            hyps.append(hyp.split())
+    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
+        hyp_tokens = greedy_search_batch(
+            model=model,
+            encoder_out=encoder_out,
+            encoder_out_lens=encoder_out_lens,
+        )
+        for hyp in sp.decode(hyp_tokens):
+            hyps.append(hyp.split())
+    elif params.decoding_method == "modified_beam_search":
+        hyp_tokens = modified_beam_search(
+            model=model,
+            encoder_out=encoder_out,
+            encoder_out_lens=encoder_out_lens,
+            beam=params.beam_size,
+        )
+        for hyp in sp.decode(hyp_tokens):
+            hyps.append(hyp.split())
+    elif params.decoding_method == "modified_beam_search_rnnlm_shallow_fusion":
+        hyp_tokens = modified_beam_search_rnnlm_shallow_fusion(
+            model=model,
+            encoder_out=encoder_out,
+            encoder_out_lens=encoder_out_lens,
+            beam=params.beam_size,
+            sp=sp,
+            rnnlm=rnnlm,
+            rnnlm_scale=rnnlm_scale,
+        )
+        for hyp in sp.decode(hyp_tokens):
+            hyps.append(hyp.split())
+    else:
+        batch_size = encoder_out.size(0)
+
+        for i in range(batch_size):
+            # fmt: off
+            encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
+            # fmt: on
+            if params.decoding_method == "greedy_search":
+                hyp = greedy_search(
+                    model=model,
+                    encoder_out=encoder_out_i,
+                    max_sym_per_frame=params.max_sym_per_frame,
+                )
+            elif params.decoding_method == "beam_search":
+                hyp = beam_search(
+                    model=model,
+                    encoder_out=encoder_out_i,
+                    beam=params.beam_size,
+                )
+            else:
+                raise ValueError(
+                    f"Unsupported decoding method: {params.decoding_method}"
+                )
+            hyps.append(sp.decode(hyp).split())
+
+    if params.decoding_method == "greedy_search":
+        return {"greedy_search": hyps}
+    elif "fast_beam_search" in params.decoding_method:
+        key = f"beam_{params.beam}_"
+        key += f"max_contexts_{params.max_contexts}_"
+        key += f"max_states_{params.max_states}"
+        if "nbest" in params.decoding_method:
+            key += f"_num_paths_{params.num_paths}_"
+            key += f"nbest_scale_{params.nbest_scale}"
+        if "LG" in params.decoding_method:
+            key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
+
+        return {key: hyps}
+    else:
+        return {f"beam_size_{params.beam_size}": hyps}
+
+
+def decode_dataset(
+    dl: torch.utils.data.DataLoader,
+    params: AttributeDict,
+    model: nn.Module,
+    sp: spm.SentencePieceProcessor,
+    word_table: Optional[k2.SymbolTable] = None,
+    decoding_graph: Optional[k2.Fsa] = None,
+    rnnlm: Optional[RnnLmModel] = None,
+    rnnlm_scale: float = 1.0,
+) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
+    """Decode dataset.
+
+    Args:
+      dl:
+        PyTorch's dataloader containing the dataset to decode.
+      params:
+        It is returned by :func:`get_params`.
+      model:
+        The neural model.
+      sp:
+        The BPE model.
+      word_table:
+        The word symbol table.
+      decoding_graph:
+        The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
+        only when --decoding_method is fast_beam_search, fast_beam_search_LG, fast_beam_search_nbest,
+        fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
+    Returns:
+      Return a dict, whose key may be "greedy_search" if greedy search
+      is used, or it may be "beam_7" if beam size of 7 is used.
+      Its value is a list of tuples. Each tuple contains two elements:
+      The first is the reference transcript, and the second is the
+      predicted result.
+    """
+    num_cuts = 0
+
+    try:
+        num_batches = len(dl)
+    except TypeError:
+        num_batches = "?"
+
+    if params.decoding_method == "greedy_search":
+        log_interval = 50
+    else:
+        log_interval = 20
+
+    results = defaultdict(list)
+    for batch_idx, batch in enumerate(dl):
+        texts = batch["supervisions"]["text"]
+        cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
+        logging.info(f"Decoding {batch_idx}-th batch")
+
+        hyps_dict = decode_one_batch(
+            params=params,
+            model=model,
+            sp=sp,
+            decoding_graph=decoding_graph,
+            word_table=word_table,
+            batch=batch,
+            rnnlm=rnnlm,
+            rnnlm_scale=rnnlm_scale,
+        )
+
+        for name, hyps in hyps_dict.items():
+            this_batch = []
+            assert len(hyps) == len(texts)
+            for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
+                ref_words = ref_text.split()
+                this_batch.append((cut_id, ref_words, hyp_words))
+
+            results[name].extend(this_batch)
+
+        num_cuts += len(texts)
+
+        if batch_idx % log_interval == 0:
+            batch_str = f"{batch_idx}/{num_batches}"
+
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+    return results
+
+
+def save_results(
+    params: AttributeDict,
+    test_set_name: str,
+    results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
+):
+    test_set_wers = dict()
+    for key, results in results_dict.items():
+        recog_path = (
+            params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
+        )
+        results = sorted(results)
+        store_transcripts(filename=recog_path, texts=results)
+        logging.info(f"The transcripts are stored in {recog_path}")
+
+        # The following prints out WERs, per-word error statistics and aligned
+        # ref/hyp pairs.
+        errs_filename = (
+            params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
+        )
+        with open(errs_filename, "w") as f:
+            wer = write_error_stats(
+                f, f"{test_set_name}-{key}", results, enable_log=True
+            )
+            test_set_wers[key] = wer
+
+        logging.info("Wrote detailed error stats to {}".format(errs_filename))
+
+    test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
+    errs_info = (
+        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+    )
+    with open(errs_info, "w") as f:
+        print("settings\tWER", file=f)
+        for key, val in test_set_wers:
+            print("{}\t{}".format(key, val), file=f)
+
+    s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
+    note = "\tbest for {}".format(test_set_name)
+    for key, val in test_set_wers:
+        s += "{}\t{}{}\n".format(key, val, note)
+        note = ""
+    logging.info(s)
+
+
+@torch.no_grad()
+def main():
+    parser = get_parser()
+    Xbmu_AmdoAsrDataModule.add_arguments(parser)
+    args = parser.parse_args()
+    args.exp_dir = Path(args.exp_dir)
+
+    params = get_params()
+    params.update(vars(args))
+
+    assert params.decoding_method in (
+        "greedy_search",
+        "beam_search",
+        "fast_beam_search",
+        "fast_beam_search_LG",
+        "fast_beam_search_nbest",
+        "fast_beam_search_nbest_LG",
+        "fast_beam_search_nbest_oracle",
+        "modified_beam_search",
+        "modified_beam_search_rnnlm_shallow_fusion",
+    )
+    params.res_dir = params.exp_dir / params.decoding_method
+
+    if params.iter > 0:
+        params.suffix = f"iter-{params.iter}-avg-{params.avg}"
+    else:
+        params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
+    if params.simulate_streaming:
+        params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}"
+        params.suffix += f"-left-context-{params.left_context}"
+    if "fast_beam_search" in params.decoding_method:
+        params.suffix += f"-beam-{params.beam}"
+        params.suffix += f"-max-contexts-{params.max_contexts}"
+        params.suffix += f"-max-states-{params.max_states}"
+        if "nbest" in params.decoding_method:
+            params.suffix += f"-nbest-scale-{params.nbest_scale}"
+            params.suffix += f"-num-paths-{params.num_paths}"
+        if "LG" in params.decoding_method:
+            params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
+    elif "beam_search" in params.decoding_method:
+        params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
+    else:
+        params.suffix += f"-context-{params.context_size}"
+        params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
+
+    params.suffix += f"-rnnlm-lm-scale-{params.rnn_lm_scale}"
+
+    if params.use_averaged_model:
+        params.suffix += "-use-averaged-model"
+
+    setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
+    logging.info("Decoding started")
+
+    device = torch.device("cpu")
+    if torch.cuda.is_available():
+        device = torch.device("cuda", 0)
+
+    logging.info(f"Device: {device}")
+
+    sp = spm.SentencePieceProcessor()
+    sp.load(params.bpe_model)
+
+    #  and  are defined in local/train_bpe_model.py
+    params.blank_id = sp.piece_to_id("")
+    params.unk_id = sp.piece_to_id("")
+    params.vocab_size = sp.get_piece_size()
+
+    if params.simulate_streaming:
+        assert (
+            params.causal_convolution
+        ), "Decoding in streaming requires causal convolution"
+
+    logging.info(params)
+
+    logging.info("About to create model")
+    model = get_transducer_model(params)
+
+    if not params.use_averaged_model:
+        if params.iter > 0:
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg
+            ]
+            if len(filenames) == 0:
+                raise ValueError(
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            elif len(filenames) < params.avg:
+                raise ValueError(
+                    f"Not enough checkpoints ({len(filenames)}) found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            logging.info(f"averaging {filenames}")
+            model.to(device)
+            model.load_state_dict(average_checkpoints(filenames, device=device))
+        elif params.avg == 1:
+            load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+        else:
+            start = params.epoch - params.avg + 1
+            filenames = []
+            for i in range(start, params.epoch + 1):
+                if i >= 1:
+                    filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+            logging.info(f"averaging {filenames}")
+            model.to(device)
+            model.load_state_dict(average_checkpoints(filenames, device=device))
+    else:
+        if params.iter > 0:
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg + 1
+            ]
+            if len(filenames) == 0:
+                raise ValueError(
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            elif len(filenames) < params.avg + 1:
+                raise ValueError(
+                    f"Not enough checkpoints ({len(filenames)}) found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            filename_start = filenames[-1]
+            filename_end = filenames[0]
+            logging.info(
+                "Calculating the averaged model over iteration checkpoints"
+                f" from {filename_start} (excluded) to {filename_end}"
+            )
+            model.to(device)
+            model.load_state_dict(
+                average_checkpoints_with_averaged_model(
+                    filename_start=filename_start,
+                    filename_end=filename_end,
+                    device=device,
+                )
+            )
+        else:
+            assert params.avg > 0, params.avg
+            start = params.epoch - params.avg
+            assert start >= 1, start
+            filename_start = f"{params.exp_dir}/epoch-{start}.pt"
+            filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
+            logging.info(
+                f"Calculating the averaged model over epoch range from "
+                f"{start} (excluded) to {params.epoch}"
+            )
+            model.to(device)
+            model.load_state_dict(
+                average_checkpoints_with_averaged_model(
+                    filename_start=filename_start,
+                    filename_end=filename_end,
+                    device=device,
+                )
+            )
+
+    model.to(device)
+    model.eval()
+
+    rnn_lm_model = None
+    rnn_lm_scale = params.rnn_lm_scale
+    if params.decoding_method == "modified_beam_search_rnnlm_shallow_fusion":
+        rnn_lm_model = RnnLmModel(
+            vocab_size=params.vocab_size,
+            embedding_dim=params.rnn_lm_embedding_dim,
+            hidden_dim=params.rnn_lm_hidden_dim,
+            num_layers=params.rnn_lm_num_layers,
+            tie_weights=params.rnn_lm_tie_weights,
+        )
+        assert params.rnn_lm_avg == 1
+
+        load_checkpoint(
+            f"{params.rnn_lm_exp_dir}/epoch-{params.rnn_lm_epoch}.pt",
+            rnn_lm_model,
+        )
+        rnn_lm_model.to(device)
+        rnn_lm_model.eval()
+
+    if "fast_beam_search" in params.decoding_method:
+        if "LG" in params.decoding_method:
+            lexicon = Lexicon(params.lang_dir)
+            word_table = lexicon.word_table
+            lg_filename = params.lang_dir / "LG.pt"
+            logging.info(f"Loading {lg_filename}")
+            decoding_graph = k2.Fsa.from_dict(
+                torch.load(lg_filename, map_location=device)
+            )
+            decoding_graph.scores *= params.ngram_lm_scale
+        else:
+            word_table = None
+            decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
+    else:
+        decoding_graph = None
+        word_table = None
+
+    num_param = sum([p.numel() for p in model.parameters()])
+    logging.info(f"Number of model parameters: {num_param}")
+
+    # we need cut ids to display recognition results.
+    args.return_cuts = True
+    xbmu_amdo = Xbmu_AmdoAsrDataModule(args)
+
+    test_cuts = xbmu_amdo.test_cuts()
+
+    test_dl = xbmu_amdo.test_dataloaders(test_cuts)
+
+    test_sets = ["test"]
+    test_dl = [test_dl]
+
+    for test_set, test_dl in zip(test_sets, test_dl):
+        results_dict = decode_dataset(
+            dl=test_dl,
+            params=params,
+            model=model,
+            sp=sp,
+            word_table=word_table,
+            decoding_graph=decoding_graph,
+            rnnlm=rnn_lm_model,
+            rnnlm_scale=rnn_lm_scale,
+        )
+
+        save_results(
+            params=params,
+            test_set_name=test_set,
+            results_dict=results_dict,
+        )
+
+    logging.info("Done!")
+
+
+if __name__ == "__main__":
+    main()
diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/decode_stream.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/decode_stream.py
new file mode 120000
index 000000000..d59ef95f7
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/decode_stream.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless2/decode_stream.py
\ No newline at end of file
diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/decoder.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/decoder.py
new file mode 120000
index 000000000..722e1c894
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/decoder.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless2/decoder.py
\ No newline at end of file
diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/encoder_interface.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/encoder_interface.py
new file mode 120000
index 000000000..f58253127
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/encoder_interface.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless2/encoder_interface.py
\ No newline at end of file
diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/export.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/export.py
new file mode 100755
index 000000000..54f656859
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/export.py
@@ -0,0 +1,287 @@
+#!/usr/bin/env python3
+#
+# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# This script converts several saved checkpoints
+# to a single one using model averaging.
+"""
+Usage:
+./pruned_transducer_stateless5/export.py \
+  --exp-dir ./pruned_transducer_stateless5/exp \
+  --bpe-model data/lang_bpe_500/bpe.model \
+  --epoch 20 \
+  --avg 10
+
+It will generate a file exp_dir/pretrained.pt
+
+To use the generated file with `pruned_transducer_stateless5/decode.py`,
+you can do:
+
+    cd /path/to/exp_dir
+    ln -s pretrained.pt epoch-9999.pt
+
+    cd /path/to/egs/librispeech/ASR
+    ./pruned_transducer_stateless5/decode.py \
+        --exp-dir ./pruned_transducer_stateless5/exp \
+        --epoch 9999 \
+        --avg 1 \
+        --max-duration 600 \
+        --decoding-method greedy_search \
+        --bpe-model data/lang_bpe_500/bpe.model
+"""
+
+import argparse
+import logging
+from pathlib import Path
+
+import sentencepiece as spm
+import torch
+from scaling_converter import convert_scaled_to_non_scaled
+from train import add_model_arguments, get_params, get_transducer_model
+
+from icefall.checkpoint import (
+    average_checkpoints,
+    average_checkpoints_with_averaged_model,
+    find_checkpoints,
+    load_checkpoint,
+)
+from icefall.utils import str2bool
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+
+    parser.add_argument(
+        "--epoch",
+        type=int,
+        default=28,
+        help="""It specifies the checkpoint to use for averaging.
+        Note: Epoch counts from 1.
+        You can specify --avg to use more checkpoints for model averaging.""",
+    )
+
+    parser.add_argument(
+        "--iter",
+        type=int,
+        default=0,
+        help="""If positive, --epoch is ignored and it
+        will use the checkpoint exp_dir/checkpoint-iter.pt.
+        You can specify --avg to use more checkpoints for model averaging.
+        """,
+    )
+
+    parser.add_argument(
+        "--avg",
+        type=int,
+        default=15,
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch' and '--iter'",
+    )
+
+    parser.add_argument(
+        "--use-averaged-model",
+        type=str2bool,
+        default=True,
+        help="Whether to load averaged model. Currently it only supports "
+        "using --epoch. If True, it would decode with the averaged model "
+        "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+        "Actually only the models with epoch number of `epoch-avg` and "
+        "`epoch` are loaded for averaging. ",
+    )
+
+    parser.add_argument(
+        "--exp-dir",
+        type=str,
+        default="pruned_transducer_stateless5/exp",
+        help="""It specifies the directory where all training related
+        files, e.g., checkpoints, log, etc, are saved
+        """,
+    )
+
+    parser.add_argument(
+        "--bpe-model",
+        type=str,
+        default="data/lang_bpe_500/bpe.model",
+        help="Path to the BPE model",
+    )
+
+    parser.add_argument(
+        "--jit",
+        type=str2bool,
+        default=False,
+        help="""True to save a model after applying torch.jit.script.
+        """,
+    )
+
+    parser.add_argument(
+        "--context-size",
+        type=int,
+        default=2,
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+    )
+
+    parser.add_argument(
+        "--streaming-model",
+        type=str2bool,
+        default=False,
+        help="""Whether to export a streaming model, if the models in exp-dir
+        are streaming model, this should be True.
+        """,
+    )
+
+    add_model_arguments(parser)
+
+    return parser
+
+
+def main():
+    args = get_parser().parse_args()
+    args.exp_dir = Path(args.exp_dir)
+
+    params = get_params()
+    params.update(vars(args))
+
+    device = torch.device("cpu")
+    if torch.cuda.is_available():
+        device = torch.device("cuda", 0)
+
+    logging.info(f"device: {device}")
+
+    sp = spm.SentencePieceProcessor()
+    sp.load(params.bpe_model)
+
+    #  is defined in local/train_bpe_model.py
+    params.blank_id = sp.piece_to_id("")
+    params.vocab_size = sp.get_piece_size()
+
+    if params.streaming_model:
+        assert params.causal_convolution
+
+    logging.info(params)
+
+    logging.info("About to create model")
+    model = get_transducer_model(params)
+
+    if not params.use_averaged_model:
+        if params.iter > 0:
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg
+            ]
+            if len(filenames) == 0:
+                raise ValueError(
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            elif len(filenames) < params.avg:
+                raise ValueError(
+                    f"Not enough checkpoints ({len(filenames)}) found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            logging.info(f"averaging {filenames}")
+            model.to(device)
+            model.load_state_dict(average_checkpoints(filenames, device=device))
+        elif params.avg == 1:
+            load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+        else:
+            start = params.epoch - params.avg + 1
+            filenames = []
+            for i in range(start, params.epoch + 1):
+                if i >= 1:
+                    filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+            logging.info(f"averaging {filenames}")
+            model.to(device)
+            model.load_state_dict(average_checkpoints(filenames, device=device))
+    else:
+        if params.iter > 0:
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg + 1
+            ]
+            if len(filenames) == 0:
+                raise ValueError(
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            elif len(filenames) < params.avg + 1:
+                raise ValueError(
+                    f"Not enough checkpoints ({len(filenames)}) found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            filename_start = filenames[-1]
+            filename_end = filenames[0]
+            logging.info(
+                "Calculating the averaged model over iteration checkpoints"
+                f" from {filename_start} (excluded) to {filename_end}"
+            )
+            model.to(device)
+            model.load_state_dict(
+                average_checkpoints_with_averaged_model(
+                    filename_start=filename_start,
+                    filename_end=filename_end,
+                    device=device,
+                )
+            )
+        else:
+            assert params.avg > 0, params.avg
+            start = params.epoch - params.avg
+            assert start >= 1, start
+            filename_start = f"{params.exp_dir}/epoch-{start}.pt"
+            filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
+            logging.info(
+                f"Calculating the averaged model over epoch range from "
+                f"{start} (excluded) to {params.epoch}"
+            )
+            model.to(device)
+            model.load_state_dict(
+                average_checkpoints_with_averaged_model(
+                    filename_start=filename_start,
+                    filename_end=filename_end,
+                    device=device,
+                )
+            )
+
+    model.to("cpu")
+    model.eval()
+
+    if params.jit:
+        # We won't use the forward() method of the model in C++, so just ignore
+        # it here.
+        # Otherwise, one of its arguments is a ragged tensor and is not
+        # torch scriptabe.
+        convert_scaled_to_non_scaled(model, inplace=True)
+        model.__class__.forward = torch.jit.ignore(model.__class__.forward)
+        logging.info("Using torch.jit.script")
+        model = torch.jit.script(model)
+        filename = params.exp_dir / "cpu_jit.pt"
+        model.save(str(filename))
+        logging.info(f"Saved to {filename}")
+    else:
+        logging.info("Not using torch.jit.script")
+        # Save it using a format so that it can be loaded
+        # by :func:`load_checkpoint`
+        filename = params.exp_dir / "pretrained.pt"
+        torch.save({"model": model.state_dict()}, str(filename))
+        logging.info(f"Saved to {filename}")
+
+
+if __name__ == "__main__":
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+    logging.basicConfig(format=formatter, level=logging.INFO)
+    main()
diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/joiner.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/joiner.py
new file mode 120000
index 000000000..9052f3cbb
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/joiner.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless2/joiner.py
\ No newline at end of file
diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/lstmp.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/lstmp.py
new file mode 120000
index 000000000..b82e115fc
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/lstmp.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/lstm_transducer_stateless2/lstmp.py
\ No newline at end of file
diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/model.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/model.py
new file mode 120000
index 000000000..a99e74334
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/model.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless2/model.py
\ No newline at end of file
diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/optim.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/optim.py
new file mode 120000
index 000000000..0a2f285aa
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/optim.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless2/optim.py
\ No newline at end of file
diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/pretrained.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/pretrained.py
new file mode 100755
index 000000000..74a2210c3
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/pretrained.py
@@ -0,0 +1,344 @@
+#!/usr/bin/env python3
+# Copyright      2021  Xiaomi Corp.        (authors: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+
+(1) greedy search
+./pruned_transducer_stateless5/pretrained.py \
+    --checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \
+    --bpe-model ./data/lang_bpe_500/bpe.model \
+    --method greedy_search \
+    /path/to/foo.wav \
+    /path/to/bar.wav
+
+(2) beam search
+./pruned_transducer_stateless5/pretrained.py \
+    --checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \
+    --bpe-model ./data/lang_bpe_500/bpe.model \
+    --method beam_search \
+    --beam-size 4 \
+    /path/to/foo.wav \
+    /path/to/bar.wav
+
+(3) modified beam search
+./pruned_transducer_stateless5/pretrained.py \
+    --checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \
+    --bpe-model ./data/lang_bpe_500/bpe.model \
+    --method modified_beam_search \
+    --beam-size 4 \
+    /path/to/foo.wav \
+    /path/to/bar.wav
+
+(4) fast beam search
+./pruned_transducer_stateless5/pretrained.py \
+    --checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \
+    --bpe-model ./data/lang_bpe_500/bpe.model \
+    --method fast_beam_search \
+    --beam-size 4 \
+    /path/to/foo.wav \
+    /path/to/bar.wav
+
+You can also use `./pruned_transducer_stateless5/exp/epoch-xx.pt`.
+
+Note: ./pruned_transducer_stateless5/exp/pretrained.pt is generated by
+./pruned_transducer_stateless5/export.py
+"""
+
+
+import argparse
+import logging
+import math
+from typing import List
+
+import k2
+import kaldifeat
+import sentencepiece as spm
+import torch
+import torchaudio
+from beam_search import (
+    beam_search,
+    fast_beam_search_one_best,
+    greedy_search,
+    greedy_search_batch,
+    modified_beam_search,
+)
+from torch.nn.utils.rnn import pad_sequence
+from train import add_model_arguments, get_params, get_transducer_model
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+
+    parser.add_argument(
+        "--checkpoint",
+        type=str,
+        required=True,
+        help="Path to the checkpoint. "
+        "The checkpoint is assumed to be saved by "
+        "icefall.checkpoint.save_checkpoint().",
+    )
+
+    parser.add_argument(
+        "--bpe-model",
+        type=str,
+        help="""Path to bpe.model.""",
+    )
+
+    parser.add_argument(
+        "--method",
+        type=str,
+        default="greedy_search",
+        help="""Possible values are:
+          - greedy_search
+          - beam_search
+          - modified_beam_search
+          - fast_beam_search
+        """,
+    )
+
+    parser.add_argument(
+        "sound_files",
+        type=str,
+        nargs="+",
+        help="The input sound file(s) to transcribe. "
+        "Supported formats are those supported by torchaudio.load(). "
+        "For example, wav and flac are supported. "
+        "The sample rate has to be 16kHz.",
+    )
+
+    parser.add_argument(
+        "--sample-rate",
+        type=int,
+        default=16000,
+        help="The sample rate of the input sound file",
+    )
+
+    parser.add_argument(
+        "--beam-size",
+        type=int,
+        default=4,
+        help="""An integer indicating how many candidates we will keep for each
+        frame. Used only when --method is beam_search or
+        modified_beam_search.""",
+    )
+
+    parser.add_argument(
+        "--beam",
+        type=float,
+        default=4,
+        help="""A floating point value to calculate the cutoff score during beam
+        search (i.e., `cutoff = max-score - beam`), which is the same as the
+        `beam` in Kaldi.
+        Used only when --method is fast_beam_search""",
+    )
+
+    parser.add_argument(
+        "--max-contexts",
+        type=int,
+        default=4,
+        help="""Used only when --method is fast_beam_search""",
+    )
+
+    parser.add_argument(
+        "--max-states",
+        type=int,
+        default=8,
+        help="""Used only when --method is fast_beam_search""",
+    )
+
+    parser.add_argument(
+        "--context-size",
+        type=int,
+        default=2,
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+    )
+    parser.add_argument(
+        "--max-sym-per-frame",
+        type=int,
+        default=1,
+        help="""Maximum number of symbols per frame. Used only when
+        --method is greedy_search.
+        """,
+    )
+
+    add_model_arguments(parser)
+
+    return parser
+
+
+def read_sound_files(
+    filenames: List[str], expected_sample_rate: float
+) -> List[torch.Tensor]:
+    """Read a list of sound files into a list 1-D float32 torch tensors.
+    Args:
+      filenames:
+        A list of sound filenames.
+      expected_sample_rate:
+        The expected sample rate of the sound files.
+    Returns:
+      Return a list of 1-D float32 torch tensors.
+    """
+    ans = []
+    for f in filenames:
+        wave, sample_rate = torchaudio.load(f)
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
+        # We use only the first channel
+        ans.append(wave[0])
+    return ans
+
+
+@torch.no_grad()
+def main():
+    parser = get_parser()
+    args = parser.parse_args()
+
+    params = get_params()
+
+    params.update(vars(args))
+
+    sp = spm.SentencePieceProcessor()
+    sp.load(params.bpe_model)
+
+    #  is defined in local/train_bpe_model.py
+    params.blank_id = sp.piece_to_id("")
+    params.unk_id = sp.piece_to_id("")
+    params.vocab_size = sp.get_piece_size()
+
+    logging.info(f"{params}")
+
+    device = torch.device("cpu")
+    if torch.cuda.is_available():
+        device = torch.device("cuda", 0)
+
+    logging.info(f"device: {device}")
+
+    logging.info("Creating model")
+    model = get_transducer_model(params)
+
+    num_param = sum([p.numel() for p in model.parameters()])
+    logging.info(f"Number of model parameters: {num_param}")
+
+    checkpoint = torch.load(args.checkpoint, map_location="cpu")
+    model.load_state_dict(checkpoint["model"], strict=False)
+    model.to(device)
+    model.eval()
+    model.device = device
+
+    logging.info("Constructing Fbank computer")
+    opts = kaldifeat.FbankOptions()
+    opts.device = device
+    opts.frame_opts.dither = 0
+    opts.frame_opts.snip_edges = False
+    opts.frame_opts.samp_freq = params.sample_rate
+    opts.mel_opts.num_bins = params.feature_dim
+
+    fbank = kaldifeat.Fbank(opts)
+
+    logging.info(f"Reading sound files: {params.sound_files}")
+    waves = read_sound_files(
+        filenames=params.sound_files, expected_sample_rate=params.sample_rate
+    )
+    waves = [w.to(device) for w in waves]
+
+    logging.info("Decoding started")
+    features = fbank(waves)
+    feature_lengths = [f.size(0) for f in features]
+
+    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
+
+    feature_lengths = torch.tensor(feature_lengths, device=device)
+
+    encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths)
+
+    num_waves = encoder_out.size(0)
+    hyps = []
+    msg = f"Using {params.method}"
+    if params.method == "beam_search":
+        msg += f" with beam size {params.beam_size}"
+    logging.info(msg)
+
+    if params.method == "fast_beam_search":
+        decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
+        hyp_tokens = fast_beam_search_one_best(
+            model=model,
+            decoding_graph=decoding_graph,
+            encoder_out=encoder_out,
+            encoder_out_lens=encoder_out_lens,
+            beam=params.beam,
+            max_contexts=params.max_contexts,
+            max_states=params.max_states,
+        )
+        for hyp in sp.decode(hyp_tokens):
+            hyps.append(hyp.split())
+    elif params.method == "modified_beam_search":
+        hyp_tokens = modified_beam_search(
+            model=model,
+            encoder_out=encoder_out,
+            encoder_out_lens=encoder_out_lens,
+            beam=params.beam_size,
+        )
+
+        for hyp in sp.decode(hyp_tokens):
+            hyps.append(hyp.split())
+    elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
+        hyp_tokens = greedy_search_batch(
+            model=model,
+            encoder_out=encoder_out,
+            encoder_out_lens=encoder_out_lens,
+        )
+        for hyp in sp.decode(hyp_tokens):
+            hyps.append(hyp.split())
+    else:
+        for i in range(num_waves):
+            # fmt: off
+            encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
+            # fmt: on
+            if params.method == "greedy_search":
+                hyp = greedy_search(
+                    model=model,
+                    encoder_out=encoder_out_i,
+                    max_sym_per_frame=params.max_sym_per_frame,
+                )
+            elif params.method == "beam_search":
+                hyp = beam_search(
+                    model=model,
+                    encoder_out=encoder_out_i,
+                    beam=params.beam_size,
+                )
+            else:
+                raise ValueError(f"Unsupported method: {params.method}")
+
+            hyps.append(sp.decode(hyp).split())
+
+    s = "\n"
+    for filename, hyp in zip(params.sound_files, hyps):
+        words = " ".join(hyp)
+        s += f"{filename}:\n{words}\n\n"
+    logging.info(s)
+
+    logging.info("Decoding Done")
+
+
+if __name__ == "__main__":
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+    logging.basicConfig(format=formatter, level=logging.INFO)
+    main()
diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/scaling.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/scaling.py
new file mode 120000
index 000000000..c10cdfe12
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/scaling.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless2/scaling.py
\ No newline at end of file
diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/scaling_converter.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/scaling_converter.py
new file mode 120000
index 000000000..db93d155b
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/scaling_converter.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py
\ No newline at end of file
diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/streaming_beam_search.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/streaming_beam_search.py
new file mode 120000
index 000000000..1199a61d6
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/streaming_beam_search.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless2/streaming_beam_search.py
\ No newline at end of file
diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/streaming_decode.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/streaming_decode.py
new file mode 120000
index 000000000..f29284163
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/streaming_decode.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless5/streaming_decode.py
\ No newline at end of file
diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/test_model.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/test_model.py
new file mode 100755
index 000000000..9aad32014
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/test_model.py
@@ -0,0 +1,65 @@
+#!/usr/bin/env python3
+# Copyright    2022  Xiaomi Corp.        (authors: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""
+To run this file, do:
+
+    cd icefall/egs/librispeech/ASR
+    python ./pruned_transducer_stateless4/test_model.py
+"""
+
+from train import get_params, get_transducer_model
+
+
+def test_model_1():
+    params = get_params()
+    params.vocab_size = 500
+    params.blank_id = 0
+    params.context_size = 2
+    params.num_encoder_layers = 24
+    params.dim_feedforward = 1536  # 384 * 4
+    params.encoder_dim = 384
+    model = get_transducer_model(params)
+    num_param = sum([p.numel() for p in model.parameters()])
+    print(f"Number of model parameters: {num_param}")
+
+
+# See Table 1 from https://arxiv.org/pdf/2005.08100.pdf
+def test_model_M():
+    params = get_params()
+    params.vocab_size = 500
+    params.blank_id = 0
+    params.context_size = 2
+    params.num_encoder_layers = 18
+    params.dim_feedforward = 1024
+    params.encoder_dim = 256
+    params.nhead = 4
+    params.decoder_dim = 512
+    params.joiner_dim = 512
+    model = get_transducer_model(params)
+    num_param = sum([p.numel() for p in model.parameters()])
+    print(f"Number of model parameters: {num_param}")
+
+
+def main():
+    #  test_model_1()
+    test_model_M()
+
+
+if __name__ == "__main__":
+    main()
diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/train.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/train.py
new file mode 100755
index 000000000..5b5ac17be
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/train.py
@@ -0,0 +1,1187 @@
+#!/usr/bin/env python3
+# Copyright    2021-2022  Xiaomi Corp.        (authors: Fangjun Kuang,
+#                                                       Wei Kang,
+#                                                       Mingshuang Luo,)
+#                                                       Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+
+export CUDA_VISIBLE_DEVICES="0,1,2,3"
+
+./pruned_transducer_stateless5/train.py \
+  --world-size 4 \
+  --num-epochs 30 \
+  --start-epoch 1 \
+  --exp-dir pruned_transducer_stateless5/exp \
+  --full-libri 1 \
+  --max-duration 300
+
+# For mix precision training:
+
+./pruned_transducer_stateless5/train.py \
+  --world-size 4 \
+  --num-epochs 30 \
+  --start-epoch 1 \
+  --use-fp16 1 \
+  --exp-dir pruned_transducer_stateless5/exp \
+  --full-libri 1 \
+  --max-duration 550
+
+"""
+
+
+import argparse
+import copy
+import logging
+import warnings
+from pathlib import Path
+from shutil import copyfile
+from typing import Any, Dict, Optional, Tuple, Union
+
+import k2
+import optim
+import sentencepiece as spm
+import torch
+import torch.multiprocessing as mp
+import torch.nn as nn
+from asr_datamodule import Xbmu_AmdoAsrDataModule
+from conformer import Conformer
+from decoder import Decoder
+from joiner import Joiner
+from lhotse.cut import Cut
+from lhotse.dataset.sampling.base import CutSampler
+from lhotse.utils import fix_random_seed
+from model import Transducer
+from optim import Eden, Eve
+from torch import Tensor
+from torch.cuda.amp import GradScaler
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.utils.tensorboard import SummaryWriter
+
+from icefall import diagnostics
+from icefall.checkpoint import load_checkpoint, remove_checkpoints
+from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
+from icefall.checkpoint import (
+    save_checkpoint_with_global_batch_idx,
+    update_averaged_model,
+)
+from icefall.dist import cleanup_dist, setup_dist
+from icefall.env import get_env_info
+from icefall.utils import (
+    AttributeDict,
+    MetricsTracker,
+    display_and_save_batch,
+    setup_logger,
+    str2bool,
+)
+
+LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
+
+
+def add_model_arguments(parser: argparse.ArgumentParser):
+    parser.add_argument(
+        "--num-encoder-layers",
+        type=int,
+        default=24,
+        help="Number of conformer encoder layers..",
+    )
+
+    parser.add_argument(
+        "--dim-feedforward",
+        type=int,
+        default=1536,
+        help="Feedforward dimension of the conformer encoder layer.",
+    )
+
+    parser.add_argument(
+        "--nhead",
+        type=int,
+        default=8,
+        help="Number of attention heads in the conformer encoder layer.",
+    )
+
+    parser.add_argument(
+        "--encoder-dim",
+        type=int,
+        default=384,
+        help="Attention dimension in the conformer encoder layer.",
+    )
+
+    parser.add_argument(
+        "--decoder-dim",
+        type=int,
+        default=512,
+        help="Embedding dimension in the decoder model.",
+    )
+
+    parser.add_argument(
+        "--joiner-dim",
+        type=int,
+        default=512,
+        help="""Dimension used in the joiner model.
+        Outputs from the encoder and decoder model are projected
+        to this dimension before adding.
+        """,
+    )
+
+    parser.add_argument(
+        "--dynamic-chunk-training",
+        type=str2bool,
+        default=False,
+        help="""Whether to use dynamic_chunk_training, if you want a streaming
+        model, this requires to be True.
+        """,
+    )
+
+    parser.add_argument(
+        "--causal-convolution",
+        type=str2bool,
+        default=False,
+        help="""Whether to use causal convolution, this requires to be True when
+        using dynamic_chunk_training.
+        """,
+    )
+
+    parser.add_argument(
+        "--short-chunk-size",
+        type=int,
+        default=25,
+        help="""Chunk length of dynamic training, the chunk size would be either
+        max sequence length of current batch or uniformly sampled from (1, short_chunk_size).
+        """,
+    )
+
+    parser.add_argument(
+        "--num-left-chunks",
+        type=int,
+        default=4,
+        help="How many left context can be seen in chunks when calculating attention.",
+    )
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+
+    parser.add_argument(
+        "--world-size",
+        type=int,
+        default=1,
+        help="Number of GPUs for DDP training.",
+    )
+
+    parser.add_argument(
+        "--master-port",
+        type=int,
+        default=12354,
+        help="Master port to use for DDP training.",
+    )
+
+    parser.add_argument(
+        "--tensorboard",
+        type=str2bool,
+        default=True,
+        help="Should various information be logged in tensorboard.",
+    )
+
+    parser.add_argument(
+        "--num-epochs",
+        type=int,
+        default=30,
+        help="Number of epochs to train.",
+    )
+
+    parser.add_argument(
+        "--start-epoch",
+        type=int,
+        default=1,
+        help="""Resume training from this epoch. It should be positive.
+        If larger than 1, it will load checkpoint from
+        exp-dir/epoch-{start_epoch-1}.pt
+        """,
+    )
+
+    parser.add_argument(
+        "--start-batch",
+        type=int,
+        default=0,
+        help="""If positive, --start-epoch is ignored and
+        it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt
+        """,
+    )
+
+    parser.add_argument(
+        "--exp-dir",
+        type=str,
+        default="pruned_transducer_stateless5/exp",
+        help="""The experiment dir.
+        It specifies the directory where all training related
+        files, e.g., checkpoints, log, etc, are saved
+        """,
+    )
+
+    parser.add_argument(
+        "--bpe-model",
+        type=str,
+        default="data/lang_bpe_500/bpe.model",
+        help="Path to the BPE model",
+    )
+
+    parser.add_argument(
+        "--initial-lr",
+        type=float,
+        default=0.003,
+        help="The initial learning rate.  This value should not need to be changed.",
+    )
+
+    parser.add_argument(
+        "--lr-batches",
+        type=float,
+        default=5000,
+        help="""Number of steps that affects how rapidly the learning rate
+        decreases. We suggest not to change this.""",
+    )
+
+    parser.add_argument(
+        "--lr-epochs",
+        type=float,
+        default=6,
+        help="""Number of epochs that affects how rapidly the learning rate decreases.
+        """,
+    )
+
+    parser.add_argument(
+        "--context-size",
+        type=int,
+        default=2,
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+    )
+
+    parser.add_argument(
+        "--prune-range",
+        type=int,
+        default=5,
+        help="The prune range for rnnt loss, it means how many symbols(context)"
+        "we are using to compute the loss",
+    )
+
+    parser.add_argument(
+        "--lm-scale",
+        type=float,
+        default=0.25,
+        help="The scale to smooth the loss with lm "
+        "(output of prediction network) part.",
+    )
+
+    parser.add_argument(
+        "--am-scale",
+        type=float,
+        default=0.0,
+        help="The scale to smooth the loss with am (output of encoder network) part.",
+    )
+
+    parser.add_argument(
+        "--simple-loss-scale",
+        type=float,
+        default=0.5,
+        help="To get pruning ranges, we will calculate a simple version"
+        "loss(joiner is just addition), this simple loss also uses for"
+        "training (as a regularization item). We will scale the simple loss"
+        "with this parameter before adding to the final loss.",
+    )
+
+    parser.add_argument(
+        "--seed",
+        type=int,
+        default=42,
+        help="The seed for random generators intended for reproducibility",
+    )
+
+    parser.add_argument(
+        "--print-diagnostics",
+        type=str2bool,
+        default=False,
+        help="Accumulate stats on activations, print them and exit.",
+    )
+
+    parser.add_argument(
+        "--save-every-n",
+        type=int,
+        default=4000,
+        help="""Save checkpoint after processing this number of batches"
+        periodically. We save checkpoint to exp-dir/ whenever
+        params.batch_idx_train % save_every_n == 0. The checkpoint filename
+        has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
+        Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
+        end of each epoch where `xxx` is the epoch number counting from 0.
+        """,
+    )
+
+    parser.add_argument(
+        "--keep-last-k",
+        type=int,
+        default=30,
+        help="""Only keep this number of checkpoints on disk.
+        For instance, if it is 3, there are only 3 checkpoints
+        in the exp-dir with filenames `checkpoint-xxx.pt`.
+        It does not affect checkpoints with name `epoch-xxx.pt`.
+        """,
+    )
+
+    parser.add_argument(
+        "--average-period",
+        type=int,
+        default=100,
+        help="""Update the averaged model, namely `model_avg`, after processing
+        this number of batches. `model_avg` is a separate version of model,
+        in which each floating-point parameter is the average of all the
+        parameters from the start of training. Each time we take the average,
+        we do: `model_avg = model * (average_period / batch_idx_train) +
+            model_avg * ((batch_idx_train - average_period) / batch_idx_train)`.
+        """,
+    )
+
+    parser.add_argument(
+        "--use-fp16",
+        type=str2bool,
+        default=False,
+        help="Whether to use half precision training.",
+    )
+
+    parser.add_argument(
+        "--delay-penalty",
+        type=float,
+        default=0.0,
+        help="""A constant value used to penalize symbol delay,
+        to encourage streaming models to emit symbols earlier.
+        See https://github.com/k2-fsa/k2/issues/955 and
+        https://arxiv.org/pdf/2211.00490.pdf for more details.""",
+    )
+
+    add_model_arguments(parser)
+
+    return parser
+
+
+def get_params() -> AttributeDict:
+    """Return a dict containing training parameters.
+
+    All training related parameters that are not passed from the commandline
+    are saved in the variable `params`.
+
+    Commandline options are merged into `params` after they are parsed, so
+    you can also access them via `params`.
+
+    Explanation of options saved in `params`:
+
+        - best_train_loss: Best training loss so far. It is used to select
+                           the model that has the lowest training loss. It is
+                           updated during the training.
+
+        - best_valid_loss: Best validation loss so far. It is used to select
+                           the model that has the lowest validation loss. It is
+                           updated during the training.
+
+        - best_train_epoch: It is the epoch that has the best training loss.
+
+        - best_valid_epoch: It is the epoch that has the best validation loss.
+
+        - batch_idx_train: Used to writing statistics to tensorboard. It
+                           contains number of batches trained so far across
+                           epochs.
+
+        - log_interval:  Print training loss if batch_idx % log_interval` is 0
+
+        - reset_interval: Reset statistics if batch_idx % reset_interval is 0
+
+        - valid_interval:  Run validation if batch_idx % valid_interval is 0
+
+        - feature_dim: The model input dim. It has to match the one used
+                       in computing features.
+
+        - subsampling_factor:  The subsampling factor for the model.
+
+        - encoder_dim: Hidden dim for multi-head attention model.
+
+        - num_decoder_layers: Number of decoder layer of transformer decoder.
+
+        - warm_step: The warm_step for Noam optimizer.
+    """
+    params = AttributeDict(
+        {
+            "best_train_loss": float("inf"),
+            "best_valid_loss": float("inf"),
+            "best_train_epoch": -1,
+            "best_valid_epoch": -1,
+            "batch_idx_train": 0,
+            "log_interval": 50,
+            "reset_interval": 200,
+            "valid_interval": 3000,  # For the 100h subset, use 800
+            # parameters for conformer
+            "feature_dim": 80,
+            "subsampling_factor": 4,
+            # parameters for Noam
+            "model_warm_step": 3000,  # arg given to model, not for lrate
+            "env_info": get_env_info(),
+        }
+    )
+
+    return params
+
+
+def get_encoder_model(params: AttributeDict) -> nn.Module:
+    # TODO: We can add an option to switch between Conformer and Transformer
+    encoder = Conformer(
+        num_features=params.feature_dim,
+        subsampling_factor=params.subsampling_factor,
+        d_model=params.encoder_dim,
+        nhead=params.nhead,
+        dim_feedforward=params.dim_feedforward,
+        num_encoder_layers=params.num_encoder_layers,
+        dynamic_chunk_training=params.dynamic_chunk_training,
+        short_chunk_size=params.short_chunk_size,
+        num_left_chunks=params.num_left_chunks,
+        causal=params.causal_convolution,
+    )
+    return encoder
+
+
+def get_decoder_model(params: AttributeDict) -> nn.Module:
+    decoder = Decoder(
+        vocab_size=params.vocab_size,
+        decoder_dim=params.decoder_dim,
+        blank_id=params.blank_id,
+        context_size=params.context_size,
+    )
+    return decoder
+
+
+def get_joiner_model(params: AttributeDict) -> nn.Module:
+    joiner = Joiner(
+        encoder_dim=params.encoder_dim,
+        decoder_dim=params.decoder_dim,
+        joiner_dim=params.joiner_dim,
+        vocab_size=params.vocab_size,
+    )
+    return joiner
+
+
+def get_transducer_model(params: AttributeDict) -> nn.Module:
+    encoder = get_encoder_model(params)
+    decoder = get_decoder_model(params)
+    joiner = get_joiner_model(params)
+
+    model = Transducer(
+        encoder=encoder,
+        decoder=decoder,
+        joiner=joiner,
+        encoder_dim=params.encoder_dim,
+        decoder_dim=params.decoder_dim,
+        joiner_dim=params.joiner_dim,
+        vocab_size=params.vocab_size,
+    )
+    return model
+
+
+def load_checkpoint_if_available(
+    params: AttributeDict,
+    model: nn.Module,
+    model_avg: nn.Module = None,
+    optimizer: Optional[torch.optim.Optimizer] = None,
+    scheduler: Optional[LRSchedulerType] = None,
+) -> Optional[Dict[str, Any]]:
+    """Load checkpoint from file.
+
+    If params.start_batch is positive, it will load the checkpoint from
+    `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if
+    params.start_epoch is larger than 1, it will load the checkpoint from
+    `params.start_epoch - 1`.
+
+    Apart from loading state dict for `model` and `optimizer` it also updates
+    `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
+    and `best_valid_loss` in `params`.
+
+    Args:
+      params:
+        The return value of :func:`get_params`.
+      model:
+        The training model.
+      model_avg:
+        The stored model averaged from the start of training.
+      optimizer:
+        The optimizer that we are using.
+      scheduler:
+        The scheduler that we are using.
+    Returns:
+      Return a dict containing previously saved training info.
+    """
+    if params.start_batch > 0:
+        filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt"
+    elif params.start_epoch > 1:
+        filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
+    else:
+        return None
+
+    assert filename.is_file(), f"{filename} does not exist!"
+
+    saved_params = load_checkpoint(
+        filename,
+        model=model,
+        model_avg=model_avg,
+        optimizer=optimizer,
+        scheduler=scheduler,
+    )
+
+    keys = [
+        "best_train_epoch",
+        "best_valid_epoch",
+        "batch_idx_train",
+        "best_train_loss",
+        "best_valid_loss",
+    ]
+    for k in keys:
+        params[k] = saved_params[k]
+
+    if params.start_batch > 0:
+        if "cur_epoch" in saved_params:
+            params["start_epoch"] = saved_params["cur_epoch"]
+
+    return saved_params
+
+
+def save_checkpoint(
+    params: AttributeDict,
+    model: Union[nn.Module, DDP],
+    model_avg: Optional[nn.Module] = None,
+    optimizer: Optional[torch.optim.Optimizer] = None,
+    scheduler: Optional[LRSchedulerType] = None,
+    sampler: Optional[CutSampler] = None,
+    scaler: Optional[GradScaler] = None,
+    rank: int = 0,
+) -> None:
+    """Save model, optimizer, scheduler and training stats to file.
+
+    Args:
+      params:
+        It is returned by :func:`get_params`.
+      model:
+        The training model.
+      model_avg:
+        The stored model averaged from the start of training.
+      optimizer:
+        The optimizer used in the training.
+      sampler:
+       The sampler for the training dataset.
+      scaler:
+        The scaler used for mix precision training.
+    """
+    if rank != 0:
+        return
+    filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
+    save_checkpoint_impl(
+        filename=filename,
+        model=model,
+        model_avg=model_avg,
+        params=params,
+        optimizer=optimizer,
+        scheduler=scheduler,
+        sampler=sampler,
+        scaler=scaler,
+        rank=rank,
+    )
+
+    if params.best_train_epoch == params.cur_epoch:
+        best_train_filename = params.exp_dir / "best-train-loss.pt"
+        copyfile(src=filename, dst=best_train_filename)
+
+    if params.best_valid_epoch == params.cur_epoch:
+        best_valid_filename = params.exp_dir / "best-valid-loss.pt"
+        copyfile(src=filename, dst=best_valid_filename)
+
+
+def compute_loss(
+    params: AttributeDict,
+    model: Union[nn.Module, DDP],
+    sp: spm.SentencePieceProcessor,
+    batch: dict,
+    is_training: bool,
+    warmup: float = 1.0,
+) -> Tuple[Tensor, MetricsTracker]:
+    """
+    Compute RNN-T loss given the model and its inputs.
+
+    Args:
+      params:
+        Parameters for training. See :func:`get_params`.
+      model:
+        The model for training. It is an instance of Conformer in our case.
+      batch:
+        A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+        for the content in it.
+      is_training:
+        True for training. False for validation. When it is True, this
+        function enables autograd during computation; when it is False, it
+        disables autograd.
+     warmup: a floating point value which increases throughout training;
+        values >= 1.0 are fully warmed up and have all modules present.
+    """
+    device = model.device if isinstance(model, DDP) else next(model.parameters()).device
+    feature = batch["inputs"]
+    # at entry, feature is (N, T, C)
+    assert feature.ndim == 3
+    feature = feature.to(device)
+
+    supervisions = batch["supervisions"]
+    feature_lens = supervisions["num_frames"].to(device)
+
+    texts = batch["supervisions"]["text"]
+    y = sp.encode(texts, out_type=int)
+    y = k2.RaggedTensor(y).to(device)
+
+    with torch.set_grad_enabled(is_training):
+        simple_loss, pruned_loss = model(
+            x=feature,
+            x_lens=feature_lens,
+            y=y,
+            prune_range=params.prune_range,
+            am_scale=params.am_scale,
+            lm_scale=params.lm_scale,
+            warmup=warmup,
+            reduction="none",
+            delay_penalty=params.delay_penalty if warmup >= 2.0 else 0,
+        )
+        simple_loss_is_finite = torch.isfinite(simple_loss)
+        pruned_loss_is_finite = torch.isfinite(pruned_loss)
+        is_finite = simple_loss_is_finite & pruned_loss_is_finite
+        if not torch.all(is_finite):
+            logging.info(
+                "Not all losses are finite!\n"
+                f"simple_loss: {simple_loss}\n"
+                f"pruned_loss: {pruned_loss}"
+            )
+            display_and_save_batch(batch, params=params, sp=sp)
+            simple_loss = simple_loss[simple_loss_is_finite]
+            pruned_loss = pruned_loss[pruned_loss_is_finite]
+
+            # If the batch contains more than 10 utterances AND
+            # if either all simple_loss or pruned_loss is inf or nan,
+            # we stop the training process by raising an exception
+            if torch.all(~simple_loss_is_finite) or torch.all(~pruned_loss_is_finite):
+                raise ValueError(
+                    "There are too many utterances in this batch "
+                    "leading to inf or nan losses."
+                )
+
+        simple_loss = simple_loss.sum()
+        pruned_loss = pruned_loss.sum()
+        # after the main warmup step, we keep pruned_loss_scale small
+        # for the same amount of time (model_warm_step), to avoid
+        # overwhelming the simple_loss and causing it to diverge,
+        # in case it had not fully learned the alignment yet.
+        pruned_loss_scale = (
+            0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
+        )
+        loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
+
+    assert loss.requires_grad == is_training
+
+    info = MetricsTracker()
+    with warnings.catch_warnings():
+        warnings.simplefilter("ignore")
+        # info["frames"] is an approximate number for two reasons:
+        # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2
+        # (2) If some utterances in the batch lead to inf/nan loss, they
+        #     are filtered out.
+        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
+
+    # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances`  # noqa
+    info["utterances"] = feature.size(0)
+    # averaged input duration in frames over utterances
+    info["utt_duration"] = feature_lens.sum().item()
+    # averaged padding proportion over utterances
+    info["utt_pad_proportion"] = (
+        ((feature.size(1) - feature_lens) / feature.size(1)).sum().item()
+    )
+
+    # Note: We use reduction=sum while computing the loss.
+    info["loss"] = loss.detach().cpu().item()
+    info["simple_loss"] = simple_loss.detach().cpu().item()
+    info["pruned_loss"] = pruned_loss.detach().cpu().item()
+
+    return loss, info
+
+
+def compute_validation_loss(
+    params: AttributeDict,
+    model: Union[nn.Module, DDP],
+    sp: spm.SentencePieceProcessor,
+    valid_dl: torch.utils.data.DataLoader,
+    world_size: int = 1,
+) -> MetricsTracker:
+    """Run the validation process."""
+    model.eval()
+
+    tot_loss = MetricsTracker()
+
+    for batch_idx, batch in enumerate(valid_dl):
+        loss, loss_info = compute_loss(
+            params=params,
+            model=model,
+            sp=sp,
+            batch=batch,
+            is_training=False,
+        )
+        assert loss.requires_grad is False
+        tot_loss = tot_loss + loss_info
+
+    if world_size > 1:
+        tot_loss.reduce(loss.device)
+
+    loss_value = tot_loss["loss"] / tot_loss["frames"]
+    if loss_value < params.best_valid_loss:
+        params.best_valid_epoch = params.cur_epoch
+        params.best_valid_loss = loss_value
+
+    return tot_loss
+
+
+def train_one_epoch(
+    params: AttributeDict,
+    model: Union[nn.Module, DDP],
+    optimizer: torch.optim.Optimizer,
+    scheduler: LRSchedulerType,
+    sp: spm.SentencePieceProcessor,
+    train_dl: torch.utils.data.DataLoader,
+    valid_dl: torch.utils.data.DataLoader,
+    scaler: GradScaler,
+    model_avg: Optional[nn.Module] = None,
+    tb_writer: Optional[SummaryWriter] = None,
+    world_size: int = 1,
+    rank: int = 0,
+) -> None:
+    """Train the model for one epoch.
+
+    The training loss from the mean of all frames is saved in
+    `params.train_loss`. It runs the validation process every
+    `params.valid_interval` batches.
+
+    Args:
+      params:
+        It is returned by :func:`get_params`.
+      model:
+        The model for training.
+      optimizer:
+        The optimizer we are using.
+      scheduler:
+        The learning rate scheduler, we call step() every step.
+      train_dl:
+        Dataloader for the training dataset.
+      valid_dl:
+        Dataloader for the validation dataset.
+      scaler:
+        The scaler used for mix precision training.
+      model_avg:
+        The stored model averaged from the start of training.
+      tb_writer:
+        Writer to write log messages to tensorboard.
+      world_size:
+        Number of nodes in DDP training. If it is 1, DDP is disabled.
+      rank:
+        The rank of the node in DDP training. If no DDP is used, it should
+        be set to 0.
+    """
+    model.train()
+
+    tot_loss = MetricsTracker()
+
+    for batch_idx, batch in enumerate(train_dl):
+        params.batch_idx_train += 1
+        batch_size = len(batch["supervisions"]["text"])
+
+        try:
+            with torch.cuda.amp.autocast(enabled=params.use_fp16):
+                loss, loss_info = compute_loss(
+                    params=params,
+                    model=model,
+                    sp=sp,
+                    batch=batch,
+                    is_training=True,
+                    warmup=(params.batch_idx_train / params.model_warm_step),
+                )
+            # summary stats
+            tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
+
+            # NOTE: We use reduction==sum and loss is computed over utterances
+            # in the batch and there is no normalization to it so far.
+            scaler.scale(loss).backward()
+            scheduler.step_batch(params.batch_idx_train)
+            scaler.step(optimizer)
+            scaler.update()
+            optimizer.zero_grad()
+        except:  # noqa
+            display_and_save_batch(batch, params=params, sp=sp)
+            raise
+
+        if params.print_diagnostics and batch_idx == 5:
+            return
+
+        if (
+            rank == 0
+            and params.batch_idx_train > 0
+            and params.batch_idx_train % params.average_period == 0
+        ):
+            update_averaged_model(
+                params=params,
+                model_cur=model,
+                model_avg=model_avg,
+            )
+
+        if (
+            params.batch_idx_train > 0
+            and params.batch_idx_train % params.save_every_n == 0
+        ):
+            save_checkpoint_with_global_batch_idx(
+                out_dir=params.exp_dir,
+                global_batch_idx=params.batch_idx_train,
+                model=model,
+                model_avg=model_avg,
+                params=params,
+                optimizer=optimizer,
+                scheduler=scheduler,
+                sampler=train_dl.sampler,
+                scaler=scaler,
+                rank=rank,
+            )
+            remove_checkpoints(
+                out_dir=params.exp_dir,
+                topk=params.keep_last_k,
+                rank=rank,
+            )
+
+        if batch_idx % params.log_interval == 0:
+            cur_lr = scheduler.get_last_lr()[0]
+            logging.info(
+                f"Epoch {params.cur_epoch}, "
+                f"batch {batch_idx}, loss[{loss_info}], "
+                f"tot_loss[{tot_loss}], batch size: {batch_size}, "
+                f"lr: {cur_lr:.2e}"
+            )
+
+            if tb_writer is not None:
+                tb_writer.add_scalar(
+                    "train/learning_rate", cur_lr, params.batch_idx_train
+                )
+
+                loss_info.write_summary(
+                    tb_writer, "train/current_", params.batch_idx_train
+                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+
+        if batch_idx > 0 and batch_idx % params.valid_interval == 0:
+            logging.info("Computing validation loss")
+            valid_info = compute_validation_loss(
+                params=params,
+                model=model,
+                sp=sp,
+                valid_dl=valid_dl,
+                world_size=world_size,
+            )
+            model.train()
+            logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
+            if tb_writer is not None:
+                valid_info.write_summary(
+                    tb_writer, "train/valid_", params.batch_idx_train
+                )
+
+    loss_value = tot_loss["loss"] / tot_loss["frames"]
+    params.train_loss = loss_value
+    if params.train_loss < params.best_train_loss:
+        params.best_train_epoch = params.cur_epoch
+        params.best_train_loss = params.train_loss
+
+
+def run(rank, world_size, args):
+    """
+    Args:
+      rank:
+        It is a value between 0 and `world_size-1`, which is
+        passed automatically by `mp.spawn()` in :func:`main`.
+        The node with rank 0 is responsible for saving checkpoint.
+      world_size:
+        Number of GPUs for DDP training.
+      args:
+        The return value of get_parser().parse_args()
+    """
+    params = get_params()
+    params.update(vars(args))
+
+    fix_random_seed(params.seed)
+    if world_size > 1:
+        setup_dist(rank, world_size, params.master_port)
+
+    setup_logger(f"{params.exp_dir}/log/log-train")
+    logging.info("Training started")
+
+    if args.tensorboard and rank == 0:
+        tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
+    else:
+        tb_writer = None
+
+    device = torch.device("cpu")
+    if torch.cuda.is_available():
+        device = torch.device("cuda", rank)
+    logging.info(f"Device: {device}")
+
+    sp = spm.SentencePieceProcessor()
+    sp.load(params.bpe_model)
+
+    #  is defined in local/train_bpe_model.py
+    params.blank_id = sp.piece_to_id("")
+    params.vocab_size = sp.get_piece_size()
+
+    if params.dynamic_chunk_training:
+        assert (
+            params.causal_convolution
+        ), "dynamic_chunk_training requires causal convolution"
+
+    logging.info(params)
+
+    logging.info("About to create model")
+    model = get_transducer_model(params)
+
+    num_param = sum([p.numel() for p in model.parameters()])
+    logging.info(f"Number of model parameters: {num_param}")
+
+    assert params.save_every_n >= params.average_period
+    model_avg: Optional[nn.Module] = None
+    if rank == 0:
+        # model_avg is only used with rank 0
+        model_avg = copy.deepcopy(model)
+
+    assert params.start_epoch > 0, params.start_epoch
+    checkpoints = load_checkpoint_if_available(
+        params=params, model=model, model_avg=model_avg
+    )
+
+    model.to(device)
+    if world_size > 1:
+        logging.info("Using DDP")
+        model = DDP(model, device_ids=[rank])
+
+    optimizer = Eve(model.parameters(), lr=params.initial_lr)
+
+    scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
+
+    if checkpoints and "optimizer" in checkpoints:
+        logging.info("Loading optimizer state dict")
+        optimizer.load_state_dict(checkpoints["optimizer"])
+
+    if (
+        checkpoints
+        and "scheduler" in checkpoints
+        and checkpoints["scheduler"] is not None
+    ):
+        logging.info("Loading scheduler state dict")
+        scheduler.load_state_dict(checkpoints["scheduler"])
+
+    if params.print_diagnostics:
+        opts = diagnostics.TensorDiagnosticOptions(
+            2**22
+        )  # allow 4 megabytes per sub-module
+        diagnostic = diagnostics.attach_diagnostics(model, opts)
+
+    xbmu_amdo = Xbmu_AmdoAsrDataModule(args)
+
+    train_cuts = xbmu_amdo.train_cuts()
+
+    def remove_short_and_long_utt(c: Cut):
+        # Keep only utterances with duration between 1 second and 20 seconds
+        #
+        # Caution: There is a reason to select 20.0 here. Please see
+        # ../local/display_manifest_statistics.py
+        #
+        # You should use ../local/display_manifest_statistics.py to get
+        # an utterance duration distribution for your dataset to select
+        # the threshold
+        if c.duration < 1.0 or c.duration > 20.0:
+            logging.warning(
+                f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
+            )
+            return False
+
+        # In pruned RNN-T, we require that T >= S
+        # where T is the number of feature frames after subsampling
+        # and S is the number of tokens in the utterance
+
+        # In ./conformer.py, the conv module uses the following expression
+        # for subsampling
+        T = ((c.num_frames - 1) // 2 - 1) // 2
+        tokens = sp.encode(c.supervisions[0].text, out_type=str)
+
+        if T < len(tokens):
+            logging.warning(
+                f"Exclude cut with ID {c.id} from training. "
+                f"Number of frames (before subsampling): {c.num_frames}. "
+                f"Number of frames (after subsampling): {T}. "
+                f"Text: {c.supervisions[0].text}. "
+                f"Tokens: {tokens}. "
+                f"Number of tokens: {len(tokens)}"
+            )
+            return False
+
+        return True
+
+    train_cuts = train_cuts.filter(remove_short_and_long_utt)
+
+    if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
+        # We only load the sampler's state dict when it loads a checkpoint
+        # saved in the middle of an epoch
+        sampler_state_dict = checkpoints["sampler"]
+    else:
+        sampler_state_dict = None
+
+    train_dl = xbmu_amdo.train_dataloaders(
+        train_cuts, sampler_state_dict=sampler_state_dict
+    )
+
+    valid_cuts = xbmu_amdo.valid_cuts()
+    valid_dl = xbmu_amdo.valid_dataloaders(valid_cuts)
+
+    if params.start_batch <= 0 and not params.print_diagnostics:
+        scan_pessimistic_batches_for_oom(
+            model=model,
+            train_dl=train_dl,
+            optimizer=optimizer,
+            sp=sp,
+            params=params,
+            warmup=0.0 if params.start_epoch == 1 else 1.0,
+        )
+
+    scaler = GradScaler(enabled=params.use_fp16)
+    if checkpoints and "grad_scaler" in checkpoints:
+        logging.info("Loading grad scaler state dict")
+        scaler.load_state_dict(checkpoints["grad_scaler"])
+
+    for epoch in range(params.start_epoch, params.num_epochs + 1):
+        scheduler.step_epoch(epoch - 1)
+        fix_random_seed(params.seed + epoch - 1)
+        train_dl.sampler.set_epoch(epoch - 1)
+
+        if tb_writer is not None:
+            tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
+
+        params.cur_epoch = epoch
+
+        train_one_epoch(
+            params=params,
+            model=model,
+            model_avg=model_avg,
+            optimizer=optimizer,
+            scheduler=scheduler,
+            sp=sp,
+            train_dl=train_dl,
+            valid_dl=valid_dl,
+            scaler=scaler,
+            tb_writer=tb_writer,
+            world_size=world_size,
+            rank=rank,
+        )
+
+        if params.print_diagnostics:
+            diagnostic.print_diagnostics()
+            break
+
+        save_checkpoint(
+            params=params,
+            model=model,
+            model_avg=model_avg,
+            optimizer=optimizer,
+            scheduler=scheduler,
+            sampler=train_dl.sampler,
+            scaler=scaler,
+            rank=rank,
+        )
+
+    logging.info("Done!")
+
+    if world_size > 1:
+        torch.distributed.barrier()
+        cleanup_dist()
+
+
+def scan_pessimistic_batches_for_oom(
+    model: Union[nn.Module, DDP],
+    train_dl: torch.utils.data.DataLoader,
+    optimizer: torch.optim.Optimizer,
+    sp: spm.SentencePieceProcessor,
+    params: AttributeDict,
+    warmup: float,
+):
+    from lhotse.dataset import find_pessimistic_batches
+
+    logging.info(
+        "Sanity check -- see if any of the batches in epoch 1 would cause OOM."
+    )
+    batches, crit_values = find_pessimistic_batches(train_dl.sampler)
+    for criterion, cuts in batches.items():
+        batch = train_dl.dataset[cuts]
+        try:
+            with torch.cuda.amp.autocast(enabled=params.use_fp16):
+                loss, _ = compute_loss(
+                    params=params,
+                    model=model,
+                    sp=sp,
+                    batch=batch,
+                    is_training=True,
+                    warmup=warmup,
+                )
+            loss.backward()
+            optimizer.step()
+            optimizer.zero_grad()
+        except Exception as e:
+            if "CUDA out of memory" in str(e):
+                logging.error(
+                    "Your GPU ran out of memory with the current "
+                    "max_duration setting. We recommend decreasing "
+                    "max_duration and trying again.\n"
+                    f"Failing criterion: {criterion} "
+                    f"(={crit_values[criterion]}) ..."
+                )
+            display_and_save_batch(batch, params=params, sp=sp)
+            raise
+
+
+def main():
+    parser = get_parser()
+    Xbmu_AmdoAsrDataModule.add_arguments(parser)
+    args = parser.parse_args()
+    args.exp_dir = Path(args.exp_dir)
+
+    world_size = args.world_size
+    assert world_size >= 1
+    if world_size > 1:
+        mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
+    else:
+        run(rank=0, world_size=1, args=args)
+
+
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+if __name__ == "__main__":
+    main()
diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/__init__.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/asr_datamodule.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/asr_datamodule.py
new file mode 120000
index 000000000..c473a600a
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/asr_datamodule.py
@@ -0,0 +1 @@
+../pruned_transducer_stateless5/asr_datamodule.py
\ No newline at end of file
diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/beam_search.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/beam_search.py
new file mode 120000
index 000000000..e24eca39f
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/beam_search.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless2/beam_search.py
\ No newline at end of file
diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/decode.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/decode.py
new file mode 100755
index 000000000..ace792e13
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/decode.py
@@ -0,0 +1,843 @@
+#!/usr/bin/env python3
+#
+# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang,
+#                                                 Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+(1) greedy search
+./pruned_transducer_stateless7/decode.py \
+    --epoch 28 \
+    --avg 15 \
+    --exp-dir ./pruned_transducer_stateless7/exp \
+    --max-duration 600 \
+    --decoding-method greedy_search
+
+(2) beam search (not recommended)
+./pruned_transducer_stateless7/decode.py \
+    --epoch 28 \
+    --avg 15 \
+    --exp-dir ./pruned_transducer_stateless7/exp \
+    --max-duration 600 \
+    --decoding-method beam_search \
+    --beam-size 4
+
+(3) modified beam search
+./pruned_transducer_stateless7/decode.py \
+    --epoch 28 \
+    --avg 15 \
+    --exp-dir ./pruned_transducer_stateless7/exp \
+    --max-duration 600 \
+    --decoding-method modified_beam_search \
+    --beam-size 4
+
+(4) fast beam search (one best)
+./pruned_transducer_stateless7/decode.py \
+    --epoch 28 \
+    --avg 15 \
+    --exp-dir ./pruned_transducer_stateless7/exp \
+    --max-duration 600 \
+    --decoding-method fast_beam_search \
+    --beam 20.0 \
+    --max-contexts 8 \
+    --max-states 64
+
+(5) fast beam search (nbest)
+./pruned_transducer_stateless7/decode.py \
+    --epoch 28 \
+    --avg 15 \
+    --exp-dir ./pruned_transducer_stateless7/exp \
+    --max-duration 600 \
+    --decoding-method fast_beam_search_nbest \
+    --beam 20.0 \
+    --max-contexts 8 \
+    --max-states 64 \
+    --num-paths 200 \
+    --nbest-scale 0.5
+
+(6) fast beam search (nbest oracle WER)
+./pruned_transducer_stateless7/decode.py \
+    --epoch 28 \
+    --avg 15 \
+    --exp-dir ./pruned_transducer_stateless7/exp \
+    --max-duration 600 \
+    --decoding-method fast_beam_search_nbest_oracle \
+    --beam 20.0 \
+    --max-contexts 8 \
+    --max-states 64 \
+    --num-paths 200 \
+    --nbest-scale 0.5
+
+(7) fast beam search (with LG)
+./pruned_transducer_stateless7/decode.py \
+    --epoch 28 \
+    --avg 15 \
+    --exp-dir ./pruned_transducer_stateless7/exp \
+    --max-duration 600 \
+    --decoding-method fast_beam_search_nbest_LG \
+    --beam 20.0 \
+    --max-contexts 8 \
+    --max-states 64
+"""
+
+
+import argparse
+import logging
+import math
+from collections import defaultdict
+from pathlib import Path
+from typing import Dict, List, Optional, Tuple
+
+import k2
+import sentencepiece as spm
+import torch
+import torch.nn as nn
+from asr_datamodule import Xbmu_AmdoAsrDataModule
+from beam_search import (
+    beam_search,
+    fast_beam_search_nbest,
+    fast_beam_search_nbest_LG,
+    fast_beam_search_nbest_oracle,
+    fast_beam_search_one_best,
+    greedy_search,
+    greedy_search_batch,
+    modified_beam_search,
+)
+from train import add_model_arguments, get_params, get_transducer_model
+
+from icefall.checkpoint import (
+    average_checkpoints,
+    average_checkpoints_with_averaged_model,
+    find_checkpoints,
+    load_checkpoint,
+)
+from icefall.lexicon import Lexicon
+from icefall.utils import (
+    AttributeDict,
+    setup_logger,
+    store_transcripts,
+    str2bool,
+    write_error_stats,
+)
+
+LOG_EPS = math.log(1e-10)
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+
+    parser.add_argument(
+        "--epoch",
+        type=int,
+        default=30,
+        help="""It specifies the checkpoint to use for decoding.
+        Note: Epoch counts from 1.
+        You can specify --avg to use more checkpoints for model averaging.""",
+    )
+
+    parser.add_argument(
+        "--iter",
+        type=int,
+        default=0,
+        help="""If positive, --epoch is ignored and it
+        will use the checkpoint exp_dir/checkpoint-iter.pt.
+        You can specify --avg to use more checkpoints for model averaging.
+        """,
+    )
+
+    parser.add_argument(
+        "--avg",
+        type=int,
+        default=9,
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch' and '--iter'",
+    )
+
+    parser.add_argument(
+        "--use-averaged-model",
+        type=str2bool,
+        default=True,
+        help="Whether to load averaged model. Currently it only supports "
+        "using --epoch. If True, it would decode with the averaged model "
+        "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+        "Actually only the models with epoch number of `epoch-avg` and "
+        "`epoch` are loaded for averaging. ",
+    )
+
+    parser.add_argument(
+        "--exp-dir",
+        type=str,
+        default="pruned_transducer_stateless7/exp",
+        help="The experiment dir",
+    )
+
+    parser.add_argument(
+        "--bpe-model",
+        type=str,
+        default="data/lang_bpe_500/bpe.model",
+        help="Path to the BPE model",
+    )
+
+    parser.add_argument(
+        "--lang-dir",
+        type=Path,
+        default="data/lang_bpe_500",
+        help="The lang dir containing word table and LG graph",
+    )
+
+    parser.add_argument(
+        "--decoding-method",
+        type=str,
+        default="greedy_search",
+        help="""Possible values are:
+          - greedy_search
+          - beam_search
+          - modified_beam_search
+          - fast_beam_search
+          - fast_beam_search_nbest
+          - fast_beam_search_nbest_oracle
+          - fast_beam_search_nbest_LG
+        If you use fast_beam_search_nbest_LG, you have to specify
+        `--lang-dir`, which should contain `LG.pt`.
+        """,
+    )
+
+    parser.add_argument(
+        "--beam-size",
+        type=int,
+        default=4,
+        help="""An integer indicating how many candidates we will keep for each
+        frame. Used only when --decoding-method is beam_search or
+        modified_beam_search.""",
+    )
+
+    parser.add_argument(
+        "--beam",
+        type=float,
+        default=20.0,
+        help="""A floating point value to calculate the cutoff score during beam
+        search (i.e., `cutoff = max-score - beam`), which is the same as the
+        `beam` in Kaldi.
+        Used only when --decoding-method is fast_beam_search,
+        fast_beam_search_nbest, fast_beam_search_nbest_LG,
+        and fast_beam_search_nbest_oracle
+        """,
+    )
+
+    parser.add_argument(
+        "--ngram-lm-scale",
+        type=float,
+        default=0.01,
+        help="""
+        Used only when --decoding_method is fast_beam_search_nbest_LG.
+        It specifies the scale for n-gram LM scores.
+        """,
+    )
+
+    parser.add_argument(
+        "--max-contexts",
+        type=int,
+        default=8,
+        help="""Used only when --decoding-method is
+        fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
+        and fast_beam_search_nbest_oracle""",
+    )
+
+    parser.add_argument(
+        "--max-states",
+        type=int,
+        default=64,
+        help="""Used only when --decoding-method is
+        fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
+        and fast_beam_search_nbest_oracle""",
+    )
+
+    parser.add_argument(
+        "--context-size",
+        type=int,
+        default=2,
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+    )
+    parser.add_argument(
+        "--max-sym-per-frame",
+        type=int,
+        default=1,
+        help="""Maximum number of symbols per frame.
+        Used only when --decoding_method is greedy_search""",
+    )
+
+    parser.add_argument(
+        "--num-paths",
+        type=int,
+        default=200,
+        help="""Number of paths for nbest decoding.
+        Used only when the decoding method is fast_beam_search_nbest,
+        fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
+    )
+
+    parser.add_argument(
+        "--nbest-scale",
+        type=float,
+        default=0.5,
+        help="""Scale applied to lattice scores when computing nbest paths.
+        Used only when the decoding method is fast_beam_search_nbest,
+        fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
+    )
+
+    parser.add_argument(
+        "--simulate-streaming",
+        type=str2bool,
+        default=False,
+        help="""Whether to simulate streaming in decoding, this is a good way to
+        test a streaming model.
+        """,
+    )
+
+    parser.add_argument(
+        "--decode-chunk-size",
+        type=int,
+        default=16,
+        help="The chunk size for decoding (in frames after subsampling)",
+    )
+
+    parser.add_argument(
+        "--left-context",
+        type=int,
+        default=64,
+        help="left context can be seen during decoding (in frames after subsampling)",
+    )
+
+    add_model_arguments(parser)
+
+    return parser
+
+
+def decode_one_batch(
+    params: AttributeDict,
+    model: nn.Module,
+    sp: spm.SentencePieceProcessor,
+    batch: dict,
+    word_table: Optional[k2.SymbolTable] = None,
+    decoding_graph: Optional[k2.Fsa] = None,
+) -> Dict[str, List[List[str]]]:
+    """Decode one batch and return the result in a dict. The dict has the
+    following format:
+
+        - key: It indicates the setting used for decoding. For example,
+               if greedy_search is used, it would be "greedy_search"
+               If beam search with a beam size of 7 is used, it would be
+               "beam_7"
+        - value: It contains the decoding result. `len(value)` equals to
+                 batch size. `value[i]` is the decoding result for the i-th
+                 utterance in the given batch.
+    Args:
+      params:
+        It's the return value of :func:`get_params`.
+      model:
+        The neural model.
+      sp:
+        The BPE model.
+      batch:
+        It is the return value from iterating
+        `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
+        for the format of the `batch`.
+      word_table:
+        The word symbol table.
+      decoding_graph:
+        The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
+        only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
+        fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
+    Returns:
+      Return the decoding result. See above description for the format of
+      the returned dict.
+    """
+    device = next(model.parameters()).device
+    feature = batch["inputs"]
+    assert feature.ndim == 3
+
+    feature = feature.to(device)
+    # at entry, feature is (N, T, C)
+
+    supervisions = batch["supervisions"]
+    feature_lens = supervisions["num_frames"].to(device)
+
+    if params.simulate_streaming:
+        feature_lens += params.left_context
+        feature = torch.nn.functional.pad(
+            feature,
+            pad=(0, 0, 0, params.left_context),
+            value=LOG_EPS,
+        )
+        encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
+            x=feature,
+            x_lens=feature_lens,
+            chunk_size=params.decode_chunk_size,
+            left_context=params.left_context,
+            simulate_streaming=True,
+        )
+    else:
+        encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
+
+    hyps = []
+
+    if params.decoding_method == "fast_beam_search":
+        hyp_tokens = fast_beam_search_one_best(
+            model=model,
+            decoding_graph=decoding_graph,
+            encoder_out=encoder_out,
+            encoder_out_lens=encoder_out_lens,
+            beam=params.beam,
+            max_contexts=params.max_contexts,
+            max_states=params.max_states,
+        )
+        for hyp in sp.decode(hyp_tokens):
+            hyps.append(hyp.split())
+    elif params.decoding_method == "fast_beam_search_nbest_LG":
+        hyp_tokens = fast_beam_search_nbest_LG(
+            model=model,
+            decoding_graph=decoding_graph,
+            encoder_out=encoder_out,
+            encoder_out_lens=encoder_out_lens,
+            beam=params.beam,
+            max_contexts=params.max_contexts,
+            max_states=params.max_states,
+            num_paths=params.num_paths,
+            nbest_scale=params.nbest_scale,
+        )
+        for hyp in hyp_tokens:
+            hyps.append([word_table[i] for i in hyp])
+    elif params.decoding_method == "fast_beam_search_nbest":
+        hyp_tokens = fast_beam_search_nbest(
+            model=model,
+            decoding_graph=decoding_graph,
+            encoder_out=encoder_out,
+            encoder_out_lens=encoder_out_lens,
+            beam=params.beam,
+            max_contexts=params.max_contexts,
+            max_states=params.max_states,
+            num_paths=params.num_paths,
+            nbest_scale=params.nbest_scale,
+        )
+        for hyp in sp.decode(hyp_tokens):
+            hyps.append(hyp.split())
+    elif params.decoding_method == "fast_beam_search_nbest_oracle":
+        hyp_tokens = fast_beam_search_nbest_oracle(
+            model=model,
+            decoding_graph=decoding_graph,
+            encoder_out=encoder_out,
+            encoder_out_lens=encoder_out_lens,
+            beam=params.beam,
+            max_contexts=params.max_contexts,
+            max_states=params.max_states,
+            num_paths=params.num_paths,
+            ref_texts=sp.encode(supervisions["text"]),
+            nbest_scale=params.nbest_scale,
+        )
+        for hyp in sp.decode(hyp_tokens):
+            hyps.append(hyp.split())
+    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
+        hyp_tokens = greedy_search_batch(
+            model=model,
+            encoder_out=encoder_out,
+            encoder_out_lens=encoder_out_lens,
+        )
+        for hyp in sp.decode(hyp_tokens):
+            hyps.append(hyp.split())
+    elif params.decoding_method == "modified_beam_search":
+        hyp_tokens = modified_beam_search(
+            model=model,
+            encoder_out=encoder_out,
+            encoder_out_lens=encoder_out_lens,
+            beam=params.beam_size,
+        )
+        for hyp in sp.decode(hyp_tokens):
+            hyps.append(hyp.split())
+    else:
+        batch_size = encoder_out.size(0)
+
+        for i in range(batch_size):
+            # fmt: off
+            encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
+            # fmt: on
+            if params.decoding_method == "greedy_search":
+                hyp = greedy_search(
+                    model=model,
+                    encoder_out=encoder_out_i,
+                    max_sym_per_frame=params.max_sym_per_frame,
+                )
+            elif params.decoding_method == "beam_search":
+                hyp = beam_search(
+                    model=model,
+                    encoder_out=encoder_out_i,
+                    beam=params.beam_size,
+                )
+            else:
+                raise ValueError(
+                    f"Unsupported decoding method: {params.decoding_method}"
+                )
+            hyps.append(sp.decode(hyp).split())
+
+    if params.decoding_method == "greedy_search":
+        return {"greedy_search": hyps}
+    elif "fast_beam_search" in params.decoding_method:
+        key = f"beam_{params.beam}_"
+        key += f"max_contexts_{params.max_contexts}_"
+        key += f"max_states_{params.max_states}"
+        if "nbest" in params.decoding_method:
+            key += f"_num_paths_{params.num_paths}_"
+            key += f"nbest_scale_{params.nbest_scale}"
+            if "LG" in params.decoding_method:
+                key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
+
+        return {key: hyps}
+    else:
+        return {f"beam_size_{params.beam_size}": hyps}
+
+
+def decode_dataset(
+    dl: torch.utils.data.DataLoader,
+    params: AttributeDict,
+    model: nn.Module,
+    sp: spm.SentencePieceProcessor,
+    word_table: Optional[k2.SymbolTable] = None,
+    decoding_graph: Optional[k2.Fsa] = None,
+) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
+    """Decode dataset.
+
+    Args:
+      dl:
+        PyTorch's dataloader containing the dataset to decode.
+      params:
+        It is returned by :func:`get_params`.
+      model:
+        The neural model.
+      sp:
+        The BPE model.
+      word_table:
+        The word symbol table.
+      decoding_graph:
+        The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
+        only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
+        fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
+    Returns:
+      Return a dict, whose key may be "greedy_search" if greedy search
+      is used, or it may be "beam_7" if beam size of 7 is used.
+      Its value is a list of tuples. Each tuple contains two elements:
+      The first is the reference transcript, and the second is the
+      predicted result.
+    """
+    num_cuts = 0
+
+    try:
+        num_batches = len(dl)
+    except TypeError:
+        num_batches = "?"
+
+    if params.decoding_method == "greedy_search":
+        log_interval = 50
+    else:
+        log_interval = 20
+
+    results = defaultdict(list)
+    for batch_idx, batch in enumerate(dl):
+        texts = batch["supervisions"]["text"]
+        cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
+
+        hyps_dict = decode_one_batch(
+            params=params,
+            model=model,
+            sp=sp,
+            decoding_graph=decoding_graph,
+            word_table=word_table,
+            batch=batch,
+        )
+
+        for name, hyps in hyps_dict.items():
+            this_batch = []
+            assert len(hyps) == len(texts)
+            for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
+                ref_words = ref_text.split()
+                this_batch.append((cut_id, ref_words, hyp_words))
+
+            results[name].extend(this_batch)
+
+        num_cuts += len(texts)
+
+        if batch_idx % log_interval == 0:
+            batch_str = f"{batch_idx}/{num_batches}"
+
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+    return results
+
+
+def save_results(
+    params: AttributeDict,
+    test_set_name: str,
+    results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
+):
+    test_set_wers = dict()
+    for key, results in results_dict.items():
+        recog_path = (
+            params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
+        )
+        results = sorted(results)
+        store_transcripts(filename=recog_path, texts=results)
+        logging.info(f"The transcripts are stored in {recog_path}")
+
+        # The following prints out WERs, per-word error statistics and aligned
+        # ref/hyp pairs.
+        errs_filename = (
+            params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
+        )
+        with open(errs_filename, "w") as f:
+            wer = write_error_stats(
+                f, f"{test_set_name}-{key}", results, enable_log=True
+            )
+            test_set_wers[key] = wer
+
+        logging.info("Wrote detailed error stats to {}".format(errs_filename))
+
+    test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
+    errs_info = (
+        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+    )
+    with open(errs_info, "w") as f:
+        print("settings\tWER", file=f)
+        for key, val in test_set_wers:
+            print("{}\t{}".format(key, val), file=f)
+
+    s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
+    note = "\tbest for {}".format(test_set_name)
+    for key, val in test_set_wers:
+        s += "{}\t{}{}\n".format(key, val, note)
+        note = ""
+    logging.info(s)
+
+
+@torch.no_grad()
+def main():
+    parser = get_parser()
+    Xbmu_AmdoAsrDataModule.add_arguments(parser)
+    args = parser.parse_args()
+    args.exp_dir = Path(args.exp_dir)
+
+    params = get_params()
+    params.update(vars(args))
+
+    assert params.decoding_method in (
+        "greedy_search",
+        "beam_search",
+        "fast_beam_search",
+        "fast_beam_search_nbest",
+        "fast_beam_search_nbest_LG",
+        "fast_beam_search_nbest_oracle",
+        "modified_beam_search",
+    )
+    params.res_dir = params.exp_dir / params.decoding_method
+
+    if params.iter > 0:
+        params.suffix = f"iter-{params.iter}-avg-{params.avg}"
+    else:
+        params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
+
+    if params.simulate_streaming:
+        params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}"
+        params.suffix += f"-left-context-{params.left_context}"
+
+    if "fast_beam_search" in params.decoding_method:
+        params.suffix += f"-beam-{params.beam}"
+        params.suffix += f"-max-contexts-{params.max_contexts}"
+        params.suffix += f"-max-states-{params.max_states}"
+        if "nbest" in params.decoding_method:
+            params.suffix += f"-nbest-scale-{params.nbest_scale}"
+            params.suffix += f"-num-paths-{params.num_paths}"
+            if "LG" in params.decoding_method:
+                params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
+    elif "beam_search" in params.decoding_method:
+        params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
+    else:
+        params.suffix += f"-context-{params.context_size}"
+        params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
+
+    if params.use_averaged_model:
+        params.suffix += "-use-averaged-model"
+
+    setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
+    logging.info("Decoding started")
+
+    device = torch.device("cpu")
+    if torch.cuda.is_available():
+        device = torch.device("cuda", 0)
+
+    logging.info(f"Device: {device}")
+
+    sp = spm.SentencePieceProcessor()
+    sp.load(params.bpe_model)
+
+    #  and  are defined in local/train_bpe_model.py
+    params.blank_id = sp.piece_to_id("")
+    params.unk_id = sp.piece_to_id("")
+    params.vocab_size = sp.get_piece_size()
+
+    if params.simulate_streaming:
+        assert (
+            params.causal_convolution
+        ), "Decoding in streaming requires causal convolution"
+
+    logging.info(params)
+
+    logging.info("About to create model")
+    model = get_transducer_model(params)
+
+    if not params.use_averaged_model:
+        if params.iter > 0:
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg
+            ]
+            if len(filenames) == 0:
+                raise ValueError(
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            elif len(filenames) < params.avg:
+                raise ValueError(
+                    f"Not enough checkpoints ({len(filenames)}) found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            logging.info(f"averaging {filenames}")
+            model.to(device)
+            model.load_state_dict(average_checkpoints(filenames, device=device))
+        elif params.avg == 1:
+            load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+        else:
+            start = params.epoch - params.avg + 1
+            filenames = []
+            for i in range(start, params.epoch + 1):
+                if i >= 1:
+                    filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+            logging.info(f"averaging {filenames}")
+            model.to(device)
+            model.load_state_dict(average_checkpoints(filenames, device=device))
+    else:
+        if params.iter > 0:
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg + 1
+            ]
+            if len(filenames) == 0:
+                raise ValueError(
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            elif len(filenames) < params.avg + 1:
+                raise ValueError(
+                    f"Not enough checkpoints ({len(filenames)}) found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            filename_start = filenames[-1]
+            filename_end = filenames[0]
+            logging.info(
+                "Calculating the averaged model over iteration checkpoints"
+                f" from {filename_start} (excluded) to {filename_end}"
+            )
+            model.to(device)
+            model.load_state_dict(
+                average_checkpoints_with_averaged_model(
+                    filename_start=filename_start,
+                    filename_end=filename_end,
+                    device=device,
+                )
+            )
+        else:
+            assert params.avg > 0, params.avg
+            start = params.epoch - params.avg
+            assert start >= 1, start
+            filename_start = f"{params.exp_dir}/epoch-{start}.pt"
+            filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
+            logging.info(
+                f"Calculating the averaged model over epoch range from "
+                f"{start} (excluded) to {params.epoch}"
+            )
+            model.to(device)
+            model.load_state_dict(
+                average_checkpoints_with_averaged_model(
+                    filename_start=filename_start,
+                    filename_end=filename_end,
+                    device=device,
+                )
+            )
+
+    model.to(device)
+    model.eval()
+
+    if "fast_beam_search" in params.decoding_method:
+        if params.decoding_method == "fast_beam_search_nbest_LG":
+            lexicon = Lexicon(params.lang_dir)
+            word_table = lexicon.word_table
+            lg_filename = params.lang_dir / "LG.pt"
+            logging.info(f"Loading {lg_filename}")
+            decoding_graph = k2.Fsa.from_dict(
+                torch.load(lg_filename, map_location=device)
+            )
+            decoding_graph.scores *= params.ngram_lm_scale
+        else:
+            word_table = None
+            decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
+    else:
+        decoding_graph = None
+        word_table = None
+
+    num_param = sum([p.numel() for p in model.parameters()])
+    logging.info(f"Number of model parameters: {num_param}")
+
+    # we need cut ids to display recognition results.
+    args.return_cuts = True
+    xbmu_amdo = Xbmu_AmdoAsrDataModule(args)
+
+    test_cuts = xbmu_amdo.test_cuts()
+
+    test_dl = xbmu_amdo.test_dataloaders(test_cuts)
+
+    test_sets = [
+        "test",
+    ]
+    test_dl = [
+        test_dl,
+    ]
+
+    for test_set, test_dl in zip(test_sets, test_dl):
+        results_dict = decode_dataset(
+            dl=test_dl,
+            params=params,
+            model=model,
+            sp=sp,
+            word_table=word_table,
+            decoding_graph=decoding_graph,
+        )
+
+        save_results(
+            params=params,
+            test_set_name=test_set,
+            results_dict=results_dict,
+        )
+
+    logging.info("Done!")
+
+
+if __name__ == "__main__":
+    main()
diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/decoder.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/decoder.py
new file mode 120000
index 000000000..8283d8c5a
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/decoder.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7/decoder.py
\ No newline at end of file
diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/encoder_interface.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/encoder_interface.py
new file mode 120000
index 000000000..f58253127
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/encoder_interface.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless2/encoder_interface.py
\ No newline at end of file
diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/export.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/export.py
new file mode 120000
index 000000000..2713792e6
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/export.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7/export.py
\ No newline at end of file
diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/jit_pretrained.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/jit_pretrained.py
new file mode 120000
index 000000000..a44034e34
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/jit_pretrained.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7/jit_pretrained.py
\ No newline at end of file
diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/joiner.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/joiner.py
new file mode 120000
index 000000000..0f0c3c90a
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/joiner.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7/joiner.py
\ No newline at end of file
diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/model.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/model.py
new file mode 120000
index 000000000..0d8bc665b
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/model.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7/model.py
\ No newline at end of file
diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/optim.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/optim.py
new file mode 120000
index 000000000..8a05abb5f
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/optim.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7/optim.py
\ No newline at end of file
diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/pretrained.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/pretrained.py
new file mode 100755
index 000000000..d05bafcfb
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/pretrained.py
@@ -0,0 +1,355 @@
+#!/usr/bin/env python3
+# Copyright      2021  Xiaomi Corp.        (authors: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This script loads a checkpoint and uses it to decode waves.
+You can generate the checkpoint with the following command:
+
+./pruned_transducer_stateless7/export.py \
+  --exp-dir ./pruned_transducer_stateless7/exp \
+  --bpe-model data/lang_bpe_500/bpe.model \
+  --epoch 20 \
+  --avg 10
+
+Usage of this script:
+
+(1) greedy search
+./pruned_transducer_stateless7/pretrained.py \
+    --checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \
+    --bpe-model ./data/lang_bpe_500/bpe.model \
+    --method greedy_search \
+    /path/to/foo.wav \
+    /path/to/bar.wav
+
+(2) beam search
+./pruned_transducer_stateless7/pretrained.py \
+    --checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \
+    --bpe-model ./data/lang_bpe_500/bpe.model \
+    --method beam_search \
+    --beam-size 4 \
+    /path/to/foo.wav \
+    /path/to/bar.wav
+
+(3) modified beam search
+./pruned_transducer_stateless7/pretrained.py \
+    --checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \
+    --bpe-model ./data/lang_bpe_500/bpe.model \
+    --method modified_beam_search \
+    --beam-size 4 \
+    /path/to/foo.wav \
+    /path/to/bar.wav
+
+(4) fast beam search
+./pruned_transducer_stateless7/pretrained.py \
+    --checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \
+    --bpe-model ./data/lang_bpe_500/bpe.model \
+    --method fast_beam_search \
+    --beam-size 4 \
+    /path/to/foo.wav \
+    /path/to/bar.wav
+
+You can also use `./pruned_transducer_stateless7/exp/epoch-xx.pt`.
+
+Note: ./pruned_transducer_stateless7/exp/pretrained.pt is generated by
+./pruned_transducer_stateless7/export.py
+"""
+
+
+import argparse
+import logging
+import math
+from typing import List
+
+import k2
+import kaldifeat
+import sentencepiece as spm
+import torch
+import torchaudio
+from beam_search import (
+    beam_search,
+    fast_beam_search_one_best,
+    greedy_search,
+    greedy_search_batch,
+    modified_beam_search,
+)
+from torch.nn.utils.rnn import pad_sequence
+from train import add_model_arguments, get_params, get_transducer_model
+
+from icefall.utils import str2bool
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+
+    parser.add_argument(
+        "--checkpoint",
+        type=str,
+        required=True,
+        help="Path to the checkpoint. "
+        "The checkpoint is assumed to be saved by "
+        "icefall.checkpoint.save_checkpoint().",
+    )
+
+    parser.add_argument(
+        "--bpe-model",
+        type=str,
+        help="""Path to bpe.model.""",
+    )
+
+    parser.add_argument(
+        "--method",
+        type=str,
+        default="greedy_search",
+        help="""Possible values are:
+          - greedy_search
+          - beam_search
+          - modified_beam_search
+          - fast_beam_search
+        """,
+    )
+
+    parser.add_argument(
+        "sound_files",
+        type=str,
+        nargs="+",
+        help="The input sound file(s) to transcribe. "
+        "Supported formats are those supported by torchaudio.load(). "
+        "For example, wav and flac are supported. "
+        "The sample rate has to be 16kHz.",
+    )
+
+    parser.add_argument(
+        "--sample-rate",
+        type=int,
+        default=16000,
+        help="The sample rate of the input sound file",
+    )
+
+    parser.add_argument(
+        "--beam-size",
+        type=int,
+        default=4,
+        help="""An integer indicating how many candidates we will keep for each
+        frame. Used only when --method is beam_search or
+        modified_beam_search.""",
+    )
+
+    parser.add_argument(
+        "--beam",
+        type=float,
+        default=4,
+        help="""A floating point value to calculate the cutoff score during beam
+        search (i.e., `cutoff = max-score - beam`), which is the same as the
+        `beam` in Kaldi.
+        Used only when --method is fast_beam_search""",
+    )
+
+    parser.add_argument(
+        "--max-contexts",
+        type=int,
+        default=4,
+        help="""Used only when --method is fast_beam_search""",
+    )
+
+    parser.add_argument(
+        "--max-states",
+        type=int,
+        default=8,
+        help="""Used only when --method is fast_beam_search""",
+    )
+
+    parser.add_argument(
+        "--context-size",
+        type=int,
+        default=2,
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+    )
+    parser.add_argument(
+        "--max-sym-per-frame",
+        type=int,
+        default=1,
+        help="""Maximum number of symbols per frame. Used only when
+        --method is greedy_search.
+        """,
+    )
+
+    add_model_arguments(parser)
+
+    return parser
+
+
+def read_sound_files(
+    filenames: List[str], expected_sample_rate: float
+) -> List[torch.Tensor]:
+    """Read a list of sound files into a list 1-D float32 torch tensors.
+    Args:
+      filenames:
+        A list of sound filenames.
+      expected_sample_rate:
+        The expected sample rate of the sound files.
+    Returns:
+      Return a list of 1-D float32 torch tensors.
+    """
+    ans = []
+    for f in filenames:
+        wave, sample_rate = torchaudio.load(f)
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
+        # We use only the first channel
+        ans.append(wave[0])
+    return ans
+
+
+@torch.no_grad()
+def main():
+    parser = get_parser()
+    args = parser.parse_args()
+
+    params = get_params()
+
+    params.update(vars(args))
+
+    sp = spm.SentencePieceProcessor()
+    sp.load(params.bpe_model)
+
+    #  is defined in local/train_bpe_model.py
+    params.blank_id = sp.piece_to_id("")
+    params.unk_id = sp.piece_to_id("")
+    params.vocab_size = sp.get_piece_size()
+
+    logging.info(f"{params}")
+
+    device = torch.device("cpu")
+    if torch.cuda.is_available():
+        device = torch.device("cuda", 0)
+
+    logging.info(f"device: {device}")
+
+    logging.info("Creating model")
+    model = get_transducer_model(params)
+
+    num_param = sum([p.numel() for p in model.parameters()])
+    logging.info(f"Number of model parameters: {num_param}")
+
+    checkpoint = torch.load(args.checkpoint, map_location="cpu")
+    model.load_state_dict(checkpoint["model"], strict=False)
+    model.to(device)
+    model.eval()
+    model.device = device
+
+    logging.info("Constructing Fbank computer")
+    opts = kaldifeat.FbankOptions()
+    opts.device = device
+    opts.frame_opts.dither = 0
+    opts.frame_opts.snip_edges = False
+    opts.frame_opts.samp_freq = params.sample_rate
+    opts.mel_opts.num_bins = params.feature_dim
+
+    fbank = kaldifeat.Fbank(opts)
+
+    logging.info(f"Reading sound files: {params.sound_files}")
+    waves = read_sound_files(
+        filenames=params.sound_files, expected_sample_rate=params.sample_rate
+    )
+    waves = [w.to(device) for w in waves]
+
+    logging.info("Decoding started")
+    features = fbank(waves)
+    feature_lengths = [f.size(0) for f in features]
+
+    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
+
+    feature_lengths = torch.tensor(feature_lengths, device=device)
+
+    encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths)
+
+    num_waves = encoder_out.size(0)
+    hyps = []
+    msg = f"Using {params.method}"
+    if params.method == "beam_search":
+        msg += f" with beam size {params.beam_size}"
+    logging.info(msg)
+
+    if params.method == "fast_beam_search":
+        decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
+        hyp_tokens = fast_beam_search_one_best(
+            model=model,
+            decoding_graph=decoding_graph,
+            encoder_out=encoder_out,
+            encoder_out_lens=encoder_out_lens,
+            beam=params.beam,
+            max_contexts=params.max_contexts,
+            max_states=params.max_states,
+        )
+        for hyp in sp.decode(hyp_tokens):
+            hyps.append(hyp.split())
+    elif params.method == "modified_beam_search":
+        hyp_tokens = modified_beam_search(
+            model=model,
+            encoder_out=encoder_out,
+            encoder_out_lens=encoder_out_lens,
+            beam=params.beam_size,
+        )
+
+        for hyp in sp.decode(hyp_tokens):
+            hyps.append(hyp.split())
+    elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
+        hyp_tokens = greedy_search_batch(
+            model=model,
+            encoder_out=encoder_out,
+            encoder_out_lens=encoder_out_lens,
+        )
+        for hyp in sp.decode(hyp_tokens):
+            hyps.append(hyp.split())
+    else:
+        for i in range(num_waves):
+            # fmt: off
+            encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
+            # fmt: on
+            if params.method == "greedy_search":
+                hyp = greedy_search(
+                    model=model,
+                    encoder_out=encoder_out_i,
+                    max_sym_per_frame=params.max_sym_per_frame,
+                )
+            elif params.method == "beam_search":
+                hyp = beam_search(
+                    model=model,
+                    encoder_out=encoder_out_i,
+                    beam=params.beam_size,
+                )
+            else:
+                raise ValueError(f"Unsupported method: {params.method}")
+
+            hyps.append(sp.decode(hyp).split())
+
+    s = "\n"
+    for filename, hyp in zip(params.sound_files, hyps):
+        words = " ".join(hyp)
+        s += f"{filename}:\n{words}\n\n"
+    logging.info(s)
+
+    logging.info("Decoding Done")
+
+
+if __name__ == "__main__":
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+    logging.basicConfig(format=formatter, level=logging.INFO)
+    main()
diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/scaling.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/scaling.py
new file mode 120000
index 000000000..5f9be9fe0
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/scaling.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7/scaling.py
\ No newline at end of file
diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/scaling_converter.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/scaling_converter.py
new file mode 120000
index 000000000..f9960e5c6
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/scaling_converter.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7/scaling_converter.py
\ No newline at end of file
diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/test_model.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/test_model.py
new file mode 120000
index 000000000..7ceac5d10
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/test_model.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7/test_model.py
\ No newline at end of file
diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/train.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/train.py
new file mode 100755
index 000000000..1332bafd8
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/train.py
@@ -0,0 +1,1224 @@
+#!/usr/bin/env python3
+# Copyright    2021-2022  Xiaomi Corp.        (authors: Fangjun Kuang,
+#                                                       Wei Kang,
+#                                                       Mingshuang Luo,)
+#                                                       Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+
+export CUDA_VISIBLE_DEVICES="0,1,2,3"
+
+./pruned_transducer_stateless7/train.py \
+  --world-size 4 \
+  --num-epochs 30 \
+  --start-epoch 1 \
+  --exp-dir pruned_transducer_stateless7/exp \
+  --full-libri 1 \
+  --max-duration 300
+
+# For mix precision training:
+
+./pruned_transducer_stateless7/train.py \
+  --world-size 4 \
+  --num-epochs 30 \
+  --start-epoch 1 \
+  --use-fp16 1 \
+  --exp-dir pruned_transducer_stateless7/exp \
+  --full-libri 1 \
+  --max-duration 550
+
+"""
+
+
+import argparse
+import copy
+import logging
+import warnings
+from pathlib import Path
+from shutil import copyfile
+from typing import Any, Dict, Optional, Tuple, Union
+
+import k2
+import optim
+import sentencepiece as spm
+import torch
+import torch.multiprocessing as mp
+import torch.nn as nn
+from asr_datamodule import Xbmu_AmdoAsrDataModule
+from decoder import Decoder
+from joiner import Joiner
+from lhotse.cut import Cut
+from lhotse.dataset.sampling.base import CutSampler
+from lhotse.utils import fix_random_seed
+from model import Transducer
+from optim import Eden, ScaledAdam
+from torch import Tensor
+from torch.cuda.amp import GradScaler
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.utils.tensorboard import SummaryWriter
+from zipformer import Zipformer
+
+from icefall import diagnostics
+from icefall.checkpoint import load_checkpoint, remove_checkpoints
+from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
+from icefall.checkpoint import (
+    save_checkpoint_with_global_batch_idx,
+    update_averaged_model,
+)
+from icefall.dist import cleanup_dist, setup_dist
+from icefall.env import get_env_info
+from icefall.hooks import register_inf_check_hooks
+from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
+
+LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
+
+
+def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
+    if isinstance(model, DDP):
+        # get underlying nn.Module
+        model = model.module
+    for module in model.modules():
+        if hasattr(module, "batch_count"):
+            module.batch_count = batch_count
+
+
+def add_model_arguments(parser: argparse.ArgumentParser):
+    parser.add_argument(
+        "--num-encoder-layers",
+        type=str,
+        default="2,4,3,2,4",
+        help="Number of zipformer encoder layers, comma separated.",
+    )
+
+    parser.add_argument(
+        "--feedforward-dims",
+        type=str,
+        default="1024,1024,2048,2048,1024",
+        help="Feedforward dimension of the zipformer encoder layers, comma separated.",
+    )
+
+    parser.add_argument(
+        "--nhead",
+        type=str,
+        default="8,8,8,8,8",
+        help="Number of attention heads in the zipformer encoder layers.",
+    )
+
+    parser.add_argument(
+        "--encoder-dims",
+        type=str,
+        default="384,384,384,384,384",
+        help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated",
+    )
+
+    parser.add_argument(
+        "--attention-dims",
+        type=str,
+        default="192,192,192,192,192",
+        help="""Attention dimension in the 2 blocks of zipformer encoder layers, comma separated;
+        not the same as embedding dimension.""",
+    )
+
+    parser.add_argument(
+        "--encoder-unmasked-dims",
+        type=str,
+        default="256,256,256,256,256",
+        help="Unmasked dimensions in the encoders, relates to augmentation during training.  "
+        "Must be <= each of encoder_dims.  Empirically, less than 256 seems to make performance "
+        " worse.",
+    )
+
+    parser.add_argument(
+        "--zipformer-downsampling-factors",
+        type=str,
+        default="1,2,4,8,2",
+        help="Downsampling factor for each stack of encoder layers.",
+    )
+
+    parser.add_argument(
+        "--cnn-module-kernels",
+        type=str,
+        default="31,31,31,31,31",
+        help="Sizes of kernels in convolution modules",
+    )
+
+    parser.add_argument(
+        "--decoder-dim",
+        type=int,
+        default=512,
+        help="Embedding dimension in the decoder model.",
+    )
+
+    parser.add_argument(
+        "--joiner-dim",
+        type=int,
+        default=512,
+        help="""Dimension used in the joiner model.
+        Outputs from the encoder and decoder model are projected
+        to this dimension before adding.
+        """,
+    )
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+
+    parser.add_argument(
+        "--world-size",
+        type=int,
+        default=1,
+        help="Number of GPUs for DDP training.",
+    )
+
+    parser.add_argument(
+        "--master-port",
+        type=int,
+        default=12354,
+        help="Master port to use for DDP training.",
+    )
+
+    parser.add_argument(
+        "--tensorboard",
+        type=str2bool,
+        default=True,
+        help="Should various information be logged in tensorboard.",
+    )
+
+    parser.add_argument(
+        "--num-epochs",
+        type=int,
+        default=30,
+        help="Number of epochs to train.",
+    )
+
+    parser.add_argument(
+        "--start-epoch",
+        type=int,
+        default=1,
+        help="""Resume training from this epoch. It should be positive.
+        If larger than 1, it will load checkpoint from
+        exp-dir/epoch-{start_epoch-1}.pt
+        """,
+    )
+
+    parser.add_argument(
+        "--start-batch",
+        type=int,
+        default=0,
+        help="""If positive, --start-epoch is ignored and
+        it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt
+        """,
+    )
+
+    parser.add_argument(
+        "--exp-dir",
+        type=str,
+        default="pruned_transducer_stateless7/exp",
+        help="""The experiment dir.
+        It specifies the directory where all training related
+        files, e.g., checkpoints, log, etc, are saved
+        """,
+    )
+
+    parser.add_argument(
+        "--bpe-model",
+        type=str,
+        default="data/lang_bpe_500/bpe.model",
+        help="Path to the BPE model",
+    )
+
+    parser.add_argument(
+        "--base-lr", type=float, default=0.05, help="The base learning rate."
+    )
+
+    parser.add_argument(
+        "--lr-batches",
+        type=float,
+        default=5000,
+        help="""Number of steps that affects how rapidly the learning rate
+        decreases. We suggest not to change this.""",
+    )
+
+    parser.add_argument(
+        "--lr-epochs",
+        type=float,
+        default=3.5,
+        help="""Number of epochs that affects how rapidly the learning rate decreases.
+        """,
+    )
+
+    parser.add_argument(
+        "--context-size",
+        type=int,
+        default=2,
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+    )
+
+    parser.add_argument(
+        "--prune-range",
+        type=int,
+        default=5,
+        help="The prune range for rnnt loss, it means how many symbols(context)"
+        "we are using to compute the loss",
+    )
+
+    parser.add_argument(
+        "--lm-scale",
+        type=float,
+        default=0.25,
+        help="The scale to smooth the loss with lm "
+        "(output of prediction network) part.",
+    )
+
+    parser.add_argument(
+        "--am-scale",
+        type=float,
+        default=0.0,
+        help="The scale to smooth the loss with am (output of encoder network) part.",
+    )
+
+    parser.add_argument(
+        "--simple-loss-scale",
+        type=float,
+        default=0.5,
+        help="To get pruning ranges, we will calculate a simple version"
+        "loss(joiner is just addition), this simple loss also uses for"
+        "training (as a regularization item). We will scale the simple loss"
+        "with this parameter before adding to the final loss.",
+    )
+
+    parser.add_argument(
+        "--seed",
+        type=int,
+        default=42,
+        help="The seed for random generators intended for reproducibility",
+    )
+
+    parser.add_argument(
+        "--print-diagnostics",
+        type=str2bool,
+        default=False,
+        help="Accumulate stats on activations, print them and exit.",
+    )
+
+    parser.add_argument(
+        "--inf-check",
+        type=str2bool,
+        default=False,
+        help="Add hooks to check for infinite module outputs and gradients.",
+    )
+
+    parser.add_argument(
+        "--save-every-n",
+        type=int,
+        default=2000,
+        help="""Save checkpoint after processing this number of batches"
+        periodically. We save checkpoint to exp-dir/ whenever
+        params.batch_idx_train % save_every_n == 0. The checkpoint filename
+        has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
+        Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
+        end of each epoch where `xxx` is the epoch number counting from 0.
+        """,
+    )
+
+    parser.add_argument(
+        "--keep-last-k",
+        type=int,
+        default=30,
+        help="""Only keep this number of checkpoints on disk.
+        For instance, if it is 3, there are only 3 checkpoints
+        in the exp-dir with filenames `checkpoint-xxx.pt`.
+        It does not affect checkpoints with name `epoch-xxx.pt`.
+        """,
+    )
+
+    parser.add_argument(
+        "--average-period",
+        type=int,
+        default=200,
+        help="""Update the averaged model, namely `model_avg`, after processing
+        this number of batches. `model_avg` is a separate version of model,
+        in which each floating-point parameter is the average of all the
+        parameters from the start of training. Each time we take the average,
+        we do: `model_avg = model * (average_period / batch_idx_train) +
+            model_avg * ((batch_idx_train - average_period) / batch_idx_train)`.
+        """,
+    )
+
+    parser.add_argument(
+        "--use-fp16",
+        type=str2bool,
+        default=False,
+        help="Whether to use half precision training.",
+    )
+
+    add_model_arguments(parser)
+
+    return parser
+
+
+def get_params() -> AttributeDict:
+    """Return a dict containing training parameters.
+
+    All training related parameters that are not passed from the commandline
+    are saved in the variable `params`.
+
+    Commandline options are merged into `params` after they are parsed, so
+    you can also access them via `params`.
+
+    Explanation of options saved in `params`:
+
+        - best_train_loss: Best training loss so far. It is used to select
+                           the model that has the lowest training loss. It is
+                           updated during the training.
+
+        - best_valid_loss: Best validation loss so far. It is used to select
+                           the model that has the lowest validation loss. It is
+                           updated during the training.
+
+        - best_train_epoch: It is the epoch that has the best training loss.
+
+        - best_valid_epoch: It is the epoch that has the best validation loss.
+
+        - batch_idx_train: Used to writing statistics to tensorboard. It
+                           contains number of batches trained so far across
+                           epochs.
+
+        - log_interval:  Print training loss if batch_idx % log_interval` is 0
+
+        - reset_interval: Reset statistics if batch_idx % reset_interval is 0
+
+        - valid_interval:  Run validation if batch_idx % valid_interval is 0
+
+        - feature_dim: The model input dim. It has to match the one used
+                       in computing features.
+
+        - subsampling_factor:  The subsampling factor for the model.
+
+        - encoder_dim: Hidden dim for multi-head attention model.
+
+        - num_decoder_layers: Number of decoder layer of transformer decoder.
+
+        - warm_step: The warmup period that dictates the decay of the
+              scale on "simple" (un-pruned) loss.
+    """
+    params = AttributeDict(
+        {
+            "best_train_loss": float("inf"),
+            "best_valid_loss": float("inf"),
+            "best_train_epoch": -1,
+            "best_valid_epoch": -1,
+            "batch_idx_train": 0,
+            "log_interval": 50,
+            "reset_interval": 200,
+            "valid_interval": 3000,  # For the 100h subset, use 800
+            # parameters for zipformer
+            "feature_dim": 80,
+            "subsampling_factor": 4,  # not passed in, this is fixed.
+            "warm_step": 2000,
+            "env_info": get_env_info(),
+        }
+    )
+
+    return params
+
+
+def get_encoder_model(params: AttributeDict) -> nn.Module:
+    # TODO: We can add an option to switch between Zipformer and Transformer
+    def to_int_tuple(s: str):
+        return tuple(map(int, s.split(",")))
+
+    encoder = Zipformer(
+        num_features=params.feature_dim,
+        output_downsampling_factor=2,
+        zipformer_downsampling_factors=to_int_tuple(
+            params.zipformer_downsampling_factors
+        ),
+        encoder_dims=to_int_tuple(params.encoder_dims),
+        attention_dim=to_int_tuple(params.attention_dims),
+        encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims),
+        nhead=to_int_tuple(params.nhead),
+        feedforward_dim=to_int_tuple(params.feedforward_dims),
+        cnn_module_kernels=to_int_tuple(params.cnn_module_kernels),
+        num_encoder_layers=to_int_tuple(params.num_encoder_layers),
+    )
+    return encoder
+
+
+def get_decoder_model(params: AttributeDict) -> nn.Module:
+    decoder = Decoder(
+        vocab_size=params.vocab_size,
+        decoder_dim=params.decoder_dim,
+        blank_id=params.blank_id,
+        context_size=params.context_size,
+    )
+    return decoder
+
+
+def get_joiner_model(params: AttributeDict) -> nn.Module:
+    joiner = Joiner(
+        encoder_dim=int(params.encoder_dims.split(",")[-1]),
+        decoder_dim=params.decoder_dim,
+        joiner_dim=params.joiner_dim,
+        vocab_size=params.vocab_size,
+    )
+    return joiner
+
+
+def get_transducer_model(params: AttributeDict) -> nn.Module:
+    encoder = get_encoder_model(params)
+    decoder = get_decoder_model(params)
+    joiner = get_joiner_model(params)
+
+    model = Transducer(
+        encoder=encoder,
+        decoder=decoder,
+        joiner=joiner,
+        encoder_dim=int(params.encoder_dims.split(",")[-1]),
+        decoder_dim=params.decoder_dim,
+        joiner_dim=params.joiner_dim,
+        vocab_size=params.vocab_size,
+    )
+    return model
+
+
+def load_checkpoint_if_available(
+    params: AttributeDict,
+    model: nn.Module,
+    model_avg: nn.Module = None,
+    optimizer: Optional[torch.optim.Optimizer] = None,
+    scheduler: Optional[LRSchedulerType] = None,
+) -> Optional[Dict[str, Any]]:
+    """Load checkpoint from file.
+
+    If params.start_batch is positive, it will load the checkpoint from
+    `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if
+    params.start_epoch is larger than 1, it will load the checkpoint from
+    `params.start_epoch - 1`.
+
+    Apart from loading state dict for `model` and `optimizer` it also updates
+    `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
+    and `best_valid_loss` in `params`.
+
+    Args:
+      params:
+        The return value of :func:`get_params`.
+      model:
+        The training model.
+      model_avg:
+        The stored model averaged from the start of training.
+      optimizer:
+        The optimizer that we are using.
+      scheduler:
+        The scheduler that we are using.
+    Returns:
+      Return a dict containing previously saved training info.
+    """
+    if params.start_batch > 0:
+        filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt"
+    elif params.start_epoch > 1:
+        filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
+    else:
+        return None
+
+    assert filename.is_file(), f"{filename} does not exist!"
+
+    saved_params = load_checkpoint(
+        filename,
+        model=model,
+        model_avg=model_avg,
+        optimizer=optimizer,
+        scheduler=scheduler,
+    )
+
+    keys = [
+        "best_train_epoch",
+        "best_valid_epoch",
+        "batch_idx_train",
+        "best_train_loss",
+        "best_valid_loss",
+    ]
+    for k in keys:
+        params[k] = saved_params[k]
+
+    if params.start_batch > 0:
+        if "cur_epoch" in saved_params:
+            params["start_epoch"] = saved_params["cur_epoch"]
+
+        if "cur_batch_idx" in saved_params:
+            params["cur_batch_idx"] = saved_params["cur_batch_idx"]
+
+    return saved_params
+
+
+def save_checkpoint(
+    params: AttributeDict,
+    model: Union[nn.Module, DDP],
+    model_avg: Optional[nn.Module] = None,
+    optimizer: Optional[torch.optim.Optimizer] = None,
+    scheduler: Optional[LRSchedulerType] = None,
+    sampler: Optional[CutSampler] = None,
+    scaler: Optional[GradScaler] = None,
+    rank: int = 0,
+) -> None:
+    """Save model, optimizer, scheduler and training stats to file.
+
+    Args:
+      params:
+        It is returned by :func:`get_params`.
+      model:
+        The training model.
+      model_avg:
+        The stored model averaged from the start of training.
+      optimizer:
+        The optimizer used in the training.
+      sampler:
+       The sampler for the training dataset.
+      scaler:
+        The scaler used for mix precision training.
+    """
+    if rank != 0:
+        return
+    filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
+    save_checkpoint_impl(
+        filename=filename,
+        model=model,
+        model_avg=model_avg,
+        params=params,
+        optimizer=optimizer,
+        scheduler=scheduler,
+        sampler=sampler,
+        scaler=scaler,
+        rank=rank,
+    )
+
+    if params.best_train_epoch == params.cur_epoch:
+        best_train_filename = params.exp_dir / "best-train-loss.pt"
+        copyfile(src=filename, dst=best_train_filename)
+
+    if params.best_valid_epoch == params.cur_epoch:
+        best_valid_filename = params.exp_dir / "best-valid-loss.pt"
+        copyfile(src=filename, dst=best_valid_filename)
+
+
+def compute_loss(
+    params: AttributeDict,
+    model: Union[nn.Module, DDP],
+    sp: spm.SentencePieceProcessor,
+    batch: dict,
+    is_training: bool,
+) -> Tuple[Tensor, MetricsTracker]:
+    """
+    Compute transducer loss given the model and its inputs.
+
+    Args:
+      params:
+        Parameters for training. See :func:`get_params`.
+      model:
+        The model for training. It is an instance of Zipformer in our case.
+      batch:
+        A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+        for the content in it.
+      is_training:
+        True for training. False for validation. When it is True, this
+        function enables autograd during computation; when it is False, it
+        disables autograd.
+     warmup: a floating point value which increases throughout training;
+        values >= 1.0 are fully warmed up and have all modules present.
+    """
+    device = model.device if isinstance(model, DDP) else next(model.parameters()).device
+    feature = batch["inputs"]
+    # at entry, feature is (N, T, C)
+    assert feature.ndim == 3
+    feature = feature.to(device)
+
+    supervisions = batch["supervisions"]
+    feature_lens = supervisions["num_frames"].to(device)
+
+    batch_idx_train = params.batch_idx_train
+    warm_step = params.warm_step
+
+    texts = batch["supervisions"]["text"]
+    y = sp.encode(texts, out_type=int)
+    y = k2.RaggedTensor(y).to(device)
+
+    with torch.set_grad_enabled(is_training):
+        simple_loss, pruned_loss = model(
+            x=feature,
+            x_lens=feature_lens,
+            y=y,
+            prune_range=params.prune_range,
+            am_scale=params.am_scale,
+            lm_scale=params.lm_scale,
+        )
+
+        s = params.simple_loss_scale
+        # take down the scale on the simple loss from 1.0 at the start
+        # to params.simple_loss scale by warm_step.
+        simple_loss_scale = (
+            s
+            if batch_idx_train >= warm_step
+            else 1.0 - (batch_idx_train / warm_step) * (1.0 - s)
+        )
+        pruned_loss_scale = (
+            1.0
+            if batch_idx_train >= warm_step
+            else 0.1 + 0.9 * (batch_idx_train / warm_step)
+        )
+
+        loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
+
+    assert loss.requires_grad == is_training
+
+    info = MetricsTracker()
+    with warnings.catch_warnings():
+        warnings.simplefilter("ignore")
+        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
+
+    # Note: We use reduction=sum while computing the loss.
+    info["loss"] = loss.detach().cpu().item()
+    info["simple_loss"] = simple_loss.detach().cpu().item()
+    info["pruned_loss"] = pruned_loss.detach().cpu().item()
+
+    return loss, info
+
+
+def compute_validation_loss(
+    params: AttributeDict,
+    model: Union[nn.Module, DDP],
+    sp: spm.SentencePieceProcessor,
+    valid_dl: torch.utils.data.DataLoader,
+    world_size: int = 1,
+) -> MetricsTracker:
+    """Run the validation process."""
+    model.eval()
+
+    tot_loss = MetricsTracker()
+
+    for batch_idx, batch in enumerate(valid_dl):
+        loss, loss_info = compute_loss(
+            params=params,
+            model=model,
+            sp=sp,
+            batch=batch,
+            is_training=False,
+        )
+        assert loss.requires_grad is False
+        tot_loss = tot_loss + loss_info
+
+    if world_size > 1:
+        tot_loss.reduce(loss.device)
+
+    loss_value = tot_loss["loss"] / tot_loss["frames"]
+    if loss_value < params.best_valid_loss:
+        params.best_valid_epoch = params.cur_epoch
+        params.best_valid_loss = loss_value
+
+    return tot_loss
+
+
+def train_one_epoch(
+    params: AttributeDict,
+    model: Union[nn.Module, DDP],
+    optimizer: torch.optim.Optimizer,
+    scheduler: LRSchedulerType,
+    sp: spm.SentencePieceProcessor,
+    train_dl: torch.utils.data.DataLoader,
+    valid_dl: torch.utils.data.DataLoader,
+    scaler: GradScaler,
+    model_avg: Optional[nn.Module] = None,
+    tb_writer: Optional[SummaryWriter] = None,
+    world_size: int = 1,
+    rank: int = 0,
+) -> None:
+    """Train the model for one epoch.
+
+    The training loss from the mean of all frames is saved in
+    `params.train_loss`. It runs the validation process every
+    `params.valid_interval` batches.
+
+    Args:
+      params:
+        It is returned by :func:`get_params`.
+      model:
+        The model for training.
+      optimizer:
+        The optimizer we are using.
+      scheduler:
+        The learning rate scheduler, we call step() every step.
+      train_dl:
+        Dataloader for the training dataset.
+      valid_dl:
+        Dataloader for the validation dataset.
+      scaler:
+        The scaler used for mix precision training.
+      model_avg:
+        The stored model averaged from the start of training.
+      tb_writer:
+        Writer to write log messages to tensorboard.
+      world_size:
+        Number of nodes in DDP training. If it is 1, DDP is disabled.
+      rank:
+        The rank of the node in DDP training. If no DDP is used, it should
+        be set to 0.
+    """
+    model.train()
+
+    tot_loss = MetricsTracker()
+
+    cur_batch_idx = params.get("cur_batch_idx", 0)
+
+    for batch_idx, batch in enumerate(train_dl):
+        if batch_idx < cur_batch_idx:
+            continue
+        cur_batch_idx = batch_idx
+
+        params.batch_idx_train += 1
+        batch_size = len(batch["supervisions"]["text"])
+
+        try:
+            with torch.cuda.amp.autocast(enabled=params.use_fp16):
+                loss, loss_info = compute_loss(
+                    params=params,
+                    model=model,
+                    sp=sp,
+                    batch=batch,
+                    is_training=True,
+                )
+            # summary stats
+            tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
+
+            # NOTE: We use reduction==sum and loss is computed over utterances
+            # in the batch and there is no normalization to it so far.
+            scaler.scale(loss).backward()
+            set_batch_count(model, params.batch_idx_train)
+            scheduler.step_batch(params.batch_idx_train)
+
+            scaler.step(optimizer)
+            scaler.update()
+            optimizer.zero_grad()
+        except:  # noqa
+            display_and_save_batch(batch, params=params, sp=sp)
+            raise
+
+        if params.print_diagnostics and batch_idx == 5:
+            return
+
+        if (
+            rank == 0
+            and params.batch_idx_train > 0
+            and params.batch_idx_train % params.average_period == 0
+        ):
+            update_averaged_model(
+                params=params,
+                model_cur=model,
+                model_avg=model_avg,
+            )
+
+        if (
+            params.batch_idx_train > 0
+            and params.batch_idx_train % params.save_every_n == 0
+        ):
+            params.cur_batch_idx = batch_idx
+            save_checkpoint_with_global_batch_idx(
+                out_dir=params.exp_dir,
+                global_batch_idx=params.batch_idx_train,
+                model=model,
+                model_avg=model_avg,
+                params=params,
+                optimizer=optimizer,
+                scheduler=scheduler,
+                sampler=train_dl.sampler,
+                scaler=scaler,
+                rank=rank,
+            )
+            del params.cur_batch_idx
+            remove_checkpoints(
+                out_dir=params.exp_dir,
+                topk=params.keep_last_k,
+                rank=rank,
+            )
+
+        if batch_idx % 100 == 0 and params.use_fp16:
+            # If the grad scale was less than 1, try increasing it.    The _growth_interval
+            # of the grad scaler is configurable, but we can't configure it to have different
+            # behavior depending on the current grad scale.
+            cur_grad_scale = scaler._scale.item()
+            if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0):
+                scaler.update(cur_grad_scale * 2.0)
+            if cur_grad_scale < 0.01:
+                logging.warning(f"Grad scale is small: {cur_grad_scale}")
+            if cur_grad_scale < 1.0e-05:
+                raise RuntimeError(
+                    f"grad_scale is too small, exiting: {cur_grad_scale}"
+                )
+
+        if batch_idx % params.log_interval == 0:
+            cur_lr = scheduler.get_last_lr()[0]
+            cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
+
+            logging.info(
+                f"Epoch {params.cur_epoch}, "
+                f"batch {batch_idx}, loss[{loss_info}], "
+                f"tot_loss[{tot_loss}], batch size: {batch_size}, "
+                f"lr: {cur_lr:.2e}, "
+                + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
+            )
+
+            if tb_writer is not None:
+                tb_writer.add_scalar(
+                    "train/learning_rate", cur_lr, params.batch_idx_train
+                )
+
+                loss_info.write_summary(
+                    tb_writer, "train/current_", params.batch_idx_train
+                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+                if params.use_fp16:
+                    tb_writer.add_scalar(
+                        "train/grad_scale",
+                        cur_grad_scale,
+                        params.batch_idx_train,
+                    )
+
+        if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
+            logging.info("Computing validation loss")
+            valid_info = compute_validation_loss(
+                params=params,
+                model=model,
+                sp=sp,
+                valid_dl=valid_dl,
+                world_size=world_size,
+            )
+            model.train()
+            logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
+            logging.info(
+                f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
+            )
+            if tb_writer is not None:
+                valid_info.write_summary(
+                    tb_writer, "train/valid_", params.batch_idx_train
+                )
+
+    loss_value = tot_loss["loss"] / tot_loss["frames"]
+    params.train_loss = loss_value
+    if params.train_loss < params.best_train_loss:
+        params.best_train_epoch = params.cur_epoch
+        params.best_train_loss = params.train_loss
+
+
+def run(rank, world_size, args):
+    """
+    Args:
+      rank:
+        It is a value between 0 and `world_size-1`, which is
+        passed automatically by `mp.spawn()` in :func:`main`.
+        The node with rank 0 is responsible for saving checkpoint.
+      world_size:
+        Number of GPUs for DDP training.
+      args:
+        The return value of get_parser().parse_args()
+    """
+    params = get_params()
+    params.update(vars(args))
+
+    fix_random_seed(params.seed)
+    if world_size > 1:
+        setup_dist(rank, world_size, params.master_port)
+
+    setup_logger(f"{params.exp_dir}/log/log-train")
+    logging.info("Training started")
+
+    if args.tensorboard and rank == 0:
+        tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
+    else:
+        tb_writer = None
+
+    device = torch.device("cpu")
+    if torch.cuda.is_available():
+        device = torch.device("cuda", rank)
+    logging.info(f"Device: {device}")
+
+    sp = spm.SentencePieceProcessor()
+    sp.load(params.bpe_model)
+
+    #  is defined in local/train_bpe_model.py
+    params.blank_id = sp.piece_to_id("")
+    params.vocab_size = sp.get_piece_size()
+
+    logging.info(params)
+
+    logging.info("About to create model")
+    model = get_transducer_model(params)
+
+    num_param = sum([p.numel() for p in model.parameters()])
+    logging.info(f"Number of model parameters: {num_param}")
+
+    assert params.save_every_n >= params.average_period
+    model_avg: Optional[nn.Module] = None
+    if rank == 0:
+        # model_avg is only used with rank 0
+        model_avg = copy.deepcopy(model).to(torch.float64)
+
+    assert params.start_epoch > 0, params.start_epoch
+    checkpoints = load_checkpoint_if_available(
+        params=params, model=model, model_avg=model_avg
+    )
+
+    model.to(device)
+    if world_size > 1:
+        logging.info("Using DDP")
+        model = DDP(model, device_ids=[rank], find_unused_parameters=True)
+
+    optimizer = ScaledAdam(model.parameters(), lr=params.base_lr, clipping_scale=2.0)
+
+    scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
+
+    if checkpoints and "optimizer" in checkpoints:
+        logging.info("Loading optimizer state dict")
+        optimizer.load_state_dict(checkpoints["optimizer"])
+
+    if (
+        checkpoints
+        and "scheduler" in checkpoints
+        and checkpoints["scheduler"] is not None
+    ):
+        logging.info("Loading scheduler state dict")
+        scheduler.load_state_dict(checkpoints["scheduler"])
+
+    if params.print_diagnostics:
+        opts = diagnostics.TensorDiagnosticOptions(
+            2**22
+        )  # allow 4 megabytes per sub-module
+        diagnostic = diagnostics.attach_diagnostics(model, opts)
+
+    if params.inf_check:
+        register_inf_check_hooks(model)
+
+    xbmu_amdo = Xbmu_AmdoAsrDataModule(args)
+
+    train_cuts = xbmu_amdo.train_cuts()
+
+    def remove_short_and_long_utt(c: Cut):
+        # Keep only utterances with duration between 1 second and 20 seconds
+        #
+        # Caution: There is a reason to select 20.0 here. Please see
+        # ../local/display_manifest_statistics.py
+        #
+        # You should use ../local/display_manifest_statistics.py to get
+        # an utterance duration distribution for your dataset to select
+        # the threshold
+        if c.duration < 1.0 or c.duration > 20.0:
+            logging.warning(
+                f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
+            )
+            return False
+
+        # In pruned RNN-T, we require that T >= S
+        # where T is the number of feature frames after subsampling
+        # and S is the number of tokens in the utterance
+
+        # In ./zipformer.py, the conv module uses the following expression
+        # for subsampling
+        T = ((c.num_frames - 7) // 2 + 1) // 2
+        tokens = sp.encode(c.supervisions[0].text, out_type=str)
+
+        if T < len(tokens):
+            logging.warning(
+                f"Exclude cut with ID {c.id} from training. "
+                f"Number of frames (before subsampling): {c.num_frames}. "
+                f"Number of frames (after subsampling): {T}. "
+                f"Text: {c.supervisions[0].text}. "
+                f"Tokens: {tokens}. "
+                f"Number of tokens: {len(tokens)}"
+            )
+            return False
+
+        return True
+
+    train_cuts = train_cuts.filter(remove_short_and_long_utt)
+
+    if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
+        # We only load the sampler's state dict when it loads a checkpoint
+        # saved in the middle of an epoch
+        sampler_state_dict = checkpoints["sampler"]
+    else:
+        sampler_state_dict = None
+
+    train_dl = xbmu_amdo.train_dataloaders(
+        train_cuts, sampler_state_dict=sampler_state_dict
+    )
+
+    valid_cuts = xbmu_amdo.valid_cuts()
+    valid_dl = xbmu_amdo.valid_dataloaders(valid_cuts)
+
+    if not params.print_diagnostics:
+        scan_pessimistic_batches_for_oom(
+            model=model,
+            train_dl=train_dl,
+            optimizer=optimizer,
+            sp=sp,
+            params=params,
+        )
+
+    scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
+    if checkpoints and "grad_scaler" in checkpoints:
+        logging.info("Loading grad scaler state dict")
+        scaler.load_state_dict(checkpoints["grad_scaler"])
+
+    for epoch in range(params.start_epoch, params.num_epochs + 1):
+        scheduler.step_epoch(epoch - 1)
+        fix_random_seed(params.seed + epoch - 1)
+        train_dl.sampler.set_epoch(epoch - 1)
+
+        if tb_writer is not None:
+            tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
+
+        params.cur_epoch = epoch
+
+        train_one_epoch(
+            params=params,
+            model=model,
+            model_avg=model_avg,
+            optimizer=optimizer,
+            scheduler=scheduler,
+            sp=sp,
+            train_dl=train_dl,
+            valid_dl=valid_dl,
+            scaler=scaler,
+            tb_writer=tb_writer,
+            world_size=world_size,
+            rank=rank,
+        )
+
+        if params.print_diagnostics:
+            diagnostic.print_diagnostics()
+            break
+
+        save_checkpoint(
+            params=params,
+            model=model,
+            model_avg=model_avg,
+            optimizer=optimizer,
+            scheduler=scheduler,
+            sampler=train_dl.sampler,
+            scaler=scaler,
+            rank=rank,
+        )
+
+    logging.info("Done!")
+
+    if world_size > 1:
+        torch.distributed.barrier()
+        cleanup_dist()
+
+
+def display_and_save_batch(
+    batch: dict,
+    params: AttributeDict,
+    sp: spm.SentencePieceProcessor,
+) -> None:
+    """Display the batch statistics and save the batch into disk.
+
+    Args:
+      batch:
+        A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+        for the content in it.
+      params:
+        Parameters for training. See :func:`get_params`.
+      sp:
+        The BPE model.
+    """
+    from lhotse.utils import uuid4
+
+    filename = f"{params.exp_dir}/batch-{uuid4()}.pt"
+    logging.info(f"Saving batch to {filename}")
+    torch.save(batch, filename)
+
+    supervisions = batch["supervisions"]
+    features = batch["inputs"]
+
+    logging.info(f"features shape: {features.shape}")
+
+    y = sp.encode(supervisions["text"], out_type=int)
+    num_tokens = sum(len(i) for i in y)
+    logging.info(f"num tokens: {num_tokens}")
+
+
+def scan_pessimistic_batches_for_oom(
+    model: Union[nn.Module, DDP],
+    train_dl: torch.utils.data.DataLoader,
+    optimizer: torch.optim.Optimizer,
+    sp: spm.SentencePieceProcessor,
+    params: AttributeDict,
+):
+    from lhotse.dataset import find_pessimistic_batches
+
+    logging.info(
+        "Sanity check -- see if any of the batches in epoch 1 would cause OOM."
+    )
+    batches, crit_values = find_pessimistic_batches(train_dl.sampler)
+    for criterion, cuts in batches.items():
+        batch = train_dl.dataset[cuts]
+        try:
+            with torch.cuda.amp.autocast(enabled=params.use_fp16):
+                loss, _ = compute_loss(
+                    params=params,
+                    model=model,
+                    sp=sp,
+                    batch=batch,
+                    is_training=True,
+                )
+            loss.backward()
+            optimizer.zero_grad()
+        except Exception as e:
+            if "CUDA out of memory" in str(e):
+                logging.error(
+                    "Your GPU ran out of memory with the current "
+                    "max_duration setting. We recommend decreasing "
+                    "max_duration and trying again.\n"
+                    f"Failing criterion: {criterion} "
+                    f"(={crit_values[criterion]}) ..."
+                )
+            display_and_save_batch(batch, params=params, sp=sp)
+            raise
+        logging.info(
+            f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
+        )
+
+
+def main():
+    parser = get_parser()
+    Xbmu_AmdoAsrDataModule.add_arguments(parser)
+    args = parser.parse_args()
+    args.exp_dir = Path(args.exp_dir)
+
+    world_size = args.world_size
+    assert world_size >= 1
+    if world_size > 1:
+        mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
+    else:
+        run(rank=0, world_size=1, args=args)
+
+
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+if __name__ == "__main__":
+    main()
diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/zipformer.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/zipformer.py
new file mode 120000
index 000000000..f2f66041e
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/zipformer.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7/zipformer.py
\ No newline at end of file
diff --git a/egs/xbmu_amdo31/ASR/shared b/egs/xbmu_amdo31/ASR/shared
new file mode 120000
index 000000000..4c5e91438
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/shared
@@ -0,0 +1 @@
+../../../icefall/shared/
\ No newline at end of file

From c25c8c6ad18b8a3d5de2f093947f7b2293eec35a Mon Sep 17 00:00:00 2001
From: Wei Kang 
Date: Sun, 4 Dec 2022 17:20:17 +0800
Subject: [PATCH 034/174] Add need_repeat_flag in phone based ctc graph
 compiler (#727)

* Fix is_repeat_token in icefall

* Fix phone based recipe

* Update egs/librispeech/ASR/conformer_ctc3/train.py

Co-authored-by: Fangjun Kuang 

* Fix black

Co-authored-by: Fangjun Kuang 
---
 egs/librispeech/ASR/conformer_ctc3/train.py |  1 +
 icefall/graph_compiler.py                   | 18 ++++++++++++++----
 2 files changed, 15 insertions(+), 4 deletions(-)

diff --git a/egs/librispeech/ASR/conformer_ctc3/train.py b/egs/librispeech/ASR/conformer_ctc3/train.py
index fb3b740c1..ac489af9e 100755
--- a/egs/librispeech/ASR/conformer_ctc3/train.py
+++ b/egs/librispeech/ASR/conformer_ctc3/train.py
@@ -890,6 +890,7 @@ def run(rank, world_size, args):
         graph_compiler = CtcTrainingGraphCompiler(
             lexicon,
             device=device,
+            need_repeat_flag=params.delay_penalty > 0,
         )
         # Manually add the sos/eos ID with their default values
         # from the BPE recipe which we're adapting here.
diff --git a/icefall/graph_compiler.py b/icefall/graph_compiler.py
index 0dcd777ad..d26ddbbd1 100644
--- a/icefall/graph_compiler.py
+++ b/icefall/graph_compiler.py
@@ -29,6 +29,7 @@ class CtcTrainingGraphCompiler(object):
         lexicon: Lexicon,
         device: torch.device,
         oov: str = "",
+        need_repeat_flag: bool = False,
     ):
         """
         Args:
@@ -39,6 +40,13 @@ class CtcTrainingGraphCompiler(object):
           oov:
             Out of vocabulary word. When a word in the transcript
             does not exist in the lexicon, it is replaced with `oov`.
+          need_repeat_flag:
+            If True, will add an attribute named `_is_repeat_token_` to ctc_topo
+            indicating whether this token is a repeat token in ctc graph.
+            This attribute is needed to implement delay-penalty for phone-based
+            ctc loss. See https://github.com/k2-fsa/k2/pull/1086 for more
+            details. Note: The above change MUST be included in k2 to open this
+            flag.
         """
         L_inv = lexicon.L_inv.to(device)
         assert L_inv.requires_grad is False
@@ -53,6 +61,12 @@ class CtcTrainingGraphCompiler(object):
         ctc_topo = k2.ctc_topo(max_token_id, modified=False)
 
         self.ctc_topo = ctc_topo.to(device)
+
+        if need_repeat_flag:
+            self.ctc_topo._is_repeat_token_ = (
+                self.ctc_topo.labels != self.ctc_topo.aux_labels
+            )
+
         self.device = device
 
     def compile(self, texts: List[str]) -> k2.Fsa:
@@ -79,10 +93,6 @@ class CtcTrainingGraphCompiler(object):
 
         fsa_with_self_loops = k2.arc_sort(fsa_with_self_loops)
 
-        self.ctc_topo._is_repeat_token_ = (
-            self.ctc_topo.labels != self.ctc_topo.aux_labels
-        ).int()
-
         decoding_graph = k2.compose(
             self.ctc_topo, fsa_with_self_loops, treat_epsilons_specially=False
         )

From bd7fa2253dab9f627edc914b3289fb2f6c0e5bb6 Mon Sep 17 00:00:00 2001
From: Fangjun Kuang 
Date: Sun, 4 Dec 2022 20:27:45 +0800
Subject: [PATCH 035/174] Update the manifest statistics of the L subset of
 wenetspeech (#731)

---
 .../ASR/local/display_manifest_statistics.py  | 19 +++++++++++++++++++
 1 file changed, 19 insertions(+)

diff --git a/egs/wenetspeech/ASR/local/display_manifest_statistics.py b/egs/wenetspeech/ASR/local/display_manifest_statistics.py
index c41445b8d..36e4ac5c3 100644
--- a/egs/wenetspeech/ASR/local/display_manifest_statistics.py
+++ b/egs/wenetspeech/ASR/local/display_manifest_statistics.py
@@ -33,6 +33,7 @@ def main():
     paths = [
         "./data/fbank/cuts_S.jsonl.gz",
         "./data/fbank/cuts_M.jsonl.gz",
+        "./data/fbank/cuts_L.jsonl.gz",
         "./data/fbank/cuts_DEV.jsonl.gz",
         "./data/fbank/cuts_TEST_NET.jsonl.gz",
         "./data/fbank/cuts_TEST_MEETING.jsonl.gz",
@@ -48,6 +49,24 @@ if __name__ == "__main__":
     main()
 
 """
+Starting display the statistics for ./data/fbank/cuts_L.jsonl.gz
+
+Cuts count: 43874235
+Total duration (hours): 30217.3
+Speech duration (hours): 30217.3 (100.0%)
+***
+Duration statistics (seconds):
+mean    2.5
+std     1.7
+min     0.2
+25%     1.4
+50%     2.0
+75%     3.0
+99%     8.4
+99.5%   9.1
+99.9%   15.4
+max     405.1
+
 Starting display the statistics for ./data/fbank/cuts_S.jsonl.gz
 Duration statistics (seconds):
 mean    2.4

From be6e08f69a9384de27c28115a299d4fe64bb5de1 Mon Sep 17 00:00:00 2001
From: Cesc 
Date: Mon, 5 Dec 2022 23:35:10 +0800
Subject: [PATCH 036/174] fix wenet stateless5 jit export error (#735)

---
 egs/wenetspeech/ASR/pruned_transducer_stateless5/export.py      | 2 ++
 egs/wenetspeech/ASR/pruned_transducer_stateless5/lstmp.py       | 1 +
 .../ASR/pruned_transducer_stateless5/scaling_converter.py       | 1 +
 3 files changed, 4 insertions(+)
 mode change 100644 => 100755 egs/wenetspeech/ASR/pruned_transducer_stateless5/export.py
 create mode 120000 egs/wenetspeech/ASR/pruned_transducer_stateless5/lstmp.py
 create mode 120000 egs/wenetspeech/ASR/pruned_transducer_stateless5/scaling_converter.py

diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/export.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/export.py
old mode 100644
new mode 100755
index 35577c327..cb541070e
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/export.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/export.py
@@ -74,6 +74,7 @@ import logging
 from pathlib import Path
 
 import torch
+from scaling_converter import convert_scaled_to_non_scaled
 from train import add_model_arguments, get_params, get_transducer_model
 
 from icefall.checkpoint import average_checkpoints, load_checkpoint
@@ -184,6 +185,7 @@ def main():
         # it here.
         # Otherwise, one of its arguments is a ragged tensor and is not
         # torch scriptabe.
+        convert_scaled_to_non_scaled(model, inplace=True)
         model.__class__.forward = torch.jit.ignore(model.__class__.forward)
         logging.info("Using torch.jit.script")
         model = torch.jit.script(model)
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/lstmp.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/lstmp.py
new file mode 120000
index 000000000..d13a1e063
--- /dev/null
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/lstmp.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless5/lstmp.py
\ No newline at end of file
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/scaling_converter.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/scaling_converter.py
new file mode 120000
index 000000000..e58473a04
--- /dev/null
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/scaling_converter.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless5/scaling_converter.py
\ No newline at end of file

From f13cf61b05432a989e6a42c95b843a56639bcbde Mon Sep 17 00:00:00 2001
From: Fangjun Kuang 
Date: Tue, 6 Dec 2022 16:34:27 +0800
Subject: [PATCH 037/174] Convert conv-emformer to ncnn (#717)

* Export conv-emformer via torch.jit.trace()
---
 ...former-transducer-stateless2-2022-12-05.sh |   79 +
 ...-lstm-transducer-stateless2-2022-09-03.sh} |    0
 ...ormer-transducer-stateless2-2022-12-05.yml |   77 +
 ...-lstm-transducer-stateless2-2022-09-03.yml |    2 +-
 .../emformer2.py                              | 1798 +++++++++++++++++
 .../export-for-ncnn.py                        |  335 +++
 .../jit_pretrained.py                         |  292 +++
 .../lstmp.py                                  |    1 +
 .../scaling_converter.py                      |    1 +
 .../streaming-ncnn-decode.py                  |  387 ++++
 .../train2.py                                 | 1128 +++++++++++
 11 files changed, 4099 insertions(+), 1 deletion(-)
 create mode 100755 .github/scripts/run-librispeech-conv-emformer-transducer-stateless2-2022-12-05.sh
 rename .github/scripts/{run-librispeech-lstm-transducer-stateless2-2022-09-03.yml => run-librispeech-lstm-transducer-stateless2-2022-09-03.sh} (100%)
 create mode 100644 .github/workflows/run-librispeech-conv-emformer-transducer-stateless2-2022-12-05.yml
 create mode 100644 egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer2.py
 create mode 100755 egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-for-ncnn.py
 create mode 100755 egs/librispeech/ASR/conv_emformer_transducer_stateless2/jit_pretrained.py
 create mode 120000 egs/librispeech/ASR/conv_emformer_transducer_stateless2/lstmp.py
 create mode 120000 egs/librispeech/ASR/conv_emformer_transducer_stateless2/scaling_converter.py
 create mode 100755 egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming-ncnn-decode.py
 create mode 100755 egs/librispeech/ASR/conv_emformer_transducer_stateless2/train2.py

diff --git a/.github/scripts/run-librispeech-conv-emformer-transducer-stateless2-2022-12-05.sh b/.github/scripts/run-librispeech-conv-emformer-transducer-stateless2-2022-12-05.sh
new file mode 100755
index 000000000..32c939206
--- /dev/null
+++ b/.github/scripts/run-librispeech-conv-emformer-transducer-stateless2-2022-12-05.sh
@@ -0,0 +1,79 @@
+#!/usr/bin/env bash
+#
+set -e
+
+log() {
+  # This function is from espnet
+  local fname=${BASH_SOURCE[1]##*/}
+  echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
+}
+
+cd egs/librispeech/ASR
+
+repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05
+
+log "Downloading pre-trained model from $repo_url"
+GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
+repo=$(basename $repo_url)
+pushd $repo
+git lfs pull --include "exp/pretrained-epoch-30-avg-10-averaged.pt"
+git lfs pull --include "data/lang_bpe_500/bpe.model"
+cd exp
+ln -s pretrained-epoch-30-avg-10-averaged.pt epoch-99.pt
+popd
+
+log "Display test files"
+tree $repo/
+soxi $repo/test_wavs/*.wav
+ls -lh $repo/test_wavs/*.wav
+
+log  "Install ncnn and pnnx"
+
+# We are using a modified ncnn here. Will try to merge it to the official repo
+# of ncnn
+git clone https://github.com/csukuangfj/ncnn
+pushd ncnn
+git submodule init
+git submodule update python/pybind11
+python3 setup.py bdist_wheel
+ls -lh dist/
+pip install dist/*.whl
+cd tools/pnnx
+mkdir build
+cd build
+cmake -D Python3_EXECUTABLE=/opt/hostedtoolcache/Python/3.8.14/x64/bin/python3 ..
+make -j4 pnnx
+
+./src/pnnx || echo "pass"
+
+popd
+
+log "Test exporting to pnnx format"
+
+./conv_emformer_transducer_stateless2/export-for-ncnn.py \
+  --exp-dir $repo/exp \
+  --bpe-model $repo/data/lang_bpe_500/bpe.model \
+  --epoch 99 \
+  --avg 1 \
+  --use-averaged-model 0 \
+  \
+  --num-encoder-layers 12 \
+  --chunk-length 32 \
+  --cnn-module-kernel 31 \
+  --left-context-length 32 \
+  --right-context-length 8 \
+  --memory-size 32
+
+./ncnn/tools/pnnx/build/src/pnnx $repo/exp/encoder_jit_trace-pnnx.pt
+./ncnn/tools/pnnx/build/src/pnnx $repo/exp/decoder_jit_trace-pnnx.pt
+./ncnn/tools/pnnx/build/src/pnnx $repo/exp/joiner_jit_trace-pnnx.pt
+
+./conv_emformer_transducer_stateless2/streaming-ncnn-decode.py \
+ --tokens $repo/data/lang_bpe_500/tokens.txt \
+ --encoder-param-filename $repo/exp/encoder_jit_trace-pnnx.ncnn.param \
+ --encoder-bin-filename $repo/exp/encoder_jit_trace-pnnx.ncnn.bin \
+ --decoder-param-filename $repo/exp/decoder_jit_trace-pnnx.ncnn.param \
+ --decoder-bin-filename $repo/exp/decoder_jit_trace-pnnx.ncnn.bin \
+ --joiner-param-filename $repo/exp/joiner_jit_trace-pnnx.ncnn.param \
+ --joiner-bin-filename $repo/exp/joiner_jit_trace-pnnx.ncnn.bin \
+ $repo/test_wavs/1089-134686-0001.wav
diff --git a/.github/scripts/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml b/.github/scripts/run-librispeech-lstm-transducer-stateless2-2022-09-03.sh
similarity index 100%
rename from .github/scripts/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml
rename to .github/scripts/run-librispeech-lstm-transducer-stateless2-2022-09-03.sh
diff --git a/.github/workflows/run-librispeech-conv-emformer-transducer-stateless2-2022-12-05.yml b/.github/workflows/run-librispeech-conv-emformer-transducer-stateless2-2022-12-05.yml
new file mode 100644
index 000000000..b9a1582c4
--- /dev/null
+++ b/.github/workflows/run-librispeech-conv-emformer-transducer-stateless2-2022-12-05.yml
@@ -0,0 +1,77 @@
+name: run-librispeech-conv-emformer-transducer-stateless2-2022-12-05
+
+on:
+  push:
+    branches:
+      - master
+  pull_request:
+    types: [labeled]
+
+  schedule:
+    # minute (0-59)
+    # hour (0-23)
+    # day of the month (1-31)
+    # month (1-12)
+    # day of the week (0-6)
+    # nightly build at 15:50 UTC time every day
+    - cron: "50 15 * * *"
+
+jobs:
+  run_librispeech_conv_emformer_transducer_stateless2_2022_12_05:
+    if: github.event.label.name == 'ready' || github.event.label.name == 'ncnn' || github.event_name == 'push' || github.event_name == 'schedule'
+    runs-on: ${{ matrix.os }}
+    strategy:
+      matrix:
+        os: [ubuntu-latest]
+        python-version: [3.8]
+
+      fail-fast: false
+
+    steps:
+      - uses: actions/checkout@v2
+        with:
+          fetch-depth: 0
+
+      - name: Setup Python ${{ matrix.python-version }}
+        uses: actions/setup-python@v2
+        with:
+          python-version: ${{ matrix.python-version }}
+          cache: 'pip'
+          cache-dependency-path: '**/requirements-ci.txt'
+
+      - name: Install Python dependencies
+        run: |
+          grep -v '^#' ./requirements-ci.txt  | grep -v kaldifst | xargs -n 1 -L 1 pip install
+          pip uninstall -y protobuf
+          pip install --no-binary protobuf protobuf
+
+      - name: Cache kaldifeat
+        id: my-cache
+        uses: actions/cache@v2
+        with:
+          path: |
+            ~/tmp/kaldifeat
+          key: cache-tmp-${{ matrix.python-version }}-2022-09-25
+
+      - name: Install kaldifeat
+        if: steps.my-cache.outputs.cache-hit != 'true'
+        shell: bash
+        run: |
+          .github/scripts/install-kaldifeat.sh
+
+      - name: Inference with pre-trained model
+        shell: bash
+        env:
+          GITHUB_EVENT_NAME: ${{ github.event_name }}
+          GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }}
+        run: |
+          mkdir -p egs/librispeech/ASR/data
+          ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank
+          ls -lh egs/librispeech/ASR/data/*
+
+          sudo apt-get -qq install git-lfs tree sox
+          export PYTHONPATH=$PWD:$PYTHONPATH
+          export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
+          export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
+
+          .github/scripts/run-librispeech-conv-emformer-transducer-stateless2-2022-12-05.sh
diff --git a/.github/workflows/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml b/.github/workflows/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml
index 59f116fde..f5ee09e16 100644
--- a/.github/workflows/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml
+++ b/.github/workflows/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml
@@ -111,7 +111,7 @@ jobs:
           export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
           export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
 
-          .github/scripts/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml
+          .github/scripts/run-librispeech-lstm-transducer-stateless2-2022-09-03.sh
 
       - name: Display decoding results for lstm_transducer_stateless2
         if: github.event_name == 'schedule'
diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer2.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer2.py
new file mode 100644
index 000000000..65a7efa77
--- /dev/null
+++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer2.py
@@ -0,0 +1,1798 @@
+# Copyright      2022  Xiaomi Corporation     (Author: Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# It is modified based on
+# 1) https://github.com/pytorch/audio/blob/main/torchaudio/models/emformer.py  # noqa
+# 2) https://github.com/pytorch/audio/blob/main/torchaudio/prototype/models/conv_emformer.py  # noqa
+
+import math
+from typing import List, Optional, Tuple
+
+import torch
+import torch.nn as nn
+from encoder_interface import EncoderInterface
+from scaling import (
+    ActivationBalancer,
+    BasicNorm,
+    DoubleSwish,
+    ScaledConv1d,
+    ScaledConv2d,
+    ScaledLinear,
+)
+
+from icefall.utils import make_pad_mask
+
+LOG_EPSILON = math.log(1e-10)
+
+
+def unstack_states(
+    states: Tuple[List[List[torch.Tensor]], List[torch.Tensor]]
+) -> List[Tuple[List[List[torch.Tensor]], List[torch.Tensor]]]:
+    """Unstack the emformer state corresponding to a batch of utterances
+    into a list of states, where the i-th entry is the state from the i-th
+    utterance in the batch.
+
+    Args:
+      states:
+        A tuple of 2 elements.
+        ``states[0]`` is the attention caches of a batch of utterance.
+        ``states[1]`` is the convolution caches of a batch of utterance.
+        ``len(states[0])`` and ``len(states[1])`` both eqaul to number of layers.  # noqa
+
+    Returns:
+      A list of states.
+      ``states[i]`` is a tuple of 2 elements of i-th utterance.
+      ``states[i][0]`` is the attention caches of i-th utterance.
+      ``states[i][1]`` is the convolution caches of i-th utterance.
+      ``len(states[i][0])`` and ``len(states[i][1])`` both eqaul to number of layers.  # noqa
+    """
+
+    attn_caches, conv_caches = states
+    batch_size = conv_caches[0].size(0)
+    num_layers = len(attn_caches)
+
+    list_attn_caches = [None] * batch_size
+    for i in range(batch_size):
+        list_attn_caches[i] = [[] for _ in range(num_layers)]
+    for li, layer in enumerate(attn_caches):
+        for s in layer:
+            s_list = s.unbind(dim=1)
+            for bi, b in enumerate(list_attn_caches):
+                b[li].append(s_list[bi])
+
+    list_conv_caches = [None] * batch_size
+    for i in range(batch_size):
+        list_conv_caches[i] = [None] * num_layers
+    for li, layer in enumerate(conv_caches):
+        c_list = layer.unbind(dim=0)
+        for bi, b in enumerate(list_conv_caches):
+            b[li] = c_list[bi]
+
+    ans = [None] * batch_size
+    for i in range(batch_size):
+        ans[i] = [list_attn_caches[i], list_conv_caches[i]]
+
+    return ans
+
+
+def stack_states(
+    state_list: List[Tuple[List[List[torch.Tensor]], List[torch.Tensor]]]
+) -> Tuple[List[List[torch.Tensor]], List[torch.Tensor]]:
+    """Stack list of emformer states that correspond to separate utterances
+    into a single emformer state so that it can be used as an input for
+    emformer when those utterances are formed into a batch.
+
+    Note:
+      It is the inverse of :func:`unstack_states`.
+
+    Args:
+      state_list:
+        Each element in state_list corresponding to the internal state
+        of the emformer model for a single utterance.
+        ``states[i]`` is a tuple of 2 elements of i-th utterance.
+        ``states[i][0]`` is the attention caches of i-th utterance.
+        ``states[i][1]`` is the convolution caches of i-th utterance.
+        ``len(states[i][0])`` and ``len(states[i][1])`` both eqaul to number of layers.  # noqa
+
+    Returns:
+      A new state corresponding to a batch of utterances.
+      See the input argument of :func:`unstack_states` for the meaning
+      of the returned tensor.
+    """
+    batch_size = len(state_list)
+
+    attn_caches = []
+    for layer in state_list[0][0]:
+        if batch_size > 1:
+            # Note: We will stack attn_caches[layer][s][] later to get attn_caches[layer][s]  # noqa
+            attn_caches.append([[s] for s in layer])
+        else:
+            attn_caches.append([s.unsqueeze(1) for s in layer])
+    for b, states in enumerate(state_list[1:], 1):
+        for li, layer in enumerate(states[0]):
+            for si, s in enumerate(layer):
+                attn_caches[li][si].append(s)
+                if b == batch_size - 1:
+                    attn_caches[li][si] = torch.stack(attn_caches[li][si], dim=1)
+
+    conv_caches = []
+    for layer in state_list[0][1]:
+        if batch_size > 1:
+            # Note: We will stack conv_caches[layer][] later to get conv_caches[layer]  # noqa
+            conv_caches.append([layer])
+        else:
+            conv_caches.append(layer.unsqueeze(0))
+    for b, states in enumerate(state_list[1:], 1):
+        for li, layer in enumerate(states[1]):
+            conv_caches[li].append(layer)
+            if b == batch_size - 1:
+                conv_caches[li] = torch.stack(conv_caches[li], dim=0)
+
+    return [attn_caches, conv_caches]
+
+
+class ConvolutionModule(nn.Module):
+    """ConvolutionModule.
+
+    Modified from https://github.com/pytorch/audio/blob/main/torchaudio/prototype/models/conv_emformer.py # noqa
+
+    Args:
+      chunk_length (int):
+        Length of each chunk.
+      right_context_length (int):
+        Length of right context.
+      channels (int):
+        The number of input channels and output channels of conv layers.
+      kernel_size (int):
+        Kernerl size of conv layers.
+      bias (bool):
+        Whether to use bias in conv layers (default=True).
+    """
+
+    def __init__(
+        self,
+        chunk_length: int,
+        right_context_length: int,
+        channels: int,
+        kernel_size: int,
+        bias: bool = True,
+    ) -> None:
+        """Construct an ConvolutionModule object."""
+        super().__init__()
+        # kernerl_size should be an odd number for 'SAME' padding
+        assert (kernel_size - 1) % 2 == 0, kernel_size
+
+        self.chunk_length = chunk_length
+        self.right_context_length = right_context_length
+        self.channels = channels
+
+        self.pointwise_conv1 = ScaledConv1d(
+            channels,
+            2 * channels,
+            kernel_size=1,
+            stride=1,
+            padding=0,
+            bias=bias,
+        )
+        # After pointwise_conv1 we put x through a gated linear unit
+        # (nn.functional.glu).
+        # For most layers the normal rms value of channels of x seems to be in
+        # the range 1 to 4, but sometimes, for some reason, for layer 0 the rms
+        # ends up being very large, between 50 and 100 for different channels.
+        # This will cause very peaky and sparse derivatives for the sigmoid
+        # gating function, which will tend to make the loss function not learn
+        # effectively.  (for most layers the average absolute values are in the
+        # range 0.5..9.0, and the average p(x>0), i.e. positive proportion,
+        # at the output of pointwise_conv1.output is around 0.35 to 0.45 for
+        # different layers, which likely breaks down as 0.5 for the "linear"
+        # half and 0.2 to 0.3 for the part that goes into the sigmoid.
+        # The idea is that if we constrain the rms values to a reasonable range
+        # via a constraint of max_abs=10.0, it will be in a better position to
+        # start learning something, i.e. to latch onto the correct range.
+        self.deriv_balancer1 = ActivationBalancer(
+            channel_dim=1, max_abs=10.0, min_positive=0.05, max_positive=1.0
+        )
+
+        # make it causal by padding cached (kernel_size - 1) frames on the left
+        self.cache_size = kernel_size - 1
+        self.depthwise_conv = ScaledConv1d(
+            channels,
+            channels,
+            kernel_size,
+            stride=1,
+            padding=0,
+            groups=channels,
+            bias=bias,
+        )
+
+        self.deriv_balancer2 = ActivationBalancer(
+            channel_dim=1, min_positive=0.05, max_positive=1.0
+        )
+
+        self.activation = DoubleSwish()
+
+        self.pointwise_conv2 = ScaledConv1d(
+            channels,
+            channels,
+            kernel_size=1,
+            stride=1,
+            padding=0,
+            bias=bias,
+            initial_scale=0.25,
+        )
+
+    def _split_right_context(
+        self,
+        pad_utterance: torch.Tensor,
+        right_context: torch.Tensor,
+    ) -> torch.Tensor:
+        """
+        Args:
+          pad_utterance:
+            Its shape is (cache_size + U, B, D).
+          right_context:
+            Its shape is (R, B, D).
+
+        Returns:
+          Right context segments padding with corresponding context.
+          Its shape is (num_segs * B, D, cache_size + right_context_length).
+        """
+        U_, B, D = pad_utterance.size()
+        R = right_context.size(0)
+        assert self.right_context_length != 0
+        assert R % self.right_context_length == 0
+        num_chunks = R // self.right_context_length
+        right_context = right_context.reshape(
+            num_chunks, self.right_context_length, B, D
+        )
+        right_context = right_context.permute(0, 2, 1, 3).reshape(
+            num_chunks * B, self.right_context_length, D
+        )
+
+        intervals = torch.arange(
+            0, self.chunk_length * (num_chunks - 1), self.chunk_length
+        )
+        first = torch.arange(self.chunk_length, self.chunk_length + self.cache_size)
+        indexes = intervals.unsqueeze(1) + first.unsqueeze(0)
+        indexes = torch.cat(
+            [indexes, torch.arange(U_ - self.cache_size, U_).unsqueeze(0)]
+        )
+        padding = pad_utterance[indexes]  # (num_chunks, cache_size, B, D)
+        padding = padding.permute(0, 2, 1, 3).reshape(
+            num_chunks * B, self.cache_size, D
+        )
+
+        pad_right_context = torch.cat([padding, right_context], dim=1)
+        # (num_chunks * B, cache_size + right_context_length, D)
+        return pad_right_context.permute(0, 2, 1)
+
+    def _merge_right_context(self, right_context: torch.Tensor, B: int) -> torch.Tensor:
+        """
+        Args:
+          right_context:
+            Right context segments.
+            It shape is (num_segs * B, D, right_context_length).
+          B:
+            Batch size.
+
+        Returns:
+          A tensor of shape (B, D, R), where
+          R = num_segs * right_context_length.
+        """
+        right_context = right_context.reshape(
+            -1, B, self.channels, self.right_context_length
+        )
+        right_context = right_context.permute(1, 2, 0, 3)
+        right_context = right_context.reshape(B, self.channels, -1)
+        return right_context
+
+    def forward(
+        self,
+        utterance: torch.Tensor,
+        right_context: torch.Tensor,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """Causal convolution module.
+
+        Args:
+          utterance (torch.Tensor):
+            Utterance tensor of shape (U, B, D).
+          right_context (torch.Tensor):
+            Right context tensor of shape (R, B, D).
+
+        Returns:
+          A tuple of 2 tensors:
+          - output utterance of shape (U, B, D).
+          - output right_context of shape (R, B, D).
+        """
+        U, B, D = utterance.size()
+        R, _, _ = right_context.size()
+
+        # point-wise conv and GLU mechanism
+        x = torch.cat([right_context, utterance], dim=0)  # (R + U, B, D)
+        x = x.permute(1, 2, 0)  # (B, D, R + U)
+        x = self.pointwise_conv1(x)  # (B, 2 * D, R + U)
+        x = self.deriv_balancer1(x)
+        x = nn.functional.glu(x, dim=1)  # (B, D, R + U)
+        utterance = x[:, :, R:]  # (B, D, U)
+        right_context = x[:, :, :R]  # (B, D, R)
+
+        # make causal convolution
+        cache = torch.zeros(B, D, self.cache_size, device=x.device, dtype=x.dtype)
+        pad_utterance = torch.cat([cache, utterance], dim=2)  # (B, D, cache + U)
+
+        # depth-wise conv on utterance
+        utterance = self.depthwise_conv(pad_utterance)  # (B, D, U)
+
+        if self.right_context_length > 0:
+            # depth-wise conv on right_context
+            pad_right_context = self._split_right_context(
+                pad_utterance.permute(2, 0, 1), right_context.permute(2, 0, 1)
+            )  # (num_segs * B, D, cache_size + right_context_length)
+            right_context = self.depthwise_conv(
+                pad_right_context
+            )  # (num_segs * B, D, right_context_length)
+            right_context = self._merge_right_context(right_context, B)  # (B, D, R)
+
+        x = torch.cat([right_context, utterance], dim=2)  # (B, D, R + U)
+        x = self.deriv_balancer2(x)
+        x = self.activation(x)
+
+        # point-wise conv
+        x = self.pointwise_conv2(x)  # (B, D, R + U)
+
+        right_context = x[:, :, :R]  # (B, D, R)
+        utterance = x[:, :, R:]  # (B, D, U)
+        return (
+            utterance.permute(2, 0, 1),
+            right_context.permute(2, 0, 1),
+        )
+
+    @torch.jit.export
+    def infer(
+        self,
+        utterance: torch.Tensor,
+        right_context: torch.Tensor,
+        cache: torch.Tensor,
+    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+        """Causal convolution module applied on both utterance and right_context.
+
+        Args:
+          utterance (torch.Tensor):
+            Utterance tensor of shape (U, B, D).
+          right_context (torch.Tensor):
+            Right context tensor of shape (R, B, D).
+          cache (torch.Tensor, optional):
+            Cached tensor for left padding of shape (B, D, cache_size).
+
+        Returns:
+          A tuple of 3 tensors:
+            - output utterance of shape (U, B, D).
+            - output right_context of shape (R, B, D).
+            - updated cache tensor of shape (B, D, cache_size).
+        """
+        #  U, B, D = utterance.size()
+        #  R, _, _ = right_context.size()
+        U = self.chunk_length
+        B = 1
+        D = self.channels
+        R = self.right_context_length
+
+        # point-wise conv
+        x = torch.cat([utterance, right_context], dim=0)  # (U + R, B, D)
+        x = x.permute(1, 2, 0)  # (B, D, U + R)
+        x = self.pointwise_conv1(x)  # (B, 2 * D, U + R)
+        x = self.deriv_balancer1(x)
+        x = nn.functional.glu(x, dim=1)  # (B, D, U + R)
+
+        # make causal convolution
+        assert cache.shape == (B, D, self.cache_size), cache.shape
+        x = torch.cat([cache, x], dim=2)  # (B, D, cache_size + U + R)
+        # update cache
+        new_cache = x[:, :, -R - self.cache_size : -R]
+
+        # 1-D depth-wise conv
+        x = self.depthwise_conv(x)  # (B, D, U + R)
+
+        x = self.deriv_balancer2(x)
+        x = self.activation(x)
+
+        # point-wise conv
+        x = self.pointwise_conv2(x)  # (B, D, U + R)
+
+        utterance = x[:, :, :U]  # (B, D, U)
+        right_context = x[:, :, U:]  # (B, D, R)
+        return (
+            utterance.permute(2, 0, 1),
+            right_context.permute(2, 0, 1),
+            new_cache,
+        )
+
+
+class EmformerAttention(nn.Module):
+    r"""Emformer layer attention module.
+
+    Args:
+      embed_dim (int):
+        Embedding dimension.
+      nhead (int):
+        Number of attention heads in each Emformer layer.
+      dropout (float, optional):
+        Dropout probability. (Default: 0.0)
+      tanh_on_mem (bool, optional):
+        If ``True``, applies tanh to memory elements. (Default: ``False``)
+      negative_inf (float, optional):
+        Value to use for negative infinity in attention weights. (Default: -1e8)
+    """
+
+    def __init__(
+        self,
+        embed_dim: int,
+        nhead: int,
+        left_context_length: int,
+        chunk_length: int,
+        right_context_length: int,
+        memory_size: int,
+        dropout: float = 0.0,
+        tanh_on_mem: bool = False,
+        negative_inf: float = -1e8,
+    ):
+        super().__init__()
+
+        if embed_dim % nhead != 0:
+            raise ValueError(
+                f"embed_dim ({embed_dim}) is not a multiple of nhead ({nhead})."
+            )
+
+        self.embed_dim = embed_dim
+        self.nhead = nhead
+        self.tanh_on_mem = tanh_on_mem
+        self.negative_inf = negative_inf
+        self.head_dim = embed_dim // nhead
+        self.dropout = dropout
+
+        self.left_context_length = left_context_length
+        self.right_context_length = right_context_length
+        self.chunk_length = chunk_length
+        self.memory_size = memory_size
+
+        self.emb_to_key_value = ScaledLinear(embed_dim, 2 * embed_dim, bias=True)
+        self.emb_to_query = ScaledLinear(embed_dim, embed_dim, bias=True)
+        self.out_proj = ScaledLinear(
+            embed_dim, embed_dim, bias=True, initial_scale=0.25
+        )
+
+    def _gen_attention_probs(
+        self,
+        attention_weights: torch.Tensor,
+        attention_mask: torch.Tensor,
+        padding_mask: Optional[torch.Tensor] = None,
+    ) -> torch.Tensor:
+        """Given the entire attention weights, mask out unecessary connections
+        and optionally with padding positions, to obtain underlying chunk-wise
+        attention probabilities.
+
+        B: batch size;
+        Q: length of query;
+        KV: length of key and value.
+
+        Args:
+          attention_weights (torch.Tensor):
+            Attention weights computed on the entire concatenated tensor
+            with shape (B * nhead, Q, KV).
+          attention_mask (torch.Tensor):
+            Mask tensor where chunk-wise connections are filled with `False`,
+            and other unnecessary connections are filled with `True`,
+            with shape (Q, KV).
+          padding_mask (torch.Tensor, optional):
+            Mask tensor where the padding positions are fill with `True`,
+            and other positions are filled with `False`, with shapa `(B, KV)`.
+
+        Returns:
+          A tensor of shape (B * nhead, Q, KV).
+        """
+        attention_weights_float = attention_weights.float()
+        attention_weights_float = attention_weights_float.masked_fill(
+            attention_mask.unsqueeze(0), self.negative_inf
+        )
+        if padding_mask is not None:
+            Q = attention_weights.size(1)
+            B = attention_weights.size(0) // self.nhead
+            attention_weights_float = attention_weights_float.view(B, self.nhead, Q, -1)
+            attention_weights_float = attention_weights_float.masked_fill(
+                padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
+                self.negative_inf,
+            )
+            attention_weights_float = attention_weights_float.view(
+                B * self.nhead, Q, -1
+            )
+
+        attention_probs = nn.functional.softmax(
+            attention_weights_float, dim=-1
+        ).type_as(attention_weights)
+
+        attention_probs = nn.functional.dropout(
+            attention_probs, p=self.dropout, training=self.training
+        )
+        return attention_probs
+
+    def _forward_impl(
+        self,
+        utterance: torch.Tensor,
+        right_context: torch.Tensor,
+        memory: torch.Tensor,
+        attention_mask: torch.Tensor,
+        padding_mask: Optional[torch.Tensor] = None,
+        left_context_key: Optional[torch.Tensor] = None,
+        left_context_val: Optional[torch.Tensor] = None,
+    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+        """Underlying chunk-wise attention implementation."""
+        #  U, B, _ = utterance.size()
+        #  R = right_context.size(0)
+        #  M = memory.size(0)
+
+        U = self.chunk_length
+        B = 1
+        R = self.right_context_length
+        M = self.memory_size
+        L = self.left_context_length
+
+        scaling = float(self.head_dim) ** -0.5
+
+        # compute query with [right_context, utterance].
+        query = self.emb_to_query(torch.cat([right_context, utterance]))
+        # compute key and value with [memory, right_context, utterance].
+        key, value = self.emb_to_key_value(
+            torch.cat([memory, right_context, utterance])
+        ).chunk(chunks=2, dim=2)
+
+        if left_context_key is not None and left_context_val is not None:
+            # now compute key and value with
+            #   [memory, right context, left context, uttrance]
+            # this is used in inference mode
+            key = torch.cat([key[: M + R], left_context_key, key[M + R :]])
+            value = torch.cat([value[: M + R], left_context_val, value[M + R :]])
+
+        #  Q = query.size(0)
+        Q = U + R
+
+        # KV = key.size(0)
+
+        reshaped_query = query.view(Q, self.nhead, self.head_dim).permute(1, 0, 2)
+        reshaped_key = key.view(M + R + U + L, self.nhead, self.head_dim).permute(
+            1, 0, 2
+        )
+        reshaped_value = value.view(M + R + U + L, self.nhead, self.head_dim).permute(
+            1, 0, 2
+        )
+
+        #  reshaped_query, reshaped_key, reshaped_value = [
+        #      tensor.contiguous().view(-1, B * self.nhead, self.head_dim).transpose(0, 1)
+        #      for tensor in [query, key, value]
+        #  ]  # (B * nhead, Q or KV, head_dim)
+        attention_weights = torch.bmm(
+            reshaped_query * scaling, reshaped_key.permute(0, 2, 1)
+        )  # (B * nhead, Q, KV)
+
+        # compute attention probabilities
+        if False:
+            attention_probs = self._gen_attention_probs(
+                attention_weights, attention_mask, padding_mask
+            )
+        else:
+            attention_probs = nn.functional.softmax(attention_weights, dim=-1)
+
+        # compute attention outputs
+        attention = torch.bmm(attention_probs, reshaped_value)
+        assert attention.shape == (B * self.nhead, Q, self.head_dim)
+        attention = attention.permute(1, 0, 2).reshape(-1, self.embed_dim)
+        # TODO(fangjun): ncnn does not support reshape(-1, 1, self.embed_dim)
+        # We have to change InnerProduct in ncnn to ignore the extra dim below
+        attention = attention.unsqueeze(1)
+
+        # apply output projection
+        output_right_context_utterance = self.out_proj(attention)
+        # The return shape of output_right_context_utterance is (10, 1, 512)
+
+        return output_right_context_utterance, key, value
+
+    def forward(
+        self,
+        utterance: torch.Tensor,
+        right_context: torch.Tensor,
+        memory: torch.Tensor,
+        attention_mask: torch.Tensor,
+        padding_mask: Optional[torch.Tensor] = None,
+    ) -> torch.Tensor:
+        # TODO: Modify docs.
+        """Forward pass for training and validation mode.
+
+        B: batch size;
+        D: embedding dimension;
+        R: length of the hard-copied right contexts;
+        U: length of full utterance;
+        M: length of memory vectors.
+
+        It computes a `big` attention matrix on full utterance and
+        then utilizes a pre-computed mask to simulate chunk-wise attention.
+
+        It concatenates three blocks: hard-copied right contexts,
+        and full utterance, as a `big` block,
+        to compute the query tensor:
+        query = [right_context, utterance],
+        with length Q = R + U.
+        It concatenates the three blocks: memory vectors,
+        hard-copied right contexts, and full utterance as another `big` block,
+        to compute the key and value tensors:
+        key & value = [memory, right_context, utterance],
+        with length KV = M + R + U.
+        Attention scores is computed with above `big` query and key.
+
+        Then the underlying chunk-wise attention is obtained by applying
+        the attention mask. Suppose
+        c_i: chunk at index i;
+        r_i: right context that c_i can use;
+        l_i: left context that c_i can use;
+        m_i: past memory vectors from previous layer that c_i can use;
+        The target chunk-wise attention is:
+        c_i, r_i (in query) -> l_i, c_i, r_i, m_i (in key)
+
+        Args:
+          utterance (torch.Tensor):
+            Full utterance frames, with shape (U, B, D).
+          right_context (torch.Tensor):
+            Hard-copied right context frames, with shape (R, B, D),
+            where R = num_chunks * right_context_length
+          memory (torch.Tensor):
+            Memory elements, with shape (M, B, D), where M = num_chunks - 1.
+            It is an empty tensor without using memory.
+          attention_mask (torch.Tensor):
+            Pre-computed attention mask to simulate underlying chunk-wise
+            attention, with shape (Q, KV).
+          padding_mask (torch.Tensor):
+            Padding mask of key tensor, with shape (B, KV).
+
+        Returns:
+          Output of right context and utterance, with shape (R + U, B, D).
+        """
+        output_right_context_utterance, _, _ = self._forward_impl(
+            utterance,
+            right_context,
+            memory,
+            attention_mask,
+            padding_mask=padding_mask,
+        )
+        return output_right_context_utterance
+
+    @torch.jit.export
+    def infer(
+        self,
+        utterance: torch.Tensor,
+        right_context: torch.Tensor,
+        memory: torch.Tensor,
+        left_context_key: torch.Tensor,
+        left_context_val: torch.Tensor,
+        padding_mask: Optional[torch.Tensor] = None,
+    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+        """Forward pass for inference.
+
+        B: batch size;
+        D: embedding dimension;
+        R: length of right context;
+        U: length of utterance, i.e., current chunk;
+        L: length of cached left context;
+        M: length of cached memory vectors.
+
+        It concatenates the right context and utterance (i.e., current chunk)
+        of current chunk, to compute the query tensor:
+        query = [right_context, utterance],
+        with length Q = R + U.
+        It concatenates the memory vectors, right context, left context, and
+        current chunk, to compute the key and value tensors:
+        key & value = [memory, right_context, left_context, utterance],
+        with length KV = M + R + L + U.
+
+        The chunk-wise attention is:
+        chunk, right context (in query) ->
+          left context, chunk, right context, memory vectors (in key).
+
+        Args:
+          utterance (torch.Tensor):
+            Current chunk frames, with shape (U, B, D), where U = chunk_length.
+          right_context (torch.Tensor):
+            Right context frames, with shape (R, B, D),
+            where R = right_context_length.
+          memory (torch.Tensor):
+            Memory vectors, with shape (M, B, D), or empty tensor.
+          left_context_key (torch,Tensor):
+            Cached attention key of left context from preceding computation,
+            with shape (L, B, D).
+          left_context_val (torch.Tensor):
+            Cached attention value of left context from preceding computation,
+            with shape (L, B, D).
+          padding_mask (torch.Tensor):
+            Padding mask of key tensor, with shape (B, KV).
+
+        Returns:
+          A tuple containing 4 tensors:
+            - output of right context and utterance, with shape (R + U, B, D).
+            - attention key of left context and utterance, which would be cached
+              for next computation, with shape (L + U, B, D).
+            - attention value of left context and utterance, which would be
+              cached for next computation, with shape (L + U, B, D).
+        """
+        #  U = utterance.size(0)
+        #  R = right_context.size(0)
+        #  L = left_context_key.size(0)
+        #  M = memory.size(0)
+
+        U = self.chunk_length
+        R = self.right_context_length
+        L = self.left_context_length
+        M = self.memory_size
+
+        # query = [right context, utterance]
+        Q = R + U
+        # key, value = [memory, right context, left context, utterance]
+        KV = M + R + L + U
+        attention_mask = torch.zeros(Q, KV).to(
+            dtype=torch.bool, device=utterance.device
+        )
+
+        output_right_context_utterance, key, value = self._forward_impl(
+            utterance,
+            right_context,
+            memory,
+            attention_mask,
+            padding_mask=padding_mask,
+            left_context_key=left_context_key,
+            left_context_val=left_context_val,
+        )
+        return (
+            output_right_context_utterance,
+            key[M + R :],
+            value[M + R :],
+        )
+
+
+class EmformerEncoderLayer(nn.Module):
+    """Emformer layer that constitutes Emformer.
+
+    Args:
+      d_model (int):
+        Input dimension.
+      nhead (int):
+        Number of attention heads.
+      dim_feedforward (int):
+        Hidden layer dimension of feedforward network.
+      chunk_length (int):
+        Length of each input segment.
+      dropout (float, optional):
+        Dropout probability. (Default: 0.0)
+      layer_dropout (float, optional):
+        Layer dropout probability. (Default: 0.0)
+      cnn_module_kernel (int):
+        Kernel size of convolution module.
+      left_context_length (int, optional):
+        Length of left context. (Default: 0)
+      right_context_length (int, optional):
+        Length of right context. (Default: 0)
+      memory_size (int, optional):
+        Number of memory elements to use. (Default: 0)
+      tanh_on_mem (bool, optional):
+        If ``True``, applies tanh to memory elements. (Default: ``False``)
+      negative_inf (float, optional):
+        Value to use for negative infinity in attention weights. (Default: -1e8)
+    """
+
+    def __init__(
+        self,
+        d_model: int,
+        nhead: int,
+        dim_feedforward: int,
+        chunk_length: int,
+        dropout: float = 0.1,
+        layer_dropout: float = 0.075,
+        cnn_module_kernel: int = 31,
+        left_context_length: int = 0,
+        right_context_length: int = 0,
+        memory_size: int = 0,
+        tanh_on_mem: bool = False,
+        negative_inf: float = -1e8,
+    ):
+        super().__init__()
+
+        self.attention = EmformerAttention(
+            embed_dim=d_model,
+            nhead=nhead,
+            left_context_length=left_context_length,
+            chunk_length=chunk_length,
+            memory_size=memory_size,
+            right_context_length=right_context_length,
+            dropout=dropout,
+            tanh_on_mem=tanh_on_mem,
+            negative_inf=negative_inf,
+        )
+        self.summary_op = nn.AvgPool1d(
+            kernel_size=chunk_length, stride=chunk_length, ceil_mode=True
+        )
+
+        self.feed_forward_macaron = nn.Sequential(
+            ScaledLinear(d_model, dim_feedforward),
+            ActivationBalancer(channel_dim=-1),
+            DoubleSwish(),
+            nn.Dropout(dropout),
+            ScaledLinear(dim_feedforward, d_model, initial_scale=0.25),
+        )
+
+        self.feed_forward = nn.Sequential(
+            ScaledLinear(d_model, dim_feedforward),
+            ActivationBalancer(channel_dim=-1),
+            DoubleSwish(),
+            nn.Dropout(dropout),
+            ScaledLinear(dim_feedforward, d_model, initial_scale=0.25),
+        )
+
+        self.conv_module = ConvolutionModule(
+            chunk_length,
+            right_context_length,
+            d_model,
+            cnn_module_kernel,
+        )
+
+        self.norm_final = BasicNorm(d_model)
+
+        # try to ensure the output is close to zero-mean
+        # (or at least, zero-median).
+        self.balancer = ActivationBalancer(
+            channel_dim=-1, min_positive=0.45, max_positive=0.55, max_abs=6.0
+        )
+
+        self.dropout = nn.Dropout(dropout)
+
+        self.layer_dropout = layer_dropout
+        self.left_context_length = left_context_length
+        self.right_context_length = right_context_length
+        self.chunk_length = chunk_length
+        self.memory_size = memory_size
+        self.d_model = d_model
+        self.use_memory = memory_size > 0
+
+    def _update_attn_cache(
+        self,
+        next_key: torch.Tensor,
+        next_val: torch.Tensor,
+        memory: torch.Tensor,
+        attn_cache: List[torch.Tensor],
+    ) -> List[torch.Tensor]:
+        """Update cached attention state:
+        1) output memory of current chunk in the lower layer;
+        2) attention key and value in current chunk's computation, which would
+        be reused in next chunk's computation.
+        """
+        # attn_cache[0].shape (self.memory_size, 1, 512)
+        # memory.shape (1, 1, 512)
+        # attn_cache[1].shape (self.left_context_length, 1, 512)
+        # attn_cache[2].shape (self.left_context_length, 1, 512)
+        # next_key.shape (self.left_context_length + self.right_context_utterance, 1, 512)
+        # next_value.shape (self.left_context_length + self.right_context_utterance, 1, 512)
+        new_memory = torch.cat([attn_cache[0], memory])
+        # TODO(fangjun): Remove torch.cat
+        #  new_key = torch.cat([attn_cache[1], next_key])
+        #  new_val = torch.cat([attn_cache[2], next_val])
+        attn_cache[0] = new_memory[1:]
+        attn_cache[1] = next_key[-self.left_context_length :]
+        attn_cache[2] = next_val[-self.left_context_length :]
+        return attn_cache
+
+    def _apply_conv_module_forward(
+        self,
+        right_context_utterance: torch.Tensor,
+        R: int,
+    ) -> torch.Tensor:
+        """Apply convolution module in training and validation mode."""
+        utterance = right_context_utterance[R:]
+        right_context = right_context_utterance[:R]
+        utterance, right_context = self.conv_module(utterance, right_context)
+        right_context_utterance = torch.cat([right_context, utterance])
+        return right_context_utterance
+
+    def _apply_conv_module_infer(
+        self,
+        right_context_utterance: torch.Tensor,
+        R: int,
+        conv_cache: torch.Tensor,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """Apply convolution module on utterance in inference mode."""
+        utterance = right_context_utterance[R:]
+        right_context = right_context_utterance[:R]
+        utterance, right_context, conv_cache = self.conv_module.infer(
+            utterance, right_context, conv_cache
+        )
+        right_context_utterance = torch.cat([right_context, utterance])
+        return right_context_utterance, conv_cache
+
+    def _apply_attention_module_forward(
+        self,
+        right_context_utterance: torch.Tensor,
+        R: int,
+        attention_mask: torch.Tensor,
+        padding_mask: Optional[torch.Tensor] = None,
+    ) -> torch.Tensor:
+        """Apply attention module in training and validation mode."""
+        utterance = right_context_utterance[R:]
+        right_context = right_context_utterance[:R]
+
+        if self.use_memory:
+            memory = self.summary_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)[
+                :-1, :, :
+            ]
+        else:
+            memory = torch.empty(0).to(dtype=utterance.dtype, device=utterance.device)
+        output_right_context_utterance = self.attention(
+            utterance=utterance,
+            right_context=right_context,
+            memory=memory,
+            attention_mask=attention_mask,
+            padding_mask=padding_mask,
+        )
+
+        return output_right_context_utterance
+
+    def _apply_attention_module_infer(
+        self,
+        right_context_utterance: torch.Tensor,
+        R: int,
+        attn_cache: List[torch.Tensor],
+        padding_mask: Optional[torch.Tensor] = None,
+    ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
+        """Apply attention module in inference mode.
+        1) Unpack cached states including:
+           - memory from previous chunks;
+           - attention key and value of left context from preceding
+             chunk's compuation;
+        2) Apply attention computation;
+        3) Update cached attention states including:
+           - memory of current chunk;
+           - attention key and value in current chunk's computation, which would
+             be resued in next chunk's computation.
+        """
+        utterance = right_context_utterance[R:]
+        right_context = right_context_utterance[:R]
+
+        pre_memory = attn_cache[0]
+        left_context_key = attn_cache[1]
+        left_context_val = attn_cache[2]
+
+        if self.use_memory:
+            memory = torch.mean(utterance, dim=0, keepdim=True)
+
+            #  memory = self.summary_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)[
+            #          :1, :, :
+            #  ]
+        else:
+            memory = torch.empty(0).to(dtype=utterance.dtype, device=utterance.device)
+        (output_right_context_utterance, next_key, next_val) = self.attention.infer(
+            utterance=utterance,
+            right_context=right_context,
+            memory=pre_memory,
+            left_context_key=left_context_key,
+            left_context_val=left_context_val,
+            padding_mask=padding_mask,
+        )
+        attn_cache = self._update_attn_cache(next_key, next_val, memory, attn_cache)
+        return output_right_context_utterance, attn_cache
+
+    def forward(
+        self,
+        utterance: torch.Tensor,
+        right_context: torch.Tensor,
+        attention_mask: torch.Tensor,
+        padding_mask: Optional[torch.Tensor] = None,
+        warmup: float = 1.0,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        r"""Forward pass for training and validation mode.
+
+        B: batch size;
+        D: embedding dimension;
+        R: length of hard-copied right contexts;
+        U: length of full utterance;
+        M: length of memory vectors.
+
+        Args:
+          utterance (torch.Tensor):
+            Utterance frames, with shape (U, B, D).
+          right_context (torch.Tensor):
+            Right context frames, with shape (R, B, D).
+          attention_mask (torch.Tensor):
+            Attention mask for underlying attention module,
+            with shape (Q, KV), where Q = R + U, KV = M + R + U.
+          padding_mask (torch.Tensor):
+            Padding mask of ker tensor, with shape (B, KV).
+
+        Returns:
+          A tuple containing 2 tensors:
+            - output utterance, with shape (U, B, D).
+            - output right context, with shape (R, B, D).
+        """
+        R = right_context.size(0)
+        src = torch.cat([right_context, utterance])
+        src_orig = src
+
+        warmup_scale = min(0.1 + warmup, 1.0)
+        # alpha = 1.0 means fully use this encoder layer, 0.0 would mean
+        # completely bypass it.
+        if self.training:
+            alpha = (
+                warmup_scale
+                if torch.rand(()).item() <= (1.0 - self.layer_dropout)
+                else 0.1
+            )
+        else:
+            alpha = 1.0
+
+        # macaron style feed forward module
+        src = src + self.dropout(self.feed_forward_macaron(src))
+
+        # emformer attention module
+        src_att = self._apply_attention_module_forward(
+            src, R, attention_mask, padding_mask=padding_mask
+        )
+        src = src + self.dropout(src_att)
+
+        # convolution module
+        src_conv = self._apply_conv_module_forward(src, R)
+        src = src + self.dropout(src_conv)
+
+        # feed forward module
+        src = src + self.dropout(self.feed_forward(src))
+
+        src = self.norm_final(self.balancer(src))
+
+        if alpha != 1.0:
+            src = alpha * src + (1 - alpha) * src_orig
+
+        output_utterance = src[R:]
+        output_right_context = src[:R]
+        return output_utterance, output_right_context
+
+    @torch.jit.export
+    def infer(
+        self,
+        utterance: torch.Tensor,
+        right_context: torch.Tensor,
+        cache: List[torch.Tensor],
+        padding_mask: Optional[torch.Tensor] = None,
+    ) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]:
+        """Forward pass for inference.
+
+         B: batch size;
+         D: embedding dimension;
+         R: length of right_context;
+         U: length of utterance;
+         M: length of memory.
+
+        Args:
+           utterance (torch.Tensor):
+             Utterance frames, with shape (U, B, D).
+           right_context (torch.Tensor):
+             Right context frames, with shape (R, B, D).
+           attn_cache (List[torch.Tensor]):
+             Cached attention tensors generated in preceding computation,
+             including memory, key and value of left context.
+           conv_cache (torch.Tensor, optional):
+             Cache tensor of left context for causal convolution.
+           padding_mask (torch.Tensor):
+             Padding mask of ker tensor.
+
+         Returns:
+           (Tensor, Tensor, List[torch.Tensor], Tensor):
+             - output utterance, with shape (U, B, D);
+             - output right_context, with shape (R, B, D);
+             - output attention cache;
+             - output convolution cache.
+        """
+        R = self.right_context_length
+        src = torch.cat([right_context, utterance])
+        attn_cache = cache[:3]
+        conv_cache = cache[3]
+
+        # macaron style feed forward module
+        src = src + self.dropout(self.feed_forward_macaron(src))
+
+        # emformer attention module
+        src_att, attn_cache = self._apply_attention_module_infer(
+            src, R, attn_cache, padding_mask=padding_mask
+        )
+        src = src + self.dropout(src_att)
+
+        # convolution module
+        src_conv, conv_cache = self._apply_conv_module_infer(src, R, conv_cache)
+        src = src + self.dropout(src_conv)
+
+        # feed forward module
+        src = src + self.dropout(self.feed_forward(src))
+
+        src = self.norm_final(self.balancer(src))
+
+        output_utterance = src[R:]
+        output_right_context = src[:R]
+        return (output_utterance, output_right_context, attn_cache + [conv_cache])
+
+
+def _gen_attention_mask_block(
+    col_widths: List[int],
+    col_mask: List[bool],
+    num_rows: int,
+    device: torch.device,
+) -> torch.Tensor:
+    assert len(col_widths) == len(
+        col_mask
+    ), "Length of col_widths must match that of col_mask"
+
+    mask_block = [
+        torch.ones(num_rows, col_width, device=device)
+        if is_ones_col
+        else torch.zeros(num_rows, col_width, device=device)
+        for col_width, is_ones_col in zip(col_widths, col_mask)
+    ]
+    return torch.cat(mask_block, dim=1)
+
+
+class EmformerEncoder(nn.Module):
+    """Implements the Emformer architecture introduced in
+    *Emformer: Efficient Memory Transformer Based Acoustic Model for Low Latency
+    Streaming Speech Recognition*
+    [:footcite:`shi2021emformer`].
+
+    In this model, the memory bank computation is simplifed, using the averaged
+    value of each chunk as its memory vector.
+
+    Args:
+      d_model (int):
+        Input dimension.
+      nhead (int):
+        Number of attention heads in each emformer layer.
+      dim_feedforward (int):
+        Hidden layer dimension of each emformer layer's feedforward network.
+      num_encoder_layers (int):
+        Number of emformer layers to instantiate.
+      chunk_length (int):
+        Length of each input segment.
+      dropout (float, optional):
+        Dropout probability. (default: 0.0)
+      layer_dropout (float, optional):
+        Layer dropout probability. (default: 0.0)
+      cnn_module_kernel (int):
+        Kernel size of convolution module.
+      left_context_length (int, optional):
+        Length of left context. (default: 0)
+      right_context_length (int, optional):
+        Length of right context. (default: 0)
+      memory_size (int, optional):
+        Number of memory elements to use. (default: 0)
+      tanh_on_mem (bool, optional):
+        If ``true``, applies tanh to memory elements. (default: ``false``)
+      negative_inf (float, optional):
+        Value to use for negative infinity in attention weights. (default: -1e8)
+    """
+
+    def __init__(
+        self,
+        chunk_length: int,
+        d_model: int = 256,
+        nhead: int = 4,
+        dim_feedforward: int = 2048,
+        num_encoder_layers: int = 12,
+        dropout: float = 0.1,
+        layer_dropout: float = 0.075,
+        cnn_module_kernel: int = 31,
+        left_context_length: int = 0,
+        right_context_length: int = 0,
+        memory_size: int = 0,
+        tanh_on_mem: bool = False,
+        negative_inf: float = -1e8,
+    ):
+        super().__init__()
+
+        assert (
+            chunk_length - 1
+        ) & chunk_length == 0, "chunk_length should be a power of 2."
+        self.shift = int(math.log(chunk_length, 2))
+
+        self.use_memory = memory_size > 0
+
+        self.emformer_layers = nn.ModuleList(
+            [
+                EmformerEncoderLayer(
+                    d_model=d_model,
+                    nhead=nhead,
+                    dim_feedforward=dim_feedforward,
+                    chunk_length=chunk_length,
+                    dropout=dropout,
+                    layer_dropout=layer_dropout,
+                    cnn_module_kernel=cnn_module_kernel,
+                    left_context_length=left_context_length,
+                    right_context_length=right_context_length,
+                    memory_size=memory_size,
+                    tanh_on_mem=tanh_on_mem,
+                    negative_inf=negative_inf,
+                )
+                for layer_idx in range(num_encoder_layers)
+            ]
+        )
+
+        self.num_encoder_layers = num_encoder_layers
+        self.d_model = d_model
+        self.left_context_length = left_context_length
+        self.right_context_length = right_context_length
+        self.chunk_length = chunk_length
+        self.memory_size = memory_size
+        self.cnn_module_kernel = cnn_module_kernel
+
+    def _gen_right_context(self, x: torch.Tensor) -> torch.Tensor:
+        """Hard copy each chunk's right context and concat them."""
+        T = x.shape[0]
+        num_chunks = math.ceil((T - self.right_context_length) / self.chunk_length)
+        # first (num_chunks - 1) right context block
+        intervals = torch.arange(
+            0, self.chunk_length * (num_chunks - 1), self.chunk_length
+        )
+        first = torch.arange(
+            self.chunk_length, self.chunk_length + self.right_context_length
+        )
+        indexes = intervals.unsqueeze(1) + first.unsqueeze(0)
+        # cat last right context block
+        indexes = torch.cat(
+            [
+                indexes,
+                torch.arange(T - self.right_context_length, T).unsqueeze(0),
+            ]
+        )
+        right_context_blocks = x[indexes.reshape(-1)]
+        return right_context_blocks
+
+    def _gen_attention_mask_col_widths(self, chunk_idx: int, U: int) -> List[int]:
+        """Calculate column widths (key, value) in attention mask for the
+        chunk_idx chunk."""
+        num_chunks = math.ceil(U / self.chunk_length)
+        rc = self.right_context_length
+        lc = self.left_context_length
+        rc_start = chunk_idx * rc
+        rc_end = rc_start + rc
+        chunk_start = max(chunk_idx * self.chunk_length - lc, 0)
+        chunk_end = min((chunk_idx + 1) * self.chunk_length, U)
+        R = rc * num_chunks
+
+        if self.use_memory:
+            m_start = max(chunk_idx - self.memory_size, 0)
+            M = num_chunks - 1
+            col_widths = [
+                m_start,  # before memory
+                chunk_idx - m_start,  # memory
+                M - chunk_idx,  # after memory
+                rc_start,  # before right context
+                rc,  # right context
+                R - rc_end,  # after right context
+                chunk_start,  # before chunk
+                chunk_end - chunk_start,  # chunk
+                U - chunk_end,  # after chunk
+            ]
+        else:
+            col_widths = [
+                rc_start,  # before right context
+                rc,  # right context
+                R - rc_end,  # after right context
+                chunk_start,  # before chunk
+                chunk_end - chunk_start,  # chunk
+                U - chunk_end,  # after chunk
+            ]
+
+        return col_widths
+
+    def _gen_attention_mask(self, utterance: torch.Tensor) -> torch.Tensor:
+        """Generate attention mask to simulate underlying chunk-wise attention
+        computation, where chunk-wise connections are filled with `False`,
+        and other unnecessary connections beyond chunk are filled with `True`.
+
+        R: length of hard-copied right contexts;
+        U: length of full utterance;
+        M: length of memory vectors;
+        Q: length of attention query;
+        KV: length of attention key and value.
+
+        The shape of attention mask is (Q, KV).
+        If self.use_memory is `True`:
+          query = [right_context, utterance];
+          key, value = [memory, right_context, utterance];
+          Q = R + U, KV = M + R + U.
+        Otherwise:
+          query = [right_context, utterance]
+          key, value = [right_context, utterance]
+          Q = R + U, KV = R + U.
+
+        Suppose:
+          c_i: chunk at index i;
+          r_i: right context that c_i can use;
+          l_i: left context that c_i can use;
+          m_i: past memory vectors from previous layer that c_i can use;
+        The target chunk-wise attention is:
+          c_i, r_i (in query) -> l_i, c_i, r_i, m_i (in key).
+        """
+        U = utterance.size(0)
+        num_chunks = math.ceil(U / self.chunk_length)
+
+        right_context_mask = []
+        utterance_mask = []
+
+        if self.use_memory:
+            num_cols = 9
+            # right context and utterance both attend to memory, right context,
+            # utterance
+            right_context_utterance_cols_mask = [
+                idx in [1, 4, 7] for idx in range(num_cols)
+            ]
+        else:
+            num_cols = 6
+            # right context and utterance both attend to right context and
+            # utterance
+            right_context_utterance_cols_mask = [
+                idx in [1, 4] for idx in range(num_cols)
+            ]
+        masks_to_concat = [right_context_mask, utterance_mask]
+
+        for chunk_idx in range(num_chunks):
+            col_widths = self._gen_attention_mask_col_widths(chunk_idx, U)
+
+            right_context_mask_block = _gen_attention_mask_block(
+                col_widths,
+                right_context_utterance_cols_mask,
+                self.right_context_length,
+                utterance.device,
+            )
+            right_context_mask.append(right_context_mask_block)
+
+            utterance_mask_block = _gen_attention_mask_block(
+                col_widths,
+                right_context_utterance_cols_mask,
+                min(
+                    self.chunk_length,
+                    U - chunk_idx * self.chunk_length,
+                ),
+                utterance.device,
+            )
+            utterance_mask.append(utterance_mask_block)
+
+        attention_mask = (
+            1 - torch.cat([torch.cat(mask) for mask in masks_to_concat])
+        ).to(torch.bool)
+        return attention_mask
+
+    def _forward(
+        self, x: torch.Tensor, lengths: torch.Tensor, warmup: float = 1.0
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """Forward pass for training and validation mode.
+
+        B: batch size;
+        D: input dimension;
+        U: length of utterance.
+
+        Args:
+          x (torch.Tensor):
+            Utterance frames right-padded with right context frames,
+            with shape (U + right_context_length, B, D).
+          lengths (torch.Tensor):
+            With shape (B,) and i-th element representing number of valid
+            utterance frames for i-th batch element in x, which contains the
+            right_context at the end.
+
+        Returns:
+          A tuple of 2 tensors:
+            - output utterance frames, with shape (U, B, D).
+            - output_lengths, with shape (B,), without containing the
+              right_context at the end.
+        """
+        U = x.size(0) - self.right_context_length
+
+        right_context = self._gen_right_context(x)
+        utterance = x[:U]
+        output_lengths = torch.clamp(lengths - self.right_context_length, min=0)
+        attention_mask = self._gen_attention_mask(utterance)
+
+        M = (
+            right_context.size(0) // self.right_context_length - 1
+            if self.use_memory
+            else 0
+        )
+        padding_mask = make_pad_mask(M + right_context.size(0) + output_lengths)
+
+        output = utterance
+        for layer in self.emformer_layers:
+            output, right_context = layer(
+                output,
+                right_context,
+                attention_mask,
+                padding_mask=padding_mask,
+                warmup=warmup,
+            )
+
+        return output, output_lengths
+
+    @torch.jit.export
+    def infer(
+        self,
+        x: torch.Tensor,
+        states: List[torch.Tensor],
+    ) -> Tuple[torch.Tensor, List[torch.Tensor],]:
+        """Forward pass for streaming inference.
+
+        B: batch size;
+        D: input dimension;
+        U: length of utterance.
+
+        Args:
+          x (torch.Tensor):
+            Utterance frames right-padded with right context frames,
+            with shape (U + right_context_length, B, D).
+          lengths (torch.Tensor):
+            With shape (B,) and i-th element representing number of valid
+            utterance frames for i-th batch element in x, which contains the
+            right_context at the end.
+          states (List[torch.Tensor, List[List[torch.Tensor]], List[torch.Tensor]]: # noqa
+            Cached states containing:
+            - attn_caches: attention states from preceding chunk's computation,
+              where each element corresponds to each emformer layer
+            - conv_caches: left context for causal convolution, where each
+              element corresponds to each layer.
+
+        Returns:
+          (Tensor, Tensor, List[List[torch.Tensor]], List[torch.Tensor]):
+            - output utterance frames, with shape (U, B, D).
+            - output lengths, with shape (B,), without containing the
+              right_context at the end.
+            - updated states from current chunk's computation.
+        """
+        # lengths = chunk_length + right_context_length
+        utterance = x[: self.chunk_length]
+        right_context = x[self.chunk_length :]
+        #  right_context_utterance = torch.cat([right_context, utterance])
+
+        output = utterance
+        output_states: List[torch.Tensor] = []
+        for layer_idx, layer in enumerate(self.emformer_layers):
+            start = layer_idx * 4
+            end = start + 4
+            cache = states[start:end]
+
+            (output, right_context, output_cache,) = layer.infer(
+                output,
+                right_context,
+                padding_mask=None,
+                cache=cache,
+            )
+            output_states.extend(output_cache)
+
+        return output, output_states
+
+    @torch.jit.export
+    def init_states(
+        self, device: torch.device = torch.device("cpu")
+    ) -> List[torch.Tensor]:
+        """Create initial states."""
+        #
+        states = []
+        # layer0: attn cache, conv cache, 3 tensors + 1 tensor
+        # layer1: attn cache, conv cache, 3 tensors +  1 tensor
+        # layer2: attn cache, conv cache, 3 tensors + 1 tensor
+        # ...
+        # last layer: attn cache, conv cache, 3 tensors + 1 tensor
+        for i in range(self.num_encoder_layers):
+            states.append(torch.zeros(self.memory_size, 1, self.d_model, device=device))
+            states.append(
+                torch.zeros(self.left_context_length, 1, self.d_model, device=device)
+            )
+            states.append(
+                torch.zeros(self.left_context_length, 1, self.d_model, device=device)
+            )
+
+            states.append(
+                torch.zeros(1, self.d_model, self.cnn_module_kernel - 1, device=device)
+            )
+        return states
+
+        attn_caches = [
+            [
+                torch.zeros(self.memory_size, self.d_model, device=device),
+                torch.zeros(self.left_context_length, self.d_model, device=device),
+                torch.zeros(self.left_context_length, self.d_model, device=device),
+            ]
+            for _ in range(self.num_encoder_layers)
+        ]
+        conv_caches = [
+            torch.zeros(self.d_model, self.cnn_module_kernel - 1, device=device)
+            for _ in range(self.num_encoder_layers)
+        ]
+        states: Tuple[List[List[torch.Tensor]], List[torch.Tensor]] = (
+            attn_caches,
+            conv_caches,
+        )
+        return states
+
+
+class Emformer(EncoderInterface):
+    def __init__(
+        self,
+        num_features: int,
+        chunk_length: int,
+        subsampling_factor: int = 4,
+        d_model: int = 256,
+        nhead: int = 4,
+        dim_feedforward: int = 2048,
+        num_encoder_layers: int = 12,
+        dropout: float = 0.1,
+        layer_dropout: float = 0.075,
+        cnn_module_kernel: int = 3,
+        left_context_length: int = 0,
+        right_context_length: int = 0,
+        memory_size: int = 0,
+        tanh_on_mem: bool = False,
+        negative_inf: float = -1e8,
+        is_pnnx: bool = True,
+    ):
+        super().__init__()
+
+        self.subsampling_factor = subsampling_factor
+        self.right_context_length = right_context_length
+        self.chunk_length = chunk_length
+        if subsampling_factor != 4:
+            raise NotImplementedError("Support only 'subsampling_factor=4'.")
+        if chunk_length % subsampling_factor != 0:
+            raise NotImplementedError(
+                "chunk_length must be a mutiple of subsampling_factor."
+            )
+        if left_context_length != 0 and left_context_length % subsampling_factor != 0:
+            raise NotImplementedError(
+                "left_context_length must be 0 or a mutiple of subsampling_factor."  # noqa
+            )
+        if right_context_length != 0 and right_context_length % subsampling_factor != 0:
+            raise NotImplementedError(
+                "right_context_length must be 0 or a mutiple of subsampling_factor."  # noqa
+            )
+
+        # self.encoder_embed converts the input of shape (N, T, num_features)
+        # to the shape (N, T//subsampling_factor, d_model).
+        # That is, it does two things simultaneously:
+        #   (1) subsampling: T -> T//subsampling_factor
+        #   (2) embedding: num_features -> d_model
+        self.encoder_embed = Conv2dSubsampling(num_features, d_model, is_pnnx=is_pnnx)
+        self.is_pnnx = is_pnnx
+
+        self.encoder = EmformerEncoder(
+            chunk_length=chunk_length // subsampling_factor,
+            d_model=d_model,
+            nhead=nhead,
+            dim_feedforward=dim_feedforward,
+            num_encoder_layers=num_encoder_layers,
+            dropout=dropout,
+            layer_dropout=layer_dropout,
+            cnn_module_kernel=cnn_module_kernel,
+            left_context_length=left_context_length // subsampling_factor,
+            right_context_length=right_context_length // subsampling_factor,
+            memory_size=memory_size,
+            tanh_on_mem=tanh_on_mem,
+            negative_inf=negative_inf,
+        )
+
+    def _forward(
+        self, x: torch.Tensor, x_lens: torch.Tensor, warmup: float = 1.0
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """Forward pass for training and non-streaming inference.
+
+        B: batch size;
+        D: feature dimension;
+        T: length of utterance.
+
+        Args:
+          x (torch.Tensor):
+            Utterance frames right-padded with right context frames,
+            with shape (B, T, D).
+          x_lens (torch.Tensor):
+            With shape (B,) and i-th element representing number of valid
+            utterance frames for i-th batch element in x, containing the
+            right_context at the end.
+          warmup:
+            A floating point value that gradually increases from 0 throughout
+            training; when it is >= 1.0 we are "fully warmed up".  It is used
+            to turn modules on sequentially.
+
+        Returns:
+          (Tensor, Tensor):
+            - output embedding, with shape (B, T', D), where
+              T' = ((T - 1) // 2 - 1) // 2 - self.right_context_length // 4.
+            - output lengths, with shape (B,), without containing the
+              right_context at the end.
+        """
+        x = self.encoder_embed(x)
+        x = x.permute(1, 0, 2)  # (N, T, C) -> (T, N, C)
+
+        x_lens = (((x_lens - 1) >> 1) - 1) >> 1
+        assert x.size(0) == x_lens.max().item()
+
+        output, output_lengths = self.encoder(x, x_lens, warmup=warmup)  # (T, N, C)
+
+        output = output.permute(1, 0, 2)  # (T, N, C) -> (N, T, C)
+
+        return output, output_lengths
+
+    def forward(
+        self,
+        x: torch.Tensor,
+        states: List[torch.Tensor],
+    ) -> Tuple[torch.Tensor, List[torch.Tensor],]:
+        """Forward pass for streaming inference.
+
+        B: batch size;
+        D: feature dimension;
+        T: length of utterance.
+
+        Args:
+          x (torch.Tensor):
+            Utterance frames right-padded with right context frames,
+            with shape (B, T, D).
+          lengths (torch.Tensor):
+            With shape (B,) and i-th element representing number of valid
+            utterance frames for i-th batch element in x, containing the
+            right_context at the end.
+          states (List[torch.Tensor, List[List[torch.Tensor]], List[torch.Tensor]]: # noqa
+            Cached states containing:
+            - past_lens: number of past frames for each sample in batch
+            - attn_caches: attention states from preceding chunk's computation,
+              where each element corresponds to each emformer layer
+            - conv_caches: left context for causal convolution, where each
+              element corresponds to each layer.
+        Returns:
+          (Tensor, Tensor):
+            - output embedding, with shape (B, T', D), where
+              T' = ((T - 1) // 2 - 1) // 2 - self.right_context_length // 4.
+            - output lengths, with shape (B,), without containing the
+              right_context at the end.
+            - updated states from current chunk's computation.
+        """
+        x = self.encoder_embed(x)
+        # drop the first and last frames
+        x = x[:, 1:-1, :]
+        x = x.permute(1, 0, 2)  # (N, T, C) -> (T, N, C)
+
+        # Caution: We assume the subsampling factor is 4!
+
+        output, output_states = self.encoder.infer(x, states)
+
+        output = output.permute(1, 0, 2)  # (T, N, C) -> (N, T, C)
+
+        return output, output_states
+
+    @torch.jit.export
+    def init_states(
+        self, device: torch.device = torch.device("cpu")
+    ) -> List[torch.Tensor]:
+        """Create initial states."""
+        return self.encoder.init_states(device)
+
+
+class Conv2dSubsampling(nn.Module):
+    """Convolutional 2D subsampling (to 1/4 length).
+
+    Convert an input of shape (N, T, idim) to an output
+    with shape (N, T', odim), where
+    T' = ((T-1)//2 - 1)//2, which approximates T' == T//4
+
+    It is based on
+    https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py  # noqa
+    """
+
+    def __init__(
+        self,
+        in_channels: int,
+        out_channels: int,
+        layer1_channels: int = 8,
+        layer2_channels: int = 32,
+        layer3_channels: int = 128,
+        is_pnnx: bool = False,
+    ) -> None:
+        """
+        Args:
+          in_channels:
+            Number of channels in. The input shape is (N, T, in_channels).
+            Caution: It requires: T >=7, in_channels >=7
+          out_channels
+            Output dim. The output shape is (N, ((T-1)//2 - 1)//2, out_channels)
+          layer1_channels:
+            Number of channels in layer1
+          layer1_channels:
+            Number of channels in layer2
+          is_pnnx:
+            True if we are converting the model to PNNX format.
+            False otherwise.
+        """
+        assert in_channels >= 7
+        super().__init__()
+
+        self.conv = nn.Sequential(
+            ScaledConv2d(
+                in_channels=1,
+                out_channels=layer1_channels,
+                kernel_size=3,
+                padding=1,
+            ),
+            ActivationBalancer(channel_dim=1),
+            DoubleSwish(),
+            ScaledConv2d(
+                in_channels=layer1_channels,
+                out_channels=layer2_channels,
+                kernel_size=3,
+                stride=2,
+            ),
+            ActivationBalancer(channel_dim=1),
+            DoubleSwish(),
+            ScaledConv2d(
+                in_channels=layer2_channels,
+                out_channels=layer3_channels,
+                kernel_size=3,
+                stride=2,
+            ),
+            ActivationBalancer(channel_dim=1),
+            DoubleSwish(),
+        )
+        self.out = ScaledLinear(
+            layer3_channels * (((in_channels - 1) // 2 - 1) // 2), out_channels
+        )
+        # set learn_eps=False because out_norm is preceded by `out`, and `out`
+        # itself has learned scale, so the extra degree of freedom is not
+        # needed.
+        self.out_norm = BasicNorm(out_channels, learn_eps=False)
+        # constrain median of output to be close to zero.
+        self.out_balancer = ActivationBalancer(
+            channel_dim=-1, min_positive=0.45, max_positive=0.55
+        )
+
+        # ncnn supports only batch size == 1
+        self.is_pnnx = is_pnnx
+        self.conv_out_dim = self.out.weight.shape[1]
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        """Subsample x.
+
+        Args:
+          x:
+            Its shape is (N, T, idim).
+
+        Returns:
+          Return a tensor of shape (N, ((T-1)//2 - 1)//2, odim)
+        """
+        # On entry, x is (N, T, idim)
+        x = x.unsqueeze(1)  # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W)
+        x = self.conv(x)
+
+        if torch.jit.is_tracing() and self.is_pnnx:
+            x = x.permute(0, 2, 1, 3).reshape(1, -1, self.conv_out_dim)
+            x = self.out(x)
+        else:
+            # Now x is of shape (N, odim, ((T-1)//2-1)//2, ((idim-1)//2-1)//2)
+            b, c, t, f = x.size()
+            x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
+        # Now x is of shape (N, ((T-1)//2 - 1))//2, odim)
+        x = self.out_norm(x)
+        x = self.out_balancer(x)
+        return x
diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-for-ncnn.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-for-ncnn.py
new file mode 100755
index 000000000..716de5734
--- /dev/null
+++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-for-ncnn.py
@@ -0,0 +1,335 @@
+#!/usr/bin/env python3
+
+"""
+Usage:
+./conv_emformer_transducer_stateless2/export-for-ncnn.py \
+  --exp-dir ./conv_emformer_transducer_stateless2/exp \
+  --bpe-model data/lang_bpe_500/bpe.model \
+  --epoch 30 \
+  --avg 10 \
+  --use-averaged-model=True \
+  --num-encoder-layers 12 \
+  --chunk-length 32 \
+  --cnn-module-kernel 31 \
+  --left-context-length 32 \
+  --right-context-length 8 \
+  --memory-size 32 \
+
+cd ./conv_emformer_transducer_stateless2/exp
+pnnx encoder_jit_trace-pnnx.pt
+pnnx decoder_jit_trace-pnnx.pt
+pnnx joiner_jit_trace-pnnx.pt
+
+You can find converted models at
+https://huggingface.co/csukuangfj/sherpa-ncnn-conv-emformer-transducer-2022-12-04
+
+See ./streaming-ncnn-decode.py
+and
+https://github.com/k2-fsa/sherpa-ncnn
+for usage.
+"""
+
+import argparse
+import logging
+from pathlib import Path
+
+import sentencepiece as spm
+import torch
+from scaling_converter import convert_scaled_to_non_scaled
+from train2 import add_model_arguments, get_params, get_transducer_model
+
+from icefall.checkpoint import (
+    average_checkpoints,
+    average_checkpoints_with_averaged_model,
+    find_checkpoints,
+    load_checkpoint,
+)
+from icefall.utils import str2bool
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+
+    parser.add_argument(
+        "--epoch",
+        type=int,
+        default=28,
+        help="""It specifies the checkpoint to use for averaging.
+        Note: Epoch counts from 0.
+        You can specify --avg to use more checkpoints for model averaging.""",
+    )
+
+    parser.add_argument(
+        "--iter",
+        type=int,
+        default=0,
+        help="""If positive, --epoch is ignored and it
+        will use the checkpoint exp_dir/checkpoint-iter.pt.
+        You can specify --avg to use more checkpoints for model averaging.
+        """,
+    )
+
+    parser.add_argument(
+        "--avg",
+        type=int,
+        default=15,
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch' and '--iter'",
+    )
+
+    parser.add_argument(
+        "--exp-dir",
+        type=str,
+        default="pruned_transducer_stateless2/exp",
+        help="""It specifies the directory where all training related
+        files, e.g., checkpoints, log, etc, are saved
+        """,
+    )
+
+    parser.add_argument(
+        "--bpe-model",
+        type=str,
+        default="data/lang_bpe_500/bpe.model",
+        help="Path to the BPE model",
+    )
+
+    parser.add_argument(
+        "--jit",
+        type=str2bool,
+        default=False,
+        help="""True to save a model after applying torch.jit.script.
+        """,
+    )
+
+    parser.add_argument(
+        "--context-size",
+        type=int,
+        default=2,
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+    )
+
+    parser.add_argument(
+        "--use-averaged-model",
+        type=str2bool,
+        default=True,
+        help="Whether to load averaged model. Currently it only supports "
+        "using --epoch. If True, it would decode with the averaged model "
+        "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+        "Actually only the models with epoch number of `epoch-avg` and "
+        "`epoch` are loaded for averaging. ",
+    )
+
+    add_model_arguments(parser)
+
+    return parser
+
+
+def export_encoder_model_jit_trace(
+    encoder_model: torch.nn.Module,
+    encoder_filename: str,
+) -> None:
+    """Export the given encoder model with torch.jit.trace()
+
+    Note: The warmup argument is fixed to 1.
+
+    Args:
+      encoder_model:
+        The input encoder model
+      encoder_filename:
+        The filename to save the exported model.
+    """
+    chunk_length = encoder_model.chunk_length  # before subsampling
+    right_context_length = encoder_model.right_context_length  # before subsampling
+    pad_length = right_context_length + 2 * 4 + 3
+    s = f"chunk_length: {chunk_length}, "
+    s += f"right_context_length: {right_context_length}\n"
+    logging.info(s)
+
+    T = chunk_length + pad_length
+
+    x = torch.zeros(1, T, 80, dtype=torch.float32)
+    states = encoder_model.init_states()
+    states = encoder_model.init_states()
+
+    traced_model = torch.jit.trace(encoder_model, (x, states))
+    traced_model.save(encoder_filename)
+    logging.info(f"Saved to {encoder_filename}")
+
+
+def export_decoder_model_jit_trace(
+    decoder_model: torch.nn.Module,
+    decoder_filename: str,
+) -> None:
+    """Export the given decoder model with torch.jit.trace()
+
+    Note: The argument need_pad is fixed to False.
+
+    Args:
+      decoder_model:
+        The input decoder model
+      decoder_filename:
+        The filename to save the exported model.
+    """
+    y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64)
+    need_pad = torch.tensor([False])
+
+    traced_model = torch.jit.trace(decoder_model, (y, need_pad))
+    traced_model.save(decoder_filename)
+    logging.info(f"Saved to {decoder_filename}")
+
+
+def export_joiner_model_jit_trace(
+    joiner_model: torch.nn.Module,
+    joiner_filename: str,
+) -> None:
+    """Export the given joiner model with torch.jit.trace()
+
+    Note: The argument project_input is fixed to True. A user should not
+    project the encoder_out/decoder_out by himself/herself. The exported joiner
+    will do that for the user.
+
+    Args:
+      joiner_model:
+        The input joiner model
+      joiner_filename:
+        The filename to save the exported model.
+
+    """
+    encoder_out_dim = joiner_model.encoder_proj.weight.shape[1]
+    decoder_out_dim = joiner_model.decoder_proj.weight.shape[1]
+    encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32)
+    decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32)
+
+    traced_model = torch.jit.trace(joiner_model, (encoder_out, decoder_out))
+    traced_model.save(joiner_filename)
+    logging.info(f"Saved to {joiner_filename}")
+
+
+@torch.no_grad()
+def main():
+    args = get_parser().parse_args()
+    args.exp_dir = Path(args.exp_dir)
+
+    params = get_params()
+    params.update(vars(args))
+
+    device = torch.device("cpu")
+
+    logging.info(f"device: {device}")
+
+    sp = spm.SentencePieceProcessor()
+    sp.load(params.bpe_model)
+
+    #  is defined in local/train_bpe_model.py
+    params.blank_id = sp.piece_to_id("")
+    params.vocab_size = sp.get_piece_size()
+
+    logging.info(params)
+
+    logging.info("About to create model")
+    model = get_transducer_model(params)
+
+    if not params.use_averaged_model:
+        if params.iter > 0:
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg
+            ]
+            if len(filenames) == 0:
+                raise ValueError(
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            elif len(filenames) < params.avg:
+                raise ValueError(
+                    f"Not enough checkpoints ({len(filenames)}) found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            logging.info(f"averaging {filenames}")
+            model.to(device)
+            model.load_state_dict(average_checkpoints(filenames, device=device))
+        elif params.avg == 1:
+            load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+        else:
+            start = params.epoch - params.avg + 1
+            filenames = []
+            for i in range(start, params.epoch + 1):
+                if i >= 1:
+                    filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+            logging.info(f"averaging {filenames}")
+            model.to(device)
+            model.load_state_dict(average_checkpoints(filenames, device=device))
+    else:
+        if params.iter > 0:
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg + 1
+            ]
+            if len(filenames) == 0:
+                raise ValueError(
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            elif len(filenames) < params.avg + 1:
+                raise ValueError(
+                    f"Not enough checkpoints ({len(filenames)}) found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            filename_start = filenames[-1]
+            filename_end = filenames[0]
+            logging.info(
+                "Calculating the averaged model over iteration checkpoints"
+                f" from {filename_start} (excluded) to {filename_end}"
+            )
+            model.to(device)
+            model.load_state_dict(
+                average_checkpoints_with_averaged_model(
+                    filename_start=filename_start,
+                    filename_end=filename_end,
+                    device=device,
+                )
+            )
+        else:
+            assert params.avg > 0, params.avg
+            start = params.epoch - params.avg
+            assert start >= 1, start
+            filename_start = f"{params.exp_dir}/epoch-{start}.pt"
+            filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
+            logging.info(
+                f"Calculating the averaged model over epoch range from "
+                f"{start} (excluded) to {params.epoch}"
+            )
+            model.to(device)
+            model.load_state_dict(
+                average_checkpoints_with_averaged_model(
+                    filename_start=filename_start,
+                    filename_end=filename_end,
+                    device=device,
+                )
+            )
+
+    model.to("cpu")
+    model.eval()
+
+    convert_scaled_to_non_scaled(model, inplace=True)
+    logging.info("Using torch.jit.trace()")
+
+    logging.info("Exporting encoder")
+    encoder_filename = params.exp_dir / "encoder_jit_trace-pnnx.pt"
+    export_encoder_model_jit_trace(model.encoder, encoder_filename)
+
+    logging.info("Exporting decoder")
+    decoder_filename = params.exp_dir / "decoder_jit_trace-pnnx.pt"
+    export_decoder_model_jit_trace(model.decoder, decoder_filename)
+
+    logging.info("Exporting joiner")
+    joiner_filename = params.exp_dir / "joiner_jit_trace-pnnx.pt"
+    export_joiner_model_jit_trace(model.joiner, joiner_filename)
+
+
+if __name__ == "__main__":
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+    logging.basicConfig(format=formatter, level=logging.INFO)
+    main()
diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/jit_pretrained.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/jit_pretrained.py
new file mode 100755
index 000000000..1fe358c79
--- /dev/null
+++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/jit_pretrained.py
@@ -0,0 +1,292 @@
+#!/usr/bin/env python3
+# flake8: noqa
+# Copyright      2022  Xiaomi Corp.        (authors: Fangjun Kuang, Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This script loads torchscript models exported by `torch.jit.trace()`
+and uses them to decode waves.
+You can use the following command to get the exported models:
+
+./conv_emformer_transducer_stateless2/export-for-ncnn.py \
+  --exp-dir ./conv_emformer_transducer_stateless2/exp \
+  --bpe-model data/lang_bpe_500/bpe.model \
+  --epoch 20 \
+  --avg 10
+
+Usage of this script:
+
+./conv_emformer_transducer_stateless2/jit_pretrained.py \
+  --encoder-model-filename ./conv_emformer_transducer_stateless2/exp/encoder_jit_trace-pnnx.pt \
+  --decoder-model-filename ./conv_emformer_transducer_stateless2/exp/decoder_jit_trace-pnnx.pt \
+  --joiner-model-filename ./conv_emformer_transducer_stateless2/exp/joiner_jit_trace-pnnx.pt \
+  --bpe-model ./data/lang_bpe_500/bpe.model \
+  /path/to/foo.wav \
+"""
+
+import argparse
+import logging
+import math
+from typing import List
+
+import kaldifeat
+import sentencepiece as spm
+import torch
+import torchaudio
+from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature
+from torch.nn.utils.rnn import pad_sequence
+from typing import Optional, List
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+
+    parser.add_argument(
+        "--encoder-model-filename",
+        type=str,
+        required=True,
+        help="Path to the encoder torchscript model. ",
+    )
+
+    parser.add_argument(
+        "--decoder-model-filename",
+        type=str,
+        required=True,
+        help="Path to the decoder torchscript model. ",
+    )
+
+    parser.add_argument(
+        "--joiner-model-filename",
+        type=str,
+        required=True,
+        help="Path to the joiner torchscript model. ",
+    )
+
+    parser.add_argument(
+        "--bpe-model",
+        type=str,
+        help="""Path to bpe.model.""",
+    )
+
+    parser.add_argument(
+        "sound_file",
+        type=str,
+        help="The input sound file(s) to transcribe. "
+        "Supported formats are those supported by torchaudio.load(). "
+        "For example, wav and flac are supported. "
+        "The sample rate has to be 16kHz.",
+    )
+
+    parser.add_argument(
+        "--sample-rate",
+        type=int,
+        default=16000,
+        help="The sample rate of the input sound file",
+    )
+
+    parser.add_argument(
+        "--context-size",
+        type=int,
+        default=2,
+        help="Context size of the decoder model",
+    )
+
+    return parser
+
+
+def read_sound_files(
+    filenames: List[str], expected_sample_rate: float
+) -> List[torch.Tensor]:
+    """Read a list of sound files into a list 1-D float32 torch tensors.
+    Args:
+      filenames:
+        A list of sound filenames.
+      expected_sample_rate:
+        The expected sample rate of the sound files.
+    Returns:
+      Return a list of 1-D float32 torch tensors.
+    """
+    ans = []
+    for f in filenames:
+        wave, sample_rate = torchaudio.load(f)
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
+        # We use only the first channel
+        ans.append(wave[0])
+    return ans
+
+
+def greedy_search(
+    decoder: torch.jit.ScriptModule,
+    joiner: torch.jit.ScriptModule,
+    encoder_out: torch.Tensor,
+    decoder_out: Optional[torch.Tensor] = None,
+    hyp: Optional[List[int]] = None,
+):
+    assert encoder_out.ndim == 2
+    context_size = 2
+    blank_id = 0
+
+    if decoder_out is None:
+        assert hyp is None, hyp
+        hyp = [blank_id] * context_size
+        decoder_input = torch.tensor(hyp, dtype=torch.int32).unsqueeze(0)
+        # decoder_input.shape (1,, 1 context_size)
+        decoder_out = decoder(decoder_input, torch.tensor([0])).squeeze(1)
+    else:
+        assert decoder_out.ndim == 2
+        assert hyp is not None, hyp
+
+    T = encoder_out.size(0)
+    for i in range(T):
+        cur_encoder_out = encoder_out[i : i + 1]
+        joiner_out = joiner(cur_encoder_out, decoder_out).squeeze(0)
+        y = joiner_out.argmax(dim=0).item()
+
+        if y != blank_id:
+            hyp.append(y)
+            decoder_input = hyp[-context_size:]
+
+            decoder_input = torch.tensor(decoder_input, dtype=torch.int32).unsqueeze(0)
+            decoder_out = decoder(decoder_input, torch.tensor([0])).squeeze(1)
+
+    return hyp, decoder_out
+
+
+def create_streaming_feature_extractor(sample_rate) -> OnlineFeature:
+    """Create a CPU streaming feature extractor.
+
+    At present, we assume it returns a fbank feature extractor with
+    fixed options. In the future, we will support passing in the options
+    from outside.
+
+    Returns:
+      Return a CPU streaming feature extractor.
+    """
+    opts = FbankOptions()
+    opts.device = "cpu"
+    opts.frame_opts.dither = 0
+    opts.frame_opts.snip_edges = False
+    opts.frame_opts.samp_freq = sample_rate
+    opts.mel_opts.num_bins = 80
+    return OnlineFbank(opts)
+
+
+@torch.no_grad()
+def main():
+    parser = get_parser()
+    args = parser.parse_args()
+    logging.info(vars(args))
+
+    device = torch.device("cpu")
+    if torch.cuda.is_available():
+        device = torch.device("cuda", 0)
+
+    logging.info(f"device: {device}")
+
+    encoder = torch.jit.load(args.encoder_model_filename)
+    decoder = torch.jit.load(args.decoder_model_filename)
+    joiner = torch.jit.load(args.joiner_model_filename)
+
+    encoder.eval()
+    decoder.eval()
+    joiner.eval()
+
+    encoder.to(device)
+    decoder.to(device)
+    joiner.to(device)
+
+    sp = spm.SentencePieceProcessor()
+    sp.load(args.bpe_model)
+
+    logging.info("Constructing Fbank computer")
+    online_fbank = create_streaming_feature_extractor(args.sample_rate)
+
+    logging.info(f"Reading sound files: {args.sound_file}")
+    wave_samples = read_sound_files(
+        filenames=[args.sound_file],
+        expected_sample_rate=args.sample_rate,
+    )[0]
+    logging.info(wave_samples.shape)
+
+    logging.info("Decoding started")
+    chunk_length = encoder.chunk_length
+    right_context_length = encoder.right_context_length
+
+    # Assume the subsampling factor is 4
+    pad_length = right_context_length + 2 * 4 + 3
+    T = chunk_length + pad_length
+
+    logging.info(f"chunk_length: {chunk_length}")
+    logging.info(f"right_context_length: {right_context_length}")
+
+    states = encoder.init_states(device)
+    logging.info(f"num layers: {len(states)//4}")
+
+    tail_padding = torch.zeros(int(0.3 * args.sample_rate), dtype=torch.float32)
+
+    wave_samples = torch.cat([wave_samples, tail_padding])
+
+    chunk = int(0.25 * args.sample_rate)  # 0.2 second
+    num_processed_frames = 0
+
+    hyp = None
+    decoder_out = None
+
+    start = 0
+    while start < wave_samples.numel():
+        logging.info(f"{start}/{wave_samples.numel()}")
+        end = min(start + chunk, wave_samples.numel())
+        samples = wave_samples[start:end]
+        start += chunk
+        online_fbank.accept_waveform(
+            sampling_rate=args.sample_rate,
+            waveform=samples,
+        )
+        while online_fbank.num_frames_ready - num_processed_frames >= T:
+            frames = []
+            for i in range(T):
+                frames.append(online_fbank.get_frame(num_processed_frames + i))
+            num_processed_frames += chunk_length
+            frames = torch.cat(frames, dim=0).unsqueeze(0)
+            # TODO(fangjun): remove x_lens
+            x_lens = torch.tensor([T])
+            encoder_out, _, states = encoder(frames, x_lens, states)
+
+            hyp, decoder_out = greedy_search(
+                decoder, joiner, encoder_out.squeeze(0), decoder_out, hyp
+            )
+
+    context_size = 2
+
+    logging.info(args.sound_file)
+    logging.info(sp.decode(hyp[context_size:]))
+
+    logging.info("Decoding Done")
+
+
+torch.set_num_threads(4)
+torch.set_num_interop_threads(1)
+torch._C._jit_set_profiling_executor(False)
+torch._C._jit_set_profiling_mode(False)
+torch._C._set_graph_executor_optimize(False)
+if __name__ == "__main__":
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+    logging.basicConfig(format=formatter, level=logging.INFO)
+    main()
diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/lstmp.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/lstmp.py
new file mode 120000
index 000000000..4f377cd01
--- /dev/null
+++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/lstmp.py
@@ -0,0 +1 @@
+../lstm_transducer_stateless2/lstmp.py
\ No newline at end of file
diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/scaling_converter.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/scaling_converter.py
new file mode 120000
index 000000000..3b667058d
--- /dev/null
+++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/scaling_converter.py
@@ -0,0 +1 @@
+../pruned_transducer_stateless3/scaling_converter.py
\ No newline at end of file
diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming-ncnn-decode.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming-ncnn-decode.py
new file mode 100755
index 000000000..b21fe5c7e
--- /dev/null
+++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming-ncnn-decode.py
@@ -0,0 +1,387 @@
+#!/usr/bin/env python3
+#
+# Copyright      2022  Xiaomi Corp.        (authors: Fangjun Kuang, Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+
+./conv_emformer_transducer_stateless2/streaming-ncnn-decode.py \
+  --tokens ./sherpa-ncnn-conv-emformer-transducer-2022-12-04/tokens.txt \
+  --encoder-param-filename ./sherpa-ncnn-conv-emformer-transducer-2022-12-04/encoder_jit_trace-epoch-30-avg-10-pnnx.ncnn.param \
+  --encoder-bin-filename ./sherpa-ncnn-conv-emformer-transducer-2022-12-04/encoder_jit_trace-epoch-30-avg-10-pnnx.ncnn.bin \
+  --decoder-param-filename ./sherpa-ncnn-conv-emformer-transducer-2022-12-04/decoder_jit_trace-epoch-30-avg-10-pnnx.ncnn.param \
+  --decoder-bin-filename ./sherpa-ncnn-conv-emformer-transducer-2022-12-04/decoder_jit_trace-epoch-30-avg-10-pnnx.ncnn.bin \
+  --joiner-param-filename ./sherpa-ncnn-conv-emformer-transducer-2022-12-04/joiner_jit_trace-epoch-30-avg-10-pnnx.ncnn.param \
+  --joiner-bin-filename ./sherpa-ncnn-conv-emformer-transducer-2022-12-04/joiner_jit_trace-epoch-30-avg-10-pnnx.ncnn.bin \
+  ./sherpa-ncnn-conv-emformer-transducer-2022-12-04/test_wavs/1089-134686-0001.wav
+
+You can find pretrained models at
+https://huggingface.co/csukuangfj/sherpa-ncnn-conv-emformer-transducer-2022-12-04
+"""
+
+import argparse
+import logging
+from typing import List, Optional, Tuple
+
+import k2
+import ncnn
+import torch
+import torchaudio
+from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature
+
+
+def get_args():
+    parser = argparse.ArgumentParser()
+
+    parser.add_argument(
+        "--tokens",
+        type=str,
+        help="Path to tokens.txt",
+    )
+
+    parser.add_argument(
+        "--encoder-param-filename",
+        type=str,
+        help="Path to encoder.ncnn.param",
+    )
+
+    parser.add_argument(
+        "--encoder-bin-filename",
+        type=str,
+        help="Path to encoder.ncnn.bin",
+    )
+
+    parser.add_argument(
+        "--decoder-param-filename",
+        type=str,
+        help="Path to decoder.ncnn.param",
+    )
+
+    parser.add_argument(
+        "--decoder-bin-filename",
+        type=str,
+        help="Path to decoder.ncnn.bin",
+    )
+
+    parser.add_argument(
+        "--joiner-param-filename",
+        type=str,
+        help="Path to joiner.ncnn.param",
+    )
+
+    parser.add_argument(
+        "--joiner-bin-filename",
+        type=str,
+        help="Path to joiner.ncnn.bin",
+    )
+
+    parser.add_argument(
+        "sound_filename",
+        type=str,
+        help="Path to foo.wav",
+    )
+
+    return parser.parse_args()
+
+
+class Model:
+    def __init__(self, args):
+        self.init_encoder(args)
+        self.init_decoder(args)
+        self.init_joiner(args)
+
+        self.num_layers = 12
+        self.memory_size = 32
+        self.d_model = 512
+        self.cnn_module_kernel = 31
+
+        self.left_context_length = 32 // 4  # after subsampling
+        self.chunk_length = 32  # before subsampling
+        right_context_length = 8  # before subsampling
+        pad_length = right_context_length + 2 * 4 + 3
+        self.T = self.chunk_length + pad_length
+        print("T", self.T, self.chunk_length)
+
+    def get_init_states(self) -> List[torch.Tensor]:
+        states = []
+
+        for i in range(self.num_layers):
+            s0 = torch.zeros(self.memory_size, self.d_model)
+            s1 = torch.zeros(self.left_context_length, self.d_model)
+            s2 = torch.zeros(self.left_context_length, self.d_model)
+            s3 = torch.zeros(self.d_model, self.cnn_module_kernel - 1)
+            states.extend([s0, s1, s2, s3])
+
+        return states
+
+    def init_encoder(self, args):
+        encoder_net = ncnn.Net()
+        encoder_net.opt.use_packing_layout = False
+        encoder_net.opt.use_fp16_storage = False
+        encoder_param = args.encoder_param_filename
+        encoder_model = args.encoder_bin_filename
+
+        encoder_net.load_param(encoder_param)
+        encoder_net.load_model(encoder_model)
+
+        self.encoder_net = encoder_net
+
+    def init_decoder(self, args):
+        decoder_param = args.decoder_param_filename
+        decoder_model = args.decoder_bin_filename
+
+        decoder_net = ncnn.Net()
+
+        decoder_net.load_param(decoder_param)
+        decoder_net.load_model(decoder_model)
+
+        self.decoder_net = decoder_net
+
+    def init_joiner(self, args):
+        joiner_param = args.joiner_param_filename
+        joiner_model = args.joiner_bin_filename
+        joiner_net = ncnn.Net()
+        joiner_net.load_param(joiner_param)
+        joiner_net.load_model(joiner_model)
+
+        self.joiner_net = joiner_net
+
+    def run_encoder(
+        self,
+        x: torch.Tensor,
+        states: List[torch.Tensor],
+    ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
+        """
+        Args:
+          x:
+            A tensor of shape (T, C)
+          states:
+            A list of tensors. len(states) == self.num_layers * 4
+        Returns:
+          Return a tuple containing:
+           - encoder_out, a tensor of shape (T, encoder_dim).
+           - next_states, a list of tensors containing the next states
+        """
+        with self.encoder_net.create_extractor() as ex:
+            ex.set_num_threads(4)
+            ex.input("in0", ncnn.Mat(x.numpy()).clone())
+
+            # layer0 in2-in5
+            # layer1 in6-in9
+            for i in range(self.num_layers):
+                offset = 1 + i * 4
+                name = f"in{offset}"
+                # (32, 1, 512) -> (32, 512)
+                ex.input(name, ncnn.Mat(states[i * 4 + 0].numpy()).clone())
+
+                name = f"in{offset+1}"
+                #  (8, 1, 512) -> (8, 512)
+                ex.input(name, ncnn.Mat(states[i * 4 + 1].numpy()).clone())
+
+                name = f"in{offset+2}"
+                #  (8, 1, 512) -> (8, 512)
+                ex.input(name, ncnn.Mat(states[i * 4 + 2].numpy()).clone())
+
+                name = f"in{offset+3}"
+                #  (1, 512, 2) -> (512, 2)
+                ex.input(name, ncnn.Mat(states[i * 4 + 3].numpy()).clone())
+
+            import pdb
+
+            #  pdb.set_trace()
+            ret, ncnn_out0 = ex.extract("out0")
+            #  assert ret == 0, ret
+            encoder_out = torch.from_numpy(ncnn_out0.numpy()).clone()
+
+            out_states: List[torch.Tensor] = []
+            for i in range(4 * self.num_layers):
+                name = f"out{i+1}"
+                ret, ncnn_out_state = ex.extract(name)
+                assert ret == 0, ret
+                ncnn_out_state = torch.from_numpy(ncnn_out_state.numpy())
+                out_states.append(ncnn_out_state)
+
+            return encoder_out, out_states
+
+    def run_decoder(self, decoder_input):
+        assert decoder_input.dtype == torch.int32
+
+        with self.decoder_net.create_extractor() as ex:
+            ex.set_num_threads(4)
+            ex.input("in0", ncnn.Mat(decoder_input.numpy()).clone())
+            ret, ncnn_out0 = ex.extract("out0")
+            assert ret == 0, ret
+            decoder_out = torch.from_numpy(ncnn_out0.numpy()).clone()
+            return decoder_out
+
+    def run_joiner(self, encoder_out, decoder_out):
+        with self.joiner_net.create_extractor() as ex:
+            ex.set_num_threads(4)
+            ex.input("in0", ncnn.Mat(encoder_out.numpy()).clone())
+            ex.input("in1", ncnn.Mat(decoder_out.numpy()).clone())
+            ret, ncnn_out0 = ex.extract("out0")
+            assert ret == 0, ret
+            joiner_out = torch.from_numpy(ncnn_out0.numpy()).clone()
+            return joiner_out
+
+
+def read_sound_files(
+    filenames: List[str], expected_sample_rate: float
+) -> List[torch.Tensor]:
+    """Read a list of sound files into a list 1-D float32 torch tensors.
+    Args:
+      filenames:
+        A list of sound filenames.
+      expected_sample_rate:
+        The expected sample rate of the sound files.
+    Returns:
+      Return a list of 1-D float32 torch tensors.
+    """
+    ans = []
+    for f in filenames:
+        wave, sample_rate = torchaudio.load(f)
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
+        # We use only the first channel
+        ans.append(wave[0])
+    return ans
+
+
+def create_streaming_feature_extractor() -> OnlineFeature:
+    """Create a CPU streaming feature extractor.
+
+    At present, we assume it returns a fbank feature extractor with
+    fixed options. In the future, we will support passing in the options
+    from outside.
+
+    Returns:
+      Return a CPU streaming feature extractor.
+    """
+    opts = FbankOptions()
+    opts.device = "cpu"
+    opts.frame_opts.dither = 0
+    opts.frame_opts.snip_edges = False
+    opts.frame_opts.samp_freq = 16000
+    opts.mel_opts.num_bins = 80
+    return OnlineFbank(opts)
+
+
+def greedy_search(
+    model: Model,
+    encoder_out: torch.Tensor,
+    decoder_out: Optional[torch.Tensor] = None,
+    hyp: Optional[List[int]] = None,
+):
+    context_size = 2
+    blank_id = 0
+
+    if decoder_out is None:
+        assert hyp is None, hyp
+        hyp = [blank_id] * context_size
+        decoder_input = torch.tensor(hyp, dtype=torch.int32)  # (1, context_size)
+        decoder_out = model.run_decoder(decoder_input).squeeze(0)
+    else:
+        assert decoder_out.ndim == 1
+        assert hyp is not None, hyp
+
+    T = encoder_out.size(0)
+    for t in range(T):
+        cur_encoder_out = encoder_out[t]
+
+        joiner_out = model.run_joiner(cur_encoder_out, decoder_out)
+        y = joiner_out.argmax(dim=0).item()
+        if y != blank_id:
+            hyp.append(y)
+            decoder_input = hyp[-context_size:]
+            decoder_input = torch.tensor(decoder_input, dtype=torch.int32)
+            decoder_out = model.run_decoder(decoder_input).squeeze(0)
+
+    return hyp, decoder_out
+
+
+def main():
+    args = get_args()
+    logging.info(vars(args))
+
+    model = Model(args)
+
+    sound_file = args.sound_filename
+
+    sample_rate = 16000
+
+    logging.info("Constructing Fbank computer")
+    online_fbank = create_streaming_feature_extractor()
+
+    logging.info(f"Reading sound files: {sound_file}")
+    wave_samples = read_sound_files(
+        filenames=[sound_file],
+        expected_sample_rate=sample_rate,
+    )[0]
+    logging.info(wave_samples.shape)
+
+    tail_padding = torch.zeros(int(0.3 * sample_rate), dtype=torch.float32)
+
+    wave_samples = torch.cat([wave_samples, tail_padding])
+
+    states = model.get_init_states()
+
+    hyp = None
+    decoder_out = None
+
+    num_processed_frames = 0
+    segment = model.T
+    offset = model.chunk_length
+
+    chunk = int(1 * sample_rate)  # 0.2 second
+
+    start = 0
+    while start < wave_samples.numel():
+        end = min(start + chunk, wave_samples.numel())
+        samples = wave_samples[start:end]
+        start += chunk
+
+        online_fbank.accept_waveform(
+            sampling_rate=sample_rate,
+            waveform=samples,
+        )
+        while online_fbank.num_frames_ready - num_processed_frames >= segment:
+            frames = []
+            for i in range(segment):
+                frames.append(online_fbank.get_frame(num_processed_frames + i))
+            num_processed_frames += offset
+            frames = torch.cat(frames, dim=0)
+            encoder_out, states = model.run_encoder(frames, states)
+            hyp, decoder_out = greedy_search(model, encoder_out, decoder_out, hyp)
+
+    symbol_table = k2.SymbolTable.from_file(args.tokens)
+
+    context_size = 2
+    text = ""
+    for i in hyp[context_size:]:
+        text += symbol_table[i]
+    text = text.replace("▁", " ").strip()
+
+    logging.info(sound_file)
+    logging.info(text)
+
+
+if __name__ == "__main__":
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+    logging.basicConfig(format=formatter, level=logging.INFO)
+
+    main()
diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train2.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train2.py
new file mode 100755
index 000000000..c91f94876
--- /dev/null
+++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train2.py
@@ -0,0 +1,1128 @@
+#!/usr/bin/env python3
+# Copyright    2021  Xiaomi Corp.        (authors: Fangjun Kuang,
+#                                                  Wei Kang,
+#                                                  Mingshuang Luo,)
+#                                                  Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+
+export CUDA_VISIBLE_DEVICES="0,1,2,3"
+
+./conv_emformer_transducer_stateless2/train.py \
+  --world-size 4 \
+  --num-epochs 30 \
+  --start-epoch 1 \
+  --exp-dir conv_emformer_transducer_stateless2/exp \
+  --full-libri 1 \
+  --max-duration 280 \
+  --master-port 12321 \
+  --num-encoder-layers 12 \
+  --chunk-length 32 \
+  --cnn-module-kernel 31 \
+  --left-context-length 32 \
+  --right-context-length 8 \
+  --memory-size 32
+
+# For mix precision training:
+./conv_emformer_transducer_stateless2/train.py \
+  --world-size 4 \
+  --num-epochs 30 \
+  --start-epoch 1 \
+  --use-fp16 1 \
+  --exp-dir conv_emformer_transducer_stateless2/exp \
+  --full-libri 1 \
+  --max-duration 300 \
+  --master-port 12321 \
+  --num-encoder-layers 12 \
+  --chunk-length 32 \
+  --cnn-module-kernel 31 \
+  --left-context-length 32 \
+  --right-context-length 8 \
+  --memory-size 32
+"""
+
+
+import argparse
+import copy
+import logging
+import warnings
+from pathlib import Path
+from shutil import copyfile
+from typing import Any, Dict, Optional, Tuple, Union
+
+import k2
+import optim
+import sentencepiece as spm
+import torch
+import torch.multiprocessing as mp
+import torch.nn as nn
+from asr_datamodule import LibriSpeechAsrDataModule
+from decoder import Decoder
+from emformer2 import Emformer
+from joiner import Joiner
+from lhotse.cut import Cut
+from lhotse.dataset.sampling.base import CutSampler
+from lhotse.utils import fix_random_seed
+from model import Transducer
+from optim import Eden, Eve
+from torch import Tensor
+from torch.cuda.amp import GradScaler
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.utils.tensorboard import SummaryWriter
+
+from icefall import diagnostics
+from icefall.checkpoint import load_checkpoint, remove_checkpoints
+from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
+from icefall.checkpoint import (
+    save_checkpoint_with_global_batch_idx,
+    update_averaged_model,
+)
+from icefall.dist import cleanup_dist, setup_dist
+from icefall.env import get_env_info
+from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
+
+LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
+
+
+def add_model_arguments(parser: argparse.ArgumentParser):
+    parser.add_argument(
+        "--encoder-dim",
+        type=int,
+        default=512,
+        help="Attention dim for the Emformer",
+    )
+
+    parser.add_argument(
+        "--nhead",
+        type=int,
+        default=8,
+        help="Number of attention heads for the Emformer",
+    )
+
+    parser.add_argument(
+        "--dim-feedforward",
+        type=int,
+        default=2048,
+        help="Feed-forward dimension for the Emformer",
+    )
+
+    parser.add_argument(
+        "--num-encoder-layers",
+        type=int,
+        default=12,
+        help="Number of encoder layers for the Emformer",
+    )
+
+    parser.add_argument(
+        "--cnn-module-kernel",
+        type=int,
+        default=31,
+        help="Kernel size for the convolution module.",
+    )
+
+    parser.add_argument(
+        "--left-context-length",
+        type=int,
+        default=32,
+        help="""Number of frames before subsampling for left context
+        in the Emformer.""",
+    )
+
+    parser.add_argument(
+        "--chunk-length",
+        type=int,
+        default=32,
+        help="""Number of frames before subsampling for each chunk
+        in the Emformer.""",
+    )
+
+    parser.add_argument(
+        "--right-context-length",
+        type=int,
+        default=8,
+        help="""Number of frames before subsampling for right context
+        in the Emformer.""",
+    )
+
+    parser.add_argument(
+        "--memory-size",
+        type=int,
+        default=0,
+        help="Number of entries in the memory for the Emformer",
+    )
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+
+    parser.add_argument(
+        "--world-size",
+        type=int,
+        default=1,
+        help="Number of GPUs for DDP training.",
+    )
+
+    parser.add_argument(
+        "--master-port",
+        type=int,
+        default=12354,
+        help="Master port to use for DDP training.",
+    )
+
+    parser.add_argument(
+        "--tensorboard",
+        type=str2bool,
+        default=True,
+        help="Should various information be logged in tensorboard.",
+    )
+
+    parser.add_argument(
+        "--num-epochs",
+        type=int,
+        default=30,
+        help="Number of epochs to train.",
+    )
+
+    parser.add_argument(
+        "--start-epoch",
+        type=int,
+        default=1,
+        help="""Resume training from this epoch. It should be positive.
+        If larger than 1, it will load checkpoint from
+        exp-dir/epoch-{start_epoch-1}.pt
+        """,
+    )
+
+    parser.add_argument(
+        "--start-batch",
+        type=int,
+        default=0,
+        help="""If positive, --start-epoch is ignored and
+        it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt
+        """,
+    )
+
+    parser.add_argument(
+        "--exp-dir",
+        type=str,
+        default="pruned_transducer_stateless2/exp",
+        help="""The experiment dir.
+        It specifies the directory where all training related
+        files, e.g., checkpoints, log, etc, are saved
+        """,
+    )
+
+    parser.add_argument(
+        "--bpe-model",
+        type=str,
+        default="data/lang_bpe_500/bpe.model",
+        help="Path to the BPE model",
+    )
+
+    parser.add_argument(
+        "--initial-lr",
+        type=float,
+        default=0.003,
+        help="""The initial learning rate. This value should not need to be
+        changed.""",
+    )
+
+    parser.add_argument(
+        "--lr-batches",
+        type=float,
+        default=5000,
+        help="""Number of steps that affects how rapidly the learning rate decreases.
+        We suggest not to change this.""",
+    )
+
+    parser.add_argument(
+        "--lr-epochs",
+        type=float,
+        default=6,
+        help="""Number of epochs that affects how rapidly the learning rate decreases.
+        """,
+    )
+
+    parser.add_argument(
+        "--context-size",
+        type=int,
+        default=2,
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+    )
+
+    parser.add_argument(
+        "--prune-range",
+        type=int,
+        default=5,
+        help="The prune range for rnnt loss, it means how many symbols(context)"
+        "we are using to compute the loss",
+    )
+
+    parser.add_argument(
+        "--lm-scale",
+        type=float,
+        default=0.25,
+        help="The scale to smooth the loss with lm "
+        "(output of prediction network) part.",
+    )
+
+    parser.add_argument(
+        "--am-scale",
+        type=float,
+        default=0.0,
+        help="The scale to smooth the loss with am (output of encoder network) part.",
+    )
+
+    parser.add_argument(
+        "--simple-loss-scale",
+        type=float,
+        default=0.5,
+        help="To get pruning ranges, we will calculate a simple version"
+        "loss(joiner is just addition), this simple loss also uses for"
+        "training (as a regularization item). We will scale the simple loss"
+        "with this parameter before adding to the final loss.",
+    )
+
+    parser.add_argument(
+        "--seed",
+        type=int,
+        default=42,
+        help="The seed for random generators intended for reproducibility",
+    )
+
+    parser.add_argument(
+        "--print-diagnostics",
+        type=str2bool,
+        default=False,
+        help="Accumulate stats on activations, print them and exit.",
+    )
+
+    parser.add_argument(
+        "--save-every-n",
+        type=int,
+        default=8000,
+        help="""Save checkpoint after processing this number of batches"
+        periodically. We save checkpoint to exp-dir/ whenever
+        params.batch_idx_train % save_every_n == 0. The checkpoint filename
+        has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
+        Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
+        end of each epoch where `xxx` is the epoch number counting from 0.
+        """,
+    )
+
+    parser.add_argument(
+        "--keep-last-k",
+        type=int,
+        default=20,
+        help="""Only keep this number of checkpoints on disk.
+        For instance, if it is 3, there are only 3 checkpoints
+        in the exp-dir with filenames `checkpoint-xxx.pt`.
+        It does not affect checkpoints with name `epoch-xxx.pt`.
+        """,
+    )
+
+    parser.add_argument(
+        "--average-period",
+        type=int,
+        default=100,
+        help="""Update the averaged model, namely `model_avg`, after processing
+        this number of batches. `model_avg` is a separate version of model,
+        in which each floating-point parameter is the average of all the
+        parameters from the start of training. Each time we take the average,
+        we do: `model_avg = model * (average_period / batch_idx_train) +
+            model_avg * ((batch_idx_train - average_period) / batch_idx_train)`.
+        """,
+    )
+
+    parser.add_argument(
+        "--use-fp16",
+        type=str2bool,
+        default=False,
+        help="Whether to use half precision training.",
+    )
+
+    add_model_arguments(parser)
+
+    return parser
+
+
+def get_params() -> AttributeDict:
+    """Return a dict containing training parameters.
+
+    All training related parameters that are not passed from the commandline
+    are saved in the variable `params`.
+
+    Commandline options are merged into `params` after they are parsed, so
+    you can also access them via `params`.
+
+    Explanation of options saved in `params`:
+
+        - best_train_loss: Best training loss so far. It is used to select
+                           the model that has the lowest training loss. It is
+                           updated during the training.
+
+        - best_valid_loss: Best validation loss so far. It is used to select
+                           the model that has the lowest validation loss. It is
+                           updated during the training.
+
+        - best_train_epoch: It is the epoch that has the best training loss.
+
+        - best_valid_epoch: It is the epoch that has the best validation loss.
+
+        - batch_idx_train: Used to writing statistics to tensorboard. It
+                           contains number of batches trained so far across
+                           epochs.
+
+        - log_interval:  Print training loss if batch_idx % log_interval` is 0
+
+        - reset_interval: Reset statistics if batch_idx % reset_interval is 0
+
+        - valid_interval:  Run validation if batch_idx % valid_interval is 0
+
+        - feature_dim: The model input dim. It has to match the one used
+                       in computing features.
+
+        - subsampling_factor:  The subsampling factor for the model.
+
+        - encoder_dim: Hidden dim for multi-head attention model.
+
+        - num_decoder_layers: Number of decoder layer of transformer decoder.
+
+        - warm_step: The warm_step for Noam optimizer.
+    """
+    params = AttributeDict(
+        {
+            "best_train_loss": float("inf"),
+            "best_valid_loss": float("inf"),
+            "best_train_epoch": -1,
+            "best_valid_epoch": -1,
+            "batch_idx_train": 0,
+            "log_interval": 50,
+            "reset_interval": 200,
+            "valid_interval": 3000,  # For the 100h subset, use 800
+            # parameters for Emformer
+            "feature_dim": 80,
+            "subsampling_factor": 4,
+            # parameters for decoder
+            "decoder_dim": 512,
+            # parameters for joiner
+            "joiner_dim": 512,
+            # parameters for Noam
+            "model_warm_step": 3000,  # arg given to model, not for lrate
+            "env_info": get_env_info(),
+        }
+    )
+
+    return params
+
+
+def get_encoder_model(params: AttributeDict) -> nn.Module:
+    # TODO: We can add an option to switch between Conformer and Transformer
+    encoder = Emformer(
+        num_features=params.feature_dim,
+        chunk_length=params.chunk_length,
+        subsampling_factor=params.subsampling_factor,
+        d_model=params.encoder_dim,
+        nhead=params.nhead,
+        dim_feedforward=params.dim_feedforward,
+        num_encoder_layers=params.num_encoder_layers,
+        cnn_module_kernel=params.cnn_module_kernel,
+        left_context_length=params.left_context_length,
+        right_context_length=params.right_context_length,
+        memory_size=params.memory_size,
+    )
+    return encoder
+
+
+def get_decoder_model(params: AttributeDict) -> nn.Module:
+    decoder = Decoder(
+        vocab_size=params.vocab_size,
+        decoder_dim=params.decoder_dim,
+        blank_id=params.blank_id,
+        context_size=params.context_size,
+    )
+    return decoder
+
+
+def get_joiner_model(params: AttributeDict) -> nn.Module:
+    joiner = Joiner(
+        encoder_dim=params.encoder_dim,
+        decoder_dim=params.decoder_dim,
+        joiner_dim=params.joiner_dim,
+        vocab_size=params.vocab_size,
+    )
+    return joiner
+
+
+def get_transducer_model(params: AttributeDict) -> nn.Module:
+    encoder = get_encoder_model(params)
+    decoder = get_decoder_model(params)
+    joiner = get_joiner_model(params)
+
+    model = Transducer(
+        encoder=encoder,
+        decoder=decoder,
+        joiner=joiner,
+        encoder_dim=params.encoder_dim,
+        decoder_dim=params.decoder_dim,
+        joiner_dim=params.joiner_dim,
+        vocab_size=params.vocab_size,
+    )
+    return model
+
+
+def load_checkpoint_if_available(
+    params: AttributeDict,
+    model: nn.Module,
+    model_avg: nn.Module = None,
+    optimizer: Optional[torch.optim.Optimizer] = None,
+    scheduler: Optional[LRSchedulerType] = None,
+) -> Optional[Dict[str, Any]]:
+    """Load checkpoint from file.
+
+    If params.start_batch is positive, it will load the checkpoint from
+    `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if
+    params.start_epoch is larger than 1, it will load the checkpoint from
+    `params.start_epoch - 1`.
+
+    Apart from loading state dict for `model` and `optimizer` it also updates
+    `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
+    and `best_valid_loss` in `params`.
+
+    Args:
+      params:
+        The return value of :func:`get_params`.
+      model:
+        The training model.
+      model_avg:
+        The stored model averaged from the start of training.
+      optimizer:
+        The optimizer that we are using.
+      scheduler:
+        The scheduler that we are using.
+    Returns:
+      Return a dict containing previously saved training info.
+    """
+    if params.start_batch > 0:
+        filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt"
+    elif params.start_epoch > 1:
+        filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
+    else:
+        return None
+
+    assert filename.is_file(), f"{filename} does not exist!"
+
+    saved_params = load_checkpoint(
+        filename,
+        model=model,
+        model_avg=model_avg,
+        optimizer=optimizer,
+        scheduler=scheduler,
+    )
+
+    keys = [
+        "best_train_epoch",
+        "best_valid_epoch",
+        "batch_idx_train",
+        "best_train_loss",
+        "best_valid_loss",
+    ]
+    for k in keys:
+        params[k] = saved_params[k]
+
+    if params.start_batch > 0:
+        if "cur_epoch" in saved_params:
+            params["start_epoch"] = saved_params["cur_epoch"]
+
+        if "cur_batch_idx" in saved_params:
+            params["cur_batch_idx"] = saved_params["cur_batch_idx"]
+
+    return saved_params
+
+
+def save_checkpoint(
+    params: AttributeDict,
+    model: Union[nn.Module, DDP],
+    model_avg: Optional[nn.Module] = None,
+    optimizer: Optional[torch.optim.Optimizer] = None,
+    scheduler: Optional[LRSchedulerType] = None,
+    sampler: Optional[CutSampler] = None,
+    scaler: Optional[GradScaler] = None,
+    rank: int = 0,
+) -> None:
+    """Save model, optimizer, scheduler and training stats to file.
+
+    Args:
+      params:
+        It is returned by :func:`get_params`.
+      model:
+        The training model.
+      model_avg:
+        The stored model averaged from the start of training.
+      optimizer:
+        The optimizer used in the training.
+      sampler:
+       The sampler for the training dataset.
+      scaler:
+        The scaler used for mix precision training.
+    """
+    if rank != 0:
+        return
+    filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
+    save_checkpoint_impl(
+        filename=filename,
+        model=model,
+        model_avg=model_avg,
+        params=params,
+        optimizer=optimizer,
+        scheduler=scheduler,
+        sampler=sampler,
+        scaler=scaler,
+        rank=rank,
+    )
+
+    if params.best_train_epoch == params.cur_epoch:
+        best_train_filename = params.exp_dir / "best-train-loss.pt"
+        copyfile(src=filename, dst=best_train_filename)
+
+    if params.best_valid_epoch == params.cur_epoch:
+        best_valid_filename = params.exp_dir / "best-valid-loss.pt"
+        copyfile(src=filename, dst=best_valid_filename)
+
+
+def compute_loss(
+    params: AttributeDict,
+    model: Union[nn.Module, DDP],
+    sp: spm.SentencePieceProcessor,
+    batch: dict,
+    is_training: bool,
+    warmup: float = 1.0,
+) -> Tuple[Tensor, MetricsTracker]:
+    """
+    Compute RNN-T loss given the model and its inputs.
+
+    Args:
+      params:
+        Parameters for training. See :func:`get_params`.
+      model:
+        The model for training. It is an instance of Conformer in our case.
+      batch:
+        A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+        for the content in it.
+      is_training:
+        True for training. False for validation. When it is True, this
+        function enables autograd during computation; when it is False, it
+        disables autograd.
+     warmup: a floating point value which increases throughout training;
+        values >= 1.0 are fully warmed up and have all modules present.
+    """
+    device = model.device if isinstance(model, DDP) else next(model.parameters()).device
+    feature = batch["inputs"]
+    # at entry, feature is (N, T, C)
+    assert feature.ndim == 3
+    feature = feature.to(device)
+
+    supervisions = batch["supervisions"]
+    feature_lens = supervisions["num_frames"].to(device)
+
+    texts = batch["supervisions"]["text"]
+    y = sp.encode(texts, out_type=int)
+    y = k2.RaggedTensor(y).to(device)
+
+    with torch.set_grad_enabled(is_training):
+        simple_loss, pruned_loss = model(
+            x=feature,
+            x_lens=feature_lens,
+            y=y,
+            prune_range=params.prune_range,
+            am_scale=params.am_scale,
+            lm_scale=params.lm_scale,
+            warmup=warmup,
+        )
+        # after the main warmup step, we keep pruned_loss_scale small
+        # for the same amount of time (model_warm_step), to avoid
+        # overwhelming the simple_loss and causing it to diverge,
+        # in case it had not fully learned the alignment yet.
+        pruned_loss_scale = (
+            0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
+        )
+        loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
+
+    assert loss.requires_grad == is_training
+
+    info = MetricsTracker()
+    with warnings.catch_warnings():
+        warnings.simplefilter("ignore")
+        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
+
+    # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances`  # noqa
+    info["utterances"] = feature.size(0)
+    # averaged input duration in frames over utterances
+    info["utt_duration"] = feature_lens.sum().item()
+    # averaged padding proportion over utterances
+    info["utt_pad_proportion"] = (
+        ((feature.size(1) - feature_lens) / feature.size(1)).sum().item()
+    )
+
+    # Note: We use reduction=sum while computing the loss.
+    info["loss"] = loss.detach().cpu().item()
+    info["simple_loss"] = simple_loss.detach().cpu().item()
+    info["pruned_loss"] = pruned_loss.detach().cpu().item()
+
+    return loss, info
+
+
+def compute_validation_loss(
+    params: AttributeDict,
+    model: Union[nn.Module, DDP],
+    sp: spm.SentencePieceProcessor,
+    valid_dl: torch.utils.data.DataLoader,
+    world_size: int = 1,
+) -> MetricsTracker:
+    """Run the validation process."""
+    model.eval()
+
+    tot_loss = MetricsTracker()
+
+    for batch_idx, batch in enumerate(valid_dl):
+        loss, loss_info = compute_loss(
+            params=params,
+            model=model,
+            sp=sp,
+            batch=batch,
+            is_training=False,
+        )
+        assert loss.requires_grad is False
+        tot_loss = tot_loss + loss_info
+
+    if world_size > 1:
+        tot_loss.reduce(loss.device)
+
+    loss_value = tot_loss["loss"] / tot_loss["frames"]
+    if loss_value < params.best_valid_loss:
+        params.best_valid_epoch = params.cur_epoch
+        params.best_valid_loss = loss_value
+
+    return tot_loss
+
+
+def train_one_epoch(
+    params: AttributeDict,
+    model: Union[nn.Module, DDP],
+    optimizer: torch.optim.Optimizer,
+    scheduler: LRSchedulerType,
+    sp: spm.SentencePieceProcessor,
+    train_dl: torch.utils.data.DataLoader,
+    valid_dl: torch.utils.data.DataLoader,
+    scaler: GradScaler,
+    model_avg: Optional[nn.Module] = None,
+    tb_writer: Optional[SummaryWriter] = None,
+    world_size: int = 1,
+    rank: int = 0,
+) -> None:
+    """Train the model for one epoch.
+
+    The training loss from the mean of all frames is saved in
+    `params.train_loss`. It runs the validation process every
+    `params.valid_interval` batches.
+
+    Args:
+      params:
+        It is returned by :func:`get_params`.
+      model:
+        The model for training.
+      optimizer:
+        The optimizer we are using.
+      scheduler:
+        The learning rate scheduler, we call step() every step.
+      train_dl:
+        Dataloader for the training dataset.
+      valid_dl:
+        Dataloader for the validation dataset.
+      scaler:
+        The scaler used for mix precision training.
+      model_avg:
+        The stored model averaged from the start of training.
+      tb_writer:
+        Writer to write log messages to tensorboard.
+      world_size:
+        Number of nodes in DDP training. If it is 1, DDP is disabled.
+      rank:
+        The rank of the node in DDP training. If no DDP is used, it should
+        be set to 0.
+    """
+    model.train()
+
+    tot_loss = MetricsTracker()
+
+    cur_batch_idx = params.get("cur_batch_idx", 0)
+
+    for batch_idx, batch in enumerate(train_dl):
+        if batch_idx < cur_batch_idx:
+            continue
+        cur_batch_idx = batch_idx
+
+        params.batch_idx_train += 1
+        batch_size = len(batch["supervisions"]["text"])
+
+        with torch.cuda.amp.autocast(enabled=params.use_fp16):
+            loss, loss_info = compute_loss(
+                params=params,
+                model=model,
+                sp=sp,
+                batch=batch,
+                is_training=True,
+                warmup=(params.batch_idx_train / params.model_warm_step),
+            )
+        # summary stats
+        tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
+
+        # NOTE: We use reduction==sum and loss is computed over utterances
+        # in the batch and there is no normalization to it so far.
+        scaler.scale(loss).backward()
+        scheduler.step_batch(params.batch_idx_train)
+        scaler.step(optimizer)
+        scaler.update()
+        optimizer.zero_grad()
+
+        if params.print_diagnostics and batch_idx == 5:
+            return
+
+        if (
+            rank == 0
+            and params.batch_idx_train > 0
+            and params.batch_idx_train % params.average_period == 0
+        ):
+            update_averaged_model(
+                params=params,
+                model_cur=model,
+                model_avg=model_avg,
+            )
+
+        if (
+            params.batch_idx_train > 0
+            and params.batch_idx_train % params.save_every_n == 0
+        ):
+            params.cur_batch_idx = batch_idx
+            save_checkpoint_with_global_batch_idx(
+                out_dir=params.exp_dir,
+                global_batch_idx=params.batch_idx_train,
+                model=model,
+                model_avg=model_avg,
+                params=params,
+                optimizer=optimizer,
+                scheduler=scheduler,
+                sampler=train_dl.sampler,
+                scaler=scaler,
+                rank=rank,
+            )
+            del params.cur_batch_idx
+            remove_checkpoints(
+                out_dir=params.exp_dir,
+                topk=params.keep_last_k,
+                rank=rank,
+            )
+
+        if batch_idx % params.log_interval == 0:
+            cur_lr = scheduler.get_last_lr()[0]
+            logging.info(
+                f"Epoch {params.cur_epoch}, "
+                f"batch {batch_idx}, loss[{loss_info}], "
+                f"tot_loss[{tot_loss}], batch size: {batch_size}, "
+                f"lr: {cur_lr:.2e}"
+            )
+
+            if tb_writer is not None:
+                tb_writer.add_scalar(
+                    "train/learning_rate", cur_lr, params.batch_idx_train
+                )
+
+                loss_info.write_summary(
+                    tb_writer, "train/current_", params.batch_idx_train
+                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+
+        if batch_idx > 0 and batch_idx % params.valid_interval == 0:
+            logging.info("Computing validation loss")
+            valid_info = compute_validation_loss(
+                params=params,
+                model=model,
+                sp=sp,
+                valid_dl=valid_dl,
+                world_size=world_size,
+            )
+            model.train()
+            logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
+            if tb_writer is not None:
+                valid_info.write_summary(
+                    tb_writer, "train/valid_", params.batch_idx_train
+                )
+
+    loss_value = tot_loss["loss"] / tot_loss["frames"]
+    params.train_loss = loss_value
+    if params.train_loss < params.best_train_loss:
+        params.best_train_epoch = params.cur_epoch
+        params.best_train_loss = params.train_loss
+
+
+def run(rank, world_size, args):
+    """
+    Args:
+      rank:
+        It is a value between 0 and `world_size-1`, which is
+        passed automatically by `mp.spawn()` in :func:`main`.
+        The node with rank 0 is responsible for saving checkpoint.
+      world_size:
+        Number of GPUs for DDP training.
+      args:
+        The return value of get_parser().parse_args()
+    """
+    params = get_params()
+    params.update(vars(args))
+    if params.full_libri is False:
+        params.valid_interval = 1600
+
+    fix_random_seed(params.seed)
+    if world_size > 1:
+        setup_dist(rank, world_size, params.master_port)
+
+    setup_logger(f"{params.exp_dir}/log/log-train")
+    logging.info("Training started")
+
+    if args.tensorboard and rank == 0:
+        tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
+    else:
+        tb_writer = None
+
+    device = torch.device("cpu")
+    if torch.cuda.is_available():
+        device = torch.device("cuda", rank)
+    logging.info(f"Device: {device}")
+
+    sp = spm.SentencePieceProcessor()
+    sp.load(params.bpe_model)
+
+    #  is defined in local/train_bpe_model.py
+    params.blank_id = sp.piece_to_id("")
+    params.vocab_size = sp.get_piece_size()
+
+    logging.info(params)
+
+    logging.info("About to create model")
+    model = get_transducer_model(params)
+
+    num_param = sum([p.numel() for p in model.parameters()])
+    logging.info(f"Number of model parameters: {num_param}")
+
+    assert params.save_every_n >= params.average_period
+    model_avg: Optional[nn.Module] = None
+    if rank == 0:
+        # model_avg is only used with rank 0
+        model_avg = copy.deepcopy(model)
+
+    assert params.start_epoch > 0, params.start_epoch
+    checkpoints = load_checkpoint_if_available(
+        params=params, model=model, model_avg=model_avg
+    )
+
+    model.to(device)
+    if world_size > 1:
+        logging.info("Using DDP")
+        model = DDP(model, device_ids=[rank])
+
+    optimizer = Eve(model.parameters(), lr=params.initial_lr)
+
+    scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
+
+    if checkpoints and "optimizer" in checkpoints:
+        logging.info("Loading optimizer state dict")
+        optimizer.load_state_dict(checkpoints["optimizer"])
+
+    if (
+        checkpoints
+        and "scheduler" in checkpoints
+        and checkpoints["scheduler"] is not None
+    ):
+        logging.info("Loading scheduler state dict")
+        scheduler.load_state_dict(checkpoints["scheduler"])
+
+    if params.print_diagnostics:
+        opts = diagnostics.TensorDiagnosticOptions(
+            2**22
+        )  # allow 4 megabytes per sub-module
+        diagnostic = diagnostics.attach_diagnostics(model, opts)
+
+    librispeech = LibriSpeechAsrDataModule(args)
+
+    if params.full_libri:
+        train_cuts = librispeech.train_all_shuf_cuts()
+    else:
+        train_cuts = librispeech.train_clean_100_cuts()
+
+    def remove_short_and_long_utt(c: Cut):
+        # Keep only utterances with duration between 1 second and 20 seconds
+        #
+        # Caution: There is a reason to select 20.0 here. Please see
+        # ../local/display_manifest_statistics.py
+        #
+        # You should use ../local/display_manifest_statistics.py to get
+        # an utterance duration distribution for your dataset to select
+        # the threshold
+        return 1.0 <= c.duration <= 20.0
+
+    train_cuts = train_cuts.filter(remove_short_and_long_utt)
+
+    if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
+        # We only load the sampler's state dict when it loads a checkpoint
+        # saved in the middle of an epoch
+        sampler_state_dict = checkpoints["sampler"]
+    else:
+        sampler_state_dict = None
+
+    train_dl = librispeech.train_dataloaders(
+        train_cuts, sampler_state_dict=sampler_state_dict
+    )
+
+    valid_cuts = librispeech.dev_clean_cuts()
+    valid_cuts += librispeech.dev_other_cuts()
+    valid_dl = librispeech.valid_dataloaders(valid_cuts)
+
+    if not params.print_diagnostics:
+        scan_pessimistic_batches_for_oom(
+            model=model,
+            train_dl=train_dl,
+            optimizer=optimizer,
+            sp=sp,
+            params=params,
+        )
+
+    scaler = GradScaler(enabled=params.use_fp16)
+    if checkpoints and "grad_scaler" in checkpoints:
+        logging.info("Loading grad scaler state dict")
+        scaler.load_state_dict(checkpoints["grad_scaler"])
+
+    for epoch in range(params.start_epoch, params.num_epochs + 1):
+        scheduler.step_epoch(epoch - 1)
+        fix_random_seed(params.seed + epoch - 1)
+        train_dl.sampler.set_epoch(epoch - 1)
+
+        if tb_writer is not None:
+            tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
+
+        params.cur_epoch = epoch
+
+        train_one_epoch(
+            params=params,
+            model=model,
+            model_avg=model_avg,
+            optimizer=optimizer,
+            scheduler=scheduler,
+            sp=sp,
+            train_dl=train_dl,
+            valid_dl=valid_dl,
+            scaler=scaler,
+            tb_writer=tb_writer,
+            world_size=world_size,
+            rank=rank,
+        )
+
+        if params.print_diagnostics:
+            diagnostic.print_diagnostics()
+            break
+
+        save_checkpoint(
+            params=params,
+            model=model,
+            model_avg=model_avg,
+            optimizer=optimizer,
+            scheduler=scheduler,
+            sampler=train_dl.sampler,
+            scaler=scaler,
+            rank=rank,
+        )
+
+    logging.info("Done!")
+
+    if world_size > 1:
+        torch.distributed.barrier()
+        cleanup_dist()
+
+
+def scan_pessimistic_batches_for_oom(
+    model: Union[nn.Module, DDP],
+    train_dl: torch.utils.data.DataLoader,
+    optimizer: torch.optim.Optimizer,
+    sp: spm.SentencePieceProcessor,
+    params: AttributeDict,
+):
+    from lhotse.dataset import find_pessimistic_batches
+
+    logging.info(
+        "Sanity check -- see if any of the batches in epoch 1 would cause OOM."
+    )
+    batches, crit_values = find_pessimistic_batches(train_dl.sampler)
+    for criterion, cuts in batches.items():
+        batch = train_dl.dataset[cuts]
+        try:
+            # warmup = 0.0 is so that the derivs for the pruned loss stay zero
+            # (i.e. are not remembered by the decaying-average in adam), because
+            # we want to avoid these params being subject to shrinkage in adam.
+            with torch.cuda.amp.autocast(enabled=params.use_fp16):
+                loss, _ = compute_loss(
+                    params=params,
+                    model=model,
+                    sp=sp,
+                    batch=batch,
+                    is_training=True,
+                    warmup=0.0,
+                )
+            loss.backward()
+            optimizer.step()
+            optimizer.zero_grad()
+        except RuntimeError as e:
+            if "CUDA out of memory" in str(e):
+                logging.error(
+                    "Your GPU ran out of memory with the current "
+                    "max_duration setting. We recommend decreasing "
+                    "max_duration and trying again.\n"
+                    f"Failing criterion: {criterion} "
+                    f"(={crit_values[criterion]}) ..."
+                )
+            raise
+
+
+def main():
+    parser = get_parser()
+    LibriSpeechAsrDataModule.add_arguments(parser)
+    args = parser.parse_args()
+    args.exp_dir = Path(args.exp_dir)
+
+    world_size = args.world_size
+    assert world_size >= 1
+    if world_size > 1:
+        mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
+    else:
+        run(rank=0, world_size=1, args=args)
+
+
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+if __name__ == "__main__":
+    main()

From 10472e7ffc8bd3f8a096eb7cc62c86a4b861a9a1 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Ali=20Haznedaro=C4=9Flu?=
 <53865510+ahazned@users.noreply.github.com>
Date: Wed, 7 Dec 2022 03:22:50 +0300
Subject: [PATCH 038/174] Update prepare.sh (#737)

---
 egs/spgispeech/ASR/prepare.sh | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/egs/spgispeech/ASR/prepare.sh b/egs/spgispeech/ASR/prepare.sh
index 4842f52d0..8331f94d5 100755
--- a/egs/spgispeech/ASR/prepare.sh
+++ b/egs/spgispeech/ASR/prepare.sh
@@ -108,7 +108,7 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
     pieces=$(find data/manifests -name "cuts_train_[0-9]*.jsonl.gz")
     lhotse combine $pieces data/manifests/cuts_train.jsonl.gz
   fi
-  gunzip -c data/manifests/train_cuts.jsonl.gz | shuf | gzip -c > data/manifests/train_cuts_shuf.jsonl.gz
+  gunzip -c data/manifests/cuts_train.jsonl.gz | shuf | gzip -c > data/manifests/cuts_train_shuf.jsonl.gz
 fi
 
 if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
@@ -136,7 +136,7 @@ if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
     # Add special words to words.txt
     echo " 0" > $lang_dir/words.txt
     echo "!SIL 1" >> $lang_dir/words.txt
-    echo "[UNK] 2" >> $lang_dir/words.txt
+    echo " 2" >> $lang_dir/words.txt
 
     # Add regular words to words.txt
     gunzip -c data/manifests/cuts_train_raw.jsonl.gz \

From 0e325c8782c8b9178cf0f2b030e49ae64f2b091d Mon Sep 17 00:00:00 2001
From: huangruizhe 
Date: Wed, 7 Dec 2022 02:43:26 -0500
Subject: [PATCH 039/174] Fixed rnn_lm model.py (#738)

---
 icefall/rnn_lm/model.py | 8 ++++----
 1 file changed, 4 insertions(+), 4 deletions(-)

diff --git a/icefall/rnn_lm/model.py b/icefall/rnn_lm/model.py
index 9eef88840..3598a4857 100644
--- a/icefall/rnn_lm/model.py
+++ b/icefall/rnn_lm/model.py
@@ -159,10 +159,10 @@ class RnnLmModel(torch.nn.Module):
         if state:
             h, c = state
         else:
-            h = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.input_size).to(
+            h = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.hidden_size).to(
                 device
             )
-            c = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.input_size).to(
+            c = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.hidden_size).to(
                 device
             )
 
@@ -179,8 +179,8 @@ class RnnLmModel(torch.nn.Module):
         if state:
             h, c = state
         else:
-            h = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.input_size)
-            c = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.input_size)
+            h = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.hidden_size)
+            c = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.hidden_size)
 
         device = next(self.parameters()).device
 

From d65fe17d2766e34adbb4080f9691ea829ac0ae05 Mon Sep 17 00:00:00 2001
From: armusc <46787089+armusc@users.noreply.github.com>
Date: Thu, 8 Dec 2022 13:21:51 +0100
Subject: [PATCH 040/174] Update train.py with parameters_names as required by
 optimizer initialization (#742)

* Update train.py
---
 egs/ami/ASR/pruned_transducer_stateless7/train.py     | 11 ++++++++++-
 .../ASR/pruned_transducer_stateless7_ctc/train.py     | 11 ++++++++++-
 2 files changed, 20 insertions(+), 2 deletions(-)

diff --git a/egs/ami/ASR/pruned_transducer_stateless7/train.py b/egs/ami/ASR/pruned_transducer_stateless7/train.py
index b5efb3405..81823ced2 100755
--- a/egs/ami/ASR/pruned_transducer_stateless7/train.py
+++ b/egs/ami/ASR/pruned_transducer_stateless7/train.py
@@ -972,7 +972,16 @@ def run(rank, world_size, args):
         logging.info("Using DDP")
         model = DDP(model, device_ids=[rank], find_unused_parameters=True)
 
-    optimizer = ScaledAdam(model.parameters(), lr=params.base_lr, clipping_scale=2.0)
+    parameters_names = []
+    parameters_names.append(
+        [name_param_pair[0] for name_param_pair in model.named_parameters()]
+    )
+    optimizer = ScaledAdam(
+        model.parameters(),
+        lr=params.base_lr,
+        clipping_scale=2.0,
+        parameters_names=parameters_names,
+    )
 
     scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
 
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/train.py
index abfd56e5a..162ad8412 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/train.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/train.py
@@ -1036,7 +1036,16 @@ def run(rank, world_size, args):
         logging.info("Using DDP")
         model = DDP(model, device_ids=[rank], find_unused_parameters=True)
 
-    optimizer = ScaledAdam(model.parameters(), lr=params.base_lr, clipping_scale=2.0)
+    parameters_names = []
+    parameters_names.append(
+        [name_param_pair[0] for name_param_pair in model.named_parameters()]
+    )
+    optimizer = ScaledAdam(
+        model.parameters(),
+        lr=params.base_lr,
+        clipping_scale=2.0,
+        parameters_names=parameters_names,
+    )
 
     scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
 

From 4501821fd98821a6cf3a238c6dc5c01422643fdb Mon Sep 17 00:00:00 2001
From: Fangjun Kuang 
Date: Fri, 9 Dec 2022 16:46:44 +0800
Subject: [PATCH 041/174] Support using OpenFst to compile HLG. (#606)

* Support using OpenFst to compile HLG.

* Fix style issues
---
 .../ASR/local/compile_hlg_using_openfst.py    | 184 ++++++++++++++++++
 egs/librispeech/ASR/prepare.sh                |  41 +++-
 icefall/shared/convert-k2-to-openfst.py       | 102 ++++++++++
 requirements.txt                              |   1 +
 4 files changed, 325 insertions(+), 3 deletions(-)
 create mode 100755 egs/librispeech/ASR/local/compile_hlg_using_openfst.py
 create mode 100755 icefall/shared/convert-k2-to-openfst.py

diff --git a/egs/librispeech/ASR/local/compile_hlg_using_openfst.py b/egs/librispeech/ASR/local/compile_hlg_using_openfst.py
new file mode 100755
index 000000000..9e5e3df69
--- /dev/null
+++ b/egs/librispeech/ASR/local/compile_hlg_using_openfst.py
@@ -0,0 +1,184 @@
+#!/usr/bin/env python3
+# Copyright    2022  Xiaomi Corp.        (authors: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""
+This script takes as input lang_dir and generates HLG from
+
+    - H, the ctc topology, built from tokens contained in lang_dir/lexicon.txt
+    - L, the lexicon, built from lang_dir/L_disambig.fst
+
+        Caution: We use a lexicon that contains disambiguation symbols
+
+    - G, the LM, built from data/lm/G_3_gram.fst.txt
+
+The generated HLG is saved in $lang_dir/HLG_fst.pt
+
+So when to use this script instead of ./local/compile_hlg.py ?
+If you have a very large G, ./local/compile_hlg.py may throw OOM for
+determinization. In that case, you can use this script to compile HLG.
+"""
+
+import argparse
+import logging
+from pathlib import Path
+
+import k2
+import kaldifst
+import torch
+
+from icefall.lexicon import Lexicon
+
+
+def get_args():
+    parser = argparse.ArgumentParser()
+    parser.add_argument(
+        "--lang-dir",
+        type=str,
+        help="""Input and output directory.
+        """,
+    )
+
+    return parser.parse_args()
+
+
+def compile_HLG(lang_dir: str) -> kaldifst.StdVectorFst:
+    """
+    Args:
+      lang_dir:
+        The language directory, e.g., data/lang_phone or data/lang_bpe_5000.
+
+    Return:
+      An FST representing HLG.
+    """
+
+    L = kaldifst.StdVectorFst.read(f"{lang_dir}/L_disambig.fst")
+    logging.info("Arc sort L")
+    kaldifst.arcsort(L, sort_type="olabel")
+    logging.info(f"L: #states {L.num_states}")
+
+    G_filename_txt = "data/lm/G_3_gram.fst.txt"
+    G_filename_binary = "data/lm/G_3_gram.fst"
+    if Path(G_filename_binary).is_file():
+        logging.info(f"Loading {G_filename_binary}")
+        G = kaldifst.StdVectorFst.read(G_filename_binary)
+    else:
+        logging.info(f"Loading {G_filename_txt}")
+        with open(G_filename_txt) as f:
+            G = kaldifst.compile(s=f.read(), acceptor=False)
+            logging.info(f"Saving G to {G_filename_binary}")
+            G.write(G_filename_binary)
+
+    logging.info("Arc sort G")
+    kaldifst.arcsort(G, sort_type="ilabel")
+
+    logging.info(f"G: #states {G.num_states}")
+
+    logging.info("Compose L and G and connect LG")
+    LG = kaldifst.compose(L, G, connect=True)
+    logging.info(f"LG: #states {LG.num_states}")
+
+    logging.info("Determinizestar LG")
+    kaldifst.determinize_star(LG)
+    logging.info(f"LG after determinize_star: #states {LG.num_states}")
+
+    logging.info("Minimize encoded LG")
+    kaldifst.minimize_encoded(LG)
+    logging.info(f"LG after minimize_encoded: #states {LG.num_states}")
+
+    logging.info("Converting LG to k2 format")
+    LG = k2.Fsa.from_openfst(LG.to_str(is_acceptor=False), acceptor=False)
+    logging.info(f"LG in k2: #states: {LG.shape[0]}, #arcs: {LG.num_arcs}")
+
+    lexicon = Lexicon(lang_dir)
+
+    first_token_disambig_id = lexicon.token_table["#0"]
+    first_word_disambig_id = lexicon.word_table["#0"]
+    logging.info(f"token id for #0: {first_token_disambig_id}")
+    logging.info(f"word id for #0: {first_word_disambig_id}")
+
+    max_token_id = max(lexicon.tokens)
+    modified = False
+    logging.info(
+        f"Building ctc_topo. modified: {modified}, max_token_id: {max_token_id}"
+    )
+
+    H = k2.ctc_topo(max_token_id, modified=modified)
+    logging.info(f"H: #states: {H.shape[0]}, #arcs: {H.num_arcs}")
+
+    logging.info("Removing disambiguation symbols on LG")
+    LG.labels[LG.labels >= first_token_disambig_id] = 0
+    LG.aux_labels[LG.aux_labels >= first_word_disambig_id] = 0
+
+    # See https://github.com/k2-fsa/k2/issues/874
+    # for why we need to set LG.properties to None
+    LG.__dict__["_properties"] = None
+
+    logging.info("Removing epsilons from LG")
+    LG = k2.remove_epsilon(LG)
+    logging.info(
+        f"LG after k2.remove_epsilon: #states: {LG.shape[0]}, #arcs: {LG.num_arcs}"
+    )
+
+    logging.info("Connecting LG after removing epsilons")
+    LG = k2.connect(LG)
+    LG.aux_labels = LG.aux_labels.remove_values_eq(0)
+    logging.info(f"LG after k2.connect: #states: {LG.shape[0]}, #arcs: {LG.num_arcs}")
+
+    logging.info("Arc sorting LG")
+    LG = k2.arc_sort(LG)
+
+    logging.info("Composing H and LG")
+
+    HLG = k2.compose(H, LG, inner_labels="tokens")
+    logging.info(
+        f"HLG after k2.compose: #states: {HLG.shape[0]}, #arcs: {HLG.num_arcs}"
+    )
+
+    logging.info("Connecting HLG")
+    HLG = k2.connect(HLG)
+    logging.info(
+        f"HLG after k2.connect: #states: {HLG.shape[0]}, #arcs: {HLG.num_arcs}"
+    )
+
+    logging.info("Arc sorting LG")
+    HLG = k2.arc_sort(HLG)
+
+    return HLG
+
+
+def main():
+    args = get_args()
+    lang_dir = Path(args.lang_dir)
+
+    filename = lang_dir / "HLG_fst.pt"
+
+    if filename.is_file():
+        logging.info(f"{filename} already exists - skipping")
+        return
+
+    HLG = compile_HLG(lang_dir)
+    logging.info(f"Saving HLG to {filename}")
+    torch.save(HLG.as_dict(), filename)
+
+
+if __name__ == "__main__":
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+    logging.basicConfig(format=formatter, level=logging.INFO)
+
+    main()
diff --git a/egs/librispeech/ASR/prepare.sh b/egs/librispeech/ASR/prepare.sh
index 542bbcdd8..11c8e1066 100755
--- a/egs/librispeech/ASR/prepare.sh
+++ b/egs/librispeech/ASR/prepare.sh
@@ -44,9 +44,9 @@ dl_dir=$PWD/download
 # It will generate data/lang_bpe_xxx,
 # data/lang_bpe_yyy if the array contains xxx, yyy
 vocab_sizes=(
-  5000
-  2000
-  1000
+  # 5000
+  # 2000
+  # 1000
   500
 )
 
@@ -168,6 +168,22 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
   if [ ! -f $lang_dir/L_disambig.pt ]; then
     ./local/prepare_lang.py --lang-dir $lang_dir
   fi
+
+  if [ ! -f $lang_dir/L.fst ]; then
+    log "Converting L.pt to L.fst"
+    ./shared/convert-k2-to-openfst.py \
+      --olabels aux_labels \
+      $lang_dir/L.pt \
+      $lang_dir/L.fst
+  fi
+
+  if [ ! -f $lang_dir/L_disambig.fst ]; then
+    log "Converting L_disambig.pt to L_disambig.fst"
+    ./shared/convert-k2-to-openfst.py \
+      --olabels aux_labels \
+      $lang_dir/L_disambig.pt \
+      $lang_dir/disambig_L.fst
+  fi
 fi
 
 
@@ -208,6 +224,22 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
         --lexicon $lang_dir/lexicon.txt \
         --bpe-model $lang_dir/bpe.model
     fi
+
+    if [ ! -f $lang_dir/L.fst ]; then
+      log "Converting L.pt to L.fst"
+      ./shared/convert-k2-to-openfst.py \
+        --olabels aux_labels \
+        $lang_dir/L.pt \
+        $lang_dir/L.fst
+    fi
+
+    if [ ! -f $lang_dir/L_disambig.fst ]; then
+      log "Converting L_disambig.pt to L_disambig.fst"
+      ./shared/convert-k2-to-openfst.py \
+        --olabels aux_labels \
+        $lang_dir/L_disambig.pt \
+        $lang_dir/L_disambig.fst
+    fi
   done
 fi
 
@@ -270,10 +302,13 @@ fi
 if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
   log "Stage 9: Compile HLG"
   ./local/compile_hlg.py --lang-dir data/lang_phone
+  ./local/compile_hlg_using_openfst.py --lang-dir data/lang_phone
 
   for vocab_size in ${vocab_sizes[@]}; do
     lang_dir=data/lang_bpe_${vocab_size}
     ./local/compile_hlg.py --lang-dir $lang_dir
+
+    ./local/compile_hlg_using_openfst.py --lang-dir $lang_dir
   done
 fi
 
diff --git a/icefall/shared/convert-k2-to-openfst.py b/icefall/shared/convert-k2-to-openfst.py
new file mode 100755
index 000000000..29a2cd7f7
--- /dev/null
+++ b/icefall/shared/convert-k2-to-openfst.py
@@ -0,0 +1,102 @@
+#!/usr/bin/env python3
+# Copyright    2022  Xiaomi Corp.        (authors: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+This script takes as input an FST in k2 format and convert it
+to an FST in OpenFST format.
+
+The generated FST is saved into a binary file and its type is
+StdVectorFst.
+
+Usage examples:
+(1) Convert an acceptor
+
+  ./convert-k2-to-openfst.py in.pt binary.fst
+
+(2) Convert a transducer
+
+  ./convert-k2-to-openfst.py --olabels aux_labels in.pt binary.fst
+"""
+
+import argparse
+import logging
+from pathlib import Path
+
+import k2
+import kaldifst.utils
+import torch
+
+
+def get_args():
+    parser = argparse.ArgumentParser()
+    parser.add_argument(
+        "--olabels",
+        type=str,
+        default=None,
+        help="""If not empty, the input FST is assumed to be a transducer
+        and we use its attribute specified by "olabels" as the output labels.
+        """,
+    )
+    parser.add_argument(
+        "input_filename",
+        type=str,
+        help="Path to the input FST in k2 format",
+    )
+
+    parser.add_argument(
+        "output_filename",
+        type=str,
+        help="Path to the output FST in OpenFst format",
+    )
+
+    return parser.parse_args()
+
+
+def main():
+    args = get_args()
+    logging.info(f"{vars(args)}")
+
+    input_filename = args.input_filename
+    output_filename = args.output_filename
+    olabels = args.olabels
+
+    if Path(output_filename).is_file():
+        logging.info(f"{output_filename} already exists - skipping")
+        return
+
+    assert Path(input_filename).is_file(), f"{input_filename} does not exist"
+    logging.info(f"Loading {input_filename}")
+    k2_fst = k2.Fsa.from_dict(torch.load(input_filename))
+    if olabels:
+        assert hasattr(k2_fst, olabels), f"No such attribute: {olabels}"
+
+    p = Path(output_filename).parent
+    if not p.is_dir():
+        logging.info(f"Creating {p}")
+        p.mkdir(parents=True)
+
+    logging.info("Converting (May take some time if the input FST is large)")
+    fst = kaldifst.utils.k2_to_openfst(k2_fst, olabels=olabels)
+    logging.info(f"Saving to {output_filename}")
+    fst.write(output_filename)
+
+
+if __name__ == "__main__":
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+    logging.basicConfig(format=formatter, level=logging.INFO)
+    main()
diff --git a/requirements.txt b/requirements.txt
index 5e32af853..a07f6b7c7 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,3 +1,4 @@
+kaldifst
 kaldilm
 kaldialign
 sentencepiece>=0.1.96

From a0cf85343dad31a678ddaac7652f0bb2bbb4cac2 Mon Sep 17 00:00:00 2001
From: Yifan Yang <64255737+yfyeung@users.noreply.github.com>
Date: Fri, 9 Dec 2022 19:23:11 +0800
Subject: [PATCH 042/174] fix for memory usage in
 pruned_transducer_stateless7/scaling.py (#752)

Co-authored-by: yifanyang 
---
 egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py
index 6f63e0629..042c9c3e4 100644
--- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py
@@ -562,7 +562,7 @@ class ActivationBalancer(torch.nn.Module):
                 sign_factor = None
 
             scale_factor = _compute_scale_factor(
-                x,
+                x.detach(),
                 self.channel_dim,
                 min_abs=self.min_abs,
                 max_abs=self.max_abs,

From c4aaf3ea3bfcebdad79f4e9d10080ed514113830 Mon Sep 17 00:00:00 2001
From: Desh Raj 
Date: Sat, 10 Dec 2022 15:45:23 +0530
Subject: [PATCH 043/174] Add AliMeeting multi-condition training recipe (#751)

* add AliMeeting multi-domain recipe

* convert scripts to symbolic links
---
 egs/alimeeting/ASR_v2/README.md               |   38 +
 egs/alimeeting/ASR_v2/RESULTS.md              |   90 ++
 egs/alimeeting/ASR_v2/local/__init__.py       |    0
 .../ASR_v2/local/compute_fbank_alimeeting.py  |  193 +++
 .../ASR_v2/local/compute_fbank_musan.py       |    1 +
 .../local/prepare_alimeeting_enhanced.py      |  158 +++
 .../ASR_v2/local/prepare_alimeeting_gss.sh    |   98 ++
 egs/alimeeting/ASR_v2/local/prepare_char.py   |    1 +
 egs/alimeeting/ASR_v2/local/prepare_words.py  |    1 +
 egs/alimeeting/ASR_v2/local/text2segments.py  |    1 +
 egs/alimeeting/ASR_v2/local/text2token.py     |    1 +
 egs/alimeeting/ASR_v2/prepare.sh              |  125 ++
 .../pruned_transducer_stateless7/__init__.py  |    0
 .../asr_datamodule.py                         |  419 ++++++
 .../beam_search.py                            |    1 +
 .../pruned_transducer_stateless7/decode.py    |  698 ++++++++++
 .../pruned_transducer_stateless7/decoder.py   |    1 +
 .../encoder_interface.py                      |    1 +
 .../pruned_transducer_stateless7/export.py    |  320 +++++
 .../jit_pretrained.py                         |    1 +
 .../pruned_transducer_stateless7/joiner.py    |    1 +
 .../pruned_transducer_stateless7/model.py     |    1 +
 .../pruned_transducer_stateless7/optim.py     |    1 +
 .../pretrained.py                             |    1 +
 .../pruned_transducer_stateless7/scaling.py   |    1 +
 .../scaling_converter.py                      |    1 +
 .../test_model.py                             |    1 +
 .../pruned_transducer_stateless7/train.py     | 1186 +++++++++++++++++
 .../pruned_transducer_stateless7/zipformer.py |    1 +
 egs/alimeeting/ASR_v2/shared                  |    1 +
 30 files changed, 3343 insertions(+)
 create mode 100644 egs/alimeeting/ASR_v2/README.md
 create mode 100644 egs/alimeeting/ASR_v2/RESULTS.md
 create mode 100644 egs/alimeeting/ASR_v2/local/__init__.py
 create mode 100755 egs/alimeeting/ASR_v2/local/compute_fbank_alimeeting.py
 create mode 120000 egs/alimeeting/ASR_v2/local/compute_fbank_musan.py
 create mode 100644 egs/alimeeting/ASR_v2/local/prepare_alimeeting_enhanced.py
 create mode 100755 egs/alimeeting/ASR_v2/local/prepare_alimeeting_gss.sh
 create mode 120000 egs/alimeeting/ASR_v2/local/prepare_char.py
 create mode 120000 egs/alimeeting/ASR_v2/local/prepare_words.py
 create mode 120000 egs/alimeeting/ASR_v2/local/text2segments.py
 create mode 120000 egs/alimeeting/ASR_v2/local/text2token.py
 create mode 100755 egs/alimeeting/ASR_v2/prepare.sh
 create mode 100644 egs/alimeeting/ASR_v2/pruned_transducer_stateless7/__init__.py
 create mode 100644 egs/alimeeting/ASR_v2/pruned_transducer_stateless7/asr_datamodule.py
 create mode 120000 egs/alimeeting/ASR_v2/pruned_transducer_stateless7/beam_search.py
 create mode 100755 egs/alimeeting/ASR_v2/pruned_transducer_stateless7/decode.py
 create mode 120000 egs/alimeeting/ASR_v2/pruned_transducer_stateless7/decoder.py
 create mode 120000 egs/alimeeting/ASR_v2/pruned_transducer_stateless7/encoder_interface.py
 create mode 100755 egs/alimeeting/ASR_v2/pruned_transducer_stateless7/export.py
 create mode 120000 egs/alimeeting/ASR_v2/pruned_transducer_stateless7/jit_pretrained.py
 create mode 120000 egs/alimeeting/ASR_v2/pruned_transducer_stateless7/joiner.py
 create mode 120000 egs/alimeeting/ASR_v2/pruned_transducer_stateless7/model.py
 create mode 120000 egs/alimeeting/ASR_v2/pruned_transducer_stateless7/optim.py
 create mode 120000 egs/alimeeting/ASR_v2/pruned_transducer_stateless7/pretrained.py
 create mode 120000 egs/alimeeting/ASR_v2/pruned_transducer_stateless7/scaling.py
 create mode 120000 egs/alimeeting/ASR_v2/pruned_transducer_stateless7/scaling_converter.py
 create mode 120000 egs/alimeeting/ASR_v2/pruned_transducer_stateless7/test_model.py
 create mode 100755 egs/alimeeting/ASR_v2/pruned_transducer_stateless7/train.py
 create mode 120000 egs/alimeeting/ASR_v2/pruned_transducer_stateless7/zipformer.py
 create mode 120000 egs/alimeeting/ASR_v2/shared

diff --git a/egs/alimeeting/ASR_v2/README.md b/egs/alimeeting/ASR_v2/README.md
new file mode 100644
index 000000000..f70327501
--- /dev/null
+++ b/egs/alimeeting/ASR_v2/README.md
@@ -0,0 +1,38 @@
+
+# Introduction
+
+This recipe trains multi-domain ASR models for AliMeeting. By multi-domain, we mean that
+we train a single model on close-talk and far-field conditions. This recipe optionally
+uses [GSS]-based enhancement for far-field array microphone.
+We pool data in the following 4 ways and train a single model on the pooled data:
+
+(i) individual headset microphone (IHM)
+(ii) IHM with simulated reverb
+(iii) Single distant microphone (SDM)
+(iv) GSS-enhanced array microphones
+
+This is different from `alimeeting/ASR` since that recipe trains a model only on the
+far-field audio. Additionally, we use text normalization here similar to the original
+M2MeT challenge, so the results should be more comparable to those from Table 4 of
+the [paper](https://arxiv.org/abs/2110.07393).
+
+The following additional packages need to be installed to run this recipe:
+* `pip install jieba`
+* `pip install paddlepaddle`
+* `pip install git+https://github.com/desh2608/gss.git`
+
+[./RESULTS.md](./RESULTS.md) contains the latest results.
+
+## Performance Record
+
+### pruned_transducer_stateless7
+
+The following are decoded using `modified_beam_search`:
+
+| Evaluation set           | eval WER    | test WER |
+|--------------------------|------------|---------|
+| IHM                      |  9.58  | 11.53 |
+| SDM                      |  23.37  | 25.85 |
+| MDM (GSS-enhanced)       |  11.82  | 14.22 |
+
+See [RESULTS](/egs/alimeeting/ASR_v2/RESULTS.md) for details.
diff --git a/egs/alimeeting/ASR_v2/RESULTS.md b/egs/alimeeting/ASR_v2/RESULTS.md
new file mode 100644
index 000000000..15b24250d
--- /dev/null
+++ b/egs/alimeeting/ASR_v2/RESULTS.md
@@ -0,0 +1,90 @@
+## Results (CER)
+
+#### 2022-12-09
+
+#### Zipformer (pruned_transducer_stateless7)
+
+Zipformer encoder + non-current decoder. The decoder
+contains only an embedding layer, a Conv1d (with kernel size 2) and a linear
+layer (to transform tensor dim).
+
+All the results below are using a single model that is trained by combining the following
+data: IHM, IHM+reverb, SDM, and GSS-enhanced MDM. Speed perturbation and MUSAN noise
+augmentation are applied on top of the pooled data.
+
+**WERs for IHM:**
+
+|                           | eval | test | comment                                  |
+|---------------------------|------------|------------|------------------------------------------|
+| greedy search             |  10.13  |  12.21  | --epoch 15 --avg 8 --max-duration 500 |
+| modified beam search      |  9.58  |  11.53  | --epoch 15 --avg 8 --max-duration 500 --beam-size 4 |
+| fast beam search          |  9.92  |  12.07  | --epoch 15 --avg 8 --max-duration 500 --beam-size 4 --max-contexts 4 --max-states 8 |
+
+**WERs for SDM:**
+
+|                           | eval | test | comment                                  |
+|---------------------------|------------|------------|------------------------------------------|
+| greedy search             |  23.70  |  26.41  | --epoch 15 --avg 8 --max-duration 500 |
+| modified beam search      |  23.37  |  25.85  | --epoch 15 --avg 8 --max-duration 500 --beam-size 4 |
+| fast beam search          |  23.60  |  26.38  | --epoch 15 --avg 8 --max-duration 500 --beam-size 4 --max-contexts 4 --max-states 8 |
+
+**WERs for GSS-enhanced MDM:**
+
+|                           | eval | test | comment                                  |
+|---------------------------|------------|------------|------------------------------------------|
+| greedy search             |  12.24  |  14.99  | --epoch 15 --avg 8 --max-duration 500 |
+| modified beam search      |  11.82  |  14.22  | --epoch 15 --avg 8 --max-duration 500 --beam-size 4 |
+| fast beam search          |  12.30  |  14.98  | --epoch 15 --avg 8 --max-duration 500 --beam-size 4 --max-contexts 4 --max-states 8 |
+
+The training command for reproducing is given below:
+
+```
+export CUDA_VISIBLE_DEVICES="0,1,2,3"
+
+./pruned_transducer_stateless7/train.py \
+  --world-size 4 \
+  --num-epochs 15 \
+  --exp-dir pruned_transducer_stateless7/exp \
+  --max-duration 300 \
+  --max-cuts 100 \
+  --prune-range 5 \
+  --lr-factor 5 \
+  --lm-scale 0.25 \
+  --use-fp16 True
+```
+
+The decoding command is:
+```
+# greedy search
+./pruned_transducer_stateless7/decode.py \
+        --epoch 15 \
+        --avg 8 \
+        --exp-dir ./pruned_transducer_stateless7/exp \
+        --max-duration 500 \
+        --decoding-method greedy_search
+
+# modified beam search
+./pruned_transducer_stateless7/decode.py \
+        --epoch 15 \
+        --avg 8 \
+        --exp-dir ./pruned_transducer_stateless7/exp \
+        --max-duration 500 \
+        --decoding-method modified_beam_search \
+        --beam-size 4
+
+# fast beam search
+./pruned_transducer_stateless7/decode.py \
+        --epoch 15 \
+        --avg 8 \
+        --exp-dir ./pruned_transducer_stateless5/exp \
+        --max-duration 500 \
+        --decoding-method fast_beam_search \
+        --beam 4 \
+        --max-contexts 4 \
+        --max-states 8
+```
+
+Pretrained model is available at 
+
+The tensorboard training log can be found at
+
diff --git a/egs/alimeeting/ASR_v2/local/__init__.py b/egs/alimeeting/ASR_v2/local/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/egs/alimeeting/ASR_v2/local/compute_fbank_alimeeting.py b/egs/alimeeting/ASR_v2/local/compute_fbank_alimeeting.py
new file mode 100755
index 000000000..c6aa2ab36
--- /dev/null
+++ b/egs/alimeeting/ASR_v2/local/compute_fbank_alimeeting.py
@@ -0,0 +1,193 @@
+#!/usr/bin/env python3
+# Copyright    2022  Johns Hopkins University        (authors: Desh Raj)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""
+This file computes fbank features of the AliMeeting dataset.
+For the training data, we prepare IHM, reverberated IHM, SDM, and GSS-enhanced
+audios. For the test data, we separately prepare IHM, SDM, and GSS-enhanced
+parts (which are the 3 evaluation settings).
+It looks for manifests in the directory data/manifests.
+
+The generated fbank features are saved in data/fbank.
+"""
+import logging
+from pathlib import Path
+
+import torch
+import torch.multiprocessing
+from lhotse import CutSet, LilcomChunkyWriter
+from lhotse.features.kaldifeat import (
+    KaldifeatFbank,
+    KaldifeatFbankConfig,
+    KaldifeatFrameOptions,
+    KaldifeatMelOptions,
+)
+from lhotse.recipes.utils import read_manifests_if_cached
+
+# Torch's multithreaded behavior needs to be disabled or
+# it wastes a lot of CPU and slow things down.
+# Do this outside of main() in case it needs to take effect
+# even when we are not invoking the main (e.g. when spawning subprocesses).
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+torch.multiprocessing.set_sharing_strategy("file_system")
+
+
+def compute_fbank_ami():
+    src_dir = Path("data/manifests")
+    output_dir = Path("data/fbank")
+
+    sampling_rate = 16000
+    num_mel_bins = 80
+
+    extractor = KaldifeatFbank(
+        KaldifeatFbankConfig(
+            frame_opts=KaldifeatFrameOptions(sampling_rate=sampling_rate),
+            mel_opts=KaldifeatMelOptions(num_bins=num_mel_bins),
+            device="cuda",
+        )
+    )
+
+    logging.info("Reading manifests")
+    manifests_ihm = read_manifests_if_cached(
+        dataset_parts=["train", "eval", "test"],
+        output_dir=src_dir,
+        prefix="alimeeting-ihm",
+        suffix="jsonl.gz",
+    )
+    manifests_sdm = read_manifests_if_cached(
+        dataset_parts=["train", "eval", "test"],
+        output_dir=src_dir,
+        prefix="alimeeting-sdm",
+        suffix="jsonl.gz",
+    )
+    # For GSS we already have cuts so we read them directly.
+    manifests_gss = read_manifests_if_cached(
+        dataset_parts=["train", "eval", "test"],
+        output_dir=src_dir,
+        prefix="alimeeting-gss",
+        suffix="jsonl.gz",
+    )
+
+    def _extract_feats(cuts: CutSet, storage_path: Path, manifest_path: Path) -> None:
+        cuts = cuts + cuts.perturb_speed(0.9) + cuts.perturb_speed(1.1)
+        _ = cuts.compute_and_store_features_batch(
+            extractor=extractor,
+            storage_path=storage_path,
+            manifest_path=manifest_path,
+            batch_duration=5000,
+            num_workers=8,
+            storage_type=LilcomChunkyWriter,
+        )
+
+    logging.info(
+        "Preparing training cuts: IHM + reverberated IHM + SDM + GSS (optional)"
+    )
+
+    logging.info("Processing train split IHM")
+    cuts_ihm = (
+        CutSet.from_manifests(**manifests_ihm["train"])
+        .trim_to_supervisions(keep_overlapping=False, keep_all_channels=False)
+        .modify_ids(lambda x: x + "-ihm")
+    )
+    _extract_feats(
+        cuts_ihm,
+        output_dir / "feats_train_ihm",
+        src_dir / "cuts_train_ihm.jsonl.gz",
+    )
+
+    logging.info("Processing train split IHM + reverberated IHM")
+    cuts_ihm_rvb = cuts_ihm.reverb_rir()
+    _extract_feats(
+        cuts_ihm_rvb,
+        output_dir / "feats_train_ihm_rvb",
+        src_dir / "cuts_train_ihm_rvb.jsonl.gz",
+    )
+
+    logging.info("Processing train split SDM")
+    cuts_sdm = (
+        CutSet.from_manifests(**manifests_sdm["train"])
+        .trim_to_supervisions(keep_overlapping=False)
+        .modify_ids(lambda x: x + "-sdm")
+    )
+    _extract_feats(
+        cuts_sdm,
+        output_dir / "feats_train_sdm",
+        src_dir / "cuts_train_sdm.jsonl.gz",
+    )
+
+    logging.info("Processing train split GSS")
+    cuts_gss = (
+        CutSet.from_manifests(**manifests_gss["train"])
+        .trim_to_supervisions(keep_overlapping=False)
+        .modify_ids(lambda x: x + "-gss")
+    )
+    _extract_feats(
+        cuts_gss,
+        output_dir / "feats_train_gss",
+        src_dir / "cuts_train_gss.jsonl.gz",
+    )
+
+    logging.info("Preparing test cuts: IHM, SDM, GSS (optional)")
+    for split in ["eval", "test"]:
+        logging.info(f"Processing {split} IHM")
+        cuts_ihm = (
+            CutSet.from_manifests(**manifests_ihm[split])
+            .trim_to_supervisions(keep_overlapping=False, keep_all_channels=False)
+            .compute_and_store_features_batch(
+                extractor=extractor,
+                storage_path=output_dir / f"feats_{split}_ihm",
+                manifest_path=src_dir / f"cuts_{split}_ihm.jsonl.gz",
+                batch_duration=500,
+                num_workers=4,
+                storage_type=LilcomChunkyWriter,
+            )
+        )
+        logging.info(f"Processing {split} SDM")
+        cuts_sdm = (
+            CutSet.from_manifests(**manifests_sdm[split])
+            .trim_to_supervisions(keep_overlapping=False)
+            .compute_and_store_features_batch(
+                extractor=extractor,
+                storage_path=output_dir / f"feats_{split}_sdm",
+                manifest_path=src_dir / f"cuts_{split}_sdm.jsonl.gz",
+                batch_duration=500,
+                num_workers=4,
+                storage_type=LilcomChunkyWriter,
+            )
+        )
+        logging.info(f"Processing {split} GSS")
+        cuts_gss = (
+            CutSet.from_manifests(**manifests_gss[split])
+            .trim_to_supervisions(keep_overlapping=False)
+            .compute_and_store_features_batch(
+                extractor=extractor,
+                storage_path=output_dir / f"feats_{split}_gss",
+                manifest_path=src_dir / f"cuts_{split}_gss.jsonl.gz",
+                batch_duration=500,
+                num_workers=4,
+                storage_type=LilcomChunkyWriter,
+            )
+        )
+
+
+if __name__ == "__main__":
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    logging.basicConfig(format=formatter, level=logging.INFO)
+
+    compute_fbank_ami()
diff --git a/egs/alimeeting/ASR_v2/local/compute_fbank_musan.py b/egs/alimeeting/ASR_v2/local/compute_fbank_musan.py
new file mode 120000
index 000000000..5833f2484
--- /dev/null
+++ b/egs/alimeeting/ASR_v2/local/compute_fbank_musan.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/local/compute_fbank_musan.py
\ No newline at end of file
diff --git a/egs/alimeeting/ASR_v2/local/prepare_alimeeting_enhanced.py b/egs/alimeeting/ASR_v2/local/prepare_alimeeting_enhanced.py
new file mode 100644
index 000000000..f1512efa5
--- /dev/null
+++ b/egs/alimeeting/ASR_v2/local/prepare_alimeeting_enhanced.py
@@ -0,0 +1,158 @@
+#!/usr/local/bin/python
+# -*- coding: utf-8 -*-
+# Data preparation for AliMeeting GSS-enhanced dataset.
+
+import logging
+from concurrent.futures import ThreadPoolExecutor
+from pathlib import Path
+
+from lhotse import Recording, RecordingSet, SupervisionSet
+from lhotse.qa import fix_manifests
+from lhotse.recipes.utils import read_manifests_if_cached
+from lhotse.utils import fastcopy
+from tqdm import tqdm
+
+logging.basicConfig(
+    format="%(asctime)s %(levelname)-8s %(message)s",
+    level=logging.INFO,
+    datefmt="%Y-%m-%d %H:%M:%S",
+)
+
+
+def get_args():
+    import argparse
+
+    parser = argparse.ArgumentParser(description="AMI enhanced dataset preparation.")
+    parser.add_argument(
+        "manifests_dir",
+        type=Path,
+        help="Path to directory containing AliMeeting manifests.",
+    )
+    parser.add_argument(
+        "enhanced_dir",
+        type=Path,
+        help="Path to enhanced data directory.",
+    )
+    parser.add_argument(
+        "--num-jobs",
+        "-j",
+        type=int,
+        default=1,
+        help="Number of parallel jobs to run.",
+    )
+    parser.add_argument(
+        "--min-segment-duration",
+        "-d",
+        type=float,
+        default=0.0,
+        help="Minimum duration of a segment in seconds.",
+    )
+    return parser.parse_args()
+
+
+def find_recording_and_create_new_supervision(enhanced_dir, supervision):
+    """
+    Given a supervision (corresponding to original AMI recording), this function finds the
+    enhanced recording correspoding to the supervision, and returns this recording and
+    a new supervision whose start and end times are adjusted to match the enhanced recording.
+    """
+    file_name = Path(
+        f"{supervision.recording_id}-{supervision.speaker}-{int(100*supervision.start):06d}_{int(100*supervision.end):06d}.flac"
+    )
+    save_path = enhanced_dir / f"{supervision.recording_id}" / file_name
+    if save_path.exists():
+        recording = Recording.from_file(save_path)
+        if recording.duration == 0:
+            logging.warning(f"Skipping {save_path} which has duration 0 seconds.")
+            return None
+
+        # Old supervision is wrt to the original recording, we create new supervision
+        # wrt to the enhanced segment
+        new_supervision = fastcopy(
+            supervision,
+            recording_id=recording.id,
+            start=0,
+            duration=recording.duration,
+        )
+        return recording, new_supervision
+    else:
+        logging.warning(f"{save_path} does not exist.")
+        return None
+
+
+def main(args):
+    # Get arguments
+    manifests_dir = args.manifests_dir
+    enhanced_dir = args.enhanced_dir
+
+    # Load manifests from cache if they exist (saves time)
+    manifests = read_manifests_if_cached(
+        dataset_parts=["train", "eval", "test"],
+        output_dir=manifests_dir,
+        prefix="alimeeting-sdm",
+        suffix="jsonl.gz",
+    )
+    if not manifests:
+        raise ValueError(
+            "AliMeeting SDM manifests not found in {}".format(manifests_dir)
+        )
+
+    with ThreadPoolExecutor(args.num_jobs) as ex:
+        for part in ["train", "eval", "test"]:
+            logging.info(f"Processing {part}...")
+            supervisions_orig = manifests[part]["supervisions"].filter(
+                lambda s: s.duration >= args.min_segment_duration
+            )
+            futures = []
+
+            for supervision in tqdm(
+                supervisions_orig,
+                desc="Distributing tasks",
+            ):
+                futures.append(
+                    ex.submit(
+                        find_recording_and_create_new_supervision,
+                        enhanced_dir,
+                        supervision,
+                    )
+                )
+
+            recordings = []
+            supervisions = []
+            for future in tqdm(
+                futures,
+                total=len(futures),
+                desc="Processing tasks",
+            ):
+                result = future.result()
+                if result is not None:
+                    recording, new_supervision = result
+                    recordings.append(recording)
+                    supervisions.append(new_supervision)
+
+            # Remove duplicates from the recordings
+            recordings_nodup = {}
+            for recording in recordings:
+                if recording.id not in recordings_nodup:
+                    recordings_nodup[recording.id] = recording
+                else:
+                    logging.warning("Recording {} is duplicated.".format(recording.id))
+            recordings = RecordingSet.from_recordings(recordings_nodup.values())
+            supervisions = SupervisionSet.from_segments(supervisions)
+
+            recordings, supervisions = fix_manifests(
+                recordings=recordings, supervisions=supervisions
+            )
+
+            logging.info(f"Writing {part} enhanced manifests")
+            recordings.to_file(
+                manifests_dir / f"alimeeting-gss_recordings_{part}.jsonl.gz"
+            )
+            supervisions.to_file(
+                manifests_dir / f"alimeeting-gss_supervisions_{part}.jsonl.gz"
+            )
+
+
+if __name__ == "__main__":
+    args = get_args()
+    main(args)
diff --git a/egs/alimeeting/ASR_v2/local/prepare_alimeeting_gss.sh b/egs/alimeeting/ASR_v2/local/prepare_alimeeting_gss.sh
new file mode 100755
index 000000000..76db19832
--- /dev/null
+++ b/egs/alimeeting/ASR_v2/local/prepare_alimeeting_gss.sh
@@ -0,0 +1,98 @@
+#!/bin/bash
+# This script is used to run GSS-based enhancement on AMI data.
+set -euo pipefail
+nj=4
+stage=0
+
+. shared/parse_options.sh || exit 1
+
+if [ $# != 2 ]; then
+   echo "Wrong #arguments ($#, expected 2)"
+   echo "Usage: local/prepare_alimeeting_gss.sh [options]  "
+   echo "e.g. local/prepare_alimeeting_gss.sh data/manifests exp/ami_gss"
+   echo "main options (for others, see top of script file)"
+   echo "  --nj                                 # number of parallel jobs"
+   echo "  --stage                           # stage to start running from"
+   exit 1;
+fi
+
+DATA_DIR=$1
+EXP_DIR=$2
+
+mkdir -p $EXP_DIR
+
+log() {
+  # This function is from espnet
+  local fname=${BASH_SOURCE[1]##*/}
+  echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
+}
+
+if [ $stage -le 1 ]; then
+  log "Stage 1: Prepare cut sets"
+  for part in train eval test; do
+    lhotse cut simple \
+      -r $DATA_DIR/alimeeting-mdm_recordings_${part}.jsonl.gz \
+      -s $DATA_DIR/alimeeting-mdm_supervisions_${part}.jsonl.gz \
+      $EXP_DIR/cuts_${part}.jsonl.gz
+  done
+fi
+
+if [ $stage -le 2 ]; then
+  log "Stage 2: Trim cuts to supervisions (1 cut per supervision segment)"
+  for part in train eval test; do
+    lhotse cut trim-to-supervisions --discard-overlapping \
+        $EXP_DIR/cuts_${part}.jsonl.gz $EXP_DIR/cuts_per_segment_${part}.jsonl.gz
+  done
+fi
+
+if [ $stage -le 3 ]; then
+  log "Stage 3: Split manifests for multi-GPU processing (optional)"
+  for part in train eval test; do
+    gss utils split $nj $EXP_DIR/cuts_per_segment_${part}.jsonl.gz \
+      $EXP_DIR/cuts_per_segment_${part}_split$nj
+  done
+fi
+
+if [ $stage -le 4 ]; then
+  log "Stage 4: Enhance train segments using GSS (requires GPU)"
+  # for train, we use smaller context and larger batches to speed-up processing
+  for JOB in $(seq $nj); do
+    gss enhance cuts $EXP_DIR/cuts_train.jsonl.gz \
+      $EXP_DIR/cuts_per_segment_train_split$nj/cuts_per_segment_train.JOB.jsonl.gz $EXP_DIR/enhanced \
+      --bss-iterations 10 \
+      --context-duration 5.0 \
+      --use-garbage-class \
+      --channels 0,1,2,3,4,5,6,7 \
+      --min-segment-length 0.05 \
+      --max-segment-length 25.0 \
+      --max-batch-duration 60.0 \
+      --num-buckets 4 \
+      --num-workers 4
+  done
+fi
+
+if [ $stage -le 5 ]; then
+  log "Stage 5: Enhance eval/test segments using GSS (using GPU)"
+  # for eval/test, we use larger context and smaller batches to get better quality
+  for part in eval test; do
+    for JOB in $(seq $nj); do
+      gss enhance cuts $EXP_DIR/cuts_${part}.jsonl.gz \
+      $EXP_DIR/cuts_per_segment_${part}_split$nj/cuts_per_segment_${part}.JOB.jsonl.gz \
+      $EXP_DIR/enhanced \
+      --bss-iterations 10 \
+      --context-duration 15.0 \
+      --use-garbage-class \
+      --channels 0,1,2,3,4,5,6,7 \
+      --min-segment-length 0.05 \
+      --max-segment-length 16.0 \
+      --max-batch-duration 45.0 \
+      --num-buckets 4 \
+      --num-workers 4
+    done
+  done
+fi
+
+if [ $stage -le 6 ]; then
+  log "Stage 6: Prepare manifests for GSS-enhanced data"
+  python local/prepare_alimeeting_enhanced.py $DATA_DIR $EXP_DIR/enhanced -j $nj --min-segment-duration 0.05
+fi
diff --git a/egs/alimeeting/ASR_v2/local/prepare_char.py b/egs/alimeeting/ASR_v2/local/prepare_char.py
new file mode 120000
index 000000000..ee5dd34f1
--- /dev/null
+++ b/egs/alimeeting/ASR_v2/local/prepare_char.py
@@ -0,0 +1 @@
+../../ASR/local/prepare_char.py
\ No newline at end of file
diff --git a/egs/alimeeting/ASR_v2/local/prepare_words.py b/egs/alimeeting/ASR_v2/local/prepare_words.py
new file mode 120000
index 000000000..970bfd60c
--- /dev/null
+++ b/egs/alimeeting/ASR_v2/local/prepare_words.py
@@ -0,0 +1 @@
+../../ASR/local/prepare_words.py
\ No newline at end of file
diff --git a/egs/alimeeting/ASR_v2/local/text2segments.py b/egs/alimeeting/ASR_v2/local/text2segments.py
new file mode 120000
index 000000000..bf4547794
--- /dev/null
+++ b/egs/alimeeting/ASR_v2/local/text2segments.py
@@ -0,0 +1 @@
+../../ASR/local/text2segments.py
\ No newline at end of file
diff --git a/egs/alimeeting/ASR_v2/local/text2token.py b/egs/alimeeting/ASR_v2/local/text2token.py
new file mode 120000
index 000000000..f6b8531b6
--- /dev/null
+++ b/egs/alimeeting/ASR_v2/local/text2token.py
@@ -0,0 +1 @@
+../../ASR/local/text2token.py
\ No newline at end of file
diff --git a/egs/alimeeting/ASR_v2/prepare.sh b/egs/alimeeting/ASR_v2/prepare.sh
new file mode 100755
index 000000000..76a108771
--- /dev/null
+++ b/egs/alimeeting/ASR_v2/prepare.sh
@@ -0,0 +1,125 @@
+#!/usr/bin/env bash
+
+set -eou pipefail
+
+stage=-1
+stop_stage=100
+use_gss=true  # Use GSS-based enhancement with MDM setting
+
+# We assume dl_dir (download dir) contains the following
+# directories and files. If not, they will be downloaded
+# by this script automatically.
+#
+#  - $dl_dir/alimeeting
+#     This directory contains the following files downloaded from
+#       https://openslr.org/62/
+#
+#     - Train_Ali_far.tar.gz
+#     - Train_Ali_near.tar.gz
+#     - Test_Ali.tar.gz
+#     - Eval_Ali.tar.gz
+#
+#  - $dl_dir/musan
+#      This directory contains the following directories downloaded from
+#       http://www.openslr.org/17/
+#
+#     - music
+#     - noise
+#     - speech
+
+dl_dir=$PWD/download
+
+. shared/parse_options.sh || exit 1
+
+# All files generated by this script are saved in "data".
+# You can safely remove "data" and rerun this script to regenerate it.
+mkdir -p data
+
+log() {
+  # This function is from espnet
+  local fname=${BASH_SOURCE[1]##*/}
+  echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
+}
+
+log "dl_dir: $dl_dir"
+
+if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
+  log "Stage 0: Download data"
+
+  if [ ! -f $dl_dir/alimeeting/Train_Ali_far.tar.gz ]; then
+    lhotse download ali-meeting $dl_dir/alimeeting
+  fi
+fi
+
+if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
+  log "Stage 1: Prepare alimeeting manifest"
+  # We assume that you have downloaded the alimeeting corpus
+  # to $dl_dir/alimeeting
+  for part in ihm sdm mdm; do
+    mkdir -p data/manifests/alimeeting
+    lhotse prepare ali-meeting --mic $part --save-mono --normalize-text m2met \
+      $dl_dir/alimeeting data/manifests
+  done
+fi
+
+if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
+  log "Stage 2: Prepare musan manifest"
+  # We assume that you have downloaded the musan corpus
+  # to data/musan
+  mkdir -p data/manifests
+  lhotse prepare musan $dl_dir/musan data/manifests
+fi
+
+if [ $stage -le 3 ] && [ $stop_stage -ge 3 ] && [ $use_gss = true ]; then
+  log "Stage 3: Apply GSS enhancement on MDM data (this stage requires a GPU)"
+  # We assume that you have installed the GSS package: https://github.com/desh2608/gss
+  local/prepare_alimeeting_gss.sh data/manifests exp/alimeeting_gss
+fi
+
+if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
+  log "Stage 4: Compute fbank for musan"
+  mkdir -p data/fbank
+  python local/compute_fbank_musan.py
+fi
+
+if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
+  log "Stage 5: Compute fbank for alimeeting"
+  mkdir -p data/fbank
+  python local/compute_fbank_alimeeting.py
+  log "Combine features from train splits"
+  lhotse combine data/manifests/cuts_train_{ihm,ihm_rvb,sdm,gss}.jsonl.gz - | shuf |\
+    gzip -c > data/manifests/cuts_train_all.jsonl.gz
+fi
+
+if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
+  log "Stage 6: Prepare char based lang"
+  lang_char_dir=data/lang_char
+  mkdir -p $lang_char_dir
+
+  # Prepare text.
+  # Note: in Linux, you can install jq with the  following command:
+  # wget -O jq https://github.com/stedolan/jq/releases/download/jq-1.6/jq-linux64
+  gunzip -c data/manifests/alimeeting-sdm_supervisions_train.jsonl.gz \
+    | jq ".text" | sed 's/"//g' \
+    | ./local/text2token.py -t "char" > $lang_char_dir/text
+
+  # Prepare words segments
+  python ./local/text2segments.py \
+    --input $lang_char_dir/text \
+    --output $lang_char_dir/text_words_segmentation
+
+  cat $lang_char_dir/text_words_segmentation | sed "s/ /\n/g" \
+    | sort -u | sed "/^$/d" \
+    | uniq > $lang_char_dir/words_no_ids.txt
+
+  # Prepare words.txt
+  if [ ! -f $lang_char_dir/words.txt ]; then
+    ./local/prepare_words.py \
+      --input-file $lang_char_dir/words_no_ids.txt \
+      --output-file $lang_char_dir/words.txt
+  fi
+
+  if [ ! -f $lang_char_dir/L_disambig.pt ]; then
+    ./local/prepare_char.py
+  fi
+fi
diff --git a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/__init__.py b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/asr_datamodule.py b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/asr_datamodule.py
new file mode 100644
index 000000000..1cfd053c7
--- /dev/null
+++ b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/asr_datamodule.py
@@ -0,0 +1,419 @@
+# Copyright      2021  Piotr Żelasko
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import argparse
+import logging
+import re
+from functools import lru_cache
+from pathlib import Path
+from typing import Any, Dict, Optional
+
+import torch
+from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
+from lhotse.cut import Cut
+from lhotse.dataset import (
+    CutConcatenate,
+    CutMix,
+    DynamicBucketingSampler,
+    K2SpeechRecognitionDataset,
+    PrecomputedFeatures,
+    SpecAugment,
+)
+from lhotse.dataset.input_strategies import OnTheFlyFeatures
+from lhotse.utils import fix_random_seed
+from torch.utils.data import DataLoader
+from tqdm import tqdm
+
+from icefall.utils import str2bool
+
+
+class _SeedWorkers:
+    def __init__(self, seed: int):
+        self.seed = seed
+
+    def __call__(self, worker_id: int):
+        fix_random_seed(self.seed + worker_id)
+
+
+class AlimeetingAsrDataModule:
+    """
+    DataModule for k2 ASR experiments.
+    It assumes there is always one train and valid dataloader,
+    but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
+    and test-other).
+    It contains all the common data pipeline modules used in ASR
+    experiments, e.g.:
+    - dynamic batch size,
+    - bucketing samplers,
+    - cut concatenation,
+    - augmentation,
+    - on-the-fly feature extraction
+    This class should be derived for specific corpora used in ASR tasks.
+    """
+
+    def __init__(self, args: argparse.Namespace):
+        self.args = args
+
+    @classmethod
+    def add_arguments(cls, parser: argparse.ArgumentParser):
+        group = parser.add_argument_group(
+            title="ASR data related options",
+            description=(
+                "These options are used for the preparation of "
+                "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
+                "effective batch sizes, sampling strategies, applied data "
+                "augmentations, etc."
+            ),
+        )
+        group.add_argument(
+            "--manifest-dir",
+            type=Path,
+            default=Path("data/manifests"),
+            help="Path to directory with train/valid/test cuts.",
+        )
+        group.add_argument(
+            "--enable-musan",
+            type=str2bool,
+            default=True,
+            help=(
+                "When enabled, select noise from MUSAN and mix it "
+                "with training dataset. "
+            ),
+        )
+        group.add_argument(
+            "--concatenate-cuts",
+            type=str2bool,
+            default=False,
+            help=(
+                "When enabled, utterances (cuts) will be concatenated "
+                "to minimize the amount of padding."
+            ),
+        )
+        group.add_argument(
+            "--duration-factor",
+            type=float,
+            default=1.0,
+            help=(
+                "Determines the maximum duration of a concatenated cut "
+                "relative to the duration of the longest cut in a batch."
+            ),
+        )
+        group.add_argument(
+            "--gap",
+            type=float,
+            default=1.0,
+            help=(
+                "The amount of padding (in seconds) inserted between "
+                "concatenated cuts. This padding is filled with noise when "
+                "noise augmentation is used."
+            ),
+        )
+        group.add_argument(
+            "--max-duration",
+            type=int,
+            default=100.0,
+            help=(
+                "Maximum pooled recordings duration (seconds) in a "
+                "single batch. You can reduce it if it causes CUDA OOM."
+            ),
+        )
+        group.add_argument(
+            "--max-cuts", type=int, default=None, help="Maximum cuts in a single batch."
+        )
+        group.add_argument(
+            "--num-buckets",
+            type=int,
+            default=50,
+            help=(
+                "The number of buckets for the BucketingSampler"
+                "(you might want to increase it for larger datasets)."
+            ),
+        )
+        group.add_argument(
+            "--on-the-fly-feats",
+            type=str2bool,
+            default=False,
+            help=(
+                "When enabled, use on-the-fly cut mixing and feature "
+                "extraction. Will drop existing precomputed feature manifests "
+                "if available."
+            ),
+        )
+        group.add_argument(
+            "--shuffle",
+            type=str2bool,
+            default=True,
+            help=(
+                "When enabled (=default), the examples will be "
+                "shuffled for each epoch."
+            ),
+        )
+
+        group.add_argument(
+            "--num-workers",
+            type=int,
+            default=8,
+            help=(
+                "The number of training dataloader workers that " "collect the batches."
+            ),
+        )
+        group.add_argument(
+            "--enable-spec-aug",
+            type=str2bool,
+            default=True,
+            help="When enabled, use SpecAugment for training dataset.",
+        )
+        group.add_argument(
+            "--spec-aug-time-warp-factor",
+            type=int,
+            default=80,
+            help=(
+                "Used only when --enable-spec-aug is True. "
+                "It specifies the factor for time warping in SpecAugment. "
+                "Larger values mean more warping. "
+                "A value less than 1 means to disable time warp."
+            ),
+        )
+
+    def train_dataloaders(
+        self,
+        cuts_train: CutSet,
+        sampler_state_dict: Optional[Dict[str, Any]] = None,
+    ) -> DataLoader:
+        """
+        Args:
+          cuts_train:
+            CutSet for training.
+          sampler_state_dict:
+            The state dict for the training sampler.
+        """
+        logging.info("About to get Musan cuts")
+
+        transforms = []
+        if self.args.enable_musan:
+            logging.info("Enable MUSAN")
+            cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
+            transforms.append(
+                CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
+            )
+        else:
+            logging.info("Disable MUSAN")
+
+        if self.args.concatenate_cuts:
+            logging.info(
+                "Using cut concatenation with duration factor "
+                f"{self.args.duration_factor} and gap {self.args.gap}."
+            )
+            # Cut concatenation should be the first transform in the list,
+            # so that if we e.g. mix noise in, it will fill the gaps between
+            # different utterances.
+            transforms = [
+                CutConcatenate(
+                    duration_factor=self.args.duration_factor, gap=self.args.gap
+                )
+            ] + transforms
+
+        input_transforms = []
+        if self.args.enable_spec_aug:
+            logging.info("Enable SpecAugment")
+            logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
+            input_transforms.append(
+                SpecAugment(
+                    time_warp_factor=self.args.spec_aug_time_warp_factor,
+                    num_frame_masks=2,
+                    features_mask_size=27,
+                    num_feature_masks=2,
+                    frames_mask_size=100,
+                )
+            )
+        else:
+            logging.info("Disable SpecAugment")
+
+        logging.info("About to create train dataset")
+        if self.args.on_the_fly_feats:
+            train = K2SpeechRecognitionDataset(
+                cut_transforms=transforms,
+                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
+                input_transforms=input_transforms,
+            )
+        else:
+            train = K2SpeechRecognitionDataset(
+                cut_transforms=transforms,
+                input_transforms=input_transforms,
+            )
+
+        logging.info("Using DynamicBucketingSampler.")
+        train_sampler = DynamicBucketingSampler(
+            cuts_train,
+            max_duration=self.args.max_duration,
+            max_cuts=self.args.max_cuts,
+            shuffle=False,
+            num_buckets=self.args.num_buckets,
+            drop_last=True,
+        )
+        logging.info("About to create train dataloader")
+
+        if sampler_state_dict is not None:
+            logging.info("Loading sampler state dict")
+            train_sampler.load_state_dict(sampler_state_dict)
+
+        # 'seed' is derived from the current random state, which will have
+        # previously been set in the main process.
+        seed = torch.randint(0, 100000, ()).item()
+        worker_init_fn = _SeedWorkers(seed)
+
+        train_dl = DataLoader(
+            train,
+            sampler=train_sampler,
+            batch_size=None,
+            num_workers=self.args.num_workers,
+            persistent_workers=False,
+            worker_init_fn=worker_init_fn,
+        )
+
+        return train_dl
+
+    def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
+
+        transforms = []
+        if self.args.concatenate_cuts:
+            transforms = [
+                CutConcatenate(
+                    duration_factor=self.args.duration_factor, gap=self.args.gap
+                )
+            ] + transforms
+
+        logging.info("About to create dev dataset")
+        if self.args.on_the_fly_feats:
+            validate = K2SpeechRecognitionDataset(
+                cut_transforms=transforms,
+                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
+            )
+        else:
+            validate = K2SpeechRecognitionDataset(
+                cut_transforms=transforms,
+            )
+        valid_sampler = DynamicBucketingSampler(
+            cuts_valid,
+            max_duration=self.args.max_duration,
+            shuffle=False,
+        )
+        logging.info("About to create dev dataloader")
+        valid_dl = DataLoader(
+            validate,
+            sampler=valid_sampler,
+            batch_size=None,
+            num_workers=2,
+            persistent_workers=False,
+        )
+
+        return valid_dl
+
+    def test_dataloaders(self, cuts: CutSet) -> DataLoader:
+        logging.debug("About to create test dataset")
+        test = K2SpeechRecognitionDataset(
+            input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
+            if self.args.on_the_fly_feats
+            else PrecomputedFeatures(),
+            return_cuts=True,
+        )
+        sampler = DynamicBucketingSampler(
+            cuts, max_duration=self.args.max_duration, shuffle=False
+        )
+        logging.debug("About to create test dataloader")
+        test_dl = DataLoader(
+            test,
+            batch_size=None,
+            sampler=sampler,
+            num_workers=self.args.num_workers,
+        )
+        return test_dl
+
+    def remove_short_cuts(self, cut: Cut) -> bool:
+        """
+        See: https://github.com/k2-fsa/icefall/issues/500
+        Basically, the zipformer model subsamples the input using the following formula:
+        num_out_frames = ((num_in_frames - 7)//2 + 1)//2
+        For num_out_frames to be at least 1, num_in_frames must be at least 9.
+        """
+        return cut.duration >= 0.09
+
+    @lru_cache()
+    def train_cuts(self, sp: Optional[Any] = None) -> CutSet:
+        logging.info("About to get AMI train cuts")
+
+        def _remove_short_and_long_utt(c: Cut):
+            if c.duration < 0.1 or c.duration > 25.0:
+                return False
+
+            # In pruned RNN-T, we require that T >= S
+            # where T is the number of feature frames after subsampling
+            # and S is the number of tokens in the utterance
+
+            # In ./zipformer.py, the conv module uses the following expression
+            # for subsampling
+            T = ((c.num_frames - 7) // 2 + 1) // 2
+            tokens = c.supervisions[0].text
+            return T >= len(tokens)
+
+        cuts_train = load_manifest_lazy(
+            self.args.manifest_dir / "cuts_train_all.jsonl.gz"
+        )
+
+        return cuts_train.filter(_remove_short_and_long_utt)
+
+    @lru_cache()
+    def eval_ihm_cuts(self) -> CutSet:
+        logging.info("About to get AliMeeting IHM eval cuts")
+        cs = load_manifest_lazy(self.args.manifest_dir / "cuts_eval_ihm.jsonl.gz")
+        return cs.filter(self.remove_short_cuts)
+
+    @lru_cache()
+    def eval_sdm_cuts(self) -> CutSet:
+        logging.info("About to get AliMeeting SDM eval cuts")
+        cs = load_manifest_lazy(self.args.manifest_dir / "cuts_eval_sdm.jsonl.gz")
+        return cs.filter(self.remove_short_cuts)
+
+    @lru_cache()
+    def eval_gss_cuts(self) -> CutSet:
+        if not (self.args.manifest_dir / "cuts_eval_gss.jsonl.gz").exists():
+            logging.info("No GSS dev cuts found")
+            return None
+        logging.info("About to get AliMeeting GSS-enhanced eval cuts")
+        cs = load_manifest_lazy(self.args.manifest_dir / "cuts_eval_gss.jsonl.gz")
+        return cs.filter(self.remove_short_cuts)
+
+    @lru_cache()
+    def test_ihm_cuts(self) -> CutSet:
+        logging.info("About to get AliMeeting IHM test cuts")
+        cs = load_manifest_lazy(self.args.manifest_dir / "cuts_test_ihm.jsonl.gz")
+        return cs.filter(self.remove_short_cuts)
+
+    @lru_cache()
+    def test_sdm_cuts(self) -> CutSet:
+        logging.info("About to get AliMeeting SDM test cuts")
+        cs = load_manifest_lazy(self.args.manifest_dir / "cuts_test_sdm.jsonl.gz")
+        return cs.filter(self.remove_short_cuts)
+
+    @lru_cache()
+    def test_gss_cuts(self) -> CutSet:
+        if not (self.args.manifest_dir / "cuts_test_gss.jsonl.gz").exists():
+            logging.info("No GSS test cuts found")
+            return None
+        logging.info("About to get AliMeeting GSS-enhanced test cuts")
+        cs = load_manifest_lazy(self.args.manifest_dir / "cuts_test_gss.jsonl.gz")
+        return cs.filter(self.remove_short_cuts)
diff --git a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/beam_search.py b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/beam_search.py
new file mode 120000
index 000000000..37516affc
--- /dev/null
+++ b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/beam_search.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7/beam_search.py
\ No newline at end of file
diff --git a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/decode.py b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/decode.py
new file mode 100755
index 000000000..53381c1f4
--- /dev/null
+++ b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/decode.py
@@ -0,0 +1,698 @@
+#!/usr/bin/env python3
+#
+# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+(1) greedy search
+./pruned_transducer_stateless7/decode.py \
+        --epoch 15 \
+        --avg 8 \
+        --exp-dir ./pruned_transducer_stateless7/exp \
+        --max-duration 500 \
+        --decoding-method greedy_search
+
+(2) modified beam search
+./pruned_transducer_stateless7/decode.py \
+        --epoch 15 \
+        --avg 8 \
+        --exp-dir ./pruned_transducer_stateless7/exp \
+        --max-duration 500 \
+        --decoding-method modified_beam_search \
+        --beam-size 4
+
+(3) fast beam search
+./pruned_transducer_stateless7/decode.py \
+        --epoch 15 \
+        --avg 8 \
+        --exp-dir ./pruned_transducer_stateless7/exp \
+        --max-duration 500 \
+        --decoding-method fast_beam_search \
+        --beam 4 \
+        --max-contexts 4 \
+        --max-states 8
+"""
+
+
+import argparse
+import logging
+from collections import defaultdict
+from pathlib import Path
+from typing import Dict, List, Optional, Tuple
+
+import k2
+import sentencepiece as spm
+import torch
+import torch.nn as nn
+from asr_datamodule import AlimeetingAsrDataModule
+from beam_search import (
+    beam_search,
+    fast_beam_search_nbest_LG,
+    fast_beam_search_one_best,
+    greedy_search,
+    greedy_search_batch,
+    modified_beam_search,
+)
+from train import add_model_arguments, get_params, get_transducer_model
+
+from icefall import NgramLm
+from icefall.checkpoint import (
+    average_checkpoints,
+    average_checkpoints_with_averaged_model,
+    find_checkpoints,
+    load_checkpoint,
+)
+from icefall.lexicon import Lexicon
+from icefall.utils import (
+    AttributeDict,
+    setup_logger,
+    store_transcripts,
+    str2bool,
+    write_error_stats,
+)
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+
+    parser.add_argument(
+        "--epoch",
+        type=int,
+        default=30,
+        help="""It specifies the checkpoint to use for decoding.
+        Note: Epoch counts from 0.
+        You can specify --avg to use more checkpoints for model averaging.""",
+    )
+
+    parser.add_argument(
+        "--iter",
+        type=int,
+        default=0,
+        help="""If positive, --epoch is ignored and it
+        will use the checkpoint exp_dir/checkpoint-iter.pt.
+        You can specify --avg to use more checkpoints for model averaging.
+        """,
+    )
+
+    parser.add_argument(
+        "--avg",
+        type=int,
+        default=10,
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch' and '--iter'",
+    )
+
+    parser.add_argument(
+        "--use-averaged-model",
+        type=str2bool,
+        default=True,
+        help="Whether to load averaged model. Currently it only supports "
+        "using --epoch. If True, it would decode with the averaged model "
+        "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+        "Actually only the models with epoch number of `epoch-avg` and "
+        "`epoch` are loaded for averaging. ",
+    )
+
+    parser.add_argument(
+        "--exp-dir",
+        type=str,
+        default="pruned_transducer_stateless2/exp",
+        help="The experiment dir",
+    )
+
+    parser.add_argument(
+        "--lang-dir",
+        type=str,
+        default="data/lang_char",
+        help="""The lang dir
+        It contains language related input files such as
+        "lexicon.txt"
+        """,
+    )
+
+    parser.add_argument(
+        "--decoding-method",
+        type=str,
+        default="greedy_search",
+        help="""Possible values are:
+          - greedy_search
+          - beam_search
+          - modified_beam_search
+          - fast_beam_search
+          - fast_beam_search_nbest
+          - fast_beam_search_nbest_oracle
+          - fast_beam_search_nbest_LG
+        If you use fast_beam_search_nbest_LG, you have to specify
+        `--lang-dir`, which should contain `LG.pt`.
+        """,
+    )
+
+    parser.add_argument(
+        "--beam-size",
+        type=int,
+        default=4,
+        help="""An interger indicating how many candidates we will keep for each
+        frame. Used only when --decoding-method is beam_search or
+        modified_beam_search.""",
+    )
+
+    parser.add_argument(
+        "--beam",
+        type=float,
+        default=4,
+        help="""A floating point value to calculate the cutoff score during beam
+        search (i.e., `cutoff = max-score - beam`), which is the same as the
+        `beam` in Kaldi.
+        Used only when --decoding-method is fast_beam_search""",
+    )
+
+    parser.add_argument(
+        "--ngram-lm-scale",
+        type=float,
+        default=0.01,
+        help="""
+        Used only when --decoding_method is fast_beam_search_nbest_LG.
+        It specifies the scale for n-gram LM scores.
+        """,
+    )
+
+    parser.add_argument(
+        "--max-contexts",
+        type=int,
+        default=8,
+        help="""Used only when --decoding-method is
+        fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
+        and fast_beam_search_nbest_oracle""",
+    )
+
+    parser.add_argument(
+        "--max-states",
+        type=int,
+        default=64,
+        help="""Used only when --decoding-method is
+        fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
+        and fast_beam_search_nbest_oracle""",
+    )
+
+    parser.add_argument(
+        "--context-size",
+        type=int,
+        default=2,
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+    )
+    parser.add_argument(
+        "--max-sym-per-frame",
+        type=int,
+        default=1,
+        help="""Maximum number of symbols per frame.
+        Used only when --decoding_method is greedy_search""",
+    )
+
+    parser.add_argument(
+        "--num-paths",
+        type=int,
+        default=200,
+        help="""Number of paths for nbest decoding.
+        Used only when the decoding method is fast_beam_search_nbest,
+        fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
+    )
+
+    parser.add_argument(
+        "--nbest-scale",
+        type=float,
+        default=0.5,
+        help="""Scale applied to lattice scores when computing nbest paths.
+        Used only when the decoding method is fast_beam_search_nbest,
+        fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
+    )
+
+    add_model_arguments(parser)
+
+    return parser
+
+
+def decode_one_batch(
+    params: AttributeDict,
+    model: nn.Module,
+    lexicon: Lexicon,
+    batch: dict,
+    decoding_graph: Optional[k2.Fsa] = None,
+) -> Dict[str, List[List[str]]]:
+    """Decode one batch and return the result in a dict. The dict has the
+    following format:
+
+        - key: It indicates the setting used for decoding. For example,
+               if greedy_search is used, it would be "greedy_search"
+               If beam search with a beam size of 7 is used, it would be
+               "beam_7"
+        - value: It contains the decoding result. `len(value)` equals to
+                 batch size. `value[i]` is the decoding result for the i-th
+                 utterance in the given batch.
+    Args:
+      params:
+        It's the return value of :func:`get_params`.
+      model:
+        The neural model.
+      batch:
+        It is the return value from iterating
+        `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
+        for the format of the `batch`.
+      decoding_graph:
+        The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
+        only when --decoding_method is fast_beam_search.
+    Returns:
+      Return the decoding result. See above description for the format of
+      the returned dict.
+    """
+    device = model.device
+    feature = batch["inputs"]
+    assert feature.ndim == 3
+
+    feature = feature.to(device)
+    # at entry, feature is (N, T, C)
+
+    supervisions = batch["supervisions"]
+    feature_lens = supervisions["num_frames"].to(device)
+
+    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
+    hyps = []
+
+    if params.decoding_method == "fast_beam_search":
+        hyp_tokens = fast_beam_search_one_best(
+            model=model,
+            decoding_graph=decoding_graph,
+            encoder_out=encoder_out,
+            encoder_out_lens=encoder_out_lens,
+            beam=params.beam,
+            max_contexts=params.max_contexts,
+            max_states=params.max_states,
+        )
+        for i in range(encoder_out.size(0)):
+            hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
+    elif params.decoding_method == "fast_beam_search_nbest_LG":
+        hyp_tokens = fast_beam_search_nbest_LG(
+            model=model,
+            decoding_graph=decoding_graph,
+            encoder_out=encoder_out,
+            encoder_out_lens=encoder_out_lens,
+            beam=params.beam,
+            max_contexts=params.max_contexts,
+            max_states=params.max_states,
+            num_paths=params.num_paths,
+            nbest_scale=params.nbest_scale,
+        )
+        for i in range(encoder_out.size(0)):
+            hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
+    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
+        hyp_tokens = greedy_search_batch(
+            model=model,
+            encoder_out=encoder_out,
+            encoder_out_lens=encoder_out_lens,
+        )
+        for i in range(encoder_out.size(0)):
+            hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
+    elif params.decoding_method == "modified_beam_search":
+        hyp_tokens = modified_beam_search(
+            model=model,
+            encoder_out=encoder_out,
+            encoder_out_lens=encoder_out_lens,
+            beam=params.beam_size,
+        )
+        for i in range(encoder_out.size(0)):
+            hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
+    else:
+        batch_size = encoder_out.size(0)
+
+        for i in range(batch_size):
+            # fmt: off
+            encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
+            # fmt: on
+            if params.decoding_method == "greedy_search":
+                hyp = greedy_search(
+                    model=model,
+                    encoder_out=encoder_out_i,
+                    max_sym_per_frame=params.max_sym_per_frame,
+                )
+            elif params.decoding_method == "beam_search":
+                hyp = beam_search(
+                    model=model,
+                    encoder_out=encoder_out_i,
+                    beam=params.beam_size,
+                )
+            else:
+                raise ValueError(
+                    f"Unsupported decoding method: {params.decoding_method}"
+                )
+            hyps.append([lexicon.token_table[idx] for idx in hyp])
+
+    if params.decoding_method == "greedy_search":
+        return {"greedy_search": hyps}
+    elif params.decoding_method == "fast_beam_search":
+        return {
+            (
+                f"beam_{params.beam}_"
+                f"max_contexts_{params.max_contexts}_"
+                f"max_states_{params.max_states}"
+            ): hyps
+        }
+    elif "fast_beam_search" in params.decoding_method:
+        key = f"beam_{params.beam}_"
+        key += f"max_contexts_{params.max_contexts}_"
+        key += f"max_states_{params.max_states}"
+        if "nbest" in params.decoding_method:
+            key += f"_num_paths_{params.num_paths}_"
+            key += f"nbest_scale_{params.nbest_scale}"
+            if "LG" in params.decoding_method:
+                key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
+
+        return {key: hyps}
+    else:
+        return {f"beam_size_{params.beam_size}": hyps}
+
+
+def decode_dataset(
+    dl: torch.utils.data.DataLoader,
+    params: AttributeDict,
+    model: nn.Module,
+    lexicon: Lexicon,
+    decoding_graph: Optional[k2.Fsa] = None,
+) -> Dict[str, List[Tuple[List[str], List[str]]]]:
+    """Decode dataset.
+
+    Args:
+      dl:
+        PyTorch's dataloader containing the dataset to decode.
+      params:
+        It is returned by :func:`get_params`.
+      model:
+        The neural model.
+      decoding_graph:
+        The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
+        only when --decoding_method is fast_beam_search.
+    Returns:
+      Return a dict, whose key may be "greedy_search" if greedy search
+      is used, or it may be "beam_7" if beam size of 7 is used.
+      Its value is a list of tuples. Each tuple contains two elements:
+      The first is the reference transcript, and the second is the
+      predicted result.
+    """
+    num_cuts = 0
+
+    try:
+        num_batches = len(dl)
+    except TypeError:
+        num_batches = "?"
+
+    if params.decoding_method == "greedy_search":
+        log_interval = 100
+    else:
+        log_interval = 2
+
+    results = defaultdict(list)
+    for batch_idx, batch in enumerate(dl):
+        texts = batch["supervisions"]["text"]
+        texts = [list(str(text).replace(" ", "")) for text in texts]
+        cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
+
+        hyps_dict = decode_one_batch(
+            params=params,
+            model=model,
+            lexicon=lexicon,
+            decoding_graph=decoding_graph,
+            batch=batch,
+        )
+
+        for name, hyps in hyps_dict.items():
+            this_batch = []
+            assert len(hyps) == len(texts)
+            for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
+                this_batch.append((cut_id, ref_text, hyp_words))
+
+            results[name].extend(this_batch)
+
+        num_cuts += len(texts)
+
+        if batch_idx % log_interval == 0:
+            batch_str = f"{batch_idx}/{num_batches}"
+
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+    return results
+
+
+def save_results(
+    params: AttributeDict,
+    test_set_name: str,
+    results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
+):
+    test_set_wers = dict()
+    for key, results in results_dict.items():
+        recog_path = (
+            params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
+        )
+        results = sorted(results)
+        store_transcripts(filename=recog_path, texts=results)
+        logging.info(f"The transcripts are stored in {recog_path}")
+
+        # The following prints out WERs, per-word error statistics and aligned
+        # ref/hyp pairs.
+        errs_filename = (
+            params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
+        )
+        with open(errs_filename, "w") as f:
+            wer = write_error_stats(
+                f, f"{test_set_name}-{key}", results, enable_log=True
+            )
+            test_set_wers[key] = wer
+
+        logging.info("Wrote detailed error stats to {}".format(errs_filename))
+
+    test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
+    errs_info = (
+        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+    )
+    with open(errs_info, "w") as f:
+        print("settings\tWER", file=f)
+        for key, val in test_set_wers:
+            print("{}\t{}".format(key, val), file=f)
+
+    s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
+    note = "\tbest for {}".format(test_set_name)
+    for key, val in test_set_wers:
+        s += "{}\t{}{}\n".format(key, val, note)
+        note = ""
+    logging.info(s)
+
+
+@torch.no_grad()
+def main():
+    parser = get_parser()
+    AlimeetingAsrDataModule.add_arguments(parser)
+    args = parser.parse_args()
+    args.exp_dir = Path(args.exp_dir)
+
+    params = get_params()
+    params.update(vars(args))
+
+    assert params.decoding_method in (
+        "greedy_search",
+        "beam_search",
+        "fast_beam_search",
+        "fast_beam_search_nbest_LG",
+        "modified_beam_search",
+    )
+    params.res_dir = params.exp_dir / params.decoding_method
+
+    if params.iter > 0:
+        params.suffix = f"iter-{params.iter}-avg-{params.avg}"
+    else:
+        params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
+
+    if "fast_beam_search" in params.decoding_method:
+        params.suffix += f"-beam-{params.beam}"
+        params.suffix += f"-max-contexts-{params.max_contexts}"
+        params.suffix += f"-max-states-{params.max_states}"
+        if "nbest" in params.decoding_method:
+            params.suffix += f"-nbest-scale-{params.nbest_scale}"
+            params.suffix += f"-num-paths-{params.num_paths}"
+            if "LG" in params.decoding_method:
+                params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
+    elif "beam_search" in params.decoding_method:
+        params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
+    else:
+        params.suffix += f"-context-{params.context_size}"
+        params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
+
+    setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
+    logging.info("Decoding started")
+
+    device = torch.device("cpu")
+    if torch.cuda.is_available():
+        device = torch.device("cuda", 0)
+
+    logging.info(f"Device: {device}")
+
+    lexicon = Lexicon(params.lang_dir)
+    params.blank_id = lexicon.token_table[""]
+    params.vocab_size = max(lexicon.tokens) + 1
+
+    logging.info(params)
+
+    logging.info("About to create model")
+    model = get_transducer_model(params)
+
+    if not params.use_averaged_model:
+        if params.iter > 0:
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg
+            ]
+            if len(filenames) == 0:
+                raise ValueError(
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            elif len(filenames) < params.avg:
+                raise ValueError(
+                    f"Not enough checkpoints ({len(filenames)}) found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            logging.info(f"averaging {filenames}")
+            model.to(device)
+            model.load_state_dict(average_checkpoints(filenames, device=device))
+        elif params.avg == 1:
+            load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+        else:
+            start = params.epoch - params.avg + 1
+            filenames = []
+            for i in range(start, params.epoch + 1):
+                if i >= 1:
+                    filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+            logging.info(f"averaging {filenames}")
+            model.to(device)
+            model.load_state_dict(average_checkpoints(filenames, device=device))
+    else:
+        if params.iter > 0:
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg + 1
+            ]
+            if len(filenames) == 0:
+                raise ValueError(
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            elif len(filenames) < params.avg + 1:
+                raise ValueError(
+                    f"Not enough checkpoints ({len(filenames)}) found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            filename_start = filenames[-1]
+            filename_end = filenames[0]
+            logging.info(
+                "Calculating the averaged model over iteration checkpoints"
+                f" from {filename_start} (excluded) to {filename_end}"
+            )
+            model.to(device)
+            model.load_state_dict(
+                average_checkpoints_with_averaged_model(
+                    filename_start=filename_start,
+                    filename_end=filename_end,
+                    device=device,
+                )
+            )
+        else:
+            assert params.avg > 0, params.avg
+            start = params.epoch - params.avg
+            assert start >= 1, start
+            filename_start = f"{params.exp_dir}/epoch-{start}.pt"
+            filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
+            logging.info(
+                f"Calculating the averaged model over epoch range from "
+                f"{start} (excluded) to {params.epoch}"
+            )
+            model.to(device)
+            model.load_state_dict(
+                average_checkpoints_with_averaged_model(
+                    filename_start=filename_start,
+                    filename_end=filename_end,
+                    device=device,
+                )
+            )
+
+    model.to(device)
+    model.eval()
+    model.device = device
+
+    if "fast_beam_search" in params.decoding_method:
+        decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
+    else:
+        decoding_graph = None
+
+    num_param = sum([p.numel() for p in model.parameters()])
+    logging.info(f"Number of model parameters: {num_param}")
+
+    alimeeting = AlimeetingAsrDataModule(args)
+
+    eval_ihm_cuts = alimeeting.eval_ihm_cuts()
+    test_ihm_cuts = alimeeting.test_ihm_cuts()
+    eval_sdm_cuts = alimeeting.eval_sdm_cuts()
+    test_sdm_cuts = alimeeting.test_sdm_cuts()
+    eval_gss_cuts = alimeeting.eval_gss_cuts()
+    test_gss_cuts = alimeeting.test_gss_cuts()
+
+    eval_ihm_dl = alimeeting.test_dataloaders(eval_ihm_cuts)
+    test_ihm_dl = alimeeting.test_dataloaders(test_ihm_cuts)
+    eval_sdm_dl = alimeeting.test_dataloaders(eval_sdm_cuts)
+    test_sdm_dl = alimeeting.test_dataloaders(test_sdm_cuts)
+    if eval_gss_cuts is not None:
+        eval_gss_dl = alimeeting.test_dataloaders(eval_gss_cuts)
+    if test_gss_cuts is not None:
+        test_gss_dl = alimeeting.test_dataloaders(test_gss_cuts)
+
+    test_sets = {
+        "eval_ihm": (eval_ihm_dl, eval_ihm_cuts),
+        "test_ihm": (test_ihm_dl, test_ihm_cuts),
+        "eval_sdm": (eval_sdm_dl, eval_sdm_cuts),
+        "test_sdm": (test_sdm_dl, test_sdm_cuts),
+    }
+    if eval_gss_cuts is not None:
+        test_sets["eval_gss"] = (eval_gss_dl, eval_gss_cuts)
+    if test_gss_cuts is not None:
+        test_sets["test_gss"] = (test_gss_dl, test_gss_cuts)
+
+    for test_set in test_sets:
+        logging.info(f"Decoding {test_set}")
+        dl, cuts = test_sets[test_set]
+        results_dict = decode_dataset(
+            dl=dl,
+            params=params,
+            model=model,
+            lexicon=lexicon,
+            decoding_graph=decoding_graph,
+        )
+
+        save_results(
+            params=params,
+            test_set_name=test_set,
+            results_dict=results_dict,
+        )
+
+    logging.info("Done!")
+
+
+if __name__ == "__main__":
+    main()
diff --git a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/decoder.py b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/decoder.py
new file mode 120000
index 000000000..8283d8c5a
--- /dev/null
+++ b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/decoder.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7/decoder.py
\ No newline at end of file
diff --git a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/encoder_interface.py b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/encoder_interface.py
new file mode 120000
index 000000000..0c2673d46
--- /dev/null
+++ b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/encoder_interface.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7/encoder_interface.py
\ No newline at end of file
diff --git a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/export.py b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/export.py
new file mode 100755
index 000000000..23a88dd29
--- /dev/null
+++ b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/export.py
@@ -0,0 +1,320 @@
+#!/usr/bin/env python3
+#
+# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# This script converts several saved checkpoints
+# to a single one using model averaging.
+"""
+
+Usage:
+
+(1) Export to torchscript model using torch.jit.script()
+
+./pruned_transducer_stateless7/export.py \
+  --exp-dir ./pruned_transducer_stateless7/exp \
+  --bpe-model data/lang_bpe_500/bpe.model \
+  --epoch 30 \
+  --avg 9 \
+  --jit 1
+
+It will generate a file `cpu_jit.pt` in the given `exp_dir`. You can later
+load it by `torch.jit.load("cpu_jit.pt")`.
+
+Note `cpu` in the name `cpu_jit.pt` means the parameters when loaded into Python
+are on CPU. You can use `to("cuda")` to move them to a CUDA device.
+
+Check
+https://github.com/k2-fsa/sherpa
+for how to use the exported models outside of icefall.
+
+(2) Export `model.state_dict()`
+
+./pruned_transducer_stateless7/export.py \
+  --exp-dir ./pruned_transducer_stateless7/exp \
+  --bpe-model data/lang_bpe_500/bpe.model \
+  --epoch 20 \
+  --avg 10
+
+It will generate a file `pretrained.pt` in the given `exp_dir`. You can later
+load it by `icefall.checkpoint.load_checkpoint()`.
+
+To use the generated file with `pruned_transducer_stateless7/decode.py`,
+you can do:
+
+    cd /path/to/exp_dir
+    ln -s pretrained.pt epoch-9999.pt
+
+    cd /path/to/egs/librispeech/ASR
+    ./pruned_transducer_stateless7/decode.py \
+        --exp-dir ./pruned_transducer_stateless7/exp \
+        --epoch 9999 \
+        --avg 1 \
+        --max-duration 600 \
+        --decoding-method greedy_search \
+        --bpe-model data/lang_bpe_500/bpe.model
+
+Check ./pretrained.py for its usage.
+
+Note: If you don't want to train a model from scratch, we have
+provided one for you. You can get it at
+
+https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11
+
+with the following commands:
+
+    sudo apt-get install git-lfs
+    git lfs install
+    git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11
+    # You will find the pre-trained model in icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11/exp
+"""
+
+import argparse
+import logging
+from pathlib import Path
+
+import sentencepiece as spm
+import torch
+import torch.nn as nn
+from scaling_converter import convert_scaled_to_non_scaled
+from train import add_model_arguments, get_params, get_transducer_model
+
+from icefall.checkpoint import (
+    average_checkpoints,
+    average_checkpoints_with_averaged_model,
+    find_checkpoints,
+    load_checkpoint,
+)
+from icefall.lexicon import Lexicon
+from icefall.utils import str2bool
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+
+    parser.add_argument(
+        "--epoch",
+        type=int,
+        default=15,
+        help="""It specifies the checkpoint to use for decoding.
+        Note: Epoch counts from 1.
+        You can specify --avg to use more checkpoints for model averaging.""",
+    )
+
+    parser.add_argument(
+        "--iter",
+        type=int,
+        default=0,
+        help="""If positive, --epoch is ignored and it
+        will use the checkpoint exp_dir/checkpoint-iter.pt.
+        You can specify --avg to use more checkpoints for model averaging.
+        """,
+    )
+
+    parser.add_argument(
+        "--avg",
+        type=int,
+        default=8,
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch' and '--iter'",
+    )
+
+    parser.add_argument(
+        "--use-averaged-model",
+        type=str2bool,
+        default=True,
+        help="Whether to load averaged model. Currently it only supports "
+        "using --epoch. If True, it would decode with the averaged model "
+        "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+        "Actually only the models with epoch number of `epoch-avg` and "
+        "`epoch` are loaded for averaging. ",
+    )
+
+    parser.add_argument(
+        "--exp-dir",
+        type=str,
+        default="pruned_transducer_stateless7/exp",
+        help="""It specifies the directory where all training related
+        files, e.g., checkpoints, log, etc, are saved
+        """,
+    )
+
+    parser.add_argument(
+        "--lang-dir",
+        type=str,
+        default="data/lang_char",
+        help="The lang dir",
+    )
+
+    parser.add_argument(
+        "--jit",
+        type=str2bool,
+        default=False,
+        help="""True to save a model after applying torch.jit.script.
+        It will generate a file named cpu_jit.pt
+
+        Check ./jit_pretrained.py for how to use it.
+        """,
+    )
+
+    parser.add_argument(
+        "--context-size",
+        type=int,
+        default=2,
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+    )
+
+    add_model_arguments(parser)
+
+    return parser
+
+
+@torch.no_grad()
+def main():
+    args = get_parser().parse_args()
+    args.exp_dir = Path(args.exp_dir)
+
+    params = get_params()
+    params.update(vars(args))
+
+    device = torch.device("cpu")
+    if torch.cuda.is_available():
+        device = torch.device("cuda", 0)
+
+    logging.info(f"device: {device}")
+
+    lexicon = Lexicon(params.lang_dir)
+
+    params.blank_id = 0
+    params.vocab_size = max(lexicon.tokens) + 1
+
+    logging.info(params)
+
+    logging.info("About to create model")
+    model = get_transducer_model(params)
+
+    model.to(device)
+
+    if not params.use_averaged_model:
+        if params.iter > 0:
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg
+            ]
+            if len(filenames) == 0:
+                raise ValueError(
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            elif len(filenames) < params.avg:
+                raise ValueError(
+                    f"Not enough checkpoints ({len(filenames)}) found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            logging.info(f"averaging {filenames}")
+            model.to(device)
+            model.load_state_dict(average_checkpoints(filenames, device=device))
+        elif params.avg == 1:
+            load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+        else:
+            start = params.epoch - params.avg + 1
+            filenames = []
+            for i in range(start, params.epoch + 1):
+                if i >= 1:
+                    filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+            logging.info(f"averaging {filenames}")
+            model.to(device)
+            model.load_state_dict(average_checkpoints(filenames, device=device))
+    else:
+        if params.iter > 0:
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg + 1
+            ]
+            if len(filenames) == 0:
+                raise ValueError(
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            elif len(filenames) < params.avg + 1:
+                raise ValueError(
+                    f"Not enough checkpoints ({len(filenames)}) found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            filename_start = filenames[-1]
+            filename_end = filenames[0]
+            logging.info(
+                "Calculating the averaged model over iteration checkpoints"
+                f" from {filename_start} (excluded) to {filename_end}"
+            )
+            model.to(device)
+            model.load_state_dict(
+                average_checkpoints_with_averaged_model(
+                    filename_start=filename_start,
+                    filename_end=filename_end,
+                    device=device,
+                )
+            )
+        else:
+            assert params.avg > 0, params.avg
+            start = params.epoch - params.avg
+            assert start >= 1, start
+            filename_start = f"{params.exp_dir}/epoch-{start}.pt"
+            filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
+            logging.info(
+                f"Calculating the averaged model over epoch range from "
+                f"{start} (excluded) to {params.epoch}"
+            )
+            model.to(device)
+            model.load_state_dict(
+                average_checkpoints_with_averaged_model(
+                    filename_start=filename_start,
+                    filename_end=filename_end,
+                    device=device,
+                )
+            )
+
+    model.to("cpu")
+    model.eval()
+
+    if params.jit is True:
+        convert_scaled_to_non_scaled(model, inplace=True)
+        logging.info("Using torch.jit.script()")
+        # We won't use the forward() method of the model in C++, so just ignore
+        # it here.
+        # Otherwise, one of its arguments is a ragged tensor and is not
+        # torch scriptabe.
+        model.__class__.forward = torch.jit.ignore(model.__class__.forward)
+        logging.info("Using torch.jit.script")
+        model = torch.jit.script(model)
+        filename = params.exp_dir / "cpu_jit.pt"
+        model.save(str(filename))
+        logging.info(f"Saved to {filename}")
+    else:
+        logging.info("Not using torchscript. Export model.state_dict()")
+        # Save it using a format so that it can be loaded
+        # by :func:`load_checkpoint`
+        filename = params.exp_dir / "pretrained.pt"
+        torch.save({"model": model.state_dict()}, str(filename))
+        logging.info(f"Saved to {filename}")
+
+
+if __name__ == "__main__":
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+    logging.basicConfig(format=formatter, level=logging.INFO)
+    main()
diff --git a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/jit_pretrained.py b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/jit_pretrained.py
new file mode 120000
index 000000000..a44034e34
--- /dev/null
+++ b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/jit_pretrained.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7/jit_pretrained.py
\ No newline at end of file
diff --git a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/joiner.py b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/joiner.py
new file mode 120000
index 000000000..0f0c3c90a
--- /dev/null
+++ b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/joiner.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7/joiner.py
\ No newline at end of file
diff --git a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/model.py b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/model.py
new file mode 120000
index 000000000..0d8bc665b
--- /dev/null
+++ b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/model.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7/model.py
\ No newline at end of file
diff --git a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/optim.py b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/optim.py
new file mode 120000
index 000000000..8a05abb5f
--- /dev/null
+++ b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/optim.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7/optim.py
\ No newline at end of file
diff --git a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/pretrained.py b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/pretrained.py
new file mode 120000
index 000000000..068f0f57f
--- /dev/null
+++ b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/pretrained.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7/pretrained.py
\ No newline at end of file
diff --git a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/scaling.py b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/scaling.py
new file mode 120000
index 000000000..5f9be9fe0
--- /dev/null
+++ b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/scaling.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7/scaling.py
\ No newline at end of file
diff --git a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/scaling_converter.py b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/scaling_converter.py
new file mode 120000
index 000000000..f9960e5c6
--- /dev/null
+++ b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/scaling_converter.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7/scaling_converter.py
\ No newline at end of file
diff --git a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/test_model.py b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/test_model.py
new file mode 120000
index 000000000..7ceac5d10
--- /dev/null
+++ b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/test_model.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7/test_model.py
\ No newline at end of file
diff --git a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/train.py b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/train.py
new file mode 100755
index 000000000..757d6535e
--- /dev/null
+++ b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/train.py
@@ -0,0 +1,1186 @@
+#!/usr/bin/env python3
+# Copyright    2021-2022  Xiaomi Corp.        (authors: Fangjun Kuang,
+#                                                       Wei Kang,
+#                                                       Mingshuang Luo,)
+#                                                       Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+
+export CUDA_VISIBLE_DEVICES="0,1,2,3"
+
+./pruned_transducer_stateless7/train.py \
+  --world-size 4 \
+  --num-epochs 15 \
+  --start-epoch 1 \
+  --exp-dir pruned_transducer_stateless7/exp \
+  --max-duration 150 \
+    --use-fp16 True
+
+"""
+
+
+import argparse
+import copy
+import logging
+import warnings
+from pathlib import Path
+from shutil import copyfile
+from typing import Any, Dict, Optional, Tuple, Union
+
+import k2
+import optim
+import sentencepiece as spm
+import torch
+import torch.multiprocessing as mp
+import torch.nn as nn
+from asr_datamodule import AlimeetingAsrDataModule
+from decoder import Decoder
+from joiner import Joiner
+from lhotse.dataset.sampling.base import CutSampler
+from lhotse.utils import fix_random_seed
+from model import Transducer
+from optim import Eden, ScaledAdam
+from torch import Tensor
+from torch.cuda.amp import GradScaler
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.utils.tensorboard import SummaryWriter
+from zipformer import Zipformer
+
+from icefall import diagnostics
+from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
+from icefall.checkpoint import load_checkpoint, remove_checkpoints
+from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
+from icefall.checkpoint import (
+    save_checkpoint_with_global_batch_idx,
+    update_averaged_model,
+)
+from icefall.dist import cleanup_dist, setup_dist
+from icefall.env import get_env_info
+from icefall.hooks import register_inf_check_hooks
+from icefall.lexicon import Lexicon
+from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
+
+LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
+
+
+def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
+    if isinstance(model, DDP):
+        # get underlying nn.Module
+        model = model.module
+    for module in model.modules():
+        if hasattr(module, "batch_count"):
+            module.batch_count = batch_count
+
+
+def add_model_arguments(parser: argparse.ArgumentParser):
+    parser.add_argument(
+        "--num-encoder-layers",
+        type=str,
+        default="2,4,3,2,4",
+        help="Number of zipformer encoder layers, comma separated.",
+    )
+
+    parser.add_argument(
+        "--feedforward-dims",
+        type=str,
+        default="1024,1024,2048,2048,1024",
+        help="Feedforward dimension of the zipformer encoder layers, comma separated.",
+    )
+
+    parser.add_argument(
+        "--nhead",
+        type=str,
+        default="8,8,8,8,8",
+        help="Number of attention heads in the zipformer encoder layers.",
+    )
+
+    parser.add_argument(
+        "--encoder-dims",
+        type=str,
+        default="384,384,384,384,384",
+        help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated",
+    )
+
+    parser.add_argument(
+        "--attention-dims",
+        type=str,
+        default="192,192,192,192,192",
+        help="""Attention dimension in the 2 blocks of zipformer encoder layers, comma separated;
+        not the same as embedding dimension.""",
+    )
+
+    parser.add_argument(
+        "--encoder-unmasked-dims",
+        type=str,
+        default="256,256,256,256,256",
+        help="Unmasked dimensions in the encoders, relates to augmentation during training.  "
+        "Must be <= each of encoder_dims.  Empirically, less than 256 seems to make performance "
+        " worse.",
+    )
+
+    parser.add_argument(
+        "--zipformer-downsampling-factors",
+        type=str,
+        default="1,2,4,8,2",
+        help="Downsampling factor for each stack of encoder layers.",
+    )
+
+    parser.add_argument(
+        "--cnn-module-kernels",
+        type=str,
+        default="31,31,31,31,31",
+        help="Sizes of kernels in convolution modules",
+    )
+
+    parser.add_argument(
+        "--decoder-dim",
+        type=int,
+        default=512,
+        help="Embedding dimension in the decoder model.",
+    )
+
+    parser.add_argument(
+        "--joiner-dim",
+        type=int,
+        default=512,
+        help="""Dimension used in the joiner model.
+        Outputs from the encoder and decoder model are projected
+        to this dimension before adding.
+        """,
+    )
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+
+    parser.add_argument(
+        "--world-size",
+        type=int,
+        default=1,
+        help="Number of GPUs for DDP training.",
+    )
+
+    parser.add_argument(
+        "--master-port",
+        type=int,
+        default=12354,
+        help="Master port to use for DDP training.",
+    )
+
+    parser.add_argument(
+        "--tensorboard",
+        type=str2bool,
+        default=True,
+        help="Should various information be logged in tensorboard.",
+    )
+
+    parser.add_argument(
+        "--num-epochs",
+        type=int,
+        default=15,
+        help="Number of epochs to train.",
+    )
+
+    parser.add_argument(
+        "--start-epoch",
+        type=int,
+        default=1,
+        help="""Resume training from this epoch. It should be positive.
+        If larger than 1, it will load checkpoint from
+        exp-dir/epoch-{start_epoch-1}.pt
+        """,
+    )
+
+    parser.add_argument(
+        "--start-batch",
+        type=int,
+        default=0,
+        help="""If positive, --start-epoch is ignored and
+        it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt
+        """,
+    )
+
+    parser.add_argument(
+        "--exp-dir",
+        type=str,
+        default="pruned_transducer_stateless7/exp",
+        help="""The experiment dir.
+        It specifies the directory where all training related
+        files, e.g., checkpoints, log, etc, are saved
+        """,
+    )
+
+    parser.add_argument(
+        "--lang-dir",
+        type=str,
+        default="data/lang_char",
+        help="""The lang dir
+        It contains language related input files such as
+        "lexicon.txt"
+        """,
+    )
+
+    parser.add_argument(
+        "--base-lr", type=float, default=0.05, help="The base learning rate."
+    )
+
+    parser.add_argument(
+        "--lr-batches",
+        type=float,
+        default=5000,
+        help="""Number of steps that affects how rapidly the learning rate
+        decreases. We suggest not to change this.""",
+    )
+
+    parser.add_argument(
+        "--lr-epochs",
+        type=float,
+        default=3.5,
+        help="""Number of epochs that affects how rapidly the learning rate decreases.
+        """,
+    )
+
+    parser.add_argument(
+        "--context-size",
+        type=int,
+        default=2,
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+    )
+
+    parser.add_argument(
+        "--prune-range",
+        type=int,
+        default=5,
+        help="The prune range for rnnt loss, it means how many symbols(context)"
+        "we are using to compute the loss",
+    )
+
+    parser.add_argument(
+        "--lm-scale",
+        type=float,
+        default=0.25,
+        help="The scale to smooth the loss with lm "
+        "(output of prediction network) part.",
+    )
+
+    parser.add_argument(
+        "--am-scale",
+        type=float,
+        default=0.0,
+        help="The scale to smooth the loss with am (output of encoder network)" "part.",
+    )
+
+    parser.add_argument(
+        "--simple-loss-scale",
+        type=float,
+        default=0.5,
+        help="To get pruning ranges, we will calculate a simple version"
+        "loss(joiner is just addition), this simple loss also uses for"
+        "training (as a regularization item). We will scale the simple loss"
+        "with this parameter before adding to the final loss.",
+    )
+
+    parser.add_argument(
+        "--seed",
+        type=int,
+        default=42,
+        help="The seed for random generators intended for reproducibility",
+    )
+
+    parser.add_argument(
+        "--print-diagnostics",
+        type=str2bool,
+        default=False,
+        help="Accumulate stats on activations, print them and exit.",
+    )
+
+    parser.add_argument(
+        "--inf-check",
+        type=str2bool,
+        default=False,
+        help="Add hooks to check for infinite module outputs and gradients.",
+    )
+
+    parser.add_argument(
+        "--save-every-n",
+        type=int,
+        default=5000,
+        help="""Save checkpoint after processing this number of batches"
+        periodically. We save checkpoint to exp-dir/ whenever
+        params.batch_idx_train % save_every_n == 0. The checkpoint filename
+        has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
+        Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
+        end of each epoch where `xxx` is the epoch number counting from 0.
+        """,
+    )
+
+    parser.add_argument(
+        "--keep-last-k",
+        type=int,
+        default=10,
+        help="""Only keep this number of checkpoints on disk.
+        For instance, if it is 3, there are only 3 checkpoints
+        in the exp-dir with filenames `checkpoint-xxx.pt`.
+        It does not affect checkpoints with name `epoch-xxx.pt`.
+        """,
+    )
+
+    parser.add_argument(
+        "--average-period",
+        type=int,
+        default=200,
+        help="""Update the averaged model, namely `model_avg`, after processing
+        this number of batches. `model_avg` is a separate version of model,
+        in which each floating-point parameter is the average of all the
+        parameters from the start of training. Each time we take the average,
+        we do: `model_avg = model * (average_period / batch_idx_train) +
+            model_avg * ((batch_idx_train - average_period) / batch_idx_train)`.
+        """,
+    )
+
+    parser.add_argument(
+        "--use-fp16",
+        type=str2bool,
+        default=False,
+        help="Whether to use half precision training.",
+    )
+
+    add_model_arguments(parser)
+
+    return parser
+
+
+def get_params() -> AttributeDict:
+    """Return a dict containing training parameters.
+
+    All training related parameters that are not passed from the commandline
+    are saved in the variable `params`.
+
+    Commandline options are merged into `params` after they are parsed, so
+    you can also access them via `params`.
+
+    Explanation of options saved in `params`:
+
+        - best_train_loss: Best training loss so far. It is used to select
+                           the model that has the lowest training loss. It is
+                           updated during the training.
+
+        - best_valid_loss: Best validation loss so far. It is used to select
+                           the model that has the lowest validation loss. It is
+                           updated during the training.
+
+        - best_train_epoch: It is the epoch that has the best training loss.
+
+        - best_valid_epoch: It is the epoch that has the best validation loss.
+
+        - batch_idx_train: Used to writing statistics to tensorboard. It
+                           contains number of batches trained so far across
+                           epochs.
+
+        - log_interval:  Print training loss if batch_idx % log_interval` is 0
+
+        - reset_interval: Reset statistics if batch_idx % reset_interval is 0
+
+        - valid_interval:  Run validation if batch_idx % valid_interval is 0
+
+        - feature_dim: The model input dim. It has to match the one used
+                       in computing features.
+
+        - subsampling_factor:  The subsampling factor for the model.
+
+        - encoder_dim: Hidden dim for multi-head attention model.
+
+        - num_decoder_layers: Number of decoder layer of transformer decoder.
+
+        - warm_step: The warmup period that dictates the decay of the
+              scale on "simple" (un-pruned) loss.
+    """
+    params = AttributeDict(
+        {
+            "best_train_loss": float("inf"),
+            "best_valid_loss": float("inf"),
+            "best_train_epoch": -1,
+            "best_valid_epoch": -1,
+            "batch_idx_train": 0,
+            "log_interval": 100,
+            "reset_interval": 200,
+            "valid_interval": 3000,  # For the 100h subset, use 800
+            # parameters for zipformer
+            "feature_dim": 80,
+            "subsampling_factor": 4,  # not passed in, this is fixed.
+            "warm_step": 2000,
+            "env_info": get_env_info(),
+        }
+    )
+
+    return params
+
+
+def get_encoder_model(params: AttributeDict) -> nn.Module:
+    # TODO: We can add an option to switch between Zipformer and Transformer
+    def to_int_tuple(s: str):
+        return tuple(map(int, s.split(",")))
+
+    encoder = Zipformer(
+        num_features=params.feature_dim,
+        output_downsampling_factor=2,
+        zipformer_downsampling_factors=to_int_tuple(
+            params.zipformer_downsampling_factors
+        ),
+        encoder_dims=to_int_tuple(params.encoder_dims),
+        attention_dim=to_int_tuple(params.attention_dims),
+        encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims),
+        nhead=to_int_tuple(params.nhead),
+        feedforward_dim=to_int_tuple(params.feedforward_dims),
+        cnn_module_kernels=to_int_tuple(params.cnn_module_kernels),
+        num_encoder_layers=to_int_tuple(params.num_encoder_layers),
+    )
+    return encoder
+
+
+def get_decoder_model(params: AttributeDict) -> nn.Module:
+    decoder = Decoder(
+        vocab_size=params.vocab_size,
+        decoder_dim=params.decoder_dim,
+        blank_id=params.blank_id,
+        context_size=params.context_size,
+    )
+    return decoder
+
+
+def get_joiner_model(params: AttributeDict) -> nn.Module:
+    joiner = Joiner(
+        encoder_dim=int(params.encoder_dims.split(",")[-1]),
+        decoder_dim=params.decoder_dim,
+        joiner_dim=params.joiner_dim,
+        vocab_size=params.vocab_size,
+    )
+    return joiner
+
+
+def get_transducer_model(params: AttributeDict) -> nn.Module:
+    encoder = get_encoder_model(params)
+    decoder = get_decoder_model(params)
+    joiner = get_joiner_model(params)
+
+    model = Transducer(
+        encoder=encoder,
+        decoder=decoder,
+        joiner=joiner,
+        encoder_dim=int(params.encoder_dims.split(",")[-1]),
+        decoder_dim=params.decoder_dim,
+        joiner_dim=params.joiner_dim,
+        vocab_size=params.vocab_size,
+    )
+    return model
+
+
+def load_checkpoint_if_available(
+    params: AttributeDict,
+    model: nn.Module,
+    model_avg: nn.Module = None,
+    optimizer: Optional[torch.optim.Optimizer] = None,
+    scheduler: Optional[LRSchedulerType] = None,
+) -> Optional[Dict[str, Any]]:
+    """Load checkpoint from file.
+
+    If params.start_batch is positive, it will load the checkpoint from
+    `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if
+    params.start_epoch is larger than 1, it will load the checkpoint from
+    `params.start_epoch - 1`.
+
+    Apart from loading state dict for `model` and `optimizer` it also updates
+    `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
+    and `best_valid_loss` in `params`.
+
+    Args:
+      params:
+        The return value of :func:`get_params`.
+      model:
+        The training model.
+      model_avg:
+        The stored model averaged from the start of training.
+      optimizer:
+        The optimizer that we are using.
+      scheduler:
+        The scheduler that we are using.
+    Returns:
+      Return a dict containing previously saved training info.
+    """
+    if params.start_batch > 0:
+        filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt"
+    elif params.start_epoch > 1:
+        filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
+    else:
+        return None
+
+    assert filename.is_file(), f"{filename} does not exist!"
+
+    saved_params = load_checkpoint(
+        filename,
+        model=model,
+        model_avg=model_avg,
+        optimizer=optimizer,
+        scheduler=scheduler,
+    )
+
+    keys = [
+        "best_train_epoch",
+        "best_valid_epoch",
+        "batch_idx_train",
+        "best_train_loss",
+        "best_valid_loss",
+    ]
+    for k in keys:
+        params[k] = saved_params[k]
+
+    if params.start_batch > 0:
+        if "cur_epoch" in saved_params:
+            params["start_epoch"] = saved_params["cur_epoch"]
+
+        if "cur_batch_idx" in saved_params:
+            params["cur_batch_idx"] = saved_params["cur_batch_idx"]
+
+    return saved_params
+
+
+def save_checkpoint(
+    params: AttributeDict,
+    model: Union[nn.Module, DDP],
+    model_avg: Optional[nn.Module] = None,
+    optimizer: Optional[torch.optim.Optimizer] = None,
+    scheduler: Optional[LRSchedulerType] = None,
+    sampler: Optional[CutSampler] = None,
+    scaler: Optional[GradScaler] = None,
+    rank: int = 0,
+) -> None:
+    """Save model, optimizer, scheduler and training stats to file.
+
+    Args:
+      params:
+        It is returned by :func:`get_params`.
+      model:
+        The training model.
+      model_avg:
+        The stored model averaged from the start of training.
+      optimizer:
+        The optimizer used in the training.
+      sampler:
+       The sampler for the training dataset.
+      scaler:
+        The scaler used for mix precision training.
+    """
+    if rank != 0:
+        return
+    filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
+    save_checkpoint_impl(
+        filename=filename,
+        model=model,
+        model_avg=model_avg,
+        params=params,
+        optimizer=optimizer,
+        scheduler=scheduler,
+        sampler=sampler,
+        scaler=scaler,
+        rank=rank,
+    )
+
+    if params.best_train_epoch == params.cur_epoch:
+        best_train_filename = params.exp_dir / "best-train-loss.pt"
+        copyfile(src=filename, dst=best_train_filename)
+
+    if params.best_valid_epoch == params.cur_epoch:
+        best_valid_filename = params.exp_dir / "best-valid-loss.pt"
+        copyfile(src=filename, dst=best_valid_filename)
+
+
+def compute_loss(
+    params: AttributeDict,
+    model: Union[nn.Module, DDP],
+    graph_compiler: CharCtcTrainingGraphCompiler,
+    batch: dict,
+    is_training: bool,
+) -> Tuple[Tensor, MetricsTracker]:
+    """
+    Compute transducer loss given the model and its inputs.
+
+    Args:
+      params:
+        Parameters for training. See :func:`get_params`.
+      model:
+        The model for training. It is an instance of Zipformer in our case.
+      batch:
+        A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+        for the content in it.
+      is_training:
+        True for training. False for validation. When it is True, this
+        function enables autograd during computation; when it is False, it
+        disables autograd.
+     warmup: a floating point value which increases throughout training;
+        values >= 1.0 are fully warmed up and have all modules present.
+    """
+    device = model.device if isinstance(model, DDP) else next(model.parameters()).device
+    feature = batch["inputs"]
+    # at entry, feature is (N, T, C)
+    assert feature.ndim == 3
+    feature = feature.to(device)
+
+    supervisions = batch["supervisions"]
+    feature_lens = supervisions["num_frames"].to(device)
+
+    batch_idx_train = params.batch_idx_train
+    warm_step = params.warm_step
+
+    texts = batch["supervisions"]["text"]
+
+    y = graph_compiler.texts_to_ids(texts)
+    if type(y) == list:
+        y = k2.RaggedTensor(y).to(device)
+    else:
+        y = y.to(device)
+
+    with torch.set_grad_enabled(is_training):
+        simple_loss, pruned_loss = model(
+            x=feature,
+            x_lens=feature_lens,
+            y=y,
+            prune_range=params.prune_range,
+            am_scale=params.am_scale,
+            lm_scale=params.lm_scale,
+        )
+
+        s = params.simple_loss_scale
+        # take down the scale on the simple loss from 1.0 at the start
+        # to params.simple_loss scale by warm_step.
+        simple_loss_scale = (
+            s
+            if batch_idx_train >= warm_step
+            else 1.0 - (batch_idx_train / warm_step) * (1.0 - s)
+        )
+        pruned_loss_scale = (
+            1.0
+            if batch_idx_train >= warm_step
+            else 0.1 + 0.9 * (batch_idx_train / warm_step)
+        )
+
+        loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
+
+    assert loss.requires_grad == is_training
+
+    info = MetricsTracker()
+    with warnings.catch_warnings():
+        warnings.simplefilter("ignore")
+        info["frames"] = ((feature_lens - 7) // 2).sum().item()
+
+    # Note: We use reduction=sum while computing the loss.
+    info["loss"] = loss.detach().cpu().item()
+    info["simple_loss"] = simple_loss.detach().cpu().item()
+    info["pruned_loss"] = pruned_loss.detach().cpu().item()
+
+    return loss, info
+
+
+def compute_validation_loss(
+    params: AttributeDict,
+    model: Union[nn.Module, DDP],
+    graph_compiler: CharCtcTrainingGraphCompiler,
+    valid_dl: torch.utils.data.DataLoader,
+    world_size: int = 1,
+) -> MetricsTracker:
+    """Run the validation process."""
+    model.eval()
+
+    tot_loss = MetricsTracker()
+
+    for batch_idx, batch in enumerate(valid_dl):
+        loss, loss_info = compute_loss(
+            params=params,
+            model=model,
+            graph_compiler=graph_compiler,
+            batch=batch,
+            is_training=False,
+        )
+        assert loss.requires_grad is False
+        tot_loss = tot_loss + loss_info
+
+    if world_size > 1:
+        tot_loss.reduce(loss.device)
+
+    loss_value = tot_loss["loss"] / tot_loss["frames"]
+    if loss_value < params.best_valid_loss:
+        params.best_valid_epoch = params.cur_epoch
+        params.best_valid_loss = loss_value
+
+    return tot_loss
+
+
+def train_one_epoch(
+    params: AttributeDict,
+    model: Union[nn.Module, DDP],
+    optimizer: torch.optim.Optimizer,
+    scheduler: LRSchedulerType,
+    graph_compiler: CharCtcTrainingGraphCompiler,
+    train_dl: torch.utils.data.DataLoader,
+    valid_dl: torch.utils.data.DataLoader,
+    scaler: GradScaler,
+    model_avg: Optional[nn.Module] = None,
+    tb_writer: Optional[SummaryWriter] = None,
+    world_size: int = 1,
+    rank: int = 0,
+) -> None:
+    """Train the model for one epoch.
+
+    The training loss from the mean of all frames is saved in
+    `params.train_loss`. It runs the validation process every
+    `params.valid_interval` batches.
+
+    Args:
+      params:
+        It is returned by :func:`get_params`.
+      model:
+        The model for training.
+      optimizer:
+        The optimizer we are using.
+      scheduler:
+        The learning rate scheduler, we call step() every step.
+      train_dl:
+        Dataloader for the training dataset.
+      valid_dl:
+        Dataloader for the validation dataset.
+      scaler:
+        The scaler used for mix precision training.
+      model_avg:
+        The stored model averaged from the start of training.
+      tb_writer:
+        Writer to write log messages to tensorboard.
+      world_size:
+        Number of nodes in DDP training. If it is 1, DDP is disabled.
+      rank:
+        The rank of the node in DDP training. If no DDP is used, it should
+        be set to 0.
+    """
+    model.train()
+
+    tot_loss = MetricsTracker()
+
+    cur_batch_idx = params.get("cur_batch_idx", 0)
+
+    for batch_idx, batch in enumerate(train_dl):
+        if batch_idx < cur_batch_idx:
+            continue
+        cur_batch_idx = batch_idx
+
+        params.batch_idx_train += 1
+        batch_size = len(batch["supervisions"]["text"])
+
+        try:
+            with torch.cuda.amp.autocast(enabled=params.use_fp16):
+                loss, loss_info = compute_loss(
+                    params=params,
+                    model=model,
+                    graph_compiler=graph_compiler,
+                    batch=batch,
+                    is_training=True,
+                )
+            # summary stats
+            tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
+
+            # NOTE: We use reduction==sum and loss is computed over utterances
+            # in the batch and there is no normalization to it so far.
+            scaler.scale(loss).backward()
+            set_batch_count(model, params.batch_idx_train)
+            scheduler.step_batch(params.batch_idx_train)
+
+            scaler.step(optimizer)
+            scaler.update()
+            optimizer.zero_grad()
+        except:  # noqa
+            display_and_save_batch(batch, params=params, graph_compiler=graph_compiler)
+            raise
+
+        if params.print_diagnostics and batch_idx == 5:
+            return
+
+        if (
+            rank == 0
+            and params.batch_idx_train > 0
+            and params.batch_idx_train % params.average_period == 0
+        ):
+            update_averaged_model(
+                params=params,
+                model_cur=model,
+                model_avg=model_avg,
+            )
+
+        if (
+            params.batch_idx_train > 0
+            and params.batch_idx_train % params.save_every_n == 0
+        ):
+            params.cur_batch_idx = batch_idx
+            save_checkpoint_with_global_batch_idx(
+                out_dir=params.exp_dir,
+                global_batch_idx=params.batch_idx_train,
+                model=model,
+                model_avg=model_avg,
+                params=params,
+                optimizer=optimizer,
+                scheduler=scheduler,
+                sampler=train_dl.sampler,
+                scaler=scaler,
+                rank=rank,
+            )
+            del params.cur_batch_idx
+            remove_checkpoints(
+                out_dir=params.exp_dir,
+                topk=params.keep_last_k,
+                rank=rank,
+            )
+
+        if batch_idx % 100 == 0 and params.use_fp16:
+            # If the grad scale was less than 1, try increasing it.    The _growth_interval
+            # of the grad scaler is configurable, but we can't configure it to have different
+            # behavior depending on the current grad scale.
+            cur_grad_scale = scaler._scale.item()
+            if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0):
+                scaler.update(cur_grad_scale * 2.0)
+            if cur_grad_scale < 0.01:
+                logging.warning(f"Grad scale is small: {cur_grad_scale}")
+            if cur_grad_scale < 1.0e-05:
+                raise RuntimeError(
+                    f"grad_scale is too small, exiting: {cur_grad_scale}"
+                )
+
+        if batch_idx % params.log_interval == 0:
+            cur_lr = scheduler.get_last_lr()[0]
+            cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
+
+            logging.info(
+                f"Epoch {params.cur_epoch}, "
+                f"batch {batch_idx}, loss[{loss_info}], "
+                f"tot_loss[{tot_loss}], batch size: {batch_size}, "
+                f"lr: {cur_lr:.2e}, "
+                + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
+            )
+
+            if tb_writer is not None:
+                tb_writer.add_scalar(
+                    "train/learning_rate", cur_lr, params.batch_idx_train
+                )
+
+                loss_info.write_summary(
+                    tb_writer, "train/current_", params.batch_idx_train
+                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+                if params.use_fp16:
+                    tb_writer.add_scalar(
+                        "train/grad_scale", cur_grad_scale, params.batch_idx_train
+                    )
+
+        if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
+            logging.info("Computing validation loss")
+            valid_info = compute_validation_loss(
+                params=params,
+                model=model,
+                graph_compiler=graph_compiler,
+                valid_dl=valid_dl,
+                world_size=world_size,
+            )
+            model.train()
+            logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
+            logging.info(
+                f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
+            )
+            if tb_writer is not None:
+                valid_info.write_summary(
+                    tb_writer, "train/valid_", params.batch_idx_train
+                )
+
+    loss_value = tot_loss["loss"] / tot_loss["frames"]
+    params.train_loss = loss_value
+    if params.train_loss < params.best_train_loss:
+        params.best_train_epoch = params.cur_epoch
+        params.best_train_loss = params.train_loss
+
+
+def run(rank, world_size, args):
+    """
+    Args:
+      rank:
+        It is a value between 0 and `world_size-1`, which is
+        passed automatically by `mp.spawn()` in :func:`main`.
+        The node with rank 0 is responsible for saving checkpoint.
+      world_size:
+        Number of GPUs for DDP training.
+      args:
+        The return value of get_parser().parse_args()
+    """
+    params = get_params()
+    params.update(vars(args))
+
+    fix_random_seed(params.seed)
+    if world_size > 1:
+        setup_dist(rank, world_size, params.master_port)
+
+    setup_logger(f"{params.exp_dir}/log/log-train")
+    logging.info("Training started")
+
+    if args.tensorboard and rank == 0:
+        tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
+    else:
+        tb_writer = None
+
+    device = torch.device("cpu")
+    if torch.cuda.is_available():
+        device = torch.device("cuda", rank)
+    logging.info(f"Device: {device}")
+
+    lexicon = Lexicon(params.lang_dir)
+    graph_compiler = CharCtcTrainingGraphCompiler(
+        lexicon=lexicon,
+        device=device,
+    )
+
+    params.blank_id = lexicon.token_table[""]
+    params.vocab_size = max(lexicon.tokens) + 1
+
+    logging.info(params)
+
+    logging.info("About to create model")
+    model = get_transducer_model(params)
+
+    num_param = sum([p.numel() for p in model.parameters()])
+    logging.info(f"Number of model parameters: {num_param}")
+
+    assert params.save_every_n >= params.average_period
+    model_avg: Optional[nn.Module] = None
+    if rank == 0:
+        # model_avg is only used with rank 0
+        model_avg = copy.deepcopy(model).to(torch.float64)
+
+    assert params.start_epoch > 0, params.start_epoch
+    checkpoints = load_checkpoint_if_available(
+        params=params, model=model, model_avg=model_avg
+    )
+
+    model.to(device)
+    if world_size > 1:
+        logging.info("Using DDP")
+        model = DDP(model, device_ids=[rank], find_unused_parameters=True)
+
+    parameters_names = []
+    parameters_names.append(
+        [name_param_pair[0] for name_param_pair in model.named_parameters()]
+    )
+    optimizer = ScaledAdam(
+        model.parameters(),
+        lr=params.base_lr,
+        clipping_scale=2.0,
+        parameters_names=parameters_names,
+    )
+
+    scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
+
+    if checkpoints and "optimizer" in checkpoints:
+        logging.info("Loading optimizer state dict")
+        optimizer.load_state_dict(checkpoints["optimizer"])
+
+    if (
+        checkpoints
+        and "scheduler" in checkpoints
+        and checkpoints["scheduler"] is not None
+    ):
+        logging.info("Loading scheduler state dict")
+        scheduler.load_state_dict(checkpoints["scheduler"])
+
+    if params.print_diagnostics:
+        opts = diagnostics.TensorDiagnosticOptions(
+            2**22
+        )  # allow 4 megabytes per sub-module
+        diagnostic = diagnostics.attach_diagnostics(model, opts)
+
+    if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
+        # We only load the sampler's state dict when it loads a checkpoint
+        # saved in the middle of an epoch
+        sampler_state_dict = checkpoints["sampler"]
+    else:
+        sampler_state_dict = None
+
+    if params.inf_check:
+        register_inf_check_hooks(model)
+
+    alimeeting = AlimeetingAsrDataModule(args)
+
+    train_cuts = alimeeting.train_cuts()
+    train_dl = alimeeting.train_dataloaders(
+        train_cuts, sampler_state_dict=sampler_state_dict
+    )
+
+    valid_cuts = alimeeting.eval_ihm_cuts()
+    valid_dl = alimeeting.valid_dataloaders(valid_cuts)
+
+    # if not params.print_diagnostics:
+    #     scan_pessimistic_batches_for_oom(
+    #         model=model,
+    #         train_dl=train_dl,
+    #         optimizer=optimizer,
+    #         graph_compiler=graph_compiler,
+    #         params=params,
+    #     )
+
+    scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
+    if checkpoints and "grad_scaler" in checkpoints:
+        logging.info("Loading grad scaler state dict")
+        scaler.load_state_dict(checkpoints["grad_scaler"])
+
+    for epoch in range(params.start_epoch, params.num_epochs + 1):
+        scheduler.step_epoch(epoch - 1)
+        fix_random_seed(params.seed + epoch - 1)
+        train_dl.sampler.set_epoch(epoch - 1)
+
+        if tb_writer is not None:
+            tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
+
+        params.cur_epoch = epoch
+
+        train_one_epoch(
+            params=params,
+            model=model,
+            model_avg=model_avg,
+            optimizer=optimizer,
+            scheduler=scheduler,
+            graph_compiler=graph_compiler,
+            train_dl=train_dl,
+            valid_dl=valid_dl,
+            scaler=scaler,
+            tb_writer=tb_writer,
+            world_size=world_size,
+            rank=rank,
+        )
+
+        if params.print_diagnostics:
+            diagnostic.print_diagnostics()
+            break
+
+        save_checkpoint(
+            params=params,
+            model=model,
+            model_avg=model_avg,
+            optimizer=optimizer,
+            scheduler=scheduler,
+            sampler=train_dl.sampler,
+            scaler=scaler,
+            rank=rank,
+        )
+
+    logging.info("Done!")
+
+    if world_size > 1:
+        torch.distributed.barrier()
+        cleanup_dist()
+
+
+def display_and_save_batch(
+    batch: dict,
+    params: AttributeDict,
+    graph_compiler: CharCtcTrainingGraphCompiler,
+) -> None:
+    """Display the batch statistics and save the batch into disk.
+
+    Args:
+      batch:
+        A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+        for the content in it.
+      params:
+        Parameters for training. See :func:`get_params`.
+      sp:
+        The BPE model.
+    """
+    from lhotse.utils import uuid4
+
+    filename = f"{params.exp_dir}/batch-{uuid4()}.pt"
+    logging.info(f"Saving batch to {filename}")
+    torch.save(batch, filename)
+
+    supervisions = batch["supervisions"]
+    features = batch["inputs"]
+
+    logging.info(f"features shape: {features.shape}")
+
+
+def scan_pessimistic_batches_for_oom(
+    model: Union[nn.Module, DDP],
+    train_dl: torch.utils.data.DataLoader,
+    optimizer: torch.optim.Optimizer,
+    graph_compiler: CharCtcTrainingGraphCompiler,
+    params: AttributeDict,
+):
+    from lhotse.dataset import find_pessimistic_batches
+
+    logging.info(
+        "Sanity check -- see if any of the batches in epoch 1 would cause OOM."
+    )
+    batches, crit_values = find_pessimistic_batches(train_dl.sampler)
+    for criterion, cuts in batches.items():
+        batch = train_dl.dataset[cuts]
+        try:
+            with torch.cuda.amp.autocast(enabled=params.use_fp16):
+                loss, _ = compute_loss(
+                    params=params,
+                    model=model,
+                    graph_compiler=graph_compiler,
+                    batch=batch,
+                    is_training=True,
+                )
+            loss.backward()
+            optimizer.zero_grad()
+        except Exception as e:
+            if "CUDA out of memory" in str(e):
+                logging.error(
+                    "Your GPU ran out of memory with the current "
+                    "max_duration setting. We recommend decreasing "
+                    "max_duration and trying again.\n"
+                    f"Failing criterion: {criterion} "
+                    f"(={crit_values[criterion]}) ..."
+                )
+            display_and_save_batch(batch, params=params, graph_compiler=graph_compiler)
+            raise
+        logging.info(
+            f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
+        )
+
+
+def main():
+    parser = get_parser()
+    AlimeetingAsrDataModule.add_arguments(parser)
+    args = parser.parse_args()
+    args.exp_dir = Path(args.exp_dir)
+
+    world_size = args.world_size
+    assert world_size >= 1
+    if world_size > 1:
+        mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
+    else:
+        run(rank=0, world_size=1, args=args)
+
+
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+if __name__ == "__main__":
+    main()
diff --git a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/zipformer.py b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/zipformer.py
new file mode 120000
index 000000000..f2f66041e
--- /dev/null
+++ b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/zipformer.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7/zipformer.py
\ No newline at end of file
diff --git a/egs/alimeeting/ASR_v2/shared b/egs/alimeeting/ASR_v2/shared
new file mode 120000
index 000000000..3a3b28f96
--- /dev/null
+++ b/egs/alimeeting/ASR_v2/shared
@@ -0,0 +1 @@
+../../../egs/aishell/ASR/shared
\ No newline at end of file

From 02c18ba4b25a805db8e8dbb6b7fc4766ad1e006a Mon Sep 17 00:00:00 2001
From: Yifan Yang <64255737+yfyeung@users.noreply.github.com>
Date: Sat, 10 Dec 2022 19:34:19 +0800
Subject: [PATCH 044/174] rm the dup line of Zipformer.py (#755)

Co-authored-by: yifanyang 
---
 egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py | 1 -
 1 file changed, 1 deletion(-)

diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py
index b007a7308..e8fd89abd 100644
--- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py
@@ -81,7 +81,6 @@ class Zipformer(EncoderInterface):
         super(Zipformer, self).__init__()
 
         self.num_features = num_features
-        self.encoder_unmasked_dims = encoder_unmasked_dims
         assert 0 < encoder_dims[0] <= encoder_dims[1]
         self.encoder_dims = encoder_dims
         self.encoder_unmasked_dims = encoder_unmasked_dims

From e83409cbe536cc031728f394fc3eb1132aac01e1 Mon Sep 17 00:00:00 2001
From: wzy <38179632+v-yunbin@users.noreply.github.com>
Date: Sun, 11 Dec 2022 20:16:10 +0800
Subject: [PATCH 045/174]  Filter the training data of T <  S  for Wenet train
 recipe (#753)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

* filter the case of T <  S  for training data

* fix style issues

* fix style issues

* fix style issues

Co-authored-by: 张云斌 
---
 .../ASR/pruned_transducer_stateless2/train.py | 32 +++++++++++++++++--
 .../ASR/pruned_transducer_stateless5/train.py | 32 +++++++++++++++++--
 2 files changed, 58 insertions(+), 6 deletions(-)

diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py
index 43fa0d01b..48b347b64 100644
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py
@@ -861,15 +861,41 @@ def run(rank, world_size, args):
     valid_cuts = wenetspeech.valid_cuts()
 
     def remove_short_and_long_utt(c: Cut):
-        # Keep only utterances with duration between 1 second and 15.0 seconds
+        # Keep only utterances with duration between 1 second and 10 seconds
         #
-        # Caution: There is a reason to select 15.0 here. Please see
+        # Caution: There is a reason to select 10.0 here. Please see
         # ../local/display_manifest_statistics.py
         #
         # You should use ../local/display_manifest_statistics.py to get
         # an utterance duration distribution for your dataset to select
         # the threshold
-        return 1.0 <= c.duration <= 15.0
+        if c.duration < 1.0 or c.duration > 10.0:
+            logging.warning(
+                f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
+            )
+            return False
+
+        # In pruned RNN-T, we require that T >= S
+        # where T is the number of feature frames after subsampling
+        # and S is the number of tokens in the utterance
+
+        # In ./conformer.py, the conv module uses the following expression
+        # for subsampling
+        T = ((c.num_frames - 1) // 2 - 1) // 2
+        tokens = c.supervisions[0].text.replace(" ", "")
+
+        if T < len(tokens):
+            logging.warning(
+                f"Exclude cut with ID {c.id} from training. "
+                f"Number of frames (before subsampling): {c.num_frames}. "
+                f"Number of frames (after subsampling): {T}. "
+                f"Text: {c.supervisions[0].text}. "
+                f"Tokens: {tokens}. "
+                f"Number of tokens: {len(tokens)}"
+            )
+            return False
+
+        return True
 
     train_cuts = train_cuts.filter(remove_short_and_long_utt)
 
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/train.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/train.py
index 440b65f32..34a72be8f 100755
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/train.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/train.py
@@ -1006,15 +1006,41 @@ def run(rank, world_size, args):
     valid_cuts = wenetspeech.valid_cuts()
 
     def remove_short_and_long_utt(c: Cut):
-        # Keep only utterances with duration between 1 second and 15.0 seconds
+        # Keep only utterances with duration between 1 second and 10 seconds
         #
-        # Caution: There is a reason to select 15.0 here. Please see
+        # Caution: There is a reason to select 10.0 here. Please see
         # ../local/display_manifest_statistics.py
         #
         # You should use ../local/display_manifest_statistics.py to get
         # an utterance duration distribution for your dataset to select
         # the threshold
-        return 1.0 <= c.duration <= 15.0
+        if c.duration < 1.0 or c.duration > 10.0:
+            logging.warning(
+                f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
+            )
+            return False
+
+        # In pruned RNN-T, we require that T >= S
+        # where T is the number of feature frames after subsampling
+        # and S is the number of tokens in the utterance
+
+        # In ./conformer.py, the conv module uses the following expression
+        # for subsampling
+        T = ((c.num_frames - 1) // 2 - 1) // 2
+        tokens = c.supervisions[0].text.replace(" ", "")
+
+        if T < len(tokens):
+            logging.warning(
+                f"Exclude cut with ID {c.id} from training. "
+                f"Number of frames (before subsampling): {c.num_frames}. "
+                f"Number of frames (after subsampling): {T}. "
+                f"Text: {c.supervisions[0].text}. "
+                f"Tokens: {tokens}. "
+                f"Number of tokens: {len(tokens)}"
+            )
+            return False
+
+        return True
 
     train_cuts = train_cuts.filter(remove_short_and_long_utt)
 

From b25c234c51426d61552cdca819ab57fe712214c9 Mon Sep 17 00:00:00 2001
From: Zengwei Yao 
Date: Sun, 11 Dec 2022 21:30:39 +0800
Subject: [PATCH 046/174] Add Zipformer-MMI (#746)

* Minor fix to conformer-mmi

* Minor fixes

* Fix decode.py

* add training files

* train with ctc warmup

* add pruned_transducer_stateless7_mmi

* add zipformer_mmi/mmi_decode.py, using HP as decoding graph

* add mmi_decode.py

* remove pruned_transducer_stateless7_mmi

* rename zipformer_mmi/train_with_ctc.py as zipformer_mmi/train.py

* remove unused method

* rename mmi_decode.py

* add export.py pretrained.py jit_pretrained.py ...

* add RESULTS.md

* add CI test

* add docs

* add README.md

Co-authored-by: pkufool 
---
 .flake8                                       |    3 +-
 ...n-librispeech-conformer-ctc3-2022-11-28.sh |    8 +-
 ...ed-transducer-stateless7-ctc-2022-12-01.sh |    8 +-
 ...un-librispeech-zipformer-mmi-2022-12-08.sh |  103 ++
 ...n-librispeech-2022-12-08-zipformer-mmi.yml |  167 +++
 docs/source/recipes/librispeech/index.rst     |    1 +
 .../recipes/librispeech/zipformer_mmi.rst     |  422 ++++++
 egs/librispeech/ASR/RESULTS.md                |   57 +
 .../ASR/conformer_ctc3/jit_pretrained.py      |    5 +-
 .../ASR/conformer_ctc3/pretrained.py          |    5 +-
 egs/librispeech/ASR/conformer_mmi/decode.py   |   12 +-
 .../ASR/conformer_mmi/train-with-attention.py |   76 +-
 egs/librispeech/ASR/conformer_mmi/train.py    |   67 +-
 egs/librispeech/ASR/generate-lm.sh            |    2 +-
 .../export.py                                 |    6 +-
 .../jit_pretrained_ctc.py                     |    5 +-
 .../pretrained_ctc.py                         |    5 +-
 egs/librispeech/ASR/zipformer_mmi/README.md   |   26 +
 egs/librispeech/ASR/zipformer_mmi/__init__.py |    0
 .../ASR/zipformer_mmi/asr_datamodule.py       |    1 +
 egs/librispeech/ASR/zipformer_mmi/decode.py   |  736 ++++++++++
 .../ASR/zipformer_mmi/encoder_interface.py    |    1 +
 egs/librispeech/ASR/zipformer_mmi/export.py   |  307 +++++
 .../ASR/zipformer_mmi/jit_pretrained.py       |  391 ++++++
 egs/librispeech/ASR/zipformer_mmi/model.py    |   75 ++
 egs/librispeech/ASR/zipformer_mmi/optim.py    |    1 +
 .../ASR/zipformer_mmi/pretrained.py           |  410 ++++++
 egs/librispeech/ASR/zipformer_mmi/scaling.py  |    1 +
 .../ASR/zipformer_mmi/scaling_converter.py    |    1 +
 .../ASR/zipformer_mmi/test_model.py           |   57 +
 egs/librispeech/ASR/zipformer_mmi/train.py    | 1198 +++++++++++++++++
 .../ASR/zipformer_mmi/zipformer.py            |    1 +
 icefall/decode.py                             |  101 ++
 icefall/mmi.py                                |   10 +-
 34 files changed, 4224 insertions(+), 45 deletions(-)
 create mode 100755 .github/scripts/run-librispeech-zipformer-mmi-2022-12-08.sh
 create mode 100644 .github/workflows/run-librispeech-2022-12-08-zipformer-mmi.yml
 create mode 100644 docs/source/recipes/librispeech/zipformer_mmi.rst
 create mode 100644 egs/librispeech/ASR/zipformer_mmi/README.md
 create mode 100644 egs/librispeech/ASR/zipformer_mmi/__init__.py
 create mode 120000 egs/librispeech/ASR/zipformer_mmi/asr_datamodule.py
 create mode 100755 egs/librispeech/ASR/zipformer_mmi/decode.py
 create mode 120000 egs/librispeech/ASR/zipformer_mmi/encoder_interface.py
 create mode 100755 egs/librispeech/ASR/zipformer_mmi/export.py
 create mode 100755 egs/librispeech/ASR/zipformer_mmi/jit_pretrained.py
 create mode 100644 egs/librispeech/ASR/zipformer_mmi/model.py
 create mode 120000 egs/librispeech/ASR/zipformer_mmi/optim.py
 create mode 100755 egs/librispeech/ASR/zipformer_mmi/pretrained.py
 create mode 120000 egs/librispeech/ASR/zipformer_mmi/scaling.py
 create mode 120000 egs/librispeech/ASR/zipformer_mmi/scaling_converter.py
 create mode 100755 egs/librispeech/ASR/zipformer_mmi/test_model.py
 create mode 100755 egs/librispeech/ASR/zipformer_mmi/train.py
 create mode 120000 egs/librispeech/ASR/zipformer_mmi/zipformer.py

diff --git a/.flake8 b/.flake8
index a0f44263c..41d8799c8 100644
--- a/.flake8
+++ b/.flake8
@@ -1,7 +1,7 @@
 [flake8]
 show-source=true
 statistics=true
-max-line-length = 80
+max-line-length = 88
 per-file-ignores =
     # line too long
     icefall/diagnostics.py: E501,
@@ -12,6 +12,7 @@ per-file-ignores =
     egs/librispeech/ASR/lstm_transducer_stateless*/*.py: E501, E203
     egs/librispeech/ASR/conv_emformer_transducer_stateless*/*.py: E501, E203
     egs/librispeech/ASR/conformer_ctc*/*py: E501,
+    egs/librispeech/ASR/zipformer_mmi/*.py: E501, E203
     egs/librispeech/ASR/RESULTS.md: E999,
 
     # invalid escape sequence (cause by tex formular), W605
diff --git a/.github/scripts/run-librispeech-conformer-ctc3-2022-11-28.sh b/.github/scripts/run-librispeech-conformer-ctc3-2022-11-28.sh
index 27944807f..df29f188e 100755
--- a/.github/scripts/run-librispeech-conformer-ctc3-2022-11-28.sh
+++ b/.github/scripts/run-librispeech-conformer-ctc3-2022-11-28.sh
@@ -13,7 +13,6 @@ cd egs/librispeech/ASR
 repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-conformer-ctc3-2022-11-27
 
 log "Downloading pre-trained model from $repo_url"
-git lfs install
 GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
 repo=$(basename $repo_url)
 
@@ -23,7 +22,12 @@ soxi $repo/test_wavs/*.wav
 ls -lh $repo/test_wavs/*.wav
 
 pushd $repo/exp
-git lfs pull --include "data/*"
+git lfs pull --include "data/lang_bpe_500/HLG.pt"
+git lfs pull --include "data/lang_bpe_500/L.pt"
+git lfs pull --include "data/lang_bpe_500/LG.pt"
+git lfs pull --include "data/lang_bpe_500/Linv.pt"
+git lfs pull --include "data/lang_bpe_500/bpe.model"
+git lfs pull --include "data/lm/G_4_gram.pt"
 git lfs pull --include "exp/jit_trace.pt"
 git lfs pull --include "exp/pretrained.pt"
 ln -s pretrained.pt epoch-99.pt
diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-2022-12-01.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-2022-12-01.sh
index 6642d5f67..e081c9374 100755
--- a/.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-2022-12-01.sh
+++ b/.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-2022-12-01.sh
@@ -13,7 +13,6 @@ cd egs/librispeech/ASR
 repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-ctc-2022-12-01
 
 log "Downloading pre-trained model from $repo_url"
-git lfs install
 GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
 repo=$(basename $repo_url)
 
@@ -23,7 +22,12 @@ soxi $repo/test_wavs/*.wav
 ls -lh $repo/test_wavs/*.wav
 
 pushd $repo/exp
-git lfs pull --include "data/*"
+git lfs pull --include "data/lang_bpe_500/HLG.pt"
+git lfs pull --include "data/lang_bpe_500/L.pt"
+git lfs pull --include "data/lang_bpe_500/LG.pt"
+git lfs pull --include "data/lang_bpe_500/Linv.pt"
+git lfs pull --include "data/lang_bpe_500/bpe.model"
+git lfs pull --include "data/lm/G_4_gram.pt"
 git lfs pull --include "exp/cpu_jit.pt"
 git lfs pull --include "exp/pretrained.pt"
 ln -s pretrained.pt epoch-99.pt
diff --git a/.github/scripts/run-librispeech-zipformer-mmi-2022-12-08.sh b/.github/scripts/run-librispeech-zipformer-mmi-2022-12-08.sh
new file mode 100755
index 000000000..77f28b054
--- /dev/null
+++ b/.github/scripts/run-librispeech-zipformer-mmi-2022-12-08.sh
@@ -0,0 +1,103 @@
+#!/usr/bin/env bash
+
+set -e
+
+log() {
+  # This function is from espnet
+  local fname=${BASH_SOURCE[1]##*/}
+  echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
+}
+
+cd egs/librispeech/ASR
+
+repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-mmi-2022-12-08
+
+log "Downloading pre-trained model from $repo_url"
+GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
+repo=$(basename $repo_url)
+
+log "Display test files"
+tree $repo/
+soxi $repo/test_wavs/*.wav
+ls -lh $repo/test_wavs/*.wav
+
+pushd $repo/exp
+git lfs pull --include "data/lang_bpe_500/3gram.pt"
+git lfs pull --include "data/lang_bpe_500/4gram.pt"
+git lfs pull --include "data/lang_bpe_500/L.pt"
+git lfs pull --include "data/lang_bpe_500/LG.pt"
+git lfs pull --include "data/lang_bpe_500/Linv.pt"
+git lfs pull --include "data/lang_bpe_500/bpe.model"
+git lfs pull --include "exp/cpu_jit.pt"
+git lfs pull --include "exp/pretrained.pt"
+ln -s pretrained.pt epoch-99.pt
+ls -lh *.pt
+popd
+
+log "Export to torchscript model"
+./zipformer_mmi/export.py \
+  --exp-dir $repo/exp \
+  --use-averaged-model false \
+  --bpe-model $repo/data/lang_bpe_500/bpe.model \
+  --epoch 99 \
+  --avg 1 \
+  --jit 1
+
+ls -lh $repo/exp/*.pt
+
+log "Decode with models exported by torch.jit.script()"
+
+./zipformer_mmi/jit_pretrained.py \
+  --bpe-model $repo/data/lang_bpe_500/bpe.model \
+  --nn-model-filename $repo/exp/cpu_jit.pt \
+  --lang-dir $repo/data/lang_bpe_500 \
+  $repo/test_wavs/1089-134686-0001.wav \
+  $repo/test_wavs/1221-135766-0001.wav \
+  $repo/test_wavs/1221-135766-0002.wav
+
+for method in 1best nbest nbest-rescoring-LG nbest-rescoring-3-gram nbest-rescoring-4-gram; do
+  log "$method"
+
+  ./zipformer_mmi/pretrained.py \
+    --method $method \
+    --checkpoint $repo/exp/pretrained.pt \
+    --lang-dir $repo/data/lang_bpe_500 \
+    --bpe-model $repo/data/lang_bpe_500/bpe.model \
+    $repo/test_wavs/1089-134686-0001.wav \
+    $repo/test_wavs/1221-135766-0001.wav \
+    $repo/test_wavs/1221-135766-0002.wav
+done
+
+
+echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}"
+echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}"
+if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode"  ]]; then
+  mkdir -p zipformer_mmi/exp
+  ln -s $PWD/$repo/exp/pretrained.pt zipformer_mmi/exp/epoch-999.pt
+  ln -s $PWD/$repo/data/lang_bpe_500 data/
+
+  ls -lh data
+  ls -lh zipformer_mmi/exp
+
+  log "Decoding test-clean and test-other"
+
+  # use a small value for decoding with CPU
+  max_duration=100
+
+  for method in 1best nbest nbest-rescoring-LG nbest-rescoring-3-gram nbest-rescoring-4-gram; do
+    log "Decoding with $method"
+
+    ./zipformer_mmi/decode.py \
+      --decoding-method $method \
+      --epoch 999 \
+      --avg 1 \
+      --use-averaged-model 0 \
+      --nbest-scale 1.2 \
+      --hp-scale 1.0 \
+      --max-duration $max_duration \
+      --lang-dir $repo/data/lang_bpe_500 \
+      --exp-dir zipformer_mmi/exp
+  done
+
+  rm zipformer_mmi/exp/*.pt
+fi
diff --git a/.github/workflows/run-librispeech-2022-12-08-zipformer-mmi.yml b/.github/workflows/run-librispeech-2022-12-08-zipformer-mmi.yml
new file mode 100644
index 000000000..5472ca59b
--- /dev/null
+++ b/.github/workflows/run-librispeech-2022-12-08-zipformer-mmi.yml
@@ -0,0 +1,167 @@
+# Copyright      2022  Zengwei Yao
+
+# See ../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+name: run-librispeech-2022-12-08-zipformer-mmi
+# zipformer
+
+on:
+  push:
+    branches:
+      - master
+  pull_request:
+    types: [labeled]
+
+  schedule:
+    # minute (0-59)
+    # hour (0-23)
+    # day of the month (1-31)
+    # month (1-12)
+    # day of the week (0-6)
+    # nightly build at 15:50 UTC time every day
+    - cron: "50 15 * * *"
+
+concurrency:
+  group: run_librispeech_2022_12_08_zipformer-${{ github.ref }}
+  cancel-in-progress: true
+
+jobs:
+  run_librispeech_2022_12_08_zipformer:
+    if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
+    runs-on: ${{ matrix.os }}
+    strategy:
+      matrix:
+        os: [ubuntu-latest]
+        python-version: [3.8]
+
+      fail-fast: false
+
+    steps:
+      - uses: actions/checkout@v2
+        with:
+          fetch-depth: 0
+
+      - name: Setup Python ${{ matrix.python-version }}
+        uses: actions/setup-python@v2
+        with:
+          python-version: ${{ matrix.python-version }}
+          cache: 'pip'
+          cache-dependency-path: '**/requirements-ci.txt'
+
+      - name: Install Python dependencies
+        run: |
+          grep -v '^#' ./requirements-ci.txt  | xargs -n 1 -L 1 pip install
+          pip uninstall -y protobuf
+          pip install --no-binary protobuf protobuf
+
+      - name: Cache kaldifeat
+        id: my-cache
+        uses: actions/cache@v2
+        with:
+          path: |
+            ~/tmp/kaldifeat
+          key: cache-tmp-${{ matrix.python-version }}-2022-09-25
+
+      - name: Install kaldifeat
+        if: steps.my-cache.outputs.cache-hit != 'true'
+        shell: bash
+        run: |
+          .github/scripts/install-kaldifeat.sh
+
+      - name: Cache LibriSpeech test-clean and test-other datasets
+        id: libri-test-clean-and-test-other-data
+        uses: actions/cache@v2
+        with:
+          path: |
+            ~/tmp/download
+          key: cache-libri-test-clean-and-test-other
+
+      - name: Download LibriSpeech test-clean and test-other
+        if: steps.libri-test-clean-and-test-other-data.outputs.cache-hit != 'true'
+        shell: bash
+        run: |
+          .github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh
+
+      - name: Prepare manifests for LibriSpeech test-clean and test-other
+        shell: bash
+        run: |
+          .github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh
+
+      - name: Cache LibriSpeech test-clean and test-other fbank features
+        id: libri-test-clean-and-test-other-fbank
+        uses: actions/cache@v2
+        with:
+          path: |
+            ~/tmp/fbank-libri
+          key: cache-libri-fbank-test-clean-and-test-other-v2
+
+      - name: Compute fbank for LibriSpeech test-clean and test-other
+        if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true'
+        shell: bash
+        run: |
+          .github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh
+
+      - name: Inference with pre-trained model
+        shell: bash
+        env:
+          GITHUB_EVENT_NAME: ${{ github.event_name }}
+          GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }}
+        run: |
+          mkdir -p egs/librispeech/ASR/data
+          ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank
+          ls -lh egs/librispeech/ASR/data/*
+
+          sudo apt-get -qq install git-lfs tree sox
+          export PYTHONPATH=$PWD:$PYTHONPATH
+          export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
+          export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
+
+          .github/scripts/run-librispeech-zipformer-mmi-2022-12-08.sh
+
+      - name: Display decoding results for librispeech zipformer-mmi
+        if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
+        shell: bash
+        run: |
+          cd egs/librispeech/ASR/
+          tree ./zipformer-mmi/exp
+
+          cd zipformer-mmi
+          echo "results for zipformer-mmi"
+          echo "===1best==="
+          find exp/1best -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
+          find exp/1best -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
+
+          echo "===nbest==="
+          find exp/nbest -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
+          find exp/nbest -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
+
+          echo "===nbest-rescoring-LG==="
+          find exp/nbest-rescoring-LG -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
+          find exp/nbest-rescoring-LG -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
+
+          echo "===nbest-rescoring-3-gram==="
+          find exp/nbest-rescoring-3-gram -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
+          find exp/nbest-rescoring-3-gram -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
+
+          echo "===nbest-rescoring-4-gram==="
+          find exp/nbest-rescoring-4-gram -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
+          find exp/nbest-rescoring-4-gram -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
+
+      - name: Upload decoding results for librispeech zipformer-mmi
+        uses: actions/upload-artifact@v2
+        if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
+        with:
+          name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-zipformer_mmi-2022-12-08
+          path: egs/librispeech/ASR/zipformer_mmi/exp/
diff --git a/docs/source/recipes/librispeech/index.rst b/docs/source/recipes/librispeech/index.rst
index 6c91b6750..568a8016f 100644
--- a/docs/source/recipes/librispeech/index.rst
+++ b/docs/source/recipes/librispeech/index.rst
@@ -7,3 +7,4 @@ LibriSpeech
    tdnn_lstm_ctc
    conformer_ctc
    lstm_pruned_stateless_transducer
+   zipformer_mmi
diff --git a/docs/source/recipes/librispeech/zipformer_mmi.rst b/docs/source/recipes/librispeech/zipformer_mmi.rst
new file mode 100644
index 000000000..db268dd02
--- /dev/null
+++ b/docs/source/recipes/librispeech/zipformer_mmi.rst
@@ -0,0 +1,422 @@
+Zipformer MMI
+===============
+
+.. hint::
+
+   Please scroll down to the bottom of this page to find download links
+   for pretrained models if you don't want to train a model from scratch.
+
+
+This tutorial shows you how to train an Zipformer MMI model
+with the `LibriSpeech `_ dataset.
+
+We use LF-MMI to compute the loss.
+
+.. note::
+
+   You can find the document about LF-MMI training at the following address:
+
+   ``_
+
+
+Data preparation
+----------------
+
+.. code-block:: bash
+
+  $ cd egs/librispeech/ASR
+  $ ./prepare.sh
+
+The script ``./prepare.sh`` handles the data preparation for you, **automagically**.
+All you need to do is to run it.
+
+.. note::
+
+   We encourage you to read ``./prepare.sh``.
+
+The data preparation contains several stages. You can use the following two
+options:
+
+  - ``--stage``
+  - ``--stop-stage``
+
+to control which stage(s) should be run. By default, all stages are executed.
+
+
+For example,
+
+.. code-block:: bash
+
+  $ cd egs/librispeech/ASR
+  $ ./prepare.sh --stage 0 --stop-stage 0
+
+means to run only stage 0.
+
+To run stage 2 to stage 5, use:
+
+.. code-block:: bash
+
+  $ ./prepare.sh --stage 2 --stop-stage 5
+
+.. hint::
+
+  If you have pre-downloaded the `LibriSpeech `_
+  dataset and the `musan `_ dataset, say,
+  they are saved in ``/tmp/LibriSpeech`` and ``/tmp/musan``, you can modify
+  the ``dl_dir`` variable in ``./prepare.sh`` to point to ``/tmp`` so that
+  ``./prepare.sh`` won't re-download them.
+
+.. note::
+
+  All generated files by ``./prepare.sh``, e.g., features, lexicon, etc,
+  are saved in ``./data`` directory.
+
+We provide the following YouTube video showing how to run ``./prepare.sh``.
+
+.. note::
+
+   To get the latest news of `next-gen Kaldi `_, please subscribe
+   the following YouTube channel by `Nadira Povey `_:
+
+      ``_
+
+..  youtube:: ofEIoJL-mGM
+
+Training
+--------
+
+For stability, it uses CTC loss for model warm-up and then switches to MMI loss.
+
+Configurable options
+~~~~~~~~~~~~~~~~~~~~
+
+.. code-block:: bash
+
+  $ cd egs/librispeech/ASR
+  $ ./zipformer_mmi/train.py --help
+
+shows you the training options that can be passed from the commandline.
+The following options are used quite often:
+
+  - ``--full-libri``
+
+    If it's True, the training part uses all the training data, i.e.,
+    960 hours. Otherwise, the training part uses only the subset
+    ``train-clean-100``, which has 100 hours of training data.
+
+    .. CAUTION::
+
+      The training set is perturbed by speed with two factors: 0.9 and 1.1.
+      If ``--full-libri`` is True, each epoch actually processes
+      ``3x960 == 2880`` hours of data.
+
+  - ``--num-epochs``
+
+    It is the number of epochs to train. For instance,
+    ``./zipformer_mmi/train.py --num-epochs 30`` trains for 30 epochs
+    and generates ``epoch-1.pt``, ``epoch-2.pt``, ..., ``epoch-30.pt``
+    in the folder ``./zipformer_mmi/exp``.
+
+  - ``--start-epoch``
+
+    It's used to resume training.
+    ``./zipformer_mmi/train.py --start-epoch 10`` loads the
+    checkpoint ``./zipformer_mmi/exp/epoch-9.pt`` and starts
+    training from epoch 10, based on the state from epoch 9.
+
+  - ``--world-size``
+
+    It is used for multi-GPU single-machine DDP training.
+
+      - (a) If it is 1, then no DDP training is used.
+
+      - (b) If it is 2, then GPU 0 and GPU 1 are used for DDP training.
+
+    The following shows some use cases with it.
+
+      **Use case 1**: You have 4 GPUs, but you only want to use GPU 0 and
+      GPU 2 for training. You can do the following:
+
+        .. code-block:: bash
+
+          $ cd egs/librispeech/ASR
+          $ export CUDA_VISIBLE_DEVICES="0,2"
+          $ ./zipformer_mmi/train.py --world-size 2
+
+      **Use case 2**: You have 4 GPUs and you want to use all of them
+      for training. You can do the following:
+
+        .. code-block:: bash
+
+          $ cd egs/librispeech/ASR
+          $ ./zipformer_mmi/train.py --world-size 4
+
+      **Use case 3**: You have 4 GPUs but you only want to use GPU 3
+      for training. You can do the following:
+
+        .. code-block:: bash
+
+          $ cd egs/librispeech/ASR
+          $ export CUDA_VISIBLE_DEVICES="3"
+          $ ./zipformer_mmi/train.py --world-size 1
+
+    .. caution::
+
+      Only multi-GPU single-machine DDP training is implemented at present.
+      Multi-GPU multi-machine DDP training will be added later.
+
+  - ``--max-duration``
+
+    It specifies the number of seconds over all utterances in a
+    batch, before **padding**.
+    If you encounter CUDA OOM, please reduce it.
+
+    .. HINT::
+
+      Due to padding, the number of seconds of all utterances in a
+      batch will usually be larger than ``--max-duration``.
+
+      A larger value for ``--max-duration`` may cause OOM during training,
+      while a smaller value may increase the training time. You have to
+      tune it.
+
+
+Pre-configured options
+~~~~~~~~~~~~~~~~~~~~~~
+
+There are some training options, e.g., weight decay,
+number of warmup steps, results dir, etc,
+that are not passed from the commandline.
+They are pre-configured by the function ``get_params()`` in
+`zipformer_mmi/train.py `_
+
+You don't need to change these pre-configured parameters. If you really need to change
+them, please modify ``./zipformer_mmi/train.py`` directly.
+
+Training logs
+~~~~~~~~~~~~~
+
+Training logs and checkpoints are saved in ``zipformer_mmi/exp``.
+You will find the following files in that directory:
+
+  - ``epoch-1.pt``, ``epoch-2.pt``, ...
+
+    These are checkpoint files saved at the end of each epoch, containing model
+    ``state_dict`` and optimizer ``state_dict``.
+    To resume training from some checkpoint, say ``epoch-10.pt``, you can use:
+
+      .. code-block:: bash
+
+        $ ./zipformer_mmi/train.py --start-epoch 11
+
+  - ``checkpoint-436000.pt``, ``checkpoint-438000.pt``, ...
+
+    These are checkpoint files saved every ``--save-every-n`` batches,
+    containing model ``state_dict`` and optimizer ``state_dict``.
+    To resume training from some checkpoint, say ``checkpoint-436000``, you can use:
+
+      .. code-block:: bash
+
+        $ ./zipformer_mmi/train.py --start-batch 436000
+
+  - ``tensorboard/``
+
+    This folder contains tensorBoard logs. Training loss, validation loss, learning
+    rate, etc, are recorded in these logs. You can visualize them by:
+
+      .. code-block:: bash
+
+        $ cd zipformer_mmi/exp/tensorboard
+        $ tensorboard dev upload --logdir . --description "Zipformer MMI training for LibriSpeech with icefall"
+
+    It will print something like below:
+
+      .. code-block::
+
+        TensorFlow installation not found - running with reduced feature set.
+        Upload started and will continue reading any new data as it's added to the logdir.
+
+        To stop uploading, press Ctrl-C.
+
+        New experiment created. View your TensorBoard at: https://tensorboard.dev/experiment/xyOZUKpEQm62HBIlUD4uPA/
+
+    Note there is a URL in the above output. Click it and you will see
+    tensorboard.
+
+  .. hint::
+
+    If you don't have access to google, you can use the following command
+    to view the tensorboard log locally:
+
+      .. code-block:: bash
+
+        cd zipformer_mmi/exp/tensorboard
+        tensorboard --logdir . --port 6008
+
+    It will print the following message:
+
+      .. code-block::
+
+        Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all
+        TensorBoard 2.8.0 at http://localhost:6008/ (Press CTRL+C to quit)
+
+    Now start your browser and go to ``_ to view the tensorboard
+    logs.
+
+
+  - ``log/log-train-xxxx``
+
+    It is the detailed training log in text format, same as the one
+    you saw printed to the console during training.
+
+Usage example
+~~~~~~~~~~~~~
+
+You can use the following command to start the training using 8 GPUs:
+
+.. code-block:: bash
+
+  export CUDA_VISIBLE_DEVICES="0,1,2,3"
+  ./zipformer_mmi/train.py \
+    --world-size 4 \
+    --num-epochs 30 \
+    --start-epoch 1 \
+    --full-libri 1 \
+    --exp-dir zipformer_mmi/exp \
+    --max-duration 500 \
+    --use-fp16 1 \
+    --num-workers 2
+
+Decoding
+--------
+
+The decoding part uses checkpoints saved by the training part, so you have
+to run the training part first.
+
+.. hint::
+
+   There are two kinds of checkpoints:
+
+    - (1) ``epoch-1.pt``, ``epoch-2.pt``, ..., which are saved at the end
+      of each epoch. You can pass ``--epoch`` to
+      ``zipformer_mmi/decode.py`` to use them.
+
+    - (2) ``checkpoints-436000.pt``, ``epoch-438000.pt``, ..., which are saved
+      every ``--save-every-n`` batches. You can pass ``--iter`` to
+      ``zipformer_mmi/decode.py`` to use them.
+
+    We suggest that you try both types of checkpoints and choose the one
+    that produces the lowest WERs.
+
+.. code-block:: bash
+
+  $ cd egs/librispeech/ASR
+  $ ./zipformer_mmi/decode.py --help
+
+shows the options for decoding.
+
+The following shows the example using ``epoch-*.pt``:
+
+.. code-block:: bash
+
+  for m in nbest nbest-rescoring-LG nbest-rescoring-3-gram nbest-rescoring-4-gram; do
+    ./zipformer_mmi/decode.py \
+      --epoch 30 \
+      --avg 10 \
+      --exp-dir ./zipformer_mmi/exp/ \
+      --max-duration 100 \
+      --lang-dir data/lang_bpe_500 \
+      --nbest-scale 1.2 \
+      --hp-scale 1.0 \
+      --decoding-method $m
+  done
+
+
+Export models
+-------------
+
+`zipformer_mmi/export.py `_ supports exporting checkpoints from ``zipformer_mmi/exp`` in the following ways.
+
+Export ``model.state_dict()``
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+Checkpoints saved by ``zipformer_mmi/train.py`` also include
+``optimizer.state_dict()``. It is useful for resuming training. But after training,
+we are interested only in ``model.state_dict()``. You can use the following
+command to extract ``model.state_dict()``.
+
+.. code-block:: bash
+
+  ./zipformer_mmi/export.py \
+    --exp-dir ./zipformer_mmi/exp \
+    --bpe-model data/lang_bpe_500/bpe.model \
+    --epoch 30 \
+    --avg 9 \
+    --jit 0
+
+It will generate a file ``./zipformer_mmi/exp/pretrained.pt``.
+
+.. hint::
+
+   To use the generated ``pretrained.pt`` for ``zipformer_mmi/decode.py``,
+   you can run:
+
+   .. code-block:: bash
+
+      cd zipformer_mmi/exp
+      ln -s pretrained epoch-9999.pt
+
+   And then pass ``--epoch 9999 --avg 1 --use-averaged-model 0`` to
+   ``./zipformer_mmi/decode.py``.
+
+To use the exported model with ``./zipformer_mmi/pretrained.py``, you
+can run:
+
+.. code-block:: bash
+
+  ./zipformer_mmi/pretrained.py \
+    --checkpoint ./zipformer_mmi/exp/pretrained.pt \
+    --bpe-model ./data/lang_bpe_500/bpe.model \
+    --method 1best \
+    /path/to/foo.wav \
+    /path/to/bar.wav
+
+Export model using ``torch.jit.script()``
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+.. code-block:: bash
+
+  ./zipformer_mmi/export.py \
+    --exp-dir ./zipformer_mmi/exp \
+    --bpe-model data/lang_bpe_500/bpe.model \
+    --epoch 30 \
+    --avg 9 \
+    --jit 1
+
+It will generate a file ``cpu_jit.pt`` in the given ``exp_dir``. You can later
+load it by ``torch.jit.load("cpu_jit.pt")``.
+
+Note ``cpu`` in the name ``cpu_jit.pt`` means the parameters when loaded into Python
+are on CPU. You can use ``to("cuda")`` to move them to a CUDA device.
+
+To use the generated files with ``./zipformer_mmi/jit_pretrained.py``:
+
+.. code-block:: bash
+
+  ./zipformer_mmi/jit_pretrained.py \
+    --nn-model-filename ./zipformer_mmi/exp/cpu_jit.pt \
+    --bpe-model ./data/lang_bpe_500/bpe.model \
+    --method 1best \
+    /path/to/foo.wav \
+    /path/to/bar.wav
+
+Download pretrained models
+--------------------------
+
+If you don't want to train from scratch, you can download the pretrained models
+by visiting the following links:
+
+  - ``_
+
+  See ``_
+  for the details of the above pretrained models
diff --git a/egs/librispeech/ASR/RESULTS.md b/egs/librispeech/ASR/RESULTS.md
index 9e5669f6d..092f77814 100644
--- a/egs/librispeech/ASR/RESULTS.md
+++ b/egs/librispeech/ASR/RESULTS.md
@@ -1,5 +1,62 @@
 ## Results
 
+### zipformer_mmi (zipformer with mmi loss)
+
+See  for more details.
+
+[zipformer_mmi](./zipformer_mmi)
+
+The tensorboard log can be found at
+
+
+You can find a pretrained model, training logs, decoding logs, and decoding
+results at:
+
+
+Number of model parameters: 69136519, i.e., 69.14 M
+
+|                          | test-clean | test-other  | comment             |
+|--------------------------|------------|-------------|---------------------|
+| 1best                    | 2.54       | 5.65        | --epoch 30 --avg 10 |
+| nbest                    | 2.54       | 5.66        | --epoch 30 --avg 10 |
+| nbest-rescoring-LG       | 2.49       | 5.42        | --epoch 30 --avg 10 |
+| nbest-rescoring-3-gram   | 2.52       | 5.62        | --epoch 30 --avg 10 |
+| nbest-rescoring-4-gram   | 2.5        | 5.51        | --epoch 30 --avg 10 |
+
+The training commands are:
+```bash
+export CUDA_VISIBLE_DEVICES="0,1,2,3"
+
+./zipformer_mmi/train.py \
+  --world-size 4 \
+  --master-port 12345 \
+  --num-epochs 30 \
+  --start-epoch 1 \
+  --lang-dir data/lang_bpe_500 \
+  --max-duration 500 \
+  --full-libri 1 \
+  --use-fp16 1 \
+  --exp-dir zipformer_mmi/exp
+```
+
+The decoding commands for the transducer branch are:
+```bash
+export CUDA_VISIBLE_DEVICES="5"
+
+for m in nbest nbest-rescoring-LG nbest-rescoring-3-gram nbest-rescoring-4-gram; do
+  ./zipformer_mmi/decode.py \
+    --epoch 30 \
+    --avg 10 \
+    --exp-dir ./zipformer_mmi/exp/ \
+    --max-duration 100 \
+    --lang-dir data/lang_bpe_500 \
+    --nbest-scale 1.2 \
+    --hp-scale 1.0 \
+    --decoding-method $m
+done
+```
+
+
 ### pruned_transducer_stateless7_ctc (zipformer with transducer loss and ctc loss)
 
 See  for more details.
diff --git a/egs/librispeech/ASR/conformer_ctc3/jit_pretrained.py b/egs/librispeech/ASR/conformer_ctc3/jit_pretrained.py
index 5be898e37..76db46cc8 100755
--- a/egs/librispeech/ASR/conformer_ctc3/jit_pretrained.py
+++ b/egs/librispeech/ASR/conformer_ctc3/jit_pretrained.py
@@ -291,7 +291,10 @@ def main():
 
     batch_size = nnet_output.shape[0]
     supervision_segments = torch.tensor(
-        [[i, 0, nnet_output.shape[1]] for i in range(batch_size)],
+        [
+            [i, 0, feature_lengths[i] // params.subsampling_factor]
+            for i in range(batch_size)
+        ],
         dtype=torch.int32,
     )
 
diff --git a/egs/librispeech/ASR/conformer_ctc3/pretrained.py b/egs/librispeech/ASR/conformer_ctc3/pretrained.py
index 3628d6a5f..880945ea0 100755
--- a/egs/librispeech/ASR/conformer_ctc3/pretrained.py
+++ b/egs/librispeech/ASR/conformer_ctc3/pretrained.py
@@ -339,7 +339,10 @@ def main():
 
     batch_size = nnet_output.shape[0]
     supervision_segments = torch.tensor(
-        [[i, 0, nnet_output.shape[1]] for i in range(batch_size)],
+        [
+            [i, 0, feature_lengths[i] // params.subsampling_factor]
+            for i in range(batch_size)
+        ],
         dtype=torch.int32,
     )
 
diff --git a/egs/librispeech/ASR/conformer_mmi/decode.py b/egs/librispeech/ASR/conformer_mmi/decode.py
index e3c7b685f..74f6e73fa 100755
--- a/egs/librispeech/ASR/conformer_mmi/decode.py
+++ b/egs/librispeech/ASR/conformer_mmi/decode.py
@@ -660,14 +660,22 @@ def main():
     # we need cut ids to display recognition results.
     args.return_cuts = True
     librispeech = LibriSpeechAsrDataModule(args)
+
+    test_clean_cuts = librispeech.test_clean_cuts()
+    test_other_cuts = librispeech.test_other_cuts()
+
+    test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
+    test_other_dl = librispeech.test_dataloaders(test_other_cuts)
+
     # CAUTION: `test_sets` is for displaying only.
     # If you want to skip test-clean, you have to skip
     # it inside the for loop. That is, use
     #
     #   if test_set == 'test-clean': continue
-    #
     test_sets = ["test-clean", "test-other"]
-    for test_set, test_dl in zip(test_sets, librispeech.test_dataloaders()):
+    test_dls = [test_clean_dl, test_other_dl]
+
+    for test_set, test_dl in zip(test_sets, test_dls):
         results_dict = decode_dataset(
             dl=test_dl,
             params=params,
diff --git a/egs/librispeech/ASR/conformer_mmi/train-with-attention.py b/egs/librispeech/ASR/conformer_mmi/train-with-attention.py
index f8c94cff9..100bc846a 100755
--- a/egs/librispeech/ASR/conformer_mmi/train-with-attention.py
+++ b/egs/librispeech/ASR/conformer_mmi/train-with-attention.py
@@ -30,6 +30,8 @@ import torch.multiprocessing as mp
 import torch.nn as nn
 from asr_datamodule import LibriSpeechAsrDataModule
 from conformer import Conformer
+from lhotse.cut import Cut
+from lhotse.dataset.sampling.base import CutSampler
 from lhotse.utils import fix_random_seed
 from torch.nn.parallel import DistributedDataParallel as DDP
 from torch.nn.utils import clip_grad_norm_
@@ -100,6 +102,41 @@ def get_parser():
         """,
     )
 
+    parser.add_argument(
+        "--exp-dir",
+        type=str,
+        default="conformer_mmi/exp-attn",
+        help="""The experiment dir.
+        It specifies the directory where all training related
+        files, e.g., checkpoints, log, etc, are saved
+        """,
+    )
+
+    parser.add_argument(
+        "--lang-dir",
+        type=str,
+        default="data/lang_bpe_500",
+        help="""The lang dir
+        It contains language related input files such as
+        "lexicon.txt"
+        """,
+    )
+
+    parser.add_argument(
+        "--seed",
+        type=int,
+        default=42,
+        help="The seed for random generators intended for reproducibility",
+    )
+
+    parser.add_argument(
+        "--use-pruned-intersect",
+        type=str2bool,
+        default=False,
+        help="""Whether to use `intersect_dense_pruned` to get denominator
+        lattice.""",
+    )
+
     return parser
 
 
@@ -114,12 +151,6 @@ def get_params() -> AttributeDict:
 
     Explanation of options saved in `params`:
 
-        - exp_dir: It specifies the directory where all training related
-                   files, e.g., checkpoints, log, etc, are saved
-
-        - lang_dir: It contains language related input files such as
-                    "lexicon.txt"
-
         - best_train_loss: Best training loss so far. It is used to select
                            the model that has the lowest training loss. It is
                            updated during the training.
@@ -164,8 +195,6 @@ def get_params() -> AttributeDict:
     """
     params = AttributeDict(
         {
-            "exp_dir": Path("conformer_mmi/exp_500_with_attention"),
-            "lang_dir": Path("data/lang_bpe_500"),
             "best_train_loss": float("inf"),
             "best_valid_loss": float("inf"),
             "best_train_epoch": -1,
@@ -184,15 +213,12 @@ def get_params() -> AttributeDict:
             "beam_size": 6,  # will change it to 8 after some batches (see code)
             "reduction": "sum",
             "use_double_scores": True,
-            #  "att_rate": 0.0,
-            #  "num_decoder_layers": 0,
             "att_rate": 0.7,
             "num_decoder_layers": 6,
             # parameters for Noam
             "weight_decay": 1e-6,
             "lr_factor": 5.0,
             "warm_step": 80000,
-            "use_pruned_intersect": False,
             "den_scale": 1.0,
             # use alignments before this number of batches
             "use_ali_until": 13000,
@@ -661,7 +687,7 @@ def run(rank, world_size, args):
     params = get_params()
     params.update(vars(args))
 
-    fix_random_seed(42)
+    fix_random_seed(params.seed)
     if world_size > 1:
         setup_dist(rank, world_size, params.master_port)
 
@@ -745,8 +771,29 @@ def run(rank, world_size, args):
         valid_ali = None
 
     librispeech = LibriSpeechAsrDataModule(args)
-    train_dl = librispeech.train_dataloaders()
-    valid_dl = librispeech.valid_dataloaders()
+    train_cuts = librispeech.train_clean_100_cuts()
+    if params.full_libri:
+        train_cuts += librispeech.train_clean_360_cuts()
+        train_cuts += librispeech.train_other_500_cuts()
+
+    def remove_short_and_long_utt(c: Cut):
+        # Keep only utterances with duration between 1 second and 20 seconds
+        #
+        # Caution: There is a reason to select 20.0 here. Please see
+        # ../local/display_manifest_statistics.py
+        #
+        # You should use ../local/display_manifest_statistics.py to get
+        # an utterance duration distribution for your dataset to select
+        # the threshold
+        return 1.0 <= c.duration <= 20.0
+
+    train_cuts = train_cuts.filter(remove_short_and_long_utt)
+
+    train_dl = librispeech.train_dataloaders(train_cuts)
+
+    valid_cuts = librispeech.dev_clean_cuts()
+    valid_cuts += librispeech.dev_other_cuts()
+    valid_dl = librispeech.valid_dataloaders(valid_cuts)
 
     for epoch in range(params.start_epoch, params.num_epochs):
         train_dl.sampler.set_epoch(epoch)
@@ -796,6 +843,7 @@ def main():
     parser = get_parser()
     LibriSpeechAsrDataModule.add_arguments(parser)
     args = parser.parse_args()
+    args.exp_dir = Path(args.exp_dir)
 
     world_size = args.world_size
     assert world_size >= 1
diff --git a/egs/librispeech/ASR/conformer_mmi/train.py b/egs/librispeech/ASR/conformer_mmi/train.py
index 5cfb2bfc7..f9f80632e 100755
--- a/egs/librispeech/ASR/conformer_mmi/train.py
+++ b/egs/librispeech/ASR/conformer_mmi/train.py
@@ -30,6 +30,8 @@ import torch.multiprocessing as mp
 import torch.nn as nn
 from asr_datamodule import LibriSpeechAsrDataModule
 from conformer import Conformer
+from lhotse.cut import Cut
+from lhotse.dataset.sampling.base import CutSampler
 from lhotse.utils import fix_random_seed
 from torch.nn.parallel import DistributedDataParallel as DDP
 from torch.nn.utils import clip_grad_norm_
@@ -100,6 +102,26 @@ def get_parser():
         """,
     )
 
+    parser.add_argument(
+        "--exp-dir",
+        type=str,
+        default="conformer_mmi/exp",
+        help="""The experiment dir.
+        It specifies the directory where all training related
+        files, e.g., checkpoints, log, etc, are saved
+        """,
+    )
+
+    parser.add_argument(
+        "--lang-dir",
+        type=str,
+        default="data/lang_bpe_500",
+        help="""The lang dir
+        It contains language related input files such as
+        "lexicon.txt"
+        """,
+    )
+
     parser.add_argument(
         "--seed",
         type=int,
@@ -107,6 +129,14 @@ def get_parser():
         help="The seed for random generators intended for reproducibility",
     )
 
+    parser.add_argument(
+        "--use-pruned-intersect",
+        type=str2bool,
+        default=False,
+        help="""Whether to use `intersect_dense_pruned` to get denominator
+        lattice.""",
+    )
+
     return parser
 
 
@@ -121,12 +151,6 @@ def get_params() -> AttributeDict:
 
     Explanation of options saved in `params`:
 
-        - exp_dir: It specifies the directory where all training related
-                   files, e.g., checkpoints, log, etc, are saved
-
-        - lang_dir: It contains language related input files such as
-                    "lexicon.txt"
-
         - best_train_loss: Best training loss so far. It is used to select
                            the model that has the lowest training loss. It is
                            updated during the training.
@@ -171,8 +195,6 @@ def get_params() -> AttributeDict:
     """
     params = AttributeDict(
         {
-            "exp_dir": Path("conformer_mmi/exp_500"),
-            "lang_dir": Path("data/lang_bpe_500"),
             "best_train_loss": float("inf"),
             "best_valid_loss": float("inf"),
             "best_train_epoch": -1,
@@ -193,13 +215,10 @@ def get_params() -> AttributeDict:
             "use_double_scores": True,
             "att_rate": 0.0,
             "num_decoder_layers": 0,
-            #  "att_rate": 0.7,
-            #  "num_decoder_layers": 6,
             # parameters for Noam
             "weight_decay": 1e-6,
             "lr_factor": 5.0,
             "warm_step": 80000,
-            "use_pruned_intersect": False,
             "den_scale": 1.0,
             # use alignments before this number of batches
             "use_ali_until": 13000,
@@ -752,8 +771,29 @@ def run(rank, world_size, args):
         valid_ali = None
 
     librispeech = LibriSpeechAsrDataModule(args)
-    train_dl = librispeech.train_dataloaders()
-    valid_dl = librispeech.valid_dataloaders()
+    train_cuts = librispeech.train_clean_100_cuts()
+    if params.full_libri:
+        train_cuts += librispeech.train_clean_360_cuts()
+        train_cuts += librispeech.train_other_500_cuts()
+
+    def remove_short_and_long_utt(c: Cut):
+        # Keep only utterances with duration between 1 second and 20 seconds
+        #
+        # Caution: There is a reason to select 20.0 here. Please see
+        # ../local/display_manifest_statistics.py
+        #
+        # You should use ../local/display_manifest_statistics.py to get
+        # an utterance duration distribution for your dataset to select
+        # the threshold
+        return 1.0 <= c.duration <= 20.0
+
+    train_cuts = train_cuts.filter(remove_short_and_long_utt)
+
+    train_dl = librispeech.train_dataloaders(train_cuts)
+
+    valid_cuts = librispeech.dev_clean_cuts()
+    valid_cuts += librispeech.dev_other_cuts()
+    valid_dl = librispeech.valid_dataloaders(valid_cuts)
 
     for epoch in range(params.start_epoch, params.num_epochs):
         fix_random_seed(params.seed + epoch)
@@ -804,6 +844,7 @@ def main():
     parser = get_parser()
     LibriSpeechAsrDataModule.add_arguments(parser)
     args = parser.parse_args()
+    args.exp_dir = Path(args.exp_dir)
 
     world_size = args.world_size
     assert world_size >= 1
diff --git a/egs/librispeech/ASR/generate-lm.sh b/egs/librispeech/ASR/generate-lm.sh
index 6baccd381..dacd276d1 100755
--- a/egs/librispeech/ASR/generate-lm.sh
+++ b/egs/librispeech/ASR/generate-lm.sh
@@ -2,7 +2,7 @@
 
 lang_dir=data/lang_bpe_500
 
-for ngram in 2 3 5; do
+for ngram in 2 3 4 5; do
   if [ ! -f $lang_dir/${ngram}gram.arpa ]; then
     ./shared/make_kn_lm.py \
       -ngram-order ${ngram} \
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/export.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/export.py
index 59a393739..c1607699f 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/export.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/export.py
@@ -72,14 +72,14 @@ Check ./pretrained.py for its usage.
 Note: If you don't want to train a model from scratch, we have
 provided one for you. You can get it at
 
-https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11
+https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-ctc-2022-12-01
 
 with the following commands:
 
     sudo apt-get install git-lfs
     git lfs install
-    git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11
-    # You will find the pre-trained model in icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11/exp
+    git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-ctc-2022-12-01
+    # You will find the pre-trained model in icefall-asr-librispeech-pruned-transducer-stateless7-ctc-2022-12-01/exp
 """
 
 import argparse
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py
index d3343d34a..ad9cf08dc 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py
@@ -304,7 +304,10 @@ def main():
 
     batch_size = nnet_output.shape[0]
     supervision_segments = torch.tensor(
-        [[i, 0, nnet_output.shape[1]] for i in range(batch_size)],
+        [
+            [i, 0, feature_lengths[i] // params.subsampling_factor]
+            for i in range(batch_size)
+        ],
         dtype=torch.int32,
     )
 
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained_ctc.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained_ctc.py
index 74aef1bc7..5d460edb5 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained_ctc.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained_ctc.py
@@ -322,7 +322,10 @@ def main():
 
     batch_size = nnet_output.shape[0]
     supervision_segments = torch.tensor(
-        [[i, 0, nnet_output.shape[1]] for i in range(batch_size)],
+        [
+            [i, 0, feature_lengths[i] // params.subsampling_factor]
+            for i in range(batch_size)
+        ],
         dtype=torch.int32,
     )
 
diff --git a/egs/librispeech/ASR/zipformer_mmi/README.md b/egs/librispeech/ASR/zipformer_mmi/README.md
new file mode 100644
index 000000000..8ca844180
--- /dev/null
+++ b/egs/librispeech/ASR/zipformer_mmi/README.md
@@ -0,0 +1,26 @@
+This recipe implements Zipformer-MMI model.
+
+See https://k2-fsa.github.io/icefall/recipes/librispeech/zipformer_mmi.html for detailed tutorials.
+
+It uses **CTC loss for warm-up** and then switches to MMI loss during training.
+
+For decoding, it uses HP (H is ctc_topo, P is token-level bi-gram) as decoding graph. Supported decoding methods are:
+- **1best**. Extract the best path from the decoding lattice as the decoding result.
+- **nbest**. Extract n paths from the decoding lattice; the path with the highest score is the decoding result.
+- **nbest-rescoring-LG**. Extract n paths from the decoding lattice, rescore them with an word-level 3-gram LM, the path with the highest score is the decoding result.
+- **nbest-rescoring-3-gram**. Extract n paths from the decoding lattice, rescore them with an token-level 3-gram LM, the path with the highest score is the decoding result.
+- **nbest-rescoring-4-gram**. Extract n paths from the decoding lattice, rescore them with an token-level 4-gram LM, the path with the highest score is the decoding result.
+
+Experimental results training on train-clean-100 (epoch-30-avg-10):
+- 1best. 6.43 & 17.44
+- nbest, nbest-scale=1.2, 6.43 & 17.45
+- nbest-rescoring-LG, nbest-scale=1.2, 5.87 & 16.35
+- nbest-rescoring-3-gram,  nbest-scale=1.2, 6.19 & 16.57
+- nbest-rescoring-4-gram,  nbest-scale=1.2, 5.87 & 16.07
+
+Experimental results training on full librispeech (epoch-30-avg-10):
+- 1best. 2.54 & 5.65
+- nbest, nbest-scale=1.2, 2.54 & 5.66
+- nbest-rescoring-LG, nbest-scale=1.2, 2.49 & 5.42
+- nbest-rescoring-3-gram,  nbest-scale=1.2, 2.52 & 5.62
+- nbest-rescoring-4-gram,  nbest-scale=1.2, 2.5 & 5.51
diff --git a/egs/librispeech/ASR/zipformer_mmi/__init__.py b/egs/librispeech/ASR/zipformer_mmi/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/egs/librispeech/ASR/zipformer_mmi/asr_datamodule.py b/egs/librispeech/ASR/zipformer_mmi/asr_datamodule.py
new file mode 120000
index 000000000..a074d6085
--- /dev/null
+++ b/egs/librispeech/ASR/zipformer_mmi/asr_datamodule.py
@@ -0,0 +1 @@
+../pruned_transducer_stateless2/asr_datamodule.py
\ No newline at end of file
diff --git a/egs/librispeech/ASR/zipformer_mmi/decode.py b/egs/librispeech/ASR/zipformer_mmi/decode.py
new file mode 100755
index 000000000..7d0ea78bb
--- /dev/null
+++ b/egs/librispeech/ASR/zipformer_mmi/decode.py
@@ -0,0 +1,736 @@
+#!/usr/bin/env python3
+#
+# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang,
+#                                                 Liyong Guo,
+#                                                 Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+(1) 1best
+./zipformer_mmi/mmi_decode.py \
+    --epoch 30 \
+    --avg 15 \
+    --exp-dir ./zipformer_mmi/exp \
+    --max-duration 100 \
+    --decoding-method 1best
+(2) nbest
+./zipformer_mmi/mmi_decode.py \
+    --epoch 30 \
+    --avg 15 \
+    --exp-dir ./zipformer_mmi/exp \
+    --max-duration 100 \
+    --nbest-scale 1.0 \
+    --decoding-method nbest
+(3) nbest-rescoring-LG
+./zipformer_mmi/mmi_decode.py \
+    --epoch 30 \
+    --avg 15 \
+    --exp-dir ./zipformer_mmi/exp \
+    --max-duration 100 \
+    --nbest-scale 1.0 \
+    --decoding-method nbest-rescoring-LG
+(4) nbest-rescoring-3-gram
+./zipformer_mmi/mmi_decode.py \
+    --epoch 30 \
+    --avg 15 \
+    --exp-dir ./zipformer_mmi/exp \
+    --max-duration 100 \
+    --nbest-scale 1.0 \
+    --decoding-method nbest-rescoring-3-gram
+(5) nbest-rescoring-4-gram
+./zipformer_mmi/mmi_decode.py \
+    --epoch 30 \
+    --avg 15 \
+    --exp-dir ./zipformer_mmi/exp \
+    --max-duration 100 \
+    --nbest-scale 1.0 \
+    --decoding-method nbest-rescoring-4-gram
+"""
+
+
+import argparse
+import logging
+import math
+from collections import defaultdict
+from pathlib import Path
+from typing import Dict, List, Optional, Tuple
+
+import k2
+import sentencepiece as spm
+import torch
+import torch.nn as nn
+from asr_datamodule import LibriSpeechAsrDataModule
+from train import add_model_arguments, get_ctc_model, get_params
+
+from icefall.checkpoint import (
+    average_checkpoints,
+    average_checkpoints_with_averaged_model,
+    find_checkpoints,
+    load_checkpoint,
+)
+from icefall.decode import (
+    get_lattice,
+    nbest_decoding,
+    nbest_rescore_with_LM,
+    one_best_decoding,
+)
+from icefall.lexicon import Lexicon
+from icefall.mmi_graph_compiler import MmiTrainingGraphCompiler
+from icefall.utils import (
+    AttributeDict,
+    get_texts,
+    setup_logger,
+    store_transcripts,
+    str2bool,
+    write_error_stats,
+)
+
+LOG_EPS = math.log(1e-10)
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+
+    parser.add_argument(
+        "--epoch",
+        type=int,
+        default=30,
+        help="""It specifies the checkpoint to use for decoding.
+        Note: Epoch counts from 1.
+        You can specify --avg to use more checkpoints for model averaging.""",
+    )
+
+    parser.add_argument(
+        "--iter",
+        type=int,
+        default=0,
+        help="""If positive, --epoch is ignored and it
+        will use the checkpoint exp_dir/checkpoint-iter.pt.
+        You can specify --avg to use more checkpoints for model averaging.
+        """,
+    )
+
+    parser.add_argument(
+        "--avg",
+        type=int,
+        default=15,
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch' and '--iter'",
+    )
+
+    parser.add_argument(
+        "--use-averaged-model",
+        type=str2bool,
+        default=True,
+        help="Whether to load averaged model. Currently it only supports "
+        "using --epoch. If True, it would decode with the averaged model "
+        "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+        "Actually only the models with epoch number of `epoch-avg` and "
+        "`epoch` are loaded for averaging. ",
+    )
+
+    parser.add_argument(
+        "--exp-dir",
+        type=str,
+        default="zipformer_mmi/exp",
+        help="The experiment dir",
+    )
+
+    parser.add_argument(
+        "--bpe-model",
+        type=str,
+        default="data/lang_bpe_500/bpe.model",
+        help="Path to the BPE model",
+    )
+
+    parser.add_argument(
+        "--lang-dir",
+        type=Path,
+        default="data/lang_bpe_500",
+        help="The lang dir containing word table and LG graph",
+    )
+
+    parser.add_argument(
+        "--decoding-method",
+        type=str,
+        default="1best",
+        help="""Decoding method. Use HP as decoding graph, where H is
+        ctc_topo and P is token-level bi-gram lm.
+        Supported values are:
+        - (1) 1best. Extract the best path from the decoding lattice as the
+          decoding result.
+        - (2) nbest. Extract n paths from the decoding lattice; the path
+          with the highest score is the decoding result.
+        - (4) nbest-rescoring-LG. Extract n paths from the decoding lattice,
+          rescore them with an word-level 3-gram LM, the path with the
+          highest score is the decoding result.
+        - (5) nbest-rescoring-3-gram. Extract n paths from the decoding
+          lattice, rescore them with an token-level 3-gram LM, the path with
+          the highest score is the decoding result.
+        - (6) nbest-rescoring-4-gram. Extract n paths from the decoding
+          lattice, rescore them with an token-level 4-gram LM, the path with
+          the highest score is the decoding result.
+        """,
+    )
+
+    parser.add_argument(
+        "--num-paths",
+        type=int,
+        default=100,
+        help="""Number of paths for n-best based decoding method.
+        Used only when "method" is one of the following values:
+        nbest, nbest-rescoring, and nbest-oracle
+        """,
+    )
+
+    parser.add_argument(
+        "--nbest-scale",
+        type=float,
+        default=1.0,
+        help="""The scale to be applied to `lattice.scores`.
+        It's needed if you use any kinds of n-best based rescoring.
+        Used only when "method" is one of the following values:
+        nbest, nbest-rescoring, and nbest-oracle
+        A smaller value results in more unique paths.
+        """,
+    )
+
+    parser.add_argument(
+        "--hp-scale",
+        type=float,
+        default=1.0,
+        help="""The scale to be applied to `ctc_topo_P.scores`.
+        """,
+    )
+
+    add_model_arguments(parser)
+
+    return parser
+
+
+def get_decoding_params() -> AttributeDict:
+    """Parameters for decoding."""
+    params = AttributeDict(
+        {
+            "frame_shift_ms": 10,
+            "search_beam": 20,
+            "output_beam": 8,
+            "min_active_states": 30,
+            "max_active_states": 10000,
+            "use_double_scores": True,
+        }
+    )
+    return params
+
+
+def decode_one_batch(
+    params: AttributeDict,
+    model: nn.Module,
+    HP: Optional[k2.Fsa],
+    bpe_model: Optional[spm.SentencePieceProcessor],
+    batch: dict,
+    G: Optional[k2.Fsa] = None,
+    LG: Optional[k2.Fsa] = None,
+) -> Dict[str, List[List[str]]]:
+    """Decode one batch and return the result in a dict. The dict has the
+    following format:
+    - key: It indicates the setting used for decoding. For example,
+           if no rescoring is used, the key is the string `no_rescore`.
+           If LM rescoring is used, the key is the string `lm_scale_xxx`,
+           where `xxx` is the value of `lm_scale`. An example key is
+           `lm_scale_0.7`
+    - value: It contains the decoding result. `len(value)` equals to
+             batch size. `value[i]` is the decoding result for the i-th
+             utterance in the given batch.
+
+    Args:
+      params:
+        It's the return value of :func:`get_params`.
+
+        - params.decoding_method is "1best", it uses 1best decoding without LM rescoring.
+        - params.decoding_method is "nbest", it uses nbest decoding without LM rescoring.
+        - params.decoding_method is "nbest-rescoring-LG", it uses nbest rescoring with word-level 3-gram LM.
+        - params.decoding_method is "nbest-rescoring-3-gram", it uses nbest rescoring with token-level 3-gram LM.
+        - params.decoding_method is "nbest-rescoring-4-gram", it uses nbest rescoring with token-level 4-gram LM.
+
+      model:
+        The neural model.
+      HP:
+        The decoding graph. H is ctc_topo, P is token-level bi-gram LM.
+      bpe_model:
+        The BPE model.
+      batch:
+        It is the return value from iterating
+        `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
+        for the format of the `batch`.
+      LG:
+        An LM. L is the lexicon, G is a word-level 3-gram LM.
+        It is used when params.decoding_method is "nbest-rescoring-LG".
+      G:
+        An LM. L is the lexicon, G is a token-level 3-gram or 4-gram LM.
+        It is used when params.decoding_method is "nbest-rescoring-3-gram"
+        or "nbest-rescoring-4-gram".
+    Returns:
+      Return the decoding result. See above description for the format of
+      the returned dict. Note: If it decodes to nothing, then return None.
+    """
+    device = HP.device
+    feature = batch["inputs"]
+    assert feature.ndim == 3, feature.shape
+    feature = feature.to(device)
+
+    # at entry, feature is (N, T, C)
+
+    supervisions = batch["supervisions"]
+    feature_lens = supervisions["num_frames"].to(device)
+
+    nnet_output, encoder_out_lens = model(x=feature, x_lens=feature_lens)
+    # nnet_output is (N, T, C)
+
+    supervision_segments = torch.stack(
+        (
+            supervisions["sequence_idx"],
+            supervisions["start_frame"] // params.subsampling_factor,
+            supervisions["num_frames"] // params.subsampling_factor,
+        ),
+        1,
+    ).to(torch.int32)
+
+    lattice = get_lattice(
+        nnet_output=nnet_output,
+        decoding_graph=HP,
+        supervision_segments=supervision_segments,
+        search_beam=params.search_beam,
+        output_beam=params.output_beam,
+        min_active_states=params.min_active_states,
+        max_active_states=params.max_active_states,
+        subsampling_factor=params.subsampling_factor,
+    )
+
+    method = params.decoding_method
+
+    if method in ["1best", "nbest"]:
+        if method == "1best":
+            best_path = one_best_decoding(
+                lattice=lattice, use_double_scores=params.use_double_scores
+            )
+            key = "no_rescore"
+        else:
+            best_path = nbest_decoding(
+                lattice=lattice,
+                num_paths=params.num_paths,
+                use_double_scores=params.use_double_scores,
+                nbest_scale=params.nbest_scale,
+            )
+            key = f"no_rescore-nbest-scale-{params.nbest_scale}-{params.num_paths}"  # noqa
+
+        # Note: `best_path.aux_labels` contains token IDs, not word IDs
+        # since we are using HP, not HLG here.
+        #
+        # token_ids is a lit-of-list of IDs
+        token_ids = get_texts(best_path)
+        # hyps is a list of str, e.g., ['xxx yyy zzz', ...]
+        hyps = bpe_model.decode(token_ids)
+        # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ]
+        hyps = [s.split() for s in hyps]
+        return {key: hyps}
+
+    assert method in [
+        "nbest-rescoring-LG",  # word-level 3-gram lm
+        "nbest-rescoring-3-gram",  # token-level 3-gram lm
+        "nbest-rescoring-4-gram",  # token-level 4-gram lm
+    ]
+
+    lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]
+    lm_scale_list += [0.8, 0.9, 1.0, 1.1, 1.2, 1.3]
+    lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0]
+
+    if method == "nbest-rescoring-LG":
+        assert LG is not None
+        LM = LG
+    else:
+        assert G is not None
+        LM = G
+    best_path_dict = nbest_rescore_with_LM(
+        lattice=lattice,
+        LM=LM,
+        num_paths=params.num_paths,
+        lm_scale_list=lm_scale_list,
+        nbest_scale=params.nbest_scale,
+    )
+
+    ans = dict()
+    suffix = f"-nbest-scale-{params.nbest_scale}-{params.num_paths}"
+    for lm_scale_str, best_path in best_path_dict.items():
+        token_ids = get_texts(best_path)
+        # hyps is a list of str, e.g., ['xxx yyy zzz', ...]
+        hyps = bpe_model.decode(token_ids)
+        # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ]
+        hyps = [s.split() for s in hyps]
+        ans[lm_scale_str + suffix] = hyps
+    return ans
+
+
+def decode_dataset(
+    dl: torch.utils.data.DataLoader,
+    params: AttributeDict,
+    model: nn.Module,
+    HP: k2.Fsa,
+    bpe_model: spm.SentencePieceProcessor,
+    G: Optional[k2.Fsa] = None,
+    LG: Optional[k2.Fsa] = None,
+) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
+    """Decode dataset.
+
+    Args:
+      dl:
+        PyTorch's dataloader containing the dataset to decode.
+      params:
+        It is returned by :func:`get_params`.
+      model:
+        The neural model.
+      HP:
+        The decoding graph. H is ctc_topo, P is token-level bi-gram LM.
+      bpe_model:
+        The BPE model.
+      LG:
+        An LM. L is the lexicon, G is a word-level 3-gram LM.
+        It is used when params.decoding_method is "nbest-rescoring-LG".
+      G:
+        An LM. L is the lexicon, G is a token-level 3-gram or 4-gram LM.
+        It is used when params.decoding_method is "nbest-rescoring-3-gram"
+        or "nbest-rescoring-4-gram".
+
+    Returns:
+      Return a dict, whose key may be "no-rescore" if no LM rescoring
+      is used, or it may be "lm_scale_0.7" if LM rescoring is used.
+      Its value is a list of tuples. Each tuple contains two elements:
+      The first is the reference transcript, and the second is the
+      predicted result.
+    """
+    num_cuts = 0
+
+    try:
+        num_batches = len(dl)
+    except TypeError:
+        num_batches = "?"
+
+    results = defaultdict(list)
+    for batch_idx, batch in enumerate(dl):
+        texts = batch["supervisions"]["text"]
+        cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
+
+        hyps_dict = decode_one_batch(
+            params=params,
+            model=model,
+            HP=HP,
+            bpe_model=bpe_model,
+            batch=batch,
+            G=G,
+            LG=LG,
+        )
+
+        for name, hyps in hyps_dict.items():
+            this_batch = []
+            assert len(hyps) == len(texts)
+            for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
+                ref_words = ref_text.split()
+                this_batch.append((cut_id, ref_words, hyp_words))
+
+            results[name].extend(this_batch)
+
+        num_cuts += len(texts)
+
+        if batch_idx % 100 == 0:
+            batch_str = f"{batch_idx}/{num_batches}"
+
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+    return results
+
+
+def save_results(
+    params: AttributeDict,
+    test_set_name: str,
+    results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
+):
+    test_set_wers = dict()
+    for key, results in results_dict.items():
+        recog_path = (
+            params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
+        )
+        results = sorted(results)
+        store_transcripts(filename=recog_path, texts=results)
+        logging.info(f"The transcripts are stored in {recog_path}")
+
+        # The following prints out WERs, per-word error statistics and aligned
+        # ref/hyp pairs.
+        errs_filename = (
+            params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
+        )
+        with open(errs_filename, "w") as f:
+            wer = write_error_stats(f, f"{test_set_name}-{key}", results)
+            test_set_wers[key] = wer
+
+        logging.info("Wrote detailed error stats to {}".format(errs_filename))
+
+    test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
+    errs_info = (
+        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+    )
+    with open(errs_info, "w") as f:
+        print("settings\tWER", file=f)
+        for key, val in test_set_wers:
+            print("{}\t{}".format(key, val), file=f)
+
+    s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
+    note = "\tbest for {}".format(test_set_name)
+    for key, val in test_set_wers:
+        s += "{}\t{}{}\n".format(key, val, note)
+        note = ""
+    logging.info(s)
+
+
+@torch.no_grad()
+def main():
+    parser = get_parser()
+    LibriSpeechAsrDataModule.add_arguments(parser)
+    args = parser.parse_args()
+    args.exp_dir = Path(args.exp_dir)
+    args.lang_dir = Path(args.lang_dir)
+
+    params = get_params()
+    # add decoding params
+    params.update(get_decoding_params())
+    params.update(vars(args))
+
+    assert params.decoding_method in (
+        "1best",
+        "nbest",
+        "nbest-rescoring-LG",  # word-level 3-gram lm
+        "nbest-rescoring-3-gram",  # token-level 3-gram lm
+        "nbest-rescoring-4-gram",  # token-level 4-gram lm
+    ), params.decoding_method
+    params.res_dir = params.exp_dir / params.decoding_method
+
+    if params.iter > 0:
+        params.suffix = f"iter-{params.iter}-avg-{params.avg}"
+    else:
+        params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
+
+    if params.use_averaged_model:
+        params.suffix += "-use-averaged-model"
+
+    setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
+    logging.info("decoding started")
+
+    device = torch.device("cpu")
+    if torch.cuda.is_available():
+        device = torch.device("cuda", 0)
+
+    logging.info(f"device: {device}")
+    logging.info(params)
+
+    lexicon = Lexicon(params.lang_dir)
+    max_token_id = max(lexicon.tokens)
+    num_classes = max_token_id + 1  # +1 for the blank
+
+    params.vocab_size = num_classes
+    #  and  are defined in local/train_bpe_model.py
+    params.blank_id = 0
+
+    bpe_model = spm.SentencePieceProcessor()
+    bpe_model.load(str(params.lang_dir / "bpe.model"))
+    mmi_graph_compiler = MmiTrainingGraphCompiler(
+        params.lang_dir,
+        uniq_filename="lexicon.txt",
+        device=device,
+        oov="",
+        sos_id=1,
+        eos_id=1,
+    )
+    HP = mmi_graph_compiler.ctc_topo_P
+    HP.scores *= params.hp_scale
+    if not hasattr(HP, "lm_scores"):
+        HP.lm_scores = HP.scores.clone()
+
+    LG = None
+    G = None
+
+    if params.decoding_method == "nbest-rescoring-LG":
+        lg_filename = params.lang_dir / "LG.pt"
+        logging.info(f"Loading {lg_filename}")
+        LG = k2.Fsa.from_dict(torch.load(lg_filename, map_location=device))
+        LG = k2.Fsa.from_fsas([LG]).to(device)
+        LG.lm_scores = LG.scores.clone()
+
+    elif params.decoding_method in ["nbest-rescoring-3-gram", "nbest-rescoring-4-gram"]:
+        order = params.decoding_method[-6]
+        assert order in ("3", "4"), (params.decoding_method, order)
+        order = int(order)
+        if not (params.lang_dir / f"{order}gram.pt").is_file():
+            logging.info(f"Loading {order}gram.fst.txt")
+            logging.warning("It may take a few minutes.")
+            with open(params.lang_dir / f"{order}gram.fst.txt") as f:
+                first_token_disambig_id = lexicon.token_table["#0"]
+
+                G = k2.Fsa.from_openfst(f.read(), acceptor=False)
+                # G.aux_labels is not needed in later computations, so
+                # remove it here.
+                del G.aux_labels
+                # CAUTION: The following line is crucial.
+                # Arcs entering the back-off state have label equal to #0.
+                # We have to change it to 0 here.
+                G.labels[G.labels >= first_token_disambig_id] = 0
+                G = k2.Fsa.from_fsas([G]).to(device)
+                # G = k2.remove_epsilon(G)
+                G = k2.arc_sort(G)
+                # Save a dummy value so that it can be loaded in C++.
+                # See https://github.com/pytorch/pytorch/issues/67902
+                # for why we need to do this.
+                G.dummy = 1
+
+                torch.save(G.as_dict(), params.lang_dir / f"{order}gram.pt")
+        else:
+            logging.info(f"Loading pre-compiled {order}gram.pt")
+            d = torch.load(params.lang_dir / f"{order}gram.pt", map_location=device)
+            G = k2.Fsa.from_dict(d)
+
+        G.lm_scores = G.scores.clone()
+
+    logging.info("About to create model")
+    model = get_ctc_model(params)
+
+    if not params.use_averaged_model:
+        if params.iter > 0:
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg
+            ]
+            if len(filenames) == 0:
+                raise ValueError(
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            elif len(filenames) < params.avg:
+                raise ValueError(
+                    f"Not enough checkpoints ({len(filenames)}) found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            logging.info(f"averaging {filenames}")
+            model.to(device)
+            model.load_state_dict(average_checkpoints(filenames, device=device))
+        elif params.avg == 1:
+            load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+        else:
+            start = params.epoch - params.avg + 1
+            filenames = []
+            for i in range(start, params.epoch + 1):
+                if i >= 1:
+                    filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+            logging.info(f"averaging {filenames}")
+            model.to(device)
+            model.load_state_dict(average_checkpoints(filenames, device=device))
+    else:
+        if params.iter > 0:
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg + 1
+            ]
+            if len(filenames) == 0:
+                raise ValueError(
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            elif len(filenames) < params.avg + 1:
+                raise ValueError(
+                    f"Not enough checkpoints ({len(filenames)}) found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            filename_start = filenames[-1]
+            filename_end = filenames[0]
+            logging.info(
+                "Calculating the averaged model over iteration checkpoints"
+                f" from {filename_start} (excluded) to {filename_end}"
+            )
+            model.to(device)
+            model.load_state_dict(
+                average_checkpoints_with_averaged_model(
+                    filename_start=filename_start,
+                    filename_end=filename_end,
+                    device=device,
+                )
+            )
+        else:
+            assert params.avg > 0, params.avg
+            start = params.epoch - params.avg
+            assert start >= 1, start
+            filename_start = f"{params.exp_dir}/epoch-{start}.pt"
+            filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
+            logging.info(
+                f"Calculating the averaged model over epoch range from "
+                f"{start} (excluded) to {params.epoch}"
+            )
+            model.to(device)
+            model.load_state_dict(
+                average_checkpoints_with_averaged_model(
+                    filename_start=filename_start,
+                    filename_end=filename_end,
+                    device=device,
+                )
+            )
+
+    model.to(device)
+    model.eval()
+
+    num_param = sum([p.numel() for p in model.parameters()])
+    logging.info(f"Number of model parameters: {num_param}")
+
+    # we need cut ids to display recognition results.
+    args.return_cuts = True
+    librispeech = LibriSpeechAsrDataModule(args)
+
+    test_clean_cuts = librispeech.test_clean_cuts()
+    test_other_cuts = librispeech.test_other_cuts()
+
+    test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
+    test_other_dl = librispeech.test_dataloaders(test_other_cuts)
+
+    test_sets = ["test-clean", "test-other"]
+    test_dl = [test_clean_dl, test_other_dl]
+
+    for test_set, test_dl in zip(test_sets, test_dl):
+        results_dict = decode_dataset(
+            dl=test_dl,
+            params=params,
+            model=model,
+            HP=HP,
+            bpe_model=bpe_model,
+            G=G,
+            LG=LG,
+        )
+
+        save_results(
+            params=params,
+            test_set_name=test_set,
+            results_dict=results_dict,
+        )
+
+    logging.info("Done!")
+
+
+if __name__ == "__main__":
+    main()
diff --git a/egs/librispeech/ASR/zipformer_mmi/encoder_interface.py b/egs/librispeech/ASR/zipformer_mmi/encoder_interface.py
new file mode 120000
index 000000000..b9aa0ae08
--- /dev/null
+++ b/egs/librispeech/ASR/zipformer_mmi/encoder_interface.py
@@ -0,0 +1 @@
+../pruned_transducer_stateless2/encoder_interface.py
\ No newline at end of file
diff --git a/egs/librispeech/ASR/zipformer_mmi/export.py b/egs/librispeech/ASR/zipformer_mmi/export.py
new file mode 100755
index 000000000..0af7bd367
--- /dev/null
+++ b/egs/librispeech/ASR/zipformer_mmi/export.py
@@ -0,0 +1,307 @@
+#!/usr/bin/env python3
+#
+# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# This script converts several saved checkpoints
+# to a single one using model averaging.
+"""
+
+Usage:
+
+(1) Export to torchscript model using torch.jit.script()
+
+./zipformer_mmi/export.py \
+  --exp-dir ./zipformer_mmi/exp \
+  --bpe-model data/lang_bpe_500/bpe.model \
+  --epoch 30 \
+  --avg 9 \
+  --jit 1
+
+It will generate a file `cpu_jit.pt` in the given `exp_dir`. You can later
+load it by `torch.jit.load("cpu_jit.pt")`.
+
+Note `cpu` in the name `cpu_jit.pt` means the parameters when loaded into Python
+are on CPU. You can use `to("cuda")` to move them to a CUDA device.
+
+Check
+https://github.com/k2-fsa/sherpa
+for how to use the exported models outside of icefall.
+
+(2) Export `model.state_dict()`
+
+./zipformer_mmi/export.py \
+  --exp-dir ./zipformer_mmi/exp \
+  --bpe-model data/lang_bpe_500/bpe.model \
+  --epoch 20 \
+  --avg 10
+
+It will generate a file `pretrained.pt` in the given `exp_dir`. You can later
+load it by `icefall.checkpoint.load_checkpoint()`.
+
+To use the generated file with `zipformer_mmi/decode.py`,
+you can do:
+
+    cd /path/to/exp_dir
+    ln -s pretrained.pt epoch-9999.pt
+
+    cd /path/to/egs/librispeech/ASR
+    ./zipformer_mmi/decode.py \
+        --exp-dir ./zipformer_mmi/exp \
+        --epoch 9999 \
+        --avg 1 \
+        --max-duration 600 \
+        --decoding-method greedy_search \
+        --bpe-model data/lang_bpe_500/bpe.model
+
+Check ./pretrained.py for its usage.
+
+Note: If you don't want to train a model from scratch, we have
+provided one for you. You can get it at
+
+https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-mmi-2022-12-08
+
+with the following commands:
+
+    sudo apt-get install git-lfs
+    git lfs install
+    git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-mmi-2022-12-08
+    # You will find the pre-trained model in icefall-asr-librispeech-zipformer-mmi-2022-12-08/exp
+"""
+
+import argparse
+import logging
+from pathlib import Path
+
+import sentencepiece as spm
+import torch
+from scaling_converter import convert_scaled_to_non_scaled
+from train import add_model_arguments, get_ctc_model, get_params
+
+from icefall.checkpoint import (
+    average_checkpoints,
+    average_checkpoints_with_averaged_model,
+    find_checkpoints,
+    load_checkpoint,
+)
+from icefall.utils import str2bool
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+
+    parser.add_argument(
+        "--epoch",
+        type=int,
+        default=30,
+        help="""It specifies the checkpoint to use for decoding.
+        Note: Epoch counts from 1.
+        You can specify --avg to use more checkpoints for model averaging.""",
+    )
+
+    parser.add_argument(
+        "--iter",
+        type=int,
+        default=0,
+        help="""If positive, --epoch is ignored and it
+        will use the checkpoint exp_dir/checkpoint-iter.pt.
+        You can specify --avg to use more checkpoints for model averaging.
+        """,
+    )
+
+    parser.add_argument(
+        "--avg",
+        type=int,
+        default=9,
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch' and '--iter'",
+    )
+
+    parser.add_argument(
+        "--use-averaged-model",
+        type=str2bool,
+        default=True,
+        help="Whether to load averaged model. Currently it only supports "
+        "using --epoch. If True, it would decode with the averaged model "
+        "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+        "Actually only the models with epoch number of `epoch-avg` and "
+        "`epoch` are loaded for averaging. ",
+    )
+
+    parser.add_argument(
+        "--exp-dir",
+        type=str,
+        default="zipformer_mmi/exp",
+        help="""It specifies the directory where all training related
+        files, e.g., checkpoints, log, etc, are saved
+        """,
+    )
+
+    parser.add_argument(
+        "--bpe-model",
+        type=str,
+        default="data/lang_bpe_500/bpe.model",
+        help="Path to the BPE model",
+    )
+
+    parser.add_argument(
+        "--jit",
+        type=str2bool,
+        default=False,
+        help="""True to save a model after applying torch.jit.script.
+        It will generate a file named cpu_jit.pt
+
+        Check ./jit_pretrained.py for how to use it.
+        """,
+    )
+
+    add_model_arguments(parser)
+
+    return parser
+
+
+@torch.no_grad()
+def main():
+    args = get_parser().parse_args()
+    args.exp_dir = Path(args.exp_dir)
+
+    params = get_params()
+    params.update(vars(args))
+
+    device = torch.device("cpu")
+    if torch.cuda.is_available():
+        device = torch.device("cuda", 0)
+
+    logging.info(f"device: {device}")
+
+    sp = spm.SentencePieceProcessor()
+    sp.load(params.bpe_model)
+
+    #  is defined in local/train_bpe_model.py
+    params.blank_id = sp.piece_to_id("")
+    params.vocab_size = sp.get_piece_size()
+
+    logging.info(params)
+
+    logging.info("About to create model")
+    model = get_ctc_model(params)
+
+    model.to(device)
+
+    if not params.use_averaged_model:
+        if params.iter > 0:
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg
+            ]
+            if len(filenames) == 0:
+                raise ValueError(
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            elif len(filenames) < params.avg:
+                raise ValueError(
+                    f"Not enough checkpoints ({len(filenames)}) found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            logging.info(f"averaging {filenames}")
+            model.to(device)
+            model.load_state_dict(average_checkpoints(filenames, device=device))
+        elif params.avg == 1:
+            load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+        else:
+            start = params.epoch - params.avg + 1
+            filenames = []
+            for i in range(start, params.epoch + 1):
+                if i >= 1:
+                    filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+            logging.info(f"averaging {filenames}")
+            model.to(device)
+            model.load_state_dict(average_checkpoints(filenames, device=device))
+    else:
+        if params.iter > 0:
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg + 1
+            ]
+            if len(filenames) == 0:
+                raise ValueError(
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            elif len(filenames) < params.avg + 1:
+                raise ValueError(
+                    f"Not enough checkpoints ({len(filenames)}) found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            filename_start = filenames[-1]
+            filename_end = filenames[0]
+            logging.info(
+                "Calculating the averaged model over iteration checkpoints"
+                f" from {filename_start} (excluded) to {filename_end}"
+            )
+            model.to(device)
+            model.load_state_dict(
+                average_checkpoints_with_averaged_model(
+                    filename_start=filename_start,
+                    filename_end=filename_end,
+                    device=device,
+                )
+            )
+        else:
+            assert params.avg > 0, params.avg
+            start = params.epoch - params.avg
+            assert start >= 1, start
+            filename_start = f"{params.exp_dir}/epoch-{start}.pt"
+            filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
+            logging.info(
+                f"Calculating the averaged model over epoch range from "
+                f"{start} (excluded) to {params.epoch}"
+            )
+            model.to(device)
+            model.load_state_dict(
+                average_checkpoints_with_averaged_model(
+                    filename_start=filename_start,
+                    filename_end=filename_end,
+                    device=device,
+                )
+            )
+
+    model.to("cpu")
+    model.eval()
+
+    if params.jit is True:
+        convert_scaled_to_non_scaled(model, inplace=True)
+        logging.info("Using torch.jit.script()")
+        model = torch.jit.script(model)
+        filename = params.exp_dir / "cpu_jit.pt"
+        model.save(str(filename))
+        logging.info(f"Saved to {filename}")
+    else:
+        logging.info("Not using torchscript. Export model.state_dict()")
+        # Save it using a format so that it can be loaded
+        # by :func:`load_checkpoint`
+        filename = params.exp_dir / "pretrained.pt"
+        torch.save({"model": model.state_dict()}, str(filename))
+        logging.info(f"Saved to {filename}")
+
+
+if __name__ == "__main__":
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+    logging.basicConfig(format=formatter, level=logging.INFO)
+    main()
diff --git a/egs/librispeech/ASR/zipformer_mmi/jit_pretrained.py b/egs/librispeech/ASR/zipformer_mmi/jit_pretrained.py
new file mode 100755
index 000000000..c9ef16ffa
--- /dev/null
+++ b/egs/librispeech/ASR/zipformer_mmi/jit_pretrained.py
@@ -0,0 +1,391 @@
+#!/usr/bin/env python3
+# Copyright      2021-2022  Xiaomi Corp.   (authors: Fangjun Kuang,
+#                                                    Zengwei)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This script loads torchscript models, exported by `torch.jit.script()`
+and uses them to decode waves.
+You can use the following command to get the exported models:
+
+./zipformer_mmi/export.py \
+  --exp-dir ./zipformer_mmi/exp \
+  --bpe-model data/lang_bpe_500/bpe.model \
+  --epoch 20 \
+  --avg 10 \
+  --jit 1
+
+Usage of this script:
+
+(1) 1best
+./zipformer_mmi/jit_pretrained.py \
+    --nn-model-filename ./zipformer_mmi/exp/cpu_jit.pt \
+    --bpe-model ./data/lang_bpe_500/bpe.model \
+    --method 1best \
+    /path/to/foo.wav \
+    /path/to/bar.wav
+(2) nbest
+./zipformer_mmi/jit_pretrained.py \
+    --nn-model-filename ./zipformer_mmi/exp/cpu_jit.pt \
+    --bpe-model ./data/lang_bpe_500/bpe.model \
+    --nbest-scale 1.2 \
+    --method nbest \
+    /path/to/foo.wav \
+    /path/to/bar.wav
+(3) nbest-rescoring-LG
+./zipformer_mmi/jit_pretrained.py \
+    --nn-model-filename ./zipformer_mmi/exp/cpu_jit.pt \
+    --bpe-model ./data/lang_bpe_500/bpe.model \
+    --nbest-scale 1.2 \
+    --method nbest-rescoring-LG \
+    /path/to/foo.wav \
+    /path/to/bar.wav
+(4) nbest-rescoring-3-gram
+./zipformer_mmi/jit_pretrained.py \
+    --nn-model-filename ./zipformer_mmi/exp/cpu_jit.pt \
+    --bpe-model ./data/lang_bpe_500/bpe.model \
+    --nbest-scale 1.2 \
+    --method nbest-rescoring-3-gram \
+    /path/to/foo.wav \
+    /path/to/bar.wav
+(5) nbest-rescoring-4-gram
+./zipformer_mmi/jit_pretrained.py \
+    --nn-model-filename ./zipformer_mmi/exp/cpu_jit.pt \
+    --bpe-model ./data/lang_bpe_500/bpe.model \
+    --nbest-scale 1.2 \
+    --method nbest-rescoring-4-gram \
+    /path/to/foo.wav \
+    /path/to/bar.wav
+"""
+
+import argparse
+import logging
+import math
+from pathlib import Path
+from typing import List
+
+import k2
+import kaldifeat
+import sentencepiece as spm
+import torch
+import torchaudio
+from decode import get_decoding_params
+from torch.nn.utils.rnn import pad_sequence
+from train import get_params
+
+from icefall.decode import (
+    get_lattice,
+    nbest_decoding,
+    nbest_rescore_with_LM,
+    one_best_decoding,
+)
+from icefall.mmi_graph_compiler import MmiTrainingGraphCompiler
+from icefall.utils import get_texts
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+
+    parser.add_argument(
+        "--nn-model-filename",
+        type=str,
+        required=True,
+        help="Path to the torchscript model cpu_jit.pt",
+    )
+
+    parser.add_argument(
+        "--bpe-model",
+        type=str,
+        help="""Path to bpe.model.""",
+    )
+
+    parser.add_argument(
+        "--method",
+        type=str,
+        default="1best",
+        help="""Decoding method. Use HP as decoding graph, where H is
+        ctc_topo and P is token-level bi-gram lm.
+        Supported values are:
+        - (1) 1best. Extract the best path from the decoding lattice as the
+          decoding result.
+        - (2) nbest. Extract n paths from the decoding lattice; the path
+          with the highest score is the decoding result.
+        - (4) nbest-rescoring-LG. Extract n paths from the decoding lattice,
+          rescore them with an word-level 3-gram LM, the path with the
+          highest score is the decoding result.
+        - (5) nbest-rescoring-3-gram. Extract n paths from the decoding
+          lattice, rescore them with an token-level 3-gram LM, the path with
+          the highest score is the decoding result.
+        - (6) nbest-rescoring-4-gram. Extract n paths from the decoding
+          lattice, rescore them with an token-level 4-gram LM, the path with
+          the highest score is the decoding result.
+        """,
+    )
+
+    parser.add_argument(
+        "--sample-rate",
+        type=int,
+        default=16000,
+        help="The sample rate of the input sound file",
+    )
+
+    parser.add_argument(
+        "--lang-dir",
+        type=Path,
+        default="data/lang_bpe_500",
+        help="The lang dir containing word table and LG graph",
+    )
+
+    parser.add_argument(
+        "--num-paths",
+        type=int,
+        default=100,
+        help="""Number of paths for n-best based decoding method.
+        Used only when "method" is one of the following values:
+        nbest, nbest-rescoring, and nbest-oracle
+        """,
+    )
+
+    parser.add_argument(
+        "--nbest-scale",
+        type=float,
+        default=1.2,
+        help="""The scale to be applied to `lattice.scores`.
+        It's needed if you use any kinds of n-best based rescoring.
+        Used only when "method" is one of the following values:
+        nbest, nbest-rescoring, and nbest-oracle
+        A smaller value results in more unique paths.
+        """,
+    )
+
+    parser.add_argument(
+        "--ngram-lm-scale",
+        type=float,
+        default=0.1,
+        help="""
+        Used when method is nbest-rescoring-LG, nbest-rescoring-3-gram,
+        and nbest-rescoring-4-gram.
+        It specifies the scale for n-gram LM scores.
+        (Note: You need to tune it on a dataset.)
+        """,
+    )
+
+    parser.add_argument(
+        "--hp-scale",
+        type=float,
+        default=1.0,
+        help="""The scale to be applied to `ctc_topo_P.scores`.
+        """,
+    )
+
+    parser.add_argument(
+        "sound_files",
+        type=str,
+        nargs="+",
+        help="The input sound file(s) to transcribe. "
+        "Supported formats are those supported by torchaudio.load(). "
+        "For example, wav and flac are supported. "
+        "The sample rate has to be 16kHz.",
+    )
+
+    return parser
+
+
+def read_sound_files(
+    filenames: List[str], expected_sample_rate: float = 16000
+) -> List[torch.Tensor]:
+    """Read a list of sound files into a list 1-D float32 torch tensors.
+    Args:
+      filenames:
+        A list of sound filenames.
+      expected_sample_rate:
+        The expected sample rate of the sound files.
+    Returns:
+      Return a list of 1-D float32 torch tensors.
+    """
+    ans = []
+    for f in filenames:
+        wave, sample_rate = torchaudio.load(f)
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
+        # We use only the first channel
+        ans.append(wave[0])
+    return ans
+
+
+@torch.no_grad()
+def main():
+    parser = get_parser()
+    args = parser.parse_args()
+    logging.info(vars(args))
+
+    params = get_params()
+    # add decoding params
+    params.update(get_decoding_params())
+    params.update(vars(args))
+
+    device = torch.device("cpu")
+    if torch.cuda.is_available():
+        device = torch.device("cuda", 0)
+
+    logging.info(f"device: {device}")
+
+    model = torch.jit.load(params.nn_model_filename)
+    model.eval()
+    model.to(device)
+
+    sp = spm.SentencePieceProcessor()
+    sp.load(args.bpe_model)
+
+    logging.info("Constructing Fbank computer")
+    opts = kaldifeat.FbankOptions()
+    opts.device = device
+    opts.frame_opts.dither = 0
+    opts.frame_opts.snip_edges = False
+    opts.frame_opts.samp_freq = 16000
+    opts.mel_opts.num_bins = 80
+
+    fbank = kaldifeat.Fbank(opts)
+
+    logging.info(f"Reading sound files: {args.sound_files}")
+    waves = read_sound_files(
+        filenames=params.sound_files, expected_sample_rate=params.sample_rate
+    )
+    waves = [w.to(device) for w in waves]
+
+    logging.info("Decoding started")
+    features = fbank(waves)
+    feature_lengths = [f.size(0) for f in features]
+
+    features = pad_sequence(
+        features,
+        batch_first=True,
+        padding_value=math.log(1e-10),
+    )
+    feature_lengths = torch.tensor(feature_lengths, device=device)
+
+    bpe_model = spm.SentencePieceProcessor()
+    bpe_model.load(str(params.lang_dir / "bpe.model"))
+    mmi_graph_compiler = MmiTrainingGraphCompiler(
+        params.lang_dir,
+        uniq_filename="lexicon.txt",
+        device=device,
+        oov="",
+        sos_id=1,
+        eos_id=1,
+    )
+    HP = mmi_graph_compiler.ctc_topo_P
+    HP.scores *= params.hp_scale
+    if not hasattr(HP, "lm_scores"):
+        HP.lm_scores = HP.scores.clone()
+
+    method = params.method
+    assert method in (
+        "1best",
+        "nbest",
+        "nbest-rescoring-LG",  # word-level 3-gram lm
+        "nbest-rescoring-3-gram",  # token-level 3-gram lm
+        "nbest-rescoring-4-gram",  # token-level 4-gram lm
+    )
+    # loading language model for rescoring
+    LM = None
+    if method == "nbest-rescoring-LG":
+        lg_filename = params.lang_dir / "LG.pt"
+        logging.info(f"Loading {lg_filename}")
+        LG = k2.Fsa.from_dict(torch.load(lg_filename, map_location=device))
+        LG = k2.Fsa.from_fsas([LG]).to(device)
+        LG.lm_scores = LG.scores.clone()
+        LM = LG
+    elif method in ["nbest-rescoring-3-gram", "nbest-rescoring-4-gram"]:
+        order = method[-6]
+        assert order in ("3", "4")
+        order = int(order)
+        logging.info(f"Loading pre-compiled {order}gram.pt")
+        d = torch.load(params.lang_dir / f"{order}gram.pt", map_location=device)
+        G = k2.Fsa.from_dict(d)
+        G.lm_scores = G.scores.clone()
+        LM = G
+
+    # Encoder forward
+    nnet_output, encoder_out_lens = model(x=features, x_lens=feature_lengths)
+
+    batch_size = nnet_output.shape[0]
+    supervision_segments = torch.tensor(
+        [
+            [i, 0, feature_lengths[i] // params.subsampling_factor]
+            for i in range(batch_size)
+        ],
+        dtype=torch.int32,
+    )
+
+    lattice = get_lattice(
+        nnet_output=nnet_output,
+        decoding_graph=HP,
+        supervision_segments=supervision_segments,
+        search_beam=params.search_beam,
+        output_beam=params.output_beam,
+        min_active_states=params.min_active_states,
+        max_active_states=params.max_active_states,
+        subsampling_factor=params.subsampling_factor,
+    )
+
+    if method in ["1best", "nbest"]:
+        if method == "1best":
+            best_path = one_best_decoding(
+                lattice=lattice, use_double_scores=params.use_double_scores
+            )
+        else:
+            best_path = nbest_decoding(
+                lattice=lattice,
+                num_paths=params.num_paths,
+                use_double_scores=params.use_double_scores,
+                nbest_scale=params.nbest_scale,
+            )
+    else:
+        best_path_dict = nbest_rescore_with_LM(
+            lattice=lattice,
+            LM=LM,
+            num_paths=params.num_paths,
+            lm_scale_list=[params.ngram_lm_scale],
+            nbest_scale=params.nbest_scale,
+        )
+        best_path = next(iter(best_path_dict.values()))
+
+    # Note: `best_path.aux_labels` contains token IDs, not word IDs
+    # since we are using HP, not HLG here.
+    #
+    # token_ids is a lit-of-list of IDs
+    token_ids = get_texts(best_path)
+    # hyps is a list of str, e.g., ['xxx yyy zzz', ...]
+    hyps = bpe_model.decode(token_ids)
+    # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ]
+    hyps = [s.split() for s in hyps]
+    s = "\n"
+    for filename, hyp in zip(params.sound_files, hyps):
+        words = " ".join(hyp)
+        s += f"{filename}:\n{words}\n\n"
+    logging.info(s)
+
+    logging.info("Decoding Done")
+
+
+if __name__ == "__main__":
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+    logging.basicConfig(format=formatter, level=logging.INFO)
+    main()
diff --git a/egs/librispeech/ASR/zipformer_mmi/model.py b/egs/librispeech/ASR/zipformer_mmi/model.py
new file mode 100644
index 000000000..4045c8b64
--- /dev/null
+++ b/egs/librispeech/ASR/zipformer_mmi/model.py
@@ -0,0 +1,75 @@
+# Copyright    2022  Xiaomi Corp.        (authors: Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from typing import Tuple
+
+import torch
+import torch.nn as nn
+from encoder_interface import EncoderInterface
+
+
+class CTCModel(nn.Module):
+    def __init__(
+        self,
+        encoder: EncoderInterface,
+        encoder_dim: int,
+        vocab_size: int,
+    ):
+        """
+        Args:
+          encoder:
+            It is the transcription network in the paper. Its accepts
+            two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,).
+            It returns two tensors: `logits` of shape (N, T, encoder_dm) and
+            `logit_lens` of shape (N,).
+        """
+        super().__init__()
+        assert isinstance(encoder, EncoderInterface), type(encoder)
+
+        self.encoder = encoder
+
+        self.ctc_output = nn.Sequential(
+            nn.Dropout(p=0.1),
+            nn.Linear(encoder_dim, vocab_size),
+            nn.LogSoftmax(dim=-1),
+        )
+
+    def forward(
+        self,
+        x: torch.Tensor,
+        x_lens: torch.Tensor,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """
+        Args:
+          x:
+            A 3-D tensor of shape (N, T, C).
+          x_lens:
+            A 1-D tensor of shape (N,). It contains the number of frames in `x`
+            before padding.
+        Returns:
+          Return the ctc outputs and encoder output lengths.
+        """
+        assert x.ndim == 3, x.shape
+        assert x_lens.ndim == 1, x_lens.shape
+
+        encoder_out, x_lens = self.encoder(x, x_lens)
+        assert torch.all(x_lens > 0)
+
+        # compute ctc log-probs
+        ctc_output = self.ctc_output(encoder_out)
+
+        return ctc_output, x_lens
diff --git a/egs/librispeech/ASR/zipformer_mmi/optim.py b/egs/librispeech/ASR/zipformer_mmi/optim.py
new file mode 120000
index 000000000..81ac4a89a
--- /dev/null
+++ b/egs/librispeech/ASR/zipformer_mmi/optim.py
@@ -0,0 +1 @@
+../pruned_transducer_stateless7/optim.py
\ No newline at end of file
diff --git a/egs/librispeech/ASR/zipformer_mmi/pretrained.py b/egs/librispeech/ASR/zipformer_mmi/pretrained.py
new file mode 100755
index 000000000..0e7fd0daf
--- /dev/null
+++ b/egs/librispeech/ASR/zipformer_mmi/pretrained.py
@@ -0,0 +1,410 @@
+#!/usr/bin/env python3
+# Copyright      2021-2022  Xiaomi Corp.   (authors: Fangjun Kuang,
+#                                                    Zengwei)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This script loads a checkpoint and uses it to decode waves.
+You can generate the checkpoint with the following command:
+
+./zipformer_mmi/export.py \
+  --exp-dir ./zipformer_mmi/exp \
+  --bpe-model data/lang_bpe_500/bpe.model \
+  --epoch 20 \
+  --avg 10
+
+Usage of this script:
+
+(1) 1best
+./zipformer_mmi/pretrained.py \
+    --checkpoint ./zipformer_mmi/exp/pretrained.pt \
+    --bpe-model ./data/lang_bpe_500/bpe.model \
+    --method 1best \
+    /path/to/foo.wav \
+    /path/to/bar.wav
+(2) nbest
+./zipformer_mmi/pretrained.py \
+    --checkpoint ./zipformer_mmi/exp/pretrained.pt \
+    --bpe-model ./data/lang_bpe_500/bpe.model \
+    --nbest-scale 1.2 \
+    --method nbest \
+    /path/to/foo.wav \
+    /path/to/bar.wav
+(3) nbest-rescoring-LG
+./zipformer_mmi/pretrained.py \
+    --checkpoint ./zipformer_mmi/exp/pretrained.pt \
+    --bpe-model ./data/lang_bpe_500/bpe.model \
+    --nbest-scale 1.2 \
+    --method nbest-rescoring-LG \
+    /path/to/foo.wav \
+    /path/to/bar.wav
+(4) nbest-rescoring-3-gram
+./zipformer_mmi/pretrained.py \
+    --checkpoint ./zipformer_mmi/exp/pretrained.pt \
+    --bpe-model ./data/lang_bpe_500/bpe.model \
+    --nbest-scale 1.2 \
+    --method nbest-rescoring-3-gram \
+    /path/to/foo.wav \
+    /path/to/bar.wav
+(5) nbest-rescoring-4-gram
+./zipformer_mmi/pretrained.py \
+    --checkpoint ./zipformer_mmi/exp/pretrained.pt \
+    --bpe-model ./data/lang_bpe_500/bpe.model \
+    --nbest-scale 1.2 \
+    --method nbest-rescoring-4-gram \
+    /path/to/foo.wav \
+    /path/to/bar.wav
+
+
+You can also use `./zipformer_mmi/exp/epoch-xx.pt`.
+
+Note: ./zipformer_mmi/exp/pretrained.pt is generated by
+./zipformer_mmi/export.py
+"""
+
+
+import argparse
+import logging
+import math
+from pathlib import Path
+from typing import List
+
+import k2
+import kaldifeat
+import sentencepiece as spm
+import torch
+import torchaudio
+from decode import get_decoding_params
+from torch.nn.utils.rnn import pad_sequence
+from train import add_model_arguments, get_ctc_model, get_params
+
+from icefall.decode import (
+    get_lattice,
+    nbest_decoding,
+    nbest_rescore_with_LM,
+    one_best_decoding,
+)
+from icefall.mmi_graph_compiler import MmiTrainingGraphCompiler
+from icefall.utils import get_texts
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+
+    parser.add_argument(
+        "--checkpoint",
+        type=str,
+        required=True,
+        help="Path to the checkpoint. "
+        "The checkpoint is assumed to be saved by "
+        "icefall.checkpoint.save_checkpoint().",
+    )
+
+    parser.add_argument(
+        "--bpe-model",
+        type=str,
+        help="""Path to bpe.model.""",
+    )
+
+    parser.add_argument(
+        "--method",
+        type=str,
+        default="1best",
+        help="""Decoding method. Use HP as decoding graph, where H is
+        ctc_topo and P is token-level bi-gram lm.
+        Supported values are:
+        - (1) 1best. Extract the best path from the decoding lattice as the
+          decoding result.
+        - (2) nbest. Extract n paths from the decoding lattice; the path
+          with the highest score is the decoding result.
+        - (4) nbest-rescoring-LG. Extract n paths from the decoding lattice,
+          rescore them with an word-level 3-gram LM, the path with the
+          highest score is the decoding result.
+        - (5) nbest-rescoring-3-gram. Extract n paths from the decoding
+          lattice, rescore them with an token-level 3-gram LM, the path with
+          the highest score is the decoding result.
+        - (6) nbest-rescoring-4-gram. Extract n paths from the decoding
+          lattice, rescore them with an token-level 4-gram LM, the path with
+          the highest score is the decoding result.
+        """,
+    )
+
+    parser.add_argument(
+        "--sample-rate",
+        type=int,
+        default=16000,
+        help="The sample rate of the input sound file",
+    )
+
+    parser.add_argument(
+        "--lang-dir",
+        type=Path,
+        default="data/lang_bpe_500",
+        help="The lang dir containing word table and LG graph",
+    )
+
+    parser.add_argument(
+        "--num-paths",
+        type=int,
+        default=100,
+        help="""Number of paths for n-best based decoding method.
+        Used only when "method" is one of the following values:
+        nbest, nbest-rescoring, and nbest-oracle
+        """,
+    )
+
+    parser.add_argument(
+        "--nbest-scale",
+        type=float,
+        default=1.2,
+        help="""The scale to be applied to `lattice.scores`.
+        It's needed if you use any kinds of n-best based rescoring.
+        Used only when "method" is one of the following values:
+        nbest, nbest-rescoring, and nbest-oracle
+        A smaller value results in more unique paths.
+        """,
+    )
+
+    parser.add_argument(
+        "--ngram-lm-scale",
+        type=float,
+        default=0.1,
+        help="""
+        Used when method is nbest-rescoring-LG, nbest-rescoring-3-gram,
+        and nbest-rescoring-4-gram.
+        It specifies the scale for n-gram LM scores.
+        (Note: You need to tune it on a dataset.)
+        """,
+    )
+
+    parser.add_argument(
+        "--hp-scale",
+        type=float,
+        default=1.0,
+        help="""The scale to be applied to `ctc_topo_P.scores`.
+        """,
+    )
+
+    parser.add_argument(
+        "sound_files",
+        type=str,
+        nargs="+",
+        help="The input sound file(s) to transcribe. "
+        "Supported formats are those supported by torchaudio.load(). "
+        "For example, wav and flac are supported. "
+        "The sample rate has to be 16kHz.",
+    )
+
+    add_model_arguments(parser)
+
+    return parser
+
+
+def read_sound_files(
+    filenames: List[str], expected_sample_rate: float
+) -> List[torch.Tensor]:
+    """Read a list of sound files into a list 1-D float32 torch tensors.
+    Args:
+      filenames:
+        A list of sound filenames.
+      expected_sample_rate:
+        The expected sample rate of the sound files.
+    Returns:
+      Return a list of 1-D float32 torch tensors.
+    """
+    ans = []
+    for f in filenames:
+        wave, sample_rate = torchaudio.load(f)
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
+        # We use only the first channel
+        ans.append(wave[0])
+    return ans
+
+
+@torch.no_grad()
+def main():
+    parser = get_parser()
+    args = parser.parse_args()
+
+    params = get_params()
+    # add decoding params
+    params.update(get_decoding_params())
+    params.update(vars(args))
+
+    sp = spm.SentencePieceProcessor()
+    sp.load(params.bpe_model)
+
+    #  is defined in local/train_bpe_model.py
+    params.blank_id = sp.piece_to_id("")
+    params.unk_id = sp.piece_to_id("")
+    params.vocab_size = sp.get_piece_size()
+
+    logging.info(f"{params}")
+
+    device = torch.device("cpu")
+    if torch.cuda.is_available():
+        device = torch.device("cuda", 0)
+
+    logging.info(f"device: {device}")
+
+    logging.info("Creating model")
+    model = get_ctc_model(params)
+
+    num_param = sum([p.numel() for p in model.parameters()])
+    logging.info(f"Number of model parameters: {num_param}")
+
+    checkpoint = torch.load(args.checkpoint, map_location="cpu")
+    model.load_state_dict(checkpoint["model"], strict=False)
+    model.to(device)
+    model.eval()
+    model.device = device
+
+    logging.info("Constructing Fbank computer")
+    opts = kaldifeat.FbankOptions()
+    opts.device = device
+    opts.frame_opts.dither = 0
+    opts.frame_opts.snip_edges = False
+    opts.frame_opts.samp_freq = params.sample_rate
+    opts.mel_opts.num_bins = params.feature_dim
+
+    fbank = kaldifeat.Fbank(opts)
+
+    logging.info(f"Reading sound files: {params.sound_files}")
+    waves = read_sound_files(
+        filenames=params.sound_files, expected_sample_rate=params.sample_rate
+    )
+    waves = [w.to(device) for w in waves]
+
+    logging.info("Decoding started")
+    features = fbank(waves)
+    feature_lengths = [f.size(0) for f in features]
+
+    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
+    feature_lengths = torch.tensor(feature_lengths, device=device)
+
+    bpe_model = spm.SentencePieceProcessor()
+    bpe_model.load(str(params.lang_dir / "bpe.model"))
+    mmi_graph_compiler = MmiTrainingGraphCompiler(
+        params.lang_dir,
+        uniq_filename="lexicon.txt",
+        device=device,
+        oov="",
+        sos_id=1,
+        eos_id=1,
+    )
+    HP = mmi_graph_compiler.ctc_topo_P
+    HP.scores *= params.hp_scale
+    if not hasattr(HP, "lm_scores"):
+        HP.lm_scores = HP.scores.clone()
+
+    method = params.method
+    assert method in (
+        "1best",
+        "nbest",
+        "nbest-rescoring-LG",  # word-level 3-gram lm
+        "nbest-rescoring-3-gram",  # token-level 3-gram lm
+        "nbest-rescoring-4-gram",  # token-level 4-gram lm
+    )
+    # loading language model for rescoring
+    LM = None
+    if method == "nbest-rescoring-LG":
+        lg_filename = params.lang_dir / "LG.pt"
+        logging.info(f"Loading {lg_filename}")
+        LG = k2.Fsa.from_dict(torch.load(lg_filename, map_location=device))
+        LG = k2.Fsa.from_fsas([LG]).to(device)
+        LG.lm_scores = LG.scores.clone()
+        LM = LG
+    elif method in ["nbest-rescoring-3-gram", "nbest-rescoring-4-gram"]:
+        order = method[-6]
+        assert order in ("3", "4")
+        order = int(order)
+        logging.info(f"Loading pre-compiled {order}gram.pt")
+        d = torch.load(params.lang_dir / f"{order}gram.pt", map_location=device)
+        G = k2.Fsa.from_dict(d)
+        G.lm_scores = G.scores.clone()
+        LM = G
+
+    # Encoder forward
+    nnet_output, encoder_out_lens = model(x=features, x_lens=feature_lengths)
+
+    batch_size = nnet_output.shape[0]
+    supervision_segments = torch.tensor(
+        [
+            [i, 0, feature_lengths[i] // params.subsampling_factor]
+            for i in range(batch_size)
+        ],
+        dtype=torch.int32,
+    )
+
+    lattice = get_lattice(
+        nnet_output=nnet_output,
+        decoding_graph=HP,
+        supervision_segments=supervision_segments,
+        search_beam=params.search_beam,
+        output_beam=params.output_beam,
+        min_active_states=params.min_active_states,
+        max_active_states=params.max_active_states,
+        subsampling_factor=params.subsampling_factor,
+    )
+
+    if method in ["1best", "nbest"]:
+        if method == "1best":
+            best_path = one_best_decoding(
+                lattice=lattice, use_double_scores=params.use_double_scores
+            )
+        else:
+            best_path = nbest_decoding(
+                lattice=lattice,
+                num_paths=params.num_paths,
+                use_double_scores=params.use_double_scores,
+                nbest_scale=params.nbest_scale,
+            )
+    else:
+        best_path_dict = nbest_rescore_with_LM(
+            lattice=lattice,
+            LM=LM,
+            num_paths=params.num_paths,
+            lm_scale_list=[params.ngram_lm_scale],
+            nbest_scale=params.nbest_scale,
+        )
+        best_path = next(iter(best_path_dict.values()))
+
+    # Note: `best_path.aux_labels` contains token IDs, not word IDs
+    # since we are using HP, not HLG here.
+    #
+    # token_ids is a lit-of-list of IDs
+    token_ids = get_texts(best_path)
+    # hyps is a list of str, e.g., ['xxx yyy zzz', ...]
+    hyps = bpe_model.decode(token_ids)
+    # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ]
+    hyps = [s.split() for s in hyps]
+    s = "\n"
+    for filename, hyp in zip(params.sound_files, hyps):
+        words = " ".join(hyp)
+        s += f"{filename}:\n{words}\n\n"
+    logging.info(s)
+
+    logging.info("Decoding Done")
+
+
+if __name__ == "__main__":
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+    logging.basicConfig(format=formatter, level=logging.INFO)
+    main()
diff --git a/egs/librispeech/ASR/zipformer_mmi/scaling.py b/egs/librispeech/ASR/zipformer_mmi/scaling.py
new file mode 120000
index 000000000..2428b74b9
--- /dev/null
+++ b/egs/librispeech/ASR/zipformer_mmi/scaling.py
@@ -0,0 +1 @@
+../pruned_transducer_stateless7/scaling.py
\ No newline at end of file
diff --git a/egs/librispeech/ASR/zipformer_mmi/scaling_converter.py b/egs/librispeech/ASR/zipformer_mmi/scaling_converter.py
new file mode 120000
index 000000000..b8b8ba432
--- /dev/null
+++ b/egs/librispeech/ASR/zipformer_mmi/scaling_converter.py
@@ -0,0 +1 @@
+../pruned_transducer_stateless7/scaling_converter.py
\ No newline at end of file
diff --git a/egs/librispeech/ASR/zipformer_mmi/test_model.py b/egs/librispeech/ASR/zipformer_mmi/test_model.py
new file mode 100755
index 000000000..7782845f4
--- /dev/null
+++ b/egs/librispeech/ASR/zipformer_mmi/test_model.py
@@ -0,0 +1,57 @@
+#!/usr/bin/env python3
+# Copyright    2022  Xiaomi Corp.        (authors: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""
+To run this file, do:
+
+    cd icefall/egs/librispeech/ASR
+    python ./zipformer_mmi/test_model.py
+"""
+
+import torch
+from train import get_ctc_model, get_params
+
+
+def test_model():
+    params = get_params()
+    params.vocab_size = 500
+    params.num_encoder_layers = "2,4,3,2,4"
+    #  params.feedforward_dims = "1024,1024,1536,1536,1024"
+    params.feedforward_dims = "1024,1024,2048,2048,1024"
+    params.nhead = "8,8,8,8,8"
+    params.encoder_dims = "384,384,384,384,384"
+    params.attention_dims = "192,192,192,192,192"
+    params.encoder_unmasked_dims = "256,256,256,256,256"
+    params.zipformer_downsampling_factors = "1,2,4,8,2"
+    params.cnn_module_kernels = "31,31,31,31,31"
+    model = get_ctc_model(params)
+
+    num_param = sum([p.numel() for p in model.parameters()])
+    print(f"Number of model parameters: {num_param}")
+
+    features = torch.randn(2, 100, 80)
+    feature_lengths = torch.full((2,), 100)
+    model(x=features, x_lens=feature_lengths)
+
+
+def main():
+    test_model()
+
+
+if __name__ == "__main__":
+    main()
diff --git a/egs/librispeech/ASR/zipformer_mmi/train.py b/egs/librispeech/ASR/zipformer_mmi/train.py
new file mode 100755
index 000000000..b2784e47c
--- /dev/null
+++ b/egs/librispeech/ASR/zipformer_mmi/train.py
@@ -0,0 +1,1198 @@
+#!/usr/bin/env python3
+# Copyright    2021-2022  Xiaomi Corp.        (authors: Fangjun Kuang,
+#                                                       Wei Kang,
+#                                                       Mingshuang Luo,)
+#                                                       Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+
+export CUDA_VISIBLE_DEVICES="0,1,2,3"
+
+./zipformer_mmi/train.py \
+  --world-size 4 \
+  --num-epochs 30 \
+  --start-epoch 1 \
+  --exp-dir zipformer_mmi/exp \
+  --full-libri 1 \
+  --max-duration 300
+
+# For mix precision training:
+
+./zipformer_mmi/train.py \
+  --world-size 4 \
+  --num-epochs 30 \
+  --start-epoch 1 \
+  --use-fp16 1 \
+  --exp-dir zipformer_mmi/exp \
+  --full-libri 1 \
+  --max-duration 500
+
+"""
+
+
+import argparse
+import copy
+import logging
+import warnings
+from pathlib import Path
+from shutil import copyfile
+from typing import Any, Dict, Optional, Tuple, Union
+
+import k2
+import optim
+import torch
+import torch.multiprocessing as mp
+import torch.nn as nn
+from asr_datamodule import LibriSpeechAsrDataModule
+from lhotse.cut import Cut
+from lhotse.dataset.sampling.base import CutSampler
+from lhotse.utils import fix_random_seed
+from model import CTCModel
+from optim import Eden, ScaledAdam
+from torch import Tensor
+from torch.cuda.amp import GradScaler
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.utils.tensorboard import SummaryWriter
+from zipformer import Zipformer
+
+from icefall import diagnostics
+from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
+from icefall.checkpoint import load_checkpoint, remove_checkpoints
+from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
+from icefall.checkpoint import (
+    save_checkpoint_with_global_batch_idx,
+    update_averaged_model,
+)
+from icefall.dist import cleanup_dist, setup_dist
+from icefall.env import get_env_info
+from icefall.hooks import register_inf_check_hooks
+from icefall.lexicon import Lexicon, UniqLexicon
+from icefall.mmi import LFMMILoss
+from icefall.mmi_graph_compiler import MmiTrainingGraphCompiler
+from icefall.utils import (
+    AttributeDict,
+    MetricsTracker,
+    encode_supervisions,
+    setup_logger,
+    str2bool,
+)
+
+LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
+
+
+def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
+    if isinstance(model, DDP):
+        # get underlying nn.Module
+        model = model.module
+    for module in model.modules():
+        if hasattr(module, "batch_count"):
+            module.batch_count = batch_count
+
+
+def add_model_arguments(parser: argparse.ArgumentParser):
+    parser.add_argument(
+        "--num-encoder-layers",
+        type=str,
+        default="2,4,3,2,4",
+        help="Number of zipformer encoder layers, comma separated.",
+    )
+
+    parser.add_argument(
+        "--feedforward-dims",
+        type=str,
+        default="1024,1024,2048,2048,1024",
+        help="Feedforward dimension of the zipformer encoder layers, comma separated.",
+    )
+
+    parser.add_argument(
+        "--nhead",
+        type=str,
+        default="8,8,8,8,8",
+        help="Number of attention heads in the zipformer encoder layers.",
+    )
+
+    parser.add_argument(
+        "--encoder-dims",
+        type=str,
+        default="384,384,384,384,384",
+        help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated",
+    )
+
+    parser.add_argument(
+        "--attention-dims",
+        type=str,
+        default="192,192,192,192,192",
+        help="""Attention dimension in the 2 blocks of zipformer encoder layers, comma separated;
+        not the same as embedding dimension.""",
+    )
+
+    parser.add_argument(
+        "--encoder-unmasked-dims",
+        type=str,
+        default="256,256,256,256,256",
+        help="Unmasked dimensions in the encoders, relates to augmentation during training.  "
+        "Must be <= each of encoder_dims.  Empirically, less than 256 seems to make performance "
+        " worse.",
+    )
+
+    parser.add_argument(
+        "--zipformer-downsampling-factors",
+        type=str,
+        default="1,2,4,8,2",
+        help="Downsampling factor for each stack of encoder layers.",
+    )
+
+    parser.add_argument(
+        "--cnn-module-kernels",
+        type=str,
+        default="31,31,31,31,31",
+        help="Sizes of kernels in convolution modules",
+    )
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+
+    parser.add_argument(
+        "--world-size",
+        type=int,
+        default=1,
+        help="Number of GPUs for DDP training.",
+    )
+
+    parser.add_argument(
+        "--master-port",
+        type=int,
+        default=12354,
+        help="Master port to use for DDP training.",
+    )
+
+    parser.add_argument(
+        "--tensorboard",
+        type=str2bool,
+        default=True,
+        help="Should various information be logged in tensorboard.",
+    )
+
+    parser.add_argument(
+        "--num-epochs",
+        type=int,
+        default=30,
+        help="Number of epochs to train.",
+    )
+
+    parser.add_argument(
+        "--start-epoch",
+        type=int,
+        default=1,
+        help="""Resume training from this epoch. It should be positive.
+        If larger than 1, it will load checkpoint from
+        exp-dir/epoch-{start_epoch-1}.pt
+        """,
+    )
+
+    parser.add_argument(
+        "--start-batch",
+        type=int,
+        default=0,
+        help="""If positive, --start-epoch is ignored and
+        it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt
+        """,
+    )
+
+    parser.add_argument(
+        "--exp-dir",
+        type=str,
+        default="zipformer_mmi/exp",
+        help="""The experiment dir.
+        It specifies the directory where all training related
+        files, e.g., checkpoints, log, etc, are saved
+        """,
+    )
+
+    parser.add_argument(
+        "--lang-dir",
+        type=str,
+        default="data/lang_bpe_500",
+        help="""The lang dir
+        It contains language related input files such as
+        "lexicon.txt"
+        """,
+    )
+
+    parser.add_argument(
+        "--base-lr", type=float, default=0.05, help="The base learning rate."
+    )
+
+    parser.add_argument(
+        "--lr-batches",
+        type=float,
+        default=5000,
+        help="""Number of steps that affects how rapidly the learning rate
+        decreases. We suggest not to change this.""",
+    )
+
+    parser.add_argument(
+        "--lr-epochs",
+        type=float,
+        default=3.5,
+        help="""Number of epochs that affects how rapidly the learning rate decreases.
+        """,
+    )
+
+    parser.add_argument(
+        "--seed",
+        type=int,
+        default=42,
+        help="The seed for random generators intended for reproducibility",
+    )
+
+    parser.add_argument(
+        "--use-pruned-intersect",
+        type=str2bool,
+        default=False,
+        help="""Whether to use `intersect_dense_pruned` to get denominator
+        lattice.""",
+    )
+
+    parser.add_argument(
+        "--print-diagnostics",
+        type=str2bool,
+        default=False,
+        help="Accumulate stats on activations, print them and exit.",
+    )
+
+    parser.add_argument(
+        "--inf-check",
+        type=str2bool,
+        default=False,
+        help="Add hooks to check for infinite module outputs and gradients.",
+    )
+
+    parser.add_argument(
+        "--save-every-n",
+        type=int,
+        default=2000,
+        help="""Save checkpoint after processing this number of batches"
+        periodically. We save checkpoint to exp-dir/ whenever
+        params.batch_idx_train % save_every_n == 0. The checkpoint filename
+        has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
+        Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
+        end of each epoch where `xxx` is the epoch number counting from 0.
+        """,
+    )
+
+    parser.add_argument(
+        "--keep-last-k",
+        type=int,
+        default=30,
+        help="""Only keep this number of checkpoints on disk.
+        For instance, if it is 3, there are only 3 checkpoints
+        in the exp-dir with filenames `checkpoint-xxx.pt`.
+        It does not affect checkpoints with name `epoch-xxx.pt`.
+        """,
+    )
+
+    parser.add_argument(
+        "--average-period",
+        type=int,
+        default=200,
+        help="""Update the averaged model, namely `model_avg`, after processing
+        this number of batches. `model_avg` is a separate version of model,
+        in which each floating-point parameter is the average of all the
+        parameters from the start of training. Each time we take the average,
+        we do: `model_avg = model * (average_period / batch_idx_train) +
+            model_avg * ((batch_idx_train - average_period) / batch_idx_train)`.
+        """,
+    )
+
+    parser.add_argument(
+        "--use-fp16",
+        type=str2bool,
+        default=False,
+        help="Whether to use half precision training.",
+    )
+
+    add_model_arguments(parser)
+
+    return parser
+
+
+def get_params() -> AttributeDict:
+    """Return a dict containing training parameters.
+
+    All training related parameters that are not passed from the commandline
+    are saved in the variable `params`.
+
+    Commandline options are merged into `params` after they are parsed, so
+    you can also access them via `params`.
+
+    Explanation of options saved in `params`:
+
+        - best_train_loss: Best training loss so far. It is used to select
+                           the model that has the lowest training loss. It is
+                           updated during the training.
+
+        - best_valid_loss: Best validation loss so far. It is used to select
+                           the model that has the lowest validation loss. It is
+                           updated during the training.
+
+        - best_train_epoch: It is the epoch that has the best training loss.
+
+        - best_valid_epoch: It is the epoch that has the best validation loss.
+
+        - batch_idx_train: Used to writing statistics to tensorboard. It
+                           contains number of batches trained so far across
+                           epochs.
+
+        - log_interval:  Print training loss if batch_idx % log_interval` is 0
+
+        - reset_interval: Reset statistics if batch_idx % reset_interval is 0
+
+        - valid_interval:  Run validation if batch_idx % valid_interval is 0
+
+        - feature_dim: The model input dim. It has to match the one used
+                       in computing features.
+
+        - subsampling_factor:  The subsampling factor for the model.
+
+        - encoder_dim: Hidden dim for multi-head attention model.
+
+        - num_decoder_layers: Number of decoder layer of transformer decoder.
+
+        - warm_step: The warmup period that dictates the decay of the
+              scale on "simple" (un-pruned) loss.
+    """
+    params = AttributeDict(
+        {
+            "best_train_loss": float("inf"),
+            "best_valid_loss": float("inf"),
+            "best_train_epoch": -1,
+            "best_valid_epoch": -1,
+            "batch_idx_train": 0,
+            "log_interval": 50,
+            "reset_interval": 200,
+            "valid_interval": 3000,  # For the 100h subset, use 800
+            # parameters for zipformer
+            "feature_dim": 80,
+            "subsampling_factor": 4,  # not passed in, this is fixed.
+            # parameters for mmi loss
+            "mmi_beam_size": 6,
+            "den_scale": 1.0,
+            # parameters for mmi loss
+            "ctc_beam_size": 10,
+            "reduction": "sum",
+            "use_double_scores": True,
+            "warm_step": 2000,
+            "env_info": get_env_info(),
+        }
+    )
+
+    return params
+
+
+def get_encoder_model(params: AttributeDict) -> nn.Module:
+    # TODO: We can add an option to switch between Zipformer and Transformer
+    def to_int_tuple(s: str):
+        return tuple(map(int, s.split(",")))
+
+    encoder = Zipformer(
+        num_features=params.feature_dim,
+        output_downsampling_factor=2,
+        zipformer_downsampling_factors=to_int_tuple(
+            params.zipformer_downsampling_factors
+        ),
+        encoder_dims=to_int_tuple(params.encoder_dims),
+        attention_dim=to_int_tuple(params.attention_dims),
+        encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims),
+        nhead=to_int_tuple(params.nhead),
+        feedforward_dim=to_int_tuple(params.feedforward_dims),
+        cnn_module_kernels=to_int_tuple(params.cnn_module_kernels),
+        num_encoder_layers=to_int_tuple(params.num_encoder_layers),
+    )
+    return encoder
+
+
+def get_ctc_model(params: AttributeDict) -> nn.Module:
+    encoder = get_encoder_model(params)
+
+    model = CTCModel(
+        encoder=encoder,
+        encoder_dim=int(params.encoder_dims.split(",")[-1]),
+        vocab_size=params.vocab_size,
+    )
+    return model
+
+
+def load_checkpoint_if_available(
+    params: AttributeDict,
+    model: nn.Module,
+    model_avg: nn.Module = None,
+    optimizer: Optional[torch.optim.Optimizer] = None,
+    scheduler: Optional[LRSchedulerType] = None,
+) -> Optional[Dict[str, Any]]:
+    """Load checkpoint from file.
+
+    If params.start_batch is positive, it will load the checkpoint from
+    `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if
+    params.start_epoch is larger than 1, it will load the checkpoint from
+    `params.start_epoch - 1`.
+
+    Apart from loading state dict for `model` and `optimizer` it also updates
+    `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
+    and `best_valid_loss` in `params`.
+
+    Args:
+      params:
+        The return value of :func:`get_params`.
+      model:
+        The training model.
+      model_avg:
+        The stored model averaged from the start of training.
+      optimizer:
+        The optimizer that we are using.
+      scheduler:
+        The scheduler that we are using.
+    Returns:
+      Return a dict containing previously saved training info.
+    """
+    if params.start_batch > 0:
+        filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt"
+    elif params.start_epoch > 1:
+        filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
+    else:
+        return None
+
+    assert filename.is_file(), f"{filename} does not exist!"
+
+    saved_params = load_checkpoint(
+        filename,
+        model=model,
+        model_avg=model_avg,
+        optimizer=optimizer,
+        scheduler=scheduler,
+    )
+
+    keys = [
+        "best_train_epoch",
+        "best_valid_epoch",
+        "batch_idx_train",
+        "best_train_loss",
+        "best_valid_loss",
+    ]
+    for k in keys:
+        params[k] = saved_params[k]
+
+    if params.start_batch > 0:
+        if "cur_epoch" in saved_params:
+            params["start_epoch"] = saved_params["cur_epoch"]
+
+        if "cur_batch_idx" in saved_params:
+            params["cur_batch_idx"] = saved_params["cur_batch_idx"]
+
+    return saved_params
+
+
+def save_checkpoint(
+    params: AttributeDict,
+    model: Union[nn.Module, DDP],
+    model_avg: Optional[nn.Module] = None,
+    optimizer: Optional[torch.optim.Optimizer] = None,
+    scheduler: Optional[LRSchedulerType] = None,
+    sampler: Optional[CutSampler] = None,
+    scaler: Optional[GradScaler] = None,
+    rank: int = 0,
+) -> None:
+    """Save model, optimizer, scheduler and training stats to file.
+
+    Args:
+      params:
+        It is returned by :func:`get_params`.
+      model:
+        The training model.
+      model_avg:
+        The stored model averaged from the start of training.
+      optimizer:
+        The optimizer used in the training.
+      sampler:
+       The sampler for the training dataset.
+      scaler:
+        The scaler used for mix precision training.
+    """
+    if rank != 0:
+        return
+    filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
+    save_checkpoint_impl(
+        filename=filename,
+        model=model,
+        model_avg=model_avg,
+        params=params,
+        optimizer=optimizer,
+        scheduler=scheduler,
+        sampler=sampler,
+        scaler=scaler,
+        rank=rank,
+    )
+
+    if params.best_train_epoch == params.cur_epoch:
+        best_train_filename = params.exp_dir / "best-train-loss.pt"
+        copyfile(src=filename, dst=best_train_filename)
+
+    if params.best_valid_epoch == params.cur_epoch:
+        best_valid_filename = params.exp_dir / "best-valid-loss.pt"
+        copyfile(src=filename, dst=best_valid_filename)
+
+
+def compute_loss(
+    params: AttributeDict,
+    model: Union[nn.Module, DDP],
+    ctc_graph_compiler: BpeCtcTrainingGraphCompiler,
+    mmi_graph_compiler: MmiTrainingGraphCompiler,
+    batch: dict,
+    is_training: bool,
+) -> Tuple[Tensor, MetricsTracker]:
+    """
+    Compute ctc loss given the model and its inputs.
+
+    Args:
+      params:
+        Parameters for training. See :func:`get_params`.
+      model:
+        The model for training. It is an instance of Zipformer in our case.
+      graph_compiler:
+        It is used to build a decoding graph from a ctc topo and training
+        transcript. The training transcript is contained in the given `batch`,
+        while the ctc topo is built when this compiler is instantiated.
+      batch:
+        A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+        for the content in it.
+      is_training:
+        True for training. False for validation. When it is True, this
+        function enables autograd during computation; when it is False, it
+        disables autograd.
+    """
+    device = model.device if isinstance(model, DDP) else next(model.parameters()).device
+    feature = batch["inputs"]
+    # at entry, feature is (N, T, C)
+    assert feature.ndim == 3
+    feature = feature.to(device)
+
+    supervisions = batch["supervisions"]
+    feature_lens = supervisions["num_frames"].to(device)
+
+    batch_idx_train = params.batch_idx_train
+    warm_step = params.warm_step
+
+    with torch.set_grad_enabled(is_training):
+        nnet_output, encoder_out_lens = model(x=feature, x_lens=feature_lens)
+
+    # NOTE: We need `encode_supervisions` to sort sequences with
+    # different duration in decreasing order, required by
+    # `k2.intersect_dense` called in `LFMMILoss.forward()`
+    with warnings.catch_warnings():
+        warnings.simplefilter("ignore")
+        supervision_segments, texts = encode_supervisions(
+            supervisions, subsampling_factor=params.subsampling_factor
+        )
+
+    dense_fsa_vec = k2.DenseFsaVec(
+        nnet_output,
+        supervision_segments,
+        allow_truncate=params.subsampling_factor - 1,
+    )
+
+    info = MetricsTracker()
+    if batch_idx_train < warm_step:
+        # Training with ctc loss
+        # Works with a BPE model
+        token_ids = ctc_graph_compiler.texts_to_ids(texts)
+        decoding_graph = ctc_graph_compiler.compile(token_ids)
+        loss = k2.ctc_loss(
+            decoding_graph=decoding_graph,
+            dense_fsa_vec=dense_fsa_vec,
+            output_beam=params.ctc_beam_size,
+            reduction=params.reduction,
+            use_double_scores=params.use_double_scores,
+        )
+        info["ctc_loss"] = loss.detach().cpu().item()
+        info["mmi_loss"] = 0
+    else:
+        # Training with mmi loss
+        loss_fn = LFMMILoss(
+            graph_compiler=mmi_graph_compiler,
+            use_pruned_intersect=params.use_pruned_intersect,
+            den_scale=params.den_scale,
+            beam_size=params.mmi_beam_size,
+        )
+        loss = loss_fn(dense_fsa_vec=dense_fsa_vec, texts=texts)
+        info["ctc_loss"] = 0
+        info["mmi_loss"] = loss.detach().cpu().item()
+
+    assert loss.requires_grad == is_training
+
+    info["frames"] = encoder_out_lens.sum().cpu().item()
+    # Note: We use reduction=sum while computing the loss.
+    info["loss"] = loss.detach().cpu().item()
+
+    return loss, info
+
+
+def compute_validation_loss(
+    params: AttributeDict,
+    model: Union[nn.Module, DDP],
+    ctc_graph_compiler: BpeCtcTrainingGraphCompiler,
+    mmi_graph_compiler: MmiTrainingGraphCompiler,
+    valid_dl: torch.utils.data.DataLoader,
+    world_size: int = 1,
+) -> MetricsTracker:
+    """Run the validation process."""
+    model.eval()
+
+    tot_loss = MetricsTracker()
+
+    for batch_idx, batch in enumerate(valid_dl):
+        loss, loss_info = compute_loss(
+            params=params,
+            model=model,
+            ctc_graph_compiler=ctc_graph_compiler,
+            mmi_graph_compiler=mmi_graph_compiler,
+            batch=batch,
+            is_training=False,
+        )
+        assert loss.requires_grad is False
+        tot_loss = tot_loss + loss_info
+
+    if world_size > 1:
+        tot_loss.reduce(loss.device)
+
+    loss_value = tot_loss["loss"] / tot_loss["frames"]
+    if loss_value < params.best_valid_loss:
+        params.best_valid_epoch = params.cur_epoch
+        params.best_valid_loss = loss_value
+
+    return tot_loss
+
+
+def train_one_epoch(
+    params: AttributeDict,
+    model: Union[nn.Module, DDP],
+    optimizer: torch.optim.Optimizer,
+    scheduler: LRSchedulerType,
+    ctc_graph_compiler: BpeCtcTrainingGraphCompiler,
+    mmi_graph_compiler: MmiTrainingGraphCompiler,
+    train_dl: torch.utils.data.DataLoader,
+    valid_dl: torch.utils.data.DataLoader,
+    scaler: GradScaler,
+    model_avg: Optional[nn.Module] = None,
+    tb_writer: Optional[SummaryWriter] = None,
+    world_size: int = 1,
+    rank: int = 0,
+) -> None:
+    """Train the model for one epoch.
+
+    The training loss from the mean of all frames is saved in
+    `params.train_loss`. It runs the validation process every
+    `params.valid_interval` batches.
+
+    Args:
+      params:
+        It is returned by :func:`get_params`.
+      model:
+        The model for training.
+      optimizer:
+        The optimizer we are using.
+      scheduler:
+        The learning rate scheduler, we call step() every step.
+      graph_compiler:
+        It is used to convert transcripts to FSAs.
+      train_dl:
+        Dataloader for the training dataset.
+      valid_dl:
+        Dataloader for the validation dataset.
+      scaler:
+        The scaler used for mix precision training.
+      model_avg:
+        The stored model averaged from the start of training.
+      tb_writer:
+        Writer to write log messages to tensorboard.
+      world_size:
+        Number of nodes in DDP training. If it is 1, DDP is disabled.
+      rank:
+        The rank of the node in DDP training. If no DDP is used, it should
+        be set to 0.
+    """
+    model.train()
+
+    tot_loss = MetricsTracker()
+
+    cur_batch_idx = params.get("cur_batch_idx", 0)
+
+    for batch_idx, batch in enumerate(train_dl):
+        if batch_idx < cur_batch_idx:
+            continue
+        cur_batch_idx = batch_idx
+
+        params.batch_idx_train += 1
+        batch_size = len(batch["supervisions"]["text"])
+
+        try:
+            with torch.cuda.amp.autocast(enabled=params.use_fp16):
+                loss, loss_info = compute_loss(
+                    params=params,
+                    model=model,
+                    ctc_graph_compiler=ctc_graph_compiler,
+                    mmi_graph_compiler=mmi_graph_compiler,
+                    batch=batch,
+                    is_training=True,
+                )
+            # summary stats
+            tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
+
+            # NOTE: We use reduction==sum and loss is computed over utterances
+            # in the batch and there is no normalization to it so far.
+            scaler.scale(loss).backward()
+            set_batch_count(model, params.batch_idx_train)
+            scheduler.step_batch(params.batch_idx_train)
+
+            scaler.step(optimizer)
+            scaler.update()
+            optimizer.zero_grad()
+        except:  # noqa
+            display_and_save_batch(
+                batch, params=params, graph_compiler=mmi_graph_compiler
+            )
+            raise
+
+        if params.print_diagnostics and batch_idx == 5:
+            return
+
+        if (
+            rank == 0
+            and params.batch_idx_train > 0
+            and params.batch_idx_train % params.average_period == 0
+        ):
+            update_averaged_model(
+                params=params,
+                model_cur=model,
+                model_avg=model_avg,
+            )
+
+        if (
+            params.batch_idx_train > 0
+            and params.batch_idx_train % params.save_every_n == 0
+        ):
+            params.cur_batch_idx = batch_idx
+            save_checkpoint_with_global_batch_idx(
+                out_dir=params.exp_dir,
+                global_batch_idx=params.batch_idx_train,
+                model=model,
+                model_avg=model_avg,
+                params=params,
+                optimizer=optimizer,
+                scheduler=scheduler,
+                sampler=train_dl.sampler,
+                scaler=scaler,
+                rank=rank,
+            )
+            del params.cur_batch_idx
+            remove_checkpoints(
+                out_dir=params.exp_dir,
+                topk=params.keep_last_k,
+                rank=rank,
+            )
+
+        if batch_idx % 100 == 0 and params.use_fp16:
+            # If the grad scale was less than 1, try increasing it.    The _growth_interval
+            # of the grad scaler is configurable, but we can't configure it to have different
+            # behavior depending on the current grad scale.
+            cur_grad_scale = scaler._scale.item()
+            if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0):
+                scaler.update(cur_grad_scale * 2.0)
+            if cur_grad_scale < 0.01:
+                logging.warning(f"Grad scale is small: {cur_grad_scale}")
+            if cur_grad_scale < 1.0e-05:
+                raise RuntimeError(
+                    f"grad_scale is too small, exiting: {cur_grad_scale}"
+                )
+
+        if batch_idx % params.log_interval == 0:
+            cur_lr = scheduler.get_last_lr()[0]
+            cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
+
+            logging.info(
+                f"Epoch {params.cur_epoch}, "
+                f"batch {batch_idx}, loss[{loss_info}], "
+                f"tot_loss[{tot_loss}], batch size: {batch_size}, "
+                f"lr: {cur_lr:.2e}, "
+                + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
+            )
+
+            if tb_writer is not None:
+                tb_writer.add_scalar(
+                    "train/learning_rate", cur_lr, params.batch_idx_train
+                )
+
+                loss_info.write_summary(
+                    tb_writer, "train/current_", params.batch_idx_train
+                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+                if params.use_fp16:
+                    tb_writer.add_scalar(
+                        "train/grad_scale",
+                        cur_grad_scale,
+                        params.batch_idx_train,
+                    )
+
+        if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
+            logging.info("Computing validation loss")
+            valid_info = compute_validation_loss(
+                params=params,
+                model=model,
+                ctc_graph_compiler=ctc_graph_compiler,
+                mmi_graph_compiler=mmi_graph_compiler,
+                valid_dl=valid_dl,
+                world_size=world_size,
+            )
+            model.train()
+            logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
+            logging.info(
+                f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
+            )
+            if tb_writer is not None:
+                valid_info.write_summary(
+                    tb_writer, "train/valid_", params.batch_idx_train
+                )
+
+    loss_value = tot_loss["loss"] / tot_loss["frames"]
+    params.train_loss = loss_value
+    if params.train_loss < params.best_train_loss:
+        params.best_train_epoch = params.cur_epoch
+        params.best_train_loss = params.train_loss
+
+
+def run(rank, world_size, args):
+    """
+    Args:
+      rank:
+        It is a value between 0 and `world_size-1`, which is
+        passed automatically by `mp.spawn()` in :func:`main`.
+        The node with rank 0 is responsible for saving checkpoint.
+      world_size:
+        Number of GPUs for DDP training.
+      args:
+        The return value of get_parser().parse_args()
+    """
+    params = get_params()
+    params.update(vars(args))
+    if params.full_libri is False:
+        params.valid_interval = 1600
+
+    fix_random_seed(params.seed)
+    if world_size > 1:
+        setup_dist(rank, world_size, params.master_port)
+
+    setup_logger(f"{params.exp_dir}/log/log-train")
+    logging.info("Training started")
+
+    if args.tensorboard and rank == 0:
+        tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
+    else:
+        tb_writer = None
+
+    lexicon = Lexicon(params.lang_dir)
+    max_token_id = max(lexicon.tokens)
+    num_classes = max_token_id + 1  # +1 for the blank
+    params.vocab_size = num_classes
+
+    device = torch.device("cpu")
+    if torch.cuda.is_available():
+        device = torch.device("cuda", rank)
+    logging.info(f"Device: {device}")
+
+    assert "lang_bpe" in str(params.lang_dir)
+    ctc_graph_compiler = BpeCtcTrainingGraphCompiler(
+        params.lang_dir,
+        device=device,
+        sos_token="",
+        eos_token="",
+    )
+    mmi_graph_compiler = MmiTrainingGraphCompiler(
+        params.lang_dir,
+        uniq_filename="lexicon.txt",
+        device=device,
+        oov="",
+        sos_id=1,
+        eos_id=1,
+    )
+
+    logging.info(params)
+
+    logging.info("About to create model")
+    model = get_ctc_model(params)
+
+    num_param = sum([p.numel() for p in model.parameters()])
+    logging.info(f"Number of model parameters: {num_param}")
+
+    assert params.save_every_n >= params.average_period
+    model_avg: Optional[nn.Module] = None
+    if rank == 0:
+        # model_avg is only used with rank 0
+        model_avg = copy.deepcopy(model).to(torch.float64)
+
+    assert params.start_epoch > 0, params.start_epoch
+    checkpoints = load_checkpoint_if_available(
+        params=params, model=model, model_avg=model_avg
+    )
+
+    model.to(device)
+    if world_size > 1:
+        logging.info("Using DDP")
+        model = DDP(model, device_ids=[rank], find_unused_parameters=True)
+
+    parameters_names = []
+    parameters_names.append(
+        [name_param_pair[0] for name_param_pair in model.named_parameters()]
+    )
+    optimizer = ScaledAdam(
+        model.parameters(),
+        lr=params.base_lr,
+        clipping_scale=2.0,
+        parameters_names=parameters_names,
+    )
+
+    scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
+
+    if checkpoints and "optimizer" in checkpoints:
+        logging.info("Loading optimizer state dict")
+        optimizer.load_state_dict(checkpoints["optimizer"])
+
+    if (
+        checkpoints
+        and "scheduler" in checkpoints
+        and checkpoints["scheduler"] is not None
+    ):
+        logging.info("Loading scheduler state dict")
+        scheduler.load_state_dict(checkpoints["scheduler"])
+
+    if params.print_diagnostics:
+        opts = diagnostics.TensorDiagnosticOptions(
+            2**22
+        )  # allow 4 megabytes per sub-module
+        diagnostic = diagnostics.attach_diagnostics(model, opts)
+
+    if params.inf_check:
+        register_inf_check_hooks(model)
+
+    librispeech = LibriSpeechAsrDataModule(args)
+
+    # train_cuts = librispeech.train_clean_100_cuts()
+    if params.full_libri:
+        # train_cuts += librispeech.train_clean_360_cuts()
+        # train_cuts += librispeech.train_other_500_cuts()
+        train_cuts = librispeech.train_all_shuf_cuts()
+    else:
+        train_cuts = librispeech.train_clean_100_cuts()
+
+    def remove_short_and_long_utt(c: Cut):
+        # Keep only utterances with duration between 1 second and 20 seconds
+        #
+        # Caution: There is a reason to select 20.0 here. Please see
+        # ../local/display_manifest_statistics.py
+        #
+        # You should use ../local/display_manifest_statistics.py to get
+        # an utterance duration distribution for your dataset to select
+        # the threshold
+        return 1.0 <= c.duration <= 20.0
+
+    train_cuts = train_cuts.filter(remove_short_and_long_utt)
+
+    if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
+        # We only load the sampler's state dict when it loads a checkpoint
+        # saved in the middle of an epoch
+        sampler_state_dict = checkpoints["sampler"]
+    else:
+        sampler_state_dict = None
+
+    train_dl = librispeech.train_dataloaders(
+        train_cuts, sampler_state_dict=sampler_state_dict
+    )
+
+    valid_cuts = librispeech.dev_clean_cuts()
+    valid_cuts += librispeech.dev_other_cuts()
+    valid_dl = librispeech.valid_dataloaders(valid_cuts)
+
+    if not params.print_diagnostics:
+        scan_pessimistic_batches_for_oom(
+            model=model,
+            train_dl=train_dl,
+            optimizer=optimizer,
+            ctc_graph_compiler=ctc_graph_compiler,
+            mmi_graph_compiler=mmi_graph_compiler,
+            params=params,
+        )
+
+    scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
+    if checkpoints and "grad_scaler" in checkpoints:
+        logging.info("Loading grad scaler state dict")
+        scaler.load_state_dict(checkpoints["grad_scaler"])
+
+    for epoch in range(params.start_epoch, params.num_epochs + 1):
+        scheduler.step_epoch(epoch - 1)
+        fix_random_seed(params.seed + epoch - 1)
+        train_dl.sampler.set_epoch(epoch - 1)
+
+        if tb_writer is not None:
+            tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
+
+        params.cur_epoch = epoch
+
+        train_one_epoch(
+            params=params,
+            model=model,
+            model_avg=model_avg,
+            optimizer=optimizer,
+            scheduler=scheduler,
+            ctc_graph_compiler=ctc_graph_compiler,
+            mmi_graph_compiler=mmi_graph_compiler,
+            train_dl=train_dl,
+            valid_dl=valid_dl,
+            scaler=scaler,
+            tb_writer=tb_writer,
+            world_size=world_size,
+            rank=rank,
+        )
+
+        if params.print_diagnostics:
+            diagnostic.print_diagnostics()
+            break
+
+        save_checkpoint(
+            params=params,
+            model=model,
+            model_avg=model_avg,
+            optimizer=optimizer,
+            scheduler=scheduler,
+            sampler=train_dl.sampler,
+            scaler=scaler,
+            rank=rank,
+        )
+
+    logging.info("Done!")
+
+    if world_size > 1:
+        torch.distributed.barrier()
+        cleanup_dist()
+
+
+def display_and_save_batch(
+    batch: dict,
+    params: AttributeDict,
+    graph_compiler: MmiTrainingGraphCompiler,
+) -> None:
+    """Display the batch statistics and save the batch into disk.
+
+    Args:
+      batch:
+        A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+        for the content in it.
+      params:
+        Parameters for training. See :func:`get_params`.
+      sp:
+        The BPE model.
+    """
+    from lhotse.utils import uuid4
+
+    filename = f"{params.exp_dir}/batch-{uuid4()}.pt"
+    logging.info(f"Saving batch to {filename}")
+    torch.save(batch, filename)
+
+    supervisions = batch["supervisions"]
+    features = batch["inputs"]
+
+    logging.info(f"features shape: {features.shape}")
+    y = graph_compiler.texts_to_ids(supervisions["text"])
+    num_tokens = sum(len(i) for i in y)
+    logging.info(f"num tokens: {num_tokens}")
+
+
+def scan_pessimistic_batches_for_oom(
+    model: Union[nn.Module, DDP],
+    train_dl: torch.utils.data.DataLoader,
+    optimizer: torch.optim.Optimizer,
+    ctc_graph_compiler: BpeCtcTrainingGraphCompiler,
+    mmi_graph_compiler: MmiTrainingGraphCompiler,
+    params: AttributeDict,
+):
+    from lhotse.dataset import find_pessimistic_batches
+
+    logging.info(
+        "Sanity check -- see if any of the batches in epoch 1 would cause OOM."
+    )
+    batches, crit_values = find_pessimistic_batches(train_dl.sampler)
+    for criterion, cuts in batches.items():
+        batch = train_dl.dataset[cuts]
+        try:
+            with torch.cuda.amp.autocast(enabled=params.use_fp16):
+                loss, _ = compute_loss(
+                    params=params,
+                    model=model,
+                    ctc_graph_compiler=ctc_graph_compiler,
+                    mmi_graph_compiler=mmi_graph_compiler,
+                    batch=batch,
+                    is_training=True,
+                )
+            loss.backward()
+            optimizer.zero_grad()
+        except Exception as e:
+            if "CUDA out of memory" in str(e):
+                logging.error(
+                    "Your GPU ran out of memory with the current "
+                    "max_duration setting. We recommend decreasing "
+                    "max_duration and trying again.\n"
+                    f"Failing criterion: {criterion} "
+                    f"(={crit_values[criterion]}) ..."
+                )
+            display_and_save_batch(
+                batch, params=params, graph_compiler=mmi_graph_compiler
+            )
+            raise
+        logging.info(
+            f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
+        )
+
+
+def main():
+    parser = get_parser()
+    LibriSpeechAsrDataModule.add_arguments(parser)
+    args = parser.parse_args()
+    args.exp_dir = Path(args.exp_dir)
+
+    world_size = args.world_size
+    assert world_size >= 1
+    if world_size > 1:
+        mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
+    else:
+        run(rank=0, world_size=1, args=args)
+
+
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+if __name__ == "__main__":
+    main()
diff --git a/egs/librispeech/ASR/zipformer_mmi/zipformer.py b/egs/librispeech/ASR/zipformer_mmi/zipformer.py
new file mode 120000
index 000000000..79b076556
--- /dev/null
+++ b/egs/librispeech/ASR/zipformer_mmi/zipformer.py
@@ -0,0 +1 @@
+../pruned_transducer_stateless7/zipformer.py
\ No newline at end of file
diff --git a/icefall/decode.py b/icefall/decode.py
index e4c614c4e..68e490c5e 100644
--- a/icefall/decode.py
+++ b/icefall/decode.py
@@ -717,6 +717,107 @@ def rescore_with_n_best_list(
     return ans
 
 
+def nbest_rescore_with_LM(
+    lattice: k2.Fsa,
+    LM: k2.Fsa,
+    num_paths: int,
+    lm_scale_list: List[float],
+    nbest_scale: float = 1.0,
+    use_double_scores: bool = True,
+) -> Dict[str, k2.Fsa]:
+    """Rescore an n-best list with an n-gram LM.
+    The path with the maximum score is used as the decoding output.
+
+    Args:
+      lattice:
+        An FsaVec with axes [utt][state][arc]. It must have the following
+        attributes: ``aux_labels`` and ``lm_scores``. They are both token
+        IDs.
+      LM:
+        An FsaVec containing only a single FSA. It is one of follows:
+        - LG, L is lexicon and G is word-level n-gram LM.
+        - G, token-level n-gram LM.
+      num_paths:
+        Size of nbest list.
+      lm_scale_list:
+        A list of floats representing LM score scales.
+      nbest_scale:
+        Scale to be applied to ``lattice.score`` when sampling paths
+        using ``k2.random_paths``.
+      use_double_scores:
+        True to use double precision during computation. False to use
+        single precision.
+    Returns:
+      A dict of FsaVec, whose key is an lm_scale and the value is the
+      best decoding path for each utterance in the lattice.
+    """
+    device = lattice.device
+
+    assert len(lattice.shape) == 3
+    assert hasattr(lattice, "aux_labels")
+    assert hasattr(lattice, "lm_scores")
+
+    assert LM.shape == (1, None, None)
+    assert LM.device == device
+
+    nbest = Nbest.from_lattice(
+        lattice=lattice,
+        num_paths=num_paths,
+        use_double_scores=use_double_scores,
+        nbest_scale=nbest_scale,
+    )
+    # nbest.fsa.scores contains 0s
+
+    nbest = nbest.intersect(lattice)
+
+    # Now nbest.fsa has its scores set
+    assert hasattr(nbest.fsa, "lm_scores")
+
+    # am scores + bi-gram scores
+    hp_scores = nbest.tot_scores()
+
+    # Now start to intersect nbest with LG or G
+    inv_fsa = k2.invert(nbest.fsa)
+    if hasattr(LM, "aux_labels"):
+        # LM is LG here
+        # delete token IDs as it is not needed
+        del inv_fsa.aux_labels
+    inv_fsa.scores.zero_()
+    inv_fsa_with_epsilon_loops = k2.linear_fsa_with_self_loops(inv_fsa)
+    path_to_utt_map = nbest.shape.row_ids(1)
+
+    LM = k2.arc_sort(LM)
+    path_lattice = k2.intersect_device(
+        LM,
+        inv_fsa_with_epsilon_loops,
+        b_to_a_map=torch.zeros_like(path_to_utt_map),
+        sorted_match_a=True,
+    )
+
+    # Its labels are token IDs.
+    # If LM is G, its aux_labels are tokens IDs;
+    # If LM is LG, its aux_labels are words IDs.
+    path_lattice = k2.top_sort(k2.connect(path_lattice))
+    one_best = k2.shortest_path(path_lattice, use_double_scores=use_double_scores)
+
+    lm_scores = one_best.get_tot_scores(
+        use_double_scores=use_double_scores,
+        log_semiring=True,  # Note: we always use True
+    )
+    # If LM is LG, we might get empty paths
+    lm_scores[lm_scores == float("-inf")] = -1e9
+
+    ans = dict()
+    for lm_scale in lm_scale_list:
+        tot_scores = hp_scores.values / lm_scale + lm_scores
+        tot_scores = k2.RaggedTensor(nbest.shape, tot_scores)
+        max_indexes = tot_scores.argmax()
+        best_path = k2.index_fsa(nbest.fsa, max_indexes)
+        key = f"lm_scale_{lm_scale}"
+        ans[key] = best_path
+    return ans
+
+
 def rescore_with_whole_lattice(
     lattice: k2.Fsa,
     G_with_epsilon_loops: k2.Fsa,
diff --git a/icefall/mmi.py b/icefall/mmi.py
index 16ed6e032..b7777b434 100644
--- a/icefall/mmi.py
+++ b/icefall/mmi.py
@@ -112,8 +112,12 @@ def _compute_mmi_loss_exact_non_optimized(
     num_graphs, den_graphs = graph_compiler.compile(texts, replicate_den=True)
 
     # TODO: pass output_beam as function argument
-    num_lats = k2.intersect_dense(num_graphs, dense_fsa_vec, output_beam=beam_size)
-    den_lats = k2.intersect_dense(den_graphs, dense_fsa_vec, output_beam=beam_size)
+    num_lats = k2.intersect_dense(
+        num_graphs, dense_fsa_vec, output_beam=beam_size, max_arcs=2147483600
+    )
+    den_lats = k2.intersect_dense(
+        den_graphs, dense_fsa_vec, output_beam=beam_size, max_arcs=2147483600
+    )
 
     num_tot_scores = num_lats.get_tot_scores(log_semiring=True, use_double_scores=True)
 
@@ -144,7 +148,7 @@ def _compute_mmi_loss_pruned(
     """
     num_graphs, den_graphs = graph_compiler.compile(texts, replicate_den=False)
 
-    num_lats = k2.intersect_dense(num_graphs, dense_fsa_vec, output_beam=10.0)
+    num_lats = k2.intersect_dense(num_graphs, dense_fsa_vec, output_beam=8.0)
 
     # the values for search_beam/output_beam/min_active_states/max_active_states
     # are not tuned. You may want to tune them.

From 0470bbae66d2c9ebc91ee5d0dfa37dfb4df3a9cb Mon Sep 17 00:00:00 2001
From: Zengwei Yao 
Date: Tue, 13 Dec 2022 15:47:30 +0800
Subject: [PATCH 047/174] minor fix for zipformer recipe (#758)

* minor fix

* add CI test
---
 .github/workflows/test.yml                    |  3 +++
 .../pruned_transducer_stateless7/export.py    |  1 -
 .../test_model.py                             | 20 +++++++++++++++----
 .../pruned_transducer_stateless7/zipformer.py | 16 +++++----------
 4 files changed, 24 insertions(+), 16 deletions(-)

diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml
index 4dbe99827..c062a2a3d 100644
--- a/.github/workflows/test.yml
+++ b/.github/workflows/test.yml
@@ -113,6 +113,9 @@ jobs:
           cd ../pruned_transducer_stateless4
           pytest -v -s
 
+          cd ../pruned_transducer_stateless7
+          pytest -v -s
+
           cd ../transducer_stateless
           pytest -v -s
 
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/export.py b/egs/librispeech/ASR/pruned_transducer_stateless7/export.py
index 9a6f3ed37..3e3160e7e 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless7/export.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7/export.py
@@ -294,7 +294,6 @@ def main():
 
     if params.jit is True:
         convert_scaled_to_non_scaled(model, inplace=True)
-        logging.info("Using torch.jit.script()")
         # We won't use the forward() method of the model in C++, so just ignore
         # it here.
         # Otherwise, one of its arguments is a ragged tensor and is not
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/test_model.py b/egs/librispeech/ASR/pruned_transducer_stateless7/test_model.py
index db7fb7b3e..cdf914df3 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless7/test_model.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7/test_model.py
@@ -20,19 +20,21 @@
 To run this file, do:
 
     cd icefall/egs/librispeech/ASR
-    python ./pruned_transducer_stateless4/test_model.py
+    python ./pruned_transducer_stateless7/test_model.py
 """
 
+import torch
+
+from scaling_converter import convert_scaled_to_non_scaled
 from train import get_params, get_transducer_model
 
 
-def test_model_1():
+def test_model():
     params = get_params()
     params.vocab_size = 500
     params.blank_id = 0
     params.context_size = 2
     params.num_encoder_layers = "2,4,3,2,4"
-    #  params.feedforward_dims = "1024,1024,1536,1536,1024"
     params.feedforward_dims = "1024,1024,2048,2048,1024"
     params.nhead = "8,8,8,8,8"
     params.encoder_dims = "384,384,384,384,384"
@@ -47,9 +49,19 @@ def test_model_1():
     num_param = sum([p.numel() for p in model.parameters()])
     print(f"Number of model parameters: {num_param}")
 
+    # Test jit script
+    convert_scaled_to_non_scaled(model, inplace=True)
+    # We won't use the forward() method of the model in C++, so just ignore
+    # it here.
+    # Otherwise, one of its arguments is a ragged tensor and is not
+    # torch scriptabe.
+    model.__class__.forward = torch.jit.ignore(model.__class__.forward)
+    print("Using torch.jit.script")
+    model = torch.jit.script(model)
+
 
 def main():
-    test_model_1()
+    test_model()
 
 
 if __name__ == "__main__":
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py
index e8fd89abd..ed1e2efa2 100644
--- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py
@@ -1,5 +1,5 @@
 #!/usr/bin/env python3
-# Copyright (c)  2021  University of Chinese Academy of Sciences (author: Han Zhu)
+# Copyright    2022  Xiaomi Corp.        (authors: Daniel Povey)
 #
 # See ../../../../LICENSE for clarification regarding multiple authors
 #
@@ -454,7 +454,7 @@ class ZipformerEncoderLayer(nn.Module):
         # pooling module
         if torch.jit.is_scripting():
             src = src + self.pooling(src, key_padding_mask=src_key_padding_mask)
-        elif random.random() > dynamic_dropout:
+        elif random.random() >= dynamic_dropout:
             src = src + self.pooling(src, key_padding_mask=src_key_padding_mask)
 
         if torch.jit.is_scripting():
@@ -478,7 +478,7 @@ class ZipformerEncoderLayer(nn.Module):
                 src, src_key_padding_mask=src_key_padding_mask
             )
         else:
-            use_self_attn = random.random() > dynamic_dropout
+            use_self_attn = random.random() >= dynamic_dropout
             if use_self_attn:
                 src_att, attn_weights = self.self_attn(
                     src,
@@ -488,7 +488,7 @@ class ZipformerEncoderLayer(nn.Module):
                 )
                 src = src + src_att
 
-            if random.random() > dynamic_dropout:
+            if random.random() >= dynamic_dropout:
                 src = src + self.conv_module1(
                     src, src_key_padding_mask=src_key_padding_mask
                 )
@@ -497,7 +497,7 @@ class ZipformerEncoderLayer(nn.Module):
             if use_self_attn:
                 src = src + self.self_attn.forward2(src, attn_weights)
 
-            if random.random() > dynamic_dropout:
+            if random.random() >= dynamic_dropout:
                 src = src + self.conv_module2(
                     src, src_key_padding_mask=src_key_padding_mask
                 )
@@ -1289,12 +1289,6 @@ class RelPositionMultiheadAttention(nn.Module):
             bsz * num_heads, seq_len, seq_len
         )
 
-        assert list(attn_output_weights.size()) == [
-            bsz * num_heads,
-            seq_len,
-            seq_len,
-        ]
-
         if attn_mask is not None:
             if attn_mask.dtype == torch.bool:
                 attn_output_weights.masked_fill_(attn_mask, float("-inf"))

From b293db4baf1606cfe95066cf28ffde56173a7ddb Mon Sep 17 00:00:00 2001
From: Daniil 
Date: Tue, 13 Dec 2022 03:13:26 -0500
Subject: [PATCH 048/174] Tedlium3 conformer ctc2 (#696)

* modify preparation

* small refacor

* add tedlium3 conformer_ctc2

* modify decode

* filter unk in decode

* add scaling converter

* address comments

* fix lambda function lhotse

* add implicit manifest shuffle

* refactor ctc_greedy_search

* import model arguments from train.py

* style fix

* fix ci test and last style issues

* update RESULTS

* fix RESULTS numbers

* fix label smoothing loss

* update model parameters number in RESULTS
---
 .../ASR/conformer_ctc/label_smoothing.py      |    3 +-
 .../ASR/conformer_ctc2/subsampling.py         |    5 +-
 .../emformer2.py                              |    4 +-
 egs/librispeech/ASR/local/compile_hlg.py      |    2 +-
 .../ASR/local/compute_fbank_musan.py          |    8 +-
 egs/librispeech/ASR/local/prepare_lang_bpe.py |   23 +-
 .../pruned_transducer_stateless2/scaling.py   |   20 +-
 .../scaling_converter.py                      |    2 +-
 egs/tedlium3/ASR/RESULTS.md                   |   83 ++
 egs/tedlium3/ASR/conformer_ctc2/__init__.py   |    0
 .../ASR/conformer_ctc2/asr_datamodule.py      |    1 +
 egs/tedlium3/ASR/conformer_ctc2/attention.py  |  201 +++
 egs/tedlium3/ASR/conformer_ctc2/combiner.py   |  244 ++++
 egs/tedlium3/ASR/conformer_ctc2/conformer.py  | 1033 ++++++++++++++++
 egs/tedlium3/ASR/conformer_ctc2/decode.py     |  899 ++++++++++++++
 egs/tedlium3/ASR/conformer_ctc2/export.py     |  294 +++++
 .../ASR/conformer_ctc2/label_smoothing.py     |    1 +
 egs/tedlium3/ASR/conformer_ctc2/lstmp.py      |    1 +
 egs/tedlium3/ASR/conformer_ctc2/optim.py      |    1 +
 egs/tedlium3/ASR/conformer_ctc2/scaling.py    |    1 +
 .../ASR/conformer_ctc2/scaling_converter.py   |    1 +
 .../ASR/conformer_ctc2/subsampling.py         |    1 +
 egs/tedlium3/ASR/conformer_ctc2/train.py      | 1061 ++++++++++++++++
 .../ASR/conformer_ctc2/transformer.py         | 1093 +++++++++++++++++
 .../convert_transcript_words_to_bpe_ids.py    |   42 +-
 .../convert_transcript_words_to_tokens.py     |    1 -
 .../ASR/local/generate_unique_lexicon.py      |    1 -
 egs/tedlium3/ASR/local/prepare_lang.py        |    1 -
 egs/tedlium3/ASR/local/prepare_lexicon.py     |   94 --
 egs/tedlium3/ASR/local/prepare_transcripts.py |   66 +-
 egs/tedlium3/ASR/local/prepare_words.py       |   83 ++
 egs/tedlium3/ASR/local/test_prepare_lang.py   |    1 -
 egs/tedlium3/ASR/prepare.sh                   |   98 +-
 icefall/decode.py                             |    2 -
 test/test_lexicon.py                          |    2 +-
 35 files changed, 5158 insertions(+), 215 deletions(-)
 create mode 100755 egs/tedlium3/ASR/conformer_ctc2/__init__.py
 create mode 120000 egs/tedlium3/ASR/conformer_ctc2/asr_datamodule.py
 create mode 100644 egs/tedlium3/ASR/conformer_ctc2/attention.py
 create mode 100644 egs/tedlium3/ASR/conformer_ctc2/combiner.py
 create mode 100644 egs/tedlium3/ASR/conformer_ctc2/conformer.py
 create mode 100755 egs/tedlium3/ASR/conformer_ctc2/decode.py
 create mode 100755 egs/tedlium3/ASR/conformer_ctc2/export.py
 create mode 120000 egs/tedlium3/ASR/conformer_ctc2/label_smoothing.py
 create mode 120000 egs/tedlium3/ASR/conformer_ctc2/lstmp.py
 create mode 120000 egs/tedlium3/ASR/conformer_ctc2/optim.py
 create mode 120000 egs/tedlium3/ASR/conformer_ctc2/scaling.py
 create mode 120000 egs/tedlium3/ASR/conformer_ctc2/scaling_converter.py
 create mode 120000 egs/tedlium3/ASR/conformer_ctc2/subsampling.py
 create mode 100755 egs/tedlium3/ASR/conformer_ctc2/train.py
 create mode 100644 egs/tedlium3/ASR/conformer_ctc2/transformer.py
 delete mode 120000 egs/tedlium3/ASR/local/convert_transcript_words_to_tokens.py
 delete mode 120000 egs/tedlium3/ASR/local/generate_unique_lexicon.py
 delete mode 120000 egs/tedlium3/ASR/local/prepare_lang.py
 delete mode 100755 egs/tedlium3/ASR/local/prepare_lexicon.py
 create mode 100755 egs/tedlium3/ASR/local/prepare_words.py
 delete mode 120000 egs/tedlium3/ASR/local/test_prepare_lang.py

diff --git a/egs/librispeech/ASR/conformer_ctc/label_smoothing.py b/egs/librispeech/ASR/conformer_ctc/label_smoothing.py
index cb0d6e04d..52d2eda3b 100644
--- a/egs/librispeech/ASR/conformer_ctc/label_smoothing.py
+++ b/egs/librispeech/ASR/conformer_ctc/label_smoothing.py
@@ -44,7 +44,8 @@ class LabelSmoothingLoss(torch.nn.Module):
             mean of the output is taken. (3) "sum": the output will be summed.
         """
         super().__init__()
-        assert 0.0 <= label_smoothing < 1.0
+        assert 0.0 <= label_smoothing < 1.0, f"{label_smoothing}"
+        assert reduction in ("none", "sum", "mean"), reduction
         self.ignore_index = ignore_index
         self.label_smoothing = label_smoothing
         self.reduction = reduction
diff --git a/egs/librispeech/ASR/conformer_ctc2/subsampling.py b/egs/librispeech/ASR/conformer_ctc2/subsampling.py
index 3fcb4196f..85a4dc8df 100644
--- a/egs/librispeech/ASR/conformer_ctc2/subsampling.py
+++ b/egs/librispeech/ASR/conformer_ctc2/subsampling.py
@@ -24,10 +24,9 @@ from scaling import (
     ScaledConv2d,
     ScaledLinear,
 )
-from torch import nn
 
 
-class Conv2dSubsampling(nn.Module):
+class Conv2dSubsampling(torch.nn.Module):
     """Convolutional 2D subsampling (to 1/4 length).
 
     Convert an input of shape (N, T, idim) to an output
@@ -61,7 +60,7 @@ class Conv2dSubsampling(nn.Module):
         assert in_channels >= 7
         super().__init__()
 
-        self.conv = nn.Sequential(
+        self.conv = torch.nn.Sequential(
             ScaledConv2d(
                 in_channels=1,
                 out_channels=layer1_channels,
diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer2.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer2.py
index 65a7efa77..188059044 100644
--- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer2.py
+++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer2.py
@@ -1435,7 +1435,7 @@ class EmformerEncoder(nn.Module):
         self,
         x: torch.Tensor,
         states: List[torch.Tensor],
-    ) -> Tuple[torch.Tensor, List[torch.Tensor],]:
+    ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
         """Forward pass for streaming inference.
 
         B: batch size;
@@ -1640,7 +1640,7 @@ class Emformer(EncoderInterface):
         self,
         x: torch.Tensor,
         states: List[torch.Tensor],
-    ) -> Tuple[torch.Tensor, List[torch.Tensor],]:
+    ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
         """Forward pass for streaming inference.
 
         B: batch size;
diff --git a/egs/librispeech/ASR/local/compile_hlg.py b/egs/librispeech/ASR/local/compile_hlg.py
index df6c609bb..08dac6a7b 100755
--- a/egs/librispeech/ASR/local/compile_hlg.py
+++ b/egs/librispeech/ASR/local/compile_hlg.py
@@ -24,7 +24,7 @@ This script takes as input lang_dir and generates HLG from
 
         Caution: We use a lexicon that contains disambiguation symbols
 
-    - G, the LM, built from data/lm/G_3_gram.fst.txt
+    - G, the LM, built from data/lm/G_n_gram.fst.txt
 
 The generated HLG is saved in $lang_dir/HLG.pt
 """
diff --git a/egs/librispeech/ASR/local/compute_fbank_musan.py b/egs/librispeech/ASR/local/compute_fbank_musan.py
index 4a4093ae4..62036467e 100755
--- a/egs/librispeech/ASR/local/compute_fbank_musan.py
+++ b/egs/librispeech/ASR/local/compute_fbank_musan.py
@@ -28,7 +28,7 @@ import os
 from pathlib import Path
 
 import torch
-from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter, combine
+from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter, MonoCut, combine
 from lhotse.recipes.utils import read_manifests_if_cached
 
 from icefall.utils import get_executor
@@ -41,6 +41,10 @@ torch.set_num_threads(1)
 torch.set_num_interop_threads(1)
 
 
+def is_cut_long(c: MonoCut) -> bool:
+    return c.duration > 5
+
+
 def compute_fbank_musan():
     src_dir = Path("data/manifests")
     output_dir = Path("data/fbank")
@@ -86,7 +90,7 @@ def compute_fbank_musan():
                 recordings=combine(part["recordings"] for part in manifests.values())
             )
             .cut_into_windows(10.0)
-            .filter(lambda c: c.duration > 5)
+            .filter(is_cut_long)
             .compute_and_store_features(
                 extractor=extractor,
                 storage_path=f"{output_dir}/musan_feats",
diff --git a/egs/librispeech/ASR/local/prepare_lang_bpe.py b/egs/librispeech/ASR/local/prepare_lang_bpe.py
index e121aefa9..2a2d9c219 100755
--- a/egs/librispeech/ASR/local/prepare_lang_bpe.py
+++ b/egs/librispeech/ASR/local/prepare_lang_bpe.py
@@ -127,7 +127,7 @@ def lexicon_to_fst_no_sil(
 
 
 def generate_lexicon(
-    model_file: str, words: List[str]
+    model_file: str, words: List[str], oov: str
 ) -> Tuple[Lexicon, Dict[str, int]]:
     """Generate a lexicon from a BPE model.
 
@@ -136,6 +136,8 @@ def generate_lexicon(
         Path to a sentencepiece model.
       words:
         A list of strings representing words.
+      oov:
+        The out of vocabulary word in lexicon.
     Returns:
       Return a tuple with two elements:
         - A dict whose keys are words and values are the corresponding
@@ -156,12 +158,9 @@ def generate_lexicon(
     for word, pieces in zip(words, words_pieces):
         lexicon.append((word, pieces))
 
-    # The OOV word is 
-    lexicon.append(("", [sp.id_to_piece(sp.unk_id())]))
+    lexicon.append((oov, ["▁", sp.id_to_piece(sp.unk_id())]))
 
-    token2id: Dict[str, int] = dict()
-    for i in range(sp.vocab_size()):
-        token2id[sp.id_to_piece(i)] = i
+    token2id: Dict[str, int] = {sp.id_to_piece(i): i for i in range(sp.vocab_size())}
 
     return lexicon, token2id
 
@@ -176,6 +175,13 @@ def get_args():
         """,
     )
 
+    parser.add_argument(
+        "--oov",
+        type=str,
+        default="",
+        help="The out of vocabulary word in lexicon.",
+    )
+
     parser.add_argument(
         "--debug",
         type=str2bool,
@@ -202,12 +208,13 @@ def main():
 
     words = word_sym_table.symbols
 
-    excluded = ["", "!SIL", "", "", "#0", "", ""]
+    excluded = ["", "!SIL", "", args.oov, "#0", "", ""]
+
     for w in excluded:
         if w in words:
             words.remove(w)
 
-    lexicon, token_sym_table = generate_lexicon(model_file, words)
+    lexicon, token_sym_table = generate_lexicon(model_file, words, args.oov)
 
     lexicon_disambig, max_disambig = add_disambig_symbols(lexicon)
 
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py
index c802ecf89..963ebdc2d 100644
--- a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py
@@ -652,16 +652,16 @@ class ActivationBalancer(torch.nn.Module):
     def forward(self, x: Tensor) -> Tensor:
         if random.random() >= self.balance_prob:
             return x
-        else:
-            return ActivationBalancerFunction.apply(
-                x,
-                self.channel_dim,
-                self.min_positive,
-                self.max_positive,
-                self.max_factor / self.balance_prob,
-                self.min_abs,
-                self.max_abs,
-            )
+
+        return ActivationBalancerFunction.apply(
+            x,
+            self.channel_dim,
+            self.min_positive,
+            self.max_positive,
+            self.max_factor / self.balance_prob,
+            self.min_abs,
+            self.max_abs,
+        )
 
 
 class DoubleSwishFunction(torch.autograd.Function):
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py b/egs/librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py
index b712eeda0..a6540c584 100644
--- a/egs/librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py
@@ -282,7 +282,7 @@ def convert_scaled_to_non_scaled(
     if not inplace:
         model = copy.deepcopy(model)
 
-    excluded_patterns = r"self_attn\.(in|out)_proj"
+    excluded_patterns = r"(self|src)_attn\.(in|out)_proj"
     p = re.compile(excluded_patterns)
 
     d = {}
diff --git a/egs/tedlium3/ASR/RESULTS.md b/egs/tedlium3/ASR/RESULTS.md
index 511b19f73..38eaa8f44 100644
--- a/egs/tedlium3/ASR/RESULTS.md
+++ b/egs/tedlium3/ASR/RESULTS.md
@@ -1,5 +1,88 @@
 ## Results
 
+### TedLium3 BPE training results (Conformer-CTC 2)
+
+#### [conformer_ctc2](./conformer_ctc2)
+
+See  for more details.
+
+The tensorboard log can be found at
+
+
+You can find a pretrained model and decoding results at:
+
+
+Number of model parameters: 101141699, i.e., 101.14 M
+
+The WERs are
+
+|                          | dev        | test        | comment             |
+|--------------------------|------------|-------------|---------------------|
+| ctc decoding             | 6.45       | 5.96        | --epoch 38 --avg 26 |
+| 1best                    | 5.92       | 5.51        | --epoch 38 --avg 26 |
+| whole lattice rescoring  | 5.96       | 5.47        | --epoch 38 --avg 26 |
+| attention decoder        | 5.60       | 5.33        | --epoch 38 --avg 26 |
+
+The training command for reproducing is given below:
+
+```
+export CUDA_VISIBLE_DEVICES="0,1,2,3"
+
+./conformer_ctc2/train.py \
+    --world-size 4 \
+    --num-epochs 40 \
+    --exp-dir conformer_ctc2/exp \
+    --max-duration 350 \
+    --use-fp16 true
+```
+
+The decoding command is:
+```
+epoch=38
+avg=26
+
+## ctc decoding
+./conformer_ctc2/decode.py \
+  --method ctc-decoding \
+  --exp-dir conformer_ctc2/exp \
+  --lang-dir data/lang_bpe_500 \
+  --result-dir conformer_ctc2/exp \
+  --max-duration 500 \
+  --epoch $epoch \
+  --avg $avg
+
+## 1best
+./conformer_ctc2/decode.py \
+  --method 1best \
+  --exp-dir conformer_ctc2/exp \
+  --lang-dir data/lang_bpe_500 \
+  --result-dir conformer_ctc2/exp \
+  --max-duration 500 \
+  --epoch $epoch \
+  --avg $avg
+
+## whole lattice rescoring
+./conformer_ctc2/decode.py \
+  --method whole-lattice-rescoring \
+  --exp-dir conformer_ctc2/exp \
+  --lm-path data/lm/G_4_gram_big.pt \
+  --lang-dir data/lang_bpe_500 \
+  --result-dir conformer_ctc2/exp \
+  --max-duration 500 \
+  --epoch $epoch \
+  --avg $avg
+
+## attention decoder
+./conformer_ctc2/decode.py \
+  --method attention-decoder \
+  --exp-dir conformer_ctc2/exp \
+  --lang-dir data/lang_bpe_500 \
+  --result-dir conformer_ctc2/exp \
+  --max-duration 500 \
+  --epoch $epoch \
+  --avg $avg
+```
+
 ### TedLium3 BPE training results (Pruned Transducer)
 
 #### 2022-03-21
diff --git a/egs/tedlium3/ASR/conformer_ctc2/__init__.py b/egs/tedlium3/ASR/conformer_ctc2/__init__.py
new file mode 100755
index 000000000..e69de29bb
diff --git a/egs/tedlium3/ASR/conformer_ctc2/asr_datamodule.py b/egs/tedlium3/ASR/conformer_ctc2/asr_datamodule.py
new file mode 120000
index 000000000..49b2ee483
--- /dev/null
+++ b/egs/tedlium3/ASR/conformer_ctc2/asr_datamodule.py
@@ -0,0 +1 @@
+../transducer_stateless/asr_datamodule.py
\ No newline at end of file
diff --git a/egs/tedlium3/ASR/conformer_ctc2/attention.py b/egs/tedlium3/ASR/conformer_ctc2/attention.py
new file mode 100644
index 000000000..178cd7e62
--- /dev/null
+++ b/egs/tedlium3/ASR/conformer_ctc2/attention.py
@@ -0,0 +1,201 @@
+# Copyright    2022  Behavox LLC.        (author: Daniil Kulko)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Optional, Tuple, Union
+
+import torch
+from scaling import ScaledLinear
+
+
+class MultiheadAttention(torch.nn.Module):
+    """Allows the model to jointly attend to information
+    from different representation subspaces. This is a modified
+    version of the original version of multihead attention
+    (see Attention Is All You Need )
+    with replacement of input / output projection layers
+    with newly introduced ScaleLinear layer
+    (see https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py).
+
+    Args:
+        embed_dim:
+          total dimension of the model.
+        num_heads:
+          number of parallel attention heads. Note that embed_dim will be split
+          across num_heads, i.e. each head will have dimension (embed_dim // num_heads).
+        dropout:
+          dropout probability on attn_output_weights. (default=0.0).
+        bias:
+          if specified, adds bias to input / output projection layers (default=True).
+        add_bias_kv:
+          if specified, adds bias to the key and value sequences at dim=0 (default=False).
+        add_zero_attn:
+          if specified, adds a new batch of zeros to the key and value sequences
+          at dim=1 (default=False).
+        batch_first:
+          if True, then the input and output tensors are provided as
+          (batch, seq, feature), otherwise (seq, batch, feature) (default=False).
+
+    Examples::
+        >>> multihead_attn = MultiheadAttention(embed_dim, num_heads)
+        >>> attn_output, attn_output_weights = multihead_attn(query, key, value)
+    """
+
+    def __init__(
+        self,
+        embed_dim: int,
+        num_heads: int,
+        dropout: float = 0.0,
+        bias: bool = True,
+        add_bias_kv: bool = False,
+        add_zero_attn: bool = False,
+        batch_first: bool = False,
+        device: Union[torch.device, str, None] = None,
+        dtype: Union[torch.dtype, str, None] = None,
+    ) -> None:
+
+        super().__init__()
+
+        self.embed_dim = embed_dim
+        self.num_heads = num_heads
+        self.dropout = dropout
+        self.batch_first = batch_first
+
+        if embed_dim % num_heads != 0:
+            raise ValueError(
+                f"embed_dim must be divisible by num_heads. "
+                "Got embedding dim vs number 0f heads: "
+                f"{embed_dim} vs {num_heads}"
+            )
+
+        self.head_dim = embed_dim // num_heads
+
+        self.in_proj = ScaledLinear(
+            embed_dim,
+            3 * embed_dim,
+            bias=bias,
+            device=device,
+            dtype=dtype,
+        )
+        self.out_proj = ScaledLinear(
+            embed_dim,
+            embed_dim,
+            bias=bias,
+            initial_scale=0.25,
+            device=device,
+            dtype=dtype,
+        )
+
+        if add_bias_kv:
+            self.bias_k = torch.nn.Parameter(
+                torch.empty((1, 1, embed_dim), device=device, dtype=dtype)
+            )
+            self.bias_v = torch.nn.Parameter(
+                torch.empty((1, 1, embed_dim), device=device, dtype=dtype)
+            )
+        else:
+            self.register_parameter("bias_k", None)
+            self.register_parameter("bias_v", None)
+
+        self.add_zero_attn = add_zero_attn
+
+        self._reset_parameters()
+
+    def _reset_parameters(self) -> None:
+        if self.bias_k is not None:
+            torch.nn.init.xavier_normal_(self.bias_k)
+        if self.bias_v is not None:
+            torch.nn.init.xavier_normal_(self.bias_v)
+
+    def forward(
+        self,
+        query: torch.Tensor,
+        key: torch.Tensor,
+        value: torch.Tensor,
+        key_padding_mask: Optional[torch.Tensor] = None,
+        need_weights: bool = True,
+        attn_mask: Optional[torch.Tensor] = None,
+    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
+        """
+        Args:
+            query:
+              Query embeddings of shape (L, N, E_q) when batch_first=False or (N, L, E_q)
+              when batch_first=True, where L is the target sequence length, N is the batch size,
+              and E_q is the query embedding dimension embed_dim. Queries are compared against
+              key-value pairs to produce the output. See "Attention Is All You Need" for more details.
+            key:
+              Key embeddings of shape (S, N, E_k) when batch_first=False or (N, S, E_k) when
+              batch_first=True, where S is the source sequence length, N is the batch size, and
+              E_k is the key embedding dimension kdim. See "Attention Is All You Need" for more details.
+            value:
+              Value embeddings of shape (S, N, E_v) when batch_first=False or (N, S, E_v) when
+              batch_first=True, where S is the source sequence length, N is the batch size, and
+              E_v is the value embedding dimension vdim. See "Attention Is All You Need" for more details.
+            key_padding_mask:
+              If specified, a mask of shape (N, S) indicating which elements within key
+              to ignore for the purpose of attention (i.e. treat as "padding").
+              Binary and byte masks are supported. For a binary mask, a True value indicates
+              that the corresponding key value will be ignored for the purpose of attention.
+              For a byte mask, a non-zero value indicates that the corresponding key value will be ignored.
+            need_weights:
+              If specifid, returns attn_output_weights in addition to attn_outputs (default=True).
+            attn_mask:
+              If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape
+              (L, S) or (N * num_heads, L, S), where N is the batch size, L is the target sequence length,
+              and S is the source sequence length. A 2D mask will be broadcasted across the batch while
+              a 3D mask allows for a different mask for each entry in the batch.
+              Binary, byte, and float masks are supported. For a binary mask, a True value indicates
+              that the corresponding position is not allowed to attend. For a byte mask, a non-zero
+              value indicates that the corresponding position is not allowed to attend. For a float mask,
+              the mask values will be added to the attention weight.
+
+        Returns:
+            attn_output:
+              Attention outputs of shape (L, N, E) when batch_first=False or (N, L, E) when batch_first=True,
+              where L is the target sequence length, N is the batch size, and E is the embedding dimension
+              embed_dim.
+            attn_output_weights:
+              Attention output weights of shape (N, L, S), where N is the batch size, L is the target sequence
+              length, and S is the source sequence length. Only returned when need_weights=True.
+        """
+        if self.batch_first:
+            query, key, value = [x.transpose(1, 0) for x in (query, key, value)]
+
+        (
+            attn_output,
+            attn_output_weights,
+        ) = torch.nn.functional.multi_head_attention_forward(
+            query,
+            key,
+            value,
+            self.embed_dim,
+            self.num_heads,
+            in_proj_weight=self.in_proj.get_weight(),
+            in_proj_bias=self.in_proj.get_bias(),
+            bias_k=self.bias_k,
+            bias_v=self.bias_v,
+            add_zero_attn=self.add_zero_attn,
+            dropout_p=self.dropout,
+            out_proj_weight=self.out_proj.get_weight(),
+            out_proj_bias=self.out_proj.get_bias(),
+            training=self.training,
+            key_padding_mask=key_padding_mask,
+            need_weights=need_weights,
+            attn_mask=attn_mask,
+        )
+
+        if self.batch_first:
+            return attn_output.transpose(1, 0), attn_output_weights
+        return attn_output, attn_output_weights
diff --git a/egs/tedlium3/ASR/conformer_ctc2/combiner.py b/egs/tedlium3/ASR/conformer_ctc2/combiner.py
new file mode 100644
index 000000000..ff526029d
--- /dev/null
+++ b/egs/tedlium3/ASR/conformer_ctc2/combiner.py
@@ -0,0 +1,244 @@
+# Copyright    2022  Behavox LLC.        (author: Daniil Kulko)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import List
+
+import torch
+
+
+class RandomCombine(torch.nn.Module):
+    """
+    This module combines a list of Tensors, all with the same shape, to
+    produce a single output of that same shape which, in training time,
+    is a random combination of all the inputs; but which in test time
+    will be just the last input.
+    The idea is that the list of Tensors will be a list of outputs of multiple
+    conformer layers.  This has a similar effect as iterated loss. (See:
+    DEJA-VU: DOUBLE FEATURE PRESENTATION AND ITERATED LOSS IN DEEP TRANSFORMER
+    NETWORKS).
+    """
+
+    def __init__(
+        self,
+        num_inputs: int,
+        final_weight: float = 0.5,
+        pure_prob: float = 0.5,
+        stddev: float = 2.0,
+    ) -> None:
+        """
+        Args:
+          num_inputs:
+            The number of tensor inputs, which equals the number of layers'
+            outputs that are fed into this module.  E.g. in an 18-layer neural
+            net if we output layers 16, 12, 18, num_inputs would be 3.
+          final_weight:
+            The amount of weight or probability we assign to the
+            final layer when randomly choosing layers or when choosing
+            continuous layer weights.
+          pure_prob:
+            The probability, on each frame, with which we choose
+            only a single layer to output (rather than an interpolation)
+          stddev:
+            A standard deviation that we add to log-probs for computing
+            randomized weights.
+        The method of choosing which layers, or combinations of layers, to use,
+        is conceptually as follows::
+            With probability `pure_prob`::
+               With probability `final_weight`: choose final layer,
+               Else: choose random non-final layer.
+            Else::
+               Choose initial log-weights that correspond to assigning
+               weight `final_weight` to the final layer and equal
+               weights to other layers; then add Gaussian noise
+               with variance `stddev` to these log-weights, and normalize
+               to weights (note: the average weight assigned to the
+               final layer here will not be `final_weight` if stddev>0).
+        """
+        super().__init__()
+        assert 0 <= pure_prob <= 1, pure_prob
+        assert 0 < final_weight < 1, final_weight
+        assert num_inputs >= 1, num_inputs
+
+        self.num_inputs = num_inputs
+        self.final_weight = final_weight
+        self.pure_prob = pure_prob
+        self.stddev = stddev
+
+        self.final_log_weight = (
+            torch.tensor((final_weight / (1 - final_weight)) * (self.num_inputs - 1))
+            .log()
+            .item()
+        )
+
+    def forward(self, inputs: List[torch.Tensor]) -> torch.Tensor:
+        """Forward function.
+        Args:
+          inputs:
+            A list of Tensor, e.g. from various layers of a transformer.
+            All must be the same shape, of (*, num_channels)
+        Returns:
+          A Tensor of shape (*, num_channels). In test mode
+          this is just the final input.
+        """
+        num_inputs = self.num_inputs
+        assert len(inputs) == num_inputs, f"{len(inputs)}, {num_inputs}"
+        if not self.training or torch.jit.is_scripting() or len(inputs) == 1:
+            return inputs[-1]
+
+        # Shape of weights: (*, num_inputs)
+        num_channels = inputs[0].shape[-1]
+        num_frames = inputs[0].numel() // num_channels
+
+        ndim = inputs[0].ndim
+        # stacked_inputs: (num_frames, num_channels, num_inputs)
+        stacked_inputs = torch.stack(inputs, dim=ndim).reshape(
+            (num_frames, num_channels, num_inputs)
+        )
+
+        # weights: (num_frames, num_inputs)
+        weights = self._get_random_weights(
+            inputs[0].dtype, inputs[0].device, num_frames
+        )
+
+        weights = weights.reshape(num_frames, num_inputs, 1)
+        # ans: (num_frames, num_channels, 1)
+        ans = torch.matmul(stacked_inputs, weights)
+        # ans: (*, num_channels)
+
+        ans = ans.reshape(inputs[0].shape[:-1] + (num_channels,))
+
+        return ans
+
+    def _get_random_weights(
+        self, dtype: torch.dtype, device: torch.device, num_frames: int
+    ) -> torch.Tensor:
+        """Return a tensor of random weights, of shape
+        `(num_frames, self.num_inputs)`,
+        Args:
+          dtype:
+            The data-type desired for the answer, e.g. float, double.
+          device:
+            The device needed for the answer.
+          num_frames:
+            The number of sets of weights desired
+        Returns:
+          A tensor of shape (num_frames, self.num_inputs), such that
+          `ans.sum(dim=1)` is all ones.
+        """
+        pure_prob = self.pure_prob
+        if pure_prob == 0.0:
+            return self._get_random_mixed_weights(dtype, device, num_frames)
+        elif pure_prob == 1.0:
+            return self._get_random_pure_weights(dtype, device, num_frames)
+        else:
+            p = self._get_random_pure_weights(dtype, device, num_frames)
+            m = self._get_random_mixed_weights(dtype, device, num_frames)
+            return torch.where(
+                torch.rand(num_frames, 1, device=device) < self.pure_prob, p, m
+            )
+
+    def _get_random_pure_weights(
+        self, dtype: torch.dtype, device: torch.device, num_frames: int
+    ) -> torch.Tensor:
+        """Return a tensor of random one-hot weights, of shape
+        `(num_frames, self.num_inputs)`,
+        Args:
+          dtype:
+            The data-type desired for the answer, e.g. float, double.
+          device:
+            The device needed for the answer.
+          num_frames:
+            The number of sets of weights desired.
+        Returns:
+          A one-hot tensor of shape `(num_frames, self.num_inputs)`, with
+          exactly one weight equal to 1.0 on each frame.
+        """
+        final_prob = self.final_weight
+
+        # final contains self.num_inputs - 1 in all elements
+        final = torch.full((num_frames,), self.num_inputs - 1, device=device)
+        # nonfinal contains random integers in [0..num_inputs - 2], these are for non-final weights.
+        nonfinal = torch.randint(self.num_inputs - 1, (num_frames,), device=device)
+
+        indexes = torch.where(
+            torch.rand(num_frames, device=device) < final_prob, final, nonfinal
+        )
+        ans = torch.nn.functional.one_hot(indexes, num_classes=self.num_inputs).to(
+            dtype=dtype
+        )
+        return ans
+
+    def _get_random_mixed_weights(
+        self, dtype: torch.dtype, device: torch.device, num_frames: int
+    ) -> torch.Tensor:
+        """Return a tensor of random one-hot weights, of shape
+        `(num_frames, self.num_inputs)`,
+        Args:
+          dtype:
+            The data-type desired for the answer, e.g. float, double.
+          device:
+            The device needed for the answer.
+          num_frames:
+            The number of sets of weights desired.
+        Returns:
+          A tensor of shape (num_frames, self.num_inputs), which elements
+          in [0..1] that sum to one over the second axis, i.e.
+          `ans.sum(dim=1)` is all ones.
+        """
+        logprobs = (
+            torch.randn(num_frames, self.num_inputs, dtype=dtype, device=device)
+            * self.stddev
+        )
+        logprobs[:, -1] += self.final_log_weight
+        return logprobs.softmax(dim=1)
+
+
+def _test_random_combine(
+    final_weight: float,
+    pure_prob: float,
+    stddev: float,
+) -> None:
+    print(
+        f"_test_random_combine: final_weight={final_weight}, "
+        f"pure_prob={pure_prob}, stddev={stddev}"
+    )
+    num_inputs = 3
+    num_channels = 50
+    m = RandomCombine(
+        num_inputs=num_inputs,
+        final_weight=final_weight,
+        pure_prob=pure_prob,
+        stddev=stddev,
+    )
+
+    x = [torch.ones(3, 4, num_channels) for _ in range(num_inputs)]
+
+    y = m(x)
+    assert y.shape == x[0].shape
+    assert torch.allclose(y, x[0])  # .. since actually all ones.
+
+
+def _test_random_combine_main() -> None:
+    _test_random_combine(0.999, 0, 0.0)
+    _test_random_combine(0.5, 0, 0.0)
+    _test_random_combine(0.999, 0, 0.0)
+    _test_random_combine(0.5, 0, 0.3)
+    _test_random_combine(0.5, 1, 0.3)
+    _test_random_combine(0.5, 0.5, 0.3)
+
+
+if __name__ == "__main__":
+    _test_random_combine_main()
diff --git a/egs/tedlium3/ASR/conformer_ctc2/conformer.py b/egs/tedlium3/ASR/conformer_ctc2/conformer.py
new file mode 100644
index 000000000..fad2f371f
--- /dev/null
+++ b/egs/tedlium3/ASR/conformer_ctc2/conformer.py
@@ -0,0 +1,1033 @@
+#!/usr/bin/env python3
+# Copyright (c)  2021  University of Chinese Academy of Sciences (author: Han Zhu)
+#                2022  Xiaomi Corp.                              (author: Quandong Wang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import copy
+import math
+import warnings
+from typing import List, Optional, Tuple
+
+import torch
+import torch.nn as nn
+from combiner import RandomCombine
+from scaling import (
+    ActivationBalancer,
+    BasicNorm,
+    DoubleSwish,
+    ScaledConv1d,
+    ScaledLinear,
+)
+from subsampling import Conv2dSubsampling
+from transformer import Supervisions, Transformer, encoder_padding_mask
+
+
+class Conformer(Transformer):
+    def __init__(
+        self,
+        num_features: int,
+        num_classes: int,
+        subsampling_factor: int = 4,
+        d_model: int = 256,
+        nhead: int = 4,
+        dim_feedforward: int = 2048,
+        num_encoder_layers: int = 12,
+        num_decoder_layers: int = 6,
+        dropout: float = 0.1,
+        layer_dropout: float = 0.075,
+        cnn_module_kernel: int = 31,
+        aux_layer_period: int = 3,
+    ) -> None:
+        """
+        Args:
+          num_features (int):
+            number of input features.
+          num_classes (int):
+            number of output classes.
+          subsampling_factor (int):
+            subsampling factor of encoder;
+            currently, subsampling_factor MUST be 4.
+          d_model (int):
+            attention dimension, also the output dimension.
+          nhead (int):
+            number of heads in multi-head attention;
+            must satisfy d_model // nhead == 0.
+          dim_feedforward (int):
+            feedforward dimention.
+          num_encoder_layers (int):
+            number of encoder layers.
+          num_decoder_layers (int):
+            number of decoder layers.
+          dropout (float):
+            dropout rate.
+          layer_dropout (float):
+            layer-dropout rate.
+          cnn_module_kernel (int):
+            kernel size of convolution module.
+          aux_layer_period (int):
+            determines the auxiliary encoder layers.
+        """
+
+        super().__init__(
+            num_features=num_features,
+            num_classes=num_classes,
+            subsampling_factor=subsampling_factor,
+            d_model=d_model,
+            nhead=nhead,
+            dim_feedforward=dim_feedforward,
+            num_encoder_layers=num_encoder_layers,
+            num_decoder_layers=num_decoder_layers,
+            dropout=dropout,
+            layer_dropout=layer_dropout,
+        )
+
+        self.num_features = num_features
+        self.subsampling_factor = subsampling_factor
+        if subsampling_factor != 4:
+            raise NotImplementedError("Support only 'subsampling_factor=4'.")
+
+        # self.encoder_embed converts the input of shape (N, T, num_features)
+        # to the shape (N, T//subsampling_factor, d_model).
+        # That is, it does two things simultaneously:
+        #   (1) subsampling: T -> T//subsampling_factor
+        #   (2) embedding: num_features -> d_model
+        self.encoder_embed = Conv2dSubsampling(num_features, d_model)
+
+        self.encoder_pos = RelPositionalEncoding(d_model, dropout)
+
+        encoder_layer = ConformerEncoderLayer(
+            d_model=d_model,
+            nhead=nhead,
+            dim_feedforward=dim_feedforward,
+            dropout=dropout,
+            layer_dropout=layer_dropout,
+            cnn_module_kernel=cnn_module_kernel,
+        )
+
+        # aux_layers from 1/3
+        self.encoder = ConformerEncoder(
+            encoder_layer=encoder_layer,
+            num_layers=num_encoder_layers,
+            aux_layers=list(
+                range(
+                    num_encoder_layers // 3,
+                    num_encoder_layers - 1,
+                    aux_layer_period,
+                )
+            ),
+        )
+
+    def run_encoder(
+        self,
+        x: torch.Tensor,
+        supervisions: Optional[Supervisions] = None,
+        warmup: float = 1.0,
+    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
+        """
+        Args:
+          x:
+            the input tensor. Its shape is (batch_size, seq_len, feature_dim).
+          supervisions:
+            Supervision in lhotse format.
+            See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32  # noqa
+            CAUTION: It contains length information, i.e., start and number of
+            frames, before subsampling
+            It is read directly from the batch, without any sorting. It is used
+            to compute encoder padding mask, which is used as memory key padding
+            mask for the decoder.
+          warmup:
+            a floating point value that gradually increases from 0 throughout
+            training; when it is >= 1.0 we are "fully warmed up".  It is used
+            to turn modules on sequentially.
+
+        Returns:
+          torch.Tensor: Predictor tensor of dimension (S, N, C).
+          torch.Tensor: Mask tensor of dimension (N, S)
+        """
+        x = self.encoder_embed(x)
+        x, pos_emb = self.encoder_pos(x)
+        x = x.permute(1, 0, 2)  # (N, S, C) -> (S, N, C)
+        mask = encoder_padding_mask(x.size(0), supervisions)
+        mask = mask.to(x.device) if mask is not None else None
+
+        x = self.encoder(
+            x, pos_emb, src_key_padding_mask=mask, warmup=warmup
+        )  # (S, N, C)
+
+        return x, mask
+
+
+class ConformerEncoderLayer(nn.Module):
+    """
+    ConformerEncoderLayer is made up of self-attn, feedforward and convolution networks.
+    See: "Conformer: Convolution-augmented Transformer for Speech Recognition"
+
+    Examples:
+        >>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8)
+        >>> src = torch.rand(10, 32, 512)
+        >>> pos_emb = torch.rand(32, 19, 512)
+        >>> out = encoder_layer(src, pos_emb)
+    """
+
+    def __init__(
+        self,
+        d_model: int,
+        nhead: int,
+        dim_feedforward: int = 2048,
+        dropout: float = 0.1,
+        bypass_scale: float = 0.1,
+        layer_dropout: float = 0.075,
+        cnn_module_kernel: int = 31,
+    ) -> None:
+        """
+        Args:
+          d_model:
+            the number of expected features in the input (required).
+          nhead:
+            the number of heads in the multiheadattention models (required).
+          dim_feedforward:
+            the dimension of the feedforward network model (default=2048).
+          dropout:
+            the dropout value (default=0.1).
+          bypass_scale:
+            a scale on the layer's output, used in bypass (resnet-type) skip-connection;
+            when the layer is bypassed the final output will be a
+            weighted sum of the layer's input and layer's output with weights
+            (1.0-bypass_scale) and bypass_scale correspondingly (default=0.1).
+          layer_dropout:
+            the probability to bypass the layer (default=0.075).
+          cnn_module_kernel (int):
+            kernel size of convolution module (default=31).
+        """
+        super().__init__()
+
+        if bypass_scale < 0.0 or bypass_scale > 1.0:
+            raise ValueError("bypass_scale should be between 0.0 and 1.0")
+
+        if layer_dropout < 0.0 or layer_dropout > 1.0:
+            raise ValueError("layer_dropout should be between 0.0 and 1.0")
+
+        self.bypass_scale = bypass_scale
+        self.layer_dropout = layer_dropout
+
+        self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0)
+
+        self.feed_forward = nn.Sequential(
+            ScaledLinear(d_model, dim_feedforward),
+            ActivationBalancer(channel_dim=-1),
+            DoubleSwish(),
+            nn.Dropout(dropout),
+            ScaledLinear(dim_feedforward, d_model, initial_scale=0.25),
+        )
+
+        self.feed_forward_macaron = nn.Sequential(
+            ScaledLinear(d_model, dim_feedforward),
+            ActivationBalancer(channel_dim=-1),
+            DoubleSwish(),
+            nn.Dropout(dropout),
+            ScaledLinear(dim_feedforward, d_model, initial_scale=0.25),
+        )
+
+        self.conv_module = ConvolutionModule(d_model, cnn_module_kernel)
+
+        self.norm_final = BasicNorm(d_model)
+
+        # try to ensure the output is close to zero-mean (or at least, zero-median).
+        self.balancer = ActivationBalancer(
+            channel_dim=-1, min_positive=0.45, max_positive=0.55, max_abs=6.0
+        )
+
+        self.dropout = nn.Dropout(dropout)
+
+    def forward(
+        self,
+        src: torch.Tensor,
+        pos_emb: torch.Tensor,
+        src_mask: Optional[torch.Tensor] = None,
+        src_key_padding_mask: Optional[torch.Tensor] = None,
+        warmup: float = 1.0,
+    ) -> torch.Tensor:
+        """
+        Pass the input through the encoder layer.
+
+        Args:
+          src:
+            the sequence to the encoder layer of shape (S, N, C) (required).
+          pos_emb:
+            positional embedding tensor of shape (N, 2*S-1, C) (required).
+          src_mask:
+            the mask for the src sequence of shape (S, S) (optional).
+          src_key_padding_mask:
+            the mask for the src keys per batch of shape (N, S) (optional).
+          warmup:
+            controls selective bypass of of layers; if < 1.0, we will
+            bypass layers more frequently.
+
+        Returns:
+            Output tensor of the shape (S, N, C), where
+            S is the source sequence length,
+            N is the batch size,
+            C is the feature number
+        """
+        src_orig = src
+
+        warmup_scale = min(self.bypass_scale + warmup, 1.0)
+        # alpha = 1.0 means fully use this encoder layer, 0.0 would mean
+        # completely bypass it.
+        if self.training:
+            alpha = (
+                warmup_scale
+                if torch.rand(()).item() <= (1.0 - self.layer_dropout)
+                else self.bypass_scale
+            )
+        else:
+            alpha = 1.0
+
+        # macaron style feed forward module
+        src = src + self.dropout(self.feed_forward_macaron(src))
+
+        # multi-headed self-attention module
+        src_att = self.self_attn(
+            src,
+            src,
+            src,
+            pos_emb=pos_emb,
+            attn_mask=src_mask,
+            key_padding_mask=src_key_padding_mask,
+        )[0]
+
+        src = src + self.dropout(src_att)
+
+        # convolution module
+        src = src + self.dropout(self.conv_module(src))
+
+        # feed forward module
+        src = src + self.dropout(self.feed_forward(src))
+
+        src = self.norm_final(self.balancer(src))
+
+        if alpha != 1.0:
+            src = alpha * src + (1 - alpha) * src_orig
+
+        return src
+
+
+class ConformerEncoder(nn.Module):
+    """
+    ConformerEncoder is a stack of N encoder layers
+
+    Examples:
+        >>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8)
+        >>> conformer_encoder = ConformerEncoder(encoder_layer, num_layers=6)
+        >>> src = torch.rand(10, 32, 512)
+        >>> pos_emb = torch.rand(32, 19, 512)
+        >>> out = conformer_encoder(src, pos_emb)
+    """
+
+    def __init__(
+        self,
+        encoder_layer: nn.Module,
+        num_layers: int,
+        aux_layers: List[int],
+    ) -> None:
+
+        """
+        Args:
+          encoder_layer:
+            an instance of the ConformerEncoderLayer() class (required).
+          num_layers:
+            the number of sub-encoder-layers in the encoder (required).
+          aux_layers:
+            list of indexes of sub-encoder-layers outputs to be combined (required).
+        """
+
+        super().__init__()
+        self.layers = nn.ModuleList(
+            [copy.deepcopy(encoder_layer) for i in range(num_layers)]
+        )
+        self.num_layers = num_layers
+
+        assert len(set(aux_layers)) == len(aux_layers)
+
+        assert num_layers - 1 not in aux_layers
+        self.aux_layers = aux_layers + [num_layers - 1]
+
+        self.combiner = RandomCombine(
+            num_inputs=len(self.aux_layers),
+            final_weight=0.5,
+            pure_prob=0.333,
+            stddev=2.0,
+        )
+
+    def forward(
+        self,
+        src: torch.Tensor,
+        pos_emb: torch.Tensor,
+        mask: Optional[torch.Tensor] = None,
+        src_key_padding_mask: Optional[torch.Tensor] = None,
+        warmup: float = 1.0,
+    ) -> torch.Tensor:
+        """
+        Pass the input through the encoder layers in turn.
+
+        Args:
+          src:
+            the sequence to the encoder of shape (S, N, C) (required).
+          pos_emb:
+            positional embedding tensor of shape (N, 2*S-1, C) (required).
+          mask:
+            the mask for the src sequence of shape (S, S) (optional).
+          src_key_padding_mask:
+            the mask for the src keys per batch of shape (N, S) (optional).
+          warmup:
+            controls selective bypass of layer; if < 1.0, we will
+            bypass the layer more frequently (default=1.0).
+
+        Returns:
+          Output tensor of the shape (S, N, C), where
+          S is the source sequence length,
+          N is the batch size,
+          C is the feature number.
+
+        """
+        output = src
+
+        outputs = []
+        for i, mod in enumerate(self.layers):
+            output = mod(
+                output,
+                pos_emb,
+                src_mask=mask,
+                src_key_padding_mask=src_key_padding_mask,
+                warmup=warmup,
+            )
+
+            if i in self.aux_layers:
+                outputs.append(output)
+
+        output = self.combiner(outputs)
+
+        return output
+
+
+class RelPositionalEncoding(torch.nn.Module):
+    """
+    Relative positional encoding module.
+
+    See: Appendix B in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
+    Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/embedding.py
+
+    """
+
+    def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None:
+        """
+        Construct an PositionalEncoding object.
+
+        Args:
+          d_model: Embedding dimension.
+          dropout_rate: Dropout rate.
+          max_len: Maximum input length.
+
+        """
+        super().__init__()
+        self.d_model = d_model
+        self.dropout = torch.nn.Dropout(p=dropout_rate)
+        self.pe = None
+        self.extend_pe(torch.tensor(0.0).expand(1, max_len))
+
+    def extend_pe(self, x: torch.Tensor) -> None:
+        """
+        Reset the positional encodings.
+
+        Args:
+          x:
+            input tensor (N, T, C), where
+            T is the source sequence length,
+            N is the batch size.
+            C is the feature number.
+
+        """
+        if self.pe is not None:
+            # self.pe contains both positive and negative parts
+            # the length of self.pe is 2 * input_len - 1
+            if self.pe.size(1) >= x.size(1) * 2 - 1:
+                # Note: TorchScript doesn't implement operator== for torch.Device
+                if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device):
+                    self.pe = self.pe.to(dtype=x.dtype, device=x.device)
+                return
+        # Suppose `i` means to the position of query vecotr and `j` means the
+        # position of key vector. We use position relative positions when keys
+        # are to the left (i>j) and negative relative positions otherwise (i Tuple[torch.Tensor, torch.Tensor]:
+        """
+        Add positional encoding.
+
+        Args:
+          x:
+            input tensor (N, T, C).
+
+        Returns:
+          torch.Tensor: Encoded tensor (N, T, C).
+          torch.Tensor: Encoded tensor (N, 2*T-1, C), where
+          T is the source sequence length,
+          N is the batch size.
+          C is the feature number.
+
+        """
+        self.extend_pe(x)
+        pos_emb = self.pe[
+            :,
+            self.pe.size(1) // 2
+            - x.size(1)
+            + 1 : self.pe.size(1) // 2  # noqa E203
+            + x.size(1),
+        ]
+        return self.dropout(x), self.dropout(pos_emb)
+
+
+class RelPositionMultiheadAttention(nn.Module):
+    """
+    Multi-Head Attention layer with relative position encoding
+    See reference: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context".
+
+    """
+
+    def __init__(
+        self,
+        embed_dim: int,
+        num_heads: int,
+        dropout: float = 0.0,
+    ) -> None:
+        """
+        Args:
+          embed_dim:
+            total dimension of the model.
+          num_heads:
+            parallel attention heads.
+          dropout:
+            a Dropout layer on attn_output_weights. Default: 0.0.
+        """
+        super().__init__()
+        self.embed_dim = embed_dim
+        self.num_heads = num_heads
+        self.dropout = dropout
+        self.head_dim = embed_dim // num_heads
+        assert (
+            self.head_dim * num_heads == self.embed_dim
+        ), "embed_dim must be divisible by num_heads"
+
+        self.in_proj = ScaledLinear(embed_dim, 3 * embed_dim, bias=True)
+        self.out_proj = ScaledLinear(
+            embed_dim, embed_dim, bias=True, initial_scale=0.25
+        )
+
+        # linear transformation for positional encoding.
+        self.linear_pos = ScaledLinear(embed_dim, embed_dim, bias=False)
+        # these two learnable bias are used in matrix c and matrix d
+        # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
+        self.pos_bias_u = nn.Parameter(torch.Tensor(num_heads, self.head_dim))
+        self.pos_bias_v = nn.Parameter(torch.Tensor(num_heads, self.head_dim))
+        self.pos_bias_u_scale = nn.Parameter(torch.zeros(()).detach())
+        self.pos_bias_v_scale = nn.Parameter(torch.zeros(()).detach())
+        self._reset_parameters()
+
+    def _pos_bias_u(self):
+        return self.pos_bias_u * self.pos_bias_u_scale.exp()
+
+    def _pos_bias_v(self):
+        return self.pos_bias_v * self.pos_bias_v_scale.exp()
+
+    def _reset_parameters(self) -> None:
+        nn.init.normal_(self.pos_bias_u, std=0.01)
+        nn.init.normal_(self.pos_bias_v, std=0.01)
+
+    def forward(
+        self,
+        query: torch.Tensor,
+        key: torch.Tensor,
+        value: torch.Tensor,
+        pos_emb: torch.Tensor,
+        key_padding_mask: Optional[torch.Tensor] = None,
+        need_weights: bool = False,
+        attn_mask: Optional[torch.Tensor] = None,
+    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
+        """
+        Args:
+          query, key, value: map a query and a set of key-value pairs to an output.
+          pos_emb: Positional embedding tensor
+          key_padding_mask: if provided, specified padding elements in the key will
+                            be ignored by the attention. When given a binary mask
+                            and a value is True, the corresponding value on the attention
+                            layer will be ignored. When given a byte mask and a value is
+                            non-zero, the corresponding value on the attention layer will be ignored.
+          need_weights: output attn_output_weights.
+          attn_mask: 2D or 3D mask that prevents attention to certain positions.
+                     A 2D mask will be broadcasted for all the batches while a 3D
+                     mask allows to specify a different mask for the entries of each batch.
+
+        Shape:
+          - Inputs:
+          - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
+            the embedding dimension.
+          - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
+            the embedding dimension.
+          - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
+            the embedding dimension.
+          - pos_emb: :math:`(N, 2*L-1, E)` where L is the target sequence length, N is the batch size, E is
+            the embedding dimension.
+          - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
+            If a ByteTensor is provided, the non-zero positions will be ignored while the position
+            with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the
+            value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
+          - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
+            3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
+            S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked
+            positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
+            while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
+            is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
+            is provided, it will be added to the attention weight.
+
+          - Outputs:
+          - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
+            E is the embedding dimension.
+          - attn_output_weights: :math:`(N, L, S)` where N is the batch size,
+            L is the target sequence length, S is the source sequence length.
+        """
+        return self.multi_head_attention_forward(
+            query,
+            key,
+            value,
+            pos_emb,
+            self.embed_dim,
+            self.num_heads,
+            self.in_proj.get_weight(),
+            self.in_proj.get_bias(),
+            self.dropout,
+            self.out_proj.get_weight(),
+            self.out_proj.get_bias(),
+            training=self.training,
+            key_padding_mask=key_padding_mask,
+            need_weights=need_weights,
+            attn_mask=attn_mask,
+        )
+
+    def rel_shift(self, x: torch.Tensor) -> torch.Tensor:
+        """
+        Compute relative positional encoding.
+
+        Args:
+          x:
+            input tensor (batch, head, time1, 2*time1-1).
+            time1 means the length of query vector.
+
+        Returns:
+          torch.Tensor: tensor of shape (batch, head, time1, time2)
+          (note: time2 has the same value as time1, but it is for
+          the key, while time1 is for the query).
+        """
+        (batch_size, num_heads, time1, n) = x.shape
+        assert n == 2 * time1 - 1
+        # Note: TorchScript requires explicit arg for stride()
+        batch_stride = x.stride(0)
+        head_stride = x.stride(1)
+        time1_stride = x.stride(2)
+        n_stride = x.stride(3)
+        return x.as_strided(
+            (batch_size, num_heads, time1, time1),
+            (batch_stride, head_stride, time1_stride - n_stride, n_stride),
+            storage_offset=n_stride * (time1 - 1),
+        )
+
+    def multi_head_attention_forward(
+        self,
+        query: torch.Tensor,
+        key: torch.Tensor,
+        value: torch.Tensor,
+        pos_emb: torch.Tensor,
+        embed_dim_to_check: int,
+        num_heads: int,
+        in_proj_weight: torch.Tensor,
+        in_proj_bias: torch.Tensor,
+        dropout_p: float,
+        out_proj_weight: torch.Tensor,
+        out_proj_bias: torch.Tensor,
+        training: bool = True,
+        key_padding_mask: Optional[torch.Tensor] = None,
+        need_weights: bool = False,
+        attn_mask: Optional[torch.Tensor] = None,
+    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
+        """
+        Args:
+          query, key, value: map a query and a set of key-value pairs to an output.
+          pos_emb: Positional embedding tensor
+          embed_dim_to_check: total dimension of the model.
+          num_heads: parallel attention heads.
+          in_proj_weight, in_proj_bias: input projection weight and bias.
+          dropout_p: probability of an element to be zeroed.
+          out_proj_weight, out_proj_bias: the output projection weight and bias.
+          training: apply dropout if is ``True``.
+          key_padding_mask: if provided, specified padding elements in the key will
+                            be ignored by the attention. This is an binary mask.
+                            When the value is True, the corresponding value on the
+                            attention layer will be filled with -inf.
+          need_weights: output attn_output_weights.
+          attn_mask: 2D or 3D mask that prevents attention to certain positions.
+                     A 2D mask will be broadcasted for all the batches while a 3D
+                     mask allows to specify a different mask for the entries of each batch.
+
+        Shape:
+          Inputs:
+          - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
+            the embedding dimension.
+          - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
+            the embedding dimension.
+          - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
+            the embedding dimension.
+          - pos_emb: :math:`(N, 2*L-1, E)` or :math:`(1, 2*L-1, E)` where L is the target sequence
+            length, N is the batch size, E is the embedding dimension.
+          - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
+            If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions
+            will be unchanged. If a BoolTensor is provided, the positions with the
+            value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
+          - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
+            3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
+            S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked
+            positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
+            while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
+            are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
+            is provided, it will be added to the attention weight.
+
+          Outputs:
+          - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
+            E is the embedding dimension.
+          - attn_output_weights: :math:`(N, L, S)` where N is the batch size,
+            L is the target sequence length, S is the source sequence length.
+        """
+
+        tgt_len, bsz, embed_dim = query.size()
+        assert embed_dim == embed_dim_to_check
+        assert key.size(0) == value.size(0) and key.size(1) == value.size(1)
+
+        head_dim = embed_dim // num_heads
+        assert (
+            head_dim * num_heads == embed_dim
+        ), "embed_dim must be divisible by num_heads"
+
+        scaling = float(head_dim) ** -0.5
+
+        if torch.equal(query, key) and torch.equal(key, value):
+            # self-attention
+            q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk(
+                3, dim=-1
+            )
+
+        elif torch.equal(key, value):
+            # encoder-decoder attention
+            # This is inline in_proj function with in_proj_weight and in_proj_bias
+            _b = in_proj_bias
+            _start = 0
+            _end = embed_dim
+            _w = in_proj_weight[_start:_end, :]
+            if _b is not None:
+                _b = _b[_start:_end]
+            q = nn.functional.linear(query, _w, _b)
+
+            # This is inline in_proj function with in_proj_weight and in_proj_bias
+            _b = in_proj_bias
+            _start = embed_dim
+            _end = None
+            _w = in_proj_weight[_start:, :]
+            if _b is not None:
+                _b = _b[_start:]
+            k, v = nn.functional.linear(key, _w, _b).chunk(2, dim=-1)
+
+        else:
+            # This is inline in_proj function with in_proj_weight and in_proj_bias
+            _b = in_proj_bias
+            _start = 0
+            _end = embed_dim
+            _w = in_proj_weight[_start:_end, :]
+            if _b is not None:
+                _b = _b[_start:_end]
+            q = nn.functional.linear(query, _w, _b)
+
+            # This is inline in_proj function with in_proj_weight and in_proj_bias
+            _b = in_proj_bias
+            _start = embed_dim
+            _end = embed_dim * 2
+            _w = in_proj_weight[_start:_end, :]
+            if _b is not None:
+                _b = _b[_start:_end]
+            k = nn.functional.linear(key, _w, _b)
+
+            # This is inline in_proj function with in_proj_weight and in_proj_bias
+            _b = in_proj_bias
+            _start = embed_dim * 2
+            _end = None
+            _w = in_proj_weight[_start:, :]
+            if _b is not None:
+                _b = _b[_start:]
+            v = nn.functional.linear(value, _w, _b)
+
+        if attn_mask is not None:
+            assert (
+                attn_mask.dtype == torch.float32
+                or attn_mask.dtype == torch.float64
+                or attn_mask.dtype == torch.float16
+                or attn_mask.dtype == torch.uint8
+                or attn_mask.dtype == torch.bool
+            ), "Only float, byte, and bool types are supported for attn_mask, not {}".format(
+                attn_mask.dtype
+            )
+            if attn_mask.dtype == torch.uint8:
+                warnings.warn(
+                    "Byte tensor for attn_mask is deprecated. Use bool tensor instead."
+                )
+                attn_mask = attn_mask.to(torch.bool)
+
+            if attn_mask.dim() == 2:
+                attn_mask = attn_mask.unsqueeze(0)
+                if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
+                    raise RuntimeError("The size of the 2D attn_mask is not correct.")
+            elif attn_mask.dim() == 3:
+                if list(attn_mask.size()) != [
+                    bsz * num_heads,
+                    query.size(0),
+                    key.size(0),
+                ]:
+                    raise RuntimeError("The size of the 3D attn_mask is not correct.")
+            else:
+                raise RuntimeError(
+                    f"attn_mask's dimension {attn_mask.dim()} is not supported"
+                )
+            # attn_mask's dim is 3 now.
+
+        # convert ByteTensor key_padding_mask to bool
+        if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
+            warnings.warn(
+                "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead."
+            )
+            key_padding_mask = key_padding_mask.to(torch.bool)
+
+        q = (q * scaling).contiguous().view(tgt_len, bsz, num_heads, head_dim)
+        k = k.contiguous().view(-1, bsz, num_heads, head_dim)
+        v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
+
+        src_len = k.size(0)
+
+        if key_padding_mask is not None:
+            assert key_padding_mask.size(0) == bsz, "{} == {}".format(
+                key_padding_mask.size(0), bsz
+            )
+            assert key_padding_mask.size(1) == src_len, "{} == {}".format(
+                key_padding_mask.size(1), src_len
+            )
+
+        q = q.transpose(0, 1)  # (batch, time1, head, d_k)
+
+        pos_emb_bsz = pos_emb.size(0)
+        assert pos_emb_bsz in (1, bsz)  # actually it is 1
+        p = self.linear_pos(pos_emb).view(pos_emb_bsz, -1, num_heads, head_dim)
+        p = p.transpose(1, 2)  # (batch, head, 2*time1-1, d_k)
+
+        q_with_bias_u = (q + self._pos_bias_u()).transpose(
+            1, 2
+        )  # (batch, head, time1, d_k)
+
+        q_with_bias_v = (q + self._pos_bias_v()).transpose(
+            1, 2
+        )  # (batch, head, time1, d_k)
+
+        # compute attention score
+        # first compute matrix a and matrix c
+        # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
+        k = k.permute(1, 2, 3, 0)  # (batch, head, d_k, time2)
+        matrix_ac = torch.matmul(q_with_bias_u, k)  # (batch, head, time1, time2)
+
+        # compute matrix b and matrix d
+        matrix_bd = torch.matmul(
+            q_with_bias_v, p.transpose(-2, -1)
+        )  # (batch, head, time1, 2*time1-1)
+        matrix_bd = self.rel_shift(matrix_bd)
+
+        attn_output_weights = matrix_ac + matrix_bd  # (batch, head, time1, time2)
+        attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1)
+
+        assert list(attn_output_weights.size()) == [bsz * num_heads, tgt_len, src_len]
+
+        if attn_mask is not None:
+            if attn_mask.dtype == torch.bool:
+                attn_output_weights.masked_fill_(attn_mask, float("-inf"))
+            else:
+                attn_output_weights += attn_mask
+
+        if key_padding_mask is not None:
+            attn_output_weights = attn_output_weights.view(
+                bsz, num_heads, tgt_len, src_len
+            )
+            attn_output_weights = attn_output_weights.masked_fill(
+                key_padding_mask.unsqueeze(1).unsqueeze(2),
+                float("-inf"),
+            )
+            attn_output_weights = attn_output_weights.view(
+                bsz * num_heads, tgt_len, src_len
+            )
+
+        attn_output_weights = nn.functional.softmax(attn_output_weights, dim=-1)
+        attn_output_weights = nn.functional.dropout(
+            attn_output_weights, p=dropout_p, training=training
+        )
+
+        attn_output = torch.bmm(attn_output_weights, v)
+        assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
+        attn_output = (
+            attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
+        )
+        attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias)
+
+        if need_weights:
+            # average attention weights over heads
+            attn_output_weights = attn_output_weights.view(
+                bsz, num_heads, tgt_len, src_len
+            )
+            return attn_output, attn_output_weights.sum(dim=1) / num_heads
+        else:
+            return attn_output, None
+
+
+class ConvolutionModule(nn.Module):
+    def __init__(self, channels: int, kernel_size: int, bias: bool = True) -> None:
+        """
+        ConvolutionModule in Conformer model.
+        Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/conformer/convolution.py
+        Construct a ConvolutionModule object.
+
+        Args:
+          channels (int):
+            the number of channels of conv layers.
+          kernel_size (int):
+            kernerl size of conv layers.
+          bias (bool):
+            whether to use bias in conv layers (default=True).
+        """
+        super().__init__()
+        # kernerl_size should be a odd number for 'SAME' padding
+        assert (kernel_size - 1) % 2 == 0
+
+        self.pointwise_conv1 = ScaledConv1d(
+            channels,
+            2 * channels,
+            kernel_size=1,
+            stride=1,
+            padding=0,
+            bias=bias,
+        )
+
+        # after pointwise_conv1 we put x through a gated linear unit (nn.functional.glu).
+        # For most layers the normal rms value of channels of x seems to be in the range 1 to 4,
+        # but sometimes, for some reason, for layer 0 the rms ends up being very large,
+        # between 50 and 100 for different channels.  This will cause very peaky and
+        # sparse derivatives for the sigmoid gating function, which will tend to make
+        # the loss function not learn effectively.  (for most layers the average absolute values
+        # are in the range 0.5..9.0, and the average p(x>0), i.e. positive proportion,
+        # at the output of pointwise_conv1.output is around 0.35 to 0.45 for different
+        # layers, which likely breaks down as 0.5 for the "linear" half and
+        # 0.2 to 0.3 for the part that goes into the sigmoid.  The idea is that if we
+        # constrain the rms values to a reasonable range via a constraint of max_abs=10.0,
+        # it will be in a better position to start learning something, i.e. to latch onto
+        # the correct range.
+        self.deriv_balancer1 = ActivationBalancer(
+            channel_dim=1, max_abs=10.0, min_positive=0.05, max_positive=1.0
+        )
+
+        self.depthwise_conv = ScaledConv1d(
+            channels,
+            channels,
+            kernel_size,
+            stride=1,
+            padding=(kernel_size - 1) // 2,
+            groups=channels,
+            bias=bias,
+        )
+
+        self.deriv_balancer2 = ActivationBalancer(
+            channel_dim=1, min_positive=0.05, max_positive=1.0
+        )
+
+        self.activation = DoubleSwish()
+
+        self.pointwise_conv2 = ScaledConv1d(
+            channels,
+            channels,
+            kernel_size=1,
+            stride=1,
+            padding=0,
+            bias=bias,
+            initial_scale=0.25,
+        )
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        """Compute convolution module.
+
+        Args:
+          x:
+            input tensor of shape (T, N, C).
+
+        Returns:
+          torch.Tensor: Output tensor (T, N, C), where
+          T is the source sequence length,
+          N is the batch size,
+          C is the feature number.
+
+        """
+        # exchange the temporal dimension and the feature dimension
+        x = x.permute(1, 2, 0)  # (#batch, channels, time).
+
+        # GLU mechanism
+        x = self.pointwise_conv1(x)  # (batch, 2*channels, time)
+
+        x = self.deriv_balancer1(x)
+        x = nn.functional.glu(x, dim=1)  # (batch, channels, time)
+
+        # 1D Depthwise Conv
+        x = self.depthwise_conv(x)
+
+        x = self.deriv_balancer2(x)
+        x = self.activation(x)
+
+        x = self.pointwise_conv2(x)  # (batch, channel, time)
+
+        return x.permute(2, 0, 1)
diff --git a/egs/tedlium3/ASR/conformer_ctc2/decode.py b/egs/tedlium3/ASR/conformer_ctc2/decode.py
new file mode 100755
index 000000000..ce4dcd142
--- /dev/null
+++ b/egs/tedlium3/ASR/conformer_ctc2/decode.py
@@ -0,0 +1,899 @@
+#!/usr/bin/env python3
+# Copyright 2021 Xiaomi Corporation (Author: Liyong Guo,
+#                                            Fangjun Kuang,
+#                                            Quandong Wang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import argparse
+import logging
+import shutil
+from collections import defaultdict
+from pathlib import Path
+from typing import Dict, List, Optional, Tuple
+
+import k2
+import sentencepiece as spm
+import torch
+import torch.nn as nn
+from asr_datamodule import TedLiumAsrDataModule
+from conformer import Conformer
+from train import add_model_arguments
+
+from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
+from icefall.checkpoint import (
+    average_checkpoints,
+    average_checkpoints_with_averaged_model,
+    find_checkpoints,
+    load_checkpoint,
+)
+from icefall.decode import (
+    get_lattice,
+    nbest_decoding,
+    nbest_oracle,
+    one_best_decoding,
+    rescore_with_attention_decoder,
+    rescore_with_n_best_list,
+    rescore_with_whole_lattice,
+)
+from icefall.env import get_env_info
+from icefall.lexicon import Lexicon
+from icefall.utils import (
+    AttributeDict,
+    get_texts,
+    load_averaged_model,
+    setup_logger,
+    store_transcripts,
+    str2bool,
+    write_error_stats,
+)
+
+
+def get_parser() -> argparse.ArgumentParser:
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+
+    parser.add_argument(
+        "--epoch",
+        type=int,
+        default=30,
+        help="""It specifies the checkpoint to use for decoding.
+        Note: Epoch counts from 1.
+        You can specify --avg to use more checkpoints for model averaging.""",
+    )
+
+    parser.add_argument(
+        "--iter",
+        type=int,
+        default=0,
+        help="""If positive, --epoch is ignored and it
+        will use the checkpoint exp_dir/checkpoint-iter.pt.
+        You can specify --avg to use more checkpoints for model averaging.
+        """,
+    )
+
+    parser.add_argument(
+        "--avg",
+        type=int,
+        default=15,
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch' and '--iter'",
+    )
+
+    parser.add_argument(
+        "--method",
+        type=str,
+        default="attention-decoder",
+        help="""Decoding method.
+        Supported values are:
+            - (0) ctc-decoding. Use CTC decoding. It uses a sentence piece
+              model, i.e., lang_dir/bpe.model, to convert word pieces to words.
+              It needs neither a lexicon nor an n-gram LM.
+            - (1) ctc-greedy-search. It only use CTC output and a sentence piece
+              model for decoding. It produces the same results with ctc-decoding.
+            - (2) 1best. Extract the best path from the decoding lattice as the
+              decoding result.
+            - (3) nbest. Extract n paths from the decoding lattice; the path
+              with the highest score is the decoding result.
+            - (4) nbest-rescoring. Extract n paths from the decoding lattice,
+              rescore them with an n-gram LM (e.g., a 4-gram LM), the path with
+              the highest score is the decoding result.
+            - (5) whole-lattice-rescoring. Rescore the decoding lattice with an
+              n-gram LM (e.g., a 4-gram LM), the best path of rescored lattice
+              is the decoding result.
+            - (6) attention-decoder. Extract n paths from the LM rescored
+              lattice, the path with the highest score is the decoding result.
+            - (7) nbest-oracle. Its WER is the lower bound of any n-best
+              rescoring method can achieve. Useful for debugging n-best
+              rescoring method.
+        """,
+    )
+
+    parser.add_argument(
+        "--use-averaged-model",
+        type=str2bool,
+        default=True,
+        help="Whether to load averaged model. Currently it only supports "
+        "using --epoch. If True, it would decode with the averaged model "
+        "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+        "Actually only the models with epoch number of `epoch-avg` and "
+        "`epoch` are loaded for averaging. ",
+    )
+
+    parser.add_argument(
+        "--num-paths",
+        type=int,
+        default=100,
+        help="""Number of paths for n-best based decoding method.
+        Used only when "method" is one of the following values:
+        nbest, nbest-rescoring, attention-decoder, and nbest-oracle
+        """,
+    )
+
+    parser.add_argument(
+        "--nbest-scale",
+        type=float,
+        default=0.5,
+        help="""The scale to be applied to `lattice.scores`.
+        It's needed if you use any kinds of n-best based rescoring.
+        Used only when "method" is one of the following values:
+        nbest, nbest-rescoring, attention-decoder, and nbest-oracle
+        A smaller value results in more unique paths.
+        """,
+    )
+
+    parser.add_argument(
+        "--exp-dir",
+        type=str,
+        default="conformer_ctc2/exp",
+        help="The experiment dir",
+    )
+
+    parser.add_argument(
+        "--lang-dir",
+        type=str,
+        default="data/lang_bpe_500",
+        help="The lang dir",
+    )
+
+    parser.add_argument(
+        "--lm-path",
+        type=str,
+        default="data/lm/G_4_gram.pt",
+        help="""The n-gram LM dir for rescoring.
+        It should contain either lm_fname.pt or lm_fname.fst.txt
+        """,
+    )
+
+    parser.add_argument(
+        "--result-dir",
+        type=str,
+        default="conformer_ctc2/exp",
+        help="Directory to store results.",
+    )
+
+    add_model_arguments(parser)
+
+    return parser
+
+
+def get_params() -> AttributeDict:
+    """Return a dict containing training parameters.
+
+    All training related parameters that are not passed from the commandline
+    are saved in the variable `params`.
+
+    Commandline options are merged into `params` after they are parsed, so
+    you can also access them via `params`.
+
+    Explanation of options saved in `params`:
+
+        - feature_dim: The model input dim. It has to match the one used
+                       in computing features.
+
+        - subsampling_factor:  The subsampling factor for the model.
+    """
+    params = AttributeDict(
+        {
+            # parameters for conformer
+            "subsampling_factor": 4,
+            "feature_dim": 80,
+            # parameters for decoding
+            "search_beam": 15,
+            "output_beam": 8,
+            "min_active_states": 10,
+            "max_active_states": 7000,
+            "use_double_scores": True,
+            "env_info": get_env_info(),
+        }
+    )
+    return params
+
+
+def ctc_greedy_search(
+    ctc_probs: torch.Tensor,
+    mask: torch.Tensor,
+) -> List[List[int]]:
+    """Apply CTC greedy search
+    Args:
+      ctc_probs (torch.Tensor): (batch, max_len, num_bpe)
+      mask (torch.Tensor): (batch, max_len)
+    Returns:
+      best path result
+    """
+
+    _, max_index = ctc_probs.max(2)  # (B, maxlen)
+    max_index = max_index.masked_fill_(mask, 0)  # (B, maxlen)
+
+    ret_hyps = []
+    for hyp in max_index:
+        hyp = torch.unique_consecutive(hyp)
+        hyp = hyp[hyp > 0].tolist()
+        ret_hyps.append(hyp)
+    return ret_hyps
+
+
+def decode_one_batch(
+    params: AttributeDict,
+    model: nn.Module,
+    HLG: Optional[k2.Fsa],
+    H: Optional[k2.Fsa],
+    bpe_model: Optional[spm.SentencePieceProcessor],
+    batch: dict,
+    word_table: k2.SymbolTable,
+    sos_id: int,
+    eos_id: int,
+    G: Optional[k2.Fsa] = None,
+) -> Dict[str, List[List[str]]]:
+    """Decode one batch and return the result in a dict. The dict has the
+    following format:
+
+        - key: It indicates the setting used for decoding. For example,
+               if no rescoring is used, the key is the string `no_rescore`.
+               If LM rescoring is used, the key is the string `lm_scale_xxx`,
+               where `xxx` is the value of `lm_scale`. An example key is
+               `lm_scale_0.7`
+        - value: It contains the decoding result. `len(value)` equals to
+                 batch size. `value[i]` is the decoding result for the i-th
+                 utterance in the given batch.
+    Args:
+      params:
+        It's the return value of :func:`get_params`.
+
+        - params.method is "1best", it uses 1best decoding without LM rescoring.
+        - params.method is "nbest", it uses nbest decoding without LM rescoring.
+        - params.method is "nbest-rescoring", it uses nbest LM rescoring.
+        - params.method is "whole-lattice-rescoring", it uses whole lattice LM
+          rescoring.
+
+      model:
+        The neural model.
+      HLG:
+        The decoding graph. Used only when params.method is NOT ctc-decoding.
+      H:
+        The ctc topo. Used only when params.method is ctc-decoding.
+      bpe_model:
+        The BPE model. Used only when params.method is ctc-decoding.
+      batch:
+        It is the return value from iterating
+        `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
+        for the format of the `batch`.
+      word_table:
+        The word symbol table.
+      sos_id:
+        The token ID of the SOS.
+      eos_id:
+        The token ID of the EOS.
+      G:
+        An LM. It is not None when params.method is "nbest-rescoring"
+        or "whole-lattice-rescoring". In general, the G in HLG
+        is a 3-gram LM, while this G is a 4-gram LM.
+    Returns:
+      Return the decoding result. See above description for the format of
+      the returned dict. Note: If it decodes to nothing, then return None.
+    """
+    if HLG is not None:
+        device = HLG.device
+    else:
+        device = H.device
+    feature = batch["inputs"]
+    assert feature.ndim == 3
+    feature = feature.to(device)
+    # at entry, feature is (N, T, C)
+
+    supervisions = batch["supervisions"]
+
+    nnet_output, memory, memory_key_padding_mask = model(feature, supervisions)
+    # nnet_output is (N, T, C)
+
+    supervision_segments = torch.stack(
+        (
+            supervisions["sequence_idx"],
+            torch.div(
+                supervisions["start_frame"],
+                params.subsampling_factor,
+                rounding_mode="floor",
+            ),
+            torch.div(
+                supervisions["num_frames"],
+                params.subsampling_factor,
+                rounding_mode="floor",
+            ),
+        ),
+        1,
+    ).to(torch.int32)
+
+    if H is None:
+        assert HLG is not None
+        decoding_graph = HLG
+    else:
+        assert HLG is None
+        assert bpe_model is not None
+        decoding_graph = H
+
+    lattice = get_lattice(
+        nnet_output=nnet_output,
+        decoding_graph=decoding_graph,
+        supervision_segments=supervision_segments,
+        search_beam=params.search_beam,
+        output_beam=params.output_beam,
+        min_active_states=params.min_active_states,
+        max_active_states=params.max_active_states,
+        subsampling_factor=params.subsampling_factor,
+    )
+
+    if params.method == "ctc-decoding":
+        best_path = one_best_decoding(
+            lattice=lattice, use_double_scores=params.use_double_scores
+        )
+        # Note: `best_path.aux_labels` contains token IDs, not word IDs
+        # since we are using H, not HLG here.
+        #
+        # token_ids is a lit-of-list of IDs
+        token_ids = get_texts(best_path)
+
+        # hyps is a list of str, e.g., ['xxx yyy zzz', ...]
+        hyps = bpe_model.decode(token_ids)
+
+        # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ]
+        unk = bpe_model.decode(bpe_model.unk_id()).strip()
+        hyps = [[w for w in s.split() if w != unk] for s in hyps]
+        key = "ctc-decoding"
+
+        return {key: hyps}
+
+    if params.method == "ctc-greedy-search":
+        hyps = ctc_greedy_search(nnet_output, memory_key_padding_mask)
+
+        # hyps is a list of str, e.g., ['xxx yyy zzz', ...]
+        hyps = bpe_model.decode(hyps)
+
+        # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ]
+        unk = bpe_model.decode(bpe_model.unk_id()).strip()
+        hyps = [[w for w in s.split() if w != unk] for s in hyps]
+        key = "ctc-greedy-search"
+
+        return {key: hyps}
+
+    if params.method == "nbest-oracle":
+        # Note: You can also pass rescored lattices to it.
+        # We choose the HLG decoded lattice for speed reasons
+        # as HLG decoding is faster and the oracle WER
+        # is only slightly worse than that of rescored lattices.
+        best_path = nbest_oracle(
+            lattice=lattice,
+            num_paths=params.num_paths,
+            ref_texts=supervisions["text"],
+            word_table=word_table,
+            nbest_scale=params.nbest_scale,
+            oov="",
+        )
+        hyps = get_texts(best_path)
+        hyps = [
+            [word_table[i] for i in ids if word_table[i] != ""] for ids in hyps
+        ]
+        key = f"oracle_{params.num_paths}_nbest_scale_{params.nbest_scale}"  # noqa
+        return {key: hyps}
+
+    if params.method == "nbest":
+        best_path = nbest_decoding(
+            lattice=lattice,
+            num_paths=params.num_paths,
+            use_double_scores=params.use_double_scores,
+            nbest_scale=params.nbest_scale,
+        )
+        key = f"no_rescore-nbest-scale-{params.nbest_scale}-{params.num_paths}"  # noqa
+
+        hyps = get_texts(best_path)
+        hyps = [
+            [word_table[i] for i in ids if word_table[i] != ""] for ids in hyps
+        ]
+        return {key: hyps}
+
+    assert params.method in [
+        "1best",
+        "nbest-rescoring",
+        "whole-lattice-rescoring",
+        "attention-decoder",
+    ]
+
+    lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]
+    lm_scale_list += [0.8, 0.9, 1.0, 1.1, 1.2, 1.3]
+    lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0]
+
+    if params.method == "1best":
+        best_path_dict = one_best_decoding(
+            lattice=lattice,
+            lm_scale_list=lm_scale_list,
+        )
+    elif params.method == "nbest-rescoring":
+        best_path_dict = rescore_with_n_best_list(
+            lattice=lattice,
+            G=G,
+            num_paths=params.num_paths,
+            lm_scale_list=lm_scale_list,
+            nbest_scale=params.nbest_scale,
+        )
+    elif params.method == "whole-lattice-rescoring":
+        best_path_dict = rescore_with_whole_lattice(
+            lattice=lattice,
+            G_with_epsilon_loops=G,
+            lm_scale_list=lm_scale_list,
+        )
+    elif params.method == "attention-decoder":
+        best_path_dict = rescore_with_attention_decoder(
+            lattice=lattice,
+            num_paths=params.num_paths,
+            model=model,
+            memory=memory,
+            memory_key_padding_mask=memory_key_padding_mask,
+            sos_id=sos_id,
+            eos_id=eos_id,
+            nbest_scale=params.nbest_scale,
+        )
+    else:
+        raise ValueError(f"Unsupported decoding method: {params.method}")
+
+    ans = dict()
+    if best_path_dict is not None:
+        for lm_scale_str, best_path in best_path_dict.items():
+            hyps = get_texts(best_path)
+            hyps = [
+                [word_table[i] for i in ids if word_table[i] != ""] for ids in hyps
+            ]
+            ans[lm_scale_str] = hyps
+    else:
+        ans = None
+    return ans
+
+
+def decode_dataset(
+    dl: torch.utils.data.DataLoader,
+    params: AttributeDict,
+    model: nn.Module,
+    HLG: Optional[k2.Fsa],
+    H: Optional[k2.Fsa],
+    bpe_model: Optional[spm.SentencePieceProcessor],
+    word_table: k2.SymbolTable,
+    sos_id: int,
+    eos_id: int,
+    G: Optional[k2.Fsa] = None,
+) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
+    """Decode dataset.
+
+    Args:
+      dl:
+        PyTorch's dataloader containing the dataset to decode.
+      params:
+        It is returned by :func:`get_params`.
+      model:
+        The neural model.
+      HLG:
+        The decoding graph. Used only when params.method is NOT ctc-decoding.
+      H:
+        The ctc topo. Used only when params.method is ctc-decoding.
+      bpe_model:
+        The BPE model. Used only when params.method is ctc-decoding.
+      word_table:
+        It is the word symbol table.
+      sos_id:
+        The token ID for SOS.
+      eos_id:
+        The token ID for EOS.
+      G:
+        An LM. It is not None when params.method is "nbest-rescoring"
+        or "whole-lattice-rescoring". In general, the G in HLG
+        is a 3-gram LM, while this G is a 4-gram LM.
+    Returns:
+      Return a dict, whose key may be "no-rescore" if no LM rescoring
+      is used, or it may be "lm_scale_0.7" if LM rescoring is used.
+      Its value is a list of tuples. Each tuple contains two elements:
+      The first is the reference transcript, and the second is the
+      predicted result.
+    """
+    num_cuts = 0
+
+    try:
+        num_batches = len(dl)
+    except TypeError:
+        num_batches = "?"
+
+    results = defaultdict(list)
+    for batch_idx, batch in enumerate(dl):
+        texts = batch["supervisions"]["text"]
+        cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
+
+        hyps_dict = decode_one_batch(
+            params=params,
+            model=model,
+            HLG=HLG,
+            H=H,
+            bpe_model=bpe_model,
+            batch=batch,
+            word_table=word_table,
+            G=G,
+            sos_id=sos_id,
+            eos_id=eos_id,
+        )
+
+        if hyps_dict is not None:
+            for lm_scale, hyps in hyps_dict.items():
+                this_batch = []
+                assert len(hyps) == len(texts)
+                for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
+                    ref_words = ref_text.split()
+                    this_batch.append((cut_id, ref_words, hyp_words))
+
+                results[lm_scale].extend(this_batch)
+        else:
+            assert len(results) > 0, "It should not decode to empty in the first batch!"
+            this_batch = []
+            hyp_words = []
+            for ref_text in texts:
+                ref_words = ref_text.split()
+                this_batch.append((ref_words, hyp_words))
+
+            for lm_scale in results.keys():
+                results[lm_scale].extend(this_batch)
+
+        num_cuts += len(texts)
+
+        if batch_idx % 100 == 0:
+            batch_str = f"{batch_idx}/{num_batches}"
+
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+    return results
+
+
+def save_results(
+    params: AttributeDict,
+    test_set_name: str,
+    results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
+) -> None:
+    if params.method == "attention-decoder":
+        # Set it to False since there are too many logs.
+        enable_log = False
+    else:
+        enable_log = True
+    test_set_wers = dict()
+    for key, results in results_dict.items():
+        recog_path = params.result_dir / f"recogs-{test_set_name}-{key}.txt"
+        results = sorted(results)
+        store_transcripts(filename=recog_path, texts=results)
+        if enable_log:
+            logging.info(f"The transcripts are stored in {recog_path}")
+
+        # The following prints out WERs, per-word error statistics and aligned
+        # ref/hyp pairs.
+        errs_filename = params.result_dir / f"errs-{test_set_name}-{key}.txt"
+        with open(errs_filename, "w") as f:
+            wer = write_error_stats(
+                f, f"{test_set_name}-{key}", results, enable_log=enable_log
+            )
+            test_set_wers[key] = wer
+
+        if enable_log:
+            logging.info("Wrote detailed error stats to {}".format(errs_filename))
+
+    test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
+    errs_info = params.result_dir / f"wer-summary-{test_set_name}.txt"
+    with open(errs_info, "w") as f:
+        print("settings\tWER", file=f)
+        for key, val in test_set_wers:
+            print("{}\t{}".format(key, val), file=f)
+
+    s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
+    note = "\tbest for {}".format(test_set_name)
+    for key, val in test_set_wers:
+        s += "{}\t{}{}\n".format(key, val, note)
+        note = ""
+    logging.info(s)
+
+
+@torch.no_grad()
+def main() -> None:
+    parser = get_parser()
+    TedLiumAsrDataModule.add_arguments(parser)
+    args = parser.parse_args()
+    args.exp_dir = Path(args.exp_dir)
+    args.lang_dir = Path(args.lang_dir)
+    args.lm_path = Path(args.lm_path)
+    args.result_dir = Path(args.result_dir)
+
+    if args.result_dir.is_dir():
+        shutil.rmtree(args.result_dir)
+    args.result_dir.mkdir()
+
+    params = get_params()
+    params.update(vars(args))
+
+    setup_logger(f"{params.exp_dir}/log-{params.method}/log-decode")
+    logging.info("Decoding started")
+    logging.info(params)
+
+    lexicon = Lexicon(params.lang_dir)
+    max_token_id = max(lexicon.tokens)
+    num_classes = max_token_id + 1  # +1 for the blank
+
+    device = torch.device("cpu")
+    if torch.cuda.is_available():
+        device = torch.device("cuda", 0)
+
+    logging.info(f"device: {device}")
+
+    graph_compiler = BpeCtcTrainingGraphCompiler(
+        params.lang_dir,
+        device=device,
+        sos_token="",
+        eos_token="",
+    )
+    sos_id = graph_compiler.sos_id
+    eos_id = graph_compiler.eos_id
+
+    if params.method in ("ctc-decoding", "ctc-greedy-search"):
+        HLG = None
+        H = k2.ctc_topo(
+            max_token=max_token_id,
+            modified=False,
+            device=device,
+        )
+        bpe_model = spm.SentencePieceProcessor()
+        bpe_model.load(str(params.lang_dir / "bpe.model"))
+    else:
+        H = None
+        bpe_model = None
+        HLG = k2.Fsa.from_dict(
+            torch.load(f"{params.lang_dir}/HLG.pt", map_location=device)
+        )
+        assert HLG.requires_grad is False
+
+        if not hasattr(HLG, "lm_scores"):
+            HLG.lm_scores = HLG.scores.clone()
+
+    if params.method in ("nbest-rescoring", "whole-lattice-rescoring"):
+        assert params.lm_path.suffix in (".pt", ".txt")
+
+        if params.lm_path.is_file() and params.lm_path.suffix == ".pt":
+            logging.info(f"Loading pre-compiled {params.lm_path.name}")
+            d = torch.load(params.lm_path, map_location=device)
+            G = k2.Fsa.from_dict(d)
+        elif not params.lm_path.is_file() and params.lm_path.suffix == ".txt":
+            raise FileNotFoundError(f"No such language model file: '{params.lm_path}'")
+        else:
+            # here we pass only if LM filename ends with '.pt' and doesn't exist
+            # or if LM filename ends '.txt' and exists.
+            if (
+                not params.lm_path.is_file()
+                and params.lm_path.suffix == ".pt"
+                and not (
+                    params.lm_path.parent / f"{params.lm_path.stem}.fst.txt"
+                ).is_file()
+            ):
+                raise FileNotFoundError(
+                    f"No such language model file: '{params.lm_path}'\n"
+                    "'.fst.txt' representation of the language model was "
+                    "not found either."
+                )
+            else:
+                # whatever params.lm_path.name we got lm_name.pt or lm_name.fst.txt
+                # we are going to load lm_name.fst.txt here
+                params.lm_path = params.lm_path.parent / params.lm_path.name.replace(
+                    ".pt", ".fst.txt"
+                )
+                logging.info(f"Loading {params.lm_path.name}")
+                logging.warning("It may take 8 minutes.")
+                with open(params.lm_path) as f:
+                    first_word_disambig_id = lexicon.word_table["#0"]
+
+                    G = k2.Fsa.from_openfst(f.read(), acceptor=False)
+                    # G.aux_labels is not needed in later computations, so
+                    # remove it here.
+                    del G.aux_labels
+                    # CAUTION: The following line is crucial.
+                    # Arcs entering the back-off state have label equal to #0.
+                    # We have to change it to 0 here.
+                    G.labels[G.labels >= first_word_disambig_id] = 0
+                    # See https://github.com/k2-fsa/k2/issues/874
+                    # for why we need to set G.properties to None
+                    G.__dict__["_properties"] = None
+                    G = k2.Fsa.from_fsas([G]).to(device)
+                    G = k2.arc_sort(G)
+                    # Save a dummy value so that it can be loaded in C++.
+                    # See https://github.com/pytorch/pytorch/issues/67902
+                    # for why we need to do this.
+                    G.dummy = 1
+
+                    torch.save(
+                        G.as_dict(),
+                        params.lm_path.parent
+                        / params.lm_path.name.replace(".fst.txt", ".pt"),
+                    )
+
+        if params.method == "whole-lattice-rescoring":
+            # Add epsilon self-loops to G as we will compose
+            # it with the whole lattice later
+            G = k2.add_epsilon_self_loops(G)
+            G = k2.arc_sort(G)
+            G = G.to(device)
+
+        # G.lm_scores is used to replace HLG.lm_scores during
+        # LM rescoring.
+        G.lm_scores = G.scores.clone()
+    else:
+        G = None
+
+    model = Conformer(
+        num_features=params.feature_dim,
+        num_classes=num_classes,
+        subsampling_factor=params.subsampling_factor,
+        d_model=params.dim_model,
+        nhead=params.nhead,
+        dim_feedforward=params.dim_feedforward,
+        num_encoder_layers=params.num_encoder_layers,
+        num_decoder_layers=params.num_decoder_layers,
+    )
+
+    if not params.use_averaged_model:
+        if params.iter > 0:
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg
+            ]
+            if len(filenames) == 0:
+                raise ValueError(
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            elif len(filenames) < params.avg:
+                raise ValueError(
+                    f"Not enough checkpoints ({len(filenames)}) found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            logging.info(f"averaging {filenames}")
+            model.to(device)
+            model.load_state_dict(average_checkpoints(filenames, device=device))
+        elif params.avg == 1:
+            load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+        else:
+            start = params.epoch - params.avg + 1
+            filenames = []
+            for i in range(start, params.epoch + 1):
+                if i >= 1:
+                    filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+            logging.info(f"averaging {filenames}")
+            model.to(device)
+            model.load_state_dict(average_checkpoints(filenames, device=device))
+    else:
+        if params.iter > 0:
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg + 1
+            ]
+            if len(filenames) == 0:
+                raise ValueError(
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            elif len(filenames) < params.avg + 1:
+                raise ValueError(
+                    f"Not enough checkpoints ({len(filenames)}) found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            filename_start = filenames[-1]
+            filename_end = filenames[0]
+            logging.info(
+                "Calculating the averaged model over iteration checkpoints"
+                f" from {filename_start} (excluded) to {filename_end}"
+            )
+            model.to(device)
+            model.load_state_dict(
+                average_checkpoints_with_averaged_model(
+                    filename_start=filename_start,
+                    filename_end=filename_end,
+                    device=device,
+                )
+            )
+        else:
+            assert params.avg > 0, params.avg
+            start = params.epoch - params.avg
+            assert start >= 1, start
+            filename_start = f"{params.exp_dir}/epoch-{start}.pt"
+            filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
+            logging.info(
+                f"Calculating the averaged model over epoch range from "
+                f"{start} (excluded) to {params.epoch}"
+            )
+            model.to(device)
+            model.load_state_dict(
+                average_checkpoints_with_averaged_model(
+                    filename_start=filename_start,
+                    filename_end=filename_end,
+                    device=device,
+                )
+            )
+
+    model.to(device)
+    model.eval()
+    num_param = sum([p.numel() for p in model.parameters()])
+    logging.info(f"Number of model parameters: {num_param}")
+
+    # we need cut ids to display recognition results.
+    args.return_cuts = True
+    tedlium = TedLiumAsrDataModule(args)
+
+    valid_cuts = tedlium.dev_cuts()
+    test_cuts = tedlium.test_cuts()
+
+    valid_dl = tedlium.valid_dataloaders(valid_cuts)
+    test_dl = tedlium.test_dataloaders(test_cuts)
+
+    test_sets = ["dev", "test"]
+    test_dls = [valid_dl, test_dl]
+
+    for test_set, test_dl in zip(test_sets, test_dls):
+        results_dict = decode_dataset(
+            dl=test_dl,
+            params=params,
+            model=model,
+            HLG=HLG,
+            H=H,
+            bpe_model=bpe_model,
+            word_table=lexicon.word_table,
+            G=G,
+            sos_id=sos_id,
+            eos_id=eos_id,
+        )
+
+        save_results(params=params, test_set_name=test_set, results_dict=results_dict)
+
+    logging.info("Done!")
+
+
+torch.set_num_threads(1)
+# when we import add_model_arguments from train.py
+# we enforce torch.set_num_interop_threads(1) in it,
+# so we ended up with setting num_interop_threads to one
+# two times: in train.py and decode.py which cause an error,
+# that is why added an additional if statement.
+if torch.get_num_interop_threads() != 1:
+    torch.set_num_interop_threads(1)
+
+# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
+# in PyTorch 1.12 and later.
+torch.backends.cuda.matmul.allow_tf32 = True
+
+if __name__ == "__main__":
+    main()
diff --git a/egs/tedlium3/ASR/conformer_ctc2/export.py b/egs/tedlium3/ASR/conformer_ctc2/export.py
new file mode 100755
index 000000000..009bea230
--- /dev/null
+++ b/egs/tedlium3/ASR/conformer_ctc2/export.py
@@ -0,0 +1,294 @@
+#!/usr/bin/env python3
+#
+# Copyright 2022 Behavox LLC (Author: Daniil Kulko)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# This script converts several saved checkpoints
+# to a single one using model averaging.
+"""
+Usage:
+./conformer_ctc2/export.py \
+  --exp-dir ./conformer_ctc2/exp \
+  --epoch 20 \
+  --avg 10
+
+It will generate a file exp_dir/pretrained.pt
+
+To use the generated file with `conformer_ctc2/decode.py`,
+you can do:
+
+    cd /path/to/exp_dir
+    ln -s pretrained.pt epoch-9999.pt
+
+    cd /path/to/egs/tedlium3/ASR
+    ./conformer_ctc2/decode.py \
+        --exp-dir ./conformer_ctc2/exp \
+        --epoch 9999 \
+        --avg 1 \
+        --max-duration 100
+"""
+
+import argparse
+import logging
+from pathlib import Path
+
+import torch
+from conformer import Conformer
+from scaling_converter import convert_scaled_to_non_scaled
+from train import add_model_arguments
+
+from icefall.checkpoint import (
+    average_checkpoints,
+    average_checkpoints_with_averaged_model,
+    find_checkpoints,
+    load_checkpoint,
+)
+from icefall.lexicon import Lexicon
+from icefall.utils import AttributeDict, str2bool
+
+
+def get_parser() -> argparse.ArgumentParser:
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+
+    parser.add_argument(
+        "--epoch",
+        type=int,
+        default=30,
+        help="""It specifies the checkpoint to use for averaging.
+        Note: Epoch counts from 0.
+        You can specify --avg to use more checkpoints for model averaging.""",
+    )
+
+    parser.add_argument(
+        "--iter",
+        type=int,
+        default=0,
+        help="""If positive, --epoch is ignored and it
+        will use the checkpoint exp_dir/checkpoint-iter.pt.
+        You can specify --avg to use more checkpoints for model averaging.
+        """,
+    )
+
+    parser.add_argument(
+        "--avg",
+        type=int,
+        default=15,
+        help=(
+            "Number of checkpoints to average. Automatically select "
+            "consecutive checkpoints before the checkpoint specified by "
+            "'--epoch' and '--iter'"
+        ),
+    )
+
+    parser.add_argument(
+        "--use-averaged-model",
+        type=str2bool,
+        default=True,
+        help=(
+            "Whether to load averaged model. Currently it only supports "
+            "using --epoch. If True, it would decode with the averaged model "
+            "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+            "Actually only the models with epoch number of `epoch-avg` and "
+            "`epoch` are loaded for averaging. "
+        ),
+    )
+
+    parser.add_argument(
+        "--exp-dir",
+        type=str,
+        default="conformer_ctc2/exp",
+        help="""It specifies the directory where all training related
+        files, e.g., checkpoints, log, etc, are saved
+        """,
+    )
+
+    parser.add_argument(
+        "--lang-dir",
+        type=str,
+        default="data/lang_bpe_500",
+        help="The lang dir",
+    )
+
+    parser.add_argument(
+        "--jit",
+        type=str2bool,
+        default=True,
+        help="""True to save a model after applying torch.jit.script.
+        """,
+    )
+
+    add_model_arguments(parser)
+
+    return parser
+
+
+def get_params() -> AttributeDict:
+    """Return a dict containing training parameters.
+
+    All training related parameters that are not passed from the commandline
+    are saved in the variable `params`.
+
+    Commandline options are merged into `params` after they are parsed, so
+    you can also access them via `params`.
+
+    Explanation of options saved in `params`:
+
+        - feature_dim: The model input dim. It has to match the one used
+                       in computing features.
+
+        - subsampling_factor:  The subsampling factor for the model.
+    """
+    # parameters for conformer
+    params = AttributeDict({"subsampling_factor": 4, "feature_dim": 80})
+    return params
+
+
+def main():
+    args = get_parser().parse_args()
+    args.exp_dir = Path(args.exp_dir)
+    args.lang_dir = Path(args.lang_dir)
+
+    params = get_params()
+    params.update(vars(args))
+
+    lexicon = Lexicon(params.lang_dir)
+    max_token_id = max(lexicon.tokens)
+    num_classes = max_token_id + 1  # +1 for the blank
+
+    device = torch.device("cpu")
+    if torch.cuda.is_available():
+        device = torch.device("cuda", 0)
+
+    logging.info(f"device: {device}")
+
+    logging.info(params)
+
+    logging.info("About to create model")
+
+    model = Conformer(
+        num_features=params.feature_dim,
+        num_classes=num_classes,
+        subsampling_factor=params.subsampling_factor,
+        d_model=params.dim_model,
+        nhead=params.nhead,
+        dim_feedforward=params.dim_feedforward,
+        num_encoder_layers=params.num_encoder_layers,
+        num_decoder_layers=params.num_decoder_layers,
+    )
+
+    model.to(device)
+
+    if not params.use_averaged_model:
+        if params.iter > 0:
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg
+            ]
+            if len(filenames) == 0:
+                raise ValueError(
+                    f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
+                )
+            elif len(filenames) < params.avg:
+                raise ValueError(
+                    f"Not enough checkpoints ({len(filenames)}) found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            logging.info(f"averaging {filenames}")
+            model.to(device)
+            model.load_state_dict(average_checkpoints(filenames, device=device))
+        elif params.avg == 1:
+            load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+        else:
+            start = params.epoch - params.avg + 1
+            filenames = []
+            for i in range(start, params.epoch + 1):
+                if i >= 1:
+                    filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+            logging.info(f"averaging {filenames}")
+            model.to(device)
+            model.load_state_dict(average_checkpoints(filenames, device=device))
+    else:
+        if params.iter > 0:
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg + 1
+            ]
+            if len(filenames) == 0:
+                raise ValueError(
+                    f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
+                )
+            elif len(filenames) < params.avg + 1:
+                raise ValueError(
+                    f"Not enough checkpoints ({len(filenames)}) found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            filename_start = filenames[-1]
+            filename_end = filenames[0]
+            logging.info(
+                "Calculating the averaged model over iteration checkpoints"
+                f" from {filename_start} (excluded) to {filename_end}"
+            )
+            model.to(device)
+            model.load_state_dict(
+                average_checkpoints_with_averaged_model(
+                    filename_start=filename_start,
+                    filename_end=filename_end,
+                    device=device,
+                )
+            )
+        else:
+            assert params.avg > 0, params.avg
+            start = params.epoch - params.avg
+            assert start >= 1, start
+            filename_start = f"{params.exp_dir}/epoch-{start}.pt"
+            filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
+            logging.info(
+                "Calculating the averaged model over epoch range from "
+                f"{start} (excluded) to {params.epoch}"
+            )
+            model.to(device)
+            model.load_state_dict(
+                average_checkpoints_with_averaged_model(
+                    filename_start=filename_start,
+                    filename_end=filename_end,
+                    device=device,
+                )
+            )
+
+    model.to("cpu")
+    model.eval()
+
+    if params.jit:
+        convert_scaled_to_non_scaled(model, inplace=True)
+        logging.info("Using torch.jit.script")
+        model = torch.jit.script(model)
+        filename = params.exp_dir / "cpu_jit.pt"
+        model.save(str(filename))
+        logging.info(f"Saved to {filename}")
+    else:
+        logging.info("Not using torch.jit.script")
+        # Save it using a format so that it can be loaded
+        # by :func:`load_checkpoint`
+        filename = params.exp_dir / "pretrained.pt"
+        torch.save({"model": model.state_dict()}, str(filename))
+        logging.info(f"Saved to {filename}")
+
+
+if __name__ == "__main__":
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+    logging.basicConfig(format=formatter, level=logging.INFO)
+    main()
diff --git a/egs/tedlium3/ASR/conformer_ctc2/label_smoothing.py b/egs/tedlium3/ASR/conformer_ctc2/label_smoothing.py
new file mode 120000
index 000000000..e9d239fff
--- /dev/null
+++ b/egs/tedlium3/ASR/conformer_ctc2/label_smoothing.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/conformer_ctc/label_smoothing.py
\ No newline at end of file
diff --git a/egs/tedlium3/ASR/conformer_ctc2/lstmp.py b/egs/tedlium3/ASR/conformer_ctc2/lstmp.py
new file mode 120000
index 000000000..b82e115fc
--- /dev/null
+++ b/egs/tedlium3/ASR/conformer_ctc2/lstmp.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/lstm_transducer_stateless2/lstmp.py
\ No newline at end of file
diff --git a/egs/tedlium3/ASR/conformer_ctc2/optim.py b/egs/tedlium3/ASR/conformer_ctc2/optim.py
new file mode 120000
index 000000000..0a2f285aa
--- /dev/null
+++ b/egs/tedlium3/ASR/conformer_ctc2/optim.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless2/optim.py
\ No newline at end of file
diff --git a/egs/tedlium3/ASR/conformer_ctc2/scaling.py b/egs/tedlium3/ASR/conformer_ctc2/scaling.py
new file mode 120000
index 000000000..c10cdfe12
--- /dev/null
+++ b/egs/tedlium3/ASR/conformer_ctc2/scaling.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless2/scaling.py
\ No newline at end of file
diff --git a/egs/tedlium3/ASR/conformer_ctc2/scaling_converter.py b/egs/tedlium3/ASR/conformer_ctc2/scaling_converter.py
new file mode 120000
index 000000000..db93d155b
--- /dev/null
+++ b/egs/tedlium3/ASR/conformer_ctc2/scaling_converter.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py
\ No newline at end of file
diff --git a/egs/tedlium3/ASR/conformer_ctc2/subsampling.py b/egs/tedlium3/ASR/conformer_ctc2/subsampling.py
new file mode 120000
index 000000000..8c91f2336
--- /dev/null
+++ b/egs/tedlium3/ASR/conformer_ctc2/subsampling.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/conformer_ctc2/subsampling.py
\ No newline at end of file
diff --git a/egs/tedlium3/ASR/conformer_ctc2/train.py b/egs/tedlium3/ASR/conformer_ctc2/train.py
new file mode 100755
index 000000000..42e4c010a
--- /dev/null
+++ b/egs/tedlium3/ASR/conformer_ctc2/train.py
@@ -0,0 +1,1061 @@
+#!/usr/bin/env python3
+# Copyright    2022  Behavox LLC.        (authors: Daniil Kulko)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+
+export CUDA_VISIBLE_DEVICES="0,1,2,3"
+
+./conformer_ctc/train.py \
+  --world-size 4 \
+  --num-epochs 30 \
+  --start-epoch 1 \
+  --exp-dir conformer_ctc/exp \
+  --max-duration 300
+
+# For mix precision training:
+
+./conformer_ctc/train.py \
+  --world-size 4 \
+  --num-epochs 30 \
+  --start-epoch 1 \
+  --use-fp16 1 \
+  --exp-dir conformer_ctc/exp \
+  --max-duration 550
+
+"""
+
+
+import argparse
+import copy
+import logging
+from pathlib import Path
+from shutil import copyfile
+from typing import Any, Dict, Optional, Tuple, Union
+
+import k2
+import optim
+import sentencepiece as spm
+import torch
+import torch.multiprocessing as mp
+from asr_datamodule import TedLiumAsrDataModule
+from conformer import Conformer
+from lhotse.dataset.sampling.base import CutSampler
+from lhotse.utils import fix_random_seed
+from local.convert_transcript_words_to_bpe_ids import convert_texts_into_ids
+from torch import Tensor
+from torch.cuda.amp import GradScaler
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.utils.tensorboard import SummaryWriter
+
+from icefall import diagnostics
+from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
+from icefall.checkpoint import load_checkpoint, remove_checkpoints
+from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
+from icefall.checkpoint import (
+    save_checkpoint_with_global_batch_idx,
+    update_averaged_model,
+)
+from icefall.dist import cleanup_dist, setup_dist
+from icefall.env import get_env_info
+from icefall.lexicon import Lexicon
+from icefall.utils import (
+    AttributeDict,
+    MetricsTracker,
+    display_and_save_batch,
+    encode_supervisions,
+    setup_logger,
+    str2bool,
+)
+
+LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
+
+
+def add_model_arguments(parser: argparse.ArgumentParser) -> None:
+    parser.add_argument(
+        "--num-encoder-layers",
+        type=int,
+        default=24,
+        help="Number of conformer encoder layers..",
+    )
+
+    parser.add_argument(
+        "--num-decoder-layers",
+        type=int,
+        default=6,
+        help="""Number of decoder layer of transformer decoder.
+        Setting this to 0 will not create the decoder at all (pure CTC model)
+        """,
+    )
+
+    parser.add_argument(
+        "--att-rate",
+        type=float,
+        default=0.8,
+        help="""The attention rate.
+        The total loss is (1 -  att_rate) * ctc_loss + att_rate * att_loss
+        """,
+    )
+
+    parser.add_argument(
+        "--dim-feedforward",
+        type=int,
+        default=1536,
+        help="Feedforward module dimension of the conformer model.",
+    )
+
+    parser.add_argument(
+        "--nhead",
+        type=int,
+        default=8,
+        help="Number of attention heads in the conformer multiheadattention modules.",
+    )
+
+    parser.add_argument(
+        "--dim-model",
+        type=int,
+        default=384,
+        help="Attention dimension in the conformer model.",
+    )
+
+
+def get_parser() -> argparse.ArgumentParser:
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+
+    parser.add_argument(
+        "--world-size",
+        type=int,
+        default=1,
+        help="Number of GPUs for DDP training.",
+    )
+
+    parser.add_argument(
+        "--master-port",
+        type=int,
+        default=12354,
+        help="Master port to use for DDP training.",
+    )
+
+    parser.add_argument(
+        "--tensorboard",
+        type=str2bool,
+        default=True,
+        help="Should various information be logged in tensorboard.",
+    )
+
+    parser.add_argument(
+        "--num-epochs",
+        type=int,
+        default=30,
+        help="Number of epochs to train.",
+    )
+
+    parser.add_argument(
+        "--start-epoch",
+        type=int,
+        default=1,
+        help="""Resume training from this epoch. It should be positive.
+        If larger than 1, it will load checkpoint from
+        exp-dir/epoch-{start_epoch-1}.pt
+        """,
+    )
+
+    parser.add_argument(
+        "--start-batch",
+        type=int,
+        default=0,
+        help="""If positive, --start-epoch is ignored and
+        it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt
+        """,
+    )
+
+    parser.add_argument(
+        "--exp-dir",
+        type=str,
+        default="conformer_ctc/exp",
+        help="""The experiment dir.
+        It specifies the directory where all training related
+        files, e.g., checkpoints, log, etc, are saved
+        """,
+    )
+
+    parser.add_argument(
+        "--lang-dir",
+        type=str,
+        default="data/lang_bpe_500",
+        help="""The lang dir
+        It contains language related input files such as
+        "lexicon.txt" and "bpe.model"
+        """,
+    )
+
+    parser.add_argument(
+        "--initial-lr",
+        type=float,
+        default=0.003,
+        help="The initial learning rate.  This value should not need to be changed.",
+    )
+
+    parser.add_argument(
+        "--lr-batches",
+        type=float,
+        default=5000,
+        help="""Number of steps that affects how rapidly the learning rate
+        decreases. We suggest not to change this.""",
+    )
+
+    parser.add_argument(
+        "--lr-epochs",
+        type=float,
+        default=6,
+        help="Number of epochs that affects how rapidly the learning rate decreases.",
+    )
+
+    parser.add_argument(
+        "--seed",
+        type=int,
+        default=42,
+        help="The seed for random generators intended for reproducibility",
+    )
+
+    parser.add_argument(
+        "--print-diagnostics",
+        type=str2bool,
+        default=False,
+        help="Accumulate stats on activations, print them and exit.",
+    )
+
+    parser.add_argument(
+        "--save-every-n",
+        type=int,
+        default=4000,
+        help="""Save checkpoint after processing this number of batches"
+        periodically. We save checkpoint to exp-dir/ whenever
+        params.batch_idx_train % save_every_n == 0. The checkpoint filename
+        has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
+        Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
+        end of each epoch where `xxx` is the epoch number counting from 0.
+        """,
+    )
+
+    parser.add_argument(
+        "--keep-last-k",
+        type=int,
+        default=30,
+        help="""Only keep this number of checkpoints on disk.
+        For instance, if it is 3, there are only 3 checkpoints
+        in the exp-dir with filenames `checkpoint-xxx.pt`.
+        It does not affect checkpoints with name `epoch-xxx.pt`.
+        """,
+    )
+
+    parser.add_argument(
+        "--average-period",
+        type=int,
+        default=100,
+        help="""Update the averaged model, namely `model_avg`, after processing
+        this number of batches. `model_avg` is a separate version of model,
+        in which each floating-point parameter is the average of all the
+        parameters from the start of training. Each time we take the average,
+        we do: `model_avg = model * (average_period / batch_idx_train) +
+            model_avg * ((batch_idx_train - average_period) / batch_idx_train)`.
+        """,
+    )
+
+    parser.add_argument(
+        "--use-fp16",
+        type=str2bool,
+        default=False,
+        help="Whether to use half precision training.",
+    )
+
+    add_model_arguments(parser)
+
+    return parser
+
+
+def get_params() -> AttributeDict:
+    """Return a dict containing training parameters.
+
+    All training related parameters that are not passed from the commandline
+    are saved in the variable `params`.
+
+    Commandline options are merged into `params` after they are parsed, so
+    you can also access them via `params`.
+
+    Explanation of options saved in `params`:
+
+        - best_train_loss: Best training loss so far. It is used to select
+                           the model that has the lowest training loss. It is
+                           updated during the training.
+
+        - best_valid_loss: Best validation loss so far. It is used to select
+                           the model that has the lowest validation loss. It is
+                           updated during the training.
+
+        - best_train_epoch: It is the epoch that has the best training loss.
+
+        - best_valid_epoch: It is the epoch that has the best validation loss.
+
+        - batch_idx_train: Used to writing statistics to tensorboard. It
+                           contains number of batches trained so far across
+                           epochs.
+
+        - log_interval:  Print training loss if batch_idx % log_interval` is 0
+
+        - reset_interval: Reset statistics if batch_idx % reset_interval is 0
+
+        - valid_interval:  Run validation if batch_idx % valid_interval is 0
+
+        - feature_dim: The model input dim. It has to match the one used
+                       in computing features.
+
+        - subsampling_factor:  The subsampling factor for the model.
+
+        - warm_step: The warm_step for Noam optimizer.
+    """
+    params = AttributeDict(
+        {
+            "best_train_loss": float("inf"),
+            "best_valid_loss": float("inf"),
+            "best_train_epoch": -1,
+            "best_valid_epoch": -1,
+            "batch_idx_train": 0,
+            "log_interval": 10,
+            "reset_interval": 200,
+            "valid_interval": 1000,
+            # parameters for conformer
+            "feature_dim": 80,
+            "subsampling_factor": 4,
+            # parameters for ctc loss
+            "beam_size": 10,
+            "reduction": "none",
+            "use_double_scores": True,
+            # parameters for Noam
+            "model_warm_step": 3000,  # arg given to model, not for lrate
+            "env_info": get_env_info(),
+        }
+    )
+
+    return params
+
+
+def load_checkpoint_if_available(
+    params: AttributeDict,
+    model: torch.nn.Module,
+    model_avg: torch.nn.Module = None,
+    optimizer: Optional[torch.optim.Optimizer] = None,
+    scheduler: Optional[LRSchedulerType] = None,
+) -> Optional[Dict[str, Any]]:
+    """Load checkpoint from file.
+
+    If params.start_batch is positive, it will load the checkpoint from
+    `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if
+    params.start_epoch is larger than 1, it will load the checkpoint from
+    `params.start_epoch - 1`.
+
+    Apart from loading state dict for `model` and `optimizer` it also updates
+    `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
+    and `best_valid_loss` in `params`.
+
+    Args:
+      params:
+        The return value of :func:`get_params`.
+      model:
+        The training model.
+      model_avg:
+        The stored model averaged from the start of training.
+      optimizer:
+        The optimizer that we are using.
+      scheduler:
+        The scheduler that is used for training.
+    Returns:
+      Return a dict containing previously saved training info.
+    """
+    if params.start_batch > 0:
+        filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt"
+    elif params.start_epoch > 1:
+        filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
+    else:
+        return None
+
+    assert filename.is_file(), f"{filename} does not exist!"
+
+    saved_params = load_checkpoint(
+        filename,
+        model=model,
+        model_avg=model_avg,
+        optimizer=optimizer,
+        scheduler=scheduler,
+    )
+
+    keys = [
+        "best_train_epoch",
+        "best_valid_epoch",
+        "batch_idx_train",
+        "best_train_loss",
+        "best_valid_loss",
+    ]
+    for k in keys:
+        params[k] = saved_params[k]
+
+    if params.start_batch > 0:
+        if "cur_epoch" in saved_params:
+            params["start_epoch"] = saved_params["cur_epoch"]
+
+    return saved_params
+
+
+def save_checkpoint(
+    params: AttributeDict,
+    model: Union[torch.nn.Module, DDP],
+    model_avg: Optional[torch.nn.Module] = None,
+    optimizer: Optional[torch.optim.Optimizer] = None,
+    scheduler: Optional[LRSchedulerType] = None,
+    sampler: Optional[CutSampler] = None,
+    scaler: Optional[GradScaler] = None,
+    rank: int = 0,
+) -> None:
+    """Save model, optimizer, scheduler and training stats to file.
+
+    Args:
+      params:
+        It is returned by :func:`get_params`.
+      model:
+        The training model.
+      model_avg:
+        The stored model averaged from the start of training.
+      optimizer:
+        The optimizer used for training.
+      scheduler:
+        The learning rate scheduler used for training.
+      sampler:
+       The sampler for the training dataset.
+      scaler:
+        The scaler used for mix precision training.
+    """
+    if rank != 0:
+        return
+    filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
+    save_checkpoint_impl(
+        filename=filename,
+        model=model,
+        model_avg=model_avg,
+        params=params,
+        optimizer=optimizer,
+        scheduler=scheduler,
+        sampler=sampler,
+        scaler=scaler,
+        rank=rank,
+    )
+
+    if params.best_train_epoch == params.cur_epoch:
+        best_train_filename = params.exp_dir / "best-train-loss.pt"
+        copyfile(src=filename, dst=best_train_filename)
+
+    if params.best_valid_epoch == params.cur_epoch:
+        best_valid_filename = params.exp_dir / "best-valid-loss.pt"
+        copyfile(src=filename, dst=best_valid_filename)
+
+
+def compute_loss(
+    params: AttributeDict,
+    model: Union[torch.nn.Module, DDP],
+    graph_compiler: BpeCtcTrainingGraphCompiler,
+    batch: dict,
+    is_training: bool,
+    warmup: float = 1.0,
+) -> Tuple[Tensor, MetricsTracker]:
+    """
+    Compute CTC loss given the model and its inputs.
+    Args:
+      params:
+        Parameters for training. See :func:`get_params`.
+      model:
+        The model for training. It is an instance of Conformer in our case.
+      batch:
+        A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+        for the content in it.
+      graph_compiler:
+        It is used to build a decoding graph from a ctc topo and training
+        transcript. The training transcript is contained in the given `batch`,
+        while the ctc topo is built when this compiler is instantiated.
+      is_training:
+        True for training. False for validation. When it is True, this
+        function enables autograd during computation; when it is False, it
+        disables autograd.
+     warmup: a floating point value which increases throughout training;
+        values >= 1.0 are fully warmed up and have all modules present.
+    """
+    device = model.device if isinstance(model, DDP) else next(model.parameters()).device
+    feature = batch["inputs"]
+    # at entry, feature is (N, T, C)
+    assert feature.ndim == 3
+    feature = feature.to(device)
+
+    supervisions = batch["supervisions"]
+    feature_lens = supervisions["num_frames"].to(device)
+
+    with torch.set_grad_enabled(is_training):
+        nnet_output, encoder_memory, memory_mask = model(
+            feature, supervisions, warmup=warmup
+        )
+
+        supervision_segments, texts = encode_supervisions(
+            supervisions, subsampling_factor=params.subsampling_factor
+        )
+
+        token_ids = convert_texts_into_ids(texts, graph_compiler.sp)
+        decoding_graph = graph_compiler.compile(token_ids)
+
+        dense_fsa_vec = k2.DenseFsaVec(
+            nnet_output,
+            supervision_segments,
+            allow_truncate=params.subsampling_factor - 1,
+        )
+
+        ctc_loss = k2.ctc_loss(
+            decoding_graph=decoding_graph,
+            dense_fsa_vec=dense_fsa_vec,
+            output_beam=params.beam_size,
+            reduction=params.reduction,
+            use_double_scores=params.use_double_scores,
+        )
+
+        if params.att_rate > 0.0:
+            with torch.set_grad_enabled(is_training):
+                mmodel = model.module if hasattr(model, "module") else model
+                # Note: We need to generate an unsorted version of token_ids
+                # `encode_supervisions()` called above sorts text, but
+                # encoder_memory and memory_mask are not sorted, so we
+                # use an unsorted version `supervisions["text"]` to regenerate
+                # the token_ids
+                #
+                # See https://github.com/k2-fsa/icefall/issues/97
+                # for more details
+                unsorted_token_ids = graph_compiler.texts_to_ids(supervisions["text"])
+                att_loss = mmodel.decoder_forward(
+                    encoder_memory,
+                    memory_mask,
+                    token_ids=unsorted_token_ids,
+                    sos_id=graph_compiler.sos_id,
+                    eos_id=graph_compiler.eos_id,
+                    warmup=warmup,
+                )
+        else:
+            att_loss = torch.tensor([0])
+
+        ctc_loss_is_finite = torch.isfinite(ctc_loss)
+        att_loss_is_finite = torch.isfinite(att_loss)
+        if torch.any(~ctc_loss_is_finite) or torch.any(~att_loss_is_finite):
+            logging.info(
+                "Not all losses are finite!\n"
+                f"ctc_loss: {ctc_loss}\n"
+                f"att_loss: {att_loss}"
+            )
+            display_and_save_batch(batch, params=params, sp=graph_compiler.sp)
+            ctc_loss = ctc_loss[ctc_loss_is_finite]
+            att_loss = att_loss[att_loss_is_finite]
+
+            # If the batch contains more than 10 utterances AND
+            # if either all ctc_loss or att_loss is inf or nan,
+            # we stop the training process by raising an exception
+            if torch.all(~ctc_loss_is_finite) or torch.all(~att_loss_is_finite):
+                raise ValueError(
+                    "There are too many utterances in this batch "
+                    "leading to inf or nan losses."
+                )
+
+        ctc_loss = ctc_loss.sum()
+        att_loss = att_loss.sum()
+
+        if params.att_rate > 0.0:
+            loss = (1.0 - params.att_rate) * ctc_loss + params.att_rate * att_loss
+        else:
+            loss = ctc_loss
+
+    assert loss.requires_grad == is_training
+
+    info = MetricsTracker()
+    # info["frames"] is an approximate number for two reasons:
+    # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2
+    # (2) If some utterances in the batch lead to inf/nan loss, they
+    #     are filtered out.
+    info["frames"] = (
+        torch.div(feature_lens, params.subsampling_factor, rounding_mode="floor")
+        .sum()
+        .item()
+    )
+
+    # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances`  # noqa
+    info["utterances"] = feature.size(0)
+    # averaged input duration in frames over utterances
+    info["utt_duration"] = feature_lens.sum().item()
+    # averaged padding proportion over utterances
+    info["utt_pad_proportion"] = (
+        ((feature.size(1) - feature_lens) / feature.size(1)).sum().item()
+    )
+
+    # Note: We use reduction=sum while computing the loss.
+    info["loss"] = loss.detach().cpu().item()
+    info["ctc_loss"] = ctc_loss.detach().cpu().item()
+    if params.att_rate > 0.0:
+        info["att_loss"] = att_loss.detach().cpu().item()
+
+    return loss, info
+
+
+def compute_validation_loss(
+    params: AttributeDict,
+    model: Union[torch.nn.Module, DDP],
+    graph_compiler: BpeCtcTrainingGraphCompiler,
+    valid_dl: torch.utils.data.DataLoader,
+    world_size: int = 1,
+) -> MetricsTracker:
+    """Run the validation process."""
+    model.eval()
+
+    tot_loss = MetricsTracker()
+
+    for batch in valid_dl:
+        loss, loss_info = compute_loss(
+            params=params,
+            model=model,
+            graph_compiler=graph_compiler,
+            batch=batch,
+            is_training=False,
+        )
+        assert loss.requires_grad is False
+        tot_loss = tot_loss + loss_info
+
+    if world_size > 1:
+        tot_loss.reduce(loss.device)
+
+    loss_value = tot_loss["loss"] / tot_loss["frames"]
+    if loss_value < params.best_valid_loss:
+        params.best_valid_epoch = params.cur_epoch
+        params.best_valid_loss = loss_value
+
+    return tot_loss
+
+
+def train_one_epoch(
+    params: AttributeDict,
+    model: Union[torch.nn.Module, DDP],
+    optimizer: torch.optim.Optimizer,
+    scheduler: LRSchedulerType,
+    graph_compiler: BpeCtcTrainingGraphCompiler,
+    train_dl: torch.utils.data.DataLoader,
+    valid_dl: torch.utils.data.DataLoader,
+    scaler: GradScaler,
+    model_avg: Optional[torch.nn.Module] = None,
+    tb_writer: Optional[SummaryWriter] = None,
+    world_size: int = 1,
+    rank: int = 0,
+) -> None:
+    """Train the model for one epoch.
+
+    The training loss from the mean of all frames is saved in
+    `params.train_loss`. It runs the validation process every
+    `params.valid_interval` batches.
+
+    Args:
+      params:
+        It is returned by :func:`get_params`.
+      model:
+        The model for training.
+      optimizer:
+        The optimizer we are using.
+      scheduler:
+        The learning rate scheduler, we call step() every step.
+      graph_compiler:
+        It is used to convert transcripts to FSAs.
+      train_dl:
+        Dataloader for the training dataset.
+      valid_dl:
+        Dataloader for the validation dataset.
+      scaler:
+        The scaler used for mix precision training.
+      model_avg:
+        The stored model averaged from the start of training.
+      tb_writer:
+        Writer to write log messages to tensorboard.
+      world_size:
+        Number of nodes in DDP training. If it is 1, DDP is disabled.
+      rank:
+        The rank of the node in DDP training. If no DDP is used, it should
+        be set to 0.
+    """
+    model.train()
+
+    tot_loss = MetricsTracker()
+
+    for batch_idx, batch in enumerate(train_dl):
+        params.batch_idx_train += 1
+        batch_size = len(batch["supervisions"]["text"])
+
+        try:
+            with torch.cuda.amp.autocast(enabled=params.use_fp16):
+                loss, loss_info = compute_loss(
+                    params=params,
+                    model=model,
+                    graph_compiler=graph_compiler,
+                    batch=batch,
+                    is_training=True,
+                    warmup=(params.batch_idx_train / params.model_warm_step),
+                )
+            # summary stats
+            tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
+
+            # NOTE: We use reduction==sum and loss is computed over utterances
+            # in the batch and there is no normalization to it so far.
+            scaler.scale(loss).backward()
+            scheduler.step_batch(params.batch_idx_train)
+            scaler.step(optimizer)
+            scaler.update()
+            optimizer.zero_grad()
+        except:  # noqa
+            display_and_save_batch(batch, params=params, sp=graph_compiler.sp)
+            raise
+
+        if params.print_diagnostics and batch_idx == 5:
+            return
+
+        if (
+            rank == 0
+            and params.batch_idx_train > 0
+            and params.batch_idx_train % params.average_period == 0
+        ):
+            update_averaged_model(
+                params=params,
+                model_cur=model,
+                model_avg=model_avg,
+            )
+
+        if (
+            params.batch_idx_train > 0
+            and params.batch_idx_train % params.save_every_n == 0
+        ):
+            save_checkpoint_with_global_batch_idx(
+                out_dir=params.exp_dir,
+                global_batch_idx=params.batch_idx_train,
+                model=model,
+                model_avg=model_avg,
+                params=params,
+                optimizer=optimizer,
+                scheduler=scheduler,
+                sampler=train_dl.sampler,
+                scaler=scaler,
+                rank=rank,
+            )
+            remove_checkpoints(
+                out_dir=params.exp_dir,
+                topk=params.keep_last_k,
+                rank=rank,
+            )
+
+        if batch_idx % params.log_interval == 0:
+            cur_lr = scheduler.get_last_lr()[0]
+            logging.info(
+                f"Epoch {params.cur_epoch}, "
+                f"batch {batch_idx}, loss[{loss_info}], "
+                f"tot_loss[{tot_loss}], batch size: {batch_size}, "
+                f"lr: {cur_lr:.2e}"
+            )
+
+            if tb_writer is not None:
+                tb_writer.add_scalar(
+                    "train/learning_rate", cur_lr, params.batch_idx_train
+                )
+
+                loss_info.write_summary(
+                    tb_writer, "train/current_", params.batch_idx_train
+                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+
+        if batch_idx > 0 and batch_idx % params.valid_interval == 0:
+            logging.info("Computing validation loss")
+            valid_info = compute_validation_loss(
+                params=params,
+                model=model,
+                graph_compiler=graph_compiler,
+                valid_dl=valid_dl,
+                world_size=world_size,
+            )
+            model.train()
+            logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
+            if tb_writer is not None:
+                valid_info.write_summary(
+                    tb_writer, "train/valid_", params.batch_idx_train
+                )
+
+    loss_value = tot_loss["loss"] / tot_loss["frames"]
+    params.train_loss = loss_value
+    if params.train_loss < params.best_train_loss:
+        params.best_train_epoch = params.cur_epoch
+        params.best_train_loss = params.train_loss
+
+
+def run(rank, world_size, args):
+    """
+    Args:
+      rank:
+        It is a value between 0 and `world_size-1`, which is
+        passed automatically by `mp.spawn()` in :func:`main`.
+        The node with rank 0 is responsible for saving checkpoint.
+      world_size:
+        Number of GPUs for DDP training.
+      args:
+        The return value of get_parser().parse_args()
+    """
+    params = get_params()
+    params.update(vars(args))
+
+    fix_random_seed(params.seed)
+    if world_size > 1:
+        setup_dist(rank, world_size, params.master_port)
+
+    setup_logger(f"{params.exp_dir}/log/log-train")
+    logging.info("Training started")
+    logging.info(params)
+
+    if args.tensorboard and rank == 0:
+        tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
+    else:
+        tb_writer = None
+
+    lexicon = Lexicon(params.lang_dir)
+    max_token_id = max(lexicon.tokens)
+    num_classes = max_token_id + 1  # +1 for the blank
+
+    device = torch.device("cpu")
+    if torch.cuda.is_available():
+        device = torch.device("cuda", rank)
+    logging.info(f"Device: {device}")
+
+    if "lang_bpe" not in str(params.lang_dir):
+        raise ValueError(
+            f"Unsupported type of lang dir (we expected it to have "
+            f"'lang_bpe' in its name): {params.lang_dir}"
+        )
+
+    graph_compiler = BpeCtcTrainingGraphCompiler(
+        params.lang_dir,
+        device=device,
+        sos_token="",
+        eos_token="",
+    )
+
+    logging.info("About to create model")
+    model = Conformer(
+        num_features=params.feature_dim,
+        num_classes=num_classes,
+        subsampling_factor=params.subsampling_factor,
+        d_model=params.dim_model,
+        nhead=params.nhead,
+        dim_feedforward=params.dim_feedforward,
+        num_encoder_layers=params.num_encoder_layers,
+        num_decoder_layers=params.num_decoder_layers,
+    )
+
+    num_param = sum([p.numel() for p in model.parameters()])
+    logging.info(f"Number of model parameters: {num_param}")
+
+    assert params.save_every_n >= params.average_period
+    model_avg: Optional[torch.nn.Module] = None
+    if rank == 0:
+        # model_avg is only used with rank 0
+        model_avg = copy.deepcopy(model)
+
+    assert params.start_epoch > 0, params.start_epoch
+    checkpoints = load_checkpoint_if_available(
+        params=params, model=model, model_avg=model_avg
+    )
+
+    model.to(device)
+    if world_size > 1:
+        logging.info("Using DDP")
+        model = DDP(model, device_ids=[rank])
+
+    optimizer = optim.Eve(model.parameters(), lr=params.initial_lr)
+    scheduler = optim.Eden(optimizer, params.lr_batches, params.lr_epochs)
+
+    if checkpoints and checkpoints.get("optimizer") is not None:
+        logging.info("Loading optimizer state dict")
+        optimizer.load_state_dict(checkpoints["optimizer"])
+
+    if checkpoints and checkpoints.get("scheduler") is not None:
+        logging.info("Loading scheduler state dict")
+        scheduler.load_state_dict(checkpoints["scheduler"])
+
+    if params.print_diagnostics:
+        opts = diagnostics.TensorDiagnosticOptions(
+            2**22
+        )  # allow 4 megabytes per sub-module
+        diagnostic = diagnostics.attach_diagnostics(model, opts)
+
+    tedlium = TedLiumAsrDataModule(args)
+
+    train_cuts = tedlium.train_cuts()
+
+    if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
+        # We only load the sampler's state dict when it loads a checkpoint
+        # saved in the middle of an epoch
+        sampler_state_dict = checkpoints["sampler"]
+    else:
+        sampler_state_dict = None
+
+    train_dl = tedlium.train_dataloaders(
+        train_cuts, sampler_state_dict=sampler_state_dict
+    )
+
+    valid_cuts = tedlium.dev_cuts()
+    valid_dl = tedlium.valid_dataloaders(valid_cuts)
+
+    if (
+        params.start_epoch <= 1
+        and params.start_batch <= 0
+        and not params.print_diagnostics
+    ):
+        scan_pessimistic_batches_for_oom(
+            model=model,
+            train_dl=train_dl,
+            optimizer=optimizer,
+            graph_compiler=graph_compiler,
+            params=params,
+            warmup=0.0 if params.start_epoch == 1 else 1.0,
+        )
+
+    scaler = GradScaler(enabled=params.use_fp16)
+    if checkpoints and "grad_scaler" in checkpoints:
+        logging.info("Loading grad scaler state dict")
+        scaler.load_state_dict(checkpoints["grad_scaler"])
+
+    for epoch in range(params.start_epoch, params.num_epochs + 1):
+        scheduler.step_epoch(epoch - 1)
+        fix_random_seed(params.seed + epoch - 1)
+        train_dl.sampler.set_epoch(epoch - 1)
+        train_dl.dataset.epoch = epoch - 1
+
+        if tb_writer is not None:
+            tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
+
+        params.cur_epoch = epoch
+
+        train_one_epoch(
+            params=params,
+            model=model,
+            model_avg=model_avg,
+            optimizer=optimizer,
+            scheduler=scheduler,
+            graph_compiler=graph_compiler,
+            train_dl=train_dl,
+            valid_dl=valid_dl,
+            scaler=scaler,
+            tb_writer=tb_writer,
+            world_size=world_size,
+            rank=rank,
+        )
+
+        if params.print_diagnostics:
+            diagnostic.print_diagnostics()
+            break
+
+        save_checkpoint(
+            params=params,
+            model=model,
+            model_avg=model_avg,
+            optimizer=optimizer,
+            scheduler=scheduler,
+            sampler=train_dl.sampler,
+            scaler=scaler,
+            rank=rank,
+        )
+
+    logging.info("Done!")
+
+    if world_size > 1:
+        torch.distributed.barrier()
+        cleanup_dist()
+
+
+def scan_pessimistic_batches_for_oom(
+    model: Union[torch.nn.Module, DDP],
+    train_dl: torch.utils.data.DataLoader,
+    optimizer: torch.optim.Optimizer,
+    graph_compiler: BpeCtcTrainingGraphCompiler,
+    params: AttributeDict,
+    warmup: float,
+):
+    from lhotse.dataset import find_pessimistic_batches
+
+    logging.info(
+        "Sanity check -- see if any of the batches in epoch 1 would cause OOM."
+    )
+    batches, crit_values = find_pessimistic_batches(train_dl.sampler)
+    for criterion, cuts in batches.items():
+        batch = train_dl.dataset[cuts]
+        try:
+            with torch.cuda.amp.autocast(enabled=params.use_fp16):
+                loss, _ = compute_loss(
+                    params=params,
+                    model=model,
+                    graph_compiler=graph_compiler,
+                    batch=batch,
+                    is_training=True,
+                    warmup=warmup,
+                )
+            loss.backward()
+            optimizer.step()
+            optimizer.zero_grad()
+        except Exception as e:
+            if "CUDA out of memory" in str(e):
+                logging.error(
+                    "Your GPU ran out of memory with the current "
+                    "max_duration setting. We recommend decreasing "
+                    "max_duration and trying again.\n"
+                    f"Failing criterion: {criterion} "
+                    f"(={crit_values[criterion]}) ..."
+                )
+            display_and_save_batch(batch, params=params, sp=graph_compiler.sp)
+            raise
+
+
+def main():
+    parser = get_parser()
+    TedLiumAsrDataModule.add_arguments(parser)
+    args = parser.parse_args()
+    args.exp_dir = Path(args.exp_dir)
+
+    world_size = args.world_size
+    assert world_size >= 1
+    if world_size > 1:
+        mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
+    else:
+        run(rank=0, world_size=1, args=args)
+
+
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
+# in PyTorch 1.12 and later.
+torch.backends.cuda.matmul.allow_tf32 = True
+
+if __name__ == "__main__":
+    main()
diff --git a/egs/tedlium3/ASR/conformer_ctc2/transformer.py b/egs/tedlium3/ASR/conformer_ctc2/transformer.py
new file mode 100644
index 000000000..9dbf32e48
--- /dev/null
+++ b/egs/tedlium3/ASR/conformer_ctc2/transformer.py
@@ -0,0 +1,1093 @@
+# Copyright    2021  University of Chinese Academy of Sciences (author: Han Zhu)
+# Copyright    2022  Xiaomi Corp.                              (author: Quandong Wang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import copy
+import math
+from typing import Dict, List, Optional, Tuple
+
+import torch
+import torch.nn as nn
+from attention import MultiheadAttention
+from combiner import RandomCombine
+from label_smoothing import LabelSmoothingLoss
+from scaling import (
+    ActivationBalancer,
+    BasicNorm,
+    DoubleSwish,
+    ScaledEmbedding,
+    ScaledLinear,
+)
+from subsampling import Conv2dSubsampling
+from torch.nn.utils.rnn import pad_sequence
+
+# Note: TorchScript requires Dict/List/etc. to be fully typed.
+Supervisions = Dict[str, torch.Tensor]
+
+
+class Transformer(nn.Module):
+    def __init__(
+        self,
+        num_features: int,
+        num_classes: int,
+        subsampling_factor: int = 4,
+        d_model: int = 256,
+        nhead: int = 4,
+        dim_feedforward: int = 2048,
+        num_encoder_layers: int = 12,
+        num_decoder_layers: int = 6,
+        dropout: float = 0.1,
+        layer_dropout: float = 0.075,
+        aux_layer_period: int = 3,
+    ) -> None:
+        """
+        Args:
+          num_features:
+            the input dimension of the model.
+          num_classes:
+            the output dimension of the model.
+          subsampling_factor:
+            number of output frames is num_in_frames // subsampling_factor;
+            currently, subsampling_factor MUST be 4.
+          d_model:
+            attention dimension.
+          nhead:
+            number of heads in multi-head attention;
+            must satisfy d_model // nhead == 0.
+          dim_feedforward:
+            the output dimension of the feedforward layers in encoder/decoder.
+          num_encoder_layers:
+            number of encoder layers.
+          num_decoder_layers:
+            number of decoder layers.
+          dropout:
+            dropout in encoder/decoder.
+          layer_dropout:
+            layer-dropout rate.
+          aux_layer_period:
+            determines the auxiliary encoder layers.
+        """
+        super().__init__()
+
+        self.num_features = num_features
+        self.num_classes = num_classes
+        self.subsampling_factor = subsampling_factor
+        if subsampling_factor != 4:
+            raise NotImplementedError("Support only 'subsampling_factor=4'.")
+
+        # self.encoder_embed converts the input of shape (N, T, num_classes)
+        # to the shape (N, T//subsampling_factor, d_model).
+        # That is, it does two things simultaneously:
+        #   (1) subsampling: T -> T//subsampling_factor
+        #   (2) embedding: num_classes -> d_model
+        self.encoder_embed = Conv2dSubsampling(num_features, d_model)
+
+        self.encoder_pos = PositionalEncoding(d_model, dropout)
+
+        encoder_layer = TransformerEncoderLayer(
+            d_model=d_model,
+            nhead=nhead,
+            dim_feedforward=dim_feedforward,
+            dropout=dropout,
+            layer_dropout=layer_dropout,
+        )
+        # aux_layers from 1/3
+        self.encoder = TransformerEncoder(
+            encoder_layer=encoder_layer,
+            num_layers=num_encoder_layers,
+            aux_layers=list(
+                range(
+                    num_encoder_layers // 3,
+                    num_encoder_layers - 1,
+                    aux_layer_period,
+                )
+            ),
+        )
+
+        # TODO(fangjun): remove dropout
+        self.encoder_output_layer = nn.Sequential(
+            nn.Dropout(p=dropout), ScaledLinear(d_model, num_classes, bias=True)
+        )
+
+        if num_decoder_layers > 0:
+            self.decoder_num_class = (
+                self.num_classes
+            )  # bpe model already has sos/eos symbol
+
+            self.decoder_embed = ScaledEmbedding(
+                num_embeddings=self.decoder_num_class, embedding_dim=d_model
+            )
+            self.decoder_pos = PositionalEncoding(d_model, dropout)
+
+            decoder_layer = TransformerDecoderLayer(
+                d_model=d_model,
+                nhead=nhead,
+                dim_feedforward=dim_feedforward,
+                dropout=dropout,
+            )
+
+            self.decoder = TransformerDecoder(
+                decoder_layer=decoder_layer,
+                num_layers=num_decoder_layers,
+                aux_layers=[],
+            )
+
+            self.decoder_output_layer = ScaledLinear(
+                d_model, self.decoder_num_class, bias=True
+            )
+
+            self.decoder_criterion = LabelSmoothingLoss(reduction="none")
+        else:
+            self.decoder_criterion = None
+
+    def forward(
+        self,
+        x: torch.Tensor,
+        supervision: Optional[Supervisions] = None,
+        warmup: float = 1.0,
+    ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
+        """
+        Args:
+          x:
+            The input tensor. Its shape is (N, S, C).
+          supervision:
+            Supervision in lhotse format.
+            See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32  # noqa
+            (CAUTION: It contains length information, i.e., start and number of
+             frames, before subsampling)
+          warmup:
+            a floating point value that gradually increases from 0 throughout
+            training; when it is >= 1.0 we are "fully warmed up". It is used
+            to turn modules on sequentially.
+
+        Returns:
+          Return a tuple containing 3 tensors:
+            - CTC output for ctc decoding. Its shape is (N, S, C)
+            - Encoder output with shape (S, N, C). It can be used as key and
+              value for the decoder.
+            - Encoder output padding mask. It can be used as
+              memory_key_padding_mask for the decoder. Its shape is (N, S).
+              It is None if `supervision` is None.
+        """
+
+        encoder_memory, memory_key_padding_mask = self.run_encoder(
+            x, supervision, warmup
+        )
+
+        x = self.ctc_output(encoder_memory)
+        return x, encoder_memory, memory_key_padding_mask
+
+    def run_encoder(
+        self,
+        x: torch.Tensor,
+        supervisions: Optional[Supervisions] = None,
+        warmup: float = 1.0,
+    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
+        """Run the transformer encoder.
+
+        Args:
+          x:
+            The model input. Its shape is (N, S, C).
+          supervisions:
+            Supervision in lhotse format.
+            See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32  # noqa
+            CAUTION: It contains length information, i.e., start and number of
+            frames, before subsampling
+            It is read directly from the batch, without any sorting. It is used
+            to compute the encoder padding mask, which is used as memory key
+            padding mask for the decoder.
+          warmup:
+            a floating point value that gradually increases from 0 throughout
+            training; when it is >= 1.0 we are "fully warmed up". It is used
+            to turn modules on sequentially.
+
+        Returns:
+          Return a tuple with two tensors:
+            - The encoder output, with shape (S, N, C)
+            - encoder padding mask, with shape (N, S).
+              The mask is None if `supervisions` is None.
+              It is used as memory key padding mask in the decoder.
+        """
+        x = self.encoder_embed(x)
+        x = self.encoder_pos(x)
+        x = x.permute(1, 0, 2)  # (N, S, C) -> (S, N, C)
+        mask = encoder_padding_mask(x.size(0), supervisions)
+        mask = mask.to(x.device) if mask is not None else None
+        x = self.encoder(x, src_key_padding_mask=mask, warmup=warmup)  # (S, N, C)
+
+        return x, mask
+
+    def ctc_output(self, x: torch.Tensor) -> torch.Tensor:
+        """
+        Args:
+          x:
+            the output tensor from the transformer encoder;
+            its shape is (S, N, C)
+
+        Returns:
+          Return a tensor that can be used for CTC decoding.
+          Its shape is (N, S, C)
+        """
+        x = self.encoder_output_layer(x)
+        x = x.permute(1, 0, 2)  # (S, N, C) -> (N, S, C)
+        x = nn.functional.log_softmax(x, dim=-1)  # (N, S, C)
+        return x
+
+    @torch.jit.export
+    def decoder_forward(
+        self,
+        memory: torch.Tensor,
+        memory_key_padding_mask: torch.Tensor,
+        token_ids: List[List[int]],
+        sos_id: int,
+        eos_id: int,
+        warmup: float = 1.0,
+    ) -> torch.Tensor:
+        """
+        Args:
+          memory:
+            It's the output of the encoder of shape (S, N, C)
+          memory_key_padding_mask:
+            The padding mask from the encoder of shape (N, S).
+          token_ids:
+            A list-of-list IDs. Each sublist contains IDs for an utterance.
+            The IDs can be either phone IDs or word piece IDs.
+          sos_id:
+            sos token id
+          eos_id:
+            eos token id
+          warmup:
+            a floating point value that gradually increases from 0 throughout
+            training; when it is >= 1.0 we are "fully warmed up". It is used
+            to turn modules on sequentially.
+
+        Returns:
+          A scalar, the **sum** of label smoothing loss over utterances
+          in the batch without any normalization.
+        """
+        ys_in = add_sos(token_ids, sos_id=sos_id)
+        ys_in = [torch.tensor(y) for y in ys_in]
+        ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id))
+
+        ys_out = add_eos(token_ids, eos_id=eos_id)
+        ys_out = [torch.tensor(y) for y in ys_out]
+        ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1))
+
+        device = memory.device
+        ys_in_pad = ys_in_pad.to(device)
+        ys_out_pad = ys_out_pad.to(device)
+
+        tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device)
+
+        tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id)
+        # TODO: Use length information to create the decoder padding mask
+        # We set the first column to False since the first column in ys_in_pad
+        # contains sos_id, which is the same as eos_id in our current setting.
+        tgt_key_padding_mask[:, 0] = False
+
+        tgt = self.decoder_embed(ys_in_pad)  # (N, T) -> (N, T, C)
+        tgt = self.decoder_pos(tgt)
+        tgt = tgt.permute(1, 0, 2)  # (N, T, C) -> (T, N, C)
+        pred_pad = self.decoder(
+            tgt=tgt,
+            memory=memory,
+            tgt_mask=tgt_mask,
+            tgt_key_padding_mask=tgt_key_padding_mask,
+            memory_key_padding_mask=memory_key_padding_mask,
+            warmup=warmup,
+        )  # (T, N, C)
+        pred_pad = pred_pad.permute(1, 0, 2)  # (T, N, C) -> (N, T, C)
+        pred_pad = self.decoder_output_layer(pred_pad)  # (N, T, C)
+
+        decoder_loss = self.decoder_criterion(pred_pad, ys_out_pad)
+
+        return decoder_loss
+
+    @torch.jit.export
+    def decoder_nll(
+        self,
+        memory: torch.Tensor,
+        memory_key_padding_mask: torch.Tensor,
+        token_ids: List[torch.Tensor],
+        sos_id: int,
+        eos_id: int,
+        warmup: float = 1.0,
+    ) -> torch.Tensor:
+        """
+        Args:
+          memory:
+            It's the output of the encoder of shape (S, N, C).
+          memory_key_padding_mask:
+            The padding mask from the encoder of shape (N, S).
+          token_ids:
+            A list-of-list IDs (e.g., word piece IDs).
+            Each sublist represents an utterance.
+          sos_id:
+            The token ID for SOS.
+          eos_id:
+            The token ID for EOS.
+          warmup:
+            a floating point value that gradually increases from 0 throughout
+            training; when it is >= 1.0 we are "fully warmed up". It is used
+            to turn modules on sequentially.
+
+        Returns:
+          A 2-D tensor of shape (len(token_ids), max_token_length)
+          representing the cross entropy loss (i.e., negative log-likelihood).
+        """
+        # The common part between this function and decoder_forward could be
+        # extracted as a separate function.
+        if isinstance(token_ids[0], torch.Tensor):
+            # This branch is executed by torchscript in C++.
+            # See https://github.com/k2-fsa/k2/pull/870
+            # https://github.com/k2-fsa/k2/blob/3c1c18400060415b141ccea0115fd4bf0ad6234e/k2/torch/bin/attention_rescore.cu#L286
+            token_ids = [tolist(t) for t in token_ids]
+
+        ys_in = add_sos(token_ids, sos_id=sos_id)
+        ys_in = [torch.tensor(y) for y in ys_in]
+        ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id))
+
+        ys_out = add_eos(token_ids, eos_id=eos_id)
+        ys_out = [torch.tensor(y) for y in ys_out]
+        ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1))
+
+        device = memory.device
+        ys_in_pad = ys_in_pad.to(device, dtype=torch.int64)
+        ys_out_pad = ys_out_pad.to(device, dtype=torch.int64)
+
+        tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device)
+
+        tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id)
+        # TODO: Use length information to create the decoder padding mask
+        # We set the first column to False since the first column in ys_in_pad
+        # contains sos_id, which is the same as eos_id in our current setting.
+        tgt_key_padding_mask[:, 0] = False
+
+        tgt = self.decoder_embed(ys_in_pad)  # (N, T) -> (N, T, C)
+        tgt = self.decoder_pos(tgt)
+        tgt = tgt.permute(1, 0, 2)  # (N, T, С) -> (T, N, C)
+        pred_pad = self.decoder(
+            tgt=tgt,
+            memory=memory,
+            tgt_mask=tgt_mask,
+            tgt_key_padding_mask=tgt_key_padding_mask,
+            memory_key_padding_mask=memory_key_padding_mask,
+            warmup=warmup,
+        )  # (T, B, F)
+        pred_pad = pred_pad.permute(1, 0, 2)  # (T, N, C) -> (N, T, C)
+        pred_pad = self.decoder_output_layer(pred_pad)  # (N, T, C)
+        # nll: negative log-likelihood
+        nll = torch.nn.functional.cross_entropy(
+            pred_pad.view(-1, self.decoder_num_class),
+            ys_out_pad.view(-1),
+            ignore_index=-1,
+            reduction="none",
+        )
+
+        nll = nll.view(pred_pad.shape[0], -1)
+
+        return nll
+
+
+class TransformerEncoderLayer(nn.Module):
+    """
+    Modified from torch.nn.TransformerEncoderLayer.
+
+    Example:
+        >>> encoder_layer = TransformerEncoderLayer(d_model=512, nhead=8)
+        >>> src = torch.rand(10, 32, 512)
+        >>> out = encoder_layer(src)
+    """
+
+    def __init__(
+        self,
+        d_model: int,
+        nhead: int,
+        dim_feedforward: int = 2048,
+        dropout: float = 0.1,
+        bypass_scale: float = 0.1,
+        layer_dropout: float = 0.075,
+    ) -> None:
+        """
+        Args:
+          d_model:
+            the number of expected features in the input (required).
+          nhead:
+            the number of heads in the multiheadattention models (required).
+          dim_feedforward:
+            the dimension of the feedforward network model (default=2048).
+          dropout:
+            the dropout value (default=0.1).
+          bypass_scale:
+            a scale on the layer's output, used in bypass (resnet-type) skip-connection;
+            when the layer is bypassed the final output will be a
+            weighted sum of the layer's input and layer's output with weights
+            (1.0-bypass_scale) and bypass_scale correspondingly (default=0.1).
+          layer_dropout:
+            the probability to bypass the layer (default=0.075).
+        """
+
+        super().__init__()
+
+        if bypass_scale < 0.0 or bypass_scale > 1.0:
+            raise ValueError("bypass_scale should be between 0.0 and 1.0")
+
+        if layer_dropout < 0.0 or layer_dropout > 1.0:
+            raise ValueError("layer_dropout should be between 0.0 and 1.0")
+
+        self.bypass_scale = bypass_scale
+        self.layer_dropout = layer_dropout
+
+        self.self_attn = MultiheadAttention(d_model, nhead)
+        # Implementation of Feedforward model
+
+        self.feed_forward = nn.Sequential(
+            ScaledLinear(d_model, dim_feedforward),
+            ActivationBalancer(channel_dim=-1),
+            DoubleSwish(),
+            nn.Dropout(dropout),
+            ScaledLinear(dim_feedforward, d_model, initial_scale=0.25),
+        )
+
+        self.norm_final = BasicNorm(d_model)
+
+        # try to ensure the output is close to zero-mean (or at least, zero-median).
+        self.balancer = ActivationBalancer(
+            channel_dim=-1, min_positive=0.45, max_positive=0.55, max_abs=6.0
+        )
+
+        self.dropout = nn.Dropout(dropout)
+
+    def forward(
+        self,
+        src: torch.Tensor,
+        src_mask: Optional[torch.Tensor] = None,
+        src_key_padding_mask: Optional[torch.Tensor] = None,
+        warmup: float = 1.0,
+    ) -> torch.Tensor:
+        """
+        Pass the input through the encoder layer.
+
+        Args:
+          src:
+            the sequence to the encoder layer of shape (S, N, C) (required).
+          src_mask:
+            the mask for the src sequence of shape (S, S) (optional).
+          src_key_padding_mask:
+            the mask for the src keys per batch of shape (N, S) (optional)
+          warmup:
+            controls selective bypass of layers; if < 1.0, we will
+            bypass the layer more frequently (default=1.0).
+
+        Returns:
+          Output tensor of the shape (S, N, C), where
+          S is the source sequence length,
+          N is the batch size,
+          C is the feature number.
+
+        """
+        src_orig = src
+
+        warmup_scale = min(self.bypass_scale + warmup, 1.0)
+        # alpha = 1.0 means fully use this encoder layer, 0.0 would mean
+        # completely bypass it.
+        if self.training:
+            alpha = (
+                warmup_scale
+                if torch.rand(()).item() <= (1.0 - self.layer_dropout)
+                else self.bypass_scale
+            )
+        else:
+            alpha = 1.0
+
+        src_att = self.self_attn(
+            src,
+            src,
+            src,
+            attn_mask=src_mask,
+            key_padding_mask=src_key_padding_mask,
+        )[0]
+        src = src + self.dropout(src_att)
+
+        src = src + self.dropout(self.feed_forward(src))
+
+        src = self.norm_final(self.balancer(src))
+
+        if alpha != 1.0:
+            src = alpha * src + (1.0 - alpha) * src_orig
+
+        return src
+
+
+class TransformerDecoderLayer(nn.Module):
+    """Modified from torch.nn.TransformerDecoderLayer.
+
+    Example:
+        >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8)
+        >>> memory = torch.rand(10, 32, 512)
+        >>> tgt = torch.rand(20, 32, 512)
+        >>> out = decoder_layer(tgt, memory)
+    """
+
+    def __init__(
+        self,
+        d_model: int,
+        nhead: int,
+        dim_feedforward: int = 2048,
+        dropout: float = 0.1,
+        bypass_scale: float = 0.1,
+        layer_dropout: float = 0.075,
+    ) -> None:
+
+        """
+        Args:
+          d_model:
+            the number of expected features in the input (required).
+          nhead:
+            the number of heads in the multiheadattention models (required).
+          dim_feedforward:
+            the dimension of the feedforward network model (default=2048).
+          dropout:
+            the dropout value (default=0.1).
+          bypass_scale:
+            a scale on the layer's output, used in bypass (resnet-type) skip-connection;
+            when the layer is bypassed, the final output will be a
+            weighted sum of the layer's input and layer's output with weights
+            (1.0-bypass_scale) and bypass_scale correspondingly (default=0.1).
+          layer_dropout:
+            the probability to bypass the layer (default=0.075).
+        """
+
+        super().__init__()
+
+        if bypass_scale < 0.0 or bypass_scale > 1.0:
+            raise ValueError("bypass_scale should be between 0.0 and 1.0")
+
+        if layer_dropout < 0.0 or layer_dropout > 1.0:
+            raise ValueError("layer_dropout should be between 0.0 and 1.0")
+
+        self.bypass_scale = bypass_scale
+        self.layer_dropout = layer_dropout
+
+        self.self_attn = MultiheadAttention(d_model, nhead)
+        self.src_attn = MultiheadAttention(d_model, nhead)
+
+        # Implementation of Feedforward model
+        self.feed_forward = nn.Sequential(
+            ScaledLinear(d_model, dim_feedforward),
+            ActivationBalancer(channel_dim=-1),
+            DoubleSwish(),
+            nn.Dropout(dropout),
+            ScaledLinear(dim_feedforward, d_model, initial_scale=0.25),
+        )
+
+        self.norm_final = BasicNorm(d_model)
+
+        # try to ensure the output is close to zero-mean (or at least, zero-median).
+        self.balancer = ActivationBalancer(
+            channel_dim=-1, min_positive=0.45, max_positive=0.55, max_abs=6.0
+        )
+
+        self.dropout = nn.Dropout(dropout)
+
+    def forward(
+        self,
+        tgt: torch.Tensor,
+        memory: torch.Tensor,
+        tgt_mask: Optional[torch.Tensor] = None,
+        memory_mask: Optional[torch.Tensor] = None,
+        tgt_key_padding_mask: Optional[torch.Tensor] = None,
+        memory_key_padding_mask: Optional[torch.Tensor] = None,
+        warmup: float = 1.0,
+    ) -> torch.Tensor:
+        """Pass the inputs (and mask) through the decoder layer.
+
+        Args:
+          tgt:
+            the sequence to the decoder layer of shape (T, N, C) (required).
+          memory:
+            the sequence from the last layer of the encoder of shape (S, N, C) (required).
+          tgt_mask:
+            the mask for the tgt sequence of shape (T, T) (optional).
+          memory_mask:
+            the mask for the memory sequence of shape (T, S) (optional).
+          tgt_key_padding_mask:
+            the mask for the tgt keys per batch of shape (N, T) (optional).
+          memory_key_padding_mask:
+            the mask for the memory keys per batch of shape (N, S) (optional).
+          warmup: controls selective bypass of layers; if < 1.0, we will
+            bypass the layer more frequently (default=1.0).
+
+        Returns:
+          Output tensor of the shape (T, N, C), where
+          S is the source sequence length,
+          T is the target sequence length,
+          N is the batch size,
+          C is the feature number.
+
+        """
+        tgt_orig = tgt
+
+        warmup_scale = min(self.bypass_scale + warmup, 1.0)
+        # alpha = 1.0 means fully use this encoder layer, 0.0 would mean
+        # completely bypass it.
+        if self.training:
+            alpha = (
+                warmup_scale
+                if torch.rand(()).item() <= (1.0 - self.layer_dropout)
+                else self.bypass_scale
+            )
+        else:
+            alpha = 1.0
+
+        tgt_att = self.self_attn(
+            tgt,
+            tgt,
+            tgt,
+            attn_mask=tgt_mask,
+            key_padding_mask=tgt_key_padding_mask,
+        )[0]
+        tgt = tgt + self.dropout(tgt_att)
+
+        src_att = self.src_attn(
+            tgt,
+            memory,
+            memory,
+            attn_mask=memory_mask,
+            key_padding_mask=memory_key_padding_mask,
+        )[0]
+        tgt = tgt + self.dropout(src_att)
+
+        tgt = tgt + self.dropout(self.feed_forward(tgt))
+
+        tgt = self.norm_final(self.balancer(tgt))
+
+        if alpha != 1.0:
+            tgt = alpha * tgt + (1.0 - alpha) * tgt_orig
+
+        return tgt
+
+
+class TransformerEncoder(nn.Module):
+    """TransformerEncoder is a stack of N encoder layers
+
+    Examples:
+        >>> encoder_layer = TransformerEncoderLayer(d_model=512, nhead=8)
+        >>> transformer_encoder = TransformerEncoder(encoder_layer, num_layers=6)
+        >>> src = torch.rand(10, 32, 512)
+        >>> out = transformer_encoder(src)
+    """
+
+    def __init__(
+        self,
+        encoder_layer: nn.Module,
+        num_layers: int,
+        aux_layers: List[int],
+    ) -> None:
+        """
+        Args:
+          encoder_layer:
+            an instance of the TransformerEncoderLayer() class (required).
+          num_layers:
+            the number of sub-encoder-layers in the encoder (required).
+          aux_layers:
+            list of indexes of sub-encoder-layers outputs to be combined (required).
+        """
+
+        super().__init__()
+        self.layers = nn.ModuleList(
+            [copy.deepcopy(encoder_layer) for i in range(num_layers)]
+        )
+        self.num_layers = num_layers
+
+        assert len(set(aux_layers)) == len(aux_layers)
+
+        assert num_layers - 1 not in aux_layers
+        self.aux_layers = aux_layers + [num_layers - 1]
+
+        self.combiner = RandomCombine(
+            num_inputs=len(self.aux_layers),
+            final_weight=0.5,
+            pure_prob=0.333,
+            stddev=2.0,
+        )
+
+    def forward(
+        self,
+        src: torch.Tensor,
+        mask: Optional[torch.Tensor] = None,
+        src_key_padding_mask: Optional[torch.Tensor] = None,
+        warmup: float = 1.0,
+    ) -> torch.Tensor:
+        """Pass the input through the encoder layers in turn.
+
+        Args:
+          src:
+            the input to the encoder of shape (S, N, C) (required).
+          mask:
+            the mask for the src sequence of shape (S, S) (optional).
+          src_key_padding_mask:
+            the mask for the src keys per batch of shape (N, S) (optional).
+          warmup:
+            controls selective bypass of layer; if < 1.0, we will
+            bypass the layer more frequently (default=1.0).
+
+        Returns:
+          Output tensor of the shape (S, N, C), where
+          S is the source sequence length,
+          N is the batch size,
+          C is the feature number.
+
+        """
+        output = src
+
+        outputs = []
+        for i, mod in enumerate(self.layers):
+            output = mod(
+                output,
+                src_mask=mask,
+                src_key_padding_mask=src_key_padding_mask,
+                warmup=warmup,
+            )
+
+            if i in self.aux_layers:
+                outputs.append(output)
+
+        output = self.combiner(outputs)
+
+        return output
+
+
+class TransformerDecoder(nn.Module):
+    """TransformerDecoder is a stack of N decoder layers
+
+    Examples:
+        >>> decoder_layer = TransformerDecoderLayer(d_model=512, nhead=8)
+        >>> transformer_decoder = TransformerDecoder(decoder_layer, num_layers=6)
+        >>> memory = torch.rand(10, 32, 512)
+        >>> tgt = torch.rand(20, 32, 512)
+        >>> out = transformer_decoder(tgt, memory)
+    """
+
+    def __init__(
+        self,
+        decoder_layer: nn.Module,
+        num_layers: int,
+        aux_layers: List[int],
+    ) -> None:
+        """
+        Args:
+          decoder_layer:
+            an instance of the TransformerDecoderLayer() class (required).
+          num_layers:
+            the number of decoder layers in the decoder (required).
+          aux_layers:
+            list of indexes of decoder layer outputs to be combined (required).
+        """
+
+        super().__init__()
+        self.layers = nn.ModuleList(
+            [copy.deepcopy(decoder_layer) for i in range(num_layers)]
+        )
+        self.num_layers = num_layers
+
+        assert len(set(aux_layers)) == len(aux_layers)
+
+        assert num_layers - 1 not in aux_layers
+        self.aux_layers = aux_layers + [num_layers - 1]
+
+        self.combiner = RandomCombine(
+            num_inputs=len(self.aux_layers),
+            final_weight=0.5,
+            pure_prob=0.333,
+            stddev=2.0,
+        )
+
+    def forward(
+        self,
+        tgt: torch.Tensor,
+        memory: torch.Tensor,
+        tgt_mask: Optional[torch.Tensor] = None,
+        memory_mask: Optional[torch.Tensor] = None,
+        tgt_key_padding_mask: Optional[torch.Tensor] = None,
+        memory_key_padding_mask: Optional[torch.Tensor] = None,
+        warmup: float = 1.0,
+    ) -> torch.Tensor:
+        """Pass the input (and mask) through the decoder layers in turn.
+
+        Args:
+          tgt:
+            the sequence to the decoder of shape (T, N, C) (required).
+          memory:
+            the sequence from the last layer of the encoder of shape (S, N, C) (required).
+          tgt_mask:
+            the mask for the tgt sequence of shape (T, T) (optional).
+          memory_mask:
+            the mask for the memory sequence of shape (T, S) (optional).
+          tgt_key_padding_mask:
+            the mask for the tgt keys per batch of shape (N, T)  (optional).
+          memory_key_padding_mask:
+            the mask for the memory keys per batch of shape (N, S) (optional).
+          warmup:
+            controls selective bypass of layer; if < 1.0, we will
+            bypass the layer more frequently (default=1.0).
+
+        Returns:
+          Output tensor of the shape (T, N, C), where
+          S is the source sequence length,
+          T is the target sequence length,
+          N is the batch size,
+          C is the feature number.
+
+        """
+        output = tgt
+
+        outputs = []
+        for i, mod in enumerate(self.layers):
+            output = mod(
+                output,
+                memory,
+                tgt_mask=tgt_mask,
+                memory_mask=memory_mask,
+                tgt_key_padding_mask=tgt_key_padding_mask,
+                memory_key_padding_mask=memory_key_padding_mask,
+                warmup=warmup,
+            )
+
+            if i in self.aux_layers:
+                outputs.append(output)
+
+        output = self.combiner(outputs)
+
+        return output
+
+
+class PositionalEncoding(nn.Module):
+    """This class implements the positional encoding
+    proposed in the following paper:
+
+    - Attention Is All You Need: https://arxiv.org/pdf/1706.03762.pdf
+
+        PE(pos, 2i) = sin(pos / (10000^(2i/d_modle))
+        PE(pos, 2i+1) = cos(pos / (10000^(2i/d_modle))
+
+    Note:
+
+      1 / (10000^(2i/d_model)) = exp(-log(10000^(2i/d_model)))
+                               = exp(-1* 2i / d_model * log(100000))
+                               = exp(2i * -(log(10000) / d_model))
+    """
+
+    def __init__(self, d_model: int, dropout: float = 0.1) -> None:
+        """
+        Args:
+          d_model: Embedding dimension.
+          dropout: Dropout probability to be applied to the output of this module.
+        """
+        super().__init__()
+        self.d_model = d_model
+        self.xscale = math.sqrt(self.d_model)
+        self.dropout = nn.Dropout(p=dropout)
+        # not doing: self.pe = None because of errors thrown by torchscript
+        self.pe = torch.zeros(1, 0, self.d_model, dtype=torch.float32)
+
+    def extend_pe(self, x: torch.Tensor) -> None:
+        """Extend the time t in the positional encoding if required.
+        The shape of `self.pe` is (1, T1, d_model). The shape of the input x
+        is (N, T, d_model). If T > T1, then we change the shape of self.pe
+        to (N, T, d_model). Otherwise, nothing is done.
+
+        Args:
+          x:
+            It is a tensor of shape (N, T, C).
+            T is the target sequence length,
+            N is the batch size,
+            C is the feature number.
+        """
+        if self.pe is not None:
+            if self.pe.size(1) >= x.size(1):
+                self.pe = self.pe.to(dtype=x.dtype, device=x.device)
+                return
+        pe = torch.zeros(x.size(1), self.d_model, dtype=torch.float32)
+        position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
+        div_term = torch.exp(
+            torch.arange(0, self.d_model, 2, dtype=torch.float32)
+            * -(math.log(10000.0) / self.d_model)
+        )
+        pe[:, 0::2] = torch.sin(position * div_term)
+        pe[:, 1::2] = torch.cos(position * div_term)
+        pe = pe.unsqueeze(0)
+        # Now pe is of shape (1, T, d_model), where T is x.size(1)
+        self.pe = pe.to(device=x.device, dtype=x.dtype)
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        """
+        Add positional encoding.
+
+        Args:
+          x: Input of shape is (N, T, C)
+
+        Returns:
+          A tensor of the same shape (N, T, C),
+          T is the target sequence length,
+          N is the batch size,
+          C is the feature number.
+
+        """
+        self.extend_pe(x)
+        x = x + self.pe[:, : x.size(1), :]
+        return self.dropout(x)
+
+
+def encoder_padding_mask(
+    max_len: int, supervisions: Optional[Supervisions] = None
+) -> Optional[torch.Tensor]:
+    """Make mask tensor containing indexes of padded part.
+
+    TODO:
+      This function **assumes** that the model uses
+      a subsampling factor of 4. We should remove that
+      assumption later.
+
+    Args:
+      max_len:
+        Maximum length of input features.
+        CAUTION: It is the length after subsampling.
+      supervisions:
+        Supervision in lhotse format.
+        See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32  # noqa
+        (CAUTION: It contains length information, i.e., start and number of
+         frames, before subsampling)
+
+    Returns:
+      Mask tensor of dimension (batch_size, input_length),
+      True denotes the masked indices.
+    """
+    if supervisions is None:
+        return None
+
+    supervision_segments = torch.stack(
+        (
+            supervisions["sequence_idx"],
+            supervisions["start_frame"],
+            supervisions["num_frames"],
+        ),
+        1,
+    ).to(torch.int32)
+
+    lengths = [0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)]
+    for idx in range(supervision_segments.size(0)):
+        # Note: TorchScript doesn't allow to unpack tensors as tuples
+        sequence_idx = supervision_segments[idx, 0].item()
+        start_frame = supervision_segments[idx, 1].item()
+        num_frames = supervision_segments[idx, 2].item()
+        lengths[sequence_idx] = start_frame + num_frames
+
+    lengths = [((i - 1) // 2 - 1) // 2 for i in lengths]
+    bs = int(len(lengths))
+    seq_range = torch.arange(0, max_len, dtype=torch.int64)
+    seq_range_expand = seq_range.unsqueeze(0).expand(bs, max_len)
+    # Note: TorchScript doesn't implement Tensor.new()
+    seq_length_expand = torch.tensor(
+        lengths, device=seq_range_expand.device, dtype=seq_range_expand.dtype
+    ).unsqueeze(-1)
+    mask = seq_range_expand >= seq_length_expand
+
+    return mask
+
+
+def decoder_padding_mask(ys_pad: torch.Tensor, ignore_id: int = -1) -> torch.Tensor:
+    """Generate a length mask for input.
+
+    The masked position are filled with True,
+    Unmasked positions are filled with False.
+
+    Args:
+      ys_pad:
+        padded tensor of dimension (batch_size, input_length).
+      ignore_id:
+        the ignored number (the padding number) in ys_pad
+
+    Returns:
+        A bool tensor of the same shape as the input tensor.
+    """
+    ys_mask = ys_pad == ignore_id
+    return ys_mask
+
+
+def generate_square_subsequent_mask(sz: int) -> torch.Tensor:
+    """Generate a square mask for the sequence. The masked positions are
+    filled with float('-inf'). Unmasked positions are filled with float(0.0).
+    The mask can be used for masked self-attention.
+
+    For instance, if sz is 3, it returns::
+
+        tensor([[0., -inf, -inf],
+                [0., 0., -inf],
+                [0., 0., 0]])
+
+    Args:
+      sz: mask size
+
+    Returns:
+      A square mask tensor of dimension (sz, sz)
+    """
+    mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
+    mask = (
+        mask.float()
+        .masked_fill(mask == 0, float("-inf"))
+        .masked_fill(mask == 1, float(0.0))
+    )
+    return mask
+
+
+def add_sos(token_ids: List[List[int]], sos_id: int) -> List[List[int]]:
+    """Prepend sos_id to each utterance.
+
+    Args:
+      token_ids:
+        A list-of-list of token IDs. Each sublist contains
+        token IDs (e.g., word piece IDs) of an utterance.
+      sos_id:
+        The ID of the SOS token.
+
+    Return:
+      Return a new list-of-list, where each sublist starts
+      with SOS ID.
+    """
+    return [[sos_id] + utt for utt in token_ids]
+
+
+def add_eos(token_ids: List[List[int]], eos_id: int) -> List[List[int]]:
+    """Append eos_id to each utterance.
+
+    Args:
+      token_ids:
+        A list-of-lists of token IDs. Each sublist contains
+        token IDs (e.g., word piece IDs) of an utterance.
+      eos_id:
+        The ID of the EOS token.
+
+    Return:
+      Return a new list-of-lists, where each sublist ends
+      with EOS ID.
+    """
+    return [utt + [eos_id] for utt in token_ids]
+
+
+def tolist(t: torch.Tensor) -> List[int]:
+    """Used by jit"""
+    return torch.jit.annotate(List[int], t.tolist())
diff --git a/egs/tedlium3/ASR/local/convert_transcript_words_to_bpe_ids.py b/egs/tedlium3/ASR/local/convert_transcript_words_to_bpe_ids.py
index 9dbcc9d9e..19ba8d24b 100644
--- a/egs/tedlium3/ASR/local/convert_transcript_words_to_bpe_ids.py
+++ b/egs/tedlium3/ASR/local/convert_transcript_words_to_bpe_ids.py
@@ -4,16 +4,18 @@
 """
 Convert a transcript based on words to a list of BPE ids.
 
-For example, if we use 2 as the encoding id of :
+For example, if we use 2 as the encoding id of 
+Note: it, inserts a space token before each 
 
 texts = ['this is a  day']
-spm_ids = [[38, 33, 6, 2, 316]]
+spm_ids = [[38, 33, 6, 15, 2, 316]]
 
 texts = [' this is a sunny day']
-spm_ids = [[2, 38, 33, 6, 118, 11, 11, 21, 316]]
+spm_ids = [[15, 2, 38, 33, 6, 118, 11, 11, 21, 316]]
 
 texts = ['']
-spm_ids = [[2]]
+spm_ids = [[15, 2]]
+
 """
 
 import argparse
@@ -38,29 +40,27 @@ def get_args():
 
 def convert_texts_into_ids(
     texts: List[str],
-    unk_id: int,
     sp: spm.SentencePieceProcessor,
 ) -> List[List[int]]:
     """
     Args:
       texts:
         A string list of transcripts, such as ['Today is Monday', 'It's sunny'].
-      unk_id:
-        A number id for the token ''.
+      sp:
+        A sentencepiece BPE model.
     Returns:
       Return an integer list of bpe ids.
     """
     y = []
     for text in texts:
-        y_ids = []
         if "" in text:
-            text_segments = text.split("")
-            id_segments = sp.encode(text_segments, out_type=int)
+            id_segments = sp.encode(text.split(""), out_type=int)
+
+            y_ids = []
             for i in range(len(id_segments)):
-                if i != len(id_segments) - 1:
-                    y_ids.extend(id_segments[i] + [unk_id])
-                else:
-                    y_ids.extend(id_segments[i])
+                y_ids += id_segments[i]
+                if i < len(id_segments) - 1:
+                    y_ids += [sp.piece_to_id("▁"), sp.unk_id()]
         else:
             y_ids = sp.encode(text, out_type=int)
         y.append(y_ids)
@@ -70,19 +70,13 @@ def convert_texts_into_ids(
 
 def main():
     args = get_args()
-    texts = args.texts
-    bpe_model = args.bpe_model
 
     sp = spm.SentencePieceProcessor()
-    sp.load(bpe_model)
-    unk_id = sp.piece_to_id("")
+    sp.load(args.bpe_model)
 
-    y = convert_texts_into_ids(
-        texts=texts,
-        unk_id=unk_id,
-        sp=sp,
-    )
-    logging.info(f"The input texts: {texts}")
+    y = convert_texts_into_ids(texts=args.texts, sp=sp)
+
+    logging.info(f"The input texts: {args.texts}")
     logging.info(f"The encoding ids: {y}")
 
 
diff --git a/egs/tedlium3/ASR/local/convert_transcript_words_to_tokens.py b/egs/tedlium3/ASR/local/convert_transcript_words_to_tokens.py
deleted file mode 120000
index 2ce13fd69..000000000
--- a/egs/tedlium3/ASR/local/convert_transcript_words_to_tokens.py
+++ /dev/null
@@ -1 +0,0 @@
-../../../librispeech/ASR/local/convert_transcript_words_to_tokens.py
\ No newline at end of file
diff --git a/egs/tedlium3/ASR/local/generate_unique_lexicon.py b/egs/tedlium3/ASR/local/generate_unique_lexicon.py
deleted file mode 120000
index c0aea1403..000000000
--- a/egs/tedlium3/ASR/local/generate_unique_lexicon.py
+++ /dev/null
@@ -1 +0,0 @@
-../../../librispeech/ASR/local/generate_unique_lexicon.py
\ No newline at end of file
diff --git a/egs/tedlium3/ASR/local/prepare_lang.py b/egs/tedlium3/ASR/local/prepare_lang.py
deleted file mode 120000
index 747f2ab39..000000000
--- a/egs/tedlium3/ASR/local/prepare_lang.py
+++ /dev/null
@@ -1 +0,0 @@
-../../../librispeech/ASR/local/prepare_lang.py
\ No newline at end of file
diff --git a/egs/tedlium3/ASR/local/prepare_lexicon.py b/egs/tedlium3/ASR/local/prepare_lexicon.py
deleted file mode 100755
index b9160b6d4..000000000
--- a/egs/tedlium3/ASR/local/prepare_lexicon.py
+++ /dev/null
@@ -1,94 +0,0 @@
-#!/usr/bin/env python3
-# Copyright    2022  Xiaomi Corp.        (authors: Mingshuang Luo)
-#
-# See ../../../../LICENSE for clarification regarding multiple authors
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-"""
-This script takes as input supervisions json dir "data/manifests"
-consisting of supervisions_train.json and does the following:
-
-1. Generate lexicon_words.txt.
-
-"""
-import argparse
-import logging
-from pathlib import Path
-
-import lhotse
-
-
-def get_args():
-    parser = argparse.ArgumentParser()
-    parser.add_argument(
-        "--manifests-dir",
-        type=str,
-        help="""Input directory.
-        """,
-    )
-    parser.add_argument(
-        "--lang-dir",
-        type=str,
-        help="""Output directory.
-        """,
-    )
-
-    return parser.parse_args()
-
-
-def prepare_lexicon(manifests_dir: str, lang_dir: str):
-    """
-    Args:
-      manifests_dir:
-        The manifests directory, e.g., data/manifests.
-      lang_dir:
-        The language directory, e.g., data/lang_phone.
-
-    Return:
-      The lexicon_words.txt file.
-    """
-    words = set()
-
-    lexicon = Path(lang_dir) / "lexicon_words.txt"
-    sups = lhotse.load_manifest(f"{manifests_dir}/tedlium_supervisions_train.jsonl.gz")
-    for s in sups:
-        # list the words units and filter the empty item
-        words_list = list(filter(None, s.text.split()))
-
-        for word in words_list:
-            if word not in words and word != "":
-                words.add(word)
-
-    with open(lexicon, "w") as f:
-        for word in sorted(words):
-            f.write(word + "  " + word)
-            f.write("\n")
-
-
-def main():
-    args = get_args()
-    manifests_dir = Path(args.manifests_dir)
-    lang_dir = Path(args.lang_dir)
-
-    logging.info("Generating lexicon_words.txt")
-    prepare_lexicon(manifests_dir, lang_dir)
-
-
-if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-
-    logging.basicConfig(format=formatter, level=logging.INFO)
-
-    main()
diff --git a/egs/tedlium3/ASR/local/prepare_transcripts.py b/egs/tedlium3/ASR/local/prepare_transcripts.py
index 7ea4e89a4..d4ccdd1e3 100755
--- a/egs/tedlium3/ASR/local/prepare_transcripts.py
+++ b/egs/tedlium3/ASR/local/prepare_transcripts.py
@@ -1,5 +1,6 @@
 #!/usr/bin/env python3
-# Copyright    2021  Xiaomi Corp.        (authors: Mingshuang Luo)
+# Copyright    2021  Xiaomi Corp.        (author: Mingshuang Luo)
+# Copyright    2022  Behavox LLC.        (author: Daniil Kulko)
 #
 # See ../../../../LICENSE for clarification regarding multiple authors
 #
@@ -17,68 +18,67 @@
 
 
 """
-This script takes as input supervisions json dir "data/manifests"
-consisting of supervisions_train.json and does the following:
-
-1. Generate train.text.
+This script takes input text file and removes all words
+that iclude any character out of English alphabet.
 
 """
 import argparse
 import logging
+import re
 from pathlib import Path
 
-import lhotse
-
 
 def get_args():
     parser = argparse.ArgumentParser()
     parser.add_argument(
-        "--manifests-dir",
+        "--input-text-path",
         type=str,
-        help="""Input directory.
-        """,
+        help="Input text file path.",
     )
     parser.add_argument(
-        "--lang-dir",
+        "--output-text-path",
         type=str,
-        help="""Output directory.
-        """,
+        help="Output text file path.",
     )
 
     return parser.parse_args()
 
 
-def prepare_transcripts(manifests_dir: str, lang_dir: str):
+def prepare_transcripts(input_text_path: Path, output_text_path: Path) -> None:
     """
     Args:
-      manifests_dir:
-        The manifests directory, e.g., data/manifests.
-      lang_dir:
-        The language directory, e.g., data/lang_phone.
+      input_text_path:
+        The input data text file path, e.g., data/lang/train_orig.txt.
+      output_text_path:
+        The output data text file path, e.g., data/lang/train.txt.
 
     Return:
-      The train.text in lang_dir.
+      Saved text file in output_text_path.
     """
-    texts = []
 
-    train_text = Path(lang_dir) / "train.text"
-    sups = lhotse.load_manifest(f"{manifests_dir}/tedlium_supervisions_train.jsonl.gz")
-    for s in sups:
-        texts.append(s.text)
+    foreign_chr_check = re.compile(r"[^a-z']")
 
-    with open(train_text, "w") as f:
-        for text in texts:
-            f.write(text)
-            f.write("\n")
+    logging.info(f"Loading {input_text_path.name}")
+    with open(input_text_path, "r", encoding="utf8") as f:
+        texts = {t.rstrip("\n") for t in f}
+
+    texts = {
+        " ".join([w for w in t.split() if foreign_chr_check.search(w) is None])
+        for t in texts
+    }
+
+    with open(output_text_path, "w+", encoding="utf8") as f:
+        for t in texts:
+            f.write(f"{t}\n")
 
 
-def main():
+def main() -> None:
     args = get_args()
-    manifests_dir = Path(args.manifests_dir)
-    lang_dir = Path(args.lang_dir)
+    input_text_path = Path(args.input_text_path)
+    output_text_path = Path(args.output_text_path)
 
-    logging.info("Generating train.text")
-    prepare_transcripts(manifests_dir, lang_dir)
+    logging.info(f"Generating {output_text_path.name}")
+    prepare_transcripts(input_text_path, output_text_path)
 
 
 if __name__ == "__main__":
diff --git a/egs/tedlium3/ASR/local/prepare_words.py b/egs/tedlium3/ASR/local/prepare_words.py
new file mode 100755
index 000000000..a37d0f08f
--- /dev/null
+++ b/egs/tedlium3/ASR/local/prepare_words.py
@@ -0,0 +1,83 @@
+#!/usr/bin/env python3
+# Copyright    2022  Behavox LLC.        (authors: Daniil Kulko)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""
+This script takes as input supervisions json dir "data/manifests"
+consisting of tedlium_supervisions_train.json and does the following:
+
+1. Generate words.txt.
+
+"""
+import argparse
+import logging
+import re
+from pathlib import Path
+
+
+def get_args():
+    parser = argparse.ArgumentParser()
+    parser.add_argument(
+        "--lang-dir",
+        type=str,
+        help="Output directory.",
+    )
+
+    return parser.parse_args()
+
+
+def prepare_words(lang_dir: str) -> None:
+    """
+    Args:
+      lang_dir:
+        The language directory, e.g., data/lang.
+
+    Return:
+      The words.txt file.
+    """
+
+    words_orig_path = Path(lang_dir) / "words_orig.txt"
+    words_path = Path(lang_dir) / "words.txt"
+
+    foreign_chr_check = re.compile(r"[^a-z']")
+
+    logging.info(f"Loading {words_orig_path.name}")
+    with open(words_orig_path, "r", encoding="utf8") as f:
+        words = {w for w_compl in f for w in w_compl.strip("-\n").split("_")}
+    words = {w for w in words if foreign_chr_check.search(w) is None and w != ""}
+    words.add("")
+    words = ["", "!SIL"] + sorted(words) + ["#0", "", ""]
+
+    with open(words_path, "w+", encoding="utf8") as f:
+        for idx, word in enumerate(words):
+            f.write(f"{word} {idx}\n")
+
+
+def main() -> None:
+    args = get_args()
+    lang_dir = Path(args.lang_dir)
+
+    logging.info("Generating words.txt")
+    prepare_words(lang_dir)
+
+
+if __name__ == "__main__":
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+    logging.basicConfig(format=formatter, level=logging.INFO)
+
+    main()
diff --git a/egs/tedlium3/ASR/local/test_prepare_lang.py b/egs/tedlium3/ASR/local/test_prepare_lang.py
deleted file mode 120000
index f0f864998..000000000
--- a/egs/tedlium3/ASR/local/test_prepare_lang.py
+++ /dev/null
@@ -1 +0,0 @@
-../../../librispeech/ASR/local/test_prepare_lang.py
\ No newline at end of file
diff --git a/egs/tedlium3/ASR/prepare.sh b/egs/tedlium3/ASR/prepare.sh
index 272cf7aed..3d90436ff 100755
--- a/egs/tedlium3/ASR/prepare.sh
+++ b/egs/tedlium3/ASR/prepare.sh
@@ -5,7 +5,6 @@ export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
 
 set -eou pipefail
 
-nj=15
 stage=0
 stop_stage=100
 
@@ -63,6 +62,13 @@ if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
     mv $dl_dir/TEDLIUM_release-3 $dl_dir/tedlium3
   fi
 
+  # Download big and small 4 gram lanuage models
+  if [ ! -d $dl_dir/lm ]; then
+    wget --continue http://kaldi-asr.org/models/5/4gram_small.arpa.gz -P $dl_dir/lm
+    wget --continue http://kaldi-asr.org/models/5/4gram_big.arpa.gz -P $dl_dir/lm
+    gzip -d $dl_dir/lm/4gram_small.arpa.gz $dl_dir/lm/4gram_big.arpa.gz
+  fi
+
   # If you have pre-downloaded it to /path/to/musan,
   # you can create a symlink
   #
@@ -100,7 +106,14 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
 
   if [ ! -e data/fbank/.tedlium3.done ]; then
     mkdir -p data/fbank
+
     python3 ./local/compute_fbank_tedlium.py
+
+    gunzip -c data/fbank/tedlium_cuts_train.jsonl.gz | shuf | \
+    gzip -c > data/fbank/tedlium_cuts_train-shuf.jsonl.gz
+    mv data/fbank/tedlium_cuts_train-shuf.jsonl.gz \
+       data/fbank/tedlium_cuts_train.jsonl.gz
+
     touch data/fbank/.tedlium3.done
   fi
 fi
@@ -115,28 +128,24 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
 fi
 
 if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
-  log "Stage 5: Prepare phone based lang"
-  lang_dir=data/lang_phone
+  log "Stage 5: Prepare BPE train data and set of words"
+  lang_dir=data/lang
   mkdir -p $lang_dir
 
-  if [ ! -f $lang_dir/train.text ]; then
+  if [ ! -f $lang_dir/train.txt ]; then
+    gunzip -c $dl_dir/tedlium3/LM/*.en.gz | sed 's: <\/s>::g' > $lang_dir/train_orig.txt
+
     ./local/prepare_transcripts.py \
-      --lang-dir $lang_dir \
-      --manifests-dir data/manifests
+      --input-text-path $lang_dir/train_orig.txt \
+      --output-text-path $lang_dir/train.txt
   fi
 
-  if [ ! -f $lang_dir/lexicon_words.txt ]; then
-    ./local/prepare_lexicon.py \
-      --lang-dir $lang_dir \
-      --manifests-dir data/manifests
-  fi
+  if [ ! -f $lang_dir/words.txt ]; then
 
-  (echo '!SIL SIL'; echo ' '; ) |
-    cat - $lang_dir/lexicon_words.txt |
-    sort | uniq > $lang_dir/lexicon.txt
+    awk '{print $1}' $dl_dir/tedlium3/TEDLIUM.152k.dic |
+    sed 's:([0-9])::g' | sort | uniq > $lang_dir/words_orig.txt
 
-  if [ ! -f $lang_dir/L_disambig.pt ]; then
-    ./local/prepare_lang.py --lang-dir $lang_dir
+    ./local/prepare_words.py --lang-dir $lang_dir
   fi
 fi
 
@@ -148,25 +157,56 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
     mkdir -p $lang_dir
     # We reuse words.txt from phone based lexicon
     # so that the two can share G.pt later.
-    cp data/lang_phone/words.txt $lang_dir
-
-    if [ ! -f $lang_dir/transcript_words.txt ]; then
-      log "Generate data for BPE training"
-      cat data/lang_phone/train.text |
-      cut -d " " -f 2- > $lang_dir/transcript_words.txt
-      # remove the  for transcript_words.txt
-      sed -i 's/ //g' $lang_dir/transcript_words.txt
-      sed -i 's/ //g' $lang_dir/transcript_words.txt
-      sed -i 's///g' $lang_dir/transcript_words.txt
-    fi
+    cp data/lang/words.txt $lang_dir
 
     ./local/train_bpe_model.py \
       --lang-dir $lang_dir \
       --vocab-size $vocab_size \
-      --transcript $lang_dir/transcript_words.txt
+      --transcript data/lang/train.txt
 
     if [ ! -f $lang_dir/L_disambig.pt ]; then
-      ./local/prepare_lang_bpe.py --lang-dir $lang_dir
+      ./local/prepare_lang_bpe.py --lang-dir $lang_dir --oov ""
+    fi
+  done
+fi
+
+if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
+  log "Stage 7: Prepare G"
+  # We assume you have install kaldilm, if not, please install
+  # it using: pip install kaldilm
+
+  mkdir -p data/lm
+  if [ ! -f data/lm/G_4_gram_small.fst.txt ]; then
+    # It is used in building HLG
+    python3 -m kaldilm \
+      --read-symbol-table="data/lang/words.txt" \
+      --disambig-symbol='#0' \
+      --max-order=4 \
+      --max-arpa-warnings=-1 \
+      $dl_dir/lm/4gram_small.arpa > data/lm/G_4_gram_small.fst.txt
+  fi
+
+  if [ ! -f data/lm/G_4_gram_big.fst.txt ]; then
+    # It is used for LM rescoring
+    python3 -m kaldilm \
+      --read-symbol-table="data/lang/words.txt" \
+      --disambig-symbol='#0' \
+      --max-order=4 \
+      --max-arpa-warnings=-1 \
+      $dl_dir/lm/4gram_big.arpa > data/lm/G_4_gram_big.fst.txt
+  fi
+fi
+
+if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
+  log "Stage 8: Compile HLG"
+
+  for vocab_size in ${vocab_sizes[@]}; do
+    lang_dir=data/lang_bpe_${vocab_size}
+
+    if [ ! -f $lang_dir/HLG.pt ]; then
+      ./local/compile_hlg.py \
+        --lang-dir $lang_dir \
+        --lm G_4_gram_small
     fi
   done
 fi
diff --git a/icefall/decode.py b/icefall/decode.py
index 68e490c5e..23f9fb9b3 100644
--- a/icefall/decode.py
+++ b/icefall/decode.py
@@ -466,9 +466,7 @@ def one_best_decoding(
     Return:
       An FsaVec containing linear paths.
     """
-
     if lm_scale_list is not None:
-
         ans = dict()
         saved_am_scores = lattice.scores - lattice.lm_scores
         for lm_scale in lm_scale_list:
diff --git a/test/test_lexicon.py b/test/test_lexicon.py
index 69867efc7..b1beab3f6 100755
--- a/test/test_lexicon.py
+++ b/test/test_lexicon.py
@@ -112,7 +112,7 @@ def uniq_lexicon_test():
     # But there is no word "ca" in the lexicon, so our
     # implementation returns the id of ""
     print(token_ids, expected_token_ids)
-    assert token_ids.tolist() == [[sp.unk_id()]]
+    assert token_ids.tolist() == [[sp.piece_to_id("▁"), sp.unk_id()]]
 
     # case 3: With OOV
     texts = ["foo"]

From fbc88948044278b687a57309248eb2ae6df0a415 Mon Sep 17 00:00:00 2001
From: Fangjun Kuang 
Date: Wed, 14 Dec 2022 13:47:23 +0800
Subject: [PATCH 049/174] Add comment for compile_hlg_using_openfst.py (#762)

---
 egs/librispeech/ASR/prepare.sh | 11 +++++++++--
 1 file changed, 9 insertions(+), 2 deletions(-)

diff --git a/egs/librispeech/ASR/prepare.sh b/egs/librispeech/ASR/prepare.sh
index 11c8e1066..59bed8389 100755
--- a/egs/librispeech/ASR/prepare.sh
+++ b/egs/librispeech/ASR/prepare.sh
@@ -302,13 +302,20 @@ fi
 if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
   log "Stage 9: Compile HLG"
   ./local/compile_hlg.py --lang-dir data/lang_phone
-  ./local/compile_hlg_using_openfst.py --lang-dir data/lang_phone
+
+  # Note If ./local/compile_hlg.py throws OOM,
+  # please switch to the following command
+  #
+  # ./local/compile_hlg_using_openfst.py --lang-dir data/lang_phone
 
   for vocab_size in ${vocab_sizes[@]}; do
     lang_dir=data/lang_bpe_${vocab_size}
     ./local/compile_hlg.py --lang-dir $lang_dir
 
-    ./local/compile_hlg_using_openfst.py --lang-dir $lang_dir
+    # Note If ./local/compile_hlg.py throws OOM,
+    # please switch to the following command
+    #
+    # ./local/compile_hlg_using_openfst.py --lang-dir $lang_dir
   done
 fi
 

From ad475ec10dec864373099ba541cad5f743a4726b Mon Sep 17 00:00:00 2001
From: Wei Kang 
Date: Thu, 15 Dec 2022 19:07:28 +0800
Subject: [PATCH 050/174] Add documents for pruned_transducer_stateless (#526)

* begin to add documents for pruned_transducer_stateless

* Move lstm docs to Streaming folder

* Add documents for pruned transducer stateless models

* Move zipformer mmi to non-streaming recipe

* Add more docs for streaming decoding

* Fix typo
---
 docs/source/index.rst                         |   8 +
 .../aishell/conformer_ctc.rst                 |   0
 .../aishell-conformer-ctc-tensorboard-log.jpg | Bin
 .../aishell-tdnn-lstm-ctc-tensorboard-log.jpg | Bin
 ...cer_stateless_modified-tensorboard-log.png | Bin
 .../{ => Non-streaming-ASR}/aishell/index.rst |   0
 .../aishell/stateless_transducer.rst          |   0
 .../aishell/tdnn_lstm_ctc.rst                 |   0
 .../recipes/Non-streaming-ASR/index.rst       |  10 +
 .../librispeech/conformer_ctc.rst             |   0
 ...rispeech-conformer-ctc-tensorboard-log.png | Bin
 ...eech-pruned-transducer-tensorboard-log.jpg | Bin 0 -> 566971 bytes
 .../librispeech/index.rst                     |   1 +
 .../pruned_transducer_stateless.rst           | 545 +++++++++++++
 .../librispeech/tdnn_lstm_ctc.rst             |   0
 .../librispeech/zipformer_mmi.rst             |   0
 .../{ => Non-streaming-ASR}/timit/index.rst   |   0
 .../timit/tdnn_ligru_ctc.rst                  |   0
 .../timit/tdnn_lstm_ctc.rst                   |   0
 .../yesno/images/tdnn-tensorboard-log.png     | Bin
 .../{ => Non-streaming-ASR}/yesno/index.rst   |   0
 .../{ => Non-streaming-ASR}/yesno/tdnn.rst    |   0
 docs/source/recipes/Streaming-ASR/index.rst   |  12 +
 .../recipes/Streaming-ASR/introduction.rst    |  52 ++
 ...speech-lstm-transducer-tensorboard-log.png | Bin
 ...eech-pruned-transducer-tensorboard-log.jpg | Bin 0 -> 560358 bytes
 .../Streaming-ASR/librispeech/index.rst       |   9 +
 .../lstm_pruned_stateless_transducer.rst      |   0
 .../pruned_transducer_stateless.rst           | 735 ++++++++++++++++++
 docs/source/recipes/index.rst                 |   6 +-
 30 files changed, 1374 insertions(+), 4 deletions(-)
 rename docs/source/recipes/{ => Non-streaming-ASR}/aishell/conformer_ctc.rst (100%)
 rename docs/source/recipes/{ => Non-streaming-ASR}/aishell/images/aishell-conformer-ctc-tensorboard-log.jpg (100%)
 rename docs/source/recipes/{ => Non-streaming-ASR}/aishell/images/aishell-tdnn-lstm-ctc-tensorboard-log.jpg (100%)
 rename docs/source/recipes/{ => Non-streaming-ASR}/aishell/images/aishell-transducer_stateless_modified-tensorboard-log.png (100%)
 rename docs/source/recipes/{ => Non-streaming-ASR}/aishell/index.rst (100%)
 rename docs/source/recipes/{ => Non-streaming-ASR}/aishell/stateless_transducer.rst (100%)
 rename docs/source/recipes/{ => Non-streaming-ASR}/aishell/tdnn_lstm_ctc.rst (100%)
 create mode 100644 docs/source/recipes/Non-streaming-ASR/index.rst
 rename docs/source/recipes/{ => Non-streaming-ASR}/librispeech/conformer_ctc.rst (100%)
 rename docs/source/recipes/{ => Non-streaming-ASR}/librispeech/images/librispeech-conformer-ctc-tensorboard-log.png (100%)
 create mode 100644 docs/source/recipes/Non-streaming-ASR/librispeech/images/librispeech-pruned-transducer-tensorboard-log.jpg
 rename docs/source/recipes/{ => Non-streaming-ASR}/librispeech/index.rst (82%)
 create mode 100644 docs/source/recipes/Non-streaming-ASR/librispeech/pruned_transducer_stateless.rst
 rename docs/source/recipes/{ => Non-streaming-ASR}/librispeech/tdnn_lstm_ctc.rst (100%)
 rename docs/source/recipes/{ => Non-streaming-ASR}/librispeech/zipformer_mmi.rst (100%)
 rename docs/source/recipes/{ => Non-streaming-ASR}/timit/index.rst (100%)
 rename docs/source/recipes/{ => Non-streaming-ASR}/timit/tdnn_ligru_ctc.rst (100%)
 rename docs/source/recipes/{ => Non-streaming-ASR}/timit/tdnn_lstm_ctc.rst (100%)
 rename docs/source/recipes/{ => Non-streaming-ASR}/yesno/images/tdnn-tensorboard-log.png (100%)
 rename docs/source/recipes/{ => Non-streaming-ASR}/yesno/index.rst (100%)
 rename docs/source/recipes/{ => Non-streaming-ASR}/yesno/tdnn.rst (100%)
 create mode 100644 docs/source/recipes/Streaming-ASR/index.rst
 create mode 100644 docs/source/recipes/Streaming-ASR/introduction.rst
 rename docs/source/recipes/{ => Streaming-ASR}/librispeech/images/librispeech-lstm-transducer-tensorboard-log.png (100%)
 create mode 100644 docs/source/recipes/Streaming-ASR/librispeech/images/streaming-librispeech-pruned-transducer-tensorboard-log.jpg
 create mode 100644 docs/source/recipes/Streaming-ASR/librispeech/index.rst
 rename docs/source/recipes/{ => Streaming-ASR}/librispeech/lstm_pruned_stateless_transducer.rst (100%)
 create mode 100644 docs/source/recipes/Streaming-ASR/librispeech/pruned_transducer_stateless.rst

diff --git a/docs/source/index.rst b/docs/source/index.rst
index be9977ca9..4ea446259 100644
--- a/docs/source/index.rst
+++ b/docs/source/index.rst
@@ -22,6 +22,14 @@ speech recognition recipes using `k2 `_.
 
    installation/index
    model-export/index
+
+.. toctree::
+   :maxdepth: 3
+
    recipes/index
+
+.. toctree::
+   :maxdepth: 2
+
    contributing/index
    huggingface/index
diff --git a/docs/source/recipes/aishell/conformer_ctc.rst b/docs/source/recipes/Non-streaming-ASR/aishell/conformer_ctc.rst
similarity index 100%
rename from docs/source/recipes/aishell/conformer_ctc.rst
rename to docs/source/recipes/Non-streaming-ASR/aishell/conformer_ctc.rst
diff --git a/docs/source/recipes/aishell/images/aishell-conformer-ctc-tensorboard-log.jpg b/docs/source/recipes/Non-streaming-ASR/aishell/images/aishell-conformer-ctc-tensorboard-log.jpg
similarity index 100%
rename from docs/source/recipes/aishell/images/aishell-conformer-ctc-tensorboard-log.jpg
rename to docs/source/recipes/Non-streaming-ASR/aishell/images/aishell-conformer-ctc-tensorboard-log.jpg
diff --git a/docs/source/recipes/aishell/images/aishell-tdnn-lstm-ctc-tensorboard-log.jpg b/docs/source/recipes/Non-streaming-ASR/aishell/images/aishell-tdnn-lstm-ctc-tensorboard-log.jpg
similarity index 100%
rename from docs/source/recipes/aishell/images/aishell-tdnn-lstm-ctc-tensorboard-log.jpg
rename to docs/source/recipes/Non-streaming-ASR/aishell/images/aishell-tdnn-lstm-ctc-tensorboard-log.jpg
diff --git a/docs/source/recipes/aishell/images/aishell-transducer_stateless_modified-tensorboard-log.png b/docs/source/recipes/Non-streaming-ASR/aishell/images/aishell-transducer_stateless_modified-tensorboard-log.png
similarity index 100%
rename from docs/source/recipes/aishell/images/aishell-transducer_stateless_modified-tensorboard-log.png
rename to docs/source/recipes/Non-streaming-ASR/aishell/images/aishell-transducer_stateless_modified-tensorboard-log.png
diff --git a/docs/source/recipes/aishell/index.rst b/docs/source/recipes/Non-streaming-ASR/aishell/index.rst
similarity index 100%
rename from docs/source/recipes/aishell/index.rst
rename to docs/source/recipes/Non-streaming-ASR/aishell/index.rst
diff --git a/docs/source/recipes/aishell/stateless_transducer.rst b/docs/source/recipes/Non-streaming-ASR/aishell/stateless_transducer.rst
similarity index 100%
rename from docs/source/recipes/aishell/stateless_transducer.rst
rename to docs/source/recipes/Non-streaming-ASR/aishell/stateless_transducer.rst
diff --git a/docs/source/recipes/aishell/tdnn_lstm_ctc.rst b/docs/source/recipes/Non-streaming-ASR/aishell/tdnn_lstm_ctc.rst
similarity index 100%
rename from docs/source/recipes/aishell/tdnn_lstm_ctc.rst
rename to docs/source/recipes/Non-streaming-ASR/aishell/tdnn_lstm_ctc.rst
diff --git a/docs/source/recipes/Non-streaming-ASR/index.rst b/docs/source/recipes/Non-streaming-ASR/index.rst
new file mode 100644
index 000000000..67123a648
--- /dev/null
+++ b/docs/source/recipes/Non-streaming-ASR/index.rst
@@ -0,0 +1,10 @@
+Non Streaming ASR
+=================
+
+.. toctree::
+   :maxdepth: 2
+
+   aishell/index
+   librispeech/index
+   timit/index
+   yesno/index
diff --git a/docs/source/recipes/librispeech/conformer_ctc.rst b/docs/source/recipes/Non-streaming-ASR/librispeech/conformer_ctc.rst
similarity index 100%
rename from docs/source/recipes/librispeech/conformer_ctc.rst
rename to docs/source/recipes/Non-streaming-ASR/librispeech/conformer_ctc.rst
diff --git a/docs/source/recipes/librispeech/images/librispeech-conformer-ctc-tensorboard-log.png b/docs/source/recipes/Non-streaming-ASR/librispeech/images/librispeech-conformer-ctc-tensorboard-log.png
similarity index 100%
rename from docs/source/recipes/librispeech/images/librispeech-conformer-ctc-tensorboard-log.png
rename to docs/source/recipes/Non-streaming-ASR/librispeech/images/librispeech-conformer-ctc-tensorboard-log.png
diff --git a/docs/source/recipes/Non-streaming-ASR/librispeech/images/librispeech-pruned-transducer-tensorboard-log.jpg b/docs/source/recipes/Non-streaming-ASR/librispeech/images/librispeech-pruned-transducer-tensorboard-log.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..800835749a7c7b2ca62398ba2e75d3091d06a6a7
GIT binary patch
literal 566971
zcmeFabzD_l_b<9Od(+)r(j`cDN=uit2uOFgf`A|)AgF+JBOpjicZf(!N;gQe>ATS9
zdEWQD_jm8-oO{ptoIj2e=UU%A)|fHonlZ+hb8m0H-^>FzN^**F00aU6ufRXxW&v=Q
z^|pBe04gc~D*yl}01QF~Ac6>F0mLEX|6n-?GXVWXM*x6u8vynfk1BY-4Pb7+()_+7
zWFY*_fi9W>{Tn0QmI2)C0D|h)PVP=_)=theJa@SOL1{%5*lmtr_=OpML7Y*_sL5EM
z0=DrK_3buXI`C!@2Srgv##CKXO-}KF>@PvE;bzXx_6Rrt;Na-)swpo`^HAS_26+{%
z4JJSUO2uzx?&d6|uKwUQ&%f`#^8e>@GXAUXzzD}}UjLB)`v9JWrJFe@yb;L$*uu@+
z0mLl;082G@c6A2;q+2?jm%H;VehOj&SFnO0etnCre&d2${P;I+_{F2GDFgC+178V-
znVG8<0AS64bQ(`{D=;6dOAvE8TG%)O05&^_#m%kFEI@o)cXkH{r(66M#LOVSGl*M2
zEcp1}^0fLlY-aZOpZv|ttp35jg5YiAuBZ)S5)i+zbd%Nkizma?N%OC=5UQWJ%1Hf1
z?=f=)(eJY6JnVFC>9_eGSX(Ool|K~A*IHW-q@w`A!efJv49YPY=>HEgj{hK@7?Q^?c$gcPslXf24yOhFtsBUXP;Fs(Ulzom?9g>fN1a=FZ}WjW
zyG!4e1*V0eIotmlCmQT|S$_v?ux
zc+~=YfgQjRRj>pHKpC(HKYM`K5=`@V
zX}Z8NNU{0n^WUYLfq9yP^|c4%?eqVd{x|hsoKm2?5B^H)@pp+d*ng)-l0~|Yq<|!c
z!~!FOiNi!-55TJ=7`b7BFrmNY_!mFoD&i30BH|e0JmRFOO|a{Kw2=xh4@&gA4K2a;
z{kOJ&F@kyW!K7hgVD4bfFd`Tkzyae0%aMXffF<*RRGGgdyls`g-^jnq^p6CmZ=dMf|(HvXys
z^ab=2v>VzA?S(c#YXBN(Gqe`^1=@Iv|0-YZ*W0!Fr!>u9^|1rri_Jegf6IIO#&65K
z?H_Jj#atU)m%npyxAb-g`!gWp7=I0F|`h@Spz|jcFL1Cct^RC(hi%)$@0Lu-h9V
zI7?uIBaIZG0_Xr{P#$i89}oc~0a-u^Py@69eZU0luhw8ccLlrvf8ZGq2D}7U<3u1C
zNC&ck0-yw_0BV2+pcUu>`hjo27%&a?o)usV*aJ?03kU>)48erpLr5Xi5C#Yvgc~9V
z5r@b^9zZl9`jAHuD~Kb+1L6+}fxLvoLy{qxkU~f~qz=*w>4gkICLs%u4afoH907`e
zfk1#jiNJ`!i6Dp|g`kX}jbMymh2VnVi|`yF2H`D27QzRFYJ^sVeuOcEd4w&5V<-Sc
zhY~?)p&U>_s0>sMY5;u#b%h2(BcVyqENCgT9(>Ou(0S-K^bCds!-vtpIKUpP0Mmh)
z!(3p2uqaqEtN>OG>wt~G7GQg@YeY;$3Pd(UVMIm5hlo~)UWj3cNr-ufRfwJ7SXx0m
zK|(?zLSjM^0LP*}k}Z-SQWVmAq%x!yq#>jwq+?`cWKv`{WHDqlWHV%UC`l+qC`~BCC~GJes5qz$sKTgfs1~R`sIjOysP(9W
zs4J)!Xt-!hXkuvEXtrq2(B7hzqIIIppdF&4qu)UnLf1gIL4S(=7X2f75BdW783rB(
zD~1e)5rzjw3`Rai3&teIAtoj!Bc>##0j4`<3}zu_JLViF919=oE|wye1y&GN3RX4N
zFxDJoYrrVus~&J#h27>N{#?1^HDDu~93&WS0BrHCzwBZy0hhlx)~$VntgEJ-3s
z%1FjY&Pi!VWl8Nw<49{rXUSk>tYjKwUSw%xon+hO1mq&*=H!v&735PC2o$UoniPH%
z*%V(Wjwz`qy-7nzBSm9R^OmNA
zW{;MFR)N-oHk0-%?Zq9YJ34oQ@08t{r9-0=q<9oe(kCpa)Tq&Pe|ia8eU65V}pH~4PN-Ca(4P6N(3
z&Th_YE2docVAJnTLs3@rP+e`>Yp``G}JXxG*&h5Y5Hq+YvF5|Y87dnXiICyYESDh>v-t2>0;{|=@#mq>fP6i
z*PDOH`7q#Nzdo71jee~GvVpEap24x9tYMetn9;Z~ld-pPuL+rny-A}fwyBxv
zr$>m7^d1#Gx-?TW%Q8EDtoS(f@t(Pi`5W^s3rUOD78{mgmhqNrPsE`B}v&*-;vDdRNb3k!;)0efNd;%fv?#Jh;=NQ=n6ms~HCUS39-NA*SDi+&q@
z6Z0hIYb<|k+AE}2j<3e!B;pF$P2yUKUWpD)lX=qMB@EG(iZiY-Pc_9))?VD#Zz$^DYrQjXHh
zGNQ7VAAye^A2-V%m5)~_SG0W+{8U=WP?=hVUlmymsrISftFf+`uhp*|s#B_Ks~4%S
z`h53uegkboN+V%oY!h13^QN0--{#{M=a%hOtJbA9)3)jMhwY;s8XaFdl{@!T8q<5=1W`4Zp)V|&sQ;5lh&x$3fH;U8#d%NzHJ(Au53APU2KQ#;O?aEGVfOY
zl>GT+@8RCczU%(YLDV7XVg3>CQQNWl@!W~S$;D|ToD^PgCV19;u6Mq6;dP08`SyzC
zs{UH}dgjLA=Em8~)$ErpaBFuXJOQorgM0wM)CU0kZ=n8Q_}doz#R2(kzk(R@SG*8OAqyzxudH}d`y}7w~e{*w{1KJal0MO#{+XlU5M#2Gr46$1@Mq=^w_WtkTw|xpy
z{+97S@1W#7JiNC_{+s(v3xJD)#ELEoh1>xUa3N4!$V~@84b}}2RA|8PcOei2C=3w^
z83h#$9b~A&0T3WiC;|+Mh9TBzz4s=*3s1iRXKAD%O_UWHnwi=
z9-dy_KE5H(L&L&fL`23XyiQDd^Y&eGR(4KqUVcGgQN^dqs_L5By85>Ej?S*`p5DIU
zkoU4+k9;9p`^I+{}Uws>+)w00RmEI};QakN{3i;R<{~XFqR%7*4px4S9yYtu=MhpAgE2M_y!mZx^xL@D=)~tavrz=*17YqkvR{s
zkCj=$ER=%au%MpF8{m0}-VKm;1Hcz=fKckaSGG66;g1`D!gTM0E%Oav#1nLmQFsG9
z>%VOOBi)~O_y05${`Aa0K_~7kB#uhM)+eR{IL=K*a&}Ygg-XI9~t=?;xyadl>rY|C2Av7O|
zQePTV?-gFt1bt%z&quc20G&hE#ZQ8c&_QBf*9?UJ
z>5(<@Yo>$4!Dq*y-53=crKr5KN2Omex_R5@H-x_zMdBJ8)*M~FtkmPXTf;+f{~2;K
zW6^tE_)!A&Q*i1zVcEOBqqN(Jtjs;|Cw8L+w)V(8-R7KI3Q>_9bSAr
z6Lb_1poXgRrlwHtdb@D6e*Mg*PkA?9#g~fTjUZd82>yCw@}yyj_%v@sR5Oyf)O`DY
zY}R1Yr9?NZjj*a2DP@T6(n(^nq#MQIvRzQs9A;fs%`yM*KB|bdbL{=>>n7Q4Y$ryZ
zYFXy*VtdZEjSsR$KCyDE7$XvBD459-s6+nCNcggNj^xeuP%2OC+jH#eM)w7-;Svq4
zSgp226PoPA`L(FBN>5|M^*3dgc#YRZ58pbdB1_LTE46O*W<_2hqlSc%9JEq*7%6v>
zG)&;^lN6d(Y!FUzZZu`e66k3hf>(9Kmv)6ErvCDqB4k?*cb7iL`Z~*NY#8Z0tg7?V
ze_jZW)i9B>N$hIy$=;{V_x4(3@Xu?o9#2R_71KL7>pnIPP+g3)aAhy3pzZtElRwc~
zIr2%OV&nH5bSx!Hpvlgh6$#u7|L61rT~`YYB!}NChP&F8y)Y>IUPbG=8L>ukh;pCR
zrSb;Ab+nrm71@)Y*Lz8qY>GKISin$+P8+%WKDQ3IY}_}6PRvMDNL&W{%ijRTJ2$`)
zaoH#wpFFg4(dH4GnN5|IkKvrBN_vvZ95UhacAf5_PHolAq2~J1VHP8A=!r#DnzCKA
zcQ-q`#^@|XF-t4*%C_AKHp$ewt_kyKi71c26mLPEWle(T&d`a_5+8D8(6SZy)t{|{
zr5iww&u^kNHuzjEk~iYgZ7Cv(jIik-F6^MAnE1Tgi{gV_uwW}am7LK2mfMr#adgzM
z=VT7EHd(tDf%)a~Z%ClnR#l+$5ht(;Y0JR@jaePd|2+t?(JD1pmW
z5iMMrD68sl*NfUiVEFW6C+OW?C*jX@gljwlf5lNw&Z;^SF*cR&lyo8{Hq`XY_7u!e
zcHaYFD{jkEK;je9r_#>(eih1)L8dsu!JkR|aWP{EYCK<-kLb|P
zm$yS@WyM%I+{XFqf&2r?~)hJ
zxgR9^xeo5!tvI^#QQ>NSs>44^!T81ZA9D}fW|j!SZUt+A=GmFzd0?Y?|4=*ZjIO-4
zz}QWB5_#Wfg;cRi&7Z9Pk}x9Lmp5CmkIy)p_b2gT)x?Y}7GJwksBOjbY*wX%5Z!o_
zH=iyBfwIy-41bC}&S-zMYx;fWR29mOo&k!_IPxS~sgI6&OK$+3J?!nny+fTkC!z*r
zjtx912Rap|I9!M`F8V#hXCmq1-#cwC5ndSi(~Z%T8!sxStVGEBXP_b3Z3|k3(1m>w
zi{Wd_UP-S{=S!|CEy?|vxqqqB<=U|_Hx0{Iag&*|OWK$_32<;wB^0N;{!+7BT|LKe
zPhkzXD%VT+dB>rQ2j0BL8Wf(SxX3}YUqwkdL-!8AYct$+WNT-0p8?jO6cdzod&-ON
zXEq$uC~KvcF^3(Z1)hrhNzN|ZnHwOd!_ZMCCt^i*2i9khCy&*
zan!7)7&zUY$zV7CDHY4Ub_r6rbtj=tHR(Gfqc_(C8ly4!qETc=7qk}~@Rz!3s8bJL
z2>+yzIa>gred=3t9Gi7EKSvz0lSKyKME?#_Rof
z;l-Nb)`y8Eq0Fk>(WYUFHN+0*rfM>1Ms{I|XvbEla@tMIm9G;|wj2V~(yhxHJ*W|`
zzuo}Yd~9$14e1$IcfQ0K_Kox|OJA}+k1Fpbow;1gM5tAJo{l6+KZoP=A|3Cmy)oIt
zOfTzf1?_pI6_h6&v5E39^I(N;qf7D5qWyN(SFWe2C7jRXrutPl1F#-?hdg&@h;U!c
z-YuAuAWuAssh^rkW7u(++9%lK~VIW#-1!h;-XfhV`rZyu%#qYAKVy2^7`A(q3~Hit0e42MM_G{<$O}%b=}1%FY}Z7bozN8h$A?RylS?toZU>+3
zY2T+u7YVjPVQJ#TcY5qr35{9#b@d!R)4$Oj{5*D
zvac*Na%}jp%;AXu`NBIQg=vR|DDDp^f;6~s^{xm{x}*JR9(vX#Xw_^A71cAuD@|zER(m{N$D7{TDU3l@8=)7WDp^HU2r&8Zl{7
z7b^NQn-YZfycf)0CWv))hvY(vMf&d)P%WqG&6;!bg)N-rGLJeIRc8J4uqX*QUH50i
z7WGb}32{l}Efnc8u=D)sP`&BjD;rQNPU=qsm&!1(rN&nr;}dB@z}(JlK}=%)R=`>Z
zd`j=9(GfFWkZ9+OFu|(dmv56CuNn6{bVQG%E0|s3b;Z+oagnP!=XXrSeksvePn=y!5dRkq?b75;)mo`KU!!eTa2)ztO6rm3-%#4j(9$-QU0*gCE;;lk71)Hu9XNlf18V;PwFau{L5Mmwv`L#n9NP
zG1ic+yPGgPjPah0qA^or+O!dsBLrn%Oy9aIF3&gwwk^u*Jd%OBHS^0O=q>++o+{Fp~&__HT+9*
z&2v$`jF!6|jbw*vp&y|$hP`;pN96A$2dmtwjt1L*#!j`4t&TJ?&D{XD*2Mb?3GA=v
z2p*%7JRHR%-;L{86Mfu3t?{wEdH=z=Rj7cGfzj;S4v=b+JQu~f$XDByG7wtvHC!TvP0G1+CM%mDGqd@yaz1WZK
zGpEX#sr9%o@$OWyC(VM32b9pYZ^=_Fags58sdBo=gn
z3+2%K`+Nna6sMF|+>zA8ce@_=^QmXS83b
zR(T|11PzI)N;j1c91S|q1QnJt6p14iwl%VSo|!tcdx_3NUK$QYCkO1Ei+k?n+`6uc
z-$zGy#heVeqsi53o4J5tbni6dr$-zgDX>c{CuY)-?(zs>A_JS2E9zdcuKusdba08%Nbx`UlM(`W-Gfg>MU^62`8Uy+BDner`!hqY*S
z9-3&@a9^8PdS7i&sp@oW;l~y~U#T&=kgZTGVh3B~lPfTB4+AKQ3E;0A`K=1WWm8ioU
zToIUjpAY-@_F}9IO>er6`tHuyP;svCMY7x(0}J+((`dY|`!DgEWGM*SP`xb0c_LyXq*P
zVNS&eM{YxH&}RzfVWIA<0{Zc!vIQr6Ycy)3VaKCm`$scU^pb1eJ|Z-g!r9-9HzQq2
z75a-svNRYmO-%<^K7&YSAk9ao^g&zgkGt67_D;J7{FG})F9r6ChSzjQN*7qiho|Kd
z=u^VPbr1Gx2Wx2vyZO4Q$l9bsNf&MV?iSU*;h3+fa6D~Fp0xJtk8p7-Ga#|mh|k$x
z8n9_B9j+=VedGK6EG#m4ZTkkWezTp@5WTbVvuo$wnU{IgQYkALJnX}EK!ukcX!pOm
zou3eaaui=~N$a7U)#sntzn3lG!?W&iI$T*$64WEu>
zv98v>(+wr#9d?c7S#>ZpG`v5Uo*SuPVVx&Kp4qBPLUWE|Yo{#Ie>Oheg?jgi{{o%n
zwllXgk{Uwe+G3l0AzY$;j=8zUFu-7)wFj8e*GH#_sfA%!<+urGfTaib~BlloA_+JjV|E?+h?UZK640I@zWQd5LNPK
zs(QDxYi?~Bm1yCt^ExIYq=zrtGGiG$dr*=dp
z%pXZBCusC#utLY=4>j|*k~bTOi%hM5ARRlXJ$Tbo7rYWF2+Swoh;q?o9u;NB$!~c$
zA6x9eW+${0C+0#8#nQ)E9<#{0sT52eIZR%C*|VF)HW*1580EcqseOG^E{r|!w1C}x
zR$eR7Y)k9;m4E2niUO-u*iLy?}X@9D;lp9
zT9IP{>`WK(-rMB(tL`(Vn0GcdxUyobSs&u;Tng*Lje=wDXy?i~#cR5F>DuH5?(n1&
zrjNIGeXx^q>ta>ute9DojwD3Iuk9=+H%^Bf
z<6VMY0yP%U!p)m`W!${-^8AYJdlN$^y#!_h9YK$|yFxn~2F`stnh+Bca#puogPh$u
z!jm6{RHbZu>cw)Y_Ghz=ymgBPo)I55FPbYpST#(x``qJwd%lBAXt3LLs+O{SdU9Z>
zBs9rN;MDt+@8|}2dAC6&+gtg*Q#gv;o8?u#8rgVNNt(DYI?*J$)MLI?vcV2?34CwC
z4JxLZ_yE`08^FcAAQI^5F@GgRNAt9H-FG50wwLP&Kc6*Pg>AOXW5bWh?{dfPM06X;
z>qY69!iXy7RPRN

eU{ir;kIa;1i>f8Ful-YLY5%0|_>uVT`EirH0fpC<6y(>-=| z{vd}f^X;UdgnjqtnxP>H)vDFk(P-uJmfjXpQ7pS34H{NnetUN9F8Y$ESd7t;UAaBv z45Tf{ktn1i8kHJPx3NfjZ{*zS@Q{&tUdmEFLJPa#Dh{*H(aGE zX;+p!vMrhp@wR_*+f7&fq_|Ys>yAhcu=4k05$XukJj0XyoYf@McRHJHu*~;^QpRa9 zqCsY!mND_9R{8;z;F;^>W8yV+EQ0ejz8Fr<){!8J#P`WIEWRZ{B4i6Z7?!eova`m# z)3Mnj2X%cxZFGyx=Ln#RGS8cSsiM&|#P^VQrlbxHvJi(A(c<11NOH zhObBGA}3;I#0+#J5vCh3P%@A@;NXMa%I(~HC-G50>#|VV=jo}q}TVU_sGLNwolP@%&{6RGxH;}l0(QZA?BY;dgmi6i9G0!Aw=~8k^0W2<- zv0kp7(QLVF51JHT>z#Vl(~$vNf=><7+FQDLi9(f0^~KwBNfW*6{#Q;;E*{*|Qyn8p zm#lY^%#WbS!%cg;&ADxN`8fDj$LV6z9vTy;rG7+cENeTEomUIPgZGWo2k2?N>cp=P zSt&huifzv(P}k*67`hPjvc7>Vqqk>wo(?5DxlLDt^c5pNk6D@q^9>a}?xNdAkWNDj z+k#yTG~uQl-)Q3l7i>65~mFT1| z#>VW3>Wd{d{+P;A?vCL2qVax{o)rftv+)@+uX^9der&R_)vFD`QtjlC(shqRBjpaI zXcSIQR<(DP$4x2X-7j^vI5@RN`=f=2?Xq$D&on;mAlFS{v5f}`+SCVxM=qrG5((Yi z=}2JRe>2lx8>@-Dk-Qs7-fubYg!*u^K7J>?!>hcbZvKN6>+=gN!)qAp$4k*Pes6^%GNHGHH$VHNOW>ANO8ntO?vR6LJK-cFqFbkdY&E+_ULhpW5x zqWx;h#}upnxL6(Z;soA3QC@B5U1X4sCOcXhAiYMLa9WOiS*&VItY7oVSk4dq#SR4v zKCdbrZPNyZ9G9Qlfa#FGh*yz!+K+ZkmKdwr0}IM;v{#@d`1GmTf~ktXM*XK+xKF|K zcjkxorFW`HHr~FIf?karTocCV$6fe~*&Zg=;aBVqe>UYDo!}V^2)h%F!gg2{L_eEw zb{*bFBEIk z`M6gQHXU~00ludO;p+THdzP;c{WCD|YhJNM)WBak@p7oO2jR>opUYFP@7|H3*L-v^ z$r79oqp&Ivbuk17#L)txZnoxYpM;vasvW-{YYty*th^%=QgT%UyNpUqf6Pr(l^is6 zM-9mm9;}Uo5@<+*8xN9X;|5rDg)%46BC2)uFYrDuhR(VQM-|ZV8Jwz&x=-M4@>fiK za>vW2l?Z{DN_ul@Hw`smZce4W?d%Bv@+6oKa!W(P={l83#17&1tEEl(TzT8 zioUF2F4$bpZGH~!($d>s(HnjrMRjZRW)&Dg_bmoNU;xgmVdU6Uixuhs?xv+#ve!C_V1 z*LjPVCFcI}ILyK&a{E!*Z{F|v&^M9nv-95ob}5RV6@-6UQ`{AEi%FF0RvP1zb01+} z{cPRBP8*%Wiv0jwpm)g9V*?qz%{#K3a8dCFp0vxZaK`Uj#gRX99T-yy<;SU0Yb05{%;0_|LQ zqRWIKVF7b3!sO_^5=$`qG(N5WsFA~TO(tEgk4fm_Ph9WpE)z;)&%If6zn(8X_l-Zl zb34`Xo%bk>dZHDQw;|1wxSBN{Iqo>M)ja1me`s$htLysQ0h#{Tt6pzJrFz?C=RZ~C zae{J1jObfEN07B0MlFpQ0NrO>Bz}1*w!~H0QGy}+{M#?(706Ihwpxs1N;U;=wNCcTH*$@va#9zf zZ@aG!O&a&huBFllKs(hZ=!m5FBmCVWHpzNKR(@zsB9ZmO9=cG=UUJVh*6*8kK6}Sx zZly&^y7|2oDovJK-?}JdOmdioU7|PL9H*2M&&39l93Ad{$mKGT-cu4{yOZeDGe0_F z_UnsV=@x=j!s1$8))!r$EA`w`|;`FR9(y65yhn9j-RE>9}=Du{8l=PK!;WqB67|WHjK4ZMZCMLH90+psLY8Y2>)=)#*g}>bJpr zYN(1&_P^(m3gLEB3U|Erwq0e z6Gjj4(BeOE$-~GR{c>%SZM&wTQ^&LHUvqc$kW+f@LdE2jPSjo+c!Xulv+wx_Hjxa<-XA@MzFYgz5cEK03+jldG;Z6@hwxRN(f4RCjsn zPZOO*n$X;iCMUWw)qbsq9+sN+yiuD&7+HyC9T zr8of`+*LVxr$h;|4*a5ZCj1Y7(9zl9hP3P$3Rj1eiWmF#<{MSK%b|H;3Ud@6+X@tq zd?=3R{g`wpJB~hz`y)-rN;8aXquuhZa|w&FD_e=ccsD;2zxsQUuEF=y0hM)2>Y&y) zfy>a@6z9XP$!y_K6PNfpDj3BGM_xw%wB9j0uQpbrDrqcMij^Bb_6b5$?7mmT@ed4t zvAx&lltWcMJDGa4Iu$Th$+Ks9aB|~OaXRlHqhV1KqFr+9_S{xk=_ncV-KX5OJv3IO zZYB-2epp-V?5xUDd>-y~o{eRG*T`{HG##v*gu*T-Hj|T04m*AZKOzntN^-}JHta0X zXH*zHb8L-ckW_(QQ8$&^>e^55cS;w%8KV~BMvGaFXb#w>#BmMUbEfS~M4KM8=Subx zdC$^eyuvQ!0yp=X2!=dCbVSJh4g4vd~uh-A$K8SS-k}stc{zB+sS`=f82fa zkdGYwrS@?Wlhoj6h{^ump+>`SAf9)K=E(Aaqrs{+MxTQ;lt1XP%Qg| zs#`Uind%tnub&EgbhimQTZKGOB)ss+*i&oTDGi8rwG@JO2$#Pu>CZY$sx}P~{$zAz zHpU_+usU`ouiV^8P2NqBpz!M;u8s~kW&yy1vGm6VxPSi9)>BLWtZs5mZfLjL$Pp^o zGhniJ=_<1rz)z~3ox9B1atu>|AT?#cYs`>e_zM-#qo>S5mQrI@S%_C99$0`UYFRAn zsK>dg>{txcC%uMh;EbaEqUTMF(Zu9keoa$3%|x+rk0bS06lmn)6-ii&g7x2~0>E;- zFIr6yII$Vt+4OWXAm!FpA(wlIRXc-ZHynt)!8Pk0W9)f>Zj~iHfz_R1k5hN<`eTo5 ztBGID$;)q{IOUZlYL=C=@bkyduW2;o=`>rENm=xNfYvPA)eZ1UcaXwmzqLz1ckFEP zbHI;xi$r8dpNqPljPP!~teprRy1$zdGpR0XHDhQrr8Iu(BFbH+?ZMi56MP#}Y8<|z{b)C#`~2*#18;F%<||8Mnz+R90-8rd zE(>M2lP#wLzB)5FUEE*y6pFG(Fgh#@9EU` zP%Jhi81BF6@GIdr6!Y}yeBbhYQQ0Hfut~t4>U}i3r%q@l;+NIs#k2@Roq%1?2^3kz zxR{ZZGhXsT-r<9bK(bs(ekaB0iPp5dz>JN#h3=9l5kMuz!p>WwrEO)*{nhRh6*`K! z4O+;Eh*YOvO|TQkH4U75t3^2bS22C-98VLrb1a}->Zzb&~5sdW;r zjju`8qpA>E9C^~N#fJxdYMOo%wf>Lx(r5rvtvnh4gxU4y4+%w-C2Io^bQbOo}nKj{4Ie?+G_wy&&7?3h!*OSeBE z9;d2^N(gX@#($nPo?2s+AlznFtt#tFsqtfbE8mjFOJq8QEYnB?|N4{m+~e!LRKKHA z*Ws_b*Q+8TBqocCT2*w>$|BT~+3|K85+vR=i>D#}#xBLC>=ibsA|ylVe8sJwS2X*? zHTCLfKbiQ<4?Nns0f>1AEK9N~?LN1xGE1T4grp*<@FlLMn&+$djjL-0x}ghN@e zG0f?YkITgOcI(Q(x)Jb_|G?`zG%1;Fu&77*&TRiwNBaC+>&t;?eHw!qb6>uMaENP^ zw$~?{{2&!5B*olHYSFg9_27^`#mTOCHel_&t$}WS_Dn#6Ze?;`KxRW-T((1gW;@{0 zDLgGdR>qz`>cGC5JUnq_Ur5diezT*Yx<;t^cVyYQ%H;i{TmT34=Xc5za%9vcGu1Vnq);H+`M5YV|sZI%X(7WpXHtIYt z!OIF~^H7T3?hJI~O)QA#`ek48?1vTR^)JsFnk_j)rdkCBX=1Jru2^^-*@kt3JtXMQD^B!`{-h{4Dgk`4QL}q23Qu!`yiY-;DbnGI&f*4fy7} z`;PK>kGBT6#t#onp6b;HMn439pBx1{CTQJcRAP_F^qR3^4fJB#70BioZPeT|RfWVL zhmPwZ(-2B?yM=u7Ec1uI*+*aGk5~#&Qyy}%4C_OR?UxPa2Y3~_5nm^Lg(Zw=s&9fyDMLuOU71G@^fF=@pKo= zqiIjceeS;LS|eIgp`A`Y&~n74^w%10uq?7U=j;69son!qs_xk&C=k6Dl`t{haM51B z+FW+2{>{l>nQg$g#&bV@o$gU)1w;pZYeVbZvX6(TQkzDiYM=+w3Q57wCnKuY$%Ib_ z%0Ff7^JxqD_{6K!RE_O69liyf%*fpmN1V3QaCA;^@_sl`PgRvS*O|9TggnCr|Ervh5hc1E4r@wwD%b^M)GFp@++`vd36k1Rgx~n<+S7FxUIQ6 z``9r@cFhD24oy1U?1+Ejnm-CwZ*kBu{fR%)DCZHsGRh{(L8`r`wbGDHNJ;rV+3V=r zhhpBY@too@^%a-Zq3VQR6)Xvh-$7L;8wsA5;XE7{Gmwv-8z z*ZhE8D;jxjzBtiQM(%CZbxLLuQAlq0a49rNf|ZHX5_Yhz#x$4nag3vh4=%4+KJJhm z+mX!6XnmJXz3Fk(E3s#Ec37!l6nkPArl^mYB8#84j=!l6=T_3tO^Du2P?_Aw9nU8s zf3Y8zsYV)+OKr(!NA#1BBSpobwAzR~|J2z5J7sUI3+h>v-V_-}*PHp{y&r9U@UfT@bD__Jm&dQI zNOwujBE^Fn(g+u<%{RxQZByf-6Hgm*&K@_Ce?F4W?4XchrxbS0#j;V_oi@0qmv9zn z5BDsyIe(H{Y!r~7T$yz6n5TlkhMcNaU6oD7=-b^RQ(R-XtE7RS%@?-o2|-PXITE@p zQ%QJxhs~asa^;>xNF3dmWvFFC6LW5_{Sk~x@hi`(;-^AF6h3;`U4c`j{Eh@0Hf(IY zz16QZNA#>dS5)Wgnx*!M_oWeu9u^jj_w`LA_-+&(dWlkN@9DjCH*~7Dt(P6Egno#{ zwL8x)pOqI0`WYHq^I!4c=uc->&)KYf80NO`pk18Y27lW(rtc z?~MZTBZ7Rzd#vVe`y7s>L_F*QSULo(nC*f(6o3fU*jQK}_qGWSwhXGKXKI~6yq=+L zvZvP-mhV}T0Crf3AwJ3++D7MRv7=9V@)61euX^^r2f7~CfeS}eN&*S4YKqZx*4IC$cF*7S;xr{`;PDE9f8@(rLWxv#{WRW9DZoK?1w zY?57K+Mc_?KWscZLs{O5m*gqt4gKb-X4YHuoFX|&XPVJXLsipLAyd}(N!gy$qUVc` z;AxWZ(tED0_`TCEhXRsw%HlIOKvm=kHHJTA*k2@{GqQ=>)@9@WV(zWqn*PK0Zxoc0 zRHQ*k>FyE{a3CPvpfpUnbAr;{N=>90FuJ?DLuocbdLzbwf#1D9-{W&1_i_IP_YXVx zb>kJ!>$=YKab7-L5zm1fjEUw|hbaKu6V(j4D6B9_!~bk(sdWFAliq2yV?NX8&pCRX zr~5;3d7wVe-?ZS}4Vz?Dlr=&=7Ut}$yCgQ|-KN3#>5L`WSH<>w)1=~74BjSFc<8~S zz!Pw(9|*n?*>J~fIL$WORlC}Qw}j&wk61A`bxN~g-P3%H_wbz;tRgI!$eB!ac21Sr!8*@Xa^j^5%||c` z-tcCgrZl-fbk<|_{M8r&q=y>y+1*p4({lv^jtUwsTCN-IlXtHpL=Hm+dNJ+sQznkC zewC|Djg9hFzhZchN~r2z|3i%&1W{}aB7I})!K4rMT~8<&BlQ!dk`mp%B%6Nvt5*4L zHmofb5KzaU>^ED&UFV~%(e*aKwHF$iwbB1Bq9IrSWfxJQnx|snT=IlPckr>G5rXV} z((B9pAii53w2RhJhpEo3_b4Ww6YS1UU#xBtO2c}_c#wEZ4{xb|R);p;@W7lEdv%J>Xk?%L&GZ>ejKtJ_m5ABEUo`>C zl+Di?T=}M={^6K71^beG?S(U%xD+xjr`e<*!8vWmb1^6dA?&CPgL8-ZulwcsiRcQy z1k)V)+%r#qIzKr9baN7w7x0d+rID5(XeFtpmR3P^dy$_I(7p~Wn3+Lk*@18J_{%Vg zRxSRusI&PeIm3skN7RXygrU<|P#wahUVnEv&QR07 z0n>^}WS#qmV~?doEaSXl<#aXFJETtXMOGDA+F7XFZ1|cXWNAzUw10(xe-&ak7zOU> z&>0(jtrkBX!5Ydnk6DF}?d?DOd^L#k0YrD*JBCU`7kwf*ZeMa+mFo$P7k$yuCNy<| zOlXC(+F^tUGKl*E5-ZtwSax1iwu1i0$>fCtnqc@#! z3++M02&i9il-Fww!eB{fn9#M&fnfw|)SiZJm?PWrt}AqG>`bl>&KY?LirF9u5`ck8 z_lr-pT{-Myc%Q#Ju^RvKyyvA25!v4KQ%f;)>ZX0a0+FA4y9Q)gh40t0U4qoxl>Q3$ zmay%Q3^rwwhjSdQJg#&Y3-nXJd)j%o>3aK-h&_?+XBR1t4QOys__ViUStzY98x-x@ zNXlDt#8o_Nu)*K~%dRdtc|Xl(2U7o;qi-b~q%A_X);wX2&0xM5bp~El9tJ?C-TOpr%LG20Dm2Wp16e7GLZaE{`ImYY?AKe}KY`i+H zm0Tq&2GJp2@ryy@v0at(%LV&aHx;)y7RnvG>E;t9vRYL9~m;p=O`D4pSyNS}Gjyz>{fF zcVq8CNC!`Uhydo{={P)E5H$%VF6Qqv%a4xzfr;0F+h{$h28k@ul zvG>FcrZ2gaJYvEG_gSy~Z7kE?r34iwcBH1b@tC55kzN*3i2D*m!NA6TQc)W{HmV5A zrcxEG`j6W{ULMHD@W}emNgipTTs{^Y+q9G|#59R5ED0avt!ugLerI@JZRC2)+cEY> z`E0(3GsILz-f*T8vv`*nQ(DI_OP2Ml$$ijPruy;oTlk|y#R<2ir=sVw6ByNtV{@OMJG@1o_7kUKV26j2 zE-^dlTFIC>_&Si&kStW|ZO#V7$I=_TP9@jSFmc>(;NvcPKH0;XAjW%gG=}fKqR)5!0+*DW2=2MXZV5w|bX!oO z6N`7Y*!cC_cPv!c-mDE0eGF{ywyYgYwZ_IAL~J0^oC2+~`J7Ok!`L1jepIfZt9=QiBQAp)(4z1S&5s>< z62PU+pL@)RJ_*?xzM)&G(o*$k$_KR*+;`2kj-q+bVj^CyRF$Ayze!cp7_*B>_-Zr< zWk60e1d_1Va+|s2Kd%Avzq$V}{)Yo?2HAkF@z9C)Z7bN~OTd!M@OM}3RWwFy>rTWn z(PE)ebTP|;i1-KmTQkugZpH}r4XP^(`h{;I0Xo^W^-ce9h~=sKV|4Sh4FnB+6E1=Xt zPQpiL`Nvgpq2jVq#-LRlLaXBttzL^e==&s*CwYjap&^kTtuSOsA4FV=*11>@22aFT zS*7r`i@dmlRRuJu6@nj!;2E3cI-L#*O%?OFk_Xp55yY1vE*bA@q{hB|&({me( zUb&XN+*t2;I8kKpq{zDc=Ne6QM7jgvV22V|_D~ zhPWzOUWapS?e-3%d98n%0_9G;-Jg`s4jGg#sEP4bzOv9RE6y}%SX`8(m8U;YAPo^h zYiTU;|HG;0bvgYRXgsL`W7BR(OKte-JdBs1IQ41nesRt$H}G|6^sV%Y+Nj7N%8fqp zr^$}I490j4^9g3o-pJA%oN=`6Ul44nzNJP(os2+f@Xb)!b;!(;MA`E6rHjc zU!;qhZCxfYMyujuGI&*=W3Z%Uu@So6RXd-lj%4xM;v5YmY@bdS z{lYUMkMFlQ3y1Ai!ls@dJV)VcH#YmE_#4plwI?UVOupWH@*ro^(9V?GkKm)n!2-(x zcl*hN7?oYc`ttf7hc(^zYj_=<=yqAFS21&$iGSGe023dK57&@T zMK>Vc1u1pZ#LO8t)u(K~|4P#O=ZDSvJT>vbWkShexqvhH7XHzp)7x{u-|wSGw2tYO z+Exg%ySO1pdpLgBu`Fc>eUN1T;wfGC*{|w|SJWw`6JaO6cUl-yQ-DyC%Q$}Or+0PH zaa*s%Xj;yj_~TuHpq{@TRw2-sflpmE)wAE=qC4!QZwnUWpdd7(<`^Iq`=e_sQXI{m zesZIM;_v1zT7V6pVaY#eq$w*3$yMcBs_^l9usp3r*$rc$R*hgPWLHHc0AEwPz>g zikf65G$hvHNg<-tC>mWd^|_GaZ4z$}Y_hac3KQ?{^Aq*sn`LR4hkOX3&vkU2j@`v(Q|5{v4`t?SFX32HDpR~eAU_Xx< z9yx%P^@jfL*~0ePwhogPoZc^{gjj{CdBe_afs!ikvq*tcGXu!BH^JG-@Ui z_1Q^{aL{A-Lv2jQtwTHJ#Q%O!_kK2Aa-!^wY?ZHvPgAgEW_ltWdK-X-!dNsxA>cxQ zr545_spnS40qNL(m-SAxUlJ}vFkb~CZ#aVv%G@&Q!t6Rg#jC2m&Z46_MrHu|L_QV$ zv?TuvX#PD_5CEYSE6*&^M;6rt0B#R#v!yo2NWIi~5@;hygiON0pQgwIZr8ror-j!s zuoeG0^ek<>=vdD@Z@Fw)e_kOd<-;%Nq36$#VzNAD>TiV53hg@6$H-t!Q`mKu%5nh; z$Cm$a*zq;~;Z&Ph-pD&$Xgxs#H|=|&PyaYo!t@u!Xl>7&`?QFMoOj)9l8UL^2(NQO zzxM97H=&Xj+vQ-)QFQu~KgtvXZ5qjkco%Eg4!~V$)uWUWk6ys#hAL?)8a8W*noa6r zvrBW+QX8XzSC6TMjaKh7xYO8vmNoc3he85^_a+vunuf}rFynd|)x$z2Sbx79qK$ep z-hz3v6ppb)O!aCItIyhD{iH<;tNqa_GKuxLOsl6MnuAj`a9MguBv!+xtbD4Y)1%7$lSB z&LnJ9?kBq#&n2mw?GbBn+1Sxu;b(b-71!AV`~KmWg6}voI>V5?rhr>5w7v_n+Vrgo zs8jX~j7D(>65nxyv-hd*QztNDD=Vb&0?NV;L@<-3P!?c{#t< zx%=l;*)Uf^C-JNJPr%Q}30D>knhYtM&y?R1-zO@WUpQ1iUwwc|wy4zG%sT$9EV%7@ zI#p@B`=c$F0HC60OU043uwglbkpGGv!S4x-4;TsIuvj%+FRB(~v+~H(_Sw67Hn z65NUOm&vp>E&7GmHt|vLuGsJZz$2}SrREqUwI za+L+l4%X-b6%Y}+uZM{ZM1by6L$XN`Tf0?ldc0TNNd`S*Gh7BSb#nTB#K^VGwhKX) zE!Pey*@TIW;pvM4?MK+qy@>;Q3_lFA2twf|A;hmwQ`gDvKeR~_RJ9O?KJA) znc5amhf>okAShKf<|D_}?T`M~A^e|9`2YC?4$pFa3C4fGLnj-86p>$}T_xF7beqtno%o)pP3gTjT+u<)DI9igEB9qcf zZAmb5J++TDx&U4oTZSZv{xkl?ytI|0H?iOUK`V;efcuR2dg5; z=?-M+Jf*$nTTJFTPaVv4g5$Dndg6Snza)|h~ylw zB!$|vB942b_PVnO4TFT?#XG}NOh3P_NJ6Wj1CjuP?b6lxvS!VoGVztxbSRWBWnrxV z{@1OMU(|oU2%j#-EmFjhtpXnCIN@_)@Lnr21ZN?H)i2bGk*^;|;pF4xDJ9bc6#4#8R~X$SQ# z#dbal(W+2(&=&6vujW(&>pOmzu3auP1~g;jon!M&NM3msQPHoDbOmS{n2Q3KwM;B9 zTwh^8j5?D7i7#)(ATBY%_~$&MxJGHbZv5YL=dLj!Pnl@Z$Hf+=#VZ<4sKqm3X;=R= zHTkw4*8-yJj1FBhbQ?^I!qm@$QFD1_2QfB?n){YI0pb;)Onr$bKEU^-JweEvN)kdA z-8Y#V;6$mbYt}w_*3}W%)Ox(YaDg_Yi0FS)M3msZ?H06JKGW0&WUo7&$z3|ewju2N z3O9(aSjjP^U7`1eWGN;{xx_DE>dqx!`YoAxe&YDh9~A(K>SXqQK14dVX`;~KOyYe* zd;fjilAJs8YCwXCeXd<^Su>+4%S?BGz{pmQDFj#O;pYsTkDhT)ht#H|^7Nwz$EYtO z9RU(ds&xT}A?qqFq9X6bi6pf?LTT!YCY#hX_%_CGaAE7|h$fS^Ia#3!bi0F6pIel5 z{kyTF)lx-HMxRGRC2+H*W5S~SJZ^-1IFjsjhq*p4+3&p;#Y7>RO&Q!d$}pjW%QK6% zN=xi>+OgEYLE;S(nc!Frf3hd}0np8eGxe~IWY^yP|ju>dy=dMq))^=nMTvvFsusmmPlkh|U-^vX_Y>;lqr5jKJ zG_hpNOgq&jTQQGFs7df6^U35ASsZS;1 zrK4H#vx~w2;2H!X7OX3n-{Z~oM)JF&^i1gvWLDdG|BsAJi8Gx-;6(oaFEec$gwBqi3jjRn6^N z^u~wT-C3_TQ$|vmq_?$3y(6qHG@nI89!-_fM_QZxcpe8?zuS?<0QJ+?%J^U|VC%D#S|=-iKW}Vl_9v1uip}pTA-*AM_y$(4tsLUb zk=?4G660u$eF5tTVucUfzs%t;y}^(Cz`fHHsFCD6F;4oXq56;qY4&q0%?gjv;q5X$ z%Y5t6uEAmxy{yFFn*g+vvp*@4FD|QOR2RMF9^T7o0dgx(SpN3qANlRfH;6}dY9;oPh*Tz-E)`K@dgmFku&135eA$nFZJJ^S8X+wP=~a?=f0 zm+|zdpKWs1UOC&EowvLkk zt-&#KJB6nMa;vAF)}Ay#Ma(Nn(l4YUISiLCVlFgH-y#;1H_Zilr0=UmqZ#=6yZKSc zwd%M>3<%I4Dp(?t*%+)ody&0jYFaCD(m z%j4h=Dz!&Oz5N8r`cpVU*7K=LWktHuy}jOtJ)UYs>A4qL#ylIHvBTy5{M>H&Z(K7( zPfjW!Cd$ie4I%cO2RE9L&>#UYU|Q@YWaMYaHQP$Vmq!V+3ah;MRgZKK&r6=iBVzUD zck2_bU$j_#O0#9=cscyyVlu!m4vjTCHhtx60l!b6%{a#MP?fHGiL=eX!!^9^{Z|9v z3J=y<>b=UNj39WFeiWf1V>zy)mcz9cM^aF#U@{9PpWaJgS*_bn`<%XTKY4Y%QIo65 z1Rdd(37X!Et6IzLTwJ_1ZTSO0#@sB^vLf!|tqVt9^`yM&`E=2Zd;Ykg6HkM0DCSaX z&yy|11RF&#VkgYS2G#-g04orYao!4#HZ)2d-IaJ#hjs*i7-zaDxi@6y#GIXB?n@T; z4F)ut2kti50{kPCv<}i8wmzm6V(+?a>Da{BdxZp;(USHDs_$g0%G1nOz$SEouO|de z)|=T>%awON=FbohX@TLThIWq}@$C6CjIq93GfQzlHMeVkYv`I!*Kd%mD-)USJ3*eZ zRC`eQko4E!C%)z`wO?RY>5x?hsQ+;ES#Q$X`3cz;gXztQ*>V0QSp2uS{b@<47!&kHAvjIl0j68ku4k)vz>@Cv*) zSkz8>`}{zgkdbhN4L&D{s+u}B<9#*%&7Sjd#m|Ln%fO9-rMK=NTdNY2oE?I)5?|3{ z@8!jV`h^^JtY}RajGd`I`%k!jaIsqeGq>*DAi&6ge!YK@r~2+k49^bHaQ~-JMm+<= zzQkxI0VIALZ$0*jvTOLRX8cyW%4cv~N*})z$~lh#mLGJ|)F5*EKU5ay9UxmFgoGcu zUD%`US*M965F1efK>`(umHYb)sVk90wP7DlbxK~UI3adCWz{afMo_fArCjtcfPkb^3) zCF^Xv7*cDX@#dR6)6y1Qh466WnX)LTlUC;4q`Z|J4a1Tg(z>tJH3DO_xU3WPvnAaA zZ0OyeAEDZR<;_A@fbV6b0!J1LxH4%No76JX%bj5@qYTo;b43hY0Lb z2l&Rj^6$h&382QF!XvI53qaI@!d+-HUPnpF=Nose z?Y8hS!^vXm+XyZW)0*agMd3v35KF^L#%b)CuH`RB(U8H`t2j=8x4&xU^PpyHdh7R> zFKkkRM>skNcO|90xnZRvEgCY)(D1Qf*UDBVFX1CwHLvQ1Uuho~w!!@cjTy6dTE=H= zNQoW-zKaDBGe?odUaOsJQ9VozYi2<}q!Mc^y*dvq8?$KgTi?{f zb0OyMVpUHi`&wq9Qq7}W%2j&vDaTcr2evUINFp(ZH$>X09$G)D3~nJnmq_5|tsh^A z7*HHFQz69W5&rgp9$D&72BHPTR6kICicxFp_E>yzx_(aS39^c@(BQ>m&f9M0A*O~+ zd1hpB-PW;>D@9N33Ci2sr};@wQxZVNSkv#af;e7f7p{|PfIr%tVh83*ZWwd94oLnw zm&N6-TJz1Ec`q%iNpKr!z`B^f!#>lwNfIvr>rJW*Z$ z$wV|$LAG<@o~ryH_w*Wj@D3Q{(&`80B;Sgo^Af!lDzEdfZaNpZV5mjzU?o|6%9uL9a^w{^$2=OL}^4i|xo;3!!CNj8OmAM5uC1U%a zWgTRkJz>Bf{)P4p&4aK9wD?&rHE?r769g)EtF1I&-CV*oW4MA%Mz3^LonjjGu&EL} zlUim$r|E9V#nDB~N<{m}q1=nJ;-v)Efmr_!hraIbN-n1+;-sMhqZ=shNARnc{cf*Y%k+^`7saGrLIpaIlP#bvQ&$2_>su6uMqKRr z2}e}4a2&5^8)ywNOY~dN7V%mBEkW=tswq47_U2?I4%?&6*7v)UGHE&y&7@)TYC|dnX%G12$I??dm!$UmP1y0+U$O@ecRwdke+>jHi58m>yYPcI%~kB+Nt@(r90Y^`WSP)|0Zcp$&w=4Z z#GYfx5NT&~373Q$p2I-175X$-`6MQJPxIiHkE3Ndmh*SDo^bdTgYG-}UoSaX2;v_O zCTYU+zLV@eoa$o`?ewXutm)%=Rvz0D-ng^mZ>z%}Udf4^+cSmLwDTb)%n=!B%kp4n zpez}yTF`~h=W1MLp3)(vv{UjReWY*H=M~KiqoZRRdJ7qwSG2)|?Ac7+n^(}<2wQAh zXAhN<;C5F}=UaBJbZr^|Gs-s6-p{amDtQLEprnRDNKyV4_`~z+8D4K;_CN=^j~a&0UhDY8FSy^8-;5)4+^> z5D#%o9^jpzSC0NFd~8Ly_208v%G<79;h43Z_j1neR%|ulM2#%%RLYjyGaV&`IZWQ2 z=hi8q+wTg|Nw43ilg`k$$Tdm%IeW!Qvtb#54R@Pw=R&FvP&yf`(<7e{Lrhmk8&6WO z1akAS6r3bsE0lBIkR7Sow+xEFn2Kd&h?|bly=TQ|5xyQx{2Ukh7w8xC87chrRZaD+ zHqALEV()t~4N` zZ5GFEcl|s5&PG*vf)n{=;$Rso9%nH>8AX_kmL;I6>aP zI?Er!wcN;}2%X}Rktwg`ruK_9k&*=)SNb6CW~5eh<$@RwKF_0W^Xxef(yw7OHJ0z= zVoG4&f|WLlL){&U-!`Pn&Ppvm1JH7=_~64WoU#983wERz7qDzJc{%^d%Wxl#t?4{~ z}q#Lo2-_t>@_65K@+=jAGT)v8f#qV(TUp&e==F|I3z~HFL)Iq zHb2L?mi#AjI%nw@aD%EHP`N+rNTKRlVxmUi)%YfS{t_?6PYxM(A-3xc-}P(d59bwo zyd1ycY;wF}TeT<`S4U`U2#Kyr@u&sJ{NXvlnjSharr3K^wWsUy7^{>1ssM$8&!n-u zu|*2jCjc{R-rc zb=3wPjwpL{mK5Mc;4#}h)mBzMR))tP3Vf z()(5#h;8^=0w%o9{s!A_$vATw&3A1n%tR?KIKTBATC-DO3?3wqZ^rY6W?*=>6eyff zOY1KY$FoY}N$Tl_-pvYJpQ^cYwjbiN%5im)qobi{Zu(efzE4^75l?iy?esql8UpzV z8AHOVF$%C?``)WP?U|A@n@^E=wv1SW4gvapYyv+Y=V$-4_O8Ljp2rc=2-$^NGdxy} z?UO#pN$ui}znrmDJ@aWf3c*LGR-zO28WG^=W2=QqS3rEMibT}|sgLqHkEJ%5uCb%v z^=QYfjpt6%a-kgS_&Q&5(n<4Cdn=Mv6Zc|%-&J;(M+cX8Un^-baE|YF@v4OUOmsQ} zwzXv+{%{pwb1eB0+g3ATCz}F{lo&75?;FgNDHA;%U&u08@mX)p{i9-PkinuKoK^C; zax{0&cf%$liwbL*#7dW8|4~*QO`8&!1ZvPFobx_Drd_QSOz797gRm>&C@UIzRy+3e z&D)+IK6N~l;iRkKUO~1X)q!v**>3%duvmyKBfu)?;!k zUJrM36$rMnGD56HhM!o4JA1*3#g#w3TrN_986`wkQ|)U{9jN67cps>#sf`|YxM01l zkR;uBHL6#-ac1UgDJI7EB^~Im{fp~#2>fu)mf6ZnZLtcX=UYm6{oEeHWE|~aGc=^a zDL^laNxdcM-!3xIQ+#x!0iA49zKh+px=5q1aFRl++J9A=_Uhjo(l*5C zd0!{Q(C`=RJt1v~9VsC9uocolg`BW65s;&g9p8JlC1pBu*FQYch6CU+_ zPgmAMf!lZY8gts0PJze8d$kzR`GiCS)O^05%gy{hbg^jYwFn=8s2lW zkQd5eBH!?E$KcPKQ_TX#p3hWMR*DmTa^M7Sbuuz13);q}CQ~_0!}NQ@)$rc9mm6Pj zwBy2p4<&l!{5KnVSJM?* zBx(hf`bCr5MfAr-^+0FMtIgfWy8_tIf~(4lS@Q9nRzlJ@d-6F=n8$6mf#3%Na64ov z@*^;s&B{tBH=EW6Z&LbE_f#WrcnPLiLsC`e=xFcwS!~cZw?Op?an&M$deVdKGcvYY z>nUJDg0CuxN-npcu49Y%Ny9!+Y_4U=kgYPhQm4{Uk|RyWAi%q0?(jEiiqV1jch|Ec zmzQRy2Zn-X#nLX0I9M}&u+2S72dD1h%efGNc>Z$|$4a9iXBKd)QN|xI6k7^~8ipN! z$-Q%}7@1^#^ZmI{QHl5f-4&l`m)Zy;oz}$Jt!oFsNeugM=1uNkL$2OLze{PI=Pp7$s=oBK}71j&&lPah=T zHoef(vBbY^-hGa8n8e@xX#;9>(uxx&0fDwOgsIki{}C^GCPe@2Ri0YyHiP%-#1vzV zFwJeTgU)I3U zW>M)W3{Cro<5q1zBQrr60L4vZmp5P#i3{O z@WX)HD$}#FJ0OovXSUwFOM}c82;VAHBz*Y8tdP}XfaC}~3&joph!E_9MgzK1-y!!u z7T3)6D})aSWPQGo(MEFAZGpapFm}ywK0hDce=4LLly6KEC&*eJgLDQBVNGmuDf`O{ zn|W(y@17~s1$C2?>MQ=$@Cr~EIHhoCRj!JCl$$nf!apxOc+?+|^X{?#Me!l?^M+$r z^_E$e&%7)#`$QtDPa|<|D7WnqP*0%=F2mw9cgh%T^^D1hB5e9K*;D?}gXz-MYGkU1 zYu3TO#D1`xX6%-DtZSjkdXUQENWz?!s*wz}hjMcV)PEL9az3DrjP4e7u}-FosPiS~t#?@7zlTY`HQX zGCgf*uda*1b|G8QFPja#_QPHqSQLCc5YpBX@8#yfbzN(KG`8Pv%sBtmJIgVxKrY?E z0(kaE8agblPTciRyO=@SB1NF3I?!M5%{V2YkNJvH-IMrDQeunfuEDI5c>cDNkCQ11^tR~veX-8+#G_`h+2 z|G)TaQL`y<{VWJept>;Mh(SgU7iAy)!(mB`RXE0aCsY~&jj54NprzKX8z=iIq{MCc zKO9DZJDG>M*oiEW*Z+v)-c00(yaKqKt*8e7yI;BKU=i_!#ta1} z3{pXV-2#*MZ87R~7SZIGi~!vvn#Y`9GGNxEZPJH68=MsV+_zvt-xe)95>)qy*1WgU zd@kWF^+5(VE_B*HW&le-53ul^mKcd4%)|DuQqQ1B{2EBh>Fm~(2Iz(^SD@8mSvvp- zFt;EsDC)j7#n0`D4bDOWsW04#uV`Ih0tTkWiqxif2c5`u?YlI zhzU?>Di>~8Sh#HQfYehK>3t;4Av#Jtrh>om@VXNP-|S#PO-87~LI?}cWGJ=*WAE_+clXQkk*e`rhQtjB0bRs`kmFfU2gg2^nPs zXIz6EsdBJ@S-8N9KQq3G-^oy(VR8?KoUWPsJ707l4O+aR2ZRj1)VAu_+JIi~|8}EQ z8ACA3s0vv|_f2mjySWqGQb($|8c0t18yXyu)X}`+`myFE5<2wol18sfVOblHT6AfR z?eb`rmA-djN1cRQgKI!h!^??Jp+l@B)&3_&8VjKT;Giqj`%nvkpMIW3Ncn>2nZH)V z`e&sfO1C#a;@3-_Kgg=o^^axyP27%bZf2k5^~Vd%lkLSp0Y|?6~MD;HUAo^vJ6p+OOp!3BmbPyZ>-V6}E$?6%J8k zcLatn<@&F5M?5`QHAE|n*GBB@3% zLAtkGxJ#7}P0H5n#x|njKBt_|Fcnk3@)ExK7TZ?eaI)j`59fI>EAlB^ATCG<&Zqun zPJRLfG>GFs$4X4l)r_plq*^#{k&z!Taid#`fXl^G4Q8uDU1R2uGu7&kHYjZ~d~C=j zc8=h2XgRMVdBQ)MX%tEzUaI(IAW`v9+ktLjDN-z3!()JAug;x?_e&aE!2f)+OHU?Y zX!Llx+JY^&5M(m-Z(9A_C?_~`CdUbRL4;!d}Io*Mx15~xE1u{*f!zgh1o7^UZx$sa*;3v?8#1tb{)*hi!Rb2 zGFh(wg!{H4#9<);aK+K?v}8FSq>aSYToSW0Y-?$+U)O(=n4FA{`$sw8<=wA#2Te_I zaq2%Dv!Z`E@P@q*R^)vR%$3sE@7E2nUsc6hMM`$txNA)t^X?ki)JXzZ9qjXc0!K+h zg19&9)c8->tlFvlI7)Nu%^5>pB|)G>cpNh)PgaU^o3|On2ime2^6hz6P?|KUWD|LJ2di&>VEH%v=SKFiDg{>suqOH-3IWF|!1ML8j-@^0%{ z3v*7xZXOSEFGkjpTfEK3CPpMJ9OqrqO*o@g%g9w_jnuB860@mT-+a1Siidlm0C#Nq zrOR=frA_*u+$SZ1m0wFT4ZeNC`&D>dwQMlM!(*{#5m^(&wjki8711fUP`^j9x}1cp zc~FfF5Rk%07NS8&)}bPzP?q4+}}n~ufF9qdpz4KecVPyj&<)p_D-Y*_3Hb>WG=dzQiSu= zn7N+NtTYJ<4XOm3r&`c%fura)!uh<7M1ODjH+XUJ_2)0!4u4AfxQUy=$;RZKwFMcD zr2d5Xt*lP`M%W{#gq`h}(fJX*q9P?u?`fKadnuseL_te-a!Uw`nRF!IT_um_J)<0z7wvMP~nZ@@-*Q!xJM><*Uu^%t|89OATC zI}<@S>Z4-MU{I$elhYLwabkB9c@JM4rGL%4(Lsk5ch~eL=OH$alCa?NrLXrKevqDb zsAaP;vFz<^;1>(tyh0pBoxshwtExC(N?UY!!i@!*kO+ zcM*&fHMZ!VHd7?&m6q(|c`*E9)&k>?R9}{>O0=~ttS{KjAV|PmvAT0BkZbsFgk$V0 z{0#f?`x~?Qp+=I{o-KW}@<>Trw%gtI@y<@eFN0f<83@T8<3>nih;W7i zZ8%h|NU|@BPubIaV`6zxKKUywBoj78eXv1>QFUbqU_2Ddt@-;lx#mMZir?SzB&lG2 zX(CS)k^T=S45R4&%QC2k^l#xHUJY|=-!;YWFrK{j_6*cb;I&}w6swxgHqN*&gSVU( z!;(FAd>F74=LKPuFbi8Y7DYTXf8_4gNAoh$2`mUGR0Upj%4!ojTZXu%!QXuvKV5iEcJf>i5I<8%v!tO}FYt zh&!-C;K=H_j{>uR*qO5BoS25Ha%ZIOtcy)u+n19BqwkQi{T?Q4VJiuBjTEsB9h@B) zl^qpuO|=Z@So#5TXVO9b=5BsBmOkQou65#5cc$1*wZxLeD<4{2*N+$-tdKpTM($qE==P(A&LkAK%e-s1xPXv-h$SXL?v&?1ocF~FsU#e60#ttO%mlW?QW0jgK zcpvE~U?TY55xgnlB-j(^Key9Sq`d~{+G||=BOir+0*ORkYdFG#zj1q-Ad4$tnP--P zCrbuFMascI$o+&*D<B{am%}6IkwwMpZZRwXb#YEG}uJY;q#p^-rQDVlZ63 zrQjgCg7~el7Fo)tB$l`ut=zs%&jeCw8>A(tWvIwI6;G5zoOi2Zaa*PLHS6Ik_jN<5 zH%!a385xenVEbaQc+dd$iVlO#{^yDwY~Y?AUu?vzGq$vFQ`GaE93FKW9rYk!&41h`d}P`O<62U~)M}mhN}e;Y^nW_|;IcXcJT;Iz zMPO0iQ2>0JSa(KY`tleh#8)zZ=b zLf;2}{H|vZy@At;n=b#|K%TWcWJ2Ltg-e{cc#@O0OLqo+)U@?{v>L^UAc*cZfnzI2 zg_DA}Paqbe2HDk=u3xVdZO{;bN!mwOqd&`fV0zWkP1MuN9I{L}vS*4Daej3+Af}GS zKOebxgQHt3O!OBiTz;J~BvWv{i9X5MqN)y6i!ip7+j9u{!# zm;0Zsh)@W~P62IXxf18qv4M!S?fkO`Dmxd)bQ*wmirsTNI_gIja&A#9fP++Yt~Y zL_aLg2@4EOJQWB23BkokBDMNDl+3YinH^{T!9|ul!z?;lpZH@DXLhw;cXLZ@7lS+n ze1ld({{6`6zEM$#a*vi>iuy+9(wtT=N^rc|FNjI&$ulHHx{0u^+nV^w6dLyOzwl+n7j2($UeI9J?Ih%R& zlJ#~#@A(x+b1O!bo<^k}u4y(RVZUYx%G*2s&84aLi9316nWt)UC8!ZgQ@11}K^sI* z{||d_9uM{Z_Y02@LKCvDSt4uM31i8YB(jUCBwGl{lrdw;zJ;PFDr7g6eI5HQJK4<; zvSfxZ%;Nrhf7kCizu$Gvb)R!z=RS}7zOVcEgDH>C<1_Ei>+^oU_Sf@ys2G>@)p=k> z{Ou$K+4Kk~5lCU{h-~}blNF7~wz}=s_r4T|=qtqXhvyYju z_<+j;Q{(xn;wux`$v9H*O%nz)*4N{bFEgEtHx9G;Cv!-b3BHAJ6a2a1Em5LU3FFG_ zEK7}{fN%4XbXNrTo+ed^84uyA4NflF8?IT>X>DZtX1S|J@fzvdg+C%YYIH9>XV9)P zbRE;9uu#$^tT|>P#6HIC$EOAbx?RA(FBN$>m!E%UcgW^#+o4|J+ct*ieDUH2tJJ5a zmBl8X>q?c!{t&`dEO&aaxINZk`*$UD6<_c zA9d5^$@?`ro(;*m=jQTE>5e*lq%fdt_uU_#Nz^9|m8Ao~b!G&(ja)nDaIS*$M2#Rc zeseN$Vj(7=mSlouHiXvlJ{Xe~eOKmR9mLAHMtAC~6m~b>(i6L36FKrK*C9 z^i{f{37@SCD8VR%AWpDt5Zue*$ffuE=5K$uQ@Z^*1+Mc8RwzaCTP;Zf67#*X#2!|) zIpzNR6@9B|%5@rC)%iyWcSyZksfn!(+e=Z2UWf~D!?8-Tva{B^k3tpq^Awm&;xQrj zYU5Od1U$F!+@+AswO)Dh#&vL=x!1z##|<+!!MhoV{RQ8Hv|0?KVD!8O!5x3SRWZoy z(&u+y4y#HGurqjLhK*4iDi-Z7$CkuKau|=lDZnh{^5OwW4L--gz7c7&C2}{1tyUC8N$Q;omKOO>{|`~G96>jKW+uJ_5X} zbjM@=%}n&)Q=_1h&;Q#r!T;%G_+N1@FYj##P_4E)Nt;h{hzqn@5B&RSNVR`XL;mxu z&j0MceNZGoZLuR~k|;?fXBPP25WKO*gCC*4;p<1?}hT_=Tjwq-SSeU8hdYNk&Zy0s=qVl*^AE&F#ZVl+}DeY5e)iqQ@Pt7_^Ls5b# zcIo!Y^6TlJEAt!f@WHU}H5^?1mK#AfTw_016H-@xP7lIJK)n?euJuO(0IK=xCx!cS z^=ny~xld{G+>XZ9)_h^~%PdZL_wa4ls!VU0(gyvjLXIE$lfVA}kcTZF4mSBa-PV74 zPl3;Xzgjmp2Qh%^Uh4Sy1Fn0$QMGq&1o?K1kkmOZgZGqzrCQgZt<)@&2h;M$Z{5>5 zI92N*QR0b}v`T$%r@UbOd!XRnuaVdYiQM#SCBUjmOo3k?PSt@K2|0GSr`0}K=F>jH zzkj@Al`ko|+}k&2v<_hj55GfZ&*K@*SF*?~xBpzHpgwlb;_Q3rA1-?6)sLr~5pS}6QFURpq8WKTl+m3^-4 zH<2q(4W2Tc!S{X~@(wH#uB=q}!pt5@S_`}&|>W<(l>Uhfn%eIx5uibj` z@rt9+XLwc0SznYH%j@roF6SqbO%*BI?ZmhxaI})oi(?Q|g8~M1b{*<6mf91V8_zt* z-f+#R`LZhW#nGcD7n>g8Rl1U*wWM*QFsOAFSoje`UaXv~`zyXJzq3Q~YSYD8QHfIB zWhmM0yxncrpqMKOBUhd_Z>6SMZDeXJY23V6()c3pW@c=XoPtM>176)vi>d8iq2WxY zlH?^df4|(nnxKGlF13{O6)=q<0J4lnMr2IE%=)zDmRFz$)@WHTtOW9%e$9|B_M^V# zMeNw8&0&SG6yKCO*O?5n4rC;RQ_;3T6K_8+12K=jC(C8|BxM38t`vazfj2DoMUb#Pqd)@n}J(I$Ue5PjgWkhvoSMVs`==MPm^5{XY6w!);tViABom^ zF_!YmDGDzu&TNrjY*EaS`#vM+GjQ0=Lx65{kojb&zK^^8v-`52`M`BxRL z`f0^Q5Ik8juuBnbr3b;nb=72#KT1!CKIIOQDhu8Y>mFO@ZB~_vY6reMloL@>eT8>% ze{dC6IW#GUh{jgn$b83r0!Cnt$4@Nf-yqQb_p^$?>;G>oc^Rrz>!J@cv`2I(u2`z# z@@y+?f%`Z3A-~QsXm9e8{*!xG0wgRG?7at&P{pINf|U-;rO7-=B`Yk@?T9X43xdAI z{a!pAhDDlO^S)*=ZS=dTBlRaFf)^nOWji;thwVL{|oPJI^CLObh7z_O zJO;q0*lh*<jhyu9Il#t8ud)Qazd$WRMs@ueMVC-Os(hS^4CqQc*Nh;ZI{Kxv&{_#`|7O`6O zKqKx4Nn|jTVlViE9l0i(KWBCfdYv`-kLCYci5X5lM*g8S+Xn95r$;=6BGZl$3xO{z zpWgq+@^k+2PymeL(Qhy%q>#LaJaj|e1yY3{>aqhM6zd?T|FZZ$UjOBZoG$4LI<^1; zVuv)LjzKcu%`@mfoVS1HAA@>NUdi8g)c=03L=fC1-XXg5&HC4P|4&% z`;sER^;flzvV#n+)H})r`PBRcVI@HLn~D3E!K=TahW;=3Kb6zUgjDX7#E}5Dw)nl| z%Bkz%E1HJTb{hx%=c>{prOYBln{aUs!n5>M0kYf97{zUScWtt)F++nCc9xla&e(@$ zVZ+(_DAD8V$fnsvIY7xI1k-AX=KDOCJ0d1^e9bzwv?R9=zb1-q9N;8eLb&qZ*7H%^ zsyZ)2LaT42ntjuI9gP4oe(F^q^9U}1@z>s>3FttTRowXX@^QjL;76T0tce?8S zH5Nb~Py9b@yX!yyn#^d*@yz`-d{}MsE%)4K-QOt&XL6*qxkZ^$=t?bxmdEawP06Rk z`lWu;mz0`@V6vJnBZq)~_={Ue!j>Rz5XuEvqdEqqK1GwBHlp{O6#&I#pG}}acjJ*0 z@N!BRpw$+Et~7;?H58CpihNE0w30TMlvxSXNh^?TnsA_?`GJmNIt@(6Ohd-w2PvT* z$VDJ}5;_Li?F=5B0aR9K7ip05Q)3Q=dt-@?41(-57UM`~w24Pd$`!)sUnx*$F-ZA(!U> z{m7p~w(SS(1V=uFN^S?dmn{t}st8HoL-ymJ|BoI!>7Sn(a`Mps>)-Tbiidr`N^-%R zyLvprYWx+7rsw6%gM`>Gy(wquMTC-n-LdV@XOFyY8*3339>K=o# z*^WU-Q2_d^cyocEQ;MvU05J4~wc(GO6AOL7F8 z3{Daun|LS=xEcAcOj0bW9EmwyNm(~IVIE`P@pXTq7HW< zM`r--7TQM|j7MXf+@QNyK|WNX6^dbRx+w9o?~6cJVazK(xc6xkL~ zs2tf8D#5=r0j?VY=7JZZk%ZWN=)aoVzc~RR8lDsm!FI?IeKrvfXQf1e|9l%aqg-OFyrE7FxD3kBsHk;J77*U-Y)Tc zh`g~X2z;GL{m)V^(|G6@)t<+vuR0)8n}U2h6W$NNTmOtIlu$uSdG7p&ydf0ipWGqM z6MyK$TRiclf%xYn(m093PNLM482%)~agyQqKh2Mw#PBCE{7DS||A!dJ z6MyH#-#PJjPW+vdYG5qnKY3RE#G;>A^ph@xlOBze82uzhKZ(&#V)T>Vu#?WYlfJ@} zuFI32)RT_gzc)0WbnO0D{oW_O@`17@=R{egqy zAu_2s=M*sReDeE50w)qUk--0D2~ZyolS%&`P{QB-Rus3vQLoJC%Ux?;)6VxS)XcB; zyXN}%-;jF#bxQZ-^IwpQPUP_aZ8;#1NB+%Dr`~mKM}py2kgd6UO9hmx?rG|lGTsRo z@$o7Tq2Sk4<&Kc2or(U$&6m6@v**7sLqBE!lM^IOjHaxkPltM5Swg-}lKjg4lj=}~ z<~h*_OV=t>1Pg&kvDgEKukQr|cp#=6d*WMazNLuqg=$JFQ3c@q+()nsmpPXkeR$;D zTcMI(PV3dP`UT8dYjLUwKXwj01EwW85bW>*?xjtFgf=6*=)_{h*I$!Ejz+O(Rpi?@ zwFmt!dsd89YN{A~$+jB*js5zq(qd;&BksKUhxPB9x&;EclS|O{ZNkNaL@i#DAuhjN z?#H${k?F{8H{a2M$jsg@opPbu>erq7|bzYx_pS;6sJJ;P~ke}9_ zxA9S$<%wfdJ`{#h92;gO4@L0Y4;^a~c9@54K`r0H)zi;uj_CGjruO+oOv9^oC9Z0> zSQyV`kmK=59orEQu(%0SfY4`%+uQ+J%ji-V-MLhBSq-g4X3vl|0O;h$i0^3drq&sl z^M2^CaK=nm##|ce;eo~V9ftzW(dMYlL&nzDw=LW)s~LyX<^*#h7l95d9_x+t@JX-{ zWP_W)Iv$wHPvDGcC+u9Fl?k~IFYBI)?hAQy;Z;0{ig<|GW+nJ{Ok3J-M<_X%dYaY@ zrU$ra1YXuUdv;1#I-jnL8NrI^LGPpvMx*!H@jjH-J+{Xn+zmmlrck080fkw85ummW z=Pt2Z=sDGU=Gmw30m{`f+|-XES(_8J4(UF0c5P-JgDzCVO}La4j$nXW;^|uFOj=Zq-Va8t zB#PI@u8SHowx=vhT@s&UiT_hP851Fb@#O} zB?(FTcM17-z#x$bSw{r$wci$^HK2={)@Bjq9>OdA9)UN?m0nnX|BUro!$tHvGzYKW zdVh#T^ZPUjAex;R3D2-7Fd+)-wUDTFp>9qovM}Gysm-0aIq86*A}Z3w-QWT8N;UP) z>rOZ7BlC*-N_1uHe&>Ul>KgqNHRro)Yi*LfZ9-vc05zEk-luar96>eCc@5u{SfNR6 z?v=keIh{~h(X4VI-cS9@?R5jqjh!3{gU?b@wEAQVI8Na7(C>_u?KXG5qw{`WE-{aj zIIo%qJ}!6|;XwDZ@5*m~oa-hbyZjQ~HU3Zbj5ojcLc7ZCQ1+$F4-=EL*;_JxUsiln zYb_XaW1n@?70~U`<-x*rCZ9w%=zQ|9o{mgc=N752ek)e}gx}6Z_f|pIE#EtQf_#)S zj_y2B(K2u(=J3yecpqoF`Fa(Kxi&K}O)U9=WW>I5bG!Gn?>ioX$i6Uv+ZhL-ASvVN zLbPP<;CeWNT4(h*6V)5{r}ve5yBB*`)R?Xn7wSLcHOA;891(2`-zrgC<|tUE?qq{; zz(+TAChK$?7JCVI*@ww_Mvn!rly&TC-7hiNq>8*40?BV=a~LKW=h|&h1e&5OMwRH( zJcBZyao9eL>9iUid+h=MJ6B5WW@3*)TRMxVAH3Q)a7)Y+X1xyOA-7-rmA!-K<$l^{ zi#`h}V^o%Z4BAH-kbzn7M*v*}WxOsa`p*`)&BmS@<|pOuSmtW{;?vLETby^fI=}mG zXbWa$)4B{Tzb1J225TM4`bCwj}88Xlkuo3FcE_ z-F`eV@}jod^;SJ5P5G9d|L$e+JP=aHe!CqVCBQK$qalYyvb(Drlp{nuJ*E2AWPZD{ z)+Q?O$*b2#1xVJ=Nz?Pi?m}Y(RB>iWa0Wbea~1<3J zi)wIOB2(*+pNsHx7`QpjFqE9!D4^Y!Jo=dZP#<9B0syUEOj5XJmM#-~8VCim3&$Xr zL%8w&;O38pY&i{|xrDb%PZc`g3PjauM7PbpfD~?Z>+Vrg5_Y3F-9*!fT+rch9TT72 zlD<3P*g;x$5lVc)rY$%oQ&s(G-klN4XsSCOEQ3`joxYyneUIAr<#h(#H>*#56Ze|) zy9C3&t)(NxqPi%8sYau7!nF;i*2ATk6z%JxDGbWRsuotpJuO|=35NSO36cmF_$fky zN+jaa6x;~sUHjS5I>E<&h4)K%$l6nL~hent5k-ERZI`=KjW&vhMmQh`7U^{Js@RAdfKnq%dR)PZaY8=apaIqz5QX)08Q1FQ(APpKTvJal{xJQd4!K zl*u^=ae~cK{#I$>S~fq)%>M&i373`^n446dTh4WNw9kNY<8jxeUD^DkmDGE!UwT?4 zK7iN4TaauNz_p3BYZ8PNo8X0KLR=`{8zR>_c|#kn_1ma!$F2u>L5+u={Mt)+y3efM zIvzOS>o45K$VbDp4>TD0(hR`U;-_^uNBG$afWh0J?y>S82gm|PFcTbV+;wR z=Ej8d(6-ysBJR^WC+{q+p$V6KI*3X9h#_@wi>8TVFlY0$u=~z@cd=-4j8AKy*lhAO z?%bWyIc3=ik(HUg3=p{q7uA{cLhxL`=@B^NQZQS)N>5E@5VOJj$|E_kDuc}q!vd}Y zltE2R*a(AcS+4E92$TBSGX+KRb}@?}HPOkP1)`2UqDSj=ISw6%d?3j0F)1KE>OFR+ z1V67QS@hBEH#W_|jos1l>9=D4ntR&J&|8#wNCi5cA_P3!6^x_u+O(jc+t7ykv;!Mg zO*Z41wh-jl;k@U0HJTTYHf_Q3Bk?X|n3LF~O)Eu!5cY4<%g!NW`InwhgCOymZ$L z-Jj66_lG+7>qD6Zz;iUp#!RY6f`q_pyqpd>S?lZ?lFK!5yW{9-o=oMRVUtzq&s;7> zx$<{89!6jAGcaL$m*qzN$8eFe6Z~QU7!-E(d68k5NOI4CJv~}&?t zFp@#4aN0;jwG%cF9ON0`8SKft^{U6$=W(2T_QSWpe)1$c0)@HMt6YN!s^XBvaY^=N~6&{A#83+b6oH(xYOaukQ^$R^`NA6`SaF+Rs(3pP>zF;|EBn!V5H5N@+G|WrKJu>R z`EHr!)!4=lBr}31K)tluCcuUFPA=ZyxNnfv`<*h6UAe$U^RmIbUj4;=_HU}9)B^}z zO6$Uxi5M)B8LmnG47bAe9acI7y+VKa5^=jbR*Eq)Q6l#l^;PPhs|;TxRQcuML<$4` z7$p7+uR>|m&;bm(kp%Vrwml&mH`?}`jE>OI$IgkEOl)hQFSf2z<{1m&D^mj9j%`)@ zn8#`wX94#Cu%tYKF*0!8QdqI^=pJ#Db>`P3=aBw0g&Q;XS(<-~-rs7)++1o#wHgUMk@lQk2BRi+AO$_)x8IH*j3gc(= zzQfYJEPq6K{mma8e)2ZkG{HQN$EF@RrLyJ-jcdfFR+^c^ivIkNi9jr^slu zvsomh+$LN@WCf*iTj(`F)3FwNKbbq{#br8xM$lT~Lrh%MT;i!mr^D z_FK_uUXV&J__l)TRC)h*0I?hl_1NpB0GyE+CURu@zMp66(v}Y8GSZb?a!tlZKYPX( zy=hAnPp=lN+tujzIXu|JHf0cyL`#x3V6ql7DKgRP-?GoS)m1q0d~K9hnHKrlbnO+( z#Xuh;(%RBp9eswRI^Z=GLE-^QF~=CO#gNl;DyXLagb}!9>j&I>gK0`O`xV> zxdvzD9Ix{kP9MeJ%l{ZSzv)U;ysdxP5nmJqkyBmQht~tbiNMj3*-8OllKS3XSPFng z8!kcpCD+KUw4{{av-M(cRW0EY>|Vy5`>6D^rI`B!tvRQ zD@jXcEF06JP9hEraskCxr3%Jug^kvoY_se{BTf~FcK_IZXY;wb3mWG6VV00j;h>9} zF(`k1Hv=DFVqKl+YSG7QBr3@tA}Z^ry=b!v*{sINVGhEO_mM1cry;NWq;|v=oKs0g ze9~17ZxNG(dJ%|Z@bCy4g1m?RLA*&WTEr_5^4kwGlp`>OUSGW)7LAztK1J=Wm=1p) zkE7WXT|Nt1ram%4S! zPs5=Eb3Dh2ryvN{XNcz?sF4v>4QLCVDQbb zz~(IOFQXMiFql`(Th6uWn|>_TZDQ`nVfkWV_U5v9?^)eYXl6EVgDDP-ne6akB<$`h zPAKuh{(LF#+fP=1BPW#l^MPjwe~Kk$Y&n1~at3hQ=wh10NUAt1a@p%8Ilw(;-oi2a z(31afeCyAv3g@}L-ml`@tbw4 zh-#4+lj`G%z&7iULh&Jbp*XX+IW`kK330i$DMRL}{>0`kSKLjs&{m{=n(k8rF&1@n z@0qR6VJ9$A_%Hec3NQR2IqMNcf-q=3wn@^*D%Yo)Wqh}~_SL~Xi~e1wE?x%fB zg&;3A4-f#;6$bg1-<}kGb6K)1`H`^(h_(Z<2s6X_v>^peeZ3Z;?7y@`GKQ^IOqD5& zoiVl;lW51-Fda}rATp80@@ z7@6sCZVuFv84<=`c@ppLgY>TMSk@7io9a8PeQ+UUCMB1p;}I;gqv<@(`b_|uYeU(d z<6X)5*@am{Px|`|)cvnt`)O;qRwCG8t3V>f>#t;?_}`L+dRV2jk3w-P+3~Q)-HY+8>-Krk>vtc!-G%!|3a|Ko0bKy+RYz7HA ziC)Y?vrw2xLAdJng()qv*B>v{HknzT*ZWi94D>ZGSOm8(WbmEh6=j_z&%zIZK#zw| z0+0hkJRGksX2&W0dLjrSSeR;Eof5GVyB5^FR8v!bw6cRCyEsJziW@wH;qJ=wASAlV@Iv}>qsU1*J~PyqY1&C1~B z9FYfx+;>(%R$MH^2J3YQh_4WsuLi_Q3_mZgKUYNOcT;MhK>5n7$;96r9MrDR!4^y7) z>2Tce3>4OVL>W%Je7;viJYj0^s^^NyU+DW^qB;jSTk8T_P|@-vDSYn53s35BHK|*y z>|0XZX&Vx$yf-WgAhR^U9y9(An!D#on_uz7n}l6lzBXPUulN@$p>Abp)+56^Q$4*x z%k$HAwS14Xc8J#Das~;&Ee%V7(~Q7naFA9AyT)y#Y$Kj^!b_@ca=L0Y@oM+Is@jLI zn(e~s>rOkT_-7nl-De_aH^^YnQ@x|WxjxJrZtz_iI1}_NfWfisLZq)~G1I5nmo*8k zUvC*~I*X`%0N0Wn0VPEjA{Y5~DVujKOR!PQFd11eGh?#L6lP&)4GMp^#I;iQ0G}QH z6W4<;dA=1+cy?dmy^-U!M=zu%^nC04eq62EwAtPB!J+rh;Vw2}-$2~V;sc><8YW0_ zPIc49HKInpZJ*Q~18?D)*KZQPUFK4=FyqO_-S|qE)S!bm*?&4&wmPQ4jMY5r`SzZT zo8RyqJ)ac=_0ADPm4U>YpK5PopRJ1R{vH+gjE<&)kdS3EdCJp{;Npl6`O%ELZDU%&xmbh$+|kqfdv;z61r5I-O~ znmBP)UHLtl28TEL|vS4DZG+wlH2i*uqp3>yNmhdh8DDglWP^uC-sfPyaO zS(XjN`_09U&l+R`@N8{E>ud84_FBar+JR)xG&9!C?(DrrO59^QYj_rj%~~6wtx%plAV6n;8`q`CzC zm@nP$(w$5z^h8M!!!(v_7)l7E|HrOkwnkVd^YSHdw(VBz1#W zDI}6~8&B7EWKhieC}m>XvNF9uM>LFlH^F#>Sq)4tk2DwKAsRFYXx0S;$1J2Q42k&NuTbCT$IhtyRU?FQ>m|F77a|XWc zln%(*3iQ+w?*Yy<1(%?`COJQVWG?p2m{9Ltt|MPMG$DS6UyOWz+2r6WYaOD7}5g~l3 z-pl4z^|v%iXDjc6EX#!aSKk*!??ubDCx_H(Om1(T)ZPEi67>Iet(_fZu|=KZYVgaz#Uxw{0fG zMoak*-AcT4zX@mc>Itc?-?~L-o=X`*2#=7ga$!)w{S1}W05EPDXtHi*KlKrAs!Dm- z`C}RO`r0Eu(D$3qL*xxd_gdh31RqRN%#nUs9QM`g58vuz_cY9BrNT}1l(I3u|omo-DWmq`aMVbL5-QRfwfEA1$##nm+p_5l(v@=)LR}>fR`Z% zOKs0)UdH8bl0U$(wtK-ogAHRE#-%q59xJ)u><>ts4j#{jy8B?ox&|d3MW)vCHxkyYOJq&^+3l*^u;53tz9|Y zvwmlW=%y=dI-zGY9138!m+k;>@IJmGO9Qu(pT~c__xH+AIoZ;OG5Z?wH?v{vw6&dD zcc@a|A?aq|dgRYa3I3e}OER}xwV1v%r1Tk94}T_VR1333X1Q|Dg6U}ZQnaTzrKz^+ zCOep)&+Lcc0C{Li&QF`FZ(q;07!!N)H0v@`%(;Z{7kBE)wt8m;uC$(&_qlNId*QNNmGmC&v?9MWq14&vf!kUJy;d{s7XsrhIyVNX$B zC)c5($yQa~=<(N1u;ssP?_V*vtrVV1v%HPGD{wcI1{dWHYxWEk+j}ThVW^HgLnoPr0ahx%E zI7)WD+r?G$Cw(t=^IRSp#rzocg0*rIYo*mreZbWbW-!pL3ppf}%ub%{I7JBCFaNpj zjFKjJ)kxdQx$Yl_mTz#N=EG>Am&~BxVk^1rl8Sn%1QYZ3GNm%k?k;HtXOcU^{T0}n zK}@OE8_N)m1xL!TlUg<=FUxoZix9=uv6(-=#cUdU zlQ4&ZWwmLKh*pk}w1(-oAEaMgzWM#!p#xQtu>23{n^)f1471}`H-kJqgK}W2omG!> zIgSpZ%bNO=3oIQ>9@9vMJZ2C47NSlCy*ES$#lVGxp04slmqlUaA+No9|#63M#Ct|4AXA2n5uQgxTX3jBd+`O z8h%FXe{5nKm@z)Q@8u}0nS#~KJUds3R`N!%>6gh}lJOt$scm6u8vam4pyES9b0~2E zEVc_&U+U0f4J50Bh0}%OX|s3))Dvo>CE3s z9d$(w*Moiwp)}Y!T;T3sE5{MCaCU;JcV@+mXJNpldp>WkyMLvozxzC1Bs{CO)NICL z$)~xm@d_vWsFa*gB2xRZR}_SxccY%9Tq4;GYjOOr<*m2db9A=L+&uEmbgM6uU0J#i zMC;h*FYuO+RU0K=PX5qf{s~sDk6Y+Kfo5Uq3+<_KURr;kdPugL?KFu#mkiTvvn`(e z3b73e=n8>fDS)Nw%w?A9H}pJ0?iZ3dcjS+oA~6Bk zp0V+2^iTa2;;;SSgNG1F;lvW8^eIE;45*E{v==!vP~PL+|-(1UkOshc3tEIOr~sLn=+Sv>)byNzi| z@?x*dRe#-CG=cln#Um3|l8iYIqU$9WuM2X*^xEdH;`Sy1COHqEC-qxlZr4)u8$_aR z*I2vXE$8lz*wDIbHx_+Ud_z26LF3U}T3m#BAW zz(}Q;H-9!fOC?>AUwcUgsXh;xKzxf5(+o6Q9iSj^0z%#%4#8Q+p%k-86O7CHy)3-4|aOZNl0yk@&zK_2%5bEULi%Ozg1nh}i7-|W~XG2(~p{BkMHxw`HGoIR5`6vnX`EJQpAvX{o9}+z0 zB#u{El17}K__II@^U7M5&>y}@hhltgMS&_OM{(Y{@H@SuMOdP zRX)EB-MBB_3g*r+P~9E?JqNd;5_qy>;4aT_GhOqX=`Hh@swz3Z&S>@rKAxT7n0C69 zb0#-@Tw4>YMstR8k?~vy)uh^AZ4}MYNsZ8Irl8~smHy97Q4;~GijKlWyE-f{ zy;u$M;J(G`pYv2Zji@b=r3DtxtP7(JRJf+4_B1WQF>l4yhNRn(>?uZEWxB>32AIfZ zaZsXXBOwrvo$Q#G$;2$gk~D)Bx7qc<)6U--B{zKO*N1v*z={5$L#zyK&Oei;0)aF} z%U~3gX|)26itfm&E%HzSVo<0{-K)q~MREgtH}WMza9t>k$vlGItwzEYQFjf9LgA3l zvp)z|TP8R;CJw62E5vWr)Y^$?S$yfyKWkPs|JYUr^qhZt>q6cS%FtyveX6lo-j9TW z;16DQ52}c|`Kdp5-xaMQZ;v#kp?V40s#Mi#sy|~( zfr4mQD&;!`VkQXl-%rIkC2H8X=FQl*mDpDn)LNyq>Is|rdQ-7F^hQ0%@g8T~g_m}1 zIW|^S<|0f>jG3N^dq|k4_}!!{EL@{L#3N_e77gPnHUydBGUV49SA&;Ib#h{4mPa0y zd9~+-+~a#ZQ!=vpQlXrW?Ik_rK2S<{(x)l1&->-eSKO=sIKk&hzPjeO)-OMdFSw@A z5>;9F%la8~BGoiE_D)IhO2~@VN6xFRW^Q*EP@{?cr_GkE3#fPd_C-en=3FTZ8z0sv zoY|Xo1N(2Ws?sjs7i8BRh(9BCXFOl}LsNiO&FW4$I!aHI93BylFUGh`G-40u=NT2# z=m+L9{5HpvI*3f+3N*5QT0z5krl)?Cx5sxCN-|KaG=WnJ-Wpv@;w5bC zJK)nJwRmTY{fn-rjVk4R(Fz!7Y;5-%IDhZiLy6oiZn=aI)*nC_zC3@6b8+w#iFE=& z>jtG4$BTFH`w(@b}lNC-Ar$-Ys|_~q;m7>pVHS^JJES(%F@!l4vLLE{xl## zH7mi?{Dfvtbt?lz;sbVt6%-atxV@B?w0JenO#>>E_*UzcY~-3fR;MYd=+G07HVGmURNtyWflqxqaaxDkZl1eAw7&^e_&ib9ByZmK)Z$1KV~4Qi^s#jZYf zD2xxYb&_v^57X!j^`2(%qx!zc57fk-05!2xY!vyqJM>5%iw;=(K%X1NZ6x1WW>|1K zP4!-!mZitrjS#w>LZCORIs3ogHo0*y9SfqkHi+T%+OuhoK|bJEG}|9R##va^jLrnf z>|(yBXR`EDR%nC8O?u`~-a6Ln$?>byE4&MA@BknyLbznO#e&fUH&eK4N8p;d25BSp z*{ajiJ(q)u>toK-n59#(X}kup4nRlv^=K;_i@rHrds%$tK8tSHK#w^8pue_Y=mO3{ z5onC|N{Uzukt|F71fl~`6+2FDm~90EEN=C9LrMn`px8X`xdyaf;B6=j6$oL}=3Bgh z*$ebTt;=z?&lJb&Zw(o%+s}MS2z8gM|O?{xWcTQl$Gs3*tE0ee6vz&_}z$>)UMca1aHY8eaeC#=`5 zAzo1tLIKhLB)H-gS{LY=E(2{fX*Tv|kxFFlQM1^e_K->wCbsffgM{#}R;@9hw&jqU zZAqy>Au(w>ZE63G8K38LQuwxpF{bgZ`&TDNyZ4d3D{>W?E>G_|x3~OkzRz%JXO+h0 z&ujdt&TUJaV2t9NmWbx~eC9JIbJYAZ(TH)&pcHLXzaT+uwRs^no8Fo5p0mS@(GeD` z^f8CgszLvhFOahs0?E8@y6*~^2P|yTW_b>5cfXEY_4^p)A!DYWdL7r98n8c+9O|Yw zGZW`2Wwdc`c~f_n`fuR}fpx}AQ4>y{x8h~4roNF79{W5JTe3_`F2&Hu3k z1w0Dt8mR%sIqcuYQF6TepJdOl-T~-^Q+N%y0u79=nLiSLdI)8%zI&^mei%4}35dZg zB_Ng3F~GK;iL;5s;te7^`Lp*+ksfVZjw&ikO|Ava`)G zbn)S@yCQwf?Hud#6cFEd5r_ip9}Eg{8iv)BYdr?hp zcYbjwNlJdnUq$nMh2|8v5!s@Te&}-^R*yl2P(M4at(KgZqs zCO*k&)7tA;A;W#!e!za6BwRDnu4kHE$5((8Eyk=7n3?GsQ z<&4IcucsOp_OR;nNPlaSD$%*iC~)7vw#PhD^5#uHq_0S2Rjj+La15*6*T&Y6fyB(D z+lfe2nYK#0QsX1!RHv{lnyzv(?n?6!&cjoUjnzPfvyrlUz-h2WO}h#k*JQ@ zEAJ;;x{G)8G#e80_{zqZ=%~Uk2!%sOG;L}sD(mU1Z{@Zsp9S>y1V8fqDx* z5K8z-bRo$T(n|=|tF6k<^Zs&f%Vo-XhB%b&C zo0)sh%$eJtId{%H{PDsAd7qS>z1LoQtxqZBAg3U%?nCRd2=AgX4MYSU2H@d;9*@O= z{HPt?`dmx z2NsCe+)w8r6Jao`o@c2I>Mtd5;n;V{&!}Cdxcoc5$YY`t>g1D+!taeg!&JqQEJVPU z!-Qu?Yq6P$Y&INnmz1bwPhSfszMQ_fsOP~w%-EsbNBdqiK~oC@zfZ1APP%`S{51J1Xf9s z-n~OoCzn4d;`gaE0DQPhV8&7Hl3^tAB9PAg%cO7h#-|^fi}g!&{q1ESPHC=?U9tTe zZJ{q~W9A-=rVV};6%Pr%g8y=+1+;28f|-W$@#)KlTWIbQ?E=N99B0-9`hYL$_ZI=Q z?dRdWcd;i5bP7+2*Nki(o!m9HLh0MA-s%;?VR{RmySE{UFfh~}8BY=gmcesZ2wjVY zqe?^Up6n&rU6%Vz+*ji`zCOt8=1CUcoRhv*e`751`$mi7ku|ZoO(+_v~}x9al^u_UgcO9nb9IvQby6s9hc@Oy4&_i8USLIYa<-dedaz|@Kg~9hfsDQ^ppgBty;Xk`j01oXP zz>bSvsbtW@C=D~Z)B~9OY3`Qq6_IgT&xRTe(}hHn9*i*TUA$?H#vQ8xz}TIqXD&Kt z`lk37Zpe1?-(QanbQr}r-0ob}EejALi|{JT%#`Y6xH-*0wQqEVEYz5jDh$|w#hjq^ z{<|t?Mg-SU zPT;F?#O<)2fhjYG`i7*xL~f;~UZ_DC^Lqu?xrE#x$dF;DUV`~&)c{;iw@XmcI0-I7 zJb-qT1M40IZIJI8p0&mBp^D)(n5&UbSM(`r#TgphR8m{gM3y?W3qlWs)N51T{RPY6 zjzS6|ot9!yHA}m!?9snGX3)z3U^tq@Z)FwLY@y^dcFR^2lbcBhs&7y2~Ok(jz?&`8V*Y0-C7=|b*a)TgU7c+XV}e4 zHAtz~v?yDubOzEa)i>TdaGn@8+|LHu;dUNLFVH>c8H;XY^=7C~><;27X84kej@yf? zk5~uw1>b3-d<#=sb{j=NgqdH&veKz3W52-ecTDXYavR?Z{BQ1gPMnQU9qOJ55Q4D! zJ3~PCODnu35S#Y13(e?J$*d&oYIkPmb?upjT#}Deca{)om-AIU%^L|7ebOalpl3<& zw6)wgKoTXUms$O2dL!wyT&Us8ORO2ZINoPBd%~+Z4jWOct&`+*J~903RVdhU$<%6Fv1tBg!9ikoUMX*3<|U5HZc(?n7t9`T zc>as29$A2>0s8*?F-Sq{>) zb~Cy&ayK*Gmy?ixHC+8oU-w1x1>{A|ZOHKB;29_xLZ`qS9|&Y8JC&HhVaO6?m;$$U z_2;@XHGbYy4R?{3EN?{mv^Z^(g33rVkBjG=q1C94K_!n%bMVW&X2R)rYfTMkFBRA3 znY~qWr_f@SW;UenN(;D(cC!6c9k;%3(~u}3%n@(_v_oa{y&p*TPACCn(F>ep?>=56 zt0M+MPc^5kg>e(pn{cbplwY{(*3wvTN1JS*%9n)G`(kR+i=+jT?%_}*T?@H0_7FHv z?W`95RF%b^2F1+l7R{*3j>fM)_LC6LwtGb{bb))B)OjU8X%U0c~uk(kqunxZ!z5E;06 z-oB=pIwQQ35n6b!LOA!`c(1Pv^`Uxh(g3Z*i(MhR;LBwA<1;#h0MUKtizjIF-lb5 zSES$+s00CtV(*a)E;&A!g76T#*{vYUOVy5M92=k&#a{XBP4#MQ;ZC|BB?U)@WdXS1 z^CRmTU*K4mM9f>R%l5%(#nEaebjr&q^l<{-xr#(9pl*?!jK7L%^PoZ4;v0QeU`&gP zvz7#vocmMd#-po-nJe6=n+Bdj?)U9_q%MEx38>t%21H-}59J%qAN!v=HkDz)90)b+ z(oXjWI11y~Z7Y{(f3?K({T4egr{@;TEh%a8yFU7V98FDn7v^>AZ1!y?wd2E@o@$U9 zSqMMV$l+~>I9zN6=rl+p3~F)U@Yd-fgw)7#it;q$33QI#-~#&~Lo?y?HwH~b#(-yZ z!(lL-LGTEQQ`zYcNNUz{<_9*fbEnS$a6cA(iSmY7kKTq*BLYVdH!-NBCdsb-*i(+F z7FuKxuhh12=ttX6O*))oW=v}QjJGN14!%<)@ReZ-gTLiT<0fliw8~VhJBsGSin_k? z;9*x+Pi}H%vHMIj{S=9%&u=RyWSlXEw3X3?PZ$UYISD4-0TRNCAo7uiR2vnQ@xdH; z#6~~GArjAG#`=v=~zF0o`AWFWX zA8IW!T8o|~U+Il%(GbdP6`8KHnMzSMfXMZX+B?lF7?dj4(;qUWXi8fLe%ON11alIe z;rfaBG8`pih(}&cNy(-2<~rU~L+LV{8h*xfZMKlk^biwhp{k*^M{ZTXtGQ-~oQm|8 zuF#{e&BC|0Jv-@=Zt-w|PP{Q}S=i(rDyu0~PCSIfbapp#XV*8g*5VOd%UavqC)+BLE{BA+0Ogs&e^74l`el+Xt}m$PXSVE zzlc_I&NFmSyQ#K8Rx5X_w>U_Ye0Bf($w%K^kzN#2<+9wgswj?kqSb=8TQK3G;c!r* zpd`k+@k|{qVg!U?kFG`LkyYkx_J_Sf)DPtQ-J~8cLX2l>9TQU16sK!}#DvWmS=SKe zMuKe38fUrr2d+-kCXM=V?253^{!-)A6qbtsfAe~!?#c z_!UVE@dO9%fX1G(F2%0e;oUx%)taz-TF4B!1#?!}*-?gN%!DSOazMMX>x;rTtLR-Q zPO|+4duN{`8XFz)n{u*F?}4^m)XaMHMFp0xMmD`EYMfW96%1VW8t$gv(=Nz}q88+j z(_Ts9m=b*`Alhcc0SY9$4Ja`<;I8Ik=Z>h-*DTv{8 zYh$=|wypj4HJvxVkBT~7l1id#i0U1@FkpHvvU>@a-jh6F{?AKz;9r)oRy2C0h*+ci z=sk3pJ)%X$&8DP!(5GL4$lDzhr~!J4&(X^1;^Uj}5ewX^?U$t%3bC?Cxf=b|pm|lK zkPaKj=eo)pcog^ruY&Ehnn2vbNy4hef9Z-fchyC%B#ERxLOIhEW<9&`;t=KwObj~G z@C#)qg(_t?4%RihSLk1)Tg*K*Hc=1;QN6uABFK7Ci&E1p`g*`d>sRbzqsIyWfeaR^ zjk0i(wJ>!rjPs?H68gbwG^q<}*4GFd4m0kS=#&nccb+KtHtGq5yG= zsHY--RWui`Je3#wZ zbdfvftkJ^##Ikdu=DBp_@C2#>s8c7aDlBSY_gAAsRHa-;?k#du*SCy&e6&eiXLFNk z$Pg^O;P(k(cLA=<65{!Lm}Tn^sCK$5nd(O?xmjfkCx-j?l&qf4a+;X4`xD6LhUBkC z*Jp&}A-=IZ48gPZ#s=k5U<$qnATVkJfl=V^z{m%jl%Yr`@T{QvM2j=D-IQ>f#4;-0 zpKLXm$!Y)DGqh=J(ngEFO2PjL)uWN8#N_4Up5&Q+hGc9B#<(b!N6NxNJa>$fld*5_ zK!-W3>n{ffch_w$AO%Zbh3(Q~^ts^MhnR~+D}O2fwY)1~#y?eQ8|Ca?xfWTk;|z*~ z*ac8X*UNiP!jbP(nGr5{%JRVygcSbqdkHrTvSwDx(U+#&^TIdXLFSJy(#)#64s>LH zS=r!|@3&0V(ramr@2`!0^l)2?Qu_Q-CPgOBP&*%cZ&UCj4z#YyMR+h0{MS}?<|oWx z;HwImxszW9$D%P;(?CUe;z?(PN>P{Ef3{U>d=vy&9X^s#s`TRsR?MWwr>!5F$G6hV zcrw$ppoqo#ikq*Pr@c}DPM)>6K zKOnB3DAkhE!KW|G`r8klUIeGH0>CCma(+-L8nz8%LkJhUwlZOV6>nH1a|`yD)IPs; zCno-?od~5A$8$|B9Ozc%{da9b%!&`GhEW`?qIeWWE4D@Yqt#3-f!u4Xy?nlxd{E_0 zMDT{e!#A2boG1Es6f>y0Td%;UwTUIlx){l5H0|Bl;@BdWDwn2jWox?hwM=tSv^2GG z@A@z~Jic>bnk#5OM(Xx8F1tq>6gHxdY&CwAg2;bRg5Z@sXdc8h?6(}-0>s1dZ*_cC z_JJ0ZpIkRyR~+qK3%yq+`RcP!QVU<*I*KZocL^Tdn?f*~Y>~*L8L18Qug7l0eG+2M z|ICrMx1!K?>3aOZ@?jF%M>;561N5Aa8GyYHbHSs#7CA>F+gh->9p;OiPjB|vJ?{#5 z7K}Z4@WM}pT4Zer1gH-r?;NipV^uHXOJte=IL1QHrxHuqma*OdH{XGnSs_IAGWndR z;H`_TSP)%`;y9n^yK8p0*BRFPzJcHWk|8@lwTB!ioxq4isX3sgffSqnkq;KY|ApZO zGXNdu;Md~sY*u)sCf9`7pJ!`#C4ZAX{yzN{A8!)hNrFuL-omXr*VkCKzHp1AMU5DD zUnV;_#%8q{;KPruP48@~3vb5SO#Hqlz0N{9v?!=R&Uolfw+b!J5&UomZhC?b zqBTYZZ4XK*YLt{HKY9`W|3wvSWIq1}w4bcyz9x^V(dsq7TJUW%)Yx zck%cEc%a&6;ljwzsj{jkS${wyTp_B#s#K4s$WI#ZW5^B~)Zrd-zZf{ioUxOb2BAQl z2X=_-e?Y+20K)$yF&YBZfA2q_AbI>KaCEtb90tg&tvLkO9oWCz27q&F{;`9sR^rtj zlY)kczm`_D2>c)ZfK)e0ryjVL+LnrT)n=^mrVZ2*HMx5YA9% zTj-wF?ti%j*meH#ZbDQQfq$_=H|PI=mh?z6=WqFZl+P!|`SLrD80T^Md@DKMu+MXh z^R(#i^2K@bd7dwypGwXTQs>9j^Aq-Y9pk+8a$YSuugU#Wv*x^7bY3m`53CkR0VQ$5 z+`pT8uDI!UGS4W}Ue-Vi&1}E0o&KqvO+WL{kL$Fs+Zs}P)~G!B`plV6aH2RaIU-OhlX{6AU_Dgtq$`4`Tw z^D+20bAUmldkKeo3J4w`kIa+6zir;({0RtC}{oFS__rk(z$P=ng^du`(CrhvQX_caBhUCzQIFRb(dlQ4!bfhj>;fR z$pX#1#isbOsv9LC-yAAzr`MHRYC$d{`(qMNtt=%dj0ImtwoCs;vIDX_6K^L7CKxn> z3%jKFPBnG0W4(~Y1+6qx%hNh|50%Sr{Tat2qx8q&nEhRK9Bt3+QRS-cPikKl7x#!4 zUUOC!HnVdB)OK^1-6N8v=>`XuQ*T3xUYe|{e!@z&HJoOmSCDh#D?19;sF!us}6H^q0Vkjw`xBhwz zrU)(2DNLM~?BjD#53Q>gW6%*I-$zAM;6__CR@-p_*G;szTF}@8V!FpGAGDPI_2Js| zjb}4&4&O{SOm84}63e%bNjHdPf~~E2m}dy9QZ09(j5l}U#W&b)=*(Wc2x?ntVL1{f z=I%kA(II~3$306NKc{DfBZt_%4di^+jrDjD9tnca-zikZ#$NhB5IL> z-P9phd{*ZRWZbW;QD5Vi=JyBeOyAvn&HBLaCeN-)TX%p-C$C(m7CstBet&FtY-(zu zwmV6ohgB~;NlJ0IOg)8#6P}Kj3-)=aKEZ*9%6%^pq+FKo_4lky1e>Xkm1A3$ZM|Tg$XF|cxg4>_g zZ(23N7BnVp9*^mOc8QbQy>nMXRMY_TM~V7YZrATo?{S=wt^11{y$Ixk&zt!bad5uWgm~P`U9#*pNVTxoqoCv z0Su%6)pk)C)p`T*1m7v&!-CDh+|&&$^SMu-xhve2=y8;FvggfJ+5_idR$E)6Im<*t z|A02mpu``u!KO#}A3xc|imL{72yNRVC=~k%}7uf^Ww~K@tQp^4)1! z>?YJF+~QNZ0c=aVe;v-*$QdVi-QrK}YiB(c4@Uk04ILp-r{p74KPWF7NMHNQK>=ik zasio(9ZX}gfH3W}G-pLn1H4JmOgR{j5E-JTo6`N-@$#{iqs@Lp@9PK6f`486Ir`~M zNIw105}aiP+AA3f;gbP}z_{z$apY>aas$=a71nwSi=2SB^i8y{#-BgYCKrAcw(Xyq z@en*RAa)pyVNc-JB{ra)$XDDvcCy(rY>O4n?1!_QpK*ywEww!5vdh!w$v&&;fH1W} z5NdHN9%T5O`#2uIi0eZo*6Mx&A`0NO8nQ!i~)>Bh|Ni9VRu#zq~oGOTy9lBaI}HEn$ryO7jq zkI7juA6qh=$XJ{j&{@^ZiAZ4cixX&{7p3_0@HyEK$o-M2;!<3TB<92Qw>Z8}jzy|P zHyY|Jnq!vbS>!J$L?krNuRS#pnQeFvdfNzP!ymSh$ZC=w;ih8%F)4=GulL?HB~BQ} z>?lS$o2dy3hQcTSrdNyk(r&?-$`FP=$m18U!IGhaD|dl-(YL-eP&y085=G!LU_~ML zwhE!Of}(>%A1L<5`v$5ow?7#74D#fyy%HnBy>4&0%=a6Gfj|H`63TZCfGs9l?uwWb zwVBZP$Cc*0wtaLq4m3f@GGx<_V9W$7VwJKA_ITB!^A9L0TeB>GaZ8anw9wk@FTBk@ zmi@gnGNn9Jz@|HtIjN90F(}ry(&4U^{P+VNe7 zR`hG}ii**9nP#Q}0-}S`(oaBCLV#;1GieIKPPnv^JQKa5i;%$;=az1p`@X!tsWb63 z4-OqGFzD^`>YwIZ{U}PlcTKDZ?WkK8KRk1Y$=KZ)Qa*&+YNqK3*?ilZdc@{n+b?vW z^V7l{*@Fx}t#3t$6~Q;vTc*NKkT%f<^=X@5zsUvYUJ066*bo!NG$OEDh1zrH9=T+H6s zHVorvzoHhv6ZuewJbhv%L{_vMkj-}!b%xCqA3=7OM0TMrhwOa^a{K?}ShTV5#`euh z<1il3A_3{IA5{RWYAQ|2W!8LJCw$iEdQ9E?uhO+XPn9<)O{k2 z*PZF49&?#Reg<1bFBi65{}yJNfGfv) z6kaz|Z}UZJ1P1`=qHM0M6)|^z-?w?|H*M$&?HaS?y|K@aE^}CY(`oBp`Eq#F_T|es z(m_9_Vb-@Rv*E^;L9JKUIxn}^LHRqz;HVJQ)VHl7SVW!xV19_A(R;euN1!TOfCt|5 zz-w3La*8h`SR^D%i+(Mw>!I zM@{h1L%~7nWBs~!hIioK=0k7u5?MF1me_+Y12jN=DwkmCQ8(Ew9?d31Q(}`LM#>8$w4^Z>3WcV<$o^i1yBJ#f<6 ztdy^57Ba4HK4<9MDWnon0H)oKN(Jn6k7Pu!^kj}2Haf9Z&*QP>##Ha4`bHFulztXZ zLKnj(uoOqgG%f?n@&DaLAssYs+T=k+ZuGO9|7BLXR}93N;>I8Ds>_2kCzRc+*&=qbInm9sgxF zDhnX`)r}JN6K=X>>3>LO+{!Y~{kpeD7i{S8mo1%l-xHpC)7;O3vuJxLv)%{N(xT#& zujHk{Wp4e;a6REspR)ICV~?MDh`+gkPRploW@b+*R)xLtx~Tti^}A`mU)H`Gw`C7a zc@bHv+BUf@zflv2@P(N42ejr*tU&L0PNILgt@{I3InWcCL|Re@R~G4|Y;ZWhM$JLi zWD!tVyMli}mbFMCbKR+q#$_O3I=}v_kAcZ+yl&@g&+FACuAe4TB&*EwqpMST+lY3$ z`;`Y_X3rm}K9QdNA{7TZgvP7#O)2YPZWcTF0TwWgXGXO-@}-E>KF%b|_sfCP#eiW5 zD`u9xomWsFHwg&G`NlU)F6tnDnr@yXJP5eHMm#8qSf&K|xAU$3eJ7wX1@Q7(6+YoI zGyG0uHb-fKVKeBfmOwIN zx^R=j(SWsTJ!%$*j#^@aIzeV0>wd;R#r3vX3*(cEzt4X=^@y#N4Kmlw39$TZ5t({? zh2o7z3ge9e?T0;j+`9ORGXv7NH3h!&1DwSG@4t+U+EPyGuMnB36TcG>b%$EvNyP4N zy#kUAAW8gOl?^Ysj7&f``~md@>c6h8i7!mUn8?B=?~eZe?`l0nEdLWkUy+uVhZ6V^~B z_`jBy*tcN7rrZl*bIIoYoflnZwpC`2l&$?VYgy}gheKUjREH)-jpo4(Z>?Vai+HpN zu}+l|uzVfauk^7JpucOqq~>-#v%Y9GMKs&qAazg3){}GhHlL#}CP?Y1g!s7yhJnS# z%%Z|nnZM6f>UxpEeL!?n&Rd?d6hd9JX#6^u1X4v|GuAZ}of zO4usOVfeVY9TWCs(6uTx28&x-z5_LMts8Sy7eV_Zd3B)9)o`?kSHamxE;y5J)i+{` z-Lr48ATKnl z>M$<(Xce!sT&86)#El`Pdalf6s=t)_ww+m`B2_KKppuhRcPb3ahGO;gqmV21jB#ZpXI^4U=0Q9?+oXVrM#_bLkv`glXs3Rvir}n=RO==kC=lH7=mW7- z^(KCGmGUh0W&~V_H8s7LHq-rMIU#aJv+G4 z*1JUK-JG>Z=1wvXeI!3`xnBsm0FR^4nS7T$KJiS2^d0c!`2+IAo*CjJF;I4AI1H}782hFb0b7xVW*B!2-}rT>>Q>i+LzpF_IqUp>&~Uz%qzAh z0}Vn#E?p_y9s@Vm(e)prR}-mLO4^f?NE}#zJ`pc7Ob!Eryz}f)SH z4dIRG6*WOWm*idh0)iU8a~V(!=5C>1u244$?@(-c0qj4KlFJ9amzL)f*mOPKWIf_E zh;LcBANhnDkfcC6Co_=-8lY6{rg++AUXxB${t@?QA8RXK%xQccU*^aqlUl5wLXHR)=^Sm5K-@nMozh<-g^#U$y%}1~Cow!cmC+|nNAZEH< zl6lLIYA1c$dpk3kx06~z(33K;De(}UU267{_fwhuMKwqXjM{itl?5*@xqG#$Nss-3 zncJSNlVg-d{q50y?vTE1FHX}7)b07;Sg3n^qJHlp6fb`!ryeAp6uY`$Y*vT55~m|x zlo}J$z(A8SCdZ)pTeFvSICutG?cWBY4j#%A#+igEi&I6*BIzye939IEzfHRzQj&U{I;Bu6af7}q&zJ!B=(Zd3A zknZKSd>SN=GHfLt`Tgt&j#tMhZOJsF;^gwUN8Kgf3MtO&I9})rNr2?OyZv2IlvAJ1 zb<2Y-xE)nfUx;nmY$NetE9x^!)m#Q`%X+9M1e-O8^JA8F$|wcOjHEFs5RdIJNq{um zs)&Dm#*YyCgng%11lZ~Z7XRSNjjUKX0Ur%2DU4Y--^~wj2Rrr%-L42tyM~WHYQ&GP zz%E+iAswg-1d9^Dh_NDjf;&f2Ql~c0=({`RTnd+rQ1!mi;}F`_Y|w!?^dSzss{@=7 z#wR~2Ylyk13S;B$Z+Ugqj`hE-dTp7WHumL%=+K)Q<`>!9tqSnHdoJRmMV12|dBhXEfoTOYDDT~Ff0+fp9IV*=cl(?s0`CM1h< z_uO*Zm@nisNTuBrrOF1idn7dV0R32srRf&&4{UrABO+NC!U%bmbZjL0C6#6T)wbW) zkhFk~WTz}X@Jt3ljW=^tn$a!#d1@3tZ;@LqJZN$6>TBvBJ;sg{{uN+u(lC@!66(&) zdv)+$a~8%w#y1|!4YL$)Nwj4sDxUk6$dgXC9((mAZ{Gv$)CV`MF=xSW`EHUFcDPep z`AWf(Yua~sN5=-{x?54B(-ZA=`s-oBHZhktZ9}A|LKG@zvyRm9+4xI(W%LPUbGNzm{0JO;?f4VJtUh+(l4^9 z;F;Me^oo*47_6o^6)BhX6nBRtDr>f^WG5%rv0#1un0}K@u<dM1c|^vpLHp$%ugEB_J(q zWxrdE-^cjQbLA$!ToYKh`aR)j(+}ijbBBzdk@H20)+>TKu>_E$HN&zM00K6>?UHOi zNHRF#vaMt3?J1KlD6lUa}$1Y{Fq;L2JHp&z&ikq+z4XiGv9q~_!#V> zen82fTiMS~Qtv+XmdwjYYg_e!hIwGOsPNwzv@4U9U7=~^dHb_-QhWHeE?z$w1B zk;px0agw>bVA&IiYfqgiq)*s==xOO*b9RT|fKSA_8sL?%$IIXtAH5f=nkd~y*E8eV zvGS*9x8n4#ynfiXXnptQQysU@;E@m;S0oR?BOlp*kYzlJR z4G-8JM1_Fy!lk>!9LcO(y)l(N`b5u$$^yW9^+dKZ7Is7p&>QcVAAkYT-mn(?Wk`ts zDLBbTcY(@)wA7l_*V>RatC{#V|Dh~ zfkpj1Z1ta%4+qvh{yo`~t0+dOdXQ|$^~?3pBMCe}hSYubGmls=kcRx8yj;$7Z$vpe zIx53XJ?f56SqP{8`)#ptRS`H#&uKJ)<<4-MA0UC{j=0~%7J6xHwf6OXMAgDV^9!zCAjO8@54%v+^6g9XaR1I1nIFowvXi5&_r6VibKY|sqgS-cD2iol zqVQskC#y{Xj!-_ZTCE*_EO`gIMm|xZHkqzVc8K=g=Ly>>VTW{s%)J@qsZ7@>r8fhH z`;QL@y)+8F38axl!AXP}5aGJ9QC}*)P0LR+ISpR_=&zhemuI}8G849z-ml{N%i<_= z%-1UL2!W4(`@_%MMKSUy{Jpw|n;W;D%vcK@t6(CTgM%`NOz79&C&4%reXD>A6sw1% z29hppeVO-)h5dNTf`l07;;Bj&!%)-0!+~bC{sFAw2$j-uGA95F!9IkNN0}8x;KmIf1&#Nl;~<#2 zjHYgTH-`52&NxkqOg$R1#tdCbwZXrfPG8O~5H#?iXIk(wj7WH^1cGT~UnbYM10x{| zsKjp9-&s7`TcAGNEtG#4O;=wnPE-AiN zm8mQ8X#(?Elj3go0`rQdvgH2jPLRn632oZ#?P2dh3AC1+KPSnU3D4DheGYP*y)^gv zFyW2Kn#nHvDikQLFM^{n{*l2TeENM{QFnK}Y>@W@3sD)Agi!#^qP(ybnE}Jss6JzU z^3#4yh2pUmb|d56wu$EH!N3^x)z2ypX$l=J(ZrQYfYh%0zd>C5KSAM8JDe4N@>x3J zsDLgbPsxoZmVBlpfDKZRzLEe2ssk>dUvd8|1$5rC{;$`w{{QlwDxlbrz1`8ox<8;9 z7X|2KT>8HO`h-9M3#9+CRSUs)f_qqSj+h&|@nsSBd5?Yyx5RrKIi5&(a*JsyD5vNw zKjAvd1Z)ES{NL!&LUB3Pnrl&mQ3D_5FJ~CumFHKa4xs|{z zwtqmv;{+$tH=rc(XTR;zj6z6|^ULXBSBdhnz^DR&VptvkoI20b7D!9DhLlK=*|US_6Eer*z1*CS)SW zN9O!h&qw%tHl44W^C)p15zqIW^ZoTampIRx&U3``Ed2b$aelh`FMX1{=n;SU^^SfC z;Oqp@_vkdQW%Q%79%m1nuI@tZ=IZLN1^uSLialbXp#&iTmF$02=y)4PI3>PCI-z2M znefs6%flC~inBb5-l^K<-G8ieJ^gO=AJCxD^UG8q5Lusy%6sx(Y;gntMjrJJzxxNI zE>F?{O1QXeDCUEz5^@AEvPv^X5*$LHPedzw0T)^q{iLN|(#W+86kwcK8A!}U|NTl5 z_-Nn`F28C}MM(~GS>ZkgrSNcy%)cfQ+0hvjcKSQ4~ zR+RQlBiC4=1OUW%=Q7#S=sD=_9C-LY&H=K{roHLkP!AEHQeC!8W;jfUuG2pQxQ6*g zvmij_`(LfezgNy|I!U)77QH(NxV4d{dvM}%Pg!%1M$ zD4J>)A&H6CM#!(A=sP1HeT@|i+VndcSC{tE85D%TzAVG&q4t%$@XB8QU>=;_yC%7u zw*zdXwd61J!i_)Hrzvjq(Rc-3Ap3@(!h=5|Y^MnQ@-uM1{@yvsoBb+?_ciS)$?aMq zRjLaZ6azsH+pz0d@Z%4NZr^%ifDGQigvt|@pnQWVOtR_fqr%xCN)qzo%W$i!J+aGH zVAAE1cD+pFDQ&lXN%3K(q~yyFZb`AAr0#Ft0`2ufR7r!#1FIHkz~UNxs-*%7ZUZoz zWLVUW5l*XjiP;w#KLTGh%QDsno`ulfQ97z-*gB4!* z5}imedxsC+*deDgkNm|}t$m4~)yztVlfurwG8DOtI`(uvMnWn#hqK&Pe)eQ%RoUtKTvGThtm}HjVQ53IF3;=n6vTq+db1jj}uw{OFdiQxxS7?;5;p&=(k>!&$ z^4}MFd#o8JWj*tqv)2Wyo;9$uN!8wD>t`yfaWi;8)}!Z*s*!I)MI+gfPL;C&3MT!D?@vAQ?&Gj& zz0ZDfsy@g57A0wncl!Aazc+`HnSB;Y00o!h^?UhfaeH~Q4WYl!z>G73&zygqz5iHf z(&JZi_ba*Wrw!CfB^O?)t?BucvfIb+_~5SkQPl!()^ZJW6QA)z{RVTw5+1p{&p==V z8uzgaaCS^x{%DKR%J3y({ic3I>a&`6iBWfN(S|eeGvs`yBabV>mb7zG?b@C_Bv{~v zd$=D}{y;JKvS5f{oyNf#jdh*tUiU^FUD*!&Q`*ApgpxC!lbu>&keqmEFIT{13;RdY zpGW+=(%&t+eHouQQm#0FR_k3Z`(aw?)ON=D&oMf~Q46+@MAqL)S5Weo9lP8m8AHiGNCnO#T zx~RdA7-|NuMt4-w66}G`7l#jQSCwVaE9l3F?nzV>WMJ5JDW(~!`fwQ@*IUQh7#}DI zz6TLCu}cXE;ZPSqZgRNxl=Oq8jAgs136>hK>S~XRK{iq8IeIdRSh49TkVdoMi z=DqfH_1tXR_LS|BS6!#tNR6eB`UY6a9!$Cp4(h2+S^PX1o1IuSoQBqn-#R9-+w%kIf#&lYa7X${!w_(jvN+BgW0K^tHrB$rD$Wi2 z60ekBH6gxNMD`*#RY{J@JHbTfpV%~V|Cel<;9%0^Qgunk-~OeOtwIQ)k!7q_Jc1H5 zDMM7hcX8I^mjRbXYD)OkwS}l>RnNQ=!q;K-kYtHg(coSrMKHe$O+D^|j!53joNYc}?8d<;zl;qGx4v}7cBLXUy{ly&CQ*eEW7UYUznGv6cc1n=P2~0yJavJ z+Q+)i3=bz=8bDoK!~60j2V8A9BIlU;mg-B)_c-$O+}vG+1|B~912Rd@a$giBKbnT& z*rN$zMS~Ui!pC)^zkblfN_bq=g)=tN6ue6*O-KXvue~u=De2M&x$p_`(O2OlUVB&8 zLY=q}CMjQ*(nA&<)5>A)Q&b8))GA^JEIoWs9HdJhFYMEl74ML?X*4cT{%gh|$N#l} z`S#80*Vh{z4&pBKX7Z;44ZPl+$unt!0Nw)FD&b4F$LL7nt>afR_e~o!bv|yU?Fk3O zUhg)%3H80rLH?nipKo=gG9FEXpnc!_4nE!ULM}Au23_XyY2=sk@`{m&jLE7%9-HPI z?w{fhzVT}5eNNgNTc=3Xqm30z<*<{gPNNnqpkl>`VjF`m5`=J&1ZCMXKe(JDuD+pc z{%fnir&D!`R*8Q$A=x>q26 zlzQ1thb@=8hN?j)B8|OwfcwgWWgV-pPUJseycKhJ!LCbBD|#_G=_LyR=H&%kN>m3q zZsW1kcdAft_V}N8-Tp|9@v8*<7-@erm2^qQ81wz5E2`aulOCdD^QYXf(>zpvg)PeD z@2>uO>m(MibL0L4y80WZPU?2^05Tni%Ao~U(jyiKF~Ix}XzN6j{5O>)z8jm|$u5^^ zErhF&EwY5Y)%D<=Dsm$qfde8wdZI+(So-}{XC;KhSY(=U!W57}8}V_6Gdz3CC?Gc;_QE#>wX+6P?E z6_m9~0hpl?_SB`OuG3lGr2a-t?pOXegVB#cjH;8oFUdw|Kxg$W-+r#@x6eTsQPhk0 zpLfD7!RtAkvSRj)%q$^Mf>*I!dF-5(7l>biMd0BWlby*GUfYfauzl%v+T2#ke9ygp zP}8$}KRPo!{<1O(rPQHgxLrO>8{CersY7>&!7jK4U#@GJ*t`D7{dbh*_SM~SNk#p% zzPl0C`AVRU6iqMCK3>%C!-pSB)md@A^n81+-0$~fMylHdH3Xu*MDR$ZXlhYqA9>Id zV0akDg9fgzfy+wt2@)}F9NXcrSlX7uS~dpy@b6^a-t_ZgxHW{3KD&od#cg(=@gnXY zumE21YBPJxZ*jI&{V=h<2Ex=^k|2%zD{>aD+;CzF;O|vKOdyfzoXP87{eYR6lqITn zZXCI}?`B(RFUnSCUp0HR?$>%FKC+FbJ<5yhl&kz>!AO)bCxsWa>_$^`PjdLIW|OjH zDSq=#%@i%bugrN|pR@{B@(a-8cl(=PiSRfRvEE){L_cv;>@>w$lGx$T>bNgXO-|_A zJGy#bi)L8A6Y!iqv|3J-f>8tUU2BKQ6cQA=F^;v=_*-F%%jMT}c48x78V|*0(YBH^fy;|o*z`8k<$WKuGJLbq7^XSh zpe0ksPR*MvL7j{3U`!^mrHCx}$EFO~*eq#gktyfsRSRBMq8T%RHRKXSCtHw4V#F5N z)YbxeVEk`xDrpUM2X!T-3G{PcGzNt+s8JPML<{^{3WZ*L2W&P`E@-e2*81flXMBg4 zU$%2y@=n=vyO@Bt%5Fbo$KjvQ9qwu_a4faHV6x z)@VwF_ustXYZm!Z_XhIwI%RV);22eMe5kAVaj3FZ2|RA`0!5nh^z(&s`zETJ(rzLn z%hMajyIJ{is51>b_ljc=m}Z+#1}|N-j;67IcAsMhQ=74!wA zfJe~(c)JW^OBOi4e|o!I8>cHuPKJ3Q{bJX`g6sY_mUGs>u$;C3!E$agv>mw_o8zn+ zeG#pwA1Q;HjXqk%)raEl2$tcY7?sFc#o%-?`5Ehjtw(vz4;i_y+w}=sQ!!WurY4-Z6!jjmvP4Z%whKLs5Wfa7;wPe|Y(V29mCN@lYHfX}+ zQ6W=q`KUC=^ZOx;f;4I_zcX4u@UO?AcbQhiQ0&(8GeU}z;Tq))qqE5`WwzWd?k?I2 zrrDF3rm4F9H}>8;sL8i&8w^sFCP;4qK|zX0M~DfiG?8YZLlmS75otjJL6F`=6cmIg zML>i|uM&C@={+crL`8an2!VvS_wU{B+nJr&nfG~iciwk*cK^ua9};r6>pZXHJW5Ii zh9N=45@ue}!l^v)#|5HOZQO;(#_ONzrgs-UK`qX6W+!tdJu|Ib_<3h1L(sp zss06B?KA*Cbf()-y}U}-#6B&h)Fmnh=f(Pj>n|;Qu$|?6-K~+cAQ8*1(4!-%U8~-ulPLCOPGkT*jA)_L{n44f8d?;_ zT{vK@F7SD%DK(Y!&bx*hyR33;#k^*+&HSB2E-KLV@-{N_*gd}xyR)|qN&~My zjxBg`P(oHEjK-mT(ReTk+zU_5K1Xy?J+aRH#ydB1aiy*?X8yZNAAjgG8|zKkfc;SE zXB>mmP+p|)@(Nu5PT1}^Qr0a)z}O84^4b!pGe3J*LOw~(u`IA!z~r2MMGTH=`*yEIqW8{RQTfUOF<^cQp&&4O(0 zYRT!tP_JupZc3@Zd&4yFM=di>AuWxeS^P#n-yb9z(uSS}&9s#ggcG6XqI!f2WT_fh zlomxey-gBYa%n#y1kQ^W|EcDDNBPwI%@a$~%A9Q6ttUR;nqB*YvLS*&47?N%#j-h) zoY61Fo2<$0yJcQh32_qaz*uTva&^ztT8qgsSanb6p0G3rq2_72$eccgCeT5|QP+DI z`eYK)ma*MQ$g;@n$#Qah^#M9^{=_quQ7OsUnBq<8lYR?0cN`%4>;?v`Fq{ zc~v26obsW5tv+|scbEf_UG`E&Hi~xvxuP-0N~UWXO_wnO9ES%HCO~(fRuV<%w<+Mh zcWRMg2+{PJ#YJV7#zvNN=XS7X?BZ^pxL_XXZvbXmm*S?NaAX^Q-ag8&uu;MQN@{Pk zJjKW2*~q=u2gW2sHkbufQaeD(C2a@ek}x%}{%r}m5zx=Z{IQavYg6;p*MyS_ubvN) z)Umm-Ei1ytDXaNR;)fkq^PQpiBW2*4sAYE(;hQNGQb&~bG+ppL@)RMb+dQkh`>2_B z9`Q$sJV4oum3zW1_K@@Z*piy(3H?oYXS-?+-JGMH==rKl1VO@c^>c*mF1$|`3(ZnY()`fFZv zD+&^M8f|}!S@j0Tcko^ZwUy|i0)Z{Vq6?H7DLsVK9y9lct1|pFp6U#SI*1Zot&~LX z=%p8JA1WM6e~HgpZOHgwB4D>E;7O1UeJ1#_Y*9A!wx_{7{#uNQ9M21i>ZS|1mUtO3 zoPpOhc7D>KMf5E3x_VdD<#)>ACIbYsZ#{X6Mhbqnz94`8wfA(RY~Ngpph ztWyWZ3>5(Oo3MBU%GAAnx*o#F!E}T*avQ+G(2OC5K1CJ4IK0f4 zp5T|F+YAZ$h)!1t#&Bl6xNr>T;8p&sDZWlCndCP$n=C>iBSEf_WHu=q1QoEO1P{S!SMiT1MMM*%{}7&NdO zJM*KiML2mwlNM27uUZrlmV7w>E&JRvBVMULo|jEo!`QxaNrHK!0Me{T9fHRY5G$kD zDRJ99O1h{dXs-wET5GP%#}QQ~8QDj&S*-;#~hsb2-+yVcYBH*joozJYX&vE7f`?|hFJJ6&le zeaZgK8RA?Q0Nv5KeRt~4YzT0kd;zS#PPHE%pKDn{ChfhZT43phQ>Syim#5ocCB2cQ zz7Lwe@wa(b_q~>I_qos1%*^bPupv*z_Ayv$s0k$AQ^=4uzG5oLfiX4#(ff{=tW0PXDlu%}f+4%6LLE#Q}lcZqupgL?Di7)>>#bJbZ5= zrl36X#o#vRnI*dxnFmYWCbD42=7|ij2_fHxenUkihbHG>8NxkODOVhs=arW!vS_^2 z`aAe1?lzEUEKl_>^%a6KLsW|4jzfq@1Yt3Iq9;Xo2?ZT8~0FK4fIsxwS?WUEBNRGR0|o_8M??q5$?bVEz->< zD~;ZGD|i2zz1h_c$8)bW2%VMI@VB*j&E6-5GxmPNI}FEzJJ5A(%`^^QDN*W^St%WA zW`x%Bee2awA0_NzaDYs*lEN+Bcxj!FHwWCCs<3sbBl<) zy-V$`Ujmj|)0#$*>#o(nWwoKJQ6(wnfZZe6;*&XnOYHVyEOEU%OwLyCiIs=`fn#Fx z&MVK`N-KSFw;BYtL}ecOnRLSIc>9u|{HsgH+JwoR@JZBN_mX$%9-kVuADwA7MHM}q zY_fVzcr7oUabm;LYRv^$Hsr~%9q5yl=;ge24Jx_}X=PVKQE<7Tusgp?OT>Dr57fL( zW;c|L=Z?N@D#te^MkDRHPiXlw?}byt3jG>xuC}D5CwK_<)_?Km8XL$24ax;DD=~Mn z98>zH@P_w`Rca7*X&Jp)=|}ki;jW|(b>X)Y*q(2`b=3M!c4h}FNq4Y+zpSGk&lURY zR`QW%{p)Xz-4ETpua{w~`y|7-2nlX_Vg<<{Kt%XYgtiYjc1hTclATFdA5Fxa2oaZI zX|c@oKLvP7t!I}&S$oj?}HUcL<7-dXe>A}qdkJB`;+>Npz0oSGe zsFRfGCqj?5FvV|jtZSPaZkevcGyOI#aburUrX|y#?70e1tAU7aQe!WI`x{cW>#)BY z9{=cEM4Bi1@#X;GSM zFrIouB2<|2?M1P6BVL(pYujo^Hd#Bai1`Ogf2x{w_gyYi4~dmMOg}ss*5K8pni_SM z>hJ;S)TJG`IPZASFfkm+^3YJz)BnoQ4VAD#1&$il14Br<-=B8iOajU)M9L&k56h|F zcJ+UdzZdX2-JTk*D;zBSeHv=rwD2RR19pCcGLaGE%1w#Nq6N?&QasDwE;;KuX-28{}<$r&VZc^S+@cj4SlLV<@`R)vwK1v)QA#}* z?qK}W`j(9CS4qE-N`^FYn~0_hC@t?Np&XWna4hqZsyD^_vy0-Uq^9q0H(CXql-_tv-^NlRgZZdJGXuO9b&)Dgid40p)It(}2lEOTDLFcsop>@_oc&&sIo-ih zeO`IE?<+hOT#M<6uu5sf)ah(;6O~j?BIiGhri65aJJ>8hklDj_)bb>K#Ry>UC>{(0Hv zCm+{XGBwW0Ued4{qgxShr}{%wVH1M+J+#iZ$>4pZht^9yP-`R(krV9LAX zp(?9A7f71%T*^>dC!_+s81-lWY3ogSIWEuWOFYTTf2@)*T$;M|7yn>|f9>)8swoy8 z4lHx(X?QjuQXG^M_@kB8P4J^7kDl1qX_OPeh0_qodn|RHxnH*_>}IyNp7N4VKQH}R z=k>-Zzhf7IV)4;>CzNu?rD=3wQY+guiVq9a-5O*3fdhOF>uZizA0L2OmPsX`(C2otZvg{{j)JpfIzySf zl(_rN?n=F#4|Y>kreH=ZDXclX$kXUVTZ0vIRxhd>7!h`o!_pa-kShd-T7NS`q_lg;Gt@Q^;)5yN{r`ch5u!NT6PnLk+ms9`iu{SqJit8ea$y-b!y zVY?e`zlUW;9L?m;yvlfJru8cu2h2F($W$tHsITQrUmNyEKJ?PH`K_K;eMW0K2?2CbQ40c*kD?I(*WGw|+q0ZLWUhqdz7kdo{YU^c3q+TfOM z9L4Smn8#f$tgVl9HzhOjpw(Eo;k!A^Cip3+p&x2tWL+Rew@uTLiG6fArZwbIja8{d*h>$n&Eh8wRmDokAIN^m(h z*_W$yz8N+jyrU)Vrg~n7WvB8lD4aICczN^s>!)hDm|N6xa`%;#n;1t*RK{JY6h%YS ztXN!|xsUSWWRr_wQcfRjWt@O}`t2K@o#d52a3QMQJf%dxii77ef!bgU7~b9l?6EO&1&^_12V{V)o- zb9-#_Hr|hBKHycK=H{~H>OG%%8LN`j@ttBI86rr*bfXv8I~Vn-e@K`oL2G(iRXGGw z>ATFbcO|FLeiq1NM_of1$rlTE_-gr)eSRCCK??81-j%-tAi6$mrR$Zu?%wfcRp;gG z3zLK^^qWt-36mo>MOu-ByDkYuaQb`o9ZEjIf3O+#%v!S@*ZOj!>eu%~&FB`roe#|{ z4;7fle|;FXa~$7QQIAA8<6{IxuVZFJ&eo&!lf319| zv4>OMIZ{*HD;|;KvdYD=oub5wIb?U2oXexeF?I?Wg-D?SsM?H@>c`vGoBB{ zHbFT$N3Vf?m0@+J7;5y>nP+m z>?JAqnW@9ec*`;H}|P%Va{E69q;njXvAhNLfk= z&fqZ*);Qs7x3}jnJ)$1tb)y>+=;E1Q?t|}rBK9sbMDWjxiQ_T_><(_3;zhFMiXD5> z#OU+mu0dt=yHt}RY662#(h0cFvRj+v=j5j=dn|241w zbafkJw=lYc^y`fhrW9niEXDpwxfEs@s@SM4%=fj+KIE9`tLK>ouS$PK$&;6oNG&?{>73Ac{`17C!A8uAh}vd|(_dz|?GX*8y47t(RR<=Isa!fkMhx^`Y~A zf<1^!jW5*;f_;07zFq`ec~+f2)19f->m_=~Nh6P0dD{5`>=#9Ya&96Y6~K#N17PdPtYO3yAxUhQ-z34A1t4_35T3)P3!$* z+HO{WyKUsF|1kOtA42`Z<~ReiL27iwTWlg^REeew5ITPUV}&c_^{N%(9^m)v=&cv! z1+IG;n6T~Tke6=KLOwF4rzb>JQk{~J22BS#1A2j3F>D%5poo0IxH4p3cdm9(>93X% ztd~dGT3&BS+ho7M>{)59g6sD?G2YS*V-I+>9ut<5(Of9_`{B1pt6n^`zTC-cxV{G0 z*jxY!`6glcMQuxgxvj8?1FA$RsK}K#%I)9#DtL{%{6duzLRtvJ!69cXNDr|pr^g*lA zs~b~`=}0N$raSktp3x$`AJD5j;~5-jjPu8pq1wg z%bm=FB?9ZZYJ^86fhKbIn&QQG$&2y@gc;qU8zUQ+-Jxxyx3isQx#GVitZEkZ|(#d$(n!cAOD+P z%*n7FscG0|0^_lrZziyP+5=`|_hGvyS^?Yl|M$h_fIg&!F$`}48xQoZjN_loDqxJG zh^F!~o+vQm{^u#iet{pO>UYEHc{d{|5f0R5&jQ&ohTxz_;V(y{wIt;?;WJkC*T=vw z<_q`@A9Tb6qYbsg5ICk?h3Z2YT88&837{TQc?lJ0sbAoeZyOkkV^W8gpj%2qMoN;y zZIe+O<*Z3JpP7H;Oe-{URfHJiE z9rE(b0iyBR@Fy(dYzm)dG=K8@OX>pe;44e)NUOc#qij?pMcA0?LmpWo!k<6_D@Ut* zX4Te(@b2>#Vf(~RX!4t;C>>>6I?>3@Eo8)5Dpb)wN*|>z}*4PbBZWo?G-6pRPJTtt5 zw|Rff@J6!arhPft;_PzyLu*mGD5{?8+7E+1(eB7hnhyO8F!Tmzz`%4L;;_<&o=#f5 z3v|U4Ga3W0Hv+Z%7Qc5xqd!AUH}lqOG7ry%b`6v?b!Zuv5xP{*zVW9NPtCDupUg3% zuqRY)*?w7CkUf{~6#o4PZ~FG6M}{~XbAX35SP6Nmhr$2x=p*1#RJs>T4J5*te@V%t zuUw;8l=P*X&wDqJ+FbX_{qnf1^6P4G13av4eih10=cNWxBKDe(inPEYE{KeX#qdJJ$+LSq)O@m2m!q?70UJd%hKOq&KZwxI z8IEC-a>`~E^?WV}xjV;op7kDM6wSssm9daTihl|LmvE{cPdl+<8oLn~{s{F`Am>Yi zOH)!KecoB(H{8@4eLPBfp0Nm(p^gTX+$a0*$7WrNr;SE$eG(Sh5oybme+{*t+foot zZdhuSeDuGK~ zg+o`i-iE>sI1t9}VX1b&?8~G>63yVp06y>TH~Acnve)_pb7hZ3?d>U?7JYWr)!F?b zM=2Oc+!Cw&bUe3VK zl+T}UhGvUXYUEKT7jVnngZk?`+ zP(hzVFTm2@$*PHh-uu?-vo4YHgCHkZQat7iZmWwJmZ1(x8-_WqyFX#bQqvBLu8GADA8ej(JPucHA%%>Pq&1w3`%3Z`M zs^zWrTAYINya zU1P!N<{5wNoBEx5JooBiz>8ZUtbsWP_T_zHAYQ|x)^;)Ke(yTD5#B)lv)uzQ#LaI- zU{86Ap@nsRUA=tKmKFACPROP@{D9Jx?eTjw?orrA1VsthQloB?p?ioYi%{~=00IKR zNB9K!+V7Z%02F9th$`BY30d;JV z+m|iXxJLA4oLSL9p0a^lEGEZ~%JX&^jF;WZrV_|Q{b)X@7@mG^oFN8y%6Na8$}JYl zevA4%0QgI1Y;*>mg51$2v&|?Uhn)i*Y5&i<_TTu<{9CxAorgTSd4&c?xz0qLcPLHV zg~dH|@NN;xd{k(zAQh6ec~g`9N2u46r_L8-aC6rID?ISLmpRplvbIkQ>q;3w=@2Bt zENJH|2u(7l@lHK%<1Uw(SrW@S@a5@nk*`MKIVn+MK;adp?Dv8t+L@`oOZ_eEJqrG~ z#ya-)46D}Uo$+h^K?u#IRz*X%Ti>5;93%O_O{lPUj0s?1QyNZzla#M?6iCE$Z|HeF zQ9)$Wfh@{5 zs*EhHyHx!ybiMpWEonQkMW@>;9&jP&r^}Gi0sud;<_!CWNvY3%IGn&b)a=AtyJ#jn z$NlQtfW@f?`UxCNSE4kMF1I2DR7gL+l1v`b3n25SS#TvOFdiD-K^nE67nS@je-DFR6>VKMeHG05_C0qV= zpe0q7%uYb}Veu<_sZ$L3RrtaTLt@6CwetCdZr$$y#(0a5><#9Y4^B5j4YJn+RhnK|{ z-k^k1qgD_iy?Wt1Q9|B(36^9=+MWF|v$C(9oF%SS-O-CQt2+_xoiBol^V!TBaLh5XbWDE7*batl0NesNhs;(P*7t+_Fq zBqBI1cls_ao|a=#}s!YvVbvkO z>Qf@TrDN=`T-j^YPkSaReS_KmLAnY16!TDJgCB1e>iwJ z(z8ZSNe!P_hETa}z$D-$4`>0FYS3vYb{Z$LaXCAsu_p>bxf-7N^L(o8xp?$((D4gB;%|!bjr9yYI?r z_0LnwrT?&`dozDWQ_u3!b1eaW$X!YEzaW#`9l$`YkpV&9$I`?gW`KC}+wF{)MaiR& zFpo>8Y^gHK+i}?iu(6vrnw!T%CM%uhu1J3BW=?r2_vHz8(L>zQob~*Le(ce!c2-m{ z*(8Pmou}WL^#2PYoThBOE&GNa@`G?A&?-|JY}wM-mzh7%FFg4^vIH=*gCux2lxT`{ zPQVxvaQ*Cq@j*0rleOR{OLidmSjey+i*Rk%Q~X276LHR;5Vf$NR(9r`qr512Y6ftq z8yb|Ho8#xDeA*@b4rlTT>z3MRdLJ8iEGIZ*&gK_Or}A4Oha-n+!gM`KP8VXns1w86 zF6{#xtRv%DBk||La57qTq`m1CESz~z+*9Y~PIcdda28&nqt`G`v;ZE(OYJ1*k)&SY zHK7p7+1eoX9LTBWnolt-@r3KDrz_YmcRI(jMtO98XC{H;4lx_$fn>J@Q|c>nM4!}o z3Ovc*;%7Bch)mD111_NIE4i7*%CqF;VOoR5+i0$iY`5S);AD(Qa0XTWJcbbI@kde? zK7}#@#}>7;$2Km12J8iLe@w{>Psn?j#o6~I?u5=smrw4*j1H?L3>fFJLJgw8?IsJUu`(8y z#fLbG<3Giwqn~%Kb}xjC*!V=^*yTOpG!MwA$UloB z{41_JpSssnTNfT?u6vO0p>Q>E{>F#TAd&v2s2=c+&k`4KO$2ap0SxY`CT27YgBq}8#XVl{ZoNma-|ysJjC2E zcidw*ca&(5nXMDTsw0;YUJQ?&ZdfLM;LL z_sUi!9zfl>P#gS1t%^?i`ALgZpSSi=K%3C$GTUhd5~?NZ^dNZhLo}a$8M9wIX{M-Q z@WCp(>SMdX@)qZ~0Jn>32e=NlQAA zvEEA}gHDJxy#HP`0dJx5krBiK9HrC|P)vI(DQLn@xy+@3jmG@FoXU19bx(dD)bnul za$5@C6n{COJV12r!%-gE3zqCE4CFKM@*^9(AC$ancz4-r}v~HB{qsGwq)o zp-)qO2h|%uWvOVw_Iw^;w(9`O! zE)K-rDRAlS0~&ED9wX7tN{<9o>Z;KUu|~?)W6B9}#O5Z`cG2V(v90b#_1A@AovqJL z?O$A!YXvpoZ!u4zm}$zZXpZ(vu5?SPO*n5nUAAf!=P`}aVUcq(N2WR|l^NX06ciD^ ztu{mwPsB<9bEzQoKs_aa@mZ+)FQ{-C>+m6U*Zc!#ACQN8E2?h z`o+E;{>%@Os`h_(>G>==>f8+Knc!tCh`H|MyxV)GX4)n6;Cj^L(v~;w@cJ|fDCAcE z)aRsj9v_$p%C4}gkr)wa>+TU4CfhQXaxS(qYV+HWe-ZiT&_Y;Tm%(69A_|*r8`0BWeXu}0r%didr6#+Qcei*!l4K?c> zchqcm%ePnclUpx8CFQPfiWb;Izdt+OThDn60N>^K_T;-^!i?!fDCJLfISIo8)S@P>-F&{X(~Hu6hJA(khRxnJZU6< zJFDgi`QOVqn>GgL+xq4OQYoY_mU6e{ACa zUWxVzoKx$Anv=f!qxp_awvXY>KFR2=Wm~I%y}jOz|2H~!jg{(uzY_-?31GFTgKv>I zfGl&IewVNWaz%Slvu;daRdx^dUw(ETJf(ck^x|c+j=;}Zt@WtY{}lD|v1g1e9S5F6 zEDy4XfZ3o>#GK&5XwOeygg?%^E?jzL#&^SI_T<$M?e2;-<;u6ySwT?`sa6I@nY-X8 zLKtQs@+84Ap$jKb-@0H%ykv4c^GV~%1E-3}%PI?#q2hjSd}+;!I(B2Y#`JiI28{;^ zH?TU1OD27s3id9XzULA4DaqFA{Eanx?eG0xj(?kl%#5aJ!4_isz{l%(qYKoiD(}{3 z5336Bw-(PtKAhd)mo@njSznpAUuYeu0IEP7#Bi}zU?kBV*rds*o&p{!)3UcoGTnr* zo`11a31a(eh%fV}!oHQ*#$}o{zet%xC`c42^TM&gh(uT&ZexBK$9LSxe7fpSS->+< zv9sBd4JE|$O?PTQHe^FZQx-3+6-?vk2Eaj^B*5bU)!O9p$P3|U$Q7NWO-)#%_nY~l zJwIAR+D6~$lGowxT$V)6V-Nh+t?i47cg~pu#ch}MmOroFama?;sX*3(f*Z=~Qq^h$=?$}F%1V{6g( zaS1^E#L~2olfY9S9So#@TA*#y2mSrq=HC3~vik^>)V5a>T3xc>SMh<5M_5rVH1}Wa z+T)s-d6RBD|AGUx)oNC=)ybm`>QU#Is)jE-p6b+SToWzv_M-FVr~uO>0}rm*?&3`V zGpV*K(q+a}+D;Ki6DAzttNVv;M+G+E4zz3$ABZ8l~ves~Qh>lTFWL__>c7m=w-ah}9evu;iVwfhL=tet> z@ts+Qp3M}yF6Z@xt6%b>0Ml>YO(Tj%ca+>w^#)WNff|Q;?vxyqD&CdOaV|yL!k1%X&0!XVH zoC!C%wx6WmmX@ex+L7tL`=S+`5TyZRBuG>D%|B7XqU*NfrH)sdKopF zNXuAZr28H#6%1_jSeM$fA}IW-Jk*Cy15~66LgjdlUtr9_uZ@WF5+%d#qqK9B;jp9qzY>Wku_ zYJ9j3(ILT4tm2J&Uk{%i$bQtG54x`3<>K&QRP>%o+Cok!2m(){-=;l68IiU6I~OwI z^w9D3k)Cs__=MzYb(#k6zK06=av zp{i0cS=>4ax4@n1K~KMWd7vTm@K{}yoO`3>wb^-BxuXJxEK2VvA9W6xw|G-DeDwAo zpH1G9eTZt(N~@G>)ZW)e-fqgi%!rOIjyC5lfNE3i_C6j}p|swiz?6he8gIf(CVb(W zD>8Y-pn~G@j~5ktqsX9bSioq4s|Lj+sa=Tr+3Lg+>I$$&a#~2D z8xl5kR0A)D7;>!#aA-VXieKzjdRn6HUHNr{vQJ00=u z#+p%DWjdkmZXW1`4PzPRJ1G*Eu*Rn}xvRNu)ySH9`WdVl^y4P~9|9i^_%#qkFTpmLTAx{8J@$2k6> za392g0fxLo)3lbP6gWxD$$8agSotqVv%tYTL`D9^`O))kf@VF}d(xw4D~2d&LdmC< z*xduVAwqIAgQVeY)37{C`CzW*-6?@eFa6nXt@qlRz1YTiG_}oxfJs;s#Zhunv6I)w zgX{R6zYd6I%P9-+pVSDEexURnJs+_*Ob+YMBu1**PTg108W#Hty73G@Gc7L_kmTvk zw^;r3!Vu#7FQ>yt?W|{S=}3qxxb)VE~ttWdlvD0wwAkEW8c0u(K__Gjv;h}2l?|F=+HHpJeqR2 zn+r>VN;>wF{*djbLgwE1X;(XSR~;)o0_pR96<+_zgRxeily0g3DPc23mA@D|lo#0< zk%Xn1qn6|QN$xs{_o5Y9&l(4_R*Sf?%L_1VVmyK9%mM&1L?u7EBH*P~SPqRo0bscC zXrXoq)IBnVV%Mz_wbUc#HgGrmjE6B}>+!0?yfAmcR3fD*Z~$xZ&f z-2YgM*Ta|p@JxZN&k98GE}afOc(wz7;skkx8+dU z;w`m`0op?y!5eJGkCH)u6qQDxqI4I4*pqby%|}DzDnRz3lWCl-P1lukB`J{#0#V@N)@)h&LFo8@%l5ToSPZx}Zb1 zWTQJB>aSx5WIw6=7jaep2Pf99BZ}kb6I2A*zn3@#6?F1z38DrNIYokYk!_?<-~q{FSJOBa>}rQ_LIO{Jp^1((Md zczDflUwPRtGTn3nzj7_z$LpYY=d)9}52QRaJm}3_+rxVxbhb_{TX#O zQGa^0;jWGAAkc5({~`O(EsCi+odp`T(|;--E;?11d$WAltSG7f$r`0%lT)QWXv?gArR{%A+3hwei&M;)hzl&wF4 zrO`b+ik6UJ8#v%AiT1-pQRx$p%`)?u?i-6d9_$SPTq_wWC|&9x0WjW*cRa!+YW6y0e@`QxKQXcNb^M#l~4zr1pTB0L0t=w| zH1B(Y^Es~T7zM)|`8QSQe`DnaMY0+pw?5GfP~v1LxpKK2%emW0L*9b}BNHEl5sO@6Rw$ zV2T>@{|ka-F%ASTp=u*&{!yQEBYaZAD6$qeD^3rxL1vglo^qAApO)B!t%nB2j>Sag z0iFSZQIb^7%M)}f;%3yW<%CviUWC3dBfQusv%wr1JV}O~m zFD%+|Dd9ZCrI`A8X5F}c`t-Q#CcnGz-jij`2NmIHC}6+z5VL7lLY*ZF$3{sytR5z$ zg_!PEMA{y5Yfpc-yQ}hJ1XUK_(XUds{F>*YCeyr)s}>237hcA8M><)LwS^JaRZdw? zY@Ytq6#u=Zrc?hQu!fGi=((&I5Kt;2SWd>0?7BwNT8u`T$+bP9i_&$qPmL}nJTl|_ z_&MqHiyK$JGT|Kzbah9#L66b|fA^=t{8Z;hJEYXmAS$juH!82H=kv3UpT_o*hsryJ znl=&=nX=`8@a<7!O1lga+~F#a*AKlgtD;$XOu)o%zT_BxP2_N@Mq9h_ShLEU7-+*A zI}A$%7HR-SZOl%|i=j?!sQN#B1N6K-s?EFYg#2 zw#QP7C>TOmr~la{Cu$K<$^Y2Q=9dTI3k7ER5A2Nue0QKrv7=uYgO2`LMZJ8 zMhGHx4iH1~P>*f9bafoP+3{OC^>u1ow)P1MP(GjE3&v_wGB(3$+DP`^8l-SY=UQS;8Y_^;?@$OgsbCMJTu*tZyRdTx=mAzAn7m+BXL}AU&BlMMb}x-8OKDhJ>n5O6Dz%1&29jEWzgb`nIUBWs?|tSMrto?}J+omi56w%d4&?t!G$aQd4NK(lCC z*o^s4Yp)%^PpjGA%v;{hknK5itt^{!^vl&^l$@!rd~XT)#{Z>H_LIBx-Rz|S7BJw$ zvG+%Cxt;f@3>ZrM1@+*qVo{U^P}|YGy2z;tIO7_4QBk^8WWsZtqwt<;3n1CFOtwVHY6*& z)3Q}}rHRux9YIo3k)n0uJZQnEw)y)ICkH{dPFzp@#p!mThsRQZc>y~7KNdTI-FSr& zccUATAih#4~?;pL@r4 z?Y+<1>z+OCIOB}<2V?k!iAm=CzPCQ_^XQ)M7L27NBaWk;(eqMUJSsB5y;UtY4C}u! zjJ!6EBuC(eua3x9z-;9S!VDr?r1ycLUW7FP|N znA+lDKrqBw)#FwBK91|PrG~6P#G4e$ zz$0W2L&;q|^MlWntUZC#&`0BBS9@#M0!O)x>R7d2_Kf=}u!|B}Pw7+8uO(EH$V%Uw zag_|M*i&THVX8DKx>LE{CsgSZKCx@Z`n-Rxo#l`7FH7F`J>smCSfCMO;RhXDlIjC6 zg7zd!VzgErk)a*NKD{w+q*uPsX8JTJqKBhn@#&3gVs83x1wECPf5mTA97{PG@f+p& z5e4p>`#YL$-7OqeIDdzb;(2EZZ~3Ezk2WMGPBoh?|5Hf}+G%_GunhyEa!W>O$mToN zVSB*)W0Vp+K9KDxxXqXM3JYV;TS+LQRT9XAv2RO)_MCZ0rWbY!iK)51j zVEasMq~(JzeC@?zhKSchVEhSQK7$?<`weQ=(mWTX@Mh={=GqAzxB`v}^z2Ti;tvTj z!0E7*7qYt8Jg{!sUw@aCS@RpLBA+27?Ii2UeT94t#&kJ9 z|6ce#Whqau6_cQ2`99weJ0FCQ$NC8)AluhDhN9F*y_&#X&C2D&C1iE+MwRj5qcG%h zLqtDwp;OGTTZ|iJ(uJ)lCMM8ti)%L-DCm0}D?2tb11u=ML^?f72K|&fGT)7S5_POW zJ;8HaA7_JJmrlLn(>ZR}{2oFqScFGGgsf1egzZRIFE%De!k}h9Sv*Zd*A}Pp=FKj9 zLco*Un>GxBKnBJy^<9Nhfwt}0?}xUR5W%`U_378eL_Lno-MFlv@~&BrZWYiP-8N7U zyk0DNPynE88lgmn1+KhsujnS_2{U45Rnew*D#uNsxeAGU52an`6EHukAN_j3Fzsoa zNH!#G5#kq;i!vuVWjW0q==AN(9hTXoM_jt_ui8pPFD(OzVm>~Ns5SBlA}8Qf&!SY> zqML;`jHV9~)a7M53$`C@943Ub#L3GS((v(}rI{MiLYWhiI;|u}u`b*35bja0=-wh8 z?Fzz($z-p)==5~u!S$6ZZG8p`qR&^Z&H4iI#z8XIZhREhxn!!M6(X0>Vs(*7&9p9> z(OH3=;^PxZx_VKuCzjBs+ffJo&Yknzgw-&6Vpdau*`$N=x9CXOk7*oJ)-JB- zE19(}u{pxeK>z`^=Og1lvI2k(u?JX&)HX;qi-WN)`~Id}Z_V)LzB!NhY&ZrI(rAMc zA9^!L#yO#*z?G{lh!FlpeufV0UzuKJ8x!|vqIMbnhdK@sBC6aOxBqk_-2zVS{`ZL!=qJDh(FIN!%H>;fg%UPwN z|2qq+B_Yk=-=MEIAeYdSc|bk+hcg^GF$lR#c}v#Bqho%lbnB^im`#L{uYYVsIC=@E zJ_qQ1yi-Of?_9S<9cJqlP(@+%d4gvcK}WduX`pPOVNA@x7f}1CAi}z(sgu-4WNHUQ z9NQp4rOqigH7XNw2Flu4N4|=BTS=|N)64@Ls_@d(EfwpJZ(I?cB|Oe45Tmz)2PPsO z_Cmwv+(+bB@e+NzmAeDZ+i-56y}~rn=wAf*^LDE%MB|`5;={u9u{=N{qRecaL)KIb zE;D6W?jv71V9qPuu?J;C+iizoDh}JR!btnpbhi&BKZM2o`fn`V%rg1ub5!b!uy<{; z5q}zc%Sj^P4YBT5I*3>j3j_@A*1&wy!RTHrr*3(Y_DG}TiIot>S3y^T#B`Hg&_&+r zy{jG^I+7C2DJ0Ghj)2T7J(bJJc>Icm<=VW!aQe98dE>;BiCy9zoUiV3Td@wyUsXDy zS(*bccsCI0j~AhBFm}{!BETEW#r;dWz5nJ`zH_eSy@l_1<(gPj!Oc|5-TIRjnK%Wu z2_d?LU~@+dGuU8aUsoL$X6ES~JOgXEG6}~{iB_YlI%^KUHe1=OxF@|00g5PJ$*YuD zfPF85?D88_HrR>&byraB65Bvpe$y@RNN1YO|3gQ?>b1FxHy;bKgsZD&z-Gw3T<9 zIE< zUPBvW1LJKs?=L?X?t+dG2hU1J8dZ+r)RA=I1P+t6NXf9vAoI372*-=E?c0-wOORATxIxtV{A!1hPkL@_)2#88we*)Y z2W<}A7XXr`8D99Dic~kqcoKdFY71u~b$1>92C+9P;W6xF2$8ko>@%lSVF4_+&4W9E zk5m|mUq9Z0C5t^MR9l|7o)gdqXY@peq}Q}e?tQu3s9JJV?gVpJ(Dn88k~8wF%Y!Pw znz%v^F4clS7wFp13m;U@z>cNoVJuO^{?9-9Vi?Z_)yD+0q2ioHyPJ6kt5LuuSuEil z=`Q?ljf82le$v;*+>KO&--{M8>XCjoLm$V~vrK1JA5Z37THKBxtGg~D_{lej*htSU zxp!(V#=jb$cUv!CxnvzNb7Wg>r#TQC@r9`-7=mdeDJ@3hVpG7vw)Ktt)XewIp@Pgr zO3aLo+}nn8@qN20$}}7&x(Tc4pd)a%FgpPofnE3@%v3SqS>Sx^uhOD0JENs|m#8fd zzZbR0IFoGb+W73gwg~@XvwSa6L7&n@*32QDx;3;PjbS1w;?}tN$7WMf9K}0FI^DfM zaVmWa)>lAB#@+$?aO6CSpU}D8)yozyKWB9seSZjkBklx}+U1(_X29}Q_jlin3{WKP z63L_g3F;{&3kWmcFWnxma&FRYr-~7l8XVK-qdg8ZZ;$G9M3lRFJw>6u{%9|jOM1oj zmdZoE0GMdo<2$5~C3J_7=rzL6%o3#Q-rS9m+s)}qTPn|bN=}UAy!vP%BKLEUZrwSW z5~e}3Q1BZxDFLqfQC(gWS>)B8uYN1{XOzWkOW18u$*1n`OFq&7^Lpmw%akmXDq+NU zOvBI%f7+uWC)?9w-6*YbZ zc6go$x&^-W8T89|O9n%Jah(brs7puN_bI8(0+Rs);RjDroR4u#0**zEQ!5ekK{D-@ zOd1`LHaNkeI=0lH%BR8}U)aK-m8Y3tv-$r9SpbXz-%baJuccOK0Zyicr2k=1!QA1+VPq0)JlB-P zPw`7tL)V)M+AF!qXP1LQ;JE?J<&=_Q4INhm(P+D+WYIL`X%oLRf6j|Aeo6Kt84sxr zwl~J!D*&ySN7O9AXLF3u-3F(zro6o@QcGkAv5ET5qjy)!cf~v-wLv^5=p0R=seDXS z?z0@Bc10RIw5B9fgEJ@Ki`mw!viJ18W;SllAM!Kr+YWg2E!tj18`MtTQNMz1gXP+i z?~?qN2G0N(qQ96t!_&LN5zUe$`}P}BK)Jy=$yL_%nTL2Et;}$%Nb(eY*p=`=6PA!! z2Xq-c*rm#qb*ZyifoL4N2V>C_{ACweynfN^W2>|%yTYxqbE#j~c)dkf8P7NI;14%( z3)|6I=Z71gd%42ftYnAIIXhcAUvD@T`ceJFWTwZ|QtV}!q_eC?FpVk*a8p4*rt9&U ze2-M*jK|PH+c1&6Hm)j>qfQX7bfoc$wlXPCW{A5h@gZbmg1~`i zGeI_r%$JRFCb84|s62Q2b^1ruXmiW1V!xL(`nluzFYPmZ%%9`mx=J+fk+hFhkl;l{ z>bbGbgFHO$46@2V)7*ICr`Fcis9gT7&BD@O`RdTFAqHE9h zvi($?hKHXarBF9Bj*mU34+t}DNow7U>aYv3Q^=C{8kWbZBtM|B&$ow&OZSPWDmhIG z(F>j*Ue~<5a;|E0OD6C5#?1Wfp*l~yBXbs}`&UKS+ej4$3WWPjOIwI;fHf*Ss_`^5 zD)Z@hk%s1H|EP%v2`|2!KXUuJl~k7;_I$F{PK<7$vc4zRk3IwL2(R;qg8)U$2B%yo4@1K0LznxnyE zK%V3aI|*Qbw!uUJkJQ4;_&{Zt%7^4RU@0ZyA{5`eD^I#Jj;PAt8gZTLlk0t6PTjxT zW6_qu$;!BJk!ug}5^5*Qf)wZmU_n0ND4k;SwRu&)ndq%kTXLBf5wdNOKi`Lcoh{F7 zhI@9sBsx};7HF2C5pa3DNV?W!SNNLZGq;{TwyPyQI8g6-Ag`jV4-qK_H zl3~=-NE-kSz};v~Ooi1qCYcc{xw#KZzY;S{Mqu$@Biai1>l3kl9@VgGYQ`C)!2|c* zj3tIBh~P%WT7WF6$;Q^kOLNgwMtXdBcABB)PJ+16l?uVi*xiNL4^f~Mv@Ja%;>BPG zblQx-4F@|F*}hMq3NhT4ad!I3Cbillo*cN4$fjW|c?_O%oIg*BJidW!=tv>XBJ&&F zKRI$#w&+N|(0hO0wC&W>h^XxrUZd%&3pSfLfK5Z2RR_l0+m?THvV!8cpq zMC|tMx=aqm(g~OdaCyl~vdxkokb=(h6JY??Ei&9R>Z+1%{ApaY+$XVXmjqrSS00|1 z{F->dNa(o(-4bqkb@8wb9*!&9?(&yiWXM}H3W5}1_+|Gj)kcMdVdLfl4>RSHHSGkss`Kl9bRK0J4)4sd8(2AXk zUj+OV!KceKK*9-HgAK35&^V++$~+dKCo?GL+szt)R79+{9$9=jo4xDkEsF#QeCg-l zSE=FVe>i&lF&KlXzy${=S~yIP7x{c5p*yk!eOwjHhui|XKTStoP2komE{u;)v*h*a zlJ_uQ^x52991Kfu&c_#&oN>;&LR2nus(!d$zrQ{G(_qX5o51{8TTmaNUj5l?T@Qpx zIA%f8wHp5hmEMCo54qYRIL$V~4ZI_D=g!vE1+m23zTdN?1;0nD`n^%0!`PSyL(q0s zb+hQ%;3|BK`Kg5mML>cc(d?Cs{zI{7(pN`Z`C+>s+tfE|pFuy`Z&B1$sU}0COyHY% zt@As{Hwn^yOUX6ob_mKbAp_SSFDn4P?K{v}_=*Wc^`D?dIUSXvR4I87 zK8JMUmk?kKR2$Ioa^~dQ*6>CklteGOT+$Yg4ApoYbK${PbJZ`3DFo`|4WVn~2lWK8 zX5Qly0DfWOcFh_y~m zTH%Sgy-TsP7w&bPL_PqOxKUn1pU!dOVVh-(r!uetq=xnv=m^)JaV0s9>6ip-q}J$= zoP%xDP4Jtj0+HBc3sz&X^yA(xXoXKZf(OY7%v20r_S9o&Nl(Fib}V(=InRDMo%g07 zP*E@y@Rc2!+{LPj(XAq8mUM>1a zROfZ4k%jCMU@9TfO3;dqcLvTslGw}iyS{61#_5Rd==sOm5D@L$VE!uEIKrHjGWN3r z(IK&A5IosgX%HE5iLoPCubn=P`L9%uUmLIj4mdug?Ktil+k#vcw5;rtK0H3t%lm~z|t(1*K+vW^4UR(gEz?0hNRrPM`0W=#A$J}ne^Vb|5 z)>fZ$4eCyR=?az;bx<8tD)S)BbY^ihQsVy_?oF96C@mZH8hH_Yv93G|->I_3Ui0QN zue4j)=Y){NTG5j6Lu{FS@Vl`pzj#{GL}xSI(EicWBcDHtwMKo2LMJTK2$Y;&EB9i~*c|&}D)#szQT_<4WwC=Y}v5;gk2}k}%)WNjPmv z6q(a?33`fro2VVDVT#}8MBXL3ZEh>Y>Rf*Ixt5!1`tp1StZ#<4K}Ea=U+Hxf>b9G2 znJIj@SX#RK^!RGljHv>nk5HDiPTTfO>?H4idnf1_uBMv|giVa-&0M6%k%b|GdB%i< zp)#{CO`*&=EG$Xd?ffQ$Jr5NA!sQ6(h6w#yfcd%C9%wV!zctRYAoGwyiK*>skc?wp z+?xBkFfUOk1^cn)<*frkg^8%IqX#8h((Xj&A^l834GXKsIN1`wcY}!ti(!=AudISa z7>|Q{BlnB%phZFR5O;9}iW&qS)XCbmg zLyLv?Q9^dZl&i<3IF`EUjEI4G^4lZHa1{FzHe~)J5l<3mrwWkv34Wca(^>i62Wr9X z1#j|r>tdHXURm0_Y(evr*U5?P2nHLjU_Y#~L;S`s3`i z=M}B)yH{Q(>MER^?RGyZy{*!YnKEvhXCUs^bb{HD1szsU;_VECcrnHYFSSyCOs_zj zc*-;GYO?8QvbRW5(g;u>c>llI?D$WcD*x;V{IetQ&yK)9I|Bdg2>i1n@XwCG|2Z9j zj}xbZf7FLK)!gXnG5=~`>qvitw_0e2kCpa>MP|o>jna{-4bXGtBA=vW$~K9I6Z>joTqId zyj)X!@zPAM`eRLTJm*W3{=O1!mQbIthp%TDu9hg(Qi7?%LXf34HTzhyO)iP`r*R_u zl)2NfXJ)x-Lbjnpx?$*>#2&Hbx2qc87&6IL4TMa9Qxm8Q&fsWLA^cgtK>`0-sJ3Q% zC6pt*^3w16<<5%3NY&!_1?7 zO|NOGwRlbx4FMdfgbw?aeSAL^-1en$PI&ywlJ2wjTtQEgtyFCI&x78XvegyA7ytxy zcQ8=#-^Euad-+~2BC@P?KZx>)EqYJw3(XGiF^skJk>WU>+j;TW{Tgt?zQWm1JUId5 z=*qaG-J)$A-y;H>;DPKUiAn<)d!YR@RM*)*`1F zv%V@v$8PCV{Oz14EM33`;FsGkWG|Kw{RsYqN4bv{CI@t%_+2r-OCrNf2quGAC zO|R-jFIImz`61VQveNx#^Ax>%=lMprd)KcCJa`A7tpgz^M-Wt%b!-h!m%rcQ0U!-O z-5pj`>90JcG;>sx-VMu|)XmU}Q-zq$^1O<1m5EJiy6OEvxK3bp#*^8>m|fx3hj?r(M9#$0O1B8PE(8(yVT#-5k(cQr&XGZ&ngfg7!(OIHJ*`WxRPl*BX|FM^q!D|9pu!#!i~UC+Hcf-ZU# zi}#5XU+sJJjPWYXHISb@FjHec7%Hql(r0M3)YURaoHFU=Fum!}#UwqCqdfX{2P^;e zWm4^`yJWytA0xvT{_x1+&E_j)nPIduTWB@ERbzTkC7~orL;cg;U+9C76Q|q=0qe-& zd!T%kH{n-j0y`WmNo;t4|4W^$OKKpTLE7L$u~AmTIIC~lvMIPnrL`3i9j@=|XiX9< zpPQ++Ev*PVtAkW6%1$w~)q`UzVL||sAnOt-)PK>(Q4!gwoi8)j6jLa-dR|Qb=W7Kais>5F_dcezr9flPVro z+ZTJqj4SJ^AYe=Yrf_g5gr|=1plSk853V%N&$2HYKXd!Cjls|L{Md|Jeh)leKCCq! zPhLSY1e_*GEg(X|gUM`X0Vqi#-n_RUqwe0JPpZdKe{JpSB~*MR?X#005J+n5|2Mks z|0U#Y_-`R^M;Lr}=XM-~iTGeB0F2_RC(b7yAXlhE;U?8(aHVLs=&vctF2yth&pfuS zf$!>F3kX-|;@$_U*Z<2%bV)K9 zL2U2L;-?e@2%%1uRhB1RCwix$eSI_%z9KkuP4W`x?~AdAN{G+$D_07}i zyV&Tph;FjMvhK2E@rRD#Mkc6~o#r0F?uP#3y4k#y4gHN~72?M#w&#hT!8PnpxLs zIIxk8G6;m1Kj)2Jey!&b#N=1x!n-|nbh+B#Sh53-SDX|uf2tOK>SKzmaR=G~cCI)k zD4=(FYhcr(Apm(?d-REJ?B5GkeOiAJTAlE5ZoMX_sOhx<>>Qn#n3HJ)T?uWVtiBa& z0UJHfL&6X~k&Q|I3YKfe+=q!XnjL!V_LYvJk9#H`t8{6*#3i=}d_nV~^oUCfcHu;o z&VzC97a#Lv@H$*kIzqMFha5K=j7-HJ-_4f1?7J2Rpx|TwpKr>w)4hoSo&ZLuQD*5h zqH{dOCMX-UR)t~^jVTe>kF>cNV0kU^&CfKre?o_GCo@e6+Y|MD-2TQ1O58gIH*e+ zB*a?~eF-O{mt;gA5JluKj%;mPr>uI4R7S6S;H;xf>>ULGL!VhtH@5fyE|LzxqXWX1 zYN=)%-3ZH_?aaBA(D(FpV%gm<(hdg7TF9v=5m?aGSi95=f3uWi>2-%aau3J+u>0-R_YZb5}CJH+=M7`Z0K_pI}xHS z!Z#)D#(igoUZ@xt`-@?}CLF5eTl11Ci&7(YN0W6Hpr@Q^k1Yu?(cJki=h78iewsew zN#FvD!)EQ)w4j;w38cGWYt_|9D3*yTn)?_#9^qU^%59c4(Qx5b8#{<650S6ukzhs>h0<01CJLq|06Njf=bT4v`>89wpw@wJ?je>+0rmky&>b|c;*j!U z9(}QZ?jW90zXxVUIRS*w=x~Utxh1FW_|HPJJKKSu&dN=MJA&`V$czFx+k$0=1D1c} zK(~J*Y6${H-5{VlVT>lo-4?UJF;5v-|FLFI+?|(M_lj6BVxQyW&dcw7xZ1OHvO(=1 z5{&z|XZ61*9`}za(*4^v_~*a>@4I*ScrpNOJ_fMJ*@n|O=w;2%=ovN&ZSgmT-y?tP zoB{RRf`jwFL5bModjO7Z%i}jFm&b4fg(8Iz1pq=ap+@NOMbw1}aYB;i+Ex)M#G)EB zH#1ovFU>vG)$M&Ml|l5y!on4e!A&hd0}g$V2Ed1>fdQnum-8X&klbTvz!*Q)uz-qpg(WrN3x|Lwj6a6JJB8G!Oz~p%+SKt*g3IoJAPX^B*wY%x?5;a7+ z#qqlQ)*NReE744Z+m_j?yKaVG@Y%oyldTrP6Y4dzS>eQ8JX;3>O!fvKTAR6q4;HCy zdOms8HTcudC+U0;PYQcV&Fuw*CNEnE#%_cJwvf8J5mCatlv+3PMFP*l*`$qW`-7tD z0=&@U85l73W8=ch+j~EvUuyg7CkqbR)fA~8U?GeF1%q{pWjOfIH2l*GOuPO|0O<$M z8sN^dRC@XTPVrN%<&KWqOmwk@s!yU?cmp7?D0G>9gofx+5RucI6hzpt7+n}vDv)I# zsHyQT@~F-2Z)#<}W?AUHz)G7r6LdcVk0s|kTnMKGdSNJ%jQ~B`7V%~~xaKr26hcd; zP4xG>M;)skL)|PCaNH`~ly(iF<8vJ1%k@7pJQaM+VyDwz>d*fnn=Mn8O5?P8uvtYy z1Hd`mPMkwVn@oTbrY7&;z3jA5KwpKg$r<;|29Koi=2BAZ5{%aoUIT6Y4Wc!1_IX7@ z8xDEtAf%1#xDVyq5uW-hF+zo3_3wy1yH+AWnppzo;!#YLCJhxL4DV9DX2m&%0!w5H zXN#mDQ`Rnvd&fi=&c0G!K5=zEhnNVEZ$>F-y93B>R+4}_)0|*4c$^P*FRZ#jvXh|U7f)o?#c!SG^0}3iY&3m)fo3CN+#yS$3?T4# z0n?K<77KZJ&};#mQ}eybB}$F2ZhC#PdUu_L_CeuYj7+mF1M?B@v1YuM3b2 zfI0mz;6XJzC$j4&$to$^cGm9a^jpqj@^=ev|wvksdE*|Y(g7^fQ% zM8pA{geBmas!kA$YP`I7f-vT+lE?f>Vtn&HjJky;nhK3UqsI1x_=7nTpnbl0Ii9ew z^u!6!z&#A!6d86?bW}W!ynH!pK#m3G))OYNJN77e;DfuyX8W1b@$@KvQsfeB>fzEQ zWZTj)6c24Xw)DB%+*l5|ztJG&%VMnW($zJaWBW|yqAlpmN%Uq3p?%4|19JiiU4nLC zV$_!N{XgmnH7D3gX}gJJF}Z$7U!00eU*_fv_fmgKM;|Da4AeRX`5oZp@muUf8~o`I zhhKx8=9Xwnevj3@%F2{+pHMM++G13gi;#IkgMBAda?7C<%ubeq}r z@quvD2E^&^h1@2~t9>|I9;jn1YKYr=yH(U%A&R&}l>Lq-E6VVC2rmndMJ z5bZkUT|KV8FG7{VCUT;$?nU$_DhETJ$ScgQ;dYWVA5uGjrsQz#C%Qfe?;1u3qDi&^ zbegA{FF8?p2|i&Ox*8W0k6;J!{nEW$O^rKl+y_4*bQtB$5!T)L0et|CZSn$q`V?Me zb8rdAGmpdrinH7=s?+kFZZ+XMuf+Go>F`#l{8|Ak2y%%%mZODEVyU{YEd`(?7lQuL zn{FQ{s)Jt00dqZh<)%%D4Fqrtl=$<3SD`P}d^&%FW*|!)Xg)-M0f2<=gDZom-NNPrKq1&WblTSCi?L<_4 zP`A}}-_a5)3v~=;;Zr$t=X+xXAWE~cFgV$(OXS(yIKHll3orJRG5EQv-Fz$r=n&+F zw%e(^Y)KUA4AKaMJkAL-qWegRB|n9TP{)8miiTaf<$OGAzHow=^2E>^HzRfdX_%P( zYx!(b+N_|!FB{juekiBWs?!1sxOk;l3B0cQh7jGOEGGPuL@jBfa(T5Y%kri4?B6C< z0i5(ps89eDb&`PVfO3*Oh?iqOTTSj`>9d$mZ|bX0s&$Nq#RZ9;gkiw&(AcMBvUU0kQ)~ zJD;J)00FSmt+IyZx@zW~IV)J0-Br(PpU(*8EO6+)nAprEhICC|P9kQsf_baHpzizz zokXmHr;|W#Zi@RlZ~54aT&p)*%4Bw3#TFjH33{?e=d8{F&bs0e02LQ0hAl9JejpBZ z!e`i8AKp11-@!Aqwz>TDLPOXu4)?7FMVby0gXJiI6qGr+w8REr_4t<-L&r;$T8B`Y zA2yW7AFNdmscW$}MNXM5;vAm~IEn-_vT*sK!vH$w)GZ}IG^L?E;4-o6T`nngu{iJ3 zLJ66Gb7o?H=4AX)OXIamNBNc1A-< z?AsT2-<*EW`br5ShrD2uC;mqUJd!k}yqt+Vdu$7K84FN$vJg=j2m==TFl;Uww`9*% zh4~@3zz;U=G_=Al+6Qo{Hl_qfC_;(+LJt$3u{l_4 zsjj2BoE`o;&F4gn{3or4Z!3&HQro~>8UoPGH6H@2d1MjFXHiP%x$8~6tZ9e&3wopB zzTVftCp`@sY5=VjI;9POpHuHQ=tvR4gr4Zi10u1J2MkKbiqvm_WDbyt;UdkpJC9e1 ztoeixgE7yeVv=bIzBJZKq(6Q??kGm?i>Tw1`#LQoYe3cuv&Ft(4v!O7B6Z@S4C~a3 zu_cw6ve#?5e_3Ii&q3`>f$fwU^O1KYIYNY6Cdndob7p~ult;iGyd(Lr6wm{}YP%*! zmiWYk{`1te8No%k;rZZG=f;D;>NG*2zWZ^_{s0HKkyIi#)x91^>bQ&T5CY4Du_p$qH<=U3n z!b~j%1h#_#k`JL_@pdG^4?mVw@SdziaIJUfeA`Hti<9t3{{h?-z!=55b zmXLT_)Kp$J>f?1WJnQK$w^jGetEMzZcUyRb^;2vDnZ%YSS}+6{j{Evzd?1dy z(H$SHvw3c~(X!Y{rMi`+qOEjlx>wrotZgPp6!cun|A^3ymkn>0L9%I+Vi&bCJ&T@; zD$=ftva}jbu`Yk~?SjS&1MKUV;)Qj)%WDTbrK4MY?Y0)mBGZ zE40@=rM(u&aH%~^RD?qzY7?Nm^1}nl7{_`W*h(JYh%#CM8n438>NriO_InRywsJ?W z1?<}x+DYy3oDRmS7jPXh5eKJrJE6jW9WfsaPj3z-pyA(8;A~RZ_zi^1itAlRi{-xY zp#nwap{_i6Vhc{d_i2mXYCD#zn1B*`|lMiqXE&&`6>m1z>O?`o^ww3hyiZL#%~ zdbJsI(YvRILLw8s={l}>&udB;6sZXsp9LwI*7BW6d9rbU~#1rHEY z*<&Jb0BNxeU^>^Kw18mMO*m0ztBTU!neyTNwOuou@60nZ!`M__i*5Q+3#r?=dfvG5 zpW(zYv@0S4%np%9N^G|m;rmaRN`x`lj=X>D(x;4bd0_2ma2A*#R=FJ8vUW{Kb~E8O zNHX$R_cv&QQbLW*il3pQ-d{QZ-nF!7f*My<>}-;Rx8KY>HXQpEKoQfNij^j z<|eL-bF_`j!s#u{Z{tK_?(nEv1h;216#iLZQaG3xszhQP;ZuQHHngX<*=H#q*e~o> zc;(smEqz~qv6^4Jy3`%qXj>RCgQIIS&H15nSTCI9>RnSqFzXf1EwctViNEHGN{R|Z z3~1R?=d;*}1R}H*8*9Z(SiwbQlMqEZK4=>o0o~_AIx>aXpGE9usBvFSXwR04*(lQd z5_hzBw028LIU`=lzUt!s=@y^rRsBHu`%eA7t8G}Y4iVZ??-RT74d6u;((VUFas!Ik zrnb+s-yI4O^A8D~Z~lCJ+baCeGBr@h1&fpINDA@LPU@MEv7cLK0wnIVkf*dqeE_bZ z5DmUU(voSs_R!LvqpkJMjm)X|9U&b&_Wp+G-EW?*r?#Mz2kz7U#Wxk3&6fOX4AAY_ z*({ioNQK)3`Ozr(F_Qsp4`kV5klyWN1MYMQ=)N1rSaCwrS2PN{jdoAoAvgq_~k z(H+hPHM0C#t&uD$Q715vQKSlLyASx61V3m4U0H(u!3mU4`~ez1;xzQn<;43y`M-JT zZO@ruThMP1WCZgi z^x1c>72VSYg(~{4FLGW?on*}t6}k}$4|^CD5SugjDCr=RAoirEI`cC+?;{E8#Q>9g zzj{bU^*FyH)^sBPu)47xXNQ{G`x^Nh74a8Eye2zed#tiP|A^HnxqRLPrMCG@q*@(& zs&3|p{{S5Xlup_`cnv_}=NPDUvuMSV=cMbt;+Ph5$!h}7n(q6vl&2rNU^R;# z$`^!vnh>+`$F@`LS4=0EtR0q-D6g}WoE7`xl*4*d!lomBQA~LxR)QciEcH>Z`PrPp z@p}4p$Cj;Zl5(<71K+|-7SQJ^_WshBZuD|Bc)S$T$2>}8BctOF8FrQSs^mCyDpBHs zf&x*gkNW*ApW}CA)r2l>IEp!M~jky@!T zUo-I6=egYA3tfhT_j?|F>$q!Nhip76veY}4W@II2ejJ_QpqSA=9#wp#ykH%68|rVy zolc(Uc^Y!JzU{4GRqZPEH^^9g5p%@6f!S}9rO;29FJejisTANyN;w$1cKT>5S!g`6 zo49RG9c}{ZlX`=ceFQKn~`weI-g#i)?18Q99C+rk(3>QHW$ISq;S9JXMclm zM!!L1?>fwxSDvS@#2+2iP`&2IwB{tV?o8|Pfv1hJdrWlzElv?|i?{w}={)r4dcvQd zYv|AAc+**$zFB0PuJg<4)VmEh3AqqN+D+x3K%WP->{jNQRrvo?5}MehmM5?8W`)0%W?QVnNX-`Df#Q zef%`%Vhby+)tkANu3~X!v!1^})no5I11CofedtsS7z`02b@!nMirjqTQEIuz22h;) zJKxOgZI|{z@-~Im4ZU4}rmC9`lTz+fUGS_Pb|bG1W%}04gl;6Q8tYqxZj7OtC{VS>reuLnc`xWWGmUJ3+77Y5UioUe$ zQT9Cz2T3JVt$uB4Z50S|`pG(;@u%ppTlLp>VP?3i)<9Sw^BK@9{KX`?I?>AJVw>iG#X_mus=8}9awPonl9c6q(`;SC@`}2@Fi*~|^#P{e4 zGQ4NP^4j0s%>RWecvl}bf|7D(iPSMH&@A}O504Zd+k7(W>X}h?c50?}CZts~VO1Dk zZ+slPr}UL1K-Ds5%fbAT`$T7aeD{zVqyjzsm%se*$ExVX1v=3`Uio{!&zo}6^YPJq zTb*R~g;x9AF3Zy^uWETGw@0i8+(&IRtEA2YOy~3HiK-fExF*PXNb4*}|3COK`XdZ5 zSA4)Ao=Ylq+JSf5=2|A@Stu3OU7`8@S?U$kzek{Du-Nap7xqt2zuZUzRGA>sK)^ZI zO^pm|;`ncQ#u@L)Hs<99y}7gq-} z{7>sgfJb3>%mM&kJ8>I0gPs0BYv1#rWq_AID$z;igV`B_^k@z94tOOxm@Hdp;b2yr zksn>)n$f0BCp!?-#fJIO)^xtTJEcDPrCMxrn(qpX|87R}0B=E{wy=){KY5b`l+(mj zOaUDdG(|G5d76MBM@_3xe1Z7qM^#Dq7g7kjIhR%P?IP~2ybSFP4c*pS>C5n!9R{%U z_j(%?pqTVpPaYS>y*q z9ot0hFiqXOkY)ci2s)lr!)rG7h0+4hR_N1d3h!ZYu`v^2fGS?tBRp`!e)a>vow{9!TwU;bxA z#;Z5Fr}2kpK~YXZuX!BiU>Z|2-zwCY;sQ%MAAEI*?>A`U;oiUcXG|4k^^)YIS-iU0 zY=0e5zbY*oa0aur1E;a>n4>LyFmLu%HN(|a1dcjm6i_yh3-Q<9;lfDHgRhed;*j|t zhc)k-`m6BPOm2-BHY9Ls4HuWTu^1>|%SM$du3yJh)=O)%{{BxqJBs;bY;n^50mNzgbSaPm}n=CjPMi3?zf5 zSMi`NPo{FRSIyPMqCxnvtD&sMqj96sdA*!@-k(CE76Q+oQ-CPw8xZ9n*$%t+i_snM z9nc7LA+|pxX6!9EX#~e~0-JY|g-&Dw%uNRSmnfNtL6zW(D81 z{syIVRighq)!hUbIanOn7iqwl`lXE@|DbjH_tq;lwWC}s|7LDA@RaVyZ1gwJ%35I5 z{@vRGF;c<^@VH7Gv2Xsnyu}P9z+Y{^c_UV0yEjNzdAvrFZsc%Gn!j_bgMC&zXU;YG zCQDWGvvyIc`9lq3K)mq#2Y^w;tkSdO=)@FYO?ZJ@zzcEcIZLnof>8lTQNaY z_W-`>e6ua70B~%1i#UoWg2M8E0nn<%&&MwGs~)TCM z@@dEZVsOG9erKaXEhe^rgK^%Ng6#&fn>>cz|@`b1I?dJwQ5d)R=BcM* z;GBQ`l2{zQ1#9Ec*mnpNz-@5;2DzE}P-DUxVF$hm5Akv7d8!Fd_4m8bxjs)wNt`UM z%uL0m7jXBtK78p#3eXr|>clT&G*uLw^4Q^(&vV+s;mgSvv>wE2H2}q9m~HVaeEO!}$t16p zfuIQ)H&))K_J?C)zxw+tWi-+a$4589{>lTYdI3Aqof7&RWCg%B?r!d*k4_Hmf%^=P z;n6+)*2B*lWQ;zu` z?7ewBlx_PzK5ms3OhmFyMO2a^OSTC~QdDAOo20TOvSbODB5M*tsVqr$BWuqG9ey8WTyBqH3xu50x{k>kl-{~ZJriw#lXhCi*F z{|&#IuPkxk>WL2t(IXngqDO*ucG@(48DXRuvWlv&bBoggNuZUl{J)}= z9m@`ciJjvY%LVJ;Xl`P&h<6)WZU?sp(ZAq3F#9I0(D75e8zyW=JrSYIB4<~{D_ZT8?O(QaFFg*X2@v851iPWqwE|vxJB2gEC7VhIV4Grnj3!#tE z^~p>WzqR9YjZJ}ci2mP^4%X`{mD!e!x5LL|ZuET4de?f#^VV2wpW{I828XeB7DaKl zh+SBPam`8L{zd%!*U|NslWj9Fp)nZ+DDQ7*k#E$1hSU8$dA!ELlWNlUrltD{XNO#C zK4eDF20d8p%eH_8(yJH~p*^;Owg0Ea_}?(siw5DwJ&opKZ#%M2h1R-Pa5D=X-6Abd z{O0h=X-}igxmrJ(>Qd~&l+7wEzVQ?f5hUX{HvGG*PE+knq5^lO>G0HL|eP z0A<0r+V&@s-dfr9Z+pZtrY_dlu&ell`&>*AN%=T!Wpyxa>$UuXJXQW&jqU7OY$}gq zk{An9lHYjki|OngRxzp!OMUuL4jh}PBH4YV4M6sL-6crM;HVVO4m6ji=)~?&5nl@0 zGO;4*v#`1=Pe<>;VU|7eQOs}ME27EMd#)D*n?Ber+hnSH3I@M+_}oEQvUz?%;l-%q zM^fJ&V~zJRZF;=KDa8T1B4KD)FXW)ft`K@7LE_<+m=7XJc%8p9+jrL6*vg-ouM9h( zZom8Dx!6lFzVE;RQ)|pqQe)*3I^WMnK|5W@DmdjG3V`@^>mJCKke{&=|qWHUa~Uiltl+Az48nEkOioawuYC6}Bf19+&*TU;Cdo zA(>El+2JO4>$px*Yy}tkcB{Da^5sdFCD%L61@-&Wm?c7M`4vUu;(iMRLiW%8WlQxf zZ`NLg0U)XAk#)xz;9BiX0aP+1d1o9e?N7h5E;Cm69)bi|gECPy*5NpRNzL+qAuHkV>Ozz@jA zNu_BV?b&p{s+7{4e8q%~1IA{5FPTw<8)F&wTgl8v1Yi$0uCR1NqtC6IMZZ$Ogwr8+ zUW)8L=ele!H60o{83*AqYG^*9Q)H6&;(LI`|9AlWj{*dFCq`0k73PpZ?11~ENORo)6O|l3zaFoCA;Tz4&*pvV;MfcYBb(D53 zz6=#6J7&-k^E3UkfIW_3Nyt_XwQXG(Sbo%Gr1^ISXlx7WT@sHz#&YgR$&Rl zX=Iktj@SW~-Q&}}XA{DHY`d5p!TM0`La@&EuXSmG+xQL)%1bJ?9jH z!!aSws%%U1Te~SU8|L$aURU{5*kFcx4)hX;6)|ox1fr4>wgS~Qn&ur&Q%#^1hfuc{f9?k? z!Kb6(QiH27mDDO~>}C6&;Zf&^krw0b?r(LA0XwF`6wJF7-B?_+p_<3og%;${f^1E} zmy+y5^GIcPP3O&J-ibm0G0f)#KQ~Z4xg@~DYKW{8fRLJ zZI>T$Try_b39xa7TC2V}ZV;|{L8Dh(L}<$F0Q>m@07Fim0*Y_kKlp!#i=?;C*J;~5 zeQ9Z6_QzIXJ*h=Z6z`gq2E9D=GLtrKO1B9)B!KL-vL&T80SDUI6C0nW#(`3<S86sDm7W70Xr4r&k+DHbI2bhLr~H)|^>_lAD@ zNEr8<+~QEumHeX6-P^twZ+&~+rN&@kJM3V3FobSa5PgGo39pQ5wk$GvxvSB^LEb0u z;M1Ego)ci4_o^3hm$M7nZr|8tXnZEDa-zBU<4w~^+@$2Ep37p`i|@}Z(e zDzJ}Z3Qnyq_w++PA~;U*pAUd`NetcD-guiM6t^km(~S2R+Q0w((I=+<$wdy>qDfk2 z!!fqXp$~ErD<*a5yvZl6i&|~b;=8TokWrydw{oHy_kP-IWY*Wvh@U~5k)me0sJvuU zB|1CZKbCEV^U02}aE593Z`yNP%6fig9}4=H3nJczSfyNn z{Wg2wZLxYj1|BE!_8+oLW@=zsT%L5OlN^Tc@-Zm8Kl%}S) zPdn-B%V5@~N&`7kr zMp6*%6puO8y$h22imdYV1(V|(lG1xgjkD?Q-)dnUqJtYM-AdDs*Xkl}ol@X?zokfS z*a^MN@*U*$kyPf6J$@FW_pV)k)w3cYz6mIf{b=%XEYOdcWrgtde;3GtQU7hzvZ^f6 z)Gx@;O5f407mP<6-wIC%n2dvJ>1<7eTky%N|7A;G(7 z*a6y^qVXYjU9hS=B2=h~5)5-{`xzzv4FMjvUQ;MAdX3WFffDJrUFN*46b6k}XJ&hC z|NV$(2{64+sUy_;hR~=(N9he9Ec;d($&2~Gi5$sr`$GvHd{_L|qUEeg>O&1h82<%# z(1aDPPs+ikX$J@#nB9_bEoPsxN8(jWoV~u9nb|QG>(c9;<}coe3YB>hS*vt^OXgr8 zs&NjZne(RHTq&ykvAc~S3#;B;C7MwFGyUkruLB%0o%m*$vM~3Xa0^|D)03&GpVS)t zA;NQXUA&8VxqU;Ddb)A8mE5bX83!VsT`WmAo9BG0DVeI0{`sDtIV?L6MvN965$Cqt zMLFcbSn0_*yCv;JL1#zQJMDMlT=V0Xtk}l6E_F3z8JA!tG;F*65lQg?g4WcKc5 zMVa92Oc3Q%*@4owT!_VDUe!Fw;G0o>(#rwRfx^1(s92}`!#5h|+1}J9Gi!a>lRX7n zbQ7&qU~*2ifY?hk2;U~S5cv*4tT zfMOXqUl5ZH);jy0~V~NMn7f@fYoX}7% zMIM;-V1N+>!hRqvvQVE~6fwZcH{HPai%w%ZB{-E>ijJIovM7zN%%Ytsqr9hE~DQnnYJ@ zMdY^G9p#fXFssPZu`73;wG#84P(o+#R3KDQ_= zU7?1Ya1dyuDddWim;A`6o9S%S*RqYWh%YGaTc@&{>ZWYr#0&)6Q`_*);;!qm}^=zV8aZF$WWNp_OR3{%$uDMJ0rPF zy9qX+iY9G@$k~P4oY<{u?VkO$>!PsaL$?QG#+?#HemYM-vJs{7SuzdzzQ01td8!`( zu1PIE9XymM1b9urdw974`E`o1_`9*cIliQ0VBng?C*3b+9{8d@!Z=Fn<+Eofon6(? z96Sg30JTVvM4a%ORlrCGC28)fS8OjtUbg=5?sa?Jv)hvZDmwl*HE1sYCh#Za$Jk@t zpY$q)3p*!giCvhIndije)&rxz6c0UYX}p>KuTwm7^Z@08>N?Hm`G<-*xh@H}cB?q3 z7OF%nbRej|718-w^g7L{eO>l!BKfD6#eiAj9kq@i$WQ|F5gJIM8R`l+sfwoapQV?n6x8#J}WoFjN7?$Ir9v^FYi}dmr2&A3@%1 z90y-*|A!dc&$0b~^7Y?4>HcpX*>6u$%&7_Qdd61rJyN;%cgqEjqmaJVUpyJDE{&s^ z6!!XToz9(-?$avqXL~Mwd#9nX%C}LoXc-m7G@R&{1;{6crG3G=!2+_dj4}vkDv`UD z;*oTqKGg?)-U0KV%QuWAI;>fQzUps{z98L` zSZLF8oVWY=s}aB8wm}=`7|GeM;r*qoV%LA1blPTn^GRp<*M`$kTsX z`18(Zh2FcDbTob(9SqN@U)JKI6qkji+rTTx>V9I=Hjai+PL;Eh!~KK3jYR{~deP?f zVP{6J+sSf&gk95XfZdHH1dbz!2%uuiq9~$%1a-d|k)m1U3)jcRqAQc{Wi1DO+Iqxg z+%0;sXoCTZV?mS7hy%4O58@IJaT18nwpOjes7$uVnM6>?Ix!l@s!C%6m;{-%7wXfb zaX{tAfs?P?)wnVSl^2l@GCl=0YS%#+?p+YO1UMXz+#Oc5#&X6wzxsxoG`N^_+B3gp z{~-7v#9*ak1ElihE4F(A@tX^v82dR**k~Oc??@eqmy~s6*;yUKg{+9*4P@5|lQ6hP6+F zW8?-(6yoi#biDXoPA9BwJak&BdAu#7UQyx6nk2AN900>`xm01X{u8jGr-Fr0xJ@>&Ix|(A}G-kY~jI|Y270mUt3ir#)o@ap12Yi0(f6r7Uz*B~XTQNHoS67OCx(Vd`LgpVKXfoptO+v$Rs||c{xc|#>aBAx+ z%<;bN%(}(ZTYNqqZT!NX!}a*WAW6;Qx2?59JipZlZcyIq24?I5(7koaJ7nE8QD`-o ztMDUy@eiESFi%f^*AXLM6HhG)>gtB8t_RdkZFoE&*6XZ+6W|Ck+)XqxG*-Vw?wNJH z`*8c<57{5N>&A9@96QGgeB^wH%$=ERRZ_j zrLN9W<3DGe*AMF-%>;U3cs^gq%+IC| z{F}Cv{sbD)c#0ntqF7CV>|iK4|(cllrDyQ+}xZ#$MQzAH~_ED zelVc7$~Y)Bbm!B4fsj-F?h;u9czLOaW(hccankqYB2lH;K>c0k1K#fZt!VjM?RDH{ zy#&*XK({!1sJha7^Fk7pmjE{ z!bZsO6>{vLE`;YMEKwP+fubwxHOEa7KSSQ~3A9zkQ4aowTkHU41la`jwXMQPtOCq%ckI-EaR4RN!&U6BQ4s0gUecFR8s1JBs+hNPNfhX~%;?>%ItyK=Yz>(pF z5pAiI!4#(7$^yzd935`jZ6+ed`)CH3_$ACMC6Q6GF18xsx0q$1#wU5UtNF@i+gDEf z6QkBMEH3U>jo6Y-Z3Sq^a5NcV^NOh?;4k=Vprpx4n3!m zs>%BNvjDZqL4d2|kta8)vs#1ghC`CvA=z7%;0P=xJK3>TDL}c&6@HoY-Njrh>}#*= zRSfI^$8YM93-iJ+(+s~`@|3+6!@d+w;M3rHQS;0!;)ckck-cUG1{6uT$Hz$I$nPm0 zb3xRrRO|Bjw8wR$DOQ6AJ(^xYr|+XC>l|#YXAcY&MeNhg`YI#8nY;Vd4cvDD#aYPj zggiqU8m+$Q!t~Y5RM34g^;lY%tJExdvd7lidG@HNVJ*AMP~IyW&_WV6!Zcj#exDQ!_$GBH=M4b?lrf0$>G}Veu#b>YCkO23VB8%6$fu)CloHJ_+Ts# zCfQ6mJA6A~yT`@Wgs^Me&4Y4lTYq6!sgKQElupELw%X~#iTN3ckohQf?CGq{ZpZ~Y zXvFm?c$a!HX`i)|(@lmLJO?|1#Z-BKjt3_C-y=D$FUQ_G7!0&5`k4 zp|0C_u_JFl)89*S`OjTtKqS!zf)Adk%S|VaL$ZcjPcf@^Bk)(CtP|SHvwk!Z0&gUx zB<$9vw3W8uW=`fyD(nb(;pv=apu>PYH(Em7!pO_Op-KE)eo)fr>HUf!lHzI>Jh z^E#c_BW`f3A<6-K&$flBD~dkMo9qS;D#G*J?$Y)X6}Um`4D^5^l715+=y#SCh9amW z0DOYmh@~m09N13!J>n#A_d&`8kEZoKACr%#=H}e1Xv<5PN%U8{>lWhY$k{RV%#r{* z?|vK6OGfFg+sc`4NRCpiN%SY}rUu+X@PTxg;kRwQa_7ou3l9eP(He5B&o@tUb4mBj zHG6pGWnf6Qr!;q>w7vKT-E;N*BeVQP?0WkZ+*p(6w?bGyX zL4DoEPg^$z;PJ)6p(bhg$|4qT%*19*{{;WRe83Fo?`V{tk~St{kZXFMfw6w%3wkEj ztq-^f<`sQJ7H=9ALqdtCkQ|hat(w8(Fb)shM)WDQHwKVA&F?V|dOi;S2}` zVHE~At81D)738()_366%U1XBEpKggDyco+^96cuXZlnQ?&T#6ZmnUn#QZQ(Rn@I;- z`#%-D3=#A2jKKY0Qqg(z0}YwVK+vE#U;|QUQYiZM1%54^M0-iN>!yz}RTOz~HL;u8 zE?Z#-4dhI65`ra~ct??KHK3LbwUW^L-dzWefn)rfwiwI7PKz~0V(|t%XkT{)$9fgV zQ3=x!+EvVwdLHaH2yZM_Ve?cbDph+nfquP!*wd}e8shMO-<}OQ@|&IIp_N@A!njRq zW>71!24hHibI4gcs>IvGT8i;b`lwRRio93i9(xWpO=k!HgXP|gp_}9r5Z!~PV%J&E zbwl7_Shv0Gl@4{B@Ky>48f6t9Xq-Vct>84)%}u6el>l$kZV#?Iv9JcpDcYp{F^^=% z0{Wfy|GHqWf6|%`M~`6IYbI4796A&$_fvkU3~q~R{?i(&Xzx_A~y41Zd_5n zP~5t7d*zCR+6O1zwoHf__nfk&EupQ9)kS$LTaRU@P1$SLor1?qzU|E~di1`j=ENPJ z`>0IRi*bMtD)6R_2k=}-P5WWy*}kLLmV@D| z3;QBSdJM=^Htv2~TXZwMgGb|47Tbwvn_^A-uA=KQg~v}e!g*7wJ7=VjJY!^Zkitp4 zYg6ce1%{_{=-Axj`~bc;Y}Ao3;V`utLgh%31{S}W@RR^B5!S(0v!IFsY5Q*)mfRNY z86$ccDEpnub`6CUcZ4FpPFqGddbtZ?NOy4M2Kh0@p6A!oP5mSrM zP&3T(4X|8XVm1=h*nDk7EJu8v_@i zu~9WO<+>wVPDUHbH|B2L!Tq*JGNW@SM5^2OQU22;3-aMY_%@ea6-wJQv6DMoH&?$I z&`e8v@uRr7&}om3R8>Wt=p(+pxAv)ThYf)I$siKX*@hfNdi)K3MZx<+;Jw;T0}rFC z+6<)vE>4-wb3WXuKijq-HRUwo2?J^5*xk^#JAHEo1z{+&P<#F6U) z(&fGZlMn~nS?7;>m=vS)^R?eXYx@vim8Lvw(GnD(xx-osGqfTZn(skI(ZLySCwXYI zk%_X0Qwrkb4VQ0sw)^jpe0A@l#07VLbSjz)JEYQqohY9NqXnmA;RY2TWpys@qsg>5 zf&N>f-Ha%z+f^A#mEWqzq56r!uele)aU4l@t z6jZyE=~)B^`LSQkAV^;%JzK;Z5XpmYN*}3#6s6#Vwtvjga6EBwXSb0_Z&3Ta&!1Cr z_g!jA`RZU&h#wW4m3bhHa}0Ky)0Fqml-@Nra%)K87*n#;JbLO7NXr+&sSY1#mS`1s z0PlwFvIJEg7=Zpb=B==5ztD}eM6rlWCpv+TR$(tFjg{X81Y(r^YLpG^BVvk>K}mJD zEeu`mEv+EWfR-u^eZaF`sR&8pjK{d#TM$1b*U;te9R@&Jz1;J=&Nsgz8tR7(>n&J6 zD~XZ!*{s`S=SDb-d7FtVDQWa8SEEfWOO&n0Vl(ff2-MNGTTEArWor)j#VumAerH$x zkP2DeroDiu)y2^`&S%BW>{_OovpVjmsOvb2Jx=57uE$f)u!uK?7?m);PeDkK;a2sT za{y5EqZNwJaqdCAEoE`wOy_zr?8(zCCECblm93f4Gc$*;+w9uEg1Wy!P^t+K#c#ou zJOZDS8si+j%bVm)uDOWYV91PTx;*^jCyo*M6UBkA zxr91LbvFFc%({%XOq1{Ro!PRxpG;?_OhWZNfx!NqP$EU!Vge_V zLfzmE@|0m_^^E9V2Xv*l{NkR>9L5t84)^gcY(NSBocKtU{Fh(ZkQ?EwBgs@aH@>Ys zQ{aLf+hxZM0gTu%Q&N+jK)jOF{S*b{E; z0lNSyQcPW1{~U2G*dZ=}2^*OKXhaRY;SL)y-9P7;Hpo;UZ0^F23g?E*iO*8SpYrt! znM)|INtrq8fMSdFO&B)JGT^-`(TeBwKC%K75enOmiOHde~dy$j&T^KLl9t3YSgdd|~I zGqA*yql_j$xzu#8tN<7Nm6cDWuZV#`s~!KNQ}lnL7;InjxQe$?0&?xssCr_1go@6e0142_w`^rr<}TjGloX zfe5Bm#1h&3Xe$Xcb_^@jbOy09?IklYfbPrz1xRqJYYz$s5h7e1Xb+D8%*AQ8W=bYy ztSGYhjbcQ(R*Lz(Z91>mWid!;RCznP2MV1eOyryTL46@VLPD*-E1iUDhpks(*XrR^ z_Pjx-2`o=$#x%v)MlAi|6&{?J9C+9N`?r5HURuADGmjwVHym6Aj- zmXYY~3&GOUPJw!5AD-Qj{ru7bw;V{2mC&npf3YYaQoW3(mhNI47$fl_NN9~ zeLvit{g&0U2Y8+xR_iY{5=!9971?V?$RO! zBpUaA$tQ@%@OMvLdM?B$G7(7SvE^`a%7t-V@(E`jWEE9nh#=%CqKAS@Op|Vx2aSEk zp5*of;M>#8LSPN{18Z=>Nq`45ltQ*c8LoYP3Yu}k85vu!{&##8+3ww*sbBVVQsz9KYqYXkRE=JIYzRx-CavJaNj_sqVq*w_YB24Ii-4eg;0DShop0i(xHhw7E zj_k8{*)Q$>_Cko$ja!x4Yw|AU>7Xbt;P(t#?ZFm7n?!em!pM4Gh^(%MEz4 z3cOPod7qkkMd^@R$XcFOz6V89Ze4{H>pK3u=sJ4=3Dz@RM7M395N0Vg^U(kVKj4LG z+sAa5c6g28{~UZ7i*A}e_|H_-MBKpPZ&0n<2#twSr3&22KMH!U?p6{ekZ?!&`pP+G zMw~w+i@R&%2M5JYR>!k8+a42fE_J}?dnVnHA29JJb`Ix(j$;PUM~HYh4@7e)Kjf$& zoVsLGjTx|o=Obl z@ozDbmt*+497kf@#09%Z(q+56i^LT0&pbVzIq<~n~doNd^D`r#?=4fVYk zi6$^j!??U3az1-cWDWk9m_r%UJ=`f?q8<~Scp!f1eMbfz5YavDiKT+vYg2=h?ucA$ zK#5XkLqY~c2re$-ymG4uqBArj3A37Q(^c*q)a?R?=?TsAKqz7>8kB+xF8WMi{bOI| zvp8XY-N&iP?&V($;x8VScOzd=H)$5-CoMxcZ#oJ+0eM+)+dBMxP7jrxr#xk#)-!SN zb8oWf4I-T}jAJ2-GX64%A#_g$>swcwv4H;dwYZjr3rHQDZI|{3%?JE>$u@qn^^K34 zm@fPzf&YV7rq?7!nv(lfP{{@Ad3Le59dPaQLJp2p#V+!G01Z2i;okgFqsp|A7ED$F zWGaO_$cZ(sDOzS-&@DO!=q03CaE(k`O7Qu`ThGdb+|euQP;JIL2sa8YZHwl^eUHt( zVdc(idT6EI)60YT^2jo(+KHsjCBBFo4Vh?U>7yK|=|I0AGEFvurWaaWw9-r@k!J>L z4J|XfIMY6819ozTD`s+_%JazBHNGu@LkOrl&rN2^dcRU$-2NCdUpvo45atP*WO9P7 zvKc)~d#`jKfgpXco1wkNem2k97&~_x0tF<$<_B9?`;kBlCnVfTD+JKR=$Q3~0w9C; z0zkb)6K-_boDwt&`69!R@J49J0c=U2l!@;wNPaYA&5HMObkFa`ul$duiN%1BssNR= z_h^#7%-uy!zhCBy0n|0b5@IGg0p-NbLaj!EO0X2vCYrDtv22sG4ph4d!g0`43K2|T zSP;;YnRei5jBj_3G9u=y0GU|x(<*EZ(T<&Xg{;M;AO^%$LC@x!G{F{7z&x%n7!LfH zA$gkxkU&+*%q~>arm5!l0Y?WAd_`l{NG`=|%JUcn;GzPI)LE)^>?{|>>qH}YIiKd# zZO;C^{D^&*cA>=JEJ)%UjLm{ZF4zEv`Sjn&%ylIdf7BR^P*9)* z>L0ZYy4ouW*nJKl*vS+#pl2P&?I8Xn`qr+G&NjCG!8QNhw{3l0R;K!|=5|d?#y@xB z1z=`Xwcf4OYQ?dxE7rPc@3{qx<@YuL6Td2%=$XaV9qY@s{wNgrSupa)>xRlGkjeBZ zs6Z0gy=g0)pGaJlZX8@DhjyV)KU#S4+P+6O-shpBe;nN8ztd2d7oI@r($`bX(EIT7 zew9Vqb(1Dtvls z7Nc`DtQj{?q_YXeNMCn@?P0g>Pm#YMxyYd2+R zfw=CGqfv52!|BK&MF&umOS_&y=hfhjn*r}9*dE_N*{4Nwz8|Pt5R#&=?d$26poj2> z20ub&$PCF}#_he-%r27a%Pq6oX3=ptWb^26gO*g`J6{~Zj z(a&D-zqM`Fr}stZzr+1

|#Msb0x!SqN+V<`NWP&)cpK(<70esG_2uVnhgYmJ&?B z`<&W^D5bTZzV*`5@x*Hpj+f(ZHk=RWydZTt59sQPzw$U1))l4d`O*E*UZzZhFgqLh zYVe|;MZxo3#l2pN(*Hc(9{p!7Fo@DYv;$0MZDBH~E!hpM)gKXiI-G-)P?3S}QAS<# z3Eu&6!kozK3N>oSKq-BHs~AI@X{|4EU=?-)==F83%zvd`8tE$6j6GZv(kScd$zF37RVrM(nh+Xafd#nGQ!2cDC|Ec>bIW9)@d8CD*y|kU3x|V!f z?jOC7_2`?;WVqGhpdn7=TJ@v~2SeXa$NOhXgAK0?uc@JL0V(R)-&k=aptQr_&g%oX zKuW#-%R6~;kKepvV3uWk^04&GdSBUsF^^W)0;sBA5o=U4^+l~HaG(K*!v01w#j3$B z{s$kUxjkDu?lrl1t7=+Y-h112+xB?n^6<{GZcGqlCBD3eC{8AT+Nh=rcW<2|xy9n4 z(D5OyZ=LX!W=ivH@O`H0(Kmk8yX&@ZLQi%vK7uhnKcme#xG)*ZhaIxd1JvM~ezX7s z3L~60hU!4B+@*#COgi)lf+B@ntQbT4L`q+#$Ti=H?3FR&dvQQnWC_#9ISgp1PI-By zkv`V~1bYrt-%IAkD~fiM?S$%NL05@av1F-<2COCL*PsL?a2mbj*fY>74mEl*tPF%n zcWQ`8?%JLYp+MKEIST6 zQGz8OupV+z12p=Y=GeC`PqFx;ChsXrq|PmN?-ekzJ3G*Z@t8JXc3;pa(2rI2pPDGl|-7f{22v)1?uxyAu4 zp%Q!KSB(9>U$2RQ9<1ehwJ}NnLXn|5Bgy=ZW=eH6ksespYR=CGmWkVA}YYwkj`(V28!SB{EgwQA0OFNM_@p6%PDTOYVC!JGEDWL$=Ge%3{wo z5T>KvNrgw|`p^gq!5&O`j&i^7=#yXo;-EdEKyfRf?$jPb{O|`0tH?K+qRhx-Dd-c!9bD`Ywm2 zELZ%U=T1NQsL@Ps|A8YX8cs|UY`5801JJJ{xVHOOWQt6HAG9qd8hgWUOXnShyalpe zq94UC6Y7V)T#@-zXx-@ho&5U4On{qU9Q}2T%HF@TOwh^wtz{=`A7HuxYOk$P%!jjW zHvZ!v-`|v6<-QN7w_QUGeRMO(0o@pb=11UbH2yRqXos^-<*nh-xf#b?94_wvy7RQy z-ya)8UnjJEMYC1iZ*ps{5BqxvED(M%^GVIQmWGHzX`n~bw86T2cV^H06RbKD$PL@; zcH{$uo*KPShtQ3X(=%{X_8j7$PiGwK4&@eN1!khZQ{_}7pXU@*NCezmdO>yA09VIt z|I6zkuC$(L&M;uG+{#oBdU9J+=ckOd*BggQP#Pv+}`Jp6Qqh(Qs%%=zk zGEK-d{75!=-n;v~nB_KE2#COZb7@-P8VqHGd4L_#%_3`nzFjRsrn)0l(w?*$Tx+XAn&XR zt;~XM1e~=V-qmO#$J&)R=-x8mnvMK^eN+9{+fxFEW>Mf&QP4kmw2!<8^A;$c!RkR; zA$Cq`E~>KFcBe>HAA`EneU2zGL07H!c%2b`b+@IJF2o+{6OYcfn_O*TyH~C>-m=sj zQ2zp~ESeG^jb2=t74Is?^#>2~O^jNQxUp{JO$)HE%3 z<+lj3xy;xwF3?@5xG0+xj$gk0Txib7|B`>Lj_mOpPe?r`87t>iQ!6sYRCPaS`W=#P z3CKY>Na1{5&&f5@$W*wz#9Q2M)1>BBVA6@=pWw^;q+=I8uL1?3u4!C_rFiFRlNWtU zqWTvw0YW5En*)!C7rXXg?1FdyK$M0ntVB;{q?e$mu?ZuWykEM!Ux|5$y;u)U?NOy1 zr-YJr-vx2oFFM#_$P3DL~bf7X!Gd2)WJQI{QGJg3=Q~0f*QPzkJ~YPSnw@Q zfw*I|M5rCeH9a*;y$g&1VmW#iVs$@T+Ckel9-3sEI5y=|3P}unxUrr{hXfp zD4&d1q9IE_#Mz!|E)(s-rDw5o;vwWA(PzI1WhX++faRZ5c&)`!c5$$73iMCQa5pHb z&LWIoJaqPrLPd-hu0w5QqKPtKMSP1bxu3^5BBxFfE~wk*ahW+DG567#tcU$2bH>M_ z!-{0BTI|Au5Geq4xz{xTGPuS!UA5wG4m4*N`nJW1Jtw%wE~X`|Fg?bO9e63w;gyEE%zQ^DiPK}*&R3cO|m^Ka{NbiClX7dix1((dcL zpY_jrUY5kb>q>2~*1?-g!JjB4ssE0F=9xIeez zCs`SfnA-@EeCzaIGk++|iHO>_2UdJu7Q#$y5){C$hht~1JgnbvdU}1SrK!VngTb;5+;W;Qg3t8FkCEqFf z^w<3l(!FKGyeWidxzy7hMuLjIfskKsz2wO2ddvM9CQ2k;pHwnYL@?HZoqlh=r097L zwovrn-GRQeRpggM8^s1px|@hJJCED(^{It)6J<{@OgR3>ljcNRYdIxVN2j@QS+kMA zI0PhORe{fO@HaGFV-}}~q26upXCDx%0z!1Wjv`}`4OeOjz4Jif1LdHd$MPX^WvFs5ISD`dwC-8RT(Uc9HUG` z=f$qu6t#^@1lnyn*3T-peSv3weeb3?R@^#^Pw$r;rgAW-{s*7Bx6hvY{8{`l7H;~a z_k@O!&1Vi+pX|6it}S9-yb8T6hGYftT>>WWf0#k$mY|Smf{RK3HEAWCCNHQyypvBy>d=4la0|~(1LyzgR*+phM0s?E z%hHYK2X95&d?Slxk%LuHug!H?KphIAwoVz_`kg2}o^i9sa@+DsHE@oz6eK|)V(dWs z0D|VY=m~e3wMY!OGStnEI=upZERZiIC5hfmE7W9>KQA{`{_<>=X2^o|H>zk+rm8l$ zJrU?54a&s1RTziI`D51IDWi_`@+1nc2fuQxNdvANb{$T*Nh7dCxw}dn-Z}22mymd-R^X)4+*&fR^{1lJu$Pn2@!c zI48)7%faa$V`BtVzAgg#7nqPyFoB_>QQdjZo@K8H>A@^1f#Shm;bbOs3W&!l%pWiLB};$(VS|T==C{64J0cQDZh4_C`1tCgUVcZ>0wMp=eij?+*MVF}1?pYHg>{|C zyrf2dZ)U&N(EjJYfe?ycp7x3<&?UN^Z+~TVnwVCb{%6ygG4K|5<5|&xO_C(PkC!YB zoc$s#&9b8|YfKBbwd_9FdboTx$ile7Rm=Uy4TjIKkML{l9aF@#wk+8qiyLeRjqR_o z8Ll-SowK2b2>8nj3^ouR(c*Ir+8(6;k*Vp4&%F~2h5c~u`au@CWhqg_WQOhOVh^NF z&b&hV!!fo40!Gi#5{Irwp1f|KdQNUy*XHv+QqZ6%s8|^*a^zuoI9{}fb&J09!dl38 zj|qM-%-k|@`|$q2t%(t}dIv71KCImS!<{pac8sDCh7(YE7sU5rY%#AOv>UNl66M{u zXDX1G;fZFk%=fPB6|pwF@3B?mp7_zl!<`JgLSX?2Dth{R#B;W}F!-h>rigKYL!{UJ zqQRmUXZsela^lBP9nl_@ows*=*EW@+)Z5%17LLK+*)Lh^KglC}s%tT86AO!BRm@2o z#XOlLiSNF*`>w_Dk52qLm!B9uZy6YNugvr9;NB}C>z^JbfBu1LHO#&thke@u_Axo~ z#r?%?L|cY2XFDr4si9lEJ=cG5!j zAY@4tdziVm3 znv>Yo)m!CoqjAx2LOv=?DpWB1|KaY>qoMxa@NrzJR6-?tOl2!Vl9X+VkhV$obzY^c zA!J`h3$jkg5-|zMI%Ho*_9go=vK#w8V;y5=`aZl?dT*cKIlps0-}C*WBgTy9@_62l z`@XLGzAmvNetdVn-q5Ys&wjHXU!GuOI089B3>MMXBy0q`N99o*W!!8{pHiQ4dIY6} zj1=hJN{P7Jd2>b8iz7-X=uILih=%3YX*NA*zbB~2<1K-i_ZdJ)j*^e15<_LsC+Izx zS{TxAHB-Yl7wq@(M~1X?*1EOP2Nd`T9?=?*+o5H;MRo5xzd^LWjMP0>Da4~*{gLHR1bhp7Hh5{sRIa;G3dCl z?@t;A7$u98HcCeh145NW4FE7QDune(vJ+u<-!g^UToW9qCQ>y^@aX$Jqxb_z?#%G@ zHxW|(f^mK&8$rE+{*%tsc|g|*S_+I zZ^N2t>CPrOj4?&?n|g4gr1s!hl)mS8#PaP=pyR0oIVlXX;$S;EVocgu_T_4NuruBV zGX*5eb^GqM5`=G)X|g%Q>PFG~($-u)gmdtPD%nTnYQ!RYxvA>AMz~2xPGjxShC=y{ ztdFx-Y2)3Pg*lblBYg) zI@!xzH4Wm!$shDNR8S+=kA>P6P+~s5z{{It|LjO;h9-1^a&L+K;}F}Cjx9ujGy@wI zsM%4=9)h%vz6!bw3K|+}-WaER3hp$H`&p)6VT_rbd?$NF{Zz8#5Yb7SAWcbVYLlV> zEG$V`9NUMvTjg#GrK>|6Mu@eo-C6;F(Az2m)*d`rJ(0VgmqL|p_uG`zx-I(2l~#xa zGIQ~L&GRw&hN^AaqA%?`mMZrc0A$cGt)<;>EM0JbANPp9PVuc>omvnlN#)&>O8#D& zX6_yY;YQT6seojh03J-uxsLj*>SfH}HJ!0Gd$=4Ct3lv0Bi#(U(YvpGgz{Me6*FTL z@0RZTh*Ou|a3&c3Rgq+~u?02MjOvJpX?23aGOq1L zR0M3(Es&=?DkWab^Jb|5)c&4Iss(155k=`s`4eVonRTu64|@HrR4&(QZoRN|H8q+Q zjq}c~bVfK2Z6AMUUoN7?$VY8kH$OR1?SF*4=O_7kAj>>OAiMubBYuBqC+;HoB#rpN z(&cVR@{PNqiwCoS_bZ+g$lN+C^p;NYVsCn}Sf9ZfN%$@G2LVo@$81VL z0M`9a7|smynwvxY4cEl)cqtun!jk8lJWm&`vzyt4d|2Pu70Ov!&n>Fb9Mi@l5ZW1@ zJD%+AAEUY?FSATaMPvM0UOeKu)))Wi+m?nYRIsBUmttI_&`49xQ{Q`a%fJ(L3pCur z;|J9#L-Og1)0v7p_?PSS-?yRTyQI%b(GR-G_Z$MWtY|G6q97K5Q&VyTL$TclS~N|5 zuqKlY-w%u)01x(;khnwlc${E2`GIM>lQY6v>{A`dz*LYfIJ^h(6qsat50hh%WZ{pa z(*OC+d}IVzBsAY$8&LmC)Vp<{t>~Y(0&S~*;WIvYnAW&)QO^z4QH{1Q&jgt{jhN&v z{s=sNSf*F&g>{62JQ~+#minC^v}0ZTf#~hN#=8z+4|aDmZrzSItLNTsez0WN#h9E- z@dA`c`(%2_*OjTo)fyUy+`lx{|IOEb>BZOnmv;TXox4D%>VIM@RuH>@o$|Nk9A67* zsV;!XkoXY^*oivmZgN7vg7C?l#C!Ien~LB7<+Cl>+rx6}MgsK30*A7K#WCi2*%Xx( z$z|&tAfUfMPc20OGnr&`6)^*m*Gd*%QCUO!BB{(Nh+ue_q&yeMd2ZkfDBe#ItFGFh z_Yt)W>3|M^_n2d2TldcJoFQ6zSf_rCIxldAeZP0p6~V5RSF4#*>#>EF8R|C*pY3s; z8hJkZGBq4(8u93y+!EVHZ{S@UtGJQ0q+}7L4MUz`Jm<85`_bi>3tLkC2`#SUIbbW{ zsT(y40Py&!1!ihlJo-5C(sd&D1-hheEHwaa+KCxw0)9D(kb3PI7dicrvvph9(8IW= z(!n-<(Q{X+CK%9zN=6hy-)9W|!P};b9R-~R#Gf?J5EA;ufB1$NDe>Ps`K^Y@ z9?jG(g!CrYH~2E#0Y#c|BhSq+aXL3$!&?T#dOE(8@iROMERTa>g9plu@%*b&P;PXl z74kjHK;+XPYSX!Ah@-+di9Wou|52Ex3I=76e0XoqS_|t!MLp!nJqv$nziZ%%g z(;G$j1iK>$iFU3ivaeNn_V!NLqs^asRI=%j8~g8L0v2ZO57vaYrv$EY>)$WHWK(KA zRoSkx&vKE4WdK|rg17^cmoLy&UQKoCHsh8V0@C6%y$t}fnNmEigP6qG#2|HUa9ujP z>)p98b(U0L1Ta;x9rHHsorO8MNambUV|}Kg{>r~M69)F8*y@R!APx?>*rs~GV_ zzY`d~;SgRwJ0;zmwR&j#^XAgt{v+H8IBWiiZARhYzb6!Ecj7z$j23juvjRt$vDv>o zU;5JjBYwUaGIg1C&!(Aij*zO*A$vbme zy{^=@PuH?J68uq=p@&3OKF+-zbQ z`7ZfkSmX`G1rSpZ<0xko^_vrhbb6XryRt9080h;x4{|iCeqf>HxN$z9G3_pA5v}Zq zNr7b`(H4wStyKP`Sw9SL>-Iy2vHx$Vt9S32mOHj=?G_*hwZ>o#4=CRoJ#}kYFj#pR zx6)qP`T=p29bW4t{}G!mk|)PXs$Gz(kv-1K=xce7j_%8qC3pm!`MHQxm623euHhL8 zmZXBP?Axro?q{dM!n&d<=gWFzHJ(n zm1b|H415uvI(b1(1$rxa>-wOD_T^~7b#LF^%F-=Z^jGWpJCva)bYyaPi;uQeBC7Ti zb_g}=vG$Z2AqK5#{~SA09OATOlDma`|NV7!)|j`6&DQiDitzmK%ML$z)zH;<3EwBzZxa}4jjPbCIy z3BqV^)a#omdv-#gr`FNYLdBl5xH%=B#cjo{^HrCyZeP>o&;MYSP`TVMIWV`Am zcnE*l^-fRP^HoF?-!JG{I{O+h^^G-;m*!>RpbYno)eVxga!?|rEv^~{OV<8Vih#KM zv$f*Ru9?K~H4%9Al+K&{S9ve)Fy*Pt&G`dqGh)*7_j{cY&$ez1P%_2I@LRRpgOkau zWsRP`8SMrdB(S$zjk=(!FhWOAtK0D2dvhF=c}Q_eX>>S}&)x}QQ&}*non1FSEGDH` z_`fQ*PmJX|MR9f(87vJ$po#-CICb}oN68loVrUm^S6VaA()P^7bSKBj-RkBI3D>v2 z#z(v2l#1G2tW4%vfpdG1C-TvX*P5;`;p%32xDWqUvdMlyIX8*!v;Vc?EV$BA4oY=^ zS^hZ)l=@4#`Hi5QjSCxHxF+SB|BC)GB~t>~;U%-M)97v_x6HUk@op|{)eoLK?vzaE zBGRI4M@juZYP$am74LS}*ncVq_Y!wC=mCYDG-6`wr$sz<0Iq~wEo`~R#4Vh@W3~Pt z6p;TLCjNC`NCTEmL>nV&WO_j;0a!6m*>OOeK>M6ma7MU};ytnK)XVE1nRNfxnxmQf z?}!7)9Cw*U>j3tB7u?|J=aOnTSNx(2(92mY`FVY^-I0z)?qK)$$e6!FQ|GUN)c>xn z)BgW)5{fm4h!|@Q!he>4VqP9||OrqOOtRS+Q=ahcRqGZug*(tT+fF?F09ifjmSoW#yWibvq zdh$pQTeX-^oERZEd{mf0raD4Rb+jo8E$Wl}S@ny)g~Yc*H>1rSmE`d@z4Pm``HfHG z>67c|r=PrJeXvTz^~n2WOM&hZBE`k;A3zz5YUjPk-wfvNod~Jm*P(4;x;14&G$-6@ z!m;_uI*p=Pr)~2~t;YNuMD|Xu<61g~_%-u_o%3%0_UPhf*s+Pg243d3Y>;c93KDz` z)`wq@lRY%pfdbLadt9xWY|h5{JU~F@d+X(RXAawt+%q3mO7fhZ1AX}a34hSI@^9x9!@iz>=eQT zHePAWkAKPZ>`7XFK^mlLLAkEkaM?mDkeAQ8?+nU*50>NbA(GZ)*{BS%Q$AI0uy2sKLScOz(~uu5)%z`9I%h(ko85^ndemLTUZmyKDDujpxNZ|W zbUx#IA=#JZMY=_#^iY~py3u~-p_?7$-@FIv1Zm@h6WD?b@=lWnkc&t%bAV`LUo2+^ znWx=j89we&vUPs0&%JMdtFP9>^=Tf}=<_ExmEIj_5%13KDF>}+*NMoCxAt4rph7S= zir*455L`~b)qzRCo{1mCC%W$8^PW*@II&j^%{w?@`4uPOT3 zh6Z=ML|pDPrG+~>8yg@P$Xs8s501&`>K@wnc!9yER^k=1v78y5y^v@QHGSFIzRC@IEI5=$v&^UJ5x?jon(O7ln|=cJlzs|R3xYAaAq~X*TE6S zyKZ$ph>0oc9*X5e>kkZ7Y4A3+2u0z7s}W79Mu-te4@Rv%?N$<}L`;bp6@i<-bQWiO zDJ4zmR)dSdw5XVs6^GX7dgRQvtK>9qAO(b`1w=qG( zm@iLQ*S+#eOkjM{&Escef+5Taf{BdivWun7h_N2nM(3uadC<(t#m4-|`i7mlKiLR? zg8gN2%0z&XUGx#yE@(>`sKWl1?NIj#?|QE@ZC7mBOf_Zs95OK^md-$~dA^gqn0h`G ziCFsvjx*D@%9nB&U^;if1jZG(pL??e(@QunuPT6T8c3M6%%Q>{L8>Ao0MJ%l!IL9% zt=5li&YMoOLL76eS|jAsr|Q(~<6>F92O+2mBGlO)wkjs#I>;th7x8_ zZ;HzPlP0wgvd|2kTmNqeu1Uw+6zsUG7N~9iaPIZw0qhDbb(0!y50K^G)gro~Oo+M{ zpxOXv5C^5cga%K`M`T1Dx_tkP)Y9?;SD_v;0w*2I(3|a|LpL>6bmu?q!rY0vo|s1HkY^acu-eDf>(V-8t`$fA|K55uNuL;{cRkr%n;4d;N3gUN2v22!I9OX}1eF%94&-NN z^Q7yianh_pS%?J+c&)b%=%e@AuRRVv7fPQKru;Qr>xsXz>P@Z6XSQLucbv#N7|o0m zLM|^N49B3yC%1p1WMpu2Yu(WP@2D ziu{H4BetCH$BgTIIcwXhhe|c6IbE&U-s*GWmHgnRn041A7)VwwLNcX@Qx&ZdzMYPG;IU=qqV!$rx^2Vx|`rl5n?l|Ng@Hh z)ySBiG-lAvqk!B(gIL0b4i73eY#zO>%lut8UPG8Ng3Z}`WeuI^DQ)Sf7x^-j{z=%g z62IW-AIf_RXFc;qSnVqNHeObMsfnjt^ur5w;-{SE-~qPh#eyswq3tgRsn7T|&I!Va(4Byl!d{(O!Gtre?6`hn0$&>svSLr*zsM2~OznpK^=fpem%f z+hIjM6E45{qR*R^cXz_zM9U9a;}8;d4?(c53dhhKoUjhb@_5f95@!=r^7cndSDjH* zN{jfV(X~{z9P7(Q5t5_nEEbsiHGPCYlNgM)o6D?+FkXdvGYQ!LAqK$4;hkBcM-2Hy zR4Wx6?I&9E%&C3N6TfuJn=Y#V0IntZ%M=a4$QeX!SM8ULXmyZFudnzcn!T<|*A|bK{^B5hg+5*9Kg}Y8*Vd2ur@! zm%j1mhTeg0ZaU1LAwTmR|?HF{Yv;=GJpt6Q8FcP z-75ANx*vSQDy{H?e%AEWS>fdmwWUvD(h!%l2jeNnLPz--xgw>=TQQq5Ijfa_>LS*H zwB;X@x%1@`fPu++{m^u90*u!_pbi0-RHqu!9oNjfBGb+IK&C;d4OpHhW<4^V+yOw} zpET93NFPjh>l%E{P|n51x70YHI?P9RQ@DdozPuZJSfveUad8}@k}*WBLilI`3PLtN zv6$#0iT*J4G&NzG!O3Av!ns6Mx+v@^aRnq88bg>{3jh?AYDbg+e9RM07CtY!ye9Qs z^PSug4Y2ofAm$Pg4RgY^nGFfky^j(Wj+uTyuslHrJdd|LV-*vrNKGj@u)dLw;w0wW zid5-P(|VfNJ|%vEDwbrcw1w_mle6$&F~Y>JDyerJ(&uRWa5RXV_5olk!$8ftjB#+P@{Uu+h#^wXU+Rcl?{= zJ=?``Y+weqAP>qp-TLsNH;NxI1@bX(D=;`C^6TIXu~j9wLiqgVfU5$y*m}TKZ^w~b zlHX?}M*O-aP1+&i6TkBkWm+`VN-NAc9cMw&IgzswY72i8W85PDSk*bKjYq~ww6_L6 zpSj%n!sf#En!?OCNPUrEk)+fdgLj5I(6=(UQ(ac;xEgOLhi1H}(=$Aw!$# z$(SL>(~$=`tv!xwQ-~ zriO{b+AguMUFTCIM_=nIJy*QKf?_zYLjfOZKXZ+z089<-xC21R`Jd+o07yVs4Xl|G z=$0u+i=zW}UZ0|j+qKe)bF5$Yqv&(MnF3YSe|%}yZtbG-1b3c_ou?onE_h`+@jS4` zSiM1=pfKe+6|mJ5TR_-2VT*I4++Cndv{~^LBayd8X->#jkYz zX#D%SxTn?AT013o4DJ|-m;z|NF7B9RnapS_(9O*iPV=*Vl+(33gANKC5D(QX2o-L6 z^iWhcLXhs^BbI85Y@|7WPO5Gb9qi+sA%^(%kW3iljqJ4o^#Mz}Q!XtD4r~#krCFNS z-UuE1v`?sx^VAiNM1ojVv!Y=ZeqkZ5cwPP+zGBcamHkGFaH)~TjIFh@^<0zhv!xF<=m$ zN!@KRe+_aIzAf8ON3lnYTaKan#Hcr7#J0Bm9X^?$vq2U3*Ptz-b%70x8r_YgY>Sn* zOK3L|oVJo+imOltK&Kn4km2!NWB^NtK1o3I*HPc88d+^ix<$MK^CqIuQ0>J3_G0rS zWLW?{ik$Xnxeah4SO4wV~W{jQ?mbKt-iiMLIxZknM<8a0X$Tubun5qqt5>~A_)P^XBEzy{z z!?P`PQ(ms~bFF*HL49rsza+`E&0u4#d`@((lWOI#q^IxX4*^w0`rxW9_8*OsVrgzN zBE!auNXBfo?+NjWsPEb@$aaL!B_mBTE~^~J&si_>E1O&!r@4562@)ka3qMk-n@nsO zlq?gv`sl^?kcB9ra&`3K@8U%4cT8?R#!^1sbhKJ|WyGY|WX$>p{aN4;Tr?AGIkz(J z8f-Az5suD*iy)hw70=qYvhtM4^y6R+a#Yh&BC}x;LG1OVGHd)Ex$p4OI-26+f!7v7 zC4zsTABRzu2taSo+dT%aU#G{t81R`W3X$4U=sCpQaG$s0?S}&`s#T2Lt3C1dvW{AB zc}Mx)So7o!hm7&6A4ztbM$cKnAORPP+UjlvAcZEa#1;qp@x7zxTgfo7y3`tST4$-$ z-?LFJRxz(|4A<#2Z~BpyTq8lAGS&m%08>}wV;j#xwC0z z%Ot6BtbT! z3mt}ycVz=rj3+=9<1ulU{Vke%CoEH=?L$mjCI@E!q&eaNpR{vjBfl2)^!;6+voPQg z&pDg%z|ruNua=^xFX_L}%s3L3zs{(P*mIjf()xZ~yHyd1bFE1Uv?KFUp(F+e$$$QD z4xcAQjjW>*ge_3ZTmf^CmDF9t!n;5;#M^N&l zp;++TKfqMiojK6HGJQp17s-_Qa@V&X8Mt@yN%2XDdeK&DVI8=i{+PdeOHTTLV*~;J zC<3d4SUo&9_jZ?F+#k&)S{}7Y`}4XSt<#MRe)AGvIceW)_TT~R(ZqpJE9?J?S9H|_ zol&ptf1LknUf=-R;F`7!|EK5WKArAgF>x0_@k`IkFTSXM;vpwNKJj)v1?WdWR}YtH z?3~%5`Ps3XB0;kg+X)3#nEk6k@?#Hji+Uu3>=0kgT@6E*ib@N|ChohQJs%(Q!J208 zO}Za*FkXy1jX*`P?8R)Mbi>wE_yCag9S| znCC~=hCPdPlxpY0-{lwxRTA=1Q6IQ)%Ac5out53B3;Dyn5xG*Y*TnW1%!TA7;w`cZ zv)(`6ym>pZ_uD;04~ynWZpxh0`a2XrgBh4dFkaMyv#q0V#PXt6n;dlcLi?NSkKf2S zXL7MNDJRdb9_J;${s|-LDL0K^104iRI_@U;;Ti_Z>XrUsMaKAK%ZrDeiXQ)M*lu5M zcBb-W^p&ex`;{H2qVAH!(Gi4qx$j=-sAzL;dN}6nTxoy3xnc>4r;ks4Pe5N)^bDaEL*pEmcC~6Wd@hS$y_|Yp? zi|F&G^5cGpw)e5JF)Kk@z7Prq?UdQ=kxB0MTlyEtT9l{-PwAOHAWraOu-Du>_vRFQ zeugh%#_3Y~LQ@0R)jhA63QKJb3u@8(2A#Rqg$NEoSU*YF@w|CXkG4q74=}#dQ+Ygh z9t+|4``=z2k29lPK<%5;HsZsZyut~O-ze%nFJhgf^X(vybAWgmwN#A^jJt%zJY)y*(xi8N z+HV#!=4G2oE?-cc5+OUTTzpsCF3?kd@?my|mprj^QB-tZBr|iNufye&N+|2;Xu&7@ z;(1?c;BVqGuoXZ71QjGNFue}lCsGcv{Ayl|d9x;?Z|cHL_8WhH`mkvDy~jwmCofu* zEs?-ib}0arV;%I_-J+8!Wa%mBc2}WK1a_}W_ioU4bqSm0kMPh_o{MdVI2yA~-KU#T zVj&DK2#ziB3^+Gcp!QJNEKC)8WF7dH2WBnelk&3j4VT$p@J>C^zB)oK+{apvhxbWT zqSz6pOM7NWSZ~ksmVM>1Uo0FW3T$WX--~vr49u=?2>WWt=o(#ocKzfB@F&*(qT2p% zeSK_+Sw4)S(r#0E*z_PL>UGds1XT4;U*o3yuVea!O6)UkAiCng-1?Q`tAU84Yx3Q|b91Ye zbi~jTc;K#^TO-6edpDDD=k?2_&C<@0zrvCo7G>KQ+=t;HKxIWU1(x&Vx)nWf3avYy z&4qnFZo4`wJmqI>mgtOHIxS>7dj(yaGfgESaFNr9d-&H&-r!Ojje>lNO%y#e+0xP; zxwqNydW#)*)9OOc7d0j4&@k7Wxzqbuo=^-(kab%?j3k3QjOGRNWyebka556kD@Ny4 zp--dir@8i)(T}ivh%D#tI^#ck`$7Ke#Xzr{E6m6w7evr2bf821xV^B7$FrE2G^>zE z!C`;WN$^px8*guYC#LU1jJh5DNmIleh0dDS)n9jbl*~JC&s2S<0$UfpCK~>^Wthf@ zX!Jd0+yKINY=ZD?41(0lfa4K|ZvjpwN6WKpw5cn)m)l|?CoWx1CC?b(X>u-<^E`tf) z@A68aZP+UMeYv!`HmlL$U~`+PB8cUir*B_g)t)R}6iDx3O?Z`~8i(EUP$TECTH1+D z+xFSGDYqvGr&Qn_*cU|hv5Y}|RpouV$HZE-q5}O{Y)AFh;{DlR&|L;U(Ksu0OqlXv z-;v9ij913`%Duj}<;lbf%RzAvNRr@t6C4VWK%p^v{lvLS3ri2mQ3Zp4YFSF^!zSI zm*+6*$-g2?t{RD~@998{ixV+xY3+#mdXTJBUl}Mzj`sI-=Leb0pTwN@aof_YFTLsp z@Cw@G_3j~;sb{T&B=uxG`hg|c4*WD=49P=uyG{s?^j@QT3%3wvddEV4ZskL5k@fHe z$rOAY1t5}3&D&h%CUQ(5Mh*?ii6goVZjx2?`Rw8Kq9vIhwx5ncPcEY8vRQI}x}l11 z%wiUmxQ>R1*0yAd=JMx0dF&zT6y6aZ`7pI|%XHJi0$foHhgQVJQx_>P)lWrxMq|e6u{O}( z!`!YgW+-^*-_oDWy3Tg1b4~{()3THoosO*WbnlZ3klR!u)dA%7y8_=z9-cDQca7Rxy!Df2hzE14j%)p1Y1%^E)Hcku zBLeRBgB)<4f*VwStpl&IK4j?!EabmJp-HZ1Ja9hm3%zmb0yyzh`*y<~${4(94 zr1Zo+sUr?N-Qh5QBVQ%RXeVfH9+*ZgAMvH`2lte-8_ki`Ce}Qd%3v{1I}JP$wQYVm zQ!Imew7jppG?HI=sT|{n#meCnslM!MZxOs2h@G2N52xk`In~C}F zfcB=*NnXTpfJ|a5O+Zu4$|ZM6>L;5G-{s7RQcTqlSGltZuU?74h3Tr!8AGPsu6D9g z6+xi>tPK&}2`J8{HLi&<8f9i%Bqbq$0ui zF9t_dE_k|j7HtyB-uRIBFYG>53{Wg_!_Y=`>ZAtW6PbQYT ziV2>ishW2-sjD%FZ}h6XZ|XMKxYIuHEQr+r@1edbKy3hm!TYKG@4a}H;-IDBdVlNI z%s66<NSu6GBGokg^`xc11fJfI13Z;;o#*uZ<_9LvI` z`*gv?s_O*DkLAon?@u~l(>bK#tH>23m!zp5qu(2D^R(mXyf$NY6PI{z>)|vOEzv?F z0eJ~xj=9Z9io2l4Xb|Hpj;?z#Pxa}K#N3xk5w~yh<`#V@Fu>Fg-AE){br4aPKD;QkjE<29D|)AC#cJ$d&*6jKLid}S z!ehhauo=W04l-#OM|8jsD`1~P>4}xgmM>g>^EbZn_!#q%#ks;0B3&y!EbZ>DT2Lm_ zQ^#3xdp&&SwKgi5HU&8aDfVGMX_U)6ZDnPJOsq7t4g9Erizl-4%pQw}bLbWaxqonw z)n5-&;`H%81(8s?BG`M;cr$wQL&=F_Vaq9eO!k34X@p2_>)}D(b-okxO?C-bA!BGy zIUMbc^APWrhlfJ-Mx@+7`?i867}LJty@>H3OXxZ?;Fdac(q0f70B$7Zg_-9DK|3M3 z1n#w*O}vf1<~eI|JZr5<|17j#n}Lw3|PVeHDWRdcwNIpgv5&b zJ~*je7lHgqqbdzU%0s(WlX(v3zGSvSJVu`nX0MmJHAerEW-B3-NpBstQ$_dNmJ*1@ z3r~O@-!x|RdG{(1&I8yDU$GcztC0Grpuv_o&OuFes3B=4zNSV;TBA@eFr~KTL4c?A z_f73}@2J}Zh6;A^{u{}J^TR!*b(K7Ps7m=Wxdet>VO##>_pj{~P~senH>bRQpcRN4 zXTxwxxEDy#v)XgZQzUF4_S@uI(n$x^npvO%eZaG<%Gpe#(X8VOn&Y}aX>=FGs>5bj z6EEqLeCD-f$;~H1S7{0vUTk9JM&Bn$4;3BoMC8==uJ<8o-{Ywh$wb`*W$nG(mzz~E zi=nARW2$nQ6Sa(&s<{<(xJ4~ll%TTgD0Vnrx{a!vQ;2W4QUqB<=*5-j?tQNPNs~W15(Rhc{=_O ziPkX4mIpv?NCI_)YLZtAeKyg=GOJ%7Q_=F(41H$R((et4HPd&yKj=Ad^}&kc!JKqA zX!`0vNYAYO0X13)Izttsqw_RM8~+@x)YtSKU5SNGPZpfJPdvIUk9k^7Uz zYkTsstLM`AK;rvuqv-pu3>Ey(W(2GMGyR>gBdMNgd1j@#wT!MxBz(g<-71sSW-VaV>NE|mij#9 z5gf5=oUdD3cpJO5Zl(SJXK0dBO@>Oi)74Tqy8p935%zkeL)Tx^$lJ1emu7#T@S~Y< z2MPF@4BrkIg}u8+TW}|X_d6uI4U1j`3tZ9fUCw_bS zjw+xnxQ4z!No)^B#b1G(jd%tt(x_X`Y#lWUvbP^$;CbGac}hD7|307 z6R3p~*=z;EUoxly!!tybu4j|c+0sG`s`BoMM}0jU)>_wNC*F4*b*a6&iWa8CC#}ez zX^g51w>eI6F`?!`=YP^z$_@KD>@O&ZXpU%$Z9jQjwgy;#OA+GEWsp5wY~-p2qQ+j+ zg++Y3hIo&5V)1+TRnInZMSn&9_MR8s^5-(n3LCL)&NvZ22)x2^8&`4q&lYpn{m2a0 zRb0^C)PH$l1lfVPl^b+-4SCoZ@gO`f)Ig_lv3mc&cYjdhatL-{MCVp{zmY_m>t2v$ zF1hA43innhq%ou|i#1+Z=qN7aEGH&CDv&sH}I(PKt4A#-M}fxia5@sK(7LC3E>ExM@C0 zHSEv(zDyDow<*I&sMNui_qT&q_EI3A=SmLF=Dj&}*ele-+kI`!aPP7lMAbNL@QQiD zNLIyWi7gOL8I2*BUX`90t+uM2SS4xYa(_T2kK2#^E}T-Ud%x_OhozGHSbA8VucEi`Js?AMH$HvdSCD3kze_2kn z$P6?77y&Yh>tA-2Tn67gxywL3@gjlRSJKcb%ovWsbDTsPo=n#^m z4aH9_4I!}lND8}9e*9>EcaPe=9LIa(<5{gX78OLq=v;%|ylD3Gs0h!%tJSaiij2vy z_tQ;8==#h;G@ASh;N27dfA}Sv_G*x*|>mn*sr#(5#-DtM$*2PCI0LXE6 zyFob?G$_xJc#!|K9B#Aa_ON67TiSiMyLXWcFiB)$pJ|{JxP;kA$wbO|zS}eB&@hKt zKAKMg3B~(g;^M>_n$akUQYhWs|CsSM(~^hq?-ZO9L4Kus8hyDVx9kCBQdhDhi?{Yd z)rYB`e9X=;sK4!rHIb%TY4BB=w9`lL%dY5JbQhr-J*CjqA!4@{ct8I-FQhRhRPXr5 zJ}cgiRrKb5uoq!LgZJt&?YWm7WC_pb{fd+#XNe<->TNATw_Oh44S)7_>?{X&fd_xg zjqN;2(9OUGKApQ1_?ktM&(iE709^#UM@uHL4ilV%BmL(Me2hHrk+pSyxk`Vub#6;_ zlNg5(!?%IH$j++Hi+{JWyNglMXj6BOh{l}usLyg{v>AUjc<=a1>knP#>4q-&QeS9WgzRKR zuR(rx{=$Stc*?k6BWB=}YvWi6@_^?i1-*|n+5 zcHe3CXV1EIc0)x`foy@#o}aT|Ir1CVOu&6hiGCW+2<(r|#n; zt60i=-0g9es?X;P*>&$l9yKsL?HBmHk$11i4}QZJyulY$hSrY0-2Xn4Rq}Y@o)3Ep zU$imzctNj{o$y=$wPd93B@2|{CnDFU-jpAoJvxik(%N(^)D_rL>UuWvB)^FEMd-)3 z5?)T_g4CX2CV77H+45HNf@yOBUPuz-XZRgYZijM9b$<NIwzTS$@IJCfX0Oh zyX0(1(lIr~eXKJR{k7Qws+C>|f+-arnf6Cg?}FNM{tlF6jea9Jys69|x%Cz?Gl|g# zE1(tG$3|5c&Jm*o>cLm?LG#&Rrfnr9kn}&~f>4fN`*ZKpRY7yh^3nJWoKqFNPoeWJ zO0%7OR`jVu)I!+z2(H_OKlXjQKB4~I&fbg5OfXKCdA)>^D9`N<7~Dj#dn6dnJyT5o zVPH5+uR4;KciMpO-pUyhT`BHViOTF&77`2LQ4?;p&pP&tJ9pP5q9nKcRCD3it^@&k z;^FWM?;}NyaSR*|zDe`;B2%*-whD?wfVLQh83gNdE8%aRhw+o~JjrrNL9Km@HjAfP zDmaw80vDf7?K>bid8LAG@ZDbt{r{z}EFoZOt6m5}WCD${Mqk18-wy){*q~7Lrg2*R zcL$|mf0vl~wuhG~9Px;vxwNqjA|D=Ee^*QmE}Ak8u)etpq19w zOUya(;F>728Q)$zB^t+7ZShe!2hq4TUmfoJh$8l&M}2?g!-VLeCp1yXKG(h~o!ZA8 zDwKcl%C`f>t~CErv6LPE)jF#E@9Jj4CJ4H$<-VV~0_w0w)zvK{~}GEJ3dN{8=HRUmZl{l1prN%NOH5MdW{Z8$CoLQ zBmwLQLXqeZ-C3=sG3r)%4A!&Uo2Gu7^OHhVy1x#^=@>4jil44>g5ta-$LR++x&VpHft<-R95*tVykG)7IxY0~0OD83>JlbE*3= z*SnS%rP^3*CO?&yR4<<&;_ps^6>r#QkvQ0~1#eJf!4CLPxK#r%yWnMH{h#`7b4qh;wL9=AbezRTPZ) zHiW;3`(pxoX>HzA6g|uja^#3PC|Z|Mu>sn)z=EPme}$wAi<{xoS5c zPAsx`LWzc1unXNNG^o@2Am8Q%)HU}$bv^VTyMUmobFZd6JJ;< z;=zVDW;}{$I!6a)$EJZ!z_m!!I#`p}SXPk#=@o1!0Pavk?Qxd@HR_5!*=8H!Y(db= zy5S(loyV$42R;+>KBE{~-?8j1Xt;1a46wfHRS3eI7@(ipv}-Ul6Xx=K&iSg0dbce< z!HQy4(K(ld-&mz5x(u9oFCI4MKQ?M`C-uJim!l7c2!D4H4J~v zHI}fE1Cd1V6ia?HGm^n~*CG4A5e&q;Pc%f4^ZP3~HR_FVN3Mxe1;95#SuvX&jrq80Z8`(*kYm() zF_?E8#h908u7G^A#@f81DDC+uW_uC<2U&VX-4G^3pDvX=1`KDCsZhihN8A|2sR(p$ zUkS~#216G>k>Te+VKVBRLa!2$c)yVEcr|kP1?a{y)hg>x$;N5)3^-Wr&{=M>f0iW` z4%_~SA%}UGei*SG&{#>`s0`upMSNxv_c zc;Dh;UrmEP!%u=s(zEF6{q4?i+{AdbYzC@~xw-4j=Wxu5)4klcZ%y9;boUqfhtfn{ zs+v{n*8nb|%;W_dvXc-iSrEzVy-RD+r0OS$61t_o9i9gctiRyD0&@^@Y9Kqgq83%jq`80S2aSr+ z12^#eKUiEPumjJj0&Smh7*ZH|6!NvwWTyIWkgm8MzJCbNV7Px9>R-r<|Gl;lgdVHW zY=j@9uK71FzX8KcZ1~G-| z`0vM~!c*&?IojSU0-tT00lAe>SpKQ)eco2=hNyL9=UqA(Va+w&vnWHqcmTW3K}QM> zzC#>5gKieP`A^9Sqg($Dv>C_t|K+x5_`4(dZ!Ap$E$Xjn$p7$@;or0k|1q?KNx$)z zrv^hCnhh~kLY4#3ChBMSB3;- zNoE$ieN^MqPu*2Tj!TgvIV>G{%d>9dI-b~JeOY?07ROEpe-qEsh-$cWd={j{@t6gC zGs_!cr*j5ejtsh$g8BG&!u)Du+Rck`LbE(a4(8=Lc=w~@J1e3s^=^B9o7v);DNilm ziAhF3n)^ny(s~i|%D#o+d(!|_sa76U0)3PzGM;>vF|{(#OL~q;4@~43HQiM0Ziy8b zbv|a*i-*NKxgv&O>CiWDDI}S{`Ah4%ViS=OL zV%a}v`rnmPx2SR9h!UY<;=*40Fv~*_?KOaKEx-n%5~*x~*0CSWhq*nPRNzHBSHXVT zhETi!kn={tS7he`q9q)$STaUBR3Lr;Q8xlZ&bi=Uh}C^W2td9d&%vDlW^>7q#`2E8 zuggKH4HQ30hxkP-882k=*pm6+51O$Huc`4flDJnTsBqAx2KD--jZy9xQCJ9UIEES* z)8i1dqFYxZnKDaZ`Hq-;w-;K!_7bj*lIDay2=()1tABK2WHV|Oo4Q}R!&M&GRE+mM#{_zTs$X|L!cpB#VdK-kGlw)y(D`5?>X zNtDcrZExO9F*l@KNQS^_!SicWZh*po-p5bP;g)U|$$%A$QnCyN!+1eE<4fPJ-C z;t`8xHRsb32lL%G8azX&(B_?h?WEQ}efji^Am$+nyB#lfiz@c5vVk@Dm!0`Lsj)S&+lVT5 zs=w`DG%yEa4Q|=TM>Id*Iy03P?=Hjpiaz$j1BLzbx7nd%)O&@yW% zOs_E4l{*2XTNWD1s?=l%Dwc86$7xi+@SIt|(cVTZEF2S^!N z|K!|%-6heRaTG_cr^TaH-3(fO*D8wL9Wy|C$?dP<$v+evWt09m-~DA^zxgSB`9;Fd z%YVJ+Pe~u9P{b4z>=}#xF%tL2?1;J=i0u4d)FEP!3%qM6Y{WH|jkd$`m zcfd|lzaiF6H~*l4fr%$En8hw))s(~Fp34ZD;FG*nK`}i{&;mFb-20UvfHOj`;FiIp zN)QCkn@t4AI^QI?tqPqHa~PCr+_`D<53k)8gE=7uUY)2PM9R@;v6L0{Sb_7RrEgL!n_xM zqoTIN~gJH5jDciZ=m}OCCdYu?BURVn_W$h;~G>M&Swr0{ylcealNxsE(*9 zwlkI?xrCUq>I0t1kQU5a&-*E6EcI3!C{gedu}HA!g^mbbi#LA<8HyGo`5FTH5c)ZB zqgVdgm9$i!8-X(GV zW4FfU!$#0f3@Jd5L)tLs&2`e<7yFwjpSMe$dyBZxdhcNe39#JQNdc=a*Y;-V$}5V= zLEh2Der4~Ios+8xcT)(HkpAlDjm`-@k0*EoDV@zrp+az6YD-Y-6P?2D7N_1?~SuEV95hXo^j8Q2K@iSUFIrq>tDIK#&g zm8=BRGb+987ra!{htpq6YX%s4x#CgpT!5;knYS5YDSq6s-Ai8lxrayMjBm_ItaUT@ zeA8va;0WChZV`Ed63-ex?u%FEMpnf>W392wAxEOaK*$wJ3826ctE{u;ClD>IL?mTT zZJ41;=}yMPgPbFj<1Z(-(bg(#dBbL`yFObTS5Bc?b zO=!b*wAL16bAIlh2#(300ng~CL6PSV8n3(pAb{u1u&b3 z6{$DLn7P%V%Ie@3-);fo%sfc3FePJcdfE!%FHA-NlHcnYvQphYAG`s?zZ> zu>al|a-jT|TG3%!@o=c~`wL zyXW%2UYESCJ0+x|Kk9Tbo=ac9cKY49^pXodQG6L@I;t}D6u4>MgG%=wI>MEtKT!N> zH`q4w08oO1g8rz%@7&~tykAaQG7vUIa}&6HFg$|S-0dIqDnR=BGnnlAH=M#;pVCA+ zv3u;Gm-hdl(qK6D%j&2PMCsGZ`qb};IW~8H=#SptQG{LK1I_m~8Coub!hr z|A)^^*XVcoT2f<35d5$!-EB`4tmH$r^4C^ddKXwLBTJBgDNImBUR+Mho8WMqsIP|4 z2`UxaZ+r8Au%w(vx#PZr`yfJZsO+lsyetO(r=HRl8cm=>+zMRk`@Lg=u%@Xl^+`d6 z1rO*xhS{GOz)RMy*S%UEyd9q$WX0sUuOiBJf=nt5Q%oLbX1JZ)=n=bZShh8&z479B zam(c}(G`K9JL>5IK5Xuy@2)8+`(KeA3rIWDTWyX@{G2AxZ>I5fAxR;qv4_?lbL2|Y zd7j2KwbDm(7Azquc!zL#D>rA+!i}xcTlbQ+3*u7_hCSvew!e8v_wtpSuhg4VC&OM} zQ{=qu$$+(Zo;7M?0(V!|-tj*DsDQV}?n=~(DZN-~vCULu(W-^mWaqJYsvw?PjD>~{ zdW7-f>;u2mp9~5KDNSy*b2dJsk(o3x^R@d{T(>w?y_|(@giky~JDf&=9ZTs8iJ+eI z%$pDkzT&&bmKmxMy?H^tKy%@%ZSL8Gvyeqz|LaA<8Nm)0ub!qwxuFzsu$XGGi7+lT z+zRsD_pC@k1BVh3U5y}1Gr@d2A6xpzY8==2o^)Ig`u2zpG-?9N>!E}r0n{C~&qVeH z?_=y1=a0g!1#P@absjNz6Zlji6s1jV=|;?cW!>3CV}S;lh?Gv;Vx!-KeC=|RRMne0 zG{M}o%~!9SOeFJLfHL8ABy?*8f{XWbp*{tG^qAX!>4kQT774N%pt^t3^^93%fj0e{1W^9*yIs8ju)PV^D!Xw0<0>6)~u%lw%UZh*_zKJmg%u_6A{ zo#*!sw2(0CR8#T;sx`P>@|LmMxQwd@IJfmFk`#HezQy1PIUBgS>Wsp{TTu@2#(lMdL*kDbvZ=zxuwH> z(Y-BB^jgO!g2fr_=%79|aH|vJ4`c-ksH}aNei_(@b_YEZLrN*>hQ1ZIUad#j*f3@H zXdiS|siIqgZ|kn*rSX14zZVq|0lfm~vf(pxEBldJZ~LL%*^`&Xy5B7b?oCTj!<$}& zZc%T-l4bP*UOcG$K_lTeeFFbBw>O~DfnfC-&Q|lZd4eqxBOYpbu_N=jKf=9%Zrdnz zQ8dn$y<1gTNcTLw>#3s8g{xST>{L9vs+}?YP3c&)Pmtb+Pdi*QlMJWA*Q?rmKls+j zH#I4i4j96g93!aKemCWZ6`(}DJjvofC?Pi}g$HB z^r!n-#P1{sRy6m4!?^A-Q*)izmTw)0H9pIPPHq&UHud`?mc?fqa&COoaynyTH|mta z=@QdXJ^N|&8th>8UFW&1Pv1G-@;U6UKH5{@0sRmY3k{QOVI?VhORkC%%O`MTqiU)C^OG==8gg{GS)L3s??BCu3s)d~589#QP;sV~Z3uN=93H@wljcWRI5$rz{-voA9a0}l9$x-(UXL(~)M~9S;k?lKcliNMBXF3kK1ItQD zP}t2tizQZ}M<3CdR=D?``7G{IMR^y#)j;J@ z6z@r~IkbFyeUFDve_C?3b&gquzBS7w0bszp>U%!1(&80Z*t8_{{g8cw_VT40Bry3E zL`szjX{tF4+C+Z`Cksb#?$!3P4yEr*5^ENF99)-Wxiqq(I}WvE0`t+EM5DC&9P4G- zqqgkA&?mlr@tslLjVHZs`3_vtAZx!y)Tzz*)}E~3Uim?DuM$Pntd;@eWE_aq%DP1D z9;1dkzH!|4H=`m(0KE)7APXaJgXrdmAq&BwgBvmztCYJqfr8SnkUF#F$M;P&CZ7&* z++-i!%oghQZud_ME065?u>sj;lCT*lGCtL!HHF|#Oj5Xa@F>Qyqn^Ii(0h4rmDF#4 z9HBOvRN0SSt@OvJp|odKB5hE{81w}KBP1qv{EOO?7Kb}vN|p66Az@Qt97z$49> zw)2F7A}F2Adq`7Q2T{La6r!zWH1<|W+z!9)t!5sC35QGDWzpRX;pZaOt~u&{^GVcm3@OM!mC6demP)QxZrw4J5WDsoR@{IlryYFMNivkw*ZxuA9Exds@z6 zLSSY&KyNPXERJL`oPZ9bqhe}I@H&avj8R9AKZ_g#mAf$J+s>^duv%P;&+JYsAf}T@ zlyiCbhTA`RY+SHi_fb?K5n11P$3Sww#xOLz61n@UM4+s_@6As|s9zmMuK_FCLNow& zkc~09$6rmvuoMIy91J`aMBOm#mw%8k^eM_ZSBayHWxu@TXqiNfkd`c*exOcgYI(%Ly@wdDANChw-zZ%P7AR9aY4cD;;LpD z?NjXX>+Fcxwa>oo-XQ<*zr7C*fnST_y6YzH|HpB&%RTwVKE$qs^=j5al3z~YSQ*kW zagUtsjUg7NL?mYcV)ud~AK2R2qF{`V^iK8nSA2itvfha~MvMLrh}9oxl$jOzwBZ=x zST7XE2R4km>kWo*Pu{xCEh&q0BEFt!7Dd7czN!0sha%Ht1S(c+0mu<3J4Agd^@E0Z z2vPT;niha|zoy=6j`+*45|sIF{(px(K4LGOFgvcK-xbB#y>#Y;o!eXP)c6mxtQ;1I3mgU z0ydgrw-w>Ge$ePv&s+p5VV${~h0Ug*Es`>ZgHY%VO#dd8+cxIn)k*Il z^a@brSy+_CiD*{ZVT)!LC#iw#pq4nL#bCQvP0&G9jPQ$^Zo2pGG2>u zh&AkZKPjptVP){e<0H4uX-a$w2&Q`~%d@dv&T$9ftMkEUehX2LtRPw^BR=%kOX5aN zSKUgh3(&9TAdH0sg-f54R^AlC+AR;ix(hp3%Q+QQ_Vn=W^?**z&%~OOuaoN&j~g8s zu6N41dZr{_V^%?&mTs`3YM%koA%CB=$lC$h6?jGG8e0MWR6w?;bPS2^D(RBZD}BiM z2R~?xK77#FCSTurJ4L{zP@JUN&ue&@#wudS*U@3+TOvJAUPZTDu9MFxsKj9Z!6BYK}BSTEv}&UaKS+1bMr(?HK1@YWnOH|DW z-k>Pof7j6Ja2vbm(97i;;@Xb!!*Hz^t&_`Mp-rMwu$`$%k`XIR=w;P}AXI7_`R0C6 z8A~`I0Py+7-_2BpE=_v9qhybZ4%SsBdEW68XpKT20!*3RO^h^yAlUgOKZ(bCf+-lCn9s%D z_A>60h+buKj2CV-aCvG+)s3+$Dpi0NWQ(xfcdZ}{=rOf+lIcho0~i7@qo7=Eh$yke zHxr#$Q^fK^-SW|&LI8?houecpH!z6mR!*S@Ap&;_Y`J^Yxa<$SaVkdEM0$T6Sa;3Y z51-MYbL~gX(n5m8f?n@Pge-3%eWBA81Te>q_Y1HYe3%yI<9*%Z@cEI;Y9fQF_)ASv zqYStRTUq$UP9gL9;{z_`OdpQ;>9Kge8yoDg9cAs3Ob)y~U1Mu(tS@)fPo48iqi0dm z=RVh?TLp;t&fEl}n>TJ(h`D{qTD1J`0Eav-b3rb1O@}&g~2N;LX0>U!^Wm5h8g^8jIgLc!=h?T3a zJvB$Qqy%o%LaG%6c;9d47l_p*EayFz+R8|r(s3CZtO5c#NP-MQCP4kZ{K<9C$hxp& zP{Rnpg_wd2?FS9Wp;?JWCsCeyqjKNUsRVLB{?wLY-x6CrF)+YS;TghYh!e+M0mJ*L z35`P!m}))k{A8xa_FZW)!||>mwAa|DRAmQ|MnPAeJDue@Mgf z;{tpAEcnPhgJlRT3jn)bM*-x59)c^_rPMu~9`Y%G6S%FVgpg17j}Z=|K`InLWdl5P zosrV2KMRtfRZkxMDXWv)5R9Z2=KKtFvy6nh&?fc%TnJ|AEIWWk5C zeBggLl!ZRxJ5Tnz%GuVi8^kOpwoWUJ*lW=ZWp^1m$P3s5-g%S~Vfx;9htrl~yx-rg zTGPZ|_+wQ2*XNHn*d&=WZ(b0U=eL)j1x8Gp?gBhs%{P|~35fkeim8kzm55dz8Xb!G z&4R@bY9Y@@z{Jbf9(()a6Wx1b1Fhc8##Gz*<_Kfn_%(LCC?K~}_o8-hvww>vLLb*y z#qkV*1(RiwQA$KIWmceQGe71aQF8>a&cVCg86u=a7g2=1!POC6=`f?td1J-h;AK1O zpWx-~9VEjLqqXj4)Yks7H5>^35!lrOreByEz)ip;tE1+S<*oe}pZ4#8n3TV;%yiW6 z7vIL>Xu0{lK4^)Uj=Qs!%+Bm!joxlzo5Vdo$jdv5Qgbcn-|nnJ-3IF6G7SQZ_!4#! z)5T(?7?Q)b==K`|Iw!5;qk;QgQTZ|2T!;J)Bz)?MXw5n3Z!eVV3L#P}$%s@pgJ90l z6R9r=@oEmd9g>~zQ$vBKzdRUaD7DoD z0n!z+(hY0$3rO6l+@LW17T3?C#ux3temKfS4FRVF(a2K!r~Wx0DNgeoz(o}k?=^P3 zfr~?)$k45pm%71z+sOK4=3Pv@8w{uaaU}j@+rw07fZ)RJlK?DGtm zWmnRRhIO50OiiB8eyNU>r$|p!F8t>GVnU$#bJ&MgORvUUN~f~-q?Sm&?s>Z0{aHmM zqeS@LZj3Hb1?%&rtTv-yKsnFm5^Z}UHT~(#jG=Up#++Ee;g;S0O zPv)+Uhi?ykvJe5sxdGOcHMmx~#|>6OyvN^@zmiXA9Z_Xg`Z}Uy(iS&(_1uw?)7Od= zF~)2y%siMoEFswH!*h^BwWo=Sjn}-L&5do3A%?F&8E-Yg5GH|S^mLLCVm8%5 zHXh0JFe>v_87*+i7~(i#ec#shwt%A3>ANxB$BZF9`)-~4bQEq{-9y&O^K`)0#zvqI z7osv3nRx3jr#{FTRdbT;Edcp)!tF^P&-J#b(>FNjlZ~+X{fFQ;h^aXXk$DpurHoqf zMU9*Jqb;NIb|-G;B3K^XktnZat4KDCljmq}LVDSVSUO9uXM}I6 z-toH=r-Q#8@fFXF)PWYhCG@Xd$#Na?h@^cR@Vdk^aIsi|j#~CG@ZI5Y&o+dG$!b z;SlOh0qd^p|JbgM_G^%n>?%#5zn;FR{96eK*;W6Zql5mcY-$~esRhBm9a%_bzoZ!ZAv!)t#1Di(=my&F9$p}}YO^%%{V^n%9r=r$-Mc*TDk;AtrrQpB}(qOUp}3!E&FGYj7qEJ^AsWLW3{4qRfG4ho}6@j zxJ@(oDeczL9?A7;4Nd*}jzdOvmzF9U{Hn|K0P41y!Jude?d!Keac^VU*&I|@MN*< z$&egQ5j|L(+ra$EV^ZvU{H&{ zJLGd?p_`j{vK?;fVnZI6M$WX7KZ`w!v(T(uSHzG%bT$JFi^1ZpTDMA}pfR9K;tD*8 zBOn%s5i)ExUaueNvEO4&Nyg7fvCXb^PHHHWHWmvGH*~SzSU9N}SJ1+%5MyaV8hmm* z@k!*s8|c;$CXD0J>NirZ!B%=gCpqRCkqhCFLtK*)%|6u3U)`eR$V&5sWON$OGHp+G zDm3E_JG?SISpVjRBUE#)mT~{sglZzw+Yfwsj+qx)CfMzfLC!o97m*x`+1&fvDgDDt zOkM@iLYQ2=ZC75(*0$KXqJ5N*l>tA;wR96RTdCj$T<6&dN0vSV#%Hft#BB*;pLqDco$0z`EQZzJ^zqA= z_yERu9Kn!YZRvfz5yn*WO3@)qm*xVP-l;9+j^mTaUT+ZHm3_Z2B^TbPsX2Vl5 zSGFy;A#@YbED~#^8`=34v*}%M72|4^H=LPgNMhB9h|!$H1`?l9r#MVWQ~fr4*l{5D zK~KrL5b{G~FGxng0c)k@g_YUJOFcX4!zpIzF=AhEINs6gE9*8I^J#Ds-OK4ADvj>X{p{jnj>?hn9!EX&n8`*rcn`Q7}5Bj}_Z=CS< z*f1d7`x+Rpbx_r2Kia?vUOvUyJc77^DVv|D=e@*Vr!_Li+Qj(1f3mc03l#L?d<9p3 z2PYRIDl2AnF57Xd#DrBVO16eb;ezD0M^=QGPcQmMDnPnvUs*3vy=zEi44|oE0wXx0UM}@jY z9SpL2@7baJJy>3$ba(f~ENqhOQ#apGBiP3vKfA`Bm?D{5M%66@{vFLdemLsY5IH^=v#jD&C~H=+%F-KBy`)L!D)Qkc8sS_zCQ&S3s332%hs2$ySuwP zd}nr-P!uQX&S#efrIz2{wjBs16hL-Gbha)|^RrI3LWEOb%DY1I|7z1QorU2l9vYHp zfqA)b$iKTDtiu;`k6kUrJ9DxYFWo3`E_L-I^#Ef41s>-AJc$VHC2M0h`|VCTU{85@ z$)SfNL3iL+*S`O9ZnkxXNy(n>L#&BE8>^r|`n@gui*U@%9xktook6M$y;X7_YtU^U z+E+QYE3f(d?_FslbIS895b#r=Tg!;LN|vMl8w*FX`(DcUwaZF&HD_<`ejEQwkVXq` zpx1RVhw=W zB`LeskO+1Szttz$DwG;5Bs)eYRgXBL_?GBl25R$BkkZ0HFJtauO-!BhHkz&a-1jD7 z06J}dpd%SQ^9(;5(r_Y0>}l7?^_-L9+~EXNh)b&rUa07fh?k!rPR@5VTu1u#^_SzngWA&4XjpJVPdtHs^~e(nO*58aCibj1TCl))*0I=#GsDw_qj{UJ8@>B zzEbqkWG9jl#)6jyCQ$uT8mx0NCP)%}?2#d_a{12n1Q&@oQNne#_tPJytms)7B-h9{ zRShom-QDXky%2*QXS#K6^8JRO1|Q0uU@GBd)u_v2ZH-&4qY6fgCj~TG+bprNaz079 z@11M0f61$p1Fzs@AVToZpj*;J92E|vmJ%EqKW3H_6|gV)xh!Pt*h0(>+nrJn`T{q= zs+8hbgW-2XwYXZ^ivp}h&hs_9leg%y$C}UzosKa>x^Z%Px)KLe6lDN5rvyypNK@Uyl%L+DWhoxE+OHi@!=Wn1hv!k%J^)TlNS7AL=sTPb5^K zc9>%*JVI3ph-EjUEf5bgi@#vcp_Z29e$eE&Aqa1gi0_Ez4WO%~J0pZeaFMT96Q7I^ zoT({}?#f3k8^g;)K5DRilIhIu3g1bur$Rv|U@;wW5%%(sdVT%Y`=0*WkE=|9=!5i=#U!l==mURP^o$9p@J@}VzHA)=9M<^)7GK(RBledT`BYmJ!9 zAt(Am=D}St*8|fbgYRts-C0M>-K!KA+gj2E`><_lVr=`^ICODrna8ej!wz#&9(mqpbp<5yA z1$g8U!12Mh%$)_-@EM{SN!njnTX6u0=s;Els85AZg~0l6L)r;UHSG0@8Uw?|F{CZT ztVODYC4hHMrCL&J#bW)MSa{7^WROTL-`pZ?Hu7Z=kCfJOh}Vtzz3 zfR|Of{e#9EfS)QD>%ad`0sJgr>i%bzW~A)#?;Q4oPDu`jCtQ}8)l6s7e`9)i>n&Mt zLx>HhPxx)-S?3#i>a+_1N@A00gyHe&2n-q4(o-U~z2A(&N)1wRh&~-mJ*gYpZzu}D zk5lW^wf?@l#1N`H!_FkB4@2aPK1wVCL`~QlO8{l?D_)EgiWr4{Ep30dY*Ua93QB{8a6sjj6+2NYRCrs4SxML=U zso3eH+e=x!aXw+C5<%e-U*A{X8)SDJ)kn&_qH*@WKv#-mh@8d;r^1usV>S)T)6Hv< zxdYp2HlJNX<6=awi@iElq{Uq%W59pXL$GX`=Hv+5C(o@W4~ye5nmrP7wC@(Frz&@N zkTz>x7mwevevss~r$YU180W~-A_1Foq6@72J`<~oR5xKBahR-fX923i3 zND6e;b#RcZTMh4*^UMs&e$gCuzwK>NhQO#;P}TSknq!`y$y#rTyv+%R_$jH?9to8} z{;`iy1)1)u4bI0OoelfMdf01Uzp9bUgDb4MAW@BV6u1%xxBJEm5lh*mLAuCrlJge5aS(1WVw}|6s7=?9^dRR|&Oowin*JHl z9;T(`wFD-|kMeXdsRjuFy~tZ^lquiNy%Y0YjXDN==@opX?)}YQFD$$l2y&26W#~xS z9K^EXmcwOYpKP@jLQf3g-E%0p8J-qOPnc`d?tGE~XYHjz-RvtJLcwnL??qqus?Ixn zYi&eV;*e>Ja{~4$lPgnW%!2Kj=fs1S@f}lq1%*wX$y~iJdcJ(CX>5vrq!)v6u^29! zFFcv)(jz>}M}1fuL`got!{6`d{rwIo@aGn11xaF?_NP3PjTmx^e$WhdVY2EAi3@gH z6i%#DzTQ{ctO37AYhsX=2l&oVj>vn|X~gO)4us~EZ6-c4*n!>l`GwCuiwV+O@lh^r z(-R{L?c3-!5>w^n`X1N_EOsM1^?bypo%_v{dN9#9X7m_ZU`y>N4ke4_t7n|$csbb{ z!fbZ-2~a<=hVShs#0a!fTLzQ2tlHLd)?Ff*KaND_NCTB+1f5j^1|$m+qHm<*xHLv| z&BXmx{DR^!x`&l+4(O4~$uUir$lM8x><2_Fc-M)ydu`M~5gggn?c?MlRfHk1wF>gvk!w5}4 zZ*~Rsl**6zQctv{)|LPpPC?17EEilWg(^fwycuW?|9%}!t)b?$KuQikaD`ymZ!#h% zdrYdZ^Vi0hTIP5UD^x~pdLDhDXIyiYHoPcjjdgWU)li85M{|3Jp?XC zlG3=Hws{7M3jilrPcz7-kZRx{q^|~zA|#0>-yBM{KrBW8DTu*t^wa?W+}J@0w6_x& z_KkbjP1JsX-65}V^=`;Qng7;~H?xSCK(T5~UdBy%bd_aiU;q5P`mPqw*vL$k?ytXs zae1C~fNrK&`wkUJPRvqT{;=elosr(Y>G`M;uBWxC5E&y7eBLtaGNw3ts$j&`Pi@3) z&8iW}?U|@dtbgt4qEm5rp7hlCK*q%ktM-qc!@<0cU;U+_1-@1=CbH+VdOW6X{oEJz zB@U~>1!eul(rH~r&Z|SC(W6;Jr=T2D$hlkLaZ_oeMLX6kKh6c87?Q3y*3^|<^ z0e__-?S34e4`5aP_rs5l3c8k;cTTGuGjf_t+a>f+&%2-; z0nMQikXh4d{VTmNl=PD(N2kbAdMUejC_{m(M2Kl{JU8<{ud<7W!^0;lkb6a+#?uf7pM9v7>zxvx z^4(vsg0hTl9Nhdi{TsCWyU6P+aMk|MVA%KHs61?-{`TSmOzhtu|C^$N(LY2VXuSFV zdJeqxcO8ekf5t>+xoWl7Gz2WP|_u>vP=|vURsCa#!erHwjLeiR7 zlYhkRUhO+H>2Q9N+I0uLiVcp!Ma`!SMdk4sdgG1_OwqdDAtvuX)&iV`cA!c|dNBr< zo51i2+z`gdVXNTjQerKF*fQ^1jupD}&i`!4EqX5*g+}dU;&& z=xr108|Fk_bqD#jc9){5XG?=R2+sx@)+}FrH(Qa+Y|dZ~7owqR9=OM=I!Tit+qRu9 z;)bf)q^GTU%)8C`1w5UU*A`H^RhC^~Tcv^UF}L$&eNxNUgj=Y*FCJBxWFG!rAvAQ4 zP!a~(I}h8KL6W#Ix#Wkqk&TihSoi_M;_9-ew2>>uhYy^JetGIXHzvUrx!{T-t?N>l z7uO$;2IPbmu#e#?c5;{}CQTWFQ()S@cJI#3n-soz12o?%Z#9iB49CUM958#D6A=3? z*1*2ds=6X!ZJPWg$36Btm|D2_Mcz9pZA+)7r+)y&yVf|YC2tl`d6{qDc#3mQK>JP9 z866F0g(`^J{vPMeSj`<60sYo6!uCYo_vUQLsqF4UU!Sci-ezhHWlV8yhMIJwXN%iD zG|O!=oG~i$Q|owT|8icUc^xcgU2PlXmE?iv^;_%1Rc-Ih;%~I+9J}UEo96Fr)ELZ;%z3nAMnSh%s^m2%Y?X3R21Xb`b_8KAznS?GIbQKuP=G zLLYshm(2-z$b0(u!{)9GNvYsaXvyUdYD)G{!eN-!OUo>M&J!Zm2b#tt25NROl)^CbY_AZWqOcDm!so%yt@%y&+ z7$^B*Qt<&ncEy*+8o}I1u z4p=i{y+$gmV7TJRnCSfG$+B?Z>h+%PyH^9I2l-10>1C}f z;)_QYH3>nDk|eCwLpJn*SF#l7>CJxlt47~Z6amplhmQI)D9xMvL-@#C=jUdd2cez1ujF_0(ySyMeSGKU87*dd}Va(2m#e;Gu4 zQ4XMfG(*fORkpHjvw}Y2C(s8d)ajvWazi%iz)~r(H8NxJ6l{{U$mcdp)+NBH4eD7q zLKZkh{`yD0!vL?~&Vwj1s?s4ujYXJ+OE$O7abO}7PO$NW1jYUYV}#mP&{M2nlnt|w z>NX!gV+31c=tnFDAY}Vd%}1G_KY0Q>yJDvI(ro=ZHS&7QcKB|w@}*FGD&*y&WQ(Em34R@D(n)U-!y|UhdY1$>FI6ra zyQ$*Fl9_RP7ztpVhoNMb9p2}nIaJurpDX$Gu zA>kI;MBqO1=C^^a#k}7nEwZw0q}UTLHJ&psB7E9vv^}5Mz2aQ={DoPfvUB*kyh}g% z%_sg{#OVJ;^Tz)F=bQMMeCLFlY1|}AabiuK;oitFw^lUsyb<%hC`(=?g9DI9uF?t?{z-v>GC4vciL zJy&3tW|Fl`%RoM9sWu6MUwbr>>zpfhHg$P%dd7Ft2P-^xciNQE=vAY6b^$^cDw;EP z>!yldi{E8%xBbha^MCls_Jugs494ahuvI8_;-X=IO4pLV`0()MzP%p}=;GR%gADeD zjnkPw@bjSdC6v`FQgQ${3>_%DATs*I#VgmAJ;9mg+pvZu!z*xVDARsW5%fIWB7cI0 z1Z04g9d_i%ckugxGnf2>z0~*ENcZBaPS$*%WfX5zM4vw0EPXv?kCVUc+yZ&|2hGzc z>*FUF>Q2&4jK(;eM1_(MpsdRc*<;`3cV(|-pZeZWH^hOu{ucFZ?tunYV{)4BM*WH`m7F@SurFhL%K+bhPy3)dZ`Drv&}QHKg}c=Ye{#xibLG<_6AqCbA4( zJ15kQS^Th*V_EW4)R~9hx!Q-O%iElS?ja5mjW(Bo!K6m5FkFGikLRJzw0?@APdkM< z_2Gy@$6H>j@UM-$5lnT)8K)v+sv&}w%od9{-pxn~vQ+Gt3HdZGC(wb)xu`$1TGZoa zcUmH$I{1S6g(YFgv|hc_UVpy(0aqY4e~EnkS3eC^tH3Fn3%CVPAPME&f|*KJGN4kcKq z$?0;gT3~*6qn4n}qZCh|N$iLK7|0^*ph+fH6{RI3hX zD3jPE?4D68gcH1SK9Lbr1tgCA1+wrED_NbCF`;`bOf77~N!Clx_rs*Km7HnY#&^*-`*y?`wU1}7L=nGCyeETHowqy`x^x=6bw*MpG$2+UDCkM2bG^!Z z2D&+q)6i*-<+tau$cyXHlJ$LQ`ZCaMkb@GQ6YN`W<8KuEv$Edzvh5t4B#;b{%Sbj5 zP{HaycLJ%9>t<-zpfXpac=}U?PkC&V!zW|O`Zssl)(YHL8$Y0jxaxN-uN_2mPbzel z*El6*=T7iuj|8OO@e)V0wjlsLa>aX$902;l{-7zYRNZUbu{gMJ*T}J0w)U!v1fJ^=$hIdYmF!RL;9fYNdJ~0_*1=5v z947yZ{P$lubvh)Wp=}p>@*+_>j7Hj~t7=OXmu1tqgzJpnEpzmPCH! zUlGvK6k>G({<%CoxE|jlGiRI9Ll(+is6A1Q-y29f-g6)N$a&%Nn;U9hb#!Is7cIVk zO1%7pVk)*ZRPEJ+98)su^sER>7Dzo7@11f_Q>?n=h`bXUMra6?MSRK|!M@^;Ngs2W`nV=)tjI4m&9$RE zVk2b|91Xy&@83n5x@LrvKYT$}^{~y7mdZu&elu19_h)Z2-5X}%>VSx@qhkh#tSTRp`^uEk`!#9Y(&i)$><_xQ{;}=M_se)McN| z9ABw;Lp+OG$KjaE%W@J2smrt|h~lMy@Y2xWfkT7ZjJphm(mYt*DER$vGTyA1efaReKsF zXu#J*vwETFt0uD-3ULF~9oz?3<)xo{Nt5Q+11T@{v&DQzGgd3-nZeKES&H6SX!|-@ z@h=Ls^J-!(%cA`EhDmqmTpFz)N}b5<&DAdGqb>sQZeYUQK#=f%vG?ZTP`_>exKe3j z+H6^-BFdU%XQn7ZlET<0Nn%2heV7okP6$~tNeD5?z7E+U`&wk*_nEQHFf*U;<$gZ* zeb?u?@8|hG$M1U_zvFlOzJK_mqcPsgbzSG{e4VfJb$;jzrI_b1L@|+*$X&TBIm)i4 zx<%V&QN}+rb;e-kZ=*{{4Wf%rd(oO33K|jJn5Pj16CdeK@kTv`uJv8{~*sn9Rp z*%LS>Op`0{QX0|#b%|D_+d~}r_}?bay-gVOM=sAW~f_)Fss`r$u>U+orf?*w=t#9>%=9n=xeh4k*ro1=TiQ-N^oQ_c7?2 zZ-p1s=A-YF8Y&z>Ik0>h!~vf_$7clcxw(<`c|w!OXOZ^pJIC3Q`e~0eI4$peKFn%S z5lc0euIyRIvhy79clUZY*uUh2n=|X*C=Mp%$pU+9bd@8rP@|v!b?$0|Nh>a>Lh{1` zVl2G!z3Ir^Z&s%sAK}T(U|VVX zXK`^PHJzEoV63t%bn>6G$>(MC72@=&LFmSFl6Y&?n;COyq@#bE!H;+ZT!BlT zj+*0XMb^WKujg9O@Ia2$yWY7lkM(+Y$yz8?0m+W3{D7IA!l&JQTkNCw`MJ=D2yhFS zk(bB69@9LlY73-h1nCv@we(?$X!G-TXu&S5R`u!)kB9jOWchviL3laDmN~AgsQeI9 z2cMb-yT(i6A4$k=$U4ZVD&TLt zGg;p^zQJI1%q@iUc7d>Ynq|B935I$ASR<7RlXLAHI(1uCa`Ac@G{T42WI>Y5qQz4f zpOSNg!^9HBQ1V3~_QVC>`yI00D^$<0EiZ)Q-#rEO(K;aG9U;Q!%w6se@#MGGcI(xV z@$j#sG~sN@8y8mO17CvpK1^3C7k#9w{LmZlY6fnhRebYGx6OQTTA?+-ROx6Re7G3Y ztz8EGZoHQpK;(S_YJp@r)QqH48Qzoka97c4v+HR@L5#DA+QBXTlPt972zp=c{^DhUA4vvTRRrQ;b8Q4kp-pS^nz(C2WAu`$Q0SafQWC$Seu4yk z+{yd#ohv`DD}zTpGH0$ELYUZbZ^Kgvk*OdFBhlg88Pu*EQr!i@mrU|SnK#b$c|VrT z_kVGN#5)SK`^d9v-|9o(5eh#Lt*X4F6#nymGAw2g_(1+gy%C@xD$^hcsPGsY;{{1I z^hL8zJ`XJ7Ke~S~Qst~23>+q)!Z2DbBsTnbU?5cb@{d(0csmOQ0C%^VMLJB z!et+*fnPU%%g_In+z3ZuyXVX{{ai37@PT5g?ikN8qE9~;Lv{nN>(5Y51(bso6`}OY zbg9>ikK4zs;6-QzfRZZ=oklQ^vIA2fq&*8?6~Zl!8t_cU8koVU(djz*p)CQ9q*ngr;aN9fY(5R zTGO8u5s271)eG=|Yaj12&uXYDoopu>_ETy}LfU2?2NT4eOk@o74Qo8c43I4#Xse?0 zMHmRpV~nG`yNZ_*@#ULpSxM})sEY2*1j-MUTJE!^9%m|wgpN(c-+S!JaoS8hn`@wp z!DlvW!P1Xbd5?;29EThc@7X;0^94U}D!euoiWg(Qb>^VoIp2eM4KLXCJ!FQ~nckV$ z^i7_4kyy(RjJW9O{r2aa*jr4ImGS+t=*jyOq75`#heazAgnhy9rv01HAXJ z<&J|wn^dbKC9e871$((u9EXS1oD8qSDLGgtGBRkXl~hnNkSltc|JTbP!|KAea&GzfOZO6lj9g}-NR z_ThC$7?6s?t3)ISnZt8s*RUbhFHCKVpL}n1gMWYY!zOw9M{1IY>1%mGxG|Ie&a_bvwVR(5qZ2^X?ViJ0Bt*@^e486R3#WZ8`3mI$lP8uy_)y%;+IO>AT?M2XJBrEN%RftL~z1*>&*o*V{hq zUHY5nnw-v*h`8}X_!i<)W=Bk+vBl35o0_;_R!)>&yFwr#`THn>tU}73O`8oFcmRiu zaC7n31_Aj^>Of`UC|~P(ahRmVm;R+EGS8o-llVDXU(ARvv#?LNF}<*N9+n2FUZE$U zdg(#qRZ_Mt*g7_jG9U8_q7@tbWA4gHF$;dco1RUGxESr-xms~Y>&`dP29f3VRpoYM zf_Avu?UEQ(h?59DGNcMbN!Qp)^LD8-H-w%>q zlW&T)E1zk>ffAwq=)?^u2%&V*ZlA+QV_|T*a~oIIq2#-l?~}hh+rF@GskH6l<@QYW z4UZH31!)-SrX*cNW181gvrIk$yi-Y#i_aao1h>v^HSC(H3%b$G7&aDn!z$bx^p}7U zXJpp;DnYnRFB*CKsKJR5!`&`0P~}RG%m{XY9vkGKFGiY_h4CAlyxgcFFD5#_ax8l0 zvdW1+m=wVa!w_S1X8^vCCS#Ah1Jv%x;YeUpf0Q^<&y%B$3|Ca%G(_5z7(d#mofdVd zuW4+N=T=4xE{#Kuz4sI9&{vpE-w?yqvXeOL!Kog)a2KSur2821DST1YfaJQDN3ZOY zRPJ^)qsLP&yy~E=JniDtaP-^WxDEZ0Su`F@VMmk@^_@m+3HPx}PJE0fO>O-X%n5~K zgLFH15R5N`PqH=Gg!Qw>YkR%X~9@S-^On>%y79={c)x|Hv=ti=V-k1qqV>NGUbvNFLUq&`{cYfF%*-s zUKF2ANwSr*D`T6goup?;VdJ{lu{K$hm)uS%Gn;^7mz3J_7VRj&Q<;b9(49#D@jei= z8i_Ns0*SXxNNPI!rm}*p`cCHd+w1J=m-o;zD!na4`3~Mmy=mIKL4_t;Q*}(%@?^$P zck(f?#tI4sK~f#Oe*E&V-1@e@)u#=77!8s&?}SDZdX{qbY+`JCbfc%wF$eGB+2UoR zFRWl#|55WsGNWa{7GLlHKG5ruQ)k_OY6xZBU7~>p;SFc5Q8Z4)tS&KbY-{L}Fd+`_ zP2Gq@Teh3mdhXu!I*L3?dtBbL43nbxcnv~6{}eMcy|-$pQ*EXIy)b&8gtUAJ9PzQU zBh6L1l&$6wUbOs)$&ZFi+t7kihI-+6R)o=sk=3g+#H}#PM9piMZ>2sOg)}*kfbrg; zDnb)iTX&umo`lu22+myK7m+=_kGRoUj31bQca=hRBV_`T^3Mbb8OO2E!R2bLxBh85 zjnoZ_+OUT*NBA2|NNgwr zv@kOUFon6sXYLNKE42tpI`=V;^7ZzcbL~q?l6TNJvFsPAFW#u6 z@DS9>0jdW1+kuFr##)M{qb$q7#0<_1q2q*i&)qMW$T&Ju@_y9~>i8x7x0-sW9ylHU zL${YIm#VB~N7hW1;>a0SdmK&=YPyS=*hwoqCI52q6nK9ls71TO=`Avb#55p^ zhgv--T1lBBrAyP7%m%dIJxco}Y4{6bVxQwbFcvp}55{~h#8Dz$0#+F@P+q#I-6{aI z9b#$^dGy7M;C5j%lUoMmpKnHg+^z;`McJWW5`~>Y{r7}FnCy^}U#C}OFvs7pu`BN5 zJkUOn$}sw9E@wbGM|hSSP_jI?@u6`If2|6@;62c*AI1NpB4X4_nIJ1Z29XVyEg>0< zzPg%d-LWgi6Re6Q%3tX&1dgGKAS`RrvK?zjNTZk$cvsn6zk}JwRnf5%mTfr4x^ciI zP*PH#YrcIn35+IsY*bXTCBKw`N$q;Q8xMKN_TUD&8kZR{BY&3@#*G3Pyng|@VhW;c z*)4`xNH${yVyNv5aO|KKBZh+i`wh7N|#<}u>9n`$P~JcoYwR%JW>y3+-7JZGt`mDx#Q zVo5K|c`Y~b*YDB&hceM6ragedwpdHSInA7$&Kvi)3L-K(jGQ-0+Sjs89^{RuSnfwc z9C$qGxgftw0ZM0g$Y$eucEX?f)&Lq}796W?78%UrtX6ePj>nqbKBBIIkgb3%@eaC)CFhjNK)i_`k0I8-(1om+!5{+@ z(1zjO-d_O!dB!i|#Rbu~V+xSp;m=>d)57dqHnm@GTjTc)v*W1~)hQ{XUFmj9F&Ve9 zyQr^RBjNM%xAq;Ht*`HQvBN*i&;O>C{H1||O?un09z-#5%;Y3;%na@)gFL3;GVbe` zLaiwbHYRjgndjbp{+X;W;NI@JQMUS*P6k%w|Kpo@`75Yv16Irb1pq7t(MaUCf#)Z7 zEn$8~ontQM2w;`?x+tJv(rM#EyXBHax7ohJh74Kwj~%_C(}d3n&v;!D^sTK)V~9F# zXN7PngB3C|XY!YKFw_N~n z9~5qZa@q7c_&^^%*r;9n)OhJKJON+w8p+HYSuHc_IVmTKdfzH0BtP8hDtD_|^HlpZ zgui!frBR+A5_s}p)x~~}*2a|L*b?|jOw(baBi$84KWIm5d_k`R`-wI$ z3=oS=R;<~wKK4KtYt_xPYLt9}=o6hyVX3PUtQ>YIXnn_;6Trf)z_*$}bb8I*eq@Zl ziT29DJJ2}scA8(fX8CrTqMI{nT{$=89R9?XuNz;A*KXYz3leen^j3deS?G5Y@4VLd zz36%@Q`u@PejnZqly$U4AzZrlE<#O`IV~4*5BwbSk8IS9(VVVKh9(#d=u~C2IK{|` zE#Yxv#_@(%&+zfep7uso!#Oe9;JeMswbLR? zja)+=qmYIQVT-ect=C>pCdLN~(^t)D2v*?+Q`^p_RhJhB?)ngIeuW>D2hv(BKXd+$js?+dluzCE!SrjOahSBaVi z!b79p3b(ZIGmNhNr$c6Sd9HHFvYTJFZ@Ds7dJ2-}>4+UlrL2#+-b^=F_F;W|h0cmD zHaAmC?00zTQ8{cY85;!IOPpp^PZo|Io_vmXD}Pc>J32ziuj42y4pNp%ufa-d7e^^v zc=DuW|JUk9UC-^0^Hl=mR~XARVTOg(8h-Yq< zxxeubKj?1akbk}$#4p`P~BRmi_PiLvyf&_~>B&t?Mtly0~eL5_< z;1k+r!QUd|KwnU=0MGnSEOVWixbBijw4SRLr#GkSbI*(yG;rtyL!EX|6sY}xPMGoQ z>k|SGjg-CuhIw9Hal8uMrL2>yOIr5cKYtvS5H@PM0=p%;&u8T(4=%?OFnVI)A~WOf zhncl)119e{pC6o}eGpu4*R@;!GI>7A|IPig#s$5jC! z+T@)FSe^to#`D5;jbG#(p{FoQ-|-|s=FdC%fkwWMEOc;v0&aHz5Bx}ob!;hhhX=-_ zIU@&#g&1{%^Z+uo(x(?{=6SDD8{|7LpYF_;7ypCl0%Y5sF1p|Z=ln{WU?xhE&ZtO( z5cKJ}WKc0{~DMNSYlydqk(Y6dP93E|0hDj&K=_2GE*ivUg?xON%%UCi*-8%%We#9Kn_GX!Cq_ zYsL>g*Z5pYktx@tH*1lWtoi!@AN|tK1^^ekXyV!baO2(l9@_PiV7+*CYNp2YCgdbAy{S)jBN-U3m z4sA;WY^=h9h07B)+mq#N8!NO0lN_(>sfnzdi*^@trt#y>Z;b*wWIq^v^6BMnZ(7Ti zGa-bq`Gq_^EYU-)>VX%=X{5_Uh&oj=DybC{9m~g(*_Qe%m{dt!Z zhDhmG8rP+)htdJWLsgaAv8FcG)GOxQ7*&?eanD>ge}G(MU0mAbxb~q@2%h2D`JWh1 zL&86mqIUo2!surlm;EBS5is9;1O*c%TQvqbld)8<9ny2XC}-RWuT=|H)QIFj>e_m} zr&4o=ZNq)0&l-`*Ne9P|>m@5SRNktxnknY(xhHb=o3j}^x8}{^0M`EvF4@m`c;;63 zOf*lDkW)9^GIs#awmjLAmuGiDH*YdoxR-X@rqj(G;~@F2am&m`m?a`cHZc*`BIytl8PZK0X0`j& z3>NFqbQ+_lM6iIri)$_qnNj}JPHIwAu^Z@wMp$)jz6%B9OcBjMP^rJR;^HJGt+4}} z<_tRkF1v~=@4y&+A*g2C%!)ENPzn2&-7PsPhu?o3;Qse%&55SJUff?m%6|fiH~)80 zAyYr-*!Zn;b*FrI7jb5MxpMs565gpH&p))>*KX)1SF;pIUFmiZBe(DYOT*g#HdH3xl1r zv3-jVPEPasgQ)?CqXVvyRvAIMv`kQhA(d@yg;3imK@`auKo$As&%wt}O<Q;}q!z29iBm5Ok4zltS=BD8=H5=GC!4x%x4?{9LDyx0?HjZZrKs7N?iH3KRt#*Zz zqN}ZFmq+_cl_ugvfn|hwWJLl5&)w|X)YRk5} zXOx&9-5$DF+gG(-(szlEW8!A#O5q_dr9|?*d_UXBvn3Dv#qvatwCO)l^}9UT?}xrf ze+i_M*ysELNKSXoFM0zDQu~Y?awQ}r`S)6escpmy9!+v8z5M8yv>&P)ZBC3Ho53m@ zjA2hw@-?*SsG5A0hM31)T7d(`vaw2fy~tO$nFE=^!cH<#LMIL@a{6IL-x3b{oOWN9 zf7eks0%?E$-D1d|EWU1Z|IMp?4oBl3T)T}4xv3QCxG|V`qnB}>(azu|PP`vh)a5iW zAN+_^Cfa@lLmuUk@%PT^MLgcGzECi+^as-m!r_X&BvkEfirTnr=)UszEzQ&1HG9oH zl^p^sJa>=3%i69zZA(1M0}8aVOd^^ zy3Q5L6c>Ny&8u@wZ{6gTV@rUR#ktmV2| z?MOc_c-(k-hU$0WuH&mbp=T<~Atc9LW_Bj|c6|I#}NQU8edM%_cr`VG0W>nmgmz z*D>_N`UVu^A}A$tU850s6f=qL-)YQnr0?(7CQ3528iOzk`5-|6uHoalU!X@|X?XWm z2YZe>ZiylXjNC(R0UlYvDOZ|lR1>J8{CpsFlL`Bbv|iUcUs`@35{ z+`Fha+<=&R-%Um*LO%d?fe!`47jN?KrIvF8|Nj4`CVVb=v)!S|Eig2s;el%e4tH3$ zdf(CHf$J$&1xt5t?Y`iJFivxK=&n`VhfAN!GHl*+o`3r1&wQz-tpw8(SG%58hn(QK z4q@hf#oh(%mA~?To9}Jn6+1C?dw4wuBbxkb^)V_Rzx5dl7AWa%#VyR- zJciEhh-8V^V4Sc4^CCwk+xtH&0DAu{JILr#K^akC8VnyJD8f+(*GK$Lj%>iOIq#%a z%(dhSgv*vs*FZydtqE0xl;Z-Us4!RulDoAB1wterp|6}=7QvP9#JqpngPXp;r@HTX zED00Yx_~3|^uQO!kzhEP`M54sWGjOn2kA4O~~*H?w5;#cTh!#i!I*bY##$uZZ4J_Q7QRu=!>^H?uA}MDqEpRx*E* zhS({|3x@+Y0E@u`6+=j{)0lWF(FU zQ`$&-3_cb)i90OgltL+#k-Kw7&BmuCmHV#vS^1_-_9a~Y)!N13v&G^~Uc!27$Nexd z=)o~6zY^7$?La{#)Zg?vWEM37fAC{t-gBsEZ2PMv+D(Z&5X0KrTWc^-*NN5sMha18 zD1FKBt=Dmw&?P2fGFvIk4DDii)jWJl!fv}*-P-)-u}0?Uh2?aDl*^%DNSK2|!=6C; zNW?Sh`~E%?$#AreT%#c|tT?{Q%DD(C0h8MAw7N91S-N1ca_GUQ)YhAt3o?$|UJNd( zk19v;V&%9EaE3)?y6sFB1v2Zs-aJIVc_ZaRX-y7O2LkjX z8ME#4qBD{MrVOpL5xUn-jfQ)7V7amLXK#_A6w%aOE zKF695%w>K9d%CNVcE7td)p19LEj>YQ)_J@*2w0oWo$t>0v!rPG6YQjVXfGz_(cEp} zjYs0i>^CuNh3)ZHxVl4PH@Lau_IRDcr)AN)gKz?FzZVPZcK#k~BmaL6A`94p!|A?;tdL(LE-p(*pm`=rc zH{P1AR~5|>vB|dk4?~jCQ-AiA^%XYcUFy@cFL(ZAdHR%dI~Pr~k)}fusr_Pct_!W? zVk1bl6FT!*lqs`HCuyde8*eJrcsn3N&N^cpJ+1w!Kf6`ym7_Sai9aJ96qMXP zKTx;b{^x&XC1lV%zGJ{_`VVzV3Z5rNqnrY^Kc~lmUA-SZoGqu}kOpRwI8j4#(!k3A zP9j5jUxy`6;dp8GE~6^P zYK;ntv@;S$FB>%7Viu5fUm^ESNL@$^a1s8jZtSoxU05popcXm0dszts4!?itF;dY! z=JS55HEf{ngMin97xxd1(>FftJLPq_|D7G;=oQ;@kwy8bdpG5C=UtAczb|@77U_J2 zYEB%IB^;0M-ha}X;ZD^c%W(LDs|DwUOXefY!yqZzZml?!)Jf8!e9PjNFE|5j4yjs&wwPFkD>1mi9Sg;A&l z4Baow0?Y}SLhRIN_?hV>x?K!N0sd5p=I89FL|OmVa+d#Rmz`)GVIF(~C8c_;S&zcq{AuF9H{4UES!#zyE*9qES= z-TJ$H$Ke!`?$#_k@?#Kp&!w?Hm~M>0V^9<%t z_v2F=zTEM*K#}1J3^|I18Zo`_ZW(*=k9Xw(!`Fsa;vh}m@?8ad>)pQ%m|ugw5tE3d z5u|Rjv>mYUt`_thuS%q+S=J>(OSmV^&#!TjztYHd$dI=Xe;DT}mnd)qYTm~;=9{HU z9u0Qtydr$R{5|Hz+}uIHFaJ5gDJlo8g5NtKtwjx zqggfwU`e!X(=1REK=+5g_|kmDLHj+jDRusW$8!+_z4W!pt91Ryn$BzN)BJOU4EKm_ zosL2bq&l*&Dt`l`@cP)ohMPqB{!@pe+pbKci(;y~sIopH_Jh)?HvkIobl(T(9WOb# zh!&Yozy!xc9Gbfiao$LDrUkfJifV!hS;CH`gw|w3t02EL2?LUrICX)YdUlh4Vf&o4 zWLV)#2=qtY#7}KL?QON;R<``Hxr)9pxuiDzNZgwPZ6d{sS=UKEMH==ppe+2z1+oO{ z0-&{+?p(sKMOt&qyqp7vx`W2Fbb4BSD|`yMrKSLRak6%phVqy>RP^nZ-_iWC`cn&0 zJ`SaGNEh45(1j=P8KAXNq^*;JXfZ>&j4vS3NzGmp`KL{oMaCihC-@hRRDvNO?_Qj+ zIxKWh3PdT)JJ=qQMxBaRBkV{S){LBw!aVHUAANV%cqwbLrpi9}$%%F02SlGhu1WFw zS^nBhVQNCFpP9Ss7O(eqxncWuUK>2eMrEocHqxkkdOzRLj#=12;B~8{!3UQ{Cf}2a z`cRssRFs1OaedeJO^8^=HL_kvn3d=t<7W|0)~9d|^siQK!5O~DBfuaf3v7S0XGH6~ z5&;JD8K3aA*dPO`J<2)V7Hb1wx8#0cpABK~< zkpgHJfJK&4y~W^bPWYTPUFrdlWbX#hQ3oAP0(I?(1m&m;1_qx5qn|Oonv6De(%269 zgXyG#`!39v0>VNM)`yru^9RN3TrvpZ21_P?_9j>8evyc>)e(HeZDM=<)Z`IUKUOb` z4_CTv1j_n*EMeSG=GS;19G=QcpjE?{iq~>8z-hZPQsAEju8b~j)0;uIZTsLyP-Z~P zq>wkl40?OhknCvi&ZA>Ui0y$OMuf?>6ka^a=~3IBoKBxY67IMD2C(8>h=d6`BzyF0 zp?`}b@Grzod+PvTE0CTDVstk~%A6w=9K@W^WjS)LD;O%LBgd#rXWT1ZM^agI8B5+O zx&gSo8WY&e>MQ?YE|wb(AdEFlgW+p730;@~XFW!7_P#vnoN%wkv|mM*61sEE2p0^< zXhx7Z=#~0}u9=p@kPr017l!^|dh?5!7z+0bSXg)qWI~aiwe;_VHD zaU}1YyLVcw)()HrMdsa z^vf~SKUK&uaxoxWn>aJXG}9O>aQ&z2&A?HP2jrYZm@V$ytl!37kl>IOBS`p`bXKm$ zKfT2I&CRgCgzQ}Z=a-lhxBcgVxtKXCefE=o`}vCdhi&*!$sjkUVRRXD43{)uF-e53 zU>D5v-}ef2&X1o;d!gs`@XE&)iA%3%m>o9mU~l-Fc>Y-y&BKEi$_&;96T{%oFm5+j z-b09dV{IpE{v&1_xqbxnvtVNE!DIM)$UgqhI6wFgVZVVICF;zM&Zq6S&~5$1=P#n3 zoi^>ApNpr#s{eOxWI3LH8%2HYT?o^;^z!!)yRNHd`Dv;!@#&yfN5dgTz^EfzH<2htQuei-9;&z9#kVX`lsV` z6YEZf44BUwk`5qmevUYhD-`dR+JF~+bt*jR2wtR!E9?kF=UtU62^q9f#hx29jys6j z0=dCQTThpzz%&djjde}&&7`3NVo6s(O$_Fi6v%Onda&%3Fu(seBxt}e!^tX^_Q7bgP`TUO-+qFx)emT ziFl7KR6F29%!xr|X0Xv|(O~9Kzo{-=IJOiM5*s;4<0NVpU*Wx{q4yO51Pbnw@SYbCzhp_Aqoja@x9fWwXdo(wf;&+JI*==r_RMxRo%}8;x|`eR z@|8XOS6DuLblg`bK*!OpPQWXZ$FMzu0@Q<~Y*3?nM7fS+npEga_wfjs8GZ1oAqsGJ`Rj7 zj=vNhQS$Bh?FBa|s&_>~>KVBT_H5MW)f;VRnacio{fDbh&WkIUV0>p|EiU#{phz0E zCc}8%3oUU)Qup>+n`g=^)JpTNE?DTABYZKFuHeX`KL@5>Px?g>FLFd)-bpOCXIyV* z3CvK>&d@Xwkm_rqsITjTAe~Sm(9hx#fp*O@1EJfJWSCn7k(kUTp|z2O;lOVu!K;5HOLE`*;52PymbrBo zI;UR+rEKdFIs9|XZS~YHZMt!27P;I@F_1U_7nMZ2tr&p;Trv*}bdU3?TZ^5}udq+e z6m?j_c9mGi#||wu4!I$eH!oK^0o0IdXEb*^O-q01eG;(Iida;?b6EXMg{V3yDr0r( z6^_auoJN_@^B)(3T?n=`P0YOY^}-F*%sJx=(#aQ{ni?y;?jP4+S<-Ru;(cH%wGs-L zg40M5BVZ?64D4j-`w`u%n5IrZLG*hN*21&9nx&*)Spnj1Y7yGHYDSWeX4X1}J5#)^ zVnD{8B}lh36wE5%$E*_uRHIGn9n{X(AebMY+KOQHjwia(T~zqRuM5XY9-TUP)0G-v z%DvHw)**dkXu(ndbOZX>PHm*sPIrH#wl+x6tpOr3tV#{EgD;6)^+c}Q%+Nek;osfQ zuTRm&+`TUIgy)6Ct_$*y6-TE_7qC>$Us_Qr={_D=O^^_*h@F8ul!Z{1Z}M95&O3pC2JBnt~^%0I<3U+4kOCZ$_)mkWhci zmYBM^DJwFn;?d803|)JyT3)5`rOZXYm-m9(yr1Y&`ePPgFydxlAzewem;YKB;UVqq z2jiQY-!FPOTli)a?d@3i1q3JBmsE@m=07R0{Sq^~YT@ef`eq=%%<0N5ny%yX&vr%? z_dn=YCOF>QC|z$2;_RdR!IT_Kt9LX{au4JMNs*B(ltc#s*c=!L4ON1N?g!066l#xA zQ>RGKI?z6Bobu0z4xq@sUc$l;!t31rIyh+*J_7BlxjlYE~#eTKVpn0n?Kq{*^Tet#v0vh~{ za}kcT4%uFy&9`N>xE$~m94py)cdUR*A1G_QqgB{Scf4aV|K=TY=epor$ML2b#)Hpj zZALC68>6)*+Ps#i{D@Zp^Qk2#Q@xX8}ZYeVG~XE zrVl#N8rRW;5OsXqkGTmCN7{ zrrSt{{s-n+5kT`&n(xLTQM_^!fp>yrM-M^yK?nYbo8W>{R^-@;k6UpzPE z(Nn2bo)rGxg_vni!WZ3r^IL_}1@MAP-*Pnlhxw%StGc76dR19sl{DX}N84plUyl^3 zE)Ui-IVbO_EF@6ENz(P17bW6RGJ#PxVtaIi5*@mJ7A?OHU%qkExwc{6si^X^fQhwL#m;QgwvSP{lc&jJiOzvP zHO1fj>{rXM-AK5JL#5f0sD?#Of>J`-_7O4%Q^hppnGQ$kFnP$;5~%0l?C?eL4A6v^ z-c}MWTwTRSshmJ^pm0ZH)|Ie_=E(r~V{w8PzMJM$99#KuNQmliHQCnV5IQWR{IbFp ztq+T&X^xzqrz3kXx-8MF(OJRqB&sCjZd`DD); z>WzizfwQ*t0h0*5t9YKTCfBrU#Ga?zOL=>C_tfgzH7CY#qHrMkgrs0c&g8b?yXH*< z0&(2J{O5^1!ii3g$y)QzUWS<3?N)szpZn#Rd^_cu3uFs2f*H|ffB~ZdwVTnP6*fK3 zg`@SJzLF8T{_b-QWfv74FY~_^ZTI160-PWK^HqQ$-f@2$etZw?JpW?cW*1-!cKTXD zI=H_cU_us0_(~eM2lyBl0i*>>GxRmWpLRcAq<|*n%3%BU9Og?Eic(33A{9hS#F{h6{%Jea9Rv*mp2leSM2OQ9QHvBL2o^|@TWY_iLRxB32AlNeBs{~ z?_C(Qvhr^%M5hu>x=<^@TdW_CRj`q-yESgU61%AheTFZn036X7UEt=!2WnRDOsZ7r zILsuTf?03|j>+{X%_kq%-mu6A#kAldVf%27`T=W*Hy+A zl9*c+LN{W#DQ_APd&zU%Qhh5tJnmwIVgBA3#dKfyON04z{8OjHZyGms3%!zyff+^< zVaIqWSJdLNvfzED-Bz!UkE9|JL&K-YDnE-d#FgAIDRn<6KC9c3I`BbGvw<&w^O)p) znwzLcY@PHpNMyBJ^xIWq*)WFFGKEw3+TzydSQyVlb9yn85bHHv`tB6(d?0HMpg)D( zpvm~0R^2~^Oqf6mzXyFMyIqE~M1!NwCU}$ZTc1_yF8ZWGV?2?+ zm4*XGkyKeIJEp&Gr1>P}@O_2+lZ5pphQ^|?wFYmjKXR1;Pf!lM(vdSbx&r#Y1g(?I z)pl9PzS4M(_4#~~w6*0{9s9D0C{Xvs`j}{wvk=$)eb-Y(FuUJkJ*gp8aYu+w`xj0c1*+Vffd;spYz(}#00}f<3AR*6wQ5a|YM&O~r57Tys~ak`Viw(| z_}Fq*-z}3*0U)f)1NasZ0498 zJlrsw^AX-nI;A;3kL7#9%#+GNB`cmwagJm;G$<4{oIIg5eQ|Rmi`z5TS}$^0G$A*; znm|@^$1uLg;b6bQaUXUWW<0w|tucEW~z}VwvXi<`Vv-!7!^?ZF+sA;|~+wFYz^G^3FZ-<@Q&)k{`8?<*a9a+RMH$(PBAkTOW;} z`tNozk3%kv4^D@ECPyiI*1WN~7p70Wa~P|}vm1W2-m|~vP^b}F$1^VBW!D5s7G@i1 zwI}%aq1aJi!p%OR)e>~>@na4C=XV*e2F9zc%n#;e#Vjs+##bSDE%?*d2Ze@gP!O` zR&UNRaZ}xX_`?Nrgw|_alrs6mvZD=dKJG`{hGwq*)V-2rz*nFz zl1)o>{5k`Fh*A*qIXh(f&G~HiA$K+^o6(e~)0{;R5oMm=FJ?t7%%QJW4!zok;JU)r z(VSs<*fg8%Eh3505eI!pdE?{8ah_K@x}7*_HMmEQ$Z62EhT$AGo8Nr-{Y&4mQy!B> zukvEgoZvWR!;}OrLjx*eT&b;cRtPtY-VhZI_u)xtNqlvhBWL*Q7IfSv9>9eC>iI6t zS!TMQe@>Z&zuzsdj1VA7>y%dz1(Z#~Hah8_9fumKB~&xy*YWC5S0J7?r|j4#8c#?lt3ec^^fB9dRXAUedMxh`qr2qw4jo_mwS9!YcM^W^ zh2wy>xwXQ+1%p%At@++H{#3vxhhX}-32S%p31h|ROQ6~)uuPejZM*|g3jq6WV`&1L zxHYaaS|bH8OI(l67Lb7DZc5a$@wfN#ZURrHpU5*@Huj<|Nyxe_=nG}L6RtJGCc)T=B83d&95W zu>c=vBHq7|*4nWB&3a&|y0s!h*{Z92iCL{Ov3i)JRI##SwP4U@u*A~&7%!V5LV8;V zN-0!4vR(yU0v#K_>Pp}I zwk`@-u()=VkW6{~uC0CiV0a(%MsnA3!3yh~kiEKBUDP7&^UKR;WNWxqkOuCK3>E7? zm@r!hsfinfufvb47BdHyL}Eu@19lH`tlZi~1+MQ&JH$iU3q|Zq%)tZs;?R)054TmH zM2g|(?zSFF-$a*Dr@=J0G(1(Yka7tcWOOt8qv$`r6n^C_fi5JT^mcBLV-t zxOp<3v%SNPl5v(4M=v<;{hK) z?(gl#kuTusPeZBV9`P0z-_HYvxyz8;v6;VE6fOp1{UWlQ=-9eyG9xvPVMU5&MvoIy5)+4IqNwC6qp`qoF#a^rFL%u>Z&HRkmlJ$Aipt9zNj{`-K_m24U z(J-=}qbI5U8>q;jx)t`Tumv?~Xx7hj(6%z^8&AhVYgHkd3c_R!2xLls-unAiq1B42 zlvD4V>q>E)^IW)Fty_~-bk49`o4l9`J?qV8kfeGr^BcQG|NQV;K3L2EqioYs| zspuBEA?*{sT`Nl9INz~&_gHUDrl!*_AE%>zgz|v) zBZcv^0$L8Os{*gT;`mth46#0x)q!*cQf=Kh;%K+F@Psw&{!=M1{$%r)@h9b}pOrmv z#{mQQUjaz=tGq$~V-V^ZoZg;+D)Vdil3CtH-9B8$SK_f{)%R~Lg?7^AUw_!+=%w#4 z9jYS?Mh2xX?Jmee#K#C*`@r@aDSXQ&TJ4z|$JZZH@$Ds@G)-UCY>34H37vq zu?w}00ul<`EPQ1i6Am3Tfz$pb-NxLZ*#Qg=;|GxX5p-w#AC{tyRsjk6clZiO4g;7> z7z3SuFwJa1wzqH-8XJ88Rlz$jBR!zxnGK$7dk|rkhlRe4B&ITs;vkp~S;+PVVeB$g zk*3!pc9zma+;Y#o=ROfxe6cJBndEY66w!tI_0OZGk-=NA7)Mgj82uwhq{iVc*)@Y* zrZapRo?B|aedUS5M7kom_nzOz&}(_+7XL7)XR-Z<+2B39P0JlGUed!!9oPdg>n8km z8FXPBur1sxTJke=_sTZf)Lfd2HjkTH+(M4vAV8!zZ!X2_uC&# z;^43!2CVvTznQlAyg)P&;U{|N>4Sy?VXmjP^zzpyMM=1E!byVfVdJt4SjgPJu#|jB`0EkU@&0 z(w$7-8Zzy;1DzT%IA+@ksa4t6#<(tcFsf8DgjZVk{Oph-!(bE&?dxV7KO~oT|Dn!v zE#=P#{XSop4SRe+s^0q$F<-sQ0me6br``5Ml(K|z{1xHNHU&ICJw;i1L!n^EwfV>C z{J{+yQQz$@;~x8k4t`XeVk$a|Rs0o3(34jxWNKU_OIJfAbjB6EwGoNASBoYru+~Ku z!@nFoNLn8*cjkLCSYzy7)>XWRYb^ykGE!td+%P zzHZKFqk#8H+PYbJcJ7*1Zi~+)PG=v=x9`apOgOk_7Fw+OjkRMnR}T12@q*5eUzc>n zYd?XJh{;t7-`jU;$And5AcPCkbu`e)S63y%K+mkL#fwq?97TZ3Ty~_kGTG zrTy|aNzFv*-{#*9BV}y5ts6}-NZ{t97ScSUs6IR`7TDsJU(oOZ(t~#5=WKardzdHq zJA4m%>Ct!Mv!SJFCc09w^C=NuW@ncN8Ss>+OLE#2T!a>6bx-p;u{sI&iH6ja17 zezq!7Q*xd%$BXal$vPhp8^aGn1@@%2EskEcb~h}qY=wdGg!s|mZAp{sM-KJ14~uz7 z{pKNseotKPo^DXZo%3MJZ8ow>9FdtGUsad=!BJx_3>K1W8rI61lK12Jb`(p1?-` zeD~AuXV!DV#c0#JJL^w^Rj|(tvzyPjvlW$Z1o6q}-Lktyk-|P}5m|opX`@Gnf41P+ z0m2)#;HUC&K{WHe1p2Yw0ANz)00LWiXFF~^W(ljfPyQ}WF7f4<1{8?3>lFmk^Xku! zzgjo&j?KdR0Fum;+z&=`Tv*|_pa4a9wI!!8T$O=h z__>^J#jGc)?YCP_Yf+8YvtYDPA|?8G?yCix_Oxs~+CZ#L!G^etXWa+-2;BAMKnuFJg}&WfHYr1+iJeqd1ozx4JsR zSuqW9RXc|P3it9zX{V>!#lBg@u@CQwpdG}u(hkG0is2*u^>q%2ireX5C-~pK-87c% zCA6}T8~lO`n$}5SY(RK|c)BJ(WdT^>fL;x!%DZnJq#?F=lAXFSmOYN4>aFlBLP3dIP+ z_|@}z>U&_de-yk;fBdomWjLLzH=q^be{M{y+G~%I!_8+@obK%xrc4eqaF1Db)WZn` z-3oN)!%1?lzazG_adjUl7e6E6Qx)}pX;@tU+&rQ-i#W&#urZ_Z$dY=0?pQij{cr)t zb=DE7wcb82i?o6I*8`KsAuaV6(Nk;VutVLFC8J-XOI1|kWm>Ovf%y3 zuC2v9j+SSw<7HHC=kBrj9>2b)ysSj|zGA}D;j&Dd=cg9IvI$`oxhF5`&igxYwUDgXTyZye9^2)g>3pFFY^F9=$xpjU# zz4y`*v>;oEj_{m_!$KLd|s~JSCx>-Q=QEky*vc;?W zCOd3UZ|_I@*00Mq$u-2(1eMerQpWoif_JJwE}xqz7f#H!nZ7&_7+3g8jdx}7g7C8k z$4u-v`pX1%ws!^_X3^tqn&t+SJjs?k`&)L&^ifM~I_QZxwV9!<)&+9!&5Vd8 zTsE)W>vbP;FmlFDPWF4%2|WD)=)4K(^$|v;GuOTCuS&=8Zcq-L5WjvvK^U1j3aWz$ z7Sjx-pHv3?eYY=txl6n$dtYRN_ssKPZbv$Bgl0^}ZH~p;n*;{zy4g%)F+z~F(e7%L zzaIuV!a2WJSoBnHK(qX1)8Pr~gx*i68OT9g#!Ev1Q=sv@^vNQ^pNW{ilsngb%Dq6z z1VTZd5d94SCp>fAd#cqnti4b?u+h``C!TYG+!~SahrAD4p%cOMZvTGcYX5=Fiq%`# zS^BQ}GCHPFbQAB8H(W%03+k|8hIGhmp^Zi4bBNboX>NWpFkM> z0Kbiu_E)gOC>p7h_SfY}ESfLr&yPFb=n+*z>W&hn(|JP#a(~ca2gJ5FO?59{K3nA> z=eYiKrBsl?7$W(kh(%pP@o^^!rY6S#a!byz$m}r=t{XAymS!QWf1KZ^2;AleFZ;P+ zXX%}UyazL!A~j_L6nP^9aLKvD!3?XxbRMu53XTpD^mIfp@F?g-A_u`Mu1=+3RftyJ z={Zzv^*Y|y&X&zu<*Kv4FEt*}+(LST>&D9=b+YL^q8|rmnJoILa|bHSc1D9EQeLd~ zS$A@PFAWm+ip&$f2FOeH!nPeqNtQ@p;eD`daufnotH?54)BERT-Qvj&mP(vyOS0CR z7>-zGfE<+%i;;g`>?YJ_{!e4xNkfh6x%2s;~My67jp5Fi2Vg4lY3(;7D>|g6bnWLx(*NlpU)2;*mb`|ytGL$jHRxVv&!?$ ziHKYp0EZx1)#k zFQ}ZnXhH@<1L$?Zu)vFJSEpZRW-xVs%%1eRv8Yh3+1IH59bWzX)tT1YFR8qAbnboq z!Ac=VsjhSd3{Yqi_=@QBBmARzz|qLgQTKVh<&CKma1#=<_hW+`9R}gi)J4!aQh5AS64ma1mXj%o1yD6&0pTzsJBNEvXk=Cl7cdFQN)%%ei-n10K{ka?cvSY4CqYL7A`Zbnd|n(GN$G9eiZlknWf_8peL z93sC+=n>HBTDEi*jN*28Q6t*gHHLMxy1r*gYb|&dt7?_F#C>Y40+K7);!CfWpktwj z{@y_yBGG{{N3snGX1Sk~;T(~)T|w-!3rC!U)VTEB{eR5IQ+a)e;M=eLxFT?za&d&Qk- zpvut$mcfLYgoD2A$48TO32O%5Pfp*VY<*#hUiw(`M44a$I(T_&bHJFlD}p z@rK_CNKq3A{7PDvULRhXnBtapfCbe`RrIa>%bIZCBnDL~_Z&NNnf5t6$4y3yzz1uilbBUhAcj z|M?Xo<3d;F9w;%g*3*d}!dtfayfy^qf>+(>$wrx!)~BhvFz=9mvtdjB__g3T<=!Z>q(@R(mF~M^A}%oB0;} zh?lL&j+0SZ@@a9}KfwuRu32VY9j8-NUMbvG-aFUXpUw>kE|d9lt#gs6hyD-xbCqIl zj5^N0=6-q7wQZnQ;GcRhSG)(&ujPOox9^(#G|VJ=XMZ*CNs{lc-KeH+DW9|`=7Sj1 z?sH$JtzK~Zm@BOxY@4F;UrJlFUR`mw`vR{9eG6E&7YEvlCu*MD&|mU2x#WX$TT)!x z8sd_Ir4uHK!Q|ZTc6Q>(1pMVa`U(|xvC6c#kfnyp!riNL-@2b^8T2{{X^pxTvlKHf5R@lsbq}lqIH)3v@Gb~qbRmI zQtvSbZ3z>(a@Quqc2FT{zhsh$CPdKwLMbi7{1)Xzk(XZE_A%SO7XClhDVPa^;3zf@X7qneU8_y}+O%o_&i`uuVi%PqX~yu_o2_6x#VBUQ`81}M0f z1=0p3cTr8}A$C*Nt*hrZD##{f0c9^N+frx190xC%HZvF(@ifp?fjnO@E{;W-9Qp8g zqczhjh@oMZ7**Es^5(*?06o<<5aVh+cs>KB@M>#BIC3nLYKiSclx;kee%ebMyyNvR ze^Fcj?-*nLXD4v@{=m`~0Q7-R#&_NvQ1&Tpv@1P$jV>8cdSIawD zT8t}0i)7`w?wXV@0Gn%2GWsZJ{Sj<4EL~#*+6*0E&bF-#MnVS=7-#oJkTQETfb9Ay zd7f8oVHzEUmMp50o2*(HgC^@p6RL<^7?dwGT6rCvka_PdP<9o!^Eg+5{-e4VxU@5x zz|rIkn19u_TlAE>B8W;%CXI?OePZCf$n?Ah;nw3XZ2sIR!~P*dJQKf&x?6O|G|qef z{!7hLH&>3MVAeEShu^2PI4I37CzldFuAzwn`6zJgtiefx$}9Z2oKkMrw~@_~Mbc=8 zi~HGwt=Z9(9w@zGTG_7CXLg47aON#Rm7eLrv$yBBE6cs%|3jIZkM`KH4e^one}JN! ze^YL_??2^S_zx+{{}e<$Ez^+ll{Xo>>hlZfF=6<@igOd)JIG%Cx4_Z=^=PU7|E~z= z{<}o1e-<47-S5H(5At2m2nN0ZF2jT9Vt{z~5jc8sz)dRD|4i`rH^agTYEEdUmvj;k zFbyrQ{7zxn15b^HJ{axi^0l%yuGD#H#u;UPlWO#9AwBM4oh(>ft|7K-pmY2?%cPTL z;6W=?exX0W*IR&2+rJo0CS;N?A?Xl(${d8PH`41%9AVl#K)eTGNApd2yh3c3fL#z(o*tJFQ>e9?_!iE$4VoB72vvojaJt_X@y zlkeGrX4xT}35QCjT>riI4~&kmON{Kxqt4q^bowxf8=7!Z25GL0Im38}Cu=Ic<4QA3 zOxjBxq}fh(y*llXO?4o7>Vu%3LetLu znLP15lpFVORxYUd@&VJv@YOtw;_0O(eKS;GYQ?M*RVLMRsU`MmAa#&q(IP59l`}bF z_JCI9Av%tre$1A#K4yrY+2@WN;N4kD^gJ`&*Ys&KB8_DG>tDj5>RJtt2Q@nti6^CZ z67(@Vj8tFD3)jfpeR_||gEx2LHPYYA?lLk2zBEcaX8UAY`E1a8a<0y4&Qk$4*ICc3 zb6+6uB~#}qwKrVGo6F$@l0a&|AG>8ET4FObL&JXQlhV+IYo7VI1evUb zp|ZV!woIbL&hSyFf)IuzqON4uG0bO)1LV z{eCsVXSrR}4~$xxJI9zg%P>0@*}r4sAcGg@Ag#;mQxaNJ06uAj4u z|Bi)3dlpnRdO2}YX9!$R@rwgaFadr$DY`aF$Ll-T>F(XE&R#rd`|*3ro1zn)b3&n!@S~1<9@y zEWAG1L0Z8}23*E|a#9xU?k|TCl*>(vxY(s(?Um8+cDZRxt;YFMU@{f&F+F{AOP}u3sFT5o;&+9J0eR0u_eQ0U)^g>Aa-L& zoO2~*Z$F>8EzFH)C-5LkuU(k^mciM|dz5XU73P6GTY3sI*al*O~xlyW>$RYo9K^aQZ2}}YY4np z!HTcOA+Cpuw35q#!$AjtI8y8D*fDQ{kQS zDQT)W>F7hzgzBe&_AmE)JRdFRDhl`j5#@(4L`^vsl2L5fNKhJE^YbRf%XjkZi3ig3 zUoYXCa7Gq5DQJ9Aa~EK{cjGvj+Rw9>0iR=(=;N+?-Y8zfw{36QE86#ozIUkJa!&Tj z%>CSdi+Ox@43*ghutWE^%2UxCU#lKN6Omtw;TjnMWfQ(vpD2*F2oXT>qxY<=Vw7m& zijZqtUfrAfH1EpvVN}>spJ0tjE%;s>~yqVH<>7&vQ^UMO;y@9%;!_X18BK1aQ1zkq%o>j$? z=MgbtC)1br3TLjY;Vp=#D!-|JwDY%oo=n*>iM+_Y(^#zF$NvmFmLEO!lo+Lc#sb#C9&S;RRMILLiX`^**th?^O}bMw#NFPNyB_UT$237=_q&?nSiqMx>K^VFSD4HSV~^eCfTiM(czNO ztF(;=G|Kw!PPWHUhXo@N-D(Q*Y=J+G8KikOWPcpXjUukdTk1f7tX9E~A_5l_)%zEP zs&yrt26W~$Iq(PHfk)|jy~w|9s94lHcCLOQ1+gpA(k0644puZ)jVJGSdqYF&H0F+i zE+hr;(Rxdl`{*N%WDtRVs>k*o#NV87BrI9{WuX&JwG4Vp4kP0MBoLw}BS4+fP!N7V ze-@Vp;vOdE{-PKJwvvHkyS2o{z4F;vi_6!6euGD39~lGTS@ZU7Yn`E-zWXZbB)8`D z>`tvymPPD_EHYc`z1zh?Cdky)lgEUaO}*$?)QmZ_6~(vL3Vhf8JdJ5Nsw~HD#lfjLi5~2G_GwOGFH!e2=fYiKf^Zg8 zX&ElL*x0zMxv~79?u{XOfc@vL4*n1e*H z;MN&pQuV&ML-xs?vp`xn2jPWlecNe!TjYM8X>LeZcS$Dcdh>Gz7cKTr9I0A!NM;*v zR8Wz-d}oNma&71>cF`HIy=5)+&93qyojbDE>LPdd(@G5@ zrwxESE<(lY2tfJu7*KB*Fn;Fldvx^;OsUx#Dos$%lYZ}}H&H|8E_)6~zcuVm+0Xpq zmy0w6^SU|k3n={g=bkKNB_sN*G`!f?a;#P`NTHVtm%S*0MU?78Xt2=W*rr2%d;sSg zB8Y6yi7@SwQt2GVC9Xcxr6%&4vI0si9C|dxu@mR%N`5B_zAkrVD-r`JR{n+PRQ|Xh zh0}jd13><9zsEM-Ica(8m7BSfygwqxhosOcmdmX#sdD}w04&wtnb`VEb=yBUuzQ$Z z!ud(W$|t-2e}QPKC0HnG*U%62Mn^%TsjtOX5_X%Y14Lk=ST*k)(<3kFcg zd&oSV$~1L0Og;){jk-K#e?#0=w)eS074=>rigHy{q!Dp-EsqhegI2KI^$NYBOvUu0 z(MnMquh&E}=^PWmyYu1HiM^v&r9|_10PABoba^^uIZkbj@a@}?KJ0=93o9cbcAba$P1w?8Z6E#D*c zovs*p0*9Bsy9`5IW~}Br9doryDiC@e+J)!K)AixC>~p=>es=#9S@b}39T#rj)bahe zO!u-iWw53`H7;$omk~g=W%^1;Doz)H>BV&+2IYH6ybOC4QT%8RZq$7zmpbnYk7QFOrde`VcMc|3ED9_32Dk;FCOh%dQ#3RzDv>yZ_nrsZELGb zAjkNs5G1t9;c5<14fZp#x=O1vA(+(G+jNu1*S$G9KwB7M_CExcBCc)K#y*z2Wffs!$FlkyCmfB-{vF|Mh znMPl13Eplt)k6)-T7mx@_DZ-YG4=jePRnNtW^W{cd_xSZ&6>;P`Xb;fVctSmfq(*= zLk}sxGf7NIOE$T2JS^~UIgnJ>A@-^D08%z+1>~E6ORO9Pl1fxQTb;Wc+xo?OPvBTy z93!K3BcA%b_1#=YHA9l#?NUep4pbP|=@Q7Z1(>f5 zLY~2zJ4?ywINxLqy9-TfI`LXE7buPmAjZ#+6z$ew?>$^*;1a`*IkDX~f*-f$Q#SXJ zD%2@Ykp2Y^<49vgevh!F!pzNP;)0Ot^+j@#ogP>ryxp2>TIpgZeNc?ybF=a@X^J>i zuo`4`EUeOt97$I$cEd9W^755$w;3Yu*pIqF4i54`Q)r}^by{Pz3AO#X@_eJ=&d4yR?w26P;D@|N@KU-ko zge?CkG|;oPTxp7ESKbDby(;RY!Bqp`6pD1p^_F04r{MS0p7!R?sRcPZGeqxBGPtxD z_CI-eCmcXnSXi`6cOmwqA~uy^37L-Lq|FfaCh@c-fdRkMKL=Yj3uVw`SF4N`JbzJ+>q%lI-y^`Oc#mIyATx?m7jzgsTH6SQR#?wCB&31HtIcD z7JDNYo7-f{G+2C95&D=wof0HzKJ75kI%a*@3Dfm{h9({mZw2P+-A=J>k=$517&Q{#Ji7n~yEX?i9|_SyjIf6kv3kGx z=Elu){?`_oJF1PL7x!gtYTHepp)uMC^Ye(68DIWH#LCZ|UtHJ0D^0+!+eH+YjQOfE z_0r1Xz$g%W@-aWIe(K~gT+TUfL`VZl?|8>LHV#I|9Y{c$y9k_pCxPrps!H)ms=$GH|FMs?9{2f)*h~U)W;(+SA3EeI+=6DNN z$++u|iI&_~g&8D{3Jn`xQ{^5Dr8rm>pAN_PF-jPQQs$0a4(Ja0N0tupHdTPwEX-P{7;&tO?*rMYWLo3TJR?Lv=~R-z}UxJQ(=1I4I%ZDn10mh>4`ZvEmO*vTOr@_30%`}RA z&+_E$3%{l=2^G5?mdDzl9a6u&@_lvQ4Rw`xIm=UEuDc@b%Q5n@ZGkwGzSyZ*cZD~n zAxcC$zC(|O(r>7dthz)q-Ovl?SYK?}?sG)ed`|f4d8N(HEk=!PNUK~^Z7%o#X+a&} z|C3F*eq9}wxM%R$LRUeJncsB%Bj^fM>K&(D1L=3iSi&t=&DT+nE8Ju(<7glOqWSKU zb?;b0a?C!27dmgaD%J2LTx{cND!mE4=T$JpW3}%H8-2pef0%nk)K>4+#9tKokp>L% z#EcM*Z0+ouzsA05Jmgo;Q}}l5eE((-?-SPvu8r{owVP=0Al8JRifej`n;1+2e*u>mlJ1qd7x< zf@5~$lRtS=i`flm&Q?q9()qtNjn^{ZFIzoI;h3B`{>*sUEEzpgSn(`hr2Yo#h8Ih% z>Lzze=_R!z$*r`~uJF%tRoVmUOwULP?Y;ZRGsOjs(?UaVhN2tgIr*w{j~pXq+|17~ zZKlvExh{Hs>Thl;Jbbi;p{2#j@R!}%<{5x2-~i_1IRcmu=}K&QHB;K0qeW$$aqS}P zXC^k1n`GyKVGMCgRhjF_e+~v0SC!LVV0RglX>mrCk_&Ria8kQ0ogV|ciQ!bNtX|r* zGoQZ4&LPxN883V_b;J+1iqBImxfOJ5L#Gt_2;*7|aeMDBrA%IPRI; zi^}%ujey(DZ3>_hj-E`u_rJe*DvWj_%fvrC^0pT;6v8}-ww~|4v^3Z3oc(h#>Tb$g z6Q&%$Z)2fu)sBmf6%`d5&aeKWpjLQ&wt_n??ykIzPp!!TY)~;_t}Bb^hGlwv9GB*uyNGuHv58Xxo>^MMxS%&$gltwCM9r7HXUW3G;4RT~yo z!f2Rh#(>sk+g)W30aN1(Dz9GH>;0HtlMK{AD1Z!`s4*9Bv3LmRarfpsZ8;I6?&B`3 zq+Kf8X*-3fp_)7vcE;;j5bhw_H@_f8zav(bkn%Wr=To`OzSl8}&Ca>wx4VBZy}JW7 z+$77}Ue%kH=${_2jC(RA%^*v4RAe*4WlnN0aIKVK-jMq*1?2a@Y#I8MNf{mdQ%Vd2z!< znMLsjGPaHp*_H>lA0*y?J$JoosH^tn}~8qWzL$M-+L!Bw^;JQ)H_zv zaI4m`vD;o~xyz)EsY76)M|m=z}H~XJh1@nchR#zeSpZpzTU3;qVoHb;wdpYFamHj zsWNn5zTljJA)akRjZ^v|Az#w#TYm#FkF2urDMk8*-0NSO>Ra~Z4jxwweY*>+;Z<$9 z41075m?^*Y`8c2;-7m`Fe)xqd;!PH4CH7S#9nzZdr|NEN9+kvJsI#)BP<}oIkD3#? zuv&v`yGFc!?vwqNWwrZlViuL-?28viN66UTl1ND~^Rm0!cV-Z(JV#q0o*7YabBf<( zxZu}>*5}|H;-k-=WLew^p6Ix1Hr0<+sik)h81^G(h`t(R0z&TC8Ux8SbH>*GsQfU= z-AH}-`iOebr1HC1Zy#886HK}LB8Hk_O78$0e^BuZ)Jve*e~B1r%Bg1d4R){z!YLjZ zr5&Gpi8>%~>`E3!UX{Jff#AsKc4kfsbyK(7yzcCAIiN25!!mJvzExx>S#;_M3XGG_hU_UY9l z@0<*lX$_}y5B1-u&o6;M9NVjM^cQvxhn1INe%h*>E{P~M&`Ri8>)J2)i{jua;-C|s z)RliU&C~Zt>NkyOOq^aSMa^CB`)yhiA*|K%J9-U4I8b|vL1SCr-q)Ky1^n(Pbk^yZ z!5$KWSZVz-+5%x?hI6ioE<6BOsrxyOyP3?cb#AO+?ZP`MY%MuS(&zCn4n94Mb+)$t zAIzZ}qDg127f*u*$9VXHQ~!T;Ti<21nx${_%7|z776!t>3tB zlQGR^(S3W_R#Uju8#&v5=J5UTTP71c&*oF|7!Vk~&if-xr;PHif&X900I1{%++Y(* z0|@`OP0PUv;_!6lpTX!ZUnG@GuQL~-$j`(}%2(=w47C?aB6zD(vjvJ{W zWH{jRflZRZFO(wV|MiUxwzzCdOU71M@V7Q6W}XSrQC%rcG6OdE2a8)QX3PJ4*WI|l4*vBIcgsKe`{ORc(ixdwB{@8t zUb}QSs3a%Kx%tAKGo0WL6J22z&Y<6vbnrbNP_l}BSHR!g#6YEalj6D58~V-z3`(*0 z)oss0A;&>=Y(21297bO3e*T^4$3V(|61P1rRoMez?4j17W&+541qfoM{CA|vcQ@4W zUND)2#OdG#_PMbx!T))Ou931L!Iwu1259mZ*s&FA5;EMcYXu(4X$K_0N8q{cNKi)4 z6$w0oa16dL4TyE(AeO?mqHKeky-hz=lsJ?Hz)6{~`55U=esA?Upl2+sac!8>Y}|~6erYcOkBzB{JD z(X;ue7jY%+ZWKq2=V9;7zkbzCmFzD+;utpYyETk{S|qd(OBi6Wdo@xiX2MC^S5tMq zomcafZ~j?bU0r_k$yayR&KrIU@Ah@LVA8Un)K;udt9y3)ezC4z^pnsj>;q>s#>DXE z$LqY4Yx$WJl-v{)Y!rq3?p&RG>`j?ZKcDp-?p3cVQO2p|uH99}9RNl<(wT>Ute1GA zZlFF?^@w;jqBKgxXzGB}!7)#sWF=hG4&aE@tT>t# zy+@lr3!{cMHe1S01*qn`(1h5ZKhk32!P)xZ?yMk|=%TGWW$6FhP`=Lhn}09$%%7`h?g z@Ov^$LoQ2VQ#ovE-ybFqyT3;);K~b1AGY%ueipOqI6j3)Jn(ySJ$ZKXbGG{>EoV}q ze1iL_Px%#R6xHU7;BR-jEf67oENXXdjt5q8)d1nPEV395m9W{lebwUa)mRs!F$Y%| zI}zd@)z-iqT_c=prPcQ+ZuGO(Ok?FQzu@F1w}kIqF)&+jxE?C{%ndd=5boHrU<|Gs z)j#rjaRRkFNW~_4-bJN36AimiXNELrap1|5oQb|N@Al|Lo%WmbJehmV&-a;D&byhO z{lK?L>79irY#16sEIcI%w823`k6W`8ba*N`fZm`78J4Ob_mMpxC?a@#K0T^Qh1&%x zWoBxU?GkJKsGWA+k3Pob=4iMwb0}`5CU6KFo}r_Von?<(+LJ@))dg5SF?Zfzu0JQV}v&?0wSV?wko=)Now~MASrKan*bC>)OZA|pSM)&$%@-&o%N4;;Zkt8OM2+8M# zVkUXZ9u#^)uOCcK5ARa{==AYA*k%;O3chvw6+)Z-2EMSGbQHxzR4v)DD#=+n(seP0 zWJeT+Qy70Ubv;0y!d>Q$V(={H^wS-1RAa|rSbx)U5iLz#@yuf0Ew#h}h`c~&{y;l+ zn2$x-Sp#wyTRlVWtB8*$(+qy5c_JkpKhK}e>ib#iVJFtIJReT^E!AA<v(j?dKWF2&}ym<*?R?(C-Vu#6tRZ;`b2*c|eq>>7%v-!G=N{*&#+u8V1)e%KuunaOg~}n$uO``@do%JO#eSWcK+yX+ zQ=VF%T|OH0WSl4DW3AxDd&{3nQuGRmCvnj}SZadDX1w#bum46Zyh5n9d(SlmZ;#|A z74;G|;d0Q`X$0xK-3*5Ig}}o10*xto#42ouK74ladog1uS8#FB&0N3QL687nN~9eu zJFXKr*XeM%r2IS2_ge~0ib~$p;X-8&@*2{EyUb^O9wcUJ`7Z}b-Fn-b!T-#+^k})?8*Bj?OGGZ+YvEeM z7K9oT*cS)yHEmp?odZXk2nCw;W=Ykc$IZu--NZKw9+-wt2x#fL$;ofuPF(BwxfRqX5o^j40S7EN>IaQ6xBrGhk^v0s7$MA0mFHvtCOB zk$H1$Ftj&yQ$2?vmNDVviJKNSw7fl%V zq2s?1_~+7AjHL>TVWG#fZCui#K8?9co#dHK{!MwOd`LnCf_~i)l2~BltRz`9#5<1= zaGa|OpZjfA z7CSSpLcWs(CBXqsP1k$v9F}(IQ^aT^lbf3wB@pDY2AjBu)^NS6`LNe0yBjs>YdL3R zBy_(?PLVWRC?we(QCAGlX0m>-pF+82i+a0XFIFdHZa8oWptjv^#B!z&U}-@VPFOl( zOllM;*NWIr18k*>UgaZ)+Y*cg96uUuME%^O$qxCN;BlOc`XvV=9)ns1&9~DgvhO`1B4X_4A!ayi7s4$@JC*& zQv#3w24i{Nd6MuY5_9g8PUr2c-$`CpV{`M91ZigI3CGzwoWQJ_!{E;|Uu>`v*8W#M zOk}w{K$j}7X$PCra}6$3Zow{TBuMR{rUEnq3WWplYr|PkDGdfSb=lUD0Zrb z>SKqkGOdd?{naau*~yP;RLMo>3qb~u%H{3uTkF(K_yu> z`17rAdpi9cEF#ptjx9r@$W-}AnCW_DMc8S(g2Yceq7Tkc)J5J4|8Ux+`6iDu^b}kT zvwg)=d3Uc(#2(+0gDfymKU5Mp5vJZ@_CRAs{G`sg>S}vO&dZ^*+1UEHndWtpb*rIr zJp(c@JEm8;GWgE0BgV_%wXy9DBOin7Rv#|YvgmJy`B;lHfD_%?iP?tvASz_PR>0xm zEz}$*_SDFFkubi8L+Ievs|NSxi)~AAw3y=_wXAmeH_H9ED|$?%w0Q*noDrx-qR90W z)}B;k$ZuR>1_*@9=e>D7bKeOPpB1@ZUW#B@Toa}99?i5fbP7xVR{8!whUsx2 zM|GQ0Tanqiw^qp3=4^RGT}9P9yBE4DZ*JdaB^J`3)JsjA(=l|-70z#~{4l>5jC@$f zK7hh_R-!c^>pG-lxQ3Y}i@HMTmo~E?@<9} zV=4}O`_L@~Yo=|}nM_X~!vO_J#foUfxVRFRRv8CfNx|3eUR5#8x*BX+V^TzjkGNKW zzrViQv#))gL9NHIl@;Z6z~K4aoiF6syd{6#8X+g9k2r@aP1bCv8r%|F8gIu!2BWLb zJoR6VHo1|TzBaIWwPsSc>rm9NW$7Vg%`y6e_1dKFRFuEF2#e>Qe!t*ysQkyIC0#EM zezusoh82q*Po&2nnZ?d^C4o#;J@h4#>(sZ`s+}(1z3MKP9RP{>l3ujdh5TRgJex~{ z{TyZ1L=$zIouil~M}mx|5AJ{1;^)l#bpxva=+f?_Ftd8Z7jY!fvmv-l_5CcZ}AUJc`v7X%MGc>{N~BT6qJl%#Bw2t#th~GM`9xj) z@Q|Mws;RtA3hQJhj`GKrJR^$aP*o%}Y z`gzVgM4W$kxgp2qn~ zU55nfHk?70N(PQ>L1=x8wy)b+56tqO^pqJx3Hz!%oCsVb)ZHU*AE!+E($72xtdkbw z(tD`2*6}4L#0Iwr`F&7Bb_J*BYZpz5DOjqvat6`{N>l7qmU>T}hdee0kAJzq^COG6 zcDGGaU`^|MFJ_|mor+%7(8$Z~fMT+_Q{20}9^;glJ8O0uo)%)ESznEG=q&bx6+3RM zcQSToVy++rl6!DkuVi^r(O+2&nWL73m^ULyv-h5b4qbQ2~(}=}L>J^b+Zvfb`ypfV9v- zAfYB9$-Zl!@BGf$d*6G`H_jN}U*EWYFkrA4i^bwy>wTU%=QF4JsZgnTf%F>m-BFq- zzY$Iy`{rHdiB6#ciIa_s5Bhgq4!DlT>+{Tsfa-RM@kuBqK~5ENY};gp=O4)SCw%B- zE=Ykuq3i@ zh2u0KzYwdTCW+b2d`0a%i;vt;$L>AL&mD7*K$}pO4=bBzb-qnCE_3W2VY;YxQAC#& zv&c(}VlrndS4P89a6v=bLJ8e6`8T<~^B*FCWYP#fMk|6_H2jVWTF=*wp>O1-Nfa_J zl8n;Xzga{vJ^x8D8tiU&1m3C)Munky634ZEk2uczzhiqPrYNI&qY70}r!2J~7a>D+ zG$XX<6z1vqo@%UU0!oP_a8NA$dMC%v;8EzV&OV4t0}9;{Cs;d;Ee9Aw;QsRQ*EiD%z@)j_G;D|>H4HLth{Un`yK)$+gvw1Ww+xQT?7b50@#lbu`0fNd z{<6EFf#O^kW7GXLO}Q=?vusAATX**QBwldVsb+rpTsC>z@#h>`J*SFGULF z*yBgeOiyoWMng>jbXNFFua}}q#oGz#7$fEr`O^ve7P~C(HkF8*8kP{&9!FlY=-2K6 zk)%Ec*;j-yYdKhl*wW-7$pNC%WD*(kVlLTRsL^nEEjV`iAc5LJGZ`hqs1N6MSe9+T z2no!$N&S4AjSuZ^u3<|E5v3>IBK%=K!Mv3OtBDa^mjcT~&u^@p^Y^(=f^?|5XJxi7 z1)v5=NeZG8K0|ou)I3tG)vjq1wUo4pIsteAgwq;dV)nJr4AlJ36Hn7xp;Jc&zv8O{ zWv};nae8fLijC9GnG-Yh&udIrgx-1RQiLz#o{_xmXTU}#v8K0y2lZkD&ZasWeR zvCwoRjDV#?2mme5ckek`gfofp(A8T}qmm1JcaT-^XMX3d%zkdntp)wL7SuX$0MYAv~=xhFUOQ?jJTIw&2$X%l{n4U)JZv^sT4 za&5mGiftWrcu{}ziH$k2)O)zzFvQ)Dv>O!WIV%+2c1Ye%yEPT9V0~ZV;=b};B-sY0 zLv#FbL*c$KB{wQ~Kq7(=)poA))RrXO5!mTHz-K^T{AVoUJj>fvL-6J8s{lYzld`2_ zo#KfnDF@=3!ihVD#MK{}OyS1W<6j|iahe~qp8S07MY$L40lnGk*EQ0!F_t{+34XX3 zUFH`}3p=Bv(81X>V=zms?x>Rtqui*TDvi7PPHWM&us(U_J^tNKg?y+z& z`gWI@i+ry)e;uTIR=JLFIv;P-YdS6&zbv8m_4^ry%TN7NgpCUzg=|45E6=3hDy1vr?mOs#_SUCv*i? z2ekShUV#s~5&JwOBu+D0W{_@uDu_z$v1Urw$metBa_`U7e+;%b8* zH4Xp(1{43yaYD;1-^q|-CYGsM1a)cCtR{slwjUt3qrzta>L zb-E}2gTqi6xM5jY)S_hTPrK?EHf>(dFgm#b#n^)bVR$bALetf6#gNzL$XuUUzWhP= zpBxE9ocbvT&}QjY|NgKz+6yJE?w8Ja04q!P5Ju$;GCNjV^v^|r|$+SN#!Ir>Nw@SaqWd}L`>HT6I>co>~^ zSPVT6r0oJ1j%`4Y&*Uis&6+?d0z*A-Dmedq9udX z25IWYNFbbL&;1yq5m+<#PBrfzM9Tn(y%og{3#2oN+@4>|6>zcvD^f+#A5yWZpR(F&P7zHnbJndAKIWv!I4N!^V;+uqplsbYNy%dLEj5b9x0zO&5P!lLD1v zz4cnWO74dI%A-XY;GSMozNBov5EXOzk(qdj{fo)x-p#yR>^*P&892Uv{W)nyOfj_- z^V@3eq`xM8=4{L6Yq9TWM}{O(Gxe~vx^n%MITGql8sTK&isbEYwbr}GA?BQ;Jiw|kBfhbd-dM4TDV@BZY)dDwg5hf$KJ*sW zd*u?;xj(_kzLoc&eDssZQDOWb%*(l_`bopB4SuzJLzBAOgHHl^f%3+$G&B9#AX!?n?&Bz7i6l?3 z$wyg!vKk^N8~n+}aJ?Ig-S#$EI~h$^qa;vwOp1Us$IiOUBW2H`-c}=E9>tEsBd~+tlJS&IicQ-I)$ml3<-*2S;_7{E zI(t~xeZ~)-KN^qPJVk>Cc>+q8M`kp!{e!z4yUSr)18G5Sd6QMWQRZxB)7-6D%(;w1 zk(;<%CIn-cF3EOoc5>C-6xy4DbN`NT^?>Ya9iiS0KFgyxjf+&DRr$6%*Uj2|qiaCt zgUX$e2FG@tLR`!<GSysfV)gag5I&9P3(>K^=a z?`4$onwGFeid61yjz!#^$Q*99f&>JDa%aCrxy;z3q?n35)9!m@Ik%&P2)WZK*H$c` zV5|tLcz(MJyteBj)3SNF+P5yFy?v}$FoN$y-{aw$52phFrHc>d40!458tPy~pL*$_ z8H(uK>ZHCxtT>Cv9mq@W=14@lZ$*iQLjfT?tvbJFH#@OZIXc*Pcf>&)8}E@0 zvC&V}XHK`>>Ra?(O2e%9Q}kN*O9!O#YSRU!lk2<1z?4Ml*?6p0#M@rxGnSu4Kc^Bm z=XW0h8aRutH_am@XHGt6IQ)28v*~OBhE*By%@5mNCr(S^8=f{xG?fx9WMA19I_0p) zwHcz>a?R3|zQ1WDdL8kOk*&|B^FeNU+Vd#HE{zTAE{| zFhr6N=PW0lnrLv1H$r*j!RyYRXh|DP=M&0_3Acl^`WX@9S2dNTRmg6Q7}ZCziecT0 zW2rsV&)ae=%od12B_|N=_$T<#Vo{Zim;;b_JS&aL&=)}Q0`KMxh>8++3c2XQ zqJ?)zxLzCl!FqQ6woZf?+sK_afs66lo7U+n$lNW05vFctl!Vz+o|Wp}3^J0;?)mjq zz%;q!Fy?T}0S?0te}%pe=H4?_l+AYj#`-}`r?86xTfoOM`;-%Ny(>zq z%!|TA<5+^8MsrY@lc`+7d8;))XH?T6l^dYlGo=*R!OKfy-|;Y#9T_yJkS)UO@(|b+ zmoP@X@E6Gej-i!6n|Jl@eJ80KH|~7M9cdCLYD9>%S#t1KK&`?zr*%OySnLai zko^nFD~!KdUP0^I^I3YJJeNAx50dpb8qC#MpcvqoBrj2{h;C@7bs?W4yT_tsulZcY zgpEFAhZ}4fZUA(G*9Za&ATuL(EcGwn7%<8f7d8b;5BKYfhvq-XyvEbYVz4=)15H(# zYf}t*ou<=Xp(YmxI46<;j%Rc}-Dj7gh26>tdv0{T^?dV1*@2KO@`PMVhBJAy4g9pB zHa=+gzGwtrmFyj(Xw>1xYD;%Cl+*hpH+9z<4b^5q=+xDwwhf-O7PHX{e&+{|G~+Rk zRr!jP&@1^${e-paJ-e6oEGVIXQ6Am8c(Gs+sSQN?=+6^FB z?QER_Sm$N%@2f1dH^Sb~bCZ?}+wY9ustzy;WWAT!G(zC>!Fb6&(o`?)B7nhSu_S@0 zpq%v|Ov)DReE`#Bv}GenaBPp67Z>f>(H`Bpp1ExY+QI#LW8&#&m1OxfBl@X2zZ9OeXW8=gDs1w8>|(K!{Ew5g zW&!`DN$5f8k>op<-!kBH>lR7vNj1saJ!jKz!78NV3i%vGxK%5Sn#ndFXE{s7ozBW5 z0Ikj5aAU#XYauDP3jSB$ZK!U!9n$a8*z8}q9_CxM`RYx*yWqQhPrts;=NQr7lSBMx zy=w4RMH711GTol#FPf)Ii0A8wy;YV!E!+$DpPc&ZF%FbH%ms#qGQBXoGBHv6Gs|ml z5X*rvzX-t1!F#}n9cEu)jlm}J#^)3N__%Nb_>n;4e@}NS(4?tNsvv$^BKRO5DYFum`!!aXkf-#m3QO-TXCjFf_ZG_MAS81DO@%C($em1~4uPz1t3X^OX%~D(qHHG z3jJE!>sF)Dl=JvqF)d#8$jf5~dJhyts;g5MeA{9C4TcX!b?n;89E{C|6Lv2Dm>>eH zSy%*@OAQ0SBdYWyHVH0|FDYDf|zP(&SNyKr;&rkBC%$N)sM~vo7|YVWSPqD)m}FXx6HNOpGWA+A1(fvp<|? zeBTfP6BJ7?%gqJQYWoA!DflDM1qm753`g||>B1O@9&BB&eLU@_{>jU?rMEr_N6^NZ z+Uy2=aBAzl*Ew6Hl!8j^B@0sW_)Sj0-~~BAq&4*#A}bO zf>An4(m{~6uLhwxVKtOrK?QdxB2zUUptgw3dBhtd@>-w@N7r93o;FE)z zvBV$b0E6zH#sDCQKVqN34Dt8BqkqC1$-#e(2x58a&$f36Tz6ei9zzZuONjy`T7iI{ z!sp14HWXp=@-c8Tuy+>v?b_AujlLxYx?j``IjrdjHrO&ChCm2vmKvaYKjjXNjU7u5 zi#$LKig<8)L~bD>uM{X^#3*4lKrs^+ytpwpQbW}_ZfCJ`a-ZTH(;7xBZb+raHi}wW zTedAe$j}H$>R#P4IQD7x5Z=M4jU^^z^CILnv^3~czsl_T0+vH!wClhbeKBMc{ve*I z-1m4CK@dP|z<0OGm!iK5 z)?Ppy@%a)juLS*RfiI|<|7YwpO{30Y55}+G2FAxbKYUSZ$EN^tdMpnJkpG}__CwNM zqexP~Jsk#d!qo`~AgsC?&=v^-A}nodR5Rw>n!rtzc{TR;#I8K&>3e;>Z*|`W8;ZRn ze#1h*7iGdT8JOG0DxQW9T-_4T&BGfTLVwA1!-)YzhM2K^LMbN)hX&m7q2@K{K_Nx z9BEU1Rwn91RBh8R=V_0G8jsYKL3eyMGVx}i7`j^^{LG>#r01$(1zJoOvNQ#Klf5Y4 zfZ~FEUsbiHj;!W*h@D&M64m+M!gmq&Tc691U8ey&CFS=-)63U~$DH%%)Rd`k8GWY^ z*#jz0tm?m|tY0vxdA8cj)!u;HWtq-ycY6nE;XkVhIY<3ARUTlLc?s92suFn4z}XH%w-&8I#NXH;LoiN zTt#Pli9~fvSs+XD9FiT;VF0f$N zZLW&?dnzMlgn~`Rp=Z0_<(d>fUp^AEB-FNq^)==wHu0-mEv^PHpqYttAjqrT1hx_byXYd&j`-Os|SUJ z(6@J-3@tdT8=IbMQ(AV*_a4dK|sEua<^Oqr3j|_k(#0!$_S5DK&aHL^e z!TWX)`+5GyxFcL~4%*SB$+EC8BAI_4Q3O%)WsRlZjD%t+rRfD`Z7h zBM?{XA*FxTo3cxZTYKtMvm&lSO}tFg6D%g)2evE_&^VU>BR07uCyr@gSMcUnfC@2k z3J9IDo^Pyn&^8e=d2U-&wARJN+STD@eD;>^oo?lL*kIE0wGI?uq>@4s=(}BPn+yFK z4f(YM45S(9n1Wwkapc@(|1DoUcowRb9EjGI036}!-P=Pfu#6!8YUE9X~8Yq<<18=Uv!_;hXe#A*GH4Z$*7(9=7T zFDbe>r($u{u~psL3r)72m5=btWq$JVTu11v?$UVlC!<1f-jp>sU{4$OhxY}AqS0R93=K* zAIRkaO9J~0=BqNVzJusaPSlT?OgTKq%A$?;)4{iz#MSfpoPBv>W zuUYnJcU=QScdWr4KQk9ZX6WGT2vdz^|$dig=Ze2qY-=a^wRgWR- zBaP;`gf0{5!+<~N=JiluUCBhhI+LX)VPI)hvjx|oOoIn8yCT4ywQ|AiGgw300yixO z888#FTFWdKqo5oFWz4;iS5-mHimK3upA~niV06=jf22-^NDXLVnsrA za>+3+w;hfj@PF-vr`x4cPZ`7LPt7vzYA#-g^lx`kO#&@mhnLQOK`fZ~EqFB_U+o;< zl$%eCMzQTD*~CiFll?CsK4#;C#z!a$gzO^O+2`snr?lnXluP_1>5Is3Ss_B;50yo5 zFwn)1dd=ky*KbcVoqq6fCM32giF6XC!T!TIVb~I%2F}h6%dqf|;S#HrO)1Rg1?f_N zIy-&~|7XfHVjW(oJbM*;jhb^hR&x77x~oZW|0+0P&mx3|RthG}jG|RwZwTR$+i+dN z<>p|K%2L^WAzA&CF4syC66B{BJl-Uq^i*;v*^sMC4IVikcqm>Sr6@N^X;+Ric!SA( zjL;$L)TE7hg#dAxzlQN+x5b|rZ)%#nD<~q%2nC7YKnGL$)L>0Y3_KqnPSaV&QkxKK z8cpHY1MtO96^&C8cl-swBSr}!ojRnz1)pMJblXIIEWJ?m^#w!0D~~Ed(@!fqG%C}y zZQZMBZoA477oz<5e#~_PvdOHaYYH|(|F`g-U;LVtx5_EcqHRi=?J zYBZ6H7qx&o%{_lT+jiWrl(2GUo|!%~6q`wm+%rg+H;r(U$(gp#kR4pPIaAR?H#vql zC`8Jmj`a}a19B#c8IU18YXKfWpo}70E_{oUYy~rwz2YpwME13ucP-(G)hZiYwutXQJ#QM|@MY zzL}dETt5*VO&{%0x92=e>s@YYh#TEjp5<@VbEHtvDO1Lf^<3GM?I1KBpG@>xNVbcA zRd`Jscx#?DZ8#4YqZEe&$U4_}haZzEK{}WY|J(hA>MPwGi&Kb7jrzoUH7{A?;s!+B zPHxa{LYd7%@w9{1d1}cWipv)C^HW3!i#iDuy)w6Bf@qiQ&n#E>{71q3*IJun%zqX7 z-~PlUdhFlb&^rH8)f!dk_8ab4k6p;krM@n%EE!%YsxRAH-QI7f^RDwhxQs=2kqSL}L3*82k zrDD?6Mcn&^1g^Fr)@n{$X7_vMdhXMM>-yOVnk@x{r zrh%BIy(^W09{4~HTEx&zrLiWzyHL^HL z7NBk^Zs-0NIWk(O>b`ML$vLr*K=e$)>V^|rNuWP=~nr@O&!tkr_ z!R^5Gu>$w*$bQJ0UJuK1Dp|IZrV(UBt8JmgI8Qq%nk2QQ`1rWDC{NA_c@)$?U)p|3 zub{kV?xO=~BWG$zU1Hpq!7ZCHzWcUb9b|btSb*bOu!9x!$W|C0$BjT%vF&YUG(nhi z2whJ&;Z8odI5^}f&4eHqDScMifKYBYWc%Uak%PLa@@vykcfi{Y#AWF9^)c&Aw`8N# z21Tk(BU!Z4y#7W@_3jJ8bJOFK{UL8|J08$YzHQHB<3rk!<=TGun%#{ zbC)c(Hk}1#vUeIJ^9G&;p7xtrq+4G6WGg@Ds$Bw?Wy^ND(c(PQ{mY>+|D;~*$WAW)H)q{ zO8(x;?-x0D6ea50wf6;yU*g%n+jG9B0;;8pSB4*RCf**uc9hQq-yut_2Hg14V7HDm zCW>5ZCx9RN%G2kr{sIQH5(r=`w~NG86U+;EOgxF#*L5)NxN)-ic-kx4Bi(Bvn`q$W z=0sppPhC-u1k0p&@iEuk&7)}3K?Do+ujd!d&32^4wC(nWl7JUEVXHr2yPh2REpyKe zuQ{N3CeuZ9%jK~stt^4cTmYDHvLl)M@kYm|^)&9}Gf9eciLCxAeytdAj${+rO8;?? zWv~a`SLE%;Ev)QTb{JEleN)W;)xOGuk%hGR!>kv`pEa2hTw+J3C!96fm@8%Nk^ zqo-`Hdn#^AxqwLkdvY2s$+2Lks0~n@l8eN4kO~F-$)&-6p+5Y$)Y;TEIO;CX{@kg7 z!s7=Og?nPoDf4zzj5Gro`@zrF4KVlAMse?M&%5=g%d9ss3Bf%c$O{?(!kDR~#buLQ z>yl(Y)KZY##o@X(-9~Q7zC&v0U0D-oZsF3;Q@`PES(prKv$+1C7x5OX^H zNRuCarO+b_{-AEcg+l-)xpni2ys)kiJL`ut@?_S>V+S4hqxKA4=yNG55>lI#A}?{W zKQ1Btxo-*EHfeQC!YxbuY>huJ-7e>&WlyDx?m_CW92DAC76-+8*xi=1V{fIr%QuXs zrn{U=#Gzo5xuhifxSXVASEl$*{=1K1jl|rmDV+fc37O)}fz#+`0q65ZW0fQ;kuL`s^hMf21$DQ9XOh>O2U5shvc%Y z=_F9t&*0FYk~X;Mj~{mXKP}b;an<^dPeR8C$spqj0A~M{YlHjq`l0;7Ui)Zbs`8ibwnL zQa_xVz1)HLUGv1?q(vLHwBIz*;gY+Tdm5`=W`?|s^n%{}WKty&U4J7ueC?9}e*=d7 z!fqj{B1mNs6cN86XLA!T`16~So1@Q3+Ou07@<%~4jkL-UT3@F4gk%M*`1JPuTb%3< zcR!=i&juvE=!(ibCepWlwFH4s0^Ff??2D%Du(`8ha2wZ>!pX6(O z5^Ax}$*+U9fjG0NEnPIX=F7EOoj8d`HeDd5>YsaIBQaqY=6b%uG~VC7YMV5pIKQ|< zhD_KYn8C^-|hNzwK@9V)$v7C--^uZ=`w5KK5PYR~pAejR;a4@McEwab^s^gQsuO zDr}`rK_$aS~^rIff1t3@=w$Im;ANCo(WOLEi9|%3t|idxD_kQy!U8K(Shk zq#|+;@_9Q1cva<9IgXwGpo4J(e9#=sb$#WaW3G140@;DsJizIH6hg*xUMLGJ~R96@13bElj365eOz zPW^`0!d5UL5Ce6BX3$s`&oj>gvqi7!ds1VmM=Vok!zp41HY;QXl28&}!PYWGj`n+Id!MGiJl=_uj%~Uog_{ ztK9HX`5QyI1YLHqfGG{PEdUn0q6eZi2?6E4*e2?4!*g&Yf}43gfw2{J4hC6JZ=bHO zxgRz#+nfM>3cr!*)NtzSr*BSoa?KY#L_oDri-R)02g&y4#o1h#W>VAbmYz(_chS|M zE2ufUFL*ghWEDvJPe{ew3iX#j^17bm^4*L7ZV$|G4Y2Q+MPMtdY2hlczZcfn|Kw8} zz}f`gHuU?8gE$bqI*@lMV4fY@W34SxlV+r4_VrOAQ-^q#@A+kgy-yz)Z!plQgg`8- zN6?1Nxb>_L4}*k6eQ1@dAo|&E+8Ma2i}a~Vz{?1s!}^9IDsD#ydZ5VqbLjNyYh_2t zXz9(Ru_8oSWh2|xKK1cn-<~Pcw}*-Tp3mT2*^y?;5k2`gHfv3)Z-nC{0q#ODXMN~} zEK|Ye6m`R?&b}y~>serUtae6DqF7?&jcdt}t7b_L3Yad6eZre!iLHd|x>o+@Qg3-K z=1xt_YPut)c5PVTc2!~{l_a=K&J+u~Uq-RZ*8WxDpAGdt3tyak8LuORI6%}pm_F6) z(Nbo=CjiOw%fjZR_1L65l3#TJ?r|({CFTR?7Q4u3MaJUgj%t5_M8Tr-8r*$p)_POFxhi;J-8?n=8 z4=~Z>NcSCI6cj$)m`!ZDv9gTef->qQldG#vYe#%j5&Tbl6>ul zD59Q3#dP}bU2yq)>iI_Bojv;$1(HX!UHWRYcM>l1Vh{vqN?CVQdG+&WoXIk?sWR3I zXnR;%p}vIMb==H0=&mleByZbDxfaXE?8ygVhjn!Sn+IYAqKX1E2oJEd$LchWR-&K@ zLCDwU1}7Pm&TE#GM!haGwQ@!m$W_XXKKjf8`+w(FfDyW5$9BZN1mKwa3$dX2cVgwW z;y-N>vD9)LhI|>Z>k7uBGibsaJijoP;iqHMBYC@lDc?(dAq-x}2n9r(nq>{rWpnseuW-?4IJMf_J5iAC0SW(4ig9F_&@ z*X7kM4-`&dJk=XX*Q+h1QU(_e*U5pfe?C9tzi0y$SpAlv<$bx8JS_EA&f$VU#!I zS*uJf0K5JyxzVCwYK=6_hR6f#%f}!%3Liq#0W=Yq&5Rvdk|tCg@fV7u4{>hW`%JJX zF%o7eS@F2rw@&hok*n(y)JMF#ANs=pV%B{QhlnHVGb3v4FwGdLjhSom(`JN|(VxrI zVHHQ*LOkOoAgys_4H;{MYefAkpL)!Q9k2UiC9l*G+86;ur3ddi0RQWp%wF=wVCXmM zdI0S`F$$0fQxb#RXmy%E0~}wEcx@vv;yq(z5@a*L+g25M?XAiA%1NXkSZ7{URNF%i zc@QCyb*x;pUZ>J>NZSgGJs+U|o8*;eX}#&FUv~WLd5HTSr%FSU>X8+8=@VS=)?~g| ziBoXyzN?poG0{Wk!eBh;*lCtFTvb{f_G)wHYhUbv-GIUrp(}uF4_lgH0UV|gR5*YT z0u6QfV)i!LV?J-OX*9SSG=)>5EGJ~3I~Zkl-$k2yM{%nE`r-FoDmtSzdYoWv*ol7V zxbS+nfnKw)9v8aELnNB$5IDpp&}40h96wvMiJx~$ofW2`(1su3aSEi#oy!XgeNO;1@wj2*7N*FFY!v@Q0I zRkQRVH*Hv;_pCMZHSugy*UOlk9rvPXwtJsGX?A0WAcF_38EQ$xnB9j5el`WEZNQ5i z>_|Svard+A&UFP-15}Wbv8@H|lz`j+L1rh(-|s}5`E>w0 za}0>uo4qji)|aAzeN@c?BoMt!4%uef_u7W%FznJQ`zG>|?`Kue2W!+a>(^z`1y@Sj z!^l#P9!d6YAl(U7g+k!daN(RE&7J=7i`*jJB3WB+{n2MKf2!NVR8^h&!c}|jfx57c zxEN>L&JWc~Bh1gPRJ>gnf79*b|LL*jOP6L4)VXy$D^8mKoWp}cA+Ie3MxV(vWp7YS zz^5dAnC$Tba3Z1)@2RNzBR*=@L^Rlo%BE0I83d?_szC?Cf6#p@s8rrs?GG)GS!RT=T>Hp zT--S;mk@s&-myq1nZz8L{hRD{*$}#Q06Mw~gfu-G8!Ap++TeQnMNHuvSiO@2Db!^9 zu%Ax%6_dT;*I>`P?EME_D>)5N{yVbyA+=t=ih2~TTI+~dk0jJ3H1n`EkAOvi~?Qk^|ZCI9h`{m?1l~CFG@{Hek z^~n3N8h6N8Rh}Zd2St|SQ*-{)v8@AU4a&F*qzuiYG$mL4@WpXS-Jj0xkeP|RM2ianqj7JQV!rEM+zF~tb=m_bhs^;yv#^@3!AHC zA)PV{;%TxW%H?f(5MPK_eLDQc#Kp4u<@xz$cXd~aB{5+U%9%Oap#gF&m_4?q0JrD#EZ_3efaTU<%uW>0>^+vjkXJ#p7##o! zu4=lvF$M7DtxW5Z(k*QIZcNHwo(ra^Wqi%lcP1t`F~tbOAA*h$rok->FRk27Y{_~w zt;6$*8d2@w^TLn0+4_)qW7kVoO@Y{I5fAU`?Z=O$u=Y8M$GC>k$F;QyP$J^m^zs)= z?+9Py>=q`C56Bm^6hsxjB$huJbD_!V2V6Z?8Ey=-5a0Nv_~q{Q`rb`Jm2z{#!d{KL zYq^X7-;lu-4OYGGzOilsh4&y%A(s3?s6dhOm93)b^q?@dEzavsUxJ$Z?||pk*Biq1 z$J{DvEmj}b#80@?usvMngE!+LjFo^;iYY3D>ro4s z&?(vCVk>Iquv!Py4&3M95q(7@Z@)xsab~;NWByngGuqV6g^3Q|OMZkBwYiNMXxQE~ zm!dFN43EfnRn?F_J_>7bT6*@e2|o~k9bdc?8>?EWUYDCqH6X2+Cu_1NEadeZCc%_m z>B+1FRh}+6!IkkTz5Wh=Pm*AlXDHnH@WXV37gR z3@-AYn4Q+#`|@_$x%)8GdcK@*OHi;FK1Ub{>_zf)3LjKY3LOinIS@p6QR%T_HaCkl z>*h2bqW##@^Af}U2U;fjyTvSjne_5c`C*M{ z^MXqaHFbZX7H}0eWT$N1$ywmOd8F0VamSkREl}K`-`#rPhX`2svo@j09-@cJCikA@ z_clP5f+aIYmpUOK>k7dy4^qWG`~}%J_)8sEGTNg^{x6RL4h&_$i|GS(l^SS?So4|y z|Gh@c_+sOn60rT*TLt|4luNM>QpSEJ1qJraxxrBUe+W%shy?EnjDxqrX>TCTz>WQI znEwaeOTd)6{vPO$vHyR&9ilnQiHOC)GnKO$_|-NaP7984Adk#^Bj6RbqkEgHTXedc zW4ZKp%B^1Zkng4Ea583{6~p%go@G-MYlj+GN{E44SP^8Kz!l52EkC?S1THtSr>Y-s;-f%sQNOa61U#SlgYO5<&-B3{Wd6 z7o7i-sPSeRrlH0l3fbzb#Bzfu`PmQ5w9|DDxgm$3G02xVT=fu97bnez_%V)M?Cf6x zQnA8+Mj=G^!IQZ(2k@~MfMbtuV^vS!i)R}pBUZ`=^6%c2#+lQ%sxP4ud-=j3aUGCBz75X!NQmWW zGRAAmNa(;^BstMo7is0R!w(h7dY3gDO@yyxKD1?I=zF4p*vS-Iv1Xz@P}HOb19C$C zAm8JAU*%pW=5<+iX&>@h0cL8~*1)@LKw-FffY`sWx$2k6*;zd!xvPqOT~Pac-NzG} zM%5qEzXU9lq-tuA8f9~S?mcE#6$a?Jj^zO3svXI{9K)facvR6Yi{f@BHKmziH@71@ zmD+mQ;}yl?Dl%Erequ2lL$0AbYe1gjVuM&A*P>|75po0hv_cP6Zbx`)_W_02Zw(}Bu_yx!zDUGJC z08E40hT9~Xrd+v0UU#EpHkmhR#2(^gH_GI}PxONHk7dPzLabl%aX+Oy?1J>e8-X6Z z%W#iYuMIR(JY>Fp#E)QJ+AL)-sw?{sIwrFN8Vgluuq5;5J-q!Gw?fc^nbPUyNWfV0 zWARBqr(c#+Nn=pf-@j@Fd86UuHcZo;v~}7O-T>yO#+N&PAX!=l2N8a=b2khCCr?vV z44pospM}Sp74bFs&7`z^%9$du&4}m2`#-snDVH-U1DTt&JBwe@u9*rj;G^Pp6Y}VS z+Imr|HlYPx`AtKS6kEUVUw>z7Ljw{J0eJ(QQwv+Jg0;~fhHpO$328YfMsVE4L>z=a zI9Ivca?>?tV$Ko3i`i1_5l6T^eZ4wI(G?+i+p9E{@m3k z^ZZ-$6q`E|JCZuJK8H_?L=H+MfO6dG#KvM9{wW*zJWXm;-YU3qIceI4^Xp^l9a5$Z zTQR{TYdVGSaAM?^9IMW+Po03v}i#$%tm2*e>#zer*b?%H`@H8^>36RbrWcCN$WZ{a# zgxmFDORd`4RI_-q%o~%v8gDr7xu0!5ZvBkT!o1M-_(nA43tGx3r){)>i{UNS=|(0+ zuSR~S%p;{*DatlR+LhxN3NY9F0pe8%yJ+QyG#p!1Qa{KSq>F2MahYLHrq#gal;VJB zZyf^YEc%Xle!dvc(JzMvVY)V?nJi<#W3~M_qD9I#VmMmqZNs#+*`^@ zMso^SgU!~+&tS(-QBJZrStSekHw;-~ycdGs)k|_9x=_VHHTLaKsoq;^#sz7Zt;%JR zijS&n91d^OwB~O5V?BKA;ZOAFci{i}Gz8HLOsjmwzjOf;%?ddvfc<-*D#I|`NFru| z>UB@oLlg!(QJ&m_?6Pfs6YG`kQ(_t^Mz}eim}om2rG7NG=XhXpOZn=}$71G67`f*d zA*2I<$?W{tAs%!>4v+TiuE1!Sb~%j@4@8G?*GC^Jq>*&9WWZsIxI2h1YW53 z%MXKu7r`8UUfS`^8$8jhyp~;k`3+K^Ta>>nheLIUM+>b*fGR2C;<5$)Lr^SC5-0Q) z$`2bTZV0M+mcIIxm25H?hS`1cN|H9D6t%y6Fkgz*H2&=@^!9_Wwa;p z54t;0iTgzx5P|az!;ZC%<&~+7k!-=n4tHGzH&eF;2|8Ror3hzEDnhMy{`bj(<;3f# z7D(E1^+sT4B)8`8nS~*(Qy=NKgNJXTSqR?L_Ku8&UL)yuQ{oEeI?R4U_JmBI);5d0 z=-#G7I}kh=554gDjNnFK$0^5#5mRW&`z*(dk~Cy$!>?Vp8R-5tC~F&EEVBdXguw;0 z+;ztvLH8wTynb*o*B8>qe%*^8p=TbQId?4oI%2g4n%v#n0-fJqw)OQlq~DJ;XQzmb z{#BxLETIQMoz=0%VFXEt2t2L#PMguz&ES`tKv2y-^30!cHDdn_;{Sh4{r>mZ>_zZ` z($F7twH|JDz_HTnx1$#ewBa{2Do zq{z^X{sLZ*8%EUsdBxCwSMC$__MzLq5OXPjQ2&2&)XM^R{m1}w4}riHf963iph=f2 z=Ez6LKk56vO;wrUG_9s5MJOFm7fQn3_5rCb5Q2c=Tx5Ox zKIkRe5Wu!71!v*nqo1k7#Ek;tkUe0 zlouWhhcPG-v!4tn02kWjrVl_X?=ky#pf7DGDnoE^ zkv^{e^vBr^h+6-J`P`jsHrXLM!4*C)V0L!cgk zOmdg+fDZIF_tBa24bnonHw!Dqn^uC*k@6df%Q}~q?S+LXO!Qn8**wI9KcLSNey@8iTD#(ux8k=QxVhwS1wFFy1sf zrO3{g(d*e@7kx|RmM&`;_iXF43V75yNca^z*hSf1{ zB8vkXq)y#(k;ljpLBgHhR(A_VW4D0^;kstuE8`=z%Xg#gF8q4cDRT-=NyWqY-_vyQ7Vk<>PPitz)mb&-mFcxLSp2TDz$1vn z7kot199S$1w{YH6Hc3l`n1t}sZy#(n}L@CmVsB{qlA@rzp z5D@7#6b0$hrAC@`BE1uO?=AG+YeEeHvc9?Av){efTKk;+{W;_N@%>i_HkRCv^r}Ybf!L^?C~FjSm+CiCEW6J183Di%6v+ zzoaq8Y+PukuJw5n8^y$%3_kxHYHa`Ob$o4Hb|2h@9>_-5xiLNtFDX*d8J$g749nHj zEfjF$F4}HAh(*3oZO^I-sQ`GB_I9_q@wK9kCwaxj^0_-6f$6tixeEr#`=2%>*dgKfAC6cwcwI%;3{4?b9D-YAXL5{=F$cbv)F*xXV*4{GDEETw< z7~`wnA3HOe(;s2ub%QBC7qRmqMm#@nX-#ZoAxZkMP`|d*nJnT{?zZyjt%700;2Euv z=-QG4_K7y*VuS`@$~eF7BU@dyg69PC!|veZmy z>J0gpox?nnXszD}UF>P5NM;3*$kB19vdD}phj(If_9Z9QE zb%+={I-mzmg!q#F*Z$RNr1Z=upKak{$I${vq}6Ds(esI|+84YMgBSediUMo1=vI5_ zK@Oe8Z&5$uo^^8YYzn!HXY4;ApB?6zrkr|q+}y#p{KPFs$2!A+2T@~i42JZUo;**< zt@5#W2wT*&Kam2Yx7P7Vh)Cyt#aBPnY@Dj(nXVHGexHk=&@e_`h%mug6eDr}YW0p} z;gU12N5C)ccVp{60-2(-ULu${V8ao;u>6_OTh>@pWl%^KAFdk)TUo_nP6o4M*GcHQ zAbE3mKK$(^42v3XqS-4>mZ%U2eh0Z=>Os55)aRw1oet(H>tb0~7(7H|$`WiV(+2MC zCQ=NN|DH>`ps+soTv;!;m9t!M&yJ46MfGefuRHp6XN4!IY4Dz&zSYw$*Y&c_RYWg6 zJLz{x3Y81#=mNb4PZt2`gHbQVR@>k1AI3+^h;>2_Ij{lvW)uv^S#UI{K5?QIKGWN=oc5FItIBv!U(Mb@59K)&Hg5F?=LhJ7> z^UVdj$}|+f`QN|b+--Wtgd@S|?;ju*h1_Hv>THX$00)_e?h@iBh{Wb)^Sqwat2yTS z{R-J5ktR*&bb~m@@)e)TwfLawd06)t;>sJCb7d76$E|=_rvHqSt3*F3hsjEH!3Soa z-tDf-F~%Zppqx+2u~+uRGi6z-o+j0?=k;Yw%B<6hnpu%hl@)2DWz{-xh-6jp@o4Z0 zGRd;OeL&VEEMQX#G6DNoFI3PwQHIRXIzpU$5I=u^ep1OGU5D^g1*>_rE-hTf?(yM~ z6eox*q$T+LE`ZG-CXvUqhy*wv}A7Vy)TC5f1yRiZO$ zTM!ZoyXmuDQ|jtfMtO7@HP@Cub$Z)fus_BAo5c~*BK1ZX{P-u(xc)vWGQcn$!l$ZQ zB_tN4(pRuL>norxNCa@g2DwIV)b;GnY{{w;E75b$mrDEDWet!5#xcd)6OKNEc;=B9 zSNP4oWqVeuQprogM?@YeMz7t+Y}eyO_t1flggU`!MYJ-AZ&}Ak+>Fvb zTN&A?X?Z627p&Jg4si9t{E7qx@@vF<`CdJM9aL2~L;q6DQRDJ1(ay(^2SA!uW9ZdP znfbvmVq%Io~Ryx<*}*D>ak@f&a28`U_JVYcVY2e|@K=jqlsI3Ckz|00t(Y zw>O@nW$Tvq*bURbcisH#a;hg*pB<$tYzy)mxxdTzyZe^%C9!sd$g@wvy3o6g3~tB{ zNywK%kpn6{dG4S`AL?rv(h|ASZF2{2aI#sYhwP6riZ!`D))xlXOwHCu>~C;Z=JR@p{$EEh)&9d??OGhr@&4}uA&b<^ z^vb8=6qIM0fJ{%F_c~A#!vZo&doa=F_~3sq9AKxg3(w<6ZZ>%NJu!R(aIFr>%Pp&8 zPeQMN{mT095axe#54+nK+Bft;qN6Pw`3zj7UG{HXY560GA)+e>^)0>A02UB@7T`9x z3wo})j3(6BA22oiH_C$UB7kxILs9Bw?1(=K1h${c$T2sQt6paO&2zZ=W-Fd6jSYM} zFnT_b;TH%Xkvcv>jJ^5Tuk4zACKZdH-PX%Bl=X;rYrsFKkB{^fM=uBc08#=DASDnj zx9MRMi#ZB5Cjs4%^E+4+Lggw6pvj#X17v%^Rjq%g5?g@YX+6~4Nt1GXGWLbO1G8qz z`pv6)5P&KT6Lv>#c_~b;dVjar*?_E9d)?Sv$821@R`bw<*RBYA=oT}=%aAp6^0*wv z$par*-m2E~+w_3hH%li6Ods0|TwCAg0VoF;;ck9toW-tvEb|f{%W@2`<+vmJOB+a7 zY~&bSz9`egTkaHk$3NGo6HEU5vCQ|fs%m6E?9&S}_V)_9hd$5JpK;u4MifZDB`&KZ!uzD- zox3h|(Xs_A4K$;#N6Z81ZTD|v4d}fs?l8Qb-zklQOgY>!ELh6QWbFRvW7h0G+Pv~} zLrQM*paPjyA))6bnvV@yG*8a3#1x{mQ@{a}5&d7J*V&B}w-oIKvAcOL5N$Fx=#@NTK3e7o1jM6SbxK zh#pb_8dqmOKwpZXfen6!f?@z3M=n@$l=N@#^SImO zZ3GRo>GRH=nxNwX2ZbEP+StfM3qQG2vsn0=_C9#r0EQ9z@@(coP(zh@+|b@9E9sOefM_R0nFG~6&xVy&DS#j7$xd? zQ_1v+Kz(+m3&m;qT-A-z%bXjSsCI8VF(S!CA5ffSY_fMPJjKjCj+cHwIJKa=nbRq` zAwG4_GfeJeO9)9+|HAJ-Wb`uftx^pW!#4|?OL8L^UEJE+W4OaXnH#aD_ zxEC586j0Cib>8lD<&;+&pCGx^aU5YVV z|27Ob+Yv7JD`?1doNeCHdHs;vqCw{0u|OI=$W@e|0||* z-s?G<15Yq+d)!?-akO{$c0m5e4!s0A4xEMTNfYWRTbzofleH!4McvJBnZ~5|PCseROV<(N{ZK@QrlN~yN+pK< zRG)HpiG~;s+D4&TnSv6qy__h&!g-6}U#A*?s`5C4=s+;KYUU~Pgo&e@vWe{~5QBoiq zRrI(G)`-!(Cr$9CW(xD@nbS$xMX4Ut{O|~^%{d2A26#>O2$mNhE`J=W z2lz;h7IOSXe~dq@7~d|Qy0yUnUhjNEMy3FA8s#Dj8qVZiGsq)mS-S;zEaAr$Av59-DxoGOV>eBmSlM=y|ByH-hL-64H>j zj$3b=)yV~%hR^w!M8y^E3#;hj?7F+>?A>K;knF)~7P=7wMiPQK2FjTA=bY~uSAMMx z+>vQFAJ*!%m!@px&(*fOTW@jELmvY^8;Qa~x+RnkhOVfdgt_P-0nRs^UB2K_CQkN{ zN%;^GrYRp?4|Hd5;DV zrqU5z&r*bGt|_g1AO$vPTl56PT?F$DF z6Kw;?500=jlLwz9K3rq@fk=^tt%zaftCL2jNx5{XsQfOQT;e4L#&^=-Md|=6@04R)77g^<|LYJ8mFd{Gh9 zr$a1f(y28zL{#x7Cd^I?te4C{FV&cE3E-$-0AwE@Pz65&^s!(6qfsCRgnU;-_TU3x zyaxcdbV%Uk_cs1~dB7<6?|!)mn}587*8xUQv8D=Jx5x(>E9P=v{iw8RAGrZ?@s<5u z3-khfekRzW?K0LFh1GvOnZpNN0is=8aIa<^>)yGxHZ96J%Bv3ohAcM!2nNTE8n+$S0~5Za`476 zENXDc*n4S^9QxER=P^o<_OW7<*~%Krcm+Um`s-(<#+}JzI$A4s5cOfyJp2|lJ1fOI zzB~W+(nWy39Q^vaEV;KhPrRD9qiuV6&y~DDAF#~dY}>eT2bua_Gx>Avu8RhOnp9># z@yEY3mcvY%DoVkm--SL2`|ao@b$;b8tx@_g@#u7yD4x0dG<&#S%478PtumMDuaRUj ziWTl@g9~dG^*i{3X?E3SfP4022YJ%703hfhoA^~EU=6X`QUfJ%b2(MNYxB%T%&Jmc zf=uFE-I&kh1$ulI6_%r0+w@$WZ(F9og0X{iY+YNzLdU4kD5InTbZ0L6xC_{gZk>RZ z-XfPkz>nx>J8B)=Qe#@U>v1h8T=u@8-4^seorqKvix{Lf&_Ssi9as8G{cV zI5Wn6FJj|)_cm$_)Ps~*P7%S^ki>gc>L1Tu)wS+=i^h*-_SKBWRGd`tPRzg^usL|? zBIM3CXagd;aA$L#tJFx!sKTJm*@0Do|MBA8J8uJql2)bF(~XK!OC&3L&DS4vvmjM~ z3H@JA2>Bv~onwX5$wyG7aI5?EoqlJ7l+HFIp;k34w7@F<@wXz|qURj};%~x;cLw#b z&$sz?r_=`rn7tcyxN)dWo%E$!F3v)x~G2~T0=HTiBK(4pWB zIIk}X7&NQ@l0idx*QYseJpRC`G@NZLMPw4NP~>HbruOzP&N%;Aab9*pqSpx8Q72xn z**|VXl6ZOn>7p}8^Mmv2HlPy>2=8Zz4EAB#^bqu|mtQuusIH5C`}XL06fI8W`Ws(j z4l0TYCm6x8-hni`rziF0L^>i(CTfwelxINw>O>t=8RKA&F3_*`e(oN}ZP?Lf3*O?9 zSivA_(*tt&vsJA3f$$Q@tIJ5*MD3h5&94vjRI=_Jl__?O)rileQHbL-HK0G>4Ky6p z?yP}xzB6}rS%A6Ypy)R4G^zBG5#3R;{%(4HuIU~VrIfIR?7R*bWPK>U5^Y7p6I)k9 zlAp;D(?oD53sr#)QW_IE@W)9lBrZDUJ$`za_pIjL-5Dh!;2SJTc#%Y%tp6FWrh7iS zkKla88(MZ6iU~uUr?Eq3)M_Ht)O6%$Zj{n}#~cZTRUzp?sbI}fsM^J&Gm6t(Du;;X zqvo^aTE;qa7Q)f}zOc9R1AgEyULndXC%rGoeSGXffTnr#1E#*QyVXir5V%gGIyBUI zTAn-r@lg?tnScip4K1=A=Yeh^`Ds1I4#D8-vPCOFlxgP{1HzuTae8ZANMeZsWUv>#S;nR8ke8|r2jWrT`40)Bkmr@x7YTPzqU~-lNV0rm z^!6>@L9&un=A3n{4(~4Nh&y2I_K`R!_o2mW3(+5DCrd%&jg|=Qob^a^?0No_k&&gO z;l1ET{I);(=?3@vGK6aOqHj5%=(uwca6X8^xaq`+M1_PO#p74|ww?h(6dbL(4go~nC~s5lzpGtCuS)Y!0nd#Ng>lQ1*%<9bxF?P^0@ zj*Ym_Kyjx;IP;XlCc(!K4fWcVEwOnGYjofmkf3i%&o4IW_@TJBe-P|#fh>In4%?FM zQW{;fxw|5tJ!N0ZcrrTBC0wj8ei8yrvTcqM%oroqn{kT z++y^}w+V0Hnh+#0EDr*i9Z85!(%=@~dE5~`5!HHIb5w8Q>V9M1H4(-U$|EtCs7wy` zsn}CN3W$Ql!L{7It+)I2#_)UcXM^}_1KsSs?5gZEuAD#8veJ}`50>mNW0^_ujR423 z38lU`6Ip_d?>Xvw;+;}o7rPiKKwl01`stmSH97{!DVprht9j4cA!8!goz{st=x|L) zDr~%MnL$%2)cT2o{)qHPBvFdVSz21$V&#!SJgvi%R^_#?EY%!IJ|%_LlKAYWMfrUe zwQ##L*6S|08l#if^x`VbBAqe!3Wn$BQ}4WrB)h@a1r2$cuK1)30BFy20S&q$hKS^GQiw!s zelX0>sejg>nQmdg;VW~nD9-kqZ#Bx(&d!Uj?3sxoY9ZqQ#)(n~FwPg}B%G*AtO z$gg;$gHp6VlYjI2{^wkuPYp*^+!597o(sVQI#t?q1dj=Ps6#z11sFKV>-xf>TWRB+ ziSMZ`2VyI&7yPzY^}>FSKm72Z))mU-V}o@W*+vaPUn z%o+cm)tw8E)%emQSI}B8`jxO{cZr!FEa4?Q%fAm$7+C)Vv&wG!Bg;c7*fHo*E z7~ZlprN>ogV7IZsKM_ft-LQZ*v?sh;mBxR8Oflp%H_$@Sfyox>u_I>oyvxc@=>gqh z<@e96%#NLF1J*DATp$zJtrw++hacyfH~6pXvE9*mle+}2RoZUgn;oYT*AyS)Eq+Gu zj)0(v`Y4f>Ja;sBfU8|@Pe18+n#6}=Ot<41dQj~U2nNn*hUc2Sn2GMr!ZG2!Y zXlkBF{ra40SiwS%V+7IuGTjwl!cG5hruDZ=rO+hkx(EF)8)U?l|Q@bLyI1#r@t5v2?Qk0^t%ZJ4}#3#A6H<D*PuEo^`x@K^lYhk*yr zW#TJwB5I>Ff^hCvSx_E*qwMH3!>;+<-zUe267S!MdJ0=R-E#n}b3;mFTF->A`s(-6p^wbwK+J{>Q;#d9IXDKJ%+!x~QObI@G^XWmS@0r5ThNj+k_||)_$gvN8Zs~yU zEFr7PgYE4097UQ5qSXdo)7O5Ne3t^LYoQZ^&?Ui0?9o-*UXj5PDWA?wGsb79RS%Gx zKJGdJRuIe9-f2mT8^J0z>nDV*AcGWN2(0CIY$nIuFL+89 zJ-I3Q4pr)6aB|)dsxWmLfO?Q_t%vg0`sSI%#At@SA1*s&Wo5v)Lx1jS{UMXFq9&6}2()-|yp6SzrDeHyv`;pm7tY8RM+59jlPix^1zW6kvr(7ic2 zQrvR0w|CAgF&l4NjgzV2UE;OsRTQA`L3Q?yk2!~vzLBf2E1z5=lxT7_?{EXc$_vGF zn0Y4?Xv>|TiKr#K7Sq$4Z`I~K>W-~mK5Ka&fb-+_PWLE#PA~TqJY8En|q< zrPL`|Z<^PYn3E3CwOKb=6q9_X@Tmv}ELIi!VR)oDPy~cKdc5jbcw3f2_G3M~Bd0S{%1GdSH0l!f@X`^D<8Ye%_cpl28O=IzKx+~J_&vx(dvA;?Ze`K^K zB^T&?Y0UKoOeJrD60D5Q-W?sTOA8wRiD8*SxbtzZ&LgAw`Eo0wg?!ANSUm}*@E;mb z0ALm4fgZg2vW9o1Uew-|oGAO~LxP1pcja2$zSkGPM7%Q3$c=LHuPf~dRsf);C`3Dp zg13TZUf>$^5Cgu9SB!*=z)SmjL2Xj2>|92VoIx7=0#|(NwWmnmw1-3W>??zNpu2OV zQhG&ZzDhmqFKRS@?hz42s!2N7R@ZSBSZu2Q&BkI? z`X-y|l7Ou44@t=y9U@l$J%8fjMr8UZIAppq4B)!n_!FZ29}!*sy=u996(FjdlunoBgJW%XLd8z@t8aDvJ zrTmv00UO}KgJS^P2$l-gheXyH)|S_av1IjLnbXM#tlc^PP^#S?(oLLShN8mc8R_JXNTOYVC+S~jC72dmbTmb&-F%^W= ze>9U2eTs4v)F`u|XZM?YGL;hPB6?`w{7oMk!t&Tm{BU9|a&%7xT;wf?z`C3W-eor5 zL%K-t%cxg4$g*Ie*_H{ow36)k*=ASwd+x;`&a+<9q0n-LwZ2 zu{N|;l{^5Jxp=Z%Yed*8=GZP-S>{~D2jt44#vnCxjy!@|s~cL;+MMeV1|2w0>okf9 z{`+x&S`AKJFd_T|T1jY{gaNG%UON{<~1t zyNI}6H_xTwj|?O0laWNzfLQ(NfSV7Dm-g!5qD`B6E3EXPpA(O;a;3pOM=X9s?SD?x!ee!X{7jf8L!5K5lE(> zzP9T*HPjZL>Of`f4`q?9ef#invoE1u{XNDyt5pk5-u5e9(S%GIVZ3v#k!|ahe8HHr zAi+hqzNDg=TeU$qzis4HSq}l|I*|DhKnwi=3;-@X@<%K$)4m{xk?c7H-<9^d5eYEM zJ<0ubwOh**^jsuWa;Fd9{d70I7m?U~S+Gfb33jhk$EFy01F@YNdwSDBL3U{^YapFk zGfTjCwTsugps)WoA7xM+%^qxc_6+Oz)Q#Wt39nei@rllrx}U8d@b(Z5*@Q|M?$LV1 zrTk)%>+_-GA{QKMC1hA}O?D~T?A9BS4fcSskW=u~g{t@SLBofxoWCiq|BeJbet=sA zd~oV87Ya$OE%CHeP_@rDS1qYt*)pa(Mn(SOJL=J6GE1m8%q}X2o*scM5$!R7pG!Ae z6**mZs7tr7$+e8n{8%nE+VGTN=bfAD%9 z8J!yo7_28^vb}B5R3Z$ogl=9-^7MTzA+Fu-)m=2%`C)=fHi71QpHHF0mAv-@wA725 zF_9m?RJZo-%e3AM3u8`Ku<@bDnvhE9PVwRTaIfd>#DK=-0PGZquVWtp=;!5pk!Pj= z#5xdn4A~>h8Z>dwoK&+f-FIZa?(@N$ zZy0R0ck_gA?lX>Yt1v4>Jo9Rp@Y4r%2^{1eX1~P~dT;BuXkUUoXKwTN)gyK?k~8b2 z=&q%w6?K<;I1p!0r!vM&-yx)AK>S0AEeY2F^>*$@#960ueFX}Z%Z zdP9t6U9<_$wmpj{KmivJqclLoAU^m~&Zi{ESzTF?&?WypGL{(2PI#4Ov#x>>#@WU5 z#qDQO1A_Z3ZwS#u1pBpt%FoFc>ysJ~e%to3eoFC~`v^c)v0{ObdtJ|G<+$S>x9bf@ zt2dp|u-rw$p{2ymdKwu=*Q+hR=qWP64^k|Wux(UZ=G%P&v+%@RD;gzv4 zqvo%@PIXnpq3jQAbC#|pjM$6QBu9b$n|eK3XaGai*LE#h-Qy4$swuq90@50hq3Tu~ zC4&`tY;<@(@>KsS%TybUv-ujm&dUCo1hG%I3gbQ3-^ZO|+w3je`zQ;a9DJBzm-{N0 z79vHmv8*4Jqw_X;{_!Q{R8zT`0R`Dhc&lX9{(euz0e~X4e``RZXO}LknN&#kaJk(!&*BTPFQoI_K*$u@Xio&ryJNa z^F20pYA}JJ&KX;c3?XN1R)u7BpNZc_yBj@;)PCplD@0`G%i6HuI4QM0ZxN5o4D5?u zultdHMfiB#kF~EaGVkQ~vfot%Iz}kW+PW$!ZIExzQTb`#{>=K5j~0aDgNE8%F7NU^ zSetOB@SRtMhI2I@%uljvYKn@HKLAAbR{Eb9m-j!$@#IWuY)U&LoU=a3{($CC=Rj1d zZpTx~AYqO<%$%p;dz~!!spz9r(r$}<$rR$Exa3*-W$3!R^F&gfRBcxEV)Ieq< zQcxhHHohPJI+@}&L5ebE^PnAA8Zw!78$Eu6{m>p*<@PgZQD_C=p8Cu*PgWURyEQTG z-FcW3T;ud$75w9mNdv&={?ka{ZRBN8h5E&N0EW|^Cir9G z@2mktIM^CQ5k7sO6+R!ue?e@ZL36B$tnj_3K3De0GCc0K}s43_3 zlMB18=?KIM*bk=@mD?EQ?Y=BWMqBbTYA-ul)mb@?@mdHK9bt0^58*LQqwrPSrQUKR zmD|CJZNfKa&xY3vN*L|&ida`+={(!({o<5uOy9p<_oe`96P(V$lh`7$ zYi`J0_@kP|8R#^dyC9?9goup!RruI9%Hn<7uiZkWgubTEvu~NmmVIT*F$a5sXyK)J zfJ8;y8SbTL|31HX?sSI!!5D|X@&0C&)wiOF(S8yFC0hPXuQ8ACZ=utqw)>pl$p&YR}m#Ff10t=Re$ zeqWYzu63n%@3W_1X}5Qlp?lzrf_C_~LpdX4JwePG&TC3s^l0$%-b#}XIq1eB#w$s* zuy%Ys)3NbTDPFnB&%c9uk}BQdJ1wE~NpZsB^#c4DyqrXwtVlMRSlXXVbBHJ%aAdK- z_rZ^)n}PHM6ztNxA*Km_&+9fuuyxd}teK(SBQ5Gi=?l&#NHF#Cn*oZXT!ry>${jL0 zYv5_IHY&HAHXUf0SPLni?xohU_}TVZHwk5yl?97QiN;|!TgtWXF5hp_523&pnr`~N zi=Dg346jF6kHKG7`tb!OSw;gn9}zd3qLtU;I)u!e7%%R&J5+@boE3nlH%>nxF$91_ zC*ouHjuIPLw~#?RHzBP6op0C~MWn~c0VDRJ1$0MRLNL_JCXXvI>-liztWbRjgC_MY zb{&&wPm<+&b+pM+_C|}rLZEua^vvcH-EJ}!wm0v1;r!c^U2#ej$OD?yb)ydp*% z{q*g}rbu^shtWt65Q7ixrR*@$OpwYl$=J`r-YRQT)6n4NSFui!$3c28qU!o)f0o76 z)O1UT0wc@&X10p>+x-@~#(tt*Lr$DWbY|!ToqP%DW2PY@)E@*yvVvH}<&6MV3{g;p zWd(iVa-fWSW`t5SVhX)V5R+G?~TNw01JCD zHemw#eJJx$NP~|T`TGmetVWrr!b=f6b0tIQGkpmL+A4W7Ef%`+#7TnH4d$KI1IitS zfw~K+?P(%E{B_YB-&^6*@d?v`*%mgAOi!Bh1RMKwl=YtxByd(q;Mr4#h^q0kNAeCA z0&c%NUG>x>GjlhZ`@V9K{x<$7^RaI?o(a079`|C4?$({>LM%l?i4_zZVs7j5Ls6Xp^x3N0k0o*|&!7M&x zP&cAxa&)x8CeTg?LMkB)n*Hykc_t>7Orp7>VtI756+WI!MI2wl`K=5Fvyfm-vYU&r z2UNTmh(XJ7SVJeGpI+o-9ri8ukJV)ytF^tV1a;KXZ#RyT$@*;uNi2I4H;_jT? z^73KAPgw6YNniIM)U8cRT?F3$c4L3;&pPx$X!+(kQFvDjYqS{A_fmz)cd)Ms(@0UI z`MD`}6yvCXg<(kB;-^o8(JqVu887|(cPPl4Lm12If^;?BTti*+2zbB>@^)*vXsrU% zH8MAJMmUpG;no?u1w`oLfmHgp?KNSW5P~*|wjZn!*V7sxZHRE7Rr0wjPy@K&@rVRJ zYPL-0;~1@t5F>sMiLp}xmI=LYB)c>=7~}2n$K8SZG@L13nbFry#9EOd3wQARR?&b$j6!tA z+j@S^zjw_}k4+#oLn(`t&91p)jlQ{4<(L;cy7;IBt<}CjQ?I@ZC$XH44Dj;#7)@^S z0vWVxDqkzaeIdC!d@Bm2&4mSVWLDcUUVG)arE}(Eqf>oKK912<6%Z*%Ulvj}R9x!u@xT&IYP~#@i47`vd0tXOfw>=s%oncnG)h z*9ba;e}f12cM6Ijcyrhh0O0JvDNnsMX8bQ;OBUMl_aSm^)PEm}T1}V(S0O9F(Zihq z8rNqc*PyM?LqRL#$pHxO&u3R6E&4akTjl1TI&{J`RAGi23pW0UUHgrAhnTm_4*N6B zGhx>4;m4O<#gYFt_S7K?0lHq$>`E(YaxH9&+S>ct>f!er%z*T zOC@M{(eS7HowRv*whD_B2lviN(j#NV^4JfHI=`tGtKOzxsE4ftmKE<7^T$ywxx{mX zG2tM|niP8OX=@Dm!$AgYe92a}Mw%MoJ#^j_D(70uXP0sI#8pt| z?C+DSMm8P8QUjT3D_h3`jOFzS0`SJZ;1x-&DLNY0i1mHIIPji>OoDXNK_g|-e@8V?nU{9fY$!t~+EbVB@^dB#F+BuRkqP3`vd01P87@+qd<-q{ z`N6xrM-E8XLA+*@>x2<+GlcixVCdw@crXa@&$7JL8R^kG~ary$58cOz5xR z;e_bgTG{}Abyi&;+{9sf-p**blzu)Jn2K4tVI{4X(GkFps%)ZgOd-jFcbTsYVrPuog--$b=(KfSS%b!H8$y7vK zd*Il((?sJtPNkQcxrL}rA1K`XZ2@2!o!hf@asoqBe6I(3g~yQ>L^A+Ca=gIYsThn(E-Q7uB?6Xr>=7|^X zD!J5d~#I^fgr93 zW}12ojoL#Of5u?d3q@_A8+gS{?%l0SCHS2QC3KOO#rF;N7_KH5K=F<9VmGs8wE3-< za&Eq6a@K+7JN8_BOu9(U!_$cq0`(`=+sI*WW?13Ek|IicZ~^Y2{3^k#?jIRlxjpv> zSyK7-%A}*Sy&jQk7zn@!4+}#>a^D@1_2?#wtZ|yg%}hWAMCRKCr@453qT}Pph)X&W zd~d!XAnf(=g>+V&itRB2YfL5JECbC1{9hWpy!nM+R1Jh@&Znb3vGaORru{q>Jz~;F z&X{c?*$<>85-vGguGY?@H1rvazK;ps%(G37Rv?6iJ5?t@86%^p zT_K02cv*jAU$Spg@uT}kdAHgZ{K0@_eLO2E=$R6z4>bY0f9&oA_f5N3`k{EpNWm!f zofo>x5w~3L+I;r_tr(I0Fp4?5m=WZZ?1zhjPt^B~5?s|klo7ar-&{DBnFsZw#(9_7 zic|FjFG(^YcN__uB(vk=Rs=vAW7&80r~622ZhSR%XEi-TIA5iUh6}$LE~hT5)_b)& z$I7fve{T(zj)>dtLpiFGz3JOsX8Ovn+>-KM!Q&dBb^kO~n73s8z9_5qYGKu5a-QYu zM{g!q(UeC;^Jjk$2vrzVSKfVL!y|F=(p>!)hf^nBN^29RJy4TBZFy?;@+_KP$#BW_ zNmDlg#u$zXZPj6+K@N-b(tVK^Ii|AZ2>dR!Z|AMdRb%yTjaRs6SkG%}^Mf$&r?%0S zh55LJZuXx2tf<8@$cGm3kX-Huia$08hfHHmMD^*(gtB^sGy=#jn8sp+%3YRTtE3Z6 z9++H)LyE0r>|9$U4aRsfQhK`D>E6dK226$o$dv=bCfDitvRn#6H_z7Q&PjuT+K1&9 z{r6OS_jk_JU&l3}UgXz34bLcG`V#U^CMad&JZYmKhgzdYPvephy~`NhARM&RjLc%Up;{PCMljTPvT6UzmcLki=8pkRmY_lxQV_uJ3 ze(?lgSr%cDvo}Ph59-s|@-$!EQD3I;ua`l`v^sjJV_hvp=r4sd&)x)%hZrnu4gVJF zcA#-l7L00Gi3|Es%A2~TT0r2=upN3X0KNm1G?{IrHLwgwyek6W%zG-4O+lfshJ-v6iVJ=y zgy#lIyyDUt5#ZSg7%=Xp0{bp5`CTPElY;8%0+tC=$t=Q%R)db3!HEU)oELzcO93C7 zRBwrEDCCz_AkOlnVXLv>tMQ6_LCU89&*YnzyH8um#z{f~%=+ghfK)bLQeA@0M;XAV zyLSa*lJmzhwEwkKi?!l;?$|$r9EQwk@;t2aT`GkBX zX%GXtgmGUv0*j?N+`-5HH^6!GoRV+j7s{)$+Ww@Bb1j}O_wIdph@Q;ad8n?&?+|;0 zFY?L3>4f3r$14g92fs`|n@LIYNzn<92qj;O)hZ9m4!sv;}h zve5^skjcgOjo1NIn3i~$$0f$%4+68|i^YiOX-M;h*yn6Sj%Q?o4BaAO8dkpoGun&- ze1yzLoz^v6Z8=<@@O@&AqIa<=mf=vIyES@gb zW5lian}ria=H(8&K~&7PQfG>j}9JY&<8@#ORxO zR`Z&vm(DvKf0}M5 zw~_5Z6zo*s35%R~qu;!iRTZCbLy$3LFKPz!F$O#FqtYI01?Krqk&HxN0BplGsFIPJ zdTUkEm@HdOiQ-sM*F#yQ_vP=L#Kq4L5j}?%nqhhJm(*h3Q8@d%4$_v(dCOd>`ERD2 zKHC=0(Eqtm-I*6dXa9<^{Z#o`1aG6%V@SyYp}PU|sE=6=0`wkIODw4w)1b;!D$BLR zG25$=tPtHt%8S*5m1^EnjsO9W26j3MSQ{3uYdiBS3jJ;-`Ywk+bgqdMI5uSSOmD0w zicu~#{F*I6pt9k~0J==-i+$n0TxJC>ES=N0KFt zz>6O|6aIt1n*pb|3~IHZxVx3cjE+x!JW^CN;r^!3=N&XPVy-SBp!%%XbvZDA#t-e_ z^lsTC=PFLRsZOl4w0OBjMzIEm>R$QG5bM4T4iwq*IY@wlo30r7mT}^7Dq%hkA$EZE z{N5`5Xk|`N9S#}W6DiQlu$jvAm+2&}RUi@a7IH_7s(;0U>PMuRiGmP`H`^O(71nN> z4Ja>SlCXnI$X4iGV<24&qoN-;O7${z?dUG^a1XX+!<7+I;eEA)Ksp>lcZ}*u=oj-ye7jf3xW}!xE07D0>tzv~sfL{wFDUjK1UzM;71zwHEo;>#>I(ENPgQF@ zYiLC22@yV@#F-|Lq0;n4Sl&ZMW(P{nH1LzBPcc;-&L>?(>R)}U+60Js264~8-$lAf6sRI}wI^;8~x`K85j_X5_|8SHSkz73{#=0;TkfyU-w8ew$=C?1{y)sUcUV*1`Yj5A zq97tjuPVKZ^cn@}0wPF$ zea?B#Irl#IKKDO`mBn0huFN^!cZ_$u2p?tH;-f}MZ3mtG{(TkxefCe$;l;E>BBeUj$ecH;&?7y)yQ^m|xNo zA=&NOA1HI_x!v9J9A4fyE(la0tqGYJOGZ9pqbSZJw6M^>VHfzJ<4vADBN+F6Y4;;W zHY(z|yHwnnj**wMqF-PM*q~vm^g(?SF*EIhS_YT?Hrfc)KACUza@asiZLorm2udLX zHP;R+-W3Jxt)aGCje@?6gJscuO&Rs2^Z z$$(yz+IZ>A2aXe6mKPB?j*7&Y&j3cFW$s`DSTC~;q<0%N?$(%$Y~HNqv`|pv&&Uj8 zw(eG^x!6!dgJC4KE>DBb0{i8|3X}FU!8_%Y3#;?M8Td&=PTV@%P$x+KCOVX9R@sJe zBPB+wK3BgE^9Z37ylTp3W+zs{B^7GS8*I`4eSY9LdYttnOR#mKb3)hVXp?VCJnK_B zQZ%-rV%*nnt;%pg=kXqS=ae&Bs0E9iJP&ZjDTlZUz_G zoKq*I9u#%h`P}-f#KQOKB=2-ynQ@= z=8Z@6mtu`hDB?Q(0sibN-*ZUf#s*zC%Og!)9FMbyX<8A|D3Gdr1^eLaz+pf!^65RZ z{&FJL2kI%;lykh5lCLB@fucX{D+VpA+pO%lkA+FI3W#Z(0_r<|rx7Sd59 z|FQvaV@w00?aWZ#37LQ7oBl+`|MT`pp}a%YOQAS?pwpdPNP(^8(fjt~{1W@ECb7_6 zI-~Wup8KEFzCHb(-$o?y&|j81ba<@VUeI)$-z}BmkoSt@*OY<4eYY5HbEh~9H?~Kt z88_0WtFiamXJx0q%BVu_B<33ACCUzVd|i!j39l5x>$^|(@(JU}jb^GXda^DkJ}OMF zcBI%bw`M}{)$A0*6>o{`VZ2cNU|D0RpbSq$PyfQFzO|$=B&ZW~6D5C4Ds0#kUCogH zv)s_gSj*QF$vrzpVrQ`NBc1d5xjj+%roXfi?+*kl93!-5Ckbm3%4gp->VIALT*Srr zC!RSaqL*4F8!wkI}!E%2w?O0~Ep^HjH?G_hDrD3e**J zf%ejIO4)y_xtOS{`Mg8ul6yuZAy=Jh;G&pqADo{d$~j^gjPV5WD$^o3`4Sy%Y7v8x zH@bL4WqGf<3=M`H4l2X`g$t+bq3A{!8Yt+fsW>M6qaLTS!FPeU#Vi(%|FB z)u7UC&<;`3F*CIm)rKkr=BHj?*G53(Sf`%=_TB%$%cby+5e?$!91&T4CJuyFuo7o8 z5&#k>19{88{9bye^QQHfX~HdMw^AdDC${dK%Nzk2=x@Lld!*-U{m*iXVVGE6=dgQl z8;e5vzs5H`0HMI}uQFM!f`Z&M1iICejddJ1d;RS-D&r-$$~IfF|8?mU8_V7rV88>B z+8its>NHD9jT^s;#4!%$eXh2(5$~UBv2FKx0jQ1n#c?gMm!Xl_e>p!g8#!WqIZvpV zJl?{;4zmj2ZbM45%dvfs1waNJsMj0;pg6XxVLF0cy^>C0#-33iw&{t_iajDr$G!Rz zA8f*IV}fJ}kh?yOxo2P4Rhn%*?MNOq@BwrG{nmJXKHR!WAEDSNuG{1v1^d2sNq_P1 zMyvNaP{H4ZOb8T-qBY=%E-tFB6H@gUj? zw}*>Ai?xF!rFHQh+igKvfVR)~0Zm^6WJlMn^o{b&GCR>2+3};kaS#y$vy=p%0Yh3_mQWW<6Z~owl2rNK?NWXzW5jQC>6q6BdWHs_ z8cz>H$=|40A+9)?9YXTznQ>+yK;;2$NZ(;E?}gH#0XdgZ%e8u?*tH+JbgU){s9SBI zbl9k8e}bsgqr))_StsUEr?Xz?WX-4}vi@j^#r_$~EfJL&v9E_&Tx4%HzJb&(v)a;G zh&_9wdn$ zn2?!ooR}~^4RRDYcfu*76w`Wcn6Fl_*b82w;q=h}3Tw~L48@?Z-+AZ#M2#7oR!nMC z#xkOCs3X*7lTgJ%Eywhzx*qBCuhurj8w32>DtGERHji)Me z4UjY5M=KNqoXSg_)-5uct}??e?%HqE(mxYvTTeU4)YkWQ9j{sI$vs3F=WW-{#gi;( z#4FGdV;I*d$YkzJvUDvGZp#s)+{`Vro_R3IKRW599OOFQdei)!OWNbLQ!6aL^TTNi zjPW^BR)6xD9MSde59ob_=&P**Mtujf1N0I?&NMT04iLVww>Yx&Zo%xDkA3v~?XYzLw{Y2KWRM!~L>KsPg{s+cWxBtt=KP@BQl z?Cu>9-!!l3yF?gA+={QUdeGfz`VPf=Dyo+cz8<|MdR}S2HNs-sBq%8C{Lv7nxfBv9 ztI@_|dx!+gnBo9K-!%BuBNkLW#nsM#pGd@{j8*&UC^Ty*D>cUV?}8fH1hj)|lwWv9 z1~>af`3*0MIMRBOQg7Qgk8OB+eO;D=c$VVreEY>Z!3-Fh*Aa>5^&+%ZCt<<}Fm0VJ zyxK6^=tD)=VQuB9Ro$o}6 zlbm)CL4y!HH@E3V3_hv z`dnvtFp&#?RX#bbd z!la8Pk8*GP%9?j-YGWN@`&?RczL;k5WBPhvccI4C+2GSSp?qS);CZVx*}a_x zE`k&23Zxb2nnkek>KtCOqu+94vKn@8KDh|1{L)v`xIpVga#pCQ!98v%4!L!<6b#rDi6rrPOrPAStV1B`gyz1sSG|mO3wq4=sp1B{Laj^^+6b5xTcAjH z`XI3c^{RzxcaaA%X6oBr}KbbV*s^*%YXzYWpJtn~mC;FyHyK_5x$8-Cz zWx@6t{$lFE5}y#kbvLEs+>;zHW{GwCW1)8v&mClgwMw?r$sVCnmrL}!p*MXLmPrK+ z-cE?x^Vrd5CnmlzvqT z@%*itAjij|Z1-R2vuUoruV_94%#v}mVdm_9M z4k@}?oS**H!cvy3HsZPxiwO)7+}j5>_wz4ToIXcpL33@=}G~EW7gi50(R>(%qU_fby{jNV@Ilt$jCX zFZ;akQ_o!aNO))ruO_f)>dGjxD|t`va~J5cQWR_(sl~p3-dB)#=+Tv_Ik!Ni*y3}} zw-4=t1(G(#SZCYWy*Y)>nLmaiK>-^PR!6F zf*853On*6BWL`Un*0Xjkp@S^6-HyZ#bklgTG{bRAZ7k7Yh3k6#@e?;aS(9$3A>ppK zo^vbu-EcCD#srSrhZK1uZD|hCnG_#UdU^0}*dgPkx_dBRcq?PncL%l@mq2uTDBh*= z3QZg2YdnZb6cP!gbCVU@>O@C;-uY&@jqHurJ8zrp6Xc?hPW8XNMIci=gVR71%)-@ep6*iO2jIJdp*J>U5eWB#2-?qzHiRyUer+k-o$1fthq zVGUCuuCwvg4d`*@aK~n9Dc^i`-+O4@J@y4`}#6~atr+552 z_QXiTA2F(INjn&q0yEEK!>|?o;5<)`ra%YL zjP1PJL1|0Q8Y!#fru&K!pr`R%sSc@5rdzeHduFwtH(-0diek;qIHt}9b44hJjF3rqM5~&ho^7{Z!JCj&H#=A zb_X(|#By~Qg;Hw*;+UtjkfVkXTpwV_wKc?Ybq)vm?`i@pfZPXw4X6fGW_ZH@uqx_+ zo{|USuCo69`6^I_EBx{#7>9+Sw6U={fC422ba4O_!vUaog$J>PC42)!8+%)1A+uAl zF38|knh&NuY>!kH&4YTuDQ>Vlpjpqq?w~&irYtVbL=7O$X%E((#VE~ZJRv=a%xcE- z1)}L&`oW-WRx<2`Ne_l?5eVkK#QwbUk@od~j7uH-?Kb?_L|`u3)ob^^t&IjINOh&WFd+~-HYv$;arU#$$IZ*~9^I~Dx*Xa3{S0GSiO=KOK#yUJDQ zom6HNsS?F=ir;uNR~!wX{Zv7m&kE2B!wv^%qHK%|Mvrkdp8D1M9^HxmtkaAHSh)5tl(eB_#%zx z_4R?Vz8>dzICR$!eEAN($UGUcGAOKJm#AR$%uf>u%Reurhj)kU0Smd1hjr)LS>!|@ zknUw$FM@4Cmxw&dOB?_F7-g;tg3ke+6@FFlHD1W&yO!YDhNC}5L^bcld5b1QSEoYYQJRHl(V;_Z8Q~aYBIhbc&=2_lJebq?jF9I1HXg(tF>P2Znr<7Uh z`Sv9L!@lqjtY-CclW~qY20828>?yV$up5u0vAo1fomTJYa2p>P6C%!hnYl6^o>Rcw z13xSy=$na5k%GZV{p3s-$$8-*=Ad#pj-3W)Lz_naw&$Q>g_0#7)y_K2*DU4NG zX+eP2{IoPq-LP9unXsJsfG6fef?v=OvC-A}ZGK6f#?f@4;|?^)CJFwv353ny3Pm&d zbJ4yDzfLu7W5g70RMKikiz#E|w(chTDKx~(X*BM-QtyENS@*Zu&iTE=I7xPrwht$=j6ZUdxVu;9DY)+uk}W2?bo z!-di4U`P%QbTP)uOSe%$w5JP3Z5*|)LfQ+gBR8HCG(66Omv|YbQNKb=ax1$>W%O3w zeB(6XbWbC{vC$Dt3nD!$p&wY03$u~kSLf;cQt+OdV z@K-}TS%%r#_qTkz?%^a36xJE{{Mw<|5;Wl&d2G@Z$(1y45JL^%M}ffpbt$}$fOI^& zWH!14TVMTPfUM1p=lCdzbL4#qq*C*9g()~!l> z*1oz^qwXMv(v3|#gU|_V9@*Y7!ShmcS`Q@D@3cn}p7k}RAspryD`x>FMj3ZONUFBH z9!8ngmg|02a=v-IkfNmI6vP%D-R&~l012!`>m{y$NuIhnM{An5c#(g-uCCc?g}ByS zNXnpG9cX`0{5pt@R#zoXm+T2%YoR_`DegvA7Ms|I%pau&pXfiuJFD)@pO~O-XHpjO z6uI5e-zuC&NV`O`6`6%~^kYi%9>X=*^Ky&OR*@LbNs&DO_7&@Sa0}EeU!I_E=9v}!|w{J1{W-`BeBCD0-Q;V1- ztVeDJ-}a{xS5aZTf3QV+(d8pf@l>Wi>1f~*TJk|unK80Bb`E+BQuJXc^3+nUv6t)L zCGoa?r6;1-`r_8I-rp9$YNc6N!wTw_%^A2ye|4oVj7vndQYYZgb5-)EiXnrM(kk7wn z_@tH5mFg03B(*$zl=krIxagw^HnyX9Loy7Vw!StB<%7SapxWAEUC+*`Xlo{dVt_B+~b}eG>hX-8c9RP|a^x`7fy={`) z>a^BF^8WbPJ`8%M!UhJye)iI$?1 zlg{{;yJ?bBC8wPy&k`;a#g4_iN`}X&(BKjjM75gxv9WT)OOI<|Q5~U|K>5)fw*$~> zP1dUVoCLhW$~!2cmeq-q_t=dk&KKa=H=Ij&boUPiOBw{W)=~(dfW;=2k_q?ml zPXv72Cr~JYVR59CFkc^5$hvX_BhJBw|c-mR2x7A(W&(oN)k!%GM?&tAA>U8eF|nSy1GNM zA>M((GR#a4dp}u09mNK`@;!;C6Rt%d-rinRW@s>cT`Q3q8F6* z2S=_;g8G{k&I23ETBcnK_Hzu&-}iTEew4iLms{r;m?W3su_jA@_C6=6PXN!0@croR zI|^b#OU)nk4;XCh?DMT19Vt4^xI64=ho($i(MjkOesOG^VfA^!wjU&_7f(iv(NSz4OM-%RkeCJ)ucEWN(}xPs>qCnu`*E@(kC;p5Y$l*p&i@I+XF z+}&Ts;sGK%?>-skOx+sblEkOT(PMgy{kE#-@gSqLVo@_G1lr1rON1^scd8 zgAk*iJ1DILwm`6d3r`>u5 zL^j4t6uVkbh)P5coopGAmJtGhGKw=TD1sVB{N!3$%`wtzRP|cO$y@Xa8{~`bxmR!T z^{#&!<-ORbo+L${!ns)Yl|adD=Q|Jq5Pf7LKSSFr?5mGtzAgHh(#mPQ&S1z5lb(mr zL}i^S*nM}&10ALByV8oLzwr_<*N_sSuI7W|InGn*eK&H!+p0+yd>jCWal9^87ZszmUz57g`NI%OyJqX2YhAw(- z8C$d3aJW=!KjkwX!5^|a)jI!(AVzt&!rxE5R&JbpXVWhgB$SB>&67p3VV!1M*8DlY zqQ}Dp`|Q2F082@QV+if`GA@H|qPDtWjN{N1^pYSZvO#uOUVAfEtC+89mvTR z**~E5uc%E0pjYkxBM_Z0(yXI3aW;28en_#bY;Vi(X6W(Bnkei~C|9`;cz(L=t=JPk zqo19YrWEk=La+LZ|5uR$PYyQ4C?aAADvoaavB&|K*-sQ;NsR$Pn&Yxusem7ib<*{d zn`cmBV}&f+GM2dI=CX*lvCb)XE3Q*?t8{d4uko}_iGLD3;Mg-kMIqf+Hy$7sQT+Bt z#&WtfA4MBQE_fDJ?+r-Zpu5e=mt90)7Q6}q4pY)5u%9TH3T3Q5B`E(Dv)3r`?pwDM zt8f9I|4G%l6%pOFh7ec#hifVjvSD+=B$ICpdR5ln=c?;XgILKKZd>}jtm~ElY81Ph zY>Eu5m&oNQ@zQ806swb%y%Bjqd4V+VF($r(oiRgs91hs-XaXERk19BXR!#H;f59m)oA)Rquq%S?{wPR?9VOViP?vZ17)t%M3Xb4lYfa=1sf z`%4UydnR-(3L%v;5GN~L{WjcLm#W2kkzv9{sxo$ujb_>}l^Y&4s0?!nJuCf^of9gU zjsmqc(0r`YOg$EY$Zynh#XH?Eh6D13y-!^8^nKyT@kasq?^)D3~|t}lo}Jz3k&FT~Bb_pIqMJYfVAuvsX5$daiGHME~LGp3I9JOi;`q0g%bESD+JIy5dFiq z`H$-x&|C!H{(g(7KQHHn0VyG%gI{`w9nzy!Pu;`i;k?R>$M;IukldbmD4G>_ch3=h za|uX}JO%I>QY#dGY~;rTP_{}y?xf3C^a*1y=E2m`DLEYI!Wr3nM zt(f{!;`jdjN6LR*PBjg*tM)BM$<4Titzol#oCx8^2uOo(Av{>-!TZ@8DND2B9gmbe z>m-0EtI_^{uBHV4dDs8-AON@IFZQvTi9s1YPPN5bES}!_8L#-2W)An)Ys3AA!>#fN z=|c;VHuNJgI$^$}oz^YnHNJ9UQh|t#22H2w9YcDLFZNMkZmzhImdm%bTf>suzD47% zd-|1?>Rvbb>id`^B}@M@Wcfp@nWJjNEwsv(swz4UTDb#+jOh`8MIj*zqSRD|(-DZ9 z$r7*`-r;FJC(!!6(6f(!RrQqPAd^#`QI)%q;WWUJN+Jomi_%;a512;`#-8s-d zlDP?VgHPk#n~D%Sg!>St6v4f9Zwy1Of8?i&XMF?x05QREU4b}GAmh;cAY5E zwWZ0E_7YCP9_2Wti6KPd%UV}=j}>DMW$V7zGSm$-(qE7qh^aUBIsmqC<=?!QH_I_V zTz&rvAOr&d8D2I5GdSJkN7(fs?n)_~?MJD%D4I1S9U zRx~i%US^dkZfKn<$jv#qR1zk6l$D3-kR0;7boxcL-5k(|oY){zt9+0@u)MM=av|C2 z`LS%DXfW@2$^xHTL)b6pGx)ZbT=j~v>jaY`c4y_|kjwav4)oSs1flBr)gm}A4Okpg zXb5~`4o8ILnfT~pVm&g)>Tf{Qcd><@=B)HNQ8{g{I2V~ax@1;d0pYXbJ`22^aOeoJ ze353;){%y7m@L3B$DdR;CDbNQ6I^@y5l6M;d4u)N=eL~f0vvR}cTDFjKVJ4FH<6)O z>=_ZuSn9`Hit71go;}=ku2VfBK>DEwafD$5BfeuZPp*`4AHA0Lz5Eidn%qir$!>Nz zJyu{kTilYMOIf?rf92(PuAWf4(D-H#^{z%FRmzL4w%N4BC^pS&Iq3v}Qt8%w6j6GJ z{aE>{ds8@tRVubQE)LCJaVmjTZws+>v}dfn2yF3-jgg8>dBp%sgVN|K?qpNZ@vNr~ zt)GQT#+kyGA!B7Qzf;gzA1KasI`pND7o&t!8sIrWuO-bKS9;y4c>^9Fp1n_(Z%x_C^>#Zgt$jWYsf<@V zkr`ZP{?QG(XY|HKY}`(PvEv>U4dvMgis(FGWc+!;v&Kl9ypu26-FizAfj8SO=Iw^% z?(W>%W0olk<*-m8jOp#2Gkn2q5^(mKci)i(1y(?E1hd+SEm_;Kdwe-^#@fiGAy%cz#CN^`)EJd2^ ztTT;ltdUezemAo45djwgf~p`F((Asp$B7FByD13f5m%;o#OhzN#JoL8ntXIxFl;cc zxg?e>qdC6f5#+jHCunJfca-jGOvl^k8GYBD{-$Cu08q~+V9CF<&tyVVtmG9*HwGa zR;ig!QB_i+ZuY2BgI_BZWOYC+tG+Un=YAo6cYh+$z3agT>g4Nbhr;&gmtX1PuhkCZL% zX4gnLmb3o4*)*@IeCu^aP+BRk@>XeHv;S)Uw?Q8B0$y(QSH4q8W<720My@WNcDj?s zM@GYTcY&VH?JpsxroQzEXd87UhI>rZs--bhl|r8@#{OyIQhE=gpWDh3K2R2of@&f; z5=xD+xsyqK{f-uw2}5oDRDf|+N3M+?Jxjp-G1t1LWMk{k-N}hH5i{KeF&{5P*FTu) zjiJY@Q;bgF;uwPMrkAUP>&It)mM;Z!N_wk~8uz{C(eU##kt{vZU{%>?RC8|4c>6pY(a~86JI!0FLK(u-q0FPcxxqFsYerCsOuoPeo5sEdO!=zgz0t=Ui;Y z`{H5@Z*EC8c`NCm$}^HFu~W2u{x&tY)e1CtD$vt%eBpcKaZmFnqJ>Di&fdqcH&H2q zeg?4ao=IfzU92Nl&MRIQjYh{z_+?gaiwsJweYm5hVOK+DTHAz8GG2adz|HvybN#iZ zt<@VzHDEFsDv2aj(oTUc!-}FfrS`}LyWOY|x!bWlQY`0e;_^RB+6S6)kzZ~L`74)r z(sW}VRFU2CR95tQHuyHC-GnrwhRs}Ji;D5-r|)JDYXuwDn0zjy^KQH=1GZ)}yME!= zng#(?=MQwN%A-@Ifd^kHM}%isi=%OV36_G}{|b&UUfkl`t9ruJrm5u9LMaPEP(EglYZJ zlf@8h;b#UouLK(dT*GgmEaDm&3MhHJ=~jy+1Yil`J^)kQ0H*W7fBfw4pU?*@;yFuk zL^A^|6?vIDFPJT>3*>x~dpm~5BY8{zCMh_B{mYMqvYK}ifK4S&Cn{z;kH1NM3pHlR z@$zUj<#z>ugc(g!vfZxK!gr~ftVfLV;9V{O2TSvb!|0*Mnv~ZW@y%-aoCX;>cn+1X z8!&%iz6nchy1u%EdSymchO%3W5Px;074*H33@|5cL z4&0(m{6`Pm@wWXK+L%8fNmf|`_H~u>*eB!59f`x)NDJlLSMJ`y^Ey~hi8|`vYIbiC zL|FIjnd zqXNkr)*Z zN}>iS-&fk#E}cl+(Y!Gt7Xc$dBrl+bSF5KxjhP;wJRjLVh3x;xX{n9xx0ko7yO%(H zQuw^wIPjr(G>uc-`R8<@3pKng$O=DTrAe3$q};kb_+3Xb2VS2JeKtM%!1tY%N$o4^ zH}|97@{U6)M4*?)0e=x>`cM>!#&~uE5}asv&weUT`@4?+gJ3+Tkrzakf6!aXf3Q5( zf219iXcqY9;Qft|0+MR~qQ{I_W~9KFOMueyo5pbcN$NcTVCEmR(MfRU zft_kSs$-d?PGBvDmd#P9!$NRKY8nu*{c#K4O<1U&p0or<@c)1n{NGBpqyw%be~^?vr+jo)#D@c*0xm29@^VrKRgvow{b&|CrrYppMxiHd!yn#T{h`;$-ZqOP>ZXsoI& z_qk%qeav4xVUui8%BaYo$e^74Lr}frGFbb-1W=5IMp@m$9SMSuC5MTC&qsA84X8pW zf+Cj)YfT8u0HXF^i*G^`;5sPTq+a4s{q+-5%j2-W;z!6I6sr;q%6b*x7wMG~cKujK zl5ozoqj!NyQYv6;{t~7DykOh#qrx4*(^T1N!!>o4sM{aNA@xwTKHSgo^Pnp}lNh?+dG;XWiEkjEn{OY5__2+=s~rELKPo2hB1OEe zK>2HxU?t1xc4ZUf?y0ENszY)fdwg=m6#HPmt0GfRwX0D}K$elA^15MXRRvXNqe&9qr* zYu)b>>_qRFCrO6_ZKzhNsLQ4EneMtAZ;jQ@GsQ}qzmypSgHh&;s4#`&9<-~e7Sfl; z8-92D=}6dieI)3wm+)12j!o`F-?>3WDbc*y@|Cyl^sGY)TNrgYq%aCa10)&Zs(5h6 zm!UGMz<%-OF$1dBtG;^|QS#xJXTr)9EW=x8ri<1Mh6JFx8eH}0>Rj=JJmBuX=gCXh^087=S&UJ zn)+RC<@dU@=w43e$-VLwIIxI5qy`pLv1D0%)FQoS5I0gP*fZrvi|h%qKb3BN#Xk*6 zY$mR}R_x94!fxb!)q=$AB1=VAJJ~Y$h|D(BS6uUozu=0`lm(c+B{lT!srk#~FO%co zg3FeB_oPH8jWowJcQ8@BV@=YC3$5C*z@PCKwI5|{7-=ZTH}M$~blQ4ut5R({c?NVG z+euky5X*Abx#&i@l@C^!3NXlkT^hU>Ue$F(%dtFkJCRQFRcj)=Jydv;)O6<#}2Kqs{5-m z`@f$41*-B(Hm3=2wETQ?)q1`5wKm@Finp)b`qCFYCoo2Fv}6-6_mc7RW3q&Mb+i`` z;bN<{5JZV4?-d-aeU)%c4}M2Kr4dMZpol>VGSUFS&H1Al!@9v_ya=T2j4KgDUylwB z&!Q^lhXE;GgS#`a&kSP$qgB!=U`Sw-WTz#zr=ry_)!hPD(=b;adjqdED9(8Di<3OR zZ{ZX+R=BeEGE_(3%xcBc$6|ZcVy7mxR&#-j4;8(#bjP#jN$u!Sb$+_aVf=@YZ#_Jb zbq~sW2ElATuRjkBgU3boZsEqSM<&`iR6xA@=j$boy8OaDqtwT#r}%5WMM10VfE{NU zGD}bU8krOGl)gJpD}3b%f!zC+I3&jYa~|>I0;(cwtZ;krRI`aL*O1vRkU7{k_aUpV zu0qbgTvRFVI`{g-KOQ~SAQ)kjuCOY5cUya|i@F`$tN0Z9$ZeL$wjXsJScJ$#2Uf#t zvudr6dWY8OdPb&=k+>lS2)U=%N(8Gac@QOjOof0;6yT7{OmfAnb+uF|cSQHlxKXsOY1CkE;^7nvN=SPDt;;Au`V%e)$!Q3Ncpe{yzQf-_z~O%CEH=9 zL%JrJg*cwuS%%PZocpOtZuH4E+YRy_r@RA7-}=8_D4)t1{!pNjrwz5qVRr+#K%L^r zCXpnM;40W%MGwX-?c4N>(8uC0U82KZ`wIAz9JdOz1Yd36!}2ck`L)lS z!8+{qHL4IS!4K-{RI*K<{tA+iIAKYuN1gy41uWN}GL`hJvw4(DypjYtE8?oHRyG35 zQC=(5X$y3V2Yu{M+30duId`;BSQ{InPg$N1ROSw%Zx{MIRbzS2l8}F$j}Rw66`WhU z8b)aeD+@`M6M$KxKHh<@2-26JtKqZ7l!}|&rt3sk^?MWg$pbmYJuAkhAA~PY7Uzv0 zufLz6bdvfel(MOi<|_efgo%*I>bE{NU=!-$A4)9;zrHlgTt*1D!M?4`mrtZ@E|y28 z`Sd|^oJxK2$8t3IIeEqB;#Hw(qLVf_res}2cUFn9TJmWD>*HrCw@1}(^0hRog-0o* zP4-VuvfyDrK*bOtYUz1HJ2u-`o^6)yn;b%ItJ%v-;QCv;E1F65Bh_jYP!vbGS<&`- zkwcc91T6_L2RDY}vm#YZ--f6+GG5n?(JG@t6J1HqTI0wO!deBkcXKw2gf)>zERBkn zR9bLCXsNUw0Hub5r33@+Lp7fW&ML6AydTX)*1vL(lp?LiXnZ94blNP`#Q)Z4i~=}& z0g$tm<*mQd#NIw|dy1o|>AcqQ1at$*&`A>+*7|16)kUA|42e4EKeL0RF!k?|<|WJykn_<3W0c!$AAu zPQURo-=vd0`>$KT68(69+9w;7$8S7SAgiQ6bvpU?HlXQbxkYh+`79Izh(M%vixOs+ z5HW{sJuqzLKh=I@2SA*Zsz|AAN_1E6@2+H`d4@^Tm`^akbYl^W#T&W)xBl}b4V`~0 zGo1T}dieIgZ8Qx0fFVT{K<06owW9FL8yf$V{Juds;^cCED7Z5@rDyLw+l9I{s;iBPpZ^x(lDGs9w6frFsh9FcA(t=T2*utPP2Ad{2Ior?0tJq6+sxFBG9bLxF`iT4zESE9%uy*!|cjo)~q5o+-5-~g;) zBYkl6i_^D6TVQTJZPciUZou`FaBqxFog00 zlFjdSV_G}*Ht{hLUU}_~f$;YMPmo7*nAm&L$q-KIbM=P-aL;Zf(3j~(pQniC;cy>* zmvp1&@yA1>Wgdk2zgFsdWL8T&S)SA*;DN5#8^_P$?G!EFMM-OJy<5A_e0J=1oTLBf zyqeWBzX&99=oj{7l_t0^jX6+VV{pNEzi#HR1pf03MCyAL1q|M3Gt}?o^;wr|3;opv z3vg%-<0pS3^1xM}x!@T6thRKaM}A8~7J~iECRE7VvG^Oal9nds@i&!aTsIqm+%_3bDT)l3Wk7J8 zb=?_UiILZ0apYt+&Hf1z1@8e>{&p)Z^=nM}N-d+a*Qsq6ZQ}Z4=;9J+iAJV>Zv0Z6 z=(nP`Js7}5H4jzN>q>;C?5T~uuss@7LBn*%leG;dEoRDM$7SsTqTQuAFH~9LK1O)O z^;ejv6hHD{h8tnQ41~qQTkms=Y`XTNq)|(bTRw~vw709AJ3A1Gk0E^1w-9$acOgVR z8o5Q6?+%X*6v$IsEXYfpmgrzcsE+p>D8>A48c^NIBZ;eEcgXYUWZYl?jz6*tvdaNE z)fGK9DCE5|qS@2HxGSrQg1LUzl{Ggx7=nn{fAu=I3? zKkoHs;RkNepYtj&v=8?Z7(9YfM5 zxX9Q`ryBBB50888=-v?T3D?T{xLw9oC-PfZxY^-S1aN4fjaPKd!2Uimm&1t#(!O%< zz&rXN>$rCa#YF9I4PTyJk@09_2GtRptTY|)n>X;N8{OBi1v#vz1t7rYezS4~hNM)N#CywgPIRR)@(1J|WiH_c>373ivMaHs;C(C|Rv^QtbD- zeR&TC$zPrF5<11%S$`s`{@$|=vk)_nWir$y(h3`N2s(W(y8ZoM*Zy}dTmK_~ z@vj!}f9sR;3M2YA$I@au@bL{ml>+mEoD%%!Fa5>-OP$60Z5*=xVd}rNkgl)90`$%3FQ+R((imNGS%=R+<=ZsGSQ zvX^vnt95tUd2`z9$$(sLuiP*We*dqd2e^GVbuRXgXvIcTx=e#dws7bC+;Z)uzo)Q zBQPouAm#fOSUUpQ)*U7U{^CP0`VxT4@&#_`ck-;seBo#+CEjJRf!}x`$Ck!nmJQQs zv+_?|Cfu$POG}3)`Ie16`(D?tpKT>UX3|euMmJ!|*nO9ACcijQDP1pPi?>3WPT_ie ztOT7ZxqT*4i6(fy!c&#Y<9T&yV=i|+#W+?~dI-i#Tzx5wJ>ZS-^!;b6E@+GDxXq$J(0TStpyZPghKTr za*?LrRWH(pR+J}@`Yg{IgH^rlGu9sm;d`y&>FQ@+CYwf-c?R}oF6ks|oxSvsUa zq(@;+dk{V~==x@Lg2IYf%=Q&S$f_JGFqu>rSZ7-pAV5o_M;zu*7i1X%h5wiMiI_eX zKczLG@IUq!@yIlPZrz9*QHXPSr#j11V?WA>`Ck3{V^|N*0w=nyljl`Clg4hlEAqCnh}q2(zEcT&fz!f4M%&f276A%D+)8w6>uzU- zM!6~H7h#X@-A^32OSG^Bnrbk8IeEmlNDCyz2)(0cb5l8wN^`5rc`r0%@7T^0tn5K; zh;RdDklbHQ7xbqY{`xTz*yGVOyRq4RZ&~71?(+H{;zTl*koI~<;#vOv+5P8TPtrJg z&lR+JWTwLQZEmXtZl;sR@o=uw5~Lkz;)KVOx>pBK1k7%G`l8Z<1J-7S`(~w*Xmy9^!en5 z*MfLX4cNU~^fN)aMx6!Kx2RS~gtK+ykAQ8hbz!q8rWEoyhG&FIt>IRcZp92GcSt-p z>O*>OhRdgUo0quhb@UJPo*?ReajRj?f*-woNDg!q(mw`X6{Cz*N|GX!jgWt>UI0N2 zb%Q#;l&O4g@$}loIJy zKsuD}Qd%UWTN*}6YREy6E&&0hyJ4goq(P*lq+w_fm>~xkcz(Bgf6u%1+54^U+sEFd#!Vw*IMWKB>o9U)vLm(Mh>m#&-nBda}`>O2UPA>Z;e;;+V>Y1n)U|W zE(#=)A%HMMNs*gC68a3iA{}8&^%2uQHuPtdZ65aPzf+{G01f%`6Gok$Cf}mXuzZ$< z5#4a5_Udp>K+k4e{DaW^#ka3~Ta$&9I)=Poy@kA<6B&BA90HTi2`ZPd5knibN!$82 zL#gP7!Ox%6G#mv?SDzW_JN54#?;85h<;8Uf=((En6H>Z>-3DXWBA#TEVt%G1+tKK9 zWPwn&_?{KRwrTr*_kKeY0Kv9A^qcw><>Zu}_v)Vx8&p8AML*yW)mq+uEV^r21)Lhc zy8?Uyji6g@Q>}lC+&?qQuP?HHmLvo)w%-4re?#`qlLRTo@#LSDAHO9DLyzBOutC3@ z1J9MI6F~c03xKfN0=>X3|4-TTFH-Wu0M97dke%h(G|u&h`XjO_2i-%H8D(d3 zVWCGEX?MWsqZV`Cu2^X+Dw!kL4W;sz`WbcP_AfTtD?*5Zq}Qyc+IQ324+Xa;yY1>e z8A9p}Vir7$Bj#6`Po-+-rqw7#dRSa69EyhqXkWG;(kBDwZWmlrKw+rc{``tFKsG%< zMiR(E@+qC}?(&5B%B)2fdWm|lQmIhF(_^#ZW;m9;AZ4zpD&y-DE-;6PRs950h-Rps z@!3A7n;z9vIZjb5P@oGwLtC1zUIH}PeM6ntmh}zG(4uoD1}2(cN={$@7gyPiw5GPn zH9={%^%Ht{NFh1iNX{mjRRZ7>kem+ zHpMZct|*^ZK-XR{s9(i5#nC$SB+6J$ydIKgzH@h+a-tvO{1gB9RO}18{D69*3RYp3 zdR++=v@a4^c6lI8s;L%(SY@O)(|+E95rV+6k+=-Ki%IaVd(*Uhb3EB+9J+u#*Gpbk zsm-FwTRQK}RNMP&0p9Z-u2i_G>&n6TO~q?;J7A|he)1uS;K2(ym23_E>4+mJq2chw z?S+yJ=6IQBia1{-o6btRU7phGrZpLCJF-%0nl(Sw8I(GqgTTr)QR#K$; z+2FM#FLd7Q3RhbX^Xm3k>*Lsg5+h#JaqAx-SAbx73RG3Gc+!_Z9iOe+Zo7VwHuzKf ze6PUY!3ZA`33OG!4BFGJr^?&WQ9i;~K5FP)Kceno9LXOA#H5inX`r_-5N$*8CGy1I}Bfd%uC85z3ij&K~A;p zA^~ur8r}Mwp}4}c%~!C0d#b1J&YV`HSy1C?nq*HRqR}0p8R3!=zw5qM4*k5{A|UNA z?!(4VF8*Jz0Rldl_ouiR`^gl^&20hbi75$kWWBnHn=+zxl3^g+??snvpotx zYqU1`DhMbc$Hw)x7{dB4NNvB#qNDvDUn*^l=3OYX344DlH4b`=TdkmXY1tcD} zxGE3gu{0TEW}(zgANGxF?c}xNr4{}_(*FUXov(9U|NdpkY*7JdwaR9`^HerT@KviY z$n|`>PrAB{k2hKn`$a5Ir@ZuYF0*Cu9#x!1!B|uYpcUNH34|OA)3OskMA@mfflanHA!& z48z6<6k#S#xuEC{yviWi>O0rfDU$x497^k?zQOsmdzrMdl*C`#5KRQdNK>IQ>UfY> zkbXslpy=IF>+3eFh8b9dD_{p45r1pZzwr!*d3qMmJ#lu1w^Ny*Vtl89F1b4WvR=P^ z4S-BvCaUex>0m#0lm}6)O?!&bM)mbqJIb3;Qh;Du<9-+A%*FJ2et$&5`lT5Dp#aYPk`n!H&s{Kgt-w$RRM6lxB1t`eaK0L?{Y<)r8n|k zp9lGhd4k&A_y7YY^Ahz6eqfwu`mcK6&&-m0tS3ubPtlqe8E$DqoBGDgW=*vL<;npH z2d^3Jt{G4g!kT>b%fo8{Ax-KL7dlC$%iQ{=K>V(fJA3p>VTS2>4MM>ugNQ6)h0T!Y z21=BsSZEI=r4JQ5PPbL5n-(Z-rAJhCc#}L?$Hgb>CA@`Yi-5J~hi-N#q}NB7y+-Xb zQT?bkm~JiCY?bZ~@B?_H0f})y37*p%U_!Y|kD1h~2(llP2uU<+Yp9RuWe&F~qXw6+ zm^!2w#@y2>`(MOwY}oL9vf&xHn;}E{Lf*eTYnCKO*?r0D$4>4Ji1i$Bz_z`~0g0Hf zZHj`j_^Z=zyekjtvat8#OH_beGy4&&E`dXA&+e1yI?4k77!tp_@0~9DLSxLZz#7 z+dD6jo0O@(#3t==w;*`!j>k!&@1YhEzE5u_3`KJ-Z$y-r2?SB;C%{fv zLSaHVMt|uV+xpHMh*lK1=I!5Q1LSlhMt?HGe`)hF_!SFmVVeF$Sm6fndqOrK-29Dd zFHyXHJX?762?#SBdm!i^W54t6u)(F-RXA_h2~n6=0A|FFsy~BL49Gk>BJbPK-LIxH zL_>iqG7|W%1EN=QFras0HsM;eKD>PosCtx^vxqYPEct*7{%;m0oNXu%Bt_x@?<8kzKe!V#DO z^#S@#ArQC#idrsE{Lgyw?_H$G$X+`TzDI#Rg@ib0@GE+fRwUi~rXNW@yt8f7#X|x? zdyfhe?>|w9P22CC->0RWuGD+5A`Xsp-pj0cg7j`m><({L*;kkgq1qv`lmYIrBdYy6 z7=})Z1e%!UkkBhUvvW}jl%B0T$Tf2=fEP%KV1U(N1VXcP9oh^?s?*ocdYZ;jtB@6B zwd@I#D_gJ;PF8A*c))Xjm932*uaX9mw8IP^t&JsB%+=4Y4@IMNNpG>vu9^(S5Psw1 z-v%rUep`td#JpBVmi_2(i0ED1F3uac_Z4Juvdc?AtG})hDfGYAma{r1P=s?qvF9ua z<_J9V8$+ngy*Qxj@k0sq%QoKeb$jU*h9hBo%RN`ZKS0qw)?|cfMjO7V3M~`HBx4zx zmZgwrKh@$v^3b-G)wjwNvz9a<9Q%bR+DrF}giGfhKnsl%AsznwN^p5Cp%*oW`hm4o3NXP`BJ_KE1I5`iPrts)9yRM~9yXKuR1bAIsV4Sd8 z`tXc+7|L;+!YOS@Bj`9Z#p0l%H)-eHI@uSw@0{HN9M3|X6YJM6jTOb8`HXqH!Y`YP zoPZCUS+|U>?RE>wsNE*j*p{CsXK|SFrRQC>M=KT^8eG*YUfxx7Nj2Id9%u}wG{(0S z*)_?UYB2>dAfnxY@ls#|6(XZ+dnSW(jGi1;Eqv}~OuGO0)iZIzji%XrtW=v<#jIo$ zRe#z4_kmB41wO&!KR=;ag(FPjuyeP^(5I-J#`w3sYvR3TA@mYBGR;?yX$9r?nB`;Lpt3L? zZ9EUKZVEx3hf1^Ft0KeT1#mpVy0)hxfy+x%g2#!)oK7W;e0lj)AYXykUBse9(dHG{ z8NTN3n*FMQ^@@YX#X-6X(taHXJ?AI$*&}1n3AI$;yS+4YPVZ7r&toF6z1p<^jofqf z*(~6QibJ(2-4D9Nq1W9Hn4}Po0Z-bMx-*_Bxb{pu+dgSr`_=(Jza5ut=@1OW_M&DT z<^2`~!BkGKvAX$4gIuNSF{=6U$$EFn;{($Hz#MCjGR|q$5s|=QVdBSjSjn z)Q=4FNP4v6Rn#qGFIj&+$3T6Q@JRDi{bzgusD2g;m`P6VPR%7X&`~#a?c9Il$Oom^ z9s)|jx6rRkvFE;FmJI2(W%Rsv!=ay=gj`RHTESCxCGPjF$}3ZPwbi~{RF4R9E>Y~< z;gUwgdi&}N1xyof;U3f1u4Iz6QKOQ3>vX-mJeOwTw&nC}muP1Ugf;_r$L&sF<)xTC z4~JY=7Ub~U?X4jC97SB+7a&o=5MiF&u%JRiVfE=_alV!VvNy%l-|Y_vOos>d%P0BQ z)kE5`{dAVoOTYQb_;Beue?+nfQ|ZmY9o$sGGrR*#QNKh< zU7(6aKRB!BcD)L!cEMzD2sj%u9Xc~`y)=bnf z)I5u?MaS$I^s-QT1xH==j)T>vMjDS!Vlv^Pu#*qv$0K|Ox4B9;6(n+=W5&D8&?1dP zy8xqBUX<5W{&nP4_Eq*8fYLQz9@YcL;^)JEfL>a(1&x+h+@G%9&JX`4n*TPJ?;R~x z2Zy?`ubz>?#FtM;#ne@?tu%9b(*Xs?kVvX$(N85XRl>@Hqwtu zR!3pN2Z<;Z;rVKv;v=)mqQ?yHZq;W;j7`dgk}G&yLJ|PDd8W zp6k7jS|6PZ`K!u%kFI>Bwf_=dxIgG`cAl|dDm=t$n8B#iisF<%>{IU?+3@Qsgs68* zk&m`WuUtmbrm~M&LWQK4lvITtQq^s;#w(Fo!YKkJ%vM`(x~u;Hj?wl$vK&c#&%ewL z5!yF1ZG<@>T>ba=&ZKP&jSv9v&f5N092KTyrpfgyQBIx_QTQhB^93MLCy;&fwVHxC3^p2TWvV)S=d693WTF6L8yZVIY z84ue&GvYQpNPFMKJoQMzDG|$p3^y>yNEczQ2U{ z|MnAEg#5oHye%)E#0|jRC_nc51AXp5bwTSd`>7jWkkd@`lT*4+AN32)4MY&r7#Xp> zbWv&P4)j-02*?Q(Eczc%fr!M zMBaY*iWgBa9Py9ggVZZzQY2i`tM8f=NqXTm!#Ym=n5;gXHX~qOoao^l63hwe6WDKy zt5$AgtO_`dkB~k?%1ujI20wW2ZcfmbHL#uX|8yl1XWCPK+85Td>MI8vLLo9Erwwxs ziJ~HyteCZrv%&Oiu?m%{&P3;kGPb+wwY|;Bv~+n^w%?i^|2ymg{8r#e1eni4+#mJ7 zQ=rnn1E|#EN)mu9f5A@wkGQn4na-#i3^#eEUn6-&M9ZkuHct5+B>wD&`Q}`{PwOi~ zkJ?U1vtA=4{Y4U|ON@4~Mb;YT=xQtI@vpZh_t!mVl1_@z`)^#~23oTe5@RH8-8y>x z<;}+3^3xAXZkV0Rg#5d+t)`8p2H$js9N018tGca%`d7~1p-~<>n@7UrY8Z1t37948bG<+DP8@g zj{zLE378=`cO*@dH$AOfc!Fqb_!#KH;bIgN^nKdsIU74C-}Z#7SNK0z6ipnef4H9O z6o&O-BN1HI{zTSaI@pu~#;FzyOPzYLB{}ncNx9n7b<;$vbYj~UZnurTmaRM+RbD0; zB{6XXdDp(Z=B;XohoThE&Ofg1x#C-GF4ngHUbH;NWh-S;Nqive{g>(`_P3?>bl0KE zu!K2eFo``lc0)l7JCSz|@rvf`PG+&y)bo}gF;m0XHHcC3v3pCW{k{mpOaM#)Zg~I!`CG{lAE8zws_n?u8ut$u7RS37)wVyg2WY|3&a_xczYJT=%y*-w zZbs4tU&BbLd1tNdnfMB4E!Da9maX2mDRy8>xTZKTe`}A{| z1kW8|V^g5!lm5#FW6yUCf91H4_In+elc|gJLnsrWnSA|vI5X>>xW>h*?6Vb&vf8?P zFQ$6KucTNauZa)%-E5`7VV-Dk3X*n_DO2Ygtl}72{k}P@H`6W}-KXKw$lxVa_}X;b zI5UE~f`>;y_YLbic_>NSB&$QR_6yZ+_<%VfQF11Z- z^NWp9?Hc+M5Z}SWBcPANmA*aEN0N-=?OlhSK?=SqERLC*oT%;)*?`aM@Dq4*`dRVt!!Tx;y(csZL8eE7G_qvr42Gxy9-IEDz_2CUBP zxY>JiNdkko^?#Hqn|OcQXa)%R)}F%PACCJucGfbNaTBHZ{oPjdK6*NTEwBH&K{fu# z^EB{(l^1J^e)=~2Pp{vCiQfGc@g?FJ)rtYzoU773pR{o$sc~v=Ez&1;=fgi7hgUj= ze92y!lGHWSUC&%TQKv{PHdacVs91AscmcQq3SbcZww3n1r~xnMuTaX97}nzHnEkKD@9+DE96$qXLcmi`(s%QFdjxTvwy@wY7F?nUu^ zy7z|E>W&Zf>|gNEKh6TWI_Z6pdprfeiN{+gG&MtN`cCgNpXoeJt>sj;Wa?;^~* z!q{?MY?{p>5+XshI6!BC!tVmz4vuVEg#l-_$(jvFe{`HQn(5js@jqi;qK`aTW>s{G z@Wt2te1F~0)CFTm+r|aW2uktv+$g{dpcHC0%OATe@@4=hGE3iS0<6UYCW$W_4p4f9 zn#yC53xp!Y<(YBo^lXzYirnc{i(61@S+`rdmrtD=U$3~}P@U|Vu>Q8&3jQw7FtPa; z@W})=<+WzICa!2dJ_krKdwG!e@JBcJm8g{&O0i{VjRAG}vj`U4moS};-(f(+t~w*G zDtMm)e8jF^o~FSi+YC?d2(gqc=gx@q1TWgm-^mK(V^L!{g0*F(GXQ(0ITR`JoNR$tspx~;#+w=O~ zxX!Wq@rUv??Z~}H=-mt>&xen;P22w-$#DBoOr>0TM6ssf%Z`bBWI!(C&wFp$5euvv z@VwewCH!QrFaZMbL+P|#K9>aZ{+Kqh)}qnkSOkv~7?(5XtSk(XHGimY(v9a2DXBj~ zwVlF6fdvEjrq%!8n=&1aQ#ckxY6nNZ2yWyLXB1{Q=(X`q+|!4y&>xYL#7N4 zOBL*Gp^l)Iu;MckKA3E=W@@C7VEdxq!E$D>R~vDArs-HsbGYd>7x9C6dVM|M&T<6f zi@C-_c87f2{{^O@k0@z0pI*haO3d}+J}QU|abP&dybzHpIU4xm`+wYLgT*BKm=U9^LFU zo3mQz+}SD=@$9`(Gl$5>qN2j=qpQTwu+6rq`}O0(U)ITWj&kV@*jxuzszl)pKmnO~ z+s_PfaH)DlaryZ2@xZl6&djxI<=mpr z-8RZA@ds$$Ib6tlnBl-XI?a;7-KJIRKmkdHHkJPa6w1^XD>}!FXxiMGA^UNx$3w%` zH5L>sG9&2)EnbR&8#}S&lp5x*#p_iWa@Q##=nT8e>aPdisNFMUG*7(rHx8N)B}6_&hn?wADQNt zcuF%zjLBXfNTL3ZqI{=e`pWQ2Zq`fQOH80`3<}iOfdC4bhN99 zWsirCxat%=@0v(o!D}ex<}KZZww#amN)w$xY4y7bcMZ8pN$+}>uA6igdfq;I2kEHj zQmsZlOUU5~b$J#I?0mA$?e-UzKns ztZaAKFb)6hBdc$JIkNieF%<0kZ8Y?%SbD7)l6pPgejIZfbpD@;;~dxT>xfDW3W|ET zQtY~av_6p!Ig~uz2;|a7K}{$2@><5-=Te&?Zf#00yop_RzTh%MYt>e^wl$J_>K+;! z3%SzTnGB0cvtc1{1karN4#@m$5%kTtC~vL{^{zRvug$moZ>!Jk*S%Hr2N>B0c$-5- zn6NM>Q0g&K{5=&888pa)V8+zP$f5F@IES;*JXb|Gc8y4mo2(nwJhX{)GPJirheW6+ zeg?pZjL{MrI~nLDEb$VV9LxM2%~qf4nDW8-tJ*tUeY_OO52k^ix8P#I|vxZb(z^^6!*dmm7R|(oGkl zyt)vDF>Qw~!t85bAqILyF`S_YYO`yWqI4odims<({(W<}3v1`<8+*IyABLrGUtGeM zx?^Wao}vJ6AL*{-ixwY~*5iSfpFR&!Mlrl+$9*MD4Z9SC#P*o7UFDW}@KP2}}&tI6!cqcaq> zKtAZ%pLulZ9rqwcGwmgt9<%VXhyJ*N(9?XxV)#mX1Q2-y#zXQg9erKxDQDjqKc)F9 zllwY7JB-$96lYlA|1>f|W&Uh|(rTTk*%KsJQ`q5r;udUgq5)`%5^EZOsTFd11TMFD zK6ntTR485@s`*V`s_7f-Buf*z56Rt^EZ+>0&isX&Zz4K3AnaKs-{aHjMFb{5VHTh)Crie+}6+=C3tcx4n zZU*g$M3m8W5IGb@%mAcMGbTTaCT9{9=koeyl%?@bO45gmH-m5GSfZcK_0h;;t+~E7 zt)%G8oau#YzaijNj8~@QC>hTG82%Y_i%(hdCdjwiOuziAuA?huhBY^>b?=$xM9C(z zZ$vF%Y1kwXdYXyyC4KC`lI&N4W6tye!%Ra1>$w_1jk>EVV-Kl$wfLQqEDz=kP(l7s zO;RWFfkI#f@}9Xy0e`dCQ_HmJ5#W|^p9~l0+d=W%8*iIIX0OHdIM3q0++}V^K+wpZ zU`;S!Ja~li0-V?UWUu2~SMVCksLaahC}>i!rM~6|9&4~+&H-qE6H$MK%>~FVFiKx@ zl+A)?$2f<1~&7DT0)PTnCfSXklEuvXQ-;D=K4CTHl(69(E zGC}xd|Ycuo)wRd`D8{ee zGCEnU)lEruj)gG>+`}_ul>AIS4{?G{N(GOEL#D5Z)n~!o7t$Ba&J^~#7w;`AQ@)=+ z#c8}+r3PKdEE?ivTiD9J8fcMkaAr?)nwzFt70H304v)cOb|5SE-QntWYPd(XY=#fmBNG}d zfcam~q@BBx>{mo6P6LV>@ePf2)xrBt>q_@ylf2AJa_2K{D#gU-43gj|f=V&p$!^I4 zR!h=UewNS`=OA)^0X4X9yVQBYf#M6>%RUJN-a{JYRnCaHs|=61xpvHf6$1O}Qk2k+ z3F8`lzT@!y7uN$%*(6U>KxbI3)$u#3l$+5N78n`#J~su8CJxI;o!GQlBAn6b4BbSR z*!s||2JrS5P`SOYEtDeDQo`%eTE*6Pu{94Wtf%eIhJ1pd;Sk!SYL`2$?|c9`}RRMb+hoyp3FIu z$>AW=Cnb2eSWA2}heA-2&23S^%{B~Fx)co2bO~pTEzcgMQEZrInqZ1}ETB8KRX;6M zCHyha?3Q!GJM3pJk6z2rcgRD6E9#LWy5IDAbg8A{EJ-UrW{n=0y80nMEl>|Z-u7Vi zxHdh)Dk)Ps!ljA2Q14MkDSI4XvIJ8k7eI2xwNv2YZ|6Lmr7~VT#&fhNLH2)R&fPdq z1|j1y_DHq)jE%aOH3m6yhgEvZK2zKyzc0HZgN6KVQ}7qx6wXcwVYpc&1F-;&87)fJ zkwRfiP^o7)9qbQw^>%Z!`vW9!cThs_4l}k@*AHT~PsB-c;16?2lKv+!j~Wuq*o77=qwo3)py5UEH>0p z+P9b|rMkzpU^ez78Kc0~sAz>jqMZw$317$HJrl?`pKNm$C!&@P@XLvm;1%oUI#zO@ zxo3?fPE(VZ={UHoSorRzE{Db!}RE&RirCH@WPtwtUO;{GDh2%xdOa(jKHY%HVq>G zoCe;o48CTBQ#6<*&RozDtk)c>KK^p{@ohB^DfNQ~x7QhC1iY=x%jB3Gbzic>1=UH> zS_shiNjoi|KL8dZn_MVQU$y_FZY!1H%%Td(zhS>=^kGU*ll3A=`@v_&W+j{^O;g1Z zNP92!S)>OqHd^e>7v3ND@9r38{^wBqqFEw7!p1>Mz+JnQ_Ycde-yJKRHx(Se9(U4#kbfxWE2U`wl0xCxdS~p-0z@sD;S> zW^b+A%js!7Vc(&GbQ^|2W#%UwhXVqylh_gyJO?P7-gbd}g2~pJ7{dMh^H9BW9c_|z z)or2L^(kOFqo)9|MXl=JX5HFTQ}yh{tMWD>60rt7toxentjb8QZdllz5jX?J<9z!? zdV~C!UK3^33`e#579rMmS}bN+T^65bihcyW3VaB{oM(e!LL*aZE`-tm>7=pkrZxHp zRfy|KYfh@EMdcZ)uUb4qyUcujTs+!$WAfvS*d%9tBYpMJbH%!(Un(}ubKUNg@cG8e z&=&-<{2i9_e_}rW@@L>Cg=pp|GVQy{A|S%N6p#nmiXla000%Gzs~p;!fnWDcVByC3 zlZ%&PDf7x6PzQ96L@XZ1>@7XQpuzMyvPH51`D2EEG4_Ex|LmIo6Kaxb{%f&M00X%1 zKwDV=t9mJ#rwtXpx@+f-J+?2{e}W5tH$+zB!+kPT7QJCk(P^~LjFFIrZ?1!9;=Gzs zo(GF46Yyv3Ww+#)&0cJQpRc}X5==qfintZA>TYpMM?Y@kj_@5K=bRQZWmGYcXn9(f ztwwVWk4H!i9TJRzCLgEr&GD~nFC5EsOhu`my!Tt;Tc+ppk5G+-SVd4-YkLpqPky+u zcB7k3QVEnHjz5Yu@t1J|QoXkj=k9iLDMmY?Hc>>`EKhWfLtJFK*wn=o{BUDq?Aq=@ z5~Sju68=Ewcg9q6Caylla#mozlPCwtzg4z)KT|XHaimg=%c2eJ=GUayyTjTY2GSagFEX(n6>~W262%=00U$5d87UmBJ zhu^yvVK>@QBgF$q)=pF6Q9ohNS@POqlkNJH+GD;M#`(uDyu4{WUQL}rNjggB#&NiTd z4*u|=L#KoU749S(X{n~#6(rEI@w|@VA9dT6p+45u7$W$}?&{1qNLoTghih%kbnlJG zy_k^vUs7dm@s7}}Lb_Xt>A|IBo+DR=wzEL0hSgq-psia>M2UU|h^58%vou-l2#w#9 zX|$lnN^>-k!l&^v=bCd^VpeOj(MKmnZ{}JT`UT9#Zm{SAdkqgE zC4(zxulMdXOjS6hT%!-S-)|1v$k1ltgMhAR3PJm0 zM}YmT&K>vYmr%7{xw818%U|A-P{sct&m?t9$vQ^gxaa|$Dr21n$akF#|BU>!^1NQF zj6Kn0xIRhm7G$fdQ^1h62&zU=|Pe497RHKV+HB_zwn^MpIzh5X>5FC2fKI86a z1zNVW+CO80C|+=ht4bUzhAp9VfU_#kPSnOcH^J0*nIU`}AjR=uZH$nMSvrgwU%<9f zP9%?C-jP1i-wn*<{oy7e=`tK-wTVedit&R~FU)WkN_*n?z)$Bz?l$=&RjL#M!jnyk z5gLNAvgxp=8c+3Ta_^{=bjLn?-7u1e)&?loWK#s6ou;>sZ3MTVas^X^TD|Vk_6F(0 zx+xAAZ8W2bl=DXyBC31zWAcw5^tABF<5!hPtd%)@#dw8bK}w$oQl|^~Yh1`_Vd@Ha zfT~22$krl9f^Ej!9Z9sKjJK!ir|Kz^g@s~|WVry4c0mni%+Q2qt&Q4+4$2}LN+@w2A*?NSLnYR95KKn5<Ne#Jg#h3HaJUrr=a2ZnR2t-%L2YoK|QKc;I6Sp zztX~a7n5xFjV$Pr8kF)k*z(gRkn=O_nt$h&pE4r(tvdt3LS1mVgPxD$M!7$Jol^as z?I01J4I89hPPe8V-ddiT*%TQBzJ;GGY`|8DgAf^nd&s}Krh$${DYWf^r$dolH%|FPR`W26469*z)2IVuCC{(CXFOxant?EcoaRJ7w0T^Cnqyp!B~ zJz`Iz%9=1Jo8nDq4hHv$v+GJ)%!D z1rKW|>cMW@-RjIq6_|5bb}r$75sX2i?2US~!78`6gpM`h^7H9CW3ZmkkiUnUSIeyZ z)C+V?{1Z7CSAr6iMYw_R%ecs*rw?(8G)+~rQOSL?>QRm@G^=x9@huNk?}Ez7n*RLy z*Ln4G8%Fa+s>~Y{9CNEL_)CLbHI*dbzB3Kta`uGlq4HE>C?2JGNqZ&)qbL~g2(&2Ya2 zj6W8u9&^3wT$g3rF(9i7S}1I&uZ83;k=$Rb zaf^RZ7^=beWBmV?D6U0I=d(0cY+rCCYkn`k>eisz=#StPd~@(hjAN7eLRTIbezYEm zXcG(=YGZFiSXaOVGz_J*s7MDbU1{9f*w(Rb`W-&U5ye{5?c1Ztv1e6ANO$8NhqXwS zL#n>mpU~iH_ff_(zP8aHnZFytN_V-o&l-Z*+#^PYcFr183njxTz^XcUH;f+>41^J; zQd@aa(qg1x^l}cNpTiICqUQ7xOp&I!JcUrU+g9&4Z}Z6&QR-+W1Z#jsR6rLbE7IuH z`d*@fL_h~5r6pmx9i4b^6Z+kz>kp7(zFU)ebc=d6%1a-Sk?gx>4@?Fmo__ahm>W8z zafgbtdvr*5D)>$*M3y<2(og=V-J^G|Ua}uobj3>5IaMXWR+#-<+(FfWhfU$qjgQgl`hj>_3su31S(z) z0}bX<8-N%Mm?#!FS`+|AKcfMvo3esOJ$i{`9a&>WoM|``_9E!VfX68~fbm{*n_2Su zF>89`Sr*}5yL~L>_xYL&;ZU@A{z7k`^-x_h_0!Nv{Uj4b88rpMg^hoOoQxiTmN=t_I|!#)q#mZAJkMX`O+)3b<5wL|6|Sg)n! z=o&faoZ!h)7@B6Kom?${9`a^nPs?sv+1Ji*fa!3bNmTEg%FA-BN>bioms%kewyo#G zQh8%Rf9c8S5Q!~e%tHkQ+|`DAF}W!qfLDr%R6*WG^96T%H8~Z+X&x{gC~}dQd+wAK z9t!RzF<`ZkY$s1$BxSDnitv;(E9b;%geWuJPYjim1w{muBfu-Y%`4IT9Zhxbr^TdV zR1M{(^+dy#V^7|wt$~yEQ;H?aMBZYF@%hGlYkyel2wKvDPorp@EL}v$#FyOQZmKlOC zSYDrkFR8FT4V*R?e%~G4^H$01rUoJHodTF3N2@W7tGuC}vyKa}poB|m?1@F4u%|Qp zEf&P_EA{NmtM+uhTols8emKaHC&{DaCc>&g`-nOAZeM!&hZ%J$30wNw@ll+w(5fYLWCk!nSedHjZesn4w+yd;8n?|OF5Sx)7t_~Dhro(Mk z@y-qJ*}T3+bz8A5o8!B15D}w@I!(DJj-xHR4PPyq+VOjzYVkLnYgrqK=wdX=i&*O z*}b>Nns}#k->GoN)oAiFie@GBbsG)$4un(fdYJlk)X3#_!J@@N^%c17QNwYcDY)&I zVgLXYZMs+4CnU?#Bq$pASNr$hoOsGA9-&6UI#wfj#&3& zU*^4qVtA6;)kao%Ju<$}>}@KLr0`%3!yu@d0tbW<1o$O^{ED-U(Hh8+9w@1w7Rt*p z-{kxEYCGN7;(>>nWR1FU*~Em8+lS(Pge1X=o%QtM2%^6CO{ie(Ok0X>C46Zlk2m$w zo~KHik#Gth-TS@1SLY|N;uh8w214M!4(-YyuO=7syE(%|cu4o{k4 z@^EKh`}Q)-^mPkC9UNKKdAaqKnbk>=SJMPE?5m?`^n4 zAn+G8m4~+jQ){>VoGaAf>LgoG$BBwJw;o*TeD!_nKpddIu0yGTbE6C*z6@AZgTM@P z*LTO-I-u({rna+A56&5ALT0S2H!;&}0wq>|XSfrW)h~I`_;Is}h$(|`=Bv6RfM%$h z+@{Axc`YA=3h7aM1p$Hm$E9zv#0=$z#N zdz35Sn#Gy3+X1`tt&QIAF~T|!(GKKPQv1^kn))yL9Yx#mi}Usf4}HcI@qRLRl;5Y) z0!!l?!=$%C)re*=WCY$_hq@Gpq=53=`FO$$xbo;f;0pnYh_-3d)p`$>V{yJiXu3wX z$oYpL9qSDZO%3Mb3ns>*N#5)CFq)Jp>qu{E25Xeuh^hNq^7=E0N$Ei&p?EEdNS6K0 zo{UFh$LahX6mEtF`inc{r$_1~@*gUA)*}@4ZebhJxO}#5nQ{5tx%#yX?Y0ah4j7h_ z9|4B5yAN zwf9Dwoh|cVK3ma(l-c!)%+#X@7|HGC7Kif$L9dd1KBaPN@p68ReO`r;Li*0hB))^w zDZ=bSNlln`=$8xb-^4a-cP{b^!t+gch1be!-`+J=kr^QF&j0>;>9yHI2+`+uiW3HG z`4T%Rf!Gcqu?fb*E~QBISS=F{ao7@<0(|}Ki^8?`&1k7L8OCyd9ak)24Q24FH%bLs zF=c#I`sQyWabaX*Ed*8_ZfFp(mVQedQUTU4yP8dnZnmY4sSLPSCAPb z(iYLo%yG#`By@R!v1<{#UDS#31|}%Io{P40-zo#v2uBrv`oj~aB<1jTsYOnPx`e7+ z@qA7%ot*RAm{kkr(BjCa7XzAweSq6{ho9r&6GAi5RT{UnX+|l6ZRh3GSp4~&S--}EY}<_7t#=Azo+w!zL4ig_r8E@A3}h`ta%VTt|i zNMB8Yyc}Ke8HwEx(K~aT>~W#P$O6;0Vw7S&M2PpcEIbz1OsYXbCEVAwbe`_a=WSF2 z5x=8JKF*iQn>ufaxHWI=(P77~F#O1*LPhp0fsdf|tCg_>ho|Re76ugJXLNICfa<*^ z-H|~CDS`-nzm0YJg!}Fvpuy4F^P8*oHoWmmny=J)y6-xl-(3xf=l*cyr9?^LA|pxu zYx&#)2IM|sfRRn;5BtVHKoh?DFzX8b9(rPcl1MN@gjAolR|5yboF%HoFAvYouWi!3 zU#xGs%*587r+!*;bCPoWj0rXZemG!i2IzSvzDi*e;D4~QVDj>ooi1zi#$9RRIFkOd za|Wh|tXR`gQ&ikCoU~c#8;JYCgpp2sdn1x z?Dh3(NavNaumDE;S?-X~vYG9;?UeV1!QIp$nJl^XKKB$4vr{Pd%Wr9a$XnfH^0R-Z zAPXZ5M_$=PS4h~dgoZujDPBpPuxMPv7MlYpZMnE^+1=m`n$taiV`*9L#i0t$_UIcS z+>5lpO2&Fd(tq!xgo$5e*VWa(d}q1Gz6_xK=GCIwLvGA#?B^74&Pfm04OGl@k8l7# z%C~H8TezvrwICPl?{Jpngd^rOUby&f@&s|qTp&=|~T&YybNbmC5mPvF+ z0MEo?@X4oYBmvr|pb(2KS=d$f)G9+qgIbpGT%xrvy}6$*8U+3&DPpV|LeXg~0P?nI zi~9qVKO5{_BSu~iKbYn1yO8Y7b#(I8B|u|Iq|%C8hRa(fWZDMG;};z^u}v&&=7I2l^IWXNKedcA{`7F9HU=bSAPE5jI+I`J$R7`mEh$DAEDg3c=mBQvay`2%-I9a z*Sm0%@v9!is~)D(w?W(l4d#4{FB{BXxSjxP4MHa~#*YZjJJZ*9;TaB;1lsY_27n>g z5AtbEb;fNNT(!|Eaf~vKWUkVD&5`5tA%Xv{ss9Qz^pu+6t%o{}LlsZCb?MTsms@FH zB|W#p^~TAv_52KOJvN`)k)KiAs>M}?g?XkYCY*i|(rRpvD7W#&@} zPd653~5u9ZzJ` z6dNdvmK$$l+ydy)8L!e?Rg;W|E&V6r_o%X!vDas~?$&MPpGTA`1uOX#;pF@hq(Hg> zuo6SN6rKI*0F^m?*$3Tu_;Iu@w8)vJ|M{x5Pn&~QGRhk&mycghniuQbTE@isJ`9qa) zGhnqj;1AG#lx3S8X6FCL-h0P2)o1JCQB)KZ1f(ND0a0nvQA$LGfG7w^FOd!+AR@0{~{?>XnroX_vhAAHzpJ3D*t zwZ6}K)>_Z{KY&7lSvEd6{huC*A0LD;?31R}F~eTiC??tMN^WzDe7Dj#14x=mDZ8E? zMJhdf^2Cp^&*j&2J3~wTSw!19A)MZlilRZwBDcCc>1-5BhY>22@sa~5N0zJ)`@X1r;=h-ghe1av;AD|!Z8r;y1A z3%6F{nK`*s)bcu}U^{Qpe4>=$p3%Wrf5CG@5MKwAJ8dUq6`O<@+A%Kf1qI_GJe2GH zkQOQOk|aFpGBE!1WivT)wiM021ZFFP9QKFc52g6wdO&Wk;cg!^A!{g`nu=(-aV(4A zBv%BlUD-W_cJbXBV~ay1G})LhDbE(4isp99$?=bB+HisAzDqEbrzy zhQJgNNjcS+$7bi>p^fp7@JEm+)dTem!!s|M1&)iD-a?(nPvXbjG<+DmXFEud-^z^y8Q%hmJsyJ=&R=iqmGp^{WN+wcQX($BPCNe2YNDBO zG(N!ivg-krnT;Jm-*q8hg})A@JLo@*0@P}b0A$-YbS7fa_pMEbl{+Q`j*Kv+t5R&% zD_ngad5%(R-6;|~90932-Zecf6Yu%G6VgDe4B=OXsn!%tppXMvIp`su&NrgdE0Uh? z7Wx6*5aoLeu)~p&s>t&K0E#g7h&2Aqj zbb}N!-XBY=cmZUwS9T65#ZKOi6N!85De+eS6(kUkVz@g}f2N$+`b^0?a+*Y>h0abtf{KKQqA$*e0F4G~7M} ztWM3J_nZ5EmDtc#Ilr>W`}wx00k?M&$n$e}G2$tNHTL!XF~ImBS%*K{-L7wrlwW z(*d&6O;`ObSCu>d*cG~eV+X>IIZu7|XdYjpaqNXEv+3Kn4qJe5)c->~1axbfru)x- zh7b~AF2TZm>i{)ctBD&&D)y5#=NS%UX%6N0t?SD&N3`(z9V^TcQh9H&W%2Lgn12u7 z{N?#gbQBTRLr8Avb-kRrD1TGu4$~3)bLB72fPn6QYjB9@fQPEsyPo%)m*7bzT^>gK z@C5dkz~uGs_MCk_QzH?sFjnK3Q68sz!alJeQh{H@#+eP2Z#W zSzKuzLz}`*LBLKZ9sJEDm~BY)WNuq^SrQsiFwu<`Z1OM5qoj^JN%nR){hjA)Pgl%0 zt5uNTGuVN}N#R<;e9XgRS2ci2rA0IlnhI)UTUdhL=~&QZPQ@E-Ts@rP=ql;~>SnlT z?D=*TeuU_z({I}FS}fMxNIOjKVBOGg($1w)t|@m-ibjo@V~t2 z3|(%qfT!Z2_^NGI0so0;ppB<=@xy=uzqyuhpV{%2{;#sWU(TO+slT&!>wOUuyD~iz zeQ6NB4}<4QCvt0G=pA)tH6X&2gqo6J(Z6&F%cM>Z!r8|u+uARR+8^rY+j!kYm z$Z+6VKiVgC!(^=Ch?J<1Lj>M%q)|14=A(3i1YhfP!ELIOzLwhpYnecDXY>DF**k|;XRsvK|?UaG;(3`5IJXR z^$aT$8Pa^|UbSfCVF{MJ5Bfd5(_J2P-!j#`0d;8eRYLbwk4z8x`nFZMg z7jylqL?KliPlM^zjY=dmD`u%910GfAiZr#|98{7NXB17+;Z;&pQtHRDiw?NcPm(_b zIeZ+L?nl3#s#n}S%_yBljH&8&ivu6psvz8K4fF8Ptx~Yeo6f8M#3cG{=<0L!#S^TO z@Z#JkuKonT{n$z>rG(Zap~~-Q{+_r8p3pSYKOJb&DZU2fh6o_z7 zeA(7G?>`_7yG<4DujS_s6(T}Iq)}FaqSxvwht^mqbyu{k7*om%`=8C4Jy=hlAa@G z7c8hxiKQjnUp+roox$UeC{dq#HR^D62GnF8Wv&zwjQTbHuu&jhr9%-R5R&|NV5Emy zzG?9YFTgYbu3M*oeAS!+8wHviIEO?B%L4+bfcmsv?Wc{+%6qM2)UI!yRTW^xseNUD z>`Md0?H-WkkDzy*;Ny+o_9!erP%6&RvUvm?(e`}C1dfb*#qrdw#bnczJRX=`mi}_g z^s$80D!Qv{&`G?&$?c*^pzQ$`hGr+xYjK$`#Jft*i@h8l|3I!FP?{iY@~d2pSBLLy zwQY4>m%M^Mc2z&YptSC|@b0ayqkbkKxi<{GHsk8$+k}Ni!mp9(C0Zkba)ibv9`4(O zt|b|(4Th8Mk3}9TD{sl~fP&a%khL@@~?7o#cVWK=>P*z_S*r8JX z*gf)2Q4d4&rR=rN$WGqkAr2FtS2V?YpieTHra>`i1urKS1j;!vv(mD zN30fDxN7r%|9Z+bW#-=gFUFjgRZXuGrQQn|-7L)5kXlje(*CphEO%r9K-~rH^Vp zB}x#i5D7aw_2hJQ$gU~Z4%fH^^GAM^lv$>69INsjxwy1#zJrBDj_=IO$zgOgBcrD$ zbW^4>4$rWe?w*)oX?d)$@Jn1yEP&AUIAFjc2-E%n>R-K5J9d)UlIXX} zJxvpwHmK}>=w>Z;Io`J@r2s80?EOe6#0vV{HtgcT#-fiPW9n!Q^mZ77C)s&*A)aFgno z!HNO%r3<4S0Y`hkclX@4_g|gWVb|d~w9$TI_ZbPcU+^paFjf7M3$q)*anfH3Gmvjv zw19+*P99!y9RojsYkJ}HJo(`o!*_n@k>Mq$_NitFF#nFW=A%xOYReL$+yZMdC_j zjLC(4Q&^7eh=-7q{;5`x!+6qUM^im3yM{AcMu4X=v%E{xPNqx7_d9WqHSD-`%SvE`AMawdLIA6t@M2}=I)aBvaPOu)LPz{uu8(7h~KL4alwORl7fLj#lr-b zN4_Xpfy7>j#cV}{0`u-Pt0e8aR6gaUZ2nzi6Xk|k`xt}Er`+BWdmbZj=j(f^(D&FM zAbbC)#tK3AEJ9TGp(EF$^32~)W;k%OMa!=7TN_^X(<huPn1E*cr)*ZF&GsL)~y<*_o8(;hM*0^`@$Cn^?tPGW1I7(^N zlYezXX?1+IRZwnDU|y}-fC!zM4^x*!I3XIK-U05k#Q6Z2ga3dZs>F9)_M#k8!k%Y+ zxJoSywmqruDs4|4-YH*Q&VD=w=?|5ugr!vC>a5FDlz$I5PUJ5?4}?>%qRblL^pL)7UG2G##m9vC$Pk%kIkFZva$TNw&VMFn%D_C~W^-PU32N#zj4RC};0PIK)&o>a&Rg|(Kq92fLxJGoB{r1du z5afG^Fv!-?VaR$j<_AawC`&qqB(&0>yl$nDD}Z9iNCJ{~Z;^_MUTBo%F-w$jY z^&+a!pFV=net-b+kC{h%=oNLq?Vy19&2qpy<{@R|WI7ypC6C(U_xs-!#60E6>xw$A zD(Q${n!_htDXhE{&ST%Ao=5;wWm=oo_l_+50PWDH0pfQVkSq%k zJT^!NQ!n+yh4VX(jwA^N?^HOsAX8Q2W3q-=p}g&}-rMLzJhevneOPqySlYmx=};Ze z>TL;zHu9`Y7M?cfJmq3}Q?g)W(YycSyGZH`<^2pg3RcP*1(%r~AZ+5+LSC0=?H9f- z5@|j$G0ou~d)C}6Gf>p0F^phB;ka^Kgh)jAd81!yzi*5?*r#$lD{VtQCaRyRO02}YM4xO_h&DK^GVE0CH5TiWU{6N&bxbs36DK~-ILA4*CcD=Vil;u`M)}>q~S2WR8ffoaaj8|@o8K$3n$JQ)a z^FA}!cl2~hO4vfn)B~0K8f|;S^=v;tso#HqAX|s?I=qwxT9VG7hJ+cJ9=vv~nN+Np z(^3zU-e<%L$#1`FEX?)Vq@=E==I51Yqy48@x4*D&|CRqUWe~tiXx1m(7=}YXe6zV4 zL-=lU^+ZnPA^y1UY%er;1)d&*{Q(CE+xN$OZUfnx;N1v5f+z)ob%CndD*Dg{99vfF zSJVA6C{P<~>5 zWum4t_QrGvPm9Vn8VNr(HSyE2tvWj5PV$+5DEYLhsMaWc9X=tGr4o6brO+!>zJ7`_iidzyN&93Mf#1vq-$yX?;^$V$b6cF;CP8Vv@1hLtA8o{ee^8!;@(J>>2y`v*;Wn{Wk3IO>M-3!GWg@{ zW}bxFMH7?M>-6#wBGaxoeq@yak3H(YMO7sSzCBZrb2*~_`ROD-KmVq2`$#U6Q}!y) zzpj5~5L>nYyxm`XB=RZY;dx`U5gUt-UfPF*(K2`XW&^&~`0m4nhLByWB`wKYYz?@*Q-snn)<#O6E_2>Wt6EoQyjEWWwDa`F zG$N(iLwQUtR$I9D3sLY;f!RP3Q<_m=`ATiKgo$Q!ww{q$b-dfV>+5a1o`1|8bGp#g z2|J;0dxM3nEnpL7cx5T~%y_N%h(4#WT8SfZ8>S7gF~;V4G2%Q`yeHxkyX&Qw3#-Rt&$&;W6y7ZVdZM#a zNlS?#Zu}_=bhr%2Ois@40(YNg(1Z0aBka+Lv<27{XL3T9FZINs#`aUMl z*zU|Uk49H%A#FT3dY68>D;IZe;GtdCxh_qy55Q_HygtFvtjP(N!(Dw1;im zg&a7)FFi-$qVxd(Z##@SY6ZKYJ)9SX}-s4%|6TvM0_V^K!GgZV- zDG57jPQ zEn*NigY-&n5O<#VZCd|<>9mX>jd;E@6p zqLYe?39&H}1LofXQU!5|S?V6I`T>W>XsX;0E3=L6t6s*ZCB-roh=!-WbUr9QR77Fcv!)ny0N!DNe_=CMwro6feg>$*NBFL#S2 zuDbYY47EIHjHZdtqyN?=mI8Wj9@#_&+;9Ql`^{`?F&lvImu*tbnSQx;w(!xZmJ^%E z0=UNS!6J=6M0O>%oX@Z`_d!p}xJhX-%>ypYFg^hO*6upS@;hHR*NQ((^#^F8$I zn9kpUU(??DKjKDV`v78KwI!Moz*Ih1GP{eo9sPZ=#l3tz0H~9W+okUhlLi~H0Vs?~LG(nt zNpTr61o1aJ`R&8K1biWKdi^-r!krOE8_EhDBH3rm&FKic;DqPc0=zJt6G{T^ysaX~ z;ph9PLAZ71t8b{6t4lG}MNRju#&twuM1JWE!e?+!H&RrVRNtZLzbb8YnGeJcS<#I?mLma|7{`te~}{lk7=Ue-_k_y z{=OFPuO%}7^zHl@Ke~p|571{|lAgtOTp=(K>O6wXN)I{9fcf?K^TA}zS{YgepctV< z(nOJ3sdQQfq?FbpP4fV56z|#P9sfnkA>yIz=~aDATK=DikS<$XrB^+FR7+&DlIs(I z31N$F1Au!GV?p33i;*SQvQ~8d^bwu)T$7!L>)tKh-sjC6b%tUFKfEjNt64_dukML6 zUi*onb>|Tc7xPA!Fizuhw0QhQYTfR5O@~v9Z<=!tEPRc_IX^b^LsrEGQv)>;4H&jKXL^(|=m z?+4^P*}4$8s*+;EvU}t_0)%IHV+j$*t&@s%?(@j~vC_=K{LwtcKP7A0V!P_kv z=Ww@|M!I_dEzu2-X`sMs={^v6hJuL+h;3!W?3Fs~9|32jdvqoGlaGig?OpvGTgW`k zlX@@%2t;|6et?X^>9@tV#r(%>WiAgA0V!qYxWKczdh@j2abdH?|NnS6FJ@hZOO+}LH zRhsZ(A(464(@bTD*>A2;9{!vN3}duixXK}>SD>SAnwg;MLr$MKVE{`&WugvQsTR@4 z;&l!LKa9-GblgEELcNoa8!3uYo?TZ1nU!(%0>-Y#QSukFoR4?f|sbVR78DG`%F z9Aa8ih1?*2?7!i|{!^h&ZlN*P3Ui~Ujy2zkf0L29zm=g}zaTiVxpM5`g6-n<*B<_l zOp?t;E3$J<*5}=_3951-Y0IWLyru6gvPPgj?mrUX#d$&bGmt zJwNb3?KM;^-ZE5{G44XX{{aVQ(6q2vBRawY(NxD)tC{IKQm%OBs(`0|TbrhRP>mGUX&@|#C8PX!iSV2&D ze>eWFdyoJL?jvv z9K1;(`O^n>R59B%ezcxs^u%ud^Pj|QayNi!y-iDog1aiZD56c?yMdb(eo0-?tI9oA;Olz=%OXv-Z?&?t zW_hp3>RXDCVyGCS#&++!oUM#H7@jK?qU)!1_&R_gsmBtRi=56jg*P zB?GoMQA3~a7+E9s$9MvGoG=3TnOK`(kC1mA}iyrU;f1Nq@}I z@k;4LG&Pi62P(lsGAGMjJ^6R(Pg0!*&t89d=4Dzo&$#Xf5<++t7>HNzzq7wKbvpTf z8@msNd^Px|vB%5=J^PcfH)EU=0b?I+CCsZj)j3t_KYtBaw!&15t|h?(Ma#>#3{cku zESv7M8wcL;;`9!2@?hopI>onD9H3*jP??F9e;BA{LbvWd+J6q0o1>U4KIB<4-cz)j z_g8mw_E-1026&x1h%07K46~~<2qCY6fBP_adz

V*mo(&vE3D!B417dIf!d4}fm} z%>|!8v;dUcAOJ8EnE>GW8$^Wsn}_IW!~$P9{lK#&GkpN~0)*Z^=Y9+W0te6&p_cCu z4UVLqRpb^5vPyRzLrl+Er&6(~HTM=oIL|cavO)t~ZUnB6haGUO#zq6&)JlQfeD~&R zZHmwa>Ba)8=4DCrySxVh-OvN3ez2$W>nv@((S!_-Aio!f2Jo$|4V^clpWM<#uK32Y$@m>ujM(P&FO5MX5HIOv$G z+iD3@=oqTsG%Jrk9^0xIllAq3LiBwzw$O6Z&999J-uwGxA9^qWn(ckY$HO>N3b_>6 zt8?WZD{)@A{`AocW&xH_^|A)4CQzbImG9$jvH)7lg6u@t+DNVf89E*(+61pP6iuo;#&D5AS$Qq`LTShSzMMqPIIAa>;xEh?2KV1~@cX>= z{s`nF9y{!M?!l%Q_r!;?3a53fUbKwN-W+d`&HI>@!DNSlfgy3%nQ**6 zHHBBhq0K%2VXPLEk@0U%^91Z_D=V#oUTZ|tXAObONz*aQ$O>6gc@-vxZTVhK^XE`g{Uh>SsS{Q%wFRreg5R2rZ09IxnzRdw#X z*d6%rJs$Zp^w;YuXIfm?9%e}GH7Tgt6VYQdk75Y!vN=db;8Uf!T}H-U0>vUuO_fPp5xYgD z(EHlAw!EV+%mX1WU!=H1heDJ6nOwZntVn_1_@2|Nqt#(tEE$4K=W0g%0PkG8S^9;j z^fRLTmJA_;(U27TnGJR9Tu|zIA&vf%HnRseM-yLfc#^+j_JD|YA?e6 zOTQ`tP!OL?iXFP=A>XBotNuPVq1C6$zsUE*6w2c$+_j*|hiHL~sa!9r7?8(?ExWt9 z@im`t&AhwM;JEE=R;4$bJ42j^WnYfuIGXKouqQGyy^m zRBs@}MNeBZ-8^75xGi-bF+(gZG4R*b12K$Gy zn(+jsm7ytEIQ?8P(KjY2=VpNI+&$grwm>Bo|24k@p{!8lMi#Fv^xKc`kZjXAIC>6K zfl-OcqH`c^$JX?8gC1ts`R2t5mF?>BwND0`ocN4SCaljfI5B!t6>xJatpIeFSR4tc zV_3lk`K@o%BdTh6=n~G+!`qx$=Qh4xAE^kcceoucZ{_Ls7n;`)LBIVM-sPulaRMr* zMKoY{dZz#e_6O*TDrN<_7YLy;)A!;2L@xt`$m|v5uOY?fPjz3K zP|lY5#GN|zrL~jg^2~$#nnBZ=fxL$TBo3P2zGY|_6{YV4lt0AHI-9F{l}{=0a3<3FW1jTq>vTgOR=UdVZPsW^sc)XW*N3~XW5aX9ONxQcg0k6Vnj6lwCg@;$zNjmMaw zYvCZ^p?_BgwzhOOnd{DY&Cgv4H=1SU7yaB5q5*c4s!v3=LdX_$h2G5Gex8Y}O`-Tv zUU~=VwpqWO6>*azURgEz2@J3(e{*_(M3C9weFKBYNL_~{Y|4s5)wTP#XAg>rotYV5 zn#5_IFbxa*&8(1Lo1gw~Z^Kl{^Q5~3=_b|_ljIun`oXtAYI`TE_EY}V_82?GuZ=}(jPI*v*hR_vmjD9aecPyX~q*Z9=GuQWt?9E#VwEAu04q*%ZjG{+8Q2Y2 zaipR4e@^QF9TI;|>-@>?>Lw$X=s=lZ_VJ&|z^OFsD0qv6Aahe0etrL=o-ly8q&JlV zUZLZsSAd{^d==mn%I|FkBdRJc(j}&&O}05ZqC1011A@}*7xKH)q!>qaQ#163ex|)P z^`@uRqilG|Z>QE{!||A>TVf;Z4q7451IcF#NCtP8PhYwE4EajnN!!d+B;vQm1@kk2 zd9lc3pTE6`zyIkYj`PICx?|A>JI=KwO9w}keVUf`+GJu4&ioPXNPAyq$sZ$UNcN~u ztS~4wY}wO#W8BxmtG)S%P3!-|)`GS*0p2`hc6KjXa%@A4X5NHNT~5HfQwI`Qe`|!U zm?LmGBBTW(0HpBT$Rk%0u4>!+CU-v7b#SIgj~dwPJoj$x)YG2YkUbcL)4WD*NproO z9oEp0nr3BbT9+{k4b49Paz3J|{q8XxG56swfVnlIZ2mehK=U{1VgKmo|ARiG-v9q5 zyMIl5|7-j24ix^U@zej{>%kED*XFvraX&yWNC4e)53BpdP#v+{g;=)O--K>wR7uBOAAQ7MLrD{s&*jKYqW~&I5H`1d0_op@4d-e;cOKTrW@_ zwnp!0ONYOb;^b4@d@aN5!7hDxMJs^eP#xbdY5q(Z-u z4Z*j_xo|OF<=uqYcc?dmsO`~4?c}nn|d|f?_ zG*gO+H0u(Ft7&Un2J0EpcLou&4CM~@qw0?B8BM*6@BhJ}*}rMP{wKb7U-K`z@MYRS z@l)V!K&%$LchLF`>yN)3C)(39?s;{9YA~(;RFG?E2&)HlCXiHKLR~r z6WT6v?lHnpkn-)YoLutAlyAzryJFBm4=Qxl(k9kw&h9+$Fa(<7-m99a z8pSh~-_9zb5-T1?M6WUuD7c+4xG)|lDXMEOPoBQ$ZdaCUXUFEllV~%QypQenY@1}X z;IRjga!+cibSz?d68i&m{Vh6`4ly6Q3t@&hb*a-&RF+o{)k01bX}^uvo*YwfbW}aA z@rCi@8`cNUNW;KrRE=zxy_ndlEZ4Excl>I06g{^f#RsiSUxgKE03c?I23@=d(%xZ> zL~yIJ0TBNyUYg&!R_6P4q4K8(X{!;B6-=Zaz3_B`ev+W{rEhR#J9_1g~~_V>tS=5H(3lAUo< zaBt7oI^22@#2b;N|DC)$h_F+qHB`p#NrT-#BtwP#}c z>r>Is%D~SpnlVT!E6#Hajq6I!TG@*~#$Yp0c(XZvNhQJQI7(5Ve9{_eX9|kjiTS6Z z!@q|{Ie)1k6u0?J2PMdF>(A}%oEyn4;4xnq19@!#$Jl^nw5Hx5KT)fyL+KV*#rBKy z+wAWM@Hy)c9*q0M)+Q|KcV<1$jhOdYkcYgOPaj2<3MK=3^CZC|BQX<+P0v5N#x1c= zx;9(3xxfcrY8y-+t9Vui%s&m9CqPKAqVUO4Wth-<1wz(3P(K$T2oCiqln-1qoYHZb zy__+LzmRwS(48YpjKT*tk?g$~4K5s#m znjAMs)LLkH)8k^t0_AVI!xGT&X|)AOEMe@d}_H8d7hN^~7OBVpKp{ZA!0~{IEGY0pKqkz^F7#QJ%ib}TdP4~mm0vlhSDY4C87<(F0 z^QqeMT)innuFSErPjY2XKu;OlF~vYTl?sug0aYRvB99D394bRc_y+TqCQ0hi#Y2iJ z%xiZ={8yi~N+e4}F4uJEu|V&Qd_x^yuV!Jz z9wK9$6Q=$kcGDO4U9q38=hE9y$DT&B7U}tKD|vHV7OWzGr%gjKF5oZ=SkwC9%yQD} z@Dzda&FSun%Fsw1j#Uu>wii%iAz~VG=FgajPsBN)UJ+nOAm!x9!cYHwn0uq#yXk|sT#w<{M3D9w*aCEP!QW}WKms# zd8SuXoKlkN7&U&?d@9Ju}XM_SWmk&F$llq!mzbG1zu&wVPf~VihMk? zt^M6>2z3Ky%a7FG?YCOoVa^aBFewg?k$J~HKi1&fT%?szy`~TQNV8KJ$>RBN4T4jm z{EPBZX}f{p`LD`x1(GwdI?f=1V)WkQ2M4iFfiB@%djhl-#YtD9-X;!(N3WNqEAsld z$Wu5=AG)ibTT~Ci{K-*;yJnxAJT+% zdsltLo=#4>o_4CrSbdrw`XXtpQh2Vx*#+aZJ>{rIPk%r9{rk=A z&W?i(`mXfiiEl?ZE#Dqh+Mg*%E>xO=ta{>+$Vs7IWLW))rC~wl!Bz7$IMhd3*UTfH zef{WJyF)k=#|r^(UwZJ4$BN_7qdQ2l7AtrBy$ezJ?z-@RSLJijqH0@!oegh<%|6Z^83Dx(t9yS*Mbi;Q#5rS zktUA^K`0s-*Jv9@S1!CX)Oe*!DR}b&_E>wa-l~LFRqLN|A_0$zs6Vj02Ww1zQg~-r zb5QQO?fDztH=WoVV!2hkg{0o<;VM~z&kwJbsbI(=VqqesETrM;%y{e21E`XiyNd{r zh{FsKR`>(NrJ&EELs9a_nDiA2-DL4(!s@u4FU|)$#Wz#qvmc+xvBrdgT;T`EWsI%* zA;E&=0{0J!NI{H1KQr{DUSELRSC*qsZYjcRX?-J{%iF{XD#s@M1_jqc3$#Hj`ZNb6 zTm~Pa-k?>SAt<$A*x&~VBaJ~uEjD8P1*>O@qx4>Z%lRYh)2Ux(mEr^doGl+xuPtzN z#@FT&Ezg5`#Lx_2qBSOL}oUvEyGuDh%Wy5$Xi@EHJ*noxR zX49FBh$#{L;bHP`Sm8YRCDw>m=_nYRC&mTCOz#FCn(ePu!>bu-Cz#X{S+eg+&6Jm# z#+mSI>llcHcR1Z?DjXO;difBxQF?6LbVOa!0?tbABP>yH07?i9ka}1-R~rQ=V`*QN z%eA`lv8LhuanO0+7iNmRSP@|88kTZ-EO>nKAQ$SGf_zfP&7o7Jggv*UWv-I_QHk8C z;cNC|7-jQ)p2-YpM)WgeQ@nEso65iSHGB?r>ai3=_KjcvkOzE5C8XY`p+|_)=I7@qFXde9L{(|?Gw)a0JEy3Gn>?=MYu0mI7^4x-FX{b{76XN7*n83H5?A*=s<4qm&6(29p zv9}6&_2_|ApG}9ippi-7^6*B>K-Ta}FYWdCzF!JP3RZZ;`9~gzb+pWY_#MA3=wklF zHk!q4JG*gbEGg&BVs%nUc6r=K>+!1!%?aAYW%>O=&LM3m@mJpxC29p6ykIEI$p{$v ztn|u~KEui-0yX%n+E2g%OM?;g7T;A6=3Ano!p+`A;SUveuUyH^Z9O8A@EPQzu%`lh z0WK*>D#;1%gp~P&hgYOupSB_1wamYh!rK*(GZI+yh>D%L#n_7>j0q|cGv4OGsQbwZ zxl|}2__=D*rdu4W(VnxTu?b}tW;VKZSxr#@rO_P3Nw6n9CNKi0*a!|K%oKmKSX`CA zlwP*9Ys%--Tpj-U+=$5jNbeWg!WmTzT^L(tEUOLuD~bh!<9F#iZYFOOLEM>)k= z{Lo$s@BOtqaIW5m_R5ngDb@lR!#4S8iIV-tUS3(d$5r7Rrh-}JItT;{B`;Vs4CRI& zUur{mpajY9wd={Dgo-Jr3OpuUk^M%tC3o-mO?ls${NgWm_5mWs=O_51ZOhZ1Y)N3s zoErcq==_9oBPbVFduEvK60N2F%~^IA?6|SxzFE6IL0e7Pt+=*w{MDVon4h6y2W}oF1~3i4$#0Q1<WA zo;R`$BK!53md0!_xEQc?pYTTM>h?*$Tb8bI#+fiU672T0BIun@tfuf`r<2e#8|ni z{_NT&M@LU`%(vh=<@t$^Z&>8;;3uW}&x4N)!;j!kU~TiqYgsOLXtlnK^QIb9veZwR zM%<(*NJl`NI}nfQYD9_GeFH;WJ;ssaD+3xcitax^zWwd|9k2=1D-S(oh(eI2CI~Ku zM~p)7*d_#%{?z(&ET5-B%q4lvvPelO{)d~Vg&mSDMOeL_pc2#I-|U)VgPdC-B_q@` z!Te-LpzC235G({xPgUIjS`}Noi@Hbei7|4*D!>aU-f8vNt1z9u%Ygl^< zvJtfkTY;?}k({BJ#*xx9uH{LoVq{g<#*wu$BN>TNM9Zl&D5BhT!oawjsk$Ih=~-Kl zxLGOaD?_633;Ao07lkwuB`*22Tn*elv2Wv03Ppju+$4CRSJi;nh`&KsDkl`Q+qe$1 zxe^wK@pIDK-Zy7uHV+F4*rWL)w7BL+-af`7F%ZLHG-1k#Y5l^niV#*Yd zJY>wR$CkPU{*hZ)tORIH=GyFSD6faMLB``W)A{6;bjqa#6c-kVJ|ZRI#(3!i{*eQ% ze047d;VL<5(Ipnt*iYS+4an1un38W{u*(z&>8)XBUD_vg7<)Nq8 zfjY5A&~Jz4m1mJ>9;-f^g03tR)=2?Inn0(mCC*hpCdDUT)d`?TPGSFDX5+Y3b8dEom#Lw$AsB&w@I( zXspuNs26(Gk5Vu8sq>I+@N>($y>>2ll8Ak4yf#`dl3n`c#UvgH;p_N-yamwV;n)FV=aAA1#i5F4vmsPzk^FRDmXwsCC2*!IpRc9mHQXPrRfIIJ-e$DTT z`uZBKdmpbn;&4Qm+IhMAJIf5y`aeB8$k1F1A&57q^OGm=D5S5FUk{#@Yt+)r`in>$ zzkKA7-8C7Z2yLE-WYJ99ub&0%JAzLKx734(D5SKy7}=;P_%tvnCdT;6UG5B?LG>>r zT&vt$=)q5y3csULDg@n@tIskts7BAI>huc>VdBU~mP6au4H-oBwfDD8< z_v4)z9_7aG4`u$U{8fay%t+-x6R(+!aE z5x#?)2f|}@=IVBE^H1ebVHpQ{0VM*%9`$rw$5{!scKX-=x`@y68BF`qfj9pNiAAB8!YyGUx z2n;fCZUXZxa3}9@W7s{DfM2Pba@zl0BD|CNew^@m`gWzS_drG!X$d%Z6bro#bBroS zQN4>O(G1JV)4|1q zrzo;@kyrUcBb`il+jxx%Xh3=&ix0K)AENBwXNSizdV z*`kb;gF|zYgDg}|x?QW{O76SXgp#N-J-E+F_;I^=`KRxrdv_y2WJDwO6QseKp(U7$ z96u>Ul*$2?lcO+N?b66zr1iwzV;gw!5($Z;*H479MU1nI+Gz{-2df2l)U%K`^ZmV> z{l+GXaQYnnw>wD3EU#2?Rk#X#!}2~rw!Y;N7E5Di?cch^_>5{mOIJS$cO)aS$P9SC zi6#4MKgHTfPhV$pAOBz6y=Pby+qN#;BtfEN$w*WI5fqS&fh-b4a!|>%&;kMtG&DiU z839EJ5+#e|)J=|(lVrNdAkYm64K(eob?(_~?Y-_j=iGhI-ut`H_oE)D>Z)0@W{sL- z%rV~ajx61((iWa=b2fLwVs5RE)IOOZ%FhAb=?!kNQ3a=E2rI4ttqM5U@(!x%@VZ4D zXJA?5U2V=Q^bZTiZ#jIt2?|vIo*ac1XMn=n*aa6OU{4CoH`#T9eNtJy4N7iy)RWxU zy%(Jb>RpMK8@@2XY&E(&Ntoo=W=llRpjb5WVnT9qqKidnDR@`tTbZB#RbV*#U;qCJ zcp3iV)6y?cOu+3CB>^7#Z3H#e>W%B9zHuAEiuMC(2#^cr96O* z)DN&WZsX-~D3&O1v}QGzNC!8k?3m}g(8i2&4Wdz)vL#d@zVD^GF=>*;Vw>)?cqbs54{>ZY*4*|i({<*hRm{L~anfhOD{9W%x|wWv}paa#P9 z2F#;Ul;Vxi!O6jcPJ^vVu(uG08Kon+Y*BA6PqFKuk+HQJYIOB}yycLcMG`J@<1@+9JLRQJYuYliel4s+95Ahv zQ#y?b#6$fICm%LEz3N^xlIz$>MB+$Y_fuhNf#vAvCWbKep;J`@sNqrTF!$WJD~w-QgssdoY3T=;kR zgE2XZE&@T0-7QD!P__A{&Xzav`@Y&SP!@7T486JlYaY!m#^NcCnx7Q7$IhB&g zR*SmsfxrtF`T2RdTbFsl0dkWI&C3JuwoFp|#k}cFC3bp|ErLkm?hMT%C#Fil*It_` zhFxPK(sO4vqBy5B<4q+&961^p9iTYM(_V|D+x4W8w;CJiSg<5~WQFNT%r<^o2qLke zp4H@aRc42N=tffPDkq1P}%3!R4*$` zlI+s7*a9cEfIU73m<_4JzCF3X-yAoP z-d&hQm+c9k(LQRz?6pzt&5)SkAgBq<+Xg>pmIj#bo9?6LT!yl3pWWp6`+U!mb2X*~ zC^cK28X z?QOEnU_(N}M!svKc|(4HixeuR*<`3s28f%2yw-OMWxJj63+Hcuh-{2e1~cwcF*NA5 zb{y>{fe*`6DXHsq+-~&x-TWeQk;1)xh_h-yxT<~Kr*WApgbdd6oH#U4TzxgtCw^D?*WPF zLcjab{*YVIqV5Yp$`@im{Ld0d{cHt41{(*HJ%<(?MVyzxt}Kbf#hxCwi!naJNJSj% zzS55B&&7<RG}VKkO142~Pjr+^D7QZl5QT_-JQ~(q+gO(zh;PhpV9XDbfEN-O%J2F>@+k` z<&pFY)VXl}^%(#MH_w96kK(5Y%iz+ii>&;<|YVQkXJEFXb~tK@qWNJs9LIwSbgwm6ItZ&}W`ip+}T z<;YG}AaA)P0Bs7a(sqi=oPL2Sd13;@RZ;Im6~LHdX7C+)sEKeBji_v68Ikbx7W8? zx@qcMWxHOo$|yO7xt$*Bg;g1ClP8WBUV>$nkHn{OVE&4JT&n#W#iy1GX}t{B-- zS2(QQnIoE0^jlW$&SU^=W^QJjF2+W!1`D`vZAeom+m4LFbiR%~=TY1lPY7PUc%sjk zBHKDR0ZBnRWzIlQi(Ny(sf)x)OjH97XqlLpva~d&6q!p_CK^I9yEI$9mrY(%MqgeD z40!5A++K)_Vhk8iu7U&x$l%&UqB4mxFNlXoi9NcUTW@&-eX#>;@#u3&KU4R`I@-p{ zSzhi=%-M2zCk;RPXIM+-43Em_VLH4++GAnEH$O~abTciynxl-D>Sea$T7=|^^j&M+ zw>8qAu(QkM9p&uM_I7xsb?9yR0g)E((!9kj1Ydn!s7UD}q7PBp8l~|y&-)4!)u@s- zgJC~m;4H#~aQ@%X?)3qw<>7miKOIPI-y}>rxomw zGR6aa9E7xmT*B+&OtD)>*3`|wdzBKWcRW5Yx6pwUIh}^9J*G}jJ6ns+pJpvvuCfhL z;q=~7$7br`kY`NbwB2S@e;f@^z=I%EF#Tp|K{~{f0spYcpM$TPZ(!Ejsr9WC5#axl7Z(7iUeJU-?5_bPYL*-$${4p)gD5J4$N;N`QN zrr3ER4>!v3>bFyAm}-d8qq5m76(8RdB}_kZS0TC6`)mtHAYu~^h0LN3y%=!eyl3oN zaiU&oI_*jxD&@SE^=Czo+-92S{w|}Td)$!d#O;g{{0qchsvv^BZ-gawOy%Yp18Q;Yj!nH?{i2Q!>(x|k1?>IBPAyIL)n z?n;?&?!5k_^X7}d3rXI%7q2VhwJ<8>GUv$(b|@r`(;`;@>9%*KEw;HfLP5U>?$G3Z z!SA9?v|yXk{2PwifsdEmf-mkO|4-FB$N;Om!w4Y4@3$E9aQ)m1H4{o-0yD19ZKSfM zYNAbEG!19Gcpq(^Lnc2>AwZkmEWfPV*-X3m1V>e5ii7B2nKNVYcc0gIv%e+QEMeKA z7w|SJ$h`5gw4FBj!~-42Oc>x6J%8`1DJ25?kG(L-+@9xslGS0jVfqjX7~kqk}RXHHP^`|;|9;4$+V za~79Khm^Kw?3G?dFGP3%+)`!gQTlnlFb?8&rm-0S8sR{_&^R91Jg~AT2wxqwai>}D z3!U(|eBi~PVYTu!GfM%EE$^D*X2vMxVCdQypj`)OYPY%R`MHyLCo}(?j--owYpOQf zt1sy}E@Q8Z#qiPa27jWI%;0_!IPv45JQXQh)^n`JAXdHMGjd+XKWa=bRx z6P&y9UmmrnkPl~4E6L%_YH@RmmB-yXB${8Q%Bu<{UzKX#P4fYY#-Zc81ypaSBnVir ztQ>M(Fs!A+G#*8s=MW@d53s^vO;?smS?H*t;$5u1LWRW(7T+@L-jr-a*9;d5mQ9Kc zcE+V1s7XvBY;np(Fp8w3Eevyai#ZieK?0t&H96wUYJa{oUe2eU_}E1#Ylpu?qtP^! zsm3j%U@P;xEiNa^Sv1BaZ3^3e(#^mCrs^mKC#-je74?_82?e=>%``+%q4-P* z`~D>Q>WZ3CtL>GE7JvX*m?L)T2aw|kC~GeGkI%TI6Z5H_n#}#@)>Z50v8U^VWK%uz zX(S8WagPHg@nI=53{SSSZ5gnRHe~Fmq3>~dAbB558lNkkp3qS$zQ>ka)0G>PpdNxy zguTM7cbC(_7yz?e>eDp5&CSWjBiossth3|GNsYwGnEBuh6~SOy(|gf7El;Pp!B`oB z*6|N8VK?s=D8JEpSJ_egL}70le*O!@3f}}`HpAltCHyN~6xO&i19biYf9uo2rWGjt z#ZQ0p4jGwC?TROsLr%^$eFF#uX8Wgx^n*kO^Y9--*n=)4yJvZLgK(_EExI@5O>~9r z`5#R3WoS9O9yYF?m5~&F?_*--MI$_L4p@(tU!bq*o#oL1M7VWN@z9+nQ#2UKywUKs zW^BOav$^wKQ*;bErZmor5cO?Jf#|YCc(XBfdCN!iaShJVr?huY$XeiEA|zo8u-&Rz|uQKdw*XcvX`j zw#vmXxRO4&^WkO@(=It5f>0jv3uMo8B8!~yCbY8)6B4ymxZ)xM1PBvK0@#A^L-Ak& z@o3dKYrCW<&y132;_mF3nFRx|dWj`nlNi?8sc4A1yJmLW4T(_Zg4--IL)TG+nEVaj zi5Nt6J`TV=)McL2yx6g4wYEu|zvy$K7OlaoreVJy#%du-0b=wI=#nMCNM1?#PT@ZF`D7T6^h_&7{n6fmgFjDL(tI=VGRxQ2g07zFZX9qx^Y z7L0K`OhkjX!#RB%Q}1~iGWPPTMO8%FMc9+1NyW_=A0^o-EXWKWo-+tyaXEIYi zcBLVLFYQ{chT*-%Tk>@eukj`Ysqj=e_ib@e*gcC^zegiiG;MbOnH|K&l_GNCb_@wa zw=*?$Qturu=ZSw=UIO#UV@$i=dA`7T1O6Gfq!R_HE{n_MdVbZl`a*TwL&gAI! zpjH7}TaxbzHdvmctf?lSa@^yk%%Iziio)&p1pTc$-7RPD77p;G*nWwfm+X5tV`o2` z`NA%fM%npXxC05o8v{xB4fkJokr(8>eLf8HzIgV*Nr0w% z83=kDqqw@eGb7n$b_$EZUX@NW;kuMg7jG1@q)W%ET}hb~lW|P(c_fo6EU@HWTKfS< zx1=2x2?sD3JAxQwEBlklVNd``5jakSyT$o)y>9L9tYSp zT=5c;3`jF$T;jQ8sZCqC45u?u&FhNS=F>PzgE43$9vD=hOl;iPhO+K}j)&M}qLbQX zb}*;%7hw-B`Nw#wr^xdQp)zH&jN@<2olgkg*|Gkx_$42SEgEY>GD$%$2}QV}=9N3F z_)_dOcNcKZ!{Mem5tgKIn-+P^XF~AT##@R3PXqcgBmrb!6@Z!`_G}#b3SF4($hbP8 zT)7}iqWHR^ivh*s+7>grHb`-&l{C;-{qMF>OLUn4o3?uBwE~$NcF_uFm+iE0p+d=j zFv88}Rfb?_|G>qy-gli1Hy=002divd-}~{&4}A`xk(aZ#l zkb7)TMhSgxOZ7&9v9)}toC#Wfsp#* zzXNkA;&ODbG>Af!0>txqkv%|c!|SxX#nND?-4_yGg7-#KF;~`&XEEqJt<{U*Z>Ux{ z@p$8U`BfRj`utOX{5DS&Yb=2XCV&Y@_&f)-pM6*#WjJxezP7S9c6Z(`e%|NqN;ar{_s=@KllBky`L^DTeuHUXy8hA{wopB>P{?RTdB}3em&jF)}gjs9G zQcqly#L?^OVdYwM@YZ=Ly6PsKSk*9=c#LVa;}fg*J$0$wM7$U77;D@WaD5bRkEGdP z<>tib9Hu?lEkoU{4kxq#M4FuqIEGtoUIj zlatv%$C+pc%jo-sNpo8pe&xPl24bFfEK7`l-9mcafXnnblf&6VIBCSH?%P}($*W31 zn{VnWG-R;Vu z)#5Nc+C8!GnMV?XbWB9YssX+M-xjU_1;~~&XHq+T@qigIcH5fSVqsqEF60K4pI4$< z%L>ZyBQ5yw22u>ciQ_Q9y>~dXg_5Ynt9}Nb?wsdU-g;rJWNuF%>e@FDxB%p^ewq## zR<6zr$J2>-$6dtB4(H)u&;8DHuCE=T6x6-*g-k4jI>2^LeB95~+&eB^SdN*W>VVIz z>y8rW@XEL_Orw@1Tpw2ecgWj#X{CK+4lUc=r<)aIo4Mk7b*n%yl&H%DIu!(Wp$k5shrkNwXYYtmUj-ejJ{Aomt#on9(qWI(j*dR_*|A1Zev$((fz)R zifdN}A`o-}*tj>pal>B}sOZ`6MIAL|XeZZCyT)C+^gc}*YDsx$FpE)9ipP0l?;pjF zA+CMLO9BD{CIOeu40t%8>8Q@XE^7yCFg2PlaJ>E|+DD;guQ%{U&)W`ao;gEu6oyO2 zl6NC6pQq!+QPWjJVXWN`5vd(A-t_aMn(f1FNx-27HtDlWUtS51a(r-i2_V2qb-t19+lX0{W+{ znH20iunjaeOX?l|LR%U;o~7Kor)DmQR)uDm|10rBdxQJ5x zW1MF`w7p}28HuqfcUx2U&U2Q~WE$3CWO6|~zXF}*NH~fv{5jy;468U}!xf`5qX|Tg zX~&(Ym!}@O&{n3AoB+Z4d zC>FtI>CTDEhRJ{ZsV3?FU&4x?bA#DTxlJZ!_}>Z;8y@xu=V_~g)G(d=xv`N$+**l5 z*EL#OcRC%mQ)C)4qTu1e)jdIcXKTub1q3gbngE|e#j2702m!B1!xP~-9;OxkwtT6-!8^JF`^{_ zsj01!bDd!=N!_}cz)V{7m&l&z2ZEMB-keMud%o(x(f$hKS1!})=vA{FEa5D zN&IW6{8#@Usc|JaVKv4wTlSucpzgz9UxOwIjZOaS%I?J6K&`}^Fa6vM^i|k=S7UYy zvKKcST<~d6<|mpe?)xz0^g{RxpCq%W^n80AN)$vIt0scd{rlyC5DEG=^2%Sn{ZuhD zBu{l^+9_Vu1;Xv|^{O7<4PAw+#8n&Ffqa+l3CR#kF?^x4CUMiBRdh2L0YvvpN20># z$Z{c>n;@3%hZ(rStAI?bE108rXrKq}wFm}mXeM`Y&eru}H}Fj%*(WAjhoAl!<FLTxw4jWC0Fb~HD)@K)Bh~5*`FBcOX1QO!ctWbJFn#WHw*bnl zLqP7TBorb44yUIp5Kk+gY zO#fu1YYjyFeKTp1S&->Bt_-3}r6jYi0Id8Z<0MlT@&H8{dG$)I zLA+W;p+I~VQBtph_ib&`TP_DVrqY-C>qtVmD2DDEjQr~a=KnxA|I<{FKUI(_C&AXn z0PAk+n1G8T+&{!aVcImEm2Eb1imAxszLU74g364sRyp<9l2jQh z`iWBUATBQmg*tjw`8p`+=shk64GB}?8GVW|Y~yk;$S}RNnp;<072Yd5vlJ_t`|;&1 zt5h^kILi^7VoyoH=?Am9rp)9?E4OWeYs4ey*$<)tAI?ia1f4lRdF!!eR4HTI48u&c zd#O^%>k9G(ayw-$%z(P-crK2Ft~E7P?puHmH~TZZDg#`WgJ*tlg1aqHsT~LS74Izr znYh);;p|nyrc>FbPx2(p-$(zj5{O~q73H_Pte=DfB5Ewq*hOkPs`XGfkXp+x{NP`t_6C+OL_$q1_VF(+ooU*(c@|RKt^d zU5i9O%*Gm|VEzn6BPk|%F>xf9epj$>u7k+_x#pc+9gru$uPNS%DQ+>EDYl#jr!B=! z-8|{1uQLrWO6NB)f5w{-sNU=+6jQp%&0EQODt5?UI{j{r0 zGintt{DL9`h}0Laez@<+dC8p&p=Gn@f$j#z| za+A2FM!W)$4j~-q`Mvs`OWbrdW1igHhfa&!@|!L=N$M{fM7%_u{_im%Mw35>LIb$R zn>0Bqwjj<7OAn(HQ^Vi&%hk}DyDo17PD!|rk3+$%rVm`fYZ6P{l=$U={+8(%@?}#w z;V1<$^yBL|57WwthsxVzr5zq`cMe|I#C0`UFApHwdj&8%-5oCXwzTiu;3_>df8z8g zPnLuz`)+-WTdWEjP)v$$$afV;{5eEMS<`xE660bF5^vKKz3ejyb#p3_qtRS17QgDG>~y&Yf*5AX3Xhv6X|8~kYC${ooMEbyG(i9SM0f)=w$xVu=W?}SRcP8f+ zw@j_&qiz6OYF|-N`O!6q`qOcJ^hbt=w);1|{iyXIq)-Up;jgxST?!KiA5m$~OYfpe zhcg8QB)&!4Jc_8h+QaEk4f6Lc!G+w#tbT6B(&hnWgp#T>JtgnGK6Q1!;B3z2pgw*G zco$uyPpKL?Sq{+2woz2$dj`fNNGXixPz)d+;?am3;|L)~_PB^0nBvPB*HkALTZov; zq>x?m%ELD|R_}5IZTYiuL0Wh24UR3leu6Anl0P0%;4gC`*iKG)#X+czhiez2Ig5hR zE5{R}i2~^=4(E-vCreoXI73tD>2w6KiGZn%z)KZ70k?&mG0##CGgE7IWG%vkv-?W0 z;amRLqFK$>BZaB90OfitPYaT13jxA#4tM6`)n{hho3fRE^w;N_*gsW%Ls{E5NA#w9 zPWFryg$%693`NjYU^3!m#J=vPSv)I9uTNAvwi)&U(FOaBcX4oAE_vclw;x|tIAn7yc!Viy&>+Sos17__DizaxJa0u=PpGb zDolkDCIYyIAxiXM1YhIW(g(we*lDrd45?cW$~`YVja4(e)WONURsr+#hhS2VJY1|> zyk2e`&%~5QR@ID@I1JPo8o%8k5wGQySm9&`oeJXJmg;_5|F|Y~uIR}-8_?&P&A7W6 zC+_v*?0G}v$@atZ4+vIx8x>t>c@@;YDP{gaLxjcrhfs^EZ0jZ^*L&-b7td3V?66AZ zXx&il3-vEucMs)z565mXn5l+sK7c(%B2*!YL!h{ys-N-TOL1lBl!m11j;RO9kqOoMU=fSra@wkyVb^KJe%Y~3@ZQTm*JI)s;WXK?u#JKA)x zJcRdy3H$i~kYeCH<> zRiof|sq6E!fw6C<2@^Tjre3~$sg7_nWmk>{$gMac#BS5Lb{k}frpC^!T}xNlPF52z z?nhRoR5+nF@Qj$g!5KdY&DK~wqI!(){6|VkK0ttzAkpqBOTtC&yq_6vh71A#YE=W) zD7qleI67SB{rXU_)ADG>ynyPRSMNRvwsLtBKa%*!7*98B-t2dOY!}fE|2%7MaiXJ~ zEZP&dD&(zM9WJj%Yo66wcst!eW2z+8V0bTU+=cUN$>e@XMPdOp$ravssKOL5g)+j# za)1zOE}kQM8{0GEIT<>pO<^D1L%$xcw(@jGfTo>4&))^LNV8$ZIJ!KdAY;d!X+v+X z@%*x$$x3F&z9aF#LUoSQvEKK^#o1YlIgu)T?d$wK718xio?oPt6?nx3mKJBjDftl< zwU&L?w6#CGneLh`!G_Hgfx0we2j|vj5Z1UoRH=Ly52q446k@V?#PNME#V)4rR#$Oz z{FfDxpz4ajQ6HwonOkjE92o0%aXLo=%h>YRi?|5UT!{y?RPR>FVl@TgbA)FCm70*_ zqPL4R#ed4bE7^GlYO0j$W)FTTbL$ehFz1BR&jE%>dD${gDrn9)8^d&lp}xFkpQvr- z$J(d$L(cnjGsXlBcqDQ=TKQC;3+mnFM_&Y!#YyGg9yO0@oWIT}D)#lk0q>K!nCVb= zIvW+9r+1gdDJ-_SRcfY0?Z(IBT@2QQKGJ`T6P5lE&51QPuAMZqD@mP+>F}xPQub=FMr^t>1SSvKkEf^=+vYfBe!%mIx?uRLWzS8V>A_&iCcUi#|-HX z+Z}TGi>{iAL_D_tc-A0O_C0)X!M`znZ``Zi=C=eAr2&pT49_`*9y$Df%NxX;B=iBg zf2*w6hf0jby0?nf<)q*agf zEMIDWkMD5%gFgUCR$x81H6>*J!O_m0VUA`nD3&~V3v7NSeM|_RQa&C) z0+6UI8eb9qzd4fgH+z5o?`2+{o<4Zh!p(g!s}Pm#>#S)=tEa%h%E~ag$3J5*4FA^w zbqf%TdNTG(zHwW}YSljCu7Sba@*2sf(CVeJy$p)^YH*_IU1c$(f%0}T$aTb&XSx=7 z$O4dNB7kgl+{B-DLEwML85UC6Tl=`gUmzX-=PRwG`f@ZWCu9u=D{BqnKV}zCsVshh z7Nr4Lv~!@;MoHY?9{SVn|J4~4u=})y1skcUH4V*u{gf=gm@J)B1I$1p7U$_P_{n%} zp;C9#5fgmPxkm3VP>Phpe^U*9jBAZyjDo(ze=^UCy*J?6*nXh|ra>Ibd!M{pQ z+|70n1vWWWR_4yfBBBzvu3d9|sd9L8qExA0ukK4thNam`Qv$!ltgGzxi@=cFCK~yN zh{As@a2B`0#kG-kqk|P5N1j!t_kPq4?&Ahw9}ml-{>g8pWWavPpmJ2)b!lSP$mChV z_>bB;kBQrDgUDge^L4U$TyqLPjX@G`cuEe($L${Y3Mj+iX#T#zla( z;^m{u^kdP-olqZLR+L+E^PAga{-fs{N|oVwsc(|sZgoz+wPPBEmCjq26*~FTWncaC z$N1?B>&vYvucQS{Pxe-l&j~ghrZalq_ebR3aCwYHdt??WE6F(iRTSkv_E*;pg`L)X zD#aPi8cYZtl-Z_injRIcsJrorQ-V;2w3($KuirD{jir_q@9k+~0vd~)_I-?vyjUMU zz|qubpzR-rl5$?U!yCB@RQ?Dv!5X)x3tLU?DoyyiyDIcbUw55W2uEHg(|@D)8~kzH zI1%rHWeryl#+)8lW9ApX@9yNjhI~(imGgb#ur;zGB<`B}|x*y_+UO zn%zv&MJb|e1O$Mu@2_2$70~P1a{DM*$h(A{e4lNhoUwPiQ3HQ_QU%-_T(rqDx$hFYWo&wZXpVTrT}@pndE~FG9gnn_^Eqz5 zQBl667p(V|dg{o%I>>zZ`N>Bc76p(d$<9p;W#;`d1J4hd=&*K5!SI9|3$nom1yn$S+VAI1!UO zGg+a8NKI>QwN!qbGRa%$D#7XYKn?EMD;OcXI7S?($R6247y!~3vZ{y94}QDq^PhG} ztp5Hpiq{rnx#h9QAgP!e{p~a36`h`750R7jS=mK~#JX=!lAVh_axU~*Y_y?P%}Siq zeQqNTw$en_0jj*~SY056$GaGpYM#8CjYyoAR&}}jz2ZT85UWMST770rhCH@A7_W~y zr)xKxODM~&f2&DWxj#y3!N|(B-CGC%&ggB-U}ZB-{+5liSf$uFisugRwM7i#6d>N- z{n?F^Q+T@Kt0f-!i%*x0>W9vNq$9jD0*n)19^!;{cP{W^jw3t^EniL8T+$HBdG^&H z=A-J6^LD5G%^b6IHesKI`KqbM11+}6Zg-qjZOuNUU5LPe^T(Dx4e8~ zYCd?o(U!UkL|)^1)b9~or16_l-r$xP9kIxda4EFfI5PHTZ|>K=ef5Wzd32&AbRYcq zAFdoizgk*B=(o&XrFHfOFWfN&_&up&efB*pS&v-y&dJK)dmrEgDHZlZBlW*N@HadE z<*&vx=qD40v)Vo`Qc=lHH=akYiVgLxlX}p+vi@7k>>p%`{aOEdATC>48;oV34@~O4 zP4iUlna3rYljPnen)|1ydyiV};!d==n3Pmh0_kwi4{7t2uZQcs10e6)|D93%+iOkP z@$BiLF+M&n#d9TJy1wP+)wgnYv4gfhmwC1I>t%E;_;`X7_tm1iU62Mu%KKfVfJg72 zv2=MkqqTMHizs(es&F4M(`4_GHjlnwkpau_`6t*ip;M{G6{*_VF3(q02j<)~-`2{i zk@LTU{u6G+b~rnn8)LQ}bUq7w0VYZ;I_8GME|5K`sMIvQvv6&02u0YVdMK{pQaBCZeTGHU(xhgc+BKTFA*VJuh0kH;HO&hDhGt&oYiqY5$ti z^0+P7zc>rGjf(N;vZnOhqk+24?xZDesa%QjRj@e(fWKI=VCre#kXtnKT-L7es=|;+Xw1U>zy6>!&IA>hO%5;^#_9n1n*8oU1MSd~= z^_Kis_1v~ZY{K8*o~Fmw9%)#7+^6v;0$T01L1)ewjcMxA=q^QRr<+~ICzMF%T&nWV zaZJzZey6!*%$w1uf<{D+WKTRDeSw-1j@{PNh_}$Dyu9qAa`N6VQAK!hf78tHH@d9L z3ara(*r_*!>oDcYDAOj)z7A;vY*f_B*5iV~NMn2!DqJC?rOSu81y_C-cYtN>w4%gX zm7@Iog&$%FrI44e4x2*33G6*_0{a?TSBQ)$qYfNot&)%YnLx(D&|9rmLZbn4(1R|s zH<`&f+j^#gxddC6wc8SWb>B-YU*6@Bj6N#&19&UE0a_|+j?K(VB?01iiC`{ugV^!! z&g?Jt!&T>NcG_)Y_KQPebLoxL>16q9?~gWrjRc*9bO4#6ahL;iXfc)*Rg=II@oJHw^U`?L*vpjB+~Jr)QCIxa zgXwxviQzcqF^GasNx*mpGZNdC4#}PS1+uf$WR7N}3azQVCP^pvgvfMyKfnsqvYmGP z3&e&OPVL~nq9l$PitwbP2J2GJ_KQ_se3o`w_O1Yf-J`jWHe9+MY2NL7ujRahJ)`!R zx6JaU!kg4TXGMa~MrFc7=~CHa>7ksqhoGl!Dx4PzP~s6h)C4}Tl`PIJPZjOL(@7?X z95X|?6hxwk&VLwa?#H`pawa88wD1qZSc~u$6P{G(;4WXWny|b)#H{6wYl&O?nn81x z1=JPPac{5fjh9_ist05yArD3~LEyiO-eZD2uVG08xZs0oDqd>kB$UxhRHw4U^94P;7ijfk<>CI)v}3T29i4HLQ;`*0el<6 z?X3Fdtzgnj$JixJ!G=Q3w(es?x#&+e3GsBkr_#VHGn=(K&wGWWEeK9+ot*N`mGDmy zd+AJ1^}I;>;tXSij14(Co5Bu({U?w~fTS%&utJ5#+J*;R*NvQZ3;4NRI+_}ZHf}O| z41$KFHUTM*2pC?tgk_BlVBP|psnrFuo*WS_@A(cG6vEexAZcn(E`u_3^l+<7<%|GC zp@YDPSe@&emR!E<@o{)S81mvkG7B61SqDv1D!!dSR35TI(9KDpC zJpInTkM)IKlBpf+Z@j!wFZ?BEx%Rb!L(vw33X2TB{c}=WLC{ZN&(_Yw%yD%(y=;Xj zPMm8<7PQQK_=Jh*0N^4d!L_4->?4U~F}@i^lCBKS4?8tSm3>0tw`{)E`doUtJA=Hr z9L-4*-w$UZ3?uk5hP0zvYXX>OMd9G_8|605anC>2MN&jo%w4!(9=*(qZ{P5}O42H| zNCr(iI10n8$Cj2HZ+NjkoUEMlNr*|3{Xq4l<{-U2KGy@FrkZEcBM*jd17GW8fD{19 z^!nhi?j*-&71i6+Gu3xX?TM1w)19RAi935{ncPa;TZ;?|Bti*U8WBa`=i2=p1?2Vf z!uXwBr;Y9iM=#%D4=vHY8eWD(4A8DwFB_k!B zYvF8sdX-n=PFYU(A}!vg6la67O3!y9cy@>GRrOYr9Xz)dl7~*%$~q9$#6$0O91Zy2 z1$gOFcW6*NmvHMSs{#~~%FMr3zbg~e4~6dqoVXK8v(~oMPxp~RuPyGSaLjH01IaA z1Eg>}H5>I}C3$Z(MMV2agMbf((vZlvp8hM}v`RjGkcN@D_M$~BjZw%bB|QAK4_99UJbdA6xivTiQ1Qy`Y;dvE3X9MSF}K)&a*81@Ual8tL5)H!9Il1^+S zuLl75_{1pm-YF^{zQJgm{l|d+27LgsFGZVvfy$%c=d{gRO{_=2G?r-U^9y7R#7|GT zK)#Ewfd-96klTGgITh7X$Rzon4*aqFwFpi91>z3?{N)xq$U{a;e|L6ke9Y>- z{snS_CxdkV%F9oF9Q1ognZg;A27oj$#Yok?%6Q7yfDg53;QlG5CrXu(B@KKptKP zZdMK4Z03y6%m+Nx-`1caA@5Y)zpL`R%VeJ-p|q@;OwT5M73-r$DawW)hl!UaJ{f{{cVp$PjG` z9&4Y8OZ#qL`<84s;IS@pa#DIj<2wbRnjRA)0itoL?4i|pCcNkKUt2F7q14yFHI-j#D|Eh;fAg)_a;^y`-Y|3j z+`=;a3$#I8R~`;$hKXU&L*FLy5B74|B1ieC%Xi@o$Q_<`;mE#HEkoXIqvXW-${ad1 zgkRP@MQ6F>&;&ccDQF>BTH273sCoJNXYSmo7Hc+`E0WD~mW!ih#-!lVdH#h^LePr||GAu0ZDf?{Sb1s$;<`S*O({&8({H z0GV7sP{?T}t>KfR{u%V-i0=HGiOgGDKaQ+dhI;j#UbzNzT=8&gbV#k*n;&yss6Q$n z(VdpOj8d;b^c%GuEA(QjYK zdZ}+u7~%eL=Fc5KEI$odI|i75e|M$*#ieEu76tP`DU;-V!b;(4KInPHPA=S8scWcz z#qs`WZIG3dR=FrWEx9lg)VMo?3Hzxjs(os~W$>fEYpH#txp|`vub{#<_)V-d}M-q945S<0H6RsZNq?M~n3xaEe)%k@=w zqyKGnajaotpgy@Zl8bp7ih|F@%(y9pd!P+b4<|6qlNLn=gSEHe&hHT=z*dePvs&$ zQH*LESjyx{PC>AY1YnAbK>lF$`vuD8u8W)ABWSX``PuLbH2U6=M0?pk#F>W@f1k2! znUna1jl#XH6#~c}$T&$khbXbMx)yxkC=S;hPY~btJzjbg{eC>wS}gMTmainiQPY1WGYfz+HKUqz zluwQgfQ-|Y)N8;(N#EKa0A>eNZYLdA-d?ai2aQ{C0J95R=eMa0c}Q=t{U83M!eRLD zQ>h0*1h~cnAdfB+Gm|sM2mkm5+Kz^wQk3j50~22TXJFu1O1VS`cfH_0yOhsK^~g?t zgAW6ntTk}Eo>Dkr3%&uiysf;IbUuPFevd`sBd30Wb_)oh)noSwh060heSrf1zN{pP z$9Cfbc?b>dfRIK4^Wop@Zv2G?3^*+n@ADl#^%Hp@H8pi*SsZr&Ow;At@HKGx?_Nw= z^-gOnHXA;L2LKk97IBft<>d#!SvX`uS;OyBuq@prpqm~lz!UDnH<*Fe&^#x!jyyjX zbkB(!PJef;u1G;nI9~A_yCM&Xep5fS_?QN|vs&v@P#NNn#;Ar)xLyI;?Pu!Al(og1 zb$o0C{M@%6zx_wIsKena%77NMnp`~P{8?`L+wB7m2+gI)L)mSs-<$ev2>)g~K!V`i zLs7~Rn3y9&wO9A@buODzP zvyYYh5BANPyrWFbFRk&HDt9Rr)pbrr-u}a<&b*TEIOoEFqA5qBN+28syYk zO-_Opq_rs^b!vfcG~tKpF{I6JAF>W5%2NHakm{%zQ$4_Y`y$_5G$T}vp^+9ZkM69_5I7+bb1PuUR)dOWI2?7$p9|QBWm!?tpXVBEG zT@dLbig0eZMc}JL%fD~}zt|oFb4!6=a7v@8yoL}d4}tpcL;g#X|H{dKb?3jP;J@ba zze~n{m(l-v1^?$gr7_18v_|`lovn_VpPg9t&bv8uRlh4=)K9W;+;wK_W;@p%{td}A zakZul@%0%gTRsXXd}Qr0;vpsL?EWs%rGwB36M4a=zB}XhKq}&|w5SH#UPc^A+g}>9 z&Vn~`7in!J%l2JNA7j)eK9zv-#s;1cb%9V}Wu}+a7uG%!|2P96CfBk)vsv7Ih26!{ zHtX8-Cz3r62foX%;biu;Id|csy2uU2*n(3P+i%rQ?>`s$pdYiw)w)j*E9lerG<8>o z?{^lEljG)KDrRJedd3wX^Ey}~mVez=+bF@(oHVkIAixcyaQ+Z6jl{l$kl4d`ze!Fs z-pISfM53iYjptjqO5@~=>+X%FdaO=1E^=1hGwC-T$o3zQC||R-mR+l_O8})zZvCo- zRy>v6Ca2ix6PcKfh#S<*ejjp0y4VX;4|E4Os>;Pjeq6t0vlmM_#7>p$%W)i@oqb>vK^{i$2@90IHwnYA3BC3Z^ZRqGWi+bW)`A_ z0=Szk9$phqZlwFW9Ss0hWQWNWkCN*;Ti;KY$x93hu}zTRZKJNtA5a z>3V4EZPO3ECDrmt!+At>r#7r^Hc0a{1UoiIhIX>4CP~*7qi&Izaq93RZuEPh;;DAV zql4VJbVF$K`x!l7^~(RvXhq|Oy~PzI+7#Ow_CM?u9s5xK(c^ldZ7!=6tU|-Uxn0oE z?py3*X}nN_6N<=IU$ok!Wj_z3<;a-lKnE+H6nsTTBKtkVEe^s8GxQcYvYDV{Q( zXhRj_yTW7DsQVsmvsr)*@4*g?vt<#W6Z_r_&nERqgDwBRg8pZjrRa$Czo5d;x4P|m z7~Y2`YtHVM$=I~RZ<#7+_fgcl0#HD3q6MIEU}MZDff{gV1@IgDsIk}I5s;r>wpz9v zdL){t`f^9T(MBq0i2<)?%=VKeuZ#u*@B;<_>&eXYg-{iqh9Y;g$uhA54#3~J0_2p-E`875E<<@89 z=bEI;L^#UONJUU{nLL;jlctO`o|pS-gX|Ggru}2fO^TLU z5R3=F9n%t`F|@tm7fP((>6^|PX^NeH?Cy*-cvh!u9C0poZ@%(YJsxk-%)lCjB7f9q z>p|}q?Z<2iFv&;n^y**m%UwLiUe7#}n$>$czvVTr!Q;*G(K8}m$2Si5`5!r}y2VGQ zVPk32o1*(;1$)Mm@oVkEym?8x!s`UHeCjVAPS7UQT_g!U?@OSf-l*@(d(P8&t|Lov z|K`CLZB>C(Hv*2q(odHp@YzS3qZ}{YlpQUqJK@Th%T3D4?ZZ953r0?{ZdunmB z`FIiQ(O!Nz&5fVQN`%=(E6TlK$7{m;IouoX75iCpI+G&iIe@j zv-w_@xusylXi>KEn`#X=K^@;b9GVLA37}kyh0J>@E4O*av=d8?&iA}0jN-g!HQvD7 zbI6CL{PYyDqv>GTJa_Ca6poYfQg5%osB|!e3#}x$;O=- zB?8x;RKLm9xpqljso47%QIeEFbh(>hW2lz({_c&|F{;Z;WD$RjpZ{*c+y+l?jDi`X-UiW5=!2J-?ZQV(n#J2~L!Wyi36ITIsmnRVAH% zY;dWE_R=ObWs-15^A%UkwyD?t`a(~V5`|P0U!$IGTmWs@3XXoIL6s5g6`&Qf1Ub7* zuT$5uouhjw#29Tk@!qYinRdrjt3jn#Ok<1G0qw8~XWMI#Pk}sfY6;P|@2(p$iI5c< zXPKv;o7*DuCWC#f!_<_+f4u~szEJRM=bKGxw}VYpPP(Y=XPA|A(~6D55uq}JxK|Q< zzJ&_$i6$emlBVXny1z8_)J{b|C9qXo3Anaid(Dvs`{ER1VS;FB3S+rwMJM}RX=7RA z>!F`s*~A+vIh$NQ6*MSv^>mJD&cwnyhZpTk1n7Qt&&Kp{(ED&aQ#WzA`BShJhh@7< zTc^*P(7KE_GOrINJgUev;qBUC_xYl6YB6*1^C7AlEKQotP}06z9#Q?oGWopmrNWfh z!lZ!D?FtuKd?hCZ)z2Ro2&?m`D<>J0?5Z7qx93Uzfw9Dc+TToMiIwf=TIDW*BU>I1 zhw~kN*|$60|5NA0;XB{2r9TKx{eGKB%!n-&JHr00+^1E&Q~uSH{bFT1h$o`h3}H@a z<8q)KwHuTj6TCsEuCh6AwyezDLbY!X`(_S4>R=u}x2Tuu^{^Rs6{Ss~N(-h5W|{fq zosAJ9SIrAQhyi37d9HyOEVCPMzNn!cxW|wWlK2c*l{*_2WXWJW*C%XBltSu^^7Qw{IsR%R6~+#5Cpig$4R?fnO%hv8mXK2uIDx z4C6PC+k_Vj?lybCz-KnRcFz(bQp!ksY!QPIE!AS}*>!#QXh;-}yZ}O@{Gh%4+RMH1 zU^M9MRDm|>-gQ(<5?e^~{h6+c9JNmh=&%zas&TGbWu&xR_=EKlLQWT6;3i#|QX?(S-lgS7XKhD#07Q~uxDV}Lt7DM-E=4(kk{z7B94Y{czePF9|&eIQ*rWJr= za<9G&IkQ>&1HKeXK21FUJgc?T=Tvd%$Fx;{kgX_zI@2#xzm;s3VSGZFvc=|`!1hJI z&S>Ifpy$CnwihYhw*p*m6YzQRNipp|9hGi1_YRM?PAP(cP}ya?O>EF=-`cBEr4XRsS)Cu z=)I_$I7evxOw=ar|HmUl@sYs@E%v?fzPFI>jG4q*jY2CkOB6hknaRi42h_-rors3KxQyZ?j|n)fXswz^u5OPoZ%su*LAaa_Bo2R3mlN`ICVl#=lNZ8;z` z@-4J}R;K-b!KY$+*8hou+>&dq5Zq5}u8$7ID`mkuU2UH#w6@WS`FS~AYpZ?Ir%&QY zAc>NQRWf$@p||y(8+VnHmcIsJsBy-%W9<4J2c;j z%yU{pOD5J^e zX0Yt83<%dR&iw0fQ6YPs7nkAEg&K2$#@lw7S zT_^f5oS7+wMfXJVdBZ)pZ4f~r>Zh=9+Wd`=wd)tuneSLIew=0JqMjjgSB56IPA-aG zYM2!t)l%rQIlN&k{%r(omX&q$v9?_{$F_k}&a9j}Y!>y`tP7p=r&{-6c&rt;$gUC+ zmx*0U>t4OF($RFp--zVOc0^F|o?|%O24-_|S-f+2m)@#scQp z?vjL4?ZHnmP65x=d8{w$^E2ORI(oruo#MNXLpi7JUq_#ww#iI6q>o;M?)*Hm?z`Ic zT%JBB&1&b@kvFcy>?53)k~qTnfmTbzp8x zn5dN{-!vlxi#!NZt5HMjsS&mn@=__fNwzdiB^&mI<-ekyC^Ux3JKJ7zk7vDmFz5Qc zIr*DKU36z)3ce#Yhjl7{p~(6BjX?Go3&c;!FO6_7NtVVE0)0K@R7MuoZ{L|mQuI6J zmTkS|{ncyRFGbx89ZYSm^k`ttY&7zlo4k)7LO+HR&P zGX#pLY~>{PqGq@;G+x&0aAWV3uAJ};z_C>QQx@gXBNrpLK9iXFUZDHtz!#D8gEawx ze6b|T;@Ld%1A-++g0JZm$9P~~I@HV7A0O-Oq6?pTG>iT2-L492Bt||4>1W#cNqy{{ zMU6Kywb{^JK_MFZ&JbF-4{%UItVh|$dpV#6yeMr6p+QOu?Ag{@;PFbO{g?=`DC}|V z!^bLaQ=b^tUJJaX#{a^!r>M+-|GDzZR=U0vnJG)$GvYc+ZRUu^luMn*;-|H|&2z|l z700v#9rc~Md2}Q*IuP?~opgs(?G5ecXhjM~td)T<3IhWai;_uvV zA9YG;&UJfl8ZMAoLhPJ-4pL>nLN&BAYqYN5XsYWy3qCvJH7aY?FHAS%o?L9F zEvFehAzFT)6oq$S34tG0rMn7phyVB*->O`o*?|b>>^_FRZe!oOLc7I1vGPa0wqK%q zE=b(ru7+&M0kF>kYLcS!7cVh`9@7Izkd=sj1PxtiZx3-TGbAzU8?0-%hMvhmeanN~ z^XlpNh8P`6Ra$f@N6$zB(n3wxzkF@<5mB3ZXBU8%o`Gy9F;uoE08vj;aq~3s$3)xg zKBlH?EK${EC0}z z_0kNZj(^O8U$Mt&ww{~@E7z&QIhkiy^f4L9WgRJmVYgz)*ub^b5GQ{mifwUqdQ^_2 zlrd?2vkjs_MEC<1HjJNFHd%SVa@y(u9dY*jwM&d3i|`RB+ChybqbN*HKwNl$xTKDj zTO8bC79d=IMnPkg{;7;<(Yf^c_*btntWMP$h!ky?Q*0i}&)CMx+sLsvGw_vUZ2nh8 z{Fn6oIos>Jb}Bi*6skapg6v|+rd(+4Xygv`CjCTH`XPJ|KWkr8=t?d+Q z3b|Y{V2z=%xmT^wJho9Plcn25M|O_&>bcH#3v<2ha2(1dZVzKBeLdojBvOCWlwd1r z+z-C3XB7-MGWTW)aRW88F<%k6)VR8Y@qf7eN2c5Vna{r%NTeI@K&m+@A?Jxf*ClD8 z=U2?_DjX{9dj%{Qg+W^Fl#V9|P#8(q)7DtcXppqnFuY5MQm;M-?!s43yM}u#_AQ<$ zUvviX3wxaVBuG5whw4+pX>&6S?K(zSTt9YU_8kBxda$S}8$qvt#N((k`U~>cuDuIJruVi!LEV-c7?rlEB~`o2g&94b%rQNj>NSo&Q>QWNAVQ zo1Q)kFo*aB_b~R#ftz+_1pyNeHf0n^Ywv>gGEUc85CYrWb?GVmVO^8)G*%J}@XIvS zehHzFg`V$fU^Z61sQ^7h4Yt+5i|%rNZRD-hme7>!C^Y-23xEsKMZe!akkL_K33hJR z-)D8!xP)VHFG=%f_AFaO)!!$mxNE)c3a}*~$HREiIPYQrHL1416cq#!_==__lkbdeD=YFQrYu@gF>@B_)@Jj( z{lMVCGaOt;;V=hOxeCkuu$fTOHW3N8Jt z-4DGryxLhKXaH?rLO5Qalb>IR9e4G@3Yk}42}H!mCyn~By?1t)7xvNook^)0oq!jm z<{kmoKupcBVBXt*TmAm_Z2ySYj7oo4zGD};8&}y&#ExtqPtID9dU1R(*0?F2gr75x z?6Q_{JJ3G}P52`w{|(ar(O?*W%NslbQHW^THsfm{B5K!y?<9<9xwclO^mQS(eVXd( zc*g2~>s4?3GlL!EMMKO%O%FAKz`A)x;w+_2M*PxL%B5D-jRUy@I&*{`I^mIXPe_M zT`HS@9iqx88pZfV^eunHLuKx9vBIJO>v6Ng3534c@q9R6FaUj+szZAvjonxf!G7v4f#}~+j{$B{h{>DILfu0V$`^x*MX*@>#5fFtJIrnX8bM{;>TNC zf%P%(ZXk>aUSI+zOOIgno3zoFiUMk;TlrP*6fuK-#p~katVzxzHt(r5R<#!`nTftt zl-Dylcj0j_!-VBv1YvuftVXt`rlH5Bi%0s;CPoPjF`f8b>XM<&_5GMJ0adTeu-^AI zFOJW2T6cb&vO1zxq;+lnaV!3!Cw{ZNg7|@mxzq-18Tad}1>qyk$9So?r2(Pm!G=MR zJITZ&S{ppXkTt7Ke^E|u?&*2gmJVj{qJ}nZs$iWt9!!+97uzct~Zc3X9yvxh^y|m4<<=*ws zlv;*VN;~|dMp3X4&FS4@R}|ETXCc3<`j&U$GO{t{^(A4NetVb9)a3Yu=;!rJ`qq2r zGe)Df(>{a)GhcP5y^iF>D_0c1Z=`<05?1YY#rU6pAF6k7Z*R z$;Rtb<;DCZ#11VbxbPb=_pHhU`~vR`s7C-?Ji)#17uV8jZ9r4>-x{N@IJ*|Z=qZ-> zAaH+Rc?D%BjEAU!3HUP9#ID17Z&>=aPE;RuN#LVgrNDODk-5D;aQ9rhJZOI`JCrPfBN z5=|Nem6PbUB}6SKeeny#_kSvb@hMA)XY>U!NN%XA@&WiA2@*DTx~Ty8j}794@qMsL z-pER^jT~?RLG5n9jT@z_18T(&*fc=h026Eg{&uxh@6_QT z)I>Klw5oJR3N_%=3^t=}=s{4p!Hb(ZHVleIMy00dMs`pLc#5<^nuj4Y4H; za^Q-1MKX*ePnSx%#Cskk3HXHhIN3|qUUZLqDMNba;d@XcYf9Pe#_3%(2g|>XNu10H zx_DrGxdPu3VoPWd!$bwQ&&pup~8x0m_?@rOmI zw!}{zu|~}%DF0r8Ow2mh2=peh zcV&;IB+nt~aG}q(!Ad?CRL=OTiopDOPoc6iyYVv- zVpDun)X~1vj`mJhHm2;_9Wr&MYSwR)-9BBpeD-s30!KACY>HA|_kQ;R-ZSg*{@w** zGitJ>m+FosTlNZqNV6{hg1mv$^tVojx-*xh`56qBd)Ygt?DfodCAP~-9DMX@!b=~- zz2_C~ir5pTsh^Ttk2ziI>JW^Sz59+|ey?8I$pe8-Hn=n3jKuLI*ZT5CSMze7!^98H z@3vmcVvBLKGD;GEKYP}VnfSfvMOxqt`&vsu23pjN7mG?Oug-nRGH%r9Wm6le4?>82 z9)S)Iqg)mb;x7+AV=aE$q)Jhy880fa}y zLUXAJQ%jMQ5j&Eq9@?e-^s#At4A%)B_ZOFa~B-6Oexam??S z!~_o$rV8S>OI^A9=E!-79VH%!Q+DkNEeg+5h>K2_lTG{V8gud;Xb?Y@PRadM}jmR#D?*(Y2=Pu`PS9YWIg+ z+ICI`-hKIa`%^0vSAjjjSRCHcD@YcxRM>uh!@loXEg^H$+P7%K6XR`MfvTIo>#m*V zKC8bD5rLT35U=wz?H}#tGzQsD1e{*cFbOUGB;Fs=ZL} z4%`GyX-c6+`@nLb6vb{x#8Le_Wod4Pz?5gEP>O&0zVfI?O~Adz0N-u&Wj1`r8n}62 z1qjB2`%;}jweT_?y3!Jo4PA+w5(lryLCpvmv!+o!lwlC}78_g!Om+g`ckwO&mAM;0 zOspycJs~v}wc&>sy4>h=fIcbqCJ%Dz=m>6hl6jWS83V^pp~0itUIC95JI&qLeiS_a zN4JeARs18RmtPb~5!@T(X9i5NUt>CYhq)>7fjd73*wxIuj`{l=fJz7|;T ztqOa=TwQd)T-m9KbHU(o&!s2?Bz#8<9AWzh8KMpUKn{~!+%RUyXU9z&Tu>*}GeSFb z0&J__shw3nvL-iw{Kk8)=V!Ke=X3|cC*Ne``Q{3#`GwlDE8reXcqK)rH`|QqaRhw= zEis_d+=XxQ!qF2hVrOGhw9LAw^{8gYJ3X=yTB`F8Z-iz&j`aKlXAT!5_8(c3^UQYq zhMcL3o5w-J61K?XDYcRd=ha&t7jsXiJ%kmk6+~Jf0xKDX+0|NG5uC<-u=W^omv%QJ z>@IsIM!?z+`HVn3ThNnsEfdoy7^fXbf{m}SAoBHt|Vdp$Mdsa zdmf@i-f*;RFONh$vgRv$`{Y#lhfQbPYQmNf1`8A3Xw$Z$7YK&i#oPl*lI5nIZywHH zX<}JI*h29GUXn)hOiz>>Kgv_`kEPgrV`^@nZWyEkGBem9z$Iu-I&YX=uU zRfU<3Sig$)r{p>IuTLnsBs`rrJ9yDHR$W#!-#v>vs1LQkJ1a$*3kRAc^IpG47%&++ zw;YfF-q*!z>aU>cWw2?v7XpJpX6UM~8NZ9s#Z{Pedoq3T{tW2L{wGa_&KhM5W5he7 zLd-7idHwYeL$PRwnFP)HpBoOxoI)|NXlBJCB$t;s(3AYz`!-W9LYMHzdhyP zo|uP1cE;B?m5q(D{ITnigl$)O8bTM!kl@YPn z*c)mph~(DXTvYpb@nwAk${gB4_MInuF|d7U%3bSzq5Q3_rSoC_g)*>NC{m>8flc75 z*)BkHhAq3BWO_Ntd^G1T<)=Q+Si9{1h&l|~L3GUa6H|p+iq91d8%{J((3{_C_}Uw* zS}K?g)}}3FxqZ)0a4jCnCS*KST$IaIaD@qT2_R|9wbB13iUkfktmQ$_O2Zo z*Z{iz6yPZg2?fAC+X+8k4opHrJB#j---|D*R-$Is`z;~XFq3sw(dW0ihjAp>r8n$E zQr9C1s&m*tXo!OyUdv8QEdohcQe0MoQVN1CjKlkVrVK^vkpu%ilUM!hV_l~$_XiIE z&snl+UNzJ9_XyZ(P~g^4a6Ha0o12}oAH@PIlZ~M# zk{&3eaT5)4V&@guP#$q!Ozv~;bP0T59K^He?x05`qS%z^AMiJ+Jf+P$BDn8x*zuW! z^Iln-ys3gWy^@mY$IInf1Ivw}Mk&1u5u-z(l!b0C|EgbqG8nnXn~^P3%U*tZ2O6IAfOYRePRjWJPK?W_iyzLbY|o4O3E$(IfEu<2j7Zn__=h% zf(f3ciLYj#2~<5A5B9mq_)nmn%b!U7pK(|7R8H$)yAxxQ!LBzvZTV}|@piom4#Phu zSUjgU6D2kBSxFo~^jz4|CCw1I}jk zv}hbT{QFh!x=mGTt($RAj4KuA#p|cupe*nUdwW2NWfq!p)CxD&;ec5ep19$g+K0S= z?N8YW^KT#vkLQ*}yT}G=X6oFY$)T4OX~dD=`tct;@(iRP*7_JB z1=FC<7jNE;EKy_ptXE~m;d-MoYg>xwWZDN96-;s2^6|*nV-{>+ni!_zgW<9`FY5x8AKhn}Bvz7Wwt$ zO^(aU!rp}6WJ;RB-9Lt2Ua)`tV*RR~H+J3!TQd5uSOz3uDc6vd_#R5TG5v#oEr@^2 z-&;~Yp-uptULjAaYKcME3+V?Z^6o>g>!{YW$rOaklM6`tSwJg4k78;0#kJOI6R@Ei z2FlZ((OavhL_hr8?cJJZmQ!lAP*wqJIx+od^#A5fD`>qX-SCpxxvktSiQl7{zMNW+ zq*}w$427$(Fnt4)W!=7h3B*V$;5gKR1-@Hw7Wz9f@@EpX8bwQGKhcswNE;pKR=z8N zs>FZ!PjMAEAgj=RJMXk@j+hfg-*Sl?iebq3?wn}QZ3fZa;vVICh~7w z?Wa{z_`Z3v0v9aBuSuJ>wKazl`8*xpL>Pj;MFYLMS;P#cNwD5*8_mqV9C()B-gE1C zq8N6yF+{Bx*#@SOUte9-Qtu+2d2aIxt;_sW=xOs!-MPzP(0W}}Jt`J~`6p`o`{n|C z|8>8k=6}Tur^D&tzd23hy`nYbaN@_Nc|zF&-$caIr>A=h!pT8O0u^e8r-Ko$__H^wnnK*@$M*Xa!ThBx>}V%Iodz#wFgRn>+7% z_6!`WHl;4Fg?Ugxm>N2!zis-AVf{NVUYrYU-|2Pepid*J91+cPj!`r5)F#7^ChB8( z6i<^v;z7$)g?$$7ylhQ1;>Pdux760QO@*k-g=p)dqgxkhh)7tY2zcul6h%wYJ4o*{ z*&Q7I8xPAaoT64$hTdojK~IlE!@3IIv{=y`Kf)LLvX_s0PljAxLLl_6;Ly*MJtD(x zY|P>eE_IlW5taDu`){Y8XR0C?WtU?%V7I#-KTD)+e*f{+SJOtbzzGu%#-P^}P9x@0 zx)F_K$cItK~dQ8g*)~jaVcq@4-uVx2I++ z7=Etd4mojF&(T~|Bs=4Yjg<0roN0IB!#6fI9&yGUJ^f?TWg<7tgVUe*))VxSO1-5# zRqkr1KQ0m6hRG%e=&kFx>i2zrGESVQN_7W8RcHPoVz8ywK4VBHRO7kxTyN}cc zI4WG@v2setisxSDU%^(Ixodnf%vjc>HT=2cvCa#79oDF==QJq*XI%#3L|#sFm%e_P zR>{g)YKuXqx_D{M_0KsTe%}hSCNnVkCqH83I-d6*57x!Z@bG^a%11Mpl27#_HAXnRrYx{HOn7C(qcDm z)vf9gedc(n`&)K%zz(yIn|C9S3ENVq;VqKp@GV~CeCI_yU<|_Upy8lQkm`lCNt`^I$;j+pCQRl!*D|fdY(wi_Du`T5cgqX-qpTXnw?UhoJ@<5 zB0u`#RRK%S{G#8(e`NAlvRWpNf-r)5Eti`g41_7{XsrsEJq(w=c)V5)V{)t=u$PYE zZ^6EHrA_pIN>oxkDk68ZZ7&NS>EtfGyKNwkRGoG<)dB3 zfhz^5guM-^TwBpEN2|C!G)Qc3rStboZ@+urbP%l9j(R8ha0|RnND8o@AJZ(6QyW0W zYAVNarp$6B+ltz*ditAq|DHnMi~^Ji6o({=9br9^(R$x$Hc`#&wHKm-)gCC z9PYq3UB&m}sBf>*v)JMJmRk^8rV7hT_wVS85rbVdgJhQk6c&_R1p?h^`h9Grb<;p+Z=n> z7$!%+FGX}eZt!P6uqni}QFRBD;hKgbKVMWTiTi@`RS?+GDsvqp@kMMQUD&bQ>d<$5 z^Eb*&G78~c-x-+=0lxZGd>F{jdwDOq@x}WCaxg8QeU3N&q=#O8-*iW70Ew&AP z9i@;~Gp$g+k07S+E-vs=t2u7&4ta1WPfC_->*mv4Rj%F1Sk8LPQ2yZ~F7FA6g{w!p z1kUuzOt}*oX42kJHxTmK{UYqr8{`fgdDgNa)wrSVfIvr~JlAD*jd2N^lw_*eu9VlB zYY$isCER}n!W6{_J0|DWj<9QIUrE_qm~Qp>0KW|WZvMqskC(}(tm0f$);;JbDuV{e zpZSHH7B&z3>N4mkvm(a&un-WCPPS;ayLJrO;Xq_QS~revFNdb2M$oh4pku5BV`!tC zQZg$~G3&aow=ulO+d&Q;y`0!spr_V0FITfj(GV9?R$IxR48>nxBldRge&q(Gj;x!g z)Z7x*ma}EEi`n2XR%6uqPr9TU@jmpF??-ZWmA!U#DAy$9Z)k0zxY`tSp9p@7$awae zqpL{Pn-u-z;o(~<)}geQp7mrM6DCeAk%>Izgo7I6B0={UCC1dRdW77xVC)zI$M~Es z+H;rC>%W_LM6_II*d-;-=j`}g3Cq*NNo%%FRx*5zWas{Ra0vH~NNhGycDXKt-jg`8 z!LF(TY1!4Fg&XH|%g644r=e^H~W>RQek-cTUXgi1Fj;0*0fh49>RmFPL{53 z`|H%qwO(s(y3N?gF*V@}(Kb z>prQ%*^wvfBn$L*D+dMaXLhu3R*CY+L~sv68}S3GU}aC8$5Spyy6TsV!L75&4?r^f5~93!js=8C8`+Ii$L7KR1e~OFL&K*t7{-@l?VhUF z>%O8pef0+shNB?|-wapuO`xg%h_r0+K1`=Q@)RlCr{K~WcuM9IVbo?%uqrHPkG3k4 z*4rFgzDe-t`GaqVF8H26>Y&BT@cC)}WspQ_p-8H?Uu9AkVgFY6UHnJd<#Ech&PnQR z8FtEiIB;2HczuD>fnH8AlA^6WpZABC^jeXT|7{Nbf37+BQ9M~`xjTa5XJ|5lAjd)| zT-P^RvDl<>fY70&O*`15xJPIICI3zU3C;}J4*XO{aMm-D_<)RoZv*hHSEd1U zbejK**F3XH*~(2e-aE+(Z9T2RzK+k`Z{PYdpO$VtF!>NSrR&KJ5A=R|mqyfVe7RM- zudLPE{Dt?YV@|%WFi8PJJXiSJ7F{;=+uZ;u1b9-JMsrHheplTy+1YC2afkC8RTwcj zogI{iVyF%rET36V8MUECW_lloKG+bUwtA4hr{!&^37^UUdWD~fgB07*q$h3AmvHzI z7C1r#PGT>O@fGlAEr3t1T?Wmt`N0dRzWy{L*jcWss%JBCCUSjyxx43;ilDN+7Y4Wk zAE_28yE2cB#|^vV${>yavT&rD=+jg}*x3fD;b7|Mm}Zy$DT(a3GXGqAcpA=2EY%mWjd15F3iA3JGaQc%*q^0lkMGZ9?uTb2m9jx%g4er3p37*?pda2c#D06q6p@+j}8V?ZMG@6(8Z zIZy!cYB+{m{GP|Rhw)d}Y8@w8Tph{z`bzN?8OoU3*)4paWy0^&v3dwX0;?LKBOH2H z1VS1r6(!&)%pkh!*{-Dg%WH1fKu<|R%i9d``0@?;x4e4ik^m!qA9SJuNN%C;oc@;u z?NV#X7J{e8=qa6OaW%^lQ`pMd59^2tv4Kt`RqwG6!l7x0*VSoaT zLj-jk8jjJ6MSa1%O4Tdwof7-(1s3YYAVr>En2=viiGI3P0^*v@8FkJy$3@3m>NZw} z7L|6>eR-?TWC3TO{tpP2y(+BK8BYV-%LcLWb1RjN7K}?R|B4P1>L23RsA!k|n)=O_ z2F8^!<_XEW3_1@#iBa8Bc_EZRcw$h*G=hX9>-7M6e`T!;(nP@G;ZIYKF5(m@b`)=W ziLq!ugo)Rk{Ss3_HQ{B^W*sx#cPjd2XQQ}OF*OV1Eqe9$Qg|{^R7PMh?t@0}^H^xe z0EGXEY#4Ya{+vasHrY=a8XR2=9)`+oE1(8B?>Wz-L@7dVz_^^7c^|8uW=?bR%(tiS zjVc>?Eq7X!Cg!ZAjSXzFN82|ldpWBk7(3+r=#wy zg|B;?CQHJX;y1X?q*Nxclz84^%F7noRXoIYVv|Bt@Mr%SfAI!TJKlCT{^H)$T9Vy1 zA5X~(hU~K2_9ZNo!Cb2M57QpQoa35fsOhr7D)?bnGdAA@`xWcsy@kkO9VEsZxlqFD zWT#Yfw`CK8%L^fzD>KZcw{H8+Qk{r5>#eeEvN4+3nBXx}GX6k$lcQDZyvgvc?Ubol z=&6n}_bBf0vJ>}z;on~dCRYXvV6z};55osGv6^>??}e}O2#eX-*;hwx!m@+oyp%x` zPrGK+%!Z^!@d{}aucZ!ov3y?Q@mNlF^D}!+P9M&_)m8ldlPJRBMtrV4NGwKE_xx7m zOxDW4OaLqeIGY>)*43ZzFca{A4Ov078tmc}P)rli#S^a2&$PgN-3EpdC7rJ9_1pIl zXj%x5!F#}=y|{9=SbK!4X{Es_h3!pO9dS1e6yrCiZ{tXqNvWqF*>83AFsIG-#c8T zn;7V3KTO4wbyiL7AFTxb&8wW`W?0!{2{F$M(q>iCI5#+wc}ws|D_T z2hTeB;0xrF+`B3}_h;R|k+a-LRRY~StSFHA|7uuFe$0py=5+Sa7LdOMrUR+dq8UfylSn33E~Gr@-dE zzqAX^6$OjsH@hrBnZ^pAw|ni&K|-qpdNHPR8BzP4QD>Ra$^a;RZqfS%;G}pdkY@MS zFVg3sf}U=jvhx5Et^sBhKx&7Ym;)!TbtteVo6-ld}!le&;_lypkxb;om&$IdTz{@he>S#CToyy%a#Rb)DnL4r5GYyNMhOYE8)frBVI6dk>TW?``Ub&a|N;CEf z%{nnV+qdli*G6vD(b-hcN^eEG3SvMGTZXC+Rdj8Z`)N$VNdLFZk92k_yo}9N%G<4Jz@=V$k-3buD;`vlKuR)s6)7k z=!ciEQE3ll@cf?2eLvQAhaDlK`6tt!L%#eBUEHpW%69OBB%BxzR! z)Fj;S-E@&u-t}z5rxv5DA_@iRH0#kq&PzlW-lUMH-i5wV+ zoSpviG`Xg)@R)p2OQDDG7%deY46(J6G*gLpr(jE|vYJ%cGSpq~tz>)K@Xb+0tsVBe z1_G+>NQK>wzhF?Gel z$}~;5Fcx?qtk>xE;3De2r)5mh-Gl_qz~^q@?eOqi6kE+n;N9=W{UZx;I0;CsXwRzbEWW>sZJ(kC-(+uG-}2 zn1fY&cc;R3Fy#A{wVIp5+#U|^vWv2HvAu+qaNC_DLzX1^I58wj-I=k!xoc4o^S11^ z!g!BucfM!cm&wYp=?HM1a`7BUw(wN{%BoauBGpQj-(5&}a3iDZZJ3R{b6dO3L1jB9 z7Q@GL&yqX0Y1~Tvest@^XZumU$lDk*l5VlVZjh(6nfzSuPT%y!j11R|CwDrJNPN!> z`99R8Um{knTptpJsuKw)*;wxJy^zq$PEI-@-I2|c%oBTRyjn~Z_GzWD**`bf7Xw|ca_~~ULdM$@n zbYGk7)AHP%hlGqPj-vJ+_ex-_FWM~W1I|EnA@r`V!I$!BwZ3+?kKxVED6)7^l$mu@ zEM0XmZmMDYp=vO~XtR9VcVtFa4d%O?iPgKwz;UN92SVlVk6Fb9#|$4^_Z$;rarmrd zo>%le-=5B0zFQwTo)1D?5G=kJgs{gE!7jAL;XLrZad2f0*0*TeBKn@>1eA#(Oz)FA zQyt$K0EbF#JW!lwm3Ewb0VK7^<7h5m|JwA%f>96VKvIINATlSA*ha9w#L`rNKfYlJ zG4_vk>f}?C^+JnT;(Vo{5PrcplrF#^M8y=|+h!j8!tQe>GdHDM=|(M?40$l(!7(Y zV;}!y#vtVOOODCaR0wi*5XMj7vz&zQ8{W0Ag6|seq3^wq>qL&fRJ*=)3Gr=^77W+@ zD%bl@m=3~s97~9Vws??JfT9Hw@D>A)r+HW zonx@8>byE=`kA7=obipr%v;q5Eqm7ZC4>`dMkX}^Ijas2!RHx1uIBUVxBW9dPA90# z-tmf%EH{uBRPRS`ywUIypCz-KHfFRv5L<-mNPWMGa-jPR0!4EIfqA9hQioT`{=Ytw z-*1)d1GH>KZT!(*Mdr)JEd%ZvcM}X zkIM36A7|GXo3AEKIckc(0X15H1AVJzmZR77O8f6HGm)$-DAv%tXnO8=*yv;X>QXkG zB=Ah=r7p-2R_SJ@b*0e|;QYitV;1nZ>2JoaBH3>t`fzWqs=#DU@XGIY=YL{r^*D{i)7`G)kw96}p^7XBl-bgtf)XZgKOmo2K=wJJgfcxhgx*U}h!jPxXwJ zE3eO_f)MpW6K9n>`T<1P+AzE2dSgn8&5j(mu{>un-=Os>h;U$Q#3eSbN%Z@jsUQVtL~_j;ftdNO;d zJH%v!=d+%Jf^qwH)BV-eh7vVFI&=O>EVte~H(s+<ZbTQa;E^9~-~(`SMIgxaZpGzb|*X03$xEFT=5-umExT#Qh>d3TYL5m6G0jF4mJ^ zaA4HygP~(&Z)48pj^~INwW|~jVmiEJk@h0+$|g>mS3Mo&%Ew-vKU-9C()T2nz${x~ z8Znm|g=43IfW*vm>?bv$*fIL{rilZz))vls{m9~hqSurr)_QPwVLK07U7LJ41f$xy zlA1;~LG$eFpMKHPe7DFZzI8L5>xs(96X&#pXI`R|#W14dR+7&!H((DAW7Hgl_4P@J z@il4sFF#3-iyRC$+`or4{vy7eAaC*heuiCy5&((305boZ!{n0 zxZ#|0kECT+0e|ekJ5^w;mHhgyt6|~6oz{Ji7NLFuF)*Vv?=nPG))x7pKEPK)QJPUp zB4zo@`p)E;U-&rJxL$j!SK=zeSr^$4Um)v^o&E4p{$l&2s?AqtJv!7c!hN4@eiv_% zZ=qF@AG^Q7%fqjn!u2QBq~tzSbNavZivvt1H6WW|_`qS156Y#X2S6{S00)y{2ncD` z09kQPMY=Yp1wGFgIN8M+)pJ!Wma{Zri&Sp_4FNZDZw4@IwR$>cD>zo9{mA5Bk;e^o z4Ij5778k$~;_ID^*9?iMc^tR}VR_u*TMYJ$*lMne;Jn^tF5+{n(}88HG<2 zA-bk9_w)G0y=keacQ_3b^c_r}!Rtk^)FujgDt-DU=fFIN%rdE|dMXV4lv}c#XF>?2 z`eo}cT?8>bgO2B@KGjrICpT?HMg#koR%U;3hOTkTglOLW*9U{PZWi^@**t1$b5jsU z^%(#PFJHkjyJ-1KbPq8QJFBNO&W%nkP$K9{FrtBE0tf0w5lC6^hfQYE<=ubW)p|^F z1+Ei~86@a0fdF|h^v7)qrvo6lKOUmZaM#&qo-+S?)z@bJUjzOyp8uJ}_tE!no5k)I zi6|j9ey*dlm$dt=3*O`{nLfY0>cYFhkpl_?t+=81sdavIRuBpjgtRh}Vg7y6$-#7$ z^Z2wckPRauTyI+v;*FawcnwT6ZjpbjdF$NaCYJfzW0w-9)a&*f+A7S+TjiWwX~svg zXg^cN?<*BJMd;`dNXmczZf`Fp!sbS}p7;kX!ab)xi_N^yyHt0)zCf*`#nYaHg>P14 z`J(seS-7mNdJ?=N`;+EAsd44=_t}bz56H882Ani3gC0f1lG?f99o}PyJbRZPWvtuEUJAK!!M=(&5X-oD=!!iyu$BC5WH~ zFXCwSz_SA+!)KWxvOmahmh(NT|JXHeG9L`sBwhZR(aoRz3@Pk@rD?({T#rc%2&$fD8&&)tWAd zaia)+r$I|t4S;V98yxmb-NUXfC30S$(CDO8_MyO5e7IbmH--;LBIReCuH#R?y`+n;!w=Ulq^){nPQ z8&Qt#^=9qpSb19x(rydGt{k$Jcs~`;jVZ;NKO_59V0fuot=wK7>y2teWBc7+ z@BdWY<>l-T>kw0HPIA>Rne=&e7y?T9qQ_`tRc1?pX8{Ho+%C_JX2t$VhKlJaJ zJL`(cg^^*69!7wl1#t{1xol4#H34h`cz0%dF^l{17yCdbeW!MG|Gqi-3pxb#4a^9~ zE6kTa9>cT&`5SWX-?u%WsNi=^I~Bv%I`+!f19t4LaoZIXv2b!#_D}da8&ruc9``Cb)|RXIa?6`8GvCufTWP5XE6Ch2sfWOa@Q3;^k}H#+8@G3|iz;s5X>b*UEb;{m{FzJ!Wcc2?wXZ*_2vRKY zuP~%EHI)&M4Fz!DN<)j>+?5j-yHqowX8RtEom;|wW>6%7gCjTRv!HCigGZAqYZD2>}fBpI4Z2v(j!a`^@Sg5jzl-`*sfPm1S|vYM=rQcI2TD=y%LE4%9s*Mlbn!Lw+E-^5PqM>*Q5T09_+G}EGqlJ_R@4^MM-~7P1x(-q z8AS!?3Lq!a62p;-c&c{YGd)kTOQBgb&wF2&z@9ILo?vh`XVJ#gT zx9$(z0X%PgV95btE-5VaBK}2NmbmfEwyP1qAk`w4El66AZVN+YiBL1?>zpU(2ZmQoTY_tOevNYndM!OByn2H&T2i;=cwGfQ}sQ+U6b$1IO_D}=ViXt-E33T zxVKxnQKtLVfJ9i4#p=Uye79mhXG07~${h89@f#dDZrG&Pg(c?+tF^LbnjBSoCY55p z;K)|EvD`zd0m=K+#5uGqcY9k!=QS~@qmyDm8x9bHSE_^`?c1*9nfYdILO84a9EMW` zcClAj7e5y(Q=62L&e{c646BYE7B4c~*eyRM5bJ9pTj1?m!3NP*UqlqUPx0 ztywx$sAOCnp?U0dmDscPagCbXoZP+dSN9>;d-`6|KJbC3pfnyHhA(-H&%w;+9nu!rcj|Zwtx^EFK*+KJoNj zECJ)6aazLg+Vh--6-PJqC2ndmAa!%Ku?w$O!6uCw85F%w#0^DXuVQ<;Drain`YmE9 zrvk!P*Ku!VA6DN27~nz|B;xo1RfIt_{ob*$QMn^cHZ#J4r5c$Q&hn$q>4}!Bcc0$0 zI_Vz70tm3oq$@#)YE?w^0ubB)$OJtTn3oLI2x}&}B-8))&Z}ckhwHZpt8ep+=^vZR z|H;Faz=?QE>7=95VUx{3zW53`mkJ1*P`DR>Ap%5gxHQI!_Jbq;sQ%Inj;P&^1UpAq zTm1Ym=4%@h5sanWs&i-``$E} zoIb|`s@8q^i=6~{oF!$DG>5d=pqaQwBsN0_T|9Lj)VL*d{w4i7y^s)q8qWg$rW2zc zr_|mv6s&O$A%v=*O{Oi{l0hJhHbhrD$jwmc3zTFn^3`N~?XgdD6Lm*-ITS^zJRfF%{e2N=o}yS46d+TQ=HMLS~byj zfu6N&#dLWl77CDl!(^O{cPPLZ6_@i^1*6zH|sS2jR@ z63ALJc_0&3{%yPexn6xtC9CQ&qSVtj3Lg;aG{srOvR7X>hZ)W0_sOFA96;hlhwXqQ z_LP z{{o5917gVVB8K)>Xs~3Ymqr5-Vun^rWNeVH!qAMRxyDe3uQ4|9yVgY5c(%pV`uuxuX4_al>s-xT$Z!jDjfMUnHj{@MdpM2Wu&D?|zRZH*6Z_cZlM33@!9a zllgToD<26wA8OhH%MlN_WRvdXatyu_xlm`F{x>M<*f;&~3%7zm-?s5H^v|dvUMEl3 zm$sjr#4CX-|Ndhf2})Z9N13qYg1Zl5dSQ#J6V|1E^HTqvpHlb8es{Cc8x#^PMMU2> zu)cFAw4W;I{#-GR>n^TrE8r+Kf`vK)O3vhx4V1}K0Z?vQ?B&!qQ4#fel{YKcR8v-t z9+N7iULio&BNa#?(Fw!zXf!4IeEyQ~W32iLGeZftyF?x@2vx4adiNy$3#2?7ztlaR)*3n_XMWz`Q=aRD9Wc8ZA$B~}~iX=uIM6Id@<&N*f>9=UHtoS#zqHIvw zF}QI)*O%Uw!FRS6D@%61f4S=M41JCdw%XaxsEATZ3}%}e?+4U5+I&y*N!UJ7PEMKlD$nH zz75aKWkuhT^Ulu@o8dV4s8p&e^<@xCaH6ew+j3L~9jUUHtXvk2+XO#d_)eym<~lKW ztht4Fsn+V{+d~y7*Y=fhI%<)lC_#lMXG%e~es-a`R-UzPr-M)BF86+!eLE}M4{i5U zJseO@>Mp~22sce+m{iHPy;goRQq$`cV*`lVr^nW|0zJ|MC*174 zTW8L@Cm{XS&fq#rWzzoXGNl$2R9a(aQ*O=x-L+{_po!#4iQ z6-J8Oc_-YqtcA?QNR`sXDY0eggnBOI0aD@OQ1d9katG$}26?-Z=VQUfUbWtn%>tw6 zU5g^f%J>o^>-_!MZbT`AM;oT~<9Nfb1}|1|Z0%Ybj~^cWkih*^@;SSx{0~0M-~2Gk zzL-D$2S+_aecT77@Y8*)JwX;T@RZ!Y3?s?XRyVMvGx#D|FI!`#*2dFU0@v|cETgwr zU6{sj-zG=p)w6m^gL*MK7ehw&`^T*@Tl3&R@lLTTrUFh^wt9cuimj)wAR&BJILOsy zg7CXzn$P8&ZFOTTEL@V`J({#F)m53wtd{qIaC{sqgUESG4UJ4$%P5}G4~ePEXnUI; zRJ5getFBjYUshv}S!M-PNH;YgYa|q84bTvpx1G5s%jE1`MSrU5D)g!jpPgdY4IaE(wdcTAnv_!} zPN(U)Q@HMd^AE#yf_21Z#R#4%)rZM^KJn{S&M4y!zp1r-@!%41-HIA6HJ?lGfp+jQ z6=mS@6@X2?F){OgVVO?mF1O2mzT-5fyyV^tHp5Fd_f{$lA=(Au)P+M}tB0jZt_Y>> z5D^AW^dRW`XTa4E$_38lTHH%x29smKQgmbj*rr(n81!XJ#B@J9Z3#h2SPLUPaK26V z+JnMf3T1dD9P<{Cy-#`5X=uQR8J2{qe*m2yN0oR2A6>~QJZ^DyF<{gPAWe3Nnlw%p zZ8q&}@6nWhLFl?`xT+VXFbR~he-uor1a+-jtpta6@5-5LJVGGg2F@D5P%UdAGeG+z zT0OK>>Buds^=UFs-feyNal69DbR0ub&iC*(W#y78LTWA6h}U(6o0~-RqqA*+66$}g z8Zf*hX_B_;0*LLEo6$cuVFo{OFU?^9*{4<{(5*U&!U5;{T1X<;YHRzUhD6zEb5h~W zDtdS_q1!U#Dc3ru&&rkadbd%Gl@4o`PnV{=H)5;XvU>d86m}`O!&8Zyp<>z$y_JSz z_)|%SBB*adF%%B^qqoou-I%x}3J81RC1H~t0HO{7jy|8nQp^mGH+>0!-hXsgAwy4w zGMWTSIigl0*e)@TIa+4?6ApX-=Ycm+?vWxOt3-6j|YJJ!bgFML!{m0F? zWk;xgftrf>%e{ zm}Teow@%@8W9l#{W8X6k;*8tnMdG9-XWNqoE!)H@oO@I?ofz+ydQ5hTNaE%kS#`dI zI||hwvj`lp*~ygn(q|=~g|)b6b*SdQ5smyX0;l-r!-xwY26otwb{bW!Kp&b*i~rLS z{6CpS`qGhLED?3F^ael6tRM?#`s`0*DZ42PL(#K^kKH7Wh1vC^=ZhY&G%%WWfGt1e zYWKQutCJU84~VrW-Dw!h+~RanM=?3-9n?%p30umCopMh*nd5(Q=9HipYTZq#HEq}^ zrOdZc{3l#Ps;A`!$xiWUA1`_>3S5|shQ6dn=3Y#t7?$6P3F4Hvy<3jrYDl=2uVS@9 zS%6pDRt8n_4ktuO{169es!JjW^akZ1Q%=iUzOX%F;&GXz98-5=)3cGbsBVOY)v=82 zcBS})?m1n%m-BX?_H5o|ASY#j zZi?|3LsbAKlv*ofJ09b+D6Ql^A54$rB&_U!$R=IIF-!;t)Ya{>IB^WA?w8~huS z>di3cSg=PX#nu5h_WJRqwLz1@$WI_Q`^VB@Pqs8dP$f2^KjSMu!$W0IgUL<#OiM4z z|C`BNyGs#`fozrrqgcVwgfvpJfn*#CT~ z<3;GU%+PdAG1eSD?|3?mCO#jkFhgF=X4uLYa0uxdo4QhrO78pSWY?8*JXWFEsSC~g zlxy-M)AC;_qCrpZ7;a|w;*{CMf*5m(pDuCv;B5Qx>BN zcNa|J;naI`sM(%*5H?Rr;F=zb!c3QgztC<6CJ}(bZGUx2otOSyz1=+=OLtDDP3UJR<{fWB>N5(6>f_ff%F^6ltD&<|#${*Sn-@)b=_b#+8AL zbOkp_naGKBg5n;VlR`Cq10U3$U-nQKneky;#dF^ct|rSV*phV2SVv+0NuHCE^L0%X zO;tiOLL8a61ha0-uP;eAKJrAPYpBF=!BgVXqw~$U$s15V5u6s~Ruh)aLsOjJ3G{Ht zsMYi#bRbnAowP*M4p8A=-+3javqxKS(DRw7{hCD|R@6qsb!S_U#RSsx0Z_F;)A_N~ zy$7!Y9@8Zi`N2W+j`;}naB)f8m0dGMYhRu<3&)Z3lIC;~3xxRjkBQ&f*iaIksI~?D z2j3@fA<7M-5=&$J*C&JgCLSXtgCz9amH3G>wBx7}SpKfu0D#FbX+ez!8is&v-?y}U z$rY#{Z3w*#c7k!BLdbwOnh^NX4m6dlztnf?+xyZson#jxTI-&|TTgA%&yO#G6=gST zT3Yu@Oq`hr#fwCG6HyDi=r53NeHy3ri{%wK9~KGG{`U6_^+;F?wz%KlruNNTcrO9; zAkxm-N}j!JCp`y>hvIm#-%LpAAy9SogwpfDO!YKibZ7M`i_m*mGy;0Z_t{AcaoA5#L%e!<-o`kN9RGe`6GuS+zxBNKr$gRr>E8siI3> z`ClMEY`|VHml~YUy#$6D1gN2WA5)BN`DfKQ5L<1GpyGhLYF26*i0pS<55G7TmyKrGoJSalZ^k?cL)!SDGndBv$u}#Pz$2 z4j|QhkVOVXq4g%?4*+h;s0f~~88ULM;o_;!EMuJR8X3njsW&JnrJyx34o;BS;sMOc zs-iU@P^SXhkGf>1mTg&!s?WzP9@j4fI9eKoeqe1)8>WvF1u_@Gq%g#mH7W;Yxcdq~ z+`z3GoP7l$3)4iWCv%f`ZhlHV|B7BNS!s4ObV2nrEmm#oi7 zfIa}RHt+FJ*G2A9QaKFdk-ktLx)A8K?*Jr#ip5smQxl>dd;d_i;t8}f?#kXwE%Drq z;Rnw6o@UpWLHiJPHGd#l5{a*j`on>SA~*J(6sP|FPotNoo9vKe4>twd2n;%GF-%j# z#*VW;3+_}~w7Jn#sP)`s5qC{Jgu|0r?LIyEv62AE7ozu>;xrk*!)$-YeI5J>lbxYa zftun5aYXb-7{Gyj7iEKFXcWVyR!7r|K!gZjx0x=#;llt%jFA!r0FxNv9RI?QE#HGG zM@{Xvg3XqvCH(k3Wu-Pab>26Pr`d%rW79lReu_tYgS)YBW1=7FljmOGrg(pyq=0JV zcYxlO&;SgP2wYic3MMQgKaW|2@_Q&;flhSiyj8>o5LAYx!xHo$tC9E##RjZYuk8Zw zb+vHRP(=dQ+i&*DqzCu2U}OTP>mNA94}f$2cVGKU=mNJ@{2Z81)zh+AzF_|^`P~uBhBX~!M+%pNhf~l>|0oqJ2|18>$sFCGQ%<$ zFL3@k)tyvtn7?ox-*WYC`#XbhJAqJ%+-?aHI^jzA0dtoZ7w&mx?mYnE(vfI@%MoSS zhkBG6EJ$9d)ybYfz)*8(Ps?fVOHtQ`4Z*SxDHzEHRTi$#^KwHTtjqmDBS!gM%g&_0D7HJ$kE`wt;7cNuD)q?c26pth}tk> zrbr46Fsen%%lk^k)e7^*;dh8p$8?s`LfBnjF2TA?^hstt)!NsZ&%9N9O`qA|f9uL^ z{XJYa6*II=;7^^o@w4$>Y)_r%Gp+J-(A%rr09kU<<+=}>Zl=Bu0pg=UEY()9wF;&% zvqCa|X*^FAJ}-O!^zQMRHTKrhe&GZ~QFD<#i_T3fVM)&K2~Lq@Wi1jc=t}eMt*tVH z@4BySN&Rr-m3mF^NYRs>MXUgRKqm9k^DIFg#t!Q=Qhlk^6IfX_QJ>#B(>c`GS9IM5 zd25?>t3UXz*>pMR{s|pwX6>aB#M%NQ6`OSD`Vx=5^wJ2U_5PRW^#Ow`1X>~8VsY(w zZ4-NSkxo1MGyQq}DRihB8->42aI&;oRa>9z`!um+s*bhroxaSmyj8&;;#^}E*-PjG zWO_78x<`^6Jx6Ou;!sw{IM>=u1t;7#=qepH<@b?C@C`L|B?nNo!MS2n_}ho^V|GR^>MO(U zRdvUoY05;N@0BCqHaOd#D=3;8L&B6z-6KAA57yu5f8$n}q^EGgrhqdHk^$89cO zP?B08`ynVCD`0CEw|ag^TjyQORsbNOob08wxA8SfZPJdqQ8BH-MS9P#8bo}md^GI| z&wz)YP(I22(pjIbX(~i7NC(MBM%WOf3Fz8TQYpQJANpxmLl(t_e50aCNDATawk=mP zn14y|Gydd-bpDy<)@s@#+(z!^6PzT!rgPn_F|q$`3OMz8bTPpx8CZ)j`FcjZ|LK$+FPve`+w-&*x&>?$$p=rTas@P ztfEV{FLliy;Hdu0e-osZuSzm?Z|M!KPIAEnBG05MdS9mtQS#mppzB7RPDM%3)ogAO zc0P#HU$w%$B~AyOMiuSVr)VN|2%+oU-;X0m;U?O!=>F?d$U@xi@mh>>KHe36kSy0b z@QBQRrYSr7J-xO?zS6YmUI9*NhfT?Tp$LH+8kdO@-f<9x0(*S@3d-CZuEez;8<}pr z?S1who&KJ~dBx#!^_$IZttP%u7p<91oAigWKJzwebu0#2*lWT!xAm_u`|x%P`+2t8 zy|Y~|!*?Es8j#VYY5II-?bdCI&DZkhV%BydjPZ|74(`oY{4{jSSuOd^)+v7IngV@- zTi`jmsB_+I?d~Rnhvzjcg`XW8oLP`B*aB~s?cx-Om^WG*&(=AgK3aJ<>!t~+@(BjDcs0A>@4fLSCB}@*+Q22^ zH25|OM$w(cJVvGEOwy|oxIX<<(r1sp7OnB?XT!>&7r(vkw@;Q@=RYI}GQkD!2&MCY z6oSX@#W>niQ_fd!++$=PyO&RntMjEWJ>xH$YRePGWQF}DM*&^azE(9maAGIjUc@5) zuw1`;W?nr%>SHz#SQDj|JN(us_8z;#VHB-w?#HN9Ra!a3X7~$)jQ^3xkYYJiImclO z{`UkRLV6fSo3a+p`jh@Q*sxybU# zVA+e!EDPi2Lit;6&{!Q-=el-z8T+rg$@Q2_0nsoczvI45%@vAo)YTDhwd-793idq) z`x}J|UMG^RWs8LtJ%V{N#N+2P9S0K@QNcECyYmP#+@-G9^NaG_c%M9nxpcpe^W=IS z_u+C3WwH#%Ti{6cI~!2~S4#bOEDGW!FuG1yEL2jRWl@Wkrl@}n|6dL&$YKQ?D}C zpqg@DhkD}I&jj|y<6oaQ^xG#3b1rE1Z|zo2NtQK<6yq*{gjiDl1&E3D{r{;I^L9KyrV9Tu8H}t?*QhDljI* zm%D|ov=S5veEM1Zr90~7H4=jT*2^mwUyhu#EOFYp<>Qpzy^}$j9yPa24O^S~;L6_v&Qjz`@mb`W#PB9Nwz@@b0S6umWT9O1xS$X`%z5lJ1#_ zv=->6QPN8CL(&wAlCO`{^W|Nbc*8As3&`@uo)PU2=D^gL-ed1p-ecwX_e^qw*f z<-E12?5A~l*Cst5HQHZzO?F0k$zAJnc#%!kOL-Sztv)deGWubGCE3JfO-;0niPbOo2>eqe(Q7B)d4BS-G;SZQBt+k(>fz)4 zj6!{)2M04|+JK8J)YeJe9EuyYvL~Y#YawSag)QhkHWIk%`r8G3II?ojJO?`4uF>~S z1iJL9S|7{g+yw-)Olp%9p7MBS=Mcl>DE;kC77wpO&3fpzCtHt;zqQ z4nB_kidMVWt#xUVmv0V7trad5Ea2#U%%9JdOyp*#JT2-ZR2zSkos~JGVine-@wTH+ zymdXRu08wxymQoy&AB(!#&0zz6`-xWbRy3^i}0*rhs3q34y|zgN!&CG^*3cvA#k^DanH2f%U~xV(G=VF^k(q@6qCbgEs6M!Wd248HIKR?w@koUR$0W z718VL_CIkF_UVm0!*Tge$W5sb3a<;{ToK<;?hUiYatnwqV1tmBBy2!|O>bv^#?_<6 z6XwwIExsY*O7&a|J8)rC*+x~hFA%v=!C$udB5~Rr{Vb20u~v+trU@K$Dd1YM2jRC2 z86zb))eB^Zj|PFMJcGSl&@ZswFB`IWd>b;VYPynQ=o_8lTna#{g2%;N_I`m>+tC`w z)rH3al^K)BE0c&R#Bm0`?N5bAH?he6VIEL_oP3Ll8Bz`@*7Ve%1BkuCDvDa0G*g@_ zKj~8*Xf&XzKP+bgSpY)4Kkq%^Ws`Hz4@U@u5C`74RjWeS5|XjOVD zE&j4EASzeDM#~90OS7fcI?ndR$+oDN*ht?VKAEy~U4TYW>05%IXKJahG@XCib$?MB zex}4A9fG??d4)T_N95}AnR7HsU+;hQW!&q*5z$^9SkV-~tYWd@*7Nk;N^L&I$9#C? zA1-T-F8(*`#>vQ1W2Tdeqr9V}X-FLN*<<{QTg85A-A@9pe>Kkk%2JqI^#$P*U@pRu z9-Z5mr15^Nsn>=0!XvVc+0y=Evio(y?6JC?gQRN$IZu9@V4^xcxVN1)!r!?D4pTaS1@RN*6Bz8(@OWt|bEc&g6x#n*IIr z;X$jbxjcj7ONt8757lPL^M{Wl>uxZKRUu&qUc zU@CUg1J7{=jwYV-JxYxgm|mK>PKDV()8*rHPKQ~ja%N`scS1xtm2SZsX*BKxd{Bq8 zisPsbt5}iGO%1WMnk5tL=HYr*Gcev(jlu`?S=IlupaILi+CJbFba!uy4c~b5lP) z^_xb^-jy>lAiaK(bSfih1NXBPKyCai;M=P75g&}K46~ui~d_Culh?n z4tFlC=z)fpG)TX-yUp?ia+VhRm7t38<=R1DBLW3+6~ukikioJ#@=r#=T-M(k{~uI* zS+0hu{2BViMf*uq*W&?AAGtfkVI%SHQSZ%iP5iu2eLS)>WIeF(lc0c1*7#p}@rh~B z0&07kpKqwTW4i~Ik3!5_xBbN3vyypq0RBLRvHexy{--(O=i2?IRmIq<=hP(07{B<^ z4ln7qmrHO{V49-D%ybM=!SLKZ0D>|mMxJ}JFk75m96HsnMr;=Ve0zWt!2y*C$1vH3 zDH{k!ehc3NR{ft__V*+E?^(t*v;CyS(fsMnmrHf|V8HL%1q8AID;%>CH!=mgmIb*O zE2+(Ff&XHH{OJilb?P!3qZk<$T%a~RlyoJczvy0kH|#ic&id6)Z7jAhP zlidQCAZ4FYG)lt(xBOhwd2ivuqQ`H7>!*A-?xo9~iL)*QrU@;lylb6ac~0djj+l4n#Q;f4a`d2>AISWM~I>sUTl8 zscmGUq(yUIJ!IYHP}YT!qv#jb#VAYI!j9PoaT~!Z%)l|Dz9pHumCg$g*TDSU!I2PX6 zlE6Q!`@dC0ee>n^60ouZ?RCNDPT|rZEfoMV8I3B0k$#7LC-CsW?xF@ipm4wh%E1wr z@e&RP%9`;cZwj1_1PFhyu{D4)T{$|O{P_`ZWkUTbiW!Uxk$=wfhy_%6@e8ETES8Di z$8+D93J>1=;&r`yRlzWDxXp)kq%pt8b{IFs4ItSWVDp&}koSArIEv)KYqG|atBYhT zm&-x%cFx?X$tQ1y*w4)H*2;|tj*jbh8+%hc)Ly)Kd0vL^)o`NkEt7i9u0gbe>^fyy ztX^~Jd%O4c@gHx{tXlm_5nL+=aUJaTb={MpWrHyj*RuczfhFdF-6gmB$^v ziAu)C>Pda!T5AIm^Ca$^8lwaE?$LHEV3>8_`vP%0I7vBF0DJ{%E9pxp`mQ1uia#OV z55a}h^bmW^KR<3QSZ=c=)N`SJQpoP}UdkfZY$P!>)M!3XLA1Btd#Bw>9NXD)k zz9wj!0Xpw+H=Vr$f54d2GkJ@d)40324)WF)2>%qivT?{|-p6AjrIIk05GPJ{v4I}G z+P&%6+RWtVl4bXkm&U{0t`@y(^5CcNYrG99)oW^#ypM9P-aBjlUM${dG6k(GRkiaX zT6tP;QsVs=$c{c|ki5=&Y47mqw|6C}dqTnx15{k^U)T%$EwcEp<^d5~g6%4bsR}Gb z>9gmN-P(2Z_(_9qiDD5xo87g+_o{s5=5b4jwAAh&S^)5@(1lg)DD3LxLxH*gvn~Cl z&Hg)PA5jj58lb0FL0uHm40vqoQDm{y+QvN04Yn|`-v3+k_g^mpuTbkL5+%b4Bhoor zd2Dw=!TAB~#Inu&&h?0bX&QAJ9F{Cm%ZIKy|K3oJ{ej7u=e+a^j!-9Q#q?Gyrn%D~ z{VD3>M*ZdE6(ANd%KGPX_TQMdVs;0wE`G>q86l(;Xw3XqtY{ z6Dbu$xs8WX)C)MDQe?+bH|P~v0ay@PN6S?UbZ6Sz)W0CkX4HWt5V`Z2wd)q8$F2ie zi=91Iinnjdm2F>A$eQ8{Xo1w`T|_)k6w2>4uoZFE1{D2VomxwX&qejA*A}B6^)Ec4 ze`>Lrq!!#Taf#E8eh=F{{`tC%>Jz}!CbEyT|7p6ll0_g@DmBBOpt}8FIx)@|m2IqV zed|*I3u( zM#QHmg8qKXILm_c#>6YUJYB3ia^V?`>a5O5eqeJTLfNNVE$e5t?X}7F2Za#)_C zO13>jws=4uR`S_fSI*L5c#WF(VQ^sA30l!}2SU*d7S*=~1XBwzi!5kf%F|wFq%|p= zTL#rU^Zsmp&h=Zzw5OA)oX(5AbLcq2@l|OTX-Tpn+%d4-@p9t_gcN;+1cnH(y9Nhz zn~V^BUmzTD3!B1f1@$(vYqLb&`M7mzwc5<$!r+`W*?}%b;-52}iMvwgw>4+g?9^Is zK99#KbEgm*@R!l%iT+R5k~@x)hIs<{bn*}nyB$o>#7Ev5V& zEXafp+{EJXFFXMk9Hg5}Ng=>W7{4fp)#F|S4gd$#t^p)imA7XDkO^>*JWY21A^KGP zdF`o$HQyXreJXa~WXBE`%GUtvS_muiUoy&E2nyzgMqZeMGJ7TfhT)<5k@c_*F_6tObp#N%iVwzPF8<&iRkdp;uKt-X*t!u=(^vjSmv-7@@k)sj*8 z>6ERU$~tE5r$6tcX*y4q2_Qp-oxS$i)8lWUCSQW1xzVic+K=JXbEK8qKJvNVqD%3x z&+=0E%<}BGV7O5F1JtEK+4xndsOFdSJ=#O9fGV#jzH%EAgG<{JW;YryRRP$?9Ps>@ zkwN!w&+ZS>-B{*&Ky;cn43A#m&%!Ovhb|*gbJd%3k_VDQiq?#pZ3ZL)6kGE`FMiow zclW;CRwYxlwYS%-u-p?{eg;PL?Vh?-*O+3rMj|XRT~M?p-^aJGy9)Pg^5D z2jhES)V;+QRE*P^sl*J625$?#6m;?Ig|+U?oF0Z+ z;CQ(e?05iO&-!%Nk6`!53!{QP&f!fgxW*p@AS~KvOm${PA2jv_2n_`La4X5*fAw%d z_t`!Uekkkq84o-7Pw-s zn#ryjL;3noo|mJlg@xVJDK9+4tBIV-x{cX|N)@d(Sq~=_bYuCm_62o+YOK6?D_#2e zDsD`Fuv3cZ)Qx`n1{_rrNXtm)CQNv*dN+bticFRdx``S1=-pRV)7 z&ZrMIi8tAsvrYkBOdoo)olNCiT@rsxP3H!p)o$Wr$<7bLcOjRZf1;T-+xe7YUPGSu z@1?i3@PEv0_MMIATJoulKF}A`>jnpNH=_{F*YGTpm|NIm*`SCWvh(~4@JG&-C9x_BNL;62#6@;CqUa>|oV9Cb31X-b zRW^EN32gXe)o>FH>^95Fi}x7Rm6$8`FJaIBpbC(4sE>oG;1~x5WK2`K(6tuji!rF{ zRwjqIVb=yROA2lO{f?^VU&iJC2E;wX)S?Qt&ZVV1oz7PD^9t$OE7XGRi~F29HOzH# zFS}}i^M?o@Zng>rLd^NTxx#sSItb_WH;=H;Ulb6qM+#fgdvo5jI&O_KvQQL$twjEi z5nhj{@+Bx+Ee7h+@M?Q!7OWq&JecwAuIoR%DO$mF3H9iW#dH1J5*@Sq#p3a*`q&0H z_=$UADMwfgoFNVwhp3Wai8|uBFL$MgMc8{nMszw6GE@ypws1O-v;P9&RFKg<*D7WQ zjg#1SL@2*3RZi3m(BGLn?DBxJ>c^$wCT%~?ynPkFImU4AoO4fSVIpsLQGC4l5WpM4 zHX)@*1J!kH$zdK(b7RuvRhGbZ|b$f!j?Q z^pw&~>btg|;62Ob33<1oK;cgLRoPhniPOWZU`pRe<4%jR_N>~N9Bq^#nyyv&89?S8 zXQ|C^`0o@6Q_S%bG1BrJ7vAn!VeMBOb4fYV)s@9lbyq8z8iK{m4)fFRGS4s>Cpt?W zk)Mn#n~+h3{alF0d`8D&{-$lm);>|VoG&_EW415dH6+{epw2|kVEQk|Ph0ZdzZhV@ ziKqTzd>?iMm`iz+6@BI?sFSup6^+2(+hx~G>nKce z)>SfnL~n2Ly%?HQza&bFflQ`k{6aPB@61d9eD$j-cQvNZHJ)Ao289+JVWz`>7e>f= z0?Sd5rI44Er|6kjo1et4z9P1m1`Y^_GEM@2J)OR?nvYZw*!;d8AgGzhQI8euK*^d zqBA!Y#E}`B{01QZLBYRE&nX{i!n~dBKIGAEe|x~#=y8|M0B>piyt6>PD`fN<1F-0u zZ-)8j(>%ZN{ZpL_q5Ltg-V>EgZ?C(ee@R&o67v&1LOPV9ff_lOiu$U8Xrf(;SVu#b z@Xfyy=w8!Sb zH-U#2yHUf(v)wA=_?KXJ3N&GJMTrv8$)-8jcb{^$nfnU4u)oVt?uGk#;7GYz`Sj+Z z+@2e#+v_?-1>W7ixySDD<)sGpMhEz||71;LYg^t3>S@pm0f4H3^K9exjTEy_Y@0V2 zZNKl{uy#-Zv&N>;wsXni^nu)HU^!i|u@(;zQrtNEiDh%T+9P@gm=koY>6&GG*FtnJ z=aMr&!2?5DI#Y(Q*&!-z_edsC;6wd9(MhdPYC>>9l;~IN7p)&r=^^+oKOX+NQ*5bA zK3rR5&&Y?JIDyT=q&){$WYgDD-qsVi?nN2JVcfY2EhD3#+4vWlD>o!ag>&63l~`q8 z%u8Mnw(v@w>gbR$fjac`q|oF7Z&u5fYO9o5jU^g528HJ;CK}SwwkeXo8t>L)`Za;S zz@Gj+?M2A#tpBVRRuw6Nc@3nLFGpV++VZcVS|yee7vNFO&0u38_uIw-?rj$~kEO*7 zZwH(M5p)To_(R|J70UqB|7c(FdkiXaXXari`sbF|`8PLj>{FwBfn5Dbi}1g&TVeE% zj#jB4@On{E5iX{Ng7_d*f!?z6;^3I#bV^wZWvW)7)S8jcTNp z4mE8M=7-`g0L(VU9TsKr{!WNa*UEr|(6hleo;-2*e4|Cj*7lm`xMwMkfu(5a`-xj+ z+3uVu3%*2i`{Y`B#(<(frrU8BO2bT#1cmd%II{FsEcm$A-<5gVn0=1>6Upqe)*GF? z6RXd!@HZE`NOZ7^yl{;L$FX}>M)oWjUK)IQ%0?3@>&Wu@U7o#6>2U6A>DSUtinXHU zTVvLbM_mry-2V8DS|~wMMbrMdKKor+83&r)+`Z(2T(=bS1z$o#rS!VN8#h(T2R9SF ziAAfhS+VK%1Xy`z-~+!>x8CgB!s`pC+lSi5n!D0&-_w^K61co4=xz8bPPe4U{)|Km zwx?IN#`Fv*L`j&1geXdJUjj0WAJo`~^5gvq;ACZ;*4gU%_1)r`meSGiAjQoR%`#$k zTdY=SuD`lX^weFGbrDfAO-R{i)N^D^u&u||&W!Af_*xx#H%{@<`KXgxQBGJ{`J9^L zr|aeFj@~wuTCW~^t3My%hnhO=u?zLt7=%Kb^`2bPNdoM@lz77bA@9w@p?v%Q;Sted zYmux&M1?Hb!;pl8lwGLElAY|skSy7kP|6ZQStk2#$et`A%5KIk%vc9wdQSKKt>}Ay zKhOR7-QVMQp5uA`=or^AUDuq~dA-l`I?wm}wbm>>Ev;{!0AVyi`$ReJRDTBWW3rlkzXD# zU{@H(*-VL^5;|iX#VC4#e8pH!r!9#DTm1~hTC9Fd|6x4qnM&+PglPQRhyAKd!lf0V z>nl4?ubyYt2vN#C_DLz_eA&>qsNsm-M-bo`MM{)`>-~V-tBlw=L1h8T&`e;f*$#F* zyXM8?gk$5$MTTxEKOhh1j%|80}w@HU(dQANET zwKX%kbsA~4Z^bCGf@-Kf`hAohe$ag9wLAX+r`?G1#n}_nP>1|R2)qz?pM$2s%j?Jj z%+e%c!$*SPl(1Mu2MRChwZ>~p<^Taw>3@oe4&tqdYXvC~gBEEt^-)^qX#*X%Ef1Jpvoz`m{#hSbP(Uldt7IIqu_ zeL>aJiuZbFWXk&{H4+i5VuO0X61D_NgHO9s9_^yM-pb>1V5q+2>Rk%ip)u3s_*}1k z#Q2Hz^Z7XstzB4UKEJ{d7u*q{HT11$y8|P-7vDIqM_rIIwv4R0F3+WWv%K|=j{-dT zgz4vo54VFidLMa*1622x7zs`aE5d%=ai{mq(#6?q!!29*3rifmMn93?GKs$?HWwYE zQDSo1n%Dm6NxC_5(P_@LCF9KEO&v@`AY-8KqcFw|4fyawR=Zcha=I@LT_B|{r%1N9 zTH+6kn@9hE9JFL*Cl2<8#%s?`p)nB>GJ24Yt0PoygcCv8j9+H%%ZOQeZ4)iF?C9RR zxsSgut7556s1hDjMRB=!>?`)HO(eMbUqTNXf(;9MQ10foBDY`vAFrqWnLxC;@qfJH zxCB@)UdcP&U=j0rccX-%Ib!|32nJFpW{V{4O^A^HX3_f}zoICN_W_wT_AQ*PVz74p zVzWX}U~-M<=3>5Mf=3Of&=~x66#qLZxV`!|`yx}Mb;Rv9t6)#et}pS3$v=IGA!`*p z07$tvG4*$MgTLi+-tEt4rmshbBDlY=dBIJ+%XGjxjz;ySR0Z8eAOSpi!+T)83tW6SyQ+aER$sb;te@37k-m>(6rnZ%|_kjPe|0Z?ue^U;6O5j))Bd8V2 z1QLzFn`4r2VqBGYIv)Yz&JpVC;JYyh(iCe6|7j5H6Z>BSCF(tyWy=As*q!*ixuffC zFHu4Ry;Ep)fjm}3Wo?sZ7w`@_dGn`=($CMovIXegUzkh{-n_i-kNXmq91PU>R}8*z zcMDS~oUQuuph%Aqtz^;?OeaWww?Hxy&vF^V=Hq3e06#IFOh=HyA{(Jc1_|^v)x?i| zW>$Q+)nC|ALaS;69*A)SQmZM!BDtkq8>F4|caq?SZ{0%V{gJCM#`GkuY*?FBcGgD$ zO2x(W4qY0DTCEm}Zzpv=GWPo-YoOj8H5V9j2Z!;2w~C`zG%7CYB}`p8tYx>DVJ}!G zVNKchAfncbz5#ml(b6f=w^^kjrrjx3rM(HuH*qQxyirYAO!u3rEMGJJLM#U&f^DxQx_dERQlX@+Y;=%%=qGdU5Y5L+D=cN=&B<0H=kSVS znmuq0tj_J;7^D1(pket7qe2fwC(sXGb{M`4n&z;BQ;zx2e4d*WF<7FPLwhlHD)+FL zcP)~Q4$!NMw%)x^9&#UXYHADv3%tF;A0~CKMrJ3yGr-Q7r&5}aARco^(dv9qM~Fki zgFppdt#ERsJFe|E_&~)XY;qIw3G3<4H=->2Tn5sWnJ$a9#7$e=S8tE-fNZ@B#-pJJ zzY%12UK7t)jIdQ#5?f0AZnzdcc4rHB5&gWdLTQ~_>}^tXEZWgSM|ylIAOjZau~e6} zR(%CVHmKM5&db-9`BMI&#~kPCpA1bzrkfvq6_o(l1`zWeoyL9OdvudOo=lqOD9ba_H4p5}p=AP@bK zZ>s60zO%Ilii=@>k_!EGWcWbOcXRZt=L$ccxE-*Wmwx+L@X(EXMaINnUs(;`%FtuC#F8o3nKVD&9&QQ-QnE;z*g4oqqA*^v zjd&cv;>P;b;@Ehecb&`GisH~D!`?|_JNR9XNQrd3 &LSiGgV=Zb;6RaX1)Ql4pd zHKz!TeD&TVTTBhw>!se%-UT}!(H{Lz(E$uh42P^EZrmhST?&T_-Ww~$`--3wkJ%`E zec?XL1kohvcRCI+stlLx5BdgI!_T7mB;b1END>3Yy-zjPw8KW$HD!TW&de9$n0@5t z1k3IwCoL{6PVIzON!UI5&X@5;`>+9QdGTU~t(PB!qDyDyyF)Xj73K9~v~^G z>4U(-l#0kc3q4K{tsurT#}zbk7!2#THPLC}Q?mO#%{Bq%dxb;CDHt&DqrwBOp+qT#H;8vW%?_uHAZ2dc z=Cy0Q5?eYZ+tP;euZf+EqHj$s`L^W_PD~wrKqb7)*N~n{Ei@69#@%(3LJPZs1|_Z_ zg8W-InkIdTE-8&SO@&nlynR1NmJ)? zJ!xZo?d1uJvN@SY&7lRo9QF7xFz+%UOQOeTk9VGcUoYw^h8`T9G2%@6s2Du0Q5+yq z^YKaa>w4jrM{R4F%|+sK;4=95F&KlDkvLwFFX@dH`Y_qXmI6+2@s&A0S&GlmC^)#3mupbjc0z-3*Urv~@9;TLv25 z*E=M?fq0_)YIx`RM+AWsyfcopeqp@YTbB{tZ`$T$<9d9#N#Z1>03`kZWx zS-a0POw?&XmUVheJ-DKz7(L$^m?Yk%9tc;kvgPMccwx&{R?1fjcRY8xPFiBALymMH zd%`a8(!3%jcCxeSP-7-=c3I+92^iM=%5?Dl(xl9^6WiTyRFBC$ChnI8NO73pFikZ= zIQj_g@`~;vJ~cxg)%;+z5_`z^QIFe`v1bP_412`93cXAE{Q}QKr{93qbfC8|K?&Op zkgCr-oJ(GnEV;8sePQIA2_06Ft#j)2?G(XCa>bAF%q&?GFF({z4kU7@N==F!Z8Z-# z*MXT7seHRERg6J0!_z5mQhZHjs9`sP3t-Rohve9BR3{w17plNPck8Tu`h-)Hl|Ksa zj~lJ;J3gN<>hm2bylpDBb-QZ~GGz8;@8vg+a@Ma~p4S>b8mR=zTnb4J9ld@NehWmXMO%0Uz4%@cw)<%c$8-s2GdloL?*IM(2f{V{b) z5-FSM#}JCX1j}BvWZQwVr*=0Un$b=)aB@r?GKeM_I}y58JFAo3R&$0Rj-?2Ra$aIS ze2VV^1(#Mu+%QKF{JTM$GqZ_=$QTSF!#cZ#47HQtwWPf>o5c8})h=~VRjNSx!&vj# zPe~?%b!7fVFMdE=c}dp29EoE6Q`k(uV$=zPhQhWtdg_nz^V#ow@c$4k0YTT8W1){% z5fR8dxGetZ*{0WaJFlfPURcB*uREAqTu!2=EqcrJzGFj@QU^~~uM{`nqYzHx{lTN6 zedwptbFTm;sfpxhb&LxjA3XW}W6HCI$8M_-2nl^4@hsdOM-iAq(8<7as8+wP@i%{Z z;^y+yqkB2qE>mNm`kUdGoCk5^lsT91uuEt3k|Qd%%niyzBX`S1fmLS(lZnCzw3 zPY%fyVKG7!B=*g&KWB*LutPJ)L^X>ukaCjtcvHajjYDeal6A|9Rj96fbfUy^oPyA@ ziZM)6X0DKD;0JM9GRN_9&R=cfwup0&9uEwE&2n_?Xl5SQP{%~*Z!V2r1yXiZ8e=S$ z2-kG@)vg6=^Y7VN- zxA)9Gx0XWljzEf5CfpmVrf(Dz$gQe)lcPh$cwCFpPn>MU*U%Y$7GqaZ`h5wTKqa%N z!Z6)i6x3%C&Tv#tfKuAt-rfzJOR&RRnPXpG$;XFbS8k);P0^F-@3yTB|3H^ z^-67GgQ_Sdm!AkIDAzvL zdQw%KrYPHXsY9Ygm}GbmZitV?oaV2R$2Na(*@_S31&XAZ6%_)Vhzab}hydzbhwPA= zcR6W4k#$#FC~8~>`PcN-!e+TTSH#W+M4V11E7V~h7gTN0SY60I7wyK+BO7#4Qqm$TnVnNgZ52V_LOEu(mBYBM_G zY5(_iKoEv29lMaaqP_{y+e)2!&WK%% z_v3BVC)YpQKql)R)JmXh5Pc$08iF>S4#N~dGE;!J_2Qbdym98-!kcZIgl4I*k`s?V zq1d^j5l0|fbHO0JyYzHzG+2ffYxZ<(J|N%+8qX^2nGk)rmV|4t=iwl%9YJ2ywgi&vR+h|E~xb=Zq&A^ebr-dBHaNIu8g?4p|_ znz=>8k{wfv%YeRhk$>2_Ct0j6 zH5&0_Hy*E+ejzBjmAt=r#jmZU*8JhQ$%r0YEs_oI3sy=aBt7WvqT!0XG7g1hI+W|- z0a)MrC1T5TKa17JDjsY|6Y7-nRa|mo$lH`cxUlG`qv91uy4&d*M5YL$v=P^3DfkbF zR-Ix}CQ${dvjx5V3>*C@4knUZkk+#~JxRA!E&1Jiu7(dEvwh{^EelrKQZJr*i#l!> z@mra%IxtJjkHQelYfu%q7+!H45wPy_RdKyqwdrhCtm1R0FEnL6*2#5ytL8MgGsusW-0Wn_M0_cO6fQjuh^=p zi7ixpUq8U~*ZAOT^RDz06|I)GRslrqN?uL*QR7V(Bx=5!f9D9Od+wo(3&2649Rc=Iaov51dzjxz1^e0sPN|aW86qz<3AJ*KXn#0qhx6h}Ca) z@d*&+{h*cDeTPNAuW}W(46Ooab}PFbYLEVXxv*6dV#JE;uZQhVxd z&)wvz_OjypE^5%6wB99@9{qie_EJ{sSQ5FtJxSL;0kZBJK{Ds?JoWlH33O{i8tAd;^ zap$2>tuD-Bfxv~2_8)5>laOgo{y#+-LoJi_!sbUnxAz)>+Ex@o^`9igW4Gr2$FF3W zILQwJZTZxillZ|gQ@kEp6^iEeV?X86H0gRK|4{sq8@^Yj*-dZO`ZkfUY;G4+#beZc zr*RvP2NjE=?w37&R2KI9cx9BHy+aF`AOi%GN9D$>9Pedv^rF^Mv+jnfi%grZ6Y8xQ z?Rs?a`RzhOMvsS85%O4KP6E*{V?Mv>zUuV>rr0Zvnm2rFxd%m%#+=@r2wM0x=a*H0Dk{{=&d^=Pgjt(a#H6NUr*Ha2!#Cz-f+=>n zxHLpRlfG~}{I~l2K0+B?A{w;Q3c}8%{F+ z65?@_?Ad?$z&QSa!T9s|mlnMvs}QiXxJ;F|< zr}%?iHXjLwVN^dT@*?#RiN6Or$vz@#@%%pS*{`dS+~E%*(i22+UidF?opBp}-RdJ8 zD|z3?6XZ|Ngm$DCkaqM$y<#>!vY3@Z`R$R-8H^HwrqhHh;Z7uzc9F%i!Qyhm0FGl! zln=abl{JYsU6?E`u1>7dj*H@@m3RMWG(1g| zS;bYmgY7DT4ps#(w`bGV7=zCt`_!k|_SbLK3Pq5cq{d*9-{p~}93e7-Xa8GD%E0;Q zAC$R+>?8q)THd7k$awCg!G%R4Ack{-gYMn4_qeW+9q{{fa~jEU&$>U7K#DJxc{*OL zSU(Wqo^p)+ce986av;3eVjZQzDctL~C=&6}x^S7B`D zh`4yB9)9Wm%_sa?kZ>AiGn54St@^=`0LhSH$Kp{>gpuL`k+tVd!ep^y==E;?MCnK0 zV4XPg5mpNPi*f!_ZlhX$N|nc@Yc+lD`9|HnQRyT5VZ}G8c;U;YGB&wU)SD70+oSmu zf(Hy3u2#yKlqe-iQfG3SjG%4_tKQh3P)3Gfdl0Q-zSNkwhjh# z7c5KS<&|$+o|nmF$G|ozzJczZfNs$_1CnSbQ4W|uZBh{Z4ZSp;Yy20HLx+7Vd6P^9 zEGj|K6=+BG?>h#0-=B;R9d=K!cFTAP)^0rrphf6Nk6M;7Hz6lYf3A{WH?zRo^GFOl3{PGAPjqHL=z|b^X^EA77D91lafxnio)QWqR1Mlqj z>G;Ke1p($z6HsnUifGBcw+a@|*F>DvH<)IuVzFX`;1Z}ZY}Mo@h2GEMqX))Ai}wn# zIv2GnF>kcbsH2O#`VURqf9T?ZDskjGSg2UTR*zI_Mi2sv)wjV4H$`gz@}p;e=?i%j zpE(wB%Fr#Q9_*84puA#?OMy@FrC%*gHoIF1)L&!9<_WpP_;xteQVgFwXzCr#iapGYsS(hU16Yj*yXILL(&_A#SG%v7$s8f7a7jq`zHSh%xpNcJ!M%+nMth zKZiz>k41)mZ*XX4z=#x^pivQw!;NvzM~)O@XZ1|BAFO>@b-fwXJ0(Xc5IIpc?l@#l zw!yll0C9tR)8J0;yiUip76G6IlN;ju8O)~$qKr%z)U17wxFj(c0so9qr@%*O<2%P_ zPFb!e*~X?SX5~(dm|9ITjGu?Ld}1jWf~R0=c!t0up}|J@5fn4vcDU(~rCh084t+m$ z6=g!p;3o^VtC7}}hW51)FRpo-y-at5HMpRP$SX1K3!*cn2~wi@cif{Adw80zf2XRt zdwwEMdVVMiTrjzg$blGO9Bk&XOr(afHAu+|z8h}!dYC_=VCi7V`E{6Svr5R666}D~IfJ)?gCSV;$#<>zT(`6+PxbC!WTmu=Bb&4~r2l ztk)+O?{nDdy7GMY!pZBo5kd@yCV2drcN8JfQaXTWg$!32Gni*lug$D4DMWr-S`*KU z#>fRNdp5k0s>!9*VQ+=d&GR@~>JsGPUMScp+kpjHym>m-V1PDzaf8YGo zr1o0_Yvk&QOm-{5xPTS8i?5(nnRK5lYtHVZRY{~bk6R(h2CX*Mv)DtlRj-LJOdb)< zDQLh)zPkJVW1K`S-Rhgzu(+3&bODKWEV)HTPKl?=Mx}=PYS}~9=^pw>A;#(JJyE1H z9-%c{<*F*iIJ!^q6|X*CVp8TN6yEWDc2iAAf)h)A&>a~FhvAIJp7htSV#6ONUBbpT zEIF8lxWtue8(Fqp8?lwo^RQWp<1`Hh1z5;t=Le9;^hMhRe}?PeI$qk1H}rgG{Lrr+ z?%;YUJMRAS_2+Vrr%$tf#c-W_gbGPMTDw)#!TOzMY;T5$vf&iMRMp(5qPj@N)X_a|SyYR7&8B?S8meJ5540?nuZ7&h z1Ì$RrdZlIAh&i)?9x2#Ty8i*@-u2kydQtq-@x^d(h&%DPu(rNyqZenQq;F@## z@Z0#12R5&{!mQ5bC1=Wsmq!ZM21e^?+Q4{dMYrt*S02=-!d3CNF?PpC-MH)X@ul^| za}{D!bXCFKRw}LJrb7?L(qg{a3%1mC_KILeIPub!+4KG0TwQR5Y&Y|py;qK+zHf)+ z-#N*ol_W+f=0A2U)Wgh2DsZd|cP$qeh>mRzdZI0v-eoz<|Bb0##oSFPj8vdr>LJ7= z^wjNe9<_D5@V4=gP6$4&PP9fi`avAo7pjFSfg2H#1uCSFk1dDB?!K~jYTclLz$LL8 z^@*bB`l~CM3G26|hlhKUY}!9HXMe;q$kdYbVmM<)Zc?-m`S8*haR8gnhWc;a3NyJ= zY@zqKB?`H zYwDV}4#LFsHmKO4f}rF@VH^#S880yWR3mcEr0N$`=|07|>kbca!QF$JC8ph4A*LwowfZ z@nJrKdVncW>_e%L*E@O6hxeawUVjpJw>>d%@v)T(hQHoWrSYaNzPauqpOpD{(~31$ zi-u6R{>T;kBQKi;v2QC8x`kB+u)>Qv+Q_< zbT_8P^}De3hu=mF;&ezIZJdXmyq)ibzDDw6n;&~~S{fg!GRGpd3&igC4XpRbSFF7L z0Z}PU5GP#r z_amHCg}h?tE|)xZR6D3V-WkBk*)2or7E<3I>6O+v!c~GQuQJ3W$|3iZioPg29_ zg~On73koH*M^oMOdlpJlw+8I^KZV1Fn;P1>@$q zTLRP!zP6e>wN|4RDKv>visQymH*U1iYaIXh`U_}~jg%ajz?_GbT3pXzzJ6<`D@)7v zX$RFgW9nLXM)4e4_{7{forJ*9A}c%UStI>7#m}D`8osy`7p-h^Mj+kc{dtK+!OjfO zR1u7!JVu(B={nv?YSHK674qb2U*4i$B#WG@iYb#HJLg5T9EW&a&amS6LE;On6OnGz z*ZFDDu=eoP{L!2U(spOmr>Jy`joq^ zJbVM1#kraLk>DwbTh<8fyngiZWY?N8L@sAefICj)w6plVl4Oh8s5FNv)+gbIC2Ea^ z)$8m6JxYD0&S!nYC+S`kp>iQ&Zj}eq-VhY|nEL!#DKGhmdW}SrdLg(bKb#b=jD|tx zN7$=zl2c69vlW*Zj?3@7;*NIo*1rGxTiw7{HoxZU>VCd){nJ5 zNFmJB>fafv{`6dn^Q|@;y!J#)IM3XLe=mQ&n3P?uRdA;9$6M`k(={#$ zUKMQ_)J!5*o+b_LHvsFljKBqcgB0 z(Vwj>f^66>pv>fbC-caiI(VG+G$Z%Vl>EM5{0|$Pr>@#f(1z>f_~xeEW2sU97L)vJ z9jxIT8I1{HTcbNpC7`9M5a>6fi(r>BhZwtn2-|n?58>>cSoG2sbtdsZCjRrDzkwJs zP_F&UQ)BVXDmZ=oLwrq1XeGJgw@3u-M((@~cpzpWMDDj6^hCqkJ@YzMmjAk-0331m z^wfXY7k{38EG7GRlbhz84Qt3sD2k9Myw~>C-Dn8c-jcG+;=Fz|g zlZVYz{(d&!GRo_H4R_k*o$BJNll{hi99wclA{9;l$RT^m6n|X7Z&GCdx7?!5JI zF+yOQ2YE$}qvnmi;CQ|-9fB<(E*YZu7$$x|s8Qy1BG5Y&uhGX6+1Msd&IhF)|7?&j=!;`*;++|gyg7{G_dFI8*gW-Jvn$E=&q zR`!U@JGbbAbislbNQXorW;7xD0%`97&}j#1&U;$0zcnXGTK4{L;k^gDr?Q%(ZTK7g z>bHDt98SCQN3EN()D(MO@2=8+R4F9VK-wA+QINTzSpA6gJjkd0EBf;}Qqn!0_~)?r zGsypl`9G(`KTls)DnbFO$OjwA9}t_2rT0jfnOc(h4ZsK7GffBWr9J;Aow2VU0d={% zN>`ZWYDw9XqwbC}pI#9%N2*!#s}#ZPLIfotqi&GVBKMInJaHXo`Pur&d=SMqbnUhf z!rrqlbI$`Y(5EN(3WJjB2B>t&(q;dzwNULdpR<9D>rKWt*lIo-tD>MMp5E>&J$Y8S z|3=Y?RL}MNov5hLa4L8wMg8azHz+3AkE_eZR4gGp9$7p**}Gz7G<4F6D%Id+V6?Vo z>(@6W8A>-v8ofk}RMFC68p#;$$0AC6528DIZuD~`IA0UIx*`%0mp5i}Yn#b}o7=Kk zUiO4rb3@J9`cAS|JIkyQ{kAM4<2wyK&F1DO(jT6pZJkoRoUgk)uE|kpJCL=h zz11HEr@=^5)O%m-O~sK%pkHtm6;^Lc4jLMa)|16wIUjdvQu*Y2`SUuOS3-T25CU`) z-aKwh^?W#v>v(@)_bQKJAy*Yytt^(jKUerwDaO5TrP-!HPhV{Gg32U^P%>ZeGs;%^ z49E>=YF*ZAv{Lb*qWQ{sn=8E^rv@|=x*0T;?hhc3E<}o3^IYt|OFE{y>dlE@blVzW z&#XhMnqC*~yKpgQq_2DAOXGK|b0#?k;ct(kBxsFq!zc{&hepyJn;kqDl&t;M-+E>D%QS-* zvD69TCUo4$9a}IkOgLuIeq8X<$V4myC^EdYAWa8klg@hRDy@1e>O%zn1_6LL<;Vyx z2|S4ZW?$|j02{xAZQda}v22G0lJ+o4M3)>47;_#O{b}m{%a%TcMi{pg0>A7qW_)b0 z9v;D2VS~;hDL3=?VUe~9!_%N`j;iBzQooIi@ZP>=>w2DoL+tpmE)T}l(IA(SwlSI$ zLJdA)LaUEYB0A~wxYxZ(OJl~l0+x;!o^WoxsZM~m$+AX(PDl=b&9Lou0u=iKZ$f;0 zcvB1yic|l>gB}8mAk4=jwhnPmjE*Z3xV%ATI_0sqDv=KL)4)CL=N-#3o0VgrlqDfq z=TKKvtJ2aC_S>39EVzgO9l4%ZWY;k}S}Cy&#o|XE3&0GX>-DY~e1tpV3BtbDh`Fd3 zbSXl0=Od!I%J>7U;@koLJ>7Z^eA!kHT_Nr^+HNPhpXipkv;51*{U${Nu(3UBX_?oI z;~8%x4&q;_I#8MhD9SbDKNe7gO0Fc;KDO#7o`!p2H$r(&bvE4MiqE|&Fxwd`wynXc z{zaXmM3G}u=9k$$wzO%OLIJDTh@zU`0;AAiSmof*_BGlz%CSm_x5n~Vqu zv$OGhr04PcP`yoY1^G8Si3fxlS&l^7J^|MP!JESZzG*rstArmUAobMb0 zh+Eq_%JiWG!AWkyjT4LE?X#;{QZ|+krTDzBp5M5QSKI*H`&>SF`p!qM&jgNcAHly7 z?timwp9ogFZA+P&a3AM{v8UlEtV zD!W+>m>FbZ`@o9n*KDIlWGgBr4tBvDljkRh$96u>{|?u``*{A&EXHd74qsod*ywR0 zWOX3hIVNDyZ=hmy&luu?Y5!xkDwSis?lp)UhYr#8|C*9W|Mv6wTRkk*vDC!MC?oVf z(m3O@&AASy{@ML;c4fJ*pLjx5${hqWE6}Pvw3b)-Yl1$#XTAqDy5H}Ce=(*0?Ge*4 zv(@x6d%nWYvgc`ocHis#B%u{dR^GdOzWc#^9vdHJUzC(|K%mXAH#Wbq3O|dCKR^Gd z>!sB0E(@*`Zh6k zmn5#D^%A}0lbNhy!4-V%+p#F+m}!H*F2uf0 z?bi8gVyrTE%tz@qK@nnHr{CVm&@>($@vZ7|b7}09X8L-@0O$%+RLhRw|9HDRtE9M? zKS|=-t_{Yk{Gc(2?n388*9LB-8VcD+QnISe*_|F)7O?=-1Snt6j{tO<_ON?tphzce z7^&LP6s1}@sEmixa_|6g7)QTzbQv+a`cMAup8nnYRe7wg&D}%NVUED@@rX>mSP)b# ztoZOZ!#8)8dGk-}_%$66JU!dtv7-9>q_n+_dwcCfsbqE7%>8$_P=zr%+Ab~?-w6>< zYeWbFnyVCYd{O;vQ_eQvRx>W_TE}41G zLEdJ%B}zIc>V`{nv`wUGp1s@=#C5`wa7IDZBL~gZOzE>b@I3GSNFBJ`GtIVn&o6Jc zyJnp;mTYt3pSK)DGcCyk)!8qfWw=B6WsUurz>GVx!)f~(zQEEttTE@nQ9a!vjYX|! zZqf@Xr`UZb1ZHQh$!#0R<6QEyXZjd(UGa^-kqEdGNxO19n2iEo*%{SSm6Grs=mlJ`5R_J21YiB%^&J_`U$F94F^!k)p%6N9xaMIY=oJ(|)|zi`57I&&k7 zv~U0KWIgoQ2=ae2ierxwXWpD`77>b#dC(If{LC|Pjc&R^rW(z0H*oui;$-^DB4pX?Ji(-8^qnr6{F*7}}g|{O_!BkENh%FD=M@*;#KOmWN>#7dG}> z;2+feok}$S7Cp!ap=S5q!e4&(2RZ~0#fwq2RIWxGa*;8-daND}%fQg__jUiNNo0!; zA^LB^mbNth&99E`m0(A5;$+Ivw+!Vp{|9&6-Mc?QCH@m7-FV_YG zCgS-`N$93+Y#W~Y4(zk~WO=dPN4*p|M%gXrQ@ zuPmeail?aeR>$M|D&-|>zE^E-R}0FrIR9M$Oz(aG%+KWi|C>Mm%|gb3@Jx5EB~(a^=s=-;yTaXab2M_!oI_EgZE!#!zVVN_Sef@b=DS zb?TUrD_lu)42+8wFH(ghp3p{JO1&uK!kMr!fXk-Y6jya6HrfDOzY22ZfU2hv&irYW zvjcX!&J{P)w0-jZ?;pOqs2(1R{*SP+^H zc^Q((b1RfJqucY)f(gf9N6{-UxiqwiS7tNfNYNl7VbXa%yPHC@#7_wotW2|gc{=>k z5?cz5S!(b^o;fb7?3wqpgbFU7^G#*3rvPLbHrgh?)30vQQi4q9YKS>d*oxe|Wi(fd%|8#IX zx9baAi3&o7^7)Gpa|AQ%zZA6XuRr1T$<0ogc2Ms*v6Z~ra3z8dG;nBz-~c>^N{lxK zyC@759AyyUu<5=w;zt{3FEtZI7v3bN=#~?_i2-o~Nf1>EyZe`@KSz*aQn`^vX zaw2X+?cDB5L-E%QD*o;cAIUkMsNaqG{Hr@ss4Ide_H?5=5kL)+4fwYgW)9dfB3h6L zA}IZK`9Bgx;TGt(LkQWQs!U*Eh;hc66~3Iu-x=y_ys^34mIS+Xi+rmG8}#_mZlBBg zfBVNs^HR&9zdPS_WS)4jkW0|UblHQ#DI)jARBB%P`R-?e*s5DkQ*|66SSiW~wr3_Odj;`k7!lj?%2<~Xq zd_*!*zHs5?OnBRe@PRh0qWrxKgRjUIF(r@udAx^KWyniUan6d$*t-F2dc$8YWU(l7 z`|vbXIpbtD#up;XaTH&*ElC6onihLhA{LSNh@WdWkgqsjqDL3#9dDgvv+`H+ zs3x7zjhX>gsXT(6=f@|qG$N!Gi*)9{K^3Cy-@vU|FVVrwc%|!s7Q(=Rsjtby*0qnNb!)~z`>#lY} z03xa_u>yAOEWp_+*M)ch@9Kq}YZ#c3w23ND*@>F(O<+1!kJnHf7Ho$}&Yr1?AAcv=pb&P2?Z`U;z`nnH zWV#AKbIiX>3aJ+n%7zlR59xK8%qv0 z4^ZFpr>Eg2&i{~i-uxZqzX!$t(FnD0%OolcD-z3K%h>~pgvZvX`6s(}<0l$=@PlB3 zF7H(wH+ps@5WbpSp-B`YwiOQAJc!_u0-kiAWm(qfTK@N%!nx1E3dtD^#_tCk*p^7# z56F8Yu10W5$KXHhX{H9mGXI?M+QP2N7SFE-56UaaO`~5W!?&&YQ0DRCuCqm56ejyP zHO12voJ3J#=72Obe6r{X{7Lgf#!FPn|>p*!)JZb`mQ`0AgLU7$E5pM^F=rg+MQO1GS|EG{VpR z*9wQ(wLyK8n+P)3?+urTK^|zBIck$rpS};!8LmIb8%1*yVgQC0@%7>cK;TO|Knq6+ z(R@eeoxvtVsEYy^T>e8)pCmR9R0vPp1e=3h(`5}eJ3v+vBUS^7T$h5%Uwj9#*}a0! z3xXQnIF+E@tJ_YLEgeHpV_Iw8X>&ndmE z`7F=^%agBMEC_fn1X`zd{%I8y?e8ToTcNG=^lXgTAu8MHMOu6POPiF7$f7C=43yff zUaq79yEA8a=ziD(SqDzl(0g6WReqsZhe)NjRqh##@1zm$njCN=+)8A%+W`;HM%wRIPT7j(Xma?deB;a zSu1Y}^LP4Et~+%QC0n*UOgYTU)Z?>pYZhuJ4K8i|W1xPL|Tuc1~%_K;Whc`@bfL!is z4FtaE5-Es&cYp{7JhpxEp1)Z-`y@Ea_t`Z4zLftoA-n>ynSpl#8+hPb z8}Xn}D@a6NAan% zc_J!@^7PSmqSx56$-ka(nPf|jm2G~X=cfv0ih7Vb&I7fgt;2x&h!fX$PHVj-W@K(g zA&?u8T=bfoe%Qj>5q<}MnfpmGm6Jht?m$odda~KJ!2QnZsezo=&tv>0IAexLMxnRi zbhz;F9}vn%rqxx!Tm?!CuZsKvEg3j&car%CCQlmq;wV0>8wp}_M|{1*ah*kBFNSh5 zFVwTxyEji8$5-rtjJLmtw8;k*L5tM#!|Ve$QIQ7sI9_>u$E+bcF466Tg}4O^RBxJb{h@^shaAX|~A@i}g3H?X~Sl?#m5l6W6js>@@AQ8ljvA{2{T8@R#P`O+wyD5jXzs8k9ol z<tM@=;)ylpAn3A@`wp{6kvYT0M7hEpiWRyydKD;wpZONe)P)NK-a#6I(pGBC zW96P_9+wc3138AIL^;gPRPMV3=m8#jf}V%}9wH0fAK4WBSJTEnGq!j3_{-_Y?LE3l z9v%w+y!Z~U9JA}eb*6SJ_ojS5X?sw&}Ar&V#hfU zg!5kGHYgpUx#=n3sLu`TR^K`O zW;2O7)ytZHiGQgdG z+n)N%|7UmikUTT!k=ZW%$EBuu&Wx?)52I?2&cdUpiVAOpk`E8I-KdRjxJyaWDlKx7 zq}B~NzLxw{!~k~Oaw~Q;uz34W(xJ7ZG~69l?uX1CYnBCDya{?wN&5Oi{W;C(8%Fea zSDV*XG5A8=FGwk6WV_wEoToRN_Txd<=UoQPb|(d|hK^ktcP-akVOi&rKK}m)oq)7B z-En!&96H5Y55tI=c6ZB6D%4sei88T%EXBl9wuH&5G763!EVM8Gt$t^mWMN+Tsre+^ z&%~0mc9T2DXZjb~l}d(SbAlDv1OUERH-o`;xnvv$4=glO_(*9job{$*oo7xn8(3W+_63 ztExwej&wFby7p?Bna#U1MggwI#+y^+ohbb+&+*JU)lBF3ga^ziF3dH_G zJKe&hyEG-IyCt8|!s3qJGDut>M;9^%>pu0&w#en6n?Plp1h3z~PJgE)e()l+0t?tY z#m(TE`B%U76mF8&fWvx3hbsR&;6PlDgS4! z06N7+*z(0W*jn8WNUTveY~lpS0o~!)om?Rndlvg7Eg-+o(F+8F$52~BfUpdJ1&KFw2=fx>&?GG=+jAW+zh<7gVGaWt*_SXP`&QN>rYs>b*&~eXOLmbZjD2UUgPG|)>wc=y z{oFnG@Ao{P_x*gH_n*Qwb6w|hoX2&n-!I1r2|=7C??IUsA=L!fTlZHmYIBYgEL6`J zFFgcar}sqIe3^-`w+jKkDCDnGYrEY0f2}n~gfT zg0cQCQVQYYf(|ZYXvA_~&ME4&x6kk7Dh%N+_8r>YSI8;x*q0gaz6a%ih2w)_!w5F! zZQG0iyWuyH8SgtXSyTJx8*`82lwOswScw>@FFJRC%-?S3g}DLcIZD+9?ww7Rd2I1h z%S@MuOC!^+TJ1<`k7E`vnmip*cW%WHsesEONi7X;)%w}?d+nK^pVHx4uIwQX1cHyTleL$_9!;@oV-XE|2jwLf-U{@p0g|ByUS#3 zqOa)+i9}Ss`)u8?>GFvn6s6>*==;ogk0LXx?YpHd`H!Nx7MnRmYe*8QO4?GafvxNT zrppcroEF(o3n;jvcGJy%FFn+@fb`yl*#nBla)tZP&V^I1psXmQ0yq14B|^ERR_-L1 zgf%S-q|$n)&t_lmhsNUH;lqw%U%nV##>8QGqxNQftY3al6kW>-7nG}!xr(St#T%Zu^#jWjpor)%S(uATEoux8 zE`DsvNBXZevL7E_ob!@=_vE>Y_L}_T2SzJRmII$FZ<;jZ zkvDbRk7QF&hj7=d${4VG!9?bsG5h$v!|lt;emp^X>4fNQ}g^d znPJMXAaH4P+xb9Uui?b`1k2}!x&qWvH%-mV5Lb8`brw8|S3`r4QI{BCCWUlZEY5v9$sAjjVMiczo}Fx)mitexUBd zo;uFpP~VcEA(MypVLNf}hC}_5n3Kb{%$r^A6g{0taG7YI?U8$g%7b|;9#hRr%trbc1(D`^^w(``A-3`mcpIp3gdS=-g7y<_ zWY^Tbekvv9!?m1a;=3N+y~sRok%Z#I84uO3CLB*%Nx`LJRFOXS!8%|BgT+_(C*NDS1Pwq5@#h9 zwV|)4^!dOfLMnL6we-W*v51_fCR%sgl}KVk$f zw_xD3eY}uT_+AS%1zu$s7Rsn4qT`*U>J}sNVfGyhL{7q|p6Ait2Vr*A?(awkn8;Wx z9??Ww4&CaDfW^Bom#L~UyTm9%B?*`etrZ z3m9E=EAF&43G|k`#|eionFItSrQe__WV^k}Nt4jiOs#}#8HT;zC=pAUT;CL|(kiF| zB|9P;zGfzoNjah3H|6=d&}W-PSU#~ScAenAn6dtGa%IJ2p6ygMPQP#0sz8@kI&qS78 zE|pg&c61;n=Am0SH0U1MCad#O;j6YBlpvCG0lm5m186GL+dnYsMxSn0c`=eT&2RJa zU=4a;vz5Z-XHZI&&nj&byA3ukNNs# z-r#nskk1#W#)%-p$56-5*-8>Hnjjv<`0F_x@-D)*!G*k{`OT0X|SS*D(Gq#^edXx z&sKEQxk&Q4Gf` zUlC*B6-pFG2jRL~x4I2THN)>!=WkS((4WG6dSU5P?^p>!9DYPyVO%%=kh}5~4u3*!IbHv~k;U(|PoM$1N}@{aP$h$pjfldBSD2Xn zWaT$!UH{SkoHd<3vOxQxSYUpsZ>su{$k-lQ2zBTFZm{l@47tFI-xszu$IlPm*4s;K zfJ=hm>OxGC0awL61-kgy0=5Eq3rc6&zxw!(z)LL%WHp}eswJ#NA}Zto;&AJZ5#pB8 zwZo zIZ2Fsw{IY6rlx59hRe$Odw5w2gnq}7K+PdZt5gL}B z72c>?n5~k%KYO}1JFlS? zcHr9{w__Ot>|vBswlf@@CCZZ8#kmREYZ~f}CubqNWztxRszHA^Oa56mzQu-iJ$8N` z<+GCPBT{#CC}`4&)8Vy)DK166mU6oz1)Ie3G=WEI&|-ZA^TF`;XU)!Iio?(Hf)Py{ z!bm$bc5n`Af?11GXw6&5pMl;Z=X6f8t6?=6xJhVLETSLl6{}d_z@O`4tVYvWE?9BK zjO#gYi5+*@=S@tP_H1cghM&Wq#4%!!evKSS_Pf(F_e1k8q6HM${I%lI3(zN9_NaR( ziD*?g;}NZsuJ&F%uIq>gbHn1uebMhRIr0nAM>J62ocOp$;qa?c%9T!sOl3GrJM69Z z*PB@!C^`P9M*V!K%Q|1t^!;vkQsFxK21E;6t^)_y8J6j=x~d)+n4=3#xSoEjzTwK; zGe10!ZVLe-1Y5{UA~`B>d36d9uKFpXUpu&doU*8Iy}v)dnRUN9fwPOXGG-b`jW_dX zF}4n+mSqP_;3Yv2ysG;GIeCtm1>x`T_~u>M8!PJh4qV4pB3Nyb!#cULutCv>=mzPX zR(XBkTK*i(pQQ9G_0PkiNHUCXPzWTd!q#zs^W(QBXY1KxOO{4V1Ly&i-u7rM`G4R! zzpA7Uz=CGKSprKt)|y4f1dpGorN6I)rFwJv-HFqz-u^}tu3sSbTC&5)4Mqp(CMa5K zK$$_8C_yz~77ei0xxTg>{s$D^pVmz(T=<|lz+;aCj=-mD6SK!pmI89-Ou#t$Q>Fge z)ckXKx620zHIRHO?B|C35tky_*;*{Mzk_e9!$Esi&jTeT3x+b+48wknPP1hN471-b zEr6Jc?cC>FG=T1>PbZZ>ev9+q-O7suuV^e{a!0=De@NF}X%pSoNN`dUa}?2zEXXf^ z#8<$BH;Xi z_5A8tSRKT1OP%UW=;uztKH=e#Gkm*q@DmB@SFo{8k@ordpAiNjEwtvvcTht!eNDXc zte=?G0uDdEeJ#dt`eIhsKkAJ(Fjy!iA!G3hvDR>rm?SuWv>nio^NY5nS`bo6-T#{6 z`KWy}{|uB8Z-=uS)^k2+-Zr&zNtJhDOT}vw#OIks^ zXneM)Kh@@YN+=>zLF=RZfS`SN*Fb8vx_Y(1wVK>>GdxxzuaghT&hMGIs@2ZxCaKv9 zwshyO0KmWg{v%8E{}UkaPDaf(Lp^0P@y~RVb(J5rZ-h0JIpB=3{O^fmi!x6EnUf~k z2s3!lS*A-+WL#QloL}K>Xp<%r>hLlX<>9yI!grrPpUc!o^*U*aCG_IW$Rwq}@l!>Y zAai$(xLb!>DLip7!@`^bz?#!%;Co-13afcv1xx>Q#ht603vm=*+5D)5$Nb*qJ0jypgMr4$nkfM>d7Kr_`eC@$I!Z(~s@BwW-gws7;iZ!j-uCNK%15nVbz8g3IZrzka z@B@=#jE2WfX;+O)-?^=ilGnZoafhJbD9J?%yc;f;Sm^}a(RH^LI@UC!`GH&IS1WSo;ttXZP_x*Nr=f?p3)OU8rT6~rbs}y#=~$I@IdXHn?x+cPMSWa0<4jL>sqb@a<-Qt9HK|G=TCKDjX3LyGB!){D|b4EdERq)z}Mt05fywG^d^NISfb7W7NDU>poZ%05X~lXn!2 z?cv`zc))?~`rFiE!Rz7q3ANSGywpky3#Se#c}xZ+i&=ePtoXs4A=}Ic5#d+(C#d&3 zcW5r2H`hFN@wD~ZG(lnM)((_s{1G8p}^*#QTh+)y!g}RphLCkt}uFo7qF@p?c z6j>QmVPbxnYXD+H$a-8oJl$T_Os|u7;ORoGE`G-Utz(`=BdfQia?WyF-zv;%DCN4Fa8P?NX9PTc3~I3hxqEW3+s$ z$fPEsW2%xja~bmuwGPZ*q<)7iMF>2HcsOwteFrNA8QYz>n>CQn$8u$b)|6 zz-Rm!m4nE`w|PuDIPy1T44%l*oEMF`jvVk{kH7u#gmdu|9buomlJ=Q(b5ODHQy{yK z47KcYc3643ZwgDh7U1_8#QRHy&T>bQSH16C;T*NmM|vSfX@m)GEnf+6pF1}(lAKNp zNln0rqow*ZM%Diy#o|`rY(|5J*k}D7H-YZ?yW`Hh7l(VX*_zBxDR$ln9IMq_vo}#+ zQ6bc=h&mr$* zGi3@wRXg_7^CYa!hd0tumfhJ?mOP~CVs%8bvf`Cy@#R73a^(Xd8GcstA{JdADo9m} z(t5=JJBLWeBQvb!ndL)gwz^Ylxs<#veiA7>&mU0=4oe3(F(@HtfWh@5_SKS^0j>AT6yB7cXavb&B(IjCUz~P!QZw!t-g*Ah zhD9qs@<$T?)96t%54uH6hahFgaJ?@KBw222kRI#yZCQT!-&b$Hs730Nt!{#tuI~IX z3AYOZZ$`$Zmb5@#;^&%Nx2N8q#(5Kqf$sl`!Oyl%BCez>(F#zfeshfjFzf*lG--33 z!XD_xk}s{j&&nl{fpq5#qReU4O%Pk&V3|oHPaxUTB$Ce+vmE{HwPy%QoV5|6s+hwM zXA!-q2xN}Pm~8^|wBLj^;k5Zm>Q=~#XakOwD6dDIjNCk#wqx*rH=}K{l}j;mb$psN zn1!@@L@Ai5Y{}Hj;ijefH+&T;;K0}22VsB$=;khiSBgZr-c$kz4YVGggCM?i|5wK* z$%C*?p)1xDjj2=XC-u!)3bLA`CETBRSbnVcC+9vMo)*BwSrRUyP^6~Kpl0WG3d?`F zdv(g~NV-|f02_AaJ=T;rok=0LV|#SeQ+%~~Vh4;BjJ1}=KnHq>&{#|GrPsKm>i}^X zo`gH!+mmk~cA;I6*yj;;bjj0JFCEf{;+nJ5td@nuTZ)F6!% z>;+ymr(dN7;$nzwhTJmj`Bf~X`Mgr4zz3QmuTVSwSW7r1vyVYLH%MYoa0)n(5b z{hvHL`t*u|_^1{LB2gBE5E{d_mVn*#NAqxnWfH=6gosDuzJ?6`2Uq{pl!(Tuerh==;#rD83P}$C(cHgV--xm|68vZ?B4|vQMhLaJev@)Wz@>F;S=lVUbo;m>r7Yo_Bj$&q4IcN!L*9oP6OMx!U6- z>gN7h4(ex}nfV3ca#)R$o#F%Rl&*wfm}tL^!Sv-T8uyAuv4gGmwJ3+#ozCSiHeq|D zBIuweecs62>$ zLO%~iK)u+MUJD4*)@^&to-t)&vj=pvgh5Dz#WXuF4f3dFHm!Ukdyb&bE??56dR1(gN9cf@y4 z7!s?=2Q|0NJEn+vXdk%Y;BogWQo>(u=ogtK+XB>UUL73RRHMlJcOLiK?R^ z(T$*tu-{UnF#e_6F_tgH-l6R-T}q97i5&RMz5AM}q~iMxtzz4YVVnN^mzP(!l@w=h zJE+Hcd2>x(TNHciQbGcIJE14{I=F1!@st1&z8{1qpjqbI&r-^3VT(0MBh+3p0&S-s z#_A2`9_du?V_-Za8kuf;=upW<$e8721Py(!HV{;51jbMf^=Uwx&#S#$L?u$0F8e|a z)x98==e^fQnBKEz6rrReFnksD`6%4Qltth0O(b$H>E7KJBCH{OCBu%(N3AaSjdSFY zEm>DCP|p<$DQ|G=*C+NLut(y^$pFtqQe-%AL%LXde)Vr|C!Z->4>&}(b9E_jpWD9 zhqVDKMu{Xa!RgJyClmNO1oPcxZl%m+VJ=~A`Pqht?@0C{W4q__x@EsJ<9m<2dJiz{ zR4o_9j4#hw?6|N@kGUND{L~|H)`0SCG=aP0_1#RP_SZgfy@rhrJPm5|r+A`U z_xNwzqaSTn#tmkbFswN1~q9KVB@pi`HG}wvxms0n)P}-HrgGN+*@sIZ0;p~ z;{k(slKjwdF3(pKH+BQN)CRc-;uoEEz-lc=5^LUu`Ia$n!5<3FBl4{=%R<6e>QZ}y zh*83OPnM(1QIa?fU{pvW2~NPXT0|P?W{5cgM+>qYZNS?IBPis<@SwzH*5Z$h z>jlKLD9ADlseXYNDdM?ct1O)$I|6Kz?60M#OK=Wn_$+ZZmhKCrcQZJ_Fy67xaOC(W zpHL>tQI6ra3l9w0w-6PdS!TAbY!12<5&BQv4ZbxjTmHVrCSt)b|6jzgi+~bPBqr%1 zi95eQ)=ncTxW7QuGr>@HAhuZkNL~MS{8Gz$i_xDM21#PG&dBE8Mr?A_UBEvOrT(M6P?Qm@=_vdwSd=hfr`m3?R-CLC~SqLwH<#X{tSp4-JO;@QTqiV z44VyGvv8_LED4#zR;x3%(aV9r6tJ&IClABWe1W8OrGnad7$M|aHT|~TGu^gw-_V3* zM+SHHes;b>dFCb!13C;{FZdya95+}_OKQ@tU7 zUkezZe{R363~oGu^D$q%x(ZL9F@8FKYUhGa#!**k84HtTm`sTxhT4CS3VCPX=UL|MyM+-9htPQVtF%@4BoOPy?q+EyXMM~3;cN; zJ_Qr4_Na6#OH0&^`xbRqop0~b;ZbWSU){;E2X(lcNaiNgy)J3Sts0fP+>`CcydwML z3k2J{fgAl)!iHRP>+oA@SThW>o1LiIurl67$W>B$^#u|WIScY0hg@MX4Yik2rX;Qu zS*(3#ty+&+8y)%JwR-#Vh7GF+C6O;WEhT)HHk-Ud`-e7uP(2zLKqZa*;| zUv{O+Na}1=Q*n`gl==OavEv&V^Y`H)ZpU)ZI}$&2#;s@Hvl5ZPfhbyW%Qi96CxBE> z9-7H{_h6Olsgrr7&QHZW+2~`;_LGTodQpo`t*)M)R#!*pz=R5@Np^x?IQw?${nS#o&NdM@a|6l@%ECuLRO3S&I{y<-MS)7@{1?J&;Grsf zyd#g795*b!8X_b*=ugG#zhB?4bvdLb6}nATeCxjvh3_@0)@4E5HCMF#Y-@FvNw#C` zYe43wl4;98)b&080(s{EeAm4*f7H;BjA{FF{~f&;zEJT58|4L_so$q$w-ySL?e7ds zwyC+qWul@1j96++bm(sVTIfLXZuqlyVe{Me>V)o2U!F<~dfBzS~6G`48bLS|8B# zh|xT}BuG{Qo%SLSvT6Q3N4<@nFjL|~7kcyn0yy5l`v)N6{5T7uJg5{d9r;>zVcEsE z4M^~3x{bveAIj1~umuius|x7wHA63nTfkNuK(Ul8DVV@^cD|||CtP+HL7745Nv4H} z<^*5PKs~3@)5x&^`Kd=E_HM{?l&HsObHqOVtdy6|;m!T|C5~SpCh?s%Oqzs&h4ee% zyEMo`eEzajzuTuiq*Z>iv)_3GOZk%Q}PDf#KklkGFvaQ|KP0~?=M-;$A4lQc}A7L3adhe#pu@h zmPirlN*H<{299r;hhXboJ=GVl>Ed~Gj>ti+yQx(J-EVIoS7FH2VeijfP||h%5>K_P zpn&sn?O-+=eDkSi_#;CopUs+dEUn5S!e3zElX96^)KZbxq4!T{n!9cvXL6|H9ErvE z(MFEgPOO@Xp!0KzCl_9pA>&{+kjsl9DpRC~#{n zhKu>prD;-{bAsXh^~(^EprW?@NE>M z89BPltA?rYpAq_NaBh$G1fK2{>HC!jCwm`uT@sJ^Kyy~@lRdfkOjk+&0z2Co7LDYr zC^W<4`Fk&~NjXX|nVybQ)f$Z>P-j_0eMF3knE=ZO;3}Rn(6Eio?#tALstSXmL($Ix zI8}Z!ELy?Ap#_H`$~OV5o_u?HlAqal|37TWZ~gBMPFJr2L#>Pk7iNcrR%}Y*yjl$3 z}9&ZIrrj0bS99_uN9WEVezz>tsD=J@3g zKkv4kujlJ)gcLFEE8D8*XYTV$2Va$hyQ)!Az0Q==EX>|Tv2cVqIk)=yU8dpiylRe_ z`4vPE8Z#+9J&sJ2x4O^`FO6KDCL`v5m@YD-E}0MIIHCWsY($2Y^|P+5om(~6tc1<1 zq`O6`*8r<&mD*7aj;Q!7i%l`#yeUyj*Ck2Ti)Aa}Z2J>Dsge`QZ7oLOhun;?=xRYF z)jfEwoY;MLHjLkGaFTY)TKU~u9VJ1gmEbMB1J3m6YKgkd45@-7|@|dEF zG&fZ1!nj#gF8r9L&fkm5*eYmcYoOfD@YD;uBrXD5wn=>u zf37&GrHC7kccO1#WQ32nw+W*9f0fivfekT4|wXUbZjh+ z@DdL}r3)fPdWF{+H?-o|GH61~o!3q6y}m&DkZ`$vM^J>f#lm8! ztZ_ykJU3=xw^=~4vu%?gr@j?;%ws>SB!~7EvZ7e&DC$*;(uu%_^x-gRl%d0&>FW`; z_X{&_%F}5){e}VL1)K&A!7bdEFm)ei6G!o=>kgi$HZI3 z5))Ez;FA9Xe2tC_MH;HS)`1>{T7r0d1TZ1!ZGYNDqk@Q0M*uB{x`Z*+8$3RmBGthY{`}#3MuaQ1AYbNPWPWaJzM0IrmYr76ICIdHd24EV#N@J{qhFJO&t-5I-`WJTv!(v=_g z`Cs6akdBp3U0q+5{`DdK=02_fT_8p&>%kY}a9TvK zV}KDtYu{o^wIMWj(IG#xXuvy4dpK=Q%o zd(kwoQ6w&o_}&SS zNHNv3Ob^c>=LZ;cv2+@F>(RfSFjc~DRtovgBO3HoxTG2x76i5V-onXCKG~HBA5$|A zk;S?WK{E7Zf1}yK&m5g(y>0rnBn$%}oAm=a{~WtC{RfiYLAdn?aR(iwg%#M-nF!)E zV$#CkG;BJ8(1y=w`q#7%&#AJ0Gnc8EXaKOfqRKK^C?S-)qgem2Ebm7V1y^4M&LQt6 z$-fYY*VFAG7vHKFoVqK+mBRcgFFrY61q#JQtljOx3!86-`AMl1i9P=`xp?A;C$ni! zu6Of2nXl>>1}&pjF}gp@uCuT*@#L5w*Jqv=_jH((uW`+@Q&Y5#mnCCS!BV2UUehs? zs&R~Wh`6<^1Oz=!VVc)#eEwLH$KGOwYpXkNl^GW9X$!QbJbNuhS`B$p_BMPNBM4!bTl>YtX{i;Ib@Y$em_}cq_>{x|TdzZ@FLqal`3(q{ z2B|C(-^7U7S@lp};P4ZaaU|Xw}p!0vEf8WOAcj zj4T|{z&Q->$#YUk;f!q7&~`9>V(owr>x#NGL3b!ZJ*~KR8a^s<`mF^+KW+MX71Pft z`&Q7s_)^uvG?Zsl8O-zm2-_^sg~(Z-w~D?2JU_US#1>IB6QkOzSM9-G6r+j@yS3>_ zWahr$Wz3Cg*pcKteG>Qa3nbHisJ;4XyyU8dYAG{}YUX2bif7feOsmH+rR%U|nYW~1 zZ`}>9sJX+7%LZQUbjp2l^qp6;6X?S5b9*OQ zv~v=mMQ~ZUS~%$l$YBq3PyoyT63T|7=F&jQ4LMe8KUg#0rC9<^QWdPIfK6oft$0ip#r)cO}R6#v$o zN=e%AQT*?)e!gMBuXz}!^7NEW`L1q$Ew1@F$ls3ox6?J1+m5X@X76sK@^SUEY~EDW zFu#S`rxs=xY^RpSPZN)%8}6IYVA3N60z0rZ#QIc*-=aJ^NOz_hYFtvFCK@0&9}=oB zOEoMVFlE@iojBq&S;iWi716SqO4ziB@@Sw6C%I}_Vj3rwii`t{HQ9Hms zGvm5Fj&hEp`#<&2rY=i23&Laf;>`|@3q^zukap{S{-^rRDx7V14e}F9(>#AI~Tk0LA~IG2BPTC1N%T*4Cz`Z_P7BH3f6^OpF-Ds>xCm(i#C$Ud`oAj zNOadA7Ne(!414F7GMC&6YC#9M3F$3DwzrH!Mr9gDtT%{xpxi=*c$#$4|45+Z0Nhwa zCTyVt&~0uTHVD&!(5g>iaJ5hB)BC21y+c#14xd-Ea<`2>7}yR_;0wbaAtp7x#^Fdt z{HHyyjZx-iDKzg#NH3mCx%qX zx(#}b1>jG@nL{vcr{7q$(^SXS#S%8LdzV6t5Cu3%pZ<97Dy_#~S(1ds_|o zTuW!FxhHX3tD*xF`|(YlMMs2{WrpJ*T$y_{w%^64@($B$c60ONu_`h-u9CWa1G*NU z6Amv@SGPNx)WC1*`OhijWtitE`8>L)O-I4zsmeTgs)a4=IZ<-PP^7IyjGs4?O+*2^ z!_rCdD3$6xTe51qGJXOQh7tIfStWyry%oD`N0dJ9BXo( z5otg3brh4E<>DD6YO**pbFs|ZB0#43La;X{zML)W#rw4x__fM@SwC&2`{n^k$D6kn zZhN@bJYx=a$yd!#x1x0lQkwc~FgIb>M|yljml)@SIr16aVZ!EIFi{0fZ)#G`ZFHP% za;HdgvsmnU0ouTD%EGsm9SZjfB^^KtewN#5IQxmyi0xS*E_HaK`Y={zzIkER_k(s` zBa4aCe7sSODzoIrmMagngc&Eq&4%<3OBU&*jM=+#u-q$?NVuDE{blFP56Sp`-3^c+ zO5Y?sUZ?~aZs+w{4rUw@WSsZqS0^c^3BE8{3Trs>*15&wy&YUns}hm7zmcM5VQ+K4 zf99wj8%A$1ib6MsKdz)!FzFhgLJ+x@C*94$cwrm2fmOVEpHx8@?2aPE zp9CFe+my_<)&16yNUF=iNXGItg&1zCS$2x}xw~MX*L%Dg`MGdiCcleRrJ1(EMq$I$ zn={nJ2j7r%{+)+8Cv?}`a%Xg(1Fp%HAf}Mwp{;Hm67`m{&HhpQz3QvmHcS!~|Mhie z(*hvmLx}NaAl=%?3to6*cv_2*Q%D6E?3zap#5WRD7rkzifVqFP#sh*>xHH^LTI%hMx{>(uZ2O zpH5biX-wc9oDQMr*juzsV z&Nz-nu#6%F-7^sX(V2jSpj7ljOmb5Y)&S*?0?<-2YKX!7uQ-4?r%@YA-b4Zlr?v}< zQ<9Pd(Jdy}LbiVN(_w4sFov1S4twqy6}%sS6)R3k*bu`zAt|Xr4o~jheG%0o zST3Xch^ph+emixEz~i%BASSKQC%_TgKn>Ho=u@ zRO^UN3ggw>u|Q(Jw16n@cp4^gLaYZquTY&;gwYenzZD+hNOi8?j9sGkN9QT}xbcMD z#B>>0Ye?-1CJu`r%0@u>jIqlqSBc_PtaTVs+#H}xQs+Hvw|9`d9eMmk?dy&QHHwRT zumvr&n;4D)10%1A=?xdk&A4~DgY#sKOsu0g0we*xbj? zEz!zP0@?TPPV29VT~AyEG*M%1E`sikQWO^!o&1SH=8ti4!v!S;XCU6}Y4>vIq~z%!xp2Tay-UNS>ByXVyE#PhYdJN<=zihaX|L4XyG?xf%DJ}kQqs2 zuRteUy<$LlAZ|cavn96WZ3b;v{#jkRq{RvedADNp85<7~teD?mrPy@wL!L-BYv=$y z+HyFIiV!P|d-5LeiPWatac;TP-d)K!s#8G)j-x>LR%LH7s{yyt)@k^pnmY(fMO#d* zpl)m=Qj-XWBK`;mJ23Erq3fSA2LEDs|Cf$zg*Gm(a->EgF8)0G%RMKY!dz!aRfXLe zizXb-skq@5>2L+==BzOGf`7jbyYDgkb zirZ8JB#A_Ae2sB_*QNCLw`nIR>Hcq9%q*qussSj+=~Bep9>=(gHBMm*T23UxDq=O@ zR|YrkTJ@-X)DY?qx00uR$!-$PLWk{uW$x{;Q?XLrJ{0M+`tqB1lxMZMe z@5dYLe<47BbG-kX4J3aU1ojMU%77EG)zA9Vbra_qKyE@J3jD|*B@mv(h{u23wTcoK zP`fHE6V?46JRtvz`GLw|L%zjjQ;EQ=Br@Is_z3i~Xa~uVUR%?EL)9wECGcXW9;mPb znl821li$<>>|dMxq5~5d0~WzV9LiKBInQdQx>q9YRixsMB1Q&;H}0VOIX-5P2=M=- zbMl?@V_^7u`cqfB|(>A2Tr%2pJ?v zMt!sBe*Hb!c>mrvby5e-OrP%sHKrhtk^st2-NF`8QaBw;E8?EKI{{2j7i^V4D5SPC zf6C!k@)jrLPcGjd?8fxB0wJjH{KG%!yETtKr-BZ%XXfC@t3G6Na$hdx^nor{&jb*V zmUE-HOE2D4y{`IfdmUtFp(5S)|3!7`ryo!I!7Y8fnvM+qor_lSJK6Ya^>a+Re}fSG z`20KI%&ExzgY`ll13bZ%!OeA$FqEYuOZjKh0ZaB5vme8200!%`X^;0J(ycM{Nh!KO z)!oq{j=?%>@!VU98P4C!Crqc{OZ##`T@(o00$>%s?~KwvDBCmW)OI2MDf4nRaIVVE z-B;a0$+4c--7NgHZA5q~l2 zf4*mYwG0JG=iJUCXKn%#Cai(e_mwRQV?)#d)%KNDfe8TR^8X+y{yqtfz@t0R-#?xP z$01}taxXbZpXr^4%{T462;<$IKbanRJt|g+{s?0e;~i(~P$He1C@&=cL!%75Igja0YjWtyja(U$kc$@{jkhLiBQ! zzHwLTSo6W1wL?l^Gi`t(wd6Sqd=b;t8Hok&68>Q~M1@>(pZWeI^Sm~GiLi53R?r7o zhZcr}EA(1c6QCi~ukBPgV7u-e4W$)JaoGDNmbe6>%%d-)YJIT`QTn1vswkca_{Pco zWv3URK`Cv0m9ajVm@qYe5wh$W+38Snqd_K=!T0^qs*@2RJarR2Fqbqb5xGV)*w3>= zOkJY(Er+bYYZvDCun?0=TSBxTR^>>elDh;qs*mu}g!Xkq3anhAl4AJO5SBm*-1UF;_ ztjGZIU7E{0I?K-D1v?#7IK~I{$OKcIV%%u6${Qba$+)aExfX~H#DV{msI3dA-n?imLK(FWlsK5%=+_m zhW%7A+KK=CD3b@2*t+aKsE6)Ix9%UcWQu=)~n0Xc(1*JIjZg7sNMAFTE=4snS zp9$@Dg(Y~hKly`FkE^op#hVm|VF3)qzz@aH=d_6mI;ksPT)3kmbM%?q&h|@eD%{Ef z=d4YW_wH70FePtG%Tlc%9xcZ5^?GzXH`j&stQl>|9#$3eu*+?T)}O0ATdN+t9Gc_# z-&l1E_`q6-vTKOd4GMtbNquS6K4y9Cxq^IA&MW3>WR0? zQafeO&^{p45H*%#X3yG3nNcq@YB8>s#+`{P zZN$5qRL7qEI0yw;9rk#nYk*QnMo>_fQ$gU@&36kNqs+l^8 zrJt@T9p|gl$Ql!q&T>18H5!VlURkynp!2=7FmI*v(i`>U_Kl(|o|b0?6RMtJwH4Pa zOF}a5W{Dca;=WcA;8JCg5?r%^mdacyHYOKF7~r*e95b1iNoaKmNs z&R&Njnl^(k8#@U*Vgi{{Bo4l~;;24zyU!qhGOU3$B26bCk8(*W&bwH~IZDJYtUWhS zTy5;))r%Btw~e?L>ric!WqUrss>iTiy2qzZJ}g>VMvAN^n>gOMB+e%G@*??z!2IC_ zZ;k^c`3z1c>&NOH?mI!F$E4DQ@QJuuie0-~6(A$*GecwLxK!qb8AW4{jH1m|7=y^F z#Jw?(C6&G;v(K)X!l+RfYfyz`em&x92jAjYa-#)h%$!XFolD^Kj=P5g8lF9SB;iU6 z7n&i8V4(r4dhXWxB^Ggw^4A09JT=sNI&a(F``{wA)!prg%oXG=_99|{e!>Q%(tjV1y{A`V_J$S z-p%yqBO;Jb*7T8SIIST3^q}%> z*SS8Po|?}RNGuA;5@a`uK2d3VD$QZ1!iK4SjhkpViK^0 zF(#=VD0_#hm-O#!vGU9=@1Leu?97JTkDZN*su?>A#mH~&ZSH25XP%Qk`7O;KwvA<> z2)d2v3vzp4TE)!iz?-mpF>Dh|_0<+yg0Ddh)71_l`BxyAYuX(Vw&5lku4d;CIDmeL zg0l+fUkSequKg*4LDu!1d1gDe=61rs6*AO!sxy-RDd-TR(ytJG)Px;bAGtx1!`wCE zC&VPZqbKZ3`+s%jfm~RpYXwdCn9duG%`Wd$Wg0=(e$Q}LQl_ne47*81HFfQH#^r;_Gw02K1b#+xt?yFyZ z;y^hTG3Ug4#^6Co)B~ zdG4~2k-=>Yp>|YlS_Nxqbw+T?ky9TBwJ|Tl_lKqFBlR6n$kS*_yaH|k)o*Ud_l`y8 znOhO_-bjP^gYuMTjXdu)yDzzFH4o?3olniFL|JI%^thJBtY)u$qN2^X8dzm?`wh-x z370}-njc(EpYk|MNUrLJEr#L+b1knBjm9!N6m*;r9ZcCpEidj;-=U8iTIg9e#9ny; z3}Tb&ttpd)TdcEXcc!!*b|_Mel5Yicyi`=$q%9vjfM;IZ?38=nC)}%JbEGp_YRO&i z&6+;7`3m%m@IqK~6r1l$E~Nq^A(M7~418i4uNW}0mr26a`MkEC38kLlkZpH#cL=>e<2`T9sWMC8tX#oKNDJjX3ZctjfyStg89AFrKTkkmH zdCq-5=Y7ue`{RB8V8;Cs_FntCuC>;+uK0e9AcZv$LiN7AJc*kHL8^WHCPgJakJ$pQ ze*j;b9@rO4xwC7GMPpWee6FvVP)9|#Ecd@ln@HOlIc*%mcdeU~c6foTSvh3yq5&BK zjihW`4pZm9Ab=*K%h2mkdw z0o$v?D&*b3R)CjylYe`$c$9Mi5c?MU0GmI#zrDqcXLlDUfb}dpaa03#v9^><74U&) zSr=rG3x1%W|C>mD9h@h1_mAg?{XPnPUj!y}czS=84saKLn>YHtx#7TLY&?M>?*-Kz zKg1cm|CcZP%RPl;9{`2lTd6>KH~3IvlMQ=K+-ScZ*l0EF0?op(l#@!!;5BeviD4f{ z*fCE#aKW~XZA~V-{nyj@ee92U+xz3Qhlq;~69znI7^r$*XogVVRcTR+7Vt@lI0xki-jU<}()Oc!K+%n;}pX9N|TMokDy0YuH2^j)Wb zfN)D-KlBg!MG(KL_%yf8z;8blXscFzfi*gvHAylRG`LQS-a~4TYY`2Agd5 zzWahAUvyBP!Dbc>n5s6+{w8B+eC9eW9iwKRHO*$!nCOFR(tN;CEm*kn%EdRK#B7As z=6a~5lKHBtX#&`#cKU=(7ck|O&gge-uD7Eby#)~GUsF0!--W z6Cx3uJ9vDiPxVIETjpYXF{%KJjc%Av^qkJdetW=DA6r>P|JaauP)BIrk$JR5e)ReC z?q=bP_lUKuKu6Ci)EJ=I6vP&JS)ybf?y> z382BbaY~|>??S(>H9uAJx_W7#wh4aP*TTjfLgjqoar(!VlAz-}zMQM9%Sc313|0nh zH;Z6cD@ExVT8p1EP$lK-q;|ahBKBS>{Rg=-%QVwk-QdFQAPm=>kvEq5BXX<6EZe|R zS0lnu!PbY7U-1#Fm5x=Kr}qe7{zi00i5xm^ zqC{VxV3t(WODiqpX6UMhCOHc%Q&;u*DyhhCIGQO zrEh89a^79k1pjRCI~}1e2UAeNr zyb`_u)PeH1d{rs|xfN?&{EIjI4)pUAmD(v*yLO>|G2wqAeb8yT^8JNGu>o0qXD%wsdmna&%AffrH^?3ig)e@T+hFxBI5>~s)dh$)&ZP&#RbywM(PfAevAK}@_w@j1{X$(F7c<$FsNswT7D z&-Yh+r1M`|@*u8X(jM_-5HDMQaGH$o6bqh-7R6nw{y?t}YYp?(>Q)ak=7ygPFF?P( zlE;GyglY8gZ#w&xsZ^NiT-R6K1yz9E$LZzY=Gtf-+(U*8YzjnXPmZLLQ`X7omFQ~NId_L$0Kye>!m6G@@xx34IR`ta~C#! zq|yC0<0Skd1hqXEZS*zw#TN`GefAX?-AS^jq(Vl>ToY~k>nxW2`s~nY-V1p0gZoB5 zY3GzBYn`9LkhrD2GkTMBSTqzCP*J2`vd9=zoI-cV4$Ff0)>+x`aK@2os?_m5&vIf8 z{^W*xRRq0z*ME*^Q!2ADdKc}o^%iHhBvOW%)Y0UT@U!7T$j^p%ia_rHHdnH4(%4Gy z_}L;H(fc9q)#v*$CslS|k}ev-o}z2gt&VK7=IYM8qV4b0&@8Da72Yz)^LKq5{&~i@ zlDN(~t!q4qBg@PXIpoXG5>W&=1%IhLdM%y6XvTVvHxiAiUeQRmC#c27}fCzi9 z3Zd{r#+COcALQv<&%nzRLOVgr*h41++ss*dmsrxVOz+~H*t^18mjvvu2#-4*4N*^@ z6+P1`IHM|Ra$L9pWU?HT2A!ZQIoAyftapgs+7IkUGp51bL>h%=G=;TH`+V))HgI_) zyJ!~pkPZCsGlmSNg(8{{3z0R-RpX5RFkz{yTX0TqytS9*InvbhLJ?~w7v4uiDguaO zY?B+(S-`{&6up$9ywHX+M0|1p-wG{M$pEQYBp!em7pv}#EO~Xvezz;6_sy~MihCz4 zjCQW*mU{13T!KGJZb`ztH)P7!Ad8cQj(SdJuJf2KE=&fs6fWd=CWNDaTvDY2elOmGMwb?q#2X=yT=IqAy8Kce?L+@W!TZ zi!}|X$nHwysPGmQY-z{sJL2Ky3RB=S9v;_Ip~q+CeTf=|kfAAaHX;h(Cw+p>U!Feq zz@Hmq=-XN&3e`3_>6hqa@}a2E!IXf;k(bD1=IR;~E)56i1&HR|`*57&inw{>tb+;sVkye9qIU zMWT+XS|@|$jAoN_jmX_-?Tm_0Ji!D9vG%KdHxVMak4-|ANW#6fi=xy3#9k5X`E62w zOwk>(-I0V*wZ{^p{QRw5=3a>A%T5~W&{0Yq8WP`Hk(1FX8BRdsVdY^u-$0M6!cNYU zfvkWtpwsO1;U*UVqlo1UL*In~^^9^rD5VYnENJ|#+Xnd#?S}PG{tmkcK-vJUq)EW( zx!dvwwkGO3I1>P$!X*!sw5kL`R$_Qcv~hp-N4KX!-j;az1`^Ov=oUFZ>AurzJ; z7QzDMj+KznJj5@O7|J}2n*wm7{JT>5uSOhUIVFxnqihiuqwk1nqEGq{h)&|xA+u7* zZOlO=fYnvhY~bFVV|b9f%`*6k(Rn)BxV$2hQfh zdPz%G=0&3EC^o?ctXR*b($}(urgfMd_%&>&9sFkt&Ij=9<;(y!&?AzS99hVx5Y%*2 zIDqK%*CTrB_)jJrS&b_5?+$ie#V>rM*FmwBG<-4!P;?2x7n}iPwK2OTZqP3l=wx^a zmGMGw^g!ieGl1`Dik%u%=1BmM7Fa*qp zk!Sm!SN_y;6ax?yT198h*%7VNs;`a(xaTbq7%(=-u`HU)6A{8Xj3oPrD<$+}_ zcB@xmVu_AZ2lpsy_8J#n*n=+XRppt$B~HJ`QC1BoX&+?g5yk8a{{`MJsS%H71EM*U zrmVMCQ}&>Vyv-V)8gMSx-YPwDsYT>xE-y~k2YF1>Z}}~>%pO0tWsjD81-$|ECh#df zbg_%Ng^%-PH4qh$C$J%=vzZ$^WqnznC|LO8PQc+O=6s4tSnO4FgkZd3YBolI9OZ$E!Uk9ZJ_GnKVo znHjL{Bdzs{E6#2siaS7%lfx#JVg04=nkVnR`88x|K=PBNd5zOnQ5uTb*u`^y zsOdh+jWE0PV-`2UWz{Rl!{Uv5)v76M6qz}n7zxBNG%KG}v57S&IUSfWsemtCXJGP{ z)c@F`FX*V?BCWMaGh^)@do6+0TqyDq=tHv5u5fyxFbFTJbv_T8k$Ef*ZEm?2+I%+0 za=1;=UE&Jvbtw&f(-&(FNqq$zCyBM39M!@t=G6VSMHcvyZ`C$_e294)xR4oZPHjSB z75SB>w6-N(@6zl$)!R{vuT;ilFv@laBbg$TrX*$o*&Eh#P`>com7)g23g_w+!-8FTkNP(0afmGRZHR@ke5Zp8G4e=zdqmt$Mx?}m?T^9Xb`PNR-mW*)d z=vropmzJSrTxsA?(PRThR-)wABQM2;*S1RPavoaQP=Gf#Ks4$66gqn^zN)H5WCOJI ziTu@jDRQTz1UJf-Daj_+MP?fw$2Sft?nb?OM*S$+T6Sc10`$7O2l~;_aq-^HXw?zh z_AokR7MgjbT{P%9IV^L{!*i^-Rt7oF?GksLxK;ka9?+Iyy(>RIXA~q*k|3fNx0x1m zJ9J92H@3R;97x_i&2~yruVQ~jYbp_pM$7^QnD{vf*T+0s-f;vJStVhuQ_UulB`#JuG(BgMZo`shdNvK%F{!7evY**1n1` z_9CPxaREwdPiC;X$vUIOI9h0Yvc*-0U7*8wQR}s2H`fa9=JZ+EL*&<> zYAw&%noIZQx(3b3y)*?1s~z(x%<3tz`z~hh?)x=8be=iZ>#V|RN{Vb#pwJ|gd72WF zRo12z>Qx5YUoY_+HCN)TY36gl1SYVn-#|5c`SHKF6sY0Ni}|CUQEE3UOJ6H_KNi;f@19;ZYsblkq{j^SdD8KH-wAjl91H)AOyZMGiJZqrhU_pIKl*xlyMf4|taQMNNQj(p{0zA9cW5E*tE6}F1JkK^;NF@&-%Rj|L@MWE>Q#}F%**tHX|_{F)g}U&-F)I zK3M4{oGg~qwCU+9v)^lY6vIcE`MPm&!OYATHzBz1qG8>vc`)4k*d^>Um;_^xViglv z9792tkwnZyG2!h(c#$s*XB9v+?fOvO20AS&1AT4(=qKi`Q}cW+ z?8vE-f!?dEf+^@54H4+$#)Cecmn=24?+QY?cl^!^VFe!~S_6p^A8>u$?hIrb$av1> zF#S`~6|@MMB|;)FXo}-az-`VxeYLidO#T3S4i41nc0UIet?6%|z|>!^^Ya6Ly$K3D zM$y1X8v{@SY}F4a2zP@>8Nvt-*KQ#)tUtpoA%`3>N&6@}pgN$3$ZR%@0R6;{%u+7)zJnYm-(24 zWJ|hc$YhfVyNw=nhR?G0(_Qa*J5#%??TSk16W4vPk;A+Xwc0!tu`(l6Hr1~fzNnU?`@*5e#U?15SDx;8RyF&i9C7pe>&b>BAbdx}qm3#jU40^PlBQ{mV@>y3j?R<>|s2Vz#0$P%^Z1VN*}C_r`uf@rubrv(-1( zkHk|(Qo(@(x-+~)>-3s1{R_00l-j8`2|IOFS;>v}=JX=;;q+xr4nQb&`W6+ub_CEz z=&l11WAUI{T@rN$)7#$-5Gc(3lb_%o{@)qKe?W8b=y*V4Y7}ear57mAhVRy89uK-H z1AIX7Dn%y7!jFSbzG7DZWtPDn{MQNbdxU>K4}R6t|Dm!MP{1AmP*JFXM-+RFGnu9l zv4B%ncgnn#PNmdW{(a5-t#We!N`f=UN(<>A{#W>VZAx9=K$F1gk@5?!8j;9*C|iB! z`_+x13z2q|5Mb8Y!%K8etJK=UXa~O^=#T)Q%t4xYf4>y`+ZyK#dcPVcUiCpB7W7=j z#!8PM)P5J>Fkd$?J6%7WO!n*#aNeK3CBK1Q$+Hdt2qYnP^;p&+3jci;%|1N~7lKzE zL(LjvrVn*48Gkxo96?1T!{gKf|q$-46KEG;858i|)1Ro0COOHr*NX`}u7^Gy6Z~ zmwSKZg9ZHwU@W=)DrrXzqxuc>sA~G;+K3t$scFaY&u{wHv~Zy=GMO1d58 zfmmr+$a!qt99FS%Z2_k7Givys$%59?z}^)U-_s3FSAlxZ9iZm%ANY+#7f<&utiK+t zTd?VYPXZ-NJp?X$KW?Z&cH?%xn1v5x-aCo#m#>Qy3%MbRKn3mP+m~~MQzX9k8QFgW z85i|q3^K+>uCWwfwic)2AWjYM5g>4*X*Gnmt^?KQ!?|@`l}A*{Iwa~glfKmV?_2|_ zGYM}(cEA855W?E-g#6j*I?jE>pThZGaeC^fTUMP z@+W|`qqMLpY7KQqE7M2xKvnvZQDEJUt``e1CPm;in?Rmoe-w)9#0_I!Ail3J0Mz49 z!s*a34Mq%D(huT+mJHPQ-m=p5^9@Ejf|hmr+hD9WukD|RmQrM=VU-p{rw^IMl8zPu zhbB30edUkNLA=+fB+N@71Gm0IaY9mgv>`$oh*=ix+(SzG3 zE*01(RdKTzJ_)EJ%G0a@oRyW>Q3D&Y@%@eZ;X(j4*dKCs!UBHC?_9_R1V3M0W&D{MIhRQuc1_&Vu&nrJ zCguA#(sL!YJ(sAAAOdpMj2~^9C%baRG-yPRQq=>w^Fq}#N9H7wY-G#trH&)SYy&T| zM>pM8-Ez<+tLC4BoNH@6D;j=TzKORfY-2v9JKs=6zV|WTaA7m(@QuNH;sKBA8VAR( zt;hFHF$n6@PvrZt#f&Q+$iZ@62|)$ya#UMlfz*bPIJB+_;ui-GfSsj%kx9I1R&Y+q zhF8E#^9x=)^T4RM8wPs-c^ z@5w{WELoz4SJc?bXZ8^91e;HWB_nux!~`iSNoE^wZxfQ&C52D6LS1EORnF_!iz^@+T|qya zeS?9*l8du%wQcOH*V5OG@WaV8DBB!G6TjB6&3g&3ZMk+vRe0*vw8zSqqqrf$~t^?Ku8 zausv*r+M}rH2|P@PH^Cx%RhjF{wfB?(YvbhI!Lgay6W2$jZV2_yHVyuD#O2xyF(MrF@|Qr8Efj!BC3Ja ziuj|-Z7nB{vm|W0*3}}0TTk%#=4BDgv&tNnU((|Aa^yVq`a5>Q$YFO-(>`;4!`Fw* z%Sy`F(<^L{A&o^8{DtN05zsZKp}MPB0(3J4J(^_7{=G~N1H1Pu(bdB#l4F`jUKK1k zyCiiAkSBI$jVxI4w_!+XC1OfQgs#lZQ%dvFw^f_2tW|Ybr2t53n_Ks zt{!?a62o-gjd@ma3betE$M#Qw9h9{luBC2S!2Sopd|NY*Yftez*80kw`TdPEFT(x% z${@23fy}O#K;1Ubw?4;>*KdIk25^8L0106<;1JXTj`yonER7nFQk`4J9;;z5J(m*? z+!aJ?e9!hK1tE?%A?Uz=x{LK^$j|YeVZc!%$plihC?MqQ33vvS`rX0u3{czlhBCL1 zB?EUo;!rsMJbO8Fe%I*S8-xvsZ+=CoWX*#EKV*z=wM541GriJGI9VF;I>m&d_NNOK zGmkXXxqZjJ^)QF$%B}SR_qA(y@m1jYE^nsj18rNqOE+k1lh>%-TIT%$B#S zHo9l!KtSjtFbz9Z!@TjhbX4s)mIDa-C=nf(>;<2@W=*0uuK2NuzWzNZ zJJ}UvPvBmM;rWpzZyG*1NWv(9k=fr9>k#14V{jDB=8SL~z4p#MZ=?1_@DnqOX@Ll z{d!X4vy69rce1Qr)3|W6lx1-_jCRR5TxFeAaF6+`F;e6FH%4)lHU4jJ(}(WIDgbWB zS0bzrZ%PT^&05q=7%dt~&Dr@glEeeMZa-y-VW*7{v^U1vQFq#&e>LS~3e~^6%hahD%>L{_5_$1um?}Jj zf%6Q$kfi+scm9!tcQ3{Rfg}IvhyBMU&pQ+x6Qt{Z;s}sM15Vj#wQJY1p>2{a$2b=E z)VtZP$9uQl%UUiU;l>-<+dd>3hYzy@4z7Ye;&KsMZ+10zCN{G|9hGyA0mQ&CzmNkt zWrhDS4X>880wKrNO;5nB3v9b|qT!vy8MC zHO?0G-`e_r$IAXY#>xT@gfAEjz}JBR^3#PgTY>(#hR;SNmAP$-1`l7=;W&!lRrXbpT_VRb%;H-=O|2A%>-k(71tK}j7 zy62(!2D}~d?RpCFUJp9Hc$RaSt3A~MUgfmh|5Jivjn|%H1I`CQS?3E#uhH#)s^ z0Ouys;RS8oPp-a;CVhwmthGV?&s41K?&f&f%ZKnG3_fz@!hC4>|PaZ4Ya zlpQASzrhrxc{RE4<#Tc8y*!Hll{vk48mCZS?ey5K&_)fVtc-*EOe%zKI*&>m`)K$R z`bkZ3KDF#7lh6IZvh**%d6?dM_1p2+y+yfax7Ui44BIz9JYXwGK9kdhrGdAG0Y^N8 zZTMccZZ8n@l)$FOW&WT86#I-;o1OB_S)m+9XCJxTRILBWA=iwgFhyTbx2)~|+Y#D;*5dONfQE6fe z>Z-hF*bEA(%?{s^zaW}50RQ_hnbG_Ava*A=V`VoW>qpavaqH_17!~C9^?sW?_CV!y z;mI46RMd~~5^o(6u@&Tpo$?j_TVCtm`@RyGpM!6MbBsU(YaU6x4L@6iCBfi?ed&(B z#Upp4*N|JG!l_FBt+i;h8 zT{||hh|BeSmws{HxAN&C$-IIs8k!IP>-VtLwQzUeKeKDEMaundj`(N9Qn1+*f zq*s9e>$>1>er>hZ6lany)gG!lk#k|A_qk)Aujbe|xug@`xI;vL@DJ zR6ZE$oX?u@)3bz?7>i>%t~ic->Q2Th0X@s3X@=JW>*W1TAR_&pLl!bi2GFQ30f*r4 zG^!w1XDtRGdun&0B&PO*1lQu!jw6aQI{Yh*-~k+c1HLAe4gbNAtN)~GCPb?GezIrA z*4mnz<$;=L;O>pZ504k6^z+s9?-3Q<*Eh#ptlMWEg!~}U1?m8%k3y}Ieh~2X5Sc6d zLg*Xz^WA}_qEqozeIS|p)4j485kA!+0z2@{yxVF53{2J&W_l!diN{;zR`4=UgXAN` zdtkf=HMGqO>x{vMPr1k0%EyAjdle<|N~}*y2(+=1FY@ysflc^l0pS(BFYo?&kV-!f zDf!IlK<@gNTv<^wlzf)SNuF2YO?$AAd1F{PD?>pmOAT+vufp#ApU8^p6vB(UL?BG` z>C?P)#uCi&*5f-W@-q_PYC@t88OGJRf7b~9W?cH_FZ(-tw%EEPy6r{w0oLT7bd+yi zpU5b2^fZ!17dw5vt?F@-aYA13JW=FCPjfUD?jgck?W@8O7E>_3%$V zkLYRTylvtzeYmu6m|ijZS7RDH=jzf5Q~|fH0wNwRp6+-EKgAkd5no3}kI^uB3@Byl zDOK59^NyUd!cs}1@?;8~zf>=ISK$2VIRrj+=jSrg%NQ*n5WfVhr~0}@r+?QN9aKMU z(z9*rSMZ^8@n&bCN=WO-MgwyoX!}%jTxtBlP$e0q$bxW~BbzvKQHzm^0ZWFQX?;Mq^ z{%yONn=Wy~MApb)>wz=!O2GOM_I&-&HxQYua4?9oo@l+DCm%JJ?9HXVK&uVJR((dfb&>k!C^`HD|IVG*r9>Exzgz8nQK_E z>;HEPP`JipKL%!2Po+z-i;6dJll>Q8Fmx$GqJUC3zn^%>Pc2&t{;u*65R`FfV=a0L z{4sFD%qPd@+@t9!d_kfKCv_+jRzZ&0di^cb)EiHV%QnM8&>G-5?vS5$*QVE#9uC-y7Yfh( z82@d^`q(&-vP=Z@bZTu-OIfWO)$Kef8`~0;&>NROj+Jd;aQDOcG#{Fz3tImklKu5( zE)3&eqIUckI3#O~rTqg@rHAO-;QtHpR&G4TmRrZH1zA?Lwg6%Tm}_A?B=L6CdZN>C za?{h^e-O>lgdtWT<8cbjQA>xcUeigo8#a$7h2`X|0VGcariT=_b-9yq|2RVy!C8OQ^D2lo`-LMUl-FcxN(s*GleZarW(VMDRWi%6%d2{#4^+4>Duc?dtzRYoBiNqXX?_wH+09Wq!N1AYkml;Ls?>f~^TwB~+ z^_#46h2PQaqy}Um;kY##Ud7M?d)EKqhyRr=oe*;nn&6VTy@TyP-%`aeUTz{eFL^Y$ z#*LeiWPQMN)|zD;+fAiZMc>tgC%Mefr1$>yYeF^9J>1?~#E%&6UR$74`7JN49@_mR zRSt_e>{XzhaNdbi?DG7IjdK0x3H0lP`jPiwY8X>GAP=)kcdvzq*Da5i$-FJQ17N%S z&_7UV0cZo4=~Zp79UsH^60g9sc{Umqg?@OUXqAmb$F5gbmD{{6ob;WSOXfce@ZkWz z*W_a-lPdGEOVF<-l2d8s%;pRkmeEF|V=nngbsfi|%B&kEuD!wY@66Y=KKyBoq?<-+ z*uj1Q=k4z*r=AQM>4uq_@8*cV_KY54vKg=bLkj?^A<5{crR?TQz}<>n#(D5YrJaf z%2N}zd^zhU-`_bLFx?c=Tw&q`08ID02=JKn`XHr`%ltccU;*<!0&zqd-0I6?uJ}dIz*(;5$??*cJNxwO!*QeESF?;Fo_Pm6EvrysM&?peZ19OX6v@4Zz4s1^4l1GlRX~e?yY0Pf=sI^+wPBN{5mV9m-447=D|09 zziITn>k&A%FNYF-*aCgOC97+dO`d!Cv8)OOmBB5 zP~W#{QM;V>ba@)Ss$~HP^ru>)Q~z%6Ee_iVEA#r=d2BOK263~U76A)BcxQqD-yZ$x z&zkG{T`XBoiBC*(rY}JSk zqcjIU2VDYh+MK^VkkKcaC+w*EMZJ;FO{AUBRYf%Bx(3;WN8!!!xDkuBz6eV60~cCf z1iE{6@R-2zgEu7LYCj@2a8kid65$MV9bC5{f zYxi{9O#hAEiLnQO^d+Gt%x!bFwW2MzH-d0B$@Vt`tClzN6UWM+%3is22kKf`jcL5= z8MAoiInH-*Y0zda!aAdLqPKWQzdEZ@VXB}p{<k~!66yktwm@NXf! zgNYvzi{k7~`^etNMZatTQ_wX=o=q@OV&&_z<08pA2ykAVDLucb&vUpEbom~L2?=WE zoy?Q#tiPPrub!g(DquI2=u5=E^sZEKlL`gdRwsbqG~I6p$5{x%7zlJ zn^E2@;Y~VQ7L<=M$hm`RG^sTZ-)(bmy-zMTN4zVcY9UK0qh8Ds-F zSD&G)+BPO}+YK>31k#~NbQU8LtJhN3Zxg&G1_YikEfxo*|$D<^lG_X>QyU_CQm~%_Tn0t z{u0Wr{#6i+ z>-;c#6h>%PhDpk-P?=G3yX5Sy(529ARgs2tm~TWGI>pl z|G=8p+H%e}jm=kh;tOLlz4xpsZxa7UoB(7@QzeCWoq0FOdhMl(W4S_|7DfWi2%U`Z zwiZ&M3xa%w5^dqhd@aw(Ji`3@^=z z;Dm}8k;A0W@Oe3j*%t3wvDKJebdnx&f=aG7pg?_KRd`EzbCZpIvzA*T@3T#xtE`Q$oO8pgPlTcCYjqvwfdB9%2zI2D61C3f?! z;kdb5`!{5Qt6%H05tbRW@KWm;-}gOJjAQ6NyM1n5{sIczK!9SeU)Jv3x1+L@i1dAU ze6!M7WcQMb&m#%}Infn!e{eS3(`D zo{~4a)@MchD*(g=i6A*TrcnY+Vf(a%DeP=FkYqdujnS}-CN*rzoNAGMoy}LiWA0%j z$hXqW$+b9v0|IMhhUE%ko=lPS*rX3~72vPA3E9wAx?ir8%y^v?=Ppd_izl2#eoUR> zwNvYa)=r(}tewVlu|p#(BqXp*bzX)CgiidOd5!+Dg8ZZHA_ao7H*c#*cQmZDfp|M= zg&?ok&`#N9eu-4wKimdW_s+BPufbo3ydB)?h zhZ9GluRuX*6&hvxt0y18u@8#qfpQvBa*Ib zMK|>EHcH9UIqP#(7iOi_dtvz|&1%lnJP9|nHe(A`>+39^U^XxjMD(K{)E%gkqD z%75u2(YQR*pxu5A?YQQEb49Ub{xx`^ej&r$^}F2??N1aO%QD!xP8Kf2(+F2ly3MbQ@&7A#EE*L3G(KB4-maJOPdFmco1*ESn|eK zsBSH&x$q~Ab85f(d@LPn62j4~6mEKdN*p0|^6sf=UoE|O7{}&N`bj(jn%HYA_{2IW zR(;%{Hh@t_FUBTY*m}~s*$e;q6w$!I4n?4>5kQnpo^q8rtO0=!b(X%Nm0_H6u+%Z_?X8ty<5J{3CWv-lJ{S~=Y+^2>+( z^i*97w_#EN0si92XE#`v-Q5^Wi>^$*4UP}ub2xK0p~I8yz_TiFeAX{~Mla@U4w{O% z8uH=#s1B8Z6#kiu54@@7IF#q*7$wxnd0S-F*wFs^D7EGqT8WxIm{<<<*%X&Aqgsj1 z%7eROZ;%A|tp#M)7YTdZWJz=0y}Zg=_W?$Tgh|W#*BiP9%~_a`+2r%>eEHy+A{DY_ z9v-_DY0otGwo;OhWM=p>fVLNbUY}!4m}h!Rkk?!=6T-7(ajiAbQXrqS;ANMqycWXh zp!*zjyMdJ%={zMxv&mOt zpvYZfougCR?LZjM?5HXjbHOb9PB*jbV(Q6UfQrCm)$sW_e@rNlD}irVRsfMv5AFP% zIie7MZ|58xem*_TA(x=Ek9`}Ijf!hKH{Ed==qZ^kneF)UuZ=T`5E$dWreHxz&4kz!++`@0j)$|F{@$`rW>t5YtotCR@EAQb3l$H(Bi{T%G0Kn;NUq<%M{t16r7#ekbBQB)W1U3Tl%m=*Mod!swx zC5;Zs5~!K;QBrZBdRu%VETzaQ!+nOG7R=jy(bANG;Z|#iV=d#PLNau`MILv6{)W3i z#?fG4sn(`A9qMK`pZjceRSKUibwv5~kzLy7k|$3-!NiC6)d*I0+5sta2g=H2PwDBZ zu-%TPt8{D(eND#^ae3~Y#2fNPi?<&PHWM2fY9)GQDUfcTwJH&@4i0YVt35h8dR*~t z9`-}6noQ6_2`Ny7GD+asDjwb|i`Wa3H|COFX=#p`oY}QUGh^>WTts^4c}ftKbeB?B*A&K*baB?ndn%^RPGr3mCLb%3tT91={{?! zv|q$^LhY}gK+Ygik~LYk+%@UwRBPmL{V;`Hj#Y0gLf}0_J|V+zlzj$3+3b65T-uG= ziX{D5o?GV)*#V*GPC{6L@}vHD`yJ>Tis;!{*Fn2&Q3~?4%Gwg=uM#nN`3Ci-c#Gcb z0Me|Q0Jky8tN8IMs;;dLoKAp4x^mgS&P)xjGkKzL{aK!$f4a_u>hVF~ym~Z~Lb=_Q zB#|$>chbb&o__9fF!Y6d62!}R^FUET*X@rB#a8DU+#OVD3)krguwnT9uNDsGHS~!OZvAiC2ojv4pBv%Brvt#v~&|{ zhB6$@-$xDm6n7Ir!nRy)>0Q-(>{{kf7Ry{9*+yS#-XZv!Qi~%1=NdZ+2$|&P#Z!wa zYXp;^<23><4CG^&uGEA=_AT9M_U`NRm?yl~u&-FAN0F$)5qn+la&&^oK)1XA@ z;^M_&<(rEo6=i`s&s^rqK(l-eGP00>P29rW@X+w$Eyk}LQE1D^OCD7D#~B6 z6aghpx9*z~JZN{M6Hs-+vh}l}ALr{s&9Aw$A?t#iRrSn*TKg2tK8(~C3XZ;vFbRY* z8lbX9@~GZu2rTRaFoF3>{$A|R<#U(=X3f1HG){ogIMVV`~(j~s7S?Lo>ZszV(s`$NEpo@VD zq@lr*JfFOYX97VLfAntso8Lq}R!5F7Rdi|=T1`Jf3)$kc@pLiz3%kX9?SNoM3!MzH zumV11iuTCHlQ%Ky_lfgdl1E_ds9|3Yq$3NYq4siB1Hj@MHC&!}0;=|&|$WlVG zGrmI?vgZ`|(%y$>rcWx&!b7F8u~{Fz1sO(0rY!r!uXno`7k+lbEtv1PZgT5gjz6hHu%yS=nDgIP5~8UpAVXmzRI(pU=0Oek7{e9it?QCK7xM126-L!1Kw$@R_XOvnotoeWkcWLK;; z&>2IWM>RdUCOn})>QrnhqJZz()qKfS=C<$4r=WX-vN9Q%izqwa5+nSklbGc7O7E4U z=-4y7rgfHqivN5dFzyQXaG$LU5lNy^AWz~76*Z8v@lh`xT2}9yTA9>i-)7buX^^#} z=MG+H4yGd>sX|Gwv3N9LiO}#z_168_4Y^a77DGz>;t)KYoRarBTK(9+Fa-(e1Ba6* zFdmeE?<|GSnExx7GBjG+GoxlseWzB9|Nmp}z2lnN-*nMnS3m^mAW@1SReF`!Xd+#D zlMY5eq)Uk+y$T3OjWiJ{0SO>A(xpZ^gc_>!gc1UTc$UAJJ>%YYX70ZC&YAN$bI)H1 zab=nHUF-Y2?RlTKbox`yr=Y&gR({T)o>u-KRfcs%rV&^FJ2_3omQIGO0%n`Gm33pn zjDndl{MnDAyF+gz7|J}odDY(%D-xO|K131=kBl{d~`+-43%ql&ik&kKov2$)(?HLFyb8-kj} z3D*|QA7m+kJpXtQe!coKwI_ZAmj8zRn_f@ne6_;y>*Yf6@h+D&oJ2(p!uG?1H!jOv z(!I>JX{EP1!)+_IZ36|2mB< zexfN5ghMqSo0{PWaW>h+bggPT+(uhSz1MzJQCV9wgQc4NY|@GGBh@XcmF0Zj$4 zjsk*{u*+jUxdAhw{;UMd(-%-#Q)uShIgG%{QcaBN_)P}gcvlT7g) zA(i%*CW#-QmY19s76kG6F;$uF?6$bxlNf<-XRq1cRi=J~ZyQqNB^F`{0B?5f3$H3~5I7oH(xYE%2w`o@grnZKRZ3T>45UQc}_?=LefS?%RSg5=(TI-g*1 zK?rkniM=sxMM)$`o&P!uGLHOILE|ytf^bE4oAy5FM6MOi8zziTaXT7^J`itzz-SR% zD3BE#tJRF6_GgiA%S6|;|mG}cTRP@Z%cGuqa zFGR8}#RSvHNxUh3cc7jlCeC;3wjKAgbHD0>G--`jesFMoSCk`u^vP{a;ow9KDM0ef zx_aOa6(Nno0nF;^+I`oe`Kf|7*T?O>cw^CRfv0MTujYT0D44tKb6gnkATGv!?~Zc5 zr&E)xh8-~_E>nXBjjS>%`r!dbCmf% zFri@|EuW;BQCD?MVtSt@SL;l~C7cb|xkmdbA}Hax6>|x03%eCYzD|^XNKC}@cOOBG z@sDU%!>cuWYM)9!0rW(Xx4phvg@2NW$4>3biwaq!V_~5fNT^4qW`C)Rq=ptSo~EF1 z`ewo=zKi~JSIGVQr^4rAbY+^Z(c3PcoK-BXwxmRAoH+=7fevxb?yoeus6X3K7Gs_G zJmb=5XS#MR8n)aZ{a05xWIvg&zIghbBs=VSeMlvw-swFhWqOL(q}qq+r)j6t?(N98 z@upuSbc-n$ulE_jb1D&FBPTLnIZiXlS-8B^{AseI5Mkpp()vzKZG5=Vo7wQul%!X* zMW>&`ZxDN)SV+_i*F&*FcLcAE_*-G8MAU{JLL6hTgW+nzxyt;X`MxB3L@(qrw~)TI zD5heCWJqW5bv~AQwQwEWXq<5CNNvN~dExmH%+1fH7Fw_B&H2w(xxXxBopiRW$Oi?; z{$X^DvN&J28p_?r`RHx~x>+-hk$&CYwV{xAE<#k#zxH`8pX{-4%aV}?)LDR<8^tX|Z9ItraX{%uY|n~v|b=y_Hu{@LfxwZ!KFU`<)qiHUQR zS`p&6$RXGo>jC8P83A7edFI7PnZ?x9t(ed1JOgc4hG_X5$2sLY4>D+1AQw>!xxWay zW$kcaTuK!GVj!#Ip<$?^4CztzkdaD>Px5Mu{^uQNJ@^@MaNpD)=@<^&$OxHJySceW-ep~*Z6Ucc;%b`YXY5; zSw)36#^EdMyqRM^B<6nYawKn@=LmAJ`wH$$Vk+x^;P z9a*w?rce@Ic-{lQ6LBs_k2p?@%G2QXQv{y`Z-weGhqrkvlB3O z%oI;!DmS%6uPiN_9>@Lu$+I1Aab>nimZw=%pI&l2xeUL!R@+xHw-uJ=5FxCwiNN+k zTTL{2yZMXEqSTG=U;c8NUSVYE20I@}{w9dZu%Tmqjq)dmr8)>n>Ex-len&2|$5g2_ z>%nI14O`x&tcO9@Zufk;*Vtw&v~BhR9Ijwm>?3Y8P$^3SI?m5mM=0DY6l%UlPOi1V zDBP5|Yv!wqcBXPE{_)U?yg7T0pcN9SNCyQZgMj)-9OY|w7`ASIaqF>z*d%6B)i+^# zDx*T4Zm0ErUPW;aND*iU{RiFQPsO~y@!3h5R?Hqe#9z9|2d;ezv%;7-{eY=3c3LXF zJd3DjA&#YajiCuAi53u}=3)v6%Tb<=Ty^zz$;Q$cEx9nqr!2-A#oEi=THLt|+I_~| zqI<3!K2M3=U#B^7FW=nv?3LT>@bj=lHuu}kmrWVXP$}3s&2534UhKA-OtrX_tr~1! z`D|=4@zJ!%_f`97WrDD zoG?rZl~xtnv{3_}^pETv8Z8-(B(bD=j^I;T*`8*~w;fKHuk_vFa8An4EQPaoXV}Z8 z4KKyiO1J>D2wu#`zpm2BK2pG(v$__eB0a9?23|T`y~EMiJHJ?TeuL>V!e?VR1vINr z!tTqqPs4j0oDM8#Lg^g>z$qG3zA~iM2wdr{2@;StGt~d! z2q>CaG7=M8m+yaB6PWN-8dZWUUcxKvHh*Z+2UGMw&y^NtYP4R|8K*8#?7`%xLu1e99 zbF?a)ZbO1m}1 z6B@VQY;nq5MH`vkh9PHvFrDeM)p#IuW@<|~=S}_l2Nh~|F7ezYsdwV-Y?f?ei#os; zJVeg=XzAg@n%5xfTJ(+v+cOirr(#jA3FB07hep9(!no6Hbf`bVK9RogGZMrx=0?y*^)c;NH4 z)<0<%-i&6q@F~!!|L(8u%#k`bL-nodMMH-1cCtoI#z#h-<~< zoev%CiRz2rSZwvgJFe6KZ2~?yy4}?-d)r%Jk|NHsMg0m*EBvAa(O-IwqIkvN4JCLI z=2$Ydpkrn_#fg7ja#FP|?_$66&fB0=#ACqovzpE1A5oWQ9Ini?UJ;b2JLN`OTOZGH zJ5>3U3abiRx4;{(_UI;~mj#r~mNb%ll^;p&_{+NHer@MsljnFuoepiEM=Ka;)2Rw7 z8fM)c_$DS6AJ6*oc`YmqxfPWqiZ{rfH;OI5{sdi|0VdRAZ1d829jziwi&AI%I`f?J z9N#`*(cwIOA{DHl(n?8z!+;rl1XPnPiOT*u*jrxPm;qzfM`aCz+4tkG3afMLwaK{h zT?y|O)vlXROvPbVg-*6l%iZVKe`fx-*5<_CeM|ch$Abt`Zh5x*P9l%+fpL zfDsPV*m2hc)2VyfmDNElT9rCHi+ddjG)WieD84REMdm%B{MPldY`0t1iyvbYB0%9Xrii7CRZlHtk?q+#;&Ji<;+CwqLEt%p^TomJB*_8@?cm-WX3_i+_k zsF?XYO(FiCe}CDlW#t5hMDPs{Wo5gwAF1+UDXu)(qo>$0Y-u>z`r>SdB?Aq;Pr`hw z;bBdPodg)?+_4F1+xkew75OzX9 z%~UhQE~zp`Psgd}nPkHk-%v_telR!Vu_A?+0c^rADLQZ+Y(6u6` zoCMr(46OANQVCpL*~q9tSZpInoN|16=MHF52OjWeS zx-NP_j2byco!EJgXAjD_Pj*PEH)v_3fy3lV7BJzs>KF+MEBEJ1z~qIA-o)D#s(}tZ zc{^6qtTZFL0sJExcRmb@AeB=q1Lq{vf=dZ{`gl*8)%ekHB8FFkhP-3F%i_ds#!wgGh zb{JIPDFDBmmUg9F_10TeWW_IzrHvXkSn*80m{4X^hj!wmqw;SP`kpbgn8p+lJl@k3 z{9EMKf9tdVfJHv~|39u;sbMSm49+3OraG?ekRcf>e!SPR}k9Geb(~)iXwa3 zhW_;csb)&4q(_VQ-sw<+G;hnc8PpRsaMju)6Rj!aO9R4^hHoB@iWz+!K(@pSh~721 zGmh?&ZCV(8-+!C_!yxh~Y9io)pHW^t(wLDYmA;d%JMJb|d=d@P1r$B&@c!c#GJxw* zR@buL=b%kzDOaRPeZOFgt5kbjPk0#**yi)5*_C4DXBv+Gq)2yKw|3BDZDoBE;cWdS zUF$9S={9HG9dTt`Di=3htWEs)&$Ao$7SVU=lt5O*uoezByh{cK5{WF*uy1-NTV4N3 zm~#^42}~HX^^*DaX_^J+5bV$7ECdhsFuldf7P`Y@s!9iz#K&wu1=Ycj(ruzxp(?z|8Z$|ra&&B=vZ-6xsHx4@nI z^Ud27zU;|f3Keu}+yLA37fq9*`X>N)RZDgw3OHudgx})(+s{V=-+HcN0 zc=Kc{j3kMjIlmH-Lo`*d?B6Ahr#i1=RXM4a1t+*bc5&Q*LJ`y@612014xXX(q;{@X zgJG#T=k4P1qB30S{8Xs?*pG~Hc7{aUEUX#=DEljTFA6rtx%?eseK z1G#1PawP>Vs&Cw}xAl`k_21&|oDHH@1JaCDDL7ZGRK&-L)kBrn$aMKS=^dek$S(Su zko{~5H#Yaq+tgmH`(X-@xe_T2w$?+tQba-59Y4Qsx^s*I-*~vG@?H-{0_0^`7?lrK z)AWT!g?%M_gE8RTBYpg}<6nJSsNK$69D@D}kYeHVkekO}HgBSTaGjDX_ zEA#h*$h{#pLsO%Woz^nR_O_je589H&h?CmlbWAd*5{HI?`zGfMR{b(8RM{QH$20Ua zHR=q0I|anG7L0fJhO_NXPLwV(bX#TtWB5X_3v24c zb=BciBq04v!x2XCB~my`cmuL5R3uUQ@|(7v{l!U1Zx=4zmp#Q1T8Fqy6lfw#2C=y^ z{YW zq2Bf1B~1~R8=x4B$(TutjSZlsW17=Juk-Q#{W~H0Ixp!dUY-T%1N{KYYr;O%e=&kf=Z|iHzfo=*!RRIWC zg*LUM!}#1&4bxY9xaN_3Q?Of9R;=jMw>SH4I~-y@3_m2aPjh`(S=MH)sSKYHRHkq% z_wT=a(}vmKWz6V=q0F~4xnKHsqZ*XaMs?*u&|`%MRiiX?$nSz*&);`ZUbskq2BW-A zwOhi<2l{0A!UYcFDXtzVH*0v?7*MAgxy0i2**8An;wWalF7zUEY%uUVoVYO?y2K(> zYGZ7b^gvFAs<`+jutCqwzBW^QjYDAJA(cua0J><1sH0+xxw0^Q%y}{3P^0lj^akXs zU(SA!oCcJGXCnNsU`1Fk9X4|ttZ3t)vpbyX-{{=fGo?S^yl$>{2iW+`gLwIY&Z)&6|yQX4-wiZ=?8WLFp}wN z^QVtrFMz(De92#xn7bl(Ci6^Ow(*?;BU9?~UqVSOtce+%S1;spXPy27^ZyVc{D1Ju z9C!`a5lS*Gkcyit;U6?Ev%l!PbyYpQT4c&O;S9@_STX+UP=6LI`1FLj1NI6dS~(P3 z-)S&#>UFQbjej~wcxFRNtHA^Lk4BO|I<5Zk`UoL01<|>ADc z?#+QNU>0EvKmXhY8c}?9jgX% zZ~MzSGKj1TdaB>pozld}cdNe@?f!-5_pO48i_bk)(N-1+{~kz3x{0gGXDr9gTyRSB z-1}k|>LbLdcFlmK@byH#hKh;r>!e%5$6Z7yt}BjuwNP)2`aR* zvdT9yb$Vfr^zt%r%e#5!f}ooc{SyXSVAttugFufdUTLq$DaOAaQ_26~b<=4?hsibw z!xPFQChC3d(&^KwQ5>?|iJiZ;%Wp}-6A3G|FP_jnS%K2csMFovr=Wk(uiVqg;_t#& zynZ+FrBVFQDD5097;?Ve`4h~Uh;A~uj{0~2rNMYoOJP=FLg)FjY|AU0Psg5lCKdqg z=X*~hjhksg2P;w!>zU)iUhN9e=tdaRt#&-Dtjg4HXFksvV{}4>OChph4iW;iz}8y^ zN?HN~ZPiCa?uPe_dOhpqQmzlS_vCrQ;gu0RJAc*CQwe0Rmg6{p!5DR0xe!LEFHhXJ zzWYW2yqYG-bxsd_!_0Xj(yIipXnRpBFBYR&V=|(nab~R*_j}pS-qmT+)gQV!A4I_( z+5cO|2+(nJ!_=uZAhUOBF~3;b%(AxAjgzye=SLJPGWv{Ju)emYs`lC20^rt0YFtWm z3APe%ALCH0Un!htW82g zI51?3p5j4(v}SExC91YE5}Ho8G;OHlbWKdSL$sVlx9gc(bSw=6M>cT16()*BHmBE& zG^?fB9h2`-{kTwEzIoQ>W$PK{tC!J$3ICTq#@}n6f9E~JS19U2ATtrJ@9f?hNH^fA zp{$X;Df~1$&G-b>RWA>o=2>Uz%rB>G7-sr!Q{PvOw~n~W31YbYi>IXb@z3l0kKU|b zn-%j2kYF)sbD->&Sbvk#{s9`n>Xdiw_UVQA@u}JhQJVxcz&AMyxA5m7ntp^j+K zF()S0qAp%&q2OaD6RU4Y?>_vq|I8wZ(I{}*XR11UODI?Kb^I(SE_NeQ;UZWzn(R;upKh z&&K=A?1UOGciNWY`wQBDC)%*lm+uOtH}rnR=$ph~`}Y31+v{hH_;_8kSY$r7$K?qZ zV-Cl@N~*jxm7ICo==;j(OL%9@#7Vj)_vi$-w7)#y7xFzntc(thPg!8v+@#<)vc?yb zyn^gwuwu%pJnS7@*>+8so(AUEK?|l^9NsRYB?dhWqyU8g31f0bOO`Hn#Y&TKo82k{Z;!c^T}v^a{;vhZ?Cw&;)S%Lh!jE7s@^r!{Uflk zjs-1FYuPpK~cNH!_7lxV#FtSU#Td zQzT$UXEfrxa~d^>4o5n$<~Dp|`s@r*83wT9kP2Hf$JM>#i+8t=(bvQ?-yOBEpeqCH zjO-7t3sC6mw6!E_Sd40E7%Cr$n<1oI=ACN4IJqmyKOL#tQ~#o)#IVTmM#n+}-`$Ft zdELmk7RZHLM2}0dX;t{@uB%BMq=+id2u5~2FIM5&L!co>%gd~E+5a4HdDB13uU*Eh@pG<;m~+vO6tCLKv17X&)O8g0gQ7I-J0P)@ z-~}hy$NlA0F!HsF0HqOfZ68>jcm+#h1d(O|J_RPL{#)Nh-BUb6OgF{HHBVh`4dkjh zvdHx&z@^M1IJShBj7Aq@^pzQI&jwb`x`9tvZ`V@NhQBp&!ASd!Hx_!`DvpYNRukf& z9BrFsI4r#4;aal>a6bHX_FgFk`jAWp_TnN@NME9jd(o|!V8;lGF}qknW6BST1`DTC zKcD|5)1eJ2eNAT0B^QX03Sj#h0n3MluW}!}#ebP`xY4E|tfsC~RS0Jzf41QtDrCQY zIqPO(l>e$+$*$v3Z6s+Q4_;5q^U=FskH1(qpY3d8!;+vP=k4;6f{yZV&=TkM3+wj}r7M_baz* zAZuqF(Cg&Oca@6G-gla3_$J@Kh8deFOUHZ zibkNYE=8ZR=}Qys+A%HL9#B+AEiGUrp$!K8h2W6#Qc5ao&9?S>--5tyvNY}%U<&eviAE@xI z_lnud-^z#IMzoIc`@jqP@;{ji7;gdz6_=vIy(D@aOBugxC5D9=_IiyHqCD26dJb~Y z>^M;ec_f`A+?wc`>Q3*x{jIowzx6r9WbNzY#c+fc4fVo+A!cJjFP$Ya=7M<1t(Xaxa(!e5frj)7}l zh#N5x*oKl>ZRWj?so)2^PCr5IK*{VcC!D*qWVy2q_D3CD`sT_jU_r%KeBt*`PdoSAE`*p#2W%lj2Z*?0= zr!l)@#}{)NX@N`O1GS=l3@<65L< zmL(^?LY0#qKS3)Cuy4Om22WJaUjtu20ZUIbkisBSM(DWmmYpJH-5m(|t^)2ayT(Av zGU(f?i0`X-nK&+^3uVFl?(J@RQI(~c8@UrUBi|~#qYGpouqQ-km^7@99PUj5#fpQS zi9ZQ1{IiU~f6Zrz#K7x;LzcMGde=f9X)-N*gJ%6Fh|SeP!ZoJ zzF|?SlVDxGXU@(oo0qiVf$aB$)+$U6O1I1we)^L7yshr(blt92`CbwO7JP~<1#6xo zj{{a$nxH~Z0p2XmZ*VwZ6wrqb98-;!jqVfda5hEYs7x|ou7RUa1QZ6cIKNJ|94Da; z9j5KXJD2NiLb})ZC3CN18Y8@+Mdp8Ne*?GVUw8QBw!d!izv2UQSA5nr zq7ubS66#k_nUjZCu1dq4cI?LDH@7$KVgikZA6Y%Rcy`U=EG_Q3h`{@D<_M(^8@b*M z>8mL=m~US?DBl@)IJT?h2k365C?913+7gCsV57f}06!quNXY8*pP(*pfCw@=4g2(t zRiaRfEr#$s(7B_+~5F`SUW>eYu}zl;i3?BG52ea;Kb1h=e-& z(psY&?G3cX8$5?QY`~CUp$(XwyrrMTQ6jm_ z-k*$Kwb(7_%W-gWg+mG% zv2-*&<%4kN?yz?BgB9INPos-Qy7YV14MM$N2p7bsV{h&?#8tyZ1n1^I874F&*Ufw$ z+We6di9?$mNRt`hF@UdF2*cXI?16cM@_>hK;Q`sbxD6TJ+JTJtqr7JdB=?e@!EqG| zfNeaFA_>4ZM2H_z%cGs?WS;|IdXJPZY?bWVurV6Bxi;B{A4?;*tleB(47Rz#TCcyp zH+?)3m%U)e{O}god7ZbnS622(X5NP9hmrL?z^v4ZkIHC8z;VH_X4*nna|d8Bz4swc z0F~&Upj(N671lXI@1+NB&9_4r(8OlQ&N0H$nDh{*1z_u`Cl=I4t^^n%Eb$2N#(*H{ zVZw25XU2NRG;?1-5pev!T2#qH>2ubfiB@^#KA5-pJ-e?HG>oBn7Ih_uTv!=6bxang z2z79Qfz-5`v|oc{!XTmtGJg)3@kwX^MDFi|{cr+$j86bu z+pGS|L+-AJOgcY7Uu^(p_gRt^_<%f$C?>N2cMw>@bIQQ2yf9BznwbemIrb(IVB*KH zEyxHk6C{iok@#n%8i7%fF8_*cn^;|62TuHQdzS%uqHe$*Rw$s{-RIfx`>}5DCkGZH zZFV~?N>-wk{N1Te($^CzMvY3mO!pl6aqVnuRQI!Z&uhMQuGiJvH*w-Eu#BJTS7JEm z+;;$0)DON)e#JnTs;bDhc;cNVSu*5qv9fk~@_tEjxK>^FOl(L{!dWx-@6mwtI}`lq ziP{-LT#s(u8*GZ4vRFa+l^8!%(@Zc10VGFgiH)X3JD-f98csqDJtgU(d=6sgD~6lg zl7Eb&&t45oYWwo!a&}muZhis#Eny=-o@qkMlBcAWl!`MGw?IM-gRlA24$H!wg^@ci zU@cjx^CA>JhI6N*1zvAe`ws0W9 z@MvFL*fWzKtyj@sTACZWuWf8Gef_CK4t1k_&mKQBo!orZ%lSF2g9bEURc}Y`i$q(V zue6Nh;KBu0RWGZ5?l;*qje3{PbiwSa&87L=gfQxUw=7KxT$E3J+=Q--G&0L1WO7+> zY(O;mR_7*ki>OY0mtK_N#!&rYSWz!{u|f^HR9UT+rYI6oG?Mh>Qic$N30z<+e>zJ` zi)gbY$CrbWRNVcyxFr8opZzaKivR2%AokGfdRX-#Qnn_WuUmP83ECi4?VYpA(s-H1 zbnM7?d%I6qF5g#R9NKUsq`4W%hoIVP{_;10s;>0%oI?-{Yfa?R#ko%WQyK?8a}el) z=C@7Rba}XJuZU}R-!LDyeP1q-srIEK)x2b&vBdS$l%v}cZF#iT(-!{~nf=Zhp}U z4Fv30i05pKbQ8+?S7L4l1zS8YC5>lU*3{r*h({eV=t^%-;-&+if`*}{i_Vgls}~~I z@5%lPaG@POtz}T3Ggm+$|FhSy2+!vWL>02DI|-Paz}dKocrDBcY)5CctJD1lk%rIN z#g3%(Gz&WbtQvv=a&K7%MYe*h?Lzc6DV-f2)DLY*Kutd`H4KjMeu*#@x&-%5 zL0zFd&3*C9ayCmdKILV-0RBY#74f#Jm;mUhtS9ZfX$&*bXE=E^T2HuT%=2E+@P+?P zu=88=-;}*a841h={De`HX%T(fsrsyB2k`6+fOFX6 zXpBQ}32Xa^ufj+)TA~}pEO@H)NDbt`+;`^1g{hhmkD^#+p}B2{v+C^hfbSC~_lyb~)pnzO-Wo*x+K(^uD-now#fpm-M#`f7MQs*n z^Q|ezLk?w~;Ff;EbhH2DU}AD1Msw7}*#)f)+=P}}?!m-V!0iB}t65V>g0KT7m^%Tn zH?uRYfZF>tebTskldNFh_%ATCw+C-Up8}kfHk@+)>`c{9ki$=qqb~%tO#))qPLavU z)V`hoZ1ghb*LQx;mRROK3l$IvWmRX2DSTjtR+Z%*POep2tDVVFyc~fW8=vFEjXm(R zab8IDt{XHKk|VW{Z^i-J846(7AGq}U`BZ+7g$@xL>?!2Xq@N(%4FJl{_64X&0rS#5 z!K5<~(wje^@eqxLldger!A>Y#DP+O?cz+x=3PU^WKt{oDR!Idc>(8_>&`2)x;T|k>_M-nC`yT-p*hq_-)zh%b6;u4dAdwzDs%rrDxj`>9r zzlmGRz+&8vIdX5?WCGC%Kiwj^rd;WBmno`5FO@K`rX^iKl3mFX!!d;fUNvg@;8WF) zD0!UNYt*`Pw+{cZ_5RA2mzv@kw)2I#O^qjTh?XN3B>KFGWK8k4ZwX75Q%c}Xqx0s3 z1W=!72GuUZn2sgIakk~yu_Z=PYq73+s+<7V&D{t6imO!i$3;b8SX25WVCQzvA)XW4aOwKQ>^aCeb!ebRoe)Fm z;+L9;@@M%yXGAZ`$?`p$Vq#gC1)r;q`gUB&_1aHu@uPX)^WGRP7MsberA8p*QZ{t$ z9<6SsdY{pY`hK#hL~+T<1 z@c+fH7x{h{vHYY$zK1C|E(7cy{{{&5bKt*mT-@s#kp*n)5YHpsRSU5rCPEkimUt8= z-Zitd(#bS~A-0o_2QzG@Q4bsux!7U?^My(t-4tZ>?CkJ?hgV}&K*?biBPp0a?50qtXP4f#bfQs$g>($MMJU1lWPkkxD8dPMnZASyg z6u`<+z%htp1R)MM29iR6iI7t}>%3DpqP@`^TS``YB&N{-)G*2um{xx<1NArELSW0Q zhw!jR1jU(sIJ&ED20P&lxx8bFg~y<<*5sjD*XYK!WB76Qa^QT=j+d{Ad)q$YgL?(U z!$xKwF_6YrLQJ_8s#qecm%l!Lj*lnMnPH=`Wu2D$h`y)3=NM6asS(SfZp zpdzyqR{IQoSpf(QDVdsBQ+|bpz{!7AM7n{q>v5yppF7Bl?uUv4q2ZsnTA0E0`~;nX{V)aM*KQQ~mL;?$L&xy+ zoq~OM8ojH0IgNaW zEm-g$k>XKpqXV(>ND}b3FZMB?Q)}Vr!`W}74M}fT%_o;i18WoaVOH3+EKepbX zLs~o%75YTG-Eows(L2@o?d+?s-TuIx`7NUOQ{p2)>-r&LII!W$R_{7b?@Dh2sssrh zsjw^#uvvYM86kA1j7(S_P6F3W`!8tzdsy)T)oeLlp+dBo0}~$zd-_U75pz9W?{iMe zkeVvfT@ENkb#E_&F1NeAKCzwAn0(#CEo(rpfRAT=@rs?Z%kB?BTX&#jM_Uv^a0{LX z8pebW`6-3F4S=mpgZuv-8r1)+V*;VTOxbFZ00NVui%%It{dyGPlfGHNqbRMZ+DIO4 zmtNk?nCIBrOZgi>q`0LN5bNCf#ms^(lLcF1a4DMpV7i~YkuYkhzLmW_?s6x*#oEC@ zCT^-rH!+NV+H$&KcW`7Hp!Po{jsxi;?;|!c_Y4MT7y>-EpP-buiUXJfP|I7%NU!$I z8td@KZ@acY*Z``*aiPZyOF+~=l_xq;o992}WgWl6ph-a2PxU=%miu9E=#49IlD-sk}N4v?Pj0P8X|`~R&x zU0pHwvX8@kd~Ad}7EL=5fY7WPUutWZpIWT_hLU(%e6x$@>~3lOnQQ$B%}+@gCf;gN z#akYkiPiS&eNXL98$_>E`h6&#LU>k=6b(Q12l8cUfe@PU) z%LYWfr+{SZ0vafO>H{b@V=DGvyg7D63}fWLQlhO7ysFY&98{|nTG zGshiL7ybkVIl;EA$a#R@`(70`i-D6Orv1f%q^1DklvCrcq$XhYGJhpC0lU|;ecX%w zm+E#gn2P9dALkX}j748G4M@~P&MYc5EKDrPtIkM7_dM^>;@vG-B~5_CJ&T4uEqY&{ z!l>2NP33za#^t%To_R3v-~5nNobcJZGI{g`%J7TfzhnUSJ^j&mC;{3J2tx!Db4CpU zILNm5;rpC(vi@ZZrY^*9kKf-|2m;N^qWf`yR-xGjNKw7TmtK{8JKnmesnH|5ICdYO=W639V-bt>z2PlA+1dXG!P0Ea9hHl^K3i&nr?9~j&I~N#)VdlJ?0-udSEdi{`$4yVn60^ zoKU^$V@%dkaF%l`&`Z!xNbqX>1Oys2|Ctq7zXpb|QJkeY0Vt;(kCKj-14{6&ji}J} z8NIe+y|E+MGz`MO1q*XVlC*kVLj$D%S2F5Icct-{D+v?|G67c-s1rD$0ap@;Vvm9O z55HW=oa#9<0`nnqQ8T3?`s(|n1EgV9mHEcuy9E{n+@n?Hd`(|XO$`jD&QiUx2|%|# zj+MU&ZhtO95W*tr#dCZ*?$_X!!#5g73(VJdtf%pt+f%ldX;J{${Os#P!$eq%I0-PH z*{(qGTNtSC?*$+(x+$44oPEF?p{Ft3m2ti5!T#ICmCBNW)NS9bi-T<=`n-3=KUYpJ zav;=|g*+$!KcQ3u@L&FpQw~vo0xIn9n+CcD`CBQzXRyBjuSEqG4mhG0`-SRdvE8np zokr40AFJ!r{X3l4mf}-mjUyJwNdUWxUcdeOqk{0H`9%NAiH~>D&!1_?$S-y+ z8OhXo)zl&ObqK*~Wa-4xr>B)AI*l2x-CFfgr5n5%FD>oOdzIttNulkT z-_ss{&h_poq_3j%4|ly$(q$>q1yTfZ)pwXlZ(#XeTwyoVbVAtD)ETzszPb9ZnbO}4 zmgFi>!h_lOg0XSoFf3oq@WP15T+9H?WR%bR&Y2PCx&^^DwUZejEtYmF%C6^!MydVn zerlQ-?^7z?dL~uY4Zbgr?;*1b^7|qx+-gQkam_vNkxXQUGT@Yrq0gwbj#Ec;YNOP z`wU|v*km;jGNQ=oFGdV0RLM3g2)8$FxTnzwkMhl}ZBD+cCMX*7vE_8+8w$X!4=0QH zFg2qt*9+mojB1i0bUy9!sUAO)86BsDX!BA$UN1i6&u}Ptkr^oI?@PS$N$u!4iT}Gw zD}sSLOT^r3TcF7w?leHJq|m6GMYiA!kDJ=K)~G2yjk2tVesuCxcPjV+lf<|*9NwDL0fEw!{Arps)_M~ATIpMFr%61Io8A2Y( zAfpl)j=G}Q5+x@>Vb>VGtI*8`kW=X?jwqD;PapC5-zDZ}`fNGiY=_j#apEbT_bHkvXn|EJj3^u3%_FNp~c? zMi>)ivkpq8Wo=EngRp5Qq}t2sO{_khF=OQ(rt zU*idsD=>{v`|Q+^;R2$0N@O0raT1?9Ul6oJEH5USHf1s4Z?*DCtY{HZdIh@~8j#yA zj3Yn7)y|2Tx>|J<1$HZP6EkPEFuZeG>kacjHE-HyyQO7>@kL`w;gkCft8Fq-J+W%c z71H8-2k&OH1zxs-X!;)rY)iMy&`l7Kc*Dqgt2x(LL-09&hU`^<^ms8b56XC`zc*G5 zs@xk&$m~gAWYcr9xDX_j4AVHygg8;Vw^YrrkWS;d+Gzdn#Ley%JsEkiSJApv=%hTg zd7`e}nANd4#J*GhT_Ja=>e&?`|Lh|ne^0!3`#$k)qX4oT39N#An>2<8K9-vd5yMQd zvAT*Gi(fW-tvR)=UmREBw$S_8g*}y-EnVFvA z(8iL_KLD7_uoVJhBvTS8TSk(lSz)ctCQB{a39GUejyYlLM;bM1I{I;G!XX6QHV%Xd375iY@`+*w-*-cQ4%j{t5v@N#S71)?2utcf9)J7cKpJD>-BkfC zniVx<9H^i9=tu32jNo8p0@eT*gY+40*(W@6UX*F%+1iEn$Zm!V$V)XI1Gw62eg z_bGgCeq=qnut3l#r8s1KYZVuo&i*i?)mGm>Wu<v-3>2D`8GJy-Ks--)Q~6L+3xrK z?0N}9JocDhgYwRxhnj1tr1aoMb~UuFn#S#!)$LH@cOX`k_GK9s?hJzl5JT8pfR=7Q zAa@C6l}bAt$+mG)AP_`r+xOj-SZrUJH|mSsyuge*---D469j^Ch#VI*a{1r>OpL=% z7!wWW%qdN41|{MsSEM^XFqdGjDhafU=2ot+F+QC8~2_wN3@q{ z$|LA~Ts21e5=+Lv72O)nV~iz?ik-tR_7&4W{3z7AU92f~(JkQ1#FX~H6aDo%^EA}G zSZW9D;7<2^Ext5$!^uYiGWvbDsl?vDe0KY4-CJ z2qGXL(vgzTQIXz5O(;rFL?B4W?|uGj_RRb4J+s&BleN~KHSYl@ImwXZx$keeuFsW( zJlFY>fCm)=eOy@= z-BVppf75ym^6+?h(Jon{E(gBNMByi~(D*$ZsBHls2i%qw#5j^dHN*@LLT?>3KH^n- zeXt#p#?-LnCkNt|XWZw;p-;-_3IM$w(+(4$S&{SFDi_M%nRpKyg{8@~h@p*!J1xEs`+->_ENI? z;D+r$=C`AM%}H@d2#KY4J79T{#xK?j&<9<8An_FrT$in<-|_BLYA9H`vbQS#adiiQ zgkLbL8bHdR;Qiwv0V=o)EkAB`Wv4zxbMUuYJg?|Ce|FOjM!|E}TmX0Vc+O_|Uf(jT zuPtN=yY-D+Wk$B1oXl6G%zsELTyM;-DD!c4uTM4$z3}AjWH@?1iW(s<@+b=x&$i~Sk0y;@4 ztVU^lnxPhTHQOfCZ6#Se80GjH;-c@={m5wW%Al~#1PFyaK8p&bglsb`!T6Cyy*62z zAr4vzDu0kRo!TxbmD?^cBX7+W-+#Ro=YN{}^_@ZV=c@1c>5#tWkj~v1#4;GTLNxQo zy7DhlnD)9$e9TuI9o2eESeT}nPLtWz8gPPujD&J*LS>&?u1cX_roQaR!6z<pxzh7T!;G&vk}Qp8IHA$Pz;o1Qs4GDa?*iWQ*sZqX8{W_}sB5RQMikNgf~n z&$d7OkM~0v`+t+|S z9`Gxo7N7|ML|z2g5aC~j47bSny`Wnh^6Y)GE1Q6XaMdSaNZ4LKECs{6gK>wk z)Bcsh5(Q4tyYtDb*=?m+KUZajr+%a8=13hb;B=b*S&5$#x7E5@VNv#w}cwizJWBC2R_F(G;8V`pHOn+ z5~Nef_88m)KeHpMT^~_srl_doV%(H*$@FBIMUnLd-N(m}XDWcl;X#JHbtxmP;nsumf6NPKq2b`OGYU}-1zN&ol_vV2t zB{r9QNVMJNSRQz2zlG9WR^i#;|9G;NI=P^2Z#x&$eprSBz&6y?V=L-Kl{B;_0E~s5 z8X~#p{Jb&Js{2%FOHqlj=GxD`v_Bc`e%cpJrIa!ZKqt^b9P3FQT9ES5HdKw+Re-`rHOk5BaJXv?}e+Ng@suXQiKYrZ)@t<#2dq_ z=f+2vTUd9S z8{=0la+E$CzwmV^tmS{kg#0gHh5vAK^S|?c|1Q+wziF;I|8OASh z>hR{1BBAQ~Qa--Nm*`X307n6nP@((@Cqlb9;#;91aeH*Z>>p6guI~*q{yFJ~*nUsa zeV#A6@t{{7Khd(Z>%;$m{1Ay)E}SHa`|_qFm~7L?65ne0sjY5u(x2nO`G)wx_@T?? zS9%voR%-9Q)5~A5cC3gdE-Wleo@1&o=kyu5bG?+ta)#k3``DR^HK&>!+tQe*Gt2%O zCDC=I`{4@xzla&Gjaj7mA_J0!`_F+n(CHTy2D%pL7s!Ba9cU8jYB9B{YYO5sT3>Fb z;www|w&pljkReAP+J{In#)49xGdjGSot?ko{3XRbrMbM+?aDq&;*k%~fWwZ}QBIV+ ze?Z;fcu78A*n{!z^Q0wjp*IYMF=v^cAuKsvZk<~BL`t*@!l!Dr45qlvCpPj1H^iF8aR2N>* z3w?zg7oF4RFyi5uWHwk_YQFltP0Tf$Gv$xl`C9y=-KYh1sxEo2S6>t;1vrI7rxN!i zY*;7lc6hgy&o|o5BxZqLilIu<6G6W*GxETSN(Ao-5S;6C)YddSHgVR zHCt2&NQmnAu(&Q-GUL(I_@?r_f}-J96W@;aZpS5J#!p1dO2`;e!%GU5&aR@8s;6B7 zxA%NWl{Bq8swxO~$2F#|LPv{I+8^?32VD=_l1pAf*9;*LPos2SXWct}iThoo)Vq$I zi{2pG&yxiFN+Wt4a-iKvPi}rs4FFu^s?-X%PF}K;(?ixe!WJx{OWF8IyX#p`@57mMQeTEgKi_@$| z>C7lCeb$jyqYpp`WAmRZ1vkOcMw^`vrbW`XTwe6BA1`TZkJ#HRLn1PcsfGBP z4y+E=O`hZZ!sqTkYhvDD*topVx=0%8Sd7Dnqb}+RP;||8aLJ9l^8jJfk^c8=hZqVh zMi!>L4NC<5Apjp?4RmD%?}j}V-o@wT1!KtDi_^38=UPB~Vgn!m{HC#=_Agv9FZJHi zycU2&ex`Uk_njB}rjMUnhKq29aV9?WvTlRZ^zXh;I*%gtPf>F}S{y5)rya7l!K4BG zjb$Rf+evcKVtil0Jfw-+4>OcYtgUQ0@{NyM~_i zM!0m)*(aAMLh(UU4LO9tb3WZRBO)UrYe62uv$UJW*4prirwqb+(RUnYX9?YB>k}TP zgk6-s2Za(8|;VvV!I-x@;Soz3W9zwLuACpke{Kc zyl}U7&SOog9q(pB4qU9y1hnxrs^f7Twj*7vQ!8}Uo|E_VF1Tl)OI_i&tf z#$c);rvCXeT8)g>sp(thM~teGFu*0yxMsU9eQZ5?T+9cjBj4%$ET5WcHIo-xtO~t z8!!TL#j*%$IGTj zwp)~$v06_+eAlZ%+Zunp5D0`1Pt>GLtj4d>ch+deM7hSs=J<{P z(E+g_Tg&hM8{jhx*~B59Uvwq%K{gPhbw8r&=47RFTV1Y?+#BEsGsXJExXmSV3!l>A z;d=abX?qEf?MkiPOA~%gt9l8&z6KZhNDw+A+;Z=`x^+Nu6;l}3I2*K7WK=BuCG7Tz^{{$(5n z3CoFu^5A4J-=$gsCz~Ji@6)go$&*s($_V-#X1Wp(?7d&ZVGASmADv`kRz}ehnCV*J z2`yCEOCv&Qg|@vA3zL69=2mz=pLV=A3`ET)rEUEKT84v=`lr8429=-Vs(Wf$tpI(U z618(Zm0`n^aTykiY3%EO$6{(5p9An;>+uSjPa!hEqCWb1HtdUI(z;aLjm8YN8)$RG zCjzb}i-*GgZy8qzKuCfhN)c$C9H|>%35WVvJjthYzrV%)%=;G_3q_u#85dTd*`l=< zR?P^Z%y?@w7xgw^(SfkGG7{muj=eefxIiVf-lbDDp&^0?>0R zFhV#{wrzQ6rhKU?4Rw`BBpvBXd`K%>@ND5sDfJrg1rr`xGbCm3hRC1Znlm}e*rn2q zr^-pTl*~BLu2Ao?B%lI!Yib%3S(U8id5lDvOe3#bpSBc=G7;xf zu)^J@#vOy{KQIC)w=(2EpguIOuesC2jO50Ww-z};^va}+{{72^%DkW>>%bw0w07v{ zsjUt&5ZCIFhT>g17~3C^G*qhW7)PxAo`3aDS0--+XCt+C7l@Td_hGm+`0iI z4_{LfO;`H0z(eghz5yfd*HI8Xx-<28Vuo!lVysG&9@Vi+v=5rEzT3X++;&NRq3=db zp@bitr&caST;E-WJ(nTFdIMwae%98E?raBHPwG2Yp7~t(zWX}^yiR*dj2ZuUL#UxK z!?h_x+UwC*lZzbkSD%C>53~HlM598EO;MUB@iZyYt^o2TLJXNd4B4*Kr0Om<#Q*-m z*z6JbvO>~k<;m?)R7aBr>LI18XR!v6s-!-EI8gtLA4O|hL_8)0;p6@sx9Q)|_v9}g zceEOpzm;JvX)qTosBzqHSzKQiVCFAwTGOAhbB6hG;I|Eb0|E84ZI!KDmF+w%LGlL1 z>o@&pUR`KzM5Bk~#J|AP{}s|8+cg;N3_F4!B{4x0I=)9`=KI_5OyVCIDNNO$ibDgOTX>W;IY zRg+iQ#OZ6}=NINZd@>Vm{5X}(t|cJBaMj-ck^}|qJj$lCIZ^mVV7nDAQ|I156QsZ6 zCM8m*?xuyUV>}m+j1W=KEdgTMssOrc!~l}JSU;FaQemLtm+;%LLN6`Rg_bZP~%q9Qwq#x=iMC0Yk>q+bmDRh8YV1=_U3W z47es_=sl2H+n!xfZwL?W9MY`b0%OgSl=Q3=ZUFh3a4zle$zc4Y_@Wq*UYNE*lju}Zu)y^kJq!IzStH$&LF1alHXgh{J>TpM%j{VJ75G8;^xhU zpWVPjRyM!pv-FKa&5_<39-S?BF6zCsNC?2;>cY?L67{j`_51(;oZ`u)?g`3R&J}o^?9O4D3JQUw+$YGw`19OEhQ=8zN3}-OC`!PbgK_`YXny zE^(XAV;3&Q_*RaFd@)UTe*qdi!|becD(tN#v)1!ooSsqev1MzlJ!g|5xbb$@tHI(*%tw#Zj~h^{zd&@e@s`);ygA z2HQDqXeo}AHePb#;dj}QZq$n^lC0kdIx-@+05eXbV)O#~@`z$!((_vKhm@jC(WeqO zPkz`gctc+gPB9fbb?M(PUfymLw`AGa7B~u`B$-fUNV_R$r7^nvhQ1QD<)fog^DHu? zOE4_j5$UajReAl`%`PXDZNb5KKK*+t(^EJrS`T^H0R?=Cy()Z42k$~kY1&g7797U8 zwpKwL$jiPGpTF47p=N`>NSE^odBU2T3cynr6@ELC;9JE+42b@%89>gMLwGj3NSDZ) z@C}FQiBE6#+2o#I?Y?|{RM?qgK{YrS}g{Ln(b=nq^{0~Ye zHMMLUQJv51Yp|CHNV#EK$7K~3YIq-A1P{aqZ!c_`jSrjnp9wJc`|&OsB=c-JJ(Mxi z8N*KddJ?aH1}Sun%+j+f>1A`*uyVcmFyW7#oL2VjG8gt6-ypAkzkc(K@vtJ+aQ+wL zesChK?j#P!_A3CzO2%)qRUwNkF+%@<+`SXrO1*^;N~viG90{B?llREgfU{~%((YyE z_MS<%F7$aYoK+(&0mA)>ehwJf3cjd$&_xkdaIz9(`}8O~JmdGxhIPb*zV$g1?W^rj z&sJX&Skkk(BvuI}M!ZAPy$ZWy9$^y@E>Lpdn{MUsOFHsR1NO)rtGd1@HLS1~zU+~- zkLT?>mZs)!5Z^TeZEc`Wc{TqTd(pKKnV^eJOd+bv!k@xaVjMhr@v{NT@E8=3OrJt@ z(uI-#jIL{D*Urq&_sis*-pgSz)07(wRsj9z$kjChq2Jji(XLaQ7r3cOl2UPHvO>bxPdP%D5)L7&GY9761ifr+DAzT(6LWAF_brT^BB5L$ecfn(Gv+O zbn+z64<2%P=2VssDa!xmn`K}J4Ou#t^M%8rs!MwC1V&O^@A12zUZNtBEW^2a!jByX z##gPboOk@d{;BPrDj?%I+y&F%j((553w~5RZ8e>_=k%ZuHKYfBx*8{x!i@+I7y<%u8f1 zgQ5Kw2x=O1;yu(ACRj|8Fi@=@-Tu)0`fUHvlXJlYHjyhVj2b_HS$z^P(?~)yjG(2c z_8t~;58{Rzh2*7P%aZUq<(F74hVjGmaJ<9(eRvMKHhl;qh(Ms6e8_%ul9MQ%v_n#8 zXbivbzSc3JawYl4@TXPhm2y|+`%Hbc!2IMYTAXUw39rS((3}cpzA=}R`i0%bi^Fph zvTPf(+>$G-zEXN3x6}CwuI@dH1b_Keh?mE)(wNDZzM*)y(xWY31qg zIR*#3-Fv&%YC;Omtw7>n(|w%{^t0pY8dHD16}B{g&MW-lnVRW78uPNY+0T+zvonXk z1PF6RjL@!_&T3%#>kgZR5Cv1nA9gLPhxKJ0HX8`NdHxapp`4@Z4TZ>(fv&F$4}-04 zBRd!w;;RHE>F0qyX^hfB@}1gde2=W{5d}XF3atZ!y_&f`e3%r(I=h=Ey?M~{GU!KI zlvcw%|6WZd(8;^1T#OGU`GG8zjSTL_oS{K{JwnU9z^>Is=nONz&b1wRNLY{b?I$&S zmzLjpNF~yKkKwdvkI3*XEQPV_Z%}WeN01LG90R7>+D<8bQmQ{as$m9-nJ-~qGj-v9 zxA&W`33_odZX0_njnbxS9bZHy^+GBSGUHHa3mZu9ozNAS6ax+DBzC zN!u-GD*SPd*qq$PFB4P_g^p|opbcrX96(xfd}l3FMK zlq;-f@YrZ%tewv6WLCT$)dvDL^7-`yplVMJvIyz@>#;RK*qS|Z{$UT(;{-WAVOe()I5IiDMh&n)5aha>cen2X3ElUpRtK8q z#N0~Z1zdJv zo5br~eTXQXt2Hj1>j*TkNNdWN`^}ylHdQvYDg+1(W{HBC$D!1r8`Y!Ad_$c%nrRhY zUxO45RpFtfbQj(Gf|((lM^CfU8s_{RA;NW-s0Yf_DbhTYlv+euOlmH!nXPP$L#cX{ zYBa_c&ip=#?Y);!@*;2bz3+^8=@L+oC;=KUa@=xX8jAlIOnXE*zvwL#me!RQU_D&t zRUag0*%+-kGRu(}(KpcOlX0{9T=4Vb^2Az-%7GObABl0D$B>HC;)943VPsI%zS3jw zsat=azwDGGiM;dTRN~lmyHgvMEYTBd%Cc!8u*cE~SbuULsrr&MP71v-6s-Yc^%T5J zi1%Ab6v^L%#Yk(Zrphe2df)6&(M+?w?)@6^9Ah`RohK(dF~M8X?{( zE-pcCjX&fCaxd(0l|)4_0ZP_8=xzwR9!CL{k#v}*UVh=-cHy#szd0`i}{|e#>#?1xE?V#fWZ*Qmsz+ z*GLig54~zRg=xlpG}VC^RL~8~o?qXa{ zbT)~Ny;Z6BT6$iY(<1V2N4A5^WE$7p?vu%ckA&9WooVNkKJyfC?lU!4OEa!7#>0gi zI~IA-5R~eu`nw8mx_0vyZch5m8Lxv&M~zyD=Ml4JH*V+}*1A0D(*Xhy@kAIO?Jfo0 zaiXnk+_JP8u8=#4obde)?^`okZF*$4yz6ERu0(=6?8Zoo+YA)L<%r?Ex)jbd#$Ois zoSI5&8&Au`d6mYPY#)4l{jCQD#6BN{QOyhKbMP~Pm0D&^sk0rL`I5NNd-V#3m+o7D z(r0FNL`s;x%&G=SfEa-@*P$X0<4{Xph$%zHS#tUYws?K|z^X>S&+60=Wqg~VeE^L> zL8$gb?Ckly-bE38%{lm{;E2!RUz51HX~d@r-?|nKW;XHotWjf&?&n$QSnado+|g2< zI;W6_L`)xaW|uUuy_-~t=31OHFM`iG%^3wx?z*XG`S@$JroYJ%n5z)ad>q*y2^?8` z)Q1$IF5nucF&5HH$TfpGx7)y2KV?gkOBOxoARA|H>p!wB->EKiwoUTSPD#F#T$&1L zn|aVw11@OSrCh8qH+QS1t+iIs?NOFVsv^td_)Lb~x113}224*N{6OZeEu*dqJqmR>`pjQVdX8cUFdX zLj;195H?Lo-J(68-X?sUckKDV`g7WcS`c>c69xsZ8;b{|7Deym7|rFxzI4M}I%WP^Dl>}A z5X9;KmNJ0WM7H!kP^Jjc_+KuoJV*I{H2dU@;chiv5TXyi&dS=6fLhmhuB0dEzE;c) z$S}Q=aocP^jZj4jGpBG65$$O_`oI^@YH^Spga^Vz!Z_a4b5JdXeJdLs==;uU?=JgF zNr^`{L_+V%SAECK48g0(CEL0#M3pEVp(J%|Oup=>GSd$B(~S28pOTG+cg1E)S%1D< zQ}pIZp(`h9jcO>kNaP_q`Co{gB>qE!@vx?UI$hFgi}n? zsE(pX;k>CcrQ9jQt9(=d2aL)HQxT98*h?eY#aMX_yzcjClJg9IMe?_{}lSwu}SnZlW3() zJk!?b3Njk-4*X5&SNB%>A#~&P^l^gH(Ami26{kRbQu7fddQCVBL)iHX?A=3fuVN^f+1TJ-FGb|)lF@+z?0f>M&YlOjr;HlW-= z19z{Fe3%%|yS~6VYw}q6-XTPD{*&3o?id39&L=<)`#q7CMa@3GfP}Y0Yy9vOtG;d> zGBsIGXXgOVE-u`t6_4Dit1AT0jA{9o#bYex++k9-viTB@&a#w{E z&P9CGNJ*K4H$O2r{Xq6AE+b>^o9|Y#^rS||D5f4teF&cNCrcZTw z?`3E=K??|aDlS6HtQp&Sxaa4$$a>t65g#8M=iW<9ML#-8`?FaUyC{m&p%48HE)Izx z!kC>rvmNxb0@}J9WXrwx`q}lKH?&`mygV+r(Vf4zlty!c6NE^JK6`w4F*PE)Lx(X~ z)lMs~>{eTR#(GqFDo|S#fNM*u1{UM6ye3;iW#|z}5Ji1kDLeWWPHosGL}_iE@YX1% z?#A1##8n$!`q-2uXu^7 zF+Rz4xzF9wJk_yqdK{4GH6Bz@@0OtW$h*lXFmX`~<24QFe-qB(qQ)#Q&4#`h0^u@D z)Qa*3jaioqf?0l5#hy-^y?gNH>{5zC7xb@R0k|3`vbmr|(Tb|u_91D-0?~nn{%a>C z=8j)?dvM?EUH$=?#-EP(B|u$o@mdl64JQ-6T=Xi5Oi0?sM`m%2yA8MMQ z+KU98NCFcbx(YR*>kkKeu3WK)y! zJC`fmfq+k-G$>CI(*af^`iwL)EGf83PRG6WaVOVxCzYYQN*`~zrKkMxM_6>112&C& zPYNTU#-3H*pwVc^3HCg-erbZLLaO4Sou~A?BG#vIQF9db^xJ&X7OzQ%ykx!A%~((J zAfO*gUQho!^M8ee^WT+g{x{9ws!9vmjb0QAgb+SLQ=y!zBYuGpzV!wjSiEn&nKIvo ze5D_w-U0fiTJ6!$UMtD<2S7EX|r-KS`N7n3AFF~UqA~eqcg4vD@^>E^~5e7vv}wZW$3S7iV0&Hspqz4vS>>dM(+mgTq5HJIDQf?4 zC{i7koCA%{ac-Diu!bb-x9 zew66TUm+N7noFto-$5+&%@Ykf<$Vn!I~kbD_untQ`cDlkzTJj50J+33aeU|@QG6`M z)eh>uD7JA)Y3r|kNwL%Xdd2S>(Uzfu)Z{XL&52h}zB8OnHPP71{JFJz?7fLO_46bb zL9^iXXDwb0?YH9k7`2$DpP8%$<2UYW-0b(^0D^nq(6vWf<`}(1C}2uK(nm; zF33YRY1!whZ^@cE-XmU1N zILov?Ol|l^jPUD7kS;@ucKg~FVBY{}hiy9`XQ(ysneWG4Vc6%u*I{0sr_${k!m*wc zGpKjam@xtW@$4(-ord)G2V09lz|UI^C3uqmo(zuEvoE49+UWpAC;I%{12{6tG6B`J z($r=sCetiwCI4bO84?i5dha(lx(4NY@)E~8Llq)<#7I+i-V_cZ_mbu&AH*M^Dv-Se*w^%-K1!>L*pUBbVg(xv9S*( zF#7KmG1T4od11R_3VXN#WH?b^P97CmpN0pPykIXZV9^GMftgOa@YwI zN5*y7T#~PsT^ojr)Y~h4NSWGJqzE+J2dlQ0PS<_>c&gW|Y79uRUBgPCN3^iuR#*`6 z3>CIq#h=9V9OLQPd`ruQKQvtt+nl-M;y%;j;>Ik@J9zm)I8fjUA(IYDs7}j6&qX1? z{D+^)e`lgD)N0|dCyV2u^Kadc&-a|ORu2POPF-%7#Z3eMk_P4H*6cEuM+cl;1993X zD@U$-@p;x~qmqSNH5MiJ-6xc)ls>Qv!p=X5L~yin60f-=I+%1+r7oU9%Oh?3fG5m% zEyk@F$!yy?8i?R-`%yhqc4@jX#A^P$xUrgxkvRW1$R$KQaL%RC908YkNhg%A_1ZEt ziY_y%X96tD_*o_XzyyZE_;2M)%R6+MYQ-d84{3Fk-_l{c-`KS~W8Q}rqbe-#vO6J( ztyh<;;vHK5(g@f%)U2bE=7FsVuS0?7{JSGfTyVCrR1`y1vu6Q#a95DVeo~Isfv_PB zms5*~91t$5%e~r(Pa0gj2-kn{=m$Ia29qSdyiny+dmH^d>dW1~QvW$pCkDz+8#$b` z6McNogwjWR+pIJgO>)xgGox@dl*x2D83|Ojd*ly&Gq%AEKk?l#C8mp>)12@~xou2hf$GqF zftK?Pxp6DH;+Lc%r7v2A4w>~*otfFpf_P!?O;)Ck*FEE<_({Jpu$Q(8XjE*8SAl=t z>tGt^bC@@-HB@^Q_(a6s&bgeiv6A&ImnXhDbxl0gXH!(3agTZLAmpSz2>&dAKF~;)CA(vZ;S!4L zS%l11+p5wBvyKIhp0YNuivrFcUz#TF2<^mbD{axw0bZwIDxokClopE@)E7j~9wbngiDGs^vjB$I>#q(nxbs`y zuHf&^9|XTV=^H85H;KJ7Vn7mTN#b!7VAfEHT;z17YLlZo7rCV#O8+&=;rMcGC@=hi z+Pryw32JbA$mM|sv|~33Cvfs{UnO}F>^9=pssdC1nnUpDI~hb3=VO*{YYwb0iuidz zEY4aUSH!2Ltu0wP8?a0OCZH-x{{qDoUz>RtyANff_E~1z>cHa(J3?a`9B)In5i%d~ zZx~+)7x_v>w#~Fbmm0VB??Em_*}EUwhqV>y&mEm@^|h}~>?{<2ZdRj3i-v}1onsUD|Cj#(A)i%(#R z?V;xfO+T~KfiAyvFxxKS7>w^DG%M<6y^z{_^kafVbC2W!WBBJ*>5D%VpFC06%#mj# zr5LtD8-yZ(RbLj>mIMP+`?@4$+OqymF81tsGjE!P()U|cz^|q2ZlTgM-UotEJ*}Je zCyA|2EL(R==>W}O>aky|*~bTcr4}K9()DAL?ro`S)(MYNZA>TbdV}8Po8DhSrvmq* zIyHM~Dq$Z>0HP_xmHaBEu9FXOcz;Pk-{{bD; z4%77vgWT1*E_aUodTKBvH_P>-QR$M~HId}L`bg=oV$iNspY<@{2FXnW?#1kbmImtI zMjMx|Vd^DnVEY4DLqx@5{UkXbp1u0m`SP0IeD7QUoP++=248OtVQV!~=`bJSa9X#e zOHJN+G?ld@^5^*3C7TSb23z}OL|c7lreDmn9yHz-49r!DfEiWr*GptTyMyy+-DXnR z?gN<&`)f^2zN5Hxr`a{Vlq*?DWrWOLQ+z_5`a3{p7GAsY-GC?j4p|j&_YD zEZ@mG;3!WJ>o;?ax9nS3vr+Y#l*SlfKQFUA zu(cGoCmP-%R@bilwJ`ak=I-m2fM>w`E;QmF&<#0K2%Qyh`~a6pH*l^PBI4*5sl}8j zB5Yg2$CWNQQYRs0-Pkg^ALA48KwBj}b7SB=Ah< zVOMinEG-26w{0PfB-;|VrSpD??Po*Q1U&4ffrl|iGSh=MQ2!AFh%`#=B)9b%5gWIl z*GREV`ULKVCe#dsmTxxSy;@!O#r1-pqQS*YfzM~pKH>;?d+H1lld2D*Ph*7i#X86= zB)>Ky2bTS^gGne!c=W*6o3t=^BhdeJ`ASmBv|it0bkowxTl!1jjV_`rsJ@gv3ra%* zg=0xBaoZ{IMBU34CqX%LXi@CcoASl)MMqO#$cXvYQKZO~IWeszNc(PrqzJWm8IuZd zVGwocq$<%4FVDae26xejwzixY8!tHy>Do*8`Y#=Yf?gQFt2_%%{bw*xaaX5sKszvu z@h+ubpu)c|NZ{|q;`_M|{$4}RHH#@$c!~nkyCSS72 zK_M`I)#@WmwWq*4r*LA_lrmiP6$;-1Kre26jsDZC6g%TWC)IK%UMf)a#0k;+oBJZ^ ztU%WhK-ue_6wOLmj(}NDx=)@}Vtj9FS@I_6Nyd2-&iE0whH^dcr&>JPC#OnFqBu&W z>zh;DW8LI`EveOf?yzBrVqJyi8HICph z5mfF{uXOeX$}vIKQqWA`?RLL1iDPO}A2gz`=1J$KDik9_|Fm4?Fdy;trQJc4N-rYx zosuuG6CR5(D!ywSx-3zhZ)7|256A^J6McJNq1slVEXdxlck?t|O6t4vO&t{7t$o6|Wqd-wIa3d> z@ZF+BkwXg|P>-Gk9+2BqvtvDPy2weo^GPlWQQF>E`|EFhW=P0~+HZVZ*M91f)ZX8k zwYk$Je>zHAd}`1ZTZ`W=ucUEc4}9?yt!Q-sfj_SXceq zcD2gZrc%NiMq%JcIu{ZP6hC$N9&mDV%*5Lqm6VD<;o5{iqZzj9Luw zX)v+Wu^k<^Wa;|b*9A|4aOn9S-vG3VDHS@~0Bv6VL_$yZq3hsbQs_f>U#;62=Z!zg zJXwC6^2+M=SWg5CS&WzrVbSM7Iug4dOJhxIuZ75K?^U_UI+g{D@4lJ;vgI=c6r}M$ z7qUB}bG$s&4P`{u-W)F&t56hqtOq+?u|dd>&$e8stl2W=yfM7Od{O`L&tyq}UqzK# zMkKzVs`S7)_I<9Z0ONUZklRGv%J{@0?2GR8*30qFf=)BUTAGS5GBv79oB`^+ys36P z0Z#(=B4Q!JIQdPqKx0yS;}kbR^OcIA{I*{q@4!Wvy5HXYE6XmAQ?4sd_G^UOM&sD1 zAM3Sp)8nqN5z`l^)Roi%NrjDzKryt}g{{WN zr)~KXu`_$!VXL38%Mi8&^N;IWpyge1Vmadl{rKIlXt5Ic#x;d!Cty#lRN$}qzm9Vf zo)Em(xgb5>s)J?GKj&j!a+}OCuWKCTqni>Z>B^@;5H#L+U;1_XK`7-mTmU@-0nsqz z@-9ivDKx};O5sCND9XiMebbAA%m%O4x6JE4wqr%Sj9Q=0(EnHYXimkGg?b4GElY?b z2wPB+5`+YKrj$RxT#)hY!-nSF0{cooRMi4d9i)CVm8QH^TS00qP4yzgXb*>Rm+2!R z&FX5%iyW_8p@qM={ADWcE|g^Z2L&j7FV~CFx4tK4!u8_X{DBkTQHZTwf)X~{^R+0( z<>a*P9HAtX1}W)fI9FU!34Sjy;r-yRO4sNg%;SIyLwO{|@{+kT_42_j%Fif_2ghP8 z;K6;S6-Prvlv|J(KdwLA-Xh^Lqs81}*k8KeyRIe)R;ewtov#CGBhji3Fu`ZpB$W=? zfJ^MqWZwV6f;HkbnijhX56m|=TU=flOx-MG6@J!c$mhSvZnF62&j0;m*#E|vqW?eA V^1nUO@_&cA{9j-HchQ$K{}1{cQaS(t literal 0 HcmV?d00001 diff --git a/docs/source/recipes/librispeech/index.rst b/docs/source/recipes/Non-streaming-ASR/librispeech/index.rst similarity index 82% rename from docs/source/recipes/librispeech/index.rst rename to docs/source/recipes/Non-streaming-ASR/librispeech/index.rst index 568a8016f..aa97f325d 100644 --- a/docs/source/recipes/librispeech/index.rst +++ b/docs/source/recipes/Non-streaming-ASR/librispeech/index.rst @@ -6,5 +6,6 @@ LibriSpeech tdnn_lstm_ctc conformer_ctc + pruned_transducer_stateless lstm_pruned_stateless_transducer zipformer_mmi diff --git a/docs/source/recipes/Non-streaming-ASR/librispeech/pruned_transducer_stateless.rst b/docs/source/recipes/Non-streaming-ASR/librispeech/pruned_transducer_stateless.rst new file mode 100644 index 000000000..d8569bc5c --- /dev/null +++ b/docs/source/recipes/Non-streaming-ASR/librispeech/pruned_transducer_stateless.rst @@ -0,0 +1,545 @@ +Pruned transducer statelessX +============================ + +This tutorial shows you how to run a conformer transducer model +with the `LibriSpeech `_ dataset. + +.. Note:: + + The tutorial is suitable for `pruned_transducer_stateless `_, + `pruned_transducer_stateless2 `_, + `pruned_transducer_stateless4 `_, + `pruned_transducer_stateless5 `_, + We will take pruned_transducer_stateless4 as an example in this tutorial. + +.. HINT:: + + We assume you have read the page :ref:`install icefall` and have setup + the environment for ``icefall``. + +.. HINT:: + + We recommend you to use a GPU or several GPUs to run this recipe. + +.. hint:: + + Please scroll down to the bottom of this page to find download links + for pretrained models if you don't want to train a model from scratch. + + +We use pruned RNN-T to compute the loss. + +.. note:: + + You can find the paper about pruned RNN-T at the following address: + + ``_ + +The transducer model consists of 3 parts: + + - Encoder, a.k.a, the transcription network. We use a Conformer model (the reworked version by Daniel Povey) + - Decoder, a.k.a, the prediction network. We use a stateless model consisting of + ``nn.Embedding`` and ``nn.Conv1d`` + - Joiner, a.k.a, the joint network. + +.. caution:: + + Contrary to the conventional RNN-T models, we use a stateless decoder. + That is, it has no recurrent connections. + + +Data preparation +---------------- + +.. hint:: + + The data preparation is the same as other recipes on LibriSpeech dataset, + if you have finished this step, you can skip to ``Training`` directly. + +.. code-block:: bash + + $ cd egs/librispeech/ASR + $ ./prepare.sh + +The script ``./prepare.sh`` handles the data preparation for you, **automagically**. +All you need to do is to run it. + +The data preparation contains several stages, you can use the following two +options: + + - ``--stage`` + - ``--stop-stage`` + +to control which stage(s) should be run. By default, all stages are executed. + + +For example, + +.. code-block:: bash + + $ cd egs/librispeech/ASR + $ ./prepare.sh --stage 0 --stop-stage 0 + +means to run only stage 0. + +To run stage 2 to stage 5, use: + +.. code-block:: bash + + $ ./prepare.sh --stage 2 --stop-stage 5 + +.. HINT:: + + If you have pre-downloaded the `LibriSpeech `_ + dataset and the `musan `_ dataset, say, + they are saved in ``/tmp/LibriSpeech`` and ``/tmp/musan``, you can modify + the ``dl_dir`` variable in ``./prepare.sh`` to point to ``/tmp`` so that + ``./prepare.sh`` won't re-download them. + +.. NOTE:: + + All generated files by ``./prepare.sh``, e.g., features, lexicon, etc, + are saved in ``./data`` directory. + +We provide the following YouTube video showing how to run ``./prepare.sh``. + +.. note:: + + To get the latest news of `next-gen Kaldi `_, please subscribe + the following YouTube channel by `Nadira Povey `_: + + ``_ + +.. youtube:: ofEIoJL-mGM + + +Training +-------- + +Configurable options +~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: bash + + $ cd egs/librispeech/ASR + $ ./pruned_transducer_stateless4/train.py --help + + +shows you the training options that can be passed from the commandline. +The following options are used quite often: + + - ``--exp-dir`` + + The directory to save checkpoints, training logs and tensorboard. + + - ``--full-libri`` + + If it's True, the training part uses all the training data, i.e., + 960 hours. Otherwise, the training part uses only the subset + ``train-clean-100``, which has 100 hours of training data. + + .. CAUTION:: + The training set is perturbed by speed with two factors: 0.9 and 1.1. + If ``--full-libri`` is True, each epoch actually processes + ``3x960 == 2880`` hours of data. + + - ``--num-epochs`` + + It is the number of epochs to train. For instance, + ``./pruned_transducer_stateless4/train.py --num-epochs 30`` trains for 30 epochs + and generates ``epoch-1.pt``, ``epoch-2.pt``, ..., ``epoch-30.pt`` + in the folder ``./pruned_transducer_stateless4/exp``. + + - ``--start-epoch`` + + It's used to resume training. + ``./pruned_transducer_stateless4/train.py --start-epoch 10`` loads the + checkpoint ``./pruned_transducer_stateless4/exp/epoch-9.pt`` and starts + training from epoch 10, based on the state from epoch 9. + + - ``--world-size`` + + It is used for multi-GPU single-machine DDP training. + + - (a) If it is 1, then no DDP training is used. + + - (b) If it is 2, then GPU 0 and GPU 1 are used for DDP training. + + The following shows some use cases with it. + + **Use case 1**: You have 4 GPUs, but you only want to use GPU 0 and + GPU 2 for training. You can do the following: + + .. code-block:: bash + + $ cd egs/librispeech/ASR + $ export CUDA_VISIBLE_DEVICES="0,2" + $ ./pruned_transducer_stateless4/train.py --world-size 2 + + **Use case 2**: You have 4 GPUs and you want to use all of them + for training. You can do the following: + + .. code-block:: bash + + $ cd egs/librispeech/ASR + $ ./pruned_transducer_stateless4/train.py --world-size 4 + + **Use case 3**: You have 4 GPUs but you only want to use GPU 3 + for training. You can do the following: + + .. code-block:: bash + + $ cd egs/librispeech/ASR + $ export CUDA_VISIBLE_DEVICES="3" + $ ./pruned_transducer_stateless4/train.py --world-size 1 + + .. caution:: + + Only multi-GPU single-machine DDP training is implemented at present. + Multi-GPU multi-machine DDP training will be added later. + + - ``--max-duration`` + + It specifies the number of seconds over all utterances in a + batch, before **padding**. + If you encounter CUDA OOM, please reduce it. + + .. HINT:: + + Due to padding, the number of seconds of all utterances in a + batch will usually be larger than ``--max-duration``. + + A larger value for ``--max-duration`` may cause OOM during training, + while a smaller value may increase the training time. You have to + tune it. + + - ``--use-fp16`` + + If it is True, the model will train with half precision, from our experiment + results, by using half precision you can train with two times larger ``--max-duration`` + so as to get almost 2X speed up. + + +Pre-configured options +~~~~~~~~~~~~~~~~~~~~~~ + +There are some training options, e.g., number of encoder layers, +encoder dimension, decoder dimension, number of warmup steps etc, +that are not passed from the commandline. +They are pre-configured by the function ``get_params()`` in +`pruned_transducer_stateless4/train.py `_ + +You don't need to change these pre-configured parameters. If you really need to change +them, please modify ``./pruned_transducer_stateless4/train.py`` directly. + + +.. NOTE:: + + The options for `pruned_transducer_stateless5 `_ are a little different from + other recipes. It allows you to configure ``--num-encoder-layers``, ``--dim-feedforward``, ``--nhead``, ``--encoder-dim``, ``--decoder-dim``, ``--joiner-dim`` from commandline, so that you can train models with different size with pruned_transducer_stateless5. + + +Training logs +~~~~~~~~~~~~~ + +Training logs and checkpoints are saved in ``--exp-dir`` (e.g. ``pruned_transducer_stateless4/exp``. +You will find the following files in that directory: + + - ``epoch-1.pt``, ``epoch-2.pt``, ... + + These are checkpoint files saved at the end of each epoch, containing model + ``state_dict`` and optimizer ``state_dict``. + To resume training from some checkpoint, say ``epoch-10.pt``, you can use: + + .. code-block:: bash + + $ ./pruned_transducer_stateless4/train.py --start-epoch 11 + + - ``checkpoint-436000.pt``, ``checkpoint-438000.pt``, ... + + These are checkpoint files saved every ``--save-every-n`` batches, + containing model ``state_dict`` and optimizer ``state_dict``. + To resume training from some checkpoint, say ``checkpoint-436000``, you can use: + + .. code-block:: bash + + $ ./pruned_transducer_stateless4/train.py --start-batch 436000 + + - ``tensorboard/`` + + This folder contains tensorBoard logs. Training loss, validation loss, learning + rate, etc, are recorded in these logs. You can visualize them by: + + .. code-block:: bash + + $ cd pruned_transducer_stateless4/exp/tensorboard + $ tensorboard dev upload --logdir . --description "pruned transducer training for LibriSpeech with icefall" + + It will print something like below: + + .. code-block:: + + TensorFlow installation not found - running with reduced feature set. + Upload started and will continue reading any new data as it's added to the logdir. + + To stop uploading, press Ctrl-C. + + New experiment created. View your TensorBoard at: https://tensorboard.dev/experiment/QOGSPBgsR8KzcRMmie9JGw/ + + [2022-11-20T15:50:50] Started scanning logdir. + Uploading 4468 scalars... + [2022-11-20T15:53:02] Total uploaded: 210171 scalars, 0 tensors, 0 binary objects + Listening for new data in logdir... + + Note there is a URL in the above output. Click it and you will see + the following screenshot: + + .. figure:: images/librispeech-pruned-transducer-tensorboard-log.jpg + :width: 600 + :alt: TensorBoard screenshot + :align: center + :target: https://tensorboard.dev/experiment/QOGSPBgsR8KzcRMmie9JGw/ + + TensorBoard screenshot. + + .. hint:: + + If you don't have access to google, you can use the following command + to view the tensorboard log locally: + + .. code-block:: bash + + cd pruned_transducer_stateless4/exp/tensorboard + tensorboard --logdir . --port 6008 + + It will print the following message: + + .. code-block:: + + Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all + TensorBoard 2.8.0 at http://localhost:6008/ (Press CTRL+C to quit) + + Now start your browser and go to ``_ to view the tensorboard + logs. + + + - ``log/log-train-xxxx`` + + It is the detailed training log in text format, same as the one + you saw printed to the console during training. + +Usage example +~~~~~~~~~~~~~ + +You can use the following command to start the training using 6 GPUs: + +.. code-block:: bash + + export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5" + ./pruned_transducer_stateless4/train.py \ + --world-size 6 \ + --num-epochs 30 \ + --start-epoch 1 \ + --exp-dir pruned_transducer_stateless4/exp \ + --full-libri 1 \ + --max-duration 300 + + +Decoding +-------- + +The decoding part uses checkpoints saved by the training part, so you have +to run the training part first. + +.. hint:: + + There are two kinds of checkpoints: + + - (1) ``epoch-1.pt``, ``epoch-2.pt``, ..., which are saved at the end + of each epoch. You can pass ``--epoch`` to + ``pruned_transducer_stateless4/decode.py`` to use them. + + - (2) ``checkpoints-436000.pt``, ``epoch-438000.pt``, ..., which are saved + every ``--save-every-n`` batches. You can pass ``--iter`` to + ``pruned_transducer_stateless4/decode.py`` to use them. + + We suggest that you try both types of checkpoints and choose the one + that produces the lowest WERs. + +.. code-block:: bash + + $ cd egs/librispeech/ASR + $ ./pruned_transducer_stateless4/decode.py --help + +shows the options for decoding. + +The following shows two examples (for two types of checkpoints): + +.. code-block:: bash + + for m in greedy_search fast_beam_search modified_beam_search; do + for epoch in 25 20; do + for avg in 7 5 3 1; do + ./pruned_transducer_stateless4/decode.py \ + --epoch $epoch \ + --avg $avg \ + --exp-dir pruned_transducer_stateless4/exp \ + --max-duration 600 \ + --decoding-method $m + done + done + done + + +.. code-block:: bash + + for m in greedy_search fast_beam_search modified_beam_search; do + for iter in 474000; do + for avg in 8 10 12 14 16 18; do + ./pruned_transducer_stateless4/decode.py \ + --iter $iter \ + --avg $avg \ + --exp-dir pruned_transducer_stateless4/exp \ + --max-duration 600 \ + --decoding-method $m + done + done + done + + +.. Note:: + + Supporting decoding methods are as follows: + + - ``greedy_search`` : It takes the symbol with largest posterior probability + of each frame as the decoding result. + + - ``beam_search`` : It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf and + `espnet/nets/beam_search_transducer.py `_ + is used as a reference. Basicly, it keeps topk states for each frame, and expands the kept states with their own contexts to + next frame. + + - ``modified_beam_search`` : It implements the same algorithm as ``beam_search`` above, but it + runs in batch mode with ``--max-sym-per-frame=1`` being hardcoded. + + - ``fast_beam_search`` : It implements graph composition between the output ``log_probs`` and + given ``FSAs``. It is hard to describe the details in several lines of texts, you can read + our paper in https://arxiv.org/pdf/2211.00484.pdf or our `rnnt decode code in k2 `_. ``fast_beam_search`` can decode with ``FSAs`` on GPU efficiently. + + - ``fast_beam_search_LG`` : The same as ``fast_beam_search`` above, ``fast_beam_search`` uses + an trivial graph that has only one state, while ``fast_beam_search_LG`` uses an LG graph + (with N-gram LM). + + - ``fast_beam_search_nbest`` : It produces the decoding results as follows: + + - (1) Use ``fast_beam_search`` to get a lattice + - (2) Select ``num_paths`` paths from the lattice using ``k2.random_paths()`` + - (3) Unique the selected paths + - (4) Intersect the selected paths with the lattice and compute the + shortest path from the intersection result + - (5) The path with the largest score is used as the decoding output. + + - ``fast_beam_search_nbest_LG`` : It implements same logic as ``fast_beam_search_nbest``, the + only difference is that it uses ``fast_beam_search_LG`` to generate the lattice. + + +Export Model +------------ + +`pruned_transducer_stateless4/export.py `_ supports exporting checkpoints from ``pruned_transducer_stateless4/exp`` in the following ways. + +Export ``model.state_dict()`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Checkpoints saved by ``pruned_transducer_stateless4/train.py`` also include +``optimizer.state_dict()``. It is useful for resuming training. But after training, +we are interested only in ``model.state_dict()``. You can use the following +command to extract ``model.state_dict()``. + +.. code-block:: bash + + # Assume that --epoch 25 --avg 3 produces the smallest WER + # (You can get such information after running ./pruned_transducer_stateless4/decode.py) + + epoch=25 + avg=3 + + ./pruned_transducer_stateless4/export.py \ + --exp-dir ./pruned_transducer_stateless4/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch $epoch \ + --avg $avg + +It will generate a file ``./pruned_transducer_stateless4/exp/pretrained.pt``. + +.. hint:: + + To use the generated ``pretrained.pt`` for ``pruned_transducer_stateless4/decode.py``, + you can run: + + .. code-block:: bash + + cd pruned_transducer_stateless4/exp + ln -s pretrained.pt epoch-999.pt + + And then pass ``--epoch 999 --avg 1 --use-averaged-model 0`` to + ``./pruned_transducer_stateless4/decode.py``. + +To use the exported model with ``./pruned_transducer_stateless4/pretrained.py``, you +can run: + +.. code-block:: bash + + ./pruned_transducer_stateless4/pretrained.py \ + --checkpoint ./pruned_transducer_stateless4/exp/pretrained.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method greedy_search \ + /path/to/foo.wav \ + /path/to/bar.wav + + +Export model using ``torch.jit.script()`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: bash + ./pruned_transducer_stateless4/export.py \ + --exp-dir ./pruned_transducer_stateless4/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 25 \ + --avg 3 \ + --jit 1 + +It will generate a file ``cpu_jit.pt`` in the given ``exp_dir``. You can later +load it by ``torch.jit.load("cpu_jit.pt")``. + +Note ``cpu`` in the name ``cpu_jit.pt`` means the parameters when loaded into Python +are on CPU. You can use ``to("cuda")`` to move them to a CUDA device. + +.. NOTE:: + + You will need this ``cpu_jit.pt`` when deploying with Sherpa framework. + + +Download pretrained models +-------------------------- + +If you don't want to train from scratch, you can download the pretrained models +by visiting the following links: + + - `pruned_transducer_stateless `_ + + - `pruned_transducer_stateless2 `_ + + - `pruned_transducer_stateless4 `_ + + - `pruned_transducer_stateless5 `_ + + See ``_ + for the details of the above pretrained models + + +Deploy with Sherpa +------------------ + +Please see ``_ +for how to deploy the models in ``sherpa``. diff --git a/docs/source/recipes/librispeech/tdnn_lstm_ctc.rst b/docs/source/recipes/Non-streaming-ASR/librispeech/tdnn_lstm_ctc.rst similarity index 100% rename from docs/source/recipes/librispeech/tdnn_lstm_ctc.rst rename to docs/source/recipes/Non-streaming-ASR/librispeech/tdnn_lstm_ctc.rst diff --git a/docs/source/recipes/librispeech/zipformer_mmi.rst b/docs/source/recipes/Non-streaming-ASR/librispeech/zipformer_mmi.rst similarity index 100% rename from docs/source/recipes/librispeech/zipformer_mmi.rst rename to docs/source/recipes/Non-streaming-ASR/librispeech/zipformer_mmi.rst diff --git a/docs/source/recipes/timit/index.rst b/docs/source/recipes/Non-streaming-ASR/timit/index.rst similarity index 100% rename from docs/source/recipes/timit/index.rst rename to docs/source/recipes/Non-streaming-ASR/timit/index.rst diff --git a/docs/source/recipes/timit/tdnn_ligru_ctc.rst b/docs/source/recipes/Non-streaming-ASR/timit/tdnn_ligru_ctc.rst similarity index 100% rename from docs/source/recipes/timit/tdnn_ligru_ctc.rst rename to docs/source/recipes/Non-streaming-ASR/timit/tdnn_ligru_ctc.rst diff --git a/docs/source/recipes/timit/tdnn_lstm_ctc.rst b/docs/source/recipes/Non-streaming-ASR/timit/tdnn_lstm_ctc.rst similarity index 100% rename from docs/source/recipes/timit/tdnn_lstm_ctc.rst rename to docs/source/recipes/Non-streaming-ASR/timit/tdnn_lstm_ctc.rst diff --git a/docs/source/recipes/yesno/images/tdnn-tensorboard-log.png b/docs/source/recipes/Non-streaming-ASR/yesno/images/tdnn-tensorboard-log.png similarity index 100% rename from docs/source/recipes/yesno/images/tdnn-tensorboard-log.png rename to docs/source/recipes/Non-streaming-ASR/yesno/images/tdnn-tensorboard-log.png diff --git a/docs/source/recipes/yesno/index.rst b/docs/source/recipes/Non-streaming-ASR/yesno/index.rst similarity index 100% rename from docs/source/recipes/yesno/index.rst rename to docs/source/recipes/Non-streaming-ASR/yesno/index.rst diff --git a/docs/source/recipes/yesno/tdnn.rst b/docs/source/recipes/Non-streaming-ASR/yesno/tdnn.rst similarity index 100% rename from docs/source/recipes/yesno/tdnn.rst rename to docs/source/recipes/Non-streaming-ASR/yesno/tdnn.rst diff --git a/docs/source/recipes/Streaming-ASR/index.rst b/docs/source/recipes/Streaming-ASR/index.rst new file mode 100644 index 000000000..8c0ffe447 --- /dev/null +++ b/docs/source/recipes/Streaming-ASR/index.rst @@ -0,0 +1,12 @@ +Streaming ASR +============= + +.. toctree:: + :maxdepth: 1 + + introduction + +.. toctree:: + :maxdepth: 2 + + librispeech/index diff --git a/docs/source/recipes/Streaming-ASR/introduction.rst b/docs/source/recipes/Streaming-ASR/introduction.rst new file mode 100644 index 000000000..d81156659 --- /dev/null +++ b/docs/source/recipes/Streaming-ASR/introduction.rst @@ -0,0 +1,52 @@ +Introduction +============ + +This page shows you how we implement streaming **X-former transducer** models for ASR. + +.. HINT:: + X-former transducer here means the encoder of the transducer model uses Multi-Head Attention, + like `Conformer `_, `EmFormer `_ etc. + +Currently we have implemented two types of streaming models, one uses Conformer as encoder, the other uses Emformer as encoder. + +Streaming Conformer +------------------- + +The main idea of training a streaming model is to make the model see limited contexts +in training time, we can achieve this by applying a mask to the output of self-attention. +In icefall, we implement the streaming conformer the way just like what `WeNet `_ did. + +.. NOTE:: + The conformer-transducer recipes in LibriSpeech datasets, like, `pruned_transducer_stateless `_, + `pruned_transducer_stateless2 `_, + `pruned_transducer_stateless3 `_, + `pruned_transducer_stateless4 `_, + `pruned_transducer_stateless5 `_ + all support streaming. + +.. NOTE:: + Training a streaming conformer model in ``icefall`` is almost the same as training a + non-streaming model, all you need to do is passing several extra arguments. + See :doc:`Pruned transducer statelessX ` for more details. + +.. HINT:: + If you want to adapt a non-streaming conformer model to be streaming, please refer + to `this pull request `_. + + +Streaming Emformer +------------------ + +The Emformer model proposed `here `_ uses more +complicated techniques. It has a memory bank component to memorize history information, +what' more, it also introduces right context in training time by hard-copying part of +the input features. + +We have three variants of Emformer models in ``icefall``. + + - ``pruned_stateless_emformer_rnnt2`` using Emformer from torchaudio, see `LibriSpeech recipe `_. + - ``conv_emformer_transducer_stateless`` using ConvEmformer implemented by ourself. Different from the Emformer in torchaudio, + ConvEmformer has a convolution in each layer and uses the mechanisms in our reworked conformer model. + See `LibriSpeech recipe `_. + - ``conv_emformer_transducer_stateless2`` using ConvEmformer implemented by ourself. The only difference from the above one is that + it uses a simplified memory bank. See `LibriSpeech recipe `_. diff --git a/docs/source/recipes/librispeech/images/librispeech-lstm-transducer-tensorboard-log.png b/docs/source/recipes/Streaming-ASR/librispeech/images/librispeech-lstm-transducer-tensorboard-log.png similarity index 100% rename from docs/source/recipes/librispeech/images/librispeech-lstm-transducer-tensorboard-log.png rename to docs/source/recipes/Streaming-ASR/librispeech/images/librispeech-lstm-transducer-tensorboard-log.png diff --git a/docs/source/recipes/Streaming-ASR/librispeech/images/streaming-librispeech-pruned-transducer-tensorboard-log.jpg b/docs/source/recipes/Streaming-ASR/librispeech/images/streaming-librispeech-pruned-transducer-tensorboard-log.jpg new file mode 100644 index 0000000000000000000000000000000000000000..9c77b8bae243ec78c33ba7b06d28bfba408d1724 GIT binary patch literal 560358 zcmeFYbzD^6*FSoOp}Ua=NlB#}5di_|Zt3oZK~Myg6cA8EP!K6$q&uWT>5y(|m_cIb z-@#A*zR&&Rckk=@=id8l*mK^q&f06QwRW7f_dav=<7xpQR#8+|1RxLqcnSUiSBrp; zLV%+k0H~<}TmS%I1LzPM00TrI3m^lb{TnMnI05KyIvM~(IRfZ^@~DH?>jX;sUFMG~ z+IzIWIPj$2L;u2w{hVyzGb?@=j5;3;2yu5{`wuYkeLxtal5TPtRJY3L-0l?MG$4gu39+RPwF%#BL zP#Xe(9JETr(%RcYPE+&YwamY+zvch^aWwv0cVL9~TGrp>|1&^pW9w}V8gB}+TiJM9 zyMnj{0MIk6J-mDX0P~v8;_u^ejl)1p?gc6c;#b$${tujgjjjH`4ZnGGwdFycL9j{i zEiJw50f2A@q%-+h+k~$g1-BmuKNbQgVzIX z@JbH)@Q1Uz>NUm(@wS_lx(0|zL5y+)t@{I`;X3#}&;v2Zk49(hEf4w?pak)48_WBu zAf^McqMM8Q!$0X(c8-e5AO?F2T7aXEvMz{kfcUwsw}Q@}Jny~Swg0ROt;)_zUhYqN zx1|?|{-|rt*IDnHel36B!B+WC`B1_j2VDb@jtc-#RzFAm>(N07Vm@zYoojtS-JnuI z4)WLZ-#lKv+SmHCf^=G$Yxs*Y;k^L3#Mx zyRHi?i;nBz@_U?sbo86HE=t$*n;;hQ_S5;ZEV{mrm+mzk)Dzv#!$kpQRl?x+9XIe( z2M7Xo05_1w`|oo1e&4kP{6PA}-_$?K@&b0h@A~~N;SBr)Yj6cr0T=MQFNkfyGJn;k z53GO`$G^Y-Rl6l9(;C#*1e(6{R-`YHbAQZCTKIX2Ko)!c#Z$8U-5VA+W%df_HTWh!FF-{o98ckuUq`O&g=f+ zeY@!P#_jVzr1;nd_<;Qxkazb8^m4R!@L`ezpANQ6%5K(tJWK+&1q1-#`dM?$0|37$ ze?LPYl)eAZ!oe{v?RIr_)$k8Z^%MXqn8D${@DGh+7@Q`+dAd8%+SkkP4|(X<7YuNg zAOc4kHNXI{0Gyyb0)Pl03ETw~02M$3&;^VDbFjZUfc@MH@CRVP6Ce_Z0iVVxKst~G z`^2M-Y368^jj^gM>q3AW4vP$OlLPq!dyM`3mWQ3_&I#i;xY-KI9Y)iiVFy zjz*8hj>eBBh9-xmil&QZhGviEi57(R3@sk*HCi^>XS6D`uV{T}V`vL#TWE(+0E!2t zgfc^Up<+;Zs0P#+Y6ta#K88j^Q=!?=5@;RR&LhwT=r;5O9TS}loe7;6?7N5|aZ{6da32m`<3%n6a4eFyWXjm_wM$n1@(cSkzeDSkhP;Se95mSdmz7us&lo zV+~`iVxh3{uo zz*)sP!zIDxz?H_;#dX4cg8Le`1h)fs2KN^p9^OqnaXc+NN4zk+*LYv>y73nAPVh-KoOD=@)0T%+7N~kz9p<8 z946c)!XaWKk|Q!F3L;7+DkmBs+9t*!W+#>>wj_Q`{Fb_Svc88vJSFUa!hg#a%FNS@)+`B@_upz z1u2CPg+4_9#aoI-iUmq2B|D`ur3+;OWf|o--mtw9 zeFJ`D?8YfI6SV@hGj$?$HT5hFIt>?%7L7kmCQS#;HZ3`=B&{`VG;JB}6df8J7o9d; zFkKGaH@ZW526`oW5BfLsZS>m=lni$noETm)G%;*2k}=9KIx@auY-ZeKqF|C^a$$PS z)Xs!prel7*t*z`**Vw^*rV91+1EL4a42&GaTIaPa1wCJa(Z%RbB^A^yd`$a`PREzgIs7_ z!dwnqZ@C7z(YQsp9l76e5Ak5|i1WDdWb;h$;`7S!`tcU=F7i?GJ>+}JSIxJ}&&qGi zpUB_Ee|cN@w#)60x2Nxr-ch;p^iJKKeE}{3Yk{`{qk;s2ih`knHG)VXZXsKtOramb zWWs8~QNpdl7b2n}ULqwT8=|bDk3=&>C&b9bG{s`Yy2UZX?}>+sH;A7|h)DQIlu7JK z@=7{O7D}#5aY)%peUw_3zA0@f{Xu#`hDqj;%zK#yS!P*F*=*URyDWEY?tZ$vCU;BD zN$#`U?!DXhyzW)rJC>J_50P)V552E+Kl*;Z0;z(YLb}47;!Q<+#Ue$7lCV;+Qp*F3 z2dWPeA519ID_biUC?ixvRYFwSRq<7IRMS{L3dh@Q_okgRi8-TRKGy~ z*xdU_CU z5MK9;{#n?wugHlg$*8R7=+7OW_rDN)@#e*4v`utx%FtH`J%tBkIKR0UQcsvW8qYK&@zYE^1m z>m=(c>-p;Q8kig2Hc~VuG~qTqYr1Lux+$mtG&NNwWGUJ zv9qm9uB*9Qy1SuAqNlc3thc&Pq_66m@VCnE!rv?VMf$4-Lgs4;}2e6 z?pEyG-TQ_xM64pckXQS$zo>uZ9S9w?9%>%W9l0Kz9Y>?6QTZogCtaror)y{a=UC^j zFK%7bU8-KrT)AFdd02W`{*DE%*WGA#;3|DT4*&>^0Dx=|j2}$?SPTBgw#>)z!sEa6K^z04<(>tU<4tF^K`-z4SF2yI~V{ef@X( zV|@xz{<85uub}0Ef+tNFY!W$W=SQ2lBx1LJzb0TKe<2hd3{NSW`*VUlTCV%_v47kHBR0h{Gs z=~oKvAq1E2}@();Bhh`@aqjkB(6%r`L8t0O%iP{cYKQ*hK=`g@%p} zMaR0f3xehkW+(|d2J;vg~ie{;_Ks47<^OCny>k6axx{Vqju`2@@NP(l9Zxaj}0V+&>fk??iZ= zi2hDjpb!Ws105Y53;ZX>!@(o|zn!jT!40a4t0@2<3IRJ4lmw6k4tqJXnxjHkQe^e^)4~K#ov06=nByD zvAmtaJsQX5v~vYmu+}v_|Jkv$F__CWW|rLyZ!X*PdSy^~5=yu;kW3Zzv0h<8m3QX( z&aR$}7D}sr@d~)WZ9b8f_Lm|gk(9EPtt&b&s61xiTEb1xizv7PUR?pkX;RwJM~(Af zUF&*RK+nedEaD2djIcP*%uFbF1X4w^W(}@@6s)5Qi&ON1>b&D{SlFgggJEe%!x|I!z_I40pU33q8>cbsq!kQ^^XE zy)+ESrPvAnm-heI#Hjrhw8>Sr4t@o25$&8dAK>Whm|SLH664KAN>g&bS~CzeZK8}0tBz2d)Y3Q@(n0^}+#H=ED%Wve&A8rd|OvaSH0 z|7nzey43#<8|4?`90B7|4360-<4#A~KRcLq6_!$6%<~(WY}b*$h53M$g(Y!J!}?M_ z(3A~FgTP!^TyH8tz;Dk{QB<`U?m8Fw%KKnb@tu=)-uG;;)$$NHwaWK{MtId)#*hNV z3Ij?@fgJd6&w9IV*_CShS?_n9(`(1v>{*%jM+SY*)q=!$N{U0Qvv*+}l&Gk*mCIE} zZ{177yA`sD+$3j^P%O~{NPL%2%WeHsgYx-aD1@c&?YH-tOU%Yo6thHD@(bubS6S!ddH6HeveYu8fIE9m9^ zE_{-{SWYKxx2*Wl<~%(#Q6-wQMDPn|b~Nz6=Xzee#l1L~n)OBQdxvAH@%re%L_Pzr zWfsJD_o&vf3#1CfIxnrRWG?CTMx6kk7Dw7)+d*>=d z;O4+l6QJ)gzWwd3Y1xxBUGWnQr=55+v%!E|s_u92;fwfNc0v@s5niS{=CiIdOQ3s~ zB=1FYzDN6?{Qo{5G+MVv1*mq!N-VY8|LC2BbHEJ|ROit$JXb(9(FM)ej*UeJ7$cd3 zPvwn*EZO$?-J)}!=7SYnA1wvgtwC^J0RB5NMUyxxbm?W_*HV@rCu2`cwypryE1<*T z2ps%5TN^d17sxAMqOSSW$v*L2C@$mSe%fWmlciT$3RWQje^23NU13s?l5>R##TVHa zYB4|6gx-*J+z^}j(ZiY?hQ)ejlVFRi-tK~R+*T0Am z@4})N2!Q)T+r-Y*84L@_-#+sX^4YEq<<6v(eDURm%YBc^&0KMp%B|V$CO;Xg>}c** z>A}C3W4TXpO6*jF1E!2@nDvM$?Ks{8UTQ&cQ7u6F=Ee=qfhFhdKv~;8Y{wpQq(P#-(^{$Nv|YTLT+y}lA&?E|g{ zkJ@BjcfpVghi&PrdI-HaMiN?O*zqr=#|G(xor8w4QgEvMY4r~ZsA$Zd>_d0Ci6*$-1@@=KU4@>AK zZgxJh$%+2B{P;{(vBKU@FLMD|>eV`?UF|wiI7_jf9lft%c>P@BY8>k|Znq6>2}&pawpFrR~SPrAW4p2JKo64H7rWMA<6$uwS&PK@^zs06OmyzINx zBbzI)KyixCl>N#IyzeQNo>Hh%Y$VFQJ3ad)JDMx@_oQ51QGB<6vYdfvmZQ4z(a=bu zY0WoQXSr6TA&+L#C4tYTo`(KaTOZSOuYh8Py!TqV6=zu$)7;hb=Cqq!JCD7Mlb_A7 z)ydZFH*}^75@}E<%n2&Pv$keH3%>2)a%P9gUJd_eL;xL=B10y0W&*>kLsC)*-ildVK+y1`<;R^D5~&0@~GjFXI)axyGe zK*sI?RFC@Y+mRGn=~S2&R6LvpaTD)AOa$FUZg%N$mYs}}s6TH9ep z^0gU{^0=;6l9-sP(%1p&qh6TmAs7KuQ>f~tsOSRq`6VQuliHC%zKUc3k^QpW{%o$Y zBUQ6w&9GpwV>~%eeLuu3H8pzol>W^-7`tjj1dQ!QVG1wBkl9wCOkcq-%$zf6nEuG; zDP%=Tjjr~$muhrLfb1kes!U(9xluMMR3C|1i4|2*BlZ2>>NiI+66zTUE9nX(mZ9=+ zuWeTT=y@6P-K)7<6^(VnB|k}Kd0J(8N#)} zaR&N0>G6k(a!3x%YL9#N%zROzc={<)>I}ln3<0fL>B>`T+#S;c0~TFLQPfridIeGK zDjBOiw>mXli#kLdxlaNexAMMoEabHh?ChZwClu`(4fPkF_L0P~_}OX?yZyl7@(@gp z$52?>3vVWtBD1_GDsgOHSGnu>!8e6C^X*Ya{s1)#^|J5*e1{$>-!Vrz?sEkYpq%#3 zWW;wAzLhItrLVplQ}@hCOlAcKLP@+7EzK zG%?F^K7Bg@%U1YKWm&Ia!H(5iwf+5dxNFJzT&5DsGX7L1ZoIuQ`NgYrH9j|2Gx%g6 zFKi%8$d$|hN%(;w<-?S6u)pYRJR8r9FUKI^3PN%M-g71nGm4+sVVv|Vha{8Ev|4~` z>**7kL--`S8ihqWNvT|sXp*jo=v&WvNy%kDQR|D0qtCP2dq=SdgP7(;rZN3v{gK5i zpTucKk#R#&+psvu{#SDdC+;C!kbBXEv~R=OQ|05ybTeBq`a3u;?JsGmfoMq)}(BHLZ@uWXmqnJarb)>s@F){{t4bLmP-du7_>R zkQ(b!@h#D^31?q0IM1VRuQ@>ckkEbQNulGaJaLgjtT0#9V4? z4$4PwZJ@-y^)VrmC|+1(=v>%^uVk40^@l{36Y?;1(g0p zCZt)Eyy#hY*)rxX9W|CVBhJ0I3C}#-c_f?>YUMJSwVf+)00GuHWt% zS%w0){7wAh7MI5qlN!%F3}0A!-Y`=AF_y-nJ~Q;Vs;X+|B*Ry%o;E6;n}U3 zCwL#*`s@`z*5)(tgpsV$WUKFnP$I3NY#MhDk2Ax?hpHmClNy%Y6~YyRv8?iM$s5hs zd#k26#cu95>=SEAG=5REz#LIq%kKE7sb8&gHSD&`!~Ba`mv$-n7{z_ zfEYvgnS_)mBMzK()PzLU*U!$ftGBm4fngG!zi3Km5!ufc+0v`-89SQsIW^9Be5ZK+ zJ0pLG;dUEYau-Q_tgbDmSpICIPUGIZDu1#}NYzEY47US(0${GNFSkfc+cF)%ERHg%q!b}r=V zkfQ01?9Y&-`PIh}1_9!B0m4t)>PUaqdnMd-J~8k+IyuRrSP#N1Uouma=8!8-I^BJk zNN3FfE#P<$trbISbXq&|L6y2YA0_6pe#C!0ymWJ!VLD9qu388d;_-OfV%Jy6z2^+~ z`a=EjCwMgUNy=I8Bwx(LW`qWu84O6s?m&_3ZS7iHqNh1?)6|gkv6TC2)KZ1{P{*!& z;{Af@CloJj10=}V-t@u_4Mgjm75VLUE3bg4lMC;GUk`1AkbNwKDjI2XjjL>R`qP)R ztIMK(Y5G(GlNR@mj)_YR;uW|y+w!>rG3%SL&bg5D@s^^G8AfPK?mqanVjIm9|6VY$ z9voiuTkBWAR8E@&Y|1ZTxsMYmw99Mjc?Iy24EB(?dsaqD@-pgn1KyRlDt;djj z6`Yt5vhx`C^PwK`&Hfm=g(2uBl z*3g<>p}r#5x+JjT;OD}NZ>&akD#q{guCl^eDDag9(10}nGF==bunAq5-@XED^Vf&+ zN8KJhz1t7D0#q6xZ(jTVz=2*25^#Ea($OK&e#e!(^8VwdsH@9?*16l%#-~+{=yIok zAYU$>C~12NHSP`HqO?VuEHghT`^CLhS;hd`CHp90`d~F~J2&cwNQpk0WHM{NVU<&~ z7PJ8JfvM(86~ebyz@wVx1K61JLUpGGqf;oc&;_39=v(cwG3udVThkmD9%^hLmZDSQ zQKsNiu1j#2bzjqG{tN>Q&U1bim#Si1mW&55xEW|Ua+^u=O$k{>1k4{sM|XMs;&L~f z`)+nTCioL~yMYH=@5CGtwtT$P^Mp1CSg!&Q4_2D>5d6!v3Z zJabj?D7Y$Z`)If8tKGU1&iPRq%gtyoeZYDHBus+NIR2B~qWP}{}EU@cdb0Ua_-GK^VTR#zsQM-WE;8yn#x zV#P$UnRGE7&5Y;Li<7@r=1fd@Y8FmZS>a6RK5TLwA@3fFP*%xt1^Nr;8gZ*ksC_K< z%`su5)*y?{$?E)M-4`(AhHOk%tDoG=yaGHInut*f%X@Uj7tQxZaLQSqNPR0+Y^S5o zea;&`;UH|#$%xnBysU-p+MDI-!VYrT+QGNw0!fxGG(({ zQQmhYNKcjs;a1+aydzy`L2FU(Rk>5SGZNBc&SXtN@mh8AjQtYt5~+C9bkv+uCBsyM zGFXnWz=6lJe$mMhQWni9zu#r^j8l>#=8%Pjnr`ieaJNI9NbVA^{>1&E#WRJ~<>yNe zBl;M|h2o90o}tdH5im9a%cdeKj{JG;Dro6I#c5QZ%#@12f7|LyOK)GMf%d?!&P zg*%bLP2$kVC2cF{W2Nwx+#$1Ht(RhKD>GN|pFcHRWrK8Tq<)W=>kQ$33D-l7EoGrp zAG2&uAZ5;+4bE8UyV9`w61TtCHq-1EqGp!ja03coA->nJ54-!j@{f3Lmj#V&S~`xl zfBZD;%A0anZT2*p6E1WBe*V{ulBO33cq>1SF-l~*tC`%lz^7PNQ>^moq&IhX(#LmY zAqzMFKQu{6?pI}q?4s;*3nG~1f-`>^Cx1Sy$?N80JC>k#{(>PF5P%_Y_A$8*DmK_%`?pD=*)ZW%yijqI8ysumha-Iv5Gg)la*k>?{u8rM=dcH_*wL;-b3? zaR>%uGfD)UbCAj+7G)DrR%MHPB@_K!wB@Y*-dn#o&m%%akE9vhC2BH6S@P7rcgT9V z9mS}L?rs%l$^UGw;ts~vZnFv;Q;Eh^3r6U>DSt|9{FG+)aGLSi6@ViZkY?=P#!osc zQtD2f$lUahwXnhStfsrC^FhBv3jB?`%dXE>&%xP`Z=#gLLfXumB~i}MUF;5v!O9Ek zHi-=5{&1gh|1Xg^Hy^#EqmUw2VmdXjTyUki0+#ICgK4GJzPG2x^SOFu>O(Iwn_~X z6b*0UpAK*{UFy@XHtg9P5-J?3UV27z%E#X0@g;?(=D`$6T6tPV&QbHu>ItJ{oA@}F zXR#_3cfZzA;(lz?z(w88aWc#X!H|O3>gZd%UvqQI~P?WJy*^fSz zJj%HOlF}G20_4EsJ!K1Vz%W89DF86|=BZElj;^zjMrBCbk!1sAK%re_cCXSPQ0&v> zfx%aer&Ipz??ko7g|~`gBuxS-@w_2-%}w24d-~m)26?6v;ad`Va}h;z>(yITOP!9) zFGDM7Da4M%mYSN78K(oSSr`Fit?3k>%jTxl(^HOL#ZLN2ri(DqMfzDDXqgi2@Gf>2 z!?Y}yd8M(-k@R&v8P>-`h2e?QoSGNj78+R|hq&TFabi=&3>FS&zre%9m6D;8EWwvZ z0;kDvU$!8<+Q#&qZZI-P=I&c@o<3(qspbX{$swcSd>a+p&a5~GcAwS-mb~o^`%(`H z3%c}uQG8B~_*AsWy?5iB7YW@GIW;byG3kHHhMVI}^70vRt=O$!kgo$+6Nk;{dGRth zC=QJf^9dzNq#XYnziF*zF9Qd0x?ky?$q$E{?cq^D75)ZYhGSA@vjyh$Y zp4t=J8xFJcZ{@pWXki?~=gNy)dvvG~Jqg(;9>zX~9$6tY$gaQQCFh10qMH7D9d6*C?#+@~t z;_MA+Ks9B=3!SG_KU-|^ai7Zd%VEVQg@z~4#|^~DUu3FOoIDmg?v$Rx6!S0MM;CfLVyc2BXRGdA7I#*}+b((lp zB;omKSikD1aW4LBn_!^;JTs=74v;mpqI7<~ETHtHPMiq87K@spk%(g5c;cxH*}6n* zx=cZd`Sffj`AeOkxE7M6Qj)>bann=&=R&eC<`^oG>>n<2eOXiYdEUN3MCwDl78tNO zkR@wfj08l{4c5XdLEnpuX0S)^b_o*XW6gH=duHYOEeh0Y7j1vp)Yd7}UN5fNGONol zY#2Ut3^tXIHRYQLsyLq9XqNA8@Lga>BuzO{8aU@W`Y+%924H=al~NGuSzF-Wu3#N|xm8~-K`CN)ys>tprF(adl4=_AM? zzVglpbuq!s;kDK)0Nb5^XUVS_N%nlgDyfw$g}Tf2L$~Y>?rNYNqw)TkFlKV757)FV$}g5Xvp* zmrQ@1LkEZG!d4aqD(UFwu%uDC=?>g*+%6(IYIYRNH@p?T_TD-v zg74pH|2f8hma!^hbSu9H&S8IPFKG`3w%{un(fFz6L`eMN$Aw1LwXtD8u7EJ+(`B)( z=oSQ381K2llA2Q6>6w}{+cOu#`?~a-i`Z=8XyQ9&)z8k!@&d?=W7_9fkQgsj*}M{4 zSKq`Gs7IkAGyqbk`PN1=VQ?|VG=5W_&R?@WJM#mq9V&y`G7}Q$6)cX@+u-cicI=R* zbFDYQ}UN1?slbTa*NOD_FtUHr3oU7U9qA0`W6%y)!KB-Zu%)3|)QaBTJcG7T)rh zI$Z7bjImOs)yr7$mkGV?kE~gy8Vbd^@}k;x^Ov?$idlCGw+HmAd^W9aA8ZtGR|j_# zt)Jzf@keEnB90Or^HJd2FX>ZDC;naTZnckY6XHyru&_u0MAX7@OQC>cRTHBHRe+MF ztC9+qOVgui9-DVLQiLne7MAdCDPPB>Q0V!+VV;V48a4K`LL>gvH11>CFadRJ>dps` zZU-?qry<>+p35!r((cgMeV$5)DtXKHKJi(g%cD9>F7GfXS|HtW^j4>($8KzH)W=*k zr6rAd_30i9=e^wG(Uq@DrMMN=&kL8e`WBsMoGZ+#tET4Gz%v;N#M;bXN1P0A;Q`za zhnpBT76ZslfvKkY*_^JJSFYW^jy_)qEc5D4Yp97*I8K3>K9YlH0^in|E^F`9PI4u z5ZIKs5GSaKO7L9RuzTamoa9583;U+{F*d1YXMjSbuf=PNL81{6pLXkkEBh;h5FYmWCVT^yA`Sdq;ubrzu$&dnE)qSxSYOqB*E41AhfQoX}7GgT0PKDBlf@Z7~l zx8K+^E&FQ8g56(O5%M`XbXI?G({RQ#k5kJ(tt>|Mtp7B94Sd@MCri@|ji1;o)PZWh zq4j;+PkUPL)h81dpMxiE!)QOCygR5Q7!%xWDytsAy4kC>VYA=RGF-ecMeXCUJ41qv zV9X0dw6oI3;g#_13q?PHkW|Zrj9(F0_~L7;JXlWq{1j6GF&g8nUyxE|-viyyB2yKU zZ}n(@Za>cxXIxIt7|w8%(>o$@Iu)z$C>b6IuVrNuKqpqXyCT9xjIDOJ1;zA7@0LpLG%~?`65u`^;m8 zpt=5G(hqv)Mnvd&qB|*gpx2&tiY~l1kAE}Tond|BHpczj$k2x8-;8mB+axPB zIAPAZ$tJP(^fqIPHv>Kbv%1h^ax0yo*4f&h|IKP%dXGoFDvQVa%)1=xP6dxPMpAgB zIJx4rw3D~suG;z|qv{Z3V&;qT!dMs`(;uM98zPdtt+dVC+n~R)__2Pmr7zLk=TIsRNZZA zodAb39*KnQo^CR1DyBug~zFpm>YSQ;H|>nrQcnR;Vh) zZSqW*V**zEdQ)UgCyF+LeheNkqV?SAsNCfJSiAAUnpPQX>`3Q4?w)*kvgaRO7-$*G zOgM5JOJeB*B$y$%C5mDjmvrEC#-st9;^bJgXD zR2uTRz5<(6@$Ow1P}{<~obDbah0AlvJ!7a|Q#wo!7nt|&WfU?!SV+wJP?FUVG=hp* zG|ww0^)s_bz9F4sbdTu@XwtQxE9>(z_Gq(WxG8h0y}G|Td$Ggvuur@|N}b+vbyt+Z zUB)$S=B*eb;koQ!R}>4}C$;EmerXo#S@8UAjXmXMq*^O=2~asHx%$}=+>70=b< z&ld+fMU_VyJw;?vdI1wj#d@u#b2}3?V+RIRJArxL9S+&iDDU0h$E4u41d$EG|B2MC zLX^Z>$79yU+KTvbySQ(XDiO3Eo~3Wa-NuBn2w>t%c+|U>jLu4$nUns)6EUNgKR74WW zg^311zs#yiDytV@QiV(4IqY*i$C2c+k<=az``jylG_9=}#|Z4s+744WeP~K}-LMAk zNJXvi8bIyFcUQenU7{RQWoipkHY`)gF5&x^P{gXnkoq zTz_&a_|W8b9|za4Pf8$%t8U7tcgM!YvkdD;&oAlE`IFm*19F99mOr&m(>%nF(TVqP`jieH8r_iWcqag?>c#adAnqq@L-+(?xYNB?pcgL52@7#_$3^?e< zLJv>uW`Vzln<2Z}*5(8`_7cpn-!QfBP_tGq9TD1#O7}Za zkJbIoT{83KAt3YNg>;>@@I=OH#iDvcp;j!p+53^P`t7lKH9viE0hiv7o0_^sDsC#j z)ErW()%!}`mS>_Oz=*b0KZv(8OK7+%m#pEb_D8f2*wd9t#|fqruEpTf-F`9ubANNu zyovNm@E|}(XD#c;G9tan9~ZihkI1_S?gL*R?l4PhFBbh;y^&oVT^T6U`M{)?qO z%WZo{zf-I@H+3+=)PPOwtngj!yQI1a6Zn8P=TE&nRlEUh^Iz{y^rp*4RAu|Dd!&pI z#~LFrS*jQqe?h3Dqxx6FvK^nEDI%9&U(wNIQ$mSRU!;IyYnxpJ!)u6GsRjKd<$iS!Oa=b#e~x> zs>_(9+Bo~MPsDEZ~G*^lHIlC#g<-GrK`(vFf3X2Fa@b%O` zPPho%ph-2fj-2VmXOnj$EDEo=-AVIlup%K6?#zYC(yy;*h<7ArePHBwt2N{Kbn<-q zMiW|S_gDz{=9C(^@y`OM9t~Y#_Pzqlu8)|7S1xxW3FAWTK29h3H>;Z1RY|+g8FcD@ z*rC(bavI+X-pyQV++rlWq*@p8GdTTJ8bVIXVLNwpX9nRey6WM@mFmgd2?xCp?$>wb|-VSSDdcI9D>(}OgydX*zj z@_eb@OaiUwxaGOik5lum!3-h{xxT5R9d*NA71fm$t{hdT{$eJKxO2`j+-YqR8sGY) z5~ChRCzfLTc=CftT?4IABoqVOMcKUqzEoO}n1(vu`mq~^op?di-4eWt+Es8pXNwkG zhLf8<5-NRj?{x7K_&yi2h1dX3HJP&1wz2Y!aSW;>tBDl!Sol7$H%%-=4Xsqgmbq&sNOk=%T(@>G36e5SZd zSiYqRJFU|=qFv2{m*?{5WFh?PFTWzPhOJ8HE&xGZcVUgK57*duoJ(5zW{r2 zezKFkPw)JoanM)w9iygZHDz(-7cU|YxdW!Bpyk%6kj0S4^7-c?NUW4V#y!$GeeRoy z7vm>3n$jNg#z-?;Ux!20=au6r@MYJeH?Qkg6~(R81L(X*@+iK(TRd^>)c%+~J*?>; zV-Y;DJMv*R$x3`vf7q}Yi-r{~j$9Dy6b!uGDQUue1m$!MZ1tz& z7OS{@=Ovl01XLqwIaj(X8fAb4_ ze@Nd`lB|iCZ?$i(4Nf5IPpT*}yU*j%5wCRG)LDJWDVH}56GWWTJp8%&ojq288=uO8 z7xAFwQg3Uhhy0cmhgQdwgt2lZxQxG9;|gL_cIAte=H`fmmCCFARLHpO3s`9IG%Xc5jiWce>Y3nTliJ@)|1Tktc>lZA$0i zJR~KzNSj58g0`j}O{5;#&5dKRFGh#I)_wGd!I|-1^-vSJ!)fJQPSN!IrfBK&rRr+e znS8N<$9YM20m7A~#@-V>Hl3v08=Yxfu;X%#)QiK3+yv5iXT1^cJY(&mZ^IEK=|8Jo zg|@tVWf%hFq-^1#_Op^5j`(Uk{jc$X#k)M%Jyw!F-wlv^5lZ3bN+QW*QmL!@56qCY ziBT{>?2*!MTe15WGp5NKIw+H`>F#Z%vhjy% zdi*=v%QjjLxnJ#D_a~1##i+$u4GO}07rAt4KTu)O_Nz@6?;fu=Bydxws|!2l8ml8XbYrWe;X+a@ac%ao`0R@wK$o6V825qK6+=53@ zUE{_Kwz`J2T#4xY@=ylUg?v|)wBpV#o7VXqHvzGMk+2Zw*##uq_pXkZcU*>a8v-pZ z;_|DzcTG5|d23}N8$QSIw0&3i)m!H1p_**2+#lJKHa!+>mun?|mrF6gHp^fBYhqW% zbeNZ=dCcWA8=Lk|hh?kAyMDNyndcLq&s9;D)(bl#6S~9f=>g{QtzA{OLP)KZ@?qo<`=q|>o(d^NNs`;wk3j0=CKBa1@5P!=0*y>&Y8m;y`jie z`GfD~YRMx7QE-*dp@ENe&m=|N-|Dl7_jBpAYZj?k)K?3djSV;VNM`W{$hHkEWsyue zyf5pL{HbVOL=)Gq7|@SJy%)VL!-m3BNBI7H>zLCg{I$uT`&EHB#$Bir+;n`Rt!5xX zihqfS&a|XwS;yYxhWF1TL2#qQ0lwn2*ZG32N`@-s#tW-z9@?hWAKoXUBYm80-0xHp zFP6)j9cqn2j%;fh0Bld!c^ab9njHX;u%-BBE8`nB#O>M(c^XilYzf~F-#hjMbXf7 zv(}Z!IQkAF=rW$yXc{^x(@76{v;||+7h%qe{A2<5I`NEzshS4!#3OS2_WkL;U`nN+ z@3&E7{yYUhU$P+V*wfhDI!6M2Smt)P`<<|@LS}go*npv;o8=mawt>;nMdn%h59`6vffZIst%8%>p z5{fOtWMN;b`J{pg0If*>gHJlXqNEOxZOWg6bK9Xx#77jtOZ zM)&-lp#29K?W1FgO#@09+Gdo-;yO$#U*?_D=h(8FRP1VHKn?Ef3u#G%I8=@jf+8}b zJiy*0O{-pj>-C(-UE=W69))$!`9|NLH4a+2jP~m9YRSgVCY<%tL}x$rrC8Y;|1aj= zE2ybI`W{6Q5do1V9aKO{=z?@ml#(D-Aaq1Jk={ZHO^S5sN>z%8lt>AoBfUv)lF&o1 z2{k~#oA2+=f9~A>+?V@sXYRv!J(D?S=X};)d#$z0TZ%!M8xYm~Sw;Gww&3sGXQ9X0 zZ8y3xqEtEx2Nd0N66tQD>`Kpm=BF(k5%?tY7_gj~{>x_nNSN{8{VU$T?F&^gG}UkB zYA%<3FL=3*2)<+G%xBdwrY~CSg#1+1V02kp`}P@4rEl`0(o`9)4$>l_$ldu&43@1H zD?GNR3}__OEALbJ>@-P9VxbNZ*W>PEB4Cd98`Z&C^XHp}ukCr?{UfO{I13suwNUqM z@vc6c#CQvfE?e}Z-$Gr({UO2V=DtFefi`h;_~ZA4dPPV*j;3=Y@|rP^4d!LsicIkl z2_-q#^~lfQwo(AM5{QuV({MTwcMcK9kTW8r^uVQ`Q|Wi4?-|qvENm@WiDbxC;hn2` z^ov=KKtV&>wTP^`JeVuA75O0bcA3Y6$ftA0fZ%Q5ZH`RRA01~S)+PpeR($$F|LK1Je{r28-3aL`Tp77^Ckl5fEOjiamPF4Pf=vFAxN==cm7)6D z;@(`h5QR7sZ=VgK{mU)#t0 zLVunglPviufTb-&BNVj!mj034M?u4fZo%L8dW!${eruSiSWQD-wcb=cVttYoVtUOb zv0g_&&R}K5uhD;Xh>i)4+!mLTmB5!!>7~z8*&Tc)Ef!j~Vpm`H4x5*I{_LU?=7O`! z{v5ZFn+)RU7Tb}TaX#njQH}k`Prp-e79HUC7Ej=*Spn@~b7iQ}gyHn#~dOFj#3G(BCX;Q|#3CKr&mH?&wmb2e9M8 zfbYSAQapBnj3xKqGAfw!YtpER`51t?oA-ul%=c%^HW9mOTJc_;1JxCVBeTc@fZFNi zx#0;K{0)I~ShiKV^(-qfjt9wV*JER!R;J)Zl|%PUjoBOXLw3*j`Tv7uXQgO%XAd%oQ;0mR23xtA;Vg~m6P%>s@yyzqljBr$yYyV zq^dt^q)LD>5yY$FB z<%K%W6~J?XI3Mc}Wq2*QNC7QaQB=|9)+{eCz&#jAO3Ln^5!-^ZVb1VJylQF4}9Xj{!aHshwvgo*nT=af&TA&+uK`hcc1xR5e z+5)!CkJ$aga6mhAUK1;~`vh~&r^f|s_#jfbw4hfyN2$~w($yy9=y$$s< z$XA;}w1C)}-rf3md)HOBGadO5s~)h_8%3ZfBM!rWM1n2ypt7y%}ggbM*p z)Iv_CgjO2j>^ee8_v6Kv=tooJPT{{sJQp1jA}Op()lz>J1hW2&xq zdW&MSabdTm=|va&OGgWNS4gk)q}ml&hBz8`!(1QdGcQo+Yb(Y)u3DkGSrSM5y_-Gv ze#Wl)A--4dmhSGnqTA>ft50;4j5cmD81@*G-Px89=pJ(*8$Al%FRGmXgwr z&cSuWCmr2sq1nTlu-~v!Sh5JRk!b(*tb*B1Q@bgm4_bOw#J zWQ?R+Ss>62J4z3+H&Ve`RDvI`6oRzS0rE{md?a=rmK--R`~mf?a<>qUGV`2 zju?Hfovr>T!QSnk~*8KC5Q@d+rT27U54GFCE%}N%UPYu_EAFr{Pe9wiu z(Cb)8IVfWnf49=%n4+bU?&E*1Z=*N_GM~)B=Kv`h{HyuJuCoL=ycLF2BzMW`KC!WP zE`N{f_Wi7mK3My9f#LD#b`Ih|j_?Rdx%m0zM?BHIC&I<1Xkt*F9GJoEt#$U5+P5P}iJmLJ#M#iF7gwYqx?MD@!6}5B~ zW&4_%$hX{6n300t#h0p8sPan z>~qW)sYL07ItxoRZfeL9cD`WQsEstaDN#bv0LqY+;rTEY>Cf02U5@|I)2Y&?7N_6R zY@0(nATKK+zH{cU90_~6qS;38l+Pe}Ys?!mE2(%-5+Tgk45tOA2}OvR-!DVrOwaeJ z`o)H(yD$Q>Z}bhuYZ^TR`qAZ&;oBpSm?O~6>}aQfUTzM$H6wN`0HnR}5x*K_4xw@un)BJr| zi^k>zcIh0hHdfK~$K8B~ zGoI`p3GY3ci?jT513TtFO4gs31q?b>q_gzb8uhg**j;B}M*m12*9b`kJs_vmXIulQ zB8KcIfIGxbug%AC0?QBL>2U-#d#S(1mrU9ASf+B!UB6-=K>c_DRn8Zehp{vdu~qbABp*`a^uJU zNW6FZbUZ0~RDN6AW zng8SS_6a$j9O0I`W>DT)6TOlp?Iz;0t_y5I>?u54nbI2n1#Vth4CpN0`k3Ugwj^I0 z5fM+H;{c(9-6rSbZW9`G#sE60Us&p+FAhy}Sn}*N$yK0P`(ug3S#$O%&J;BoBwq5riF*FBrOIscl51mS`1$XGX%mMSxnZKgJfF zx(5v#&8xfAxojqVD(R`|MyiW+_D`9WRvNPHuC-nkAsl`Z9$_t#+RZehag|6rmmkou zZ?7VA_;momj20i~wEvN0laho=KtLUx&+!z?S0UhOg%?kL`hcIdQ@7t(b!T(_xfWNG zaXv>PcH`jr4vF-<+Twk;3%jlJT>i%IJ(KiJeca7N&KxIevFTGk z19Ym3mn&)(*-~dgqXdsWsbb^>GsU#k*8E7**U7Za+Y*L1Z4-GpVn$%j|Il+n_qDtv zeWy2VGmP%U7V>^cBq5#5ifu7lfe$lypVgjPjrNZnzKA2@WM~RPyy;Cf< z%;@L3t7%bQDbF#Gdsq2p@E8RCy7d9R#KFbO8ul?^B;)4I-Xx``>45wd5|?dtkL{2L z%}>8W<(D2*E$^|pSJydaJ>0ARk#)t3&Rf=e>yUFJMMr8w8B08G1c+kFeKHgHdS zH!rJZ3Xg1Y&93`L@+>I&qUN&>mr4aYSJn3t7NB5qns*4DsYx-cRpUI7sI2a!nd|sZ zTHP9XmEX#22E|zBo6LNr+icWeo2oY{u}*)v{Zq+)~ zj{W0A3t8*brfL^A=9A2R*i75Y;QoIk_KOGyfUsfCZjJSyL!5S&Gu71stkN7~;R?*2vbn>X?4NTr;Sa-0%}6%?{U zLwI<9j(4rAh8S-_Vjzw3Jr9hwIigC0B?N~;ZI+Tc^QpJ=whkRnFLExQXRPe=epz+h zm5!p=Wnw;4m_b=%{iA$X>zT((R8=f5K{F|hnwonv2MsuxczdgqfeblAXCtqq`s|j1 zUGy~i__C2xcv*)vh!3BNJ=4UBChjLhT=0nHj64i8M)PH`%HHU-qaf4yrt-Z7~zkdv}FF%o>fNPnz_W#nat}kE{a+!GFV*KHA~!X zo%HzUC&*iAi7v&H-+>RsjQTe3U_OMQ3FO$k^d(VBATF>gpaEIe7?$>`;c!~?soSu4 zrswv}$gP7sd2G5keXl}|Sm#8hkc;feswU844^o+VT-6>|5*nDxjJ>p_I-UOeigBMj1NjwkmeN)(K?+|8)R)h;uK`s` z7VF2t4#MZXL|2k})@(7`$Sbdmft7xpr7F!%V^{O_dV7FP`z?^;H(Z zZv|pq1|8{&8o#hwxHG4?8at0{NPg~irfD+#~R=w8(eO7%8KGC4F}HSrl4^ z8FuAQTc1@?pXuoIG+MwO57H2-q#u7e)1v;K(@4>ZWBELLPNv88(((;sDC6T z?)5)L=)Wu!!5#IrdIS~s(`#pa4(6e4+tUC_erwOcHwkdL3IX6CBD=MzH8)><6IH&* ztn4sv^cVOedrxpFpZ}@kbP9$ntjxeH*LP9O}QEe;NtW$s=u!xnk4^` zkWZd3n0}6@I&uc8)2is3j~OMXavaP{;VYwuc^G+U!{znC#e{cb1*1xQ&}i*Ym!v%< z&#+O-NRwZRQ3P-nc-5G|^Xp_=Hv{Ja8}8lt%We)CjfC#;Y8=xnVElbozX0^&8FC5U|R z^73~mEpN|TDr8J~TtZUtV^x~usJ%itlP(QeYKa^6S^+75u5?US)6y zK3}B)+xziE`QH4-(1fcUk|&|5xlWLcGyb-E75tkLfIyr=cPdArq(6RdY0k8o9BOu> za3n+|7fNV_{Zg9hA(!oEfxo~{VEddyYL%l*?-=e9jIcm4KEXx1hvrm7`c}wLr$K5c z_-`H`Yf{4TB4x+#?n2W8cN$=&wgz_{t!`3TFzUkbrU)Ki>&ydP8A7$*tG_ zezy(&a#fH>K5HClwb1F;D5w`wT|aI#U6m}1+q^mp;kJMUDc^siN$WEudU!=c%msgr)5=vq-?gdyf7gW#-dj8(Hf zub*+P_%3Q?Ine*#jr_16g=s_*1L1Bs7nx;6Yuk4;rGbo^fN%qwZZuV!!dnuJ-l074 z-i`FD)y9OA;cqgOypu@)vgMcZ4`hb5C18*etV2?e-V(>FZ6&rQF{a-cPtL!k-3uGV zdi=$zzXKFT)@LwN*#RfpcwepT6{WzyS+eydQ6)+36Or;>6V9NEZp7I6(^;2h*yU?W z=eEDUq;So1_rFan>lWqMeePfcUln*PN>!E21!?Le<~eqs*$zCEl4@iR!Uaa>g! ze3U@%&g;?F$IX~4y=Q`%6{Zs3x(IrcoFvIp!nyo9!eJ+Q&+@4gj;0F>9jD>4)Ry8f zB|joHzo|-J>gd)X?RYO`xQe{wN%74d|~veJV4^6$4k+OeYZ zfl5{QvO)Qfp=Lkozy|!~BPkqn8$yBC#AG9KjB>0JhhwJcy%>mtV$0l+JfcHmm?c>^ zRR0xoN8-HmFLuL-l6=EIJc~8yd*2(&-m57`R{2PeCL>NJ)L}y-kz5o$`D|NY zB|hwQbE9kF_t$+@KkU-Pp7*s+MU9?6gHn9tJNO+_gzPxed_6l;OpbO+0e!ksp5*3g`1Ij%8g0-wz zfxy^$EZM>k0I0QN@gYNBhx8-NC$fG)PfvAic~&CR)wLin6J%2ctiZ=^Rqn6e@Vif9 z$iZmvRxs(g;zbH%teF~z?hU?GW$|R#7(salx7O)fz9POEUtUgLr@;|zx>ed;&u;kI zg6fqy|FqNhj4EdoVgu5)l{1Og4S(J(9e6?;;HRR)Bi|1NA8e z8B@`^lGnd0O0uzhilIRPdas2h*2ZnboSxN$SHsecHMRpveeEA|UV7vaz4;>wH!yoP z%fodSr!iFYb3n@HM$0Ctkm@QKwU;?R4!|#C#hBG2JFIhZ1-&-XvTEE63wBelkx!76 z7#xH#uItB7#G`)Z67P&8@I2z`y-!;$4I%yDen^YUKa!#`FfpI8X3|ZxJq}->zkg&& zrYivyuOuN_LYU$e41Yk;U4D6dQ#P@Hvhm32`-9AfV>ZlGe**Z2f$J@lze6|_%&93$SGbp7o6T5+ zPalxen>R7U(@uV0p5R{NE0J23xSk53&AnTZ@r6aX>(@~7`5*NyOJ$7m#|2T`@s?$cMh5%_Mgg;qu9<9ju*FE z2CPsuz2hdIZ6=QK`**m9Xa_v@8T!J^7zo8I9!8K{Ok87R)}c2YcFtGdO+_JLbiopdb}^wAaew zZ7AF8iE9>gKg5CUTmn%Pn8$1w{9KGex3}zMSau^nRBA887^hVF3<0-LTw7}&>?3N{ zNZd(wNl1JdcNgwo6$be)YY7Uoj4v7>k-Euj>2BhTYK(q`BG1omoel%0ihE_8@$qKm zYd`}%y)>URgMgq5^u(XTpV()7G|?S`f>IpP0w9So;2?ak_$nx=u+70Sn7iMhY`axH z>q0~1HNX)$u3@#hU~31IHI&cHp~9f*X~LG^y=`m)OJP4kww%E^s+U>>P#ieh(5-5a zz3SuiR6$5h?}yxh2hY(-o`0Z~@}2)kn3`3W;jGTBaue@&Tj$G#;v`*DnANj(xEb{x zzSR`*x0#&I-*L5|DOI7*NoT2>%czjwIooYw`A728B>uWdpo7+@AVn|WNyYbhB&;`x zm|u_rUlG07Z@Ul%Y`6I7)Z>H<5A#JCKjRn1?RS~hsEm(^>X|a)yrMgVDM;EMtn!5yk+UBN3ALnKTp(c- z{EK-Qe>QwuDNtbLwUmOZDEj;6%(_;^6x!*UZ*+Z5I;J$dLylT-Qr5I}xc^BTX)~UF zNtp@RmGIOcA5@R1o#0r_Gg*Nq2di0p{l)T5<6K?wf@*qZ&$PJpkyC6T&sH)CdC20B z32WP)G>%@su{LW~_`NkE-A?>fQcT@ZwUKCj={`ec;7m63U`4wPM`~4zds4>!0GXl3 zBYu@`!Ek53$f&a^ch+W+L#Cptq0&TCur$+}Pe57F_k&*ECVnn65_iI@$%&wWmu_+&|=joq_CJ&X;DrtS&R%ZlR498_^64 z7U*6V&&QS)t6{I-Jmgw1M{?l8Cc(r``M#1Ev}Q#$A%hqV)9h{{{rT3N{?^wn3+OKm z?TTWu3I-SS+qSemH{)JssqLW zRckKAfARzU<%u(&A@?;u3m^kJB1bv)a+{=zE z*q=H$Yq3f`-RKp5voalX%Hg*>KbiEMxz1em7OlC263|4+))(~~!yT3s9OTuxz2 zNXx-5q{LkN6H3kY&vXXPzin5wk+!93n4FUc6%?^nM7EI&FVdm-gyqrXZo%(F`ECct z$&kBp)xS7HSMPOH;@eK+a7146_}u$PJgZjuwvkSC83l)qp~#;u((7Sm!DAf3B)m&& zu1ABlMW&1NT`LIOA( z^Cpwq%3rwBIJ6|^xilltgP1rZ`(9OACI9kGsCbI&qPM+b<=ObtkzI|WKW zuev|)mm%@81XFvY3q846!8+qZ@0z{4niX)ZK3@CnqSbr$~cVxvfCvWlGW9a+$L1LD+(oH{6mn&g+wUmVE1>HnGz48 z(!TwvV@gUArhwt!L2t?OrQ9&$ez=tXc@qw~axZ25bZSsg+9SfARg2;C&2BjfSk!bR z_TwO5E)2dDe?3JS?#=n?R zH{dCol~7!;h=GGTJ)fTRgNE#&N|7_mrlqyf~elHMGHS%FOUX@Cj z2(E%+MFc#1HOt%}TMZ|oMEcW-hV{kDkUqM@?-TVCig=@q0ibkwh=6;qEY%C^T~#Lt zr|d*4fWXjd#jx$`d(Atx&kEHCa))hDjO>qo(%zg+BikO3pXp~XXnoc*Y#3a{Cu5+m zE!+e+-E)l;p=ei;Y=tvYOn5j00Zo1f!6g%I^oXmw|45dS!o9PpL8$eMiTP21_3eNaxK_Amh$*XzFzi`(R&Hf}sbf9t8KfcQ{@+_Y-*Uowhu8wRX+Y3$hE)=20Rx;2C(Zhq31Ik@K9YQ2qrbHjIwI&exe~o<1T)e`)cwL^KMQ=W1jwx z1OCmkg&`tE*ShoIV83-9<`MWy{WOMCt8C0K#a%JVxtIam2hwV)*}*WR6I*hIsQ(!K z0C*=C^VR}q-kBPXgby^8Y8!nkLI#jRE^kbCPj<^}j39@Nto-(%m@8h0h^NO@lV8|M zEV1t`UtD6adzLJ!fzY63LA_$rtgNI?-wE(E$~B3eHWg`P=s4-_i7!vpp7)u}+gZ_J zatt`@wxBnt%Ao{2LB_17uD+}`@_kk7Zq6ajruiv;iUKbY8SE`r8%jCmcKYOCJ_-3IQ+RBb=4&aJ; zFGXo))>Om5l&sJPXRYcv)ql~!9f<|W#VXI*8OrZ@3qBTGL%t1j4z4%jd#yAOQ-zsd!aH-%{sv7RvEPnMH(yPp zdEBfNYqi5`lv!&&u9xO|=4^2#M{{LXRCJZS7MEscSM9O4GYOczrf|-(PhVl%LTL8A z#{}yCoY1H$b=h{Oq#XwI@WTPZy$VuLVxVWt_PQ}+%r`kYUwyvDN0oMVvSfe5l@NPM zveVv8NhDBh)bKuKcJ`}}gLfO79vO!|8l_WLxDV;vEbCa8D8Bnk6JvBdy}DA3pUM?L zV&yNkY!*%|0KYkvqegT`PuP_2S@HxvK_t^L@Kvfk3-{lCnw*QSbITR$3*7%f08`Okx;nNcx5mhG*AoK9A#E-AK!nsui`ODO1R05zHPb0 zlYtOLC1 z&1y$DBKuNkGDYEOIP>8!c@kblr51&BW-k@@?rA>7#sE|8SZc(iMz&K@f2b~!ByJs z%L)uIq&Y}A0j;eHN3ft^T1Buc;W+$p7QbdUhS{G9N zeAg1SbT!F~e47^S)iUq@#*L}cH4^+`TwDUpUf$sI7V zS0oE{f>aNunx7QRYM$VGy1)U^{yfN9Q>4Xj8EnYo zz~AjI_X!~1)7bjO3p>&Iz{u~{pr6ay7w$w=C+^1*N_OPj&Jt}s_w+&*S>7=V(}m~>%SX8DBrojowr@@MRa2Q+w}V_6jjT$>J(CxG zmR96kU5aI?a&<&C>D|{(R_~+t7J`G4#DnJiXNUjn_sDLIH><@q60}TJn&r ziaHg9(;Ss(O$84%9d0cU)3)k~MDtO}*#C+h;8KCIYEkOpt$#ovy3;P+{rmgONlFGr zB>pbB9>7Ab8w9ZRM_> z1PwnyVQA*!9*8g9hmpe$mGC#2knG*!7eF=i}sU8jA{cUR{nIwBp#lcsMg7|_r+QIit@ksTC#$@X# zz>hEUeTIfq5em^RLXna2DdS%yi`FP_b#{w4&HKz%S~JR&&(e z*YFn)V^uvP%8jXOb2w?NKzq#|#dxb8+?Jyr44ee5u@2JaX==!p#`l77=vAYpI3)4Quy2A$v7N{Ia&+ zqU&w)*mDv(Pp{Xb2rFuQ?{8fwpZVk)V9Uhn;;r!|wcihVfWJ4pHTo3qjDWDQ0MV1! z;Bi7+YG!JtndimAS({l_NZIJA??w<|j)1bt9-l@u8efP$C-RME%c_X$@{38C8{ zbgE}vVqD!(f@UP2B(!DNiGPFPJ?@%@L2;gZeWzORJ?k%ljvYA<3-2uSTO2gG?EBfB zyH%xaYrDQlubj_7^JMha5pDIEsUB*NE^`B!umH|Awsf}ii z7w35uIO*;S`}$K>nrN@9^r9zy=vx$QH#~1_+`WB6(Gyl?gApOtLH=K>+00M#=VCG; zPgGN#M9HZE);&&dG|0@2j`t#&8uO|WTScbL z91gC!<$h~HhtCQdT))KJ9Q5H zam>`ZhUHCyaWSqe0OF;fZNZvS{N$QE#6l?0uS?hr)Pkjn*xfyNl(&Lnj;QK=T+{PP z@Q&TB;-M2!Qq z)2xgRw)bId7`K>Is69;i1y*>D&ID!lxHhKRj>r8cv#sro)Fe|Ya>6SmQREas9#e*f zK59fh7O=7vq?zaDBSAij9-SGM1Xo-mR$AI>-R7?~FY6E$Hbi3BP^`6cmiHh;F;nUw zaPCZ*+F8yrA6wSdC^dA_xS-he?eWnY?@*za{7*A(eF&G4dbz9~Q;RlgBMQ0KG>$9u zpPd)K{g~p?KFE1{?O>Y>I?p(yqf~k~&4C&2;fZb}2B{4H{6Z}|Q9%K`+&|krMzmrU zo5`;}EZO5kYCK9&=5wZicKb5P_Y?eVo&2{_>v!3`A13xRgy9N%LpS!sxh~)3FI(!V z(yiRZm9QH7E;GWir}r2DMLD zw#5aR;vj$zk2knUJ^_2g1eMvKje|P);u2__W@F<1!UBJWT4YzOZU&u*L2_n$?Tj$hMZ!SAk1~(7u{u)yKa0DD&CGOt!>a|K}j5%s#YJ~+u6#N$5J8P?=K$K z`7J~3-ED*9oD818*pP0W`;JG>mMkq*S!jdAerZj*Mo=_Wv`tv@5s4Z#l$W8(T-B)T z=bZk=+PUYR^)l2qofjw|-9L1+zOc~(v1B7c2iM0_1gQ#>&F>nMKqtZSMsnYsc@}xL2KeWO+HWORgx-)L&BCJi%kueHbJEi6r0l0o zcZk3=F+*jt^k~X6o@3m1v3S06_rF#Icl>`Y6kyOLFCjdd6p|ca>Gml#?}pro(zOH| z&+xcQblt@?8%j!Z@=c2f&)dbC4MwT!vzVBwJtU^BB}V;ntUSk{JJqRO=v}XW_TCbs zcxM;WTEXk4?zQ#OjmVw{NsiW15MJ7SA92F(5)`aC4e!j-0#xHjeNGrdAw zR?w8QZ&e1V-bbz~eW0Vy?`W=$J@zme5*)J}5BCHD7AzPuGOISMJ)LW9uX240(7x5% z`zMGISRGN96TL{o$?^DtvLeLzcUrq(-)OjeIv*RuQb;Y3UEfh|Iss!vqP46;XZ6D2 zhg1DCFLRBmsNeER_S_}>otH#Y(eKTlNwiv0jKp5b-#@^t?4%X>#9aP%PELNU=lc3i zaNIb`*h@f--y>aym@s$7&U)IGfU3(Zb4I<-vRIXl;b^)o)>lm?-$+-qG!vs|X-5@U zqEd0IHn-KN?L03x=w_*9uJ1|YwIA}Fh|tW*|F4}L|6iw11Znl}{bH|5G{^@J?cG<^^BO9$Z1k zHWTGR#sk)QYOTmFu(YSBuHZl_Qz;m`{d;9Cy5cStSYLYrHuY0i5~XIHECy-KmN^xk zKi$%}X2vu6)Z%#a!LDxwqfS?VR;J**jHz{S{7&G|If(X`nhqhP!IV za?cM({~V{&0n=YU+(w}-@@D$-9(V)opcf1(FW!13hxKUeSv}GNhhogtI9mnF-(|sv ze8yz04A(1X3i+5iCL#Kr6RCd1gANPRW9RbRn@JHhj?$LoIO%8~nEebeRp0IvB!*V| zu8q5q-5mLJ9I*oszzz*K=RL;3TC1qY{bnoDbv~;QZZ3x z%uCbvE<6tx*C6>ic;s(RN}T$~gA*KP{@Gv4)_)|TK|07#@um@$tu0G8o1)Fud-yNk z+TAM6oPG6l+yL`EeO00(efIYTo?%ie!@2NeG0)a}kW7L%%)~XQVATa3AQB533c?)_G~^rm79;g1*;Mw`4ZFn`k}^ z(hpcTP}_r-d^7O{T1GGmykq}bIFKo1lur(nUxln@&p5mb{F+=3R1F@_vXOKs{eeRx;p&j3p=Ft1A z*W4YG7@$hdoVo4V?F#6T8|Sdn>kHPw<;+k8Y5u7oyOlCCv_`A7i}S2MObNq8w1(S; z{EoCH@J=bne-C;cI930E)p9V5-csYEN`50t3}s`pQEZ2Tj#_#Eqx)ya_BWZ&_UZ_$ zCrH}_yHAruuZ)7Kzt)iFZi;5Vu;_XKZ+?Og4K+cf8K)J|<|yhX{^|C)L(D2~CU-n- zpJL*jEeo#hCa$}S=It0|ibXV#^CzjfX?f`eF)3nS{6oF#wTO&c1+<&7t3Haj0jR?; zNzzWl4o2)*0xDQ{wtfDQyla0au;Mf@9=irBYEsr3{g24#t^SXoXKZad)FBuWjjYNS zO1}WGkkx#A+4Wq=F$^Y65z?xBaYkGts`dKtDSY|mbt_S(IPrmGn5d9K;2`s$qq@f-SVJz(IiE1>lV;1N3c3=ui=B#_;iHkE?qsxB| zQO1q*onUZ0^13;%WFfegmG^unuWu? zvru54`92J=v(fXbyZEJ8XZN@nzX+IfU;dy;M=c1ZbT^N-@!a<@sBC`dNWjcaxNcVJ z=fkOE7x448RL$uM-s+~f19ic4>rMB74}TMBod3W_;SrXSRrnc*d!FRzZlx3rO|3|( z1Rz@Cpi@vxUSsyiEB=!^KvQQ@K^b?@iMmcI+oW@{nst2}4hfa+!MjveAV3ndMZkT@=}BNj0-_7(Lp?hndmn8&p9df%1Y0K*@$WC)bGu zbY;l~2ZNr~=lqiYp`m<~?%>#<>fMbxF^rlOksTV_a(_QEq=- zxi%aLaUT;8RD^*lIy2-g!I7WeVr84J8(z|~8uH@QTnDbqtqh18mwGoEf}e7vz1;S} z9+qh}~)aev9yjgW{yWxR$4DAiF~^F&mvD zx13V8LMUn(>4CkuJeCF~y;O7O)wG*`FHYe}@PkSp`JDt1?t_jGx8@^51G2f*Pd`c6 z^FnV}iRP|-ek~Zm4>sh3OJz1TnVcvU^8CR;OJrxhZgpy#wfUblZW<7Gu5FBb{U77` ze-?&_Ai(kgsTR)nT0lxsZOWBhoq1t;>)zVim`_ykZQ?xGA{+LNm7f<(r@{DqmnB#L zjL9;8V9e*eoBf(~t{US=!LM_(>zg_^TRX$5J_pFg^TvRNx1?+B#;y3yEU##YJ_pXn zbrv$q*^M<$Zu}=nL$Fiin-PKu?!jeO+x?Jb<)~Sq{l|F}-{R@ogNKlVbAkD3wzyru zF=;0kTPr(06_li#YK)xVLs8Y8d6tW^`RO(P2zl2qQnk|D(HH!uE~^m_>R3SQHqO`` z`uzmD|2%p|Yh?K<=u_ERav}wHMbHoq{#(wk6V5`-rs2doure?5?#TElzpoGA6FEs~ zkAJWE2^`Uhc}kTb z)M?!)Ai5GBcw4bv#dUJBWbPR~llZb;PZE*hQq1t=IS;>m|0WOLffD^$D>$Mym;RAD9t{AFCUer-N+uS|$|7tDGP1s_kGouch zy!RFS?L%QqqU^eV)Z-k3LjO*`zUkze1G|4Dw>R?RE21*U?7jqZ64kSgI(EZ()s3i` z5Ufy;4>1_Qh`4s>d&P2v(#D?esYE33p0st^v+17OYs1r+A5z_uXJBb2q{Mi@n_^i?5_t1@4MFE2gJ9;raAt%WX07ncdkMlPcZ5 zKPABAt7-SvC}cS%^YHNUrMclmE@sR}-x7ThI#^oc)v1zBd^PHCJZlDrWwES4x1Yz> znK`anyd4h3Q&bQXiSlq)!b?n9F*0LX^1-8T8P@Bv@^g2_umdC{Ik0`%)c`A^+Gww` zZ$Tdaf3Wx7;c&M9zUT--bfULOl!zcAdY$M*)I@YcM2jFKq75TN?*u^@gh--A?}liJ z5}hE*4AILB$z+%@&hx&%wbt3cz1G_MtZQHAT>IL`A9L|APrL8?+rRhcq=^ zh+gZZs8f(DlB_X8MN**&d!?wjt?dHa8;5Vp7wgD+47NugQmVG{VHM)WxakLTrBp-n zi)I^dm!MgiS+P>ToT5m*gK2(On$2u2+OxA^RY_3g3vU~lua@lPuv@Wp4HNNSU$G5wH%E+7EBJb8DL zsr}VQRq3Orws!uj=P8Xfd!NiO)m>DPAd39jg<5s{Wr=?vj;ta04?bLLg85p@ET7zM zA?%UR+(1`-0Vz-ody*k3p<=Ua%ubY=vaGvK`(C8j!p-SPd*8H4XoY_v2x!-QvGk3^wr1gwSANdTv{$9p5FHrN_ z<-LO&9PIwnY$11@%T$WHoPDJ;|$Hxzu)>})x=c>l|08)!)rnhjxw-GMo;R6OmO zZ|o286V!U5yLWFl#^LZYC7XXYm`x>Qd4b`uJCu7WXV3`g>R!aGMaa>j?oobwGp&H} zRkG=b?E0iD*H+JHiT;%weA_}A@&{+R^#|%+*k7s4Tm?fzK3e-I(uT+}+{;+1WC3FksN`UNeRvW%Y}lUPX4=1 z2Sosg&e`YxpJVXf(c7;6-a^nj#~sYrV!r(bY9Lqh>HEnR7vgvLSNYX4b}>ZTY6jXs zjbe-33`L<1&*EeRI`M{=o!~XkqpOqg-mB@(LDxLugX68NX006I-=HsOCQOa3qY^sXYN1Em!Bd@?DyoaurBxzLPyg_s-vxsJw3h~Lvn z{d3(HpO9?n+nV{Jy$3ozq$I zAHr`CCH#~h2TrFIf6dp~ z_x%b*(Yn?CIN+uQjYV;o66Cs>x_$uk{;`JsZBU*}#6?xx2~5b-wx*4anTY_~{#5 zZ-4}=1%pSxEt!TeXg+56_2l-iCMY@Njj+yrn(zo|0)i#8!2Q(?p!_aPYH zR7LN@$87h4t)jEZ!FHvT`g`}ngL970_2Zf+1z%n-LA@@`YSn%=CV6r&_63gJ!tcc$ z->Hf}6z5gV=dU&0eQPi#vwNWCErX(K@7tyOx%W1XjUuw7%zX8tH2jdqa`N|W+`u4*8)im13>)b!N$4_9JZ z=C!vOtvl`PU-kxS#QABiD_Gx_-);RJ{0VrKuc{7zy-K;A{m=l?J;6Bz);$UTUG^w$ z^Ya5|ouUu?(UCoE5M3F7z6I+3__%be${0_D7_*r)oa1b1GqTG#S*j}4jv97aq(*}T*at9z~n_HGiva#$koaUUY z$5*Q&JR6_GohM2;wKZtIx`oGGEw%*u+{jL+&eZ>zx%)SBU)nF&vn_tdbTqsifI2Fq zn$Fgi*?phW5e=v61@q7MeCjWgGGFFM67lZM)`!G?K%ySQfWmx173o&QDJZ5&si?gP ztM)J>^I9p1O7c)e|4Eph_YLCryvY9fV#JV#5AZ~f{y|jolC%u=pNHk5UmrM z>Vy)0KmjId2TrlVD-;nYlt6&q`IF>GD*K7}MQeNt0wo{u$N%$p!T)?Ei#$`^mA~gv6UThI8p3*fotDJev1c5@zhb@-?rZe`xqQHMdg44eT#6AFBR+2qShwFgK6$F&@ z`@fHo|8`u4UpcWH0=7J}my*aN4{O95`(m*fl4eS=55SYY>3~up0b`1wy)( z{Le?fo&Hkr)(AWW^2^zoXCrjB63%4hOeg$rp*0)Y<%PeDR3dhHP4Ce+ZK!N2X-MU#IufHsFBg^(@gm@WVD@i;8kj)G=fo-yHNleJu zED|s%2UiM#UEqK#%@hHcK*TH35&Yot*~|aGI`c;&J^8flfCm!~M1?PdMo7ZYM~EFn z8yLI41{T`~K;xQ|+efNoS6tP2T>=#PY&*t*HCA$^J!CRiQ9kC#<)&G(d~M@lQ!m6g zFh5Y+4_KM_f6S&3oM;aIrz2pFkE{QRA|;H8=b)=0RL3Kk@KJLJkpZ&iJ85zEuZ0!+ z4>{2N$4dC;V^X;b{1+|oHBi>k*}4BbHvj$DosHGmx&pNKnH&KYK2=ys&z$YEu;MI$ItvxgBJs1F!&!FaERXbm zdd4icYLArhTjU>!pudwL01Vpb3GpvzpR>=|82lfy2FOpx{)08$cGwGNC$ZraW;ZK%@s+}lZR6R$#~ZV?=z;TXg6UK!tB9I2 zD>GFuO!f=ye(t8shb))1gRT|ggk;!Ea3+~&)Myz$P7+b zH-%$F)f}`^ZmK*i6Sd$^VDOcyOJOX2|DSNGUacC2%snO1$f8a`FP4~o;*FP+X$)43 znP+zlK3`WGp}2UvFpYJVw@T)<;1e>9wDSfJ85=?`5Y?wh*I?oRIZY3Juts{fQdBMN zSGhbfY1LMtWd3nGK`)KHpH$GkFjb@f5c5N9#19z{ry)wCgQ1iG$$i{9S!kXKmabAe zXczsDch6^;^z6E;Q=4oVUt^?i@kjm(KHT4g2Y>(0b|}B=NYDo8;BUZgM5A?y0G^0QB<~SZ@^cg5e_ZJ_b0$(XBPKfLi z14+GXX^M7MFPyJY^QP#U-r|q``+Z#>9d~3DxcdjKwFWXEN-ebxf7t%etJb754IdFo z>bdqYOkXIB%6h$x3(L_7W-i~8WsD{XO~4F1+%LtsKbvo=99N|N!8Uirp(5f9zt1hc z3|7-ylx!H0UhnyI5?k`FCTeN6k;GKBE4fP;7MLL1B^Z1?K-N3!HHPHwjI<6eya{~! zTq>XShXDLFbA8)No0fK?5FCB$EyTp>T>;`AUs>DW9C6d^2`(m1FY~F z6@m%Xra4=zBifj8E1y52wu{?=zwFTwdLMrc zy3>k+GZ_?G$twsR7qF5XaOu52ovuRP1<>(y+>!1w)h+@fiA$QiGTKJ_{)d%EO8*)F zII!g-(LT|4-l~Y^Kqwa;BbfF!O=>tqvD|~(IQmy>`%}BNXyb`;C-=(h*ZCgAT^)7K zTZo}4@Dmz6pTUNy1i%F-FP`KQb+Pg~cpDULg#mO@b{;RYaeenO&MLOlhg1Bp&G^&B z!E5rol^#3$EcmQjnCB58=w#L{rp+Lu?2Cp^{hpu%;$*k1l2@tBo}4#kC2Q6Wb=v}C z%#pJ}gVrI}7a5tLHiB#%qu;e8zQY5=Yy0KYs>yeu0Yn)*Vi}si&s2w<>8`L=b<%17 z@j=G%Raq%LZ_Kii(x3aFU#iqHWQ$|=U~CyK0k29>+f3Plu)@^&b0y|#ZL9_A{`m6* zsM$(|e;=Xz#>8JP#)$JjDJAidCW5>g%ZSPilMW4|UV;tT_0&!;TL&V$5_H91CNTJW zs!fvpa{fzj&p!VjjKSzYPvoTx1%(~=$$xJQERUmOp3z?+bsJ3iDSId{m}W~{%6XV% zlX=I>GBw6ktm)xZg)hG={$nQdzuC=;?^~D*Xp(;sEG#KF_qSybl^1;wme`%9ygYht zOk3t=<~3>|C&8v+kW!pzVMYyWMBVaAbpe&Bwyc$nu%3d#%A^VTzZv#-nI;-(WL^<7 ztXu!Wo}O~E`V!M^Oh_U~;(xMUz^CK?;Jgsj8@Njql?}eEcl1Ua z+GAm#88INOVycy<0r3M59kdmy)5)Ry`Qgi_pxxhCS3~?yM*@U2NxD#v`PNuwtM2Ok z)UlnypTm+JZ8>K(t(E=yYMLU*EQmQx0h$D_?SC06MjFqFPC5lSC&$luMPMJzx#$1U zOGB+>NQ%#VX$mNbwOe_7uBOlQRa#FwPh>Yl$8 zE?bzle~KbK^^o4n(XiGLVxeJB+o+m=#>4ApcEvDIoWm$dJm61PkWu_fh;9#ZR@%K$ z%E&WwxMJ^$zPZ^($AKenHg#X9{p%w!5>3n|$1W5zXJsjHCAl=MXBi|b_9o3^3xO9a zCFVv)A$J}Fn2Y1ITOc$*QP*`#O@;Ur|6v(UzmX3LU;3Oi-sM^O$aHpLs7x^-J4f`s zCM$C^iXw8JQ#^1^u0!q=)TN;&FoF#|T=w7ju>2k7SiYGy=>61R#q39Ux_YWYI48e~ zkd|YK|D2wJU6af4I_`2~+!D?J|2Wg9Sol}LDd<`lZCg@r4ChyIUt=>5j15i&v%gFK z9_E2%jPAIA@7Kq8CuL5rDQC-O+z)=0INBBdNNd7@KOON?_wEUq1a7}TXh$4)_wjqxfx z*hq-$$n5@SW<0Vl^~)z)+3(N8Y*ikZ(n)DyYhP(*^X0KU(Nz$F9V4p|yLOeNap0K{ zSZ}m^#BV6;W>zM5at7Z}=ARikCG#Hd%Snq8YLw{?mkpSc%&o%ulpt~NdPcw*i-*={ z%Ks{^v#}dC$TO&9z*=bz`?8?1t`g=mMw&)6X7w=6 z_TwVPI(Xbe!4vgx`VkMuz~|PY^QO=816dw6I_C*8m_q!Usec51K_eBB?vhkQ6O8<> zKj1iP;~Q6AV%KAy7Txhqqij5Q5WGH1rrm!{{fQn$G_Moe9<OOBA!YuI?m=6+aY}z-Axf|+TpNiRDhF?{#yV%eB=z#C7*D?=nTfP_oNEhDu z5daw~$6RzuE&oP_lN+}GfIWY;82E?v+}i?Yeu}YNiiciI^EV(Sfybj%Zv*^smGNl! zjw>2E$EY>tnQk`zY~hyOrvHmd3Ce{cKCGG?*X7jP4L2F4=Zw^=xKA?SkLck&OEs>& zl5s))FIALLxrlT0Ivx@b-I`!SHGPSl`x`RV;>Bs-p7dJj^Al&jL25iMo`NU_@aPL) zID)?OFJUEWWesouu+gg!(&#DZcRzq-I&s^C$_{f#qLBo=IvY%7p)+OzC0iP+V(#?B zJ~&>?qjaL;@toAp=;V>>6RojQgejhTj+X0S&CnijBP15-hG5%<(vq&aJ#UeZd)DgO z8f74qdG%cq7rSUt8dr`^(mhwIeNW2J9ngRfpth@5R6Gx47W(16^SBrSnXmX*1o;;+qA#*2LZL;uHK^aqi4!>jz#z=&s&#!2KeGEj2-2r54pg<#1fH%T^$+7YTA zSk{)P3EVLCY2s}wNKIL|_hboUJXQLH$XyZ9`~9_W#ha>!(Y-TDwf_^QO>hpK4U*rfeY@G> zmG`a^xxyD>w^N>t(`41t$ljwN5)Am>mxZt1_|d9g)#!nd$KeBhA?69F%qld_wBTbv zklDBo_iXP6{q-o3gmb(Ir(sYKIdEew_@kwpFN#ur%GtJ!e>{r+!-u}sPpP9&)ubjI`u%xzXn!NT zp6WBTL{*7#^~bqK*M>eBO8mG#(XG@iAcz0+#<3G3M)!eh>7?546htn+Yatwww3 zS?el_tkRG{v_gmB4KvAbx2gy=waGdNomFGc{`tAsik(Y`^ce)BO3wkoFG99$kCXD< zcQj$&c_5fOF)O(^&2e;w=G!gqqMFnc_1EFld-O1MKt6*-NbHh5vLjTt>lBWS+{k|# zgYYZ3@Mzq~Br⁣M#-*{=F={nfo*S)J1A9dr=mYIqbLNYN02xU7`Ojv>I6_vfC9x zP=sB<*v-TfRoVJ`CD}GV*%`bQles@_wgat`DSehCsIRYetEbsaKH(yimIy4XSCs)} zq|YSbQ_$GmNjSzt%Ck6h5zgR;;JyLG_s_Zu3;2ixJo`+ z0ti?nX$C&JlSjI$Se zn6`X-yy|&|rDT4j82q4p|I-Ql;eg#<7qf;`m5sNvo+n+9C~9~&U`|#K73ie_)KyCwLT9weje^La4Vx06aY*}&XcaX>lnOr+HvDEjrwzqm6Z#L!% zWPsSyAKqH_sq<9>6o0t4&3re38&9M6z;f~VRKCVXGJo#8@$(Oq@;;t7cA&Z?cG$Ws zNi6a2h2S`_lBC%DFy(DVyeOhyG7`yx_>8Fg)BRc|Seh%NVn<@F5gE&8{A%O+b$-Rv z*A6Kb_!vdIdMd!tTP{M@BfITvZetZFK;iOn9p<*GYbxgs6!f?5W(HU)NNVjk6$PErf~8vIFn?ePMx3IYyZJ%m-{HuVKmv7eMn)Ca;H zlV>YQa!&8M>3i<_^}oS>h)`|P$^w@JgS(O8h{n=hC1DRtzTC82c*)g-7!49OJwlnyi*Y8GbB?$gf^BRfxb{7%R`$-vt`hF zUsJXGCRhp`!*a&~*39cS!uCs-$P%=;2$@YgMk{1E=l00A{XGJCsSYkUHbesxUC43LK53+T?akBpQam zJwI$Vn%Z(oq%G>JL~!>%ZQ_U#$FrOvvXRM|3f5EYVj=R2qFK z<1r|`AX|wLKTv2nq+`C^6#q!9TAPhv9|!4<6Or+p0C1Ze~I^L=kZ0H#7UDk#O=h}3h#|2SC8-RkQ=)c zoN3^F!^T#dHC!|-$!vr*Rk`5mx7n|k5}4HP;CZkfxS^9eq9Gc9SC6v4Io`n$;G~sC znMUhz)!C2Zye3VQ+oC(~HNG8t8IisBU&+5B9SjWb8lH5(?v`#V6>VEb*5Jd%a(|h` z2i~+L6GB;p2|3u3b&jx-?E^4;U3w&WU2K2f5_KIsy21Wh(}ZNbowRoSgg`ULQ(LPa zbw2t}?u?Uj_aq(cBIK6yL3|pPMSiwzA{nozjJ%s z^TdGu{>C}uN)R{_6w*zAhR(O_1CYzS8wg4gOTZ&ZR{S4ZOFpypV4Kfg#Rt1+X00y= zgQ8Sn8?2s+p5OT|xB@ajc~aOx3)3BkaS-f@D%GQPC$@s|6)H*wbS!EU3O#?M?o=b% z;s!c{-jGj_W})W*uCOQgr4s#RW5E542$k7HI5!xopZSgU^anLU=3;8#0tR`y?hr|LSE~VBQ$uU4H0JWqu0l}% z(~;AYrFO}7$Nr;A!ylnrW6pX*(&`P-TfV%B+ zWzU`^r4_54&b{=}=p&!)TilROKu+NWdVfy&^P(~q(o3rS!){tjxj%C@p)wnbBHxBu0>W~s>X|EE zam7xF!+6)t4<)`ueW*4fTrNf?Sh=uyUG?x5|1N=@G|wD5;rob0kM&B(=fifq;FuF> zeAyh1{-Kb#n)quO&@Utb>`3yj zm&?F2FJVnK_C?cG{ri`bXx1x2r9(e=fKQ5un&>4~cSvn=;xFiV5A--ktDzP1yOsuB z2Q8kRy>EiPDxGp6tKW8D1>mw%P)+jgY!%T0pAn1^TS<;w!T2Xj5ch>2SUEPFmA8;AHT1cKNf48K098*iUbL|n5)J2t`CtY~%=FThHKY4gbqi9SE6p~|l>#aK zUM}8>8CfbV#SV?y^KT+Lq17Y1QXIRUC}aYJtqqRibM2Y7V4ur>OJ6aR`sm{KTkKQC zHf}o(8gdVAtX>rfp9Rv+9Nyjl06L`|R(?*v3fAu0la%{PZpPBJo9>rjbphj)@hmwS z*_)hzpz%_HV$*LK({Go3_+iukC^L&)bLUH{iucbSt<7Jb^VEie_R|6p^Xa%mY}QUM z2BC*t50@d(>kv$)mnZC1wS@y-n$6n|S!vzWns3x7VCiwlS7OF_hW@n<>>#T@RAf&< zPR^yI6sQSR%oY&WEzOhP#|R^Ojn5MW>x`)^=llHljDnTF?!L=bvxyaWn61*pe&@Z& z(0p?C@q2;=3c8!&j*Z4yV;lSS={HBe<8!-|?BA-c1_WD2f@ z?Ako@6$gK=j8E;#ek?2SBqSp$v?xq@n1@7@OVGl1tA~IAEyE~tV$<7)dQ+c~w>#YS zT%7~Fv8g;T&asero%tSPMy$fIQy;hy+>4|oP3BYJ8J{0)^E;N%%8izI@dqyIlWTA}2dzi)KV!JH_Qysyhb;AqY^hqL_lvCkGK2H2=)zO_)IAENUdy_A~gc>JQD1X`nw!ABK_ZPoAc*$~h z>TSLhFQe2v*HP_3;_4fV4QOZbT<&K$(DCLPjsH{G*8#2@f6zku9#fqv;atO0^8GvI zR7zPvU#LS{eL!Pxz)BYBLR!2)G{SoGl?WjkO(0eqtL`TY=Tt8USYR>d{?Pu^MZR@fS z*yV+QB3XC>6Q`#wKS;*U8-I1be;=?{7iA^sLI+Mk44!QUZ{3BLLVD|mr-pl&^WM1} zxWr!8h4>iRlqw2){9FnrsZDi=0J32qBgyM=%-g{|CCg`|-8$jAux9;3{aqVN-+YFp ztV^tJv$DbR3bjWuM<=PbkWFN(f0xS zI7@u{yA23?#k=hNXR$zpWOgtNi;s@8%vhA`6ccym(ihjK(JNrs{;#LiAxZ#g^M57T z{V9~6G=;1~bXCm>b}yYndeBQ6&#X>YidQ^9>Uol8Sr+ww+SuyKe!7>YdI}1(qj+;M zrJHI5x%@sUHFOH7f0ty03%DHziW`B*&6>Yn2~#6Xo`O7*`G{O6xe%0B=s1bVj;HsJ zQbJ2eX6yqx4YWqdvT$*xZ;Uxt-XA#-O8*i#$r}jt3ZM<|QkMzf#Gum8?N`9j95n*{ zzK+d2V+c1j|2cyzhX=)p44-|%3`3BF0MhuMrCQg|Cy=|Cy=;#U0nt4iF5`o{AJ%N< zV;3s}7y-3hv%m6-byYo#@k`Z1P^JzlyqT)5Ce|gd%nG096`_AU z*_)T1uUoJ@^_-KN%pk_Q5N{}Ybe``b@s(8{UlI@9~7heM-x-lpL(8h#`=MQCWj!=q>?!#W*1o3S_pSl-d z6DoOv`6+0?q=WAi^xf$**3qW z?%n{~Z|}aQQ_McsQ?u+vJU#(Z7C)~-8BRevQYi3=SX|p59KIfF*tomvtorVQN6~`q zDag^F1p4R_kJudwtsGD6K!=&FVC|ZmEqR8@DAoA=CBTHw!BHGBGFFgBl94`cB<>l` zS2@$;YXd*;CiX$p+*9XnbiU5@SU%{!H1i!^-@jW1v`P%3PC?Nsny6fbHq)|zX*q;0 z^hQNt!#!PjGbLiLlz8wxhYlmICx$R}>;%s2q=h7kzdb&((%6fnwAP4XRmfykp2tlsKM&!WHVkPx3u02S~O?~gu1(kAK zBU8h7I_HO$c|}e%l(6#}#mnL9pr3wA^CP%n{Iz3n9NY~Y$?c4Y^!vkaJUeaU{za_y z!51$tEnmu`218D#h*4j);gdqRC*nMe8_$EPieCg*dovlXye>-Ds(XGwUY>fvTK^ko zQBOJfd&+pGW-f7v3M^#BdwvEdI3eRsy0Vo6O>U#P?o`t#A1qmTUC>33*stf!Jft^o z;ANV0!Z^89H+%91;=Vt-k0Mp9EOE@iq$Xu^jdw#R_i0iuDqZ_=>+8V0cOjTaESpHKUj#&TfG?@5o+@+WB|+v%)L~Dl!ucyY_b;U$q+MXc_pi{Nf`HKx z#&~6a*tWkk4hzIgf8E{n>{+xgd{muEO6EE=6RR_ec(mgi-$}K%J_C6K=m;>InhNO) zgwcz19xrW{ZitWbvaxO)_q@`YHuYHOJ8o$M`$k)$^J!b;^$!V-uT7{vFV)S&wRsb) zj8~TCZ2n|I8HvoQ?-Gl8F9{A2v<3=qZ3sPdQt3}Cd26IYeUbq9BfxFG=HaDDty{?B zC;3Qzk|4|kKVL{1c)kRZtqT9sL5s;q$gGm9dg=rxH}I{A<<@_gFPjV-54F990acAk z;)SDuOkYe(;|nLe)yDuG8%Zj+=UG|Lxi_Qy;ttRr=!cQcj;qz0f(e&}Jl6{g7-zT- z;R*cASjJ=^qk!n*XPdInnv&hVP+r*XEPm;5#;Jid>f7f|CUvS*`&2zW6RRZ9mK*uu z1P*!<*})t7jl_VZTVi`K26OsgKX5Ske$lEheBLNY_+;_A|4Xu$$%1s}sHq~WM)vU{ zM%nCz1do6TYam)q#S0A3J9#iYr&kLu5P$2``Qu3f**bN9b(Psk`6ftFUv-GEJB?Y1cQA@Qmj{?u8qk7cnF-IIN zd9zl(?XVepbIu_(tm8qp@o+*)S$=pn2H(95#)}zMcvO9nu(&r>Z!G%rh|JSyK)<4M#BNPn z`?rt(c%cYDKV_AX5A%}MDg=;1ym zX?6v*pCC?zo{ucB+B7wKToSI&^{!)Hd|dQC=pv^!7Zp#)_5ftLc!?hAx*vH8s&z>% z+)(44kWqy}=hcf&(iX2)&MBx9^!;hq^ud|%0LwRDzF1}obl|X!y~&iM!6mXfBr{Q~ zMM}45vmw)D{<_*&pg${38O1f9tOup4ug(oeQq2662U0N+K>HG8xe&~;dQzmOqJx7o z%6x4rqgOQ$w-BXGYj}Rtb-;QG>hD>wt0_BVaRMDfmqog~W1+0@>6LHaPWg6V-s>7u*gf{kXrxFMBkAWw#-Dd2=l_iwo>mx z4l~YyVKQO3dF5UL9p$Agq2$+bgs~-q(Ps^5I?+ZG&o=ahGUBK&MFt9C-p9QC4pJ4v zAXk?9;4uh{2{@jK^bJ~z!^6lOZ)BNFTHfsS{q$*gp=(QG1+pbT7+2#=?9~9N17S4YPV+0q<-P~=ro`F9XQT;aVZl%ewQF1aB!s*`Gx5T4N{#pV&c>qo1eG+2euv_F>P z79;lYeOQpF3|SVz)=i0OXOjRUDr2GDHdosql*gM&c|Gi=Gp}hx`X189~T1NeumUeBi_SZHsx>he%z{I(ej$OGddr8`@oEu=Zgge3(yzF zxq@6S$k$CmdhJK$(~W_5I*c@hjW7w=Hu_mqbJgCqFt&m1q76s$33{ z_?eLU0sFk9#AZb7hbU-)%Kc&UiwiuT`?+^X%Yq`m1aY-^o~pn6O6ocUbpT^;LyKBg zu)p_*L0{9#%k2%!(i$I~9AAlVUsAlZdlxDOJddFoCyWiVzcb@l?l#AsQS@k9DWjNM z^ZR`Fl?%Z+a#TP+<&~k2lwF?FP=mZ!7GqhrSQ!hmqho_1G@z&GV$g1 zJVz0g6^DC?w}hic3sz}B%b}M^!%HBb_wRdhPbe5$nOb58{qB}to>SeB=zYCD+w>(P zUCsFv#W(L0ypHZU?M2w@0ofA*sJC#|%pLWvwDEBC_0EhS{8EzegVSxBoGw>UzOko;Uh z4qWlvZ^@T%hVrPq(ypRKcbut+IKP{nfA!Y5aR9<TfZs3PUj!APv zHQD7CdqwK+1;k<3j=XlZ6wIuifI2~>sJ~mJ{TXygSSG3 z@MGBU^<{YUZ-c~(-TZtFN$WlG_Y+vqsCAk#>po%zrQX8!VK~Dm<9Au?pPex-2urA( z?UK*?#+X@NaGS)`8UfO=P`ol1b|X}P^o@OqUPccVOi-R%vI=O_%cr`b?I*MSUTd5( zO)<^laufra3>2m#9j@BlzXXOcVp=@UkMs+y4XoQYWTMOkS*T&J=)Q^!@%V^5prpuL ziQ8?nz;b-OmQ8!UiHcu__7it);+3<9Hax$6u|q{)j_voQ1ch*5MOLW7N$e%B>;*D| z6E)2qG-R@y^285@E4_;K>HICk(4Er<6qA;^mRRarmuRi6A1Sf7JA(7BexUw0Tg!~6>v z-;;QWPcd0NAufmA?6{z$h$<^Bk#wV}UBi;GmnkZ}c0)9)35WSXPB|m^7X%HwU?;z@ z@%fN+Gb}iU=EB2IV-=w@DpIK_m78c&Csv)ej6el}aD-)fSL_P58 zrE9@`mz4>wjjQCTHQ2QXY%T5`Ms_9|wqA{?ObrQCk!O~XwpqAse?G}_KI7RH!5`1M zWdGy=flJcjlKshh(sg_b^_1X*VD)&$E3)PR$=o!hJRFh4wvXVH%`@&^qG_|MbQ>}| z?!VoZ$-!qjedTh0OUVA`FqQ2dWZiyPo$7N!_ncgIIlg9HU)UR=RU`jo?#7mIO3hV4 zDrH|X3VwCAZ8awR3lx+WuLI@{@cL!zF0Hfo@KTxX&-I{|eDf@-ODDP%Kr}I%+?}Hv z9ZD5o`YgkC>%PyDtzKJQ7Kb1#GTFCyXT88g*<*HbWjmqBw$BgK&xOJc8;4!r%!uC-TDB-)qe)7RX`s`lRDF^Bz~9SaTaTI_j~Cf#dS_u+CM#3M2W< zxvAJg;iMbAv%wCOLMLKtRr5Ne=XhOjl-R|V!(Qw9Bm}F|zT(nalB&GU+bwO7J4^{1 z^#WtDY|*W3%&hAp-h!{K4_AK1qUYtuGy}fjVtZ9wlXsOxp~@Yfg$4QwXkfxVD-hNT zIeoL}go>8vjGX(_h347n+vy{>Jl2g9dyr1bkN!-zb1|t*CI7D%sopfn3Lat-P;tn1<)O3CJa4wQAGRB36 zQ7DZUN6ue`W}J0et^$M@AQs%V);e6;x~uP_9KT&mFO=9EW1a%K82|z)Tj*lEQ9=;9 zXJzy4_kV5Mi_Vk~?pJ(JI4XU8SgjnBE_C(y4c)R2vd%S0HrG!14&Go{cO<^;`P%ro z(p($JQ9$!!-@?0>_>qCpqPO_L7(o=|MSWFb3$lli%u3%C$Oco@T3RaG8YmzqpJ(Je z8LxFwn2ND|;#;oxrD`;2g-E`f$Ag-Q_=Tk05N#MSc&WpgAd@_``WSw(T0CU5@`%JOu=jya31=)lCeq=b5v4fT< z3t?8FeHu~{FnFiA-p20M;)rnX>bTy=O@UfgMZ+PDkI7k_ed>hhzI<8CQujWC9lokJ zE8e-(s?sVq(*@$k98_D4y65WwZ2q7K(Cmfs!0bwajChILoJ&MWbyjx^UBsI0$^2mT zqGp%&Zr3v7YjT^F{=_y{N0c#D0LQH+(g4;qd2hE4*OY)NAN`p7;V$n_TC&tH6e;SU znIxhoUJ@9ZFeD=iUxAr3XovLjgInvKx25g5+vD{G`UQ;2=cz+x9X=K$Bsiek<jM69A!-1?yH;Do7>NVSf^^XzX#la`q0&h0s z%)S_cp2Im*WU-8BqY4z9dgJB~cYeB#!|Sr!+wO1UP;brY*=n5I#@fPpKc+@;Ei=J* z(2#C#S|Tkz0yA%lm&YZFSk2VGF83<$ot)4Oap{GvWw`jGd~f=K7Ft1Jy3=shj%xu= zQIP1Y`_tsMi+1v$E*N{;h}8B6w6B?t>+J%8zl ziLtKH^N-cWyxq>$6H*RNP0tiMx8Dn=2q2(c6Mq>~2R3rMb~h~t=iNt72cUBfmZg~I zFpE#r!-h0=6;<}g`9;u-Abjw0bia30CYXxu_+h?CQ$F8aLk?+UI-C!sgmii3Dx29V zi%-UvE-Hv2(Hh_U7A8LKutdJqLi)IU~dh2>odK znfM(dTe|_wxU-gY^(?}O6 zxg}9@vdsxAcZ@sidwvX?`C~oQSR0r*5^}ZS>(XtJL-1%Av zGsG8pRsIbB5jesfICD{F<~)!8rPQ=9CbiOZ>q4tDg32C<`p}wBx2}wnnBDzBZ7(_3 zXWm7&TvZTzT(0z6t6;xso=0Fi}Jf=s!m-^i$= zd9t-DqGx+2zC!k1`aE8=HvB#xxw>8&MEt+ld-HH8|Mzcrq>??dZ$lIkvR9U&vZbO@ zLX1Sno-i0Qmh8(|LMYQhlI&adWf(glWUMozWEp1An9Pjs>+`*r=lMRz?~mVe-1i^% zecb&q2Zv)^*Y$p1=lfh==j$AyFXH-yS#H<{C^ z$7CL|yg(V_#&VLwLkCPhV_N=wb!iA4L)6Q924ZjDT4>zd8@9fnuDn5}sk zIk2N|9@tLcNy`lq~f;$a86|qu+G(kW( zVbM#AxTk3THQe{s=vmH^_YRj9EPu5<86c&JJDuTb^9?&cR!3Q&3?m+pj^P#Se=e&5DWJ4OY z{RVzP?H`rrU|HX9=tlAWTs*D`iJkVJt1qkQ)bD9FWp@{pK9WC_2~qrg*?F5Dc%-n&Nv}XxI3T?W9XsO#Mx9-qUlUA`yjsMWMTo=}%e?(eXAN=)?o9 zuX=7sZ*RhN`)iPyY4>SvO)CpFDf=$*$4%SkOiQpKuX)_?D;u+9@U_X>MJbq=hnD=r z_Z{z*Br6@OkVMFgb6N6=z$mH!Oe+p3qUCf?9b7Ts7K!xD{JV-iJ`yz-f zjPA7V`;KvVx@<|FBvzYUuf{T&T%+h@@ZpmJ>uYAtZu2YEC7T%T^_B)3vU&K~4E|AB zhk>1=&LQ%iEO~niFY`Lc?-PLtq78-rMm}Lt(@-a}8ZT=N9;if|@mwS)F*GSQ@t3{B zxYalBjWRXw$n99N8AEa%uziEEIeyg( zBX~bq=_yeDD!cvu>(-rzd|el2xD3<9mbe(bk&%(YY@766KSw-N6 zU{L)yvCYXPIJN6}5QO=dd76Ss+(oj&ptmVOt*TiMiVP7&uf&1Qw zTTUAm>-Tz9KBpn^@S}PF*HbM7e|aYt)Lhl&s%fMFGfVmB?Y#rCEcqGW$}8_ zN=+{afWEJztZWPkLut_zWbenQ8McdS2RA2(MO`_(^!^S$#!lx-JC;xzK&Q_eA7}U68=zib=ip)A+G5%43 z6<;xXf_myp`(w=8^_!kscr>58f4{HfXgc+STSL&|_kJl_-+19LuJu|z44G9b%j z5I8D;7`|xJ!FwEjBm2kJ4GJ%Hk6Mcho0EJyW#L+PuVCYN=?@>?A3@m8m#7I21=c}> zV+_qdpqLBfcNm}>wzjZ?2We__Y6?mXFDmWUFB%57?I@~S?>u(P6d-K>NI37{v=@E1GtSL6xpf3f8EoV;*IKQ zo~Cdo6%CVqym`;8ouBVHy_;@|Bf5W8xPu#&IX|-LXYh=+O6LKmKx%bj5&B>;ARUhe zbi*T{v%i_HLy`)n24e0bE9lq4;KQ@X$@+zs-HJAV#w88u#XDnpAlBT~#2bsm1L&G( z_H_PVTY8IC;myN~qNqoU2Hc-V0y1~YE1=AR_6CUS=k`Yxq(30}!7pCBmLzEP=I9K+ z=1$x24^->-6}NmWg3=Q%Zw!LVlk=$n9y`W?u1;H_1^`zcq+1@GF!wD2x(Y|O5yR9$ zcW(+6SqsUn>7PHS3wY_D`~tyo(2XVo1r)c%c=RD8UgKa2_{U$WC+Mbh-9EZJj&!6P zj;ihd;f>w;l~C|9u%hJyu!?pdojr5=K-9D=Z!Z?|uVoa|eSWF_vy8Zb8IgaM5k7+r zNMB7{YU4k^pvk4}9*g|{^s|ux+u;H9N*M)~&H=rPz@Gg!u%s3aYcJM@7wv_FRiPDp z)?-$o;WrJziS2c9ok09RuptS4RLP|TP2+>R!(gP%I@LM_`r&CV(50H%%CkB~hiVfg zHs?MG1InrZ%6}kCyMp*)z(?~S=U^uoKzRu$)P2=TyNvPwaNW=3+p!Q00juAxF9(Ra zsY^%McEOY!sw(qU+r_nQT28NlC};W58>m%Po`prip=4J1>?=>sfk5N3$3ddKf^JLb zjbJkC($PX3;Ohc#8=}37JN8|e(!9gTbC_c@Bj2jo4`V;xli)W7g+rh!2VbGpalrR* zIT)cl%fv@H7E_whvL3ismE5NqbHvBSD&7d-3kvvf#H9-xZ8*wc*Y7ITAq(w<_TBR~JV8J>J34L~iOEyoKnM z^lv>*W0i*)1c>_Dy|>Qn;@#fUo79TtHu-~rf)kXujZvMA!KJWHJh!mxTl*1domd&~ zr#aV#V$G6KN-}djopU{ntI#{ZxR9YkozVVe5p<%E2+V`%>Zcb86_ulG@B-lJcwk9pK3H8vX0U90KYQl`9NQUo4o50iZ9RO&{=W(M-LSy%D|4 zd}_Py#`LYZ%^k!9&-}upIj{5BCjpUb;JU>PAUx;3P2RE~^D^Yn?+rdLdwp$H6H6UB%}xsy6fvx{B!PYqO{Slhxo7=&f!N>s-H`i*!bqlw*$Ej^j5rBtL-9sX`4)#KDQyze=i8qrx@ zVI<^&^3=fqxTb=k!Q1s4A&%@a`U7HD&Hpfpq*hdiqohJj^7Y!aeF_`{nqTe`+Frbk z;9wjuD%v zO$<)*8m0@$uBT7Vxsw%`uPlDKc1%8RZWlhTc?&=CF&SFW71c0}2`vaoM2gYq| zQASjg3CF-i`>AX~5T+3F~{9s`72>qIU@b6~-N}_202-~8y6&nWVY1&X-g5p%6)f6iF z(;Q6|TYpkZtSm32?J;J^(aWBwOFTbw@?xa~zoA2$Fdjg(-_Jv*>~4wv10E9ygd`*t z32b|9-nEiE={+b>GwQGV0!1;2)L?Gb?N@Q=9#fjzJ{#^?p{Y( z&P|^^S>N30X4i44&Ep(gq51*_sI5C(F!cwym82H&o0fo+$y@Zv&H)(?4jfi zRE-9}S#93U!78k#54hl0TFIx;A2VtIBkC6R%xw7rjF_z(f3Ju6ZBFNnm(6;(0L!-g zMj2fKuE)&QoLzmYn;W>u|9C*QWzy%k35g?Zkn5)61Bnop_(2VAPBoTG3?^mbfE|tJ zjW5t|qP-#mDJTeU;OuH)(`Pd$joG@PIOVgtr((@S`%hI4HR_$&K~D?-8HNe!LuRBV zGSX4kH%UsJk^(R{Iaam!-5^-FKj1Gh=TywL$xVF8O{#U%)p-z$`v0xVj{FDcahtV3 zSbMK!1dEItK`3qx(Z?-Xy$+?h;OF# zkS#h3RLFbX>svglj`G#9TQ|px>OGs0wMU&?lFyP$M#f?Sh7uv=B8$6^@#m(fAM>`a`0wpu zJJVl~{N@fB-TfWJBmtJ|{2G#MZavM=J>>ZZB($~FVj)qIzn*loF;bcDm=L97??LCB z(SoZi2$a+hF=YfZ0lzH>E>3$u&OA8byWvZ-5wfynHUQ2m*lXL{xqk6`_uDUK#GlJ<5?XpMCX`7>FnUsAdqsa0Kv=VL-LN zxLcdnPM+>*Xyt2s%C$@ew_qJ`it|5;KZ1YaCn(jy+<|^q4RV5|R`F}#| z9*hNksM{lt#@C)-Zb)6}qg!jtHv4Wbqu?%rURka_I@*{0QH_9^8TkoWLhs)ILY3kU zbOJ)7nR2wqqr~FlJx|XQa`gdvp8VLV!Lw~HMn+skIevOxP9iYvTwvecef4R9_~R?Y z;Sayo>0caIJeuht_^zy|2;)~%%kN_SbNaqYCTA~ zuWwR$GI z{iGJpdiNHEzoIl#4S@4egMXAlZm_Bi7hCMAJS!$S`DO?g`u6Mtn~;-NT#b3kZ)~up zb)4IO4_rP1Lbj+7^(=jWM}gsiQ(6G{_Lh?UYQdHElON6;w$PgU($&GwdMqD>UlH7Y zvjn67-0(TknkW$M;`D}{cAQT=bADCC)u))b;(}GKoI5Fr(3jOpN;`mgHp;nSA`zo$= zuHnbx2Er-#{b!DoJEt(NT?#xF5>76Jw$GoSe6RpRN>D-wzFlB-YQH|~gM5y^XQ0?H zSIJBKSSMTBF=CX)R0Ko~a37>3D1xX?ReoBjy{&P1>nCuPW)CV~@Y7Y)$rF?CuKsj} z8Kw@A48Kjb2F|R3I1G(v))(UR+GHonYk##|QLk&qJ#=SuL=u%7E?|F+UQ1ks>*JXx zy7DCBV&OWjLit4=X+^DdwP*OxzfJbOW2FmHa*luq>~7KIq)H2~POL)2-dCqjduocw zn)C#`GLAE(ypCq84p#=bln-raRYAfF9OF7rpgVV3%6#Y2Vy;XdRIQME^ zrA!nd-T>}3H)68AvOTihqxDuVR?YJS&7z0OSIBxZO2#kqvgp~Qv%17VS+#F#YaM>U zpf~WiB{bv2Wg{}A+k!U{EUf1pNa4IQUYJUl5`EzQz&>6?q`)o1qws}^;hEe!BkCf4 zNGsr2MkBo--wKepUOV!_2e{vYa2#f8r%U=^RA*53#v0b>AQN;s21E|%)w3C>Zzr3h z1@Fxk73K7X&B9KKm=VY2uH5`7@Ys-p`4Tu725=2Jk7p?FMn4fbgQF8%u|<$%-?z6Cgx?+Eq}NY)tE;60+z za8nd}Lt`CU2e@G?!W)b9Be2XRaF4EPf>=&ZQoE3m+mY_mU7qaHrW#l)wV>7)+>N7i z7k&xqkN{I~m_n|=E=mD_(8I8|D;)^s&u8{-8?hINCp@%DP`HyVoOu29Q5P>jp*ILI zg@H%cZGVKM0X>ZeM_}s*LKVsI>;A-c;hxsAa`|Vo7+<-Mx6HLQYxJtKoZrloS>#kU z_;%SyYIIeG0rP_%m@Q7v#kKmY=2Xbj`KyaZ!_4JCkby-dEZrtl>hobiCy{Vjz-R)N z>45!WH7)v@Z$i7k7TgkFCK(ZJX<-wIJ8qNTcG~0nX4hvKG=t)-hyfCni~1O&?-gMiSYLnsz_)nuPG$@Qk^8TTFJj@N!M z_g6e^e$pjdR7RcOBEr@ja6BDj=w}W(9RVsmty#2`=f!5)C^jJO-Y@p0<3%3_(dRlv zLo9Qp<@A!j$M=a}KIhx|^S@DqDg6EU9}nJ;2hzcJ&WY30~uHb5xa0URhGMTvRl#j zNTK)3q_C$0Z!Soug`S8|4;Z6C@HPxt@QD@=iu$_s!Iwsg@61h^ZM}->p2zvxz6%SF zmV{rGrnK_Feyu66V%yF#q#9`1`v!II*LG;W-3yXgMH6)ZI7HK)EGeaKZI;Y-#*C4~ znzgj}Y(cw*VvJ4(pnUr~JvvQK%(CmhbmRp-e{{1h7av)rwzE(xRzh zl>6Z0foLg)5@mUTC$oE&rAyB2dv7V(K8CxLCtYf;^?AX3@ZvvyEi*QhuRg>~!B z5!)>|87q)UO8+`gc@#@H95>K zT>6YV6*Cak0b&j4kZ@sz8jTrBnUWXr1vN{|%Qq?DBt6@3!hEo`86hJiqP{*=;kR|{ zxVBoc+{^8d^H}@f>|VMqY!CyWYh7s%DEup3KtJ}8!3778{)gOyrkdz=?v`rs{N~SN z_dkFnzlrQOC4wltQII3v2(EkURKK3Yq=nocD3?-ED}Ej4O&8bS37OyLEaxSe znO2q6lU^!MTdxW>*0!YJ^W;U%Wz(-#S8g7YsQHF zDpUsf9t!06Vbo6HQ#UI8c!+&i{fjGY6x7Ba(3^G0kS3P9`DWQf+&pD~u8b&!ey*j6%MbhjZY321AStB>Ch~SBwa4Qs;usD92dOz8DFAc*uANiD zF{fxAD=DRY<`^+L(#X-cm<(7lioZ`obJ4Y#69qkP~uOMV)JxEIx!$spB4A099C z5jvhcn<|);cF|^1(~2z(=8Lza3ipf(GHetdU zrP1pa@&FkP+&hUVxKQ#V$w7KGh-zS>SM8$Wy3_D-C4jq`4cdz3t+;8(ccZ$6|5Qtq zqOF;!(q|>U&e`;rQZql=RDpUuh63$udS6-&3nNgXZ%r-l_^K3s@!pirZC%a_c4|vJ zc5@Z_7?94JOV=&Yv^7^Z`Q!xz*NpWca*eQ)F>?PUIF=xopSxpRI;+yJUlzHtKu zS{$J_zRs7_YJbH?Q;Wh6I19b+7V%g&>>LI02L+Pa|*T;UG%2xL%2|slI zkj0I=wppU7hT>^NR>5FPngiuv;Q&YY10r5_!1bKbtVHq7U|m!030=qcJ7AKRh-KlY zj6pybGwPWtg=g`FS9b2kdK2 zz9}aNmfh_FgG5z+yRI*`y7s*uo~{znyx)yPWy1~`A=G(g^`^f7@4N3nnfbCMoWPry z0~UN*kpP)#=n(nzFw-=4_g;sIw(-g}YpKRSV`uYUT8p&H_!VrKcy~}JINZKW#G}fi()PDFVpL3I29@8kHP`>r>pqZM zr5xzPrKS!ve+oH@RZ?!#e7Y3v>_1EaOqIdOA&-f}N#8HAYE&%f=9ebCi zdr9ge+`2bthfjG$ODCq3iNoN9)0!Y|O~3u#eOP?P-?%H zX}TP;_*e)Zm-TA3Mg}z;+xw;h3MA@xP)UeGXaW*55ADd_M^kkl4~mJS1i9d02mP=kj+^`%SD~>*l4m zOGg$$=3w)QRKP3gJvhdYhej80Wp))?UmV&K&RiIM`hb2R!@M5V*mliHL3DmN?8rZ~ z;Q#f*_UG%r!3O@0fcQHGtbahK+v#_}4VfCJ3?(B*E)$0($0m`%jiXpiC(9V}wES-`oG){J*02 z_u&0iJAak#Uz79K;QnBZkLpBL3w!)iG2tZ6XZasJX=_1orctIu~I z_#O|r3QVv%A@yq+U5_Sl2hcAtL>nnmv3v~kPeYodv*)?0hEHnfZ4_@EOSb^QR4b*} zD*$4FM^KtP3Eh)&6r8eHJ5}`I@qSKXM|&)W#FL??5z();{0Fpd(Z}=siYxAOzg(6T z@0Tr>gKv-;)9$jX%HS>oNM+vs6CnkaXZ8G7yO+*-2BqaVhn`uc+l&YP&Y~a>*pX!2 z?ggnP-)H^)OQcYRnX(24%V5{D5q_Cv7z>5FdOjutE4udLp1|9T3PXI6(r^5STFG@W z`H6*UJY2d)T(tat3qjZXum(;S-Ts|lOg1CPM-*3T2rkvUX$%su>923EZI9Oqx4An3 zS0k@mP34$2X5c)$-*df5&oz?1;@|8cZDWI9OfE2(3Y_JZx1v4xxEeGaAos-K>hDy+ zhfyT@n~t>*0hV2T!If-;)|3iD4oVZVY0|hp+3r+j~gRJjg!xIYS6wda7o|spl55P1cKxiw_68K?#|5O(acH(k~ z*BfzGgQ7rGo>gyFiaH0e+$2Qao$brN_U@nk`|lr{7&r?JUBS3RZYwE+#V+6$!X&TS zYYKh%8GD|X(&@Ch5_Eq#itUEdy!(GbmF)~eu}#=~3a1I#wLP|c>nZaF`>C8eD2Q9V z%ltFs^W<0WF1XSMBYG1-yB8IJzKQ?LM=1>$XQZcyzYaD25=h^wUWCQoK9Y>Eu2?bq zlxXur>CE=};PDGot?4iSwJkw#XEXwF5>`qujc=2OsRuUz6Or)}pS87x?~XO(6?on^ zIhCecU89=)>70O@UMMJ>+ZcxI0>|fJ+XFtq*m`A7Qi5Za`Q}ABteaxmC+Dy&PUmF( zSXreh&F*^~|NavG6H^%ePP+J?n_}_b_wUcxzZ;K=5miANz=GII%7?@-Zu|Jw4yGlN zbEWj+cdO1sj8#`+BIfwEzE1xnA&`GS!&D@CpY;pd7chAc378F_6w^ODdZ2I0P+#mo z?+YQfgLeLaP&QwHUVC6t^1puxss_7_u~Uuw1M>Q&r;7Isc=+C)&H9Ko%V?jdc2iH_A_u-Tj6a|+QV|eCg}3R^f-*OF&g=WRvwnvrLhIu?c%K!><#gs9 zG^o6*nEeB)s2V2~wr2A~VlRt)D_bjS;92r_Sm{LJRPn;Sv%1Zw*gV4GeGeM_&dagZ z);aOdM+j+$tP6C?ZS>sqY*wHfcPN+fiyNz@v{}3=>Bt$_WrFcSo&L~%!TW;ycF*{a z95W_zztJh}EBYqOHdmPOjsDPRlXeiG0p} z+FbC7VEkfTf7c(7zaPsKO`f=!PK2f$oTtfn7Fl|SFTuRu=aJ^G-A;OL$|3;Ai2oif z3j2l7T?LqGUg43m?Zluga&%V$0mZ+5cmJuscu7l&!GJ;;31_}pcvySk(nrN3FUYXY zeM>x=Bu?oyq0x&EM$xr7vAf}rTNLG&(|OwO-soYcd}SO>Q>NU4=vANl6^yJn((oTx zx9pUFmPNpB6{qa_Cf)_t>~-*R&jRQD$Z%y_OW1EQecfwy=Vdmv)ob{Vo#*<@^#8CH%1>L7hS_&3>5hIfpDkkDuJ=6m`R$+PXX9_k z+xDR%tHJ;nvU8bPe%64+ldA23-{SjeS>mA+U(H$E^bD;nD$&cO*zeK|GQl;sH%skfDu!V6ZV?=xM!Q z&+CoF+XlAU*k%|N&(uZ;#vgMKec{G-_|*)Hfr-|RN&{grX+fIVh&T%HceE&ZM0*)atpP zxcW%B;fH=scY_ngmuh0WbmU67YrSMJR%|$(g_GoioDVbZ`vt|G9>@|Qq57R`>BD|D zfKL*{B`2yyBXB81gW`w4FTBap! z9bYj+q@JPqQ#iZ*TGv%$pNiWnDa6~Vs-CV&?&YXDtk6}X;e@{ve;9AcQo?x5Y-WJ* zwgzWm?f8b+Vsh&zs8{pY#-@smU0wCf)|r`C1eJu04nwKgNXkv90NjcBF;G&L z3Y}>8@QoI_+cfp+Y<-4<5u)5q$+l2?N9K56^P|KkALgo)b#6vGph&88V-zqVdV=TAUaEwbmrScXYJo3B=ln`=zZcs3xwcC3IG&p@Wv{+l*m)}+GA&3N~f*Pp042hvM?yv zf6x#82q5e@-&K;bX+nT)fFQ<77^FKY)y`&vLY^WL?$*4wES-2JJ2sCGZRUe&0 zWNn3KQN)3{pLGZcm@_oFX$u48#leJz-({~@gf$jUE!AdC-YH7If2CXpH2cu`Sw??^ zKbRrF))NXvL7eTO;y+{BYZyk_DM;Uz2i-R3vsihCSV5)_4Z($C;kY#+q3NGRLO`o?@b564g%s8M4`H`KE{)%eX%a#O#`?%93w5W6oAj8V)m(JuGvu?l>|VQI*6 z_6STk-ryo)ww*%{;4R{`AT(e`xffxVX1#DJ*(v_#C!f;q(no%n%f$?Z{H{MLX?~{B z6~q*O+oI8p2DgCU!MVXwwJ?=BrVwF7yxAg5fiajL`}s_I_i;B*Ht%Y;%?CAm{!oQ= zhy=u^m91lf;qmoA3TDU4VQY_cpL74Q_H$j)sq{y4uFmf>{ZbZ(@?&`(_Vs`eX=6}* zAm}rPaX@Bp5ZsNe;W*6J`=l*t(hg(FSHYg=QZCT^+9a{maG?i2jR&VL0wk*mExNvx zP}LZmzye20yklpfX`uT0__yN@E(y6q*GN(vB`LPFE-qmD3$!M#cOFbdkq;k0#xzTI z!(gf_mFMOzHfKlPOmaQnxd1t{*`^G?Qcb%~UhD?*Yl)kyl3E~-r*9dZI1(agCT`sT z+F57=NUfvnfE&+L|E@}=B{Bf}f8Trkr|+!J!F5l94KwN#;0k0%T$E?NX*p5EI8?=a zt9Mfpn}eTi=Cs&Lmib(u)#?Hek|IMI`brC4^|R(td3n}xZiy|B_pokxB)Wki!C}7u*t;6FZ)(`Op~;|=dOoxk+6pS zY-Re*v^1ehrb8E4_52;=No@3+u!^P84H6mX_?JpAoN;ud6|9&gZsM+Fy-ihGSM+h+ z7IZP0<8nIvLsl|iD#GA<1rPJh!TD8a%_znKonSYTt6kyk?SG^(YJO7gbP$dQn#}WS zJGdK+W1~)zePdJ*8vW!6!)mf}V%b-m!{pb|@?Pg2%UnfMJ`?;6+e;dzBg?wiA0g%s zbjI6GAu!tLXToqqgGInSHoPs4S#-PPt2FZ6R{@;MyT zk1>U-TI;`lD*GiazoWeQ&itWIetL(i-hQ;9(1-uYFo6*SNLh8ACuF`ROQ2#>jezy-I2Ad5_6DZZ1|03V$AX zU#Axpbs$P(k_Ye5ya`bc5_j2^%dH++*ni%3e#L?fDiD35qv+Z<;z`--Lh>QKRKKxr z>ah~g96HHx_QNAnnMd+0jClg`?fWKM(pECLO47__jO5aI7p^iaXx@RILtCFGbYWM=<%ur+t_?s`>5IhqQWGjOI{?GBhsoUqm>tgbhq!OA z39HFQ6&8(V9oA<40l_fe*F?>Roe_ft=YheGT9g`F+HK0tWg&z#C|JgX~q_Y zVAcb`!RKq?T?!8)oN7-Be4*K>@+;0g+u114<~jd4-fPp-1|qa#@*Xxc$wNwL+WVJM zPfC&66?3OVL*+h+L#8WGWyJ8oaWXogZZv+IhBB_6bmVrO(r&I#d=k-87^AM(p-+K2 zx-T3fL*sk^q-xwQH{@EQtAcA?U7{!lKf6QU>!q~w)YrU{Ul`$p(nZ|uZeR|vFRQz@ z*e!?4zTZta28`S6hd5jQfIj(fK14sb4~~&-&Zm~sa=4#$Ev)Xk-&MRC%Py1nMM+Rk zLP8?A6d?-q4);@hnGx>}?mmYJAKcg@2s5Kd?Yj=IKp|WPVzefthj_{5o*TiDW0$;B-r1!fq-#q@IsQiAx z?@ILM%8->MSS9RH#6_^c{J|KS{V9}YSy#b718cCXaq+%{*+}rM){!w!ykhJukUlE1 zW}Te)(?Fkbe;R&)+>CM5rG*49BfTAQBxQax$Az#E*Kc=?{R*@0FS$saNo}X=gE#E4 z6iAOHpQQZY=@nA&Ov85b@BESZc1NXn<~xS|L%d~V0Xum!2j;u65om9Qf0dXihq&_H zUTHf$8hgG66xM)o&?P>{?Y7C$vOdoA7A@FUC7IPH`_!e+*s_{`kGiVz-S<_jVpIwveFUu~D!(m7r(?rxYyf8laWEVT?)GSLAT)01SqQei zO07$|d%{Cp{D)2W_j}1b&1{>yqjg~`DZeUG3}h%6nk;^@nihPmCf>rJF~r9;^{!L# zU<22uWSbx!Q#mO^2@0ed&kCiz5rte_FE`s%-=;FSPrvy{t0Aml(Z}KEqq`drLSNu& zTweXf=cBe)XMKL`_AfmK9>BUW9EnC%&%FJVQ>H`> zC5NMX_+Ihe9p9Zb8G2%-hw&E6XlzS`Ff^6<&Qwla@k~c) z#J9dvPcGczh7XrBS_^=1}xzj!!wlg%ad*rc=5kl|U|NIl*cNa0e(6&i|I!8;F7B&j|V zy14vvS%s_S(m11di_(oNv z2T6*P6uO5w=b87o*AwX~6WLwT3Qw2rU64McF`4t#itPcxmk<7gCRHzWmQv8}O$2jW z_@zPL;)-gg2)Fl$Y9IO}kKWLTtaXzeYtw<-f7RoA`?Cod_4@kdZpX8^-n)|zp0m{d zxbOc`*zom|Yc=e!6k89prZxX(jeJU*;`(TzhKZb53|aeo5_kTy;zz*`MYj@rY5lIg zOX&FaqrV(3J<2p0e|fKU=mqXm)G1quR2JDEkWKVe&;;c;v)DkAW^s#xctdai*eQKI z&*lid#huTaRzQYv#(#xmAe_t3b?`-9_pqSoVoT{ts03#QFJA};Fg9S@8A}Y68O>=J zq`wSbMAWUL#eF|*v>m54c{YxVe4kKpHVFc042oX4xn>&=3O?QD;qHQ_O3fUfGk)#c zsm0xW01^Z36Y)RFkid{FjNOk?`Bz~dR&eRjQ3^@ijRB**kKd(8P#k%KE z>GH1l)wFOmM?t#moKdhZaj?{lAEVtM`>v$V)yYhHpU;qonOevP*QLl6+17goE60dF zjF&#o+PEwDt{9BFOmN8w9j-3|V_vD9a;d8+==V{Xzb)bcpIp)ojnZX&Ufx-(jxB7` z>GBov^!`x#Opk^)Z$DBWmX41`Ogfje^J0qKT>!kEaqs{MC(%J%mq=a6V1w z8ZmlSPXMW?GNakVWx$I-`h{uTVJX?=|JafaqY-& zP+W^v+a(WVJodHUsg!wmANDW$x)c7^R=X)yS(S2pMwD7!O!fK$%7~%F9bn3Ut89K? zT~p$>ZcoXr zBt1*T)UVfvtW0>s+*r#sdVNG*k~;VDXg1WSFE}DKf~@u9&8K-<5aw33r@}9B({ue? zYZHFeT;6l8#y*I335ublYGNG(cVcNF2b{{dOwhFyai?yZXo;(NO4gWUR9+vucqW5I z+J$&H$~u|YM#B&mC07eL#}Nt)^%lBG&f_Xa#~Dqv%EqeaN?Tk5F4sRIybB+g`)Kd< z(K4x_$?m~!J?(;ehEM&_t6=Z8;kF^vuc*rnqP>p4)hczapKi zbUmzI=@lJ;S`_poj-pzSU4Kg;A)jgl*a^nFe>l1Pd>e2(pW-!d?+5Qz<)%6$XV zy*I6O;igyFryuDvr~%A_4ymG45NnAv2kl*FX!v*32DHLF=u;_7+KxKr=6a5Ed zRFTrg&d_S5@H7xrx!P0e&|HuAZxktJk9U19>mJLnzGlt$dD*D|rnIKcvdbIQ5>7eM zg*@?wW=?TjTf7@?H(&mA_{OnkryU})Q?FQk}Iex(jBfxk_tVn26qcwz+lko4~ zdE+O?yl?&3-Y3XtR{S_N>t%+I1!aY?zI@f;B>J=>GE_T>!im|RYNoK>o(gQ$YoP|= zQpSg6vG4c~+nl2A?0A>7+{A-pkRtH#54fNEt($Q~6!!=7O)rbt8v^i;=B|f~??*P8 zotO5(9sym)>w6M|quWl}(&T$*D39VkVuyLEJRdpJH9nlTWz-#eorM(#+>Uu;BZ_!0 z5H@I`67Mdu8Hmo$y60S)AT`*Q1_@OfgR3JH|*$>+4oALc~X0(T&65D z+D?xVE!YsOO=m%G3ts-_r6mWiG&4)I%5=iu*d21J;NAFq`zI~5&(m>hK4zbNCp0LsopT;#+G&X5I zcgHqU_Ci=@%kJjqZI4ud{DAv!&*a3e3wd!Njf^oTaf7hc6pe zeNA()pVbvl5?WAaNg>;O+2R@Olhw0gtA$GegX$O_c$RRYH@`QhZCx&DCGnn1jH?0* zKejrcD*9oGcw}c!=+VZHm^bcCu?>DHa(3Y@4LqfT>RmB?Qp-F_y%4v>guKP>_G1on zv#;23l;Veh@8nrh>8(a@3JrAH~E(vc=00g)yjU783H=|n)9L{GLZJ zB}BO5)M-)L@L(pfJ{nNOX0lq8jNZ5!JIv)bC@bu_qICa&=gDMd@%fS*FBRsJ_QDJz3ihs(q5_XZnn6>^0(LwX1WnoNbG_F`V~5zN<&_>CX5ulx#@#V zdwTb2=SBS(()&)qvp?P6_w!F9L;M|%w}Ft?|au8C9B=-s!Ozg0caK0 z;trRwQA~eLNi;P!?w2u0d)7<0cC?JCXfPyg&N)DiMVC4?IyVwz*hoKr%AR4$m4Sg! zwCs1_93-GoP3k1BlWe7k)!Cw+*>NBut7_Q5ZY4|hT_uFnWiefd6yhMg|w1a$P zT)a9Lw`b4CN-AICR9j#C!jNor+0HHNW53!cdCKXeOqYGZaIG!UwU!u=%vO_ne68Lr zglFJtY}y+Up)-tQ^l5K}Dlps%Rh7R%u)xF+m!bE4q2IDDWtAaPLz9}W%mbmW6w43v zg1h5X9Z=PhaQHxvt1TcfU1?Zcaw)1mXalv_5hwq45Us7FW8q(`0E=%6HDzv&{)A9novMMS`>( zmu;W(x{46cfQ~!5cZ9b4X6>w8U7`BqebZ{sScopUV)yo&8vYx0T^{~`erX*Blva=* zN0ld6&tLbA7}+>jpk=0K-U!yy0h9o78GUWCB%rXGBl8M1JF-em_X833GLD#f=8$r5 zhCqrfd-r#(`hJqKKoFXA-zzSFz_eG>q)Q5^J1RJW^&&Mtx3f`Xv;Y}H^rIfFHhrBlJIrCUpoo6nOayt_a;$(Z$o+4pk&j^I~e7tav$iO%;8h%RLcSfBqb>A za19mlLy;X3JaL#A$BS@oj~f(F8c^_`E>dvUroY}_m*(Jmg^ulK^`*baBB;u54GVow zi6OnG)x(1-C$1O7<$45V>5sh(DpDeQm%(jGZba8~thpz}z*Bz6kMZ`z5Oa#-wIN{z zrC5pn@i;vZU2aI-UXPdrI7y&gYU+n9jNtuYa=p$;(noMZXHZ=Tmye?l#;KY+bKSdjap+*OvGm;Bzh2A^fiNLRzLmEY4BV=g&GJ=aW@ zCFPM(_zE2$?l`3y>>I8Zn*W9Sg)Vj&e6laQ_5BwoDZ%y)!ilW?(m8!f;oU5IW@}Bq zS}pI7t{RH$*3ciM%Kyu<`#&L1{;S8|lm7;(0@Ry-wT}U}T`%GU^bfd|d_w{Rz^$a{ z*xtJvI($3rKyc-+Q>ryQSv3{P!&$9#d*kCwUw8O{U!kbU0QeLk%>H%YMcM=qQJNBD zs0o3T9H2V9rb_dJ;a}h_vjDo~Ux6b32CnopD`r3d+=?|078mvER-(rfBjUR8%O_8A zv+WW@`PVrE0M!Izn5@h6;;$9@mpqB0DNZqK zl+5Z%Wa>7D-MI5J2ku9iH@vrPCIBbBr+@VPtlg-EifQ#2c6R#_0fQY&u0}lY4SGv{ z-TTw$P87_wMWeAo1SDkLP+G0mx6?udm~*rB#Ayhpsg{6;jR8P6@`}Gf*E)@N06?kC z#<5QEQTS)bj~ls%#Sf8mWP{FPK^-?D+N`xb{}!>%p0Lubnypq2yZa`^%`n-;l=)J# z(O#PMnqoASi-_sPFj71}qZLbMS~(uQe5_>3voN5*pWC|O^K_9WL%sRU&(Te?R{1t_+J+lW$Xx+_zKHWP#mMe?V%;19&O!K{Bg008^Al@TROE{O+ z1?7je%hvcSFnqC@=$G>{6pz_Rw`@r{dbSq6Jw;etYc244iw|u0rh8_BxJ8b^Z{1K{_^vC}b{v}bE*KK5 z{)UMj+kkJ<4tk^~?s*)tSWNGSItB5maH}KJ^<)_!w>qhxrcKP!`={2Kujq8|LDe}E z9ahCL-l&Y?I`elB(q?k~=By4xNhntiq0}pX-@EXOvd*nA%_E2J`MsG5Knr4+aB!3C zi$75LrnPP+IubBcdp3RM27LbWv`Vuq?TxgITyNZJcQHFcnH*JyxI|bFI8smcX$VS9 zM@zk-S=rCEvwv_(j%;;G2E+lVavYF?h`E@ZE5~-Ol#pbDZ$OD3nfqFgX(`k8v-`2( z0F+d^tfTlttfqMCI{_69sYL_%*gU4Y;B^8LT)#B4&~~soBK;$K_`H{)VfszYHSwDa zzx3&WCg7Jgd^AKI40jM4>6jcv%-v~JJAxn)V|hI5xDj|;f!)8}-?W)>r0SZAO=g=u8!sN!@YC%=RLw7TC? zp#N@|b=WIYgr245;#TJ)|kE-{HA1*CBTUblH*}*$HKAy#RoJ=86vC7!g6KP(U!z&et=#0diLweHk-U)9S02y2IMFw$$69cV>ZA(S6QB?4piFa1=HdP? z=AGwy)dua0Om@l~bIZ@ac{yO^7;j~%2W*U=q`m@nIuan0spF29;U-7S+P5|2c;f{X zyws_-+CN;12nm!2#Ngb6q2b z;5cJ+HQ^VxI-p$Jy34O_z0YQ87%VBIIi;#d;Zd;QJQHgdW265|AA>DiWvMq8kK*hr zQYMC{RM`u}cB;wSd6}z<(=GJf+3!hdvT}0LlVRpy2K{PT`{8bWQh@P5i}($V;4qw) zxCdc>UGOluIm);~W&+6tH1*r}+t8f25yO)a8YOTU2clP7&}kBXZrZE^PafmB2FKx= zfq}F6ue3H(zC6P2#mJBi-3SIQAG9{(D?^7}WQN+#mgli*PHH|~P@8)#0{+|@v?g!c*^q$-Cn|&mFWjCB<{fsNekdr&V!#&l9dr;lSBA?4 zfT5V%h0YxX`*$0OK{M_(Ds1!_&}HK@zd`qLLY<@$c>ne?esOGpv$q)3&ij3$R;IIs z-QdRT*`Tgv-vbdbb@ee3JngyuGVt}2CljGAGTjA+c`+pmhHkIEjRQhjt$zz?rC%+* zm;nOac1WdMJb8{315)+b4BfMEvmNudWYX4-L&ZH}ot1Kqsgr!JD1&>@-r&n&V2Ux``!`-x3F~F<%4yV2X$Y3mwB5}_WLrFbdMZXNkIT%hR-mjo|Dg~!Q7M2hnT9tAf09| z<)!k)&h&b$6bRJn;Z>)rj$k4r_7_6Bj{;No6f5zBl?o1Yz+I0N^#@pcJ13pFmkiAA zS+vEZcu+930ZIXC2IymYbIMFR`$-9A<8%=n;E5rezd>iA@1VT8;IFHIY-2S-i^w)B za1?;KJ5xL8H}L44u8MV`nEqOG_KvIR8h9QCBvI+@biR<Cne+UlE`7?6(Z{cq|iitZpyy5V8V(MwU00;87j!F z9WsAPY-`stGlCk2!^G-p0g_O1zOQ;qSLL3K4a+>EGpfn7l~ghm8>DSOgb->0*aA=M zDUEKF=(mcBuOU6zs=+y;g3+fWv~W3weFh55y7T<|{%8vCCa%2J_Xw4$0|V?IKN)Ho z8tKAx<8$k>7DO=~DlchJ#{sfwy=Mw&wLsXY3h-KzFAyKxg@^YCIac+3_yYeix6Air zO*o!1!uEk?P+vrd&m&-{+6Yz+D4%-cr@=Hj)`7~Z8I)Dn`8m_0Tb?4ucZmQScknO|(%`OSa z5$3P_3i2sD4&5JoQ%3V_Q&V>a*;gS~)@UPPxh3jOIuT9Na?{oJIDs%1moN?=B|o_D0f(;{;W;Ay(w62hD+?oq^Bs74NiBmYBbHZDvE~RQ;m3=;+?}f zFK&7#gfB?;tctqsE$kus`40?P2Hj9ThpHgM|$KJ+P z@ZIuAK&(x%xA>Q#eCA1NyXbGwR1hO zWcKGm8d4}>XsG9wv7E&5(EGg0?&pji_I|9pF8Iq)Qj;$BBE|gVX)6=a8J8H3YTk2-&@}Bp(|hE2*DZ1m0Kjl6`Nt^Z#dY;ENaDoUO-dLU*l0Hsm=M+M{P)d zira6{V|%#b$?~P&pzE#}7u_>(+3VzN00`3|yk4vSnw&5cDUvMz)ZQ}%$D*8G=wbkZ zOPh>Y!Fmv;$1;FspK%AfA}MHEP0Sv`?khpV-(#z%_-uRE!)0__)mZ8vDN#&W4_u64 z4c!w&0+rz(P%+BAmbjh>&fYEANZ2aR1Mg3hhXNu!?6Se{-9ab49#)-t538bfj_rIg zAVd)E3M&rPFYi?MhT`AFo!1jAw6QxTrkvOZfb5 z*jRj#N&~D-V1HSDpCuXZ6TSZ>os{+A1TJmwq`1uXAxeLP{-BKSOHJ51-++DF%(?w^)~R zkv@B>CjoQL-%Z)}D`huaUW8}=RAMO1(EXXGrv{+dPrMd;A~>GmMbtbRKg+*pPQT@U z?@>t@o3^O9A;R1pa5z-5uR_`F6%Ar>YO8;#uZmn=fqMPBlQPSztFEZm| zgwQQ_QZ(I7-1J3*qTjt!;^P2;LTRJH|ERa|Pr4ibQ%9&WavJjCCjgxAEoKk{JF>ne zCt5oA%i7)els5^+BKKII2mS>NlYP9+9;FSySTvEnC;_S_LXcel8^jhuL>aVAzlgp( zwVD1`Ul%{a!&r;V+gL>AZ;;$vd_Mub9c6dSp1+8yYaatIP%tv1b2?*++6{qy_cxz= z3w^1wVtR&vkX-=lop@SnVs<_wokCT%EW&Ba=l3m{w>aw>|O#$OmQC;CzBda%fH*o z`xdBU@zwvcq~kQ7RZM4(+4H9F@Nbz$fCP^v-nawJMlqVLY_(H=fhye+u>K_*_IgMn zyNo{T)WxOre)`*3?7F+y>tv)j9NOo~o7uT?zR~~skd&cAE(U=X1{Bld=uOksGn&q% z@yLDz(qXqak^>=9M_{2pDQp>PQh(8Xsu~^c3~J+xH`Inx#a^$pLGVmnmaK?tgq3GF zbJN#TKTgi!Wm7|IlzZ#C(uL27iQZ-Dk7m*ziHPC@Spy#(NDK&ownPINz{+ul`rFz} zRtE&w?0%VCxR*D}Q?aLeXGf}{!A#*^a8Sx4Jb3#AyPBj$LCV3NEpy&7w1c&2S@RC? zxrJ{yXKAu21x1^JXm%f4F^$YXKdg2?V_g<;XlR~Js3=>J|Jnw7_+VGs-rianKF4np zM|1U1Cnw}4RNNUSc9M~aL#6DJvi)i`Yg-LvNxod{pS(J6FrDG*hAUM^>NL13YjU*2s+kA2|F(e|y{pWKbK?k6QUo|vff9mGH!o?J#O0g> zwvP(i$qi&-71kLTx7QnowAOab6>+3TZx9L-R4}}2MA3ie$D*0j8uA{Y=Z%wNu%Oa_ z=Vh^UN5WoLaRQZR8d8tY#sTe(iRXW1(@E@Z)n{DaJbCU=GR`zSa(oePbR*TgnW&hg zpz3HM``Q%2{+^C}e_IgxinA?BtjuQtnpjui!fV)IW+rle=KLM~{Zd8!?@@#6MV4x> z4M8DPS;9_!YOC5?vb!}9`9;~Rn1yQ&LX}O}u1xP=>4k@HZ&WYO`8WV*B@09`H(cEz zrYwMMCaw2HfzHW#=K`}Rd(&}Odv^8x%`T)^3sJf_b+aaIOI=w<*@Y&V-iA$uN0+dU z65(@Uk9=b6B=HlH&xAxx3(T0%n_l%@{YS1Fec#F(j;r9GdT5gp&U}BACnw!P%T0x* zszR;-4=w1T`oi<)xo54EZKtaDNVTq2gjt+%0+QePdtQ-X4zA>L0CMu1b{X@i63>-4 zPpmXS>&LUN0ClNA!Xhp_vHi49;{r#V+Q{dI7h%#%#W0WT#(|WyvO8xMABu8Wy21vl ze}jPG&BVR8%P!x9UsaXPLzR8E?^fGu-`}gFO0s4>z=d`@u{Uc_OX zcmK=5sP+jZxE+HT16_tl__5uN>si9NI6Rns2>}U(u}k2W{jV8q_?{HEYL9BX)op=`e7rwsTW9+*v^pYdZT>noUajMY zU;9=Yzdd@ut4Zg#r+NFH<^WAi-#vy%J0K2kry=v=z_eSI_`IgJ+4>oq*kjvgV|i+K zVfql7 zOl}gK6RY@8haC+T7z~d>Z?#;n+IjHUlV6G7-GDf8oF^@C=BcpT<#Ua@ZTl^pan` z_=AT(mg4A&;PNeEM(p=*Hg>bqIg!)L5EjiJn_O%SVs^0UVhk-(v4+yqFM?4KbKMzB zVPj>$RtO=U#z7zIo_Eyu#lymRof?i|75DddZ{N{Br_s3;lcg_8mlh2h3kA=kI~VS^ z!0=#hzvkB|yP^B+@4h$cBCN81nz)on-M(kB0+hJ^_S+C%UBRz_k{|2}XOG5db%K-2 zOT1b|>E#^NP*A?H$i2-opokEchN@ScmERl+45w7Oqu8=8U!`Nny`EDZdnv9@zlsHr z!ffRyOm!of3D*&jS{^@P6JA;Q-NBoNg{fNY(}t2?Bt;?|vX_v<;#{=O@~3-lk}}Je zv~>!is&?fw)}AwX1VC1nqPD&*oTqzCfW{&;2*-Mf-LCv~0(n+BJs!0!jVs~q!^-`( zb08%U|CA;DX)9Cz5aD6AA?t|_Pfxd400wJJFPsPH>y;jmTnO-jb$E#D3`N^LV<) z_pSRo&%&*^b8E;sOn*7Nhcn?e4Ixa+pN z(s){Vf`dBF@W@X7KrxeH-A8$^D}9;~_W8}9Vx0%A#HK#V8E9My9{DwEtOw?yxD!S* z*T4(m9aQ02+@(3!Z?=_5LRplG(C-Ep0HJ4Q$&YB`nv;AAU-QVT3cS3#BT?>VBx7U! z^zK8#Ic@&AQ@{oRtd)Uc!SpFzQ0d}p%}$Qh!q8Bf zrqv2EMf43!FFWY=P8D5Cz%<9y&jmV)%Nj zcxz|DcK0NBb;ql{v^4s1sv^@@VTX9m*dLK@ZZBEG8ys=icP7ktPK`RQ=k!YzS->=K zv(c)739bThT{qcgdFK~fb7Yr3=#&rjW?YwG^C0feh8Th#B=&+IQiaJiFBiTY{dlx6 z;-*;WAL7?}E_v*h0GrZ173DcDeg3KvGLzXPf~TF~bAW=5!JWa2e;qwa-gxR6ysmxQ zgMH2wOEYQ&YXU+?SJb?17k9Fp>KG4f0hxa3l5QdUSAqVZ(ean3`Za$wf+Kzwvy(4% z>+%wokUZm+l8uvPMh2%ux)YDTT%A{NhA3Yi3FEV3`eX@2JRI7K36qCW-hWN*i4Lki zgFX20c@PpV!s_2IC}*L??V3$Ez;VV=`CYe~Gvm_;z16CUaw^SrO7fVeb%yb5r++4D z&u3ahY<`A8>?gOw+7u3qrq>QC_i#C<9!Ujc*j?2BFu#Nd zHl@JO?UiBL1g?MijyVO{|s0RYhz6T#dHFWQL-$zms;SZh2vrBv>U#w)T) za-Gto&$`qb0T}nco00!&SRROu*&2anWPCBD@b6~i=N}CJ&RGsYSoA|EhpyPNc@*}9 zN1OHOtowCW)AG3o7sb8qgR)d;Kx+j*{)Hvxe}HD<{QuH(s!_c+@3&R>Xxv04sAX=8 zS-2G0?afba4J~{hIeyD*zf^+0a;d1a8E{;it7HeLIOHoDwLv@f8KBu?c_d;U8b{#~1wZ1%G_OA7Ajt7yR)Be|*6oU+~8l z{P6{Ue8K$l&DqHLY5DK}!s1EbMor1DwW;X)tM2Av)tGQcKfnw$n z>_Aca_w;{TW9(f`1as`Zs39gaip780YP)kvYVxiQ;{}!$kcp-gCu(qPXp>=bSx{^uWmf9{n>k^M@7FHq*bJmY4-fkEs|AT?rR+ckD@k_VJU` zk=JZYJmxgzcsz~H@TlmW@h+fP)z-F52jJenkKa`Y5OdGaoUU!D>9C{AG5sq0Z8f@x zp9qP8_?$3x7s=G-o3|Ub97v$>Mzv2L$DU#!vWxp$Ub&Sj{o&}N*n5#(UjMDvin zkIeugP|l9a7$0B`G@VngE`=1CR4$H~=)JR>5yYjPJZC~1t(#H$aW!#8hp-9I2Qz#V z*Cjkg>tulw>CZ@|Xd4TP6W?RrOeZ82~A8E+8VV8z{kO*C+vdiB{qbWnZ`1{bq2(%_Xk6T5T|dag~o zd!Yw;c3gD?Dfr3yyZ!jmJLN%Z6(&&jk~bes%|9ch+!d2rb&nR|o?+Uyl~S&ePNyF2 zNGLEK2!!vfBj$SV{v%s3pVj{jrh>C$`j;ivNU2RHg+{rW)Bbfql_EP^>w?1t`ThCdcq(Ana^VrxtT7q0GyEH5 zf<@1)2W*?-sxQCOG0-4f1yJ9T%Id}bza}rk9eZKsOvg$(* z{@X@Q!FAb%{aqAf>oYRw(D`e?ojRrI^!^-ot6w0f^$z#93ZYvS(EQII4|5kMC3n&K6? zP3uozfcZWAWk2ElP=BeYvOiCEj;`S(?>7io6>sV)>PJQd=p-5tQJ_0;pq|&y_x%Pf-&cu;#C&PQV4vjYaU;jVR|F)tk7nYk zR}J1eLxxpt)w^C!@2iyg=)N>TdV@a;t@Q`M7! zxv#@ZVPfAmw*dhQ=~`0pZ&1xB^yETY38x;J{n(IVKZ8_(p=!r=7;wwlhD2oH9W`z3 zD(Q8Y?C`C3g%y$}9h7;daLwCx*AJL_iWeLJXc7@X@M2}Q`*BqeS0ZB!%Se@9b5G8> zr?2xtmDjEg?NXxe&9i6mH`KU_0JB|>i+u$5u!+Ax6Z!m5l1e0s6FE^_R~#4Cw~sIm z!z=Ennj)QK-siy*Qw;xMq%A{PVl7TjJehC+5N4MFaNSe@>w#Vf>As0dJY40zHtADP zohjR%kDSOXko)V?w{u2}&eBOXWx;uyR4wXf)aiCGb;2G>t1CoSnC+d77){F0Qx<@~ zI`+u9-ty}4KuuMvFL(G*++xJ0WWh;8HGW1{i#n>yHtuMS_n}ydmvUduVby#UJt)aB zaP962Uj2uf_f6n)6n%w|q#)F^k0qRDMB|;)L@+kIlU0$Q@_=xX+uw7=+sYsE*MzEu zyV%lx@zNQu2T$BkGfMKKAo}-ABdwOl`p7v9gDMP_j1fgJ;1-MNe$}uWiLu~A-hRur z=;SX~86cY2J)x_;XeshoG((%pqWsD)7=DZuZ)n~kL(Lmz75s}Gxso55H-C&YipAv) zT(!EktHB*z83p@5+}kcBK7tZlt6oyBoxI$WBJ_^5L}7Lc3@hzd^nNAh7~Fet!~OC= ze-#KB?{n1t49_mn?>bQ}uxB}4lJaH6W5XlClT!p;dx73z6jIl?Bdp6f?<35VP-jeZ z$Y}KrBVFM&HQw*}K>Ok)FpZR;s!(=P8WGPnybWg25yfxr z&K%2FQ5M#;8UZ-5-=^o4ezxdD!UAf_5amapr2%aGP%L!L2N3( z^NT->TZmQw_?r?>s*tqv^W;rZ62T}2g?-=!@WW=Zn_q63qW{D#e51(NK_m&RIAlr( zm-3#4_@Ndc_GD#3cP9W^))hB!tLnxV23o#1Rk)T(H^1sUrde=_w=ejb-;ErjQ){nO ze}nu-BCn1^Jkd!iTl*5mk+18Yvur+9s1}sESoZezx`r5S4xE0rbG5GMS)~v3y-ddq zB2S&IZvi@Syyc09)9=nYE@nKD6{S%0o=crj1r?>K=#h8a;f@X(#zB%$|QOK+9L(C7eZ)!`Oeo^ zeJZ~l@`V@bik#KTDE2;9LR@cZY041v8b?bQhqF{37}==VSe&|l`yG>w9{+v~J;1;# zuenLRpm_3DPpo*mErgg8IPcg?WE?YUBz$t^PV!Ol-P3k8Wtx{7Y~>M)A#L}KUTRtR z2%VSdb6DksU<)0PT2uZdtc0-{iPrn9a1Wwy)8V`OC0486Xqj##^39B zi@QBS?#ZHZ5B&FklJk5hew;;=I{x6cjKJ~hRGB$biLf{$l-4UU@vY7{+07}nNX9$F%5T~2GFPlq-uG)?&(0HcKj4gM zr(N92TgdDkn}e?L_a}s!N9_A!Dw?8i2-|25b4*mv- zBVa^dL?AXZc}5_wU5Om@_BY7$JMXvN^uh`GiAYdTk}MNY+B!fw9l4ltq)h@iARk4LU;GB$X=wZnD!=6LWfvF~ zIiH+KZLLE6pggA8-dx+;L9am)2^435O>|BF17T=RdVlfz!Czsr`d!qp7YGTvw1PsvKK@C>Uqvo8g{(*K|cYus0Uh0N2A;n=Rp4MC2&0dL2 z1zUoQrY=vUtj}nQtY_z*mSQV15mF^In#bmZ+n@B zOHdCuz%42E8NpOWt}tmz@(I@pJ$-NFBXizdkEpVH2TjE0tIH*vqs0rgYv9E8E2vIg zHOu67;pCmF!^?FWc?$+jJG&B?Wa;E2kopW0drZ^4IiX(HI`Y`DuK@wz;*Gcr#PpABpKF#ej+^h}prdgo&Of!wFzorj!VQ6W38?adkVw4WzyxdpPj zK%L|)13jf5TFON{yl}zsWr_sSEq3vqiKGlaslJi&Yzaz|MN_lJnPjnJ3ufJ}vM0~w zF+Sa|mgVCdhlA9M@4ou7eoJ3cU%FL~VdbwOJHdE*u700U-4KW{KoUCQ+!;)9n?Q<@ z?{s$!-d%VYVYt&fwPowLb3n&vY?4c<$*O#}bkcQCb`;{Zml_VGgJ7R!60dfR1n#N& zD#o{}cn@D%`hIRj;`qTF`iz&55GY#zay5xF9wVgVOz>=1KvpMLr$p%OhLi%cjzM zHmyvhqQWg>w;KymlRih&IX?1yC{&xQv*PCIC@(#1(vUX2WjePxb%%SRui}Zeq!jr4hqqSn*NJadmEasP+@>x0qpx5Bp{FS#?0?^-!D2xoUc25^zi6$x_IF$XHox4 z5AIffGPuy;7Co5W`(7C#2BZ=&0;G406Crx?Mn|pu+W~w)%co!kP!Dmt(xVp;YqBYS zCfLa+_@f-=g8fpfL#KUWH-=!9nZBUjy5d@f_H7RP@_Z8UD$)_32v3DX2H06ZKCUin z9!vJOYm?1zW2ed5UF|w~^R0$>*4O5Yr{q?4-GcB_Zs| zL}EJiqA9_Q6iK)x+}HAY;pNK93Yfo~u=zU^uF?x>6@QVl&FJr@VP)M?@CXIqzA5@* zM5{2~feSzt%k>#v%&X7r1Z*>2o?98n!+$^Y!ewv!MbSj5xZJ)^{O&{W#Qs%5q&9L& zSAon9Z1z4Vp}O@lR>aH~9Z=<8`fM?}AJnr*C+J5-=MDp z6$?l~M8GVrT3`ybA=8XL zNoEC@OeM;&U`p}k4#RJdYXp><91D;-3mrQmCouvD=e4oKVipQhxvO6h@67Htre3kU zVAZVgLxMw!B7{dZXUSy1VvZDD@?3nS_E?8Hf#O9t5TU(b05#d4$?5nNsX^E&^;0%~ z*;1do@l0^~!ztmb%m-F+t?L2|Rxg-FA-5sF0K?-0J{H?U{qUG}L=9PnDWd2Vsc9+E zT2N8s6tHhRiO-zLGRN{{@9Q%Pt!MaP1$~WDv0W?Pzd?ED-#4pooXy(*>M zPRN$updY+?TXNB;`a*{Q1$J)%f#l{MinF zjKd#$_C$<5a5)z6#w(d znrRwB0*jl9367kPrV>JvIb7 z0?X60E1VOO@0$wmtG-+>Fi*^R@x}Z0DY`RU96;h}@^`_$o63%$C4_fu6$`k)E%vE9 z6!*7-?^amnC9;fl;*%W_Z_2t5{WQQl8dkCEJ+Yy0~<9MP?Lx11JqVC>st4oDn|Q6GTXpor#QiZHyS%z=X&~z)Cg8r=3(iwtfoFc?eF=>^X%F!L7(Oo5DnrTo!U0q-bw#MY{d}(?}f=;bN~9Z zq%zHo+=vSM>FybUBiBN|7Pp4_#`JU{kqzh}(p`N=?(TD_c~X3H zcDH>IUXvYL)9{&v8T(wIf4Ie0qQ&eVFT?krFGK9#xs3hU25UW0nV#Z=A{KHJj&+rm zqBq5BH?bGJ%9~SJ-CCL5oHg+M0b4n^BO5C1i0Xt;MI;t{#ifauG=wEC>dG1{ipge% zt>8}OHLi3Xqo?{Y(J$}G?y;Pwtx8YN%Xbfa<=a49a9S|R>Snvnqq&yGaLw=*gB=LO z1PXG65#zA^-K!bxT#q`T+;w}6j`=*@63Q9)g)n-#@qx?6Kuey>E&;hCBUM#D-kVn z+3WIV1R^rtNj1J_i{~+s;HKv>eBPgBCF;o@Ws2L2ZD$$RWj7#5&qP^vbNhbqwNP08 zYk$$CvZ`&(v2R&U;;Q0+qhkU8;iZknWQ6IMsDoUnn|s87BSgAlXVoZ|em4kSC70d=h9n}77UU1VAwUFGqy!0{%~$I~{k zo8iwq)3o>`Nz_C5=v)eDf0c;>CU`&8<#j>`>_eRhBZaBHBjPVUI20}K3wh2Zm1p&r zru_1UV2i=*RT3Y^>jK5#Q`ZUUQzrZHh*_r#uc(9^Teq|t3zM@?pZxSYUI4BdNEHEn zEz53BVFVA*FF>ZOpg7^Ch95Jo$d{$~P{%HgGCwGLED(OHMWt3*>e0t9ZM6oI@^$RI zmg9HlS5_?=Qx?>-VmHp}r8q{@NO9Z)k*cnrr^D@_NNi-^AoTZ}d)<5d@M2sS%;t34 z2f~+^B<%~~Q51uXhUJy$Z86%IXx$jHJmC^h~n}5a}K9El6?omH|GpxJA&dV_X@8eP?SNWo= z=Hm72qLK{Ro@$IICh;e#_Q}bU@@X}FoM(z=`=oT-vBBN&J3o*IgZs5#*ucwCD_RFo zQ^4Gh#^78a-9;Qz2yS;NwEN3BgV)#g`Bk2jBwelSId8LJWXi%A$&uSZIYkEJAoF`2 z%81JZ&*h!nXZ$_066RM;%8L;|o>`h9-O&$Lh#HBM=KFgvdBGv6E_IpCr7q*xV%jZd z<;#jEr|A2kmo7q<^h#3PP|jkq%viTz>Bn)^RbhC)CBbv5Wf0dA<(Yq78!#nueUANu z19Fas@&(7EDIfwW2~wNi9hd}W80#>gUlAu1R5kYsJa$ho$y=~{dR4V%2y?0W)`y&V zsltA>@781o%goMUY>9B zG;dsdD#XuMLzlzH>9_r+gfmxH9XTj|iafCvrHHpCGByJAAbj5doSD0DLfrj^N2Y;I zPb6!M&ZK{$yTHW6{t?79o!5Y$nGvztIUmoj(KOsJkyf*78Dcx32fmE(Kgrh>LWE9- z?s0GJpk9#kO>FX-N7I8ySFURQ(sjy*C0IFT37OhhOlqh<$c%^9tak4utY~K9@7Gqt z+I5Iwt^Q&SWq8@uhC#n-mc|_QH}tVa&Ot)jCWGjz^Cg@`?dsG|s~oyoF5e>dlB6&b zI;N#T?o~cNA@0-uQ&QVb_utFrNzsWgEGAR>Sr0K~nLcJ8%ezo2-Tn6eT<3E#lw6iU~PzFW+i$68=y@db<4S zedRV^w>xuJt(aRYX(NhYJHSwC50=b{V?|~B1lL`3g@U=4z8RRnLZqN~xESr1c!v%& zHjH(>uDl^60{#}R2r$~EA=P}t^C!2E>KQAtb6z>;>eKym?GTS0bEmW7vSDI$Z|~yz!*!6iGGyq}r}w{tjOiq~rF`ByC(QWoqH1id8*`NF zUVa*Ke{57Xyu!kp-V~xd-TJ`w!855Dr~1!k{r@s=r8h3Sr3S#9jTumw)ur;ry2he~R)= zq#z+PUWfN(buxBYU|4@DH#HT;3g2;Ki)a|gb&wH#ZpCEq&yci#!kYc(NBT=qQ;tVK zRWv1bIub*T{#>K6td{!fH)x<2coX1E92bthBJp1=4$=NO|G!jB``>lFIAnWY^Uk3b zp(lBzaVL&%bpY)RumG4e3p5Dd09MrgJl*fvf1Tq1S-}6&dHVl}|3@p-{;kqQgK8ve zNd^NuqW(YJ5waVZi>KERU^-o)5bACaP9ez zlD}nMGXNREu?25jUd!?2uK+urV*2gYkA#Nwa&$8KoVRyF`^R|h)`Ldr6_<`Ro!M88 z$f;G(Arw4MQJ6x&+{5sxheolZ5araTyr1k2Xn(m4g}i?sXk7sMnIgNnmkh4A+&&;= zI-)ocSBWDD`R}gr4f?R<^Y&wQe`PP0r#%&=bOenVxcuw}B!Jl46@B($tL!?4-gpr~ z=y2|i$I`@QAyx@dEuPCX(F5yNOfM#&7m#B|-qCemK=<_f zn5!551rPimfmr{%|G)nl$%;UUjM=#Ki~g`%ADx{|Kueo4fC3sD0F4KswaNZOP8;G( zBi=;eZhCxC!3@ST$hzYW5QzpesB?C|3t+*IXm+~6}KLMt?Ix}@m* z?)v!=hIcJAiEhy@%?x9ro3A)q{t1Zv10&MP%bA`SC%i^|TEAkPRl7+mfx5a(vo@XS z_OFaMhfMD(k?01XiBP1Q8JvfA%8B6=gyTP0zcRr3m2mzZ|3BB(wHu!_*DY%ulWvd` zgCb%4Jq-~sQ@Vka9p0y00ax`e2K5=M9rnU8JN$$_J8~^SyRVod^nzrn&1~rW?S@M{ zHNTcr-R9_`0ZMKrA)p3>gQf_pXTAiIX<2;pl-%vQg5xe-;XE8J+1HRFW);N;(w&DlN*sn1U5u29Z73(i+itZw0_N+n)E;eg&A>@=$!iz9gq$IEgZUr|vont-P^#aV z1Fd88lug{7g4f#|YONmnfAUKdcoJnM8{IDUx+g*r?T9`BRD<2Uh=TrJWPxBSH1@h# zG?=vt>QcKh7^+<}arw0m<$x|9e|pba!%ASlO7xJc;he`_yZ5KYD{Z~?ltCmNL_CP5 zYq9Re8_Ya-hnu1ftVF5nAU-KLeB47bM@|59fwwpPJgMf?gM97w-g`s_xD0N;y=rWHp@fBVBjJ6^RXh^8 z%FSp91LV60hDmZS!yro~o#fiIn@jZkuOBu4WWIm*9*TUbuWvRt`E=XKfc4GTMdp@R zzC8{812(GPpuSrKpm}VkJVZV-zBAtA3Sp;o;mqPmCU~?4g@dlaOJ-npm7K4L4E3IG zzVhSJzP6x3FO!W8jQUDL%=yE50b{U3%-pWjfS|u*)8*azo}j}HC^pYk^ePcHRjsKM zoDX|RDT{XWT-zSKf@xL<$*h?rfw?cs!EXNs#du(6{RhdTyq@ysG^0tkdgiOY9e?9= z$uA1Cv`tM}f-w>hHfS7c=KbvBne?xf8X-@GI4`YLj~%%FCH~}W3Z; zgKZ{7KM;62P$CmMntJ6{;@$XM-?RCY$X)}+VA~_jM9_BW-}o9~;D2A>OPlc^db{Ju z01|Kdks5S75%hcb|J)?_=UDqMDx&|Y$JYNpulpaU3jfZa3uvz+F}*uUSdA5z#JCiQ z$fl)bZfRrc?AO>|>+|r^BrbvT{n#lvSjYAL~w^JR|oY;Z2^3R8{C&MC{isw$ta4auRh-nBhM zkIC=Ji@W8knNKypJ`lbaRjK`@KZ@#Tjq3p2X*P4boYQ5-r-R-54Kkh7WXFU~eQgK{ zOz=A_F%5{8ya?a>@I;Yr$>Ai<-{heQ9dD@h>%HVe-a;lNDhBP)R3Dms2n^^4Cq{aMrMVOnwcP5bbf1(76t zW$e7B#>%CoG3+5d7CW3ayKr>9bJyMD^}5_AJ7RTuvgJ5RwS}OzSfAFZrm5rg9vttp zX`SRcbT;dwi6du^_mk${p?X{SGpp(ONp&>a>N2ZnehR=yJ`HDD*=L7e9IL2L+UE=k zt~g_5WtG>opTEbho^&-OAuRo)BF|}_cgCVx;C8c_yycu;EtfmkchTI11%y099l43L z3TtvI`TSCFzWv7Sf|}^qFJHcZk|V-I^|^4zX;6=}7(m22ynN}kr~!^GZA#f^0_z~1&G^tKWHqjw> z#t+nT$6&6oROI;7869aIsJc725#7GcVt6<0>UWYUCgrSN)-r0+FGXtJjeDPGao&4f z-aKTF@ouV)NVGQC^On5@M2Q4#CK!j}n)Sw@4e$$2CAM46p_f`01g^SdnG3$*kng#j zQWwkn7;uydlEiNSl^D2-TUZDypM_!Zv>CWkxtc#~#^&|t`>fV*W4?~$FT550eZiB5 z=M-6=P~kl!3<5UtKUMWF#x6D(#{WdlH1=4_7r!B(5R3~pZqM94E4iJrb@T0iCYw3m z#Y%7pkR?vP@VM_i7-s11i}go8E^js}7z0qsP9-tX#ut3adE}QHRaAk$K@|&z0xB!W z)w+FW`wf6gH z(~OQlPKA#k0UWxkjh%T54}Mn2L^oX&J?gIp84P{=@LR_ZzOh*8PH8~We+9U8RQmAf zm8w~rWy5H=e<@sJRQWHq{=t%Q_v=1=%D2)iwX&;RyF0hwO!#u)e9Z-mQ@+XH>A%-m&?jjcmk?At zW`>e}J`lC~5n#zP8Xw`tY%yF{d4_|^G! z5uCk)br(;axuHYz7Kh$dTKa-I-HCH=yZ8ttJ9@b*m&w>qL!vo%<^gBi*H7`guGtdT zuRJ(tpaT#FCC<45uJ)Ma38}ZYzwPM0f1jLTyaZCu?ZGZ4YAUVBod=%@(zC3cRBn#y z&}bHj-}9I&gD%2OFCV7;9i7hmC5VZFX#+fs9I zG+S?zz@fJM*;Xv$ykgnSLID9{Jq7{%ixXft2oPg{lLnLBCSVYC*+Z+IDQP>5-c0pf z51*Hl)Uo`ap!eRLPt?;r3?TMrY*UFcU&Gnq_E^=5;YW!SANR zs@>Ay&)ezwMj}x~Wv58ED#|<%YmO@1Ek^5v*t$^iqje0zLWWk~k7|R9XIc<&t=<=} z6avhp4Moy{9bT^QZ4V+OEe$0Kv)AAzBmIr)?KsFLhU!jF?k;DIM?CZwW&S`#Us{zd zK}Jy+RO%5fg_{fKCY80M9cee>m3J@f)rAJwdb+$82=AuX$W{a#UQdhBY@=C&nmfWT zHGXc!rq(-E>Nq{Y=Wtlfb^AP^-r@#BVbWjyeMz(Q886`-m8T##P6b6izja5)HD~cv08tBdW|0 z$djpN9-!B%sJ(XMv>1C7!zq5yzAI3S%B-@@EaD?@Ic+8|Dzhb`22&62Ml)-UKc2Am zOiYnEJE=cg7W&4cbEc&Pv){9yiUfwzu1v`361+327E4+Xs<&xHz?x$`X`7xD>G2z- zzNEe?c){66TdMSHxGA|(84Z~5H7+v26>!bAqb)9ljR_T=vX`4p{A-%dUbyS-9)CvK zrzAu%dN|Ut4FvvlWcdYA@3PB)fj+$%1Z#sa`%TrS#UsziQ@V?EAn18v5ElL`ar&4YvkV-muJeV5yX-v5GG1jKP}xJBSBEj3l=a2rppu2~omz z@h?UzXlej6&qDnY{ZkG4Pd)hDp9(*07G-`Rw)zUNi8xN-IvFaS51&JTe`>*C!~U)) zx*+?JwTd0ts-vRjae0f2e9kXD&r-pj0PWt>(p(?!eFn}FE-T1x(@>1zUmf$b)GD{I z6)<&tiRrceOU!{28M~b=kGRaM#U3-hkE)6)D3hf+HU;`lJ2+e>Mj;J=`KN-}NO#^d z4KlSiKG83pjX#&%mL~STt&d*@rJIHtG~5;X>y;5u*R;rdYV5mG*&`_5;t*XqRTQ6T zmzykOfoA--7p!xI7dAGvg}~jg8JZq@I&ukx2hKI}ct_$p$%CaUr`AlcB9JfMxUtVa z>z3gcrOzt+_8TjRWEd+%kdW7bV#DQEk&)$unY9Y#c*u|@+U+&Jm<&p4UbORiw4!2a z2-X`rxqB8sQIUyI#AL2yfjlAE-S=>}xpTWx7hEUu+un=G@#EC<@n+{1bYSBTG z2%=af3g==n1PH0bKq8F7E#D|#yGqX40Z<}`Vp*gL%Hw=b%{k8wKVl1xrx!6jdV|mx zRuI&Ab#kwTZw1NFjLNFCdfO*vBKhmfyM)Aqw-1odK5iKV)Ru$bKLPI5IU5hvWn`2l ztV9aN`C~zLG|woDRr1i*Xpo=%@Ik``kwEjJak=FQ^iuIJb97=uA8W!=$>(9w%&fC5 zI~nR1R>CLp@^5g4Fo*EOpq;$t`5coBd2TJbhP%x_S? zKIAlFiXI?Kj)eB_!R3fDK?8un&|8uL`7K4%7$?8o@%7>G4&7_kya(>y=}9@U+%jT? zpfZvBk$r8w?+CR^3M6rGon;gR0xHq3IWhF6gk{t_iLsVLKHg(y8#uleZC8~LPSuAon>pKt6n=-J&ytKf z+HU6lNR-&G&-?T)xo?BPFsE_w+oD`;@evsdKTvV~ zOlu#Sj<&!Cb~P`iEb?f2K)cL@ZD+`_3wwYwLd6FUvxU@!1r?{;!2+?vv0gc{T@hja zu)KU;V7#`mWSOYx1|*j0YLXUCsXhPV_`7jW<=G^3?W2o!MkdCbs*kSA(mm1a5%WHX z{?xcB1eft9io*@D(D=*qhIk)By)S^wHNXjtN);)A8q%3$`~*S6kQ%pj#U zfsY`N@a6aLYglL!;^9ylM1)|zd<9OsayoZUfzvG+5yrKXR+>TsPrK;!i48$ammF5`x7 zV$l(M51;qG8+xY`=Ii=u{k?GD=0$?3@_^jOntAn)*x?YojM%w4P!_&>3gmioFN#U(tnYf z`aPW%6EOd(BHTqV!7pPciF!m!3j3^b9nmstHF#F(FgSE|_SMl`cpViEFt>u880GXT z5UGBkf0??Olpy-A|1=|>;Q@>oF;xb*oSnWEKAPM#b_(o$#;9dB?|T|>MGPlP+7qkDq2Vl^!93sKnG z;-K-QCK$a9!H;c@=CiTEX({dGqLw>EU4&;aAQhLf$^FX`?=>B8B5asR9I6X035O$_ z*D969R-UoFhAzEqnDC(t{Yr~3?)tn}UA4urRjaC`)&Cr1)GK+;=&|sQna_1tZOrT9 zwB076p^g)^O%I}aRfa#t?x4GMmWD=l=U$qRv1abNEJeWe0uJ1pJGs+M)M>{a2n=3Z z3d+>a{WfzT6MEzznR|L~0V>O_saR(J=Hbj6?zkO>8Bv0Azyh+r(pJ;ne{Jd?*b6_o^G_|i()NjP~E>l40L(HB_~A9A<5k* zbr9#lfo9L0-8rS?hyImjOR$c8!tF(dmMd5^(TwOqVJ8TK>rbFAhmHLQF8zxVbDlk_ z`7aCDfC0(Ui_h{|??`8H`nxGLdRjw?3YhdU_@E%W?)lq@&ylJ0I+N8;ttO;-y_+hDuly}j zm=6Df;D87`3Mp>f*~Eo!x7l+|c=E5?@voRPjz6f+FhfnGX(&9`prujkx%W|K6C49O z?GY{z#J$lzligfOs9Cw2jaj*c9+wPuYRcbFxK`#+eNOPyb*izo1P7{v;~k~nAPMoe zzd@6gtw3FpV;)Fhs4AJdN9kJ-`drtT%2!o&`GThqr;N$@d1GD|(d={Vch62yeYZ5j z(%|poCff#~wpb~`vtJvX#60&Yzd;EP6{)qF+yiZjuWg)yM0*RTz^$b0j(nL3gfl<} zUBnKdPah#;5VO$P#c3ZqWBgDpIp-_!`o4;0m7=Z3Vj9Xf88;FERS)M_c=U^BXPRi; zyX@nEG{e!ylOlNrR(n^Ure}ivKA5x%UAzPfqV)VSj5|u3Ml5%u;~Y~wD*Zx3Up7$1 ze*OegTx~25{<^d=DpYNOqNfb0AnBI~(JqK`$@pD(Sf%~eLAQ88p|({;uLFITR+ ztRR{Q+VBL+J~TDj>4ZW}lEjsCvM~`l0`j^hXJUS4zmwB7E4e&h7l4CT$Ry;SJrlP2 z?vNxi(K@+|h=J=brBS&qaN%4MByr&tC089kjeFi1$lXk;rE{7i%$&ZJ?Vx=s!^f4T z^`B~E0sNEC4OqnPEHIg05JiYu;LbGy7-rQqrP&LLHC<*bPhSQJE(;JV05@RlZ4XP+ zdk~f%v0Il)jWI9exfh=tUcXS!Z`v9VdtYf(lIth57dfrBoVH69_KpzOr3iBk7ayK< zdRJX^U9RdDm27pb;-`le;;{y6V!IErhtdHH>*l5UUVM0ArF*{E1YF0l)Mks`sui7z z`E|!5BYS5=_lw9}AeTAo4aE)W2u&J17TY-oXl~K|tZ?bdjf?ivy3JQtUEy|+7yQ8! z-g7S7J59FnRb+zb>YQ(c(N6`w@U9)r3-6&_Y%?-FyUW!y0>44Ka<=v9jT8m7>kIeH z?B*+p^qwXgPQb*@?E1^6@rt*tByQ>sGAG3vb6#o}(gAvM(Fd1HNM~!8 zlQX8Lf(m;U6mY7set12)gwjXNJKDOZb*p>g?)OqLffBwP~%Rrz17U!YHWAWPQ2pR^Nq~9 zX?Id(hEXDbuVq@7SxiV=`|Nm*F?N?+WTB(HFV7Qo`r31=z{0mNQy|*Yy^j_Mp-Ud3 z%Z-tK3UM=nwYF|3aBXZ4eUnpgzSQ0GFBzpTIMA!z)GDPpgyDm+=q~WANykiDk68kO zt406KwC$FvN_H+=tU_#_hYP2;k#S-!cW?YhbDzd=AvQ^jLVO>{qV*9`GTB(3imo@CK|@4uTKF&hE$73L*Jgfj+NcB8m( zyZOacc0oxU$GtaICj+bJOiN9lODQmQ1d4@Tj8agDeVlPan?8p!3_UmF!b#HjW#r!gQt+d6Bt9njS>$Z%{NGFhK_L8h^d}0KtFgRIhBn&2}*p zH<6rx{VV(u0njTw(>(RJ$#0-QK&PS%Q-nwBV7!BYCG)P6bXy&i4*kTf8C{S^@%UKz zgFtHaQjsXC0=5FFDouhV_7^k;MLXKU#@$#J6z@iHBL-AJXA~rqu8klSugmLZ&lCzu z(Y$hQW&Ky!iYp3y1SJF9w-kx=^4_iCYlRyzL`CE}VEZP7U0ed=g=^)jFo!~m(_-gG z#81-kM_KJ;&f{&E*)45a&VT=#yy_)k^*rz-wa75}M;uX=ac%|`qbu{G+*V{HIYW{U{>*Wwn{*o`0tt9G;t#kfgrFBuCfuwV>IH}fCDFR@cv}#c2Rhe zy~QEYR(-4fE>(oOrT4$Y{`_}r5Y;l`48j=^f#$l5Bp9`XExm@CkF(AnUr*s2dCYrD zwCl?VNR6s*B{gd+oEwXb)nt2PO>k-dW!I2We6MoKqr$=Jt@1LEN^{>~9d^Gw2s%;-44&*Ij}Ql1z#F%#CRG5(3hSvM`2O zTwfr+mV?iK`-=XYoMrc37dOebZKx-9{z=(05Flhb_5-9^Fq0nDv z3Ro$Bm_Yyb>iu83KCDuJa7yT1U;+;QMQiuokfJwJLjDNsAn9EY;_uym{s#f|?*r_A z`tKi70x3FoeuLV84!rfx^h)64Nh1hsl&4ZO=zstE-@M=YY#;N^ZrgfoRXyPKgR?{4 zzTW3MnJ*bIX=WI$O23aVcF?BD)aLo|p_RvG7;ao_y^Rt=H;yaMu5!%ITlZ(SE4vgL ziajj9u%mOn{$J&v|Mcbg_m|-Rz2?%bko>&2ZjB+|Y6Dqfkq?-2cvMw6WN!jOclO3& z>0j<%y=ihu1jHc8MHg53BV2OC!yiE1ZmIlgco~xakjElZ414`$&yUsQnvY+$7z$Xd zr7C^ccnJpEUB{5q4U@9YIIdUJTVa>w9o(lXrb;3-wr$b z?BlI$HK8lVfb-n{S66IY9^xe!cJk!fBQbV9y~dv>Hn0rVlcKqi zI5gt#A!9!(=ikbR7c#toN|0!5YDjmEbuAJ|UrA@Mu%JH|DUzD?<{uhT{Ud{{|C!ea z-uexCh6j?(W9SiM^W8L3T)i3b!EaE3`3&gqHADY-_5SyLI{uL$QuPX=Oc{f*hM(TH zc6sV=V|S+uB|RbPR^ew#{dI9mfpL|OS?8H_2CM2tx#)<1S7Bh-Yj6#w^8;lHSvM5# z|6{nGC@?S4Go6={X2#zC{!D(rmycamtXp+dM@WEs=d#?`Ycg-CgT*}JS4dD6sX)AS zx2SL|vnjk!?`nMCQjFR>b;MxD)Ak7P?@$X;U5UwwuObO>&D|n6T7DY#L{0Mdn2wiP z_zF+L-wED-1%8xDTwd&OCajgfuYU^ipq#6Gj=%VR{>0bpLDS1d8#-ZE3&{;>?+B%X zWcXh|kBt+!7A{p{c@o1~ouy^i$99WwlYSzcov^Vq#7{E8p*t#sWx6NeAi31HBVFY* zVP(t-tDk&qBEOkG33J~-&>`COAarouagu5QLBSM9S5X?2x+kN@%pZcr2jzNGn}&4U zPVchX>#x0N_M%Eu6Sf23iX|(EW{4&LShT5h7mYPBRqq(ZThz7IH>AoW?OKyJ((lq| zrFUHLx$;;F%~1XoZ_TTVd8dNIa~1tu50Il(RC?Hzc|!Z{Rx=nTL6 z&iATB_7lCNa4pUR{}OOpiR*+2Pqy3=egk=f=}z7Zk;^J;8n1l${mRt_3&Dc>g4v9` zF@0h*2m3qGC0547d*Nfbb=kQ3T~t+xj^LV3`uP|0%JCIrGIg$2tPJKuI`(Fm*~L^I z1ol=H{~Sqnq)yAr<~v-;Nf^l!FOe%%@G(~-Am7AGCSCHiOc``OKnNf{jbnZBr37zW zURNN6zPfh2MGiBs!tSk-fDsTaSv;Yq`D!(wUK3^TWSXy!cl}cpdpLkVqixZ^RGlTT z?D|)jbWwOwc3O3<5P%1mQ}gWWpqrqli9uv|U|0l4VTN{tZa74PNhl}-u}AN8v)|lm{`Q#0!R6rSl2H^ zVw%|)Y_Fe7>YC)2O?cbnGR1pCu&)tPy)k?F0Nj(8R0qJ+T;M{ZBz)tDzs4Zrthtkh z;gr$K*}#ukqR^W%fLvZaAJ+Cah|#}_;n73`F6ge?*_i=WP?UyPL?v*9J@X{={+XZx z`k+osRp&l6TnmXQNvLlQ^>Vwlt_=pYAaFGc53$O0LrY^z?22jO#LBj1wchtacw*k;F~_q*L?&6&p3p zE7+ADl&Bf2soc-6A(wxyMmr9;2VaO%tQ=UET#T2_&`mBq(XQOUAiK~>U{=yu`dGBb zV$5e~GDRRS$t3i3Xy-BW@YSt0J>?0--RB%reP(jL2c=gSwIOSSN!)TT=9HAo(!xH;e5@ZyntzuB>oHGmQx`iHFQIm6DRyK5Z|g* zPd%Jc+H`nde!;4ruaXA{P>?qnE%NCg4N(|mRwaFVSQ6TwfGKD(1Ui>8nJhPz~+G)IX3?eSIdn}!kC_xf&IuG zFSpvV_%TH_!(qkM2g-7+OK%p)fvSWLB(c^K9RTx`bCZPzlRI91c>>O!_F9^T(xtk= zv*$b6?`yoDoKRh4rJPmmmD9jDb#{8;So}(=V=VRuazk0Z4t}1xc&J+qy>cg+3ia}oZ1DMjbcjAcn{X~* z9c@QqZw^Hw&(tGWen6N;;X37yghiVZzjx9Bl(u|{SK68OADAf@c3{!yENRAV&V6odfD$Y5aHw9hG>^l z-exkJwjRMx%&1i=)Vm+&>)#Z%Fg#uSW$j}DReppyY?}VY$4q<;x$LqVV9Y!l2D=e; zr9b+-{-w`u>^7yAuHQcja+K#sML^yXiiy_&(Neu-B%Aw!RN77CO`JoIUn_V2*xc6S z^&f(7ZDd$)USu_)W!7EIpc*1vA>M_n5e6|DfX3z2^wt75&NJ`yr@X5;mh{nmQgr0H ztkdgzI-+WilUI{{0Uefx1k}=!A zYP@l6%k0=L&ey&O$$)l83x^M*nNf;x!(q70Tm7Y+&r$1k!UxMGpK4lk(*zRk?lELb zHPOgV7qB_vcy@=dHk|+)6dJKGX+73rww$y!I%_hz1DOmCy}sf@?+$;@Yi8dpx~qwI0!Gd^-mDycbN`<*uLbGf&JzV}S_B!Q6CtMySr1Kvt5*ET zSNM#bC{{^#FLFyOaxmrFxNGc$($=%-`Wm?J?3J@#HQrC!AS0oxWT+e4ujMz0O-0Ke z-5_rSJ3R0$=^|FX3CGmsrnH$4uFp7F^BB{7rxt7DP%9<*odZ&jTl)P-BN5x;ljEC|1whYLe(Gj=}bHCr~y7U>$H+&szab+MLbXhAG<>^r^92T~ZpqZrQh z2~fayg8v$dUaI4H8_LNhmhQ#=taR<-)`q(bm$VN3DRz4pwh{lGV1~==xfrC^#?3!b zP-B%heOc7u?2X85b|ss>E1BewsJrjW?3N6us%}6?6>m%Mbt#?|uQOXVRfNflF*4Uqi`A&mTLkmH zOuOBc*;l}tw-nb?bsCN!AUtrQZ8*E`V!J`*(~Z?1?s!GwRm_EYlt$M-KNz4VL5-VKJbK;aN&E1*R}6VHLMK15o1FZz&kBvsbZsNnh$(Wo=P;yy=Bt z55#4M_}9%*3B0?~MtlGvsUs~3LQfc9GGU8fWZOs07CxeCbGaj!Zp=3;=2PgC`t&~V z6?Hc%9AG&_J#m}h1|*RqOKKZZ8N6oiSFTB_X2(?vOt z#V07uE(RrC8?oNjZ_H4NoA+%sW+JcMF7hdgVhIE85R>KdB1tBVV1u3o%9XgVv|YfX z*(NO$4-ebRt3~EGt9qVzHz$&|$jG7Hmk+Qs?ICQPzd>|C)cYMg{BqR|MAP}H3=GqY z=;pcer2NOl)9-kD4mw(7R@gca2}o+h{V=6vSXbHZa+H_AtlXEHrr?4kqb3FKx>{N_ z9vU@G7nWDw>4#welfx7J6m}+@94Mu=rnvmB%vbJ4BAqY(1-x^b;zub?Z?^qmlzx%8K5iLrX^14@IM0lc4Mw zHWDY%EU6<=6mlM`VN&_QaLr9L_ZDYHWDiN7IvEd`;fDMMb$QLYEL%pv-OJ#Rkt44& zgbkbT5$Z+u{G+D7Hc1bY;uH-EO26$EH$2mO5J@uyIRg(SIN>_mZxasNLJBf()CMO0 zl)pPW!DJg>`l8u-jM#FEYwZy=~uM@Xezi3)^dgCX}H3aLX5z2PW-+^(?{Ikt5N9| zOVtaS&c5DR$r#Vcep<2HfCb|{54~NR5mq{)EFaLBasLlu9M%&A&uf>T&K|~!{5%!K0t)T^Use9WmvJTd zw*+6Tq4lVy==+(id~uvbwScMnTUTxXC4DQ%Qsj#RbL8D;-c6B@1!?M5e*7Z+Am?xk zx!}9V@8Eh>1}Rlj3!5q`U*}pvf>O6S=$cSBiEw-f$WV3k*?3FP0JaTUP_{)Mu{2eQN$ccq?Md>- zf0|JAtkA7}fp1>BxN>TVcsKALEeIV~7YUNkf#v;VZ&gbo8U!T@BOg`?!(Fx9O5Yyh z6@S@7uPe~on)m)FaFDC*8ja_0pWt|&dcHTn{ACopJD3k(gSu3mXnDA+NGO*NYO0*N6SJ| zCe{)^fKx~SY_)|6UN*y;Fd$@H^Ef|gGzvEA%2oPSnlS*6h;2-GvA^(jQDM+NW?9Sx zMv$gYv6#Q=aM*TnK!r}T*CtQd!s|9Ch^<%y4+pBRBXSXRJ2nRB81!j_3%QWxk3M_G z6fd}k42}g@ha-XZ_t~%QqKUlEYn!W=q)WvzzgxrvH}-&;2);c)4u;=Z8sgqEDct1W zLuw@-r=0JS>+U~sm)JPx!xu{hUEX_V+-UNBBT68d_yL%fbZ^0TTd4 ztmo*wG%uMzBA~g*HnZJ@2K_`t@FNn!W&29qh3UukkG7NU`Cagq9Wj0>H*cj*Bfz|o zzm)cZ8WLoTaATV;U#iQeBg8C)(XY_AlLbsN7TgDgS7+=~1k2`*xeuahcH=YP1>vw4 zeeEF~K(6f&LuAK_=k_nS-PlC(*tTKmOW`R zUdyl@a%B)qhmh&*4@7i^+Y18jFP0DABwM{#I{D&R#c5c8`-LfzEfPROAht^BPf4m1W4x@U{}g`7CV z^Ov@!nFS<8eipF3sq>oMU#dS&@~dJZL|?;!pB79OvAfN@Ybw5+`a0n26O`?TwW!R< zru4fHJw<$NC3k@uO&o!D0QSe$`!%S>XQX&F!1oUT! z%lmD1=R0rVJ~XQG3Cq&kwa@2%ahl+Hy70N(ImHhh zkArpDg;&mHin`ivR9(+K$?SwS?zgwPcZC?nk&NFE^4i&hqV$tY$24_{$%!4ibGv5p zP8rwfH0kbh2A)T`6P&Q0!iNDC(v7OPa9V&^>407M*mxV2M)x@LMG8olJ=Hn2L)DH_D#pH@Ms^hn3rDz^>#g~+=p80?ZEFsx6 zHxst8$L-Mi$&NfTwxIFVPXKZIP8TL6_Ug0jj+Dx8K9OE&l>y{ z8i0N9Hc~uQ+_>n_)z+{~K+u^ph5c$6!b>VS1N4))n zrXlVk+>>x=B`+;@K_hqkan+t^T_R=ZR%0?K$=CgYxywyX`o1>v6CWzD{Vi><9Gst= z=cP%%(rU88*0snx55G!?DVt$|60czvG-N_e&a@AIAcETCONeA7jY9or9{qmIg$NQKWI>xG0aOiSgy;>szAoUvE%HDILc%7Vd{g(X1e+ zSMbJQP>VO9%h{NsplPX)T3COWLIT-|VdIoc#W1?!ihZINg{>d++BPTw!rwcs*s~PA z#^uW;9s9eB1S8adzppPIT|>?coB;x;e@;*!mk00OHW#odRp3L zJ86bD)wPWcrr_(*yEKoJ@e@N00n#A48&n&oVq`{0U;+T>V=Tm^b;A+>?TZ;Cw$`gc zA$$N26WxJ19a7Ac1`(=(fUbz6mtH<3`L9nh3IQQp!pGDbcg4xA>xf4|!G7Q!bj`}Z ztg_rZ0R=fm zgM*6)$>}{jPwe_l6dN_%cN#O)`elwlsKP|<&dSg(?XrRZ`~smITZZCvLy3H^di-YmTfJKa zg|Z2cKsz_Gt1|AK$WzTWiac}QN!0ECG^gFyHdFzE@4lrBQqBT)jZa2_Vs%mQI$W>N zdb4_S)K(VZ9L#L!W6HmmIpFt%FWkAErWAt;{6tPcUD*ofGkZF-u}fQK4_9VSa+9UY z@(?bTi(j%v6$pO0Jh=zZKF+)U1+;s&0hN?F|LbPL!%4BUG zPVjKIohbG6YP{n1=&phJas&#EA5H)W4w4%o-N|DqEoD(yp~Ys2S3J?U?u&9#C{x~A zOjmpkk9V^MsQ=2jgtiCu5u!&bfca$F3UX(#kOT~ZpW4`AzqJM~neh;K(RhjKlG5~w z!gO6@rjBk4tCyTC#j#iTX1-Reqxi|0kaO-x_7)`^4?clV?v3LjaY1K}JRm}fEYc$a z^*yr*SDE3xwB`|Qoop?Z(&6iG|CtDXp>mU7IpgA0ZDTP&T{gp=r!!vyyO``}gJ1`` zt96sgd4m-FwPf7ftjSta3f|Y)?+U zOiu5unUfw0NIoM-@8!D;?{ge4AEv)xlT%@qx8bX*2KH7y)M-Bgi5)~UR6uwTo&bJA zA8%8hhphigNeK0w*l=S1cxrZ3Y?W3VPNolz&n+rD=RS4o!T_i_U0$3H?8!#f#Z2x} zVQVp>1XH+ytRs%~#VLyq?Tyibf$*&iwT?ZeCSzxd4UbzbghS%}_CQS)WUatl zyFg8^0!c#sQE)sfSHlY*#VhTbFi#+Ny@K*VB;dig#+^f)?+qX%`J;r~J_b49t-Nlu zy1hRcO?iLrpknx*PjOzb+KKmA;u4B6d=M?vVnVzUBy$7H*6X3yDa?UWiE+F9El)nj zD^#m-W|>Xh+I~OLwb0qIo=Sk;KuhLmyBqL9JWkQPQW|8^9gu)1dW5a7A8dL0etJ?)fwe=r~D>$QrlN2xb5t8ZwL1CLJn> z8ZwaFI+yZ}hy7+}RT?SZbINGgbd;5(M$aA5s`Uo5etD!Av+oy+{V4dQU@3vcd7IohP zj50V1zkE9En-kEyLL>YWU(88`og^vj&q`8K{DGTUGYTaPhBQQCY3mR(?p@%z=qeRa#}>eUwu^oQZ2>#wH|yc_$4!jA>XV4(20wE~^i>uvyZ{5R^}JRIu({~MhW zvWM);D5C7imSu*@mP9K160$Xo31MUyJ6XdAMIlQn%VaOx*mv2zs3c&bhw7>pFk9#&sF4@qR7O?Xd)mB!X)&U5Tl&FM1Qc98DG3-2AjY zGl4Ooy-X|>giP>G(4sPpX(O(X{qjAB~PqBO<;E$FL{>4Rez9~-@iGLui&HkGed2Jr_A zydhq`XB3r^-C@Zr04{b)F~ya2`iMzz*b*uI9qYvxu4#}Z;fw# zu$A&)7&MDw%j$t)PqJH2qdV1U(zR5_6JRpJA~5Sy=@T|{;9TuTlCodR6kmIqGi^2ugT>x<@Ue$Q*|p->37#f0yjch7B%e z`S*?i4#D4Fc(!qE?%?9APKvP&%_CdpT#8XA*#{26As(^L+FKCsk)!gbyKxTNB<7KT ze;g-6vv!Sy4ALUsa88{e$q2Rm);u$@#z52_Nrl$43;1_>A75=5;%<`n1F{vj52dfe z3HIJ`l9F%5Gsmd@_W{{|^W6d&R=^}f}_vmROstD*# zqVzLX&SmQv)tF4*uMrs$`F0@M2I9vYY2N8Pl;tVb%h4cAw3(^zM}N^Aq@bP=g{`hQ zd<{lq?|OnxwcYV`Dy8@z=#y8wxMuM;PTrv|l{dufr*e&WDzpq`0hoh8{-&3UtxvB9 z3k_cLVUQyDwjHwqY4!*a)5b-$#R`V;(*RNxeHg$P!gTim#jIl>?ZDT|o2F_Eo6C7- z&$zAXRt&UF`E`vk4hU=R6g$Rt68ki5M| z`Ehlrg5KP}g7VI{m>HR-8wRL1_-49i2C78rTZBDL(XwFfX zk)$@9!hoZ6vUmpKLmjF2UF0?rd>H_1S>Ubrwx6sSzbu> z#IHx~O|+tkm!gDUp;xO%M;61+M}Fl4eh@MG2%1lx7E(eppw3c)?op<;qchmpL`;qM z9kpcMpz70Y;LZXZHJbRg{-uN?VgiWj+X|u2o{D##4gSFZWG>NzcUA!w+jJ4M7m71& zgJ@9iQJe`GD9P&S{Z9%Nzi*b?!J0~5$H*Oo__Byt|LlLpb!TQ4xB^Bvjx~|m_9t)m zF|)nBe)SE--bZG6IfwTO>NE5D&l~xllQzH7De*|9?YVY@1Z^kySQPb)A|Aoq@i$16 zhMufTjitmBbIr|U7x4flBq_<>>?4b?l><9Jd@hdBH{pccThh=Vkab-2vJ7Ji%YQ#O ztB3?(D6IroErMB8@zM8;WMVgUv<~*xvW5MF{k^l*6RcQV=@g4DWw)&U>m0FQ=o4VN z)6o4l=w+e{0g;LoMd)}h!6C#ec#|mOh6(o;<@#m&+}H2kh3kkHpDaE6Vzl<~{HZ;z z-qhI-BgE|r!iXqfYtRsOnlcl8D{VbzI2xvreWvs2D^s6tE8{*HhG#v-%RF59ROo>P z<>z~fb|;<>?S)9-r5R1?h>x2T&UTmTm)P2kNU(Z5k4Sy}H%Rsi;aS_`FM^ncGG<$f zb*zqdz=v66lugZqqS~A8iliWG?v6uij=luOKqeC|cZd8ZpvC_^ptV;6Z&yG%I=I{g>E{@-#`{{QEL3i^+2 zy#Fhr8)gU{2Np$pB|9GsP?TcSFsGgUdu71(Zn- zvZB_|br}U?%v2M)z~-J4zg} zGJ=928Q?o~<=V8~e3_^XVH?!F2ur->Or@t2e_fY)Yu5ho4kIJF`Pvfx?0#^oWEe0( zt?fxBA06yCW$kt6v}_emtY9z-ZuZ&{QsJL#Om*(58;;t(Jf%+zhcq;vd_~!p0-!h& ze;L`-xsLi`dc@zQOr*>z{kh~Jmu_K(TJtrr0@FA>K@GX#P`Of?(~MettKbu1nMllU zF(o2j&fb;O#@SFHC2g(W~PP^WZLoUu*TtvSPT)$x&HD4 zZdkT>2F{qIAQ8hd$ajV}_1>?`3^O$MGGyEh@-V%EV8%i-E3=n?g{o`ER|u~(fF77# z=l*a?M&7qaYo-la<=KNzYP~W_zu-W-NnSbmnO(f`VzcOB7 zj1g!09R1=1O%cn#O4X<=?Z{VVz_d&i&L;+D57?(}jxPrNWpak6Gvx^wC}OlQ5df(< zPfICjUk<#bYkFfeplk!?z`ysiK5zS0QcUN$p@MVGX%la)d!2b`0R(eldS7E)SE_i} zC@!yo1m^X&PmSF~lvh+RF1Ptq!`08;nY$qKG->k`L$HK$AM;vVZ@vIcKBEiGfT5WX z4k-FFG*hqVY<-o*4W@E$itv214<&~ki^Wouk$vSr=1-q~k@tp@(TzTh;wF@Lr7H;2R-%y{ z(Kc!u|Eh+TX=pF-F(cItkkIZ1g@8AYk8j$ee4TmMaM^K+G%#_5aiQ0<+*=@BYQ$$x zPW5hj0qYvUo|z$2hdUB=3M0EXO7Gp`0TzShbB3{@bS5Fvzvz6%xwh%J zZCuB|fQh~f;vg~VrJ{P$AFtwr?>8A&)c{yy`?PaSZ#+b2mKyeZm7G_QbGPEXL-m}~ zEdHcWJNQD{myKN^hSSWFfW@2>!+|#i9}Vbh-=^-7cZeUpc@h@FNIiCO@$W%-!eTk>Y}W#`~S zd3Jb6aWg_dA3)Gv^C}@HS1S;0l3?&tKQ=h-jDZ|<2wjFk`HZghJU@i zEZDBLX9Vca0{a%90`Ut5Ip6!Z6VS&p*3cv0^l()j__52zA`>MwhAy-~lb;Fz3sXxwDH~V4)isnyq%;)d zdJfW!&r8^J`RqUf!y(&dfpqngk^m*9AINK+FPC7$*2L^&c`{ocPr8Rq)UK9Ht%~#?lEEwa)E+Qtpmq4%DhC$ zoV+Q4)gut+T@M!51{IPr93+cfrkaP|x+i?pNz((nq9jiWmj1p>!|DD>NbA;0PW;bFg z%iI&X7P&T$bIS3JU5EW($m&^apt;s;&b4O$q~#z?zsRw6(+*X#qS4N9Le-+UTp(RC zwM+X=YeGTB8e>=|@@BWV2qi9V^EaOLNmgUeI~B#jM3Z#nOB|=h1xZPrl)`W|HOwCd zJbH!&z0csBhpj5@8aay+2k?O?q$yD;#fRQo?9KgTbZMidEw`w(}hx8OKZyxB?ZEmJSS zC7wUHyIe2!F-iHgcd(qzp6^j|WTj|?PACa`Zq5YF^aiG*?rY%Ad)|9nvGk9mW!q=6 zpZ<4<<08xPV_=A4OA$HnnxqMk4;`y0u~-V*rj?y^Zt2}@4eM)z{^t`GZ;P$3#K?Dd zY&|+pNV$?Pw2z76plG*=Gf)l_>Qdvt+(1ouWn9Mv=FgE_GH_wny~g{4hl%4`0p(`5 z3txSknLvlYzYSyrDftwrCB-^g$DS}Za*Ep9e)@`9s4U(1C%3+eKnk+*;-5=NANwCM zsc1#AkPlB9)%s})zLXFHDzxGklm%XrX5Zx<^dzJCT%T(?LG>rZH(n+Gb#)R~u4shD z<_@KyrxBaddYL9ix!VG#AdJZh$t?>F@v3DumZl?+i^DFmH}D?mzL8~B;g{@#nJ%f^f6)VRwmJz z73Oq$^3nL5+4wZcD8r#Ti$j>j{bA@WxW8!AUs+eX-*DaWXbq>NZmAKYDL=8~!=>I% zF;k-2VRZCOo3AOm_&YTX=?@hzq}b>G-i=|FO>8G!|E@FA*|?ZUKvOTHQx%#`(1?;j zq7th&+lPwDPb|&8%88G)`xaN)Ey}W_L^0!F_w?Zf7zTdiRcACS)%jmqcn zmwDF@`PmwY0&K5_PB8}Vw*U3$s;tiGjROM3=iL$;Hpvdtt`Igu*s#|fjxG18urnD= zo!2=?xHC;23XD(`13j9@sO=hc=!zra@qix3*3}{uq@`)5q?zJ@3Gr6_1KSw1nf|_H zA|YEoK62hHt(M{t(j#JuEZIQ8`aTl(b{)F&7Fi~hOY`one zIB-j^is$2Q3&#tMcIceQJ?a$9lh_1Ar4PwV?zPQdC!22u`<|<G`Uq9)XXPVO6C?H7j}u==+rDm6 z#%YmSP5?wH43Q#73>C2?ui&F-wo6T4H-9Hz+`aW?Yth}~<@keBj-qmpmi$U1mCaL7 zV4%TDIVnJI3J_9HB>6JQsrKWvtH)0Csl1ER9`=sfw`<9D3quu$Zm}GSCQ|$Da9K>KV&dIG;D$ z_vj0~S6I3^h9z0oq;4Q*frSp;S?X|Qt9${PAw2X8?ZX5M~S%~ z>mZ{@!})PLPH2{2EXBWYe1i&gTkn8`P_iLiNQ4u6;gzz4pQh&1m;*G9m*kk;29B+v z8DYnZr!PeQ8lRl&z8v$iDbtnXTV~u~lV-0Pd1uRvlz)cidy?3yge)TPN9fpv=3#s? z@F(e)-KKJSj^;-~e8#S+l4JV!-d?$2pZeDX8ljX5@VaQe$Evg5dAyXs+_EPS$w}2W zTTi!q-110+Cezj`Mk3b?&s|}60Ew&|&s#8$|2GqXC%r&p4R(jBR)iF|Psxr~8wMX+ z_9$rG;*ifnr3pG{c0DN*F&@u|VGxx5h7}3>@!~L|*^PjX4MMZ2IU4n}ULn(74B43_ z4_)i$7`#nsDNbiR%YRj*Y2;l2S5t=V(SOP*|8tHBdI3$wdOMo7*Nd~0iAdqShtv|1 zJ)%~hAP4)ZGv}VncK1*tIo>T+h4oLg%WG88sf1s&f0|$h9p73r|3IHR;)+K(leUZq zl`o0TEC7_m+ioj&u3NCBp;2YfO={-sx?lj}nv9U^XD%u0Qj|C~l-!6X{lpN5)7S$M%%o_UuRYS{G9)Esk>ZG&{sCg24_Z<+Uiw3W!{NzKqu-1kcUtYsUhV1 z_A}45&u&Xil^(L_j%Z(H?nZG{})M$X+k$&jFylC{ky z6`q@O<7^%Cx%ERC-yYpm+q|-@BSKpjH(X#7Y>-X&yS?Wav?uNf?T-oO0X6g31Zz?yada|BR-xO%LE+nC<+ zRNLutls&pE`luHk92BgoIiiXDp*iAFz|~ThimT~{a8m(?4S*=eRtIJSD4Ox|wf1r% zgMDEW@*17XjR6KO{L_72xeC2@zjsow!sT@4wN{a7#Pe@z?i7BkOTlNO%kf=@$CN9V z>@TSMHS^}x?EwJ!2d3U+R=^%aAZw%(5N`-z>lDtqI@f*O>M7U3J0p^l+woJ{6RIgm zmzSUsixMGjU2~q?LHWLSE*1`*47#J>>S>ip`|v2S(SRY4jmP(`ItUU)jX8ce0(d(o z3!qQys4a<^chjYX9|*rI4piwDzH`a1yt-fg#}+Axy|(g))aQtFxk4(}v5`hMo*P+4 zI830)1TUG3AXVq;-vi`avC&08(+X#V{uZW_U;DWiIUyP71LOnOc$lP z+7eqSd=*>8Y!bCUiKcS4##z5K+n`#nI5~j;DbtIsLv-R3K;36MBlshgP)|`E2fz?y1)I z-o9)dWyfV@uGT|04@vHf`uD#hPh2oVKnS=7;3?1|VPp83{5TyUqCH30(O{RQql04p zvusm#%dRq^G@|@xhp0M}eg@#`KXTCnGt=?lT{+qXK#(PqFtXz7^on!PdioA-G3R~&`Sg65oghxB}o52 z4pwT|m(R_k^Bc00cIeW0gCZN$4f1x5D%TEIRf= zCR&JJq9_6VgXkP=vJMM@tt{8f$T9X{dd7ct@7Z*HdWJ%W#HyswK412GSdK>b$lN$s zfOfsn%>E`B9DQk`&e3u#%igG^`I{2ot2ZV_?v{_WP1@PHBofFiP=Hlby5^(-h$m>% zT$-tBpJx|;3D{wxQ8y0!?9g7F`O50=uJub968^f3blmGGTnPkm)zVy!ZE4U^2oF#$ zW8$7-P-egcnRajh)J{(lvV*c4uCp;9wB0TbSQ$#iNHX|JUI%IutoMSaEw#j*>`6TA zoZE?|%T8B#w2Y_2215;}kK_WBWJ|j1&)<$`J1?=skj^=}$U#CRfZ1D>!-nx^Q4f!A zuLG($;2w=pGw*V5vjgLO1xxjSh87JK2%MCupRVqncH?cs0*7^&2+EpL8j1D>pQUkD zQ{TZTr5*GuU$EaolC#F2*h$pJ5Ax1C+KHw2YOu10UzJHhHMK#Fz5Y$|XmWdCLT@h^ zCux0B-E2xQgSFWu3EO)&+Rc0I!m=#ei}d9`V?HKVx@G&N_h{B!SWlW+?+y#41_4@l zNApog-ew!wB|XFCM}u6UV@w^X>J&K>~Hdu1c}Ou`id( zPK->cF&+&K)?fOC;tWg_ioT|9NiEgy9-k6ixvTQ5l#~zjQqXN9=fXO)a{lG8YqeoR z9(oa4|2msHcdAjj*l-@OI%7)y>eg_|aC562Q{Ku(#kbSNsV3y0PT}num`fQowX@OH zyT0c3C$%|!H#Xx7H~hNw&yXfxB!GScso#^u3!}}~&}GQ`H_*IKru&5PW+_BD!*B35 zKHj^xyu(cfMSWB1$XuJYlm?Pg!gFZ6HS`?qI^bZsuC>^7dCI5GG$gg!1ir#IZgsB5 z@*-n=rcve^Y(E{V27Ro6FsK4pI}>dU*#~55%0xF^F30iuihf|#`6L#dw>wwU&vG)& zM_vPV1md5Ta(h6AQ4a)SKy-->)BTOkvGIvqIt1F-3GiyTxKOg2MLJe(c(=l%f94A* zC+;7Va(tATe_nB^oI1KdSp_N>HYp0Y4e)#7L54P40QnCOPVDrq+$V+cfuCN7-GL8{ z_pH5{-IByru+MYrZv zd+@OR!7Mo-dy_)sN>vX7w*=EMPf^;^WUJqULl?7{Zk|K zw~q@rxc+gy=7{~L)!3gNK&t(U1Q~mhQI>zA+3%$E|DxEkVE&PR7Ffmi6y~nQi&S3{ zl+BSU(wisHkaDnL3@{7cUn<515^+CITiP6+E7ub-VZeeR)R2Q6if2lzP-+5}D~uKP zVmr=7tSbk&7~u5{38VH+J2vc5LrrqR<;g2XF$=Zt27Y`ahY$`sHqfa#*hE!KQcix? zj>x)`NczCKM@Yy=iHX%*zv68EFXb*>pKaZ|HI_0OL?`pVcySz`p+4t+1y`d(x9Gd! z*TKL!_={3sbS$k;F<%PI{b;wACs|wF1i2V_#CCl@7WU(QA5!DbxK45-|5cxdQp=e? zTMww-L?uCLQAZ#So6@Wx6IEyJVf7_zckxe$(n7}fdXJja@_f_w8f)*$tnp zwA$hGSL7I%<4F|GpNPVbQb(tpzbu}k7{$_D2*4j1p492PEPT{G1c@sS%SLB=0u{u%J11Krt#5y^&#hWg;xU1R@O;1Idcg8z+o!_G z>dZ$Y*RJh_xq&q@iTIyE06t{7hJP2hr^ljEqs03+#n!gdAcN%4-)ul zTUPl)a}i14WgV9=a_0a_iLC%k-$N>&NdNc#_m7*BUH&6|HYv}L%M9It_!>OG@!8U+ z1F0)|AFlvEP4x%baH>}tyr^Gk;&(n&gZ>7CO^Ph{ne81AVr^) zl*fGFddnG?eUL9(!25Ykm-$5N4*Cs-IsDXn`a(E6yJP;yoxE_}O7pppZj+@@DCsQq z3XzG5P-_;@crk*jxHW1d;jC`FX|h)msd86Zo9{1#;QbuT@jYsk5ydkR?fy6DJThzr zq1xiyV?N7^~lc zk}Q~Ucz~kgHHc?N=_L}lI#gA=M1e-{K-w$3X1axm-~Cu_wyIbf7Xx?NvmZLX&iRd{v3K7b z!%#L8&EAyz52+6Wh(A={-85*b&zi<@Sbb+hUn1`L6uB?>kX#ES>IIjUgpUvh)?{cW z2bjp!l`4o7A;JM+ongd1IAp@U`J3_}C=*@BpT+bq#UPS#&DyKO`^4d3!Cz z%;_AN8VU6;xx_DzG8cg7z}>+W)SgB%rb9UN5{rjz?@w9KJ z0X7rE&pz=>OS&S2!zk>dP(?1BcIQ!f*m;UigTiRScioz`FTVs3mon$l*TxKA+iGyE z9wx9VyIyAamCp|JaHZ?0aRh8!zDTRUFDS<A%CeD@tFALYgvY;Cuk*| zHrD*}H^`J)N<>FP1tCIIu6AcZ7wRklEr!n(nAlzf>QSqSii22zj-%1obX6gF4-3{Z z_?hJoaS0PkBLJEJr8lV~M7j2W%1hBUfqB(=HsyR5gW9BA?vLEr4DuOUFszE$I~}7x zcv;>AbS$@I-m}9_Ly-Y~NS%SVLUvYc1eo=D(5d4y!Yg^$RS|mUR3kOEwH|jPd z0gC9W2LPo4d~();8T~&#rGrs?m@uFk;ejotmUW8@0YDAdjpykd2fxOzRKlG!_PF0? zPFC~zetU*Qy!Ias1mLD`k@-tI9nQiVS#)ViYUDoL-@2CdVpGe+*0ziu_Bo4`;nOS2 z?Brn4=~UWCT)gPdM0B_bINgPho84(mdlbNBuGswo6X} za3~xof#UVn^Hmupb*;JWR#iA}4Qpj1h2cUexl_u0gN^FX^+ewbG33cL13W>31l5r| zJ=1MXzg|zS%od2Rv$0#;v9Q-PnCiLDzT0mu!B?#vS>XQjdk zc-Hm6M)gb&p+~jV5ei#G_U_kG_hJkfU&??Ur^CW)sT|}RKqts@O1>403=?T_z3$SS zsoG?5$bFY}C>^Zy)ArPbN%EgRHBMCqAU_52aho~p{SR#Heos9aKt^S7*F=AKOW~-h zE8Og2WM*~4=`C5se zxV-$+{U(vnYjS73?aKKs4j!Z6}_L&(~^A@ zJ6T@P;q4keP?@La*r;H;`fD9g7@_nyugK)4CFUzK{@E7eNwtd>aIN z`APu_)It8SF6zv91xCIWXfKI2G)&$Wc|BoL%c7A*yv5n-bk^Epfa_GaBvT8gIV;uG zna6;1<#~w4%u>C&?=J;}P_wBA{3qXA%xg*cpz|CJIwZ9g1JI*>f-}1*(YV=NG1cRU z#^3W@*Pw!=@$%_#b4N|)-Wm54VnqGwTaP!y)U(w4Jx8~4n3MKlx*>P|iVd+pMs7dy zaEk7gtzHh)?5m7T04?iW02U>mYrl64#YR}r?wsa%+bkSfZZ-C;M`eTOqg%?VG{y_U zLoSVZewKEH5XwJjGp-f||EntSk&4EdbJAI45$j&BuS z&q4VlmLj2+aGn99AjR?iGc0u_QcK@fL7Y?j3ji|OJOb(wqI8=2d9WKwu0{nAFL+#H z>a?SrnK?ocZSpQkHw3T}Zf-m5KPs&Y8HsCFAmpE2dxrGv$T7%EEW*W|o3*v?g~J8Y z5?lL<9vaMV4*d{KH(AbE0KopeooL1|E~JKYMz<3W8S=7lLxqWHcjRtE!`IFvb*W?q z4t#gi$cwdRmnG;Kl>P)$GHHbC2Qs1E{zdyoyY2wt)~nyF9zGIb{`j*I;M-)Avm!@^12`Z8!nj>HNL_+Am-63si6NPInGHgNQkg^fa5Qt6&*7Wt%HA z{_w31i61hIwYKbg*t613XmuF@&~)f1EDH)&X;01Vfv}C$4LglqPS`#)Z}|G|^vS4H zn$VZhSJ&E?{hd+`K)z6N?0TNq1nQQ0gAPC=BK*JO6_@!GtGDNA5Wuv-2-43Tbh$|0j5Av1cKrFDZxOx z{55bh_A|h&VJ@Pnk)1L=P%V4Rq^FM%jCQ@=#S0GYq7CZtugcGKN-D0+j)DB5A9Err zcM~o%L`M;VI_;hLoT)$OIvO3EE>5|?LR$2n5(KQdHBY76-q?QDuJ@>G8N{XUC=f@x zPS|1bI#?6~?o(Kix0Op4o9Shq#YKrB+eib!2NpKz&zMZ4>j55-%>Vi&>3AOnA@kza z0+k3YS1X#8DgDJ=PYG8_b#6X|do^7>rQh-T;FTud%so9p^)Nni&pXT$aoLIe5N$Zj z$|n8bNDevvYehoD)%L?fE(rs|P2SpPR~@XATiJnN_wyyOLg4IMx5SN}N8^E8?5v>*Q8SA8|gVS_T_b&;}Atg{|K_u|HME;it`3 zzNRr?vLpE>RifG+oo&as6oxB5CSzhRQKXzyvW;p&!`R$B#S5naVf}iKMHkzFKancE zzIt>1->==u<0E&qN+tGKt{osS?N6VQChZolA+v_&hUBh0D;wHuUl#ER?tuXf($(gB z1TZTSZ=9h=$xHUcc&EP_fuRv%L+pd2uUD)bQhJ5>j5@$v%`Ct~=1!S_KT1cRs1)1ikuOrXzQ{G1d#gHaVX(GZ|1{j&L{l-ijiK)G zI<)#1^=?QqKzGg{VYwF!yo1l%h1C$RVB}>&5`!rWTp zk5>b*Srkw=1Be@r;@`yrkIpxIo#a_mYswM)Zh2jX?Wf9>wUT(T7{on2aZe&N63V8d zfa@IDy3L#gX@8U6P7yEqjJlbia6{J;=PcmJ;R7@MdjB!v=^@K;;oD4E6}GF)8BWPW zfQ*1tA~I#3#E~kPdb_j&w}^O)`jlT1OXmdTKes8LA;d}6{cC?+BS0Cc&^EgpRCL?X zrR=!|M1V^P~HsKIKM=i|I@e~V|Ok@%+YHHkJg2M85phN)1mP$r0wrHC+gvSJJI z`LT2jDN(M>mTznM>4FIyF3N@!e|EOBf#UM|WaUI~j4n)7GMPYKag)w~g=$1rxYdsP zR6a_v&%I@Ab=oB@`}xzg@as2Jn02g!4M2Py`s5};1ttMQ^M?~06_vQ3e5q?$EU15$ zSQh^l)Gz-&Nb^yFva276xu9}80n`KBd2u?r1O@TA&Q`wht>^C{ZcchyvZU_)(+s+#k8HRfgT zWgu~a@ZY(EXH4xk)*f~AWH|H`8nH$VioERLDFV$D+t|#m2dW==ykiCWiiWQ`dDvx* z1X({9eF1@Ffh@BfIYDslREpn$$#j8P)*+Wdlzwx3vege9d%PU2Qp9&AMQ`b_^7GuS zvGVW2(39)vA(=QI*TwyG`)=kmin>HFXIrt}WH&Ps!P|161HdHrQB(nRJHhJ&K9S+*4VP+p!UMC9T|sZFR-7XA>7JdT%;d+vP0Dz|n zH_Lx1n=C^!1peaZ(k_O{yq>FRUwZrb3)Fum5oLym64#n~+RB8w&3_bIgy}d8%t7@_ zq)lJSQ0Y_&mkf9icC*Y)H&sHMfyDI#-{B-cZl+6~yc<4@rTLc1&i^oHD!NuOU-hoz z$~4h1zC22jYAxj{_bSxinR-}?x=4gtu*}XZX&}={yVOgomI869p9W@lc4bqAoSWZ%21x)=KZ0NK3tdA)i zv2mF7qt#f20PEe&-@0z*>UOBrDqF{0fzqJ?(1;ZKl%x5<4Ar!F-oQWL12UW%=-vF+ zOi#AVfvMPBkbHMk;~}HP@^hwdfZT9EGjG`U2~$~Sq?)NrQrbY(VsWZ@X8e)xZz#_5 z=C}gdeMiVE@ZghJOl%VWS-n4!{7vi^jNmZ>V0}Gtd?m88Q(Od??j_e%+!ldzPUE-C z1MRN{OPww+_6*!ie(z`Viphj;Pa$rrg8Z2X(@u6arf^2Q|D_^tg_|z=psMw_o4#(_ zS%rJC{IP9$>G8~Wtbpxr=q&aCF;@2tO3a@^(6kYsvNWv|H`^Hfd#pWn{X!V#)DDAs z6=XV(;@Y9Y>y4ZDlCP@5rXRI2ruN6=so!g36uPPOp5go{9I~h%KH~R{vUFYoa#a)}}f{J(42ma6*mj1dtt)Y>Wo;#H)bJ>BlU!69-RYVZ-0M87WMKh{BmNREUbFlUs6B2oZ?MXR`Iz) zb2nw$dYhE=AnNu6Fa~>EI=;2#mM}p-*DG#W@ZeAF?rTTzK(67CWlDYza7hDT^L=w5ymXi;bbok5RFa+k?s01RjLm4=w;)u^ zLrH&*UX0z|iN?VQ)jWg<7;L4$tQo`n}>_T zD?hnOnJs|*6<78(rNUHC?PI!~6Z4p<*9iI4+o*sS@QOTrcJ5oPj}Lh@E@~EmkxP$v zL89q$?UhhCKKjTp{wqZ>jR@vi7@(B%+oj|lgd|3%)&F@F3y29iMB*WO5=(n2Zlqfl z8{m58wk^3X36tVp#3#oa>EMOjAB3QrQ9VnPE~8dPCE9uZX`FM^3x&xqMZ?VTPjC4= z;MZHf3@`oo{d%t*=(zIVQK$TqzghkN%_6#ISLf#6Am}uf%KA5m7WNIp`GhVtBRyS| z6=EFw27M{V6?endEk%9aA8KiH!UU;3g}-Mjeiwj-K(p{vKy8i|mMR0lQ}PE9G$u41 zswcc}Wt4#d)Qwjf=nT-mLExf(0N1s;I|sN}=l=~#9c60a+q3S;PjphDItP^4ulKvW zP&pDNr~HH%jmj(69QOv#~7D`F5Y2f``^VskbL<`3~sI$=_>V+i8ow6fvLTi z{qf2<&nO!ar+}18Cuii9BpTxaVmFx|ys3OOP)F}YuSZVkjBvW@6(&rW&=IJQn0oVsxshS1teL z-+=&Op3{o-C}=NgG-RdpjPPN-Qp%H_+r#K=drO|2bEq_ZTXbLJ9D>F?5oQ&we#T;_ z_SI|u@;m=0zsVPkb$oQ@(JA%oiCh<2veuLMb-X|?^3P9=&M$yVW4JPvvaXBeRI@ssHsqKmrtoDifJ z?-idzFeFJF2yl-{HGXAI& z$!49^c2-fFr~NERzx=jE%Dh;SBR4op_-l-E$oY_u$>~O%X1idsjS2tGRf|r%QTPw{np!op%k-keKHn-KnH7mB6*29> zGw&EY`F3p&^_5;sG1c~2Qj#3IpbXMa-(b2#A}@!XafpJ+cUwJ3A{bll{IV60+s~}Z zRWhZ@G|47*>SViEr|4ZW3`Jjkq==%Ap7kmduYbi~)uhkgWotomp2$0AM3p<(Bym55 zNZD2%Twa}#qI*Z3iG8G8@b z^xYVA-o~A)_J$TbS`hUirRho|K3P5D!yF=FvSzV8xj86)P1AzIGb1zX7k;|IYkM|# zp{iS_8fATxd}hk{i}VXBK;Ij$7NSXA+!?guA{-67fLB2nmcf|E-gTTA?ol2PF zX?5E(Q`DY2Yy=8#tHu|R%2?ri+2F-{XEmO(zB~G4IQgm%^1_n8ZcWG7lj5UIJ%+`a|8N zkilxz=bwFTti(1WY$2%OpSmd}-P!z*O<`6k4!Zv7$xhcaQ=YQ$6*HOtA+K>+LpbN* z$NQxQuf<>`d^$5dA#wR=epn99>UE{P8E%iLXWAo{^=f?hrrQVCE;V4k)`CBSi+9&d z>xd~y34|Um9Y5YHDuSAPWdWP7jW|@c@D?7N98O)*M|S#;*W3Z!j@~eo2#s;i0~0$V z!Fq;zl4hdXuBnF`_Q4wa3sUaAFA}u+(Chy49eur<8B5MT6`_2^U)lAUX+7p^<;?@m zooRn_XM{rMrDvUW-jLDMrVvo>zU|QS#Aw_==kt+l?40na5rC)IT&$-)fmH z^kd8dmZ%5M%s1)Cl;DL!hvd*ij1M46%rLEc zU9byDBVYBP;NDMl7DK5O@Sj|}myjbUy{Dt7-{vE^^aJI765+`o>PPSLN?8@3G?h&P zQA4wH6!)|*Dqq^Hcd)j^uhnG3l)0QqI3T59{m=M6!l<5L+g{n@5Suc>3?oR{}P_bO#8}+%dGfrAWA=dw0hZt5qwrj3j4?S z?Cm7L$ma+7tzORCi1+*U%GO+2X+W8!|3~~e~O&n zQtyv)Sjh4F?`zfDk^EurZpoM32sZ6edKIhbuf+3Tg|y#2&Kt60Q0{~#D7ROKOj%3% zrajthERQ1obLkzFHzw7{**IASgDHGP$y0zNM_rb#YC~U1_LXrHZay6`SS@@}9mII6 zcmt);sn=!}Mf!ijU4M=s7rD=q@9<~fn#}ja5I)@AdU|DEIvGWv!w?%WIIz(w$3*goyj?Z+YX~D*=uR~9#MN3E%TdvBKX!I9)2!v z_kT7Uz_pBWHGNjb^(tjJ+)=;NyD&W4I5%jY@3(JM0EbO^~!v zXjf(9s5(WbC~rlT{mQESWG}sd)Zy$m%}?{I^?9dMvPwJ;-~@{8Dvm3`KA`;%J>^LQ z>5z3@9c5(X{3T{@8O=|2E7wjzDP3PghsPPknSDtkN)qpAQSDgT1vDE7)e&y#vj=Bw z%ZJ@-dz{*IDFJOp-K{vyydgs-*Rl=_9>I|8MBWL*ej@XpXbSdv9@})`{)fkh_k@&l zpXuCuGD5${&j0P_owjQ#1C89vP&|sKW14BGPO-v+tqz|zrnp=AnU6IawifWMP*v*u zcswsq|82EqQ1}aSJwg*}NBfXBgmzqyHJy?aZrs9yy^?QzWvXjq7V8S}!$n6{ ziwdtHtjuCQaTllMHq_>#xpSo296`e{fm0(ETo!MtGwDs~4+Ie?= zw{*U0;(KLM4^*4p$RJBOFxG)SwH&sY!Y4Chg*8mUt?;TOl7n#YF!dB`2^b?b?WxR5 z3)dKX6&jFr0nqH#e74p}cMMsv$;BJU0(1)G!MG5o(q6gF zF}}0Z^*H5fVnUNj3>c8wF!lEE_2zr&XY#J=ySRNy>bbG5OT-2hpR;2O@8xP4UD^Ri z!{)c>h2iomhex6>7;kf8UC*{!X?6FlB;+UjGIutHbRGV}7}|4JO7`JAsU~hc^Z{fBE(QWLE9pJtN`Z7%Rh-mJH`Ppy=EAsV63g2&f25 zWB4X#18T&VB9TXbew3wCs;LV;Vk7{j-GPaT9GGep66)rQnnRF0iLj5}40#HKxer1? z-C`)X622Eq70Y|dV<>Vcm`#T|bXnPJB_R+LLq#mw9-y`I!F`f);GNtmJ?MLL2-$a( zRmXaAD&@oo`(&Nokl&a2<2hn=4se*)#!U-fF`e*9 zb;qEYj@sbbt!=JGSjII@m`bEZPwwg=ZwQkb0lDOFmDd(_QOD5vwJTKxN1#$5);7aAm@sRDl89ng%RS_Aem zP6b#0!QurxjArf>vyW$})nrMyMbd1?o6KYDI=qdG(l5{8dIB0ZPf`iO4U^B&7no6s ziZ^H0dz6efG#JDgXmEWKv?H;8=zP=}%D};37HN_!L_cNeQzK+dW3tG9Qr;>;kvod4 zBN3kW>TL79Y4Sb}HKA;GZ~O$;*xvll%I2pZK@aH;=Y>T;?w$*L{O$}tIRVKIT#+9{ zI?M=gkM*JPF=jmAT_++yhB6dSW1W1R2nNPDegtg)9dih(LTDa(d;(ZbIF*~oz-@vi zdr+|VSx=Z>YA?K>2-3#y^S_-o;1N2c=ogfsxw%j8qw))F=6Vc9?@U}-lCe+p7m1Zz zr}cP+O(|z#)nu3sxE_{9kMTC{9$BQ$SjK%*=Hwirsv%okxs$olt`-)}QM0mc&+9!y ztlL^!EvzG%)#a~Kcdl3hM=5G17>!DMW5I#?C>M=Q2Rhc8{$fAGwS$5+xF zL~ort^X&!~hr35q+Q;;)BU7W<=i0~A&;gPp28b-Wpq=?+WQKDo$QM$%dSqi>>^P(K zEViC2O9Ce9gX>D~5sNWc18+}=Gz3j;44qj!G<5da{C$)X4=6~67Zph#6@%{>y4lVr z#~!rZnL`+Mg9YqhA^a5k<*9uPClG%aG|%X6x(jumuLbh`6uc$idK*;qYz{G5P7J=} zClPFPcNY4p1W~Q*CyUj9@Yj*2kXWXGq1*S4a_%pVL$Ip; z=u3#c*D^6lqX;{`lk5hLxBC7eAw9PO`dWLL`F&LQaobe4;$Be1X&Pm^r;l@!NLxB^ zG?)YpO%fkOvgQ0NRI#b#N&Crunh{`|I|NVB1*GcF)b$9^Q zkyN#TUuaKjk4Lguf5whWe1$T}biSmR4C|FpG2s$C1odCd3`2b1t*xMSp$#IXro-M! zmJLVI<0#NFNU*eKvN;1-^Y;dhrGl{xJhbvIa2Bmau&-A(N7m>k6lJ2ba&|)t)c?e? z`o}#Tgs1L%pf3hAOAyO|ojO*8)x?w$dC}cT^CNKxz;Z+zMmSDht!F zt^^u=qCH744JLM)i?EHS57u{L-{Tm5bR5%N6BxaAg2f~F2=Ic2+OVg=g>`_jN*9YS zpnxk5o?+q=W1wP@Q*#I#B9-%e7>03|il%qt>3oUV((qw^@7aJkL}XZg4Y3;)1x%wT zkd$&aRz`+e;T)puBr(bJ!x#RNhoD{tM!;H7`6qz1Jpd#gk_|9pq{#Yq`qQhBX%*P> zZ;HtefKEK@9EG#M>Z^tbNk_=A+I=E@95;thNy_<6C;&BoYc_ed8tN}d%E>?qf&0h_ z4lor<-)r`}6eE&Ua*04vW>;%4Jb^4aVt^3MAug?f7RwHLT%Za)l!d3SoPU!ytm4ed zba(`!@ZWwe(P<7LYzLlK7S_2MZ}htcAQq|sNz{H;P95_0>AWd%b6 ziQPImbRYxKJeeu8_o>_^3$!F&`vdXtA-3Oq(u1)g24qQ`q?5m^xfi-?7SPTJ4e*gn z`VYZT5L9U1|!!_nBo40Znb8v zC#nC`j-&<35y=+wQ_>H_aSZ%~;wFWbqBP)6)IR*4drHDGwpP`VhVoYUnTn9 z-W1W^pI5lHNo`!2w7;M(F^|?*7u?#Lb3$(meSM8S!-|qcku3_8)O;Kmb=MzNk!?J6 zi?aP=Cm>vfSlyb^AbI9SJz!QDi{8OX(Lz)ob?r_nt+GfLpqJ|FMy%t&c7yhYs(9*7Jm(~u7qC=d&T4bQv_QXk9ceazDl?Sa z8`3A!L722-V(9fQ(2VeI@R<@clP$r~l4&r9(1BBp7N}vu9%7!F5t>#rFHGqNa)ETP zsT4uu60LKH>`k!0-q6qk3Yp!?D*1n_kAGPFwIn>H z56!f*?m2I?0oV&;ZLndcbBH4{9~bDVXFGX@Xa5|6-V0Gh4hhNprfv72aV22$E&{Jz zoeJE=0*^)gp|jh%fJ#{B0FwpmOIn7M5L6uRMg=D5m?Xm9e0o#AcwHUPQm(yABRQbrvxk3}^+*!%&%Wqx@kn4*s=?t{5`MItzxYin&!ut~(S6aa=#ty=#D%V@ zCyRUD_Ib}H0^=DrX!hN`CHxD;oG+>UxtW~}K%-)pb0F$O(N`M2QG->QVwv+|vU3;# z&DAZAng4~a`~LI)H~jtoP(YBw0J`$fArCvXE(MIqxk*^gOhR{W>mbs5fX9g1=0eV6 z!66+lIuC^01%@poM;6Zr7=*mU1}2H!VNtL+{vcz84#ee1-NVXLE9~Z4m?SY}k)@CF zd9O8iejT3|>&I=k0C**Kn^8uY&Xqu)__KzlT)~fyYtzR=fy|$vl}RUmmSF&pR)eeC zTYV6Ur07c-r((XGv$cB`kHK!qzos2U zkk_3<9JU8tAU@DB!q>famRfxiTKej;-P4)^FyBw)pXA)z>zE_L8*Te$>a7RsjVtm= z-oK;&QX}Zg2{51C3{5pTY8WRq6qnE$u^lr^7uxOTU_p``OoAvKhW0je_KH|}p?V^9 zi=n5LCr!(sh=vyhWXPM(Sf4V0#ykY8&LPtINw36Ou#p65+8e0zdGlZiBzA(DKkvIM zwP`^!LaD^N$fp&Kq0FP`QB;q91CcU>oW0M8XZ-{Jx1|F7e;#VV_8dj~ZmEY?D4jQi ziW0~pa^nfwz^)Y~5J$nP^&PO3A-UNAMm%R8L>ZC-b%LtX=+2|yS?)-kJ2M8nqcxlh zrm=6C$MnEFV1aG;dIZ}oK}fpm^IzcqzX<%Z-b{2>Jdrubl*T3osr%t?yyDYn7SfjrPK6ncn3|Y^GXxdeO)$OeLpX~NGszx zmyeNsfT;C0`tjQb1V}xif}H+UtWl>8SbcN0nnwsXTpY^|_Ut7RS`J+g)v67WeQI@= zYvPsDORv=r-tLT1$2hInKAh+a2BNa6Q34vGUlv%|zp_+N3!Q3}g|~3mt&<(ysT~|4 zzx9#jwGc7KFS@6igKr~7`hg$+6{4e__~99QvHqxWj57ZZH&^|RRI4)~RfwFCz;E~w zu6^xi8F({eJQV}<>KGZ$*$kf{)lE(yhY$I+tuI|49I{*~un$#Cx>K;0wl>GR(b0YQ z*(Ff1@~>=E10j(bOxzB<_zlccM?!uLmf>8Rm@*r;f+zHb{_#6c&lLv=wG$BIHW8Ce z?j`)g57X-y`&gKXdK$Sl<4)JM_}e|}6Nyp4+_T(dd5KceNvYtFw6?RshgSB1y8!Yp z&)Al2y>qb0=s`jf={iJ6y2-O3+o+lMeS#N^ih;fk9GfquF z!Txex$N-(^qI|uv80>;%t_30E2pK*C5C{n@HwyDYm!|aeOV}vt&sDM4H2ndkmS-T>&PNZR#dnL;j}RLRLJQ`-ZSawW=w%-3TOR^2YaGk~ znUAqiGaLF*lWp3}bFg+qstO~%5UeH6=ME4GKM@a`N*;>BOahxY1bsJ-yujcNJAgC^ zNQ(RpIR7?zfEb#W^?3v8*Tp=FAHnw+kUfaq7=V-PiBxIBZ9gGe#sY+PTOsUcEVFM5 z5DjJ@(?uWZw`<70q5=%`Y$bha-nI56I)DG(`S0HG;@`3&2&B#n5DcxO#i~#ud|7>5 z586$IOLz>f(FJ#Oe0}*)L%`FqJD@H)`gZZBc*J3DoE=8#GQgwv(X>oBN6f7i45|t$yV+fa`Y4LgVbWg1AWBo?W*F`M)Ehj%c zzF(0N(XsU*Wn8pS)^7auE*nhhsIRQE(MKuI6eZMYigL0gmz_bgc^sQ$bhz?`FqgQ* z;cC3O1Yc;GNN;n>-9y4SRU0D?N?UN}0D;<{IC8UGUw!ZWR%JSRXYaB-#+(mc1s(5h zY8LfvligXu`+kp5B7(cm9RW@Z&jWqE4j|0YJR@Hthjt`whGNR}8bMP}?qAmb_@`vJ zmQwD*C;IZR^I@;70xf(|8gv-H#1Q4TDFZ;nKvt)^9Q({MfWdD>7)RTTsWJQ5 zVeXlTSqhuO;ldC8gBndtH?tuo>! z8QWOa+3W2#{nyX))3vAcOhE2Mqd+6#sE-YK)pkNimd!zZhAL%ciCcfFe!5_LkEwA( zP8c&7Bzcx~Azz~CD^e?SCqSa+N><;DbTV!LDwG$8ZsB{0Wz*T5uoAU$dTZi3?93V< zfQA?V7!#O(^?wO(t_+t5E$p%JzF}kQJkiefTLS1r6>E_0D{$AT?Ia=S_LV zyg45~X0mMt;RDN(fN_zr`YjIw^h*BkjnyA!@4q?D|GgY1iywMixE?+14d2070Ppl* zDh^!EhNiClcnCFD|NpD7k;NY|0 z%@^-o2NxFA;q|66mw=|o4c0W(Z6I^n7Cnc6ylJOF$?qkRDo_WFjscS5(CbJxHJDu; z>X&;xhah*rTu`pK6m!)iYL+!z_R6M58dlP$#^|X%1$|GsKbMVWCGjcTP#~qb z%8yY)`BC127NPH)!tCQOZTn=<9QUc~wgS>(9KX1xmE|h6B}=KpHdbU-w#r?%6f1z7 z5vZF(tknbXTlfozsxH0zH3UCNS! z^p(M7V|iCX_eRGr+j8#Vp)V{bBZ+KgIMJtH$^x5 z>kQ6I``)tN4Z4S^9Ag?8lyIasS&n8Zm->$}X1FaJNDtAOWsf?yKWiXv%o4G<9`odj z)n(T&0(MD5-YD@R-a(}yuc{aNxcix|>}BohN%=2M*=*lXwd@W&?Mf#>6c5UNkRqn( z!q^O}OC}O-9G%CMl!JR=jx5id9sNxnCQ%CISj2kNo7_ON9O2a?{_H8xGqud|m+&zR*s zq*MEd{EWnoB^YlH*F}f#%hXg&Z}{9nUav0Jv^N2pBzyA-5^K@BJ4>RVJY5`{JNg$}!f=UUz&Hk9WsYQR3qbml>}9 zw%bu{!u0?iE)c$`u|5U_(r@T8jJbY?GRq@rSiE@AJ8+968)JB?+-Bv6RSv6>C!b+7 zOqXl+~mSv@;2Y~&xo6?z$f`=L|)=pF#F)_$VBK0G$Ixtjw` z-A~Y(PlFp!!PT}>51X+xE*N^+YqVu7|6$%5g@RcV+=V;C)Wmq}wIx4#k0uQiqXJF=uY zob=o*VKD`dVd7((J$-k+bY_wlZ3Gx$`g;u`DJPSS6IjrDXR!1QgXn1*AHA;*$3LB15~;0lHBI^&oys>2|J*kO*jP~|;U5$T}CcPtD`)7z@D_7gB$D6 zU7xjFJMfdw6)>Qzxo}y>0rqhYhd&R2bQnmT{=!Z!-elVd*|C(k4wqi=3j)iP*XN;z zHZ^5mU7Y0It0=y^8j!%U)j)q;X1pTCiQOZ9(K_@YvEXkdN&b(lGY(qivPi)4g{w#)nRD=L*S|ns@jIBW+ zK(PjEA?7h8k^tsfyvTFXv@_VP+O+2uC=kJxUK*e#)%>2eKE)ae#@^Quy{_rcnw~O< zzphDdsRfDe*B0rv=gl2LtHqv7jeXflnD!Ij?0z7gdT-Ur^H+pd?_bt~TE6S@ zTbfk6PZuQ9SN$#NYG@W|)9GCkWo?POk}8TaE?aot*=SDYhD8T>vh-TiUs-snwd#k* z8(Z&Il9NohrWVgLSoltHi&Q`kXa6ZzCF|9L!rN>2pPsVXI(;&vnDwoeWYqKBBax0h zUg%b7*Xj?4W`psihWHI;?s8*&-ijt&F85bUxmtC~hMu~wseZ^y&ZGX)ySJS0bT+Rh zl|AY?yDMlr@uPlJ7WDA8u4arD15%VFvY6G!dCL$1IgB4x_+MJ4zm#x&nsT%+XLzf z3mP=rqt>sk9xDv(#B;hHMsv8HeuefPP8n9PlG1n^asJguai{%f4J&V0wH-g}ly>KG zfT%6^g=ZeK8(F>%0a)jGw*xsdkLilf0J_O`nNE58- z^F;T7wC)yl;QcTaWNufD7g*QRG&eU(yBqJ9s@Gz%7WB@XB}Gz%3zi`k z!3Bs;qYO{6IRv}W77Fq7Sl@luy-vo~uepMFU8(un6IlnJo_($77XU21RGpLiyCf4T9lYM1i4J^qMj(bL-6hqC)2qFS8z(R$Z_knG z%4;b&!4__sH=#v+;#&P~Tzj${>H(o>FK3#L3d0s1yXPh${{T_phH?NiJWJ)JSsOQH zmqKt+(0K8yV};l52p~IFW#dm!z~z$#;Co}-ij7&PZExAZUJ~lwhGg1fK`xFzo552?V~`x_g5HaFT5a)5LpSawC$XTz zUAT2PenX8Z^5e$H%1|*-o*y9riCA>ZrxymxBJ!czC=8S=-jd7wZ!~1Ipm_%gZ{<^ zdyz>80s7JpJsKkhD(-$Xq)M>KgwLQI#7-ioECCu>P>wjnx_Ep3v|CG;1VWsJ3nT%( zQzy{6aw`-1Po`(AIrwNc1Dyd6+b>c$;rTq#%4zrrNXq&|FM5>( zS~$DW9g1VI3=nm7i=%0N1gEabwF2s&as~tD5W8Lign8?u^P&W1Z_Y+arDyb5tB2Ca zY88CzmqDm8>CvL^jsH_p^G}KmbYYPMV6b3p(IRtTYXcSw#s`O&N;SJ<-|hgigOz4B zO=fnB0NEY1o@FJ#0g)_=e)yNmTfEpI&nr=zhe(F^D0--=8B`@AQh%hnL7r{7SG>bE zuTUp94kw)T`XCTvJH7~F?ed$jzQeQkyz8d1qc%h^fM5zEt8hV{`AegLphm>hnt~MRSFLqO!M8=TG1g&n zt5LSeGZxFz+mG@)6hOWTssx^N4s*Sz)6=@WgI}9Vahj!;0_=kYnePZH4T_SlEW;W+ zlLnWyb>$_O+nGh)_Arlhxq6GY+o#7~(EqwjpAml*!@8rA&9Ss_nrEXCtvTOncN9-p z(4FcXqXMQ?3<+7Y&+?CP=AB|;8K?VngYXwjWg97D>Rs}+Sht&y4?gw<);F(;a+$*MP! zD~;WR3ZAXu{&-OTeq5)o;=t6(4R4WqNt|Ln4D}bTx)e`h9u|qdpJ<&VY0@BAF68g* z9xe3j^y-k#%SU^LYgU;>AMSpZEWgi+zim%=A(or^)s@|<1B7hI0aB>$dkZ;wLWtix z{bsF4SvSNQ!?)Tzu6rpcUk>C3z_va_SvfoSz`&efJ$3(r zX`oDmfY0^qo==>7Kpjyk@0y9>oo6iQj*xe8mJlaH^EHip658PSMpL1%O>&!G*!j>W;2&!MqntSzswfUc)o7x()&BB*ZW3C_AA3tO|A1Fz%=Au9z8d&33z{^N@kE!`*$(h@C z-m!6dTbiJ;6CP1{?6_s8qfuve(&C0lGZxH0wP_FGowU%SIkf08I=!LKRW?+#q*>4k z^ifBh($=batZYf~=rFKhok4Ias*(k_U%6Iv|rxnyyIT#v=@$7G!hQQWQx=oCWq> z*o}rq9Moi^ZVft{=;?|OWJa;c1mA?42T|;LZ3a&3Y`TMFd0LhqSGa9s#NJhd49q&F z##})_Q|&N0a2atz8p^N&$;V0~RyOs7CfQM`FL-QWmA1DS)Z^X9*>+PKAS8YB{5H4z z`@oW=GghTYx|()GL7d?giuLHh?tzcYCXlD(qm|}6vj(`LUJW>uTbbSr@w&41rR+dj zx+v$gV`ok~7Iz*@d@^r)^#45Xqif z6A!*>rSj}dl;6Df)2H&Y<`VBkHi=zvf&v&e9Zs$?H+@|~woA>JiZQ6zDa(W<@M|Y7 z{~#@qAti3IYx~|V;fk1#pKO;+zI++}wny`2xPF9Pf~!i^=0IuF*B68AW@8oI#kJ~O z-8w42wN_iKG}+?BUW~jpk(P-BT_dl(nO%%Gn&kzX;22^MAZ44chHL!-8n!n9{1VokWe*~( z#OWF|lQ(4iXYw;?J=*TAH&sEws2v#A&~xliLKhk^Z% zss!}WzZmpyv#=xd%pWyBGnToAzpL_5z%^CiV$8dWDh?fqXKLvB@n|WRc~1WX`xRcv zoe}=hMS2qxyGY2p9`U|m{`#q8Xh7IYKlt+AV4+?61z!e#`v{$!%9Q9s(ab^1BN*c4 z1+>UL3NY~Mpge6FE9@$Ro}EMj<4@ScbS#1Tw~D31(N{qeB2uJ(@PG2T=pP^3^50L* z@8W-VW1&p_pWTgdYhU!Hn<#IKkf3AMlj@Mk#6+`@eIQ+Wc%Wsvl4gPAb&G|jbi5lUruXh$coH<74>4?UH2pTv zQ;4@D&UT_<`9s2Szxt+AtZz;mf##QWE_v9AVbBTZss{RNfxq|6?jX!e0hpHqQY>AM zmkmC~fXoI<2MOR3bceO0CF}+o$wqy^w2>af|3{wbBa;)8MoKBRx0Jc* z<<@-k_Ntllam6bFqZ@9^n1R+aiEgL~-=!Alx=hL=hnFGh*O=u*5jkZkux6AiJNxlQ zPq|H6kwR_f?Kf{cCMPR!y>_y|-Nn^W3S0-(JbREC4J+l8Cw;4Iolq{_ZZkTKK7lBv z#0Eg-xJGTv7{Omz!jWFviM^fLs<>hiJaDt~80y5iF!RL; zojmrw(rU$op_=W~n--^vC%*a~8CNR~CEYF9!_Z~o;2jKwdQloK5fC{)xy#qTs7lqq zZmifHo0^e1hrlRuvX8%>3BSm)Z0+;rJExGf&P&-*OJ5h0pl)t7gB5gYaT69@Fn0SO zXqfs7npKiTUSR}Vy^Jro(B3j_lXGG^z4`d{7eSrq=}z6Ya#uG9_*5J9Li`Zl)aTy!5N^7+eo%HE4r8??a zlSbz$jr!y%gLLsYzmEI0=C%q!4&Oqxr5&a&f~-RksN4)ahUkL9<+QDx>#PIaMMHX3 zeOa6LRUE-h#iGX1)dGCksCM zvOUL}rowl((<_Qyes5z*2A#LS@xr>Ceb0zl#Ir*OauhSH%;pfa8SY-%MK91DqSbCG zK4|K-)NVbYx)A<0yxZzk8ZJIB4X7mtL=Mm~iLOAX9)J`Plq3X~sGrg=fIYb}baSseyG!EA3GIM*IkpBvaAQmcHK>oA4m>!&-S(kprEqgFo#6NKP%%>dR|#nQM%$`R><2Gx);}RFc3(Cu*~y5+8g*E z1lJBtT-6!VF*lAqmY^4;m-JDDu%KbY=t4W0i{^5U!rki|OIA~wd}-_SHtyh~(ub#$ zi>_|Q#F31azckUJI4FMWy*JR%^GxrSem{2^?J+#KY^1(s#9+8mIk=1#$}~Xf!of*m zv}?VeAbv3s%0XDwC&w2++E>nD^E~45N?-E$E>)^FbAv?Pe5df4SCrCjb6mAC-#o2#vd13Li3tGl-LK+5~{_ zy#SG6kAz03cs)dtgHthiSKPu*_XGaTGxr_N>}ZeJIPHb;ZI9m3ASoZ2mAadS4!>X4 zHM(|EuisN6?W3TM+4$?64plB)uPp}NfWsvh36Y%8R6y8X)YS1iQIwm2?y7+3z2uU& zS>QRg7J`ws24L=h<*mQ%$WXBOX)@$jFKU`dI)a`N22HGQK^)!R^jkCZLOQmg)uLu= z#P@FNvvEHHyr10GKj_%=G~D%2(YzaBhnQ%X5di}=B zj%Tkkk6}o%VHA-!EreRPi4&dJhr_zxP~&i9mlbECz7-zyG4} zqDJDsFlI)+Kb~w!V=>+0PxV(f83R;)zUkPp#PnzS+JLwD4AL6~WcxmbpK~bF1|+l& z$&?0SqGh@()x?9}ox@Lu_PXlK?f~6&N9&+(+;T9^BTHpbzjE^&Vo|3t#>a6MRwDp3 z^7RDBv8u_qZ@CXlgAS6SK}E-pkIw9Phn*Zjg)1j8L!N8jV_vA|xLLTKuDwAU=IAG4 znc{d)O1*^5s%_5hdwpM7DXx3-ne%v5lMWbxW7`%y5Ip7$9-+I#RUGsc?}9*Hv2}9b z$`de8W#_xp3=`|Kd7W8Nm!iDl${}-Mt2ENZ7h&C=qPMv3zrY$8y^QI&drgFwcgUzY zOj9jm{Nc+MgIPY%%{Q(HP!|bUaUsB51}9xi0btDt^AD#`YxY^NYNDQ?f z@VzbH0tQwb(1pf1gd0SaNf@jQ*?sJ*v%$nErG0)D-v&*JSpg} zSP_SWG=FLs-$rnJfAGT>>1IRr-`+3y51Rh_V0oIUA?292scB@BOiQ)ljQ^XG4~G8I z6YMRd!;AO<)C?eI>M#Yi?=p2fPI@t*-d`Elp;+j6@@LYR?@s8p(K(S*lfWkvFj>%h3)&5or zT4Cj8L&)^?ljE+kqhX(CO2-UGSGcA zZzzQkuAcSKK1tnLOgl_|_gl;ZXUZt{@UHp}%H$cn6EC#kWt0?}^MMvTTlcq)W17z7 zB!4EdPDae5|H6uuSB$b}%56ceQuzJ%{JgQzE!-K5Q55ThjE9g1?<~g}IzVXwZt&J< zTXTHi#n-PiWg3c3|DKTVA)vcoLCq&5WAIv%JnYm1K;@qHH{mO9VCvgygz^%kNJ)1c_Az$O5|voz_Eh!6)wywDDje> z`3a1|+q#x0tyTe3&6iNIEHVndni4N{!bEeonDOm5>h$FnJ*lnPotx6#24h)0kUP8rB~W&0a~)*rS*;*aSS3x`)U4FBE07sSoF z-V_gDjdBCywi50bMrs>{)=s&i*ua3mYHjR|YWk zZ}hA`xL4f^>-O{ST7oBj1Vq0vw9eFFP=nwdpf_t8R8`57VN0UH@(dRby`s>}q`uzt5O5<#DE+oDPI@Yw@ zKUrK1fl!Kn`Qm-z5$;Wr$78bU%gs7>o^`G3t{uMx`f@Ni%YfguJPE?f9T3eA-C+c> zR-TEVxPSnu(N564lR-@0L!KB!+bR%K_hKieurB|jkKw+R$ua1-mb52}34EJpwLm;SqZ|Em%~sI>Nt zSeG@Xc3*1$m&+375-DjoTyTeRFoI5KRqm?grrUtIU}I@_I-6NyXkK5WHdmaJTkjon z+KW&7!`2)P{{ve?jJ;1qJ{H?_b2v8a7W{DX8LDvd{Wq28xBAwY)sHZB=&)&zj59J% zYL$^AxggE?+G(B?-v<5IqLK%3U%pIUvl8yB;-`~fzW9#5CRZ@C4RqJp^*DV=9>7aE zz)D6%9hK|MggM^wTYD!2Rjlvv7_V?P`P%bjHpp_~+5KEsOoNExwIB}%^MmOIV!yx8~>-W}<(3E{dLm2P_dXqX+h0jkT*u3O$++b*M2!gh7TiQZtrG!!aam{Sg0mK7>WH zUh{?Av<4f|jWM@gT5uoJ(n53*WR!l~2!JUjrfPr22|3cWS1`0`$EE=2dCrin@Qg^9}RO9XV|Jw+Jx3)Ut$JiG3=#IcDKH+VZSF!`(&%w@L-3pZ4Ec#!Lnx8 z-4q>CpgX_g*+NWpf}CD&u!)PJ<+O2_qfYl>c!&eNw$Ix;1zPKfFTnCBx;R)gH}NK1 zzba^+dF4ia4EI&kmY|{X;1cF=ae`Jso-0Z@U@ZDD&2G%)Y=+aMJhi(v@yfYvS08Dl z=sx`lrhfLE$x~8)qhMX&CYJ|`k+MW6YG+fcthZ!4?T?cZlH8J0`|_w`(;Cs8L0KpD zK%)a4pYOB-Q9lkUUZ)*A?LT?Qy&e<)DMvD+bsb2lYIY_yM>-tv-PW)S0;&)U;)1 z@!N&oJ zC5k$FNSLGwh#7Zj4zQE@&v&CcSjR>gs@I7TqJ`2X?P>x;Ib?0g zHox|?KGXg$xGYU4=<}?;3*Ikt<*VFKziaV`0`2l^twlCf3;w>x$9JsXR&rHU)|rS+Ez-&iBk zZwU_WDw9*YJ{8BK=Io!l0(YHcQmChWM7kye!ey=N#_|-^Yk!u=vdhk| zv!J7HT+eQZ2L#|TIPC7i58uv2a(z+4 z{gxjLKGA=S3LYudJ-=6aX)) z^7gF3tccBcS=rma_i&QD_!3Mfj6ii(1Wiik>y zic|$e1f)g5f)MGQNN>_h1f};*s0k_FiQAI6&))l-`#twQ z-|zXIKh8dem9o}cbG`GO;~no9(pv`;pmrnK<_ru`iPq)QrS>Uv@}643@bFYXw^G;UlDsaVi)_}~5-qN}IHW+?Rg1C%-DgL6vc-i1Ol9ST<130oJ-98D z-LF3r??`x`V%qIkJxpxwDiD!k@11-?etVjoWFbhzYl6+IfIA0;8{{#^@QR@w%2F*O*evM^xzBF>1iDq_>AARz##h)Hdl%b?_2(~N zk8`Ghq0d=(#jf{H-wYd%PgbC9hBkrwcn~o?4~nM@ci!!^-e2*+>hfn0ULMf{a)Mvb z;D*GtrAP_%VA-pB^H^?0SIK^bi5M@)%rlK$(FQYT9l{G5eui0ERwG0){M;EnTdqom zfw={eLwE^pu3C*u6(lySQlOVyb=d35+JNYlc#aTK!Clj!Wh|30y5rD)7^ZDo=|o%2 z7=pZbwf*NJL%4tOQeA-3Q*h?bu`9H(E;ng|Z!MD9^WduQ`FAQcKs9EF@(6}Yd$lRo z8#Eu%uNe^CEx*N_%+MMM)+D-dNmeT{WNqMI(qrSqi1ag&@$>xiOymF4(WBGK%w3*U z+hQ+VGWvGp`i&qq>$BAHuQ0oD1N3iqd`vii;TAM9oF=%&pWTt#A`FM(O~Kea+zJ;` z%Y_EhBQd(LEIKJp4MWYI?ltzVY8yZ0$?aFp0?*(y+;J47>Gb3E&ugr0D?=|pD#PtN zp1c1^+(^gNtIT!P`!X{`a@v=rh~!dWNuc)~KmJx_`R(5<5yanm+#I^y5dFhtm_?p! z0MQ-6zxhMb2YHMFAUis}&agj8-kyQ0Xz%X-p!RU#5LCYdG0$w^Vmw}reQkcAhYMwP zpXKRU*{e5<9IXv6j(@l^h+TgVH7B#48a#B=Suy9LXk^CN@ zXC4$V!Qw0_11<_=A*d2j-n~%W4BD}?vBvxY4L0#M1(KI>o*iRD+?4oe%Ve@&V>*A1 zl(sLe8Z*mI$f8<<-C4sO;SeHVzk>(iKKNx{4DH!YvU48aA(;#E(V7cc-TPhDBTSBz z6SVgGbcP5yI%qsObcy=v5K=w7X3@gdqECc!G^Xmx*$J`x$e3W#-I&TWXWR5( zd-guXPo~LiJ>tv02KsCg`yQc5e(c-c9=NbJFl(yX4sWGXG^7%KwBzX5S)d*VC1S>K zZSQ~%xReGhmwmt->#0=sVR9DIC9l7*R0R($4!df00Y`cYDoyk?7W{5Borp`T?_mey zw!^^@-fYCx;p{K&B|mOSyO2|+{3I}F!UhlrZ{6S;g18U^dWdN(%`;;(nEsWgF!A=p zN^$m^Qr;iXpO$=@i?okbI(N<1P4Xxu6?uu9oTI1egL%Z){7|=czG%4jZmS#OMeGkX zyf~4^50(NR_u_~PAyk0J3WfXQf30Vq5s2cW=&Y1Z^?2`t3CzMMSv9a?Ji5Dp znU+@twz=BKKWcsSC7U$OpVd~kBaG$N*kn_NzV}@7ZyO|Gx_V9SGDFYSFR^Nd66-g? z>p>}7fKvX!ug&;|lD=a3?q|nTN#6;@5DkA+oap6!!H?2|rN0!0x9CIhLDXhS;!8RN zG#Wgsf06)hVA1&gw6A4QyvF~gG5W`N{j2)wpJD+p)MN-3Fvor0KL`IUJ`@~=7G|M* zD8eC*K_qUx32YCLg5(Mkeyyaj(<%8`@#`7ecdm`$c`Imi?0(BrFRq!84|=F4h(~tY zvaF;y4K@r->51bXs-?+K1`A$Sm`1 zH)CQM+GXTToI3dMpjL6)%TzVm`ejqER-wKl?VdZs?+kL{6Pg#KFlCt8R8&}RIvFx=xqZdvvXiq{ecKCwy zq`q=C+ld{pif#7fHc*;OM-h*pehQS02z@8i+JcD&Yy3uZC=DXbz=ZU6fH%#xg$ z(w1V#ciMqNw`n%TXA`WRd>N;ReT9i4X4uE|{odT8$OROL>_z~$E#xZNo>~Gc+fB1~ zS~U|YtK3qzcjlj(isCLqnJ=|Hr6@-3h*caw-O|*2v8ptSnonz-l;wkjO^ zr$SDC+MiPBf|c|*WXreoSSI z?6+F~{~!NOmeeFZo&m@TE=0MFq7rMi^xk4_S9ZaJgBsSJ29;CXCW-a$$)7e#!shF`2MQ zLMr2L?`gHU@X2<=!6k>gpDR;K$yXCCZ1$;gZ??4G<=ya(Dy-O)F1#*1X=^o~Of_4V z{x-rp_5DMKKC8pjpYILdq_jW%z|U#^pK5c*w{c(A4pl!N2;1St^g{GvWE@|D@&Q=< z9g(D|z9RNXv`Q%9z<%nDq>loa`MYl1gj8!W=^eIm&T2 z;|c7KZq|39G%Lt&G^nVWe61mbPjhA58e2t0>7!n{3+Q$%O)6u0c@#~F%zxRp=aH*O zw*yji*_f?pN^?CT!nxMv*hevte$?j`JGD=p7JY|M48?G&I_n)Af3XYf-frYi6|nT7 zDZVOzcm+Vnk1o)-M0I%C3$n4NgT=qh1tLjYprtZC$vxTl{w3DRT*p> zYfcYYA<#JQjRMFkHm)DgACGw5!j%sf*_*jT>Tfmj`@tr%ofe13q{1n+n8n_&u#EA# zLK)@V#4^rOV|kI6aE1K^EQ{xZ&vyyAQHS;R!{x-7DP<0ElU#y91oa^S$=x0)iBky#X!-Kn7f~x9A+@y0<|OZvN3uWZFb5+B+<~^jZ~K%xL`X7HmTYrJov+e4iL_ z_wW~ow$!K62p0}WnaYcO+oObBOVuHLg*|yc_^}U{>&FauYEQ$vLG5Jqm0W*{JJ<*= zA(yY5Ze-DT37-IyX-{kmG{6@J@I5d)3P#b0@wc-_Izkp2iilxj&L;!h=Y+}?7Z1On z$rtFXz$VUvg!=M({G83VMW%NJ-fR)LZ??Af&z z9;L;a2^wmeD6D~>ehHIS>x`p$F{rV^PXW|Cme)`FAecW45qyg$F z^dc$mXO@#9MVGfF5?|dd=}8*9>!?QY(FT<1&45*{{Z2c`OpqodBK!H$x0h}6t_K~% zUtk|#Ua9LmHAhL$Ki_j~qXQ`o&joIbRiS65PqVhdq0#I-#Y@%QHbXgZur>chYyoA^ zvEgOxLIQ0!81tX1A3|noVM1~@4qB^xSYhiml&%!J=p;Ne86=4Hvi`z8;y6?D&yAM? zJ&tA#GC(@=%lCYSUW<4f6-YF>U%5WC+tHjB|y}T zIFjIYBHmOEW*_(9ICiU&ZJ#?j*0) zw_a+dI@-*FYrra5iS^BSnDGGqBW+JWEMaY&70wD*v zq?7L&rs|%y>F)km9}V4nX`wZSFKZkP^ZIng^9ye3xdD;GXxgSiS~#5Oi;kD=$E^L_ ziKNEf^K3{YDZgY^JwsApSpOS5OFIYE z#Fue~6<)Iir)#03bB4!?$}_i#Mk6$8#i;^i8*X|~&gMI~`=`oF!X1-(I|WK77yU(% z?Mq^`Cp+qqjuRC5@x}8OeY>umoc7t7^ph4yR-o9p8o-wnsJ2p57a%S@`-MqugEcW3 zkG7Eyx$Hc5`ew$wuW)I7!k63l#c%||#AkSjFYU4+FGYD-?<`nxg`*mBS_K8jNnnz( zs#>%tEge4I^rmCXgKi==Qhoby%iOrx%LoFH37VGDcuIE^|B|BGM7h9!pFC+-d%32f z-3ZyQCTlotc%7h41G#P~G-9St#P+6TiPDZk$)2;fx`aRB2$_h?C(bI12_yj05>#BAvRAIVr6}1BSp&sP&=OX3*o@09^7OvszOl<#HF#`-6RKuU>y3?=|JcIF2#|r$51W1N4WF(?u%OH*qY+!LYU(~kUTOV8LmvDHh{LxwJih>y35(xY4JN{o>#daC+RPRwE#^WA>$1d6xqdT0>6Ge>4 z7rUWR%A<|C<`O98|ALH<{}&MHtyS><=#mr|$8<$EEwJ&>!UYu&86#gW0#nR9wPz_7 zCH0s*ym%X~#->t&`LT4KJKfQAbb6QoC*AYMkL%E=>$bXa9&E5(0dML6@Q){`jraDy>n2_dKYBx}G?-CL{ z9Bn0);F*L?e40gVQC(XpzR=E?2I6D$pf993=ry~nujAFHwT(v#~|+d(QGL6rCQR!QDCD7M5Cnh<&FA$_6%V85h@Af2R~>{8_jTQqA@p-kw!No`C_okjysY=!C9+(x z42BGHI1SV3RGG6y6-d|M2G^QnRhuyVJNDdohe&&`fGm)<0{+8?xhbYickb{;#FmBn z3$9X|N!HeF-_Hj*eTCf#=;W1<;4<1rt$_wl@>Kw{ZX&+4dus+_qsqJ0nJH#IDvP($ z$cYbqtzmpGTkx$AW2htM^_-7I`{aK7c*I!GbC*-+dDQ%>LYzB+Bhm9&>=Y@Kl=ayD zw9V0mIOT(f6(_qYg1^Fy?J=~i(l}5IlUOROfm7z@Yx(yyNvYyY=fpw934)KL_}DBZ zM~Bv7Nxd2|AaAdf+A?S-B&4bVJ%}R?>hku1=K+R(sCumI^lG)wltePS_gd)#YfZz4& z)-A$;Z$0d4aaYh7nw^9%u@pEWzB4%VsU$XQ#Zx2!cp&RQnX}q}qmEq>95W;m8PMk* zwHCSgx^q_lL;!iI_R+y>!^RUH|n&F9*fKQJ2yI@ONxq-$cI4AD^YkxOAE!!|K zsLF-@GN&S}Ob}dLSoXCWB0FuQK9|Y!I$7)J#Te#tEEvfg=76c8dRx0lt>JH0)mPyg z!vIf%VL^z*McYTCPH#qUHV*A(nE+lRE~wUBxmh`@KlCmxlvy+G`3?%Yc=k1?{nMfo zEAV%DN}uR)G#53{nfQdrLo*PVUv7gxJPvVYc}}p}yWcusDO0thLa?hq=x1BkZ!o%> zFkP(=M1f!=<7B6Q>%Usp zbC5h{{0EU6y0;V7E|faJjeVdRxh1R)qx0^TSl_;aD4NjW5-vLX6c>!Z=E3PaO`cTgBBYQgG zWg@*Vw?^*j$kSt<2ncts=*mSio`DKI;{!ApGLE}HD2uv#DJo zJD^Yoan)Bddxz^g$HT@^^ybAd)c11Gx31kE2uB&PH#WyCmH=9KVg^V7f7wC4V*lFs zzoUThpA<>{Q#9z)!4xj8DAfDE^B6Zp#R_R=sRfG=@{FrT-{r2Lw9=K9}wu^NFnmG4Es<1U~r(|AWtFm)G zO91Q)+{hTQWw8?bZcoC6x;WZX|2GJeX%ZW;f|di)$JbSc*nXl<*IL3VzoXYJyO%y1 ze}yT}!rQ%vI@r)`)U9E%7_;1K5Zx~j zi|VKx%^3{v9K*MrgsPAcjw&K( zXbiIr>%AAq9ErM48J|Etlv&jv#T7HQO~ahf_40A)ecDxT-&bn{oJ`{kWr6|pXZ}(b{3nUbt;$5x1n8totZZGNQ-)`fS3M_;)S1k7Prx+FSi6-M61UmalK)1c|y_ ztOree;u2IRu!@hQ7(yADQx$w-pT)cF&2BH?cqKk7*IVq$w;J+!d&N^;m1pGWh?RcI z<>N;v*OL`_=ZZHHi1@00e8$_;<2MJFq0r?Mm#7{)uPWplv6lnpmxW9!^_7%dtZZ*K zy#urnp!86i-Ti zZp!D+vwz7U^r8wGi@|p}IIn-;!JCQ^?Z|V@${KM&lZ6ai{KiX{B<&_V@jm^D+2e** zQfVy9A&1T1z2q(CV+Emw>FLl&VG6yZ5jo_A2C zG89ic7e|L|r)IcZP(AHaYH(g&fq;Fr^wyS5P4$Gx94!~a#rGC6 zU=N4Pn_r?ga!{+FKz#dYZOQ}i)Q$T<%j>X@A#;Os8)9cZo@k2p4yeyVtsPjG(j2FO zynd@|SL43scYS&zvZRBvA(!mcrw!%R#i?-sILtx_j1ZrP(?8-U z5<+pSpx%YXuSW9zYAMY!N*MJz&pc`AKA2tcqL8-W+uNz8qK6l;+e~dA8wVyaDU>Vq z#}I4K8gQkPgiUYMEY9LyV0E$En#DyV$xfFM=1q+%k!!c&&RjaW92-i_v+hPcfg`~7Uwzs@Y%JHX{44CD z*K5GHd_p^Zowj|fgojYkxWf?i>}Vrep(JfOQxKglzSei!#uRG@&a!g&AGLa_P=F}i z8T-OEY4wwAMUE;=m*|pA`i%dVvoFe9b-MmKZsl&;O67{QwK%Zx<)<^CbOnCeUAyV3 z3Hq8qha9+yRb@~9iv3#PkT41vVW`L83l@Oty1jY)s^O;A<7G3oZW2%32Qj2uxGunc z?O^!v)(ewPDj6A04|6@*ttws=RPv`sPb61A)DKhac$0G5cJjnn>tX&e&YSi8Hx~Cg zw&Er&L}}1&fN6m`9JR^`#C73-R z>(_n}kVsnj)dr^q7MFjxh`&EPOaUDoKcun^Tr3R*T920`+&XViY-MwoTt=I2%blhl zYbICk*Yt~@KjXol1XxWZ-2w59%gG@FO0-aRV0YB$B9qP(LZ1u)m}5cZsPoZlYy-O{ z(ck9BLJMfg`Wcp9ZJ5_&r~AI+bOZQm17nfn9o?~aWZ^_qk|Xi+#XyhDjJk3*|47&D zS8V@UTImnDC8_BTaKc<)J(&%dB)>p43fEDX0AWey5YSKQte`#iEZF)-p@kIE&q4{V z_rxvCL-V(7=UacTQ!Jam0yRK?9o!f6phnB~@5a7wFWh)d5a_9g0Wi*v0H+p0Bdyr5 zg%BNEr_s%Wh7O7{Wxh^z?tRkRi@SG9DZCxWBTO>8ZC);>QEK-*J>nKwkztM5L}5Y` zHqGf|I`!YotmnNSy!vIf;mtVaqpN|9U-jSOp^hT)dh+; zRo|UEC1ewlHt6Di5g4!HWsHSC$^dptC64G1WQCf?>0e>EC@1a>c9;7vG|aDTn7Dqr zKO5{^+2sEc4R}Q3*~m@^C_n)S{{2^2Hm>>!mbmbeT28y7yFw5{Oo`6E8K+6#1Qk(S zu-T$2lr?vRasq0MK`PXbUZMQbDAB}z@4ra!_lS9XI3wDW>E)#mSftvjv=W>ZPk+dC zj8Pgm^SeGP6}{g{uZ*HYf2hPK;C~cHwakAxx?Wnl-5naZ+62^@{Vu4QLiFgP<9;EjJyE9Vcw;hi(zQ ztPn%l6j%d%gGzrF6w>#4j2Jlh3V?)-;8brhW4~wmt9~d8F$xRWg!~i?VAKnMjotd* z`S^GB(_iBHKlv^{b{4l4I&=g=$OwbJ*j=A^@n}?Zr}E z;8Y$yM}2S$27u_-SJ>&EirfCC2f7qq3;sk%#bm^ZHw;+HIgt5@d(0=6SQFMAtXp?# zOlb&rv6*(Z1`txbXzoaE%Dox+sN{~g$CEdGuQiE;SNi5Tv9*&d*Y!Atk^D{+<$jW@ zPHw{u6X8z1hBx20KhX$|5`m8rKM|)<79xIYnWR%ui7$B6w(Gu^^sWJE?+f{g@vBEm zRYRrPdvAI`(=i~mQ67~t*lI=-zSZjY>2k|3Zuv!(IEn4RYzd~;6*mAW-w>?TXt(ph zEq2d|Qb4vE@D?}T)%_JF=RG$zDYL7Fx9d<|;Py?PlDD}|@t)n%qHBk8!|Tx!_LYiQbQimJ`$czL20=8gNYu*kK^z%)DGTWEy;6g@9*m2(sU5F z@gY*&lXWs4<3YsUBcQY3c??a{lgCmmtE+z8FjyBe<9neRNwx{??Tv@Z)gR6I)7n?t zz|+z_Xa~BhE38lsS26lu0BR|NJB7H3Asr9{b2$haA)@B*6xO(P&Zv;m=GeDK-EI_} ziVE0^*?N~cO%04DNKo>i4`3Il_55lNf$6zM23HPZW+(K>M5(a$kg4_i8|;Hf1;Hs5 z2MotL97=`c>N6y>4?M9+?BuuK7*Gi-Y|VW4zO#r=XH5O~vs1Z+TPI>Awgp_2-yo{o z@vwm8?03<4XAFC&$9Dx4ZDE0F>#mGOqFK( z%a?a{jV5i9(Ta={?R~r1p<#IVf(t~L{p)Y0DMjg3+hDSH}t^YG_KRq@06Ut$x228QQStp5ro z{42Jt`A5I7gf9IjlBPMr+51n9w)7=wX=Q7%aa=8 zoDJzR_TvrUUGfj1K0SCVE;PyUU^PL-!~Q^>xp5Avjag=Ip^JCYM}q}6D$ncY`GsY+em-a5c6#AG$L0K zaeGM^=aS(;v8$zn;-LZ71zgu3!ld1>hQF<2uu-U&Yi5|r43eb6l{&a~31)ti7mb(% z>YdEAEyzt!!)Z7fAlq`zfFE{d1e*5l25JYK)P(d}#6{f9?C9d?oxPx+yU3?EXa{x` z(R8CWe6zFU>Jk1^qOY_xDtBe~;2SX+*>zF9?-9ekK;bZW4qbvJ9gyJ$8Fx$w&0m=k zydYU(f&ysix|%X z?`%qa&#bozZCOT0@Fis+0Y1A%1n_C#q{>0^*6H$V1QxUq?_Khy$!cvs;o8t zDQ5D?JyHu6Zu-nL+)kXN%8xk#2PgvN4J6wye87l57iS7q1{EIMj%}^U&XRU!tmh8T zXW!g?+N4UnUmQ=&D}7H4B6FVIAWIF2)w`P6w`s`z-c_TNT5GQQFl|j#b+jRfKY>Ge4we)N6Uq5v@~!&3f^*EP#CJ}c0S5|%aO5S!Fn zu*P;r@a;XYEfksJXxvQKgF_RF?;maT#XNIt2~KhWDf$V-6mAZid>JQ3%R-#A9JDmR zRg!oX!<%m(AElf;bWKsJa=0_y=;T!^OkH8`D?|;dMI6-{zRc#l2!dy?2FEyiCHc_uiCzjD z8I*^c{9)8r(>z6(8`(1xmpantxSWgqcAUS~Kg`0re4~Ivw#mjtL9TgW{kt1;hb{Vo z3)Nd2j?!$0;2EaNz7n)T04vL}E_AP7S?Qx7d@zpoL^nZmXijJbeLgtgq>spf0sGS_ z-k0xa-)mitsP*!clIqw$7gwxj96q@u{&95qu)Bqg|2`i{cD85rJBA84)W8C|ow7=8 z2<3&1N!tHx=Y;d&t*ol=kaq)S7N5NC9Ls%@@o<+*fg6*3xttT6vZoNBXs>J#OJxq| zhq!^=@zAHUpk;GuXZPCAd8Yf#)@@%&>$NLSk`IB~m%~9gp?C))BS})KZFN|=H3bqVI^3SxnEhY08?~zJtbtyxX@^jD3 zj_a;7klDP`L2Pz`By3uOAyEGCsXcm`2n-M$S*UhFAUegk6Inlgg?Znin@(U-uETmE zM)=sN1ig7}^FJ03Arp3 z_X={Y{xmVMVbY(Rw?JhDEP$EecER!p_Alt(#dio6BC@2uH*u|n%vNRodtYJcsJwiD z2!KGjq7%s=6Y2g4^DCOBE?_ng9geq$+7Qq3Kwi|>1`#GGXm%s8=3e2n)n$FCo#`(1 zPdG*8DBJS_4 zooF17gUB*`phw)dfKm-Q#D4~OVHpOY*l}6rx@qTAuMcOTw|*+!JXyJ+XaIkKZoD)_ zADyOV^JPM(#}Xr-v>C?TIq39${s@BZ8-oZqFvlXd7aBMBu3kq1aJ|a5g2N0Byf(aY zr88Ks6QYz0uSXCMz(V!whJ3dexli{u;9C20j-eaRaKq`Yxys%Y3<#?8_6f}T&-Ndj z3*HMi3)43rt@Wb&H2o)Wa!;)Dk?hl)NizRpom?ok1JNv8b*`l zbLp4aM1}+;o1wP35Mi z4t`0&CL7oBVLd1suZi{N-AiToa>y#dMYgLGf>r3r>z?G!mvpun z4{cQR7{1yhKZ}fh#C*TN_43)GleM<|yJm*&_tgqnY$)4lQJ56_h3i;H=lQg=`^z`J z?zaRd$h%pAYY?|F$>Lq%E@!p0)W;%|JuUN}`u)V$J$e$YXj`3# z?o}*$3cc6wq(y&5rg!I+Q>XR`a4e>P9{A5BaYFP4T}QL z7rm9K@G#oDa9^Wn;Q@+FVMb0jTTIol%nBZb&>bUJIGQtHE>=HS@`-@t!|+j|gA}`7 z*^O>UTa6#CpFyAEr(CU_4Q7U7UK6XAXQ@{d5>!M=CH3~czC-isC{SB8lN0R|6c}37 zz&zpn1t?68yB-;n;c9c{h*x1Pvx+ZGS!aAp?ZuPf-Pe+9f~$v53WyiG);lIgL@QvG z;?)KnhCUtWsEvrc={kbw;v!NHkZ z8_1AqO1m#~Q@JzBo$Uhl5w>|(TCWX)QO>8`%ygFM6XB63)+U?p^r-afnE9f>F=_NX zxaHM{9Xm#wvT;vzlb$Cfn>AgC(27^SrI-hh5SQ6kGnd!;jL0+Fxi=1~$}83vdfI6|QQa z7$o27;%L=SYcY9Xh8BSp?S0W9?0>{~`0aX3j1z7mrH@baJd$JoTdRS0u^(S`)CNjT zcH6+smyY@Ev-dG!we)@V3_GjZ1_=zF47qRgpysINm|AF2k$lIUekUy1K~JO2wtoAF zE_<1JjJZa-w(f%xUZE{7wgnm_R&w+DEz}Y>^chC=As;Qrc_fSU`ZcSVb)4e3x45}3 zC3zz76kE{!wtT0)=|{)OINzOyNhU{ci5ND7u$J)bKdicr#i!YIGs@96uFT@DMqW_$ z1DhNnwjONsgE|!M@39 z`|cl5j;N{ZONAKhXG_{BF1-gWbh{)e;4~e8v#*4`Ew!F6Q38Tt z#@vG`%`@T^6m>|f6EW;b`4Uc{ z+m1o^`V+BaARcwuqAYH}@!74I%f~M0RW5x)sB3|0?Kd5LO(}<7l_#`_Pf*Q@CGK#( zY(KQ|0x!S8l%*a`r)rW`YYVh;>k-TpV@Crl#j%72FnOWFxc)Zu5)Vn6$qpp*n`C2N zXW>5Yy28A4*~v8F%2RuC$>q`WwNU5|9z<>s7BSI!7wc8&e}g6yl9JbpUEqN75y4hr zbnQY{X#Buv_l@6`K=MY#T@KbZV26ncXv#P3O~E27(J~sY6~OyU0yqu71Vkkq?S{$d zGAL8XG1}-d_@~)L9>q86doczd1;&?WSQy0jtHeQHtG&)b#QNt;y9I%Ltl$#jcaN}F zK+)lE9oAsuW53{YZL#0%y)>PEbCpYHit6bmZ95M7?S`Ov0RwRd(}z+kHenLi{GA?b z9$vC`#4Mtse+yzOokKZ)=W++6f9Z87elrVJMey1qX&eqFWOU8qnjNzP<3mv18UC!l z8&9AKuD$!Ws{#2kgR}7^W;)BbKO`-%2mU?N;RY6?|M9Og(bLlNgGvm68=>L23D$Ba z;nBZzBm6gRCIc;d;^fgx*&W;7R^fGVNePXPxG?aZoJru3W40#p92|x|Ur| zthN^!JhP!HK4giPS!CMYx}EZQ(y9 zd<`sCy%6V!t61YM0=w58*Jcw-gR?@7xk@~ofd`E&vAhAf&ulDraC2Q-x92s%^hvtf zno-2(odbRX4wu*6@RW#jdOUd2kNbv?Sd{?3cT}uW@dE{n{g$;61Z1CT0IDphG)yc0 zic4!wZsv*E{$}fr@>0i?JMT~p3;r9-bcZrtqjKHsz)0p zIZ(0BHzxSOo>d<|TEch(l_rFTM7?G95DusHv2hf2?Vp+N+Bx$C)nS~3y>MK6gVdr# z;}@L|0eWc}*}0Jig6c5WE`5g0@mX~eQ#A8id<6o#E#@4xd)d zcDmou&8o0`?mcWmiV4JPHp<1BCCWY-j^|4AZ84`0On_5#YP@{}_8TVH2&x?Jz3ilwTG%C_>OK&E2;IdI#}sS(3f!Is3If_1~HkkY;`0 zcn_7N{YBkdD2F3`^LvqEOLg|PNS$cNlzth$k9_k-7h zB<&bczx+doR`GPdQ1EfMZ2OruX%AX}oz;8C>siN7SanjG+x6uU_o%ujm=nl&el$-b zs*0u%S8_RX-vx5boQ%?p_I+C{?}$WOuJE~3J^kY8p%nE8eru1g!MDvDC@2%0XAN}o zFG*=t22Nb(^yN!eCUd*-u!Z;IbQdsgr=Lee-3klu@QGB)jkw46&Th4#>xk4UC_-Y< zt?hWNo_tebkwn;yY?p4~!_0!@GmhGW%~ieB;K_xdMk?sf5L;T-zo1RMX0O64?me6P z1oFA#k3b*?aRmOYOiC}^9k~^UMQo0 zAqw=e0#!!<-87)ebFGD23XhF#>%jO=CtK{DyV1VtZ4sYdMhk&LLch-{*@tzS z6RX!zNUfL20@fkx%`9>pwo%cLKpHg+8hBL?y3#?dFsc{Zm`iD-Vr)@-@Cvk5JR*ei zxzEe@s_ZX0Y%ZzI;Jd9UXJ($Mr~5i98rplS(e~pMdiEE-q3RI5e6~OxRgf*Gtv_kK z+5)`tO0E~h;VP63>PkHE_WdQ>B?8s>+B*h}j#fotrdnSDYLmU$yNO`*F-;y|G>fNp zEWN}@(z0Y^4#ub;#?PrehR-s?7s^-aWic+g_LJ8$+>gGfIK-~+e^I*zo(?^^o}bYm z&|tfsbG6k0pDON?pM3@nTRLi|=n^YRhT7oA=0sA?f|cnEd;-5(jQ_TZ?iMYp%)*4> z6z#nOA9@Mab{a#elhqTBo02aW-npZ6UXVuO3ei{al{sXOD4Av}D^0CSTXSa5qk+w@ z=5QAFa6i}6_TKe6xy{aEi?%5nHk3%7FTX8yYgqC{++~H%Q4;>JS;Y03I@PFF>F3Lx z*vk2nP97;QMvByZgS1sf(#;UL!<~*0#SPc96%x}=1|t9mN&wT|1Fk@bs5t}frE80S zBznKa#2YQK<;adXUiR!fN23^(IATo+%7s~c2nO!?AyB67Lv{F?T;@zrxEVyQe;&E> zih19Gdi$!lS&N;z$epxkU`fqG@?NMW&0Qpfc$Vk9Ra&fn)}6vq6BMgJ79YdVd&}B2 z5M##dKj9u=UW4I+k~NZB))$fDGT*{qIWw3V0wZ3} zCQH%WD3G&DBRbz9*4eM50e6b{msantpBh*Iwz6kS9T{Lkz<6U=;f2Mng5)K-Ire-( zl_WG6j_HC=wAmJ}MSIfrx~R8~AZOY9$GYo)+`>E(m>x?@Z<)3{iw3ULSaJqHJsGmT z%TY7C`*2T0FVR;WXBB~+H3WJbLMCydpEylxg0P{qNn3^+YkoK4pRoXXb)i_Gfa%by zMgqOGm0-WbRZaX{!M`n}c&RaF>gpo`$bg35fJ*)u{^W1+L;BW#>skEotv5&pFJgZN z7wjAan|*)rs8iDOf)A*2fhB9351V>PK&s$w2Sv8a(SSSQC^6 z<>?bK9*z<1P&swS*gt34+GmS@Z(zv(H{6Z?jbwN12UErxNB;o+$<_GJf|V=A*c(4Z zb@e{TT-i4Hlf|1EoQeUe^G^>A0EQTgy6+%{!H#%?3Ny?~9%GMH8(x=`Baq5`tZMYQCwvcuUw(Rk9 z+sNwj({_+602FI0g&-n3P@p$V74Igk0lWr)Z!~zo?0~ZI&o^>@`?sNOGDUla-Ai0& zUgd&@sMAf(PB|1>x;I_>QZ;^m*QwR4R~~GPkzaSJh?8j~UCT8%(adow6!mT-#0vSz zuspTMiuql@1t2^m@OJlozIEbj3$P=R+Zw0x&-Ps)+SDE^ld77!!>hoF3O5OTzg7w6 z7zi5z(oEQs4X%fX?4`OHg&syO;+d}Jd`%)mYHPv;j>%n<0atq5eKw;KUaiXirw#}=>8;;oE7%`H zC?sjU{aspiCv+Se=E|R2(Lc?mIBcFAcT-8)EmU|E-GQf_GX$iNWC>)g@Zj`XsIvDp zgr{O@zPc2D-_A54wJO)!w7{8IiAutTI^?plOf>a5VnDKQdI@094FlNN<1c(>>sc4< z_1WK+3l9P72byf*VZbGVQ#UCuyO(qmXwVnLNA2dBH-)U;b0M^6_4f&EpT$dY2cXdd zBgat(CtleP!?xY_;U##_Abkqg`IkTEIJbl}c3N-Pt5-88ah<)U)hcXRKEI;WXc)&r z#sUfp9v60f_hqltuw2uM8JdZfpXm7U&gcwUvNO|+Hck1C?^hTazO1&-g^PSfzP@Ky zO_OByRqBxvQrN8M+6rDI-|eW!=MJ#DMKQnJKOqe8l~QA(olQ@FP1xFH3K>drMTcCc zM(RC|D+mwy5igD(WDQ7)*8@}@VVrstENVuKX#=431l`+f~Nzwi8XqlGqegd(aj1af!E2h;HR)Y zxW7FdS_yb>Cf*Ckw@c1CK&Sf<3(;g5ev;OjdO+o{z*3mq_4%tU@L?7Iel;&=LG<jVbBNj*cnIhEWVJ*lEXn7J&%LUNYD+Qzkz{P0tYGmtX+`Qk;8L^4# zWF8*F7{R1I+}is0y&L`6{j1+Uvs#~FR=}4rbYTIZBDEdcPK^4nn7i;miLc;`b?Tv4 zEKCL=Ph-5)$FIsg+Z;XX$R1-#0cb{`#2W~|!Jo>v&pCiVE(?<{i2r1ETC*a2S>sEE zbspcn3pv_%uEMO|&mNN=AD=tad+9?{shU<_dO>n07X6X(S)e!MNo@+ZkV^d9Dnop# zN(UK@y5pkVD@s~t4qs9AdtOkU&+!4{+E2r=%<3-b+_C62Q8K=Ab7XxvYUKUaB|X#9 z;myvYqGHdR!PrS7$W4F@^qteIS@JmAOZVQRp3qP=*9~_K%D&e)GD;Bl6ieW6SJ5h7 zFE2ztK$6EpX*Xi%`WV$gNcmT&4d)ERbP7XIE%ZK+991W_fF!Y~t#Gsu1N?|C^Nbbs zfYZ>DdGdY0(jw62v?Xam^O8oiDJmo@|BJ9)pU);W)X~Htz+GzLQB|SQQLy&c!8+&l zv;;gYCu{CaUYpFDV&S*T{NW$DE*$;98+5hf-OLv=jr7C3Y_Fc)rJPHj`cV~yTYz8W z!7brQNkVg3c>~qO>Z;8XjkwpFZ3bYo8%*btyod6|X>ZU=UbD&1v74q>oBH}DRMHph zfMxGGFoupVX+`4z^h9Bk{4|0JK&!+9c!D;t6q=^^Y^Gk$tsZ`zU!Xc+TB6nyA0`b9xW*HTsAh>Z@sOqPia+`<9b{7Cy&nGTXg>5ZvQY9{{P1FdpFa4tsN{6 z*hDy)=B6L+BV-l3u(Eq4_43Jru(BolQj)T_HIrJcSCJW3wQY?IpZ&ff53 zkYjCpO{@~H@6!_pQVMe|J834GAI8I3r$)ZDPq9nRuuz~p8uoE z>UY<37xCZSJ=|o%elYhnMdXf%Y;?2JpuCm#;;Vn@dgxp;_Kc*aI=c5`{F|_>!b=-B5r0Iq>E@S2Kw4O!C1N@>Tsva zKtS=(Jw_X+LLaFmBG^0pWwK-d&3uL`2{^7$kwWnc`%k2r-%SB+0;*|-&P%lrMzmgiUTXK zBULI@@%YKH_T7UMZg-vAn8#=|B!OOJ(^%dJdIL8@tCB?cse`JOQoEw1H6>4ks z#reaJp#@SwR++EaXh%xFX(5QIS6Y5BypEbUv0y1uU(^qDeW$wsJyT$IVlQ6RNpiKa zb9wzr(aUoNoC-gE)7>1UFdJVuMTKc#Omwhh#y zef0Y+h?g{QQ~qZ})A`!@{z-5lD$EnbdT0LCg505)P44SQB<)%yTKo8aOX*+r`%6R@ z(mGM};yv5Z+OgDR%xW?U4^^t`pD{dW=IRAYFkM8JPTr7K9m*3-HqX>99n&yt+r45K zJ^Q6N=}Np$G_vpW*;43nt83fPhsr~Ohc>4n2I;`(chbK#0%168x@dxOSZ@Ogu%bd; zhJ>exvNmJX_FM@4QhtCi-SlBA{gN%c6tkJretm7&Kvp{D(*0YBHvUI?3xG6*8@P8} zXb$xmO3pR!EYi@_2C=`xR%zv4Nu;#DXipG@w0Gq2TU&j6z@_Dg0KcfbnT1+q>)K2T za{kTthPY57SJv~UdIPJXzXlvmdPC`>Yxdp7#x5UKlf?6`=?H&O2>D<=YMs_FwH+rc zpf;P?E`DxXNd9d&%2sW!tRkQ>y9el;SC!NBlLRX9H!Q_8X^%$k?R~$9l|0o$) z|AP@U;a{{=z@QAZcB%_OmMYaRHl--&Ws6{(R!Wc*{$Ls-b(76!x2YCPR zLo#qDoMOVQa09(d!Aj5i@8*9wnQwaf5VkFL#zZRSLQS-ZPeEg21L83Ja#xOKf9AsR zcFv5Cxiz|5;r{b@Igt};hi(wtPd$_qKJb9G9P4^_$>SDtXW)gm4N#7aK{&h9)X^Ud zYzx=~vjRt@ot*fve9MuWu@$bmPHyTkhG*Nu2F-iV8Z6xw_Kbm_A__w%78RGXhxD_o zT76xM{kC%LPLAGi7pB~@PIENh9DWbUF}fspjHo|#i|o%@FX8Y;{dKmp0n_%7SZRoP~0FSon64_j-U~ zYZy>$kY>1O9KHsX;z70=3KZ(B=V*5EZ< zg;39TsAEp@E}HUtMx=wv%&3pv>VAbY4HPb)#EHILG^U7qVID?IzIqUU^$_d6hHDLm zj2Mm`lk9g6z4+dpFsg4sIm)o&8NI*Qz)7eja=z&bBPERXV&JCyQ1b@kx)4TxlEoxm zut;VGyNaQJUidU8e71ERd_kAT()VL2DFGH`KstXlU*LruWxy0#7A;a#bm+0s9mD(% zL~g(%n+WRS$L&;_ho&_xbJmG2@Vs~kIsFP)1=9@!B-IzF1F`aSttuB<<~10KmA&^t zBmL>YfHT0PN6-vliL`@Xq2E=j&6=Zkz1K8HRYe(eUyK# ziSP4)C%))=94BTI64yEGRU}-82A|&mcFE<=0Ylo6GJoXe2}Btp?s1auNM48D_p$Q` z)rA!Mhb6CTg_`!iTRy$Dm0gSRq_&Z|;R_VuxWB^R?W&)9?cb{H{|omv0`#5d-|9IN z;_iq1L2gbn6`pCkbgOAc@5JX{|2&_n4@4FfC4GX?Scy;36 zSrz<);o7&Duzt#83@ZP}3tD z;an5r*4vYc0X}iFgvv-s(uY#G0ey(42nk%A&U8ZJVs9#Ubrq?!AF_8+o>%+}t;Q?G z=y!v$e|fQ)Jap%mIn8q}jjk2k5Vs-n@vD6KcHbFa>-`*$+%|Gm0lL12>2+u6`W3)$ z=r6T;N%SvrG`BBfv7%Z0u=u(ypm3Zw9l-MBUAxyW{F>#pDL7i8JT~&LaQG9$pY=BphX0oUZ}52h2#~7S)Y1lkW>LGR z{O@Q)v;SaVwwKKUe8Qh0=oDgFBMULYM%$m)FAUHemQCn`0Q1u83HV3f|F5aR`{UF7 z%AAuuKdZ)Dnx_~u8?)=%ZlKSTkEZ9Lh~YKcu}%THmIWbZeGeJmhSRhUwzuW@CHJT_ z=+!()eIhw96Qz>s^T4`kGMlmMc{Rs+$&(6Ti#?KEQw%x=cl)%6aUj;|XWSXPvyf00 z!5b7;_M#|JwD8$2QG2b>djWX`lCO9*^1chO)%pl`X`X>!AqIrK+?Xya7?7SadhxXA znz@zGzVB=r6|vDYiq&==ww}&{tly^dB&|$Jr##Kf%aHzhcC*aqeUdUq)FeZg!wF9j zrOp$Rj7{snQGkxmm+3}TohMj`_BvV&5C{f>vQ-=A3tJzgox;u|dd@zG(RwawaaT~! zA_Ui2$9BR&laq+EH{&f%28W4WTR}hC2(Prj>lR;-?`aws_;M4D)2bERI9H&##q!zc z$ysuVFW~5FMx3Y(5atRLwub5raO(y;skRzcMB!kW0DfL7zhR3sm&mL2{^DUpptJP? z4#`=>a8!xdkae}f>r`pQsjP2$$F2D!HNXa6OTt@=9R?&rF6uwV%2C|3MlU8_XJijd z{jMobHk>55eIP@%N$Mt!av`NmIn5V%v`>E7 z5O*_Qv+^S*MBR<`t<$kSN!3o6e@0G^LS@~FAVU?~S51i?QP*0JH*w$}Zj})M7^xlt zj}FrV+~5>fZ*+MW2odxPV|K>%RaR#)-%65v4_36qirfmn*3Wy@cxmPKqkX8+bBu3* zi_lv$vB4g`39QhR`Y~MO{&7ORkbmEmv-$Gj@7B;PqJ}qdmC#+yT6zEg=BeVMN_N5- z?d-9dl4JY%VJ1)G1P)x7^mxyjK|jgH0zCFwU@Y|<@V2!8q5U}W=i3hN9z#-hfl?37 zMoe8&1h21xXhd+$6IsO61!5=dfG`bFgCyTjmjl|sK$i_l%timNzz@c$)!_uQUQg>Q zwaXc!QaEMjD^8sUUf08$r%LIqIB5yO8Ti}JAhGu&rxiAJm!U2!76~C&*|Q!V{pj^X z^%VMb;dzBh^&bqqb8;!E6f3(&yv#EKwbOQ;T+lw^>e(1wOiG~VcppY)MP`pM3V-;(j{4CI`C2eZsLb zTt3;vh!j=o_P{}^+aSPG=z`90Q&0W@evi|=bgnE+-LYdtfU|je_X?!0N4(UY|K02=baenq(|KV@3uNde@ zC=wMbVqwAOJ#aaNR3z-l?BYyc_HC!uafnL=X7V(8Ai%DEoX<)1wk5gQsR}dB+=DqQ zx!GCol%gD%gjuf3+`y-+ZIW{#xdgtz0E`w;_sDO&Uy^7iQddcEBfR_d0 zvLSC4u12nGC-rrc17ypy^BZft>%P%9k&orv=7}dkkyV@dpcz@X=2TbnrdHYJjcoN4NSmxp_NEffTr_ruj@-~h&^{?RsokbG>>jM?%ugFDQGcjRm>Jm(E-PIh zEBPwot2xEcI28fun<(O~XW2R6+VHvH@ZjV=;DZe_(Xo;qrFP`4x2l9G19axJ$#HiW8(-jS9bi$5 zxxglz>(6guOlIM&s>oi`Z6gO<&cVyYSh!O3sCqi@g=DnCHITYT6lIT2IBbcjx-v(q_paeRDda_zZGkcrdueYDPW^`qN&R0_6x z*91I){f5!2k81fSRC*|jBw(ckc@Uj{&bVn}S{v(C+iMyO%>W3O=uBDp&cuhczV@JeNqV?JR(5qj_2 z&udKs@+XE3&Hkq0p`jlPwl*Lk^%cV@U|TP!gG0}{Jhho?542TR!*#OivQ)3)dF~Px zHUL4+zeAx!M(|Qjgxte%xplTFVA&P|_g=0?O5dX%AxLSyb%D}Ad3JVPZ;QM|d#dA0 zm<^BwTDg0oXiHQ>EB$0<$Pm79kc-Sr-PnOZxA;IndzelJOuzHB`gVV@CBl`BNl)o4D)QpG2TZ+z_N&)DghDa*sQw)u!|{5<#xB1rB3 zG$r8LbGapBh}dwra!J_`jQkS=jK0CZ6l5K1O&r|VBfx<0P2un z#QYf$fezT821*hMFvX}6QkVSk_GFR0deWkY2WC>*XhoDytCzsNZsi`~Us z2ZSH~%8Uvs^la}d7#!sU<|~ce!1LO>Y>gOvgqRx#?@l6u81DV@c!Ks^Ways5KD8&C zxWy#1C-b;1N5<4@UPX~E02>SNGmeIznk40xzKem(~~x0nx4b^_0^y4*c9 zg8PhR6aol98Nb}8A12HBgs1%XJK@2S2nIT7D=*{kbN_#r^X*+=32EV<>7Vy4C)FuC z_7RKVk(7(;_*|Dgx1*U;+VNVn%}ZPJbtk9f_bk-)om)&Z|C2}dAO9!#-@2nVq5qnc z{zFD|$MK)jIEDZP1AxGC<>n^<{_x>Lq_c0Dk*>Ou?2QlIAAf79J@fSR@ym;>Fz@QJ zhi^aKF%*0YAs*@$xp-ys$VOBtl!|6p1N1so4?^3 zq$gn<=4TCh3+Cb1B^jDVmq}X0k}nJBO@*!kRN1gs+U}><1@ffhzAzm5(iHxEgBxu` zOP>;wVQU&P(DsD2O<}jc9a8#nL%$uD3Sq)qgx{t5pD%rN4>bUHaJlCY(q_@Rqg^+{tJ6eoq zWY^oVrYa>R!6jW7hO-Mqr)650{5DlIH;dhkl_lcN8JQV}K_$!Ijmlj2O+TVspTrcXWhyPvmu5?W-oRQc`YyJQEv!|-dEgeV|#z%IATL$a!OqE zjj>DTnjIRap(K^S&VdCapRb2j)4QfoVZ6cqB8Y)f^b3m(O#u=;&R10pn8n$hk$u;` zCO*f%|FS|ZY1l!G9QRp1LqwMuY+en{RG1wXyjMm~s(p6;E|^Tt$=@866Yq;!L}Yfq z{ta&;oT1~d>JUHz-;?Yt`h(g82@nZA!6Tw(>KAL{pQX=_U28t+Y$kjl;9wvjDsT}T zz4n?ivL?LoNFx@Ly<^&wOSVDu9C*`l+#K1&F7?KIQ$>@a(`TQ~&bUUrYQ$4#vrTCp6b{ zCW!kwCmcGGJ@*{@JS^pbKUWjQw~SNr7YV_=$ds(p1On5_&aAWr-__rQ14_9!HH#{( z(n%c{hK)fUW)hBBq>)-UF85HRFfMc$_c9y7}DY3OY+Rq=|c=u($5EUIMwgTd6B4kXj_3YD1-O6)f~oes;&QnY4X;!;Tlg|`+`?w!T@Ot`Jw9N(;) zFHGJM%yBb%99f29yF+%l^MxFh)HYz8hw^sv=_!uy}BALbb*s-UpmJ= z@p_AbV+NiKbfpu3D$qgftOL8+Wd%}jat7_|luo`Ev!ZHT;fvji@`(2Oz`LjiVXrCo zR@*yen&}~{??5PhcKwu(C1M33c?Ej(vof_KM>R|u%5+eluYMgdREFrx*D^o9LzgKk z$tOK5df>nvcGQ354gPYsk?3#l{TU?Pr(CG71&-3InAa4Bs18Ja8?L^FN&&#Iw$#&I zHsZaJpt&3pQ*}t6}m<2nkdz4brL^vbJm$ZRXk+e6sCL6oZ&l-k~jpH3eA8 z)EwIbacO>GvZy=jiH6?K`KR!b-}+Fa)5E(p7)#NlI~yA7DqidPG48_{5yGa)o&Kdl z5YwpRPVA}TqF8tY6OQ@j^V|TXeFob1`22ult0rUgI?R>ptxH5OYbJBxmmogp3Spc+ zigj?oLCxaPgh zixI^E)YIy>!~n$tss{Jk?awGV?=fv9MB|lU+9yw*8Iodj+C_A%)4i{m3@g$q!{6WG z8LocAW6|6@b~R;%s;eO$G>b7_<(0c;rf_F&{5!YdDtU5ifMNz>DnM(hE!N5hVd%@f zdA_<;Fkl^O2A2kqiv9rT_Z$6T_Sw49S@WoNv2BZqM0JCwky1#e+#E}2z$!Vl72|m2 z&Bh6h`@kOI25E-eEU-2Z;?FKW9uSJoKse;y>j}`+>9%FDd$vzak*U34A2-9C3X@!l zY%j_%K6{#+_c2~Gy#gb=0cAV)?ZGLP1YL4*U>LpQq9pa)kqkAFZwLGY+7Qtk-{B_Y zqJT*!$5y?b+z+L8FWbRMd@K1i2PJgaTHIgxn<1B~GY!EniZoQ-_<9wz1_`*%=f;xf zIM$I3xn8Z-r2g7F%=_y#8_$xE<60|x{O^!gpb%ucsQ$SR62^r&31g*{aidydo!i;- z+%Mg=m1Ly7FyjFqCNo*EO785>w(IjZhhe=6t7Li+6`bsgc&bTjqy@4~7cK0pQeQ|{ zJqI8d=rW)a>Ot3?ev?k`s(Ju(p&T6pz)vm&vke(Csr(0nMz+R-=8d{?Rn_qf*W+@f1x>< z5|_C*6n!xq&96>s=s)kdVa$2U^%_%rS?7h$Kjv^DQ)rCa29#i#gr+_R@^Y^)RE6yo z&2SnlCWG|R&jUP`L7?mE0@dN)!#74?uY&#DtK@8udwwxn#h~`S5WEClS1{d1X2u`$ zFVbMs0|sA6KIG0U5G%1D7=Wv^tfu!MWUT`LUv(cqxl@U*Q+gCi*+2e+p#s%r=7S8M zZYP3kDd@I-B6cSfGY)DgvzUd;BRrJ>5EMdUA^Xn~_#fWRHV?9_4A_u;S7#W_Ke_tR zw<;A2BUXl&Y*o+8MwGwPw+?T2?8IG+oak6uV{*qtQxBA zUUrn)mwIo&%Ygd4bGH?GoND5Qh)g*?>i#t2>Bg=lHmL5->n|I#T4>;0Hk4T$J0J2mjI%O3hVk%xLP)H$Jk)%4V@Wn#+gc|IludP%aAx zHoHA$6kJQ4V!=br_ndF)O>Zbn^y@(MqLtU08^0LcPNI)F*GG%BYmRz!T9g~;X#|~9 z_;wrem%e-aM<91I65Gw7n27wU^}Yy;Cqk%UTcDK^dN3^WUy#Rw|AY*Bf6VU>0MU^> z-<4DMObs*lIBXT}`_cc;3E~?1C5l^JhS+IC3<-f>J~)ciROsU}D1dqBUQHZ2&J(W6 zaidA)RwMeklJpj(|=mXQ{p!uO!u|JTnprHY%S&0cU!ya9oi4QE;U|v4HA{OH_}*ep=GI>Eba9!zQ@!OmIu0*?dSj~w1!qj!_#5W+6C(Hj)t96& zkl?<${eyuBq8&1;ybm}Q9YOs|&y!(w{=Xc^|6oG$D@(_ovHQ;xA*XHuoupT&5Cma- z6>z&mM{k633Rtv2b~DumN0WTcq{~PeUGWi`SKH-&`D*Zgi#<{z*)W84i5z2y^kS- zq`?twSM0TK(d)6X<5i6>97+;;0D%U8mD5Jm%SwHB?<9fy2g8Ewc?b_))aW;JhQoRn z{b|EXgeVe<`uwK~#k|7CAYiUN2|zFLSEtq4fCxZf+ra^HSw~G>I2?Kx16c;7?^-1~ z19hLtd~!@f{9-?L>s{ljG(9JZc7Y#=PhlE z?nzQ-r!Rv^fgy`klHjdj=p^mVI&xV6*uV4MAQmVd^^@>;?OufUa)?pV1{JJ%&33z)|8Z;drhHrJ*Z zZbUn2T;$cJ`=Zj!nGS zWaozv6dyiG3tP;H%xzlq3nNEWj^IS6A>bo{fqrtoyYO{N>G>W-gvHL+ zsTVQ&lgEzD+*@*L&{NScfJ@_A6IVR>jN}rI4?2+Ei=DY`(csE_H^#6JxM8ql?W~28 zab^I=vyz&GYdLTSuX8X}%5vY6g?IdRJ(YC6x=j3H^CmP_x6!TPU{5oVczpp}D*uC_ zqcN7=r?j<3;1x{DF{rONaUIXRPn122e`C5~A`bo)dMU8yW;=;h)b>?8=cGSNDg9v; z0{s?e5oNxUyGT3Tr)jkHVHNeE!i=$x?wZ<1t0Ilb>6n!A+_suoHPtjOt5c*}Hw^VMP z!lORQW{nys{)3_O8`YQ0Whb zP=$%z1u8+xuzt->=V0LKY$Ws;A{(`c87T9cRac;>nk^sj;(h%AwrDOfBn|=15<=MD@uw0KB|UOt8Pfr)ouV;t;f* z)r}N(?au7Qlo^j2Kh&6%Yokk%4glLLQ*<^S@Nf;f^*GAfa zuj*dhHDEBY$E)4*Urmv{>DE_;x1rJ=t)aMv85_AylQ84BxRw|ENnwlkfAw|XB5v8C zYIxx>p0!aL$=QYJJl58%Y5ezbI|ol0-M;ky9LV)AKo>^?;+_wN9t@?%hEwes)&61P zRjUk#U}w_0CeVzjWkh6J^}YBBI=^z7^=>qMVDZMC#}~83RQ*Q`~sU;<~c7jpdfg? zv9Rz)eW(?YL1dQ&8x5E1KHFO*=P<08>^NQ-n7g zcCQ9P@oqN9Bw#DukGqa{k@O zt}${HGnxF(4ryAj*!QujPHgS*2vHdo1zs~&0Br-c+lDbV!C{VTs}9i#Z0cP>o*;IS zir9voF3dM2Io-#kL~h$Ntlf>jF)*<9H$Wm2O=s(Y+Qp87GK#+C`rYW$j+P`*cgZWy z?NcZ6GL_H?yK<+8YAEpHnISu4IfKlQEzh(+HYwmEe%El9i+-XNu?(VYtxc_00rTul z|B`@-w;4D6Agkt-bWtw^=zd?i#N!zVVQs}Pa+~Ti2zZnNfw4-0mVucL1a{^n4i|u0 zD2mQP%VSQ^;yYDWbQ-?zO2OSX(XTM}FIR;~MKMsDfa1#;SJo*oL57&iB}b#HnUg3t z5{fWG6o%nH73I22WBPh zIZ|O4O~D=+Oa|pQnc?)JRf{wIo@|HDmps6|88Ng^cas@PJ)D2kSu~WJky&H|Z6a3R zlQi?9*8!HbrilnGSkhK;@5>N*SLh_4yjtS^j0CM`JM42A7k=xT(ur#wLDI$7y9<3- z_?`KDmFLFGWSE)?#%)?qo~z+3I|83DYE*sjaD{HxD5ceXcC3T$h=24pP0Y2CTSG?~ zU0QMSlh4y}e9HhFvw@*(U+LiL!*)*dw2?fsGDTiHxinIdU-7Km=?Ql~Uq4oApf)kn zI~xGD0iLEcZkL@E#D;2gZpF2dXLaPzirOCxH{EPM1;Adty-qsUHw_xWTzIp-WxumR z7^uQ>i7Y$M&GQ@nV*s-I5(VCPMf|#blh+>Uxr@+wv9?gdU-Q2B!WO~^w0}};=CO%`09Ah zdIN6Ta5&0hxGbb@pu}hx^OC#pO*LnZU)f6P| z!(CaI<12iA0Wq>Phuh>kw>6b_k_7>{*iZVdS>#`eXb{E^z^B3(OzQY_@Qp*l;DN$_ z+PU|2k=cDPxYoY6)!qZ0AIyic12=6|xHU5>y%eSd5BF4D?;3u&ITq6D9N5Z3FPe{a z(3FZ2d*u_Nb_zw_Kz`ZjNjEAS5flmwzpi1dRD_{08T=sn6%9n+3 z4xW?y3nGr1Q>MM!ZMx;rCo98VQni?R9%CLv;8D2`cruI>2IrHNcv)rAFju;>_l#El zVf6jjZ9(uSlyMvY)x3BY%M1}taKad$C8wIBcEO@8@&Krnu9 zcCz<(3eDvI!44h}X!}1qg}+GczxY%c2k_Mg0=Nbv#Ni(d3D;>%AQr${nZ}+iQ=yfs>c?hDVYv=i;O0Hn&TG3SckQE z>2JR4+kUfU!~bhf|RGS#sd|n^?7g% z@5RIdbfMy?xX2{Y^Q6ri;&*;tzUzmNd~qz;!0+b=-P=F?!mIV+q3=c%5*s=eDP^tBXU;K^{CD zlvHn+o2{op+%Wg-(;)#v(+A4J>*(%mYq7fU2fWF!ha5~3Uv(P=RI#-ncBd7tM&89v znPd^U6RRc8t7URezY}awytioW;#<9|P$C<1&64d-hY(69b{803KXQ%RQ6NQ`A}6Z~ z=U&-x;Axq+*yf<`indVom-8Blz7NX#D+n;u*I4G5(AY9XM%TG0Opxz8SxCY$pax`{ z@AO96tKiorb|^db^pN_v&5&hDp>QzU)xdQSC-gxahQ8#yotL8>BbI6h|?$afN z5fr+2o>j$0V)~#!tN+;Sw)JvZYWV~56T6wHC=d%iWz1Ttti?vCpdPb5%qSRP--93m1niH5Z3R1%)x)h2{AU9LxDJTDEsoUOu_NmocG2m8{>Ki!tyv ziZL5qc9*q{xSyb+jMl$UCqDDWCoc+$#rgOuvff@FeKg(^!MmGFN!?KT-UzeGU-m|I z!@7$c1U4fh9v^=%`27j6E(jtTCoo==)toxEn4RA0CQ)%?;pE}_UbhduU#1}MLiHB? zprIya0XtjQ#N&6RY`<)y`_`zQ5*sQD*B9sM0UeK%$TuJbby^*r8%rK?7nHn1KI#AD zCjr6il{*4!1mrDA(|Pip?DA_Du-^OY=9aqGJ8cSqbQpai00|y>=bNnA4$1(WuCvuQ z1-P2lWr8v6fLaqY)B+V;bnvN@6iYi#&xpYh>9$n65^n5La$xvZaCZZ5^GIt<)7Ph-CiKlVOA`=KU zc1LhxCP!6ndwbiJ;y793SZZFMs~2tM0)qe5rESg!QP(7aj0=wa2nn%Vs1i{+&bB=2 z^nhnq9u6Pc%EE`t8zj|gp@H~$%)qzsWU;!BbfSQ_h~#}*yr!RjecNbP% zq&L(IIcjO|L(k41kd+UQh>h#;GkcR29r3}t{6gz|hiew@Z9(~(W@sM*UxNpnkWHH{ zLFAZmpdZ#|#)SDqPCtJ_4!3{k=B45Y@0ER7t`0V$^yX4w>za@o$b!aB!X2~ZS?v9e z2x&nJy=~aK+Uj2|oi%I0);h(3_3QN0YuLv~@;p7KejxR)Jc7P3c#NLU)`T7jVqYVdf__i9`6a|6HDi+>k9+ad!yvS zHh>$(g#)N}fG%@PG+^a!4T{n+M^E)q-fg@H<&IKOK);XiJZ=bz=35kINwn~ zkhl>fDURUq6J!C8%`Eyq z`w6TG`^!%Hab~Iy+8%)Ig#-Qiglhn6;A3ew^|o}nSRH_kLKf&QZ0647x3YyEU*LT{ ztT&y-bn@aeL+6Qs@)lyNwPaoJ>~V!Q0Yc|c>gEnMaM&f(`0`1g#49ObPWPIra#lX# zySZ|`?c#YJJ)_H#%~DYl?2b+gzNmq+&R;K-uj!{aoZ<4A0is0nuGizhb`Urib^}s# zHvw^Sy3-L0pi2Z1;}&E?TwJYK1NvA~tq|kgr@c>I+i9>x{apmtjcvtVpXZNt8X1?O zStFzR*Fbe6FmjdD{Fl*j-}REgt=5^W<+eu9xyi|OC|`qujE*!y9w7&;7|Xme&W;Bm zf)NP}wY;90Q7a$f__8l{;d{=Q1&%LSLhrN_p;RYeO)~oV4U$zEpPJ9}ARViK7|rnS zJnawa(|z8}n}C*AUteRIC}^a2!m-#6bh(T0x}&*AMtlGD2HNTb5le%kreRedrv}iXEGVL;*z8- zUVpw3`X_HP`{mbSk^m6YgzN4vHMrYx6Ue(jd;KkL#OWf=gTy6)Szq9CM_MenTU9ho z-cIj$zwxkfuYiJqFY^>rW0xHLLPlCGEtWNn>r2=)16QrOm@WUBQvTngZ4Uf8bVYZf zN4#x%DKGI-xm=xmVrsTsNog>Fvu?k z(Lhd@1fWzUneQ@`oqjM8;tm*KK@#cpnqu9X0Wf?*#PY1{0;Ea=bxTjwgvZiX(D#aJ ziG&vaD_@Qs9%u^U>&-B@UM{_7TmZ7xk%=?7*{67IWUjKuPB=yG5?|mklHD?BSGP8R zbaDz4=An>cQYgC#(VVD6wT;-z=fCLH+PBxUJobB~MvO4TUwbnes`B$2q zOE*ip@KNJcMm)Ue!SHT@?um3iwbj@W@>3rHXb&>UOH;&%c*dSo1sUqVa6(g(+cv^9(pQ}H?V7h_bpDT z4$bMHD)UtRs0y-n_~wUzgB1xSwg)B*Sn_fUU2m}?%P_U4U2JH%XL~hesD}|J;F5ih zsOIVdITntwXOmqMuGjRNUk_=Kx|n%>Eh9NqdEwIW2?v0$g$bbw%hZjaxfWC-X@yKX zh@*%H4X<9l@&)!cBi%{a`piDCvTmFb+Gq+l-9syNIzJDyo#2vq25y8vv z*__EiS@gkiexfM0p&Azex>?ej+wcJ#Y+Qf#+VUmV?#dBH;il~5-6P^E!();sxAOYo zFp}8T9V9L)6s?<1h8p+Cw`S8ga!)vjCC$zGv-oY%>G*wc=JOXS|y;z(%QmNEw%p#1E+3#Qs=TZ z8g330Xp^f$J2*8cC%Nuu?)5vumOk<(xO~Ga85JchD^Sis} z2B}bz$GFp0qzf(CecvMkx6Bdx!o4~2Y)z97#WbA#32e4AD8A|bc z3jgh*Zg%0VLx_Bt^Pgbho?>L-2dB%3OnhE99Ir8BLC32^e6TeWg@!eV;4-g|d^#l2 zs`s88KKCwKX%w%IX zHddpF`vu~Ot!o!dU)e>>n8$P1L&z;x0)g&4kmO!288{?NuN=>&TKD~6u!DlA78o*& zguQ0SLg3Y|II*C*T;H9)Yi0ao;#;Zl^X*4&#agvqODxYX38oWWAhh=$oje36TcDg} zZSLT(W$pk9r+@8Wp_0mjV+Zo{PF+*F?_?C28#BaI^#%=~0Ef!!cXBft5vSndL}cXo zDj~RfZ{fGq(|zYZaiJf0bc9LQPwu;afUD+tgbcR?g~sqNeAFN6P5b}9`2LGPFKj>u z06U(}{J}5@-Z}oBmHyw3;8$P_OqYH#my7;051$D6Nk6BuED`2_W zIOB6hSd;$oiALT}mp8@k9}IKJjWn{~-hTtKQ+ZeDC!IjRJ0T`Q0NT}`|9sCD2QVxF zL?GFlUJ1-(wyu-chP8un;+C%>a@?>CvsJmY49{#0s8N?q|9zuVT$d;>O;(*6mM2_xR|@9k z2c`l9P1qg2j$O_P>7%p>uu9|2qRmF?*(y~8nq^!_JaYaQ_|(u5U0S? zzAvFbRO}J*+-xMbA?gZITfm!iIUx3&?hM=-f9$NaLYy&v1!HcGI2OPvHfh)J8#wZW z%t5z|1z_A%joDzd=D<*&=3yob=d8dlJeMy@pQu@)kU7cUgF1oj1f&ZH2m%q2 zULw5{ktV%^^b&dxHIU-Fe9nH(*?XU-eD4_V_s$e%S2&CzgEMFTkrx5~S-Q0WYo;fl{gD$nq%*dYno=KT=PgIUpPlM&^z>CdMRT?041uu0| z(v-HSWZ{enw7Mdna&LQ0`vxo^fsykmIKd&B1q@3FxR-5ADna^u=u@?6ot*>N4knJI z%sA>Lv$>YI*>bu8E**C1hyTlGytd*BEH#rj=-bs@5&xUcC7)sp@;{r^(E8RID8RUo z?cez66|X`M_8~v^i-H6{v{M^mY;-mw@g1Ht;j#p)DSDJ6s3ChIha!ej2(i1*Hz9s6 z;#?2KC!-@@$CY={V5y$h>d**j=VB|dJDKQ|`FFLvv_!D(mMRkUQ#^*YzpRVu3q0MK zVH*@X6!ysYR*?>S>+=epVtUHw-e0SKJy0X$HItZaCRDy- zL%AXAW?d{`{^4GiMw8G2k`g!op_Za4CkWvAySqGkKD-J`gj80?96RHU+%Bryhd#yI2DyhGunzV*b&bcs)c3n0 z4%hP2@Q}O7XP4+tF`i$1@QW&!eQNI;#^_Ql$9@vVjd;T|4>Y6>lKT@!k9!E`{6DIn*n$)t}?Qd-3lU( zB37~OWsx`DaBm#tunarqv-=iT;uAPV7*YofCl8u?8GI}+>F?9rwolkChn-!3GtC}S zat|iYeSMe_*oMHI;%Csb)fO{0jx+fMpRLO}?9Ht9{Fsy<92Xn=eElTV=9lo2 zUFo-PFATndL9o#vX&v*XmSzP|Mc+I6ORCC;kIzb|6^a^XMa@Vza|Pyko8v^-ti;Ey zIJ9v=h}>(?fSL)0hE_K<-RN43xtGky4WKQhAa>g z%P3|q-ZX2L5zW=;X6KizPY;AU_XMHhLOsPWh?V8}#!qO=G1rqH(54?Sb~ND`f>FU9 zyIU4*&)xN`ou0@lr`cT3a~|jR2fCl%Ucty1U>Z~IzmfVuG$`f8zS zA17QsS)>nV(-Y#a`1#6->>BemMicr?H!Dw8*8aVfF2Xj*l`7l9T-!8yep)_U9JVc0 zO}Wv_asNeH-4({JC@^G)i6T+sp(Y_SS9{mAbvnRF)2#rLe-;C+g3t zrU~`XzbS*_&j}9a=|hf%L}SfIH7U4tlYc6ioPR5slmT9_fhJuB#8Yv+(=f{QB^^Hr zZ@b%8H?t;S<*HWb6$?9fI#DtZ4F-voln7s5}QT*#KhIm&Bs>j}Z)J|R`Pq1tkfGALnE%O%_ES-8`J zxMm@TAjPJy5$r)ll&{G#%c#wdfMVM2#8F5JCNvvA+F0_m(Cza%$Xj3;1_K4czyeUc z##*7r&=j&1a?wWwcw5o8DBehZSK?l*DCt>_o!02ywfX0}{i`hAv|rxpi`4sN-trdE zh`wgjTMkS!4&fvoc+U&;Q+r06A6p#q$)0dw+F;2(3zzF5N$!`)mLAFu+-!isTXUsgjSgu4j2dQ1 zQaE8A5$8f}+oMlO4}d$XH#rNys4d)@No7`IZ#%gVqo$$-M7yRPg4 zsJRdiJuaFikJ|>!mh0l2JN^+*j4G42f|@T#r_bBCMN@7kFp#|=W*F7899iHooMp3YJBVvhX1GjkZv(ivOrTu4Y#X7SLq z zKkUZ-?w(wm@NXirTM$h!DtM&r5$cD>0)>Xu=q%O#C^gJwy7gTKN(PsipIos?gzA)W zlIIp9MDV)j)a;$MK0SXUk-{5e#%>(povYeC)WUs&%BOpeir!}q2$MzzcK3r(LkQH6 z#Bbq7f*a(UDK};EE55S@S|T5};66-ANg-m_eM+|3^H1ny_s$Zfn9!dUiGf9iZZ8y2 zNR&_Lr@ZkryU0A;M}{@%i4~Vy=3k0F+7VmdV3v=2@os`KpAKpY_6fL2{yvd4#g2<* z##F1fPwdd2_Hp#O&Q9=4v7q?C4=xp{Ta4@}NboVOX)_|G`4oxw#!8AMOc>p_XaF8C7&T3w$4WruP_o10a6 zT_S!G3FW{X1}~BzW!93v#5Yo~z26{hAn%x8$(dt%Uuo|eGAz$qf63V!u08O8ZP5K$ z;?>#Le2EP8&gRVKw2vN({@7=4215DpD_#SaFAtOv&^3>(wjTORSE~8YtaJsT`tN{W zM(9y7}Edxz7?HoU)}OZ2i{Z2u;NX!otgHFn^~jFwUbhPDn|-;es@$ zc-%88@=N5ls&+r7#Bjh8DH>5iMO{O<+sVo(H|%#r1rDiN7RrI%s{M2*6O{D#KhU62y@&a@Qn z>hWJx0zyMG>>2OXFzXb{gIc;TL+8JX8#(%DD?Y5boUW~fay?5*H6))Q0N{UR0=mkY z`HM<6XFCCV+lRHCkNoj@C_P)Q)GhC=ZS)F)w~i4EfaNWm^06fIHd>=rAK3{OkiPhM z!WVy0)wg&NCZ3AT8|Rqy38rIZgqBgm;RX=>frju$N09vsgye#$G1sC7B8L~sKfe>h zxP`WReb^;cleAMx>JW+CyOJW&m>++3D~*bziwb+*0#s70z$p}fuiH758Ki)ZK#qL7 zdcPrAadSVwO9z=^9dT(K(OUc9?zIdr{}V7<jH`+FW%`BjWW1^0A3`pSd|&4%5fIXvOg_)63Q$ZiUSSGn2l^}kGyj=v$y8a< zS!Anl1GT6U{U;f>w4Ur>3U`3m;owh9;Jl;xFDi~53p0bmKZf{xQ)%xIo zB8x*e0k{00xe7hi;Cut1VdW!Au=9_sg+D)+068kcN;2?Bwwsb5<{c6_uE-vHA+J@cPQxpW&Ee1 zW(;5jFB(nmM;`?DZ}yP;apV+I4lRV!8`{1M_Khb{!KvD>b62rHrn`tIut|xN8W;AA};ffX`VmF zD*)LO&!3~pU;W1<+3as?iGR(z{Vhx3w-WjHaFf8Nlm993|BLvj*BF%(<==# zz$f~L;tHh#O7XR}IAi`?PcZ+ywrZ@K$~d%e?a4b^I%;n3+tH+D^jpb=fgLUlwNQW0 zQL9hKgQ(cgX0}izIxZPbFdZAYCHH~j_q6&|>Q?=A1ME8tOAX)5%wGRt^_5Ex6B5P^ z&>UVo1aF)I?~J~GrnSf2ZX+cLRu^h(-%#0w{-NPA6n9ONX9|M(cF(`0kEsO@pyAT)a_m;A!+j4UK&}OcwY&Y6DV3WJ_ub(*!wb zS=xd@?sbb+TTaDHdY+5Fa@VXl(E~p4_DI>gLuN|JHt5X44p>RbfD*{E>jb&Ff`fiv!KkHl4cT=M!;TvYJc!T)H(K=X& zzqcMf+y(Z=!%6ksnk9{b;_YFw&_dPEYOMcFqqXMv?6FdBl=Fp!okpi2*8Q7--#YeP z8D8R%N9BJs2ES&x?dwV9#(Xe;;&Yl-316)MKLJ$-weGM~uX!Qm4)x7JgvE`UH6WJ| zBZ4Ziurr{7bf5Zw&e~d;&)#L=TrL~I`g?IXXR7zL+T<6myNU_DY>9#?Fu~(}^J6d6 zC@zbv$B`JvJ4C}DtzLF6HXfQ}rKraBwMQ|Xi@nVHPU0Ks>U(&UH*9GlEjd&g9r=gAYA5SoVgr2A)#G#HhyTb9(0q+@mJ zYFXS5VL>IJ4cP8Dyc9Re=FGq?s#<=jI@sV@GCcYDIA5qi(6XBBm70t@rLlc8(jmh0 zFyW|CecaESNWwWFF0F32wog6j;y`9_6|#0qi1BC@Mfu%CQ>$OZV*(d=6-`qtf9^K4 zBG(y0pyx9{4%IbDw_KJ*Ya{Xb+y3Hu9yYQo2?W^p-UIF{y;pb@P3Ef1hR$Ebq+uzuE?3(!dQNp`rFlZ-0Eomyn>w$tFzT`1@L#-^K05>k(IFg){v5c_rZJK|C(+if^3ix)qLd!mUcTcBEZF(WH(*W0^ zICt865Rs_Up!MZy#`*>deE^BFFqS|Z*9tT2kkZjZ9@u5XZ%uw%nyw%;=NpffM_tvP z>v=ybUef>UtoKO_hU5UQF^=L1KDt6PcEcGIb$UP>&Q2aGULaaDBk=fC0M?#UZWm2e zY(F-?zLwyYrg=v?<9M%_ip@XAWkO~|B80sdQ@R|F4Q^S#Qh8Dm)3l)fGOhF|TP8;Z zS0gy0XsRDN(wKL!uNA$?yC(h{|!}=a(GRXd<#<~nK5jYd9MHVJ<<_OaP8@&Nr;)*IEI0u zbR79 z^2ZS0+-0rF8$U*BY|7%yh@Ch9F1mIwp|w}r=DjR?8QHc*8hx-e<7DKj>@%xp?uZT9 zYd#qXgvl8cDYiy}QuzqA__UTa}#q2)jK=#_Q*9@D&QgmV`B~sZ{lgX=C35|N_7C2IsDUBfd|t5uCzg%upsq4|`AyAQeLHPz#T_>Wg*0T&s&bzF+s*ti9U1=A5K>J9lsBtExSn zIA@mF4|X-0$mSEJnV;y%chq`IY{b1;b&CSFVHKhE!o@wg?^+n?nJj-YoAdgfY!c+K zp}t+z1Fv^Ua4x=Z`_7(o(Y=Q@Z)8U2-oEDE?O8%upxBa6^c-_XEL+(`joW%+Ho>h$PK(9f{*VkNw5A>C< zo#m9_6I~o$kXqb%Km&F+TH-h8W(XjZ&w(_A!Fcf3ydyb&PsXP)iDK*1Rn!JN2DZtQ zHC?92>;N{Kl<3{I_k4Pg#aHKA`6}m?8Z~vd{w~u4nxm7niY(O-b2}Q#CP>Lgav-d1bnm;Ut=Jr{ zcTJZelv` z8QtNe`|uxDg8zdZAsDv&fBTw08yXDdhPhIdbbe8NL+$X~;v)ZVKkIL{SwQjnH)f5> ztveOR`Nk1<14*G?|t7$|Hq(ojRk^22fDyXbZFoPJ#8^VtE7 z{DqF#NqWiv-n4CoW-oT~Y9E=SrPHNx4m5=(3FMJojkIxB5669j2#O-jZmI@A3U44A zI_pr}!h^#yN+|Qx^w;-GDq6H-9c3$Kt75}p*1H$8TS<`*jVz-_VM99ds2saYN=X~I z0K7I`WRDDrzsUngr5OcqHN3_(l?#n=7R1uGC_G6oehe=%qDr}y044u6WBH_yZEE2~ z_+4PS|4nY}mY&+|YX%wa0oYztk@;q6h!K%!m~%aZNDR5ooKo6E`13Ri zmgS-MGn#kuL|91Bzue#=oE00&SQ_rfg!~u9GffjNE*` zOU8O4B&^$_0^3#YBz@tgw#akrho8O~_2&J|>R?8WX{g|B-TC~4BGdaj=-|WG1Jy?4 z>{}$w13!XAI{iZu- zR#c?4RyrZ7x{z?Vz2^l^)-$j}GJs7=>W8cT#24SwD7|qrKK6|(%fnLKh_BOjk+aYf zyGTnX5w^v`QvQms?w^gh^<`$ys4a?Bpl)shFd5~_K=uc#u;-Ksyxy%{kq0lS?DD?1 zgt<9gRLpq#^~dR`v4|?}uRq+@_ol*8Q&J1+4zU-`c{9b$#IJo%d=j1`wU-7?oY_2| z09OCBvu@~nA41=52#-S4n1;H4Q5E+uIZJ(iC&}* z-s>7U;cw!`Z3r3qR^@8C=0aOP@qd)BYXg?8p6&1;zbmqkhgJS=(&r3ws z9QRG|&PN*{d*=WE@999KVkDscyl-E>#mSI#qz;N^<}L81#p@0oo(^gXEs;4pri7SH zllrT5EbB^;-5*1}FGI2JL8!r1Ld3bHiSVP)!5+o?wonLeQt@ItCY_tP_>sb&CkOp+gGuW*|!5>zW$Xvkemd9+3|A zCkfN%s)iaeIdqE~`-`3=c!1w@q5>S^Zh**3@bFR11D*TrlB;L~n*!9hTP2 zdS|SvKik@WF&;{9TW_`{)j^;G^w(yho4eFTp?8Q=1SD?W9sP25yu^bnv|>Gzl}%}) zY!6Prp)&R6yUFLCzV~kEt^GK%2Rtoaujc?#4Ztr`U^;c#$K6X#mM{)ZSKEzWRLUhj z{(yhDs05tC7$|;-B%dRdzC$?ErBTZ$;d_KEqO@P!7%_WYarX0PtNY+qUwTIUW< zwIJkwtaau=jmpgmS(o0*XD|Q@7ffD}!A9KOTO<2hXMgOs&feqZUv>8Czug%=gHh`U z1K`O!XVhjE3pp>_1&ZIcE)w-&51>7HR_q8g+mgVY+7vYf8T{x8cHU!9-TdPufy3km z>;(VPA#d(q(soR||I%38VL01|r#_M~wf~Z4JOASOFXD;keq5@8o{QX*8LD^(`Eh{| z2k>=l*GXU#7f6Yz@#^%BIIPqu+11!k-1gf-<(FxEj8oGOu zGTEScW{<`+Y(8sBo-S#q=A`Zy$us#^{j;#vLOc$7&-tj%8pCK;*pGul$(S_mV2$l- z){biS(7-G2*Ig=P733i5BfulVqw%+)u+(eMeUcO?arh4|O#bWH8dV}DvQf$CocZ{B z;F0rgpo|y<4jJ7EZs?M<(2Cyyqh1{44hwd@89W;tr{omAJpJWyDwkTb$3i-Jce(R$ zW#6vH#{%(a0^IR;QbWYlgW!*v{X-A8Z#3z&HR~iUKtXU_Am!E&Vm)GBsL(;%&vI)? z$3Bm(%`=sH8cPbQM1gg=b0dGIdU86ZKVo{+NGB{Z_1sEZ!cMU0R`pr4#{&k&puqV# z_kUC1U<9iPMd_!t05iWiACgV@V8tj^vvc_3^_w#Ep|oF*ox8<#oVkO5X{==R zE%wWE^K}WP=6iWJvT0jWep~cjN`OGUe3Trgpa38Fh*XOen%Svk_ahfw!_6BR(;-O) z?JhJr0g}OM-{ZD=AMqtfz0WSXrY)|1yOh!I6~n{N)i_f^v_;xVA;IWQ&n|><AiG}+E#nEgveH*D; z`HvY-h{@6#@P%vU72Zq2JA1qRb@xu}WXnw@rQ$v6APIb!t+pg;+*b`qAWEx=~B zK{$p!Jnc|R#a7Y_4hQ)ziZ8M)hh$fNNG}J#n@{LQN&nrx510JgM3jO_q^PD*P4cJD zg+*l=mfGu=BAZgqp71yFH$IX$kG(B5d2pg-7qNh#nQVGaPRpCo@g36lLm%#`A0&(U zI(!1sJ4ipVi$1%A-mo^g)o#6DckL(C8a+IEFw_AO-mJZ()9(9hjmJP#!h$GzJ2gKW z$Gf?busebWqkF-c6SRddD}*vtJ8x1mVTYK#)HtS=4}H?T804P2xh@<+_!mpBx@N2t ziC)gT7OU5c8ToVvZIh4G(d%AYt7c6hykcdXXmh4Js?zc9tKh9B$JmS8KkZHSznKU& zKf+(nGRyj2#!|Snj^S{6vT(N^8a!lxC*JFGQedqgv5mAXE$wM@(3f1Kd%?dgbf?GJf={NmT z|HrU*_derUl;9TP?3mO<`Vg#5gNfb~gn%G1c;Qjl=474Ay5pWj|A2}+ZEw@KoP(x& zebv1>bS9M#tteE<^a>J@V1;slyg9d<2=$*$AEkI;!&CqXv1m8E0cKP4#|bXa+&Oe~ z#Y#6~#cT^DUV6K_WKeoX@){RHbDg%DSN8_3F7Z$YUE!1M(s>sw>c)o4>f+gyrl>0d zFk!p^R-UFIqL`@(@%Fx)s+Dc(79ZuyP+@_%yF>`ec2XD|?DO0w5@wtq3Y*B0D3BIz z9>hy8>Hd6P0-ZvxaDh)-X3onqO5)tO@RCikcl_pvGRjvvxW>$v3>!Lwip+IWw}EHb z&7x9zW~8L9UFOKkRt173n~V{Le6+AVfwuVsRDK*#DnqZrW zS6<29V}DUG?yLwa&fi5If7AJP^;J>KV}@Hk#tuN)>Vw~DAb_bE z*K>1f{Tpl^5t9YXJ6(EDy?Qj>E zysV&JCBC;@==|}u3WYDQc_D{BowDN1Rafn!X78jgI??c7sn+<0Vs?UX5R!HjuAJ&0 z^%IS~mBMvi`Tpyqdb!u^)ZrzE+!9<`5*h4m6R~X-gnWbdLoqjZuhx~R^<26MbLQN>Z=xa+9)Ph%DGC!R~JA&@~ZB;+^{cGP)&21 zFK>0adUe1BuF7y`uB5B}3B^EgD}*=Ol-G@xH+;$44G=@n(h zF~bkh!^mKijwP`%{-RR+4d5IfXG=SEt`3IkrN`!OO`g8LH{$(e5vaFAcb1U@cE+{w zLdkS@!3AuQWK6m;$sPDR+INNZ*=e2@(8fZvYk#6hnu-hU-=E9tBB`x66Rk)q@Os4Z z!XaiiH{_qVEe}DTe4VxL3B{QSxO1T5^M!que|KLs?R_nu;>SK4Fw;G{E`F*CP*T z*OC5z5b~~i`3zIC+prl@xJL8e4a4W;oxzAi-hMBX%vSObSMj0_E@fV>Y@_by6_X6( z#>PPnTdc!n>}!er>YueAuG`QGf7e-ca(bw25cJ>28k%@hF~KFWIs}n7B2x=<+})|( zoiW}6)!!T71#6=pAsF&I6!A|Qs>8qZZ4KFyyqW3u&l(B1k6q-aLL0_Bx*4GJPDLrr z$8XN1yjWl5W0n*tn)jj-_f5_P**uR22~m~T)1+zIgxix`)OdPrp9KYfedv30OW6_h zwUcHyU6MOBIHCOHy}U`#Nc?eOh1^4GeC+inYBuZ_m8TtVnzFYk1(h=PP?@>g{x_q; z{`=)V+mxC-?2hb;>EBFEshIvaJOATf{pTflR*IaxB2sG}>vE^%U_lGz$zwf99wD(3~>Gqx}tS=D`L z3c?xQ?%ljmKEcu}nRln5a_nVV`=G}Qd&Xnp=L@R)j5Ur|d$yJ~THX1oFZku?E#Qo9 ztm;SdPIYUwxdSWnUC|sU_WKPcK5yCPHGKP!7+NUrU}F>2(8a<0n62mSs)#O0pR$2*l!F$HB~Kye0j~r_P#~y zjfMHOeftOK-BXlp6psls05z*q1YQWi8Mx?+jXyHJzpCWtNbh>!*fnSBaw)vj#MZO? znr7eoN&{wK`Eycy{Nyhx5g$ZW z>Sx1xmu+R4v2vPx*$%HYd~DBbpS&iLHZd=+g8EqEGaX?gM8x>Ph#~$^Y3#Q-@{fx+s)%8^`nd7z;fO z<+X9nibMV8`ec2N%~i*&F=LH)t_;`NSmZfzG>#VG-)mFIsh_S+aozA*=;(iM*7rgj z5Y%3Fd_P9aYs!f&lAS)G1+RZ!E>eWT=gom}KFW7Q+&t>uX6#k7&Bz9I=+>5YnT-*N zLjdkLW1zH|@rI{a2ajnc(2+E+3% zQ_0ArngNU2g1*%&dZ*Hx@G7d>7+!V9jfyVC4?7r0tgid?LxWRICE2`UoJ(b}J6G<5b(GzDu8u9SA|WlTBBqRj;e zV$UAgX1ASxOk;X9)F)=wi}PcX7kbgheFmv+eYP}mIXDI z1?CKs7D2&~HZD@k-|NwZsqZ5;lSuQ}4EeO2jG0yA+n zUKpIKfxHz4h~cbR??TLk;Gb4dTX#qREI$Fo`d*$*9L`!hnSZ#-a&r7K#o~rzOizQ( zH=^aFZmrfCz)ne_EUrAmH{!DDOcZZ3lqXy`1T|oP!i@JNx&V=@ zBnyIzuG4h3V-Hb+vdkVf^XW)-Do6rt(W0;#nXWwVqCC_OBY@-b0n77g=;uq*%ia&a zi|CG1EarU#rVx9|9fKezkgQS_(&GHYaoVBAEn!$5 zlg90;MGttGPj@dXzN0ddlHYV0vcGkQY~8gAzWZgZQA$ldbmzk6*AErios;|{G3Zpw z!`0^N83=!PF;UdPG)vCMq94VwZeF|@7-arBo^Q+egto3kP)jiaJ2y30_L>I5I z#zJRBWC-UoSAMBJFTm18N;jZ)!C||@V~PZZ8hfmY=i^j@?~f0w_o{#3rMoJ3Pk~`v z;U;|GsgT3`S2$5*8>&$Eq=Mj)I#-h|bm;K&@WfTF^UnGg<;8tzQ16Bs0Zy}R#55mq z6u5R?W4akpJ_R4n+6QYnA5bxiCqOukjrZ}4CZ?$SF$wd_Qr{?7sb3THWyxBG<#1FP zMHPOcXCdW-%4jbtYY!DF?>g-bK@|1wYIGHiFJC?9E5R97)ozoseJKVa&6%*g26#e7 zSIQ2mvsQ}G=gd{sh&nSi-crbRvDoFgykK7Y8}z56wF!aErfjI=D~gv2vBC45jql`g z&Q`T>=qA2=^zNo>K8eL{O^kW^VOpD~hv8bXqE)?CIKwGz-}i}YK11_rmooqpTmFmc zg4^i5&Fe{LGICWV3%W%8^^fb@G#V7%5f7!RtxeRIu1CE^Y<|tfh$zkK5$*z|GC^!?xWqo^Qy6_ z^DyTC;p{l9hcG_!y4#ueNsE_%BI%p4$IMs96_SX%;PdU4hTm@$4$Ei6UJs%kYaeB^_(KQ ztX(2oKrV>dcF1l$xg7K84}{}|BFck98rd+p2YOb(O8MI#UvR6I3r+G6&Czk0obSrU zy(o-jV1zTMPOJHd-7lL#>}~=jJu2zRiiO>4sS^tEFp6SK$qbBe7((t^MxyZ?A1K_* zS~W_D!Mm6g?GgitnK`=YPt{$w`bth>W`r-q48L;R$=V2Ffgf`|huxC$2WqtaxMgf> znH^|fH{mE1;K@EMTix|e+WvkahrE3Ee#i>$7nK|8=Hdow_Hp2HOh<#g;JIa?FEwCc z`$DqeNgC4!sbdXZDE>n(bGKF0)`dan%3RLRoaReBB}tF7eh@XB@>eZDn;wDJ*u}#e zZL|hQah!0S^95F02JBiAud1H$9cekx^v*$0ds&0+<0GIxN3oRXW<)YROqEm8kYauy zv4^J5%>Hl~8Mg)P6d+&hS)amgMj?!DEZXvj^YF66*=cBg4hmyQ9wy!Z8Q;z%(OvvF zZ5t-a#_)vNY&qDfJ4)CV4454Z4QFlNKEDC!CROTI2oCrz$$7VT;bA0qf+M9ig5o>S zH_5pe>$Xb?sEA+!+btS6WP{fCqOi&3X`0)`&HE9%d1*sEoL`C18B|tMI(P{gOPomv zbYc}=)}@r_Y4xF>LdA3%Q1+sGhy~*UZw0M`W}`!&=f9`|YRSz#s61W#LvBmNy@+io zyFct$vVa&s5xN$!YrTT{rizB_mn|JK)v}Kl5o~+AGmlD|p^YmrDi{&M6F_HJqE_O; zvDM+>@BGAMD&Kwhjdu1bKyyx!I3EHAZ4aCiFQ7ZLRl0v}%~SpA z{%bq_bMV17PGpT?8uJD&y2N04SX*%Olurkz@8xjudneRR+N2eHqt%j?Vf#VT?F0w) z9Z6c1tOh**o6@>9bh;9{toDnl6?8rvAJ8lwXFm8v1;T^@4$_8Ebg{&Q2!d_IVH9dB z?HAQp|B=HZ%02)p)jOg3qkmDwh*0KJDIBSkFKAL0YF~=V=hZ_5Kx>=hkX^oCRCAyk zR*?xVFt}4${J*FK5#)|rkWWni|Njy>H3WC!V*;{~!cz)9tZyYPx-aYw3h)zem4~!D z`k&{De18m|crLk|^A*F_|4JPJNXP%yy6yk<*Zwc<=&Kz>f=sXAnHyKJ4D!O{vM)(&m%N2)&9bLp*pEu~rHk~w9Hh=O2S+vF) z=Vnv7t&*?)V+E)4TovqP)@H zm%h$h7QSXm=Ni%ZF4wXrNT=Og-~5LZ%9P-j7s1>$g*swMhgpt2t8DKr)$z07Ymt%f z#C?bR1zo(F7q!33Oxt7yIHW|o(PZ}ksuVz)NW6^c78-217w5$-RzCUG$~< zmm8#6`Z5vX{FYx3@ec>y1>B;I-BG9dH^pqt`M-*4|76yGK(`lBV58H-eWp)`BY$w7 zCo(cl06SRAiK+FGLA7WSK%2k>EOr%D1T}n4k+;rMU3?3bLS~rqR#}{R(j$Ydfo>_z zk)xoqr)D6B1V1o)XMMO{d-va-X?p{mHPKi&DB6ac$IiM2?II4GwiWcnjhz-j#-$r}2sZ?Ln5eLz<9 zE^qXYkgXfV=UrnussHFv>z5`6ACC8y(|>8HoAq99=>>p-hlV)t)5r>G<0-uOt{~c# zZK-^|dR3A`eZ0$+OHaA@Fg55s)fMVjd=R3PLo7IzOac9yfLg3GdG?9&i)uuPjEpV; zCs6i2r5gNf71}w^q!qK)UsP7gK#Rd}=Y>;0r^N~fynQ5s6m7#w`7Vt;ynf-ZXfLEp zoFwu!F7Zy5kdOV3YKrsL<%Y)U(`h3DV{2KaU#X*@j@DGo6;cOr%vK>;3&#&q_t6Z9 z8%8SQf*M#eXs3G)ydLAAq za(7SD$Xzk-%0EkcAo`<8>MePIFCwc66$3)`xbGXM4vVxvl)lAq-LGVpYlUsR81vKg zCy99H$|$~Ijs42g#bF-qVjp4rdGNSP?Btf#E;YF1Mo{uON8`P$FQ%4TM$H~f964_J zI#yll^ADELUfjxVyIF=}R!&137QNF2tXbOP9+R-Bh=cf;eUWdE$^ro?;A}L!i{~o_L=MLFNJnozbkTutx?5SMIBzL{VnA#Vl{YH5w-Zuf|x=WodVj^Nr z;%9>nOe&SXHwi`-jGt$9)tNNW^mH1XFL#L?r7};PO{6~(1x7{S;TZk`f=Slc$d9Kn z!U2O^D!SAlDx$P@fzX_1WQC$j=rn&Gts`2;J@#1i=xw7>TGWL-!t=7r^tz)&=)4!2 zYuYYZ6VaV7rE{CJ^BQp9?2KT-k%5-U-dUyBB9_4%n}0n(+Q;A%`Q^R zM-d_&YRxcNID2vcI2bDv>;po5s%$~OvYmW1A3md4kwtKw4#(y!_IW374d^(!iJiC=k{^M9M~gVHJw+qmGLMHDS8+;+ZEliq~bOvecia)A)hSb05t;$j6Wj2@p|7W>w|7(GQRcL*;UYDX76fj$|6%b&U^a~Oz8-EgKm zeS=3a)sZ>(#poUN<={uY_OTZ&sPLvfN^jq9>Mm%*J(kG^tM~v~rXuzJ%%;}m7LoP* zw=LV;2-XV6q*8&@Maa>bf$#zT-ddA^YMbc@20|^%z((vzOs;PI6*`q~wMt{3UY(p} zY_9Q^$;G_jI%)pw2k$#2=A8wrzNVn)a^55>%MH*NP52T$`ff4Ia3zA zS}l^EiXa!@BG%4{{tRD*ALDkA8>;#vqyQ^LItL12{(HYad#FQkthyw_!AT4;4mPve9D~WD}qx7&S@1OpCbt}-VOOwAn01hv}ZF2AP^v&^ExPiVG zBf#@N{?pJ7VLiD4PdU(q6G1VZQtd0`DCDHvLOTBR(PS6O!le&4s}oA^e4@k-Itzi4 z9qh;_D#jTp@^+|kN>xC_de6n0f|Z)-b>Y^D*fptL(a%&dRH%R z@cAXbJR^%y{XI5ps*2=~!7LX>CdEC&Etv*fvQb&Yztnx7(K_Ua)Gkjpsz{A&6TDZa zB!32p!vy!;Q)wNrX)EQxILH^kg#fBB z-(ixXaJ@F@cSghD08ue;fk_%o!0N~~(zjSmfB)jQm)b`j%+Jw%I->AVtVIsBb74+i z2Z$JaH~}G;3doZZ=L@g-?-q4sQdF{Gs*7)Z2StcoyF5X?nx?~%epI}i9{%%1W-ms^ z82i_tdC)(G&K@*UKHGf;_3y%bB{u74iiI7leX&{s5R}kOhLwke&&$uk@^mTd_gZ+l zz~;ZALlRf9r6HmRv)yvYj{zg0dN`!HcIS>?jJy$cHzj1Vo^q=j>{Z}9q$KRpZUM{= z**XzOsX&wIhce+Rj8m%I=A`Gp%bCO@{~ybl2LraVUwb!1KEuS;rlXV*8%)}_NcQlh zlk^~+SpKXfb*tjBM~zQ55q0O9D9EnVVdD-F_x$UCDu}N!G%!huaA+7{IC`|ty4=x2n!@Ics8x7!E~FaGXN1doYx{qmaWmZ_^F%IopM z?-(8@8r`{RGJXr+Mvtm|evta|^}Y}^wo36rH_2|g>x^4l(>}{tQ^IJ2&Gi(Pgs9i_ zcE_{p%^veV&AxRG6VKn8Zov1%|I#b2nfDhWnqgE+t|*0)*yx$lhI|Ug=c<|1sp z*;yb6T#(}1f?kaHfc!)g%NysBETM{96IPFEO3n8fZ+{?VBiJkzulKWlN1ZL!V8c!e zY%ZGFXB21WI3hjI=dC4ZCIwFq%9=fwW1rT0n%y`=*U;RAVzCh>F^jy}s1{2ZH@~-+vpei9u>b9+2;<=MkFch$P2M$cxb~&MA%OA=1}mLzFeD` z-bcAVBuB0C*{_>bg@w8*J}dhA^5}<`!L4*O(YNd54K|w_eE9CkOS#TyQa7+}WGmUH z*I-A4&68yY-E4Wv6Ih1|rxyT&pg%pzuLY85=(PN12I&_SYW6CW5PUE?g9&od*a(I)VDfBT>#ZE_ z=M^;d%&T}XAKO_&82_6Z0@!Z@zI6LLdYcT&cCak7_5a#nMSq1`3{wjD|lLxdg z%B6_EJgPhtpl$&!KL~owjB{M;_Y9zOiAh7`k$Dm$4syEQF5SKCqf3JXu>x9-8pgnY zS+?&ec^J^5^Q1GO!aM18&F`uKJ#%y08Z$xaS)ZOH%I%nwkG}h z(j5ey+seJ0yvykwEY2Tpe%si;LvKbsSo^9EIxFY63vSfyd>=ck2oVrmdLb#a?-Q@& zC#|a(g=XD3o89a3{zVTSTPJ6+nGr0-?xhs+V9*mz*fP8et*iQ~fJE&14(#L^Y4mi2 zP9Pg4vESSL%^?%i<(QN;6)(8w+H>@wC_IYN8@mLw=+A?kNzZrt7H^1=k#mmM3fa60 z(*#`$;X?zElv(#RV_R=kMuW>&*3u{G)`#1?p5gV;nFKnuzP!oKWGkM?PWIdHLl-g^ zE>~SRh_X@Ko0yQRn6{k$$iW^!{mdf5@DPI2sUir~-EW8*JWyO{4N3Tj;|fi%sK?M) zZBrGFs&d-0M%Fxj@Dpl0)VAPeXQN+Y0QD6kz3|?k#nt8n);|bF)Ff_f)lNFd>%gmm z5vSaIdJy79HPhzJx7Ob{Zul98S;JoTTJF7+U=i(5JhI6A#)2B(k`@TFl9a1?z&vVa zZ+8mP7$5bVZ4Wq3`FqtoUl5iPsT0{!1Yj;k(C*^_nxe^1v4CxMHRvlc$&-a!8CyHGcwow^P`8$RUfpDRsV6!aXK%;q-`E065X)iZ z&JxP2F0uF!bkQzA^UTKNm3)XT!x+-bLg8kpb1uXGA?`iHnp*pQK@=4Wf`CX9f&!u- zpdh_ORJwqG^b!>jG16OrK#<-;M-hlh6Dg726MFA0bO^nZPy>YQSw8PG=Y95`x#r9{ zA7(yjxFA_8>%Q-Q`xRO<6`MKDJ8PC9`07uiQ844{$$ODpcNa)J=<1(BH!`o>+mniP zK4KK?bn(FS=8pC#Zwte|a#g&7E)p3Q6WH90b>Mrm zAFE&y*C!!a*hWC|!*Sp&C$S+1C=1X_9H)(*G;Ur>zk%bblzv5VSuajn1B@%A| zVSSf0GllrR4{kt(J3!MoOHfU^jkXRg=IgY&wj5@Gtut%#>t(naFTC33z(sVUC`NT> zVIF8`7H?bVoIL1Uf++#ivn^qk^@eglv3E|m+!5c+Z-f)FG`1>!vW8C0SfF$Cw`eg( zBF6f1b!Ieu=9h}|ehA9u27k-neqU=IgZhi=5;7RPybV-;?Es7hXO|{J_}vef{&n5r zIem$m0u;w<8u8`fB5)gg08RuSb;-n-XX1t%o*o2?v)H#&>jOHhX+q>k_eJq}VBRkA zVHw#3lg2Nt43U0p^!9?zLXWM=TdqWE&6y_-9<3B8t{Fx>Omkd?yn9Ndc~bGaDrk;$ z?dMq}5K4aF8YrBukcfl4R;o2#=QXvn*vI#mR5wK!K3{Ns@$GOq`6!CHU_1|t<=KWS zHZ;rK2i9Sd1DWmvJv;2Zjk~J7^tXF4s9}{SNHTk3A9ZB`r54 zVVFoB1<-N0W2fVL+p_!L;uYR@vbE`Ef8YNFIIzGG?(U@AHn@LUMC8m}q?ZXcY0e zAXML0Eb2NBn^IVoIyHAfHwVA(QG(~SXpgcUKh zMH4k)vBFEOeFl4wiZ88U$n9kF9ySFYlJ3^HQz6BMh5TnuG$ti z%R?uiFIj9Z5enR*Bz6m!W9j~%z~;*DlJ^v`VpaKk7xg3HQ3i)C)2z03(@M2R%9AKz zZ9Pi23zm=u)1x4;JJ(TUB}u)6S0nSu+ErgK9UH#7$jvZa_U&}lGo_*eqi)}}1Gx&D zZ&>xpPpqwjw(uM}xzznd^{y&Z;TBBc8f{fjA7V8gOIA|*Z)t@bvbv9k>;8R2f{h8h z&`FuqMBQ1(n&&5JzF}wmMiJkZJ{$wD1lMRELOpj|r&%?OO$5nOddMj?9EFr6-sR37 z82|0G|C-$>_wGp5lwv~3iC+doVI@etAs*e+KM|U3ttE0iG*^NOQMj2nR66p#vXO!E zODn~b&yo}lUjUP_|D>CrNhCB#M+!pA}ROxDkQ6|KXx%uQ+`v<#XG3 zUhD(7;ym(7r!9NtmMWl*#Qz<*$B<9&;1QQI8n)ap04OmfRlZ3FRmTk>ij5%ja9eyh;^L^1&=&5+5tmsosrLsi6Yz`t zH?}&I^s386^?KYVwz|#3?q``FI{OAvM1p*tulzmL{rf9GNVB--f10Mo_KSCmnBrV% z2Kjdb3`V|V49u*STc@c9s>c7$I{hyrLO?p{4~F>6Jcza5evtqQw(}?&CTqVf7aU9M zVB2R+sEs%%be#D#nm)6_-X!S!I!95cv`Mq0{d$Ma&6dN3SorX|lr+C|cVXwacGYWOJ`Z^%nrK`h0~HIdgND{A_RlIpZ>NLb*q`@x(8K1)0xLGtcNmfOzM#$>ph@|0Ff!xuFv5E5~;{pD64qgPmm%5p9c_H7L z*xQJ{wobr&@|r`_LnYBa$9i82Rht<PT1fD;T_-vA?k98`^3pm@uY&53l1%OI24ZDS7@&FpUbmdX@ODrDB_t(jf*WB95oci9>gCZT2M=3shUtIO zjsBGqOSL9z*fYF8v&jQY*fl{E3@`BJf(z<|iXWGy+Lc^1c2q_P=MCk`n7qn?2pT75 zCoj`qBdw4(@Vio@;v3L04QL^@b+tx)`}L{U*(hNKxGLVNZAzmAK%3VVcR=0@E%!{A z6cv3K?G?83@cO`ua{c4tRR{gddoOjp)H}DH3IZvuuV>GfOqfIyM0*ic(y@Gb^HcTB zS&%xZYJi8YcAc71aIKX>`9>){mOPQ9CzTxApSo27!1R2GX`RcyIZzTR-YnFR$Tsb5>sQNyVU&VHo$NonWJ=Qp9WFW+}(G<%~&0$f!{ru ze%$&kMtj`?P0Te7e@R?7H&(;Q8r^#Kkzr7zzjhQCKl5tb)*~oV1p@j*4#k1{PRLmlnR=+{HJ_ZfK-s{<C=ak?rNN_hBd^APZ z!?8*(rwpYMOTR8YcqUdSUyn}CUF{9fD@FS&44^zYl{Z`Hq)m;0@|=z_p)hI3x)p|G zxnJ6hc%;QMAC+#MOPaDCT-enS;?Z?5Ob~9}rt)O3Jds;-Gtq9!_3D12Ex=$eCFD1m zFdujh_to@+Om>L4huywuc6G%G-^km9aEZ~)1*!Jl*Ug^v8>hdTGvT@e5#N#HhZ6;T z^xfQbJzg(o3w3h7=rw&z`HSkM9PW$^8}mo5tw*Q@;JiB4nClI5;Q%TLlHc^eIy|tGSyfM&Sh0?ujK!u3u5AF949QFzGE_wqf zHEz{fQflXcJ+Z!CYzQ#%_vArqK8}qYWbv?l$cfFS82l9>LgBVOErY|L=!4w=%l7j> zgnCR7S_;b$JLrP;Wyn)o;++OEW>onbYZ?8dNH3?S;ddyzpAzvHI>vV~MvM9Dw=+Tw zF~^qEHGZvTw*1WxNnukr-j!9IJjvk|e0_Lk-QVJZeR5SSvIKEn#a_vB0|pJAf5P^IqKv_=IhEz4FQA};SuF*5Rwo>x+eeeJIk`TBYpA^1e{Uo zCq5jlJ$V8HiTy9iyFu~99&K%ULwXC2Ay@kyZ-rfa?Z`kwR;q{AI4evjYmsrB^kv!(TkQYd0&)aVrupC}avb~)+!C2WT zEPd{*w@<=4;?lmL5ab}Jc4%+x*#Yj7Q-8&|3mi03QsQ$CTh8n*D}6y*Za$IaF5VCh zho*Id<~P60mGqW0-PHueCKLoogGyf;Wxfw=kJx}-iKMyK$RONN+y2#y5r+*yYl(0@H*C!zz+fqZ59Xm}yN6(;K; znikq8#)7NuJhKEB+;R0iQ-}@I)FXOMn8o-wez8W5%7G`2|4%o-Bo?kjNSID_T7~Id zoWIJ1#2*0q;g*2n&&2*75}aIWdHStB>qB1C5!A||$(u>KHQ$Sv-+8X)3vs6*XDDvn z^%x`$ze!5{1=an4B79l1aF4YmBPV~&yMMu+>e4H%libnYUBR5Nckm8BFpY_g*{wh&?Fv_CViU6WYin=Pr@nl`JPUIJ>(~tq-FdTR)q)6eq{$C@D*r z@L3UBe+VJDNH2VSrf`%mz!A`77;1xgL})HuDtB~-7b+c;+jd1!-E>m35L~eu-qdv| zNxD0sc;DfgmAh*w~v-JoH z)?=A#Nx!_~ zI4yV1o(*QF0inNb@iw!a|Io=nqWna>8~&Q;hIA#w?hIM_qyiBY_l@+>i;d>adp+XrFgC$+tg-AyISI4?-==LLkwyQa$}S}L zfMyRv_j18!j1_-QG04=6byVMEwa@=+NPA-;iug8^k)ChaT(g+W7g6QXh0px_#6@ zhAD50o8CAM&knp?m@wbwFQ4>g%Oou~mW>?RB~?Q{Y$})qr=Sj?k~+kmKzG(y=+yRN z4C-F1^GfNs&BAFS2{HJPyncJJ{ISa63*-`XHIu(1X}^*X?ROBH0Kg!8ov~MItns4K$wc1tUX>f5>~ch5klaCM2( z#t*i%OSWGK7Y$Ck21L3po(a7?_&R>1!I*2G2o=e1s#@g6z%Y|&diMptK=)o?f zqs#g-kvlr`OX%{WP3`#%dZ;zOUyX>3C7UXsE2z=(@Yw_UB+!Aty>YCo`(1fa85{_K zw~u%L*B_DdpbrKPOQ{zsw%^sf3(UtfvW-8#(8@h)C8BZ z`Z@7J;hZgGE;#JbL4W43t!>_;QyJZSM_bDeI~}|gce{RgOuy-7k7c+a&hK-)zlN9N z{|VQssRrzNzP&$I`jz#@Th_2*=&g;Yf`J-9$jFUItsh;EtX;4EMkhvL%zsMpWtlky zFo2N`l#9gk3oI6QFQrduV2u)bhC%n<-{RAjS9xlI`#bae(V1LezXNskDHR&L4NNcD z!FYZwfR|uiKHb62n|ohB(}rl!MgW*f!QPWaL<^Fn2E>?^^B#aLCz^zVC{e*Zh;fFM z)uhl_LJLlq0LYbDt`CZxS0R>h_H@#~Z-GDk?*IazX*y zf!EW4%oUPIiYiUtJ*8g(QO^LAIXn;CtX3YL3HtNszA|1pWFJ*a?h5wj!^g? zWME3o2$cNCXDkpif)sJ>7ew)+-<6(JH?v^<-6>LFfcrgooQzLy^qJ>l6 zxve+f4Mm8X0m9bhUCI%J_{08)jpn2s_@cws*OO@TyONyyTH7(oCp;bE^14u{b*wz& z{_`kNBVk$G8M<0GIXz$yH702+p%d}SdOOM!*H@MLb;i_#;kRDJP1FaxuuvEK;s*cu z*oJGW2hj%NzmLnx3Woz*yb84q#umUVQRq2H5PkO83QlqE%e%=Zn+*?z(!6MKO*dG? zdJfpFjpi#wg=u6nM}a$d%!(eCo-?j_xG3S!rP%u7(rv7u3p04)*h#AXG`deXTHd)nw_y)%3AVo@uO;0wzV)8N8X+wsMiJGZW#9ordwYdwI$nY-=pM&4fc)mEG! zqN5bfD{m8TbjQEt)r;GZU;qME==hi^!cu&xcjkKBQu{tnw)&wIS6t!^3-pdk2dvQi z>4RXBF8Qs`Yq;u{0y)>bE@H|3-B=yW4nazM@cp4Ke|@cI-{TD=;tHF}q?rdCWPiY~2|_Kio>fE`(v zk8rXDHf6^m5!)_bMO&^2;sPaamJP*G^GCr7;^sM zc}?*UEneSvRhQmXilQrDs`Ik7KbJFA8jE>3wxqwsB}lxGQvko+_f>z_I>8>O=~Usi zug-;C2fcboGK{(+A>%l7=Ey*YfooBNegW#ZpRcuDKmV)(<(d#UWJ~My($>a6cbWU` z!f{;GaN1>Tail7?6f_L(|N8E;xQWXLWn94EZfSDSQHird`$G54I%~@dtp#+@&8swX zk0>JL{iQmY?MtoBkDTsLQ31|0ieMAcFH9N<>|iU$RsJ+jCSig|>CRKB^{U|v zi+@PRa%h5EfTKb`Uwm1woNKZ7Huh-y*t=Dbc*n*S5n3)Fl<|qucst&4m)2C4>C*(w zqu#Nl9aB)K0gk~`&81=5vFKLfqUWxW4_z%z&Yj4}CeGH}^WEFtogw-c_uPQ~&Aa-u zmj_AIzaly2vf}w73adp;=D+LRp6B&q(D6sJ&F8#v)BOox`Db=TMC*GJ^+M{(Qqq-d zSH%zCd-D^mO{*#isF2%n8Wp~Bnu!h%YNBdz{y+g9aKK(raH}kFMeUr2QJqxdO(s;n z7k2Y}fkY%;b|sNP=Hi*buMNr80^e?h`xTB!pUMW9P{}C)TNF;E_j)F?@Q*( z-LQKukR`Z+;uDP^M^UKqFvcu<5JHHtumy5X%PGV!$3BDu(u5@PZ1~n7hT9a%g5`Di zz~%B~)|a=v+b6uR$qA(jJn247pVw<$pqDkSD}$;kb1mr+YZb~7Ny;+|P`nS@%bYO( zv1Sp>avzA6a*LFv0cJA)AZaKQkT~-0eYoW3-`D9hURShgcjVE7NdC<;=gAhybK+g0 z^0A@)%si3smA0!TbZRsp6GAObvkWR&s8rTmPFc^O%*(huhriQ7QPA&!$01si5ZHmP zrw5s0ISL+iV-~t&!F*_%Gs!(J!z}z60+8M*ThQwihkXHk3qC#d{r5X#LM)H;_GCa~ z>pF>b>+8%J<;vPcDavOi%DdN5`F3!h5&jy&bgGR|%82jOHrFdEiP2HFElC5&#^{>( zRX0(ecVTB~ep~nMGS`zWiCUGvNI(Yo1XHnl-(@{$hiBB^y5r1|@1*mVrrW90qaCB> zZNbdLG7rdDTz)fC4oij*&}8+^%^KpxEHb zeOLwx43(?nHLa?@|(u9}>cLf!4(pxU_d@f@62{5h}TO_||a9aFqe2{{ho;aoN0nt3~hd&-zHX2tp z-Fp>&i0tk2!cADPk zXTXWzFRJhyTi)$X(@EOGHT7d$HvwZWBaROaBV`qk%M_N86;5gVm-N_l$k(Tm(hh-} z_O_o#-L$ZgaT$*v7at89VM$2%{mX5)eHmnC$d&=uNNNUHQ?6H)C!SWSHcTsC z_nL{IpBp+7ftszEBM*7(0JJQ2Kz_0JNKAPfa;-n56k4x=7NC1yEbkPn_G7C63^!Rc z%~;*|98t(?&Cw8NO#(fef+dqjpVI}y#5kNw5!L&7S9zL}a;$(nhY?zIdFM%#l*0a) z(e$Lns=KX2eOUULPe{D@C#YQX1f&R54HdL3d$QnYQEQUZA-b<%Paph&G3&5Ojs;)d zoT}a8Y{9nYObP0n-&iO=ZkX|MKdJE%U$0sYL0gdH(#_EwQUc%>jP2>?s5`_n;OQcY zvmp*Zk&y#TPlwSd_{rz%uXuFRs}W1R74v3S?@Kt!JmPtukymTUa+G+RsB{!bf55!V zjkpXXe;&^*xNBrE1v8_EPbGy~`eTl9eE_$6jpr{a7lZ&&NI)e?Z)9BgzS4XYbhR+{ zkWt~v{cO)D^Dux;7Hxs(<67SV%|q^zl`+Zla|fxed?hiWYE~hr?Gz~~kj3!9OT!Vp zo`XI!72f`3CqchXfB$GYWqM{}dGiU$WjoyMIKsEe_%zl~zh&bVt;$dsNlP&KxWt?? z00_O*$Yt_J=5RpUMfy%mCYs7~zX47AH&fC2Ep8D`KBa*)A@=9<+bFG73nB#Mwl~!6 z#5_iU=>ylD!JzgFsp1Q_>ZBBmW8CRd9Fup!{#+s>a@Mn_+Ancz?69qxgwg5?%|@e+ z>xtrf`-KF3oBs8?Wk|n9D`pF>GC?h)Zl%NOciW*gim(F=mtOjm@Ip_F=6^ci$Rqq*9_oAoKi@S&2E*;~NT$gdn`34|RdPSI z`gcwT$uq<8)V*2IF#YdbM(r6>n@gUB4a3O^s!A0tIJbT%uCO0K2f6>;omZ;<7`T!? zGo1_j+T>oOdZxdiF9fO2McE&RLANKpS&oNAt^+fM*3&{tzWFxk;tmhx;1vASmt;T3 z(dGhsvT~^Fv$69FjPYI{K>RI#UT&lY7CDUOR1|#7hSn`e;jfFL?Id*paLxZH1ysnq z!2aAj3&0;ShGZ!EN0rvU27hUPQBiyiQU5ADWSMu^-$U$ubln^NhvEg*@rQqVN6`ISt8^1+qU7J7 z?FZ%|o4qGl-v~~d1WYm}`ez;w%_fJ;n!Z+qB~I^0jC(r2W*xGp!pqCiCSS#l!$IXH zg{@$#;>eeahxK1h^|#ZRL4ikKU)?z+9f`S3Jdc{*(WqNY)|_K_;`%(D7j3@enmgJ81Z#Lc& z!=B}Mj(UCVBl}~NeS=Y3-XJ|6dDVMuCWjatO7_Tup*05fB9&Qrf>pJJc^H4cZ(h*yK>$;dtRG<<#^lwPMkp*%T-J3`+c!SAETc+G@kQxNB0IO$BaB;d&hldP17f7 z{f4$=u!zg(&c;FPGl?j*hY74)Ss4DnOaWq(3}$~|X4!Jz{u&Q0FX2OQwT;?p)|TYS zmd|^ZH-4f%`n;dE98rrF{j(zUPv=B}r|*;{oQahUSL1fO&uqmp2o^B|{OHqaWemdr zH3%BSf)Ww1?=t<1>Y)YA)xJnGv)AG-gXAy!Wascy+A3EYSv4u(>3-sAmb1dk&YWMj z{XWxubLR)oR=BJK8Ci3tAbS=~j-TS7OoR8*)K?Mjp%3lnpB`GMaL%8?_m{8o*xuc2}xLCXL4Pl=`JeiW?ic8+~(3 zEgRNpQWAGAbFy`DHRw&_-fc{7gtdTAONl`@Dh+0>#hk`}RWc)%%mD;uY(ID*m=sQF zc!1ax*x{fjQGi`k5DHKTE*~f1?E}(@sCs$BCz2pbuT*m^oT;ldiMiU>P%vi5m9@t^ zcj;tP?1ZIRk{$eI+1qb_c2{C^W>MA9P1G^Pnhk7nB-eRKYOrw_I_0Rb?AF`wbiW~f zeg47f$(rp%N5A1_m27X+pnPkyt(!5OngoD5ta)V^gZo9cX`_bf=-{+MfzFptlqe3qMbAbLJEL4Oi7b2^uc?)Tc%1 zN693fr`%8z54kgrER-O@b!=qSaCdnlr@|HOL?K-x-RDv8LeOWLase?28;|G%7bz=l z|1H6(Ir0$y2`n6rD=k9eOg=T&)6qlg63@u8OQp(^PXE0q`b_-&5e&B_9OLdyFDtazh$C zK1x1U!F-SuEN72L+~{sDH(5KL+WFa`ciGo~V^_tbZ>W3oDBCCDbCS%u)IoqyP@rC z5;}xoe81RZF=qX#wx-ey0ODXAAzv9MCym#N%uUn@n%BG9Z$C6JxLSYim7lzlqvN}` z(;S!d(o0b7BD?b1bVX8mZdNKb$^?tZcAt<_gMWi9UD%I!w52En^cRIoWb-oOjhrl) z!P@w159X{f;kVIJF>{ZvF#5bP+wFB8!VM@0)s_spsBLd&UYn@jtu9y6UxY@-z;vZoHgBv7Z;reO9E_=uVFc?~WCN)=_Yp9vi36yct9{7{e0QsJx#)!BzlX-Z z%Nra^`D2UsPiW?k{);EfYs;y%U}`h^2K&%75V7jRfKAM#sZj7(QxIOqdFkk@+lD&@ z-{Ux5ksccGK{p`vYc0k!t%M-{OQ%Ucz%;_c?gXNASX_()h}^elw({LC9kQ$l28=AI z+1Ir7ELxc{m^aZpP}V+F*(s=zJzVgp7(V%BE%@K~%xuzU^2mIjpP{RzKp4UdCKC5H<*m*lCm0S7;G;yMfvfQKPK?X_d;7&~q=RIfy;T*}dP! zgD+1M-ib2bTkh(2sMd0f*hhbRT@l7f{Wkc*i?te8Y4L^Ghw*D#8E^KZB95QW3+1aP zu=U7j>o0||YvPMEz8q}VS%^iE)_e0M&1<7fQS|1WAqiVZy}?C!(yzq7sEDhsdjJ!! z7_Rp>DxLd_Uvn5p&cYm_KZQI|KI?|LW#x6(oK8RGnu??U%NLs?w#X9(WcR)C(5W9| zCY|f^x%`02siKK~c+69|#51?jj2$N>xD zgUO%11Ch}hz~YLh&qLScWo}dt3G7F-%kA)9)W>T8nD4 zMFy>miY`b4{twK$ztV0gAIH$jrTaOTA<-+I&DV7-V(|Z~s&fA^UHbn)udSVsL=HDW zoH{XMZ*NCNJ613pjnMoLje>$O-{*Lt{NzF7u(7DQ2x+c)@ku!wf%urGIcpc&3XFjz zZ3O}bgWn8|%_`n!cGdJg1xZZu9u~3P)M=`I^Vrvo`fSbMi{ejmhB;<;WxG>U(0M;< zWoov86%rE%{2}@-V()!ERLMz$_cVc$~#QRX%<}m7X)UlKD6-LlF{Oug$RGIU2r>^|K4t$PM^<6S6Kv6{O_*;Saf0(z zRt~)h+Z5i0FQ{FUdJ<7+z0@lw$y3;x?+hxDyJkWPS1l`xd>zDb{@JrL`ucBQF+ZLK zY&i_GNf*Sm{mP!pwCy}RKFtZD%Mwr z^HVQ3sHNs)-lT}EOT75ne>rEluO-8<#-~Ro2G6Pdk$~&M9(L&f?v$v-5+1!rxmr-= zCDfcxd+OZs(_q4W4%Sm3&LZayJT=#5hvW{a8sv3+l4qe9I=j!$E2 z#;c0u4gh)MYGmEG#yHg?FqN4=w{xFb|k~v&{h;j{Td=k%daZe4NSPztc*_H&axo2<5zs9AhCnw7Q zc@X8?EVGrv9UI(Rwl|m6%w~gMG_q0c@sxV<2NYz4>P+?MufBF9=ets)=;TRM|y z!JLsvE`}WOB6$=_pFg_kYPN@>y!4l@eK-UInTMP%RS&iupRefM%@dN>&leX}R-ee> zYV8d{nDiBz{AEpbe@b~lTI zST#xq-0dj&C9B*Ikj;n4tx-Fk70ECooyf&*-OiD1bgTleyA5ZCbE1=)!cy(w0TA?J zvuy!a17H(e1(GKV{E0lG{dw^pRXl|H5!}z9RR=K*`Eoe<`pTjaa7s8qVh$iJs3Rs< z^a)KjITF4}2%2*xJQ;Md>RlI8d36g`|RZ$D+?ddmx{ zKKf*Fix5sxQvg2UHGrDw<3zVRQ{75)?w2{Qcl|c0%SiH?|@;9(g`iUVr6nQY@B{x zqAKW&q$ZtNp#Ca0stN=}emF=LAihhT#WHl~uf=rUzw(io5Fn^vK6q2?ZtJ}945Mo` zf`Qa=VpgIY#nDveYS=O>7t$)lYseUT-S|4~QF!@2&9TmXv=~kFLZCkR`gzbdbWdi~ z*Ty{Xm=@p&qCIgde$c2j@nW--%Iig`NuieL~Vl-HKO4Cq< ziIGIvIJ&e*j(0{^4-x7Z+T@j-z?_H$qy+*u1U_aThEOBg1qk$~?q30oBMlumJjuSn z`pM`N;uq<}vM4P~_o(>ii=h5|E1!<%{r8L`X})uHIxZd&I5atpC)d*FVptJvx9N`;5PM=4nHB(s z&oEqj&@YxcNLOAN=rBMa{yvMyu3gpf{FA`zI_B2N%zyKuD_#FjD_kll;7af>$AN#d zhW^DIY6i@5KcEhPGW&6EDFEqj^LhygsDh5qn}4JKYv3PPtyVyVDvqL29*W%rYI@aK z^ypJeg;BpCavP7{KcSjQHt*xVN*4R8gLb@17YiOA_2C~c{G`sY?1mjSBhSHb`2C)w zZ-7&dOiNhbm%Z>1Yn!|bzvnmJpW(d&o224Ia{E1&4}M0HCV-l_m~k~$*|0`KTgk>W z?eVcZ9W(FmVOhgHrZK&l9vNwuNqPma)`m0|0!xVNmHm9iLYbTB@9pRz96CR?2{5@$CbW)0}Dk z!`rntc6vVX!N}<)RZ(Z;o2i_ykvI+gA?eFwouvyGx+>8#_;+*0sqK zO$J374D;|j)B~knj?WRahvNQ3`nlkU8hd7JgA|!hGH@IIhHG55)^!9knHi2!7f+J%RLnevnM7zTHkcvzsM{ZlluKL?ff_Kbs3BD zlgAo4<4Ug%Yb$!B|I6E`i)+mlsR8Pxe2>-(#-X^x-$H& z!N@|)vbNy4XRON@E$KROQ~Wa0niNRbp}({e0!|RyD$kt=&<++rzG`0;)EF2kxq`JW{HInFCjNWJs) z)bFbzOAiv}XrYNhwxSyBU4Kyt4S76y{Kn^1ZG!W(?=4{B*cqxuVcpObp$%&qa+z+=sRPmo^@Eak}6FP}cK@*7g#TG*#$7@c)& z_Ja7sHEiZZF54*FE22{RWvZgk=UkZT^Ok?uRKU9##mI7O%`Nll>X^J!Rt>fu!J;5RRvZ992-twV%8v#f+*gIx zTPA);f$wqckN9+BxM~K;RMUbZ%B~1jv(8+z!kaPDU6w0`%b;y`0&vUQ`lLqd*?X@E z%7IxQ*zG5_LK4?j=i<6}?KmFuP)D#+9Z%GaBb(%(2sX~2-tJY8VN>W45~ju1(|COR z6alF7;XKUB{kLi$9_gVq_LjEUYZjK&%o$lm{Cb(uGz$V$~wluOLvy6y?7alyExl>NI2R^FN zN0MOgS5hstEi9)g*Py$h7&JYlpIXPMb*Tk5i|b%Anw)_v3?tqfS-!(?-9sAte%dWd zg^G+N3VWS@$3JRVy0WV{4;i;OF)Vi!M+THg-3m=g*xd#cTUbG`)L{H$t^^`tPLM1{ z_=LN^c_IDh!)@q_x!)8f-w(DiHClBNS5)tN<-faU1)Ag z8!pC(%|QLcHZ5r8>6u_&kkjuawk+nuLLVQYsZDD4P#;onsfl0lJ*O%e6UWZ%7!@oI zu1j@R>928}w-1uhB(qexxbqtfa!u(7tiL3@RAmyRRtyG;5o^RCdANIqJkoDLd7sNI21{)A+e&EsBfIbf8D*Rau2{n5_ z6xvU5(-oRLpqPC$c4T0cmRne=eS@()jm0h`r~!$;Z0}O(-s1l@xcYwy82wKa_y1CF zt&~{GM&T#RJ>&S2?Lf5!Z(GM}nuzeftO>Yz}e$YMYXE3ifRBI zT`Pec6M&+qU(h1e!h_Iop^Olf`fm=#Fz`S* ztUg`k%iD!CsXFEz=>t=N&nb`6IewNS>P`g%#5!CYq)1wTbz^x!c7b}-*UlWXa>spi zgq#&&!^VU@91h7*Df8IpH+=-Izf`$MDWIb%mtWtuXZ(xmXYlDfId>Tx1=-OOr$1h$ zp+uT2&wV_{r|gtilT*5GiT2$M``t7cD9@~>1XdElSjUsOB7=R!kC0Q)ikIGoHsGFO z#QW_dc9-7c9nkT^b!$8gpp|M*!KPDmhP46BvmDZ{;-G#RM+Hxaq!}9^WQtEdfux0W!#GO1(#NaYFW&8(2PQh#c3+4`t7k_yc`S=6&( z%@pg!;O1M%5R(<~?hYavoq@I1K4P>-Z19Kx+;5Brsgb>zCnkRgXzi|lSs=cFN4oF+Y+gef@bb<9YXcqNj z^UilCz<*LO)?3Ze9NffBgzZ`^(~4xjQmKv=j0pmZcPz6S*q-`V@U<=pZ8U;qP1SV>ma zob#E_oSV`j)D$CqY7cFr%#@XM#`Lr*ss;gByLInimeh-Av{mf5Rb@B+z4#9p**=q1Lqs@U{&MTu|3)=DM8rYUpnG`~jvcXZ&^EfdrtNVCOzrmtOV@mh}%>s|B zVF99|IB5YZ$(6L!lH1e-4?^{l$aR9b@6)|sCagrm4uKbC@S=6aq9X@c2>McXZxsaTyxbDX6D{>$vm_6Tafxg_Z<>Of@XGv?< zSlOKcA5H{0S;M~?3%R22G1Z?4B|NwX`2X!Go*lY)^{x*yY$--DFT-hXR3`KY)tzlV zSYkEQ#v+&%?%on0BZQPW4}h{xc3S=5`FK=bGH(y z4ZrQznO2|~5J=Ugt|Ps0T!re2m|}7Ar$oTdB!NJ}Ukw`egTovh1=4LTO5S=h0?V|@ z3S6YDi{(g^P!Nv`aN1{Q5z5}D94pzl80%8VnAPCGrs~*H`2l%>Y%lAXzsqkp(h1Yl z7FEYCo8p;Bg8a0F?p32YQIdQ0H9!P_5SfJ>S4ouu7$x| z7ONwRv}5{*TlZ(B39jE#4tn-g+D1^z>_Na<^I!dcaTE4|F2cv&o1b(57szG5@`S!r zZ#}Ee3sG-w#D3j&BJwXXeU#2LXtx%tpLE-mrggbsG@P^|tWq_g0=AbmM``qJqGcqyJ~5|9O)At*xnl31t>_6|%mc&PEilzOx< z%X>vO%Q{z0X<)T_bi1w*M;DXa!&}MlrND20F^_QijZ=%;wu+W|4!EdELPe8bUCT*+ z>b@lsv~R!8xvOWup8e9;Dwokkv9R+r|B4Q{M!n#*Vr^V{-Q1kDS}wd-*8;ex^&lSA z2yEfke!^V&mS z4Vzlw09&bxtqTXg&?$ZV{X6$ct!5K`>T3%X2~F0JRf`A9oa~?0Ki_YbS?{Egk7d=L zg^NCWz~!Isq=58hJd22p`>X4dQcKn}(7JJT#66eA7@!V%bLl4BeaJ)Eq1T`D1_jzHG|C&bV8^a>Emfu^ifqY=YDL1KsFXcO! zo3tit1cF*IL?JFX#XB1i|7!f8+b2q8*nw`L z)Y4a&9?Jo4xCE3tx+~E9w{RH@0uRvKA)mfjM0Y1+E9gdo*0Qf838PHYB|EQ?}5$WdOXdwWA z>iri;rCBfelup%9rS8MmZ#9yZ?YNtD=uUOu0&Ig3(<%jTX#Gi&<)t8IhTnG&-Qj8$ z6A_3K<&PKXd9ejpa|nSm`3m(|NIoY-$>0X~U=c(OH%E$9hCUtT>)!>-%)5{D7vyLf ze&-1YjT{_ljPfkIuWIyTQf|jf^>7Sn09Drn}`` znG1{+8l$W+f|qsA)6w|eMrm3-6w`dwb5QC6>ImRVEtbj2cf`%i&B^Ww#L)Q8&ctdh zFCSRd*1T}D=aQmX)sXKS2DWTMZ~5tX`L!joqveL{PHAP&nMC zXw#jNIw`^l@(`&$$8yeSCdAZ+4+*GI|1dK);JdU5ZiVq=aaRKhSY;^Pz=J%w3^)kFKiJArug zC`vN;F1$|_4c}S-kO8#tDGf|X6$spOAzg!}27||TrQIR}l9L?cRQ8E4;yhY0H;pZ` zI&LFV;_gzH5+diQ(FB8wBJi3eLD_ExyNxP|luXoe8UV08P3Bk>Z5=+K~b= z^@sf#fkxjdL~-TVCp<}${AxX_$8Y3^$0Joh4%d+n@ClCWo=VJfd2O>_W(x*hD#SXY z$wmu`h;!r{uW=CgR_=LWC;PCcFt~Jze<;d?1k;|@n)iOhN@0+((E=t+? zcfW_U^UM2(-*A%bO+tlh+6l=_=*Tg%?h?$% z$G;T0(dMu!ctO)$b&G&*4jvoI?QqvH_&J^SN%_4Yu!hEo&2vCGFkMoJ3Db>{TK1?9 z`d&meEk_&!Q1p{%0p&dtH7bB4{46MzmR(($AIM?7g%sK{>$ZwasLGHg%D~K^Us0ER zLQ7!^W96o+x`B?OGg~yaoTNfJJYu-`2tx~d)`ZKC@LEi%joB_YS+35OP2)3-EaZ?o zZGD5tUW^{)FyA%ii!nVFQ|cAA!8x%Sudu_DGKcR^fPRjb&*B>$nO3{2$>1ejAn_?i z{dGvv_|EOB&-aJ>?kvY@T;vb6#LF+Et+0bE=#DJ$@k5InCi9b46=908qd8%KFI39% z2*K_Q9s;f;!vz_Rf%Q6(ClI0rW;GVoA(>22#adY6c_+URfC$oA_;w9sVHM?*L)h%* z(D_(%`+8-FWRh1mlTH2QZXBi6;5?yxJj1`lmN0CUXj4GG4Cqa^>WgX%7`A({>Id{G zSEBhXhE(=yXZf7R^GqCnm4JREQRFeiEJl7Ny&kskY;sGYw*ToJ)JT=?QcsjL1==~C z{~_p&?T?CP4+gm>SM#6nyuE}~HV9ZH8+qK$VSIa)lYz_9@m=3pc|jqSR#`+ z9}tJL?D2z3WxHZ|escj2s#q@O`@Q6pR(_=@Z#B#9ocA>2(>a!_xyth8uhsp5HrKjf zd$%67y7TWpFbLi)EZo;Pu0al=RMzI8Rabg4aX_#rcx12qT5NMMSmjR|*%fG3%7c4o zmi}hCS1+>ag``j7wsSq2TBir|HFx)>@R+qQiNAK6+Sv9yuIRr;o!<~|?^A9 zF!xD#r4pzNoVZFUOGEFu%&e~Xddv3|z|D^do^`j4j9U=my+nPv_Qqz-aBZt;8Nf@E z+6!X@W~IUP1>XbG?~q*Tz>~-y6H|$gS^xuc-KP-6t6|9d1kpebrgp1WQy2F!PhV;% zZrAM*Ad_@vTv-uC-7~&|)xW2DiO)hV^)2&>Y!coyQit>SnhVxoE}<#M6aF*$HY=F7 zS)%ysIr)UJt!Kz&l9q|Inj-r&r)7ri>E$)lD}P?2&Wg|v@Cp`XkF}TeZXw9qp<4Bm z-q&M(lDe8WzFWGexiF0KPOW|N>j~IGy{diTi}(t@tLv#*|Gn_e>j z#^fR(ga@B#;porlNFD1w1mJrZTALK-*RJ!$z0F-083XjBaZTzv)fi_GCmpyREW2*< zY4Fw`f)z2Jg`cM#yu5pex$9n;dY7@Y@J@eD!pso`F(~4 z+^#8*wkM^{D}BYs_#>0wLfq3Xm=zx0Jslo#MDz}<=Z_R++}ynw9axF=dyLfPY-jwk z>D<3mIycf$$L^@ij-hG-TQ(i>omj)qM)@@fM6+Na13aKx?0{I6qzkR?$CzJHF%ejN z!l}z=@bH22)=nU_t_H_5O-&Vm6xm-H-+bD^FC}c@x-#q$3eBP_PuNT%s+$N=YzjW5 zaEZRTt9LMj{WOwuyP-d{%5q!7%dJ?Pa%c%RCz2B-gE7}Qj`Sgib;NADE`Ju&$LH3L zZFAfJ7Lyc?@DMRQgHGIekE-{03bO&=g$^kllhqBSJZfVxv0P=K!Xdc?z{H0SD-7JI)yG%FTQTF56iNgEobP`Uad% z2RqcAYJv!L)89yf6+oWLXTT>~WkCLDk`4F+4?T(ib3cZ7RT@fK16X|CT{BD-y#!7n zM4SFzYv7fI$x%429@R$5_!UQf=?PFt)Txuomv50GJ8}Jg00V7?kMLzM;6=5b#E^K7 z5dVhL|BqWumk&_?mY0uE#LFNk!Ti{?=0Uz1Fge}myc($EM0;l%ncZ947i z*}rD%|IABlcy{zRFi!JdX#C$KV^MK(bOZq~O^>7h_N6jY&E^}dO!&9U3+3Mu;A~FJ zk7N~1d*1y?<08^h{s(Z5M4?Fgz}!WxY`MJD^AF21ZExxA8_5_QiN{TrbiJh2X0=BI zyM-C1>uNyAxJ)c@)n0V97venyQPgGbwK*JSJLKu`7RzmnCqZbxw;o{Bf7lP)9R9^@ zZ6bR~-!ngtb*F{|QcHSeT-09me)d^%?VCol28g+uLrHZp6$TcFIR|lic`sK`lUMQ( zA&oofPzLGXEUBz5b48AWx7Q9-{`hHS$@D3-%@SQjNTECe?}1iKD*)RBmf9%zBC5?W z)7S86Qw`E8_Gg`Rr(wHBjC>8Bop|O6@cxRT}FP|L=erG*!-+`{23HNO=cSXE)4i-dY&j)-Ga3Bj`-IAmEyu$dz zU5BLFedXvTxw5cw{8it?MqhLg)%N3q%0TfWwz>7}eTH!Ad0>BsYtiPf)muP?ue!Wj zZ9u#LtjV1Pm{d%loOskS4Czb;1O2=1)L)ZuJ3>*h90-s5(KiNaW&e}L4q-#d1TSm1 z{8d(i0$x8mD#{y^i3CMxeV4EBWFGH|% z($M` z6KJWLtdmlibH(50WDBthtGlo^(y62`jA+c0Zkm}JgC<3(i2gd919u0+#@ld)P0rt1 ze;4(1A_5Gedh?jP{hpeEhMzw4?q9 zSs~v(x|6bDH%d0YCL!jUtm%%}mmpXrI9_@kd=>10-)i$j^fl@!dH_nL?sK@#41spx ziVUy9a8HScM03{s4g*v?nZ`FRR&vB^@<*^Jb#Dn3nxD85tz2$;NObjt3!idvn2q!= ztk$I^9KHcR=z7J9m#3JH1sZ(@6|Xidf!VadEY14b!{o=V@WU77XPfpSv-b^Q03UzB zolW6zeus$q37;2+M^NoI<;71kLAaiMRdp?!7c2u)5l^?0X#S)ze}4X>F{jYm^PZ-M z5; z(l+jq_da;AAU+$e`8s)=COsPIMn7F7G!F`OA~O8VVJ^TUrdJEyvTS){%^tHdLN>i> zTb>*p976q-M*Jz{(*y5m0zrJ5+}1T%T|^C821kD_IYydexVG;T(-_-I zGQPBu&Gj8Gip>z%z$k%CmWji77k9X1MZNO-0&K0mpEsM9(GTERL?)GA{}!WA)PknY zUH4jL<=I+XPU*Of$$gJr9qED82M(aAZO^Dn; zYiJ%-z!U!xY4#C1o3i{SyGAK?Uam4aqj>tPN?tFx*Ld%J>9xh%wQJun7374=D{zOB zrocE3vksmTfgk2~KGP;d5B;PKV7Vx#v=I6$awC}S2Q-v)AWLK206yttJF0g|87dw? z?}O!fo~|sylt3qNnE<-?!Lkn4chEc4^tvKZj;^7_+P4Rr`6cMiVQ*YjKqnKjyM40+ z3F#hr3_UivGn=*e9?AyGs9l2Eu2$Ola%!0$6cQLAv_2%g{pKt>CRCNkKE0yHdf~9 zpQ-Q9CIr#Q-_E6raXVL}>o6okoJ)J<|I8+>zxjGqo+7!yBs)N5toixaF7M?6YKX|e zQ}2n%W>={&WGX}=)*YX=^AR3cb&V)mw|kbhKB;4bPv^G zOgZqrr(=K8Y)i6JKVK&CiF*{TOBcIP!@K=lxIXho>6t%`#Nc(Q$X3tj$}d!&IfP=2#>xZ}_ovfo_!pCF*WB z#hzaGDs|k5SW|Qq7Xn{?=L|j?Iw3-arbX`jYX0LAO-(aRV0N{J^E&8K?+B3V(IIp$ zDv8YT7On=B-dLu@8Sw%viKHTY;$=3?0Vkc1B$w+0GB>f{t&`<3m6T>YU&%M}Mcl+> z5;#%1Nj|34M7Y}**i+5npKd^0BV!W^sjg?F)!}@iSTlv+vdiq<$#EH(1cQ(Ou&n&xn~Vkz+YacK$wW@+Qlc7_;PB zG{Pv3;9r=Zgtg*-@@er_hAZj*4dC*yLA^M@17&{Ar*Zk5FK^JdmvB{U{pjAdUyGkZ znCcOl1mI*Aq}2!1L|-^bPZYKau{sU7zav0BM)jV?-e@oQ&^f-RhiK5I78Fx|ro^k2 z`h`tuPTN=hQ4$(rqDrnZ1M|6#jxDJqzR|ulJvukzHqULJRDaS1p=2z-z;Y9sq(du? z`Ug0uHJZnvV7xVOf7TxGbpva133%IqsC)F9u>K;2G;BzDK@teE!=;$jpsTNlJgI_c z4Ja%>e9U*qP`y?KMtD?ZCpWIK`q4wneV z&~w!_1(J@yXR)xl&D8hng0|^(d52^GupzvOS4;md@Nu2i`d6;sOI=AKe&SyS{~ z@wGwIZkS;%`V5DGE~&`=27FGK9XOMk{rw=dZT9PgWYZ}ail0e{4kRNn2!wsoXiL&1 zP5GaX0A*_ams|D>lfswODKjr;8ILFI1G zw`J00Gy#nxZ7tD-n)9+XXJ2(>DLYQNf0NVo#qu4Sx3aPFH)ulwAPC?HM9EKo1>f=O z=m0Yv4M4B>lc?cE7MtwHAg8*I7wT3*`YrR(FZLHIXriM-MDl=RjKHI*5$se5Fj6H% zo&e2OtDp!j_Rb{0Zpj&3F;M7?PWd68%9Fmcq7|ZI^NLo`fbN`4ymb(McmWgD zY=_lTG{X5!y(n+vOGiBD<8rBqT#`?IEPp=sg31$pR)*(kr0GbSkjcsU?n!PDtpsO5p=imcvtVDe335y9P%#V@jM59 zIJDUVm;JD=&Y3=?ZhU74mi|I8seQ-s;>*Zx6?e%5Eo`hx;-2VgOTFh+gr0bc-&-5-El#&IQmpxNP6B_SbfQcyH?;!Q&c z@4#3FpXB^o1ArPF zo(F`5vk0q$|wl}jR%v!9N6-MxA7_;>v*%dlH&7g z8ZjyT9AL+%6pbJV*^O9fD`XB4J?#T&}b|4PnuW=#M(?OcL_nzx|+$SO!heb zK;0`B$@}N;uXRumMs=Lk%FJlUO*eqsc3~*~08Gdk1sb&pDC%24M*5Fu(9_+?8XyGk zN7rMxlrc8?DnLRiX{nKsf3hS7N1%QBneiP!3;tjeI?j#)2EFtk>Nx}fBKmEU?=Jhn zmqWikn&_Y$O+1s+Y8k$BuN^fveyAwU@}F+DKA6v9k&@K{1)};6=w;@re%%rGK%Q+J z^sw`JR_j_;(r>?TL%dlSNK!K<7&un(%Yi{6e4=VypeZk`Rvh2UgEipZvg2~N%Ac)N zL34I&2@Stu6vnA{`lL5gvDVjcuSyo;!W)xQOTVNoId^^G`10d7ERQCEPTt5mW@$KV zn(3p@j;u8*!Kd+umbJL(xOo5k^K)$S%h(O#q-K;jS;L$AEy?`OPZ)6%9urhjff`qO zV$D1K*ycn6AOh`&GlvM7WD5-`whIV0mtkM#te19Cet9DLke0X%GL0mDrZcO34;bg? z1r2_O_sTHkP+v9kjrQ-39%GGbaDy4LRd)}BG-B*reqNeg^ySHFklqL|$y*vJXuqZr z9nBX!n%i((Rv&m4xB1ld#D%iGjD0*U*Jo73F_WYL_5>^bWuobO3IxUYLMhj44$;7s z@%1N|7)z1BtP_OM>-R^jKKMN13f`yotBSJ7{PYpO_sGr}$ z9G!?C6jRURmXY`Hg2-@&Z?x`}cD+NaH)lOIFK8V}p|0}zE4h@l{Cs&=Z#?}oZ!6E? zaKLu(mET+U{i+Pm_qGQ+8PGDDT_HvBFJ}8KDyCcGpsncc_K=bfHh`iykO)M`% zs&Q0xlCrYDYY-h?tz*(0+|E76t@S>yC8YJ~Guuy9H($N>QA)NslE%#izxA8kgQmAl zFZ(K(3utt&dHJ{QoRyL9mIsRG0AUjvbLS2^A6qAiTpdowHL(eQG4FTX%Q`vICCBMc z9G391_oU^g=JzbmsM-5e_+IbcqO%!b2}qo8?rt&a+_$RhR5n|AFVH>e-q&=4>BkQt zus&R8zmd4zepg@HWe~{tdYtB&bVo;jryW#uoV|@%6#wDK9xReNs#g4P=_>o$$3I59 zi9)A6bdjsnn}r5(e3w7Pd>zP72so>b^(DP*8V-S^9fFyv3ESSw_3zbu!zIrNaM})^ zdn3JqxmGT%t#a2-U$;a(%qP$)K7tllLe#=$O({VzOLw}nAZkfIQN-GG@^Fv4U(w_e z1GG9wSaUg17pHj%Uu8Av27{Mx$of_e+BpZT-Z6`Kbi-zam(Ivq)e6%|+!YJHNr-f= ze=o^B75Y^%B0*@bhbIY9B&bWIa%H%JRfun)cC;F(o6nLh~?|O%Kl7eZ8|^ zRNu|KvkqL`xaz0Y<#Fvm%H%a&2WsW^6nm}-T6>c)=DKLk_3t1 zY|MLa8x%bA8XRBXly!CDFosqiFSrnG?v%R!rWP8v$KXMnKwaj9SCXFU;ZdH022(og5|UAJtlv^xd*~TZ=Xf zmT2RT8;77VRh>xtcjZ72AR-Tn*u}bSL2vt+Ssi(POdpuBD&p+9;&wynlp&48zhI3h z+?Zz>H<_Q7&#|b3HJ5PCt=xV%Y)ylJzn_l1C61C^j*G>!$uAVJjm1-TwTLDy9mhaE zUP7upa83A5OzAJ*R)*J|7`7g5WJK#=}bH_F+|t8j;bmOs|I zOguJCzD+)g{NrrByP8&DoQFAxe%Hqo$-ar7GAqTUFX^V9{q&}zn`sg1#-mo0+4n{d z4)rJMJ?6@7|KUhBG&rO%RqXIw;Rcd$L|+qlhnPzayjr6I{xZc})AJJUu zs{Xx6mSKVXyF2g@Veb};6mao18GTVxb9ze^7N7m4Uc#a;w8`CsJ4J4$cL6dmvEm`) zqhGo{6oukXAU{1+y3~Gz8^=rGh37}q%WFr0HSMwNDkwP^cy}7Y5L=%8wJI-&1x9y> z7)9V+SXA{?#h^&lzbsQ>!ZjlHN^AXNPr}w@11f4dxx@~CQ`(55s-;HP-=>Y_yaUzF zRfCCNj-NGS+%xS#kRYj|JF%$K?~CZ1B7C_*1)u!N;&Z0K&u@~14{ZD3E;w(@edjEO zQ}(p!690U-q(j_Eri_asa_uDC{1|k>ub^}O2Ch7{IGwf8@wu@-Y{GC4)cX^{!_x05 zvceR#Kd*Ye%7pM1yY_%n;3?CNof4wmE|+7!yGESqrs*$FO`b4}4nPgmS^p${;Tyv_KI;a#VY_3e4Ga@X<2X$=mrxIlM{<{!Z|Ybrv> z@nwSO;(Ebbdfzu^ZEn2s6z~dZ35Xjlbn-ocfM}!w4XWO`xRhxjT^AE?#%j*gmEWIp zo*tBpOp;#=R3Fu!cR~Cjz77p5d~$5e)r|e+RKP5|ODR!~>|^4Y;oPNg>k@-q?6 z_Bz!T4SJtma<-qRzZHFePBcL|fH?@YtJE0{WLHw=O0m?Zn-H_hI-L@7y(^l*6Nhq- zy&Y{ubrBCvXP`xR|70UW6nl|so~yPqWjvi6XG=*(aKhMZSfZH}{AQ4zE2y#{RVZNk ztqQ9^XqL}y9G^kl(wdl@Tb<%!nXN49?OR|uA5=prSOVbc*&ZNaPK3&KlyNiS_zQ>1 zbkH;?(=X-SF=vZ0dHw`bgnnun_jl;`^q+rsxYnBK^Uo~A(tm*byb&=U=NR873T0_i zNwqbFVgT9-@&i%vK!JBzSRGRrDE!x)6Ure?H{n)F@1y;= zRFI--N3ir!*S|RoX^v6bZfpRf0T6#eJdjoTZ@0I#4%F41a(BTGCE&~7;aJUM*+>9I zzXULy)}J)pe|2}iAD4o&MbRg~)Z$mb#`vWV0OW^a4nRbo=r<(vH&q)K7s+9k#pNM=1Fy6pw2Q0Qw+NRpe-nYj!*cLT ztX}RoY24oW!4yb^aMF=Bhwi{!HSjN?|8ld|*CD1+Z10Z$*Il}3wn0Bk`;M$}`l3ul zpkrZB8q?2UOWYB(b}Lc|^Tn_EhOMT+hzOJz(BQ zPbFC!idZ*8D~S|t>bW-N=U!RR`h>`|qu=R|O|OE=!6bjV?NQ~n%wZ6WU6v5?&6h`I z_5PWkaFQ9Pkg0(08j`OqLVo3|ZeHXlmu&+xN%4YuWV7y9JH{Ms>aJdJx1nrZc+EgN z?lGuulic__?Lj2L&Uf88(LTQ8xrfv$26Z8K3*NdD7e;=PENTz>=qZy#Gp&iiNaN@} z-g6U0%H>5_x;Aj-eU*F-f>qnUna{YE;f)w8CX?{x5%B!rUnY=~yHP><)JxAjMnaM_ zKuF>n&t-~iqY)+Il}D>R^%Y84N%L#Z5-1unO|SFZVrKP^ z1y_-J+YDR#1B7_YhK_fUhGK2XeB4)DE;E;3G2N9YvP z6Of$hY9hgwu`P3}c*ps==h@f`sy*1OnpP+CSGl=!HDM(#f_eQk=JKp1G#&gmNRcG3 zZtxYDnipboUhZ9vGt+k8mvZbXIr+kavsk+LtQz7jG_*)gH(A!>v)gjZR(92zd?uIi z^EIhlIXzKtx7oTan-%j@?qe9vwao3C-jc59Qx|YeeCyMTpWnaj&Dia}B)>7kh93-$ zqYlOUqv%VY;G6XuNW;V#}41U_yCclQzSw+ecdLZLRr+XJSnbvM6`SlF8mmkwa?-+Az z`t_Bc9w+>OGZL~p%*XZ3pE0VPiS}oVZT22AMTe1%;t?vh{bD|aix)mFiG5EeTu+uT zU;7>TIqO!)=za`^$AZEbe6`K6^WJq~9dh-WeR+(Yc;Atn;zzxf&f2xo#{<1Rb8;I` z&>rP}!6#4f@Q5<0-i{a<=Z)n)`j5}!>zK1ue|uRf!-4|DSRBg$*_`&US?sp-jkJI% zaLY9d4Z-@z+v~f<@N1DL;J8}$+Ta^w-;s^%E}|737bhk68#dVy)-}s%R^cT5B(J;u zR#yIj98YynYQ0IBEK)>OMu?b)WFMN+oJmr11Cd+oy_qWHcN zS6ll=T3YHK*R^oroG`6sMzYURi&~EfJ7B?Xt3$J^IK7U48TZ`kU8~S4PsR?fbN5oD zcAnNt!@ZlAzxoK^&Id{2-E`CS36MpRz`|+y9oWZZlg-%Yf}c-ABnXBB-&VrQcfJAF zk6<6cI|f-Tv+00;xX@K zs~1Dr3mtdb738xduTJDXHFP`~)g=-6l5oO`MaWxb9d%j{k*$<>-ILh7K%pHP z^sx#kVc^MRZew!bk4G2euy2g->Aq#$rj5qOI9cM+>_U}D%SjsiD5J9zN+jQ06ed$U1e3xvv z0_em`YDf_$*LkJn-C?6s%i|Dfc4nr#38rzXb-6crH3LY2i{%x?ekp&9VqNkHzp4+seDipV!SQP zbh_RTSF_GH$m-`n8&2zOHd*cZ@#wWnZLU1c9dtB|s@ijCQRY0XZSq-iyRsg9cXFpD zdcaYUnHmSWr&G9<;d6+HLhQNvC~n1Q9FYg2iula9XRd?J%RizSvO?us$$CN3-!>C0$n?6h6G~rp z;nlZH8H82c%Wh-tXOR0Dl?jOSZI|TDR}twBrg@2RrdWGgeAgTEsq}}+e`)3Sbkk!6(Mq`tC4RsXbHU(_iah3<|XBWsCQa&=Dj?`ybtK?2P|P!>z*sM zN1!&Ld?)1b+uj*>%;UA`mN=|>PXr9SIHO27DtO@E_N0fUv)^Byp*c2bA5e;$H4FVr zs93Yn+Nke!-WneD;$Mt)Z5n9sS(0dFk#QxnAemg8W{0e-Vf9%0X()e#U}~E7?*s$u znw=HSZ?PB3`M_-7_fEveiTVPICq}Y4(U}cAx!k^Q5>KmjSG~DTB9t70FF5esa0y=4 z7R$Q_=I2;xJ%}3={`6CB-EU)Eq?izhKj+-|CguCQ-{-cHpGxtxGpU);HfPcdC6Z~u z5`98N(krs)67g+f{&a$@`qHImHoq*A$dtsqV8Kvj?|>D zN$jme{f+iZm=#3PgT==xLgzn5=h|~6an+}oNi3U2QW=f$HsulwGbC>M2Q&M1Tx z>Ji$GgheL6cE4+MGo7t!$HbAXdd>QSEBBg08facL25k~uk@=L*x9?(=U*7mVS0L$f zG;_}e90oUES$u6oTII{O7w>|%Dl=VuutG_XF3s-Z*r0M+13PF70YW>>d<_MDD&D9G z>J;yQDzib=cYor~(4;=BtMd^2RVDD!p+Yl?R_^6`-8bR))d%)X%3P?U zRN>g&0R!CO*QkPcu?X5rD~0hiLjlW&U7p#M2Zi{R(Q60Xe>^;9xed<@Vie~d%r5_T6LbOcdv#38;mjCQ8f_io}81Hw5Ez|3+l2g4jqo_IzYWyqNX$fPc+F&aE36@pncU`EP*>Y5y(z`}DUzpHNoQ#GC$T5bBiY zGXjPW-{EsAM{^T@UsQX@=f9{2>JsXxJ+Gu%%@Ia7#V<#3{`&VD!Ti0o%ipq56VM13 zZ>Uo-#QJ~HK`Zfp62C?p``51LM^snimMFr5b~=o*C=+i;X@0J2WY5jNJ^$tUSNOX? zxHKzqvH(c-_gHqN<{13(s}|P=dN`2}z*<}VUhxdn1q*XMPmw5TR=owUA zp-`*f{Vk52qf)>+rz55D$+`Z@jo#CHFf2xT-tS+2SYR}HAP%3=X-D7%*WbLr?7rCa zLq}B)UoC%kH^|wYf&sl!1%h_yTD(D1!2o>wilOQPGJ#MueAzmkIt8%tHmktt;^A{Q zQRsWN*?LD^16~uWx?ZVqs(97PWSn#E^P3k>AK1%ZXxtIez-x~!-d(!#HMD1nUwOFY z_5AM`ux1)t&}oK@R*xawVVNF3RrN;@(}DE*i;exEdCjPpvSEW%7x*t?0ro=t691y8 zpI&i#YhhmQ_4N;}K0=>7mHo<%7HZSWW#28ykls}qs6gkwu*E$cvH-5dV3TuBlA)Bn zRpFcv;}7;Cr?mpKSxe2Ar-Ja=X2`x0*q+z6`_!q(ERwHG(3QKMA8hmrw(fWNi)_6 zPgGeW2o`K5l2yi*D<_Zx%wc=X#EXQ^O8ZqIr7Y(@s+8bvCh%6os8{5~_vR-!?d3D( z;{Ho=*7aM~d$?;9iA7#>-!_t(^=lM$rsiUnZG^W}vbZVZ}(bp6=^Jq6}LvS3D>jap7V5GHZD@C`2u84FZ#NvcunwqUWJcFQn-apuGd# zu)P0aHKkxZU_v-g$8ZPUzP7Ho!FOBZ4r#=MSGDG$<`$H}X=K)7b6%#(vF^+F#Gfu- z{nN#5tTQY1{Ya4Y)l!*F$++VxVkCJLe-8u6ubUBPAY!}X}&YRRTWw(Eucua@oiRKgO-9iy; z5QB>tF2@u1&MNM3i3dE|d3XT5dG2f*9lq>DCPHc5T|6yB zw~spZp!vo3Y4$g_9=pqN@R-P?U=3K-r=@$s4`+r&44S+ID*QcaS8u8+rs^`XdMlS$ zEG4xz5I&UmGJ3chQ7Lb;^F`hkcP|Ms7;BFu8l>)QYCu8@jKp*Ft65X-c{`m7&dw+V z)`g*s=~iNQa`<`Y?{{PqE|uyWQe-ZEC+LH{RN=CM2AHMc{2v z4|L7+W#N9I%I2RmX%xoh3;dbmQKt`_aRx`{fd1QXhVNH{x_mLL`Dxgm#_BqDkdx;~ z==Q@`sW^pG{N;nG0snuRCd5XHAR0>m>b*Wxm>}M zdekpc_JUP;k;lT+`Pf2HRx8~vwPEpmyc<~C7IHB znUt0gFi}9|4b6e0Td5k~aIdZ1I$e0ZqMh99C%&>jQ zV?rJ=%qCh(I6U~h6|A7@?}jzSA{$OEvu!`_{pQDd;okv}G~9d2=S;C4r&XkkRr=h(Y_cB+G-ogvd&pjbz8*C2z1n;kj){_hPhS%;#4B*VW+{=06WlpRVK$_6Z*zGayomRt4o9K2@HT z_^~nDZBF2Yu{!B)ww*e0-W0!Jl-%+sQn_m}N^Jo;Q<>a{p;Qfm$%%(|R{Bcp^GKyr z*eo_Z7-$VN!wl8Sq10#b8C7R2k@;@HbG8AJ5dU^``S^h{mrvWJ2bS#1%E{3Kc8c&1 z<|KYZXmZ>Z$r8S2 z)@^(-4li<>$&yhen2s#rijpzWtLDOaVt8jxFmjt*j1b{pedmP7EU~GT+((1w1+Pi^&I!oX=BF2>c z9Z-&ztIBAO8jW4e$@QwQs3;y?;We;3T-!bFcdq!uOVvkY{eLoxS&F~~RH$E=0GLT@ zDigjelMkP!z+WmJBKbY`uQ|F{*OLf*va~8Xf+~&Kjat#!)tSkDc1rnP+ct*3JxUbD zIYWAl-KlFp2?S1lq%#EG?}alX*8w=!Jq;nqDB^_XJ8jQt;Rh~*Kt=RdK=%J{{*TdV zD`FD7#PS1hzlXV7vP5P~ca^GC^|r$E1JfWBxz!^B>vY zR^xx?8h*d|Z+5x=?5gVjRWh274Ap`*G5sHuLT;Upqe}E`knzyaM=x7TsNi&P!xHrP77zvok5w)pdtdT-J*}rYzg^d51 zx1YPwMlC5-umEe&dMeB2WwWOqZ1jY#<@r{;QK-sS_2(`kW2yj~G6a6UeoPCHR_|W_ z4Vue)2h2X@k$5MMg)s5eT6lAo;<#a zPb(jF;jTg?F(Dy5QpgA)J)@@R>E7J0n64kgpBPvR-clxBsH=`X!t;+9yx<<;5vR5} zSPV0JtV))W4?=ue?5n9a|0sW?rqa31WSycQ@s9CwTxlo=oadudo4)Ny&X-~*tk&hK z62r$g!i6@B&UU_c(r@4vY7??K zs(iSucAmLLRvv*4a<`vN*0ReH_V{nDXTke#GJO?o)DIB7!tbXFQKu{C;2}c^sl^R3 zAW~q1@ElKzPkW;$OFh}6*vD3*C7Y1=5fub;u>FahZb>YF`3mLOZC&KWBA=s>Fi;gG zjA@jUJWYasKt^=tIZrFZlIO*_agbNUvG5Bp9;XGl;P4Jelq*b(Hpp!!C@$4Qv}8FE zbY&^PV0#0aw+O$Dw)Ez7i(g{P|3mEP%4>~eR!$jb?SQaClqulx(wSimqt1s@Xjq+Xrm_Zmc7^SYILCi!z`cm`6;nw`xV*$=8fjmdKt#q;tquJ~oefeIOu4w?tizqdxS4cOQ|xsY zO@40`pyWaWL~b~f)w* z`Xvc_tb+In;)gSjq4!`nRahE^&sWWHvluQph);wp6D?7kZ-EKD##zrRMFh*=%+>&R z)J_r072X3Up?WRdf&=Z*;eN=WtUK2FgAL5g_M=YpmA&53DDClAp`K`(9#tXHNNXD~ zZ}h;57jfZZaHl1%U=}Bnen0x+ac}0cP7}{@O@)V?6Ji=B1&*=s0qW&O?;)dzvc4%73&xTTfSNFR0qI_WDOScaOkU0b zKj!K26&!l*11oSKl3lq8i7Zklo}b&8;~OzANm+2`Hwg=7q6*(+P%mM+GsvHtXUg^3 z>(0BgS(3nJo%GslByl75thm1rC*Sbt7cI{frHtw1ubUj!ZK+BYY(C<|TRH?L?7ZR# zjsd=K@k!vs4?n6(sde>X78_Hg9Lcgw15Yp(A*}@}K?I*UICa5N9!FkxQn-DVe^= z&t%}fJLb-j@0Z}0)r+}^ifp5(C^R`jQz10P*0J)LHY9yU&9Y7WkQ6s(F7Yhy{I&6s zHU;afS{Hc2URqfrf;m30f>~6H$d<1y)x`P-EN;VBV&FGOBh(ckg-u(c#Q>q^d(ZEI z?}q&(o&J1HH}~U$^OhmI!ytV$G8sKl?p=d zUpJ@UHPd{2uJe;H1V4}kHhyD`zl3f&i$Izh3VFQd2%ytDYkeE?W*_AEBod~-E?MyA zhbharYv&9RfDJSn_6u0@Xn?}5f0jJ}<)1qfOcT0=RG;gZA8EMRy2LxB(|T5avy$f? zlsD8u=$4M>Z;&e!+P28uT-ze#3tCjJ-c z2cc%iKnQ@jH!FXGO2XlIw&2>8pwMYM*9!JVOP*9`X>t{f;M^jQUDe^5M}+r>vo~H6 zy2g+2I*5tldr8^wU)45geo{}zZH~z0omGh|ox^;J7bDnF%cgWcw;oo14NYTf5iFLt?ykG#Fxt0g_Jcfw* zL9jhV&p_eaJ6U~})o)pZ!M`CHtH$EG81dksL7*UyoUqimU=lKWOKoxtrzAX^V3~S4ln$JgUi;a1qN3TKPe5;Kb*4 z8sv0?+?c=yONdslGx%ocwc_+K+>72$@wnZS++Kk_4f;OQsAs?dLMnI^vy^_}uO1l7 z)no5c%Krkm(kXB7ISq`_-i8(b-9)O&z3K-c;j-BgV(22|K~9;d>P5MDht!+ZMyJw! zNC1DmQSR0$t-rNe$GLp4^}(e}M`y&=GU^(};9KU=h&zUrr}}XFV;Do)NXr0H29NEw zT(g39o$TG5dLy#CTl1p&oM|-$K)ZHc2IpF{R~UF7Hu>r%vv5OhJucM?3X%xK0E!$| zo!Z5VB(~*PgQen&iL>~Qd6I3I#gd!!Q;xk!ElR5QUkq@wUlyWk>|Qo+< zlWPu8hMkCMGc@Oc!hmNQG-!Zl%uie^ z7OlQ$ghU7*Y%+778;N=uKAC^(^F`>$XgMN=WsT&kq zfkkg7U6nah;$Ga~`1(mTFm0JL^E8aUGIcNWsxL0&~a44&dFWiYOPodhwW~4;t zSZD%N^Cqj61FXT=TYEEUJd9YC^mFEt6MwEkL3c?D~GlNWNM zIC;z!wy~J*6LHyLN-~8>vpA7gorag(q;>l2S(tzIH-vhRMl3`$F2LTtw_6*}IRSc@ zbnw1YYn{c^fb{mILF&mtJ5*v(_;+?al}GPwh$4eQ`3l#XdRchT*Li82u8GL{?V5kN z@Kc)bHbCm*<$cFRt80P?=6)P(Gr6qrFxVlcuIaop=#h~uoI*b{?TvftYfYL!^XJw~ zZ}+W+m-tyUGdG6~pwFtTi_~xXx{YYGanmPQ-fvjlj?!>}Mv^q~F#^q}kB*z74`(Ey z$}8n11z#VgIsFu0cG=Sco>x5Laii+BN@LWuZl@qIlW10p(a-HJ5aOd=JZH10_IW7Z zkxL6ZVJ7oaDtS?Fnb$-yxA1`DJ^B$CI(=2s{TfudDq$;kBg+NX}NGW zJQ1R>X@hkWL%7Q`(@ETur3}nKv}-yRsa;b4rKx3AM@t6n!uy3~A&+MK{eB&e1#>xH z{Kb>8v}{4VRaBUOw^L>#4JQ^L4e+Yn<$hgqE52g`5e+XIl(16G$ZU7JXGW63y$9O4 z^344*aEAx1a*I3P2z+#rA|iSwrhdCoQ6j`;28?Cs_qwqVB2WRlih1ThMp6Fa!62Ml z92(JXwc0bt;OAtfW{J~4u3N{wyK!4`EDO9iFjx180PIvebJFUzR*Ul?Q>??R;r))L zE*-%I3D?M6K81d@YTjBG2t5$>!*I$sd~^35AL)QmERk+cFr1rKHu1bnQ5*3`x_2il z15WV47ccX2;9m~ns~R!3p$e&-7IB2_R>q$8%7CwEi9uq3N75+bc4tp6MCF&MV2b3_ z)U=x6?NtU!qY#>ohmScWwhY0ord80}^|G_H)+)DmiIsRZ2d6k&&;8c?Mn;uC?KXD0 zM0bapo9eR9bmkC?Y?))P7!U^*{#t=!IS1DC>VMW$QjjKUSkABgz^DZ=%JgkZDq1;o zZmc{cRKy+7>5hfETU8x%o}*kBq!>NBaH0P6ZQ(_Lu+}g}2f=)jj{%$@Ro@XcwdMr-fe9CNOd|L6S#S@oQ~d&L7xu_a?^32XNC5~az{;I)twf2u)rnCP&?tQpHNVYL4N z!MFZ&QD$&bmIM3F5bcbk{0*A#ygt2U{22|ty?$q}y@5`D)Mk9CWzxF0*Pk@?RauFY z59Gm`$dZRb7CxkS^25GB(#ef)Yg$11*eicxz}93dc)ia`-pu%}%2wJzWe%;fe;#kq zo2;v8Wn0;$tJ-7u;Pm>c#rSqF;w9(MQp@#m*OO}DsjjC&pNC`p#ke+-FHyV{AI#I) ztv1w>Jo=@tSr2`6M)$CHYtc%@{~o`vK{f!voS5D{gB*Z?fgZxj7Jt(N^$&Iy|86Y& zKQSe~;y&*KL9ex^oiGBM?(joY`af!=K^`3c@j9^3Kf0YE{;``0eerL$#6*8Mo*DiT zq2VL+H_+*%HSk{n8o8kK;0J(r@!u@XiW~o6W%x**FXMz{tv;{Qoul-WXmoZ%wH1q3FTDmjqekgvaW@O`Z9#$}?FH(CeD{%Q}VKvRnX1p?Ty5{ot4n?&J zo8Th}Ix)c|Po3WGp~D?B!jo5c#RKXM`0`%?+eR!&g~a$QNC~*>74S~~-X=VNwhL_K z9Ytcfa7ebruOR&QpI0GErtqUX7cX=GM*VOa3OwG(2R^Wao&qa%(rPEknXGKB+#N3u zizgTGq$EgyGwl0{$QXZvay0k=EX!JFdpduvZEj@<@47#V-+-jNMU9Qk??NCy%8ua) z7$|_qmw#eE6YBZlN8`WHC!*c=6Y>myUT!J5BGYV1azu078RkuI48_LlVtS0_nZT1M zx%4U5WIbuUL3Kif8%1lx4R{OmW>z}>yT1G0$3~mMhlqLGy4V3gAh9C8U0l#!Ya`h| zW;OVNV*O)|Y%6cA$m^q=JKP2vmcYMi`M{!RRj)Ml!;iJ;&0wH}VeDke1j1{4y#BEkZu zkTP1BJP!jN%|xbtwHL*b4nd~}B~kFR`pOGwW`}G7n}HpC-O3QvnsVQKztS&j#=k0o_s7lROreLNR9R}*Soh_M z3wjTx4n&skp%0vC_Le8EE$(ANX%Gd+l{uXlS$pYsq7J#cqoqeLCNKy}H!HvFI0LoYR1zJxp@pF!Fu$*LTMBsY>L?fn`@Y^{gTrr3vFdzHm#n$V_Wv& zM2vy0CzZ2n>)@xfCu#y*E>cM4CN#B!5^8q&bW@$^SC#!}GAXt}FYhp;UgBR{``0J6xIeyHp(Yw?s}rmcSy|#S)S<5Fmg+fB`c}|)m_%`OlsVoPo;(gb{PSVc7xxNuvLyG4Lj#!-tH(>WY zAM}@g>X2Ae^GSdsQTMvJH6~u%HYVBYCPjs`@eO=9y&tnd?EIjd`6L3z^{ZLML*ZfR zsHf5NidoLE2^EY1LmfcS`#_t65eIhgA^-sQ2%K+l9;lq~kut$Z0j7#KA?xEep!;Jo z2oTeOX7>sm-YNQw!*P|eoD2R5zVUpgJIA9s;v5U`NdjX5$=!ka4=IX(^VqmZv<&B~ zv8-DL1_O5#S-)nGB$Fi$VNJb9fs}54le+$~H?w=Sa*2oMdjuMIx5|De{tF-SNYFSO zs#RN_cciJl2Mqi9?E`?lv~;%l-`uQ5k>$&lbbdwXP;0HR3=qXRnhyx!ohQgO3Ie8% zj{v9u9h!MwJ;5PrW9bz?*z~@us(qnbfkJ)QTj$Z@ zyH3)+J&Q_fbB(<+E&Z!O(+Y{|%Y%Jxs$gY1Tt>gUd^XC!k*ygRU*q=i-+6{`3-j0e z8)a8386kTH0R7M$Azh47#~Cmv;uQe$--t``%x&PSXc-h1);WPM1vQUv&oPgE-#5aq<9Zdo}0>@{P5_`Io`(%03c(Iy6w2M582%_jCgl_Z8*r99KWQ z62xUfH*^%QCIQ^77$D#<(w|c(DxAd*kU&JKE-%fmyT`>A8NJ7BQW|hS94ks3f3c9v zCRxMd`*2bY#fG4XRCfa;zV~9;uQt;!`pe}lBKJ)bpl(wtT>;)=DfuJ@i|h+Boa?1J(p1!aBIzM!_eyhU`4 zXEv(F&%=^Frj9?06KxWAx*M~w?@n`0vkyTV&fl7{rNH{K2k2IY_8{n|nJ~UYQtFC^ zp`V(WLf-`AiG74sUeEj-b%{w$TlN+~b5bl>wWLZ$)w3qXJG(a2xyptBO}gXHXJb5n zk72bM{{u{cP=Dc)fA~in3^`!=1=$H8eA4+43GCP>K*RAPw3o<#wkC4EC zxJB^CCVPK|f&P@7-QlE@|I)6ipI6Hw+@qv*g+}4!Ti)y6DZJ9_rxyVW^Du8_3pvbd zyY?mEBwoJQ>8rY&v(&D)3|orzV(z^)?gvj2-M(ASt1KM+%4^7kLc0GDiJL&Tc4p3aq`%-=N(p!WW)sO6nErLB%U#^snDhTA3{C zg#log9ScXj1*#GMTD~WcM;jOm#>G&JEe+VSQoVW(X+cwwaHVe`@rwl~+=r74x@gaw z7ByH!nB&&x;EtmEagB;Vt2OgzZYRZ1ke!lOBa}OoG%A>9SP8?}Sa`VeMB@Ih`{5ij zeVrh!I;`=$a(S>^#b-g*i?Q4s0+kSCn)5e^x!>?VFTQ?2-Niu$=&GGpW|;+c^D2kOZe8ron>wX$6E%x2eGI6y;G>v$^GQ|z~UNA3)qO5j*q6-}^ zrL)$a75&1*%Y*U8#}Z$`u{s|}Bw}Re$?!3`z{{VZ=Ec_hro`eaOXf@jk%NFQjg*_( z`-AD&9=h^~SE)3w$V$p+z6di>*56bjxz|V+=#ZdO^x{=`TwH^IH9uY*v?~~JZL~)ZrlQM zwA6gz#;_&3q$rCVemp}g4t(}p0tbn%qApva?j30Mhbms~9UWW#6nu$aW5oix_}YyXolf7k|oXM(8C; z1??gjo(h$`y>3G?%kJFo#d#sSA*nU*$JSU|=5O`QcK$8mPFeQQ7i&|Yhwk)ydf;;Z zG-vzlb=Uot1H})g_Y|}^UMk&dHkEu^eAVhFKJ78fmC8=f8sMR(2W3Anx$FV$hE=w6Q!%sTh;GeQfLo%{x?K9(~w+5!}7 zn-%%lqK0y9fSfu|lKWEWSiE0%f6)ICm`u-j`n{;71ygR}qOw$=GXqDENG_2CTg6`r z`d7{XalDh@UvV6>F8AonS)~iNTP#!%n!dbqKS^h}?%06E(Q+|`ypcl{!KiU^Txxl_ z=Hu>!>40!V&A}1gY5_^D9|ES@9Dn@AV0MrNG7}kKo8=v(8;Tux16gO8?=M(PW75yT z@xpJlt=S#&FXLd@_Gfpn#X>FZn%PT5?vMrmWD*}k$OMmB?v;N~u9MT>JyEXXSSfD1 z<5jumFWEjmG%;3m)xPHwQ_b4EU9n?JraB6rqNPrixZ<#i6E)tb)_Bt{=P9MF ztNdY|#rn7B1FLvd$VzeW|>6&lT`f#&p&9(m{&pOP|wQU+t*SD=BZJ0lc- z8I=UVxyu7>2X=z8aZY;kD(nu1x~4NWS3eCQ7cu^m7sm)}L9KOrcQ{50JO>blgH}SJjCRIcXrD?p`Ri&C-I11oEC;^|0te9~Z%y-q-510&v<=QAL^Mnps8t zTi0%0rZ^Bfbf<98U&wu)8Yc0Oem?K<#8ICQs!Am(M?aOtGmqc`_`$8%)5Lxh`DhLC ze|>c32pKz#m{M%L=r zgl>D^WPgjxFm@pMn4~p$K(p$+PA~<3F=h2OM_p%|n8M(~T;D)D3AFBd`2@f(WYfK! zTEmdD(2M9o{3nzNdK+*boD5jweNyHP>0JeCvNzz(9TRA=s@QQJ+@g2dK%MyoP1~nC zn?1TqB-mt0B}M!}sQXd=9rKmJKySe(0-1FA@zDx?`%SG34jD2{%rYfVJ+%++&b*Ry zud0qu(^9lEH!PenVV0yi$Md*^Zk$j|JBG1kH-mlC1t2SWJukpp9j7DjN`twl+*+Td zzB?Faym)7ZA1ck`+()2l&*yKNUnG_sLL<*Q!-jybzO}a=3Ka5}AEXE4nQ+qoy8!z~ z{%p5N)*SR%s@ZfadoXhDLFOmyl8^N2;nKi25h=V_Ya)Nal1b+C@zc?bpKQ&x+#q9! zkmbGGl5a_rq&u~HO8qb2M70pOhHw_?a!cTwGx+SDg<4;&45uhjE_Th1MPmsz=LhQx zLJzr{Zpi~R&Lv*cpICu~Kr~!rZFsx*U3@)?PJ{QbZ?1^9c%ZY!^h3a>TbA=Q1ojQu zdi@#YtPg&Qvj#~8X14q?Vtj=k6$@xLJA7m)+P(RP>S;2Wt=evvzhq$ZWe*oiv|LN@ zl_wKaU|5cvSk#c?T)$qd`W#!3tzI~>ZhUvR0zziEIbilvGFdQF$A%gi7uC!@T$rVm zU>X^vN%J$A=`A+vY6fPsNz!P3(Aqo``&m1PLeAah+$-6^OPg^LMIce!xfPst8D~$inRp7fufjz!FE=e6bEtLEZ@I_ct=t(>Zlu%4oj+s^yes4&LIvsDAd>yham_ zp(a8N8wTJwQ>lN0()MtE07#=${PXZJv)4cP+MFP<>e5`lL6G?~Qgy<;Mi*?sk`dL# zQbg<4mO;w?lwmZ_NBeU3$W%W$8L5{fC~6&ms&ReI53+UvP&NUwL7T=u9KR%i$~%Df zQ{7$>;K&2&G@Se&2Oy{a6}!Uh1B=xDkDoP{E~0^3xt9x43!f}RlDe(63}aoISxzrM z6TYZ{U=u+{=R98ldGcXP+9Cemk^Rzxe>^5eq}KrHGk^PN5q}){{9Vaddw!rKZNF%b zR?YLCK1h1RaIn0E#7}KO7Nz%daKAzDWiAL7@G)91eoDAKND3c3B|M4#4GPToM2!~$^8NEZt>G^6(-C6(iJ@Z+&wxWXsmlQfpI}f3R2H`)hC?DNneB*Xf zq>J<@&iHlB5n;z%J58iMVfaa@Z|r#ibAvnBA82O3`-R-*5ICR#8a97lPc$pu%$NxO z(Kh|UB6mOQC)y;v(K+Dlaa}~&rqdhQ`3%|5g(-Fx&tuN50gntnvp*Ii4S8o64D5JY zPGnc3Byy!kk_q4-Rpl|ZXM~5Iy~{kOSsYpR6sX!`yj0WaEyR_mOS0oVegIeDm$q2E z_Nd#sXwU8kBXMs$Qf9mcFWf%IplVe!cX}n=Nn@pZKoQbn99Q6a|ZaG7L!8S!=mvdl;VJYt27AHYl&0ZJ3NzwV0n$k|mc3P#8*;yIMrO_`e8 z*pK*U@MEAMF@0TO?$j|}sO=D9VV$s$R$(gLNIQZKQ>}wO0}-=FGtCg`0)HArO9Zu> z=WBJwElcK#ij~X~j=DHH(&v$-W&AUsC#l!F&}@mP>bJ0+EyB&z>FbYj(U;rrKW=sE zxfqX2K&O#*nvgQi+(I4>r~}#2%653kgpl_B&7!wnhT=UHOp)g0rcu_jr|RiDo2sCD zr!Ps$9kuyTINxQU&ke9ML^K8NxS8!FhAL@?zvXGFyclBk>`p4JU&fGA5M$WZ==y4{k-6>%VXBAM>2ZEgI}!&Erb9TFvFnM{kl{=lZB@_IOkz$ zs2KxoRfv;C(<1+sx?=UR&uO``6+5+|Y6AQ-dDF>jOb+sICfs^k_N_SSkoV$1XAU0l zI)hq7SiPXV3Gc%>8<*)%E8UZy2s;jMO4ncIyoxSNI-3+2I&w7J&GHyCf|gxjJi9-2 zfiTmijJT<7<#>awxDEhQV>){d@&{WZ@uS%4b9OjnyA zydNX`D5l+oO5o?GHcg^5;ZbXkm<<`mZ8sXsjk(}G$CdeDaPJ#B&s^Vn4M&okzTxO<^DAF}~&NfmKV!mYvYirgyk&Je3OFP+=1 zWhWyIIrK&4+Uhwti9y1!F-vWXWg}~8**w{k*R_d{7&Kkn-#wA&ynM-xSt@khyLXQM z*DLKYk*yG}^v^w?Ok>B=mBRb_Uj8I1`ROzOxpPtj(_0LpOpjWZlw6EyTfU{p>f6AR zv5yFXai&i(Dp?p>21o=ja763&m;Ac&<7?y`d|r$R*T%1VicXsR;~47Z%Sbxj57xI zNI&d5F|*xp1U~5q@ttrFoAemDpIu>2KE75R7OBw5z&o+2T@8#-&eZ}kVbbF;L9|%- z`@KD?S5!d^dNeOXV>LpBR;wME(zWcvUq-ZFzkhB#fVvnoU|&)i>ZRr$)YZyiudQ&J z(mBGZLdnYYoNa+=q8&BhKRU1AcR*v*tB@6VdEwU0$2gu9FrC{gjs16}Q}^@OH3xEw z^{M*4c`Nef11*$DXlZ_u;-F;V^J^X9=b5=xOAJckmT?0FIr$A~b(qh8xWEJ1Vx zIwXIqExX;$*dc;C?qg?q-Nb4;P;dr0=wvBFL-JA}G`-ue(Qow_8wLs8kV=jB*W^v3q6PUAPe^f+}6BMCdA z0W(?O+MhLyshS%P^fM~YW##~Q{~6pKnbT(9rjij$Kg%OkWEz(u;*rLP;|3Irv5OHG z*8obTDN?DCO|M%|5%~bvjK#**%m8o-HoE*;2h|Pn|XL!mFn%NCti|KZnXJU>E z^M15To73$lYuF=$WG^?itI3W_A3`W@$q*?*4W zR3BjFa-`lvd`^${ zK}=l@Y5rAZP9G3rxX9ua*3-SHGDli;)aaRFyzo?9Xb9Amurp6kyw&wH z{y4#)DI0WidF(Uz`6G4;pop}^kTi?Vk?p5`xmSJhIx)0#MmXl$>qi#yhA&uO3IE3X_onlLk9jDLHW zDo-RDb)6?=qN@{>Sg~N>nMRINN>Yn5S=HswA`8vjHoJWL?cpznnzgCMuh>_J_a2Hq zk3txeQlj#ewnYV3l^M=Ih5I+Y!xqUPaJ+olhl=;aEe>YrnY`dz>l=flJFEPaY2fu5%7CemdsR~ieWC0P>^<@Me)l6!o`kGDk@hiF=|iysRi3hv zUEM^`uC%~X3PpzpJImeuox!P(#W!*}^u?6fOrzv^ZD`O@Vw&OM#B7%ei7vmE6~K7= zEm+@#-opCpgincuvn!dpuO(V0D^fpfr?-ewX7}M=t(5-Yd+Dav zQ8fI_@s~|nk}<&`^tLzeB;3k!3K9&16u4;|C%{6kJ3sjT(D1pOc}Ran>Y{06lN6O1 zqr};3+NDbA7ns}6_fN`5{I&}PMDX5A#7?y?$&ux~AL(;ac>38NKJ)veC8Sc~Gz)hb?%VwWRL>1KQsQMF9b zjJ=_aK?cd*!)dZySLNYaK#h+dYK<+#ZHkYN^vsF#ey<8PaEbn~TkT>Q`Eo!s=tgxY z@jMKSAyLGHjRW*PYv(L0NyYe+Sy94Wnz0+Ad)(@=sS@|1tv<$ax^-*6oR0O`-Rh9$ z$e)i0M0t=zJen4-*P;E!Q9ic)Ea+-Rc9754Ys0JW$URgPU1|_GhzD!B^T>Tlq_P3odGIYAk%oTl5SDoz(sxsJf(I z{h`}D`UQ)?=(N>xoB~FysPQVP#qyf_+7syahio#Wz8)@!bod`7)4wjk8vhh+uKiiU z<>b2UwfMIz-`IVw{u{)bN!S1s5I&-Ru&scM2F##e#{Q14%7i&F|iL|ab(@>%in3IdoT<_6%=xB?FpLB`}a z$Q~Z+UKI$lU3543qUf0{_!|`WZ%|u%PHI1Y`dRkH*O0$H=@C$V)hyzanSa4QQ6xNF zipK?O0(#T;K=S->EkOUyhNHxhDRxLIoYi7d8fh<+SW;4dAiDP}oUV3!%Ppr>2V#=^ zfi#^S-6g1+DnVE=)BcMe;cSwfp~7!$k^8SY4@RBgR$KO+_!Oxf#ztfN!8w|~|5Se- zC02z?^YEt$@tr6W6ZE<#KI5lWRxoM);jx#z$|Bjo$AA3^)|TSy@sL+jz4O(`_H(eL zz$cuXBPI}=-Z9jj-Rhxl`8+bL@G3=*uHeqlo%2!UD!NzQHU{tAx(1Dx>eG}kt^x9^ zKRfpP8@@)&BA-*7Sk(+kTP_Cf8D-Wt9SlHRIRW=$tFGSCO{Wj4(L18^edmUIl10Hk z)Taz)f`hBfD{usK;W(dsY)VD7^`hQchvcD_h}5-@8C7l}w;!_avQ2zQQG6caU7yUz06a;`nm^qgB0 zE0|XHOy7*N^L?Azvnc|48#PqTq|=8afr%8kdGC2wnDK?o7Y{Ty&hq(IiS@U+?~Hb?xOsl0-;+ibsinRK|#k8 z&sUxHV}C{(@2F|d8!-=l2l}>WMDQkd!n;OYd^&PD8K9U==!5Z;BA&AW;nMuI0&%En z>nQt`?wPa`t7ZZpBbX5@G_N#_Fx6WcV@p0O;4W?tmY4nla5}8psLfxIJdiW0POfJN z%25>dQz;&P^$WjdyZ_{ubv;Rjf)Dx0NKI#<-}E`k%COU?g75a@{kceU&s@b-13g9R zPi8TTZEqfe@86{GK0@w)b8{qk1D!FZG~5eg+Rr%wIoEmg1iC4|nWCYL#+l5c(Qr?T zM3~k7eI#1xk?YWi7%ljN(4O75!^^rI9{5q?i2crIqen6~o|lXcWva6}+@R%1sF8SN z1O!$2_dMpmfZ)oaNAXh`3wv~KF3eJ@hy0kZOv0|3Ed11pl)?FdDMz5{Tr86O5nioc z+toOml>R=x%ec2;+T*mcQpUB-Cn*rL`5LGYaZ?y+O-Os)1hx6J_7q7?63yvu9om*m z)!9b#l$GLajT3QEu`k=;NQ^VBZU{Sla$eg*Pkj2kxy)C1bgG_TA`G8~jbHY%`0|Ty z2MK8Tgpbsp0KKgkxB@{7=!O-I$p!#Y_#($zvCw+4o9WrAeb!_YyUTyw)<573SgXE9 z$j6H&H1M<N@MuGl09lIGbew`M!{n?Y zbNY>VRi&kg+x>nSO{|fhs#ATdws8iFn5DC9F(~H*y+vj1jMvW?wO&6<2J!4B0F9a-Y<_hMwr6aj5zTHyFc=uy4yR?hY3ZvoXwr>F*1?uMA=@LN zY#8bdi(JHyL;5Wz#Y<#@_8*~e&@!i;sQ@6r@OMsW``$zk3R#*M?Ub_#)YI(i+@vHNz z_G<}AKiWzQ$k7k*16+_M4ERQV+GKCXGar>r(M(U+KI!2!_zJ)&^{*m`(Ir93`1|ft z2uPit5xd}L%-}C_a`cZQ{PnXVNSX!iYoMKJO z%%DasxRKMaE3unq?zEH~uV3#(t5A~NT*fuWQ2Zc)bm?@mdw*Mwqybt+q{!tgH#0&D z?ff}?v6N;Bp#6%PY@RA?rVq8zGOY_JVMQ72x4bWEyo+^Nb(y3V``NYQYnR)UdR1M{ zXyTeka1L$7Dl}o`-LV<(70dRGZh`K!ti*HfdE9sMp6i{C{qFa}_P3typXB9Rc^Evm zIer%VmP(u}!N9$V^xE1{@u)Ed;cwnR7N5XK0^EyZe>HG``yTsW_r3RGXlskJ0EbIz zs^WpSURZ%j0569*J7R6q`{S^22u8E@MF}Q$`u*3=t0|IgL%dcck16x3dOt-Ziyb1Q zNOcz#)GNFxLt0hB^fco(CJ42?B|~kCi-GJ-epEcG)*4^rX1AjdnPv2~1{4w(Ir8@Z15#29VTfP&ngXkG#|u2O?<7 zS9bX_tE8U%1aBs&Dbu6341vsda*4*PSPPY^Xt7(f@lLJw)iYEM3X)Z3P_4kGe36pt zR&LO}_qYp^BxyCq%Tr}qui^Xj-QJwxR$$C281Fd^U;k2E*5aZgMxe#K6^*PP3FOmd z|Dcyd$CXMVO_tVD)8NHwP+&$(H*5L9K-#1}_OdrUBD3+v^;!fjx6{Ct(Z8fvf_n~*fF6Op^_j(^x!hS$iy?Y7F zV+fqb9BUZEN2e}!D&>fnZtm^ZgO0ui>$0(PmXZ`L%r;Y(G3pGq-_~5? z%Wp^Mp9?5pC|U<8bQQDKt!ctC^`@C#DiuVJC3CW$az;a;U-u0+mdx}sQM^XKs=OK? zw-1d4Fe`;{p+`)}t3*%+RG>fnDdTb(eeeLiT>>zG)Pr6Tn1rc|WW2a9fUM^LW_- z#ue9bPFQR9=J_SfWtFea@vA>2O1IbgBq)62U69R=g`jlUq4*QZZ)Gua?YWl_pC*i9 zhS=y_14nBcr!bkWjgwry#3i|7N~5P~y}2UWzt$>|)LM$$nqBuh{NM zLJkXiz2F(nubVEE^FthElT?8{nwHp4-X}&z1meztz}%~qlS~DYe&w1L`=^Zvvet#M z<3l$ZAUFOuZc{0PN9A|ByBPtrvnw@Wa_5oSF55=WQ zyaVpa;)lq$+NdFXf;Q4v8vup7Ny%;lI#gBl?ry19yXh-*idHSUA4hwOjtg~5e86i> z5Rp$jUJ2?w$sn$fexRUgg$o2<$IRo2Iz+{`ZO8)P=~)vlI$ zCp)MsXv?j~S1z0*j~>LLcjvPL^}36i;2XWy2pRd6;gDU<$!z>{e_g!4P*UH>%h;(G z^-C0=5v#K)wQ-9>qWIqZD{#Rj5Q9%?A*uyBYyTz-gh~Ca@VR65eQaI388&R}F^OGA z#mi@vbl+j@dGTG`7-6#)%?HbWQ+e<lk`&BJa6knb8r9p%wLu@_2UjoT9i|yKxMLL@sz_p zsqs}tqZB6Z*T(Oz151D0rTvdB!dR$5?s*TnM7n&@bGQppjV{(QIwiz@mj9?Y;T(B- zP0sr_G2u0B?oQ`w5L~!QpE1(O9F(NClJo63?#{i29F{^K%>EUOe>`ig`AscEjRh}w zCR0XL-yrveh)gAYElvmX3pH7agf42FEy|&`u{M{mIqEK!@Z1sso5x>{9y6+%2-DZTW>hELIFmprvyf+zmybUf(grg z`Sj@^=X8zKrG9v+CoHq9_fBP5O|&7%pFfQV=;u6blt1Tyg}SA zN9CsplMEKO8E!X@D6{$;OvbXq5JxrnV;nd#?LwU+ta-iev3$)!IG;L=z7>0r*zg>? zczux}P~Ojn=_P}){>q2ghYvAISGHI(-U4-};HLdZLQZ}jT_4iDy?!$3p5gG8yyI7Y z>lro$#7JNd+yldxi+ZY5A|BakgOtEM{ecRj(t;1zpgsL09mT=!% z(oi+uRkzskr{}k@M15KC!tr=-jr5+hdgEmoHW;vu1Y4bD-h*f50SV!=)+5qtVimHZ znnN(-{&S-p%wz!k2|k2{>7BwCKBe-#C}a6s$BXJ`jFIQ#I=GewvXpB@*7NUeORqQj z#O^OzJbtg~{?7d4k+z~*iffTc_l+&gR6U$I`D-o3hR*<4dSx$|!mFqE#<+#SotGHZ z*1JXxV{0$kOP0#MLF~}FJR%!`CAhJO?=r*3Xf5_tRF(D&i4}vS)S`%u-i(0#bkj?t z=gKQ(Ml1cSdsm|^?Xb{bQ=PDt%>JbVgA0wLK5Za&+bs}RZl>`3`l>!f*u7EFFHAJM zJX~0R=NhQ#D5Dfn6Utg!h;bfdeKfxqyY&FaHy86E`u{Eg2$G$Qsz z#ZL@$lOMRRW0FWc48xo3XE&e=`Ihnw?W>VWPv8;17=5&AQgUfcuRW_KSTX%M8Ovmup3HICj>Q;{b4jEL- zRUOhMsd~NCM6Bxbamkxe15<{bkC#id5J4PVCtd_LUU-!(Y|S%e%sSn~9t`v0Y248x zrVwt~WSO;iy7Htsl}bOWDrGr!FZ-x8Bxws3OrjGZ2ys9$cT>7iLsRDJ3)auVsJpKD zQru)xtgHT9M$R9T8Q&30Fpplf^Q}HtR~7r^_SxnBR$=e5)}}fSqYQuJwb9rKE5?>| z`vRuCc&20l5%-D^`BH7)E`)IL;C^lnV&|b!Ppts$lQ{mE48i`81s}G%B6MfIM=CkH zDFwLKRtUNwxc!UIPwI~Qnhci-M%64P;T^VPamVW4e145w-rkp*bza6YH%aHOVpj=~ zvPQ(9*MK7)gJE3DAzd+gqO^1dP&HD+w+pTeB8k;#asOKLRY0W~kqNxIU)7i-6zW;J z&rozY!RMtae;aLc_P|Z1!xRVaY*^6|E|gYPAsQIAYfX%L>^8_(rpYL#hey$u<<%S? z8gB5qo)H;l;;E$ra>{(TkT8!BIl3VoveVY&?Z#u5W5ru>gal^Q%X z>fxtZah~lo>}FoM{or^ayAkw8_!ut!c@h)^L1%@tOJjU*@^Rf#UhyR?cFp4Du-X(2 z33Dlv`@o0xmC-MdYm-Fv$!W)@zH(c=maP=g&AF*0APRX=Vv zQn~q=Ir#g(#wyNKiu00A>>9}*!*|ZoVyM}NUv{UIN=i+w$ME8;9+wCm`})RQm9IGQ ztbIPEaFi;S&Rs+(I1~`PjOl%~9t5iT-i5n9HYbW)JdTjTZOELPzruhuS4$iD{ELPPKuN zl+|X1?p(`LQL=0lyht4o16sBUDy~SMId71+oFP~6s`p}y)YW^(*!+xg{khkQd{A0A z^@P***oqQ0&G1)G#htlyk#CvVT`j%LC}Z+iM83L9EV^u5-ga*-KJz4n2Y&pyKDxd= zdq17D(Qx#rK-cy62cdPfmBi%QQ?N>oEZ#_hj^pSJzPSzQ1nO-SAZyXvPQ+i=D9bj$geLCE&ehGMaduYELZ z@0<2BU6Q1ZK;4RVVMP;pG>s&-wb_Vj7uQ>dRwJv`dlpVQygPTIam_SWQu2Toe>xe< zZj67nO^iB|p1_d(YZ`XklSh-wD*n8S;>a@wQJqrHjfqaTiP|88rK2iN^i65V*R-(%VhEB^bxs!fjQFq|{$BHAhII8q1 zXKRo8^FwKF_2+{gsAB7?sxKuWVHZmL+U86iKpT(cDsc|Ic^hM89`ZAoE(Vm#-#-mv6yxJfQV%dA)rrE-Se9A(PFMi$XlY5# za=ktJg*95m*yxd0nS0Su1C1miLKXDXxkm)e(7L*CL6%3;?Rq6%#H-h9Ry?Q$SL<+% z$}uyr#CZ>0dxVb;nGulr=+ZYDXPC#4AU~)ybq3IKFn!3q^G0tmFAjS2c_=R!5&5sB zLDxSUW%W&v4g&Igp|HFL$aDf`rv=5jL`K$eW3z9Na&&7}$|&peqp}p&-5)LLojq^= z_7UZ4hyo<@-Tum@)7u2)jQqZ-5w)A&7o+^EI1 zrK$_$46{E+{iuY|1k!eRE3rb5Wh>x`n z`D?^p(madwNYKJjc2R-kH5bBD7C6yeY42oyQk4$>EVms1IS0p#@=FT@H(MsZaMgNQ zL{>p)Q{a@tx#!fm+cZPQp0S%G%yNe~2XNf94mKz_V`_Dmk? zQ>=qq&fA-mF*KZS^BvEfDO(4>SgbH+A!$v)SSi0E#tN(_M*iOHWDp1Upc zJk7V_V?Ns@g;^oafe&AZL&2o(1kby-!jFbEC%RE&3JY|fwTM2u2NB>OZR}rf>>J4> zz7=U<*C@yf$IxiSuU522e%i?3$exYzS1pLgsMUJy#%t$ACm(N?mgqBsHy_55Dup)R z1#+H={g@PkGFON5(e-(y$sTIi1SE+;SW?CXlPWK}05E}9+sdQ6$Nsfs!V%8XD}W9V z2}-1lA2injB#v`@scxn#3m$9?Pr9M zIO(DwX#vCIlu&>2;>nly#W&X#IM_NI#|u2|GVdl$}$<^77T?*X|!b zKD*oNuK+t3-69{({%|-*RODlRmSZxT$NaKIxxc~$>PBm$wXY*LM@Ul&6z`tfg8ND^ zM9~Zg3UO~htY;1zQxbb1r^{*rv?i7CIGw#YrRaRtx)Gl$z??)4!Z#^b{C_%;EkPf8 zQ?H?l_Vt|hsyO0dFAWwGv{vS$5j{6SMM8~mlAw#l$cyF zGSZV=Y${ZBJO14tBD4AAmc7NhTVJWyVfN|eTEWLS&nA!s+K`2IQbf=BqOQh|m@i;l zokoVuK>%q_EfC(H#}i`F<+Q?REzyHpQs?LIsj4zH$IygMhwFNIDJtWE92~n?BVVri zix9(5ML*vHNE8Ku;+B%>7EtM8G59=oTznw!mfs`#hRVorlUm~!v=37nYqIpaHWeI? z#G1Yj+Rd|wxb4%g{IZdi>>=AS^B(20rxufwNAsb|&@aZPWEPB|J+)qJGTvd1rut=V z=*fT*ODWf=!xtFi7>FGGgp<>?`sx=#DdZ>AcGwW^2Zgn24~ zsf8)}4%Bm$>;o1BYlBVbEBm)QuaqSUqV8}#d%BJ*9=}~drPnyRH#yVj=UE=E$;6aM zo&y=7$~uh+w2p4fXVyiw_4n(L$grSBc!LHJ&mW%FP=B5tEncyt5c~i!LWhcE3iBnL)6Y7UGWriS&*YN@* zH9wl~SsYb~RU2caTTimo7VjXQO`h*uG~OF{k>jpe^bXj|Zh7JFJ(K^?g^YQ9;19QO z@U$?0J&-^^A*;^_Mb^^JUH4aTIkLdP(<`pK?R0kMYJ4-r`)h)!uP&^-_L+X7^i|Bi zyIgrc4%ue)`t55PhPD0>ES%cMz0p#t{C&_)<0z+t=aJ&R=SQr41PxhP$>iwuAd&tp ziUy)BoX7d-6LGfKN$1oF7ixzjnvgAlxGxzhlZU#UmtNPFl$;Q2p{g$PHf&`vDg9&< zr9IEFeGydQnzm<5 zUQCS_%v@fQ*IDar4CMY&>_XEysYK&bIxQi?Pzd`ymNOL5E&L8SF7KhM# zq2C~1p{?K95dKj^=no;Er>fy}Zg|tV%c&3SLpxC~N(T6|&s}0NV1jrZs1O2q%x07~ zY13*?1cu660@C&$)#z@Ve3`S6*|`brsNnk{Pg8}t6is=VT)>ke<3@sZVUk2}8W~;8 z0bI+Tk%S}lwZ($5;0E=hynwa&5Sw)RD?C>yZ}ZV6{gP}zkg?iy5PM;uQ@ON4PK$W( z(VWFAHHXvAK{Na;tRIKZ#L5=&BVvx;ZMH278Yh##wY}Nge)X_4ovuZ{yk8AMF@bc^ z$jSsAYebVrgHM_x^@22&#ZK(Z0%Lo!?`^NWEogsvo+ZS2Gbe&LIOw`%inCz=RHlUq zU7-(KHQaRby6Pi2hdU>=Z5HF-n+~eS&%b9DRo?oT7w}`qk?N-h#27Ks*rvpszSf{W zf*}IgFBq8H*ZbtW;Kp~=^bZyw+159dgkj+L+@$(YP3zw_H&*#dJ zPW$BIM5=E`*B)@V=JWysU#EemBlnJETT+2ix8>7KDGmdzzw4ZC4~%Vd zauE|~W2q9lsP1ZoOn9~nP%6PS0+;<^1pYpJtf!hJp+_&h3)Kvw8HEp%IB zBhRRTJVj?DXW^y}sa9tPv4KiL+Y3=UoU}hWDd%@dTOHP{=g>6faZ>gHd^L`cGvBUa z;F4v)MAk{)>VvHBx8&=-P^>k6(3?^@AT`2iFhQZghq>|CaMvpVSa07x2TOb4yQdkU zPU%TCRi5(bZDT!o?#Ym*!}`q2pEL-t9(g+!?TMFj8LZJ?3?8l=d!TmLUuexD>!CmQ z+qrG`lGdsgKJV*!@ya4inF0b0W9t`20{{;})ik&0D8IC zKYFI-Jp2P1<;Ifo0<-pv{NFW-MBL8J^;tv7^u=jvR#*jAUC3&L(-EiwD4sg67Ar6f zJDe~PD}JVY>JZL|*4jB}CS+7g}rjTUM*IB5;>0L+Th|`M?SFK)u7fOld{RG{n z2Jv;bpryjTw0c_o{oMLhiGdrN#(jucs}x)2gWQq&X#(zL~CoM~q$Y zG0lIqdZ%PdV|hgCm4{i}p5*M7EXH=(5YrVnp;P2G3F3#8R%H7_ohK?Nyyd z7?3FK(9z@R+>ltU^bC$ln^;yHyIvNxAas_tP@lTui157+5noWDmgK$FCga+2^(l4( zf^Xuut+?sPO3WVAKz4IBBh+1{yl{D=Es~ zg1OUGT6`{$RaHdVa=N7pEIkC0wbWuAvPN?T$uNfKgPL{elxcH}^fQ*6c$F+4gj!dz z$=1^1pc&z^mS=j7 zTwvb?9Ga9d{31LUI(*-buni8h89~4>9MbxKFZF*n@c$l*|3lxavq3)jw@p%=ng>6A zWS3sm>;8*WL;sVUl7MBQ#Mtg3K))aZ`Hx_T_vU}Xz5$A!e=>6DX9Iys2Lj0eA8^Kk zF0XpauwMm*#L4EK|MB~O?|*KmbNZ5R)NKyO)>Z5Lnst0^DEa={jtI{nR|hu6JvR2E z^rDG==CDpd@5RCh8_*x8MrW8M{Qq`9^F5@cR&9c>3w-?ssbM5Vm@Y0p1Dt{I9r!Ac zr$}R{+=G0NjlO^3zCmczWAIOnqg3LYu9zcQYl5l~_3(8u$Fc2G3A>bAtGgp%V~LF<1Fy7rt|Xpd(qa%z7v@r2uLhNjsb(@y+&732a%+`$GhfcQjLJKG z#O}}qCTk3zWf()ELafzgqZ2P|3}jlGTzbl#)cJT`YDICMZ=8Nlu^%_R&-(^ev1$1g zcCO|E;>Mrdxca}PfOJ7j2*8W|d~^#{gh*?ZiAuAQX(tX!@W~>2K$|g6Iurf~9za{S z-%VmM=G1urT?wDQdC+nJvK{5PD87nc14UQBdv{xV;8LPwKTGm_Wst7D);ao)=qa;U z(wMTkzy*y#;IdPw^(qZR;nmq-^C)Y4L8Mvq4dMjfq|d~=aR=3343j<5yA>DddQl;O zMlVCBXu*>MM2?6mc-9sQDF_t0)3$>Z#rTnfO7^}|F032dZfI}l7T^d7D|aEM5^CPs4E^{2uS3q(aqFK?yKv0`?h5DqHZ8!i>L6qyo%I* z545990=|0~#Ezd01rrDnv;mf-C4c;fQJN0kZ5K$1$kl00c$4_!`~lZ^tB@LVgVJ1?@&+m&y~n_eO;;j`dy=BpKa#b-CHygS5{6(>~P$ismz|hZu#^^h4)Jf<6GqKE^3b& z>3mu7&04}`I?8%K@O-hwYU0F>`V!O{@faa^Sox+j%2sIsvjES4{h2?_PEf)vI*_pN zIVbZM@~2WS(%q;@0CNdb>~TJo<%#MYdc#xbW>71fEF`Erd;bAue|d+Ab2oewJy>|& zp%Q4|tk;43I)4F3xq@UMvm*0i?b*H+e=|FU%E5twn7Pj+cP8=*%G4t=bBc~V#8-J4 zkLvbf+fGK*OfcUHQjSyz5VN9s{IT!Jy=m?{yGi?q{*Ch;ah5Fl^DUCkY1h52raYC{ zJzm}dt)_ZaRV7v>D%)OB>bf;5uNWU;m~5GLP;rVU>Ui<5CenGCYQoT~kE-^J<`36= zgFwT89uQ^^({TUp@v-8XYKIn5m(k|9S~)3h>{H@x<5ua%q&{nyZ<$|;61iQ6iL(F& zcQLQrSc`z^m;y=T>*!Ch!iI%Zz3kycckgV6{u9F%outCcT>^^REh8GdyvV%!Mj-!z zZBSZ#O;WGO@$)t9pJ;uSN_H@a;<&&{fyeV~i#$Y&+t7OOF+LM1R7z%m4-%nxvTYJg zPeZHbU-0n2nTKgDx3e{ z2^Y39omg|1-Nhj!F1^=cvuCz7g%WR~7<9G-7yPQG-s0;yH91?dizn_!9MB{l@*mkK zDUJ0`>k#3$P{;46MO$q(IRfAdpW;jdM&l!J2e5GjbO3NJe1^@7$tlZXMGVH>F7zUR?-a#pMs$;BUP#r+suhet@GY z^_Gzbl$?gupSi1;yJ{aojbSCArKDrbBb+TWijw=_%;>4-4pv(|z26{epsBV#bW{l9 z<7<`9zV4svVR3GfCwf#@L(A!Jo) zaR`8U)s+8ynB)&%w1I}en>wE=ElIh4gS@$gx`dKjqagU>i%9j}&^;cO%`3mTEP?-+ zC^Xm*YNXi}KvKaM=YWwLm?NLhKKFYCk#9liKbycpp9j{}thh+3P$C7vc2f?yrhQ~0l#0( zbb}-hEPK8aA-mnM8n_)9FiYnpoF7v6OKq$Ic^E9MAPXFSG4Kpn=6^_lEA4V*0{?gxOk%^7M}O^c|^g&#IU@0nf0;FOt;ER@31i z{h6avj#jyGj<@3ayK_ser3~$M4(wIJw-cxQVqh=&C&e72KfRzjZNNMq5&J0G#4y9>ZCVyp*!0_3AyGk|R+@v1AZ$R)v$m0<{Xk+N+z7W{6rqJglF^8CmSYbG0UFTXr^AugpHi=EaDb7|%Uao^57>e{d8eGdX|xp8l!=s3pEB)+bJVW`}z= z4p$}O=HF}hU|QQ}B;AfO&90#oaiL95{iOq(Z|o**v1}k220Mkq{N>yp+%=1%KwgdA z(>Kgo84`F`%20;(t$kE-+5hQoY=Kk#+z3Vsl){@ult!(l97}7g`oc`(5&gn%R;6D_xC+ES->2XF3(8yve+^cHu_f77#Ti_uSBK0Ua(&i_ z!-r4{UyZ7tb5k?$8IJQ!9G`{T6B02a@VBwVUP8{xH6ud!HWVmZK7pl|1MkbbpS5cw zHICY)UsC$&0!-(v8FuYI4khymaFM_bQcWIuQ)bcfZ$p>+e7R8|Zq^vDO9; ztTwm%1d12{C9<*yNc^s|f1?jR1qG!~bL2OOiiF#*U7L_UM=Q~H8_7)w?J|J0&)$^4 zo8Rj@6~AitObG&ds}D%@$G2 zPV%r}ZO9owN2UkW)lq;lK_PVOZ?C@S#XssuG#qOLEvZmXpj%#=Aql^-&9W{;1#fNx zy+Q+5T;?sS%wHJb?x)1xCebJZgPFT=?aqAfaLZeZW}`a=Z4n%S&T76Zc{`h(d%G^g z12y1fCCNK0ExrU5Y)cc@yJg;$+EwhS`DW3i>3jMa4(IYeZ`f>Qin3P3X(oFZMgBL!2{IK~kZ$WKe&X>)@wgxSw zmBawWVO_xH_U=+nFTgGbEd{Z_k zS?_XVZ^7N@M_NA9$CNkZYL_*<`Xt=sMIQ#aeq3oARM}Nb1A_szAc>dgu0v${(r`^y zYVLlC?Mm9h{b2w2F&c?&;v2?Bz5^GcwnR6+82yMel$47h zrzCYEms%tiBJkdlIghtVt)m2P#=y?9M2QeW$TU!vvsyIULlYj^?a|B#9H=L$u787| zg|_mDRwR`oBE<%MuNHvH$xKMB>eM}1DMAFJ)i$*Yk=cUrNyw_5iyHZcM5jG+UBnzd z02rckt#brX>cvM(=>AN(4;5&yGStHD1uGl~BuKm2mu_O`;IjSwv0ndRqS?8c7yK0tCiwh$LN$<#K(^ZE-P)`%x|iIsU%*tC|VQ*FJIa#%BTbi1*zI z6+AC3VJ^a`=fy6&G6PDBOJIhl2r38GO5k4GYh0bICdtH16Bf;~4zN)@-sL9iv=mF4 zY!~(vkmXMra&_tE$X7}1u4|%OlylchHp(eMys5qhG5%@bLY3j9G#BKabmR@h{!DEb?m*#J`(7^sMAFd`Q$cZUWQNO-UX^5ouF6Z84s{I zmN@v_fvP+$l5#104N5#*9?2mYa=9Rkw61ar`S!X%^%o=Cr;ZgJl<+y8+NvXHFCtC* z2IKw~iN~#kZEWEVXd3{5M9fQRsIB$VSZk4_B#9-}io!UydnX!50iGlZROL%eZ0gWF z+8h|U6@*^oWLe4?NzbE0M1N}C;=Bi|n6UkRb4HUH?b-9sLIp@V_C-Zqg?loOm3)%L~*Y1ty#r zt!(=(>rMChWFsbBHt)bC4Jb3O>9jqgXq9GqO=Iz&oLpAziJi+-62Q-S#9W({h)tBvE0Xm~@*vKyyRgJ+o`Y!R~&U5CdW zqo+(kz_R!(KfJoB{=QMzinst@p66=OMO;t;FY}hyb4C7oFTd;v*{{tj z7)&f4eHF)z3R$tM#VBGZUu-ZArXjp6ly@ICf1Ihb`elBac*+M^#jUr$y+f4Q)= z`pd`_&)Yfg)`#B|40fB_{c>`CeUf+GUTK4=k3J)Cj1)<|x?gFR(td0UIJ><=E}F0L zrTX?D0bq3S@mFP_4Vr1^oiz*vB6Pu+*E}YYBW3R;0z7;q+lMh=Wf&azFe?( za8=_0g3zb^T^iM0*ZYM+mh-uKHyzR<3b`C&D9BHVNQkVBh26b7d(tOlYUp#@_T^n{ zUu<~VM5Mqc3=38AcGAvZsz+Yz}el<*;~JtqW^!-5>ta=rNC* zC#tz=PPm+Ey5}7G$vCbkcKh=F7S@ZuZBFrgG^OUGaO+owDbL7~= zo#(h9chf76w*5ovQUv)sax}C-=zp3Ou{VeA$u@9QWl%`YJd@jSKY z^0kD}{A793nUPjW#O5V7nlvNYiB&RYf0eh3US-!+-VKfv!yP4cn0 z#=P`dWcN$n4bnwgkhH%n&y1of$Bl2Oz1`MTc9rnh1HOk>+N1J!$MrNV!A}bDG#;9|-m0|vAY@C`%5gD}u5TH&`6BB} z^=yjU+tG`t_+j@dvRaqaTmY~L6g>npKaSMNW8WYp=4wn~R;~7H58}y>@<-*uK1@<` zWM{z)W-Ud=rNxqEMspuo>M@@Subf3!~x1jBEoEBZ(ait>H zX+==mV=(+LRm8f5CQ8@yfTY>_^XQ353LcGW%lfX0#N2`~kJPdm~Kl zzd|v}9IyV(d(i|@GH1gW(N_bQG>ny{4t3(B>SWtf1B{4U>Dxw^B8?5cyN3pok8`+N z2>s&3AIQlHMErzgx_`zQ0NGX0{KOk{{Pn!JTMG5cYr-b>6IV+*OX@P^8@BDSUlXn? z-gFc8eNAKh?;=wG*5n09Xfu!=e2i|^{~Okusg}$1iSaGD1wSX7b2yAGcLkd@jK{vJ>BR|ib_{nA49@BQl zSDJzYRolyD%qzId6rPitTcJp1Im~ zHow9JiqF*QG!otUVUxr(4zA^I5JeE)Be(5`Ki{Q!d60ls`5(V4r{K*u03P^0(C2mE-_L8wSFume7FguHS!Pq# zN9ZGK+gWy<%2sbeU(VV8H;!w*lhfzPsYdy3f(23F`G57wruXT$AmGL=O6EWBxF3Lw7gYztm*o-fW_QD`R+fb z2!8wdS78R2#xH@7IqILi$^!+je?0&H%0A%`f9HN1Or)G!GZ;;sL{g0|3}p?i=1o5v zYM(knc%lrIqYtXR%qDT4;(ugcf6W-r{2bg!81Ad%Zkw3Pr<@(lVU0oG3B3bx6H5_O zW(_+`|I<-BAhQTv2bS;aE!3s7l(_+f3QgW@Nrh0Y?1EYcpv4!YX0m@v(%X+ihYW& zg;u2Oo9~OfI4jV0ePkQv9_UQ{y{D&KwTdsty9(%h_oGvQ!gTd9Jl4Oo55Bfh#w z&JvIay?OjHGnXCGv|%EpbX_(r$`B#>y{NwE*-#IF%Te%=fOGOIFnLFF6d&oV*e|cT zKN5>9)acuHZh)D-jLYkqsW-Z=ht{_Tcd^d#5=^#D%v=7EOmE2w_Sa&sx8MzR6|m!7 zir!glT+-k0@oVrtSuBc=b(7KsG3dI4aC)D)UxE=Wvwj3di&^%_*GZGUo%ftz?yr%? z1acQR=Lc}p$vh8*`(q8F4qY%a9^eDVY*n9&THE#j7l=`Fnt;=qJYg6;JHU$C8 z@-~to+3Hp&Ekm}@%$?CQz{TK5<0RXHJNxx>N^$n)K3boQ+6>KXiU8J4@2y$?ivU|Bk-!adHWemI|;TH!pW8h2<(SI;3!eXl7)ma&~Cq`pY-_;;t)gZ=u36O!62Dey1 zq|;}G{Z&3$c@s}c&w>~@-(kY6TQS7*e4T*Fg)+i@Iq~_Wswe<5U*HtAeN22Y$!X{t~b`a zg=2u=7dKtnt{dnCwmxMfkKu3_e0OrLpy_(kpNM$X<%1--PEIarCXWkK$b#=}d^DP~Ul12;ATsGQ(+jQLB0 zMD+Ts^gcg0B(*sR$(H^Lr!si&EH7ptbV8)#roT8~GEwKo0!k`ZsuCqbyNBvof2)ao z$j6v~tcgO|7o|;F^Buw&*;yh5cv%dT@Y!Wj!31qr%Hj4&!R}09#TKSHVOa?;y}BzR z_8YFu5xmkZylG-9oW_XYw{ZD`1x5u8J;gD<9kZl!Ed*e_M#d`dCKHn!J zLhLFsJFlq=|HoyB?!T8Ic7eJ{Q=aO0N$2C@cFwxa$vlllMwGRaa{|oXe`5Oa^>(%g zs}7J^dyZnVO(%Ms^Sgb0|4VBMsO_-^mq#p3As$SFmv;Kgo;eB}DEQ!eNCYcioY7E+iQ=Y@^c0J~eZ`_cf`fjbZnSET`CN@c1EUU{oEPFUb4w2-R89 zh|LzJ9J=d=)gNZHXV~{AdDaVP9yoM6J&t$yG_Nn4%W6y(QVO}OFr zM4d6g?b=ICX`@E6W;WcglIh$Z#2?DptpI*dI??Mvow;+=T0rgzFPA8G`h!wF_@v)O z1Hf;dsff%hi{~D%Ed0KPTr~B3@4T#F<}0_>o)1;Vsoa_^@^{Gh48Q!eF`04qYR|=* zolj9UJ#t`bbp-0=8ez==CRY5pGdF#ZTIK!40(qie;<%BvJyXHE<{}+r3KSVG-O5MO zC@C@oe=vhcVc=<$$tL#}ijIyHw-m=*e0G>xd3HW&?%bbHof>Y^lP_^>`N9l?ni=2c ze5T~+2>)^lMHdS^1aDQ%0HV&4%AEux;n@o zQFb`_+%EHS!}Tul-Qql!-fQJH4yToK^<&cazwdd?PkY{wVXiUhuocrXrFHlH!&tOi zshhe8q7i)P9|VtY5IP4VmeDZ9myXRbt=EJ8+H#LSw)!_B7AY8Pcg%vzz2>_070q4; zUP}B28S)f4yT6_NJ%a{PMuz3Z=2MM^-p8+9wrOg605>;M`9r!vUrMC7fBjQ6qoHm} zA}75B9of@~ud)HWYQ??zCZ@)WFgh?Lgg4>={L0WDPe=f%%|GF<%TsD1|D^cDp_Z|7 z#U(6LqXKA(`RPd?r=gwqX6X<1%rq@bFIpNhe})!#BJ+Xw^=95Q({%+t`=8FlrJj88 z;aArj2AN%)P45h8p1Dj;aoT;e{m)bAzSj}w`qs+9Tey&a@S4rNPg4=r(;wuCu0OwS zY}uB)DqkK#U*y86=40pfxbt6nfTMrp_T#weFzjg#Qe4Gfta-a8hqI1ljGqt}>~BR- z??$HcDvn7vVNHKng#YZj`ySvNWWXH)Qlo1J{Uqw{a!qjk_NsPT<^A+%D3r1tj2MS% zW8t)?qdrdCfj##ek{pL(=)3-);nQ2LGn5yn96T)^+A7w|bvsvAnVIDh)B$rm zkLw71c#J;sF4%K0S^9UPAoA&za>S?JMef+<_`zG`yn)kq!5%;(f4S~tBtO=yo441= z^UzN z?itOznf~v*L2zU#6q5|<@Ncy57Wz`zpFS3gz85eE`W^j9f9=b|$$$%3)2{sDHbedR zfR|b^jMMj;qz`U##G%z7XCuyWJTKP$okHcl#{>8uBMa`=n%n%`dU<*@DV3RI$dQ>~ z5r0}l8KT{l@;;&Bsuw2oce6yxZ-X~&eVilk`#OGqh2Ku=;bf_Ej}RI()gZUjVkAFIzXEl*fS1wd-F^OO_8FBD3bRBG?T&)aC z6XLvf!by!gO04%)+Ie9`K4bhh$Tpl9SIg*Q+GdK>?q8OuQS8|^lSw>k_RwAYnTwv( z*%-#I<-;|Y?NYo=j!)AE!%>+v!#RJuHIi%2yj!Cf=aGB87Ztp-pBy7x4X-(hqQ~z~ zRw0ItZELT6gER;o98_4Uyp~~;Jvyj-tc7;y=-CQ{ecMS^U@eGNojnf}pAmDaRq}m| zw*7HY=!vS`;`S>YM_5>2dW*C%!(Kd(v-!w;5Rr;Xsa3#lOxi}T<%|P-5acv#_HhAN_N1Yso|Qm z@CA83A=i6WmeIxH`!l+kBkAjpZK4JPFDINoEPR=Qzf~EfN6f1|hcBudqIh`TUo2PY z@umyYg6^>6{e{fg-evPMteX1~t)GX>Id6SrORCxjOt_XpJc3j!{tZIpk18bz6Vz() z-V-hqBa5{s5e-n9k|TOFCzPVMvx;;C_W&B6XSSEWQ_8V=>dt z)F`Mhp#5C0CQX1{fCPO7#L!72J6D0F+s&- zspjmm#oD6Rt5aw=wBsE!+ZM{ko))yOcqB@~mv#DvMS%aJ!faG3{$;ig(sYd>FL!wU z!JxfuEB4C1yjMqd??|+Jz`+#%6-oAH$xGomLF>*Ct0yTJ`u!!=R)JIqJ}&ovG56MC zQN8`v_#i5s(kYUXO2-f)Fob|8ohl6i0#Y(Ciqa(@ARr*!jC6x^gLKCXAPqC*00X|; zr_VXhd7tm|{^GoUa1Gb?a$oE{_r315KCwP>aPGh=ffptXAKy_uumMjQX0dvTGJbGS zV^g^Cc{P<87qJwE)%ZH+9o{X+Gby@%Ym0&?C|t8ad@|XxPr@&a@fEQM{sVmTHBP|? zFL`7jksIaS6-ctS;Yd0BDpXWXuIY9xCJmpfKA|<^LbHQQwC1^smuSkV1P5C*{n@$g zb_;=bx;i>9NEA-9;TBpf7H(IX@iM?uJ>#w9M!V$m=>R|%9sW#Rak;g1KGojDZB~D9 zoGR35oEu&g?)0P8b23>O`>oHimytWVr;Z}FmHzGZ_?t~Y?5XPwDX`zZR?)Ja^O(P-_25j;J;;6{f93X=dk93C^shKvI*|pU_L>)H*7S1ggLsDd0IF2 zVhmAUA4}w6RUXR~tCnH7M>k|4XUB8hSJ~3U8$juEgAd0GM32T3fvh81+S^ufd{4*f z@WUo9#8n z$(z0tMysS;b-P1LqJmGP;yzcM3>=dwW<8SG=nKb(ilEISP`$3)pC@T%b0QL4oH$Aj z!Jcp*rnyzlSZ>D8A8x*tO))=o82Ae&wQf1kKc(ufCNg{ysjFZsfWPsz92 zfi|t0C%w1Wp0!KPzCdY5(srd&?s^RJ{>g@|1@Y6hYvtt^2EGr{}6PS8GTDf0TS;TAHvjSbnXra^I7kek_pi1NU%ognl zbB~cKb^@zV>UTdc(p;Z5b`e-Vp$`bsT+s7Q{N7~T3KkH0$2q$ELktL7^$3WW$zAJ` zmpa(?xg{mUA@MdT)_{LUgFnLB=>BfypZsEgor`k_>C)6d9_I|j6xb(Lr!k`zqC=rl zxR=Q)b+sqmEyAM+joc#p%{cWsf}K~RGd-~?l+ zNYsJtGM>2z91E3a#FXSzx}SJ=c8$KBgWO`jcX!NwTWl`Uz1xC95Q>}UkSNU+`qgmG z+`(<;UewkW@+`;DN&n?7N|le5v7b?^jStKeUQJZ~0ug(+AkB``L;ShEi~13c8txTC z#UsiaG$_c3XS0hw#i~x@X?TZ*N#G>?d=@Bxaj-i|k75xWDeY177?jtT5RhQ8j-ueq zY0Gn8EmyH%B9oGDQlO?5g9WijqV_%|H~lO@z-TlDzkJ-fB{qLTi*6Md7h)brN7TuX-bRCIEO151XuR8XDcWzQuP-bv4z?90@E@qL?a0JwV=_eF1V=P8 zdK=HYU+gcD%iX|VU`|W7m5O~Ky?dA)bjQycw+ULaJjmMZ$x4hRKpS-WYtR(5`j(8r*VG^EFa3?TLtDiw7KV4#w?KXV$ zc*Pu8x3NvArA?sFW>fg_HSA;b&9ci1#6HN5A^n7@9=T{?v2~m4LX%sQ!Y>CYd^6J8 zhdb!@9=@{YDv5R(U zlsph!4}HZC@q7kImq328rCB0GnsmW!UDV?AP{$k(HSpw=>SnS86F^g1z?c!Z^`#)?4LC^g%QoYj56n{v^N`zx8T=^=)ea> zt5;6=3v3%-73FVBoVx@bl_>i0X;+RpbCfwIy`*dRa+mGkwB;}@{+MmI_3p)dYd8XW zh*W7oF5!AhAidrRc{CJF)}X3>R%JM7Et(zNVefP$>L=G625JbiK7rP(^usn76aeGbwE_ z(&L`eZ`0SD>@#EWry~&6_In+cyrp?`t{=$?u_i0bxUJHL6VL~*EY+-3Z-6a`Vc6*r)D=Z_qWB)W>q{v66`Mwm<8s!==!vEB*>pZ52wf z-jN|^At%Csk1UNp_NTR4>CC-8Xh?WZMps*3MlDwpVOzzAN_6@|u#4Uu9 z!`|liEIi_SU3K=H2`T`FSbu&i{R1x;bJ~_1-L`bo40F+z#ywf;;Y9KTX;{UqXo3+_ z)_;65>~jCU(qVM0c#WZYkM2R;!%_28TO_1|#HF3wVd@Pe2+CD`a{8j}p^Lk3ftgZ` zz%5gu{X1bqK5Q-v!VQ&;-v|IAKnS46%{)tFdYDdbQttGlL}RbuJ09_t$s+Z9K;LqF zn)eX>o#@khwet+Nx-qsGnCKf@pki>5x#3x2a#Y6i((l<3O`p!KOWwYcPOX^k%Ntue zrovGnx`ZHXqxiwHMUE2E>4D$}j$WL__3xARUJAV=fAr(>YqYRoZVja>cyuU|2)e^` zJ!iRv*l~%nQRF!quk-5m#L?V>KUGUb`PklKABpfH;$bU6lzf7ZzA$*Q_E5LS#G!Ck zhpfv<6f@VwLXLsu_S`ssSmIFrar&TSD9JS^OHK;sXNy_O;&g}Y^X{)$mMN!32&$s3 zzN+@1e{*tVke2al?Yrwwzc0V`w$_iFs<_LXkK0Wa0oFfdVL_!e<%8>ogTeMUV|+dE z4{l~SaiwxDG?uG;wpdi=_jh<#F=#OR6@>T0csSySnapIgQOtu+v&fOZZFcgyz4iUP z^WFsmIn(>A&M~O3@urE^PY>|Ie@<1wsFpA(-+uvfypV_8=laC6#QG=d60xd#o6n*gHxGYrGD`OO#Q{=4dy}lb~xeYI<4_gyq`4%x80%OZSSeH`ENf{QR^L(z(7ITV609q#j1n;J$2fQ z$S@ML^pvjPnEAuW`0>xXwR#MQ+8G zJn3!iVG(bZ1OXZ38H7XcB+Hr@VIq+Y;*|EY!=X&Zs@c@>qT58W(l&I<$e*)T!>=-% za{m?s0bFt3N|7;om&@bZwH@%e!}1r^6O#R6tEUu*IRVx#@g~xrbUOO}vQos_HOCJKLh9M?sn#_IXv=vqlCCn|QPykqk??U2 z=MI;<4a6zNmcuRZw~Vca>V!g<0~`xpGO$(x{>U1hG^g%rE$(iLbm;W-Tm{)U1&kN6 zg>plD@eX=8Oe|8SmzP5USnAXWY+oGR17aUXWjSY{X=co7#QhV&U0TlB73#}lA_PfBO zn4(ppk?9~!d!|JmOlUz>;Kbslz$lSCsj2;w`H+>?PtFxFE06oiv3YNboWGKl?PG1f zZeDSZVHrl7x+?Os9##KfS06FSbqB9`hY7ShRD7hr(wm>5o-i0(!g~P&dFPp7o+Hml zmO|mHi3T;3xgbsWjL3~cj_2RghTU2@RIhRPKUE2dqL7RETjcawvnMJD2nSKg*HrD( zI?iUA&1_4jrL&~xIe1CGG=_>BJbB<;obW_xR85qRxt+d$eeeb-+76!SxcUneEKi8N zg?4E}VKp%^FVELQ8kOP%W_&}lqDG$BYJXFv5)27qQNsJrHC; z{(K{8^nx6R3~h6|WwC?1+40S`Q4{;@DbNq)N-OTOdf|0S(gU{PL9XAp{%APFt14+P^tZ4l8KK+Jtcc=#|;pUU7g`Xb_UCcRb z0u~{yhAb|1>(nv>&l2?5KJ!>DprZpD@z656i*y**Oh*C)<8(9A#z}JS*U)wISvAIi zAQ{zZpXUQfHzD#3dS|*7H?=dbZfs2wAnc(8f{8cZ!7(Zi>@6!?qsg{lL>ki#$PVr7T|n_2w^x z9la`M4d+zZL`uq<>Y^6*7VnX+R;w|1;t?v6t5UBnkwe)Fph6hGVyqHsIPAC~)vI0K zZA<7f3nsAEoO7ZpELK?^Nw@Ca$2&{o5h@cTyC!$t~ zh1foe>1MJcQsao~i^&WaTICH;rCxgj3j$l91)9{;Xfi2pZfz0)6(3y@BNtZ}e4YOZZ^faS!@{#Zgm zWn<^)LTVt6V#+0zqhCEZL&trCNW8sia>_Bg+XBazcnjoPTa zp##xmg3M>=eyT8fX1Z5;@2*p{QsFrNC7(@}EMMs_B-91NQ$xu!%BPQ6)XtiQW zV-!lR<(z%J^GRg*}gBb37`GG!2l8Rq(C)s0vsv#zK9Zd$jUGQyi&5 zQk!5uc+^)+}BH}5|vLQFhJK#qaEqV;F$*IM(#?BzVA3OzKXVfF8YjrwAIg- zine-&sZrBzv>}hT*)Wklgu-|E)5y$b^Ig=Z9a^0qSdCJ#4j;wZ zws=al4Jy68$&IgA9F$^gmmMI^F6or)TSnJcRM5p_n;A9NX+b8(y$K`tzh`FRDBjE- zDOy{Kh)kGk_Zm}G@a&84I0yRT%*vmCI-{VCc^DZO|L~Qjk4rT4 zZa!83WxW*Lco*sCu&)1OLXb0DXPmR;y{?`$*~^o|TU6STXY7t1ABqFD`98@Zkoz=t ziv&Lep%A$0wxhGoHKj=YF-?&WL=fGS07rQb02~n&RX(XsfeCbYr~%rdtoa&pRH}vM1oYaE(Zi9ZYE7|Jid0cEl(+T+tHuyy?(^K1puWXfXP1nC`A|JJ=_T^tj9TJ zupVa`c7N{sEWjps_qtDPr&w+5gP#en?M&96EpzS%A3}10AzrD7xoQHR{4E-W9YWRq z8SPx-X*}G5IoXIpKRF8>)6+jl;lE#BBtmsH0o+oYl^kk;B(czJnnGw?$Oy>o?tC2+ z*rNiIe_lDWjkjbUbc_~z6j)tYI>A(j4h~yfXA(@nQj_EDu%(}(KWP+LRg6E{XQ_Q0S6>RPwd`47M?B!I6YR??geNuAnV4ZE zFrh^l?(Q0k+4Iw|a0dohh|L40`r>;L{Z`A+7&R-y&l%AUbac`M9;ro}$$!f|{ zi@>@dWi1~~9)~~Mz*^~|?8tx2dVa#Ze>8^c@t*69!8}xf>qb8P|oyYWVmqxktVabx=O+R^4==8D3~Qx4Om2i)m0k-=^xF?R}= z8&lrWBJ#qk(n+8)XkZ<2I+7LjBAhe~6ZTebE0Y^Vi%s%$FGSu%FvO}p#<7v&hWxuc ze&r2+8b9Ves$xYuKmcGP_KDJ~&D;S5r8(d1{Ks4gA_b;R76!R42Yc}?yZrr=&iC8Q zxknrAffk3(dZtWAk{7dl`6jnhWAmSpO5Yoz^Su3j5ClET^PcNCZ>~b^t4%pKHrZ;R z^$Rk$kaMr^nN{V+!1#qfAE@-M90)2|eN%ac!}$RIrUeM_=3&ycQI6`T3`3cb!gC6d ze9F#A1BdZQo*I(EAqCqr19vNfZrXjMlY1wFd>9l*EZf~X3(1u`No(MnU8Bm6m2%C| zaAS|YsZ&1$i3DZ~9i(oR)Hl6f-qBtrquOpE23xQ-pvADUTPNq^-wCQNxSlmzMM&P7 z*)!w_1t~@M*^N-N-NqaJ$L|)piy$Gm;(@t1E(%UnDK%6~ddSMj%4?19y`HYbye!QQ zSTN$x4$Y3W)0YVnNgPfl#9y%fj9I46${v07>LKdNLt8G#Yf3J02! z!7utIZ|pr$5T_?8BYs%skcC=DSOoYBV9Z%VTPOw6P0_R=vDOVjuA;89ypJ|d@G}EP zD4bO9--_Jn8`v38E7jV*;)yhRZ3k!BZUcgw23cTXQ?;LBKs3^lGi6SV4?4_Tv_^~Ipf885y*n-ylWb@igeP`a-L0Ab$1v$RNXEXZt)URBr z$MD{e_~*u3E6V7+js`P#^}5d+$y4tuoAhUCZ18Wab?m__BuO9VbJKsvt$MQ`23E_) z$hW^q4?CU-Y);KK(eRLoy~!S`Auck%0`dYIbTreTSUdJ>ye$!3p~c)ETwY+@vt}KV z1GrOt?Kxb8>2}$luLTQR(KB7`8mv7Tpa;nfjWybS-Dcq&hw|m9PkALP7BqRsBoq0t z-gxHIy#7UZ|9*gxZqT!gI4?bOPLz=TH*F$~WwqC}q9m|3VIoZf?YCvoL(uv~r6G-8 zl4UNd_h02vl zzn{iVzLlfB6Vqkdeb>sb;v*eB1B9<<-wb-Ln4+gTG}1wvvpGYdnRYbdQ^{YdAFLGRSwp{(uqsw2$P}@XWbv1dc06ns zFW_C{h!w{fUA*%m^1F@CXEqIghGbdpm0-a4xt!9>U3>IcL$f={@`IZFtAhD6{+PsW z+#&7#$2XKUEfBz9cbOb6xKEL{MQ5Bw9*6XYK`+@r>?;iqqHcFC8}HmGdWuO~Hzr6` zkRc=MzaaBtS3H_587o3Ti!{~GPRqBo3+k^)Ky347hO9nEH9$AnLeh;C_ALZ3A`tXK zaJ&39s2*Fy&OogQFcu-rBWEh1hd}93{Lfle@`S#si8~T zwZ^(XV9oMWK6JSR_X1-9{PMYeypGXl)N;#tH8_NlDe;bJ)EO1>4{_YBjNNr zcMf=5zm?}Z6TlTqT{;)}M;dQ^HN`}m$4pMnXI>9wFv?A~rD!L2DwB<`=gNbOid9rT}uA8L+5oa{8!rSBU@G^T>Zw zC>UL_4zeVnAxN)a=V{KLS^0*J-|n&tS<59kK4T)ydwC*UaCGLU4AKbzJjO6aE8%2h z4zq(pc?fNCe~@F{ zEXVQxxNTgVEeDQkzrH-wJIM#(9MM#)$l^VmJc5&$sgF6#_TQ_nddFb&$o{I_OK&)l*oQG2%EM{E zIirhx_wPvrvd5U`PF|=bBo}-+on-;$xV^|LkZ*s*mUP{1>sy<>&vAWAr~h zgpYFu%Dt5#@bc%^_c~)Ol{xmMh_Pdn3nIr*i~WXHC$Q_jd;hn`L;hzkzW@_@vNw}S z6YVoMs>`OQ4V6Bcv=V-Z;>lvvqkw=w_dx%o5E7SB1$7mJK)VmU4zxb#if4V)tpJPhDTDs&i{>y>7hU%`93)mHxopRJk z=O^W(lnQgWGRR#MsN^7y`}%ThB@irlWs}4t6(;la#7E#>OorYpTl|l^Hx#5m5BepcVC?Gp*W8?2Jbs#`SCNwTb zj*&amZJgLi*BL1lyGvYSsWHCKf2a22@2jjGz@g$D-@7+tvqc=CHiqw_1J2Z zqod4%{?$PrJ=?R9@;Pmpjz=%4v*xo^6z=bKdI33qHt<94u!JRs?8ySFF=uC|70;Nb z&YrLKx+RAtx#_4YfVUBg(QZG$6p}Gkl`c3hj_fG6$z(X%y0S$n=@Z;GV2pcHrf6cb zu$vb3_pBK%qdDp%2I)vwx!|_6D~?()TFl(H>g){c@bcXM#Otq?snYdoy8T-A`N!j|>72i^v^5a&;C8hjl)Txm z0{rj(0=2bt$YZ~ij&(F&Lv@Ec@jH3s+t`lI`Q;2q7|=|8yl)m^Ss$d7sv$&nh+WyY zX<*>N4Mo+oFA-xD_qLyouGH3!!(-UfPR{p`=Ub{PoC)&MKh@|hi3gPI^l&%==}VT# zynR(srJ=Tng=J)dVcr{A~v3CLfod7MmWP8L>?g&)PPHsRhaC3o$sKPQLcdtAPGyB6vUJZV?rs9~h{G88)y|s5=7+jB2!)yBm!kxoB z)I;%Zb{<7~*qS}Cs!cQ6yBbY1EfMSQnWB{8qV-iZ1%YnkM4g9$dZvy@7(1moe}RIq z5AB(+k_y^RPQ?!*vddHn%hU5^>zE;8?%HMVAAPGHb9xlSs5;6a+;ga^N5!+m5izv; zu3vd3dB1edYI;{9nSoMd{ajHpRSK?pxzZLf;=s+n?R+6iSBM17HoYeb+K0NzNP#4rvJaj=Kp|k*UErNYMU&qv_;t_mtPzto-7U@ zR_xb)dC!jRSmAh0hx_d}xcDgl8`x|X=PlpPvB4;cL1+`TB;`U#Uqp>xrfEFlu+#Gl zv3^bX;4!Jf^=s0=6qGMhjob)e2$$)mMExK<)jh+kRJNBbSwB>hOf#xht_?(NNiy?P3(He~|5O38TtT zitI96HptgeV9E_&b|J81VM=T>wn-u%bymtm(!0ygYIZ??*4l?R4hm|29;LAz9?*2t zx3~emVJ-V5k|o2}DX*7la?-25Ih zcZe1G$=r%=@&tdf)8JsE^C+*=&=B+VI{o_fm1~9lz$)oS=7XIl8$23(A!W4qdx0gX z9YInQh8*uqoi(>Q37k^As^3Wez23Um$+vwL5?cIX%%dMxIoIxC`zE|8j+Ta&VqvcN zalsi^C?kElB{5+v=^zzOH)ig88UrRct2@IxmoLUZkO27&Ze%)^4-4r2e!>&ri-od1S{X1#Ago@7`PyJj~p_JHea1^Aq38? zD}oVVpgnEX1!*c%oCec1h&NKCXH>Y0ImPPv=iSIxA#%EUk8iyGdlSh`vHjqqNQmK> z$96vquqISJZ0_eeVX>H{;z;LVGjV$>_sn&U+Tsr)^J~$^c&T>jL}2!e?N^wOL<`za zD-AU4k7X`OluPsdjDdq5dBzW?%76ekjWtMCCYPWcIL~u1O!csJhB?p<7SGvu5D-g)puy z4Hq@_p5efchCXdyhsU{eZ_l==A-(_lZQt<3*e9WJWfCtRrGB-VdJZ5zlMZ zpCay)s1)%W?iBB2=OS2MyE_+1h(o`kE3ikWt}XcG^Uo1s8*$&$D8n^>9QpbQ%lNXA zXfmC@zLyXt0Y0 zlp%7wrUJWjvB!P8=NIU{WH0kg(9Vrl5=%E4@qFEVgmbejNwfIl%#5eEYEJasd+!Rc zS5-BlJBTAR)T$t#)r6e~-!xI-3`)Gm3TwTq(2k-)HJl||z529Vd zkWCF*1u3x2o_@f}BOOz@PK~M#OJDO*}szC}Tv~_tHdV zl*EThGExnMfXJf|dx<+buEcfeyId=B?xv!hi#Cbe1NKBB>p!{b>v?c{cbW;k3HyAd z7-SVs=Cos_N{aVAOfLF)tODPw^gDs~DeiyMk9Lcp2X)s_FTsP4KA?%ycuk-VamG1u znss$Gk+aI0bu;Y^Oe^5!FVFM==p6dJNd6c$|Lxc${=?V|iEJc5z(d}E-_VA=s7)AU z{4B(NA+|OcQQ-RW%IXRKWUMmpr&n~(>)-kk?f`lNNR|guEJvRcV1Q<5!e1bAAVAyJ zpiPM(YG)CsEUTK#NJ`gmI!U~@U}|En(CK+UYKxxIgTGOJ_%OODc?Uom?zSZ|s0f0n zp+s9F3c2#6St-hQZo46TyMR3+!4SKh#9;ak4Uzcwi{oVK5}7zLRpu&&-4uMO@1{p{A6Y9IQlGG8Y7i!S|v(+EKBIRF506NFUpV3w0m2$|)gs zQ`Mc%X1i>!K%lobWwv;5*)4vEx=?K-UNNFcvc1l2@`&c<=lXqXdG^QNS;F{Vf7q~JBbPG9MoB=s!7gI@;9!NUfexIIt^XM@iN7{9uO2augYi=`UZW^ejSs%;zP^mN5ODJlm_xSCr z(sq)M?y(|+#%Q76=(6oA<3etHne+OhR;VNiKo`aon+2wr-dlVHOz+&I?AQ1|4gUOxuCTw>TK(Vh`Nuom&epr-E8*|_ zSuwe%ZPJ5LIbv(UiA}r6z)ZQ9jy$vGRJ9UaE0PZd=4B3)2#wn*ZcZ^$?7`zAgdL1_rQ#z$$g)!Mt2z?1PO>(PV4C26#AG$Ga{*Hw!}2Bf`eBzm zGg%w##x^`7nL1$?IbsBo)g|@w7xiHn+Q%5N3y8hG1VdU5n$4-i z$G~ZjL)n2dPv|5C7?czE3l#4JXp_jN{*;&$B zK#MLy^lR?;Q_qDRzu(KGJ>C2t_s%-ToJWp-3_VSRtDra$)4O25>Qj0fDI zhSahEaaAL&`Lx_?)7H`N6cVQ)+oshrk)b2CP~=g+v{we~v^bzR3Ao#bmYH-^%$&I? z&>elYXfQ)5F}M11pb*VB?zbO1q5YgZ7^3M8ZLeL)$3!T6VULb-L$6a`9vD+lHF#9_OLe zvYSlG?{&VgRQXh-ZI`>0Iet{gcW83O$cps3#1(~Ipk0BwLaP6|k58}IY@D!Vzd+xk zBYuHG!V_R37ty21f8_s0V*-v;d+ z@7IyuBruV{wx-ST(9Y9+}z=fENy9X(8o-EE~(XO6_@hxJbM;$(=X7k&!X^fIi~VHTDf&k z#EXeh)76QesJWb_d3t<-+|Uwnb#sdi(FJn1hrnRTxFi1)E_gom zYuET^(9#dU;Q0Rl6rr!uNMyjkecHepVGi5BZ{r0#AeIv?(eVp}2L+r-19eu+JdQh( zn5~61Is5o76L(cddhd}+o~eFkN+ZRAH}bU&A4omp?Mg|G{iGXbB6@<2w&WaKiOjp3 zUt3#4lde(&NQp*~>ED9lf940C2F#FF$It@dsW2BPd-3$vQ*%Qf-^E29WA>k2<&il* zAC7dnWqrAH{L$}LtRn23z(D#${lu&#*VXcnfzFWkb7+3_^CC5nW6j1?~`R<`jkSkjEqHTnbJ)2<%fo+uIW-ihc6jffpUDPuK=}(8AOZ>)k8DhXQ5YRSD=s&k zi-E`Yxx@IdA}vHD0ldCxF6i6py;C0s57mfou-h+5xg;~-MmDlY7i9%IZ(8U%iCPqF z>gp&%la4Z!q0DHS;TQb^>Aw2XxMUvKFst+nR6Lz>aH-{W3|Z1OXp=`}{sP&z zoby5t1}CL|C-q0QXDpvZJAZ#2U;yp*G7`);hgppL~(i-Ss(+FRG|J{8a^>J^C+ zM*2&EZb49|V%7)KWP_%JHAt8(!zV;YyE>thvV_xO_ad-Tpi;-c0sHqD?f>BKcvOE+ z{Z6n0X3V=!uciT$K+m`Gzb5{(6j&<-+apr6X=X7DKaU`#=HH#4*YNxTf%zKvS)GT( zB^WEZ)j?Z!r3Rf>;;5Rpm6el?X8!dapJI!N*!h~RUa6YiDci_C>Zbl-Za!;&@^s4} z7Qdt<%ihYp$MA=<&B`}iQ5=V`@2HkAQ6c&b&qWT@#PFQxcGHz^1B>h&UDk6eJ!g8= zO-LA6I|C!3*nk!a74+}UGrA$%r^M%GeT-+v*%ZSlOpg6TF(WM09L08rg z$r)zj6w=E>94rv^HNvLy{NV2aD3 zw;uM75B7L)fHNCJr7zb4x0jENC^>8|0=1`q`+_o?fXH zi8M3ODyv4pLj-oJ?4K~2oD~Pm)_q*Q|7x9J$W5eMM}bvw#FREl?#@OUPKx4t!vLfW z`o@tlltnS?M#9s#R`Z*I$L!p*4Mxtp&tiFOlh|WlC^-z^xJd`%!4V2He@*NEfBnh# zkB)QtO;(OA2ZZTYxPWOUA!vW&AL1D0`YW2D6@L zEmk#AWFj`6dj}Klk=^v0aoZ|22p3}S#9pAAI$E`2!hf)#@4~Fn{md8L;PSa7Vwb+# zluUgwvoX3vZ~f%Ua7xqh#jy6)(V1*I@7k`D+zOIJ1bD3Q3nNIu#v)U4d;6<(tH~25 z>M-Xq!qn0bcKVMWU-`U%{IREUF@UGrOq4lC6sZi|i_kyMf%%%Bf(KFra)R{q?wYAP zYdbMkCwu-r;BU=HYKJD%27xt}BQ=!J2^p91S@UOde#5CuEats2W2dn}4-)mvAyVI0 z!_Gvms4RAXM^pwzI`GcIfbm4|0y!a!#8b>W)V4kxj$3yb4LtmU1SrZ6=;aG zT7R#b$T#6JVSkQ|RtF|@9pWCbBUsOxh7SSl8g`MbN%UxP`%aRM-V6?7KpmSJvjV3U z8I)BWqk`a!I~k9aGf^&SfOI7GF zzE0_|6-Zm-DCp9i3*orc0zxQk1%QjqE>dL^z`pdY_a<+DEcqnk+V{kXKNFpT!D^e>RXIb33I z-`D5#Sg!v+6+q-ss_8b4z%rm zRrclXf?xMO>|(ixlI?z1I;WBN$Z?aqWr${mwFj3EYyJHh`^J1w+iEmTjhf|U6?>Eq zc#~Se_Pg1DXp|NnF83FM1yllIsx)0#S$;|yOUcTlxV__6{Kc8^AlBv9i*@PoyL}`) zlXd=x5hiDgUeq?Og5+ z#?&8OuD3x}sPCjzC#0djVajm`SfuW%6{F2#HG>zNjjw!CL~6VeioQ5Y4nImamj9Qq z>0-Dr{h7NO`CH(eJ+nDIc4=FrLkj^@ zHV9Z^Yzr?_yy8z9yw?v-S3H;Qo6ygSMAfpyqfe5SnYEQKi?ha$Tu3#aBWm5t9eu^a zIwkA_wWgFPB6?IM!DZ|vel3Rbs81Nn3zsX3^j%;1-RQG}3niHSMVjq-$=ucY;Mm^b zXhfn+$MX#1RM9nhMIA>Ij??(YmdAtT$<%A=?hb+?5%8Qj6QQIgAQCamvTQ(r7EwfX zhe9a|?9znC<;yF5)47N14zu3(&j?HVA3W|6-YV)@mt=}DTkwH*0y=FAMlQ$fut+hI z6T}Gkf3sW8i|{|}OOw(GB87;E* z#%s(at-Bxgn+V!y|A`2W0a2g{*8hCrRXVDh72w|7zq$7XFjY6}if#o^-q(*{7r?Mc zAh>#b;WhPY=}0mWaRph0{r710e~6Wkzux@8@gvA3;S|zu0W}P~sh~!R+>JDnarvs( zzAiF7z?(`RSQ2TtI=G$xYN!AE3LW_U?>E&(2D~V;D`tB+ib?Ar$9^(--?<7=$Y7e9 z%IH>bR>T@a>U=8404Y@&*Qw{4F8wpH2|L^5@0(s~T~I70y3%*1pISM$_6;D-hra3S zSMR7#J)rcjVtX$f)R-uJgoQuC-_99sGBL4`l1ah~mOZY3*J94=A3Wp=!?wYuHvIDz z+U$&NYK*A*!b*Ga=c#)MRkv_>U->Sd8xWys5qs%qAtejat@Fs;$k2~!cZ?cc&pLf{ z*u0l`^FH3Z`AX~6{h#SKpcE^^c$oHXS{A7Wrzd8b#a02oKvE{3w82b5D5@I|qBBvn zd;$Jp@U+hMEy=25heFfF5;ZS-S)rK*udyuI0nz5VAedq9if(4d-Wpzoomg3MQV zf%gRGdE6ZPWDY#lU(B1mJMA7i_}w0u#BdOuh}`a`@#Xj zbl@|}`*!$k!)JS@FQhlMCE<dKlO!^IOiPd4AFtiQhhMHNI3zK8qAz@up=SiIzeM~;ew z+lS245gFfE&$1O;DJOpQ?83f2ktZe&6Gv0Ec4;)fcfvjfA{s8sw<`l(DmUrIORA9Z z@^Pr^S^->tkF!irH2|qv1)@ZdY@@9fn}s#M(nTq36K0`iK)=f~Vx<&t?$t4eOcowq zXxR=fo?ou#9lwF54ad8kRHwV5Z~ngsi8V!2!xawGH7a^pwsb0deotDO;J#@#nhB%r zyoH}ib0oN{Bj+?_ZUAvYf}Jni9fLr&lM|(qlj+_6_^NMY51KBwJ`_bJWYx2I(fjar ze<7Ux%d#7w%Kr%u8&9XjGX6>U7kNWe5`GK+$mmxCfbjQTlL_PdYup3K05Cy+X8_WG z3;=8om}JwUya>D7g@l`te;Ua&AzTxNEXA^pTReoY`{98mC{v_^& zVSm>4>43J+|1VdHKKxg{$@-rV6%Ie&W4sEc#LD-aMHjx1-_UQl^2B61c z%aIaFDCdmJqO3WeMP+~BsD-uatRs&15-H_Wuqo?J<1I(JbEb|BJo%j%w;n*M@_rC@7&wFHvbK zB1o5#V4;hkH0dZJ9i#^dgen~b1QdiQNH39&w9pYjkltH@ASIAc10?a?zj@Dh&di+o z&Nt^<=gj;wYq6F@lD(5>?>x`_T=!MF_~q+}-=`gu)?91^IS#+gWcw{FB7;6$I~Twa z^u?U@$e4YhaO7h|@wF)VC#$B)wmxIPWPwR38*d@>-*{1K`QCae#W5{_EoE3Oc}vJxi%S!)yiai}GvS;_Bh-89=HcMD{b0#T4U=1K*~`<0F(U1*JYsu}hn; zM*p0M{7&CjgTEYQVKRD_C;B5zB%j8qSGvv#rX(|}n@{NG;r0=2jP>4NmSC z-6UzBdXRZe_3?<#0N9AQ&LjQD6izOEPL+!)_sn#>@bmN=C;iG_o9O(!*~WoJV53SN zIbQgiJOBUS{Qjewp8uuo?bu&}FzD1ks5n^w3MJVQg`oE;Mj10bfN-ZSf(^>{p~#Ar zzHvBp$dV7xIB*lsx!1Mt2dHRI2VRRwYS}}^*$d< z9&n)Es|}j2P^V%c&p(Y*1vBO+!BiGjVK+A~(r*6BvO`RaiRY$)RDh`2z89vB6pqzp zBU@qgGY0K)r=Qt4PTz@xD!==NDsyVzJ1_e)P^9}9_6BC{#&M%2Dj{t{7o0U1N*T47 zj|2ImqYw_IWEz58d%}79DJ{pO*+W@lc9NUS(wXoMMGcxcHTkasA|FfLm`OM+>J^oh zl+%e}mOd}6m~oe;y5iAr$DsB#wRQKM*B!1nhZ6pWQ$m)$xhc#V|`EMladIn67 zu8DIyHPey&0##WqgDsGU>8=aK}v?v!XL%|LF>r0o9kIIROu!-GnYwzCq!uRl8o z9-l7|Eua*79hZ78IFG-4;CiVlb!<~tba@uE@VN`l6UzEaYAEQM1>!#GHz=3nMOGUj zi{yeC1)r{IpUkqt%a=llJ~qcCVTX~^iL+{q&_u#ACLf-Edzh^L3HHveP$juTm(D46 zA-=;)pd9_yhU8*{hF>S8TE)}@X(Ej^&y!4Cv$qjxR+-PfwzPb!$h!SiKxPZcu&d0` z9taxra$|H9D_VVVJHsjO+Wqpcog+V0cwJj?wWS~%vKDb5O6jI>%14sa?Ep`+u}@wx zK1)@#j|ryQQR=gtuq$8t!AjPLKNC82qncTW{o#w?m~^!aK1|e-y))XYQ0Qg}jx%2M zZIXwHF-+V%Hg+*k65ea28wq75S>>*)x9wYAG|8u<+sUD6JN1yurrQ}8+zAReuo7KIaZ<0$ zARYp2LI@4XV?e)dD0OnNE^|sgTl2EI@B1yDW;+_T7rG+vJ%R3MEn$cHzzrKg@dq&6 z%L<^?Mnj;@%YAQ-pTHQNYYJ=(y#4RheF^m!3QrHGY?MvT=)xiA~*_5Vgs1xri$}^MP@t6T*tHDkOz}`R7b8xxi zv9Av=aX62W$7LV_=t(H)l)y8LNN%|Avh(UhqfFNu*-kzVE;mXi@GM=NuiDDF9!;+Q z2%#H*vJq_OJ`yVdN6fOuCxs+k{f1__F`~)0@_?E-BVQw);5i#3K1J4F?ivV68ae-L3o7E5x`k?q~IhlGA2w5RYlOqng(Z zyp`aLAbEk<_uzP{oEq?%0xzyXg$X_}jHUUATdyt^ebuHZ;9O=={ycpUrd6UF$b6ts%AQuYtG$`s#jBYx zSIP8&QBqxo4zE%&f1*+E!ZUm|ET|zY9osD9kQR7}lC~HCkvlQh{hx zibwaP{fSXbhwjv60OsRY{!B^5&&dZr-9-I0^QY>?8(G)D6ECdU3Y4l_x-?T5Zuas`+niw^( z1uH3_-k>dJywR5_Z=@-Ah8LaF4n9{mD%9k9`#JHPV|EknGzZ@?jNGFW2vy-kRxeR+ zCM=dNnZ7gu$}we;$xd8@#*L1U#dA+(fH4W5%?c3JHzBTcpu*=*i4fHq@*N;f!_D32 z-^k`Y58W8}*`@To^RdLqg{$3Km}B4=J>W6qdNqogpw&ysgDuULA!mm%yUt!T`ukyZ z{xu{+--|SPM7Ge^3MVE+*gD7_8GtHI>^BJHQZw`-J#KzqWI%3!nh$fr{iO0bAMBT5 za06ujKNg$M)J$?}6;x5kZAXEE!TtAz6tKm;!CE zVgHDowZK(X4{y9{{IUfB6-mXl$$agBwilMsA7j%x2O^9Wp63%iRfp2rr&8L?G`(|s zuljSOZNBUH^y*<~QltM0XveCW@jVdql7f&Q-;&{pgJ{+?)zzT%YA>07-&CsP4UW7N zh-1-J$;&0ywwBWCf-jQ}u+3#hG8>7fID8*UOEySs6I0qfLlb{^z_~s!wqt#PZ6P;F z7i*bJIB4TOr?7O?l{MTE+ZQh~e>r5J?n=}hSLaxqfIz`M?4&ME8wI{A*TLYVt@Et% zFmaOI{*J4ngvr|#SE&J00^dtm zXy<1m^hfd6ibM+3o#3tbW~*}YiP#Y1sK_2pbdvr+qL^ev<2uxNdwj zA`88LT0Pw5xtjIl&e zFkD{QvA#kJSZ|WL0CoceEQ8{7c-`slo!8Ct$&VP6xFCk@VN7SX>!vq5X0jz_V-A8(K+E`Q!k|&T~2pt2sBnC)|Gw4XW5Ac<2 ziQ$XByPyRB6+QG>V|85>e9}Yv zl5pm0&j6gHZb7EE*nE0>Al&Cl0Zlj=}A(pfl8?>4@TT_SRG8mz)RAs_oB_5 zSHf+!7np7f`Qt#0Lr?*PDPb$tD}jM*K%gX6jlb`7LFQDUFFcBU^4x?$;^b>dpLm>r z{E=+4I?4>_5_n?i!WzYp*kl|ZYbOWn8PnWlhGUDqBB7lNn~7>0Ni){WK`)%v|0CSu zf2#5H_cxe!lDnaj!%zry%((f+*wb_;MX>lwv!bf}@5Af>C+y-H?$(p8z?367^VIwW z$cc61A;*^k5z^DgJ%8X1^0q1aK{=lFEH6N`B0(T@{-X0*JbkHgWat|^y3@&(2c<^> zE2NmK#WDFh50hd)$_UOiSdsg67k#^6=Mbh;01gcEnA)B`7VX&y981jGf(6btnMw}X!h7$Nf{#s1{^0xBJ8zM{N4+Sj zND#emAHNWbE8=~UzZ((U)QY-z#$FW|z>_x}bT01sxhw<09{N>dQo z4_{8%&qS@tvifn{-0(&w-ItZBb-<+v`l5yZc;#=9Aww60B>NkLH{5T$g$ghV%^zU% z4S3SGZxLt)xtg^!|F|T-M#?xOM*#OxV z^BYtYjsuf3ILN?;FvWGMFB-U|2K|{E}?S$vJ@eVhe*$|c9}-(mzBG$ z_#HsKk`FVMeRb!w*!r1M{CI6>gG59B>U@UlP_7@CMeaedDqTFocjFCUS2N_cytDai z4p0Uwy*0=A40h9l@RoGUEAyrw3U@ju=QTWu@A=?9hbKh=t`aZuIi-i2aQ;-g6%swE@zni$(9)}g25*veO#8>6 zpGd_e@Zj<`j>;w$&BGzsCR-3g6%7wx^b zBH3$FIgBTg7@X{v;&Q?g-t;bO{02#RPu;AtQ38@~Fn`ImO8Co%ImfItdsf91*OaY5 zC++lEOzBGw3OeqnyY{GB+jOSZ2;;fA24F$|hC=++s~b>Zn973~#7*2h=`pZlnAmgC zIy3w9W{+5VLu23rAe`DTeuAne;qL)zZLB>X+*|n6F~C*IGyjTUa??+zzC;m*g@{j%<+^@?^0wx0Fv5Dm$d>@CB4W*`4*PFAmemNSrkvU1n4AzaDL9i`baB# zU;q85(dzMmO?lmD7(fssq#a99SrnamQFtrGZ^xI_juCMgwSGzJv4^3Y1A!^Tt>8x6 zJcJ6Rq@>XzM^#$Zm-@V;&d@|XB~M!Yg_`>)qeORXfauoj-_!NFd!GOf_A>jmr_9ry z?h?LqYf`4^{Y|ZmVEyVr8m^O}pIpKcO%L(8)yi4)^%kRj@70l}lHpMc`=hev%m>c4H(pAGqrJxBXBBt%pyL1e+xvgv z*uP|7%>s0G)!(4eJH!a;dCallpBi3kAK~plM{S-`gwX`{kI2iTw!9UPkfb0y-SW3SaR{jBNMVe#)9b)8_P`iR00YtW0oY z>(Elp8PhRM$E{(El^**4-@Le5C(d;z({z`}x@36u+u5dMpZQHFGMZ3bq@Sa}wEUuw zOS-nAxqGa78N=qC0D=3r4Li+)@h;+Z;H6%^k1J%+v*4cm)x8R+8$68D)cUIb2Jxn^XPUG z>O2pwHYNF3-VPsdf3w;JKKn?=+r{EA;@@HJSt}ug5Czp}wwM5ha_;h+e6IiC$hMpmTP)q9O{UHQiPNN2qo;Gr&2IS@94?oKowMp73;i(x7n zQF3?(Zr7ay{PfQrVfUZ!SNQ4&iRw7WD<})lsYbS`yZCuhidX!so@vaxC06=uNoo*L z5;k00Lk&%=K!=|&4sd(UEGOFK2IRYC>}$gpc4X~bx5SL)XWX;iL93hsfsR=e0!7l; zjZp?_jYf3!bT=#w!&7q@J<&h--n(VDOxG=q;)|N>1z4wI6nw+c`Y*3zUboR(%Gvu( z_&iFJ^d9sCK`ypxAQDoYk_M32^IicwhCYDBUVLS{sw;ByOuxd%E>2=XZeZNFP#&DF zrCx0JWmeqcciFaW-i`FRe$HH-mcprDL9rpXTbEcwOUs(C_7~tBhv0DeE+#$A28<7o zjD4%l_iAxi{zIdrH~DhHV;R%r5MX_M3k_O?3fo(X5XD1e#8_MFBleBTJ(mtSpkA;2 z{Dsw#bN9Hq9-6>;SLi~Vmi~&b-TdS4nrye-fJ$Reukp!mP^cSq)XJ&--NC560)lAl zJH|U;JN~nL``xC_Ui3Fa(n~YV?Y#h)ej992l{$>v)x&I;79fd5fA#PH`4>l2CttxF z6d zIbRe|Xv5w$+j&uzKAJCb87Fp0$e*Z#{l)1lD;F~vsLZ7-)Ay)98Uaomh9XRLgD&fA z4(#-V_shBPzWBg6y5PKnmY~}Ki>P07Pd-sU%!z5 zV_5%8W`oh``#izhmCBtzhNGi*57kkelfx_rB7@K|Ee0KT*M`zlG&QA1OYk;py3zGO z(Wt!WP`wqgUz2uVbSuf<>FrYDr!er_&YxdVl6zC0b#}Z#oFqe_3hjWhM_7`AodBdW zAnIaGO}xgGyvE`E=4S~zauzu}YEnIy0?q!e^|A<7E&(w-!i=q29xPXVDV+vLCCDk# ze2G7@*r2+!*W2j;y7}+x%9a0o%;GO3?PD0cww+%Y_+tc-&LY8RC-TDsbc=XiQwOz_sH3IpRIr{EC+2#<_vm(-xm;t+i>xLAu`MnHJlsG@10 z)E?+>eWm=j&0m38_J8Fy|Ap7o_+PfSifU_&>f=(nQ2v~$b8B+Z2FhH#Y}W2;jMqjl z(DaeiTjzKP!ciO%N>`s>mnNCa+AwY-hWPrlHTt?mMo_mqjnY~ySu(m_kZ-5y>SNDb zl(y{JGnbBrXV-M@MOpQyQ zV?Ta2c+ryJSc^rTvk#Q~QPwDf)po0KsQq|v&<%cY(t3SI?OXe~4+K>=Kq5Mp4_GUm zUx#wIopwwW^~{kz>*_L7&TnpglBcpkw2EnR=1jlm_va+3#qlsDnhyhusB@1GxxM=m zu9s_U$@?T4@<1J2=?gel@jYd8(I7(mu?(42;oT@~kt@buc84QVZt}qH);3~5oUaBy z9Y2aEs$O6gWV#qH-(qOUY51@O#0Wa^DLY6W%|lisxOKQdm8H5&pE@$Nb!H1zX!7vq z3f?LC((;v#j(53kR&L*)1pCx(%N7xlBr?4DXqzi+II4q4>5`0+`%L%x6MjI^X%et@ z`r@Jp4A2llw+P;=Te)2w=U}7#h1|o__t@at`Pfbl5Y2YMz{@p2My=h&f0B@pusisP z)P6taQH|P6@OD|;T70;L=2lf~{CUs0vm`mf{30H=X2~{~G_^PKu`WqdB+X%sR;$mz zJt2*2%Q90Hgs3LgSVj0+UV7Dw#e|?;cMGhIs>e#%HBWpI>5KaQh2=8(|$Y~;=&f=grhEN_1AmYzQs z0eR?x7(?DIf=`f)Ta0`akYOe9C?{VjbEP{b&l37v&#ETAy(rIngRa_pMDMf@`jbOC zHO)JFU3T5q#-B%3`hjFwTq_+}U7*?{@HEMG0TaRPJ(sd5%If5Vpg#M$-6LTWhuI#? zh;>isZTAGF!EFlc{AH2&&=bnAMqNigsk9jtM*4obh@$JK5^Y$ z?W3UqsOW<=hz>;a!{Ch`og8-l4BWR{fp?CFsFOeKwf;&C-q@WWEPD?n$3C6B8c}wfj383MuZOtMUmt=Y&RdMuK z#bWWfuUQ5^C|!qKzBYV37Z2Fz;sg2Ncd^hk{sbJKl9gE-^%+u8jo5rM#p_yoe2%5{H=lRewt7e-D}U+K#-#wBXt0C1|e45US>~) zs5aIF5UPfC6gt{pKeYAHyS>EwjQ=+JIUtyKLZ5;^-|4IgRQOUkU-npK?6bLF()riT z8)YrijK2I@$pfZq_Q3O8bbQ|-c}aeaA5(CW&4SRSCD%WYlS02zXrZWPsRRj`Efwvwy8Q3>}vgGNPSFrP+V!U3- zO3s*TdFAv4G8(dY9>CxNpbNGt=-eGaq%~ zjgPm39=0;(^FCv8@AL-VRx(Ly0oew0G}e&QmRy5Id@3i3l`47_vlC-4rDVAA+P=_O zNC&)m{@tC2NRi}t%bD5?xgQ?U!h%@Cw)U7B>^rWx$0zFV~Qs7@Dch$!3Qu3iMxhb(+1>c zDG*|c)t=CA&`+)AknA960AN&C^cz$NoZ^}Rvda)Tk_L#auN~1xgANa`gonwXDA+-(z{> z7p_=KFYWp$>Mg)TR(tYK)I8Drt#g}jZ;?ate}jSmSB5*)>wm5()uTiH`zzQ8R`Jc( z5%#2RVz)Wz)bXu#pDb`SL7J%-3uCb`=`K+6mAfl3Wop)QT;uFIXT@6#GPd3UL3gv8 zxkz$EC!9!Fpz<8G4^wqI#}2EbwWUR!{!zUy`owO3+S`-2T-OzKStHjxuIa@4`$(IcswVH@|TPQqjaD8~DO>#@6 z#oD&KitUM|{vr%Q9iFd_M`Oh~$ZW*Yw03aBOCj>q^Nj0L2Hd5a#bPx|lb=}0CnPx3rYB4u99sHLUh->fIyQ5~*UeVpZndtO zBU`MZF@y#xGu#b}-nZHpCksj_&D-oyhnF@|u-zFI%9E4tH~8GGuKxmv=Tvc0{(^2d zG6KW9`_K-HrnR5St#tJHtjUNfp4PFs6}>!D2de#wXlXp(+m>X2_zgO~L9*E&Afh(9 z7kMz@S<-`zG})^)mFl_EI#2j0UAYz5J3}+yuYIjq>o+mD%uaJ#*;57b`HgWeq`bRZ z*?&KXBxH+&u}`ZyVR?8OGNm3pK0T4XJd%)UEaJNK?7CFYWH%?Lvj5Nism}mZfB2J_mjWAbg(;y8@&Gr5HICIQF`Epz=_=(k(M?&c=@KH2g>sbBhYN zQhw76ltoT_`GowHp&;!R1*tN|8)N4$%p*<^<52C$h%9DIvS#Q*)cpv9hS9Ck7KgX$ zHz)b(W*Pt#?DkSp0nO1(RlwRqlI%(f#Ws0%&!47BwK`LwBnNkT^O@COkIt+0ubxXi zVD56So>aMxxza99??IPox(XxcfunsCBH~KPhPmX+{9c&JX6`hd+L2uQ97TP6n%lCo zcWz&<8HQUqsyMfJ^`~1TNDvUb=j61wO5BE8k+37z!R|*CpB{!&^<5 zQiOt>gS$lS_Dx1%P6N@h)bpfvBCM0!Fd}46Oy-#aU+H@n@0a{DCBB{y1SrFr!otj5 zt#@93qD@qR1Q>XHy%e=<-N&rW^r4dX+yYGxGz24zu|r29Yr**2goB4;(hwd8{`1vC z-aHNWMstTRHZ(M+$X0ZS?VouPebIp9T7PHM3Mcv#@);8NwzB{p&JZT@d7{_U0M6=_ zY1Nd)?9cQwiw$~^-VAHd2~)XWMf7;Z+uuO}X9;fD{$yftD-&t!-jIwk%b4FvkeTCt z+t@E@M!j;Ce%U1XO9>yBm@Nyf(&oV!2V^YwZxESvy#vF)r)+r@;gBCt`L%)l^EL5g zkMjJ@J7@YYB!1g@*j_}Pw^Ser<&(V!8{IyV%x}an7Bhy`JIm=JPnh}TnQT5&WA62! zdt2U7RK$Dr+U*bE5mV7cdtiNwypDMqBR-X8H2l;>Yk8p8>`d=<$+f48CDEa&wvm0Y zB*D+IHiFv-tp0D%DOy5)+woRh`c|BmnA$SnZ}$M#mgL~xlKS~fY`9IMo}g3*N_J9C zcyG!iLXdX1E&8x_|1K#CI}ih@M1?{x!Lh4u-rt|{q7+~V@blIh!6i>F*SuJ0Zo%yrM*j;R-a@sUSlV!RCqSjs$BLoiZu z5L^Uf94h*H^Ch@HR@&G`XLNBGJA9+Gtdzq~JsSMF@LM6B`{8Kl`I7tZr<--7e zm2jY`8?|4Iw5~m)qR2H_?9Zdg1Cdz?Lg8w zQfYxt&93QG2!!4AS#2{nMvA4RloqjiBLzjX%xvD`K=C|fY-Zp0#3QY$@;5q(bS*XW zyh^qOaYI#N3~( zLU;t9g>pAnIxFc58a^ zr^r3b1`qP;^}r$A2IkrP=b8){nM1$LkCq=?*o}lO0JrKypvqu?P9u=LD=ph2vEHdZj%FYT`57kl0E6FCiMW z*k5lkzD-|m9hRq40tu8dw9SXr8q15#7Pq zC`e{~1vWX117-yAbugFZZ;`SZ4)pU`#>hdUzdJSJ?~?iYuBY!shn&RykeeeCZm*f>8sS^}>7)z-ZeOX6 z1Y`o9%;f2oE$?=bO?go!hHXq;GP~YK%fTdj@1w6>B0rftLUX{zc9?PZ^3YuxpMNk4_TR-||7+B2=s)dRXm=`cdwn}({n6}qy}9*fg4Vl+ z5(2Mb3-@(#mNJ%2L>yJC`O9yR>->=po2SB7JBkzT*1pd#GA4Zx(*E!`_GMn+uf&&m zi{DlX&@1n;ahkZBEYU7kwWwMv`>1nI3x;|n^xh|Vgt3@eZ0F~DM&yE$KwQE>XD`wu zIc=cBk@UMr(61G2LMKa2*W|sLv7y*QljD-!IeqfkY06eZykUw#=g88lD97o^8yjm_ z{!2M&5BF1Kl80`Z_3T>}|MG*h0Awr(b~7W-LS5*6S?x_-)ed%Eni`}c9>;4-5V+zZ;~+1@WWRGe~n@nym)4@NbSv+?7WTr zp+VSt^5a4y_!=vkQM3M`i-)L8n|76kkEQlgruXOZE^@FPe6`g2cu#jNAb{f@B8+!Zj0VgT+F%#F zT0gGo^WIfh+1O5caaz+_5SX^v&A=v@cFfzzoi+hyoB9X;D4pehd7eMVhfMbp%K#wy8j=L2hFzh>{As5T?Do+e7RW`pmRazA zACSEE8#MI_i@WogSoL}x$~|iKFyle_J1h4uuQUe-X-~TpFEepG;W`}l#gBRfN)R9+ z&`S=XYQ&X@T|i9pSk&#Y48!BH9OEvzYu`%PKEI5GyabUH7dYCH6UFMJLvAwDh`7>S z92mU1V=jYlqKp;hPq#x|rmo)b)dVGm-ZT)KR#svAIsX~c4Gx<-F-nz!i+2ZJ{Y;q4 zmy44zBdO)O^S#$s=}^1x(Pzl@rpxz=V$en0X^b8M0&u~}(BWxK0wkj9H;8fa)~S!f z_@~us))HK)LVDdZOXIEc708V}TR-44M!$r0>w;Q+-3s71gAp%Mbus5M*iEvF9j^4# znXZfWhEz1Doy1VII?s=B!9-heBe?Ot3vnXtKS-8*YD%Ac(;g{FYCLs5zlg|RvP%%F zmmG83kRfWfplBMIvEa#Woa~n;3j^$M0pq5b`rHnmhnA*&bI*L4v~S)^G+XXEt0-H0 z*D8;n_YWaI->2KaTRO8Un=YK@g1R>L?F{}yf_u+2?-PS(+b?d(^pfTmIodE$9=Bl~ z;=txzFn!`ww9cJxJnsbbJgY5LuNP;1;N{_AYyde>fpOe){ltr+lF34f(1l(v? zU&TiIxQWm!dglSdbz7cgF5WKBuO(I!X4w3oKsto-IMNOk3&`CU=GjRck1?7bM{fvJ zXlM9Yebh0~Y?d7RYxEBXT*EIMy1zcXkvSdlK-CNR*@~Pl((ZOd z1XTB0dFTQg!g$2xA3SY|4n|Sg(&rtW+0S$q)_PwXrbIQi3oM!~w8`lpGUYcY#0J@OmI$N7bJtR1YSDw~ zD@r#Pw$6}*+bqOe)A{0qZIuSx1-S#NjLMaUtd(LDuNC^V^fD(0CwF>z{XKyJ|LEt? z91aljsgJaYmDJAf0>44`CdWHrl`E}~$%s2+WG6!Ce&)zBAwP~RS8UG7zhjc2Rnb$M2BCq_57H2g% z+-_Cr7Cz9i#TNMO=DgF*StpgitLZB0$wlMv3=PaettIQ0iH z21KU?A_KbX5Nl*ly(4m8$`SRAJNNh^APn@P6k|4?B>=v@zc|9le2b8i4&QS~;`hpb zT<15=EZse}FMT4}EO$#I#h?-Z(URjG>S3Wuki~V1nZXZf(esiAmTz;P_`680MW=~7ZUlb>2TXk6-Fat`B5%l0Zpq$u#4rK>4>L_|S7?a}-7TgE37CZ>g$q!g8mHb_#^oI&iHp~V z?m89E`b|7MsOjok@@43jM^^+2Ql}u5wb8KW2~6-;y{&ZjV6&9=bYlK75 zqMzm@K@(Z80tFEFkU(S)3UTW<2<;|t*D~N#rsH7C&(Sefcy!_HMxF+ zHcW<)yM~$WsA^2iEq$-snF{!@h1sFLs47%|}=@Q}?pQL|Em=?)o&I)GMVvUw*;5OkT)5r3NQ z#5)`;f_jP+lUo*p+X7F(qg#1iwicMFdqk4DezJ(VI3t-_4?ci9Gw%%S>;Hhcz=G@5 zH-4nz*=jy|$vE{Y7RO9$=A~RP>q=wQzV>x>hke3*r%Y1`T!6F}XZgiX1B_HKf|ZyM z=Y2+6`D~S#>BzjL#%D9ZJjE3H?Jx8g=>ka&t|wX52exUgm@re9<~XN(lEB^Ci1x~i z(wjEJE^i*s+)XkESW_Nf62oGllJe1!oA3Qe@;A?_09GO^lbCQ1`FBVsmR z9>0%e*S%@xAP}M^DX*kQ_j3i0S(pb|P=5fXGEgam{;}$KTc*{07xsfk!(~DS)hX!? zcoE6z%sKN{FMO|FN;E42DZmK%$c0RwEGfdPh>uPO{MW{-|OX*sxBWt%h9y9^xM-eC_8o3LpFAf1Dudd zR6+B+6u_#%^i*k6e5KUWex7xq5UFz&s8op_My%%a&Lt79%V~IJT*A4)Aq-nH6Pbx5 zUNj9a+R-4g)keu=%t}7{QZwxA{9tIvT-ojZs(ln#NjwkkDw=ezUAHQF`n$kK zPwtQje1rT4a0tv1!epOyic*WQL-D}7j`$$AUrn2ZXs_opn{^9JH(D#jg+HtZ!dF1E z$(?~o4)v+15{I$+=0tC{y|`UX<)#kC3RUSx1<#V?9sv%ZAth%|@~4aPQUvf*>FjI)FA;^+shj!hkp*BK;dA zJT*73Mp&+r#tZlJJZ3qm?PMEP;#5s@k=Io)*6sq$UT5I_z;9XSnn*DOy8R3S``cDG z#BFQ>>%?@bOOZ&{D2sA;#g+hc5?w@QJr79yesID;viQxynVtq3FBWT zzLDUK<+HJ*d_X$JiQlPj_|fPBuOPh~a7%_3o_?K1W5N9D_WO!#77X&m%IW4sqFXCC z)KYnLxPxDoARC2tbf~3VJAP1G_cokUZ%8BjEzJ)Wm0h%ERaMpJ>Q`ND7iwy1Re?BD zy~xNM`l1{VNFMY8Sd6+9krVz{AR(SA4d?0FB5-u;@NUmePdKhTXsB&cL(cX-Nn9>` z@~xG7i;gyZA4wbdW9~{lC3JT_1_n{$k&AcP?45G_9QjbE|Lt@8)T75HvMDVV9U$;g ziwe5hdn*V@+o)fLVkH5ddcj4G8_Q}9UxvkU-_7iNnavUE_d8V?y?4pMUJTlDNw?B@ zyRkHW_?k+;8S5*`X^u`{m;MWS( z8yQ!V3@<&d5u{O*#h(m;*^xz%#g5JFgVj2L!yDY0KT{gVjb^CJegZ~GP`^XY#DUxw z>EbQ0;5lAEkzEPMrNC9^xYbDc+d|gfy1_rXy8Z8-@4t?S z`Oo>lv^#0IZgNV1Cv_IHn|^I?J`~AWBBM=17{D4$7c=B;i*1Vh^vOCYnLK@*F#3Q- z`OLO>0&LrGZ{QeEatzrQe{~|59_E@0+nVOWmQ%*n%69Gc0XNcL&u{Uw$GLFl)3Mja z{4{bv1v)@Z0f`Vf1|y9~tl6S_-YS0CL|t;FXF#W&YO2ABzYo3}i-TMz8O2=Hg)B&5D2NZ)AT_`wo`prn(WI*_$bajrNx;&A9|mV zmyGXIMdEVg2Cnv_MhP#L@*6D7cD5}z$P7@PMXizI0AHy>yi+ObDrc#*#T@={x*@QM z00A&FKtb#VJJJygj5A&lZ*e=kb_Ag`{syf|S0Bk(1uU){zX43peuJ){;z@vfT)-5$ zhexjJD~az%4hfG!d9~e7Y$~|)$z4C~S-TvWrL5G4a!%i~VsNUN6DDk;%BJMK8pC;U zmADoC&C=Haoz51f({t0>cD|gcX9sf= zLxmjYAQ-voVZ@WpzL@UargZ$<$hQ zN};%x}#>FJe=s`1hfg9>u!=|0o4DZe~)J^$p_oe(6w z#$n%%6nw_kT;FIV*#|#~cpbLDa@u{u7Y&&05E%?enTw~0YUo8}{noV8SU9Y558jyZ z;d9NP@cd@olKXmWEbHosvBOWmeD9x6m-&0BP3}f8W+kpd__{vH!KUz2v+yAPX;VYR znMG{vd#v1bgG=YLf}{4>rpAAAv~7iQyhP$xaLQuMm*Hg_B&6Elo^V~((P@;Pl?=~D zT8UUo(m2WIidr{M!wdT4?uO4%6EMo+^0ekg%sW4MIuoq_X>%q5opnAXQsCVq6ebdMO3 z#AD{p>Ew5r@t#Lc!n@m3G*i$f5^`~i3x3dwIu&@a2WO$KOZ`6RA}*sX-=<_bl9np5 zfuMrk6_`;pcr zmw-Ul*1yhrYvtSsi+8lt%y)zt<)A#XrT~D{<|CNC&s!NT)!7B@eCAG3BTTs66R- z(e70s6pq2h=a)I^PdGBKPKkA7S-x3PBHWt_HgvEpW+5@2aVJvFGADL6h4~_iEy2x`W>7^)GDl2KR$t=42P#IN^EWV%$xo zTg6T}9re);%Pz9hc;?4+j{*nb(`ny=-FtTuejJ@bfQd2K`B+!ojczO-ogPXpqPu?o z%~VQLRk%Rb-1g*id){Z8_iOK(zu(7jKyL(&&$Cn6Cdpc(WTA!c^xb#s?7eDL8;!FD z&1_35KE=3}bz{J!o%rc1*C&zJY)453U<|MZ2}O)ppi4mLdgC~o`7Kr-t=4BAf?w{}mp1eZ zq^r<)1REzOC(YoeP7k!Zks7S3jT4|(dXDkn#oT(5qK5LrYC2KRVy))*ux=WGf|Gi5 zlOx7UzR^HkZADMkCk^8$v{>=3-ubXVPJ-C6{gGj6luu3C%kbGs$ytk;z8~LAB^Iu! zfgUNhbNrOiol3w$%a7T}k4K~jxMXv%RuKx6tTbD!*@;WZdc#Z2wM^!jL7?q>7FO}< zR&UEMLMq=FFZ*`G=n#G+TEf27l%g7`p_G(iPFSC|bW#8`U`f)7(+;+eSE8qebbS+L#EyfU+)xEEZzLtb=DbqW{RvyDJ5$Y!5yJkwb}*P8Mv`U17Q8rv{*df!=-d8z z{H0%ZXIP-rzcWyaI^u)k-Nj7gm=g}VCh>b=^!;Fw((C7awVJk#)$3}Td`whxKiOx4z~A0i zf;{V(ztXM#v$>wUh6x?IryVU7MyTA6*Br#7s{9|LT{lAH;iW6`alLG3?%I8L8ldV^ zeW*%kSm>NX9s7|6wA~`JL>r)x!KwhwGLh!s{1J}tYR5wB3@-$$seIAQcrI=`_NQ^6?YN?bFE!Dl-M#HXiukLt(9$dqF~Fk0^&T+wE4kHN2 zeKav3ZNr@?j!lc)yFxZ8M#zoq0pU0LaPr5(Gvr`fcS%shvQyY>TjiCN6~XstUa}nJ z1Nqdk^g4|1IcXMq72)h9JF>O0WN)VNQ0d9itFak1?m3OnMBW12vYtZyj zIVw-gW6TeJ$_WPi4a1QAa}q8;zOG&_c8^U_VpX}I1SYb!qFAn1^QpHgq&wCCV*exC z;cSI7XtDC;j=M!k#xHufcnLbdcT9&}Fb~XQ9)wElcsx_bHu`15h z>e{Jp@qXCdLoT8mHm*BZiK7TUCO(z4Zcb{=VgC{!s-5DBkP?i!B4(dDIQ^EudW~y_ zul@C{?r|h%vlyvsiP{%<>4*2(U3-K-+Cit#Y+B?0u=nQiP`+*d_%KDvlCnh%A=za~ zgwY~NB0|}wl8`N12E!;J>x59sUR0JTYlI2?>+sX~))4pOfV^<}^Fii^`dx|@Mvb`gDZaUbe*hlUvC zI*D#~ImjCI&io9&+<%g4xAQq{{?(#Ip=)Ek^GybaIWhvzy>HQ@RxWktrv;ZM*Y(_n_V|O-?#E^lhNwKUzNWn9 zMUi|))`b;Pv%ZjRd(-WeH~h`K&F_x!rKr29nm(q0yTgmf#>nN-h?L4U8m~DvurRQj z&%mujwJ>{n`of)kssUQwY#)dC%}{craPoE3nTm({+`)^|vY)kzy>nfzE96{%d&e@s zz_qK5amI#|jMw$_>^e-)r^aoW+g1sVT(5O z7MpJkh3{9>CPIfXHHgr?D6P)Qy(Hb*wc_eX;A>3Ad|Cm^P%0t9n)TL4Vu-@*3;BrPH_vNDe4yk_GUal|57suA+-%9l)TQ&Bu z2Pg=&wrIk#jO!_co-=;dW0L`>_P>L`A9uTj7ues`_pXm#OQ4rl|lXKvVmI zxliLHwUEyYQ7rYm#uy)>RML{3j68{X7ANz4EDS=<2ZcVkYk0-#^`qd03H2DJk~$}c zkAr&&R@s?P99J4HzxPV)FRPC|u5lUaz^w&Sxt>yX1iCVtzuMxUOBf7)g$y@;ZZ^|v zTwY?TASa@|W7o~-MgyJQsmXL2aT_hfZ1bWL($i5%L?oi{<`A27`y(X!G~wk?oafHU z{+!;JaPPz3B8^VZ4{OY^cg7OD!pg6cb!Z=H3f(BgnBGNcqiWk)u8!MgCmL+21*|Ey z!e>kDh{vXH>0CaScIla&NZ|7T|K}v5Ix85lA$Vnon`A=tWPzF9laY4jylQggxy=>E z-8I5pP|KbES18KFxW(@v2f>aXkYZAcZSJgwNSEx2)6wrzN)8X=zkHgUc);FpJM^IA zqt~&YZ{E65ee}I2mW`-C56z;867I$kZL4vGev7qm{;RFF$4W-}K0KM|z>Yrcc1|lO z-YZ!0?fqjLRFI+eaX^Uk(=y2~4>TyaPwOr@?auotq}#%oaY^ac|6EzS(Hwtp3NR%0 zWDzReQ6ov*w)#itJ4KPy>;{UjSS#2~Zs1Jc)8?_-WSL$v%x zphafP>ba_~uxO-K-d6r5d&T^oD3l{BYZzsRmT zW}Q}?g|-zw)q5&*j7{WN@u*zS(wTz~C$wSmfrZdQKVQJUs~MChE(U7srksqeFN;c{ zK1=Vlx*hNO09o`}R_glX%(KZ)8W!x&tv z?hUiVeGbzvL;L-Y^dt9I`(#DlpuUzVYK`DM>X6V;SZXp=BY!Uceov?qe<_AeJ}s%b zUj_Oc;$KSYYrqb-7}fCJK)RdYwL?8QIzRV1mGmBZ%ld+vPArj)qaRC{f~BqOCc(zI z>u@~%Dc>-)s;p@mk!%>xdGR7iU;EwK>J=0DpWBW%^*?nEE>)TLmqA%jv6a#k$pTVb zHc7WSkuxgfPTP!gI1;O#Ufg`uMu%&_GQ2$J&DDnna>@pbGos!he@6SFTFGL_*inIK zB7BUkicNr&JXe1drS_rYWvAzfC4tp>Pjj#5+KK`66ZBXyG11b#LKc@0JAYYP*PsUJ z$#=wQl8d7nUq)F<0{PShi`$(tJ~{`ER!4V1HDMR+(EwV=MdHm#o-b$ZT}R=pS}w=Z zZn4+&HQfC$a-;4P|6F5l13%F+*WZxBLQ?j6N$W>&c&vfU;n|%d4*G1CYIyFp^4`>V z2<>9($AL3%vj(&@AIKh1cW6UFuM=Twk+Fss9YNtomBoj@X_)NoPtVVJ9aBp8@$I-ks_9f_WDe>ises@+##oIFO_uA* zRWl|8FU%e&isT&4PcY3VA8d8_SxMKj;4o_3K%wZI<9Nz>;EwYl$$+@@78*B^Bl%Ew zkqwb0LS8y@WkxzVWxQGLn+8r~F@#6XJV?Q#x{zEp5tZQBh7EB7i5}7io4%VJ z8~m`le!=4T_@l2U-c}0xcMR<-cO?s>lt{P{l?63f+tJ*!y`6Q5l}{a=7ZiPYCgKOz zCU+4gxLY31qqsCLvpop}rpO8iS}ADBdWaYmO)elLj6DQnS3lcvDZAGpF>J_z%hXrr zzPbBBKyUiWa1M~u+y`b< zLX3s(0o_=k>YMi5JK`84jxf&YagM|3Q%Nr!2sPTpq`K^xNo8=Wq1jA8{RK zDvc7V$Pm8U-K#GgXjSpQ56}G-BK*y0vVSVHo&n_C*LzbjSIdqZ03I@BVna2Kk+gov z(%hJ=RWxi6DO&J`+j!=^>igrXRZ6w!nxn`)H1sN*Oa>Uh0jx7fNE+2Zj`2TD0sQ{I zT{<*iYe*6Wo)U(76|I@o#AN%Bm%NIu`!1PFen(;`3fuv>sUHx)@yERdF=~nMktgLK zY7M@bLZt#UI|5$wGXzp}X3 zEh;DAnK1EV)oYa?2&C`1pT+`M)uU{J(X8{+%(< z8RcWV_Cgife-G1^*l%KD@X2jgqed5`a7$e%&Mo{`NvmzCJRn{xf4z~Du$DVK6s`fK z=1-EcM-aP|!zf~7xxyDal=i$@bec!##0(BIJ{^4=60h{Y_LQ@AlN-mzy9VRY*N&jS)k3sm0*RmXO_d1?X`wR#DDB2!O4oK zOIv<3g6PG4V!98zEnl27L7v|6SmfDpL4lN)anEw5 zHxUm#!snsp#kvD04DahYJP|pdbN%UgRk;-#5nc!-g@hk-=`FcHjaAblEH(fW^}AxD zV&wkssv%*pt0nQDW_tzSNyYcPgQ_HsL*C>T8TU!wTf!9i9{|pbeJd~sg%w!NA?0EM z*=d8*$1XXPELmy09DXeHHl#@sit0Q7b9mfTsdp(lVD@UQWxK3!6c(P#JW<|%p0g|9kc_?oi;ry$I@n6N4eo%5fbI8m zQf?&R!yFZlc=X{^xt7NZirK%Yy>ypwbe5k^U2DYV9!X^ja8_)E0Xs5-e9^y+yO%kE z*Q6+G+R#owSNqOpPfW`pe6bmifXnZ_y5kxTXGpcxw3;anNmyz-rJDumbQMvH}hQP?Ea?(Wn}+*HcBQ zHpoRzY#fiem)!`weXi8tZi(eshlVsth!itPPEHFI)F5wA5ofwhbiMl~*C(hNt`7~saS5eS3 z)QO@B%<1;=^82)<*=0k_r{r0$lEcT_t#`A_o8MuKa=$kMYP!|NJLm)4(QojLzT-c1L(!T@B!B&mrIvi6*mLYA5p z;3dZPkzyOl>T&jO4kz8}iSOC83A1PrEQs=2;&Kt zafKmOh)6hxQY#}xu80(80{}wRr1*R#2TN54Mb~@wo>Re(YIc#+803yOIm*s(g2kK# z5Pmy=T~ioIGlXYjK`-+{UjY_ut!ifWPHIkji>>gW$i)QqHy_SD${_?Z9nm>8y(CANl;=qulEXilH@AYTNFHpYd)ACE^A>>T=D znoi;ER1}w53F*x2xmaMo=CBf~-QwaR0Gn6#_4dAc$s+zy#E4(XpOKLM0iYbXC`FlM zGv)_fDtLYyOmzI`za{ES~TGUK?A^7O5LpC+d zp9^IjTR*f5)RI!J;q}5e`wP^031p1aCZ5qlc_G}eIV zQbj3@==a*z4;A6=S01OFBCueC%8%M)%&J~33AtQSQTo=@6gfB(7yH@!g`h-L9{ah$ zlRNP5{F{cD%UuXxVBt(-e#&G;QbI%7U{Mxr^FTROE`i*1@>$roaU*dm2 z3UKvDz;i=UIOd@m5>Ymy-001$nO;`!Ji5b4#CdhtMo7*L%ANPVq=H^*g8C%3nVxXldunw%0 ziKg>P1lTB7U?w%YhQ{o|&TlsLB}k)XjeLBxgxXiuba>qmem6HN_g;&iy%xkDWnX!a zv_Vz{Rcu1N2qjg~s0#qnlQu5u4r4LF>y>pXnLlQj5F_ie9L!>Vd9wF|dn;oT`j%{j znd!$!>#^2aS&Zo&JM3#@&k3eJz59!oO7t{ON?}}J-$r6g*x$s-7>EnMYGnM5f%!KJ zFcC|_i`YP^Sz)5=z13X%856VDMxuU4g5*lz#*NWS5c zz+r-R4H;^ar7hele4?Ntuy_A3Np0D%ymq5GQGu9PL8uv0Vl=r~`^ zfxbQmm1=;-+TLr;H7BtXn(JxIq&XtK-b%F7zafLvY}m1&_>?kzlYGIXXGvSqG4k73 z?`H-PW3jh~Dht7CQYwdyxb)htp-In(>L_7CW$P$Eu9n4T>gqeQ0 z2Dj{T8-%QRzRW)zSXxCCz^icKyo^;az$Vz^FZH$hg@eCL_v(7^WUt+1aWcQjw?o`V zcX-`AeR%1$py_Ug35H5gEIj~ztXzlyAI+-6e4Qn)*^)9wtcTSRH9dFl`dFo}zqLB; zi-j&7y;kX8oV^A0G`X8h*q13aH%A>6y=Gw8e zQi5dqk!2LeM4CI-Sr~b0_C%krX~)fy_jHrUmqHVf*ZBy6Qf{f^5dRS{Hp(9G-LwuE zoxdXRNU6aHAMvBc3AL66vvpN-Z6axRU=`P z{}5TTXFg{AVpnQR?<5n0;gws4Q4#6~bQl8QV!#{Lx0c%pzuCGUb%F>0b6*yD7U45w z*0j-(9;b=?vf_tQx}rrnnqO{}LL{=CjPEvILuJf^+z`0h3dSlpv>ZON$kgy=+N+G^ zvDrK&A-b>WXDX#H`W$@sm_SZh~+#H%W303?;svAU4fN50-=!HcuAbv08S&H+|Hd($f3HqZXC-1$heY3XkoAT0I3s@`T70 zwqE=gW9@*s-jbTq=rfkx{Kpo*1eI?{%|CF*h+BMm#urg`L_J?}?)(wxFyEh7glGI) z;LLfI`zV}*p$A2G<>xSa5Vw~~WM#5Y^`kRSaSuZezJd>2vYH2ZYT?|0{!$b*YCcMf z$gN2TL?%>&gP^10Gu@U`U?sc1XuhVFU>j^G5{G(++a>QF;2v#(dq7blpY(4wn~G3V zAS{o}f?9m=hpP@r4(ENH4DlDsIx{N-o9*u;Tw&pHeeCgtQM#!Ym&Reh*f0e3haz!W z1m#*w{n4P$>`46RM9aYBtN0W(UZ2)ICwJz!fF$fYb~Q{c=s+GRD?dR&NYAqUn}H?Z zfrZjKY5dQsiYs|y)Fo;RgPxg&o5Y-b|D^EF46#*8NLZ`sF{JLj;QfphKxEk9)2Ei< zYz7QzeHY#lZMqVa-Ci6OZ&&hUUXnBBOZ-e1;}Zi}v_RKjzRlH^EYaF&yrj@;y|+S% zx>r6p?0FD>z)JVT6JA+$6XhEbFYa?`W(mLa>jWysZCV*uoN}*{qmM%WmgHoceY@A@ zVLRZ0a5dzZYiI5`H?}17A^sbn0iTK1E)g%0snJN?6enpD3x?sm05f zK7N3WvgZfnlpTl3}Y*hNZZV(?V zG;c)vik{}Y8vy_Ds}jw5P!Of{==l%hv>e)?5;+nA5^(ieU!;P#i_dNr`+-EgF@jLmr-v&6?v4Z+3fH z6)k;=Kk~a*vPN>;0&jfi%h_JFO_tSso$IdbLFv86zSl#bsSS=EV;pp8{I&ZvOAH#Z zdtP~Z3cfa!=V7nj;hspGYsCoA6_>sbZpm9cp39Si=enD(x5LIKR**-y|EB3f`@zHA z9z)*}5XJl7rRbI=oR7Ld88m9)5r!|E`=rhQ!MCF?Qr)ov(F#mmUd*%bQ$j ze`(Vs8bxQTBGwb23TH!N&Ja!FSW8C8g%$ZJCvIL?BI04^K5Qv zv&d#}7SDiaQG~Rsh&8pMTzpMTqdR37q#3ts8uc@+O3c4~w7=0M&uJu4Md<#i9Dgm% z-Hg@4vFpeg3;cI7XHr6qGBNY@q1>K#N(1w&T@UC{rIxxU1j5qx_%O%zw=up#!W)2( zxJv;SvJl>4*JpeGvR&rG=VLx{J8b#A1vawHTp|=LO>4`r555n4Nz0JZyR&<5>V>of z2e08}i(|2}AF1x6ivH4sHQbKP?>)J<3Fa1Onq7G*bx9DZ ~@4<1>`^R2=yS!sa6 zu6SuUh+@x$_|b!eXC+z(;6X^wl6j_K$;Jn3cV>-h-JlIN!r~29nujXN#;9jf7E&e< z0w@;JYk=kLrl=LP7wu6$WuFm$SK7sHSv$n->;dFv8)K)&USe_dn7KoE?qSI%G2cGQ zmEKDeiB5)ka^ia!gHhTxqq&FGZS}*D*Q1}F>OXTm_Q)N#+}4hqx!#tn5D(q?yOhzv zj|arihU_|3>*|aZkn2}Svfq>qjcbRHMIP z((^hp-nm=N5tq&+3@_lMF7OMSaC>p^J~MGRP*b4P+M6}a#s%%D6FjLHCogyRJ+r(< z@LQ!J9=l5Z?m{PSLof5ByC%V3CAfBWI(rEWA&pa*KvIzG3^YCtM+-M20#4B~Elt?Tb}bINQQ@@!!u}?CErB$NYT>hs{LiEAJsMt0UBR9qq(qh<9Si5i5$cu( zwf8$k-)RzY^i`E(E^Rn)a1Hysau@0hK|e@MCC7g7P>15Bz`>i4`e_ePC@BIeKAX=K1T<@|3k(fO>8UnA-?b z_$I&T4+!OXJs|AdSVm88*g-oL5f80+@a3Q!gvQ4aT#@4QjdspIml@@Li- z^1X?BRVPem&uWfG_{i1)uZVrJBFSd~zI%cR58sIlG%9Hi?4Kzqc}iCE>REJ| zLcjlf?D^V(rtgvoq}+_3LGJ=Qr05s+7_7M+y~q3f@TaFfzu#9|zHt5du=^wRMw^CE z8W+m75cAQ?wBYNk+~)$(^6)i>uP?*J#$|F|zIz@dG8&*x3p-iWU}pA*OQuL}u}PGw zOwsA+MtK;CoA$sCYO*!O)HZ+HA4;8{f3r9 zV7eWaxTo%tMv8M22d|yhQK8+tKVM`x-ZWXk2-xCK+P&hT^$GaLQdvMFXwo8jQSWrZ zSD3Qfk3faD{v89Nd)PJ^`p`pF6C za(#|KCtN}|t)4R0>xZlndQTrT{NSP$d8*`#YJeP?cT;!kyDL#7kO4q6dt`{};X-9L z?YvAE4f4Fu01rU3&xO2{GNMfp-ejagop6hv&_NFEng?KjI0su{DF)01aT*8uBRmz| zcC3#^$s(ClaL`Q+Dzzz!5-OHuk{owk9zT^bgt$GEPM`%&h*5N%DyM%y;>BlH-UulF zfJB@9fKa)_-Eeu+x;;W&YD(bCkE$aT`aB&W@ZyT z+D0}_#|vcv*)@@k#`D0VQ(3mp<+gd0v!?Z*Fo*%{p_Vt=krG2^f%upc1N^K6Qujr$710|1@ID zScYlU{q}g#!Z~YN0{n^*dFcEcZtMP#gjJ)LSjs?pi_#yE6^Xg&|8#eK$z9>!L)&j& zMEg?Q@OA9?k`W9&W9hYZy-?aTiPP5cK0Q^7{VWzf!+ZT_djI)>8uInUHHk;lHWyNGF=M|+Ts zBe`ofncBbWbUKB~X@Ts>M&yiMEGdDeHLpvn7zg`F?@Bv*oaH+BZX$Y`J^(0)t&-_? zR`ai>wm^1O*MYq#YU9KaT4kJk%=zzS9us1A=}d~UthK>9igU)s)`n7CO#s6AKm$&e8r<={(Wp9f+({ueP_Ect-gT8rLI zQGZKDJe0ZrdG^Aw_K~cFgQ)BoT6>j1;pHTWz3$KQ&5QMTjftMd5|`X!zX?x)geSJO zdr8*kJIha`nVqzk$(~f})ZUkK{|JFzk}mpvGDteX&x8%`4tr~JT=Shm3bkNy^vFc+I^X(Ro<+OP>2s}R#O_ub!=u4I5>m}^PmOy68n;9&2=eF zh;J;$hoT?qU{Zf0$;UM|utODIGl@37Bg3~Rewd4^aD?~Gg2PkGV*}XfoDZy-}=^PDKUl*_k8d=s7oH?Zp4y*THhoOv08; zLm@!&K1rG|=@KHakIaxdyBW8VVpsar?bZ2!cY#B82U84P>m@scz1MfEjDPrKv4 zA2=wf&2;fNCJNLNu>6izB-&U=U-(oGcqnciecpYa)=Q$XS{$ZtP}7&R{=MheNHk0( z!|v8e5$E0cZyN;(CmCNUZXSX1(b!wr&}}k!8byp`3MJsqwCAaDJ1DMkEuC8)jK3AS zXaCx3pBTyRxi2(MbR)(V^kEUb^2Q7@hu$hK{)IZ-rj1|Cnlaoj^K9pJdV5H0s!4w# z<@U_&C?L}W`*53MVzm2q&?(rrl=B^n{`|U16VDZz);5A?giVe!!43%GD~6EPr%UtS zm*}Z}%TzgvfF=cS`rz#sx z+1T;3lM0}d#;c4|lMP#i-hV-?ojBQ+_-Um$$E3OZqT2@H4;VXE zT)@OgF(1e>KZy9maQz|!C@(g&)js0Hj38Dwho}!kX{Y|QR31kurP}Gea<~U{TX;U> zf)VDi@6Tx%N!i#nhg&86L<9M@jru>CW4{`bUw@x~=Li*gO0}n)HOWykS5&1g zAYT4@H36w5oR^rpTpPh@CkdvUCz4Cj+_c!6h8Xtz9 z{sh5qmnVNtWrsngJm&WB@%xaj^l3W_^y0nZ!=Bc(Mhx$s&OQX&U>ii^i%MuAAkyI9 zfr|{v6yao0zBttqOC{+O@))ZU4@WBD!y<~g77MN>nR=YNlg%th7v08szrw0~u?j|9 z9&IJF^rNSr*g#HWD&E+a-D9s~{^K%LdG69rn#dw^Ozf zHoI@|Tg{!Z%WcwxHaSNP3Q^vIs>dL-|?G1@n6 zgmlf&(oy{nuQy?-#LQA*j9nY$X_qAlJ5u>{>`A$@@7m?A%p{?`ekpf8*GkY=?-o~K zj5z^2zuKZ*v$?Bt;asN^t{YBy_fxS^x=2|*?I@ET7gLEgD{&sRNx2NcO(Noz4~7!B zfSZd2df;$nR)|9A=4M)6$V(q-H!hF_Pz{*N_{Zo?{{QgA=Z%j7Iz>&vqdChQS|)P_7~SgJy=9Br zPB;z3AxhXfdJ)*j8C-u_>i^0P4udRX_t>fK*Uw$(%#D(rfXZb-q5sq(0lP_cp?w4c zC}ipPXV1xkZHhv5hv;$TR5acj)Z>AC-9A%)S}ZK8!BeR* zXi|KiA4tRV69LG9GfwmFN8YY}cClC_Z5&t_8r|Op(GAakKvr9Dao3T`G_6QDtu%$Y zt8M2G$hPtQrxE&{V>6p7=BMO4{uEu1TTH!oQSEU(s`Z~lD~y6MLVs09uQ++G{5cSW zvxzizvMu%_{&Shfgbg2Y7c)Bi? zwgcnstbRml)~^tZZL0mB@DEer5uo=R2JRdK<`c3C=qys69-vcpWhrd zv6-6RKkV(1$qOmP;cX1x&ZPGu;H|KO&ccDWuG3NYt;+OIqN$XmYYtoB1jkUBc1+iSFI>VNzeFtzn)gCuP2V!J2~fO zio~5Tu5^sK!X9i2=vb7o_@NM^E_P6CMvCM(Eg(&5d4Y@_8GN$;WQWvL)aLa*Pl8P#9#4BJ{JtCBA2v-3V&giL=I!Hw^0v--@seN=zM} zx3VU7O!~PhvIX^`f@IW;km+tXZ>jl4Z}ScIxx%pZSy11S461g`qhIKzH;?JvH8%}! z-m7~jmz4vRnpw3wf>rX1tlCM>9RJUold*45WmzF+>hd%C20iQS4_YZusCuK)nmPyF z@>EDA93}FuO|e-r*Y{@6+S~fMHS^fuGmwafZ|`W|o~)E1;c7Avr0Zh}UVvTWQrK_d zO|wXueaGEW0OH(goc*$A!;U~3KrJGQ2=2rE`cA85u~N?{C9-s5|H7S&!-5Pti&k~O zA+xMp#ScB*P2tZ&-@bD4F2b}T;;FTXQ;<kY!qMjxKyzgRxuQo zQ*0P>9@?Kq^}1`e$pwfpMkBb0LjAya2DdY_1C~JZa}9)g`*|0!uNh6zY~;a>id?kr$3Zvle~g$As>VL*%W{6Ht)Fzv?t+(LH9IJLACZ(-?S{&iJI?Pu1R<^0 z!MOnYE~ql;1?sp-H|4TSny>E@>G84E)&O5F=CTc#B1Ikko|J}bDx!9VOGZp zd2zrQ=z>A*Q%qs0-I3sKep|=aJFCinW&Q5l?z6%V-No;FhrN=5LiBtFOGl_;wegR> z&KdZn@A1tx@3Q2jN7N*g&=0MYi%1E+S$Z$*RJD>?GBG>Hoj51@tjhJ2zP~Nz-LYr) zi=*%N?d?zuD>ZAchbxHID)m#J(Vzj6o}FKO528X2aHt_Lg}?idfgl2m-gQYUtb zf7h^SJB;wi0k;BqMx%*BJXcJC;$&j@0K5)NxW!SRt^e`Efbo}li+{g+>w~&u=nYPA ziXPCXvN&XhPz}*#@HlVVNF}OMxdJLT^>gZ*#ip0ac0C-XLRfLLNGW81JxeoXhR z-6Dh{!n;U?90uN~Q0*<(n*4s8jp&aj^T3XvcOjD=Z@s3&11!sBgL@1dH44zvw=kr0 zN<*6vy&n+tzEb#3hu&Vhqd{$%*C%ybpQH>a(#?7N7eCE&wb33-pUUl_)q1vZJ77fz zf{9GsF|k?_jg9mXYZMO`p)iy#K4h5+1zX*Ef_cw>nh@;xcyghraDKcRe zI%#HN>smni{_&=NtJe?$f)ePAn_ECrcb9E0=`+st_Rg#ed5AxwN}xCFQU>!1{?XO{ z+V-IDqiOrLZZ2a$fziCm@|YnjV0?}!esK9Np6@uuv%H_&-LrkA{pPIw`{{-N(WJ2g z-=twSc<1Jt)by|pPM+xtZJ8EUX|VE(NHx=|e{_HS_Qqt$3Pq1w0gz%4_~W`GxoMg` zu+>FwT7YCIQ2>{B?B1(Dumh{bG;X*qCBDl*23qiHJ>=a{yX(-@XZ?jR$O`Yx7?E=~ zqwl|()w|>6uh2Cwk4?L9leURoq|$J^XI z?+?OnOM7Hw#>q1AFJdWIrU)yW-PtUx$Im4bHu6^+4A@q49Gpyz3oy-rH3df&P_<2) zi3M9f`*MBckm9P%{GR%aUX@}!H3(+uZYO4@hWUO)>Q*Nq*) z1@BGLH;B8xCpGvCv+t=VpF{d1EAM41fTRtB*;oXI8)Z=zGdX!{pWma#t9@#;jp||a z+pMUk{F3c70VKCR$=#Vxm+%>{8py2Utr-2(T>aFV)*38SX$GX(|GNZry}l&zKJ?AT z;ws1by5{D#!gkgmq1Zp<>HI$08riaoWj^rhoy=<%=Vxl!t5R)xmX8JLP2_mbMQL z96E-0+7Cz;2@gCv=Iyx@ZU^S|HDY7zV6aIX6bd{#z~9n6odX1vBHJdDIeG~AA$ftS z1N;b4fXbp{DRxk+SF^wa4D?c-JbLwKQlD631c%PC*j{xj61$p^6RopjLYe)ZA5(&F+MI1J1%0{h9f+( zgRSf<1{$b|rbtIz-k<8C1jmlcV<^xHxrQZ${?_CD6!(R3}QKzQW+{vE~_zPEvfEvX3S%+V_ zV&47S6F~Hzlyo6eQpPMW_gAS*9#2o-K2}A|bS@BEuPj;>n`%gJzY6s3mdN8|#mdh5&S zuv&nxv1O>+x4&GIceN3Dnra0!|8+jhjQO^%DK(O`t=i(%&F^M7D!l4ve;S(wss~LM znr!RRwnr3yw?Q5n2-HCUnE2}{#JD6Dt4b)YNR!$eoIm^Z7R^vW-c3$@4%iD10)Mh; z;!3DAePSw$kEAeA8c^DOlcDtuc>VT;2Epf@u7NMH1H1(=pr#KMZ(Df|(S#mcnoL?A zSC>;T$T(Yh(f<-yUE#ksE-SzJm6(ym4VF0olMTYD^ktjNShl&%y*)*vo_yDsiDTjHm@U(W7fze&-AjZ$sp^bnfvpz!Gj)0z6Y-aC}DdX z*dE*V_G8txpDk3>fEJv{yO}qdFsmtMkd^bly-t7UGwHXwXYePo5MPHJ@b&eRCj#Po zK7e$2^;Rg`9eAyOpWw%PC;&b-RhP2H2YTKvhfg<6B0*^yC}5>r|A3+H2ga@FUx73j zI}Hz8Vkb;dT>&GfEd%REVv=Yk#MXRHqQ8I(YZZjQ!qQXo;N!w0xXlT7aEKN4&HBFj z10wR1=2n8GO2k*1@Y?FrD$w%{)5c08c9q@0pIbATyOV2uOcuj?otD>ZJnSJz2ZP_k z>QG~nFu_G%n>MuiZ?b+nx=abWZ~J80Wv>+fBOmwe%& z=3dRY04tEmbgtxeFmsGa@RfWW_F9GM;M>;_ANuOnuEff{*%DWpvWhOcR7%q#T5EIE zjvla4(H>1}OnEViyNpRkiT*UjQ-CQREgxTEm z-zxY)A8qqAxCv~NsHyZn*_5ShK$0}NfVs+B0z`dq)*6LvuuQ{NfrUx`j}|4Z{w3Pq zGe1S~ZUvvyw_(B^h#tSv3Hti-{uUAO8f@a*DQh+00G?e)_oMAs5k=dhBTOpn zZa46PrrYTsaxLrXxawgN@}U$_RdaR!XcG5X z`aRlGl>dvV&9VY1ltSEkC=UIHt2EA z3CerZBGxNYterU^@>zJPgV8CMXL%Q-(RDhRuUkXb^7gx$QIOSR+U!=MaY?Ino+aWU z58LBiHXZYq?;}M%$*EM&SVe0=enxhA4~6`IjHi=6(B9jpZ0_iuj+^oaq|cb}(Z$V? zO!V}A)AYZ`@i2bIal*%oP$$5)j08(CZ1Z-d&9AVYzn4d$`$QU|X*mG(Yy=LvW;5WV ze>r6VGhxvR9vwHq0p>S+MJX>4wyHy$2Hu*#iT0gt0t`SQ`olNpl)<_NVq`nr)^p{r zJD77TF7op}eehVb3Fw;~S{PtAfAhKg3f3@1{Nk5X0)9!b%BMN#@PPaOOU@8e6)erj5tuXq8i7%VmRZ{favJ z4>1V#Wc1)Nt$YmJIET$UwG9{8@AWO zUr*Pf+iyu*vwHuXZ8djR@(aa}c=}@sYfkyR^O#@X{ToH$EQ%o|u#dR;V~kxyEwjS6 zEB%ulVY4-CA+^?$5|`5_#PYS#Y6P>7>)i5Em(K1{DBb;r(NJf2_l9=EP~n*p(-cA& z>PU%`{-x&|H)agV7Sz)&pr)u9OJw$zJ3~ZIoLzJ4COz^hY^Z4iEv!I0&$jL`ive}E z30XDSvSMdShVBkCIlrvY6IyH~V9ApAR@LyI%bg|V*6-jxyre98(Kt6#?&jy<4qI#l zWt-0E!SBMy{Lmk991BRqJ-DrJ?%8$PK3ur6>?_2@Gr?_*=96F0euw>C)-R70pBt?V zB&b(Zyv)pUu-{k1Z1!OGVSvP&$oC-W6DZ3eG{Uy(gQGrhhQIVJbv8Q91d{fce(=Sa zpyGXNk>&O2$%oCP(+eX^Z{jk4u4o_a2S0ZsXEYlnWDyRaHn3-si?*GeAr60AsQ-@I z`2TBKBRB+a2@2pGLFBJfV(rm$G&Gf7byGe-L!SCtO@TFu;ohj<;ojr>26coInzb&) z_Yz|Z*%sbTb(wZ*@aXR$rPTBS@1bJfC9X933yhCL?5{72Z}_}!)EIp)+bdDMSjlTnf>%@w~!)G(v|a{={B>)DGn2NaAhsCo#q!M z5Q^XZLxA~XH&U){TOkn>@=5m@;_2-6PB)TdfyY_rF<>L=_uiV@jrRS825}!f&Mr$u z0gEKfv>TTEpYrX7=Yh{&vIBUv>&C~P7~y|S9N&7z{+pyNCO-M?ZB2BCogwYE250Gf+TBdqkI#-(5An!)IC;h%OHP`Pbd|1o_hWQAqdoEbsY^Bp~sNPavv z{x3MEfPnUB5a>9|pTb=d`|Gfa-=m-8X`gQM5FUl^74OV_$mViyJPY>W7y4`(y0K{# zwmuJtlvn?AZbpX(U{+$SjIzx4a^ay(7d1}L|0)P!2DWqz00j$xel}h3r)fc|{()&Q zp4*>1<{Woj1FTAZwx7$1*NYORile%g(9Q-N+cpwEJqW*wL;7w0mfCEmCZOdg-~b`f zanF_apb7t9kx$iAC{bl0RG&J^+@GI@^ZbmbK%WLjexdyYu>wNnX6s8IQ8ziDz$gjOgVrGQ6a_uR29lHgMCG>p#2Zg!t!qzc z$EOUy;vgs?vY~_7tpfDOTBq3lhw1NGf^6!q9$gQ>+L72tbD?@9qqJ{U0}nGX^CuFl zpWm7^RDnbVmuIh)7+Y%B_@0G({_Xg__RxQ?-HrhLO;cK&Eh^uS3p2YM0yRqh!km&-g`@^frR*NoHH}v zoNwmLobz7qpZCXHCd~8X+0Wi1`Q5{!L9qUK|5=@-=1 z=09JTsw+PlAu-!ttx?dIwjmO-wa9s9YYpalbltTjuw6((`ewA7#rSFT7plzC%V$5`a-pRiRdh@FlRZ9DRZ~a1v1^ zoI*(4R$0)6{unsdB!pB9lE5tm&Ybbs?|?i7r0eB=`mU?-<68C~#zFEKME9*m^f$jl zpc`OsjTLUJ9Mh+wW!9BnAIa@WXsMq=ZTGI3@xhDXXUF-z!be0)HsuHoa(&yykVqn3 zJiubL;=V(6iq3o!fBK%4-lV)3{5kl?&5#Mi$R!LCFwz9$fltQ)>f=3MXG?R5(}dLJ z4t(Nr0__=WeW0@dw|7o8QTK(s5u^fL0hps*#l~;_TMlXjf!*9~An`HbnH2eqOIcyDmhL4}%H@mNk_+?UA zqV_cYZ_HIRHxfLPI%E1~L4{Xr<0cdzzlcu%F|RdFjM6JFl#xw9=f)R^wGd!GyQlS^ z&ho#}AQp$Q;6p6(6)30$%6%I?kO8Er`l_(FXO-+Y(!*2Z2sL7}Bj333`X!<&F|))8 zgm6LtvtI}Jmaz^tU~S}xfBN!&`U(y=U4=3AgwxB7>mvp82QOxCZGx`R+}jDAs_zh0 zHlkoS*vn8enKs0KZ}~!qy%VA9M+d9^@QasOf1Xu3$UuV0h%b!a+1L5d5hx3k=*W*3 zo{ii2?K_S3Gjzlw+DosjgA2BWv>>;!^1OE@KG;Ev&AL>z*6y7+mZ z2>)ksn$D&Ff$?T3SL+F5s`Ta_BHgI%Yn+2S$hM)y-S{rPo}KZ+b*VwA86w5%cZlt;-oSfQXqjtMZa}gum8==Y z0=%2tNHo3EFs1sK!FPx~D5lee|I|8#;UwT#=A~3Ov zM(JdJ-p+!m=j<{p$@|MYW%%3NK`+?`AQE?K=Gyi*{~8I|_3xjePrvSs`u#DEmw9R- zM)X2YGtAh+qU3YnYl^UYO{Qv%72{*c5Y0ba{^nfv6&A&dFUGD!JQ9P@j7O!V%tk%R zJW$xaI8yq8^_aIzh}BFum&(3bUO*=1%AWUTVKGnP?!$nQkSUpmQqSo^6wF5XX$Noe zEeK0!PTaeG?G4ilGHy5dVX7#|w|D7UdObr`=bst{O)_#(1bU`%A~{0jZCoB2Ohos+ zSRiy$#u+^CuYVU161H%eG4;dO!JcL-z8JN-mC9g2;d{EGmZChDm=>Z;yWie_@b%Qf z7xe`R(}3%?PJK*cUF5WI zSy52`mp+y2v(Z29`IvCXT7z#C#e!3==$22#72Ni@>7jnsTYWRuhwEfOjzKv|?PE=x z&o$In4(zGCdAXZKfscYlq75I3;xi-aNLn2sJ?~Oyw%$jlYj3CDDDhsW8v6oz=_2SQ z)zl|ZkHRQ=ZL|xYj{n;6F*i9=syLGS zvT7^uT|-iu8tk5jSS3qw+x`jqhVh`=N3T_WE!wAag}p{?CphmhI@{ZRQMf;i%D|(g zQch~pjr*OSf2(IJ^*TlYZREMCch`ZMIPtA*HzRXwX7gtwYt@^< z9Q8UMLLZYo_(6xzT-I$JD9|l=W9yKdMflOrsn^#?UZnIqDSni6d z`(x5=2I|`8vKU$`yR%oQ49+c9-jPe=cE2M8ILNzC8n60)<2O+AXuJL?RqA=#NjIU# z?{m(cu>;jE@-WCg850{5C(;4UFJV=iy;xBFDkjFsv2vO2>z%mu7Kzk(!#LrhSB%&F zA@dZsj$Pt#Qd>N4IzQqRU381mfXT@A;0MXZVHNUDHO|mWiJA}YO%Drxit(W3@DTS@ z=^fP4zk2?hk-`gK4%kyRc~*8=IeL;ue$8Q;Uq|TPAQP@Ux$|h?Z1A~)0ot(Tus+Er zk%UvOj9?F58&4z$yR2H3{o4%T+a>Q@!tm1gt3&f8`3~hz9Gp-kC1S1lPmurSYQ65o z_N{%z*>BFDa2^@EH*0QgE-SY`S-| z1w*6GwD>ssXftD(5YPOVRL-mPbT1R)m0ppAfbG>LPx*NeRb@6Jx(hp=vVZUM&V+Vs ziLR!&GZm{nQ`s;3=4FTgdC)gj4RBJSBm$7cl}*%A@4xVR!E>lpXf$y<8I&UY`te^t zzu&J_-Fu^-GiY#V7iNiA3}V2%GaoiNN7WyQT2 z0|U^{r8#v``}510${(TM-(#=`YsQ=|mmU=DoCNl$gaHadj`9u2{0o z`yo;H^${x8P2<-6Zj#S-AfLarbJxdTAfLaz+}`*6C1Ym-GImICqG;RG@4R8?RW?(b zKuR{-2RD~ICsLrN{)JQgtt9|k`TJ)Ok_Vv7-}_^C-+;irRUAKvoSUcso0Qx1OWAEY z8j_>UWZL>P`Z*>Y7mmcia;HKrv-_qX8l?J&MWC?y?1XtS!e_1e{!>}hh3#HMlMrz_ z9r2$V<_`qwZ)UvPqS+6DX=aMvcKH$I)7O(9-lottl^V|~5d))^(49SRn%jB8sleIA>0sN(uN?U`IUvbUx(dBf z>r=H=R!ljJek8*la!PUcxT`XYuPCdFzs|gLCY`@FDh?#HL!52_TLh(NbP6fD+M4!` z?TgE|Jt#91jY4G)ejufhhmq*9U^Qyj*(tv+W(dRSnxzu0=+4=O$a+hc-eV+=;^Qlr zF|icYLyYC(a;7xN4ySy)P4jsiNRAx#^mrk%nYn1so|tUE8RT+MZJpkrRbx{_S)*-B zC(Ltp^o)xc!>Jn{c+0s7hSK=AGkOx+Zx`2ihj7=3M|;cPx?B?Q&t@`<7Tf+}=sdJk zz=n|%=66FFIXpumMtE^Z^WM0(msY6LEawV0-mI*|ugf(p$}!B8)8Ir?aE$wdtGE2C zh1?Es3!DhY&ep$_^kD$`g5Rwryve5_Me9kq!FgkAKb%HW`WiurV0KKks@?u(1bYj= zoG3xwnfQ5G-@E<}`C+Z9i1n`-s-dnZ|*%sI?bp0VS*e74H; ziB+P|_T5gj2!3E%yzg{+f$kfNR2|wiS*;}vka}~=nqO-=NMOq`j@lgwCjT(m7gn!RLTAH zpsG3Rw89x*j_;79KF$mci0Q~94b?_67@7qy{vo9TMRpRxU>&Uj8HOBIp~$)T`+miN z12=~$Oei6CzcluFpZYYbE>uo&s8<%64T26PEqzPc2*%uKU8fQK0BebX8+^<|Dqoit z#>-*}g{Gm}cCrz2V!bbDCNCX_Ec4c-fLO^P+>s^P>N4}InB$aTN|AV;Gm7N~@;S|T zT7>94z7!e#naB*;dUd$suqCad)@D=n#@1rU$P&aEigQZW($|$4nbTYpZoEBEkZS)tgCuZaNh|+#ci@)-qzA9Sh-3)<}SlFWY$zsXTr5m6O%&{nX>krVW-$1 z&yT)&rw$1h+(%9-cVm@#Ia)yxJ`9AX^AxY%6{dD{QNCT)`An;>MWLi9x$Mn^j6fBe zUo~UH;W%=OgfaKaY-k^t0@#=wGd?^wc)asO=jvwD*=8~y-G*9S8mT!>(u@?U3`&l} zGS4*J2nWi8Yl6p=q1xAka4`7{U)iWynXc&GUhUq(vOuc|-RlnqvE$*=?D9|1WliVI z>pE_ZinTF5GH`>8hbmv@Az7P397K)F<$%>@1GROB&jx`B!8s8K5SBUC}qwYXKm+|=XHJcUlW%j|y*UX%l=@t3AU+ZH=XWnt~z=r+I+LiFghX=4u0 zKFvMW0!pbJA=uy^C!%~VY7LzCiSBGJO!@-EFI!oW4>7ZM&G=7vC>MOOZ9JJ)d403; zWH~^1?S`%;P?jntgTaJ$T1O`jP}#lB?DtJ}>qz)K<-B zD&cb7ytFA-265#3b6l^w|~5;oRMz0)96D8-o4| zydkyYxDfv)6HJ=84JaLjvq7M4H-I_F2l-gK#+Ab$wGW~qSy+t;OX5eI^HN44p2+4~)|evf}6e7^YkyAbk+ei1_ww41mO{^?dd0 zhG}t|Sfg$4MfmDXA(H7=|GLjv4zKND1En4#+V_JzsJz|;9wrX2qCRkY{B1GCx9#_8 zgg|@5f(gLU@$b-hASc7;Q8_+@P}7UV;+;Dtk>P#f-ywwvyu}tffo}$6s-D6$hObkM z)(4SH-!BE+9^uGP)6*ur(?(Adubj+Y+oyz(;+>0ks|E}oPiYCuzg%(x=YyFcNV<$3 zD!jtjy{xTs*)jO!M69OOY!+n8Y7iy<#XOwbBl*^$kyqa#5=5VcqVoo_aV9U`s?Les zDi|=^E@E%S^DY!&x&6(2`@dv6SHEYMKd?TBPhm+bNJHx|AY@Z9fzgi=;h%yp$Wt0*HsU#}{|AbVfO3q=S%hFF_` z5`6%?34~B~ac%qt_p)RMu(J@orR{0d>H`4S^a%nmG7S)HaOB@U{wrb=L^mi-?50Z^ zhw%dZ<#))(=con^{a^B@zW~Ap-9dJK{6}^^OS%TW%+dy*D~;dU2VxWd8g_Ni>A{x( zbTk@qwhCiL@MlJBTfw)EWcIbLBqWOja=fpf5mrR{QOYc%TzDCMW`$jDLjL3qGO5e z`Ih$h+HN8fM%aJIWRU%1+}!ObCMrWsanE&G-i;`o35_B56%mMJ{Y_k_Fx8RON z+QT@YmhgQs_Beq>zk3cI&7LH)}3&QHV(F1y2RbVP~J4II9XcGJ|U7J`>`>1OmI+r z^&Bnrz+6B*FJ!B5#|)f`^+A5P774+xb=Vy&elq*W%kkAxoE`VIBl03{7s|Q6U}?V%VP#|* zLy@oeiMQ6V0U6FG?5sGO5ryDmjqg4xUbWA9kmqs0sRnueS|~r3XWD))#wwD7f5GU< z<+`PVF?bM|l5~)(RHnjh+)<`CX4VKE#|$PLWa#ECxP2^${SqX%)TFR-7+0}S3voGQ zicOb6;fyn!r9WlAmQUCW88luwLXYXcdf7f^*iF`6Y~~xQdyOvft9Xqo)TgdnilpirPvKcoJ!&95ywoXgeQfG-~p`J?{{Y z{N2+u5I<#pF{*$8xq`ITJZD*)NfM68Cbj6)xEiP-^Q!e>s@6x%ga=ahLn(Q_vAEB< zH#>hRrw`V_O^irN!RdAEJ8rkJD@v=ChQHMAsbwU;P|vD=QlhbmMtc#|?Q9un;6yPc zdYxIqfdPSrmm`~z$}yXff`sLKCqj*z3nYP-1IO%BQ7(yXt=x1~aMBJ*H%{-HkGxBjKK};H zC?^}y+7J+=s@Z-=kQ}*>Z@yU3U4N&9tsH$u|CDkuaiKncHV2zFhJgmdj$wme<>tM~ zja7By-aa~gG^oPrMW2<6ZP!vyt8!6iuBqKnds(z4> z+$E-Ib`o3<)~<%g72lAPSI{BGS{iguH~SrcuLum&)(n%%8DP+ry&FNLGSIqU-;WL& zH&Ixw4whFP3(?^r?sFG|EM?{NXxicopjAMVU2!|n^aFCU2iaYVnwiUlNGclM^K(4iN%Vv zj}0%?NazP@ox5o*c_*}~Mn(1jM1WKxA+b=bwVVPQrix=`k`i-Pi5`8Q;8n^yV&tTF zkemG6286ggN6|KnWU~!QHBlG~DoVw5Xmt=t?Qh)Fn{zViN>MG;GM<)Z-MTpTcEO!+ zuglDnId-X9nNgoWianQ{STN*}&|gl`dv5%}0`zGyTm3AnHL05zKXgv&Y$xh**->5; zxyT1=f8HDegHr?LeMu`^s`iHnzAs6@#^nNs9cc-zzew@36s%l@xLwWl+IY`X`EN>n z^&h6Jzr3Y>Ce}QLxGE*Q(%jIpjUMxNNVX%`Rm(Qrc0D0uI%XCW#a|6)z9Hwm zKce@!Nc0zrA-#+McimSWee{e6;(DfcI4X+uy{(mdPuI1M2>1%6 z4lte}zZd2<;mi4r#!Z73!#jd-gPEC(DJs-DaxsT@>T#mHxFT&Q0w}-na0Cg~5QV3u zprL$r>xo##lzUYsE{Cr8D92s9bf+-l;j0r;Z>0@wIXoMLC*b?bkJ>wp?kfls=F{3x z&5PncJk*z(Q({l|n6l!P-l2!V-1++;c8EjsnXn_+xte$yXEs>yi2M8G`H1(j5fZ#r z(;{ac+*nlS3K$kRI)m+j0L9AJth>)F>%eh1Yae z!)S~St{Y5bi?N^+gJNZz`!L8V z+qgr(MXA2+8PwJP@c8BRj=|4HadM&8n=TfSU6SgWj?X!^o>Xodjts6NpCnR#-O!a`llA9*9$+(woaoihC1y-KEtAdgv`k7rGTgc z^Yhg8?vJ@4Jm#p!mKRwc_m3A*kzPebc#5|%qvZI;ar}^awc{aB|SU)2d zjfJ~TaP9W9rJ5*;Dn(afWo~00^vFzDw2|q;;jnWIL0wbiM(&=q*6=(YDCYTN#Q_^U zDec?s3CKO?TbN0oQB<#JMKAP4Mmw>}MR2uL(8z>dQ`0KInuECoH-V|Hh^H<+0VlWz=@*c*TEzl0T9GqNUtda;6#JtIP ze?h5)JaD;zgj1|?jfr7K%WQsaSVC6`t=`SuSt=tO=F7pHJl9{(67>8;Zwi*&NoAke zQD^aPsWUjws@@cTt<#v;SigRqnK(ng<=LZD!2)+|a~}%?lv{-a(8V3n#N>Mp>jx9} z=i*GB6hEv>qaNwXG%dfY)gsKo6)I3g==*5V#IuC%cRlSan}c&ns``4lH=-a7)w4Pv zb^c*-3}!#YiDMikw(R3%ftZ{TkjwUt+Nfg}EMyK2u$dQLsF*mRbg)l)hWQ?i_)7^A zHxR<>rLD1xf>+O2^BteMi@l9CW_&_>YOw_kv#NT=G=Duydze+to>m*HTq!Zrd;jgt zf?g-oeg)o-2kP-bu4e16(Pze@gIRF+u~Km0rX9{M)%Q`Mc=?3(WS`5YHD$swdBmMV zxwa4CHyW^}=vrlRf*b}}i+W*Ote6>KnH6ZeWf&%6n&p4EQwu)eKqk^u&em< zx$1eH`xhZWTLZck>vFefi;yRRswT=$4-!ZzU8}7|Yr6+5mbcxc*CT>Ej7j9sgJrbq z0bS3eSaF#!i$@3LMinVkzMgtGo%3NcViD1M?i)p$qssoFir|COsKYR((ReysRxSU6 zvr(pwo`a=l-dpdR*Ce1J8>1hILO9wH<>$HS%GCA>Rg890=drxok&4gvpE_i|@{teF zm%mhI8pp7(kBnSQ0-D|zljV_(>)N_>ia`otL(MirCsP-jIS(fmacIzC^sBW{zH9!i zL3(07OLzab!tmLP7jGzITe4nU^7{{`_68|?} ziuz9VK!w~4?@_%>{4@p|M|(Lx#yMeHZxEYtcEkJ!d*A^ur9s^ylHeY``mLtO@nHx} zM^Md50mbYPE3Q=*hZq4-D~35_es7~7Jv2&__5d>?x88lG6LSE`gF7|c{BRtF%Nwit z4)Kex63)z4sd0gWxf;RF7Oct_7$$l3$&4k#qZf_X9KyZM^@B zgttam;%wSnl%ft?^tss0uu^TvO!`LY{3 zqh4lcf?7TpPb32|kAskYfMJjv#@jmUOF!_&1 z3g596G1bY)jUb@*YIPTNHK*-@%${Bqu1)?zr#X1_laYbEsb$DbKfF@6A<0K`C2EMR z5Qtm-zjDq$Ddz;Sx?d$mOlF3bY4br&4@Afk{EBX`E^SiuA<$ePS|y|NOMDK%*l2z# z-UMmy-Skku8PYf(YTOP)3z1@=E=L?>$H6b}D((E&@9~0du~tFu7Yo;Q{VQ!QiZ>_U zWi_9<1&TBS^p7d99%JVEl;5R2-UydgKi!EGzC@iLm8@1`Ew;5zwE1kZE?Vrh=x>pl zecR+)g`N*NyJwzbz+q|A<9-YW6Wc#-q>Q8JkQ>&hGX6>?TV^4~*Ae=IO|y=)3Bs(!o&JZr{G&fb17 zCHfVi#b2=+J%{HLY|)igjpfTsxYaJtxlcy5Kdb^73NyoBHj(1$q~Z01 zksjqxoOT5>;u}oj@ySQf(paK#vQ!orhkGIaGK_?)Vlt%n)zfm9fgp%L9V#%lOEOIJeJ45wechWY5)S`)L3zsfL7$MNhGV`}V53Ie{iU6eGt>X83t z43%Mgh*W4mP7WtJyuh!X7x8%+eX}6wF}G~c-2sho>w8Ss7b%+^o9RhLq(l@Eqbb#J z`Wa?dA4%`G2C1bY-37xH#V^{k+VAv!5g6B_%uioEyV=5fo$w4Z^{r8nu7>)ZMPen|V>Oy|%gK zZKh%Z;!zXJ@c@-8_u4C|0y%RcLa$NzGYircUkOfooIoqpE_KxTY#Z#Vvu*w}iETut5*K}*rxmzvG>j^X@WhSQmxd9g};cf_+|VKa^xOQn5h#N`tOz()5jEOn|-_&W>ab>;ylT|dgT^{_55#ojmLK~8>2 zU_#NGa}Nwbq|q71cgEZ=h-}m1gv~A6;71WWlYwvFkcSkW{anrASxL^G&;jI=%y4or zNx+U!qL0VKkLb`=ZPUpXVstY)zG*jdwmND?ah+4{O`g4u)Idt?#1KzP#&6q|fdo(5 z-({CiOnc0Ha0Y$NIKSdWeEcbRu}gaL2l047PX|&G0r6W001+Nk_ zyKj;uDe&$vm$RX#J}6t;C99?3<=Ql~te6a3)FyCIlylR6P)$Pmmvb0Lu(NI_I!j@- z>!~vd5l?52IDa~AbqH)P43Ja$)*5y3d3x~{4_!v$8GEy+__)pu26-6n z2!67ot#yKk)BqW&!1~qyzJe1RJl+Zc`YG{mqePX92hA@?rUT!(yNEIa0gm;^Zm`cj z!N1J#_jIHHKX@kDRa0plYETfY`x0&$8u(J*PI>EpErgP^Bi^Q03dGwu__8QMW1M~dK(hZ6e&BCvY>iuQ1DTQw zT|lPf+u=(0@%Q4;9g7lL+GMpsY>k-_0{`3M=C1edZxK~SuBe{wZ;aU%;0fEjKO2-2 ze3Q+Fc5BW5AL!}45~uM3X|%Gc?+_zxz&*w4a0YNs-B`Ubz8gWHvL6NDzZAbp1_GSz z?{a~EAhB27JJvwNk3u5l?%i*5=&HJ233GrAEX}>$!tQJY(U8$R4EBJoE5x!))as=< zz28VScDP(FEivE{D&VkdPbccl`a)Ym(pQyDUJOyH^<{xW;l+=gvFIzLBqur5ctC6M*S*x80^swBt5 z=8s7Yt-~3we-jtmcl{DM`*Y^yZTZHl$VivL5;K@oBZrCA$ z*O%Egz*mIp90QZWr=`oCD5E~#MIyPgL{U-5`%q!Bdz+522rc(pt=J)&!y7!hgiqG+RNW$|{kcZPbkBZXx4`m>j_<>U1qEMg9hzkh0kP92h=Iuux`~_i+5w^^i;y#Pthui?W}!RDDm!jG3)1g z&!|n{@HxTis|QHN##X^8q46A8!~xdgo5N5A2W{;VGg(_bu&QJQy>P6;O;foM)C-aG zH&xQ(ljNGWMKoR#sJlk9amdlWG`^HKj}$^a@id7we~y1`6pTftBJ*4XRyP+Ga-)X^ z4LB;ku&@Qm&nuc6-i4b+erD@O5B4lrxIX$=)OmL8Mi@TS=moJ+1aA#LJTOdLo}_Oz z3pwx|(vuz;zX8v?Gb^{8p0r6;ZrxtLHd!8oLA}TwkotVx)H=QJWw`}g{(f@_DvhS% zmTXpdZc@}|?biHj_Z~cMh@AW=z{=kdRe#K5E1z!Pjg5&kHrJVS0d1=cd-|2lZJ-^xBx%Wf~b12%f?Gj zQ%jeJJ^}Hzx*4u{%e3`@_k~?}D6j#3UG2ltyO)#q_A!!Gj4ZhK8sXUCe#Nu%qTzS8GWSz)3(dcx3WMt znHA!-{KgR@0sz;QU<1a5EFKGLcUkneC_o%*7#IRGeg&J8wTB1|jOX(@G|?RmDv)Uq zbwy$@?WZtf8}4<{s*2?f#7AcHXe5jYx7At(@42o0yzP@>_VYu^7o&(2WG=hI4y!&r9u5Ns~b&- z=kraCuA0r`NAm6Fo;hqvSY2i>_}|QC1&8ktb})0%BE&mh!$X^lAcnRziCE>-M8{po z7YPfVpIa&%T3&RK%JUt89@FOboI3AP&;C46uYJlecYK+N)iE?SEd3My6 z2q2ZdPS6FonFFPt5Jp1=^9KdSfnUm4ZMH%E?1pszgZJ;+g&Q*7Yg$=qb%W>g!>TqK zgzzIF(9``pmNHG&)v-fammaKqymtYLJ&&4(PxJ%FD6!W{A%FQ&G{s;ah^SvL0n7`T z1h1b46lYSK-{_PXmNLYK4zMqf@!djErW=Jth;pGg4Y34D8-2@*_*WB!@Z zYHh9SE|@+`{71`|vYAX<*eEcua*)O3;zaL?#(O`m4o-!WnsuJ$>CLzWpKOp*Ow@I_w zD86^>gU3BKs!(aXs|B*l01F8pm19ezdrYKlc!SPXCGNa zJ%>l(uZImsu0RZjt#oyPt_O&dB7KUXK2mhJ$&N}b!w|(xnNu|%ueN{Yifw2NI{76P zNLB8KjJXH2Qo6RH4hbt$*_+)(JBuZ%h%nyLxn>x}(`#|?nejcERP9Rx;p0=bCNttZ zWP(7`rd?R~tl}ZnPu^2XQSdVh10*jJrb{dWsHhxmkpgW}B!ttLmdyPHNz-QwYy->5 z&UGk#Ecn2c^_j+jf{7v|i1*v$1e(Y8@M_Q6l=_n$MD9T2bvlbl5G7>CZ_M|U7!6KI zSS|SpLn&}p>R86`(dMfi&0W@4&h|7@wsa)+4ua^3YVF{J(Z9^FTeBjB+IF-&h6Sbcy&ZMZ!xx%G1?}*$tEvi?90nV zuj?ml8NXe)ZFx*;twZ%PYX@jn@MtHtNy|tlsAHZNG^@6J887mzXT$eWU%a!%nbGeM zPbhV%Ra&4)jiX>wswrhh^|}0I@t_Nn#}o+`D~4n4dT(VQ7KfchF)}I7xiJjy!`RBb zEHgs;WLu)EW3r9ZY2+(53ycQy^{bg#e(?ChN zRkImrQb0KwCT5TZN;}!nwMs(Tcj%Gby^}W?SenclI1cxHqHM9cWiMLk%7&rwi(|)> zm1y5h%ynm$d$st~aPejKyjaUVo7*ZB@X_EykEMjZ8ex`vGi4TIEmf0Vlq49i_~@yI z7e1swO5#!p0>eASF6S#|;utzOI{4lc2v(KYiQdeft!HZ#ZU;@ND;pl7#>K0&+bG($ zF4!>xr)r#YTAslmI^dGby&&2BqGUneCW{?;5~yhDaQauUaXwZfsH44y2E^m+_Xt?7}MbdB*IiscP?ekviaqsPugP-fTJdSsU9tEuDSxWL1Z1Sy{k9VFI3!XP4 z#nt7ko6@tXm7jZ3yQs90rTK8`?M)`X^N{+_`<=QF^!Bt>aTlYNim+_Kz0ZVvy=s%v z&b?Hi@?^~j5fr^=5p$HJ9xyQTkcE7O9T?H4nJX^v&Gm=!-tPV;JW$=lg*<(LgN#$8 zmIRQhuq#eEQavvx1o~U3jy=Vdf11 zI?{w>=(8VmQta)DEzy#(n7l4hRTEl5_Vm&@2i|zcR_mxWpwmm5trX04kpng?Pse*?UY<3Mw++k(=jaY0NTIto}+6R0l_f z39Hc}44&%BC6CyeAT;_l3t4>LY?%sVQweo|IoF3ptpmpqj*0hh<{Q>lD4{*@?R>1M zi{*x8UBLW%W72d6Dy6Spkh3GQVx)z-YU>D`bjFcIsBGpqPILr%aj+J z0XHl@GA3q9c>a|NxybtZdIVNW>)OZ4gCFFWd2U}*b^Aa8A*moO!kbaxEIVp7Tty0U zdouMe=XdBhgtB4Y?$fw$>d7K^;>t!=lpmy}Yc-Hxuj84J%ad-2KHlUr=LAy?A4(+< z@2$i9mUVg8ZE=?ppQt{Gh+~tB*gtf%Rj8wuY28o5aITeOcok^^EFk^}@G5b9royFH zjMo#Bbn9JA@Qs|yizo5H%YuvbiH7W>k+o8^n5n?0_d1*vjn|t@OkOM|Dhpm;Z)UU= zF?h($1f%op*HDmwSRIR_#tgn4jfK}m+PXfDUaVJ%8y6jmUmV%j3(p<8RDN z13pxA9g0$ljc~1#Z%z65j4W}s9e$eo7+rOqp!RXUYZRdloWe~#uFXUHh42l+hv2li zR(;n$)eV`JO6bx|h+RBvE+<{<6i!dlgtM!RW5y)6q&Xz?Y)82~naEmVZRE;4*Bp~A zc=7NhF;y?7;XA?f_S%=R?{_9LofeqLOJ;A5Ji0l5Z&vz*I{5_(4!>y-$v%%Yt3-I0 zUvfGC(@0|3=(xw&J&zWpR6o&nH1v$YVh?wB_A0bzYnmY3qPMTE7-Yb% z$M26Eo!i&RXPTn@`fF^*(^N|8*FEwhFF0aYmw7XG2IiQ(M0K3RS3}H(--6UY%IE2Y zV{f=L4KJw`i>MFd(V+l`eBK#sGTs^ic>yyxZq0=RWcDDqLzo!*JEinOa9n8KRNoiq zlOh064}pfG5Yebbivrt0T&r;FA8vNq269tc@t)RDh_gs1OmevOK%j{-SILa*Bg7@c z`;#Ln=N0ymf7sw4X?6}r+EY~O@pZI%meUhpUAdY1w~tNznm19J!eCF~R(fnCn3Q0N-5);Ubi|lykfnX{P0yh5;%TQo zOnWOVW4ZYRgmhU#075l!_nbH2RIq|I;@w<@v`yaCP1V>r4*D|$ue-jX(|sah(WleI-Ru04G&8qOH?(1y zzm@*}otA}cd$ZaIS4rH^L&WJ=yODGx?c?6SVy)SadCBVukzqYK3gm|Lq1|Mp-Zqw1 zaw>QBQ$>pwd~&{oOBto5pNleG==L9tmXzc~eael(^|s!h;?-_tD3(8at?CMwrRkZ% z2DVoV19Y=;qjoh2$|vQl9cK62VVuGf3oXeG#*%(tHW)*-FOR3>!xUuC`ob=H+`S@qI3x_Ubgf%Wf|2K?|)u? zto&g|ru)H|7i##q1zjOMXq)p6X@zOQ0Cc%lQtA~6H({4F$InfI1AP;bJ|ZCnr@D7Q zE>rvuZnKo41cjJwoIB;7lg-_R#Lb;L#$5sbA}Bc!tbKeWq9fCHWrs=C8LvX23l3Gg zO8kD(@z@n3rU!HhzqMb5{F8175uBI z@vZ;cJw4@bvwfgcH1+JPdd=%rQo}X&HPBVN8+~bXfb6P1^(+-o%RR~2;3X-S2Qir; zal37-(mGtajo zv6k9O`S2O5H7|JCPr0#w{P2JJvX@v~oIv92dG`V;*ED-wtPaeDp)K@Jvr{Lq+!Lgd zOl)^jNqANx_8SIEf5;^zknOH^eH&dvWG#^9Dvtd~)^5l}WXkZzIg6UT&dhDnd^kf) zzCb(Cd@B$0EZ!NM?NHh7^vm?u+~TWQ4^G9JI^0i}I5T2B^=>p!kY{ zI)jClrJA}N=O}n%lQOzziBBz@pN|@jc`oS!NqK_6L$gx@gGgg~0n#tc;t&YIq(+J# z*EF)>!J-_bAfa7$W6Dm@N@F5m1%1G#o%FnsIyfkC)x%#hS~73dL7F1 z{62uiG>e@_%p_uG;=;!!ndXgKXL0r`o%8wnqb{_`jfGcj(p;=q;18uY?1;KG7&!T{jd{BO64c9M5IaP98Hxn z-;{>8=v>4(D_)}#i+M+MslCZafrm49%IWB&85L8gqr9qNg!JrUMcC@J1c~mk-Rqr! zS3~kGlOw^Aur2-^s0_L*R@3L$)j)2~#*z87vWfMn*)9H4-yx7EbAwPB%Ba8(9|9n! zFGHXpwzx~T2k?cg(r)eVLkxi;)q{XL2Iy?DfPkngkaGudc>$O_a$0d@`x^C)Z((NK z?3M^?JkXS@VW#{#f2A|i{jdyR<_yCp4yNP6F1JNiZ}~Etemps-_v!2e(@3Tsm z_$>?w#NYEL-UU3gc@SgZJ9gFh16Q(l71_1H*pq3bhy3iV|IvMD&Oh=3>`a#3yBwY1 z=Rgxav8%wiyPi3)g2_wVV&O8I-Tz>v-Lo9*a6kRFF8=Ade!BxXZFLlx5$Gi+RZ2{ec4z=*I5-o8hN=DpR2r&Onk!JEVZ}T&tvRP(3?L zI`Ntld}|4Wo04y#LDuX5)CVrwt;W%9{@pPCcp@EXOD}05|oUb=1b8fNz*?1z!f%5oB^Uv$G0O<!&jj5tsiD=^=Xp5mzR%o;hpQBhMtX~4~{v?JFUNG-GHExhqvFZ8g&sF zXU&d!ht|F;Vj|9p(FWrV2+!>q5aH+yJ)E|7J5tDL4zJ0#^uXRwP0%vyf)kqEz-xW= ztSZs)>&$qOPnWjyAhKQ-2fTR2hzr52>+Mi0N2`|O7|0N5ACw#1KgNQQvK|EkD0o2((uZjS5 zHit`-hHHOVK|o#cpQ@N4b$p(~yJkU3^++KxJbUl71~bZL;ebI`0T7oH53FQ*BX04( z&`r}{+Zx(_erl%#dj)yOr1xlR<|0sLGn?SN$jwPM_zyTNU|_sXumP1ITA~N_Q_ATBo_f*+rn_+JRv{X|Uhu(2=M`8UBO#e z( zntCS>F!wIE7#0>p8kmn5$_zetn<+X#66mm=V}M1~ZgCM0cbWQ&-q zZ9J~ztLNrNv|RY~>DfWUS1(?B*L>r^XVQ~Y%O~3V%XgNZd2bKv>9NPs zFNiwA*VK#Ssk!L0`^&41qlFEGOV~pd+pjrLd{S|| z;iFQR%o*hi>6=f3^r>b)r!Blvp$Nk%R}JzNVZ+fi%Iwh+V2R2`Jo@(S?)8nsz1l(^ znr{|Q1u;fw4j2r)_^KlNL6kLnZHF1Pbg2T0T^Vh~Q}-jLUz>ad!a3l83Y7wYkdFiP zL|<<=$)i6QR6*#{7iFTqdBEO^Sg4B-LK<$zz{kUkUzC+}M?WYzDFQ(zlp_?Oi zTThYN+fnb0r5H5d9zUX^le)pOfzaC2G>!$}<{uRKYG|+V5_SEpE8{~s8-xfr6QT&< z9Fi6EO*CrZKmFy7xeq>myNv&f-f2WbE3w~)82W-E-><*n!SNLq3-nTP&jW0$;qmeWSVAFH^0kXz%eo}AD)99nqtfFRtlQ0j7%;{I#)82vK%n(~OE@>mc{i%> zlVEAsMt+vZTSPOJ9(Bsd=JopQ#8pg}y1%e+b$)qLTxjZP4_P|{alC@t$^`f62{ zP%;T6Zca$^EuCq7)=};wnQB&`)SjgM33G#%>RYdcjkXB~ z<$w*KZm|vqA{MGm-D?p~d{o{nOSG5uafKY^B``#izj(-;F+wmYo64ACHi15S9T%B}XC) zwtQtsDtvGBrZp#8sH4Q-YZD+c78 z`{diXQHaxT3UR5;9Z#4&D|~F-uZ~U{#c`i~fBVoQp@5s(2Oljp@VGaROiruLoHV%` zY1VFHYUR zF}coT$DI{m_HlBY+~VQY3t#7KRkUK>NR&3sZg|QD74W}A-2XVM0RRZ(p%h~YcyCwe ze$Z}h6crw^#I!l=3aUKIU4?*{onj(%f56&SUmpVG0OI=0nS!anRD79T%Gi=q@bbNG zxjpI^hMmL@x`veJxUP)AVqgd8EC<}dLV>%&NstoJ-*I97u&2J#W+tkMPs0Hw454nCh9Nf>wn zzK4%3K!gn+Ar-3GN^O*G)p{H;3(p74eJfiMqbm@@`yqo~z4N?K-FE$C{fFD9KC+rz z3R3L3&2+|M{>*x5B40EJ_+t#wB)rB<8eOmF{Wt^X{jC-g`v!uC3q)$RY!{q7x4^HH zf{h{yfwcz2;ogFdhtczyG@xhT@8JSKvtGZdHgmQ)lVP31_%q4Qi^nzU9-=-ev^2Kx zXpj|ipoKCG(%O2uT119IXnYq@Cd{h(kXgv&97E&Dm)Eu^h%2IFp;a)@dyL8S^6o~E ztrI~$vDR_k*d1dGe*E0 zEM%{ZOdBe*rga;6sj<+GloS)9hJ$;%t8IE(#-mEp5bGFx66#T6ogXe9PjQ=O+T5wU z_LYH}NbsAHs+u6BW}T`l+vD$-UJ9#HQ6+cjAddA5V~b4i!}&2i-A{S4=46|Z zhBte1r}T-5>E}ryj)s>qb-63z9XAszE@doh@IA?<3v^vTb!2fA@>)!+KROm{)_LHA zaHWF40q7o*k=wysL}U0y%fx3`)-DNmN$*EjcE{2v6kqXLj>L%x&CkKZ2RuyI&9+v` z!^w|avBf$%Kr_wWhmWwo!ftHm5lC@iJ^#!QsQ$~i2jmKF=Rd=;;VbJ5lc`lMVbJH6|Nh)D zO=Z{!yYo#LdnU(=tQPkY1zIooEnpU>(=rv!B7WX!@-;;znEYNjx`r) zNb~p$!7%S1(W%i&@*fgabqT?Hge< z;LL$qwN&>CObyh5wg6D}>t_meb|CP+-ivrd=`_tpwfIWtXP|@+ z%vWfnNkkQ#hA(PXuaE~sl_x>ITk0Wogxd#H*rrLa-N1F_K7~&5*JneIZ8X4tLlGp> z^=oVicwkTR^>Zi|aAME)t3mp<^xeYd5_^R5m+zAn(jJO3TNR#zN^W2EBI^mp>;6x6 zGB*!vP~H6Z5Aqvhkc(Jn)zM)t=3W2d(vWK1V^p`)xMPJg7x@Gpx@EfzvOX8(AgnfPZs3sTci_H0KZ zFpUXct}+IURKC6!IRHuQ=tQp&P+)UT%YSNbzIurW6_opC>7Vab@DXS1E;x;fNp&OH zAG&|{gUf4&x-(1*$=P}Gr2{C-=C(OX?=)O-(qJ~#shUBzQ)AP^FDward;F7 z**g4Cf4qa@I%7lfKB)H8nD;W|nzexDt?vD#pr&Kk*iqIfTEL&N>p}EQ*N`?wHK0JE z;jc$`Q|>IZD8-o;w|M2w2`>NOij9bZOHYmS0z9LvYBDwMqP8d;v=noW>r?v}c_Z|l zFY{aW*hzY&J9LD}hfv43(RNvcN6i-)cZV@2XLm` z^A{dCH3*E^oL!#y})Z1`;lQu2lRrvYAt9s#FxTH()kz*Vd1hlMEDS>j55 z)3N>;qnEhX+NaC+%d0HQ{QDTLS!~~7zLj~0A4&|7C;S2@O7!apC%ILEO*oPwms!DC zOfKjHJK#FlB?^V&?>57Kyo-ib;x&v~Tciy0w$(%(1q~f9)tR_;QmHK!C=AbkvW|bc z0mvT56cl0|Y27X^3hYb0;d8O`_1rM~d|LKvuorwSQ~|FCDnrC|;4|yg)TAUq89hCu zYGrlB!X)hEhTt-<>k?u`n_!iKH-%-33wJiS>&ISjMio2VPqff`zSqvmj-?8;bJ)v- zw0Q+z@Z67g>$T;TkHo+3GCbN?>9_gdn+TB?{9fb`>wt6vnjG&4LnVItIRo2?v1HU5 zBKR-?@kTA)UAgtzVj?+%=XkNaNM2V}nHolQ;W@^hLdAEf5gzI|)QeIRD6cH$ar zN${qcjfULzS3q)SS24Yi;in2Mt+Nk5XuxPHcNZdXD;oV2PvxYfvdYU{0yB?z;WKYR;53}PL{Fv)EdD0mU)Vzock#pC)` z4W6i{>cNQ@jtNFcPL�d!GrId}|pusQ-U{F#jEXFbu9gE$Q}=MV!Su%Xt~|8ckrQ zQTewj2oIU^JR^Dtruj~bnz(&kvAs3=#nN&wf<^3)m(0B$_SE6^tMgJQGk5E8?0xd9uyy>>EH2QI{W~ zZ7~1YBm9zQAx5NjGca=OwQ)PDyHL>K|Eh-g-1N%8MYrE)Jowy&2Kd0%j_n0I{fQqe zZBn75hQ|!{0mGb64=LP`xPR|I0x#~K<-&q6H6L18M=|ArW9N6KT9p9e21kGSbO4o=ug}6=SCe^v`GTd! zgPJe{cA!!)?vI)M(bMKf3i=V=e`}EhvVXp`Sb+GXLqNK&HdTtWnY>kYk=&5|F3S&f zSPc97Z{q+69&iAoGTIAJ3}Gb5DQL;JwY!&=5oNi194Eowce@Mze&^ygj=2yZLY)b` zR(?+PIsp{m|DMBnWG!{E1e@lhj!7@XqW%3pIMM7&bpp0}5AyM~8`FHhVQJL38u)yu z4u5I^)Ed~pXkd~I5VxfYE_0!`ZOGI$?S83JzGY8G&02o6d^64n`@Ky}3Uz+*)3}5x zwxJ{@?-7|(u^t|YBMdGV7f0C`?V(B_6!XrM6qOqV&V?2OTXA zemk#G1LP#Y=5`_&!Fny|CRZ9I5kS*(%j7gUlnZdqoc0k!8R` z?uYAoTBEi?12%}6V&GMv1Qr?9uYhOZjZZ4aDe!1bcB|(Gw0`QRKgN>&gdw>6EHW;u z2il&`+B?m{aO`Ze^%ggRn>4eE>iTo{K16M(pug@_ie&8(y!YN-^M*yjl0kKGeqY#T zvBj)7=QnRii&O4T67i|-NDzVgH0dc42wEAp;^-N#u-TMnyP2i0FZ3ln%i0GRPQ4|; zA4kRYj)$KS3y!+Lc#4)bS{dz6<^ltAd?ttFZ4J%4>U3c;9{{%G#)a)h!8O+9qeF>g z(?Z1V5>=M)e#^QWSuUOxkLl_x#R`{K%zF+`Ou`hI7v558!4$l_-#;I(JW!wz|ZxHaMmUN}z1 zf!P4vmvJg#fpz(#WRzKn>)iyG=k1jauQEfb8ups_E1Onc_1eMh+V4A3Cn3~)4Zrtn zS7p27Ydu$P1A*X6=@&kWR41@f!NeVLTv|4jIrLOcxO`F+50wzFD^v;IEYo_gp~GgYl?O?gQcYm0 zkB;*?b^||60tTX`u~|7lvLxGEoW8U&7%}_|t4lqP&h}%LS*Je3x2>&RXD{AdWB?5e!m5*BSVjW4jc%g$2;0v9dbS#@;cjh6L~M$oX#UeQ<{$r_iU zQYtZDBZ*u8HuF!gOn}lK@F!H2mZi`~S|{@x`L1xtOlv`E>QwegZwXV}eNvxYQD4E; zRoW!~aaQRfcrObeQ}Xmtn@n~M^tjxwt6rU@e+NbXLoCP_Vz!Dtj>7GDGKHeN@{aR5 z@G*YCk@*Z0g3W%0tsh_Mz?+upC`R1W9{!*)%{hwsh73JPGZB5J5#t@5!2c=wrqsyq$%(Z77en5OWA0aeGn`VH|csSaXT2CQm!3zE3{E z{7G_7<^7?H_nkk9ycT~ZdgG#SC3*9h(V={-=MG&w;8joRWJjhZonrP{=WAe*ok0rQ z0{a*&dB>Jd{Gh?Dx8|^=WDv*5O@l*v_N+TV(hVf*NNX(TRpv0?gYAE{V$cpDzlX$; z;zUr|ptfHBT7eq;8}=4xH~k;kUa*IMRZVi4XrT%WoaFd~Vlalq?(DH$X6pxqrwmY- z`Bv!JL~T<}`7r>CRt%;dzZtS7z2-$9D9^&Wb2K+1W3z)H{Aw##d;f zHIzfGV^8@^)?`6gsKZ<~qH(y0N28`q%jZqAV~&`9;C@D*CY(Guw9Xl>jmq=!o6uKCY6CtU&zJ*>Xa1T$u3)jTCO_9$D0x(ASk9c&CuHUCrW3 zw@64}nCXM3S0AbeCA>PUP>>8~bl(ju*}OCZOHo^=cV{$Jey}^xDSYzdq_0A-U@nLJ zHCJp~7uQUoS1*saJYR9Eg_#ZeNZinM(y(#0*n*(cTzHwO{3r-1DhJ7`E6oJc*B=7i zE{h3&hOK)CCUOm=ZN5C+>-3I?>pkyoQyg9ROUrAWFWOS;JEXs)z$IOqeqp7yQ0VG8V*bN+(Id}%g)?X0gIULTtQ3J;0yzzK z9;%>J48H?Y|Ej2q$*F>p;35!mGn$1)cE2>o%&vb0Fjr^HJdGCd@Bi%xr^BBffvnVG z&lJ8Qn43B&pJ8n$`#~Em67a^pxc9B9D~=oe0*74ZB3up$P9PA7Yf*)FL;R|Fi=i8&rdJIMX*H1 z^40c{7nhaWVkLI$VTUJ;#4v^CyUlbTYs$BZ#W$6c;kNfVkB&aH(8-Z@H7rc2W({pR=!o`NHNV`FaK4r_z=PmUb|CNg7HUnugZK0x^wa zH;gVG_aElJeEN;?o2lj3*Pmk7^+NY@x;CYUzsGEw%+gFAZK@5nWU;_HwRU(Xh1O&y zNj(=&3@XhsV(Z*-txel7t&4O5=A2W^WvF@{XZFBzm&OYN^A4wbk!8ZNR&P2sM?5Us zee!K2UGe+I@C}JQi100JT7u?_7MBCeJ2--rxo*FFDZW9DS#gX(!50R0y*t;7yXQA& zfs}D0nYD;0+8}8axUQo{-aG%H+){kRv!f-XWTWT43$Ak*Z{#drJD8Ae-6Wx~n=;)$ zrHOqn3!5wxb}J?=%HiY_>kappXEg~49eH|jAla)pI8A=@1+a%LmoJ(%X1?wAIalLi zMnCIOcq#8xOTZeXvQtv?qC~15M`xXh0>Jm^K(2g-f#TAUfWikc^gnCLH1hrae$aOc z#hh))EJO7~D6aFl73;@Wr&wW%{O33>SM}3f`^e}jT^72XeQG|#uZf+vC?^Q-T+|O|=3pwB+y}|5kk?FSaX8gg20k}Oyck_X+d!MD-uO!R7l+C@Y6dA?-jIr3HH6Eui&i+(*f`5PFd3L%25ZnN@5@6F zGaL|wfml?bSb3NHGGSK=|A#P92F|zgkrS;$?pnceZ+Yb8r+NAarGPD>aUOsSf7>6a zH?Q;3eH1uI*E{TfU9b-bA{(h6_LlbyUDZynjiBmrt6$BK{DIiC>T*D9nTPJEIk1m* zOfQSOw*C^ZkJfp%sx|_^xu4wLcYMqruI7UE&g9vc(7FVAOxb2{J=VN+YjU-Kfia%i zZ2G47s@_~m#81`#p6a@Xb`_H`7tUbLxz&=h2M{d#CT?*|;@fxa>yZcUU-+s#EhuZkT36Ysk3J9P11%U!1m zD%J=3E)2>zrqkG5lY7)ZpwO~<7AXV+Am=KEJc$O**1xOSoj{SHzWWl$P#4{$XA%&r zCs5`QmB+L>tueKF!Rj%(k(~7uV+c53R`)TIFr4_Xn?uV)F3)mrfL(maO0oaEwcD14&yM zTm@XZZzB4KOHqqlfQ-~|03ghN{tgXJWAdUzva5w+n*=-gK|F#@!8mt-{sb z2xKTG7%fgp!3+rZcCRevlT{7NCvx35@f#;@L`{0gn&o`dm_a`(Q5NY#J%tZ>-uLnf zr`OI=0Vlvf0s{dM7Ah5S5CNb+aWx((NLJ4p#OxV&B!&1ihqP>&aUwkMFj>w(e1q{6 z4d{HC3rZ?U`1Ix5!kA>vNWKdOVP{U=#|4us(Q&D&Q5WzJ_Vcw@KPDE=obc@+UTm;_ zetLtW^8p!;_FKqNq%+YuhoJMm=BkL9wIxY{oazN=fy;Hp5F=4huq;&&sV_l#I2&Jq z8Z=qI$BFsK;k9pyDwfd63c~t+K>Kl?j>GCOF1ozh8eUttnZG(-b>o8s0h=q~@qt)8 zS~dYGRxNx~Ig4G|%uRYRy8uqS7JJ?%*s@;9I<@_ZfN#hGAN-s`Tq(k2i*c)*zAc=kuQPB-?+jvxnjJDr zfS}wz?zT=&zoU%E>vR}V!PNR-y@d#`vU;YIdiSyfiB`H|XM~-hwUmZ;sq`Qr%WSEQ zME5Ko+Y9+y6Op|s>vY|#307Q6Z#ioDLoc{S$8Q`e5Gwo(W4PyUX_rJ7hu|vjNxqb# z;;m(Iq785cKMovb0!*iGvcYMqd4?ljbHJcgc$hseD@D!~1h7Yu1HykfRe#C2Y(YUK z4TQ<8AT#TOFEfn57rHb^Yrf@g!bVW}Q?T{cU_viBOOsFdBgnZZ(uUT_zeTScn*R)= z?ODLRE=ZHSKw0P{2-sH>4eS4Kzzz?<3}_XuD1Bi@b`0b@g4jKEsS!b9&!lPjVwJ6_ zNs+cq{PJvR$BN-N5~Z*;4gsLgo|henrA>l9YyD-a9PvAgY8x zUs8p?v|-ORQ{gm2L@-p}_=HwrGde`yMfg+>`L0#cAEd-paeiyTY6qAY2G{S35R_{j ziUYXhq$ary4D^j8_%f3vYOxqZX(^dBpRxbFNDG=1suPH7Qwe52AbpF<)@eVql zXG)Ep@9bXOdtMl|sm4+KQ&(pNFTXLpa0aadI!xq=n8Fr38rkeAbp#gu3(r+3n0(Y3>YSf6w&>z*l#Ei_!y~0!j z3Y?ALDEs=ELVpaR1ZDv5i6M3`O@28C|N4f)oK!dkH1@~NxEsyzmP~{FscEk`uGP0^6RKj!o@!a=uh6{*S9ndU zK|aM726j(_i{z6)m+9M!`n{L7>QQs$A;nhgh5B1qQ~p}Kf}Y(hg)&9ifz9CyV6LpJ z1Mqg>IA$LHgiMO3`uHdLaRd1+*vghM(;Z(z>c77ElQsGE7q54TP~O@g0lYODj{a(O z7xbC5>!sbZ`z+xL6`x^ZM>w|g!m$kA>72ykC5z86ktV`3@TgQ1qha(j4A*M|9v_g- z6R6Pz=teVC;r!`AU~y|)X2X>z za#K>esoZ4U9HL%FF6H1PgL>Pg^E&c6CY>pVN*m?0UZ(p?B|W5{RoRb#Bjc~?Im>Eh zEXQF@yGhe~B|gJi(!_(`0lEd=&Rqq#Ii(pb?=HVy{N2=-mDdl} zAHFSLbL&mOh0n0pnER}?&SpKs0SZEv%Aic?r<0-IQOGH5`&MG-;f}4(_u{WAc%?eC ze;_aL!q>qqG{FiBOd&C)SK$wz^er38SUm3v$>qnvTUI8nheylfV{x{cPm1A819^s$ z4_}{dyZ^xKLDhK4MNk>oqJdFa)L28V)w8|R!J|HYmY|*dbU9ij(=LZ3p(7Yf>XR7t zU7FlnydrwnVPrmPeXXqUsXQh2qRKJlG36;`{*~5JIVPrf&2jchHx< zC8NLJnpVy7?-xj=oLbDdmE*H1tf_TZ4E(M7AvsFP@Y_r}+lRsv>OmoF1ux8nV0-oVG2Ldw6;xsQ>I&>+3w3LFKhi{Ccc?>F3kyi=<; zh=^J{emY{*U4V15GK_ED|50waOx`MWRU5=S4f-$K^6&p!X%fKRVewRnbIpej*I?Wa zl}FaW6!NEgb%^&Cp@ps`kPXDAp~+hiOi=wfEcr7m0gt7f!RR-n64|WimlkmqAzgR1 zLuqLjlU#3`)3%KKf*6%J2F)^ny(>)|0JSNr$RI5|s>phM^*q{h+59C$mr7C-Q2G%g zh(eVI!^xOoaD?%p(pbYMN%Ze_B{-9;5iyOJ`MTC2QZFlC$ zL!Tx^F5Z?8O!c@QJGL0Tqruv-dW&aWq9XTbbnKNgMR2>PaM#&D|Gl8NvCUnk#d>}0 z)uy8;bH6RI3Rezay18+D$@4BqH`~!=H?Rn%PUV*AOTKDCgMIguccri!Y-F3}p7~I8 zqlh-kJ&eUeM>2|arp2I|s0v}O&%2lAdn&53iU!=t<(j6OA=MrCIQ;WRN9uWAGsa7h zg0#N6q1~Z81Y(alv&G*wfcD#cb2fxP-czSF(6oI8rz~AQIG8*Rt zmnlLCOE%EFZdzgfxH5&QSKva?iqba#T2aeP1^>Z>vgn>Y1cbpB%DR_1+g88IG$jpi ziGBv%fH)g3dUgeS<5(c~cH-X&AeCnTmOl0jC>>qEL;*^pb|l#w@LQq?ka_3Dj#8kr zH2(WX3cpt*`)wY3sJamX!_3-V6gaYB@S(M($>7u_umD*z#Y#oIOMOrtc$}ccMxZ z*7&J3<;!KL@;u6Ws9|^TMNs<52yM0p`bVd@s;&=9>|i{z9_F?_Eyp(ADE?My)3F`* zR!*ItBw5ue&=!Ocnepx&=9{f3NBko`TKsvvGMK_02fp+ehB}yYEt1^=TbTrcE)y$P zfzHvDC{)DIOSV(@_skBwUk+B?mR`Lz^xRpjeah9dOeIJTBwk6aDW=|}L{BOj=hJsO?9D8zwh+c8(lXO?Poh*RJ2t&j0AtXBe1dH$J>93Vamgy!)mc z`V3B^WSYjnLeD1v#D|ViOy>$So2GU*U5u`B+lU_0w=&x+xwP4iCHlD1#Rl>9-B~wx zSdMiaz+KJl7yb;Rb(iyPT;FindqH0FOn~y_lS&a`qo&4Hshpvf(!LH*HqO>#p z@AJ_9`8K$-jwJOZwMG$lO_8IWYH!2Jta8c4rqyyw2Wz-{&wYHy4V*U=bj-Tw-C*+Q zcjP@dFMRmf8Kppx9jXLb=jnZpn`w5nBlS;B_}?24saM;2lXurF!^;P0_Epg+BZ1!V zR(Dq10rx|1TJ}3ev0Z3PlQNV-$4j$m+g_4mHYjS!z7i%#I4{Z)8hK*Z5yNnE4VxeI zu#_gFWk4?Lm_Q=i&ZNQLu_=h?qqliPR<*xt6%frkdSGaS6U_vZ3T!b?IEaAshlbx{ z$>c1!6%I|r>wSiGw^FDA0}9xFuY5k=<^rz^91U+Uno1ELd6AkwER1>E2qa-h^= zs=FXTUNuz-5HJr$tkqP}%74elO1o4Z%j1GZdw#o9L+(HMmi;(s@=5SF=n?ae92Yb| z-WKG6CD}uywb;Xp7*3#^;*BPeEE5t9@l6wKhhJrtgVcGp1(Rnr##C2b$ef;K$fS|K zkubf-y&xR$GdfSxHJ7yn_Nrny0vX%ZzRD6Fq&L-mj?*FTcIsgG z6xXg76S3)viu1>Vz^r<(w5ee4^PJQ3;$ti%a;2ep#%V|HF9i)0y;7zgePHn*&>lyj+xwLJ`3BK{X=`WgQ-G?Wm zclCaf!)(=jK7LYpNj(xfeYyu@$~S;-Ds4^2ZEq+iJbcAX+L)%)ul}fT4}MCLJ9wyg z?eT;zc%EqY__B7+#N5Fwwg5RTQn>B5*P*f{56&}n^WJeBib(WFo+Typ`%m@}T(|9( zAFh@a$stW)96a5%SzaWYTyWK}(8}p*ZDtC!I8rhP+y)fHFf(-lshKHu zkEKR;WQCJLhSsn%60agJmMpNYAdycCMP~9bOobxOdYe_R_#zcruhBjfxNfg3^63`g z4ed?gDcMwwvasxFS&PNzJL*sHSL|q z(10%Ao(U<4qg!!!Kg#kjAIa)Q3g+9yM`u3>(NP-I^zyU2w3<%|4BVviaKSSz^$*LM z_mAXmK>aKc~58$U=WdN)( zs3YI8l9`<@!(2ff+>Wus45qkh!|&`Tn#xIE-15FTt@uGz*NdC->i@|F`9MrCW=o($ zll?(t^H3Je0ll}~OJD%8lN@%G*EgdV7ok$Ni#M|X|!)C7wGoC5jelu2=pUo$^Nr9N)Zk)CL^-C0Y`RV z&$m}?@5z~m3o2+`ywDM{SXOIt_02Y#aR;59RJZcao`t4QV|kM~VqqSd=Nuoelsf@# zKWW76rHvv%xqMnbt>Yh$$yoy^F@)4aObPvVnI)h}2~xOcVp;hw0Ryoz|8VV+Zm;Qk&0t3&Z$ipu^x zr9Q$DEe$pb5W0|>0aVeNXAZRp65^J1*7X}gM{PHE3ymEt2~}TPNm>BTP@>G{3P59^ zs&guE**Y2o{GGpu=zpEK7In{xmVbsBnu4qUkn{EUyB#%La{xdY3f7-YHKtMjTJV;^ zbrsDsef41ePc9q(y2S|yMc;!4q+%c}sHBaset*fpds34d#$`5g`lqjPs+$f83>p7l zxFofvKq`uP1a@WYMhw%6HLBl$&fneYS!GPOURzhuL__e~8l+LNRFsOZu%@M?z}XH# z4^~oEz!U2x=?|Ar?-}zvtGjMtpPg^aZ0a1|Q{hGWyspFHeKu2$Jm=aUm-(1QM@Q8{ zD&ttAnb|$(n{Bj|j_lUJy>u}__#Jn2Q*kI{FzHUH$S`@)c>z6DHat}hGD*EmVQ#$7 z5$-5;3Nj~%$JSv$JB0bCt;L0={yhbeWr_BMmQk$@(&}*}W4FEKNs|})Gd8iSxG%(; zEU)ZOP}h9eFCj#(W?ciD@fin%$eD2BzRM!u>(@P_ z)LdQda_TX$S6~75M1k6$K-=q>^jgy^qvZuK*M0x;f~(8_lop|{U6IKa*hT4%1zdV= zK#xFH+Vs8_TsVs3og_`e_pGDuztpx!OB>jXB3+^+C<=3tS?U1tz8^~_ z3gUGD@P51@O^(eWgF|q*Qez5bIcgb@K+!J2tn>?#0kl61Z%3&#PJ8mGN)0}hs9apWJIbeh zXd|3s2NnSD)a1Qh_9|W2{Y0f)87fW@v6VAr{8y( zyrVinnAQvb1Q`fvoNL@dY4?taOBAs#6gZh2 zek$gmb-`7teL*R>M3=$fEvtx>p=&}zLIx0MiE@@ztdFdk&yKRTb2&7$ZDBON{)eO&K21 zfU}nRS16S$Iufl7X1WUUwrV)S@$(A~8*K+{Ta^AMz_vwJ0ZQz8x~v}vpH!9%zf1RY zmx?i0OSHO`Id1Itnr=!n@X;TXTh#J~am)f9nrAt}_47OYZb-g=K3>OO675l$)_le& z>m$)}o`b-D>wChm4`x6boXJ>cO7{xrv!k#wzB)^RvM7~V&JWy7;o$mto(=fGl1&cd zmuWkDT=tevkwXFN)*Lx2?(9P%K&0*ep6!_#1z54d*hizSCF`<^W;FwY2tQ1v6(FXY z2b7(&F)Jci3n@0ZO?*Zh;HX z4Jgh*Ka$ibRp77V2)MpXGMV*~qWII(SU@XX)4CsA1^EMg80>@Jn&E%*T|D$rtjF|9 zzd=rn+h8so1FTqb#;DPKg9nINIRgmm3yoC40KbizMg9Ffe#b}t;Sq-&#y4~ExZ*E# zDqZF6z&eHbMikdQ=-D)qvPwMo55gR!cx$xJ3{<4bqgkv9WP9uGO^9R)r`*4KIiq=g z6zw1KQS6~s&O_nq%a``+h6&PLy3AT)nI32X+Q^Dk@s8JHo#BRngq5ZNzbAJjNAshV z{SKGfI7hay{rRu7j@%R4@NBJ6Ip>Z;94|%hAsOV5{XD|&wK<*ci#!V{4i_^rWY9Pr zawfak?Dm#m$@)?!OQrC=TNAsAP9?Tgc^S2X_7#RSKiuU0ShsF=ys5sUqxf@A&I12- zn&sQMtxxyz-Bfu+_mnr(SbIkBzKFXP)DQQ`YbgxB{_?r;xs1Xrx3{uM7ZtfR69de7 zIqq2UnbB8`ByKXlRR2_3yQJDFR{u5PI)#LVLwM+*7 z+P5F4Lc4S<*hPgB?-hmug70q;dNGqS5`BoxOe@At0O@ZahyX$p?+ny52aUuhcs8NH z^3x~3`2-}NGwa&TZ^d`wBx+Z(3WkEYsU*LcyDH0lp!DC3R zpwu%|*n;-On7kXt=al9Il3yh|jLuaJfX+~Yc23Xq(gw_7PwPIM?CCD6Oz>x($0WN3 z61}D{;~^lo*`1~(uS8Z{LM~K--DE5BoTuj%yed#L2UN^WIRJR7IH-~Q877m|E4sWr z6L5zHgcdWlg#x?jB7iQT=t~$R(pJBJoVM9F|j2L|(T3tM(~TP4;38oe7+XHWwB0TY zlDh+vx;Ib+yDE?!M;_rrdzOsYMcf1O4eXF=?y6XmnY!CjYF9V3TI@4sOf0#1>e9PX z_ez51R=BUbppbjo_R|M?aJ!6AmWxowuB8d&CPxEqxoLXU=|V-Y#kmg^Jl1wd9un7o zjv}kng{OFHjao)u0e4?Ob~`tV1G z-FwXV0*0_bmFP{u3|BRdeBsRgp|)Q@OMVNcVvG)G!UH+DG|0hS5Y$bYsV;vbJF1`a zH%g_FG);fe>MNAMcHP{)$b=+u0g5lnDiZM@1xkA(8Gx>66Gq|W2YTaH4{=zcZ%#L; z+|6IEr_)B+dESmv(|ChXRR7ZofQcR|Uj(CpO#S1*zJZ}%-3MrcaDqf{N%+bNVugY7 z@FC=*kxV2HIB7lTIH3r(3JC`;Y=QEiP}RRc%0${H)b!vq{Z&oeUj}Ryzz9qn{#8?U z!+$@P_%cG0VUq=?MP{d2aHXlK_eTC%Wqx}#E3o@(#Ee7=6OZD~Yb12OKCwQM<@Bke zw==fh3J`ZeE3Q6)(^LT_w7OyXH>1cO4(Vn6KF{zMinx$oQvSi9@@;9I4aR3u=K^Br z;bnynS6?2DQV@bFE%D$GF$6o-G1mE(UrF8nqN&Vf$=`utqr0ekb%OpH*2ByROW2M) z^L#|ivhruxmQLF|#GtRMhmemySoYeO&oD>Cyi{5wd_kE!LOxFelB9nTuYVub1>TnA z*_gSZ+M=~&Im6K#QRtCENywy$emD^D>SE?4{$+fHzaAtEE=2i!0$?=Qy)2NRL2*JjJ?YUptCB8PO@A>H^ zpPOo}dwi4boH?x^b-bErG)iVqHvsQ&yZpAJ5s3JGL)ndp_RN<3lVuaoPLhT7ym;MW z)QS9Zxyf2cqAPb$H^-c(zIacXWo?(C`Ze=|kxhe=Fi#dJ5 z6ZFGpbB(u{B`)?Au%6$uDGuhaCKtng07-8tQVu(=ESR$^dE2?A=$F245(Fw)P1R8P5T)7?>^bq+6c4Z=ZX!tKu=8XU+JCA9K*z3KyH$4lTyxKZfF zimH0v1*hS4S%6Fx(CB9h)uKK4v(M8bN=Tv7nVug>Y$pKyyyvOOBD&u;VCBgOd40pP zhRi)}A{lGX7cE4ZaPS8%o1y|-v5jyhg>}xdMXl9RA2z0pJ&5xteZO8yzT}iWv)gPJ zDFy3LBdXS5eYr{UiTUKdi(Kb`rAp!oQ}a^FLN+q!tKEFs9= zK$PYejgY_E0b2fn^X^MG+hrag`caD@F@jSI5WN3QUNVQ?IwG^>)5sHt@k>tz=$I4a z;Fdu8j5?oLyqeqgz9iV>5Q)JrI@jzQ_yc%Ke};z?W#sS6*re~?fEo0K z9Ne_PmIs1@f+Pr=k6PIX>|u$wULwZ3L`kqSYoXr(JHEez?{^-&5&&BOPZ5AEi63bm zYbN~B{Q8a9_`kc;X{x_UPsgIY{WEX$P6CB9^Q5P8_*z(+@2ZxgA0q`I%h60eh1+Hj z<%soaF&!1QQaqwa;Omz(lwa^>Pb=5)9BC&0mq{V%sy3&CVrYF~E0(^qb?o@?k_Xcu6K)9e8VKESvgy=<3w7(^(6@+YF?ie3 zjz>m28i<>^NZ7>lUyLUEdFVoPP7~TdWk3J?}g-tW9 zDZ9NroaAfc*YJnvU(o$zDe;awV$YnC*|wT3uhu37?b)$KMkfaQs6!7N26RNJWXmn@ z60MrXT;r$W>{}1o8-A+UH&G|FH-LFUn?3u;of$rVI_^!dsrQQJeSVhd7;AI(8r=a;BH1dm|7BNe zV^>S*gcV9??AZ3DK>lR*Syj`5DP7~c2P4f`%5G*zaKq^MAG*>UVM5E|!F1gvk0Lo0 z0K2;+s}!&%1A)N6wk!b0jgqB+z(DJ~6|}CSSe}$`&%YEK8h@k+2n@tT!WMp7c7S1% zWl%(~tAH&30IY)K?o||9rE}Q%?M-Pnn^W(LA->ox*V8@ltnQvwiwCIbGKiEnjE__; z2QHemIJ_10IaloBLchhM&^TL}Reh6E?MBvlQEKe`ABZ*bZUls;_Z3Cf zhbunQ>cD%^o4gDQ{rL)fm2%#&4D5U;$9>cFsp#u~coX3WHb$o~M9(~$3EyCxf2&aOqeYFi!c3cWhFFJGGKYz=~4QT(8 z*&c}vEqa2} z9s2(lL_DeUKY-JMdI`X3LA|8A>ZGaqNxFjnfVE2{74UsMh7&t4&@A;H)S+wI`2e}O z>VHE_RUWij7C_{;$=ivtpIG4h0Nj|IUO`b>r~i^qIJ4+ z>Gz4?vScM5B>7Dz$gjU&p`$C{?_Wv*fB)4bFS=R~W6#t4)EfHN4HfK*(5%cDd|{kz z4lt$0nveiTmB83ea#?d-$ejz}@%;VwK&2G`w$D|{ybpbhzTmXr?*p7x5lEbpGr%r$ z{}()spSw}Rj)BQUs9RYCwUhecR9Ar161XN^@ogW=uWg|qwPW~dfYe5g{L!8KjS#cv zuTE?juwSHh^Pwnrt@8V!*!2G40!Q1A>bMT5`kNsvR~;$r87UkZU;rs35&-aOLQsc^ zE>W%bykeO7)gszlS*sezOoK};9v@x(uIv*h!@0o$@rO=)8vFB#hi0UZP5Ic z(8YMCYqmtc-f>aV-e?_P_^eFID4Zm)PO}*rfI9+)9bnlhO4%okA!eCCeIQK(thLA$ zuepA~_sH3O;4ogO&^ci+eFtcBHG%Uf2xzqKitO6_Xm*NE_Y8B}HTEmqUS1!0A1!E| z*l<%Q0LD$D{wy>vV5${?Q1ab@H^D8P$1YUs7M5O{O*1yw5(0C&*+x$@?uG5j!_Wmg z5i~95ax4AQeESnv*23d$$@Q*uiwSqqMEj05SxEDm7gf%-P1wy3h=J$JM=#pt=C&&A zgUmlw=3Pk-jR+K|N>&P)a*ZTDvvjbVp9#E>=W8|WhH7y5B%*y*U)(pdh2va!)kpLY z8p6}r5DbXno~)q?I-M?aj7;md99~49Bqi~L_dZuipjp^Jv&6*@l*ppUSmUKaR$HwC zc<|k~fgM@x3cRCZ1#r!K@PrUC^5p`NcO5R_&k^$>?<(V5UD)sdupKq8!Me*n?*i?A zHYZ~4K0|NBz*+2_%3#tn5Qz&^&GC%QraHcoNaxKkKY^DwcS4#EeC zdWR}WYRUxgkd^!W$c0mg$7w`X0q!dQa3Xjod=)jqRW7{35uZek<$UZObFaZ2D_{A9 zDJ?iv4qeN$No}^Yxs$e2m^0D3JStaEFhP&R8E`yIDr`4n0PSAx+e|xri7~hig`|Gu zcFT<+JXr%18fTPHgQ5z{73e(jdUv%3cVusYY_VNN8v1$r$&fhO+d+Lhdh>WbZXHpV zuS31pEx zxiFCIMUg@He_7z5@xK)D)c?ocn}hkbO5wk)%oX zbz1C`Eo3)EmaHK})`S!$+1JUIrR>XCvJ(pxP1bFL>dkMv(a&U z$S2AWS_Q7a3>kl9FR8BrS{68UlB##=SutX;5-EksgTWfi*1ik2fS^lQztPrsdh< z8UM(k>hU+qKG>9fx3k(He0~2L@XWWmDsSvH05Rm7@z#cNLU~r2HQa`P(g{|R_ZL!XRzr+e>&+uyV|XC z2%=?5S;wP2cABVl#;7w+dY-_A|3I&yKgXS?d{i)QH7^q3ZjMRiDMmID&j&vE8@}t_= zdd1ELeRx@PW6X1QDtvAT_PL&LxDG-u!-pkQ&sH2O-Ob^3vk(?*agqF);1)+Y1BW2U zy!aGcvvhvjY`-!#zuRfKGTGZ=bT2Xv%6 zj0j6lUPkvZTdT|sXTkSZ^y2|RucQPmzVJnSB(w^)8^5!P_C1p)CgVMmfXkVkuqVbZ z_xV>;M|_N3ZX9v=3#ToDJ3mAGu2T*v#-)%QPfSH|j)v}+KRzd*XRuNP-bb5f?#a#n z@FtD^>H8-suyBBfpl9ZgB$^V!0*rd5uWx$wm;=vK(kbMcdE($^A?!a38yhVHzaMM1 z^^=v{S_q4wpsKUc{xM_(iu)rXr8Zx9-rODoa@*@&-J4wH{%Xzs@${!~-ptUo96kTZ(n77FvkRR-keW-#kOzsi;QU-Jr zPEuw4?mmtS$Nuz13ofFP<@upE4Q3W;h_EK96sBh`v32jYZN)M?g7XeU0Tjpd| zyuh=qyRze(J>~+&>#k>v@>~kmbZAlHX+jL=(ASJiSSS&47d&7OD=6{nRVz7wu7t+L z2Z<2^E{`(Laa@*^ak+MATtbH5b!i6kMPLg_ajmX4ar;t>e~BUaRk~0wqRsG122V@0 zP{s%+71WzDZ0Ptaatpf*j<5UGOQJZ=IL$4dnW<)ht^KGk?9|0xwEs56!C|u(@8sdvM&V8gD~ohsfwKN{+YoJn3PByOUXZXs;~ z4`~m(wVbf6rS1Q*SARq3{kg{VKlbY1<3azAy$Yh&Z`E{vLyY;y?A4#-hp3rXplrxi z#SSvD5Rai74zxKWiV2Wwa=U#G!QeN*ENQo0+JA*R)y?o{CDxs&UK!Bn3|&F5PQO^^ z2C4yM44EYQO(AUs#sV7ZNhlJ*Qmz501_0>?Fzl)uN_IGSUAm#DP`+(X&3n#JUuO%L7sAKdEC+928%P~~K1nSN%!{dxS*w#g=vy@TJNfzSt z=Y}to)u-NH8)B;UaM^T@QfNTS33lpdP3fL>@pZF`M_ zcl7kF`PUy+mkmnkz4c07iY*MzFDNx5^G{UdA7N+nxan*roqdk!Q3f;pq~!8w9@@Ou zu^Y?o>nL5spBy{O#MJ4#=a^?YdapJV;s2z^+1h4A*ZLq~B9d+iwHr5I%Y|$4$y&0w zacK62d;586t$VdOPRP7nlCz6@98iy1aM)D;ckL0~mWS3jma^#@PmP{&(K=3pOFb36 zPvm(gw_*2$XRZR-WHM*ONR3tMz}Y&1vn)qPFL&Gf0RQq=ij|HD*zPH( zij$?dk1Cv(7P2`9p9huq#i5e-oWi_ahtiuP4@aEyI&0z?=x8NX^buWp?ZfkZDR;KX z-NT&4ddI8Rot9(@Ok8(m?5W^GBl5ywk;c~0VLdxF0;XZ)jGcz9b1Kg5iuM$sAfsXk zVz**6oa5dLxK{Vr3veYxk0IaolTx{HvY}pCSOZTzZsXY5Q+iV4i zk4C|F$*~8%7(qA>Wa*v%n0kP}@rwN;A*I(Vldnjb&pE1HWpqcft8Za$TSK$lpS<}D z-AjZc+-@p->{YTkGq_8A<*DOcb+3Ii^-4yO6;838RiDU%P?`x?+YI^k_{bGV@)O&> zDrL6?`^U(dJoIS9g4uN4@2+?o-N8)4WtLFQ@Ey^i=Sqt*10hW2+I`V0QUf*7=@Jhb z6K>RTS~Y(eONi%PrZ%F%xP9@YI&wnUUg90GbnS?l&#b$#S$6?kFLY_m{z<3-&CKSQ zq2z-EYO^c*qt|F=X4IdIZ+Y*>`Mmn(X#VT!(&k3>;oq;Dog475geBlfZRCmX_Sb_T zW*CAS-+Geb$7fn^0^h?}Na5~7eIx+JMX65fc<#@`#b5{StE^uw-6wp& zHFoJ$3w|+?TwJ)2DX92qT@FKle@G^v9*& zsSwM@J;Ijb!U>((9QE|##3jAu*sW5)EPBoubD-v*52QvT)GIcJbW4v~0*&LJ2CyT3 zo5m>u+l2}Wo-2Gq=!6uA*&lFxi_+hsD=bg_sw&3H zF>!NHtw~%s*ovH)0+fCr|4x1E?~cQ>gpPL==VQ?M>*-`&@zlGKjCF*E&;e-UJsEUm zL4uGRbB^GUo=;!TPA)Po`{|FH-X-pZ9@l;i=%9dRozuDc zFlLWaZN*#SK9zjl!e$07^1lc@AKdh6r$kohODBPYww{{323=t9g3kJp?`xY=eeL)^ zTL_@Y-iE0v34$J67GTW_V`Xh(sk?H1N7;yIep6?2jaC2a%tp#RSI*WSIWPpCrmCX? zH26XlTnh0V-jk1iE}vH7wxg^R!7ShRleAgb>|}`ZNPJ2x$LWBx!BOY0?=}puP11L;yI{7PwM&exPgh9*iQ;<)5*B09XQzL-Q z(-x%y0ImNrml!t_d^bJ^YRqBpL6($#2FU^7?t|)V?VLgvD>Yps+V5TTFoRp(?dI6$ zq})pluWNt9zZ{coyw3mRc?3~s61zl*O{h~iFXw?yB(ug14kE7Ham z4s<5vSW1ogJzmpV8|%Qbk@K~+athycEp5MVdwD)e7S>?cO0j1jMfTx$hgb9rSAt(y z3a94hm%Wh*zgVfJ!Oa=*Hl4-Jubg}Fc&baM<$gT}V2?ESs#yk!gH^%zWyv&Ij(sSz zy{r{>Bh5ZU2*WxH;L02+2q5bkU>{xFlRQ!~bbYv{KA+1YUa$4Xasq z=ib#?*`bp;=l<|>1}X$D@TMF|T5}~-T*87U+E!ZXT^U+#fG*K%|H`4Ddb9=D9hU*W zaI_yqSO;~8bKlk6fy%lh?Pm7t_)ogE-x(R8(E~>Gm_ja-APXCQWLp(yxBX&owg>h{ z6jM#)AH8#RXZ(*@zW;}1)JB`k185hx?JM*}@cchT2(wZ)(1yqr$SG)D2FQr5?%NM4y?B43TJA*<7(TO%Y>eWUu=VlDiT{RPEj!@L4x+(UU<~9+yeObpOGYQ} zMHq3b6;QOWZ`iUY>iu)4@mBP>F^+D11+;;l?%t|@A;n&vOh=Euzahiw<_=_7NkiTD zWDothjM86GR3R6^bGZ->o+FqIlXk#-&5H+oNM6hnpeW0jzo97mX#JlX*l+C;{|)7m z1P*bg3ml^7Or`fEmlWhbRxsM*23=!wq$4Z`Zm1?WTA-%j*qEWhmEeNCnFebOA?SZT z7GZv5rNB@E4n|_u`%1Bep535;GKXWIeiR44KlC~YNXdSs&|Ti(0=@x-Ezkjf!U`z( zMG(Q3XESH`uM~>`E8!7olH{h3oU-pOw=Z!d1+SF5-?qM0wA%_$`yJg)Cg0hrI{cc& zeb458ef=lh(hJdbsb+Y;;qz)yrcWrO!yDiSop0XF$MHwpDZg2Izpl$z9d1XoV8*9n zztFACfe9%jg+Km@!^-hA$I+TEWmr%30Ah8G;+WKdS!63|weozr3bXR^1W_Y22Wn#E zt}xm$OGozGU?(0V*XLLl<5VkciV6^+hVO0xN1C^YSxbU7(7Fh76ek0oVj}fuC^2M) zzY$6}e{L%?u~=gnvPM0C#BhU^rs|=*2DX{t$f?Ug_io@jF>5ogj+bWYnn|auUjkEA$wienbq>r z(s^sYQV5g*MZk<^cm0AxaO(V0!O)L7neGm2mqK_$NwFu+Vyh&S*ZMw9+*{d!d{?eA zYW1kolN}%Mo;6m5c_nnkivj(D5jR7 zR)WYQvT%~?+_T!L?;$(23Y%Tsu?HoHP`yt`GrT`GI+&|v55W?kjzSlaSq_N|JE0bd zdm$UU);8a=Q#8Gbhu}*%C^1$Eb8`S6amEjJ6GK+sHZ*HvS0J!YZ<2JQ2)RB6)=BS6 zx;~{UF&6FzO9Og+^lVK{17%M8t{BEjt1|=8da}s~L5Jt3a)=K{(=gZQ}zWJ-fS~90ev_ zc-dn{29ytsSVzBtg{LTbna&}{J_-^Mi;7k-xW_q|KUQVM&HIJ5l2L}_hZN>93Onyk zyBW9(-V(=6if0&kvO|@@fjf!az&dVny2*VWiy2wNtyVUfcAwE2#awgA>N!FuYSnrJ z-78tS)Qpfjo~q8c?CLOB7NXK!-|v-k@DA7nb!V3kStiZ*79w?N4F9FXZ{GTxedW{fmuN-pu>QzMrs?K1{>tC?PyyACjb4Awdc+ z>v4n~)XRCh+o}ke7kreQmiv{$625<+AOHS*ZpR$^podg%ns#u2P+n7z*O{|b+QqPk zsYh~dy;Rcg8(LXLbBxc}o{)-JdJ`+0r(uI!w}Y0(km?3#Dm#UiIF3hEn2_Q5TU`Sy+{hadx$hH}B)C zjrTaHcKIGC>Tk>cBhxDD!QXYcYNb~FrCLFK3_=W5NA}1fm#jcMPy*2ZuPy4|cG`@a zjL;ET*(!DLg;PTtn0;HIJe%G(h)VQdCg{I3(hGj2fb;s#Naumv#!Ta`^Jn^WKRGcU5Coyz2CvL@dJJMuHwDr_NXpUvhj^kSPG$a>U` z4bwQ&lGty3?7Ul4MY^}_K|rGPFEaY~-O|P@RlpAz#}X7=AFX73rML|m{!8T{UpxQr zf;vBL#UB|!QPhKsocI`i7>U^n(ntB&a)vT>g*H(?-0H(oLrV?ktfgU4Sg*hM#7pV+ zKHjp8GXDlE@0+R9zqm2K?RANJ{NH_#C^m!`SjiXg zpot)Jh`_u-jQDQd1gq5BS3uZf;@JDxB^)=AzS>0LI1w2IdpEJ&h?Df~hW9<%RzI?uHqaX*0sPB3Hf0dfOD~CYCclYpQ)53}ILA zx1P+0TSj$rAlMe}s$p#A`z_f>mAyM7UgH2-cMa!}+on0>%Jt)u20JLc(ToOh|W^!slv~}W_v^DKhh1PLr%ERo~{{7 z6q8)2UONEVA`_?3RMRIChlL(bPw@6>&S= z=q8bLQsL4H6V>bN!D6~UaO0{-^VxOH%SY6>gkzCt#GmCFl4T*OJU5v6joSC?B-%8e zvaZ;pfHM}z`@UOaX>6F0e)n&iEYnRwAV%J=Sh<>R6RwMjTR2J8CNjT9-T&QTKKn2X zOD?+2=a2aWDTRR^^#Mj;ru==n{z5nW_4hy-{rgY+Me$0-3Nx?<3H zc4|Xe`L|A2B`uro!C%FV4I%AK9}nI1_q{FX*@B7(HAs3Nw9g_7P3v=mCf~#6Hzm*@9~af!j53 z1A#ux0WA^u8kG)c8ix}ovbk$qcNUfGe0BvGqiv6Kv1n1@IvYkG}iCCn#9vI38 zEjdZmY7Us7zdJM=egvBaanqk0y60FwLOGmVB@WqQ>{krkJtfu8^{UV?#_ojy-Lgf) z80hjX`O`@MXCwAo{puf~Q~Lbt@*IpKrJLt!?QGWF=&Pl+i<3;tN&5tNdvv2?*AL;e+v?NOiGJI|Z?93NtG)`!X<8kE$% zLlwe>cxO+YquQ1!M0d-`)GA8puu`wMnL~#mOV4VNy=ut0ax#LES?;Bx%b3hqy9`Itlq!2x>6HSBPHTz}JRPBy zi;4bUDRNQ1Wj90!@4UHJ`UF0e|pn_YiMS9s+Bxaq9O$)a+ z4iw69bm4f5JauC+<1N#^ebmy3D&1!89dn=dmmw3R-7yL?X4Oo4y4EkpJ_{NmmOKTU zlS1QFa&~tem=Axb^`Mz0MBr0b&Fci7FPcsV?fLh4>K;5ZIb}KP;cgsT0u_rpq8~W0 zb*O(<#%z$8sp{MJpku~uLrHQ?{sYh3w=1COc9Q^M4v`mF&jSAP19ZLiKh{SK_08M6 zJAbdO7?78^D^F^)TL`sW(%S+={<+M|-!b(D8bRhtdYJniS^fv=LO-TMxZ43Zx>L+CCOUX zg$)Pyf+FDHzVKTmlAmzMe<<@9*0CYES*@j>u76!&Cw2Ifsc))A9q4hpE|7c%qZdT} zQ`u9k*qtCc?T#mWr7+XKu2)+Z{$yt9r}^{+RBIhOp)=hAz=rLgDaPEnIwf_O$hAsK zinl8Ka?)-*4cZN)a(OqQkn${lvSfZo3-;&Gcw{dA#;sc@I4!IBY<0ih+(MhWgbOZ? z(vWeh3UZd~w~gk8Xd0qKF_`ngsTZ6Pt(2XZ7P(!m%b|gCEP42_QTTq#?%ksY5BF(C z#XFsNpd}_u(Oufjm*x_C{Ef=wy!;opJCx950r%T9w~QR3A}AY85Yv1`Av%f8B|i>`bokQ!fi>Y9wWh zo16GGFFvZOM<3I^_F&9{?-7enC!xm1I6`%rTXR*|DD? z{C6L2&!9Pd?Uc(ye$odD6(csRc7XfY>|XrLXl5&G|4s{$CYH9KD5+Z|kA=!IUt+}F zn4_Cc`3HQ+NqR5cgR==KQ-Jz&;ytj<&8+kL>ljqrT%1bRr7s2CUbB*`Vz-Ff*4D46 z`DAZ1?wHnf8l#;ah-x%_*;(vo53Nc4c0Pxhz1P}RFpJrlO&>95%F7}W6Lk8&tdClV zUxYd{H$I^1L-=CF?`O#{Bb-i0M_3Y?vrZOuI`YO>czmopmwk>R9`Q(H|&T*)$rY_Zg9OF#cchYfr;3D{?6<>ac+a< zmJzrzQL{atN$#Y(Twzh?cu_dQi$U?iPUxUgJyn(pc75a!I>r5bmy9x3cG!v5;*)Ls zCLiATu6a|G3&*jSOQ0T&x@U~O92p&R9vJFSSgN2L-4Q*4%Rjhg)>!P`hmL@|Ob_j< zI5IUvH`a-{6z-z8Ucbn4C!I?`nDyMtvH8UI@~A(|`75HSANbH*iWD^t^==YARaCN2 zqwVy#&@*5t8MB2;YB0GeA=g-@qnq)n8jtt9wGL=yQ5K{QLnX@GO^`b33^wh4R-xZ* zX2n`{vV{Y>KiBMwxjoqT z6RuLfFzT&!qx8BWTrXRh!Ivzyf;)HR(#h?l9RMo*krHnL7GS5!-%q-imG2Gq+)>5wJrz2Go^IW~r zK>l>G7;(8WU+(Hb3(+qg8pKcM(rd}~kVW);Jv>Q_s zI)M~#9;dr_#gJF7r9Gs&!N-O8s);DLvIcO6tSg`Pssgd9VxO#)?2jU){w+r5V~B_n zQrwcN;^au<$6m`fy3Y&kBUqljuJ0>uSf;I9_3Q-f8uRUdwLNP%w6L#bV7enl7?HNyu4RVwmqQ~DE8d0ik*^BlyiR~VAN#ZMSV0EA8 zKhIZrs!{H`2dqGD7}t#2=RKmTqFbaG&MOg=`D*nEw*>QYt-IU$MGH)cz{4^@4}163 z6H({hNWFZpFs>yl+s{?rGPW-=Q->(gJ;&8;QgoJbr&zH3^^f#nM$LLrEIOYZmeg`p z@<)3YnZG1RNvgW`HKvXZXF{8=Nt)VL3miS_l9G@ajRbhdY3Ja2C4<*oN&a}|YBpN4@xfdZ}=!w`3Uow%sR%PPX-=i)kB?E(Y zNfFA=|1Mfe52}>Q;v{cSiB|+gY-RRuvC`%g(z+?}HBU=a;$latW8#m1_)8=EAK^4V zJCtn+#Jpox)cP=5fcoZ2Y?IpqG+$&7-IF!^R^lic$C69)et~~erGEe@e~N+u)TEAx zM4*=mL-kt#BW>=HENXqk!0o+vO=`!e*7GLs{|8tv^+V`54_KKP@L*|0paA@I$U7@L zJ=r_^L(RuH2%p3nBJ=NO15$*5G1B?Oxx@(r)TSq#8f=7QUmpR6*`wnt#p#ardwh6a z@&z}#Rj^xB5beZ7I*ZE#F#xgCB}c-wi~Y9@;lcPORlbzkp(N0~WKLsY=}1H-o!fKB^=iU_J#gesJdd6<5fC}!Sdre%PdfDwK<@my}PK*GGkpsQj`PFht#*G{&d zDe22b^+PicVQ%HF|LVfFS+%|BPOK}&%ILx<-4VRm+k~TC%TNX*T8A@lIfC5V8+&Aq z?}K5I<}UC^Io$lLad)ows%DOGFpHp7`LnQ@{Kd(d&_3E~&;1r6p>#XCm*Zx8nLh~T zjP0A*ND<0bIJLuu@R-r`y1)hz_;!6c0OwYk#^o5GkYZ{L^&5#Qe%#(@9@{_c_kaxH z)c)Xot($^1J5_B#9LXcOdwHUKUAy_cNX;0Wi}Fr&W*;a_Xv4KUGM$lBkpdP^-zTJB zQbQ69q6P!iYnB8ap!ZkseWrI#A1z-h(4yYU_+riYq)T7UlE2#{omIA(8Jm61x6Wdt zp0Y=b%$)p^mGZ?q3OaaKCTEp1cJNubLGFnZ-!e?Ocv(VLA~XS~#S5{#lV?#k^1vZr z&C&BV?>FzvJURj%*yfagPI(1DPs^erR{;99Jj!4#1@*l=xwUFx$Y!4)08nqlRlpKZ zS_`1S01N{H*hB^qp!mjkCu4eOfvy0&rwrD4K=H4p&qw?~8E1UMwHs~^kWx)hlO9F` z&=raK2!@=|ap~6?Td$9-pLYR$yLRAnj^K0P$2rO;NiKC@b^(`=Z)&7q9USosK#t9W z(?(5!^U1C#raK5sD+c|*5kTnW0QkoDX?@cbsMVv<^y&_ELW=E;_2R^XXjarKUSlE~XOxCNWj_T+RH-&Zan{ z&b`Ko`&qYtrBLZQrV@zoUworu%o>J~kT^Q7RqKrTf&=nq!f(+lyYj%;?{GW#rf!dw zNHa$;rKza|#Xe=urET{hKUEYk+5kp__@SqGK>q=j{yIkS{GSr|-h|rTQ)`FmbNC*WC_36?ZT+j)8Xy{?A#c2LQMIA#&yu&@d_3 zLIp&GVU5tq{1GPs3l);r)NM>&{V%XL|EX%}Uoj9ds6?VZ76%Nqlyuv&K7|Y}uxYtu zjXz>Yen+`I=+Q1tV>1kOsvB(cnEakGoBinP*I8FIUh<}=NDC1?4EUUHo-Dn@Kec95 zR|{t&)ArY<&A7r>LLORfQ!19}38R+Jt!-SVZfP%1seh=8DbU1cxJKLpB@2?c95ky2a5p+I`Km$bj-v`p=SoEQyW)F6^KbLO^3Z6lp8Mqu1gc4t;=R4Xs9)VPV01rfw%s*dJ_*Z`nE4cpK_M zYjR)x-4)6fh@m<1;07a4)I$KCXKEeR0HhPvM^&>3+9{85J0Xla=1<7wxTPKKe9~^my|x7@ zw}tOsBrVu4g#PC=TqGT`-UiT#vq4~(0?ep*XC%DUHy?H2ML&gED5TSUhxo#+B=yPO!KS;<*L4x%t5yZPP6>)`*+NPwluQCmj>`bI%n$Sv`4JQ}+!bG*$ zK8t8Kf6y?8|F?k7-mI?n3Xq^Vki zc3^{eYa3NR|NLM%p0tZ8_YJBoTn?UnpmyB&e>J)G%=G1HaA?IK zOF+(gZd;l57&S4n#TUI{6Moxq$(Y5Xy!=D7-K2X=GYlfwkZGxIRl6efLQ5@Zln@P6 zK6@r;(qDkllbazIL&;v26If~Z~!Ufs%4@ddhG32iV9?^K^JnI z4Z5BWs~*A)BZr{v#X~sl-Dz{i^U@JA@FSw?4qs?XXhKzf zoE!PLkY*U5^V01(W-#_Eg_nB+c76Dod4gOyF111Ib?;|!u6MnP`}^ONWE>3|K2$Xv za#+TTx4+`pbjvHev_^XhMfo_xx#F=;VKgCbcMy4*Syr^uZ+yms)0Y_&AUr94$x<4T z5v54s6y_S|OjDsKZhGEEF_i{B^ao;I9Nv3mVkKw9LxXYMF{Uu`;M_UR{!>De6eEwf zDXpKY7$?OhIae_1)|0r3iX92myu?mSjJrMFDgz(#UInwHub8lGqOV9Ba_i1TzA5_B zC5mvw6w=4&NHI409@w!BO3AEbaIFM|1Br-niPMy1fwOS7YTXM4aCLxnOCvNwuj8#kIf2FW~+XIk>Czu9uJ^ZQ9;$Q*oQQKR^EA%kI@ib);E8rUdKxkZSUfYxzf&B|#-_NMP)o zjClGJv}sx{geUrZsa+z$G8}5N`l<Do+oE}3H)fJ`hv3qOekJFO+=9kXN-12cB`I?5)^Rh9%}rue#H8gZfDJewb5wR;LxBNJVh=P zpFDbz$TNZnw$@)}XcgTq>0$rmxw7rab)T}y7KnV{Yp<}k_9^E~cmBlce zJ`k4vFm14R(pG|bwN9S5bM$sK^N|Y|d?Ozpf4W#aY|E}11>139EMxIf<^tVmPNAdywtPCuCHq|^4? z-sh^aTvcg@R`FO--k0V}&^30jNME~xgspF8=3#EJ*G%1kGP?2$UO}O7p;F5O6>mp7 za>3^MiC3`z8?yeD$5JR5H$KPwO`Lgis#o6}+tSiM!jsbf)s$}F=L;QM2apr0n|y9# zV=wqCXa3-?U9^;2t=h(A0I=~x8u?$`{O^JZso(kX8k9|2@mRiL)EeU#=ztAr4@Ui8 z5XM%7Aj|bX6F=zwKMki+O&F@z9(ID~8Rgtav4CL15G6YUO?E!D;RI`<*2YNJCObVY zGvVqJzD$)j+>r%hT_;cQ_E`{OEujZRE_xnwc6WBGEhun8_d1D1C*IDf$d71T*PxXQ z*iDus(y`Vmy+2`HIKYd0>Wo2w zI^TO|q2VO_^uY^SrhKod24}cbk>`9BeXTC_a-P$SOp`x*1hnYV>vctN-A`8KP!)T6 zdqMK_ht_1y1ZB<;@UR>yKixns^RwLAH`PI#-HdL5~^A1ku8*kptAY!6;}Af z)agg=#(1R7-sgg~yah`-oxJX2FQ}+-sORYa%IW;OmroUm%SI=Q`^-l=vfjLuzRj3S zSjVUy_7{QhKu1liR+WeaFj$tooM@}a32gZyfBfxT8sp~@hv1ge&#+|g32vf_RRw@8 zqj;v5k^tiM$1Rx~bnRaaiMNgj#ed~{0QXTWPv*bRw1NAGS?6-Ewv;;_Ft1=O40gz3)q3}Ad@d2EJ=0V@QwuNB9_a3`Mz5@XyzSiGh(YP0VwtE$*53!sGl zq2nC|`|~K&ar{G^z=&V*M7fb-CXM!a=ngrnXwiR8F5T}D5ENu zk8{!wS-@SBj2)W2=yZ7Mjf5-(`Je=wa9>3d)~_m&%2}xL_EM7&PwLU#{QUGLu?nAc z{iu3vWiuQ|z&ra)7x2#R*IhhisfjXft7%e=_h(#@WrkipDZ7iNLF%>$~ zgpd?6X9d#4Md0m}14*GuPcu&MZj7S|0@7oFWT zBek0)IA@zv%X~y=#EhX6sAy7uH+H`xceK^vSmxAGi|)6~L(pOX#H3REgv4ywnYV2{ z>n&tLia*)W|NZZYHai!>Tc1|DdS$2HRt~~)!PMprA)8OyK7nszg%2jZ2bOXiGXCnF zZdZ3!B#!RguA@(tp^tnCTf3b&z`kt1Ik06+lOe6DJ~>iWg@orDF@V%ItNq^zum5&D z3&UE2AP|2C_(?&g?U!Sq<;z?Be~KYHr=$*8U`_zCGC3NmqrN_LX{%Y^zx_=3_fNX; z8=zyoffzw<1pz?E<158an2~6G|2N)hAp-DLLv!8%-l`lZy#2&c{_liR^PeB4e;m5U z-C-zmRi|fGomL?NkK!)J+5QHv%T}rZ#-xR=Tmw=)rfZZnJT26H|F2KG2FC>+dw$FN zMR{+kmc_X=kUHi;+sl4)4y7FGYX#fS9tID-xlmWss4CStINfJXi-Na-E|LrC>*;+! z;}Oz&CG_T%r?-{81lDeXbFWYe3D{ad!8u zu0l((YN_B#<`^lv`JaOleX9dksMbKb{t-)Y3e$trb@~6VgOd6cbWQ~Z6&(D*iINYv zI9j8E%J9v*Um-VlD=_|n*JIMYfZyrMp%cfWhh025-Gy<7y|Q9ON;t@z=2IEMUOpJX^f;49<0V>B|;(Io4T~O7GiE?y~bT zyx?XCWKfS2AssRt;tz-_IPuEv3c*|{!`i4JEP^Z$oPxvvk_OaohczkLbfKX^PZUszQVCZH4+tN)Ni*Kj;q)rBp-Cs#P< zCXB1RDgdbd_m)-}9S0&>@V$M`5t^NIbNZ@z>LxF$ttfMh6-q>OtcD+%9X<1;bZ^Rb z`$VUxlb?cRq{=J*i_VX^#zFNk$fw)=S~**u)V7x+>BE#UiX>W=D^H z14oOiZY<|%SYPVrizw}%Egj9NvbLs3TOPE(RnSezFlV8e#+(OB&p$Eo1hmBCRq7Ti z$0WKqq^Sq5N;%%~G$zWng}eU~>+@ipVNQ2rZ}>To=KRPEw0(Z#>G%&_U&q}KCz-20 zZ`8k1z&$Gr?3~M&P#3ZlDxnLzqRACz0K$6lr&h;*s|SS*^a`1{fQS@v*+|L8FPg)1>Hf-g&(_5=Gb8?T^+=3=26)E~EGjbbmWtsB&DQB%E znL(ybYmo*R?~$T}(<_jwvCle~E)DJIkOtc)iDT-jt23)vu%t%kSAuVWS4Ov;VIkbU zghLk0T%$!4q^_3JW!IT9A3~`2kD5p^>S zoSW8L(MweRvow>Y!4XHb;)L@Eu+;AOOgIu46>J2{tw%l5pl-~36Yq#SUn~n_B}7ey=6!i%7YfLU%_~TRyGZ5|yNv#^YP$ZK^l%X5@2HV9G#9+qGM( zEHB}llxjOuPs-`Of(qaI<4qan*$Z!uKYRP0%uF?5=L5$qXEv#x`!dcOeOVsAov>TW zl&yc^-n_;8S(_tyN{Su_@jp1-T|-^?c0oOKfW!+t)HiE4weGCjoX7(1`j2NV(7ykx z-XLn=7ztd{wNbz|-Jce5e>@#>?C0jSrO>qvzR#)^a!D7+=m~9EXeYNWnJm|nR9kLa zF&p%3YiN+DiIii(_A}RActZU8Hu}Ccc})I?=E&@h7D+%PDP{--7%Mw64ZtDLKc%$1 z>@8#^Lj`jAO|m2ApN+wZecs8)DjQpLZ+E|El^^@1TO_mga)x3iB?RpFPjVDD0QW{FbrpC*i+uY9GJb zk6SkX7J^$GZCT&@5jph)(Z<_M1gRB7?liz#2kOWVeecOK{E{!w#dw}UHW=VR#TgDN z3W{}bRqYRGCg^y!yC}Gx6rjm_3zhcbo3O`A*HEJO9C`*&!lt-@>Y@_)>?(Y@0W_H( zkY|=-*r1QWK$ZDi$*Di@xGm|v(os=bas7sFC(o+zR1~$)itmcdG##~d?YziY=YRQ&M zgQR$w-5JieMyUf7UzD1(AJ>-|^IMF&$mF4Q#xodx7O$-uy>6phe!4>M^MN?~1ORBT zEFCQHk{mC-SoQuSWzU{Nv>kFqsK?1wOSgbt`SZtd<<{sH4bQ+p92KL?!PgAAfe&tG zbwtj8?hc*k!G(sp?ihOLPR}2cUS(Ex8buRi{H8u;=AN9vVqe9sPGHqqJ;GCQ7Qi&I*Sj=Lu-mvyt6fq7G&nlUAB=k@vX2@U~A-QQeh#ene_kKvwtZ_n` z@0w(W^F-EDua_#$zd#jR-BwI|R-N|B=1}*{{@ac{PwffXpN2vl;QO7a>Iw%>v>y}f zSF)c};4!EpKcYRz(GjJ9cc{YZR^LoveZiHO0#BaAI-YWgKXP}Q{e&aSWBs|* zXKMreley#uyQfcE7(d%@Gl7~I_KdCinp%4%IlU-zB>SyjH5V5Dz=e>Z?6(CFzZ^n+ zMKSo@(^1`vYkK56oPiM7T=4aYa6=k9!Vw&&i>ra!W8W$qR7kVJ9C7QlR?SHawB zh*{E{XsYSKJEej-8-5wWQOiFSxX-!_=_Cjw{p> zeHqMkMo(WoB)y!EQ0sCRP1e2m=F<7h#&k|*n$s>++zEDd?d6MOCmzsZ6LsIH`De{~ zo&~VUXt_VN`yc za@)|k;-xSRvWtDd49EzNH?EJEOk^ihBn-~0MfM(X&&oee#oh~VF6$jRBsub;A}?_$ zpS;+g{-i8Tf~eto^HAW+QY()Si$EOf+$RKt-Z}bg(GE`Yn|Bp@lfoX6FH9%qEfxNM z?0sigli9j<5Rp!72uO(v2uc$K=|ogS1dJ7ws-PesO+Z48Ql(2*KmsDtiAb*z=^{-^ zr1vh8Py;0KTg;3Oxc8npbIx`Coa_7HmAn{uS5|wU^*ncpDGZc^GWEVCPpxTpPTW)B zV{wmj?hNtKut)SDs^y&4-1}EL>o&d?0=>fi6-p#qEyO{rhx$X$1GwAW940quy}hVI zCvVATqGxbDDPj~&yX0F0F-%rMtBfd%#1x70&Kqo^D35lnwV~G=nkW#0pa5eEOA%fQo{zO|It}@KSCC z#vYGs(0rqoVwkf01$9z91r)0txZ{{5Ulvh!#N-=NyV?^oviguQC}yJ0y32jZE>xZZ zQ+<=*+J*W`2G61ywI$Xz(kWtPHgd3*LW6H=e3u;IsQG9T($~jm_93SZpFT9YQB8P{ zIiq0ci%juR8<-YtDl}a+#49)v&4O;yjmBz?N5ss81%z$PS;r_fm&WPps>xCy7uHwE z7U6`zM%$s;V&&%Xr)VR6=7(rb1#5vP3S3g^ch#&+ih>2hH0F#R>dI6OD=RwVqKRxH zc}j4+RdI^JnSzp+3I19^YJ!9ptd7}E>@qv|y4W?-W+{vW+FHhvI2 zuoi!*iQF%5QM~-Sl)Lx^I{L3dlL(*h5bi?+(;zVg?sV00jdneeh^M<6t{(YB?Wp~w z_#1u8B^vvY+=G&)44}{YqthxMH;BumOIfrAjcOVgsAGrRxK;LhJ(fUbk2V}AhBBe* zRyMJ0StRdGSs`NzJb!SVY(2JVQuc=f^D}|~$y*Y(z7Ah!-uz5B;@z6Dp~&x zA9P$OWO*;D1Qn@%MGuu4PHrcqQpe)8K^h$Pt&rklk(RByRekGsMDe!5q0V;l#$%}; zAm&ZB^pAf)$?#7q#%t7=Vmhi9O5Ty(O8Xu6f4YwNLw1b1nFRmUDO4W+tJc;}z_!w_ zI^{S1g;w!1#$wS1AgKi*05L2$Ue1D>j z*4Ow#gy7;mlDnP03MKJQ@~y^HbH_bBy{DxW!rq>{EFXr%1kR)G3pMJplLYW1mHEt< z%=j4P58FmPJnetDG0stjx)sMFFHvgM)ajcn7h)hWxMfzFxKC^Rn|&ZiFQOUl8Y(8jK&D4Kt3%eVCHI!o)fFB=A~v;=FKmeXT&=`}+p}oR_zJm$RO4p(YRVbz zYfoItV)-mDbfMy_t!P_#2l}c@cmsyFdo%~E>FKL;D}xy0X~lwTAOL0f5;L1Y;W7-{ zK;V}Ygobfn=ZCS3a^?eq_i$3K^Nk*sqYoyJ2Q|W0VrL+jSEnJjZD>STmt>eqf)DaE zdp*Er4PF~rbe(4n$(YNz_7LT*;qAH?GDXJs|ft3s%MGLa#gUhwhF^jx~ z;2#QmJBp#ZLG)<>qCt!=NOcH-__G?maIh6VI}Czh1|ZG!gJKXzK#>A;e8~@i*|)0R zWB?#b=m4RHP>RuGsimztQ0dNF1kL{x*Wb>j0DFrVMY+&$c@<9a-ok`ZcgtHU*{Zvo&zSY9gStEureS5HQ%|K^*k&<&V-S|M9;PnB{&@Aht3o zHu)QLsEkgajmc|E{&UZSY*Q^?#@yQM&>wiqDQbg%lg-?I(m58^$R(cFE)y(C^Wyms zL>4GTFhJa#20vid_5eAI+46Zxq`JN#iO(GbAPLDBGH?MQWYZ427bfWvMY}lW=SD*X z%9*l`@1mW%y1cWMHlyut@S;~bHkmGt;7D^9jJc*3IO%RFv8T&LVG{~o+V27-B|qX~ z_(saWV+G?C0}#wuijh$*pb@Hu=e0kfa+NZYyItj?I2-G4htDl?zq^obb5+Ua0wxi) ztzHcJpnm%;wpH;TO94I>g5&`P1oL83=4&%qwWU4Ts+Gtl#d+Rc+ zfH=1Ur^iVN`_X~_x#sjg1G`I7iKZn@%k^j6oIQ1UP`B{rWYAs1h9|6J&MjXFc#L-A z(Wo9*U*g<7d%#b+`3^OjzezoQwndD7TU`^!BP5z=NbKk+zpALP%OUUsIM!-KC~_J; zA(#fEOv|m`PqbHKnXR&y83S%PuDM)qIWTaeBU}36iSvW+7cb!_=S0$w&LiK%Sdqx` zrcbfs_s34WD`JP2I>6U=*SLX$`nJOCsi0KkRm!Zs=au4jMc&yBy+XA!?pM0~ScjAj zzT_w#I=&(QEz5!U)h%qa1ety%IqpuUDR;6X2Op{HJLJjxX>>kYl|_GsYjlYV&zqZP zQ%S}qeC{xBmx+)}v-m^pBFjMshpoP5N4q{0I1dY^JA_x}p!IBWg8I92y260xsiWn7 zITN&_!UNu-Lh;(RdR1sXgQsT)PyS*xoJr{zQR!$%B}E#L{Jy+X-u0^0M<7fC@${W- zvrh7;oq}YRK}-e@esBU?pASHv>y96^?2pEGa*yD`4KvnYE^UVNu&1azZJ$fMi@Q<- z+5;>CL$Y#rk=C+%D-H9RU$5?+9^8}Z6E5jQCYM}jDB68WQq?7GwPBrMAewc!I6~;y zHS4uVEO)RRg~vEj`WmC4mO90)@dddDUbN0zHMqbDa-_3msC62LMu2Z<$m(0+I{Tgj0 zQ9+lkENi-(o>c+KJY`Uuzt{`HbpawMS8p9JecC-3Js)oXau;}t8Y~1^B$HoxJgv*X zej}us32*@zdpEfN-za%hQ+a<)RQl_?o@QTHB?Xm$;H3b&F1#sV{(81&H^HE6lhu%w z4Ok5%-~Wy%K2SL1azL$*T9bl)RrXKWyn#<9xIv{46HHdTbpt^4|8DDL4c*>b+vx&rxg9Y}&W->O<@~x%OWV43D`af5a4}>o$=%H}_sHu-&E^OOb zF9LG{@S}lYaDxJ}_bxzbSZ~RKIP?E?GyjG-Lj@)I&ckPWdI8frK7aOQJLao({_Lo+ z9rt@+Gd!T1F)F5X(!=7_&xmyU^H4B>e3 z@drML3-t2X6JE!+syN#-`!5MWDeN!{h+-3hzC$QGVfZAh>SOw|zdl#2>Q8JCg0moKgex?Fw+TJ5*eJDEH$n+-hh@#Hjze!1muJzyo0m+DZq1 zC)Bs);M%WKfydnkZjgsm<^&Nt9eh}obhrzArIyP{X4@NvvdBj`Fd>;)nP>1t^v;;vHf>M%{ z9gqf^eU|P2Sa#e0ZRWD&a1;;e?~u(NEtwId2J>+1kBVg!_#e2MIlxxDS7>6m$O)74 zWTCE%XJGw?j?7caz^2?fMTOedQ~#!Cc@`@#H(K!7nzPkBzfEmpOG@XKuH4hx^K=qh zuGJu;>lrexzs$-N5c|CChHKnQC?5gqWtoKUY!5m^9IB~|3#H_^Wo%Xxx3o5VTS9c^ zQ}f>r-`L-^6Uv7N_lm3yo{y#>M;0rlZ_ET}|IaCrS(HJhT_t3rK0AcN!IdzSF$EFqZ4eBc#S@@d-RWY=c#t6Y!KMDZ(`c|?g^Pbey6jM*Tzfq=N+w!NFNFbF+jZ~ z`on>TeT=caZ^CYKJw3;6Z~`j)X_w1w7aZ8yawom#=G{4N8I;d;h)CjgW?qd@_G7?9 zRxlhyZ~W1gs(2ap)Ykh5ArBptqzd_ze3uT>jLwtfcY!?AjhL; z_+Rw7ducc03!wWDum-;k$*-k4vlRN3ibE4fhBf)LCn#=<#_~0Nu*Xg48*6wflV1zp z37c^6|8||n?AcmT)G6Qj(GKGjM03`ukOOHOh&O>&rbtUP+4{Lg(#JdrxA|3K7CP9F ziKIRf!qq)f!G(})Dd=vVt56p@Nly=4-%}fQI)~~Bd|fYjxmtSzealIpJl|uo$?0aU z&$)#PBPTvn*18)VKH_)o?MgcSQEkZ$H|G2t*_IKde!~0QB-hC|xEtLRW^Tv1v%^Hm zB1!qzMe)NgULyw8CT{PhuT|xyvBz@f5zO%6h>n8`Gfvj`@&?{%=u5Ok3()r2HtCTw zZwq`_SE*yyeP2PUom|&`NkS0bgpYvD#{Hn2okWAN3r%KQg*5%4d(0mt@C zn3&WKUUK5d28WD1&l|b8rPzRsSToCeK9@r;8t=ToaWgbHA)%#1yNI|$Lb89az||DW zo`6H=PuS(y8L;B-)dmn2t98!ZRb~FHpPpjuci;x~R53}YqgwUL$O$PnE5iyoo?Q%v zw<;%=3|1!%SH8xgo;13O;v+F33J#+;#SFB(lZx1(zSa`s2);>$VGg+?1FIJFW9d3ktI&s`!q^exbFbtj~=-v^fp45wU zt`&+6@Gf!jW7j-p!=wy(U$`nzE+uM@W(0|8#@sFJ-nUEZuEaT_@ay1&5-vBY zIfK!SYa#O0lzoL8Yide>*mr~_o`_K+CocslR=^A_&eKSHsNXE>bL~BF{`ko$>B31} zQ+77pJ%JcDvuk5!>ni=O&F&mcN#2#?MRK7N-VRqK^hGDatlRXK*q4F;!Lo|C`XHeF z3GdT!dvkO{x${2MmqmsHk~$!u(T7&Q5Uz zYmy(Y0RLT&yXU!VD<2#zCMY%Is!@49zOQD)g}V}N)e?$6mhAW($y+4LiyqgP$sDNQ zJ?<4g9{2hc*~(@}#)95sI*${t2e@_G_+T}r$-_%J6P8|V<)sEgn|wCt92o21J+bJr>R za61OUNx0U(rUR08i=Y;o=ky_NvTR*iW|6VKaX>x;n!92#C8SI#O1CK> zVO|XzgGHzh^sdFO*9N?rlDpTOlRmTBs&lLM^M#4wVAk^p$X)blxz2h8k(B;3(WiQ{ zzUEw@7D{fociVaQ=}JsA${Y$CtaGp)(p0*}T3R#F{}|qNuR^6I=PF`Y7nN&T!x?nz zFwrqBrRk0(U({kWhl4h(v^Vczdf!7ND~?C-^&+i@do*R$dNyt1`ViB_hNPIRRf|0N z^}!SqwA~j4m0`|%9ZTZmL{~?>&`*;ar2tlB-GcI0S=l|pHeGQ{_?l7m4hXR0Dq{hUGjr!-;JPTkcQZ2cN? z4C(8EB-Spqk6Y)0lTzyTh$m!gTql&B>2{o`FI*WL8Jg3Ak5^sJ#!1GgMjy1SX|R{r z9_gf-f;iISEIG+-6+tS4A(9{6_NeXQtZ1N~B?eyuMD0D`5U8)8^II$OoM7nbHIx5MG=PV1)04*8*5PZ$5VdeHiRzj}~;anb;HxG0J}kK@fb zlN+{2An&<$LDrYE3;ZDjO{i|yzPbT{Nc&s*Rw-F*?Q28%?b2=o-? zApPwV=WHdegeEHy)fV8U`Y;m`DvxsxBw(} zQ`XKvY}7_dEe%?_wX*_*`)?WQtrbzy>@~ziLX%{3R zTPDx16Y0M^10x$|*H;R9YP~LbpE%<9esLcdNl~g#X?m+P40L;t3{LATKM{ zna2^q-lAk=TaUTCT?Xqk$3fInsmePKc0Ca%_n;cjQKkR@`)7~YZ!6Ml@3AY}x(_&? z|3z1aA>jsHI7BgHqQ>C#weJvcpJkx1`13wIj6gxf5=V9eWrIYZ5dUB9&tC_EN6i{z zfx^e=7q0B@K4lx|QDG%*&I4zIqozR@#Cgq2-FDr`D`{5hqr zY^y#cu$=vK=iIDp{IQAtW@5jkg=M;!_ATT2S1aHz$Md24 zK3smjrnZ&$Nz=dUp{D+os8ol?04NOQ1Au-W{>R1xN(Wmdk9J7C?{Ubf>84&iR`*dc1E87l zEhI08!F3XRoi2g-Z<+`)##>3I+3E&FP5f+i=4bOB=F*E*yDn6MipqTI$M6v6rHeub ziOvqjjaQFhr0i1YgR*?ld5O{-j+!Y%m_X&mDH6PMncJY}nDv9G&o^$CrQLYUR6P0V zya3vHv`lqF8j1OUd$e9xr908V|8%TQJXn^BE6}RiDy8TF)JSMnY&)H}ktqfYFGs@jD&0B?3}zp}?07KNtrS%bwDSE4$So>INQ{vEPk zdtWx3c>{HEk(C@nVosZp>-5g58Be*O!eAAS1wbyp2-6ea2w?m&AH>>kQa^)Wx{&XQlPgj;M*>I}%8hmtK8+ zfqn`q)vfE^)x#LN(gN+PmTifaP+OYWn;yb(Z^2+`&CulKcZiTXpONc{=G%k({89rF z0f^6Y@JgQP*PBp%gR+!M1&2hkX~W`NeScj$Yx9+ZunL};7)pi<%J%3p(a(E5rccoO zJ$l>c?dyIrSO_<#bj(${oHK!ND(++;ZE>)K-Pf^$%c-X9I_dZfwnf#ZQ5jS5M1g}v zPGTDlz4KqDQGfAG-`pfdSH-;c;68aHuQarXk%>^RMyny09nyTFr={9?JUG%8`KXe+|I=lKOqV0|?{n5o z9?M-k#=1qy^D)EmGU=%3eF*EF)4~Ord#`ITwdD%haIsvEynR~nsUouxGQ5^&FWw>3 zu|llttG&FFuDP0<9qYT3&buNb_RXE<=4O4FY`%V}Jf$k*zT;c-Q_|GpEtW@`_65Jt zK69P2EAKu8Baf#gxe{|JRgSRrqX(BRE@C05x|RU=S7}=yN_S_|Z2Aw_h%F_!;z4>;eLjGx;|}Yy z{;tBZ-9>p*Io5h}-2U)+Xr(p$32z$3uCW5mngM{a0x~fGEVV^;Yb#ocXIqsDP`ojjXrOEp8>M{c?wv}S6UwwX`vch z8H|*0JzB1vY> zDZ0IS^#;;0GFSM+U1%+vAp(3m9y%%D953u=8ZuUhX})=1_L#1v=5R;``$hVx&I(lz zN^3kvr-Jsdx`>c!KvS_r-XT+lU39|LOP5&?z);w1RLrYF0Pj7X6|Xt&z4*aT^0xP) z;bU&6%0}*#B1OAr@JpBnx3Ph$r_Q%OQ@wIMd+BLPp3fjv%#PQ4<>u-P8<#LInu%BA zC0?f&bmSE_Hh$U|*iT@o&^^zNd_5FfvOJ3tTU;oVw=iQ)@gJzADiE~K3sQZRG2|Xe zu&#wM&L}YxqzbF`8K&F5S)KJ@dv2=Icf!So;S$|fyHG=GRWRm%Sm$al4>C{dOnBQR z8t4i~E!IvKCelQt9WL?oIIrAtw47<<=*<#^&_@>6H%`_JQ3Rc=j&+7j1m=Z(i{Btm zeuTZdj*zPJ&*zZoU%4@wD16U&&33H|R{K&-frBurM!3E6RdU{EZ17 zak7``p$nCv`1A36n{I`}Wy5?+-^RZ~c5d`VD^`sXW3u$qIy{3*?k8YSq@4@s{lsN= zLFDX*Fx~g9k`Jq>SRAEa@HaaLC()p|a5!wMFi-LA1j6-swv*MsdjO$(6D4CV0>7t8X zsN+teP*#PbqBoyzENVJ;OCnfJkPNqi^_liZ=e}-Vx}-)?ZNajTOs&)0Ta@_l3!xd3|~pWNJ^ly;$AIZ>*B> z@3f7$qaz}7&B7pkrsZ4flFU4fPhymj`LOx1^I3P7x|*9B+*VO1>nVFL#-Qe)qvGPj zwMi{+tddaE{)PZT_{P06=E1GraY?S_=EjHLA!PGK=mBE<#_(XR_5PVhmoszKg+H=3 z^io0#PG6EeWMf;`Ran7lfm2c}m6! zoD50@Lb<)w<5EWH6N7h#b+G8;1-?!j`sbLG*y5J8%B~Qf`cZBa6YMTq6Sa7b+CAa# z{K_ro+Qgygma9q55@947o~@W;9<2k}K3;+Ekh_=&hZdOOtd3RUo2F3K-tMp0>d)n? z(eik|cp}it+8b|TQ)OfuF?hX;L~ zcje%B;S4%Nv@FPQ%}&!NInAplS(}~bK7>x4SH)i25R`-Y_C(w!HKt#@uyW#T$e6`v zkv#e(F^`tvm7RUMQ(?73-ytnpPXeGDQ!#6?EAGUkVL0`a-TfYNTVd?w4(T1&BQAWp zSlaCiu>x`6termITbM_TVMa%3JhoSt54n7NoZN-XgR>_vAsug!d z@*7h;C3BuX;DYo{#l4({($3%Bt9;@f&jHw@>Z|TfcUU zwv>d16rG!13pL;$1<-@NBt?9@?`Y8R6=X{BhD4DTq)zy4v~Y0Iy>Z)MXHY>{i69t6 zkCZchaE=NPcxmxyHPg3v;iWPBk}~E7`PI@WZRM%aC#nlS+wCfpiO57Us^LUoZZ!ihnLbNv)VJ*z zB(7L2s@4F(Cz8sDi2DrLdcwcgKBsbT1?yqg>^@N{{!8>TZqC~M@iaKzLL#$Sx-ZiM z$uvgeKqrPO>_+~gW2Knax@k3>b~e}i!k!pgCN>B+L4CiA57frBvdEF!srJ+eJFZrh z`HWY7gdVzqH`*iFtI;&1)@b>iYNg=&)fM+1S5d`3;js&(%w1b4<+TZmn8 z^ryaeOyO~^X-`_Y)^U!6#o9eY)2ispz5T*+K^;_wG{4xDugD87-;k}J-jwU^Hx>0J z7*`K)ij`Z6WE&ZiCkl&$PoS=;Z>TGia|y7j7|6pn$fOX(=$2>`TjG_oFs*pil{k19hy@T4{h^HiPTc^=afPQcLKVE_`O#?MA(Rcg!Wbd6hJ28E#C*?8axS0}5`*J|Pwlaa` zMXEkv;1{`ce>G_z%fBM5eKP9ht7m$a_K|}uyWh0x?S|Te@b!O%Cvv>JspCNCT6jTQ zZwHy?=v)SuQSUopEmle)bAeF-`D?l26H0!p$Q$TxX0f`Wsxne8 zx3XFlU0C9|`&1f(^TKH1o$F~s(MZ!<}Xk&iZ(WsAo`Ao z)nEcv3NI5M)#PEJDSPvr7z6QqaZidSiI<@^7Tsmk8>ficr}xYNJCV8GMW3G01N8{F zT?Sd#liX%Nqk!AvxB7718KsbC$X`myb_=rNsT{f*<)AyDdo?8kxU9xEuIDOzM6i-9 zh@(O}NQ1s99mbbg$8&-hRtJ@&WtVT(b!fju(ylzhJ9?FslJBFe%+-AcB`Cw3o(n4V zGK#%$Wdrhlaoy^-rvtd|LvCE&@mLN*3%MRS=r#P&+`V_7z*V<_#4MhfH%4sBSkkD=G;z)9%R=1HiK{18FjdE?_xCfAq(8~9W& zW_<&+>&zyJ0t5i{r)DV?;LOmPf%ypnWjWJJQFUoqUkm}|A@I3vZroe`94IotoMB`X zfctJXJB6aO1Yf6wtR0-4`bnPJSB+f|GdEnG7XYfswcrqS$O5{)3SLjM8ve)s+E118 zBj>E=>jm!w82FVvMU+TzOdEzSJ7dXSR&XHrwRA=0_tV@z-4!9)&(I+-25L~4vNEYX zrXiNbXc3+){G@=@QuX5o`(dB<7exvV*}b*vG?V@$&TOI~qJ#57ahI^wgVC}e_3~W3 z^KE|Dcb&N;(~VtQ40&#J?Qql4pU=4enhbGr1&PeWhrg8mZU5|5=OJC z4GCGJt716oj+P$Uh?d?-f4kicXLc)F6(L-=hbdR)GE@V<`c%@IbqWNb__-d#Y3J9+ zY%Ohh9;q{VK9rlJ-O1(u1#^%b4FZIfcl27P?OUZnq5Hwt6nFZXZ{@h1I$^V zR3#lAZ-V%5N#k#+I_ZLn)}I6vsd}QSL2Hzv?jvBeklMHBMR*F&(*Kw;K-rc(0W_e3 z@;92_whF$m2Zx#k`|8w@Ux^PuVLMSYG5Y}E z0$DX|S~*;qKg|>HH_V2=8TOFj`8DQQkaHP;Z~7!HVxY6@ySws70J(yRFuK~%((sE= z-hYA3wB5Y=SK>qBCh-CGlzryZd$vMgw1Dm>-=}W7|JM|$n~#-u-aS}djIZ;Vw_9mC z`q_0<>F%DwOE0yOw|DVo1@))G>nrW*=>T^ciecVZIXkbVlwPLVQm&Tp!j_bO{_Fb?9= zdRKhbi=Fv#|2m;A*)3!HcA*!H#@s(!doY(9xK*&90Mw4N^VcX@9lm(+lLlZ~zerhM zKA&D7Mt#x1uWmsJ#2h|xgz{(28N;l4Ogtt56Ws0#h2#2Gu^Tsu61w+$wLe?WWg`Y| zp1XElDY%6LzS`t=#Nj&x`b@ba>jHeK8(v#98^Ypxq}0UarS#d*!3_4SLay(S{6I<` z%EZLRF~$GtX4z#5*4xc|!~6>^Oh%J}INOp)PE}(6NRKL132uT+Qr7JBtwIUC!zQ_9 zD%q+JjB8as=5e$n+E9_0Rivv1J@!2_)coutv!=ffuq*E98KmKbhTN0jXd2uKvhJ9* zb1CYWG`hasw)sxhbmRQ#J0kbII3*{yrF=RJHz@FMq1ZBOV^6ul?3+0tB z)kdWd!>da222DTbtLrw{zI@gkc1qFab#Jc z@jottSh2x>IxkBSAH6v@p@ClV2`ZlR4T`{wp{kLjF1a!V!`IkPR~m@U&Z?Y(*m>pN zN#nV6pc&|fQ_zxI-1?W#EnGJ`dMA7V1`J3s7S?1IXI!@9zZWC~G?^dTIzk5h zRd-Akmn*Ba2zkM%ZsqWG27;JJ{Jzig72IYNF#@+2TMxif0h)r8(`3*+r@IwF_cmlr zcRs&;^=)8YAzFlj$q3*7mv?~oex;lsBAEq%2vQ+;+lR%plwY9J`4uL~uY zylu^#lB>d0vVxOX=5voat8!OgTQirr&~1E~6@s}+I1q`-GTS#DV}-FxM4W0B6FEQ| ze}etM7g63A;n2{ipq#j18;e>67)G~S*Yyl>F1t!G-WQpn_=f2JXbCzJ1v!}Mw#7xs4id3x5B-j2EsHucmSvkV{1cp54x%r1#=^+_w=0d zdsuJ9rzuBwK-Tigk-h44x}n0EpHjn|(y2_&Gxm|>0ns>|aE%lt*Pkpd&OML&USp~ljl27UI$Zr zTy;kWIf#?Pz2_CgP|V(WK8KFz2I+KvWA2z>;gMd5r|TY5F?%F-{hZ1A@d_7wK}sUB zWz?-DkVGx)*dtZodla2+e(JSi^4xTB(A*KV=&=`byt+Z->U5-`*B*f7EJzi97c~HT z=`{J6o2Df|qKb0Xk^4JjI#3FnzUZJ{*eM?P{1}>a!-hhO_w^OcF1u`+%VN~UALQ=H ze$HcelINGpw7CGtG`!X}A+YK@2^>@p4S<*gP4yk(2FPNps=r*H=GK0W26yB!LLuE9 zKvM7n^RVF>9=6*19YVN{+IB@X{2jv9ji|x;RPVO6|IKg)_uDVKMiaEFVnm{A#C!|1 z3Mvnm6dzIPtFu|;V+Cf96^nu zzM)rAsSzJ8q47bWa&+Q5LQ}CqO$X|N%m&+Yu+@LXHpfTla zsP+5#IDk|gUc*HYM=K53L#nhdMIG90H0c5|+A7~6<|Qrvw!i;Nclrmzcm$u3g3cmU zQO@jLU19y%Bty_7r{%1!Lvz>&=*#unsC<3XU;6GJ&Ckx?WPp z7al_*LO13B@iHC0?FuqewmCzlew$a>+LYba3}Bca0|%zT?+|~$t=x8=FW`u zb5y+|m=Enbv!vf9*hgc$whw$*ItuZs)8Eo-R>Lf(3OnneN`P6qdZ5-tctH1fN5CeF5pfZBrc=z zPDVkv?_Sfo+5cUKD5t_Yz7TSFW>S(mHf?U?*Q%NYjxpL%@az9?bx??LTAu{tmLm}BNxGCt|u24N!fv(7q{sng9 zPLs2Gbx++{kXBGXIoQ)zg;(y09&xeJ-F3I&7Jak(w_PVYw1LIzaG~$AA@hu!L*(NP zm1>sIm-AJ79y>j-)hkA@bXE=LTz~Uz+`fG8J5n=+cl3dhH0*KpHLdGC5oNJq4@rfK zO8fAuwVX|`_iTgEiFZ&-I}6(#&B15Wp64=Q&PY`AsIo4beyy}Bz%rG6KD}{cX~Ef) zY99gA#XMcv;>A3gM!N%o&K{XkX+Mzc;rf1g&F)aG%<|mwcgSNUj{66tJ(k?ymvf}VVw2p{Tlp+nLsJl+rw=u>@aNPe9DRYO9d;v8 z_&V|ttkZU9E$%#+Ws3{h?etoP2EqCGa=*-&cBndIK$OUxNOE{8pvG%E5@Bk7pCqR4 zVYud&z0Vw9#SI39ZIH4nOuH)>kRY$e_IAzZku;DZ`#o;Kh~lk1%ESeH9_W8n=GT2W zPadUof~|DhR^?uc5`Xy*LaW~&U0o`rLKdr~mr!k4o7Rdzv|!tooutwxSuA#kE`2=+kNWHKX6As^3|KK)c)`Vo=*6jFXf50c_>JwuYiU80KUXfdK&iq z^Z)6R8TBWU-Dqz=^l~oV*GsZ3o0xk&Xu_lO^@3&AT&uT5MuSI=KdaGx8m+MNYsP1N zg#v9W%f}81pJlq7C|pqRY`iSi!ARH4#L>7!)h_TlL7);# zO-!k7C)Qn&%gV;sS<=Min)OViy>W_KY2vL8j?%uIwA1YMp~&W&i%q)t!4Fs_r1p)F zTkT>F-QRe%E*B0Q?vi3ZF-m*FK);a@cINUkCw57yY1Rh4DY-G+-lhZPl7e49rShKH zzwcQB`ch;UpdTGeoQkO;0^vc?{NoB*>o$&i-rVKWQ06V5WgP&D(;)rQ>TbQU5*Z1C z^fkc>QBc-7Xi9k^<;=fHwg{QUF2Q0c3AFSG$m!!?v3 zq~4K0EFY40%2zuYrEM7&aoKpS?zD{5(k2peo=iIdcvNNvz@12|eM!upva@sf*0}>w z@pr?Dch23g4Sy1bVwiN*-Oy-E`(%%%G7IZhJKPg;BrfMW#6)r}b`f@Q;iT{)v)AX$ z<+;+0kk|D^>agdf!_yT8@0?X!2Yj*Ss2(l4UWKzl`^TTZ+G&`ELAx3fEY7!fcfoff zvJIJ@aF?V+KsT-~B%n&zyI4a$zyq=S?&@MV4Y`k02D=uQJV6f^YCuqfhj6X#j6_pxcwhU9?YsyD^n@Z5bIV;VN^B*4peJRa zy?0V+AO?MEI%%}7hNzMs5jO42701Tcd;0ol2874({hlR^HXir2czBLa)h@$aB(-0Z zo|$ezdF4!I*R*9h56B3s~82DIPwC5>=N6 z*P%_{A?}|6F8$~GO|sszt+)NP4Uc8F{%Wy2Gp(+YXyKcZpxNYCGD1p zy>*iRK@j}8(g2<4=le~9AmFC_xgzvd10+nn(w!EFqEq*x?6inD5)Jg2D1cJ*Kn{tSs!lisejhV3gMOOECoGo4f>DneTQ5` zZ{95T?4p8xixH zG^al|l`G)u@Rm=Ko$3I03&`qi>$88-**t0^mbALMS!|C=Pns-4$UU}|6X1mM=fn%y z4}`}Ce1YOh8@MF^xc;{{L8ro}OU$i9evRI z{db5~T)ZqscbhN6meJbs_OAzS>u5jnyW?VdW~(ubywsVDyE|7 z)ng@X!m2OGI}+QKiMb<&H0T3F^Nov^awGEPd(|`M6uYeDdQfi>^d)q?<7#JACXOw9 z$!_`-f-m%~;4vhrv|1~%x}N%+*CegZCN1Zaf9eT2R``LDxH65W`c4VS;AS1Q-4eq0 z>n~03Qc8aHSYhJ=Q4ODQnm|u@UoCPxma+2caH`vXqOAK+WC=o3tU(vV}z z_(#O(UHS3-DSlG=m%feX^{SlG2|U%xdElIUl8Vm&{jP`9^Z{>sPN>SRIvEAu@8Ni+ z#_5rnDM6>OTX}|A$sbY+W|m#m;1Xj#cn;H3Sl^Sx+9WMPXnL9yz01NBR+C@6I#W$f zH+b#8?7WEeHeBEZdSE!0LGtma$8v-J2-Ocn9G@g)zqCIe`+&}G?KgX)5}hOvQr4fE zrxZ1Si5>%4t6MSsoL4&Nd$q~Qlxe}O;!eqh&fp!YwHKt8#K87Q`@{JRsr|M5JA|*t za2@0(!67`jBJgiHSM{5#=rp*`@2+{9B96PfdNh*rW{rS|<8SUdah zHnPX3z!KQv;JEz{{EPnr*i*NdKjOgX&sO%Zcy(!mTNnp4)SXDd{A!|Ahc_i5L0JI@ zWHo`l+OIw;n;Ya$C&PcGdyKjA0v8vatWx0O(qRSpbgVtZeDfM2O2j1Z<>NA#S z^_YQMl>RgL!#LNZP0tD_6POpOAn?F|y)VK4@7!Yko^A)7I=H7Rf5?8;XeTH^AFFFv z0$y!fIyZ-;2zYkoC`u`Kv?KspFvm7~kNx0bo$$l^Z_`I2bn~(F1kAi%>x~6qP895$ zFr`}Ue&||??rxqz>V`{j$KOnQ0Qn{~yc#~=3ySZ#|6})!zh^$xeTI)&f^e=MEa)}> z%KE1>>fhKXf*G|%a}WY*?|*g$2z0l0gU-_Bzj0l0>U{GAn0%PJ1Cx(02XcV5oZP=! z54phCxE-MD%K>ZP8UN1)!cY=;Rzj1Jy|voVgytBLh61&h=X4W83m-Ih3gwM*@X9VV z9IeR8J`Y>t5rxmzz?bgX8ZXSqq`QACvkK7`XeQipv$pm^3l}Ww$2gz-pagUORD$qX zzL}5NKaicnU);}HQda<}N>Un8A5CSaR1ES!+XMe`7nk z>nF4fk~-$YaZxfg+gLv08B15Kgdz*G+8=;0nx?G$vGa^YVaV8YIiG7pp5xAfRO#vd zdmuNiJf?*(H~_AVf#VRAxJED~`+S2hX-G{)PP#+BLqgC8P+zP7xbIxmX&%2HETDfi zFUj~K)2f*Otg=0BPk^?p|>!R>?*dpoBW2W6WXToH;c(WGe}Vlz^*TBBJmPf zerQFr{NhaVi~!tES5fmIdP+I~=Kc@u_Avz^3Km3`Z@sYf+J<_*-YDEGYGrry8p9t5 zz279JIf2A<8Tu#m(3zI(2!gTGf8LG(hbBzI$#&M`Hw~DKd`=i7sQ3W# zl^=zRl&^h`L=}GlJqJ**kqV%KkLDukna{MS&xy@H9!4>(0)*Fgjeu#{>L{+f#>)+I z2A`2`-5no44?7{+`THKhx-NXIn!bEn`gKuPc)5KaeEWM)C4@)d)BOy%q545KlMWEW z1civN-AxB(Iszp8Rwe0}2_z(%8rJ%1`zVadu{R-QtLIICKA$7|MzlXkfLpme5mlo!F$ zW%`{*9?HxF7T#-Au&dsnf8-mV@fWSRz$1476gZ?6H$saALe|pUKhvG$GRdoB$ zE~kCAh+Xoplz5jT{WA7{8ln0-P}dV{IGhokk*r!Nj~{0WegiPV<13*kqcQvj#er<3 zK_k$fzz;);4>_9s52!?>=~?9>jkTqGyKkGmm?lxvg*3>;NfRj7 zz`|;!IsOH(2ye21yH_Cn4hUGm`}*TAx`FB$I1$p!h@SZc8z=gYHr$tz2|xs&c}YV0 zT72`z|DgkJ&)SvIfPbQLVx=MUREF`t%C|jdKR~QayHt*@(o#DZN+FD!7YSiD)m5;? zM4YtS+^7qzJ`D_}QBbSRrr@NFl#dr0Fqw*6-q`8w!A_`ppkisn^2HLn$Kmy1iw*Nt zP8;W<9bGGPeNJV7C+%DPNMlDf4Oz3zTU_V?+WCs7j?4&yO_pH)Tv2xJiSeG zpmAY{?@7xbVfPi}(|6Yg5enM>BcmYy{)E$U2d}rUq4{Y+#;sxkuz|?X_%m0tyC^X! zId>iX*P@nB((!bDcQ=3!4_&>k4m7Y&^>n|0!~t{=)9{ksSNOrzI&Qw7M}N`Rc}W)4 zK_0id5REb-nRlsH&GfLFhS?JI(l^IKS@uM zs|c~wCO}>s!Aoj7x@ZnY57lE~4Nr^l6}9~R<^7RaUadaY@VPF4LQX=+X6soJ&VEJ3 zN{luG9iS5OL+5Xiwvw)}5on?s15Or0*3^z-$ta!|#Z46^L9?e`@{_hoP^A!e{kvn^x1CBhtFB0O6Th3Eq^wmiy_dxMxrnk)3WM=T@q$%m$Vm5J%KMC$ANXThvsFi&(FX3*X=^6&+4HR z;1;(4KVIBepV0o}uuKfTIXyFNC$%_uqim9in{t@dCS`#yK|)j+=IFA``2Kv6Xk)Iq<$>wjSL z03`EYU3LKDcl}gui0}Y%n#sxHuPpg=i+?>D<%t;wPlH6kW8j zRVmSGqTjCFy?yT?5z;>IA9IltR7wD~h|iZ!;3diw-_K6vUTVe<8=zq1NaWg)2(?Fx z<7CY5f+?dV{)S2RtJSI3)JYQ-fwGPy8oSsn_YZYKJ|$xJ#hqqGDdd@w$VvznU`Y zFLQmd*NGg~W|<@EdF%?Tcm`Zlg+rn<@7%R-=D?s~a$356h=I7_kJ!)MYs9PLKUspB z2nN7qz0#j7HmhSLZfi9*CaQj(|0*iB^siFMpHP$|eI8)Ip!TGJ`6^SRYLZLc+u3c_ zk`y$|#{Zq{t^g{sKbr&Acct*AKx4Q}F67KaidXrgKpY2#vSU!v zgTA|W0OZakOqIZR&pKSIid-HM=yrYtVl_5Ah67C?z5OZr{eyAoRl?Rd((AfiY{p4} zF!j_Ih1b#FEYoWHs=EBibM?So8LjXSBocqXno$7vh5uwb+BK+Q$I>0+8!)?ERd_c* z>2lgTQ2N$(iL0ysCZxFXpVijaeZ>b~%;W2wS&B<83B>VyhE=Bkf{|Rr6xwoxv3z+o zoJZfF1-*KbIiCCI-eGtLK5mUjK5clUU^<@qhtj62FS=;1dKO=Pi$Xhu+z-t0p@QN0 zW!fsxP4X(4QaWB*<8P>A`<93~Y`)Nh(%^l;5mT?J_*}y-QhFZ(r8Y6==zEh1`a<6r z3_wv)0G>}&0pR)Dm%K=2(#AfmvMYz*Fo6~eOD6!us`%AefvhL4GR>-8o3dIH9M|_$ zfF&uVfJbH-iJL$kZPAx%h4U*(IRJ00v3@9~Y5jH(WVdpD+kD(+HJSk_E?c7c+}{y^ zTK?Q>agx^9f7I630TCIHVxBwz(B!M@;OlKE+RE${5Rkc?mz)gFxEP<+5|?lcKUYS} z7|~j~IvCLE^~$zNQ1EI9T;E#YHP01E0v&Py_-;)T7}=fk`)vT2NrjB*j@Z@q}Y7c!1657_+y#7h)I7zG&*FJ3oaN{&VOtE<@>kB-7_43Vq86}3r+LAxJuz{Doca3M8rFCBHp}6TTZ{XX#>PS8hmZ3KP?j}Dm{Mh=!Z{hUr;uDg=_zU4UW2$ z&2QSHZu-U#yhZE&(m(_m)Xnsu0#lWX7LaUbK z0CjyanR{I`iVZv67sH%a3Z~gBu-vlk->Jos0Z$D;l*4G=X%620DOzDraXcFuIB&%^ zP%1KZ0lucgzi%FZ10%u(JUgBR*w(9#hyhgXsxYkLU3CBAd0D?_m$1{W1BwQ2P#inyWYrOqz{^GH(6zF-ICPb&{FY@etB*gc#P8ao zaZQ>XMu7S87tp=F>%YuLbV(2U8Ysc1p!zuNt95Fc4k&*aKuh<;7l6m+Q_u8BUpo&F z%(cZU>+O)hn;Z?d2Rw5!sT>K%POaVrm~R+TsR6`VzvMeY30Ck1eCjfl zAOP#C^cYCBGD%s~nyn}F^V`<6K&H67H3T?jj;YHFW(1JBpl<^9mlUwS-h>ydX1}@N z1}^f10B?v2Je^v44UjL*0JpT8+9OUD^@=O!kV5v+yAeb0{24LNjZtiGx3p|0?fR@x z0`IbK?E4(H`UL2uDV5dKRW_!Ij?El$@ypHC8&#vBF-#T@-xRnuW#A(9U(?bUAVLD{ z<<+SHUhdk(dzm!I?J&Ygkz}fkbc&?btDgHq&XioVmg%^?#DG&N#IdcDky)XFG4B^p z{N3%7B8qb?#iyeR6l=${2}lWy=4cFW9`4+m^-X8l5=_Si`t6ATv4upP;W;b~;Ci~c znMFw~e|FP|dmzlxBgH4^8?~RRlBC!&9oxh0DgAn=3PwK01*Vy2&tWOtzR2@4!-|3_x{wih8#5yWGoN9K0zS0D5+PkN*eJb2VKRE)6Ua;&JKc(sZ^W-dbsqj%{~s5qb&1yo%&hhLWKa zSddYxAWiiUkWp18NQg$~#4_74cV zwZD7v2Xqw(D*xl6s2!IV2xV%=Y#D1lRVBvnIHN*k=%nn0s8F4ot}eG&Yx~sQS`p{1 zw~lc6Ym@h7l(9XcmiwUJ`~2V8G_?0~TeUop$%FA}-FGB|{I|Z!f8z2q!8R8A0z{VA zr(Rv_m1NkMQwrU9bgOLb$af9foXYFtUoZtkn)FD~!HKkg>mnEX4nsi5&2?Isy%2S0 zEl@afMaY^QC+EVORWUoF`&P8^JTd&$TM!2EC}J1iKc$|(J>_BWJFhZY(k*!&=IYAQ zKcv&pQpT5jk1w(OYyBB^t!Vh??aIMJ6zO-ZuVd6qVk0X?7j#-YY}m7V4+BkEc4N46 z{kiXhhaVD`-v^=7jZp`95>!kc8x$f>NEc^Zqv)gBdk3iS85de)UoRU+S$h^K%Q|Iwi3 z-+ldK&A}7=%Js<)TPHP1T}X}}Uz+JdWsE~^W5@7}Hjw9ozj>v`nNtP4_Ex)PZn;(| zK+qM>_=Oiq>3=N)k20@(W9x7CKKGjcSdnetB4R5{m ztozqY(h?)(*OEZMHN@{Y&fhZ@fHC?`C{y#f!6O;VO`DTh*?{o{5jCiLOBnMf?_LT_ zV0taISDf9{B|1>Puj$!;VU21yF)V8NdM!Jlt&4`7%GJ_K1pMQAMl!M>7vimn$9+m= zW9_Yj;%s`4$HW;oWY4=|9Ns%oD*vxt=5ZPpFmC(oaP_CEx6(fM?rgo;SH4MPVvm*% zum)WTN7Va7#Rgk)t>jKA63MlNm>sL0^$32$z4vErUJWM>5X5~8td5=5?Jm`ilNh1H ztdm7%@cZ@17N?n;vbRr;?0MuKGXjG+4KYWx{;%C3+IxhL^hnO6^K{Oqs%W~D8?29B zv)p~p!VuN^vsq;?KNl7{;ILVpVVXmGyOz>_6RAaz(p`P_=c>ow4GdoD+zHE6fQou! zShFt*8qCv|%c%><#+Kjmi=Fx*(j}!bfMTe@U$jg=j=NHRUfBr3jEp?`7c6!gIG)U* zUv}Nt*k@U!i`%*A)aPHPWZotdb@=&+zS%x?IZGga&9$F;_&o0O(?hnu&EijormQhz zCyHbjh6TE!0d-Q8_lczlh*clB09Ft)eSkgO6ZqNjJSpi+tMfx~-j56r*h zl!8BO4Xm=BYkaD@pI!bW?}cJ|EHG-%Ps-aqwzt|Ct11U~?S=UVOboln)$W(c& zmJ43dkpYax9x52stLPX0ZgAdM@W&z(l*)8ZCs2I0m2L+&D$Sz_b~;Qo@<&b3zIxfFz$NXrO`d?dNY zHRj)SoS%2~LDtppuZQrOOe>J?!{j$H;g>Q@r}*R6Y`ctXNKw;5)bl1fbV_VRbAsfP(#*T9(!U&aE|S zu}F4PrA1Zr;qoxrtrbPAM>VjGG||p9JGLl%wgCbj!UMUU;-IURP$M8Er83QmD(-`u zeJlURwfA`WJq5h~jaN$_;Wb{>EqG>R`&^gTh~{lFGHItoC(qxNN#SgKQbe!YpWic8 znbI$OwWQ)N*m!tnSf&odYzI2ox=5zh5d2RwE)BEu*09v-ho8)mGnr)q!SBpY$NWZw0&&%y^rz!K5|+opyhhor>?Pc^YgNe z3NoCxdMLEQ34QG$U;ej2;@~30wQUzH%#mn`E<+pqjCOTfD;0nEL?I1lTFEhYu8%cG zJC)IlnK_?nd{=9)aI7zoEr~#J{V`^nFtv~e2q2{HgAZ_GB zw7hgotd*f3&?f5-Xj6pn;_lMb5WP~Ox%tDNOKkKK0?QT2Jr-=30gwNp-t* ziS`Ld{WX%jQo<7mX1hv@-}lh<2){pDgCqWmSn8HVEiZX$5G6keyl{$1k4+94r{OkI z_qQqFa^q?|dT)I)&D?{p{*3F;Lvx9&heG%E5Q;hczlx0meQ*(;59!GD^XQZ=N|ltd z?MAtSRExS9Zu*Jl~>AdZl{;X26UizcEet6Wsi zYJX-~9sfwpIJP(AMJAOXo6iGK(vKR>pZ?$iFxS62?S7QKl$y=`{t38?0Afv_Unc(V zBcRuxRj)SELn3|9F2v#DVK!eE5Mq` zA;o$Foxz_q5@EigtAq+clY~uVEZVR$9A}F(bo50W57AxTy%U!tYuv{};66nKb3s1= zN;Ihv^gWN<&4WB8dW;Cd+4#>+HIl+)GSk&kca%c~)n-8;OweiV&ows906(Iq@Nyd7 z5~QFb>9vkVbWYcXp5EY`al}gy<9Q4SIYtgSR`(0Ms!LmZ#q5Fwkp_KsjL>bPrk;Y1 z$-)dRxn`7zrqcGMZC$0&%wqdvQP?;oX!LiIa*{uQr*t~YoQWd>-rDdU4?+b$MBLO~ znD+e$>S0{g#LKwQM)*LfG?pHXRY)dO)irh}@E^~571AddTTb2TtOBvHJ(2E25=APu zf;!#WUF1i&E$^-qfTtb=4CXpxVP~yAsWSJYq(;|Gq<@EZ_^6>w$A_}qWKxU}l9Ji( zI%^s}8}WM0b8_mG#g;wiStsV>PzVD9Lk2I>;`So_sI$JN`UsmDw}C~BH1VvB(R%#| zTQ?Hz&EESWgpol>b0zU2!T^E!CDD`^En=+wjDB!QrEWNW5g%*F!O=BG_bdM-94|>e z%edl}jQlA|ls)XpqEJ&+?WBkeHo28SZAEb+E~5tvXvLNprLe3@b0&D~fX*9?I zFYy<#;w7*mB?-&dmd?^)P|DNuDy3t0Fu#$1OpV9F8gC*Us>#$cTZpvPNC{A`9uira zON|xhR!yKNa?dQyEP$t!Fy9n0+8an?M}}(MbXHA&gR!OpRSHSPgym%M>KlAdR^?*b z7@b`7n0p%#azmT0GEx&jDZkE0!rYdpA`KDHv92W|5sXZIi*G>z*&`*n;~B}?l~yv1 zjz*5QHr7gmF+G)LF{T`a43g@)@loXnEEHD*jO0!p7BAV3q=3s@XD%Ay5DrEKC+1?> zO1TUP_j|Zs({Z6n!QR-0XbP)a&RWIg=<12ZMLj8EvPaV|L{fUcYUr8D(fU3t9&F@9 zw4r#@)FbjUr&&V#42GQ+jG!t3qR%Lq08b?0n1(zaB;q2Cb%)A=l*?(7jOP zZV|xT8?uUH2{h;a1@siw_-Qi4!GSlva+E22I*C=rr>wAznGpS=W%xd>K<&nmDXc zjL6RQa(*EE2$b>1@W#t%GVxwk=Y*!2m;kfKHYP)9~-zg`Pn7cLnp=^mOsxP#2 zk+3a6r3!6btxrfz<2lka;#IplUI1^t^ALeyTh6qkiVk?5-Zt8w-rv2f6!W9d3kgsx z05%6UGC?>yYQwf>MbB~M>3P0K?ksw+W}?$QDM0O&W{{0nX;3R+NR{!_Y%w4mXY}mXMHYVs%U8ERmeTQ~fZmh#_tE-8U0Mufu+GF)Iifq+D_(|A z>MGyir5MK_uzFOt?1mI&$8Za;Q4*kn+yjH4?9A;B28vlv%oHeU_HvlCU52bTgfush z1)s~vqLuPh8L}f}p&UWWq6~T!wKzUSyWZ?mMJGiZ^Ch!~2PEtBS&us;)RcAPZ@-T4 zHh<2#B*H)%7As2#S+vsT(+vXaure9hSx1~bm-R{2fAxwLb@62jK7J_DxXM2}(W{K_ zF--98rtpKA0)_Kt??9L!ca&FAvZCXhjAq{Yzdfx6@V3EL${#>z$ajd^Nn(}@p zxxAk5({>H(8xFpTkP_c|(v`Fu7cb$saVs8lVC+VU(VWDYMet_gpJ00Kvt_6slZXISnyMEVc*}Z)fMff^!g~r;8caGg>QPd1+BV-XHrCi_mHT#0;r3` zEe?S0qsdIzme&;7Ud|Zj>3m>J(Aeb)#UP80xl4YRUaA?=7Z7Wgc|i`lFrYvft`vYg z=kg0PYU|UGy$1G2LyB!FBTLwd2fN0+pVUy=_So#-Ch2E@~mOy z#*GN-W%VO;yShln&>dFjcNz2(b~{y=eoXMmf&lbpfU7xXIODAyl{N=nOEeHgEUz!N zu{T$^W(BwjpN})Rk@lRl?XJyITaxscRQ~Q$k*%Dho9*!$_t0^7ytQ+MB6;4bFeYE5 zzY4Y)T(1q?t;?>d zfm8JplA2806xrCk@E5l4NM|I_FPexMh4njTJ>o{@Nz^w<(ET4#tUvgXK^YZs9E@W zE)xg4G$c|06JGR^_bCUCDeTCJo0g}_b!FPvq7xyVT(@yV1=W*}I@F%j{ z7Hxj6jv=NcGY;JlOqFXfI!z9i6JM5_YG72RKEpdHNI{FX{Zq-V*x7o861fX^xu*8NJjvf5lKtvc{?qqHkf`=cAkU2d zHWgBV?li|}5;c%m>3>9(7y6*AR8U=6UIZy=p@gW;UmV}lyP~Q|g^?P6SmWC)KbeYa z!MuX4IS9TXi<8+|D1^oPDpMH^yb=|>hzs{RnbipKTmoK`Z;Gaj6Kd-hYGX2;i}p%$ z#><1m7IGOoTf5q*1uB)3(t<7?W-)U69ebszH3U7a3s229zzP(0Csh;d z4foK$Lf0k@GTEIyh*QseBGZ$>Wigu^+M5Zbd!Vd`Q&m=MW91lLYdV5uDzd~106 z(`KFk143v1z@-Ysed86(szvi7;}~C9vUbS_qqnJ4!iw=ng40e6`&_NS8i_Z0!p z5drsD*=YKxv-QopsB+^XuIya8{&ZxKShZ{Kvl6+iLjcOqMenq1N9^lT|GF4WL{>hx`i1O!!^jw({rH@Tcf^*07FITpRFsM-p#FeZ|qgE zUYlrB%)ormgZ=R{-JY6eD8doy6L?}r3`v9fL~yRPS^*1~f}krIkIFg26nO{A%fS`Wsy zdlMt;^G}7Mb}YQmf{6uvoZoU>&|6&6+9y>3Pm(6Z>Le)daENtbw-TrY-jK}7N6WZb zTSU{oNDaRkLKEVoD1)_dPuhi)tGfvQR?KJ9+PgBUl$6A@ul#X5XO^RaPg7B1JE(7! z!cL2z6{VBTVV&bn!cuN_Yzb9m`l9tQRq;)hRbXsxnSj9anP+lK>xudHuPu?O--Fo+S>cnhGaCripvXM|bWP+|d^2>+WX?%HxM@4lM zZYD+wQC!)zYj$=5WJ>G(i67Hyd>k9&nzg zfrn2J`Eb%_ub%iT1wEk7$k*?06;zcsOUxD$E0$B{)_MbC3{0i!R)QQi|-TLM=SX7|b zWo~RrI3RK?7%BK6oW#||8dV(>Bi~HyQk_QZDFVxlIM^!|(X_dfiQ~6Km6bqRg65WDiNLJuY<%dSB`6^Sy3Fi-j#LJr6 zZjoUw!9usH8BLH+Z?5T1}zU$bU0dS^n564njL^AiVd9F)Lk432#Yi{@$aotZW8xc z_>g!a{mAa5e`I$F|314bK^d;+bEArC!vpP}qA8V*S2WnQ^MGRl+WFxFj$X#5^!Y4e z{b&69p9qMTk-Sw5#^6J;iiRr!A&VliJ~>(g`C3n92W)&^1VU#`MMM8|k|+8j$s?Td zzLMlk{v*ko`^rE3NDOQi9XT#d#ym|R+SbAY*|?nD)Pk4a2N${TOxBtRM(?E~2UNc& z_&m5^XvUya9nhNi_arYbHo#c>r+)kWnit=$VI$Agx5v$XZ!J zyUiLuSC|=_pC)FOS1&}x@1QRj*F215NYob+X9rm!tlDew!ks8M2@wY0*p=kRjgM`L zjf6Afx$Vi6>>XN)4Gj?m%XTP_c=(<0YJHflM=#lnsze z+uLKbKb=@{Wo|WK9`h7h(I;>9;%@5{6G#+!n6(hFWaDgZ>iAJ_ww8g{gE@6jz`k&# zKbFV8hgn?KQwkOr63b`@ z>U&eT%zxV-luC+-SpEfc+h887!@gUbZAIS49=A|_DiA_Q$(A#_Bbam4vn)1{1%E%G zgzu>aV-9rIb`pK2YTi_(9s^;w*{qM56(G~BCK>5j%b*O5wOoADL=rHmp)8SbT)cc;jfPS5fG_bBu;bHDrnjg=MJ9)1&xuh#%eq^8vbqMn4?)G_)cO z=EHC}R&Xnzt$hZVeWY2 zPrkr0{-|4z7fB|77dH$%lIWavGV1+=f-k{~WaBBF)fBuAOX`1`kCWwEL^9dzFHg>7 zGd`m?>2&4DGIfF$L$9TcRPUdgM#8H2;r~WXWCl zMoFHPPIgu6FCc6^(aMbcj?Enb`^u@Ylg}OG(ne?I)z3$CPiRA!$c`Sr!Ch)=ybC)F zv$3`F#&ecclZh9NvVX^wV?y5kWBFW{U(l}o>4v@wjS*)a5 z=?WMWI6D&ApkND!=*yMd9^vCeXc(e2a#~RnRYd|#nsw7|4H;qP-IB3RXU9YZx)UHp zJ^m)p;!QAmfwRKS+42p5c0Pbzw$Y8xvY+!Xex@#BELNzaj5lMGXG{W>xj|DfJ9nkZ zQnz>zk!JqQ>D-Kil}GxN-2UlqT*5aQk;9Q=khWN%1NcC2(wRyQGqL^z_Qr!B6j@o| z64RE~z!nB`qH+5uZ@@ZzpVQOlHS3(=VL}cy5h(>NMG&ftsOI=0+6S>E@4nTCyru6p zL8+8Qq;&2af&&#%kM-~V^b=1XXEr+mql&Z(@*u|N#nIaDh{WD zX818#_X&uZRLY<{3t1xn3ITLc=%~Cyde!hovVvqn& zTS+bR0vj0>75G@xc9@I286eJHi#k)C-I1Hbm_EHF*M3?&@q0)1W9lC{`+P@BZA&3Z zP6F=c9ZJEq@5xOFJcWk8;5I%-CRrK-U*45QzP+e;L1g?MOAmG8;WUzYwYSdTgJL{@l z8@t&-=j4!vlcK70XzvDebZ2((NL~!8t3O1bJ2-T4L+ExDY@-g32-4LihE*1L-ah-< zg=@Z>&QF(>Po}!lcbQ%-o*x$0_Ibn`o2)mE2#3BE?uIIhzwNcyhtBrol#Oj0+(g#cN7Zg7;Z( zg$KWNA4c<7v7?4O>$Gvk3o5DKEucyRSH$0CFW$fC-#aeKRZ9|`)3rSzYw42GNu`97 z#6kyf0WNW(&Q+yx27)#^4asiE;3QJeQ)qn?RPENb*aL*?Di+_Z#mbN0MBDt16^(tY zgm1r#%&?{RS9!l`^yN?=XtYPPE@`SJ_noY?MC9E_GV(u~K_PTUbYP60qgA$bWhZ6rNHoUx zNz;G3d8^rNB7s=8wI}(!I^<)apZua_R$WoXA@bFzYHp)Df6DRzl$t159N1`5jc{90 zLcK79QC)`FFzh^dGjTcLy+p~^h9AblVP;sq=Tjp2xIx` zDNkuH+?3o|U(-}q^spEr?9{z>R4RE|ko=7rMPz82zldyMZjiN|S>*&%M0Vih<>0PT z5@hVK9>bN0otr~uBm6!YnMml7VK_B|KTX&Y?gf|V>)NqG@|+^`oOU&el&9l!4=o1= zA6Ze}QgY-GhGi4K<@F=lvzqPrzEQE_CBBLlure}Iq3c$jqV!~3@dV#ezOFaWHD+;A zPO&xN^(}z9?f90nxkNK{3e+dK5f|pQMs8N+*`{0*-ba;EPH3QCFR{;!a*i5rWE2mP|8MW!_I4-$HHm7nx>r&l|*%qP1dck<}jwOBBWKZlz~ za;S;r79)Z*tKucajO;QE26A-T@sz`;D!(@%_b?c-wVq!3BZ;_Fra;!Q(oS3B<{rS#k z{kre0Hx*tXQS2s6KJZG-XJr2<7}={^CtM)x(ZxV# zW6MDwNj4L7kKTpRB$723t`sT4zFfVZ-0r$-!*tB8D-nD=KmP#xa~sOzqh5)%*OJn1 zxm{rc;ySQ8PmZ#VsgQcj&!S=r_eba6t;;+JL%u8)=0;gGer~5G-5(x0d4YxEKGfk8 ze97FX!n7FXYPxbRDf2`FFljp+q~cC}e7km?0_dxQF2x{Gtky+-SW^UTb#j~LuFK?8 z`kLT^4$=D8wivtw0r7WeH?8sn3vvEO>{;O&_H45JYF;zVQ*ZoPvzQR9soe@aVP7DA zwldl!-KVSlKKP8hklIYhm?-`U0M>+LAyW)`481a-32XDgy5M_`m+Z{|mm7BrHGkI4W`o9Yg|HEUY z$^U;{)fc9CNSDIz1mz)IpQ)a@XO~Zw`Px5umLUr|T~3+=nuz+q655l~&=6*AtvGXBCsseST|jaB31sOVMvV>&I2rsg0ym3JcAth5vg)MYC&0Y%f`SwQoZtWI%^ zJIUWC0p*9zZ)P(BDHJTU6V|zzit6e0vq}=1tnRnbr~$8#f4Xt!x{%5QFfAMf36Huk z)bXO6N&%9nb-mU*l7Vu1uPUVn)j!!w2$sBSWV@h*OGh=^kpaZf)Vy$<3wwx_O8Mu} zjPg;;ev&~sXocMVNynQY=>5`9-u79P_Ziq0%q#%S;vwx!c;g_?Xu)wsaOxgd#kLw269Jhom(OUFU`{6)jK5!R5el^brgSw z3&eh$mXAg$VWB{sMZN1t{cs*79VJfUOYh%swSU~Sl;FeD=0up_;&O;su|(&O7F~~P zMr{o`7Zv)#NMj(mkt%kWA#I5UMZH3oFr4lbmj5Iy&sjE!XfZu<5prkP#62wMQPna`MJaf8Wj`lNZmj!2;y_r-@*;$y1VOdEPS~mhbbAqgiX_ zr%^tZA~T~AA9wj(XfLzh_+D=atrmF(O5W z0APJQVyg#?lhhB!43*#{VLaLfwe8#HkS!$LIssd1vNu5>f)_7FgysVqNJ9VFo$!Y_ zKJuUvP*#d`Uo1ZTTom3&Il8RiA5IT8u5UCC7voP*cLthgVPhWBc2 zI9Hz!kS4`XrFE&TvkE_;gu%m#qJs!8?MWpmsx|`gfuz;rQch8*%OU)aF7RYVayfyhooxhTt z_B!=`B|O2`DL1z5!I-YLRjo`Kk{#V!Hj>)7qI!w1|&vDPewn2`ZjL& z(Av+M``7Ej);&4l2qxGkZXUa0>pxF7+U}DFm*7y6Ei(T(Y4Vk*+ac7v&M*j+njd!c(zDI zf`-Sc3AJJg&y#(^lXE4C8WFP3c9NXzq03k`vebWr1dFCG=WFC3gxTek@%@2NUYDPU z`3K+%=4q3;TOK4$ALe;ln>4dnR7hBlV{+=Mp@ROrO0ktJ%i?u?qPqH;O7C}Jz_5Q= zHKZ9@UBa>yLqoXc>@3aIUJz(u9P$?IVtEPK5@jXu^tqyd-zg`^wQ#c9VDbZJ01UHg zeNfU#gFgiQabzvPB;KSWyvfKeQB1iBO{$^m_~HjiDp((cLHa7EYizFt#(0&{!aX^u z@4-^OAGQt=GA(9;Gds&HX8VJW*eBn2Sw3_W!~Fhv`f|UDYE1UCN<8=L(<{~PdY=6_D(5MQ+^PH;N=#%p;=x=}vZobXDda6p{k1cg;>xi4={vD*bQe!IJErC5rQ_istc;;} zK{`N7ZHz;kgOYkaY~9 zovKPhf<{)b7)u>omovr6ZS4MyquyadO(x63jLB_mGC zX#MVXEk0<71SoTRa?@>y)57dhmvOF+{L9a7FNUWjW4^yFosForaeL9WPDJwcrYQ1? z&0%=?;XYo_;sb@`-9y~3FQ&91$vqR+j6>0^k3kPz{2uFdy>UfiL1I!Cefcj(cBJJ0 zK>Z*EmMcA4Q6O)_am!Xii&-Y3*=4B<;L5H$g$;3Ki-l61CCK3kH4U^ zqfWFE%}H>}dzh!Ai^nkwF)8zx=9VYOk6{QQQCtNpQ#HN=LT*(t9C&3$j%6A#JFjac z92t>AL%99z(J1{>8=DhEd62(MPlPRpb-bCfF8UF@Fit>_Av7)I%zy>9ei+zD34Txm z7c0V=olD-o^R08SRsL;8b#2lHIurvWX4dgybbs!)P4iR}U=`Rc(va(^I6K&1%o)JG zj+g5*x9Z@A$`-tL^NT$HkJIM`tA`Y{x}G&*i9AVgkI0E4 zAIP-gX6UZcaiq#e?MJYl_8Fg7&bum7_sD4V63%S>4XtUkLkz6qJmw?>d*6uW`_9ug z6(3_@sgoL%WvTP}$d5Ks__q1j{>C{^&vE_dbrUsAhS^aJ!fi|z9z6$kyh?O6#EUAG>J7fA;~hv+z$1w#@DmexA`KXQ zuW~GBg)L_VUR_Y-f?Ws}I!AR9pT}Fsn_9OXqnG+Oe`c+`Pq|tC=a2TU&8YuXhe!0g zKWjW~_%Mx#l^pjJ&d>J{Rh_lcZWg?AcR176!~RF>3L0^Q0#tqI)di2yig@rhgJKc+ z%!gF(hfK4U40U&KZoJk_1{bBraF_Wqvh-84#l6AFj?#u@N>1+aE?BiMGFP16EQNl5 zs;C1UuC1w%!sf3WF5~PqQP71vs*^Y5@t!ma6wALJ-Vm<_tuXCZHWfpHq~L>fSqf0B zXlF-CjO^azNxl$b@{LxoBt0!=b~j~okI?A@KzN2hOXyz;fNLF38Q4^Ky>#IsxRJ4O za?fD#6pKaGf>TY#_H7#&B_L70Me|mwePc?zAKYwLYS^aSC%Bg5Q!4AH;D^psk+EK0 zsB3VX94p;mS_>Va(t2BGuD~10ugTAV7-cmKXd_T{$rRiC+4JFu!SRvgN7vDN*_k?1@WrF2iu~e7hv(v zd|+4<}@%5u#lx&7ht!eAt?4DM_ImjmK2^X5e^gzmC~qYqV7ScSgFGjwMhgR@ox+W1p= ztRU{{QzQmer37rF7g3D#*zZpt6isaUiP7srF2GY&m*jAbG%)H+Sw-r`M(rX#Ne~yE+Uh5oVimV}~~NJS)N>&hjOkMdqdF`n$RV zR(2@2^tvcxWWsp*kM=R*g6TXPm-rA~FlzKmtS!<9I2o{^cC0((*^L>_l!5DfATp7< z5}van8^%Y?ur(x%cB97`hNUd&xU$WSj7>ltHalO6u58UPAITE&Lc}O{Jt=-^q`C$z zhVgt>*S(9FJ5&(9slVqG#A}xvV-;9wq28cz_9D~x$!^O<&+hP%jt$)pIj6{2)vy`{ z=4|&4A|$Bzt&`*hnyyAdM2%67QbAYxoaF=c6Oz)z;+Q@)D=9g`E!#e_s_C@UTzkp0lwKI=X zXd{W73B0%8<>6MGG7%vi=Z1=r<+cTvFVT2Xj^41OrVL20AEP~=9eg)vnnYqI6!ZVe zyV9tpvUD4V5>WDh#hr4<~tC4oz>+lsr4^f$j0J{}Gu)fp1 z`YT~THJWJi>TxVwXae``7JY1Bl&HTl*3u)#@iNiHypNmcW8ae%`|e?#;_aKOOgi?a zm&;;7!Ej`uBjL@wji{soU#0?XivI}Pp=(K(wpqn)Mns9Nrd-`$80BHuZ0 zE0i046{5_;TC-$pmG$io*55rBgK>y3Ru_C9dg4L<;iK23${rp|$&OqTt6DnjIT-ao z=@;L}gCx}H4Dvjmb2)8)>pt|*i~0DnoZuXkp>A5rnl#SCBR(IL5a;hp*Uasy!rm)R zA8(rz8V9pd>a}=@)qxk{c$T&7BOV2l30jGI*gJdm`0-@9l{YqQL6zFn+(|*0?eT(s zZzO(cEPnd^K<3}>rBD0+W&i%gHvHu$Sx|?QH$n1N`g;aB6qubiny&n&XUEIUE&7}9 zJKj_~u;bj|aC(s%yzTs8j7W;OtcDjO>O8vMjvSc!V7ur0)C~T!vZ}=xy{hY>*=pVp|TpLURG^2|lfbT$Ga) zu0xLzD&-kSRy%SvG;u+UXfUL}I*#R`A?iWsc_G`LXL&lB@2^1k`Q@=J$<(Wo>X{p4 zw7+}vj%ssyNL$Ho>%sU-*55jwwZ{=mOLS6|&+I-~xP^Xn2YoD1db5p_`Q$$&?mrj$ z|Mb^bqJPs+4%jJ*z(|uK{O--TJ;rA_>_|AAR5$bp^?ddC>PRIeU)|R58rwo$>qw>S zo?Y~hk=|RAeKVcV)n;|4c!^p1mDVo1~-BvsV|0Yd_bC)%L37I?DrlF4Hv`7K}Y76 zhHoD3^8RT@UgUQN8dj}RH9i_B9pIn3(_Y7S`cR!?vb(5x@ft7L?>yDQ!tym9chlym ze9M);SYQ9=e*VEhss*9YUkpQFchFsB`*qpsn|{FAZlMi48l$yb^UY$MC^_k<6bmKI zE&k7A=K$p^CW=;_5+mO1giUdOf9yQUSYAJ4?!xFa51gcj938Wg>0f5P`Ib&p_0Cf+ zE9#DIRaBfpDVZLL&XhP&C|QTfkFaE3>I%Oen!j_MUq;_VR*H~%DB}UD!T9UnV8&mW z+UGqR{uQtPZfNg^|6CX+-3i$hbesopbA81L-gVaS2{9rK?OM9vi5_oNi{r-S2;Q^x zpUK+eTFhkrt%a%a=4^ChEd2SX}i+nXU}W;oXxWt zvkvr`^m6oFGHdbPw|BTMXnX<@RH^$V1*kgtnnce5XfR&)#)+P3sEzv zNaoiVuEVI;`H$HCPEZ}PKw1kIsu_X7<&`6g#fbKt9*7n|WNE&dWjZk(Kc0!Hak<=@ zdnfYzoLQ2oQ7*>GQ)9BIs$iH=uw_@+!IF7xFM3cBCh4b1Psqx6<%8X%c#}w0Oz(O4 zSj>I&EQoZmV2~Q}MCrVwVJ{xmX>AADm7_T;3v4$vu8Lbc<$cmO0s(F-o%s5CB!SjI zWPzTGOS`~6ThQ_PRX&F(bKR&WdX0CDvvEPd0C_&TEII*FgLD=`44YWC-rfmnn`TBE zOGbR!m=2_-cRbSvf(m#$4$h(rpNSD%UF2_%#{`#IOyVm*qo<)q^SJIFj7S3fF9Rrv z6&3OlZU(6~e-6?YBfN-wtggSAP>&)qr!0fcWpkJLR0c7M!f@fKTFobRQ^0psV#NFy zkam@-?iRtEJwARhWTMWSiCZzG=k}xBo0`rx=~tH0 z_c^hnfj$Q&{R-)$V|ZViqak}ko`(?pGvB^ON1;8I$+|eZ_170wrSVgL5h5cVh9zto~)r)`ON>CrQfde z#bpA|ZuhbJRU-VlO2-OWpA|s7Ym3sVcJu6&2^iMkgDE4k08dWe%MVu3X8bwZ*O_>g z!CM?@c{Vg1f|{yo>h0hpQuIp_`d>kV1F$`4>dAj`mtSAPxi(VzDmWF-gm~M7$3acj z5q_%4Uf}N82Ws+CP(Q;A_0vTeooK+%8Jf!1{Thl{OIcZ+VrBq^)$>R?i`Oc z6cZkT^a$Jnx*kdwY~%LI?~CC-OOw@HY0YqG7W6A94PJ&On$GUdE>@(oBE`($H`dijn|!c0d>_qt)JBWu1!m?t@AFlGTf0| zIZ)7r2;d11hg!No&(a6BihfV4AbBhT#C#hR0H@=#3#Uv6S%=2L%@W# zCc1)N(PPk?s`2tlLS=9rA&sH3F?y4vDRQ+zc7uB^DPo}gWkqr2K+d}h%O*QrE)RL@ zy3IEw7||jk`m6{OC8Y^>?*su<4FkvB=*2q;EzWb8yhSnMX$VFb484Jw!@Xox2JzPm zD52j)%Pnax&vNpAxd0V^F(P0k^qP($s0>xjk-&5aQxhW;0OF}}sG^@lMEJ2u6I=l% z^Q~a2SD^M@?%L<^&wqTa#ZznAXZilU(1cUj|2V#&RBBkQmTa}xdRqZ}fNhrCYU41O z<7ro3@qrt-JEKa8&myfAJw=b)Fnv;T)-SAt{z8C*7pa^!u>pCp}8rdh9ChEWwx zjV*Kip4~#Hj4FWp;WRNK7vPRMTJ-)mr~}8r11*#ZM=>Ja88#av{syt+x%kf!3e5oh z+^4L>h{hN(;<}Bf@-g%QbSzhYD-0EE2OUPAfISB}nrve_$4!igRb@~FL1IMPtNi}y zJ3Kdl&Zn!^wg&-p2G~3+YY5OerMNL^NFX?A&_jIvZ7P9=$GuOQHK!^&*{U|F88|v| z@ObRBC#gfkRJ)^kCv%(6*n`L`vmW^E)y|sx>;2 zqFefA%bjyllw_AasxRG`F@7g8K$f33dyC%SSBbY^)rL<=OUz7hkBCl@&`FPwKKY}f zxu`U4ElehD<&YG0{k=%<_{)Z8equN1CBMdY(X4zanT9Ini9@>$oZZ{ICM{z%91R1U zsLeb91N&C9XL=|mES&}(S&>;=1X@HS&^;?OR(Fp;MWYUR`rdPERpP;>b1KUW4x6V#;XJj_rZ;gv*|ANyN9zwLejRiuY3Rkq zXqm#Jgi@6Ny;$sJzEguq;?lX7$0@RhEak1!60FZYvY72_p@q2B4{Tx0#SBt<`A8|= z#Z|6P-zcBZVD4Hq8E&^d>*D!DgAq8MFNc#b+u;NEt>1fGS!s>k$%lv9wAZS>Adx*P zV$?0JSGQbU@O3X6n4QjNAK>Z9MS6dK`=#y0ScjcYtB-ER)*vtZyJUqtE2frupSYdE zz+cA6geYCcvFm5bs;qK^{2dD+3%&p#sxESr!-D&rK)G0lfi>~MeVAFIz*~S89Slws z;giIOm4dW^rKPi=-=K3|S8hb4>zrl0g;4|5+7Ue7lJ;!5WoNaEM|(C#dROIA^U$(= zbUAI2_Q6CFVC5&KMg>RY;{~#xm>R|qHiia~1Qye*j>>bS&jSJz2BtFxc@^B47`__- z*Ls>zb`Cbq@1H0_&Qdt^BEecHH1T6ET^c4Zb9+RO;D#lPCwl%mK|m)2mGvEAAj6?e zCi;LEwudMQ%{nOK<-LVLDVXKOfML;yBs?8dI74_KcnvIq%|xLSlXsTbhvp1(Ail3a zlh4$Fzv3EKPWMk&M10FTZe_2l^(Zp<9ea0u{GRFe*_FDq&eR~?+qEBZ?9z5w37GK> zY<*=7(qzsPARHHRTGQkN$Dszfc{R~1zSs~9h&Tcd zH!qYg7w^D99>d$&Tx;)`z-67!30O!wqs8qQ^7j;oc%7}Cx%Yi*A4}zzOqY=ws)NNTMj_3ZN=fCw;p0dM+yKvBs`Y+`Qd`a2oA>&q-ipR`LYWp zsLz4TM2^&j>mdT>2^ME(*?@YEeJT#0V=L~AP-nhexEHtDy-Hi4z=1e-9}8l1-uH|i zpAQ!0rp{3Nx3(Bh4;q;^v)6TRHB-35u^G2#*V*^lp9$(0#*_{or+j6pu*dm@N!wYH zwVCq1(h4EDWe)J(pJTdY{}ALYCVmNAw;TRD1kC$mKs9yNDzbTxO-+;J3GX|XPWp^` z?!3|2|F~$;5gulP?2W-Z+uFUyRrfTO{<$dbYFD7+l~N;xA<`$TV&KD`(n@7zB(G*Q z-?#ywO4_J<5GrC|^UIy*zEi6qh2x58KO@&4A-Tg!a*k&Am>N}mH|!-}yaNQFvHaB4 z;g1_ksxzv*L|#7!V|xhpX>HswPHbPfrf@sG=?EP}$7=y=gOk>xlu)i&94!Jcy~>LW zq}4?TbRE+Gks71mLn@>cvIe6cSq+sgnkcfG*ECt(JecJ$a9pBYszCo(bN#;S4I|iUOXo&yw!tM3Gs6Dqmen8IFQWu>$U{!97;y!?xUS+!7!0My%wVtULLiL3smc#~*?WLk zNVr6g{b8i^a+_Uo`$8bnpw^#XVnUk9zIyplMBgcY1z9Y3eK{rY8I^lgyd&P$$I05;IP{ojhwybm<>MJ1O=+dO8)k3B30uKw1C9&q zUC?Z3Ghh~4ha_H*@`5(8D&C@HODplVtlEoR(uA2A zPNwJ-Mj|N=d983PcpxHLOG}~&%_b1X;al=tzmU+n(~Zw_!m=fG6~E8%Ywank8Vka# zW4O2YC&xyOZtL7DaNGe*6yUG1hi7B-=gVuDQo=)BiyGM4`IKk-ktN#p#HABt*h%Vf zBPo#^s2A|=9U5h(1UTpV%@%SSz5*fFh&C<6s^@_BFp$fYR8Tp-dmxes^^SML$Xoba z{WkqXQ$~fAXoi=jUekVu)v3DO!Zx>``_er`qQ3*k_n_eC|H$5f|K-oF_-Ea|fLJKF zeIgbhftt<{abf_m!19Tj@G7oI1XROczpMie8kzU`w6dt>J=T_Xg*!cr(P~e;RZI$H z^96exlCxYmore9Oq{{^NgS(@VXeu8%Sv-vN17A4Hg3@*kQzSvE(5LQo2JWKp5Jfu)rbAEjB1v&g zLeGAJ+)p2>A=B&V>&ZIv5)L0qahvL{KKT&xq2V&rX~ACh*T$&+Zq>H+Ub>7EZpjHtz-H~PiItDGFr`#mpM2w9Ie^s>~jC`H;BAEEyOlc1ZZH+ zjS2#zq5leoCggBSPju0K)DM-@cy^Z5+C|RGg;%be_L$@=H5OkuaeGsDn`7Z z%nD|ai<2DEk_k>I!4wFv6||baV9?T8jfB=(kw>4Y@Q0k8`F63TX~Z}H_^S1t=}V9+ z>iVfjr`{YwA;h3oJj-~nd~$;Nj}@#rQr=K%S@ia)sye;#Y)NAW^5eNBVZz;(Yhxyb zPyyi#E+H$UtQB>}9n2sgxq@;RFUsC0>ty)XgXYEK_hzeTIgT@Imx1}Yh5ie;WIemB zt`ebqY~1hcGO#2#d5#_Fq+;}AVB3R%Mk^GzHFawNX}#0grCd{tM+tE(2AumBlL`HW z=U6%LVCpxMs|8wuRshfhOIB~r!)Cv?=>nYjDZDDwsMwN zGO+0`s>Si`mgJivTfX(bx8Q=Zup}7hk?3~5&OT%f$?L7zTPurVaj{?ln%qw29XvLi z$mzb6{r$vXK1%Ag%%j_DjiZ834idu#6nVkG@+5_;>OJ?MySbGjr!xGFntv38dFjk8 zp7p*3_L7MaDZtD@FRmq3{Y3w$C));136#c%Gt25+}U2lE8w&Mq{-CenE zy_0G@u+)hoBCiBcqN)`9-G*e!(=S)4MfM+hB1(_#5F_&9#}`-^=4^u~6|KcBriYg8 z4-5lK_H@Sg_in1^MFLCqCIU#&;m=sagwgiM{Db~^C*!=M)+IQ~_j)niYUk~?q^gDI oZtXo$Xw9sQvt*gf{aVy+f6*@$xa_~{q57|QPiFVu>(kovUjW+IkN^Mx literal 0 HcmV?d00001 diff --git a/docs/source/recipes/Streaming-ASR/librispeech/index.rst b/docs/source/recipes/Streaming-ASR/librispeech/index.rst new file mode 100644 index 000000000..546ce168b --- /dev/null +++ b/docs/source/recipes/Streaming-ASR/librispeech/index.rst @@ -0,0 +1,9 @@ +LibriSpeech +=========== + +.. toctree:: + :maxdepth: 1 + + pruned_transducer_stateless + + lstm_pruned_stateless_transducer diff --git a/docs/source/recipes/librispeech/lstm_pruned_stateless_transducer.rst b/docs/source/recipes/Streaming-ASR/librispeech/lstm_pruned_stateless_transducer.rst similarity index 100% rename from docs/source/recipes/librispeech/lstm_pruned_stateless_transducer.rst rename to docs/source/recipes/Streaming-ASR/librispeech/lstm_pruned_stateless_transducer.rst diff --git a/docs/source/recipes/Streaming-ASR/librispeech/pruned_transducer_stateless.rst b/docs/source/recipes/Streaming-ASR/librispeech/pruned_transducer_stateless.rst new file mode 100644 index 000000000..de7102ba8 --- /dev/null +++ b/docs/source/recipes/Streaming-ASR/librispeech/pruned_transducer_stateless.rst @@ -0,0 +1,735 @@ +Pruned transducer statelessX +============================ + +This tutorial shows you how to run a **streaming** conformer transducer model +with the `LibriSpeech `_ dataset. + +.. Note:: + + The tutorial is suitable for `pruned_transducer_stateless `_, + `pruned_transducer_stateless2 `_, + `pruned_transducer_stateless4 `_, + `pruned_transducer_stateless5 `_, + We will take pruned_transducer_stateless4 as an example in this tutorial. + +.. HINT:: + + We assume you have read the page :ref:`install icefall` and have setup + the environment for ``icefall``. + +.. HINT:: + + We recommend you to use a GPU or several GPUs to run this recipe. + +.. hint:: + + Please scroll down to the bottom of this page to find download links + for pretrained models if you don't want to train a model from scratch. + + +We use pruned RNN-T to compute the loss. + +.. note:: + + You can find the paper about pruned RNN-T at the following address: + + ``_ + +The transducer model consists of 3 parts: + + - Encoder, a.k.a, the transcription network. We use a Conformer model (the reworked version by Daniel Povey) + - Decoder, a.k.a, the prediction network. We use a stateless model consisting of + ``nn.Embedding`` and ``nn.Conv1d`` + - Joiner, a.k.a, the joint network. + +.. caution:: + + Contrary to the conventional RNN-T models, we use a stateless decoder. + That is, it has no recurrent connections. + + +Data preparation +---------------- + +.. hint:: + + The data preparation is the same as other recipes on LibriSpeech dataset, + if you have finished this step, you can skip to ``Training`` directly. + +.. code-block:: bash + + $ cd egs/librispeech/ASR + $ ./prepare.sh + +The script ``./prepare.sh`` handles the data preparation for you, **automagically**. +All you need to do is to run it. + +The data preparation contains several stages, you can use the following two +options: + + - ``--stage`` + - ``--stop-stage`` + +to control which stage(s) should be run. By default, all stages are executed. + + +For example, + +.. code-block:: bash + + $ cd egs/librispeech/ASR + $ ./prepare.sh --stage 0 --stop-stage 0 + +means to run only stage 0. + +To run stage 2 to stage 5, use: + +.. code-block:: bash + + $ ./prepare.sh --stage 2 --stop-stage 5 + +.. HINT:: + + If you have pre-downloaded the `LibriSpeech `_ + dataset and the `musan `_ dataset, say, + they are saved in ``/tmp/LibriSpeech`` and ``/tmp/musan``, you can modify + the ``dl_dir`` variable in ``./prepare.sh`` to point to ``/tmp`` so that + ``./prepare.sh`` won't re-download them. + +.. NOTE:: + + All generated files by ``./prepare.sh``, e.g., features, lexicon, etc, + are saved in ``./data`` directory. + +We provide the following YouTube video showing how to run ``./prepare.sh``. + +.. note:: + + To get the latest news of `next-gen Kaldi `_, please subscribe + the following YouTube channel by `Nadira Povey `_: + + ``_ + +.. youtube:: ofEIoJL-mGM + + +Training +-------- + +.. NOTE:: + + We put the streaming and non-streaming model in one recipe, to train a streaming model you only + need to add **4** extra options comparing with training a non-streaming model. These options are + ``--dynamic-chunk-training``, ``--num-left-chunks``, ``--causal-convolution``, ``--short-chunk-size``. + You can see the configurable options below for their meanings or read https://arxiv.org/pdf/2012.05481.pdf for more details. + +Configurable options +~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: bash + + $ cd egs/librispeech/ASR + $ ./pruned_transducer_stateless4/train.py --help + + +shows you the training options that can be passed from the commandline. +The following options are used quite often: + + - ``--exp-dir`` + + The directory to save checkpoints, training logs and tensorboard. + + - ``--full-libri`` + + If it's True, the training part uses all the training data, i.e., + 960 hours. Otherwise, the training part uses only the subset + ``train-clean-100``, which has 100 hours of training data. + + .. CAUTION:: + The training set is perturbed by speed with two factors: 0.9 and 1.1. + If ``--full-libri`` is True, each epoch actually processes + ``3x960 == 2880`` hours of data. + + - ``--num-epochs`` + + It is the number of epochs to train. For instance, + ``./pruned_transducer_stateless4/train.py --num-epochs 30`` trains for 30 epochs + and generates ``epoch-1.pt``, ``epoch-2.pt``, ..., ``epoch-30.pt`` + in the folder ``./pruned_transducer_stateless4/exp``. + + - ``--start-epoch`` + + It's used to resume training. + ``./pruned_transducer_stateless4/train.py --start-epoch 10`` loads the + checkpoint ``./pruned_transducer_stateless4/exp/epoch-9.pt`` and starts + training from epoch 10, based on the state from epoch 9. + + - ``--world-size`` + + It is used for multi-GPU single-machine DDP training. + + - (a) If it is 1, then no DDP training is used. + + - (b) If it is 2, then GPU 0 and GPU 1 are used for DDP training. + + The following shows some use cases with it. + + **Use case 1**: You have 4 GPUs, but you only want to use GPU 0 and + GPU 2 for training. You can do the following: + + .. code-block:: bash + + $ cd egs/librispeech/ASR + $ export CUDA_VISIBLE_DEVICES="0,2" + $ ./pruned_transducer_stateless4/train.py --world-size 2 + + **Use case 2**: You have 4 GPUs and you want to use all of them + for training. You can do the following: + + .. code-block:: bash + + $ cd egs/librispeech/ASR + $ ./pruned_transducer_stateless4/train.py --world-size 4 + + **Use case 3**: You have 4 GPUs but you only want to use GPU 3 + for training. You can do the following: + + .. code-block:: bash + + $ cd egs/librispeech/ASR + $ export CUDA_VISIBLE_DEVICES="3" + $ ./pruned_transducer_stateless4/train.py --world-size 1 + + .. caution:: + + Only multi-GPU single-machine DDP training is implemented at present. + Multi-GPU multi-machine DDP training will be added later. + + - ``--max-duration`` + + It specifies the number of seconds over all utterances in a + batch, before **padding**. + If you encounter CUDA OOM, please reduce it. + + .. HINT:: + + Due to padding, the number of seconds of all utterances in a + batch will usually be larger than ``--max-duration``. + + A larger value for ``--max-duration`` may cause OOM during training, + while a smaller value may increase the training time. You have to + tune it. + + - ``--use-fp16`` + + If it is True, the model will train with half precision, from our experiment + results, by using half precision you can train with two times larger ``--max-duration`` + so as to get almost 2X speed up. + + - ``--dynamic-chunk-training`` + + The flag that indicates whether to train a streaming model or not, it + **MUST** be True if you want to train a streaming model. + + - ``--short-chunk-size`` + + When training a streaming attention model with chunk masking, the chunk size + would be either max sequence length of current batch or uniformly sampled from + (1, short_chunk_size). The default value is 25, you don't have to change it most of the time. + + - ``--num-left-chunks`` + + It indicates how many left context (in chunks) that can be seen when calculating attention. + The default value is 4, you don't have to change it most of the time. + + + - ``--causal-convolution`` + + Whether to use causal convolution in conformer encoder layer, this requires + to be True when training a streaming model. + + +Pre-configured options +~~~~~~~~~~~~~~~~~~~~~~ + +There are some training options, e.g., number of encoder layers, +encoder dimension, decoder dimension, number of warmup steps etc, +that are not passed from the commandline. +They are pre-configured by the function ``get_params()`` in +`pruned_transducer_stateless4/train.py `_ + +You don't need to change these pre-configured parameters. If you really need to change +them, please modify ``./pruned_transducer_stateless4/train.py`` directly. + + +.. NOTE:: + + The options for `pruned_transducer_stateless5 `_ are a little different from + other recipes. It allows you to configure ``--num-encoder-layers``, ``--dim-feedforward``, ``--nhead``, ``--encoder-dim``, ``--decoder-dim``, ``--joiner-dim`` from commandline, so that you can train models with different size with pruned_transducer_stateless5. + + +Training logs +~~~~~~~~~~~~~ + +Training logs and checkpoints are saved in ``--exp-dir`` (e.g. ``pruned_transducer_stateless4/exp``. +You will find the following files in that directory: + + - ``epoch-1.pt``, ``epoch-2.pt``, ... + + These are checkpoint files saved at the end of each epoch, containing model + ``state_dict`` and optimizer ``state_dict``. + To resume training from some checkpoint, say ``epoch-10.pt``, you can use: + + .. code-block:: bash + + $ ./pruned_transducer_stateless4/train.py --start-epoch 11 + + - ``checkpoint-436000.pt``, ``checkpoint-438000.pt``, ... + + These are checkpoint files saved every ``--save-every-n`` batches, + containing model ``state_dict`` and optimizer ``state_dict``. + To resume training from some checkpoint, say ``checkpoint-436000``, you can use: + + .. code-block:: bash + + $ ./pruned_transducer_stateless4/train.py --start-batch 436000 + + - ``tensorboard/`` + + This folder contains tensorBoard logs. Training loss, validation loss, learning + rate, etc, are recorded in these logs. You can visualize them by: + + .. code-block:: bash + + $ cd pruned_transducer_stateless4/exp/tensorboard + $ tensorboard dev upload --logdir . --description "pruned transducer training for LibriSpeech with icefall" + + It will print something like below: + + .. code-block:: + + TensorFlow installation not found - running with reduced feature set. + Upload started and will continue reading any new data as it's added to the logdir. + + To stop uploading, press Ctrl-C. + + New experiment created. View your TensorBoard at: https://tensorboard.dev/experiment/97VKXf80Ru61CnP2ALWZZg/ + + [2022-11-20T15:50:50] Started scanning logdir. + Uploading 4468 scalars... + [2022-11-20T15:53:02] Total uploaded: 210171 scalars, 0 tensors, 0 binary objects + Listening for new data in logdir... + + Note there is a URL in the above output. Click it and you will see + the following screenshot: + + .. figure:: images/streaming-librispeech-pruned-transducer-tensorboard-log.jpg + :width: 600 + :alt: TensorBoard screenshot + :align: center + :target: https://tensorboard.dev/experiment/97VKXf80Ru61CnP2ALWZZg/ + + TensorBoard screenshot. + + .. hint:: + + If you don't have access to google, you can use the following command + to view the tensorboard log locally: + + .. code-block:: bash + + cd pruned_transducer_stateless4/exp/tensorboard + tensorboard --logdir . --port 6008 + + It will print the following message: + + .. code-block:: + + Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all + TensorBoard 2.8.0 at http://localhost:6008/ (Press CTRL+C to quit) + + Now start your browser and go to ``_ to view the tensorboard + logs. + + + - ``log/log-train-xxxx`` + + It is the detailed training log in text format, same as the one + you saw printed to the console during training. + +Usage example +~~~~~~~~~~~~~ + +You can use the following command to start the training using 4 GPUs: + +.. code-block:: bash + + export CUDA_VISIBLE_DEVICES="0,1,2,3" + ./pruned_transducer_stateless4/train.py \ + --world-size 4 \ + --dynamic-chunk-training 1 \ + --causal-convolution 1 \ + --num-epochs 30 \ + --start-epoch 1 \ + --exp-dir pruned_transducer_stateless4/exp \ + --full-libri 1 \ + --max-duration 300 + +.. NOTE:: + + Comparing with training a non-streaming model, you only need to add two extra options, + ``--dynamic-chunk-training 1`` and ``--causal-convolution 1`` . + + +Decoding +-------- + +The decoding part uses checkpoints saved by the training part, so you have +to run the training part first. + +.. hint:: + + There are two kinds of checkpoints: + + - (1) ``epoch-1.pt``, ``epoch-2.pt``, ..., which are saved at the end + of each epoch. You can pass ``--epoch`` to + ``pruned_transducer_stateless4/decode.py`` to use them. + + - (2) ``checkpoints-436000.pt``, ``epoch-438000.pt``, ..., which are saved + every ``--save-every-n`` batches. You can pass ``--iter`` to + ``pruned_transducer_stateless4/decode.py`` to use them. + + We suggest that you try both types of checkpoints and choose the one + that produces the lowest WERs. + +.. tip:: + + To decode a streaming model, you can use either ``simulate streaming decoding`` in ``decode.py`` or + ``real streaming decoding`` in ``streaming_decode.py``, the difference between ``decode.py`` and + ``streaming_decode.py`` is that, ``decode.py`` processes the whole acoustic frames at one time with masking (i.e. same as training), + but ``streaming_decode.py`` processes the acoustic frames chunk by chunk (so it can only see limited context). + +.. NOTE:: + + ``simulate streaming decoding`` in ``decode.py`` and ``real streaming decoding`` in ``streaming_decode.py`` should + produce almost the same results given the same ``--decode-chunk-size`` and ``--left-context``. + + +Simulate streaming decoding +~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: bash + + $ cd egs/librispeech/ASR + $ ./pruned_transducer_stateless4/decode.py --help + +shows the options for decoding. +The following options are important for streaming models: + + ``--simulate-streaming`` + + If you want to decode a streaming model with ``decode.py``, you **MUST** set + ``--simulate-streaming`` to ``True``. ``simulate`` here means the acoustic frames + are not processed frame by frame (or chunk by chunk), instead, the whole sequence + is processed at one time with masking (the same as training). + + ``--causal-convolution`` + + If True, the convolution module in encoder layers will be causal convolution. + This is **MUST** be True when decoding with a streaming model. + + ``--decode-chunk-size`` + + For streaming models, we will calculate the chunk-wise attention, ``--decode-chunk-size`` + indicates the chunk length (in frames after subsampling) for chunk-wise attention. + For ``simulate streaming decoding`` the ``decode-chunk-size`` is used to generate + the attention mask. + + ``--left-context`` + + ``--left-context`` indicates how many left context frames (after subsampling) can be seen + for current chunk when calculating chunk-wise attention. Normally, ``left-context`` should equal + to ``decode-chunk-size * num-left-chunks``, where ``num-left-chunks`` is the option used + to train this model. For ``simulate streaming decoding`` the ``left-context`` is used to generate + the attention mask. + + +The following shows two examples (for the two types of checkpoints): + +.. code-block:: bash + + for m in greedy_search fast_beam_search modified_beam_search; do + for epoch in 25 20; do + for avg in 7 5 3 1; do + ./pruned_transducer_stateless4/decode.py \ + --epoch $epoch \ + --avg $avg \ + --simulate-streaming 1 \ + --causal-convolution 1 \ + --decode-chunk-size 16 \ + --left-context 64 \ + --exp-dir pruned_transducer_stateless4/exp \ + --max-duration 600 \ + --decoding-method $m + done + done + done + + +.. code-block:: bash + + for m in greedy_search fast_beam_search modified_beam_search; do + for iter in 474000; do + for avg in 8 10 12 14 16 18; do + ./pruned_transducer_stateless4/decode.py \ + --iter $iter \ + --avg $avg \ + --simulate-streaming 1 \ + --causal-convolution 1 \ + --decode-chunk-size 16 \ + --left-context 64 \ + --exp-dir pruned_transducer_stateless4/exp \ + --max-duration 600 \ + --decoding-method $m + done + done + done + + +Real streaming decoding +~~~~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: bash + + $ cd egs/librispeech/ASR + $ ./pruned_transducer_stateless4/streaming_decode.py --help + +shows the options for decoding. +The following options are important for streaming models: + + ``--decode-chunk-size`` + + For streaming models, we will calculate the chunk-wise attention, ``--decode-chunk-size`` + indicates the chunk length (in frames after subsampling) for chunk-wise attention. + For ``real streaming decoding``, we will process ``decode-chunk-size`` acoustic frames at each time. + + ``--left-context`` + + ``--left-context`` indicates how many left context frames (after subsampling) can be seen + for current chunk when calculating chunk-wise attention. Normally, ``left-context`` should equal + to ``decode-chunk-size * num-left-chunks``, where ``num-left-chunks`` is the option used + to train this model. + + ``--num-decode-streams`` + + The number of decoding streams that can be run in parallel (very similar to the ``bath size``). + For ``real streaming decoding``, the batches will be packed dynamically, for example, if the + ``num-decode-streams`` equals to 10, then, sequence 1 to 10 will be decoded at first, after a while, + suppose sequence 1 and 2 are done, so, sequence 3 to 12 will be processed parallelly in a batch. + + +.. NOTE:: + + We also try adding ``--right-context`` in the real streaming decoding, but it seems not to benefit + the performance for all the models, the reasons might be the training and decoding mismatch. You + can try decoding with ``--right-context`` to see if it helps. The default value is 0. + + +The following shows two examples (for the two types of checkpoints): + +.. code-block:: bash + + for m in greedy_search fast_beam_search modified_beam_search; do + for epoch in 25 20; do + for avg in 7 5 3 1; do + ./pruned_transducer_stateless4/decode.py \ + --epoch $epoch \ + --avg $avg \ + --decode-chunk-size 16 \ + --left-context 64 \ + --num-decode-streams 100 \ + --exp-dir pruned_transducer_stateless4/exp \ + --max-duration 600 \ + --decoding-method $m + done + done + done + + +.. code-block:: bash + + for m in greedy_search fast_beam_search modified_beam_search; do + for iter in 474000; do + for avg in 8 10 12 14 16 18; do + ./pruned_transducer_stateless4/decode.py \ + --iter $iter \ + --avg $avg \ + --decode-chunk-size 16 \ + --left-context 64 \ + --num-decode-streams 100 \ + --exp-dir pruned_transducer_stateless4/exp \ + --max-duration 600 \ + --decoding-method $m + done + done + done + + +.. tip:: + + Supporting decoding methods are as follows: + + - ``greedy_search`` : It takes the symbol with largest posterior probability + of each frame as the decoding result. + + - ``beam_search`` : It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf and + `espnet/nets/beam_search_transducer.py `_ + is used as a reference. Basicly, it keeps topk states for each frame, and expands the kept states with their own contexts to + next frame. + + - ``modified_beam_search`` : It implements the same algorithm as ``beam_search`` above, but it + runs in batch mode with ``--max-sym-per-frame=1`` being hardcoded. + + - ``fast_beam_search`` : It implements graph composition between the output ``log_probs`` and + given ``FSAs``. It is hard to describe the details in several lines of texts, you can read + our paper in https://arxiv.org/pdf/2211.00484.pdf or our `rnnt decode code in k2 `_. ``fast_beam_search`` can decode with ``FSAs`` on GPU efficiently. + + - ``fast_beam_search_LG`` : The same as ``fast_beam_search`` above, ``fast_beam_search`` uses + an trivial graph that has only one state, while ``fast_beam_search_LG`` uses an LG graph + (with N-gram LM). + + - ``fast_beam_search_nbest`` : It produces the decoding results as follows: + + - (1) Use ``fast_beam_search`` to get a lattice + - (2) Select ``num_paths`` paths from the lattice using ``k2.random_paths()`` + - (3) Unique the selected paths + - (4) Intersect the selected paths with the lattice and compute the + shortest path from the intersection result + - (5) The path with the largest score is used as the decoding output. + + - ``fast_beam_search_nbest_LG`` : It implements same logic as ``fast_beam_search_nbest``, the + only difference is that it uses ``fast_beam_search_LG`` to generate the lattice. + +.. NOTE:: + + The supporting decoding methods in ``streaming_decode.py`` might be less than that in ``decode.py``, if needed, + you can implement them by yourself or file a issue in `icefall `_ . + + +Export Model +------------ + +`pruned_transducer_stateless4/export.py `_ supports exporting checkpoints from ``pruned_transducer_stateless4/exp`` in the following ways. + +Export ``model.state_dict()`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Checkpoints saved by ``pruned_transducer_stateless4/train.py`` also include +``optimizer.state_dict()``. It is useful for resuming training. But after training, +we are interested only in ``model.state_dict()``. You can use the following +command to extract ``model.state_dict()``. + +.. code-block:: bash + + # Assume that --epoch 25 --avg 3 produces the smallest WER + # (You can get such information after running ./pruned_transducer_stateless4/decode.py) + + epoch=25 + avg=3 + + ./pruned_transducer_stateless4/export.py \ + --exp-dir ./pruned_transducer_stateless4/exp \ + --streaming-model 1 \ + --causal-convolution 1 \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch $epoch \ + --avg $avg + +.. caution:: + + ``--streaming-model`` and ``--causal-convolution`` require to be True to export + a streaming mdoel. + +It will generate a file ``./pruned_transducer_stateless4/exp/pretrained.pt``. + +.. hint:: + + To use the generated ``pretrained.pt`` for ``pruned_transducer_stateless4/decode.py``, + you can run: + + .. code-block:: bash + + cd pruned_transducer_stateless4/exp + ln -s pretrained.pt epoch-999.pt + + And then pass ``--epoch 999 --avg 1 --use-averaged-model 0`` to + ``./pruned_transducer_stateless4/decode.py``. + +To use the exported model with ``./pruned_transducer_stateless4/pretrained.py``, you +can run: + +.. code-block:: bash + + ./pruned_transducer_stateless4/pretrained.py \ + --checkpoint ./pruned_transducer_stateless4/exp/pretrained.pt \ + --simulate-streaming 1 \ + --causal-convolution 1 \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method greedy_search \ + /path/to/foo.wav \ + /path/to/bar.wav + + +Export model using ``torch.jit.script()`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: bash + + ./pruned_transducer_stateless4/export.py \ + --exp-dir ./pruned_transducer_stateless4/exp \ + --streaming-model 1 \ + --causal-convolution 1 \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 25 \ + --avg 3 \ + --jit 1 + +.. caution:: + + ``--streaming-model`` and ``--causal-convolution`` require to be True to export + a streaming mdoel. + +It will generate a file ``cpu_jit.pt`` in the given ``exp_dir``. You can later +load it by ``torch.jit.load("cpu_jit.pt")``. + +Note ``cpu`` in the name ``cpu_jit.pt`` means the parameters when loaded into Python +are on CPU. You can use ``to("cuda")`` to move them to a CUDA device. + +.. NOTE:: + + You will need this ``cpu_jit.pt`` when deploying with Sherpa framework. + + +Download pretrained models +-------------------------- + +If you don't want to train from scratch, you can download the pretrained models +by visiting the following links: + + - `pruned_transducer_stateless `_ + + - `pruned_transducer_stateless2 `_ + + - `pruned_transducer_stateless4 `_ + + - `pruned_transducer_stateless5 `_ + + See ``_ + for the details of the above pretrained models + + +Deploy with Sherpa +------------------ + +Please see ``_ +for how to deploy the models in ``sherpa``. diff --git a/docs/source/recipes/index.rst b/docs/source/recipes/index.rst index 9d1d83d29..63793275c 100644 --- a/docs/source/recipes/index.rst +++ b/docs/source/recipes/index.rst @@ -13,7 +13,5 @@ We may add recipes for other tasks as well in the future. :maxdepth: 2 :caption: Table of Contents - aishell/index - librispeech/index - timit/index - yesno/index + Non-streaming-ASR/index + Streaming-ASR/index From 6d659f423dbb67b20309e00d5885b76c5dfd15e8 Mon Sep 17 00:00:00 2001 From: kobenaxie <572745565@qq.com> Date: Thu, 15 Dec 2022 20:42:07 +0800 Subject: [PATCH 051/174] delete duplicate line for encoder initial state (#765) --- .../ASR/conv_emformer_transducer_stateless2/export-for-ncnn.py | 1 - 1 file changed, 1 deletion(-) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-for-ncnn.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-for-ncnn.py index 716de5734..64c16141c 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-for-ncnn.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-for-ncnn.py @@ -152,7 +152,6 @@ def export_encoder_model_jit_trace( x = torch.zeros(1, T, 80, dtype=torch.float32) states = encoder_model.init_states() - states = encoder_model.init_states() traced_model = torch.jit.trace(encoder_model, (x, states)) traced_model.save(encoder_filename) From fbc1d3b194cfb2be4d01de85d9c3a3ea13c961fb Mon Sep 17 00:00:00 2001 From: Zengwei Yao Date: Sat, 17 Dec 2022 22:03:13 +0800 Subject: [PATCH 052/174] fix src_key_padding_mask in DownsampledZipformerEncoder (#768) --- egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index ed1e2efa2..71f12e44a 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -741,7 +741,7 @@ class DownsampledZipformerEncoder(nn.Module): src, feature_mask=feature_mask, mask=mask, - src_key_padding_mask=mask, + src_key_padding_mask=src_key_padding_mask, ) src = self.upsample(src) # remove any extra frames that are not a multiple of downsample_factor From 65d7192dca03ba21bff4270add3891c9730491a7 Mon Sep 17 00:00:00 2001 From: Zengwei Yao Date: Mon, 19 Dec 2022 20:10:39 +0800 Subject: [PATCH 053/174] Fix zipformer attn_output_weights (#774) * fix attn_output_weights * remove in-place op --- .../pruned_transducer_stateless7/zipformer.py | 34 +++++++++++++++++-- 1 file changed, 32 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 71f12e44a..ad3b88df0 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -1291,9 +1291,11 @@ class RelPositionMultiheadAttention(nn.Module): if attn_mask is not None: if attn_mask.dtype == torch.bool: - attn_output_weights.masked_fill_(attn_mask, float("-inf")) + attn_output_weights = attn_output_weights.masked_fill( + attn_mask, float("-inf") + ) else: - attn_output_weights += attn_mask + attn_output_weights = attn_output_weights + attn_mask if key_padding_mask is not None: attn_output_weights = attn_output_weights.view( @@ -1313,6 +1315,34 @@ class RelPositionMultiheadAttention(nn.Module): # only storing the half-precision output for backprop purposes. attn_output_weights = softmax(attn_output_weights, dim=-1) + # If we are using chunk-wise attention mask and setting a limited + # num_left_chunks, the attention may only see the padding values which + # will also be masked out by `key_padding_mask`. At this circumstances, + # the whole column of `attn_output_weights` will be `-inf` + # (i.e. be `nan` after softmax). So we fill `0.0` at the masking + # positions to avoid invalid loss value below. + if ( + attn_mask is not None + and attn_mask.dtype == torch.bool + and key_padding_mask is not None + ): + if attn_mask.size(0) != 1: + attn_mask = attn_mask.view(bsz, num_heads, seq_len, seq_len) + combined_mask = attn_mask | key_padding_mask.unsqueeze(1).unsqueeze(2) + else: + # attn_mask.shape == (1, tgt_len, src_len) + combined_mask = attn_mask.unsqueeze(0) | key_padding_mask.unsqueeze( + 1 + ).unsqueeze(2) + + attn_output_weights = attn_output_weights.view( + bsz, num_heads, seq_len, seq_len + ) + attn_output_weights = attn_output_weights.masked_fill(combined_mask, 0.0) + attn_output_weights = attn_output_weights.view( + bsz * num_heads, seq_len, seq_len + ) + attn_output_weights = nn.functional.dropout( attn_output_weights, p=dropout_p, training=training ) From 070c77e724d4da91900925c44b237523d97f9f08 Mon Sep 17 00:00:00 2001 From: Yifan Yang <64255737+yfyeung@users.noreply.github.com> Date: Wed, 21 Dec 2022 17:41:31 +0800 Subject: [PATCH 054/174] Add Blankskip to Zipformer+CTC (#730) * init files * add ctc as auxiliary loss and ctc_decode.py * tuning the scalar of HLG score for 1best, nbest and nbest-oracle * rename to pruned_transducer_stateless7_ctc * fix doc * fix bug, recover the hlg scores * modify ctc_decode.py, move out the hlg scale * fix hlg_scale * add export.py and pretrained.py, and so on * upload files, update README.md and RESULTS.md * add CI test * update .gitignore * create symlinks * Add Blank Skip to Zipformer+CTC * Add warmup to blank skip * Add warmup to blank skip * Add __init__.py * Add parameters_names to Adam * Add warmup to blank skip * Modify frame_reducer * Modify frame_reducer * Add Blank Skip to decode. * Add ctc_decode.py * Add blank skip to Zipformer+CTC * process conflict * process conflict * modify ctc_guild_decode_bk.py * modify Lconv * produce the conflict * Add export.py * finish export * fix for running black * Add ci test * Add ci-test * chmod * chmod * fix bug for ci-test * fix bug for ci-test * fix bug for ci-test * rename the dirname * rename the dirname * change dirname * change dirname * fix notes * add pretrained.py * add pretrained.py * add pretrained.py * add pretrained.py * add pretrained.py * add pretrained.py * fix * fix * fix * finished * add the Copyright info and notes Co-authored-by: Zengwei Yao Co-authored-by: yifanyang --- ...ed-transducer-stateless7-ctc-2022-12-01.sh | 2 +- ...transducer-stateless7-ctc-bs-2022-12-15.sh | 148 ++ ...brispeech-2022-12-15-stateless7-ctc-bs.yml | 163 +++ .gitignore | 1 + egs/gigaspeech/ASR/.gitignore | 1 + egs/librispeech/ASR/.gitignore | 1 + .../ASR/pruned_transducer_stateless7/optim.py | 2 +- .../jit_pretrained_ctc.py | 8 +- .../__init__.py | 0 .../asr_datamodule.py | 1 + .../beam_search.py | 1 + .../ctc_decode.py | 809 +++++++++++ .../ctc_guild_decode_bs.py | 857 +++++++++++ .../decode.py | 841 +++++++++++ .../decoder.py | 1 + .../encoder_interface.py | 1 + .../export.py | 319 +++++ .../frame_reducer.py | 84 ++ .../jit_pretrained.py | 271 ++++ .../jit_pretrained_ctc.py | 426 ++++++ .../joiner.py | 1 + .../lconv.py | 114 ++ .../model.py | 224 +++ .../optim.py | 1 + .../pretrained.py | 352 +++++ .../pretrained_ctc.py | 440 ++++++ .../scaling.py | 1 + .../scaling_converter.py | 1 + .../test_model.py | 55 + .../train.py | 1251 +++++++++++++++++ .../zipformer.py | 1 + 31 files changed, 6372 insertions(+), 6 deletions(-) create mode 100755 .github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-bs-2022-12-15.sh create mode 100644 .github/workflows/run-librispeech-2022-12-15-stateless7-ctc-bs.yml create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/__init__.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/asr_datamodule.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/beam_search.py create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/ctc_decode.py create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/ctc_guild_decode_bs.py create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/decode.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/decoder.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/encoder_interface.py create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/export.py create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/frame_reducer.py create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/jit_pretrained.py create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/jit_pretrained_ctc.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/joiner.py create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/lconv.py create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/model.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/optim.py create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/pretrained.py create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/pretrained_ctc.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/scaling.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/scaling_converter.py create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/test_model.py create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/train.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/zipformer.py diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-2022-12-01.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-2022-12-01.sh index e081c9374..3cbb480f6 100755 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-2022-12-01.sh +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-2022-12-01.sh @@ -148,4 +148,4 @@ if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == done rm pruned_transducer_stateless7_ctc/exp/*.pt -fi +fi \ No newline at end of file diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-bs-2022-12-15.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-bs-2022-12-15.sh new file mode 100755 index 000000000..ed66a728e --- /dev/null +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-bs-2022-12-15.sh @@ -0,0 +1,148 @@ +#!/usr/bin/env bash + +set -e + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/yfyeung/icefall-asr-librispeech-pruned_transducer_stateless7_ctc_bs-2022-12-14 + +log "Downloading pre-trained model from $repo_url" +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +log "Display test files" +tree $repo/ +soxi $repo/test_wavs/*.wav +ls -lh $repo/test_wavs/*.wav + +pushd $repo/exp +git lfs pull --include "data/lang_bpe_500/HLG.pt" +git lfs pull --include "data/lang_bpe_500/L.pt" +git lfs pull --include "data/lang_bpe_500/LG.pt" +git lfs pull --include "data/lang_bpe_500/Linv.pt" +git lfs pull --include "data/lang_bpe_500/bpe.model" +git lfs pull --include "exp/cpu_jit.pt" +git lfs pull --include "exp/pretrained.pt" +ln -s pretrained.pt epoch-99.pt +ls -lh *.pt +popd + +log "Export to torchscript model" +./pruned_transducer_stateless7_ctc_bs/export.py \ + --exp-dir $repo/exp \ + --use-averaged-model false \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --epoch 99 \ + --avg 1 \ + --jit 1 + +ls -lh $repo/exp/*.pt + +log "Decode with models exported by torch.jit.script()" + +./pruned_transducer_stateless7_ctc_bs/jit_pretrained.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --nn-model-filename $repo/exp/cpu_jit.pt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + +for m in ctc-decoding 1best; do + ./pruned_transducer_stateless7_ctc_bs/jit_pretrained_ctc.py \ + --model-filename $repo/exp/cpu_jit.pt \ + --words-file $repo/data/lang_bpe_500/words.txt \ + --HLG $repo/data/lang_bpe_500/HLG.pt \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --method $m \ + --sample-rate 16000 \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav +done + +for sym in 1 2 3; do + log "Greedy search with --max-sym-per-frame $sym" + + ./pruned_transducer_stateless7_ctc_bs/pretrained.py \ + --method greedy_search \ + --max-sym-per-frame $sym \ + --checkpoint $repo/exp/pretrained.pt \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav +done + +for method in modified_beam_search beam_search fast_beam_search; do + log "$method" + + ./pruned_transducer_stateless7_ctc_bs/pretrained.py \ + --method $method \ + --beam-size 4 \ + --checkpoint $repo/exp/pretrained.pt \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav +done + +for m in ctc-decoding 1best; do + ./pruned_transducer_stateless7_ctc_bs/pretrained_ctc.py \ + --checkpoint $repo/exp/pretrained.pt \ + --words-file $repo/data/lang_bpe_500/words.txt \ + --HLG $repo/data/lang_bpe_500/HLG.pt \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --method $m \ + --sample-rate 16000 \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav +done + +echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}" +echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}" + +if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then + mkdir -p pruned_transducer_stateless7_ctc_bs/exp + ln -s $PWD/$repo/exp/pretrained.pt pruned_transducer_stateless7_ctc_bs/exp/epoch-999.pt + ln -s $PWD/$repo/data/lang_bpe_500 data/ + + ls -lh data + ls -lh pruned_transducer_stateless7_ctc_bs/exp + + log "Decoding test-clean and test-other" + + # use a small value for decoding with CPU + max_duration=100 + + for method in greedy_search fast_beam_search modified_beam_search; do + log "Decoding with $method" + + ./pruned_transducer_stateless7_ctc_bs/decode.py \ + --decoding-method $method \ + --epoch 999 \ + --avg 1 \ + --use-averaged-model 0 \ + --max-duration $max_duration \ + --exp-dir pruned_transducer_stateless7_ctc_bs/exp + done + + for m in ctc-decoding 1best; do + ./pruned_transducer_stateless7_ctc_bs/ctc_decode.py \ + --epoch 999 \ + --avg 1 \ + --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ + --max-duration $max_duration \ + --use-averaged-model 0 \ + --decoding-method $m \ + --hlg-scale 0.6 + done + + rm pruned_transducer_stateless7_ctc_bs/exp/*.pt +fi diff --git a/.github/workflows/run-librispeech-2022-12-15-stateless7-ctc-bs.yml b/.github/workflows/run-librispeech-2022-12-15-stateless7-ctc-bs.yml new file mode 100644 index 000000000..6e2b40cf3 --- /dev/null +++ b/.github/workflows/run-librispeech-2022-12-15-stateless7-ctc-bs.yml @@ -0,0 +1,163 @@ +# Copyright 2022 Fangjun Kuang (csukuangfj@gmail.com) + +# See ../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: run-librispeech-2022-12-15-stateless7-ctc-bs +# zipformer + +on: + push: + branches: + - master + pull_request: + types: [labeled] + + schedule: + # minute (0-59) + # hour (0-23) + # day of the month (1-31) + # month (1-12) + # day of the week (0-6) + # nightly build at 15:50 UTC time every day + - cron: "50 15 * * *" + +jobs: + run_librispeech_2022_12_15_zipformer_ctc_bs: + if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event.label.name == 'blank-skip' || github.event_name == 'push' || github.event_name == 'schedule' + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-latest] + python-version: [3.8] + + fail-fast: false + + steps: + - uses: actions/checkout@v2 + with: + fetch-depth: 0 + + - name: Setup Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + cache: 'pip' + cache-dependency-path: '**/requirements-ci.txt' + + - name: Install Python dependencies + run: | + grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install + pip uninstall -y protobuf + pip install --no-binary protobuf protobuf + + - name: Cache kaldifeat + id: my-cache + uses: actions/cache@v2 + with: + path: | + ~/tmp/kaldifeat + key: cache-tmp-${{ matrix.python-version }}-2022-09-25 + + - name: Install kaldifeat + if: steps.my-cache.outputs.cache-hit != 'true' + shell: bash + run: | + .github/scripts/install-kaldifeat.sh + + - name: Cache LibriSpeech test-clean and test-other datasets + id: libri-test-clean-and-test-other-data + uses: actions/cache@v2 + with: + path: | + ~/tmp/download + key: cache-libri-test-clean-and-test-other + + - name: Download LibriSpeech test-clean and test-other + if: steps.libri-test-clean-and-test-other-data.outputs.cache-hit != 'true' + shell: bash + run: | + .github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh + + - name: Prepare manifests for LibriSpeech test-clean and test-other + shell: bash + run: | + .github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh + + - name: Cache LibriSpeech test-clean and test-other fbank features + id: libri-test-clean-and-test-other-fbank + uses: actions/cache@v2 + with: + path: | + ~/tmp/fbank-libri + key: cache-libri-fbank-test-clean-and-test-other-v2 + + - name: Compute fbank for LibriSpeech test-clean and test-other + if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true' + shell: bash + run: | + .github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh + + - name: Inference with pre-trained model + shell: bash + env: + GITHUB_EVENT_NAME: ${{ github.event_name }} + GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }} + run: | + mkdir -p egs/librispeech/ASR/data + ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank + ls -lh egs/librispeech/ASR/data/* + + sudo apt-get -qq install git-lfs tree sox + export PYTHONPATH=$PWD:$PYTHONPATH + export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH + export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH + + .github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-bs-2022-12-15.sh + + - name: Display decoding results for librispeech pruned_transducer_stateless7_ctc_bs + if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' + shell: bash + run: | + cd egs/librispeech/ASR/ + tree ./pruned_transducer_stateless7_ctc_bs/exp + + cd pruned_transducer_stateless7_ctc_bs + echo "results for pruned_transducer_stateless7_ctc_bs" + echo "===greedy search===" + find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 + find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 + + echo "===fast_beam_search===" + find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 + find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 + + echo "===modified beam search===" + find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 + find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 + + echo "===ctc decoding===" + find exp/ctc-decoding -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 + find exp/ctc-decoding -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 + + echo "===1best===" + find exp/1best -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 + find exp/1best -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 + + - name: Upload decoding results for librispeech pruned_transducer_stateless7_ctc_bs + uses: actions/upload-artifact@v2 + if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' + with: + name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-pruned_transducer_stateless7-ctc-bs-2022-12-15 + path: egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/exp/ diff --git a/.gitignore b/.gitignore index 583410f45..8af05d884 100644 --- a/.gitignore +++ b/.gitignore @@ -33,3 +33,4 @@ node_modules *.param *.bin +.DS_Store diff --git a/egs/gigaspeech/ASR/.gitignore b/egs/gigaspeech/ASR/.gitignore index 5592679cc..8dec2d86d 100644 --- a/egs/gigaspeech/ASR/.gitignore +++ b/egs/gigaspeech/ASR/.gitignore @@ -1 +1,2 @@ log-* +.DS_Store \ No newline at end of file diff --git a/egs/librispeech/ASR/.gitignore b/egs/librispeech/ASR/.gitignore index 5592679cc..8dec2d86d 100644 --- a/egs/librispeech/ASR/.gitignore +++ b/egs/librispeech/ASR/.gitignore @@ -1 +1,2 @@ log-* +.DS_Store \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py index ff8fbb32c..374b78cb3 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py @@ -1,4 +1,4 @@ -# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey) +# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey) # # See ../LICENSE for clarification regarding multiple authors # diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py index ad9cf08dc..d50d231d5 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py @@ -31,7 +31,7 @@ Usage of this script: (1) ctc-decoding ./pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py \ - --nn-model-filename ./pruned_transducer_stateless7_ctc/exp/cpu_jit.pt \ + --model-filename ./pruned_transducer_stateless7_ctc/exp/cpu_jit.pt \ --bpe-model data/lang_bpe_500/bpe.model \ --method ctc-decoding \ --sample-rate 16000 \ @@ -40,7 +40,7 @@ Usage of this script: (2) 1best ./pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py \ - --nn-model-filename ./pruned_transducer_stateless7_ctc/exp/cpu_jit.pt \ + --model-filename ./pruned_transducer_stateless7_ctc/exp/cpu_jit.pt \ --HLG data/lang_bpe_500/HLG.pt \ --words-file data/lang_bpe_500/words.txt \ --method 1best \ @@ -51,7 +51,7 @@ Usage of this script: (3) nbest-rescoring ./pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py \ - --nn-model-filename ./pruned_transducer_stateless7_ctc/exp/cpu_jit.pt \ + --model-filename ./pruned_transducer_stateless7_ctc/exp/cpu_jit.pt \ --HLG data/lang_bpe_500/HLG.pt \ --words-file data/lang_bpe_500/words.txt \ --G data/lm/G_4_gram.pt \ @@ -63,7 +63,7 @@ Usage of this script: (4) whole-lattice-rescoring ./pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py \ - --nn-model-filename ./pruned_transducer_stateless7_ctc/exp/cpu_jit.pt \ + --model-filename ./pruned_transducer_stateless7_ctc/exp/cpu_jit.pt \ --HLG data/lang_bpe_500/HLG.pt \ --words-file data/lang_bpe_500/words.txt \ --G data/lm/G_4_gram.pt \ diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/__init__.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/__init__.py new file mode 100755 index 000000000..e69de29bb diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/asr_datamodule.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/asr_datamodule.py new file mode 120000 index 000000000..a074d6085 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/asr_datamodule.py @@ -0,0 +1 @@ +../pruned_transducer_stateless2/asr_datamodule.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/beam_search.py new file mode 120000 index 000000000..8554e44cc --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/beam_search.py @@ -0,0 +1 @@ +../pruned_transducer_stateless2/beam_search.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/ctc_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/ctc_decode.py new file mode 100755 index 000000000..0ef733226 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/ctc_decode.py @@ -0,0 +1,809 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, +# Liyong Guo, +# Quandong Wang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) ctc-decoding +./pruned_transducer_stateless7_ctc_bs/ctc_decode.py \ + --epoch 30 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ + --max-duration 600 \ + --decoding-method ctc-decoding +(2) 1best +./pruned_transducer_stateless7_ctc_bs/ctc_decode.py \ + --epoch 30 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ + --max-duration 600 \ + --hlg-scale 0.8 \ + --decoding-method 1best +(3) nbest +./pruned_transducer_stateless7_ctc_bs/ctc_decode.py \ + --epoch 30 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ + --max-duration 600 \ + --hlg-scale 0.8 \ + --decoding-method 1best +(4) nbest-rescoring +./pruned_transducer_stateless7_ctc_bs/ctc_decode.py \ + --epoch 30 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ + --max-duration 600 \ + --hlg-scale 0.8 \ + --lm-dir data/lm \ + --decoding-method nbest-rescoring +(5) whole-lattice-rescoring +./pruned_transducer_stateless7_ctc_bs/ctc_decode.py \ + --epoch 30 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ + --max-duration 600 \ + --hlg-scale 0.8 \ + --lm-dir data/lm \ + --decoding-method whole-lattice-rescoring +""" + + +import argparse +import logging +import math +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.decode import ( + get_lattice, + nbest_decoding, + nbest_oracle, + one_best_decoding, + rescore_with_n_best_list, + rescore_with_whole_lattice, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + get_texts, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7_ctc_bs/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_bpe_500", + help="The lang dir containing word table and LG graph", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram, 2 means tri-gram", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="ctc-decoding", + help="""Decoding method. + Supported values are: + - (1) ctc-decoding. Use CTC decoding. It uses a sentence piece + model, i.e., lang_dir/bpe.model, to convert word pieces to words. + It needs neither a lexicon nor an n-gram LM. + - (2) 1best. Extract the best path from the decoding lattice as the + decoding result. + - (3) nbest. Extract n paths from the decoding lattice; the path + with the highest score is the decoding result. + - (4) nbest-rescoring. Extract n paths from the decoding lattice, + rescore them with an n-gram LM (e.g., a 4-gram LM), the path with + the highest score is the decoding result. + - (5) whole-lattice-rescoring. Rescore the decoding lattice with an + n-gram LM (e.g., a 4-gram LM), the best path of rescored lattice + is the decoding result. + you have trained an RNN LM using ./rnn_lm/train.py + - (6) nbest-oracle. Its WER is the lower bound of any n-best + rescoring method can achieve. Useful for debugging n-best + rescoring method. + """, + ) + + parser.add_argument( + "--num-paths", + type=int, + default=100, + help="""Number of paths for n-best based decoding method. + Used only when "method" is one of the following values: + nbest, nbest-rescoring, and nbest-oracle + """, + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""The scale to be applied to `lattice.scores`. + It's needed if you use any kinds of n-best based rescoring. + Used only when "method" is one of the following values: + nbest, nbest-rescoring, and nbest-oracle + A smaller value results in more unique paths. + """, + ) + + parser.add_argument( + "--hlg-scale", + type=float, + default=0.8, + help="""The scale to be applied to `hlg.scores`. + """, + ) + + parser.add_argument( + "--lm-dir", + type=str, + default="data/lm", + help="""The n-gram LM dir. + It should contain either G_4_gram.pt or G_4_gram.fst.txt + """, + ) + + add_model_arguments(parser) + + return parser + + +def get_decoding_params() -> AttributeDict: + """Parameters for decoding.""" + params = AttributeDict( + { + "frame_shift_ms": 10, + "search_beam": 20, + "output_beam": 8, + "min_active_states": 30, + "max_active_states": 10000, + "use_double_scores": True, + } + ) + return params + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + HLG: Optional[k2.Fsa], + H: Optional[k2.Fsa], + bpe_model: Optional[spm.SentencePieceProcessor], + batch: dict, + word_table: k2.SymbolTable, + G: Optional[k2.Fsa] = None, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + - key: It indicates the setting used for decoding. For example, + if no rescoring is used, the key is the string `no_rescore`. + If LM rescoring is used, the key is the string `lm_scale_xxx`, + where `xxx` is the value of `lm_scale`. An example key is + `lm_scale_0.7` + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + - params.decoding_method is "1best", it uses 1best decoding without LM rescoring. + - params.decoding_method is "nbest", it uses nbest decoding without LM rescoring. + - params.decoding_method is "nbest-rescoring", it uses nbest LM rescoring. + - params.decoding_method is "whole-lattice-rescoring", it uses whole lattice LM + rescoring. + model: + The neural model. + HLG: + The decoding graph. Used only when params.decoding_method is NOT ctc-decoding. + H: + The ctc topo. Used only when params.decoding_method is ctc-decoding. + bpe_model: + The BPE model. Used only when params.decoding_method is ctc-decoding. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + word_table: + The word symbol table. + G: + An LM. It is not None when params.decoding_method is "nbest-rescoring" + or "whole-lattice-rescoring". In general, the G in HLG + is a 3-gram LM, while this G is a 4-gram LM. + Returns: + Return the decoding result. See above description for the format of + the returned dict. Note: If it decodes to nothing, then return None. + """ + if HLG is not None: + device = HLG.device + else: + device = H.device + feature = batch["inputs"] + assert feature.ndim == 3 + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + encoder_out, encoder_out_lens = model.encoder(feature, feature_lens) + nnet_output = model.ctc_output(encoder_out) + # nnet_output is (N, T, C) + + supervision_segments = torch.stack( + ( + supervisions["sequence_idx"], + supervisions["start_frame"] // params.subsampling_factor, + supervisions["num_frames"] // params.subsampling_factor, + ), + 1, + ).to(torch.int32) + + if H is None: + assert HLG is not None + decoding_graph = HLG + else: + assert HLG is None + assert bpe_model is not None + decoding_graph = H + + lattice = get_lattice( + nnet_output=nnet_output, + decoding_graph=decoding_graph, + supervision_segments=supervision_segments, + search_beam=params.search_beam, + output_beam=params.output_beam, + min_active_states=params.min_active_states, + max_active_states=params.max_active_states, + subsampling_factor=params.subsampling_factor, + ) + + if params.decoding_method == "ctc-decoding": + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + # Note: `best_path.aux_labels` contains token IDs, not word IDs + # since we are using H, not HLG here. + # + # token_ids is a lit-of-list of IDs + token_ids = get_texts(best_path) + + # hyps is a list of str, e.g., ['xxx yyy zzz', ...] + hyps = bpe_model.decode(token_ids) + + # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ] + hyps = [s.split() for s in hyps] + key = "ctc-decoding" + return {key: hyps} + + if params.decoding_method == "nbest-oracle": + # Note: You can also pass rescored lattices to it. + # We choose the HLG decoded lattice for speed reasons + # as HLG decoding is faster and the oracle WER + # is only slightly worse than that of rescored lattices. + best_path = nbest_oracle( + lattice=lattice, + num_paths=params.num_paths, + ref_texts=supervisions["text"], + word_table=word_table, + nbest_scale=params.nbest_scale, + oov="", + ) + hyps = get_texts(best_path) + hyps = [[word_table[i] for i in ids] for ids in hyps] + key = f"oracle_{params.num_paths}_nbest_scale_{params.nbest_scale}" # noqa + return {key: hyps} + + if params.decoding_method in ["1best", "nbest"]: + if params.decoding_method == "1best": + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + key = "no_rescore" + else: + best_path = nbest_decoding( + lattice=lattice, + num_paths=params.num_paths, + use_double_scores=params.use_double_scores, + nbest_scale=params.nbest_scale, + ) + key = f"no_rescore-nbest-scale-{params.nbest_scale}-{params.num_paths}" # noqa + + hyps = get_texts(best_path) + hyps = [[word_table[i] for i in ids] for ids in hyps] + return {key: hyps} + + assert params.decoding_method in [ + "nbest-rescoring", + "whole-lattice-rescoring", + ] + + lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7] + lm_scale_list += [0.8, 0.9, 1.0, 1.1, 1.2, 1.3] + lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0] + + if params.decoding_method == "nbest-rescoring": + best_path_dict = rescore_with_n_best_list( + lattice=lattice, + G=G, + num_paths=params.num_paths, + lm_scale_list=lm_scale_list, + nbest_scale=params.nbest_scale, + ) + elif params.decoding_method == "whole-lattice-rescoring": + best_path_dict = rescore_with_whole_lattice( + lattice=lattice, + G_with_epsilon_loops=G, + lm_scale_list=lm_scale_list, + ) + else: + assert False, f"Unsupported decoding method: {params.decoding_method}" + + ans = dict() + if best_path_dict is not None: + for lm_scale_str, best_path in best_path_dict.items(): + hyps = get_texts(best_path) + hyps = [[word_table[i] for i in ids] for ids in hyps] + ans[lm_scale_str] = hyps + else: + ans = None + return ans + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + HLG: Optional[k2.Fsa], + H: Optional[k2.Fsa], + bpe_model: Optional[spm.SentencePieceProcessor], + word_table: k2.SymbolTable, + G: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + HLG: + The decoding graph. Used only when params.decoding_method is NOT ctc-decoding. + H: + The ctc topo. Used only when params.decoding_method is ctc-decoding. + bpe_model: + The BPE model. Used only when params.decoding_method is ctc-decoding. + word_table: + It is the word symbol table. + G: + An LM. It is not None when params.decoding_method is "nbest-rescoring" + or "whole-lattice-rescoring". In general, the G in HLG + is a 3-gram LM, while this G is a 4-gram LM. + Returns: + Return a dict, whose key may be "no-rescore" if no LM rescoring + is used, or it may be "lm_scale_0.7" if LM rescoring is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + HLG=HLG, + H=H, + bpe_model=bpe_model, + batch=batch, + word_table=word_table, + G=G, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % 100 == 0: + batch_str = f"{batch_idx}/{num_batches}" + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_filename, "w") as f: + wer = write_error_stats(f, f"{test_set_name}-{key}", results) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = ( + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + args.lang_dir = Path(args.lang_dir) + args.lm_dir = Path(args.lm_dir) + + params = get_params() + # add decoding params + params.update(get_decoding_params()) + params.update(vars(args)) + + assert params.decoding_method in ( + "ctc-decoding", + "1best", + "nbest", + "nbest-rescoring", + "whole-lattice-rescoring", + "nbest-oracle", + ) + params.res_dir = params.exp_dir / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + logging.info(params) + + lexicon = Lexicon(params.lang_dir) + max_token_id = max(lexicon.tokens) + num_classes = max_token_id + 1 # +1 for the blank + + params.vocab_size = num_classes + # and are defined in local/train_bpe_model.py + params.blank_id = 0 + + if params.decoding_method == "ctc-decoding": + HLG = None + H = k2.ctc_topo( + max_token=max_token_id, + modified=False, + device=device, + ) + bpe_model = spm.SentencePieceProcessor() + bpe_model.load(str(params.lang_dir / "bpe.model")) + else: + H = None + bpe_model = None + HLG = k2.Fsa.from_dict( + torch.load(f"{params.lang_dir}/HLG.pt", map_location=device) + ) + assert HLG.requires_grad is False + + HLG.scores *= params.hlg_scale + if not hasattr(HLG, "lm_scores"): + HLG.lm_scores = HLG.scores.clone() + + if params.decoding_method in ( + "nbest-rescoring", + "whole-lattice-rescoring", + ): + if not (params.lm_dir / "G_4_gram.pt").is_file(): + logging.info("Loading G_4_gram.fst.txt") + logging.warning("It may take 8 minutes.") + with open(params.lm_dir / "G_4_gram.fst.txt") as f: + first_word_disambig_id = lexicon.word_table["#0"] + + G = k2.Fsa.from_openfst(f.read(), acceptor=False) + # G.aux_labels is not needed in later computations, so + # remove it here. + del G.aux_labels + # CAUTION: The following line is crucial. + # Arcs entering the back-off state have label equal to #0. + # We have to change it to 0 here. + G.labels[G.labels >= first_word_disambig_id] = 0 + # See https://github.com/k2-fsa/k2/issues/874 + # for why we need to set G.properties to None + G.__dict__["_properties"] = None + G = k2.Fsa.from_fsas([G]).to(device) + G = k2.arc_sort(G) + # Save a dummy value so that it can be loaded in C++. + # See https://github.com/pytorch/pytorch/issues/67902 + # for why we need to do this. + G.dummy = 1 + + torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt") + else: + logging.info("Loading pre-compiled G_4_gram.pt") + d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device) + G = k2.Fsa.from_dict(d) + + if params.decoding_method == "whole-lattice-rescoring": + # Add epsilon self-loops to G as we will compose + # it with the whole lattice later + G = k2.add_epsilon_self_loops(G) + G = k2.arc_sort(G) + G = G.to(device) + + # G.lm_scores is used to replace HLG.lm_scores during + # LM rescoring. + G.lm_scores = G.scores.clone() + else: + G = None + + logging.info("About to create model") + model = get_transducer_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + librispeech = LibriSpeechAsrDataModule(args) + + test_clean_cuts = librispeech.test_clean_cuts() + test_other_cuts = librispeech.test_other_cuts() + + test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) + test_other_dl = librispeech.test_dataloaders(test_other_cuts) + + test_sets = ["test-clean", "test-other"] + test_dl = [test_clean_dl, test_other_dl] + + for test_set, test_dl in zip(test_sets, test_dl): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + HLG=HLG, + H=H, + bpe_model=bpe_model, + word_table=lexicon.word_table, + G=G, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/ctc_guild_decode_bs.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/ctc_guild_decode_bs.py new file mode 100755 index 000000000..9c2166aaf --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/ctc_guild_decode_bs.py @@ -0,0 +1,857 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao, +# Yifan Yang,) +# +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) greedy search +./pruned_transducer_stateless7_ctc_bs/ctc_guild_decode_bs.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ + --max-duration 600 \ + --decoding-method greedy_search + +(2) beam search (not recommended) +./pruned_transducer_stateless7_ctc_bs/ctc_guild_decode_bs.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ + --max-duration 600 \ + --decoding-method beam_search \ + --beam-size 4 + +(3) modified beam search +./pruned_transducer_stateless7_ctc_bs/ctc_guild_decode_bs.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ + --max-duration 600 \ + --decoding-method modified_beam_search \ + --beam-size 4 + +(4) fast beam search (one best) +./pruned_transducer_stateless7_ctc_bs/ctc_guild_decode_bs.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 + +(5) fast beam search (nbest) +./pruned_transducer_stateless7_ctc/ctc_guild_decode_bs.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_ctc/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(6) fast beam search (nbest oracle WER) +./pruned_transducer_stateless7_ctc_bs/ctc_guild_decode_bs.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_oracle \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(7) fast beam search (with LG) +./pruned_transducer_stateless7_ctc_bs/ctc_guild_decode_bs.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_LG \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 +""" + + +import argparse +import logging +import math +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from beam_search import ( + beam_search, + fast_beam_search_nbest, + fast_beam_search_nbest_LG, + fast_beam_search_nbest_oracle, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, +) +from train import add_model_arguments, get_params, get_transducer_model +from torch.nn.utils.rnn import pad_sequence + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + make_pad_mask, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=9, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7_ctc_bs/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_bpe_500", + help="The lang dir containing word table and LG graph", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - fast_beam_search + - fast_beam_search_nbest + - fast_beam_search_nbest_oracle + - fast_beam_search_nbest_LG + If you use fast_beam_search_nbest_LG, you have to specify + `--lang-dir`, which should contain `LG.pt`. + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=20.0, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search, + fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle + """, + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.01, + help=""" + Used only when --decoding_method is fast_beam_search_nbest_LG. + It specifies the scale for n-gram LM scores. + """, + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=8, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=64, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram, 2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. + Used only when --decoding_method is greedy_search""", + ) + + parser.add_argument( + "--num-paths", + type=int, + default=200, + help="""Number of paths for nbest decoding. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""Scale applied to lattice scores when computing nbest paths. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--simulate-streaming", + type=str2bool, + default=False, + help="""Whether to simulate streaming in decoding, this is a good way to + test a streaming model. + """, + ) + + parser.add_argument( + "--decode-chunk-size", + type=int, + default=16, + help="The chunk size for decoding (in frames after subsampling)", + ) + + parser.add_argument( + "--left-context", + type=int, + default=64, + help="left context can be seen during decoding (in frames after subsampling)", + ) + + add_model_arguments(parser) + + return parser + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + batch: dict, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if greedy_search is used, it would be "greedy_search" + If beam search with a beam size of 7 is used, it would be + "beam_7" + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return the decoding result. See above description for the format of + the returned dict. + """ + device = next(model.parameters()).device + feature = batch["inputs"] + assert feature.ndim == 3 + + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + if params.simulate_streaming: + feature_lens += params.left_context + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, params.left_context), + value=LOG_EPS, + ) + encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward( + x=feature, + x_lens=feature_lens, + chunk_size=params.decode_chunk_size, + left_context=params.left_context, + simulate_streaming=True, + ) + else: + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + + # filter out blank frames using ctc outputs + ctc_output = model.ctc_output(encoder_out) + encoder_out = model.lconv( + x=encoder_out, + src_key_padding_mask=make_pad_mask(encoder_out_lens), + ) + encoder_out, encoder_out_lens = model.frame_reducer( + x=encoder_out, + x_lens=encoder_out_lens, + ctc_output=ctc_output, + blank_id=0, + ) + + hyps = [] + + if params.decoding_method == "fast_beam_search": + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_LG": + hyp_tokens = fast_beam_search_nbest_LG( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in hyp_tokens: + hyps.append([word_table[i] for i in hyp]) + elif params.decoding_method == "fast_beam_search_nbest": + hyp_tokens = fast_beam_search_nbest( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_oracle": + hyp_tokens = fast_beam_search_nbest_oracle( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + ref_texts=sp.encode(supervisions["text"]), + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + else: + batch_size = encoder_out.size(0) + + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.decoding_method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyps.append(sp.decode(hyp).split()) + + if params.decoding_method == "greedy_search": + return {"greedy_search": hyps} + elif "fast_beam_search" in params.decoding_method: + key = f"beam_{params.beam}_" + key += f"max_contexts_{params.max_contexts}_" + key += f"max_states_{params.max_states}" + if "nbest" in params.decoding_method: + key += f"_num_paths_{params.num_paths}_" + key += f"nbest_scale_{params.nbest_scale}" + if "LG" in params.decoding_method: + key += f"_ngram_lm_scale_{params.ngram_lm_scale}" + + return {key: hyps} + else: + return {f"beam_size_{params.beam_size}": hyps} + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + if params.decoding_method == "greedy_search": + log_interval = 50 + else: + log_interval = 20 + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + word_table=word_table, + batch=batch, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = ( + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + assert params.decoding_method in ( + "greedy_search", + "beam_search", + "fast_beam_search", + "fast_beam_search_nbest", + "fast_beam_search_nbest_LG", + "fast_beam_search_nbest_oracle", + "modified_beam_search", + ) + params.res_dir = params.exp_dir / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if params.simulate_streaming: + params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}" + params.suffix += f"-left-context-{params.left_context}" + + if "fast_beam_search" in params.decoding_method: + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + if "nbest" in params.decoding_method: + params.suffix += f"-nbest-scale-{params.nbest_scale}" + params.suffix += f"-num-paths-{params.num_paths}" + if "LG" in params.decoding_method: + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" + elif "beam_search" in params.decoding_method: + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + else: + params.suffix += f"-context-{params.context_size}" + params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and are defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + if params.simulate_streaming: + assert ( + params.causal_convolution + ), "Decoding in streaming requires causal convolution" + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + + if "fast_beam_search" in params.decoding_method: + if params.decoding_method == "fast_beam_search_nbest_LG": + lexicon = Lexicon(params.lang_dir) + word_table = lexicon.word_table + lg_filename = params.lang_dir / "LG.pt" + logging.info(f"Loading {lg_filename}") + decoding_graph = k2.Fsa.from_dict( + torch.load(lg_filename, map_location=device) + ) + decoding_graph.scores *= params.ngram_lm_scale + else: + word_table = None + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + else: + decoding_graph = None + word_table = None + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + librispeech = LibriSpeechAsrDataModule(args) + + test_clean_cuts = librispeech.test_clean_cuts() + test_other_cuts = librispeech.test_other_cuts() + + test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) + test_other_dl = librispeech.test_dataloaders(test_other_cuts) + + test_sets = ["test-clean", "test-other"] + test_dl = [test_clean_dl, test_other_dl] + + for test_set, test_dl in zip(test_sets, test_dl): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + sp=sp, + word_table=word_table, + decoding_graph=decoding_graph, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/decode.py new file mode 100755 index 000000000..ce45a4beb --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/decode.py @@ -0,0 +1,841 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) greedy search +./pruned_transducer_stateless7_ctc_bs/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ + --max-duration 600 \ + --decoding-method greedy_search + +(2) beam search (not recommended) +./pruned_transducer_stateless7_ctc_bs/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ + --max-duration 600 \ + --decoding-method beam_search \ + --beam-size 4 + +(3) modified beam search +./pruned_transducer_stateless7_ctc_bs/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ + --max-duration 600 \ + --decoding-method modified_beam_search \ + --beam-size 4 + +(4) fast beam search (one best) +./pruned_transducer_stateless7_ctc_bs/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 + +(5) fast beam search (nbest) +./pruned_transducer_stateless7_ctc_bs/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(6) fast beam search (nbest oracle WER) +./pruned_transducer_stateless7_ctc_bs/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_oracle \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(7) fast beam search (with LG) +./pruned_transducer_stateless7_ctc_bs/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_LG \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 +""" + + +import argparse +import logging +import math +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from beam_search import ( + beam_search, + fast_beam_search_nbest, + fast_beam_search_nbest_LG, + fast_beam_search_nbest_oracle, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, +) +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=9, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7_ctc_bs/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_bpe_500", + help="The lang dir containing word table and LG graph", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - fast_beam_search + - fast_beam_search_nbest + - fast_beam_search_nbest_oracle + - fast_beam_search_nbest_LG + If you use fast_beam_search_nbest_LG, you have to specify + `--lang-dir`, which should contain `LG.pt`. + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=20.0, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search, + fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle + """, + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.01, + help=""" + Used only when --decoding_method is fast_beam_search_nbest_LG. + It specifies the scale for n-gram LM scores. + """, + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=8, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=64, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram, 2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. + Used only when --decoding_method is greedy_search""", + ) + + parser.add_argument( + "--num-paths", + type=int, + default=200, + help="""Number of paths for nbest decoding. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""Scale applied to lattice scores when computing nbest paths. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--simulate-streaming", + type=str2bool, + default=False, + help="""Whether to simulate streaming in decoding, this is a good way to + test a streaming model. + """, + ) + + parser.add_argument( + "--decode-chunk-size", + type=int, + default=16, + help="The chunk size for decoding (in frames after subsampling)", + ) + + parser.add_argument( + "--left-context", + type=int, + default=64, + help="left context can be seen during decoding (in frames after subsampling)", + ) + + add_model_arguments(parser) + + return parser + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + batch: dict, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if greedy_search is used, it would be "greedy_search" + If beam search with a beam size of 7 is used, it would be + "beam_7" + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return the decoding result. See above description for the format of + the returned dict. + """ + device = next(model.parameters()).device + feature = batch["inputs"] + assert feature.ndim == 3 + + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + if params.simulate_streaming: + feature_lens += params.left_context + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, params.left_context), + value=LOG_EPS, + ) + encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward( + x=feature, + x_lens=feature_lens, + chunk_size=params.decode_chunk_size, + left_context=params.left_context, + simulate_streaming=True, + ) + else: + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + + hyps = [] + + if params.decoding_method == "fast_beam_search": + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_LG": + hyp_tokens = fast_beam_search_nbest_LG( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in hyp_tokens: + hyps.append([word_table[i] for i in hyp]) + elif params.decoding_method == "fast_beam_search_nbest": + hyp_tokens = fast_beam_search_nbest( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_oracle": + hyp_tokens = fast_beam_search_nbest_oracle( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + ref_texts=sp.encode(supervisions["text"]), + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + else: + batch_size = encoder_out.size(0) + + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.decoding_method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyps.append(sp.decode(hyp).split()) + + if params.decoding_method == "greedy_search": + return {"greedy_search": hyps} + elif "fast_beam_search" in params.decoding_method: + key = f"beam_{params.beam}_" + key += f"max_contexts_{params.max_contexts}_" + key += f"max_states_{params.max_states}" + if "nbest" in params.decoding_method: + key += f"_num_paths_{params.num_paths}_" + key += f"nbest_scale_{params.nbest_scale}" + if "LG" in params.decoding_method: + key += f"_ngram_lm_scale_{params.ngram_lm_scale}" + + return {key: hyps} + else: + return {f"beam_size_{params.beam_size}": hyps} + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + if params.decoding_method == "greedy_search": + log_interval = 50 + else: + log_interval = 20 + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + word_table=word_table, + batch=batch, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = ( + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + assert params.decoding_method in ( + "greedy_search", + "beam_search", + "fast_beam_search", + "fast_beam_search_nbest", + "fast_beam_search_nbest_LG", + "fast_beam_search_nbest_oracle", + "modified_beam_search", + ) + params.res_dir = params.exp_dir / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if params.simulate_streaming: + params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}" + params.suffix += f"-left-context-{params.left_context}" + + if "fast_beam_search" in params.decoding_method: + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + if "nbest" in params.decoding_method: + params.suffix += f"-nbest-scale-{params.nbest_scale}" + params.suffix += f"-num-paths-{params.num_paths}" + if "LG" in params.decoding_method: + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" + elif "beam_search" in params.decoding_method: + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + else: + params.suffix += f"-context-{params.context_size}" + params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and are defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + if params.simulate_streaming: + assert ( + params.causal_convolution + ), "Decoding in streaming requires causal convolution" + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + + if "fast_beam_search" in params.decoding_method: + if params.decoding_method == "fast_beam_search_nbest_LG": + lexicon = Lexicon(params.lang_dir) + word_table = lexicon.word_table + lg_filename = params.lang_dir / "LG.pt" + logging.info(f"Loading {lg_filename}") + decoding_graph = k2.Fsa.from_dict( + torch.load(lg_filename, map_location=device) + ) + decoding_graph.scores *= params.ngram_lm_scale + else: + word_table = None + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + else: + decoding_graph = None + word_table = None + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + librispeech = LibriSpeechAsrDataModule(args) + + test_clean_cuts = librispeech.test_clean_cuts() + test_other_cuts = librispeech.test_other_cuts() + + test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) + test_other_dl = librispeech.test_dataloaders(test_other_cuts) + + test_sets = ["test-clean", "test-other"] + test_dl = [test_clean_dl, test_other_dl] + + for test_set, test_dl in zip(test_sets, test_dl): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + sp=sp, + word_table=word_table, + decoding_graph=decoding_graph, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/decoder.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/decoder.py new file mode 120000 index 000000000..33944d0d2 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/decoder.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7/decoder.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/encoder_interface.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/encoder_interface.py new file mode 120000 index 000000000..b9aa0ae08 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/encoder_interface.py @@ -0,0 +1 @@ +../pruned_transducer_stateless2/encoder_interface.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/export.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/export.py new file mode 100755 index 000000000..96d316604 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/export.py @@ -0,0 +1,319 @@ +#!/usr/bin/env python3 +# +# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This script converts several saved checkpoints +# to a single one using model averaging. +""" + +Usage: + +(1) Export to torchscript model using torch.jit.script() + +./pruned_transducer_stateless7_ctc_bs/export.py \ + --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 30 \ + --avg 13 \ + --jit 1 + +It will generate a file `cpu_jit.pt` in the given `exp_dir`. You can later +load it by `torch.jit.load("cpu_jit.pt")`. + +Note `cpu` in the name `cpu_jit.pt` means the parameters when loaded into Python +are on CPU. You can use `to("cuda")` to move them to a CUDA device. + +Check +https://github.com/k2-fsa/sherpa +for how to use the exported models outside of icefall. + +(2) Export `model.state_dict()` + +./pruned_transducer_stateless7_ctc_bs/export.py \ + --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 30 \ + --avg 13 + +It will generate a file `pretrained.pt` in the given `exp_dir`. You can later +load it by `icefall.checkpoint.load_checkpoint()`. + +To use the generated file with `pruned_transducer_stateless7_ctc_bs/decode.py`, +you can do: + + cd /path/to/exp_dir + ln -s pretrained.pt epoch-9999.pt + + cd /path/to/egs/librispeech/ASR + ./pruned_transducer_stateless7_ctc_bs/decode.py \ + --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ + --epoch 9999 \ + --avg 1 \ + --max-duration 600 \ + --decoding-method greedy_search \ + --bpe-model data/lang_bpe_500/bpe.model + +Check ./pretrained.py for its usage. + +Note: If you don't want to train a model from scratch, we have +provided one for you. You can get it at + +https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11 + +with the following commands: + + sudo apt-get install git-lfs + git lfs install + git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11 + # You will find the pre-trained model in icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11/exp +""" + +import argparse +import logging +from pathlib import Path + +import sentencepiece as spm +import torch +from scaling_converter import convert_scaled_to_non_scaled +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=9, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--jit", + type=str2bool, + default=False, + help="""True to save a model after applying torch.jit.script. + It will generate a file named cpu_jit.pt + + Check ./jit_pretrained.py for how to use it. + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram, 2 means tri-gram", + ) + + add_model_arguments(parser) + + return parser + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + model.to(device) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to("cpu") + model.eval() + + if params.jit is True: + convert_scaled_to_non_scaled(model, inplace=True) + logging.info("Using torch.jit.script()") + # We won't use the forward() method of the model in C++, so just ignore + # it here. + # Otherwise, one of its arguments is a ragged tensor and is not + # torch scriptabe. + model.__class__.forward = torch.jit.ignore(model.__class__.forward) + logging.info("Using torch.jit.script") + model = torch.jit.script(model) + filename = params.exp_dir / "cpu_jit.pt" + model.save(str(filename)) + logging.info(f"Saved to {filename}") + else: + logging.info("Not using torchscript. Export model.state_dict()") + # Save it using a format so that it can be loaded + # by :func:`load_checkpoint` + filename = params.exp_dir / "pretrained.pt" + torch.save({"model": model.state_dict()}, str(filename)) + logging.info(f"Saved to {filename}") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/frame_reducer.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/frame_reducer.py new file mode 100755 index 000000000..3de21a293 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/frame_reducer.py @@ -0,0 +1,84 @@ +#!/usr/bin/env python3 +# +# Copyright 2022 Xiaomi Corp. (authors: Yifan Yang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn +from torch.nn.utils.rnn import pad_sequence +from icefall.utils import make_pad_mask + + +class FrameReducer(nn.Module): + """The encoder output is first used to calculate + the CTC posterior probability; then for each output frame, + if its blank posterior is bigger than some thresholds, + it will be simply discarded from the encoder output. + """ + + def __init__( + self, + ): + super().__init__() + + def forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + ctc_output: torch.Tensor, + blank_id: int = 0, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + x: + The shared encoder output with shape [N, T, C]. + x_lens: + A tensor of shape (batch_size,) containing the number of frames in + `x` before padding. + ctc_output: + The CTC output with shape [N, T, vocab_size]. + blank_id: + The ID of the blank symbol. + Returns: + x_fr: + The frame reduced encoder output with shape [N, T', C]. + x_lens_fr: + A tensor of shape (batch_size,) containing the number of frames in + `x_fr` before padding. + """ + + padding_mask = make_pad_mask(x_lens) + non_blank_mask = (ctc_output[:, :, blank_id] < math.log(0.9)) * (~padding_mask) + T_range = torch.arange(x.shape[1], device=x.device) + + frames_list: List[torch.Tensor] = [] + lens_list: List[int] = [] + for i in range(x.shape[0]): + indexes = torch.masked_select( + T_range, + non_blank_mask[i], + ) + frames = x[i][indexes] + frames_list.append(frames) + lens_list.append(frames.shape[0]) + x_fr = pad_sequence(frames_list).transpose(0, 1) + x_lens_fr = torch.tensor(lens_list).to(device=x.device) + + return x_fr, x_lens_fr diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/jit_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/jit_pretrained.py new file mode 100755 index 000000000..da2c6a39a --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/jit_pretrained.py @@ -0,0 +1,271 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads torchscript models, exported by `torch.jit.script()` +and uses them to decode waves. +You can use the following command to get the exported models: + +./pruned_transducer_stateless7_ctc_bs/export.py \ + --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 20 \ + --avg 10 \ + --jit 1 + +Usage of this script: + +./pruned_transducer_stateless7_ctc_bs/jit_pretrained.py \ + --nn-model-filename ./pruned_transducer_stateless7_ctc_bs/exp/cpu_jit.pt \ + /path/to/foo.wav \ + /path/to/bar.wav +""" + +import argparse +import logging +import math +from typing import List + +import kaldifeat +import sentencepiece as spm +import torch +import torchaudio +from torch.nn.utils.rnn import pad_sequence + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--nn-model-filename", + type=str, + required=True, + help="Path to the torchscript model cpu_jit.pt", + ) + + parser.add_argument( + "--bpe-model", + type=str, + help="""Path to bpe.model.""", + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float = 16000 +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"Expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0]) + return ans + + +def greedy_search( + model: torch.jit.ScriptModule, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, +) -> List[List[int]]: + """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. + Args: + model: + The transducer model. + encoder_out: + A 3-D tensor of shape (N, T, C) + encoder_out_lens: + A 1-D tensor of shape (N,). + Returns: + Return the decoded results for each utterance. + """ + assert encoder_out.ndim == 3 + assert encoder_out.size(0) >= 1, encoder_out.size(0) + + packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( + input=encoder_out, + lengths=encoder_out_lens.cpu(), + batch_first=True, + enforce_sorted=False, + ) + + device = encoder_out.device + blank_id = 0 # hard-code to 0 + + batch_size_list = packed_encoder_out.batch_sizes.tolist() + N = encoder_out.size(0) + + assert torch.all(encoder_out_lens > 0), encoder_out_lens + assert N == batch_size_list[0], (N, batch_size_list) + + context_size = model.decoder.context_size + hyps = [[blank_id] * context_size for _ in range(N)] + + decoder_input = torch.tensor( + hyps, + device=device, + dtype=torch.int64, + ) # (N, context_size) + + decoder_out = model.decoder( + decoder_input, + need_pad=torch.tensor([False]), + ).squeeze(1) + + offset = 0 + for batch_size in batch_size_list: + start = offset + end = offset + batch_size + current_encoder_out = packed_encoder_out.data[start:end] + current_encoder_out = current_encoder_out + # current_encoder_out's shape: (batch_size, encoder_out_dim) + offset = end + + decoder_out = decoder_out[:batch_size] + + logits = model.joiner( + current_encoder_out, + decoder_out, + ) + # logits'shape (batch_size, vocab_size) + + assert logits.ndim == 2, logits.shape + y = logits.argmax(dim=1).tolist() + emitted = False + for i, v in enumerate(y): + if v != blank_id: + hyps[i].append(v) + emitted = True + if emitted: + # update decoder output + decoder_input = [h[-context_size:] for h in hyps[:batch_size]] + decoder_input = torch.tensor( + decoder_input, + device=device, + dtype=torch.int64, + ) + decoder_out = model.decoder( + decoder_input, + need_pad=torch.tensor([False]), + ) + decoder_out = decoder_out.squeeze(1) + + sorted_ans = [h[context_size:] for h in hyps] + ans = [] + unsorted_indices = packed_encoder_out.unsorted_indices.tolist() + for i in range(N): + ans.append(sorted_ans[unsorted_indices[i]]) + + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + logging.info(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + model = torch.jit.load(args.nn_model_filename) + + model.eval() + + model.to(device) + + sp = spm.SentencePieceProcessor() + sp.load(args.bpe_model) + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = 16000 + opts.mel_opts.num_bins = 80 + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {args.sound_files}") + waves = read_sound_files( + filenames=args.sound_files, + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + + features = pad_sequence( + features, + batch_first=True, + padding_value=math.log(1e-10), + ) + + feature_lengths = torch.tensor(feature_lengths, device=device) + + encoder_out, encoder_out_lens = model.encoder( + x=features, + x_lens=feature_lengths, + ) + + hyps = greedy_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + s = "\n" + for filename, hyp in zip(args.sound_files, hyps): + words = sp.decode(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/jit_pretrained_ctc.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/jit_pretrained_ctc.py new file mode 100755 index 000000000..653c25e06 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/jit_pretrained_ctc.py @@ -0,0 +1,426 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads torchscript models, exported by `torch.jit.script()` +and uses them to decode waves. +You can use the following command to get the exported models: + +./pruned_transducer_stateless7_ctc_bs/export.py \ + --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 30 \ + --avg 13 \ + --jit 1 + +Usage of this script: + +(1) ctc-decoding +./pruned_transducer_stateless7_ctc_bs/jit_pretrained_ctc.py \ + --model-filename ./pruned_transducer_stateless7_ctc_bs/exp/cpu_jit.pt \ + --bpe-model data/lang_bpe_500/bpe.model \ + --method ctc-decoding \ + --sample-rate 16000 \ + /path/to/foo.wav \ + /path/to/bar.wav + +(2) 1best +./pruned_transducer_stateless7_ctc_bs/jit_pretrained_ctc.py \ + --model-filename ./pruned_transducer_stateless7_ctc_bs/exp/cpu_jit.pt \ + --HLG data/lang_bpe_500/HLG.pt \ + --words-file data/lang_bpe_500/words.txt \ + --method 1best \ + --sample-rate 16000 \ + /path/to/foo.wav \ + /path/to/bar.wav + + +(3) nbest-rescoring +./pruned_transducer_stateless7_ctc_bs/jit_pretrained_ctc.py \ + --model-filename ./pruned_transducer_stateless7_ctc_bs/exp/cpu_jit.pt \ + --HLG data/lang_bpe_500/HLG.pt \ + --words-file data/lang_bpe_500/words.txt \ + --G data/lm/G_4_gram.pt \ + --method nbest-rescoring \ + --sample-rate 16000 \ + /path/to/foo.wav \ + /path/to/bar.wav + + +(4) whole-lattice-rescoring +./pruned_transducer_stateless7_ctc_bs/jit_pretrained_ctc.py \ + --model-filename ./pruned_transducer_stateless7_ctc_bs/exp/cpu_jit.pt \ + --HLG data/lang_bpe_500/HLG.pt \ + --words-file data/lang_bpe_500/words.txt \ + --G data/lm/G_4_gram.pt \ + --method whole-lattice-rescoring \ + --sample-rate 16000 \ + /path/to/foo.wav \ + /path/to/bar.wav +""" + +import argparse +import logging +import math +from typing import List + +import k2 +import kaldifeat +import sentencepiece as spm +import torch +import torchaudio +from ctc_decode import get_decoding_params +from torch.nn.utils.rnn import pad_sequence +from train import get_params + +from icefall.decode import ( + get_lattice, + one_best_decoding, + rescore_with_n_best_list, + rescore_with_whole_lattice, +) +from icefall.utils import get_texts + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--model-filename", + type=str, + required=True, + help="Path to the torchscript model.", + ) + + parser.add_argument( + "--words-file", + type=str, + help="""Path to words.txt. + Used only when method is not ctc-decoding. + """, + ) + + parser.add_argument( + "--HLG", + type=str, + help="""Path to HLG.pt. + Used only when method is not ctc-decoding. + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + help="""Path to bpe.model. + Used only when method is ctc-decoding. + """, + ) + + parser.add_argument( + "--method", + type=str, + default="1best", + help="""Decoding method. + Possible values are: + (0) ctc-decoding - Use CTC decoding. It uses a sentence + piece model, i.e., lang_dir/bpe.model, to convert + word pieces to words. It needs neither a lexicon + nor an n-gram LM. + (1) 1best - Use the best path as decoding output. Only + the transformer encoder output is used for decoding. + We call it HLG decoding. + (2) nbest-rescoring. Extract n paths from the decoding lattice, + rescore them with an LM, the path with + the highest score is the decoding result. + We call it HLG decoding + n-gram LM rescoring. + (3) whole-lattice-rescoring - Use an LM to rescore the + decoding lattice and then use 1best to decode the + rescored lattice. + We call it HLG decoding + n-gram LM rescoring. + """, + ) + + parser.add_argument( + "--G", + type=str, + help="""An LM for rescoring. + Used only when method is + whole-lattice-rescoring or nbest-rescoring. + It's usually a 4-gram LM. + """, + ) + + parser.add_argument( + "--num-paths", + type=int, + default=100, + help=""" + Used only when method is attention-decoder. + It specifies the size of n-best list.""", + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=1.3, + help=""" + Used only when method is whole-lattice-rescoring and nbest-rescoring. + It specifies the scale for n-gram LM scores. + (Note: You need to tune it on a dataset.) + """, + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help=""" + Used only when method is nbest-rescoring. + It specifies the scale for lattice.scores when + extracting n-best lists. A smaller value results in + more unique number of paths with the risk of missing + the best path. + """, + ) + + parser.add_argument( + "--num-classes", + type=int, + default=500, + help=""" + Vocab size in the BPE model. + """, + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float = 16000 +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"Expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0]) + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + + params = get_params() + # add decoding params + params.update(get_decoding_params()) + params.update(vars(args)) + + logging.info(f"{params}") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + model = torch.jit.load(args.model_filename) + model.to(device) + model.eval() + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = params.sample_rate + opts.mel_opts.num_bins = params.feature_dim + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {params.sound_files}") + waves = read_sound_files( + filenames=params.sound_files, expected_sample_rate=params.sample_rate + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + feature_lengths = torch.tensor(feature_lengths, device=device) + + encoder_out, encoder_out_lens = model.encoder( + x=features, + x_lens=feature_lengths, + ) + nnet_output = model.ctc_output(encoder_out) + + batch_size = nnet_output.shape[0] + supervision_segments = torch.tensor( + [ + [i, 0, feature_lengths[i] // params.subsampling_factor] + for i in range(batch_size) + ], + dtype=torch.int32, + ) + + if params.method == "ctc-decoding": + logging.info("Use CTC decoding") + bpe_model = spm.SentencePieceProcessor() + bpe_model.load(params.bpe_model) + max_token_id = params.num_classes - 1 + + H = k2.ctc_topo( + max_token=max_token_id, + modified=False, + device=device, + ) + + lattice = get_lattice( + nnet_output=nnet_output, + decoding_graph=H, + supervision_segments=supervision_segments, + search_beam=params.search_beam, + output_beam=params.output_beam, + min_active_states=params.min_active_states, + max_active_states=params.max_active_states, + subsampling_factor=params.subsampling_factor, + ) + + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + token_ids = get_texts(best_path) + hyps = bpe_model.decode(token_ids) + hyps = [s.split() for s in hyps] + elif params.method in [ + "1best", + "nbest-rescoring", + "whole-lattice-rescoring", + ]: + logging.info(f"Loading HLG from {params.HLG}") + HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) + HLG = HLG.to(device) + if not hasattr(HLG, "lm_scores"): + # For whole-lattice-rescoring and attention-decoder + HLG.lm_scores = HLG.scores.clone() + + if params.method in [ + "nbest-rescoring", + "whole-lattice-rescoring", + ]: + logging.info(f"Loading G from {params.G}") + G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu")) + G = G.to(device) + if params.method == "whole-lattice-rescoring": + # Add epsilon self-loops to G as we will compose + # it with the whole lattice later + G = k2.add_epsilon_self_loops(G) + G = k2.arc_sort(G) + + # G.lm_scores is used to replace HLG.lm_scores during + # LM rescoring. + G.lm_scores = G.scores.clone() + + lattice = get_lattice( + nnet_output=nnet_output, + decoding_graph=HLG, + supervision_segments=supervision_segments, + search_beam=params.search_beam, + output_beam=params.output_beam, + min_active_states=params.min_active_states, + max_active_states=params.max_active_states, + subsampling_factor=params.subsampling_factor, + ) + + if params.method == "1best": + logging.info("Use HLG decoding") + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + if params.method == "nbest-rescoring": + logging.info("Use HLG decoding + LM rescoring") + best_path_dict = rescore_with_n_best_list( + lattice=lattice, + G=G, + num_paths=params.num_paths, + lm_scale_list=[params.ngram_lm_scale], + nbest_scale=params.nbest_scale, + ) + best_path = next(iter(best_path_dict.values())) + elif params.method == "whole-lattice-rescoring": + logging.info("Use HLG decoding + LM rescoring") + best_path_dict = rescore_with_whole_lattice( + lattice=lattice, + G_with_epsilon_loops=G, + lm_scale_list=[params.ngram_lm_scale], + ) + best_path = next(iter(best_path_dict.values())) + + hyps = get_texts(best_path) + word_sym_table = k2.SymbolTable.from_file(params.words_file) + hyps = [[word_sym_table[i] for i in ids] for ids in hyps] + else: + raise ValueError(f"Unsupported decoding method: {params.method}") + + s = "\n" + for filename, hyp in zip(params.sound_files, hyps): + words = " ".join(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/joiner.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/joiner.py new file mode 120000 index 000000000..ecfb6dd8a --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/joiner.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7/joiner.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/lconv.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/lconv.py new file mode 100755 index 000000000..bfd49d533 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/lconv.py @@ -0,0 +1,114 @@ +# Copyright 2022 Xiaomi Corp. (authors: Yifan Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn +from scaling import ( + ActivationBalancer, + ScaledConv1d, +) + + +class LConv(nn.Module): + """A convolution module to prevent information loss.""" + + def __init__( + self, + channels: int, + kernel_size: int = 7, + bias: bool = True, + ): + """ + Args: + channels: + Dimension of the input embedding, and of the lconv output. + """ + super().__init__() + self.pointwise_conv1 = nn.Conv1d( + channels, + 2 * channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + ) + + self.deriv_balancer1 = ActivationBalancer( + 2 * channels, + channel_dim=1, + max_abs=10.0, + min_positive=0.05, + max_positive=1.0, + ) + + self.depthwise_conv = nn.Conv1d( + 2 * channels, + 2 * channels, + kernel_size=kernel_size, + stride=1, + padding=(kernel_size - 1) // 2, + groups=channels, + bias=bias, + ) + + self.deriv_balancer2 = ActivationBalancer( + 2 * channels, + channel_dim=1, + min_positive=0.05, + max_positive=1.0, + max_abs=20.0, + ) + + self.pointwise_conv2 = ScaledConv1d( + 2 * channels, + channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + initial_scale=0.05, + ) + + def forward( + self, + x: torch.Tensor, + src_key_padding_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Args: + x: A 3-D tensor of shape (N, T, C). + Returns: + Return a tensor of shape (N, T, C). + """ + # exchange the temporal dimension and the feature dimension + x = x.permute(0, 2, 1) # (#batch, channels, time). + + x = self.pointwise_conv1(x) # (batch, 2*channels, time) + + x = self.deriv_balancer1(x) + + if src_key_padding_mask is not None: + x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0) + + x = self.depthwise_conv(x) + + x = self.deriv_balancer2(x) + + x = self.pointwise_conv2(x) # (batch, channels, time) + + return x.permute(0, 2, 1) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/model.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/model.py new file mode 100755 index 000000000..86acc5a10 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/model.py @@ -0,0 +1,224 @@ +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, Wei Kang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Tuple + +import k2 +import torch +import torch.nn as nn +from encoder_interface import EncoderInterface + +from icefall.utils import add_sos, make_pad_mask + + +class Transducer(nn.Module): + """It implements https://arxiv.org/pdf/1211.3711.pdf + "Sequence Transduction with Recurrent Neural Networks" + """ + + def __init__( + self, + encoder: EncoderInterface, + decoder: nn.Module, + joiner: nn.Module, + lconv: nn.Module, + frame_reducer: nn.Module, + encoder_dim: int, + decoder_dim: int, + joiner_dim: int, + vocab_size: int, + ): + """ + Args: + encoder: + It is the transcription network in the paper. Its accepts + two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,). + It returns two tensors: `logits` of shape (N, T, encoder_dm) and + `logit_lens` of shape (N,). + decoder: + It is the prediction network in the paper. Its input shape + is (N, U) and its output shape is (N, U, decoder_dim). + It should contain one attribute: `blank_id`. + joiner: + It has two inputs with shapes: (N, T, encoder_dim) and (N, U, decoder_dim). + Its output shape is (N, T, U, vocab_size). Note that its output contains + unnormalized probs, i.e., not processed by log-softmax. + """ + super().__init__() + assert isinstance(encoder, EncoderInterface), type(encoder) + assert hasattr(decoder, "blank_id") + + self.encoder = encoder + self.decoder = decoder + self.joiner = joiner + self.lconv = lconv + self.frame_reducer = frame_reducer + + self.simple_am_proj = nn.Linear( + encoder_dim, + vocab_size, + ) + self.simple_lm_proj = nn.Linear(decoder_dim, vocab_size) + + self.ctc_output = nn.Sequential( + nn.Dropout(p=0.1), + nn.Linear(encoder_dim, vocab_size), + nn.LogSoftmax(dim=-1), + ) + + def forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + y: k2.RaggedTensor, + prune_range: int = 5, + am_scale: float = 0.0, + lm_scale: float = 0.0, + warmup: float = 1.0, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Args: + x: + A 3-D tensor of shape (N, T, C). + x_lens: + A 1-D tensor of shape (N,). It contains the number of frames in `x` + before padding. + y: + A ragged tensor with 2 axes [utt][label]. It contains labels of each + utterance. + prune_range: + The prune range for rnnt loss, it means how many symbols(context) + we are considering for each frame to compute the loss. + am_scale: + The scale to smooth the loss with am (output of encoder network) + part + lm_scale: + The scale to smooth the loss with lm (output of predictor network) + part + warmup: + A floating point value which decides whether to do blank skip. + Returns: + Return a tuple containing simple loss, pruned loss, and ctc-output. + Note: + Regarding am_scale & lm_scale, it will make the loss-function one of + the form: + lm_scale * lm_probs + am_scale * am_probs + + (1-lm_scale-am_scale) * combined_probs + """ + assert x.ndim == 3, x.shape + assert x_lens.ndim == 1, x_lens.shape + assert y.num_axes == 2, y.num_axes + + assert x.size(0) == x_lens.size(0) == y.dim0 + + encoder_out, x_lens = self.encoder(x, x_lens) + assert torch.all(x_lens > 0) + + # compute ctc log-probs + ctc_output = self.ctc_output(encoder_out) + + # blank skip + blank_id = self.decoder.blank_id + + if warmup >= 2.0: + # lconv + encoder_out = self.lconv( + x=encoder_out, + src_key_padding_mask=make_pad_mask(x_lens), + ) + + # frame reduce + encoder_out_fr, x_lens_fr = self.frame_reducer( + encoder_out, + x_lens, + ctc_output, + blank_id, + ) + else: + encoder_out_fr = encoder_out + x_lens_fr = x_lens + + # Now for the decoder, i.e., the prediction network + row_splits = y.shape.row_splits(1) + y_lens = row_splits[1:] - row_splits[:-1] + + sos_y = add_sos(y, sos_id=blank_id) + + # sos_y_padded: [B, S + 1], start with SOS. + sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id) + + # decoder_out: [B, S + 1, decoder_dim] + decoder_out = self.decoder(sos_y_padded) + + # Note: y does not start with SOS + # y_padded : [B, S] + y_padded = y.pad(mode="constant", padding_value=0) + + y_padded = y_padded.to(torch.int64) + boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) + boundary[:, 2] = y_lens + boundary[:, 3] = x_lens_fr + + am = self.simple_am_proj(encoder_out_fr) + lm = self.simple_lm_proj(decoder_out) + + with torch.cuda.amp.autocast(enabled=False): + simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( + lm=lm.float(), + am=am.float(), + symbols=y_padded, + termination_symbol=blank_id, + lm_only_scale=lm_scale, + am_only_scale=am_scale, + boundary=boundary, + reduction="sum", + return_grad=True, + ) + + # ranges : [B, T, prune_range] + ranges = k2.get_rnnt_prune_ranges( + px_grad=px_grad, + py_grad=py_grad, + boundary=boundary, + s_range=prune_range, + ) + + # am_pruned : [B, T, prune_range, encoder_dim] + # lm_pruned : [B, T, prune_range, decoder_dim] + am_pruned, lm_pruned = k2.do_rnnt_pruning( + am=self.joiner.encoder_proj(encoder_out_fr), + lm=self.joiner.decoder_proj(decoder_out), + ranges=ranges, + ) + + # logits : [B, T, prune_range, vocab_size] + + # project_input=False since we applied the decoder's input projections + # prior to do_rnnt_pruning (this is an optimization for speed). + logits = self.joiner(am_pruned, lm_pruned, project_input=False) + + with torch.cuda.amp.autocast(enabled=False): + pruned_loss = k2.rnnt_loss_pruned( + logits=logits.float(), + symbols=y_padded, + ranges=ranges, + termination_symbol=blank_id, + boundary=boundary, + reduction="sum", + ) + + return (simple_loss, pruned_loss, ctc_output) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/optim.py new file mode 120000 index 000000000..81ac4a89a --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/optim.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7/optim.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/pretrained.py new file mode 100755 index 000000000..ea0fe9164 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/pretrained.py @@ -0,0 +1,352 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads a checkpoint and uses it to decode waves. +You can generate the checkpoint with the following command: + +./pruned_transducer_stateless7_ctc_bs/export.py \ + --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 30 \ + --avg 13 + +Usage of this script: + +(1) greedy search +./pruned_transducer_stateless7_ctc_bs/pretrained.py \ + --checkpoint ./pruned_transducer_stateless7_ctc_bs/exp/pretrained.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method greedy_search \ + /path/to/foo.wav \ + /path/to/bar.wav + +(2) beam search +./pruned_transducer_stateless7_ctc_bs/pretrained.py \ + --checkpoint ./pruned_transducer_stateless7_ctc_bs/exp/pretrained.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method beam_search \ + --beam-size 4 \ + /path/to/foo.wav \ + /path/to/bar.wav + +(3) modified beam search +./pruned_transducer_stateless7_ctc_bs/pretrained.py \ + --checkpoint ./pruned_transducer_stateless7_ctc_bs/exp/pretrained.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method modified_beam_search \ + --beam-size 4 \ + /path/to/foo.wav \ + /path/to/bar.wav + +(4) fast beam search +./pruned_transducer_stateless7_ctc_bs/pretrained.py \ + --checkpoint ./pruned_transducer_stateless7_ctc_bs/exp/pretrained.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method fast_beam_search \ + --beam-size 4 \ + /path/to/foo.wav \ + /path/to/bar.wav + +You can also use `./pruned_transducer_stateless7_ctc_bs/exp/epoch-xx.pt`. + +Note: ./pruned_transducer_stateless7_ctc_bs/exp/pretrained.pt is generated by +./pruned_transducer_stateless7_ctc_bs/export.py +""" + + +import argparse +import logging +import math +from typing import List + +import k2 +import kaldifeat +import sentencepiece as spm +import torch +import torchaudio +from beam_search import ( + beam_search, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, +) +from torch.nn.utils.rnn import pad_sequence +from train import add_model_arguments, get_params, get_transducer_model + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--checkpoint", + type=str, + required=True, + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", + ) + + parser.add_argument( + "--bpe-model", + type=str, + help="""Path to bpe.model.""", + ) + + parser.add_argument( + "--method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - fast_beam_search + """, + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=4, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=4, + help="""Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=8, + help="""Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram, 2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. Used only when + --method is greedy_search. + """, + ) + + add_model_arguments(parser) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}" + ) + # We use only the first channel + ans.append(wave[0]) + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + + params = get_params() + + params.update(vars(args)) + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(f"{params}") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + logging.info("Creating model") + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + checkpoint = torch.load(args.checkpoint, map_location="cpu") + model.load_state_dict(checkpoint["model"], strict=False) + model.to(device) + model.eval() + model.device = device + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = params.sample_rate + opts.mel_opts.num_bins = params.feature_dim + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {params.sound_files}") + waves = read_sound_files( + filenames=params.sound_files, expected_sample_rate=params.sample_rate + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + + feature_lengths = torch.tensor(feature_lengths, device=device) + + encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths) + + num_waves = encoder_out.size(0) + hyps = [] + msg = f"Using {params.method}" + if params.method == "beam_search": + msg += f" with beam size {params.beam_size}" + logging.info(msg) + + if params.method == "fast_beam_search": + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + ) + + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + else: + for i in range(num_waves): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError(f"Unsupported method: {params.method}") + + hyps.append(sp.decode(hyp).split()) + + s = "\n" + for filename, hyp in zip(params.sound_files, hyps): + words = " ".join(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/pretrained_ctc.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/pretrained_ctc.py new file mode 100755 index 000000000..412631ba1 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/pretrained_ctc.py @@ -0,0 +1,440 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads torchscript models, exported by `torch.jit.script()` +and uses them to decode waves. +You can use the following command to get the exported models: + +./pruned_transducer_stateless7_ctc_bs/export.py \ + --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 20 \ + --avg 10 + +Usage of this script: + +(1) ctc-decoding +./pruned_transducer_stateless7_ctc_bs/jit_pretrained_ctc.py \ + --checkpoint ./pruned_transducer_stateless7_ctc_bs/exp/pretrained.pt \ + --bpe-model data/lang_bpe_500/bpe.model \ + --method ctc-decoding \ + --sample-rate 16000 \ + /path/to/foo.wav \ + /path/to/bar.wav + +(2) 1best +./pruned_transducer_stateless7_ctc_bs/jit_pretrained_ctc.py \ + --checkpoint ./pruned_transducer_stateless7_ctc_bs/exp/pretrained.pt \ + --HLG data/lang_bpe_500/HLG.pt \ + --words-file data/lang_bpe_500/words.txt \ + --method 1best \ + --sample-rate 16000 \ + /path/to/foo.wav \ + /path/to/bar.wav + +(3) nbest-rescoring +./bruned_transducer_stateless7_ctc/jit_pretrained_ctc.py \ + --checkpoint ./pruned_transducer_stateless7_ctc_bs/exp/pretrained.pt \ + --HLG data/lang_bpe_500/HLG.pt \ + --words-file data/lang_bpe_500/words.txt \ + --G data/lm/G_4_gram.pt \ + --method nbest-rescoring \ + --sample-rate 16000 \ + /path/to/foo.wav \ + /path/to/bar.wav + + +(4) whole-lattice-rescoring +./pruned_transducer_stateless7_ctc_bs/jit_pretrained_ctc.py \ + --checkpoint ./pruned_transducer_stateless7_ctc_bs/exp/pretrained.pt \ + --HLG data/lang_bpe_500/HLG.pt \ + --words-file data/lang_bpe_500/words.txt \ + --G data/lm/G_4_gram.pt \ + --method whole-lattice-rescoring \ + --sample-rate 16000 \ + /path/to/foo.wav \ + /path/to/bar.wav +""" + +import argparse +import logging +import math +from typing import List + +import k2 +import kaldifeat +import sentencepiece as spm +import torch +import torchaudio +from ctc_decode import get_decoding_params +from torch.nn.utils.rnn import pad_sequence +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.decode import ( + get_lattice, + one_best_decoding, + rescore_with_n_best_list, + rescore_with_whole_lattice, +) +from icefall.utils import get_texts + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--checkpoint", + type=str, + required=True, + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram, 2 means tri-gram", + ) + + parser.add_argument( + "--words-file", + type=str, + help="""Path to words.txt. + Used only when method is not ctc-decoding. + """, + ) + + parser.add_argument( + "--HLG", + type=str, + help="""Path to HLG.pt. + Used only when method is not ctc-decoding. + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + help="""Path to bpe.model. + Used only when method is ctc-decoding. + """, + ) + + parser.add_argument( + "--method", + type=str, + default="1best", + help="""Decoding method. + Possible values are: + (0) ctc-decoding - Use CTC decoding. It uses a sentence + piece model, i.e., lang_dir/bpe.model, to convert + word pieces to words. It needs neither a lexicon + nor an n-gram LM. + (1) 1best - Use the best path as decoding output. Only + the transformer encoder output is used for decoding. + We call it HLG decoding. + (2) nbest-rescoring. Extract n paths from the decoding lattice, + rescore them with an LM, the path with + the highest score is the decoding result. + We call it HLG decoding + n-gram LM rescoring. + (3) whole-lattice-rescoring - Use an LM to rescore the + decoding lattice and then use 1best to decode the + rescored lattice. + We call it HLG decoding + n-gram LM rescoring. + """, + ) + + parser.add_argument( + "--G", + type=str, + help="""An LM for rescoring. + Used only when method is + whole-lattice-rescoring or nbest-rescoring. + It's usually a 4-gram LM. + """, + ) + + parser.add_argument( + "--num-paths", + type=int, + default=100, + help=""" + Used only when method is attention-decoder. + It specifies the size of n-best list.""", + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=1.3, + help=""" + Used only when method is whole-lattice-rescoring and nbest-rescoring. + It specifies the scale for n-gram LM scores. + (Note: You need to tune it on a dataset.) + """, + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help=""" + Used only when method is nbest-rescoring. + It specifies the scale for lattice.scores when + extracting n-best lists. A smaller value results in + more unique number of paths with the risk of missing + the best path. + """, + ) + + parser.add_argument( + "--num-classes", + type=int, + default=500, + help=""" + Vocab size in the BPE model. + """, + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + add_model_arguments(parser) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float = 16000 +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}" + ) + # We use only the first channel + ans.append(wave[0]) + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + + params = get_params() + # add decoding params + params.update(get_decoding_params()) + params.update(vars(args)) + params.vocab_size = params.num_classes + params.blank_id = 0 + + logging.info(f"{params}") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + logging.info("Creating model") + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + checkpoint = torch.load(args.checkpoint, map_location="cpu") + model.load_state_dict(checkpoint["model"], strict=False) + model.to(device) + model.eval() + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = params.sample_rate + opts.mel_opts.num_bins = params.feature_dim + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {params.sound_files}") + waves = read_sound_files( + filenames=params.sound_files, expected_sample_rate=params.sample_rate + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + feature_lengths = torch.tensor(feature_lengths, device=device) + + encoder_out, encoder_out_lens = model.encoder( + x=features, + x_lens=feature_lengths, + ) + nnet_output = model.ctc_output(encoder_out) + + batch_size = nnet_output.shape[0] + supervision_segments = torch.tensor( + [[i, 0, nnet_output.shape[1]] for i in range(batch_size)], + dtype=torch.int32, + ) + + if params.method == "ctc-decoding": + logging.info("Use CTC decoding") + bpe_model = spm.SentencePieceProcessor() + bpe_model.load(params.bpe_model) + max_token_id = params.num_classes - 1 + + H = k2.ctc_topo( + max_token=max_token_id, + modified=False, + device=device, + ) + + lattice = get_lattice( + nnet_output=nnet_output, + decoding_graph=H, + supervision_segments=supervision_segments, + search_beam=params.search_beam, + output_beam=params.output_beam, + min_active_states=params.min_active_states, + max_active_states=params.max_active_states, + subsampling_factor=params.subsampling_factor, + ) + + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + token_ids = get_texts(best_path) + hyps = bpe_model.decode(token_ids) + hyps = [s.split() for s in hyps] + elif params.method in [ + "1best", + "nbest-rescoring", + "whole-lattice-rescoring", + ]: + logging.info(f"Loading HLG from {params.HLG}") + HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) + HLG = HLG.to(device) + if not hasattr(HLG, "lm_scores"): + # For whole-lattice-rescoring and attention-decoder + HLG.lm_scores = HLG.scores.clone() + + if params.method in [ + "nbest-rescoring", + "whole-lattice-rescoring", + ]: + logging.info(f"Loading G from {params.G}") + G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu")) + G = G.to(device) + if params.method == "whole-lattice-rescoring": + # Add epsilon self-loops to G as we will compose + # it with the whole lattice later + G = k2.add_epsilon_self_loops(G) + G = k2.arc_sort(G) + + # G.lm_scores is used to replace HLG.lm_scores during + # LM rescoring. + G.lm_scores = G.scores.clone() + + lattice = get_lattice( + nnet_output=nnet_output, + decoding_graph=HLG, + supervision_segments=supervision_segments, + search_beam=params.search_beam, + output_beam=params.output_beam, + min_active_states=params.min_active_states, + max_active_states=params.max_active_states, + subsampling_factor=params.subsampling_factor, + ) + + if params.method == "1best": + logging.info("Use HLG decoding") + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + if params.method == "nbest-rescoring": + logging.info("Use HLG decoding + LM rescoring") + best_path_dict = rescore_with_n_best_list( + lattice=lattice, + G=G, + num_paths=params.num_paths, + lm_scale_list=[params.ngram_lm_scale], + nbest_scale=params.nbest_scale, + ) + best_path = next(iter(best_path_dict.values())) + elif params.method == "whole-lattice-rescoring": + logging.info("Use HLG decoding + LM rescoring") + best_path_dict = rescore_with_whole_lattice( + lattice=lattice, + G_with_epsilon_loops=G, + lm_scale_list=[params.ngram_lm_scale], + ) + best_path = next(iter(best_path_dict.values())) + + hyps = get_texts(best_path) + word_sym_table = k2.SymbolTable.from_file(params.words_file) + hyps = [[word_sym_table[i] for i in ids] for ids in hyps] + else: + raise ValueError(f"Unsupported decoding method: {params.method}") + + s = "\n" + for filename, hyp in zip(params.sound_files, hyps): + words = " ".join(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/scaling.py new file mode 120000 index 000000000..2428b74b9 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/scaling.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7/scaling.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/scaling_converter.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/scaling_converter.py new file mode 120000 index 000000000..b8b8ba432 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/scaling_converter.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7/scaling_converter.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/test_model.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/test_model.py new file mode 100755 index 000000000..7f0893985 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/test_model.py @@ -0,0 +1,55 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +To run this file, do: + + cd icefall/egs/librispeech/ASR + python ./pruned_transducer_stateless7_ctc_bs/test_model.py +""" + +from train import get_params, get_transducer_model + + +def test_model_1(): + params = get_params() + params.vocab_size = 500 + params.blank_id = 0 + params.context_size = 2 + params.num_encoder_layers = "2,4,3,2,4" + params.feedforward_dims = "1024,1024,2048,2048,1024" + params.nhead = "8,8,8,8,8" + params.encoder_dims = "384,384,384,384,384" + params.attention_dims = "192,192,192,192,192" + params.encoder_unmasked_dims = "256,256,256,256,256" + params.zipformer_downsampling_factors = "1,2,4,8,2" + params.cnn_module_kernels = "31,31,31,31,31" + params.decoder_dim = 512 + params.joiner_dim = 512 + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + print(f"Number of model parameters: {num_param}") + + +def main(): + test_model_1() + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/train.py new file mode 100755 index 000000000..63e9d6e90 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/train.py @@ -0,0 +1,1251 @@ +#!/usr/bin/env python3 +# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +export CUDA_VISIBLE_DEVICES="0,1,2,3" +./pruned_transducer_stateless7_ctc_bs/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --exp-dir pruned_transducer_stateless7_ctc_bs/exp \ + --full-libri 1 \ + --max-duration 300 +# For mix precision training: +./pruned_transducer_stateless7_ctc_bs/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir pruned_transducer_stateless7_ctc_bs/exp \ + --full-libri 1 \ + --max-duration 550 +""" + + +import argparse +import copy +import logging +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import k2 +import optim +import sentencepiece as spm +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from decoder import Decoder +from joiner import Joiner +from lconv import LConv +from frame_reducer import FrameReducer +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model import Transducer +from optim import Eden, ScaledAdam +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from zipformer import Zipformer + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.utils import ( + AttributeDict, + MetricsTracker, + encode_supervisions, + setup_logger, + str2bool, +) + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for module in model.modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=str, + default="2,4,3,2,4", + help="Number of zipformer encoder layers, comma separated.", + ) + + parser.add_argument( + "--feedforward-dims", + type=str, + default="1024,1024,2048,2048,1024", + help="Feedforward dimension of the zipformer encoder layers, comma separated.", + ) + + parser.add_argument( + "--nhead", + type=str, + default="8,8,8,8,8", + help="Number of attention heads in the zipformer encoder layers.", + ) + + parser.add_argument( + "--encoder-dims", + type=str, + default="384,384,384,384,384", + help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated", + ) + + parser.add_argument( + "--attention-dims", + type=str, + default="192,192,192,192,192", + help="""Attention dimension in the 2 blocks of zipformer encoder layers, comma separated; + not the same as embedding dimension.""", + ) + + parser.add_argument( + "--encoder-unmasked-dims", + type=str, + default="256,256,256,256,256", + help="Unmasked dimensions in the encoders, relates to augmentation during training. " + "Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance " + " worse.", + ) + + parser.add_argument( + "--zipformer-downsampling-factors", + type=str, + default="1,2,4,8,2", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--cnn-module-kernels", + type=str, + default="31,31,31,31,31", + help="Sizes of kernels in convolution modules", + ) + + parser.add_argument( + "--decoder-dim", + type=int, + default=512, + help="Embedding dimension in the decoder model.", + ) + + parser.add_argument( + "--joiner-dim", + type=int, + default=512, + help="""Dimension used in the joiner model. + Outputs from the encoder and decoder model are projected + to this dimension before adding. + """, + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7_ctc_bs/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--base-lr", type=float, default=0.05, help="The base learning rate." + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=5000, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=3.5, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram, 2 means tri-gram", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network) part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--ctc-loss-scale", + type=float, + default=0.5, + help="Scale for CTC loss.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=2000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 0. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + Explanation of options saved in `params`: + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + - best_train_epoch: It is the epoch that has the best training loss. + - best_valid_epoch: It is the epoch that has the best validation loss. + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + - log_interval: Print training loss if batch_idx % log_interval` is 0 + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + - valid_interval: Run validation if batch_idx % valid_interval is 0 + - feature_dim: The model input dim. It has to match the one used + in computing features. + - subsampling_factor: The subsampling factor for the model. + - encoder_dim: Hidden dim for multi-head attention model. + - num_decoder_layers: Number of decoder layer of transformer decoder. + - warm_step: The warmup period that dictates the decay of the + scale on "simple" (un-pruned) loss. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, # For the 100h subset, use 800 + # parameters for zipformer + "feature_dim": 80, + "subsampling_factor": 4, # not passed in, this is fixed. + # parameters for ctc loss + "beam_size": 10, + "use_double_scores": True, + "warm_step": 2000, + "env_info": get_env_info(), + } + ) + + return params + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + # TODO: We can add an option to switch between Zipformer and Transformer + def to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + + encoder = Zipformer( + num_features=params.feature_dim, + output_downsampling_factor=2, + zipformer_downsampling_factors=to_int_tuple( + params.zipformer_downsampling_factors + ), + encoder_dims=to_int_tuple(params.encoder_dims), + attention_dim=to_int_tuple(params.attention_dims), + encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims), + nhead=to_int_tuple(params.nhead), + feedforward_dim=to_int_tuple(params.feedforward_dims), + cnn_module_kernels=to_int_tuple(params.cnn_module_kernels), + num_encoder_layers=to_int_tuple(params.num_encoder_layers), + ) + return encoder + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + decoder_dim=params.decoder_dim, + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + joiner = Joiner( + encoder_dim=int(params.encoder_dims.split(",")[-1]), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return joiner + + +def get_lconv(params: AttributeDict) -> nn.Module: + lconv = LConv( + channels=int(params.encoder_dims.split(",")[-1]), + ) + return lconv + + +def get_frame_reducer(params: AttributeDict) -> nn.Module: + frame_reducer = FrameReducer() + return frame_reducer + + +def get_transducer_model(params: AttributeDict) -> nn.Module: + encoder = get_encoder_model(params) + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + lconv = get_lconv(params) + frame_reducer = get_frame_reducer(params) + + model = Transducer( + encoder=encoder, + decoder=decoder, + joiner=joiner, + lconv=lconv, + frame_reducer=frame_reducer, + encoder_dim=int(params.encoder_dims.split(",")[-1]), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + if "cur_batch_idx" in saved_params: + params["cur_batch_idx"] = saved_params["cur_batch_idx"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute transducer loss given the model and its inputs. + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + batch_idx_train = params.batch_idx_train + warm_step = params.warm_step + warmup = batch_idx_train / warm_step + + texts = batch["supervisions"]["text"] + token_ids = sp.encode(texts, out_type=int) + y = k2.RaggedTensor(token_ids).to(device) + + with torch.set_grad_enabled(is_training): + simple_loss, pruned_loss, ctc_output = model( + x=feature, + x_lens=feature_lens, + y=y, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + warmup=warmup, + ) + + s = params.simple_loss_scale + # take down the scale on the simple loss from 1.0 at the start + # to params.simple_loss scale by warm_step. + simple_loss_scale = ( + s + if batch_idx_train >= warm_step + else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) + ) + pruned_loss_scale = ( + 1.0 + if batch_idx_train >= warm_step + else 0.1 + 0.9 * (batch_idx_train / warm_step) + ) + + loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + + # Compute ctc loss + + # NOTE: We need `encode_supervisions` to sort sequences with + # different duration in decreasing order, required by + # `k2.intersect_dense` called in `k2.ctc_loss` + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + supervision_segments, token_ids = encode_supervisions( + supervisions, + subsampling_factor=params.subsampling_factor, + token_ids=token_ids, + ) + + # Works with a BPE model + decoding_graph = k2.ctc_graph(token_ids, modified=False, device=device) + dense_fsa_vec = k2.DenseFsaVec( + ctc_output, + supervision_segments, + allow_truncate=params.subsampling_factor - 1, + ) + + ctc_loss = k2.ctc_loss( + decoding_graph=decoding_graph, + dense_fsa_vec=dense_fsa_vec, + output_beam=params.beam_size, + reduction="sum", + use_double_scores=params.use_double_scores, + ) + assert ctc_loss.requires_grad == is_training + loss += params.ctc_loss_scale * ctc_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + info["ctc_loss"] = ctc_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + sp: spm.SentencePieceProcessor, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + cur_batch_idx = params.get("cur_batch_idx", 0) + + for batch_idx, batch in enumerate(train_dl): + if batch_idx < cur_batch_idx: + continue + cur_batch_idx = batch_idx + + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + set_batch_count(model, params.batch_idx_train) + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except: # noqa + display_and_save_batch(batch, params=params, sp=sp) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + params.cur_batch_idx = batch_idx + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + del params.cur_batch_idx + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if batch_idx % params.log_interval == 0: + cur_lr = scheduler.get_last_lr()[0] + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", + cur_grad_scale, + params.batch_idx_train, + ) + + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + if params.full_libri is False: + params.valid_interval = 1600 + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + parameters_names = [] + parameters_names.append( + [name_param_pair[0] for name_param_pair in model.named_parameters()] + ) + + optimizer = ScaledAdam( + model.parameters(), + lr=params.base_lr, + clipping_scale=2.0, + parameters_names=parameters_names, + ) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 2**22 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + librispeech = LibriSpeechAsrDataModule(args) + + train_cuts = librispeech.train_clean_100_cuts() + if params.full_libri: + train_cuts += librispeech.train_clean_360_cuts() + train_cuts += librispeech.train_other_500_cuts() + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + return 1.0 <= c.duration <= 20.0 + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = librispeech.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) + + valid_cuts = librispeech.dev_clean_cuts() + valid_cuts += librispeech.dev_other_cuts() + valid_dl = librispeech.valid_dataloaders(valid_cuts) + + if not params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + sp=sp, + params=params, + ) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sp=sp, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + sp: spm.SentencePieceProcessor, +) -> None: + """Display the batch statistics and save the batch into disk. + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + sp: + The BPE model. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + y = sp.encode(supervisions["text"], out_type=int) + num_tokens = sum(len(i) for i in y) + logging.info(f"num tokens: {num_tokens}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + sp: spm.SentencePieceProcessor, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params, sp=sp) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/zipformer.py new file mode 120000 index 000000000..79b076556 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/zipformer.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7/zipformer.py \ No newline at end of file From 7eb2d0edb637e8000e9c1a5eafe157d967d932ec Mon Sep 17 00:00:00 2001 From: BuaaAlban Date: Fri, 23 Dec 2022 11:38:22 +0800 Subject: [PATCH 055/174] Update train.py (#773) Fix transducer lstm egs bug as mentioned in issue 579 --- egs/librispeech/ASR/transducer_lstm/train.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/egs/librispeech/ASR/transducer_lstm/train.py b/egs/librispeech/ASR/transducer_lstm/train.py index 792708bc0..a6f2bd08c 100755 --- a/egs/librispeech/ASR/transducer_lstm/train.py +++ b/egs/librispeech/ASR/transducer_lstm/train.py @@ -629,18 +629,8 @@ def run(rank, world_size, args): # Keep only utterances with duration between 1 second and 20 seconds return 1.0 <= c.duration <= 20.0 - num_in_total = len(train_cuts) - train_cuts = train_cuts.filter(remove_short_and_long_utt) - num_left = len(train_cuts) - num_removed = num_in_total - num_left - removed_percent = num_removed / num_in_total * 100 - - logging.info(f"Before removing short and long utterances: {num_in_total}") - logging.info(f"After removing short and long utterances: {num_left}") - logging.info(f"Removed {num_removed} utterances ({removed_percent:.5f}%)") - train_dl = librispeech.train_dataloaders(train_cuts) valid_cuts = librispeech.dev_clean_cuts() From 59eb465b3cd47a212117b535644f24ed190093e1 Mon Sep 17 00:00:00 2001 From: Yifan Yang <64255737+yfyeung@users.noreply.github.com> Date: Fri, 23 Dec 2022 17:55:36 +0800 Subject: [PATCH 056/174] optimize frame_reducer.py (#783) Co-authored-by: yifanyang --- .../pruned_transducer_stateless7_ctc_bs/frame_reducer.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/frame_reducer.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/frame_reducer.py index 3de21a293..9fe88929d 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/frame_reducer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/frame_reducer.py @@ -66,19 +66,14 @@ class FrameReducer(nn.Module): padding_mask = make_pad_mask(x_lens) non_blank_mask = (ctc_output[:, :, blank_id] < math.log(0.9)) * (~padding_mask) - T_range = torch.arange(x.shape[1], device=x.device) frames_list: List[torch.Tensor] = [] lens_list: List[int] = [] for i in range(x.shape[0]): - indexes = torch.masked_select( - T_range, - non_blank_mask[i], - ) - frames = x[i][indexes] + frames = x[i][non_blank_mask[i]] frames_list.append(frames) lens_list.append(frames.shape[0]) - x_fr = pad_sequence(frames_list).transpose(0, 1) + x_fr = pad_sequence(frames_list, batch_first=True) x_lens_fr = torch.tensor(lens_list).to(device=x.device) return x_fr, x_lens_fr From 4e249da2c402eb83e6206365c161693d2f5db070 Mon Sep 17 00:00:00 2001 From: Yifan Yang <64255737+yfyeung@users.noreply.github.com> Date: Mon, 26 Dec 2022 14:30:20 +0800 Subject: [PATCH 057/174] Add zipformer_ctc_blankskip.rst (#784) * Add zipformer_ctc_blankskip.rst * typo fix for zipformer_mmi.rst * fix warning * Update docs/source/recipes/Non-streaming-ASR/librispeech/zipformer_ctc_blankskip.rst Co-authored-by: yifanyang Co-authored-by: Fangjun Kuang --- .../export-with-torch-jit-script.rst | 2 +- .../aishell/conformer_ctc.rst | 2 +- .../librispeech/conformer_ctc.rst | 8 +- .../Non-streaming-ASR/librispeech/index.rst | 2 +- .../pruned_transducer_stateless.rst | 3 +- .../librispeech/zipformer_ctc_blankskip.rst | 453 ++++++++++++++++++ .../librispeech/zipformer_mmi.rst | 4 +- 7 files changed, 464 insertions(+), 10 deletions(-) create mode 100644 docs/source/recipes/Non-streaming-ASR/librispeech/zipformer_ctc_blankskip.rst diff --git a/docs/source/model-export/export-with-torch-jit-script.rst b/docs/source/model-export/export-with-torch-jit-script.rst index a041dc1d5..efd7dc2e1 100644 --- a/docs/source/model-export/export-with-torch-jit-script.rst +++ b/docs/source/model-export/export-with-torch-jit-script.rst @@ -1,7 +1,7 @@ .. _export-model-with-torch-jit-script: Export model with torch.jit.script() -=================================== +==================================== In this section, we describe how to export a model via ``torch.jit.script()``. diff --git a/docs/source/recipes/Non-streaming-ASR/aishell/conformer_ctc.rst b/docs/source/recipes/Non-streaming-ASR/aishell/conformer_ctc.rst index 72690e102..6e30ce397 100644 --- a/docs/source/recipes/Non-streaming-ASR/aishell/conformer_ctc.rst +++ b/docs/source/recipes/Non-streaming-ASR/aishell/conformer_ctc.rst @@ -703,7 +703,7 @@ It will show you the following message: HLG decoding -^^^^^^^^^^^^ +~~~~~~~~~~~~ .. code-block:: bash diff --git a/docs/source/recipes/Non-streaming-ASR/librispeech/conformer_ctc.rst b/docs/source/recipes/Non-streaming-ASR/librispeech/conformer_ctc.rst index 4656acfd6..b7f89c89f 100644 --- a/docs/source/recipes/Non-streaming-ASR/librispeech/conformer_ctc.rst +++ b/docs/source/recipes/Non-streaming-ASR/librispeech/conformer_ctc.rst @@ -888,7 +888,7 @@ It will show you the following message: CTC decoding -^^^^^^^^^^^^ +~~~~~~~~~~~~ .. code-block:: bash @@ -926,7 +926,7 @@ Its output is: YET THESE THOUGHTS AFFECTED HESTER PRYNNE LESS WITH HOPE THAN APPREHENSION HLG decoding -^^^^^^^^^^^^ +~~~~~~~~~~~~ .. code-block:: bash @@ -966,7 +966,7 @@ The output is: HLG decoding + n-gram LM rescoring -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. code-block:: bash @@ -1012,7 +1012,7 @@ The output is: HLG decoding + n-gram LM rescoring + attention decoder rescoring -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. code-block:: bash diff --git a/docs/source/recipes/Non-streaming-ASR/librispeech/index.rst b/docs/source/recipes/Non-streaming-ASR/librispeech/index.rst index aa97f325d..3ebb36b25 100644 --- a/docs/source/recipes/Non-streaming-ASR/librispeech/index.rst +++ b/docs/source/recipes/Non-streaming-ASR/librispeech/index.rst @@ -7,5 +7,5 @@ LibriSpeech tdnn_lstm_ctc conformer_ctc pruned_transducer_stateless - lstm_pruned_stateless_transducer zipformer_mmi + zipformer_ctc_blankskip diff --git a/docs/source/recipes/Non-streaming-ASR/librispeech/pruned_transducer_stateless.rst b/docs/source/recipes/Non-streaming-ASR/librispeech/pruned_transducer_stateless.rst index d8569bc5c..86d43c8fe 100644 --- a/docs/source/recipes/Non-streaming-ASR/librispeech/pruned_transducer_stateless.rst +++ b/docs/source/recipes/Non-streaming-ASR/librispeech/pruned_transducer_stateless.rst @@ -499,9 +499,10 @@ can run: Export model using ``torch.jit.script()`` -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. code-block:: bash + ./pruned_transducer_stateless4/export.py \ --exp-dir ./pruned_transducer_stateless4/exp \ --bpe-model data/lang_bpe_500/bpe.model \ diff --git a/docs/source/recipes/Non-streaming-ASR/librispeech/zipformer_ctc_blankskip.rst b/docs/source/recipes/Non-streaming-ASR/librispeech/zipformer_ctc_blankskip.rst new file mode 100644 index 000000000..d85a3c67f --- /dev/null +++ b/docs/source/recipes/Non-streaming-ASR/librispeech/zipformer_ctc_blankskip.rst @@ -0,0 +1,453 @@ +Zipformer CTC Blank Skip +======================== + +.. hint:: + + Please scroll down to the bottom of this page to find download links + for pretrained models if you don't want to train a model from scratch. + + +This tutorial shows you how to train a Zipformer model based on the guidance from +a co-trained CTC model using `blank skip method `_ +with the `LibriSpeech `_ dataset. + +.. note:: + + We use both CTC and RNN-T loss to train. During the forward pass, the encoder output + is first used to calculate the CTC posterior probability; then for each output frame, + if its blank posterior is bigger than some threshold, it will be simply discarded + from the encoder output. To prevent information loss, we also put a convolution module + similar to the one used in conformer (referred to as “LConv”) before the frame reduction. + + +Data preparation +---------------- + +.. code-block:: bash + + $ cd egs/librispeech/ASR + $ ./prepare.sh + +The script ``./prepare.sh`` handles the data preparation for you, **automagically**. +All you need to do is to run it. + +.. note:: + + We encourage you to read ``./prepare.sh``. + +The data preparation contains several stages. You can use the following two +options: + + - ``--stage`` + - ``--stop-stage`` + +to control which stage(s) should be run. By default, all stages are executed. + + +For example, + +.. code-block:: bash + + $ cd egs/librispeech/ASR + $ ./prepare.sh --stage 0 --stop-stage 0 + +means to run only stage 0. + +To run stage 2 to stage 5, use: + +.. code-block:: bash + + $ ./prepare.sh --stage 2 --stop-stage 5 + +.. hint:: + + If you have pre-downloaded the `LibriSpeech `_ + dataset and the `musan `_ dataset, say, + they are saved in ``/tmp/LibriSpeech`` and ``/tmp/musan``, you can modify + the ``dl_dir`` variable in ``./prepare.sh`` to point to ``/tmp`` so that + ``./prepare.sh`` won't re-download them. + +.. note:: + + All generated files by ``./prepare.sh``, e.g., features, lexicon, etc, + are saved in ``./data`` directory. + +We provide the following YouTube video showing how to run ``./prepare.sh``. + +.. note:: + + To get the latest news of `next-gen Kaldi `_, please subscribe + the following YouTube channel by `Nadira Povey `_: + + ``_ + +.. youtube:: ofEIoJL-mGM + +Training +-------- + +For stability, it doesn`t use blank skip method until model warm-up. + +Configurable options +~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: bash + + $ cd egs/librispeech/ASR + $ ./pruned_transducer_stateless7_ctc_bs/train.py --help + +shows you the training options that can be passed from the commandline. +The following options are used quite often: + + - ``--full-libri`` + + If it's True, the training part uses all the training data, i.e., + 960 hours. Otherwise, the training part uses only the subset + ``train-clean-100``, which has 100 hours of training data. + + .. CAUTION:: + + The training set is perturbed by speed with two factors: 0.9 and 1.1. + If ``--full-libri`` is True, each epoch actually processes + ``3x960 == 2880`` hours of data. + + - ``--num-epochs`` + + It is the number of epochs to train. For instance, + ``./pruned_transducer_stateless7_ctc_bs/train.py --num-epochs 30`` trains for 30 epochs + and generates ``epoch-1.pt``, ``epoch-2.pt``, ..., ``epoch-30.pt`` + in the folder ``./pruned_transducer_stateless7_ctc_bs/exp``. + + - ``--start-epoch`` + + It's used to resume training. + ``./pruned_transducer_stateless7_ctc_bs/train.py --start-epoch 10`` loads the + checkpoint ``./pruned_transducer_stateless7_ctc_bs/exp/epoch-9.pt`` and starts + training from epoch 10, based on the state from epoch 9. + + - ``--world-size`` + + It is used for multi-GPU single-machine DDP training. + + - (a) If it is 1, then no DDP training is used. + + - (b) If it is 2, then GPU 0 and GPU 1 are used for DDP training. + + The following shows some use cases with it. + + **Use case 1**: You have 4 GPUs, but you only want to use GPU 0 and + GPU 2 for training. You can do the following: + + .. code-block:: bash + + $ cd egs/librispeech/ASR + $ export CUDA_VISIBLE_DEVICES="0,2" + $ ./pruned_transducer_stateless7_ctc_bs/train.py --world-size 2 + + **Use case 2**: You have 4 GPUs and you want to use all of them + for training. You can do the following: + + .. code-block:: bash + + $ cd egs/librispeech/ASR + $ ./pruned_transducer_stateless7_ctc_bs/train.py --world-size 4 + + **Use case 3**: You have 4 GPUs but you only want to use GPU 3 + for training. You can do the following: + + .. code-block:: bash + + $ cd egs/librispeech/ASR + $ export CUDA_VISIBLE_DEVICES="3" + $ ./pruned_transducer_stateless7_ctc_bs/train.py --world-size 1 + + .. caution:: + + Only multi-GPU single-machine DDP training is implemented at present. + Multi-GPU multi-machine DDP training will be added later. + + - ``--max-duration`` + + It specifies the number of seconds over all utterances in a + batch, before **padding**. + If you encounter CUDA OOM, please reduce it. + + .. HINT:: + + Due to padding, the number of seconds of all utterances in a + batch will usually be larger than ``--max-duration``. + + A larger value for ``--max-duration`` may cause OOM during training, + while a smaller value may increase the training time. You have to + tune it. + + +Pre-configured options +~~~~~~~~~~~~~~~~~~~~~~ + +There are some training options, e.g., weight decay, +number of warmup steps, results dir, etc, +that are not passed from the commandline. +They are pre-configured by the function ``get_params()`` in +`pruned_transducer_stateless7_ctc_bs/train.py `_ + +You don't need to change these pre-configured parameters. If you really need to change +them, please modify ``./pruned_transducer_stateless7_ctc_bs/train.py`` directly. + +Training logs +~~~~~~~~~~~~~ + +Training logs and checkpoints are saved in ``pruned_transducer_stateless7_ctc_bs/exp``. +You will find the following files in that directory: + + - ``epoch-1.pt``, ``epoch-2.pt``, ... + + These are checkpoint files saved at the end of each epoch, containing model + ``state_dict`` and optimizer ``state_dict``. + To resume training from some checkpoint, say ``epoch-10.pt``, you can use: + + .. code-block:: bash + + $ ./pruned_transducer_stateless7_ctc_bs/train.py --start-epoch 11 + + - ``checkpoint-436000.pt``, ``checkpoint-438000.pt``, ... + + These are checkpoint files saved every ``--save-every-n`` batches, + containing model ``state_dict`` and optimizer ``state_dict``. + To resume training from some checkpoint, say ``checkpoint-436000``, you can use: + + .. code-block:: bash + + $ ./pruned_transducer_stateless7_ctc_bs/train.py --start-batch 436000 + + - ``tensorboard/`` + + This folder contains tensorBoard logs. Training loss, validation loss, learning + rate, etc, are recorded in these logs. You can visualize them by: + + .. code-block:: bash + + $ cd pruned_transducer_stateless7_ctc_bs/exp/tensorboard + $ tensorboard dev upload --logdir . --description "Zipformer MMI training for LibriSpeech with icefall" + + It will print something like below: + + .. code-block:: + + TensorFlow installation not found - running with reduced feature set. + Upload started and will continue reading any new data as it's added to the logdir. + + To stop uploading, press Ctrl-C. + + New experiment created. View your TensorBoard at: https://tensorboard.dev/experiment/xyOZUKpEQm62HBIlUD4uPA/ + + Note there is a URL in the above output. Click it and you will see + tensorboard. + + .. hint:: + + If you don't have access to google, you can use the following command + to view the tensorboard log locally: + + .. code-block:: bash + + cd pruned_transducer_stateless7_ctc_bs/exp/tensorboard + tensorboard --logdir . --port 6008 + + It will print the following message: + + .. code-block:: + + Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all + TensorBoard 2.8.0 at http://localhost:6008/ (Press CTRL+C to quit) + + Now start your browser and go to ``_ to view the tensorboard + logs. + + + - ``log/log-train-xxxx`` + + It is the detailed training log in text format, same as the one + you saw printed to the console during training. + +Usage example +~~~~~~~~~~~~~ + +You can use the following command to start the training using 4 GPUs: + +.. code-block:: bash + + export CUDA_VISIBLE_DEVICES="0,1,2,3" + ./pruned_transducer_stateless7_ctc_bs/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --full-libri 1 \ + --exp-dir pruned_transducer_stateless7_ctc_bs/exp \ + --max-duration 600 \ + --use-fp16 1 + +Decoding +-------- + +The decoding part uses checkpoints saved by the training part, so you have +to run the training part first. + +.. hint:: + + There are two kinds of checkpoints: + + - (1) ``epoch-1.pt``, ``epoch-2.pt``, ..., which are saved at the end + of each epoch. You can pass ``--epoch`` to + ``pruned_transducer_stateless7_ctc_bs/ctc_guild_decode_bs.py`` to use them. + + - (2) ``checkpoints-436000.pt``, ``epoch-438000.pt``, ..., which are saved + every ``--save-every-n`` batches. You can pass ``--iter`` to + ``pruned_transducer_stateless7_ctc_bs/ctc_guild_decode_bs.py`` to use them. + + We suggest that you try both types of checkpoints and choose the one + that produces the lowest WERs. + +.. code-block:: bash + + $ cd egs/librispeech/ASR + $ ./pruned_transducer_stateless7_ctc_bs/ctc_guild_decode_bs.py --help + +shows the options for decoding. + +The following shows the example using ``epoch-*.pt``: + +.. code-block:: bash + + for m in greedy_search fast_beam_search modified_beam_search; do + ./pruned_transducer_stateless7_ctc_bs/ctc_guild_decode_bs.py \ + --epoch 30 \ + --avg 13 \ + --exp-dir pruned_transducer_stateless7_ctc_bs/exp \ + --max-duration 600 \ + --decoding-method $m + done + +To test CTC branch, you can use the following command: + +.. code-block:: bash + + for m in ctc-decoding 1best; do + ./pruned_transducer_stateless7_ctc_bs/ctc_guild_decode_bs.py \ + --epoch 30 \ + --avg 13 \ + --exp-dir pruned_transducer_stateless7_ctc_bs/exp \ + --max-duration 600 \ + --decoding-method $m + done + +Export models +------------- + +`pruned_transducer_stateless7_ctc_bs/export.py `_ supports exporting checkpoints from ``pruned_transducer_stateless7_ctc_bs/exp`` in the following ways. + +Export ``model.state_dict()`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Checkpoints saved by ``pruned_transducer_stateless7_ctc_bs/train.py`` also include +``optimizer.state_dict()``. It is useful for resuming training. But after training, +we are interested only in ``model.state_dict()``. You can use the following +command to extract ``model.state_dict()``. + +.. code-block:: bash + + ./pruned_transducer_stateless7_ctc_bs/export.py \ + --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 30 \ + --avg 13 \ + --jit 0 + +It will generate a file ``./pruned_transducer_stateless7_ctc_bs/exp/pretrained.pt``. + +.. hint:: + + To use the generated ``pretrained.pt`` for ``pruned_transducer_stateless7_ctc_bs/ctc_guild_decode_bs.py``, + you can run: + + .. code-block:: bash + + cd pruned_transducer_stateless7_ctc_bs/exp + ln -s pretrained epoch-9999.pt + + And then pass ``--epoch 9999 --avg 1 --use-averaged-model 0`` to + ``./pruned_transducer_stateless7_ctc_bs/ctc_guild_decode_bs.py``. + +To use the exported model with ``./pruned_transducer_stateless7_ctc_bs/pretrained.py``, you +can run: + +.. code-block:: bash + + ./pruned_transducer_stateless7_ctc_bs/pretrained.py \ + --checkpoint ./pruned_transducer_stateless7_ctc_bs/exp/pretrained.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method greedy_search \ + /path/to/foo.wav \ + /path/to/bar.wav + +To test CTC branch using the exported model with ``./pruned_transducer_stateless7_ctc_bs/pretrained_ctc.py``: + +.. code-block:: bash + + ./pruned_transducer_stateless7_ctc_bs/jit_pretrained_ctc.py \ + --checkpoint ./pruned_transducer_stateless7_ctc_bs/exp/pretrained.pt \ + --bpe-model data/lang_bpe_500/bpe.model \ + --method ctc-decoding \ + --sample-rate 16000 \ + /path/to/foo.wav \ + /path/to/bar.wav + +Export model using ``torch.jit.script()`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: bash + + ./pruned_transducer_stateless7_ctc_bs/export.py \ + --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 30 \ + --avg 13 \ + --jit 1 + +It will generate a file ``cpu_jit.pt`` in the given ``exp_dir``. You can later +load it by ``torch.jit.load("cpu_jit.pt")``. + +Note ``cpu`` in the name ``cpu_jit.pt`` means the parameters when loaded into Python +are on CPU. You can use ``to("cuda")`` to move them to a CUDA device. + +To use the generated files with ``./pruned_transducer_stateless7_ctc_bs/jit_pretrained.py``: + +.. code-block:: bash + + ./pruned_transducer_stateless7_ctc_bs/jit_pretrained.py \ + --nn-model-filename ./pruned_transducer_stateless7_ctc_bs/exp/cpu_jit.pt \ + /path/to/foo.wav \ + /path/to/bar.wav + +To test CTC branch using the generated files with ``./pruned_transducer_stateless7_ctc_bs/jit_pretrained_ctc.py``: + +.. code-block:: bash + + ./pruned_transducer_stateless7_ctc_bs/jit_pretrained_ctc.py \ + --model-filename ./pruned_transducer_stateless7_ctc_bs/exp/cpu_jit.pt \ + --bpe-model data/lang_bpe_500/bpe.model \ + --method ctc-decoding \ + --sample-rate 16000 \ + /path/to/foo.wav \ + /path/to/bar.wav + +Download pretrained models +-------------------------- + +If you don't want to train from scratch, you can download the pretrained models +by visiting the following links: + + - ``_ + + See ``_ + for the details of the above pretrained models diff --git a/docs/source/recipes/Non-streaming-ASR/librispeech/zipformer_mmi.rst b/docs/source/recipes/Non-streaming-ASR/librispeech/zipformer_mmi.rst index db268dd02..a7b59a992 100644 --- a/docs/source/recipes/Non-streaming-ASR/librispeech/zipformer_mmi.rst +++ b/docs/source/recipes/Non-streaming-ASR/librispeech/zipformer_mmi.rst @@ -272,7 +272,7 @@ You will find the following files in that directory: Usage example ~~~~~~~~~~~~~ -You can use the following command to start the training using 8 GPUs: +You can use the following command to start the training using 4 GPUs: .. code-block:: bash @@ -382,7 +382,7 @@ can run: /path/to/bar.wav Export model using ``torch.jit.script()`` -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. code-block:: bash From dfbcf606e7a7798bc5d9f73da82126914800be0e Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 27 Dec 2022 09:25:42 +0800 Subject: [PATCH 058/174] small fixes to prepare.sh (#789) --- egs/librispeech/ASR/prepare.sh | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/egs/librispeech/ASR/prepare.sh b/egs/librispeech/ASR/prepare.sh index 59bed8389..b1d207049 100755 --- a/egs/librispeech/ASR/prepare.sh +++ b/egs/librispeech/ASR/prepare.sh @@ -123,10 +123,12 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then touch data/fbank/.librispeech.done fi - cat <(gunzip -c data/fbank/librispeech_cuts_train-clean-100.jsonl.gz) \ - <(gunzip -c data/fbank/librispeech_cuts_train-clean-360.jsonl.gz) \ - <(gunzip -c data/fbank/librispeech_cuts_train-other-500.jsonl.gz) | \ - shuf | gzip -c > data/fbank/librispeech_cuts_train-all-shuf.jsonl.gz + if [ ! -f data/fbank/librispeech_cuts_train-all-shuf.jsonl.gz ]; then + cat <(gunzip -c data/fbank/librispeech_cuts_train-clean-100.jsonl.gz) \ + <(gunzip -c data/fbank/librispeech_cuts_train-clean-360.jsonl.gz) \ + <(gunzip -c data/fbank/librispeech_cuts_train-other-500.jsonl.gz) | \ + shuf | gzip -c > data/fbank/librispeech_cuts_train-all-shuf.jsonl.gz + fi if [ ! -e data/fbank/.librispeech-validated.done ]; then log "Validating data/fbank for LibriSpeech" @@ -244,7 +246,7 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then fi if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then - log "Stage 7: Prepare bigram P" + log "Stage 7: Prepare bigram token-level P for MMI training" for vocab_size in ${vocab_sizes[@]}; do lang_dir=data/lang_bpe_${vocab_size} From 88b7895adf03424497619b54cdd9a230e9216b5c Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 27 Dec 2022 13:59:55 +0800 Subject: [PATCH 059/174] fix librispeech.py in multi-dataset setup (#791) --- .../ASR/pruned_transducer_stateless3/librispeech.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/librispeech.py b/egs/librispeech/ASR/pruned_transducer_stateless3/librispeech.py index 6dba8e9fe..9f2cb6225 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/librispeech.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/librispeech.py @@ -72,3 +72,12 @@ class LibriSpeech: f = self.manifest_dir / "librispeech_cuts_dev-other.jsonl.gz" logging.info(f"About to get dev-other cuts from {f}") return load_manifest_lazy(f) + + def train_all_shuf_cuts(self) -> CutSet: + logging.info( + "About to get the shuffled train-clean-100, \ + train-clean-360 and train-other-500 cuts" + ) + return load_manifest_lazy( + self.manifest_dir / "librispeech_cuts_train-all-shuf.jsonl.gz" + ) From a24a1cbfa9ffd03a629b988ae19e5d35248e72ec Mon Sep 17 00:00:00 2001 From: Yifan Yang <64255737+yfyeung@users.noreply.github.com> Date: Tue, 27 Dec 2022 15:06:53 +0800 Subject: [PATCH 060/174] small fix for zipformer_ctc_blankskip.rst (#792) Co-authored-by: yifanyang --- .../Non-streaming-ASR/librispeech/zipformer_ctc_blankskip.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/recipes/Non-streaming-ASR/librispeech/zipformer_ctc_blankskip.rst b/docs/source/recipes/Non-streaming-ASR/librispeech/zipformer_ctc_blankskip.rst index d85a3c67f..56a420605 100644 --- a/docs/source/recipes/Non-streaming-ASR/librispeech/zipformer_ctc_blankskip.rst +++ b/docs/source/recipes/Non-streaming-ASR/librispeech/zipformer_ctc_blankskip.rst @@ -228,7 +228,7 @@ You will find the following files in that directory: .. code-block:: bash $ cd pruned_transducer_stateless7_ctc_bs/exp/tensorboard - $ tensorboard dev upload --logdir . --description "Zipformer MMI training for LibriSpeech with icefall" + $ tensorboard dev upload --logdir . --description "Zipformer-CTC co-training using blank skip for LibriSpeech with icefall" It will print something like below: From 05dfd5e630d525dcc8828feba4d9daf6624af319 Mon Sep 17 00:00:00 2001 From: marcoyang1998 <45973641+marcoyang1998@users.noreply.github.com> Date: Tue, 27 Dec 2022 15:26:11 +0800 Subject: [PATCH 061/174] Fix distillation with HuBERT (#790) * update vq huggingface url * remove hard lhotse version requirement * resolve ID mismatch * small fixes * Update egs/librispeech/ASR/pruned_transducer_stateless6/vq_utils.py Co-authored-by: Fangjun Kuang * update version check Co-authored-by: Fangjun Kuang --- .../ASR/distillation_with_hubert.sh | 12 +++++-- .../pruned_transducer_stateless6/vq_utils.py | 34 ++++++++++++++++--- 2 files changed, 39 insertions(+), 7 deletions(-) diff --git a/egs/librispeech/ASR/distillation_with_hubert.sh b/egs/librispeech/ASR/distillation_with_hubert.sh index 2a69d3921..d5d3008aa 100755 --- a/egs/librispeech/ASR/distillation_with_hubert.sh +++ b/egs/librispeech/ASR/distillation_with_hubert.sh @@ -35,7 +35,7 @@ stop_stage=4 # export CUDA_VISIBLE_DEVICES="0" # # Suppose GPU 2,3,4,5 are available. -export CUDA_VISIBLE_DEVICES="0,1,2,3" +# export CUDA_VISIBLE_DEVICES="0,1,2,3" exp_dir=./pruned_transducer_stateless6/exp mkdir -p $exp_dir @@ -49,7 +49,7 @@ full_libri=False # "True" -> stage 0 and stage 1 would be skipped, # and directly download the extracted codebook indexes for distillation # "False" -> start from scratch -use_extracted_codebook=False +use_extracted_codebook=True # teacher_model_id can be one of # "hubert_xtralarge_ll60k_finetune_ls960" -> fine-tuned model, it is the one we currently use. @@ -155,8 +155,14 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then fi log "Downloading extracted codebook indexes to $codebook_download_dir" # Make sure you have git-lfs installed (https://git-lfs.github.com) + # The codebook indexes are generated using lhotse 1.11.0, to avoid + # potential issues, we recommend you to use lhotse version >= 1.11.0 + lhotse_version=$(python3 -c "import lhotse; from packaging import version; print(version.parse(lhotse.version.__version__)>=version.parse('1.11.0'))") + if [ "$lhotse_version" == "False" ]; then + log "Expecting lhotse >= 1.11.0. This may lead to potential ID mismatch." + fi git lfs install - git clone https://huggingface.co/Zengwei/pruned_transducer_stateless6_hubert_xtralarge_ll60k_finetune_ls960 $codebook_download_dir + git clone https://huggingface.co/marcoyang/pruned_transducer_stateless6_hubert_xtralarge_ll60k_finetune_ls960 $codebook_download_dir mkdir -p data/vq_fbank mv $codebook_download_dir/*.jsonl.gz data/vq_fbank/ diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/vq_utils.py b/egs/librispeech/ASR/pruned_transducer_stateless6/vq_utils.py index 97a83b974..bf072d865 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/vq_utils.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/vq_utils.py @@ -244,10 +244,36 @@ class CodebookIndexExtractor: ) cuts_vq = load_manifest(vq_manifest_path) cuts_ori = load_manifest(ori_manifest_path) - cuts_vq = cuts_vq.sort_like(cuts_ori) - for cut_idx, (cut_vq, cut_ori) in enumerate(zip(cuts_vq, cuts_ori)): - assert cut_vq.id == cut_ori.id - cut_ori.codebook_indexes = cut_vq.codebook_indexes + assert len(cuts_vq) == len(cuts_ori), "Cuts should have the same length!" + + if set(cuts_vq.ids) == set(cuts_ori.ids): + # IDs match exactly + cuts_vq = cuts_vq.sort_like(cuts_ori) + for cut_idx, (cut_vq, cut_ori) in enumerate(zip(cuts_vq, cuts_ori)): + assert cut_vq.id == cut_ori.id, (cut_vq.id, cut_ori.id) + cut_ori.codebook_indexes = cut_vq.codebook_indexes + else: + # in case of ID mismatch, remap them + # get the mapping between audio and cut ID + logging + ori_id_map = {} + for id in cuts_ori.ids: + # some text normalization + if "sp" in id: + clean_id = "-".join(id.split("-")[:3]) + "_" + id.split("_")[-1] + else: + clean_id = "-".join(id.split("-")[:3]) + ori_id_map[clean_id] = id + + for id in cuts_vq.ids: + if "sp" in id: + clean_id = "-".join(id.split("-")[:3]) + "_" + id.split("_")[-1] + else: + clean_id = "-".join(id.split("-")[:3]) + assert clean_id in ori_id_map, clean_id + cuts_ori[ori_id_map[clean_id]].codebook_indexes = cuts_vq[ + id + ].codebook_indexes CutSet.from_cuts(cuts_ori).to_jsonl(dst_vq_manifest_path) logging.info(f"Processed {subset}.") From 3c54333b06a87bb2efc665c66d3c25370033d182 Mon Sep 17 00:00:00 2001 From: Yuekai Zhang Date: Wed, 28 Dec 2022 11:20:38 +0800 Subject: [PATCH 062/174] fix bug (#796) --- .../pruned_transducer_stateless5/conformer.py | 38 ++++++++++++------- 1 file changed, 25 insertions(+), 13 deletions(-) diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/conformer.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/conformer.py index 9bb55d07a..23a877b2f 100644 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/conformer.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/conformer.py @@ -966,20 +966,32 @@ class RelPositionMultiheadAttention(nn.Module): (batch_size, num_heads, time1, n) = x.shape time2 = time1 + left_context - assert ( - n == left_context + 2 * time1 - 1 - ), f"{n} == {left_context} + 2 * {time1} - 1" + if not torch.jit.is_tracing(): + assert ( + n == left_context + 2 * time1 - 1 + ), f"{n} == {left_context} + 2 * {time1} - 1" - # Note: TorchScript requires explicit arg for stride() - batch_stride = x.stride(0) - head_stride = x.stride(1) - time1_stride = x.stride(2) - n_stride = x.stride(3) - return x.as_strided( - (batch_size, num_heads, time1, time2), - (batch_stride, head_stride, time1_stride - n_stride, n_stride), - storage_offset=n_stride * (time1 - 1), - ) + if torch.jit.is_tracing(): + rows = torch.arange(start=time1 - 1, end=-1, step=-1) + cols = torch.arange(time2) + rows = rows.repeat(batch_size * num_heads).unsqueeze(-1) + indexes = rows + cols + + x = x.reshape(-1, n) + x = torch.gather(x, dim=1, index=indexes) + x = x.reshape(batch_size, num_heads, time1, time2) + return x + else: + # Note: TorchScript requires explicit arg for stride() + batch_stride = x.stride(0) + head_stride = x.stride(1) + time1_stride = x.stride(2) + n_stride = x.stride(3) + return x.as_strided( + (batch_size, num_heads, time1, time2), + (batch_stride, head_stride, time1_stride - n_stride, n_stride), + storage_offset=n_stride * (time1 - 1), + ) def multi_head_attention_forward( self, From 1f0408b1031dccfcb13ae3641b576434aec4f983 Mon Sep 17 00:00:00 2001 From: marcoyang1998 <45973641+marcoyang1998@users.noreply.github.com> Date: Thu, 29 Dec 2022 10:53:36 +0800 Subject: [PATCH 063/174] Support Transformer LM (#750) * support transformer LM * show number of parameters during training * update docstring * testing files for ppl calculation * add lm wrampper for rnn and transformer LM * apply lm wrapper in lm shallow fusion * small updates * update decode.py to support LM fusion and LODR * add export.py * update CI and workflow * update decoding results * fix CI * remove transformer LM from CI test --- ...h-lstm-transducer-stateless2-2022-09-03.sh | 24 +- ...-lstm-transducer-stateless2-2022-09-03.yml | 11 +- egs/librispeech/ASR/RESULTS.md | 66 +- .../ASR/lstm_transducer_stateless2/decode.py | 195 +++--- .../beam_search.py | 584 +++++++++-------- .../pruned_transducer_stateless3/decode.py | 186 +++--- .../pruned_transducer_stateless5/decode.py | 254 +++++--- .../pruned_transducer_stateless7/decode.py | 178 ++++- icefall/__init__.py | 2 + icefall/lm_wrapper.py | 254 ++++++++ icefall/rnn_lm/model.py | 21 +- icefall/rnn_lm/train.py | 3 + icefall/transformer_lm/attention.py | 510 +++++++++++++++ icefall/transformer_lm/compute_perplexity.py | 195 ++++++ icefall/transformer_lm/dataset.py | 1 + icefall/transformer_lm/encoder.py | 329 ++++++++++ icefall/transformer_lm/export.py | 186 ++++++ icefall/transformer_lm/model.py | 115 ++++ icefall/transformer_lm/scaling.py | 1 + icefall/transformer_lm/train.py | 609 ++++++++++++++++++ 20 files changed, 3086 insertions(+), 638 deletions(-) create mode 100644 icefall/lm_wrapper.py create mode 100644 icefall/transformer_lm/attention.py create mode 100644 icefall/transformer_lm/compute_perplexity.py create mode 120000 icefall/transformer_lm/dataset.py create mode 100644 icefall/transformer_lm/encoder.py create mode 100644 icefall/transformer_lm/export.py create mode 100644 icefall/transformer_lm/model.py create mode 120000 icefall/transformer_lm/scaling.py create mode 100644 icefall/transformer_lm/train.py diff --git a/.github/scripts/run-librispeech-lstm-transducer-stateless2-2022-09-03.sh b/.github/scripts/run-librispeech-lstm-transducer-stateless2-2022-09-03.sh index ac5b15979..9b883f889 100755 --- a/.github/scripts/run-librispeech-lstm-transducer-stateless2-2022-09-03.sh +++ b/.github/scripts/run-librispeech-lstm-transducer-stateless2-2022-09-03.sh @@ -193,7 +193,7 @@ if [[ x"${GITHUB_EVENT_LABEL_NAME}" == x"shallow-fusion" ]]; then ls -lh data ls -lh lstm_transducer_stateless2/exp - log "Decoding test-clean and test-other" + log "Decoding test-clean and test-other with RNN LM" ./lstm_transducer_stateless2/decode.py \ --use-averaged-model 0 \ @@ -201,12 +201,14 @@ if [[ x"${GITHUB_EVENT_LABEL_NAME}" == x"shallow-fusion" ]]; then --avg 1 \ --exp-dir lstm_transducer_stateless2/exp \ --max-duration 600 \ - --decoding-method modified_beam_search_rnnlm_shallow_fusion \ + --decoding-method modified_beam_search_lm_shallow_fusion \ --beam 4 \ - --rnn-lm-scale 0.3 \ - --rnn-lm-exp-dir $lm_repo/exp \ - --rnn-lm-epoch 88 \ - --rnn-lm-avg 1 \ + --use-shallow-fusion 1 \ + --lm-type rnn \ + --lm-exp-dir $lm_repo/exp \ + --lm-epoch 88 \ + --lm-avg 1 \ + --lm-scale 0.3 \ --rnn-lm-num-layers 3 \ --rnn-lm-tie-weights 1 fi @@ -245,11 +247,13 @@ if [[ x"${GITHUB_EVENT_LABEL_NAME}" == x"LODR" ]]; then --avg 1 \ --exp-dir lstm_transducer_stateless2/exp \ --max-duration 600 \ - --decoding-method modified_beam_search_rnnlm_LODR \ + --decoding-method modified_beam_search_LODR \ --beam 4 \ - --rnn-lm-scale 0.3 \ - --rnn-lm-exp-dir $lm_repo/exp \ - --rnn-lm-epoch 88 \ + --use-shallow-fusion 1 \ + --lm-type rnn \ + --lm-exp-dir $lm_repo/exp \ + --lm-scale 0.4 \ + --lm-epoch 88 \ --rnn-lm-avg 1 \ --rnn-lm-num-layers 3 \ --rnn-lm-tie-weights 1 \ diff --git a/.github/workflows/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml b/.github/workflows/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml index f5ee09e16..3752f67e3 100644 --- a/.github/workflows/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml +++ b/.github/workflows/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml @@ -139,9 +139,10 @@ jobs: cd egs/librispeech/ASR tree lstm_transducer_stateless2/exp cd lstm_transducer_stateless2/exp - echo "===modified_beam_search_rnnlm_shallow_fusion===" - find modified_beam_search_rnnlm_shallow_fusion -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find modified_beam_search_rnnlm_shallow_fusion -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 + echo "===modified_beam_search_lm_shallow_fusion===" + echo "===Using RNNLM===" + find modified_beam_search_lm_shallow_fusion -name "log-*rnn*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 + find modified_beam_search_lm_shallow_fusion -name "log-*rnn*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - name: Display decoding results for lstm_transducer_stateless2 if: github.event.label.name == 'LODR' @@ -151,8 +152,8 @@ jobs: tree lstm_transducer_stateless2/exp cd lstm_transducer_stateless2/exp echo "===modified_beam_search_rnnlm_LODR===" - find modified_beam_search_rnnlm_LODR -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find modified_beam_search_rnnlm_LODR -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 + find modified_beam_search_LODR -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 + find modified_beam_search_LODR -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - name: Upload decoding results for lstm_transducer_stateless2 uses: actions/upload-artifact@v2 diff --git a/egs/librispeech/ASR/RESULTS.md b/egs/librispeech/ASR/RESULTS.md index 092f77814..007d34a62 100644 --- a/egs/librispeech/ASR/RESULTS.md +++ b/egs/librispeech/ASR/RESULTS.md @@ -320,6 +320,10 @@ Number of model parameters: 70369391, i.e., 70.37 M |----------------------|------------|-------------|----------------------------------------| | greedy search | 2.17 | 5.23 | --epoch 39 --avg 6 --max-duration 600 | | modified beam search | 2.15 | 5.20 | --epoch 39 --avg 6 --max-duration 600 | +| modified beam search + RNNLM shallow fusion | 1.99 | 4.73 | --epoch 39 --avg 6 --max-duration 600 | +| modified beam search + TransformerLM shallow fusion | 1.94 | 4.73 | --epoch 39 --avg 6 --max-duration 600 | +| modified beam search + RNNLM + LODR | 1.91 | 4.57 | --epoch 39 --avg 6 --max-duration 600 | +| modified beam search + TransformerLM + LODR | 1.91 | 4.51 | --epoch 39 --avg 6 --max-duration 600 | | fast beam search | 2.15 | 5.22 | --epoch 39 --avg 6 --max-duration 600 | The training commands are: @@ -458,7 +462,9 @@ The WERs are: | greedy search (max sym per frame 1) | 2.78 | 7.36 | --iter 468000 --avg 16 | | modified_beam_search | 2.73 | 7.15 | --iter 468000 --avg 16 | | modified_beam_search + RNNLM shallow fusion | 2.42 | 6.46 | --iter 468000 --avg 16 | -| modified_beam_search + RNNLM shallow fusion | 2.28 | 5.94 | --iter 468000 --avg 16 | +| modified_beam_search + TransformerLM shallow fusion | 2.37 | 6.48 | --iter 468000 --avg 16 | +| modified_beam_search + RNNLM + LODR | 2.24 | 5.89 | --iter 468000 --avg 16 | +| modified_beam_search + TransformerLM + LODR | 2.19 | 5.90 | --iter 468000 --avg 16 | | fast_beam_search | 2.76 | 7.31 | --iter 468000 --avg 16 | | greedy search (max sym per frame 1) | 2.77 | 7.35 | --iter 472000 --avg 18 | | modified_beam_search | 2.75 | 7.08 | --iter 472000 --avg 18 | @@ -513,9 +519,12 @@ for m in greedy_search fast_beam_search modified_beam_search; do done ``` -To decode with RNNLM shallow fusion, use the following decoding command. A well-trained RNNLM -can be found here: +You may also decode using shallow fusion with external neural network LM. To do so you need to +download a well-trained NN LM: +RNN LM: +Transformer LM: +```bash for iter in 472000; do for avg in 8 10 12 14 16 18; do ./lstm_transducer_stateless2/decode.py \ @@ -523,23 +532,24 @@ for iter in 472000; do --avg $avg \ --exp-dir ./lstm_transducer_stateless2/exp \ --max-duration 600 \ - --decoding-method modified_beam_search_rnnlm_shallow_fusion \ - --beam 4 \ - --rnn-lm-scale 0.3 \ - --rnn-lm-exp-dir /path/to/RNNLM \ - --rnn-lm-epoch 99 \ - --rnn-lm-avg 1 \ - --rnn-lm-num-layers 3 \ - --rnn-lm-tie-weights 1 + --decoding-method modified_beam_search_lm_shallow_fusion \ + --use-shallow-fusion 1 \ + --lm-type rnn \ + --lm-exp-dir /ceph-data4/yangxiaoyu/pretrained_models/LM/icefall-librispeech-rnn-lm/exp \ + --lm-epoch 99 \ + --lm-scale $lm_scale \ + --lm-avg 1 \ done done +``` -You may also decode using LODR + RNNLM shallow fusion. This decoding method is proposed in . +You may also decode using LODR + LM shallow fusion. This decoding method is proposed in . It subtracts the internal language model score during shallow fusion, which is approximated by a bi-gram model. The bi-gram can be generated by `generate-lm.sh`, or you may download it from . The decoding command is as follows: +```bash for iter in 472000; do for avg in 8 10 12 14 16 18; do ./lstm_transducer_stateless2/decode.py \ @@ -547,18 +557,22 @@ for iter in 472000; do --avg $avg \ --exp-dir ./lstm_transducer_stateless2/exp \ --max-duration 600 \ - --decoding-method modified_beam_search_rnnlm_LODR \ + --decoding-method modified_beam_search_LODR \ --beam 4 \ - --rnn-lm-scale 0.4 \ - --rnn-lm-exp-dir /path/to/RNNLM \ - --rnn-lm-epoch 99 \ - --rnn-lm-avg 1 \ - --rnn-lm-num-layers 3 \ - --rnn-lm-tie-weights 1 \ - --token-ngram 2 \ + --max-contexts 4 \ + --use-shallow-fusion 1 \ + --lm-type rnn \ + --lm-exp-dir /ceph-data4/yangxiaoyu/pretrained_models/LM/icefall-librispeech-rnn-lm/exp \ + --lm-epoch 99 \ + --lm-scale 0.4 \ + --lm-avg 1 \ + --tokens-ngram 2 \ --ngram-lm-scale -0.16 done done +``` +Note that you can also set `--lm-type transformer` to use transformer LM during LODR. But it will be slower +because it has not been optimized. The pre-trained transformer LM is available at Pretrained models, training logs, decoding logs, and decoding results are available at @@ -1717,6 +1731,9 @@ layers (24 v.s 12) but a narrower model (1536 feedforward dim and 384 encoder di | greedy search (max sym per frame 1) | 2.54 | 5.72 | --epoch 30 --avg 10 --max-duration 600 | | modified beam search | 2.47 | 5.71 | --epoch 30 --avg 10 --max-duration 600 | | modified beam search + RNNLM shallow fusion | 2.27 | 5.24 | --epoch 30 --avg 10 --max-duration 600 | +| modified beam search + RNNLM + LODR | 2.23 | 5.17 | --epoch 30 --avg 10 --max-duration 600 | +| modified beam search + TransformerLM shallow fusion | 2.27 | 5.26 | --epoch 30 --avg 10 --max-duration 600 | +| modified beam search + TransformerLM + LODR | 2.22 | 5.11 | --epoch 30 --avg 10 --max-duration 600 | | fast beam search | 2.5 | 5.72 | --epoch 30 --avg 10 --max-duration 600 | ```bash @@ -2080,7 +2097,8 @@ subset so that the gigaspeech dataloader never exhausts. | greedy search (max sym per frame 1) | 2.03 | 4.70 | --iter 1224000 --avg 14 --max-duration 600 | | modified beam search | 2.00 | 4.63 | --iter 1224000 --avg 14 --max-duration 600 | | modified beam search + rnnlm shallow fusion | 1.94 | 4.2 | --iter 1224000 --avg 14 --max-duration 600 | -| modified beam search + LODR | 1.83 | 4.03 | --iter 1224000 --avg 14 --max-duration 600 | +| modified beam search + rnnlm + LODR | 1.77 | 3.99 | --iter 1224000 --avg 14 --max-duration 600 | +| modified beam search + TransformerLM + LODR | 1.75 | 3.94 | --iter 1224000 --avg 14 --max-duration 600 | | fast beam search | 2.10 | 4.68 | --iter 1224000 --avg 14 --max-duration 600 | The training commands are: @@ -2126,8 +2144,10 @@ for iter in 1224000; do done done ``` -You may also decode using shallow fusion with external RNNLM. To do so you need to -download a well-trained RNNLM from this link +You may also decode using shallow fusion with external neural network LM. To do so you need to +download a well-trained NN LM: +RNN LM: +Transformer LM: ```bash rnn_lm_scale=0.3 diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py b/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py index fa5bf1825..78be9c01f 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py @@ -93,36 +93,37 @@ Usage: --max-contexts 8 \ --max-states 64 -(8) modified beam search (with RNNLM shallow fusion) +(8) modified beam search (with LM shallow fusion) ./lstm_transducer_stateless2/decode.py \ --epoch 35 \ --avg 15 \ --exp-dir ./lstm_transducer_stateless2/exp \ --max-duration 600 \ - --decoding-method modified_beam_search_rnnlm_shallow_fusion \ + --decoding-method modified_beam_search_lm_shallow_fusion \ --beam 4 \ - --rnn-lm-scale 0.3 \ - --rnn-lm-exp-dir /path/to/RNNLM \ + --lm-type rnn \ + --lm-scale 0.3 \ + --lm-exp-dir /path/to/LM \ --rnn-lm-epoch 99 \ --rnn-lm-avg 1 \ --rnn-lm-num-layers 3 \ --rnn-lm-tie-weights 1 -(9) modified beam search with RNNLM shallow fusion + LODR +(9) modified beam search with LM shallow fusion + LODR ./lstm_transducer_stateless2/decode.py \ --epoch 35 \ --avg 15 \ --max-duration 600 \ --exp-dir ./lstm_transducer_stateless2/exp \ - --decoding-method modified_beam_search_rnnlm_LODR \ + --decoding-method modified_beam_search_LODR \ --beam 4 \ - --max-contexts 4 \ - --rnn-lm-scale 0.4 \ - --rnn-lm-exp-dir /path/to/RNNLM/exp \ + --lm-type rnn \ + --lm-scale 0.4 \ + --lm-exp-dir /path/to/LM \ --rnn-lm-epoch 99 \ --rnn-lm-avg 1 \ --rnn-lm-num-layers 3 \ - --rnn-lm-tie-weights 1 \ + --rnn-lm-tie-weights 1 --tokens-ngram 2 \ --ngram-lm-scale -0.16 \ """ @@ -148,14 +149,14 @@ from beam_search import ( greedy_search, greedy_search_batch, modified_beam_search, + modified_beam_search_lm_shallow_fusion, + modified_beam_search_LODR, modified_beam_search_ngram_rescoring, - modified_beam_search_rnnlm_LODR, - modified_beam_search_rnnlm_shallow_fusion, ) from librispeech import LibriSpeech from train import add_model_arguments, get_params, get_transducer_model -from icefall import NgramLm +from icefall import LmScorer, NgramLm from icefall.checkpoint import ( average_checkpoints, average_checkpoints_with_averaged_model, @@ -163,7 +164,6 @@ from icefall.checkpoint import ( load_checkpoint, ) from icefall.lexicon import Lexicon -from icefall.rnn_lm.model import RnnLmModel from icefall.utils import ( AttributeDict, setup_logger, @@ -253,8 +253,8 @@ def get_parser(): - fast_beam_search_nbest_oracle - fast_beam_search_nbest_LG - modified_beam_search_ngram_rescoring - - modified_beam_search_rnnlm_shallow_fusion - - modified_beam_search_rnnlm_LODR + - modified_beam_search_lm_shallow_fusion + - modified_beam_search_LODR If you use fast_beam_search_nbest_LG, you have to specify `--lang-dir`, which should contain `LG.pt`. """, @@ -344,67 +344,28 @@ def get_parser(): ) parser.add_argument( - "--rnn-lm-scale", - type=float, - default=0.0, - help="""Used only when --method is modified-beam-search_rnnlm_shallow_fusion. - It specifies the path to RNN LM exp dir. - """, - ) - - parser.add_argument( - "--rnn-lm-exp-dir", - type=str, - default="rnn_lm/exp", - help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion. - It specifies the path to RNN LM exp dir. - """, - ) - - parser.add_argument( - "--rnn-lm-epoch", - type=int, - default=7, - help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion. - It specifies the checkpoint to use. - """, - ) - - parser.add_argument( - "--rnn-lm-avg", - type=int, - default=2, - help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion. - It specifies the number of checkpoints to average. - """, - ) - - parser.add_argument( - "--rnn-lm-embedding-dim", - type=int, - default=2048, - help="Embedding dim of the model", - ) - - parser.add_argument( - "--rnn-lm-hidden-dim", - type=int, - default=2048, - help="Hidden dim of the model", - ) - - parser.add_argument( - "--rnn-lm-num-layers", - type=int, - default=4, - help="Number of RNN layers the model", - ) - parser.add_argument( - "--rnn-lm-tie-weights", + "--use-shallow-fusion", type=str2bool, default=False, - help="""True to share the weights between the input embedding layer and the - last output linear layer + help="""Use neural network LM for shallow fusion. + If you want to use LODR, you will also need to set this to true + """, + ) + + parser.add_argument( + "--lm-type", + type=str, + default="rnn", + help="Type of NN lm", + choices=["rnn", "transformer"], + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.3, + help="""The scale of the neural network LM + Used only when `--use-shallow-fusion` is set to True. """, ) @@ -440,8 +401,7 @@ def decode_one_batch( decoding_graph: Optional[k2.Fsa] = None, ngram_lm: Optional[NgramLm] = None, ngram_lm_scale: float = 1.0, - rnnlm: Optional[RnnLmModel] = None, - rnnlm_scale: float = 1.0, + LM: Optional[LmScorer] = None, ) -> Dict[str, List[List[str]]]: """Decode one batch and return the result in a dict. The dict has the following format: @@ -470,6 +430,9 @@ def decode_one_batch( The decoding graph. Can be either a `k2.trivial_graph` or LG, Used only when --decoding_method is fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + LM: + A neural net LM for shallow fusion. Only used when `--use-shallow-fusion` + set to true. Returns: Return the decoding result. See above description for the format of the returned dict. @@ -581,20 +544,19 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif params.decoding_method == "modified_beam_search_rnnlm_shallow_fusion": - hyp_tokens = modified_beam_search_rnnlm_shallow_fusion( + elif params.decoding_method == "modified_beam_search_lm_shallow_fusion": + hyp_tokens = modified_beam_search_lm_shallow_fusion( model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, beam=params.beam_size, sp=sp, - rnnlm=rnnlm, - rnnlm_scale=rnnlm_scale, + LM=LM, ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif params.decoding_method == "modified_beam_search_rnnlm_LODR": - hyp_tokens = modified_beam_search_rnnlm_LODR( + elif params.decoding_method == "modified_beam_search_LODR": + hyp_tokens = modified_beam_search_LODR( model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, @@ -602,8 +564,7 @@ def decode_one_batch( sp=sp, LODR_lm=ngram_lm, LODR_lm_scale=ngram_lm_scale, - rnnlm=rnnlm, - rnnlm_scale=rnnlm_scale, + LM=LM, ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) @@ -658,8 +619,7 @@ def decode_dataset( decoding_graph: Optional[k2.Fsa] = None, ngram_lm: Optional[NgramLm] = None, ngram_lm_scale: float = 1.0, - rnnlm: Optional[RnnLmModel] = None, - rnnlm_scale: float = 1.0, + LM: Optional[LmScorer] = None, ) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: """Decode dataset. @@ -678,6 +638,8 @@ def decode_dataset( The decoding graph. Can be either a `k2.trivial_graph` or LG, Used only when --decoding_method is fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + LM: + A neural network LM, used during shallow fusion Returns: Return a dict, whose key may be "greedy_search" if greedy search is used, or it may be "beam_7" if beam size of 7 is used. @@ -711,8 +673,7 @@ def decode_dataset( batch=batch, ngram_lm=ngram_lm, ngram_lm_scale=ngram_lm_scale, - rnnlm=rnnlm, - rnnlm_scale=rnnlm_scale, + LM=LM, ) for name, hyps in hyps_dict.items(): @@ -730,6 +691,7 @@ def decode_dataset( batch_str = f"{batch_idx}/{num_batches}" logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -781,6 +743,7 @@ def save_results( def main(): parser = get_parser() AsrDataModule.add_arguments(parser) + LmScorer.add_arguments(parser) args = parser.parse_args() args.exp_dir = Path(args.exp_dir) @@ -795,9 +758,9 @@ def main(): "fast_beam_search_nbest_LG", "fast_beam_search_nbest_oracle", "modified_beam_search", - "modified_beam_search_rnnlm_LODR", + "modified_beam_search_LODR", + "modified_beam_search_lm_shallow_fusion", "modified_beam_search_ngram_rescoring", - "modified_beam_search_rnnlm_shallow_fusion", ) params.res_dir = params.exp_dir / params.decoding_method @@ -820,12 +783,18 @@ def main(): else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" - params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" - if "rnnlm" in params.decoding_method: - params.suffix += f"-rnnlm-lm-scale-{params.rnn_lm_scale}" + if "ngram" in params.decoding_method: + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" + if params.use_shallow_fusion: + if params.lm_type == "rnn": + params.suffix += f"-rnnlm-lm-scale-{params.lm_scale}" + elif params.lm_type == "transformer": + params.suffix += f"-transformer-lm-scale-{params.lm_scale}" - if "LODR" in params.decoding_method: - params.suffix += "-LODR" + if "LODR" in params.decoding_method: + params.suffix += ( + f"-LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}" + ) if params.use_averaged_model: params.suffix += "-use-averaged-model" @@ -954,28 +923,19 @@ def main(): ngram_lm = None ngram_lm_scale = None - # only load rnnlm if used - if "rnnlm" in params.decoding_method: - rnn_lm_scale = params.rnn_lm_scale - - rnn_lm_model = RnnLmModel( - vocab_size=params.vocab_size, - embedding_dim=params.rnn_lm_embedding_dim, - hidden_dim=params.rnn_lm_hidden_dim, - num_layers=params.rnn_lm_num_layers, - tie_weights=params.rnn_lm_tie_weights, + # only load the neural network LM if doing shallow fusion + if params.use_shallow_fusion: + LM = LmScorer( + lm_type=params.lm_type, + params=params, + device=device, + lm_scale=params.lm_scale, ) - assert params.rnn_lm_avg == 1 + LM.to(device) + LM.eval() - load_checkpoint( - f"{params.rnn_lm_exp_dir}/epoch-{params.rnn_lm_epoch}.pt", - rnn_lm_model, - ) - rnn_lm_model.to(device) - rnn_lm_model.eval() else: - rnn_lm_model = None - rnn_lm_scale = 0.0 + LM = None if "fast_beam_search" in params.decoding_method: if params.decoding_method == "fast_beam_search_nbest_LG": @@ -1003,7 +963,9 @@ def main(): librispeech = LibriSpeech(manifest_dir=args.manifest_dir) test_clean_cuts = librispeech.test_clean_cuts() + # test_clean_cuts = test_clean_cuts.subset(first=500) test_other_cuts = librispeech.test_other_cuts() + # test_other_cuts = test_other_cuts.subset(first=500) test_clean_dl = asr_datamodule.test_dataloaders(test_clean_cuts) test_other_dl = asr_datamodule.test_dataloaders(test_other_cuts) @@ -1021,8 +983,7 @@ def main(): decoding_graph=decoding_graph, ngram_lm=ngram_lm, ngram_lm_scale=ngram_lm_scale, - rnnlm=rnn_lm_model, - rnnlm_scale=rnn_lm_scale, + LM=LM, ) save_results( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py index b324cc9b7..7388af389 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -26,7 +26,9 @@ from model import Transducer from icefall import NgramLm, NgramLmStateCost from icefall.decode import Nbest, one_best_decoding +from icefall.lm_wrapper import LmScorer from icefall.rnn_lm.model import RnnLmModel +from icefall.transformer_lm.model import TransformerLM from icefall.utils import ( DecodingResults, add_eos, @@ -1846,254 +1848,14 @@ def modified_beam_search_ngram_rescoring( return ans -def modified_beam_search_rnnlm_shallow_fusion( - model: Transducer, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - sp: spm.SentencePieceProcessor, - rnnlm: RnnLmModel, - rnnlm_scale: float, - beam: int = 4, - return_timestamps: bool = False, -) -> List[List[int]]: - """Modified_beam_search + RNNLM shallow fusion - - Args: - model (Transducer): - The transducer model - encoder_out (torch.Tensor): - Encoder output in (N,T,C) - encoder_out_lens (torch.Tensor): - A 1-D tensor of shape (N,), containing the number of - valid frames in encoder_out before padding. - sp: - Sentence piece generator. - rnnlm (RnnLmModel): - RNNLM - rnnlm_scale (float): - scale of RNNLM in shallow fusion - beam (int, optional): - Beam size. Defaults to 4. - - Returns: - Return a list-of-list of token IDs. ans[i] is the decoding results - for the i-th utterance. - """ - assert encoder_out.ndim == 3, encoder_out.shape - assert encoder_out.size(0) >= 1, encoder_out.size(0) - assert rnnlm is not None - lm_scale = rnnlm_scale - vocab_size = rnnlm.vocab_size - - packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( - input=encoder_out, - lengths=encoder_out_lens.cpu(), - batch_first=True, - enforce_sorted=False, - ) - - blank_id = model.decoder.blank_id - sos_id = sp.piece_to_id("") - unk_id = getattr(model, "unk_id", blank_id) - context_size = model.decoder.context_size - device = next(model.parameters()).device - - batch_size_list = packed_encoder_out.batch_sizes.tolist() - N = encoder_out.size(0) - assert torch.all(encoder_out_lens > 0), encoder_out_lens - assert N == batch_size_list[0], (N, batch_size_list) - - # get initial lm score and lm state by scoring the "sos" token - sos_token = torch.tensor([[sos_id]]).to(torch.int64).to(device) - init_score, init_states = rnnlm.score_token(sos_token) - - B = [HypothesisList() for _ in range(N)] - for i in range(N): - B[i].add( - Hypothesis( - ys=[blank_id] * context_size, - log_prob=torch.zeros(1, dtype=torch.float32, device=device), - state=init_states, - lm_score=init_score.reshape(-1), - timestamp=[], - ) - ) - - rnnlm.clean_cache() - encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) - - offset = 0 - finalized_B = [] - for (t, batch_size) in enumerate(batch_size_list): - start = offset - end = offset + batch_size - current_encoder_out = encoder_out.data[start:end] # get batch - current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1) - # current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim) - offset = end - - finalized_B = B[batch_size:] + finalized_B - B = B[:batch_size] - - hyps_shape = get_hyps_shape(B).to(device) - - A = [list(b) for b in B] - B = [HypothesisList() for _ in range(batch_size)] - - ys_log_probs = torch.cat( - [hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps] - ) - - decoder_input = torch.tensor( - [hyp.ys[-context_size:] for hyps in A for hyp in hyps], - device=device, - dtype=torch.int64, - ) # (num_hyps, context_size) - - decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) - decoder_out = model.joiner.decoder_proj(decoder_out) - - current_encoder_out = torch.index_select( - current_encoder_out, - dim=0, - index=hyps_shape.row_ids(1).to(torch.int64), - ) # (num_hyps, 1, 1, encoder_out_dim) - - logits = model.joiner( - current_encoder_out, - decoder_out, - project_input=False, - ) # (num_hyps, 1, 1, vocab_size) - - logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) - - log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size) - - log_probs.add_(ys_log_probs) - - vocab_size = log_probs.size(-1) - - log_probs = log_probs.reshape(-1) - - row_splits = hyps_shape.row_splits(1) * vocab_size - log_probs_shape = k2.ragged.create_ragged_shape2( - row_splits=row_splits, cached_tot_size=log_probs.numel() - ) - ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) - """ - for all hyps with a non-blank new token, score this token. - It is a little confusing here because this for-loop - looks very similar to the one below. Here, we go through all - top-k tokens and only add the non-blanks ones to the token_list. - The RNNLM will score those tokens given the LM states. Note that - the variable `scores` is the LM score after seeing the new - non-blank token. - """ - token_list = [] - hs = [] - cs = [] - for i in range(batch_size): - topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) - - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - topk_hyp_indexes = (topk_indexes // vocab_size).tolist() - topk_token_indexes = (topk_indexes % vocab_size).tolist() - for k in range(len(topk_hyp_indexes)): - hyp_idx = topk_hyp_indexes[k] - hyp = A[i][hyp_idx] - - new_token = topk_token_indexes[k] - if new_token not in (blank_id, unk_id): - assert new_token != 0, new_token - token_list.append([new_token]) - # store the LSTM states - hs.append(hyp.state[0]) - cs.append(hyp.state[1]) - - # forward RNNLM to get new states and scores - if len(token_list) != 0: - tokens_to_score = ( - torch.tensor(token_list).to(torch.int64).to(device).reshape(-1, 1) - ) - - hs = torch.cat(hs, dim=1).to(device) - cs = torch.cat(cs, dim=1).to(device) - scores, lm_states = rnnlm.score_token(tokens_to_score, (hs, cs)) - - count = 0 # index, used to locate score and lm states - for i in range(batch_size): - topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) - - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - topk_hyp_indexes = (topk_indexes // vocab_size).tolist() - topk_token_indexes = (topk_indexes % vocab_size).tolist() - - for k in range(len(topk_hyp_indexes)): - hyp_idx = topk_hyp_indexes[k] - hyp = A[i][hyp_idx] - - ys = hyp.ys[:] - - lm_score = hyp.lm_score - state = hyp.state - - hyp_log_prob = topk_log_probs[k] # get score of current hyp - new_token = topk_token_indexes[k] - new_timestamp = hyp.timestamp[:] - if new_token not in (blank_id, unk_id): - - ys.append(new_token) - new_timestamp.append(t) - hyp_log_prob += lm_score[new_token] * lm_scale # add the lm score - - lm_score = scores[count] - state = ( - lm_states[0][:, count, :].unsqueeze(1), - lm_states[1][:, count, :].unsqueeze(1), - ) - count += 1 - - new_hyp = Hypothesis( - ys=ys, - log_prob=hyp_log_prob, - state=state, - lm_score=lm_score, - timestamp=new_timestamp, - ) - B[i].add(new_hyp) - - B = B + finalized_B - best_hyps = [b.get_most_probable(length_norm=True) for b in B] - - sorted_ans = [h.ys[context_size:] for h in best_hyps] - sorted_timestamps = [h.timestamp for h in best_hyps] - ans = [] - ans_timestamps = [] - unsorted_indices = packed_encoder_out.unsorted_indices.tolist() - for i in range(N): - ans.append(sorted_ans[unsorted_indices[i]]) - ans_timestamps.append(sorted_timestamps[unsorted_indices[i]]) - - if not return_timestamps: - return ans - else: - return DecodingResults( - tokens=ans, - timestamps=ans_timestamps, - ) - - -def modified_beam_search_rnnlm_LODR( +def modified_beam_search_LODR( model: Transducer, encoder_out: torch.Tensor, encoder_out_lens: torch.Tensor, sp: spm.SentencePieceProcessor, LODR_lm: NgramLm, LODR_lm_scale: float, - rnnlm: RnnLmModel, - rnnlm_scale: float, + LM: LmScorer, beam: int = 4, ) -> List[List[int]]: """This function implements LODR (https://arxiv.org/abs/2203.16776) with @@ -2113,13 +1875,11 @@ def modified_beam_search_rnnlm_LODR( sp: Sentence piece generator. LODR_lm: - A low order n-gram LM + A low order n-gram LM, whose score will be subtracted during shallow fusion LODR_lm_scale: The scale of the LODR_lm - rnnlm (RnnLmModel): - RNNLM, the external language model - rnnlm_scale (float): - scale of RNNLM in shallow fusion + LM: + A neural net LM, e.g an RNNLM or transformer LM beam (int, optional): Beam size. Defaults to 4. @@ -2130,9 +1890,8 @@ def modified_beam_search_rnnlm_LODR( """ assert encoder_out.ndim == 3, encoder_out.shape assert encoder_out.size(0) >= 1, encoder_out.size(0) - assert rnnlm is not None - lm_scale = rnnlm_scale - vocab_size = rnnlm.vocab_size + assert LM is not None + lm_scale = LM.lm_scale packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( input=encoder_out, @@ -2154,7 +1913,8 @@ def modified_beam_search_rnnlm_LODR( # get initial lm score and lm state by scoring the "sos" token sos_token = torch.tensor([[sos_id]]).to(torch.int64).to(device) - init_score, init_states = rnnlm.score_token(sos_token) + lens = torch.tensor([1]).to(device) + init_score, init_states = LM.score_token(sos_token, lens) B = [HypothesisList() for _ in range(N)] for i in range(N): @@ -2162,7 +1922,7 @@ def modified_beam_search_rnnlm_LODR( Hypothesis( ys=[blank_id] * context_size, log_prob=torch.zeros(1, dtype=torch.float32, device=device), - state=init_states, # state of the RNNLM + state=init_states, # state of the NN LM lm_score=init_score.reshape(-1), state_cost=NgramLmStateCost( LODR_lm @@ -2170,7 +1930,6 @@ def modified_beam_search_rnnlm_LODR( ) ) - rnnlm.clean_cache() encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) offset = 0 @@ -2236,7 +1995,7 @@ def modified_beam_search_rnnlm_LODR( It is a little confusing here because this for-loop looks very similar to the one below. Here, we go through all top-k tokens and only add the non-blanks ones to the token_list. - The RNNLM will score those tokens given the LM states. Note that + LM will score those tokens given the LM states. Note that the variable `scores` is the LM score after seeing the new non-blank token. """ @@ -2256,21 +2015,41 @@ def modified_beam_search_rnnlm_LODR( new_token = topk_token_indexes[k] if new_token not in (blank_id, unk_id): - assert new_token != 0, new_token - token_list.append([new_token]) - # store the LSTM states - hs.append(hyp.state[0]) - cs.append(hyp.state[1]) + if LM.lm_type == "rnn": + token_list.append([new_token]) + # store the LSTM states + hs.append(hyp.state[0]) + cs.append(hyp.state[1]) + else: + # for transformer LM + token_list.append( + [sos_id] + hyp.ys[context_size:] + [new_token] + ) - # forward RNNLM to get new states and scores + # forward NN LM to get new states and scores if len(token_list) != 0: - tokens_to_score = ( - torch.tensor(token_list).to(torch.int64).to(device).reshape(-1, 1) - ) + x_lens = torch.tensor([len(tokens) for tokens in token_list]).to(device) + if LM.lm_type == "rnn": + tokens_to_score = ( + torch.tensor(token_list).to(torch.int64).to(device).reshape(-1, 1) + ) + hs = torch.cat(hs, dim=1).to(device) + cs = torch.cat(cs, dim=1).to(device) + state = (hs, cs) + else: + # for transformer LM + tokens_list = [torch.tensor(tokens) for tokens in token_list] + tokens_to_score = ( + torch.nn.utils.rnn.pad_sequence( + tokens_list, batch_first=True, padding_value=0.0 + ) + .to(device) + .to(torch.int64) + ) - hs = torch.cat(hs, dim=1).to(device) - cs = torch.cat(cs, dim=1).to(device) - scores, lm_states = rnnlm.score_token(tokens_to_score, (hs, cs)) + state = None + + scores, lm_states = LM.score_token(tokens_to_score, x_lens, state) count = 0 # index, used to locate score and lm states for i in range(batch_size): @@ -2305,18 +2084,19 @@ def modified_beam_search_rnnlm_LODR( state_cost.lm_score, hyp.state_cost.lm_score, ) - # score = score + RNNLM_score - LODR_score - # LODR_LM_scale is a negative number here + # score = score + TDLM_score - LODR_score + # LODR_LM_scale should be a negative number here hyp_log_prob += ( lm_score[new_token] * lm_scale + LODR_lm_scale * current_ngram_score ) # add the lm score lm_score = scores[count] - state = ( - lm_states[0][:, count, :].unsqueeze(1), - lm_states[1][:, count, :].unsqueeze(1), - ) + if LM.lm_type == "rnn": + state = ( + lm_states[0][:, count, :].unsqueeze(1), + lm_states[1][:, count, :].unsqueeze(1), + ) count += 1 else: state_cost = hyp.state_cost @@ -2340,3 +2120,263 @@ def modified_beam_search_rnnlm_LODR( ans.append(sorted_ans[unsorted_indices[i]]) return ans + + +def modified_beam_search_lm_shallow_fusion( + model: Transducer, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + sp: spm.SentencePieceProcessor, + LM: LmScorer, + beam: int = 4, + return_timestamps: bool = False, +) -> List[List[int]]: + """Modified_beam_search + NN LM shallow fusion + + Args: + model (Transducer): + The transducer model + encoder_out (torch.Tensor): + Encoder output in (N,T,C) + encoder_out_lens (torch.Tensor): + A 1-D tensor of shape (N,), containing the number of + valid frames in encoder_out before padding. + sp: + Sentence piece generator. + LM (LmScorer): + A neural net LM, e.g RNN or Transformer + beam (int, optional): + Beam size. Defaults to 4. + + Returns: + Return a list-of-list of token IDs. ans[i] is the decoding results + for the i-th utterance. + """ + assert encoder_out.ndim == 3, encoder_out.shape + assert encoder_out.size(0) >= 1, encoder_out.size(0) + assert LM is not None + lm_scale = LM.lm_scale + + packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( + input=encoder_out, + lengths=encoder_out_lens.cpu(), + batch_first=True, + enforce_sorted=False, + ) + + blank_id = model.decoder.blank_id + sos_id = sp.piece_to_id("") + unk_id = getattr(model, "unk_id", blank_id) + context_size = model.decoder.context_size + device = next(model.parameters()).device + + batch_size_list = packed_encoder_out.batch_sizes.tolist() + N = encoder_out.size(0) + assert torch.all(encoder_out_lens > 0), encoder_out_lens + assert N == batch_size_list[0], (N, batch_size_list) + + # get initial lm score and lm state by scoring the "sos" token + sos_token = torch.tensor([[sos_id]]).to(torch.int64).to(device) + lens = torch.tensor([1]).to(device) + init_score, init_states = LM.score_token(sos_token, lens) + + B = [HypothesisList() for _ in range(N)] + for i in range(N): + B[i].add( + Hypothesis( + ys=[blank_id] * context_size, + log_prob=torch.zeros(1, dtype=torch.float32, device=device), + state=init_states, + lm_score=init_score.reshape(-1), + timestamp=[], + ) + ) + + encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) + + offset = 0 + finalized_B = [] + for (t, batch_size) in enumerate(batch_size_list): + start = offset + end = offset + batch_size + current_encoder_out = encoder_out.data[start:end] # get batch + current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1) + # current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim) + offset = end + + finalized_B = B[batch_size:] + finalized_B + B = B[:batch_size] + + hyps_shape = get_hyps_shape(B).to(device) + + A = [list(b) for b in B] + B = [HypothesisList() for _ in range(batch_size)] + + ys_log_probs = torch.cat( + [hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps] + ) + + lm_scores = torch.cat( + [hyp.lm_score.reshape(1, -1) for hyps in A for hyp in hyps] + ) + + decoder_input = torch.tensor( + [hyp.ys[-context_size:] for hyps in A for hyp in hyps], + device=device, + dtype=torch.int64, + ) # (num_hyps, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) + decoder_out = model.joiner.decoder_proj(decoder_out) + + current_encoder_out = torch.index_select( + current_encoder_out, + dim=0, + index=hyps_shape.row_ids(1).to(torch.int64), + ) # (num_hyps, 1, 1, encoder_out_dim) + + logits = model.joiner( + current_encoder_out, + decoder_out, + project_input=False, + ) # (num_hyps, 1, 1, vocab_size) + + logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) + + log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size) + + log_probs.add_(ys_log_probs) + + vocab_size = log_probs.size(-1) + + log_probs = log_probs.reshape(-1) + + row_splits = hyps_shape.row_splits(1) * vocab_size + log_probs_shape = k2.ragged.create_ragged_shape2( + row_splits=row_splits, cached_tot_size=log_probs.numel() + ) + ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) + """ + for all hyps with a non-blank new token, score this token. + It is a little confusing here because this for-loop + looks very similar to the one below. Here, we go through all + top-k tokens and only add the non-blanks ones to the token_list. + `LM` will score those tokens given the LM states. Note that + the variable `scores` is the LM score after seeing the new + non-blank token. + """ + token_list = [] # a list of list + hs = [] + cs = [] + for i in range(batch_size): + topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + topk_hyp_indexes = (topk_indexes // vocab_size).tolist() + topk_token_indexes = (topk_indexes % vocab_size).tolist() + for k in range(len(topk_hyp_indexes)): + hyp_idx = topk_hyp_indexes[k] + hyp = A[i][hyp_idx] + + new_token = topk_token_indexes[k] + if new_token not in (blank_id, unk_id): + if LM.lm_type == "rnn": + token_list.append([new_token]) + # store the LSTM states + hs.append(hyp.state[0]) + cs.append(hyp.state[1]) + else: + # for transformer LM + token_list.append( + [sos_id] + hyp.ys[context_size:] + [new_token] + ) + + if len(token_list) != 0: + x_lens = torch.tensor([len(tokens) for tokens in token_list]).to(device) + if LM.lm_type == "rnn": + tokens_to_score = ( + torch.tensor(token_list).to(torch.int64).to(device).reshape(-1, 1) + ) + hs = torch.cat(hs, dim=1).to(device) + cs = torch.cat(cs, dim=1).to(device) + state = (hs, cs) + else: + # for transformer LM + tokens_list = [torch.tensor(tokens) for tokens in token_list] + tokens_to_score = ( + torch.nn.utils.rnn.pad_sequence( + tokens_list, batch_first=True, padding_value=0.0 + ) + .to(device) + .to(torch.int64) + ) + + state = None + + scores, lm_states = LM.score_token(tokens_to_score, x_lens, state) + + count = 0 # index, used to locate score and lm states + for i in range(batch_size): + topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + topk_hyp_indexes = (topk_indexes // vocab_size).tolist() + topk_token_indexes = (topk_indexes % vocab_size).tolist() + + for k in range(len(topk_hyp_indexes)): + hyp_idx = topk_hyp_indexes[k] + hyp = A[i][hyp_idx] + + ys = hyp.ys[:] + + lm_score = hyp.lm_score + state = hyp.state + + hyp_log_prob = topk_log_probs[k] # get score of current hyp + new_token = topk_token_indexes[k] + new_timestamp = hyp.timestamp[:] + if new_token not in (blank_id, unk_id): + + ys.append(new_token) + new_timestamp.append(t) + + hyp_log_prob += lm_score[new_token] * lm_scale # add the lm score + + lm_score = scores[count] + if LM.lm_type == "rnn": + state = ( + lm_states[0][:, count, :].unsqueeze(1), + lm_states[1][:, count, :].unsqueeze(1), + ) + count += 1 + + new_hyp = Hypothesis( + ys=ys, + log_prob=hyp_log_prob, + state=state, + lm_score=lm_score, + timestamp=new_timestamp, + ) + B[i].add(new_hyp) + + B = B + finalized_B + best_hyps = [b.get_most_probable(length_norm=True) for b in B] + + sorted_ans = [h.ys[context_size:] for h in best_hyps] + sorted_timestamps = [h.timestamp for h in best_hyps] + ans = [] + ans_timestamps = [] + unsorted_indices = packed_encoder_out.unsorted_indices.tolist() + for i in range(N): + ans.append(sorted_ans[unsorted_indices[i]]) + ans_timestamps.append(sorted_timestamps[unsorted_indices[i]]) + + if not return_timestamps: + return ans + else: + return DecodingResults( + tokens=ans, + timestamps=ans_timestamps, + ) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py index e00aab34a..109a94a69 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py @@ -92,36 +92,37 @@ Usage: --max-contexts 8 \ --max-states 64 -(8) modified beam search (with RNNLM shallow fusion) +(8) modified beam search (with LM shallow fusion) ./pruned_transducer_stateless3/decode.py \ --epoch 28 \ --avg 15 \ --exp-dir ./pruned_transducer_stateless3/exp \ --max-duration 600 \ - --decoding-method modified_beam_search_rnnlm_shallow_fusion \ - --beam 4 \ - --rnn-lm-scale 0.3 \ - --rnn-lm-exp-dir /path/to/RNNLM \ + --decoding-method modified_beam_search_lm_shallow_fusion \ + --beam-size 4 \ + --lm-type rnn \ + --lm-scale 0.3 \ + --lm-exp-dir /path/to/LM \ --rnn-lm-epoch 99 \ --rnn-lm-avg 1 \ --rnn-lm-num-layers 3 \ --rnn-lm-tie-weights 1 -(9) modified beam search with RNNLM shallow fusion + LODR +(9) modified beam search with LM shallow fusion + LODR ./pruned_transducer_stateless3/decode.py \ --epoch 28 \ --avg 15 \ --max-duration 600 \ --exp-dir ./pruned_transducer_stateless3/exp \ - --decoding-method modified_beam_search_rnnlm_LODR \ - --beam 4 \ - --max-contexts 4 \ - --rnn-lm-scale 0.4 \ - --rnn-lm-exp-dir /path/to/RNNLM/exp \ + --decoding-method modified_beam_search_LODR \ + --beam-size 4 \ + --lm-type rnn \ + --lm-scale 0.4 \ + --lm-exp-dir /path/to/LM \ --rnn-lm-epoch 99 \ --rnn-lm-avg 1 \ --rnn-lm-num-layers 3 \ - --rnn-lm-tie-weights 1 \ + --rnn-lm-tie-weights 1 --tokens-ngram 2 \ --ngram-lm-scale -0.16 \ """ @@ -149,14 +150,14 @@ from beam_search import ( greedy_search, greedy_search_batch, modified_beam_search, + modified_beam_search_lm_shallow_fusion, + modified_beam_search_LODR, modified_beam_search_ngram_rescoring, - modified_beam_search_rnnlm_LODR, - modified_beam_search_rnnlm_shallow_fusion, ) from librispeech import LibriSpeech from train import add_model_arguments, get_params, get_transducer_model -from icefall import NgramLm +from icefall import LmScorer, NgramLm from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.lexicon import Lexicon from icefall.rnn_lm.model import RnnLmModel @@ -240,8 +241,8 @@ def get_parser(): - fast_beam_search_nbest_oracle - fast_beam_search_nbest_LG - modified_beam_search_ngram_rescoring - - modified_beam_search_rnnlm_shallow_fusion - - modified_beam_search_rnnlm_LODR + - modified_beam_search_lm_shallow_fusion + - modified_beam_search_LODR If you use fast_beam_search_nbest_LG, you have to specify `--lang-dir`, which should contain `LG.pt`. """, @@ -392,58 +393,28 @@ def get_parser(): ) parser.add_argument( - "--rnn-lm-exp-dir", - type=str, - default="rnn_lm/exp", - help="""Used only when --method is rnn-lm. - It specifies the path to RNN LM exp dir. - """, - ) - - parser.add_argument( - "--rnn-lm-epoch", - type=int, - default=7, - help="""Used only when --method is rnn-lm. - It specifies the checkpoint to use. - """, - ) - - parser.add_argument( - "--rnn-lm-avg", - type=int, - default=2, - help="""Used only when --method is rnn-lm. - It specifies the number of checkpoints to average. - """, - ) - - parser.add_argument( - "--rnn-lm-embedding-dim", - type=int, - default=2048, - help="Embedding dim of the model", - ) - - parser.add_argument( - "--rnn-lm-hidden-dim", - type=int, - default=2048, - help="Hidden dim of the model", - ) - - parser.add_argument( - "--rnn-lm-num-layers", - type=int, - default=4, - help="Number of RNN layers the model", - ) - parser.add_argument( - "--rnn-lm-tie-weights", + "--use-shallow-fusion", type=str2bool, - default=True, - help="""True to share the weights between the input embedding layer and the - last output linear layer + default=False, + help="""Use neural network LM for shallow fusion. + If you want to use LODR, you will also need to set this to true + """, + ) + + parser.add_argument( + "--lm-type", + type=str, + default="rnn", + help="Type of NN lm", + choices=["rnn", "transformer"], + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.3, + help="""The scale of the neural network LM + Used only when `--use-shallow-fusion` is set to True. """, ) @@ -481,7 +452,7 @@ def decode_one_batch( ngram_lm: Optional[NgramLm] = None, ngram_lm_scale: float = 1.0, rnn_lm_model: Optional[RnnLmModel] = None, - rnnlm_scale: float = 1.0, + LM: Optional[LmScorer] = None, ) -> Dict[str, List[List[str]]]: """Decode one batch and return the result in a dict. The dict has the following format: @@ -515,10 +486,9 @@ def decode_one_batch( fast_beam_search_nbest, fast_beam_search_nbest_oracle, or fast_beam_search_with_nbest_rescoring. It an FsaVec containing an acceptor. - rnn_lm_model: - A rnnlm which can be used for rescoring or shallow fusion - rnnlm_scale: - The scale of the rnnlm. + LM: + A neural net LM for shallow fusion. Only used when `--use-shallow-fusion` + set to true. ngram_lm: A ngram lm. Used in LODR decoding. ngram_lm_scale: @@ -697,20 +667,19 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif params.decoding_method == "modified_beam_search_rnnlm_shallow_fusion": - hyp_tokens = modified_beam_search_rnnlm_shallow_fusion( + elif params.decoding_method == "modified_beam_search_lm_shallow_fusion": + hyp_tokens = modified_beam_search_lm_shallow_fusion( model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, beam=params.beam_size, sp=sp, - rnnlm=rnn_lm_model, - rnnlm_scale=rnnlm_scale, + LM=LM, ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif params.decoding_method == "modified_beam_search_rnnlm_LODR": - hyp_tokens = modified_beam_search_rnnlm_LODR( + elif params.decoding_method == "modified_beam_search_LODR": + hyp_tokens = modified_beam_search_LODR( model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, @@ -718,8 +687,7 @@ def decode_one_batch( sp=sp, LODR_lm=ngram_lm, LODR_lm_scale=ngram_lm_scale, - rnnlm=rnn_lm_model, - rnnlm_scale=rnnlm_scale, + LM=LM, ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) @@ -812,7 +780,7 @@ def decode_dataset( ngram_lm: Optional[NgramLm] = None, ngram_lm_scale: float = 1.0, rnn_lm_model: Optional[RnnLmModel] = None, - rnnlm_scale: float = 1.0, + LM: Optional[LmScorer] = None, ) -> Dict[str, List[Tuple[List[str], List[str]]]]: """Decode dataset. @@ -836,6 +804,8 @@ def decode_dataset( fast_beam_search_nbest, fast_beam_search_nbest_oracle, or fast_beam_search_with_nbest_rescoring. It's an FsaVec containing an acceptor. + LM: + A neural network LM, used during shallow fusion Returns: Return a dict, whose key may be "greedy_search" if greedy search is used, or it may be "beam_7" if beam size of 7 is used. @@ -871,7 +841,7 @@ def decode_dataset( ngram_lm=ngram_lm, ngram_lm_scale=ngram_lm_scale, rnn_lm_model=rnn_lm_model, - rnnlm_scale=rnnlm_scale, + LM=LM, ) for name, hyps in hyps_dict.items(): @@ -1005,6 +975,7 @@ def load_ngram_LM( def main(): parser = get_parser() AsrDataModule.add_arguments(parser) + LmScorer.add_arguments(parser) args = parser.parse_args() args.exp_dir = Path(args.exp_dir) @@ -1022,9 +993,9 @@ def main(): "modified_beam_search", "fast_beam_search_with_nbest_rescoring", "fast_beam_search_with_nbest_rnn_rescoring", - "modified_beam_search_rnnlm_LODR", + "modified_beam_search_LODR", + "modified_beam_search_lm_shallow_fusion", "modified_beam_search_ngram_rescoring", - "modified_beam_search_rnnlm_shallow_fusion", ) params.res_dir = params.exp_dir / params.decoding_method @@ -1055,12 +1026,18 @@ def main(): params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" params.suffix += f"-temperature-{params.temperature}" - if "rnnlm" in params.decoding_method: - params.suffix += f"-rnnlm-lm-scale-{params.rnn_lm_scale}" - if "LODR" in params.decoding_method: - params.suffix += "-LODR" if "ngram" in params.decoding_method: params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" + if params.use_shallow_fusion: + if params.lm_type == "rnn": + params.suffix += f"-rnnlm-lm-scale-{params.lm_scale}" + elif params.lm_type == "transformer": + params.suffix += f"-transformer-lm-scale-{params.lm_scale}" + + if "LODR" in params.decoding_method: + params.suffix += ( + f"-LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}" + ) setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") logging.info("Decoding started") @@ -1195,28 +1172,19 @@ def main(): ngram_lm = None ngram_lm_scale = None - # only load rnnlm if used - if "rnnlm" in params.decoding_method: - rnn_lm_scale = params.rnn_lm_scale - - rnn_lm_model = RnnLmModel( - vocab_size=params.vocab_size, - embedding_dim=params.rnn_lm_embedding_dim, - hidden_dim=params.rnn_lm_hidden_dim, - num_layers=params.rnn_lm_num_layers, - tie_weights=params.rnn_lm_tie_weights, + # only load the neural network LM if doing shallow fusion + if params.use_shallow_fusion: + LM = LmScorer( + lm_type=params.lm_type, + params=params, + device=device, + lm_scale=params.lm_scale, ) - assert params.rnn_lm_avg == 1 + LM.to(device) + LM.eval() - load_checkpoint( - f"{params.rnn_lm_exp_dir}/epoch-{params.rnn_lm_epoch}.pt", - rnn_lm_model, - ) - rnn_lm_model.to(device) - rnn_lm_model.eval() else: - rnn_lm_model = None - rnn_lm_scale = 0.0 + LM = None num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") @@ -1247,7 +1215,7 @@ def main(): ngram_lm=ngram_lm, ngram_lm_scale=ngram_lm_scale, rnn_lm_model=rnn_lm_model, - rnnlm_scale=rnn_lm_scale, + LM=LM, ) save_results( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py index 8b993f638..90b0fcf4b 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py @@ -87,22 +87,39 @@ Usage: --max-contexts 8 \ --max-states 64 -(8) modified beam search with RNNLM shallow fusion (with LG) +(8) modified beam search with RNNLM shallow fusion ./pruned_transducer_stateless5/decode.py \ --epoch 35 \ --avg 15 \ --exp-dir ./pruned_transducer_stateless5/exp \ --max-duration 600 \ - --decoding-method fast_beam_search_nbest_LG \ - --beam 4 \ - --max-contexts 4 \ - --rnn-lm-scale 0.4 \ - --rnn-lm-exp-dir /path/to/RNNLM/exp \ + --decoding-method modified_beam_search_lm_shallow_fusion \ + --beam-size 4 \ + --lm-type rnn \ + --lm-scale 0.3 \ + --lm-exp-dir /path/to/LM \ --rnn-lm-epoch 99 \ --rnn-lm-avg 1 \ --rnn-lm-num-layers 3 \ --rnn-lm-tie-weights 1 +(9) modified beam search with LM shallow fusion + LODR +./pruned_transducer_stateless5/decode.py \ + --epoch 28 \ + --avg 15 \ + --max-duration 600 \ + --exp-dir ./pruned_transducer_stateless5/exp \ + --decoding-method modified_beam_search_LODR \ + --beam-size 4 \ + --lm-type rnn \ + --lm-scale 0.4 \ + --lm-exp-dir /path/to/LM \ + --rnn-lm-epoch 99 \ + --rnn-lm-avg 1 \ + --rnn-lm-num-layers 3 \ + --rnn-lm-tie-weights 1 + --tokens-ngram 2 \ + --ngram-lm-scale -0.16 \ """ @@ -128,10 +145,13 @@ from beam_search import ( greedy_search, greedy_search_batch, modified_beam_search, - modified_beam_search_rnnlm_shallow_fusion, + modified_beam_search_lm_shallow_fusion, + modified_beam_search_LODR, + modified_beam_search_ngram_rescoring, ) from train import add_model_arguments, get_params, get_transducer_model +from icefall import LmScorer, NgramLm from icefall.checkpoint import ( average_checkpoints, average_checkpoints_with_averaged_model, @@ -139,7 +159,6 @@ from icefall.checkpoint import ( load_checkpoint, ) from icefall.lexicon import Lexicon -from icefall.rnn_lm.model import RnnLmModel from icefall.utils import ( AttributeDict, setup_logger, @@ -229,7 +248,8 @@ def get_parser(): - fast_beam_search_nbest - fast_beam_search_nbest_oracle - fast_beam_search_nbest_LG - - modified_beam_search_rnnlm_shallow_fusion # for rnn lm shallow fusion + - modified_beam_search_lm_shallow_fusion # for rnn lm shallow fusion + - modified_beam_search_LODR If you use fast_beam_search_nbest_LG, you have to specify `--lang-dir`, which should contain `LG.pt`. """, @@ -342,69 +362,49 @@ def get_parser(): ) parser.add_argument( - "--rnn-lm-scale", - type=float, - default=0.0, - help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion. - It specifies the path to RNN LM exp dir. - """, - ) - - parser.add_argument( - "--rnn-lm-exp-dir", - type=str, - default="rnn_lm/exp", - help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion. - It specifies the path to RNN LM exp dir. - """, - ) - - parser.add_argument( - "--rnn-lm-epoch", - type=int, - default=7, - help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion. - It specifies the checkpoint to use. - """, - ) - - parser.add_argument( - "--rnn-lm-avg", - type=int, - default=2, - help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion. - It specifies the number of checkpoints to average. - """, - ) - - parser.add_argument( - "--rnn-lm-embedding-dim", - type=int, - default=2048, - help="Embedding dim of the model", - ) - - parser.add_argument( - "--rnn-lm-hidden-dim", - type=int, - default=2048, - help="Hidden dim of the model", - ) - - parser.add_argument( - "--rnn-lm-num-layers", - type=int, - default=4, - help="Number of RNN layers the model", - ) - parser.add_argument( - "--rnn-lm-tie-weights", + "--use-shallow-fusion", type=str2bool, default=False, - help="""True to share the weights between the input embedding layer and the - last output linear layer + help="""Use neural network LM for shallow fusion. + If you want to use LODR, you will also need to set this to true """, ) + + parser.add_argument( + "--lm-type", + type=str, + default="rnn", + help="Type of NN lm", + choices=["rnn", "transformer"], + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.3, + help="""The scale of the neural network LM + Used only when `--use-shallow-fusion` is set to True. + """, + ) + + parser.add_argument( + "--tokens-ngram", + type=int, + default=3, + help="""Token Ngram used for rescoring. + Used only when the decoding method is + modified_beam_search_ngram_rescoring, or LODR + """, + ) + + parser.add_argument( + "--backoff-id", + type=int, + default=500, + help="""ID of the backoff symbol. + Used only when the decoding method is + modified_beam_search_ngram_rescoring""", + ) add_model_arguments(parser) return parser @@ -417,8 +417,9 @@ def decode_one_batch( batch: dict, word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, - rnnlm: Optional[RnnLmModel] = None, - rnnlm_scale: float = 1.0, + ngram_lm: Optional[NgramLm] = None, + ngram_lm_scale: float = 1.0, + LM: Optional[LmScorer] = None, ) -> Dict[str, List[List[str]]]: """Decode one batch and return the result in a dict. The dict has the following format: @@ -447,6 +448,13 @@ def decode_one_batch( The decoding graph. Can be either a `k2.trivial_graph` or LG, Used only when --decoding_method is fast_beam_search, fast_beam_search_LG, fast_beam_search_nbest, fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + LM: + A neural net LM for shallow fusion. Only used when `--use-shallow-fusion` + set to true. + ngram_lm: + A ngram lm. Used in LODR decoding. + ngram_lm_scale: + The scale of the ngram language model. Returns: Return the decoding result. See above description for the format of the returned dict. @@ -559,15 +567,38 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif params.decoding_method == "modified_beam_search_rnnlm_shallow_fusion": - hyp_tokens = modified_beam_search_rnnlm_shallow_fusion( + elif params.decoding_method == "modified_beam_search_ngram_rescoring": + hyp_tokens = modified_beam_search_ngram_rescoring( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, + beam=params.beam_size, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_lm_shallow_fusion": + hyp_tokens = modified_beam_search_lm_shallow_fusion( model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, beam=params.beam_size, sp=sp, - rnnlm=rnnlm, - rnnlm_scale=rnnlm_scale, + LM=LM, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_LODR": + hyp_tokens = modified_beam_search_LODR( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + sp=sp, + LODR_lm=ngram_lm, + LODR_lm_scale=ngram_lm_scale, + LM=LM, ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) @@ -620,8 +651,9 @@ def decode_dataset( sp: spm.SentencePieceProcessor, word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, - rnnlm: Optional[RnnLmModel] = None, - rnnlm_scale: float = 1.0, + ngram_lm: Optional[NgramLm] = None, + ngram_lm_scale: float = 1.0, + LM: Optional[LmScorer] = None, ) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: """Decode dataset. @@ -640,6 +672,8 @@ def decode_dataset( The decoding graph. Can be either a `k2.trivial_graph` or LG, Used only when --decoding_method is fast_beam_search, fast_beam_search_LG, fast_beam_search_nbest, fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + LM: + A neural network LM, used during shallow fusion Returns: Return a dict, whose key may be "greedy_search" if greedy search is used, or it may be "beam_7" if beam size of 7 is used. @@ -663,7 +697,6 @@ def decode_dataset( for batch_idx, batch in enumerate(dl): texts = batch["supervisions"]["text"] cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] - logging.info(f"Decoding {batch_idx}-th batch") hyps_dict = decode_one_batch( params=params, @@ -672,8 +705,9 @@ def decode_dataset( decoding_graph=decoding_graph, word_table=word_table, batch=batch, - rnnlm=rnnlm, - rnnlm_scale=rnnlm_scale, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, + LM=LM, ) for name, hyps in hyps_dict.items(): @@ -742,6 +776,7 @@ def save_results( def main(): parser = get_parser() LibriSpeechAsrDataModule.add_arguments(parser) + LmScorer.add_arguments(parser) args = parser.parse_args() args.exp_dir = Path(args.exp_dir) @@ -757,7 +792,8 @@ def main(): "fast_beam_search_nbest_LG", "fast_beam_search_nbest_oracle", "modified_beam_search", - "modified_beam_search_rnnlm_shallow_fusion", + "modified_beam_search_lm_shallow_fusion", + "modified_beam_search_LODR", ) params.res_dir = params.exp_dir / params.decoding_method @@ -783,7 +819,18 @@ def main(): params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" - params.suffix += f"-rnnlm-lm-scale-{params.rnn_lm_scale}" + if "ngram" in params.decoding_method: + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" + if params.use_shallow_fusion: + if params.lm_type == "rnn": + params.suffix += f"-rnnlm-lm-scale-{params.lm_scale}" + elif params.lm_type == "transformer": + params.suffix += f"-transformer-lm-scale-{params.lm_scale}" + + if "LODR" in params.decoding_method: + params.suffix += ( + f"-LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}" + ) if params.use_averaged_model: params.suffix += "-use-averaged-model" @@ -895,24 +942,34 @@ def main(): model.to(device) model.eval() - rnn_lm_model = None - rnn_lm_scale = params.rnn_lm_scale - if params.decoding_method == "modified_beam_search_rnnlm_shallow_fusion": - rnn_lm_model = RnnLmModel( - vocab_size=params.vocab_size, - embedding_dim=params.rnn_lm_embedding_dim, - hidden_dim=params.rnn_lm_hidden_dim, - num_layers=params.rnn_lm_num_layers, - tie_weights=params.rnn_lm_tie_weights, + # only load N-gram LM when needed + if "ngram" in params.decoding_method or "LODR" in params.decoding_method: + lm_filename = f"{params.tokens_ngram}gram.fst.txt" + logging.info(f"lm filename: {lm_filename}") + ngram_lm = NgramLm( + str(params.lang_dir / lm_filename), + backoff_id=params.backoff_id, + is_binary=False, ) - assert params.rnn_lm_avg == 1 + logging.info(f"num states: {ngram_lm.lm.num_states}") + ngram_lm_scale = params.ngram_lm_scale + else: + ngram_lm = None + ngram_lm_scale = None - load_checkpoint( - f"{params.rnn_lm_exp_dir}/epoch-{params.rnn_lm_epoch}.pt", - rnn_lm_model, + # only load the neural network LM if doing shallow fusion + if params.use_shallow_fusion: + LM = LmScorer( + lm_type=params.lm_type, + params=params, + device=device, + lm_scale=params.lm_scale, ) - rnn_lm_model.to(device) - rnn_lm_model.eval() + LM.to(device) + LM.eval() + + else: + LM = None if "fast_beam_search" in params.decoding_method: if "LG" in params.decoding_method: @@ -955,8 +1012,9 @@ def main(): sp=sp, word_table=word_table, decoding_graph=decoding_graph, - rnnlm=rnn_lm_model, - rnnlm_scale=rnn_lm_scale, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, + LM=LM, ) save_results( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py index bc15948fc..b9bce465f 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py @@ -1,7 +1,8 @@ #!/usr/bin/env python3 # # Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, -# Zengwei Yao) +# Zengwei Yao, +# Xiaoyu Yang) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -91,6 +92,41 @@ Usage: --beam 20.0 \ --max-contexts 8 \ --max-states 64 + +(8) modified beam search with RNNLM shallow fusion +./pruned_transducer_stateless5/decode.py \ + --epoch 35 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless5/exp \ + --max-duration 600 \ + --decoding-method modified_beam_search_lm_shallow_fusion \ + --beam-size 4 \ + --lm-type rnn \ + --lm-scale 0.3 \ + --lm-exp-dir /path/to/LM \ + --rnn-lm-epoch 99 \ + --rnn-lm-avg 1 \ + --rnn-lm-num-layers 3 \ + --rnn-lm-tie-weights 1 + +(9) modified beam search with LM shallow fusion + LODR +./pruned_transducer_stateless5/decode.py \ + --epoch 28 \ + --avg 15 \ + --max-duration 600 \ + --exp-dir ./pruned_transducer_stateless5/exp \ + --decoding-method modified_beam_search_LODR \ + --beam-size 4 \ + --lm-type rnn \ + --lm-scale 0.4 \ + --lm-exp-dir /path/to/LM \ + --rnn-lm-epoch 99 \ + --rnn-lm-avg 1 \ + --rnn-lm-num-layers 3 \ + --rnn-lm-tie-weights 1 + --tokens-ngram 2 \ + --ngram-lm-scale -0.16 \ + """ @@ -115,9 +151,13 @@ from beam_search import ( greedy_search, greedy_search_batch, modified_beam_search, + modified_beam_search_lm_shallow_fusion, + modified_beam_search_LODR, + modified_beam_search_ngram_rescoring, ) from train import add_model_arguments, get_params, get_transducer_model +from icefall import LmScorer, NgramLm from icefall.checkpoint import ( average_checkpoints, average_checkpoints_with_averaged_model, @@ -213,6 +253,8 @@ def get_parser(): - fast_beam_search_nbest - fast_beam_search_nbest_oracle - fast_beam_search_nbest_LG + - modified_beam_search_lm_shallow_fusion # for rnn lm shallow fusion + - modified_beam_search_LODR If you use fast_beam_search_nbest_LG, you have to specify `--lang-dir`, which should contain `LG.pt`. """, @@ -274,6 +316,7 @@ def get_parser(): default=2, help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) + parser.add_argument( "--max-sym-per-frame", type=int, @@ -323,6 +366,50 @@ def get_parser(): help="left context can be seen during decoding (in frames after subsampling)", ) + parser.add_argument( + "--use-shallow-fusion", + type=str2bool, + default=False, + help="""Use neural network LM for shallow fusion. + If you want to use LODR, you will also need to set this to true + """, + ) + + parser.add_argument( + "--lm-type", + type=str, + default="rnn", + help="Type of NN lm", + choices=["rnn", "transformer"], + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.3, + help="""The scale of the neural network LM + Used only when `--use-shallow-fusion` is set to True. + """, + ) + + parser.add_argument( + "--tokens-ngram", + type=int, + default=3, + help="""Token Ngram used for rescoring. + Used only when the decoding method is + modified_beam_search_ngram_rescoring, or LODR + """, + ) + + parser.add_argument( + "--backoff-id", + type=int, + default=500, + help="""ID of the backoff symbol. + Used only when the decoding method is + modified_beam_search_ngram_rescoring""", + ) add_model_arguments(parser) return parser @@ -335,6 +422,9 @@ def decode_one_batch( batch: dict, word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, + ngram_lm: Optional[NgramLm] = None, + ngram_lm_scale: float = 1.0, + LM: Optional[LmScorer] = None, ) -> Dict[str, List[List[str]]]: """Decode one batch and return the result in a dict. The dict has the following format: @@ -363,6 +453,13 @@ def decode_one_batch( The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used only when --decoding_method is fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + LM: + A neural net LM for shallow fusion. Only used when `--use-shallow-fusion` + set to true. + ngram_lm: + A ngram lm. Used in LODR decoding. + ngram_lm_scale: + The scale of the ngram language model. Returns: Return the decoding result. See above description for the format of the returned dict. @@ -468,6 +565,30 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_lm_shallow_fusion": + hyp_tokens = modified_beam_search_lm_shallow_fusion( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + sp=sp, + LM=LM, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_LODR": + hyp_tokens = modified_beam_search_LODR( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + sp=sp, + LODR_lm=ngram_lm, + LODR_lm_scale=ngram_lm_scale, + LM=LM, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) else: batch_size = encoder_out.size(0) @@ -517,6 +638,9 @@ def decode_dataset( sp: spm.SentencePieceProcessor, word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, + ngram_lm: Optional[NgramLm] = None, + ngram_lm_scale: float = 1.0, + LM: Optional[LmScorer] = None, ) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: """Decode dataset. @@ -535,6 +659,8 @@ def decode_dataset( The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used only when --decoding_method is fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + LM: + A neural network LM, used during shallow fusion Returns: Return a dict, whose key may be "greedy_search" if greedy search is used, or it may be "beam_7" if beam size of 7 is used. @@ -566,6 +692,9 @@ def decode_dataset( decoding_graph=decoding_graph, word_table=word_table, batch=batch, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, + LM=LM, ) for name, hyps in hyps_dict.items(): @@ -634,6 +763,7 @@ def save_results( def main(): parser = get_parser() LibriSpeechAsrDataModule.add_arguments(parser) + LmScorer.add_arguments(parser) args = parser.parse_args() args.exp_dir = Path(args.exp_dir) @@ -648,6 +778,8 @@ def main(): "fast_beam_search_nbest_LG", "fast_beam_search_nbest_oracle", "modified_beam_search", + "modified_beam_search_lm_shallow_fusion", + "modified_beam_search_LODR", ) params.res_dir = params.exp_dir / params.decoding_method @@ -675,6 +807,19 @@ def main(): params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + if "ngram" in params.decoding_method: + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" + if params.use_shallow_fusion: + if params.lm_type == "rnn": + params.suffix += f"-rnnlm-lm-scale-{params.lm_scale}" + elif params.lm_type == "transformer": + params.suffix += f"-transformer-lm-scale-{params.lm_scale}" + + if "LODR" in params.decoding_method: + params.suffix += ( + f"-LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}" + ) + if params.use_averaged_model: params.suffix += "-use-averaged-model" @@ -785,6 +930,34 @@ def main(): model.to(device) model.eval() + # only load N-gram LM when needed + if "ngram" in params.decoding_method or "LODR" in params.decoding_method: + lm_filename = f"{params.tokens_ngram}gram.fst.txt" + logging.info(f"lm filename: {lm_filename}") + ngram_lm = NgramLm( + str(params.lang_dir / lm_filename), + backoff_id=params.backoff_id, + is_binary=False, + ) + logging.info(f"num states: {ngram_lm.lm.num_states}") + ngram_lm_scale = params.ngram_lm_scale + else: + ngram_lm = None + ngram_lm_scale = None + + # only load the neural network LM if doing shallow fusion + if params.use_shallow_fusion: + LM = LmScorer( + lm_type=params.lm_type, + params=params, + device=device, + lm_scale=params.lm_scale, + ) + LM.to(device) + LM.eval() + + else: + LM = None if "fast_beam_search" in params.decoding_method: if params.decoding_method == "fast_beam_search_nbest_LG": lexicon = Lexicon(params.lang_dir) @@ -826,6 +999,9 @@ def main(): sp=sp, word_table=word_table, decoding_graph=decoding_graph, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, + LM=LM, ) save_results( diff --git a/icefall/__init__.py b/icefall/__init__.py index 27ad74213..82d21706c 100644 --- a/icefall/__init__.py +++ b/icefall/__init__.py @@ -68,3 +68,5 @@ from .utils import ( ) from .ngram_lm import NgramLm, NgramLmStateCost + +from .lm_wrapper import LmScorer diff --git a/icefall/lm_wrapper.py b/icefall/lm_wrapper.py new file mode 100644 index 000000000..0468befd0 --- /dev/null +++ b/icefall/lm_wrapper.py @@ -0,0 +1,254 @@ +# Copyright (c) 2022 Xiaomi Corporation (authors: Xiaoyu Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import logging + +import torch + +from icefall.checkpoint import average_checkpoints, load_checkpoint +from icefall.rnn_lm.model import RnnLmModel +from icefall.transformer_lm.model import TransformerLM +from icefall.utils import AttributeDict, str2bool + + +class LmScorer(torch.nn.Module): + """This is a wrapper for NN LMs + The language models supported include: + RNN, + Transformer + """ + + def __init__( + self, + lm_type: str, + params: AttributeDict, + device, + lm_scale: float = 0.3, + ): + super(LmScorer, self).__init__() + assert lm_type in ["rnn", "transformer"], f"{lm_type} is not supported" + self.lm_type = lm_type + self.lm = self.get_lm(lm_type, device, params) + self.lm_scale = lm_scale + self.params = params + + @classmethod + def add_arguments(cls, parser): + # LM general arguments + parser.add_argument( + "--vocab-size", + type=int, + default=500, + ) + + parser.add_argument( + "--lm-epoch", + type=int, + default=7, + help="""Which epoch to be used + """, + ) + + parser.add_argument( + "--lm-avg", + type=int, + default=1, + help="""Number of checkpoints to be averaged + """, + ) + + parser.add_argument("--lm-exp-dir", type=str, help="Path to LM experiments") + + # Now RNNLM related arguments + parser.add_argument( + "--rnn-lm-embedding-dim", + type=int, + default=2048, + help="Embedding dim of the model", + ) + + parser.add_argument( + "--rnn-lm-hidden-dim", + type=int, + default=2048, + help="Hidden dim of the model", + ) + + parser.add_argument( + "--rnn-lm-num-layers", + type=int, + default=3, + help="Number of RNN layers the model", + ) + + parser.add_argument( + "--rnn-lm-tie-weights", + type=str2bool, + default=True, + help="""True to share the weights between the input embedding layer and the + last output linear layer + """, + ) + + # Now transformers + parser.add_argument( + "--transformer-lm-exp-dir", type=str, help="Directory of transformer LM exp" + ) + + parser.add_argument( + "--transformer-lm-dim-feedforward", + type=int, + default=2048, + help="Dimension of FFW module in transformer", + ) + + parser.add_argument( + "--transformer-lm-encoder-dim", + type=int, + default=768, + help="Encoder dimension of transformer", + ) + + parser.add_argument( + "--transformer-lm-embedding-dim", + type=int, + default=768, + help="Input embedding dimension of transformer", + ) + + parser.add_argument( + "--transformer-lm-nhead", + type=int, + default=8, + help="Number of attention heads in transformer", + ) + + parser.add_argument( + "--transformer-lm-num-layers", + type=int, + default=16, + help="Number of encoder layers in transformer", + ) + + parser.add_argument( + "--transformer-lm-tie-weights", + type=str2bool, + default=True, + help="If tie weights in transformer LM", + ) + + def get_lm(self, lm_type: str, device, params: AttributeDict) -> torch.nn.Module: + """Return the neural network LM + + Args: + lm_type (str): Type name of NN LM + """ + if lm_type == "rnn": + model = RnnLmModel( + vocab_size=params.vocab_size, + embedding_dim=params.rnn_lm_embedding_dim, + hidden_dim=params.rnn_lm_hidden_dim, + num_layers=params.rnn_lm_num_layers, + tie_weights=params.rnn_lm_tie_weights, + ) + + if params.lm_avg == 1: + load_checkpoint( + f"{params.lm_exp_dir}/epoch-{params.lm_epoch}.pt", model + ) + model.to(device) + else: + start = params.lm_epoch - params.lm_avg + 1 + filenames = [] + for i in range(start, params.lm_epoch + 1): + if start >= 0: + filenames.append(f"{params.lm_exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + + elif lm_type == "transformer": + model = TransformerLM( + vocab_size=params.vocab_size, + d_model=params.transformer_lm_encoder_dim, + embedding_dim=params.transformer_lm_embedding_dim, + dim_feedforward=params.transformer_lm_dim_feedforward, + nhead=params.transformer_lm_nhead, + num_layers=params.transformer_lm_num_layers, + tie_weights=params.transformer_lm_tie_weights, + params=params, + ) + + if params.lm_avg == 1: + load_checkpoint( + f"{params.lm_exp_dir}/epoch-{params.lm_epoch}.pt", model + ) + model.to(device) + else: + start = params.lm_epoch - params.lm_avg + 1 + filenames = [] + for i in range(start, params.lm_epoch + 1): + if start >= 0: + filenames.append(f"{params.lm_exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + raise NotImplementedError() + + return model + + def score_token(self, x: torch.Tensor, x_lens: torch.Tensor, state=None): + """Score the input and return the prediction + This requires the lm to have the method `score_token` + Args: + x (torch.Tensor): Input tokens + x_lens (torch.Tensor): Length of the input tokens + state (optional): LM states + + """ + return self.lm.score_token(x, x_lens, state) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + LmScorer.add_arguments(parser) + args = parser.parse_args() + + params = AttributeDict() + params.update(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + Scorer = LmScorer(params=params, device=device) + Scorer.eval() + + x = ( + torch.tensor([[1, 4, 19, 256, 77], [1, 4, 19, 256, 77]]) + .to(device) + .to(torch.int64) + ) + x_lens = torch.tensor([5, 5]).to(device) + + state = None + + score, state = Scorer.score(x, x_lens) + print(score.shape) + print(score[0]) + print(score[1]) diff --git a/icefall/rnn_lm/model.py b/icefall/rnn_lm/model.py index 3598a4857..08eb753b5 100644 --- a/icefall/rnn_lm/model.py +++ b/icefall/rnn_lm/model.py @@ -153,9 +153,24 @@ class RnnLmModel(torch.nn.Module): def clean_cache(self): self.cache = {} - def score_token(self, tokens: torch.Tensor, state=None): + def score_token(self, x: torch.Tensor, x_lens: torch.Tensor, state=None): + """Score a batch of tokens + + Args: + x (torch.Tensor): + A batch of tokens + x_lens (torch.Tensor): + The length of tokens in the batch before padding + state (_type_, optional): + Either None or a tuple of two torch.Tensor. Each tensor has + the shape of (hidden_dim) + + + Returns: + _type_: _description_ + """ device = next(self.parameters()).device - batch_size = tokens.size(0) + batch_size = x.size(0) if state: h, c = state else: @@ -166,7 +181,7 @@ class RnnLmModel(torch.nn.Module): device ) - embedding = self.input_embedding(tokens) + embedding = self.input_embedding(x) rnn_out, states = self.rnn(embedding, (h, c)) logits = self.output_linear(rnn_out) diff --git a/icefall/rnn_lm/train.py b/icefall/rnn_lm/train.py index 803da99d6..f43e66cd2 100755 --- a/icefall/rnn_lm/train.py +++ b/icefall/rnn_lm/train.py @@ -531,6 +531,9 @@ def run(rank, world_size, args): tie_weights=params.tie_weights, ) + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + checkpoints = load_checkpoint_if_available(params=params, model=model) model.to(device) diff --git a/icefall/transformer_lm/attention.py b/icefall/transformer_lm/attention.py new file mode 100644 index 000000000..5ce83b15e --- /dev/null +++ b/icefall/transformer_lm/attention.py @@ -0,0 +1,510 @@ +# Copyright (c) 2021 University of Chinese Academy of Sciences (author: Han Zhu) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +from typing import List, Optional, Tuple + +import torch +from torch import Tensor, nn + +from icefall.transformer_lm.scaling import ( + ActivationBalancer, + BasicNorm, + DoubleSwish, + ScaledConv1d, + ScaledConv2d, + ScaledLinear, +) +from icefall.utils import is_jit_tracing + + +class RelPositionMultiheadAttention(nn.Module): + r"""Multi-Head Attention layer with relative position encoding + + See reference: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" + + Args: + embed_dim: total dimension of the model. + num_heads: parallel attention heads. + dropout: a Dropout layer on attn_output_weights. Default: 0.0. + + Examples:: + + >>> rel_pos_multihead_attn = RelPositionMultiheadAttention(embed_dim, num_heads) + >>> attn_output, attn_output_weights = multihead_attn(query, key, value, pos_emb) + """ + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + ) -> None: + super(RelPositionMultiheadAttention, self).__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + assert ( + self.head_dim * num_heads == self.embed_dim + ), "embed_dim must be divisible by num_heads" + + self.in_proj = ScaledLinear(embed_dim, 3 * embed_dim, bias=True) + self.out_proj = ScaledLinear( + embed_dim, embed_dim, bias=True, initial_scale=0.25 + ) + + # linear transformation for positional encoding. + self.linear_pos = ScaledLinear(embed_dim, embed_dim, bias=False) + # these two learnable bias are used in matrix c and matrix d + # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 + self.pos_bias_u = nn.Parameter(torch.Tensor(num_heads, self.head_dim)) + self.pos_bias_v = nn.Parameter(torch.Tensor(num_heads, self.head_dim)) + self.pos_bias_u_scale = nn.Parameter(torch.zeros(()).detach()) + self.pos_bias_v_scale = nn.Parameter(torch.zeros(()).detach()) + self._reset_parameters() + + def _pos_bias_u(self): + return self.pos_bias_u * self.pos_bias_u_scale.exp() + + def _pos_bias_v(self): + return self.pos_bias_v * self.pos_bias_v_scale.exp() + + def _reset_parameters(self) -> None: + nn.init.normal_(self.pos_bias_u, std=0.01) + nn.init.normal_(self.pos_bias_v, std=0.01) + + def forward( + self, + query: Tensor, + key: Tensor, + value: Tensor, + pos_emb: Tensor, + key_padding_mask: Optional[Tensor] = None, + need_weights: bool = False, + attn_mask: Optional[Tensor] = None, + left_context: int = 0, + ) -> Tuple[Tensor, Optional[Tensor]]: + r""" + Args: + query, key, value: map a query and a set of key-value pairs to an output. + pos_emb: Positional embedding tensor + key_padding_mask: if provided, specified padding elements in the key will + be ignored by the attention. When given a binary mask and a value is True, + the corresponding value on the attention layer will be ignored. When given + a byte mask and a value is non-zero, the corresponding value on the attention + layer will be ignored + need_weights: output attn_output_weights. + attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all + the batches while a 3D mask allows to specify a different mask for the entries of each batch. + left_context (int): left context (in frames) used during streaming decoding. + this is used only in real streaming decoding, in other circumstances, + it MUST be 0. + + Shape: + - Inputs: + - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is + the embedding dimension. + - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - pos_emb: :math:`(N, 2*L-1, E)` where L is the target sequence length, N is the batch size, E is + the embedding dimension. + - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. + If a ByteTensor is provided, the non-zero positions will be ignored while the position + with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the + value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. + - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. + 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, + S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked + positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend + while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` + is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor + is provided, it will be added to the attention weight. + + - Outputs: + - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, + E is the embedding dimension. + - attn_output_weights: :math:`(N, L, S)` where N is the batch size, + L is the target sequence length, S is the source sequence length. + """ + return self.multi_head_attention_forward( + query, + key, + value, + pos_emb, + self.embed_dim, + self.num_heads, + self.in_proj.get_weight(), + self.in_proj.get_bias(), + self.dropout, + self.out_proj.get_weight(), + self.out_proj.get_bias(), + training=self.training, + key_padding_mask=key_padding_mask, + need_weights=need_weights, + attn_mask=attn_mask, + left_context=left_context, + ) + + def rel_shift(self, x: Tensor, left_context: int = 0) -> Tensor: + """Compute relative positional encoding. + + Args: + x: Input tensor (batch, head, time1, 2*time1-1+left_context). + time1 means the length of query vector. + left_context (int): left context (in frames) used during streaming decoding. + this is used only in real streaming decoding, in other circumstances, + it MUST be 0. + + Returns: + Tensor: tensor of shape (batch, head, time1, time2) + (note: time2 has the same value as time1, but it is for + the key, while time1 is for the query). + """ + (batch_size, num_heads, time1, n) = x.shape + + time2 = time1 + left_context + if not is_jit_tracing(): + assert ( + n == left_context + 2 * time1 - 1 + ), f"{n} == {left_context} + 2 * {time1} - 1" + + if is_jit_tracing(): + rows = torch.arange(start=time1 - 1, end=-1, step=-1) + cols = torch.arange(time2) + rows = rows.repeat(batch_size * num_heads).unsqueeze(-1) + indexes = rows + cols + + x = x.reshape(-1, n) + x = torch.gather(x, dim=1, index=indexes) + x = x.reshape(batch_size, num_heads, time1, time2) + return x + else: + # Note: TorchScript requires explicit arg for stride() + batch_stride = x.stride(0) + head_stride = x.stride(1) + time1_stride = x.stride(2) + n_stride = x.stride(3) + return x.as_strided( + (batch_size, num_heads, time1, time2), + (batch_stride, head_stride, time1_stride - n_stride, n_stride), + storage_offset=n_stride * (time1 - 1), + ) + + def multi_head_attention_forward( + self, + query: Tensor, + key: Tensor, + value: Tensor, + pos_emb: Tensor, + embed_dim_to_check: int, + num_heads: int, + in_proj_weight: Tensor, + in_proj_bias: Tensor, + dropout_p: float, + out_proj_weight: Tensor, + out_proj_bias: Tensor, + training: bool = True, + key_padding_mask: Optional[Tensor] = None, + need_weights: bool = False, + attn_mask: Optional[Tensor] = None, + left_context: int = 0, + ) -> Tuple[Tensor, Optional[Tensor]]: + r""" + Args: + query, key, value: map a query and a set of key-value pairs to an output. + pos_emb: Positional embedding tensor + embed_dim_to_check: total dimension of the model. + num_heads: parallel attention heads. + in_proj_weight, in_proj_bias: input projection weight and bias. + dropout_p: probability of an element to be zeroed. + out_proj_weight, out_proj_bias: the output projection weight and bias. + training: apply dropout if is ``True``. + key_padding_mask: if provided, specified padding elements in the key will + be ignored by the attention. This is an binary mask. When the value is True, + the corresponding value on the attention layer will be filled with -inf. + need_weights: output attn_output_weights. + attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all + the batches while a 3D mask allows to specify a different mask for the entries of each batch. + left_context (int): left context (in frames) used during streaming decoding. + this is used only in real streaming decoding, in other circumstances, + it MUST be 0. + + Shape: + Inputs: + - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is + the embedding dimension. + - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - pos_emb: :math:`(N, 2*L-1, E)` or :math:`(1, 2*L-1, E)` where L is the target sequence + length, N is the batch size, E is the embedding dimension. + - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. + If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions + will be unchanged. If a BoolTensor is provided, the positions with the + value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. + - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. + 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, + S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked + positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend + while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` + are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor + is provided, it will be added to the attention weight. + + Outputs: + - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, + E is the embedding dimension. + - attn_output_weights: :math:`(N, L, S)` where N is the batch size, + L is the target sequence length, S is the source sequence length. + """ + + tgt_len, bsz, embed_dim = query.size() + if not is_jit_tracing(): + assert embed_dim == embed_dim_to_check + assert key.size(0) == value.size(0) and key.size(1) == value.size(1) + + head_dim = embed_dim // num_heads + if not is_jit_tracing(): + assert ( + head_dim * num_heads == embed_dim + ), "embed_dim must be divisible by num_heads" + + scaling = float(head_dim) ** -0.5 + + if torch.equal(query, key) and torch.equal(key, value): + # self-attention + q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk( + 3, dim=-1 + ) + + elif torch.equal(key, value): + # encoder-decoder attention + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = 0 + _end = embed_dim + _w = in_proj_weight[_start:_end, :] + if _b is not None: + _b = _b[_start:_end] + q = nn.functional.linear(query, _w, _b) + + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = embed_dim + _end = None + _w = in_proj_weight[_start:, :] + if _b is not None: + _b = _b[_start:] + k, v = nn.functional.linear(key, _w, _b).chunk(2, dim=-1) + + else: + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = 0 + _end = embed_dim + _w = in_proj_weight[_start:_end, :] + if _b is not None: + _b = _b[_start:_end] + q = nn.functional.linear(query, _w, _b) + + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = embed_dim + _end = embed_dim * 2 + _w = in_proj_weight[_start:_end, :] + if _b is not None: + _b = _b[_start:_end] + k = nn.functional.linear(key, _w, _b) + + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = embed_dim * 2 + _end = None + _w = in_proj_weight[_start:, :] + if _b is not None: + _b = _b[_start:] + v = nn.functional.linear(value, _w, _b) + + if attn_mask is not None: + assert ( + attn_mask.dtype == torch.float32 + or attn_mask.dtype == torch.float64 + or attn_mask.dtype == torch.float16 + or attn_mask.dtype == torch.uint8 + or attn_mask.dtype == torch.bool + ), "Only float, byte, and bool types are supported for attn_mask, not {}".format( + attn_mask.dtype + ) + if attn_mask.dtype == torch.uint8: + warnings.warn( + "Byte tensor for attn_mask is deprecated. Use bool tensor instead." + ) + attn_mask = attn_mask.to(torch.bool) + + if attn_mask.dim() == 2: + attn_mask = attn_mask.unsqueeze(0) + if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: + raise RuntimeError("The size of the 2D attn_mask is not correct.") + elif attn_mask.dim() == 3: + if list(attn_mask.size()) != [ + bsz * num_heads, + query.size(0), + key.size(0), + ]: + raise RuntimeError("The size of the 3D attn_mask is not correct.") + else: + raise RuntimeError( + "attn_mask's dimension {} is not supported".format(attn_mask.dim()) + ) + # attn_mask's dim is 3 now. + + # convert ByteTensor key_padding_mask to bool + if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: + warnings.warn( + "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." + ) + key_padding_mask = key_padding_mask.to(torch.bool) + + q = (q * scaling).contiguous().view(tgt_len, bsz, num_heads, head_dim) + k = k.contiguous().view(-1, bsz, num_heads, head_dim) + v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) + + src_len = k.size(0) + + if key_padding_mask is not None and not is_jit_tracing(): + assert key_padding_mask.size(0) == bsz, "{} == {}".format( + key_padding_mask.size(0), bsz + ) + assert key_padding_mask.size(1) == src_len, "{} == {}".format( + key_padding_mask.size(1), src_len + ) + + q = q.transpose(0, 1) # (batch, time1, head, d_k) + + pos_emb_bsz = pos_emb.size(0) + if not is_jit_tracing(): + assert pos_emb_bsz in (1, bsz) # actually it is 1 + + p = self.linear_pos(pos_emb).view(pos_emb_bsz, -1, num_heads, head_dim) + # (batch, 2*time1, head, d_k) --> (batch, head, d_k, 2*time -1) + p = p.permute(0, 2, 3, 1) + + q_with_bias_u = (q + self._pos_bias_u()).transpose( + 1, 2 + ) # (batch, head, time1, d_k) + + q_with_bias_v = (q + self._pos_bias_v()).transpose( + 1, 2 + ) # (batch, head, time1, d_k) + + # compute attention score + # first compute matrix a and matrix c + # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 + k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) + matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2) + + # compute matrix b and matrix d + matrix_bd = torch.matmul(q_with_bias_v, p) # (batch, head, time1, 2*time1-1) + matrix_bd = self.rel_shift(matrix_bd, left_context) + + attn_output_weights = matrix_ac + matrix_bd # (batch, head, time1, time2) + + attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) + + if not is_jit_tracing(): + assert list(attn_output_weights.size()) == [ + bsz * num_heads, + tgt_len, + src_len, + ] + + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + attn_output_weights.masked_fill_(attn_mask, float("-inf")) + else: + attn_output_weights += attn_mask + + if key_padding_mask is not None: + attn_output_weights = attn_output_weights.view( + bsz, num_heads, tgt_len, src_len + ) + attn_output_weights = attn_output_weights.masked_fill( + key_padding_mask.unsqueeze(1).unsqueeze(2), + float("-inf"), + ) + attn_output_weights = attn_output_weights.view( + bsz * num_heads, tgt_len, src_len + ) + + attn_output_weights = nn.functional.softmax(attn_output_weights, dim=-1) + + # If we are using dynamic_chunk_training and setting a limited + # num_left_chunks, the attention may only see the padding values which + # will also be masked out by `key_padding_mask`, at this circumstances, + # the whole column of `attn_output_weights` will be `-inf` + # (i.e. be `nan` after softmax), so, we fill `0.0` at the masking + # positions to avoid invalid loss value below. + if ( + attn_mask is not None + and attn_mask.dtype == torch.bool + and key_padding_mask is not None + ): + if attn_mask.size(0) != 1: + attn_mask = attn_mask.view(bsz, num_heads, tgt_len, src_len) + combined_mask = attn_mask | key_padding_mask.unsqueeze(1).unsqueeze(2) + else: + # attn_mask.shape == (1, tgt_len, src_len) + combined_mask = attn_mask.unsqueeze(0) | key_padding_mask.unsqueeze( + 1 + ).unsqueeze(2) + + attn_output_weights = attn_output_weights.view( + bsz, num_heads, tgt_len, src_len + ) + attn_output_weights = attn_output_weights.masked_fill(combined_mask, 0.0) + attn_output_weights = attn_output_weights.view( + bsz * num_heads, tgt_len, src_len + ) + + attn_output_weights = nn.functional.dropout( + attn_output_weights, p=dropout_p, training=training + ) + + attn_output = torch.bmm(attn_output_weights, v) + + if not is_jit_tracing(): + assert list(attn_output.size()) == [ + bsz * num_heads, + tgt_len, + head_dim, + ] + + attn_output = ( + attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) + ) + attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) + + if need_weights: + # average attention weights over heads + attn_output_weights = attn_output_weights.view( + bsz, num_heads, tgt_len, src_len + ) + return attn_output, attn_output_weights.sum(dim=1) / num_heads + else: + return attn_output, None diff --git a/icefall/transformer_lm/compute_perplexity.py b/icefall/transformer_lm/compute_perplexity.py new file mode 100644 index 000000000..72d7c477b --- /dev/null +++ b/icefall/transformer_lm/compute_perplexity.py @@ -0,0 +1,195 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang +# Xiaoyu Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import logging +import math +from pathlib import Path + +import torch +from dataset import get_dataloader +from train import get_params + +from icefall.checkpoint import average_checkpoints, load_checkpoint +from icefall.transformer_lm.model import TransformerLM +from icefall.utils import AttributeDict, setup_logger, str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=7, + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", + ) + parser.add_argument( + "--avg", + type=int, + default=1, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="transformer_lm/exp_full_libri_16layer_maxlen200_8gpu", + ) + + parser.add_argument( + "--lm-data", + type=str, + help="Path to the LM test data for computing perplexity", + default="transformer_lm/libri_lm_training_bpe500/sorted_lm_data-test.pt", + ) + + parser.add_argument( + "--vocab-size", + type=int, + default=500, + help="Vocabulary size of the model", + ) + + parser.add_argument( + "--num-layers", + type=int, + default=16, + help="Number of RNN layers the model", + ) + + parser.add_argument( + "--tie-weights", + type=str2bool, + default=False, + help="""True to share the weights between the input embedding layer and the + last output linear layer + """, + ) + + parser.add_argument( + "--batch-size", + type=int, + default=50, + help="Number of RNN layers the model", + ) + + parser.add_argument( + "--max-sent-len", + type=int, + default=100, + help="Number of RNN layers the model", + ) + + return parser + + +def main(): + parser = get_parser() + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + args.lm_data = Path(args.lm_data) + + params = get_params() + params.update(vars(args)) + + setup_logger(f"{params.exp_dir}/log-ppl/") + logging.info("Computing perplexity started") + logging.info(params) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + logging.info("About to create model") + model = TransformerLM( + vocab_size=params.vocab_size, + d_model=params.encoder_dim, + embedding_dim=params.embedding_dim, + dim_feedforward=params.dim_feedforward, + nhead=params.nhead, + num_layers=params.num_layers, + tie_weights=params.tie_weights, + params=params, + ) + + if params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + model.to(device) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if start >= 0: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + + model.eval() + num_param = sum([p.numel() for p in model.parameters()]) + num_param_requires_grad = sum( + [p.numel() for p in model.parameters() if p.requires_grad] + ) + + logging.info(f"Number of model parameters: {num_param}") + logging.info( + f"Number of model parameters (requires_grad): " + f"{num_param_requires_grad} " + f"({num_param_requires_grad/num_param_requires_grad*100}%)" + ) + + logging.info(f"Loading LM test data from {params.lm_data}") + test_dl = get_dataloader( + filename=params.lm_data, + is_distributed=False, + params=params, + ) + + tot_loss = 0.0 + num_tokens = 0 + num_sentences = 0 + for batch_idx, batch in enumerate(test_dl): + x, y, sentence_lengths = batch + x = x.to(device) + y = y.to(device) + sentence_lengths = sentence_lengths.to(device) + + nll = model(x, y, sentence_lengths) + loss = nll.sum().cpu().item() + + tot_loss += loss + num_tokens += sentence_lengths.sum().cpu().item() + num_sentences += x.size(0) + + ppl = math.exp(tot_loss / num_tokens) + logging.info( + f"total nll: {tot_loss}, num tokens: {num_tokens}, " + f"num sentences: {num_sentences}, ppl: {ppl:.3f}" + ) + + +if __name__ == "__main__": + main() diff --git a/icefall/transformer_lm/dataset.py b/icefall/transformer_lm/dataset.py new file mode 120000 index 000000000..5792a6cf0 --- /dev/null +++ b/icefall/transformer_lm/dataset.py @@ -0,0 +1 @@ +../rnn_lm/dataset.py \ No newline at end of file diff --git a/icefall/transformer_lm/encoder.py b/icefall/transformer_lm/encoder.py new file mode 100644 index 000000000..4357b83d7 --- /dev/null +++ b/icefall/transformer_lm/encoder.py @@ -0,0 +1,329 @@ +# Copyright (c) 2021 Xiaomi Corporation (authors: Xiaoyu Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import math +from typing import List, Optional, Tuple + +import torch +import torch.nn.functional as F +from torch import Tensor, nn + +from icefall.transformer_lm.attention import RelPositionMultiheadAttention +from icefall.transformer_lm.scaling import ( + ActivationBalancer, + BasicNorm, + DoubleSwish, + ScaledConv1d, + ScaledConv2d, + ScaledLinear, +) +from icefall.utils import is_jit_tracing, make_pad_mask + + +class Transformer(torch.nn.Module): + """_summary_ + + Args: + input_dim (int): Input feature dimension + d_mode (int): The dimension of the transformer + dim_feedforward (int ): The dimension of the ffw module + nhead (int): The number of attention heads + dropout_rate (float): dropout rate + att_dropout (float): dropout rate in attention module + """ + + def __init__( + self, + input_dim: int, + d_model: int, + dim_feedforward: int, + nhead: int = 4, + num_layers: int = 6, + dropout_rate: float = 0.1, + att_dropout: float = 0.0, + ): + super().__init__() + + self.encoder_layers = num_layers + self.d_model = d_model + + self.embed = ScaledLinear(input_dim, d_model) + self.norm_before = BasicNorm(d_model, learn_eps=False) + + self.encoder_pos = RelPositionalEncoding(d_model, dropout_rate) + + encoder_layer = TransformerEncoderLayer( + d_model=d_model, + dim_feedforward=dim_feedforward, + nhead=nhead, + dropout_rate=dropout_rate, + ) + + self.encoder = TransformerEncoder(encoder_layer, num_layers) + + def _create_attention_mask(self, x_lens: torch.Tensor): + # create a 2D attention mask to mask out + # the upper right half of the attention matrix + max_len = max(x_lens) + ones = torch.ones(max_len, max_len, device=x_lens.device, dtype=torch.bool) + return torch.triu(ones, diagonal=1) + + def forward( + self, x: torch.Tensor, x_lens: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Transformer forward + + Args: + x (torch.Tensor): Input tensor (B,T,input_dim) + x_lens (torch.Tensor): The length of input tensors before padding (B,) + + Returns: + Return a tuple of 2 tensors: + - x: output feature of the transformer (B,T,d_model) + - x_lens: output feature lens of the transformer + """ + + attention_mask = self._create_attention_mask(x_lens) + src_key_padding_mask = make_pad_mask(x_lens) + + x = self.norm_before(self.embed(x)) + + x, pos_emb = self.encoder_pos(x) + x = x.permute(1, 0, 2) + + x = self.encoder( + x, + pos_emb, + mask=attention_mask, # pass the attention mast + src_key_padding_mask=src_key_padding_mask, + ) # (T, N, C) + + x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + return x, x_lens + + +class TransformerEncoder(torch.nn.Module): + def __init__(self, encoder_layer: torch.nn.Module, num_layers: int) -> None: + """TransformerEncoder is a stack of N encoder layers + + Args: + encoder_layer (torch.nn.Module): an instance of the TransformerEncoderLayer() + num_layers (int): Number of layers to be stacked + """ + super().__init__() + self.layers = nn.ModuleList( + [copy.deepcopy(encoder_layer) for i in range(num_layers)] + ) + self.num_layers = num_layers + + def forward( + self, + src: torch.Tensor, + pos_emb: torch.Tensor, + src_key_padding_mask: Optional[torch.Tensor] = None, + mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """_summary_ + + Args: + src: the sequence to the encoder (required). + pos_emb: Positional embedding tensor (required). + mask: the mask for the src sequence (optional). + src_key_padding_mask: the mask for the src keys per batch (optional). + + Returns: + output: transformer encoded features + """ + output = src + + for layer_index, mod in enumerate(self.layers): + output = mod( + output, + pos_emb, + src_key_padding_mask=src_key_padding_mask, + src_mask=mask, + ) + + return output + + +class TransformerEncoderLayer(torch.nn.Module): + def __init__( + self, + d_model: int, + dim_feedforward: int, + nhead: int, + dropout_rate: float, + ): + """TransformerEncoderLayer is made up of self-attn and feedforward module + + Args: + d_model (int): The model size + dim_feedforward (int): Dimension of ffw module + nhead (int): Number of heads + dropout_rate (float): Dropout rate + """ + super().__init__() + + self.d_model = d_model + + self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) + self.feed_forward = nn.Sequential( + ScaledLinear(d_model, dim_feedforward), + ActivationBalancer(channel_dim=-1), + DoubleSwish(), + nn.Dropout(dropout_rate), + ScaledLinear(dim_feedforward, d_model, initial_scale=0.25), + ) + + self.norm_final = BasicNorm(d_model) + + self.balancer = ActivationBalancer( + channel_dim=-1, min_positive=0.45, max_positive=0.55, max_abs=6.0 + ) + + self.dropout = nn.Dropout(dropout_rate) + + def forward( + self, + src: torch.Tensor, + pos_emb: torch.Tensor, + src_key_padding_mask: Optional[torch.Tensor] = None, + src_mask: Optional[torch.Tensor] = None, + cache=None, + ): + """ + Pass the input through the encoder layer. + + Args: + src: the sequence to the encoder layer (required). + pos_emb: Positional embedding tensor (required). + src_key_padding_mask: the mask for the src keys per batch (optional). + src_mask: the mask for the src sequence (optional). + """ + src_orig = src + + src_att = self.self_attn( + src, + src, + src, + pos_emb=pos_emb, + attn_mask=src_mask, + key_padding_mask=src_key_padding_mask, + )[0] + + src = src + self.dropout(src_att) + + # feed forward module + src = src + self.dropout(self.feed_forward(src)) + + src = self.norm_final(self.balancer(src)) + + return src + + +class RelPositionalEncoding(torch.nn.Module): + """Relative positional encoding module. + + See : Appendix B in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" + Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/embedding.py + + Args: + d_model: Embedding dimension. + dropout_rate: Dropout rate. + max_len: Maximum input length. + + """ + + def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: + """Construct an PositionalEncoding object.""" + super(RelPositionalEncoding, self).__init__() + if is_jit_tracing(): + # 10k frames correspond to ~100k ms, e.g., 100 seconds, i.e., + # It assumes that the maximum input won't have more than + # 10k frames. + # + # TODO(fangjun): Use torch.jit.script() for this module + max_len = 10000 + + self.d_model = d_model + self.dropout = torch.nn.Dropout(p=dropout_rate) + self.pe = None + self.extend_pe(torch.tensor(0.0).expand(1, max_len)) + + def extend_pe(self, x: torch.Tensor, left_context: int = 0) -> None: + """Reset the positional encodings.""" + x_size_1 = x.size(1) + left_context + if self.pe is not None: + # self.pe contains both positive and negative parts + # the length of self.pe is 2 * input_len - 1 + if self.pe.size(1) >= x_size_1 * 2 - 1: + # Note: TorchScript doesn't implement operator== for torch.Device + if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): + self.pe = self.pe.to(dtype=x.dtype, device=x.device) + return + # Suppose `i` means to the position of query vector and `j` means the + # position of key vector. We use position relative positions when keys + # are to the left (i>j) and negative relative positions otherwise (i Tuple[torch.Tensor, torch.Tensor]: + """Add positional encoding. + + Args: + x (torch.Tensor): Input tensor (batch, time, `*`). + left_context (int): left context (in frames) used during streaming decoding. + this is used only in real streaming decoding, in other circumstances, + it MUST be 0. + + Returns: + torch.Tensor: Encoded tensor (batch, time, `*`). + torch.Tensor: Encoded tensor (batch, 2*time-1, `*`). + + """ + self.extend_pe(x, left_context) + x_size_1 = x.size(1) + left_context + pos_emb = self.pe[ + :, + self.pe.size(1) // 2 + - x_size_1 + + 1 : self.pe.size(1) // 2 # noqa E203 + + x.size(1), + ] + return self.dropout(x), self.dropout(pos_emb) diff --git a/icefall/transformer_lm/export.py b/icefall/transformer_lm/export.py new file mode 100644 index 000000000..c08982e37 --- /dev/null +++ b/icefall/transformer_lm/export.py @@ -0,0 +1,186 @@ +#!/usr/bin/env python3 +# Copyright (c) 2022 Xiaomi Corporation (authors: Xiaoyu Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This script converts several saved checkpoints +# to a single one using model averaging. + +import argparse +import logging +from pathlib import Path + +import torch +from model import TransformerLM + +from icefall.checkpoint import load_checkpoint +from icefall.utils import AttributeDict, load_averaged_model, str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=11, + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", + ) + + parser.add_argument( + "--avg", + type=int, + default=5, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", + ) + + parser.add_argument( + "--vocab-size", + type=int, + default=500, + help="Vocabulary size of the model", + ) + + parser.add_argument( + "--embedding-dim", + type=int, + default=768, + help="Embedding dim of the model", + ) + + parser.add_argument( + "--encoder-dim", + type=int, + default=768, + help="Encoder dim of the model", + ) + + parser.add_argument( + "--dim_feedforward", + type=int, + default=2048, + help="Hidden dim of the model", + ) + + parser.add_argument( + "--nhead", + type=int, + default=8, + help="Number of attention heads", + ) + + parser.add_argument( + "--num-layers", + type=int, + default=16, + help="Number of Transformer layers", + ) + + parser.add_argument( + "--tie-weights", + type=str2bool, + default=True, + help="""True to share the weights between the input embedding layer and the + last output linear layer + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="rnn_lm/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--jit", + type=str2bool, + default=True, + help="""True to save a model after applying torch.jit.script. + """, + ) + + return parser + + +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = AttributeDict({}) + params.update(vars(args)) + + logging.info(params) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + logging.info("About to create model") + model = TransformerLM( + vocab_size=params.vocab_size, + d_model=params.encoder_dim, + embedding_dim=params.embedding_dim, + dim_feedforward=params.dim_feedforward, + nhead=params.nhead, + num_layers=params.num_layers, + tie_weights=params.tie_weights, + params=params, + ) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + model.to(device) + + if params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + model = load_averaged_model( + params.exp_dir, model, params.epoch, params.avg, device + ) + + model.to("cpu") + model.eval() + + if params.jit: + logging.info("Using torch.jit.script") + model = torch.jit.script(model) + filename = params.exp_dir / "cpu_jit.pt" + model.save(str(filename)) + logging.info(f"Saved to {filename}") + else: + logging.info("Not using torch.jit.script") + # Save it using a format so that it can be loaded + # by :func:`load_checkpoint` + filename = params.exp_dir / "pretrained.pt" + torch.save({"model": model.state_dict()}, str(filename)) + logging.info(f"Saved to {filename}") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/icefall/transformer_lm/model.py b/icefall/transformer_lm/model.py new file mode 100644 index 000000000..79dda3168 --- /dev/null +++ b/icefall/transformer_lm/model.py @@ -0,0 +1,115 @@ +# Copyright (c) 2022 Xiaomi Corporation (authors: Xiaoyu Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import Optional, Tuple + +import torch +import torch.nn.functional as F + +from icefall.transformer_lm.encoder import Transformer +from icefall.utils import AttributeDict, add_eos, add_sos, make_pad_mask + + +class TransformerLM(torch.nn.Module): + def __init__( + self, + vocab_size: int, + embedding_dim: int, + d_model: int, + dim_feedforward: int, + nhead: int = 8, + num_layers: int = 16, + tie_weights: bool = True, + dropout: float = 0.1, + emb_dropout_rate: float = 0.0, + params: AttributeDict = None, + ): + super().__init__() + + self.vocab_size = vocab_size + self.params = params + + self.input_embedding = torch.nn.Embedding( + num_embeddings=vocab_size, + embedding_dim=embedding_dim, + ) + + self.encoder = Transformer( + input_dim=embedding_dim, + d_model=d_model, + dim_feedforward=dim_feedforward, + nhead=nhead, + num_layers=num_layers, + dropout_rate=dropout, + ) + + self.output_linear = torch.nn.Linear( + in_features=d_model, out_features=vocab_size + ) + if tie_weights: + logging.info("Tying weights") + assert d_model == embedding_dim, (d_model, embedding_dim) + self.output_linear.weight = self.input_embedding.weight + else: + logging.info("Not tying weights") + + def forward( + self, + x: torch.Tensor, + y: torch.Tensor, + x_lens: torch.Tensor, + return_logits: bool = False, + ): + """Forward transformer language model + + Args: + x (torch.Tensor): Input tokens (B,L) + y (torch.Tensor): Output tokens (with EOS appended) (B,L) + x_lens (torch.Tensor): Length of input tokens before padding (B,) + return_logits (bool, optional): Return logits instead of NLL + + """ + + x = self.input_embedding(x) + + x, x_lens = self.encoder(x, x_lens) + + logits = self.output_linear(x) + + if return_logits: + return logits + + nll_loss = F.cross_entropy( + logits.reshape(-1, self.vocab_size), y.reshape(-1), reduction="none" + ) + + mask = make_pad_mask(x_lens).reshape(-1) + nll_loss.masked_fill_(mask, 0) + + return nll_loss + + def score_token(self, x: torch.Tensor, x_lens: torch.Tensor, state=None): + + bs = x.size(0) + + state = None + logits = self.forward(x, x, x_lens, return_logits=True) + index = torch.arange(bs) + + last_logits = logits[index, x_lens - 1, :] + + return last_logits.log_softmax(-1), state diff --git a/icefall/transformer_lm/scaling.py b/icefall/transformer_lm/scaling.py new file mode 120000 index 000000000..0876c0704 --- /dev/null +++ b/icefall/transformer_lm/scaling.py @@ -0,0 +1 @@ +../../egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py \ No newline at end of file diff --git a/icefall/transformer_lm/train.py b/icefall/transformer_lm/train.py new file mode 100644 index 000000000..c36abfcdf --- /dev/null +++ b/icefall/transformer_lm/train.py @@ -0,0 +1,609 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Xiaoyu Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +Usage: + ./transformer_lm/train.py \ + --start-epoch 0 \ + --world-size 2 \ + --num-epochs 1 \ + --use-fp16 0 \ + --num-layers 12 \ + --batch-size 400 + +""" + +import argparse +import logging +import math +from pathlib import Path +from shutil import copyfile +from typing import Optional, Tuple + +import torch +import torch.multiprocessing as mp +import torch.nn as nn +import torch.optim as optim +from dataset import get_dataloader +from lhotse.utils import fix_random_seed +from model import TransformerLM +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.nn.utils import clip_grad_norm_ +from torch.utils.tensorboard import SummaryWriter + +from icefall.checkpoint import load_checkpoint +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=0, + help="""Resume training from from this epoch. + If it is positive, it will load checkpoint from + exp_dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="transformer_lm/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, logs, etc, are saved + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=True, + help="Whether to use half precision training.", + ) + + parser.add_argument( + "--batch-size", + type=int, + default=400, + ) + + parser.add_argument( + "--lm-data", + type=str, + default="data/lm_training_bpe_500/sorted_lm_data.pt", + help="LM training data", + ) + + parser.add_argument( + "--lm-data-valid", + type=str, + default="data/lm_training_bpe_500/sorted_lm_data-valid.pt", + help="LM validation data", + ) + + parser.add_argument( + "--vocab-size", + type=int, + default=500, + help="Vocabulary size of the model", + ) + + parser.add_argument( + "--num-layers", + type=int, + default=12, + help="Number of Transformer layers in the model", + ) + + parser.add_argument( + "--tie-weights", + type=str2bool, + default=True, + help="""True to share the weights between the input embedding layer and the + last output linear layer + """, + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters.""" + + params = AttributeDict( + { + "max_sent_len": 200, + "sos_id": 1, + "eos_id": 1, + "blank_id": 0, + "lr": 1e-3, + "weight_decay": 1e-6, + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 200, + "reset_interval": 2000, + "valid_interval": 1000, + "nhead": 8, + "embedding_dim": 768, + "encoder_dim": 768, + "dim_feedforward": 2048, + "dropout": 0.1, + "env_info": get_env_info(), + } + ) + return params + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, +) -> None: + """Load checkpoint from file. + + If params.start_epoch is positive, it will load the checkpoint from + `params.start_epoch - 1`. Otherwise, this function does nothing. + + Apart from loading state dict for `model`, `optimizer` and `scheduler`, + it also updates `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + optimizer: + The optimizer that we are using. + scheduler: + The learning rate scheduler we are using. + Returns: + Return None. + """ + if params.start_epoch <= 0: + return + + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + logging.info(f"Loading checkpoint: {filename}") + saved_params = load_checkpoint( + filename, + model=model, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: nn.Module, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + params=params, + optimizer=optimizer, + scheduler=scheduler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + model: nn.Module, + x: torch.Tensor, + y: torch.Tensor, + sentence_lengths: torch.Tensor, + is_training: bool, +) -> Tuple[torch.Tensor, MetricsTracker]: + """Compute the negative log-likelihood loss given a model and its input. + Args: + model: + The NN model, + x: + A 2-D tensor. Each row contains BPE token IDs for a sentence. Also, + each row starts with SOS ID. + y: + A 2-D tensor. Each row is a shifted version of the corresponding row + in `x` but ends with an EOS ID (before padding). + sentence_lengths: + A 1-D tensor containing number of tokens of each sentence + before padding. + is_training: + True for training. False for validation. + """ + with torch.set_grad_enabled(is_training): + device = model.device + x = x.to(device) + y = y.to(device) + sentence_lengths = sentence_lengths.to(device) + + nll = model(x, y, sentence_lengths) + loss = nll.sum() + + num_tokens = sentence_lengths.sum().item() + + loss_info = MetricsTracker() + # Note: Due to how MetricsTracker() is designed, + # we use "frames" instead of "num_tokens" as a key here + loss_info["frames"] = num_tokens + loss_info["loss"] = loss.detach().item() + return loss, loss_info + + +def compute_validation_loss( + params: AttributeDict, + model: nn.Module, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process. The validation loss + is saved in `params.valid_loss`. + """ + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + x, y, sentence_lengths = batch + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + model=model, + x=x, + y=y, + sentence_lengths=sentence_lengths, + is_training=False, + ) + + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: nn.Module, + optimizer: torch.optim.Optimizer, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all sentences is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + """ + model.train() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(train_dl): + params.batch_idx_train += 1 + x, y, sentence_lengths = batch + batch_size = x.size(0) + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + model=model, + x=x, + y=y, + sentence_lengths=sentence_lengths, + is_training=True, + ) + + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + optimizer.zero_grad() + loss.backward() + clip_grad_norm_(model.parameters(), 5.0, 2.0) + optimizer.step() + + if batch_idx % params.log_interval == 0: + # Note: "frames" here means "num_tokens" + this_batch_ppl = math.exp(loss_info["loss"] / loss_info["frames"]) + tot_ppl = math.exp(tot_loss["loss"] / tot_loss["frames"]) + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}, ppl: {this_batch_ppl}] " + f"tot_loss[{tot_loss}, ppl: {tot_ppl}], " + f"batch size: {batch_size}" + ) + + if tb_writer is not None: + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + + tb_writer.add_scalar( + "train/current_ppl", this_batch_ppl, params.batch_idx_train + ) + + tb_writer.add_scalar("train/tot_ppl", tot_ppl, params.batch_idx_train) + + if batch_idx > 0 and batch_idx % params.valid_interval == 0: + logging.info("Computing validation loss") + + valid_info = compute_validation_loss( + params=params, + model=model, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + + valid_ppl = math.exp(valid_info["loss"] / valid_info["frames"]) + logging.info( + f"Epoch {params.cur_epoch}, validation: {valid_info}, " + f"ppl: {valid_ppl}" + ) + + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + tb_writer.add_scalar( + "train/valid_ppl", valid_ppl, params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + is_distributed = world_size > 1 + + fix_random_seed(params.seed) + if is_distributed: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + logging.info(params) + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + + logging.info(f"Device: {device}") + + logging.info("About to create model") + model = TransformerLM( + vocab_size=params.vocab_size, + d_model=params.encoder_dim, + embedding_dim=params.embedding_dim, + dim_feedforward=params.dim_feedforward, + nhead=params.nhead, + num_layers=params.num_layers, + tie_weights=params.tie_weights, + params=params, + ) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + checkpoints = load_checkpoint_if_available(params=params, model=model) + + model.to(device) + if is_distributed: + model = DDP(model, device_ids=[rank]) + + model.device = device + + optimizer = optim.Adam( + model.parameters(), + lr=params.lr, + weight_decay=params.weight_decay, + ) + if checkpoints: + logging.info("Load optimizer state_dict from checkpoint") + optimizer.load_state_dict(checkpoints["optimizer"]) + + logging.info(f"Loading LM training data from {params.lm_data}") + train_dl = get_dataloader( + filename=params.lm_data, + is_distributed=is_distributed, + params=params, + ) + + logging.info(f"Loading LM validation data from {params.lm_data_valid}") + valid_dl = get_dataloader( + filename=params.lm_data_valid, + is_distributed=is_distributed, + params=params, + ) + + # Note: No learning rate scheduler is used here + for epoch in range(params.start_epoch, params.num_epochs): + if is_distributed: + train_dl.sampler.set_epoch(epoch) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + optimizer=optimizer, + train_dl=train_dl, + valid_dl=valid_dl, + tb_writer=tb_writer, + world_size=world_size, + ) + + save_checkpoint( + params=params, + model=model, + optimizer=optimizer, + rank=rank, + ) + + logging.info("Done!") + + if is_distributed: + torch.distributed.barrier() + cleanup_dist() + + +def main(): + parser = get_parser() + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() From aa0fe4e4ac4d9bb4a1082709c103e76f70eb8b6f Mon Sep 17 00:00:00 2001 From: marcoyang1998 <45973641+marcoyang1998@users.noreply.github.com> Date: Thu, 29 Dec 2022 11:54:42 +0800 Subject: [PATCH 064/174] Fix typos in RESULTS.md (#797) --- egs/librispeech/ASR/RESULTS.md | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/egs/librispeech/ASR/RESULTS.md b/egs/librispeech/ASR/RESULTS.md index 007d34a62..05422562c 100644 --- a/egs/librispeech/ASR/RESULTS.md +++ b/egs/librispeech/ASR/RESULTS.md @@ -318,13 +318,13 @@ Number of model parameters: 70369391, i.e., 70.37 M | | test-clean | test-other | comment | |----------------------|------------|-------------|----------------------------------------| -| greedy search | 2.17 | 5.23 | --epoch 39 --avg 6 --max-duration 600 | -| modified beam search | 2.15 | 5.20 | --epoch 39 --avg 6 --max-duration 600 | -| modified beam search + RNNLM shallow fusion | 1.99 | 4.73 | --epoch 39 --avg 6 --max-duration 600 | -| modified beam search + TransformerLM shallow fusion | 1.94 | 4.73 | --epoch 39 --avg 6 --max-duration 600 | -| modified beam search + RNNLM + LODR | 1.91 | 4.57 | --epoch 39 --avg 6 --max-duration 600 | -| modified beam search + TransformerLM + LODR | 1.91 | 4.51 | --epoch 39 --avg 6 --max-duration 600 | -| fast beam search | 2.15 | 5.22 | --epoch 39 --avg 6 --max-duration 600 | +| greedy search | 2.17 | 5.23 | --epoch 30 --avg 9 --max-duration 600 | +| modified beam search | 2.15 | 5.20 | --epoch 30 --avg 9 --max-duration 600 | +| modified beam search + RNNLM shallow fusion | 1.99 | 4.73 | --epoch 30 --avg 9 --max-duration 600 | +| modified beam search + TransformerLM shallow fusion | 1.94 | 4.73 | --epoch 30 --avg 9 --max-duration 600 | +| modified beam search + RNNLM + LODR | 1.91 | 4.57 | --epoch 30 --avg 9 --max-duration 600 | +| modified beam search + TransformerLM + LODR | 1.91 | 4.51 | --epoch 30 --avg 9 --max-duration 600 | +| fast beam search | 2.15 | 5.22 | --epoch 30 --avg 9 --max-duration 600 | The training commands are: ```bash From d167aad4abd5d330da9da1aa006478eb4361cd04 Mon Sep 17 00:00:00 2001 From: Zengwei Yao Date: Fri, 30 Dec 2022 10:52:18 +0800 Subject: [PATCH 065/174] Add streaming zipformer (#787) * add streaming zipformer codes * add test_model.py * add export.py, pretrained.py, jit_pretrained.py * add cached_len for pooling module * add jit_trace_export.py and jit_trace_pretrained.py * fix bug in jit.trace * update RESULTS.md * add CI test * minor fix in pruned_transducer_stateless7/zipformer.py * update README.md --- ...nsducer-stateless7-streaming-2022-12-29.sh | 148 + ...speech-2022-12-29-stateless7-streaming.yml | 172 + egs/librispeech/ASR/README.md | 22 +- egs/librispeech/ASR/RESULTS.md | 78 + .../pruned_transducer_stateless7/scaling.py | 6 +- .../pruned_transducer_stateless7/zipformer.py | 2 +- .../asr_datamodule.py | 1 + .../beam_search.py | 1 + .../decode.py | 813 +++++ .../decode_stream.py | 151 + .../decoder.py | 1 + .../encoder_interface.py | 1 + .../export.py | 320 ++ .../jit_pretrained.py | 278 ++ .../jit_trace_export.py | 313 ++ .../jit_trace_pretrained.py | 295 ++ .../joiner.py | 1 + .../model.py | 1 + .../optim.py | 1 + .../pretrained.py | 355 ++ .../scaling.py | 1 + .../scaling_converter.py | 1 + .../streaming_beam_search.py | 1 + .../streaming_decode.py | 615 ++++ .../test_model.py | 150 + .../train.py | 1264 ++++++++ .../zipformer.py | 2881 +++++++++++++++++ 27 files changed, 7867 insertions(+), 6 deletions(-) create mode 100755 .github/scripts/run-librispeech-pruned-transducer-stateless7-streaming-2022-12-29.sh create mode 100644 .github/workflows/run-librispeech-2022-12-29-stateless7-streaming.yml create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless7_streaming/asr_datamodule.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless7_streaming/beam_search.py create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless7_streaming/decode.py create mode 100644 egs/librispeech/ASR/pruned_transducer_stateless7_streaming/decode_stream.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless7_streaming/decoder.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless7_streaming/encoder_interface.py create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export.py create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless7_streaming/jit_pretrained.py create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless7_streaming/jit_trace_export.py create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless7_streaming/jit_trace_pretrained.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless7_streaming/joiner.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless7_streaming/model.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless7_streaming/optim.py create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless7_streaming/pretrained.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless7_streaming/scaling.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless7_streaming/scaling_converter.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless7_streaming/streaming_beam_search.py create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless7_streaming/test_model.py create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train.py create mode 100644 egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless7-streaming-2022-12-29.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless7-streaming-2022-12-29.sh new file mode 100755 index 000000000..afb0dc05a --- /dev/null +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless7-streaming-2022-12-29.sh @@ -0,0 +1,148 @@ +#!/usr/bin/env bash + +set -e + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29 + +log "Downloading pre-trained model from $repo_url" +git lfs install +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +log "Display test files" +tree $repo/ +soxi $repo/test_wavs/*.wav +ls -lh $repo/test_wavs/*.wav + +pushd $repo/exp +git lfs pull --include "data/lang_bpe_500/bpe.model" +git lfs pull --include "exp/cpu_jit.pt" +git lfs pull --include "exp/pretrained.pt" +git lfs pull --include "exp/encoder_jit_trace.pt" +git lfs pull --include "exp/decoder_jit_trace.pt" +git lfs pull --include "exp/joiner_jit_trace.pt" +ln -s pretrained.pt epoch-99.pt +ls -lh *.pt +popd + +log "Export to torchscript model" +./pruned_transducer_stateless7_streaming/export.py \ + --exp-dir $repo/exp \ + --use-averaged-model false \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --decode-chunk-len 32 \ + --epoch 99 \ + --avg 1 \ + --jit 1 + +ls -lh $repo/exp/*.pt + +log "Decode with models exported by torch.jit.script()" + +./pruned_transducer_stateless7_streaming/jit_pretrained.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --nn-model-filename $repo/exp/cpu_jit.pt \ + --decode-chunk-len 32 \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + +log "Export to torchscript model by torch.jit.trace()" +./pruned_transducer_stateless7_streaming/jit_trace_export.py \ + --exp-dir $repo/exp \ + --use-averaged-model false \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --decode-chunk-len 32 \ + --epoch 99 \ + --avg 1 + +log "Decode with models exported by torch.jit.trace()" + +./pruned_transducer_stateless7_streaming/jit_trace_pretrained.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --encoder-model-filename $repo/exp/encoder_jit_trace.pt \ + --decoder-model-filename $repo/exp/decoder_jit_trace.pt \ + --joiner-model-filename $repo/exp/joiner_jit_trace.pt \ + --decode-chunk-len 32 \ + $repo/test_wavs/1089-134686-0001.wav + +for sym in 1 2 3; do + log "Greedy search with --max-sym-per-frame $sym" + + ./pruned_transducer_stateless7_streaming/pretrained.py \ + --method greedy_search \ + --max-sym-per-frame $sym \ + --checkpoint $repo/exp/pretrained.pt \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --decode-chunk-len 32 \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav +done + +for method in modified_beam_search beam_search fast_beam_search; do + log "$method" + + ./pruned_transducer_stateless7_streaming/pretrained.py \ + --method $method \ + --beam-size 4 \ + --checkpoint $repo/exp/pretrained.pt \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --decode-chunk-len 32 \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav +done + +echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}" +echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}" +if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then + mkdir -p pruned_transducer_stateless7_streaming/exp + ln -s $PWD/$repo/exp/pretrained.pt pruned_transducer_stateless7_streaming/exp/epoch-999.pt + ln -s $PWD/$repo/data/lang_bpe_500 data/ + + ls -lh data + ls -lh pruned_transducer_stateless7_streaming/exp + + log "Decoding test-clean and test-other" + + # use a small value for decoding with CPU + max_duration=100 + num_decode_stream=200 + + for method in greedy_search fast_beam_search modified_beam_search; do + log "decoding with $method" + + ./pruned_transducer_stateless7_streaming/decode.py \ + --decoding-method $method \ + --epoch 999 \ + --avg 1 \ + --use-averaged-model 0 \ + --max-duration $max_duration \ + --decode-chunk-len 32 \ + --exp-dir pruned_transducer_stateless7_streaming/exp + done + + for method in greedy_search fast_beam_search modified_beam_search; do + log "Decoding with $method" + + ./pruned_transducer_stateless7_streaming/streaming_decode.py \ + --decoding-method $method \ + --epoch 999 \ + --avg 1 \ + --use-averaged-model 0 \ + --decode-chunk-len 32 \ + --num-decode-streams $num_decode_stream + --exp-dir pruned_transducer_stateless7_streaming/exp + done + + rm pruned_transducer_stateless7_streaming/exp/*.pt +fi diff --git a/.github/workflows/run-librispeech-2022-12-29-stateless7-streaming.yml b/.github/workflows/run-librispeech-2022-12-29-stateless7-streaming.yml new file mode 100644 index 000000000..6dd93946a --- /dev/null +++ b/.github/workflows/run-librispeech-2022-12-29-stateless7-streaming.yml @@ -0,0 +1,172 @@ +# Copyright 2022 Fangjun Kuang (csukuangfj@gmail.com) + +# See ../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: run-librispeech-2022-12-29-stateless7-streaming +# zipformer + +on: + push: + branches: + - master + pull_request: + types: [labeled] + + schedule: + # minute (0-59) + # hour (0-23) + # day of the month (1-31) + # month (1-12) + # day of the week (0-6) + # nightly build at 15:50 UTC time every day + - cron: "50 15 * * *" + +concurrency: + group: run_librispeech_2022_12_29_zipformer_streaming-${{ github.ref }} + cancel-in-progress: true + +jobs: + run_librispeech_2022_12_29_zipformer_streaming: + if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event.label.name == 'streaming-zipformer' || github.event_name == 'push' || github.event_name == 'schedule' + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-latest] + python-version: [3.8] + + fail-fast: false + + steps: + - uses: actions/checkout@v2 + with: + fetch-depth: 0 + + - name: Setup Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + cache: 'pip' + cache-dependency-path: '**/requirements-ci.txt' + + - name: Install Python dependencies + run: | + grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install + pip uninstall -y protobuf + pip install --no-binary protobuf protobuf + + - name: Cache kaldifeat + id: my-cache + uses: actions/cache@v2 + with: + path: | + ~/tmp/kaldifeat + key: cache-tmp-${{ matrix.python-version }}-2022-09-25 + + - name: Install kaldifeat + if: steps.my-cache.outputs.cache-hit != 'true' + shell: bash + run: | + .github/scripts/install-kaldifeat.sh + + - name: Cache LibriSpeech test-clean and test-other datasets + id: libri-test-clean-and-test-other-data + uses: actions/cache@v2 + with: + path: | + ~/tmp/download + key: cache-libri-test-clean-and-test-other + + - name: Download LibriSpeech test-clean and test-other + if: steps.libri-test-clean-and-test-other-data.outputs.cache-hit != 'true' + shell: bash + run: | + .github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh + + - name: Prepare manifests for LibriSpeech test-clean and test-other + shell: bash + run: | + .github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh + + - name: Cache LibriSpeech test-clean and test-other fbank features + id: libri-test-clean-and-test-other-fbank + uses: actions/cache@v2 + with: + path: | + ~/tmp/fbank-libri + key: cache-libri-fbank-test-clean-and-test-other-v2 + + - name: Compute fbank for LibriSpeech test-clean and test-other + if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true' + shell: bash + run: | + .github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh + + - name: Inference with pre-trained model + shell: bash + env: + GITHUB_EVENT_NAME: ${{ github.event_name }} + GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }} + run: | + mkdir -p egs/librispeech/ASR/data + ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank + ls -lh egs/librispeech/ASR/data/* + + sudo apt-get -qq install git-lfs tree sox + export PYTHONPATH=$PWD:$PYTHONPATH + export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH + export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH + + .github/scripts/run-librispeech-pruned-transducer-stateless7-streaming-2022-12-29.sh + + - name: Display decoding results for librispeech pruned_transducer_stateless7_streaming + if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' + shell: bash + run: | + cd egs/librispeech/ASR/ + tree ./pruned_transducer_stateless7_streaming/exp + + cd pruned_transducer_stateless7_streaming + echo "results for pruned_transducer_stateless7_streaming" + echo "===greedy search===" + find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 + find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 + + echo "===fast_beam_search===" + find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 + find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 + + echo "===modified beam search===" + find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 + find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 + + echo "===streaming greedy search===" + find exp/streaming/greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 + find exp/streaming/greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 + + echo "===streaming fast_beam_search===" + find exp/streaming/fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 + find exp/streaming/fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 + + echo "===streaming modified beam search===" + find exp/streaming/modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 + find exp/streaming/modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 + + + - name: Upload decoding results for librispeech pruned_transducer_stateless7_streaming + uses: actions/upload-artifact@v2 + if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' + with: + name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-pruned_transducer_stateless7-streaming-2022-12-29 + path: egs/librispeech/ASR/pruned_transducer_stateless7_streaming/exp/ diff --git a/egs/librispeech/ASR/README.md b/egs/librispeech/ASR/README.md index caa23a49f..94cb445a8 100644 --- a/egs/librispeech/ASR/README.md +++ b/egs/librispeech/ASR/README.md @@ -19,18 +19,36 @@ The following table lists the differences among them. | `pruned_transducer_stateless` | Conformer | Embedding + Conv1d | Using k2 pruned RNN-T loss | | `pruned_transducer_stateless2` | Conformer(modified) | Embedding + Conv1d | Using k2 pruned RNN-T loss | | `pruned_transducer_stateless3` | Conformer(modified) | Embedding + Conv1d | Using k2 pruned RNN-T loss + using GigaSpeech as extra training data | -| `pruned_transducer_stateless4` | Conformer(modified) | Embedding + Conv1d | same as pruned_transducer_stateless2 + save averaged models periodically during training | +| `pruned_transducer_stateless4` | Conformer(modified) | Embedding + Conv1d | same as pruned_transducer_stateless2 + save averaged models periodically during training + delay penalty | | `pruned_transducer_stateless5` | Conformer(modified) | Embedding + Conv1d | same as pruned_transducer_stateless4 + more layers + random combiner| | `pruned_transducer_stateless6` | Conformer(modified) | Embedding + Conv1d | same as pruned_transducer_stateless4 + distillation with hubert| | `pruned_transducer_stateless7` | Zipformer | Embedding + Conv1d | First experiment with Zipformer from Dan| | `pruned_transducer_stateless7_ctc` | Zipformer | Embedding + Conv1d | Same as pruned_transducer_stateless7, but with extra CTC head| +| `pruned_transducer_stateless7_ctc_bs` | Zipformer | Embedding + Conv1d | pruned_transducer_stateless7_ctc + blank skip | +| `pruned_transducer_stateless7_streaming` | Streaming Zipformer | Embedding + Conv1d | streaming version of pruned_transducer_stateless7 | | `pruned_transducer_stateless8` | Zipformer | Embedding + Conv1d | Same as pruned_transducer_stateless7, but using extra data from GigaSpeech| | `pruned_stateless_emformer_rnnt2` | Emformer(from torchaudio) | Embedding + Conv1d | Using Emformer from torchaudio for streaming ASR| | `conv_emformer_transducer_stateless` | ConvEmformer | Embedding + Conv1d | Using ConvEmformer for streaming ASR + mechanisms in reworked model | | `conv_emformer_transducer_stateless2` | ConvEmformer | Embedding + Conv1d | Using ConvEmformer with simplified memory for streaming ASR + mechanisms in reworked model | | `lstm_transducer_stateless` | LSTM | Embedding + Conv1d | Using LSTM with mechanisms in reworked model | -| `lstm_transducer_stateless2` | LSTM | Embedding + Conv1d | Using LSTM with mechanisms in reworked model + gigaspeech (multi-dataset setup) | +| `lstm_transducer_stateless2` | LSTM | Embedding + Conv1d | Using LSTM with mechanisms in reworked model + gigaspeech (multi-dataset setup) | +| `lstm_transducer_stateless3` | LSTM | Embedding + Conv1d | Using LSTM with mechanisms in reworked model + gradient filter + delay penalty | The decoder in `transducer_stateless` is modified from the paper [Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/). We place an additional Conv1d layer right after the input embedding layer. + +# CTC + +| | Encoder | Comment | +|------------------------------|--------------------|------------------------------| +| `conformer-ctc` | Conformer | Use auxiliary attention head | +| `conformer-ctc2` | Reworked Conformer | Use auxiliary attention head | +| `conformer-ctc3` | Reworked Conformer | Streaming version + delay penalty | + +# MMI + +| | Encoder | Comment | +|------------------------------|-----------|---------------------------------------------------| +| `conformer-mmi` | Conformer | | +| `zipformer-mmi` | Zipformer | CTC warmup + use HP as decoding graph for decoding | diff --git a/egs/librispeech/ASR/RESULTS.md b/egs/librispeech/ASR/RESULTS.md index 05422562c..b30cf7c1f 100644 --- a/egs/librispeech/ASR/RESULTS.md +++ b/egs/librispeech/ASR/RESULTS.md @@ -1,5 +1,83 @@ ## Results +### Streaming Zipformer-Transducer (Pruned Stateless Transducer + Streaming Zipformer) + +#### [pruned_transducer_stateless7_streaming](./pruned_transducer_stateless7_streaming) + +See for more details. + +You can find a pretrained model, training logs, decoding logs, and decoding +results at: + + +Number of model parameters: 70369391, i.e., 70.37 M + +##### training on full librispeech + +The WERs are: + +| decoding method | chunk size | test-clean | test-other | comment | decoding mode | +|----------------------|------------|------------|------------|---------------------|----------------------| +| greedy search | 320ms | 3.15 | 8.09 | --epoch 30 --avg 9 | simulated streaming | +| greedy search | 320ms | 3.17 | 8.24 | --epoch 30 --avg 9 | chunk-wise | +| fast beam search | 320ms | 3.2 | 8.04 | --epoch 30 --avg 9 | simulated streaming | +| fast beam search | 320ms | 3.36 | 8.19 | --epoch 30 --avg 9 | chunk-wise | +| modified beam search | 320ms | 3.11 | 7.93 | --epoch 30 --avg 9 | simulated streaming | +| modified beam search | 320ms | 3.12 | 8.11 | --epoch 30 --avg 9 | chunk-size | +| greedy search | 640ms | 2.97 | 7.5 | --epoch 30 --avg 9 | simulated streaming | +| greedy search | 640ms | 2.98 | 7.67 | --epoch 30 --avg 9 | chunk-wise | +| fast beam search | 640ms | 3.02 | 7.47 | --epoch 30 --avg 9 | simulated streaming | +| fast beam search | 640ms | 2.96 | 7.61 | --epoch 30 --avg 9 | chunk-wise | +| modified beam search | 640ms | 2.94 | 7.36 | --epoch 30 --avg 9 | simulated streaming | +| modified beam search | 640ms | 2.95 | 7.53 | --epoch 30 --avg 9 | chunk-size | + +Note: `simulated streaming` indicates feeding full utterance during decoding using `decode.py`, +while `chunk-size` indicates feeding certain number of frames at each time using `streaming_decode.py`. + +The training command is: + +```bash +./pruned_transducer_stateless7_streaming/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir pruned_transducer_stateless7_streaming/exp \ + --full-libri 1 \ + --max-duration 750 \ + --master-port 12345 +``` + +The tensorboard log can be found at + + +The simulated streaming decoding command (e.g., chunk-size=320ms) is: +```bash +for $m in greedy_search fast_beam_search modified_beam_search; do + ./pruned_transducer_stateless7_streaming/decode.py \ + --epoch 30 \ + --avg 9 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --max-duration 600 \ + --decode-chunk-len 32 \ + --decoding-method $m +done +``` + +The streaming chunk-size decoding command (e.g., chunk-size=320ms) is: +```bash +for m in greedy_search modified_beam_search fast_beam_search; do + ./pruned_transducer_stateless7_streaming/streaming_decode.py \ + --epoch 30 \ + --avg 9 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --decoding-method $m \ + --decode-chunk-len 32 \ + --num-decode-streams 2000 +done +``` + + ### zipformer_mmi (zipformer with mmi loss) See for more details. diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 042c9c3e4..1cbde6db0 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -298,7 +298,7 @@ class SoftmaxFunction(torch.autograd.Function): def softmax(x: Tensor, dim: int): - if torch.jit.is_scripting(): + if torch.jit.is_scripting() or torch.jit.is_tracing(): return x.softmax(dim) return SoftmaxFunction.apply(x, dim) @@ -783,7 +783,7 @@ class WithLoss(torch.autograd.Function): def with_loss(x, y): - if torch.jit.is_scripting(): + if torch.jit.is_scripting() or torch.jit.is_tracing(): return x # returns x but adds y.sum() to the loss function. return WithLoss.apply(x, y) @@ -1013,7 +1013,7 @@ class DoubleSwish(torch.nn.Module): """Return double-swish activation function which is an approximation to Swish(Swish(x)), that we approximate closely with x * sigmoid(x-1). """ - if torch.jit.is_scripting(): + if torch.jit.is_scripting() or torch.jit.is_tracing(): return x * torch.sigmoid(x - 1.0) return DoubleSwishFunction.apply(x) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index ad3b88df0..d18258085 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -907,7 +907,7 @@ class RelPositionalEncoding(torch.nn.Module): self.d_model = d_model self.dropout = torch.nn.Dropout(dropout_rate) self.pe = None - self.extend_pe(torch.tensor(0.0).expand(1, max_len)) + self.extend_pe(torch.tensor(0.0).expand(max_len)) def extend_pe(self, x: Tensor) -> None: """Reset the positional encodings.""" diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/asr_datamodule.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/asr_datamodule.py new file mode 120000 index 000000000..a074d6085 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/asr_datamodule.py @@ -0,0 +1 @@ +../pruned_transducer_stateless2/asr_datamodule.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/beam_search.py new file mode 120000 index 000000000..8554e44cc --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/beam_search.py @@ -0,0 +1 @@ +../pruned_transducer_stateless2/beam_search.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/decode.py new file mode 100755 index 000000000..aebe2b94b --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/decode.py @@ -0,0 +1,813 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) greedy search +./pruned_transducer_stateless7_streaming/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --max-duration 600 \ + --decode-chunk-len 32 \ + --decoding-method greedy_search + +(2) beam search (not recommended) +./pruned_transducer_stateless7_streaming/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --max-duration 600 \ + --decode-chunk-len 32 \ + --decoding-method beam_search \ + --beam-size 4 + +(3) modified beam search +./pruned_transducer_stateless7_streaming/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --max-duration 600 \ + --decode-chunk-len 32 \ + --decoding-method modified_beam_search \ + --beam-size 4 + +(4) fast beam search (one best) +./pruned_transducer_stateless7_streaming/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --max-duration 600 \ + --decode-chunk-len 32 \ + --decoding-method fast_beam_search \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 + +(5) fast beam search (nbest) +./pruned_transducer_stateless7_streaming/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --max-duration 600 \ + --decode-chunk-len 32 \ + --decoding-method fast_beam_search_nbest \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(6) fast beam search (nbest oracle WER) +./pruned_transducer_stateless7_streaming/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --max-duration 600 \ + --decode-chunk-len 32 \ + --decoding-method fast_beam_search_nbest_oracle \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(7) fast beam search (with LG) +./pruned_transducer_stateless7_streaming/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --max-duration 600 \ + --decode-chunk-len 32 \ + --decoding-method fast_beam_search_nbest_LG \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 +""" + + +import argparse +import logging +import math +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from beam_search import ( + beam_search, + fast_beam_search_nbest, + fast_beam_search_nbest_LG, + fast_beam_search_nbest_oracle, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, +) +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=9, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7_streaming/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_bpe_500", + help="The lang dir containing word table and LG graph", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - fast_beam_search + - fast_beam_search_nbest + - fast_beam_search_nbest_oracle + - fast_beam_search_nbest_LG + If you use fast_beam_search_nbest_LG, you have to specify + `--lang-dir`, which should contain `LG.pt`. + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=20.0, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search, + fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle + """, + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.01, + help=""" + Used only when --decoding_method is fast_beam_search_nbest_LG. + It specifies the scale for n-gram LM scores. + """, + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=8, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=64, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. + Used only when --decoding_method is greedy_search""", + ) + + parser.add_argument( + "--num-paths", + type=int, + default=200, + help="""Number of paths for nbest decoding. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""Scale applied to lattice scores when computing nbest paths. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + add_model_arguments(parser) + + return parser + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + batch: dict, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if greedy_search is used, it would be "greedy_search" + If beam search with a beam size of 7 is used, it would be + "beam_7" + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return the decoding result. See above description for the format of + the returned dict. + """ + device = next(model.parameters()).device + feature = batch["inputs"] + assert feature.ndim == 3 + + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + feature_lens += 30 + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, 30), + value=LOG_EPS, + ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + + hyps = [] + + if params.decoding_method == "fast_beam_search": + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_LG": + hyp_tokens = fast_beam_search_nbest_LG( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in hyp_tokens: + hyps.append([word_table[i] for i in hyp]) + elif params.decoding_method == "fast_beam_search_nbest": + hyp_tokens = fast_beam_search_nbest( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_oracle": + hyp_tokens = fast_beam_search_nbest_oracle( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + ref_texts=sp.encode(supervisions["text"]), + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + else: + batch_size = encoder_out.size(0) + + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.decoding_method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyps.append(sp.decode(hyp).split()) + + if params.decoding_method == "greedy_search": + return {"greedy_search": hyps} + elif "fast_beam_search" in params.decoding_method: + key = f"beam_{params.beam}_" + key += f"max_contexts_{params.max_contexts}_" + key += f"max_states_{params.max_states}" + if "nbest" in params.decoding_method: + key += f"_num_paths_{params.num_paths}_" + key += f"nbest_scale_{params.nbest_scale}" + if "LG" in params.decoding_method: + key += f"_ngram_lm_scale_{params.ngram_lm_scale}" + + return {key: hyps} + else: + return {f"beam_size_{params.beam_size}": hyps} + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + if params.decoding_method == "greedy_search": + log_interval = 50 + else: + log_interval = 20 + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + word_table=word_table, + batch=batch, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = ( + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + assert params.decoding_method in ( + "greedy_search", + "beam_search", + "fast_beam_search", + "fast_beam_search_nbest", + "fast_beam_search_nbest_LG", + "fast_beam_search_nbest_oracle", + "modified_beam_search", + ) + params.res_dir = params.exp_dir / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + params.suffix += f"-streaming-chunk-size-{params.decode_chunk_len}" + + if "fast_beam_search" in params.decoding_method: + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + if "nbest" in params.decoding_method: + params.suffix += f"-nbest-scale-{params.nbest_scale}" + params.suffix += f"-num-paths-{params.num_paths}" + if "LG" in params.decoding_method: + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" + elif "beam_search" in params.decoding_method: + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + else: + params.suffix += f"-context-{params.context_size}" + params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and are defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + assert model.encoder.decode_chunk_size == params.decode_chunk_len // 2, ( + model.encoder.decode_chunk_size, + params.decode_chunk_len, + ) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + + if "fast_beam_search" in params.decoding_method: + if params.decoding_method == "fast_beam_search_nbest_LG": + lexicon = Lexicon(params.lang_dir) + word_table = lexicon.word_table + lg_filename = params.lang_dir / "LG.pt" + logging.info(f"Loading {lg_filename}") + decoding_graph = k2.Fsa.from_dict( + torch.load(lg_filename, map_location=device) + ) + decoding_graph.scores *= params.ngram_lm_scale + else: + word_table = None + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + else: + decoding_graph = None + word_table = None + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + librispeech = LibriSpeechAsrDataModule(args) + + test_clean_cuts = librispeech.test_clean_cuts() + test_other_cuts = librispeech.test_other_cuts() + + test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) + test_other_dl = librispeech.test_dataloaders(test_other_cuts) + + test_sets = ["test-clean", "test-other"] + test_dl = [test_clean_dl, test_other_dl] + + for test_set, test_dl in zip(test_sets, test_dl): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + sp=sp, + word_table=word_table, + decoding_graph=decoding_graph, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/decode_stream.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/decode_stream.py new file mode 100644 index 000000000..0d7e86fcf --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/decode_stream.py @@ -0,0 +1,151 @@ +# Copyright 2022 Xiaomi Corp. (authors: Wei Kang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import List, Optional, Tuple + +import k2 +import torch +from beam_search import Hypothesis, HypothesisList + +from icefall.utils import AttributeDict + + +class DecodeStream(object): + def __init__( + self, + params: AttributeDict, + cut_id: str, + initial_states: List[torch.Tensor], + decoding_graph: Optional[k2.Fsa] = None, + device: torch.device = torch.device("cpu"), + ) -> None: + """ + Args: + initial_states: + Initial decode states of the model, e.g. the return value of + `get_init_state` in conformer.py + decoding_graph: + Decoding graph used for decoding, may be a TrivialGraph or a HLG. + Used only when decoding_method is fast_beam_search. + device: + The device to run this stream. + """ + if params.decoding_method == "fast_beam_search": + assert decoding_graph is not None + assert device == decoding_graph.device + + self.params = params + self.cut_id = cut_id + self.LOG_EPS = math.log(1e-10) + + self.states = initial_states + + # It contains a 2-D tensors representing the feature frames. + self.features: torch.Tensor = None + + self.num_frames: int = 0 + # how many frames have been processed. (before subsampling). + # we only modify this value in `func:get_feature_frames`. + self.num_processed_frames: int = 0 + + self._done: bool = False + + # The transcript of current utterance. + self.ground_truth: str = "" + + # The decoding result (partial or final) of current utterance. + self.hyp: List = [] + + # how many frames have been processed, after subsampling (i.e. a + # cumulative sum of the second return value of + # encoder.streaming_forward + self.done_frames: int = 0 + + # It has two steps of feature subsampling in zipformer: out_lens=((x_lens-7)//2+1)//2 + # 1) feature embedding: out_lens=(x_lens-7)//2 + # 2) output subsampling: out_lens=(out_lens+1)//2 + self.pad_length = 7 + + if params.decoding_method == "greedy_search": + self.hyp = [params.blank_id] * params.context_size + elif params.decoding_method == "modified_beam_search": + self.hyps = HypothesisList() + self.hyps.add( + Hypothesis( + ys=[params.blank_id] * params.context_size, + log_prob=torch.zeros(1, dtype=torch.float32, device=device), + ) + ) + elif params.decoding_method == "fast_beam_search": + # The rnnt_decoding_stream for fast_beam_search. + self.rnnt_decoding_stream: k2.RnntDecodingStream = k2.RnntDecodingStream( + decoding_graph + ) + else: + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + + @property + def done(self) -> bool: + """Return True if all the features are processed.""" + return self._done + + @property + def id(self) -> str: + return self.cut_id + + def set_features( + self, + features: torch.Tensor, + tail_pad_len: int = 0, + ) -> None: + """Set features tensor of current utterance.""" + assert features.dim() == 2, features.dim() + self.features = torch.nn.functional.pad( + features, + (0, 0, 0, self.pad_length + tail_pad_len), + mode="constant", + value=self.LOG_EPS, + ) + self.num_frames = self.features.size(0) + + def get_feature_frames(self, chunk_size: int) -> Tuple[torch.Tensor, int]: + """Consume chunk_size frames of features""" + chunk_length = chunk_size + self.pad_length + + ret_length = min(self.num_frames - self.num_processed_frames, chunk_length) + + ret_features = self.features[ + self.num_processed_frames : self.num_processed_frames + ret_length # noqa + ] + + self.num_processed_frames += chunk_size + if self.num_processed_frames >= self.num_frames: + self._done = True + + return ret_features, ret_length + + def decoding_result(self) -> List[int]: + """Obtain current decoding result.""" + if self.params.decoding_method == "greedy_search": + return self.hyp[self.params.context_size :] # noqa + elif self.params.decoding_method == "modified_beam_search": + best_hyp = self.hyps.get_most_probable(length_norm=True) + return best_hyp.ys[self.params.context_size :] # noqa + else: + assert self.params.decoding_method == "fast_beam_search" + return self.hyp diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/decoder.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/decoder.py new file mode 120000 index 000000000..33944d0d2 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/decoder.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7/decoder.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/encoder_interface.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/encoder_interface.py new file mode 120000 index 000000000..b9aa0ae08 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/encoder_interface.py @@ -0,0 +1 @@ +../pruned_transducer_stateless2/encoder_interface.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export.py new file mode 100755 index 000000000..5c06cc052 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export.py @@ -0,0 +1,320 @@ +#!/usr/bin/env python3 +# +# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This script converts several saved checkpoints +# to a single one using model averaging. +""" + +Usage: + +(1) Export to torchscript model using torch.jit.script() + +./pruned_transducer_stateless7_streaming/export.py \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 30 \ + --avg 9 \ + --jit 1 + +It will generate a file `cpu_jit.pt` in the given `exp_dir`. You can later +load it by `torch.jit.load("cpu_jit.pt")`. + +Note `cpu` in the name `cpu_jit.pt` means the parameters when loaded into Python +are on CPU. You can use `to("cuda")` to move them to a CUDA device. + +Check +https://github.com/k2-fsa/sherpa +for how to use the exported models outside of icefall. + +(2) Export `model.state_dict()` + +./pruned_transducer_stateless7_streaming/export.py \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 20 \ + --avg 10 + +It will generate a file `pretrained.pt` in the given `exp_dir`. You can later +load it by `icefall.checkpoint.load_checkpoint()`. + +To use the generated file with `pruned_transducer_stateless7_streaming/decode.py`, +you can do: + + cd /path/to/exp_dir + ln -s pretrained.pt epoch-9999.pt + + cd /path/to/egs/librispeech/ASR + ./pruned_transducer_stateless7_streaming/decode.py \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --epoch 9999 \ + --avg 1 \ + --max-duration 600 \ + --decoding-method greedy_search \ + --bpe-model data/lang_bpe_500/bpe.model + +Check ./pretrained.py for its usage. + +Note: If you don't want to train a model from scratch, we have +provided one for you. You can get it at + +https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11 + +with the following commands: + + sudo apt-get install git-lfs + git lfs install + git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11 + # You will find the pre-trained model in icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11/exp +""" + +import argparse +import logging +from pathlib import Path + +import sentencepiece as spm +import torch +import torch.nn as nn +from scaling_converter import convert_scaled_to_non_scaled +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=9, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7_streaming/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--jit", + type=str2bool, + default=False, + help="""True to save a model after applying torch.jit.script. + It will generate a file named cpu_jit.pt + + Check ./jit_pretrained.py for how to use it. + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + add_model_arguments(parser) + + return parser + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + model.to(device) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to("cpu") + model.eval() + + if params.jit is True: + convert_scaled_to_non_scaled(model, inplace=True) + # We won't use the forward() method of the model in C++, so just ignore + # it here. + # Otherwise, one of its arguments is a ragged tensor and is not + # torch scriptabe. + model.__class__.forward = torch.jit.ignore(model.__class__.forward) + logging.info("Using torch.jit.script") + model = torch.jit.script(model) + filename = params.exp_dir / "cpu_jit.pt" + model.save(str(filename)) + logging.info(f"Saved to {filename}") + else: + logging.info("Not using torchscript. Export model.state_dict()") + # Save it using a format so that it can be loaded + # by :func:`load_checkpoint` + filename = params.exp_dir / "pretrained.pt" + torch.save({"model": model.state_dict()}, str(filename)) + logging.info(f"Saved to {filename}") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/jit_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/jit_pretrained.py new file mode 100755 index 000000000..4fd5e1820 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/jit_pretrained.py @@ -0,0 +1,278 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads torchscript models, exported by `torch.jit.script()` +and uses them to decode waves. +You can use the following command to get the exported models: + +./pruned_transducer_stateless7_streaming/export.py \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 20 \ + --avg 10 \ + --jit 1 + +Usage of this script: + +./pruned_transducer_stateless7_streaming/jit_pretrained.py \ + --nn-model-filename ./pruned_transducer_stateless7_streaming/exp/cpu_jit.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + /path/to/foo.wav \ + /path/to/bar.wav +""" + +import argparse +import logging +import math +from typing import List + +import kaldifeat +import sentencepiece as spm +import torch +import torchaudio +from torch.nn.utils.rnn import pad_sequence + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--nn-model-filename", + type=str, + required=True, + help="Path to the torchscript model cpu_jit.pt", + ) + + parser.add_argument( + "--bpe-model", + type=str, + help="""Path to bpe.model.""", + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + parser.add_argument( + "--decode-chunk-len", + type=int, + default=32, + help="The chunk size for decoding (in frames before subsampling)", + ) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float = 16000 +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0]) + return ans + + +def greedy_search( + model: torch.jit.ScriptModule, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, +) -> List[List[int]]: + """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. + Args: + model: + The transducer model. + encoder_out: + A 3-D tensor of shape (N, T, C) + encoder_out_lens: + A 1-D tensor of shape (N,). + Returns: + Return the decoded results for each utterance. + """ + assert encoder_out.ndim == 3 + assert encoder_out.size(0) >= 1, encoder_out.size(0) + + packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( + input=encoder_out, + lengths=encoder_out_lens.cpu(), + batch_first=True, + enforce_sorted=False, + ) + + device = encoder_out.device + blank_id = 0 # hard-code to 0 + + batch_size_list = packed_encoder_out.batch_sizes.tolist() + N = encoder_out.size(0) + + assert torch.all(encoder_out_lens > 0), encoder_out_lens + assert N == batch_size_list[0], (N, batch_size_list) + + context_size = model.decoder.context_size + hyps = [[blank_id] * context_size for _ in range(N)] + + decoder_input = torch.tensor( + hyps, + device=device, + dtype=torch.int64, + ) # (N, context_size) + + decoder_out = model.decoder( + decoder_input, + need_pad=torch.tensor([False]), + ).squeeze(1) + + offset = 0 + for batch_size in batch_size_list: + start = offset + end = offset + batch_size + current_encoder_out = packed_encoder_out.data[start:end] + current_encoder_out = current_encoder_out + # current_encoder_out's shape: (batch_size, encoder_out_dim) + offset = end + + decoder_out = decoder_out[:batch_size] + + logits = model.joiner( + current_encoder_out, + decoder_out, + ) + # logits'shape (batch_size, vocab_size) + + assert logits.ndim == 2, logits.shape + y = logits.argmax(dim=1).tolist() + emitted = False + for i, v in enumerate(y): + if v != blank_id: + hyps[i].append(v) + emitted = True + if emitted: + # update decoder output + decoder_input = [h[-context_size:] for h in hyps[:batch_size]] + decoder_input = torch.tensor( + decoder_input, + device=device, + dtype=torch.int64, + ) + decoder_out = model.decoder( + decoder_input, + need_pad=torch.tensor([False]), + ) + decoder_out = decoder_out.squeeze(1) + + sorted_ans = [h[context_size:] for h in hyps] + ans = [] + unsorted_indices = packed_encoder_out.unsorted_indices.tolist() + for i in range(N): + ans.append(sorted_ans[unsorted_indices[i]]) + + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + logging.info(vars(args)) + + device = torch.device("cpu") + + logging.info(f"device: {device}") + + model = torch.jit.load(args.nn_model_filename) + model.encoder.decode_chunk_size = args.decode_chunk_len // 2 + + model.eval() + + model.to(device) + + sp = spm.SentencePieceProcessor() + sp.load(args.bpe_model) + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = 16000 + opts.mel_opts.num_bins = 80 + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {args.sound_files}") + waves = read_sound_files( + filenames=args.sound_files, + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + + features = pad_sequence( + features, + batch_first=True, + padding_value=math.log(1e-10), + ) + + feature_lengths = torch.tensor(feature_lengths, device=device) + + encoder_out, encoder_out_lens = model.encoder( + x=features, + x_lens=feature_lengths, + ) + + hyps = greedy_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + s = "\n" + for filename, hyp in zip(args.sound_files, hyps): + words = sp.decode(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/jit_trace_export.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/jit_trace_export.py new file mode 100755 index 000000000..a164f3f69 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/jit_trace_export.py @@ -0,0 +1,313 @@ +#!/usr/bin/env python3 + +""" +Usage: +./pruned_transducer_stateless7_streaming/jit_trace_export.py \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 30 \ + --avg 10 \ + --use-averaged-model=True \ + --decode-chunk-len 32 +""" + +import argparse +import logging +from pathlib import Path + +import sentencepiece as spm +import torch +from scaling_converter import convert_scaled_to_non_scaled +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import AttributeDict, str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="""It specifies the checkpoint to use for averaging. + Note: Epoch counts from 0. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless2/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + add_model_arguments(parser) + + return parser + + +def export_encoder_model_jit_trace( + encoder_model: torch.nn.Module, + encoder_filename: str, + params: AttributeDict, +) -> None: + """Export the given encoder model with torch.jit.trace() + + Note: The warmup argument is fixed to 1. + + Args: + encoder_model: + The input encoder model + encoder_filename: + The filename to save the exported model. + """ + decode_chunk_len = params.decode_chunk_len # before subsampling + pad_length = 7 + s = f"decode_chunk_len: {decode_chunk_len}" + logging.info(s) + assert encoder_model.decode_chunk_size == decode_chunk_len // 2, ( + encoder_model.decode_chunk_size, + decode_chunk_len, + ) + + T = decode_chunk_len + pad_length + + x = torch.zeros(1, T, 80, dtype=torch.float32) + x_lens = torch.full((1,), T, dtype=torch.int32) + states = encoder_model.get_init_state(device=x.device) + + encoder_model.__class__.forward = encoder_model.__class__.streaming_forward + traced_model = torch.jit.trace(encoder_model, (x, x_lens, states)) + traced_model.save(encoder_filename) + logging.info(f"Saved to {encoder_filename}") + + +def export_decoder_model_jit_trace( + decoder_model: torch.nn.Module, + decoder_filename: str, +) -> None: + """Export the given decoder model with torch.jit.trace() + + Note: The argument need_pad is fixed to False. + + Args: + decoder_model: + The input decoder model + decoder_filename: + The filename to save the exported model. + """ + y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64) + need_pad = torch.tensor([False]) + + traced_model = torch.jit.trace(decoder_model, (y, need_pad)) + traced_model.save(decoder_filename) + logging.info(f"Saved to {decoder_filename}") + + +def export_joiner_model_jit_trace( + joiner_model: torch.nn.Module, + joiner_filename: str, +) -> None: + """Export the given joiner model with torch.jit.trace() + + Note: The argument project_input is fixed to True. A user should not + project the encoder_out/decoder_out by himself/herself. The exported joiner + will do that for the user. + + Args: + joiner_model: + The input joiner model + joiner_filename: + The filename to save the exported model. + + """ + encoder_out_dim = joiner_model.encoder_proj.weight.shape[1] + decoder_out_dim = joiner_model.decoder_proj.weight.shape[1] + encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32) + decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32) + + traced_model = torch.jit.trace(joiner_model, (encoder_out, decoder_out)) + traced_model.save(joiner_filename) + logging.info(f"Saved to {joiner_filename}") + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + + logging.info(f"device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to("cpu") + model.eval() + + convert_scaled_to_non_scaled(model, inplace=True) + logging.info("Using torch.jit.trace()") + + logging.info("Exporting encoder") + encoder_filename = params.exp_dir / "encoder_jit_trace.pt" + export_encoder_model_jit_trace(model.encoder, encoder_filename, params) + + logging.info("Exporting decoder") + decoder_filename = params.exp_dir / "decoder_jit_trace.pt" + export_decoder_model_jit_trace(model.decoder, decoder_filename) + + logging.info("Exporting joiner") + joiner_filename = params.exp_dir / "joiner_jit_trace.pt" + export_joiner_model_jit_trace(model.joiner, joiner_filename) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/jit_trace_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/jit_trace_pretrained.py new file mode 100755 index 000000000..f2ac1914d --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/jit_trace_pretrained.py @@ -0,0 +1,295 @@ +#!/usr/bin/env python3 +# flake8: noqa +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang, Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads torchscript models exported by `torch.jit.trace()` +and uses them to decode waves. +You can use the following command to get the exported models: + +./pruned_transducer_stateless7_streaming/jit_trace_export.py \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 30 \ + --avg 10 \ + --use-averaged-model=True \ + --decode-chunk-len 32 + +Usage of this script: + +./pruned_transducer_stateless7_streaming/jit_trace_pretrained.py \ + --encoder-model-filename ./pruned_transducer_stateless7_streaming/exp/encoder_jit_trace.pt \ + --decoder-model-filename ./pruned_transducer_stateless7_streaming/exp/decoder_jit_trace.pt \ + --joiner-model-filename ./pruned_transducer_stateless7_streaming/exp/joiner_jit_trace.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --decode-chunk-len 32 \ + /path/to/foo.wav \ +""" + +import argparse +import logging +import math +from typing import List, Optional + +import kaldifeat +import sentencepiece as spm +import torch +import torchaudio +from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature +from torch.nn.utils.rnn import pad_sequence + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--encoder-model-filename", + type=str, + required=True, + help="Path to the encoder torchscript model. ", + ) + + parser.add_argument( + "--decoder-model-filename", + type=str, + required=True, + help="Path to the decoder torchscript model. ", + ) + + parser.add_argument( + "--joiner-model-filename", + type=str, + required=True, + help="Path to the joiner torchscript model. ", + ) + + parser.add_argument( + "--bpe-model", + type=str, + help="""Path to bpe.model.""", + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + parser.add_argument( + "--decode-chunk-len", + type=int, + default=32, + help="The chunk size for decoding (in frames before subsampling)", + ) + + parser.add_argument( + "sound_file", + type=str, + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0]) + return ans + + +def greedy_search( + decoder: torch.jit.ScriptModule, + joiner: torch.jit.ScriptModule, + encoder_out: torch.Tensor, + decoder_out: Optional[torch.Tensor] = None, + hyp: Optional[List[int]] = None, +): + assert encoder_out.ndim == 2 + context_size = 2 + blank_id = 0 + + if decoder_out is None: + assert hyp is None, hyp + hyp = [blank_id] * context_size + decoder_input = torch.tensor(hyp, dtype=torch.int32).unsqueeze(0) + # decoder_input.shape (1,, 1 context_size) + decoder_out = decoder(decoder_input, torch.tensor([False])).squeeze(1) + else: + assert decoder_out.ndim == 2 + assert hyp is not None, hyp + + T = encoder_out.size(0) + for i in range(T): + cur_encoder_out = encoder_out[i : i + 1] + joiner_out = joiner(cur_encoder_out, decoder_out).squeeze(0) + y = joiner_out.argmax(dim=0).item() + + if y != blank_id: + hyp.append(y) + decoder_input = hyp[-context_size:] + + decoder_input = torch.tensor(decoder_input, dtype=torch.int32).unsqueeze(0) + decoder_out = decoder(decoder_input, torch.tensor([False])).squeeze(1) + + return hyp, decoder_out + + +def create_streaming_feature_extractor(sample_rate) -> OnlineFeature: + """Create a CPU streaming feature extractor. + + At present, we assume it returns a fbank feature extractor with + fixed options. In the future, we will support passing in the options + from outside. + + Returns: + Return a CPU streaming feature extractor. + """ + opts = FbankOptions() + opts.device = "cpu" + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = sample_rate + opts.mel_opts.num_bins = 80 + return OnlineFbank(opts) + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + logging.info(vars(args)) + + device = torch.device("cpu") + + logging.info(f"device: {device}") + + encoder = torch.jit.load(args.encoder_model_filename) + decoder = torch.jit.load(args.decoder_model_filename) + joiner = torch.jit.load(args.joiner_model_filename) + + encoder.eval() + decoder.eval() + joiner.eval() + + encoder.to(device) + decoder.to(device) + joiner.to(device) + + sp = spm.SentencePieceProcessor() + sp.load(args.bpe_model) + + logging.info("Constructing Fbank computer") + online_fbank = create_streaming_feature_extractor(args.sample_rate) + + logging.info(f"Reading sound files: {args.sound_file}") + wave_samples = read_sound_files( + filenames=[args.sound_file], + expected_sample_rate=args.sample_rate, + )[0] + logging.info(wave_samples.shape) + + logging.info("Decoding started") + chunk_length = args.decode_chunk_len + assert encoder.decode_chunk_size == chunk_length // 2, ( + encoder.decode_chunk_size, + chunk_length, + ) + + # we subsample features with ((x_len - 7) // 2 + 1) // 2 + pad_length = 7 + T = chunk_length + pad_length + + logging.info(f"chunk_length: {chunk_length}") + + states = encoder.get_init_state(device) + + tail_padding = torch.zeros(int(0.3 * args.sample_rate), dtype=torch.float32) + + wave_samples = torch.cat([wave_samples, tail_padding]) + + chunk = int(0.25 * args.sample_rate) # 0.2 second + num_processed_frames = 0 + + hyp = None + decoder_out = None + + start = 0 + while start < wave_samples.numel(): + logging.info(f"{start}/{wave_samples.numel()}") + end = min(start + chunk, wave_samples.numel()) + samples = wave_samples[start:end] + start += chunk + online_fbank.accept_waveform( + sampling_rate=args.sample_rate, + waveform=samples, + ) + while online_fbank.num_frames_ready - num_processed_frames >= T: + frames = [] + for i in range(T): + frames.append(online_fbank.get_frame(num_processed_frames + i)) + frames = torch.cat(frames, dim=0).unsqueeze(0) + x_lens = torch.tensor([T], dtype=torch.int32) + encoder_out, out_lens, states = encoder( + x=frames, + x_lens=x_lens, + states=states, + ) + num_processed_frames += chunk_length + + hyp, decoder_out = greedy_search( + decoder, joiner, encoder_out.squeeze(0), decoder_out, hyp + ) + + context_size = 2 + logging.info(args.sound_file) + logging.info(sp.decode(hyp[context_size:])) + + logging.info("Decoding Done") + + +torch.set_num_threads(4) +torch.set_num_interop_threads(1) +torch._C._jit_set_profiling_executor(False) +torch._C._jit_set_profiling_mode(False) +torch._C._set_graph_executor_optimize(False) +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/joiner.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/joiner.py new file mode 120000 index 000000000..ecfb6dd8a --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/joiner.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7/joiner.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/model.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/model.py new file mode 120000 index 000000000..e17d4f734 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/model.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7/model.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/optim.py new file mode 120000 index 000000000..81ac4a89a --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/optim.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7/optim.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/pretrained.py new file mode 100755 index 000000000..fb77fdd42 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/pretrained.py @@ -0,0 +1,355 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads a checkpoint and uses it to decode waves. +You can generate the checkpoint with the following command: + +./pruned_transducer_stateless7_streaming/export.py \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 20 \ + --avg 10 + +Usage of this script: + +(1) greedy search +./pruned_transducer_stateless7_streaming/pretrained.py \ + --checkpoint ./pruned_transducer_stateless7_streaming/exp/pretrained.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method greedy_search \ + /path/to/foo.wav \ + /path/to/bar.wav + +(2) beam search +./pruned_transducer_stateless7_streaming/pretrained.py \ + --checkpoint ./pruned_transducer_stateless7_streaming/exp/pretrained.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method beam_search \ + --beam-size 4 \ + /path/to/foo.wav \ + /path/to/bar.wav + +(3) modified beam search +./pruned_transducer_stateless7_streaming/pretrained.py \ + --checkpoint ./pruned_transducer_stateless7_streaming/exp/pretrained.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method modified_beam_search \ + --beam-size 4 \ + /path/to/foo.wav \ + /path/to/bar.wav + +(4) fast beam search +./pruned_transducer_stateless7_streaming/pretrained.py \ + --checkpoint ./pruned_transducer_stateless7_streaming/exp/pretrained.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method fast_beam_search \ + --beam-size 4 \ + /path/to/foo.wav \ + /path/to/bar.wav + +You can also use `./pruned_transducer_stateless7_streaming/exp/epoch-xx.pt`. + +Note: ./pruned_transducer_stateless7_streaming/exp/pretrained.pt is generated by +./pruned_transducer_stateless7_streaming/export.py +""" + + +import argparse +import logging +import math +from typing import List + +import k2 +import kaldifeat +import sentencepiece as spm +import torch +import torchaudio +from beam_search import ( + beam_search, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, +) +from torch.nn.utils.rnn import pad_sequence +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.utils import str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--checkpoint", + type=str, + required=True, + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", + ) + + parser.add_argument( + "--bpe-model", + type=str, + help="""Path to bpe.model.""", + ) + + parser.add_argument( + "--method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - fast_beam_search + """, + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=4, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=4, + help="""Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=8, + help="""Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. Used only when + --method is greedy_search. + """, + ) + + add_model_arguments(parser) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0]) + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + + params = get_params() + + params.update(vars(args)) + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(f"{params}") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + logging.info("Creating model") + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + checkpoint = torch.load(args.checkpoint, map_location="cpu") + model.load_state_dict(checkpoint["model"], strict=False) + model.to(device) + model.eval() + model.device = device + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = params.sample_rate + opts.mel_opts.num_bins = params.feature_dim + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {params.sound_files}") + waves = read_sound_files( + filenames=params.sound_files, expected_sample_rate=params.sample_rate + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + + feature_lengths = torch.tensor(feature_lengths, device=device) + + encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths) + + num_waves = encoder_out.size(0) + hyps = [] + msg = f"Using {params.method}" + if params.method == "beam_search": + msg += f" with beam size {params.beam_size}" + logging.info(msg) + + if params.method == "fast_beam_search": + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + ) + + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + else: + for i in range(num_waves): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError(f"Unsupported method: {params.method}") + + hyps.append(sp.decode(hyp).split()) + + s = "\n" + for filename, hyp in zip(params.sound_files, hyps): + words = " ".join(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/scaling.py new file mode 120000 index 000000000..2428b74b9 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/scaling.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7/scaling.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/scaling_converter.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/scaling_converter.py new file mode 120000 index 000000000..b8b8ba432 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/scaling_converter.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7/scaling_converter.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/streaming_beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/streaming_beam_search.py new file mode 120000 index 000000000..3a5f89833 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/streaming_beam_search.py @@ -0,0 +1 @@ +../pruned_transducer_stateless2/streaming_beam_search.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py new file mode 100755 index 000000000..7a349ecb2 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py @@ -0,0 +1,615 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corporation (Authors: Wei Kang, Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Usage: +./pruned_transducer_stateless7_streaming/streaming_decode.py \ + --epoch 28 \ + --avg 15 \ + --decode-chunk-len 32 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --decoding_method greedy_search \ + --num-decode-streams 2000 +""" + +import argparse +import logging +import math +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import numpy as np +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from decode_stream import DecodeStream +from kaldifeat import Fbank, FbankOptions +from lhotse import CutSet +from streaming_beam_search import ( + fast_beam_search_one_best, + greedy_search, + modified_beam_search, +) +from torch.nn.utils.rnn import pad_sequence +from train import add_model_arguments, get_params, get_transducer_model +from zipformer import stack_states, unstack_states + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import ( + AttributeDict, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 0. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless2/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Supported decoding methods are: + greedy_search + modified_beam_search + fast_beam_search + """, + ) + + parser.add_argument( + "--num_active_paths", + type=int, + default=4, + help="""An interger indicating how many candidates we will keep for each + frame. Used only when --decoding-method is modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=4, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search""", + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=4, + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=32, + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--num-decode-streams", + type=int, + default=2000, + help="The number of streams that can be decoded parallel.", + ) + + add_model_arguments(parser) + + return parser + + +def decode_one_chunk( + params: AttributeDict, + model: nn.Module, + decode_streams: List[DecodeStream], +) -> List[int]: + """Decode one chunk frames of features for each decode_streams and + return the indexes of finished streams in a List. + + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + decode_streams: + A List of DecodeStream, each belonging to a utterance. + Returns: + Return a List containing which DecodeStreams are finished. + """ + device = model.device + + features = [] + feature_lens = [] + states = [] + processed_lens = [] + + for stream in decode_streams: + feat, feat_len = stream.get_feature_frames(params.decode_chunk_len) + features.append(feat) + feature_lens.append(feat_len) + states.append(stream.states) + processed_lens.append(stream.done_frames) + + feature_lens = torch.tensor(feature_lens, device=device) + features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS) + + # We subsample features with ((x_len - 7) // 2 + 1) // 2 and the max downsampling + # factor in encoders is 8. + # After feature embedding (x_len - 7) // 2, we have (23 - 7) // 2 = 8. + tail_length = 23 + if features.size(1) < tail_length: + pad_length = tail_length - features.size(1) + feature_lens += pad_length + features = torch.nn.functional.pad( + features, + (0, 0, 0, pad_length), + mode="constant", + value=LOG_EPS, + ) + + states = stack_states(states) + processed_lens = torch.tensor(processed_lens, device=device) + + encoder_out, encoder_out_lens, new_states = model.encoder.streaming_forward( + x=features, + x_lens=feature_lens, + states=states, + ) + + encoder_out = model.joiner.encoder_proj(encoder_out) + + if params.decoding_method == "greedy_search": + greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams) + elif params.decoding_method == "fast_beam_search": + processed_lens = processed_lens + encoder_out_lens + fast_beam_search_one_best( + model=model, + encoder_out=encoder_out, + processed_lens=processed_lens, + streams=decode_streams, + beam=params.beam, + max_states=params.max_states, + max_contexts=params.max_contexts, + ) + elif params.decoding_method == "modified_beam_search": + modified_beam_search( + model=model, + streams=decode_streams, + encoder_out=encoder_out, + num_active_paths=params.num_active_paths, + ) + else: + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + + states = unstack_states(new_states) + + finished_streams = [] + for i in range(len(decode_streams)): + decode_streams[i].states = states[i] + decode_streams[i].done_frames += encoder_out_lens[i] + if decode_streams[i].done: + finished_streams.append(i) + + return finished_streams + + +def decode_dataset( + cuts: CutSet, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[List[str], List[str]]]]: + """Decode dataset. + + Args: + cuts: + Lhotse Cutset containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + device = model.device + + opts = FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = 16000 + opts.mel_opts.num_bins = 80 + + log_interval = 50 + + decode_results = [] + # Contain decode streams currently running. + decode_streams = [] + for num, cut in enumerate(cuts): + # each utterance has a DecodeStream. + initial_states = model.encoder.get_init_state(device=device) + decode_stream = DecodeStream( + params=params, + cut_id=cut.id, + initial_states=initial_states, + decoding_graph=decoding_graph, + device=device, + ) + + audio: np.ndarray = cut.load_audio() + # audio.shape: (1, num_samples) + assert len(audio.shape) == 2 + assert audio.shape[0] == 1, "Should be single channel" + assert audio.dtype == np.float32, audio.dtype + + # The trained model is using normalized samples + assert audio.max() <= 1, "Should be normalized to [-1, 1])" + + samples = torch.from_numpy(audio).squeeze(0) + + fbank = Fbank(opts) + feature = fbank(samples.to(device)) + decode_stream.set_features(feature, tail_pad_len=params.decode_chunk_len) + decode_stream.ground_truth = cut.supervisions[0].text + + decode_streams.append(decode_stream) + + while len(decode_streams) >= params.num_decode_streams: + finished_streams = decode_one_chunk( + params=params, model=model, decode_streams=decode_streams + ) + for i in sorted(finished_streams, reverse=True): + decode_results.append( + ( + decode_streams[i].id, + decode_streams[i].ground_truth.split(), + sp.decode(decode_streams[i].decoding_result()).split(), + ) + ) + del decode_streams[i] + + if num % log_interval == 0: + logging.info(f"Cuts processed until now is {num}.") + + # decode final chunks of last sequences + while len(decode_streams): + finished_streams = decode_one_chunk( + params=params, model=model, decode_streams=decode_streams + ) + for i in sorted(finished_streams, reverse=True): + decode_results.append( + ( + decode_streams[i].id, + decode_streams[i].ground_truth.split(), + sp.decode(decode_streams[i].decoding_result()).split(), + ) + ) + del decode_streams[i] + + if params.decoding_method == "greedy_search": + key = "greedy_search" + elif params.decoding_method == "fast_beam_search": + key = ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}" + ) + elif params.decoding_method == "modified_beam_search": + key = f"num_active_paths_{params.num_active_paths}" + else: + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + return {key: decode_results} + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[List[str], List[str]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = ( + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + params.res_dir = params.exp_dir / "streaming" / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + # for streaming + params.suffix += f"-streaming-chunk-size-{params.decode_chunk_len}" + + # for fast_beam_search + if params.decoding_method == "fast_beam_search": + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if start >= 0: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + model.device = device + + decoding_graph = None + if params.decoding_method == "fast_beam_search": + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + librispeech = LibriSpeechAsrDataModule(args) + + test_clean_cuts = librispeech.test_clean_cuts() + test_other_cuts = librispeech.test_other_cuts() + + test_sets = ["test-clean", "test-other"] + test_cuts = [test_clean_cuts, test_other_cuts] + + for test_set, test_cut in zip(test_sets, test_cuts): + results_dict = decode_dataset( + cuts=test_cut, + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/test_model.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/test_model.py new file mode 100755 index 000000000..5400df804 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/test_model.py @@ -0,0 +1,150 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +To run this file, do: + + cd icefall/egs/librispeech/ASR + python ./pruned_transducer_stateless7_streaming/test_model.py +""" + +import torch +from scaling_converter import convert_scaled_to_non_scaled +from train import get_params, get_transducer_model + + +def test_model(): + params = get_params() + params.vocab_size = 500 + params.blank_id = 0 + params.context_size = 2 + params.num_encoder_layers = "2,4,3,2,4" + params.feedforward_dims = "1024,1024,2048,2048,1024" + params.nhead = "8,8,8,8,8" + params.encoder_dims = "384,384,384,384,384" + params.attention_dims = "192,192,192,192,192" + params.encoder_unmasked_dims = "256,256,256,256,256" + params.zipformer_downsampling_factors = "1,2,4,8,2" + params.cnn_module_kernels = "31,31,31,31,31" + params.decoder_dim = 512 + params.joiner_dim = 512 + params.num_left_chunks = 4 + params.short_chunk_size = 50 + params.decode_chunk_len = 32 + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + print(f"Number of model parameters: {num_param}") + + # Test jit script + convert_scaled_to_non_scaled(model, inplace=True) + # We won't use the forward() method of the model in C++, so just ignore + # it here. + # Otherwise, one of its arguments is a ragged tensor and is not + # torch scriptabe. + model.__class__.forward = torch.jit.ignore(model.__class__.forward) + print("Using torch.jit.script") + model = torch.jit.script(model) + + +def test_model_jit_trace(): + params = get_params() + params.vocab_size = 500 + params.blank_id = 0 + params.context_size = 2 + params.num_encoder_layers = "2,4,3,2,4" + params.feedforward_dims = "1024,1024,2048,2048,1024" + params.nhead = "8,8,8,8,8" + params.encoder_dims = "384,384,384,384,384" + params.attention_dims = "192,192,192,192,192" + params.encoder_unmasked_dims = "256,256,256,256,256" + params.zipformer_downsampling_factors = "1,2,4,8,2" + params.cnn_module_kernels = "31,31,31,31,31" + params.decoder_dim = 512 + params.joiner_dim = 512 + params.num_left_chunks = 4 + params.short_chunk_size = 50 + params.decode_chunk_len = 32 + model = get_transducer_model(params) + model.eval() + + num_param = sum([p.numel() for p in model.parameters()]) + print(f"Number of model parameters: {num_param}") + + convert_scaled_to_non_scaled(model, inplace=True) + + # Test encoder + def _test_encoder(): + encoder = model.encoder + assert encoder.decode_chunk_size == params.decode_chunk_len // 2, ( + encoder.decode_chunk_size, + params.decode_chunk_len, + ) + T = params.decode_chunk_len + 7 + + x = torch.zeros(1, T, 80, dtype=torch.float32) + x_lens = torch.full((1,), T, dtype=torch.int32) + states = encoder.get_init_state(device=x.device) + encoder.__class__.forward = encoder.__class__.streaming_forward + traced_encoder = torch.jit.trace(encoder, (x, x_lens, states)) + + states1 = encoder.get_init_state(device=x.device) + states2 = traced_encoder.get_init_state(device=x.device) + for i in range(5): + x = torch.randn(1, T, 80, dtype=torch.float32) + x_lens = torch.full((1,), T, dtype=torch.int32) + y1, _, states1 = encoder.streaming_forward(x, x_lens, states1) + y2, _, states2 = traced_encoder(x, x_lens, states2) + assert torch.allclose(y1, y2, atol=1e-6), (i, (y1 - y2).abs().mean()) + + # Test decoder + def _test_decoder(): + decoder = model.decoder + y = torch.zeros(10, decoder.context_size, dtype=torch.int64) + need_pad = torch.tensor([False]) + + traced_decoder = torch.jit.trace(decoder, (y, need_pad)) + d1 = decoder(y, need_pad) + d2 = traced_decoder(y, need_pad) + assert torch.equal(d1, d2), (d1 - d2).abs().mean() + + # Test joiner + def _test_joiner(): + joiner = model.joiner + encoder_out_dim = joiner.encoder_proj.weight.shape[1] + decoder_out_dim = joiner.decoder_proj.weight.shape[1] + encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32) + decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32) + + traced_joiner = torch.jit.trace(joiner, (encoder_out, decoder_out)) + j1 = joiner(encoder_out, decoder_out) + j2 = traced_joiner(encoder_out, decoder_out) + assert torch.equal(j1, j2), (j1 - j2).abs().mean() + + _test_encoder() + _test_decoder() + _test_joiner() + + +def main(): + test_model() + test_model_jit_trace() + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train.py new file mode 100755 index 000000000..2bdc882a5 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train.py @@ -0,0 +1,1264 @@ +#!/usr/bin/env python3 +# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo,) +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +./pruned_transducer_stateless7_streaming/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --exp-dir pruned_transducer_stateless7_streaming/exp \ + --full-libri 1 \ + --max-duration 300 + +# For mix precision training: + +./pruned_transducer_stateless7_streaming/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir pruned_transducer_stateless7_streaming/exp \ + --full-libri 1 \ + --max-duration 550 +""" + + +import argparse +import copy +import logging +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import k2 +import optim +import sentencepiece as spm +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from decoder import Decoder +from joiner import Joiner +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model import Transducer +from optim import Eden, ScaledAdam +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from zipformer import Zipformer + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for module in model.modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=str, + default="2,4,3,2,4", + help="Number of zipformer encoder layers, comma separated.", + ) + + parser.add_argument( + "--feedforward-dims", + type=str, + default="1024,1024,2048,2048,1024", + help="Feedforward dimension of the zipformer encoder layers, comma separated.", + ) + + parser.add_argument( + "--nhead", + type=str, + default="8,8,8,8,8", + help="Number of attention heads in the zipformer encoder layers.", + ) + + parser.add_argument( + "--encoder-dims", + type=str, + default="384,384,384,384,384", + help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated", + ) + + parser.add_argument( + "--attention-dims", + type=str, + default="192,192,192,192,192", + help="""Attention dimension in the 2 blocks of zipformer encoder layers, comma separated; + not the same as embedding dimension.""", + ) + + parser.add_argument( + "--encoder-unmasked-dims", + type=str, + default="256,256,256,256,256", + help="Unmasked dimensions in the encoders, relates to augmentation during training. " + "Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance " + " worse.", + ) + + parser.add_argument( + "--zipformer-downsampling-factors", + type=str, + default="1,2,4,8,2", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--cnn-module-kernels", + type=str, + default="31,31,31,31,31", + help="Sizes of kernels in convolution modules", + ) + + parser.add_argument( + "--decoder-dim", + type=int, + default=512, + help="Embedding dimension in the decoder model.", + ) + + parser.add_argument( + "--joiner-dim", + type=int, + default=512, + help="""Dimension used in the joiner model. + Outputs from the encoder and decoder model are projected + to this dimension before adding. + """, + ) + + parser.add_argument( + "--short-chunk-size", + type=int, + default=50, + help="""Chunk length of dynamic training, the chunk size would be either + max sequence length of current batch or uniformly sampled from (1, short_chunk_size). + """, + ) + + parser.add_argument( + "--num-left-chunks", + type=int, + default=4, + help="How many left context can be seen in chunks when calculating attention.", + ) + + parser.add_argument( + "--decode-chunk-len", + type=int, + default=32, + help="The chunk size for decoding (in frames before subsampling)", + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7_streaming/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--base-lr", type=float, default=0.05, help="The base learning rate." + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=5000, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=3.5, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network) part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=2000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 0. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - encoder_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warmup period that dictates the decay of the + scale on "simple" (un-pruned) loss. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, # For the 100h subset, use 800 + # parameters for zipformer + "feature_dim": 80, + "subsampling_factor": 4, # not passed in, this is fixed. + "warm_step": 2000, + "env_info": get_env_info(), + } + ) + + return params + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + # TODO: We can add an option to switch between Zipformer and Transformer + def to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + + encoder = Zipformer( + num_features=params.feature_dim, + output_downsampling_factor=2, + zipformer_downsampling_factors=to_int_tuple( + params.zipformer_downsampling_factors + ), + encoder_dims=to_int_tuple(params.encoder_dims), + attention_dim=to_int_tuple(params.attention_dims), + encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims), + nhead=to_int_tuple(params.nhead), + feedforward_dim=to_int_tuple(params.feedforward_dims), + cnn_module_kernels=to_int_tuple(params.cnn_module_kernels), + num_encoder_layers=to_int_tuple(params.num_encoder_layers), + num_left_chunks=params.num_left_chunks, + short_chunk_size=params.short_chunk_size, + decode_chunk_size=params.decode_chunk_len // 2, + ) + return encoder + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + decoder_dim=params.decoder_dim, + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + joiner = Joiner( + encoder_dim=int(params.encoder_dims.split(",")[-1]), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return joiner + + +def get_transducer_model(params: AttributeDict) -> nn.Module: + encoder = get_encoder_model(params) + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + + model = Transducer( + encoder=encoder, + decoder=decoder, + joiner=joiner, + encoder_dim=int(params.encoder_dims.split(",")[-1]), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + if "cur_batch_idx" in saved_params: + params["cur_batch_idx"] = saved_params["cur_batch_idx"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute transducer loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + batch_idx_train = params.batch_idx_train + warm_step = params.warm_step + + texts = batch["supervisions"]["text"] + y = sp.encode(texts, out_type=int) + y = k2.RaggedTensor(y).to(device) + + with torch.set_grad_enabled(is_training): + simple_loss, pruned_loss = model( + x=feature, + x_lens=feature_lens, + y=y, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + ) + + s = params.simple_loss_scale + # take down the scale on the simple loss from 1.0 at the start + # to params.simple_loss scale by warm_step. + simple_loss_scale = ( + s + if batch_idx_train >= warm_step + else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) + ) + pruned_loss_scale = ( + 1.0 + if batch_idx_train >= warm_step + else 0.1 + 0.9 * (batch_idx_train / warm_step) + ) + + loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + sp: spm.SentencePieceProcessor, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + cur_batch_idx = params.get("cur_batch_idx", 0) + + for batch_idx, batch in enumerate(train_dl): + if batch_idx < cur_batch_idx: + continue + cur_batch_idx = batch_idx + + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + set_batch_count(model, params.batch_idx_train) + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except: # noqa + display_and_save_batch(batch, params=params, sp=sp) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + params.cur_batch_idx = batch_idx + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + del params.cur_batch_idx + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if batch_idx % params.log_interval == 0: + cur_lr = scheduler.get_last_lr()[0] + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", + cur_grad_scale, + params.batch_idx_train, + ) + + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + if params.full_libri is False: + params.valid_interval = 1600 + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + parameters_names = [] + parameters_names.append( + [name_param_pair[0] for name_param_pair in model.named_parameters()] + ) + optimizer = ScaledAdam( + model.parameters(), + lr=params.base_lr, + clipping_scale=2.0, + parameters_names=parameters_names, + ) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 2**22 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + librispeech = LibriSpeechAsrDataModule(args) + + train_cuts = librispeech.train_clean_100_cuts() + if params.full_libri: + train_cuts += librispeech.train_clean_360_cuts() + train_cuts += librispeech.train_other_500_cuts() + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + if c.duration < 1.0 or c.duration > 20.0: + logging.warning( + f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + ) + return False + + # In pruned RNN-T, we require that T >= S + # where T is the number of feature frames after subsampling + # and S is the number of tokens in the utterance + + # In ./zipformer.py, the conv module uses the following expression + # for subsampling + T = ((c.num_frames - 7) // 2 + 1) // 2 + tokens = sp.encode(c.supervisions[0].text, out_type=str) + + if T < len(tokens): + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Number of frames (before subsampling): {c.num_frames}. " + f"Number of frames (after subsampling): {T}. " + f"Text: {c.supervisions[0].text}. " + f"Tokens: {tokens}. " + f"Number of tokens: {len(tokens)}" + ) + return False + + return True + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = librispeech.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) + + valid_cuts = librispeech.dev_clean_cuts() + valid_cuts += librispeech.dev_other_cuts() + valid_dl = librispeech.valid_dataloaders(valid_cuts) + + if not params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + sp=sp, + params=params, + ) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sp=sp, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + sp: spm.SentencePieceProcessor, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + sp: + The BPE model. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + y = sp.encode(supervisions["text"], out_type=int) + num_tokens = sum(len(i) for i in y) + logging.info(f"num tokens: {num_tokens}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + sp: spm.SentencePieceProcessor, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params, sp=sp) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py new file mode 100644 index 000000000..88beb38c1 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py @@ -0,0 +1,2881 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey,) +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import itertools +import logging +import math +import random +import warnings +from typing import List, Optional, Tuple, Union + +import torch +from encoder_interface import EncoderInterface +from scaling import ( + ScaledLinear, # not as in other dirs.. just scales down initial parameter values. +) +from scaling import ( + ActivationBalancer, + BasicNorm, + DoubleSwish, + Identity, + MaxEig, + ScaledConv1d, + Whiten, + _diag, + penalize_abs_values_gt, + random_clamp, + softmax, +) +from torch import Tensor, nn + +from icefall.dist import get_rank +from icefall.utils import make_pad_mask, subsequent_chunk_mask + + +def stack_states(state_list: List[List[Tensor]]) -> List[Tensor]: + """Stack list of zipformer states that correspond to separate utterances + into a single emformer state, so that it can be used as an input for + zipformer when those utterances are formed into a batch. + + Note: + It is the inverse of :func:`unstack_states`. + + Args: + state_list: + Each element in state_list corresponding to the internal state + of the zipformer model for a single utterance. + ``states[i]`` is a list of 7 * num_encoders elements of i-th utterance. + ``states[i][0:num_encoders]`` is the cached numbers of past frames. + ``states[i][num_encoders:2*num_encoders]`` is the cached average tensors. + ``states[i][2*num_encoders:3*num_encoders]`` is the cached key tensors of the first attention modules. + ``states[i][3*num_encoders:4*num_encoders]`` is the cached value tensors of the first attention modules. + ``states[i][4*num_encoders:5*num_encoders]`` is the cached value tensors of the second attention modules. + ``states[i][5*num_encoders:6*num_encoders]`` is the cached left contexts of the first convolution modules. + ``states[i][6*num_encoders:7*num_encoders]`` is the cached left contexts of the second convolution modules. + + Returns: + A new state corresponding to a batch of utterances. + See the input argument of :func:`unstack_states` for the meaning + of the returned tensor. + """ + batch_size = len(state_list) + assert len(state_list[0]) % 7 == 0, len(state_list[0]) + num_encoders = len(state_list[0]) // 7 + + cached_len = [] + cached_avg = [] + cached_key = [] + cached_val = [] + cached_val2 = [] + cached_conv1 = [] + cached_conv2 = [] + + # For cached_len + len_list = [state_list[n][0:num_encoders] for n in range(batch_size)] + for i in range(num_encoders): + # len_avg: (num_layers, batch_size) + len_avg = torch.cat([len_list[n][i] for n in range(batch_size)], dim=1) + cached_len.append(len_avg) + + # For cached_avg + avg_list = [ + state_list[n][num_encoders : 2 * num_encoders] for n in range(batch_size) + ] + for i in range(num_encoders): + # avg: (num_layers, batch_size, D) + avg = torch.cat([avg_list[n][i] for n in range(batch_size)], dim=1) + cached_avg.append(avg) + + # For cached_key + key_list = [ + state_list[n][2 * num_encoders : 3 * num_encoders] for n in range(batch_size) + ] + for i in range(num_encoders): + # key: (num_layers, left_context_size, batch_size, D) + key = torch.cat([key_list[n][i] for n in range(batch_size)], dim=2) + cached_key.append(key) + + # For cached_val + val_list = [ + state_list[n][3 * num_encoders : 4 * num_encoders] for n in range(batch_size) + ] + for i in range(num_encoders): + # val: (num_layers, left_context_size, batch_size, D) + val = torch.cat([val_list[n][i] for n in range(batch_size)], dim=2) + cached_val.append(val) + + # For cached_val2 + val2_list = [ + state_list[n][4 * num_encoders : 5 * num_encoders] for n in range(batch_size) + ] + for i in range(num_encoders): + # val2: (num_layers, left_context_size, batch_size, D) + val2 = torch.cat([val2_list[n][i] for n in range(batch_size)], dim=2) + cached_val2.append(val2) + + # For cached_conv1 + conv1_list = [ + state_list[n][5 * num_encoders : 6 * num_encoders] for n in range(batch_size) + ] + for i in range(num_encoders): + # conv1: (num_layers, batch_size, D, kernel-1) + conv1 = torch.cat([conv1_list[n][i] for n in range(batch_size)], dim=1) + cached_conv1.append(conv1) + + # For cached_conv2 + conv2_list = [ + state_list[n][6 * num_encoders : 7 * num_encoders] for n in range(batch_size) + ] + for i in range(num_encoders): + # conv2: (num_layers, batch_size, D, kernel-1) + conv2 = torch.cat([conv2_list[n][i] for n in range(batch_size)], dim=1) + cached_conv2.append(conv2) + + states = ( + cached_len + + cached_avg + + cached_key + + cached_val + + cached_val2 + + cached_conv1 + + cached_conv2 + ) + return states + + +def unstack_states(states: List[Tensor]) -> List[List[Tensor]]: + """Unstack the zipformer state corresponding to a batch of utterances + into a list of states, where the i-th entry is the state from the i-th + utterance in the batch. + + Note: + It is the inverse of :func:`stack_states`. + + Args: + states: + A list of 7 * num_encoders elements: + ``states[0:num_encoders]`` is the cached numbers of past frames. + ``states[num_encoders:2*num_encoders]`` is the cached average tensors. + ``states[2*num_encoders:3*num_encoders]`` is the cached key tensors of the first attention modules. + ``states[3*num_encoders:4*num_encoders]`` is the cached value tensors of the first attention modules. + ``states[4*num_encoders:5*num_encoders]`` is the cached value tensors of the second attention modules. + ``states[5*num_encoders:6*num_encoders]`` is the cached left contexts of the first convolution modules. + ``states[6*num_encoders:7*num_encoders]`` is the cached left contexts of the second convolution modules. + + Returns: + A list of states. + ``states[i]`` is a list of 7 * num_encoders elements of i-th utterance. + """ + assert len(states) % 7 == 0, len(states) + num_encoders = len(states) // 7 + ( + cached_len, + cached_avg, + cached_key, + cached_val, + cached_val2, + cached_conv1, + cached_conv2, + ) = (states[i * num_encoders : (i + 1) * num_encoders] for i in range(7)) + + batch_size = cached_len[0].shape[1] + + len_list = [[] for _ in range(batch_size)] + for i in range(num_encoders): + # cached_len[i]: (num_layers, batch_size) + len_avg = cached_len[i].chunk(chunks=batch_size, dim=1) + for n in range(batch_size): + len_list[n].append(len_avg[n]) + + avg_list = [[] for _ in range(batch_size)] + for i in range(num_encoders): + # cached_avg[i]: (num_layers, batch_size, D) + avg = cached_avg[i].chunk(chunks=batch_size, dim=1) + for n in range(batch_size): + avg_list[n].append(avg[n]) + + key_list = [[] for _ in range(batch_size)] + for i in range(num_encoders): + # cached_key[i]: (num_layers, left_context, batch_size, D) + key = cached_key[i].chunk(chunks=batch_size, dim=2) + for n in range(batch_size): + key_list[n].append(key[n]) + + val_list = [[] for _ in range(batch_size)] + for i in range(num_encoders): + # cached_val[i]: (num_layers, left_context, batch_size, D) + val = cached_val[i].chunk(chunks=batch_size, dim=2) + for n in range(batch_size): + val_list[n].append(val[n]) + + val2_list = [[] for _ in range(batch_size)] + for i in range(num_encoders): + # cached_val2[i]: (num_layers, left_context, batch_size, D) + val2 = cached_val2[i].chunk(chunks=batch_size, dim=2) + for n in range(batch_size): + val2_list[n].append(val2[n]) + + conv1_list = [[] for _ in range(batch_size)] + for i in range(num_encoders): + # cached_conv1[i]: (num_layers, batch_size, D, kernel-1) + conv1 = cached_conv1[i].chunk(chunks=batch_size, dim=1) + for n in range(batch_size): + conv1_list[n].append(conv1[n]) + + conv2_list = [[] for _ in range(batch_size)] + for i in range(num_encoders): + # cached_conv2[i]: (num_layers, batch_size, D, kernel-1) + conv2 = cached_conv2[i].chunk(chunks=batch_size, dim=1) + for n in range(batch_size): + conv2_list[n].append(conv2[n]) + + state_list = [ + ( + len_list[i] + + avg_list[i] + + key_list[i] + + val_list[i] + + val2_list[i] + + conv1_list[i] + + conv2_list[i] + ) + for i in range(batch_size) + ] + return state_list + + +class Zipformer(EncoderInterface): + """ + Args: + num_features (int): Number of input features + d_model: (int,int): embedding dimension of 2 encoder stacks + attention_dim: (int,int): attention dimension of 2 encoder stacks + nhead (int, int): number of heads + dim_feedforward (int, int): feedforward dimension in 2 encoder stacks + num_encoder_layers (int): number of encoder layers + dropout (float): dropout rate + cnn_module_kernel (int): Kernel size of convolution module + vgg_frontend (bool): whether to use vgg frontend. + warmup_batches (float): number of batches to warm up over + """ + + def __init__( + self, + num_features: int, + output_downsampling_factor: int = 2, + encoder_dims: Tuple[int] = (384, 384), + attention_dim: Tuple[int] = (256, 256), + encoder_unmasked_dims: Tuple[int] = (256, 256), + zipformer_downsampling_factors: Tuple[int] = (2, 4), + nhead: Tuple[int] = (8, 8), + feedforward_dim: Tuple[int] = (1536, 2048), + num_encoder_layers: Tuple[int] = (12, 12), + dropout: float = 0.1, + cnn_module_kernels: Tuple[int] = (31, 31), + pos_dim: int = 4, + num_left_chunks: int = 4, + short_chunk_threshold: float = 0.75, + short_chunk_size: int = 50, + decode_chunk_size: int = 16, + warmup_batches: float = 4000.0, + ) -> None: + super(Zipformer, self).__init__() + + self.num_features = num_features + assert 0 < encoder_dims[0] <= encoder_dims[1] + self.encoder_dims = encoder_dims + self.encoder_unmasked_dims = encoder_unmasked_dims + self.zipformer_downsampling_factors = zipformer_downsampling_factors + self.output_downsampling_factor = output_downsampling_factor + + self.num_left_chunks = num_left_chunks + self.short_chunk_threshold = short_chunk_threshold + self.short_chunk_size = short_chunk_size + + # Used in decoding + self.decode_chunk_size = decode_chunk_size + + # will be written to, see set_batch_count() + self.batch_count = 0 + self.warmup_end = warmup_batches + + for u, d in zip(encoder_unmasked_dims, encoder_dims): + assert u <= d, (u, d) + + # self.encoder_embed converts the input of shape (N, T, num_features) + # to the shape (N, (T - 7)//2, encoder_dims). + # That is, it does two things simultaneously: + # (1) subsampling: T -> (T - 7)//2 + # (2) embedding: num_features -> encoder_dims + self.encoder_embed = Conv2dSubsampling( + num_features, encoder_dims[0], dropout=dropout + ) + + # each one will be ZipformerEncoder or DownsampledZipformerEncoder + encoders = [] + + self.num_encoders = len(encoder_dims) + for i in range(self.num_encoders): + encoder_layer = ZipformerEncoderLayer( + encoder_dims[i], + attention_dim[i], + nhead[i], + feedforward_dim[i], + dropout, + cnn_module_kernels[i], + pos_dim, + ) + + # For the segment of the warmup period, we let the Conv2dSubsampling + # layer learn something. Then we start to warm up the other encoders. + encoder = ZipformerEncoder( + encoder_layer, + num_encoder_layers[i], + dropout, + warmup_begin=warmup_batches * (i + 1) / (self.num_encoders + 1), + warmup_end=warmup_batches * (i + 2) / (self.num_encoders + 1), + ) + + if zipformer_downsampling_factors[i] != 1: + encoder = DownsampledZipformerEncoder( + encoder, + input_dim=encoder_dims[i - 1] if i > 0 else encoder_dims[0], + output_dim=encoder_dims[i], + downsample=zipformer_downsampling_factors[i], + ) + encoders.append(encoder) + self.encoders = nn.ModuleList(encoders) + + # initializes self.skip_layers and self.skip_modules + self._init_skip_modules() + + self.downsample_output = AttentionDownsample( + encoder_dims[-1], encoder_dims[-1], downsample=output_downsampling_factor + ) + + def _get_layer_skip_dropout_prob(self): + if not self.training: + return 0.0 + batch_count = self.batch_count + min_dropout_prob = 0.025 + + if batch_count > self.warmup_end: + return min_dropout_prob + else: + return 0.5 - (batch_count / self.warmup_end) * (0.5 - min_dropout_prob) + + def _init_skip_modules(self): + """ + If self.zipformer_downampling_factors = (1, 2, 4, 8, 4, 2), then at the input of layer + indexed 4 (in zero indexing), with has subsapling_factor=4, we combine the output of + layers 2 and 3; and at the input of layer indexed 5, which which has subsampling_factor=2, + we combine the outputs of layers 1 and 5. + """ + skip_layers = [] + skip_modules = [] + z = self.zipformer_downsampling_factors + for i in range(len(z)): + if i <= 1 or z[i - 1] <= z[i]: + skip_layers.append(None) + skip_modules.append(SimpleCombinerIdentity()) + else: + # TEMP + for j in range(i - 2, -1, -1): + if z[j] <= z[i] or j == 0: + # TEMP logging statement. + logging.info( + f"At encoder stack {i}, which has downsampling_factor={z[i]}, we will " + f"combine the outputs of layers {j} and {i-1}, with downsampling_factors={z[j]} and {z[i-1]}." + ) + skip_layers.append(j) + skip_modules.append( + SimpleCombiner( + self.encoder_dims[j], + self.encoder_dims[i - 1], + min_weight=(0.0, 0.25), + ) + ) + break + self.skip_layers = skip_layers + self.skip_modules = nn.ModuleList(skip_modules) + + def get_feature_masks(self, x: torch.Tensor) -> List[float]: + # Note: The actual return type is Union[List[float], List[Tensor]], + # but to make torch.jit.script() work, we use List[float] + """ + In eval mode, returns [1.0] * num_encoders; in training mode, returns a number of + randomized feature masks, one per encoder. + On e.g. 15% of frames, these masks will zero out all enocder dims larger than + some supplied number, e.g. >256, so in effect on those frames we are using + a smaller encoer dim. + + We generate the random masks at this level because we want the 2 masks to 'agree' + all the way up the encoder stack. This will mean that the 1st mask will have + mask values repeated self.zipformer_subsampling_factor times. + + Args: + x: the embeddings (needed for the shape and dtype and device), of shape + (num_frames, batch_size, encoder_dims0) + """ + num_encoders = len(self.encoder_dims) + if torch.jit.is_scripting() or not self.training: + return [1.0] * num_encoders + + (num_frames0, batch_size, _encoder_dims0) = x.shape + + assert self.encoder_dims[0] == _encoder_dims0, ( + self.encoder_dims, + _encoder_dims0, + ) + + max_downsampling_factor = max(self.zipformer_downsampling_factors) + + num_frames_max = num_frames0 + max_downsampling_factor - 1 + + feature_mask_dropout_prob = 0.15 + + # frame_mask_max shape: (num_frames_max, batch_size, 1) + frame_mask_max = ( + torch.rand(num_frames_max, batch_size, 1, device=x.device) + > feature_mask_dropout_prob + ).to(x.dtype) + + feature_masks = [] + for i in range(num_encoders): + ds = self.zipformer_downsampling_factors[i] + upsample_factor = max_downsampling_factor // ds + + frame_mask = ( + frame_mask_max.unsqueeze(1) + .expand(num_frames_max, upsample_factor, batch_size, 1) + .reshape(num_frames_max * upsample_factor, batch_size, 1) + ) + num_frames = (num_frames0 + ds - 1) // ds + frame_mask = frame_mask[:num_frames] + feature_mask = torch.ones( + num_frames, + batch_size, + self.encoder_dims[i], + dtype=x.dtype, + device=x.device, + ) + u = self.encoder_unmasked_dims[i] + feature_mask[:, :, u:] *= frame_mask + feature_masks.append(feature_mask) + + return feature_masks + + def forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + x: + The input tensor. Its shape is (batch_size, seq_len, feature_dim). + x_lens: + A tensor of shape (batch_size,) containing the number of frames in + `x` before padding. + chunk_size: + The chunk size used in evaluation mode. + Returns: + Return a tuple containing 2 tensors: + - embeddings: its shape is (batch_size, output_seq_len, encoder_dims[-1]) + - lengths, a tensor of shape (batch_size,) containing the number + of frames in `embeddings` before padding. + """ + x = self.encoder_embed(x) + + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + + lengths = (x_lens - 7) >> 1 + assert x.size(0) == lengths.max().item(), (x.shape, lengths, lengths.max()) + mask = make_pad_mask(lengths) + + outputs = [] + feature_masks = self.get_feature_masks(x) + + if self.training: + # Training mode + max_ds = max(self.zipformer_downsampling_factors) + # Generate dynamic chunk-wise attention mask during training + max_len = x.size(0) // max_ds + short_chunk_size = self.short_chunk_size // max_ds + chunk_size = torch.randint(1, max_len, (1,)).item() + if chunk_size > (max_len * self.short_chunk_threshold): + # Full attention + chunk_size = x.size(0) + else: + # Chunk-wise attention + chunk_size = chunk_size % short_chunk_size + 1 + chunk_size *= max_ds + else: + chunk_size = self.decode_chunk_size + # Evaluation mode + for ds in self.zipformer_downsampling_factors: + assert chunk_size % ds == 0, (chunk_size, ds) + + attn_mask = ~subsequent_chunk_mask( + size=x.size(0), + chunk_size=chunk_size, + num_left_chunks=self.num_left_chunks, + device=x.device, + ) + + for i, (module, skip_module) in enumerate( + zip(self.encoders, self.skip_modules) + ): + ds = self.zipformer_downsampling_factors[i] + k = self.skip_layers[i] + if isinstance(k, int): + layer_skip_dropout_prob = self._get_layer_skip_dropout_prob() + if torch.jit.is_scripting(): + x = skip_module(outputs[k], x) + elif (not self.training) or random.random() > layer_skip_dropout_prob: + x = skip_module(outputs[k], x) + x = module( + x, + feature_mask=feature_masks[i], + src_key_padding_mask=None if mask is None else mask[..., ::ds], + attn_mask=attn_mask[::ds, ::ds], + ) + outputs.append(x) + + x = self.downsample_output(x) + # class Downsample has this rounding behavior.. + assert self.output_downsampling_factor == 2, self.output_downsampling_factor + lengths = (lengths + 1) >> 1 + + x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + + return x, lengths + + def streaming_forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + states: List[Tensor], + ) -> Tuple[Tensor, Tensor, List[Tensor]]: + """ + Args: + x: + The input tensor. Its shape is (batch_size, seq_len, feature_dim). + seq_len is the input chunk length. + x_lens: + A tensor of shape (batch_size,) containing the number of frames in + `x` before padding. + states: + A list of 7 * num_encoders elements: + ``states[0:num_encoders]`` is the cached numbers of past frames. + ``states[num_encoders:2*num_encoders]`` is the cached average tensors. + ``states[2*num_encoders:3*num_encoders]`` is the cached key tensors of the first attention modules. + ``states[3*num_encoders:4*num_encoders]`` is the cached value tensors of the first attention modules. + ``states[4*num_encoders:5*num_encoders]`` is the cached value tensors of the second attention modules. + ``states[5*num_encoders:6*num_encoders]`` is the cached left contexts of the first convolution modules. + ``states[6*num_encoders:7*num_encoders]`` is the cached left contexts of the second convolution modules. + + Returns: + Return a tuple containing 3 tensors: + - embeddings: its shape is (batch_size, output_seq_len, encoder_dims[-1]) + - lengths, a tensor of shape (batch_size,) containing the number + of frames in `embeddings` before padding. + - updated states. + """ + assert len(states) == 7 * self.num_encoders, (len(states), self.num_encoders) + + cached_len = states[: self.num_encoders] + cached_avg = states[self.num_encoders : 2 * self.num_encoders] + cached_key = states[2 * self.num_encoders : 3 * self.num_encoders] + cached_val = states[3 * self.num_encoders : 4 * self.num_encoders] + cached_val2 = states[4 * self.num_encoders : 5 * self.num_encoders] + cached_conv1 = states[5 * self.num_encoders : 6 * self.num_encoders] + cached_conv2 = states[6 * self.num_encoders : 7 * self.num_encoders] + + x = self.encoder_embed(x) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + lengths = (x_lens - 7) >> 1 + assert x.size(0) == lengths.max().item(), (x.shape, lengths, lengths.max()) + + outputs = [] + new_cached_len = [] + new_cached_avg = [] + new_cached_key = [] + new_cached_val = [] + new_cached_val2 = [] + new_cached_conv1 = [] + new_cached_conv2 = [] + + for i, (module, skip_module) in enumerate( + zip(self.encoders, self.skip_modules) + ): + k = self.skip_layers[i] + if isinstance(k, int): + x = skip_module(outputs[k], x) + x, len_avg, avg, key, val, val2, conv1, conv2 = module.streaming_forward( + x, + cached_len=cached_len[i], + cached_avg=cached_avg[i], + cached_key=cached_key[i], + cached_val=cached_val[i], + cached_val2=cached_val2[i], + cached_conv1=cached_conv1[i], + cached_conv2=cached_conv2[i], + ) + outputs.append(x) + # Update caches + new_cached_len.append(len_avg) + new_cached_avg.append(avg) + new_cached_key.append(key) + new_cached_val.append(val) + new_cached_val2.append(val2) + new_cached_conv1.append(conv1) + new_cached_conv2.append(conv2) + + x = self.downsample_output(x) + # class Downsample has this rounding behavior.. + assert self.output_downsampling_factor == 2, self.output_downsampling_factor + lengths = (lengths + 1) >> 1 + + x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + + new_states = ( + new_cached_len + + new_cached_avg + + new_cached_key + + new_cached_val + + new_cached_val2 + + new_cached_conv1 + + new_cached_conv2 + ) + return x, lengths, new_states + + @torch.jit.export + def get_init_state( + self, + device: torch.device = torch.device("cpu"), + ) -> List[Tensor]: + """Get initial states. + A list of 7 * num_encoders elements: + ``states[0:num_encoders]`` is the cached numbers of past frames. + ``states[num_encoders:2*num_encoders]`` is the cached average tensors. + ``states[2*num_encoders:3*num_encoders]`` is the cached key tensors of the first attention modules. + ``states[3*num_encoders:4*num_encoders]`` is the cached value tensors of the first attention modules. + ``states[4*num_encoders:5*num_encoders]`` is the cached value tensors of the second attention modules. + ``states[5*num_encoders:6*num_encoders]`` is the cached left contexts of the first convolution modules. + ``states[6*num_encoders:7*num_encoders]`` is the cached left contexts of the second convolution modules. + """ + cached_len = [] + cached_avg = [] + cached_key = [] + cached_val = [] + cached_val2 = [] + cached_conv1 = [] + cached_conv2 = [] + + left_context_len = self.decode_chunk_size * self.num_left_chunks + + for i, encoder in enumerate(self.encoders): + num_layers = encoder.num_layers + ds = self.zipformer_downsampling_factors[i] + + len_avg = torch.zeros(num_layers, 1, dtype=torch.int32, device=device) + cached_len.append(len_avg) + + avg = torch.zeros(num_layers, 1, encoder.d_model, device=device) + cached_avg.append(avg) + + key = torch.zeros( + num_layers, + left_context_len // ds, + 1, + encoder.attention_dim, + device=device, + ) + cached_key.append(key) + + val = torch.zeros( + num_layers, + left_context_len // ds, + 1, + encoder.attention_dim // 2, + device=device, + ) + cached_val.append(val) + + val2 = torch.zeros( + num_layers, + left_context_len // ds, + 1, + encoder.attention_dim // 2, + device=device, + ) + cached_val2.append(val2) + + conv1 = torch.zeros( + num_layers, + 1, + encoder.d_model, + encoder.cnn_module_kernel - 1, + device=device, + ) + cached_conv1.append(conv1) + + conv2 = torch.zeros( + num_layers, + 1, + encoder.d_model, + encoder.cnn_module_kernel - 1, + device=device, + ) + cached_conv2.append(conv2) + + states = ( + cached_len + + cached_avg + + cached_key + + cached_val + + cached_val2 + + cached_conv1 + + cached_conv2 + ) + return states + + +class ZipformerEncoderLayer(nn.Module): + """ + ZipformerEncoderLayer is made up of self-attn, feedforward and convolution networks. + + Args: + d_model: the number of expected features in the input (required). + nhead: the number of heads in the multiheadattention models (required). + feedforward_dim: the dimension of the feedforward network model (default=2048). + dropout: the dropout value (default=0.1). + cnn_module_kernel (int): Kernel size of convolution module. + + Examples:: + >>> encoder_layer = ZipformerEncoderLayer(d_model=512, nhead=8) + >>> src = torch.rand(10, 32, 512) + >>> pos_emb = torch.rand(32, 19, 512) + >>> out = encoder_layer(src, pos_emb) + """ + + def __init__( + self, + d_model: int, + attention_dim: int, + nhead: int, + feedforward_dim: int = 2048, + dropout: float = 0.1, + cnn_module_kernel: int = 31, + pos_dim: int = 4, + ) -> None: + super(ZipformerEncoderLayer, self).__init__() + + self.d_model = d_model + self.attention_dim = attention_dim + self.cnn_module_kernel = cnn_module_kernel + + # will be written to, see set_batch_count() + self.batch_count = 0 + + self.self_attn = RelPositionMultiheadAttention( + d_model, + attention_dim, + nhead, + pos_dim, + dropout=0.0, + ) + + self.pooling = PoolingModule(d_model) + + self.feed_forward1 = FeedforwardModule(d_model, feedforward_dim, dropout) + + self.feed_forward2 = FeedforwardModule(d_model, feedforward_dim, dropout) + + self.feed_forward3 = FeedforwardModule(d_model, feedforward_dim, dropout) + + self.conv_module1 = ConvolutionModule(d_model, cnn_module_kernel) + + self.conv_module2 = ConvolutionModule(d_model, cnn_module_kernel) + + self.norm_final = BasicNorm(d_model) + + self.bypass_scale = nn.Parameter(torch.tensor(0.5)) + + # try to ensure the output is close to zero-mean (or at least, zero-median). + self.balancer = ActivationBalancer( + d_model, + channel_dim=-1, + min_positive=0.45, + max_positive=0.55, + max_abs=6.0, + ) + self.whiten = Whiten( + num_groups=1, whitening_limit=5.0, prob=(0.025, 0.25), grad_scale=0.01 + ) + + def get_bypass_scale(self): + if torch.jit.is_scripting() or not self.training: + return self.bypass_scale + if random.random() < 0.1: + # ensure we get grads if self.bypass_scale becomes out of range + return self.bypass_scale + # hardcode warmup period for bypass scale + warmup_period = 20000.0 + initial_clamp_min = 0.75 + final_clamp_min = 0.25 + if self.batch_count > warmup_period: + clamp_min = final_clamp_min + else: + clamp_min = initial_clamp_min - (self.batch_count / warmup_period) * ( + initial_clamp_min - final_clamp_min + ) + return self.bypass_scale.clamp(min=clamp_min, max=1.0) + + def get_dynamic_dropout_rate(self): + # return dropout rate for the dynamic modules (self_attn, pooling, convolution); this + # starts at 0.2 and rapidly decreases to 0. Its purpose is to keep the training stable + # at the beginning, by making the network focus on the feedforward modules. + if torch.jit.is_scripting() or not self.training: + return 0.0 + warmup_period = 2000.0 + initial_dropout_rate = 0.2 + final_dropout_rate = 0.0 + if self.batch_count > warmup_period: + return final_dropout_rate + else: + return initial_dropout_rate - ( + initial_dropout_rate * final_dropout_rate + ) * (self.batch_count / warmup_period) + + def forward( + self, + src: Tensor, + pos_emb: Tensor, + attn_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + """ + Pass the input through the encoder layer. + + Args: + src: the sequence to the encoder layer (required). + pos_emb: Positional embedding tensor (required). + src_mask: the mask for the src sequence (optional). + src_key_padding_mask: the mask for the src keys per batch (optional). + batch_split: if not None, this layer will only be applied to + + Shape: + src: (S, N, E). + pos_emb: (N, 2*S-1, E) + src_mask: (S, S). + src_key_padding_mask: (N, S). + S is the source sequence length, N is the batch size, E is the feature number + """ + src_orig = src + + # macaron style feed forward module + src = src + self.feed_forward1(src) + + # dropout rate for submodules that interact with time. + dynamic_dropout = self.get_dynamic_dropout_rate() + + # pooling module + if torch.jit.is_scripting(): + src = src + self.pooling(src, src_key_padding_mask=src_key_padding_mask) + elif random.random() >= dynamic_dropout: + src = src + self.pooling(src, src_key_padding_mask=src_key_padding_mask) + + if torch.jit.is_scripting(): + src_att, attn_weights = self.self_attn( + src, + pos_emb=pos_emb, + attn_mask=attn_mask, + key_padding_mask=src_key_padding_mask, + ) + src = src + src_att + + src = src + self.conv_module1( + src, src_key_padding_mask=src_key_padding_mask + ) + + src = src + self.feed_forward2(src) + + src = src + self.self_attn.forward2(src, attn_weights) + + src = src + self.conv_module2( + src, src_key_padding_mask=src_key_padding_mask + ) + else: + use_self_attn = random.random() >= dynamic_dropout + if use_self_attn: + src_att, attn_weights = self.self_attn( + src, + pos_emb=pos_emb, + attn_mask=attn_mask, + key_padding_mask=src_key_padding_mask, + ) + src = src + src_att + + if random.random() >= dynamic_dropout: + src = src + self.conv_module1( + src, src_key_padding_mask=src_key_padding_mask + ) + + src = src + self.feed_forward2(src) + + if use_self_attn: + src = src + self.self_attn.forward2(src, attn_weights) + + if random.random() >= dynamic_dropout: + src = src + self.conv_module2( + src, src_key_padding_mask=src_key_padding_mask + ) + + src = src + self.feed_forward3(src) + + src = self.norm_final(self.balancer(src)) + + delta = src - src_orig + + src = src_orig + delta * self.get_bypass_scale() + + return self.whiten(src) + + def streaming_forward( + self, + src: Tensor, + pos_emb: Tensor, + cached_len: Tensor, + cached_avg: Tensor, + cached_key: Tensor, + cached_val: Tensor, + cached_val2: Tensor, + cached_conv1: Tensor, + cached_conv2: Tensor, + ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: + """ + Pass the input through the encoder layer. + + Args: + src: the sequence to the encoder layer (required). + pos_emb: Positional embedding tensor (required). + cached_len: processed number of past frames. + cached_avg: cached average of past frames. + cached_key: cached key tensor of left context for the first attention module. + cached_val: cached value tensor of left context for the first attention module. + cached_val2: cached value tensor of left context for the second attention module. + cached_conv1: cached left context for the first convolution module. + cached_conv2: cached left context for the second convolution module. + + Shape: + src: (S, N, E). + pos_emb: (N, left_context_len+2*S-1, E) + cached_len: (N,) + N is the batch size. + cached_avg: (N, C). + N is the batch size, C is the feature dimension. + cached_key: (left_context_len, N, K). + N is the batch size, K is the key dimension. + cached_val: (left_context_len, N, V). + N is the batch size, V is the key dimension. + cached_val2: (left_context_len, N, V). + N is the batch size, V is the key dimension. + cached_conv1: (N, C, kernel_size-1). + N is the batch size, C is the convolution channels. + cached_conv2: (N, C, kernel_size-1). + N is the batch size, C is the convolution channels. + """ + src_orig = src + + # macaron style feed forward module + src = src + self.feed_forward1(src) + + src_pool, cached_len, cached_avg = self.pooling.streaming_forward( + src, + cached_len=cached_len, + cached_avg=cached_avg, + ) + src = src + src_pool + + ( + src_attn, + attn_weights, + cached_key, + cached_val, + ) = self.self_attn.streaming_forward( + src, + pos_emb=pos_emb, + cached_key=cached_key, + cached_val=cached_val, + ) + src = src + src_attn + + src_conv, cached_conv1 = self.conv_module1.streaming_forward( + src, + cache=cached_conv1, + ) + src = src + src_conv + + src = src + self.feed_forward2(src) + + src_attn, cached_val2 = self.self_attn.streaming_forward2( + src, + attn_weights, + cached_val=cached_val2, + ) + src = src + src_attn + + src_conv, cached_conv2 = self.conv_module2.streaming_forward( + src, + cache=cached_conv2, + ) + src = src + src_conv + + src = src + self.feed_forward3(src) + + src = self.norm_final(self.balancer(src)) + + delta = src - src_orig + + src = src_orig + delta * self.bypass_scale + + return ( + src, + cached_len, + cached_avg, + cached_key, + cached_val, + cached_val2, + cached_conv1, + cached_conv2, + ) + + +class ZipformerEncoder(nn.Module): + r"""ZipformerEncoder is a stack of N encoder layers + + Args: + encoder_layer: an instance of the ZipformerEncoderLayer() class (required). + num_layers: the number of sub-encoder-layers in the encoder (required). + + Examples:: + >>> encoder_layer = ZipformerEncoderLayer(d_model=512, nhead=8) + >>> zipformer_encoder = ZipformerEncoder(encoder_layer, num_layers=6) + >>> src = torch.rand(10, 32, 512) + >>> out = zipformer_encoder(src) + """ + + def __init__( + self, + encoder_layer: nn.Module, + num_layers: int, + dropout: float, + warmup_begin: float, + warmup_end: float, + ) -> None: + super().__init__() + # will be written to, see set_batch_count() Note: in inference time this + # may be zero but should be treated as large, we can check if + # self.training is true. + self.batch_count = 0 + self.warmup_begin = warmup_begin + self.warmup_end = warmup_end + # module_seed is for when we need a random number that is unique to the module but + # shared across jobs. It's used to randomly select how many layers to drop, + # so that we can keep this consistent across worker tasks (for efficiency). + self.module_seed = torch.randint(0, 1000, ()).item() + + self.encoder_pos = RelPositionalEncoding(encoder_layer.d_model, dropout) + + self.layers = nn.ModuleList( + [copy.deepcopy(encoder_layer) for i in range(num_layers)] + ) + self.num_layers = num_layers + + self.d_model = encoder_layer.d_model + self.attention_dim = encoder_layer.attention_dim + self.cnn_module_kernel = encoder_layer.cnn_module_kernel + + assert 0 <= warmup_begin <= warmup_end, (warmup_begin, warmup_end) + + delta = (1.0 / num_layers) * (warmup_end - warmup_begin) + cur_begin = warmup_begin + for i in range(num_layers): + self.layers[i].warmup_begin = cur_begin + cur_begin += delta + self.layers[i].warmup_end = cur_begin + + def get_layers_to_drop(self, rnd_seed: int): + ans = set() + if not self.training: + return ans + + batch_count = self.batch_count + num_layers = len(self.layers) + + def get_layerdrop_prob(layer: int) -> float: + layer_warmup_begin = self.layers[layer].warmup_begin + layer_warmup_end = self.layers[layer].warmup_end + + initial_layerdrop_prob = 0.5 + final_layerdrop_prob = 0.05 + + if batch_count == 0: + # As a special case, if batch_count == 0, return 0 (drop no + # layers). This is rather ugly, I'm afraid; it is intended to + # enable our scan_pessimistic_batches_for_oom() code to work correctly + # so if we are going to get OOM it will happen early. + # also search for 'batch_count' with quotes in this file to see + # how we initialize the warmup count to a random number between + # 0 and 10. + return 0.0 + elif batch_count < layer_warmup_begin: + return initial_layerdrop_prob + elif batch_count > layer_warmup_end: + return final_layerdrop_prob + else: + # linearly interpolate + t = (batch_count - layer_warmup_begin) / layer_warmup_end + assert 0.0 <= t < 1.001, t + return initial_layerdrop_prob + t * ( + final_layerdrop_prob - initial_layerdrop_prob + ) + + shared_rng = random.Random(batch_count + self.module_seed) + independent_rng = random.Random(rnd_seed) + + layerdrop_probs = [get_layerdrop_prob(i) for i in range(num_layers)] + tot = sum(layerdrop_probs) + # Instead of drawing the samples independently, we first randomly decide + # how many layers to drop out, using the same random number generator between + # jobs so that all jobs drop out the same number (this is for speed). + # Then we use an approximate approach to drop out the individual layers + # with their specified probs while reaching this exact target. + num_to_drop = int(tot) + int(shared_rng.random() < (tot - int(tot))) + + layers = list(range(num_layers)) + independent_rng.shuffle(layers) + + # go through the shuffled layers until we get the required number of samples. + if num_to_drop > 0: + for layer in itertools.cycle(layers): + if independent_rng.random() < layerdrop_probs[layer]: + ans.add(layer) + if len(ans) == num_to_drop: + break + if shared_rng.random() < 0.005 or __name__ == "__main__": + logging.info( + f"warmup_begin={self.warmup_begin:.1f}, warmup_end={self.warmup_end:.1f}, " + f"batch_count={batch_count:.1f}, num_to_drop={num_to_drop}, layers_to_drop={ans}" + ) + return ans + + def forward( + self, + src: Tensor, + # Note: The type of feature_mask should be Union[float, Tensor], + # but to make torch.jit.script() work, we use `float` here + feature_mask: float = 1.0, + attn_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + r"""Pass the input through the encoder layers in turn. + + Args: + src: the sequence to the encoder (required). + feature_mask: something that broadcasts with src, that we'll multiply `src` + by at every layer. + mask: the mask for the src sequence (optional). + src_key_padding_mask: the mask for the src keys per batch (optional). + + Shape: + src: (S, N, E). + pos_emb: (N, 2*S-1, E) + mask: (S, S). + src_key_padding_mask: (N, S). + S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number + + Returns: (x, x_no_combine), both of shape (S, N, E) + """ + pos_emb = self.encoder_pos(src) + output = src + + if torch.jit.is_scripting(): + layers_to_drop = [] + else: + rnd_seed = src.numel() + random.randint(0, 1000) + layers_to_drop = self.get_layers_to_drop(rnd_seed) + + output = output * feature_mask + + for i, mod in enumerate(self.layers): + if not torch.jit.is_scripting(): + if i in layers_to_drop: + continue + output = mod( + output, + pos_emb, + attn_mask=attn_mask, + src_key_padding_mask=src_key_padding_mask, + ) + + output = output * feature_mask + + return output + + @torch.jit.export + def streaming_forward( + self, + src: Tensor, + cached_len: Tensor, + cached_avg: Tensor, + cached_key: Tensor, + cached_val: Tensor, + cached_val2: Tensor, + cached_conv1: Tensor, + cached_conv2: Tensor, + ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: + r"""Pass the input through the encoder layers in turn. + + Args: + src: the sequence to the encoder (required). + cached_len: number of past frames. + cached_avg: cached average of past frames. + cached_key: cached key tensor for first attention module. + cached_val: cached value tensor for first attention module. + cached_val2: cached value tensor for second attention module. + cached_conv1: cached left contexts for the first convolution module. + cached_conv2: cached left contexts for the second convolution module. + + Shape: + src: (S, N, E). + cached_len: (N,) + N is the batch size. + cached_avg: (num_layers, N, C). + N is the batch size, C is the feature dimension. + cached_key: (num_layers, left_context_len, N, K). + N is the batch size, K is the key dimension. + cached_val: (num_layers, left_context_len, N, V). + N is the batch size, V is the key dimension. + cached_val2: (num_layers, left_context_len, N, V). + N is the batch size, V is the key dimension. + cached_conv1: (num_layers, N, C, kernel_size-1). + N is the batch size, C is the convolution channels. + cached_conv2: (num_layers, N, C, kernel_size-1). + N is the batch size, C is the convolution channels. + + Returns: A tuple of 8 tensors: + - output tensor + - updated cached number of past frmaes. + - updated cached average of past frmaes. + - updated cached key tensor of of the first attention module. + - updated cached value tensor of of the first attention module. + - updated cached value tensor of of the second attention module. + - updated cached left contexts of the first convolution module. + - updated cached left contexts of the second convolution module. + """ + assert cached_len.size(0) == self.num_layers, ( + cached_len.size(0), + self.num_layers, + ) + assert cached_avg.size(0) == self.num_layers, ( + cached_avg.size(0), + self.num_layers, + ) + assert cached_key.size(0) == self.num_layers, ( + cached_key.size(0), + self.num_layers, + ) + assert cached_val.size(0) == self.num_layers, ( + cached_val.size(0), + self.num_layers, + ) + assert cached_val2.size(0) == self.num_layers, ( + cached_val2.size(0), + self.num_layers, + ) + assert cached_conv1.size(0) == self.num_layers, ( + cached_conv1.size(0), + self.num_layers, + ) + assert cached_conv2.size(0) == self.num_layers, ( + cached_conv2.size(0), + self.num_layers, + ) + + left_context_len = cached_key.shape[1] + pos_emb = self.encoder_pos(src, left_context_len) + output = src + + new_cached_len = [] + new_cached_avg = [] + new_cached_key = [] + new_cached_val = [] + new_cached_val2 = [] + new_cached_conv1 = [] + new_cached_conv2 = [] + for i, mod in enumerate(self.layers): + output, len_avg, avg, key, val, val2, conv1, conv2 = mod.streaming_forward( + output, + pos_emb, + cached_len=cached_len[i], + cached_avg=cached_avg[i], + cached_key=cached_key[i], + cached_val=cached_val[i], + cached_val2=cached_val2[i], + cached_conv1=cached_conv1[i], + cached_conv2=cached_conv2[i], + ) + # Update caches + new_cached_len.append(len_avg) + new_cached_avg.append(avg) + new_cached_key.append(key) + new_cached_val.append(val) + new_cached_val2.append(val2) + new_cached_conv1.append(conv1) + new_cached_conv2.append(conv2) + + return ( + output, + torch.stack(new_cached_len, dim=0), + torch.stack(new_cached_avg, dim=0), + torch.stack(new_cached_key, dim=0), + torch.stack(new_cached_val, dim=0), + torch.stack(new_cached_val2, dim=0), + torch.stack(new_cached_conv1, dim=0), + torch.stack(new_cached_conv2, dim=0), + ) + + +class DownsampledZipformerEncoder(nn.Module): + r""" + DownsampledZipformerEncoder is a zipformer encoder evaluated at a reduced frame rate, + after convolutional downsampling, and then upsampled again at the output, and combined + with the origin input, so that the output has the same shape as the input. + """ + + def __init__( + self, encoder: nn.Module, input_dim: int, output_dim: int, downsample: int + ): + super(DownsampledZipformerEncoder, self).__init__() + self.downsample_factor = downsample + self.downsample = AttentionDownsample(input_dim, output_dim, downsample) + self.encoder = encoder + self.num_layers = encoder.num_layers + self.d_model = encoder.d_model + self.attention_dim = encoder.attention_dim + self.cnn_module_kernel = encoder.cnn_module_kernel + self.upsample = SimpleUpsample(output_dim, downsample) + self.out_combiner = SimpleCombiner( + input_dim, output_dim, min_weight=(0.0, 0.25) + ) + + def forward( + self, + src: Tensor, + # Note: the type of feature_mask should be Unino[float, Tensor], + # but to make torch.jit.script() happ, we use float here + feature_mask: float = 1.0, + attn_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + r"""Downsample, go through encoder, upsample. + + Args: + src: the sequence to the encoder (required). + feature_mask: something that broadcasts with src, that we'll multiply `src` + by at every layer. feature_mask is expected to be already downsampled by + self.downsample_factor. + attn_mask: attention mask (optional). Should be downsampled already. + src_key_padding_mask: the mask for the src keys per batch (optional). Should be downsampled already. + + Shape: + src: (S, N, E). + attn_mask: (S, S). + src_key_padding_mask: (N, S). + S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number + + Returns: output of shape (S, N, F) where F is the number of output features + (output_dim to constructor) + """ + src_orig = src + src = self.downsample(src) + + src = self.encoder( + src, + feature_mask=feature_mask, + attn_mask=attn_mask, + src_key_padding_mask=src_key_padding_mask, + ) + src = self.upsample(src) + # remove any extra frames that are not a multiple of downsample_factor + src = src[: src_orig.shape[0]] + + return self.out_combiner(src_orig, src) + + def streaming_forward( + self, + src: Tensor, + cached_len: Tensor, + cached_avg: Tensor, + cached_key: Tensor, + cached_val: Tensor, + cached_val2: Tensor, + cached_conv1: Tensor, + cached_conv2: Tensor, + ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: + r"""Downsample, go through encoder, upsample. + + Args: + src: the sequence to the encoder (required). + cached_avg: cached average value of past frames. + cached_len: length of past frames. + cached_key: cached key tensor for the first attention module. + cached_val: cached value tensor for the first attention module. + cached_val2: cached value tensor for the second attention module. + cached_conv1: cached left context for the first convolution module. + cached_conv2: cached left context for the second convolution module. + + Shape: + src: (S, N, E). + cached_len: (N,) + N is the batch size. + cached_avg: (num_layers, N, C). + N is the batch size, C is the feature dimension. + cached_key: (num_layers, left_context_len, N, K). + N is the batch size, K is the key dimension. + cached_val: (num_layers, left_context_len, N, V). + N is the batch size, V is the key dimension. + cached_val2: (num_layers, left_context_len, N, V). + N is the batch size, V is the key dimension. + cached_conv1: (num_layers, N, C, kernel_size-1). + N is the batch size, C is the convolution channels. + cached_conv2: (num_layers, N, C, kernel_size-1). + N is the batch size, C is the convolution channels. + Returns: output of shape (S, N, F) where F is the number of output features + (output_dim to constructor) + """ + src_orig = src + src = self.downsample(src) + + ( + src, + cached_len, + cached_avg, + cached_key, + cached_val, + cached_val2, + cached_conv1, + cached_conv2, + ) = self.encoder.streaming_forward( + src, + cached_len=cached_len, + cached_avg=cached_avg, + cached_key=cached_key, + cached_val=cached_val, + cached_val2=cached_val2, + cached_conv1=cached_conv1, + cached_conv2=cached_conv2, + ) + src = self.upsample(src) + # remove any extra frames that are not a multiple of downsample_factor + src = src[: src_orig.shape[0]] + + return ( + self.out_combiner(src_orig, src), + cached_len, + cached_avg, + cached_key, + cached_val, + cached_val2, + cached_conv1, + cached_conv2, + ) + + +class AttentionDownsample(torch.nn.Module): + """ + Does downsampling with attention, by weighted sum, and a projection.. + """ + + def __init__(self, in_channels: int, out_channels: int, downsample: int): + """ + Require out_channels > in_channels. + """ + super(AttentionDownsample, self).__init__() + self.query = nn.Parameter(torch.randn(in_channels) * (in_channels**-0.5)) + + # fill in the extra dimensions with a projection of the input + if out_channels > in_channels: + self.extra_proj = nn.Linear( + in_channels * downsample, out_channels - in_channels, bias=False + ) + else: + self.extra_proj = None + self.downsample = downsample + + def forward(self, src: Tensor) -> Tensor: + """ + x: (seq_len, 1, in_channels) + Returns a tensor of shape + ( (seq_len+downsample-1)//downsample, batch_size, out_channels) + """ + (seq_len, batch_size, in_channels) = src.shape + ds = self.downsample + d_seq_len = (seq_len + ds - 1) // ds + + # Pad to an exact multiple of self.downsample + if seq_len != d_seq_len * ds: + # right-pad src, repeating the last element. + pad = d_seq_len * ds - seq_len + src_extra = src[src.shape[0] - 1 :].expand(pad, src.shape[1], src.shape[2]) + src = torch.cat((src, src_extra), dim=0) + assert src.shape[0] == d_seq_len * ds, (src.shape[0], d_seq_len, ds) + + src = src.reshape(d_seq_len, ds, batch_size, in_channels) + scores = (src * self.query).sum(dim=-1, keepdim=True) + + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + scores = penalize_abs_values_gt(scores, limit=10.0, penalty=1.0e-04) + + weights = scores.softmax(dim=1) + + # ans1 is the first `in_channels` channels of the output + ans = (src * weights).sum(dim=1) + src = src.permute(0, 2, 1, 3).reshape(d_seq_len, batch_size, ds * in_channels) + + if self.extra_proj is not None: + ans2 = self.extra_proj(src) + ans = torch.cat((ans, ans2), dim=2) + return ans + + +class SimpleUpsample(torch.nn.Module): + """ + A very simple form of upsampling that mostly just repeats the input, but + also adds a position-specific bias. + """ + + def __init__(self, num_channels: int, upsample: int): + super(SimpleUpsample, self).__init__() + self.bias = nn.Parameter(torch.randn(upsample, num_channels) * 0.01) + + def forward(self, src: Tensor) -> Tensor: + """ + x: (seq_len, batch_size, num_channels) + Returns a tensor of shape + ( (seq_len*upsample), batch_size, num_channels) + """ + upsample = self.bias.shape[0] + (seq_len, batch_size, num_channels) = src.shape + src = src.unsqueeze(1).expand(seq_len, upsample, batch_size, num_channels) + src = src + self.bias.unsqueeze(1) + src = src.reshape(seq_len * upsample, batch_size, num_channels) + return src + + +class SimpleCombinerIdentity(nn.Module): + def __init__(self, *args, **kwargs): + super().__init__() + + def forward(self, src1: Tensor, src2: Tensor) -> Tensor: + return src1 + + +class SimpleCombiner(torch.nn.Module): + """ + A very simple way of combining 2 vectors of 2 different dims, via a + learned weighted combination in the shared part of the dim. + Args: + dim1: the dimension of the first input, e.g. 256 + dim2: the dimension of the second input, e.g. 384. + The output will have the same dimension as dim2. + """ + + def __init__(self, dim1: int, dim2: int, min_weight: Tuple[float] = (0.0, 0.0)): + super(SimpleCombiner, self).__init__() + assert dim2 >= dim1, (dim2, dim1) + self.weight1 = nn.Parameter(torch.zeros(())) + self.min_weight = min_weight + + def forward(self, src1: Tensor, src2: Tensor) -> Tensor: + """ + src1: (*, dim1) + src2: (*, dim2) + + Returns: a tensor of shape (*, dim2) + """ + assert src1.shape[:-1] == src2.shape[:-1], (src1.shape, src2.shape) + + weight1 = self.weight1 + if not torch.jit.is_scripting(): + if ( + self.training + and random.random() < 0.25 + and self.min_weight != (0.0, 0.0) + ): + weight1 = weight1.clamp( + min=self.min_weight[0], max=1.0 - self.min_weight[1] + ) + + src1 = src1 * weight1 + src2 = src2 * (1.0 - weight1) + + src1_dim = src1.shape[-1] + src2_dim = src2.shape[-1] + if src1_dim != src2_dim: + if src1_dim < src2_dim: + src1 = torch.nn.functional.pad(src1, (0, src2_dim - src1_dim)) + else: + src1 = src1[:src2_dim] + + return src1 + src2 + + +class RelPositionalEncoding(torch.nn.Module): + """Relative positional encoding module. + + See : Appendix B in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" + Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/embedding.py + + Args: + d_model: Embedding dimension. + dropout_rate: Dropout rate. + max_len: Maximum input length. + + """ + + def __init__( + self, + d_model: int, + dropout_rate: float, + max_len: int = 5000, + ) -> None: + """Construct a PositionalEncoding object.""" + super(RelPositionalEncoding, self).__init__() + self.d_model = d_model + self.dropout = torch.nn.Dropout(dropout_rate) + self.pe = None + self.extend_pe(torch.tensor(0.0).expand(max_len)) + + def extend_pe(self, x: Tensor, left_context_len: int = 0) -> None: + """Reset the positional encodings.""" + x_size_left = x.size(0) + left_context_len + if self.pe is not None: + # self.pe contains both positive and negative parts + # the length of self.pe is 2 * input_len - 1 + if self.pe.size(1) >= x_size_left * 2 - 1: + # Note: TorchScript doesn't implement operator== for torch.Device + if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): + self.pe = self.pe.to(dtype=x.dtype, device=x.device) + return + # Suppose `i` means to the position of query vecotr and `j` means the + # position of key vector. We use position relative positions when keys + # are to the left (i>j) and negative relative positions otherwise (i Tensor: + """Add positional encoding. + + Args: + x (torch.Tensor): Input tensor (time, batch, `*`). + left_context_len: (int): Length of cached left context. + + Returns: + torch.Tensor: Encoded tensor (batch, left_context_len + 2*time-1, `*`). + + """ + self.extend_pe(x, left_context_len) + x_size_left = x.size(0) + left_context_len + pos_emb = self.pe[ + :, + self.pe.size(1) // 2 + - x_size_left + + 1 : self.pe.size(1) // 2 # noqa E203 + + x.size(0), + ] + return self.dropout(pos_emb) + + +class RelPositionMultiheadAttention(nn.Module): + r"""Multi-Head Attention layer with relative position encoding + + This is a quite heavily modified from: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context", + we have to write up the differences. + + + Args: + embed_dim: total dimension of the model. + attention_dim: dimension in the attention module, may be less or more than embed_dim + but must be a multiple of num_heads. + num_heads: parallel attention heads. + dropout: a Dropout layer on attn_output_weights. Default: 0.0. + + Examples:: + + >>> rel_pos_multihead_attn = RelPositionMultiheadAttention(embed_dim, num_heads) + >>> attn_output, attn_output_weights = multihead_attn(query, key, value, pos_emb) + """ + + def __init__( + self, + embed_dim: int, + attention_dim: int, + num_heads: int, + pos_dim: int, + dropout: float = 0.0, + ) -> None: + super(RelPositionMultiheadAttention, self).__init__() + self.embed_dim = embed_dim + self.attention_dim = attention_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = attention_dim // num_heads + self.pos_dim = pos_dim + assert self.head_dim % 2 == 0, self.head_dim + assert self.head_dim * num_heads == attention_dim, ( + self.head_dim, + num_heads, + attention_dim, + ) + + # the initial_scale is supposed to take over the "scaling" factor of + # head_dim ** -0.5, dividing it between the query and key. + in_proj_dim = ( + 2 * attention_dim + + attention_dim // 2 # query, key + + pos_dim * num_heads # value + ) # positional encoding query + + self.in_proj = ScaledLinear( + embed_dim, in_proj_dim, bias=True, initial_scale=self.head_dim**-0.25 + ) + + # self.whiten_values is applied on the values in forward(); + # it just copies the keys but prevents low-rank distribution by modifying grads. + self.whiten_values = Whiten( + num_groups=num_heads, + whitening_limit=2.0, + prob=(0.025, 0.25), + grad_scale=0.025, + ) + self.whiten_keys = Whiten( + num_groups=num_heads, + whitening_limit=2.0, + prob=(0.025, 0.25), + grad_scale=0.025, + ) + + # linear transformation for positional encoding. + self.linear_pos = ScaledLinear( + embed_dim, num_heads * pos_dim, bias=False, initial_scale=0.05 + ) + + # the following are for diagnosics only, see --print-diagnostics option. + # they only copy their inputs. + self.copy_pos_query = Identity() + self.copy_query = Identity() + + self.out_proj = ScaledLinear( + attention_dim // 2, embed_dim, bias=True, initial_scale=0.05 + ) + + self.in_proj2 = nn.Linear(embed_dim, attention_dim // 2, bias=False) + self.out_proj2 = ScaledLinear( + attention_dim // 2, embed_dim, bias=True, initial_scale=0.05 + ) + # self.whiten_values2 is applied on the values in forward2() + self.whiten_values2 = Whiten( + num_groups=num_heads, + whitening_limit=2.0, + prob=(0.025, 0.25), + grad_scale=0.025, + ) + + def forward( + self, + x: Tensor, + pos_emb: Tensor, + key_padding_mask: Optional[Tensor] = None, + attn_mask: Optional[Tensor] = None, + ) -> Tuple[Tensor, Tensor]: + r""" + Args: + x: input to be projected to query, key, value + pos_emb: Positional embedding tensor + key_padding_mask: if provided, specified padding elements in the key will + be ignored by the attention. When given a binary mask and a value is True, + the corresponding value on the attention layer will be ignored. When given + a byte mask and a value is non-zero, the corresponding value on the attention + layer will be ignored + attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all + the batches while a 3D mask allows to specify a different mask for the entries of each batch. + + Shape: + - Inputs: + - x: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is + the embedding dimension. + - pos_emb: :math:`(N, 2*L-1, E)` where L is the target sequence length, N is the batch size, E is + the embedding dimension. + - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. + If a ByteTensor is provided, the non-zero positions will be ignored while the position + with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the + value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. + - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. + 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, + S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked + positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend + while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` + is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor + is provided, it will be added to the attention weight. + + - Returns: (attn_output, attn_weights) + + - attn_output: :math:`(S, N, E)` where S is the sequence length, N is the batch size, + E is the embedding dimension. + - attn_weights: :math:`(N * N, S, S)` where N is the batch size, H is the num-heads + and S is the sequence length. + """ + x, weights = self.multi_head_attention_forward( + self.in_proj(x), + self.linear_pos(pos_emb), + self.attention_dim, + self.num_heads, + self.dropout, + self.out_proj.weight, + self.out_proj.bias, + training=self.training, + key_padding_mask=key_padding_mask, + attn_mask=attn_mask, + ) + return x, weights + + def streaming_forward( + self, + x: Tensor, + pos_emb: Tensor, + cached_key: Tensor, + cached_val: Tensor, + ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + r""" + Args: + x: input to be projected to query, key, value + pos_emb: Positional embedding tensor + attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all + the batches while a 3D mask allows to specify a different mask for the entries of each batch. + + Shape: + - Inputs: + - x: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is + the embedding dimension. + - pos_emb: :math:`(N, 2*L-1, E)` where L is the target sequence length, N is the batch size, E is + the embedding dimension. + - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. + 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, + S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked + positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend + while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` + is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor + is provided, it will be added to the attention weight. + - cached_key: :math:`(left_context_len, N, K)`, where N is the batch size, K is the key dimension. + - cached_val: :math:`(left_context_len, N, V)`, where N is the batch size, V is the value dimension. + + - Returns: (attn_output, attn_weights, cached_key, cached_val) + + - attn_output: :math:`(S, N, E)` where S is the sequence length, N is the batch size, + E is the embedding dimension. + - attn_weights: :math:`(N * N, S, S)` where N is the batch size, H is the num-heads + and S is the sequence length. + - cached_key: :math:`(left_context_len, N, K)`, updated cached attention key tensor of + left context + - cached_val: :math:`(left_context_len, N, K)`, updated cached attention value tensor of + """ + ( + x, + weights, + cached_key, + cached_val, + ) = self.streaming_multi_head_attention_forward( + self.in_proj(x), + self.linear_pos(pos_emb), + self.attention_dim, + self.num_heads, + self.out_proj.weight, + self.out_proj.bias, + cached_key=cached_key, + cached_val=cached_val, + ) + return x, weights, cached_key, cached_val + + def multi_head_attention_forward( + self, + x_proj: Tensor, + pos: Tensor, + attention_dim: int, + num_heads: int, + dropout_p: float, + out_proj_weight: Tensor, + out_proj_bias: Tensor, + training: bool = True, + key_padding_mask: Optional[Tensor] = None, + attn_mask: Optional[Tensor] = None, + ) -> Tuple[Tensor, Tensor]: + r""" + Args: + x_proj: the projected input, to be split into query, key, value. + pos: head-specific biases arising from the positional embeddings. + attention_dim: dimension inside attention mechanism + num_heads: parallel attention heads. + dropout_p: probability of an element to be zeroed. + out_proj_weight, out_proj_bias: the output projection weight and bias. + training: apply dropout if is ``True``. + key_padding_mask: if provided, specified padding elements in the key will + be ignored by the attention. This is an binary mask. When the value is True, + the corresponding value on the attention layer will be filled with -inf. + attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all + the batches while a 3D mask allows to specify a different mask for the entries of each batch. + + Shape: + Inputs: + - x: :math:`(L, N, 7 * A // 2)` where L is the target sequence length, N is the batch size, A is + the attention dimension. Will be split into (query, key, value, pos). + - pos: :math:`(N, 2*L-1, A//2)` or :math:`(1, 2*L-1, A//2)` where L is the sequence + length, N is the batch size, and A is the attention dim. + - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. + If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions + will be unchanged. If a BoolTensor is provided, the positions with the + value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. + - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. + 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, + S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked + positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend + while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` + are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor + is provided, it will be added to the attention weight. + + Outputs: + - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, + E is the embedding dimension. + - attn_weights: :math:`(N * H, S, S)` where N is the batch size, + H is the num-heads, S is the sequence length. + """ + + seq_len, bsz, _ = x_proj.size() + + head_dim = attention_dim // num_heads + pos_dim = self.pos_dim # positional-encoding dim per head + assert ( + head_dim * num_heads == attention_dim + ), f"attention_dim must be divisible by num_heads: {head_dim}, {num_heads}, {attention_dim}" + + # self-attention + q = x_proj[..., 0:attention_dim] + k = x_proj[..., attention_dim : 2 * attention_dim] + value_dim = attention_dim // 2 + v = x_proj[..., 2 * attention_dim : 2 * attention_dim + value_dim] + # p is the position-encoding query, its dimension is num_heads*pos_dim.. + p = x_proj[..., 2 * attention_dim + value_dim :] + + k = self.whiten_keys(k) # does nothing in the forward pass. + v = self.whiten_values(v) # does nothing in the forward pass. + q = self.copy_query(q) # for diagnostics only, does nothing. + p = self.copy_pos_query(p) # for diagnostics only, does nothing. + + if attn_mask is not None: + assert ( + attn_mask.dtype == torch.float32 + or attn_mask.dtype == torch.float64 + or attn_mask.dtype == torch.float16 + or attn_mask.dtype == torch.uint8 + or attn_mask.dtype == torch.bool + ), "Only float, byte, and bool types are supported for attn_mask, not {}".format( + attn_mask.dtype + ) + if attn_mask.dtype == torch.uint8: + warnings.warn( + "Byte tensor for attn_mask is deprecated. Use bool tensor instead." + ) + attn_mask = attn_mask.to(torch.bool) + + if attn_mask.dim() == 2: + attn_mask = attn_mask.unsqueeze(0) + if list(attn_mask.size()) != [1, seq_len, seq_len]: + raise RuntimeError("The size of the 2D attn_mask is not correct.") + elif attn_mask.dim() == 3: + if list(attn_mask.size()) != [ + bsz * num_heads, + seq_len, + seq_len, + ]: + raise RuntimeError("The size of the 3D attn_mask is not correct.") + else: + raise RuntimeError( + "attn_mask's dimension {} is not supported".format(attn_mask.dim()) + ) + # attn_mask's dim is 3 now. + + # convert ByteTensor key_padding_mask to bool + if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: + warnings.warn( + "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." + ) + key_padding_mask = key_padding_mask.to(torch.bool) + + q = q.reshape(seq_len, bsz, num_heads, head_dim) + p = p.reshape(seq_len, bsz, num_heads, pos_dim) + k = k.reshape(seq_len, bsz, num_heads, head_dim) + v = v.reshape(seq_len, bsz * num_heads, head_dim // 2).transpose(0, 1) + + if key_padding_mask is not None: + assert key_padding_mask.size(0) == bsz, "{} == {}".format( + key_padding_mask.size(0), bsz + ) + assert key_padding_mask.size(1) == seq_len, "{} == {}".format( + key_padding_mask.size(1), seq_len + ) + + q = q.permute(1, 2, 0, 3) # (batch, head, time1, head_dim) + p = p.permute(1, 2, 0, 3) # (batch, head, time1, pos_dim) + k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) + + seq_len2 = 2 * seq_len - 1 + pos = pos.reshape(1, seq_len2, num_heads, pos_dim).permute(0, 2, 3, 1) + # pos shape now: (batch, head, pos_dim, seq_len2) + + # (batch, head, time1, pos_dim) x (1, head, pos_dim, seq_len2) -> (batch, head, time1, seq_len2) + # [where seq_len2 represents relative position.] + pos_weights = torch.matmul(p, pos) + # the following .as_strided() expression converts the last axis of pos_weights from relative + # to absolute position. I don't know whether I might have got the time-offsets backwards or + # not, but let this code define which way round it is supposed to be. + pos_weights = pos_weights.as_strided( + (bsz, num_heads, seq_len, seq_len), + ( + pos_weights.stride(0), + pos_weights.stride(1), + pos_weights.stride(2) - pos_weights.stride(3), + pos_weights.stride(3), + ), + storage_offset=pos_weights.stride(3) * (seq_len - 1), + ) + + # caution: they are really scores at this point. + attn_output_weights = torch.matmul(q, k) + pos_weights + + if not torch.jit.is_scripting(): + if training and random.random() < 0.1: + # This is a harder way of limiting the attention scores to not be too large. + # It incurs a penalty if any of them has an absolute value greater than 50.0. + # this should be outside the normal range of the attention scores. We use + # this mechanism instead of, say, a limit on entropy, because once the entropy + # gets very small gradients through the softmax can become very small, and + # some mechanisms like that become ineffective. + attn_output_weights = penalize_abs_values_gt( + attn_output_weights, limit=25.0, penalty=1.0e-04 + ) + + # attn_output_weights: (batch, head, time1, time2) + attn_output_weights = attn_output_weights.view( + bsz * num_heads, seq_len, seq_len + ) + + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + attn_output_weights = attn_output_weights.masked_fill( + attn_mask, float("-inf") + ) + else: + attn_output_weights = attn_output_weights + attn_mask + + if key_padding_mask is not None: + attn_output_weights = attn_output_weights.view( + bsz, num_heads, seq_len, seq_len + ) + attn_output_weights = attn_output_weights.masked_fill( + key_padding_mask.unsqueeze(1).unsqueeze(2), + float("-inf"), + ) + attn_output_weights = attn_output_weights.view( + bsz * num_heads, seq_len, seq_len + ) + + # Using this version of softmax, defined in scaling.py, + # should save a little of the memory used in backprop by, if + # we are in automatic mixed precision mode (amp) == autocast, + # only storing the half-precision output for backprop purposes. + attn_output_weights = softmax(attn_output_weights, dim=-1) + + # If we are using chunk-wise attention mask and setting a limited + # num_left_chunks, the attention may only see the padding values which + # will also be masked out by `key_padding_mask`. At this circumstances, + # the whole column of `attn_output_weights` will be `-inf` + # (i.e. be `nan` after softmax). So we fill `0.0` at the masking + # positions to avoid invalid loss value below. + if ( + attn_mask is not None + and attn_mask.dtype == torch.bool + and key_padding_mask is not None + ): + if attn_mask.size(0) != 1: + attn_mask = attn_mask.view(bsz, num_heads, seq_len, seq_len) + combined_mask = attn_mask | key_padding_mask.unsqueeze(1).unsqueeze(2) + else: + # attn_mask.shape == (1, tgt_len, src_len) + combined_mask = attn_mask.unsqueeze(0) | key_padding_mask.unsqueeze( + 1 + ).unsqueeze(2) + + attn_output_weights = attn_output_weights.view( + bsz, num_heads, seq_len, seq_len + ) + attn_output_weights = attn_output_weights.masked_fill(combined_mask, 0.0) + attn_output_weights = attn_output_weights.view( + bsz * num_heads, seq_len, seq_len + ) + + attn_output_weights = nn.functional.dropout( + attn_output_weights, p=dropout_p, training=training + ) + + attn_output = torch.bmm(attn_output_weights, v) + assert list(attn_output.size()) == [bsz * num_heads, seq_len, head_dim // 2] + attn_output = ( + attn_output.transpose(0, 1) + .contiguous() + .view(seq_len, bsz, attention_dim // 2) + ) + attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) + + return attn_output, attn_output_weights + + def streaming_multi_head_attention_forward( + self, + x_proj: Tensor, + pos: Tensor, + attention_dim: int, + num_heads: int, + out_proj_weight: Tensor, + out_proj_bias: Tensor, + cached_key: Tensor, + cached_val: Tensor, + ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + r""" + Args: + x_proj: the projected input, to be split into query, key, value. + pos: head-specific biases arising from the positional embeddings. + attention_dim: dimension inside attention mechanism + num_heads: parallel attention heads. + out_proj_weight, out_proj_bias: the output projection weight and bias. + cached_key: cached attention key tensor of left context. + cached_val: cached attention value tensor of left context. + + Shape: + Inputs: + - x: :math:`(L, N, 7 * A // 2)` where L is the target sequence length, N is the batch size, A is + the attention dimension. Will be split into (query, key, value, pos). + - pos: :math:`(N, 2*L-1, A//2)` or :math:`(1, 2*L-1, A//2)` where L is the sequence + length, N is the batch size, and A is the attention dim. + If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions + will be unchanged. If a BoolTensor is provided, the positions with the + value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. + + Outputs: + - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, + E is the embedding dimension. + - attn_weights: :math:`(N * H, S, S)` where N is the batch size, + H is the num-heads, S is the sequence length. + - cached_key: :math:`(left_context_len, N, K)`, updated cached attention key tensor of left context. + - cached_val: :math:`(left_context_len, N, K)`, updated cached attention value tensor of left context. + """ + + seq_len, bsz, _ = x_proj.size() + + head_dim = attention_dim // num_heads + pos_dim = self.pos_dim # positional-encoding dim per head + assert ( + head_dim * num_heads == attention_dim + ), f"attention_dim must be divisible by num_heads: {head_dim}, {num_heads}, {attention_dim}" + + # self-attention + q = x_proj[..., 0:attention_dim] + k = x_proj[..., attention_dim : 2 * attention_dim] + value_dim = attention_dim // 2 + v = x_proj[..., 2 * attention_dim : 2 * attention_dim + value_dim] + # p is the position-encoding query, its dimension is num_heads*pos_dim.. + p = x_proj[..., 2 * attention_dim + value_dim :] + + left_context_len = cached_key.shape[0] + assert left_context_len > 0, left_context_len + assert cached_key.shape[0] == cached_val.shape[0], ( + cached_key.shape, + cached_val.shape, + ) + # Pad cached left contexts + k = torch.cat([cached_key, k], dim=0) + v = torch.cat([cached_val, v], dim=0) + # Update cached left contexts + cached_key = k[-left_context_len:, ...] + cached_val = v[-left_context_len:, ...] + + # The length of key and value + kv_len = k.shape[0] + + q = q.reshape(seq_len, bsz, num_heads, head_dim) + p = p.reshape(seq_len, bsz, num_heads, pos_dim) + k = k.reshape(kv_len, bsz, num_heads, head_dim) + v = v.reshape(kv_len, bsz * num_heads, head_dim // 2).transpose(0, 1) + + q = q.permute(1, 2, 0, 3) # (batch, head, time1, head_dim) + p = p.permute(1, 2, 0, 3) # (batch, head, time1, pos_dim) + k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) + + seq_len2 = 2 * seq_len - 1 + left_context_len + pos = pos.reshape(1, seq_len2, num_heads, pos_dim).permute(0, 2, 3, 1) + # pos shape now: (batch, head, pos_dim, seq_len2) + + # (batch, head, time1, pos_dim) x (1, head, pos_dim, seq_len2) -> (batch, head, time1, seq_len2) + # [where seq_len2 represents relative position.] + pos_weights = torch.matmul(p, pos) + # the following .as_strided() expression converts the last axis of pos_weights from relative + # to absolute position. I don't know whether I might have got the time-offsets backwards or + # not, but let this code define which way round it is supposed to be. + pos_weights = pos_weights.as_strided( + (bsz, num_heads, seq_len, kv_len), + ( + pos_weights.stride(0), + pos_weights.stride(1), + pos_weights.stride(2) - pos_weights.stride(3), + pos_weights.stride(3), + ), + storage_offset=pos_weights.stride(3) * (seq_len - 1), + ) + + # caution: they are really scores at this point. + attn_output_weights = torch.matmul(q, k) + pos_weights + + # attn_output_weights: (batch, head, time1, time2) + attn_output_weights = attn_output_weights.view(bsz * num_heads, seq_len, kv_len) + + # Using this version of softmax, defined in scaling.py, + # should save a little of the memory used in backprop by, if + # we are in automatic mixed precision mode (amp) == autocast, + # only storing the half-precision output for backprop purposes. + attn_output_weights = softmax(attn_output_weights, dim=-1) + + attn_output = torch.bmm(attn_output_weights, v) + assert list(attn_output.size()) == [bsz * num_heads, seq_len, head_dim // 2] + attn_output = ( + attn_output.transpose(0, 1) + .contiguous() + .view(seq_len, bsz, attention_dim // 2) + ) + attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) + + return attn_output, attn_output_weights, cached_key, cached_val + + def forward2( + self, + x: Tensor, + attn_weights: Tensor, + ) -> Tensor: + """ + Second forward function, where we re-use the attn_weights returned by the first forward function + but with different input. + Args: + x: input, of shape (seq_len, batch_size, embed_dim) + attn_weights: attention weights returned by forward(), of shape (batch_size * num_heads, seq_len, seq_len) + Returns: + output of the same shape as x, i.e. (seq_len, batch_size, embed_dim) + """ + num_heads = self.num_heads + (seq_len, bsz, embed_dim) = x.shape + head_dim = self.attention_dim // num_heads + # v: (tgt_len, bsz, embed_dim // 2) + v = self.in_proj2(x) + v = self.whiten_values2(v) # does nothing in the forward pass. + v = v.reshape(seq_len, bsz * num_heads, head_dim // 2).transpose(0, 1) + + # now v: (bsz * num_heads, seq_len, head_dim // 2) + attn_output = torch.bmm(attn_weights, v) + + if not torch.jit.is_scripting(): + if random.random() < 0.001 or __name__ == "__main__": + self._print_attn_stats(attn_weights, attn_output) + + # attn_output: (bsz * num_heads, seq_len, head_dim) + attn_output = ( + attn_output.transpose(0, 1) + .contiguous() + .view(seq_len, bsz, self.attention_dim // 2) + ) + # returned value is of shape (seq_len, bsz, embed_dim), like x. + return self.out_proj2(attn_output) + + def streaming_forward2( + self, + x: Tensor, + attn_weights: Tensor, + cached_val: Tensor, + ) -> Tuple[Tensor, Tensor]: + """ + Second forward function, where we re-use the attn_weights returned by the first forward function + but with different input. + Args: + x: input, of shape (seq_len, batch_size, embed_dim) + attn_weights: attention weights returned by forward(), of shape (batch_size * num_heads, seq_len, seq_len) + cached_val: cached attention value tensor of left context. + Returns: + - output of the same shape as x, i.e. (seq_len, batch_size, embed_dim) + - updated cached attention value tensor of left context. + """ + num_heads = self.num_heads + (seq_len, bsz, embed_dim) = x.shape + head_dim = self.attention_dim // num_heads + # v: (tgt_len, bsz, embed_dim // 2) + v = self.in_proj2(x) + + left_context_len = cached_val.shape[0] + assert left_context_len > 0, left_context_len + v = torch.cat([cached_val, v], dim=0) + cached_val = v[-left_context_len:] + + seq_len2 = left_context_len + seq_len + v = v.reshape(seq_len2, bsz * num_heads, head_dim // 2).transpose(0, 1) + + # now v: (bsz * num_heads, seq_len, head_dim // 2) + attn_output = torch.bmm(attn_weights, v) + + # attn_output: (bsz * num_heads, seq_len, head_dim) + attn_output = ( + attn_output.transpose(0, 1) + .contiguous() + .view(seq_len, bsz, self.attention_dim // 2) + ) + # returned value is of shape (seq_len, bsz, embed_dim), like x. + return self.out_proj2(attn_output), cached_val + + def _print_attn_stats(self, attn_weights: Tensor, attn_output: Tensor): + # attn_weights: (batch_size * num_heads, seq_len, seq_len) + # attn_output: (bsz * num_heads, seq_len, head_dim) + (n, seq_len, head_dim) = attn_output.shape + num_heads = self.num_heads + bsz = n // num_heads + + with torch.no_grad(): + with torch.cuda.amp.autocast(enabled=False): + attn_weights = attn_weights.to(torch.float32) + attn_output = attn_output.to(torch.float32) + attn_weights_entropy = ( + -((attn_weights + 1.0e-20).log() * attn_weights) + .sum(dim=-1) + .reshape(bsz, num_heads, seq_len) + .mean(dim=(0, 2)) + ) + attn_output = attn_output.reshape(bsz, num_heads, seq_len, head_dim) + attn_output = attn_output.permute(1, 0, 2, 3).reshape( + num_heads, bsz * seq_len, head_dim + ) + attn_output_mean = attn_output.mean(dim=1, keepdim=True) + attn_output = attn_output - attn_output_mean + attn_covar = torch.matmul(attn_output.transpose(1, 2), attn_output) / ( + bsz * seq_len + ) + # attn_covar: (num_heads, head_dim, head_dim) + # eigs, _ = torch.symeig(attn_covar) + # logging.info(f"attn_weights_entropy = {attn_weights_entropy}, output_eigs = {eigs}") + + attn_covar = _diag(attn_covar).mean(dim=1) # (num_heads,) + embed_dim = self.in_proj2.weight.shape[1] + in_proj_covar = ( + self.in_proj2.weight.reshape(num_heads, head_dim, embed_dim) ** 2 + ).mean(dim=(1, 2)) + out_proj_covar = ( + self.out_proj2.weight.reshape(embed_dim, num_heads, head_dim) ** 2 + ).mean(dim=(0, 2)) + logging.info( + f"attn_weights_entropy = {attn_weights_entropy}, covar={attn_covar}, in_proj_covar={in_proj_covar}, out_proj_covar={out_proj_covar}" + ) + + +class PoolingModule(nn.Module): + """ + Averages the input over the time dimension and project with a square matrix. + """ + + def __init__(self, d_model: int): + super().__init__() + self.proj = ScaledLinear(d_model, d_model, initial_scale=0.1, bias=False) + + def forward( + self, + x: Tensor, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + """ + Args: + x: a Tensor of shape (T, N, C) + src_key_padding_mask: a Tensor of bool, of shape (N, T), with True in masked + positions. + + Returns: + - output, a Tensor of shape (T, N, C). + """ + if src_key_padding_mask is not None: + # False in padding positions + padding_mask = src_key_padding_mask.logical_not().to(x.dtype) # (N, T) + # Cumulated numbers of frames from start + cum_mask = padding_mask.cumsum(dim=1) # (N, T) + x = x.cumsum(dim=0) # (T, N, C) + pooling_mask = padding_mask / cum_mask + pooling_mask = pooling_mask.transpose(0, 1).contiguous().unsqueeze(-1) + # now pooling_mask: (T, N, 1) + x = x * pooling_mask # (T, N, C) + else: + num_frames = x.shape[0] + cum_mask = torch.arange(1, num_frames + 1).unsqueeze(1) # (T, 1) + x = x.cumsum(dim=0) # (T, N, C) + pooling_mask = (1.0 / cum_mask).unsqueeze(2) + # now pooling_mask: (T, N, 1) + x = x * pooling_mask + + x = self.proj(x) + return x + + def streaming_forward( + self, + x: Tensor, + cached_len: Tensor, + cached_avg: Tensor, + ) -> Tuple[Tensor, Tensor, Tensor]: + """ + Args: + x: a Tensor of shape (T, N, C) + cached_len: a Tensor of int, of shape (N,), containing the number of + past frames in batch. + cached_avg: a Tensor of shape (N, C), the average over all past frames + in batch. + + Returns: + A tuple of 2 tensors: + - output, a Tensor of shape (T, N, C). + - updated cached_avg, a Tensor of shape (N, C). + """ + x = x.cumsum(dim=0) # (T, N, C) + x = x + (cached_avg * cached_len.unsqueeze(1)).unsqueeze(0) + # Cumulated numbers of frames from start + cum_mask = torch.arange(1, x.size(0) + 1, device=x.device) + cum_mask = cum_mask.unsqueeze(1) + cached_len.unsqueeze(0) # (T, N) + pooling_mask = (1.0 / cum_mask).unsqueeze(2) + # now pooling_mask: (T, N, 1) + x = x * pooling_mask # (T, N, C) + + cached_len = cached_len + x.size(0) + cached_avg = x[-1] + + x = self.proj(x) + return x, cached_len, cached_avg + + +class FeedforwardModule(nn.Module): + """Feedforward module in Zipformer model.""" + + def __init__(self, d_model: int, feedforward_dim: int, dropout: float): + super(FeedforwardModule, self).__init__() + self.in_proj = nn.Linear(d_model, feedforward_dim) + self.balancer = ActivationBalancer( + feedforward_dim, channel_dim=-1, max_abs=10.0, min_prob=0.25 + ) + self.activation = DoubleSwish() + self.dropout = nn.Dropout(dropout) + self.out_proj = ScaledLinear(feedforward_dim, d_model, initial_scale=0.01) + + def forward(self, x: Tensor): + x = self.in_proj(x) + x = self.balancer(x) + x = self.activation(x) + x = self.dropout(x) + x = self.out_proj(x) + return x + + +class ConvolutionModule(nn.Module): + """ConvolutionModule in Zipformer model. + Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/zipformer/convolution.py + + Args: + channels (int): The number of channels of conv layers. + kernel_size (int): Kernerl size of conv layers. + bias (bool): Whether to use bias in conv layers (default=True). + + """ + + def __init__(self, channels: int, kernel_size: int, bias: bool = True) -> None: + """Construct an ConvolutionModule object.""" + super(ConvolutionModule, self).__init__() + # kernerl_size should be a odd number for 'SAME' padding + assert (kernel_size - 1) % 2 == 0, kernel_size + + self.pointwise_conv1 = nn.Conv1d( + channels, + 2 * channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + ) + + # after pointwise_conv1 we put x through a gated linear unit (nn.functional.glu). + # For most layers the normal rms value of channels of x seems to be in the range 1 to 4, + # but sometimes, for some reason, for layer 0 the rms ends up being very large, + # between 50 and 100 for different channels. This will cause very peaky and + # sparse derivatives for the sigmoid gating function, which will tend to make + # the loss function not learn effectively. (for most layers the average absolute values + # are in the range 0.5..9.0, and the average p(x>0), i.e. positive proportion, + # at the output of pointwise_conv1.output is around 0.35 to 0.45 for different + # layers, which likely breaks down as 0.5 for the "linear" half and + # 0.2 to 0.3 for the part that goes into the sigmoid. The idea is that if we + # constrain the rms values to a reasonable range via a constraint of max_abs=10.0, + # it will be in a better position to start learning something, i.e. to latch onto + # the correct range. + self.deriv_balancer1 = ActivationBalancer( + 2 * channels, + channel_dim=1, + max_abs=10.0, + min_positive=0.05, + max_positive=1.0, + ) + + # Will pad cached left context + self.lorder = kernel_size - 1 + self.depthwise_conv = nn.Conv1d( + channels, + channels, + kernel_size, + stride=1, + padding=0, + groups=channels, + bias=bias, + ) + + self.deriv_balancer2 = ActivationBalancer( + channels, + channel_dim=1, + min_positive=0.05, + max_positive=1.0, + max_abs=20.0, + ) + + self.activation = DoubleSwish() + + self.pointwise_conv2 = ScaledConv1d( + channels, + channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + initial_scale=0.05, + ) + + def forward( + self, + x: Tensor, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + """Compute convolution module. + + Args: + x: Input tensor (#time, batch, channels). + src_key_padding_mask: the mask for the src keys per batch (optional): + (batch, #time), contains bool in masked positions. + + Returns: + - Output tensor (#time, batch, channels). + """ + # exchange the temporal dimension and the feature dimension + x = x.permute(1, 2, 0) # (#batch, channels, time). + + # GLU mechanism + x = self.pointwise_conv1(x) # (batch, 2*channels, time) + + x = self.deriv_balancer1(x) + x = nn.functional.glu(x, dim=1) # (batch, channels, time) + + if src_key_padding_mask is not None: + x.masked_fill_(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0) + + # 1D Depthwise Conv + # Make depthwise_conv causal by + # manualy padding self.lorder zeros to the left + x = nn.functional.pad(x, (self.lorder, 0), "constant", 0.0) + x = self.depthwise_conv(x) + + x = self.deriv_balancer2(x) + x = self.activation(x) + + x = self.pointwise_conv2(x) # (batch, channel, time) + + return x.permute(2, 0, 1) + + def streaming_forward( + self, + x: Tensor, + cache: Tensor, + ) -> Tuple[Tensor, Tensor]: + """Compute convolution module. + + Args: + x: Input tensor (#time, batch, channels). + src_key_padding_mask: the mask for the src keys per batch: + (batch, #time), contains bool in masked positions. + cache: Cached left context for depthwise_conv, with shape of + (batch, channels, #kernel_size-1). Only used in real streaming decoding. + + Returns: + A tuple of 2 tensors: + - Output tensor (#time, batch, channels). + - New cached left context, with shape of (batch, channels, #kernel_size-1). + """ + # exchange the temporal dimension and the feature dimension + x = x.permute(1, 2, 0) # (#batch, channels, time). + + # GLU mechanism + x = self.pointwise_conv1(x) # (batch, 2*channels, time) + + x = self.deriv_balancer1(x) + x = nn.functional.glu(x, dim=1) # (batch, channels, time) + + # 1D Depthwise Conv + assert cache.shape == (x.size(0), x.size(1), self.lorder), ( + cache.shape, + (x.size(0), x.size(1), self.lorder), + ) + x = torch.cat([cache, x], dim=2) + # Update cache + cache = x[:, :, -self.lorder :] + x = self.depthwise_conv(x) + + x = self.deriv_balancer2(x) + x = self.activation(x) + + x = self.pointwise_conv2(x) # (batch, channel, time) + + return x.permute(2, 0, 1), cache + + +class Conv2dSubsampling(nn.Module): + """Convolutional 2D subsampling (to 1/4 length). + + Convert an input of shape (N, T, idim) to an output + with shape (N, T', odim), where + T' = (T-3)//2 - 2 == (T-7)//2 + + It is based on + https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + layer1_channels: int = 8, + layer2_channels: int = 32, + layer3_channels: int = 128, + dropout: float = 0.1, + ) -> None: + """ + Args: + in_channels: + Number of channels in. The input shape is (N, T, in_channels). + Caution: It requires: T >=7, in_channels >=7 + out_channels + Output dim. The output shape is (N, (T-7)//2, out_channels) + layer1_channels: + Number of channels in layer1 + layer2_channels: + Number of channels in layer2 + layer3_channels: + Number of channels in layer3 + """ + assert in_channels >= 7, in_channels + super().__init__() + + self.conv = nn.Sequential( + nn.Conv2d( + in_channels=1, + out_channels=layer1_channels, + kernel_size=3, + padding=(0, 1), # (time, freq) + ), + ActivationBalancer(layer1_channels, channel_dim=1), + DoubleSwish(), + nn.Conv2d( + in_channels=layer1_channels, + out_channels=layer2_channels, + kernel_size=3, + stride=2, + padding=0, + ), + ActivationBalancer(layer2_channels, channel_dim=1), + DoubleSwish(), + nn.Conv2d( + in_channels=layer2_channels, + out_channels=layer3_channels, + kernel_size=3, + stride=(1, 2), # (time, freq) + ), + ActivationBalancer(layer3_channels, channel_dim=1), + DoubleSwish(), + ) + out_height = (((in_channels - 1) // 2) - 1) // 2 + self.out = ScaledLinear(out_height * layer3_channels, out_channels) + self.dropout = nn.Dropout(dropout) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Subsample x. + + Args: + x: + Its shape is (N, T, idim). + + Returns: + Return a tensor of shape (N, (T-7)//2, odim) + """ + # On entry, x is (N, T, idim) + x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W) + x = self.conv(x) + # Now x is of shape (N, odim, (T-7)//2, ((idim-1)//2 - 1)//2) + b, c, t, f = x.size() + x = self.out(x.transpose(1, 2).reshape(b, t, c * f)) + # Now x is of shape (N, (T-7)//2, odim) + x = self.dropout(x) + return x + + +def _test_zipformer_main(): + feature_dim = 50 + batch_size = 5 + seq_len = 47 + feature_dim = 50 + # Just make sure the forward pass runs. + + c = Zipformer( + num_features=feature_dim, + encoder_dims=(64, 96), + encoder_unmasked_dims=(48, 64), + nhead=(4, 4), + decode_chunk_size=4, + ) + # Just make sure the forward pass runs. + f = c( + torch.randn(batch_size, seq_len, feature_dim), + torch.full((batch_size,), seq_len, dtype=torch.int64), + ) + assert ((seq_len - 7) // 2 + 1) // 2 == f[0].shape[1], (seq_len, f.shape[1]) + f[0].sum().backward() + c.eval() + f = c( + torch.randn(batch_size, seq_len, feature_dim), + torch.full((batch_size,), seq_len, dtype=torch.int64), + ) + f # to remove flake8 warnings + + +def _test_conv2d_subsampling(): + num_features = 80 + encoder_dims = 384 + dropout = 0.1 + encoder_embed = Conv2dSubsampling(num_features, encoder_dims, dropout=dropout) + for i in range(20, 40): + x = torch.rand(2, i, num_features) + y = encoder_embed(x) + assert (x.shape[1] - 7) // 2 == y.shape[1], (x.shape[1], y.shape[1]) + + +def _test_pooling_module(): + N, S, C = 2, 12, 32 + chunk_len = 4 + m = PoolingModule(d_model=C) + + # test chunk-wise forward with padding_mask + x = torch.randn(S, N, C) + y = m(x) + cached_len = torch.zeros(N, dtype=torch.int32) + cached_avg = torch.zeros(N, C) + for i in range(S // chunk_len): + start = i * chunk_len + end = start + chunk_len + x_chunk = x[start:end] + y_chunk, cached_len, cached_avg = m.streaming_forward( + x_chunk, + cached_len=cached_len, + cached_avg=cached_avg, + ) + assert torch.allclose(y_chunk, y[start:end]), (y_chunk, y[start:end]) + + +def _test_state_stack_unstack(): + m = Zipformer( + num_features=80, + encoder_dims=(64, 96), + encoder_unmasked_dims=(48, 64), + nhead=(4, 4), + zipformer_downsampling_factors=(4, 8), + num_left_chunks=2, + decode_chunk_size=8, + ) + s1 = m.get_init_state() + s2 = m.get_init_state() + states = stack_states([s1, s2]) + new_s1, new_s2 = unstack_states(states) + for i in range(m.num_encoders * 7): + for x, y in zip(s1[i], new_s1[i]): + assert torch.equal(x, y) + for x, y in zip(s2[i], new_s2[i]): + assert torch.equal(x, y) + + +if __name__ == "__main__": + logging.getLogger().setLevel(logging.INFO) + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + _test_zipformer_main() + _test_conv2d_subsampling() + _test_pooling_module() + _test_state_stack_unstack() From a54b748a02550fe523b05c6ccc5845170ba99114 Mon Sep 17 00:00:00 2001 From: behnamasefi Date: Fri, 30 Dec 2022 03:06:09 +0000 Subject: [PATCH 066/174] check for utterance len (#795) Co-authored-by: behnam --- .../pruned_transducer_stateless7_ctc/train.py | 28 ++++++++++++++++++- .../train.py | 28 ++++++++++++++++++- 2 files changed, 54 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/train.py index 162ad8412..5a05e1836 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/train.py @@ -1086,7 +1086,33 @@ def run(rank, world_size, args): # You should use ../local/display_manifest_statistics.py to get # an utterance duration distribution for your dataset to select # the threshold - return 1.0 <= c.duration <= 20.0 + if c.duration < 1.0 or c.duration > 20.0: + logging.warning( + f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + ) + return False + + # In pruned RNN-T, we require that T >= S + # where T is the number of feature frames after subsampling + # and S is the number of tokens in the utterance + + # In ./zipformer.py, the conv module uses the following expression + # for subsampling + T = ((c.num_frames - 7) // 2 + 1) // 2 + tokens = sp.encode(c.supervisions[0].text, out_type=str) + + if T < len(tokens): + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Number of frames (before subsampling): {c.num_frames}. " + f"Number of frames (after subsampling): {T}. " + f"Text: {c.supervisions[0].text}. " + f"Tokens: {tokens}. " + f"Number of tokens: {len(tokens)}" + ) + return False + + return True train_cuts = train_cuts.filter(remove_short_and_long_utt) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/train.py index 63e9d6e90..522ecc974 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/train.py @@ -1077,7 +1077,33 @@ def run(rank, world_size, args): # You should use ../local/display_manifest_statistics.py to get # an utterance duration distribution for your dataset to select # the threshold - return 1.0 <= c.duration <= 20.0 + if c.duration < 1.0 or c.duration > 20.0: + logging.warning( + f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + ) + return False + + # In pruned RNN-T, we require that T >= S + # where T is the number of feature frames after subsampling + # and S is the number of tokens in the utterance + + # In ./zipformer.py, the conv module uses the following expression + # for subsampling + T = ((c.num_frames - 7) // 2 + 1) // 2 + tokens = sp.encode(c.supervisions[0].text, out_type=str) + + if T < len(tokens): + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Number of frames (before subsampling): {c.num_frames}. " + f"Number of frames (after subsampling): {T}. " + f"Text: {c.supervisions[0].text}. " + f"Tokens: {tokens}. " + f"Number of tokens: {len(tokens)}" + ) + return False + + return True train_cuts = train_cuts.filter(remove_short_and_long_utt) From 67ae5fdf2bf2b09d2ce9e5acb7dab12b2d2fc441 Mon Sep 17 00:00:00 2001 From: Zengwei Yao Date: Fri, 30 Dec 2022 15:21:18 +0800 Subject: [PATCH 067/174] Doc streaming zipformer (#798) * add doc for streaming_zipformer * update README.md --- .../Streaming-ASR/librispeech/index.rst | 2 + .../librispeech/zipformer_transducer.rst | 654 ++++++++++++++++++ .../README.md | 3 + egs/librispeech/ASR/zipformer_mmi/README.md | 2 +- 4 files changed, 660 insertions(+), 1 deletion(-) create mode 100644 docs/source/recipes/Streaming-ASR/librispeech/zipformer_transducer.rst create mode 100644 egs/librispeech/ASR/pruned_transducer_stateless7_streaming/README.md diff --git a/docs/source/recipes/Streaming-ASR/librispeech/index.rst b/docs/source/recipes/Streaming-ASR/librispeech/index.rst index 546ce168b..d52e08058 100644 --- a/docs/source/recipes/Streaming-ASR/librispeech/index.rst +++ b/docs/source/recipes/Streaming-ASR/librispeech/index.rst @@ -7,3 +7,5 @@ LibriSpeech pruned_transducer_stateless lstm_pruned_stateless_transducer + + zipformer_transducer diff --git a/docs/source/recipes/Streaming-ASR/librispeech/zipformer_transducer.rst b/docs/source/recipes/Streaming-ASR/librispeech/zipformer_transducer.rst new file mode 100644 index 000000000..f0e8961d7 --- /dev/null +++ b/docs/source/recipes/Streaming-ASR/librispeech/zipformer_transducer.rst @@ -0,0 +1,654 @@ +Zipformer Transducer +==================== + +This tutorial shows you how to run a **streaming** zipformer transducer model +with the `LibriSpeech `_ dataset. + +.. Note:: + + The tutorial is suitable for `pruned_transducer_stateless7_streaming `_, + +.. HINT:: + + We assume you have read the page :ref:`install icefall` and have setup + the environment for ``icefall``. + +.. HINT:: + + We recommend you to use a GPU or several GPUs to run this recipe. + +.. hint:: + + Please scroll down to the bottom of this page to find download links + for pretrained models if you don't want to train a model from scratch. + + +We use pruned RNN-T to compute the loss. + +.. note:: + + You can find the paper about pruned RNN-T at the following address: + + ``_ + +The transducer model consists of 3 parts: + + - Encoder, a.k.a, the transcription network. We use a Zipformer model (proposed by Daniel Povey) + - Decoder, a.k.a, the prediction network. We use a stateless model consisting of + ``nn.Embedding`` and ``nn.Conv1d`` + - Joiner, a.k.a, the joint network. + +.. caution:: + + Contrary to the conventional RNN-T models, we use a stateless decoder. + That is, it has no recurrent connections. + + +Data preparation +---------------- + +.. hint:: + + The data preparation is the same as other recipes on LibriSpeech dataset, + if you have finished this step, you can skip to ``Training`` directly. + +.. code-block:: bash + + $ cd egs/librispeech/ASR + $ ./prepare.sh + +The script ``./prepare.sh`` handles the data preparation for you, **automagically**. +All you need to do is to run it. + +The data preparation contains several stages, you can use the following two +options: + + - ``--stage`` + - ``--stop-stage`` + +to control which stage(s) should be run. By default, all stages are executed. + + +For example, + +.. code-block:: bash + + $ cd egs/librispeech/ASR + $ ./prepare.sh --stage 0 --stop-stage 0 + +means to run only stage 0. + +To run stage 2 to stage 5, use: + +.. code-block:: bash + + $ ./prepare.sh --stage 2 --stop-stage 5 + +.. HINT:: + + If you have pre-downloaded the `LibriSpeech `_ + dataset and the `musan `_ dataset, say, + they are saved in ``/tmp/LibriSpeech`` and ``/tmp/musan``, you can modify + the ``dl_dir`` variable in ``./prepare.sh`` to point to ``/tmp`` so that + ``./prepare.sh`` won't re-download them. + +.. NOTE:: + + All generated files by ``./prepare.sh``, e.g., features, lexicon, etc, + are saved in ``./data`` directory. + +We provide the following YouTube video showing how to run ``./prepare.sh``. + +.. note:: + + To get the latest news of `next-gen Kaldi `_, please subscribe + the following YouTube channel by `Nadira Povey `_: + + ``_ + +.. youtube:: ofEIoJL-mGM + + +Training +-------- + +Configurable options +~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: bash + + $ cd egs/librispeech/ASR + $ ./pruned_transducer_stateless7_streaming/train.py --help + + +shows you the training options that can be passed from the commandline. +The following options are used quite often: + + - ``--exp-dir`` + + The directory to save checkpoints, training logs and tensorboard. + + - ``--full-libri`` + + If it's True, the training part uses all the training data, i.e., + 960 hours. Otherwise, the training part uses only the subset + ``train-clean-100``, which has 100 hours of training data. + + .. CAUTION:: + The training set is perturbed by speed with two factors: 0.9 and 1.1. + If ``--full-libri`` is True, each epoch actually processes + ``3x960 == 2880`` hours of data. + + - ``--num-epochs`` + + It is the number of epochs to train. For instance, + ``./pruned_transducer_stateless7_streaming/train.py --num-epochs 30`` trains for 30 epochs + and generates ``epoch-1.pt``, ``epoch-2.pt``, ..., ``epoch-30.pt`` + in the folder ``./pruned_transducer_stateless7_streaming/exp``. + + - ``--start-epoch`` + + It's used to resume training. + ``./pruned_transducer_stateless7_streaming/train.py --start-epoch 10`` loads the + checkpoint ``./pruned_transducer_stateless7_streaming/exp/epoch-9.pt`` and starts + training from epoch 10, based on the state from epoch 9. + + - ``--world-size`` + + It is used for multi-GPU single-machine DDP training. + + - (a) If it is 1, then no DDP training is used. + + - (b) If it is 2, then GPU 0 and GPU 1 are used for DDP training. + + The following shows some use cases with it. + + **Use case 1**: You have 4 GPUs, but you only want to use GPU 0 and + GPU 2 for training. You can do the following: + + .. code-block:: bash + + $ cd egs/librispeech/ASR + $ export CUDA_VISIBLE_DEVICES="0,2" + $ ./pruned_transducer_stateless7_streaming/train.py --world-size 2 + + **Use case 2**: You have 4 GPUs and you want to use all of them + for training. You can do the following: + + .. code-block:: bash + + $ cd egs/librispeech/ASR + $ ./pruned_transducer_stateless7_streaming/train.py --world-size 4 + + **Use case 3**: You have 4 GPUs but you only want to use GPU 3 + for training. You can do the following: + + .. code-block:: bash + + $ cd egs/librispeech/ASR + $ export CUDA_VISIBLE_DEVICES="3" + $ ./pruned_transducer_stateless7_streaming/train.py --world-size 1 + + .. caution:: + + Only multi-GPU single-machine DDP training is implemented at present. + Multi-GPU multi-machine DDP training will be added later. + + - ``--max-duration`` + + It specifies the number of seconds over all utterances in a + batch, before **padding**. + If you encounter CUDA OOM, please reduce it. + + .. HINT:: + + Due to padding, the number of seconds of all utterances in a + batch will usually be larger than ``--max-duration``. + + A larger value for ``--max-duration`` may cause OOM during training, + while a smaller value may increase the training time. You have to + tune it. + + - ``--use-fp16`` + + If it is True, the model will train with half precision, from our experiment + results, by using half precision you can train with two times larger ``--max-duration`` + so as to get almost 2X speed up. + + We recommend using ``--use-fp16 True``. + + - ``--short-chunk-size`` + + When training a streaming attention model with chunk masking, the chunk size + would be either max sequence length of current batch or uniformly sampled from + (1, short_chunk_size). The default value is 50, you don't have to change it most of the time. + + - ``--num-left-chunks`` + + It indicates how many left context (in chunks) that can be seen when calculating attention. + The default value is 4, you don't have to change it most of the time. + + + - ``--decode-chunk-len`` + + The chunk size for decoding (in frames before subsampling). It is used for validation. + The default value is 32 (i.e., 320ms). + + +Pre-configured options +~~~~~~~~~~~~~~~~~~~~~~ + +There are some training options, e.g., number of encoder layers, +encoder dimension, decoder dimension, number of warmup steps etc, +that are not passed from the commandline. +They are pre-configured by the function ``get_params()`` in +`pruned_transducer_stateless7_streaming/train.py `_ + +You don't need to change these pre-configured parameters. If you really need to change +them, please modify ``./pruned_transducer_stateless7_streaming/train.py`` directly. + + +Training logs +~~~~~~~~~~~~~ + +Training logs and checkpoints are saved in ``--exp-dir`` (e.g. ``pruned_transducer_stateless7_streaming/exp``. +You will find the following files in that directory: + + - ``epoch-1.pt``, ``epoch-2.pt``, ... + + These are checkpoint files saved at the end of each epoch, containing model + ``state_dict`` and optimizer ``state_dict``. + To resume training from some checkpoint, say ``epoch-10.pt``, you can use: + + .. code-block:: bash + + $ ./pruned_transducer_stateless7_streaming/train.py --start-epoch 11 + + - ``checkpoint-436000.pt``, ``checkpoint-438000.pt``, ... + + These are checkpoint files saved every ``--save-every-n`` batches, + containing model ``state_dict`` and optimizer ``state_dict``. + To resume training from some checkpoint, say ``checkpoint-436000``, you can use: + + .. code-block:: bash + + $ ./pruned_transducer_stateless7_streaming/train.py --start-batch 436000 + + - ``tensorboard/`` + + This folder contains tensorBoard logs. Training loss, validation loss, learning + rate, etc, are recorded in these logs. You can visualize them by: + + .. code-block:: bash + + $ cd pruned_transducer_stateless7_streaming/exp/tensorboard + $ tensorboard dev upload --logdir . --description "pruned transducer training for LibriSpeech with icefall" + + .. hint:: + + If you don't have access to google, you can use the following command + to view the tensorboard log locally: + + .. code-block:: bash + + cd pruned_transducer_stateless7_streaming/exp/tensorboard + tensorboard --logdir . --port 6008 + + It will print the following message: + + .. code-block:: + + Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all + TensorBoard 2.8.0 at http://localhost:6008/ (Press CTRL+C to quit) + + Now start your browser and go to ``_ to view the tensorboard + logs. + + + - ``log/log-train-xxxx`` + + It is the detailed training log in text format, same as the one + you saw printed to the console during training. + +Usage example +~~~~~~~~~~~~~ + +You can use the following command to start the training using 4 GPUs: + +.. code-block:: bash + + export CUDA_VISIBLE_DEVICES="0,1,2,3" + ./pruned_transducer_stateless7_streaming/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir pruned_transducer_stateless7_streaming/exp \ + --full-libri 1 \ + --max-duration 550 + +Decoding +-------- + +The decoding part uses checkpoints saved by the training part, so you have +to run the training part first. + +.. hint:: + + There are two kinds of checkpoints: + + - (1) ``epoch-1.pt``, ``epoch-2.pt``, ..., which are saved at the end + of each epoch. You can pass ``--epoch`` to + ``pruned_transducer_stateless7_streaming/decode.py`` to use them. + + - (2) ``checkpoints-436000.pt``, ``epoch-438000.pt``, ..., which are saved + every ``--save-every-n`` batches. You can pass ``--iter`` to + ``pruned_transducer_stateless7_streaming/decode.py`` to use them. + + We suggest that you try both types of checkpoints and choose the one + that produces the lowest WERs. + +.. tip:: + + To decode a streaming model, you can use either ``simulate streaming decoding`` in ``decode.py`` or + ``real chunk-wise streaming decoding`` in ``streaming_decode.py``. The difference between ``decode.py`` and + ``streaming_decode.py`` is that, ``decode.py`` processes the whole acoustic frames at one time with masking (i.e. same as training), + but ``streaming_decode.py`` processes the acoustic frames chunk by chunk. + +.. NOTE:: + + ``simulate streaming decoding`` in ``decode.py`` and ``real chunk-size streaming decoding`` in ``streaming_decode.py`` should + produce almost the same results given the same ``--decode-chunk-len``. + + +Simulate streaming decoding +~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: bash + + $ cd egs/librispeech/ASR + $ ./pruned_transducer_stateless7_streaming/decode.py --help + +shows the options for decoding. +The following options are important for streaming models: + + ``--decode-chunk-len`` + + It is same as in ``train.py``, which specifies the chunk size for decoding (in frames before subsampling). + The default value is 32 (i.e., 320ms). + + +The following shows two examples (for the two types of checkpoints): + +.. code-block:: bash + + for m in greedy_search fast_beam_search modified_beam_search; do + for epoch in 30; do + for avg in 12 11 10 9 8; do + ./pruned_transducer_stateless7_streaming/decode.py \ + --epoch $epoch \ + --avg $avg \ + --decode-chunk-len 32 \ + --exp-dir pruned_transducer_stateless7_streaming/exp \ + --max-duration 600 \ + --decoding-method $m + done + done + done + + +.. code-block:: bash + + for m in greedy_search fast_beam_search modified_beam_search; do + for iter in 474000; do + for avg in 8 10 12 14 16 18; do + ./pruned_transducer_stateless7_streaming/decode.py \ + --iter $iter \ + --avg $avg \ + --decode-chunk-len 32 \ + --exp-dir pruned_transducer_stateless7_streaming/exp \ + --max-duration 600 \ + --decoding-method $m + done + done + done + + +Real streaming decoding +~~~~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: bash + + $ cd egs/librispeech/ASR + $ ./pruned_transducer_stateless7_streaming/streaming_decode.py --help + +shows the options for decoding. +The following options are important for streaming models: + + ``--decode-chunk-len`` + + It is same as in ``train.py``, which specifies the chunk size for decoding (in frames before subsampling). + The default value is 32 (i.e., 320ms). + For ``real streaming decoding``, we will process ``decode-chunk-len`` acoustic frames at each time. + + ``--num-decode-streams`` + + The number of decoding streams that can be run in parallel (very similar to the ``bath size``). + For ``real streaming decoding``, the batches will be packed dynamically, for example, if the + ``num-decode-streams`` equals to 10, then, sequence 1 to 10 will be decoded at first, after a while, + suppose sequence 1 and 2 are done, so, sequence 3 to 12 will be processed parallelly in a batch. + + +The following shows two examples (for the two types of checkpoints): + +.. code-block:: bash + + for m in greedy_search fast_beam_search modified_beam_search; do + for epoch in 30; do + for avg in 12 11 10 9 8; do + ./pruned_transducer_stateless7_streaming/decode.py \ + --epoch $epoch \ + --avg $avg \ + --decode-chunk-len 32 \ + --num-decode-streams 100 \ + --exp-dir pruned_transducer_stateless7_streaming/exp \ + --decoding-method $m + done + done + done + + +.. code-block:: bash + + for m in greedy_search fast_beam_search modified_beam_search; do + for iter in 474000; do + for avg in 8 10 12 14 16 18; do + ./pruned_transducer_stateless7_streaming/decode.py \ + --iter $iter \ + --avg $avg \ + --decode-chunk-len 16 \ + --num-decode-streams 100 \ + --exp-dir pruned_transducer_stateless7_streaming/exp \ + --decoding-method $m + done + done + done + + +.. tip:: + + Supporting decoding methods are as follows: + + - ``greedy_search`` : It takes the symbol with largest posterior probability + of each frame as the decoding result. + + - ``beam_search`` : It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf and + `espnet/nets/beam_search_transducer.py `_ + is used as a reference. Basicly, it keeps topk states for each frame, and expands the kept states with their own contexts to + next frame. + + - ``modified_beam_search`` : It implements the same algorithm as ``beam_search`` above, but it + runs in batch mode with ``--max-sym-per-frame=1`` being hardcoded. + + - ``fast_beam_search`` : It implements graph composition between the output ``log_probs`` and + given ``FSAs``. It is hard to describe the details in several lines of texts, you can read + our paper in https://arxiv.org/pdf/2211.00484.pdf or our `rnnt decode code in k2 `_. ``fast_beam_search`` can decode with ``FSAs`` on GPU efficiently. + + - ``fast_beam_search_LG`` : The same as ``fast_beam_search`` above, ``fast_beam_search`` uses + an trivial graph that has only one state, while ``fast_beam_search_LG`` uses an LG graph + (with N-gram LM). + + - ``fast_beam_search_nbest`` : It produces the decoding results as follows: + + - (1) Use ``fast_beam_search`` to get a lattice + - (2) Select ``num_paths`` paths from the lattice using ``k2.random_paths()`` + - (3) Unique the selected paths + - (4) Intersect the selected paths with the lattice and compute the + shortest path from the intersection result + - (5) The path with the largest score is used as the decoding output. + + - ``fast_beam_search_nbest_LG`` : It implements same logic as ``fast_beam_search_nbest``, the + only difference is that it uses ``fast_beam_search_LG`` to generate the lattice. + +.. NOTE:: + + The supporting decoding methods in ``streaming_decode.py`` might be less than that in ``decode.py``, if needed, + you can implement them by yourself or file a issue in `icefall `_ . + + +Export Model +------------ + +Currently it supports exporting checkpoints from ``pruned_transducer_stateless7_streaming/exp`` in the following ways. + +Export ``model.state_dict()`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Checkpoints saved by ``pruned_transducer_stateless7_streaming/train.py`` also include +``optimizer.state_dict()``. It is useful for resuming training. But after training, +we are interested only in ``model.state_dict()``. You can use the following +command to extract ``model.state_dict()``. + +.. code-block:: bash + + # Assume that --epoch 30 --avg 9 produces the smallest WER + # (You can get such information after running ./pruned_transducer_stateless7_streaming/decode.py) + + epoch=30 + avg=9 + + ./pruned_transducer_stateless7_streaming/export.py \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch $epoch \ + --avg $avg \ + --use-averaged-model=True \ + --decode-chunk-len 32 + +It will generate a file ``./pruned_transducer_stateless7_streaming/exp/pretrained.pt``. + +.. hint:: + + To use the generated ``pretrained.pt`` for ``pruned_transducer_stateless7_streaming/decode.py``, + you can run: + + .. code-block:: bash + + cd pruned_transducer_stateless7_streaming/exp + ln -s pretrained.pt epoch-999.pt + + And then pass ``--epoch 999 --avg 1 --use-averaged-model 0`` to + ``./pruned_transducer_stateless7_streaming/decode.py``. + +To use the exported model with ``./pruned_transducer_stateless7_streaming/pretrained.py``, you +can run: + +.. code-block:: bash + + ./pruned_transducer_stateless7_streaming/pretrained.py \ + --checkpoint ./pruned_transducer_stateless7_streaming/exp/pretrained.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method greedy_search \ + --decode-chunk-len 32 \ + /path/to/foo.wav \ + /path/to/bar.wav + + +Export model using ``torch.jit.script()`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: bash + + ./pruned_transducer_stateless7_streaming/export.py \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 30 \ + --avg 9 \ + --decode-chunk-len 32 \ + --jit 1 + +.. caution:: + + ``--decode-chunk-len`` is required to export a ScriptModule. + +It will generate a file ``cpu_jit.pt`` in the given ``exp_dir``. You can later +load it by ``torch.jit.load("cpu_jit.pt")``. + +Note ``cpu`` in the name ``cpu_jit.pt`` means the parameters when loaded into Python +are on CPU. You can use ``to("cuda")`` to move them to a CUDA device. + +Export model using ``torch.jit.trace()`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: bash + + epoch=30 + avg=9 + + ./pruned_transducer_stateless7_streaming/jit_trace_export.py \ + --bpe-model data/lang_bpe_500/bpe.model \ + --use-averaged-model=True \ + --decode-chunk-len 32 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --epoch $epoch \ + --avg $avg + +.. caution:: + + ``--decode-chunk-len`` is required to export a ScriptModule. + +It will generate 3 files: + + - ``./pruned_transducer_stateless7_streaming/exp/encoder_jit_trace.pt`` + - ``./pruned_transducer_stateless7_streaming/exp/decoder_jit_trace.pt`` + - ``./pruned_transducer_stateless7_streaming/exp/joiner_jit_trace.pt`` + +To use the generated files with ``./pruned_transducer_stateless7_streaming/jit_trace_pretrained.py``: + +.. code-block:: bash + + ./pruned_transducer_stateless7_streaming/jit_trace_pretrained.py \ + --encoder-model-filename ./pruned_transducer_stateless7_streaming/exp/encoder_jit_trace.pt \ + --decoder-model-filename ./pruned_transducer_stateless7_streaming/exp/decoder_jit_trace.pt \ + --joiner-model-filename ./pruned_transducer_stateless7_streaming/exp/joiner_jit_trace.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --decode-chunk-len 32 \ + /path/to/foo.wav + + +Download pretrained models +-------------------------- + +If you don't want to train from scratch, you can download the pretrained models +by visiting the following links: + + - `pruned_transducer_stateless7_streaming `_ + + See ``_ + for the details of the above pretrained models + +Deploy with Sherpa +------------------ + +Please see ``_ +for how to deploy the models in ``sherpa``. diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/README.md b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/README.md new file mode 100644 index 000000000..6e461e196 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/README.md @@ -0,0 +1,3 @@ +This recipe implements Streaming Zipformer-Transducer model. + +See https://k2-fsa.github.io/icefall/recipes/Streaming-ASR/librispeech/zipformer_transducer.html for detailed tutorials. diff --git a/egs/librispeech/ASR/zipformer_mmi/README.md b/egs/librispeech/ASR/zipformer_mmi/README.md index 8ca844180..e9a37a52a 100644 --- a/egs/librispeech/ASR/zipformer_mmi/README.md +++ b/egs/librispeech/ASR/zipformer_mmi/README.md @@ -1,6 +1,6 @@ This recipe implements Zipformer-MMI model. -See https://k2-fsa.github.io/icefall/recipes/librispeech/zipformer_mmi.html for detailed tutorials. +See https://k2-fsa.github.io/icefall/recipes/Non-streaming-ASR/librispeech/zipformer_mmi.html for detailed tutorials. It uses **CTC loss for warm-up** and then switches to MMI loss during training. From 2fd970b6821d47dacb2e6513321520db21fff67b Mon Sep 17 00:00:00 2001 From: Daniil Date: Sun, 1 Jan 2023 19:08:32 -0500 Subject: [PATCH 068/174] not removing result_dir in tedlium conformer ctc2 + add lm stem to compile_hlg_using_openfst.py + add MASTER_ADDR to be prvided to setup_dist (#801) --- .../ASR/local/compile_hlg_using_openfst.py | 19 ++++++++++++++----- egs/tedlium3/ASR/conformer_ctc2/decode.py | 7 ++----- icefall/dist.py | 8 ++++++-- 3 files changed, 22 insertions(+), 12 deletions(-) diff --git a/egs/librispeech/ASR/local/compile_hlg_using_openfst.py b/egs/librispeech/ASR/local/compile_hlg_using_openfst.py index 9e5e3df69..15fc47ef1 100755 --- a/egs/librispeech/ASR/local/compile_hlg_using_openfst.py +++ b/egs/librispeech/ASR/local/compile_hlg_using_openfst.py @@ -24,7 +24,7 @@ This script takes as input lang_dir and generates HLG from Caution: We use a lexicon that contains disambiguation symbols - - G, the LM, built from data/lm/G_3_gram.fst.txt + - G, the LM, built from data/lm/G_n_gram.fst.txt The generated HLG is saved in $lang_dir/HLG_fst.pt @@ -46,6 +46,13 @@ from icefall.lexicon import Lexicon def get_args(): parser = argparse.ArgumentParser() + parser.add_argument( + "--lm", + type=str, + default="G_3_gram", + help="""Stem name for LM used in HLG compiling. + """, + ) parser.add_argument( "--lang-dir", type=str, @@ -56,11 +63,13 @@ def get_args(): return parser.parse_args() -def compile_HLG(lang_dir: str) -> kaldifst.StdVectorFst: +def compile_HLG(lang_dir: str, lm: str = "G_3_gram") -> kaldifst.StdVectorFst: """ Args: lang_dir: The language directory, e.g., data/lang_phone or data/lang_bpe_5000. + lm: + The language stem base name. Return: An FST representing HLG. @@ -71,8 +80,8 @@ def compile_HLG(lang_dir: str) -> kaldifst.StdVectorFst: kaldifst.arcsort(L, sort_type="olabel") logging.info(f"L: #states {L.num_states}") - G_filename_txt = "data/lm/G_3_gram.fst.txt" - G_filename_binary = "data/lm/G_3_gram.fst" + G_filename_txt = f"data/lm/{lm}.fst.txt" + G_filename_binary = f"data/lm/{lm}.fst" if Path(G_filename_binary).is_file(): logging.info(f"Loading {G_filename_binary}") G = kaldifst.StdVectorFst.read(G_filename_binary) @@ -171,7 +180,7 @@ def main(): logging.info(f"{filename} already exists - skipping") return - HLG = compile_HLG(lang_dir) + HLG = compile_HLG(lang_dir, args.lm) logging.info(f"Saving HLG to {filename}") torch.save(HLG.as_dict(), filename) diff --git a/egs/tedlium3/ASR/conformer_ctc2/decode.py b/egs/tedlium3/ASR/conformer_ctc2/decode.py index ce4dcd142..28d39de70 100755 --- a/egs/tedlium3/ASR/conformer_ctc2/decode.py +++ b/egs/tedlium3/ASR/conformer_ctc2/decode.py @@ -20,7 +20,6 @@ import argparse import logging -import shutil from collections import defaultdict from pathlib import Path from typing import Dict, List, Optional, Tuple @@ -183,7 +182,7 @@ def get_parser() -> argparse.ArgumentParser: parser.add_argument( "--result-dir", type=str, - default="conformer_ctc2/exp", + default="conformer_ctc2/exp/results", help="Directory to store results.", ) @@ -635,9 +634,7 @@ def main() -> None: args.lm_path = Path(args.lm_path) args.result_dir = Path(args.result_dir) - if args.result_dir.is_dir(): - shutil.rmtree(args.result_dir) - args.result_dir.mkdir() + args.result_dir.mkdir(exist_ok=True) params = get_params() params.update(vars(args)) diff --git a/icefall/dist.py b/icefall/dist.py index 9df1c5bd1..672948623 100644 --- a/icefall/dist.py +++ b/icefall/dist.py @@ -21,12 +21,16 @@ import torch from torch import distributed as dist -def setup_dist(rank, world_size, master_port=None, use_ddp_launch=False): +def setup_dist( + rank, world_size, master_addr=None, master_port=None, use_ddp_launch=False +): """ rank and world_size are used only if use_ddp_launch is False. """ if "MASTER_ADDR" not in os.environ: - os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_ADDR"] = ( + "localhost" if master_addr is None else str(master_addr) + ) if "MASTER_PORT" not in os.environ: os.environ["MASTER_PORT"] = "12354" if master_port is None else str(master_port) From 80cce141b4235c9bf0d6a903f202e1217d56c18b Mon Sep 17 00:00:00 2001 From: marcoyang1998 <45973641+marcoyang1998@users.noreply.github.com> Date: Tue, 3 Jan 2023 15:40:53 +0800 Subject: [PATCH 069/174] Full libri fix manifest (#804) * modify the name of the directory of vq manifest * fix missing manifest in full libri training --- .../ASR/distillation_with_hubert.sh | 22 +++++++++++++++---- .../pruned_transducer_stateless6/vq_utils.py | 15 ++++++++++--- 2 files changed, 30 insertions(+), 7 deletions(-) diff --git a/egs/librispeech/ASR/distillation_with_hubert.sh b/egs/librispeech/ASR/distillation_with_hubert.sh index d5d3008aa..a38cf590c 100755 --- a/egs/librispeech/ASR/distillation_with_hubert.sh +++ b/egs/librispeech/ASR/distillation_with_hubert.sh @@ -43,7 +43,7 @@ mkdir -p $exp_dir # full_libri can be "True" or "False" # "True" -> use full librispeech dataset for distillation # "False" -> use train-clean-100 subset for distillation -full_libri=False +full_libri=True # use_extracted_codebook can be "True" or "False" # "True" -> stage 0 and stage 1 would be skipped, @@ -145,8 +145,12 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then log "Currently we only uploaded codebook indexes from teacher model hubert_xtralarge_ll60k_finetune_ls960" exit 1 fi + # The codebook indexes to be downloaded are generated using the following setup: + embedding_layer=36 + num_codebooks=8 + mkdir -p $exp_dir/vq - codebook_dir=$exp_dir/vq/$teacher_model_id + codebook_dir=$exp_dir/vq/${teacher_model_id}_layer${embedding_layer}_cb${num_codebooks} mkdir -p codebook_dir codebook_download_dir=$exp_dir/download_codebook if [ -d $codebook_download_dir ]; then @@ -164,8 +168,9 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then git lfs install git clone https://huggingface.co/marcoyang/pruned_transducer_stateless6_hubert_xtralarge_ll60k_finetune_ls960 $codebook_download_dir - mkdir -p data/vq_fbank - mv $codebook_download_dir/*.jsonl.gz data/vq_fbank/ + vq_fbank=data/vq_fbank_layer${embedding_layer}_cb${num_codebooks}/ + mkdir -p $vq_fbank + mv $codebook_download_dir/*.jsonl.gz $vq_fbank mkdir -p $codebook_dir/splits4 mv $codebook_download_dir/*.h5 $codebook_dir/splits4/ log "Remove $codebook_download_dir" @@ -181,6 +186,15 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then --max-duration 100 \ --teacher-model-id $teacher_model_id \ --use-extracted-codebook $use_extracted_codebook + + if [ "$full_libri" == "True" ]; then + # Merge the 3 subsets and create a full one + rm ${vq_fbank}/librispeech_cuts_train-all-shuf.jsonl.gz + cat <(gunzip -c ${vq_fbank}/librispeech_cuts_train-clean-100.jsonl.gz) \ + <(gunzip -c ${vq_fbank}/librispeech_cuts_train-clean-360.jsonl.gz) \ + <(gunzip -c ${vq_fbank}/librispeech_cuts_train-other-500.jsonl.gz) | \ + shuf | gzip -c > ${vq_fbank}/librispeech_cuts_train-all-shuf.jsonl.gz + fi fi if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/vq_utils.py b/egs/librispeech/ASR/pruned_transducer_stateless6/vq_utils.py index bf072d865..14ff86f23 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/vq_utils.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/vq_utils.py @@ -68,7 +68,10 @@ class CodebookIndexExtractor: def init_dirs(self): # vq_dir is the root dir for quantization, containing: # training data, trained quantizer, and extracted codebook indexes - self.vq_dir = self.params.exp_dir / f"vq/{self.params.teacher_model_id}/" + self.vq_dir = ( + self.params.exp_dir + / f"vq/{self.params.teacher_model_id}_layer{self.params.embedding_layer}_cb{self.params.num_codebooks}/" + ) self.vq_dir.mkdir(parents=True, exist_ok=True) # manifest_dir contains: @@ -79,7 +82,10 @@ class CodebookIndexExtractor: # It's doesn't matter whether ori_manifest_dir is str or Path. # Set it to Path to be consistent. self.ori_manifest_dir = Path("./data/fbank/") - self.dst_manifest_dir = Path("./data/vq_fbank/") + self.dst_manifest_dir = Path( + f"./data/vq_fbank_layer" + + f"{self.params.embedding_layer}_cb{self.params.num_codebooks}/" + ) self.dst_manifest_dir.mkdir(parents=True, exist_ok=True) @@ -284,7 +290,10 @@ class CodebookIndexExtractor: Merge generated vq included manfiests and storage to self.dst_manifest_dir. """ for subset in self.params.subsets: - vq_manifests = f"{self.manifest_dir}/with_codebook_indexes-librispeech-cuts_train-{subset}*.jsonl.gz" + vq_manifests = ( + f"{self.manifest_dir}/" + + f"with_codebook_indexes-librispeech-cuts_train-{subset}*.jsonl.gz" + ) dst_vq_manifest = ( self.dst_manifest_dir / f"librispeech_cuts_train-{subset}-vq.jsonl.gz" ) From 0f26edfde96d48406f2f227ed87584c6a94f3f68 Mon Sep 17 00:00:00 2001 From: Yunusemre Date: Tue, 3 Jan 2023 08:59:44 +0000 Subject: [PATCH 070/174] Add Zipformer Onnx Support (#778) * add export script * add zipformer onnx pretrained script * add onnx zipformer test * fix style * add zipformer onnx to workflow * replace is_in_onnx_export with is_tracing * add github.event.label.name == 'onnx' * add is_tracing to necessary conditions * fix pooling_mask * add onnx_check * add onnx_check to scripts * add is_tracing to scaling.py --- ...pruned-transducer-stateless7-2022-11-11.sh | 30 ++ .../run-librispeech-2022-11-11-stateless7.yml | 2 +- .../pruned_transducer_stateless7/export.py | 267 +++++++++++- .../onnx_check.py | 286 +++++++++++++ .../onnx_pretrained.py | 388 ++++++++++++++++++ .../pruned_transducer_stateless7/scaling.py | 7 +- .../pruned_transducer_stateless7/test_onnx.py | 374 +++++++++++++++++ .../pruned_transducer_stateless7/zipformer.py | 57 ++- 8 files changed, 1383 insertions(+), 28 deletions(-) create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless7/onnx_check.py create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless7/onnx_pretrained.py create mode 100644 egs/librispeech/ASR/pruned_transducer_stateless7/test_onnx.py diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless7-2022-11-11.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless7-2022-11-11.sh index 8e485d2e6..999841b80 100755 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless7-2022-11-11.sh +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless7-2022-11-11.sh @@ -30,6 +30,15 @@ ln -s pretrained.pt epoch-99.pt ls -lh *.pt popd +log "Test exporting to ONNX format" +./pruned_transducer_stateless7/export.py \ + --exp-dir $repo/exp \ + --use-averaged-model false \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --epoch 99 \ + --avg 1 \ + --onnx 1 + log "Export to torchscript model" ./pruned_transducer_stateless7/export.py \ --exp-dir $repo/exp \ @@ -41,6 +50,27 @@ log "Export to torchscript model" ls -lh $repo/exp/*.pt +log "Decode with ONNX models" + +./pruned_transducer_stateless7/onnx_check.py \ + --jit-filename $repo/exp/cpu_jit.pt \ + --onnx-encoder-filename $repo/exp/encoder.onnx \ + --onnx-decoder-filename $repo/exp/decoder.onnx \ + --onnx-joiner-filename $repo/exp/joiner.onnx \ + --onnx-joiner-encoder-proj-filename $repo/exp/joiner_encoder_proj.onnx \ + --onnx-joiner-decoder-proj-filename $repo/exp/joiner_decoder_proj.onnx + +./pruned_transducer_stateless7/onnx_pretrained.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --encoder-model-filename $repo/exp/encoder.onnx \ + --decoder-model-filename $repo/exp/decoder.onnx \ + --joiner-model-filename $repo/exp/joiner.onnx \ + --joiner-encoder-proj-model-filename $repo/exp/joiner_encoder_proj.onnx \ + --joiner-decoder-proj-model-filename $repo/exp/joiner_decoder_proj.onnx \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + log "Decode with models exported by torch.jit.script()" ./pruned_transducer_stateless7/jit_pretrained.py \ diff --git a/.github/workflows/run-librispeech-2022-11-11-stateless7.yml b/.github/workflows/run-librispeech-2022-11-11-stateless7.yml index 365e2761a..7694e8bf5 100644 --- a/.github/workflows/run-librispeech-2022-11-11-stateless7.yml +++ b/.github/workflows/run-librispeech-2022-11-11-stateless7.yml @@ -39,7 +39,7 @@ concurrency: jobs: run_librispeech_2022_11_11_zipformer: - if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule' + if: github.event.label.name == 'onnx' || github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule' runs-on: ${{ matrix.os }} strategy: matrix: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/export.py b/egs/librispeech/ASR/pruned_transducer_stateless7/export.py index 3e3160e7e..db8b5eb2b 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/export.py @@ -41,7 +41,31 @@ Check https://github.com/k2-fsa/sherpa for how to use the exported models outside of icefall. -(2) Export `model.state_dict()` +(2) Export to ONNX format + +./pruned_transducer_stateless7/export.py \ + --exp-dir ./pruned_transducer_stateless7/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 20 \ + --avg 10 \ + --onnx 1 + +It will generate the following files in the given `exp_dir`. +Check `onnx_check.py` for how to use them. + + - encoder.onnx + - decoder.onnx + - joiner.onnx + - joiner_encoder_proj.onnx + - joiner_decoder_proj.onnx + +Please see ./onnx_pretrained.py for usage of the generated files + +Check +https://github.com/k2-fsa/sherpa-onnx +for how to use the exported models outside of icefall. + +(3) Export `model.state_dict()` ./pruned_transducer_stateless7/export.py \ --exp-dir ./pruned_transducer_stateless7/exp \ @@ -172,6 +196,23 @@ def get_parser(): """, ) + parser.add_argument( + "--onnx", + type=str2bool, + default=False, + help="""If True, --jit is ignored and it exports the model + to onnx format. It will generate the following files: + + - encoder.onnx + - decoder.onnx + - joiner.onnx + - joiner_encoder_proj.onnx + - joiner_decoder_proj.onnx + + Refer to ./onnx_check.py and ./onnx_pretrained.py for how to use them. + """, + ) + parser.add_argument( "--context-size", type=int, @@ -184,6 +225,204 @@ def get_parser(): return parser +def export_encoder_model_onnx( + encoder_model: nn.Module, + encoder_filename: str, + opset_version: int = 11, +) -> None: + """Export the given encoder model to ONNX format. + The exported model has two inputs: + + - x, a tensor of shape (N, T, C); dtype is torch.float32 + - x_lens, a tensor of shape (N,); dtype is torch.int64 + + and it has two outputs: + + - encoder_out, a tensor of shape (N, T, C) + - encoder_out_lens, a tensor of shape (N,) + + Note: The warmup argument is fixed to 1. + + Args: + encoder_model: + The input encoder model + encoder_filename: + The filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + x = torch.zeros(1, 101, 80, dtype=torch.float32) + x_lens = torch.tensor([101], dtype=torch.int64) + + # encoder_model = torch.jit.script(encoder_model) + # It throws the following error for the above statement + # + # RuntimeError: Exporting the operator __is_ to ONNX opset version + # 11 is not supported. Please feel free to request support or + # submit a pull request on PyTorch GitHub. + # + # I cannot find which statement causes the above error. + # torch.onnx.export() will use torch.jit.trace() internally, which + # works well for the current reworked model + torch.onnx.export( + encoder_model, + (x, x_lens), + encoder_filename, + verbose=False, + opset_version=opset_version, + input_names=["x", "x_lens"], + output_names=["encoder_out", "encoder_out_lens"], + dynamic_axes={ + "x": {0: "N", 1: "T"}, + "x_lens": {0: "N"}, + "encoder_out": {0: "N", 1: "T"}, + "encoder_out_lens": {0: "N"}, + }, + ) + logging.info(f"Saved to {encoder_filename}") + + +def export_decoder_model_onnx( + decoder_model: nn.Module, + decoder_filename: str, + opset_version: int = 11, +) -> None: + """Export the decoder model to ONNX format. + + The exported model has one input: + + - y: a torch.int64 tensor of shape (N, decoder_model.context_size) + + and has one output: + + - decoder_out: a torch.float32 tensor of shape (N, 1, C) + + Note: The argument need_pad is fixed to False. + + Args: + decoder_model: + The decoder model to be exported. + decoder_filename: + Filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64) + need_pad = False # Always False, so we can use torch.jit.trace() here + # Note(fangjun): torch.jit.trace() is more efficient than torch.jit.script() + # in this case + torch.onnx.export( + decoder_model, + (y, need_pad), + decoder_filename, + verbose=False, + opset_version=opset_version, + input_names=["y", "need_pad"], + output_names=["decoder_out"], + dynamic_axes={ + "y": {0: "N"}, + "decoder_out": {0: "N"}, + }, + ) + logging.info(f"Saved to {decoder_filename}") + + +def export_joiner_model_onnx( + joiner_model: nn.Module, + joiner_filename: str, + opset_version: int = 11, +) -> None: + """Export the joiner model to ONNX format. + The exported joiner model has two inputs: + + - projected_encoder_out: a tensor of shape (N, joiner_dim) + - projected_decoder_out: a tensor of shape (N, joiner_dim) + + and produces one output: + + - logit: a tensor of shape (N, vocab_size) + + The exported encoder_proj model has one input: + + - encoder_out: a tensor of shape (N, encoder_out_dim) + + and produces one output: + + - projected_encoder_out: a tensor of shape (N, joiner_dim) + + The exported decoder_proj model has one input: + + - decoder_out: a tensor of shape (N, decoder_out_dim) + + and produces one output: + + - projected_decoder_out: a tensor of shape (N, joiner_dim) + """ + encoder_proj_filename = str(joiner_filename).replace(".onnx", "_encoder_proj.onnx") + decoder_proj_filename = str(joiner_filename).replace(".onnx", "_decoder_proj.onnx") + + encoder_out_dim = joiner_model.encoder_proj.weight.shape[1] + decoder_out_dim = joiner_model.decoder_proj.weight.shape[1] + joiner_dim = joiner_model.decoder_proj.weight.shape[0] + + projected_encoder_out = torch.rand(1, 1, 1, joiner_dim, dtype=torch.float32) + projected_decoder_out = torch.rand(1, 1, 1, joiner_dim, dtype=torch.float32) + + project_input = False + # Note: It uses torch.jit.trace() internally + torch.onnx.export( + joiner_model, + (projected_encoder_out, projected_decoder_out, project_input), + joiner_filename, + verbose=False, + opset_version=opset_version, + input_names=[ + "encoder_out", + "decoder_out", + "project_input", + ], + output_names=["logit"], + dynamic_axes={ + "encoder_out": {0: "N"}, + "decoder_out": {0: "N"}, + "logit": {0: "N"}, + }, + ) + logging.info(f"Saved to {joiner_filename}") + + encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32) + torch.onnx.export( + joiner_model.encoder_proj, + encoder_out, + encoder_proj_filename, + verbose=False, + opset_version=opset_version, + input_names=["encoder_out"], + output_names=["projected_encoder_out"], + dynamic_axes={ + "encoder_out": {0: "N"}, + "projected_encoder_out": {0: "N"}, + }, + ) + logging.info(f"Saved to {encoder_proj_filename}") + + decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32) + torch.onnx.export( + joiner_model.decoder_proj, + decoder_out, + decoder_proj_filename, + verbose=False, + opset_version=opset_version, + input_names=["decoder_out"], + output_names=["projected_decoder_out"], + dynamic_axes={ + "decoder_out": {0: "N"}, + "projected_decoder_out": {0: "N"}, + }, + ) + logging.info(f"Saved to {decoder_proj_filename}") + + @torch.no_grad() def main(): args = get_parser().parse_args() @@ -292,7 +531,31 @@ def main(): model.to("cpu") model.eval() - if params.jit is True: + if params.onnx is True: + convert_scaled_to_non_scaled(model, inplace=True) + opset_version = 13 + logging.info("Exporting to onnx format") + encoder_filename = params.exp_dir / "encoder.onnx" + export_encoder_model_onnx( + model.encoder, + encoder_filename, + opset_version=opset_version, + ) + + decoder_filename = params.exp_dir / "decoder.onnx" + export_decoder_model_onnx( + model.decoder, + decoder_filename, + opset_version=opset_version, + ) + + joiner_filename = params.exp_dir / "joiner.onnx" + export_joiner_model_onnx( + model.joiner, + joiner_filename, + opset_version=opset_version, + ) + elif params.jit is True: convert_scaled_to_non_scaled(model, inplace=True) # We won't use the forward() method of the model in C++, so just ignore # it here. diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/onnx_check.py b/egs/librispeech/ASR/pruned_transducer_stateless7/onnx_check.py new file mode 100755 index 000000000..63acc0922 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/onnx_check.py @@ -0,0 +1,286 @@ +#!/usr/bin/env python3 +# +# Copyright 2022 Xiaomi Corporation (Author: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script checks that exported onnx models produce the same output +with the given torchscript model for the same input. +""" + +import argparse +import logging + +import onnxruntime as ort +import torch + +from icefall import is_module_available + +if not is_module_available("onnxruntime"): + raise ValueError("Please 'pip install onnxruntime' first.") + + +ort.set_default_logger_severity(3) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--jit-filename", + required=True, + type=str, + help="Path to the torchscript model", + ) + + parser.add_argument( + "--onnx-encoder-filename", + required=True, + type=str, + help="Path to the onnx encoder model", + ) + + parser.add_argument( + "--onnx-decoder-filename", + required=True, + type=str, + help="Path to the onnx decoder model", + ) + + parser.add_argument( + "--onnx-joiner-filename", + required=True, + type=str, + help="Path to the onnx joiner model", + ) + + parser.add_argument( + "--onnx-joiner-encoder-proj-filename", + required=True, + type=str, + help="Path to the onnx joiner encoder projection model", + ) + + parser.add_argument( + "--onnx-joiner-decoder-proj-filename", + required=True, + type=str, + help="Path to the onnx joiner decoder projection model", + ) + + return parser + + +def test_encoder( + model: torch.jit.ScriptModule, + encoder_session: ort.InferenceSession, +): + inputs = encoder_session.get_inputs() + outputs = encoder_session.get_outputs() + input_names = [n.name for n in inputs] + output_names = [n.name for n in outputs] + + assert inputs[0].shape == ["N", "T", 80] + assert inputs[1].shape == ["N"] + + for N in [1, 5]: + for T in [12, 50]: + print("N, T", N, T) + x = torch.rand(N, T, 80, dtype=torch.float32) + x_lens = torch.randint(low=10, high=T + 1, size=(N,)) + x_lens[0] = T + + encoder_inputs = { + input_names[0]: x.numpy(), + input_names[1]: x_lens.numpy(), + } + + torch_encoder_out, torch_encoder_out_lens = model.encoder(x, x_lens) + + encoder_out, encoder_out_lens = encoder_session.run( + output_names, + encoder_inputs, + ) + + torch_encoder_out, torch_encoder_out_lens = model.encoder(x, x_lens) + + encoder_out = torch.from_numpy(encoder_out) + assert torch.allclose(encoder_out, torch_encoder_out, atol=1e-05), ( + (encoder_out - torch_encoder_out).abs().max(), + encoder_out.shape, + torch_encoder_out.shape, + ) + + +def test_decoder( + model: torch.jit.ScriptModule, + decoder_session: ort.InferenceSession, +): + inputs = decoder_session.get_inputs() + outputs = decoder_session.get_outputs() + input_names = [n.name for n in inputs] + output_names = [n.name for n in outputs] + + assert inputs[0].shape == ["N", 2] + for N in [1, 5, 10]: + y = torch.randint(low=1, high=500, size=(10, 2)) + + decoder_inputs = {input_names[0]: y.numpy()} + decoder_out = decoder_session.run( + output_names, + decoder_inputs, + )[0] + decoder_out = torch.from_numpy(decoder_out) + + torch_decoder_out = model.decoder(y, need_pad=False) + assert torch.allclose(decoder_out, torch_decoder_out, atol=1e-5), ( + (decoder_out - torch_decoder_out).abs().max() + ) + + +def test_joiner( + model: torch.jit.ScriptModule, + joiner_session: ort.InferenceSession, + joiner_encoder_proj_session: ort.InferenceSession, + joiner_decoder_proj_session: ort.InferenceSession, +): + joiner_inputs = joiner_session.get_inputs() + joiner_outputs = joiner_session.get_outputs() + joiner_input_names = [n.name for n in joiner_inputs] + joiner_output_names = [n.name for n in joiner_outputs] + + assert joiner_inputs[0].shape == ["N", 1, 1, 512] + assert joiner_inputs[1].shape == ["N", 1, 1, 512] + + joiner_encoder_proj_inputs = joiner_encoder_proj_session.get_inputs() + encoder_proj_input_name = joiner_encoder_proj_inputs[0].name + + assert joiner_encoder_proj_inputs[0].shape == ["N", 384] + + joiner_encoder_proj_outputs = joiner_encoder_proj_session.get_outputs() + encoder_proj_output_name = joiner_encoder_proj_outputs[0].name + + joiner_decoder_proj_inputs = joiner_decoder_proj_session.get_inputs() + decoder_proj_input_name = joiner_decoder_proj_inputs[0].name + + assert joiner_decoder_proj_inputs[0].shape == ["N", 512] + + joiner_decoder_proj_outputs = joiner_decoder_proj_session.get_outputs() + decoder_proj_output_name = joiner_decoder_proj_outputs[0].name + + for N in [1, 5, 10]: + encoder_out = torch.rand(N, 384) + decoder_out = torch.rand(N, 512) + + projected_encoder_out = torch.rand(N, 1, 1, 512) + projected_decoder_out = torch.rand(N, 1, 1, 512) + + joiner_inputs = { + joiner_input_names[0]: projected_encoder_out.numpy(), + joiner_input_names[1]: projected_decoder_out.numpy(), + } + joiner_out = joiner_session.run(joiner_output_names, joiner_inputs)[0] + joiner_out = torch.from_numpy(joiner_out) + + torch_joiner_out = model.joiner( + projected_encoder_out, + projected_decoder_out, + project_input=False, + ) + assert torch.allclose(joiner_out, torch_joiner_out, atol=1e-5), ( + (joiner_out - torch_joiner_out).abs().max() + ) + + # Now test encoder_proj + joiner_encoder_proj_inputs = {encoder_proj_input_name: encoder_out.numpy()} + joiner_encoder_proj_out = joiner_encoder_proj_session.run( + [encoder_proj_output_name], joiner_encoder_proj_inputs + )[0] + joiner_encoder_proj_out = torch.from_numpy(joiner_encoder_proj_out) + + torch_joiner_encoder_proj_out = model.joiner.encoder_proj(encoder_out) + assert torch.allclose( + joiner_encoder_proj_out, torch_joiner_encoder_proj_out, atol=1e-5 + ), ((joiner_encoder_proj_out - torch_joiner_encoder_proj_out).abs().max()) + + # Now test decoder_proj + joiner_decoder_proj_inputs = {decoder_proj_input_name: decoder_out.numpy()} + joiner_decoder_proj_out = joiner_decoder_proj_session.run( + [decoder_proj_output_name], joiner_decoder_proj_inputs + )[0] + joiner_decoder_proj_out = torch.from_numpy(joiner_decoder_proj_out) + + torch_joiner_decoder_proj_out = model.joiner.decoder_proj(decoder_out) + assert torch.allclose( + joiner_decoder_proj_out, torch_joiner_decoder_proj_out, atol=1e-5 + ), ((joiner_decoder_proj_out - torch_joiner_decoder_proj_out).abs().max()) + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + logging.info(vars(args)) + + model = torch.jit.load(args.jit_filename) + + options = ort.SessionOptions() + options.inter_op_num_threads = 1 + options.intra_op_num_threads = 1 + + logging.info("Test encoder") + encoder_session = ort.InferenceSession( + args.onnx_encoder_filename, + sess_options=options, + ) + test_encoder(model, encoder_session) + + logging.info("Test decoder") + decoder_session = ort.InferenceSession( + args.onnx_decoder_filename, + sess_options=options, + ) + test_decoder(model, decoder_session) + + logging.info("Test joiner") + joiner_session = ort.InferenceSession( + args.onnx_joiner_filename, + sess_options=options, + ) + joiner_encoder_proj_session = ort.InferenceSession( + args.onnx_joiner_encoder_proj_filename, + sess_options=options, + ) + joiner_decoder_proj_session = ort.InferenceSession( + args.onnx_joiner_decoder_proj_filename, + sess_options=options, + ) + test_joiner( + model, + joiner_session, + joiner_encoder_proj_session, + joiner_decoder_proj_session, + ) + logging.info("Finished checking ONNX models") + + +if __name__ == "__main__": + torch.manual_seed(20220727) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/onnx_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7/onnx_pretrained.py new file mode 100755 index 000000000..3a06ee293 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/onnx_pretrained.py @@ -0,0 +1,388 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads ONNX models and uses them to decode waves. +You can use the following command to get the exported models: + +./pruned_transducer_stateless7/export.py \ + --exp-dir ./pruned_transducer_stateless7/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 20 \ + --avg 10 \ + --onnx 1 + +Usage of this script: + +./pruned_transducer_stateless7/onnx_pretrained.py \ + --encoder-model-filename ./pruned_transducer_stateless7/exp/encoder.onnx \ + --decoder-model-filename ./pruned_transducer_stateless7/exp/decoder.onnx \ + --joiner-model-filename ./pruned_transducer_stateless7/exp/joiner.onnx \ + --joiner-encoder-proj-model-filename ./pruned_transducer_stateless7/exp/joiner_encoder_proj.onnx \ + --joiner-decoder-proj-model-filename ./pruned_transducer_stateless7/exp/joiner_decoder_proj.onnx \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + /path/to/foo.wav \ + /path/to/bar.wav +""" + +import argparse +import logging +import math +from typing import List + +import kaldifeat +import numpy as np +import onnxruntime as ort +import sentencepiece as spm +import torch +import torchaudio +from torch.nn.utils.rnn import pad_sequence + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--encoder-model-filename", + type=str, + required=True, + help="Path to the encoder onnx model. ", + ) + + parser.add_argument( + "--decoder-model-filename", + type=str, + required=True, + help="Path to the decoder onnx model. ", + ) + + parser.add_argument( + "--joiner-model-filename", + type=str, + required=True, + help="Path to the joiner onnx model. ", + ) + + parser.add_argument( + "--joiner-encoder-proj-model-filename", + type=str, + required=True, + help="Path to the joiner encoder_proj onnx model. ", + ) + + parser.add_argument( + "--joiner-decoder-proj-model-filename", + type=str, + required=True, + help="Path to the joiner decoder_proj onnx model. ", + ) + + parser.add_argument( + "--bpe-model", + type=str, + help="""Path to bpe.model.""", + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="Context size of the decoder model", + ) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0]) + return ans + + +def greedy_search( + decoder: ort.InferenceSession, + joiner: ort.InferenceSession, + joiner_encoder_proj: ort.InferenceSession, + joiner_decoder_proj: ort.InferenceSession, + encoder_out: np.ndarray, + encoder_out_lens: np.ndarray, + context_size: int, +) -> List[List[int]]: + """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. + Args: + decoder: + The decoder model. + joiner: + The joiner model. + joiner_encoder_proj: + The joiner encoder projection model. + joiner_decoder_proj: + The joiner decoder projection model. + encoder_out: + A 3-D tensor of shape (N, T, C) + encoder_out_lens: + A 1-D tensor of shape (N,). + context_size: + The context size of the decoder model. + Returns: + Return the decoded results for each utterance. + """ + encoder_out = torch.from_numpy(encoder_out) + encoder_out_lens = torch.from_numpy(encoder_out_lens) + assert encoder_out.ndim == 3 + assert encoder_out.size(0) >= 1, encoder_out.size(0) + + packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( + input=encoder_out, + lengths=encoder_out_lens.cpu(), + batch_first=True, + enforce_sorted=False, + ) + + projected_encoder_out = joiner_encoder_proj.run( + [joiner_encoder_proj.get_outputs()[0].name], + {joiner_encoder_proj.get_inputs()[0].name: packed_encoder_out.data.numpy()}, + )[0] + + blank_id = 0 # hard-code to 0 + + batch_size_list = packed_encoder_out.batch_sizes.tolist() + N = encoder_out.size(0) + + assert torch.all(encoder_out_lens > 0), encoder_out_lens + assert N == batch_size_list[0], (N, batch_size_list) + + hyps = [[blank_id] * context_size for _ in range(N)] + + decoder_input_nodes = decoder.get_inputs() + decoder_output_nodes = decoder.get_outputs() + + joiner_input_nodes = joiner.get_inputs() + joiner_output_nodes = joiner.get_outputs() + + decoder_input = torch.tensor( + hyps, + dtype=torch.int64, + ) # (N, context_size) + + decoder_out = decoder.run( + [decoder_output_nodes[0].name], + { + decoder_input_nodes[0].name: decoder_input.numpy(), + }, + )[0].squeeze(1) + projected_decoder_out = joiner_decoder_proj.run( + [joiner_decoder_proj.get_outputs()[0].name], + {joiner_decoder_proj.get_inputs()[0].name: decoder_out}, + )[0] + + projected_decoder_out = torch.from_numpy(projected_decoder_out) + + offset = 0 + for batch_size in batch_size_list: + start = offset + end = offset + batch_size + current_encoder_out = projected_encoder_out[start:end] + # current_encoder_out's shape: (batch_size, encoder_out_dim) + offset = end + + projected_decoder_out = projected_decoder_out[:batch_size] + + logits = joiner.run( + [joiner_output_nodes[0].name], + { + joiner_input_nodes[0].name: np.expand_dims( + np.expand_dims(current_encoder_out, axis=1), axis=1 + ), + joiner_input_nodes[1] + .name: projected_decoder_out.unsqueeze(1) + .unsqueeze(1) + .numpy(), + }, + )[0] + logits = torch.from_numpy(logits).squeeze(1).squeeze(1) + # logits'shape (batch_size, vocab_size) + + assert logits.ndim == 2, logits.shape + y = logits.argmax(dim=1).tolist() + emitted = False + for i, v in enumerate(y): + if v != blank_id: + hyps[i].append(v) + emitted = True + if emitted: + # update decoder output + decoder_input = [h[-context_size:] for h in hyps[:batch_size]] + decoder_input = torch.tensor( + decoder_input, + dtype=torch.int64, + ) + decoder_out = decoder.run( + [decoder_output_nodes[0].name], + { + decoder_input_nodes[0].name: decoder_input.numpy(), + }, + )[0].squeeze(1) + projected_decoder_out = joiner_decoder_proj.run( + [joiner_decoder_proj.get_outputs()[0].name], + {joiner_decoder_proj.get_inputs()[0].name: decoder_out}, + )[0] + projected_decoder_out = torch.from_numpy(projected_decoder_out) + + sorted_ans = [h[context_size:] for h in hyps] + ans = [] + unsorted_indices = packed_encoder_out.unsorted_indices.tolist() + for i in range(N): + ans.append(sorted_ans[unsorted_indices[i]]) + + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + logging.info(vars(args)) + + session_opts = ort.SessionOptions() + session_opts.inter_op_num_threads = 1 + session_opts.intra_op_num_threads = 1 + + encoder = ort.InferenceSession( + args.encoder_model_filename, + sess_options=session_opts, + ) + + decoder = ort.InferenceSession( + args.decoder_model_filename, + sess_options=session_opts, + ) + + joiner = ort.InferenceSession( + args.joiner_model_filename, + sess_options=session_opts, + ) + + joiner_encoder_proj = ort.InferenceSession( + args.joiner_encoder_proj_model_filename, + sess_options=session_opts, + ) + + joiner_decoder_proj = ort.InferenceSession( + args.joiner_decoder_proj_model_filename, + sess_options=session_opts, + ) + + sp = spm.SentencePieceProcessor() + sp.load(args.bpe_model) + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = "cpu" + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = args.sample_rate + opts.mel_opts.num_bins = 80 + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {args.sound_files}") + waves = read_sound_files( + filenames=args.sound_files, + expected_sample_rate=args.sample_rate, + ) + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + + features = pad_sequence( + features, + batch_first=True, + padding_value=math.log(1e-10), + ) + + feature_lengths = torch.tensor(feature_lengths, dtype=torch.int64) + + encoder_input_nodes = encoder.get_inputs() + encoder_out_nodes = encoder.get_outputs() + encoder_out, encoder_out_lens = encoder.run( + [encoder_out_nodes[0].name, encoder_out_nodes[1].name], + { + encoder_input_nodes[0].name: features.numpy(), + encoder_input_nodes[1].name: feature_lengths.numpy(), + }, + ) + + hyps = greedy_search( + decoder=decoder, + joiner=joiner, + joiner_encoder_proj=joiner_encoder_proj, + joiner_decoder_proj=joiner_decoder_proj, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + context_size=args.context_size, + ) + s = "\n" + for filename, hyp in zip(args.sound_files, hyps): + words = sp.decode(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 1cbde6db0..156b91f09 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -261,7 +261,7 @@ class RandomGrad(torch.nn.Module): self.min_abs = min_abs def forward(self, x: Tensor): - if torch.jit.is_scripting() or not self.training: + if torch.jit.is_scripting() or not self.training or torch.jit.is_tracing(): return x else: return RandomGradFunction.apply(x, self.min_abs) @@ -530,7 +530,7 @@ class ActivationBalancer(torch.nn.Module): self.register_buffer("count", torch.tensor(0, dtype=torch.int64)) def forward(self, x: Tensor) -> Tensor: - if torch.jit.is_scripting() or not x.requires_grad: + if torch.jit.is_scripting() or not x.requires_grad or torch.jit.is_tracing(): return _no_op(x) count = self.cpu_count @@ -790,7 +790,7 @@ def with_loss(x, y): def _no_op(x: Tensor) -> Tensor: - if torch.jit.is_scripting(): + if torch.jit.is_scripting() or torch.jit.is_tracing(): return x else: # a no-op function that will have a node in the autograd graph, @@ -862,6 +862,7 @@ class MaxEig(torch.nn.Module): torch.jit.is_scripting() or self.max_var_per_eig <= 0 or random.random() > self.cur_prob + or torch.jit.is_tracing() ): return _no_op(x) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/test_onnx.py b/egs/librispeech/ASR/pruned_transducer_stateless7/test_onnx.py new file mode 100644 index 000000000..2440d267c --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/test_onnx.py @@ -0,0 +1,374 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This file is to test that models can be exported to onnx. +""" +import os + +from icefall import is_module_available + +if not is_module_available("onnxruntime"): + raise ValueError("Please 'pip install onnxruntime' first.") + +import onnxruntime as ort +import torch +from scaling_converter import convert_scaled_to_non_scaled +from zipformer import ( + Conv2dSubsampling, + RelPositionalEncoding, + Zipformer, + ZipformerEncoder, + ZipformerEncoderLayer, +) + +ort.set_default_logger_severity(3) + + +def test_conv2d_subsampling(): + filename = "conv2d_subsampling.onnx" + opset_version = 13 + N = 30 + T = 50 + num_features = 80 + d_model = 512 + x = torch.rand(N, T, num_features) + + encoder_embed = Conv2dSubsampling(num_features, d_model) + encoder_embed.eval() + encoder_embed = convert_scaled_to_non_scaled(encoder_embed, inplace=True) + + torch.onnx.export( + encoder_embed, + x, + filename, + verbose=False, + opset_version=opset_version, + input_names=["x"], + output_names=["y"], + dynamic_axes={ + "x": {0: "N", 1: "T"}, + "y": {0: "N", 1: "T"}, + }, + ) + + options = ort.SessionOptions() + options.inter_op_num_threads = 1 + options.intra_op_num_threads = 1 + + session = ort.InferenceSession( + filename, + sess_options=options, + ) + + input_nodes = session.get_inputs() + assert input_nodes[0].name == "x" + assert input_nodes[0].shape == ["N", "T", num_features] + + inputs = {input_nodes[0].name: x.numpy()} + + onnx_y = session.run(["y"], inputs)[0] + + onnx_y = torch.from_numpy(onnx_y) + torch_y = encoder_embed(x) + assert torch.allclose(onnx_y, torch_y, atol=1e-05), (onnx_y - torch_y).abs().max() + + os.remove(filename) + + +def test_rel_pos(): + filename = "rel_pos.onnx" + + opset_version = 13 + N = 30 + T = 50 + num_features = 80 + d_model = 512 + x = torch.rand(N, T, num_features) + + encoder_pos = RelPositionalEncoding(d_model, dropout_rate=0.1) + encoder_pos.eval() + encoder_pos = convert_scaled_to_non_scaled(encoder_pos, inplace=True) + + x = x.permute(1, 0, 2) + + torch.onnx.export( + encoder_pos, + x, + filename, + verbose=False, + opset_version=opset_version, + input_names=["x"], + output_names=["pos_emb"], + dynamic_axes={ + "x": {0: "N", 1: "T"}, + "pos_emb": {0: "N", 1: "T"}, + }, + ) + + options = ort.SessionOptions() + options.inter_op_num_threads = 1 + options.intra_op_num_threads = 1 + + session = ort.InferenceSession( + filename, + sess_options=options, + ) + + input_nodes = session.get_inputs() + assert input_nodes[0].name == "x" + assert input_nodes[0].shape == ["N", "T", num_features] + + inputs = {input_nodes[0].name: x.numpy()} + onnx_pos_emb = session.run(["pos_emb"], inputs) + onnx_pos_emb = torch.from_numpy(onnx_pos_emb[0]) + + torch_pos_emb = encoder_pos(x) + assert torch.allclose(onnx_pos_emb, torch_pos_emb, atol=1e-05), ( + (onnx_pos_emb - torch_pos_emb).abs().max() + ) + print(onnx_pos_emb.abs().sum(), torch_pos_emb.abs().sum()) + + os.remove(filename) + + +def test_zipformer_encoder_layer(): + filename = "zipformer_encoder_layer.onnx" + opset_version = 13 + N = 30 + T = 50 + + d_model = 384 + attention_dim = 192 + nhead = 8 + feedforward_dim = 1024 + dropout = 0.1 + cnn_module_kernel = 31 + pos_dim = 4 + + x = torch.rand(N, T, d_model) + + encoder_pos = RelPositionalEncoding(d_model, dropout) + encoder_pos.eval() + encoder_pos = convert_scaled_to_non_scaled(encoder_pos, inplace=True) + + x = x.permute(1, 0, 2) + pos_emb = encoder_pos(x) + + encoder_layer = ZipformerEncoderLayer( + d_model, + attention_dim, + nhead, + feedforward_dim, + dropout, + cnn_module_kernel, + pos_dim, + ) + encoder_layer.eval() + encoder_layer = convert_scaled_to_non_scaled(encoder_layer, inplace=True) + + torch.onnx.export( + encoder_layer, + (x, pos_emb), + filename, + verbose=False, + opset_version=opset_version, + input_names=["x", "pos_emb"], + output_names=["y"], + dynamic_axes={ + "x": {0: "T", 1: "N"}, + "pos_emb": {0: "N", 1: "T"}, + "y": {0: "T", 1: "N"}, + }, + ) + + options = ort.SessionOptions() + options.inter_op_num_threads = 1 + options.intra_op_num_threads = 1 + + session = ort.InferenceSession( + filename, + sess_options=options, + ) + + input_nodes = session.get_inputs() + inputs = { + input_nodes[0].name: x.numpy(), + input_nodes[1].name: pos_emb.numpy(), + } + onnx_y = session.run(["y"], inputs)[0] + onnx_y = torch.from_numpy(onnx_y) + + torch_y = encoder_layer(x, pos_emb) + assert torch.allclose(onnx_y, torch_y, atol=1e-05), (onnx_y - torch_y).abs().max() + + print(onnx_y.abs().sum(), torch_y.abs().sum(), onnx_y.shape, torch_y.shape) + + os.remove(filename) + + +def test_zipformer_encoder(): + filename = "zipformer_encoder.onnx" + + opset_version = 13 + N = 3 + T = 15 + + d_model = 512 + attention_dim = 192 + nhead = 8 + feedforward_dim = 1024 + dropout = 0.1 + cnn_module_kernel = 31 + pos_dim = 4 + num_encoder_layers = 12 + + warmup_batches = 4000.0 + warmup_begin = warmup_batches / (num_encoder_layers + 1) + warmup_end = warmup_batches / (num_encoder_layers + 1) + + x = torch.rand(N, T, d_model) + + encoder_layer = ZipformerEncoderLayer( + d_model, + attention_dim, + nhead, + feedforward_dim, + dropout, + cnn_module_kernel, + pos_dim, + ) + encoder = ZipformerEncoder( + encoder_layer, num_encoder_layers, dropout, warmup_begin, warmup_end + ) + encoder.eval() + encoder = convert_scaled_to_non_scaled(encoder, inplace=True) + + # jit_model = torch.jit.trace(encoder, (pos_emb)) + + torch_y = encoder(x) + + torch.onnx.export( + encoder, + (x), + filename, + verbose=False, + opset_version=opset_version, + input_names=["x"], + output_names=["y"], + dynamic_axes={ + "x": {0: "T", 1: "N"}, + "y": {0: "T", 1: "N"}, + }, + ) + + options = ort.SessionOptions() + options.inter_op_num_threads = 1 + options.intra_op_num_threads = 1 + + session = ort.InferenceSession( + filename, + sess_options=options, + ) + + input_nodes = session.get_inputs() + inputs = { + input_nodes[0].name: x.numpy(), + } + onnx_y = session.run(["y"], inputs)[0] + onnx_y = torch.from_numpy(onnx_y) + + torch_y = encoder(x) + assert torch.allclose(onnx_y, torch_y, atol=1e-05), (onnx_y - torch_y).abs().max() + + print(onnx_y.abs().sum(), torch_y.abs().sum(), onnx_y.shape, torch_y.shape) + + os.remove(filename) + + +def test_zipformer(): + filename = "zipformer.onnx" + opset_version = 11 + N = 3 + T = 15 + num_features = 80 + x = torch.rand(N, T, num_features) + x_lens = torch.full((N,), fill_value=T, dtype=torch.int64) + + zipformer = Zipformer(num_features=num_features) + zipformer.eval() + zipformer = convert_scaled_to_non_scaled(zipformer, inplace=True) + + # jit_model = torch.jit.trace(zipformer, (x, x_lens)) + torch.onnx.export( + zipformer, + (x, x_lens), + filename, + verbose=False, + opset_version=opset_version, + input_names=["x", "x_lens"], + output_names=["y", "y_lens"], + dynamic_axes={ + "x": {0: "N", 1: "T"}, + "x_lens": {0: "N"}, + "y": {0: "N", 1: "T"}, + "y_lens": {0: "N"}, + }, + ) + options = ort.SessionOptions() + options.inter_op_num_threads = 1 + options.intra_op_num_threads = 1 + + session = ort.InferenceSession( + filename, + sess_options=options, + ) + + input_nodes = session.get_inputs() + inputs = { + input_nodes[0].name: x.numpy(), + input_nodes[1].name: x_lens.numpy(), + } + onnx_y, onnx_y_lens = session.run(["y", "y_lens"], inputs) + onnx_y = torch.from_numpy(onnx_y) + onnx_y_lens = torch.from_numpy(onnx_y_lens) + + torch_y, torch_y_lens = zipformer(x, x_lens) + assert torch.allclose(onnx_y, torch_y, atol=1e-05), (onnx_y - torch_y).abs().max() + + assert torch.allclose(onnx_y_lens, torch_y_lens, atol=1e-05), ( + (onnx_y_lens - torch_y_lens).abs().max() + ) + print(onnx_y.abs().sum(), torch_y.abs().sum(), onnx_y.shape, torch_y.shape) + print(onnx_y_lens, torch_y_lens) + + os.remove(filename) + + +@torch.no_grad() +def main(): + test_conv2d_subsampling() + test_rel_pos() + test_zipformer_encoder_layer() + test_zipformer_encoder() + test_zipformer() + + +if __name__ == "__main__": + torch.manual_seed(20221011) + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index d18258085..b1717ec64 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -210,7 +210,7 @@ class Zipformer(EncoderInterface): (num_frames, batch_size, encoder_dims0) """ num_encoders = len(self.encoder_dims) - if torch.jit.is_scripting() or not self.training: + if torch.jit.is_scripting() or not self.training or torch.jit.is_tracing(): return [1.0] * num_encoders (num_frames0, batch_size, _encoder_dims0) = x.shape @@ -293,7 +293,7 @@ class Zipformer(EncoderInterface): k = self.skip_layers[i] if isinstance(k, int): layer_skip_dropout_prob = self._get_layer_skip_dropout_prob() - if torch.jit.is_scripting(): + if torch.jit.is_scripting() or torch.jit.is_tracing(): x = skip_module(outputs[k], x) elif (not self.training) or random.random() > layer_skip_dropout_prob: x = skip_module(outputs[k], x) @@ -386,7 +386,7 @@ class ZipformerEncoderLayer(nn.Module): ) def get_bypass_scale(self): - if torch.jit.is_scripting() or not self.training: + if torch.jit.is_scripting() or not self.training or torch.jit.is_tracing(): return self.bypass_scale if random.random() < 0.1: # ensure we get grads if self.bypass_scale becomes out of range @@ -407,7 +407,7 @@ class ZipformerEncoderLayer(nn.Module): # return dropout rate for the dynamic modules (self_attn, pooling, convolution); this # starts at 0.2 and rapidly decreases to 0. Its purpose is to keep the training stable # at the beginning, by making the network focus on the feedforward modules. - if torch.jit.is_scripting() or not self.training: + if torch.jit.is_scripting() or not self.training or torch.jit.is_tracing(): return 0.0 warmup_period = 2000.0 initial_dropout_rate = 0.2 @@ -452,12 +452,12 @@ class ZipformerEncoderLayer(nn.Module): dynamic_dropout = self.get_dynamic_dropout_rate() # pooling module - if torch.jit.is_scripting(): + if torch.jit.is_scripting() or torch.jit.is_tracing(): src = src + self.pooling(src, key_padding_mask=src_key_padding_mask) elif random.random() >= dynamic_dropout: src = src + self.pooling(src, key_padding_mask=src_key_padding_mask) - if torch.jit.is_scripting(): + if torch.jit.is_scripting() or torch.jit.is_tracing(): src_att, attn_weights = self.self_attn( src, pos_emb=pos_emb, @@ -658,7 +658,7 @@ class ZipformerEncoder(nn.Module): pos_emb = self.encoder_pos(src) output = src - if torch.jit.is_scripting(): + if torch.jit.is_scripting() or torch.jit.is_tracing(): layers_to_drop = [] else: rnd_seed = src.numel() + random.randint(0, 1000) @@ -667,7 +667,7 @@ class ZipformerEncoder(nn.Module): output = output * feature_mask for i, mod in enumerate(self.layers): - if not torch.jit.is_scripting(): + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): if i in layers_to_drop: continue output = mod( @@ -864,7 +864,7 @@ class SimpleCombiner(torch.nn.Module): assert src1.shape[:-1] == src2.shape[:-1], (src1.shape, src2.shape) weight1 = self.weight1 - if not torch.jit.is_scripting(): + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): if ( self.training and random.random() < 0.25 @@ -1258,21 +1258,31 @@ class RelPositionMultiheadAttention(nn.Module): # the following .as_strided() expression converts the last axis of pos_weights from relative # to absolute position. I don't know whether I might have got the time-offsets backwards or # not, but let this code define which way round it is supposed to be. - pos_weights = pos_weights.as_strided( - (bsz, num_heads, seq_len, seq_len), - ( - pos_weights.stride(0), - pos_weights.stride(1), - pos_weights.stride(2) - pos_weights.stride(3), - pos_weights.stride(3), - ), - storage_offset=pos_weights.stride(3) * (seq_len - 1), - ) + if torch.jit.is_tracing(): + (batch_size, num_heads, time1, n) = pos_weights.shape + rows = torch.arange(start=time1 - 1, end=-1, step=-1) + cols = torch.arange(seq_len) + rows = rows.repeat(batch_size * num_heads).unsqueeze(-1) + indexes = rows + cols + pos_weights = pos_weights.reshape(-1, n) + pos_weights = torch.gather(pos_weights, dim=1, index=indexes) + pos_weights = pos_weights.reshape(batch_size, num_heads, time1, seq_len) + else: + pos_weights = pos_weights.as_strided( + (bsz, num_heads, seq_len, seq_len), + ( + pos_weights.stride(0), + pos_weights.stride(1), + pos_weights.stride(2) - pos_weights.stride(3), + pos_weights.stride(3), + ), + storage_offset=pos_weights.stride(3) * (seq_len - 1), + ) # caution: they are really scores at this point. attn_output_weights = torch.matmul(q, k) + pos_weights - if not torch.jit.is_scripting(): + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): if training and random.random() < 0.1: # This is a harder way of limiting the attention scores to not be too large. # It incurs a penalty if any of them has an absolute value greater than 50.0. @@ -1383,7 +1393,7 @@ class RelPositionMultiheadAttention(nn.Module): # now v: (bsz * num_heads, seq_len, head_dim // 2) attn_output = torch.bmm(attn_weights, v) - if not torch.jit.is_scripting(): + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): if random.random() < 0.001 or __name__ == "__main__": self._print_attn_stats(attn_weights, attn_output) @@ -1458,7 +1468,10 @@ class PoolingModule(nn.Module): a Tensor of shape (1, N, C) """ if key_padding_mask is not None: - pooling_mask = key_padding_mask.logical_not().to(x.dtype) # (N, T) + if torch.jit.is_tracing(): + pooling_mask = (~key_padding_mask).to(x.dtype) + else: + pooling_mask = key_padding_mask.logical_not().to(x.dtype) # (N, T) pooling_mask = pooling_mask / pooling_mask.sum(dim=1, keepdim=True) pooling_mask = pooling_mask.transpose(0, 1).contiguous().unsqueeze(-1) # now pooling_mask: (T, N, 1) From 8642dbc0bd4174acb6612b6510f971f98a16f7d3 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Wed, 4 Jan 2023 12:21:19 +0800 Subject: [PATCH 071/174] Fix setup_dist (#806) --- icefall/dist.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/icefall/dist.py b/icefall/dist.py index 672948623..922f31a2f 100644 --- a/icefall/dist.py +++ b/icefall/dist.py @@ -22,7 +22,7 @@ from torch import distributed as dist def setup_dist( - rank, world_size, master_addr=None, master_port=None, use_ddp_launch=False + rank, world_size, master_port=None, use_ddp_launch=False, master_addr=None ): """ rank and world_size are used only if use_ddp_launch is False. From b9626f2e0684dd59d761c6bd6e9b6127a387d11c Mon Sep 17 00:00:00 2001 From: Yifan Yang <64255737+yfyeung@users.noreply.github.com> Date: Thu, 5 Jan 2023 17:18:43 +0800 Subject: [PATCH 072/174] fix typo for ctc-decode.py (#815) Co-authored-by: yifanyang --- .../ASR/pruned_transducer_stateless7_ctc/ctc_decode.py | 2 +- .../ASR/pruned_transducer_stateless7_ctc_bs/ctc_decode.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/ctc_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/ctc_decode.py index 9c23e7d66..4b373e4c7 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/ctc_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/ctc_decode.py @@ -44,7 +44,7 @@ Usage: --exp-dir ./pruned_transducer_stateless7_ctc/exp \ --max-duration 600 \ --hlg-scale 0.8 \ - --decoding-method 1best + --decoding-method nbest (4) nbest-rescoring ./pruned_transducer_stateless7_ctc/ctc_decode.py \ diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/ctc_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/ctc_decode.py index 0ef733226..f137485b2 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/ctc_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/ctc_decode.py @@ -42,7 +42,7 @@ Usage: --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ --max-duration 600 \ --hlg-scale 0.8 \ - --decoding-method 1best + --decoding-method nbest (4) nbest-rescoring ./pruned_transducer_stateless7_ctc_bs/ctc_decode.py \ --epoch 30 \ From 9a9c5a0f9b083a729ee00d439df1054f517e1b6d Mon Sep 17 00:00:00 2001 From: kobenaxie <572745565@qq.com> Date: Fri, 6 Jan 2023 11:16:22 +0800 Subject: [PATCH 073/174] remove unused codes. (#821) --- .../emformer2.py | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer2.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer2.py index 188059044..f0c92a9b4 100644 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer2.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer2.py @@ -1512,24 +1512,6 @@ class EmformerEncoder(nn.Module): ) return states - attn_caches = [ - [ - torch.zeros(self.memory_size, self.d_model, device=device), - torch.zeros(self.left_context_length, self.d_model, device=device), - torch.zeros(self.left_context_length, self.d_model, device=device), - ] - for _ in range(self.num_encoder_layers) - ] - conv_caches = [ - torch.zeros(self.d_model, self.cnn_module_kernel - 1, device=device) - for _ in range(self.num_encoder_layers) - ] - states: Tuple[List[List[torch.Tensor]], List[torch.Tensor]] = ( - attn_caches, - conv_caches, - ) - return states - class Emformer(EncoderInterface): def __init__( From 9453eb1c709140becd3373666bb51e996e8f7260 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Fri, 6 Jan 2023 17:00:27 +0800 Subject: [PATCH 074/174] Fix doc for building ncnn (#822) --- docs/README.md | 24 ++++++++++++++ .../lstm_pruned_stateless_transducer.rst | 31 ++++++++++++++++--- 2 files changed, 51 insertions(+), 4 deletions(-) create mode 100644 docs/README.md diff --git a/docs/README.md b/docs/README.md new file mode 100644 index 000000000..3abb38f8b --- /dev/null +++ b/docs/README.md @@ -0,0 +1,24 @@ + +## Usage + +```bash +cd /path/to/icefall/docs +pip install -r requirements.txt +make clean +make html +cd build/html +python3 -m http.server 8000 +``` + +It prints: + +``` +Serving HTTP on 0.0.0.0 port 8000 (http://0.0.0.0:8000/) ... +``` + +Open your browser and go to to view the generated +documentation. + +Done! + +**Hint**: You can change the port number when starting the server. diff --git a/docs/source/recipes/Streaming-ASR/librispeech/lstm_pruned_stateless_transducer.rst b/docs/source/recipes/Streaming-ASR/librispeech/lstm_pruned_stateless_transducer.rst index 643855cc2..d09421eb5 100644 --- a/docs/source/recipes/Streaming-ASR/librispeech/lstm_pruned_stateless_transducer.rst +++ b/docs/source/recipes/Streaming-ASR/librispeech/lstm_pruned_stateless_transducer.rst @@ -531,16 +531,36 @@ First, let us install a modified version of ``ncnn``: git clone https://github.com/csukuangfj/ncnn cd ncnn git submodule update --recursive --init - python3 setup.py bdist_wheel - ls -lh dist/ - pip install ./dist/*.whl + + # Note: We don't use "python setup.py install" or "pip install ." here + + mkdir -p build-wheel + cd build-wheel + + cmake \ + -DCMAKE_BUILD_TYPE=Release \ + -DNCNN_PYTHON=ON \ + -DNCNN_BUILD_BENCHMARK=OFF \ + -DNCNN_BUILD_EXAMPLES=OFF \ + -DNCNN_BUILD_TOOLS=OFF \ + .. + + make -j4 + + cd .. + + # Note: $PWD here is /path/to/ncnn + + export PYTHONPATH=$PWD/python:$PYTHONPATH + export PATH=$PWD/tools/pnnx/build/src:$PATH + export PATH=$PWD/build/tools/quantize:$PATH # now build pnnx cd tools/pnnx mkdir build cd build + cmake .. make -j4 - export PATH=$PWD/src:$PATH ./src/pnnx @@ -549,6 +569,9 @@ First, let us install a modified version of ``ncnn``: We assume that you have added the path to the binary ``pnnx`` to the environment variable ``PATH``. + We also assume that you have added ``build/tools/quantize`` to the environment + variable ``PATH`` so that you are able to use ``ncnn2int8`` later. + Second, let us export the model using ``torch.jit.trace()`` that is suitable for ``pnnx``: From 42cc10117eed5960e7219bbb9501a0beda602cfa Mon Sep 17 00:00:00 2001 From: marcoyang1998 <45973641+marcoyang1998@users.noreply.github.com> Date: Mon, 9 Jan 2023 15:08:39 +0800 Subject: [PATCH 075/174] Fix ncnn install (#824) * add README to docs * fix ncnn installation --- .../librispeech/lstm_pruned_stateless_transducer.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/recipes/Streaming-ASR/librispeech/lstm_pruned_stateless_transducer.rst b/docs/source/recipes/Streaming-ASR/librispeech/lstm_pruned_stateless_transducer.rst index d09421eb5..22addd1d2 100644 --- a/docs/source/recipes/Streaming-ASR/librispeech/lstm_pruned_stateless_transducer.rst +++ b/docs/source/recipes/Streaming-ASR/librispeech/lstm_pruned_stateless_transducer.rst @@ -542,7 +542,7 @@ First, let us install a modified version of ``ncnn``: -DNCNN_PYTHON=ON \ -DNCNN_BUILD_BENCHMARK=OFF \ -DNCNN_BUILD_EXAMPLES=OFF \ - -DNCNN_BUILD_TOOLS=OFF \ + -DNCNN_BUILD_TOOLS=ON \ .. make -j4 @@ -553,7 +553,7 @@ First, let us install a modified version of ``ncnn``: export PYTHONPATH=$PWD/python:$PYTHONPATH export PATH=$PWD/tools/pnnx/build/src:$PATH - export PATH=$PWD/build/tools/quantize:$PATH + export PATH=$PWD/build-wheel/tools/quantize:$PATH # now build pnnx cd tools/pnnx From fcffa593f011bd3213af5af044eb3ce2ede666c1 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 10 Jan 2023 15:38:33 +0800 Subject: [PATCH 076/174] Add FAQs to doc (#827) * Add FAQs * small fixes --- docs/source/faqs.rst | 67 +++++++++++++++++++++++++++++++++++++++++++ docs/source/index.rst | 1 + 2 files changed, 68 insertions(+) create mode 100644 docs/source/faqs.rst diff --git a/docs/source/faqs.rst b/docs/source/faqs.rst new file mode 100644 index 000000000..c70ded431 --- /dev/null +++ b/docs/source/faqs.rst @@ -0,0 +1,67 @@ +Frequently Asked Questions (FAQs) +================================= + +In this section, we collect issues reported by users and post the corresponding +solutions. + + +OSError: libtorch_hip.so: cannot open shared object file: no such file or directory +----------------------------------------------------------------------------------- + +One user is using the following code to install ``torch`` and ``torchaudio``: + +.. code-block:: bash + + pip install \ + torch==1.10.0+cu111 \ + torchvision==0.11.0+cu111 \ + torchaudio==0.10.0 \ + -f https://download.pytorch.org/whl/torch_stable.html + +and it throws the following error when running ``tdnn/train.py``: + +.. code-block:: + + OSError: libtorch_hip.so: cannot open shared object file: no such file or directory + +The fix is to specify the CUDA version while installing ``torchaudio``. That +is, change ``torchaudio==0.10.0`` to ``torchaudio==0.10.0+cu11```. Therefore, +the correct command is: + +.. code-block:: bash + + pip install \ + torch==1.10.0+cu111 \ + torchvision==0.11.0+cu111 \ + torchaudio==0.10.0+cu111 \ + -f https://download.pytorch.org/whl/torch_stable.html + +AttributeError: module 'distutils' has no attribute 'version' +------------------------------------------------------------- + +The error log is: + +.. code-block:: + + Traceback (most recent call last): + File "./tdnn/train.py", line 14, in + from asr_datamodule import YesNoAsrDataModule + File "/home/xxx/code/next-gen-kaldi/icefall/egs/yesno/ASR/tdnn/asr_datamodule.py", line 34, in + from icefall.dataset.datamodule import DataModule + File "/home/xxx/code/next-gen-kaldi/icefall/icefall/__init__.py", line 3, in + from . import ( + File "/home/xxx/code/next-gen-kaldi/icefall/icefall/decode.py", line 23, in + from icefall.utils import add_eos, add_sos, get_texts + File "/home/xxx/code/next-gen-kaldi/icefall/icefall/utils.py", line 39, in + from torch.utils.tensorboard import SummaryWriter + File "/home/xxx/tool/miniconda3/envs/yyy/lib/python3.8/site-packages/torch/utils/tensorboard/__init__.py", line 4, in + LooseVersion = distutils.version.LooseVersion + AttributeError: module 'distutils' has no attribute 'version' + +The fix is: + +.. code-block:: bash + + pip uninstall setuptools + + pip install setuptools==58.0.4 diff --git a/docs/source/index.rst b/docs/source/index.rst index 4ea446259..8d76eb68b 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -21,6 +21,7 @@ speech recognition recipes using `k2 `_. :caption: Contents: installation/index + faqs model-export/index .. toctree:: From c05f5d76df6e9cc208b99308a8e426e54e9be69e Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 10 Jan 2023 20:52:13 +0800 Subject: [PATCH 077/174] fix decoding for ncnn (#828) --- .../streaming-ncnn-decode.py | 8 +++++--- .../ASR/lstm_transducer_stateless2/ncnn-decode.py | 8 +++++--- .../lstm_transducer_stateless2/streaming-ncnn-decode.py | 8 +++++--- 3 files changed, 15 insertions(+), 9 deletions(-) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming-ncnn-decode.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming-ncnn-decode.py index b21fe5c7e..e4104a5bb 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming-ncnn-decode.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming-ncnn-decode.py @@ -131,6 +131,8 @@ class Model: encoder_net = ncnn.Net() encoder_net.opt.use_packing_layout = False encoder_net.opt.use_fp16_storage = False + encoder_net.opt.num_threads = 4 + encoder_param = args.encoder_param_filename encoder_model = args.encoder_bin_filename @@ -144,6 +146,7 @@ class Model: decoder_model = args.decoder_bin_filename decoder_net = ncnn.Net() + decoder_net.opt.num_threads = 4 decoder_net.load_param(decoder_param) decoder_net.load_model(decoder_model) @@ -154,6 +157,8 @@ class Model: joiner_param = args.joiner_param_filename joiner_model = args.joiner_bin_filename joiner_net = ncnn.Net() + joiner_net.opt.num_threads = 4 + joiner_net.load_param(joiner_param) joiner_net.load_model(joiner_model) @@ -176,7 +181,6 @@ class Model: - next_states, a list of tensors containing the next states """ with self.encoder_net.create_extractor() as ex: - ex.set_num_threads(4) ex.input("in0", ncnn.Mat(x.numpy()).clone()) # layer0 in2-in5 @@ -220,7 +224,6 @@ class Model: assert decoder_input.dtype == torch.int32 with self.decoder_net.create_extractor() as ex: - ex.set_num_threads(4) ex.input("in0", ncnn.Mat(decoder_input.numpy()).clone()) ret, ncnn_out0 = ex.extract("out0") assert ret == 0, ret @@ -229,7 +232,6 @@ class Model: def run_joiner(self, encoder_out, decoder_out): with self.joiner_net.create_extractor() as ex: - ex.set_num_threads(4) ex.input("in0", ncnn.Mat(encoder_out.numpy()).clone()) ex.input("in1", ncnn.Mat(decoder_out.numpy()).clone()) ret, ncnn_out0 = ex.extract("out0") diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/ncnn-decode.py b/egs/librispeech/ASR/lstm_transducer_stateless2/ncnn-decode.py index 3b471fa85..3bd1b0a09 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/ncnn-decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/ncnn-decode.py @@ -104,6 +104,8 @@ class Model: encoder_net = ncnn.Net() encoder_net.opt.use_packing_layout = False encoder_net.opt.use_fp16_storage = False + encoder_net.opt.num_threads = 4 + encoder_param = args.encoder_param_filename encoder_model = args.encoder_bin_filename @@ -118,6 +120,7 @@ class Model: decoder_net = ncnn.Net() decoder_net.opt.use_packing_layout = False + decoder_net.opt.num_threads = 4 decoder_net.load_param(decoder_param) decoder_net.load_model(decoder_model) @@ -129,6 +132,8 @@ class Model: joiner_model = args.joiner_bin_filename joiner_net = ncnn.Net() joiner_net.opt.use_packing_layout = False + joiner_net.opt.num_threads = 4 + joiner_net.load_param(joiner_param) joiner_net.load_model(joiner_model) @@ -136,7 +141,6 @@ class Model: def run_encoder(self, x, states): with self.encoder_net.create_extractor() as ex: - ex.set_num_threads(10) ex.input("in0", ncnn.Mat(x.numpy()).clone()) x_lens = torch.tensor([x.size(0)], dtype=torch.float32) ex.input("in1", ncnn.Mat(x_lens.numpy()).clone()) @@ -165,7 +169,6 @@ class Model: assert decoder_input.dtype == torch.int32 with self.decoder_net.create_extractor() as ex: - ex.set_num_threads(10) ex.input("in0", ncnn.Mat(decoder_input.numpy()).clone()) ret, ncnn_out0 = ex.extract("out0") assert ret == 0, ret @@ -174,7 +177,6 @@ class Model: def run_joiner(self, encoder_out, decoder_out): with self.joiner_net.create_extractor() as ex: - ex.set_num_threads(10) ex.input("in0", ncnn.Mat(encoder_out.numpy()).clone()) ex.input("in1", ncnn.Mat(decoder_out.numpy()).clone()) ret, ncnn_out0 = ex.extract("out0") diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-ncnn-decode.py b/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-ncnn-decode.py index baff15ea6..02ed16a8c 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-ncnn-decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-ncnn-decode.py @@ -92,6 +92,8 @@ class Model: encoder_net = ncnn.Net() encoder_net.opt.use_packing_layout = False encoder_net.opt.use_fp16_storage = False + encoder_net.opt.num_threads = 4 + encoder_param = args.encoder_param_filename encoder_model = args.encoder_bin_filename @@ -106,6 +108,7 @@ class Model: decoder_net = ncnn.Net() decoder_net.opt.use_packing_layout = False + decoder_net.opt.num_threads = 4 decoder_net.load_param(decoder_param) decoder_net.load_model(decoder_model) @@ -117,6 +120,8 @@ class Model: joiner_model = args.joiner_bin_filename joiner_net = ncnn.Net() joiner_net.opt.use_packing_layout = False + joiner_net.opt.num_threads = 4 + joiner_net.load_param(joiner_param) joiner_net.load_model(joiner_model) @@ -124,7 +129,6 @@ class Model: def run_encoder(self, x, states): with self.encoder_net.create_extractor() as ex: - # ex.set_num_threads(10) ex.input("in0", ncnn.Mat(x.numpy()).clone()) x_lens = torch.tensor([x.size(0)], dtype=torch.float32) ex.input("in1", ncnn.Mat(x_lens.numpy()).clone()) @@ -153,7 +157,6 @@ class Model: assert decoder_input.dtype == torch.int32 with self.decoder_net.create_extractor() as ex: - # ex.set_num_threads(10) ex.input("in0", ncnn.Mat(decoder_input.numpy()).clone()) ret, ncnn_out0 = ex.extract("out0") assert ret == 0, ret @@ -162,7 +165,6 @@ class Model: def run_joiner(self, encoder_out, decoder_out): with self.joiner_net.create_extractor() as ex: - # ex.set_num_threads(10) ex.input("in0", ncnn.Mat(encoder_out.numpy()).clone()) ex.input("in1", ncnn.Mat(decoder_out.numpy()).clone()) ret, ncnn_out0 = ex.extract("out0") From 8582b6e41acbd1258633492212c06589d1370960 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Wed, 11 Jan 2023 15:34:30 +0800 Subject: [PATCH 078/174] Add doc about converting conv-emformer to sherpa-ncnn (#830) --- docs/source/conf.py | 6 + ...nv-emformer-transducer-for-ncnn-output.txt | 21 + ...-decode-conv-emformer-transducer-libri.txt | 7 + docs/source/model-export/export-ncnn.rst | 492 +++++++++++++++++- .../lstm_pruned_stateless_transducer.rst | 9 +- 5 files changed, 526 insertions(+), 9 deletions(-) create mode 100644 docs/source/model-export/code/export-conv-emformer-transducer-for-ncnn-output.txt create mode 100644 docs/source/model-export/code/test-stremaing-ncnn-decode-conv-emformer-transducer-libri.txt diff --git a/docs/source/conf.py b/docs/source/conf.py index 221d9d734..33429f74c 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -78,3 +78,9 @@ html_context = { } todo_include_todos = True + +rst_epilog = """ +.. _sherpa-ncnn: https://github.com/k2-fsa/sherpa-ncnn +.. _git-lfs: https://git-lfs.com/ +.. _ncnn: https://github.com/tencent/ncnn +""" diff --git a/docs/source/model-export/code/export-conv-emformer-transducer-for-ncnn-output.txt b/docs/source/model-export/code/export-conv-emformer-transducer-for-ncnn-output.txt new file mode 100644 index 000000000..ecbdd4b31 --- /dev/null +++ b/docs/source/model-export/code/export-conv-emformer-transducer-for-ncnn-output.txt @@ -0,0 +1,21 @@ +2023-01-11 12:15:38,677 INFO [export-for-ncnn.py:220] device: cpu +2023-01-11 12:15:38,681 INFO [export-for-ncnn.py:229] {'best_train_loss': inf, 'best_valid_loss': inf, 'best_train_epoch': -1, 'best_v +alid_epoch': -1, 'batch_idx_train': 0, 'log_interval': 50, 'reset_interval': 200, 'valid_interval': 3000, 'feature_dim': 80, 'subsampl +ing_factor': 4, 'decoder_dim': 512, 'joiner_dim': 512, 'model_warm_step': 3000, 'env_info': {'k2-version': '1.23.2', 'k2-build-type': +'Release', 'k2-with-cuda': True, 'k2-git-sha1': 'a34171ed85605b0926eebbd0463d059431f4f74a', 'k2-git-date': 'Wed Dec 14 00:06:38 2022', + 'lhotse-version': '1.12.0.dev+missing.version.file', 'torch-version': '1.10.0+cu102', 'torch-cuda-available': False, 'torch-cuda-vers +ion': '10.2', 'python-version': '3.8', 'icefall-git-branch': 'fix-stateless3-train-2022-12-27', 'icefall-git-sha1': '530e8a1-dirty', ' +icefall-git-date': 'Tue Dec 27 13:59:18 2022', 'icefall-path': '/star-fj/fangjun/open-source/icefall', 'k2-path': '/star-fj/fangjun/op +en-source/k2/k2/python/k2/__init__.py', 'lhotse-path': '/star-fj/fangjun/open-source/lhotse/lhotse/__init__.py', 'hostname': 'de-74279 +-k2-train-3-1220120619-7695ff496b-s9n4w', 'IP address': '127.0.0.1'}, 'epoch': 30, 'iter': 0, 'avg': 1, 'exp_dir': PosixPath('icefa +ll-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp'), 'bpe_model': './icefall-asr-librispeech-conv-emformer-transdu +cer-stateless2-2022-07-05//data/lang_bpe_500/bpe.model', 'jit': False, 'context_size': 2, 'use_averaged_model': False, 'encoder_dim': +512, 'nhead': 8, 'dim_feedforward': 2048, 'num_encoder_layers': 12, 'cnn_module_kernel': 31, 'left_context_length': 32, 'chunk_length' +: 32, 'right_context_length': 8, 'memory_size': 32, 'blank_id': 0, 'vocab_size': 500} +2023-01-11 12:15:38,681 INFO [export-for-ncnn.py:231] About to create model +2023-01-11 12:15:40,053 INFO [checkpoint.py:112] Loading checkpoint from icefall-asr-librispeech-conv-emformer-transducer-stateless2-2 +022-07-05/exp/epoch-30.pt +2023-01-11 12:15:40,708 INFO [export-for-ncnn.py:315] Number of model parameters: 75490012 +2023-01-11 12:15:41,681 INFO [export-for-ncnn.py:318] Using torch.jit.trace() +2023-01-11 12:15:41,681 INFO [export-for-ncnn.py:320] Exporting encoder +2023-01-11 12:15:41,682 INFO [export-for-ncnn.py:149] chunk_length: 32, right_context_length: 8 diff --git a/docs/source/model-export/code/test-stremaing-ncnn-decode-conv-emformer-transducer-libri.txt b/docs/source/model-export/code/test-stremaing-ncnn-decode-conv-emformer-transducer-libri.txt new file mode 100644 index 000000000..114fe7342 --- /dev/null +++ b/docs/source/model-export/code/test-stremaing-ncnn-decode-conv-emformer-transducer-libri.txt @@ -0,0 +1,7 @@ +2023-01-11 14:02:12,216 INFO [streaming-ncnn-decode.py:320] {'tokens': './icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/data/lang_bpe_500/tokens.txt', 'encoder_param_filename': './icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/encoder_jit_trace-pnnx.ncnn.param', 'encoder_bin_filename': './icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/encoder_jit_trace-pnnx.ncnn.bin', 'decoder_param_filename': './icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/decoder_jit_trace-pnnx.ncnn.param', 'decoder_bin_filename': './icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/decoder_jit_trace-pnnx.ncnn.bin', 'joiner_param_filename': './icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/joiner_jit_trace-pnnx.ncnn.param', 'joiner_bin_filename': './icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/joiner_jit_trace-pnnx.ncnn.bin', 'sound_filename': './icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/test_wavs/1089-134686-0001.wav'} +T 51 32 +2023-01-11 14:02:13,141 INFO [streaming-ncnn-decode.py:328] Constructing Fbank computer +2023-01-11 14:02:13,151 INFO [streaming-ncnn-decode.py:331] Reading sound files: ./icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/test_wavs/1089-134686-0001.wav +2023-01-11 14:02:13,176 INFO [streaming-ncnn-decode.py:336] torch.Size([106000]) +2023-01-11 14:02:17,581 INFO [streaming-ncnn-decode.py:380] ./icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/test_wavs/1089-134686-0001.wav +2023-01-11 14:02:17,581 INFO [streaming-ncnn-decode.py:381] AFTER EARLY NIGHTFALL THE YELLOW LAMPS WOULD LIGHT UP HERE AND THERE THE SQUALID QUARTER OF THE BROTHELS diff --git a/docs/source/model-export/export-ncnn.rst b/docs/source/model-export/export-ncnn.rst index 3dbb8b514..11471d611 100644 --- a/docs/source/model-export/export-ncnn.rst +++ b/docs/source/model-export/export-ncnn.rst @@ -1,12 +1,492 @@ Export to ncnn ============== -We support exporting LSTM transducer models to `ncnn `_. - -Please refer to :ref:`export-model-for-ncnn` for details. +We support exporting both +`LSTM transducer models `_ +and +`ConvEmformer transducer models `_ +to `ncnn `_. We also provide ``_ performing speech recognition using ``ncnn`` with exported models. -It has been tested on Linux, macOS, Windows, and Raspberry Pi. The project is -self-contained and can be statically linked to produce a binary containing -everything needed. +It has been tested on Linux, macOS, Windows, ``Android``, and ``Raspberry Pi``. + +`sherpa-ncnn`_ is self-contained and can be statically linked to produce +a binary containing everything needed. Please refer +to its documentation for details: + + - ``_ + + +Export LSTM transducer models +----------------------------- + +Please refer to :ref:`export-lstm-transducer-model-for-ncnn` for details. + + + +Export ConvEmformer transducer models +------------------------------------- + +We use the pre-trained model from the following repository as an example: + + - ``_ + +We will show you step by step how to export it to `ncnn`_ and run it with `sherpa-ncnn`_. + +.. hint:: + + We use ``Ubuntu 18.04``, ``torch 1.10``, and ``Python 3.8`` for testing. + +.. caution:: + + Please use a more recent version of PyTorch. For instance, ``torch 1.8`` + may ``not`` work. + +1. Download the pre-trained model +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. hint:: + + You can also refer to ``_ to download the pre-trained model. + + You have to install `git-lfs`_ before you continue. + +.. code-block:: bash + + cd egs/librispeech/ASR + + GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05 + cd icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05 + + git lfs pull --include "exp/pretrained-epoch-30-avg-10-averaged.pt" + git lfs pull --include "data/lang_bpe_500/bpe.model" + + cd .. + +.. note:: + + We download ``exp/pretrained-xxx.pt``, not ``exp/cpu-jit_xxx.pt``. + + +In the above code, we download the pre-trained model into the directory +``egs/librispeech/ASR/icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05``. + +2. Install ncnn and pnnx +^^^^^^^^^^^^^^^^^^^^^^^^ + +.. code-block:: bash + + # We put ncnn into $HOME/open-source/ncnn + # You can change it to anywhere you like + + cd $HOME + mkdir -p open-source + cd open-source + + git clone https://github.com/csukuangfj/ncnn + cd ncnn + git submodule update --recursive --init + + # Note: We don't use "python setup.py install" or "pip install ." here + + mkdir -p build-wheel + cd build-wheel + + cmake \ + -DCMAKE_BUILD_TYPE=Release \ + -DNCNN_PYTHON=ON \ + -DNCNN_BUILD_BENCHMARK=OFF \ + -DNCNN_BUILD_EXAMPLES=OFF \ + -DNCNN_BUILD_TOOLS=ON \ + .. + + make -j4 + + cd .. + + # Note: $PWD here is $HOME/open-source/ncnn + + export PYTHONPATH=$PWD/python:$PYTHONPATH + export PATH=$PWD/tools/pnnx/build/src:$PATH + export PATH=$PWD/build-wheel/tools/quantize:$PATH + + # Now build pnnx + cd tools/pnnx + mkdir build + cd build + cmake .. + make -j4 + + ./src/pnnx + +Congratulations! You have successfully installed the following components: + + - ``pnxx``, which is an executable located in + ``$HOME/open-source/ncnn/tools/pnnx/build/src``. We will use + it to convert models exported by ``torch.jit.trace()``. + - ``ncnn2int8``, which is an executable located in + ``$HOME/open-source/ncnn/build-wheel/tools/quantize``. We will use + it to quantize our models to ``int8``. + - ``ncnn.cpython-38-x86_64-linux-gnu.so``, which is a Python module located + in ``$HOME/open-source/ncnn/python/ncnn``. + + .. note:: + + I am using ``Python 3.8``, so it + is ``ncnn.cpython-38-x86_64-linux-gnu.so``. If you use a different + version, say, ``Python 3.9``, the name would be + ``ncnn.cpython-39-x86_64-linux-gnu.so``. + + Also, if you are not using Linux, the file name would also be different. + But that does not matter. As long as you can compile it, it should work. + +We have set up ``PYTHONPATH`` so that you can use ``import ncnn`` in your +Python code. We have also set up ``PATH`` so that you can use +``pnnx`` and ``ncnn2int8`` later in your terminal. + +.. caution:: + + Please don't use ``_. + We have made some modifications to the offical `ncnn`_. + + We will synchronize ``_ periodically + with the official one. + +3. Export the model via torch.jit.trace() +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +First, let us rename our pre-trained model: + +.. code-block:: + + cd egs/librispeech/ASR + + cd icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp + + ln -s pretrained-epoch-30-avg-10-averaged.pt epoch-30.pt + + cd ../.. + +Next, we use the following code to export our model: + +.. code-block:: bash + + dir=./icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/ + + ./conv_emformer_transducer_stateless2/export-for-ncnn.py \ + --exp-dir $dir/exp \ + --bpe-model $dir/data/lang_bpe_500/bpe.model \ + --epoch 30 \ + --avg 1 \ + --use-averaged-model 0 \ + \ + --num-encoder-layers 12 \ + --chunk-length 32 \ + --cnn-module-kernel 31 \ + --left-context-length 32 \ + --right-context-length 8 \ + --memory-size 32 \ + --encoder-dim 512 + +.. hint:: + + We have renamed our model to ``epoch-30.pt`` so that we can use ``--epoch 30``. + There is only one pre-trained model, so we use ``--avg 1 --use-averaged-model 0``. + + If you have trained a model by yourself and if you have all checkpoints + available, please first use ``decode.py`` to tune ``--epoch --avg`` + and select the best combination with with ``--use-averaged-model 1``. + +.. note:: + + You will see the following log output: + + .. literalinclude:: ./code/export-conv-emformer-transducer-for-ncnn-output.txt + + The log shows the model has ``75490012`` number of parameters, i.e., ``~75 M``. + + .. code-block:: + + ls -lh icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/pretrained-epoch-30-avg-10-averaged.pt + + -rw-r--r-- 1 kuangfangjun root 289M Jan 11 12:05 icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/pretrained-epoch-30-avg-10-averaged.pt + + You can see that the file size of the pre-trained model is ``289 MB``, which + is roughly ``4 x 75 M``. + +After running ``conv_emformer_transducer_stateless2/export-for-ncnn.py``, +we will get the following files: + +.. code-block:: bash + + ls -lh icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/*pnnx* + + -rw-r--r-- 1 kuangfangjun root 1010K Jan 11 12:15 icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/decoder_jit_trace-pnnx.pt + -rw-r--r-- 1 kuangfangjun root 283M Jan 11 12:15 icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/encoder_jit_trace-pnnx.pt + -rw-r--r-- 1 kuangfangjun root 3.0M Jan 11 12:15 icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/joiner_jit_trace-pnnx.pt + + +.. _conv-emformer-step-3-export-torchscript-model-via-pnnx: + +3. Export torchscript model via pnnx +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. hint:: + + Make sure you have set up the ``PATH`` environment variable. Otherwise, + it will throw an error saying that ``pnnx`` could not be found. + +Now, it's time to export our models to `ncnn`_ via ``pnnx``. + +.. code-block:: + + cd icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/ + + pnnx ./encoder_jit_trace-pnnx.pt + pnnx ./decoder_jit_trace-pnnx.pt + pnnx ./joiner_jit_trace-pnnx.pt + +It will generate the following files: + +.. code-block:: bash + + ls -lh icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/*ncnn*{bin,param} + + -rw-r--r-- 1 kuangfangjun root 503K Jan 11 12:38 icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/decoder_jit_trace-pnnx.ncnn.bin + -rw-r--r-- 1 kuangfangjun root 437 Jan 11 12:38 icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/decoder_jit_trace-pnnx.ncnn.param + -rw-r--r-- 1 kuangfangjun root 142M Jan 11 12:36 icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/encoder_jit_trace-pnnx.ncnn.bin + -rw-r--r-- 1 kuangfangjun root 79K Jan 11 12:36 icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/encoder_jit_trace-pnnx.ncnn.param + -rw-r--r-- 1 kuangfangjun root 1.5M Jan 11 12:38 icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/joiner_jit_trace-pnnx.ncnn.bin + -rw-r--r-- 1 kuangfangjun root 488 Jan 11 12:38 icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/joiner_jit_trace-pnnx.ncnn.param + +There are two types of files: + +- ``param``: It is a text file containing the model architectures. You can + use a text editor to view its content. +- ``bin``: It is a binary file containing the model parameters. + +We compare the file sizes of the models below before and after converting via ``pnnx``: + +.. see https://tableconvert.com/restructuredtext-generator + ++----------------------------------+------------+ +| File name | File size | ++==================================+============+ +| encoder_jit_trace-pnnx.pt | 283 MB | ++----------------------------------+------------+ +| decoder_jit_trace-pnnx.pt | 1010 KB | ++----------------------------------+------------+ +| joiner_jit_trace-pnnx.pt | 3.0 MB | ++----------------------------------+------------+ +| encoder_jit_trace-pnnx.ncnn.bin | 142 MB | ++----------------------------------+------------+ +| decoder_jit_trace-pnnx.ncnn.bin | 503 KB | ++----------------------------------+------------+ +| joiner_jit_trace-pnnx.ncnn.bin | 1.5 MB | ++----------------------------------+------------+ + +You can see that the file size of the models after converting is about one half +of the models before converting: + + - encoder: 283 MB vs 142 MB + - decoder: 1010 KB vs 503 KB + - joiner: 3.0 MB vs 1.5 MB + +The reason is that by default ``pnnx`` converts ``float32`` parameters +to ``float16``. A ``float32`` parameter occupies 4 bytes, while it is 2 bytes +for ``float16``. Thus, it is ``twice smaller`` after conversion. + +.. hint:: + + If you use ``pnnx ./encoder_jit_trace-pnnx.pt fp16=0``, then ``pnnx`` + won't convert ``float32`` to ``float16``. + +4. Test the exported models in icefall +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. note:: + + We assume you have set up the environment variable ``PYTHONPATH`` when + building `ncnn`_. + +Now we have successfully converted our pre-trained model to `ncnn`_ format. +The generated 6 files are what we need. You can use the following code to +test the converted models: + +.. code-block:: bash + + ./conv_emformer_transducer_stateless2/streaming-ncnn-decode.py \ + --tokens ./icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/data/lang_bpe_500/tokens.txt \ + --encoder-param-filename ./icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/encoder_jit_trace-pnnx.ncnn.param \ + --encoder-bin-filename ./icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/encoder_jit_trace-pnnx.ncnn.bin \ + --decoder-param-filename ./icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/decoder_jit_trace-pnnx.ncnn.param \ + --decoder-bin-filename ./icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/decoder_jit_trace-pnnx.ncnn.bin \ + --joiner-param-filename ./icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/joiner_jit_trace-pnnx.ncnn.param \ + --joiner-bin-filename ./icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/joiner_jit_trace-pnnx.ncnn.bin \ + ./icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/test_wavs/1089-134686-0001.wav + +.. hint:: + + `ncnn`_ supports only ``batch size == 1``, so ``streaming-ncnn-decode.py`` accepts + only 1 wave file as input. + +The output is given below: + +.. literalinclude:: ./code/test-stremaing-ncnn-decode-conv-emformer-transducer-libri.txt + +Congratulations! You have successfully exported a model from PyTorch to `ncnn`_! + + +5. Modify the exported encoder for sherpa-ncnn +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +In order to use the exported models in `sherpa-ncnn`_, we have to modify +``encoder_jit_trace-pnnx.ncnn.param``. + +Let us have a look at the first few lines of ``encoder_jit_trace-pnnx.ncnn.param``: + +.. code-block:: + + 7767517 + 1060 1342 + Input in0 0 1 in0 + +**Explanation** of the above three lines: + + 1. ``7767517``, it is a magic number and should not be changed. + 2. ``1060 1342``, the first number ``1060`` specifies the number of layers + in this file, while ``1342`` specifies the number intermediate outputs of + this file + 3. ``Input in0 0 1 in0``, ``Input`` is the layer type of this layer; ``in0`` + is the layer name of this layer; ``0`` means this layer has no input; + ``1`` means this layer has one output. ``in0`` is the output name of + this layer. + +We need to add 1 extra line and the result looks like below: + +.. code-block:: bash + + 7767517 + 1061 1342 + SherpaMetaData sherpa_meta_data1 0 0 0=1 1=12 2=32 3=31 4=8 5=32 6=8 7=512 + Input in0 0 1 in0 + +**Explanation** + + 1. ``7767517``, it is still the same + 2. ``1061 1342``, we have added an extra layer, so we need to update ``1060`` to ``1061``. + We don't need to change ``1342`` since the newly added layer has no inputs and outputs. + 3. ``SherpaMetaData sherpa_meta_data1 0 0 0=1 1=12 2=32 3=31 4=8 5=32 6=8 7=512`` + This line is newly added. Its explanation is given below: + + - ``SherpaMetaData`` is the type of this layer. Must be ``SherpaMetaData``. + - ``sherpa_meta_data1`` is the name of this layer. Must be ``sherpa_meta_data1``. + - ``0 0`` means this layer has no inputs and output. Must be ``0 0`` + - ``0=1``, 0 is the key and 1 is the value. MUST be ``0=1`` + - ``1=12``, 1 is the key and 12 is the value of the + parameter ``--num-encoder-layers`` that you provided when running + ``conv_emformer_transducer_stateless2/export-for-ncnn.py``. + - ``2=32``, 2 is the key and 32 is the value of the + parameter ``--memory-size`` that you provided when running + ``conv_emformer_transducer_stateless2/export-for-ncnn.py``. + - ``3=31``, 3 is the key and 31 is the value of the + parameter ``--cnn-module-kernel`` that you provided when running + ``conv_emformer_transducer_stateless2/export-for-ncnn.py``. + - ``4=8``, 4 is the key and 8 is the value of the + parameter ``--left-context-length`` that you provided when running + ``conv_emformer_transducer_stateless2/export-for-ncnn.py``. + - ``5=32``, 5 is the key and 32 is the value of the + parameter ``--chunk-length`` that you provided when running + ``conv_emformer_transducer_stateless2/export-for-ncnn.py``. + - ``6=8``, 6 is the key and 8 is the value of the + parameter ``--right-context-length`` that you provided when running + ``conv_emformer_transducer_stateless2/export-for-ncnn.py``. + - ``7=512``, 7 is the key and 512 is the value of the + parameter ``--encoder-dim`` that you provided when running + ``conv_emformer_transducer_stateless2/export-for-ncnn.py``. + + For ease of reference, we list the key-value pairs that you need to add + in the following table. If your model has a different setting, please + change the values for ``SherpaMetaData`` accordingly. Otherwise, you + will be ``SAD``. + + +------+-----------------------------+ + | key | value | + +======+=============================+ + | 0 | 1 (fixed) | + +------+-----------------------------+ + | 1 | ``--num-encoder-layers`` | + +------+-----------------------------+ + | 2 | ``--memory-size`` | + +------+-----------------------------+ + | 3 | ``--cnn-module-kernel`` | + +------+-----------------------------+ + | 4 | ``--left-context-length`` | + +------+-----------------------------+ + | 5 | ``--chunk-length`` | + +------+-----------------------------+ + | 6 | ``--right-context-length`` | + +------+-----------------------------+ + | 7 | ``--encoder-dim`` | + +------+-----------------------------+ + + 4. ``Input in0 0 1 in0``. No need to change it. + +.. caution:: + + When you add a new layer ``SherpaMetaData``, please remember to update the + number of layers. In our case, update ``1060`` to ``1061``. Otherwise, + you will be SAD later. + +.. hint:: + + After adding the new layer ``SherpaMetaData``, you cannot use this model + with ``streaming-ncnn-decode.py`` anymore since ``SherpaMetaData`` is + supported only in `sherpa-ncnn`_. + +.. hint:: + + `ncnn`_ is very flexible. You can add new layers to it just by text-editing + the ``param`` file! You don't need to change the ``bin`` file. + +Now you can use this model in `sherpa-ncnn`_. +Please refer to the following documentation: + + - Linux/macOS/Windows/arm/aarch64: ``_ + - Android: ``_ + - Python: ``_ + +We have a list of pre-trained models that have been exported for `sherpa-ncnn`_: + + - ``_ + + You can find more usages there. + +6. (Optional) int8 quantization with sherpa-ncnn +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +This step is optional. + +In this step, we describe how to quantize our model with ``int8``. + +Change :ref:`conv-emformer-step-3-export-torchscript-model-via-pnnx` to +disable ``fp16`` when using ``pnnx``: + +.. code-block:: + + cd icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/ + + pnnx ./encoder_jit_trace-pnnx.pt fp16=0 + pnnx ./decoder_jit_trace-pnnx.pt + pnnx ./joiner_jit_trace-pnnx.pt fp16=0 + +.. note:: + + We add ``fp16=0`` when exporting the encoder and joiner. ``ncnn`` does not + support quantizing the decoder model yet. We will update this documentation + once ``ncnn`` supports it. (Maybe in this year, 2023). + +TODO(fangjun): Finish it. + +Have fun with `sherpa-ncnn`_! diff --git a/docs/source/recipes/Streaming-ASR/librispeech/lstm_pruned_stateless_transducer.rst b/docs/source/recipes/Streaming-ASR/librispeech/lstm_pruned_stateless_transducer.rst index 22addd1d2..ce8ba1453 100644 --- a/docs/source/recipes/Streaming-ASR/librispeech/lstm_pruned_stateless_transducer.rst +++ b/docs/source/recipes/Streaming-ASR/librispeech/lstm_pruned_stateless_transducer.rst @@ -515,10 +515,10 @@ To use the generated files with ``./lstm_transducer_stateless2/jit_pretrained``: Please see ``_ for how to use the exported models in ``sherpa``. -.. _export-model-for-ncnn: +.. _export-lstm-transducer-model-for-ncnn: -Export model for ncnn -~~~~~~~~~~~~~~~~~~~~~ +Export LSTM transducer models for ncnn +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ We support exporting pretrained LSTM transducer models to `ncnn `_ using @@ -657,3 +657,6 @@ by visiting the following links: You can find more usages of the pretrained models in ``_ + +Export ConvEmformer transducer models for ncnn +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ From 142420b3afa7b07c95f733c2e72ee80078364a44 Mon Sep 17 00:00:00 2001 From: marcoyang1998 <45973641+marcoyang1998@users.noreply.github.com> Date: Wed, 11 Jan 2023 16:45:24 +0800 Subject: [PATCH 079/174] Add docs for distillation (#812) * add README to docs * update documents for distillation * upload png files --- .../librispeech/distillation.rst | 220 ++++++++++++++++++ .../images/distillation_codebook.png | Bin 0 -> 57170 bytes .../images/distillation_directory.png | Bin 0 -> 43816 bytes .../Non-streaming-ASR/librispeech/index.rst | 1 + .../ASR/distillation_with_hubert.sh | 6 +- 5 files changed, 224 insertions(+), 3 deletions(-) create mode 100644 docs/source/recipes/Non-streaming-ASR/librispeech/distillation.rst create mode 100644 docs/source/recipes/Non-streaming-ASR/librispeech/images/distillation_codebook.png create mode 100644 docs/source/recipes/Non-streaming-ASR/librispeech/images/distillation_directory.png diff --git a/docs/source/recipes/Non-streaming-ASR/librispeech/distillation.rst b/docs/source/recipes/Non-streaming-ASR/librispeech/distillation.rst new file mode 100644 index 000000000..aa379c3f8 --- /dev/null +++ b/docs/source/recipes/Non-streaming-ASR/librispeech/distillation.rst @@ -0,0 +1,220 @@ +Distillation with HuBERT +======================== + +This totorial shows you how to perform knowledge distillation in ``icefall`` +with the `LibriSpeech `_ dataset. The distillation method +used here is called "Multi Vector Quantization Knowledge Distillation" (MVQ-KD). +Please have a look at our paper `Predicting Multi-Codebook Vector Quantization Indexes for Knowledge Distillation `_ +for more details about MVQ-KD. + +.. note:: + + This tutorial is based on recipe + `pruned_transducer_stateless4 `_. + Currently, we only implement MVQ-KD in this recipe. However, MVQ-KD is theoretically applicable to all recipes + with only minor changes needed. Feel free to try out MVQ-KD in different recipes. If you + encounter any problems, please open an issue here `icefall `_. + +.. note:: + + We assume you have read the page :ref:`install icefall` and have setup + the environment for ``icefall``. + +.. HINT:: + + We recommend you to use a GPU or several GPUs to run this recipe. + +Data preparation +---------------- + +We first prepare necessary training data for ``LibriSpeech``. +This is the same as in `Pruned_transducer_statelessX <./pruned_transducer_stateless.rst>`_. + +.. hint:: + + The data preparation is the same as other recipes on LibriSpeech dataset, + if you have finished this step, you can skip to ``Codebook index preparation`` directly. + +.. code-block:: bash + + $ cd egs/librispeech/ASR + $ ./prepare.sh + +The script ``./prepare.sh`` handles the data preparation for you, **automagically**. +All you need to do is to run it. + +The data preparation contains several stages, you can use the following two +options: + + - ``--stage`` + - ``--stop-stage`` + +to control which stage(s) should be run. By default, all stages are executed. + +For example, + +.. code-block:: bash + + $ cd egs/librispeech/ASR + $ ./prepare.sh --stage 0 --stop-stage 0 # run only stage 0 + $ ./prepare.sh --stage 2 --stop-stage 5 # run from stage 2 to stage 5 + +.. HINT:: + + If you have pre-downloaded the `LibriSpeech `_ + dataset and the `musan `_ dataset, say, + they are saved in ``/tmp/LibriSpeech`` and ``/tmp/musan``, you can modify + the ``dl_dir`` variable in ``./prepare.sh`` to point to ``/tmp`` so that + ``./prepare.sh`` won't re-download them. + +.. NOTE:: + + All generated files by ``./prepare.sh``, e.g., features, lexicon, etc, + are saved in ``./data`` directory. + +We provide the following YouTube video showing how to run ``./prepare.sh``. + +.. note:: + + To get the latest news of `next-gen Kaldi `_, please subscribe + the following YouTube channel by `Nadira Povey `_: + + ``_ + +.. youtube:: ofEIoJL-mGM + + +Codebook index preparation +-------------------------- + +Here, we prepare necessary data for MVQ-KD. This requires the generation +of codebook indexes (please read our `paper `_. +if you are interested in details). In this tutorial, we use the pre-computed +codebook indexes for convenience. The only thing you need to do is to +run ``./distillation_with_hubert.sh``. + +.. note:: + There are 5 stages in total, the first and second stage will be automatically skipped + when choosing to downloaded codebook indexes prepared by `icefall`_. + Of course, you can extract and compute the codebook indexes by yourself. This + will require you downloading a HuBERT-XL model and it can take a while for + the extraction of codebook indexes. + + +As usual, you can control the stages you want to run by specifying the following +two options: + + - ``--stage`` + - ``--stop-stage`` + +For example, + +.. code-block:: bash + + $ cd egs/librispeech/ASR + $ ./distillation_with_hubert.sh --stage 0 --stop-stage 0 # run only stage 0 + $ ./distillation_with_hubert.sh --stage 2 --stop-stage 4 # run from stage 2 to stage 5 + +Here are a few options in ``./distillation_with_hubert.sh`` +you need to know before you proceed. + +- ``--full_libri`` If True, use full 960h data. Otherwise only ``train-clean-100`` will be used +- ``--use_extracted_codebook`` If True, the first two stages will be skipped and the codebook + indexes uploaded by us will be downloaded. + +Since we are using the pre-computed codebook indexes, we set +``use_extracted_codebook=True``. If you want to do full `LibriSpeech`_ +experiments, please set ``full_libri=True``. + +The following command downloads the pre-computed codebook indexes +and prepares MVQ-augmented training manifests. + +.. code-block:: bash + + $ ./distillation_with_hubert.sh --stage 2 --stop-stage 2 # run only stage 2 + +Please see the +following screenshot for the output of an example execution. + +.. figure:: ./images/distillation_codebook.png + :width: 800 + :alt: Downloading codebook indexes and preparing training manifest. + :align: center + + Downloading codebook indexes and preparing training manifest. + +.. hint:: + + The codebook indexes we prepared for you in this tutorial + are extracted from the 36-th layer of a fine-tuned HuBERT-XL model + with 8 codebooks. If you want to try other configurations, please + set ``use_extracted_codebook=False`` and set ``embedding_layer`` and + ``num_codebooks`` by yourself. + +Now, you should see the following files under the direcory ``./data/vq_fbank_layer36_cb8``. + +.. figure:: ./images/distillation_directory.png + :width: 800 + :alt: MVQ-augmented training manifests + :align: center + + MVQ-augmented training manifests. + +Whola! You are ready to perform knowledge distillation training now! + +Training +-------- + +To perform training, please run stage 3 by executing the following command. + +.. code-block:: bash + + $ ./prepare.sh --stage 3 --stop-stage 3 # run MVQ training + +Here is the code snippet for training: + +.. code-block:: bash + + WORLD_SIZE=$(echo ${CUDA_VISIBLE_DEVICES} | awk '{n=split($1, _, ","); print n}') + + ./pruned_transducer_stateless6/train.py \ + --manifest-dir ./data/vq_fbank_layer36_cb8 \ + --master-port 12359 \ + --full-libri $full_libri \ + --spec-aug-time-warp-factor -1 \ + --max-duration 300 \ + --world-size ${WORLD_SIZE} \ + --num-epochs 30 \ + --exp-dir $exp_dir \ + --enable-distillation True \ + --codebook-loss-scale 0.01 + +There are a few training arguments in the following +training commands that should be paid attention to. + - ``--enable-distillation`` If True, knowledge distillation training is enabled. + - ``--codebook-loss-scale`` The scale of the knowledge distillation loss. + - ``--manifest-dir`` The path to the MVQ-augmented manifest. + + +Decoding +-------- + +After training finished, you can test the performance on using +the following command. + +.. code-block:: bash + + export CUDA_VISIBLE_DEVICES=0 + ./pruned_transducer_stateless6/train.py \ + --decoding-method "modified_beam_search" \ + --epoch 30 \ + --avg 10 \ + --max-duration 200 \ + --exp-dir $exp_dir \ + --enable-distillation True + +You should get similar results as `here `_. + +That's all! Feel free to experiment with your own setups and report your results. +If you encounter any problems during training, please open up an issue `here `_. + diff --git a/docs/source/recipes/Non-streaming-ASR/librispeech/images/distillation_codebook.png b/docs/source/recipes/Non-streaming-ASR/librispeech/images/distillation_codebook.png new file mode 100644 index 0000000000000000000000000000000000000000..1a40d6c6ec3caf9c551b62f6c1ad6c1058823048 GIT binary patch literal 57170 zcmbrm2UJtrw?2x6!x02*5Rj%=K)^&m=_G<8DhPr|krL@D1f;hV6cMG0s1Yeq5fJGh zp(hmSB_N#uK`EiPKp>>O*v|Rg`+xU8?i+890fxWh);pE`j1gAzh04}EpeIQR%0{4Y zAuv&bz<%l?xNyMK;=lv^}?fWL54 z;8~OX8jHuK;&TCy*{PbjfQBa;Ai&B7?RqoOYpq9cROFf{rtJ_FBoqyt(qxv-nj-JL z^=7S~*Iu8e(oi=NZFw(NOse#*hEz;WyfG?4QE2YqyUTtru=ns-*8*$hFZK29s)Ggd zXzLZfM51e>A0++#Mt<*x7(G}bo;}!x&PFO?NoB0vU;h3n7{jR7Sg!ZN%)lnum35HY zV!XasebBTp6BE>sui6=sk%$-oiQ$P<>S`A7>X(S%C4_BG^Z~$9kHZJP2t)8U;&Dgj zPP@%0_p%HFFtDd3cAlb-g86+@xoKHwcr&@fVRAoBu9QHn7AyqnIzeAmra;=_z6_K( z;q1t(-1Fz~mLXTQtc++H&9!^~h-Cg4+S?ofm8%h;3clO#{c%ws6aM!SzKtxAQ(a#+ zBZ7)Jqw&_FXcwyw8@X7V%g1gmEX>_%rA$N`ZR3m2B0k=})3ASk%Nnfgqk)D;e(v^F zcSUd-*4yj+b^z(yxpLL0c%FG@vQysnltv-2?i1pgv9)33wbPysPz>SZUd13TG^41T zWe(H-ZGKPZGP7s(vnKAiwgB9c>E0loVKJc4-{6(M&{D;_f{#;7#faU$_GTEE?E$DQ z(DJTUn^t5Dfv{^1A}krNKzbZi@nM2GY=jR~JQ6A8s`$D@zpG5cLJZXh;t7+G16vUv z zg0IAQy8qs+u@RV~cERW!d`h&~6iI+&7hhs~HP*|6=xPnH{qWj#fUo~R&jo?mHC7P( z8cVe>#)+n4(|1CwbikbNJP3WZSfqWJJfl@MhYl&MkX#!{SFAXqO}_@$4kMp|9Hod0 z|B&LIz|vy;ll?yh*wT(h^STzE;+dIpcp_M&j|!;~5G?b)oyX&=-W<#0hIm5VuN@H< z;N3-VS}NI4Nd{j>Q2h$SA7Pj8)m~MC;32JSLWbh{K|X(*<87tiJ^_c+4jKTbKS37N zMlEo$(^Nj}V^8SO*7;ZyGy*@umm8%;Sv2rgwJE;1G{3ih0wa_ig1pNt8XU%DANj46 z%`TwOy&aiD>M#Ck07#i0QxEP;6Wr*F*fc_aZGwqSCV{)ll8AFhtH7Ccqr;Wid@=D3 z4shTEm9Jk38guPw-^hRqe84{m_^s@X{+D;l^dXhGA%Kq;eZjQ0H2-sXkalNZqv1lO zW+G4r^EGuUj|o)khO2cZ($QzsZAE!!L%c>o#Tx-&$ZO&?n}Ml-rr*2luSSg7TV0ge zie0!s)>a;X3&OixInG$f&v!o6yAj?E*`>Jir&>XE_p{EEM6KydVqD`Ebr79iK+c*3 zWu|G7%?t9)3KiM$K=A7-+YRJIg5THuJvMTGK68y9C^O^Qe1H@~tZgpLqFXmmH$O0Rn<`oS zaHz{>DXRk%**@$sv7c(|Y=O?kKU*1^NbNB0rkLkV0%L-h&_lGcjj_N^JDw-q198qU zoc_W=X-Y+IDS5mY!Mw3`3Q6N%y&%vT)Tp^6M~H>0kkb5Us;)6jMM_XqRt<_=rS|PD z>Qt*iho7lZ>YIGZ zhfe1r!G{egC3++8+kBq-t-{KJWs%e}^`L0~@|y$&nXA^{k|(|g<73>)+uZKB?qyrG zyNG$=EvIa(ma|9P@4WT5RAN#`JL$mmX50ria{|l}N?jv*kZ$fvGTii+&}ymsT?XPAHfbxpHKGYTUe~ zELeUO-b7wg9fklU9KyT=pIvG3JvIKWogXUZV7p+i2d_pRxuv-S1Y^8iV^?deA4>Tc=9V#JJeJ=;9&ZC?+hlr|8Z!>Sf*6TP)Uo5H9%xoeNF$8>Af90 z{XThMkz#F&ijNw50f+~d^ZPMUfwZ!W8;c}3ILUZ8L?S(RIDb)ep?~k{${7m3iWBk_ zY*eeM*y&mVKMQ$x|4`3{k=q(2ZSekNZ)Mdw-6jqD@M|W;+I3fqRkQI}WeQocbF2r_ z<0ElMyEqxzS3H%9Tq4N83JH&~VKh6{o}QVy>+E4kd*3V{4S>p|{u>{x`TKLzoddL~ z(21ep+~n9E$OC>mV(n|b9kIZ#*!`PfD!{KdQk&=(RiM2mmGo1ZMzA)7T0X2y&`Iww zFFRr$?9C02>3Qp;L&4XkSrDVefCc~Rvfe%{aEZBdiN&H=Rhk`fUcZ>>ZE_~;c(J#%FnQv3ZnLt} z#L}sR)Tte^wZ1C59rB{cZcvwxQGkR#HHodLJolYfdGe)*oUG?m=~;iycFil%$lnG7*>=n3E6swtCdvi4 zDy}_&za(1F7xM1Tp_(?K@>QIgpAFL)C3>4 zZVO+N#kZqoCY$t!3CM%m$LLkG52%G=Foc;*+;EjC7B4b~56?TeJD|`saZ*U|0J;0H zfzK5})1vk)`~g40TEX#1Z^MiYctE{~=Aa$5_Z)OSsiSoGq#JI-dLlzy0mD>D3wZ82 zq=vH1f9K~WQvInKKrU!!4iVd+n)4rHtwU6)D$cA0qO=t4BI20XTJ@3AqfsI5Va+QD zf4C7f)h=VAxCW;&cC7fpo#0JEeyXCfm|UvSP+9xqGT0)cDBGYa;xhhD%*|;^i)$ib zJHOp)ePi(z)dV1~FXH@#a)nLoJPiC6-0WAb1P#opshe|<4xgF#E{o?Xy>L%0_SDr^ zkF3G&!1kuvTb3)Y+()#X*NPofdrtHxCj!&*9UW{ZL%)8P#tckeb&q)MJ+d9sU~s9w zW$WU)xN3T6kdl<(Y*C!URiCQMLaC|-uz!NP&U1O_1r9J>QhU# zRkW5RNK?kA55&I&naLhVbL^R#+`RA9@NE_N9`eV3r#jm){$_ za-D@gO*L|#*P4Y}5%1RXkQ+y)<8#HU?u4Ye($pj!i2yIkkyY|b_26U?eJ>C-K6Y4N z1?wUH6|0{kbaYN^Xq@KL@%Z>s7QYXMhMiU~bHtqd{H|awWD!s&r3ytodg<4;%|{%B zL=d7otAUO==I;BLi;~})54#6!JC$FY$OWcDN`;8G6fSQLXuCxppFdj6{0?7AuX}c0 zYc*bWJ5l-D8IKx~CZ)?B>CvN!1ceXRGwLKtyIv_u$zX|(+Fs8kZfn%fcyCv3MS(%H zSEAZHgn6P`npM;&*GZq^B$N&2E>EXpkD>*B;J@a44_v2!@F#qDV{!X2VGM?+7d29F z%nHX3@VnBqSdxNVbvS4q?bEphp^7 z5>GpU$E^3~zVzkSLw*ZZD4}cJT3T=U1-X@I;O&j1!08R@O0MgyLs>0@mt* z-PE%L;l^(94VqRSzp9{%+qSUQkljcYG!fL^dxOq|3FX>?wZktwK<5`&aiCc-TQ2y%5f6W zGpR$%#kD%Sx@ZVQ$(bx+yN^H|Qk+(A5TxJZ3Ztv5Z?`u@ zK2I1fAQvZG5m?%iby*JwsZy+zJiH|M!FJ*@PV>g4-2*h*)F*zWPF!~K;HPu_jU8l7 z`U(1@!GO;an5<=EwV#W>XX-nAG<~Q1vZS3}marPlCcJ4v;rp3^vRF8dmyrGZ?#mzg zRuVtrR=HJw(!L>~g<5sLmTk4DP927lc1jsxy^RCx6Kzg|dpN*;e`1Fu z3>d3$?Da?Qh6+IwpIDzKe4#XrHzC*-8dd0?VGmr&uZ8(M zA7U(3q_a7IDQ7PQ%aC5Fzd?P3R+_27?sLM$;;q8bWqbMHG0*Z>JAbP3A zorfjAs#4n`f!UNS1MdAbd^dbfel|46sMfBkX+G}S(fo8;%vq`Pufcd*n z!X3PQeiWmHM_Cz2ptP#V#K~*S(VSvwl}D@{v!IXt<(n{|&ku%mp%#gXg+Q+n3Qb!3 zAk&p{%^R6a)U>#B$@Y!SUA}>JA!25ba5QCofuU3%4QX=0rH zMbg|m#rm&Ka$uH@OZ`Vw@#2x?p5RSHAl{wGF#pdGVTPv0Yjdl_?a-sC|JobkU3Dz*~JlA7?{8}G$!nVdu zr~JYDO}pl}kbC3STy{yA>nr0DI6*j7W8Bn!vSjVx|2~!bli1aq5C#cJuC^&&Jjr5+ z{JFcLT$}zY^ZaoK_~1YD#X1|g%kLm8L|bsyTef5dt8!?1*$Ih&@i&*XnUN$9H&YguoHNJvxa%U>)QUAw*Gi}yw3#9Yc`mKIS|O6 z#3+0OaE+o+Z?S18W~taFb(ML}O%o&pzW?se$Xtx`!|6K#^!1*?><^nfN7#kzb(=iC)^fz{tOq2zQo*&SWDPN zUN-9*Gzj-UdWHpl)ThGqqe<{nGY|iHj%%gJ7yq$PczYWUnc)jv4U({6dtZ20`TEK5 zwn4ZG4@>?RvB!CS<=~nBum+dgxamKW9!jhmST^7;PHPLN9X5TWD(1cl z9HE=)&=+1T(kcdDVpB9wdjY(=+e#h(8H2}*n7*^5BNFbBxjZQEg616)`JQ3GUHC0I zY3ehk>Z8`8={x$e08c~S=iK2MmWZFsWx`y(HsYTFw1?9l-pCo)saj?BT9`(W>W=%t z-d-WztIg+I`Z$B+{F*#JVKr>-^juP*jXE!a5d5Q+kp%y&XRJ*w>m&0IOQd8P@<=BE zJN0$)BmI{=@ygI?N+cZ)CR)VXn#p$7kwU;TiSagazAQIH34y^`vgOD3_|4Ik`qYw?Njnd zjr#^NzLw9^Aud~z*yPNe>tL&c$Vo=CxvGy@iAk$6=in7!GGR( zguHAcr-siykBoF+L8bq>e`YNJC?wFxW`*+oLqd(qMg8Yq(dJN1f3BJn|D`1GUv|-8 z6=x5yK0xnpZLBh!z)!@YZj+w09Pd8I8B747|(Y2E2xyGhuEL-u8RVNDjrht zKRnT>B%m10YqG5zr(-`dfAlO(Pq1Mdno-7{I`2W>;Nzn1_;;pNvZLt_87+EDS78zk z0v_NUh_I0><1S$YwOPmoOWl0b1P)aASMCDcO0z)v7V9ul>0LK-GE~mD8Y5eP(9N@v zGhA$Ul{o+x%?P z2dKm6agMMkR(5zv+cd|?;Z=1?mpe6wkLXd4kOqVLsCSZGTya)Oj81MQ$#Q`l_C4gO zAL;SG9zcE)k~MyaX`)DfW)ObmXlF0qvzNGV6Dp+4tP7y}vVrCT`(V5oC0||s_rIc- zJ)mj|eXJ>Er;p3n$cf)4osVsjr4ONb*DA%xA z3XQ2S+c#6$TOZdv@ZJm}Qz6~Az9;gVK!sV?+7r{V$J`~VF3x$qj~7h8eOjcI-Kyfl zkfaPKwlK;f1SIg;IF9v!L+1ZkHSHq*e&*sb|Nrm4|D7&Z&52haA&w8_`0i!MH_GUu zPt4}B@pIbGHBCx#C7@TS%vHZ62noWbv$g|$h;~mg?_C=`3%+%0d0@&?T$?~i%iR@P z#{}{x36B@DFw1ttn-^Wcn~NdpKPFIP?EGV+xJPueaX84>wm&$BRcjF5B>vV@>j}r# zqL`aB=2A(uh$btuJauHkfuIZ#_kfM6%+M)~$D zmW>bjax6J|8WDeA`~{f)onGEhH^1_F*Lth{xyKHx?H8D*7q*s%8d~M@KMh1@2a_C@ zd#HPWV`RO_;(=Cooa@xrV-qV28wDyPe6?9%2IQx@?YRB2j*u3u)sy2QdRdy`i;Y87XC!S@={dd{m@E9bClh?y8#JDZ6T={md=oc)zzuzn7IFJa@7bR9|X8O|Z9F z9ipni0_<0K5@IN@*Cp|UX>gHS_}{*#(1$cE?h3rQ(dbleX7xHh=|Y>Ptk*|cm(yAG z?20X~6>@?~nrpxRnK{6|ya;|tg^}maiM0--M)y80C%~;8G0uu>3q6li4w)nJnYvpO-ePmg@@^`? z;-l`4?A|)#RDQ6pLy1G!4BNg@qv$VU3fKzR9hhw}OMo?Ch55cR=eJbbYA3CfNaLh1 zAGk|cu*&nCh<2(kP7P@vec43?+V4NS;?-X1p!~}|-@$0!(RbT5;*D5A+^pJNO*+M9 zETGnbI-NjR+M!X1m~Kj1nf7zH$0i* z%RJtg!f!OPaaa>lN#QPku}(&nFE(cp2w`uastdWwOf6E^R{LKY?#HbJm5FII4zqLN zl~|gbauUcbyS#?7Q0?|}53zP~;E8zeNIj-b`@@@hD8ipMt|^+yx7Pt#usDEdXy=7iLAO8sXI*(UDuE+Jpd1xE5xjI6cTmwQ!T9Sx82ur7KB_mEZ3;tv69w!r-h zCemJa18^hlKiQm@wM|hIk4#kYF*p^y)I>6eV^~UY9BR}oyy=z`lSBxh8c_EYF5l4v z-ym%gCfuG29U|D-2@*0W8{3!(WwGAkcqeVA3MWeXYLw%PPIQS-;#pPtvE-kjdff9p z(j-lMczaim*XuW|4W5vFFrG8Zp{2en9gXQct>>`QJ($ISB2%Itp06N9Uuz-x9kc(j zU1=HD84Zfg_G>hZz>tPjOUWw*NHjTWK&G~Eo$fPaO@wO6JMDv?K}upum_D3WG^5-i z>a*yj(p@A>B4IKS0xW%}v>gabJ~Ozf`6uHg2 zrFmu{4QtFw*?$%ilZLU}SDgz?B~-f>)7~dHyUQ|neehgB3)ybdK47FE^XNaMg8KA= z@zK7xvDewKpLPOW49Gga^5Cg}k|FA~2Z4|L&Pc#%!bg-W1y@cSQp-;35(wKrK=x_E zz~UEG4~Wa^SIea#hGiy;yY&6?lt`6NO%~8jowx2B)JGjwjWcQ3^mExYf0OQW6zlN+ zG26Y`?uTne_8{YG&t$9{kPQYd;83X{=Udbli>IoVCL2S{-yBrUnYGKKS{I=*D0+p> zjxUx6g2Wrk0c`n;{_6Kjex`ln;ha!uCkeijxE#D%j9)?7&UT3no#L^zb>dnHXz6dg zS#*y2g!XgxT!9GWSo%)X<|n+_v*m|D5Ig$M%qWiAZSSuK+6WxFM}$qUhvKgvLF(2z zBC7_!GG&$ajTaB>K=U4JHvaKkD84hU>OUYOp65T%{ZFNPdyc0{;?#OACWSX05b;zr)Dh4=Fy=hM)$3gi>$a!E@qfmeyg|B1)cB&354W z2Os3uU@FbF_`?V1iL0Kh^)U)Vmx@53s$LS7F(fuK6O87I&>xQb1FUM~yKLopb*C05 z+Gql9v^?)1B#Y@K_Yn-aN;I*oMlX-2k`OAWrYfxgAKzrufY$5f56z=nnloNU&7&^n z&XYIkL;Zj!Ihq}sCE8f~wt*<{{IDk8iMs#frD(b57Z`&omqB&~TeP09 zi_~luQ8=pYICc!~!QRzBv3_BpSg2YWzasv?b6E$_=6%X>qOTzsAfswztvb=wG`hLr zto=LDLkVnhr~p^P(=Qav7P53d#DKy}EIz#6*^;81xcus&%xBPP-1DFBEClk5S>E{Y z*irD!(#FsTt{6g)V4tnIa=_XXH|)X}(x29P1+HR;jCMl4!^(juqreh>AKesbpVAH)F|yKlzwT2BcDCr&G-|Qh%zc)DlV_)5Z?gT z)FtHyS>*_6W5BD#g@ohSrSykueYAzpjgN#I70@_1|@2Iwg_YJ zK8k%%;*j;}HYC_lqp##E`BA$qZ1Cv?(YA_7oC|j2H=V*b$6=|Sm_E85Z!(U%Vt<<3 z;?$trh5e3`sf)F6av4+iS(N69{Fe5}{0X1ac~9#5H-}X?+uw7N5bR5KDt0vBEOd zK^`7&MLiwV!_`YN57p;?V13yEJoI}N%|}9iXr?zXTAT| zPHqXPRf_79PP;llZleScH_k$GcUhRWz=Y&O&Mb5T~4rChQMVlnv(U^m_TRas+DA^c#)ad2-&tqT}gOzR4CWa|vF4 zBwUp(+L?lPs&%}^O2kJiHhvh~hn12|878|a4Q235pZue;EgFNvTLYh0=q||~hDcd= z;gR1y2Xh+hQs)4;%i9dr{Ya;;X|{h|eqg&1IXt6I&lX<6KEp_Zmlrc#m=)^$OeDV(@$bM>eh2GbxcB&Tl=gz49=+2NFfg*}-8 zy7NBT`b}0}xG5=z%IhhA&q@KV<5F)5YRCGKgWS3^nm#StxCOm@?>R6ww6E=jf0|4h zZ^N@h#2EbKU!!W7qmA$BSTh07)@$rCzTjA7Q+6Z)Hd=nBMTC-NEBB(z+=$)v+Fjewd*3hlSFUq1 zY{Z!XJ(1GHx;VavrfQ~M(i5{apoVJ$#xXrSsKY(OV&CAd7*nSl;r%(TX5!AltpgQX z4)~8)^WM1NOVpcha<&!FLp0}anjH^0KK0x1Z3L-{c803#Q-Gh?sk+uSBiQCLB4o~f zR|YCK40>MTr+$&3>X$h6=!pMScV}$`+)P4TTfiJqdaSt7FV)e<-R>M@VbYbP+M)AH z$eMalkdW`un>d40+)r*8N+A4J`yzLrviz`5d7?c4Ocd#pNnGxQ0gTZ(i=SnncBi@b1sf^CCP`F1oP&0!$Z_?;w-%Ft)cQoa*V}Go zDH>Z}3I0;LnsG+VsYJKthwJq&92WvsobFNv-78X|X~{RcvtQCWSveEBCY5(FF(%;~ zAHAQ8#Gs{)B|<;uwzk1b4Kg9M>N*rz>ul(Am88mRq?de;%RMa#PIjA#0LO+CiQ%tM zOv-8TAy)w#ThvJAwq5+zZ)bcTssa>hbOpiY*$v5Tyc?b_MX^<8y+nU9%E|?wt^3Ezcdm`F7(g3ESDTfyE=^+2x+^=i5OKGux0O@)OAoe>o;lqez!DUOcpEa4Ok~-F`}N zUj^M?l1r*2T(Prx6(sF1h(4l~NhR!kYW?z9@LkkBOtja^faTNfx5WG> z*HlCr414;6cExuEDQiCceAo~uGxg-iAPa$^5@2f*m z;I|2c>sVvvftjpDDW5!V+vzKT_9jihVZ&a&awWj*DC0HRwW=|ipT~Bu*KfwG=E5we zQWG^H6v$|28?9GAjj||+I8?{vZoK(X1ZBfmpHTw|p0ew|%rj01GYFSz+#*Ry2EMzF@~wJ^F)ahl;RVVU z*@r4F1~)qYCY;7)IBmnf{txH-|A!7e4xGhrvW-zY-ZEOg5CZbA9b>Lo>=x(+HMpbf z9sA-Ud6;WW3&sX{lr1PLQ0fne$!1Ut!ZQIoHVEGo^O-Wtywx8d7kd2?4(Q*fHP5>D znB+g#iRLp6q+onUXi;Kgn5?s?+6zq=S>jn=8$M%x`TTmwjUbsnfvd)cb_308x-#>A zUUmh{c3x##TUDEtfs9Yf(!A#=^!iiJuf+L_9svm6%>NjZT8zK9!;_D8>N{l&8Jw|} z;^%?6tDUozy&O>4uHQJD^^?isC}RdBH21(J(vxm`>C~;x<@5b@Pa-4J&Y-brWq}}4 zgBD-z+UUsDpgi)78eiURT!M5!pyW~JgzmtAa&f+pWw5*aZV;`}5rBe<)oeCH~d$KK@PpK4SsMH9*M?wD|&i=lVjoi$I&%oC|z%)D% zq$epPq?9T^HhtpD>{UZt&JPb@=3pu*DTu2&1;OL4+G)BcyM1cGwxMp-=_q>DBw-j- zAn%qPjmSx5!c&0hWhsEq)zLlzuN=w6RF%)hjYt=PSUsEZuY(!Jj!VTIHl~)$i0`;b z!KrKg{Zl8jNxdhEs`VE}%DF5zM*GuspU!JFO%GPxLk_kOaGr0%eyzCB*i};7^a1wU#(0EIxO|88+KIdW zxPEtVB=mBeE?W7E%@VF%#p8{m@>SeyS*o`xC1C?Rz-h3Y$iV#exgx)v|8CLPFf8~! z$N4`w>{N20Qqxt5G!<5D;WUFiPw2&6S%o)XZt@(q##oD5iFNaS*jA8#|6y@~P;vTD zNg_cD|3YOR^#gh{q`??KSsvcN`XykaT$d9JRx^WMd8Wt8ttB%DO5V z`WZ*p1N5f}2o2^QY^}}iL&RbgNhsbNd4%R;>q9DtO7OIQ|X9oa>o3t>C zmE(D-S%|OsLNe9K1}yRG6$;px@Z>&P03m(D`KR702@ui(7gd2pP z%pX{N4oM*=8h^zW+YU3NTemMh+33Gm7~@_%5hZEs^(rnmJZ(TF`X+W>@fxAk1E~IQ zWU#rbXjQdxiWI&~f3syKpsg>qZLpKGlLwqr9|45@81FlF6K$t8*H2U@Nl_Tl{q=Vu zo*w-JX4QT~2cV4+DmZJnD9vaM!E3+L-u5>8zDE%HuMHW2KgI2uI8F^%b(PaC@S{q$ zQo|R%oi-uFcu|KiVcRyr*qL1?zA6K=e}@Oc|F&TKUn&mwvsTVe=F3)%gN`RNl<*72 zth;t?IQ)hImxZG*;2DVhB!rdNHSjH~obTs*Bs+hk1;aavSuGV0O6tA~YnQ(sKJ~sG zdRc^B@rZG%u}s_k8$n=c`0!A`^5FBTd&RD*^Q-!-{iX7_yV(vQ5W{erA76B-r&1{KYFt-Da(} zK>YFS5xI%kx2L$FKVg4o@K)|GTE$rXc0j~{ke1cJ2AsZS#sf9Ra(08Xf2$6ZvF+79 zBrr42-Tfw?j-;r4p^knpK9%s$0&vN~o_I6%bS{r8auK5dP=ap43OzM*;yUa6PP@KW zSF+t!fnbeW(9XoBK-&wGGZ?wAxXZBHquV=!7i8!)y-Pa@%m^RAPZMdyaiW3LWNX0c zutZ$$%7=JW)fhf>-s z&zxyAdLK8TB@mA&-D=l&fA(D(xZ8r9qPXSntetbn#a|(z+0gOBw)MT7dt(KAt&$vf zxY54RitOo#wT4x+V@UeYkqHLPjo9>VGc^{a8SvBl*C>0j>0!vmWFF035{LV0%VW^> zn2?Vu5rzSsPT=KGT!haP{3n@jv-fyY)1{m3)xLH{(n;KDx@W#}S&Y{NaYDG_cmUox+` z^858s6ZuL|j|=0jmxk|%Z$ejKQ=Vkt&vyr&bVKt_70gNk`*iYSGIz-^jkrTzdqhlX zqzdNN>nJ8JUvx5#ih5uYF(2GoM%mZN^UWl4>-N$y?nl+v_cb@&dc{os)>K-mGo8mY z|GtkH5Ji|4QI~m7S*hwQHhE!lOHd`?4&o7Y7MIwra0#2A(X0hiskpYq*?{*dweZ@| z9Pw`J={DFtsbIaJeC7NNJp?QG_ZlY`24X~0LX&k@- z?D)S>GTy9|PR!q^(gFFM^8r`A-W3S0rF-)I>58aMq0($(w9moJ#b$K16 zf+0-e!%r^h;I+|1F|2+qkUn7V!3*a*=bk_CGsm<|)hHPG?(h2V zT+R8Kuj#t-1?vGWZls)C>;RdET3M>4?}paB()js(R;zG1!**Q5Wos!gBMNC}67KLs zG+_8hFcKF+5v{EBGOA=6RZpNhJv#nnub-u%x`VblH35Cq%l9xrNi&-pt9Fcm?Q2Kz zzR2}`n%N^1X~&t&cZya(n`bT9VG$eW_DnCtq5nq1aGm7r;p^W_qJXp<5SOo)5x2$Z z>zuBNo9{B3Zqn?IX40?d^zwfE3-HPI=!@B*`I5C1m)aQ{@@;j0uyRD5@Du`~tUm*I z(3nT7H;ruj92eJzl3lWLU2BL3T_HY$CRJR-ZqD#=ZEvqk#YecVx^Jj)5O{3+;envo zj>4cmMnM7wK|0Ix8`|D}kZZtgD_$b)?k*O2HJnrPRV-Ny<17EzCCXQ8Xl#23i|dhY zr5MlzeD74=xLqWfX?z*bG@wt9w3|(TsJy8?vjO?22u7`iPQP@}bo&Qhms#djm`j+Y z`YyKrz}Hmxf8gt4sGy5YOcL_P5#S%!p14Z0k42qzrr(5VhvQE%C+do4uVXqkwCKfl ze0sSScTVUfS#TY&HCduPQXBTc|84cIwE-3%;G>v5vo537cs8!6Wc$S3Y)m%)|IDO} zrq^9Urf%QI%!TNE|Baow3O*Zf=gZ1*w_Pn96AC7)o!aT=l@zNPuX0%RkZ)w+?Y8v= zJ`abindzNFs|MFUm+Ot~m^Z=jaDRM8(^KQSnwhJ(%@Y0%mwU@(U|Kk2F5~qUXDuwqg@K__KDD>r$~mc4RB6Di=^6AYFXBA(*VMTN-`UFTC2_xcc#-%02Q|@Q zUqv@3@1s74*3+BO4?E=D*GHz5&I3M&XIy|R;q+c$Bfzf6*^T}J$xQyQ`2(Zt*L)Y9 z^#UuNobIj?{hvt_mh8Wz$voU?PUCfB!Pop_-3Ig#V+P8LtzlSYPp@bWwukhgsA}u&TG@! z15%h9)Jxift?bYXR)tK{L~|o*ABO=S{{tfpIhi<6Q5Xn%YB9d5%+~DQJPEX`m3;2` z|H_kaZOt-L47mSyJP8+L=MP+1``?HXvt+!xy%G8{p)BV&PacYDllTwUx&rv)k$)KS zzjm$ufXSg!JD5f$&SLU!%o)qUoR#9T>pb<0`)v~`Kl*CUpjhd3 z#~5oi&^G6qM3C(Q6MF-Kvw2ph*_!lh#rW!WaB$F~aOTxaD&viz2kN3nzd_>7p}7N1 zbFwzp`qfMo;nemu!{VSFSIUdwS+^I!SsK@v-_Il@=0$``-^jNrDrUExRozY?2QHl$ zaHJmnIh&P8&dJ*(9>le?l01eO{jI4brCFJsu@5IXL`(UW`ow8HqZG}1X=Iz4K3&(! z@%TDefns;Uns z#@CXNS6=HNY609faP1HAQo$i!ic_q^r2+QX&}-`oglTR8+NF6oPa+X;Qaw{LPtg3? zRutZcuA4k~&(i#9J~~9BzcO5AgVg#`KnzOs*TX(foP^I=L{5>L3I(pFien(8q!A)> zb_hmAS5AI>8AOzvZM7LI{W*sku#{t|r>t(a)0a zB@#g;Kt9W}zS+@m; zhaQ~0N8!`a*sZDOqW+rlz*OpM%JOzBrKR@9{K)p;g{`@x=VqJ2!wTQcB+g4?0+<|r zba6!sd*?l2%KH-INH~W#ij<@Wg?A&{-nt^p*Ahj&&!R)FW!%pX^!4u5C}b)5$fxD*g=pA6Fc@bfSe^whnn zl_g>2v7-UAH3~oEwVIX~5Bi*IqwvgQZX60T9-bqAl3~gtJfrpOyL@UUcPO4R=P)&c zTQ+rbs@(%&8(pgfu&A{WTr4rLAT4WAG^p=Hbw;*`H~}ZC^eG8;#PyBgblo|^ zapU2N%6jCt!JwrJsRDkO8FyhL&2r6>~zxa$vW8veU2^OZ?czTdmlH3zQ)j_cp#TTl|Q0f^3C9LFtmV$PKkVay5vl zA~qnL(7-3PtciJ=>SztU>_KKcY)>dC$esYeunJJQ*CRP)Dxa}0cfu?U5)G;OX>Hz; zyk|2znWq|LCSDJMG8av154RoiBZWk0kH&N5; zFjLg3OUPg8|4mMGdq|P+VC4<=EenUCngq-CejizjEZUZ^J>*X0Bolj%Kz=22%9!?e$@X8uPMybt;F3yyw64gO2V^nb4%-Jji<8xfM#2#V7YLuCa0!RV3|6H~szMA;faF!G4VP(h%d2w&&E< zrLQ6k@2{d0;H@{D8_LKq{@f6IQxSl>nfNBZ9fsvelg4-?Mva|kX1F@ z?L3+W%4AAOU9y%>(=Dn@4^2KW|DAY$Lyx`S<+CQT=$)NvB;Gz2xd;{S@3$_G$Ap)h>u_n@Q>!VXpr=5rgg{ zyPUrAegE31(wFZ31*4DGv%P`3r{Zut2R6PL{BBfQCd3{Zlt^mJKt1)m(2}2ND_|p` zO2#Qx=n!I<;l8hI#K~?j=4%Uk0;wP81l{JgK}s=p`F{^2o%4s--OT^}R|SNNCE&Q% z%kKzw^2M~&mj-ds^~D4A!|=Gk&8T+QVP39%&)KXaJX&@OM(F0hvyRG zoL>YvN@x{%#XUPkSc{2a*hf7o2ab0++Vofe(hg z)<(lTqBEii&b%Y8*&Q>c3su+92QaO(ns1^Qc|GKZsX^(4PKwGP(xqrX zN<;)yI!K2g(jgG3frKU{p|?OFB*44CeBU>}_kQa&|8Tijl6CLRz4x5G&))lVKAa;l z4tY7e^6`md?iZA+jxBah9*b`m_O3FrN19_c98#EiYl`%a_iXk>u+Z4>gxE0cbkT`e zNYZe*uRY;9df8%3&HKyNEVKj(_Irp*XFL_|p}XgIP4%jCQ!BW^@5-R5l_DrR(p~;k z0;J39%Dkkyd5(R^K;Ql!1kwcc)qfoIN_?|QI_allDA-?=D|FhGoOMV4#Nvu8>jBLT zOfZ4m1V@PW7hV6OJXZH$Qsz~zwoAa+{;L#$)Rr`=j zdUBQfPg3nG87`;9r>ah|v?W!A1kkf%@u}?=v%hhy*#MC~;FDExL>>w9hX8?~x5x68 zL|wj1E+6VI^>6tEXNp^ZD}G%#TUtKJ{W}R4LD0gKqa@Zroy%6nDn6Ti5bKIHy<=_0 z%?%61RVd`wfLExy<{kAd7eSBb@8cnSihG`Wu=))Fq+rfebY!bX!Po}717>Posz0IH zYQ$6xjnWy<9Ke%lKi)s=E33yB7_|OYK5QDWYOztY)NQzuJfR@F@uZ%?7#Q5<(t}yU zPOjDA?}C_8?sdkG6`kNpuql6T1cpHmHJ?MYC8)HQ)U?bxHOv_g)Bip)@+hHDOpyocgC%V{rw2t-3d=`!s=o10J7W z7o3|L2JeeQCB5XI)tcADd{Lb4GW;iD{379Y5tRvX=9Yi1ZBwMKfDJU}?9YJx(#2+;cBN`3GYz>~i@{h=% zw3^LYg_(sfquh@Z3=)F=;h3L+2*lI2`lbXL)e8-g6S=Z}y)$}rk}_RQ6=(jV*<~R< zKQqDqDXxip6eBZF0?gD+uMC2Iyyuy4Fw}llI{?RL`x?ylWgKw2n6+WS3S4Z=o+x^& zAc*ZhT%3>L#hps>H*a{9-6yuuY6vRIkI1U=$}oJVdv$2~-Uh-nKfC^=hYL0d%_|O= z!}o8CB>r1@Tbx0!U*fjDA-7{e?s^qzJ#-=gxPO0C8J)B^Sou7%4F*MkqbvuW_vby@ z-<%zruS)j>y(SZ*zPL6@f3qbt2G21rM9VgvueQOR%5WAH|TMOYL;*GW@G_%no3?z40mzig2ckC(WbY%7w#&Q|-dHKwZ;Jwb83i}e_= zpiBvX_TN||eNI6@5$T zpd5X4tbM=4ykB?{ZjkxX?Dht|dH{?MHC`s8z+FT-PdL%>4pBRL^-7`odtz6Vl1-WH z*WJteo?DGLs-?y;B6A3Pu<}cPU{h$1ntUi(m^R|{Y0uK`EX8dzCUT~*z^39yOE>6Q zDC_8Fsg1&doDBMPu@)-CKc$TpC6Fkh=r3z0w?=L#bK1!jpf!M2`=?W^Z}3Vl1POIc zY$_2dW484jccKN1J)0Zmih9^9|JBxx@g7tTDzhIH-+B_VU>{)zZO&MTYon7f#m=m9 zVz{1@DgPP)LWB+$f+eOf)5^_-ri+_LXh;|AGkAg0d8c@k0mbjo7+^q`;*w_p{^sUe zPJQ(L=JOI>$m8-GKd`E>lKIm5N)cqGlEFz&I%U$K5Et*8LlzqwQhw)DdCat^fn3T& z!mDjCM_3WXWAdv#3!^=iBFXx+nN?^b9Q+x6uGz22qnWBZCeU~>3FIIt_tcr8qol|& zG6XnI3OK4le>3^yB<~yw7KKOY{~XLyS!P~V%k|U9iVpYxOYfE9!yTJC{?X<8@wW6i zih8`m*71kd=hs1BE%sB~kX~Z~eKKKw?a1in0S9^X#if5*)i;y(tm@lFPd7VGZe@zu z{UeA9!eN9VyOrl6h1=9sFYbEPV@~>&kE5eb>TGClloFHF17_z~%JeU&6u=2xq#JlY zJ%&_gPYRhsSD&PKuN2KRa0VUGZ+5>71jTSnZ*dV9w#RCmjraoeP4r9<$bA!EOe;;p#YIKRM?Na zYhVM8-TH0MT)v6yBu3&^TQABd5M8=pSRn*Mg5&!pV9YJJ$+XQ>7D$!q-l&wO`XmZ# zPU91k89>Ogr&ZK`haH!DM|z}O2}XYQbngF z{O=@^UL=2YdVkC@)LU%zq-{tQ>h1?mOH*-1y+hV${p&xeGaG-Dmc?c#AakzfO3Yl0 zxGM)Qg(;fIYCFe+3N7CS;CWKF+KK}RuYSUxMUc@yS0Yg0VShDoun+3C|7;=B%rPL9 z+s)QY&Q4z^A@kie>`_D#=(qq`(`qvfrUEM$`M8-zFkX8ejUdATJk%V%R(Msy`B7Ju z*<6W@i#{#XDys3~a4TIY4B-!V<@m0C0CV>oheIEa5XYi%*+cuyhm$`Ag>7tyXJw2y z)(dSMCNB4%_OC+%#^MryTSveqMbDjepY{Bn_T$gxV;qZ0{||^kenoH? zv$1W;<^gjowmb>n#VKq9;1iswhOo<%)oqn*fh#Q_-HrRzefQ8$`L_l}3?~>cKQAKc zBTID}wB1^7VO<4>uuQfFCWg0vKw3626>I8exuQth{NoB7BSpJlVc$ER^$){YHv15Y zlQu$5DzuYb?_vmzW~da~v3bct{}dqd3paTWh0k&rF+5fg!y`I-f?GTC_v*N>GAg~( zxwn0`96?vld23=uAJ?3dEmv_iZiFL-_XX6^j&Qxv5pv9FLH;F^`=ZVTN)q&U$jUkr ze*gGl@`xHndcJzdrMZv*H)b<8lrjGEL$_C>Dr*T7@-RKh%7SXp#zJ*p9Z>S4MPaaM!>%Y$Ncd9C_m*fPU&)Kr%_@v`B+~s7#yrty zIG8nk$?PrjPjica0@{G<^FdSV0-d6*t@rqQQtbWC|K*TF7)MZpLliJ^8{9M_{D#Ew zoOBil@_KgPRc4=0j2AcAQ9h*@&{5bx-lc@<8UHKEi=B33E01ep2Om6e^&ZP6%&lje zvyCQg3$~YCLkRaib=kHskrVf7=(Tn=nC4gRO!@Y^2gO46rFd8iP9v;!Q5rK4eTj30 z38ZhH?sKI~jW1ny5iQ|!p{ubpr}^lNbd-=lVKyAqUe>{S!FQuE5fKleyY&pctYE0g zVLDZM+(Fd8io)sTMWHww0?q&GaAK8$INf72X|(|@YE@{-x;L-lRGSDfw>aeISeN_F ziB=f{#}lbe?|`Ty8mm}0K?Lz(HgwG}N`U4pB~Pafq1ES6-PpCHnqe(my%K@lU;ia0 zKC-n`rfWlsT#Pv2i)vvh`|@zC&t;ik?)x~W<}uBa^|7HJEGc2-1{qp;IDV?YZ7uc| zk^E}rE_icR%igm?;Ex=raD3?2d7t(+>~B@q27EG_TEEbVzDHd9?yNp2*C}TbW7ItQ z@?X%bw@-!e0v5faqVo%VCblP&m|6qP${Tu)5*q?OP#mhPi>4+acz2k9e`GjlXq9RU zXSBUuA+XOnT+2RJ$nVAewA9!y^oA(OXlI|=ylP^{Du9(%S<&Rj_&@lP)(*foRos@@ zG5y}s-VzYgC$UK^M6t+CQJ1Yuf8uS^8=rJ%3p0yN{X-Hvu`~pO=myC)m64~zuZqr* zs^*r<+U{<2KJw01URaDJ;Co$GX2_*sWrQKP^|r`N#=+%sl$|nE4!G5;lYpjMF|r^k zG0>dJixZlitXFgv8dOSmM*iqbad(4M^?Qy5U}turVYPRV1CCtS}@Rgijst84!XI`#?R_AJ>8sBvjFzrDK(9ZdPNmba;k$ zxc|FjPZJ8o4I7vq1IA**S-G;E_t&sQ=OGUj!dOnnm^EyndFA9PVe{0?Su>M_{=&D# zuLoDX`@K%P;1ayG1!6k+b51-oXi?R2x)!% zZ%CXwFDtH5i_3=+L|yJQYNz0o1@%gwCO3)uIA50v7=PH(m&Baqz7qYPhFLXXPArBU zjfyC($8rA*P4mx}wse2-Q>BA6E@^vihPF<`7=TJvxubQ#d9X(IX@qN7MFxHS<@_%! zDoP+X<;O31P(z{eRmk7W@!nSC@Rl-E#+Jx zd%?`SQ?0_DWg{F2S?axgO3^$a7)n1m&dMovLFwP9^MBE6d4a3S?UbY()+&>1g)4T` zQ*=qLFfsirGc<$iMI{7k?3-l4@f=#?`sI7V2lW-YomFzYeycDo?^_v&ppz$6;U@D2 zXw(QT(Q^kS@4NK9^VkPlhCjSB{>FkeHXhhlRauVZFA`P3TvgZHZnOBnZC3DtDoi~L z1yl!BRd5EUa4NNF)`Q$jd%cmFCn zaB{PM!qJ4QLb=1$M%>A50M(>8=9dcWEATDCh9h$lZmG4OO)x1rP(Dt7;4>61@Z^Hv+Mn@s6KTX>(8d%e_8@GEY3L}SCEfrpe5c}aZtc)%d0+l3({K8{P3PzPo&~R%(KC9r zR(X~k?O^;RS<*j97wG*DD0I)`3pg+QgZMJ&7<0qTk)j;b+0+T-tv%R6qp`%f3C z3)jR>=yHCK+>R9WQ~YTmcZi9cuvi_<(L8qu>N&!Z4s< zNW(EyzpyqyQT!QnA}_*;a79P`PHCC*_&ep2S=H6pVzBodOW~Ap>4yvnEDZ$dSQ4*Z zakv_-eC{-y(R{%-9`-vWIzM%x3M@EG7PtOJTwYZwC~2sJp)rl1x|c3F6V5X7;$8QF zc&By`HybPi80tJtNS6OWs*h16UeLiBMEVHN&EinikL+fNj`)3r3(-Q;lB_Ua_#xKa zj*LRpbA`kuiJAg9IykQNh6Q)i%#bL@^I`q<(9^#?3beQg-K<{{bbNG1{mOIEI-ln# z&#jiz`$o#l2D#BgA^>kZ;aTaG{kB<>@2civR*FEuKC$u20^ z#zXMK174BK-25dCY;JMC>1ySzVIbo=p^IknSNS^F0|Vde1U2uOx7I!a(m@~!Lxs<) z_}Zl)mjy9wSn73}kTuyoni-eL#o8r*xrYBu{_?P&T6q`z3aGN`oxD3zee;fIM}E3h z>f|W1aPz{x-Hz_TLHEwFX&Ah(BxOAAdSta0_I0M5XiDgvWjje&&ipar&Nnq5PbXQ& zPDQicul@126o-AkJMDMjoF4Dj>#90)^1>%6zw~+i+@BK~JzYCaemgbFj3BBt($=_y zFM9Jb`b)W*KJ8mb(_|P4T>1+FjYNl*r*B@d~3{xvzeq*;~+qzoW+pSWu4{PlaAoKSSeSPix_NE4Et*PwL zNNu9&&#yo*^N+Oaa-yw2wWwBilOh2u>-$c$pjotdJ!!@rV$A8SIMgx_F-)#Ew9Spn zzJWdH9`6o;l^HFv^GrU^lJ9U|>L z<&YIcu}qa8lhCkGTdxg|@)*7443UaAnAi%Os20wety({^zL6-u>YDDcSAHOc%Dt4-Qi7MEf8ZPO^Mfgme)f2AaAam}rKw>x<2-5E}O2@D)!{{3*2UZARTinKMv zz%i1yQ6xamQ^^aeGcZcz{2bROZ5L!F_RU~LfqA>XX+wF7j$1H)GfO_qf5$>E#2iR{ zteyBVJMviQV3rOJmm*V?n;w=hh3SEK54?mLPcEbHQsWSOeRnxE~jW#v;AQ$gMO|NTeCUKEZOsM?_`%872vZVTR z^5j-6`=a?{bhkcTr#`-W+ixy1!R)dtD}}DpCBK4;5vx{c6u8*+HuetSD5N1`*`KEIYB zSL0=@YLyL}XH=XYKC&nRbzPv-4pz-7#H%%!Sd+J{MS}u$IfqFO^gxx;Kz7_d*;JSx z)Y+A&8z|_cghJNunu(r6L&}Zb?3-OQU!=P};B}brRefDQHha>lYBG2y->*yaiQ2rf zT$L_%RCVDFGD6a)92jgKt}!HD6W!J z2>!7y@y*+&b-&Dbx7xfJUPdSgVaDvrNRwXUcI(004-jyO1}mlLDO%1q-2?_EIX7_p z-U~)(AE_T1pZv_Co zyjQDXUz*uBrrh%D#L~<3x65?n@2oJg_5AkyP`(3($?&%@3JY{8 z3f_KDU=UJeUO@{W!|e3Wi#)tC^uB!H;c$6DxZVV==Fa5 zl|ps|WI*`W?6B(7l$dMSFFZxWpT#bkP?4Qe#>5u^4&Fx+>(`D6%lx|p2BrSH1Xezg z3Zl#|!aX|z!)FY2s2`yRlI4F&Sf3P_zO^Atgnk9?MOpXnb8J70KZFWq)Sk6Bi34S^ zdjEDs*qDzxUw8*%Lr%&HaXu!lIT(4k$ad5E+0xxnG!O?IfzHRfBD}vh1ur`)({zR# z!WbqWl`4gqLk%m+1oR!R!iBf_S*9(q0$q1UEd2L#fcAA4aj zo;&^piYJhSe}=lmCKtJ!Plb1X!CA{RGH^w~L_c$y628z)iB|T`!6u{qZM$<4yj8hc zWFcf-d%0-?DVGe?y9>3Gq`Vv6I=*T~@tI`K;L578&nHI-_iyx~*7GqswNsD#l>eK0T+<%SuLaa& zP3GItXF}MiDgL>v;b(Wx1!{Qb%{AO-oxyGSYvqGhDq`-dme+A=NLtv>7e=1*|NG2L z=53u`)~$ymS4EntPEl49zN@^iyhV4sL_45E<9P(S&=pxb37w;x%-YQznO9nV(dM$y z|5y9GR=jmGFnQ9d&+os;x;;Nr#Jn!7rdq`ZGfW$J zwK#0f(u5=gW>ePV7G^T1>Q%~@=YH!a7QW6B`=DG68RGNZ$Qz0_N|XF}d1vx6BYbeIuREW5SXJ2g4-6 zzuSn_2ZHyvwu3D(ey$R{6+@{%t~Ey>y}o28=uXAduRi&v&mN`BGP5Fy&2QGYV$|rg zSconwI>NkuqkiMVHw16;SPL#x+KDc-Xxkp#a0(~W>JV1%dcJG&XHg`cNP4DG-*0jG zr-&d+B>~iRziozpdF)X~bChzQ9Wa2+EiJbJ6J{%F_IrbNPZ@TWtrgB;Ke-s8X-!-X zy!J!grjqyG^Kni;$t^YmUieWyPwK?>&VS_?F$5aXH}|VmL%9P1o5dq4rwYuV%1Aau zHxML|(lG=xA?vn*yPwUGxPm=M zRYb0FhOq6t*RCanr3SG77C}vvl(=kznY&n+^QPtCR z%U-+nMA-AoHLkYZ0N41CvTZV4)mB-|6yXkq@ARk3k~;?M8cMLi@}Zyj$LmDdWXo8 zUt4cPuHu@(Tonnk8R z$yW8n(q@;>)AJ9R+mx&L>m*V(Vq4DXmlTeg-p!ERPE=ky|R>9)}7x5e}d| zvFDTaU98)_uf%)ZGvsZ^j<9FAr`8obpDZ_JVo?`+{p>{{b90ZoBOCJA<~TQ_Ysq+9 zrfU5ZT0`b$R1cF@`KLW+Oz+#xUwso$!{wirr2R`KbVPWMz40aSGFWf#I92&rAwU|4 zTEL^ZNVA!88y|~)$QhdYE8lg5JH-?5`VbHkoS!l}6-!zBGCoy2T8+S?2M#KLX^bUCm-l%;7{!KhEjR;3yY0uljzXnj*Eoywg5HMl z?5w`FKLYRmB~d6D=J?``TyU+4*?4^{*B;be`Zv_Ak7$>WIE#6ZYcBX1Lt}lcSnoJ) z&+W7t@6vkOAKa(Q$uy$@^0mi@W03OBud|YVcRW{j=cO5%BVdcT+w%TU?4pAC_| zDlP55w)~P+K&1gm=l~usYv|iFky)Xai5&K9+CZuK1*1 z)BpVaUlFGEt#OgZC=2(bZs4#}V?R&4sb8ayX673c9wv_JH^$+`VV4B^MI*M^H2=WP z93IOT)AgVF-WK%;nM)#33`eo;NzYEAe>47USI-O93SIU#f13%8QZ`t8BYoB3t?ZhU zRe)7fSv<-@H4!@Yy6-*ak1R_8K>1a=Vg2XRV zbFXSw&i)7FdRp4{UwR#UAKOF{)Y?B)q`JLw?*)ag3i5bBb6dGA-}K@ffWSW|FSHb9 zWkXDN53VVcrpr*~TyB1-c0+NP@2hmo)VFNlJJ*mMw|X{gZzu>YGkWx#HN%FRvwa9R z6I?o=%Xlwj-p-YaBmbDq1BxQ^`Q|UOZ73LEo4uS~b9d`u%aV=hk$r_46VJQQZTOGX zre?a;MOMn8IOkRy1|!%aijds*sv$0Pro3B00%e;RQF~D}h~=$(-$NOv3_(y@*1rO= z40}`{h8u|Xh<$ae~m zU#_&wj$U&&q72LfEHHp?#~eT>zDjH1D*LACd6LomrNe2zTlU__K=f*#-2RIZp+>v+ zqRHva$I2qOgE2WVQOVX0cydzpX{)lYRzjSnmz-QZ&KxzLK~_~R!^M0P!#doW=#@t& z-E#)Tv8=d2yI0C&1cZ>b-uS2H!m9NjMRd`eSU zpU_%XN3cDYr5^M3tHK$O-BP<=0dV{(S5do3cSf60a?4r!s9QE4Gxht`gm3`of3ZoEIz zp9Bk42j=3VFR2#`#;h8E5`&-w`tzHWxQEK*qQ1DQ6MK7yFP_t8r1W}&VCb`Jge7cK zq0Pb+<08GinY?@Dw^W)jm^)ndG`Z(|ubnmQ`G$(Kv18GMMmdlC$?1`(9nBEHv#J3+ z>&5sTo;47mh<@`l1=3T!;qYR@MRd#7BG_M?ZlBPHJr`(3Gj12SbiAD=i9MLS9g_7H z*X_!kbMKqSSOq-oaN5Ze)cWQ}>kkvwVHajT^-+qrWLGz>SIw3Vgr}71%6ROc)0Wa} zhL#G%@fk@x0Ph51c@#vy5lvxpG5Zp{r80^XtMr9bhHYfY_PVw?S4R! zD{4)Jv8mNV$AH4*-QLUEzO$svwLIf<$b=& z$N9LNYI?LpUCl96xl{5H)!fFY)@9>RDSgqk6Y(UZbf7lP_?;m_SjEV6CGT*dcSpug^}E)W z$Qn2=4xOim%IISlzoG)H+xejnR86#D-yZsjuDW|!KW?@m$Ni+0@$ntq*_Kp>!re}G zmfTXM%`|4;u|~|R+Zi7Bw~rux)31))`|d9P0y=-jw8mGFuPrI&y!So6i||DOV`mYL zz_eP{D6+Q3HR@l>k&X#1pWDHn!y;by#od(nT6`@J%QkF2IZJ<^oAmo=rO{#aKJY}s z8D6gL4EjWH6Ri0@1{t3#85^>Bn@CoS?vS7ECdy*1SwwOjyh1BxOjR%lA%fYC)9}in z#`k>~ohJjx6EMU>jH~&|-};228?Z&$goVb4@dP(b3zlpM#VWppO3Q+YV$PQ2t{x+= zGyURfHdoa=Ef~K6u>iZ(&X{p~=hU7Lzw-ZbMd;?#9Z!+;7I_+nN|@l4AR_(E!%NoQ z)IzSGH~+J6Xn0~o5M_}bscC%Z#4wRNc>l~(cIDn#Y18(j^dS%NzkG$NSwvE3VBSc< zhLPYZ{f9vZgP_a#()v;F9^|yjMc%DWrvVgvR>oL$(Li`FE*0_Tgf=}mNiyt3OFi9W zN$c%YI}x?PJb6JfG>Y=BsWKTC%oB2kXvgrKQ$rVsymG73hhv0eqE*d@M4?lPBjz^$@_xB8j^?c+zP|9#p;{ZQn zMdB}!`CJk@R#yj?<>a!k_s&WJ=CsJ41wj(pyl%xxpSkcpi5H-;ih!7K@2&zi-XO0O z0g&=z?V?e-e_osI!OOf67DhKc+;v4&1!idVC@pG@0!$T%-u*H7~I!D^?eNqs{&jG&Y8M z*Bg@PQqz);lH@nmO-7MTbiG-6h{F`NF@3bI$*I-Gq7B%x)DQ=sDNkfJ{JYkwmQj*- zE3J?Bo_&hV)17u*`KKvC<1564>+6nQFH3kRcz4vQPN*MCKG#*@aS2c7CDTfQXs_k8 z9~TQlHKp%K1`t%xQg15U{2VVLYVo`nAsVOXr0`$}p>kw$Zt`H@f&R3YqCxxSu&L76 zl+Rk+ZhfmFC+u%~P0mO<4VvA?$vBUdc7T^cT?%Q15Qv25(1cp3jC$HI8GpX~1`4~J zV#U6s|Pjs)Sw(RKQc3&wR13MV6^iP4NiN=&%i}E*Ni#lvc(1X9=b(|Yhnb@+#X;j3& z#-%y`Sc>s~NaA{)kg^Th?%Fm~*`t(zM;x^yi`&o1bXe?i6hpx#Krl5^Y(|^$wjZ#$ zPtCs$R&hAqc-EJ-{%m724tCRzZ(e{aNf&{62fIVw{LyXdxnzSs8rj4guA`5xd6O!% zmzO%WUIA(xV*1`t^=raUjqRf{f+;1pE{sinM?}avDK{etJpz{CdAwYOe%0SxwU#+u zDYnc2Qa#+$cj`RL-HaP~F$X;L=>A zyDzM-URtSf6|GuA>ueQ9&Q*j!hr3)%gc`st%0%d0nK>r?i1N~HwjfI+AFl2visT37 z55*uJRu65^;k}HI%7LMS9P33pgIusK`UquT?{6+f&KjA3KARiByhtfcf%qI=ksVkN z3um^zVq5ZsBFx)qUq|K}_~fL!Gd7Gd(s}G~?BOb#H9kj!GxtKo{F(t;OXh?c^tegp zFwam?ipcyUiDhpr{lfmY28z}Cty4ZmLWv5!=a!ULKZ9cf9??ki@ux^rE@WQ2p!_nJ z91Lz#zFti9Jb>)A?4{&D`RpL7OEv<&xCOFHxuMNIb&0AfH@Wnw2dus_Yj%?AL;}luM_{rdV$f!g|N2DfH@fQm|YrM4K&7>6BmXAY1ZNX)5e*xS@Y@%C44w_O^ zoGWnCwipeB@*7^w$Z2@RX>KP+uLJ|~=SvZy;muZ8YgSDha`BzS{pwsOBT6Se7O?cn zJrbCRSz*V2OByo|+e+fGs3_K;c)3QH77ng?zKV~0y6u^R83Jzd@56S0na1L<-##_i z$sua?-;z9hQ5$j>x~>rLdj2g`VaN3xnT<5+9_;oIQwBPlHH+gU-ht(p)9hMAj7SjA z18ABgt2VzClRTbMo5M3&NTuz}o0rd0(I%!b2{b)IGLP+i-bJPnv1lk2RTC+vpvWz{u$Sn@RgB*PgbVHM>pvFW1H(myAZgh^arbBn1@hE)myqBXVOJ=>x_C9uK$g^s$^rD0&W@u8L1Wo33hNet` zBMO~hqM~NApqY%5D20!#mCeK(7LSW$1UvHGqA-w2%nY{#f_-)aI{sL*S@7KpVLj(1 zoYzdOW8daddJXNza;N)8b_fX<^R+DOtgH#8y{x&t8CT`gfn@jw&wZWq-x4$ZM!1-4 zPzlU?MAq(E_Ko;hXA0W;@(W9KhBr9gp*5K8_H3Y}|I)MB3hef5-eQ;VIN^g=uUb_1 z^}1R-RB6~7#BaH>Y0zQ+js@~z;rW5M;}r$P?^9C-ICJV#?vV6Xtv9o#$VQ1@^?}Vp z_6rTi|8?!sL*YR!xpx+tm5gE6XEb5nTb1+j$u5h5Srsydv+Nw6vC`uCmZ65tHD1yQ zZ=Rs??&6Q~{r&2S)~`pqJh#7W-XcC&O6KX?S1um8Z_1Z?u(4xO7}z37mfS2Zlaa@1 z(es5b?AL!K4{Pp`LO-RX=gMzib&iQ=wygp8q21|oM3)A}+r-I@Lvk~Zn~=RHsk&Aa z#`UfS;6gs#m?<*qH6>vux?K*fJ3Vx6xI%l8aBN5=w_J)^v?ir(!`KpB^XT;Cw+5bW zYk)`LX>vY0)L;cx^-iqq++T7w~``LoW888T71iWbo%iz~~^)4bOw zY%kI8u~Gy!o7|Q(CacHmSr-LASNp%FBaM@7omNdQdcSaRv+qV`a%pDkEU>)dQfTmM z0?mLDk?lOpi7Ee7PH>D|!K*|(TBAh}k+5%N<>;=`$T3I3@G&?LX7LC%H!j8;?{+Ad z_hyr*=gbKr%JDj`$fTkkrMy|BVRFChhP*?lXNtOKJ-i?JGJ-~6e9!$tTtljpu%-Ic z$2sA?$w!*QXz#OAC;QI<7s)0qzs<4qD?EUOHKV~hHd=erMXJAFfd>_kvi#H4Op8z? z(OBILN9FMoZPjP^fUNvUnG-(rd_nPZ3C5H^CJ=*;qi+2DQgBdWKk_ZoI-(gErHa>r z&H+xBbL0;ha)VeCnfrI#TvtZ1EwA`KZ>y#AOQbWNxb>~A_8`A+xz61KSpvHXcccq5 zK}y-JX2eXUj#;RtmdBCpuj~)S(dWo(pJ=2;maMj@&$qCRtj<;&RaMydI=bxinjThx za?bpR8@PlA4|So(`-@3KR90t!F!v{oCs0PFJr3U6Vn7nHqX<|3HgB~jB3Py>OS{rp zdXw!K6a5sbbCXb6D&cmq_!-O_k~DsF=j!QtC_{4Br4kH(%FQ819R{hbEo3ib)yLc` zF}@ptq{q2}kx4oG&cUyrmyhW9p`vnfbLmf5<8`#8{`>3bSr5L)&bfog&)(U;jFo&5F@_B0jx zQX*?Y@k434OXelNTqggfobK@5c>85g-Qz>qY0|+)Zy(DOrXdifWeTzO9mrv}|B9-8 zfa0 zXz6%{L1f}kgbF6>oHZ$f95ZIt!Y(7Fd-gA#I@abrok+x%(5&TQEC2>*t;30Lcq=PUJ!V_i4P9Le7I%4N12j&Z_qa+m>28vdp6Sl=AEm!g69X%0}11Ex2N z*kJUJX--#kxjy>#c_s+pGVTE6Tj{Aitq2-=eC$`Fr8MZiJyhEeulHeTo2}!Bxh-T7 zZVarxOItsKMJudUbF!cNB(r7??1ZZSf3Mgpy}8C64oF7iu(03lJYt8Z%~rO7M5COtFk%LL_FNUt4(YSpPY6{)eVYE8?(u$y#VYds-Q^YLouZoUP9E z^mz+QXMR-!p)o9V^C%?pTlFso4(HTrqbt=Z#h8>>tvTt^e^YeS>Gb6Aid=$oz~w9o zv?FZ;d6%TdYnu~OKg>w@H$S)jyphrhS)3bC=bXL{2$JWQobSEHz7JnT9yv`R!(0f=q<$ z{58%rrS;*ooJ%l6?~f%Dj7N_sLup6+d#xEYw8jO`p{}U8XnESkj7J;!;l(lB_yQJs z?6<{_0MqyQ8*EtAWkMcPl8&O*0!k8_`%&R4fb#`J|Il_AaK6Z3bk_p-x7x*H!u_A? zk>g~l2R2C)edXJ(h85!bZ9hr&$3GF{3w5-s&YZgT-l2IOfP;I+dIq)wC}~vpgT9pzR`$) zv_^8tW!Aqgd36%CXi{0{XV~t%={G~1HK9l@yZcLWK>q>;QB6N+*bsFop~a5~?(2YR zWoDvucN8V{pHY;c1hZBIxjTybHSo_-6h-hWLDG$ABjR+7<(^pUbLo#sE$=;9rm;;n zQm0!zD;Ddli1!!Yz$r4g=-3&b@0sAG>f61y@=F}}tfpCw(`BQ_PHVbh%_bZ4H?POK zy-5D?Ecq6ajk+nn0j}%E`h^jnInPmVz^lV+)nInYZX0CQiaXxl#2BmW#eH8k5vk9F{mlZIH&8zw%6lNJ&;3u?7#XJ4W3k#$$I3-)($PE|iwTb^j6Sp(fT##M9eh@yz1T$gNGLtqN$=dA zW1dm}^FH|F6S})+hv2svqM5jLLr3fCigasr!PKf7EXucyRB+21ZOgn$O3{3q;W+Lq zgjpxXuob#BZ*3!ma%q;;pOsOk`16&cZxP0CKwgszwq#d1$-IZ#hB`pK0=i`m<9UY^ ziJym48g&L$jjAI$N@P<}?@O$67-qC71A|Go)W}J?8tjj17tr>%ORHsXrjAk|j2=S6 zEzyKf)vYjE@8Y_=;c=;q^`^EN!hgU3zf1W1_=n<@wRjO@vH>jCw+dC^pPQ8fX&>EET+_uuNV1oVN(NCF*I=j9^<%w*~0s|d`FxNZ~v+Hj>B)b zzZ>qX1=G4G;5vA`#VV@pgV1S$cprcO&hifg+LI2+bac1HZ=t-ym&YUPvGuMijqbt)7 zIZk&{DsoDM1?nYMXyFjV=CKLSk-~I}QhPB9g%nyE{%I(?@$AOh6l-QxfF_oGVi{KA ztRfmJ><4o~MZj-YYeeX-hSDMfl@{Th^%T>@r4Rzn+B64PK~j?tiR*uKr3MKM$^CDV zi)k|iXgI}2+d*LmY*BBAUAD{;S9Yr)o=(j|4i=-Y3t}D+_pg!Y-M%yB@G6_gW^#kV zX=~az(W25G!m6zE>S3s3-bg>TJ8mdp!NHaTtO=I+%V)C_G<^B1&Ua~9^Ha*2jdmSD z`%gC;m5%X~<)~wIj0(qr4w~(Tk_U)1%IfDfsP#Jy&26frI>2jt%QRpyk#AjmL>d>i z=&M+Bk4L9gzJRA`fp@fn`j8V>9tI=d6uBL0Re3*iGtBSy>|4+@X+`a)#LHC=HjNrT z=CBr@LBA}u5LZpU3w0<5PXx(G5ca@C@3NLF>#L_51VKF&H#ORyGAcN&eu_4&Ymb_o zX~&UJy%RIRG3UI{q*soEw-Z)|Rn3eCJO!ukQ9|1NHhuS?M6mvV@jtP|)I4XAORR2S zmAZRQD*@Om@2C^35%Y^+?5+?BBNtEYU1K{meN#>%I=-Qao{*2aAW;qpAe8gV;S^pm zqx*`hetuh8&-8o0;zf zMW#@mLizJVVH^|xnyyk}DFmn33qcc@!!;YvB)FO-GAvAZMP}D?+jPpavY?@3Gg;XU zFKk)<8c8m8jg!PI>E1v_`3amY{eVicJY`>XI%Jr>Y6cqx$2*jPslM9h<&zzXBR;|- z7fr0+ylns%{`#8CGY@l5!f63%J_&y8-)j?vTPzCaQzoySG+Jrpt^md`SuboufMFMgZLD@OW11Hmp6$$y6WWZ@D`aVVbK@bg zqJhrF_WBLqo%gS68_7Us5jlOqeSKmkvzh8^hWRXT1jrAOItz(yN`#3n1gUMh&(LA) z5;^&gZ{Nul*#ipRRLan9?RKaw zZYIf7G%-GY?TmM6i~%nW^g#fN@Sk;AK3#aEo90y00LZ(qgQ$BhDDyW`y{Lk%&fN~n2CXltp~tXeZ^tE#Ob2x3gCh8RMq62w$f5OW9_ zy;o?vy7ylDx$phF&+{Jl>kr3aJd%qmzww6TLEE_|(d>KU(oFx? z#zpL(8W+Cd|DtgTsKN;@@6<@5Gb@Ne<{i1YFMDSxC}+?CC>Frd?;0F*DZ z%RB#?2`Z}ETt6#lDXQ67V7kM7i{^=h>C)oWI)zQEPuHb_9j?jUOKOEWjVD)P7Stl- zrg!=GAX&FMP$uod8VKfC{z;NpLDwEi zdP}zX$;un77(Hw=b5IU`#5(&i^E^MD+%LiXg2)e5iGo0^P1Gi#G#-K z89|@kf+b;F)VQ{Hi%mC7&_UXS9Z7zo{(co9L`!+6;Sb~-zc;p5gYvRgKT;tboeljF zU$gg)qW59~l$PWnfqt>Da90`RyfUispcq!3)((;<0cb7mn+bRDuT2Hq_qnifY*jCa zx{0O%YqGx0KS3|$mpu@yq@yvo01$z@|H+OPH+f_GT*~$#S^gVVXEf`G>WeF3Jz;4p zFDI6lGdZFH^w+D0WJhzP>6JI7El_{C(L~5_1pAE3Lhjf;*?nr$1Ig2SMnms^cgMwT zc+n1;2&`V45gatVe&z4R_tt4FYrZn`@yh#4TqD|oYuU%PpqmNg{^N#bPT`<>cRameZkk3~)8Whx|R+w`I-I$R@`bnfkHe z*Ck#hNhPx0BGC8v$YIqY?RgJ39j@M=+Vz??=|7-!7e-R7G4Jke-ne*)DDU@H;U(0M zGh%^hHkmfu*p<)n=~S)-+dp1NJ~QUm?YQd>CqzLXv?}+HDhSVjW?Xb(U3B_m2{90?Ldxlm+iAkd!XPF-& zFn4qF@f}^3275`SW|AmipOGF0*kS6cHb#S=^efL@ol8ZEqxI1aaq!m^V-V%r>T#vt z0Px!ba{ZZ1>?ig>*)7hb&TIR`jZc>_7I9uv-gSil56dnth{wi;r6aGtEh>qaNH67QRGyyQ_axI>x#XOZ_ki}SkTYy;q3UFG*y7#o(#{06Sy z`wo!5dexz!SQnz+S}n)3gO9pPx>>{QFgJ4iLV~m}LqF>p#$=~m`887I>b>cGs27cy zo6+JWpQG;!iT|l4x0j;m4gHnxmTWJ+a?G^L2$8!F0}4l`g#!UE}U1H=>jL*294F%*pk& z=GafcG$3|m>_p)YZ7j{3SKtcN5#Zx3K(}e^;FBnuL&Cl!KgMhDS5U=X|(!%f@n~PM#`LK`+!=2`1&g*%YvKxP3EpYZ-Sg`;y5Abr@aAp+K|M5=}uF$ zxbm>rXzMmlx0o9Dj+?M5nZa76NQA>tsjb;_+OZ4Eo!hU_tQRcI@g}B=e#KU$kZP|} z=!wPNgzY_?BMYC)K2Tkt>_`Qh>E7<<1VjU1of})k&uX%I2q9zab+hh}z%$kguh#?b z?OLEi<#j$r&u@^Xblr*5*$WKXx+upW^6L9A0rBzj@Y9;q7* zUV&ru(T(iXOGYXoOnN5zSM-oOt4Hfp$nm;&r{xUlKcRbT*G`&B)eYonK6yVUo=pxj zy!(4y;poeQ%FBEO$rYVKwzGH509(g@?oX3puzp+V%Ki4t{=42Z>*uEY)#9DvmcP3v zMbMB!ah)oT3Q0>noyA4a774(dMx-7Y%I|M~$eviO0YqG6b=9r=^%TFm5$PQX*1J-U zCe`X2j0V4xp@`oM|A@J^M>PYSBNy*ZY_JwSi&PP~Bl69icHgC$^;)k30rwppN<`He zb{fRD3+28TxKx$epJ=3zq$IiIH(L{2`#3D?rpsW~vn*2BS}OoLX!Z;cs8i}Ts|XxJ zxvU$)S4aK?+(X@{ZFf|Qnyo*Ge!?Fv6v2Jj@-pAk7o3Vb*h?~Z;^I)^5c?{UhGN9! z8*(_OLpf*R87kf9FM|e;c^tsh7dT-CoK*#| z@$r2XFaCWdYGEZ>*p3{Na)uYz*(2`RXTcVRTG2-Dc5NrD7M8rAv%6yfLw19c zmadzbG_{MBx=a}{9Njns!2PNUzUG3K;p9LZKK-7#e1KDz7F0=q+xrW}-cPIS{1QFK_8@G*xGnPX{ zZ8gQ#(wyFmrXjN@lv*bzM9-JC!)(agZC|=Cz>Hwm?aQeih(M5TW0TEjiG%`@4B>jW zQ5)XvoKS0X*PA+{`S7TjWz&}I;_v^G!1A4j>i)gAVGk4tc?sN zc+^cYs18`TpE7wKNJ>l=)Zl&xjnty~`!7%;rqlKD|Cv+90=wFGoWyXrBy^qT#?YA+y@ERnKA z+#kR(Fqeop7`%3qsTW1;-XBN4gL158MSPD~^`#Ss)oWTt-$7^cHyFWqIX=~2c zjY!j9RmO7TPjaGw(UaS#SAXf1CB7G=HY`XF551} z%NRs$StA;wf}!s0!yd|}N$E#__NkGY{Jp4`_E(i2nz5a(zq*Tg$aT*V1SvHpP$s?$ z%WEDXhH>+H!7&nAh(lpjoWh8~QY4F@KRZ>A5;|4H7d7?(yUp$}#tPo6$h(sil8mW} zs8{miJ>{|xj2l=uYvgp3xB4Kx079skLx$c z?qjKle&KugIllSYm)9XcV;sHxjD-4PPiBdWY^EVv_n-RWUIc^wQqObp1iQ=qn&(r7 z?3hm7DoSwvg{0J?qhLmkOq^d)66;u%O^QEEMAYKR1)>^Q4#|3IaZt+EQ*`mHr7*dP zmV<``8*ZGs547~T$93Dh&&++RSrg38KT1P^z`rz2_g$$a{ry`=!E42hla`x@K9m(m z-Yv9!OX;nAJf9~fH97DSJ?Xnm{K1bj6H8^O6G$`d;yei&%M6>3C_SeIUsg!CSw;A< z!}9EpL;t^qKr1e-jR8ov%b^=lJ070=b)pEjbO=n|@n$ZYH5Lw`EEdtHz^>MqC`k0G zlh6@rm?pOdz=~7(F3(l7hhKU)&Jq^+u2g^%{j+5^#H6(pUrQEe)}QzX2-G_`58#q| zzS-5i_NYyOKiMB$QhIMED(2fN|5ipSDY^_m;c%-0^n1Xs4sP?%aX1)iIpE@Uu22ql zn_x=)1D723dUy)~an<-A{AxPlKEp|%?t*1OY|i{G3l)MEcCf0X;>f~`?Mccg#!WB@ zz4|ASrUhEPVNNa1$IjK`hpmVXF>>0nbsTB;3id)*>imY&P8G;H97G9#v;wsPs`Nl< zw&^>A?WJnGm5Ss$@Rxlax425-Nax#+k1asfZ-tNBUEOv zQgd3Qs>ljYQc)6EjmNJ1`6uSvEjNUC{i@HdH?1d(n2c^efiO$J+#aKkISzUPjxvge z4fw@c{?xSM)(jnRU1qGpPh71_ZW_?KcCRqU8C9MAka2AsVP zCqJ~#T(YWdu_jik<2*bpsCnCCSq*?(l+b1_C6KlAYiaWa2Tprg6|v8B)@Mc38~QF( z{h#S0`*VCaC|DTxvyI zB3+$Mn9rF(tRvg8sA?yo(p?h25_AQgW<74(8N=u-Qkt?Lha0Miq=oN6NMVaKY<&UU}`W!7WktxwixhN+RwWqjbpG7Z(02zQ4 z@bgjAMp9R|K`9ue5d}$w3@^QyyaCYy@c0FV9y#rZycs=U1YJYb79=?_Z_HpnPWjI< zUJ%ORyTlMwRLY*NbQ`%93zTO3Oh?_}?tuM;a2;8mp_M(BC+F*LC|8xo#_!NoYAakw z7P_Py^_d*Gaqh6pi4tQ`1;|=L@8#+4 zDVoK6R`Xlgsw#><$?dwqV@_o|BrixqtS|6<2NamvV?ZF^Ig_p*l9P{4YH*QJW6(hP z3yw|bfY}|yNYZE5n2{UoVY=;MepZsF@E~f|_u@FAK!zx>k~GwUBRESZT7i0{3N?>G z0g~1@q8_S{QG9@DZ*a^*LifdKqqx-4en{kXBaDlg)~iThh23FWpDyP6n=>_P6gFf> z*Y%$IpL}T)&Y=cD8O5-yJ#suB5jY?_r=&8_J@v~S01y-InISLorA&T3zCNud^#elp zLa2LjeD_a+^mKZP{DpC92!xZac;7h5T zDoi8}p!Uc1FkSmPE)z{7Ht_^S1GiIuHxgi50wQY@rDMr$k7a%$LCKD_W0CU^}N$D}9aGg-$xYPeyb?lh7K?y;G=b zvzyVF2`5Zv@GIbD8m^A_EPF=pb0igX0vlKj6K*9W&BP8E<91OveG5r*;Y!9t*8?kV zsn85Jmh*TJFC{xo^cW)5|7sUIYONRcruhY&xZu@gH1e@9Z3U71mu~dsKh%wmLd`1& zYKl#}n&{4(3Hze1(~`Xi&mq|*bBtj3_ZFX2{4P?#^K!s#+hK8Xe%w}BUsnlP@>>(R zyr*_Vy3n-MQUygnNv{_LOdr5l2dv8->lA4RDEQSF|25XR9O|bpW1v4%J?!fTho6O; zC=+FucaIYSF?Q3v=h$$1lB|a&a$MX^hJRoCMV8ZI#%LZ*bCSNi5E#+Udr&3=l)a?T z)7mPeiUH@{NNwq(FU682z#Y&w~OMPM$5yRj*6iv>*lnQxo+`R z#I)6g1WTi%mgxAvz>>?2MTn!v94!0dRQ8$`Bh?WG?oao#lZqwEMG?lZ_!g0TJsKNf^luL}n84w@FDjt3&jJsW?4D!Ls|IPd-t8JwpZ1F5M-G8R*teNBb!4)cn@Kw26Lr{l@CDA1d z;%ur>_MSpvO9-2J@&GyK`?!o1c>jIdA#mkrwWvN|KwE9Tpk=+hSw*W~QcL0)D*@>B zj{_BW-|yQOi9iwsD;~BmKawAxaH2>&LkcX}wi?7aTB0SvkF5V^`G&w=-YNua9s3VT z@r_OZc)vk2dczFSm1FxPHOp;MRBLkel4Y!wfwy+8*%Tob`D9O!EyvK0wYY&29ClOH zcs+4pdKcE@nw@0^=`6L%CbE% zzBcObc30^@nMgs#b%3Z2?y?F!Ds~yI zK$4_cVd?kNK^h91yI$Vb#A>I5&JK#p}m>Py8q~Fz{f>>}o*U z%1bwZwC0F%pf-2Nw`{+dSkrWt%SQsg+#7c{=!EJU_u}S?%wmP6Ck6CHnD;Pi zf(3gko3QfPLIvVnT^4sVyDwc&4(-ySZUtq}!~i8bg|_+PbG-l#?1Dk%?zMUh_o)r= zBI*n+=h%0aT7!{w&GZ8g#Lq7dir;c`RF2+a>FFY$xG(cdSF_vYb)J&B>>sFaOI_lC zSP7_1zc-aCRQ01^2VJ060@k!=WifJO$ZeLgsc@@xDmhLskxFmMe5C9XI~Ew)Jao`V zG>AsHutu%tSaICN1bjOsImV4@CurU?+WPP{6HQ~-dE4W&61>Z$@4-(x`*JbYRE>L! zWvIpHBQQJM7Rzy~IpbX}2A(vM3Z7~ozb6ZEs+l79qqg5ewg@_iCz@2pS&YIlPS{CM z?2aW369Ieuj%-?i2TnSdq&{lZPxh2%FMq~4LfPB$U9O5R#7#`#ULjKE*^0DmUv>p+ zrdPqj9Z{!{(ub-FH+-(ws&dH~J7Yy6^1V?o!%k6KrIK$N+txgiRRpNPDfBOxy}0Ol zJNe`xMozIYcoGtv9t`>bbp-c5tVb8`&07hfzKs6&rTvglK9H$RNwdV0$H~wOe^vI0 zy5B1M@K;&R7plk4ax44CtJm3I>kcI}-d*#TH;s&wjJu2|&SOZRG!vXZr-2E8+U9+U zmzFUizy&)!_CG#}h#pziIHOCLY|eku602_m3>UPks`muus^JnXKEH33VH<;B)0 zlvjR$esL`^dh&nIlBQWfYBnY!wqWHd&pNOtyN71C?I1viy+)pn2u*|^wN6bM?-S)Q zEL4=s?2TwnY>4QR(JuG*sO9D&+ko=HcHh1oxci_>JwCK`>vt;3u}4^F6pyIVk&A zaD@@>5!1Dujx~9ZLY9DVikX9laG8ucD(<}Ks^aAEj7;nQSN%bt@NDL%d=tOe zu9lY?;g-)PU(1Kj^x2)XpKgVDZ;AW?x3X8X9C#<16eZ9Hxv4K|5F`}TPTrxpO+xXh zINC8~ntY=5aW4EEIBxceEVI1;7lh#yhV!2a26ckil{=dzcBY_&zH%q=IgtX}B;>}` zc4wqN<7*OHcJeaEkwzzO$H5r?B_F8bCo z;H_vRdLY&B7l(3R?{}+@{SE8O7`yL?Gs>xCFN7D}e|Yps+FvGa3&BGotue<+pmJxR8k`U5g0VHA zQTexEKO>@z6po_x(CZ9^_`h#Wt2^Ft!dK3@^}<`e;+e{#Q(5?9feZ;NQ_}s;@Ei~Y zc5VjOw`A*E=kqGJu^Cc_Sa|iOs#FcER@f2T(#X1~61Rt8^;xpv6?r`t&79?n_ww`@ zY*Ykk+gncL=tNfOJ@LY0^Z^-_$uMe>AitFr(rnE@jZ zJ>XACc6npHiDcneyl-BQT39;p@(`(Qbby3^ob4Y~49a-fsm?l`q2nW)}!fv9(f~kf&^3s zLS4-*GO?W(S)r%aUEFJZAz_3X2IM!k%5=0!VY;B0oE#{Qc#5%IF)0fyTZ!BH@1io! ztyBJT{^|1fhmDu@YzS(PXV!x<&=y*oczQaj%Kp73TI>M=Sa=r`(vVVF@4R){&n$mI zMg4JGxavkLB4qbN{C1%xbH(pS*LCqbeKte6UVprMU>Gu_Guz~#*6PzFrZh7gd_+0O znno9ynv4+_tpXPw1WZ(jn2J6C+-?V14 z-y;;$4eyx^LL9+`9>Phf5D1B^ymr}y5U(9w>naJp8K(I58MNjg1lgu2|DoT`SKb`Sg#asrfy(}pQAj+V&N22~^jc$fS~ zm%P8wZN-EuK%JI5$v4IcRZ6+dQG~?1n44r6rczTph2Bf(r%8WG@GI)4XqLPS9csi6 z4pwnSl3k$NI=#VwE3>_{psM!_${I{?%UfO^K<@|yAI~wxq2OOK5pw=G`x149iUQt; z>Cg4N`|hx8l6Y7bxM7+HkCm@F`-}CwvU?HG;44-yMEkXB+BopTq-oB?@96~g;;2t< zftg?N^2+al!7YHv42Z%GkBEOv*gSqD8Nn`w3-Xwsd?37Bq4h209yl)%0H#d4wi4z{8E5<>n5#3>HVfL)gha?o9{O) zN!-pQ;X6Es`@z2$)rK4<*5cSpYjOM{+^+*~R^q^ONYU2j92mNBIFv;#9XuF0*t*Wb z<>5%O*KKL+O{V-UWLM=*?9;Qtmw4$O5suA3`P<~m>E}SZdPc4uFNH*-R$ySi#Ae5CP}XcE8aq8Df_49dgEpg0I6)_p?Uywap!h#!5x(-+EW zSy*=7aYKuEu;Q_*u+VYj!}~_ZldK!8U@Ap+E$lsNe2koXRdO%yLjcU6e7$e zV3d-Ft;s()#kpaR61{iS^i7R%@3_U*jT`NB%{n6u)|SWHwuo0R0qA&@ZSSJw;a9y+a4&N%hgx;s zk@z_O+%n5}x~cbG)MIy{Zq!58P3j+QRLeqSgQwRJhz*N)cG07Y?CpYusgndhP(WXnxZIfCzw*=8LQl1$j%7dZOyQh2}wHst($AhT+HxZz^TxAAGw z$JBg?j{Bzj_ZvZ3`JkS#(1~%THXx&?l}HR)Mzv4hI9$SxPjq6?xsWtcKXv2L9O6RK zo&%7yu?j14HA?Q+5x4lt7{~RM9DJc)o%^`Bmqaud_Y(d%Ngu13aN?L-=`}0BA-4}; zfUl}~(i!)Z3aj!gKD3Q)Oe=Q^J^D7REP&`d+X7JgPBALB6)iPW3!l{)qU{v|m=}jw`#@}!1YOUMx+{G63G9Gd;FTn z`M;*?CgGQOM56BUz=!|FdDdiqmTNrwZ2mM>qk}U|J=>+Y;`QeX$VvjHDgjX}_Wh;Z zc_-(!E{{=Sje-(JMtc|KYAaU6hb?JFVqF*&P`=GJeHKs9#eClYSO*tacy zixrd-7V{%$`GbXas&_kucB%Y1Iiz#T+6~crqQ*0Z)Q|};v0n7JBK?{bGg~z6=o}); zT#moRRd@Eo3~#hdDXqZ^KnDXTw1n-B>C39L!63xoTqYD44IvlM&4_yVLKkYk(gR81zumYwA6MY49sX6O% z{}3{6!+aKYP4+^c&FNE~?|;D2O3h<{qm7H<1bttsv@{ZP7fi|UfeF1SEIG>f{eCb54&z{nHo`N`~%$` z5VHppl#_l6^h*A@@lO+`Ul&b!0s)v9JW8?Yak4-!ITuN5eAuEM-_e^gAjzNRTDTut zl&N667&kvdSv0KiXfj94`A;XOmmzIGr8W$xn$2l(|D7ISalODn?z_BcN`r^8#j5Ua z*=YJcp^W-N9n`^j>4e5^c}ClGOTnK}>Tq}i^PQVwL4_me#2Gbgy3i(Ezb?rM(wEt6wyA4cjBNwI&K#xE+)Z>_WLkr}bd}5( zpb>E*&ms1K?U|DVkcRtxe9v&!J*no2@UAOBP`lMd|5TE`$R0-JR81RZ3)o=Nr#F3a zxJVU~K6BLCpd0?g_B8IQiWe_Ju}}=A)T&i)TWm|oi83A0 zFm~EO0QV_j$X>9=i%f1HlDB zIPqx-24FXI8tvJl;`AS!Gynh!;85cQ!Ha&}8C+L=h*@KCSyb&iDV#R*2cR1l1mVn^ z$L6&&JUsbd_1y&(c!bE$yH~EhZt^6P*H6Y5bq!EGJIhu=E#_o%Lfq)`k5h5%LZg_g zc=Pu|F(~}&4Q>9}x9^6xJoj_pe>STU`tHG`33o>Z?et&Z3nP020XF|_Aq!sdiuS?Z9RQr^Y7V-DIChh1&fK55S!yiWu5O7e5u35rU_n>y1Z~3MA*oy z-B`73N3FOyNx_N3GJpxqK$Ur6r`DjsC34K@=AlCo53_Fi#kXQ(!~G&|nMqRbefzT< zLrRs%O%|WiLd5A^nstEo46OT9=Zyg-MPIP`(NHZj(Jwj{fRkIiNrVpDm%j3;T6TS- zmAHQH%y4v3x1Jr~QPDsAG}y;Pwwy%P2-@kRw&Wx>O^%zXDU0JP;ZfzbJV&=@P_Ro$yCY+J0^=SV_$<9Qu-1{C638G24Jh=4<4*EZGNTRUw9RXy$xhricRpue8jgB8sMB2?5PRSq@m1p#vadNQ=S);Qhm z9}YJ1)s(3j1Ue#R`83FO@F*?k(p)g% zBl;^{zKQlVuxgo1$eG#aR|d4(^G=sJ%;1wIX-A$bpG2?Xw*z*NZwo^j%w->Z6Sxv~ z0O$q5J>0fd=(j#x)H(-zr{V3|6h#OnmzilV&C~I)XR*>_ZbTs!D z{38s!z-@nbVYR8P=5EBz=zGC&!E8mvypF*lF+9Jm2i6YL+TghWMNoFH=G5d=aiuFj z&1;lh1PvAsdLkNr@i*`f7_s}NU9CE)w~aj(#S?SiMt2lz=6x(b3$HbNogp3?=<;cW z(2X2ko?ApaHG4L~1)H7S2{c*Iby?LBQ44lSFcd5o)8x$lKu|~AHL!@>!B>vA8~1~b z-(=np)jAMYcU&i0)udMoBCnuy+=58-6j)9$NiLZNuxXJ2RrU3(mm;>L;u(yVv288G ztKf$(4M{)1Gq21gDkOXu;2|XBtUS9u;+(P$EAMpxJzI~6jbn-o*12gv_25F8%D6wB zcwPLoy*R3JoB2gP@Lv)ChlCi}?*Eu6EvzE_Xi0UH2x-)*0P;4&eo60b^MoQ)XHHQw zuP6<3oa)cCKnr|^h4zPnS`Y&jgomT0>%4?lt<}uo(h^OCg_Bu(CY%NvZxN#8pB_|g zeF|wLPME2fcN2s%00~4iv}UT|3c>@$7_@fMXSx^*~@?xw-e`oDSPOpu89i zN;r!smYLz*mOV1Dc4eH<2YCT-)5w(XD%FEQ?y@iuz5b;s-8)MAd(Jg$H&;4Wxas$= zZuRjflx{^c7W2Q?S4k~9e4jfCbS(N^9aQ1pi=tk%11AF$g3MBH6}gYp8`_Px7{7C* zy*(KS?$%T8SR)M)Hk*Ugw303Ysv#LC!2WU<;XrlAJ_EqR3mz5;&!g)xyS1F83T(4U zKGg`>AM%r`FoUX#DB~=6-e>-H6%7pTnT*vT_FYr0x{2&v{aJI!!5wb=pYKSl`B97* z01?5}@n47t_Ao;E$i5^kT#0!?ALNm;xE89w>Jj}eqJb7u1SM?Psv^4XaoH327yEC2 zGP&w*(+D}~+6vJX8hq@GG3)Y-sD0RF{HoW;p6Ep={6)sQRou67LhH#m|n+) zA{70Sn#jY$(co%O)eH4T%E0KNr?e(=q4ozntwV0l&!Y=UO28rN($#`UkZZ% z(`y1PA6QN?>Pj9{3qQ-hMK;Q23IIsI&LGUn>%rn`t^H!Yob^!MC7m)p0BrRtFJ%hH zFzqOQuW~!3Qogo*Ao9d_o1eZAf&T^g)j*uv*WS;JHe#90Z$1pT(v0vQKVj{4bT=dP2K`h5l#g1N~w_ zsO1Oq*91&nBkXctmXq3Nk8|T>dXqneAt+L)1kMv6Ij}~@FANhpD{1L<$FcJ8>y0T;Tkggxo-<5F1!wH@K`WuJ$eUx8d~DZ9)4dcKI)1 zTtky&*Ax)y&~HA!XHyC&&Ckp!6`Ov9ItW;OH$KIM&J_dPckqtj;CpQ>?iyOGr|w5oQJV8kQ&debVOhc@ z4p{hOn-004FjD3~IqvV>4v{phf^e^rk9mhM8^)mn{QFi%!Sqw288-2&9%ps++-}Bb zEvw$2m-$~3)Yu6{k_hB8SG+ve#Aml^Z^y+Oh%D*99ZYZD`%J~$SNpTx-Gev7_uLHd zRrxXVnUcUG*s#X#;=`Z@yK?+;WpYq$vC6FrLJP>o=w}sw4^OizOQ&~8{nrjPok<6N zWz*#z4gNx7u|UfBZVP|YNZRp7ZjuN8iq6c!6{M@DG#rG=sR5*O^CGBQB>jao!v$wwGe~+Nn;0Z2t zQm9eUVBG+gwvMP*=7#~e{`s+nCBN=P`$$a@2VUX7W>aQLUv%Y$u!2(@I$>MvAl@DHQf#S0<}*6QWicNAvQ4ysm(d2viG)#n)&OX13%(>;(^DTO17 zlF4&lR1tEn1e^ng=7H7>pSM22sQv_Omp+)#8~i|i`1 zfZ?aWnC!#Gm`r+amoupc)Z4|BUd)PUri~gnDANwL#jzZ5Pa}!{Xj9{z|JkNS?%l*d zyK7-NqJYL@E5G$4*oOURyz%$J?)L;vFNlw)6aW=}k8pCC0_THa3(t5xPIIYhKPfw~ zO*V1|fRNBXz}Busgo=k2SL(lF@7JU&Jgi@%iWjg?xcQ#79J}%L#c-%Oqw~v};QW?h zzgeVI=`|jnoc})Qc%x$g-mh;r7&_7~cgBAb+iV{ATVnh5z-*89H3IucK1V2n{TzJ% zvg*o7>{(V%xds%(*)efec|Ac=_lTPy3G(S5kmFZN80Rd}2mtMPWt8~!c3{$Y_Tu&h zrk?z6Gx3=@EL?K)kT-WLpU6FC%LnOuuGJDF<)501#*R_tjg^F0X;t-1c}2T{13vyw?c^#b)W+~4suRFe`?e3T z#lo&^@T)?poWGPVTd{quvOp^lYmTeY`EQIBOCu0tFHiq_q8dlxe+W_Scg=At$Fd2GOPTc7094XGNbapc5k3^Zj-xXs+!b0C8Kf3}v`$l!RyBBt>W7 zQe?{0{;SkR$$t8au~?Svz)8JEdo4B2Jix2oy!6w(ZxliaXg*xuRMB8Mma%Gl@}4x36*pkF~MZQKf3b z_4yQnAMLjwCFC*iv{UYH0~^WB!|6S;Z1_8bli|-;WiL@SxRI5_%-P*cn}mr?YyRGY zIV|)5p$qf|;csMz%1=D6JCdTUdM0r_yAp14Unm)1DtAV z&%#GkuHjp+c4Bx0Cx_L5sHM5KQb<`^BeHx>z2GAr}D@VXA%*|JNLDjaY7dyIRSzPRdI$z-xQzcP`a*? zZ2|al_TRym!~bLWay;zq;yBC!=(8KfE9|f0{t&$j-}eD}+uLGAT91~P12xZ(*Fb+Z z^>p(zV3i#>I%oQO&ue>Kr;ks6h__K!PFf!c5A1*X z6Mz}%dh6jg_46=5#(lZq&(HMR)a9`|n=wtJr|vK#Tmr{wJvZ+nEX1zC*!DtVM! zz!`prpSK@Y0;@*cJ|b=hh0Px?QzFLhUmjJt3;u$f^m}Qzl`^?HhZj%Jo|szgg^y2G zZf{l}M|yuDB(>31;yZ?}c2PM8@fo%=Sg)*G#-PyI5`V{VUFcd?-X z7u#T50->o{Hc%cdWGp0p$y-`S@h`mEweOkfygD;0vHnzd+HW9@W2z=viczg1`Ie#Z z^k{!<9=W!0jC*bCwEAHKih3WK@w7kYvPYhH1FF ze}Gh%o80HEa6wSeXM20Lspz*hWRPzPrl_t9cq|HJi-@`FG~$KSo`b7rm+9$9GGjNo z(Y(1>o}hRnX~7)jwO6xZLNP9DKoLx|!>H#?Mg)Q&YpEtbiq`lenK#H@X42c+l4tCEuIlXKUc^{`A2{`=${-6PaLfLoKn?CC8hyR+h zCQ!f5fBqNF+W$o>n_DMDEq=+bh=4wvV%Yn-%y`k`3VOT=mKD|eI43vZMI9(CmMHag*TuOZOeM1XnzuT_a2>yo;sc0G3u$|+3v zI&1AM^OH%OaQ#Dpw60k-OV@}1`Y^?RsFnN`80Xt`7li}k{MAQR{_1zqTN=T5OHqDc zA5oDEwvX}+wYJEbVNTyLQ%?At72dDn^4H?b%W1~Un_3VJ#Fr8b;WNgLte~JuXe}}1 zI33HA2|oCAP#@DKsz?8HfhL0BXKX!Ij)$&zHP*YCeJ~p$WIII+70`GRYsrU7tns-H zaYbiUy`p-?Snv35!cIpAuR1SXa<-kuy6I!1^aHXk8Mt0b)9Q*Xpxm6<<|b>;>2vBR zpI(i4I6u3FUnmp9W@kAm}yMIyG(1pl%jV;?GZsXZyl1_&GiVJ zsdTjIpp?>VX3MxCcbIi_TXUe1ie}utv}$S@IGvSu*83E)gp*$FnlH@UyGqIdw`mAt z_MR4MXH6q#P&|OvfPXj0t`7OcBfC$U8Z6htyu-Uiv@1xx^w!ZXP1)vyuD!48`tED* z{n?*y^fha9>D`LErKF^!uUx)(LrUrwu$0ucWB=H(^^5V>;Oko-+b}oGFGv;BmAPAm zpONRS&Pz#^ChQU4`DLrT>+WR-jFi;AhCkoin$h`gQc~~gu3S8SGZZ#k2Y#>7o;WYT z9n8?3#{4m$V1O*7*-N`Ve*@@r-v4oY_qenCesC=D{?1~pJcaX0&HH^@Ojw52O6;e8 z-3e#hDnjbkclFu|52hw29!PY!g(`AE&DJkw+uic0_2_mAEUMcek;LPgnp+L|&sViv z=km}0S(1_(I_!!4>2Kw^XZ}n+ZJpUHgSIJQuke_IGHNzzBa}IgbQlJ^xhUbr83kL2 zd0gJ+{3f_=Z65v`!frBBytFLyh0ubwhI?V+$LE{K74{_NfQj5v+J+GRvh>QvS`E%} zCq6_p9eEO?-KO9gdPL0Pk~UbIAW1uV(?Y?C!KxRnbKwstEu%J@YYb5>J9|Q7etlKV z8Vo4%^&X49}=92hD4HYQ4p?!EJUogmwn#fsL(VtaPZt;g?Sc=Btc4E9n4%KZ)uBXl5uX|d)qCn4 zs5=*1aW^O~8A&>x&FiY&0^3RdS}a))y4k+;dsyz|ICS%_$U5S4Rul_kprRd{Mi5Jd z(G5F@aSH*TfTJwzRQ=;mua*x8vVfVlxyr&HT$2!7I1U61XPF~za6lz8y7?<6@QR?8 zbpxQNm|W>(e}MFYAtIgV75-)_;u0IY;su>Eg^Z2?3?EkfZFy+e?OEMJRPCXcOT&|E z_Et8>RxC2BxeC2KCA3QX>crjheM|N_(-YnmKl>fX;})otkK9+L+CS9WK<`GsI?;I_ zy0hThH-lepWM}CaH*S@_*X>qJ0ZU2QP3C_2&lNP}Z-?E$MyDy2J=NRFdN%X{L- z_ck|K4!xjpqA}0aPB6NMvZ4wiMP%igM$>f-5E^DO^RH_p{2^z>c&wmy9l$9dfbm3Q zN*xHKNmy_7~|ZX@w?rQ1d2q=htW9S@H9D0hEuQklzrV)g&UvBg-4j ze)G*cVU1JW$7q-HzB_z@sjIhE+$DES$u-+!@JKDJRoUDVYtTaYzGUIYX-}`(C$aZw z`AtZ6Jaa$2(GP0?M`@1Vx6BABdS5I_X#4t1TWlA(Po}$2>%f%a(dF+nH4696Cc%fu zOAOeBT_~1W@17Be<%M<(BztTbxYd9Jz|ie8u=aDp1pcj*`k4ocSf0oJj6IpfG~`Y%FJ z#RT;%w6KM}+NUi^cWQLA+yWo@7)qz0)9b z#vH!%5I3~?Ui=5l?F`%}cz-v$pk3L?qenX(0zw*<7D2(^qSF*CxAZo^){5>ROb|LU=-7j6KD}yNB zM28ZO^_!5})O?Mf>bA+{oR+y0NxuY^(Q#vz^}#E{WWB;hV8i)dlw2CAor(Dbnf%cj zT{@)nG!A4?e{HTHklgE6Et3SWj8;8ghDPD`q|Fu1OO&ROKd&jDdh!9fR@ADIIxR;_ ztn^*Z&HK?lK-8)0{xFFu0ICiK?G8v+%KTO}ZPRAEbR8bm8+?vo1C^1A?_ciT@-&5%k)L8- z1{Mo0HGQJ2_IIQN{=hH&2xZ(IC&mIMi@G|qS^|FOis!+1oA_$inLU?fo< z`FQ*dK0^96|bEDqJ&wX(&)le%h8(CNo(9r$j z$uBWt@jAC=%BE*@&{f9sl66c;7w7TwPqNE@7~$=al)V8S&3@d;P=q3Vz`F6XTC!G1 zv2*78b*}IQ6p-VL6uK80{K12Gd#_P0XW(M{J#jwg>sn^-%XFW$O!k;3DFp z`~cCPY;rpwan_s0y@qj*{7}_cw`R)+^>^x{>?Uoa@B7H3Q@ugk?;-TSwu=tZ?K0@m zeKKSAuctSsb4tmw_d@%$WttsrA3&AMc6uWUr=JhH@k0s@{F=T`{N+K?7X*=RD($rkuSW+Fl#@c9qK$f?+4iPC49*j4uVi#fW%ZSfBA{Bm zwZT`j&-_@+a;4=-I;>%XBic?;A8nZLOTte9U43!8>W}zjWH@_$|3*=M_*{SD2~QiB zS}vDpuzN7zh*w~nAltBE5mK5myr>}`cWO~%U4a&y$Ba43@d~ZA$gxh@gOqp{^5fbQ ztzm#Kqf>|_U?zX5I<@8>;rFy;E_lJi?s=tu==h9etS>T`QwK)l)whaS7M>7d-uM)_;n+(=7L*3Pb(oGX+vf&4zO2EpX<0uK?oMYtiY_$z1i-5@z;=uy0E`rThGi*Ia>C?q{}r)=A5hY$_gK+xGgf z%Q00Pbys$*%)ZH=(8m@TcM45T>yAAv=HIiDnbqqB-vm@!8DxW8g<1Mqg|E`qmgU<_ z#iYn2weLLBh|tdYSuLvv@yZrJZ8j&YV_lKt<3GQ;0hCK}kh>nx ztb_AMXeyl2CPNWD7gMuaIr7~DE|pQJat+EwZy&Exc_ZBn=gJ-+%PFiUvuf@@6X8mD z^l509!mAdC2fSon)$@b{L6$GrEG10Vk(TPrZ2xfqJ|rL7=!DyRXrBBQ@N=BnA(vF(5~5| zFND{8x;TZp&}9A7{tlvJX_m(sDkh&(!=gQd*LRM5Z6W-ki!_@NZ0l>iZ`{jVzowRD zu5;Xh0;3P;klpfDZQFPDcLrD?1a~0ZLjBE1XE;^*XDFeS@)BvaLgBf4GQb;uK7hW+ zV=$E?a=yO#kmjJt(yyDBGpp2D-YzxG*L%$uC&(IdNa#EkAzZm@twI~{&M6#=HSxEhEK;d4A+ z71a+G>i52=&3O?wRGoFHRN+Gk$6~cpTj{JB&TQ}UJS#G1=i2QN--B+@FwC99E|?3g z8V_-o7b`(0Pf91KrtXPRM2Vap*1=eEvhK8D$sWImE)y#l^<8VfPYfW@ZQb6?$QbmN z-bJh|@a%uT%qte1^#Xk3*!GZj7c-YqC9S*BXI}Cz-Jz8p^m#?Qn1Z>Lr`j;HSgam& z7MdYYG5{;Ri#w^tIZn!+^J(VmLO!(YOrQKg#rOTl*_o%80O9WxE&p*-4Ym~v?4xHP z%aV!ZgqWd_dQpiXhlEdGsgQBJgOBJ7P;)LL;uhjy+0~X49N5+VVQrxjm_s5HFuZKxGo$6S)3;{a8?&b2xcM;`Pn#EuMOj z9uEUxRu1OmiNG)V6(628=;+77dwYs)cd9%A$O}Zr|rJ_vX=MOU&9XBbA}-&7Os4Fio(OX504n zxNCHoX72eA@jYMgT317`Ps;d{bH;XFH)(4(RJ-1y&2SguC!a@n2UpS`)(u)?P0BOM(7Af?VZ#t-X;%UNNig zzRiI*Ws6M?&{Z5&P$f7B)gAVaC@W4I43zC3u@~h;rL0&fy_~#*6#wxSx>sZBM}4V+ zAgaD7#qEJOWRo(rwK&-Mb?LbR^)~cprFG-7WpHLr7{xK`jjn@8k+gPhOWgSk zN#tE5s%@C5^26vmedFzrS(C3uTcdtW3N(W*RnNy1BD_ZPhkxm|fN$OWvlnbWw*|SP z_O|t1r1!3W%nE*K#P5tqZP6(`#tL>ZUir8|$vylC;FO>qZaMy`6qr;N4T_@|t>BCU zodqBPsxoqRs&^ptpEh4w-gf<)m`3vxB*F`+ukb+J=2N@Z9L-J34u=-Pb+SE|R0 zN+(!7xUvrwY6w@OiONfky1v61K`o_A&One#_d%&g{maYN`vg+0JGWp(J16j5BP$oQ zbZ+zpkt{r^7TCW1F8*Y(%%_pGj4D5>EHjNi2r_u>IEXRpl-<=_PkP0`R=s&|?$7=|hmV+E?Ygj#-%q zg=4s9kaiEN7#CUCnTgUDvImlyCfD}DkAQ6QGzl}&i(~Z&MzVuE6^KgEzs^X}9ocI) zCMy{3eym)TmRZ((mEHamH0q+g57Vdulv<_-?l<_tqJT(}vnm+>p_GE=0Z@9T@VrGt zA6>>q8IqtH{Ev;!vgfL=3wOfb7`Js%#GU5w)pi5!n}o^cS$7}twRer*`Vh%$wUSFcg7{L`IJkkJAvf|$xR7}idaMKXjFIpLZ)H1BI=%y+7UZ#hU1G&+Nzc@( z+78c&%iZ_LhZuv^;v;7q5H1^yMEUir8HxV5+jMbIRYY_^_Q1@+p z6%kNU)qsgr@M-eGYByG$W!PDkD%CF?o?kK;CaVU#1Qv9q+u8Ii5ABEhy)b~cN@Unp zr=_H-l`=nYcYUS>yq^3?>hS@huVzpaoVc6 zyZ%N4rGP)DW&W*vujVH!o4=JmyWG$HJ6`-=Y%er5E3rXCeI#$o1yHv@%YDnHibA^~ ziJB?a8Y=cjm5!M^%q7(3MsC=-`@Dc|#7i#s+lpHT{zQVrRde|%I~;m5>U4SSc*N#L z1DEG92kD)diOk$s9QE@14#-32QhV>F8lK`T@h`uI4HyCr)a3FN7rmT4d=t#$`QFg3 zhFMx6sd8U<1Vl7Y9F0b6_m1_~1j9 z?Lo_Vuig)m8!3uf>KRt1A$Q(w(|y{rJFpx8uq>eGwgg-mMFBicvSK1k=u=V<7Ea@h5TDjmYJX^ zgp1M;wpEl@d=T{dU!~Dc4WyYr?-nWNn3)eAx*YecYi9pCLuBx# z_mx+gO!v>bDlZ{|73|U#ob$R#N_y-QPvv;K;YQj$%O;w}W2b1R6xiqTwCtKi*g`}G zv1SfKSzcKh$f!7BX5_FJwYX%a7pAy=gY7?QV5I5#T9~M-ET`*lVL;NPpQmgC+hSr= zb2qr0fSG>)2T_K9qkZ9O#UA7>B)WJ&5Ri%OXE&4CgXu|_fL(Bpz)6r@KJSq`+E+~Q zG*pT1Un|E=M$}6zy00r;ZPO9+-K`B7Ex9nC6W%N4f{&9XrNJ}}bvQi=Z%w%ez&&V{E_?6}Z`!p6-IcHlYEzPQ6Xd?-r=K-S(=%zTfnEK}IewHCQGg=v z)B2)sCz+K|FT#Yr_OOaR$NI;co#>4PE1C6rH@0&^vPI{^kchrK57vAi@&aVaiPDFF z;V$K_y{Iy6I7W8}T%a~YZcZ)>#z>ij->u{J9f%q@z4%QWzkb98-lW_%%U z;;%47R71d_N{sfcNM*QU{)yw)pY^b6Ux3V=|fR>Fk+N!00Ho9>&;99}>tQ~8uEXiSt{c($qs6n}rY zx;gxz9cx5X>Tmt@O95eOltjR|66^bf*}>W^tdf^N_eAkR)<~h;TKd6&!n&|tKYK4y z&d?u@GS>d}?g`+|51R$^ey}UWVV4)i+5^K5h?1jrC~$~?5i1Ib_s%Z--XisMPTk|N z_o@xRdg;Mqk)RMQJT!RF=G030W({2OnVKY-vcNx2SjWFq>xe0wh$&*e8!q;^4v21I z`On<&JHcY({9YjK<`{hUbn&X6!Yr65yKON=P^n8U^k?h`5GE=Wuzq@B3 zznHnU!ZUt9Gy7x}-y(6^* zxj7x7drg2Hxp^bL94nh%8hUcS8B>?{3ftrG|1Z2(1q3s#&eU%^B6C4{>`A&WD;<5F z9Mwb9Axn`H?-0G($^2&lEr;fApHn`jKOu*oDNmcdSd+3 zru^;>wNn8qA2Zz7)%YB}gtS;S>-o{hSAN)oz3G{EplHV>C_%otEy85pR*3j6yrTy$ zq8>loF?hZ~S1xUF)wn2Y?tskT#1DJa)rBa;$~oWITPgwRJ{#KlKe;jQL+ydsfdh*UY%)SJsYN<=xc_3!-!!0EJ&*fhBdG1;OX@UI}s^V$sEcn?zsH`MZtlT!*W z6?dfmqqBW;bl)v#h7oF+Q<}q{=Z;R9cPY$zsEOrstW$aCwGR&;wz@ryv>czA5{wOI zN2F}3Lo2DjcPZJE=LXi^4FpZUZdhn4LcQpJg(@o6jL$ZrR+&>y7eeA|by9iQ%zXYj z%G|fu^WOr!o{{dbuZsL$qOMYSuia#+a`94`IaE_u5~kL3kznjmLa?P?(A5#YpI&;^ zs#T>KK;^$1ukO|EFW9JPwy#XoGFmdax#>CbMDaO-Zj>_{+xca*#hK|=E;4no(TYSF z*46$>0Tb&BX5xld(Czn%-bblX(0P7KC`*D zWHGh!t@1=H?bJ!5?V$L}8`JnygeW!WfJTUNgB&pUMd{d!b=})m$)Ge`{Q06`_wZ^n zr!k8hYRZwAH@eJLcc%zU-AlXKFrMo5Egnp0SyIoqA5GM#-in|-Liy{>$3gfy9Zm1^ zkJm5wyVq+~#nt9AqmOPnEgh;_ub9W22s&Fs7YT&>H*Q$^!;1)8Bqv9?!vM^d_NWcB z0FneJZjC(p4VWbKs2uv)Lt&OmsY~qP9A8eK9WyPxLMo#N$aZ`OveP#IqeJlwGe z>9paE+)uwRxqsd-OrVDFP%3icV^m0ys|M8 zK1}6Ul$23317U1;TSr3b_Rf-5^Z|K_l| zqI({ISqxp)QGBVS%(|F(INT)(8*!Bsh2J+du|2{se7$4yvug?AJtdj(=%XLR%sK-X zK#JNh4`tdq81L%2{S?5PQ7Y&^K|oDmX5+3aciA{L2M?RpPc?97)5Y6aNno`SZh?VK zlZ>a_1twF$l*teKxM*b~s@WZZ(4DPP%DOKEg12*zu5QhcOSs!IRnrP|3cwe`W%}M8 z`21aFRnIz7rdR>N7M~yGzYX^!NGqsA)OUzhdC;!yxKr~6yFQX?$se8-gWin4I1kAQ z9aKJ){KPz=#xGS6ugntwM9;WSWfS4@$B5JF85zVINH6Qb zo8>dNL_u~DTk$f*FPL`zh2GXmC`26`Y>s~%b>&KSUdThZNk=m1}o7hl|&aw zr)Ns{m?bA*UNEkQRI-H%=tC1$7z~>;eEfEs)%UYopie10%=8`YN+@tJ5+q?CTOBUh zI7JucjpjzoKD_;k6cG>Hq+120uz~a=CX{k?vK-A%@IL%p+QKDTs%&`053qeG0$e9^ zn$;lRHS{}7DWQF%AZSXmJZEd>zr3M$~`~cfFh0|%}WTN z@LgTD64lnNZ72n$NuN3Z&zn{7a9;x<8hf;N!cmmwKbr`7lFV49$gk|qEE%vbNubM8 zv|UZ|;=Yt}COT+&c4}Rj#KCphH)-<=e8=&R6bd~ke6qI5rl2h0i~Ykl!~@1kdxXWT zJ);cBD!2~!deyY|(GgQXUhElCeX9X7Uh?b(vlZ7RaY!Xhh%xs=Vh==Q_-6!e5WTVQ zcrTmFgMLph9GVxUxX7wZJ&5G)|seWDNz!!UH(xrIa4)h5@`sO1nUvZ1`ttbDB z^Zf(3Hn!tcV6)C@F1ATJM0V8jU!boP;7up?@vJ1N0sHrdy$+?4ts8~HQz3$FuGIfA z(6{3nI0UdD--vDX#Yh@Wo@uUKLA+0>zfa^^oS05#{npG#MI_-oS*-=TG~W4OCgGcAH&yERwmM& zKW(EM2x-BuNUHW8ivCxXapKZ$MKG1jdUMg%;=ia0MYNG3lz;bWa4yKyOJX=3&hh9#fM$E~dq9J{2%e6yO1+@^tX1$N5_T0R>uDEi5)F6zl+~Y6rB4{uBYHuA0$VP zVNWTL2=V23BwwS1^X$o>3scPbp{uS3U3{WpC~wWg9HqPji9U%;P#eXj8ueBo#Zk@^ zR_~T4du;W=z(zuP4$UrLvmsP=kCfEh6HOoNPj}0d^T|?DCq2^NS!Ow)H!V_=483PbCD%^#+U`ZVC0?bIZ+} zQ`G~1(QlRUKLX@`h@bo#R87E=f0@%{4YMprml|~M@w*}td-=ifwDvWJ1mrAgnl+B6 zCrS|?`ZHz!2J58!{-Fon_O~BGe))e1_s;!q$GdCy`hSCk)u;ZATDbe)RR7fvyTO;q zoByX3Sij}7IV$b^TYK+qcQpXlZ}e;lf(O$wH+l1Ukz!=TzQ6rLN?E%0|Iwxe;QzOC zzMGz@2F2f8-Tt>~K40JBtbZ&2_bailb6GsW&5INbawj%V;0+^&X+j1QcxT3Dltg*4 zzOg{^8Bx^BFn2}c%*VAb5L=|pHlcvUw291(BR$Nf;*0O61!>piXN~|}5Q!%whlC=w@5L|h!vxj60mo{;m?eu^Pp;+LS zDQUstlEo$m^9crYi^gJZtr<>{;4sKeWY-SpzZ^_=fcmPeh;YjlF0!K|&V@~7A+xp9 zlome6J)r!=Z9g^rk}@DgeJL4KKVIQ4GtSZ=sl+R;o9x6glN5EWnPm#bwoLv9>DPIl zSS;%9zRZu@E6zsL3g>L@73DK4dlr`#ziMt^hA$S9?=qV_1~27~6e=#sG5TQv#qoiO zU?d%8LZsyDSD=SUKkEbJ?4i0y^;+T_ZMDLWivAbLl=<;KFkZY6m}yJ_Q;6wxylkBU z{s8M{PbTMm<-V}d$?)|51eqh|4xIOUxp|m<&7Taa1uq2fA26LPVOAz|wKsFuG z*4Z`A%z3P~-Ql{r9J@=W__4T@ev~urS<1kuTiL|rdG3bE>VW|bx?bXfp>s3n5fX$F znVn#jXhW5&lh5RxaHTyCZ@!dEUhEw&Nh}Lz*MhR63K)QizM^4g=aQ8F5DHnJ6#kSh z|L=s6^Qj8-TiN2)OmkgTEju0{16bPf3Gw8PT^xn}ld)Dtw2gdNFHy6%n0#^d5;PK) z1@k~u1XLPqSTx=}o_UN?r*G&=TfD2Kt-?+u5BJH8Eayc0td^boebr}Q-5J8f1QVt?%Im^ijGjHl3OYTuEZu)o-FyoG#gn^Ugov29kn0Hni3dS3e0P&D z>$Y@vz$^=MyZOCyR^ygnpuY=FI>=)R=bMTX!WIuJq&kIR-u9-JbDBn|Z@|~=d9Z-4 z;l?FNN}Y*jV?){xHv8!JCIWZqtYiI|W{*>v34-gz??*=DB0DY&IGv%=DL{0k<+ zjSaE^!t;>1@XPp{4()PIdADj@Xj!m>F(vH2w?mf2&BCE0q|CAiiaqc7${_1-hQIXg z$TV#K3!b-AJ*2RdaCzwy;MV*DZ@}v!Qq{ix6f8Pv;=$-}MFX>S`N(DD0J@}zpo=bq zzaP(&Q!}L@Ur59+{tZFQPPtq)otjUfWAlF z?UmE2x;Dn*pQSKd*AFlEc@cfuf>wJkQ_p?$+8jU;(~`{S?E+PJVINoz-p`NigO)*@ zsuO|Wiqe5t;Uvq4PNE=5n^=*7L5W@%63$#Wc4FqROPHUJJQh)Lo)8$2Q+Jz=$hSpd zslI14ed^QXWD4s2YQnB75(^IeCYdfpXqL-M^hF{WA8q|_Z2L^RnTmk2aw`~sr5_mq zLE^r-uI{Id_pMVW9qSMywVTz%$Pv=WNrWo*7%TIsAid(Ai9aVPpw^Hj8~CCfL)c>v zrA}O~GRw5nelWXIV?G%YQsx79a|Q(?UM;dBGfe@7WB$ zFm3H)WUv#&Pf)i4HY+jj_68r@QMW^dY+|>{}US}04M~D zH6>qzFSfKGOhNNIvEdHLR)?B8!I~0@cL+Yj&+92%4UAB?zO*ju7%9DE_>8$8>$*)$V+{#l%2GOQT2 zA;QtuvYO_O;#y1F`QNd1Jpc4e5+veT9HY+jQ+O>k+MpL0lpl7J!vYT9(2cDZ=5li) zbfrx_Y`k_{Pok?JtWK4Exiu4p0fh%TqP|Zw9R&j3TUKy58iS|sSL)By_U+OzYbH>R zIPA@U>#l_7$EW3Tvhv!42n6dEQMKB0t5U!ISuylbrl8@{r(qOND8iH*Ujz5Oj{RuI( zfrPkyt|#ywCGGl6j1kpa>@vhj)cAyDU#uaeb2D=QXpbE#J(6t`W+F<*%?5p_JrvQk z@S|1c-Cz7K^f>&=pqUcyf+9h_@Rx0xhGho zq-$#p4Vxd52tFli8j{KdSIuO36}ZUJO(zgf>zjXLbhUyzj$IY>I>KXevL`S3ArIBl zX|wg*R>eCN+w_=2I+No&>^;Yh=POL~wOf%g)xVo+?+EuUzdKSw`lgM2FrdK&jswx7#X*XWE#FeAVw}xU&sQH zj`(F_kNv~KZbbjqlR%tTT9{3)o>1DMZ*~d~43=bRBfQy4g}&di*X+fb)%Zcery`db zL3NCb333tB>s9=%&FI-U*$kIqlT&iO>=cG3FZ}q&WXHScv(5js@jiRw_nW?$;7t^& zBv}Mwp%u92(q~wDN6^U#{0n-hdKznh?~W^6BU`Q%nV>$dx&TbW=N5rW2-IHctMoDf z#S>*J0B!9S9NsWrs#PK1<_&=3KSG`X7-z_a>tQS9y~n3#|ONzF2K} z@*KYvg{Kly*$Qvo1K1(wo ztV|JG(>r1ynw>!IJ1`Uwd62cPCD583Y` z8^JY&zp}sQlx|I$24h~bktGQUS{r)7+^kodm49%O4%QQNC$dIH_RBAVbaH~8vyI%K z9ZKt2qmPcq1|<`fk~KRH-KRUJcdUZW=gX)dWw)JI$Fr8ey~ z2)1viG;?HU3h&NdudVYDoP$^wna=0gAqPaA!omUTzwF)lm@RvEcH-k|-le3MA_ZBH zF}ALw-{N`Z^y+1gKUJ&Wk-P{8yMNldmfd?wwOL)PSX9EYP7W@2wdsT10FDHrBll#Y zHT8?e2Ql|kPo&VVA$!mBtF246*?GF!QFrMi45qh6Znc-Q`tN zbUpbGn)9JUjh7O{y94yhSSc!yYnEGbX-}w|`!R2+k^1ZAL{FA_yj;@M3UM6diWR zW+91sTP<=sh)LK>Lt~~3NeG5hRIrt2#{1z!5|2lXzH<#3{U3^+6N-{Cc(yw=Xh9T+ zy4KZL94qUpS_MqnG_0>-ENxB5^lIV(peGg0+b7C8LIS2Io!@`SA`Pmok2MvDQhn3O z#G;3~ckksQ{pf)oN+K;n8=`H~05DVEx7PiCjOo!-(L;nYw9g)=+b^qF$+nxZ$gzth z-hlNRh(z8YN&wt23E%B8Z~@6+SO|LjV!f+{Grta30^%S?OcS*r9N$bO@4ZICtD=Jb z{Ixe+^&HYgta#t>7cGl>iph^CK z-9~~3&&QcfTPZ-0XQjshxVDw%fFT2y94hD(ylFq?ba-~HUck?-01}V(50`7E%u9`y zo<{C7fk)lhhfEuvmdr5D>O%m?XC&OnXSsKKB)R9DLIwyd+eX)krW%LqzV2aA%F8e> zx0jzzl1p({0&f2&Z^LH(HlhQ5_NF(j9MLK**5h@oWQ)>!byTBhR@0y5ys`_nA64RH z!P>+M;l1_*sE#I;as%n6OnRvsIqVD18p;oz2$rjVJQ?$`K2~czFJQK(;+XAta+)?G z6QC5Yj`o!UXC`b^Al}l)%8+Qix%ObtgfQ7Ci*y?jClL;*(6@#u_kBHqrRJc$=$m0Y z4%|*L-1T&8%#ho9%1{g89)82QbVnQ*tuoD-+G7k@!~aIPc#VJEZ2B6eFPrB~ja?sq zXS>!G@xxN&pSpfp4MLQMaO-_UgiBTfFxu?!1y$9ks%jcyaE6Mr4B zdDa%Z`GyOUPRsmmDV|wp0+)j%(ZKVh-B=3(5G*tJ8f?svP~F7PO;qQ|!O(#{%VF&8 z-}}Cw-*Znrpcci}{}QbB`A-Lf{+Fd`S-qJXwny|3=V?p!8Lp0u*zy>Cl>FA6yh z_lucCv3L4}9+a>KNu~|D3pljpG;4ov#9_CKJr*W3)8V&c+gC3`WbB7Jgxh2sSwCRY z6Gy5^&T10_p%HgsKQZF6yTyrgQnjDoTXL}SnnrxK^HweY*9_Z#hr!0AtQhc2yhwAy zR0~R^%uSA~b$H4k(ZI+s5Lf| zK<|geC!ZSIViIseTpM&LLiJx*==7#*WGW-FrA^7H;Vp038Lzsl+d_$>;mQhxsYBx0 z3uhNw{9&_cZU1M1E(ks@#!$StRkkR!)Akl1ulM zOU0IXx4=!1;jLt&u1hOKQ263qQC4p^B`ES>cdhdulK;qLMa6T$Ll;~uMNhHc)V|vN z_f*!3l7jyf+*S=(&Vy$PpW3t{>k>ZF?G(CfILh%szIS;Vy!x+ot{3^2z>#xBa&d=C zvnDmG8x`ZK8^Z`?S!UV~{%Y>8q(V#!h`cO<#VdJXy-nP+rxjmw!|?Ha%e5zfoy$1> z#`2;>DZ=L&?kr`hX~9{pQeSRPMJ{!c>XcM(MxTr0PsXa|Lh6$dRlyklI>(O~gRkw} z5geT7AwDwgN;U6p51V#RP~*g-(bn`<4P5HCRF3nTV%e0sibd#z>AL=_IGc9l$m*JW z5H2hmjA7$}J zQ^}jwI|X-651(E)S7&A^5tk^lkjY&iRJ|DBVmOHde6S4eSPfIVrSZ;5}RlNSG?-V zYOY;-7TZx&OTNLg9N@d6*`2B<3xfb4u3zCAz8#h+J9dY*u(-JYFZ)cPXfy>;_A;qi zbdq%oJs$)8k2F(i^DDKQ*}TKy10^cy-w5Jg)mFjBq3`~!Z8pn$Mz7xS7;07L(;8G! z1Xt_rvp60$On4V^zCy0EcWZ$6v8bzR4frxYH}He&hTW61wN1|IZBGUEmfTum#w5I% z)NRG%TbK5)aGbmO`%MkNF_KhlBFO2aPFrd6ynn=WK+oiZ>lQgqf^Qvpf5I{rxcypT zKZ$PFT`NE6z~4c!zHax0)T3|6i%E7mxyTDs7)6#!RJve}a44iL<|PUxS=#teG(vDX z3G$i=BlohiB^!d2vFG~pZ-i}`$i@x-nZf@vEz19s7I}I5|BDvohaoBy{)-kFoVBQR zwC=;}15|HB4hQqI^bBp)Z;CWZXN8LsVS6Hq>jc@gp+N?@zDqBx&#;2GHn@)0_d13VyVTi6&~(;dI~TYw;|J7G%!DoW(M+~ZOs znM)K751$^CQ?QF4md`cuT)I#aVEdWVw^Cc%EpF;7Q||wc$|)hJZg$FPuwb|%Mv-$D zlU_o3x~52_1Rh#<914S{njWt^s%;Dy=e%bU{z?}eBmW9ILwOb~*2tM_%tc_H>uwr$ z1!8J-q>;XwuAQjP#CA*&6bxVDJWl4Fu8-BI`H@>Wm#u^zl`BX}&{2v@IFm0&R_P3! zjG&p&{}ei=5O{j!nk(lm4%W%*Xjr;5$1&pfO**53>b}D0NF!h@?n)szerC%F@~+3v zm8b+0i{c?qc7@!2=aMkE zEMDf?`rzg+F^Adu*5##?0#jPHQ9Xs5qIh3S16pa$rhCbs4Zo`UG%avW9_5%SUF^z8 zZ|~1$I&KF5_L0LJ-2YrZ_QMt}M82K#uWEKR|3euQ23KG`d5>G$B5)4 zW;=Ccp=GRzw6VVow|<>u{iT2tK4!}BsYsjIbPMpDeayE!6IF>7WyQ!I_6*_`-Wn5b zVz_v~FT>OSBza44yF5ZyFfV(<2{g(5uFSL*p2oYAKg7$HzQ=#kXZ+=f-4X#siGgM{ z{X3yM3T!8_&n+&6{MXWPZN-x`vH53M#4hi+U0h>8c^8~vH1#{qpO@xlFu9dJ9L>&; zUz|+@?0KJ{j<%fql&Lu_-WWJP_{kcuqdAjM7LA)09G%s|#No^&dN#eNlwIowL7Khz zT)PTjQ050}tv;utS+$cvI!g+Yd@@X}gIOzNhkxlO9#F0 z22&=ozc4H+?ZhweN`|cL^10i%U|6m=kms|v=Q=O@1#g(;bPB&kh^X3Cy3gR8$(7M% zx3~`jQfXdu@|b|6Q+oafyJG`pCUR{UG8?b5%8oTJ1v4AF`m&j=>8Wz*$!ghL?zl#u zUs{yEb{VTRAXOI#9u}k*lwI=NO5jWEgAqV`gq`D2LI=$PCTH84`Y_t#Sb z`>pV>xYsH2hWXdXE=z7 zL20P+kMYOCiVwg!vRqW>Oq;W=8RhgVP?>jDhyOo~t`BKyjnCbL*#AY|d51NXZ~fjG zN7PYd6qOD-2LY%5_I(qvxd5fnSRAqie9E9h2 z(1EXBP3I?cN)z2GpEV%YHp&Ta=iDj^_R~I_%fH!+|KMgk@KzF|4d{%%Vz*b^Wr6Py zk!dwi&PRArg+FXzJnK9h!LyA$=i^T9H)Nj13lGM+wspuR9 zmE(oyUJ9i2`P2CtgwF&@`Ep4}#lO&n%I7N(2tXI=hT6Zf@IiKk6lrV^w204%M2ETE zZf?Uq#Ew+ePUO)^q@DiaCGHpN{mox<4IcN@#UL)i;hTbI<3yVw)oxO`GE9tb6a2&u zVlgQ;kVFj)lB^m>)mUa4-@$D4y?ZcKM(EI2Z&r5wA=n2Iu$`4*qa}S*;6rw=-LL==?a02Oo z0v+iS*QO(xsvmI*c=sMk&61jSx`S|`0`8(JQas*#WX|&v^=!Ge$TJ6P_EryDDYa0} zl=$rj@}4$vzi{7{#VnHMK9AWbsdyrzWRd%j_CP)NrgR9Lr8d{3$0D-?T4#qv9n0F5 zU-{W~`byc@a&E}MPgEiYgv3w^_3kQI;0Zr=Le^-}{nyY=Rgn!3^z^6CmF$V$vXh!_ z)hOTpNbM?2%JS9||1ENN0HY&lA3f*z%g2UG&|@}PfIaM=$w6QC{g0J{!Zaf^-}L4o z%j{g%56rr)s6ofUI!6Pqg9?CeU%Mw%Kc*dWWc@~aZzl8n#)bL9ed>u(gl6;v?xLbc zt$=W<1JA3k+}S}u9cHFCWREjdFDBar`&;_7W5t6kQpcYze2AG!RHyES`1o19%JDQI z1cTF z2;+AR21Ush`!=Iaj1hr32eEAbV9S;h5ym2Y`{=X#i4mpa2Rdcy+&bOE#aWqXsgw1D zVa#N6_4a)g=xhyH_v3T4t=m`ZVJ)+zaR>Y+QR|^;N>i~vx^z20 zt&Fz^JrxSC^4GcZkPAGOVoriBhc9lcl(@o}pQAq|CWxM!{8gj9Tux(B z4w(F0%RJ6k2?4Lw2lAG3{1edj7IaiNq(zc92*E1%=@bop73(aRI(6?$rMDZQZCn}B z^Xyi$z(u#or?e-43Uo-Rh!(N?k10Uc79?IibZSh9XZgWY zmhoS}_VCSGBOB&RCw{o?JpO#cyX~5c{0ZEr(w`C}(onbM0eguKRii&x<6!8!@N74G zZzE(>O6v5Lf%+>;M(#X$FLOpHFFuQTm21dPJ8e=d|5@eu|GyH@J1tf4 z?9MPKc5&i4-z}uHr2mnUtjp2-D3em|PUaz!EVi>6#Fpkin=Jn@&V4>Gpu#CsvYVqX z)}fbWuQA%hl`K>LQG}9y*%fZP^v0u&!)>B_8lruWgxwTIKKIqauM|ACrU_?S+;nf^osfugG+emGO2Z%N;MD2>Uu(cmptYAV6lh@N|E_y z37T241ev|uQd_Xapv(;14@$@@>$i|%%3FW_edsEO`5f?}tNB&n%(G&fS7=y*XTGnu-D5{A>h9DO1*T7t4`|##+QsJ&GCXaD%{DekcBpjHTirlqzvwy-FlSSun<~T#0E%{D$H=) z=8>9cC8NIBM077kq|AGV5dt-e2p=N7m1NVZxwQ|nyZvUW7e`&AN47jxB)M)ikyb;! z(5M^VvgMzffS9NLFPMNV{@YAIpZ-%OAe)cJQn9Mkgk&m-yyXRZda=tpvwu!FsqX<` z0kZ19P8-pL(wLWLBVbroHQ2j5g_UZIx9&rNaHG;WnlCZa7zxK;TIZ%+3R=c{;e`$m zho1A5VeKa3Iu64d{e!_N_3~ERn^cQ>*lZ4pZot6>rYsR#>M}E)}P?R?!_pv_PoWZXEpej z^ohJeOP64z_QqDzSLOM-$t zF#S-SS4}^9qLBE~kEm>I62%DAFfn;zzqh`hW1ol5IMUQ@^gi9A`K z)03d1_Qm04MHA3G)Xnzy%$`@UMZJ8g;Gn1TnD;t%le~s+@7R<)em<5ME-ArP+7})7 zR+BoJ$jT&&-%6IQ1I#SFx)>X&h7AXMcBUy`pC}=4L|ny8BzkVR1W8wsdUfAVcb2!f zjNHp6S)xB_2I2FJkeI8qi{JBkmN8Dm73vm0S7@~A`MAmr5REtp1=)=}v&f|`qz?o* z+ZUw}yi_H*k?zFUhEiGrs1H^}vB`xGm_?S2d5d)*Lo4{!uEWHxxU8Ae<`SQFMeH5|D?spa_#Sq)l?91biBbdGb`qfJ?zB#K#?il8B=?Z;dyy>eZqr z^x36inO>;9JIiEs^C#XDdUK%XI;SiSPrd}OUS7{8s%p~DCKyAI(kEeF zr5p=g^*bj zpJ@EhcWyxs)ihk&r6P~Mh-9|@)G+Z?x3ds}odHCU%#D+7j#A#bWe_UuA=$3gb0y1* zh5}h$ClIzw_Yf|#h4QSry@_K(wz@ZUxk6vX9Cm>isA)b{SG=0#wRvgv$fGE_+}6S; zDRl3EY8fc7>>evaoyJMz-p>gVr8s0oOkn%|Jf(tIQsaC03eZOW53RD1@Vu&w<1U%>nNmlFbtpNew5D_r`5V8u z9x5x-)HS!AOdX6jH9v&!vrt8mQZK_wKbwC>HH9D};wP)>()rGwolB4VBS%Y!fQ%C> z5S1GA>=|{i%aKwTDQ?n1_qP+j0y!CPwy~iPQYLTy;<2Te$RNaywm-2gcGBbq>p>$L zu~_eS@#Gt|y;w9k6B;x2@iQ^K)vk#$66!nAdwHq8pur+LGu(`EKaf=`qcUmqgnj=I zoc1GQgK=C}Y*2IkW9V^&e2j}gE-rJME*v64wCSQ)&JPEZr1UWJ6 zYE_DxA^v2DRz)>=nD{NcC6%XXBCT^sl`A2-Ptz)iiew73PZrNYH!{6ED=y)w%Rj3>LSW;LK;X==c zfjDDM*N@NUP9Mo7oywN*QPqt7dL>ViuHa27qBGHs3l7k?(DYSS7O2Q28~;py;wN1P z=jjQJFje20#zpm1^1LRm)We>BpH|c%HCI&b8Ye2?5&i9^&WnOesm{N;;Q~W6M(ZxAW=qwnT+*ke&BL znhu}`%5Tdrow0{JU9xc|@*9N)kBzON1F4OL!-?qKF}>FqOt%ZL#u}Su;|9GdAw!wH zE{f%85ZU22<;7uy*dz79S>YpYMib1XDSEr&+`i5@FZ01)uY@U#xYl=2ZgmM)7-_Q= z<_51tJE3>N46J6OMnVVPx)lF{L0&C)-ODuP>7g9@H$^~o=o>c{+XtouJ&I{`uWe&n z*ANc`KH3g#$J;oft;jod=SuH_E{n`Rfd0YJv-khd(SsTMZ+G+<{(Fv|@J9FF7V6r9 zwa=x7#&Wj3lBuJ(kj_312-&WHHP8i1?9x^9z5ga__T^CUYZ&RP+p5p_mys)|n!AiC zEl950PwxgCViPu=n3G7r&O;WICabQ&MO{uObd3ea#0_j|F`InVc21*-v;1&6GZQoR ztB2>tJs^#Dcgc2XBDX`Gvp89FV#$$^gb5V;4Cjtw{(aoW8^Wumo+*aiNj2?MKB4L^ z-q(LVaM^G0W9)O|4ZZ{Me`V?k8vh5To}X7uJ(Lww&pC@P-%LH6>Qz(E=)=D-^;{Et zlQZR%*Ed~W^rzVrV(4ZjI%=#E~PpEB8x@5=(v1_Wo)4V zI*zo|9N!|%Dqnu_H@UL^qRFS~2V6ph*3y*mBI&|P9IQt78FewOjv=$V*ohP3)m&0bLFTW`|A`5_`bMoqcO3^`i)p|UK z^+iy2VN2~yNx;V|Q=E3j*j6A1|gsSUETku^JkBRHZ zyi+1kE+P@%jXd5Twm6V9k-Xl|Q>U}ieL2hf(ojY3=3_J^J`cFV%juF-pSM6hiVvmN zy`R&pIax_Xjn49B5w5)KP+}6<32Ji12S}wu`h8=^T-KQM-ZD1vIVwqbvZf@b;}h*1 z&HsT<)gxw*_i~pxNLgIO7A7o{%=O20%ve|?;&EJdpB^qj+_*is zSis)91Nfh1bTPNbaS1URW5zN)6-}g+YK5*Kbt$sPM6I`yuWVlN}n+n$9JKB1wfB3atJe<@JO0$m|RB&2@~9S^VVt zMe$W^ybn4UQm$26z3>z^JHUs=-%32J&Q@DU^fp5#n3gbibc-x4q8@w4U|MV2J7z)K zE5p5>%<^c@iCLV|8HlV>IMqY4lx?uENk^|CZLD3=TUM#=K~DCUf8^qk@HzqK=A+=l z3aL4MU+9p6u*Iofi9Sw}fE8og@}s#8nuUjXHks7AAXj)M_2kfbr!h`aoEhgvRoKkK zk!%5jt0(70KLO9CM!@~yIU=xs8%oQ|zI@rG9fXJwBkA-`vqv6=riRKe609=%)=wHH zqhDuzIf+lx$aKF)j*>`ozdJuU9;S48`#e;o*h~9xd;6E;b-9u^^Foq?0;kF@C?yn2 zwxeg>4x2ld26;zq39mMb1NGG)FEI}KI8-F$L97OL$d&I9$}*Upk21xR{Xw3-e9(P$ zSFMCoJ~>Gn#8clka`(ganTW_y=Hs!#;Q`IP(zLBIGD1Iq%Fko5P4h0A>gJlL%us~o zG_QOXw!4{eV&Y2ZO?2?wS2BDDLCm*CXwt3cz3r?&%0qv=m8cz!srR zW);j%#f@;Qu0MR!Q=UB->~vw4+EN#7ajkQH`X;9AR0Q2O-CN|c&krt3r5@DFd$1!>^4y#>Jyq$Ni;`&msp2a#4e0B9K3k%rXE5B|^$d2YEz@b`tpwuiU6MCV$)2t$^q1>R9`sq|hH+DdrosbvnkN zK?hcET&|rI!GGrhrj(tIX$`E?sIWyYPiZeK(y4`n&B--K7jR4o<(oN!wMZDX9?*aa zlOknD@)6P}<4@}|lqMWCZ!WZ!28#|(A#6MH+`X$!?O+t{?s<=AkKbp!SB8GsTYVuQ z!rc)?=;Oln%;hg{h9Y2vaA((`&c44_c22D*I|CwVzg`1rEs0`fyTwsfMbGj($E0I= zx*KwR_qq+VVQ$t2$b+VA#r$r5Mmhssq2sdkuvv)nIJlX9F4TW8D8dO;PGenFVmq*| zK~JFtVNl2u-$SFNIrV!Z7{6fSMJBne>~mA8Lt`!z7GZI_p{I{7@3?fwno7{Dj`5dt zSc5_v;KY9_5V8&M&sQ31X?nOX!vFqe<)YfD}(YRCk8H?I`)G zzD(u)vX|#-Vj#?T1t38ea)B#Gk!Am&%N8k+J)tKpQH#=T?5u;FBX^8sk$Pq=QfGS; zfGkgEe!+-AQE3C|-nNK~u%;QG2@VkEndUChONkfkn!*de*{Wm(AZ2keyo2%5Dzx@C zZs`L*DO>At$J^&2R%yo#Ao-WGcDc zVyl&4$O~~N-^-J|A2qCI-mSBU~S~iW?748oYVfed=%#Ti8EUV=k(H&cJ6MiCDVjv7CZZ@LwgkejnIUA)dIa zLZZa)bF^~)gdnY8@FM!)vam(P$hmHs(x7)&)y9OpSxt@_-QGs(^rVGsIa#5Z{=EHZ zk!>*A>@d<6z`*zQzmwvyeQ$tuAMLG~oJP<>YUGZV;jimd2Sz)I2WeC)>l)aVw%0q2 z%&F*HU9^y2xS9ZNU&zM3(4LuaYk%W2~SrcjD-sKil(~by!+2 zs-RO4PJgW#fvb#K9_EHIHR-uwB8qr)@CC&1++&=b&Z&M=hAk&MPd_!X@cmppu@SIK zk&z>sexBY`$&n-Wj*J!uRz36|+=?NTt!dul5=Uz8)SQp8Lj*oi16gd+lOG+-^om!xjW!3dJ-o=glUI`t^v~Bbu#{Y z`RKm!99}&t2!F{=CV(JtQd5?9e$3=KHk^~%&KZaz3?S!^P9frbVh>pf0(PHr))nsV zZ=NX=Wk92KV2?D+)ii|!yzx7VOK{yPtkX{$PCC9BX$`qq2PCxbC~sG|12H3w#0XcMNKBkhww$Ux z{fK=TbX^#6N<4!QqO`B7;z7hwz`7(7!J+SI`k1n8jL*HQS{|i!rYCszz9cm9wRWx9 z-UV7qP#=Jk8X#HCjBEHsoDcgl>DaX0`VP0ZwNyjizBjfeC_8qtlcddCfn#Y)lcNHD zuU?;n6|;2nTNam@(_#gvlQL8c+_C|`Tl37;Z7ylQa)z>Q0mVguvw&VZFxj@TWn4H$JJ{n1<9U| zo6o;BVA5&4@L}ofKZGI$G7B|5XcPh~>tu=i8*=MtHD%hP@8jXLypLahRD2e=hwLIY zR1fj;Ee7v*0FrYXipeqH+WowgMdz@&$hmy?4$f=4?)3w0h~C?MEkh<4O2`XZ#7NGN zT|qb9)Yb;`6Mz}17ksvjHw`y%K5_t-{Js|)cXcSraGw#iaf0aE-qIGaV#*P}vTDka z^Jy1hfgulax<3|1 zkWKTSj)^ohx^Gt&#ueK(7K`WBkF;iHdP2j?u!T_Y>31 zQZJ2tud)t~kqT8nn9NHeavISf19EMQRLSSMh%NB?;--D2Fz=lwv%n6d2#Hw-TFi+4 zv_=+8Yg?>D(_K)rp2@}&>g3a9QFE%iOoEcXwNbEdGP(MQj-uUiacB9dqKhN7BT9Z- z8boaT#Y}j{105Mor&FpAfGvqyP<2UtMu(EOmduh9>1=$J9fe`vg&jfEau0Z%(#=AP zFK~;F#(_M{>m*;XtwRx_4BRA*Qt@XF(hp0EMb?7iJ`6|;rAfroYJBN3`tfsSVsOg~ z=4*W*LMJv;bLFdBO7;c&G7n1-^NBXv$i03V=rwiN+}BE_;`L#=MoLu3rM!!PSxUnq z=*m~G6#vc~e{*d>4xqAd2djzO*I`p#yrnk{z`d-7IKw5?BG9s%Z+4*&6{@y~XwU__ znha}D_o>b=swg)a@lH6TWPKP`;?psfYRi_l`USLUnp2;pq&{V~)IDjssLe`2cM=sm{ zn=Qwqf7O;VbMh~2IT-?E`T_L|(>sj07w)9bzSEANtlzDH^m`XTTZ1C$7pYsx#9`kd z`}TwVleh}mvDoKD*^;vxLXl8FD*CsRS&Qe221278b@zmS33K6+uFfpKhZ?fN+a!9* zkXh5*A!LdtJ=edREkfDhv^)r^Le}U;`9F^P&Ru#VIS7>Iz$1nfTW2`(3>CGHr%^%c z%X)v6Z#)^9=Jp=33qDF;s;Xqy*`l8GiC1j%PA~pW8+sd)l1emD+mN#vmq7Xu)EB2b zYJSJNj>|pgyLhDQoaNfz{Pbu|CU^m?Cu3ZwLwwGlOffq@U>}Nn^>Zw!UjN^E} zW{z|l>sl&^ur~=fbeamMtE?P4QAlGcF79%(l_E>I#X-H{%}NSEFW7y^X=*gQ;aWv* z7P~tk&xL18RNK#sf%GX?drf+#y~l=LOo9FCNiHGn(HE@UtXn6_eV_wsPh)Je7r+>Mr(;{o-1JCp3vk$p{XWJ%>@eIU1H=59m!72d7FepK4aowJ zPpViaQO#vf1Q}3|K7mJ@`KZFg)%;V)uRy93S^m_#(=OhQ^N7f&W2w`z6G&fj#I`y4 z+w`f+i)*Vk-*Ni@fBZqrE{1!ILz*aZ@KD8%;NK;EXy%?lu6NEinT5kpK5g$hSIqI) z?a}1-;tnVY^#*|n4@k_gyib5Kh2b9GDjGO(D6WuQbYnKXz3A&awa_IFC$;uowe$gB zwpq{fIEOSKuMbmL$?GHed3^&}Ag}NK%>$fj&31RbnAm`^jFhRed)>~cZxgLpc|#5j z)jw9T`xBkKRo>!pQW5bJNXMntx!{+&;b@!-aSe?1DWM3TVn=<8b2wpG@8Is5&+0V5 z{F-F|rXFeLKj-YF>1=S>&#R*gpmpV>>bW@-2*%nOL7hd8|Q?LfO!)e+K`Q?@oQ^e3AYlOvJD3O7n6YEVc{to@8rbHD+1}9rk^K2f9}( z+yRLgIja?L5_Y?6{JeOX&n<5HferxFR(6WKsUe|@-jN4u8&%jFn^5CV)2hdifd!YO zPOAD@hov}Ja&YTNQCq%+&n74$dj+BN{l#4QbCKH%6uE|oNn7(6sD!p9z?j+o&TN(m}dHDzX_ z1Zsoht3*CY_{wQ>P)uGLTl$sUXk05?ynI|mE)a@->_v8pEstW1|U*vHr1 zlqb&KDc83yv=!gdX0+i1Ztxupj`xNwbJ#JgNF%U%!*?3s=u_aVdr8(l+d(D41b4bA z5N3?(%lU*U!iYM+3mo$FS?3wItfn+~olAuU;T&9taf|zPf?`ll${u^r(e0ylr`5I7 zn@GL~FnV^GPB$Pr?f|ePXY+i0Y{GF|(7^UNm zcSzh(ZP*+Cf}9z^bc?rwWi$djhAr-ZDvH59W$7^NJ$8p>RQ#|*<{BgN|AOHF^C=$3(ZEUW5Xf&PdNwoV?g)Knd= z5%^W1Q-=7ugYvX|=TahJ1AuA(V9IF#s^No;fy`dm73*sN5db6V=jkN!NIi}GQ>DfC zCqT63DjUFQotG7$2CaXwC}@n1UwasNF}@bSmK!aP$w@=W`vicPn-<{L7>S*H+Mc(hT^U{WJfn=KXC42*eu7k1WkT`LeF2 zk8P=;@a2XP_vCLr{-!I98S&1Th0~Jlkg#(A!lycjk64WxvxFJl`tY))mj2gj$g zQEl?oCTiMc;I>*hcY*9OKkO6OdXTgcw59EIK#s#EADHomId=^7x^Cn5pYf@;>kQST z?*Q%|{U<{T%q{rH^2q1b-jm}e=OdO9ZRNpwSQq#(TLMq-wHo^6+tpd8`sW%b`VVWM zAoxFB15xU&nr7Ebj3qC2YCky-JISwxlMY$*Twst~GlZ7o+;9L6!u^#?Z& z*Z5{>9}OJw+mXwTS+v8Hcg1b>f_av1F$_>0ku9lWT@pt4gcK0Yf?u#-k36LS9-;hc zE*+JLVv1|@crEtwe}U3gR2 z=@)2f2zsfeUo?7FD&?x{XAW<4h2(^T-5!u=GH0=~Rq+0>4{qF@uLn z`6r_F1g2@>K;cZtQL!_xU;BW))dKuqNqUVI0WhspwV;=Rq=ZQ#^t}W!R=%0x9ALBN zH;OhQ?FN^3wSng@qIauuclro(vgdMF%Je0ZkcCUQdt_4wb=uJp;r?db|->M9t)7+GMd7#b(>ntpa#wNssQ{G-sfOYl3)jnBWiu*$#b zdt@U1i;=2hmQnky&uW4~eEyAK;wt|}mbA*hQ33cjCxgE6Z}zFhGsHK4iNU`D{h*^q zXy#?diScf&504xXf|-u6&O)CoH?{qKNp<9UtL@FknJ`);=@{m?LUxp_qxx?X4)dF0 zbmj1`$eV(y#`7{B9(!{ryZ}d_tmmgjmY?$^Xl3nxEzek ztm0>YjnXu}_Pbp-{0oZjzuFK6xRAJfKJ6mugtYNkv1UhT((uyF7jrUJC~@Amg3HxB zhI-o{I=6?u_8N}}tsrNm*O0TArz^$!m|+yA}wEUf=?>lw7tdM+>aPR%8kTh^4DBu0uSKJ=}dvvL$z`}M|E-^dv| z(KmqOni{SDCcHekkfHUJsJeQR<`@pTh@WMcx?J7<4VIC_{07T_LD9g+Vgt${lvZP2 zu@9CL9-5#ku27O|Rhpk~-Nv%}JOGPor5E5yu`kA??VQhkNe-5rA<{m&4s@L4PK(s} znSmaAx~x5eZo|8P<$WEdE{G&sqXh2IU0=Vy>^GIiMzkDqo^t<0f8=*~<$7b%?P$j* z>I5fm5dX`gFq6+47A;t)aR3;>6}0o`2?V0G>%I92-W$xN>$)t{127LCCC@{2=?~b5 zE$wX<8Qzik4Wr?!~ZJ7snarrJV!OIezOI!PkVysKlRBF^4%nOk5%d_{|9T+zt`QiVp^)&t=%b+pF zW?o=8!y=W9G|SWSH(dXkkI4fc=VjBU>bVA~;<2{hdDSzumSAo6p^~DM8ZRS<-KW3L zTC9OtOG?AqtmR;Q$RWp+B7)Q8*LN;Yv*g;8;3I}2nOndI@&osS_m%*`=#%_lH10X2 zlx>QRdE)z0y@h`UZ>n1Bvwtu7`@-3l0c1jc(BwE2#WJK+ zs@s46hVA<1E>nEB>MnaHXJ4-F_bKQ83>2;r zDI&N`Psf{VA2(dzbi|BA={jj;Sg3 z!X+s)J2Yk0d=OA4O32L;d|b+bm}w_W#ZC#QADP5tnXwbpPv;jR8b||3UQx^au7pIF z%u9y$5NDqO?!9H6ZvW@GVK7w9Jj$p=Wa}sZIvzk8-q3}=g zLz#u7(3#^pvH{+lYOa9-^J59PGe9gM zdM%b<3Xw8znMJ_p{rVGT_4OvhhM6VII!|4>R)q!;$Ni3i7ac{TZ4(V94^`{=wv);? zr++-AdAoNAm$22!u1tG(&t)BjA;Wme{&``p=J(He{))bNYtntY@HgH;pI-U#-On8Q z2fy(Sa2NQz1B>Sz_sNK9lgMnfia^86m9vXW(9KZ1;8bAB^?)mOa|cPZOz3Q9M$mbO zxF_O`j_DtVTO!Tj{H9boAXuL)*yfbXXE9Fi>0NuIyd$hQxc9F}o`$jyFaAfHQagB1 zMOiyWWB#KlLHl@mjAxc(-r+%?Gn^WiOR&briUW^u7+>@&ofAysnpB|hS_9_pp}oJj zTp0N|?#S9F=C&mu4YHcTF}v0sOI8=Bt1W*`h}$_%N;h6x`{=u`Dgo=-fJPGvB(&>t zs%6Fn4ZMS%u6Tjh9xul>0>dWBWOmHb>$k7`^(QxI9f~xhJm5Qgm(O~Anf6rE`u4kj zH~-;3C5RHj_=0E%Pxb(y1lTV>yp|AHHqm`l-?W1DM$bLYYLbMuC9l24b?ac?_0WRC z&q-{p9mY);xL@^9=DeA*g%9SI$>E8)WMzih{W5~snPu^c8?FNqYuDq(Dqrqv>6OeB zO-KW_WgV<)3wlV!c6Ch=X<(h4D(#l7YyTo4B0xjiLFPRv?eo+-W}~@%X3a^#ga<(P zD*%aM{K!7PP@$#u+HIK@+t{fVRZ76e0ph}y?}-0MrEiG;zn4R`+i}={;5&rL0>=%Eeh8RS3z`1Bzl}(SOUK+t$|V3I#jnmD8+Su=ullY zwNi47pw1zpfu|7klF0Up6qg?+k*$P;b25&%u+W3q!q*;BMfjn>UmA-$>>826oh z3D&Pi<~Zf@o*Nyllz8d8fl_Nt?gOMr4e5-8*M`db!Jqq)u9!@;`M+#Vr>@?-i_;3x zcGU6Enuk3hKas5KOcY=!#Tzl|OM_p#Z=;>5>wK-O{wRkUbBhWAbU$|XLPr}0m89?y z3$tH2>C@EKb*8ZHT@<#`dFZrLtUpy2aA7$*|6v@|rMMF9O5)<=`!tiycDnE}J@Y>U zn4U-B0H#NM&2mC;hZLj%7{lzE&|Yd2&VAk-(|oNnp<)z?lCB&c(32^=CXOMS(%-ow zS^+~)seCD#oYnLXRZ+e6^Ck|~Mv2G_?H(gX0S9h@o(m#Ei{3E>{}mcA3l4 z?f7$7+B564Xy^B;{HIybsM^Y8uRfXD(kfv0MorBdgAzRg@YuHWALmq+ykRp9Ku zRxf}>LjQ_9`nN}SKm%a~QA2ZPg~GHF7#_$;A*PycQtr&d2dk@VX2{!VLQZdlAUx1q z;z?E3_p@M^1ply(6XqY*J3DVT0}+Q*`!tl(>R}zZlqs5Brk8scJGu&YuStlb)zEXoIiM#RQ4+Nc^UC2{ZCTeQ{uia8KB5(0Lg zC=+p7EuYM0C&#V7aV_gA*GsVAX`9OoPu55*)Ubh>Q!C*CuydnF6siDbPkN5)*Mq8y z9uPp~MsUo%(l`2?{w!z&fnI&GkVOi!sR!rMe0P>?h>V1RfMqT7G?uZ^m~;T++UPT^ zgohEWwMAH;n{_ zOCx<=y19Mcy6Q$Jm+osgybYka4x1bTj`3A-q%uPJV^WP@^diBlBu=>-vklKzNlW*8 zFf+poxA~QM0cBH{9|B){*#4f%S{a;gWBe53a2pt$XMOmQ-;fqfG;s$*4?q+FmYGoO zS$_Ogm8>QHsV@)J``?MB*F4h;&9Q#MDC=cF@`a>&f?)xuVQMSDxcuDXSJHv$nohi}+ zmPjH@LOK-+shqX`*vVy4o6& z-L+rh*4-GrkhRnP=D#`ZbDi2gcHmSsjfeP8u95V|W*vGr3M}0bJ*9+KlwXO`%gckm znuw>q6nJy}Hx85El6Fj_W{0nIq?*@&G5i8-f8M0%+E)&cBKc$Z)t+>v(hA&Pt+ZxG z3@a=7m9{c%wbE7-R$_%2^>yVfjlF*`q5WN{6%P9AQY#RJb6Co1`0yO`<(@VLJkV0( zf8FlM$r@}*Xc+1m?Rcx$O#(2OB3*cHrh8XdVeLg`*R96e2tTe%Ytie&5uu$^TEEox zjGs4Tte^k++u!?r+iE^-U2F|d_a0L1=621kFW=j~%>Mo3O^G+({@S|1uhXeKncdf> z$Ig`ho@$fJE^0)(EWU$o9;;x41g`z++lv3DEC8%%)1ZKib(g3;MxZYVDg* z1tu!lP9BW1+iaOE2zGf+MyOcpkL2gLiX|D#=1D}%v^h3UBvl65)b*~9-QDtUf`33a zzZ9Kd1)pR1SWOoHxbE$p#;MYuThvNV<&BFQ8uv!HNj?I=&gHl996hIo2C43Y&LswE z8X19$RmcKtT!w7-zIau;6>S#RQ8;RR4~N~7iI2(KZe=vX^l4BoduW~(WFT`9^wO%} zqt1UQ4pzof=Ji|jqV>W?o%*>EZoq?Zmzs9YhcdM-@I4{lgI^-lmEc$6on(pjzr;6t zeaAQd@1?#d9&qGk+dk@K%#5~3oUq2NruNQpKhg5$$8M2bd!LT_1u$ zm-K2;kbyQ_sWSkz<|ny~i?3hO0c+>|6AH6Zc?@FxQ`gdREv+Pp*PE7v_ID4iWVxgo zRfFcc0bAS-r5JN5AEhrpJv`gG^c$oN-9tJ8@yP%puu{|11ONEf&%*q(2mV{iv@YhT@;#Gw7Z~kpr>=TJZ>HL#x5ID?^l1+^5|_K}$muo7TN`2x*!MZpiQYX$$59 zKSZmpPDE6mcaV=xF?0C+Y+0r!v&Ju6nz?9F z2|zacVfA`_uq{KUT{o)5qoAsxs{Ca;^Uv^khsxVf2k;_O5bkSq_(Bm`t3Z=}-7@Dj zmK#GBl!fka3YGN%__-Z9_ywnjQg#`MK-9siDY^f=6qhUv7!mc;=7L$_Eu{wY4@Mq^ zg_8mi452ArE;%C4(}y#mp+i4V{+4ANL-EXV@wub~=MtU=+-foHSaRqr4slzmXQ=4U z*SqX1Gq+Bg6tA4Wya?P7Cl2ZGxrUo*hlC z+-P@6+iX`*O6()SlD3`aFgb0%m{X@L3Kry(PC1Om2%g;nXPilGZ7~-Pz;sNtxGzJ> zPS?$(>WxCN6jcw)cAm?+N|DH8Ym>Iv^_ zG(xBKNI+kkC&K+r0lu?CVmcU?*Devft`MuMh5E#DHgy@$DRhA%Kd} z=Oa67#Zj8V>aroAsfr$+;cMwTnN@=((xJ=pP#>v-pwsv_W&Ks!(s8q!3XR?ziR#)B z0@?>>1t+#QA|cWyw?o+J=-m~$7N(@fZC7-=adz`u)8@I~mL!(WtZdbh2>`b0Xjdn{ zEcs4PMf<#=O#1V)cY5CI73r1|&?~E3b>4aysz@NSTBkp{T_(-)s<97RN#PpuNK1;a zhoAq*a~I205&klz_ES33z+RoGB0Zf%Kb!&wTzB91ah$M|E;XYg7Gi7)W8jcXn9v~q zrWX5I#?@N^`>dPFKHVY@%_q$Q3@cVAcgW=BEW0)0i3U80dc6~@-1f$aHQ41_UckTX zL*8}3(lU{CpZtK4H~oeaHL^leRNQtm&ttgjjFV}9AjB{!PD9=U(r^M4XaJlA@4Ybb zF`3Raq-f#YUo{vukbkGcjb(7e;$=NR(^u+?^zT!(bRRf<+3I((xgXmTwHtHZrrq(5 zG+r|7I}?IIX>5+R;9*@w^A`s-GYA4^>7-oNK=Q0H%uFANyd?cglxWs$ciVKV%3lC+ zYz;$8X5S%q>)ORL#j`RK&m=f~q|cq27gsA$p$7weX?f+WX!NRigs(P#43>2)d%O_q+n$YAZF1-3migP z*hwochW4Y5BC=@tK9V-RJD~~Y2sa)#TZ1M8^N?D6D@>GB{ff9PBiFu-oWHWYJRWg1e6r z)jz(Y7wXrdV3$(!3IezvxAIBSjp?mFW-VA)*^=erfUC^}UvzSN=D+S;7NXP8SZ}T~ z=B5&OX=f1_VqS`pjYl?okQeAMCLCGjFRY@cgS8Djl|aG`<{ml$q=I9H^n46sJN&AD znWdF|mEXvTXQ09EFT3@;ec6FGm^Y!R6&KSG1I!OC4qf|15gz46lgEkg4PZbkP@o|Z zf-lq2Ntva#CCB0xHDVz6sM)f(0~$6Xg7Mq&!Y8k9)XYA~t1Kme<_&LsFrD$D&Unqc zg(r&YfQo{q+1cSQKw$$@eJ>u6usqD<%l4;5Hs@zwGV3!o^K0TH(_?v*^C(nt*xeu$;0l^ZXFA+r-?=pZjxyWx!J4O z@t8CbYs>8995K2y8+K;y{8-SiBEDN^KZQy^e4&kXiFR+}h&aH3h%A|zM@Dv>l zVk2FO6sys!iQHNXJ2;kVD*0tFE(3h7$S7;~8Q}UGN%Fza?_uU=faO%{*Np~r0(e@z z>yC}BBYGlmcP3>DJk&gE-p|DOKyFoh5$ClTJ1P}7XiJ}DTtnZvW!a4}5>up9owpiJ zN#&0;{ah>`&q$VT7Nk7Yz143y-dttM$lOpWvB~YTtE-=_($LSn0#8=C=h44a#PLO}f`(K0^Wl|5I_sT^Of z%Q&t7sJkKP8kh0x`D3^|sMxg@E;G{xz@?3br@fAtZ2R$7p;9Yb_o}%$Cv-nIPN+jN zhkU9*xgt{NxI8=9uivdl#EjHjez9ar8n)buIc#%3;rFMUkNjKr7+_D?oF4hjgZW}X zB^+jbI(wwpbJ%9d*HCd2w16K+uJ(YzslfSbQu7!97;J1)*U(Vy7~C+A`IuF%mtcSof!WQQfzK z0lExL_78L!D#oU4ueCq5;8ZzXm!o2H(QuDpa~R)F9r(Iu#;^T^WDmT4ADAW6Js9WSIMpdBx1 z3RIX_BC<*umokK@T-3qNPT4-}OW*bnJnuQrd(L~_-}C$~I+P|l0Wsh8o1-bX{L-0> z!$NavOJ;sD>v`!2!fa37`fz?5Tf=zBIU*@IcqtIe!5YGA_Rd=rVMvA*#$G_vQR@5`H5ExCqB{D1d^6S`ScjOYI zfza)2$z(~bX0WzuxN<`@;n#^2iHD9KmY<^>Ta#P5H;h%L_AuY{%WFR*yy&)vm|_lr zA8ILHW6(7^u+`?Bpo`KV2AKj<`5=cHOLb&pk()Ew=WA{Htm&hLqFG#EM> zC|FnaxA-VEYf1qGbUfWH*4Z1Z?xj;lW>B(anH&GuH2HiC?0+NbT2B?ULEpGhY^-VB6i0Wft$_HdtI6{3Rheh-58v@6EMpw{5|+|n)K4#k z@$Psx((fssjRXLL5=qW{E;-$}3)}K`9t43Dr>M4(H4TRR#qCc!KFCI)?T03QNK8q> zxW=)SKB9<`_HVUG12a`E*B=jWciY=Z<&`@rv@I}T7e#3jLy~X+D*fH`2q0gf4;2mT z6KN+JuEG91#Y)>-0t?~Y-W5S87S-ULKb83<@e4-P7;vGqk`PX)k&^TH%?J05N|C!AL@S^Ht8O16-r&A(M7H9*OutzmVZ;g%Gzt0V_NDIPxL;_~5)0aK7s0?zcHQfbrz1&#!dk9%LOLw8bzC)F12S&4Od z0C!0xl77y!fG^q8tV^7;Z`bQDnG6 zdrHbcAW?@ejh?Psw0Y^UEn%a+@UF9j7hsN3yC~6D(igt=x?hZ|yC_l(X;>s10sv0i|CUoS+*Q#5%niYWjd^AGTyTc`wf Vhj(_{xmp|lePGzWTYHZE{1?T2|A_zq literal 0 HcmV?d00001 diff --git a/docs/source/recipes/Non-streaming-ASR/librispeech/index.rst b/docs/source/recipes/Non-streaming-ASR/librispeech/index.rst index 3ebb36b25..bf439861a 100644 --- a/docs/source/recipes/Non-streaming-ASR/librispeech/index.rst +++ b/docs/source/recipes/Non-streaming-ASR/librispeech/index.rst @@ -9,3 +9,4 @@ LibriSpeech pruned_transducer_stateless zipformer_mmi zipformer_ctc_blankskip + distillation diff --git a/egs/librispeech/ASR/distillation_with_hubert.sh b/egs/librispeech/ASR/distillation_with_hubert.sh index a38cf590c..6aaa0333b 100755 --- a/egs/librispeech/ASR/distillation_with_hubert.sh +++ b/egs/librispeech/ASR/distillation_with_hubert.sh @@ -150,7 +150,7 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then num_codebooks=8 mkdir -p $exp_dir/vq - codebook_dir=$exp_dir/vq/${teacher_model_id}_layer${embedding_layer}_cb${num_codebooks} + codebook_dir=$exp_dir/vq/${teacher_model_id} mkdir -p codebook_dir codebook_download_dir=$exp_dir/download_codebook if [ -d $codebook_download_dir ]; then @@ -180,9 +180,9 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then ./pruned_transducer_stateless6/extract_codebook_index.py \ --full-libri $full_libri \ --exp-dir $exp_dir \ - --embedding-layer 36 \ + --embedding-layer $embedding_layer \ --num-utts 1000 \ - --num-codebooks 8 \ + --num-codebooks $num_codebooks \ --max-duration 100 \ --teacher-model-id $teacher_model_id \ --use-extracted-codebook $use_extracted_codebook From 958dbb3a1d02ecced9ff62624625892afb2206c3 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Wed, 11 Jan 2023 20:29:36 +0800 Subject: [PATCH 080/174] add doc for int8 quantization with sherpa-ncnn (#832) * add doc for int8 quantization with sherpa-ncnn * typo fixes --- ...te-int-8-scale-table-for-conv-emformer.txt | 104 ++++++ docs/source/model-export/export-ncnn.rst | 307 +++++++++++++++++- 2 files changed, 397 insertions(+), 14 deletions(-) create mode 100644 docs/source/model-export/code/generate-int-8-scale-table-for-conv-emformer.txt diff --git a/docs/source/model-export/code/generate-int-8-scale-table-for-conv-emformer.txt b/docs/source/model-export/code/generate-int-8-scale-table-for-conv-emformer.txt new file mode 100644 index 000000000..347e7e51a --- /dev/null +++ b/docs/source/model-export/code/generate-int-8-scale-table-for-conv-emformer.txt @@ -0,0 +1,104 @@ +Don't Use GPU. has_gpu: 0, config.use_vulkan_compute: 1 +num encoder conv layers: 88 +num joiner conv layers: 3 +num files: 3 +Processing ../test_wavs/1089-134686-0001.wav +Processing ../test_wavs/1221-135766-0001.wav +Processing ../test_wavs/1221-135766-0002.wav +Processing ../test_wavs/1089-134686-0001.wav +Processing ../test_wavs/1221-135766-0001.wav +Processing ../test_wavs/1221-135766-0002.wav +----------encoder---------- +conv_87 : max = 15.942385 threshold = 15.938493 scale = 7.968131 +conv_88 : max = 35.442448 threshold = 15.549335 scale = 8.167552 +conv_89 : max = 23.228289 threshold = 8.001738 scale = 15.871552 +linear_90 : max = 3.976146 threshold = 1.101789 scale = 115.267128 +linear_91 : max = 6.962030 threshold = 5.162033 scale = 24.602713 +linear_92 : max = 12.323041 threshold = 3.853959 scale = 32.953129 +linear_94 : max = 6.905416 threshold = 4.648006 scale = 27.323545 +linear_93 : max = 6.905416 threshold = 5.474093 scale = 23.200188 +linear_95 : max = 1.888012 threshold = 1.403563 scale = 90.483986 +linear_96 : max = 6.856741 threshold = 5.398679 scale = 23.524273 +linear_97 : max = 9.635942 threshold = 2.613655 scale = 48.590950 +linear_98 : max = 6.460340 threshold = 5.670146 scale = 22.398010 +linear_99 : max = 9.532276 threshold = 2.585537 scale = 49.119396 +linear_101 : max = 6.585871 threshold = 5.719224 scale = 22.205809 +linear_100 : max = 6.585871 threshold = 5.751382 scale = 22.081648 +linear_102 : max = 1.593344 threshold = 1.450581 scale = 87.551147 +linear_103 : max = 6.592681 threshold = 5.705824 scale = 22.257959 +linear_104 : max = 8.752957 threshold = 1.980955 scale = 64.110489 +linear_105 : max = 6.696240 threshold = 5.877193 scale = 21.608953 +linear_106 : max = 9.059659 threshold = 2.643138 scale = 48.048950 +linear_108 : max = 6.975461 threshold = 4.589567 scale = 27.671457 +linear_107 : max = 6.975461 threshold = 6.190381 scale = 20.515701 +linear_109 : max = 3.710759 threshold = 2.305635 scale = 55.082436 +linear_110 : max = 7.531228 threshold = 5.731162 scale = 22.159557 +linear_111 : max = 10.528083 threshold = 2.259322 scale = 56.211544 +linear_112 : max = 8.148807 threshold = 5.500842 scale = 23.087374 +linear_113 : max = 8.592566 threshold = 1.948851 scale = 65.166611 +linear_115 : max = 8.437109 threshold = 5.608947 scale = 22.642395 +linear_114 : max = 8.437109 threshold = 6.193942 scale = 20.503904 +linear_116 : max = 3.966980 threshold = 3.200896 scale = 39.676392 +linear_117 : max = 9.451303 threshold = 6.061664 scale = 20.951344 +linear_118 : max = 12.077262 threshold = 3.965800 scale = 32.023804 +linear_119 : max = 9.671615 threshold = 4.847613 scale = 26.198460 +linear_120 : max = 8.625638 threshold = 3.131427 scale = 40.556595 +linear_122 : max = 10.274080 threshold = 4.888716 scale = 25.978189 +linear_121 : max = 10.274080 threshold = 5.420480 scale = 23.429659 +linear_123 : max = 4.826197 threshold = 3.599617 scale = 35.281532 +linear_124 : max = 11.396383 threshold = 7.325849 scale = 17.335875 +linear_125 : max = 9.337198 threshold = 3.941410 scale = 32.221970 +linear_126 : max = 9.699965 threshold = 4.842878 scale = 26.224073 +linear_127 : max = 8.775370 threshold = 3.884215 scale = 32.696438 +linear_129 : max = 9.872276 threshold = 4.837319 scale = 26.254213 +linear_128 : max = 9.872276 threshold = 7.180057 scale = 17.687883 +linear_130 : max = 4.150427 threshold = 3.454298 scale = 36.765789 +linear_131 : max = 11.112692 threshold = 7.924847 scale = 16.025545 +linear_132 : max = 11.852893 threshold = 3.116593 scale = 40.749626 +linear_133 : max = 11.517084 threshold = 5.024665 scale = 25.275314 +linear_134 : max = 10.683807 threshold = 3.878618 scale = 32.743618 +linear_136 : max = 12.421055 threshold = 6.322729 scale = 20.086264 +linear_135 : max = 12.421055 threshold = 5.309880 scale = 23.917679 +linear_137 : max = 4.827781 threshold = 3.744595 scale = 33.915554 +linear_138 : max = 14.422395 threshold = 7.742882 scale = 16.402161 +linear_139 : max = 8.527538 threshold = 3.866123 scale = 32.849449 +linear_140 : max = 12.128619 threshold = 4.657793 scale = 27.266134 +linear_141 : max = 9.839593 threshold = 3.845993 scale = 33.021378 +linear_143 : max = 12.442304 threshold = 7.099039 scale = 17.889746 +linear_142 : max = 12.442304 threshold = 5.325038 scale = 23.849592 +linear_144 : max = 5.929444 threshold = 5.618206 scale = 22.605080 +linear_145 : max = 13.382126 threshold = 9.321095 scale = 13.625010 +linear_146 : max = 9.894987 threshold = 3.867645 scale = 32.836517 +linear_147 : max = 10.915313 threshold = 4.906028 scale = 25.886522 +linear_148 : max = 9.614287 threshold = 3.908151 scale = 32.496181 +linear_150 : max = 11.724932 threshold = 4.485588 scale = 28.312899 +linear_149 : max = 11.724932 threshold = 5.161146 scale = 24.606939 +linear_151 : max = 7.164453 threshold = 5.847355 scale = 21.719223 +linear_152 : max = 13.086471 threshold = 5.984121 scale = 21.222834 +linear_153 : max = 11.099524 threshold = 3.991601 scale = 31.816805 +linear_154 : max = 10.054585 threshold = 4.489706 scale = 28.286930 +linear_155 : max = 12.389185 threshold = 3.100321 scale = 40.963501 +linear_157 : max = 9.982999 threshold = 5.154796 scale = 24.637253 +linear_156 : max = 9.982999 threshold = 8.537706 scale = 14.875190 +linear_158 : max = 8.420287 threshold = 6.502287 scale = 19.531588 +linear_159 : max = 25.014746 threshold = 9.423280 scale = 13.477261 +linear_160 : max = 45.633553 threshold = 5.715335 scale = 22.220921 +linear_161 : max = 20.371849 threshold = 5.117830 scale = 24.815203 +linear_162 : max = 12.492933 threshold = 3.126283 scale = 40.623318 +linear_164 : max = 20.697504 threshold = 4.825712 scale = 26.317358 +linear_163 : max = 20.697504 threshold = 5.078367 scale = 25.008038 +linear_165 : max = 9.023975 threshold = 6.836278 scale = 18.577358 +linear_166 : max = 34.860619 threshold = 7.259792 scale = 17.493614 +linear_167 : max = 30.380934 threshold = 5.496160 scale = 23.107042 +linear_168 : max = 20.691216 threshold = 4.733317 scale = 26.831076 +linear_169 : max = 9.723948 threshold = 3.952728 scale = 32.129707 +linear_171 : max = 21.034811 threshold = 5.366547 scale = 23.665123 +linear_170 : max = 21.034811 threshold = 5.356277 scale = 23.710501 +linear_172 : max = 10.556884 threshold = 5.729481 scale = 22.166058 +linear_173 : max = 20.033039 threshold = 10.207264 scale = 12.442120 +linear_174 : max = 11.597379 threshold = 2.658676 scale = 47.768131 +----------joiner---------- +linear_2 : max = 19.293503 threshold = 14.305265 scale = 8.877850 +linear_1 : max = 10.812222 threshold = 8.766452 scale = 14.487047 +linear_3 : max = 0.999999 threshold = 0.999755 scale = 127.031174 +ncnn int8 calibration table create success, best wish for your int8 inference has a low accuracy loss...\(^0^)/...233... diff --git a/docs/source/model-export/export-ncnn.rst b/docs/source/model-export/export-ncnn.rst index 11471d611..ed0264089 100644 --- a/docs/source/model-export/export-ncnn.rst +++ b/docs/source/model-export/export-ncnn.rst @@ -204,7 +204,7 @@ Next, we use the following code to export our model: .. literalinclude:: ./code/export-conv-emformer-transducer-for-ncnn-output.txt - The log shows the model has ``75490012`` number of parameters, i.e., ``~75 M``. + The log shows the model has ``75490012`` parameters, i.e., ``~75 M``. .. code-block:: @@ -213,7 +213,7 @@ Next, we use the following code to export our model: -rw-r--r-- 1 kuangfangjun root 289M Jan 11 12:05 icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/pretrained-epoch-30-avg-10-averaged.pt You can see that the file size of the pre-trained model is ``289 MB``, which - is roughly ``4 x 75 M``. + is roughly ``75490012*4/1024/1024 = 287.97 MB``. After running ``conv_emformer_transducer_stateless2/export-for-ncnn.py``, we will get the following files: @@ -286,8 +286,8 @@ We compare the file sizes of the models below before and after converting via `` | joiner_jit_trace-pnnx.ncnn.bin | 1.5 MB | +----------------------------------+------------+ -You can see that the file size of the models after converting is about one half -of the models before converting: +You can see that the file sizes of the models after conversion are about one half +of the models before conversion: - encoder: 283 MB vs 142 MB - decoder: 1010 KB vs 503 KB @@ -338,6 +338,8 @@ The output is given below: Congratulations! You have successfully exported a model from PyTorch to `ncnn`_! +.. _conv-emformer-modify-the-exported-encoder-for-sherpa-ncnn: + 5. Modify the exported encoder for sherpa-ncnn ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -356,14 +358,15 @@ Let us have a look at the first few lines of ``encoder_jit_trace-pnnx.ncnn.param 1. ``7767517``, it is a magic number and should not be changed. 2. ``1060 1342``, the first number ``1060`` specifies the number of layers - in this file, while ``1342`` specifies the number intermediate outputs of - this file + in this file, while ``1342`` specifies the number of intermediate outputs + of this file 3. ``Input in0 0 1 in0``, ``Input`` is the layer type of this layer; ``in0`` is the layer name of this layer; ``0`` means this layer has no input; - ``1`` means this layer has one output. ``in0`` is the output name of + ``1`` means this layer has one output; ``in0`` is the output name of this layer. -We need to add 1 extra line and the result looks like below: +We need to add 1 extra line and also increment the number of layers. +The result looks like below: .. code-block:: bash @@ -376,13 +379,13 @@ We need to add 1 extra line and the result looks like below: 1. ``7767517``, it is still the same 2. ``1061 1342``, we have added an extra layer, so we need to update ``1060`` to ``1061``. - We don't need to change ``1342`` since the newly added layer has no inputs and outputs. + We don't need to change ``1342`` since the newly added layer has no inputs or outputs. 3. ``SherpaMetaData sherpa_meta_data1 0 0 0=1 1=12 2=32 3=31 4=8 5=32 6=8 7=512`` This line is newly added. Its explanation is given below: - ``SherpaMetaData`` is the type of this layer. Must be ``SherpaMetaData``. - ``sherpa_meta_data1`` is the name of this layer. Must be ``sherpa_meta_data1``. - - ``0 0`` means this layer has no inputs and output. Must be ``0 0`` + - ``0 0`` means this layer has no inputs or output. Must be ``0 0`` - ``0=1``, 0 is the key and 1 is the value. MUST be ``0=1`` - ``1=12``, 1 is the key and 12 is the value of the parameter ``--num-encoder-layers`` that you provided when running @@ -483,10 +486,286 @@ disable ``fp16`` when using ``pnnx``: .. note:: - We add ``fp16=0`` when exporting the encoder and joiner. ``ncnn`` does not + We add ``fp16=0`` when exporting the encoder and joiner. `ncnn`_ does not support quantizing the decoder model yet. We will update this documentation - once ``ncnn`` supports it. (Maybe in this year, 2023). + once `ncnn`_ supports it. (Maybe in this year, 2023). -TODO(fangjun): Finish it. +It will generate the following files -Have fun with `sherpa-ncnn`_! +.. code-block:: bash + + ls -lh icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/*_jit_trace-pnnx.ncnn.{param,bin} + + -rw-r--r-- 1 kuangfangjun root 503K Jan 11 15:56 icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/decoder_jit_trace-pnnx.ncnn.bin + -rw-r--r-- 1 kuangfangjun root 437 Jan 11 15:56 icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/decoder_jit_trace-pnnx.ncnn.param + -rw-r--r-- 1 kuangfangjun root 283M Jan 11 15:56 icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/encoder_jit_trace-pnnx.ncnn.bin + -rw-r--r-- 1 kuangfangjun root 79K Jan 11 15:56 icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/encoder_jit_trace-pnnx.ncnn.param + -rw-r--r-- 1 kuangfangjun root 3.0M Jan 11 15:56 icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/joiner_jit_trace-pnnx.ncnn.bin + -rw-r--r-- 1 kuangfangjun root 488 Jan 11 15:56 icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/joiner_jit_trace-pnnx.ncnn.param + +Let us compare again the file sizes: + ++----------------------------------------+------------+ +| File name | File size | ++----------------------------------------+------------+ +| encoder_jit_trace-pnnx.pt | 283 MB | ++----------------------------------------+------------+ +| decoder_jit_trace-pnnx.pt | 1010 KB | ++----------------------------------------+------------+ +| joiner_jit_trace-pnnx.pt | 3.0 MB | ++----------------------------------------+------------+ +| encoder_jit_trace-pnnx.ncnn.bin (fp16) | 142 MB | ++----------------------------------------+------------+ +| decoder_jit_trace-pnnx.ncnn.bin (fp16) | 503 KB | ++----------------------------------------+------------+ +| joiner_jit_trace-pnnx.ncnn.bin (fp16) | 1.5 MB | ++----------------------------------------+------------+ +| encoder_jit_trace-pnnx.ncnn.bin (fp32) | 283 MB | ++----------------------------------------+------------+ +| joiner_jit_trace-pnnx.ncnn.bin (fp32) | 3.0 MB | ++----------------------------------------+------------+ + +You can see that the file sizes are doubled when we disable ``fp16``. + +.. note:: + + You can again use ``streaming-ncnn-decode.py`` to test the exported models. + +Next, follow :ref:`conv-emformer-modify-the-exported-encoder-for-sherpa-ncnn` +to modify ``encoder_jit_trace-pnnx.ncnn.param``. + +Change + +.. code-block:: bash + + 7767517 + 1060 1342 + Input in0 0 1 in0 + +to + +.. code-block:: bash + + 7767517 + 1061 1342 + SherpaMetaData sherpa_meta_data1 0 0 0=1 1=12 2=32 3=31 4=8 5=32 6=8 7=512 + Input in0 0 1 in0 + +.. caution:: + + Please follow :ref:`conv-emformer-modify-the-exported-encoder-for-sherpa-ncnn` + to change the values for ``SherpaMetaData`` if your model uses a different setting. + + +Next, let us compile `sherpa-ncnn`_ since we will quantize our models within +`sherpa-ncnn`_. + +.. code-block:: bash + + # We will download sherpa-ncnn to $HOME/open-source/ + # You can change it to anywhere you like. + cd $HOME + mkdir -p open-source + + cd open-source + git clone https://github.com/k2-fsa/sherpa-ncnn + cd sherpa-ncnn + mkdir build + cd build + cmake .. + make -j 4 + + ./bin/generate-int8-scale-table + + export PATH=$HOME/open-source/sherpa-ncnn/build/bin:$PATH + +The output of the above commands are: + +.. code-block:: bash + + (py38) kuangfangjun:build$ generate-int8-scale-table + Please provide 10 arg. Currently given: 1 + Usage: + generate-int8-scale-table encoder.param encoder.bin decoder.param decoder.bin joiner.param joiner.bin encoder-scale-table.txt joiner-scale-table.txt wave_filenames.txt + + Each line in wave_filenames.txt is a path to some 16k Hz mono wave file. + +We need to create a file ``wave_filenames.txt``, in which we need to put +some calibration wave files. For testing purpose, we put the ``test_wavs`` +from the pre-trained model repository ``_ + +.. code-block:: bash + + cd egs/librispeech/ASR + cd icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/ + + cat < wave_filenames.txt + ../test_wavs/1089-134686-0001.wav + ../test_wavs/1221-135766-0001.wav + ../test_wavs/1221-135766-0002.wav + EOF + +Now we can calculate the scales needed for quantization with the calibration data: + +.. code-block:: bash + + cd egs/librispeech/ASR + cd icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/ + + generate-int8-scale-table \ + ./encoder_jit_trace-pnnx.ncnn.param \ + ./encoder_jit_trace-pnnx.ncnn.bin \ + ./decoder_jit_trace-pnnx.ncnn.param \ + ./decoder_jit_trace-pnnx.ncnn.bin \ + ./joiner_jit_trace-pnnx.ncnn.param \ + ./joiner_jit_trace-pnnx.ncnn.bin \ + ./encoder-scale-table.txt \ + ./joiner-scale-table.txt \ + ./wave_filenames.txt + +The output logs are in the following: + +.. literalinclude:: ./code/generate-int-8-scale-table-for-conv-emformer.txt + +It generates the following two files: + +.. code-block:: bash + + $ ls -lh encoder-scale-table.txt joiner-scale-table.txt + -rw-r--r-- 1 kuangfangjun root 955K Jan 11 17:28 encoder-scale-table.txt + -rw-r--r-- 1 kuangfangjun root 18K Jan 11 17:28 joiner-scale-table.txt + +.. caution:: + + Definitely, you need more calibration data to compute the scale table. + +Finally, let us use the scale table to quantize our models into ``int8``. + +.. code-block:: bash + + ncnn2int8 + + usage: ncnn2int8 [inparam] [inbin] [outparam] [outbin] [calibration table] + +First, we quantize the encoder model: + +.. code-block:: bash + + cd egs/librispeech/ASR + cd icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/ + + ncnn2int8 \ + ./encoder_jit_trace-pnnx.ncnn.param \ + ./encoder_jit_trace-pnnx.ncnn.bin \ + ./encoder_jit_trace-pnnx.ncnn.int8.param \ + ./encoder_jit_trace-pnnx.ncnn.int8.bin \ + ./encoder-scale-table.txt + +Next, we quantize the joiner model: + +.. code-block:: bash + + ncnn2int8 \ + ./joiner_jit_trace-pnnx.ncnn.param \ + ./joiner_jit_trace-pnnx.ncnn.bin \ + ./joiner_jit_trace-pnnx.ncnn.int8.param \ + ./joiner_jit_trace-pnnx.ncnn.int8.bin \ + ./joiner-scale-table.txt + +The above two commands generate the following 4 files: + +.. code-block:: bash + + -rw-r--r-- 1 kuangfangjun root 99M Jan 11 17:34 encoder_jit_trace-pnnx.ncnn.int8.bin + -rw-r--r-- 1 kuangfangjun root 78K Jan 11 17:34 encoder_jit_trace-pnnx.ncnn.int8.param + -rw-r--r-- 1 kuangfangjun root 774K Jan 11 17:35 joiner_jit_trace-pnnx.ncnn.int8.bin + -rw-r--r-- 1 kuangfangjun root 496 Jan 11 17:35 joiner_jit_trace-pnnx.ncnn.int8.param + +Congratulations! You have successfully quantized your model from ``float32`` to ``int8``. + +.. caution:: + + ``ncnn.int8.param`` and ``ncnn.int8.bin`` must be used in pairs. + + You can replace ``ncnn.param`` and ``ncnn.bin`` with ``ncnn.int8.param`` + and ``ncnn.int8.bin`` in `sherpa-ncnn`_ if you like. + + For instance, to use only the ``int8`` encoder in ``sherpa-ncnn``, you can + replace the following invocation: + + .. code-block:: + + cd egs/librispeech/ASR + cd icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/ + + sherpa-ncnn \ + ../data/lang_bpe_500/tokens.txt \ + ./encoder_jit_trace-pnnx.ncnn.param \ + ./encoder_jit_trace-pnnx.ncnn.bin \ + ./decoder_jit_trace-pnnx.ncnn.param \ + ./decoder_jit_trace-pnnx.ncnn.bin \ + ./joiner_jit_trace-pnnx.ncnn.param \ + ./joiner_jit_trace-pnnx.ncnn.bin \ + ../test_wavs/1089-134686-0001.wav + + with + + .. code-block:: + + cd egs/librispeech/ASR + cd icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/ + + sherpa-ncnn \ + ../data/lang_bpe_500/tokens.txt \ + ./encoder_jit_trace-pnnx.ncnn.int8.param \ + ./encoder_jit_trace-pnnx.ncnn.int8.bin \ + ./decoder_jit_trace-pnnx.ncnn.param \ + ./decoder_jit_trace-pnnx.ncnn.bin \ + ./joiner_jit_trace-pnnx.ncnn.param \ + ./joiner_jit_trace-pnnx.ncnn.bin \ + ../test_wavs/1089-134686-0001.wav + + +The following table compares again the file sizes: + + ++----------------------------------------+------------+ +| File name | File size | ++----------------------------------------+------------+ +| encoder_jit_trace-pnnx.pt | 283 MB | ++----------------------------------------+------------+ +| decoder_jit_trace-pnnx.pt | 1010 KB | ++----------------------------------------+------------+ +| joiner_jit_trace-pnnx.pt | 3.0 MB | ++----------------------------------------+------------+ +| encoder_jit_trace-pnnx.ncnn.bin (fp16) | 142 MB | ++----------------------------------------+------------+ +| decoder_jit_trace-pnnx.ncnn.bin (fp16) | 503 KB | ++----------------------------------------+------------+ +| joiner_jit_trace-pnnx.ncnn.bin (fp16) | 1.5 MB | ++----------------------------------------+------------+ +| encoder_jit_trace-pnnx.ncnn.bin (fp32) | 283 MB | ++----------------------------------------+------------+ +| joiner_jit_trace-pnnx.ncnn.bin (fp32) | 3.0 MB | ++----------------------------------------+------------+ +| encoder_jit_trace-pnnx.ncnn.int8.bin | 99 MB | ++----------------------------------------+------------+ +| joiner_jit_trace-pnnx.ncnn.int8.bin | 774 KB | ++----------------------------------------+------------+ + +You can see that the file sizes of the model after ``int8`` quantization +are much smaller. + +.. hint:: + + Currently, only linear layers and convolutional layers are quantized + with ``int8``, so you don't see an exact ``4x`` reduction in file sizes. + +.. note:: + + You need to test the recognition accuracy after ``int8`` quantization. + +You can find the speed comparison at ``_. + + +That's it! Have fun with `sherpa-ncnn`_! From 5c8e9628cc39b9fb1e471d53df9aec06b2602b97 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Fri, 13 Jan 2023 15:21:29 +0800 Subject: [PATCH 081/174] update faq for libpython3.10.so not found (#838) --- docs/source/conf.py | 3 + docs/source/faqs.rst | 40 ++++++++++++ .../librispeech/distillation.rst | 65 ++++++++++--------- .../pruned_transducer_stateless.rst | 2 + 4 files changed, 79 insertions(+), 31 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index 33429f74c..ef9fe1445 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -81,6 +81,9 @@ todo_include_todos = True rst_epilog = """ .. _sherpa-ncnn: https://github.com/k2-fsa/sherpa-ncnn +.. _icefall: https://github.com/k2-fsa/icefall .. _git-lfs: https://git-lfs.com/ .. _ncnn: https://github.com/tencent/ncnn +.. _LibriSpeech: https://www.openslr.org/12 +.. _musan: http://www.openslr.org/17/ """ diff --git a/docs/source/faqs.rst b/docs/source/faqs.rst index c70ded431..72b0302d7 100644 --- a/docs/source/faqs.rst +++ b/docs/source/faqs.rst @@ -65,3 +65,43 @@ The fix is: pip uninstall setuptools pip install setuptools==58.0.4 + +ImportError: libpython3.10.so.1.0: cannot open shared object file: No such file or directory +-------------------------------------------------------------------------------------------- + +If you are using ``conda`` and encounter the following issue: + +.. code-block:: + + Traceback (most recent call last): + File "/k2-dev/yangyifan/anaconda3/envs/icefall/lib/python3.10/site-packages/k2-1.23.3.dev20230112+cuda11.6.torch1.13.1-py3.10-linux-x86_64.egg/k2/__init__.py", line 24, in + from _k2 import DeterminizeWeightPushingType + ImportError: libpython3.10.so.1.0: cannot open shared object file: No such file or directory + + During handling of the above exception, another exception occurred: + + Traceback (most recent call last): + File "/k2-dev/yangyifan/icefall/egs/librispeech/ASR/./pruned_transducer_stateless7_ctc_bs/decode.py", line 104, in + import k2 + File "/k2-dev/yangyifan/anaconda3/envs/icefall/lib/python3.10/site-packages/k2-1.23.3.dev20230112+cuda11.6.torch1.13.1-py3.10-linux-x86_64.egg/k2/__init__.py", line 30, in + raise ImportError( + ImportError: libpython3.10.so.1.0: cannot open shared object file: No such file or directory + Note: If you're using anaconda and importing k2 on MacOS, + you can probably fix this by setting the environment variable: + export DYLD_LIBRARY_PATH=$CONDA_PREFIX/lib/python3.10/site-packages:$DYLD_LIBRARY_PATH + +Please first try to find where ``libpython3.10.so.1.0`` locates. + +For instance, + +.. code-block:: bash + + cd $CONDA_PREFIX/lib + find . -name "libpython*" + +If you are able to find it inside ``$CODNA_PREFIX/lib``, please set the +following environment variable: + +.. code-block:: bash + + export LD_LIBRARY_PATH=$CONDA_PREFIX/lib:$LD_LIBRARY_PATH diff --git a/docs/source/recipes/Non-streaming-ASR/librispeech/distillation.rst b/docs/source/recipes/Non-streaming-ASR/librispeech/distillation.rst index aa379c3f8..ea9f350cd 100644 --- a/docs/source/recipes/Non-streaming-ASR/librispeech/distillation.rst +++ b/docs/source/recipes/Non-streaming-ASR/librispeech/distillation.rst @@ -1,16 +1,16 @@ Distillation with HuBERT ======================== -This totorial shows you how to perform knowledge distillation in ``icefall`` -with the `LibriSpeech `_ dataset. The distillation method -used here is called "Multi Vector Quantization Knowledge Distillation" (MVQ-KD). +This tutorial shows you how to perform knowledge distillation in `icefall`_ +with the `LibriSpeech`_ dataset. The distillation method +used here is called "Multi Vector Quantization Knowledge Distillation" (MVQ-KD). Please have a look at our paper `Predicting Multi-Codebook Vector Quantization Indexes for Knowledge Distillation `_ for more details about MVQ-KD. .. note:: This tutorial is based on recipe - `pruned_transducer_stateless4 `_. + `pruned_transducer_stateless4 `_. Currently, we only implement MVQ-KD in this recipe. However, MVQ-KD is theoretically applicable to all recipes with only minor changes needed. Feel free to try out MVQ-KD in different recipes. If you encounter any problems, please open an issue here `icefall `_. @@ -18,7 +18,7 @@ for more details about MVQ-KD. .. note:: We assume you have read the page :ref:`install icefall` and have setup - the environment for ``icefall``. + the environment for `icefall`_. .. HINT:: @@ -27,13 +27,13 @@ for more details about MVQ-KD. Data preparation ---------------- -We first prepare necessary training data for ``LibriSpeech``. -This is the same as in `Pruned_transducer_statelessX <./pruned_transducer_stateless.rst>`_. +We first prepare necessary training data for `LibriSpeech`_. +This is the same as in :ref:`non_streaming_librispeech_pruned_transducer_stateless`. .. hint:: The data preparation is the same as other recipes on LibriSpeech dataset, - if you have finished this step, you can skip to ``Codebook index preparation`` directly. + if you have finished this step, you can skip to :ref:`codebook_index_preparation` directly. .. code-block:: bash @@ -61,8 +61,8 @@ For example, .. HINT:: - If you have pre-downloaded the `LibriSpeech `_ - dataset and the `musan `_ dataset, say, + If you have pre-downloaded the `LibriSpeech`_ + dataset and the `musan`_ dataset, say, they are saved in ``/tmp/LibriSpeech`` and ``/tmp/musan``, you can modify the ``dl_dir`` variable in ``./prepare.sh`` to point to ``/tmp`` so that ``./prepare.sh`` won't re-download them. @@ -84,24 +84,27 @@ We provide the following YouTube video showing how to run ``./prepare.sh``. .. youtube:: ofEIoJL-mGM +.. _codebook_index_preparation: + Codebook index preparation -------------------------- Here, we prepare necessary data for MVQ-KD. This requires the generation of codebook indexes (please read our `paper `_. -if you are interested in details). In this tutorial, we use the pre-computed -codebook indexes for convenience. The only thing you need to do is to -run ``./distillation_with_hubert.sh``. +if you are interested in details). In this tutorial, we use the pre-computed +codebook indexes for convenience. The only thing you need to do is to +run `./distillation_with_hubert.sh `_. .. note:: - There are 5 stages in total, the first and second stage will be automatically skipped - when choosing to downloaded codebook indexes prepared by `icefall`_. - Of course, you can extract and compute the codebook indexes by yourself. This - will require you downloading a HuBERT-XL model and it can take a while for - the extraction of codebook indexes. - -As usual, you can control the stages you want to run by specifying the following + There are 5 stages in total, the first and second stage will be automatically skipped + when choosing to downloaded codebook indexes prepared by `icefall`_. + Of course, you can extract and compute the codebook indexes by yourself. This + will require you downloading a HuBERT-XL model and it can take a while for + the extraction of codebook indexes. + + +As usual, you can control the stages you want to run by specifying the following two options: - ``--stage`` @@ -115,7 +118,7 @@ For example, $ ./distillation_with_hubert.sh --stage 0 --stop-stage 0 # run only stage 0 $ ./distillation_with_hubert.sh --stage 2 --stop-stage 4 # run from stage 2 to stage 5 -Here are a few options in ``./distillation_with_hubert.sh`` +Here are a few options in `./distillation_with_hubert.sh `_ you need to know before you proceed. - ``--full_libri`` If True, use full 960h data. Otherwise only ``train-clean-100`` will be used @@ -126,14 +129,14 @@ Since we are using the pre-computed codebook indexes, we set ``use_extracted_codebook=True``. If you want to do full `LibriSpeech`_ experiments, please set ``full_libri=True``. -The following command downloads the pre-computed codebook indexes -and prepares MVQ-augmented training manifests. +The following command downloads the pre-computed codebook indexes +and prepares MVQ-augmented training manifests. .. code-block:: bash $ ./distillation_with_hubert.sh --stage 2 --stop-stage 2 # run only stage 2 -Please see the +Please see the following screenshot for the output of an example execution. .. figure:: ./images/distillation_codebook.png @@ -146,12 +149,12 @@ following screenshot for the output of an example execution. .. hint:: The codebook indexes we prepared for you in this tutorial - are extracted from the 36-th layer of a fine-tuned HuBERT-XL model + are extracted from the 36-th layer of a fine-tuned HuBERT-XL model with 8 codebooks. If you want to try other configurations, please - set ``use_extracted_codebook=False`` and set ``embedding_layer`` and + set ``use_extracted_codebook=False`` and set ``embedding_layer`` and ``num_codebooks`` by yourself. -Now, you should see the following files under the direcory ``./data/vq_fbank_layer36_cb8``. +Now, you should see the following files under the directory ``./data/vq_fbank_layer36_cb8``. .. figure:: ./images/distillation_directory.png :width: 800 @@ -165,7 +168,7 @@ Whola! You are ready to perform knowledge distillation training now! Training -------- -To perform training, please run stage 3 by executing the following command. +To perform training, please run stage 3 by executing the following command. .. code-block:: bash @@ -176,7 +179,7 @@ Here is the code snippet for training: .. code-block:: bash WORLD_SIZE=$(echo ${CUDA_VISIBLE_DEVICES} | awk '{n=split($1, _, ","); print n}') - + ./pruned_transducer_stateless6/train.py \ --manifest-dir ./data/vq_fbank_layer36_cb8 \ --master-port 12359 \ @@ -191,6 +194,7 @@ Here is the code snippet for training: There are a few training arguments in the following training commands that should be paid attention to. + - ``--enable-distillation`` If True, knowledge distillation training is enabled. - ``--codebook-loss-scale`` The scale of the knowledge distillation loss. - ``--manifest-dir`` The path to the MVQ-augmented manifest. @@ -204,7 +208,7 @@ the following command. .. code-block:: bash - export CUDA_VISIBLE_DEVICES=0 + export CUDA_VISIBLE_DEVICES=0 ./pruned_transducer_stateless6/train.py \ --decoding-method "modified_beam_search" \ --epoch 30 \ @@ -217,4 +221,3 @@ You should get similar results as `here `_. - diff --git a/docs/source/recipes/Non-streaming-ASR/librispeech/pruned_transducer_stateless.rst b/docs/source/recipes/Non-streaming-ASR/librispeech/pruned_transducer_stateless.rst index 86d43c8fe..42fd3df77 100644 --- a/docs/source/recipes/Non-streaming-ASR/librispeech/pruned_transducer_stateless.rst +++ b/docs/source/recipes/Non-streaming-ASR/librispeech/pruned_transducer_stateless.rst @@ -1,3 +1,5 @@ +.. _non_streaming_librispeech_pruned_transducer_stateless: + Pruned transducer statelessX ============================ From 2a463a420d5080a93ac8933554e13f788a8a59e1 Mon Sep 17 00:00:00 2001 From: Zengwei Yao Date: Mon, 16 Jan 2023 20:15:35 +0800 Subject: [PATCH 082/174] Filter uneven-sized batch (#843) * add filter_uneven_sized_batch fucntion * set --filter-uneven-sized-batch=True as default --- .../ASR/pruned_transducer_stateless7/train.py | 33 ++++++++++++++++- icefall/utils.py | 36 +++++++++++++++++++ 2 files changed, 68 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py index 31a3a0505..a806244ff 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py @@ -82,7 +82,13 @@ from icefall.checkpoint import ( from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.hooks import register_inf_check_hooks -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool +from icefall.utils import ( + AttributeDict, + MetricsTracker, + filter_uneven_sized_batch, + setup_logger, + str2bool, +) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -368,6 +374,21 @@ def get_parser(): help="Whether to use half precision training.", ) + parser.add_argument( + "--filter-uneven-sized-batch", + type=str2bool, + default=True, + help="""Whether to filter uneven-sized minibatch. + For the uneven-sized batch, the total duration after padding would possibly + cause OOM. Hence, for each batch, which is sorted descendingly by length, + we simply drop the last few shortest samples, so that the retained total frames + (after padding) would not exceed `allowed_max_frames`: + `allowed_max_frames = int(max_frames * (1.0 + allowed_excess_duration_ratio))`, + where `max_frames = max_duration * 1000 // frame_shift_ms`. + We set allowed_excess_duration_ratio=0.1. + """, + ) + add_model_arguments(parser) return parser @@ -420,6 +441,9 @@ def get_params() -> AttributeDict: """ params = AttributeDict( { + "frame_shift_ms": 10.0, + # only used when params.filter_uneven_sized_batch is True + "allowed_excess_duration_ratio": 0.1, "best_train_loss": float("inf"), "best_valid_loss": float("inf"), "best_train_epoch": -1, @@ -642,6 +666,13 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ + if params.filter_uneven_sized_batch: + max_frames = params.max_duration * 1000 // params.frame_shift_ms + allowed_max_frames = int( + max_frames * (1.0 + params.allowed_excess_duration_ratio) + ) + batch = filter_uneven_sized_batch(batch, allowed_max_frames) + device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] # at entry, feature is (N, T, C) diff --git a/icefall/utils.py b/icefall/utils.py index 99e51a2a9..ba0b7fe43 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -1395,3 +1395,39 @@ def is_module_available(*modules: str) -> bool: import importlib return all(importlib.util.find_spec(m) is not None for m in modules) + + +def filter_uneven_sized_batch(batch: dict, allowed_max_frames: int): + """For the uneven-sized batch, the total duration after padding would possibly + cause OOM. Hence, for each batch, which is sorted descendingly by length, + we simply drop the last few shortest samples, so that the retained total frames + (after padding) would not exceed the given allow_max_frames. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + allowed_max_frames: + The allowed max number of frames in batch. + """ + features = batch["inputs"] + supervisions = batch["supervisions"] + + N, T, _ = features.size() + assert T == supervisions["num_frames"].max(), (T, supervisions["num_frames"].max()) + keep_num_utt = allowed_max_frames // T + + if keep_num_utt >= N: + return batch + + # Note: we assume the samples in batch is sorted descendingly by length + logging.info( + f"Filtering uneven-sized batch, original batch size is {N}, " + f"retained batch size is {keep_num_utt}." + ) + batch["inputs"] = features[:keep_num_utt] + for k, v in supervisions.items(): + assert len(v) == N, (len(v), N) + batch["supervisions"][k] = v[:keep_num_utt] + + return batch From 0af3e7beda1cb47cba8b51ce71f691e86cae2091 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Mon, 16 Jan 2023 20:26:36 +0800 Subject: [PATCH 083/174] fix export for stateless4 (#844) --- egs/librispeech/ASR/pruned_transducer_stateless4/export.py | 2 ++ egs/librispeech/ASR/pruned_transducer_stateless4/lstmp.py | 1 + .../ASR/pruned_transducer_stateless4/scaling_converter.py | 1 + 3 files changed, 4 insertions(+) create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless4/lstmp.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless4/scaling_converter.py diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/export.py b/egs/librispeech/ASR/pruned_transducer_stateless4/export.py index 401b3ef3a..8f33f5b05 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/export.py @@ -50,6 +50,7 @@ from pathlib import Path import sentencepiece as spm import torch +from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_params, get_transducer_model from icefall.checkpoint import ( @@ -261,6 +262,7 @@ def main(): model.eval() if params.jit: + convert_scaled_to_non_scaled(model, inplace=True) # We won't use the forward() method of the model in C++, so just ignore # it here. # Otherwise, one of its arguments is a ragged tensor and is not diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/lstmp.py b/egs/librispeech/ASR/pruned_transducer_stateless4/lstmp.py new file mode 120000 index 000000000..9aa06f82f --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/lstmp.py @@ -0,0 +1 @@ +../pruned_transducer_stateless3/lstmp.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/scaling_converter.py b/egs/librispeech/ASR/pruned_transducer_stateless4/scaling_converter.py new file mode 120000 index 000000000..3b667058d --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/scaling_converter.py @@ -0,0 +1 @@ +../pruned_transducer_stateless3/scaling_converter.py \ No newline at end of file From f5ff7a18ebf90c82dd73434b276328fcbe287c13 Mon Sep 17 00:00:00 2001 From: Wei Kang Date: Tue, 17 Jan 2023 11:28:59 +0800 Subject: [PATCH 084/174] Fix the unclear description for streaming model (#849) --- docs/source/recipes/Streaming-ASR/introduction.rst | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/docs/source/recipes/Streaming-ASR/introduction.rst b/docs/source/recipes/Streaming-ASR/introduction.rst index d81156659..e1382e77d 100644 --- a/docs/source/recipes/Streaming-ASR/introduction.rst +++ b/docs/source/recipes/Streaming-ASR/introduction.rst @@ -30,8 +30,9 @@ In icefall, we implement the streaming conformer the way just like what `WeNet < See :doc:`Pruned transducer statelessX ` for more details. .. HINT:: - If you want to adapt a non-streaming conformer model to be streaming, please refer - to `this pull request `_. + If you want to modify a non-streaming conformer recipe to support both streaming and non-streaming, please refer + to `this pull request `_. After adding the code needed by streaming training, + you have to re-train it with the extra arguments metioned in the docs above to get a streaming model. Streaming Emformer From 6b1ab71dc9c715fe08f5ba7dadc6d7c083be904c Mon Sep 17 00:00:00 2001 From: Zengwei Yao Date: Fri, 27 Jan 2023 21:24:12 +0800 Subject: [PATCH 085/174] hardcode --filter-uneven-sized-batch (#854) --- .../ASR/pruned_transducer_stateless7/train.py | 32 ++++++------------- 1 file changed, 10 insertions(+), 22 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py index a806244ff..6022406eb 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py @@ -374,21 +374,6 @@ def get_parser(): help="Whether to use half precision training.", ) - parser.add_argument( - "--filter-uneven-sized-batch", - type=str2bool, - default=True, - help="""Whether to filter uneven-sized minibatch. - For the uneven-sized batch, the total duration after padding would possibly - cause OOM. Hence, for each batch, which is sorted descendingly by length, - we simply drop the last few shortest samples, so that the retained total frames - (after padding) would not exceed `allowed_max_frames`: - `allowed_max_frames = int(max_frames * (1.0 + allowed_excess_duration_ratio))`, - where `max_frames = max_duration * 1000 // frame_shift_ms`. - We set allowed_excess_duration_ratio=0.1. - """, - ) - add_model_arguments(parser) return parser @@ -442,7 +427,6 @@ def get_params() -> AttributeDict: params = AttributeDict( { "frame_shift_ms": 10.0, - # only used when params.filter_uneven_sized_batch is True "allowed_excess_duration_ratio": 0.1, "best_train_loss": float("inf"), "best_valid_loss": float("inf"), @@ -666,12 +650,16 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - if params.filter_uneven_sized_batch: - max_frames = params.max_duration * 1000 // params.frame_shift_ms - allowed_max_frames = int( - max_frames * (1.0 + params.allowed_excess_duration_ratio) - ) - batch = filter_uneven_sized_batch(batch, allowed_max_frames) + # For the uneven-sized batch, the total duration after padding would possibly + # cause OOM. Hence, for each batch, which is sorted descendingly by length, + # we simply drop the last few shortest samples, so that the retained total frames + # (after padding) would not exceed `allowed_max_frames`: + # `allowed_max_frames = int(max_frames * (1.0 + allowed_excess_duration_ratio))`, + # where `max_frames = max_duration * 1000 // frame_shift_ms`. + # We set allowed_excess_duration_ratio=0.1. + max_frames = params.max_duration * 1000 // params.frame_shift_ms + allowed_max_frames = int(max_frames * (1.0 + params.allowed_excess_duration_ratio)) + batch = filter_uneven_sized_batch(batch, allowed_max_frames) device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] From 1ce2bc1ee08a9b31b00c12aeb0912f41ec399d3f Mon Sep 17 00:00:00 2001 From: Teo Wen Shen <36886809+teowenshen@users.noreply.github.com> Date: Sat, 28 Jan 2023 14:47:21 +0900 Subject: [PATCH 086/174] edit comments (#852) --- .../pruned_transducer_stateless7/zipformer.py | 16 +++++++-------- .../zipformer.py | 20 +++++++++---------- 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index b1717ec64..5cde57812 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -197,13 +197,13 @@ class Zipformer(EncoderInterface): """ In eval mode, returns [1.0] * num_encoders; in training mode, returns a number of randomized feature masks, one per encoder. - On e.g. 15% of frames, these masks will zero out all enocder dims larger than + On e.g. 15% of frames, these masks will zero out all encoder dims larger than some supplied number, e.g. >256, so in effect on those frames we are using - a smaller encoer dim. + a smaller encoder dim. We generate the random masks at this level because we want the 2 masks to 'agree' all the way up the encoder stack. This will mean that the 1st mask will have - mask values repeated self.zipformer_subsampling_factor times. + mask values repeated self.zipformer_downsampling_factors times. Args: x: the embeddings (needed for the shape and dtype and device), of shape @@ -1009,10 +1009,10 @@ class RelPositionMultiheadAttention(nn.Module): # the initial_scale is supposed to take over the "scaling" factor of # head_dim ** -0.5, dividing it between the query and key. in_proj_dim = ( - 2 * attention_dim - + attention_dim // 2 # query, key - + pos_dim * num_heads # value - ) # positional encoding query + 2 * attention_dim # query, key + + attention_dim // 2 # value + + pos_dim * num_heads # positional encoding query + ) self.in_proj = ScaledLinear( embed_dim, in_proj_dim, bias=True, initial_scale=self.head_dim**-0.25 @@ -1509,7 +1509,7 @@ class FeedforwardModule(nn.Module): class ConvolutionModule(nn.Module): """ConvolutionModule in Zipformer model. - Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/zipformer/convolution.py + Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/conformer/convolution.py Args: channels (int): The number of channels of conv layers. diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py index 88beb38c1..e13629384 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py @@ -421,13 +421,13 @@ class Zipformer(EncoderInterface): """ In eval mode, returns [1.0] * num_encoders; in training mode, returns a number of randomized feature masks, one per encoder. - On e.g. 15% of frames, these masks will zero out all enocder dims larger than + On e.g. 15% of frames, these masks will zero out all encoder dims larger than some supplied number, e.g. >256, so in effect on those frames we are using - a smaller encoer dim. + a smaller encoder dim. We generate the random masks at this level because we want the 2 masks to 'agree' all the way up the encoder stack. This will mean that the 1st mask will have - mask values repeated self.zipformer_subsampling_factor times. + mask values repeated self.zipformer_downsampling_factors times. Args: x: the embeddings (needed for the shape and dtype and device), of shape @@ -1687,8 +1687,8 @@ class RelPositionalEncoding(torch.nn.Module): if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return - # Suppose `i` means to the position of query vecotr and `j` means the - # position of key vector. We use position relative positions when keys + # Suppose `i` means to the position of query vector and `j` means the + # position of key vector. We use positive relative positions when keys # are to the left (i>j) and negative relative positions otherwise (i Date: Sat, 28 Jan 2023 14:43:47 +0800 Subject: [PATCH 087/174] fix expired links (#856) --- egs/aishell/ASR/README.md | 2 +- egs/librispeech/ASR/README.md | 2 +- egs/timit/ASR/README.md | 2 +- egs/yesno/ASR/README.md | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/egs/aishell/ASR/README.md b/egs/aishell/ASR/README.md index 75fc6326e..f4a59e552 100644 --- a/egs/aishell/ASR/README.md +++ b/egs/aishell/ASR/README.md @@ -1,7 +1,7 @@ # Introduction -Please refer to +Please refer to for how to run models in this recipe. diff --git a/egs/librispeech/ASR/README.md b/egs/librispeech/ASR/README.md index 94cb445a8..9ffd78d5b 100644 --- a/egs/librispeech/ASR/README.md +++ b/egs/librispeech/ASR/README.md @@ -1,6 +1,6 @@ # Introduction -Please refer to for how to run models in this recipe. +Please refer to for how to run models in this recipe. [./RESULTS.md](./RESULTS.md) contains the latest results. diff --git a/egs/timit/ASR/README.md b/egs/timit/ASR/README.md index f10bfccfd..d493fc479 100644 --- a/egs/timit/ASR/README.md +++ b/egs/timit/ASR/README.md @@ -1,3 +1,3 @@ -Please refer to +Please refer to for how to run models in this recipe. diff --git a/egs/yesno/ASR/README.md b/egs/yesno/ASR/README.md index 7257bad9a..38b491fc6 100644 --- a/egs/yesno/ASR/README.md +++ b/egs/yesno/ASR/README.md @@ -10,5 +10,5 @@ get the following WER: ``` Please refer to - + for detailed instructions. From e277e31e37279da78ece356efc664e310ef18e5d Mon Sep 17 00:00:00 2001 From: Yifan Yang <64255737+yfyeung@users.noreply.github.com> Date: Sun, 29 Jan 2023 15:35:36 +0800 Subject: [PATCH 088/174] update huggingface link of zipformer_ctc_blankskip.rst (#858) * update huggingface link * update link --------- Co-authored-by: yifanyang --- .../Non-streaming-ASR/librispeech/zipformer_ctc_blankskip.rst | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/source/recipes/Non-streaming-ASR/librispeech/zipformer_ctc_blankskip.rst b/docs/source/recipes/Non-streaming-ASR/librispeech/zipformer_ctc_blankskip.rst index 56a420605..4929df950 100644 --- a/docs/source/recipes/Non-streaming-ASR/librispeech/zipformer_ctc_blankskip.rst +++ b/docs/source/recipes/Non-streaming-ASR/librispeech/zipformer_ctc_blankskip.rst @@ -447,7 +447,8 @@ Download pretrained models If you don't want to train from scratch, you can download the pretrained models by visiting the following links: - - ``_ + - trained on LibriSpeech 100h: ``_ + - trained on LibriSpeech 960h: ``_ See ``_ for the details of the above pretrained models From e9019511eb1792b6fa2c166dbe4f6ab02e7e537f Mon Sep 17 00:00:00 2001 From: BuaaAlban Date: Tue, 31 Jan 2023 15:19:50 +0800 Subject: [PATCH 089/174] Fix bug in streaming_conformer_ctc egs (#862) * Update train.py Fix transducer lstm egs bug as mentioned in issue 579 * Update train.py fix dataloader bug --- .../ASR/streaming_conformer_ctc/train.py | 21 ++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/streaming_conformer_ctc/train.py b/egs/librispeech/ASR/streaming_conformer_ctc/train.py index 553b7d092..d265de45b 100755 --- a/egs/librispeech/ASR/streaming_conformer_ctc/train.py +++ b/egs/librispeech/ASR/streaming_conformer_ctc/train.py @@ -50,7 +50,7 @@ from icefall.utils import ( setup_logger, str2bool, ) - +from lhotse.cut import Cut def get_parser(): parser = argparse.ArgumentParser( @@ -645,8 +645,23 @@ def run(rank, world_size, args): optimizer.load_state_dict(checkpoints["optimizer"]) librispeech = LibriSpeechAsrDataModule(args) - train_dl = librispeech.train_dataloaders() - valid_dl = librispeech.valid_dataloaders() + + if params.full_libri: + train_cuts = librispeech.train_all_shuf_cuts() + else: + train_cuts = librispeech.train_clean_100_cuts() + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + return 1.0 <= c.duration <= 20.0 + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + train_dl = librispeech.train_dataloaders(train_cuts) + + valid_cuts = librispeech.dev_clean_cuts() + valid_cuts += librispeech.dev_other_cuts() + valid_dl = librispeech.valid_dataloaders(valid_cuts) scan_pessimistic_batches_for_oom( model=model, From d8234e199c65a5971827ddaaa4deb72bd173f0ae Mon Sep 17 00:00:00 2001 From: Yifan Yang <64255737+yfyeung@users.noreply.github.com> Date: Tue, 31 Jan 2023 15:57:03 +0800 Subject: [PATCH 090/174] Add export to ONNX for Zipformer+CTC using blank skip (#861) * Add export to ONNX for Zipformer+CTC using blank skip --------- Co-authored-by: yifanyang --- .../export.py | 6 +- .../export_onnx.py | 665 ++++++++++++++++++ .../frame_reducer.py | 76 +- .../onnx_pretrained.py | 461 ++++++++++++ 4 files changed, 1188 insertions(+), 20 deletions(-) create mode 100644 egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/export_onnx.py mode change 100755 => 100644 egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/frame_reducer.py create mode 100644 egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/onnx_pretrained.py diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/export.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/export.py index 96d316604..05df8cfff 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/export.py @@ -72,14 +72,14 @@ Check ./pretrained.py for its usage. Note: If you don't want to train a model from scratch, we have provided one for you. You can get it at -https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11 +https://huggingface.co/yfyeung/icefall-asr-librispeech-pruned_transducer_stateless7_ctc_bs-2023-01-29 with the following commands: sudo apt-get install git-lfs git lfs install - git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11 - # You will find the pre-trained model in icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11/exp + git clone https://huggingface.co/yfyeung/icefall-asr-librispeech-pruned_transducer_stateless7_ctc_bs-2023-01-29 + # You will find the pre-trained model in icefall-asr-librispeech-pruned_transducer_stateless7_ctc_bs-2023-01-29/exp """ import argparse diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/export_onnx.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/export_onnx.py new file mode 100644 index 000000000..50efa6e60 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/export_onnx.py @@ -0,0 +1,665 @@ +#!/usr/bin/env python3 +# +# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang, +# Yifan Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This script converts several saved checkpoints +# to a single one using model averaging. +""" + +Usage: + +(1) Export to ONNX format + +./pruned_transducer_stateless7_ctc_bs/export_onnx.py \ + --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 30 \ + --avg 13 + +It will generate the following files in the given `exp_dir`. +Check `onnx_check.py` for how to use them. + + - encoder.onnx + - decoder.onnx + - joiner.onnx + - joiner_encoder_proj.onnx + - joiner_decoder_proj.onnx + - lconv.onnx + - frame_reducer.onnx + +Please see ./onnx_pretrained.py for usage of the generated files + +Check +https://github.com/k2-fsa/sherpa-onnx +for how to use the exported models outside of icefall. + +Note: If you don't want to train a model from scratch, we have +provided one for you. You can get it at + +https://huggingface.co/yfyeung/icefall-asr-librispeech-pruned_transducer_stateless7_ctc_bs-2023-01-29 + +with the following commands: + + sudo apt-get install git-lfs + git lfs install + git clone https://huggingface.co/yfyeung/icefall-asr-librispeech-pruned_transducer_stateless7_ctc_bs-2023-01-29 + # You will find the pre-trained model in icefall-asr-librispeech-pruned_transducer_stateless7_ctc_bs-2023-01-29/exp +""" + +import argparse +import logging +from pathlib import Path + +import sentencepiece as spm +import torch +import torch.nn as nn +from scaling_converter import convert_scaled_to_non_scaled +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=9, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7_ctc_bs/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--onnx", + type=str2bool, + default=True, + help="""If True, --jit is ignored and it exports the model + to onnx format. It will generate the following files: + + - encoder.onnx + - decoder.onnx + - joiner.onnx + - joiner_encoder_proj.onnx + - joiner_decoder_proj.onnx + - lconv.onnx + - frame_reducer.onnx + + Refer to ./onnx_check.py and ./onnx_pretrained.py for how to use them. + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + add_model_arguments(parser) + + return parser + + +def export_encoder_model_onnx( + encoder_model: nn.Module, + encoder_filename: str, + opset_version: int = 11, +) -> None: + """Export the given encoder model to ONNX format. + The exported model has two inputs: + + - x, a tensor of shape (N, T, C); dtype is torch.float32 + - x_lens, a tensor of shape (N,); dtype is torch.int64 + + and it has two outputs: + + - encoder_out, a tensor of shape (N, T, C) + - encoder_out_lens, a tensor of shape (N,) + + Note: The warmup argument is fixed to 1. + + Args: + encoder_model: + The input encoder model + encoder_filename: + The filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + x = torch.zeros(15, 2000, 80, dtype=torch.float32) + x_lens = torch.tensor([2000] * 15, dtype=torch.int64) + + # encoder_model = torch.jit.script(encoder_model) + # It throws the following error for the above statement + # + # RuntimeError: Exporting the operator __is_ to ONNX opset version + # 11 is not supported. Please feel free to request support or + # submit a pull request on PyTorch GitHub. + # + # I cannot find which statement causes the above error. + # torch.onnx.export() will use torch.jit.trace() internally, which + # works well for the current reworked model + torch.onnx.export( + encoder_model, + (x, x_lens), + encoder_filename, + verbose=False, + opset_version=opset_version, + input_names=["x", "x_lens"], + output_names=["encoder_out", "encoder_out_lens"], + dynamic_axes={ + "x": {0: "N", 1: "T"}, + "x_lens": {0: "N"}, + "encoder_out": {0: "N", 1: "T"}, + "encoder_out_lens": {0: "N"}, + }, + ) + logging.info(f"Saved to {encoder_filename}") + + +def export_decoder_model_onnx( + decoder_model: nn.Module, + decoder_filename: str, + opset_version: int = 11, +) -> None: + """Export the decoder model to ONNX format. + + The exported model has one input: + + - y: a torch.int64 tensor of shape (N, decoder_model.context_size) + + and has one output: + + - decoder_out: a torch.float32 tensor of shape (N, 1, C) + + Note: The argument need_pad is fixed to False. + + Args: + decoder_model: + The decoder model to be exported. + decoder_filename: + Filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + y = torch.zeros(15, decoder_model.context_size, dtype=torch.int64) + need_pad = False # Always False, so we can use torch.jit.trace() here + # Note(fangjun): torch.jit.trace() is more efficient than torch.jit.script() + # in this case + torch.onnx.export( + decoder_model, + (y, need_pad), + decoder_filename, + verbose=False, + opset_version=opset_version, + input_names=["y", "need_pad"], + output_names=["decoder_out"], + dynamic_axes={ + "y": {0: "N"}, + "decoder_out": {0: "N"}, + }, + ) + logging.info(f"Saved to {decoder_filename}") + + +def export_joiner_model_onnx( + joiner_model: nn.Module, + joiner_filename: str, + opset_version: int = 11, +) -> None: + """Export the joiner model to ONNX format. + The exported joiner model has two inputs: + + - projected_encoder_out: a tensor of shape (N, joiner_dim) + - projected_decoder_out: a tensor of shape (N, joiner_dim) + + and produces one output: + + - logit: a tensor of shape (N, vocab_size) + + The exported encoder_proj model has one input: + + - encoder_out: a tensor of shape (N, encoder_out_dim) + + and produces one output: + + - projected_encoder_out: a tensor of shape (N, joiner_dim) + + The exported decoder_proj model has one input: + + - decoder_out: a tensor of shape (N, decoder_out_dim) + + and produces one output: + + - projected_decoder_out: a tensor of shape (N, joiner_dim) + """ + encoder_proj_filename = str(joiner_filename).replace(".onnx", "_encoder_proj.onnx") + decoder_proj_filename = str(joiner_filename).replace(".onnx", "_decoder_proj.onnx") + + encoder_out_dim = joiner_model.encoder_proj.weight.shape[1] + decoder_out_dim = joiner_model.decoder_proj.weight.shape[1] + joiner_dim = joiner_model.decoder_proj.weight.shape[0] + + projected_encoder_out = torch.rand(1, 1, 1, joiner_dim, dtype=torch.float32) + projected_decoder_out = torch.rand(1, 1, 1, joiner_dim, dtype=torch.float32) + + project_input = False + # Note: It uses torch.jit.trace() internally + torch.onnx.export( + joiner_model, + (projected_encoder_out, projected_decoder_out, project_input), + joiner_filename, + verbose=False, + opset_version=opset_version, + input_names=[ + "encoder_out", + "decoder_out", + "project_input", + ], + output_names=["logit"], + dynamic_axes={ + "encoder_out": {0: "N"}, + "decoder_out": {0: "N"}, + "logit": {0: "N"}, + }, + ) + logging.info(f"Saved to {joiner_filename}") + + encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32) + torch.onnx.export( + joiner_model.encoder_proj, + encoder_out, + encoder_proj_filename, + verbose=False, + opset_version=opset_version, + input_names=["encoder_out"], + output_names=["projected_encoder_out"], + dynamic_axes={ + "encoder_out": {0: "N"}, + "projected_encoder_out": {0: "N"}, + }, + ) + logging.info(f"Saved to {encoder_proj_filename}") + + decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32) + torch.onnx.export( + joiner_model.decoder_proj, + decoder_out, + decoder_proj_filename, + verbose=False, + opset_version=opset_version, + input_names=["decoder_out"], + output_names=["projected_decoder_out"], + dynamic_axes={ + "decoder_out": {0: "N"}, + "projected_decoder_out": {0: "N"}, + }, + ) + logging.info(f"Saved to {decoder_proj_filename}") + + +def export_lconv_onnx( + lconv: nn.Module, + lconv_filename: str, + opset_version: int = 11, +) -> None: + """Export the lconv to ONNX format. + + The exported lconv has two inputs: + + - lconv_input: a tensor of shape (N, T, C) + - src_key_padding_mask: a tensor of shape (N, T) + + and has one output: + + - lconv_out: a tensor of shape (N, T, C) + + Args: + lconv: + The lconv to be exported. + lconv_filename: + Filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + lconv_input = torch.zeros(15, 498, 384, dtype=torch.float32) + src_key_padding_mask = torch.zeros(15, 498, dtype=torch.bool) + + torch.onnx.export( + lconv, + (lconv_input, src_key_padding_mask), + lconv_filename, + verbose=False, + opset_version=opset_version, + input_names=["lconv_input", "src_key_padding_mask"], + output_names=["lconv_out"], + dynamic_axes={ + "lconv_input": {0: "N", 1: "T"}, + "src_key_padding_mask": {0: "N", 1: "T"}, + "lconv_out": {0: "N", 1: "T"}, + }, + ) + logging.info(f"Saved to {lconv_filename}") + + +def export_frame_reducer_onnx( + frame_reducer: nn.Module, + frame_reducer_filename: str, + opset_version: int = 11, +) -> None: + """Export the frame_reducer to ONNX format. + + The exported frame_reducer has four inputs: + + - x: a tensor of shape (N, T, C) + - x_lens: a tensor of shape (N, T) + - ctc_output: a tensor of shape (N, T, vocab_size) + - blank_id: an int, always 0 + + and has two outputs: + + - x_fr: a tensor of shape (N, T, C) + - x_lens_fr: a tensor of shape (N, T) + + Args: + frame_reducer: + The frame_reducer to be exported. + frame_reducer_filename: + Filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + x = torch.zeros(15, 498, 384, dtype=torch.float32) + x_lens = torch.tensor([498] * 15, dtype=torch.int64) + ctc_output = torch.randn(15, 498, 500, dtype=torch.float32) + + torch.onnx.export( + frame_reducer, + (x, x_lens, ctc_output), + frame_reducer_filename, + verbose=False, + opset_version=opset_version, + input_names=["x", "x_lens", "ctc_output"], + output_names=["out", "out_lens"], + dynamic_axes={ + "x": {0: "N", 1: "T"}, + "x_lens": {0: "N"}, + "ctc_output": {0: "N", 1: "T"}, + "out": {0: "N", 1: "T"}, + "out_lens": {0: "N"}, + }, + ) + logging.info(f"Saved to {frame_reducer_filename}") + + +def export_ctc_output_onnx( + ctc_output: nn.Module, + ctc_output_filename: str, + opset_version: int = 11, +) -> None: + """Export the frame_reducer to ONNX format. + + The exported frame_reducer has one inputs: + + - encoder_out: a tensor of shape (N, T, C) + + and has one output: + + - ctc_output: a tensor of shape (N, T, vocab_size) + + Args: + ctc_output: + The ctc_output to be exported. + ctc_output_filename: + Filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + encoder_out = torch.zeros(15, 498, 384, dtype=torch.float32) + + torch.onnx.export( + ctc_output, + (encoder_out), + ctc_output_filename, + verbose=False, + opset_version=opset_version, + input_names=["encoder_out"], + output_names=["ctc_output"], + dynamic_axes={ + "encoder_out": {0: "N", 1: "T"}, + "ctc_output": {0: "N", 1: "T"}, + }, + ) + logging.info(f"Saved to {ctc_output_filename}") + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + model.to(device) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to("cpu") + model.eval() + + convert_scaled_to_non_scaled(model, inplace=True) + opset_version = 13 + logging.info("Exporting to onnx format") + encoder_filename = params.exp_dir / "encoder.onnx" + export_encoder_model_onnx( + model.encoder, + encoder_filename, + opset_version=opset_version, + ) + + decoder_filename = params.exp_dir / "decoder.onnx" + export_decoder_model_onnx( + model.decoder, + decoder_filename, + opset_version=opset_version, + ) + + joiner_filename = params.exp_dir / "joiner.onnx" + export_joiner_model_onnx( + model.joiner, + joiner_filename, + opset_version=opset_version, + ) + + lconv_filename = params.exp_dir / "lconv.onnx" + export_lconv_onnx( + model.lconv, + lconv_filename, + opset_version=opset_version, + ) + + frame_reducer_filename = params.exp_dir / "frame_reducer.onnx" + export_frame_reducer_onnx( + model.frame_reducer, + frame_reducer_filename, + opset_version=opset_version, + ) + + ctc_output_filename = params.exp_dir / "ctc_output.onnx" + export_ctc_output_onnx( + model.ctc_output, + ctc_output_filename, + opset_version=opset_version, + ) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/frame_reducer.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/frame_reducer.py old mode 100755 new mode 100644 index 9fe88929d..4a19edf66 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/frame_reducer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/frame_reducer.py @@ -22,7 +22,8 @@ from typing import List, Optional, Tuple, Union import torch import torch.nn as nn -from torch.nn.utils.rnn import pad_sequence +import torch.nn.functional as F + from icefall.utils import make_pad_mask @@ -43,7 +44,6 @@ class FrameReducer(nn.Module): x: torch.Tensor, x_lens: torch.Tensor, ctc_output: torch.Tensor, - blank_id: int = 0, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: @@ -54,26 +54,68 @@ class FrameReducer(nn.Module): `x` before padding. ctc_output: The CTC output with shape [N, T, vocab_size]. - blank_id: - The ID of the blank symbol. Returns: - x_fr: + out: The frame reduced encoder output with shape [N, T', C]. - x_lens_fr: + out_lens: A tensor of shape (batch_size,) containing the number of frames in - `x_fr` before padding. + `out` before padding. """ + N, T, C = x.size() + padding_mask = make_pad_mask(x_lens) - non_blank_mask = (ctc_output[:, :, blank_id] < math.log(0.9)) * (~padding_mask) + non_blank_mask = (ctc_output[:, :, 0] < math.log(0.9)) * (~padding_mask) - frames_list: List[torch.Tensor] = [] - lens_list: List[int] = [] - for i in range(x.shape[0]): - frames = x[i][non_blank_mask[i]] - frames_list.append(frames) - lens_list.append(frames.shape[0]) - x_fr = pad_sequence(frames_list, batch_first=True) - x_lens_fr = torch.tensor(lens_list).to(device=x.device) + out_lens = non_blank_mask.sum(dim=1) + max_len = out_lens.max() + pad_lens_list = torch.full_like(out_lens, max_len.item()) - out_lens + max_pad_len = pad_lens_list.max() - return x_fr, x_lens_fr + out = F.pad(x, (0, 0, 0, max_pad_len)) + + valid_pad_mask = ~make_pad_mask(pad_lens_list) + total_valid_mask = torch.concat([non_blank_mask, valid_pad_mask], dim=1) + + out = out[total_valid_mask].reshape(N, -1, C) + + return out.to(device=x.device), out_lens.to(device=x.device) + + +if __name__ == "__main__": + import time + from torch.nn.utils.rnn import pad_sequence + + test_times = 10000 + frame_reducer = FrameReducer() + + # non zero case + x = torch.ones(15, 498, 384, dtype=torch.float32) + x_lens = torch.tensor([498] * 15, dtype=torch.int64) + ctc_output = torch.log(torch.randn(15, 498, 500, dtype=torch.float32)) + x_fr, x_lens_fr = frame_reducer(x, x_lens, ctc_output) + + avg_time = 0 + for i in range(test_times): + delta_time = time.time() + x_fr, x_lens_fr = frame_reducer(x, x_lens, ctc_output) + delta_time = time.time() - delta_time + avg_time += delta_time + print(x_fr.shape) + print(x_lens_fr) + print(avg_time / test_times) + + # all zero case + x = torch.zeros(15, 498, 384, dtype=torch.float32) + x_lens = torch.tensor([498] * 15, dtype=torch.int64) + ctc_output = torch.zeros(15, 498, 500, dtype=torch.float32) + + avg_time = 0 + for i in range(test_times): + delta_time = time.time() + x_fr, x_lens_fr = frame_reducer(x, x_lens, ctc_output) + delta_time = time.time() - delta_time + avg_time += delta_time + print(x_fr.shape) + print(x_lens_fr) + print(avg_time / test_times) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/onnx_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/onnx_pretrained.py new file mode 100644 index 000000000..8ff02fbcb --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/onnx_pretrained.py @@ -0,0 +1,461 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang, +# Yifan Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads ONNX models and uses them to decode waves. +You can use the following command to get the exported models: + +./pruned_transducer_stateless7_ctc_bs/export_onnx.py \ + --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 30 \ + --avg 13 + +Usage of this script: + +./pruned_transducer_stateless7_ctc_bs/onnx_pretrained.py \ + --encoder-model-filename ./pruned_transducer_stateless7_ctc_bs/exp/encoder.onnx \ + --decoder-model-filename ./pruned_transducer_stateless7_ctc_bs/exp/decoder.onnx \ + --joiner-model-filename ./pruned_transducer_stateless7_ctc_bs/exp/joiner.onnx \ + --joiner-encoder-proj-model-filename ./pruned_transducer_stateless7_ctc_bs/exp/joiner_encoder_proj.onnx \ + --joiner-decoder-proj-model-filename ./pruned_transducer_stateless7_ctc_bs/exp/joiner_decoder_proj.onnx \ + --lconv-filename ./pruned_transducer_stateless7_ctc_bs/exp/lconv.onnx \ + --frame-reducer-filename ./pruned_transducer_stateless7_ctc_bs/exp/frame_reducer.onnx \ + --ctc-output-filename ./pruned_transducer_stateless7_ctc_bs/exp/ctc_output.onnx \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + /path/to/foo.wav \ + /path/to/bar.wav +""" + +import argparse +import logging +import math +from typing import List + +import kaldifeat +import numpy as np +import onnxruntime as ort +import sentencepiece as spm +import torch +import torchaudio +from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence + +from icefall.utils import make_pad_mask + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--encoder-model-filename", + type=str, + required=True, + help="Path to the encoder onnx model. ", + ) + + parser.add_argument( + "--decoder-model-filename", + type=str, + required=True, + help="Path to the decoder onnx model. ", + ) + + parser.add_argument( + "--joiner-model-filename", + type=str, + required=True, + help="Path to the joiner onnx model. ", + ) + + parser.add_argument( + "--joiner-encoder-proj-model-filename", + type=str, + required=True, + help="Path to the joiner encoder_proj onnx model. ", + ) + + parser.add_argument( + "--joiner-decoder-proj-model-filename", + type=str, + required=True, + help="Path to the joiner decoder_proj onnx model. ", + ) + + parser.add_argument( + "--lconv-filename", + type=str, + required=True, + help="Path to the lconv onnx model. ", + ) + + parser.add_argument( + "--frame-reducer-filename", + type=str, + required=True, + help="Path to the frame reducer onnx model. ", + ) + + parser.add_argument( + "--ctc-output-filename", + type=str, + required=True, + help="Path to the ctc_output onnx model. ", + ) + + parser.add_argument( + "--bpe-model", + type=str, + help="""Path to bpe.model.""", + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="Context size of the decoder model", + ) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0]) + return ans + + +def greedy_search( + decoder: ort.InferenceSession, + joiner: ort.InferenceSession, + joiner_encoder_proj: ort.InferenceSession, + joiner_decoder_proj: ort.InferenceSession, + encoder_out: np.ndarray, + encoder_out_lens: np.ndarray, + context_size: int, +) -> List[List[int]]: + """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. + Args: + decoder: + The decoder model. + joiner: + The joiner model. + joiner_encoder_proj: + The joiner encoder projection model. + joiner_decoder_proj: + The joiner decoder projection model. + encoder_out: + A 3-D tensor of shape (N, T, C) + encoder_out_lens: + A 1-D tensor of shape (N,). + context_size: + The context size of the decoder model. + Returns: + Return the decoded results for each utterance. + """ + encoder_out = torch.from_numpy(encoder_out) + encoder_out_lens = torch.from_numpy(encoder_out_lens) + assert encoder_out.ndim == 3 + assert encoder_out.size(0) >= 1, encoder_out.size(0) + + packed_encoder_out = pack_padded_sequence( + input=encoder_out, + lengths=encoder_out_lens.cpu(), + batch_first=True, + enforce_sorted=False, + ) + + projected_encoder_out = joiner_encoder_proj.run( + [joiner_encoder_proj.get_outputs()[0].name], + {joiner_encoder_proj.get_inputs()[0].name: packed_encoder_out.data.numpy()}, + )[0] + + blank_id = 0 # hard-code to 0 + + batch_size_list = packed_encoder_out.batch_sizes.tolist() + N = encoder_out.size(0) + + assert torch.all(encoder_out_lens > 0), encoder_out_lens + assert N == batch_size_list[0], (N, batch_size_list) + + hyps = [[blank_id] * context_size for _ in range(N)] + + decoder_input_nodes = decoder.get_inputs() + decoder_output_nodes = decoder.get_outputs() + + joiner_input_nodes = joiner.get_inputs() + joiner_output_nodes = joiner.get_outputs() + + decoder_input = torch.tensor( + hyps, + dtype=torch.int64, + ) # (N, context_size) + + decoder_out = decoder.run( + [decoder_output_nodes[0].name], + { + decoder_input_nodes[0].name: decoder_input.numpy(), + }, + )[0].squeeze(1) + projected_decoder_out = joiner_decoder_proj.run( + [joiner_decoder_proj.get_outputs()[0].name], + {joiner_decoder_proj.get_inputs()[0].name: decoder_out}, + )[0] + + projected_decoder_out = torch.from_numpy(projected_decoder_out) + + offset = 0 + for batch_size in batch_size_list: + start = offset + end = offset + batch_size + current_encoder_out = projected_encoder_out[start:end] + # current_encoder_out's shape: (batch_size, encoder_out_dim) + offset = end + + projected_decoder_out = projected_decoder_out[:batch_size] + + logits = joiner.run( + [joiner_output_nodes[0].name], + { + joiner_input_nodes[0].name: np.expand_dims( + np.expand_dims(current_encoder_out, axis=1), axis=1 + ), + joiner_input_nodes[1] + .name: projected_decoder_out.unsqueeze(1) + .unsqueeze(1) + .numpy(), + }, + )[0] + logits = torch.from_numpy(logits).squeeze(1).squeeze(1) + # logits'shape (batch_size, vocab_size) + + assert logits.ndim == 2, logits.shape + y = logits.argmax(dim=1).tolist() + emitted = False + for i, v in enumerate(y): + if v != blank_id: + hyps[i].append(v) + emitted = True + if emitted: + # update decoder output + decoder_input = [h[-context_size:] for h in hyps[:batch_size]] + decoder_input = torch.tensor( + decoder_input, + dtype=torch.int64, + ) + decoder_out = decoder.run( + [decoder_output_nodes[0].name], + { + decoder_input_nodes[0].name: decoder_input.numpy(), + }, + )[0].squeeze(1) + projected_decoder_out = joiner_decoder_proj.run( + [joiner_decoder_proj.get_outputs()[0].name], + {joiner_decoder_proj.get_inputs()[0].name: decoder_out}, + )[0] + projected_decoder_out = torch.from_numpy(projected_decoder_out) + + sorted_ans = [h[context_size:] for h in hyps] + ans = [] + unsorted_indices = packed_encoder_out.unsorted_indices.tolist() + for i in range(N): + ans.append(sorted_ans[unsorted_indices[i]]) + + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + logging.info(vars(args)) + + session_opts = ort.SessionOptions() + session_opts.inter_op_num_threads = 1 + session_opts.intra_op_num_threads = 1 + + encoder = ort.InferenceSession( + args.encoder_model_filename, + sess_options=session_opts, + ) + + decoder = ort.InferenceSession( + args.decoder_model_filename, + sess_options=session_opts, + ) + + joiner = ort.InferenceSession( + args.joiner_model_filename, + sess_options=session_opts, + ) + + joiner_encoder_proj = ort.InferenceSession( + args.joiner_encoder_proj_model_filename, + sess_options=session_opts, + ) + + joiner_decoder_proj = ort.InferenceSession( + args.joiner_decoder_proj_model_filename, + sess_options=session_opts, + ) + + lconv = ort.InferenceSession( + args.lconv_filename, + sess_options=session_opts, + ) + + frame_reducer = ort.InferenceSession( + args.frame_reducer_filename, + sess_options=session_opts, + ) + + ctc_output = ort.InferenceSession( + args.ctc_output_filename, + sess_options=session_opts, + ) + + sp = spm.SentencePieceProcessor() + sp.load(args.bpe_model) + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = "cpu" + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = args.sample_rate + opts.mel_opts.num_bins = 80 + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {args.sound_files}") + waves = read_sound_files( + filenames=args.sound_files, + expected_sample_rate=args.sample_rate, + ) + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + + features = pad_sequence( + features, + batch_first=True, + padding_value=math.log(1e-10), + ) + + feature_lengths = torch.tensor(feature_lengths, dtype=torch.int64) + + encoder_input_nodes = encoder.get_inputs() + encoder_out_nodes = encoder.get_outputs() + encoder_out, encoder_out_lens = encoder.run( + [encoder_out_nodes[0].name, encoder_out_nodes[1].name], + { + encoder_input_nodes[0].name: features.numpy(), + encoder_input_nodes[1].name: feature_lengths.numpy(), + }, + ) + + ctc_output_input_nodes = ctc_output.get_inputs() + ctc_output_out_nodes = ctc_output.get_outputs() + ctc_out = ctc_output.run( + [ctc_output_out_nodes[0].name], + { + ctc_output_input_nodes[0].name: encoder_out, + }, + )[0] + + lconv_input_nodes = lconv.get_inputs() + lconv_out_nodes = lconv.get_outputs() + encoder_out = lconv.run( + [lconv_out_nodes[0].name], + { + lconv_input_nodes[0].name: encoder_out, + lconv_input_nodes[1] + .name: make_pad_mask(torch.from_numpy(encoder_out_lens)) + .numpy(), + }, + )[0] + + frame_reducer_input_nodes = frame_reducer.get_inputs() + frame_reducer_out_nodes = frame_reducer.get_outputs() + encoder_out_fr, encoder_out_lens_fr = frame_reducer.run( + [frame_reducer_out_nodes[0].name, frame_reducer_out_nodes[1].name], + { + frame_reducer_input_nodes[0].name: encoder_out, + frame_reducer_input_nodes[1].name: encoder_out_lens, + frame_reducer_input_nodes[2].name: ctc_out, + }, + ) + + hyps = greedy_search( + decoder=decoder, + joiner=joiner, + joiner_encoder_proj=joiner_encoder_proj, + joiner_decoder_proj=joiner_decoder_proj, + encoder_out=encoder_out_fr, + encoder_out_lens=encoder_out_lens_fr, + context_size=args.context_size, + ) + s = "\n" + for filename, hyp in zip(args.sound_files, hyps): + words = sp.decode(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() From e36ea89112bb3d81602cb4df51bd68e6d06dc150 Mon Sep 17 00:00:00 2001 From: Yifan Yang <64255737+yfyeung@users.noreply.github.com> Date: Wed, 1 Feb 2023 21:04:56 +0800 Subject: [PATCH 091/174] update result.md for pruned_transducer_stateless7_ctc_bs (#865) --- egs/librispeech/ASR/RESULTS.md | 105 ++++++++++++++++++++++++++++++--- 1 file changed, 98 insertions(+), 7 deletions(-) diff --git a/egs/librispeech/ASR/RESULTS.md b/egs/librispeech/ASR/RESULTS.md index b30cf7c1f..a3e44f09c 100644 --- a/egs/librispeech/ASR/RESULTS.md +++ b/egs/librispeech/ASR/RESULTS.md @@ -93,13 +93,13 @@ results at: Number of model parameters: 69136519, i.e., 69.14 M -| | test-clean | test-other | comment | -|--------------------------|------------|-------------|---------------------| -| 1best | 2.54 | 5.65 | --epoch 30 --avg 10 | -| nbest | 2.54 | 5.66 | --epoch 30 --avg 10 | -| nbest-rescoring-LG | 2.49 | 5.42 | --epoch 30 --avg 10 | -| nbest-rescoring-3-gram | 2.52 | 5.62 | --epoch 30 --avg 10 | -| nbest-rescoring-4-gram | 2.5 | 5.51 | --epoch 30 --avg 10 | +| | test-clean | test-other | comment | +| ---------------------- | ---------- | ---------- | ------------------- | +| 1best | 2.54 | 5.65 | --epoch 30 --avg 10 | +| nbest | 2.54 | 5.66 | --epoch 30 --avg 10 | +| nbest-rescoring-LG | 2.49 | 5.42 | --epoch 30 --avg 10 | +| nbest-rescoring-3-gram | 2.52 | 5.62 | --epoch 30 --avg 10 | +| nbest-rescoring-4-gram | 2.5 | 5.51 | --epoch 30 --avg 10 | The training commands are: ```bash @@ -134,6 +134,97 @@ for m in nbest nbest-rescoring-LG nbest-rescoring-3-gram nbest-rescoring-4-gram; done ``` +### pruned_transducer_stateless7_ctc_bs (zipformer with transducer loss and ctc loss using blank skip) + +See https://github.com/k2-fsa/icefall/pull/730 for more details. + +[pruned_transducer_stateless7_ctc_bs](./pruned_transducer_stateless7_ctc_bs) + +The tensorboard log can be found at + + +You can find a pretrained model, training logs, decoding logs, and decoding +results at: + + +Number of model parameters: 76804822, i.e., 76.80 M + +Test on 8-card V100 cluster, with 4-card busy and 4-card idle. + +#### greedy_search + +| model | test-clean | test-other | decoding time(s) | comment | +| ------------------------------------------------------------ | ---------- | ---------- | ---------------- | ------------------- | +| [pruned_transducer_stateless7_ctc_bs](./pruned_transducer_stateless7_ctc_bs) | 2.28 | 5.53 | 48.939 | --epoch 30 --avg 13 | +| [pruned_transducer_stateless7_ctc](./pruned_transducer_stateless7_ctc) | 2.24 | 5.18 | 91.900 | --epoch 30 --avg 8 | + +- [pruned_transducer_stateless7_ctc_bs](./pruned_transducer_stateless7_ctc_bs) applies blank skip both on training and decoding, and [pruned_transducer_stateless7_ctc](./pruned_transducer_stateless7_ctc) doesn`t apply blank skip. +- Applying blank skip both on training and decoding is **1.88 times** faster than the model that doesn't apply blank skip without obvious performance loss. + +#### modified_beam_search + +| model | test-clean | test-other | decoding time(s) | comment | +| ------------------------------------------------------------ | ---------- | ---------- | ---------------- | ------------------- | +| [pruned_transducer_stateless7_ctc_bs](./pruned_transducer_stateless7_ctc_bs) | 2.26 | 5.44 | 80.446 | --epoch 30 --avg 13 | +| [pruned_transducer_stateless7_ctc](./pruned_transducer_stateless7_ctc) | 2.20 | 5.12 | 283.676 | --epoch 30 --avg 8 | + +- [pruned_transducer_stateless7_ctc_bs](./pruned_transducer_stateless7_ctc_bs) applies blank skip both on training and decoding, and [pruned_transducer_stateless7_ctc](./pruned_transducer_stateless7_ctc) doesn`t apply blank skip. +- Applying blank skip both on training and decoding is **3.53 times** faster than the model that doesn't apply blank skip without obvious performance loss. + +The training commands for the model using blank skip ([pruned_transducer_stateless7_ctc_bs](./pruned_transducer_stateless7_ctc_bs)) are: + +```bash +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +./pruned_transducer_stateless7_ctc_bs/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --full-libri 1 \ + --use-fp16 1 \ + --max-duration 750 \ + --exp-dir pruned_transducer_stateless7_ctc_bs/exp \ + --feedforward-dims "1024,1024,2048,2048,1024" \ + --ctc-loss-scale 0.2 \ + --master-port 12535 +``` + +The decoding commands for the transducer branch of the model using blank skip ([pruned_transducer_stateless7_ctc_bs](./pruned_transducer_stateless7_ctc_bs)) are: + +```bash +for m in greedy_search modified_beam_search fast_beam_search; do + for epoch in 30; do + for avg in 15; do + ./pruned_transducer_stateless7_ctc_bs/ctc_guild_decode_bs.py \ + --epoch $epoch \ + --avg $avg \ + --use-averaged-model 1 \ + --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ + --feedforward-dims "1024,1024,2048,2048,1024" \ + --max-duration 600 \ + --decoding-method $m + done + done +done +``` + +The decoding commands for the transducer branch of the model without blank skip ([pruned_transducer_stateless7_ctc](./pruned_transducer_stateless7_ctc)) are: + +```bash +for m in greedy_search modified_beam_search fast_beam_search; do + for epoch in 30; do + for avg in 15; do + ./pruned_transducer_stateless7_ctc/decode.py \ + --epoch $epoch \ + --avg $avg \ + --use-averaged-model 1 \ + --exp-dir ./pruned_transducer_stateless7_ctc/exp \ + --feedforward-dims "1024,1024,2048,2048,1024" \ + --max-duration 600 \ + --decoding-method $m + done + done +done +``` ### pruned_transducer_stateless7_ctc (zipformer with transducer loss and ctc loss) From 1e6d6f816001dbcf1275204385740a81bcc2ff14 Mon Sep 17 00:00:00 2001 From: Zengwei Yao Date: Fri, 3 Feb 2023 11:54:57 +0800 Subject: [PATCH 092/174] shuffle full Librispeech for zipformer recipes (#869) * shuffle libri --- egs/librispeech/ASR/pruned_transducer_stateless7/train.py | 6 +++--- .../ASR/pruned_transducer_stateless7_ctc/train.py | 6 +++--- .../ASR/pruned_transducer_stateless7_ctc_bs/train.py | 8 ++++---- .../ASR/pruned_transducer_stateless7_streaming/train.py | 6 +++--- egs/librispeech/ASR/pruned_transducer_stateless8/train.py | 6 +++--- egs/librispeech/ASR/streaming_conformer_ctc/train.py | 3 ++- 6 files changed, 18 insertions(+), 17 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py index 6022406eb..792a243e5 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py @@ -1043,10 +1043,10 @@ def run(rank, world_size, args): librispeech = LibriSpeechAsrDataModule(args) - train_cuts = librispeech.train_clean_100_cuts() if params.full_libri: - train_cuts += librispeech.train_clean_360_cuts() - train_cuts += librispeech.train_other_500_cuts() + train_cuts = librispeech.train_all_shuf_cuts() + else: + train_cuts = librispeech.train_clean_100_cuts() def remove_short_and_long_utt(c: Cut): # Keep only utterances with duration between 1 second and 20 seconds diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/train.py index 5a05e1836..718381baa 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/train.py @@ -1072,10 +1072,10 @@ def run(rank, world_size, args): librispeech = LibriSpeechAsrDataModule(args) - train_cuts = librispeech.train_clean_100_cuts() if params.full_libri: - train_cuts += librispeech.train_clean_360_cuts() - train_cuts += librispeech.train_other_500_cuts() + train_cuts = librispeech.train_all_shuf_cuts() + else: + train_cuts = librispeech.train_clean_100_cuts() def remove_short_and_long_utt(c: Cut): # Keep only utterances with duration between 1 second and 20 seconds diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/train.py index 522ecc974..b282ab9db 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/train.py @@ -55,9 +55,9 @@ import torch.multiprocessing as mp import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule from decoder import Decoder +from frame_reducer import FrameReducer from joiner import Joiner from lconv import LConv -from frame_reducer import FrameReducer from lhotse.cut import Cut from lhotse.dataset.sampling.base import CutSampler from lhotse.utils import fix_random_seed @@ -1063,10 +1063,10 @@ def run(rank, world_size, args): librispeech = LibriSpeechAsrDataModule(args) - train_cuts = librispeech.train_clean_100_cuts() if params.full_libri: - train_cuts += librispeech.train_clean_360_cuts() - train_cuts += librispeech.train_other_500_cuts() + train_cuts = librispeech.train_all_shuf_cuts() + else: + train_cuts = librispeech.train_clean_100_cuts() def remove_short_and_long_utt(c: Cut): # Keep only utterances with duration between 1 second and 20 seconds diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train.py index 2bdc882a5..c7a2a136d 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train.py @@ -1049,10 +1049,10 @@ def run(rank, world_size, args): librispeech = LibriSpeechAsrDataModule(args) - train_cuts = librispeech.train_clean_100_cuts() if params.full_libri: - train_cuts += librispeech.train_clean_360_cuts() - train_cuts += librispeech.train_other_500_cuts() + train_cuts = librispeech.train_all_shuf_cuts() + else: + train_cuts = librispeech.train_clean_100_cuts() def remove_short_and_long_utt(c: Cut): # Keep only utterances with duration between 1 second and 20 seconds diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/train.py b/egs/librispeech/ASR/pruned_transducer_stateless8/train.py index abe249c7b..b0abad5ae 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless8/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/train.py @@ -1154,10 +1154,10 @@ def run(rank, world_size, args): librispeech = LibriSpeech(manifest_dir=args.manifest_dir) - train_cuts = librispeech.train_clean_100_cuts() if params.full_libri: - train_cuts += librispeech.train_clean_360_cuts() - train_cuts += librispeech.train_other_500_cuts() + train_cuts = librispeech.train_all_shuf_cuts() + else: + train_cuts = librispeech.train_clean_100_cuts() train_cuts = filter_short_and_long_utterances(train_cuts, sp) diff --git a/egs/librispeech/ASR/streaming_conformer_ctc/train.py b/egs/librispeech/ASR/streaming_conformer_ctc/train.py index d265de45b..bb55ed6bb 100755 --- a/egs/librispeech/ASR/streaming_conformer_ctc/train.py +++ b/egs/librispeech/ASR/streaming_conformer_ctc/train.py @@ -30,6 +30,7 @@ import torch.multiprocessing as mp import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule from conformer import Conformer +from lhotse.cut import Cut from lhotse.utils import fix_random_seed from torch import Tensor from torch.nn.parallel import DistributedDataParallel as DDP @@ -50,7 +51,7 @@ from icefall.utils import ( setup_logger, str2bool, ) -from lhotse.cut import Cut + def get_parser(): parser = argparse.ArgumentParser( From bffce413f07d938e35d69d5eb2f360c7ff842502 Mon Sep 17 00:00:00 2001 From: Yifan Yang <64255737+yfyeung@users.noreply.github.com> Date: Fri, 3 Feb 2023 12:32:06 +0800 Subject: [PATCH 093/174] Fix filename ctc_guild_decode_bs.py -> ctc_guide_decode_bs.py (#870) * fix filename ctc_guild_decode_bs.py -> ctc_guide_decode_bs.py --------- Co-authored-by: yifanyang --- .../librispeech/zipformer_ctc_blankskip.rst | 14 +++++++------- egs/librispeech/ASR/RESULTS.md | 2 +- ...c_guild_decode_bs.py => ctc_guide_decode_bs.py} | 14 +++++++------- .../pruned_transducer_stateless7_ctc_bs/lconv.py | 2 +- 4 files changed, 16 insertions(+), 16 deletions(-) rename egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/{ctc_guild_decode_bs.py => ctc_guide_decode_bs.py} (98%) diff --git a/docs/source/recipes/Non-streaming-ASR/librispeech/zipformer_ctc_blankskip.rst b/docs/source/recipes/Non-streaming-ASR/librispeech/zipformer_ctc_blankskip.rst index 4929df950..aa73bfe33 100644 --- a/docs/source/recipes/Non-streaming-ASR/librispeech/zipformer_ctc_blankskip.rst +++ b/docs/source/recipes/Non-streaming-ASR/librispeech/zipformer_ctc_blankskip.rst @@ -299,11 +299,11 @@ to run the training part first. - (1) ``epoch-1.pt``, ``epoch-2.pt``, ..., which are saved at the end of each epoch. You can pass ``--epoch`` to - ``pruned_transducer_stateless7_ctc_bs/ctc_guild_decode_bs.py`` to use them. + ``pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py`` to use them. - (2) ``checkpoints-436000.pt``, ``epoch-438000.pt``, ..., which are saved every ``--save-every-n`` batches. You can pass ``--iter`` to - ``pruned_transducer_stateless7_ctc_bs/ctc_guild_decode_bs.py`` to use them. + ``pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py`` to use them. We suggest that you try both types of checkpoints and choose the one that produces the lowest WERs. @@ -311,7 +311,7 @@ to run the training part first. .. code-block:: bash $ cd egs/librispeech/ASR - $ ./pruned_transducer_stateless7_ctc_bs/ctc_guild_decode_bs.py --help + $ ./pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py --help shows the options for decoding. @@ -320,7 +320,7 @@ The following shows the example using ``epoch-*.pt``: .. code-block:: bash for m in greedy_search fast_beam_search modified_beam_search; do - ./pruned_transducer_stateless7_ctc_bs/ctc_guild_decode_bs.py \ + ./pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py \ --epoch 30 \ --avg 13 \ --exp-dir pruned_transducer_stateless7_ctc_bs/exp \ @@ -333,7 +333,7 @@ To test CTC branch, you can use the following command: .. code-block:: bash for m in ctc-decoding 1best; do - ./pruned_transducer_stateless7_ctc_bs/ctc_guild_decode_bs.py \ + ./pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py \ --epoch 30 \ --avg 13 \ --exp-dir pruned_transducer_stateless7_ctc_bs/exp \ @@ -367,7 +367,7 @@ It will generate a file ``./pruned_transducer_stateless7_ctc_bs/exp/pretrained.p .. hint:: - To use the generated ``pretrained.pt`` for ``pruned_transducer_stateless7_ctc_bs/ctc_guild_decode_bs.py``, + To use the generated ``pretrained.pt`` for ``pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py``, you can run: .. code-block:: bash @@ -376,7 +376,7 @@ It will generate a file ``./pruned_transducer_stateless7_ctc_bs/exp/pretrained.p ln -s pretrained epoch-9999.pt And then pass ``--epoch 9999 --avg 1 --use-averaged-model 0`` to - ``./pruned_transducer_stateless7_ctc_bs/ctc_guild_decode_bs.py``. + ``./pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py``. To use the exported model with ``./pruned_transducer_stateless7_ctc_bs/pretrained.py``, you can run: diff --git a/egs/librispeech/ASR/RESULTS.md b/egs/librispeech/ASR/RESULTS.md index a3e44f09c..1a894498e 100644 --- a/egs/librispeech/ASR/RESULTS.md +++ b/egs/librispeech/ASR/RESULTS.md @@ -194,7 +194,7 @@ The decoding commands for the transducer branch of the model using blank skip ([ for m in greedy_search modified_beam_search fast_beam_search; do for epoch in 30; do for avg in 15; do - ./pruned_transducer_stateless7_ctc_bs/ctc_guild_decode_bs.py \ + ./pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py \ --epoch $epoch \ --avg $avg \ --use-averaged-model 1 \ diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/ctc_guild_decode_bs.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py similarity index 98% rename from egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/ctc_guild_decode_bs.py rename to egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py index 9c2166aaf..01ba7b711 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/ctc_guild_decode_bs.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py @@ -21,7 +21,7 @@ """ Usage: (1) greedy search -./pruned_transducer_stateless7_ctc_bs/ctc_guild_decode_bs.py \ +./pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py \ --epoch 28 \ --avg 15 \ --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ @@ -29,7 +29,7 @@ Usage: --decoding-method greedy_search (2) beam search (not recommended) -./pruned_transducer_stateless7_ctc_bs/ctc_guild_decode_bs.py \ +./pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py \ --epoch 28 \ --avg 15 \ --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ @@ -38,7 +38,7 @@ Usage: --beam-size 4 (3) modified beam search -./pruned_transducer_stateless7_ctc_bs/ctc_guild_decode_bs.py \ +./pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py \ --epoch 28 \ --avg 15 \ --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ @@ -47,7 +47,7 @@ Usage: --beam-size 4 (4) fast beam search (one best) -./pruned_transducer_stateless7_ctc_bs/ctc_guild_decode_bs.py \ +./pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py \ --epoch 28 \ --avg 15 \ --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ @@ -58,7 +58,7 @@ Usage: --max-states 64 (5) fast beam search (nbest) -./pruned_transducer_stateless7_ctc/ctc_guild_decode_bs.py \ +./pruned_transducer_stateless7_ctc/ctc_guide_decode_bs.py \ --epoch 28 \ --avg 15 \ --exp-dir ./pruned_transducer_stateless7_ctc/exp \ @@ -71,7 +71,7 @@ Usage: --nbest-scale 0.5 (6) fast beam search (nbest oracle WER) -./pruned_transducer_stateless7_ctc_bs/ctc_guild_decode_bs.py \ +./pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py \ --epoch 28 \ --avg 15 \ --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ @@ -84,7 +84,7 @@ Usage: --nbest-scale 0.5 (7) fast beam search (with LG) -./pruned_transducer_stateless7_ctc_bs/ctc_guild_decode_bs.py \ +./pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py \ --epoch 28 \ --avg 15 \ --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/lconv.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/lconv.py index bfd49d533..a902358ae 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/lconv.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/lconv.py @@ -62,7 +62,7 @@ class LConv(nn.Module): kernel_size=kernel_size, stride=1, padding=(kernel_size - 1) // 2, - groups=channels, + groups=2 * channels, bias=bias, ) From 029c8566e424b44e64da70c6fb532caace9c7d54 Mon Sep 17 00:00:00 2001 From: Yifan Yang <64255737+yfyeung@users.noreply.github.com> Date: Fri, 3 Feb 2023 17:49:54 +0800 Subject: [PATCH 094/174] Small fix for frame_reducer.py (#871) --- .../ASR/pruned_transducer_stateless7_ctc_bs/frame_reducer.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/frame_reducer.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/frame_reducer.py index 4a19edf66..bc3fc57eb 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/frame_reducer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/frame_reducer.py @@ -44,6 +44,7 @@ class FrameReducer(nn.Module): x: torch.Tensor, x_lens: torch.Tensor, ctc_output: torch.Tensor, + blank_id: int = 0, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: @@ -54,6 +55,8 @@ class FrameReducer(nn.Module): `x` before padding. ctc_output: The CTC output with shape [N, T, vocab_size]. + blank_id: + The blank id of ctc_output. Returns: out: The frame reduced encoder output with shape [N, T', C]. @@ -65,7 +68,7 @@ class FrameReducer(nn.Module): N, T, C = x.size() padding_mask = make_pad_mask(x_lens) - non_blank_mask = (ctc_output[:, :, 0] < math.log(0.9)) * (~padding_mask) + non_blank_mask = (ctc_output[:, :, blank_id] < math.log(0.9)) * (~padding_mask) out_lens = non_blank_mask.sum(dim=1) max_len = out_lens.max() From bf5f0342a24b2dd92a908980ba5e8619ca2a08f4 Mon Sep 17 00:00:00 2001 From: Yuekai Zhang Date: Mon, 6 Feb 2023 10:37:07 +0800 Subject: [PATCH 095/174] Add streaming onnx export for zipformer (#831) * add streaming onnx export for zipformer * update triton support * add comments * add ci test * add onnxmltools for fp16 onnx export --- ...nsducer-stateless7-streaming-2022-12-29.sh | 10 + ...speech-2022-12-29-stateless7-streaming.yml | 2 +- .../export.py | 563 +++++++++++++++++- .../onnx_model_wrapper.py | 231 +++++++ .../zipformer.py | 60 +- requirements-ci.txt | 1 + 6 files changed, 843 insertions(+), 24 deletions(-) create mode 100644 egs/librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_model_wrapper.py diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless7-streaming-2022-12-29.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless7-streaming-2022-12-29.sh index afb0dc05a..bcbc91a44 100755 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless7-streaming-2022-12-29.sh +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless7-streaming-2022-12-29.sh @@ -33,6 +33,16 @@ ln -s pretrained.pt epoch-99.pt ls -lh *.pt popd +log "Test exporting to ONNX format" +./pruned_transducer_stateless7_streaming/export.py \ + --exp-dir $repo/exp \ + --use-averaged-model false \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --epoch 99 \ + --avg 1 \ + --fp16 \ + --onnx 1 + log "Export to torchscript model" ./pruned_transducer_stateless7_streaming/export.py \ --exp-dir $repo/exp \ diff --git a/.github/workflows/run-librispeech-2022-12-29-stateless7-streaming.yml b/.github/workflows/run-librispeech-2022-12-29-stateless7-streaming.yml index 6dd93946a..a1f3b4f75 100644 --- a/.github/workflows/run-librispeech-2022-12-29-stateless7-streaming.yml +++ b/.github/workflows/run-librispeech-2022-12-29-stateless7-streaming.yml @@ -39,7 +39,7 @@ concurrency: jobs: run_librispeech_2022_12_29_zipformer_streaming: - if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event.label.name == 'streaming-zipformer' || github.event_name == 'push' || github.event_name == 'schedule' + if: github.event.label.name == 'onnx' || github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event.label.name == 'streaming-zipformer' || github.event_name == 'push' || github.event_name == 'schedule' runs-on: ${{ matrix.os }} strategy: matrix: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export.py index 5c06cc052..1bc54fa26 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export.py @@ -72,25 +72,81 @@ Check ./pretrained.py for its usage. Note: If you don't want to train a model from scratch, we have provided one for you. You can get it at -https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11 +https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29 with the following commands: sudo apt-get install git-lfs git lfs install - git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11 + git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29 # You will find the pre-trained model in icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11/exp + +(3) Export to ONNX format with pretrained.pt + +cd ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp +ln -s pretrained.pt epoch-999.pt +./pruned_transducer_stateless7_streaming/export.py \ + --exp-dir ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --use-averaged-model False \ + --epoch 999 \ + --avg 1 \ + --fp16 \ + --onnx 1 + +It will generate the following files in the given `exp_dir`. +Check `onnx_check.py` for how to use them. + + - encoder.onnx + - decoder.onnx + - joiner.onnx + - joiner_encoder_proj.onnx + - joiner_decoder_proj.onnx + +Check +https://github.com/k2-fsa/sherpa-onnx +for how to use the exported models outside of icefall. + +(4) Export to ONNX format for triton server + +cd ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp +ln -s pretrained.pt epoch-999.pt +./pruned_transducer_stateless7_streaming/export.py \ + --exp-dir ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --use-averaged-model False \ + --epoch 999 \ + --avg 1 \ + --fp16 \ + --onnx-triton 1 \ + --onnx 1 + +It will generate the following files in the given `exp_dir`. +Check `onnx_check.py` for how to use them. + + - encoder.onnx + - decoder.onnx + - joiner.onnx + +Check +https://github.com/k2-fsa/sherpa/tree/master/triton +for how to use the exported models outside of icefall. + """ + import argparse import logging from pathlib import Path +import onnxruntime import sentencepiece as spm import torch import torch.nn as nn +from onnx_model_wrapper import OnnxStreamingEncoder, TritonOnnxDecoder, TritonOnnxJoiner from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_params, get_transducer_model +from zipformer import stack_states from icefall.checkpoint import ( average_checkpoints, @@ -172,6 +228,42 @@ def get_parser(): """, ) + parser.add_argument( + "--onnx", + type=str2bool, + default=False, + help="""If True, --jit is ignored and it exports the model + to onnx format. It will generate the following files: + + - encoder.onnx + - decoder.onnx + - joiner.onnx + - joiner_encoder_proj.onnx + - joiner_decoder_proj.onnx + + Refer to ./onnx_check.py and ./onnx_pretrained.py for how to use them. + """, + ) + + parser.add_argument( + "--onnx-triton", + type=str2bool, + default=False, + help="""If True, --onnx would export model into the following files: + + - encoder.onnx + - decoder.onnx + - joiner.onnx + These files would be used for https://github.com/k2-fsa/sherpa/tree/master/triton. + """, + ) + + parser.add_argument( + "--fp16", + action="store_true", + help="whether to export fp16 onnx model, default false", + ) + parser.add_argument( "--context-size", type=int, @@ -184,6 +276,391 @@ def get_parser(): return parser +def test_acc(xlist, blist, rtol=1e-3, atol=1e-5, tolerate_small_mismatch=True): + for a, b in zip(xlist, blist): + try: + torch.testing.assert_allclose(a, b, rtol=rtol, atol=atol) + except AssertionError as error: + if tolerate_small_mismatch: + print("small mismatch detected", error) + else: + return False + return True + + +def export_encoder_model_onnx( + encoder_model: nn.Module, + encoder_filename: str, + opset_version: int = 11, +) -> None: + """Export the given encoder model to ONNX format. + The exported model has two inputs: + + - x, a tensor of shape (N, T, C); dtype is torch.float32 + - x_lens, a tensor of shape (N,); dtype is torch.int64 + + and it has two outputs: + + - encoder_out, a tensor of shape (N, T, C) + - encoder_out_lens, a tensor of shape (N,) + + Note: The warmup argument is fixed to 1. + + Args: + encoder_model: + The input encoder model + encoder_filename: + The filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + batch_size = 17 + seq_len = 101 + torch.manual_seed(0) + x = torch.rand(batch_size, seq_len, 80, dtype=torch.float32) + x_lens = torch.tensor([seq_len - i for i in range(batch_size)], dtype=torch.int64) + + # encoder_model = torch.jit.script(encoder_model) + # It throws the following error for the above statement + # + # RuntimeError: Exporting the operator __is_ to ONNX opset version + # 11 is not supported. Please feel free to request support or + # submit a pull request on PyTorch GitHub. + # + # I cannot find which statement causes the above error. + # torch.onnx.export() will use torch.jit.trace() internally, which + # works well for the current reworked model + initial_states = [encoder_model.get_init_state() for _ in range(batch_size)] + states = stack_states(initial_states) + + left_context_len = encoder_model.decode_chunk_size * encoder_model.num_left_chunks + encoder_attention_dim = encoder_model.encoders[0].attention_dim + + len_cache = torch.cat(states[: encoder_model.num_encoders]).transpose(0, 1) # B,15 + avg_cache = torch.cat( + states[encoder_model.num_encoders : 2 * encoder_model.num_encoders] + ).transpose( + 0, 1 + ) # [B,15,384] + cnn_cache = torch.cat(states[5 * encoder_model.num_encoders :]).transpose( + 0, 1 + ) # [B,2*15,384,cnn_kernel-1] + pad_tensors = [ + torch.nn.functional.pad( + tensor, + ( + 0, + encoder_attention_dim - tensor.shape[-1], + 0, + 0, + 0, + left_context_len - tensor.shape[1], + 0, + 0, + ), + ) + for tensor in states[ + 2 * encoder_model.num_encoders : 5 * encoder_model.num_encoders + ] + ] + attn_cache = torch.cat(pad_tensors).transpose(0, 2) # [B,64,15*3,192] + + encoder_model_wrapper = OnnxStreamingEncoder(encoder_model) + + torch.onnx.export( + encoder_model_wrapper, + (x, x_lens, len_cache, avg_cache, attn_cache, cnn_cache), + encoder_filename, + verbose=False, + opset_version=opset_version, + input_names=[ + "x", + "x_lens", + "len_cache", + "avg_cache", + "attn_cache", + "cnn_cache", + ], + output_names=[ + "encoder_out", + "encoder_out_lens", + "new_len_cache", + "new_avg_cache", + "new_attn_cache", + "new_cnn_cache", + ], + dynamic_axes={ + "x": {0: "N", 1: "T"}, + "x_lens": {0: "N"}, + "encoder_out": {0: "N", 1: "T"}, + "encoder_out_lens": {0: "N"}, + "len_cache": {0: "N"}, + "avg_cache": {0: "N"}, + "attn_cache": {0: "N"}, + "cnn_cache": {0: "N"}, + "new_len_cache": {0: "N"}, + "new_avg_cache": {0: "N"}, + "new_attn_cache": {0: "N"}, + "new_cnn_cache": {0: "N"}, + }, + ) + logging.info(f"Saved to {encoder_filename}") + + # Test onnx encoder with torch native encoder + encoder_model.eval() + ( + encoder_out_torch, + encoder_out_lens_torch, + new_states_torch, + ) = encoder_model.streaming_forward( + x=x, + x_lens=x_lens, + states=states, + ) + ort_session = onnxruntime.InferenceSession( + str(encoder_filename), providers=["CPUExecutionProvider"] + ) + ort_inputs = { + "x": x.numpy(), + "x_lens": x_lens.numpy(), + "len_cache": len_cache.numpy(), + "avg_cache": avg_cache.numpy(), + "attn_cache": attn_cache.numpy(), + "cnn_cache": cnn_cache.numpy(), + } + ort_outs = ort_session.run(None, ort_inputs) + + assert test_acc( + [encoder_out_torch.numpy(), encoder_out_lens_torch.numpy()], ort_outs[:2] + ) + logging.info(f"{encoder_filename} acc test succeeded.") + + +def export_decoder_model_onnx( + decoder_model: nn.Module, + decoder_filename: str, + opset_version: int = 11, +) -> None: + """Export the decoder model to ONNX format. + + The exported model has one input: + + - y: a torch.int64 tensor of shape (N, decoder_model.context_size) + + and has one output: + + - decoder_out: a torch.float32 tensor of shape (N, 1, C) + + Note: The argument need_pad is fixed to False. + + Args: + decoder_model: + The decoder model to be exported. + decoder_filename: + Filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64) + need_pad = False # Always False, so we can use torch.jit.trace() here + # Note(fangjun): torch.jit.trace() is more efficient than torch.jit.script() + # in this case + torch.onnx.export( + decoder_model, + (y, need_pad), + decoder_filename, + verbose=False, + opset_version=opset_version, + input_names=["y", "need_pad"], + output_names=["decoder_out"], + dynamic_axes={ + "y": {0: "N"}, + "decoder_out": {0: "N"}, + }, + ) + logging.info(f"Saved to {decoder_filename}") + + +def export_decoder_model_onnx_triton( + decoder_model: nn.Module, + decoder_filename: str, + opset_version: int = 11, +) -> None: + """Export the decoder model to ONNX format. + + The exported model has one input: + + - y: a torch.int64 tensor of shape (N, decoder_model.context_size) + + and has one output: + + - decoder_out: a torch.float32 tensor of shape (N, 1, C) + + Note: The argument need_pad is fixed to False. + + Args: + decoder_model: + The decoder model to be exported. + decoder_filename: + Filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64) + + decoder_model = TritonOnnxDecoder(decoder_model) + + torch.onnx.export( + decoder_model, + (y,), + decoder_filename, + verbose=False, + opset_version=opset_version, + input_names=["y"], + output_names=["decoder_out"], + dynamic_axes={ + "y": {0: "N"}, + "decoder_out": {0: "N"}, + }, + ) + logging.info(f"Saved to {decoder_filename}") + + +def export_joiner_model_onnx( + joiner_model: nn.Module, + joiner_filename: str, + opset_version: int = 11, +) -> None: + """Export the joiner model to ONNX format. + The exported joiner model has two inputs: + + - projected_encoder_out: a tensor of shape (N, joiner_dim) + - projected_decoder_out: a tensor of shape (N, joiner_dim) + + and produces one output: + + - logit: a tensor of shape (N, vocab_size) + + The exported encoder_proj model has one input: + + - encoder_out: a tensor of shape (N, encoder_out_dim) + + and produces one output: + + - projected_encoder_out: a tensor of shape (N, joiner_dim) + + The exported decoder_proj model has one input: + + - decoder_out: a tensor of shape (N, decoder_out_dim) + + and produces one output: + + - projected_decoder_out: a tensor of shape (N, joiner_dim) + """ + encoder_proj_filename = str(joiner_filename).replace(".onnx", "_encoder_proj.onnx") + decoder_proj_filename = str(joiner_filename).replace(".onnx", "_decoder_proj.onnx") + + encoder_out_dim = joiner_model.encoder_proj.weight.shape[1] + decoder_out_dim = joiner_model.decoder_proj.weight.shape[1] + joiner_dim = joiner_model.decoder_proj.weight.shape[0] + + projected_encoder_out = torch.rand(1, 1, 1, joiner_dim, dtype=torch.float32) + projected_decoder_out = torch.rand(1, 1, 1, joiner_dim, dtype=torch.float32) + + project_input = False + # Note: It uses torch.jit.trace() internally + torch.onnx.export( + joiner_model, + (projected_encoder_out, projected_decoder_out, project_input), + joiner_filename, + verbose=False, + opset_version=opset_version, + input_names=[ + "encoder_out", + "decoder_out", + "project_input", + ], + output_names=["logit"], + dynamic_axes={ + "encoder_out": {0: "N"}, + "decoder_out": {0: "N"}, + "logit": {0: "N"}, + }, + ) + logging.info(f"Saved to {joiner_filename}") + + encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32) + torch.onnx.export( + joiner_model.encoder_proj, + encoder_out, + encoder_proj_filename, + verbose=False, + opset_version=opset_version, + input_names=["encoder_out"], + output_names=["projected_encoder_out"], + dynamic_axes={ + "encoder_out": {0: "N"}, + "projected_encoder_out": {0: "N"}, + }, + ) + logging.info(f"Saved to {encoder_proj_filename}") + + decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32) + torch.onnx.export( + joiner_model.decoder_proj, + decoder_out, + decoder_proj_filename, + verbose=False, + opset_version=opset_version, + input_names=["decoder_out"], + output_names=["projected_decoder_out"], + dynamic_axes={ + "decoder_out": {0: "N"}, + "projected_decoder_out": {0: "N"}, + }, + ) + logging.info(f"Saved to {decoder_proj_filename}") + + +def export_joiner_model_onnx_triton( + joiner_model: nn.Module, + joiner_filename: str, + opset_version: int = 11, +) -> None: + """Export the joiner model to ONNX format. + The exported model has two inputs: + - encoder_out: a tensor of shape (N, encoder_out_dim) + - decoder_out: a tensor of shape (N, decoder_out_dim) + and has one output: + - joiner_out: a tensor of shape (N, vocab_size) + Note: The argument project_input is fixed to True. A user should not + project the encoder_out/decoder_out by himself/herself. The exported joiner + will do that for the user. + """ + encoder_out_dim = joiner_model.encoder_proj.weight.shape[1] + decoder_out_dim = joiner_model.decoder_proj.weight.shape[1] + encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32) + decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32) + + joiner_model = TritonOnnxJoiner(joiner_model) + # Note: It uses torch.jit.trace() internally + torch.onnx.export( + joiner_model, + (encoder_out, decoder_out), + joiner_filename, + verbose=False, + opset_version=opset_version, + input_names=["encoder_out", "decoder_out"], + output_names=["logit"], + dynamic_axes={ + "encoder_out": {0: "N"}, + "decoder_out": {0: "N"}, + "logit": {0: "N"}, + }, + ) + logging.info(f"Saved to {joiner_filename}") + + @torch.no_grad() def main(): args = get_parser().parse_args() @@ -292,7 +769,87 @@ def main(): model.to("cpu") model.eval() - if params.jit is True: + if params.onnx: + convert_scaled_to_non_scaled(model, inplace=True) + opset_version = 13 + logging.info("Exporting to onnx format") + encoder_filename = params.exp_dir / "encoder.onnx" + export_encoder_model_onnx( + model.encoder, + encoder_filename, + opset_version=opset_version, + ) + if not params.onnx_triton: + decoder_filename = params.exp_dir / "decoder.onnx" + export_decoder_model_onnx( + model.decoder, + decoder_filename, + opset_version=opset_version, + ) + + joiner_filename = params.exp_dir / "joiner.onnx" + export_joiner_model_onnx( + model.joiner, + joiner_filename, + opset_version=opset_version, + ) + else: + decoder_filename = params.exp_dir / "decoder.onnx" + export_decoder_model_onnx_triton( + model.decoder, + decoder_filename, + opset_version=opset_version, + ) + + joiner_filename = params.exp_dir / "joiner.onnx" + export_joiner_model_onnx_triton( + model.joiner, + joiner_filename, + opset_version=opset_version, + ) + + if params.fp16: + try: + import onnxmltools + from onnxmltools.utils.float16_converter import convert_float_to_float16 + except ImportError: + print("Please install onnxmltools!") + import sys + + sys.exit(1) + + def export_onnx_fp16(onnx_fp32_path, onnx_fp16_path): + onnx_fp32_model = onnxmltools.utils.load_model(onnx_fp32_path) + onnx_fp16_model = convert_float_to_float16(onnx_fp32_model) + onnxmltools.utils.save_model(onnx_fp16_model, onnx_fp16_path) + + encoder_fp16_filename = params.exp_dir / "encoder_fp16.onnx" + export_onnx_fp16(encoder_filename, encoder_fp16_filename) + + decoder_fp16_filename = params.exp_dir / "decoder_fp16.onnx" + export_onnx_fp16(decoder_filename, decoder_fp16_filename) + + joiner_fp16_filename = params.exp_dir / "joiner_fp16.onnx" + export_onnx_fp16(joiner_filename, joiner_fp16_filename) + + if not params.onnx_triton: + encoder_proj_filename = str(joiner_filename).replace( + ".onnx", "_encoder_proj.onnx" + ) + encoder_proj_fp16_filename = ( + params.exp_dir / "joiner_encoder_proj_fp16.onnx" + ) + export_onnx_fp16(encoder_proj_filename, encoder_proj_fp16_filename) + + decoder_proj_filename = str(joiner_filename).replace( + ".onnx", "_decoder_proj.onnx" + ) + decoder_proj_fp16_filename = ( + params.exp_dir / "joiner_decoder_proj_fp16.onnx" + ) + export_onnx_fp16(decoder_proj_filename, decoder_proj_fp16_filename) + + elif params.jit: convert_scaled_to_non_scaled(model, inplace=True) # We won't use the forward() method of the model in C++, so just ignore # it here. diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_model_wrapper.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_model_wrapper.py new file mode 100644 index 000000000..f52deecc9 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_model_wrapper.py @@ -0,0 +1,231 @@ +from typing import Optional, Tuple + +import torch + + +class OnnxStreamingEncoder(torch.nn.Module): + """This class warps the streaming Zipformer to reduce the number of + state tensors for onnx. + https://github.com/k2-fsa/icefall/pull/831 + """ + + def __init__(self, encoder): + """ + Args: + encoder: A Instance of Zipformer Class + """ + super().__init__() + self.model = encoder + + def forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + len_cache: torch.tensor, + avg_cache: torch.tensor, + attn_cache: torch.tensor, + cnn_cache: torch.tensor, + ) -> Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + ]: + """ + Args: + x: + The input tensor. Its shape is (batch_size, seq_len, feature_dim). + x_lens: + A tensor of shape (batch_size,) containing the number of frames in + `x` before padding. + len_cache: + The cached numbers of past frames. + avg_cache: + The cached average tensors. + attn_cache: + The cached key tensors of the first attention modules. + The cached value tensors of the first attention modules. + The cached value tensors of the second attention modules. + cnn_cache: + The cached left contexts of the first convolution modules. + The cached left contexts of the second convolution modules. + + Returns: + Return a tuple containing 2 tensors: + + """ + num_encoder_layers = [] + encoder_attention_dims = [] + states = [] + for i, encoder in enumerate(self.model.encoders): + num_encoder_layers.append(encoder.num_layers) + encoder_attention_dims.append(encoder.attention_dim) + + len_cache = len_cache.transpose(0, 1) # sum(num_encoder_layers)==15, [15, B] + offset = 0 + for num_layer in num_encoder_layers: + states.append(len_cache[offset : offset + num_layer]) + offset += num_layer + + avg_cache = avg_cache.transpose(0, 1) # [15, B, 384] + offset = 0 + for num_layer in num_encoder_layers: + states.append(avg_cache[offset : offset + num_layer]) + offset += num_layer + + attn_cache = attn_cache.transpose(0, 2) # [15*3, 64, B, 192] + left_context_len = attn_cache.shape[1] + offset = 0 + for i, num_layer in enumerate(num_encoder_layers): + ds = self.model.zipformer_downsampling_factors[i] + states.append( + attn_cache[offset : offset + num_layer, : left_context_len // ds] + ) + offset += num_layer + for i, num_layer in enumerate(num_encoder_layers): + encoder_attention_dim = encoder_attention_dims[i] + ds = self.model.zipformer_downsampling_factors[i] + states.append( + attn_cache[ + offset : offset + num_layer, + : left_context_len // ds, + :, + : encoder_attention_dim // 2, + ] + ) + offset += num_layer + for i, num_layer in enumerate(num_encoder_layers): + ds = self.model.zipformer_downsampling_factors[i] + states.append( + attn_cache[ + offset : offset + num_layer, + : left_context_len // ds, + :, + : encoder_attention_dim // 2, + ] + ) + offset += num_layer + + cnn_cache = cnn_cache.transpose(0, 1) # [30, B, 384, cnn_kernel-1] + offset = 0 + for num_layer in num_encoder_layers: + states.append(cnn_cache[offset : offset + num_layer]) + offset += num_layer + for num_layer in num_encoder_layers: + states.append(cnn_cache[offset : offset + num_layer]) + offset += num_layer + + encoder_out, encoder_out_lens, new_states = self.model.streaming_forward( + x=x, + x_lens=x_lens, + states=states, + ) + + new_len_cache = torch.cat(states[: self.model.num_encoders]).transpose( + 0, 1 + ) # [B,15] + new_avg_cache = torch.cat( + states[self.model.num_encoders : 2 * self.model.num_encoders] + ).transpose( + 0, 1 + ) # [B,15,384] + new_cnn_cache = torch.cat(states[5 * self.model.num_encoders :]).transpose( + 0, 1 + ) # [B,2*15,384,cnn_kernel-1] + assert len(set(encoder_attention_dims)) == 1 + pad_tensors = [ + torch.nn.functional.pad( + tensor, + ( + 0, + encoder_attention_dims[0] - tensor.shape[-1], + 0, + 0, + 0, + left_context_len - tensor.shape[1], + 0, + 0, + ), + ) + for tensor in states[ + 2 * self.model.num_encoders : 5 * self.model.num_encoders + ] + ] + new_attn_cache = torch.cat(pad_tensors).transpose(0, 2) # [B,64,15*3,192] + + return ( + encoder_out, + encoder_out_lens, + new_len_cache, + new_avg_cache, + new_attn_cache, + new_cnn_cache, + ) + + +class TritonOnnxDecoder(torch.nn.Module): + """This class warps the Decoder in decoder.py + to remove the scalar input "need_pad". + Triton currently doesn't support scalar input. + https://github.com/triton-inference-server/server/issues/2333 + """ + + def __init__( + self, + decoder: torch.nn.Module, + ): + """ + Args: + decoder: A instance of Decoder + """ + super().__init__() + self.model = decoder + + def forward(self, y: torch.Tensor) -> torch.Tensor: + """ + Args: + y: + A 2-D tensor of shape (N, U). + Returns: + Return a tensor of shape (N, U, decoder_dim). + """ + # False to not pad the input. Should be False during inference. + need_pad = False + return self.model(y, need_pad) + + +class TritonOnnxJoiner(torch.nn.Module): + """This class warps the Joiner in joiner.py + to remove the scalar input "project_input". + Triton currently doesn't support scalar input. + https://github.com/triton-inference-server/server/issues/2333 + "project_input" is set to True. + Triton solutions only need export joiner to a single joiner.onnx. + """ + + def __init__( + self, + joiner: torch.nn.Module, + ): + super().__init__() + self.model = joiner + + def forward( + self, + encoder_out: torch.Tensor, + decoder_out: torch.Tensor, + ) -> torch.Tensor: + """ + Args: + encoder_out: + Output from the encoder. Its shape is (N, T, s_range, C). + decoder_out: + Output from the decoder. Its shape is (N, T, s_range, C). + Returns: + Return a tensor of shape (N, T, s_range, C). + """ + # Apply input projections encoder_proj and decoder_proj. + project_input = True + return self.model(encoder_out, decoder_out, project_input) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py index e13629384..1b267c1c5 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py @@ -2084,16 +2084,26 @@ class RelPositionMultiheadAttention(nn.Module): # the following .as_strided() expression converts the last axis of pos_weights from relative # to absolute position. I don't know whether I might have got the time-offsets backwards or # not, but let this code define which way round it is supposed to be. - pos_weights = pos_weights.as_strided( - (bsz, num_heads, seq_len, seq_len), - ( - pos_weights.stride(0), - pos_weights.stride(1), - pos_weights.stride(2) - pos_weights.stride(3), - pos_weights.stride(3), - ), - storage_offset=pos_weights.stride(3) * (seq_len - 1), - ) + if torch.jit.is_tracing(): + (batch_size, num_heads, time1, n) = pos_weights.shape + rows = torch.arange(start=time1 - 1, end=-1, step=-1) + cols = torch.arange(seq_len) + rows = rows.repeat(batch_size * num_heads).unsqueeze(-1) + indexes = rows + cols + pos_weights = pos_weights.reshape(-1, n) + pos_weights = torch.gather(pos_weights, dim=1, index=indexes) + pos_weights = pos_weights.reshape(batch_size, num_heads, time1, seq_len) + else: + pos_weights = pos_weights.as_strided( + (bsz, num_heads, seq_len, seq_len), + ( + pos_weights.stride(0), + pos_weights.stride(1), + pos_weights.stride(2) - pos_weights.stride(3), + pos_weights.stride(3), + ), + storage_offset=pos_weights.stride(3) * (seq_len - 1), + ) # caution: they are really scores at this point. attn_output_weights = torch.matmul(q, k) + pos_weights @@ -2275,16 +2285,26 @@ class RelPositionMultiheadAttention(nn.Module): # the following .as_strided() expression converts the last axis of pos_weights from relative # to absolute position. I don't know whether I might have got the time-offsets backwards or # not, but let this code define which way round it is supposed to be. - pos_weights = pos_weights.as_strided( - (bsz, num_heads, seq_len, kv_len), - ( - pos_weights.stride(0), - pos_weights.stride(1), - pos_weights.stride(2) - pos_weights.stride(3), - pos_weights.stride(3), - ), - storage_offset=pos_weights.stride(3) * (seq_len - 1), - ) + if torch.jit.is_tracing(): + (batch_size, num_heads, time1, n) = pos_weights.shape + rows = torch.arange(start=time1 - 1, end=-1, step=-1) + cols = torch.arange(kv_len) + rows = rows.repeat(batch_size * num_heads).unsqueeze(-1) + indexes = rows + cols + pos_weights = pos_weights.reshape(-1, n) + pos_weights = torch.gather(pos_weights, dim=1, index=indexes) + pos_weights = pos_weights.reshape(batch_size, num_heads, time1, kv_len) + else: + pos_weights = pos_weights.as_strided( + (bsz, num_heads, seq_len, kv_len), + ( + pos_weights.stride(0), + pos_weights.stride(1), + pos_weights.stride(2) - pos_weights.stride(3), + pos_weights.stride(3), + ), + storage_offset=pos_weights.stride(3) * (seq_len - 1), + ) # caution: they are really scores at this point. attn_output_weights = torch.matmul(q, k) + pos_weights diff --git a/requirements-ci.txt b/requirements-ci.txt index b8e49899e..50d4e5e3f 100644 --- a/requirements-ci.txt +++ b/requirements-ci.txt @@ -22,5 +22,6 @@ typeguard==2.13.3 multi_quantization onnx +onnxmltools onnxruntime kaldifst From caf23546edea120f402b03916d3a5647f54a28d8 Mon Sep 17 00:00:00 2001 From: Yifan Yang <64255737+yfyeung@users.noreply.github.com> Date: Mon, 6 Feb 2023 12:17:45 +0800 Subject: [PATCH 096/174] No more T < S after frame_reducer (#875) * No more T < S after frame_reducer * Fix for style check * Adjust the permissions * Add support for inference to frame_reducer * Fix for flake8 check --------- Co-authored-by: yifanyang --- .../__init__.py | 0 .../export_onnx.py | 0 .../frame_reducer.py | 74 +++++++++++++++---- .../lconv.py | 0 .../model.py | 10 ++- .../onnx_pretrained.py | 0 .../train.py | 3 +- 7 files changed, 65 insertions(+), 22 deletions(-) mode change 100755 => 100644 egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/__init__.py mode change 100644 => 100755 egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/export_onnx.py mode change 100755 => 100644 egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/lconv.py mode change 100755 => 100644 egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/model.py mode change 100644 => 100755 egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/onnx_pretrained.py diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/__init__.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/__init__.py old mode 100755 new mode 100644 diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/export_onnx.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/export_onnx.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/frame_reducer.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/frame_reducer.py index bc3fc57eb..0841f7cf1 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/frame_reducer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/frame_reducer.py @@ -1,7 +1,8 @@ #!/usr/bin/env python3 # -# Copyright 2022 Xiaomi Corp. (authors: Yifan Yang, -# Zengwei Yao) +# Copyright 2022 Xiaomi Corp. (authors: Yifan Yang, +# Zengwei Yao, +# Wei Kang) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -18,7 +19,7 @@ # limitations under the License. import math -from typing import List, Optional, Tuple, Union +from typing import Optional, Tuple import torch import torch.nn as nn @@ -44,6 +45,7 @@ class FrameReducer(nn.Module): x: torch.Tensor, x_lens: torch.Tensor, ctc_output: torch.Tensor, + y_lens: Optional[torch.Tensor] = None, blank_id: int = 0, ) -> Tuple[torch.Tensor, torch.Tensor]: """ @@ -55,6 +57,9 @@ class FrameReducer(nn.Module): `x` before padding. ctc_output: The CTC output with shape [N, T, vocab_size]. + y_lens: + A tensor of shape (batch_size,) containing the number of frames in + `y` before padding. blank_id: The blank id of ctc_output. Returns: @@ -64,15 +69,45 @@ class FrameReducer(nn.Module): A tensor of shape (batch_size,) containing the number of frames in `out` before padding. """ - N, T, C = x.size() padding_mask = make_pad_mask(x_lens) non_blank_mask = (ctc_output[:, :, blank_id] < math.log(0.9)) * (~padding_mask) + if y_lens is not None: + # Limit the maximum number of reduced frames + limit_lens = T - y_lens + max_limit_len = limit_lens.max().int() + fake_limit_indexes = torch.topk( + ctc_output[:, :, blank_id], max_limit_len + ).indices + T = ( + torch.arange(max_limit_len) + .expand_as( + fake_limit_indexes, + ) + .to(device=x.device) + ) + T = torch.remainder(T, limit_lens.unsqueeze(1)) + limit_indexes = torch.gather(fake_limit_indexes, 1, T) + limit_mask = torch.full_like( + non_blank_mask, + False, + device=x.device, + ).scatter_(1, limit_indexes, True) + + non_blank_mask = non_blank_mask | ~limit_mask + out_lens = non_blank_mask.sum(dim=1) max_len = out_lens.max() - pad_lens_list = torch.full_like(out_lens, max_len.item()) - out_lens + pad_lens_list = ( + torch.full_like( + out_lens, + max_len.item(), + device=x.device, + ) + - out_lens + ) max_pad_len = pad_lens_list.max() out = F.pad(x, (0, 0, 0, max_pad_len)) @@ -82,26 +117,30 @@ class FrameReducer(nn.Module): out = out[total_valid_mask].reshape(N, -1, C) - return out.to(device=x.device), out_lens.to(device=x.device) + return out, out_lens if __name__ == "__main__": import time - from torch.nn.utils.rnn import pad_sequence test_times = 10000 + device = "cuda:0" frame_reducer = FrameReducer() # non zero case - x = torch.ones(15, 498, 384, dtype=torch.float32) - x_lens = torch.tensor([498] * 15, dtype=torch.int64) - ctc_output = torch.log(torch.randn(15, 498, 500, dtype=torch.float32)) - x_fr, x_lens_fr = frame_reducer(x, x_lens, ctc_output) + x = torch.ones(15, 498, 384, dtype=torch.float32, device=device) + x_lens = torch.tensor([498] * 15, dtype=torch.int64, device=device) + y_lens = torch.tensor([150] * 15, dtype=torch.int64, device=device) + ctc_output = torch.log( + torch.randn(15, 498, 500, dtype=torch.float32, device=device), + ) avg_time = 0 for i in range(test_times): + torch.cuda.synchronize(device=x.device) delta_time = time.time() - x_fr, x_lens_fr = frame_reducer(x, x_lens, ctc_output) + x_fr, x_lens_fr = frame_reducer(x, x_lens, ctc_output, y_lens) + torch.cuda.synchronize(device=x.device) delta_time = time.time() - delta_time avg_time += delta_time print(x_fr.shape) @@ -109,14 +148,17 @@ if __name__ == "__main__": print(avg_time / test_times) # all zero case - x = torch.zeros(15, 498, 384, dtype=torch.float32) - x_lens = torch.tensor([498] * 15, dtype=torch.int64) - ctc_output = torch.zeros(15, 498, 500, dtype=torch.float32) + x = torch.zeros(15, 498, 384, dtype=torch.float32, device=device) + x_lens = torch.tensor([498] * 15, dtype=torch.int64, device=device) + y_lens = torch.tensor([150] * 15, dtype=torch.int64, device=device) + ctc_output = torch.zeros(15, 498, 500, dtype=torch.float32, device=device) avg_time = 0 for i in range(test_times): + torch.cuda.synchronize(device=x.device) delta_time = time.time() - x_fr, x_lens_fr = frame_reducer(x, x_lens, ctc_output) + x_fr, x_lens_fr = frame_reducer(x, x_lens, ctc_output, y_lens) + torch.cuda.synchronize(device=x.device) delta_time = time.time() - delta_time avg_time += delta_time print(x_fr.shape) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/lconv.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/lconv.py old mode 100755 new mode 100644 diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/model.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/model.py old mode 100755 new mode 100644 index 86acc5a10..0582b289f --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/model.py @@ -131,6 +131,10 @@ class Transducer(nn.Module): # compute ctc log-probs ctc_output = self.ctc_output(encoder_out) + # y_lens + row_splits = y.shape.row_splits(1) + y_lens = row_splits[1:] - row_splits[:-1] + # blank skip blank_id = self.decoder.blank_id @@ -146,16 +150,14 @@ class Transducer(nn.Module): encoder_out, x_lens, ctc_output, + y_lens, blank_id, ) else: encoder_out_fr = encoder_out x_lens_fr = x_lens - # Now for the decoder, i.e., the prediction network - row_splits = y.shape.row_splits(1) - y_lens = row_splits[1:] - row_splits[:-1] - + # sos_y sos_y = add_sos(y, sos_id=blank_id) # sos_y_padded: [B, S + 1], start with SOS. diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/onnx_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/onnx_pretrained.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/train.py index b282ab9db..ea280e642 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/train.py @@ -1,4 +1,3 @@ -#!/usr/bin/env python3 # Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang, # Wei Kang, # Mingshuang Luo, @@ -35,7 +34,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3" --use-fp16 1 \ --exp-dir pruned_transducer_stateless7_ctc_bs/exp \ --full-libri 1 \ - --max-duration 550 + --max-duration 750 """ From 5a05b957300ee21a4d2370039ac612e7265e0834 Mon Sep 17 00:00:00 2001 From: Zengwei Yao Date: Mon, 6 Feb 2023 23:21:46 +0800 Subject: [PATCH 097/174] add params.hlg_scale (#880) --- egs/librispeech/ASR/conformer_ctc3/decode.py | 199 ++++++------------- 1 file changed, 61 insertions(+), 138 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc3/decode.py b/egs/librispeech/ASR/conformer_ctc3/decode.py index 8eca2ae02..39186e546 100755 --- a/egs/librispeech/ASR/conformer_ctc3/decode.py +++ b/egs/librispeech/ASR/conformer_ctc3/decode.py @@ -58,7 +58,6 @@ For example: --left-context 64 \ --manifest-dir data/fbank_ali Note: It supports calculating symbol delay with following decoding methods: - - ctc-greedy-search - ctc-decoding - 1best """ @@ -96,10 +95,8 @@ from icefall.decode import ( from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, - DecodingResults, get_texts, get_texts_with_timestamp, - make_pad_mask, parse_hyp_and_timestamp, setup_logger, store_transcripts_and_timestamps, @@ -177,20 +174,18 @@ def get_parser(): - (0) ctc-decoding. Use CTC decoding. It uses a sentence piece model, i.e., lang_dir/bpe.model, to convert word pieces to words. It needs neither a lexicon nor an n-gram LM. - - (1) ctc-greedy-search. It only use CTC output and a sentence piece - model for decoding. It produces the same results with ctc-decoding. - - (2) 1best. Extract the best path from the decoding lattice as the + - (1) 1best. Extract the best path from the decoding lattice as the decoding result. - - (3) nbest. Extract n paths from the decoding lattice; the path + - (2) nbest. Extract n paths from the decoding lattice; the path with the highest score is the decoding result. - - (4) nbest-rescoring. Extract n paths from the decoding lattice, + - (3) nbest-rescoring. Extract n paths from the decoding lattice, rescore them with an n-gram LM (e.g., a 4-gram LM), the path with the highest score is the decoding result. - - (5) whole-lattice-rescoring. Rescore the decoding lattice with an + - (4) whole-lattice-rescoring. Rescore the decoding lattice with an n-gram LM (e.g., a 4-gram LM), the best path of rescored lattice is the decoding result. you have trained an RNN LM using ./rnn_lm/train.py - - (6) nbest-oracle. Its WER is the lower bound of any n-best + - (5) nbest-oracle. Its WER is the lower bound of any n-best rescoring method can achieve. Useful for debugging n-best rescoring method. """, @@ -250,6 +245,14 @@ def get_parser(): help="left context can be seen during decoding (in frames after subsampling)", ) + parser.add_argument( + "--hlg-scale", + type=float, + default=0.8, + help="""The scale to be applied to `hlg.scores`. + """, + ) + add_model_arguments(parser) return parser @@ -270,47 +273,6 @@ def get_decoding_params() -> AttributeDict: return params -def ctc_greedy_search( - ctc_probs: torch.Tensor, - nnet_output_lens: torch.Tensor, -) -> List[List[int]]: - """Apply CTC greedy search - Args: - ctc_probs (torch.Tensor): (batch, max_len, feat_dim) - nnet_output_lens (torch.Tensor): (batch, ) - Returns: - List[List[int]]: best path result - """ - topk_prob, topk_index = ctc_probs.topk(1, dim=2) # (B, maxlen, 1) - topk_index = topk_index.squeeze(2) # (B, maxlen) - mask = make_pad_mask(nnet_output_lens) - topk_index = topk_index.masked_fill_(mask, 0) # (B, maxlen) - hyps = [hyp.tolist() for hyp in topk_index] - scores = topk_prob.max(1) - ret_hyps = [] - timestamps = [] - for i in range(len(hyps)): - hyp, time = remove_duplicates_and_blank(hyps[i]) - ret_hyps.append(hyp) - timestamps.append(time) - return ret_hyps, timestamps, scores - - -def remove_duplicates_and_blank(hyp: List[int]) -> Tuple[List[int], List[int]]: - # modified from https://github.com/wenet-e2e/wenet/blob/main/wenet/utils/common.py - new_hyp: List[int] = [] - time: List[int] = [] - cur = 0 - while cur < len(hyp): - if hyp[cur] != 0: - new_hyp.append(hyp[cur]) - time.append(cur) - prev = cur - while cur < len(hyp) and hyp[cur] == hyp[prev]: - cur += 1 - return new_hyp, time - - def decode_one_batch( params: AttributeDict, model: nn.Module, @@ -402,26 +364,11 @@ def decode_one_batch( nnet_output = model.get_ctc_output(encoder_out) # nnet_output is (N, T, C) - if params.decoding_method == "ctc-greedy-search": - hyps, timestamps, _ = ctc_greedy_search( - nnet_output, - encoder_out_lens, - ) - res = DecodingResults(hyps=hyps, timestamps=timestamps) - hyps, timestamps = parse_hyp_and_timestamp( - res=res, - sp=bpe_model, - subsampling_factor=params.subsampling_factor, - frame_shift_ms=params.frame_shift_ms, - ) - key = "ctc-greedy-search" - return {key: (hyps, timestamps)} - supervision_segments = torch.stack( ( supervisions["sequence_idx"], supervisions["start_frame"] // params.subsampling_factor, - supervisions["num_frames"] // params.subsampling_factor, + encoder_out_lens.cpu(), ), 1, ).to(torch.int32) @@ -434,75 +381,6 @@ def decode_one_batch( assert bpe_model is not None decoding_graph = H - if params.decoding_method in ["1best", "nbest", "nbest-oracle"]: - hlg_scale_list = [0.2, 0.4, 0.6, 0.8, 1.0] - - ori_scores = decoding_graph.scores.clone() - - ans = {} - for hlg_scale in hlg_scale_list: - decoding_graph.scores = ori_scores * hlg_scale - lattice = get_lattice( - nnet_output=nnet_output, - decoding_graph=decoding_graph, - supervision_segments=supervision_segments, - search_beam=params.search_beam, - output_beam=params.output_beam, - min_active_states=params.min_active_states, - max_active_states=params.max_active_states, - subsampling_factor=params.subsampling_factor, - ) - key_suffix = f"-HLG-scale-{hlg_scale}" - - if params.decoding_method == "nbest-oracle": - # Note: You can also pass rescored lattices to it. - # We choose the HLG decoded lattice for speed reasons - # as HLG decoding is faster and the oracle WER - # is only slightly worse than that of rescored lattices. - best_path = nbest_oracle( - lattice=lattice, - num_paths=params.num_paths, - ref_texts=supervisions["text"], - word_table=word_table, - nbest_scale=params.nbest_scale, - oov="", - ) - hyps = get_texts(best_path) - hyps = [[word_table[i] for i in ids] for ids in hyps] - key = f"oracle-{params.num_paths}-nbest-scale-{params.nbest_scale}" # noqa - timestamps = [[] for _ in range(len(hyps))] - ans[key + key_suffix] = (hyps, timestamps) - - elif params.decoding_method in ["1best", "nbest"]: - if params.decoding_method == "1best": - best_path = one_best_decoding( - lattice=lattice, - use_double_scores=params.use_double_scores, - ) - key = "no-rescore" - res = get_texts_with_timestamp(best_path) - hyps, timestamps = parse_hyp_and_timestamp( - res=res, - subsampling_factor=params.subsampling_factor, - frame_shift_ms=params.frame_shift_ms, - word_table=word_table, - ) - else: - best_path = nbest_decoding( - lattice=lattice, - num_paths=params.num_paths, - use_double_scores=params.use_double_scores, - nbest_scale=params.nbest_scale, - ) - key = f"no_rescore-nbest-scale-{params.nbest_scale}-{params.num_paths}" # noqa - hyps = get_texts(best_path) - hyps = [[word_table[i] for i in ids] for ids in hyps] - timestamps = [[] for _ in range(len(hyps))] - - ans[key + key_suffix] = (hyps, timestamps) - - return ans - lattice = get_lattice( nnet_output=nnet_output, decoding_graph=decoding_graph, @@ -532,6 +410,51 @@ def decode_one_batch( key = "ctc-decoding" return {key: (hyps, timestamps)} + if params.decoding_method == "nbest-oracle": + # Note: You can also pass rescored lattices to it. + # We choose the HLG decoded lattice for speed reasons + # as HLG decoding is faster and the oracle WER + # is only slightly worse than that of rescored lattices. + best_path = nbest_oracle( + lattice=lattice, + num_paths=params.num_paths, + ref_texts=supervisions["text"], + word_table=word_table, + nbest_scale=params.nbest_scale, + oov="", + ) + hyps = get_texts(best_path) + hyps = [[word_table[i] for i in ids] for ids in hyps] + timestamps = [[] for _ in range(len(hyps))] + key = f"oracle_{params.num_paths}_nbest_scale_{params.nbest_scale}_hlg_scale_{params.hlg_scale}" # noqa + return {key: (hyps, timestamps)} + + if params.decoding_method in ["1best", "nbest"]: + if params.decoding_method == "1best": + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + key = f"no_rescore_hlg_scale_{params.hlg_scale}" + res = get_texts_with_timestamp(best_path) + hyps, timestamps = parse_hyp_and_timestamp( + res=res, + subsampling_factor=params.subsampling_factor, + frame_shift_ms=params.frame_shift_ms, + word_table=word_table, + ) + else: + best_path = nbest_decoding( + lattice=lattice, + num_paths=params.num_paths, + use_double_scores=params.use_double_scores, + nbest_scale=params.nbest_scale, + ) + key = f"no_rescore-nbest-scale-{params.nbest_scale}-{params.num_paths}-hlg-scale-{params.hlg_scale}" # noqa + hyps = get_texts(best_path) + hyps = [[word_table[i] for i in ids] for ids in hyps] + timestamps = [[] for _ in range(len(hyps))] + return {key: (hyps, timestamps)} + assert params.decoding_method in [ "nbest-rescoring", "whole-lattice-rescoring", @@ -757,7 +680,6 @@ def main(): params.update(vars(args)) assert params.decoding_method in ( - "ctc-greedy-search", "ctc-decoding", "1best", "nbest", @@ -811,7 +733,7 @@ def main(): params.sos_id = sos_id params.eos_id = eos_id - if params.decoding_method in ["ctc-decoding", "ctc-greedy-search"]: + if params.decoding_method == "ctc-decoding": HLG = None H = k2.ctc_topo( max_token=max_token_id, @@ -828,6 +750,7 @@ def main(): ) assert HLG.requires_grad is False + HLG.scores *= params.hlg_scale if not hasattr(HLG, "lm_scores"): HLG.lm_scores = HLG.scores.clone() From 52f3a747bec5d06ad0eba6a77f5abaf21110e5ec Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 7 Feb 2023 12:12:26 +0800 Subject: [PATCH 098/174] Refactor onnx export for streaming zipformer (#879) --- ...nsducer-stateless7-streaming-2022-12-29.sh | 3 +- .github/scripts/test-onnx-export.sh | 70 ++ .github/workflows/test-onnx-export.yml | 75 ++ .../export-onnx.py | 639 ++++++++++++++++++ .../onnx_check.py | 267 ++++++++ .../onnx_model_wrapper.py | 2 +- .../onnx_pretrained.py | 510 ++++++++++++++ .../zipformer.py | 11 +- 8 files changed, 1572 insertions(+), 5 deletions(-) create mode 100755 .github/scripts/test-onnx-export.sh create mode 100644 .github/workflows/test-onnx-export.yml create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx.py create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_check.py create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_pretrained.py diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless7-streaming-2022-12-29.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless7-streaming-2022-12-29.sh index bcbc91a44..2611c8efc 100755 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless7-streaming-2022-12-29.sh +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless7-streaming-2022-12-29.sh @@ -22,13 +22,14 @@ tree $repo/ soxi $repo/test_wavs/*.wav ls -lh $repo/test_wavs/*.wav -pushd $repo/exp +pushd $repo git lfs pull --include "data/lang_bpe_500/bpe.model" git lfs pull --include "exp/cpu_jit.pt" git lfs pull --include "exp/pretrained.pt" git lfs pull --include "exp/encoder_jit_trace.pt" git lfs pull --include "exp/decoder_jit_trace.pt" git lfs pull --include "exp/joiner_jit_trace.pt" +cd exp ln -s pretrained.pt epoch-99.pt ls -lh *.pt popd diff --git a/.github/scripts/test-onnx-export.sh b/.github/scripts/test-onnx-export.sh new file mode 100755 index 000000000..20aa02950 --- /dev/null +++ b/.github/scripts/test-onnx-export.sh @@ -0,0 +1,70 @@ +#!/usr/bin/env bash + +set -e + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29 + +log "==========================================================================" +log "Downloading pre-trained model from $repo_url" +git lfs install +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "data/lang_bpe_500/bpe.model" +git lfs pull --include "exp/pretrained.pt" +cd exp +ln -s pretrained.pt epoch-99.pt +popd + +log "Export via torch.jit.trace()" + +./pruned_transducer_stateless7_streaming/jit_trace_export.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --decode-chunk-len 32 \ + --exp-dir $repo/exp/ + +log "Test exporting to ONNX format" + +./pruned_transducer_stateless7_streaming/export-onnx.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --decode-chunk-len 32 \ + --exp-dir $repo/exp/ + +ls -lh $repo/exp + +log "Run onnx_check.py" + +./pruned_transducer_stateless7_streaming/onnx_check.py \ + --jit-encoder-filename $repo/exp/encoder_jit_trace.pt \ + --jit-decoder-filename $repo/exp/decoder_jit_trace.pt \ + --jit-joiner-filename $repo/exp/joiner_jit_trace.pt \ + --onnx-encoder-filename $repo/exp/encoder-epoch-99-avg-1.onnx \ + --onnx-decoder-filename $repo/exp/decoder-epoch-99-avg-1.onnx \ + --onnx-joiner-filename $repo/exp/joiner-epoch-99-avg-1.onnx + +log "Run onnx_pretrained.py" + +./pruned_transducer_stateless7_streaming/onnx_pretrained.py \ + --encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \ + --decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \ + --joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1089-134686-0001.wav + +rm -rf $repo +log "--------------------------------------------------------------------------" diff --git a/.github/workflows/test-onnx-export.yml b/.github/workflows/test-onnx-export.yml new file mode 100644 index 000000000..c7729dedb --- /dev/null +++ b/.github/workflows/test-onnx-export.yml @@ -0,0 +1,75 @@ +name: test-onnx-export + +on: + push: + branches: + - master + pull_request: + types: [labeled] + + schedule: + # minute (0-59) + # hour (0-23) + # day of the month (1-31) + # month (1-12) + # day of the week (0-6) + # nightly build at 15:50 UTC time every day + - cron: "50 15 * * *" + +concurrency: + group: test_onnx_export-${{ github.ref }} + cancel-in-progress: true + +jobs: + test_onnx_export: + if: github.event.label.name == 'ready' || github.event.label.name == 'onnx' || github.event_name == 'push' || github.event_name == 'schedule' + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-latest] + python-version: [3.8] + fail-fast: false + + steps: + - uses: actions/checkout@v2 + with: + fetch-depth: 0 + + - name: Setup Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + cache: 'pip' + cache-dependency-path: '**/requirements-ci.txt' + + - name: Install Python dependencies + run: | + grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install + pip uninstall -y protobuf + pip install --no-binary protobuf protobuf + + - name: Cache kaldifeat + id: my-cache + uses: actions/cache@v2 + with: + path: | + ~/tmp/kaldifeat + key: cache-tmp-${{ matrix.python-version }}-2022-09-25 + + - name: Install kaldifeat + if: steps.my-cache.outputs.cache-hit != 'true' + shell: bash + run: | + .github/scripts/install-kaldifeat.sh + + - name: Test ONNX export + shell: bash + env: + GITHUB_EVENT_NAME: ${{ github.event_name }} + GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }} + run: | + export PYTHONPATH=$PWD:$PYTHONPATH + export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH + export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH + + .github/scripts/test-onnx-export.sh diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx.py new file mode 100755 index 000000000..a72472495 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx.py @@ -0,0 +1,639 @@ +#!/usr/bin/env python3 +# +# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang) + +""" +This script exports a transducer model from PyTorch to ONNX. + +We use the pre-trained model from +https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29 +as an example to show how to use this file. + +1. Download the pre-trained model + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "data/lang_bpe_500/bpe.model" +git lfs pull --include "exp/pretrained.pt" +cd exp +ln -s pretrained.pt epoch-99.pt +popd + +2. Export the model to ONNX + +./pruned_transducer_stateless7_streaming/export-onnx.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --decode-chunk-len 32 \ + --exp-dir $repo/exp/ + +It will generate the following 3 files in $repo/exp + + - encoder-epoch-99-avg-1.onnx + - decoder-epoch-99-avg-1.onnx + - joiner-epoch-99-avg-1.onnx + +See ./onnx_pretrained.py for how to use the exported models. +""" + +import argparse +import logging +from pathlib import Path +from typing import Dict, List, Tuple + +import onnx +import sentencepiece as spm +import torch +import torch.nn as nn +from decoder import Decoder +from scaling_converter import convert_scaled_to_non_scaled +from torch import Tensor +from train import add_model_arguments, get_params, get_transducer_model +from zipformer import Zipformer + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import setup_logger, str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=9, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7_streaming/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + add_model_arguments(parser) + + return parser + + +class OnnxEncoder(nn.Module): + """A wrapper for Zipformer and the encoder_proj from the joiner""" + + def __init__(self, encoder: Zipformer, encoder_proj: nn.Linear): + """ + Args: + encoder: + A zipformer encoder. + encoder_proj: + The projection layer for encoder from the joiner. + """ + super().__init__() + self.encoder = encoder + self.encoder_proj = encoder_proj + + def forward(self, x: Tensor, states: List[Tensor]) -> Tuple[Tensor, List[Tensor]]: + """Please see the help information of Zipformer.streaming_forward""" + N = x.size(0) + T = x.size(1) + x_lens = torch.tensor([T] * N, device=x.device) + + output, _, new_states = self.encoder.streaming_forward( + x=x, + x_lens=x_lens, + states=states, + ) + + output = self.encoder_proj(output) + # Now output is of shape (N, T, joiner_dim) + + return output, new_states + + +class OnnxDecoder(nn.Module): + """A wrapper for Decoder and the decoder_proj from the joiner""" + + def __init__(self, decoder: Decoder, decoder_proj: nn.Linear): + super().__init__() + self.decoder = decoder + self.decoder_proj = decoder_proj + + def forward(self, y: torch.Tensor) -> torch.Tensor: + """ + Args: + y: + A 2-D tensor of shape (N, context_size). + Returns + Return a 2-D tensor of shape (N, joiner_dim) + """ + need_pad = False + decoder_output = self.decoder(y, need_pad=need_pad) + decoder_output = decoder_output.squeeze(1) + output = self.decoder_proj(decoder_output) + + return output + + +class OnnxJoiner(nn.Module): + """A wrapper for the joiner""" + + def __init__(self, output_linear: nn.Linear): + super().__init__() + self.output_linear = output_linear + + def forward( + self, + encoder_out: torch.Tensor, + decoder_out: torch.Tensor, + ) -> torch.Tensor: + """ + Args: + encoder_out: + A 2-D tensor of shape (N, joiner_dim) + decoder_out: + A 2-D tensor of shape (N, joiner_dim) + Returns: + Return a 2-D tensor of shape (N, vocab_size) + """ + logit = encoder_out + decoder_out + logit = self.output_linear(torch.tanh(logit)) + return logit + + +def add_meta_data(filename: str, meta_data: Dict[str, str]): + """Add meta data to an ONNX model. It is changed in-place. + + Args: + filename: + Filename of the ONNX model to be changed. + meta_data: + Key-value pairs. + """ + model = onnx.load(filename) + for key, value in meta_data.items(): + meta = model.metadata_props.add() + meta.key = key + meta.value = value + + onnx.save(model, filename) + + +def export_encoder_model_onnx( + encoder_model: OnnxEncoder, + encoder_filename: str, + opset_version: int = 11, +) -> None: + """ + Onnx model inputs: + - 0: src + - many state tensors (the exact number depending on the actual model) + + Onnx model outputs: + - 0: output, its shape is (N, T, joiner_dim) + - many state tensors (the exact number depending on the actual model) + + Args: + encoder_model: + The model to be exported + encoder_filename: + The filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + + encoder_model.encoder.__class__.forward = ( + encoder_model.encoder.__class__.streaming_forward + ) + + decode_chunk_len = encoder_model.encoder.decode_chunk_size * 2 + pad_length = 7 + T = decode_chunk_len + pad_length + logging.info(f"decode_chunk_len: {decode_chunk_len}") + logging.info(f"pad_length: {pad_length}") + logging.info(f"T: {T}") + + x = torch.rand(1, T, 80, dtype=torch.float32) + + init_state = encoder_model.encoder.get_init_state() + + num_encoders = encoder_model.encoder.num_encoders + logging.info(f"num_encoders: {num_encoders}") + logging.info(f"len(init_state): {len(init_state)}") + + inputs = {} + input_names = ["x"] + + outputs = {} + output_names = ["encoder_out"] + + def build_inputs_outputs(tensors, name, N): + for i, s in enumerate(tensors): + logging.info(f"{name}_{i}.shape: {s.shape}") + inputs[f"{name}_{i}"] = {N: "N"} + outputs[f"new_{name}_{i}"] = {N: "N"} + input_names.append(f"{name}_{i}") + output_names.append(f"new_{name}_{i}") + + num_encoder_layers = ",".join(map(str, encoder_model.encoder.num_encoder_layers)) + encoder_dims = ",".join(map(str, encoder_model.encoder.encoder_dims)) + attention_dims = ",".join(map(str, encoder_model.encoder.attention_dims)) + cnn_module_kernels = ",".join(map(str, encoder_model.encoder.cnn_module_kernels)) + ds = encoder_model.encoder.zipformer_downsampling_factors + left_context_len = encoder_model.encoder.left_context_len + left_context_len = [left_context_len // k for k in ds] + left_context_len = ",".join(map(str, left_context_len)) + + meta_data = { + "model_type": "streaming_zipformer", + "version": "1", + "model_author": "k2-fsa", + "decode_chunk_len": str(decode_chunk_len), # 32 + "pad_length": str(pad_length), # 7 + "num_encoder_layers": num_encoder_layers, + "encoder_dims": encoder_dims, + "attention_dims": attention_dims, + "cnn_module_kernels": cnn_module_kernels, + "left_context_len": left_context_len, + } + logging.info(f"meta_data: {meta_data}") + + # (num_encoder_layers, 1) + cached_len = init_state[num_encoders * 0 : num_encoders * 1] + + # (num_encoder_layers, 1, encoder_dim) + cached_avg = init_state[num_encoders * 1 : num_encoders * 2] + + # (num_encoder_layers, left_context_len, 1, attention_dim) + cached_key = init_state[num_encoders * 2 : num_encoders * 3] + + # (num_encoder_layers, left_context_len, 1, attention_dim//2) + cached_val = init_state[num_encoders * 3 : num_encoders * 4] + + # (num_encoder_layers, left_context_len, 1, attention_dim//2) + cached_val2 = init_state[num_encoders * 4 : num_encoders * 5] + + # (num_encoder_layers, 1, encoder_dim, cnn_module_kernel-1) + cached_conv1 = init_state[num_encoders * 5 : num_encoders * 6] + + # (num_encoder_layers, 1, encoder_dim, cnn_module_kernel-1) + cached_conv2 = init_state[num_encoders * 6 : num_encoders * 7] + + build_inputs_outputs(cached_len, "cached_len", 1) + build_inputs_outputs(cached_avg, "cached_avg", 1) + build_inputs_outputs(cached_key, "cached_key", 2) + build_inputs_outputs(cached_val, "cached_val", 2) + build_inputs_outputs(cached_val2, "cached_val2", 2) + build_inputs_outputs(cached_conv1, "cached_conv1", 1) + build_inputs_outputs(cached_conv2, "cached_conv2", 1) + + logging.info(inputs) + logging.info(outputs) + logging.info(input_names) + logging.info(output_names) + + torch.onnx.export( + encoder_model, + (x, init_state), + encoder_filename, + verbose=False, + opset_version=opset_version, + input_names=input_names, + output_names=output_names, + dynamic_axes={ + "x": {0: "N", 1: "T"}, + "encoder_out": {0: "N", 1: "T"}, + **inputs, + **outputs, + }, + ) + + add_meta_data(filename=encoder_filename, meta_data=meta_data) + + +def export_decoder_model_onnx( + decoder_model: nn.Module, + decoder_filename: str, + opset_version: int = 11, +) -> None: + """Export the decoder model to ONNX format. + + The exported model has one input: + + - y: a torch.int64 tensor of shape (N, context_size) + + and has one output: + + - decoder_out: a torch.float32 tensor of shape (N, joiner_dim) + + Note: The argument need_pad is fixed to False. + + Args: + decoder_model: + The decoder model to be exported. + decoder_filename: + Filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + context_size = decoder_model.decoder.context_size + vocab_size = decoder_model.decoder.vocab_size + y = torch.zeros(10, context_size, dtype=torch.int64) + torch.onnx.export( + decoder_model, + y, + decoder_filename, + verbose=False, + opset_version=opset_version, + input_names=["y"], + output_names=["decoder_out"], + dynamic_axes={ + "y": {0: "N"}, + "decoder_out": {0: "N"}, + }, + ) + meta_data = { + "context_size": str(context_size), + "vocab_size": str(vocab_size), + } + add_meta_data(filename=decoder_filename, meta_data=meta_data) + + +def export_joiner_model_onnx( + joiner_model: nn.Module, + joiner_filename: str, + opset_version: int = 11, +) -> None: + """Export the joiner model to ONNX format. + The exported joiner model has two inputs: + + - encoder_out: a tensor of shape (N, joiner_dim) + - decoder_out: a tensor of shape (N, joiner_dim) + + and produces one output: + + - logit: a tensor of shape (N, vocab_size) + """ + joiner_dim = joiner_model.output_linear.weight.shape[1] + logging.info(f"joiner dim: {joiner_dim}") + + projected_encoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) + projected_decoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) + + torch.onnx.export( + joiner_model, + (projected_encoder_out, projected_decoder_out), + joiner_filename, + verbose=False, + opset_version=opset_version, + input_names=[ + "encoder_out", + "decoder_out", + ], + output_names=["logit"], + dynamic_axes={ + "encoder_out": {0: "N"}, + "decoder_out": {0: "N"}, + "logit": {0: "N"}, + }, + ) + meta_data = { + "joiner_dim": str(joiner_dim), + } + add_meta_data(filename=joiner_filename, meta_data=meta_data) + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + setup_logger(f"{params.exp_dir}/log-export/log-export-onnx") + + logging.info(f"device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + model.to(device) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to("cpu") + model.eval() + + convert_scaled_to_non_scaled(model, inplace=True) + encoder = OnnxEncoder( + encoder=model.encoder, + encoder_proj=model.joiner.encoder_proj, + ) + + decoder = OnnxDecoder( + decoder=model.decoder, + decoder_proj=model.joiner.decoder_proj, + ) + + joiner = OnnxJoiner(output_linear=model.joiner.output_linear) + + encoder_num_param = sum([p.numel() for p in encoder.parameters()]) + decoder_num_param = sum([p.numel() for p in decoder.parameters()]) + joiner_num_param = sum([p.numel() for p in joiner.parameters()]) + total_num_param = encoder_num_param + decoder_num_param + joiner_num_param + logging.info(f"encoder parameters: {encoder_num_param}") + logging.info(f"decoder parameters: {decoder_num_param}") + logging.info(f"joiner parameters: {joiner_num_param}") + logging.info(f"total parameters: {total_num_param}") + + if params.iter > 0: + suffix = f"iter-{params.iter}" + else: + suffix = f"epoch-{params.epoch}" + + suffix += f"-avg-{params.avg}" + if params.use_averaged_model: + suffix += "-with-averaged-model" + + opset_version = 13 + + logging.info("Exporting encoder") + encoder_filename = params.exp_dir / f"encoder-{suffix}.onnx" + export_encoder_model_onnx( + encoder, + encoder_filename, + opset_version=opset_version, + ) + logging.info(f"Exported encoder to {encoder_filename}") + + logging.info("Exporting decoder") + decoder_filename = params.exp_dir / f"decoder-{suffix}.onnx" + export_decoder_model_onnx( + decoder, + decoder_filename, + opset_version=opset_version, + ) + logging.info(f"Exported decoder to {decoder_filename}") + + logging.info("Exporting joiner") + joiner_filename = params.exp_dir / f"joiner-{suffix}.onnx" + export_joiner_model_onnx( + joiner, + joiner_filename, + opset_version=opset_version, + ) + logging.info(f"Exported joiner to {joiner_filename}") + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_check.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_check.py new file mode 100755 index 000000000..72ad59a55 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_check.py @@ -0,0 +1,267 @@ +#!/usr/bin/env python3 +# +# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang) + +""" +This script checks that exported ONNX models produce the same output +with the given torchscript model for the same input. + +We use the pre-trained model from +https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29 +as an example to show how to use this file. + +1. Download the pre-trained model + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "data/lang_bpe_500/bpe.model" +git lfs pull --include "exp/pretrained.pt" +cd exp +ln -s pretrained.pt epoch-99.pt +popd + +2. Export the model via torch.jit.trace() + +./pruned_transducer_stateless7_streaming/jit_trace_export.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --decode-chunk-len 32 \ + --exp-dir $repo/exp/ + +It will generate the following 3 files inside $repo/exp + + - encoder_jit_trace.pt + - decoder_jit_trace.pt + - joiner_jit_trace.pt + +3. Export the model to ONNX + +./pruned_transducer_stateless7_streaming/export-onnx.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --decode-chunk-len 32 \ + --exp-dir $repo/exp/ + +It will generate the following 3 files inside $repo/exp: + + - encoder-epoch-99-avg-1.onnx + - decoder-epoch-99-avg-1.onnx + - joiner-epoch-99-avg-1.onnx + +4. Run this file + +./pruned_transducer_stateless7_streaming/onnx_check.py \ + --jit-encoder-filename $repo/exp/encoder_jit_trace.pt \ + --jit-decoder-filename $repo/exp/decoder_jit_trace.pt \ + --jit-joiner-filename $repo/exp/joiner_jit_trace.pt \ + --onnx-encoder-filename $repo/exp/encoder-epoch-99-avg-1.onnx \ + --onnx-decoder-filename $repo/exp/decoder-epoch-99-avg-1.onnx \ + --onnx-joiner-filename $repo/exp/joiner-epoch-99-avg-1.onnx +""" + +import argparse +import logging + +from onnx_pretrained import OnnxModel +from zipformer import stack_states + +from icefall import is_module_available + +if not is_module_available("onnxruntime"): + raise ValueError("Please 'pip install onnxruntime' first.") + +import onnxruntime as ort +import torch + +ort.set_default_logger_severity(3) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--jit-encoder-filename", + required=True, + type=str, + help="Path to the torchscript encoder model", + ) + + parser.add_argument( + "--jit-decoder-filename", + required=True, + type=str, + help="Path to the torchscript decoder model", + ) + + parser.add_argument( + "--jit-joiner-filename", + required=True, + type=str, + help="Path to the torchscript joiner model", + ) + + parser.add_argument( + "--onnx-encoder-filename", + required=True, + type=str, + help="Path to the ONNX encoder model", + ) + + parser.add_argument( + "--onnx-decoder-filename", + required=True, + type=str, + help="Path to the ONNX decoder model", + ) + + parser.add_argument( + "--onnx-joiner-filename", + required=True, + type=str, + help="Path to the ONNX joiner model", + ) + + return parser + + +def test_encoder( + torch_encoder_model: torch.jit.ScriptModule, + torch_encoder_proj_model: torch.jit.ScriptModule, + onnx_model: OnnxModel, +): + N = torch.randint(1, 100, size=(1,)).item() + T = onnx_model.segment + C = 80 + x_lens = torch.tensor([T] * N) + torch_states = [torch_encoder_model.get_init_state() for _ in range(N)] + torch_states = stack_states(torch_states) + + onnx_model.init_encoder_states(N) + + for i in range(5): + logging.info(f"test_encoder: iter {i}") + x = torch.rand(N, T, C) + torch_encoder_out, _, torch_states = torch_encoder_model( + x, x_lens, torch_states + ) + torch_encoder_out = torch_encoder_proj_model(torch_encoder_out) + + onnx_encoder_out = onnx_model.run_encoder(x) + + assert torch.allclose(torch_encoder_out, onnx_encoder_out, atol=1e-4), ( + (torch_encoder_out - onnx_encoder_out).abs().max() + ) + + +def test_decoder( + torch_decoder_model: torch.jit.ScriptModule, + torch_decoder_proj_model: torch.jit.ScriptModule, + onnx_model: OnnxModel, +): + context_size = onnx_model.context_size + vocab_size = onnx_model.vocab_size + for i in range(10): + N = torch.randint(1, 100, size=(1,)).item() + logging.info(f"test_decoder: iter {i}, N={N}") + x = torch.randint( + low=1, + high=vocab_size, + size=(N, context_size), + dtype=torch.int64, + ) + torch_decoder_out = torch_decoder_model(x, need_pad=torch.tensor([False])) + torch_decoder_out = torch_decoder_proj_model(torch_decoder_out) + torch_decoder_out = torch_decoder_out.squeeze(1) + + onnx_decoder_out = onnx_model.run_decoder(x) + assert torch.allclose(torch_decoder_out, onnx_decoder_out, atol=1e-4), ( + (torch_decoder_out - onnx_decoder_out).abs().max() + ) + + +def test_joiner( + torch_joiner_model: torch.jit.ScriptModule, + onnx_model: OnnxModel, +): + encoder_dim = torch_joiner_model.encoder_proj.weight.shape[1] + decoder_dim = torch_joiner_model.decoder_proj.weight.shape[1] + for i in range(10): + N = torch.randint(1, 100, size=(1,)).item() + logging.info(f"test_joiner: iter {i}, N={N}") + encoder_out = torch.rand(N, encoder_dim) + decoder_out = torch.rand(N, decoder_dim) + + projected_encoder_out = torch_joiner_model.encoder_proj(encoder_out) + projected_decoder_out = torch_joiner_model.decoder_proj(decoder_out) + + torch_joiner_out = torch_joiner_model(encoder_out, decoder_out) + onnx_joiner_out = onnx_model.run_joiner( + projected_encoder_out, projected_decoder_out + ) + + assert torch.allclose(torch_joiner_out, onnx_joiner_out, atol=1e-4), ( + (torch_joiner_out - onnx_joiner_out).abs().max() + ) + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + logging.info(vars(args)) + + torch_encoder_model = torch.jit.load(args.jit_encoder_filename) + torch_decoder_model = torch.jit.load(args.jit_decoder_filename) + torch_joiner_model = torch.jit.load(args.jit_joiner_filename) + + onnx_model = OnnxModel( + encoder_model_filename=args.onnx_encoder_filename, + decoder_model_filename=args.onnx_decoder_filename, + joiner_model_filename=args.onnx_joiner_filename, + ) + + logging.info("Test encoder") + # When exporting the model to onnx, we have already put the encoder_proj + # inside the encoder. + test_encoder(torch_encoder_model, torch_joiner_model.encoder_proj, onnx_model) + + logging.info("Test decoder") + # When exporting the model to onnx, we have already put the decoder_proj + # inside the decoder. + test_decoder(torch_decoder_model, torch_joiner_model.decoder_proj, onnx_model) + + logging.info("Test joiner") + test_joiner(torch_joiner_model, onnx_model) + + logging.info("Finished checking ONNX models") + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +# See https://github.com/pytorch/pytorch/issues/38342 +# and https://github.com/pytorch/pytorch/issues/33354 +# +# If we don't do this, the delay increases whenever there is +# a new request that changes the actual batch size. +# If you use `py-spy dump --pid --native`, you will +# see a lot of time is spent in re-compiling the torch script model. +torch._C._jit_set_profiling_executor(False) +torch._C._jit_set_profiling_mode(False) +torch._C._set_graph_executor_optimize(False) +if __name__ == "__main__": + torch.manual_seed(20230207) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_model_wrapper.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_model_wrapper.py index f52deecc9..71a418742 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_model_wrapper.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_model_wrapper.py @@ -12,7 +12,7 @@ class OnnxStreamingEncoder(torch.nn.Module): def __init__(self, encoder): """ Args: - encoder: A Instance of Zipformer Class + encoder: An instance of Zipformer Class """ super().__init__() self.model = encoder diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_pretrained.py new file mode 100755 index 000000000..265684d18 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_pretrained.py @@ -0,0 +1,510 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang) + +""" +This script loads ONNX models exported by ./export-onnx.py +and uses them to decode waves. + +We use the pre-trained model from +https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29 +as an example to show how to use this file. + +1. Download the pre-trained model + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "data/lang_bpe_500/bpe.model" +git lfs pull --include "exp/pretrained.pt" +cd exp +ln -s pretrained.pt epoch-99.pt +popd + +2. Export the model to ONNX + +./pruned_transducer_stateless7_streaming/export-onnx.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --decode-chunk-len 32 \ + --exp-dir $repo/exp/ + +It will generate the following 3 files in $repo/exp + + - encoder-epoch-99-avg-1.onnx + - decoder-epoch-99-avg-1.onnx + - joiner-epoch-99-avg-1.onnx + +3. Run this file with the exported ONNX models + +./pruned_transducer_stateless7_streaming/onnx_pretrained.py \ + --encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \ + --decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \ + --joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1089-134686-0001.wav + +Note: Even though this script only supports decoding a single file, +the exported ONNX models do support batch processing. +""" + +import argparse +import logging +from typing import Dict, List, Optional, Tuple + +import k2 +import numpy as np +import onnxruntime as ort +import torch +import torchaudio +from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--encoder-model-filename", + type=str, + required=True, + help="Path to the encoder onnx model. ", + ) + + parser.add_argument( + "--decoder-model-filename", + type=str, + required=True, + help="Path to the decoder onnx model. ", + ) + + parser.add_argument( + "--joiner-model-filename", + type=str, + required=True, + help="Path to the joiner onnx model. ", + ) + + parser.add_argument( + "--tokens", + type=str, + help="""Path to tokens.txt.""", + ) + + parser.add_argument( + "sound_file", + type=str, + help="The input sound file to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + return parser + + +class OnnxModel: + def __init__( + self, + encoder_model_filename: str, + decoder_model_filename: str, + joiner_model_filename: str, + ): + session_opts = ort.SessionOptions() + session_opts.inter_op_num_threads = 1 + session_opts.intra_op_num_threads = 1 + + self.session_opts = session_opts + + self.init_encoder(encoder_model_filename) + self.init_decoder(decoder_model_filename) + self.init_joiner(joiner_model_filename) + + def init_encoder(self, encoder_model_filename: str): + self.encoder = ort.InferenceSession( + encoder_model_filename, + sess_options=self.session_opts, + ) + self.init_encoder_states() + + def init_encoder_states(self, batch_size: int = 1): + encoder_meta = self.encoder.get_modelmeta().custom_metadata_map + + decode_chunk_len = int(encoder_meta["decode_chunk_len"]) + pad_length = int(encoder_meta["pad_length"]) + + num_encoder_layers = encoder_meta["num_encoder_layers"] + encoder_dims = encoder_meta["encoder_dims"] + attention_dims = encoder_meta["attention_dims"] + cnn_module_kernels = encoder_meta["cnn_module_kernels"] + left_context_len = encoder_meta["left_context_len"] + + def to_int_list(s): + return list(map(int, s.split(","))) + + num_encoder_layers = to_int_list(num_encoder_layers) + encoder_dims = to_int_list(encoder_dims) + attention_dims = to_int_list(attention_dims) + cnn_module_kernels = to_int_list(cnn_module_kernels) + left_context_len = to_int_list(left_context_len) + + logging.info(f"decode_chunk_len: {decode_chunk_len}") + logging.info(f"pad_length: {pad_length}") + logging.info(f"num_encoder_layers: {num_encoder_layers}") + logging.info(f"encoder_dims: {encoder_dims}") + logging.info(f"attention_dims: {attention_dims}") + logging.info(f"cnn_module_kernels: {cnn_module_kernels}") + logging.info(f"left_context_len: {left_context_len}") + + num_encoders = len(num_encoder_layers) + + cached_len = [] + cached_avg = [] + cached_key = [] + cached_val = [] + cached_val2 = [] + cached_conv1 = [] + cached_conv2 = [] + + N = batch_size + + for i in range(num_encoders): + cached_len.append(torch.zeros(num_encoder_layers[i], N, dtype=torch.int64)) + cached_avg.append(torch.zeros(num_encoder_layers[i], N, encoder_dims[i])) + cached_key.append( + torch.zeros( + num_encoder_layers[i], left_context_len[i], N, attention_dims[i] + ) + ) + cached_val.append( + torch.zeros( + num_encoder_layers[i], + left_context_len[i], + N, + attention_dims[i] // 2, + ) + ) + cached_val2.append( + torch.zeros( + num_encoder_layers[i], + left_context_len[i], + N, + attention_dims[i] // 2, + ) + ) + cached_conv1.append( + torch.zeros( + num_encoder_layers[i], N, encoder_dims[i], cnn_module_kernels[i] - 1 + ) + ) + cached_conv2.append( + torch.zeros( + num_encoder_layers[i], N, encoder_dims[i], cnn_module_kernels[i] - 1 + ) + ) + + self.cached_len = cached_len + self.cached_avg = cached_avg + self.cached_key = cached_key + self.cached_val = cached_val + self.cached_val2 = cached_val2 + self.cached_conv1 = cached_conv1 + self.cached_conv2 = cached_conv2 + + self.num_encoders = num_encoders + + self.segment = decode_chunk_len + pad_length + self.offset = decode_chunk_len + + def init_decoder(self, decoder_model_filename: str): + self.decoder = ort.InferenceSession( + decoder_model_filename, + sess_options=self.session_opts, + ) + + decoder_meta = self.decoder.get_modelmeta().custom_metadata_map + self.context_size = int(decoder_meta["context_size"]) + self.vocab_size = int(decoder_meta["vocab_size"]) + + logging.info(f"context_size: {self.context_size}") + logging.info(f"vocab_size: {self.vocab_size}") + + def init_joiner(self, joiner_model_filename: str): + self.joiner = ort.InferenceSession( + joiner_model_filename, + sess_options=self.session_opts, + ) + + joiner_meta = self.joiner.get_modelmeta().custom_metadata_map + self.joiner_dim = int(joiner_meta["joiner_dim"]) + + logging.info(f"joiner_dim: {self.joiner_dim}") + + def _build_encoder_input_output( + self, + x: torch.Tensor, + ) -> Tuple[Dict[str, np.ndarray], List[str]]: + encoder_input = {"x": x.numpy()} + encoder_output = ["encoder_out"] + + def build_states_input(states: List[torch.Tensor], name: str): + for i, s in enumerate(states): + if isinstance(s, torch.Tensor): + encoder_input[f"{name}_{i}"] = s.numpy() + else: + encoder_input[f"{name}_{i}"] = s + + encoder_output.append(f"new_{name}_{i}") + + build_states_input(self.cached_len, "cached_len") + build_states_input(self.cached_avg, "cached_avg") + build_states_input(self.cached_key, "cached_key") + build_states_input(self.cached_val, "cached_val") + build_states_input(self.cached_val2, "cached_val2") + build_states_input(self.cached_conv1, "cached_conv1") + build_states_input(self.cached_conv2, "cached_conv2") + + return encoder_input, encoder_output + + def _update_states(self, states: List[np.ndarray]): + num_encoders = self.num_encoders + + self.cached_len = states[num_encoders * 0 : num_encoders * 1] + self.cached_avg = states[num_encoders * 1 : num_encoders * 2] + self.cached_key = states[num_encoders * 2 : num_encoders * 3] + self.cached_val = states[num_encoders * 3 : num_encoders * 4] + self.cached_val2 = states[num_encoders * 4 : num_encoders * 5] + self.cached_conv1 = states[num_encoders * 5 : num_encoders * 6] + self.cached_conv2 = states[num_encoders * 6 : num_encoders * 7] + + def run_encoder(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x: + A 3-D tensor of shape (N, T, C) + Returns: + Return a 3-D tensor of shape (N, T', joiner_dim) where + T' is usually equal to ((T-7)//2+1)//2 + """ + encoder_input, encoder_output_names = self._build_encoder_input_output(x) + out = self.encoder.run(encoder_output_names, encoder_input) + + self._update_states(out[1:]) + + return torch.from_numpy(out[0]) + + def run_decoder(self, decoder_input: torch.Tensor) -> torch.Tensor: + """ + Args: + decoder_input: + A 2-D tensor of shape (N, context_size) + Returns: + Return a 2-D tensor of shape (N, joiner_dim) + """ + out = self.decoder.run( + [self.decoder.get_outputs()[0].name], + {self.decoder.get_inputs()[0].name: decoder_input.numpy()}, + )[0] + + return torch.from_numpy(out) + + def run_joiner( + self, encoder_out: torch.Tensor, decoder_out: torch.Tensor + ) -> torch.Tensor: + """ + Args: + encoder_out: + A 2-D tensor of shape (N, joiner_dim) + decoder_out: + A 2-D tensor of shape (N, joiner_dim) + Returns: + Return a 2-D tensor of shape (N, vocab_size) + """ + out = self.joiner.run( + [self.joiner.get_outputs()[0].name], + { + self.joiner.get_inputs()[0].name: encoder_out.numpy(), + self.joiner.get_inputs()[1].name: decoder_out.numpy(), + }, + )[0] + return torch.from_numpy(out) + + return torch.from_numpy(out) + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0].contiguous()) + return ans + + +def create_streaming_feature_extractor() -> OnlineFeature: + """Create a CPU streaming feature extractor. + + At present, we assume it returns a fbank feature extractor with + fixed options. In the future, we will support passing in the options + from outside. + + Returns: + Return a CPU streaming feature extractor. + """ + opts = FbankOptions() + opts.device = "cpu" + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = 16000 + opts.mel_opts.num_bins = 80 + return OnlineFbank(opts) + + +def greedy_search( + model: OnnxModel, + encoder_out: torch.Tensor, + context_size: int, + decoder_out: Optional[torch.Tensor] = None, + hyp: Optional[List[int]] = None, +) -> List[int]: + """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. + Args: + model: + The transducer model. + encoder_out: + A 3-D tensor of shape (1, T, joiner_dim) + context_size: + The context size of the decoder model. + decoder_out: + Optional. Decoder output of the previous chunk. + hyp: + Decoding results for previous chunks. + Returns: + Return the decoded results so far. + """ + + blank_id = 0 + + if decoder_out is None: + assert hyp is None, hyp + hyp = [blank_id] * context_size + decoder_input = torch.tensor([hyp], dtype=torch.int64) + decoder_out = model.run_decoder(decoder_input) + else: + assert hyp is not None, hyp + + encoder_out = encoder_out.squeeze(0) + T = encoder_out.size(0) + for t in range(T): + cur_encoder_out = encoder_out[t : t + 1] + joiner_out = model.run_joiner(cur_encoder_out, decoder_out).squeeze(0) + y = joiner_out.argmax(dim=0).item() + if y != blank_id: + hyp.append(y) + decoder_input = hyp[-context_size:] + decoder_input = torch.tensor([decoder_input], dtype=torch.int64) + decoder_out = model.run_decoder(decoder_input) + + return hyp, decoder_out + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + logging.info(vars(args)) + + model = OnnxModel( + encoder_model_filename=args.encoder_model_filename, + decoder_model_filename=args.decoder_model_filename, + joiner_model_filename=args.joiner_model_filename, + ) + + sample_rate = 16000 + + logging.info("Constructing Fbank computer") + online_fbank = create_streaming_feature_extractor() + + logging.info(f"Reading sound files: {args.sound_file}") + waves = read_sound_files( + filenames=[args.sound_file], + expected_sample_rate=sample_rate, + )[0] + + tail_padding = torch.zeros(int(0.3 * sample_rate), dtype=torch.float32) + wave_samples = torch.cat([waves, tail_padding]) + + num_processed_frames = 0 + segment = model.segment + offset = model.offset + + context_size = model.context_size + hyp = None + decoder_out = None + + chunk = int(1 * sample_rate) # 1 second + start = 0 + while start < wave_samples.numel(): + end = min(start + chunk, wave_samples.numel()) + samples = wave_samples[start:end] + start += chunk + + online_fbank.accept_waveform( + sampling_rate=sample_rate, + waveform=samples, + ) + + while online_fbank.num_frames_ready - num_processed_frames >= segment: + frames = [] + for i in range(segment): + frames.append(online_fbank.get_frame(num_processed_frames + i)) + num_processed_frames += offset + frames = torch.cat(frames, dim=0) + frames = frames.unsqueeze(0) + encoder_out = model.run_encoder(frames) + hyp, decoder_out = greedy_search( + model, + encoder_out, + context_size, + decoder_out, + hyp, + ) + + symbol_table = k2.SymbolTable.from_file(args.tokens) + + text = "" + for i in hyp[context_size:]: + text += symbol_table[i] + text = text.replace("▁", " ").strip() + + logging.info(args.sound_file) + logging.info(text) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py index 1b267c1c5..f7e52a9e6 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py @@ -270,7 +270,7 @@ class Zipformer(EncoderInterface): dim_feedforward (int, int): feedforward dimension in 2 encoder stacks num_encoder_layers (int): number of encoder layers dropout (float): dropout rate - cnn_module_kernel (int): Kernel size of convolution module + cnn_module_kernels (int): Kernel size of convolution module vgg_frontend (bool): whether to use vgg frontend. warmup_batches (float): number of batches to warm up over """ @@ -311,6 +311,8 @@ class Zipformer(EncoderInterface): # Used in decoding self.decode_chunk_size = decode_chunk_size + self.left_context_len = self.decode_chunk_size * self.num_left_chunks + # will be written to, see set_batch_count() self.batch_count = 0 self.warmup_end = warmup_batches @@ -330,7 +332,10 @@ class Zipformer(EncoderInterface): # each one will be ZipformerEncoder or DownsampledZipformerEncoder encoders = [] + self.num_encoder_layers = num_encoder_layers self.num_encoders = len(encoder_dims) + self.attention_dims = attention_dim + self.cnn_module_kernels = cnn_module_kernels for i in range(self.num_encoders): encoder_layer = ZipformerEncoderLayer( encoder_dims[i], @@ -382,7 +387,7 @@ class Zipformer(EncoderInterface): def _init_skip_modules(self): """ - If self.zipformer_downampling_factors = (1, 2, 4, 8, 4, 2), then at the input of layer + If self.zipformer_downsampling_factors = (1, 2, 4, 8, 4, 2), then at the input of layer indexed 4 (in zero indexing), with has subsapling_factor=4, we combine the output of layers 2 and 3; and at the input of layer indexed 5, which which has subsampling_factor=2, we combine the outputs of layers 1 and 5. @@ -695,7 +700,7 @@ class Zipformer(EncoderInterface): num_layers = encoder.num_layers ds = self.zipformer_downsampling_factors[i] - len_avg = torch.zeros(num_layers, 1, dtype=torch.int32, device=device) + len_avg = torch.zeros(num_layers, 1, dtype=torch.int64, device=device) cached_len.append(len_avg) avg = torch.zeros(num_layers, 1, encoder.d_model, device=device) From 8d3810e289d44b9e1b86e2f48b6f4f3b3185d62a Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 7 Feb 2023 15:01:59 +0800 Subject: [PATCH 099/174] Simplify ONNX export (#881) * Simplify ONNX export * Fix ONNX CI tests --- ...pruned-transducer-stateless3-2022-05-13.sh | 30 -- .github/scripts/test-onnx-export.sh | 56 +- ...runed-transducer-stateless3-2022-05-13.yml | 2 +- .../pruned_transducer_stateless2/joiner.py | 2 - .../export-onnx.py | 497 ++++++++++++++++++ .../pruned_transducer_stateless3/export.py | 270 +--------- .../onnx_check.py | 301 +++++------ .../onnx_pretrained.py | 344 ++++++------ .../export-onnx.py | 2 +- .../onnx_check.py | 6 - .../onnx_pretrained.py | 1 - 11 files changed, 874 insertions(+), 637 deletions(-) create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless3/export-onnx.py diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-05-13.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-05-13.sh index 880767443..ceb77c7c3 100755 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-05-13.sh +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-05-13.sh @@ -27,14 +27,6 @@ ln -s pretrained-iter-1224000-avg-14.pt pretrained.pt ln -s pretrained-iter-1224000-avg-14.pt epoch-99.pt popd -log "Test exporting to ONNX format" - -./pruned_transducer_stateless3/export.py \ - --exp-dir $repo/exp \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ - --epoch 99 \ - --avg 1 \ - --onnx 1 log "Export to torchscript model" ./pruned_transducer_stateless3/export.py \ @@ -51,30 +43,8 @@ log "Export to torchscript model" --avg 1 \ --jit-trace 1 -ls -lh $repo/exp/*.onnx ls -lh $repo/exp/*.pt -log "Decode with ONNX models" - -./pruned_transducer_stateless3/onnx_check.py \ - --jit-filename $repo/exp/cpu_jit.pt \ - --onnx-encoder-filename $repo/exp/encoder.onnx \ - --onnx-decoder-filename $repo/exp/decoder.onnx \ - --onnx-joiner-filename $repo/exp/joiner.onnx \ - --onnx-joiner-encoder-proj-filename $repo/exp/joiner_encoder_proj.onnx \ - --onnx-joiner-decoder-proj-filename $repo/exp/joiner_decoder_proj.onnx - -./pruned_transducer_stateless3/onnx_pretrained.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ - --encoder-model-filename $repo/exp/encoder.onnx \ - --decoder-model-filename $repo/exp/decoder.onnx \ - --joiner-model-filename $repo/exp/joiner.onnx \ - --joiner-encoder-proj-model-filename $repo/exp/joiner_encoder_proj.onnx \ - --joiner-decoder-proj-model-filename $repo/exp/joiner_decoder_proj.onnx \ - $repo/test_wavs/1089-134686-0001.wav \ - $repo/test_wavs/1221-135766-0001.wav \ - $repo/test_wavs/1221-135766-0002.wav - log "Decode with models exported by torch.jit.trace()" ./pruned_transducer_stateless3/jit_pretrained.py \ diff --git a/.github/scripts/test-onnx-export.sh b/.github/scripts/test-onnx-export.sh index 20aa02950..13a5aa765 100755 --- a/.github/scripts/test-onnx-export.sh +++ b/.github/scripts/test-onnx-export.sh @@ -10,9 +10,8 @@ log() { cd egs/librispeech/ASR -repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29 - log "==========================================================================" +repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29 log "Downloading pre-trained model from $repo_url" git lfs install GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url @@ -68,3 +67,56 @@ log "Run onnx_pretrained.py" rm -rf $repo log "--------------------------------------------------------------------------" + +log "==========================================================================" +repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13 +log "Downloading pre-trained model from $repo_url" +git lfs install +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "data/lang_bpe_500/bpe.model" +git lfs pull --include "exp/pretrained-iter-1224000-avg-14.pt" + +cd exp +ln -s pretrained-iter-1224000-avg-14.pt epoch-9999.pt +popd + +log "Export via torch.jit.script()" + +./pruned_transducer_stateless3/export.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --epoch 9999 \ + --avg 1 \ + --exp-dir $repo/exp/ \ + --jit 1 + +log "Test exporting to ONNX format" + +./pruned_transducer_stateless3/export-onnx.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --epoch 9999 \ + --avg 1 \ + --exp-dir $repo/exp/ + +ls -lh $repo/exp + +log "Run onnx_check.py" + +./pruned_transducer_stateless3/onnx_check.py \ + --jit-filename $repo/exp/cpu_jit.pt \ + --onnx-encoder-filename $repo/exp/encoder-epoch-9999-avg-1.onnx \ + --onnx-decoder-filename $repo/exp/decoder-epoch-9999-avg-1.onnx \ + --onnx-joiner-filename $repo/exp/joiner-epoch-9999-avg-1.onnx + +log "Run onnx_pretrained.py" + +./pruned_transducer_stateless3/onnx_pretrained.py \ + --encoder-model-filename $repo/exp/encoder-epoch-9999-avg-1.onnx \ + --decoder-model-filename $repo/exp/decoder-epoch-9999-avg-1.onnx \ + --joiner-model-filename $repo/exp/joiner-epoch-9999-avg-1.onnx \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav diff --git a/.github/workflows/run-librispeech-pruned-transducer-stateless3-2022-05-13.yml b/.github/workflows/run-librispeech-pruned-transducer-stateless3-2022-05-13.yml index 2c2bcab0c..f67f7599b 100644 --- a/.github/workflows/run-librispeech-pruned-transducer-stateless3-2022-05-13.yml +++ b/.github/workflows/run-librispeech-pruned-transducer-stateless3-2022-05-13.yml @@ -39,7 +39,7 @@ concurrency: jobs: run_librispeech_pruned_transducer_stateless3_2022_05_13: - if: github.event.label.name == 'onnx' || github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule' + if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule' runs-on: ${{ matrix.os }} strategy: matrix: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py b/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py index 1954f4724..9f88bd029 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py @@ -56,8 +56,6 @@ class Joiner(nn.Module): """ if not is_jit_tracing(): assert encoder_out.ndim == decoder_out.ndim - assert encoder_out.ndim in (2, 4) - assert encoder_out.shape == decoder_out.shape if project_input: logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/export-onnx.py b/egs/librispeech/ASR/pruned_transducer_stateless3/export-onnx.py new file mode 100755 index 000000000..1af68be70 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/export-onnx.py @@ -0,0 +1,497 @@ +#!/usr/bin/env python3 +# +# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang) + +""" +This script exports a transducer model from PyTorch to ONNX. + +We use the pre-trained model from +https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29 +as an example to show how to use this file. + +1. Download the pre-trained model + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "data/lang_bpe_500/bpe.model" +git lfs pull --include "exp/pretrained-iter-1224000-avg-14.pt" + +cd exp +ln -s pretrained-iter-1224000-avg-14.pt epoch-9999.pt +popd + +2. Export the model to ONNX + +./pruned_transducer_stateless3/export-onnx.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --epoch 9999 \ + --avg 1 \ + --exp-dir $repo/exp/ + +It will generate the following 3 files inside $repo/exp: + + - encoder-epoch-9999-avg-1.onnx + - decoder-epoch-9999-avg-1.onnx + - joiner-epoch-9999-avg-1.onnx + +See ./onnx_pretrained.py and ./onnx_check.py for how to +use the exported ONNX models. +""" + +import argparse +import logging +from pathlib import Path +from typing import Dict, Tuple + +import onnx +import sentencepiece as spm +import torch +import torch.nn as nn +from conformer import Conformer +from decoder import Decoder +from scaling_converter import convert_scaled_to_non_scaled +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint +from icefall.utils import setup_logger + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="""It specifies the checkpoint to use for averaging. + Note: Epoch counts from 0. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless3/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + add_model_arguments(parser) + + return parser + + +def add_meta_data(filename: str, meta_data: Dict[str, str]): + """Add meta data to an ONNX model. It is changed in-place. + + Args: + filename: + Filename of the ONNX model to be changed. + meta_data: + Key-value pairs. + """ + model = onnx.load(filename) + for key, value in meta_data.items(): + meta = model.metadata_props.add() + meta.key = key + meta.value = value + + onnx.save(model, filename) + + +class OnnxEncoder(nn.Module): + """A wrapper for Conformer and the encoder_proj from the joiner""" + + def __init__(self, encoder: Conformer, encoder_proj: nn.Linear): + """ + Args: + encoder: + A Conformer encoder. + encoder_proj: + The projection layer for encoder from the joiner. + """ + super().__init__() + self.encoder = encoder + self.encoder_proj = encoder_proj + + def forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Please see the help information of Conformer.forward + + Args: + x: + A 3-D tensor of shape (N, T, C) + x_lens: + A 1-D tensor of shape (N,). Its dtype is torch.int64 + Returns: + Return a tuple containing: + - encoder_out, A 3-D tensor of shape (N, T', joiner_dim) + - encoder_out_lens, A 1-D tensor of shape (N,) + """ + encoder_out, encoder_out_lens = self.encoder(x, x_lens) + + encoder_out = self.encoder_proj(encoder_out) + # Now encoder_out is of shape (N, T, joiner_dim) + + return encoder_out, encoder_out_lens + + +class OnnxDecoder(nn.Module): + """A wrapper for Decoder and the decoder_proj from the joiner""" + + def __init__(self, decoder: Decoder, decoder_proj: nn.Linear): + super().__init__() + self.decoder = decoder + self.decoder_proj = decoder_proj + + def forward(self, y: torch.Tensor) -> torch.Tensor: + """ + Args: + y: + A 2-D tensor of shape (N, context_size). + Returns + Return a 2-D tensor of shape (N, joiner_dim) + """ + need_pad = False + decoder_output = self.decoder(y, need_pad=need_pad) + decoder_output = decoder_output.squeeze(1) + output = self.decoder_proj(decoder_output) + + return output + + +class OnnxJoiner(nn.Module): + """A wrapper for the joiner""" + + def __init__(self, output_linear: nn.Linear): + super().__init__() + self.output_linear = output_linear + + def forward( + self, + encoder_out: torch.Tensor, + decoder_out: torch.Tensor, + ) -> torch.Tensor: + """ + Args: + encoder_out: + A 2-D tensor of shape (N, joiner_dim) + decoder_out: + A 2-D tensor of shape (N, joiner_dim) + Returns: + Return a 2-D tensor of shape (N, vocab_size) + """ + logit = encoder_out + decoder_out + logit = self.output_linear(torch.tanh(logit)) + return logit + + +def export_encoder_model_onnx( + encoder_model: OnnxEncoder, + encoder_filename: str, + opset_version: int = 11, +) -> None: + """Export the given encoder model to ONNX format. + The exported model has two inputs: + + - x, a tensor of shape (N, T, C); dtype is torch.float32 + - x_lens, a tensor of shape (N,); dtype is torch.int64 + + and it has two outputs: + + - encoder_out, a tensor of shape (N, T', joiner_dim) + - encoder_out_lens, a tensor of shape (N,) + + Args: + encoder_model: + The input encoder model + encoder_filename: + The filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + x = torch.zeros(1, 100, 80, dtype=torch.float32) + x_lens = torch.tensor([100], dtype=torch.int64) + + torch.onnx.export( + encoder_model, + (x, x_lens), + encoder_filename, + verbose=False, + opset_version=opset_version, + input_names=["x", "x_lens"], + output_names=["encoder_out", "encoder_out_lens"], + dynamic_axes={ + "x": {0: "N", 1: "T"}, + "x_lens": {0: "N"}, + "encoder_out": {0: "N", 1: "T"}, + "encoder_out_lens": {0: "N"}, + }, + ) + + +def export_decoder_model_onnx( + decoder_model: OnnxDecoder, + decoder_filename: str, + opset_version: int = 11, +) -> None: + """Export the decoder model to ONNX format. + + The exported model has one input: + + - y: a torch.int64 tensor of shape (N, decoder_model.context_size) + + and has one output: + + - decoder_out: a torch.float32 tensor of shape (N, joiner_dim) + + Args: + decoder_model: + The decoder model to be exported. + decoder_filename: + Filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + context_size = decoder_model.decoder.context_size + vocab_size = decoder_model.decoder.vocab_size + + y = torch.zeros(10, context_size, dtype=torch.int64) + torch.onnx.export( + decoder_model, + y, + decoder_filename, + verbose=False, + opset_version=opset_version, + input_names=["y"], + output_names=["decoder_out"], + dynamic_axes={ + "y": {0: "N"}, + "decoder_out": {0: "N"}, + }, + ) + + meta_data = { + "context_size": str(context_size), + "vocab_size": str(vocab_size), + } + add_meta_data(filename=decoder_filename, meta_data=meta_data) + + +def export_joiner_model_onnx( + joiner_model: nn.Module, + joiner_filename: str, + opset_version: int = 11, +) -> None: + """Export the joiner model to ONNX format. + The exported joiner model has two inputs: + + - encoder_out: a tensor of shape (N, joiner_dim) + - decoder_out: a tensor of shape (N, joiner_dim) + + and produces one output: + + - logit: a tensor of shape (N, vocab_size) + """ + joiner_dim = joiner_model.output_linear.weight.shape[1] + logging.info(f"joiner dim: {joiner_dim}") + + projected_encoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) + projected_decoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) + + torch.onnx.export( + joiner_model, + (projected_encoder_out, projected_decoder_out), + joiner_filename, + verbose=False, + opset_version=opset_version, + input_names=[ + "encoder_out", + "decoder_out", + ], + output_names=["logit"], + dynamic_axes={ + "encoder_out": {0: "N"}, + "decoder_out": {0: "N"}, + "logit": {0: "N"}, + }, + ) + meta_data = { + "joiner_dim": str(joiner_dim), + } + add_meta_data(filename=joiner_filename, meta_data=meta_data) + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + setup_logger(f"{params.exp_dir}/log-export/log-export-onnx") + + logging.info(f"device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params, enable_giga=False) + + model.to(device) + + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict( + average_checkpoints(filenames, device=device), strict=False + ) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if start >= 0: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict( + average_checkpoints(filenames, device=device), strict=False + ) + + model.to("cpu") + model.eval() + + convert_scaled_to_non_scaled(model, inplace=True) + + encoder = OnnxEncoder( + encoder=model.encoder, + encoder_proj=model.joiner.encoder_proj, + ) + + decoder = OnnxDecoder( + decoder=model.decoder, + decoder_proj=model.joiner.decoder_proj, + ) + + joiner = OnnxJoiner(output_linear=model.joiner.output_linear) + + encoder_num_param = sum([p.numel() for p in encoder.parameters()]) + decoder_num_param = sum([p.numel() for p in decoder.parameters()]) + joiner_num_param = sum([p.numel() for p in joiner.parameters()]) + total_num_param = encoder_num_param + decoder_num_param + joiner_num_param + logging.info(f"encoder parameters: {encoder_num_param}") + logging.info(f"decoder parameters: {decoder_num_param}") + logging.info(f"joiner parameters: {joiner_num_param}") + logging.info(f"total parameters: {total_num_param}") + + if params.iter > 0: + suffix = f"iter-{params.iter}" + else: + suffix = f"epoch-{params.epoch}" + + suffix += f"-avg-{params.avg}" + + opset_version = 13 + + logging.info("Exporting encoder") + encoder_filename = params.exp_dir / f"encoder-{suffix}.onnx" + export_encoder_model_onnx( + encoder, + encoder_filename, + opset_version=opset_version, + ) + logging.info(f"Exported encoder to {encoder_filename}") + + logging.info("Exporting decoder") + decoder_filename = params.exp_dir / f"decoder-{suffix}.onnx" + export_decoder_model_onnx( + decoder, + decoder_filename, + opset_version=opset_version, + ) + logging.info(f"Exported decoder to {decoder_filename}") + + logging.info("Exporting joiner") + joiner_filename = params.exp_dir / f"joiner-{suffix}.onnx" + export_joiner_model_onnx( + joiner, + joiner_filename, + opset_version=opset_version, + ) + logging.info(f"Exported joiner to {joiner_filename}") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/export.py b/egs/librispeech/ASR/pruned_transducer_stateless3/export.py index 239bdc12f..f30c9df6a 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/export.py @@ -52,32 +52,7 @@ It will also generate 3 other files: `encoder_jit_script.pt`, It will generates 3 files: `encoder_jit_trace.pt`, `decoder_jit_trace.pt`, and `joiner_jit_trace.pt`. - -(3) Export to ONNX format - -./pruned_transducer_stateless3/export.py \ - --exp-dir ./pruned_transducer_stateless3/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ - --epoch 20 \ - --avg 10 \ - --onnx 1 - -It will generate the following files in the given `exp_dir`. -Check `onnx_check.py` for how to use them. - - - encoder.onnx - - decoder.onnx - - joiner.onnx - - joiner_encoder_proj.onnx - - joiner_decoder_proj.onnx - -Please see ./onnx_pretrained.py for usage of the generated files - -Check -https://github.com/k2-fsa/sherpa-onnx -for how to use the exported models outside of icefall. - -(4) Export `model.state_dict()` +(3) Export `model.state_dict()` ./pruned_transducer_stateless3/export.py \ --exp-dir ./pruned_transducer_stateless3/exp \ @@ -210,23 +185,6 @@ def get_parser(): """, ) - parser.add_argument( - "--onnx", - type=str2bool, - default=False, - help="""If True, --jit is ignored and it exports the model - to onnx format. It will generate the following files: - - - encoder.onnx - - decoder.onnx - - joiner.onnx - - joiner_encoder_proj.onnx - - joiner_decoder_proj.onnx - - Refer to ./onnx_check.py and ./onnx_pretrained.py for how to use them. - """, - ) - parser.add_argument( "--context-size", type=int, @@ -370,206 +328,6 @@ def export_joiner_model_jit_trace( logging.info(f"Saved to {joiner_filename}") -def export_encoder_model_onnx( - encoder_model: nn.Module, - encoder_filename: str, - opset_version: int = 11, -) -> None: - """Export the given encoder model to ONNX format. - The exported model has two inputs: - - - x, a tensor of shape (N, T, C); dtype is torch.float32 - - x_lens, a tensor of shape (N,); dtype is torch.int64 - - and it has two outputs: - - - encoder_out, a tensor of shape (N, T, C) - - encoder_out_lens, a tensor of shape (N,) - - Note: The warmup argument is fixed to 1. - - Args: - encoder_model: - The input encoder model - encoder_filename: - The filename to save the exported ONNX model. - opset_version: - The opset version to use. - """ - x = torch.zeros(1, 100, 80, dtype=torch.float32) - x_lens = torch.tensor([100], dtype=torch.int64) - - # encoder_model = torch.jit.script(encoder_model) - # It throws the following error for the above statement - # - # RuntimeError: Exporting the operator __is_ to ONNX opset version - # 11 is not supported. Please feel free to request support or - # submit a pull request on PyTorch GitHub. - # - # I cannot find which statement causes the above error. - # torch.onnx.export() will use torch.jit.trace() internally, which - # works well for the current reworked model - warmup = 1.0 - torch.onnx.export( - encoder_model, - (x, x_lens, warmup), - encoder_filename, - verbose=False, - opset_version=opset_version, - input_names=["x", "x_lens", "warmup"], - output_names=["encoder_out", "encoder_out_lens"], - dynamic_axes={ - "x": {0: "N", 1: "T"}, - "x_lens": {0: "N"}, - "encoder_out": {0: "N", 1: "T"}, - "encoder_out_lens": {0: "N"}, - }, - ) - logging.info(f"Saved to {encoder_filename}") - - -def export_decoder_model_onnx( - decoder_model: nn.Module, - decoder_filename: str, - opset_version: int = 11, -) -> None: - """Export the decoder model to ONNX format. - - The exported model has one input: - - - y: a torch.int64 tensor of shape (N, decoder_model.context_size) - - and has one output: - - - decoder_out: a torch.float32 tensor of shape (N, 1, C) - - Note: The argument need_pad is fixed to False. - - Args: - decoder_model: - The decoder model to be exported. - decoder_filename: - Filename to save the exported ONNX model. - opset_version: - The opset version to use. - """ - y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64) - need_pad = False # Always False, so we can use torch.jit.trace() here - # Note(fangjun): torch.jit.trace() is more efficient than torch.jit.script() - # in this case - torch.onnx.export( - decoder_model, - (y, need_pad), - decoder_filename, - verbose=False, - opset_version=opset_version, - input_names=["y", "need_pad"], - output_names=["decoder_out"], - dynamic_axes={ - "y": {0: "N"}, - "decoder_out": {0: "N"}, - }, - ) - logging.info(f"Saved to {decoder_filename}") - - -def export_joiner_model_onnx( - joiner_model: nn.Module, - joiner_filename: str, - opset_version: int = 11, -) -> None: - """Export the joiner model to ONNX format. - The exported joiner model has two inputs: - - - projected_encoder_out: a tensor of shape (N, joiner_dim) - - projected_decoder_out: a tensor of shape (N, joiner_dim) - - and produces one output: - - - logit: a tensor of shape (N, vocab_size) - - The exported encoder_proj model has one input: - - - encoder_out: a tensor of shape (N, encoder_out_dim) - - and produces one output: - - - projected_encoder_out: a tensor of shape (N, joiner_dim) - - The exported decoder_proj model has one input: - - - decoder_out: a tensor of shape (N, decoder_out_dim) - - and produces one output: - - - projected_decoder_out: a tensor of shape (N, joiner_dim) - """ - encoder_proj_filename = str(joiner_filename).replace(".onnx", "_encoder_proj.onnx") - - decoder_proj_filename = str(joiner_filename).replace(".onnx", "_decoder_proj.onnx") - - encoder_out_dim = joiner_model.encoder_proj.weight.shape[1] - decoder_out_dim = joiner_model.decoder_proj.weight.shape[1] - joiner_dim = joiner_model.decoder_proj.weight.shape[0] - - projected_encoder_out = torch.rand(1, joiner_dim, dtype=torch.float32) - projected_decoder_out = torch.rand(1, joiner_dim, dtype=torch.float32) - - project_input = False - # Note: It uses torch.jit.trace() internally - torch.onnx.export( - joiner_model, - (projected_encoder_out, projected_decoder_out, project_input), - joiner_filename, - verbose=False, - opset_version=opset_version, - input_names=[ - "projected_encoder_out", - "projected_decoder_out", - "project_input", - ], - output_names=["logit"], - dynamic_axes={ - "projected_encoder_out": {0: "N"}, - "projected_decoder_out": {0: "N"}, - "logit": {0: "N"}, - }, - ) - logging.info(f"Saved to {joiner_filename}") - - encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32) - torch.onnx.export( - joiner_model.encoder_proj, - encoder_out, - encoder_proj_filename, - verbose=False, - opset_version=opset_version, - input_names=["encoder_out"], - output_names=["projected_encoder_out"], - dynamic_axes={ - "encoder_out": {0: "N"}, - "projected_encoder_out": {0: "N"}, - }, - ) - logging.info(f"Saved to {encoder_proj_filename}") - - decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32) - torch.onnx.export( - joiner_model.decoder_proj, - decoder_out, - decoder_proj_filename, - verbose=False, - opset_version=opset_version, - input_names=["decoder_out"], - output_names=["projected_decoder_out"], - dynamic_axes={ - "decoder_out": {0: "N"}, - "projected_decoder_out": {0: "N"}, - }, - ) - logging.info(f"Saved to {decoder_proj_filename}") - - @torch.no_grad() def main(): args = get_parser().parse_args() @@ -636,31 +394,7 @@ def main(): model.to("cpu") model.eval() - if params.onnx is True: - convert_scaled_to_non_scaled(model, inplace=True) - opset_version = 11 - logging.info("Exporting to onnx format") - encoder_filename = params.exp_dir / "encoder.onnx" - export_encoder_model_onnx( - model.encoder, - encoder_filename, - opset_version=opset_version, - ) - - decoder_filename = params.exp_dir / "decoder.onnx" - export_decoder_model_onnx( - model.decoder, - decoder_filename, - opset_version=opset_version, - ) - - joiner_filename = params.exp_dir / "joiner.onnx" - export_joiner_model_onnx( - model.joiner, - joiner_filename, - opset_version=opset_version, - ) - elif params.jit is True: + if params.jit is True: convert_scaled_to_non_scaled(model, inplace=True) logging.info("Using torch.jit.script()") # We won't use the forward() method of the model in C++, so just ignore diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check.py b/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check.py index 163d737e3..6541f0295 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check.py @@ -19,21 +19,70 @@ """ This script checks that exported onnx models produce the same output with the given torchscript model for the same input. + +We use the pre-trained model from +https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13 +as an example to show how to use this file. + +1. Download the pre-trained model + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "data/lang_bpe_500/bpe.model" +git lfs pull --include "exp/pretrained-iter-1224000-avg-14.pt" + +cd exp +ln -s pretrained-iter-1224000-avg-14.pt epoch-9999.pt +popd + +2. Export the model via torchscript (torch.jit.script()) + +./pruned_transducer_stateless3/export.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --epoch 9999 \ + --avg 1 \ + --exp-dir $repo/exp/ \ + --jit 1 + +It will generate the following file in $repo/exp: + - cpu_jit.pt + +3. Export the model to ONNX + +./pruned_transducer_stateless3/export-onnx.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --epoch 9999 \ + --avg 1 \ + --exp-dir $repo/exp/ + +It will generate the following 3 files inside $repo/exp: + + - encoder-epoch-9999-avg-1.onnx + - decoder-epoch-9999-avg-1.onnx + - joiner-epoch-9999-avg-1.onnx + +4. Run this file + +./pruned_transducer_stateless3/onnx_check.py \ + --jit-filename $repo/exp/cpu_jit.pt \ + --onnx-encoder-filename $repo/exp/encoder-epoch-9999-avg-1.onnx \ + --onnx-decoder-filename $repo/exp/decoder-epoch-9999-avg-1.onnx \ + --onnx-joiner-filename $repo/exp/joiner-epoch-9999-avg-1.onnx """ import argparse import logging from icefall import is_module_available +from onnx_pretrained import OnnxModel -if not is_module_available("onnxruntime"): - raise ValueError("Please 'pip install onnxruntime' first.") - -import onnxruntime as ort import torch -ort.set_default_logger_severity(3) - def get_parser(): parser = argparse.ArgumentParser( @@ -68,163 +117,81 @@ def get_parser(): help="Path to the onnx joiner model", ) - parser.add_argument( - "--onnx-joiner-encoder-proj-filename", - required=True, - type=str, - help="Path to the onnx joiner encoder projection model", - ) - - parser.add_argument( - "--onnx-joiner-decoder-proj-filename", - required=True, - type=str, - help="Path to the onnx joiner decoder projection model", - ) - return parser def test_encoder( - model: torch.jit.ScriptModule, - encoder_session: ort.InferenceSession, + torch_model: torch.jit.ScriptModule, + onnx_model: OnnxModel, ): - inputs = encoder_session.get_inputs() - outputs = encoder_session.get_outputs() - input_names = [n.name for n in inputs] - output_names = [n.name for n in outputs] + C = 80 + for i in range(10): + N = torch.randint(low=1, high=20, size=(1,)).item() + T = torch.randint(low=50, high=100, size=(1,)).item() + logging.info(f"test_encoder: iter {i}, N={N}, T={T}") - assert inputs[0].shape == ["N", "T", 80] - assert inputs[1].shape == ["N"] + x = torch.rand(N, T, C) + x_lens = torch.randint(low=10, high=T + 1, size=(N,)) + x_lens[0] = T - for N in [1, 5]: - for T in [12, 25]: - print("N, T", N, T) - x = torch.rand(N, T, 80, dtype=torch.float32) - x_lens = torch.randint(low=10, high=T + 1, size=(N,)) - x_lens[0] = T + torch_encoder_out, torch_encoder_out_lens = torch_model.encoder(x, x_lens) + torch_encoder_out = torch_model.joiner.encoder_proj(torch_encoder_out) - encoder_inputs = { - input_names[0]: x.numpy(), - input_names[1]: x_lens.numpy(), - } - encoder_out, encoder_out_lens = encoder_session.run( - output_names, - encoder_inputs, - ) + onnx_encoder_out, onnx_encoder_out_lens = onnx_model.run_encoder(x, x_lens) - torch_encoder_out, torch_encoder_out_lens = model.encoder(x, x_lens) - - encoder_out = torch.from_numpy(encoder_out) - assert torch.allclose(encoder_out, torch_encoder_out, atol=1e-05), ( - (encoder_out - torch_encoder_out).abs().max(), - encoder_out.shape, - torch_encoder_out.shape, - ) + assert torch.allclose(torch_encoder_out, onnx_encoder_out, atol=1e-05), ( + (torch_encoder_out - onnx_encoder_out).abs().max() + ) def test_decoder( - model: torch.jit.ScriptModule, - decoder_session: ort.InferenceSession, + torch_model: torch.jit.ScriptModule, + onnx_model: OnnxModel, ): - inputs = decoder_session.get_inputs() - outputs = decoder_session.get_outputs() - input_names = [n.name for n in inputs] - output_names = [n.name for n in outputs] + context_size = onnx_model.context_size + vocab_size = onnx_model.vocab_size + for i in range(10): + N = torch.randint(1, 100, size=(1,)).item() + logging.info(f"test_decoder: iter {i}, N={N}") + x = torch.randint( + low=1, + high=vocab_size, + size=(N, context_size), + dtype=torch.int64, + ) + torch_decoder_out = torch_model.decoder(x, need_pad=torch.tensor([False])) + torch_decoder_out = torch_model.joiner.decoder_proj(torch_decoder_out) + torch_decoder_out = torch_decoder_out.squeeze(1) - assert inputs[0].shape == ["N", 2] - for N in [1, 5, 10]: - y = torch.randint(low=1, high=500, size=(10, 2)) - - decoder_inputs = {input_names[0]: y.numpy()} - decoder_out = decoder_session.run( - output_names, - decoder_inputs, - )[0] - decoder_out = torch.from_numpy(decoder_out) - - torch_decoder_out = model.decoder(y, need_pad=False) - assert torch.allclose(decoder_out, torch_decoder_out, atol=1e-5), ( - (decoder_out - torch_decoder_out).abs().max() + onnx_decoder_out = onnx_model.run_decoder(x) + assert torch.allclose(torch_decoder_out, onnx_decoder_out, atol=1e-4), ( + (torch_decoder_out - onnx_decoder_out).abs().max() ) def test_joiner( - model: torch.jit.ScriptModule, - joiner_session: ort.InferenceSession, - joiner_encoder_proj_session: ort.InferenceSession, - joiner_decoder_proj_session: ort.InferenceSession, + torch_model: torch.jit.ScriptModule, + onnx_model: OnnxModel, ): - joiner_inputs = joiner_session.get_inputs() - joiner_outputs = joiner_session.get_outputs() - joiner_input_names = [n.name for n in joiner_inputs] - joiner_output_names = [n.name for n in joiner_outputs] + encoder_dim = torch_model.joiner.encoder_proj.weight.shape[1] + decoder_dim = torch_model.joiner.decoder_proj.weight.shape[1] + for i in range(10): + N = torch.randint(1, 100, size=(1,)).item() + logging.info(f"test_joiner: iter {i}, N={N}") + encoder_out = torch.rand(N, encoder_dim) + decoder_out = torch.rand(N, decoder_dim) - assert joiner_inputs[0].shape == ["N", 512] - assert joiner_inputs[1].shape == ["N", 512] + projected_encoder_out = torch_model.joiner.encoder_proj(encoder_out) + projected_decoder_out = torch_model.joiner.decoder_proj(decoder_out) - joiner_encoder_proj_inputs = joiner_encoder_proj_session.get_inputs() - encoder_proj_input_name = joiner_encoder_proj_inputs[0].name - - assert joiner_encoder_proj_inputs[0].shape == ["N", 512] - - joiner_encoder_proj_outputs = joiner_encoder_proj_session.get_outputs() - encoder_proj_output_name = joiner_encoder_proj_outputs[0].name - - joiner_decoder_proj_inputs = joiner_decoder_proj_session.get_inputs() - decoder_proj_input_name = joiner_decoder_proj_inputs[0].name - - assert joiner_decoder_proj_inputs[0].shape == ["N", 512] - - joiner_decoder_proj_outputs = joiner_decoder_proj_session.get_outputs() - decoder_proj_output_name = joiner_decoder_proj_outputs[0].name - - for N in [1, 5, 10]: - encoder_out = torch.rand(N, 512) - decoder_out = torch.rand(N, 512) - - projected_encoder_out = torch.rand(N, 512) - projected_decoder_out = torch.rand(N, 512) - - joiner_inputs = { - joiner_input_names[0]: projected_encoder_out.numpy(), - joiner_input_names[1]: projected_decoder_out.numpy(), - } - joiner_out = joiner_session.run(joiner_output_names, joiner_inputs)[0] - joiner_out = torch.from_numpy(joiner_out) - - torch_joiner_out = model.joiner( - projected_encoder_out, - projected_decoder_out, - project_input=False, - ) - assert torch.allclose(joiner_out, torch_joiner_out, atol=1e-5), ( - (joiner_out - torch_joiner_out).abs().max() + torch_joiner_out = torch_model.joiner(encoder_out, decoder_out) + onnx_joiner_out = onnx_model.run_joiner( + projected_encoder_out, projected_decoder_out ) - # Now test encoder_proj - joiner_encoder_proj_inputs = {encoder_proj_input_name: encoder_out.numpy()} - joiner_encoder_proj_out = joiner_encoder_proj_session.run( - [encoder_proj_output_name], joiner_encoder_proj_inputs - )[0] - joiner_encoder_proj_out = torch.from_numpy(joiner_encoder_proj_out) - - torch_joiner_encoder_proj_out = model.joiner.encoder_proj(encoder_out) - assert torch.allclose( - joiner_encoder_proj_out, torch_joiner_encoder_proj_out, atol=1e-5 - ), ((joiner_encoder_proj_out - torch_joiner_encoder_proj_out).abs().max()) - - # Now test decoder_proj - joiner_decoder_proj_inputs = {decoder_proj_input_name: decoder_out.numpy()} - joiner_decoder_proj_out = joiner_decoder_proj_session.run( - [decoder_proj_output_name], joiner_decoder_proj_inputs - )[0] - joiner_decoder_proj_out = torch.from_numpy(joiner_decoder_proj_out) - - torch_joiner_decoder_proj_out = model.joiner.decoder_proj(decoder_out) - assert torch.allclose( - joiner_decoder_proj_out, torch_joiner_decoder_proj_out, atol=1e-5 - ), ((joiner_decoder_proj_out - torch_joiner_decoder_proj_out).abs().max()) + assert torch.allclose(torch_joiner_out, onnx_joiner_out, atol=1e-4), ( + (torch_joiner_out - onnx_joiner_out).abs().max() + ) @torch.no_grad() @@ -232,48 +199,38 @@ def main(): args = get_parser().parse_args() logging.info(vars(args)) - model = torch.jit.load(args.jit_filename) + torch_model = torch.jit.load(args.jit_filename) - options = ort.SessionOptions() - options.inter_op_num_threads = 1 - options.intra_op_num_threads = 1 + onnx_model = OnnxModel( + encoder_model_filename=args.onnx_encoder_filename, + decoder_model_filename=args.onnx_decoder_filename, + joiner_model_filename=args.onnx_joiner_filename, + ) logging.info("Test encoder") - encoder_session = ort.InferenceSession( - args.onnx_encoder_filename, - sess_options=options, - ) - test_encoder(model, encoder_session) + test_encoder(torch_model, onnx_model) logging.info("Test decoder") - decoder_session = ort.InferenceSession( - args.onnx_decoder_filename, - sess_options=options, - ) - test_decoder(model, decoder_session) + test_decoder(torch_model, onnx_model) logging.info("Test joiner") - joiner_session = ort.InferenceSession( - args.onnx_joiner_filename, - sess_options=options, - ) - joiner_encoder_proj_session = ort.InferenceSession( - args.onnx_joiner_encoder_proj_filename, - sess_options=options, - ) - joiner_decoder_proj_session = ort.InferenceSession( - args.onnx_joiner_decoder_proj_filename, - sess_options=options, - ) - test_joiner( - model, - joiner_session, - joiner_encoder_proj_session, - joiner_decoder_proj_session, - ) + test_joiner(torch_model, onnx_model) logging.info("Finished checking ONNX models") +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +# See https://github.com/pytorch/pytorch/issues/38342 +# and https://github.com/pytorch/pytorch/issues/33354 +# +# If we don't do this, the delay increases whenever there is +# a new request that changes the actual batch size. +# If you use `py-spy dump --pid --native`, you will +# see a lot of time is spent in re-compiling the torch script model. +torch._C._jit_set_profiling_executor(False) +torch._C._jit_set_profiling_mode(False) +torch._C._set_graph_executor_optimize(False) if __name__ == "__main__": torch.manual_seed(20220727) formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_pretrained.py index 550cf6aad..5adb6c16a 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_pretrained.py @@ -18,35 +18,61 @@ This script loads ONNX models and uses them to decode waves. You can use the following command to get the exported models: -./pruned_transducer_stateless3/export.py \ - --exp-dir ./pruned_transducer_stateless3/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ - --epoch 20 \ - --avg 10 \ - --onnx 1 +We use the pre-trained model from +https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13 +as an example to show how to use this file. -Usage of this script: +1. Download the pre-trained model + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "data/lang_bpe_500/bpe.model" +git lfs pull --include "exp/pretrained-iter-1224000-avg-14.pt" + +cd exp +ln -s pretrained-iter-1224000-avg-14.pt epoch-9999.pt +popd + +2. Export the model to ONNX + +./pruned_transducer_stateless3/export-onnx.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --epoch 9999 \ + --avg 1 \ + --exp-dir $repo/exp/ + +It will generate the following 3 files inside $repo/exp: + + - encoder-epoch-9999-avg-1.onnx + - decoder-epoch-9999-avg-1.onnx + - joiner-epoch-9999-avg-1.onnx + +3. Run this file ./pruned_transducer_stateless3/onnx_pretrained.py \ - --encoder-model-filename ./pruned_transducer_stateless3/exp/encoder.onnx \ - --decoder-model-filename ./pruned_transducer_stateless3/exp/decoder.onnx \ - --joiner-model-filename ./pruned_transducer_stateless3/exp/joiner.onnx \ - --joiner-encoder-proj-model-filename ./pruned_transducer_stateless3/exp/joiner_encoder_proj.onnx \ - --joiner-decoder-proj-model-filename ./pruned_transducer_stateless3/exp/joiner_decoder_proj.onnx \ - --bpe-model ./data/lang_bpe_500/bpe.model \ - /path/to/foo.wav \ - /path/to/bar.wav + --encoder-model-filename $repo/exp/encoder-epoch-9999-avg-1.onnx \ + --decoder-model-filename $repo/exp/decoder-epoch-9999-avg-1.onnx \ + --joiner-model-filename $repo/exp/joiner-epoch-9999-avg-1.onnx \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav """ import argparse import logging import math -from typing import List +from typing import List, Tuple +import k2 import kaldifeat import numpy as np import onnxruntime as ort -import sentencepiece as spm import torch import torchaudio from torch.nn.utils.rnn import pad_sequence @@ -79,23 +105,9 @@ def get_parser(): ) parser.add_argument( - "--joiner-encoder-proj-model-filename", + "--tokens", type=str, - required=True, - help="Path to the joiner encoder_proj onnx model. ", - ) - - parser.add_argument( - "--joiner-decoder-proj-model-filename", - type=str, - required=True, - help="Path to the joiner decoder_proj onnx model. ", - ) - - parser.add_argument( - "--bpe-model", - type=str, - help="""Path to bpe.model.""", + help="""Path to tokens.txt.""", ) parser.add_argument( @@ -115,16 +127,122 @@ def get_parser(): help="The sample rate of the input sound file", ) - parser.add_argument( - "--context-size", - type=int, - default=2, - help="Context size of the decoder model", - ) - return parser +class OnnxModel: + def __init__( + self, + encoder_model_filename: str, + decoder_model_filename: str, + joiner_model_filename: str, + ): + session_opts = ort.SessionOptions() + session_opts.inter_op_num_threads = 1 + session_opts.intra_op_num_threads = 1 + + self.session_opts = session_opts + + self.init_encoder(encoder_model_filename) + self.init_decoder(decoder_model_filename) + self.init_joiner(joiner_model_filename) + + def init_encoder(self, encoder_model_filename: str): + self.encoder = ort.InferenceSession( + encoder_model_filename, + sess_options=self.session_opts, + ) + + def init_decoder(self, decoder_model_filename: str): + self.decoder = ort.InferenceSession( + decoder_model_filename, + sess_options=self.session_opts, + ) + + decoder_meta = self.decoder.get_modelmeta().custom_metadata_map + self.context_size = int(decoder_meta["context_size"]) + self.vocab_size = int(decoder_meta["vocab_size"]) + + logging.info(f"context_size: {self.context_size}") + logging.info(f"vocab_size: {self.vocab_size}") + + def init_joiner(self, joiner_model_filename: str): + self.joiner = ort.InferenceSession( + joiner_model_filename, + sess_options=self.session_opts, + ) + + joiner_meta = self.joiner.get_modelmeta().custom_metadata_map + self.joiner_dim = int(joiner_meta["joiner_dim"]) + + logging.info(f"joiner_dim: {self.joiner_dim}") + + def run_encoder( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + x: + A 3-D tensor of shape (N, T, C) + x_lens: + A 2-D tensor of shape (N,). Its dtype is torch.int64 + Returns: + Return a tuple containing: + - encoder_out, its shape is (N, T', joiner_dim) + - encoder_out_lens, its shape is (N,) + """ + out = self.encoder.run( + [ + self.encoder.get_outputs()[0].name, + self.encoder.get_outputs()[1].name, + ], + { + self.encoder.get_inputs()[0].name: x.numpy(), + self.encoder.get_inputs()[1].name: x_lens.numpy(), + }, + ) + return torch.from_numpy(out[0]), torch.from_numpy(out[1]) + + def run_decoder(self, decoder_input: torch.Tensor) -> torch.Tensor: + """ + Args: + decoder_input: + A 2-D tensor of shape (N, context_size) + Returns: + Return a 2-D tensor of shape (N, joiner_dim) + """ + out = self.decoder.run( + [self.decoder.get_outputs()[0].name], + {self.decoder.get_inputs()[0].name: decoder_input.numpy()}, + )[0] + + return torch.from_numpy(out) + + def run_joiner( + self, encoder_out: torch.Tensor, decoder_out: torch.Tensor + ) -> torch.Tensor: + """ + Args: + encoder_out: + A 2-D tensor of shape (N, joiner_dim) + decoder_out: + A 2-D tensor of shape (N, joiner_dim) + Returns: + Return a 2-D tensor of shape (N, vocab_size) + """ + out = self.joiner.run( + [self.joiner.get_outputs()[0].name], + { + self.joiner.get_inputs()[0].name: encoder_out.numpy(), + self.joiner.get_inputs()[1].name: decoder_out.numpy(), + }, + )[0] + + return torch.from_numpy(out) + + def read_sound_files( filenames: List[str], expected_sample_rate: float ) -> List[torch.Tensor]: @@ -149,36 +267,22 @@ def read_sound_files( def greedy_search( - decoder: ort.InferenceSession, - joiner: ort.InferenceSession, - joiner_encoder_proj: ort.InferenceSession, - joiner_decoder_proj: ort.InferenceSession, - encoder_out: np.ndarray, - encoder_out_lens: np.ndarray, - context_size: int, + model: OnnxModel, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, ) -> List[List[int]]: """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. Args: - decoder: - The decoder model. - joiner: - The joiner model. - joiner_encoder_proj: - The joiner encoder projection model. - joiner_decoder_proj: - The joiner decoder projection model. + model: + The transducer model. encoder_out: - A 3-D tensor of shape (N, T, C) + A 3-D tensor of shape (N, T, joiner_dim) encoder_out_lens: A 1-D tensor of shape (N,). - context_size: - The context size of the decoder model. Returns: Return the decoded results for each utterance. """ - encoder_out = torch.from_numpy(encoder_out) - encoder_out_lens = torch.from_numpy(encoder_out_lens) - assert encoder_out.ndim == 3 + assert encoder_out.ndim == 3, encoder_out.shape assert encoder_out.size(0) >= 1, encoder_out.size(0) packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( @@ -188,11 +292,6 @@ def greedy_search( enforce_sorted=False, ) - projected_encoder_out = joiner_encoder_proj.run( - [joiner_encoder_proj.get_outputs()[0].name], - {joiner_encoder_proj.get_inputs()[0].name: packed_encoder_out.data.numpy()}, - )[0] - blank_id = 0 # hard-code to 0 batch_size_list = packed_encoder_out.batch_sizes.tolist() @@ -201,50 +300,27 @@ def greedy_search( assert torch.all(encoder_out_lens > 0), encoder_out_lens assert N == batch_size_list[0], (N, batch_size_list) + context_size = model.context_size hyps = [[blank_id] * context_size for _ in range(N)] - decoder_input_nodes = decoder.get_inputs() - decoder_output_nodes = decoder.get_outputs() - - joiner_input_nodes = joiner.get_inputs() - joiner_output_nodes = joiner.get_outputs() - decoder_input = torch.tensor( hyps, dtype=torch.int64, ) # (N, context_size) - decoder_out = decoder.run( - [decoder_output_nodes[0].name], - { - decoder_input_nodes[0].name: decoder_input.numpy(), - }, - )[0].squeeze(1) - projected_decoder_out = joiner_decoder_proj.run( - [joiner_decoder_proj.get_outputs()[0].name], - {joiner_decoder_proj.get_inputs()[0].name: decoder_out}, - )[0] - - projected_decoder_out = torch.from_numpy(projected_decoder_out) + decoder_out = model.run_decoder(decoder_input) offset = 0 for batch_size in batch_size_list: start = offset end = offset + batch_size - current_encoder_out = projected_encoder_out[start:end] - # current_encoder_out's shape: (batch_size, encoder_out_dim) + current_encoder_out = packed_encoder_out.data[start:end] + # current_encoder_out's shape: (batch_size, joiner_dim) offset = end - projected_decoder_out = projected_decoder_out[:batch_size] + decoder_out = decoder_out[:batch_size] + logits = model.run_joiner(current_encoder_out, decoder_out) - logits = joiner.run( - [joiner_output_nodes[0].name], - { - joiner_input_nodes[0].name: current_encoder_out, - joiner_input_nodes[1].name: projected_decoder_out.numpy(), - }, - )[0] - logits = torch.from_numpy(logits).squeeze(1).squeeze(1) # logits'shape (batch_size, vocab_size) assert logits.ndim == 2, logits.shape @@ -261,17 +337,7 @@ def greedy_search( decoder_input, dtype=torch.int64, ) - decoder_out = decoder.run( - [decoder_output_nodes[0].name], - { - decoder_input_nodes[0].name: decoder_input.numpy(), - }, - )[0].squeeze(1) - projected_decoder_out = joiner_decoder_proj.run( - [joiner_decoder_proj.get_outputs()[0].name], - {joiner_decoder_proj.get_inputs()[0].name: decoder_out}, - )[0] - projected_decoder_out = torch.from_numpy(projected_decoder_out) + decoder_out = model.run_decoder(decoder_input) sorted_ans = [h[context_size:] for h in hyps] ans = [] @@ -287,39 +353,12 @@ def main(): parser = get_parser() args = parser.parse_args() logging.info(vars(args)) - - session_opts = ort.SessionOptions() - session_opts.inter_op_num_threads = 1 - session_opts.intra_op_num_threads = 1 - - encoder = ort.InferenceSession( - args.encoder_model_filename, - sess_options=session_opts, + model = OnnxModel( + encoder_model_filename=args.encoder_model_filename, + decoder_model_filename=args.decoder_model_filename, + joiner_model_filename=args.joiner_model_filename, ) - decoder = ort.InferenceSession( - args.decoder_model_filename, - sess_options=session_opts, - ) - - joiner = ort.InferenceSession( - args.joiner_model_filename, - sess_options=session_opts, - ) - - joiner_encoder_proj = ort.InferenceSession( - args.joiner_encoder_proj_model_filename, - sess_options=session_opts, - ) - - joiner_decoder_proj = ort.InferenceSession( - args.joiner_decoder_proj_model_filename, - sess_options=session_opts, - ) - - sp = spm.SentencePieceProcessor() - sp.load(args.bpe_model) - logging.info("Constructing Fbank computer") opts = kaldifeat.FbankOptions() opts.device = "cpu" @@ -347,30 +386,27 @@ def main(): ) feature_lengths = torch.tensor(feature_lengths, dtype=torch.int64) - - encoder_input_nodes = encoder.get_inputs() - encoder_out_nodes = encoder.get_outputs() - encoder_out, encoder_out_lens = encoder.run( - [encoder_out_nodes[0].name, encoder_out_nodes[1].name], - { - encoder_input_nodes[0].name: features.numpy(), - encoder_input_nodes[1].name: feature_lengths.numpy(), - }, - ) + encoder_out, encoder_out_lens = model.run_encoder(features, feature_lengths) hyps = greedy_search( - decoder=decoder, - joiner=joiner, - joiner_encoder_proj=joiner_encoder_proj, - joiner_decoder_proj=joiner_decoder_proj, + model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, - context_size=args.context_size, ) s = "\n" + + symbol_table = k2.SymbolTable.from_file(args.tokens) + + def token_ids_to_words(token_ids: List[int]) -> str: + text = "" + for i in token_ids: + text += symbol_table[i] + return text.replace("▁", " ").strip() + + context_size = model.context_size for filename, hyp in zip(args.sound_files, hyps): - words = sp.decode(hyp) - s += f"{filename}:\n{words}\n\n" + words = token_ids_to_words(hyp[context_size:]) + s += f"{filename}:\n{words}\n" logging.info(s) logging.info("Decoding Done") diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx.py index a72472495..35d6b0556 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx.py @@ -146,7 +146,7 @@ class OnnxEncoder(nn.Module): """ Args: encoder: - A zipformer encoder. + A Zipformer encoder. encoder_proj: The projection layer for encoder from the joiner. """ diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_check.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_check.py index 72ad59a55..6c78ba70b 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_check.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_check.py @@ -76,14 +76,8 @@ from zipformer import stack_states from icefall import is_module_available -if not is_module_available("onnxruntime"): - raise ValueError("Please 'pip install onnxruntime' first.") - -import onnxruntime as ort import torch -ort.set_default_logger_severity(3) - def get_parser(): parser = argparse.ArgumentParser( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_pretrained.py index 265684d18..715560c70 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_pretrained.py @@ -333,7 +333,6 @@ class OnnxModel: self.joiner.get_inputs()[1].name: decoder_out.numpy(), }, )[0] - return torch.from_numpy(out) return torch.from_numpy(out) From ffbf6d919931fb1034fa28550b4221a705f2dff4 Mon Sep 17 00:00:00 2001 From: Yifan Yang <64255737+yfyeung@users.noreply.github.com> Date: Tue, 7 Feb 2023 16:19:08 +0800 Subject: [PATCH 100/174] Add generate_averaged_model.py (#882) --- .../generate_averaged_model.py | 203 ++++++++++++++++++ 1 file changed, 203 insertions(+) create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless7/generate_averaged_model.py diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/generate_averaged_model.py b/egs/librispeech/ASR/pruned_transducer_stateless7/generate_averaged_model.py new file mode 100755 index 000000000..381772ce7 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/generate_averaged_model.py @@ -0,0 +1,203 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2022 Xiaomi Corporation (Author: Yifan Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) use the checkpoint exp_dir/epoch-xxx.pt +./pruned_transducer_stateless7/generate_averaged_model.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7/exp + +It will generate a file `epoch-28-avg-15.pt` in the given `exp_dir`. +You can later load it by `torch.load("epoch-28-avg-15.pt")`. + +(2) use the checkpoint exp_dir/checkpoint-iter.pt +./pruned_transducer_stateless7/generate_averaged_model.py \ + --iter 22000 \ + --avg 5 \ + --exp-dir ./pruned_transducer_stateless7/exp + +It will generate a file `iter-22000-avg-5.pt` in the given `exp_dir`. +You can later load it by `torch.load("iter-22000-avg-5.pt")`. +""" + + +import argparse +from pathlib import Path +from typing import Dict, List + +import sentencepiece as spm +import torch +from asr_datamodule import LibriSpeechAsrDataModule + +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints_with_averaged_model, + find_checkpoints, +) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=9, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + add_model_arguments(parser) + + return parser + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + print("Script started") + + device = torch.device("cpu") + print(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + print("About to create model") + model = get_transducer_model(params) + + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + print( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + filename = params.exp_dir / f"iter-{params.iter}-avg-{params.avg}.pt" + torch.save({"model": model.state_dict()}, filename) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + print( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + filename = params.exp_dir / f"epoch-{params.epoch}-avg-{params.avg}.pt" + torch.save({"model": model.state_dict()}, filename) + + num_param = sum([p.numel() for p in model.parameters()]) + print(f"Number of model parameters: {num_param}") + + print("Done!") + + +if __name__ == "__main__": + main() From 7ae03f6c887cca5cdf88a6c84a51c66790832a88 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 7 Feb 2023 17:47:08 +0800 Subject: [PATCH 101/174] Add onnx export support for pruned_transducer_stateless5 (#883) --- .github/scripts/test-onnx-export.sh | 72 +++ .../export-onnx.py | 2 +- .../pruned_transducer_stateless5/conformer.py | 33 +- .../export-onnx.py | 565 ++++++++++++++++++ .../onnx_check.py | 1 + .../onnx_pretrained.py | 1 + 6 files changed, 663 insertions(+), 11 deletions(-) create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless5/export-onnx.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless5/onnx_check.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless5/onnx_pretrained.py diff --git a/.github/scripts/test-onnx-export.sh b/.github/scripts/test-onnx-export.sh index 13a5aa765..2d5cc6bbf 100755 --- a/.github/scripts/test-onnx-export.sh +++ b/.github/scripts/test-onnx-export.sh @@ -120,3 +120,75 @@ log "Run onnx_pretrained.py" $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav + +rm -rf $repo +log "--------------------------------------------------------------------------" + + +log "==========================================================================" +repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless5-2022-05-13 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "data/lang_bpe_500/bpe.model" +git lfs pull --include "exp/pretrained-epoch-39-avg-7.pt" + +cd exp +ln -s pretrained-epoch-39-avg-7.pt epoch-99.pt +popd + +log "Export via torch.jit.script()" + +./pruned_transducer_stateless5/export.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --epoch 99 \ + --avg 1 \ + --use-averaged-model 0 \ + --exp-dir $repo/exp \ + --num-encoder-layers 18 \ + --dim-feedforward 2048 \ + --nhead 8 \ + --encoder-dim 512 \ + --decoder-dim 512 \ + --joiner-dim 512 \ + --jit 1 + +log "Test exporting to ONNX format" + +./pruned_transducer_stateless5/export-onnx.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --epoch 99 \ + --avg 1 \ + --use-averaged-model 0 \ + --exp-dir $repo/exp \ + --num-encoder-layers 18 \ + --dim-feedforward 2048 \ + --nhead 8 \ + --encoder-dim 512 \ + --decoder-dim 512 \ + --joiner-dim 512 + +ls -lh $repo/exp + +log "Run onnx_check.py" + +./pruned_transducer_stateless5/onnx_check.py \ + --jit-filename $repo/exp/cpu_jit.pt \ + --onnx-encoder-filename $repo/exp/encoder-epoch-99-avg-1.onnx \ + --onnx-decoder-filename $repo/exp/decoder-epoch-99-avg-1.onnx \ + --onnx-joiner-filename $repo/exp/joiner-epoch-99-avg-1.onnx + +log "Run onnx_pretrained.py" + +./pruned_transducer_stateless5/onnx_pretrained.py \ + --encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \ + --decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \ + --joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + +rm -rf $repo +log "--------------------------------------------------------------------------" diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/export-onnx.py b/egs/librispeech/ASR/pruned_transducer_stateless3/export-onnx.py index 1af68be70..ca8be307c 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/export-onnx.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/export-onnx.py @@ -6,7 +6,7 @@ This script exports a transducer model from PyTorch to ONNX. We use the pre-trained model from -https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29 +https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13 as an example to show how to use this file. 1. Download the pre-trained model diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py index b3a7d71bc..8bbceec61 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py @@ -32,7 +32,7 @@ from scaling import ( ) from torch import Tensor, nn -from icefall.utils import make_pad_mask, subsequent_chunk_mask +from icefall.utils import is_jit_tracing, make_pad_mask, subsequent_chunk_mask class Conformer(EncoderInterface): @@ -1012,15 +1012,28 @@ class RelPositionMultiheadAttention(nn.Module): n == left_context + 2 * time1 - 1 ), f"{n} == {left_context} + 2 * {time1} - 1" # Note: TorchScript requires explicit arg for stride() - batch_stride = x.stride(0) - head_stride = x.stride(1) - time1_stride = x.stride(2) - n_stride = x.stride(3) - return x.as_strided( - (batch_size, num_heads, time1, time2), - (batch_stride, head_stride, time1_stride - n_stride, n_stride), - storage_offset=n_stride * (time1 - 1), - ) + + if is_jit_tracing(): + rows = torch.arange(start=time1 - 1, end=-1, step=-1) + cols = torch.arange(time2) + rows = rows.repeat(batch_size * num_heads).unsqueeze(-1) + indexes = rows + cols + + x = x.reshape(-1, n) + x = torch.gather(x, dim=1, index=indexes) + x = x.reshape(batch_size, num_heads, time1, time2) + return x + else: + # Note: TorchScript requires explicit arg for stride() + batch_stride = x.stride(0) + head_stride = x.stride(1) + time1_stride = x.stride(2) + n_stride = x.stride(3) + return x.as_strided( + (batch_size, num_heads, time1, time2), + (batch_stride, head_stride, time1_stride - n_stride, n_stride), + storage_offset=n_stride * (time1 - 1), + ) def multi_head_attention_forward( self, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/export-onnx.py b/egs/librispeech/ASR/pruned_transducer_stateless5/export-onnx.py new file mode 100755 index 000000000..743fe8a92 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/export-onnx.py @@ -0,0 +1,565 @@ +#!/usr/bin/env python3 +# +# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang) + +""" +This script exports a transducer model from PyTorch to ONNX. + +We use the pre-trained model from +https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless5-2022-05-13 +as an example to show how to use this file. + +1. Download the pre-trained model + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless5-2022-05-13 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "data/lang_bpe_500/bpe.model" +git lfs pull --include "exp/pretrained-epoch-39-avg-7.pt" + +cd exp +ln -s pretrained-epoch-39-avg-7.pt epoch-99.pt +popd + +2. Export the model to ONNX + +./pruned_transducer_stateless5/export-onnx.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --epoch 99 \ + --avg 1 \ + --use-averaged-model 0 \ + --exp-dir $repo/exp \ + --num-encoder-layers 18 \ + --dim-feedforward 2048 \ + --nhead 8 \ + --encoder-dim 512 \ + --decoder-dim 512 \ + --joiner-dim 512 + +It will generate the following 3 files inside $repo/exp: + + - encoder-epoch-99-avg-1.onnx + - decoder-epoch-99-avg-1.onnx + - joiner-epoch-99-avg-1.onnx + +See ./onnx_pretrained.py and ./onnx_check.py for how to +use the exported ONNX models. +""" + +import argparse +import logging +from pathlib import Path +from typing import Dict, Tuple + +import onnx +import sentencepiece as spm +import torch +import torch.nn as nn +from conformer import Conformer +from decoder import Decoder +from scaling_converter import convert_scaled_to_non_scaled +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import setup_logger, str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="""It specifies the checkpoint to use for averaging. + Note: Epoch counts from 0. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless5/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + add_model_arguments(parser) + + return parser + + +def add_meta_data(filename: str, meta_data: Dict[str, str]): + """Add meta data to an ONNX model. It is changed in-place. + + Args: + filename: + Filename of the ONNX model to be changed. + meta_data: + Key-value pairs. + """ + model = onnx.load(filename) + for key, value in meta_data.items(): + meta = model.metadata_props.add() + meta.key = key + meta.value = value + + onnx.save(model, filename) + + +class OnnxEncoder(nn.Module): + """A wrapper for Conformer and the encoder_proj from the joiner""" + + def __init__(self, encoder: Conformer, encoder_proj: nn.Linear): + """ + Args: + encoder: + A Conformer encoder. + encoder_proj: + The projection layer for encoder from the joiner. + """ + super().__init__() + self.encoder = encoder + self.encoder_proj = encoder_proj + + def forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Please see the help information of Conformer.forward + + Args: + x: + A 3-D tensor of shape (N, T, C) + x_lens: + A 1-D tensor of shape (N,). Its dtype is torch.int64 + Returns: + Return a tuple containing: + - encoder_out, A 3-D tensor of shape (N, T', joiner_dim) + - encoder_out_lens, A 1-D tensor of shape (N,) + """ + encoder_out, encoder_out_lens = self.encoder(x, x_lens) + + encoder_out = self.encoder_proj(encoder_out) + # Now encoder_out is of shape (N, T, joiner_dim) + + return encoder_out, encoder_out_lens + + +class OnnxDecoder(nn.Module): + """A wrapper for Decoder and the decoder_proj from the joiner""" + + def __init__(self, decoder: Decoder, decoder_proj: nn.Linear): + super().__init__() + self.decoder = decoder + self.decoder_proj = decoder_proj + + def forward(self, y: torch.Tensor) -> torch.Tensor: + """ + Args: + y: + A 2-D tensor of shape (N, context_size). + Returns + Return a 2-D tensor of shape (N, joiner_dim) + """ + need_pad = False + decoder_output = self.decoder(y, need_pad=need_pad) + decoder_output = decoder_output.squeeze(1) + output = self.decoder_proj(decoder_output) + + return output + + +class OnnxJoiner(nn.Module): + """A wrapper for the joiner""" + + def __init__(self, output_linear: nn.Linear): + super().__init__() + self.output_linear = output_linear + + def forward( + self, + encoder_out: torch.Tensor, + decoder_out: torch.Tensor, + ) -> torch.Tensor: + """ + Args: + encoder_out: + A 2-D tensor of shape (N, joiner_dim) + decoder_out: + A 2-D tensor of shape (N, joiner_dim) + Returns: + Return a 2-D tensor of shape (N, vocab_size) + """ + logit = encoder_out + decoder_out + logit = self.output_linear(torch.tanh(logit)) + return logit + + +def export_encoder_model_onnx( + encoder_model: OnnxEncoder, + encoder_filename: str, + opset_version: int = 11, +) -> None: + """Export the given encoder model to ONNX format. + The exported model has two inputs: + + - x, a tensor of shape (N, T, C); dtype is torch.float32 + - x_lens, a tensor of shape (N,); dtype is torch.int64 + + and it has two outputs: + + - encoder_out, a tensor of shape (N, T', joiner_dim) + - encoder_out_lens, a tensor of shape (N,) + + Args: + encoder_model: + The input encoder model + encoder_filename: + The filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + x = torch.zeros(1, 100, 80, dtype=torch.float32) + x_lens = torch.tensor([100], dtype=torch.int64) + + torch.onnx.export( + encoder_model, + (x, x_lens), + encoder_filename, + verbose=False, + opset_version=opset_version, + input_names=["x", "x_lens"], + output_names=["encoder_out", "encoder_out_lens"], + dynamic_axes={ + "x": {0: "N", 1: "T"}, + "x_lens": {0: "N"}, + "encoder_out": {0: "N", 1: "T"}, + "encoder_out_lens": {0: "N"}, + }, + ) + + +def export_decoder_model_onnx( + decoder_model: OnnxDecoder, + decoder_filename: str, + opset_version: int = 11, +) -> None: + """Export the decoder model to ONNX format. + + The exported model has one input: + + - y: a torch.int64 tensor of shape (N, decoder_model.context_size) + + and has one output: + + - decoder_out: a torch.float32 tensor of shape (N, joiner_dim) + + Args: + decoder_model: + The decoder model to be exported. + decoder_filename: + Filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + context_size = decoder_model.decoder.context_size + vocab_size = decoder_model.decoder.vocab_size + + y = torch.zeros(10, context_size, dtype=torch.int64) + torch.onnx.export( + decoder_model, + y, + decoder_filename, + verbose=False, + opset_version=opset_version, + input_names=["y"], + output_names=["decoder_out"], + dynamic_axes={ + "y": {0: "N"}, + "decoder_out": {0: "N"}, + }, + ) + + meta_data = { + "context_size": str(context_size), + "vocab_size": str(vocab_size), + } + add_meta_data(filename=decoder_filename, meta_data=meta_data) + + +def export_joiner_model_onnx( + joiner_model: nn.Module, + joiner_filename: str, + opset_version: int = 11, +) -> None: + """Export the joiner model to ONNX format. + The exported joiner model has two inputs: + + - encoder_out: a tensor of shape (N, joiner_dim) + - decoder_out: a tensor of shape (N, joiner_dim) + + and produces one output: + + - logit: a tensor of shape (N, vocab_size) + """ + joiner_dim = joiner_model.output_linear.weight.shape[1] + logging.info(f"joiner dim: {joiner_dim}") + + projected_encoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) + projected_decoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) + + torch.onnx.export( + joiner_model, + (projected_encoder_out, projected_decoder_out), + joiner_filename, + verbose=False, + opset_version=opset_version, + input_names=[ + "encoder_out", + "decoder_out", + ], + output_names=["logit"], + dynamic_axes={ + "encoder_out": {0: "N"}, + "decoder_out": {0: "N"}, + "logit": {0: "N"}, + }, + ) + meta_data = { + "joiner_dim": str(joiner_dim), + } + add_meta_data(filename=joiner_filename, meta_data=meta_data) + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + setup_logger(f"{params.exp_dir}/log-export/log-export-onnx") + + logging.info(f"device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + model.to(device) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to("cpu") + model.eval() + + convert_scaled_to_non_scaled(model, inplace=True) + + encoder = OnnxEncoder( + encoder=model.encoder, + encoder_proj=model.joiner.encoder_proj, + ) + + decoder = OnnxDecoder( + decoder=model.decoder, + decoder_proj=model.joiner.decoder_proj, + ) + + joiner = OnnxJoiner(output_linear=model.joiner.output_linear) + + encoder_num_param = sum([p.numel() for p in encoder.parameters()]) + decoder_num_param = sum([p.numel() for p in decoder.parameters()]) + joiner_num_param = sum([p.numel() for p in joiner.parameters()]) + total_num_param = encoder_num_param + decoder_num_param + joiner_num_param + logging.info(f"encoder parameters: {encoder_num_param}") + logging.info(f"decoder parameters: {decoder_num_param}") + logging.info(f"joiner parameters: {joiner_num_param}") + logging.info(f"total parameters: {total_num_param}") + + if params.iter > 0: + suffix = f"iter-{params.iter}" + else: + suffix = f"epoch-{params.epoch}" + + suffix += f"-avg-{params.avg}" + + opset_version = 13 + + logging.info("Exporting encoder") + encoder_filename = params.exp_dir / f"encoder-{suffix}.onnx" + export_encoder_model_onnx( + encoder, + encoder_filename, + opset_version=opset_version, + ) + logging.info(f"Exported encoder to {encoder_filename}") + + logging.info("Exporting decoder") + decoder_filename = params.exp_dir / f"decoder-{suffix}.onnx" + export_decoder_model_onnx( + decoder, + decoder_filename, + opset_version=opset_version, + ) + logging.info(f"Exported decoder to {decoder_filename}") + + logging.info("Exporting joiner") + joiner_filename = params.exp_dir / f"joiner-{suffix}.onnx" + export_joiner_model_onnx( + joiner, + joiner_filename, + opset_version=opset_version, + ) + logging.info(f"Exported joiner to {joiner_filename}") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/onnx_check.py b/egs/librispeech/ASR/pruned_transducer_stateless5/onnx_check.py new file mode 120000 index 000000000..66d63b807 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/onnx_check.py @@ -0,0 +1 @@ +../pruned_transducer_stateless3/onnx_check.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/onnx_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless5/onnx_pretrained.py new file mode 120000 index 000000000..7607623c8 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/onnx_pretrained.py @@ -0,0 +1 @@ +../pruned_transducer_stateless3/onnx_pretrained.py \ No newline at end of file From d12e6f098c79341d11f628dc8f2605e6365e5ecd Mon Sep 17 00:00:00 2001 From: Zengwei Yao Date: Tue, 7 Feb 2023 21:43:16 +0800 Subject: [PATCH 102/174] Get (start, end) timestamps for CTC models (#876) * parse timestamps and texts for BPE-based models * parse timestamps (frame indexes) and texts for other cases * add test functions * add parse_fsa_timestamps_and_texts function, test in conformer_ctc3/decode.py * calculate symbol delay for (start, end) timestamps --- egs/librispeech/ASR/conformer_ctc3/decode.py | 63 ++-- icefall/utils.py | 367 ++++++++++++++++++- test/test_parse_timestamp.py | 154 ++++++++ 3 files changed, 545 insertions(+), 39 deletions(-) create mode 100755 test/test_parse_timestamp.py diff --git a/egs/librispeech/ASR/conformer_ctc3/decode.py b/egs/librispeech/ASR/conformer_ctc3/decode.py index 39186e546..3b24ad597 100755 --- a/egs/librispeech/ASR/conformer_ctc3/decode.py +++ b/egs/librispeech/ASR/conformer_ctc3/decode.py @@ -96,8 +96,7 @@ from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, get_texts, - get_texts_with_timestamp, - parse_hyp_and_timestamp, + parse_fsa_timestamps_and_texts, setup_logger, store_transcripts_and_timestamps, str2bool, @@ -396,13 +395,8 @@ def decode_one_batch( best_path = one_best_decoding( lattice=lattice, use_double_scores=params.use_double_scores ) - # Note: `best_path.aux_labels` contains token IDs, not word IDs - # since we are using H, not HLG here. - # - # token_ids is a lit-of-list of IDs - res = get_texts_with_timestamp(best_path) - hyps, timestamps = parse_hyp_and_timestamp( - res=res, + timestamps, hyps = parse_fsa_timestamps_and_texts( + best_paths=best_path, sp=bpe_model, subsampling_factor=params.subsampling_factor, frame_shift_ms=params.frame_shift_ms, @@ -435,12 +429,11 @@ def decode_one_batch( lattice=lattice, use_double_scores=params.use_double_scores ) key = f"no_rescore_hlg_scale_{params.hlg_scale}" - res = get_texts_with_timestamp(best_path) - hyps, timestamps = parse_hyp_and_timestamp( - res=res, + timestamps, hyps = parse_fsa_timestamps_and_texts( + best_paths=best_path, + word_table=word_table, subsampling_factor=params.subsampling_factor, frame_shift_ms=params.frame_shift_ms, - word_table=word_table, ) else: best_path = nbest_decoding( @@ -504,7 +497,18 @@ def decode_dataset( sos_id: int, eos_id: int, G: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[str, List[str], List[str], List[float], List[float]]]]: +) -> Dict[ + str, + List[ + Tuple[ + str, + List[str], + List[str], + List[Tuple[float, float]], + List[Tuple[float, float]], + ] + ], +]: """Decode dataset. Args: @@ -555,7 +559,7 @@ def decode_dataset( time = [] if s.alignment is not None and "word" in s.alignment: time = [ - aliword.start + (aliword.start, aliword.end) for aliword in s.alignment["word"] if aliword.symbol != "" ] @@ -601,7 +605,15 @@ def save_results( test_set_name: str, results_dict: Dict[ str, - List[Tuple[List[str], List[str], List[str], List[float], List[float]]], + List[ + Tuple[ + List[str], + List[str], + List[str], + List[Tuple[float, float]], + List[Tuple[float, float]], + ] + ], ], ): test_set_wers = dict() @@ -621,7 +633,11 @@ def save_results( ) with open(errs_filename, "w") as f: wer, mean_delay, var_delay = write_error_stats_with_timestamps( - f, f"{test_set_name}-{key}", results, enable_log=True + f, + f"{test_set_name}-{key}", + results, + enable_log=True, + with_end_time=True, ) test_set_wers[key] = wer test_set_delays[key] = (mean_delay, var_delay) @@ -637,16 +653,17 @@ def save_results( for key, val in test_set_wers: print("{}\t{}".format(key, val), file=f) - test_set_delays = sorted(test_set_delays.items(), key=lambda x: x[1][0]) + # sort according to the mean start symbol delay + test_set_delays = sorted(test_set_delays.items(), key=lambda x: x[1][0][0]) delays_info = ( params.res_dir / f"symbol-delay-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(delays_info, "w") as f: - print("settings\tsymbol-delay", file=f) + print("settings\t(start, end) symbol-delay (s) (start, end)", file=f) for key, val in test_set_delays: print( - "{}\tmean: {}s, variance: {}".format(key, val[0], val[1]), + "{}\tmean: {}, variance: {}".format(key, val[0], val[1]), file=f, ) @@ -657,10 +674,12 @@ def save_results( note = "" logging.info(s) - s = "\nFor {}, symbol-delay of different settings are:\n".format(test_set_name) + s = "\nFor {}, (start, end) symbol-delay (s) of different settings are:\n".format( + test_set_name + ) note = "\tbest for {}".format(test_set_name) for key, val in test_set_delays: - s += "{}\tmean: {}s, variance: {}{}\n".format(key, val[0], val[1], note) + s += "{}\tmean: {}, variance: {}{}\n".format(key, val[0], val[1], note) note = "" logging.info(s) diff --git a/icefall/utils.py b/icefall/utils.py index ba0b7fe43..2358ed02f 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -1,5 +1,6 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang -# Mingshuang Luo) +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, +# Mingshuang Luo, +# Zengwei Yao) # # See ../../LICENSE for clarification regarding multiple authors # @@ -453,11 +454,32 @@ def store_transcripts_and_timestamps( for cut_id, ref, hyp, time_ref, time_hyp in texts: print(f"{cut_id}:\tref={ref}", file=f) print(f"{cut_id}:\thyp={hyp}", file=f) + if len(time_ref) > 0: - s = "[" + ", ".join(["%0.3f" % i for i in time_ref]) + "]" + if isinstance(time_ref[0], tuple): + # each element is pair + s = ( + "[" + + ", ".join(["(%0.3f, %.03f)" % (i, j) for (i, j) in time_ref]) + + "]" + ) + else: + # each element is a float number + s = "[" + ", ".join(["%0.3f" % i for i in time_ref]) + "]" print(f"{cut_id}:\ttimestamp_ref={s}", file=f) - s = "[" + ", ".join(["%0.3f" % i for i in time_hyp]) + "]" - print(f"{cut_id}:\ttimestamp_hyp={s}", file=f) + + if len(time_hyp) > 0: + if isinstance(time_hyp[0], tuple): + # each element is pair + s = ( + "[" + + ", ".join(["(%0.3f, %.03f)" % (i, j) for (i, j) in time_hyp]) + + "]" + ) + else: + # each element is a float number + s = "[" + ", ".join(["%0.3f" % i for i in time_hyp]) + "]" + print(f"{cut_id}:\ttimestamp_hyp={s}", file=f) def write_error_stats( @@ -624,9 +646,18 @@ def write_error_stats( def write_error_stats_with_timestamps( f: TextIO, test_set_name: str, - results: List[Tuple[str, List[str], List[str], List[float], List[float]]], + results: List[ + Tuple[ + str, + List[str], + List[str], + List[Union[float, Tuple[float, float]]], + List[Union[float, Tuple[float, float]]], + ] + ], enable_log: bool = True, -) -> Tuple[float, float, float]: + with_end_time: bool = False, +) -> Tuple[float, Union[float, Tuple[float, float]], Union[float, Tuple[float, float]]]: """Write statistics based on predicted results and reference transcripts as well as their timestamps. @@ -659,6 +690,8 @@ def write_error_stats_with_timestamps( enable_log: If True, also print detailed WER to the console. Otherwise, it is written only to the given file. + with_end_time: + Whether use end timestamps. Returns: Return total word error rate and mean delay. @@ -704,7 +737,15 @@ def write_error_stats_with_timestamps( words[ref_word][0] += 1 num_corr += 1 if has_time: - all_delay.append(time_hyp[p_hyp] - time_ref[p_ref]) + if with_end_time: + all_delay.append( + ( + time_hyp[p_hyp][0] - time_ref[p_ref][0], + time_hyp[p_hyp][1] - time_ref[p_ref][1], + ) + ) + else: + all_delay.append(time_hyp[p_hyp] - time_ref[p_ref]) p_hyp += 1 p_ref += 1 if has_time: @@ -716,16 +757,39 @@ def write_error_stats_with_timestamps( ins_errs = sum(ins.values()) del_errs = sum(dels.values()) tot_errs = sub_errs + ins_errs + del_errs - tot_err_rate = "%.2f" % (100.0 * tot_errs / ref_len) + tot_err_rate = float("%.2f" % (100.0 * tot_errs / ref_len)) - mean_delay = "inf" - var_delay = "inf" + if with_end_time: + mean_delay = (float("inf"), float("inf")) + var_delay = (float("inf"), float("inf")) + else: + mean_delay = float("inf") + var_delay = float("inf") num_delay = len(all_delay) if num_delay > 0: - mean_delay = sum(all_delay) / num_delay - var_delay = sum([(i - mean_delay) ** 2 for i in all_delay]) / num_delay - mean_delay = "%.3f" % mean_delay - var_delay = "%.3f" % var_delay + if with_end_time: + all_delay_start = [i[0] for i in all_delay] + mean_delay_start = sum(all_delay_start) / num_delay + var_delay_start = ( + sum([(i - mean_delay_start) ** 2 for i in all_delay_start]) / num_delay + ) + + all_delay_end = [i[1] for i in all_delay] + mean_delay_end = sum(all_delay_end) / num_delay + var_delay_end = ( + sum([(i - mean_delay_end) ** 2 for i in all_delay_end]) / num_delay + ) + + mean_delay = ( + float("%.3f" % mean_delay_start), + float("%.3f" % mean_delay_end), + ) + var_delay = (float("%.3f" % var_delay_start), float("%.3f" % var_delay_end)) + else: + mean_delay = sum(all_delay) / num_delay + var_delay = sum([(i - mean_delay) ** 2 for i in all_delay]) / num_delay + mean_delay = float("%.3f" % mean_delay) + var_delay = float("%.3f" % var_delay) if enable_log: logging.info( @@ -734,7 +798,8 @@ def write_error_stats_with_timestamps( f"{del_errs} del, {sub_errs} sub ]" ) logging.info( - f"[{test_set_name}] %symbol-delay mean: {mean_delay}s, variance: {var_delay} " # noqa + f"[{test_set_name}] %symbol-delay mean (s): " + f"{mean_delay}, variance: {var_delay} " # noqa f"computed on {num_delay} correct words" ) @@ -817,7 +882,8 @@ def write_error_stats_with_timestamps( hyp_count = corr + hyp_sub + ins print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f) - return float(tot_err_rate), float(mean_delay), float(var_delay) + + return tot_err_rate, mean_delay, var_delay class MetricsTracker(collections.defaultdict): @@ -1431,3 +1497,270 @@ def filter_uneven_sized_batch(batch: dict, allowed_max_frames: int): batch["supervisions"][k] = v[:keep_num_utt] return batch + + +def parse_bpe_start_end_pairs( + tokens: List[str], is_first_token: List[bool] +) -> List[Tuple[int, int]]: + """Parse pairs of start and end frame indexes for each word. + + Args: + tokens: + List of BPE tokens. + is_first_token: + List of bool values, which indicates whether it is the first token, + i.e., not repeat or blank. + + Returns: + List of (start-frame-index, end-frame-index) pairs for each word. + """ + assert len(tokens) == len(is_first_token), (len(tokens), len(is_first_token)) + + start_token = b"\xe2\x96\x81".decode() # '_' + blank_token = "" + + non_blank_idx = [i for i in range(len(tokens)) if tokens[i] != blank_token] + num_non_blank = len(non_blank_idx) + + pairs = [] + start = -1 + end = -1 + for j in range(num_non_blank): + # The index in all frames + i = non_blank_idx[j] + + found_start = False + if is_first_token[i] and (j == 0 or tokens[i].startswith(start_token)): + found_start = True + if tokens[i] == start_token: + if j == num_non_blank - 1: + # It is the last non-blank token + found_start = False + elif is_first_token[non_blank_idx[j + 1]] and tokens[ + non_blank_idx[j + 1] + ].startswith(start_token): + # The next not-blank token is a first-token and also starts with start_token + found_start = False + if found_start: + start = i + + if start != -1: + found_end = False + if j == num_non_blank - 1: + # It is the last non-blank token + found_end = True + elif is_first_token[non_blank_idx[j + 1]] and tokens[ + non_blank_idx[j + 1] + ].startswith(start_token): + # The next not-blank token is a first-token and also starts with start_token + found_end = True + if found_end: + end = i + + if start != -1 and end != -1: + if not all([tokens[t] == start_token for t in range(start, end + 1)]): + # except the case of all start_token + pairs.append((start, end)) + # Reset start and end + start = -1 + end = -1 + + return pairs + + +def parse_bpe_timestamps_and_texts( + best_paths: k2.Fsa, sp: spm.SentencePieceProcessor +) -> Tuple[List[Tuple[int, int]], List[List[str]]]: + """Parse timestamps (frame indexes) and texts. + + Args: + best_paths: + A k2.Fsa with best_paths.arcs.num_axes() == 3, i.e. + containing multiple FSAs, which is expected to be the result + of k2.shortest_path (otherwise the returned values won't + be meaningful). Its attribtutes `labels` and `aux_labels` + are both BPE tokens. + sp: + The BPE model. + + Returns: + utt_index_pairs: + A list of pair list. utt_index_pairs[i] is a list of + (start-frame-index, end-frame-index) pairs for each word in + utterance-i. + utt_words: + A list of str list. utt_words[i] is a word list of utterence-i. + """ + shape = best_paths.arcs.shape().remove_axis(1) + + # labels: [utt][arcs] + labels = k2.RaggedTensor(shape, best_paths.labels.contiguous()) + # remove -1's. + labels = labels.remove_values_eq(-1) + labels = labels.tolist() + + # aux_labels: [utt][arcs] + aux_labels = k2.RaggedTensor(shape, best_paths.aux_labels.contiguous()) + + # remove -1's. + all_aux_labels = aux_labels.remove_values_eq(-1) + # len(all_aux_labels[i]) is equal to the number of frames + all_aux_labels = all_aux_labels.tolist() + + # remove 0's and -1's. + out_aux_labels = aux_labels.remove_values_leq(0) + # len(out_aux_labels[i]) is equal to the number of output BPE tokens + out_aux_labels = out_aux_labels.tolist() + + utt_index_pairs = [] + utt_words = [] + for i in range(len(labels)): + tokens = sp.id_to_piece(labels[i]) + words = sp.decode(out_aux_labels[i]).split() + + # Indicates whether it is the first token, i.e., not-repeat and not-blank. + is_first_token = [a != 0 for a in all_aux_labels[i]] + index_pairs = parse_bpe_start_end_pairs(tokens, is_first_token) + assert len(index_pairs) == len(words), (len(index_pairs), len(words), tokens) + utt_index_pairs.append(index_pairs) + utt_words.append(words) + + return utt_index_pairs, utt_words + + +def parse_timestamps_and_texts( + best_paths: k2.Fsa, word_table: k2.SymbolTable +) -> Tuple[List[Tuple[int, int]], List[List[str]]]: + """Parse timestamps (frame indexes) and texts. + + Args: + best_paths: + A k2.Fsa with best_paths.arcs.num_axes() == 3, i.e. + containing multiple FSAs, which is expected to be the result + of k2.shortest_path (otherwise the returned values won't + be meaningful). Attribtute `labels` is the prediction unit, + e.g., phone or BPE tokens. Attribute `aux_labels` is the word index. + word_table: + The word symbol table. + + Returns: + utt_index_pairs: + A list of pair list. utt_index_pairs[i] is a list of + (start-frame-index, end-frame-index) pairs for each word in + utterance-i. + utt_words: + A list of str list. utt_words[i] is a word list of utterence-i. + """ + # [utt][words] + word_ids = get_texts(best_paths) + + shape = best_paths.arcs.shape().remove_axis(1) + + # labels: [utt][arcs] + labels = k2.RaggedTensor(shape, best_paths.labels.contiguous()) + # remove -1's. + labels = labels.remove_values_eq(-1) + labels = labels.tolist() + + # aux_labels: [utt][arcs] + aux_shape = shape.compose(best_paths.aux_labels.shape) + aux_labels = k2.RaggedTensor(aux_shape, best_paths.aux_labels.values.contiguous()) + aux_labels = aux_labels.tolist() + + utt_index_pairs = [] + utt_words = [] + for i, (label, aux_label) in enumerate(zip(labels, aux_labels)): + num_arcs = len(label) + # The last arc of aux_label is the arc entering the final state + assert num_arcs == len(aux_label) - 1, (num_arcs, len(aux_label)) + + index_pairs = [] + start = -1 + end = -1 + for arc in range(num_arcs): + # len(aux_label[arc]) is 0 or 1 + if label[arc] != 0 and len(aux_label[arc]) != 0: + if start != -1 and end != -1: + index_pairs.append((start, end)) + start = arc + if label[arc] != 0: + end = arc + if start != -1 and end != -1: + index_pairs.append((start, end)) + + words = [word_table[w] for w in word_ids[i]] + assert len(index_pairs) == len(words), (len(index_pairs), len(words)) + + utt_index_pairs.append(index_pairs) + utt_words.append(words) + + return utt_index_pairs, utt_words + + +def parse_fsa_timestamps_and_texts( + best_paths: k2.Fsa, + sp: Optional[spm.SentencePieceProcessor] = None, + word_table: Optional[k2.SymbolTable] = None, + subsampling_factor: int = 4, + frame_shift_ms: float = 10, +) -> Tuple[List[Tuple[float, float]], List[List[str]]]: + """Parse timestamps (in seconds) and texts for given decoded fsa paths. + Currently it supports two cases: + (1) ctc-decoding, the attribtutes `labels` and `aux_labels` + are both BPE tokens. In this case, sp should be provided. + (2) HLG-based 1best, the attribtute `labels` is the prediction unit, + e.g., phone or BPE tokens; attribute `aux_labels` is the word index. + In this case, word_table should be provided. + + Args: + best_paths: + A k2.Fsa with best_paths.arcs.num_axes() == 3, i.e. + containing multiple FSAs, which is expected to be the result + of k2.shortest_path (otherwise the returned values won't + be meaningful). + sp: + The BPE model. + word_table: + The word symbol table. + subsampling_factor: + The subsampling factor of the model. + frame_shift_ms: + Frame shift in milliseconds between two contiguous frames. + + Returns: + utt_time_pairs: + A list of pair list. utt_time_pairs[i] is a list of + (start-time, end-time) pairs for each word in + utterance-i. + utt_words: + A list of str list. utt_words[i] is a word list of utterence-i. + """ + if sp is not None: + assert word_table is None, "word_table is not needed if sp is provided." + utt_index_pairs, utt_words = parse_bpe_timestamps_and_texts( + best_paths=best_paths, sp=sp + ) + elif word_table is not None: + assert sp is None, "sp is not needed if word_table is provided." + utt_index_pairs, utt_words = parse_timestamps_and_texts( + best_paths=best_paths, word_table=word_table + ) + else: + raise ValueError("Either sp or word_table should be provided.") + + utt_time_pairs = [] + for utt in utt_index_pairs: + start = convert_timestamp( + frames=[i[0] for i in utt], + subsampling_factor=subsampling_factor, + frame_shift_ms=frame_shift_ms, + ) + end = convert_timestamp( + # The duration in frames is (end_frame_index - start_frame_index + 1) + frames=[i[1] + 1 for i in utt], + subsampling_factor=subsampling_factor, + frame_shift_ms=frame_shift_ms, + ) + utt_time_pairs.append(list(zip(start, end))) + + return utt_time_pairs, utt_words diff --git a/test/test_parse_timestamp.py b/test/test_parse_timestamp.py new file mode 100755 index 000000000..92bfb49c6 --- /dev/null +++ b/test/test_parse_timestamp.py @@ -0,0 +1,154 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao) +# +# See ../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from pathlib import Path + +import k2 +import sentencepiece as spm +import torch + +from icefall.lexicon import Lexicon +from icefall.utils import parse_bpe_timestamps_and_texts, parse_timestamps_and_texts + +ICEFALL_DIR = Path(__file__).resolve().parent.parent + + +def test_parse_bpe_timestamps_and_texts(): + lang_dir = ICEFALL_DIR / "egs/librispeech/ASR/data/lang_bpe_500" + if not lang_dir.is_dir(): + print(f"{lang_dir} does not exist.") + return + + sp = spm.SentencePieceProcessor() + sp.load(str(lang_dir / "bpe.model")) + + text_1 = "HELLO WORLD" + token_ids_1 = sp.encode(text_1, out_type=int) + # out_type=str: ['_HE', 'LL', 'O', '_WORLD'] + # out_type=int: [22, 58, 24, 425] + + # [22, 22, 58, 24, 0, 0, 425, 425, 425, 0, 0] + labels_1 = ( + token_ids_1[0:1] * 2 + + token_ids_1[1:3] + + [0] * 2 + + token_ids_1[3:4] * 3 + + [0] * 2 + ) + # [22, 0, 58, 24, 0, 0, 425, 0, 0, 0, 0, -1] + aux_labels_1 = ( + token_ids_1[0:1] + + [0] + + token_ids_1[1:3] + + [0] * 2 + + token_ids_1[3:4] + + [0] * 4 + + [-1] + ) + fsa_1 = k2.linear_fsa(labels_1) + fsa_1.aux_labels = torch.tensor(aux_labels_1).to(torch.int32) + + text_2 = "SAY GOODBYE" + token_ids_2 = sp.encode(text_2, out_type=int) + # out_type=str: ['_SAY', '_GOOD', 'B', 'Y', 'E'] + # out_type=int: [289, 286, 41, 16, 11] + + # [289, 0, 0, 286, 286, 41, 16, 11, 0, 0] + labels_2 = ( + token_ids_2[0:1] + [0] * 2 + token_ids_2[1:2] * 2 + token_ids_2[2:5] + [0] * 2 + ) + # [289, 0, 0, 286, 0, 41, 16, 11, 0, 0, -1] + aux_labels_2 = ( + token_ids_2[0:1] + + [0] * 2 + + token_ids_2[1:2] + + [0] + + token_ids_2[2:5] + + [0] * 2 + + [-1] + ) + fsa_2 = k2.linear_fsa(labels_2) + fsa_2.aux_labels = torch.tensor(aux_labels_2).to(torch.int32) + + fsa_vec = k2.create_fsa_vec([fsa_1, fsa_2]) + + utt_index_pairs, utt_words = parse_bpe_timestamps_and_texts(fsa_vec, sp) + assert utt_index_pairs[0] == [(0, 3), (6, 8)], utt_index_pairs[0] + assert utt_words[0] == ["HELLO", "WORLD"], utt_words[0] + assert utt_index_pairs[1] == [(0, 0), (3, 7)], utt_index_pairs[1] + assert utt_words[1] == ["SAY", "GOODBYE"], utt_words[1] + + +def test_parse_timestamps_and_texts(): + lang_dir = ICEFALL_DIR / "egs/librispeech/ASR/data/lang_bpe_500" + if not lang_dir.is_dir(): + print(f"{lang_dir} does not exist.") + return + + lexicon = Lexicon(lang_dir) + + sp = spm.SentencePieceProcessor() + sp.load(str(lang_dir / "bpe.model")) + word_table = lexicon.word_table + + text_1 = "HELLO WORLD" + token_ids_1 = sp.encode(text_1, out_type=int) + # out_type=str: ['_HE', 'LL', 'O', '_WORLD'] + # out_type=int: [22, 58, 24, 425] + word_ids_1 = [word_table[s] for s in text_1.split()] # [79677, 196937] + # [22, 22, 58, 24, 0, 0, 425, 425, 425, 0, 0] + labels_1 = ( + token_ids_1[0:1] * 2 + + token_ids_1[1:3] + + [0] * 2 + + token_ids_1[3:4] * 3 + + [0] * 2 + ) + # [[79677], [], [], [], [], [], [196937], [], [], [], [], []] + aux_labels_1 = [word_ids_1[0:1]] + [[]] * 5 + [word_ids_1[1:2]] + [[]] * 5 + + fsa_1 = k2.linear_fsa(labels_1) + fsa_1.aux_labels = k2.RaggedTensor(aux_labels_1) + + text_2 = "SAY GOODBYE" + token_ids_2 = sp.encode(text_2, out_type=int) + # out_type=str: ['_SAY', '_GOOD', 'B', 'Y', 'E'] + # out_type=int: [289, 286, 41, 16, 11] + word_ids_2 = [word_table[s] for s in text_2.split()] # [154967, 72079] + # [289, 0, 0, 286, 286, 41, 16, 11, 0, 0] + labels_2 = ( + token_ids_2[0:1] + [0] * 2 + token_ids_2[1:2] * 2 + token_ids_2[2:5] + [0] * 2 + ) + # [[154967], [], [], [72079], [], [], [], [], [], [], []] + aux_labels_2 = [word_ids_2[0:1]] + [[]] * 2 + [word_ids_2[1:2]] + [[]] * 7 + + fsa_2 = k2.linear_fsa(labels_2) + fsa_2.aux_labels = k2.RaggedTensor(aux_labels_2) + + fsa_vec = k2.create_fsa_vec([fsa_1, fsa_2]) + + utt_index_pairs, utt_words = parse_timestamps_and_texts(fsa_vec, word_table) + assert utt_index_pairs[0] == [(0, 3), (6, 8)], utt_index_pairs[0] + assert utt_words[0] == ["HELLO", "WORLD"], utt_words[0] + assert utt_index_pairs[1] == [(0, 0), (3, 7)], utt_index_pairs[1] + assert utt_words[1] == ["SAY", "GOODBYE"], utt_words[1] + + +if __name__ == "__main__": + test_parse_bpe_timestamps_and_texts() + test_parse_timestamps_and_texts() From af735eb75bf55e6b1e41602105a6f939aedbaf5c Mon Sep 17 00:00:00 2001 From: Zengwei Yao Date: Wed, 8 Feb 2023 21:54:35 +0800 Subject: [PATCH 103/174] Get alignments using lhotse workflows align-with-torchaudio (#888) * add lhotse workflow align-with-torchaudio * modify related decode.py files --- egs/librispeech/ASR/add_alignments.sh | 50 +++++++++++++++++-- egs/librispeech/ASR/conformer_ctc3/decode.py | 5 +- .../ASR/lstm_transducer_stateless3/decode.py | 5 +- .../pruned_transducer_stateless4/decode.py | 5 +- 4 files changed, 48 insertions(+), 17 deletions(-) diff --git a/egs/librispeech/ASR/add_alignments.sh b/egs/librispeech/ASR/add_alignments.sh index 5e4480bf6..6c47d25a2 100755 --- a/egs/librispeech/ASR/add_alignments.sh +++ b/egs/librispeech/ASR/add_alignments.sh @@ -2,11 +2,51 @@ set -eou pipefail -alignments_dir=data/alignment +# align could be in ("mfa", "torchaudio") +# We recommend "torchaudio" +align="torchaudio" + +# It adds alignments to the existing fbank features dir (e.g., data/fbank) +# and save cuts to a new dir (e.g., data/fbank_ali). cuts_in_dir=data/fbank cuts_out_dir=data/fbank_ali -python3 ./local/add_alignment_librispeech.py \ - --alignments-dir $alignments_dir \ - --cuts-in-dir $cuts_in_dir \ - --cuts-out-dir $cuts_out_dir +if [ $align == "mfa" ]; then + # It add alignments from https://github.com/CorentinJ/librispeech-alignments, + # generated using the Montreal Forced Aligner (https://montreal-forced-aligner.readthedocs.io). + alignments_dir=data/alignment + + python3 ./local/add_alignment_librispeech.py \ + --alignments-dir $alignments_dir \ + --cuts-in-dir $cuts_in_dir \ + --cuts-out-dir $cuts_out_dir +elif [ $align == "torchaudio" ]; then + # See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/bin/modes/workflows.py for details. + # + # It use a pretrained ASR model from torchaudio to generate alignments. + # It will attach word-level alignment information (start, end, and score) to the + # supervisions in each cut. + mkdir -p $cuts_out_dir + + parts=( + train-clean-100 + train-clean-360 + train-other-500 + test-clean + test-other + dev-clean + dev-other + ) + + echo "The alignments will be saved to $cuts_out_dir" + for part in ${parts[@]}; do + echo "Start to align $part" + lhotse workflows align-with-torchaudio --dont-normalize-text \ + $cuts_in_dir/librispeech_cuts_${part}.jsonl.gz \ + $cuts_out_dir/librispeech_cuts_${part}.jsonl.gz + done + echo "Finished" +else + echo "align is expected to be in ('mfa', 'torchaudio'), but got $align" + exit 1 +fi diff --git a/egs/librispeech/ASR/conformer_ctc3/decode.py b/egs/librispeech/ASR/conformer_ctc3/decode.py index 3b24ad597..2300fecc3 100755 --- a/egs/librispeech/ASR/conformer_ctc3/decode.py +++ b/egs/librispeech/ASR/conformer_ctc3/decode.py @@ -40,10 +40,7 @@ Usage: To evaluate symbol delay, you should: (1) Generate cuts with word-time alignments: -./local/add_alignment_librispeech.py \ - --alignments-dir data/alignment \ - --cuts-in-dir data/fbank \ - --cuts-out-dir data/fbank_ali +./add_alignments.sh (2) Set the argument "--manifest-dir data/fbank_ali" while decoding. For example: ./conformer_ctc3/decode.py \ diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/decode.py b/egs/librispeech/ASR/lstm_transducer_stateless3/decode.py index b7953e5e3..832b99433 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/decode.py @@ -94,10 +94,7 @@ Usage: To evaluate symbol delay, you should: (1) Generate cuts with word-time alignments: -./local/add_alignment_librispeech.py \ - --alignments-dir data/alignment \ - --cuts-in-dir data/fbank \ - --cuts-out-dir data/fbank_ali +./add_alignments.sh (2) Set the argument "--manifest-dir data/fbank_ali" while decoding. For example: ./lstm_transducer_stateless3/decode.py \ diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py index f5cbc21f7..5fa129a89 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py @@ -109,10 +109,7 @@ Usage: To evaluate symbol delay, you should: (1) Generate cuts with word-time alignments: -./local/add_alignment_librispeech.py \ - --alignments-dir data/alignment \ - --cuts-in-dir data/fbank \ - --cuts-out-dir data/fbank_ali +./add_alignments.sh (2) Set the argument "--manifest-dir data/fbank_ali" while decoding. For example: ./pruned_transducer_stateless4/decode.py \ From 2b995639b7120fcda061978008ed3bc0855fef3a Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Thu, 9 Feb 2023 00:02:38 +0800 Subject: [PATCH 104/174] Add ONNX support for Zipformer and ConvEmformer (#884) --- ...former-transducer-stateless2-2022-12-05.sh | 79 --- ...h-lstm-transducer-stateless2-2022-09-03.sh | 98 --- ...pruned-transducer-stateless7-2022-11-11.sh | 30 - ...nsducer-stateless7-streaming-2022-12-29.sh | 10 - .github/scripts/test-ncnn-export.sh | 133 ++++ .github/scripts/test-onnx-export.sh | 157 +++++ .../run-librispeech-2022-11-11-stateless7.yml | 2 +- ...speech-2022-12-29-stateless7-streaming.yml | 2 +- ...-lstm-transducer-stateless2-2022-09-03.yml | 2 +- ...s2-2022-12-05.yml => test-ncnn-export.yml} | 20 +- docs/source/model-export/export-onnx.rst | 109 +-- .../lstm_pruned_stateless_transducer.rst | 12 +- .../emformer2.py | 129 ++-- .../export-for-ncnn.py | 17 +- .../export-onnx.py | 644 ++++++++++++++++++ .../export.py | 2 + .../onnx_pretrained.py | 456 +++++++++++++ .../train2.py | 2 + .../lstm_transducer_stateless/export-onnx.py | 1 + .../lstm_transducer_stateless/onnx_check.py | 1 + .../onnx_pretrained.py | 1 + .../export-for-ncnn.py | 337 +++++++++ .../lstm_transducer_stateless2/export-onnx.py | 593 ++++++++++++++++ .../ASR/lstm_transducer_stateless2/export.py | 298 +------- .../lstm_transducer_stateless2/ncnn-decode.py | 25 +- .../lstm_transducer_stateless2/onnx_check.py | 261 +++++++ .../onnx_pretrained.py | 428 ++++++++++++ .../streaming-ncnn-decode.py | 22 +- .../lstm_transducer_stateless3/export-onnx.py | 1 + .../lstm_transducer_stateless3/onnx_check.py | 1 + .../onnx_pretrained.py | 1 + .../onnx_check.py | 6 +- .../export-onnx.py | 560 +++++++++++++++ .../pruned_transducer_stateless7/export.py | 267 +------- .../onnx_check.py | 287 +------- .../onnx_pretrained.py | 389 +---------- .../pruned_transducer_stateless7/zipformer.py | 12 +- .../export-onnx.py | 8 +- .../onnx_pretrained.py | 9 +- 39 files changed, 3806 insertions(+), 1606 deletions(-) delete mode 100755 .github/scripts/run-librispeech-conv-emformer-transducer-stateless2-2022-12-05.sh create mode 100755 .github/scripts/test-ncnn-export.sh rename .github/workflows/{run-librispeech-conv-emformer-transducer-stateless2-2022-12-05.yml => test-ncnn-export.yml} (75%) create mode 100755 egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-onnx.py create mode 100755 egs/librispeech/ASR/conv_emformer_transducer_stateless2/onnx_pretrained.py create mode 120000 egs/librispeech/ASR/lstm_transducer_stateless/export-onnx.py create mode 120000 egs/librispeech/ASR/lstm_transducer_stateless/onnx_check.py create mode 120000 egs/librispeech/ASR/lstm_transducer_stateless/onnx_pretrained.py create mode 100755 egs/librispeech/ASR/lstm_transducer_stateless2/export-for-ncnn.py create mode 100755 egs/librispeech/ASR/lstm_transducer_stateless2/export-onnx.py create mode 100755 egs/librispeech/ASR/lstm_transducer_stateless2/onnx_check.py create mode 100755 egs/librispeech/ASR/lstm_transducer_stateless2/onnx_pretrained.py create mode 120000 egs/librispeech/ASR/lstm_transducer_stateless3/export-onnx.py create mode 120000 egs/librispeech/ASR/lstm_transducer_stateless3/onnx_check.py create mode 120000 egs/librispeech/ASR/lstm_transducer_stateless3/onnx_pretrained.py create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless7/export-onnx.py mode change 100755 => 120000 egs/librispeech/ASR/pruned_transducer_stateless7/onnx_check.py mode change 100755 => 120000 egs/librispeech/ASR/pruned_transducer_stateless7/onnx_pretrained.py diff --git a/.github/scripts/run-librispeech-conv-emformer-transducer-stateless2-2022-12-05.sh b/.github/scripts/run-librispeech-conv-emformer-transducer-stateless2-2022-12-05.sh deleted file mode 100755 index 32c939206..000000000 --- a/.github/scripts/run-librispeech-conv-emformer-transducer-stateless2-2022-12-05.sh +++ /dev/null @@ -1,79 +0,0 @@ -#!/usr/bin/env bash -# -set -e - -log() { - # This function is from espnet - local fname=${BASH_SOURCE[1]##*/} - echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" -} - -cd egs/librispeech/ASR - -repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05 - -log "Downloading pre-trained model from $repo_url" -GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url -repo=$(basename $repo_url) -pushd $repo -git lfs pull --include "exp/pretrained-epoch-30-avg-10-averaged.pt" -git lfs pull --include "data/lang_bpe_500/bpe.model" -cd exp -ln -s pretrained-epoch-30-avg-10-averaged.pt epoch-99.pt -popd - -log "Display test files" -tree $repo/ -soxi $repo/test_wavs/*.wav -ls -lh $repo/test_wavs/*.wav - -log "Install ncnn and pnnx" - -# We are using a modified ncnn here. Will try to merge it to the official repo -# of ncnn -git clone https://github.com/csukuangfj/ncnn -pushd ncnn -git submodule init -git submodule update python/pybind11 -python3 setup.py bdist_wheel -ls -lh dist/ -pip install dist/*.whl -cd tools/pnnx -mkdir build -cd build -cmake -D Python3_EXECUTABLE=/opt/hostedtoolcache/Python/3.8.14/x64/bin/python3 .. -make -j4 pnnx - -./src/pnnx || echo "pass" - -popd - -log "Test exporting to pnnx format" - -./conv_emformer_transducer_stateless2/export-for-ncnn.py \ - --exp-dir $repo/exp \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ - --epoch 99 \ - --avg 1 \ - --use-averaged-model 0 \ - \ - --num-encoder-layers 12 \ - --chunk-length 32 \ - --cnn-module-kernel 31 \ - --left-context-length 32 \ - --right-context-length 8 \ - --memory-size 32 - -./ncnn/tools/pnnx/build/src/pnnx $repo/exp/encoder_jit_trace-pnnx.pt -./ncnn/tools/pnnx/build/src/pnnx $repo/exp/decoder_jit_trace-pnnx.pt -./ncnn/tools/pnnx/build/src/pnnx $repo/exp/joiner_jit_trace-pnnx.pt - -./conv_emformer_transducer_stateless2/streaming-ncnn-decode.py \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - --encoder-param-filename $repo/exp/encoder_jit_trace-pnnx.ncnn.param \ - --encoder-bin-filename $repo/exp/encoder_jit_trace-pnnx.ncnn.bin \ - --decoder-param-filename $repo/exp/decoder_jit_trace-pnnx.ncnn.param \ - --decoder-bin-filename $repo/exp/decoder_jit_trace-pnnx.ncnn.bin \ - --joiner-param-filename $repo/exp/joiner_jit_trace-pnnx.ncnn.param \ - --joiner-bin-filename $repo/exp/joiner_jit_trace-pnnx.ncnn.bin \ - $repo/test_wavs/1089-134686-0001.wav diff --git a/.github/scripts/run-librispeech-lstm-transducer-stateless2-2022-09-03.sh b/.github/scripts/run-librispeech-lstm-transducer-stateless2-2022-09-03.sh index 9b883f889..91cdea01a 100755 --- a/.github/scripts/run-librispeech-lstm-transducer-stateless2-2022-09-03.sh +++ b/.github/scripts/run-librispeech-lstm-transducer-stateless2-2022-09-03.sh @@ -28,63 +28,6 @@ ln -s pretrained-iter-468000-avg-16.pt pretrained.pt ln -s pretrained-iter-468000-avg-16.pt epoch-99.pt popd -log "Install ncnn and pnnx" - -# We are using a modified ncnn here. Will try to merge it to the official repo -# of ncnn -git clone https://github.com/csukuangfj/ncnn -pushd ncnn -git submodule init -git submodule update python/pybind11 -python3 setup.py bdist_wheel -ls -lh dist/ -pip install dist/*.whl -cd tools/pnnx -mkdir build -cd build -cmake .. -make -j4 pnnx - -./src/pnnx || echo "pass" - -popd - -log "Test exporting to pnnx format" - -./lstm_transducer_stateless2/export.py \ - --exp-dir $repo/exp \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ - --epoch 99 \ - --avg 1 \ - --use-averaged-model 0 \ - --pnnx 1 - -./ncnn/tools/pnnx/build/src/pnnx $repo/exp/encoder_jit_trace-pnnx.pt -./ncnn/tools/pnnx/build/src/pnnx $repo/exp/decoder_jit_trace-pnnx.pt -./ncnn/tools/pnnx/build/src/pnnx $repo/exp/joiner_jit_trace-pnnx.pt - -./lstm_transducer_stateless2/ncnn-decode.py \ - --bpe-model-filename $repo/data/lang_bpe_500/bpe.model \ - --encoder-param-filename $repo/exp/encoder_jit_trace-pnnx.ncnn.param \ - --encoder-bin-filename $repo/exp/encoder_jit_trace-pnnx.ncnn.bin \ - --decoder-param-filename $repo/exp/decoder_jit_trace-pnnx.ncnn.param \ - --decoder-bin-filename $repo/exp/decoder_jit_trace-pnnx.ncnn.bin \ - --joiner-param-filename $repo/exp/joiner_jit_trace-pnnx.ncnn.param \ - --joiner-bin-filename $repo/exp/joiner_jit_trace-pnnx.ncnn.bin \ - $repo/test_wavs/1089-134686-0001.wav - -./lstm_transducer_stateless2/streaming-ncnn-decode.py \ - --bpe-model-filename $repo/data/lang_bpe_500/bpe.model \ - --encoder-param-filename $repo/exp/encoder_jit_trace-pnnx.ncnn.param \ - --encoder-bin-filename $repo/exp/encoder_jit_trace-pnnx.ncnn.bin \ - --decoder-param-filename $repo/exp/decoder_jit_trace-pnnx.ncnn.param \ - --decoder-bin-filename $repo/exp/decoder_jit_trace-pnnx.ncnn.bin \ - --joiner-param-filename $repo/exp/joiner_jit_trace-pnnx.ncnn.param \ - --joiner-bin-filename $repo/exp/joiner_jit_trace-pnnx.ncnn.bin \ - $repo/test_wavs/1089-134686-0001.wav - - - log "Test exporting with torch.jit.trace()" ./lstm_transducer_stateless2/export.py \ @@ -106,47 +49,6 @@ log "Decode with models exported by torch.jit.trace()" $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav -log "Test exporting to ONNX" - -./lstm_transducer_stateless2/export.py \ - --exp-dir $repo/exp \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ - --epoch 99 \ - --avg 1 \ - --use-averaged-model 0 \ - --onnx 1 - -log "Decode with ONNX models " - -./lstm_transducer_stateless2/streaming-onnx-decode.py \ - --bpe-model-filename $repo/data/lang_bpe_500/bpe.model \ - --encoder-model-filename $repo//exp/encoder.onnx \ - --decoder-model-filename $repo/exp/decoder.onnx \ - --joiner-model-filename $repo/exp/joiner.onnx \ - --joiner-encoder-proj-model-filename $repo/exp/joiner_encoder_proj.onnx \ - --joiner-decoder-proj-model-filename $repo/exp/joiner_decoder_proj.onnx \ - $repo/test_wavs/1089-134686-0001.wav - -./lstm_transducer_stateless2/streaming-onnx-decode.py \ - --bpe-model-filename $repo/data/lang_bpe_500/bpe.model \ - --encoder-model-filename $repo//exp/encoder.onnx \ - --decoder-model-filename $repo/exp/decoder.onnx \ - --joiner-model-filename $repo/exp/joiner.onnx \ - --joiner-encoder-proj-model-filename $repo/exp/joiner_encoder_proj.onnx \ - --joiner-decoder-proj-model-filename $repo/exp/joiner_decoder_proj.onnx \ - $repo/test_wavs/1221-135766-0001.wav - -./lstm_transducer_stateless2/streaming-onnx-decode.py \ - --bpe-model-filename $repo/data/lang_bpe_500/bpe.model \ - --encoder-model-filename $repo//exp/encoder.onnx \ - --decoder-model-filename $repo/exp/decoder.onnx \ - --joiner-model-filename $repo/exp/joiner.onnx \ - --joiner-encoder-proj-model-filename $repo/exp/joiner_encoder_proj.onnx \ - --joiner-decoder-proj-model-filename $repo/exp/joiner_decoder_proj.onnx \ - $repo/test_wavs/1221-135766-0002.wav - - - for sym in 1 2 3; do log "Greedy search with --max-sym-per-frame $sym" diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless7-2022-11-11.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless7-2022-11-11.sh index 999841b80..8e485d2e6 100755 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless7-2022-11-11.sh +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless7-2022-11-11.sh @@ -30,15 +30,6 @@ ln -s pretrained.pt epoch-99.pt ls -lh *.pt popd -log "Test exporting to ONNX format" -./pruned_transducer_stateless7/export.py \ - --exp-dir $repo/exp \ - --use-averaged-model false \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ - --epoch 99 \ - --avg 1 \ - --onnx 1 - log "Export to torchscript model" ./pruned_transducer_stateless7/export.py \ --exp-dir $repo/exp \ @@ -50,27 +41,6 @@ log "Export to torchscript model" ls -lh $repo/exp/*.pt -log "Decode with ONNX models" - -./pruned_transducer_stateless7/onnx_check.py \ - --jit-filename $repo/exp/cpu_jit.pt \ - --onnx-encoder-filename $repo/exp/encoder.onnx \ - --onnx-decoder-filename $repo/exp/decoder.onnx \ - --onnx-joiner-filename $repo/exp/joiner.onnx \ - --onnx-joiner-encoder-proj-filename $repo/exp/joiner_encoder_proj.onnx \ - --onnx-joiner-decoder-proj-filename $repo/exp/joiner_decoder_proj.onnx - -./pruned_transducer_stateless7/onnx_pretrained.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ - --encoder-model-filename $repo/exp/encoder.onnx \ - --decoder-model-filename $repo/exp/decoder.onnx \ - --joiner-model-filename $repo/exp/joiner.onnx \ - --joiner-encoder-proj-model-filename $repo/exp/joiner_encoder_proj.onnx \ - --joiner-decoder-proj-model-filename $repo/exp/joiner_decoder_proj.onnx \ - $repo/test_wavs/1089-134686-0001.wav \ - $repo/test_wavs/1221-135766-0001.wav \ - $repo/test_wavs/1221-135766-0002.wav - log "Decode with models exported by torch.jit.script()" ./pruned_transducer_stateless7/jit_pretrained.py \ diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless7-streaming-2022-12-29.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless7-streaming-2022-12-29.sh index 2611c8efc..584f5d488 100755 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless7-streaming-2022-12-29.sh +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless7-streaming-2022-12-29.sh @@ -34,16 +34,6 @@ ln -s pretrained.pt epoch-99.pt ls -lh *.pt popd -log "Test exporting to ONNX format" -./pruned_transducer_stateless7_streaming/export.py \ - --exp-dir $repo/exp \ - --use-averaged-model false \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ - --epoch 99 \ - --avg 1 \ - --fp16 \ - --onnx 1 - log "Export to torchscript model" ./pruned_transducer_stateless7_streaming/export.py \ --exp-dir $repo/exp \ diff --git a/.github/scripts/test-ncnn-export.sh b/.github/scripts/test-ncnn-export.sh new file mode 100755 index 000000000..c6d70ae7a --- /dev/null +++ b/.github/scripts/test-ncnn-export.sh @@ -0,0 +1,133 @@ +#!/usr/bin/env bash + +set -e + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +cd egs/librispeech/ASR + +log "Install ncnn and pnnx" + +# We are using a modified ncnn here. Will try to merge it to the official repo +# of ncnn +git clone https://github.com/csukuangfj/ncnn +pushd ncnn +git submodule init +git submodule update python/pybind11 +python3 setup.py bdist_wheel +ls -lh dist/ +pip install dist/*.whl +cd tools/pnnx +mkdir build +cd build + +echo "which python3" + +which python3 +#/opt/hostedtoolcache/Python/3.8.16/x64/bin/python3 + +cmake -D Python3_EXECUTABLE=$(which python3) .. +make -j4 pnnx + +./src/pnnx || echo "pass" + +popd + +log "==========================================================================" +repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "data/lang_bpe_500/bpe.model" +git lfs pull --include "exp/pretrained-epoch-30-avg-10-averaged.pt" + +cd exp +ln -s pretrained-epoch-30-avg-10-averaged.pt epoch-99.pt +popd + +log "Export via torch.jit.trace()" + +./conv_emformer_transducer_stateless2/export-for-ncnn.py \ + --exp-dir $repo/exp \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --epoch 99 \ + --avg 1 \ + --use-averaged-model 0 \ + \ + --num-encoder-layers 12 \ + --chunk-length 32 \ + --cnn-module-kernel 31 \ + --left-context-length 32 \ + --right-context-length 8 \ + --memory-size 32 + +./ncnn/tools/pnnx/build/src/pnnx $repo/exp/encoder_jit_trace-pnnx.pt +./ncnn/tools/pnnx/build/src/pnnx $repo/exp/decoder_jit_trace-pnnx.pt +./ncnn/tools/pnnx/build/src/pnnx $repo/exp/joiner_jit_trace-pnnx.pt + +python3 ./conv_emformer_transducer_stateless2/streaming-ncnn-decode.py \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --encoder-param-filename $repo/exp/encoder_jit_trace-pnnx.ncnn.param \ + --encoder-bin-filename $repo/exp/encoder_jit_trace-pnnx.ncnn.bin \ + --decoder-param-filename $repo/exp/decoder_jit_trace-pnnx.ncnn.param \ + --decoder-bin-filename $repo/exp/decoder_jit_trace-pnnx.ncnn.bin \ + --joiner-param-filename $repo/exp/joiner_jit_trace-pnnx.ncnn.param \ + --joiner-bin-filename $repo/exp/joiner_jit_trace-pnnx.ncnn.bin \ + $repo/test_wavs/1089-134686-0001.wav + +rm -rf $repo +log "--------------------------------------------------------------------------" + +log "==========================================================================" +repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "data/lang_bpe_500/bpe.model" +git lfs pull --include "exp/pretrained-iter-468000-avg-16.pt" + +cd exp +ln -s pretrained-iter-468000-avg-16.pt epoch-99.pt +popd + +log "Export via torch.jit.trace()" + +./lstm_transducer_stateless2/export-for-ncnn.py \ + --exp-dir $repo/exp \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --epoch 99 \ + --avg 1 \ + --use-averaged-model 0 + +./ncnn/tools/pnnx/build/src/pnnx $repo/exp/encoder_jit_trace-pnnx.pt +./ncnn/tools/pnnx/build/src/pnnx $repo/exp/decoder_jit_trace-pnnx.pt +./ncnn/tools/pnnx/build/src/pnnx $repo/exp/joiner_jit_trace-pnnx.pt + +python3 ./lstm_transducer_stateless2/streaming-ncnn-decode.py \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --encoder-param-filename $repo/exp/encoder_jit_trace-pnnx.ncnn.param \ + --encoder-bin-filename $repo/exp/encoder_jit_trace-pnnx.ncnn.bin \ + --decoder-param-filename $repo/exp/decoder_jit_trace-pnnx.ncnn.param \ + --decoder-bin-filename $repo/exp/decoder_jit_trace-pnnx.ncnn.bin \ + --joiner-param-filename $repo/exp/joiner_jit_trace-pnnx.ncnn.param \ + --joiner-bin-filename $repo/exp/joiner_jit_trace-pnnx.ncnn.bin \ + $repo/test_wavs/1089-134686-0001.wav + +python3 ./lstm_transducer_stateless2/ncnn-decode.py \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --encoder-param-filename $repo/exp/encoder_jit_trace-pnnx.ncnn.param \ + --encoder-bin-filename $repo/exp/encoder_jit_trace-pnnx.ncnn.bin \ + --decoder-param-filename $repo/exp/decoder_jit_trace-pnnx.ncnn.param \ + --decoder-bin-filename $repo/exp/decoder_jit_trace-pnnx.ncnn.bin \ + --joiner-param-filename $repo/exp/joiner_jit_trace-pnnx.ncnn.param \ + --joiner-bin-filename $repo/exp/joiner_jit_trace-pnnx.ncnn.bin \ + $repo/test_wavs/1089-134686-0001.wav + +rm -rf $repo +log "--------------------------------------------------------------------------" diff --git a/.github/scripts/test-onnx-export.sh b/.github/scripts/test-onnx-export.sh index 2d5cc6bbf..39467c44a 100755 --- a/.github/scripts/test-onnx-export.sh +++ b/.github/scripts/test-onnx-export.sh @@ -10,6 +10,8 @@ log() { cd egs/librispeech/ASR + + log "==========================================================================" repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29 log "Downloading pre-trained model from $repo_url" @@ -192,3 +194,158 @@ log "Run onnx_pretrained.py" rm -rf $repo log "--------------------------------------------------------------------------" + +log "==========================================================================" +repo_url= + +rm -rf $repo +log "--------------------------------------------------------------------------" +repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "data/lang_bpe_500/bpe.model" +git lfs pull --include "exp/pretrained.pt" + +cd exp +ln -s pretrained.pt epoch-99.pt +popd + +log "Export via torch.jit.script()" + +./pruned_transducer_stateless7/export.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --exp-dir $repo/exp \ + --feedforward-dims "1024,1024,2048,2048,1024" \ + --jit 1 + +log "Test exporting to ONNX format" + +./pruned_transducer_stateless7/export-onnx.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --exp-dir $repo/exp \ + --feedforward-dims "1024,1024,2048,2048,1024" + +ls -lh $repo/exp + +log "Run onnx_check.py" + +./pruned_transducer_stateless7/onnx_check.py \ + --jit-filename $repo/exp/cpu_jit.pt \ + --onnx-encoder-filename $repo/exp/encoder-epoch-99-avg-1.onnx \ + --onnx-decoder-filename $repo/exp/decoder-epoch-99-avg-1.onnx \ + --onnx-joiner-filename $repo/exp/joiner-epoch-99-avg-1.onnx + +log "Run onnx_pretrained.py" + +./pruned_transducer_stateless7/onnx_pretrained.py \ + --encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \ + --decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \ + --joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + +log "==========================================================================" +repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "data/lang_bpe_500/bpe.model" +git lfs pull --include "exp/pretrained-epoch-30-avg-10-averaged.pt" + +cd exp +ln -s pretrained-epoch-30-avg-10-averaged.pt epoch-99.pt +popd + +log "Test exporting to ONNX format" + +./conv_emformer_transducer_stateless2/export-onnx.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --exp-dir $repo/exp \ + --num-encoder-layers 12 \ + --chunk-length 32 \ + --cnn-module-kernel 31 \ + --left-context-length 32 \ + --right-context-length 8 \ + --memory-size 32 + +log "Run onnx_pretrained.py" + +./conv_emformer_transducer_stateless2/onnx_pretrained.py \ + --encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \ + --decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \ + --joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1221-135766-0001.wav + +rm -rf $repo +log "--------------------------------------------------------------------------" + +log "==========================================================================" +repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "data/lang_bpe_500/bpe.model" +git lfs pull --include "exp/pretrained-iter-468000-avg-16.pt" + +cd exp +ln -s pretrained-iter-468000-avg-16.pt epoch-99.pt +popd + +log "Export via torch.jit.trace()" + +./lstm_transducer_stateless2/export.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --exp-dir $repo/exp/ \ + --jit-trace 1 + +log "Test exporting to ONNX format" + +./lstm_transducer_stateless2/export-onnx.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --exp-dir $repo/exp + +ls -lh $repo/exp + +log "Run onnx_check.py" + +./lstm_transducer_stateless2/onnx_check.py \ + --jit-encoder-filename $repo/exp/encoder_jit_trace.pt \ + --jit-decoder-filename $repo/exp/decoder_jit_trace.pt \ + --jit-joiner-filename $repo/exp/joiner_jit_trace.pt \ + --onnx-encoder-filename $repo/exp/encoder-epoch-99-avg-1.onnx \ + --onnx-decoder-filename $repo/exp/decoder-epoch-99-avg-1.onnx \ + --onnx-joiner-filename $repo/exp/joiner-epoch-99-avg-1.onnx + +log "Run onnx_pretrained.py" + +./lstm_transducer_stateless2/onnx_pretrained.py \ + --encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \ + --decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \ + --joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1221-135766-0001.wav + +rm -rf $repo +log "--------------------------------------------------------------------------" diff --git a/.github/workflows/run-librispeech-2022-11-11-stateless7.yml b/.github/workflows/run-librispeech-2022-11-11-stateless7.yml index 7694e8bf5..365e2761a 100644 --- a/.github/workflows/run-librispeech-2022-11-11-stateless7.yml +++ b/.github/workflows/run-librispeech-2022-11-11-stateless7.yml @@ -39,7 +39,7 @@ concurrency: jobs: run_librispeech_2022_11_11_zipformer: - if: github.event.label.name == 'onnx' || github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule' + if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule' runs-on: ${{ matrix.os }} strategy: matrix: diff --git a/.github/workflows/run-librispeech-2022-12-29-stateless7-streaming.yml b/.github/workflows/run-librispeech-2022-12-29-stateless7-streaming.yml index a1f3b4f75..6dd93946a 100644 --- a/.github/workflows/run-librispeech-2022-12-29-stateless7-streaming.yml +++ b/.github/workflows/run-librispeech-2022-12-29-stateless7-streaming.yml @@ -39,7 +39,7 @@ concurrency: jobs: run_librispeech_2022_12_29_zipformer_streaming: - if: github.event.label.name == 'onnx' || github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event.label.name == 'streaming-zipformer' || github.event_name == 'push' || github.event_name == 'schedule' + if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event.label.name == 'streaming-zipformer' || github.event_name == 'push' || github.event_name == 'schedule' runs-on: ${{ matrix.os }} strategy: matrix: diff --git a/.github/workflows/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml b/.github/workflows/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml index 3752f67e3..f737d9a25 100644 --- a/.github/workflows/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml +++ b/.github/workflows/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml @@ -22,7 +22,7 @@ concurrency: jobs: run_librispeech_lstm_transducer_stateless2_2022_09_03: - if: github.event.label.name == 'ready' || github.event.label.name == 'LODR' || github.event.label.name == 'shallow-fusion' || github.event.label.name == 'ncnn' || github.event.label.name == 'onnx' || github.event_name == 'push' || github.event_name == 'schedule' + if: github.event.label.name == 'ready' || github.event.label.name == 'LODR' || github.event.label.name == 'shallow-fusion' || github.event_name == 'push' || github.event_name == 'schedule' runs-on: ${{ matrix.os }} strategy: matrix: diff --git a/.github/workflows/run-librispeech-conv-emformer-transducer-stateless2-2022-12-05.yml b/.github/workflows/test-ncnn-export.yml similarity index 75% rename from .github/workflows/run-librispeech-conv-emformer-transducer-stateless2-2022-12-05.yml rename to .github/workflows/test-ncnn-export.yml index b9a1582c4..e10cfe76b 100644 --- a/.github/workflows/run-librispeech-conv-emformer-transducer-stateless2-2022-12-05.yml +++ b/.github/workflows/test-ncnn-export.yml @@ -1,4 +1,4 @@ -name: run-librispeech-conv-emformer-transducer-stateless2-2022-12-05 +name: test-ncnn-export on: push: @@ -16,15 +16,18 @@ on: # nightly build at 15:50 UTC time every day - cron: "50 15 * * *" +concurrency: + group: test_ncnn_export-${{ github.ref }} + cancel-in-progress: true + jobs: - run_librispeech_conv_emformer_transducer_stateless2_2022_12_05: + test_ncnn_export: if: github.event.label.name == 'ready' || github.event.label.name == 'ncnn' || github.event_name == 'push' || github.event_name == 'schedule' runs-on: ${{ matrix.os }} strategy: matrix: os: [ubuntu-latest] python-version: [3.8] - fail-fast: false steps: @@ -41,7 +44,7 @@ jobs: - name: Install Python dependencies run: | - grep -v '^#' ./requirements-ci.txt | grep -v kaldifst | xargs -n 1 -L 1 pip install + grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install pip uninstall -y protobuf pip install --no-binary protobuf protobuf @@ -59,19 +62,14 @@ jobs: run: | .github/scripts/install-kaldifeat.sh - - name: Inference with pre-trained model + - name: Test ncnn export shell: bash env: GITHUB_EVENT_NAME: ${{ github.event_name }} GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }} run: | - mkdir -p egs/librispeech/ASR/data - ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank - ls -lh egs/librispeech/ASR/data/* - - sudo apt-get -qq install git-lfs tree sox export PYTHONPATH=$PWD:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH - .github/scripts/run-librispeech-conv-emformer-transducer-stateless2-2022-12-05.sh + .github/scripts/test-ncnn-export.sh diff --git a/docs/source/model-export/export-onnx.rst b/docs/source/model-export/export-onnx.rst index dd4b3437a..83c8440b5 100644 --- a/docs/source/model-export/export-onnx.rst +++ b/docs/source/model-export/export-onnx.rst @@ -1,69 +1,78 @@ Export to ONNX ============== -In this section, we describe how to export models to ONNX. +In this section, we describe how to export the following models to ONNX. + +In each recipe, there is a file called ``export-onnx.py``, which is used +to export trained models to ONNX. + +There is also a file named ``onnx_pretrained.py``, which you can use +the exported ONNX model in Python to decode sound files. + +Example +======= + +In the following, we demonstrate how to export a streaming Zipformer pre-trained +model from ``_ +to ONNX. + +Download the pre-trained model +------------------------------ .. hint:: - Only non-streaming conformer transducer models are tested. - - -When to use it --------------- - -It you want to use an inference framework that supports ONNX -to run the pretrained model. - - -How to export -------------- - -We use -``_ -as an example in the following. + We assume you have installed `git-lfs`_. .. code-block:: bash - cd egs/librispeech/ASR - epoch=14 - avg=2 - ./pruned_transducer_stateless3/export.py \ - --exp-dir ./pruned_transducer_stateless3/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ - --epoch $epoch \ - --avg $avg \ - --onnx 1 + cd egs/librispeech/ASR -It will generate the following files inside ``pruned_transducer_stateless3/exp``: + repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29 + GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url + repo=$(basename $repo_url) - - ``encoder.onnx`` - - ``decoder.onnx`` - - ``joiner.onnx`` - - ``joiner_encoder_proj.onnx`` - - ``joiner_decoder_proj.onnx`` + pushd $repo + git lfs pull --include "data/lang_bpe_500/bpe.model" + git lfs pull --include "exp/pretrained.pt" + cd exp + ln -s pretrained.pt epoch-99.pt + popd -You can use ``./pruned_transducer_stateless3/exp/onnx_pretrained.py`` to decode -waves with the generated files: +Export the model to ONNX +------------------------ .. code-block:: bash - ./pruned_transducer_stateless3/onnx_pretrained.py \ - --bpe-model ./data/lang_bpe_500/bpe.model \ - --encoder-model-filename ./pruned_transducer_stateless3/exp/encoder.onnx \ - --decoder-model-filename ./pruned_transducer_stateless3/exp/decoder.onnx \ - --joiner-model-filename ./pruned_transducer_stateless3/exp/joiner.onnx \ - --joiner-encoder-proj-model-filename ./pruned_transducer_stateless3/exp/joiner_encoder_proj.onnx \ - --joiner-decoder-proj-model-filename ./pruned_transducer_stateless3/exp/joiner_decoder_proj.onnx \ - /path/to/foo.wav \ - /path/to/bar.wav \ - /path/to/baz.wav + ./pruned_transducer_stateless7_streaming/export-onnx.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --decode-chunk-len 32 \ + --exp-dir $repo/exp/ +.. warning:: -How to use the exported model ------------------------------ + ``export-onnx.py`` from different recipes has different options. -We also provide ``_ -performing speech recognition using `onnxruntime `_ -with exported models. -It has been tested on Linux, macOS, and Windows. + In the above example, ``--decode-chunk-len`` is specific for the + streaming Zipformer. Other models won't have such an option. + +It will generate the following 3 files in ``$repo/exp`` + + - ``encoder-epoch-99-avg-1.onnx`` + - ``decoder-epoch-99-avg-1.onnx`` + - ``joiner-epoch-99-avg-1.onnx`` + +Decode sound files with exported ONNX models +-------------------------------------------- + +.. code-block:: bash + + ./pruned_transducer_stateless7_streaming/onnx_pretrained.py \ + --encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \ + --decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \ + --joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1089-134686-0001.wav diff --git a/docs/source/recipes/Streaming-ASR/librispeech/lstm_pruned_stateless_transducer.rst b/docs/source/recipes/Streaming-ASR/librispeech/lstm_pruned_stateless_transducer.rst index ce8ba1453..d04565e5d 100644 --- a/docs/source/recipes/Streaming-ASR/librispeech/lstm_pruned_stateless_transducer.rst +++ b/docs/source/recipes/Streaming-ASR/librispeech/lstm_pruned_stateless_transducer.rst @@ -580,12 +580,11 @@ for ``pnnx``: iter=468000 avg=16 - ./lstm_transducer_stateless2/export.py \ + ./lstm_transducer_stateless2/export-for-ncnn.py \ --exp-dir ./lstm_transducer_stateless2/exp \ --bpe-model data/lang_bpe_500/bpe.model \ --iter $iter \ - --avg $avg \ - --pnnx 1 + --avg $avg It will generate 3 files: @@ -615,7 +614,7 @@ To use the above generated files, run: .. code-block:: bash ./lstm_transducer_stateless2/ncnn-decode.py \ - --bpe-model-filename ./data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --encoder-param-filename ./lstm_transducer_stateless2/exp/encoder_jit_trace-pnnx.ncnn.param \ --encoder-bin-filename ./lstm_transducer_stateless2/exp/encoder_jit_trace-pnnx.ncnn.bin \ --decoder-param-filename ./lstm_transducer_stateless2/exp/decoder_jit_trace-pnnx.ncnn.param \ @@ -627,7 +626,7 @@ To use the above generated files, run: .. code-block:: bash ./lstm_transducer_stateless2/streaming-ncnn-decode.py \ - --bpe-model-filename ./data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --encoder-param-filename ./lstm_transducer_stateless2/exp/encoder_jit_trace-pnnx.ncnn.param \ --encoder-bin-filename ./lstm_transducer_stateless2/exp/encoder_jit_trace-pnnx.ncnn.bin \ --decoder-param-filename ./lstm_transducer_stateless2/exp/decoder_jit_trace-pnnx.ncnn.param \ @@ -657,6 +656,3 @@ by visiting the following links: You can find more usages of the pretrained models in ``_ - -Export ConvEmformer transducer models for ncnn -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer2.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer2.py index f0c92a9b4..4a844b79f 100644 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer2.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer2.py @@ -169,9 +169,11 @@ class ConvolutionModule(nn.Module): channels: int, kernel_size: int, bias: bool = True, + is_pnnx: bool = True, ) -> None: """Construct an ConvolutionModule object.""" super().__init__() + self.is_pnnx = is_pnnx # kernerl_size should be an odd number for 'SAME' padding assert (kernel_size - 1) % 2 == 0, kernel_size @@ -383,12 +385,14 @@ class ConvolutionModule(nn.Module): - output right_context of shape (R, B, D). - updated cache tensor of shape (B, D, cache_size). """ - # U, B, D = utterance.size() - # R, _, _ = right_context.size() - U = self.chunk_length - B = 1 - D = self.channels - R = self.right_context_length + if self.is_pnnx is False: + U, B, D = utterance.size() + R, _, _ = right_context.size() + else: + U = self.chunk_length + B = 1 + D = self.channels + R = self.right_context_length # point-wise conv x = torch.cat([utterance, right_context], dim=0) # (U + R, B, D) @@ -448,8 +452,10 @@ class EmformerAttention(nn.Module): dropout: float = 0.0, tanh_on_mem: bool = False, negative_inf: float = -1e8, + is_pnnx: bool = True, ): super().__init__() + self.is_pnnx = is_pnnx if embed_dim % nhead != 0: raise ValueError( @@ -539,14 +545,15 @@ class EmformerAttention(nn.Module): left_context_val: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Underlying chunk-wise attention implementation.""" - # U, B, _ = utterance.size() - # R = right_context.size(0) - # M = memory.size(0) - - U = self.chunk_length - B = 1 - R = self.right_context_length - M = self.memory_size + if self.is_pnnx is False: + U, B, _ = utterance.size() + R = right_context.size(0) + M = memory.size(0) + else: + U = self.chunk_length + B = 1 + R = self.right_context_length + M = self.memory_size L = self.left_context_length scaling = float(self.head_dim) ** -0.5 @@ -570,21 +577,29 @@ class EmformerAttention(nn.Module): # KV = key.size(0) - reshaped_query = query.view(Q, self.nhead, self.head_dim).permute(1, 0, 2) - reshaped_key = key.view(M + R + U + L, self.nhead, self.head_dim).permute( - 1, 0, 2 - ) - reshaped_value = value.view(M + R + U + L, self.nhead, self.head_dim).permute( - 1, 0, 2 - ) - - # reshaped_query, reshaped_key, reshaped_value = [ - # tensor.contiguous().view(-1, B * self.nhead, self.head_dim).transpose(0, 1) - # for tensor in [query, key, value] - # ] # (B * nhead, Q or KV, head_dim) - attention_weights = torch.bmm( - reshaped_query * scaling, reshaped_key.permute(0, 2, 1) - ) # (B * nhead, Q, KV) + if self.is_pnnx is True: + reshaped_query = query.view(Q, self.nhead, self.head_dim).permute(1, 0, 2) + reshaped_key = key.view(M + R + U + L, self.nhead, self.head_dim).permute( + 1, 0, 2 + ) + reshaped_value = value.view( + M + R + U + L, self.nhead, self.head_dim + ).permute(1, 0, 2) + else: + reshaped_query, reshaped_key, reshaped_value = [ + tensor.contiguous() + .view(-1, B * self.nhead, self.head_dim) + .transpose(0, 1) + for tensor in [query, key, value] + ] # (B * nhead, Q or KV, head_dim) + if self.is_pnnx is True: + attention_weights = torch.bmm( + reshaped_query * scaling, reshaped_key.permute(0, 2, 1) + ) # (B * nhead, Q, KV) + else: + attention_weights = torch.bmm( + reshaped_query * scaling, reshaped_key.transpose(1, 2) + ) # (B * nhead, Q, KV) # compute attention probabilities if False: @@ -597,10 +612,15 @@ class EmformerAttention(nn.Module): # compute attention outputs attention = torch.bmm(attention_probs, reshaped_value) assert attention.shape == (B * self.nhead, Q, self.head_dim) - attention = attention.permute(1, 0, 2).reshape(-1, self.embed_dim) - # TODO(fangjun): ncnn does not support reshape(-1, 1, self.embed_dim) - # We have to change InnerProduct in ncnn to ignore the extra dim below - attention = attention.unsqueeze(1) + if self.is_pnnx is True: + attention = attention.permute(1, 0, 2).reshape(-1, self.embed_dim) + # TODO(fangjun): ncnn does not support reshape(-1, 1, self.embed_dim) + # We have to change InnerProduct in ncnn to ignore the extra dim below + attention = attention.unsqueeze(1) + else: + attention = ( + attention.transpose(0, 1).contiguous().view(Q, B, self.embed_dim) + ) # apply output projection output_right_context_utterance = self.out_proj(attention) @@ -733,15 +753,16 @@ class EmformerAttention(nn.Module): - attention value of left context and utterance, which would be cached for next computation, with shape (L + U, B, D). """ - # U = utterance.size(0) - # R = right_context.size(0) - # L = left_context_key.size(0) - # M = memory.size(0) - - U = self.chunk_length - R = self.right_context_length - L = self.left_context_length - M = self.memory_size + if self.is_pnnx is False: + U = utterance.size(0) + R = right_context.size(0) + L = left_context_key.size(0) + M = memory.size(0) + else: + U = self.chunk_length + R = self.right_context_length + L = self.left_context_length + M = self.memory_size # query = [right context, utterance] Q = R + U @@ -811,6 +832,7 @@ class EmformerEncoderLayer(nn.Module): memory_size: int = 0, tanh_on_mem: bool = False, negative_inf: float = -1e8, + is_pnnx: bool = True, ): super().__init__() @@ -824,6 +846,7 @@ class EmformerEncoderLayer(nn.Module): dropout=dropout, tanh_on_mem=tanh_on_mem, negative_inf=negative_inf, + is_pnnx=is_pnnx, ) self.summary_op = nn.AvgPool1d( kernel_size=chunk_length, stride=chunk_length, ceil_mode=True @@ -850,6 +873,7 @@ class EmformerEncoderLayer(nn.Module): right_context_length, d_model, cnn_module_kernel, + is_pnnx=is_pnnx, ) self.norm_final = BasicNorm(d_model) @@ -1204,6 +1228,7 @@ class EmformerEncoder(nn.Module): memory_size: int = 0, tanh_on_mem: bool = False, negative_inf: float = -1e8, + is_pnnx: bool = True, ): super().__init__() @@ -1229,6 +1254,7 @@ class EmformerEncoder(nn.Module): memory_size=memory_size, tanh_on_mem=tanh_on_mem, negative_inf=negative_inf, + is_pnnx=is_pnnx, ) for layer_idx in range(num_encoder_layers) ] @@ -1561,6 +1587,20 @@ class Emformer(EncoderInterface): self.encoder_embed = Conv2dSubsampling(num_features, d_model, is_pnnx=is_pnnx) self.is_pnnx = is_pnnx + self.num_encoder_layers = num_encoder_layers + self.memory_size = memory_size + self.d_model = d_model + self.cnn_module_kernel = cnn_module_kernel + self.left_context_length = left_context_length // subsampling_factor + self.right_context_length = right_context_length + self.subsampling_factor = subsampling_factor + + assert subsampling_factor == 4, subsampling_factor + pad_length = right_context_length + 2 * 4 + 3 + + # before subsampling + self.T = self.chunk_length + pad_length + self.encoder = EmformerEncoder( chunk_length=chunk_length // subsampling_factor, d_model=d_model, @@ -1575,6 +1615,7 @@ class Emformer(EncoderInterface): memory_size=memory_size, tanh_on_mem=tanh_on_mem, negative_inf=negative_inf, + is_pnnx=is_pnnx, ) def _forward( @@ -1691,7 +1732,7 @@ class Conv2dSubsampling(nn.Module): layer1_channels: int = 8, layer2_channels: int = 32, layer3_channels: int = 128, - is_pnnx: bool = False, + is_pnnx: bool = True, ) -> None: """ Args: @@ -1767,7 +1808,7 @@ class Conv2dSubsampling(nn.Module): x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W) x = self.conv(x) - if torch.jit.is_tracing() and self.is_pnnx: + if torch.jit.is_tracing() and self.is_pnnx is True: x = x.permute(0, 2, 1, 3).reshape(1, -1, self.conv_out_dim) x = self.out(x) else: diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-for-ncnn.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-for-ncnn.py index 64c16141c..e31033c74 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-for-ncnn.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-for-ncnn.py @@ -1,6 +1,10 @@ #!/usr/bin/env python3 """ +Please see +https://k2-fsa.github.io/icefall/model-export/export-ncnn.html +for more details about how to use this file. + Usage: ./conv_emformer_transducer_stateless2/export-for-ncnn.py \ --exp-dir ./conv_emformer_transducer_stateless2/exp \ @@ -44,7 +48,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import str2bool +from icefall.utils import setup_logger, str2bool def get_parser(): @@ -96,14 +100,6 @@ def get_parser(): help="Path to the BPE model", ) - parser.add_argument( - "--jit", - type=str2bool, - default=False, - help="""True to save a model after applying torch.jit.script. - """, - ) - parser.add_argument( "--context-size", type=int, @@ -217,6 +213,8 @@ def main(): device = torch.device("cpu") + setup_logger(f"{params.exp_dir}/log-export/log-export-ncnn") + logging.info(f"device: {device}") sp = spm.SentencePieceProcessor() @@ -330,5 +328,4 @@ def main(): if __name__ == "__main__": formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-onnx.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-onnx.py new file mode 100755 index 000000000..ad0b45bd9 --- /dev/null +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-onnx.py @@ -0,0 +1,644 @@ +#!/usr/bin/env python3 +# +# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang) + +""" +This script exports a transducer model from PyTorch to ONNX. + +We use the pre-trained model from +https://huggingface.co/Zengwei/icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05 +as an example to show how to use this file. + +1. Download the pre-trained model + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "data/lang_bpe_500/bpe.model" +git lfs pull --include "exp/pretrained-epoch-30-avg-10-averaged.pt" + +cd exp +ln -s pretrained-epoch-30-avg-10-averaged.pt epoch-99.pt +popd + +2. Export the model to ONNX + +./conv_emformer_transducer_stateless2/export-onnx.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --exp-dir $repo/exp \ + --num-encoder-layers 12 \ + --chunk-length 32 \ + --cnn-module-kernel 31 \ + --left-context-length 32 \ + --right-context-length 8 \ + --memory-size 32 + +It will generate the following 3 files inside $repo/exp: + + - encoder-epoch-99-avg-1.onnx + - decoder-epoch-99-avg-1.onnx + - joiner-epoch-99-avg-1.onnx + +See ./onnx_pretrained.py for how to +use the exported ONNX models. +""" + +import argparse +import logging +from pathlib import Path +from typing import Dict, Tuple + +import onnx +import sentencepiece as spm +import torch +import torch.nn as nn +from decoder import Decoder +from scaling_converter import convert_scaled_to_non_scaled +from train2 import add_model_arguments, get_params, get_transducer_model +from emformer import Emformer + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import setup_logger, str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="""It specifies the checkpoint to use for averaging. + Note: Epoch counts from 0. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless5/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + add_model_arguments(parser) + + return parser + + +def add_meta_data(filename: str, meta_data: Dict[str, str]): + """Add meta data to an ONNX model. It is changed in-place. + + Args: + filename: + Filename of the ONNX model to be changed. + meta_data: + Key-value pairs. + """ + model = onnx.load(filename) + for key, value in meta_data.items(): + meta = model.metadata_props.add() + meta.key = key + meta.value = value + + onnx.save(model, filename) + + +class OnnxEncoder(nn.Module): + """A wrapper for Emformer and the encoder_proj from the joiner""" + + def __init__(self, encoder: Emformer, encoder_proj: nn.Linear): + """ + Args: + encoder: + A Emformer encoder. + encoder_proj: + The projection layer for encoder from the joiner. + """ + super().__init__() + self.encoder = encoder + self.encoder_proj = encoder_proj + + def forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Please see the help information of Emformer.forward + + Args: + x: + A 3-D tensor of shape (N, T, C) + x_lens: + A 1-D tensor of shape (N,). Its dtype is torch.int64 + Returns: + Return a tuple containing: + - encoder_out, A 3-D tensor of shape (N, T', joiner_dim) + - encoder_out_lens, A 1-D tensor of shape (N,) + """ + encoder_out, encoder_out_lens = self.encoder(x, x_lens) + + encoder_out = self.encoder_proj(encoder_out) + # Now encoder_out is of shape (N, T, joiner_dim) + + return encoder_out, encoder_out_lens + + +class OnnxDecoder(nn.Module): + """A wrapper for Decoder and the decoder_proj from the joiner""" + + def __init__(self, decoder: Decoder, decoder_proj: nn.Linear): + super().__init__() + self.decoder = decoder + self.decoder_proj = decoder_proj + + def forward(self, y: torch.Tensor) -> torch.Tensor: + """ + Args: + y: + A 2-D tensor of shape (N, context_size). + Returns + Return a 2-D tensor of shape (N, joiner_dim) + """ + need_pad = False + decoder_output = self.decoder(y, need_pad=need_pad) + decoder_output = decoder_output.squeeze(1) + output = self.decoder_proj(decoder_output) + + return output + + +class OnnxJoiner(nn.Module): + """A wrapper for the joiner""" + + def __init__(self, output_linear: nn.Linear): + super().__init__() + self.output_linear = output_linear + + def forward( + self, + encoder_out: torch.Tensor, + decoder_out: torch.Tensor, + ) -> torch.Tensor: + """ + Args: + encoder_out: + A 2-D tensor of shape (N, joiner_dim) + decoder_out: + A 2-D tensor of shape (N, joiner_dim) + Returns: + Return a 2-D tensor of shape (N, vocab_size) + """ + logit = encoder_out + decoder_out + logit = self.output_linear(torch.tanh(logit)) + return logit + + +def export_encoder_model_onnx( + encoder_model: OnnxEncoder, + encoder_filename: str, + opset_version: int = 11, +) -> None: + """Export the given encoder model to ONNX format. + The exported model has the following inputs: + + - x, a tensor of shape (N, T, C); dtype is torch.float32 + - a list of states (each layers has 4 states) + + + and it has two outputs: + + - encoder_out, a tensor of shape (N, T', joiner_dim) + - a list of states (each layers has 4 states) + + Args: + encoder_model: + The input encoder model + encoder_filename: + The filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + num_encoder_layers = encoder_model.encoder.num_encoder_layers + memory_size = encoder_model.encoder.memory_size + cnn_module_kernel = encoder_model.encoder.cnn_module_kernel + chunk_length = encoder_model.encoder.chunk_length + right_context_length = encoder_model.encoder.right_context_length + encoder_dim = encoder_model.encoder.d_model + left_context_length = encoder_model.encoder.left_context_length + + T = encoder_model.encoder.T + + logging.info(f"num_encoder_layers={num_encoder_layers}") + logging.info(f"memory_size={memory_size}") + logging.info(f"cnn_module_kernel={cnn_module_kernel}") + logging.info(f"chunk_length={chunk_length}") + logging.info(f"right_context_length={right_context_length}") + logging.info(f"encoder_dim={encoder_dim}") + logging.info(f"left_context_length={left_context_length} (after subsampling)") + logging.info(f"T={T}") + + meta_data = { + "model_type": "conv-emformer", + "version": "1", + "model_author": "k2-fsa", + "decode_chunk_len": str(chunk_length), # 32 + "T": str(T), # 32 + "num_encoder_layers": str(num_encoder_layers), + "memory_size": str(memory_size), + "cnn_module_kernel": str(cnn_module_kernel), + "right_context_length": str(right_context_length), + "left_context_length": str(left_context_length), + "encoder_dim": str(encoder_dim), + } + logging.info(f"meta_data: {meta_data}") + + x = torch.zeros(1, T, 80, dtype=torch.float32) + states = encoder_model.encoder.init_states() + + # Each layer has 4 states + assert len(states) == num_encoder_layers * 4, (len(states), num_encoder_layers) + # layer 0: + # state0: (memory_size, 1, encoder_dim) + # state1: (left_context_length, 1, encoder_dim) + # state2: (left_context_length, 1, encoder_dim) + # state3: (1, encoder_dim, cnn_module_kernel-1) + + inputs = {} + input_names = ["x"] + + outputs = {} + output_names = ["encoder_out"] + + def build_inputs_outputs(s, name): + assert len(s) == 4, len(s) + logging.info(f"{name}_0.shape: {s[0].shape}") + input_names.append(f"{name}_0") + inputs[f"{name}_0"] = {1: "N"} + output_names.append(f"new_{name}_0") + + logging.info(f"{name}_1.shape: {s[1].shape}") + input_names.append(f"{name}_1") + inputs[f"{name}_1"] = {1: "N"} + output_names.append(f"new_{name}_1") + + logging.info(f"{name}_2.shape: {s[2].shape}") + input_names.append(f"{name}_2") + inputs[f"{name}_2"] = {1: "N"} + output_names.append(f"new_{name}_2") + + logging.info(f"{name}_3.shape: {s[3].shape}") + input_names.append(f"{name}_3") + inputs[f"{name}_3"] = {0: "N"} + output_names.append(f"new_{name}_3") + + for i in range(num_encoder_layers): + base_name = f"layer{i}" + s = states[i * 4 : (i + 1) * 4] + build_inputs_outputs(s, base_name) + + torch.onnx.export( + encoder_model, + (x, states), + encoder_filename, + verbose=False, + opset_version=opset_version, + input_names=input_names, + output_names=output_names, + dynamic_axes={ + "x": {0: "N"}, + "encoder_out": {0: "N"}, + **inputs, + **outputs, + }, + ) + + add_meta_data(filename=encoder_filename, meta_data=meta_data) + + +def export_decoder_model_onnx( + decoder_model: OnnxDecoder, + decoder_filename: str, + opset_version: int = 11, +) -> None: + """Export the decoder model to ONNX format. + + The exported model has one input: + + - y: a torch.int64 tensor of shape (N, decoder_model.context_size) + + and has one output: + + - decoder_out: a torch.float32 tensor of shape (N, joiner_dim) + + Args: + decoder_model: + The decoder model to be exported. + decoder_filename: + Filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + context_size = decoder_model.decoder.context_size + vocab_size = decoder_model.decoder.vocab_size + + y = torch.zeros(10, context_size, dtype=torch.int64) + torch.onnx.export( + decoder_model, + y, + decoder_filename, + verbose=False, + opset_version=opset_version, + input_names=["y"], + output_names=["decoder_out"], + dynamic_axes={ + "y": {0: "N"}, + "decoder_out": {0: "N"}, + }, + ) + + meta_data = { + "context_size": str(context_size), + "vocab_size": str(vocab_size), + } + add_meta_data(filename=decoder_filename, meta_data=meta_data) + + +def export_joiner_model_onnx( + joiner_model: nn.Module, + joiner_filename: str, + opset_version: int = 11, +) -> None: + """Export the joiner model to ONNX format. + The exported joiner model has two inputs: + + - encoder_out: a tensor of shape (N, joiner_dim) + - decoder_out: a tensor of shape (N, joiner_dim) + + and produces one output: + + - logit: a tensor of shape (N, vocab_size) + """ + joiner_dim = joiner_model.output_linear.weight.shape[1] + logging.info(f"joiner dim: {joiner_dim}") + + projected_encoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) + projected_decoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) + + torch.onnx.export( + joiner_model, + (projected_encoder_out, projected_decoder_out), + joiner_filename, + verbose=False, + opset_version=opset_version, + input_names=[ + "encoder_out", + "decoder_out", + ], + output_names=["logit"], + dynamic_axes={ + "encoder_out": {0: "N"}, + "decoder_out": {0: "N"}, + "logit": {0: "N"}, + }, + ) + meta_data = { + "joiner_dim": str(joiner_dim), + } + add_meta_data(filename=joiner_filename, meta_data=meta_data) + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + params.is_pnnx = False + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + setup_logger(f"{params.exp_dir}/log-export/log-export-onnx") + + logging.info(f"device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + model.to(device) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to("cpu") + model.eval() + + convert_scaled_to_non_scaled(model, inplace=True) + + encoder = OnnxEncoder( + encoder=model.encoder, + encoder_proj=model.joiner.encoder_proj, + ) + + decoder = OnnxDecoder( + decoder=model.decoder, + decoder_proj=model.joiner.decoder_proj, + ) + + joiner = OnnxJoiner(output_linear=model.joiner.output_linear) + + encoder_num_param = sum([p.numel() for p in encoder.parameters()]) + decoder_num_param = sum([p.numel() for p in decoder.parameters()]) + joiner_num_param = sum([p.numel() for p in joiner.parameters()]) + total_num_param = encoder_num_param + decoder_num_param + joiner_num_param + logging.info(f"encoder parameters: {encoder_num_param}") + logging.info(f"decoder parameters: {decoder_num_param}") + logging.info(f"joiner parameters: {joiner_num_param}") + logging.info(f"total parameters: {total_num_param}") + + if params.iter > 0: + suffix = f"iter-{params.iter}" + else: + suffix = f"epoch-{params.epoch}" + + suffix += f"-avg-{params.avg}" + + opset_version = 13 + + logging.info("Exporting encoder") + encoder_filename = params.exp_dir / f"encoder-{suffix}.onnx" + export_encoder_model_onnx( + encoder, + encoder_filename, + opset_version=opset_version, + ) + logging.info(f"Exported encoder to {encoder_filename}") + + logging.info("Exporting decoder") + decoder_filename = params.exp_dir / f"decoder-{suffix}.onnx" + export_decoder_model_onnx( + decoder, + decoder_filename, + opset_version=opset_version, + ) + logging.info(f"Exported decoder to {decoder_filename}") + + logging.info("Exporting joiner") + joiner_filename = params.exp_dir / f"joiner-{suffix}.onnx" + export_joiner_model_onnx( + joiner, + joiner_filename, + opset_version=opset_version, + ) + logging.info(f"Exported joiner to {joiner_filename}") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + main() diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export.py index 949214aec..b53426c75 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export.py @@ -64,6 +64,7 @@ from pathlib import Path import sentencepiece as spm import torch +from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_params, get_transducer_model from icefall.checkpoint import ( @@ -258,6 +259,7 @@ def main(): model.eval() if params.jit: + convert_scaled_to_non_scaled(model, inplace=True) # We won't use the forward() method of the model in C++, so just ignore # it here. # Otherwise, one of its arguments is a ragged tensor and is not diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/onnx_pretrained.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/onnx_pretrained.py new file mode 100755 index 000000000..db92ac696 --- /dev/null +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/onnx_pretrained.py @@ -0,0 +1,456 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang) + +""" +This script loads ONNX models exported by ./export-onnx.py +and uses them to decode waves. + +We use the pre-trained model from +https://huggingface.co/Zengwei/icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05 +as an example to show how to use this file. + +1. Download the pre-trained model + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "data/lang_bpe_500/bpe.model" +git lfs pull --include "exp/pretrained-epoch-30-avg-10-averaged.pt" + +cd exp +ln -s pretrained-epoch-30-avg-10-averaged.pt epoch-99.pt +popd + +2. Export the model to ONNX + +./conv_emformer_transducer_stateless2/export-onnx.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --exp-dir $repo/exp \ + --num-encoder-layers 12 \ + --chunk-length 32 \ + --cnn-module-kernel 31 \ + --left-context-length 32 \ + --right-context-length 8 \ + --memory-size 32 + +It will generate the following 3 files inside $repo/exp: + + - encoder-epoch-99-avg-1.onnx + - decoder-epoch-99-avg-1.onnx + - joiner-epoch-99-avg-1.onnx + +3. Run this file with the exported ONNX models + +./conv_emformer_transducer_stateless2/onnx_pretrained.py \ + --encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \ + --decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \ + --joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1089-134686-0001.wav + +Note: Even though this script only supports decoding a single file, +the exported ONNX models do support batch processing. +""" + +import argparse +import logging +from typing import Dict, List, Optional, Tuple + +import k2 +import numpy as np +import onnxruntime as ort +import torch +import torchaudio +from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--encoder-model-filename", + type=str, + required=True, + help="Path to the encoder onnx model. ", + ) + + parser.add_argument( + "--decoder-model-filename", + type=str, + required=True, + help="Path to the decoder onnx model. ", + ) + + parser.add_argument( + "--joiner-model-filename", + type=str, + required=True, + help="Path to the joiner onnx model. ", + ) + + parser.add_argument( + "--tokens", + type=str, + help="""Path to tokens.txt.""", + ) + + parser.add_argument( + "sound_file", + type=str, + help="The input sound file to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + return parser + + +class OnnxModel: + def __init__( + self, + encoder_model_filename: str, + decoder_model_filename: str, + joiner_model_filename: str, + ): + session_opts = ort.SessionOptions() + session_opts.inter_op_num_threads = 1 + session_opts.intra_op_num_threads = 1 + + self.session_opts = session_opts + + self.init_encoder(encoder_model_filename) + self.init_decoder(decoder_model_filename) + self.init_joiner(joiner_model_filename) + + def init_encoder(self, encoder_model_filename: str): + self.encoder = ort.InferenceSession( + encoder_model_filename, + sess_options=self.session_opts, + ) + self.init_encoder_states() + + def init_encoder_states(self, batch_size: int = 1): + encoder_meta = self.encoder.get_modelmeta().custom_metadata_map + + model_type = encoder_meta["model_type"] + assert model_type == "conv-emformer", model_type + + decode_chunk_len = int(encoder_meta["decode_chunk_len"]) + T = int(encoder_meta["T"]) + + num_encoder_layers = int(encoder_meta["num_encoder_layers"]) + memory_size = int(encoder_meta["memory_size"]) + cnn_module_kernel = int(encoder_meta["cnn_module_kernel"]) + right_context_length = int(encoder_meta["right_context_length"]) + left_context_length = int(encoder_meta["left_context_length"]) + encoder_dim = int(encoder_meta["encoder_dim"]) + + logging.info(f"decode_chunk_len: {decode_chunk_len}") + logging.info(f"T: {T}") + logging.info(f"num_encoder_layers: {num_encoder_layers}") + logging.info(f"memory_size: {memory_size}") + logging.info(f"cnn_module_kernel: {cnn_module_kernel}") + logging.info(f"left_context_length: {left_context_length} (after subsampling)") + logging.info(f"right_context_length: {right_context_length}") + logging.info(f"encoder_dim: {encoder_dim}") + + N = batch_size + + states = [] + for i in range(num_encoder_layers): + s0 = torch.zeros(memory_size, N, encoder_dim) + s1 = torch.zeros(left_context_length, N, encoder_dim) + s2 = torch.zeros(left_context_length, N, encoder_dim) + s3 = torch.zeros(N, encoder_dim, cnn_module_kernel - 1) + states.extend([s0, s1, s2, s3]) + + self.states = states + + self.segment = T + self.offset = decode_chunk_len + self.num_encoder_layers = num_encoder_layers + + def init_decoder(self, decoder_model_filename: str): + self.decoder = ort.InferenceSession( + decoder_model_filename, + sess_options=self.session_opts, + ) + + decoder_meta = self.decoder.get_modelmeta().custom_metadata_map + self.context_size = int(decoder_meta["context_size"]) + self.vocab_size = int(decoder_meta["vocab_size"]) + + logging.info(f"context_size: {self.context_size}") + logging.info(f"vocab_size: {self.vocab_size}") + + def init_joiner(self, joiner_model_filename: str): + self.joiner = ort.InferenceSession( + joiner_model_filename, + sess_options=self.session_opts, + ) + + joiner_meta = self.joiner.get_modelmeta().custom_metadata_map + self.joiner_dim = int(joiner_meta["joiner_dim"]) + + logging.info(f"joiner_dim: {self.joiner_dim}") + + def _build_encoder_input_output( + self, + x: torch.Tensor, + ) -> Tuple[Dict[str, np.ndarray], List[str]]: + encoder_input = {"x": x.numpy()} + encoder_output = ["encoder_out"] + + def build_inputs_outputs(states: List[torch.Tensor], name: str): + for i in range(4): + if isinstance(states[i], torch.Tensor): + encoder_input[f"{name}_{i}"] = states[i].numpy() + else: + encoder_input[f"{name}_{i}"] = states[i] + + encoder_output.append(f"new_{name}_{i}") + + for i in range(self.num_encoder_layers): + base_name = f"layer{i}" + s = self.states[i * 4 : (i + 1) * 4] + build_inputs_outputs(s, base_name) + + return encoder_input, encoder_output + + def _update_states(self, states: List[np.ndarray]): + self.states = states + + def run_encoder(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x: + A 3-D tensor of shape (N, T, C) + Returns: + Return a 3-D tensor of shape (N, T', joiner_dim) where + T' is usually equal to ((T-7)//2+1)//2 + """ + encoder_input, encoder_output_names = self._build_encoder_input_output(x) + out = self.encoder.run(encoder_output_names, encoder_input) + + self._update_states(out[1:]) + + return torch.from_numpy(out[0]) + + def run_decoder(self, decoder_input: torch.Tensor) -> torch.Tensor: + """ + Args: + decoder_input: + A 2-D tensor of shape (N, context_size) + Returns: + Return a 2-D tensor of shape (N, joiner_dim) + """ + out = self.decoder.run( + [self.decoder.get_outputs()[0].name], + {self.decoder.get_inputs()[0].name: decoder_input.numpy()}, + )[0] + + return torch.from_numpy(out) + + def run_joiner( + self, encoder_out: torch.Tensor, decoder_out: torch.Tensor + ) -> torch.Tensor: + """ + Args: + encoder_out: + A 2-D tensor of shape (N, joiner_dim) + decoder_out: + A 2-D tensor of shape (N, joiner_dim) + Returns: + Return a 2-D tensor of shape (N, vocab_size) + """ + out = self.joiner.run( + [self.joiner.get_outputs()[0].name], + { + self.joiner.get_inputs()[0].name: encoder_out.numpy(), + self.joiner.get_inputs()[1].name: decoder_out.numpy(), + }, + )[0] + + return torch.from_numpy(out) + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0].contiguous()) + return ans + + +def create_streaming_feature_extractor() -> OnlineFeature: + """Create a CPU streaming feature extractor. + + At present, we assume it returns a fbank feature extractor with + fixed options. In the future, we will support passing in the options + from outside. + + Returns: + Return a CPU streaming feature extractor. + """ + opts = FbankOptions() + opts.device = "cpu" + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = 16000 + opts.mel_opts.num_bins = 80 + return OnlineFbank(opts) + + +def greedy_search( + model: OnnxModel, + encoder_out: torch.Tensor, + context_size: int, + decoder_out: Optional[torch.Tensor] = None, + hyp: Optional[List[int]] = None, +) -> List[int]: + """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. + Args: + model: + The transducer model. + encoder_out: + A 3-D tensor of shape (1, T, joiner_dim) + context_size: + The context size of the decoder model. + decoder_out: + Optional. Decoder output of the previous chunk. + hyp: + Decoding results for previous chunks. + Returns: + Return the decoded results so far. + """ + + blank_id = 0 + + if decoder_out is None: + assert hyp is None, hyp + hyp = [blank_id] * context_size + decoder_input = torch.tensor([hyp], dtype=torch.int64) + decoder_out = model.run_decoder(decoder_input) + else: + assert hyp is not None, hyp + + encoder_out = encoder_out.squeeze(0) + T = encoder_out.size(0) + for t in range(T): + cur_encoder_out = encoder_out[t : t + 1] + joiner_out = model.run_joiner(cur_encoder_out, decoder_out).squeeze(0) + y = joiner_out.argmax(dim=0).item() + if y != blank_id: + hyp.append(y) + decoder_input = hyp[-context_size:] + decoder_input = torch.tensor([decoder_input], dtype=torch.int64) + decoder_out = model.run_decoder(decoder_input) + + return hyp, decoder_out + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + logging.info(vars(args)) + + model = OnnxModel( + encoder_model_filename=args.encoder_model_filename, + decoder_model_filename=args.decoder_model_filename, + joiner_model_filename=args.joiner_model_filename, + ) + + sample_rate = 16000 + + logging.info("Constructing Fbank computer") + online_fbank = create_streaming_feature_extractor() + + logging.info(f"Reading sound files: {args.sound_file}") + waves = read_sound_files( + filenames=[args.sound_file], + expected_sample_rate=sample_rate, + )[0] + + tail_padding = torch.zeros(int(0.3 * sample_rate), dtype=torch.float32) + wave_samples = torch.cat([waves, tail_padding]) + + num_processed_frames = 0 + segment = model.segment + offset = model.offset + + context_size = model.context_size + hyp = None + decoder_out = None + + chunk = int(1 * sample_rate) # 1 second + start = 0 + while start < wave_samples.numel(): + end = min(start + chunk, wave_samples.numel()) + samples = wave_samples[start:end] + start += chunk + + online_fbank.accept_waveform( + sampling_rate=sample_rate, + waveform=samples, + ) + + while online_fbank.num_frames_ready - num_processed_frames >= segment: + frames = [] + for i in range(segment): + frames.append(online_fbank.get_frame(num_processed_frames + i)) + num_processed_frames += offset + frames = torch.cat(frames, dim=0) + frames = frames.unsqueeze(0) + encoder_out = model.run_encoder(frames) + hyp, decoder_out = greedy_search( + model, + encoder_out, + context_size, + decoder_out, + hyp, + ) + + symbol_table = k2.SymbolTable.from_file(args.tokens) + + text = "" + for i in hyp[context_size:]: + text += symbol_table[i] + text = text.replace("▁", " ").strip() + + logging.info(args.sound_file) + logging.info(text) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train2.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train2.py index c91f94876..dd0a60736 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train2.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train2.py @@ -425,6 +425,7 @@ def get_params() -> AttributeDict: "joiner_dim": 512, # parameters for Noam "model_warm_step": 3000, # arg given to model, not for lrate + "is_pnnx": True, "env_info": get_env_info(), } ) @@ -446,6 +447,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module: left_context_length=params.left_context_length, right_context_length=params.right_context_length, memory_size=params.memory_size, + is_pnnx=params.is_pnnx, ) return encoder diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/export-onnx.py b/egs/librispeech/ASR/lstm_transducer_stateless/export-onnx.py new file mode 120000 index 000000000..9f5064deb --- /dev/null +++ b/egs/librispeech/ASR/lstm_transducer_stateless/export-onnx.py @@ -0,0 +1 @@ +../lstm_transducer_stateless2/export-onnx.py \ No newline at end of file diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/onnx_check.py b/egs/librispeech/ASR/lstm_transducer_stateless/onnx_check.py new file mode 120000 index 000000000..0b1ea0326 --- /dev/null +++ b/egs/librispeech/ASR/lstm_transducer_stateless/onnx_check.py @@ -0,0 +1 @@ +../lstm_transducer_stateless2/onnx_check.py \ No newline at end of file diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/onnx_pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless/onnx_pretrained.py new file mode 120000 index 000000000..099c2882f --- /dev/null +++ b/egs/librispeech/ASR/lstm_transducer_stateless/onnx_pretrained.py @@ -0,0 +1 @@ +../lstm_transducer_stateless2/onnx_pretrained.py \ No newline at end of file diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/export-for-ncnn.py b/egs/librispeech/ASR/lstm_transducer_stateless2/export-for-ncnn.py new file mode 100755 index 000000000..7982ace68 --- /dev/null +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/export-for-ncnn.py @@ -0,0 +1,337 @@ +#!/usr/bin/env python3 + +""" +Please see +https://k2-fsa.github.io/icefall/model-export/export-ncnn.html +for more details about how to use this file. + +We use the pre-trained model from +https://huggingface.co/csukuangfj/icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03 +as an example to show how to use this file. + +1. Download the pre-trained model + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "data/lang_bpe_500/bpe.model" +git lfs pull --include "exp/pretrained-iter-468000-avg-16.pt" + +cd exp +ln -s pretrained-iter-468000-avg-16.pt epoch-99.pt +popd + +2. Export via torch.jit.trace() + +./lstm_transducer_stateless2/export-for-ncnn.py \ + --exp-dir $repo/exp \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --epoch 99 \ + --avg 1 \ + --use-averaged-model 0 \ + +cd ./lstm_transducer_stateless2/exp +pnnx encoder_jit_trace-pnnx.pt +pnnx decoder_jit_trace-pnnx.pt +pnnx joiner_jit_trace-pnnx.pt + +See ./streaming-ncnn-decode.py +and +https://github.com/k2-fsa/sherpa-ncnn +for usage. +""" + +import argparse +import logging +from pathlib import Path + +import sentencepiece as spm +import torch +from scaling_converter import convert_scaled_to_non_scaled +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import setup_logger, str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="""It specifies the checkpoint to use for averaging. + Note: Epoch counts from 0. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless2/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + add_model_arguments(parser) + + return parser + + +def export_encoder_model_jit_trace( + encoder_model: torch.nn.Module, + encoder_filename: str, +) -> None: + """Export the given encoder model with torch.jit.trace() + + Note: The warmup argument is fixed to 1. + + Args: + encoder_model: + The input encoder model + encoder_filename: + The filename to save the exported model. + """ + x = torch.zeros(1, 100, 80, dtype=torch.float32) + x_lens = torch.tensor([100], dtype=torch.int64) + states = encoder_model.get_init_states() + + traced_model = torch.jit.trace(encoder_model, (x, x_lens, states)) + traced_model.save(encoder_filename) + logging.info(f"Saved to {encoder_filename}") + + +def export_decoder_model_jit_trace( + decoder_model: torch.nn.Module, + decoder_filename: str, +) -> None: + """Export the given decoder model with torch.jit.trace() + + Note: The argument need_pad is fixed to False. + + Args: + decoder_model: + The input decoder model + decoder_filename: + The filename to save the exported model. + """ + y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64) + need_pad = torch.tensor([False]) + + traced_model = torch.jit.trace(decoder_model, (y, need_pad)) + traced_model.save(decoder_filename) + logging.info(f"Saved to {decoder_filename}") + + +def export_joiner_model_jit_trace( + joiner_model: torch.nn.Module, + joiner_filename: str, +) -> None: + """Export the given joiner model with torch.jit.trace() + + Note: The argument project_input is fixed to True. A user should not + project the encoder_out/decoder_out by himself/herself. The exported joiner + will do that for the user. + + Args: + joiner_model: + The input joiner model + joiner_filename: + The filename to save the exported model. + + """ + encoder_out_dim = joiner_model.encoder_proj.weight.shape[1] + decoder_out_dim = joiner_model.decoder_proj.weight.shape[1] + encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32) + decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32) + + traced_model = torch.jit.trace(joiner_model, (encoder_out, decoder_out)) + traced_model.save(joiner_filename) + logging.info(f"Saved to {joiner_filename}") + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + + setup_logger(f"{params.exp_dir}/log-export/log-export-ncnn") + + logging.info(f"device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + params.is_pnnx = True + + logging.info("About to create model") + model = get_transducer_model(params, enable_giga=False) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to("cpu") + model.eval() + + convert_scaled_to_non_scaled(model, inplace=True) + logging.info("Using torch.jit.trace()") + + logging.info("Exporting encoder") + encoder_filename = params.exp_dir / "encoder_jit_trace-pnnx.pt" + export_encoder_model_jit_trace(model.encoder, encoder_filename) + + logging.info("Exporting decoder") + decoder_filename = params.exp_dir / "decoder_jit_trace-pnnx.pt" + export_decoder_model_jit_trace(model.decoder, decoder_filename) + + logging.info("Exporting joiner") + joiner_filename = params.exp_dir / "joiner_jit_trace-pnnx.pt" + export_joiner_model_jit_trace(model.joiner, joiner_filename) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/export-onnx.py b/egs/librispeech/ASR/lstm_transducer_stateless2/export-onnx.py new file mode 100755 index 000000000..46873ebf9 --- /dev/null +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/export-onnx.py @@ -0,0 +1,593 @@ +#!/usr/bin/env python3 +# +# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang) + +""" +This script exports a transducer model from PyTorch to ONNX. + +We use the pre-trained model from +https://huggingface.co/csukuangfj/icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03 +as an example to show how to use this file. + +1. Download the pre-trained model + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "data/lang_bpe_500/bpe.model" +git lfs pull --include "exp/pretrained-iter-468000-avg-16.pt" + +cd exp +ln -s pretrained-iter-468000-avg-16.pt epoch-99.pt +popd + +2. Export the model to ONNX + +./lstm_transducer_stateless2/export-onnx.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --exp-dir $repo/exp + +It will generate the following 3 files inside $repo/exp: + + - encoder-epoch-99-avg-1.onnx + - decoder-epoch-99-avg-1.onnx + - joiner-epoch-99-avg-1.onnx + +See ./onnx_pretrained.py and ./onnx_check.py for how to +use the exported ONNX models. +""" + +import argparse +import logging +from pathlib import Path +from typing import Dict, Optional, Tuple + +import onnx +import sentencepiece as spm +import torch +import torch.nn as nn +from decoder import Decoder +from lstm import RNN +from scaling_converter import convert_scaled_to_non_scaled +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import setup_logger, str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="""It specifies the checkpoint to use for averaging. + Note: Epoch counts from 0. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless5/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + add_model_arguments(parser) + + return parser + + +def add_meta_data(filename: str, meta_data: Dict[str, str]): + """Add meta data to an ONNX model. It is changed in-place. + + Args: + filename: + Filename of the ONNX model to be changed. + meta_data: + Key-value pairs. + """ + model = onnx.load(filename) + for key, value in meta_data.items(): + meta = model.metadata_props.add() + meta.key = key + meta.value = value + + onnx.save(model, filename) + + +class OnnxEncoder(nn.Module): + """A wrapper for RNN and the encoder_proj from the joiner""" + + def __init__(self, encoder: RNN, encoder_proj: nn.Linear): + """ + Args: + encoder: + An RNN encoder. + encoder_proj: + The projection layer for encoder from the joiner. + """ + super().__init__() + self.encoder = encoder + self.encoder_proj = encoder_proj + + def forward( + self, + x: torch.Tensor, + states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Please see the help information of RNN.forward + + Args: + x: + A 3-D tensor of shape (N, T, C) + states: + A tuple of 2 tensors (optional). It is for streaming inference. + states[0] is the hidden states of all layers, + with shape of (num_layers, N, d_model); + states[1] is the cell states of all layers, + with shape of (num_layers, N, rnn_hidden_size). + Returns: + Return a tuple containing: + - encoder_out, A 3-D tensor of shape (N, T', joiner_dim) + - updated states, whose shape is the same as the input states. + """ + N = x.size(0) + T = x.size(1) + x_lens = torch.tensor([T] * N, dtype=torch.int64, device=x.device) + encoder_out, _, next_states = self.encoder(x, x_lens, states) + + encoder_out = self.encoder_proj(encoder_out) + # Now encoder_out is of shape (N, T, joiner_dim) + + return encoder_out, next_states + + +class OnnxDecoder(nn.Module): + """A wrapper for Decoder and the decoder_proj from the joiner""" + + def __init__(self, decoder: Decoder, decoder_proj: nn.Linear): + super().__init__() + self.decoder = decoder + self.decoder_proj = decoder_proj + + def forward(self, y: torch.Tensor) -> torch.Tensor: + """ + Args: + y: + A 2-D tensor of shape (N, context_size). + Returns + Return a 2-D tensor of shape (N, joiner_dim) + """ + need_pad = False + decoder_output = self.decoder(y, need_pad=need_pad) + decoder_output = decoder_output.squeeze(1) + output = self.decoder_proj(decoder_output) + + return output + + +class OnnxJoiner(nn.Module): + """A wrapper for the joiner""" + + def __init__(self, output_linear: nn.Linear): + super().__init__() + self.output_linear = output_linear + + def forward( + self, + encoder_out: torch.Tensor, + decoder_out: torch.Tensor, + ) -> torch.Tensor: + """ + Args: + encoder_out: + A 2-D tensor of shape (N, joiner_dim) + decoder_out: + A 2-D tensor of shape (N, joiner_dim) + Returns: + Return a 2-D tensor of shape (N, vocab_size) + """ + logit = encoder_out + decoder_out + logit = self.output_linear(torch.tanh(logit)) + return logit + + +def export_encoder_model_onnx( + encoder_model: OnnxEncoder, + encoder_filename: str, + opset_version: int = 11, +) -> None: + """Export the given encoder model to ONNX format. + The exported model has the following inputs: + + - x, a tensor of shape (N, T, C); dtype is torch.float32 + - state0, a tensor of shape (num_encoder_layers, batch_size, d_model) + - state1, a tensor of shape (num_encoder_layers, batch_size, rnn_hidden_size) + + and it has 3 outputs: + + - encoder_out, a tensor of shape (N, T', joiner_dim) + - new_state0, a tensor of shape (num_encoder_layers, batch_size, d_model) + - new_state1, a tensor of shape (num_encoder_layers, batch_size, rnn_hidden_size) + + Args: + encoder_model: + The input encoder model + encoder_filename: + The filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + num_encoder_layers = encoder_model.encoder.num_encoder_layers + d_model = encoder_model.encoder.d_model + rnn_hidden_size = encoder_model.encoder.rnn_hidden_size + + decode_chunk_len = 4 + T = 9 + + x = torch.zeros(1, T, 80, dtype=torch.float32) + states = encoder_model.encoder.get_init_states() + # state0: (num_encoder_layers, batch_size, d_model) + # state1: (num_encoder_layers, batch_size, rnn_hidden_size) + + torch.onnx.export( + encoder_model, + (x, states), + encoder_filename, + verbose=False, + opset_version=opset_version, + input_names=["x", "state0", "state1"], + output_names=["encoder_out", "new_state0", "new_state1"], + dynamic_axes={ + "x": {0: "N"}, + "state0": {1: "N"}, + "state1": {1: "N"}, + "encoder_out": {0: "N"}, + "new_state0": {1: "N"}, + "new_state1": {1: "N"}, + }, + ) + + meta_data = { + "model_type": "lstm", + "version": "1", + "model_author": "k2-fsa", + "decode_chunk_len": str(decode_chunk_len), # 32 + "T": str(T), # 39 + "num_encoder_layers": str(num_encoder_layers), + "d_model": str(d_model), + "rnn_hidden_size": str(rnn_hidden_size), + } + logging.info(f"meta_data: {meta_data}") + + add_meta_data(filename=encoder_filename, meta_data=meta_data) + + +def export_decoder_model_onnx( + decoder_model: OnnxDecoder, + decoder_filename: str, + opset_version: int = 11, +) -> None: + """Export the decoder model to ONNX format. + + The exported model has one input: + + - y: a torch.int64 tensor of shape (N, decoder_model.context_size) + + and has one output: + + - decoder_out: a torch.float32 tensor of shape (N, joiner_dim) + + Args: + decoder_model: + The decoder model to be exported. + decoder_filename: + Filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + context_size = decoder_model.decoder.context_size + vocab_size = decoder_model.decoder.vocab_size + + y = torch.zeros(10, context_size, dtype=torch.int64) + torch.onnx.export( + decoder_model, + y, + decoder_filename, + verbose=False, + opset_version=opset_version, + input_names=["y"], + output_names=["decoder_out"], + dynamic_axes={ + "y": {0: "N"}, + "decoder_out": {0: "N"}, + }, + ) + + meta_data = { + "context_size": str(context_size), + "vocab_size": str(vocab_size), + } + add_meta_data(filename=decoder_filename, meta_data=meta_data) + + +def export_joiner_model_onnx( + joiner_model: nn.Module, + joiner_filename: str, + opset_version: int = 11, +) -> None: + """Export the joiner model to ONNX format. + The exported joiner model has two inputs: + + - encoder_out: a tensor of shape (N, joiner_dim) + - decoder_out: a tensor of shape (N, joiner_dim) + + and produces one output: + + - logit: a tensor of shape (N, vocab_size) + """ + joiner_dim = joiner_model.output_linear.weight.shape[1] + logging.info(f"joiner dim: {joiner_dim}") + + projected_encoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) + projected_decoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) + + torch.onnx.export( + joiner_model, + (projected_encoder_out, projected_decoder_out), + joiner_filename, + verbose=False, + opset_version=opset_version, + input_names=[ + "encoder_out", + "decoder_out", + ], + output_names=["logit"], + dynamic_axes={ + "encoder_out": {0: "N"}, + "decoder_out": {0: "N"}, + "logit": {0: "N"}, + }, + ) + meta_data = { + "joiner_dim": str(joiner_dim), + } + add_meta_data(filename=joiner_filename, meta_data=meta_data) + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + setup_logger(f"{params.exp_dir}/log-export/log-export-onnx") + + logging.info(f"device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params, enable_giga=False) + + model.to(device) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to("cpu") + model.eval() + + convert_scaled_to_non_scaled(model, inplace=True, is_onnx=True) + + encoder = OnnxEncoder( + encoder=model.encoder, + encoder_proj=model.joiner.encoder_proj, + ) + + decoder = OnnxDecoder( + decoder=model.decoder, + decoder_proj=model.joiner.decoder_proj, + ) + + joiner = OnnxJoiner(output_linear=model.joiner.output_linear) + + encoder_num_param = sum([p.numel() for p in encoder.parameters()]) + decoder_num_param = sum([p.numel() for p in decoder.parameters()]) + joiner_num_param = sum([p.numel() for p in joiner.parameters()]) + total_num_param = encoder_num_param + decoder_num_param + joiner_num_param + logging.info(f"encoder parameters: {encoder_num_param}") + logging.info(f"decoder parameters: {decoder_num_param}") + logging.info(f"joiner parameters: {joiner_num_param}") + logging.info(f"total parameters: {total_num_param}") + + if params.iter > 0: + suffix = f"iter-{params.iter}" + else: + suffix = f"epoch-{params.epoch}" + + suffix += f"-avg-{params.avg}" + + opset_version = 13 + + logging.info("Exporting encoder") + encoder_filename = params.exp_dir / f"encoder-{suffix}.onnx" + export_encoder_model_onnx( + encoder, + encoder_filename, + opset_version=opset_version, + ) + logging.info(f"Exported encoder to {encoder_filename}") + + logging.info("Exporting decoder") + decoder_filename = params.exp_dir / f"decoder-{suffix}.onnx" + export_decoder_model_onnx( + decoder, + decoder_filename, + opset_version=opset_version, + ) + logging.info(f"Exported decoder to {decoder_filename}") + + logging.info("Exporting joiner") + joiner_filename = params.exp_dir / f"joiner-{suffix}.onnx" + export_joiner_model_onnx( + joiner, + joiner_filename, + opset_version=opset_version, + ) + logging.info(f"Exported joiner to {joiner_filename}") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/export.py b/egs/librispeech/ASR/lstm_transducer_stateless2/export.py index 5977cb36d..0adc68112 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/export.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/export.py @@ -74,29 +74,6 @@ with the following commands: git lfs install git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03 # You will find the pre-trained models in icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp - -(3) Export to ONNX format - -./lstm_transducer_stateless2/export.py \ - --exp-dir ./lstm_transducer_stateless2/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ - --epoch 20 \ - --avg 10 \ - --onnx 1 - -It will generate the following files in the given `exp_dir`. - - - encoder.onnx - - decoder.onnx - - joiner.onnx - - joiner_encoder_proj.onnx - - joiner_decoder_proj.onnx - -Please see ./streaming-onnx-decode.py for usage of the generated files - -Check -https://github.com/k2-fsa/sherpa-onnx -for how to use the exported models outside of icefall. """ import argparse @@ -192,35 +169,6 @@ def get_parser(): """, ) - parser.add_argument( - "--pnnx", - type=str2bool, - default=False, - help="""True to save a model after applying torch.jit.trace for later - converting to PNNX. It will generate 3 files: - - encoder_jit_trace-pnnx.pt - - decoder_jit_trace-pnnx.pt - - joiner_jit_trace-pnnx.pt - """, - ) - - parser.add_argument( - "--onnx", - type=str2bool, - default=False, - help="""If True, --jit and --pnnx are ignored and it exports the model - to onnx format. It will generate the following files: - - - encoder.onnx - - decoder.onnx - - joiner.onnx - - joiner_encoder_proj.onnx - - joiner_decoder_proj.onnx - - Refer to ./onnx_check.py and ./onnx_pretrained.py for how to use them. - """, - ) - parser.add_argument( "--context-size", type=int, @@ -305,209 +253,6 @@ def export_joiner_model_jit_trace( logging.info(f"Saved to {joiner_filename}") -def export_encoder_model_onnx( - encoder_model: nn.Module, - encoder_filename: str, - opset_version: int = 11, -) -> None: - """Export the given encoder model to ONNX format. - The exported model has 3 inputs: - - - x, a tensor of shape (N, T, C); dtype is torch.float32 - - x_lens, a tensor of shape (N,); dtype is torch.int64 - - states: a tuple containing: - - h0: a tensor of shape (num_layers, N, proj_size) - - c0: a tensor of shape (num_layers, N, hidden_size) - - and it has 3 outputs: - - - encoder_out, a tensor of shape (N, T, C) - - encoder_out_lens, a tensor of shape (N,) - - states: a tuple containing: - - next_h0: a tensor of shape (num_layers, N, proj_size) - - next_c0: a tensor of shape (num_layers, N, hidden_size) - - Note: The warmup argument is fixed to 1. - - Args: - encoder_model: - The input encoder model - encoder_filename: - The filename to save the exported ONNX model. - opset_version: - The opset version to use. - """ - N = 1 - x = torch.zeros(N, 9, 80, dtype=torch.float32) - x_lens = torch.tensor([9], dtype=torch.int64) - h = torch.rand(encoder_model.num_encoder_layers, N, encoder_model.d_model) - c = torch.rand(encoder_model.num_encoder_layers, N, encoder_model.rnn_hidden_size) - - warmup = 1.0 - torch.onnx.export( - encoder_model, # use torch.jit.trace() internally - (x, x_lens, (h, c), warmup), - encoder_filename, - verbose=False, - opset_version=opset_version, - input_names=["x", "x_lens", "h", "c", "warmup"], - output_names=["encoder_out", "encoder_out_lens", "next_h", "next_c"], - dynamic_axes={ - "x": {0: "N", 1: "T"}, - "x_lens": {0: "N"}, - "h": {1: "N"}, - "c": {1: "N"}, - "encoder_out": {0: "N", 1: "T"}, - "encoder_out_lens": {0: "N"}, - "next_h": {1: "N"}, - "next_c": {1: "N"}, - }, - ) - logging.info(f"Saved to {encoder_filename}") - - -def export_decoder_model_onnx( - decoder_model: nn.Module, - decoder_filename: str, - opset_version: int = 11, -) -> None: - """Export the decoder model to ONNX format. - - The exported model has one input: - - - y: a torch.int64 tensor of shape (N, decoder_model.context_size) - - and has one output: - - - decoder_out: a torch.float32 tensor of shape (N, 1, C) - - Note: The argument need_pad is fixed to False. - - Args: - decoder_model: - The decoder model to be exported. - decoder_filename: - Filename to save the exported ONNX model. - opset_version: - The opset version to use. - """ - y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64) - need_pad = False # Always False, so we can use torch.jit.trace() here - # Note(fangjun): torch.jit.trace() is more efficient than torch.jit.script() - # in this case - torch.onnx.export( - decoder_model, - (y, need_pad), - decoder_filename, - verbose=False, - opset_version=opset_version, - input_names=["y", "need_pad"], - output_names=["decoder_out"], - dynamic_axes={ - "y": {0: "N"}, - "decoder_out": {0: "N"}, - }, - ) - logging.info(f"Saved to {decoder_filename}") - - -def export_joiner_model_onnx( - joiner_model: nn.Module, - joiner_filename: str, - opset_version: int = 11, -) -> None: - """Export the joiner model to ONNX format. - The exported joiner model has two inputs: - - - projected_encoder_out: a tensor of shape (N, joiner_dim) - - projected_decoder_out: a tensor of shape (N, joiner_dim) - - and produces one output: - - - logit: a tensor of shape (N, vocab_size) - - The exported encoder_proj model has one input: - - - encoder_out: a tensor of shape (N, encoder_out_dim) - - and produces one output: - - - projected_encoder_out: a tensor of shape (N, joiner_dim) - - The exported decoder_proj model has one input: - - - decoder_out: a tensor of shape (N, decoder_out_dim) - - and produces one output: - - - projected_decoder_out: a tensor of shape (N, joiner_dim) - """ - encoder_proj_filename = str(joiner_filename).replace(".onnx", "_encoder_proj.onnx") - - decoder_proj_filename = str(joiner_filename).replace(".onnx", "_decoder_proj.onnx") - - encoder_out_dim = joiner_model.encoder_proj.weight.shape[1] - decoder_out_dim = joiner_model.decoder_proj.weight.shape[1] - joiner_dim = joiner_model.decoder_proj.weight.shape[0] - - projected_encoder_out = torch.rand(1, joiner_dim, dtype=torch.float32) - projected_decoder_out = torch.rand(1, joiner_dim, dtype=torch.float32) - - project_input = False - # Note: It uses torch.jit.trace() internally - torch.onnx.export( - joiner_model, - (projected_encoder_out, projected_decoder_out, project_input), - joiner_filename, - verbose=False, - opset_version=opset_version, - input_names=[ - "projected_encoder_out", - "projected_decoder_out", - "project_input", - ], - output_names=["logit"], - dynamic_axes={ - "projected_encoder_out": {0: "N"}, - "projected_decoder_out": {0: "N"}, - "logit": {0: "N"}, - }, - ) - logging.info(f"Saved to {joiner_filename}") - - encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32) - torch.onnx.export( - joiner_model.encoder_proj, - encoder_out, - encoder_proj_filename, - verbose=False, - opset_version=opset_version, - input_names=["encoder_out"], - output_names=["projected_encoder_out"], - dynamic_axes={ - "encoder_out": {0: "N"}, - "projected_encoder_out": {0: "N"}, - }, - ) - logging.info(f"Saved to {encoder_proj_filename}") - - decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32) - torch.onnx.export( - joiner_model.decoder_proj, - decoder_out, - decoder_proj_filename, - verbose=False, - opset_version=opset_version, - input_names=["decoder_out"], - output_names=["projected_decoder_out"], - dynamic_axes={ - "decoder_out": {0: "N"}, - "projected_decoder_out": {0: "N"}, - }, - ) - logging.info(f"Saved to {decoder_proj_filename}") - - @torch.no_grad() def main(): args = get_parser().parse_args() @@ -531,10 +276,6 @@ def main(): logging.info(params) - if params.pnnx: - params.is_pnnx = params.pnnx - logging.info("For PNNX") - logging.info("About to create model") model = get_transducer_model(params, enable_giga=False) @@ -629,44 +370,7 @@ def main(): model.to("cpu") model.eval() - if params.onnx: - logging.info("Export model to ONNX format") - convert_scaled_to_non_scaled(model, inplace=True, is_onnx=True) - - opset_version = 11 - encoder_filename = params.exp_dir / "encoder.onnx" - export_encoder_model_onnx( - model.encoder, - encoder_filename, - opset_version=opset_version, - ) - - decoder_filename = params.exp_dir / "decoder.onnx" - export_decoder_model_onnx( - model.decoder, - decoder_filename, - opset_version=opset_version, - ) - - joiner_filename = params.exp_dir / "joiner.onnx" - export_joiner_model_onnx( - model.joiner, - joiner_filename, - opset_version=opset_version, - ) - - elif params.pnnx: - convert_scaled_to_non_scaled(model, inplace=True) - logging.info("Using torch.jit.trace()") - encoder_filename = params.exp_dir / "encoder_jit_trace-pnnx.pt" - export_encoder_model_jit_trace(model.encoder, encoder_filename) - - decoder_filename = params.exp_dir / "decoder_jit_trace-pnnx.pt" - export_decoder_model_jit_trace(model.decoder, decoder_filename) - - joiner_filename = params.exp_dir / "joiner_jit_trace-pnnx.pt" - export_joiner_model_jit_trace(model.joiner, joiner_filename) - elif params.jit_trace is True: + if params.jit_trace is True: convert_scaled_to_non_scaled(model, inplace=True) logging.info("Using torch.jit.trace()") encoder_filename = params.exp_dir / "encoder_jit_trace.pt" diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/ncnn-decode.py b/egs/librispeech/ASR/lstm_transducer_stateless2/ncnn-decode.py index 3bd1b0a09..3eeaa5397 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/ncnn-decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/ncnn-decode.py @@ -19,7 +19,7 @@ """ Usage: ./lstm_transducer_stateless2/ncnn-decode.py \ - --bpe-model-filename ./data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --encoder-param-filename ./lstm_transducer_stateless2/exp/encoder_jit_trace-iter-468000-avg-16-pnnx.ncnn.param \ --encoder-bin-filename ./lstm_transducer_stateless2/exp/encoder_jit_trace-iter-468000-avg-16-pnnx.ncnn.bin \ --decoder-param-filename ./lstm_transducer_stateless2/exp/decoder_jit_trace-iter-468000-avg-16-pnnx.ncnn.param \ @@ -27,15 +27,19 @@ Usage: --joiner-param-filename ./lstm_transducer_stateless2/exp/joiner_jit_trace-iter-468000-avg-16-pnnx.ncnn.param \ --joiner-bin-filename ./lstm_transducer_stateless2/exp/joiner_jit_trace-iter-468000-avg-16-pnnx.ncnn.bin \ ./test_wavs/1089-134686-0001.wav + +Please see +https://k2-fsa.github.io/icefall/model-export/export-ncnn.html +for details. """ import argparse import logging from typing import List +import k2 import kaldifeat import ncnn -import sentencepiece as spm import torch import torchaudio @@ -44,9 +48,9 @@ def get_args(): parser = argparse.ArgumentParser() parser.add_argument( - "--bpe-model-filename", + "--tokens", type=str, - help="Path to bpe.model", + help="Path to tokens.txt", ) parser.add_argument( @@ -240,9 +244,6 @@ def main(): model = Model(args) - sp = spm.SentencePieceProcessor() - sp.load(args.bpe_model_filename) - sound_file = args.sound_filename sample_rate = 16000 @@ -280,8 +281,16 @@ def main(): encoder_out, encoder_out_lens, hx, cx = model.run_encoder(features, states) hyp = greedy_search(model, encoder_out) + + symbol_table = k2.SymbolTable.from_file(args.tokens) + + text = "" + for i in hyp: + text += symbol_table[i] + text = text.replace("▁", " ").strip() + logging.info(sound_file) - logging.info(sp.decode(hyp)) + logging.info(text) if __name__ == "__main__": diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/onnx_check.py b/egs/librispeech/ASR/lstm_transducer_stateless2/onnx_check.py new file mode 100755 index 000000000..c83f38b2a --- /dev/null +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/onnx_check.py @@ -0,0 +1,261 @@ +#!/usr/bin/env python3 +# +# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang) + +""" +This script checks that exported ONNX models produce the same output +with the given torchscript model for the same input. + +We use the pre-trained model from +https://huggingface.co/csukuangfj/icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03 +as an example to show how to use this file. + +1. Download the pre-trained model + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "data/lang_bpe_500/bpe.model" +git lfs pull --include "exp/pretrained-iter-468000-avg-16.pt" + +cd exp +ln -s pretrained-iter-468000-avg-16.pt epoch-99.pt +popd + +2. Export the model via torch.jit.trace() + +./lstm_transducer_stateless2/export.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --exp-dir $repo/exp/ \ + --jit-trace 1 + +It will generate the following 3 files inside $repo/exp + + - encoder_jit_trace.pt + - decoder_jit_trace.pt + - joiner_jit_trace.pt + +3. Export the model to ONNX + +./lstm_transducer_stateless2/export-onnx.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --exp-dir $repo/exp + + +It will generate the following 3 files inside $repo/exp: + + - encoder-epoch-99-avg-1.onnx + - decoder-epoch-99-avg-1.onnx + - joiner-epoch-99-avg-1.onnx + +4. Run this file + +./lstm_transducer_stateless2/onnx_check.py \ + --jit-encoder-filename $repo/exp/encoder_jit_trace.pt \ + --jit-decoder-filename $repo/exp/decoder_jit_trace.pt \ + --jit-joiner-filename $repo/exp/joiner_jit_trace.pt \ + --onnx-encoder-filename $repo/exp/encoder-epoch-99-avg-1.onnx \ + --onnx-decoder-filename $repo/exp/decoder-epoch-99-avg-1.onnx \ + --onnx-joiner-filename $repo/exp/joiner-epoch-99-avg-1.onnx + +""" + +import argparse +import logging + +from onnx_pretrained import OnnxModel + +from icefall import is_module_available + +import torch + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--jit-encoder-filename", + required=True, + type=str, + help="Path to the torchscript encoder model", + ) + + parser.add_argument( + "--jit-decoder-filename", + required=True, + type=str, + help="Path to the torchscript decoder model", + ) + + parser.add_argument( + "--jit-joiner-filename", + required=True, + type=str, + help="Path to the torchscript joiner model", + ) + + parser.add_argument( + "--onnx-encoder-filename", + required=True, + type=str, + help="Path to the ONNX encoder model", + ) + + parser.add_argument( + "--onnx-decoder-filename", + required=True, + type=str, + help="Path to the ONNX decoder model", + ) + + parser.add_argument( + "--onnx-joiner-filename", + required=True, + type=str, + help="Path to the ONNX joiner model", + ) + + return parser + + +def test_encoder( + torch_encoder_model: torch.jit.ScriptModule, + torch_encoder_proj_model: torch.jit.ScriptModule, + onnx_model: OnnxModel, +): + N = torch.randint(1, 100, size=(1,)).item() + T = onnx_model.segment + C = 80 + x_lens = torch.tensor([T] * N) + torch_states = torch_encoder_model.get_init_states(N) + + onnx_model.init_encoder_states(N) + + for i in range(5): + logging.info(f"test_encoder: iter {i}") + x = torch.rand(N, T, C) + torch_encoder_out, _, torch_states = torch_encoder_model( + x, x_lens, torch_states + ) + torch_encoder_out = torch_encoder_proj_model(torch_encoder_out) + + onnx_encoder_out = onnx_model.run_encoder(x) + + assert torch.allclose(torch_encoder_out, onnx_encoder_out, atol=1e-4), ( + (torch_encoder_out - onnx_encoder_out).abs().max() + ) + + +def test_decoder( + torch_decoder_model: torch.jit.ScriptModule, + torch_decoder_proj_model: torch.jit.ScriptModule, + onnx_model: OnnxModel, +): + context_size = onnx_model.context_size + vocab_size = onnx_model.vocab_size + for i in range(10): + N = torch.randint(1, 100, size=(1,)).item() + logging.info(f"test_decoder: iter {i}, N={N}") + x = torch.randint( + low=1, + high=vocab_size, + size=(N, context_size), + dtype=torch.int64, + ) + torch_decoder_out = torch_decoder_model(x, need_pad=torch.tensor([False])) + torch_decoder_out = torch_decoder_proj_model(torch_decoder_out) + torch_decoder_out = torch_decoder_out.squeeze(1) + + onnx_decoder_out = onnx_model.run_decoder(x) + assert torch.allclose(torch_decoder_out, onnx_decoder_out, atol=1e-4), ( + (torch_decoder_out - onnx_decoder_out).abs().max() + ) + + +def test_joiner( + torch_joiner_model: torch.jit.ScriptModule, + onnx_model: OnnxModel, +): + encoder_dim = torch_joiner_model.encoder_proj.weight.shape[1] + decoder_dim = torch_joiner_model.decoder_proj.weight.shape[1] + for i in range(10): + N = torch.randint(1, 100, size=(1,)).item() + logging.info(f"test_joiner: iter {i}, N={N}") + encoder_out = torch.rand(N, encoder_dim) + decoder_out = torch.rand(N, decoder_dim) + + projected_encoder_out = torch_joiner_model.encoder_proj(encoder_out) + projected_decoder_out = torch_joiner_model.decoder_proj(decoder_out) + + torch_joiner_out = torch_joiner_model(encoder_out, decoder_out) + onnx_joiner_out = onnx_model.run_joiner( + projected_encoder_out, projected_decoder_out + ) + + assert torch.allclose(torch_joiner_out, onnx_joiner_out, atol=1e-4), ( + (torch_joiner_out - onnx_joiner_out).abs().max() + ) + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + logging.info(vars(args)) + + torch_encoder_model = torch.jit.load(args.jit_encoder_filename) + torch_decoder_model = torch.jit.load(args.jit_decoder_filename) + torch_joiner_model = torch.jit.load(args.jit_joiner_filename) + + onnx_model = OnnxModel( + encoder_model_filename=args.onnx_encoder_filename, + decoder_model_filename=args.onnx_decoder_filename, + joiner_model_filename=args.onnx_joiner_filename, + ) + + logging.info("Test encoder") + # When exporting the model to onnx, we have already put the encoder_proj + # inside the encoder. + test_encoder(torch_encoder_model, torch_joiner_model.encoder_proj, onnx_model) + + logging.info("Test decoder") + # When exporting the model to onnx, we have already put the decoder_proj + # inside the decoder. + test_decoder(torch_decoder_model, torch_joiner_model.decoder_proj, onnx_model) + + logging.info("Test joiner") + test_joiner(torch_joiner_model, onnx_model) + + logging.info("Finished checking ONNX models") + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +# See https://github.com/pytorch/pytorch/issues/38342 +# and https://github.com/pytorch/pytorch/issues/33354 +# +# If we don't do this, the delay increases whenever there is +# a new request that changes the actual batch size. +# If you use `py-spy dump --pid --native`, you will +# see a lot of time is spent in re-compiling the torch script model. +torch._C._jit_set_profiling_executor(False) +torch._C._jit_set_profiling_mode(False) +torch._C._set_graph_executor_optimize(False) +if __name__ == "__main__": + torch.manual_seed(20230207) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/onnx_pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless2/onnx_pretrained.py new file mode 100755 index 000000000..fb9e121e5 --- /dev/null +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/onnx_pretrained.py @@ -0,0 +1,428 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang) + +""" +This script loads ONNX models exported by ./export-onnx.py +and uses them to decode waves. + +We use the pre-trained model from +https://huggingface.co/csukuangfj/icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03 +as an example to show how to use this file. + +1. Download the pre-trained model + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "data/lang_bpe_500/bpe.model" +git lfs pull --include "exp/pretrained-iter-468000-avg-16.pt" +cd exp +ln -s exp/pretrained-iter-468000-avg-16.pt epoch-99.pt +popd + +2. Export the model to ONNX + +./lstm_transducer_stateless2/export-onnx.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --exp-dir $repo/exp + +It will generate the following 3 files inside $repo/exp: + + - encoder-epoch-99-avg-1.onnx + - decoder-epoch-99-avg-1.onnx + - joiner-epoch-99-avg-1.onnx + +3. Run this file with the exported ONNX models + +./lstm_transducer_stateless2/onnx_pretrained.py \ + --encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \ + --decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \ + --joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1221-135766-0001.wav + +Note: Even though this script only supports decoding a single file, +the exported ONNX models do support batch processing. +""" + +import argparse +import logging +from typing import Dict, List, Optional, Tuple + +import k2 +import numpy as np +import onnxruntime as ort +import torch +import torchaudio +from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--encoder-model-filename", + type=str, + required=True, + help="Path to the encoder onnx model. ", + ) + + parser.add_argument( + "--decoder-model-filename", + type=str, + required=True, + help="Path to the decoder onnx model. ", + ) + + parser.add_argument( + "--joiner-model-filename", + type=str, + required=True, + help="Path to the joiner onnx model. ", + ) + + parser.add_argument( + "--tokens", + type=str, + help="""Path to tokens.txt.""", + ) + + parser.add_argument( + "sound_file", + type=str, + help="The input sound file to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + return parser + + +class OnnxModel: + def __init__( + self, + encoder_model_filename: str, + decoder_model_filename: str, + joiner_model_filename: str, + ): + session_opts = ort.SessionOptions() + session_opts.inter_op_num_threads = 1 + session_opts.intra_op_num_threads = 1 + + self.session_opts = session_opts + + self.init_encoder(encoder_model_filename) + self.init_decoder(decoder_model_filename) + self.init_joiner(joiner_model_filename) + + def init_encoder(self, encoder_model_filename: str): + self.encoder = ort.InferenceSession( + encoder_model_filename, + sess_options=self.session_opts, + ) + self.init_encoder_states() + + def init_encoder_states(self, batch_size: int = 1): + encoder_meta = self.encoder.get_modelmeta().custom_metadata_map + + model_type = encoder_meta["model_type"] + assert model_type == "lstm", model_type + + decode_chunk_len = int(encoder_meta["decode_chunk_len"]) + T = int(encoder_meta["T"]) + + num_encoder_layers = int(encoder_meta["num_encoder_layers"]) + d_model = int(encoder_meta["d_model"]) + rnn_hidden_size = int(encoder_meta["rnn_hidden_size"]) + + logging.info(f"decode_chunk_len: {decode_chunk_len}") + logging.info(f"T: {T}") + logging.info(f"num_encoder_layers: {num_encoder_layers}") + logging.info(f"d_model: {d_model}") + logging.info(f"rnn_hidden_size: {rnn_hidden_size}") + + N = batch_size + + s0 = torch.zeros(num_encoder_layers, N, d_model) + s1 = torch.zeros(num_encoder_layers, N, rnn_hidden_size) + states = [s0.numpy(), s1.numpy()] + + self.states = states + + self.segment = T + self.offset = decode_chunk_len + + def init_decoder(self, decoder_model_filename: str): + self.decoder = ort.InferenceSession( + decoder_model_filename, + sess_options=self.session_opts, + ) + + decoder_meta = self.decoder.get_modelmeta().custom_metadata_map + self.context_size = int(decoder_meta["context_size"]) + self.vocab_size = int(decoder_meta["vocab_size"]) + + logging.info(f"context_size: {self.context_size}") + logging.info(f"vocab_size: {self.vocab_size}") + + def init_joiner(self, joiner_model_filename: str): + self.joiner = ort.InferenceSession( + joiner_model_filename, + sess_options=self.session_opts, + ) + + joiner_meta = self.joiner.get_modelmeta().custom_metadata_map + self.joiner_dim = int(joiner_meta["joiner_dim"]) + + logging.info(f"joiner_dim: {self.joiner_dim}") + + def _build_encoder_input_output( + self, + x: torch.Tensor, + ) -> Tuple[Dict[str, np.ndarray], List[str]]: + encoder_input = { + "x": x.numpy(), + "state0": self.states[0], + "state1": self.states[1], + } + encoder_output = ["encoder_out", "new_state0", "new_state1"] + + return encoder_input, encoder_output + + def _update_states(self, states: List[np.ndarray]): + self.states = states + + def run_encoder(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x: + A 3-D tensor of shape (N, T, C) + Returns: + Return a 3-D tensor of shape (N, T', joiner_dim) where + T' is usually equal to ((T-3)//2-1)//2 + """ + encoder_input, encoder_output_names = self._build_encoder_input_output(x) + out = self.encoder.run(encoder_output_names, encoder_input) + + self._update_states(out[1:]) + + return torch.from_numpy(out[0]) + + def run_decoder(self, decoder_input: torch.Tensor) -> torch.Tensor: + """ + Args: + decoder_input: + A 2-D tensor of shape (N, context_size) + Returns: + Return a 2-D tensor of shape (N, joiner_dim) + """ + out = self.decoder.run( + [self.decoder.get_outputs()[0].name], + {self.decoder.get_inputs()[0].name: decoder_input.numpy()}, + )[0] + + return torch.from_numpy(out) + + def run_joiner( + self, encoder_out: torch.Tensor, decoder_out: torch.Tensor + ) -> torch.Tensor: + """ + Args: + encoder_out: + A 2-D tensor of shape (N, joiner_dim) + decoder_out: + A 2-D tensor of shape (N, joiner_dim) + Returns: + Return a 2-D tensor of shape (N, vocab_size) + """ + out = self.joiner.run( + [self.joiner.get_outputs()[0].name], + { + self.joiner.get_inputs()[0].name: encoder_out.numpy(), + self.joiner.get_inputs()[1].name: decoder_out.numpy(), + }, + )[0] + + return torch.from_numpy(out) + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0].contiguous()) + return ans + + +def create_streaming_feature_extractor() -> OnlineFeature: + """Create a CPU streaming feature extractor. + + At present, we assume it returns a fbank feature extractor with + fixed options. In the future, we will support passing in the options + from outside. + + Returns: + Return a CPU streaming feature extractor. + """ + opts = FbankOptions() + opts.device = "cpu" + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = 16000 + opts.mel_opts.num_bins = 80 + return OnlineFbank(opts) + + +def greedy_search( + model: OnnxModel, + encoder_out: torch.Tensor, + context_size: int, + decoder_out: Optional[torch.Tensor] = None, + hyp: Optional[List[int]] = None, +) -> List[int]: + """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. + Args: + model: + The transducer model. + encoder_out: + A 3-D tensor of shape (1, T, joiner_dim) + context_size: + The context size of the decoder model. + decoder_out: + Optional. Decoder output of the previous chunk. + hyp: + Decoding results for previous chunks. + Returns: + Return the decoded results so far. + """ + + blank_id = 0 + + if decoder_out is None: + assert hyp is None, hyp + hyp = [blank_id] * context_size + decoder_input = torch.tensor([hyp], dtype=torch.int64) + decoder_out = model.run_decoder(decoder_input) + else: + assert hyp is not None, hyp + + encoder_out = encoder_out.squeeze(0) + T = encoder_out.size(0) + for t in range(T): + cur_encoder_out = encoder_out[t : t + 1] + joiner_out = model.run_joiner(cur_encoder_out, decoder_out).squeeze(0) + y = joiner_out.argmax(dim=0).item() + if y != blank_id: + hyp.append(y) + decoder_input = hyp[-context_size:] + decoder_input = torch.tensor([decoder_input], dtype=torch.int64) + decoder_out = model.run_decoder(decoder_input) + + return hyp, decoder_out + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + logging.info(vars(args)) + + model = OnnxModel( + encoder_model_filename=args.encoder_model_filename, + decoder_model_filename=args.decoder_model_filename, + joiner_model_filename=args.joiner_model_filename, + ) + + sample_rate = 16000 + + logging.info("Constructing Fbank computer") + online_fbank = create_streaming_feature_extractor() + + logging.info(f"Reading sound files: {args.sound_file}") + waves = read_sound_files( + filenames=[args.sound_file], + expected_sample_rate=sample_rate, + )[0] + + tail_padding = torch.zeros(int(0.3 * sample_rate), dtype=torch.float32) + wave_samples = torch.cat([waves, tail_padding]) + + num_processed_frames = 0 + segment = model.segment + offset = model.offset + + context_size = model.context_size + hyp = None + decoder_out = None + + chunk = int(1 * sample_rate) # 1 second + start = 0 + while start < wave_samples.numel(): + end = min(start + chunk, wave_samples.numel()) + samples = wave_samples[start:end] + start += chunk + + online_fbank.accept_waveform( + sampling_rate=sample_rate, + waveform=samples, + ) + + while online_fbank.num_frames_ready - num_processed_frames >= segment: + frames = [] + for i in range(segment): + frames.append(online_fbank.get_frame(num_processed_frames + i)) + num_processed_frames += offset + frames = torch.cat(frames, dim=0) + frames = frames.unsqueeze(0) + encoder_out = model.run_encoder(frames) + hyp, decoder_out = greedy_search( + model, + encoder_out, + context_size, + decoder_out, + hyp, + ) + + symbol_table = k2.SymbolTable.from_file(args.tokens) + + text = "" + for i in hyp[context_size:]: + text += symbol_table[i] + text = text.replace("▁", " ").strip() + + logging.info(args.sound_file) + logging.info(text) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-ncnn-decode.py b/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-ncnn-decode.py index 02ed16a8c..cbbc77928 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-ncnn-decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-ncnn-decode.py @@ -16,13 +16,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +""" +Please see +https://k2-fsa.github.io/icefall/model-export/export-ncnn.html +for usage +""" import argparse import logging from typing import List, Optional +import k2 import ncnn -import sentencepiece as spm import torch import torchaudio from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature @@ -32,9 +37,9 @@ def get_args(): parser = argparse.ArgumentParser() parser.add_argument( - "--bpe-model-filename", + "--tokens", type=str, - help="Path to bpe.model", + help="Path to tokens.txt", ) parser.add_argument( @@ -251,9 +256,6 @@ def main(): model = Model(args) - sp = spm.SentencePieceProcessor() - sp.load(args.bpe_model_filename) - sound_file = args.sound_filename sample_rate = 16000 @@ -329,10 +331,16 @@ def main(): model, encoder_out.squeeze(0), decoder_out, hyp ) + symbol_table = k2.SymbolTable.from_file(args.tokens) + context_size = 2 + text = "" + for i in hyp[context_size:]: + text += symbol_table[i] + text = text.replace("▁", " ").strip() logging.info(sound_file) - logging.info(sp.decode(hyp[context_size:])) + logging.info(text) if __name__ == "__main__": diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/export-onnx.py b/egs/librispeech/ASR/lstm_transducer_stateless3/export-onnx.py new file mode 120000 index 000000000..9f5064deb --- /dev/null +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/export-onnx.py @@ -0,0 +1 @@ +../lstm_transducer_stateless2/export-onnx.py \ No newline at end of file diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/onnx_check.py b/egs/librispeech/ASR/lstm_transducer_stateless3/onnx_check.py new file mode 120000 index 000000000..0b1ea0326 --- /dev/null +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/onnx_check.py @@ -0,0 +1 @@ +../lstm_transducer_stateless2/onnx_check.py \ No newline at end of file diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/onnx_pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless3/onnx_pretrained.py new file mode 120000 index 000000000..099c2882f --- /dev/null +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/onnx_pretrained.py @@ -0,0 +1 @@ +../lstm_transducer_stateless2/onnx_pretrained.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check.py b/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check.py index 6541f0295..5ca4173c1 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check.py @@ -125,13 +125,13 @@ def test_encoder( onnx_model: OnnxModel, ): C = 80 - for i in range(10): + for i in range(3): N = torch.randint(low=1, high=20, size=(1,)).item() - T = torch.randint(low=50, high=100, size=(1,)).item() + T = torch.randint(low=30, high=50, size=(1,)).item() logging.info(f"test_encoder: iter {i}, N={N}, T={T}") x = torch.rand(N, T, C) - x_lens = torch.randint(low=10, high=T + 1, size=(N,)) + x_lens = torch.randint(low=30, high=T + 1, size=(N,)) x_lens[0] = T torch_encoder_out, torch_encoder_out_lens = torch_model.encoder(x, x_lens) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/export-onnx.py b/egs/librispeech/ASR/pruned_transducer_stateless7/export-onnx.py new file mode 100755 index 000000000..f76915a74 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/export-onnx.py @@ -0,0 +1,560 @@ +#!/usr/bin/env python3 +# +# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang) + +""" +This script exports a transducer model from PyTorch to ONNX. + +We use the pre-trained model from +https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11 +as an example to show how to use this file. + +1. Download the pre-trained model + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "data/lang_bpe_500/bpe.model" +git lfs pull --include "exp/pretrained-epoch-30-avg-9.pt" + +cd exp +ln -s pretrained-epoch-30-avg-9.pt epoch-99.pt +popd + +2. Export the model to ONNX + +./pruned_transducer_stateless7/export-onnx.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --exp-dir $repo/exp \ + --feedforward-dims "1024,1024,2048,2048,1024" + +It will generate the following 3 files inside $repo/exp: + + - encoder-epoch-99-avg-1.onnx + - decoder-epoch-99-avg-1.onnx + - joiner-epoch-99-avg-1.onnx + +See ./onnx_pretrained.py and ./onnx_check.py for how to +use the exported ONNX models. +""" + +import argparse +import logging +from pathlib import Path +from typing import Dict, Tuple + +import onnx +import sentencepiece as spm +import torch +import torch.nn as nn +from decoder import Decoder +from scaling_converter import convert_scaled_to_non_scaled +from train import add_model_arguments, get_params, get_transducer_model +from zipformer import Zipformer + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import setup_logger, str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="""It specifies the checkpoint to use for averaging. + Note: Epoch counts from 0. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless5/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + add_model_arguments(parser) + + return parser + + +def add_meta_data(filename: str, meta_data: Dict[str, str]): + """Add meta data to an ONNX model. It is changed in-place. + + Args: + filename: + Filename of the ONNX model to be changed. + meta_data: + Key-value pairs. + """ + model = onnx.load(filename) + for key, value in meta_data.items(): + meta = model.metadata_props.add() + meta.key = key + meta.value = value + + onnx.save(model, filename) + + +class OnnxEncoder(nn.Module): + """A wrapper for Zipformer and the encoder_proj from the joiner""" + + def __init__(self, encoder: Zipformer, encoder_proj: nn.Linear): + """ + Args: + encoder: + A Zipformer encoder. + encoder_proj: + The projection layer for encoder from the joiner. + """ + super().__init__() + self.encoder = encoder + self.encoder_proj = encoder_proj + + def forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Please see the help information of Zipformer.forward + + Args: + x: + A 3-D tensor of shape (N, T, C) + x_lens: + A 1-D tensor of shape (N,). Its dtype is torch.int64 + Returns: + Return a tuple containing: + - encoder_out, A 3-D tensor of shape (N, T', joiner_dim) + - encoder_out_lens, A 1-D tensor of shape (N,) + """ + encoder_out, encoder_out_lens = self.encoder(x, x_lens) + + encoder_out = self.encoder_proj(encoder_out) + # Now encoder_out is of shape (N, T, joiner_dim) + + return encoder_out, encoder_out_lens + + +class OnnxDecoder(nn.Module): + """A wrapper for Decoder and the decoder_proj from the joiner""" + + def __init__(self, decoder: Decoder, decoder_proj: nn.Linear): + super().__init__() + self.decoder = decoder + self.decoder_proj = decoder_proj + + def forward(self, y: torch.Tensor) -> torch.Tensor: + """ + Args: + y: + A 2-D tensor of shape (N, context_size). + Returns + Return a 2-D tensor of shape (N, joiner_dim) + """ + need_pad = False + decoder_output = self.decoder(y, need_pad=need_pad) + decoder_output = decoder_output.squeeze(1) + output = self.decoder_proj(decoder_output) + + return output + + +class OnnxJoiner(nn.Module): + """A wrapper for the joiner""" + + def __init__(self, output_linear: nn.Linear): + super().__init__() + self.output_linear = output_linear + + def forward( + self, + encoder_out: torch.Tensor, + decoder_out: torch.Tensor, + ) -> torch.Tensor: + """ + Args: + encoder_out: + A 2-D tensor of shape (N, joiner_dim) + decoder_out: + A 2-D tensor of shape (N, joiner_dim) + Returns: + Return a 2-D tensor of shape (N, vocab_size) + """ + logit = encoder_out + decoder_out + logit = self.output_linear(torch.tanh(logit)) + return logit + + +def export_encoder_model_onnx( + encoder_model: OnnxEncoder, + encoder_filename: str, + opset_version: int = 11, +) -> None: + """Export the given encoder model to ONNX format. + The exported model has two inputs: + + - x, a tensor of shape (N, T, C); dtype is torch.float32 + - x_lens, a tensor of shape (N,); dtype is torch.int64 + + and it has two outputs: + + - encoder_out, a tensor of shape (N, T', joiner_dim) + - encoder_out_lens, a tensor of shape (N,) + + Args: + encoder_model: + The input encoder model + encoder_filename: + The filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + x = torch.zeros(1, 100, 80, dtype=torch.float32) + x_lens = torch.tensor([100], dtype=torch.int64) + + torch.onnx.export( + encoder_model, + (x, x_lens), + encoder_filename, + verbose=False, + opset_version=opset_version, + input_names=["x", "x_lens"], + output_names=["encoder_out", "encoder_out_lens"], + dynamic_axes={ + "x": {0: "N", 1: "T"}, + "x_lens": {0: "N"}, + "encoder_out": {0: "N", 1: "T"}, + "encoder_out_lens": {0: "N"}, + }, + ) + + +def export_decoder_model_onnx( + decoder_model: OnnxDecoder, + decoder_filename: str, + opset_version: int = 11, +) -> None: + """Export the decoder model to ONNX format. + + The exported model has one input: + + - y: a torch.int64 tensor of shape (N, decoder_model.context_size) + + and has one output: + + - decoder_out: a torch.float32 tensor of shape (N, joiner_dim) + + Args: + decoder_model: + The decoder model to be exported. + decoder_filename: + Filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + context_size = decoder_model.decoder.context_size + vocab_size = decoder_model.decoder.vocab_size + + y = torch.zeros(10, context_size, dtype=torch.int64) + torch.onnx.export( + decoder_model, + y, + decoder_filename, + verbose=False, + opset_version=opset_version, + input_names=["y"], + output_names=["decoder_out"], + dynamic_axes={ + "y": {0: "N"}, + "decoder_out": {0: "N"}, + }, + ) + + meta_data = { + "context_size": str(context_size), + "vocab_size": str(vocab_size), + } + add_meta_data(filename=decoder_filename, meta_data=meta_data) + + +def export_joiner_model_onnx( + joiner_model: nn.Module, + joiner_filename: str, + opset_version: int = 11, +) -> None: + """Export the joiner model to ONNX format. + The exported joiner model has two inputs: + + - encoder_out: a tensor of shape (N, joiner_dim) + - decoder_out: a tensor of shape (N, joiner_dim) + + and produces one output: + + - logit: a tensor of shape (N, vocab_size) + """ + joiner_dim = joiner_model.output_linear.weight.shape[1] + logging.info(f"joiner dim: {joiner_dim}") + + projected_encoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) + projected_decoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) + + torch.onnx.export( + joiner_model, + (projected_encoder_out, projected_decoder_out), + joiner_filename, + verbose=False, + opset_version=opset_version, + input_names=[ + "encoder_out", + "decoder_out", + ], + output_names=["logit"], + dynamic_axes={ + "encoder_out": {0: "N"}, + "decoder_out": {0: "N"}, + "logit": {0: "N"}, + }, + ) + meta_data = { + "joiner_dim": str(joiner_dim), + } + add_meta_data(filename=joiner_filename, meta_data=meta_data) + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + setup_logger(f"{params.exp_dir}/log-export/log-export-onnx") + + logging.info(f"device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + model.to(device) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to("cpu") + model.eval() + + convert_scaled_to_non_scaled(model, inplace=True) + + encoder = OnnxEncoder( + encoder=model.encoder, + encoder_proj=model.joiner.encoder_proj, + ) + + decoder = OnnxDecoder( + decoder=model.decoder, + decoder_proj=model.joiner.decoder_proj, + ) + + joiner = OnnxJoiner(output_linear=model.joiner.output_linear) + + encoder_num_param = sum([p.numel() for p in encoder.parameters()]) + decoder_num_param = sum([p.numel() for p in decoder.parameters()]) + joiner_num_param = sum([p.numel() for p in joiner.parameters()]) + total_num_param = encoder_num_param + decoder_num_param + joiner_num_param + logging.info(f"encoder parameters: {encoder_num_param}") + logging.info(f"decoder parameters: {decoder_num_param}") + logging.info(f"joiner parameters: {joiner_num_param}") + logging.info(f"total parameters: {total_num_param}") + + if params.iter > 0: + suffix = f"iter-{params.iter}" + else: + suffix = f"epoch-{params.epoch}" + + suffix += f"-avg-{params.avg}" + + opset_version = 13 + + logging.info("Exporting encoder") + encoder_filename = params.exp_dir / f"encoder-{suffix}.onnx" + export_encoder_model_onnx( + encoder, + encoder_filename, + opset_version=opset_version, + ) + logging.info(f"Exported encoder to {encoder_filename}") + + logging.info("Exporting decoder") + decoder_filename = params.exp_dir / f"decoder-{suffix}.onnx" + export_decoder_model_onnx( + decoder, + decoder_filename, + opset_version=opset_version, + ) + logging.info(f"Exported decoder to {decoder_filename}") + + logging.info("Exporting joiner") + joiner_filename = params.exp_dir / f"joiner-{suffix}.onnx" + export_joiner_model_onnx( + joiner, + joiner_filename, + opset_version=opset_version, + ) + logging.info(f"Exported joiner to {joiner_filename}") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/export.py b/egs/librispeech/ASR/pruned_transducer_stateless7/export.py index db8b5eb2b..3e3160e7e 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/export.py @@ -41,31 +41,7 @@ Check https://github.com/k2-fsa/sherpa for how to use the exported models outside of icefall. -(2) Export to ONNX format - -./pruned_transducer_stateless7/export.py \ - --exp-dir ./pruned_transducer_stateless7/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ - --epoch 20 \ - --avg 10 \ - --onnx 1 - -It will generate the following files in the given `exp_dir`. -Check `onnx_check.py` for how to use them. - - - encoder.onnx - - decoder.onnx - - joiner.onnx - - joiner_encoder_proj.onnx - - joiner_decoder_proj.onnx - -Please see ./onnx_pretrained.py for usage of the generated files - -Check -https://github.com/k2-fsa/sherpa-onnx -for how to use the exported models outside of icefall. - -(3) Export `model.state_dict()` +(2) Export `model.state_dict()` ./pruned_transducer_stateless7/export.py \ --exp-dir ./pruned_transducer_stateless7/exp \ @@ -196,23 +172,6 @@ def get_parser(): """, ) - parser.add_argument( - "--onnx", - type=str2bool, - default=False, - help="""If True, --jit is ignored and it exports the model - to onnx format. It will generate the following files: - - - encoder.onnx - - decoder.onnx - - joiner.onnx - - joiner_encoder_proj.onnx - - joiner_decoder_proj.onnx - - Refer to ./onnx_check.py and ./onnx_pretrained.py for how to use them. - """, - ) - parser.add_argument( "--context-size", type=int, @@ -225,204 +184,6 @@ def get_parser(): return parser -def export_encoder_model_onnx( - encoder_model: nn.Module, - encoder_filename: str, - opset_version: int = 11, -) -> None: - """Export the given encoder model to ONNX format. - The exported model has two inputs: - - - x, a tensor of shape (N, T, C); dtype is torch.float32 - - x_lens, a tensor of shape (N,); dtype is torch.int64 - - and it has two outputs: - - - encoder_out, a tensor of shape (N, T, C) - - encoder_out_lens, a tensor of shape (N,) - - Note: The warmup argument is fixed to 1. - - Args: - encoder_model: - The input encoder model - encoder_filename: - The filename to save the exported ONNX model. - opset_version: - The opset version to use. - """ - x = torch.zeros(1, 101, 80, dtype=torch.float32) - x_lens = torch.tensor([101], dtype=torch.int64) - - # encoder_model = torch.jit.script(encoder_model) - # It throws the following error for the above statement - # - # RuntimeError: Exporting the operator __is_ to ONNX opset version - # 11 is not supported. Please feel free to request support or - # submit a pull request on PyTorch GitHub. - # - # I cannot find which statement causes the above error. - # torch.onnx.export() will use torch.jit.trace() internally, which - # works well for the current reworked model - torch.onnx.export( - encoder_model, - (x, x_lens), - encoder_filename, - verbose=False, - opset_version=opset_version, - input_names=["x", "x_lens"], - output_names=["encoder_out", "encoder_out_lens"], - dynamic_axes={ - "x": {0: "N", 1: "T"}, - "x_lens": {0: "N"}, - "encoder_out": {0: "N", 1: "T"}, - "encoder_out_lens": {0: "N"}, - }, - ) - logging.info(f"Saved to {encoder_filename}") - - -def export_decoder_model_onnx( - decoder_model: nn.Module, - decoder_filename: str, - opset_version: int = 11, -) -> None: - """Export the decoder model to ONNX format. - - The exported model has one input: - - - y: a torch.int64 tensor of shape (N, decoder_model.context_size) - - and has one output: - - - decoder_out: a torch.float32 tensor of shape (N, 1, C) - - Note: The argument need_pad is fixed to False. - - Args: - decoder_model: - The decoder model to be exported. - decoder_filename: - Filename to save the exported ONNX model. - opset_version: - The opset version to use. - """ - y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64) - need_pad = False # Always False, so we can use torch.jit.trace() here - # Note(fangjun): torch.jit.trace() is more efficient than torch.jit.script() - # in this case - torch.onnx.export( - decoder_model, - (y, need_pad), - decoder_filename, - verbose=False, - opset_version=opset_version, - input_names=["y", "need_pad"], - output_names=["decoder_out"], - dynamic_axes={ - "y": {0: "N"}, - "decoder_out": {0: "N"}, - }, - ) - logging.info(f"Saved to {decoder_filename}") - - -def export_joiner_model_onnx( - joiner_model: nn.Module, - joiner_filename: str, - opset_version: int = 11, -) -> None: - """Export the joiner model to ONNX format. - The exported joiner model has two inputs: - - - projected_encoder_out: a tensor of shape (N, joiner_dim) - - projected_decoder_out: a tensor of shape (N, joiner_dim) - - and produces one output: - - - logit: a tensor of shape (N, vocab_size) - - The exported encoder_proj model has one input: - - - encoder_out: a tensor of shape (N, encoder_out_dim) - - and produces one output: - - - projected_encoder_out: a tensor of shape (N, joiner_dim) - - The exported decoder_proj model has one input: - - - decoder_out: a tensor of shape (N, decoder_out_dim) - - and produces one output: - - - projected_decoder_out: a tensor of shape (N, joiner_dim) - """ - encoder_proj_filename = str(joiner_filename).replace(".onnx", "_encoder_proj.onnx") - decoder_proj_filename = str(joiner_filename).replace(".onnx", "_decoder_proj.onnx") - - encoder_out_dim = joiner_model.encoder_proj.weight.shape[1] - decoder_out_dim = joiner_model.decoder_proj.weight.shape[1] - joiner_dim = joiner_model.decoder_proj.weight.shape[0] - - projected_encoder_out = torch.rand(1, 1, 1, joiner_dim, dtype=torch.float32) - projected_decoder_out = torch.rand(1, 1, 1, joiner_dim, dtype=torch.float32) - - project_input = False - # Note: It uses torch.jit.trace() internally - torch.onnx.export( - joiner_model, - (projected_encoder_out, projected_decoder_out, project_input), - joiner_filename, - verbose=False, - opset_version=opset_version, - input_names=[ - "encoder_out", - "decoder_out", - "project_input", - ], - output_names=["logit"], - dynamic_axes={ - "encoder_out": {0: "N"}, - "decoder_out": {0: "N"}, - "logit": {0: "N"}, - }, - ) - logging.info(f"Saved to {joiner_filename}") - - encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32) - torch.onnx.export( - joiner_model.encoder_proj, - encoder_out, - encoder_proj_filename, - verbose=False, - opset_version=opset_version, - input_names=["encoder_out"], - output_names=["projected_encoder_out"], - dynamic_axes={ - "encoder_out": {0: "N"}, - "projected_encoder_out": {0: "N"}, - }, - ) - logging.info(f"Saved to {encoder_proj_filename}") - - decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32) - torch.onnx.export( - joiner_model.decoder_proj, - decoder_out, - decoder_proj_filename, - verbose=False, - opset_version=opset_version, - input_names=["decoder_out"], - output_names=["projected_decoder_out"], - dynamic_axes={ - "decoder_out": {0: "N"}, - "projected_decoder_out": {0: "N"}, - }, - ) - logging.info(f"Saved to {decoder_proj_filename}") - - @torch.no_grad() def main(): args = get_parser().parse_args() @@ -531,31 +292,7 @@ def main(): model.to("cpu") model.eval() - if params.onnx is True: - convert_scaled_to_non_scaled(model, inplace=True) - opset_version = 13 - logging.info("Exporting to onnx format") - encoder_filename = params.exp_dir / "encoder.onnx" - export_encoder_model_onnx( - model.encoder, - encoder_filename, - opset_version=opset_version, - ) - - decoder_filename = params.exp_dir / "decoder.onnx" - export_decoder_model_onnx( - model.decoder, - decoder_filename, - opset_version=opset_version, - ) - - joiner_filename = params.exp_dir / "joiner.onnx" - export_joiner_model_onnx( - model.joiner, - joiner_filename, - opset_version=opset_version, - ) - elif params.jit is True: + if params.jit is True: convert_scaled_to_non_scaled(model, inplace=True) # We won't use the forward() method of the model in C++, so just ignore # it here. diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/onnx_check.py b/egs/librispeech/ASR/pruned_transducer_stateless7/onnx_check.py deleted file mode 100755 index 63acc0922..000000000 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/onnx_check.py +++ /dev/null @@ -1,286 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2022 Xiaomi Corporation (Author: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -This script checks that exported onnx models produce the same output -with the given torchscript model for the same input. -""" - -import argparse -import logging - -import onnxruntime as ort -import torch - -from icefall import is_module_available - -if not is_module_available("onnxruntime"): - raise ValueError("Please 'pip install onnxruntime' first.") - - -ort.set_default_logger_severity(3) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--jit-filename", - required=True, - type=str, - help="Path to the torchscript model", - ) - - parser.add_argument( - "--onnx-encoder-filename", - required=True, - type=str, - help="Path to the onnx encoder model", - ) - - parser.add_argument( - "--onnx-decoder-filename", - required=True, - type=str, - help="Path to the onnx decoder model", - ) - - parser.add_argument( - "--onnx-joiner-filename", - required=True, - type=str, - help="Path to the onnx joiner model", - ) - - parser.add_argument( - "--onnx-joiner-encoder-proj-filename", - required=True, - type=str, - help="Path to the onnx joiner encoder projection model", - ) - - parser.add_argument( - "--onnx-joiner-decoder-proj-filename", - required=True, - type=str, - help="Path to the onnx joiner decoder projection model", - ) - - return parser - - -def test_encoder( - model: torch.jit.ScriptModule, - encoder_session: ort.InferenceSession, -): - inputs = encoder_session.get_inputs() - outputs = encoder_session.get_outputs() - input_names = [n.name for n in inputs] - output_names = [n.name for n in outputs] - - assert inputs[0].shape == ["N", "T", 80] - assert inputs[1].shape == ["N"] - - for N in [1, 5]: - for T in [12, 50]: - print("N, T", N, T) - x = torch.rand(N, T, 80, dtype=torch.float32) - x_lens = torch.randint(low=10, high=T + 1, size=(N,)) - x_lens[0] = T - - encoder_inputs = { - input_names[0]: x.numpy(), - input_names[1]: x_lens.numpy(), - } - - torch_encoder_out, torch_encoder_out_lens = model.encoder(x, x_lens) - - encoder_out, encoder_out_lens = encoder_session.run( - output_names, - encoder_inputs, - ) - - torch_encoder_out, torch_encoder_out_lens = model.encoder(x, x_lens) - - encoder_out = torch.from_numpy(encoder_out) - assert torch.allclose(encoder_out, torch_encoder_out, atol=1e-05), ( - (encoder_out - torch_encoder_out).abs().max(), - encoder_out.shape, - torch_encoder_out.shape, - ) - - -def test_decoder( - model: torch.jit.ScriptModule, - decoder_session: ort.InferenceSession, -): - inputs = decoder_session.get_inputs() - outputs = decoder_session.get_outputs() - input_names = [n.name for n in inputs] - output_names = [n.name for n in outputs] - - assert inputs[0].shape == ["N", 2] - for N in [1, 5, 10]: - y = torch.randint(low=1, high=500, size=(10, 2)) - - decoder_inputs = {input_names[0]: y.numpy()} - decoder_out = decoder_session.run( - output_names, - decoder_inputs, - )[0] - decoder_out = torch.from_numpy(decoder_out) - - torch_decoder_out = model.decoder(y, need_pad=False) - assert torch.allclose(decoder_out, torch_decoder_out, atol=1e-5), ( - (decoder_out - torch_decoder_out).abs().max() - ) - - -def test_joiner( - model: torch.jit.ScriptModule, - joiner_session: ort.InferenceSession, - joiner_encoder_proj_session: ort.InferenceSession, - joiner_decoder_proj_session: ort.InferenceSession, -): - joiner_inputs = joiner_session.get_inputs() - joiner_outputs = joiner_session.get_outputs() - joiner_input_names = [n.name for n in joiner_inputs] - joiner_output_names = [n.name for n in joiner_outputs] - - assert joiner_inputs[0].shape == ["N", 1, 1, 512] - assert joiner_inputs[1].shape == ["N", 1, 1, 512] - - joiner_encoder_proj_inputs = joiner_encoder_proj_session.get_inputs() - encoder_proj_input_name = joiner_encoder_proj_inputs[0].name - - assert joiner_encoder_proj_inputs[0].shape == ["N", 384] - - joiner_encoder_proj_outputs = joiner_encoder_proj_session.get_outputs() - encoder_proj_output_name = joiner_encoder_proj_outputs[0].name - - joiner_decoder_proj_inputs = joiner_decoder_proj_session.get_inputs() - decoder_proj_input_name = joiner_decoder_proj_inputs[0].name - - assert joiner_decoder_proj_inputs[0].shape == ["N", 512] - - joiner_decoder_proj_outputs = joiner_decoder_proj_session.get_outputs() - decoder_proj_output_name = joiner_decoder_proj_outputs[0].name - - for N in [1, 5, 10]: - encoder_out = torch.rand(N, 384) - decoder_out = torch.rand(N, 512) - - projected_encoder_out = torch.rand(N, 1, 1, 512) - projected_decoder_out = torch.rand(N, 1, 1, 512) - - joiner_inputs = { - joiner_input_names[0]: projected_encoder_out.numpy(), - joiner_input_names[1]: projected_decoder_out.numpy(), - } - joiner_out = joiner_session.run(joiner_output_names, joiner_inputs)[0] - joiner_out = torch.from_numpy(joiner_out) - - torch_joiner_out = model.joiner( - projected_encoder_out, - projected_decoder_out, - project_input=False, - ) - assert torch.allclose(joiner_out, torch_joiner_out, atol=1e-5), ( - (joiner_out - torch_joiner_out).abs().max() - ) - - # Now test encoder_proj - joiner_encoder_proj_inputs = {encoder_proj_input_name: encoder_out.numpy()} - joiner_encoder_proj_out = joiner_encoder_proj_session.run( - [encoder_proj_output_name], joiner_encoder_proj_inputs - )[0] - joiner_encoder_proj_out = torch.from_numpy(joiner_encoder_proj_out) - - torch_joiner_encoder_proj_out = model.joiner.encoder_proj(encoder_out) - assert torch.allclose( - joiner_encoder_proj_out, torch_joiner_encoder_proj_out, atol=1e-5 - ), ((joiner_encoder_proj_out - torch_joiner_encoder_proj_out).abs().max()) - - # Now test decoder_proj - joiner_decoder_proj_inputs = {decoder_proj_input_name: decoder_out.numpy()} - joiner_decoder_proj_out = joiner_decoder_proj_session.run( - [decoder_proj_output_name], joiner_decoder_proj_inputs - )[0] - joiner_decoder_proj_out = torch.from_numpy(joiner_decoder_proj_out) - - torch_joiner_decoder_proj_out = model.joiner.decoder_proj(decoder_out) - assert torch.allclose( - joiner_decoder_proj_out, torch_joiner_decoder_proj_out, atol=1e-5 - ), ((joiner_decoder_proj_out - torch_joiner_decoder_proj_out).abs().max()) - - -@torch.no_grad() -def main(): - args = get_parser().parse_args() - logging.info(vars(args)) - - model = torch.jit.load(args.jit_filename) - - options = ort.SessionOptions() - options.inter_op_num_threads = 1 - options.intra_op_num_threads = 1 - - logging.info("Test encoder") - encoder_session = ort.InferenceSession( - args.onnx_encoder_filename, - sess_options=options, - ) - test_encoder(model, encoder_session) - - logging.info("Test decoder") - decoder_session = ort.InferenceSession( - args.onnx_decoder_filename, - sess_options=options, - ) - test_decoder(model, decoder_session) - - logging.info("Test joiner") - joiner_session = ort.InferenceSession( - args.onnx_joiner_filename, - sess_options=options, - ) - joiner_encoder_proj_session = ort.InferenceSession( - args.onnx_joiner_encoder_proj_filename, - sess_options=options, - ) - joiner_decoder_proj_session = ort.InferenceSession( - args.onnx_joiner_decoder_proj_filename, - sess_options=options, - ) - test_joiner( - model, - joiner_session, - joiner_encoder_proj_session, - joiner_decoder_proj_session, - ) - logging.info("Finished checking ONNX models") - - -if __name__ == "__main__": - torch.manual_seed(20220727) - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/onnx_check.py b/egs/librispeech/ASR/pruned_transducer_stateless7/onnx_check.py new file mode 120000 index 000000000..20e334271 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/onnx_check.py @@ -0,0 +1 @@ +../pruned_transducer_stateless5/onnx_check.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/onnx_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7/onnx_pretrained.py deleted file mode 100755 index 3a06ee293..000000000 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/onnx_pretrained.py +++ /dev/null @@ -1,388 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -This script loads ONNX models and uses them to decode waves. -You can use the following command to get the exported models: - -./pruned_transducer_stateless7/export.py \ - --exp-dir ./pruned_transducer_stateless7/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ - --epoch 20 \ - --avg 10 \ - --onnx 1 - -Usage of this script: - -./pruned_transducer_stateless7/onnx_pretrained.py \ - --encoder-model-filename ./pruned_transducer_stateless7/exp/encoder.onnx \ - --decoder-model-filename ./pruned_transducer_stateless7/exp/decoder.onnx \ - --joiner-model-filename ./pruned_transducer_stateless7/exp/joiner.onnx \ - --joiner-encoder-proj-model-filename ./pruned_transducer_stateless7/exp/joiner_encoder_proj.onnx \ - --joiner-decoder-proj-model-filename ./pruned_transducer_stateless7/exp/joiner_decoder_proj.onnx \ - --bpe-model ./data/lang_bpe_500/bpe.model \ - /path/to/foo.wav \ - /path/to/bar.wav -""" - -import argparse -import logging -import math -from typing import List - -import kaldifeat -import numpy as np -import onnxruntime as ort -import sentencepiece as spm -import torch -import torchaudio -from torch.nn.utils.rnn import pad_sequence - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--encoder-model-filename", - type=str, - required=True, - help="Path to the encoder onnx model. ", - ) - - parser.add_argument( - "--decoder-model-filename", - type=str, - required=True, - help="Path to the decoder onnx model. ", - ) - - parser.add_argument( - "--joiner-model-filename", - type=str, - required=True, - help="Path to the joiner onnx model. ", - ) - - parser.add_argument( - "--joiner-encoder-proj-model-filename", - type=str, - required=True, - help="Path to the joiner encoder_proj onnx model. ", - ) - - parser.add_argument( - "--joiner-decoder-proj-model-filename", - type=str, - required=True, - help="Path to the joiner decoder_proj onnx model. ", - ) - - parser.add_argument( - "--bpe-model", - type=str, - help="""Path to bpe.model.""", - ) - - parser.add_argument( - "sound_files", - type=str, - nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", - ) - - parser.add_argument( - "--sample-rate", - type=int, - default=16000, - help="The sample rate of the input sound file", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="Context size of the decoder model", - ) - - return parser - - -def read_sound_files( - filenames: List[str], expected_sample_rate: float -) -> List[torch.Tensor]: - """Read a list of sound files into a list 1-D float32 torch tensors. - Args: - filenames: - A list of sound filenames. - expected_sample_rate: - The expected sample rate of the sound files. - Returns: - Return a list of 1-D float32 torch tensors. - """ - ans = [] - for f in filenames: - wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" - # We use only the first channel - ans.append(wave[0]) - return ans - - -def greedy_search( - decoder: ort.InferenceSession, - joiner: ort.InferenceSession, - joiner_encoder_proj: ort.InferenceSession, - joiner_decoder_proj: ort.InferenceSession, - encoder_out: np.ndarray, - encoder_out_lens: np.ndarray, - context_size: int, -) -> List[List[int]]: - """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. - Args: - decoder: - The decoder model. - joiner: - The joiner model. - joiner_encoder_proj: - The joiner encoder projection model. - joiner_decoder_proj: - The joiner decoder projection model. - encoder_out: - A 3-D tensor of shape (N, T, C) - encoder_out_lens: - A 1-D tensor of shape (N,). - context_size: - The context size of the decoder model. - Returns: - Return the decoded results for each utterance. - """ - encoder_out = torch.from_numpy(encoder_out) - encoder_out_lens = torch.from_numpy(encoder_out_lens) - assert encoder_out.ndim == 3 - assert encoder_out.size(0) >= 1, encoder_out.size(0) - - packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( - input=encoder_out, - lengths=encoder_out_lens.cpu(), - batch_first=True, - enforce_sorted=False, - ) - - projected_encoder_out = joiner_encoder_proj.run( - [joiner_encoder_proj.get_outputs()[0].name], - {joiner_encoder_proj.get_inputs()[0].name: packed_encoder_out.data.numpy()}, - )[0] - - blank_id = 0 # hard-code to 0 - - batch_size_list = packed_encoder_out.batch_sizes.tolist() - N = encoder_out.size(0) - - assert torch.all(encoder_out_lens > 0), encoder_out_lens - assert N == batch_size_list[0], (N, batch_size_list) - - hyps = [[blank_id] * context_size for _ in range(N)] - - decoder_input_nodes = decoder.get_inputs() - decoder_output_nodes = decoder.get_outputs() - - joiner_input_nodes = joiner.get_inputs() - joiner_output_nodes = joiner.get_outputs() - - decoder_input = torch.tensor( - hyps, - dtype=torch.int64, - ) # (N, context_size) - - decoder_out = decoder.run( - [decoder_output_nodes[0].name], - { - decoder_input_nodes[0].name: decoder_input.numpy(), - }, - )[0].squeeze(1) - projected_decoder_out = joiner_decoder_proj.run( - [joiner_decoder_proj.get_outputs()[0].name], - {joiner_decoder_proj.get_inputs()[0].name: decoder_out}, - )[0] - - projected_decoder_out = torch.from_numpy(projected_decoder_out) - - offset = 0 - for batch_size in batch_size_list: - start = offset - end = offset + batch_size - current_encoder_out = projected_encoder_out[start:end] - # current_encoder_out's shape: (batch_size, encoder_out_dim) - offset = end - - projected_decoder_out = projected_decoder_out[:batch_size] - - logits = joiner.run( - [joiner_output_nodes[0].name], - { - joiner_input_nodes[0].name: np.expand_dims( - np.expand_dims(current_encoder_out, axis=1), axis=1 - ), - joiner_input_nodes[1] - .name: projected_decoder_out.unsqueeze(1) - .unsqueeze(1) - .numpy(), - }, - )[0] - logits = torch.from_numpy(logits).squeeze(1).squeeze(1) - # logits'shape (batch_size, vocab_size) - - assert logits.ndim == 2, logits.shape - y = logits.argmax(dim=1).tolist() - emitted = False - for i, v in enumerate(y): - if v != blank_id: - hyps[i].append(v) - emitted = True - if emitted: - # update decoder output - decoder_input = [h[-context_size:] for h in hyps[:batch_size]] - decoder_input = torch.tensor( - decoder_input, - dtype=torch.int64, - ) - decoder_out = decoder.run( - [decoder_output_nodes[0].name], - { - decoder_input_nodes[0].name: decoder_input.numpy(), - }, - )[0].squeeze(1) - projected_decoder_out = joiner_decoder_proj.run( - [joiner_decoder_proj.get_outputs()[0].name], - {joiner_decoder_proj.get_inputs()[0].name: decoder_out}, - )[0] - projected_decoder_out = torch.from_numpy(projected_decoder_out) - - sorted_ans = [h[context_size:] for h in hyps] - ans = [] - unsorted_indices = packed_encoder_out.unsorted_indices.tolist() - for i in range(N): - ans.append(sorted_ans[unsorted_indices[i]]) - - return ans - - -@torch.no_grad() -def main(): - parser = get_parser() - args = parser.parse_args() - logging.info(vars(args)) - - session_opts = ort.SessionOptions() - session_opts.inter_op_num_threads = 1 - session_opts.intra_op_num_threads = 1 - - encoder = ort.InferenceSession( - args.encoder_model_filename, - sess_options=session_opts, - ) - - decoder = ort.InferenceSession( - args.decoder_model_filename, - sess_options=session_opts, - ) - - joiner = ort.InferenceSession( - args.joiner_model_filename, - sess_options=session_opts, - ) - - joiner_encoder_proj = ort.InferenceSession( - args.joiner_encoder_proj_model_filename, - sess_options=session_opts, - ) - - joiner_decoder_proj = ort.InferenceSession( - args.joiner_decoder_proj_model_filename, - sess_options=session_opts, - ) - - sp = spm.SentencePieceProcessor() - sp.load(args.bpe_model) - - logging.info("Constructing Fbank computer") - opts = kaldifeat.FbankOptions() - opts.device = "cpu" - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = args.sample_rate - opts.mel_opts.num_bins = 80 - - fbank = kaldifeat.Fbank(opts) - - logging.info(f"Reading sound files: {args.sound_files}") - waves = read_sound_files( - filenames=args.sound_files, - expected_sample_rate=args.sample_rate, - ) - - logging.info("Decoding started") - features = fbank(waves) - feature_lengths = [f.size(0) for f in features] - - features = pad_sequence( - features, - batch_first=True, - padding_value=math.log(1e-10), - ) - - feature_lengths = torch.tensor(feature_lengths, dtype=torch.int64) - - encoder_input_nodes = encoder.get_inputs() - encoder_out_nodes = encoder.get_outputs() - encoder_out, encoder_out_lens = encoder.run( - [encoder_out_nodes[0].name, encoder_out_nodes[1].name], - { - encoder_input_nodes[0].name: features.numpy(), - encoder_input_nodes[1].name: feature_lengths.numpy(), - }, - ) - - hyps = greedy_search( - decoder=decoder, - joiner=joiner, - joiner_encoder_proj=joiner_encoder_proj, - joiner_decoder_proj=joiner_decoder_proj, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - context_size=args.context_size, - ) - s = "\n" - for filename, hyp in zip(args.sound_files, hyps): - words = sp.decode(hyp) - s += f"{filename}:\n{words}\n\n" - logging.info(s) - - logging.info("Decoding Done") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/onnx_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7/onnx_pretrained.py new file mode 120000 index 000000000..7607623c8 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/onnx_pretrained.py @@ -0,0 +1 @@ +../pruned_transducer_stateless3/onnx_pretrained.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 5cde57812..3959c0bb2 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -44,7 +44,7 @@ from scaling import ( from torch import Tensor, nn from icefall.dist import get_rank -from icefall.utils import make_pad_mask +from icefall.utils import is_jit_tracing, make_pad_mask class Zipformer(EncoderInterface): @@ -792,7 +792,8 @@ class AttentionDownsample(torch.nn.Module): src = src.reshape(d_seq_len, ds, batch_size, in_channels) scores = (src * self.query).sum(dim=-1, keepdim=True) - scores = penalize_abs_values_gt(scores, limit=10.0, penalty=1.0e-04) + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + scores = penalize_abs_values_gt(scores, limit=10.0, penalty=1.0e-04) weights = scores.softmax(dim=1) @@ -904,6 +905,13 @@ class RelPositionalEncoding(torch.nn.Module): def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: """Construct a PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() + if is_jit_tracing(): + # 10k frames correspond to ~100k ms, e.g., 100 seconds, i.e., + # It assumes that the maximum input won't have more than + # 10k frames. + # + # TODO(fangjun): Use torch.jit.script() for this module + max_len = 10000 self.d_model = d_model self.dropout = torch.nn.Dropout(dropout_rate) self.pe = None diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx.py index 35d6b0556..d7092403e 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx.py @@ -306,11 +306,11 @@ def export_encoder_model_onnx( left_context_len = ",".join(map(str, left_context_len)) meta_data = { - "model_type": "streaming_zipformer", + "model_type": "zipformer", "version": "1", "model_author": "k2-fsa", "decode_chunk_len": str(decode_chunk_len), # 32 - "pad_length": str(pad_length), # 7 + "T": str(T), # 39 "num_encoder_layers": num_encoder_layers, "encoder_dims": encoder_dims, "attention_dims": attention_dims, @@ -362,8 +362,8 @@ def export_encoder_model_onnx( input_names=input_names, output_names=output_names, dynamic_axes={ - "x": {0: "N", 1: "T"}, - "encoder_out": {0: "N", 1: "T"}, + "x": {0: "N"}, + "encoder_out": {0: "N"}, **inputs, **outputs, }, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_pretrained.py index 715560c70..8192e01fd 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_pretrained.py @@ -136,8 +136,11 @@ class OnnxModel: def init_encoder_states(self, batch_size: int = 1): encoder_meta = self.encoder.get_modelmeta().custom_metadata_map + model_type = encoder_meta["model_type"] + assert model_type == "zipformer", model_type + decode_chunk_len = int(encoder_meta["decode_chunk_len"]) - pad_length = int(encoder_meta["pad_length"]) + T = int(encoder_meta["T"]) num_encoder_layers = encoder_meta["num_encoder_layers"] encoder_dims = encoder_meta["encoder_dims"] @@ -155,7 +158,7 @@ class OnnxModel: left_context_len = to_int_list(left_context_len) logging.info(f"decode_chunk_len: {decode_chunk_len}") - logging.info(f"pad_length: {pad_length}") + logging.info(f"T: {T}") logging.info(f"num_encoder_layers: {num_encoder_layers}") logging.info(f"encoder_dims: {encoder_dims}") logging.info(f"attention_dims: {attention_dims}") @@ -219,7 +222,7 @@ class OnnxModel: self.num_encoders = num_encoders - self.segment = decode_chunk_len + pad_length + self.segment = T self.offset = decode_chunk_len def init_decoder(self, decoder_model_filename: str): From 35e5a2475ce143a965a79d06852e8f0263ed1a41 Mon Sep 17 00:00:00 2001 From: Karel Vesely Date: Thu, 9 Feb 2023 00:57:02 +0100 Subject: [PATCH 105/174] Librispeech, validate_manifest.py (#890) --- .../ASR/local/validate_manifest.py | 27 ++++++++++++++----- 1 file changed, 20 insertions(+), 7 deletions(-) diff --git a/egs/librispeech/ASR/local/validate_manifest.py b/egs/librispeech/ASR/local/validate_manifest.py index f620b91ea..de49f5321 100755 --- a/egs/librispeech/ASR/local/validate_manifest.py +++ b/egs/librispeech/ASR/local/validate_manifest.py @@ -35,6 +35,7 @@ from pathlib import Path from lhotse import CutSet, load_manifest_lazy from lhotse.cut import Cut +from lhotse.dataset.speech_recognition import validate_for_asr def get_args(): @@ -55,16 +56,22 @@ def validate_one_supervision_per_cut(c: Cut): def validate_supervision_and_cut_time_bounds(c: Cut): + tol = 2e-3 # same tolerance as in 'validate_for_asr()' s = c.supervisions[0] - if s.start < c.start: - raise ValueError( - f"{c.id}: Supervision start time {s.start} is less " - f"than cut start time {c.start}" - ) - if s.end > c.end: + # Supervision start time is relative to Cut ... + # https://lhotse.readthedocs.io/en/v0.10_e/cuts.html + if s.start < -tol: raise ValueError( - f"{c.id}: Supervision end time {s.end} is larger " + f"{c.id}: Supervision start time {s.start} must not be negative." + ) + if s.start > tol: + raise ValueError( + f"{c.id}: Supervision start time {s.start} is not at the beginning of the Cut. Please apply `lhotse cut trim-to-supervisions`." + ) + if c.start + s.end > c.end + tol: + raise ValueError( + f"{c.id}: Supervision end time {c.start+s.end} is larger " f"than cut end time {c.end}" ) @@ -83,6 +90,12 @@ def main(): validate_one_supervision_per_cut(c) validate_supervision_and_cut_time_bounds(c) + # Validation from K2 training + # - checks supervision start is 0 + # - checks supervision.duration is not longer than cut.duration + # - there is tolerance 2ms + validate_for_asr(cut_set) + if __name__ == "__main__": formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" From e916027bfe2b92e9ecab2d2ae90f75acbc89dd4c Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Thu, 9 Feb 2023 10:33:40 +0800 Subject: [PATCH 106/174] Fix doc typos for onnx export (#891) --- docs/source/conf.py | 2 ++ docs/source/model-export/export-onnx.rst | 11 ++++++----- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index ef9fe1445..6452c5d6d 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -86,4 +86,6 @@ rst_epilog = """ .. _ncnn: https://github.com/tencent/ncnn .. _LibriSpeech: https://www.openslr.org/12 .. _musan: http://www.openslr.org/17/ +.. _ONNX: https://github.com/onnx/onnx +.. _onnxruntime: https://github.com/microsoft/onnxruntime """ diff --git a/docs/source/model-export/export-onnx.rst b/docs/source/model-export/export-onnx.rst index 83c8440b5..ddcbc965f 100644 --- a/docs/source/model-export/export-onnx.rst +++ b/docs/source/model-export/export-onnx.rst @@ -1,20 +1,21 @@ Export to ONNX ============== -In this section, we describe how to export the following models to ONNX. +In this section, we describe how to export models to `ONNX`_. In each recipe, there is a file called ``export-onnx.py``, which is used -to export trained models to ONNX. +to export trained models to `ONNX`_. There is also a file named ``onnx_pretrained.py``, which you can use -the exported ONNX model in Python to decode sound files. +the exported `ONNX`_ model in Python with `onnxruntime`_ to decode sound files. Example ======= In the following, we demonstrate how to export a streaming Zipformer pre-trained -model from ``_ -to ONNX. +model from +``_ +to `ONNX`_. Download the pre-trained model ------------------------------ From 5cd1636cb3d6ee6502043d25d6205079b709401c Mon Sep 17 00:00:00 2001 From: Yifan Yang <64255737+yfyeung@users.noreply.github.com> Date: Thu, 9 Feb 2023 12:12:23 +0800 Subject: [PATCH 107/174] Fix a bug in decode.py (#893) Co-authored-by: yifanyang --- egs/librispeech/ASR/lstm_transducer_stateless3/decode.py | 1 - egs/librispeech/ASR/pruned_transducer_stateless4/decode.py | 1 - 2 files changed, 2 deletions(-) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/decode.py b/egs/librispeech/ASR/lstm_transducer_stateless3/decode.py index 832b99433..a380bc470 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/decode.py @@ -481,7 +481,6 @@ def decode_one_batch( res = DecodingResults(hyps=tokens, timestamps=timestamps) hyps, timestamps = parse_hyp_and_timestamp( - decoding_method=params.decoding_method, res=res, sp=sp, subsampling_factor=params.subsampling_factor, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py index 5fa129a89..c44db0206 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py @@ -525,7 +525,6 @@ def decode_one_batch( res = DecodingResults(hyps=tokens, timestamps=timestamps) hyps, timestamps = parse_hyp_and_timestamp( - decoding_method=params.decoding_method, res=res, sp=sp, subsampling_factor=params.subsampling_factor, From 59ac8bfc70a0062bc32b341e2bd1089f61a1011c Mon Sep 17 00:00:00 2001 From: emilyluj <49872352+emilyluj@users.noreply.github.com> Date: Thu, 9 Feb 2023 18:32:03 +0800 Subject: [PATCH 108/174] fix mmi graph compiler bug. (#895) --- icefall/mmi_graph_compiler.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/icefall/mmi_graph_compiler.py b/icefall/mmi_graph_compiler.py index 9f680f83d..600f09f2b 100644 --- a/icefall/mmi_graph_compiler.py +++ b/icefall/mmi_graph_compiler.py @@ -74,7 +74,9 @@ class MmiTrainingGraphCompiler(object): # CAUTION: The following line is crucial. # Arcs entering the back-off state have label equal to #0. # We have to change it to 0 here. - P.labels[P.labels >= first_token_disambig_id] = 0 + labels = P.labels.clone() + labels[labels >= first_token_disambig_id] = 0 + P.labels = labels P = k2.remove_epsilon(P) P = k2.arc_sort(P) From cba6ecc1d1c63da6cc73988cebe0a0189935a8df Mon Sep 17 00:00:00 2001 From: xiabingquan <53991699+xiabingquan@users.noreply.github.com> Date: Thu, 9 Feb 2023 23:54:45 +0800 Subject: [PATCH 109/174] Update README.md (#894) --- egs/librispeech/ASR/tdnn_lstm_ctc/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/README.md b/egs/librispeech/ASR/tdnn_lstm_ctc/README.md index 94d4ed6a3..b1e01a218 100644 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/README.md +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/README.md @@ -1,4 +1,4 @@ Please visit - + for how to run this recipe. From 57604aac34c8ff3f2398ce7ee916ddd3fe32125f Mon Sep 17 00:00:00 2001 From: KajiMaCN <827272056@qq.com> Date: Fri, 10 Feb 2023 21:28:19 +0800 Subject: [PATCH 110/174] fix tal_csasr data pre-processing (#898) --- egs/tal_csasr/ASR/prepare.sh | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/egs/tal_csasr/ASR/prepare.sh b/egs/tal_csasr/ASR/prepare.sh index d9938fa63..c5d498d74 100755 --- a/egs/tal_csasr/ASR/prepare.sh +++ b/egs/tal_csasr/ASR/prepare.sh @@ -12,9 +12,12 @@ stop_stage=100 # directories and files. If not, they will be downloaded # by this script automatically. # -# - $dl_dir/tal_csasr +# - $dl_dir/TALCS_corpus # You can find three directories:train_set, dev_set, and test_set. # You can get it from https://ai.100tal.com/dataset +# - dev_set +# - test_set +# - train_set # # - $dl_dir/musan # This directory contains the following directories downloaded from @@ -44,7 +47,9 @@ if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then log "Stage 0: Download data" # Before you run this script, you must get the TAL_CSASR dataset # from https://ai.100tal.com/dataset - mv $dl_dir/TALCS_corpus $dl_dir/tal_csasr + if [ ! -d $dl_dir/tal_csasr/TALCS_corpus ]; then + mv $dl_dir/TALCS_corpus $dl_dir/tal_csasr + fi # If you have pre-downloaded it to /path/to/TALCS_corpus, # you can create a symlink @@ -116,7 +121,7 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then fi # Prepare text. - # Note: in Linux, you can install jq with the following command: + # Note: in Linux, you can install jq with the following command: # 1. wget -O jq https://github.com/stedolan/jq/releases/download/jq-1.6/jq-linux64 # 2. chmod +x ./jq # 3. cp jq /usr/bin From 48c2c22dbe53372e5b6565266d76283a03f6670c Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Mon, 13 Feb 2023 11:44:25 +0800 Subject: [PATCH 111/174] Fix export to ncnn for lstm3 (#900) --- .../export-for-ncnn.py | 1 + .../ASR/lstm_transducer_stateless3/lstm.py | 12 +++++++- .../ASR/lstm_transducer_stateless3/train.py | 30 +++++++++++++++---- 3 files changed, 36 insertions(+), 7 deletions(-) create mode 120000 egs/librispeech/ASR/lstm_transducer_stateless3/export-for-ncnn.py diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/export-for-ncnn.py b/egs/librispeech/ASR/lstm_transducer_stateless3/export-for-ncnn.py new file mode 120000 index 000000000..d56cff73f --- /dev/null +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/export-for-ncnn.py @@ -0,0 +1 @@ +../lstm_transducer_stateless2/export-for-ncnn.py \ No newline at end of file diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/lstm.py b/egs/librispeech/ASR/lstm_transducer_stateless3/lstm.py index 6e51b85e4..cb67fffe4 100644 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/lstm.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/lstm.py @@ -135,6 +135,7 @@ class RNN(EncoderInterface): dropout: float = 0.1, layer_dropout: float = 0.075, aux_layer_period: int = 0, + is_pnnx: bool = False, ) -> None: super(RNN, self).__init__() @@ -176,6 +177,8 @@ class RNN(EncoderInterface): else None, ) + self.is_pnnx = is_pnnx + def forward( self, x: torch.Tensor, @@ -216,7 +219,14 @@ class RNN(EncoderInterface): # lengths = ((x_lens - 3) // 2 - 1) // 2 # issue an warning # # Note: rounding_mode in torch.div() is available only in torch >= 1.8.0 - lengths = (((x_lens - 3) >> 1) - 1) >> 1 + if not self.is_pnnx: + lengths = (((x_lens - 3) >> 1) - 1) >> 1 + else: + lengths1 = torch.floor((x_lens - 3) / 2) + lengths = torch.floor((lengths1 - 1) / 2) + lengths = lengths.to(x_lens) + + if not torch.jit.is_tracing(): assert x.size(0) == lengths.max().item() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/train.py b/egs/librispeech/ASR/lstm_transducer_stateless3/train.py index f56b4fd83..6ef4c9860 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/train.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/train.py @@ -102,7 +102,28 @@ def add_model_arguments(parser: argparse.ArgumentParser): "--encoder-dim", type=int, default=512, - help="Encoder output dimesion.", + help="Encoder output dimension.", + ) + + parser.add_argument( + "--decoder-dim", + type=int, + default=512, + help="Decoder output dimension.", + ) + + parser.add_argument( + "--joiner-dim", + type=int, + default=512, + help="Joiner output dimension.", + ) + + parser.add_argument( + "--dim-feedforward", + type=int, + default=2048, + help="Dimension of feed forward.", ) parser.add_argument( @@ -395,14 +416,10 @@ def get_params() -> AttributeDict: # parameters for conformer "feature_dim": 80, "subsampling_factor": 4, - "dim_feedforward": 2048, - # parameters for decoder - "decoder_dim": 512, - # parameters for joiner - "joiner_dim": 512, # parameters for Noam "model_warm_step": 3000, # arg given to model, not for lrate "env_info": get_env_info(), + "is_pnnx": False, } ) @@ -419,6 +436,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module: dim_feedforward=params.dim_feedforward, num_encoder_layers=params.num_encoder_layers, aux_layer_period=params.aux_layer_period, + is_pnnx=params.is_pnnx, ) return encoder From c102e7fbf07f25cee9baad2b739a827a356c3132 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Mon, 13 Feb 2023 12:16:43 +0800 Subject: [PATCH 112/174] more fixes for lstm3 to support exporting to ncnn (#902) --- .../ASR/lstm_transducer_stateless3/lstm.py | 35 ++++++++++++++----- 1 file changed, 27 insertions(+), 8 deletions(-) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/lstm.py b/egs/librispeech/ASR/lstm_transducer_stateless3/lstm.py index cb67fffe4..59a835d35 100644 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/lstm.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/lstm.py @@ -121,6 +121,8 @@ class RNN(EncoderInterface): Period of auxiliary layers used for random combiner during training. If set to 0, will not use the random combiner (Default). You can set a positive integer to use the random combiner, e.g., 3. + is_pnnx: + True to make this class exportable via PNNX. """ def __init__( @@ -149,7 +151,13 @@ class RNN(EncoderInterface): # That is, it does two things simultaneously: # (1) subsampling: T -> T//subsampling_factor # (2) embedding: num_features -> d_model - self.encoder_embed = Conv2dSubsampling(num_features, d_model) + self.encoder_embed = Conv2dSubsampling( + num_features, + d_model, + is_pnnx=is_pnnx, + ) + + self.is_pnnx = is_pnnx self.num_encoder_layers = num_encoder_layers self.d_model = d_model @@ -177,8 +185,6 @@ class RNN(EncoderInterface): else None, ) - self.is_pnnx = is_pnnx - def forward( self, x: torch.Tensor, @@ -226,7 +232,6 @@ class RNN(EncoderInterface): lengths = torch.floor((lengths1 - 1) / 2) lengths = lengths.to(x_lens) - if not torch.jit.is_tracing(): assert x.size(0) == lengths.max().item() @@ -387,7 +392,7 @@ class RNNEncoderLayer(nn.Module): # for cell state assert states[1].shape == (1, src.size(1), self.rnn_hidden_size) src_lstm, new_states = self.lstm(src, states) - src = src + self.dropout(src_lstm) + src = self.dropout(src_lstm) + src # feed forward module src = src + self.dropout(self.feed_forward(src)) @@ -533,6 +538,7 @@ class Conv2dSubsampling(nn.Module): layer1_channels: int = 8, layer2_channels: int = 32, layer3_channels: int = 128, + is_pnnx: bool = False, ) -> None: """ Args: @@ -545,6 +551,9 @@ class Conv2dSubsampling(nn.Module): Number of channels in layer1 layer1_channels: Number of channels in layer2 + is_pnnx: + True if we are converting the model to PNNX format. + False otherwise. """ assert in_channels >= 9 super().__init__() @@ -587,6 +596,10 @@ class Conv2dSubsampling(nn.Module): channel_dim=-1, min_positive=0.45, max_positive=0.55 ) + # ncnn supports only batch size == 1 + self.is_pnnx = is_pnnx + self.conv_out_dim = self.out.weight.shape[1] + def forward(self, x: torch.Tensor) -> torch.Tensor: """Subsample x. @@ -600,9 +613,15 @@ class Conv2dSubsampling(nn.Module): # On entry, x is (N, T, idim) x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W) x = self.conv(x) - # Now x is of shape (N, odim, ((T-3)//2-1)//2, ((idim-3)//2-1)//2) - b, c, t, f = x.size() - x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) + + if torch.jit.is_tracing() and self.is_pnnx: + x = x.permute(0, 2, 1, 3).reshape(1, -1, self.conv_out_dim) + x = self.out(x) + else: + # Now x is of shape (N, odim, ((T-3)//2-1)//2, ((idim-3)//2-1)//2) + b, c, t, f = x.size() + x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) + # Now x is of shape (N, ((T-3)//2-1))//2, odim) x = self.out_norm(x) x = self.out_balancer(x) From c34ee676914aa58a4145c93a15c49d634d7f8960 Mon Sep 17 00:00:00 2001 From: Yifan Yang <64255737+yfyeung@users.noreply.github.com> Date: Mon, 13 Feb 2023 14:05:38 +0800 Subject: [PATCH 113/174] Update generate_model_from_checkpoint.py (#901) --- .../generate_averaged_model.py | 203 ------------- .../generate_model_from_checkpoint.py | 282 ++++++++++++++++++ 2 files changed, 282 insertions(+), 203 deletions(-) delete mode 100755 egs/librispeech/ASR/pruned_transducer_stateless7/generate_averaged_model.py create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless7/generate_model_from_checkpoint.py diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/generate_averaged_model.py b/egs/librispeech/ASR/pruned_transducer_stateless7/generate_averaged_model.py deleted file mode 100755 index 381772ce7..000000000 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/generate_averaged_model.py +++ /dev/null @@ -1,203 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021-2022 Xiaomi Corporation (Author: Yifan Yang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: -(1) use the checkpoint exp_dir/epoch-xxx.pt -./pruned_transducer_stateless7/generate_averaged_model.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless7/exp - -It will generate a file `epoch-28-avg-15.pt` in the given `exp_dir`. -You can later load it by `torch.load("epoch-28-avg-15.pt")`. - -(2) use the checkpoint exp_dir/checkpoint-iter.pt -./pruned_transducer_stateless7/generate_averaged_model.py \ - --iter 22000 \ - --avg 5 \ - --exp-dir ./pruned_transducer_stateless7/exp - -It will generate a file `iter-22000-avg-5.pt` in the given `exp_dir`. -You can later load it by `torch.load("iter-22000-avg-5.pt")`. -""" - - -import argparse -from pathlib import Path -from typing import Dict, List - -import sentencepiece as spm -import torch -from asr_datamodule import LibriSpeechAsrDataModule - -from train import add_model_arguments, get_params, get_transducer_model - -from icefall.checkpoint import ( - average_checkpoints_with_averaged_model, - find_checkpoints, -) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=30, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=9, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="pruned_transducer_stateless7/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - add_model_arguments(parser) - - return parser - - -@torch.no_grad() -def main(): - parser = get_parser() - LibriSpeechAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - if params.iter > 0: - params.suffix = f"iter-{params.iter}-avg-{params.avg}" - else: - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - - print("Script started") - - device = torch.device("cpu") - print(f"Device: {device}") - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - print("About to create model") - model = get_transducer_model(params) - - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - print( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - filename = params.exp_dir / f"iter-{params.iter}-avg-{params.avg}.pt" - torch.save({"model": model.state_dict()}, filename) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - print( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - filename = params.exp_dir / f"epoch-{params.epoch}-avg-{params.avg}.pt" - torch.save({"model": model.state_dict()}, filename) - - num_param = sum([p.numel() for p in model.parameters()]) - print(f"Number of model parameters: {num_param}") - - print("Done!") - - -if __name__ == "__main__": - main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/generate_model_from_checkpoint.py b/egs/librispeech/ASR/pruned_transducer_stateless7/generate_model_from_checkpoint.py new file mode 100755 index 000000000..37edc0390 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/generate_model_from_checkpoint.py @@ -0,0 +1,282 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2022 Xiaomi Corporation (Author: Yifan Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) use the averaged model with checkpoint exp_dir/epoch-xxx.pt +./pruned_transducer_stateless7/generate_model_from_checkpoint.py \ + --epoch 28 \ + --avg 15 \ + --use-averaged-model True \ + --exp-dir ./pruned_transducer_stateless7/exp + +It will generate a file `epoch-28-avg-15-use-averaged-model.pt` in the given `exp_dir`. +You can later load it by `torch.load("epoch-28-avg-15-use-averaged-model.pt")`. + +(2) use the averaged model with checkpoint exp_dir/checkpoint-iter.pt +./pruned_transducer_stateless7/generate_model_from_checkpoint.py \ + --iter 22000 \ + --avg 5 \ + --use-averaged-model True \ + --exp-dir ./pruned_transducer_stateless7/exp + +It will generate a file `iter-22000-avg-5-use-averaged-model.pt` in the given `exp_dir`. +You can later load it by `torch.load("iter-22000-avg-5-use-averaged-model.pt")`. + +(3) use the original model with checkpoint exp_dir/epoch-xxx.pt +./pruned_transducer_stateless7/generate_model_from_checkpoint.py \ + --epoch 28 \ + --avg 15 \ + --use-averaged-model False \ + --exp-dir ./pruned_transducer_stateless7/exp + +It will generate a file `epoch-28-avg-15.pt` in the given `exp_dir`. +You can later load it by `torch.load("epoch-28-avg-15.pt")`. + +(4) use the original model with checkpoint exp_dir/checkpoint-iter.pt +./pruned_transducer_stateless7/generate_model_from_checkpoint.py \ + --iter 22000 \ + --avg 5 \ + --use-averaged-model False \ + --exp-dir ./pruned_transducer_stateless7/exp + +It will generate a file `iter-22000-avg-5.pt` in the given `exp_dir`. +You can later load it by `torch.load("iter-22000-avg-5.pt")`. +""" + + +import argparse +from pathlib import Path +from typing import Dict, List + +import sentencepiece as spm +import torch + +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.utils import str2bool +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=9, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model." + "If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + add_model_arguments(parser) + + return parser + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + print("Script started") + + device = torch.device("cpu") + print(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + print("About to create model") + model = get_transducer_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + print(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + filename = params.exp_dir / f"iter-{params.iter}-avg-{params.avg}.pt" + torch.save({"model": model.state_dict()}, filename) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + filename = params.exp_dir / f"epoch-{params.epoch}-avg-{params.avg}.pt" + torch.save({"model": model.state_dict()}, filename) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + print(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + filename = params.exp_dir / f"epoch-{params.epoch}-avg-{params.avg}.pt" + torch.save({"model": model.state_dict()}, filename) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + print( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + filename = ( + params.exp_dir + / f"iter-{params.iter}-avg-{params.avg}-use-averaged-model.pt" + ) + torch.save({"model": model.state_dict()}, filename) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + print( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + filename = ( + params.exp_dir + / f"epoch-{params.epoch}-avg-{params.avg}-use-averaged-model.pt" + ) + torch.save({"model": model.state_dict()}, filename) + + num_param = sum([p.numel() for p in model.parameters()]) + print(f"Number of model parameters: {num_param}") + + print("Done!") + + +if __name__ == "__main__": + main() From 6a8b649e56671ad9f9e5fa3ae13a2a1c177411e8 Mon Sep 17 00:00:00 2001 From: Desh Raj Date: Mon, 13 Feb 2023 02:53:28 -0500 Subject: [PATCH 114/174] Add small streaming Zipformer transducer model (#903) --- egs/librispeech/ASR/RESULTS.md | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/egs/librispeech/ASR/RESULTS.md b/egs/librispeech/ASR/RESULTS.md index 1a894498e..ecb84eb01 100644 --- a/egs/librispeech/ASR/RESULTS.md +++ b/egs/librispeech/ASR/RESULTS.md @@ -77,6 +77,18 @@ for m in greedy_search modified_beam_search fast_beam_search; do done ``` +#### Smaller model + +A smaller model (~20M params) is also available with configuration based on [this comment](https://github.com/k2-fsa/icefall/pull/745#issuecomment-1405282740). The WERs are: + +| decoding method | chunk size | test-clean | test-other | comment | decoding mode | +|----------------------|------------|------------|------------|---------------------|----------------------| +| greedy search | 320ms | 3.94 | 9.79 | --epoch 30 --avg 9 | simulated streaming | +| modified beam search | 320ms | 3.88 | 9.53 | --epoch 30 --avg 9 | simulated streaming | + +You can find a pretrained model, training logs, decoding logs, and decoding +results at: + ### zipformer_mmi (zipformer with mmi loss) From 25ee50e27cf74770d077da050d0b895c9aa86e87 Mon Sep 17 00:00:00 2001 From: Zengwei Yao Date: Mon, 13 Feb 2023 19:45:09 +0800 Subject: [PATCH 115/174] add ctc-greedy-search with timestamps (#905) --- egs/librispeech/ASR/conformer_ctc3/decode.py | 127 ++++++++++++++++++- 1 file changed, 120 insertions(+), 7 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc3/decode.py b/egs/librispeech/ASR/conformer_ctc3/decode.py index 2300fecc3..6fbf9d674 100755 --- a/egs/librispeech/ASR/conformer_ctc3/decode.py +++ b/egs/librispeech/ASR/conformer_ctc3/decode.py @@ -92,7 +92,10 @@ from icefall.decode import ( from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, + convert_timestamp, get_texts, + make_pad_mask, + parse_bpe_start_end_pairs, parse_fsa_timestamps_and_texts, setup_logger, store_transcripts_and_timestamps, @@ -167,21 +170,24 @@ def get_parser(): default="ctc-decoding", help="""Decoding method. Supported values are: - - (0) ctc-decoding. Use CTC decoding. It uses a sentence piece + - (0) ctc-greedy-search. It uses a sentence piece model, + i.e., lang_dir/bpe.model, to convert word pieces to words. + It needs neither a lexicon nor an n-gram LM. + - (1) ctc-decoding. Use CTC decoding. It uses a sentence piece model, i.e., lang_dir/bpe.model, to convert word pieces to words. It needs neither a lexicon nor an n-gram LM. - - (1) 1best. Extract the best path from the decoding lattice as the + - (2) 1best. Extract the best path from the decoding lattice as the decoding result. - - (2) nbest. Extract n paths from the decoding lattice; the path + - (3) nbest. Extract n paths from the decoding lattice; the path with the highest score is the decoding result. - - (3) nbest-rescoring. Extract n paths from the decoding lattice, + - (4) nbest-rescoring. Extract n paths from the decoding lattice, rescore them with an n-gram LM (e.g., a 4-gram LM), the path with the highest score is the decoding result. - - (4) whole-lattice-rescoring. Rescore the decoding lattice with an + - (5) whole-lattice-rescoring. Rescore the decoding lattice with an n-gram LM (e.g., a 4-gram LM), the best path of rescored lattice is the decoding result. you have trained an RNN LM using ./rnn_lm/train.py - - (5) nbest-oracle. Its WER is the lower bound of any n-best + - (6) nbest-oracle. Its WER is the lower bound of any n-best rescoring method can achieve. Useful for debugging n-best rescoring method. """, @@ -269,6 +275,101 @@ def get_decoding_params() -> AttributeDict: return params +def ctc_greedy_search( + ctc_probs: torch.Tensor, + nnet_output_lens: torch.Tensor, + sp: spm.SentencePieceProcessor, + subsampling_factor: int = 4, + frame_shift_ms: float = 10, +) -> Tuple[List[Tuple[float, float]], List[List[str]]]: + """Apply CTC greedy search + Args: + ctc_probs (torch.Tensor): + (batch, max_len, feat_dim) + nnet_output_lens (torch.Tensor): + (batch, ) + sp: + The BPE model. + subsampling_factor: + The subsampling factor of the model. + frame_shift_ms: + Frame shift in milliseconds between two contiguous frames. + + Returns: + utt_time_pairs: + A list of pair list. utt_time_pairs[i] is a list of + (start-time, end-time) pairs for each word in + utterance-i. + utt_words: + A list of str list. utt_words[i] is a word list of utterence-i. + """ + topk_prob, topk_index = ctc_probs.topk(1, dim=2) # (B, maxlen, 1) + topk_index = topk_index.squeeze(2) # (B, maxlen) + mask = make_pad_mask(nnet_output_lens) + topk_index = topk_index.masked_fill_(mask, 0) # (B, maxlen) + hyps = [hyp.tolist() for hyp in topk_index] + + def get_first_tokens(tokens: List[int]) -> List[bool]: + is_first_token = [] + first_tokens = [] + for t in range(len(tokens)): + if tokens[t] != 0 and (t == 0 or tokens[t - 1] != tokens[t]): + is_first_token.append(True) + first_tokens.append(tokens[t]) + else: + is_first_token.append(False) + return first_tokens, is_first_token + + utt_time_pairs = [] + utt_words = [] + for utt in range(len(hyps)): + first_tokens, is_first_token = get_first_tokens(hyps[utt]) + all_tokens = sp.id_to_piece(hyps[utt]) + index_pairs = parse_bpe_start_end_pairs(all_tokens, is_first_token) + words = sp.decode(first_tokens).split() + assert len(index_pairs) == len(words), ( + len(index_pairs), + len(words), + all_tokens, + ) + start = convert_timestamp( + frames=[i[0] for i in index_pairs], + subsampling_factor=subsampling_factor, + frame_shift_ms=frame_shift_ms, + ) + end = convert_timestamp( + # The duration in frames is (end_frame_index - start_frame_index + 1) + frames=[i[1] + 1 for i in index_pairs], + subsampling_factor=subsampling_factor, + frame_shift_ms=frame_shift_ms, + ) + utt_time_pairs.append(list(zip(start, end))) + utt_words.append(words) + + return utt_time_pairs, utt_words + + +def remove_duplicates_and_blank(hyp: List[int]) -> Tuple[List[int], List[int]]: + # modified from https://github.com/wenet-e2e/wenet/blob/main/wenet/utils/common.py + new_hyp: List[int] = [] + time: List[Tuple[int, int]] = [] + cur = 0 + start, end = -1, -1 + while cur < len(hyp): + if hyp[cur] != 0: + new_hyp.append(hyp[cur]) + start = cur + prev = cur + while cur < len(hyp) and hyp[cur] == hyp[prev]: + if start != -1: + end = cur + cur += 1 + if start != -1 and end != -1: + time.append((start, end)) + start, end = -1, -1 + return new_hyp, time + + def decode_one_batch( params: AttributeDict, model: nn.Module, @@ -360,6 +461,17 @@ def decode_one_batch( nnet_output = model.get_ctc_output(encoder_out) # nnet_output is (N, T, C) + if params.decoding_method == "ctc-greedy-search": + timestamps, hyps = ctc_greedy_search( + ctc_probs=nnet_output, + nnet_output_lens=encoder_out_lens, + sp=bpe_model, + subsampling_factor=params.subsampling_factor, + frame_shift_ms=params.frame_shift_ms, + ) + key = "ctc-greedy-search" + return {key: (hyps, timestamps)} + supervision_segments = torch.stack( ( supervisions["sequence_idx"], @@ -696,6 +808,7 @@ def main(): params.update(vars(args)) assert params.decoding_method in ( + "ctc-greedy-search", "ctc-decoding", "1best", "nbest", @@ -749,7 +862,7 @@ def main(): params.sos_id = sos_id params.eos_id = eos_id - if params.decoding_method == "ctc-decoding": + if params.decoding_method in ["ctc-decoding", "ctc-greedy-search"]: HLG = None H = k2.ctc_topo( max_token=max_token_id, From e63a8c27f811bbee321429f8253ff8d1260aa929 Mon Sep 17 00:00:00 2001 From: Teo Wen Shen <36886809+teowenshen@users.noreply.github.com> Date: Mon, 13 Feb 2023 23:19:50 +0900 Subject: [PATCH 116/174] CSJ pruned_transducer_stateless7_streaming (#892) * update manifest stats * update transcript configs * lang_char and compute_fbanks * save cuts in fbank_dir * add core codes * update decode.py * Create local/utils * tidy up * parse raw in prepare_lang_char.py * update manifest stats * update transcript configs * lang_char and compute_fbanks * save cuts in fbank_dir * add core codes * update decode.py * Create local/utils * tidy up * parse raw in prepare_lang_char.py * working train * Add compare_cer_transcript.py * fix tokenizer decode, allow d2f only * comment cleanup * add export files and READMEs * reword average column * fix comments * Update new results --- egs/csj/ASR/README.md | 11 + egs/csj/ASR/RESULTS.md | 200 +++ egs/csj/ASR/local/add_transcript_mode.py | 94 ++ egs/csj/ASR/local/compute_fbank_csj.py | 109 +- egs/csj/ASR/local/compute_fbank_musan.py | 12 +- egs/csj/ASR/local/conf/disfluent.ini | 243 +-- egs/csj/ASR/local/conf/fluent.ini | 243 +-- egs/csj/ASR/local/conf/number.ini | 241 --- egs/csj/ASR/local/conf/symbol.ini | 251 +--- .../ASR/local/disfluent_recogs_to_fluent.py | 202 +++ .../ASR/local/display_manifest_statistics.py | 376 +++-- egs/csj/ASR/local/prepare_lang_char.py | 102 +- egs/csj/ASR/local/utils/asr_datamodule.py | 462 ++++++ egs/csj/ASR/local/utils/tokenizer.py | 253 ++++ egs/csj/ASR/prepare.sh | 42 +- .../TelegramStreamIO.py | 76 + .../asr_datamodule.py | 1 + .../beam_search.py | 1 + .../decode.py | 852 +++++++++++ .../decode_stream.py | 1 + .../decoder.py | 1 + .../encoder_interface.py | 1 + .../export.py | 313 ++++ .../jit_trace_export.py | 308 ++++ .../jit_trace_pretrained.py | 286 ++++ .../joiner.py | 1 + .../model.py | 1 + .../optim.py | 1 + .../pretrained.py | 347 +++++ .../scaling.py | 1 + .../scaling_converter.py | 1 + .../streaming_beam_search.py | 1 + .../streaming_decode.py | 597 ++++++++ .../test_model.py | 150 ++ .../tokenizer.py | 1 + .../train.py | 1304 +++++++++++++++++ .../zipformer.py | 1 + 37 files changed, 5847 insertions(+), 1240 deletions(-) create mode 100644 egs/csj/ASR/README.md create mode 100644 egs/csj/ASR/RESULTS.md create mode 100644 egs/csj/ASR/local/add_transcript_mode.py create mode 100644 egs/csj/ASR/local/disfluent_recogs_to_fluent.py create mode 100644 egs/csj/ASR/local/utils/asr_datamodule.py create mode 100644 egs/csj/ASR/local/utils/tokenizer.py create mode 100644 egs/csj/ASR/pruned_transducer_stateless7_streaming/TelegramStreamIO.py create mode 120000 egs/csj/ASR/pruned_transducer_stateless7_streaming/asr_datamodule.py create mode 120000 egs/csj/ASR/pruned_transducer_stateless7_streaming/beam_search.py create mode 100755 egs/csj/ASR/pruned_transducer_stateless7_streaming/decode.py create mode 120000 egs/csj/ASR/pruned_transducer_stateless7_streaming/decode_stream.py create mode 120000 egs/csj/ASR/pruned_transducer_stateless7_streaming/decoder.py create mode 120000 egs/csj/ASR/pruned_transducer_stateless7_streaming/encoder_interface.py create mode 100644 egs/csj/ASR/pruned_transducer_stateless7_streaming/export.py create mode 100644 egs/csj/ASR/pruned_transducer_stateless7_streaming/jit_trace_export.py create mode 100644 egs/csj/ASR/pruned_transducer_stateless7_streaming/jit_trace_pretrained.py create mode 120000 egs/csj/ASR/pruned_transducer_stateless7_streaming/joiner.py create mode 120000 egs/csj/ASR/pruned_transducer_stateless7_streaming/model.py create mode 120000 egs/csj/ASR/pruned_transducer_stateless7_streaming/optim.py create mode 100644 egs/csj/ASR/pruned_transducer_stateless7_streaming/pretrained.py create mode 120000 egs/csj/ASR/pruned_transducer_stateless7_streaming/scaling.py create mode 120000 egs/csj/ASR/pruned_transducer_stateless7_streaming/scaling_converter.py create mode 120000 egs/csj/ASR/pruned_transducer_stateless7_streaming/streaming_beam_search.py create mode 100755 egs/csj/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py create mode 100755 egs/csj/ASR/pruned_transducer_stateless7_streaming/test_model.py create mode 120000 egs/csj/ASR/pruned_transducer_stateless7_streaming/tokenizer.py create mode 100755 egs/csj/ASR/pruned_transducer_stateless7_streaming/train.py create mode 120000 egs/csj/ASR/pruned_transducer_stateless7_streaming/zipformer.py diff --git a/egs/csj/ASR/README.md b/egs/csj/ASR/README.md new file mode 100644 index 000000000..95c2ec6ac --- /dev/null +++ b/egs/csj/ASR/README.md @@ -0,0 +1,11 @@ +# Introduction + +[./RESULTS.md](./RESULTS.md) contains the latest results. + +# Transducers + +These are the types of architectures currently available. + +| | Encoder | Decoder | Comment | +|---------------------------------------|---------------------|--------------------|---------------------------------------------------| +| `pruned_transducer_stateless7_streaming` | Streaming Zipformer | Embedding + Conv1d | Adapted from librispeech pruned_transducer_stateless7_streaming | diff --git a/egs/csj/ASR/RESULTS.md b/egs/csj/ASR/RESULTS.md new file mode 100644 index 000000000..56fdb899f --- /dev/null +++ b/egs/csj/ASR/RESULTS.md @@ -0,0 +1,200 @@ +# Results + +## Streaming Zipformer-Transducer (Pruned Stateless Transducer + Streaming Zipformer) + +### [pruned_transducer_stateless7_streaming](./pruned_transducer_stateless7_streaming) + +See for more details. + +You can find a pretrained model, training logs, decoding logs, and decoding results at: + + +Number of model parameters: 75688409, i.e. 75.7M. + +#### training on disfluent transcript + +The CERs are: + +| decoding method | chunk size | eval1 | eval2 | eval3 | excluded | valid | average | decoding mode | +| --------------- | ---------- | ----- | ----- | ----- | -------- | ----- | ------- | ------------- | +| fast beam search | 320ms | 5.39 | 4.08 | 4.16 | 5.4 | 5.02 | --epoch 30 --avg 17 | simulated streaming | +| fast beam search | 320ms | 5.34 | 4.1 | 4.26 | 5.61 | 4.91 | --epoch 30 --avg 17 | chunk-wise | +| greedy search | 320ms | 5.43 | 4.14 | 4.31 | 5.48 | 4.88 | --epoch 30 --avg 17 | simulated streaming | +| greedy search | 320ms | 5.44 | 4.14 | 4.39 | 5.7 | 4.98 | --epoch 30 --avg 17 | chunk-wise | +| modified beam search | 320ms | 5.2 | 3.95 | 4.09 | 5.12 | 4.75 | --epoch 30 --avg 17 | simulated streaming | +| modified beam search | 320ms | 5.18 | 4.07 | 4.12 | 5.36 | 4.77 | --epoch 30 --avg 17 | chunk-wise | +| fast beam search | 640ms | 5.01 | 3.78 | 3.96 | 4.85 | 4.6 | --epoch 30 --avg 17 | simulated streaming | +| fast beam search | 640ms | 4.97 | 3.88 | 3.96 | 4.91 | 4.61 | --epoch 30 --avg 17 | chunk-wise | +| greedy search | 640ms | 5.02 | 3.84 | 4.14 | 5.02 | 4.59 | --epoch 30 --avg 17 | simulated streaming | +| greedy search | 640ms | 5.32 | 4.22 | 4.33 | 5.39 | 4.99 | --epoch 30 --avg 17 | chunk-wise | +| modified beam search | 640ms | 4.78 | 3.66 | 3.85 | 4.72 | 4.42 | --epoch 30 --avg 17 | simulated streaming | +| modified beam search | 640ms | 5.77 | 4.72 | 4.73 | 5.85 | 5.36 | --epoch 30 --avg 17 | chunk-wise | + +Note: `simulated streaming` indicates feeding full utterance during decoding using `decode.py`, +while `chunk-size` indicates feeding certain number of frames at each time using `streaming_decode.py`. + +The training command was: +```bash +./pruned_transducer_stateless7_streaming/train.py \ + --feedforward-dims "1024,1024,2048,2048,1024" \ + --world-size 8 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir pruned_transducer_stateless7_streaming/exp_disfluent_2_pad30 \ + --max-duration 375 \ + --transcript-mode disfluent \ + --lang data/lang_char \ + --manifest-dir /mnt/host/corpus/csj/fbank \ + --pad-feature 30 \ + --musan-dir /mnt/host/corpus/musan/musan/fbank +``` + +The simulated streaming decoding command was: +```bash +for chunk in 64 32; do + for m in greedy_search fast_beam_search modified_beam_search; do + python pruned_transducer_stateless7_streaming/decode.py \ + --feedforward-dims "1024,1024,2048,2048,1024" \ + --exp-dir pruned_transducer_stateless7_streaming/exp_disfluent_2_pad30 \ + --epoch 30 \ + --avg 17 \ + --max-duration 350 \ + --decoding-method $m \ + --manifest-dir /mnt/host/corpus/csj/fbank \ + --lang data/lang_char \ + --transcript-mode disfluent \ + --res-dir pruned_transducer_stateless7_streaming/exp_disfluent_2_pad30/github/sim_"$chunk"_"$m" \ + --decode-chunk-len $chunk \ + --pad-feature 30 \ + --gpu 0 + done +done +``` + +The streaming chunk-wise decoding command was: +```bash +for chunk in 64 32; do + for m in greedy_search fast_beam_search modified_beam_search; do + python pruned_transducer_stateless7_streaming/streaming_decode.py \ + --feedforward-dims "1024,1024,2048,2048,1024" \ + --exp-dir pruned_transducer_stateless7_streaming/exp_disfluent_2_pad30 \ + --epoch 30 \ + --avg 17 \ + --max-duration 350 \ + --decoding-method $m \ + --manifest-dir /mnt/host/corpus/csj/fbank \ + --lang data/lang_char \ + --transcript-mode disfluent \ + --res-dir pruned_transducer_stateless7_streaming/exp_disfluent_2_pad30/github/stream_"$chunk"_"$m" \ + --decode-chunk-len $chunk \ + --gpu 2 \ + --num-decode-streams 40 + done +done +``` + +#### training on fluent transcript + +The CERs are: + +| decoding method | chunk size | eval1 | eval2 | eval3 | excluded | valid | average | decoding mode | +| --------------- | ---------- | ----- | ----- | ----- | -------- | ----- | ------- | ------------- | +| fast beam search | 320ms | 4.19 | 3.63 | 3.77 | 4.43 | 4.09 | --epoch 30 --avg 12 | simulated streaming | +| fast beam search | 320ms | 4.06 | 3.55 | 3.66 | 4.70 | 4.04 | --epoch 30 --avg 12 | chunk-wise | +| greedy search | 320ms | 4.22 | 3.62 | 3.82 | 4.45 | 3.98 | --epoch 30 --avg 12 | simulated streaming | +| greedy search | 320ms | 4.13 | 3.61 | 3.85 | 4.67 | 4.05 | --epoch 30 --avg 12 | chunk-wise | +| modified beam search | 320ms | 4.02 | 3.43 | 3.62 | 4.43 | 3.81 | --epoch 30 --avg 12 | simulated streaming | +| modified beam search | 320ms | 3.97 | 3.43 | 3.59 | 4.99 | 3.88 | --epoch 30 --avg 12 | chunk-wise | +| fast beam search | 640ms | 3.80 | 3.31 | 3.55 | 4.16 | 3.90 | --epoch 30 --avg 12 | simulated streaming | +| fast beam search | 640ms | 3.81 | 3.34 | 3.46 | 4.58 | 3.85 | --epoch 30 --avg 12 | chunk-wise | +| greedy search | 640ms | 3.92 | 3.38 | 3.65 | 4.31 | 3.88 | --epoch 30 --avg 12 | simulated streaming | +| greedy search | 640ms | 3.98 | 3.38 | 3.64 | 4.54 | 4.01 | --epoch 30 --avg 12 | chunk-wise | +| modified beam search | 640ms | 3.72 | 3.26 | 3.39 | 4.10 | 3.65 | --epoch 30 --avg 12 | simulated streaming | +| modified beam search | 640ms | 3.78 | 3.32 | 3.45 | 4.81 | 3.81 | --epoch 30 --avg 12 | chunk-wise | + +Note: `simulated streaming` indicates feeding full utterance during decoding using `decode.py`, +while `chunk-size` indicates feeding certain number of frames at each time using `streaming_decode.py`. + +The training command was: +```bash +./pruned_transducer_stateless7_streaming/train.py \ + --feedforward-dims "1024,1024,2048,2048,1024" \ + --world-size 8 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir pruned_transducer_stateless7_streaming/exp_fluent_2_pad30 \ + --max-duration 375 \ + --transcript-mode fluent \ + --lang data/lang_char \ + --manifest-dir /mnt/host/corpus/csj/fbank \ + --pad-feature 30 \ + --musan-dir /mnt/host/corpus/musan/musan/fbank +``` + +The simulated streaming decoding command was: +```bash +for chunk in 64 32; do + for m in greedy_search fast_beam_search modified_beam_search; do + python pruned_transducer_stateless7_streaming/decode.py \ + --feedforward-dims "1024,1024,2048,2048,1024" \ + --exp-dir pruned_transducer_stateless7_streaming/exp_fluent_2_pad30 \ + --epoch 30 \ + --avg 12 \ + --max-duration 350 \ + --decoding-method $m \ + --manifest-dir /mnt/host/corpus/csj/fbank \ + --lang data/lang_char \ + --transcript-mode fluent \ + --res-dir pruned_transducer_stateless7_streaming/exp_fluent_2_pad30/github/sim_"$chunk"_"$m" \ + --decode-chunk-len $chunk \ + --pad-feature 30 \ + --gpu 1 + done +done +``` + +The streaming chunk-wise decoding command was: +```bash +for chunk in 64 32; do + for m in greedy_search fast_beam_search modified_beam_search; do + python pruned_transducer_stateless7_streaming/streaming_decode.py \ + --feedforward-dims "1024,1024,2048,2048,1024" \ + --exp-dir pruned_transducer_stateless7_streaming/exp_fluent_2_pad30 \ + --epoch 30 \ + --avg 12 \ + --max-duration 350 \ + --decoding-method $m \ + --manifest-dir /mnt/host/corpus/csj/fbank \ + --lang data/lang_char \ + --transcript-mode fluent \ + --res-dir pruned_transducer_stateless7_streaming/exp_fluent_2_pad30/github/stream_"$chunk"_"$m" \ + --decode-chunk-len $chunk \ + --gpu 3 \ + --num-decode-streams 40 + done +done +``` + +#### Comparing disfluent to fluent + +$$ \texttt{CER}^{f}_d = \frac{\texttt{sub}_f + \texttt{ins} + \texttt{del}_f}{N_f} $$ + +This comparison evaluates the disfluent model on the fluent transcript (calculated by `disfluent_recogs_to_fluent.py`), forgiving the disfluent model's mistakes on fillers and partial words. It is meant as an illustrative metric only, so that the disfluent and fluent models can be compared. + +| decoding method | chunk size | eval1 (d vs f) | eval2 (d vs f) | eval3 (d vs f) | excluded (d vs f) | valid (d vs f) | decoding mode | +| --------------- | ---------- | -------------- | --------------- | -------------- | -------------------- | --------------- | ----------- | +| fast beam search | 320ms | 4.54 vs 4.19 | 3.44 vs 3.63 | 3.56 vs 3.77 | 4.22 vs 4.43 | 4.22 vs 4.09 | simulated streaming | +| fast beam search | 320ms | 4.48 vs 4.06 | 3.41 vs 3.55 | 3.65 vs 3.66 | 4.26 vs 4.7 | 4.08 vs 4.04 | chunk-wise | +| greedy search | 320ms | 4.53 vs 4.22 | 3.48 vs 3.62 | 3.69 vs 3.82 | 4.38 vs 4.45 | 4.05 vs 3.98 | simulated streaming | +| greedy search | 320ms | 4.53 vs 4.13 | 3.46 vs 3.61 | 3.71 vs 3.85 | 4.48 vs 4.67 | 4.12 vs 4.05 | chunk-wise | +| modified beam search | 320ms | 4.45 vs 4.02 | 3.38 vs 3.43 | 3.57 vs 3.62 | 4.19 vs 4.43 | 4.04 vs 3.81 | simulated streaming | +| modified beam search | 320ms | 4.44 vs 3.97 | 3.47 vs 3.43 | 3.56 vs 3.59 | 4.28 vs 4.99 | 4.04 vs 3.88 | chunk-wise | +| fast beam search | 640ms | 4.14 vs 3.8 | 3.12 vs 3.31 | 3.38 vs 3.55 | 3.72 vs 4.16 | 3.81 vs 3.9 | simulated streaming | +| fast beam search | 640ms | 4.05 vs 3.81 | 3.23 vs 3.34 | 3.36 vs 3.46 | 3.65 vs 4.58 | 3.78 vs 3.85 | chunk-wise | +| greedy search | 640ms | 4.1 vs 3.92 | 3.17 vs 3.38 | 3.5 vs 3.65 | 3.87 vs 4.31 | 3.77 vs 3.88 | simulated streaming | +| greedy search | 640ms | 4.41 vs 3.98 | 3.56 vs 3.38 | 3.69 vs 3.64 | 4.26 vs 4.54 | 4.16 vs 4.01 | chunk-wise | +| modified beam search | 640ms | 4 vs 3.72 | 3.08 vs 3.26 | 3.33 vs 3.39 | 3.75 vs 4.1 | 3.71 vs 3.65 | simulated streaming | +| modified beam search | 640ms | 5.05 vs 3.78 | 4.22 vs 3.32 | 4.26 vs 3.45 | 5.02 vs 4.81 | 4.73 vs 3.81 | chunk-wise | +| average (d - f) | | 0.43 | -0.02 | -0.02 | -0.34 | 0.13 | | diff --git a/egs/csj/ASR/local/add_transcript_mode.py b/egs/csj/ASR/local/add_transcript_mode.py new file mode 100644 index 000000000..f6b4b2caf --- /dev/null +++ b/egs/csj/ASR/local/add_transcript_mode.py @@ -0,0 +1,94 @@ +import argparse +import logging +from configparser import ConfigParser +from pathlib import Path +from typing import List + +from lhotse import CutSet, SupervisionSet +from lhotse.recipes.csj import CSJSDBParser + +ARGPARSE_DESCRIPTION = """ +This script adds transcript modes to an existing CutSet or SupervisionSet. +""" + + +def get_args(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + description=ARGPARSE_DESCRIPTION, + ) + parser.add_argument( + "-f", + "--fbank-dir", + type=Path, + help="Path to directory where manifests are stored.", + ) + parser.add_argument( + "-c", + "--config", + type=Path, + nargs="+", + help="Path to config file for transcript parsing.", + ) + return parser.parse_args() + + +def get_CSJParsers(config_files: List[Path]) -> List[CSJSDBParser]: + parsers = [] + for config_file in config_files: + config = ConfigParser() + config.optionxform = str + assert config.read(config_file), f"{config_file} could not be found." + decisions = {} + for k, v in config["DECISIONS"].items(): + try: + decisions[k] = int(v) + except ValueError: + decisions[k] = v + parsers.append( + (config["CONSTANTS"].get("MODE"), CSJSDBParser(decisions=decisions)) + ) + return parsers + + +def main(): + args = get_args() + logging.basicConfig( + format=("%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"), + level=logging.INFO, + ) + parsers = get_CSJParsers(args.config) + config = ConfigParser() + config.optionxform = str + assert config.read(args.config), args.config + decisions = {} + for k, v in config["DECISIONS"].items(): + try: + decisions[k] = int(v) + except ValueError: + decisions[k] = v + + logging.info(f"Adding {', '.join(x[0] for x in parsers)} transcript mode.") + + manifests = args.fbank_dir.glob("csj_cuts_*.jsonl.gz") + assert manifests, f"No cuts to be found in {args.fbank_dir}" + + for manifest in manifests: + results = [] + logging.info(f"Adding transcript modes to {manifest.name} now.") + cutset = CutSet.from_file(manifest) + for cut in cutset: + for name, parser in parsers: + cut.supervisions[0].custom[name] = parser.parse( + cut.supervisions[0].custom["raw"] + ) + cut.supervisions[0].text = "" + results.append(cut) + results = CutSet.from_items(results) + res_file = manifest.as_posix() + manifest.replace(manifest.parent / ("bak." + manifest.name)) + results.to_file(res_file) + + +if __name__ == "__main__": + main() diff --git a/egs/csj/ASR/local/compute_fbank_csj.py b/egs/csj/ASR/local/compute_fbank_csj.py index 667ad427e..ce560025d 100644 --- a/egs/csj/ASR/local/compute_fbank_csj.py +++ b/egs/csj/ASR/local/compute_fbank_csj.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -# Copyright 2022 The University of Electro-Communications (Author: Teo Wen Shen) # noqa +# Copyright 2023 The University of Electro-Communications (Author: Teo Wen Shen) # noqa # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -19,9 +19,7 @@ import argparse import logging import os -from itertools import islice from pathlib import Path -from random import Random from typing import List, Tuple import torch @@ -35,20 +33,10 @@ from lhotse import ( # See the following for why LilcomChunkyWriter is preferre RecordingSet, SupervisionSet, ) +from lhotse.recipes.csj import concat_csj_supervisions # fmt: on -ARGPARSE_DESCRIPTION = """ -This script follows the espnet method of splitting the remaining core+noncore -utterances into valid and train cutsets at an index which is by default 4000. - -In other words, the core+noncore utterances are shuffled, where 4000 utterances -of the shuffled set go to the `valid` cutset and are not subject to speed -perturbation. The remaining utterances become the `train` cutset and are speed- -perturbed (0.9x, 1.0x, 1.1x). - -""" - # Torch's multithreaded behavior needs to be disabled or # it wastes a lot of CPU and slow things down. # Do this outside of main() in case it needs to take effect @@ -57,66 +45,101 @@ torch.set_num_threads(1) torch.set_num_interop_threads(1) RNG_SEED = 42 +# concat_params_train = [ +# {"gap": 1.0, "maxlen": 10.0}, +# {"gap": 1.5, "maxlen": 8.0}, +# {"gap": 1.0, "maxlen": 18.0}, +# ] + +concat_params = {"gap": 1.0, "maxlen": 10.0} def make_cutset_blueprints( manifest_dir: Path, - split: int, ) -> List[Tuple[str, CutSet]]: cut_sets = [] + logging.info("Creating non-train cuts.") + # Create eval datasets - logging.info("Creating eval cuts.") for i in range(1, 4): + sps = sorted( + SupervisionSet.from_file( + manifest_dir / f"csj_supervisions_eval{i}.jsonl.gz" + ), + key=lambda x: x.id, + ) + cut_set = CutSet.from_manifests( recordings=RecordingSet.from_file( manifest_dir / f"csj_recordings_eval{i}.jsonl.gz" ), - supervisions=SupervisionSet.from_file( - manifest_dir / f"csj_supervisions_eval{i}.jsonl.gz" - ), + supervisions=concat_csj_supervisions(sps, **concat_params), ) cut_set = cut_set.trim_to_supervisions(keep_overlapping=False) cut_sets.append((f"eval{i}", cut_set)) - # Create train and valid cuts - logging.info("Loading, trimming, and shuffling the remaining core+noncore cuts.") - recording_set = RecordingSet.from_file( - manifest_dir / "csj_recordings_core.jsonl.gz" - ) + RecordingSet.from_file(manifest_dir / "csj_recordings_noncore.jsonl.gz") - supervision_set = SupervisionSet.from_file( - manifest_dir / "csj_supervisions_core.jsonl.gz" - ) + SupervisionSet.from_file(manifest_dir / "csj_supervisions_noncore.jsonl.gz") - + # Create excluded dataset + sps = sorted( + SupervisionSet.from_file(manifest_dir / "csj_supervisions_excluded.jsonl.gz"), + key=lambda x: x.id, + ) cut_set = CutSet.from_manifests( - recordings=recording_set, - supervisions=supervision_set, + recordings=RecordingSet.from_file( + manifest_dir / "csj_recordings_excluded.jsonl.gz" + ), + supervisions=concat_csj_supervisions(sps, **concat_params), ) cut_set = cut_set.trim_to_supervisions(keep_overlapping=False) - cut_set = cut_set.shuffle(Random(RNG_SEED)) + cut_sets.append(("excluded", cut_set)) - logging.info( - "Creating valid and train cuts from core and noncore, split at {split}." + # Create valid dataset + sps = sorted( + SupervisionSet.from_file(manifest_dir / "csj_supervisions_valid.jsonl.gz"), + key=lambda x: x.id, ) - valid_set = CutSet.from_cuts(islice(cut_set, 0, split)) + cut_set = CutSet.from_manifests( + recordings=RecordingSet.from_file( + manifest_dir / "csj_recordings_valid.jsonl.gz" + ), + supervisions=concat_csj_supervisions(sps, **concat_params), + ) + cut_set = cut_set.trim_to_supervisions(keep_overlapping=False) + cut_sets.append(("valid", cut_set)) - train_set = CutSet.from_cuts(islice(cut_set, split, None)) + logging.info("Creating train cuts.") + + # Create train dataset + sps = sorted( + SupervisionSet.from_file(manifest_dir / "csj_supervisions_core.jsonl.gz") + + SupervisionSet.from_file(manifest_dir / "csj_supervisions_noncore.jsonl.gz"), + key=lambda x: x.id, + ) + + recording = RecordingSet.from_file( + manifest_dir / "csj_recordings_core.jsonl.gz" + ) + RecordingSet.from_file(manifest_dir / "csj_recordings_noncore.jsonl.gz") + + train_set = CutSet.from_manifests( + recordings=recording, supervisions=concat_csj_supervisions(sps, **concat_params) + ).trim_to_supervisions(keep_overlapping=False) train_set = train_set + train_set.perturb_speed(0.9) + train_set.perturb_speed(1.1) - cut_sets.extend([("valid", valid_set), ("train", train_set)]) + cut_sets.append(("train", train_set)) return cut_sets def get_args(): parser = argparse.ArgumentParser( - description=ARGPARSE_DESCRIPTION, formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) - - parser.add_argument("--manifest-dir", type=Path, help="Path to save manifests") - parser.add_argument("--fbank-dir", type=Path, help="Path to save fbank features") - parser.add_argument("--split", type=int, default=4000, help="Split at this index") + parser.add_argument( + "-m", "--manifest-dir", type=Path, help="Path to save manifests" + ) + parser.add_argument( + "-f", "--fbank-dir", type=Path, help="Path to save fbank features" + ) return parser.parse_args() @@ -138,7 +161,7 @@ def main(): ) return else: - cut_sets = make_cutset_blueprints(args.manifest_dir, args.split) + cut_sets = make_cutset_blueprints(args.manifest_dir) for part, cut_set in cut_sets: logging.info(f"Processing {part}") cut_set = cut_set.compute_and_store_features( @@ -147,7 +170,7 @@ def main(): storage_path=(args.fbank_dir / f"feats_{part}").as_posix(), storage_type=LilcomChunkyWriter, ) - cut_set.to_file(args.manifest_dir / f"csj_cuts_{part}.jsonl.gz") + cut_set.to_file(args.fbank_dir / f"csj_cuts_{part}.jsonl.gz") logging.info("All fbank computed for CSJ.") (args.fbank_dir / ".done").touch() diff --git a/egs/csj/ASR/local/compute_fbank_musan.py b/egs/csj/ASR/local/compute_fbank_musan.py index f60e62c85..c942df98e 100644 --- a/egs/csj/ASR/local/compute_fbank_musan.py +++ b/egs/csj/ASR/local/compute_fbank_musan.py @@ -28,9 +28,7 @@ from icefall.utils import get_executor ARGPARSE_DESCRIPTION = """ This file computes fbank features of the musan dataset. -It looks for manifests in the directory data/manifests. -The generated fbank features are saved in data/fbank. """ # Torch's multithreaded behavior needs to be disabled or @@ -42,8 +40,6 @@ torch.set_num_interop_threads(1) def compute_fbank_musan(manifest_dir: Path, fbank_dir: Path): - # src_dir = Path("data/manifests") - # output_dir = Path("data/fbank") num_jobs = min(15, os.cpu_count()) num_mel_bins = 80 @@ -104,8 +100,12 @@ def get_args(): formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) - parser.add_argument("--manifest-dir", type=Path, help="Path to save manifests") - parser.add_argument("--fbank-dir", type=Path, help="Path to save fbank features") + parser.add_argument( + "-m", "--manifest-dir", type=Path, help="Path to save manifests" + ) + parser.add_argument( + "-f", "--fbank-dir", type=Path, help="Path to save fbank features" + ) return parser.parse_args() diff --git a/egs/csj/ASR/local/conf/disfluent.ini b/egs/csj/ASR/local/conf/disfluent.ini index c987e72c5..4f0a9ec0e 100644 --- a/egs/csj/ASR/local/conf/disfluent.ini +++ b/egs/csj/ASR/local/conf/disfluent.ini @@ -1,320 +1,79 @@ -; # This section is ignored if this file is not supplied as the first config file to -; # lhotse prepare csj -[SEGMENTS] -; # Allowed period of nonverbal noise. If exceeded, a new segment is created. -gap = 0.5 -; # Maximum length of segment (s). -maxlen = 10 -; # Minimum length of segment (s). Segments shorter than `minlen` will be dropped silently. -minlen = 0.02 -; # Use this symbol to represent a period of allowed nonverbal noise, i.e. `gap`. -; # Pass an empty string to avoid adding any symbol. It was "" in kaldi. -; # If you intend to use a multicharacter string for gap_sym, remember to register the -; # multicharacter string as part of userdef-string in prepare_lang_char.py. -gap_sym = - [CONSTANTS] ; # Name of this mode MODE = disfluent -; # Suffixes to use after the word surface (no longer used) -MORPH = pos1 cForm cType2 pos2 -; # Used to differentiate between A tag and A_num tag -JPN_NUM = ゼロ 0 零 一 二 三 四 五 六 七 八 九 十 百 千 . -; # Dummy character to delineate multiline words -PLUS = + [DECISIONS] -; # TAG+'^'とは、タグが一つの転記単位に独立していない場合 -; # The PLUS (fullwidth) sign '+' marks line boundaries for multiline entries - ; # フィラー、感情表出系感動詞 ; # 0 to remain, 1 to delete ; # Example: '(F ぎょっ)' F = 0 -; # Example: '(L (F ン))', '比べ(F えー)る' -F^ = 0 ; # 言い直し、いいよどみなどによる語断片 ; # 0 to remain, 1 to delete ; # Example: '(D だ)(D だいが) 大学の学部の会議' D = 0 -; # Example: '(L (D ドゥ)+(D ヒ))' -D^ = 0 ; # 助詞、助動詞、接辞の言い直し ; # 0 to remain, 1 to delete ; # Example: '西洋 (D2 的)(F えー)(D ふ) 風というか' D2 = 0 -; # Example: '(X (D2 ノ))' -D2^ = 0 ; # 聞き取りや語彙の判断に自信がない場合 ; # 0 to remain, 1 to delete ; # Example: (? 字数) の ; # If no option: empty string is returned regardless of output ; # Example: '(?) で' ? = 0 -; # Example: '(D (? すー))+そう+です+よ+ね' -?^ = 0 ; # タグ?で、値は複数の候補が想定される場合 ; # 0 for main guess with matching morph info, 1 for second guess ; # Example: '(? 次数, 実数)', '(? これ,ここで)+(? 説明+し+た+方+が+いい+か+な)' ?, = 0 -; # Example: '(W (? テユクー);(? ケッキョク,テユウコトデ))', '(W マシ;(? マシ+タ,マス))' -?,^ = 0 ; # 音や言葉に関するメタ的な引用 ; # 0 to remain, 1 to delete ; # Example: '助詞の (M は) は (M は) と書くが発音は (M わ)' M = 0 -; # Example: '(L (M ヒ)+(M ヒ))', '(L (M (? ヒ+ヒ)))' -M^ = 0 ; # 外国語や古語、方言など ; # 0 to remain, 1 to delete ; # Example: '(O ザッツファイン)' O = 0 -; # Example: '(笑 (O エクスキューズ+ミー))', '(笑 メダッ+テ+(O ナンボ))' -O^ = 0 ; # 講演者の名前、差別語、誹謗中傷など ; # 0 to remain, 1 to delete ; # Example: '国語研の (R ××) です' R = 0 -R^ = 0 ; # 非朗読対象発話(朗読における言い間違い等) ; # 0 to remain, 1 to delete ; # Example: '(X 実際は) 実際には' X = 0 -; # Example: '(L (X (D2 ニ)))' -X^ = 0 ; # アルファベットや算用数字、記号の表記 ; # 0 to use Japanese form, 1 to use alphabet form ; # Example: '(A シーディーアール;CD-R)' A = 1 -; # Example: 'スモール(A エヌ;N)', 'ラージ(A キュー;Q)', '(A ティーエフ;TF)+(A アイディーエフ;IDF)' (Strung together by pron: '(W (? ティーワイド);ティーエフ+アイディーエフ)') -A^ = 1 ; # タグAで、単語は算用数字の場合 ; # 0 to use Japanese form, 1 to use Arabic numerals ; # Example: (A 二千;2000) -A_num = eval:self.notag -A_num^ = eval:self.notag +A_num = 0 ; # 何らかの原因で漢字表記できなくなった場合 ; # 0 to use broken form, 1 to use orthodox form ; # Example: '(K たち (F えー) ばな;橘)' K = 1 -; # Example: '合(K か(?)く;格)', '宮(K ま(?)え;前)' -K^ = 1 ; # 転訛、発音の怠けなど、一時的な発音エラー ; # 0 to use wrong form, 1 to use orthodox form ; # Example: '(W ギーツ;ギジュツ)' W = 1 -; # Example: '(F (W エド;エト))', 'イベント(W リレーティッド;リレーテッド)' -W^ = 1 ; # 語の読みに関する知識レベルのいい間違い ; # 0 to use wrong form, 1 to use orthodox form ; # Example: '(B シブタイ;ジュータイ)' B = 0 -; # Example: 'データー(B カズ;スー)' -B^ = 0 ; # 笑いながら発話 ; # 0 to remain, 1 to delete ; # Example: '(笑 ナニガ)', '(笑 (F エー)+ソー+イッ+タ+ヨー+ナ)' 笑 = 0 -; # Example: 'コク(笑 サイ+(D オン))', -笑^ = 0 ; # 泣きながら発話 ; # 0 to remain, 1 to delete ; # Example: '(泣 ドンナニ)' 泣 = 0 -泣^ = 0 ; # 咳をしながら発話 ; # 0 to remain, 1 to delete ; # Example: 'シャ(咳 リン) ノ' 咳 = 0 -; # Example: 'イッ(咳 パン)', 'ワズ(咳 カ)' -咳^ = 0 ; # ささやき声や独り言などの小さな声 ; # 0 to remain, 1 to delete ; # Example: '(L アレコレナンダッケ)', '(L (W コデ;(? コレ,ココデ))+(? セツメー+シ+タ+ホー+ガ+イー+カ+ナ))' L = 0 -; # Example: 'デ(L ス)', 'ッ(L テ+コ)ト' -L^ = 0 - -[REPLACEMENTS] -; # ボーカルフライなどで母音が同定できない場合 - = -; # 「うん/うーん/ふーん」の音の特定が困難な場合 - = -; # 非語彙的な母音の引き延ばし - = -; # 非語彙的な子音の引き延ばし - = -; # 言語音と独立に講演者の笑いが生じている場合 -<笑> = -; # 言語音と独立に講演者の咳が生じている場合 -<咳> = -; # 言語音と独立に講演者の息が生じている場合 -<息> = -; # 講演者の泣き声 -<泣> = -; # 聴衆(司会者なども含む)の発話 -<フロア発話> = -; # 聴衆の笑い -<フロア笑> = -; # 聴衆の拍手 -<拍手> = -; # 講演者が発表中に用いたデモンストレーションの音声 -<デモ> = -; # 学会講演に発表時間を知らせるためにならすベルの音 -<ベル> = -; # 転記単位全体が再度読み直された場合 -<朗読間違い> = -; # 上記以外の音で特に目立った音 -<雑音> = -; # 0.2秒以上のポーズ -

= -; # Redacted information, for R -; # It is \x00D7 multiplication sign, not your normal 'x' -× = × - -[FIELDS] -; # Time information for segment -time = 3 -; # Word surface -surface = 5 -; # Word surface root form without CSJ tags -notag = 9 -; # Part Of Speech -pos1 = 11 -; # Conjugated Form -cForm = 12 -; # Conjugation Type -cType1 = 13 -; # Subcategory of POS -pos2 = 14 -; # Euphonic Change / Subcategory of Conjugation Type -cType2 = 15 -; # Other information -other = 16 -; # Pronunciation for lexicon -pron = 10 -; # Speaker ID -spk_id = 2 - -[KATAKANA2ROMAJI] -ア = 'a -イ = 'i -ウ = 'u -エ = 'e -オ = 'o -カ = ka -キ = ki -ク = ku -ケ = ke -コ = ko -ガ = ga -ギ = gi -グ = gu -ゲ = ge -ゴ = go -サ = sa -シ = si -ス = su -セ = se -ソ = so -ザ = za -ジ = zi -ズ = zu -ゼ = ze -ゾ = zo -タ = ta -チ = ti -ツ = tu -テ = te -ト = to -ダ = da -ヂ = di -ヅ = du -デ = de -ド = do -ナ = na -ニ = ni -ヌ = nu -ネ = ne -ノ = no -ハ = ha -ヒ = hi -フ = hu -ヘ = he -ホ = ho -バ = ba -ビ = bi -ブ = bu -ベ = be -ボ = bo -パ = pa -ピ = pi -プ = pu -ペ = pe -ポ = po -マ = ma -ミ = mi -ム = mu -メ = me -モ = mo -ヤ = ya -ユ = yu -ヨ = yo -ラ = ra -リ = ri -ル = ru -レ = re -ロ = ro -ワ = wa -ヰ = we -ヱ = wi -ヲ = wo -ン = ŋ -ッ = q -ー = - -キャ = kǐa -キュ = kǐu -キョ = kǐo -ギャ = gǐa -ギュ = gǐu -ギョ = gǐo -シャ = sǐa -シュ = sǐu -ショ = sǐo -ジャ = zǐa -ジュ = zǐu -ジョ = zǐo -チャ = tǐa -チュ = tǐu -チョ = tǐo -ヂャ = dǐa -ヂュ = dǐu -ヂョ = dǐo -ニャ = nǐa -ニュ = nǐu -ニョ = nǐo -ヒャ = hǐa -ヒュ = hǐu -ヒョ = hǐo -ビャ = bǐa -ビュ = bǐu -ビョ = bǐo -ピャ = pǐa -ピュ = pǐu -ピョ = pǐo -ミャ = mǐa -ミュ = mǐu -ミョ = mǐo -リャ = rǐa -リュ = rǐu -リョ = rǐo -ァ = a -ィ = i -ゥ = u -ェ = e -ォ = o -ヮ = ʍ -ヴ = vu -ャ = ǐa -ュ = ǐu -ョ = ǐo diff --git a/egs/csj/ASR/local/conf/fluent.ini b/egs/csj/ASR/local/conf/fluent.ini index f7f27f5bc..5d033ed17 100644 --- a/egs/csj/ASR/local/conf/fluent.ini +++ b/egs/csj/ASR/local/conf/fluent.ini @@ -1,320 +1,79 @@ -; # This section is ignored if this file is not supplied as the first config file to -; # lhotse prepare csj -[SEGMENTS] -; # Allowed period of nonverbal noise. If exceeded, a new segment is created. -gap = 0.5 -; # Maximum length of segment (s). -maxlen = 10 -; # Minimum length of segment (s). Segments shorter than `minlen` will be dropped silently. -minlen = 0.02 -; # Use this symbol to represent a period of allowed nonverbal noise, i.e. `gap`. -; # Pass an empty string to avoid adding any symbol. It was "" in kaldi. -; # If you intend to use a multicharacter string for gap_sym, remember to register the -; # multicharacter string as part of userdef-string in prepare_lang_char.py. -gap_sym = - [CONSTANTS] ; # Name of this mode MODE = fluent -; # Suffixes to use after the word surface (no longer used) -MORPH = pos1 cForm cType2 pos2 -; # Used to differentiate between A tag and A_num tag -JPN_NUM = ゼロ 0 零 一 二 三 四 五 六 七 八 九 十 百 千 . -; # Dummy character to delineate multiline words -PLUS = + [DECISIONS] -; # TAG+'^'とは、タグが一つの転記単位に独立していない場合 -; # The PLUS (fullwidth) sign '+' marks line boundaries for multiline entries - ; # フィラー、感情表出系感動詞 ; # 0 to remain, 1 to delete ; # Example: '(F ぎょっ)' F = 1 -; # Example: '(L (F ン))', '比べ(F えー)る' -F^ = 1 ; # 言い直し、いいよどみなどによる語断片 ; # 0 to remain, 1 to delete ; # Example: '(D だ)(D だいが) 大学の学部の会議' D = 1 -; # Example: '(L (D ドゥ)+(D ヒ))' -D^ = 1 ; # 助詞、助動詞、接辞の言い直し ; # 0 to remain, 1 to delete ; # Example: '西洋 (D2 的)(F えー)(D ふ) 風というか' D2 = 1 -; # Example: '(X (D2 ノ))' -D2^ = 1 ; # 聞き取りや語彙の判断に自信がない場合 ; # 0 to remain, 1 to delete ; # Example: (? 字数) の ; # If no option: empty string is returned regardless of output ; # Example: '(?) で' ? = 0 -; # Example: '(D (? すー))+そう+です+よ+ね' -?^ = 0 ; # タグ?で、値は複数の候補が想定される場合 ; # 0 for main guess with matching morph info, 1 for second guess ; # Example: '(? 次数, 実数)', '(? これ,ここで)+(? 説明+し+た+方+が+いい+か+な)' ?, = 0 -; # Example: '(W (? テユクー);(? ケッキョク,テユウコトデ))', '(W マシ;(? マシ+タ,マス))' -?,^ = 0 ; # 音や言葉に関するメタ的な引用 ; # 0 to remain, 1 to delete ; # Example: '助詞の (M は) は (M は) と書くが発音は (M わ)' M = 0 -; # Example: '(L (M ヒ)+(M ヒ))', '(L (M (? ヒ+ヒ)))' -M^ = 0 ; # 外国語や古語、方言など ; # 0 to remain, 1 to delete ; # Example: '(O ザッツファイン)' O = 0 -; # Example: '(笑 (O エクスキューズ+ミー))', '(笑 メダッ+テ+(O ナンボ))' -O^ = 0 ; # 講演者の名前、差別語、誹謗中傷など ; # 0 to remain, 1 to delete ; # Example: '国語研の (R ××) です' R = 0 -R^ = 0 ; # 非朗読対象発話(朗読における言い間違い等) ; # 0 to remain, 1 to delete ; # Example: '(X 実際は) 実際には' X = 0 -; # Example: '(L (X (D2 ニ)))' -X^ = 0 ; # アルファベットや算用数字、記号の表記 ; # 0 to use Japanese form, 1 to use alphabet form ; # Example: '(A シーディーアール;CD-R)' A = 1 -; # Example: 'スモール(A エヌ;N)', 'ラージ(A キュー;Q)', '(A ティーエフ;TF)+(A アイディーエフ;IDF)' (Strung together by pron: '(W (? ティーワイド);ティーエフ+アイディーエフ)') -A^ = 1 ; # タグAで、単語は算用数字の場合 ; # 0 to use Japanese form, 1 to use Arabic numerals ; # Example: (A 二千;2000) -A_num = eval:self.notag -A_num^ = eval:self.notag +A_num = 0 ; # 何らかの原因で漢字表記できなくなった場合 ; # 0 to use broken form, 1 to use orthodox form ; # Example: '(K たち (F えー) ばな;橘)' K = 1 -; # Example: '合(K か(?)く;格)', '宮(K ま(?)え;前)' -K^ = 1 ; # 転訛、発音の怠けなど、一時的な発音エラー ; # 0 to use wrong form, 1 to use orthodox form ; # Example: '(W ギーツ;ギジュツ)' W = 1 -; # Example: '(F (W エド;エト))', 'イベント(W リレーティッド;リレーテッド)' -W^ = 1 ; # 語の読みに関する知識レベルのいい間違い ; # 0 to use wrong form, 1 to use orthodox form ; # Example: '(B シブタイ;ジュータイ)' B = 0 -; # Example: 'データー(B カズ;スー)' -B^ = 0 ; # 笑いながら発話 ; # 0 to remain, 1 to delete ; # Example: '(笑 ナニガ)', '(笑 (F エー)+ソー+イッ+タ+ヨー+ナ)' 笑 = 0 -; # Example: 'コク(笑 サイ+(D オン))', -笑^ = 0 ; # 泣きながら発話 ; # 0 to remain, 1 to delete ; # Example: '(泣 ドンナニ)' 泣 = 0 -泣^ = 0 ; # 咳をしながら発話 ; # 0 to remain, 1 to delete ; # Example: 'シャ(咳 リン) ノ' 咳 = 0 -; # Example: 'イッ(咳 パン)', 'ワズ(咳 カ)' -咳^ = 0 ; # ささやき声や独り言などの小さな声 ; # 0 to remain, 1 to delete ; # Example: '(L アレコレナンダッケ)', '(L (W コデ;(? コレ,ココデ))+(? セツメー+シ+タ+ホー+ガ+イー+カ+ナ))' L = 0 -; # Example: 'デ(L ス)', 'ッ(L テ+コ)ト' -L^ = 0 - -[REPLACEMENTS] -; # ボーカルフライなどで母音が同定できない場合 - = -; # 「うん/うーん/ふーん」の音の特定が困難な場合 - = -; # 非語彙的な母音の引き延ばし - = -; # 非語彙的な子音の引き延ばし - = -; # 言語音と独立に講演者の笑いが生じている場合 -<笑> = -; # 言語音と独立に講演者の咳が生じている場合 -<咳> = -; # 言語音と独立に講演者の息が生じている場合 -<息> = -; # 講演者の泣き声 -<泣> = -; # 聴衆(司会者なども含む)の発話 -<フロア発話> = -; # 聴衆の笑い -<フロア笑> = -; # 聴衆の拍手 -<拍手> = -; # 講演者が発表中に用いたデモンストレーションの音声 -<デモ> = -; # 学会講演に発表時間を知らせるためにならすベルの音 -<ベル> = -; # 転記単位全体が再度読み直された場合 -<朗読間違い> = -; # 上記以外の音で特に目立った音 -<雑音> = -; # 0.2秒以上のポーズ -

= -; # Redacted information, for R -; # It is \x00D7 multiplication sign, not your normal 'x' -× = × - -[FIELDS] -; # Time information for segment -time = 3 -; # Word surface -surface = 5 -; # Word surface root form without CSJ tags -notag = 9 -; # Part Of Speech -pos1 = 11 -; # Conjugated Form -cForm = 12 -; # Conjugation Type -cType1 = 13 -; # Subcategory of POS -pos2 = 14 -; # Euphonic Change / Subcategory of Conjugation Type -cType2 = 15 -; # Other information -other = 16 -; # Pronunciation for lexicon -pron = 10 -; # Speaker ID -spk_id = 2 - -[KATAKANA2ROMAJI] -ア = 'a -イ = 'i -ウ = 'u -エ = 'e -オ = 'o -カ = ka -キ = ki -ク = ku -ケ = ke -コ = ko -ガ = ga -ギ = gi -グ = gu -ゲ = ge -ゴ = go -サ = sa -シ = si -ス = su -セ = se -ソ = so -ザ = za -ジ = zi -ズ = zu -ゼ = ze -ゾ = zo -タ = ta -チ = ti -ツ = tu -テ = te -ト = to -ダ = da -ヂ = di -ヅ = du -デ = de -ド = do -ナ = na -ニ = ni -ヌ = nu -ネ = ne -ノ = no -ハ = ha -ヒ = hi -フ = hu -ヘ = he -ホ = ho -バ = ba -ビ = bi -ブ = bu -ベ = be -ボ = bo -パ = pa -ピ = pi -プ = pu -ペ = pe -ポ = po -マ = ma -ミ = mi -ム = mu -メ = me -モ = mo -ヤ = ya -ユ = yu -ヨ = yo -ラ = ra -リ = ri -ル = ru -レ = re -ロ = ro -ワ = wa -ヰ = we -ヱ = wi -ヲ = wo -ン = ŋ -ッ = q -ー = - -キャ = kǐa -キュ = kǐu -キョ = kǐo -ギャ = gǐa -ギュ = gǐu -ギョ = gǐo -シャ = sǐa -シュ = sǐu -ショ = sǐo -ジャ = zǐa -ジュ = zǐu -ジョ = zǐo -チャ = tǐa -チュ = tǐu -チョ = tǐo -ヂャ = dǐa -ヂュ = dǐu -ヂョ = dǐo -ニャ = nǐa -ニュ = nǐu -ニョ = nǐo -ヒャ = hǐa -ヒュ = hǐu -ヒョ = hǐo -ビャ = bǐa -ビュ = bǐu -ビョ = bǐo -ピャ = pǐa -ピュ = pǐu -ピョ = pǐo -ミャ = mǐa -ミュ = mǐu -ミョ = mǐo -リャ = rǐa -リュ = rǐu -リョ = rǐo -ァ = a -ィ = i -ゥ = u -ェ = e -ォ = o -ヮ = ʍ -ヴ = vu -ャ = ǐa -ュ = ǐu -ョ = ǐo diff --git a/egs/csj/ASR/local/conf/number.ini b/egs/csj/ASR/local/conf/number.ini index cf9038f62..3ada9aa24 100644 --- a/egs/csj/ASR/local/conf/number.ini +++ b/egs/csj/ASR/local/conf/number.ini @@ -1,320 +1,79 @@ -; # This section is ignored if this file is not supplied as the first config file to -; # lhotse prepare csj -[SEGMENTS] -; # Allowed period of nonverbal noise. If exceeded, a new segment is created. -gap = 0.5 -; # Maximum length of segment (s). -maxlen = 10 -; # Minimum length of segment (s). Segments shorter than `minlen` will be dropped silently. -minlen = 0.02 -; # Use this symbol to represent a period of allowed nonverbal noise, i.e. `gap`. -; # Pass an empty string to avoid adding any symbol. It was "" in kaldi. -; # If you intend to use a multicharacter string for gap_sym, remember to register the -; # multicharacter string as part of userdef-string in prepare_lang_char.py. -gap_sym = - [CONSTANTS] ; # Name of this mode MODE = number -; # Suffixes to use after the word surface (no longer used) -MORPH = pos1 cForm cType2 pos2 -; # Used to differentiate between A tag and A_num tag -JPN_NUM = ゼロ 0 零 一 二 三 四 五 六 七 八 九 十 百 千 . -; # Dummy character to delineate multiline words -PLUS = + [DECISIONS] -; # TAG+'^'とは、タグが一つの転記単位に独立していない場合 -; # The PLUS (fullwidth) sign '+' marks line boundaries for multiline entries - ; # フィラー、感情表出系感動詞 ; # 0 to remain, 1 to delete ; # Example: '(F ぎょっ)' F = 1 -; # Example: '(L (F ン))', '比べ(F えー)る' -F^ = 1 ; # 言い直し、いいよどみなどによる語断片 ; # 0 to remain, 1 to delete ; # Example: '(D だ)(D だいが) 大学の学部の会議' D = 1 -; # Example: '(L (D ドゥ)+(D ヒ))' -D^ = 1 ; # 助詞、助動詞、接辞の言い直し ; # 0 to remain, 1 to delete ; # Example: '西洋 (D2 的)(F えー)(D ふ) 風というか' D2 = 1 -; # Example: '(X (D2 ノ))' -D2^ = 1 ; # 聞き取りや語彙の判断に自信がない場合 ; # 0 to remain, 1 to delete ; # Example: (? 字数) の ; # If no option: empty string is returned regardless of output ; # Example: '(?) で' ? = 0 -; # Example: '(D (? すー))+そう+です+よ+ね' -?^ = 0 ; # タグ?で、値は複数の候補が想定される場合 ; # 0 for main guess with matching morph info, 1 for second guess ; # Example: '(? 次数, 実数)', '(? これ,ここで)+(? 説明+し+た+方+が+いい+か+な)' ?, = 0 -; # Example: '(W (? テユクー);(? ケッキョク,テユウコトデ))', '(W マシ;(? マシ+タ,マス))' -?,^ = 0 ; # 音や言葉に関するメタ的な引用 ; # 0 to remain, 1 to delete ; # Example: '助詞の (M は) は (M は) と書くが発音は (M わ)' M = 0 -; # Example: '(L (M ヒ)+(M ヒ))', '(L (M (? ヒ+ヒ)))' -M^ = 0 ; # 外国語や古語、方言など ; # 0 to remain, 1 to delete ; # Example: '(O ザッツファイン)' O = 0 -; # Example: '(笑 (O エクスキューズ+ミー))', '(笑 メダッ+テ+(O ナンボ))' -O^ = 0 ; # 講演者の名前、差別語、誹謗中傷など ; # 0 to remain, 1 to delete ; # Example: '国語研の (R ××) です' R = 0 -R^ = 0 ; # 非朗読対象発話(朗読における言い間違い等) ; # 0 to remain, 1 to delete ; # Example: '(X 実際は) 実際には' X = 0 -; # Example: '(L (X (D2 ニ)))' -X^ = 0 ; # アルファベットや算用数字、記号の表記 ; # 0 to use Japanese form, 1 to use alphabet form ; # Example: '(A シーディーアール;CD-R)' A = 1 -; # Example: 'スモール(A エヌ;N)', 'ラージ(A キュー;Q)', '(A ティーエフ;TF)+(A アイディーエフ;IDF)' (Strung together by pron: '(W (? ティーワイド);ティーエフ+アイディーエフ)') -A^ = 1 ; # タグAで、単語は算用数字の場合 ; # 0 to use Japanese form, 1 to use Arabic numerals ; # Example: (A 二千;2000) A_num = 1 -A_num^ = 1 ; # 何らかの原因で漢字表記できなくなった場合 ; # 0 to use broken form, 1 to use orthodox form ; # Example: '(K たち (F えー) ばな;橘)' K = 1 -; # Example: '合(K か(?)く;格)', '宮(K ま(?)え;前)' -K^ = 1 ; # 転訛、発音の怠けなど、一時的な発音エラー ; # 0 to use wrong form, 1 to use orthodox form ; # Example: '(W ギーツ;ギジュツ)' W = 1 -; # Example: '(F (W エド;エト))', 'イベント(W リレーティッド;リレーテッド)' -W^ = 1 ; # 語の読みに関する知識レベルのいい間違い ; # 0 to use wrong form, 1 to use orthodox form ; # Example: '(B シブタイ;ジュータイ)' B = 0 -; # Example: 'データー(B カズ;スー)' -B^ = 0 ; # 笑いながら発話 ; # 0 to remain, 1 to delete ; # Example: '(笑 ナニガ)', '(笑 (F エー)+ソー+イッ+タ+ヨー+ナ)' 笑 = 0 -; # Example: 'コク(笑 サイ+(D オン))', -笑^ = 0 ; # 泣きながら発話 ; # 0 to remain, 1 to delete ; # Example: '(泣 ドンナニ)' 泣 = 0 -泣^ = 0 ; # 咳をしながら発話 ; # 0 to remain, 1 to delete ; # Example: 'シャ(咳 リン) ノ' 咳 = 0 -; # Example: 'イッ(咳 パン)', 'ワズ(咳 カ)' -咳^ = 0 ; # ささやき声や独り言などの小さな声 ; # 0 to remain, 1 to delete ; # Example: '(L アレコレナンダッケ)', '(L (W コデ;(? コレ,ココデ))+(? セツメー+シ+タ+ホー+ガ+イー+カ+ナ))' L = 0 -; # Example: 'デ(L ス)', 'ッ(L テ+コ)ト' -L^ = 0 - -[REPLACEMENTS] -; # ボーカルフライなどで母音が同定できない場合 - = -; # 「うん/うーん/ふーん」の音の特定が困難な場合 - = -; # 非語彙的な母音の引き延ばし - = -; # 非語彙的な子音の引き延ばし - = -; # 言語音と独立に講演者の笑いが生じている場合 -<笑> = -; # 言語音と独立に講演者の咳が生じている場合 -<咳> = -; # 言語音と独立に講演者の息が生じている場合 -<息> = -; # 講演者の泣き声 -<泣> = -; # 聴衆(司会者なども含む)の発話 -<フロア発話> = -; # 聴衆の笑い -<フロア笑> = -; # 聴衆の拍手 -<拍手> = -; # 講演者が発表中に用いたデモンストレーションの音声 -<デモ> = -; # 学会講演に発表時間を知らせるためにならすベルの音 -<ベル> = -; # 転記単位全体が再度読み直された場合 -<朗読間違い> = -; # 上記以外の音で特に目立った音 -<雑音> = -; # 0.2秒以上のポーズ -

= -; # Redacted information, for R -; # It is \x00D7 multiplication sign, not your normal 'x' -× = × - -[FIELDS] -; # Time information for segment -time = 3 -; # Word surface -surface = 5 -; # Word surface root form without CSJ tags -notag = 9 -; # Part Of Speech -pos1 = 11 -; # Conjugated Form -cForm = 12 -; # Conjugation Type -cType1 = 13 -; # Subcategory of POS -pos2 = 14 -; # Euphonic Change / Subcategory of Conjugation Type -cType2 = 15 -; # Other information -other = 16 -; # Pronunciation for lexicon -pron = 10 -; # Speaker ID -spk_id = 2 - -[KATAKANA2ROMAJI] -ア = 'a -イ = 'i -ウ = 'u -エ = 'e -オ = 'o -カ = ka -キ = ki -ク = ku -ケ = ke -コ = ko -ガ = ga -ギ = gi -グ = gu -ゲ = ge -ゴ = go -サ = sa -シ = si -ス = su -セ = se -ソ = so -ザ = za -ジ = zi -ズ = zu -ゼ = ze -ゾ = zo -タ = ta -チ = ti -ツ = tu -テ = te -ト = to -ダ = da -ヂ = di -ヅ = du -デ = de -ド = do -ナ = na -ニ = ni -ヌ = nu -ネ = ne -ノ = no -ハ = ha -ヒ = hi -フ = hu -ヘ = he -ホ = ho -バ = ba -ビ = bi -ブ = bu -ベ = be -ボ = bo -パ = pa -ピ = pi -プ = pu -ペ = pe -ポ = po -マ = ma -ミ = mi -ム = mu -メ = me -モ = mo -ヤ = ya -ユ = yu -ヨ = yo -ラ = ra -リ = ri -ル = ru -レ = re -ロ = ro -ワ = wa -ヰ = we -ヱ = wi -ヲ = wo -ン = ŋ -ッ = q -ー = - -キャ = kǐa -キュ = kǐu -キョ = kǐo -ギャ = gǐa -ギュ = gǐu -ギョ = gǐo -シャ = sǐa -シュ = sǐu -ショ = sǐo -ジャ = zǐa -ジュ = zǐu -ジョ = zǐo -チャ = tǐa -チュ = tǐu -チョ = tǐo -ヂャ = dǐa -ヂュ = dǐu -ヂョ = dǐo -ニャ = nǐa -ニュ = nǐu -ニョ = nǐo -ヒャ = hǐa -ヒュ = hǐu -ヒョ = hǐo -ビャ = bǐa -ビュ = bǐu -ビョ = bǐo -ピャ = pǐa -ピュ = pǐu -ピョ = pǐo -ミャ = mǐa -ミュ = mǐu -ミョ = mǐo -リャ = rǐa -リュ = rǐu -リョ = rǐo -ァ = a -ィ = i -ゥ = u -ェ = e -ォ = o -ヮ = ʍ -ヴ = vu -ャ = ǐa -ュ = ǐu -ョ = ǐo diff --git a/egs/csj/ASR/local/conf/symbol.ini b/egs/csj/ASR/local/conf/symbol.ini index f9801284b..dafd65c9a 100644 --- a/egs/csj/ASR/local/conf/symbol.ini +++ b/egs/csj/ASR/local/conf/symbol.ini @@ -1,321 +1,80 @@ -; # This section is ignored if this file is not supplied as the first config file to -; # lhotse prepare csj -[SEGMENTS] -; # Allowed period of nonverbal noise. If exceeded, a new segment is created. -gap = 0.5 -; # Maximum length of segment (s). -maxlen = 10 -; # Minimum length of segment (s). Segments shorter than `minlen` will be dropped silently. -minlen = 0.02 -; # Use this symbol to represent a period of allowed nonverbal noise, i.e. `gap`. -; # Pass an empty string to avoid adding any symbol. It was "" in kaldi. -; # If you intend to use a multicharacter string for gap_sym, remember to register the -; # multicharacter string as part of userdef-string in prepare_lang_char.py. -gap_sym = - [CONSTANTS] ; # Name of this mode -; # See https://www.isca-speech.org/archive/pdfs/interspeech_2022/horii22_interspeech.pdf +; # From https://www.isca-speech.org/archive/pdfs/interspeech_2022/horii22_interspeech.pdf MODE = symbol -; # Suffixes to use after the word surface (no longer used) -MORPH = pos1 cForm cType2 pos2 -; # Used to differentiate between A tag and A_num tag -JPN_NUM = ゼロ 0 零 一 二 三 四 五 六 七 八 九 十 百 千 . -; # Dummy character to delineate multiline words -PLUS = + [DECISIONS] -; # TAG+'^'とは、タグが一つの転記単位に独立していない場合 -; # The PLUS (fullwidth) sign '+' marks line boundaries for multiline entries - ; # フィラー、感情表出系感動詞 ; # 0 to remain, 1 to delete ; # Example: '(F ぎょっ)' -F = # -; # Example: '(L (F ン))', '比べ(F えー)る' -F^ = # +F = "#", ["F"] ; # 言い直し、いいよどみなどによる語断片 ; # 0 to remain, 1 to delete ; # Example: '(D だ)(D だいが) 大学の学部の会議' -D = @ -; # Example: '(L (D ドゥ)+(D ヒ))' -D^ = @ +D = "@", ["D"] ; # 助詞、助動詞、接辞の言い直し ; # 0 to remain, 1 to delete ; # Example: '西洋 (D2 的)(F えー)(D ふ) 風というか' -D2 = @ -; # Example: '(X (D2 ノ))' -D2^ = @ +D2 = "@", ["D2"] ; # 聞き取りや語彙の判断に自信がない場合 ; # 0 to remain, 1 to delete ; # Example: (? 字数) の ; # If no option: empty string is returned regardless of output ; # Example: '(?) で' ? = 0 -; # Example: '(D (? すー))+そう+です+よ+ね' -?^ = 0 ; # タグ?で、値は複数の候補が想定される場合 ; # 0 for main guess with matching morph info, 1 for second guess ; # Example: '(? 次数, 実数)', '(? これ,ここで)+(? 説明+し+た+方+が+いい+か+な)' ?, = 0 -; # Example: '(W (? テユクー);(? ケッキョク,テユウコトデ))', '(W マシ;(? マシ+タ,マス))' -?,^ = 0 ; # 音や言葉に関するメタ的な引用 ; # 0 to remain, 1 to delete ; # Example: '助詞の (M は) は (M は) と書くが発音は (M わ)' M = 0 -; # Example: '(L (M ヒ)+(M ヒ))', '(L (M (? ヒ+ヒ)))' -M^ = 0 ; # 外国語や古語、方言など ; # 0 to remain, 1 to delete ; # Example: '(O ザッツファイン)' O = 0 -; # Example: '(笑 (O エクスキューズ+ミー))', '(笑 メダッ+テ+(O ナンボ))' -O^ = 0 ; # 講演者の名前、差別語、誹謗中傷など ; # 0 to remain, 1 to delete ; # Example: '国語研の (R ××) です' R = 0 -R^ = 0 ; # 非朗読対象発話(朗読における言い間違い等) ; # 0 to remain, 1 to delete ; # Example: '(X 実際は) 実際には' X = 0 -; # Example: '(L (X (D2 ニ)))' -X^ = 0 ; # アルファベットや算用数字、記号の表記 ; # 0 to use Japanese form, 1 to use alphabet form ; # Example: '(A シーディーアール;CD-R)' A = 1 -; # Example: 'スモール(A エヌ;N)', 'ラージ(A キュー;Q)', '(A ティーエフ;TF)+(A アイディーエフ;IDF)' (Strung together by pron: '(W (? ティーワイド);ティーエフ+アイディーエフ)') -A^ = 1 ; # タグAで、単語は算用数字の場合 ; # 0 to use Japanese form, 1 to use Arabic numerals ; # Example: (A 二千;2000) -A_num = eval:self.notag -A_num^ = eval:self.notag +A_num = 1 ; # 何らかの原因で漢字表記できなくなった場合 ; # 0 to use broken form, 1 to use orthodox form ; # Example: '(K たち (F えー) ばな;橘)' K = 1 -; # Example: '合(K か(?)く;格)', '宮(K ま(?)え;前)' -K^ = 1 ; # 転訛、発音の怠けなど、一時的な発音エラー ; # 0 to use wrong form, 1 to use orthodox form ; # Example: '(W ギーツ;ギジュツ)' W = 1 -; # Example: '(F (W エド;エト))', 'イベント(W リレーティッド;リレーテッド)' -W^ = 1 ; # 語の読みに関する知識レベルのいい間違い ; # 0 to use wrong form, 1 to use orthodox form ; # Example: '(B シブタイ;ジュータイ)' B = 0 -; # Example: 'データー(B カズ;スー)' -B^ = 0 ; # 笑いながら発話 ; # 0 to remain, 1 to delete ; # Example: '(笑 ナニガ)', '(笑 (F エー)+ソー+イッ+タ+ヨー+ナ)' 笑 = 0 -; # Example: 'コク(笑 サイ+(D オン))', -笑^ = 0 ; # 泣きながら発話 ; # 0 to remain, 1 to delete ; # Example: '(泣 ドンナニ)' 泣 = 0 -泣^ = 0 ; # 咳をしながら発話 ; # 0 to remain, 1 to delete ; # Example: 'シャ(咳 リン) ノ' 咳 = 0 -; # Example: 'イッ(咳 パン)', 'ワズ(咳 カ)' -咳^ = 0 ; # ささやき声や独り言などの小さな声 ; # 0 to remain, 1 to delete ; # Example: '(L アレコレナンダッケ)', '(L (W コデ;(? コレ,ココデ))+(? セツメー+シ+タ+ホー+ガ+イー+カ+ナ))' L = 0 -; # Example: 'デ(L ス)', 'ッ(L テ+コ)ト' -L^ = 0 - -[REPLACEMENTS] -; # ボーカルフライなどで母音が同定できない場合 - = -; # 「うん/うーん/ふーん」の音の特定が困難な場合 - = -; # 非語彙的な母音の引き延ばし - = -; # 非語彙的な子音の引き延ばし - = -; # 言語音と独立に講演者の笑いが生じている場合 -<笑> = -; # 言語音と独立に講演者の咳が生じている場合 -<咳> = -; # 言語音と独立に講演者の息が生じている場合 -<息> = -; # 講演者の泣き声 -<泣> = -; # 聴衆(司会者なども含む)の発話 -<フロア発話> = -; # 聴衆の笑い -<フロア笑> = -; # 聴衆の拍手 -<拍手> = -; # 講演者が発表中に用いたデモンストレーションの音声 -<デモ> = -; # 学会講演に発表時間を知らせるためにならすベルの音 -<ベル> = -; # 転記単位全体が再度読み直された場合 -<朗読間違い> = -; # 上記以外の音で特に目立った音 -<雑音> = -; # 0.2秒以上のポーズ -

= -; # Redacted information, for R -; # It is \x00D7 multiplication sign, not your normal 'x' -× = × - -[FIELDS] -; # Time information for segment -time = 3 -; # Word surface -surface = 5 -; # Word surface root form without CSJ tags -notag = 9 -; # Part Of Speech -pos1 = 11 -; # Conjugated Form -cForm = 12 -; # Conjugation Type -cType1 = 13 -; # Subcategory of POS -pos2 = 14 -; # Euphonic Change / Subcategory of Conjugation Type -cType2 = 15 -; # Other information -other = 16 -; # Pronunciation for lexicon -pron = 10 -; # Speaker ID -spk_id = 2 - -[KATAKANA2ROMAJI] -ア = 'a -イ = 'i -ウ = 'u -エ = 'e -オ = 'o -カ = ka -キ = ki -ク = ku -ケ = ke -コ = ko -ガ = ga -ギ = gi -グ = gu -ゲ = ge -ゴ = go -サ = sa -シ = si -ス = su -セ = se -ソ = so -ザ = za -ジ = zi -ズ = zu -ゼ = ze -ゾ = zo -タ = ta -チ = ti -ツ = tu -テ = te -ト = to -ダ = da -ヂ = di -ヅ = du -デ = de -ド = do -ナ = na -ニ = ni -ヌ = nu -ネ = ne -ノ = no -ハ = ha -ヒ = hi -フ = hu -ヘ = he -ホ = ho -バ = ba -ビ = bi -ブ = bu -ベ = be -ボ = bo -パ = pa -ピ = pi -プ = pu -ペ = pe -ポ = po -マ = ma -ミ = mi -ム = mu -メ = me -モ = mo -ヤ = ya -ユ = yu -ヨ = yo -ラ = ra -リ = ri -ル = ru -レ = re -ロ = ro -ワ = wa -ヰ = we -ヱ = wi -ヲ = wo -ン = ŋ -ッ = q -ー = - -キャ = kǐa -キュ = kǐu -キョ = kǐo -ギャ = gǐa -ギュ = gǐu -ギョ = gǐo -シャ = sǐa -シュ = sǐu -ショ = sǐo -ジャ = zǐa -ジュ = zǐu -ジョ = zǐo -チャ = tǐa -チュ = tǐu -チョ = tǐo -ヂャ = dǐa -ヂュ = dǐu -ヂョ = dǐo -ニャ = nǐa -ニュ = nǐu -ニョ = nǐo -ヒャ = hǐa -ヒュ = hǐu -ヒョ = hǐo -ビャ = bǐa -ビュ = bǐu -ビョ = bǐo -ピャ = pǐa -ピュ = pǐu -ピョ = pǐo -ミャ = mǐa -ミュ = mǐu -ミョ = mǐo -リャ = rǐa -リュ = rǐu -リョ = rǐo -ァ = a -ィ = i -ゥ = u -ェ = e -ォ = o -ヮ = ʍ -ヴ = vu -ャ = ǐa -ュ = ǐu -ョ = ǐo diff --git a/egs/csj/ASR/local/disfluent_recogs_to_fluent.py b/egs/csj/ASR/local/disfluent_recogs_to_fluent.py new file mode 100644 index 000000000..45c9c7656 --- /dev/null +++ b/egs/csj/ASR/local/disfluent_recogs_to_fluent.py @@ -0,0 +1,202 @@ +import argparse +from pathlib import Path + +import kaldialign +from lhotse import CutSet + +ARGPARSE_DESCRIPTION = """ +This helper code takes in a disfluent recogs file generated from icefall.utils.store_transcript, +compares it against a fluent transcript, and saves the results in a separate directory. +This is useful to compare disfluent models with fluent models on the same metric. + +""" + + +def get_args(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + description=ARGPARSE_DESCRIPTION, + ) + parser.add_argument( + "--recogs", + type=Path, + required=True, + help="Path to the recogs-XXX file generated by icefall.utils.store_transcript.", + ) + parser.add_argument( + "--cut", + type=Path, + required=True, + help="Path to the cut manifest to be compared to. Assumes that disfluent_tag exists in the custom dict.", + ) + parser.add_argument( + "--res-dir", type=Path, required=True, help="Path to save results" + ) + return parser.parse_args() + + +def d2f(stats): + """ + Compare the outputs of a disfluent model against a fluent reference. + Indicates a disfluent model's performance only on the content words + + CER^d_f = (sub_f + ins + del_f) / Nf + + """ + return stats["base"] / stats["Nf"] + + +def calc_cer(refs, hyps): + subs = { + "F": 0, + "D": 0, + } + ins = 0 + dels = { + "F": 0, + "D": 0, + } + cors = { + "F": 0, + "D": 0, + } + dis_ref_len = 0 + flu_ref_len = 0 + + for ref, hyp in zip(refs, hyps): + assert ( + ref[0] == hyp[0] + ), f"Expected ref cut id {ref[0]} to be the same as hyp cut id {hyp[0]}." + tag = ref[2].copy() + ref = ref[1] + dis_ref_len += len(ref) + # Remember that the 'D' and 'F' tags here refer to CSJ tags, not disfluent and fluent respectively. + flu_ref_len += len([t for t in tag if ("D" not in t and "F" not in t)]) + hyp = hyp[1] + ali = kaldialign.align(ref, hyp, "*") + tags = ["*" if r[0] == "*" else tag.pop(0) for r in ali] + for tag, (ref_word, hyp_word) in zip(tags, ali): + if "D" in tag or "F" in tag: + tag = "D" + else: + tag = "F" + + if ref_word == "*": + ins += 1 + elif hyp_word == "*": + dels[tag] += 1 + elif ref_word != hyp_word: + subs[tag] += 1 + else: + cors[tag] += 1 + + return { + "subs": subs, + "ins": ins, + "dels": dels, + "cors": cors, + "dis_ref_len": dis_ref_len, + "flu_ref_len": flu_ref_len, + } + + +def for_each_recogs(recogs_file: Path, refs, out_dir): + hyps = [] + with recogs_file.open() as fin: + for line in fin: + if "ref" in line: + continue + cutid, hyp = line.split(":\thyp=") + hyps.append((cutid, eval(hyp))) + + assert len(refs) == len( + hyps + ), f"Expected refs len {len(refs)} and hyps len {len(hyps)} to be equal." + stats = calc_cer(refs, hyps) + stat_table = ["tag,yes,no"] + + for cer_type in ["subs", "dels", "cors", "ins"]: + ret = f"{cer_type}" + for df in ["D", "F"]: + try: + ret += f",{stats[cer_type][df]}" + except TypeError: + # insertions do not belong to F or D, and is not subscriptable. + ret += f",{stats[cer_type]}," + break + stat_table.append(ret) + stat_table = "\n".join(stat_table) + + stats = { + "subd": stats["subs"]["D"], + "deld": stats["dels"]["D"], + "cord": stats["cors"]["D"], + "Nf": stats["flu_ref_len"], + "base": stats["subs"]["F"] + stats["ins"] + stats["dels"]["F"], + } + + cer = d2f(stats) + results = [ + f"{cer:.2%}", + f"Nf,{stats['Nf']}", + ] + results = "\n".join(results) + + with (out_dir / (recogs_file.stem + ".dfcer")).open("w") as fout: + fout.write(results) + fout.write("\n\n") + fout.write(stat_table) + + +def main(): + args = get_args() + recogs_file: Path = args.recogs + assert ( + recogs_file.is_file() or recogs_file.is_dir() + ), f"recogs_file cannot be found at {recogs_file}." + + args.res_dir.mkdir(parents=True, exist_ok=True) + + if recogs_file.is_file() and recogs_file.stem.startswith("recogs-"): + assert ( + "csj_cuts" in args.cut.name + ), f"Expected {args.cut} to be a cuts manifest." + + refs: CutSet = CutSet.from_file(args.cut) + refs = sorted( + [ + ( + e.id, + list(e.supervisions[0].custom["disfluent"]), + e.supervisions[0].custom["disfluent_tag"].split(","), + ) + for e in refs + ], + key=lambda x: x[0], + ) + for_each_recogs(recogs_file, refs, args.res_dir) + + elif recogs_file.is_dir(): + recogs_file_path = recogs_file + for partname in ["eval1", "eval2", "eval3", "excluded", "valid"]: + refs: CutSet = CutSet.from_file(args.cut / f"csj_cuts_{partname}.jsonl.gz") + refs = sorted( + [ + ( + r.id, + list(r.supervisions[0].custom["disfluent"]), + r.supervisions[0].custom["disfluent_tag"].split(","), + ) + for r in refs + ], + key=lambda x: x[0], + ) + for recogs_file in recogs_file_path.glob(f"recogs-{partname}-*.txt"): + for_each_recogs(recogs_file, refs, args.res_dir) + + else: + raise TypeError(f"Unrecognised recogs file provided: {recogs_file}") + + +if __name__ == "__main__": + main() diff --git a/egs/csj/ASR/local/display_manifest_statistics.py b/egs/csj/ASR/local/display_manifest_statistics.py index c043cf853..924474d33 100644 --- a/egs/csj/ASR/local/display_manifest_statistics.py +++ b/egs/csj/ASR/local/display_manifest_statistics.py @@ -45,8 +45,8 @@ def get_parser(): def main(): args = get_parser() - for path in args.manifest_dir.glob("csj_cuts_*.jsonl.gz"): - + for part in ["eval1", "eval2", "eval3", "valid", "excluded", "train"]: + path = args.manifest_dir / f"csj_cuts_{part}.jsonl.gz" cuts: CutSet = load_manifest(path) print("\n---------------------------------\n") @@ -58,123 +58,271 @@ if __name__ == "__main__": main() """ -## eval1 -Cuts count: 1272 -Total duration (hh:mm:ss): 01:50:07 -Speech duration (hh:mm:ss): 01:50:07 (100.0%) -Duration statistics (seconds): -mean 5.2 -std 3.9 -min 0.2 -25% 1.9 -50% 4.0 -75% 8.1 -99% 14.3 -99.5% 14.7 -99.9% 16.0 -max 16.9 -Recordings available: 1272 -Features available: 1272 -Supervisions available: 1272 +csj_cuts_eval1.jsonl.gz: +Cut statistics: +╒═══════════════════════════╤══════════╕ +│ Cuts count: │ 1023 │ +├───────────────────────────┼──────────┤ +│ Total duration (hh:mm:ss) │ 01:55:40 │ +├───────────────────────────┼──────────┤ +│ mean │ 6.8 │ +├───────────────────────────┼──────────┤ +│ std │ 2.7 │ +├───────────────────────────┼──────────┤ +│ min │ 0.2 │ +├───────────────────────────┼──────────┤ +│ 25% │ 4.9 │ +├───────────────────────────┼──────────┤ +│ 50% │ 7.7 │ +├───────────────────────────┼──────────┤ +│ 75% │ 9.0 │ +├───────────────────────────┼──────────┤ +│ 99% │ 10.0 │ +├───────────────────────────┼──────────┤ +│ 99.5% │ 10.0 │ +├───────────────────────────┼──────────┤ +│ 99.9% │ 10.0 │ +├───────────────────────────┼──────────┤ +│ max │ 10.0 │ +├───────────────────────────┼──────────┤ +│ Recordings available: │ 1023 │ +├───────────────────────────┼──────────┤ +│ Features available: │ 0 │ +├───────────────────────────┼──────────┤ +│ Supervisions available: │ 1023 │ +╘═══════════════════════════╧══════════╛ SUPERVISION custom fields: -- fluent (in 1272 cuts) -- disfluent (in 1272 cuts) -- number (in 1272 cuts) -- symbol (in 1272 cuts) +Speech duration statistics: +╒══════════════════════════════╤══════════╤══════════════════════╕ +│ Total speech duration │ 01:55:40 │ 100.00% of recording │ +├──────────────────────────────┼──────────┼──────────────────────┤ +│ Total speaking time duration │ 01:55:40 │ 100.00% of recording │ +├──────────────────────────────┼──────────┼──────────────────────┤ +│ Total silence duration │ 00:00:00 │ 0.00% of recording │ +╘══════════════════════════════╧══════════╧══════════════════════╛ -## eval2 -Cuts count: 1292 -Total duration (hh:mm:ss): 01:56:50 -Speech duration (hh:mm:ss): 01:56:50 (100.0%) -Duration statistics (seconds): -mean 5.4 -std 3.9 -min 0.1 -25% 2.1 -50% 4.6 -75% 8.6 -99% 14.1 -99.5% 15.2 -99.9% 16.1 -max 16.9 -Recordings available: 1292 -Features available: 1292 -Supervisions available: 1292 -SUPERVISION custom fields: -- fluent (in 1292 cuts) -- number (in 1292 cuts) -- symbol (in 1292 cuts) -- disfluent (in 1292 cuts) +--------------------------------- -## eval3 -Cuts count: 1385 -Total duration (hh:mm:ss): 01:19:21 -Speech duration (hh:mm:ss): 01:19:21 (100.0%) -Duration statistics (seconds): -mean 3.4 -std 3.0 -min 0.2 -25% 1.2 -50% 2.5 -75% 4.6 -99% 12.7 -99.5% 13.7 -99.9% 15.0 -max 15.9 -Recordings available: 1385 -Features available: 1385 -Supervisions available: 1385 +csj_cuts_eval2.jsonl.gz: +Cut statistics: +╒═══════════════════════════╤══════════╕ +│ Cuts count: │ 1025 │ +├───────────────────────────┼──────────┤ +│ Total duration (hh:mm:ss) │ 02:02:07 │ +├───────────────────────────┼──────────┤ +│ mean │ 7.1 │ +├───────────────────────────┼──────────┤ +│ std │ 2.5 │ +├───────────────────────────┼──────────┤ +│ min │ 0.1 │ +├───────────────────────────┼──────────┤ +│ 25% │ 5.9 │ +├───────────────────────────┼──────────┤ +│ 50% │ 7.9 │ +├───────────────────────────┼──────────┤ +│ 75% │ 9.1 │ +├───────────────────────────┼──────────┤ +│ 99% │ 10.0 │ +├───────────────────────────┼──────────┤ +│ 99.5% │ 10.0 │ +├───────────────────────────┼──────────┤ +│ 99.9% │ 10.0 │ +├───────────────────────────┼──────────┤ +│ max │ 10.0 │ +├───────────────────────────┼──────────┤ +│ Recordings available: │ 1025 │ +├───────────────────────────┼──────────┤ +│ Features available: │ 0 │ +├───────────────────────────┼──────────┤ +│ Supervisions available: │ 1025 │ +╘═══════════════════════════╧══════════╛ SUPERVISION custom fields: -- number (in 1385 cuts) -- symbol (in 1385 cuts) -- fluent (in 1385 cuts) -- disfluent (in 1385 cuts) +Speech duration statistics: +╒══════════════════════════════╤══════════╤══════════════════════╕ +│ Total speech duration │ 02:02:07 │ 100.00% of recording │ +├──────────────────────────────┼──────────┼──────────────────────┤ +│ Total speaking time duration │ 02:02:07 │ 100.00% of recording │ +├──────────────────────────────┼──────────┼──────────────────────┤ +│ Total silence duration │ 00:00:00 │ 0.00% of recording │ +╘══════════════════════════════╧══════════╧══════════════════════╛ -## valid -Cuts count: 4000 -Total duration (hh:mm:ss): 05:08:09 -Speech duration (hh:mm:ss): 05:08:09 (100.0%) -Duration statistics (seconds): -mean 4.6 -std 3.8 -min 0.1 -25% 1.5 -50% 3.4 -75% 7.0 -99% 13.8 -99.5% 14.8 -99.9% 16.0 -max 17.3 -Recordings available: 4000 -Features available: 4000 -Supervisions available: 4000 -SUPERVISION custom fields: -- fluent (in 4000 cuts) -- symbol (in 4000 cuts) -- disfluent (in 4000 cuts) -- number (in 4000 cuts) +--------------------------------- -## train -Cuts count: 1291134 -Total duration (hh:mm:ss): 1596:37:27 -Speech duration (hh:mm:ss): 1596:37:27 (100.0%) -Duration statistics (seconds): -mean 4.5 -std 3.6 -min 0.0 -25% 1.6 -50% 3.3 -75% 6.4 -99% 14.0 -99.5% 14.8 -99.9% 16.6 -max 27.8 -Recordings available: 1291134 -Features available: 1291134 -Supervisions available: 1291134 +csj_cuts_eval3.jsonl.gz: +Cut statistics: +╒═══════════════════════════╤══════════╕ +│ Cuts count: │ 865 │ +├───────────────────────────┼──────────┤ +│ Total duration (hh:mm:ss) │ 01:26:44 │ +├───────────────────────────┼──────────┤ +│ mean │ 6.0 │ +├───────────────────────────┼──────────┤ +│ std │ 3.0 │ +├───────────────────────────┼──────────┤ +│ min │ 0.3 │ +├───────────────────────────┼──────────┤ +│ 25% │ 3.3 │ +├───────────────────────────┼──────────┤ +│ 50% │ 6.8 │ +├───────────────────────────┼──────────┤ +│ 75% │ 8.7 │ +├───────────────────────────┼──────────┤ +│ 99% │ 10.0 │ +├───────────────────────────┼──────────┤ +│ 99.5% │ 10.0 │ +├───────────────────────────┼──────────┤ +│ 99.9% │ 10.0 │ +├───────────────────────────┼──────────┤ +│ max │ 10.0 │ +├───────────────────────────┼──────────┤ +│ Recordings available: │ 865 │ +├───────────────────────────┼──────────┤ +│ Features available: │ 0 │ +├───────────────────────────┼──────────┤ +│ Supervisions available: │ 865 │ +╘═══════════════════════════╧══════════╛ SUPERVISION custom fields: -- disfluent (in 1291134 cuts) -- fluent (in 1291134 cuts) -- symbol (in 1291134 cuts) -- number (in 1291134 cuts) +Speech duration statistics: +╒══════════════════════════════╤══════════╤══════════════════════╕ +│ Total speech duration │ 01:26:44 │ 100.00% of recording │ +├──────────────────────────────┼──────────┼──────────────────────┤ +│ Total speaking time duration │ 01:26:44 │ 100.00% of recording │ +├──────────────────────────────┼──────────┼──────────────────────┤ +│ Total silence duration │ 00:00:00 │ 0.00% of recording │ +╘══════════════════════════════╧══════════╧══════════════════════╛ + +--------------------------------- + +csj_cuts_valid.jsonl.gz: +Cut statistics: +╒═══════════════════════════╤══════════╕ +│ Cuts count: │ 3743 │ +├───────────────────────────┼──────────┤ +│ Total duration (hh:mm:ss) │ 06:40:15 │ +├───────────────────────────┼──────────┤ +│ mean │ 6.4 │ +├───────────────────────────┼──────────┤ +│ std │ 3.0 │ +├───────────────────────────┼──────────┤ +│ min │ 0.1 │ +├───────────────────────────┼──────────┤ +│ 25% │ 3.9 │ +├───────────────────────────┼──────────┤ +│ 50% │ 7.4 │ +├───────────────────────────┼──────────┤ +│ 75% │ 9.0 │ +├───────────────────────────┼──────────┤ +│ 99% │ 10.0 │ +├───────────────────────────┼──────────┤ +│ 99.5% │ 10.0 │ +├───────────────────────────┼──────────┤ +│ 99.9% │ 10.1 │ +├───────────────────────────┼──────────┤ +│ max │ 11.8 │ +├───────────────────────────┼──────────┤ +│ Recordings available: │ 3743 │ +├───────────────────────────┼──────────┤ +│ Features available: │ 0 │ +├───────────────────────────┼──────────┤ +│ Supervisions available: │ 3743 │ +╘═══════════════════════════╧══════════╛ +SUPERVISION custom fields: +Speech duration statistics: +╒══════════════════════════════╤══════════╤══════════════════════╕ +│ Total speech duration │ 06:40:15 │ 100.00% of recording │ +├──────────────────────────────┼──────────┼──────────────────────┤ +│ Total speaking time duration │ 06:40:15 │ 100.00% of recording │ +├──────────────────────────────┼──────────┼──────────────────────┤ +│ Total silence duration │ 00:00:00 │ 0.00% of recording │ +╘══════════════════════════════╧══════════╧══════════════════════╛ + +--------------------------------- + +csj_cuts_excluded.jsonl.gz: +Cut statistics: +╒═══════════════════════════╤══════════╕ +│ Cuts count: │ 980 │ +├───────────────────────────┼──────────┤ +│ Total duration (hh:mm:ss) │ 00:56:06 │ +├───────────────────────────┼──────────┤ +│ mean │ 3.4 │ +├───────────────────────────┼──────────┤ +│ std │ 3.1 │ +├───────────────────────────┼──────────┤ +│ min │ 0.1 │ +├───────────────────────────┼──────────┤ +│ 25% │ 0.8 │ +├───────────────────────────┼──────────┤ +│ 50% │ 2.2 │ +├───────────────────────────┼──────────┤ +│ 75% │ 5.8 │ +├───────────────────────────┼──────────┤ +│ 99% │ 9.9 │ +├───────────────────────────┼──────────┤ +│ 99.5% │ 9.9 │ +├───────────────────────────┼──────────┤ +│ 99.9% │ 10.0 │ +├───────────────────────────┼──────────┤ +│ max │ 10.0 │ +├───────────────────────────┼──────────┤ +│ Recordings available: │ 980 │ +├───────────────────────────┼──────────┤ +│ Features available: │ 0 │ +├───────────────────────────┼──────────┤ +│ Supervisions available: │ 980 │ +╘═══════════════════════════╧══════════╛ +SUPERVISION custom fields: +Speech duration statistics: +╒══════════════════════════════╤══════════╤══════════════════════╕ +│ Total speech duration │ 00:56:06 │ 100.00% of recording │ +├──────────────────────────────┼──────────┼──────────────────────┤ +│ Total speaking time duration │ 00:56:06 │ 100.00% of recording │ +├──────────────────────────────┼──────────┼──────────────────────┤ +│ Total silence duration │ 00:00:00 │ 0.00% of recording │ +╘══════════════════════════════╧══════════╧══════════════════════╛ + +--------------------------------- + +csj_cuts_train.jsonl.gz: +Cut statistics: +╒═══════════════════════════╤════════════╕ +│ Cuts count: │ 914151 │ +├───────────────────────────┼────────────┤ +│ Total duration (hh:mm:ss) │ 1695:29:43 │ +├───────────────────────────┼────────────┤ +│ mean │ 6.7 │ +├───────────────────────────┼────────────┤ +│ std │ 2.9 │ +├───────────────────────────┼────────────┤ +│ min │ 0.1 │ +├───────────────────────────┼────────────┤ +│ 25% │ 4.6 │ +├───────────────────────────┼────────────┤ +│ 50% │ 7.5 │ +├───────────────────────────┼────────────┤ +│ 75% │ 8.9 │ +├───────────────────────────┼────────────┤ +│ 99% │ 11.0 │ +├───────────────────────────┼────────────┤ +│ 99.5% │ 11.0 │ +├───────────────────────────┼────────────┤ +│ 99.9% │ 11.1 │ +├───────────────────────────┼────────────┤ +│ max │ 18.0 │ +├───────────────────────────┼────────────┤ +│ Recordings available: │ 914151 │ +├───────────────────────────┼────────────┤ +│ Features available: │ 0 │ +├───────────────────────────┼────────────┤ +│ Supervisions available: │ 914151 │ +╘═══════════════════════════╧════════════╛ +SUPERVISION custom fields: +Speech duration statistics: +╒══════════════════════════════╤════════════╤══════════════════════╕ +│ Total speech duration │ 1695:29:43 │ 100.00% of recording │ +├──────────────────────────────┼────────────┼──────────────────────┤ +│ Total speaking time duration │ 1695:29:43 │ 100.00% of recording │ +├──────────────────────────────┼────────────┼──────────────────────┤ +│ Total silence duration │ 00:00:00 │ 0.00% of recording │ +╘══════════════════════════════╧════════════╧══════════════════════╛ """ diff --git a/egs/csj/ASR/local/prepare_lang_char.py b/egs/csj/ASR/local/prepare_lang_char.py index 16107f543..58b197922 100644 --- a/egs/csj/ASR/local/prepare_lang_char.py +++ b/egs/csj/ASR/local/prepare_lang_char.py @@ -21,24 +21,14 @@ import logging from pathlib import Path from lhotse import CutSet +from lhotse.recipes.csj import CSJSDBParser ARGPARSE_DESCRIPTION = """ -This script gathers all training transcripts of the specified {trans_mode} type -and produces a token_list that would be output set of the ASR system. +This script gathers all training transcripts, parses them in disfluent mode, and produces a token list that would be the output set of the ASR system. -It splits transcripts by whitespace into lists, then, for each word in the -list, if the word does not appear in the list of user-defined multicharacter -strings, it further splits that word into individual characters to be counted -into the output token set. - -It outputs 4 files into the lang directory: -- trans_mode: the name of transcript mode. If trans_mode was not specified, - this will be an empty file. -- userdef_string: a list of user defined strings that should not be split - further into individual characters. By default, it contains "", "", - "" -- words_len: the total number of tokens in the output set. -- words.txt: a list of tokens in the output set. The length matches words_len. +It outputs 3 files into the lang directory: +- tokens.txt: a list of tokens in the output set. +- lang_type: a file that contains the string "char" """ @@ -50,98 +40,52 @@ def get_args(): ) parser.add_argument( - "--train-cut", type=Path, required=True, help="Path to the train cut" - ) - - parser.add_argument( - "--trans-mode", - type=str, - default=None, - help=( - "Name of the transcript mode to use. " - "If lang-dir is not set, this will also name the lang-dir" - ), + "train_cut", metavar="train-cut", type=Path, help="Path to the train cut" ) parser.add_argument( "--lang-dir", type=Path, - default=None, + default=Path("data/lang_char"), help=( "Name of lang dir. " "If not set, this will default to lang_char_{trans-mode}" ), ) - parser.add_argument( - "--userdef-string", - type=Path, - default=None, - help="Multicharacter strings that do not need to be split", - ) - return parser.parse_args() def main(): args = get_args() - logging.basicConfig( format=("%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"), level=logging.INFO, ) - if not args.lang_dir: - p = "lang_char" - if args.trans_mode: - p += f"_{args.trans_mode}" - args.lang_dir = Path(p) + sysdef_string = set(["", "", ""]) - if args.userdef_string: - args.userdef_string = set(args.userdef_string.read_text().split()) - else: - args.userdef_string = set() + # Using disfluent parsing as fluent is a subset of disfluent + parser = CSJSDBParser() - sysdef_string = ["", "", ""] - args.userdef_string.update(sysdef_string) + token_set = set() + logging.info(f"Creating vocabulary from {args.train_cut}.") + train_cut: CutSet = CutSet.from_file(args.train_cut) + for cut in train_cut: + if "_sp" in cut.id: + continue - train_set: CutSet = CutSet.from_file(args.train_cut) - - words = set() - logging.info( - f"Creating vocabulary from {args.train_cut.name} at {args.trans_mode} mode." - ) - for cut in train_set: - try: - text: str = ( - cut.supervisions[0].custom[args.trans_mode] - if args.trans_mode - else cut.supervisions[0].text - ) - except KeyError: - raise KeyError( - f"Could not find {args.trans_mode} in {cut.supervisions[0].custom}" - ) - for t in text.split(): - if t in args.userdef_string: - words.add(t) - else: - words.update(c for c in list(t)) - - words -= set(sysdef_string) - words = sorted(words) - words = [""] + words + ["", ""] + text: str = cut.supervisions[0].custom["raw"] + for w in parser.parse(text, sep=" ").split(" "): + token_set.update(w) + token_set = [""] + sorted(token_set - sysdef_string) + ["", ""] args.lang_dir.mkdir(parents=True, exist_ok=True) - (args.lang_dir / "words.txt").write_text( - "\n".join(f"{word}\t{i}" for i, word in enumerate(words)) + (args.lang_dir / "tokens.txt").write_text( + "\n".join(f"{t}\t{i}" for i, t in enumerate(token_set)) ) - (args.lang_dir / "words_len").write_text(f"{len(words)}") - - (args.lang_dir / "userdef_string").write_text("\n".join(args.userdef_string)) - - (args.lang_dir / "trans_mode").write_text(args.trans_mode) + (args.lang_dir / "lang_type").write_text("char") logging.info("Done.") diff --git a/egs/csj/ASR/local/utils/asr_datamodule.py b/egs/csj/ASR/local/utils/asr_datamodule.py new file mode 100644 index 000000000..619820a75 --- /dev/null +++ b/egs/csj/ASR/local/utils/asr_datamodule.py @@ -0,0 +1,462 @@ +# Copyright 2021 Piotr Żelasko +# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import inspect +import logging +from functools import lru_cache +from pathlib import Path +from typing import Any, Dict, List, Optional, Union + +import torch +from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy +from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures + CutConcatenate, + CutMix, + DynamicBucketingSampler, + K2SpeechRecognitionDataset, + PrecomputedFeatures, + SingleCutSampler, + SpecAugment, +) +from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples + AudioSamples, + OnTheFlyFeatures, +) +from lhotse.utils import fix_random_seed +from torch.utils.data import DataLoader + +from icefall.utils import str2bool + + +class _SeedWorkers: + def __init__(self, seed: int): + self.seed = seed + + def __call__(self, worker_id: int): + fix_random_seed(self.seed + worker_id) + + +class AsrVariableTranscriptDataset(K2SpeechRecognitionDataset): + def __init__( + self, + *args, + transcript_mode: str = "", + return_cuts: bool = False, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.transcript_mode = transcript_mode + self.return_cuts = True + self._return_cuts = return_cuts + + def __getitem__(self, cuts: CutSet) -> Dict[str, Union[torch.Tensor, List[str]]]: + batch = super().__getitem__(cuts) + + if self.transcript_mode: + batch["supervisions"]["text"] = [ + supervision.custom[self.transcript_mode] + for cut in batch["supervisions"]["cut"] + for supervision in cut.supervisions + ] + + if not self._return_cuts: + del batch["supervisions"]["cut"] + + return batch + + +class CSJAsrDataModule: + """ + DataModule for k2 ASR experiments. + It assumes there is always one train and valid dataloader, + but there can be multiple test dataloaders (e.g. LibriSpeech test-clean + and test-other). + It contains all the common data pipeline modules used in ASR + experiments, e.g.: + - dynamic batch size, + - bucketing samplers, + - cut concatenation, + - augmentation, + - on-the-fly feature extraction + This class should be derived for specific corpora used in ASR tasks. + """ + + def __init__(self, args: argparse.Namespace): + self.args = args + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="ASR data related options", + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", + ) + + group.add_argument( + "--transcript-mode", + type=str, + default="", + help="Mode of transcript in supervision to use.", + ) + group.add_argument( + "--manifest-dir", + type=Path, + default=Path("data/manifests"), + help="Path to directory with train/valid/test cuts.", + ) + group.add_argument( + "--musan-dir", type=Path, help="Path to directory with musan cuts. " + ) + group.add_argument( + "--max-duration", + type=int, + default=200.0, + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", + ) + group.add_argument( + "--bucketing-sampler", + type=str2bool, + default=True, + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", + ) + group.add_argument( + "--num-buckets", + type=int, + default=30, + help="The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets).", + ) + group.add_argument( + "--concatenate-cuts", + type=str2bool, + default=False, + help="When enabled, utterances (cuts) will be concatenated " + "to minimize the amount of padding.", + ) + group.add_argument( + "--duration-factor", + type=float, + default=1.0, + help="Determines the maximum duration of a concatenated cut " + "relative to the duration of the longest cut in a batch.", + ) + group.add_argument( + "--gap", + type=float, + default=1.0, + help="The amount of padding (in seconds) inserted between " + "concatenated cuts. This padding is filled with noise when " + "noise augmentation is used.", + ) + group.add_argument( + "--on-the-fly-feats", + type=str2bool, + default=False, + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", + ) + group.add_argument( + "--shuffle", + type=str2bool, + default=True, + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", + ) + group.add_argument( + "--drop-last", + type=str2bool, + default=True, + help="Whether to drop last batch. Used by sampler.", + ) + group.add_argument( + "--return-cuts", + type=str2bool, + default=False, + help="When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it.", + ) + + group.add_argument( + "--num-workers", + type=int, + default=2, + help="The number of training dataloader workers that " + "collect the batches.", + ) + + group.add_argument( + "--enable-spec-aug", + type=str2bool, + default=True, + help="When enabled, use SpecAugment for training dataset.", + ) + + group.add_argument( + "--spec-aug-time-warp-factor", + type=int, + default=80, + help="Used only when --enable-spec-aug is True. " + "It specifies the factor for time warping in SpecAugment. " + "Larger values mean more warping. " + "A value less than 1 means to disable time warp.", + ) + + group.add_argument( + "--enable-musan", + type=str2bool, + default=True, + help="When enabled, select noise from MUSAN and mix it" + "with training dataset. ", + ) + + group.add_argument( + "--input-strategy", + type=str, + default="PrecomputedFeatures", + help="AudioSamples or PrecomputedFeatures", + ) + + def train_dataloaders( + self, + cuts_train: CutSet, + sampler_state_dict: Optional[Dict[str, Any]] = None, + ) -> DataLoader: + """ + Args: + cuts_train: + CutSet for training. + sampler_state_dict: + The state dict for the training sampler. + """ + transforms = [] + if self.args.enable_musan: + logging.info("Enable MUSAN") + logging.info("About to get Musan cuts") + cuts_musan = load_manifest(self.args.musan_dir / "musan_cuts.jsonl.gz") + transforms.append( + CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + ) + else: + logging.info("Disable MUSAN") + + if self.args.concatenate_cuts: + logging.info( + f"Using cut concatenation with duration factor " + f"{self.args.duration_factor} and gap {self.args.gap}." + ) + # Cut concatenation should be the first transform in the list, + # so that if we e.g. mix noise in, it will fill the gaps between + # different utterances. + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + input_transforms = [] + if self.args.enable_spec_aug: + logging.info("Enable SpecAugment") + logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") + # Set the value of num_frame_masks according to Lhotse's version. + # In different Lhotse's versions, the default of num_frame_masks is + # different. + num_frame_masks = 10 + num_frame_masks_parameter = inspect.signature( + SpecAugment.__init__ + ).parameters["num_frame_masks"] + if num_frame_masks_parameter.default == 1: + num_frame_masks = 2 + logging.info(f"Num frame mask: {num_frame_masks}") + input_transforms.append( + SpecAugment( + time_warp_factor=self.args.spec_aug_time_warp_factor, + num_frame_masks=num_frame_masks, + features_mask_size=27, + num_feature_masks=2, + frames_mask_size=100, + ) + ) + else: + logging.info("Disable SpecAugment") + + logging.info("About to create train dataset") + train = AsrVariableTranscriptDataset( + input_strategy=eval(self.args.input_strategy)(), + cut_transforms=transforms, + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + transcript_mode=self.args.transcript_mode, + ) + + if self.args.on_the_fly_feats: + # NOTE: the PerturbSpeed transform should be added only if we + # remove it from data prep stage. + # Add on-the-fly speed perturbation; since originally it would + # have increased epoch size by 3, we will apply prob 2/3 and use + # 3x more epochs. + # Speed perturbation probably should come first before + # concatenation, but in principle the transforms order doesn't have + # to be strict (e.g. could be randomized) + # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa + # Drop feats to be on the safe side. + train = AsrVariableTranscriptDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + transcript_mode=self.args.transcript_mode, + ) + + if self.args.bucketing_sampler: + logging.info("Using DynamicBucketingSampler.") + train_sampler = DynamicBucketingSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + num_buckets=self.args.num_buckets, + drop_last=self.args.drop_last, + ) + else: + logging.info("Using SingleCutSampler.") + train_sampler = SingleCutSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + ) + logging.info("About to create train dataloader") + + if sampler_state_dict is not None: + logging.info("Loading sampler state dict") + train_sampler.load_state_dict(sampler_state_dict) + + # 'seed' is derived from the current random state, which will have + # previously been set in the main process. + seed = torch.randint(0, 100000, ()).item() + worker_init_fn = _SeedWorkers(seed) + + train_dl = DataLoader( + train, + sampler=train_sampler, + batch_size=None, + num_workers=self.args.num_workers, + persistent_workers=False, + worker_init_fn=worker_init_fn, + ) + + return train_dl + + def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: + transforms = [] + if self.args.concatenate_cuts: + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + logging.info("About to create dev dataset") + if self.args.on_the_fly_feats: + validate = AsrVariableTranscriptDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + return_cuts=self.args.return_cuts, + transcript_mode=self.args.transcript_mode, + ) + else: + validate = AsrVariableTranscriptDataset( + cut_transforms=transforms, + return_cuts=self.args.return_cuts, + transcript_mode=self.args.transcript_mode, + ) + valid_sampler = DynamicBucketingSampler( + cuts_valid, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.info("About to create dev dataloader") + valid_dl = DataLoader( + validate, + sampler=valid_sampler, + batch_size=None, + num_workers=2, + persistent_workers=False, + ) + + return valid_dl + + def test_dataloaders(self, cuts: CutSet) -> DataLoader: + logging.debug("About to create test dataset") + + test = AsrVariableTranscriptDataset( + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) + if self.args.on_the_fly_feats + else eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + transcript_mode=self.args.transcript_mode, + ) + sampler = DynamicBucketingSampler( + cuts, + max_duration=self.args.max_duration, + shuffle=False, + ) + + logging.debug("About to create test dataloader") + test_dl = DataLoader( + test, + batch_size=None, + sampler=sampler, + num_workers=self.args.num_workers, + ) + return test_dl + + @lru_cache() + def train_cuts(self) -> CutSet: + logging.info("About to get train cuts") + return load_manifest_lazy(self.args.manifest_dir / "csj_cuts_train.jsonl.gz") + + @lru_cache() + def valid_cuts(self) -> CutSet: + logging.info("About to get valid cuts") + return load_manifest_lazy(self.args.manifest_dir / "csj_cuts_valid.jsonl.gz") + + @lru_cache() + def excluded_cuts(self) -> CutSet: + logging.info("About to get excluded cuts") + return load_manifest_lazy(self.args.manifest_dir / "csj_cuts_excluded.jsonl.gz") + + @lru_cache() + def eval1_cuts(self) -> CutSet: + logging.info("About to get eval1 cuts") + return load_manifest_lazy(self.args.manifest_dir / "csj_cuts_eval1.jsonl.gz") + + @lru_cache() + def eval2_cuts(self) -> CutSet: + logging.info("About to get eval2 cuts") + return load_manifest_lazy(self.args.manifest_dir / "csj_cuts_eval2.jsonl.gz") + + @lru_cache() + def eval3_cuts(self) -> CutSet: + logging.info("About to get eval3 cuts") + return load_manifest_lazy(self.args.manifest_dir / "csj_cuts_eval3.jsonl.gz") diff --git a/egs/csj/ASR/local/utils/tokenizer.py b/egs/csj/ASR/local/utils/tokenizer.py new file mode 100644 index 000000000..c9be72be1 --- /dev/null +++ b/egs/csj/ASR/local/utils/tokenizer.py @@ -0,0 +1,253 @@ +import argparse +from pathlib import Path +from typing import Callable, List, Union + +import sentencepiece as spm +from k2 import SymbolTable + + +class Tokenizer: + text2word: Callable[[str], List[str]] + + @staticmethod + def add_arguments(parser: argparse.ArgumentParser): + group = parser.add_argument_group(title="Lang related options") + + group.add_argument("--lang", type=Path, help="Path to lang directory.") + + group.add_argument( + "--lang-type", + type=str, + default=None, + help=( + "Either 'bpe' or 'char'. If not provided, it expects lang_dir/lang_type to exists. " + "Note: 'bpe' directly loads sentencepiece.SentencePieceProcessor" + ), + ) + + @staticmethod + def Load(lang_dir: Path, lang_type="", oov=""): + + if not lang_type: + assert (lang_dir / "lang_type").exists(), "lang_type not specified." + lang_type = (lang_dir / "lang_type").read_text().strip() + + tokenizer = None + + if lang_type == "bpe": + assert ( + lang_dir / "bpe.model" + ).exists(), f"No BPE .model could be found in {lang_dir}." + tokenizer = spm.SentencePieceProcessor() + tokenizer.Load(str(lang_dir / "bpe.model")) + elif lang_type == "char": + tokenizer = CharTokenizer(lang_dir, oov=oov) + else: + raise NotImplementedError(f"{lang_type} not supported at the moment.") + + return tokenizer + + load = Load + + def PieceToId(self, piece: str) -> int: + raise NotImplementedError( + "You need to implement this function in the child class." + ) + + piece_to_id = PieceToId + + def IdToPiece(self, id: int) -> str: + raise NotImplementedError( + "You need to implement this function in the child class." + ) + + id_to_piece = IdToPiece + + def GetPieceSize(self) -> int: + raise NotImplementedError( + "You need to implement this function in the child class." + ) + + get_piece_size = GetPieceSize + + def __len__(self) -> int: + return self.get_piece_size() + + def EncodeAsIdsBatch(self, input: List[str]) -> List[List[int]]: + raise NotImplementedError( + "You need to implement this function in the child class." + ) + + def EncodeAsPiecesBatch(self, input: List[str]) -> List[List[str]]: + raise NotImplementedError( + "You need to implement this function in the child class." + ) + + def EncodeAsIds(self, input: str) -> List[int]: + return self.EncodeAsIdsBatch([input])[0] + + def EncodeAsPieces(self, input: str) -> List[str]: + return self.EncodeAsPiecesBatch([input])[0] + + def Encode( + self, input: Union[str, List[str]], out_type=int + ) -> Union[List, List[List]]: + if not input: + return [] + + if isinstance(input, list): + if out_type is int: + return self.EncodeAsIdsBatch(input) + if out_type is str: + return self.EncodeAsPiecesBatch(input) + + if out_type is int: + return self.EncodeAsIds(input) + if out_type is str: + return self.EncodeAsPieces(input) + + encode = Encode + + def DecodeIdsBatch(self, input: List[List[int]]) -> List[str]: + raise NotImplementedError( + "You need to implement this function in the child class." + ) + + def DecodePiecesBatch(self, input: List[List[str]]) -> List[str]: + raise NotImplementedError( + "You need to implement this function in the child class." + ) + + def DecodeIds(self, input: List[int]) -> str: + return self.DecodeIdsBatch([input])[0] + + def DecodePieces(self, input: List[str]) -> str: + return self.DecodePiecesBatch([input])[0] + + def Decode( + self, + input: Union[int, List[int], List[str], List[List[int]], List[List[str]]], + ) -> Union[List[str], str]: + + if not input: + return "" + + if isinstance(input, int): + return self.id_to_piece(input) + elif isinstance(input, str): + raise TypeError( + "Unlike spm.SentencePieceProcessor, cannot decode from type str." + ) + + if isinstance(input[0], list): + if not input[0] or isinstance(input[0][0], int): + return self.DecodeIdsBatch(input) + + if isinstance(input[0][0], str): + return self.DecodePiecesBatch(input) + + if isinstance(input[0], int): + return self.DecodeIds(input) + if isinstance(input[0], str): + return self.DecodePieces(input) + + raise RuntimeError("Unknown input type") + + decode = Decode + + def SplitBatch(self, input: List[str]) -> List[List[str]]: + raise NotImplementedError( + "You need to implement this function in the child class." + ) + + def Split(self, input: Union[List[str], str]) -> Union[List[List[str]], List[str]]: + if isinstance(input, list): + return self.SplitBatch(input) + elif isinstance(input, str): + return self.SplitBatch([input])[0] + raise RuntimeError("Unknown input type") + + split = Split + + +class CharTokenizer(Tokenizer): + def __init__(self, lang_dir: Path, oov="", sep=""): + assert ( + lang_dir / "tokens.txt" + ).exists(), f"tokens.txt could not be found in {lang_dir}." + token_table = SymbolTable.from_file(lang_dir / "tokens.txt") + assert ( + "#0" not in token_table + ), "This tokenizer does not support disambig symbols." + self._id2sym = token_table._id2sym + self._sym2id = token_table._sym2id + self.oov = oov + self.oov_id = self._sym2id[oov] + self.sep = sep + if self.sep: + self.text2word = lambda x: x.split(self.sep) + else: + self.text2word = lambda x: list(x.replace(" ", "")) + + def piece_to_id(self, piece: str) -> int: + try: + return self._sym2id[piece] + except KeyError: + return self.oov_id + + def id_to_piece(self, id: int) -> str: + return self._id2sym[id] + + def get_piece_size(self) -> int: + return len(self._sym2id) + + def EncodeAsIdsBatch(self, input: List[str]) -> List[List[int]]: + return [[self.piece_to_id(i) for i in self.text2word(text)] for text in input] + + def EncodeAsPiecesBatch(self, input: List[str]) -> List[List[str]]: + return [ + [i if i in self._sym2id else self.oov for i in self.text2word(text)] + for text in input + ] + + def DecodeIdsBatch(self, input: List[List[int]]) -> List[str]: + return [self.sep.join(self.id_to_piece(i) for i in text) for text in input] + + def DecodePiecesBatch(self, input: List[List[str]]) -> List[str]: + return [self.sep.join(text) for text in input] + + def SplitBatch(self, input: List[str]) -> List[List[str]]: + return [self.text2word(text) for text in input] + + +def test_CharTokenizer(): + test_single_string = "こんにちは" + test_multiple_string = [ + "今日はいい天気ですよね", + "諏訪湖は綺麗でしょう", + "这在词表外", + "分かち 書き に し た 文章 です", + "", + ] + test_empty_string = "" + sp = Tokenizer.load(Path("lang_char"), "char", oov="") + splitter = sp.split + print(sp.encode(test_single_string, out_type=str)) + print(sp.encode(test_single_string, out_type=int)) + print(sp.encode(test_multiple_string, out_type=str)) + print(sp.encode(test_multiple_string, out_type=int)) + print(sp.encode(test_empty_string, out_type=str)) + print(sp.encode(test_empty_string, out_type=int)) + print(sp.decode(sp.encode(test_single_string, out_type=str))) + print(sp.decode(sp.encode(test_single_string, out_type=int))) + print(sp.decode(sp.encode(test_multiple_string, out_type=str))) + print(sp.decode(sp.encode(test_multiple_string, out_type=int))) + print(sp.decode(sp.encode(test_empty_string, out_type=str))) + print(sp.decode(sp.encode(test_empty_string, out_type=int))) + print(splitter(test_single_string)) + print(splitter(test_multiple_string)) + print(splitter(test_empty_string)) + + +if __name__ == "__main__": + test_CharTokenizer() diff --git a/egs/csj/ASR/prepare.sh b/egs/csj/ASR/prepare.sh index c4ce91984..52339bb35 100755 --- a/egs/csj/ASR/prepare.sh +++ b/egs/csj/ASR/prepare.sh @@ -32,7 +32,7 @@ # - speech # # By default, this script produces the original transcript like kaldi and espnet. Optionally, you -# can generate other transcript formats by supplying your own config files. A few examples of these +# can add other transcript formats by supplying your own config files. A few examples of these # config files can be found in local/conf. # fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 @@ -44,10 +44,10 @@ nj=8 stage=-1 stop_stage=100 -csj_dir=/mnt/minami_data_server/t2131178/corpus/CSJ -musan_dir=/mnt/minami_data_server/t2131178/corpus/musan/musan -trans_dir=$csj_dir/retranscript -csj_fbank_dir=/mnt/host/csj_data/fbank +csj_dir=/mnt/host/corpus/csj +musan_dir=/mnt/host/corpus/musan/musan +trans_dir=$csj_dir/transcript +csj_fbank_dir=/mnt/host/corpus/csj/fbank musan_fbank_dir=$musan_dir/fbank csj_manifest_dir=data/manifests musan_manifest_dir=$musan_dir/manifests @@ -63,12 +63,8 @@ log() { if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then log "Stage 1: Prepare CSJ manifest" - # If you want to generate more transcript modes, append the path to those config files at c. - # Example: lhotse prepare csj $csj_dir $trans_dir $csj_manifest_dir -c local/conf/disfluent.ini - # NOTE: In case multiple config files are supplied, the second config file and onwards will inherit - # the segment boundaries of the first config file. if [ ! -e $csj_manifest_dir/.csj.done ]; then - lhotse prepare csj $csj_dir $trans_dir $csj_manifest_dir -j 4 + lhotse prepare csj $csj_dir $csj_manifest_dir -t $trans_dir -j 16 touch $csj_manifest_dir/.csj.done fi fi @@ -88,32 +84,24 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then python local/compute_fbank_csj.py --manifest-dir $csj_manifest_dir \ --fbank-dir $csj_fbank_dir parts=( - train - valid eval1 eval2 eval3 + valid + excluded + train ) for part in ${parts[@]}; do - python local/validate_manifest.py --manifest $csj_manifest_dir/csj_cuts_$part.jsonl.gz + python local/validate_manifest.py --manifest $csj_fbank_dir/csj_cuts_$part.jsonl.gz done touch $csj_fbank_dir/.csj-validated.done fi fi if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then - log "Stage 4: Prepare CSJ lang" - modes=disfluent - - # If you want prepare the lang directory for other transcript modes, just append - # the names of those modes behind. An example is shown as below:- - # modes="$modes fluent symbol number" - - for mode in ${modes[@]}; do - python local/prepare_lang_char.py --trans-mode $mode \ - --train-cut $csj_manifest_dir/csj_cuts_train.jsonl.gz \ - --lang-dir lang_char_$mode - done + log "Stage 4: Prepare CSJ lang_char" + python local/prepare_lang_char.py $csj_fbank_dir/csj_cuts_train.jsonl.gz + python local/add_transcript_mode.py -f $csj_fbank_dir -c local/conf/fluent.ini local/conf/number.ini fi if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then @@ -128,6 +116,6 @@ fi if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then log "Stage 6: Show manifest statistics" - python local/display_manifest_statistics.py --manifest-dir $csj_manifest_dir > $csj_manifest_dir/manifest_statistics.txt - cat $csj_manifest_dir/manifest_statistics.txt + python local/display_manifest_statistics.py --manifest-dir $csj_fbank_dir > $csj_fbank_dir/manifest_statistics.txt + cat $csj_fbank_dir/manifest_statistics.txt fi diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/TelegramStreamIO.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/TelegramStreamIO.py new file mode 100644 index 000000000..f5235207a --- /dev/null +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/TelegramStreamIO.py @@ -0,0 +1,76 @@ +import logging +from configparser import ConfigParser + +import requests + + +def escape_html(text: str): + """ + Escapes all html characters in text + :param str text: + :rtype: str + """ + return text.replace("&", "&").replace("<", "<").replace(">", ">") + + +class TelegramStreamIO(logging.Handler): + + API_ENDPOINT = "https://api.telegram.org" + MAX_MESSAGE_LEN = 4096 + formatter = logging.Formatter( + "%(asctime)s - %(levelname)s at %(funcName)s " + "(line %(lineno)s):\n\n%(message)s" + ) + + def __init__(self, tg_configfile: str): + super(TelegramStreamIO, self).__init__() + config = ConfigParser() + if not config.read(tg_configfile): + raise FileNotFoundError( + f"{tg_configfile} not found. " "Retry without --telegram-cred flag." + ) + config = config["TELEGRAM"] + token = config["token"] + self.chat_id = config["chat_id"] + self.url = f"{self.API_ENDPOINT}/bot{token}/sendMessage" + + @staticmethod + def setup_logger(params): + if not params.telegram_cred: + return + formatter = logging.Formatter( + f"{params.exp_dir.name} %(asctime)s \n%(message)s" + ) + tg = TelegramStreamIO(params.telegram_cred) + tg.setLevel(logging.WARN) + tg.setFormatter(formatter) + logging.getLogger("").addHandler(tg) + + def emit(self, record: logging.LogRecord): + """ + Emit a record. + Send the record to the Web server as a percent-encoded dictionary + """ + data = { + "chat_id": self.chat_id, + "text": self.format(self.mapLogRecord(record)), + "parse_mode": "HTML", + } + try: + requests.get(self.url, json=data) + # return response.json() + except Exception as e: + logging.error(f"Failed to send telegram message: {repr(e)}") + pass + + def mapLogRecord(self, record): + """ + Default implementation of mapping the log record into a dict + that is sent as the CGI data. Overwrite in your class. + Contributed by Franz Glasner. + """ + + for k, v in record.__dict__.items(): + if isinstance(v, str): + setattr(record, k, escape_html(v)) + return record diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/asr_datamodule.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/asr_datamodule.py new file mode 120000 index 000000000..a48591198 --- /dev/null +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/asr_datamodule.py @@ -0,0 +1 @@ +../local/utils/asr_datamodule.py \ No newline at end of file diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/beam_search.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/beam_search.py new file mode 120000 index 000000000..d7349b0a3 --- /dev/null +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/beam_search.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/beam_search.py \ No newline at end of file diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/decode.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/decode.py new file mode 100755 index 000000000..19d3c79c8 --- /dev/null +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/decode.py @@ -0,0 +1,852 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) greedy search +./pruned_transducer_stateless7_streaming/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --max-duration 600 \ + --decode-chunk-len 32 \ + --lang data/lang_char \ + --decoding-method greedy_search + +(2) beam search (not recommended) +./pruned_transducer_stateless7_streaming/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --max-duration 600 \ + --decode-chunk-len 32 \ + --decoding-method beam_search \ + --lang data/lang_char \ + --beam-size 4 + +(3) modified beam search +./pruned_transducer_stateless7_streaming/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --max-duration 600 \ + --decode-chunk-len 32 \ + --decoding-method modified_beam_search \ + --lang data/lang_char \ + --beam-size 4 + +(4) fast beam search (one best) +./pruned_transducer_stateless7_streaming/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --max-duration 600 \ + --decode-chunk-len 32 \ + --decoding-method fast_beam_search \ + --beam 20.0 \ + --max-contexts 8 \ + --lang data/lang_char \ + --max-states 64 + +(5) fast beam search (nbest) +./pruned_transducer_stateless7_streaming/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --max-duration 600 \ + --decode-chunk-len 32 \ + --decoding-method fast_beam_search_nbest \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --lang data/lang_char \ + --nbest-scale 0.5 + +(6) fast beam search (nbest oracle WER) +./pruned_transducer_stateless7_streaming/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --max-duration 600 \ + --decode-chunk-len 32 \ + --decoding-method fast_beam_search_nbest_oracle \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --lang data/lang_char \ + --nbest-scale 0.5 + +(7) fast beam search (with LG) +./pruned_transducer_stateless7_streaming/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --max-duration 600 \ + --decode-chunk-len 32 \ + --decoding-method fast_beam_search_nbest_LG \ + --beam 20.0 \ + --max-contexts 8 \ + --lang data/lang_char \ + --max-states 64 +""" + + +import argparse +import logging +import math +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import torch +import torch.nn as nn +from asr_datamodule import CSJAsrDataModule +from beam_search import ( + beam_search, + fast_beam_search_nbest, + fast_beam_search_nbest_LG, + fast_beam_search_nbest_oracle, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, +) +from tokenizer import Tokenizer +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--gpu", + type=int, + default=0, + ) + + parser.add_argument( + "--avg", + type=int, + default=9, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7_streaming/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--res-dir", + type=Path, + default=None, + help="The path to save results.", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_char", + help="The lang dir. It should contain at least a word table.", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - fast_beam_search + - fast_beam_search_nbest + - fast_beam_search_nbest_oracle + - fast_beam_search_nbest_LG + If you use fast_beam_search_nbest_LG, you have to specify + `--lang-dir`, which should contain `LG.pt`. + """, + ) + + parser.add_argument( + "--decoding-graph", + type=str, + default="", + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=20.0, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search, + fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle + """, + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.01, + help=""" + Used only when --decoding_method is fast_beam_search_nbest_LG. + It specifies the scale for n-gram LM scores. + """, + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=8, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=64, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. + Used only when --decoding_method is greedy_search""", + ) + + parser.add_argument( + "--num-paths", + type=int, + default=200, + help="""Number of paths for nbest decoding. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""Scale applied to lattice scores when computing nbest paths. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--pad-feature", + type=int, + default=30, + help=""" + Number of frames to pad at the end. + """, + ) + + add_model_arguments(parser) + + return parser + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + sp: Tokenizer, + batch: dict, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if greedy_search is used, it would be "greedy_search" + If beam search with a beam size of 7 is used, it would be + "beam_7" + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return the decoding result. See above description for the format of + the returned dict. + """ + device = next(model.parameters()).device + feature = batch["inputs"] + assert feature.ndim == 3 + + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + if params.pad_feature: + feature_lens += params.pad_feature + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, params.pad_feature), + value=LOG_EPS, + ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + + hyps = [] + + if params.decoding_method == "fast_beam_search": + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(sp.text2word(hyp)) + elif params.decoding_method == "fast_beam_search_nbest_LG": + hyp_tokens = fast_beam_search_nbest_LG( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in hyp_tokens: + hyps.append([word_table[i] for i in hyp]) + elif params.decoding_method == "fast_beam_search_nbest": + hyp_tokens = fast_beam_search_nbest( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(sp.text2word(hyp)) + elif params.decoding_method == "fast_beam_search_nbest_oracle": + hyp_tokens = fast_beam_search_nbest_oracle( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + ref_texts=sp.encode(supervisions["text"]), + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(sp.text2word(hyp)) + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(sp.text2word(hyp)) + elif params.decoding_method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(sp.text2word(hyp)) + else: + batch_size = encoder_out.size(0) + + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.decoding_method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyps.append(sp.text2word(sp.decode(hyp))) + + if params.decoding_method == "greedy_search": + return {"greedy_search": hyps} + elif "fast_beam_search" in params.decoding_method: + key = f"beam_{params.beam}_" + key += f"max_contexts_{params.max_contexts}_" + key += f"max_states_{params.max_states}" + if "nbest" in params.decoding_method: + key += f"_num_paths_{params.num_paths}_" + key += f"nbest_scale_{params.nbest_scale}" + if "LG" in params.decoding_method: + key += f"_ngram_lm_scale_{params.ngram_lm_scale}" + + return {key: hyps} + else: + return {f"beam_size_{params.beam_size}": hyps} + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + sp: Tokenizer, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + if params.decoding_method == "greedy_search": + log_interval = 50 + else: + log_interval = 20 + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + word_table=word_table, + batch=batch, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = sp.text2word(ref_text) + this_batch.append((cut_id, ref_words, hyp_words)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = ( + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + return test_set_wers + + +@torch.no_grad() +def main(): + parser = get_parser() + CSJAsrDataModule.add_arguments(parser) + Tokenizer.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + assert params.decoding_method in ( + "greedy_search", + "beam_search", + "fast_beam_search", + "fast_beam_search_nbest", + "fast_beam_search_nbest_LG", + "fast_beam_search_nbest_oracle", + "modified_beam_search", + ) + if not params.res_dir: + params.res_dir = params.exp_dir / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + params.suffix += f"-streaming-chunk-size-{params.decode_chunk_len}" + + if "fast_beam_search" in params.decoding_method: + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + if "nbest" in params.decoding_method: + params.suffix += f"-nbest-scale-{params.nbest_scale}" + params.suffix += f"-num-paths-{params.num_paths}" + if "LG" in params.decoding_method: + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" + elif "beam_search" in params.decoding_method: + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + else: + params.suffix += f"-context-{params.context_size}" + params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", params.gpu) + + logging.info(f"Device: {device}") + + sp = Tokenizer.load(params.lang, params.lang_type) + + # and are defined in local/prepare_lang_char.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + assert model.encoder.decode_chunk_size == params.decode_chunk_len // 2, ( + model.encoder.decode_chunk_size, + params.decode_chunk_len, + ) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + + decoding_graph = None + word_table = None + + if params.decoding_graph: + decoding_graph = k2.Fsa.from_dict( + torch.load(params.decoding_graph, map_location=device) + ) + elif "fast_beam_search" in params.decoding_method: + if params.decoding_method == "fast_beam_search_nbest_LG": + lexicon = Lexicon(params.lang_dir) + word_table = lexicon.word_table + lg_filename = params.lang_dir / "LG.pt" + logging.info(f"Loading {lg_filename}") + decoding_graph = k2.Fsa.from_dict( + torch.load(lg_filename, map_location=device) + ) + decoding_graph.scores *= params.ngram_lm_scale + else: + word_table = None + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + csj_corpus = CSJAsrDataModule(args) + + for subdir in ["eval1", "eval2", "eval3", "excluded", "valid"]: + results_dict = decode_dataset( + dl=csj_corpus.test_dataloaders(getattr(csj_corpus, f"{subdir}_cuts")()), + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + ) + tot_err = save_results( + params=params, + test_set_name=subdir, + results_dict=results_dict, + ) + with ( + params.res_dir + / ( + f"{subdir}-{params.decode_chunk_len}_{params.beam_size}" + f"_{params.avg}_{params.epoch}.cer" + ) + ).open("w") as fout: + if len(tot_err) == 1: + fout.write(f"{tot_err[0][1]}") + else: + fout.write("\n".join(f"{k}\t{v}") for k, v in tot_err) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/decode_stream.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/decode_stream.py new file mode 120000 index 000000000..ca8fed319 --- /dev/null +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/decode_stream.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/decode_stream.py \ No newline at end of file diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/decoder.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/decoder.py new file mode 120000 index 000000000..1ce277aa6 --- /dev/null +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/decoder.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/decoder.py \ No newline at end of file diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/encoder_interface.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/encoder_interface.py new file mode 120000 index 000000000..cb673b3eb --- /dev/null +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/encoder_interface.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/encoder_interface.py \ No newline at end of file diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/export.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/export.py new file mode 100644 index 000000000..2d45ecca3 --- /dev/null +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/export.py @@ -0,0 +1,313 @@ +#!/usr/bin/env python3 +# +# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This script converts several saved checkpoints +# to a single one using model averaging. +""" + +Usage: + +(1) Export to torchscript model using torch.jit.script() + +./pruned_transducer_stateless7_streaming/export.py \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --lang data/lang_char \ + --epoch 30 \ + --avg 9 \ + --jit 1 + +It will generate a file `cpu_jit.pt` in the given `exp_dir`. You can later +load it by `torch.jit.load("cpu_jit.pt")`. + +Note `cpu` in the name `cpu_jit.pt` means the parameters when loaded into Python +are on CPU. You can use `to("cuda")` to move them to a CUDA device. + +Check +https://github.com/k2-fsa/sherpa +for how to use the exported models outside of icefall. + +(2) Export `model.state_dict()` + +./pruned_transducer_stateless7_streaming/export.py \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --lang data/lang_char \ + --epoch 20 \ + --avg 10 + +It will generate a file `pretrained.pt` in the given `exp_dir`. You can later +load it by `icefall.checkpoint.load_checkpoint()`. + +To use the generated file with `pruned_transducer_stateless7_streaming/decode.py`, +you can do: + + cd /path/to/exp_dir + ln -s pretrained.pt epoch-9999.pt + + cd /path/to/egs/csj/ASR + ./pruned_transducer_stateless7_streaming/decode.py \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --epoch 9999 \ + --avg 1 \ + --max-duration 600 \ + --decoding-method greedy_search \ + --lang data/lang_char + +Check ./pretrained.py for its usage. + +Note: If you don't want to train a model from scratch, we have +provided one for you. You can get it at + +https://huggingface.co/TeoWenShen/icefall-asr-csj-pruned-transducer-stateless7-streaming-230208 + +with the following commands: + + sudo apt-get install git-lfs + git lfs install + git clone https://huggingface.co/TeoWenShen/icefall-asr-csj-pruned-transducer-stateless7-streaming-230208 + # You will find the pre-trained model in icefall-asr-csj-pruned-transducer-stateless7-230208/exp_fluent +""" + +import argparse +import logging +from pathlib import Path + +import torch +from scaling_converter import convert_scaled_to_non_scaled +from tokenizer import Tokenizer +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=9, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7_streaming/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--jit", + type=str2bool, + default=False, + help="""True to save a model after applying torch.jit.script. + It will generate a file named cpu_jit.pt + + Check ./jit_pretrained.py for how to use it. + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + add_model_arguments(parser) + + return parser + + +@torch.no_grad() +def main(): + parser = get_parser() + Tokenizer.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + sp = Tokenizer.load(params.lang, params.lang_type) + + # is defined in local/prepare_lang_char.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + model.to(device) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to("cpu") + model.eval() + + if params.jit is True: + convert_scaled_to_non_scaled(model, inplace=True) + # We won't use the forward() method of the model in C++, so just ignore + # it here. + # Otherwise, one of its arguments is a ragged tensor and is not + # torch scriptabe. + model.__class__.forward = torch.jit.ignore(model.__class__.forward) + logging.info("Using torch.jit.script") + model = torch.jit.script(model) + filename = params.exp_dir / "cpu_jit.pt" + model.save(str(filename)) + logging.info(f"Saved to {filename}") + else: + logging.info("Not using torchscript. Export model.state_dict()") + # Save it using a format so that it can be loaded + # by :func:`load_checkpoint` + filename = params.exp_dir / "pretrained.pt" + torch.save({"model": model.state_dict()}, str(filename)) + logging.info(f"Saved to {filename}") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/jit_trace_export.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/jit_trace_export.py new file mode 100644 index 000000000..ab7c8748a --- /dev/null +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/jit_trace_export.py @@ -0,0 +1,308 @@ +#!/usr/bin/env python3 + +""" +Usage: +# use -O to skip assertions and avoid some of the TracerWarnings +python -O pruned_transducer_stateless7_streaming/jit_trace_export.py \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --lang data/lang_char \ + --epoch 30 \ + --avg 10 \ + --use-averaged-model=True \ + --decode-chunk-len 32 +""" + +import argparse +import logging +from pathlib import Path + +import torch +from scaling_converter import convert_scaled_to_non_scaled +from tokenizer import Tokenizer +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import AttributeDict, str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="""It specifies the checkpoint to use for averaging. + Note: Epoch counts from 0. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless2/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + add_model_arguments(parser) + + return parser + + +def export_encoder_model_jit_trace( + encoder_model: torch.nn.Module, + encoder_filename: str, + params: AttributeDict, +) -> None: + """Export the given encoder model with torch.jit.trace() + + Note: The warmup argument is fixed to 1. + + Args: + encoder_model: + The input encoder model + encoder_filename: + The filename to save the exported model. + """ + decode_chunk_len = params.decode_chunk_len # before subsampling + pad_length = 7 + s = f"decode_chunk_len: {decode_chunk_len}" + logging.info(s) + assert encoder_model.decode_chunk_size == decode_chunk_len // 2, ( + encoder_model.decode_chunk_size, + decode_chunk_len, + ) + + T = decode_chunk_len + pad_length + + x = torch.zeros(1, T, 80, dtype=torch.float32) + x_lens = torch.full((1,), T, dtype=torch.int32) + states = encoder_model.get_init_state(device=x.device) + + encoder_model.__class__.forward = encoder_model.__class__.streaming_forward + traced_model = torch.jit.trace(encoder_model, (x, x_lens, states)) + traced_model.save(encoder_filename) + logging.info(f"Saved to {encoder_filename}") + + +def export_decoder_model_jit_trace( + decoder_model: torch.nn.Module, + decoder_filename: str, +) -> None: + """Export the given decoder model with torch.jit.trace() + + Note: The argument need_pad is fixed to False. + + Args: + decoder_model: + The input decoder model + decoder_filename: + The filename to save the exported model. + """ + y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64) + need_pad = torch.tensor([False]) + + traced_model = torch.jit.trace(decoder_model, (y, need_pad)) + traced_model.save(decoder_filename) + logging.info(f"Saved to {decoder_filename}") + + +def export_joiner_model_jit_trace( + joiner_model: torch.nn.Module, + joiner_filename: str, +) -> None: + """Export the given joiner model with torch.jit.trace() + + Note: The argument project_input is fixed to True. A user should not + project the encoder_out/decoder_out by himself/herself. The exported joiner + will do that for the user. + + Args: + joiner_model: + The input joiner model + joiner_filename: + The filename to save the exported model. + + """ + encoder_out_dim = joiner_model.encoder_proj.weight.shape[1] + decoder_out_dim = joiner_model.decoder_proj.weight.shape[1] + encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32) + decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32) + + traced_model = torch.jit.trace(joiner_model, (encoder_out, decoder_out)) + traced_model.save(joiner_filename) + logging.info(f"Saved to {joiner_filename}") + + +@torch.no_grad() +def main(): + parser = get_parser() + Tokenizer.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + + logging.info(f"device: {device}") + + sp = Tokenizer.load(params.lang, params.lang_type) + + # is defined in local/prepare_lang_char.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to("cpu") + model.eval() + + convert_scaled_to_non_scaled(model, inplace=True) + logging.info("Using torch.jit.trace()") + + logging.info("Exporting encoder") + encoder_filename = params.exp_dir / "encoder_jit_trace.pt" + export_encoder_model_jit_trace(model.encoder, encoder_filename, params) + + logging.info("Exporting decoder") + decoder_filename = params.exp_dir / "decoder_jit_trace.pt" + export_decoder_model_jit_trace(model.decoder, decoder_filename) + + logging.info("Exporting joiner") + joiner_filename = params.exp_dir / "joiner_jit_trace.pt" + export_joiner_model_jit_trace(model.joiner, joiner_filename) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/jit_trace_pretrained.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/jit_trace_pretrained.py new file mode 100644 index 000000000..d84cf04a3 --- /dev/null +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/jit_trace_pretrained.py @@ -0,0 +1,286 @@ +#!/usr/bin/env python3 +# flake8: noqa +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang, Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads torchscript models exported by `torch.jit.trace()` +and uses them to decode waves. +You can use the following command to get the exported models: + +./pruned_transducer_stateless7_streaming/jit_trace_export.py \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --lang data/lang_char \ + --epoch 30 \ + --avg 10 \ + --use-averaged-model=True \ + --decode-chunk-len 32 + +Usage of this script: + +./pruned_transducer_stateless7_streaming/jit_trace_pretrained.py \ + --encoder-model-filename ./pruned_transducer_stateless7_streaming/exp/encoder_jit_trace.pt \ + --decoder-model-filename ./pruned_transducer_stateless7_streaming/exp/decoder_jit_trace.pt \ + --joiner-model-filename ./pruned_transducer_stateless7_streaming/exp/joiner_jit_trace.pt \ + --lang data/lang_char \ + --decode-chunk-len 32 \ + /path/to/foo.wav \ +""" + +import argparse +import logging +from typing import List, Optional + +import torch +import torchaudio +from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature +from tokenizer import Tokenizer + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--encoder-model-filename", + type=str, + required=True, + help="Path to the encoder torchscript model. ", + ) + + parser.add_argument( + "--decoder-model-filename", + type=str, + required=True, + help="Path to the decoder torchscript model. ", + ) + + parser.add_argument( + "--joiner-model-filename", + type=str, + required=True, + help="Path to the joiner torchscript model. ", + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + parser.add_argument( + "--decode-chunk-len", + type=int, + default=32, + help="The chunk size for decoding (in frames before subsampling)", + ) + + parser.add_argument( + "sound_file", + type=str, + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0]) + return ans + + +def greedy_search( + decoder: torch.jit.ScriptModule, + joiner: torch.jit.ScriptModule, + encoder_out: torch.Tensor, + decoder_out: Optional[torch.Tensor] = None, + hyp: Optional[List[int]] = None, +): + assert encoder_out.ndim == 2 + context_size = 2 + blank_id = 0 + + if decoder_out is None: + assert hyp is None, hyp + hyp = [blank_id] * context_size + decoder_input = torch.tensor(hyp, dtype=torch.int32).unsqueeze(0) + # decoder_input.shape (1,, 1 context_size) + decoder_out = decoder(decoder_input, torch.tensor([False])).squeeze(1) + else: + assert decoder_out.ndim == 2 + assert hyp is not None, hyp + + T = encoder_out.size(0) + for i in range(T): + cur_encoder_out = encoder_out[i : i + 1] + joiner_out = joiner(cur_encoder_out, decoder_out).squeeze(0) + y = joiner_out.argmax(dim=0).item() + + if y != blank_id: + hyp.append(y) + decoder_input = hyp[-context_size:] + + decoder_input = torch.tensor(decoder_input, dtype=torch.int32).unsqueeze(0) + decoder_out = decoder(decoder_input, torch.tensor([False])).squeeze(1) + + return hyp, decoder_out + + +def create_streaming_feature_extractor(sample_rate) -> OnlineFeature: + """Create a CPU streaming feature extractor. + + At present, we assume it returns a fbank feature extractor with + fixed options. In the future, we will support passing in the options + from outside. + + Returns: + Return a CPU streaming feature extractor. + """ + opts = FbankOptions() + opts.device = "cpu" + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = sample_rate + opts.mel_opts.num_bins = 80 + return OnlineFbank(opts) + + +@torch.no_grad() +def main(): + parser = get_parser() + Tokenizer.add_arguments(parser) + args = parser.parse_args() + logging.info(vars(args)) + + device = torch.device("cpu") + + logging.info(f"device: {device}") + + encoder = torch.jit.load(args.encoder_model_filename) + decoder = torch.jit.load(args.decoder_model_filename) + joiner = torch.jit.load(args.joiner_model_filename) + + encoder.eval() + decoder.eval() + joiner.eval() + + encoder.to(device) + decoder.to(device) + joiner.to(device) + + sp = Tokenizer.load(args.lang, args.lang_type) + + logging.info("Constructing Fbank computer") + online_fbank = create_streaming_feature_extractor(args.sample_rate) + + logging.info(f"Reading sound files: {args.sound_file}") + wave_samples = read_sound_files( + filenames=[args.sound_file], + expected_sample_rate=args.sample_rate, + )[0] + logging.info(wave_samples.shape) + + logging.info("Decoding started") + chunk_length = args.decode_chunk_len + assert encoder.decode_chunk_size == chunk_length // 2, ( + encoder.decode_chunk_size, + chunk_length, + ) + + # we subsample features with ((x_len - 7) // 2 + 1) // 2 + pad_length = 7 + T = chunk_length + pad_length + + logging.info(f"chunk_length: {chunk_length}") + + states = encoder.get_init_state(device) + + tail_padding = torch.zeros(int(0.3 * args.sample_rate), dtype=torch.float32) + + wave_samples = torch.cat([wave_samples, tail_padding]) + + chunk = int(0.25 * args.sample_rate) # 0.2 second + num_processed_frames = 0 + + hyp = None + decoder_out = None + + start = 0 + while start < wave_samples.numel(): + logging.info(f"{start}/{wave_samples.numel()}") + end = min(start + chunk, wave_samples.numel()) + samples = wave_samples[start:end] + start += chunk + online_fbank.accept_waveform( + sampling_rate=args.sample_rate, + waveform=samples, + ) + while online_fbank.num_frames_ready - num_processed_frames >= T: + frames = [] + for i in range(T): + frames.append(online_fbank.get_frame(num_processed_frames + i)) + frames = torch.cat(frames, dim=0).unsqueeze(0) + x_lens = torch.tensor([T], dtype=torch.int32) + encoder_out, out_lens, states = encoder( + x=frames, + x_lens=x_lens, + states=states, + ) + num_processed_frames += chunk_length + + hyp, decoder_out = greedy_search( + decoder, joiner, encoder_out.squeeze(0), decoder_out, hyp + ) + + context_size = 2 + logging.info(args.sound_file) + logging.info(sp.decode(hyp[context_size:])) + + logging.info("Decoding Done") + + +torch.set_num_threads(4) +torch.set_num_interop_threads(1) +torch._C._jit_set_profiling_executor(False) +torch._C._jit_set_profiling_mode(False) +torch._C._set_graph_executor_optimize(False) +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/joiner.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/joiner.py new file mode 120000 index 000000000..482ebcfef --- /dev/null +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/joiner.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/joiner.py \ No newline at end of file diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/model.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/model.py new file mode 120000 index 000000000..16c2bf28d --- /dev/null +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/model.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/model.py \ No newline at end of file diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/optim.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/optim.py new file mode 120000 index 000000000..522bbaff9 --- /dev/null +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/optim.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/optim.py \ No newline at end of file diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/pretrained.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/pretrained.py new file mode 100644 index 000000000..932026868 --- /dev/null +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/pretrained.py @@ -0,0 +1,347 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads a checkpoint and uses it to decode waves. +You can generate the checkpoint with the following command: + +./pruned_transducer_stateless7_streaming/export.py \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --lang data/lang_char \ + --epoch 20 \ + --avg 10 + +Usage of this script: + +(1) greedy search +./pruned_transducer_stateless7_streaming/pretrained.py \ + --checkpoint ./pruned_transducer_stateless7_streaming/exp/pretrained.pt \ + --lang data/lang_char \ + --method greedy_search \ + /path/to/foo.wav \ + /path/to/bar.wav + +(2) beam search +./pruned_transducer_stateless7_streaming/pretrained.py \ + --checkpoint ./pruned_transducer_stateless7_streaming/exp/pretrained.pt \ + --lang data/lang_char \ + --method beam_search \ + --beam-size 4 \ + /path/to/foo.wav \ + /path/to/bar.wav + +(3) modified beam search +./pruned_transducer_stateless7_streaming/pretrained.py \ + --checkpoint ./pruned_transducer_stateless7_streaming/exp/pretrained.pt \ + --lang data/lang_char \ + --method modified_beam_search \ + --beam-size 4 \ + /path/to/foo.wav \ + /path/to/bar.wav + +(4) fast beam search +./pruned_transducer_stateless7_streaming/pretrained.py \ + --checkpoint ./pruned_transducer_stateless7_streaming/exp/pretrained.pt \ + --lang data/lang_char \ + --method fast_beam_search \ + --beam-size 4 \ + /path/to/foo.wav \ + /path/to/bar.wav + +You can also use `./pruned_transducer_stateless7_streaming/exp/epoch-xx.pt`. + +Note: ./pruned_transducer_stateless7_streaming/exp/pretrained.pt is generated by +./pruned_transducer_stateless7_streaming/export.py +""" + + +import argparse +import logging +import math +from typing import List + +import k2 +import kaldifeat +import torch +import torchaudio +from beam_search import ( + beam_search, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, +) +from tokenizer import Tokenizer +from torch.nn.utils.rnn import pad_sequence +from train import add_model_arguments, get_params, get_transducer_model + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--checkpoint", + type=str, + required=True, + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", + ) + + parser.add_argument( + "--method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - fast_beam_search + """, + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=4, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=4, + help="""Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=8, + help="""Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. Used only when + --method is greedy_search. + """, + ) + + add_model_arguments(parser) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0]) + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + Tokenizer.add_arguments(parser) + args = parser.parse_args() + + params = get_params() + + params.update(vars(args)) + + sp = Tokenizer.load(params.lang, params.lang_type) + + # is defined in local/prepare_lang_char.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(f"{params}") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + logging.info("Creating model") + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + checkpoint = torch.load(args.checkpoint, map_location="cpu") + model.load_state_dict(checkpoint["model"], strict=False) + model.to(device) + model.eval() + model.device = device + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = params.sample_rate + opts.mel_opts.num_bins = params.feature_dim + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {params.sound_files}") + waves = read_sound_files( + filenames=params.sound_files, expected_sample_rate=params.sample_rate + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + + feature_lengths = torch.tensor(feature_lengths, device=device) + + encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths) + + num_waves = encoder_out.size(0) + hyps = [] + msg = f"Using {params.method}" + if params.method == "beam_search": + msg += f" with beam size {params.beam_size}" + logging.info(msg) + + if params.method == "fast_beam_search": + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + ) + + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + else: + for i in range(num_waves): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError(f"Unsupported method: {params.method}") + + hyps.append(sp.decode(hyp).split()) + + s = "\n" + for filename, hyp in zip(params.sound_files, hyps): + words = " ".join(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/scaling.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/scaling.py new file mode 120000 index 000000000..a7ef73bcb --- /dev/null +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/scaling.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/scaling.py \ No newline at end of file diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/scaling_converter.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/scaling_converter.py new file mode 120000 index 000000000..566c317ff --- /dev/null +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/scaling_converter.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/scaling_converter.py \ No newline at end of file diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/streaming_beam_search.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/streaming_beam_search.py new file mode 120000 index 000000000..2adf271c1 --- /dev/null +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/streaming_beam_search.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/streaming_beam_search.py \ No newline at end of file diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py new file mode 100755 index 000000000..9700dd89e --- /dev/null +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py @@ -0,0 +1,597 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corporation (Authors: Wei Kang, Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Usage: +./pruned_transducer_stateless7_streaming/streaming_decode.py \ + --epoch 28 \ + --avg 15 \ + --decode-chunk-len 32 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --decoding_method greedy_search \ + --lang data/lang_char \ + --num-decode-streams 2000 +""" + +import argparse +import logging +import math +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import numpy as np +import torch +import torch.nn as nn +from asr_datamodule import CSJAsrDataModule +from decode import save_results +from decode_stream import DecodeStream +from kaldifeat import Fbank, FbankOptions +from lhotse import CutSet +from streaming_beam_search import ( + fast_beam_search_one_best, + greedy_search, + modified_beam_search, +) +from tokenizer import Tokenizer +from torch.nn.utils.rnn import pad_sequence +from train import add_model_arguments, get_params, get_transducer_model +from zipformer import stack_states, unstack_states + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import AttributeDict, setup_logger, str2bool + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 0. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--gpu", + type=int, + default=0, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless2/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Supported decoding methods are: + greedy_search + modified_beam_search + fast_beam_search + """, + ) + + parser.add_argument( + "--decoding-graph", + type=str, + default="", + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--num_active_paths", + type=int, + default=4, + help="""An interger indicating how many candidates we will keep for each + frame. Used only when --decoding-method is modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=4.0, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search""", + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=4, + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=32, + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--num-decode-streams", + type=int, + default=2000, + help="The number of streams that can be decoded parallel.", + ) + + parser.add_argument( + "--res-dir", + type=Path, + default=None, + help="The path to save results.", + ) + + add_model_arguments(parser) + + return parser + + +def decode_one_chunk( + params: AttributeDict, + model: nn.Module, + decode_streams: List[DecodeStream], +) -> List[int]: + """Decode one chunk frames of features for each decode_streams and + return the indexes of finished streams in a List. + + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + decode_streams: + A List of DecodeStream, each belonging to a utterance. + Returns: + Return a List containing which DecodeStreams are finished. + """ + device = model.device + + features = [] + feature_lens = [] + states = [] + processed_lens = [] + + for stream in decode_streams: + feat, feat_len = stream.get_feature_frames(params.decode_chunk_len) + features.append(feat) + feature_lens.append(feat_len) + states.append(stream.states) + processed_lens.append(stream.done_frames) + + feature_lens = torch.tensor(feature_lens, device=device) + features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS) + + # We subsample features with ((x_len - 7) // 2 + 1) // 2 and the max downsampling + # factor in encoders is 8. + # After feature embedding (x_len - 7) // 2, we have (23 - 7) // 2 = 8. + tail_length = 23 + if features.size(1) < tail_length: + pad_length = tail_length - features.size(1) + feature_lens += pad_length + features = torch.nn.functional.pad( + features, + (0, 0, 0, pad_length), + mode="constant", + value=LOG_EPS, + ) + + states = stack_states(states) + processed_lens = torch.tensor(processed_lens, device=device) + + encoder_out, encoder_out_lens, new_states = model.encoder.streaming_forward( + x=features, + x_lens=feature_lens, + states=states, + ) + + encoder_out = model.joiner.encoder_proj(encoder_out) + + if params.decoding_method == "greedy_search": + greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams) + elif params.decoding_method == "fast_beam_search": + processed_lens = processed_lens + encoder_out_lens + fast_beam_search_one_best( + model=model, + encoder_out=encoder_out, + processed_lens=processed_lens, + streams=decode_streams, + beam=params.beam, + max_states=params.max_states, + max_contexts=params.max_contexts, + ) + elif params.decoding_method == "modified_beam_search": + modified_beam_search( + model=model, + streams=decode_streams, + encoder_out=encoder_out, + num_active_paths=params.num_active_paths, + ) + else: + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + + states = unstack_states(new_states) + + finished_streams = [] + for i in range(len(decode_streams)): + decode_streams[i].states = states[i] + decode_streams[i].done_frames += encoder_out_lens[i] + if decode_streams[i].done: + finished_streams.append(i) + + return finished_streams + + +def decode_dataset( + cuts: CutSet, + params: AttributeDict, + model: nn.Module, + sp: Tokenizer, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[List[str], List[str]]]]: + """Decode dataset. + + Args: + cuts: + Lhotse Cutset containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + device = model.device + + opts = FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = 16000 + opts.mel_opts.num_bins = 80 + + log_interval = 50 + + decode_results = [] + # Contain decode streams currently running. + decode_streams = [] + for num, cut in enumerate(cuts): + # each utterance has a DecodeStream. + initial_states = model.encoder.get_init_state(device=device) + decode_stream = DecodeStream( + params=params, + cut_id=cut.id, + initial_states=initial_states, + decoding_graph=decoding_graph, + device=device, + ) + + audio: np.ndarray = cut.load_audio() + # audio.shape: (1, num_samples) + assert len(audio.shape) == 2 + assert audio.shape[0] == 1, "Should be single channel" + assert audio.dtype == np.float32, audio.dtype + + # The trained model is using normalized samples + assert audio.max() <= 1, "Should be normalized to [-1, 1])" + + samples = torch.from_numpy(audio).squeeze(0) + + fbank = Fbank(opts) + feature = fbank(samples.to(device)) + decode_stream.set_features(feature, tail_pad_len=params.decode_chunk_len) + decode_stream.ground_truth = cut.supervisions[0].custom[params.transcript_mode] + + decode_streams.append(decode_stream) + + while len(decode_streams) >= params.num_decode_streams: + finished_streams = decode_one_chunk( + params=params, model=model, decode_streams=decode_streams + ) + for i in sorted(finished_streams, reverse=True): + decode_results.append( + ( + decode_streams[i].id, + sp.text2word(decode_streams[i].ground_truth), + sp.text2word(sp.decode(decode_streams[i].decoding_result())), + ) + ) + del decode_streams[i] + + if num % log_interval == 0: + logging.info(f"Cuts processed until now is {num}.") + + # decode final chunks of last sequences + while len(decode_streams): + finished_streams = decode_one_chunk( + params=params, model=model, decode_streams=decode_streams + ) + for i in sorted(finished_streams, reverse=True): + decode_results.append( + ( + decode_streams[i].id, + sp.text2word(decode_streams[i].ground_truth), + sp.text2word(sp.decode(decode_streams[i].decoding_result())), + ) + ) + del decode_streams[i] + + if params.decoding_method == "greedy_search": + key = "greedy_search" + elif params.decoding_method == "fast_beam_search": + key = ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}" + ) + elif params.decoding_method == "modified_beam_search": + key = f"num_active_paths_{params.num_active_paths}" + else: + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + return {key: decode_results} + + +@torch.no_grad() +def main(): + parser = get_parser() + CSJAsrDataModule.add_arguments(parser) + Tokenizer.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + if not params.res_dir: + params.res_dir = params.exp_dir / "streaming" / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + # for streaming + params.suffix += f"-streaming-chunk-size-{params.decode_chunk_len}" + + # for fast_beam_search + if params.decoding_method == "fast_beam_search": + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", params.gpu) + + logging.info(f"Device: {device}") + + sp = Tokenizer.load(params.lang, params.lang_type) + + # and is defined in local/prepare_lang_char.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if start >= 0: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + model.device = device + + decoding_graph = None + if params.decoding_graph: + decoding_graph = k2.Fsa.from_dict( + torch.load(params.decoding_graph, map_location=device) + ) + elif params.decoding_method == "fast_beam_search": + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + args.return_cuts = True + csj_corpus = CSJAsrDataModule(args) + + for subdir in ["eval1", "eval2", "eval3", "excluded", "valid"]: + results_dict = decode_dataset( + cuts=getattr(csj_corpus, f"{subdir}_cuts")(), + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + ) + tot_err = save_results( + params=params, test_set_name=subdir, results_dict=results_dict + ) + + with ( + params.res_dir + / ( + f"{subdir}-{params.decode_chunk_len}" + f"_{params.avg}_{params.epoch}.cer" + ) + ).open("w") as fout: + if len(tot_err) == 1: + fout.write(f"{tot_err[0][1]}") + else: + fout.write("\n".join(f"{k}\t{v}") for k, v in tot_err) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/test_model.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/test_model.py new file mode 100755 index 000000000..0a82ccfa4 --- /dev/null +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/test_model.py @@ -0,0 +1,150 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +To run this file, do: + + cd icefall/egs/csj/ASR + python ./pruned_transducer_stateless7_streaming/test_model.py +""" + +import torch +from scaling_converter import convert_scaled_to_non_scaled +from train import get_params, get_transducer_model + + +def test_model(): + params = get_params() + params.vocab_size = 500 + params.blank_id = 0 + params.context_size = 2 + params.num_encoder_layers = "2,4,3,2,4" + params.feedforward_dims = "1024,1024,2048,2048,1024" + params.nhead = "8,8,8,8,8" + params.encoder_dims = "384,384,384,384,384" + params.attention_dims = "192,192,192,192,192" + params.encoder_unmasked_dims = "256,256,256,256,256" + params.zipformer_downsampling_factors = "1,2,4,8,2" + params.cnn_module_kernels = "31,31,31,31,31" + params.decoder_dim = 512 + params.joiner_dim = 512 + params.num_left_chunks = 4 + params.short_chunk_size = 50 + params.decode_chunk_len = 32 + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + print(f"Number of model parameters: {num_param}") + + # Test jit script + convert_scaled_to_non_scaled(model, inplace=True) + # We won't use the forward() method of the model in C++, so just ignore + # it here. + # Otherwise, one of its arguments is a ragged tensor and is not + # torch scriptabe. + model.__class__.forward = torch.jit.ignore(model.__class__.forward) + print("Using torch.jit.script") + model = torch.jit.script(model) + + +def test_model_jit_trace(): + params = get_params() + params.vocab_size = 500 + params.blank_id = 0 + params.context_size = 2 + params.num_encoder_layers = "2,4,3,2,4" + params.feedforward_dims = "1024,1024,2048,2048,1024" + params.nhead = "8,8,8,8,8" + params.encoder_dims = "384,384,384,384,384" + params.attention_dims = "192,192,192,192,192" + params.encoder_unmasked_dims = "256,256,256,256,256" + params.zipformer_downsampling_factors = "1,2,4,8,2" + params.cnn_module_kernels = "31,31,31,31,31" + params.decoder_dim = 512 + params.joiner_dim = 512 + params.num_left_chunks = 4 + params.short_chunk_size = 50 + params.decode_chunk_len = 32 + model = get_transducer_model(params) + model.eval() + + num_param = sum([p.numel() for p in model.parameters()]) + print(f"Number of model parameters: {num_param}") + + convert_scaled_to_non_scaled(model, inplace=True) + + # Test encoder + def _test_encoder(): + encoder = model.encoder + assert encoder.decode_chunk_size == params.decode_chunk_len // 2, ( + encoder.decode_chunk_size, + params.decode_chunk_len, + ) + T = params.decode_chunk_len + 7 + + x = torch.zeros(1, T, 80, dtype=torch.float32) + x_lens = torch.full((1,), T, dtype=torch.int32) + states = encoder.get_init_state(device=x.device) + encoder.__class__.forward = encoder.__class__.streaming_forward + traced_encoder = torch.jit.trace(encoder, (x, x_lens, states)) + + states1 = encoder.get_init_state(device=x.device) + states2 = traced_encoder.get_init_state(device=x.device) + for i in range(5): + x = torch.randn(1, T, 80, dtype=torch.float32) + x_lens = torch.full((1,), T, dtype=torch.int32) + y1, _, states1 = encoder.streaming_forward(x, x_lens, states1) + y2, _, states2 = traced_encoder(x, x_lens, states2) + assert torch.allclose(y1, y2, atol=1e-6), (i, (y1 - y2).abs().mean()) + + # Test decoder + def _test_decoder(): + decoder = model.decoder + y = torch.zeros(10, decoder.context_size, dtype=torch.int64) + need_pad = torch.tensor([False]) + + traced_decoder = torch.jit.trace(decoder, (y, need_pad)) + d1 = decoder(y, need_pad) + d2 = traced_decoder(y, need_pad) + assert torch.equal(d1, d2), (d1 - d2).abs().mean() + + # Test joiner + def _test_joiner(): + joiner = model.joiner + encoder_out_dim = joiner.encoder_proj.weight.shape[1] + decoder_out_dim = joiner.decoder_proj.weight.shape[1] + encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32) + decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32) + + traced_joiner = torch.jit.trace(joiner, (encoder_out, decoder_out)) + j1 = joiner(encoder_out, decoder_out) + j2 = traced_joiner(encoder_out, decoder_out) + assert torch.equal(j1, j2), (j1 - j2).abs().mean() + + _test_encoder() + _test_decoder() + _test_joiner() + + +def main(): + test_model() + test_model_jit_trace() + + +if __name__ == "__main__": + main() diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/tokenizer.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/tokenizer.py new file mode 120000 index 000000000..958c99e85 --- /dev/null +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/tokenizer.py @@ -0,0 +1 @@ +../local/utils/tokenizer.py \ No newline at end of file diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/train.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/train.py new file mode 100755 index 000000000..601de2c41 --- /dev/null +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/train.py @@ -0,0 +1,1304 @@ +#!/usr/bin/env python3 +# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo,) +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +./pruned_transducer_stateless7_streaming/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --exp-dir pruned_transducer_stateless7_streaming/exp \ + --lang data/lang_char \ + --max-duration 300 + +# For mix precision training: + +./pruned_transducer_stateless7_streaming/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir pruned_transducer_stateless7_streaming/exp \ + --lang data/lang_char \ + --max-duration 550 +""" + + +import argparse +import copy +import logging +import math +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import k2 +import optim +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import CSJAsrDataModule +from decoder import Decoder +from joiner import Joiner +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model import Transducer +from optim import Eden, ScaledAdam +from tokenizer import Tokenizer +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from zipformer import Zipformer + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] +LOG_EPS = math.log(1e-10) + +try: + from TelegramStreamIO import TelegramStreamIO + + HAS_TELEGRAM = True +except ImportError: + HAS_TELEGRAM = False + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for module in model.modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=str, + default="2,4,3,2,4", + help="Number of zipformer encoder layers, comma separated.", + ) + + parser.add_argument( + "--feedforward-dims", + type=str, + default="1024,1024,2048,2048,1024", + help="Feedforward dimension of the zipformer encoder layers, comma separated.", + ) + + parser.add_argument( + "--nhead", + type=str, + default="8,8,8,8,8", + help="Number of attention heads in the zipformer encoder layers.", + ) + + parser.add_argument( + "--encoder-dims", + type=str, + default="384,384,384,384,384", + help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated", + ) + + parser.add_argument( + "--attention-dims", + type=str, + default="192,192,192,192,192", + help="""Attention dimension in the 2 blocks of zipformer encoder layers, comma separated; + not the same as embedding dimension.""", + ) + + parser.add_argument( + "--encoder-unmasked-dims", + type=str, + default="256,256,256,256,256", + help="Unmasked dimensions in the encoders, relates to augmentation during training. " + "Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance " + " worse.", + ) + + parser.add_argument( + "--zipformer-downsampling-factors", + type=str, + default="1,2,4,8,2", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--cnn-module-kernels", + type=str, + default="31,31,31,31,31", + help="Sizes of kernels in convolution modules", + ) + + parser.add_argument( + "--decoder-dim", + type=int, + default=512, + help="Embedding dimension in the decoder model.", + ) + + parser.add_argument( + "--joiner-dim", + type=int, + default=512, + help="""Dimension used in the joiner model. + Outputs from the encoder and decoder model are projected + to this dimension before adding. + """, + ) + + parser.add_argument( + "--short-chunk-size", + type=int, + default=50, + help="""Chunk length of dynamic training, the chunk size would be either + max sequence length of current batch or uniformly sampled from (1, short_chunk_size). + """, + ) + + parser.add_argument( + "--num-left-chunks", + type=int, + default=4, + help="How many left context can be seen in chunks when calculating attention.", + ) + + parser.add_argument( + "--decode-chunk-len", + type=int, + default=32, + help="The chunk size for decoding (in frames before subsampling)", + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument("--debug", action="store_true", help="Use hardcoded arguments") + + parser.add_argument( + "--telegram-cred", + type=Path, + default=None, + help="Telegram credentials to report progress in telegram", + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=Path, + default="pruned_transducer_stateless7_streaming/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--base-lr", type=float, default=0.05, help="The base learning rate." + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=5000, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=3.5, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network) part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=2000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 0. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + parser.add_argument( + "--pad-feature", + type=int, + default=0, + help=""" + Number of frames to pad at the end. + """, + ) + + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - encoder_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warmup period that dictates the decay of the + scale on "simple" (un-pruned) loss. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 1000, # For the 100h subset, use 800 + # parameters for zipformer + "feature_dim": 80, + "subsampling_factor": 4, # not passed in, this is fixed. + "warm_step": 2000, + "env_info": get_env_info(), + } + ) + + return params + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + # TODO: We can add an option to switch between Zipformer and Transformer + def to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + + encoder = Zipformer( + num_features=params.feature_dim, + output_downsampling_factor=2, + zipformer_downsampling_factors=to_int_tuple( + params.zipformer_downsampling_factors + ), + encoder_dims=to_int_tuple(params.encoder_dims), + attention_dim=to_int_tuple(params.attention_dims), + encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims), + nhead=to_int_tuple(params.nhead), + feedforward_dim=to_int_tuple(params.feedforward_dims), + cnn_module_kernels=to_int_tuple(params.cnn_module_kernels), + num_encoder_layers=to_int_tuple(params.num_encoder_layers), + num_left_chunks=params.num_left_chunks, + short_chunk_size=params.short_chunk_size, + decode_chunk_size=params.decode_chunk_len // 2, + ) + return encoder + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + decoder_dim=params.decoder_dim, + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + joiner = Joiner( + encoder_dim=int(params.encoder_dims.split(",")[-1]), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return joiner + + +def get_transducer_model(params: AttributeDict) -> nn.Module: + encoder = get_encoder_model(params) + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + + model = Transducer( + encoder=encoder, + decoder=decoder, + joiner=joiner, + encoder_dim=int(params.encoder_dims.split(",")[-1]), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + if "cur_batch_idx" in saved_params: + params["cur_batch_idx"] = saved_params["cur_batch_idx"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: Tokenizer, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute transducer loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + if params.pad_feature: + feature_lens += params.pad_feature + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, params.pad_feature), + value=LOG_EPS, + ) + + batch_idx_train = params.batch_idx_train + warm_step = params.warm_step + + texts = batch["supervisions"]["text"] + y = sp.encode(texts, out_type=int) + y = k2.RaggedTensor(y).to(device) + + with torch.set_grad_enabled(is_training): + simple_loss, pruned_loss = model( + x=feature, + x_lens=feature_lens, + y=y, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + ) + + s = params.simple_loss_scale + # take down the scale on the simple loss from 1.0 at the start + # to params.simple_loss scale by warm_step. + simple_loss_scale = ( + s + if batch_idx_train >= warm_step + else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) + ) + pruned_loss_scale = ( + 1.0 + if batch_idx_train >= warm_step + else 0.1 + 0.9 * (batch_idx_train / warm_step) + ) + + loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: Tokenizer, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + sp: Tokenizer, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + cur_batch_idx = params.get("cur_batch_idx", 0) + + for batch_idx, batch in enumerate(train_dl): + if batch_idx < cur_batch_idx: + continue + cur_batch_idx = batch_idx + + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + set_batch_count(model, params.batch_idx_train) + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except Exception as e: # noqa + logging.error(e, exc_info=True) + display_and_save_batch(batch, params=params, sp=sp) + raise e + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + params.cur_batch_idx = batch_idx + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + del params.cur_batch_idx + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if batch_idx % params.log_interval == 0: + cur_lr = scheduler.get_last_lr()[0] + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + if HAS_TELEGRAM and batch_idx in [0, 500] and not rank: + logging.warning( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + else: + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", + cur_grad_scale, + params.batch_idx_train, + ) + + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + if ( + HAS_TELEGRAM + and batch_idx % (params.valid_interval * 3) == 0 + and not rank + ): + log_mode = logging.warning + else: + log_mode = logging.info + log_mode(f"Epoch {params.cur_epoch}, validation: {valid_info}") + log_mode( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, master_port=params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + if HAS_TELEGRAM and params.telegram_cred: + TelegramStreamIO.setup_logger(params) + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + sp = Tokenizer.load(args.lang, args.lang_type) + + # is defined in local/prepare_lang_char.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + parameters_names = [] + parameters_names.append( + [name_param_pair[0] for name_param_pair in model.named_parameters()] + ) + optimizer = ScaledAdam( + model.parameters(), + lr=params.base_lr, + clipping_scale=2.0, + parameters_names=parameters_names, + ) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 2**22 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + if c.duration < 0.3 or c.duration > 20.0: + logging.warning( + f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + ) + return False + + # In pruned RNN-T, we require that T >= S + # where T is the number of feature frames after subsampling + # and S is the number of tokens in the utterance + + # In ./zipformer.py, the conv module uses the following expression + # for subsampling + T = ((c.num_frames - 7) // 2 + 1) // 2 + tokens = sp.encode(c.supervisions[0].text, out_type=str) + + if T < len(tokens): + logging.info( + f"Exclude cut with ID {c.id} from training. " + f"Number of frames (before subsampling): {c.num_frames}. " + f"Number of frames (after subsampling): {T}. " + f"Text: {c.supervisions[0].text}. " + f"Tokens: {tokens}. " + f"Number of tokens: {len(tokens)}" + ) + return False + + return True + + csj_corpus = CSJAsrDataModule(args) + train_cuts = csj_corpus.train_cuts() + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = csj_corpus.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) + + valid_cuts = csj_corpus.valid_cuts() + valid_dl = csj_corpus.valid_dataloaders(valid_cuts) + + if params.start_batch <= 0 and not params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + sp=sp, + params=params, + ) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sp=sp, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + sp: Tokenizer, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + sp: + The BPE model. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + y = sp.encode(supervisions["text"], out_type=int) + num_tokens = sum(len(i) for i in y) + logging.info(f"num tokens: {num_tokens}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + sp: Tokenizer, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params, sp=sp) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + parser = get_parser() + CSJAsrDataModule.add_arguments(parser) + Tokenizer.add_arguments(parser) + args = parser.parse_args() + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/zipformer.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/zipformer.py new file mode 120000 index 000000000..ec183baa7 --- /dev/null +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/zipformer.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py \ No newline at end of file From c5e687ddf5620a5804e3603f32601da8d136e70c Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Mon, 13 Feb 2023 23:41:43 +0800 Subject: [PATCH 117/174] Export streaming zipformer to ncnn (#906) --- .github/scripts/test-ncnn-export.sh | 99 + .../export-for-ncnn.py | 10 + .../streaming-ncnn-decode.py | 5 +- .../export-for-ncnn.py | 12 +- .../pruned_transducer_stateless7/decoder.py | 6 +- .../pruned_transducer_stateless7/joiner.py | 1 - .../scaling_converter.py | 104 +- .../README.md | 7 + .../export-for-ncnn-zh.py | 367 ++ .../export-for-ncnn.py | 369 ++ .../streaming-ncnn-decode.py | 419 +++ .../train2.py | 1265 +++++++ .../zipformer.py | 27 +- .../zipformer2.py | 3144 +++++++++++++++++ 14 files changed, 5805 insertions(+), 30 deletions(-) create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train2.py create mode 100644 egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py diff --git a/.github/scripts/test-ncnn-export.sh b/.github/scripts/test-ncnn-export.sh index c6d70ae7a..9dd7736c0 100755 --- a/.github/scripts/test-ncnn-export.sh +++ b/.github/scripts/test-ncnn-export.sh @@ -131,3 +131,102 @@ python3 ./lstm_transducer_stateless2/ncnn-decode.py \ rm -rf $repo log "--------------------------------------------------------------------------" + +log "==========================================================================" +repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "data/lang_bpe_500/bpe.model" +git lfs pull --include "exp/pretrained.pt" + +cd exp +ln -s pretrained.pt epoch-99.pt +popd + +./pruned_transducer_stateless7_streaming/export-for-ncnn.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --exp-dir $repo/exp \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + \ + --decode-chunk-len 32 \ + --num-encoder-layers "2,4,3,2,4" \ + --feedforward-dims "1024,1024,2048,2048,1024" \ + --nhead "8,8,8,8,8" \ + --encoder-dims "384,384,384,384,384" \ + --attention-dims "192,192,192,192,192" \ + --encoder-unmasked-dims "256,256,256,256,256" \ + --zipformer-downsampling-factors "1,2,4,8,2" \ + --cnn-module-kernels "31,31,31,31,31" \ + --decoder-dim 512 \ + --joiner-dim 512 + +./ncnn/tools/pnnx/build/src/pnnx $repo/exp/encoder_jit_trace-pnnx.pt +./ncnn/tools/pnnx/build/src/pnnx $repo/exp/decoder_jit_trace-pnnx.pt +./ncnn/tools/pnnx/build/src/pnnx $repo/exp/joiner_jit_trace-pnnx.pt + +python3 ./pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --encoder-param-filename $repo/exp/encoder_jit_trace-pnnx.ncnn.param \ + --encoder-bin-filename $repo/exp/encoder_jit_trace-pnnx.ncnn.bin \ + --decoder-param-filename $repo/exp/decoder_jit_trace-pnnx.ncnn.param \ + --decoder-bin-filename $repo/exp/decoder_jit_trace-pnnx.ncnn.bin \ + --joiner-param-filename $repo/exp/joiner_jit_trace-pnnx.ncnn.param \ + --joiner-bin-filename $repo/exp/joiner_jit_trace-pnnx.ncnn.bin \ + $repo/test_wavs/1089-134686-0001.wav + +rm -rf $repo +log "--------------------------------------------------------------------------" + +log "==========================================================================" +repo_url=https://huggingface.co/pfluo/k2fsa-zipformer-chinese-english-mixed +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "data/lang_char_bpe/L.pt" +git lfs pull --include "data/lang_char_bpe/L_disambig.pt" +git lfs pull --include "data/lang_char_bpe/Linv.pt" +git lfs pull --include "exp/pretrained.pt" + +cd exp +ln -s pretrained.pt epoch-99.pt +popd + +./pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py \ + --lang-dir $repo/data/lang_char_bpe \ + --exp-dir $repo/exp \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --decode-chunk-len 32 \ + --num-encoder-layers "2,4,3,2,4" \ + --feedforward-dims "1024,1024,1536,1536,1024" \ + --nhead "8,8,8,8,8" \ + --encoder-dims "384,384,384,384,384" \ + --attention-dims "192,192,192,192,192" \ + --encoder-unmasked-dims "256,256,256,256,256" \ + --zipformer-downsampling-factors "1,2,4,8,2" \ + --cnn-module-kernels "31,31,31,31,31" \ + --decoder-dim 512 \ + --joiner-dim 512 + +./ncnn/tools/pnnx/build/src/pnnx $repo/exp/encoder_jit_trace-pnnx.pt +./ncnn/tools/pnnx/build/src/pnnx $repo/exp/decoder_jit_trace-pnnx.pt +./ncnn/tools/pnnx/build/src/pnnx $repo/exp/joiner_jit_trace-pnnx.pt + +python3 ./pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py \ + --tokens $repo/data/lang_char_bpe/tokens.txt \ + --encoder-param-filename $repo/exp/encoder_jit_trace-pnnx.ncnn.param \ + --encoder-bin-filename $repo/exp/encoder_jit_trace-pnnx.ncnn.bin \ + --decoder-param-filename $repo/exp/decoder_jit_trace-pnnx.ncnn.param \ + --decoder-bin-filename $repo/exp/decoder_jit_trace-pnnx.ncnn.bin \ + --joiner-param-filename $repo/exp/joiner_jit_trace-pnnx.ncnn.param \ + --joiner-bin-filename $repo/exp/joiner_jit_trace-pnnx.ncnn.bin \ + $repo/test_wavs/0.wav + +rm -rf $repo +log "--------------------------------------------------------------------------" diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-for-ncnn.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-for-ncnn.py index e31033c74..8fbb02f14 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-for-ncnn.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-for-ncnn.py @@ -310,6 +310,16 @@ def main(): model.eval() convert_scaled_to_non_scaled(model, inplace=True) + + encoder_num_param = sum([p.numel() for p in model.encoder.parameters()]) + decoder_num_param = sum([p.numel() for p in model.decoder.parameters()]) + joiner_num_param = sum([p.numel() for p in model.joiner.parameters()]) + total_num_param = encoder_num_param + decoder_num_param + joiner_num_param + logging.info(f"encoder parameters: {encoder_num_param}") + logging.info(f"decoder parameters: {decoder_num_param}") + logging.info(f"joiner parameters: {joiner_num_param}") + logging.info(f"total parameters: {total_num_param}") + logging.info("Using torch.jit.trace()") logging.info("Exporting encoder") diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming-ncnn-decode.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming-ncnn-decode.py index e4104a5bb..74da9e6c8 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming-ncnn-decode.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming-ncnn-decode.py @@ -203,11 +203,8 @@ class Model: # (1, 512, 2) -> (512, 2) ex.input(name, ncnn.Mat(states[i * 4 + 3].numpy()).clone()) - import pdb - - # pdb.set_trace() ret, ncnn_out0 = ex.extract("out0") - # assert ret == 0, ret + assert ret == 0, ret encoder_out = torch.from_numpy(ncnn_out0.numpy()).clone() out_states: List[torch.Tensor] = [] diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/export-for-ncnn.py b/egs/librispeech/ASR/lstm_transducer_stateless2/export-for-ncnn.py index 7982ace68..08bfcb204 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/export-for-ncnn.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/export-for-ncnn.py @@ -99,7 +99,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="pruned_transducer_stateless2/exp", + default="lstm_transducer_stateless2/exp", help="""It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved """, @@ -316,6 +316,16 @@ def main(): model.eval() convert_scaled_to_non_scaled(model, inplace=True) + + encoder_num_param = sum([p.numel() for p in model.encoder.parameters()]) + decoder_num_param = sum([p.numel() for p in model.decoder.parameters()]) + joiner_num_param = sum([p.numel() for p in model.joiner.parameters()]) + total_num_param = encoder_num_param + decoder_num_param + joiner_num_param + logging.info(f"encoder parameters: {encoder_num_param}") + logging.info(f"decoder parameters: {decoder_num_param}") + logging.info(f"joiner parameters: {joiner_num_param}") + logging.info(f"total parameters: {total_num_param}") + logging.info("Using torch.jit.trace()") logging.info("Exporting encoder") diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/decoder.py b/egs/librispeech/ASR/pruned_transducer_stateless7/decoder.py index 5f90e6375..384b78524 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/decoder.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/decoder.py @@ -87,7 +87,11 @@ class Decoder(nn.Module): y = y.to(torch.int64) # this stuff about clamp() is a temporary fix for a mismatch # at utterance start, we use negative ids in beam_search.py - embedding_out = self.embedding(y.clamp(min=0)) * (y >= 0).unsqueeze(-1) + if torch.jit.is_tracing(): + # This is for exporting to PNNX via ONNX + embedding_out = self.embedding(y) + else: + embedding_out = self.embedding(y.clamp(min=0)) * (y >= 0).unsqueeze(-1) if self.context_size > 1: embedding_out = embedding_out.permute(0, 2, 1) if need_pad is True: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/joiner.py b/egs/librispeech/ASR/pruned_transducer_stateless7/joiner.py index 3ddac2cf2..62a4d22d6 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/joiner.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/joiner.py @@ -53,7 +53,6 @@ class Joiner(nn.Module): """ assert encoder_out.ndim == decoder_out.ndim assert encoder_out.ndim in (2, 4) - assert encoder_out.shape[:-1] == decoder_out.shape[:-1] if project_input: logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling_converter.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling_converter.py index 56165d1f9..86067b04f 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling_converter.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling_converter.py @@ -22,11 +22,101 @@ BasicNorm is replaced by a module with `exp` removed. """ import copy -from typing import List +from typing import List, Tuple import torch import torch.nn as nn from scaling import ActivationBalancer, BasicNorm, Whiten +from zipformer import PoolingModule + + +class PoolingModuleNoProj(nn.Module): + def forward( + self, + x: torch.Tensor, + cached_len: torch.Tensor, + cached_avg: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Args: + x: + A tensor of shape (T, N, C) + cached_len: + A tensor of shape (N,) + cached_avg: + A tensor of shape (N, C) + Returns: + Return a tuple containing: + - new_x + - new_cached_len + - new_cached_avg + """ + x = x.cumsum(dim=0) # (T, N, C) + x = x + (cached_avg * cached_len.unsqueeze(1)).unsqueeze(0) + # Cumulated numbers of frames from start + cum_mask = torch.arange(1, x.size(0) + 1, device=x.device) + cum_mask = cum_mask.unsqueeze(1) + cached_len.unsqueeze(0) # (T, N) + pooling_mask = (1.0 / cum_mask).unsqueeze(2) + # now pooling_mask: (T, N, 1) + x = x * pooling_mask # (T, N, C) + + cached_len = cached_len + x.size(0) + cached_avg = x[-1] + + return x, cached_len, cached_avg + + +class PoolingModuleWithProj(nn.Module): + def __init__(self, proj: torch.nn.Module): + super().__init__() + self.proj = proj + self.pooling = PoolingModuleNoProj() + + def forward( + self, + x: torch.Tensor, + cached_len: torch.Tensor, + cached_avg: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Args: + x: + A tensor of shape (T, N, C) + cached_len: + A tensor of shape (N,) + cached_avg: + A tensor of shape (N, C) + Returns: + Return a tuple containing: + - new_x + - new_cached_len + - new_cached_avg + """ + x, cached_len, cached_avg = self.pooling(x, cached_len, cached_avg) + return self.proj(x), cached_len, cached_avg + + def streaming_forward( + self, + x: torch.Tensor, + cached_len: torch.Tensor, + cached_avg: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Args: + x: + A tensor of shape (T, N, C) + cached_len: + A tensor of shape (N,) + cached_avg: + A tensor of shape (N, C) + Returns: + Return a tuple containing: + - new_x + - new_cached_len + - new_cached_avg + """ + x, cached_len, cached_avg = self.pooling(x, cached_len, cached_avg) + return self.proj(x), cached_len, cached_avg class NonScaledNorm(nn.Module): @@ -53,7 +143,7 @@ class NonScaledNorm(nn.Module): def convert_basic_norm(basic_norm: BasicNorm) -> NonScaledNorm: - assert isinstance(basic_norm, BasicNorm), type(BasicNorm) + assert isinstance(basic_norm, BasicNorm), type(basic_norm) norm = NonScaledNorm( num_channels=basic_norm.num_channels, eps_exp=basic_norm.eps.data.exp().item(), @@ -62,6 +152,11 @@ def convert_basic_norm(basic_norm: BasicNorm) -> NonScaledNorm: return norm +def convert_pooling_module(pooling: PoolingModule) -> PoolingModuleWithProj: + assert isinstance(pooling, PoolingModule), type(pooling) + return PoolingModuleWithProj(proj=pooling.proj) + + # Copied from https://pytorch.org/docs/1.9.0/_modules/torch/nn/modules/module.html#Module.get_submodule # noqa # get_submodule was added to nn.Module at v1.9.0 def get_submodule(model, target): @@ -83,6 +178,7 @@ def get_submodule(model, target): def convert_scaled_to_non_scaled( model: nn.Module, inplace: bool = False, + is_pnnx: bool = False, ): """ Args: @@ -91,6 +187,8 @@ def convert_scaled_to_non_scaled( inplace: If True, the input model is modified inplace. If False, the input model is copied and we modify the copied version. + is_pnnx: + True if we are going to export the model for PNNX. Return: Return a model without scaled layers. """ @@ -103,6 +201,8 @@ def convert_scaled_to_non_scaled( d[name] = convert_basic_norm(m) elif isinstance(m, (ActivationBalancer, Whiten)): d[name] = nn.Identity() + elif isinstance(m, PoolingModule) and is_pnnx: + d[name] = convert_pooling_module(m) for k, v in d.items(): if "." in k: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/README.md b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/README.md index 6e461e196..d3691e647 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/README.md +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/README.md @@ -1,3 +1,10 @@ This recipe implements Streaming Zipformer-Transducer model. See https://k2-fsa.github.io/icefall/recipes/Streaming-ASR/librispeech/zipformer_transducer.html for detailed tutorials. + +[./emformer.py](./emformer.py) and [./train.py](./train.py) +are basically the same as +[./emformer2.py](./emformer2.py) and [./train2.py](./train2.py). +The only purpose of [./emformer2.py](./emformer2.py) and [./train2.py](./train2.py) +is for exporting to [sherpa-ncnn](https://github.com/k2-fsa/sherpa-ncnn). + diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py new file mode 100755 index 000000000..1f870ca5a --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py @@ -0,0 +1,367 @@ +#!/usr/bin/env python3 + +""" +Please see +https://k2-fsa.github.io/icefall/model-export/export-ncnn.html +for more details about how to use this file. + +We use +https://huggingface.co/pfluo/k2fsa-zipformer-chinese-english-mixed +to demonstrate the usage of this file. + +1. Download the pre-trained model + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/pfluo/k2fsa-zipformer-chinese-english-mixed +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "data/lang_char_bpe/L.pt" +git lfs pull --include "data/lang_char_bpe/L_disambig.pt" +git lfs pull --include "data/lang_char_bpe/Linv.pt" +git lfs pull --include "exp/pretrained.pt" + +cd exp +ln -s pretrained.pt epoch-99.pt +popd + +2. Export to ncnn + +./pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py \ + --lang-dir $repo/data/lang_char_bpe \ + --exp-dir $repo/exp \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --decode-chunk-len 32 \ + --num-encoder-layers "2,4,3,2,4" \ + --feedforward-dims "1024,1024,1536,1536,1024" \ + --nhead "8,8,8,8,8" \ + --encoder-dims "384,384,384,384,384" \ + --attention-dims "192,192,192,192,192" \ + --encoder-unmasked-dims "256,256,256,256,256" \ + --zipformer-downsampling-factors "1,2,4,8,2" \ + --cnn-module-kernels "31,31,31,31,31" \ + --decoder-dim 512 \ + --joiner-dim 512 + +cd $repo/exp + +pnnx encoder_jit_trace-pnnx.pt +pnnx decoder_jit_trace-pnnx.pt +pnnx joiner_jit_trace-pnnx.pt + +You can find converted models at +https://huggingface.co/csukuangfj/sherpa-ncnn-streaming-zipformer-bilingual-zh-en-2023-02-13 + +See ./streaming-ncnn-decode.py +and +https://github.com/k2-fsa/sherpa-ncnn +for usage. +""" + +import argparse +import logging +from pathlib import Path + +import torch +from scaling_converter import convert_scaled_to_non_scaled +from train2 import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import setup_logger, str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="""It specifies the checkpoint to use for averaging. + Note: Epoch counts from 0. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7_streaming/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_char", + help="The lang dir", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + add_model_arguments(parser) + + return parser + + +def export_encoder_model_jit_trace( + encoder_model: torch.nn.Module, + encoder_filename: str, +) -> None: + """Export the given encoder model with torch.jit.trace() + + Note: The warmup argument is fixed to 1. + + Args: + encoder_model: + The input encoder model + encoder_filename: + The filename to save the exported model. + """ + encoder_model.__class__.forward = encoder_model.__class__.streaming_forward + + decode_chunk_len = encoder_model.decode_chunk_size * 2 + pad_length = 7 + T = decode_chunk_len + pad_length # 32 + 7 = 39 + + logging.info(f"decode_chunk_len: {decode_chunk_len}") + logging.info(f"T: {T}") + + x = torch.zeros(1, T, 80, dtype=torch.float32) + states = encoder_model.get_init_state() + + traced_model = torch.jit.trace(encoder_model, (x, states)) + traced_model.save(encoder_filename) + logging.info(f"Saved to {encoder_filename}") + + +def export_decoder_model_jit_trace( + decoder_model: torch.nn.Module, + decoder_filename: str, +) -> None: + """Export the given decoder model with torch.jit.trace() + + Note: The argument need_pad is fixed to False. + + Args: + decoder_model: + The input decoder model + decoder_filename: + The filename to save the exported model. + """ + y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64) + need_pad = torch.tensor([False]) + + traced_model = torch.jit.trace(decoder_model, (y, need_pad)) + traced_model.save(decoder_filename) + logging.info(f"Saved to {decoder_filename}") + + +def export_joiner_model_jit_trace( + joiner_model: torch.nn.Module, + joiner_filename: str, +) -> None: + """Export the given joiner model with torch.jit.trace() + + Note: The argument project_input is fixed to True. A user should not + project the encoder_out/decoder_out by himself/herself. The exported joiner + will do that for the user. + + Args: + joiner_model: + The input joiner model + joiner_filename: + The filename to save the exported model. + + """ + encoder_out_dim = joiner_model.encoder_proj.weight.shape[1] + decoder_out_dim = joiner_model.decoder_proj.weight.shape[1] + encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32) + decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32) + + traced_model = torch.jit.trace(joiner_model, (encoder_out, decoder_out)) + traced_model.save(joiner_filename) + logging.info(f"Saved to {joiner_filename}") + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + + setup_logger(f"{params.exp_dir}/log-export/log-export-ncnn") + + logging.info(f"device: {device}") + + lexicon = Lexicon(params.lang_dir) + params.blank_id = 0 + params.vocab_size = max(lexicon.tokens) + 1 + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to("cpu") + model.eval() + + convert_scaled_to_non_scaled(model, inplace=True, is_pnnx=True) + + encoder_num_param = sum([p.numel() for p in model.encoder.parameters()]) + decoder_num_param = sum([p.numel() for p in model.decoder.parameters()]) + joiner_num_param = sum([p.numel() for p in model.joiner.parameters()]) + total_num_param = encoder_num_param + decoder_num_param + joiner_num_param + logging.info(f"encoder parameters: {encoder_num_param}") + logging.info(f"decoder parameters: {decoder_num_param}") + logging.info(f"joiner parameters: {joiner_num_param}") + logging.info(f"total parameters: {total_num_param}") + + logging.info("Using torch.jit.trace()") + + logging.info("Exporting encoder") + encoder_filename = params.exp_dir / "encoder_jit_trace-pnnx.pt" + export_encoder_model_jit_trace(model.encoder, encoder_filename) + + logging.info("Exporting decoder") + decoder_filename = params.exp_dir / "decoder_jit_trace-pnnx.pt" + export_decoder_model_jit_trace(model.decoder, decoder_filename) + + logging.info("Exporting joiner") + joiner_filename = params.exp_dir / "joiner_jit_trace-pnnx.pt" + export_joiner_model_jit_trace(model.joiner, joiner_filename) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py new file mode 100755 index 000000000..0f84eca83 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py @@ -0,0 +1,369 @@ +#!/usr/bin/env python3 + +""" +Please see +https://k2-fsa.github.io/icefall/model-export/export-ncnn.html +for more details about how to use this file. + +We use +https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29 +to demonstrate the usage of this file. + +1. Download the pre-trained model + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "data/lang_bpe/bpe.model" +git lfs pull --include "exp/pretrained.pt" + +cd exp +ln -s pretrained.pt epoch-99.pt +popd + +2. Export to ncnn + +./pruned_transducer_stateless7_streaming/export-for-ncnn.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --exp-dir $repo/exp \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + \ + --decode-chunk-len 32 \ + --num-encoder-layers "2,4,3,2,4" \ + --feedforward-dims "1024,1024,2048,2048,1024" \ + --nhead "8,8,8,8,8" \ + --encoder-dims "384,384,384,384,384" \ + --attention-dims "192,192,192,192,192" \ + --encoder-unmasked-dims "256,256,256,256,256" \ + --zipformer-downsampling-factors "1,2,4,8,2" \ + --cnn-module-kernels "31,31,31,31,31" \ + --decoder-dim 512 \ + --joiner-dim 512 + +cd $repo/exp + +pnnx encoder_jit_trace-pnnx.pt +pnnx decoder_jit_trace-pnnx.pt +pnnx joiner_jit_trace-pnnx.pt + +You can find converted models at +https://huggingface.co/csukuangfj/sherpa-ncnn-streaming-zipformer-en-2023-02-13 + +See ./streaming-ncnn-decode.py +and +https://github.com/k2-fsa/sherpa-ncnn +for usage. +""" + +import argparse +import logging +from pathlib import Path + +import sentencepiece as spm +import torch +from scaling_converter import convert_scaled_to_non_scaled +from train2 import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import setup_logger, str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="""It specifies the checkpoint to use for averaging. + Note: Epoch counts from 0. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7_streaming/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + add_model_arguments(parser) + + return parser + + +def export_encoder_model_jit_trace( + encoder_model: torch.nn.Module, + encoder_filename: str, +) -> None: + """Export the given encoder model with torch.jit.trace() + + Note: The warmup argument is fixed to 1. + + Args: + encoder_model: + The input encoder model + encoder_filename: + The filename to save the exported model. + """ + encoder_model.__class__.forward = encoder_model.__class__.streaming_forward + + decode_chunk_len = encoder_model.decode_chunk_size * 2 + pad_length = 7 + T = decode_chunk_len + pad_length # 32 + 7 = 39 + + logging.info(f"decode_chunk_len: {decode_chunk_len}") + logging.info(f"T: {T}") + + x = torch.zeros(1, T, 80, dtype=torch.float32) + states = encoder_model.get_init_state() + + traced_model = torch.jit.trace(encoder_model, (x, states)) + traced_model.save(encoder_filename) + logging.info(f"Saved to {encoder_filename}") + + +def export_decoder_model_jit_trace( + decoder_model: torch.nn.Module, + decoder_filename: str, +) -> None: + """Export the given decoder model with torch.jit.trace() + + Note: The argument need_pad is fixed to False. + + Args: + decoder_model: + The input decoder model + decoder_filename: + The filename to save the exported model. + """ + y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64) + need_pad = torch.tensor([False]) + + traced_model = torch.jit.trace(decoder_model, (y, need_pad)) + traced_model.save(decoder_filename) + logging.info(f"Saved to {decoder_filename}") + + +def export_joiner_model_jit_trace( + joiner_model: torch.nn.Module, + joiner_filename: str, +) -> None: + """Export the given joiner model with torch.jit.trace() + + Note: The argument project_input is fixed to True. A user should not + project the encoder_out/decoder_out by himself/herself. The exported joiner + will do that for the user. + + Args: + joiner_model: + The input joiner model + joiner_filename: + The filename to save the exported model. + + """ + encoder_out_dim = joiner_model.encoder_proj.weight.shape[1] + decoder_out_dim = joiner_model.decoder_proj.weight.shape[1] + encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32) + decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32) + + traced_model = torch.jit.trace(joiner_model, (encoder_out, decoder_out)) + traced_model.save(joiner_filename) + logging.info(f"Saved to {joiner_filename}") + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + + setup_logger(f"{params.exp_dir}/log-export/log-export-ncnn") + + logging.info(f"device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to("cpu") + model.eval() + + convert_scaled_to_non_scaled(model, inplace=True, is_pnnx=True) + + encoder_num_param = sum([p.numel() for p in model.encoder.parameters()]) + decoder_num_param = sum([p.numel() for p in model.decoder.parameters()]) + joiner_num_param = sum([p.numel() for p in model.joiner.parameters()]) + total_num_param = encoder_num_param + decoder_num_param + joiner_num_param + logging.info(f"encoder parameters: {encoder_num_param}") + logging.info(f"decoder parameters: {decoder_num_param}") + logging.info(f"joiner parameters: {joiner_num_param}") + logging.info(f"total parameters: {total_num_param}") + + logging.info("Using torch.jit.trace()") + + logging.info("Exporting encoder") + encoder_filename = params.exp_dir / "encoder_jit_trace-pnnx.pt" + export_encoder_model_jit_trace(model.encoder, encoder_filename) + + logging.info("Exporting decoder") + decoder_filename = params.exp_dir / "decoder_jit_trace-pnnx.pt" + export_decoder_model_jit_trace(model.decoder, decoder_filename) + + logging.info("Exporting joiner") + joiner_filename = params.exp_dir / "joiner_jit_trace-pnnx.pt" + export_joiner_model_jit_trace(model.joiner, joiner_filename) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py new file mode 100755 index 000000000..5a36b695f --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py @@ -0,0 +1,419 @@ +#!/usr/bin/env python3 +# +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang, Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +./pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py \ + --tokens ./sherpa-ncnn-streaming-zipformer-en-2023-02-13/tokens.txt \ + --encoder-param-filename ./sherpa-ncnn-streaming-zipformer-en-2023-02-13/encoder_jit_trace-pnnx.ncnn.param \ + --encoder-bin-filename ./sherpa-ncnn-streaming-zipformer-en-2023-02-13/encoder_jit_trace-pnnx.ncnn.bin \ + --decoder-param-filename ./sherpa-ncnn-streaming-zipformer-en-2023-02-13/decoder_jit_trace-pnnx.ncnn.param \ + --decoder-bin-filename ./sherpa-ncnn-streaming-zipformer-en-2023-02-13/decoder_jit_trace-pnnx.ncnn.bin \ + --joiner-param-filename ./sherpa-ncnn-streaming-zipformer-en-2023-02-13/joiner_jit_trace-pnnx.ncnn.param \ + --joiner-bin-filename ./sherpa-ncnn-streaming-zipformer-en-2023-02-13/joiner_jit_trace-pnnx.ncnn.bin \ + ./sherpa-ncnn-streaming-zipformer-en-2023-02-13/test_wavs/1089-134686-0001.wav + +You can find pretrained models at +- English: https://huggingface.co/csukuangfj/sherpa-ncnn-streaming-zipformer-en-2023-02-13 +- Bilingual (Chinese + English): https://huggingface.co/csukuangfj/sherpa-ncnn-streaming-zipformer-bilingual-zh-en-2023-02-13 +""" + +import argparse +import logging +from typing import List, Optional, Tuple + +import k2 +import ncnn +import torch +import torchaudio +from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--tokens", + type=str, + help="Path to tokens.txt", + ) + + parser.add_argument( + "--encoder-param-filename", + type=str, + help="Path to encoder.ncnn.param", + ) + + parser.add_argument( + "--encoder-bin-filename", + type=str, + help="Path to encoder.ncnn.bin", + ) + + parser.add_argument( + "--decoder-param-filename", + type=str, + help="Path to decoder.ncnn.param", + ) + + parser.add_argument( + "--decoder-bin-filename", + type=str, + help="Path to decoder.ncnn.bin", + ) + + parser.add_argument( + "--joiner-param-filename", + type=str, + help="Path to joiner.ncnn.param", + ) + + parser.add_argument( + "--joiner-bin-filename", + type=str, + help="Path to joiner.ncnn.bin", + ) + + parser.add_argument( + "sound_filename", + type=str, + help="Path to foo.wav", + ) + + return parser.parse_args() + + +def to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + + +class Model: + def __init__(self, args): + self.init_encoder(args) + self.init_decoder(args) + self.init_joiner(args) + + # Please change the parameters according to your model + self.num_encoder_layers = to_int_tuple("2,4,3,2,4") + self.encoder_dims = to_int_tuple("384,384,384,384,384") # also known as d_model + self.attention_dims = to_int_tuple("192,192,192,192,192") + self.zipformer_downsampling_factors = to_int_tuple("1,2,4,8,2") + self.cnn_module_kernels = to_int_tuple("31,31,31,31,31") + + self.decode_chunk_size = 32 // 2 + num_left_chunks = 4 + self.left_context_length = self.decode_chunk_size * num_left_chunks # 64 + + self.chunk_length = self.decode_chunk_size * 2 + pad_length = 7 + self.T = self.chunk_length + pad_length + + def get_init_states(self) -> List[torch.Tensor]: + cached_len_list = [] + cached_avg_list = [] + cached_key_list = [] + cached_val_list = [] + cached_val2_list = [] + cached_conv1_list = [] + cached_conv2_list = [] + + for i in range(len(self.num_encoder_layers)): + num_layers = self.num_encoder_layers[i] + ds = self.zipformer_downsampling_factors[i] + attention_dim = self.attention_dims[i] + left_context_length = self.left_context_length // ds + encoder_dim = self.encoder_dims[i] + cnn_module_kernel = self.cnn_module_kernels[i] + + cached_len_list.append(torch.zeros(num_layers)) + cached_avg_list.append(torch.zeros(num_layers, encoder_dim)) + cached_key_list.append( + torch.zeros(num_layers, left_context_length, attention_dim) + ) + cached_val_list.append( + torch.zeros(num_layers, left_context_length, attention_dim // 2) + ) + cached_val2_list.append( + torch.zeros(num_layers, left_context_length, attention_dim // 2) + ) + cached_conv1_list.append( + torch.zeros(num_layers, encoder_dim, cnn_module_kernel - 1) + ) + cached_conv2_list.append( + torch.zeros(num_layers, encoder_dim, cnn_module_kernel - 1) + ) + + states = ( + cached_len_list + + cached_avg_list + + cached_key_list + + cached_val_list + + cached_val2_list + + cached_conv1_list + + cached_conv2_list + ) + + return states + + def init_encoder(self, args): + encoder_net = ncnn.Net() + encoder_net.opt.use_packing_layout = False + encoder_net.opt.use_fp16_storage = False + encoder_net.opt.num_threads = 4 + + encoder_param = args.encoder_param_filename + encoder_model = args.encoder_bin_filename + + encoder_net.load_param(encoder_param) + encoder_net.load_model(encoder_model) + + self.encoder_net = encoder_net + + def init_decoder(self, args): + decoder_param = args.decoder_param_filename + decoder_model = args.decoder_bin_filename + + decoder_net = ncnn.Net() + decoder_net.opt.num_threads = 4 + + decoder_net.load_param(decoder_param) + decoder_net.load_model(decoder_model) + + self.decoder_net = decoder_net + + def init_joiner(self, args): + joiner_param = args.joiner_param_filename + joiner_model = args.joiner_bin_filename + joiner_net = ncnn.Net() + joiner_net.opt.num_threads = 4 + + joiner_net.load_param(joiner_param) + joiner_net.load_model(joiner_model) + + self.joiner_net = joiner_net + + def run_encoder( + self, + x: torch.Tensor, + states: List[torch.Tensor], + ) -> Tuple[torch.Tensor, List[torch.Tensor]]: + """ + Args: + x: + A tensor of shape (T, C) + states: + A list of tensors. len(states) == self.num_layers * 4 + Returns: + Return a tuple containing: + - encoder_out, a tensor of shape (T, encoder_dim). + - next_states, a list of tensors containing the next states + """ + with self.encoder_net.create_extractor() as ex: + ex.input("in0", ncnn.Mat(x.numpy()).clone()) + + for i in range(len(states)): + name = f"in{i+1}" + ex.input(name, ncnn.Mat(states[i].squeeze().numpy()).clone()) + + ret, ncnn_out0 = ex.extract("out0") + assert ret == 0, ret + encoder_out = torch.from_numpy(ncnn_out0.numpy()).clone() + + out_states: List[torch.Tensor] = [] + for i in range(len(states)): + name = f"out{i+1}" + ret, ncnn_out_state = ex.extract(name) + assert ret == 0, ret + ncnn_out_state = torch.from_numpy(ncnn_out_state.numpy()) + + if i < len(self.num_encoder_layers): + # for cached_len, we need to discard the last dim + ncnn_out_state = ncnn_out_state.squeeze(1) + + out_states.append(ncnn_out_state) + + return encoder_out, out_states + + def run_decoder(self, decoder_input): + assert decoder_input.dtype == torch.int32 + + with self.decoder_net.create_extractor() as ex: + ex.input("in0", ncnn.Mat(decoder_input.numpy()).clone()) + ret, ncnn_out0 = ex.extract("out0") + assert ret == 0, ret + decoder_out = torch.from_numpy(ncnn_out0.numpy()).clone() + return decoder_out + + def run_joiner(self, encoder_out, decoder_out): + with self.joiner_net.create_extractor() as ex: + ex.input("in0", ncnn.Mat(encoder_out.numpy()).clone()) + ex.input("in1", ncnn.Mat(decoder_out.numpy()).clone()) + ret, ncnn_out0 = ex.extract("out0") + assert ret == 0, ret + joiner_out = torch.from_numpy(ncnn_out0.numpy()).clone() + return joiner_out + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0]) + return ans + + +def create_streaming_feature_extractor() -> OnlineFeature: + """Create a CPU streaming feature extractor. + + At present, we assume it returns a fbank feature extractor with + fixed options. In the future, we will support passing in the options + from outside. + + Returns: + Return a CPU streaming feature extractor. + """ + opts = FbankOptions() + opts.device = "cpu" + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = 16000 + opts.mel_opts.num_bins = 80 + return OnlineFbank(opts) + + +def greedy_search( + model: Model, + encoder_out: torch.Tensor, + decoder_out: Optional[torch.Tensor] = None, + hyp: Optional[List[int]] = None, +): + context_size = 2 + blank_id = 0 + + if decoder_out is None: + assert hyp is None, hyp + hyp = [blank_id] * context_size + decoder_input = torch.tensor(hyp, dtype=torch.int32) # (1, context_size) + decoder_out = model.run_decoder(decoder_input).squeeze(0) + else: + assert decoder_out.ndim == 1 + assert hyp is not None, hyp + + T = encoder_out.size(0) + for t in range(T): + cur_encoder_out = encoder_out[t] + + joiner_out = model.run_joiner(cur_encoder_out, decoder_out) + y = joiner_out.argmax(dim=0).item() + if y != blank_id: + hyp.append(y) + decoder_input = hyp[-context_size:] + decoder_input = torch.tensor(decoder_input, dtype=torch.int32) + decoder_out = model.run_decoder(decoder_input).squeeze(0) + + return hyp, decoder_out + + +def main(): + args = get_args() + logging.info(vars(args)) + + model = Model(args) + + sound_file = args.sound_filename + + sample_rate = 16000 + + logging.info("Constructing Fbank computer") + online_fbank = create_streaming_feature_extractor() + + logging.info(f"Reading sound files: {sound_file}") + wave_samples = read_sound_files( + filenames=[sound_file], + expected_sample_rate=sample_rate, + )[0] + logging.info(wave_samples.shape) + + tail_padding = torch.zeros(int(0.3 * sample_rate), dtype=torch.float32) + + wave_samples = torch.cat([wave_samples, tail_padding]) + + states = model.get_init_states() + logging.info(f"number of states: {len(states)}") + + hyp = None + decoder_out = None + + num_processed_frames = 0 + segment = model.T + offset = model.chunk_length + + chunk = int(1 * sample_rate) # 0.2 second + + start = 0 + while start < wave_samples.numel(): + end = min(start + chunk, wave_samples.numel()) + samples = wave_samples[start:end] + start += chunk + + online_fbank.accept_waveform( + sampling_rate=sample_rate, + waveform=samples, + ) + while online_fbank.num_frames_ready - num_processed_frames >= segment: + frames = [] + for i in range(segment): + frames.append(online_fbank.get_frame(num_processed_frames + i)) + num_processed_frames += offset + frames = torch.cat(frames, dim=0) + encoder_out, states = model.run_encoder(frames, states) + hyp, decoder_out = greedy_search(model, encoder_out, decoder_out, hyp) + + symbol_table = k2.SymbolTable.from_file(args.tokens) + + context_size = 2 + text = "" + for i in hyp[context_size:]: + text += symbol_table[i] + text = text.replace("▁", " ").strip() + + logging.info(sound_file) + logging.info(text) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train2.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train2.py new file mode 100755 index 000000000..5437e961e --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train2.py @@ -0,0 +1,1265 @@ +#!/usr/bin/env python3 +# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo,) +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +./pruned_transducer_stateless7_streaming/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --exp-dir pruned_transducer_stateless7_streaming/exp \ + --full-libri 1 \ + --max-duration 300 + +# For mix precision training: + +./pruned_transducer_stateless7_streaming/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir pruned_transducer_stateless7_streaming/exp \ + --full-libri 1 \ + --max-duration 550 +""" + + +import argparse +import copy +import logging +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import k2 +import optim +import sentencepiece as spm +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from decoder import Decoder +from joiner import Joiner +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model import Transducer +from optim import Eden, ScaledAdam +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from zipformer2 import Zipformer + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for module in model.modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=str, + default="2,4,3,2,4", + help="Number of zipformer encoder layers, comma separated.", + ) + + parser.add_argument( + "--feedforward-dims", + type=str, + default="1024,1024,2048,2048,1024", + help="Feedforward dimension of the zipformer encoder layers, comma separated.", + ) + + parser.add_argument( + "--nhead", + type=str, + default="8,8,8,8,8", + help="Number of attention heads in the zipformer encoder layers.", + ) + + parser.add_argument( + "--encoder-dims", + type=str, + default="384,384,384,384,384", + help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated", + ) + + parser.add_argument( + "--attention-dims", + type=str, + default="192,192,192,192,192", + help="""Attention dimension in the 2 blocks of zipformer encoder layers, comma separated; + not the same as embedding dimension.""", + ) + + parser.add_argument( + "--encoder-unmasked-dims", + type=str, + default="256,256,256,256,256", + help="Unmasked dimensions in the encoders, relates to augmentation during training. " + "Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance " + " worse.", + ) + + parser.add_argument( + "--zipformer-downsampling-factors", + type=str, + default="1,2,4,8,2", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--cnn-module-kernels", + type=str, + default="31,31,31,31,31", + help="Sizes of kernels in convolution modules", + ) + + parser.add_argument( + "--decoder-dim", + type=int, + default=512, + help="Embedding dimension in the decoder model.", + ) + + parser.add_argument( + "--joiner-dim", + type=int, + default=512, + help="""Dimension used in the joiner model. + Outputs from the encoder and decoder model are projected + to this dimension before adding. + """, + ) + + parser.add_argument( + "--short-chunk-size", + type=int, + default=50, + help="""Chunk length of dynamic training, the chunk size would be either + max sequence length of current batch or uniformly sampled from (1, short_chunk_size). + """, + ) + + parser.add_argument( + "--num-left-chunks", + type=int, + default=4, + help="How many left context can be seen in chunks when calculating attention.", + ) + + parser.add_argument( + "--decode-chunk-len", + type=int, + default=32, + help="The chunk size for decoding (in frames before subsampling)", + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7_streaming/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--base-lr", type=float, default=0.05, help="The base learning rate." + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=5000, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=3.5, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network) part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=2000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 0. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - encoder_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warmup period that dictates the decay of the + scale on "simple" (un-pruned) loss. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, # For the 100h subset, use 800 + # parameters for zipformer + "feature_dim": 80, + "subsampling_factor": 4, # not passed in, this is fixed. + "warm_step": 2000, + "env_info": get_env_info(), + } + ) + + return params + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + # TODO: We can add an option to switch between Zipformer and Transformer + def to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + + encoder = Zipformer( + num_features=params.feature_dim, + output_downsampling_factor=2, + zipformer_downsampling_factors=to_int_tuple( + params.zipformer_downsampling_factors + ), + encoder_dims=to_int_tuple(params.encoder_dims), + attention_dim=to_int_tuple(params.attention_dims), + encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims), + nhead=to_int_tuple(params.nhead), + feedforward_dim=to_int_tuple(params.feedforward_dims), + cnn_module_kernels=to_int_tuple(params.cnn_module_kernels), + num_encoder_layers=to_int_tuple(params.num_encoder_layers), + num_left_chunks=params.num_left_chunks, + short_chunk_size=params.short_chunk_size, + decode_chunk_size=params.decode_chunk_len // 2, + is_pnnx=True, + ) + return encoder + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + decoder_dim=params.decoder_dim, + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + joiner = Joiner( + encoder_dim=int(params.encoder_dims.split(",")[-1]), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return joiner + + +def get_transducer_model(params: AttributeDict) -> nn.Module: + encoder = get_encoder_model(params) + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + + model = Transducer( + encoder=encoder, + decoder=decoder, + joiner=joiner, + encoder_dim=int(params.encoder_dims.split(",")[-1]), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + if "cur_batch_idx" in saved_params: + params["cur_batch_idx"] = saved_params["cur_batch_idx"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute transducer loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + batch_idx_train = params.batch_idx_train + warm_step = params.warm_step + + texts = batch["supervisions"]["text"] + y = sp.encode(texts, out_type=int) + y = k2.RaggedTensor(y).to(device) + + with torch.set_grad_enabled(is_training): + simple_loss, pruned_loss = model( + x=feature, + x_lens=feature_lens, + y=y, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + ) + + s = params.simple_loss_scale + # take down the scale on the simple loss from 1.0 at the start + # to params.simple_loss scale by warm_step. + simple_loss_scale = ( + s + if batch_idx_train >= warm_step + else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) + ) + pruned_loss_scale = ( + 1.0 + if batch_idx_train >= warm_step + else 0.1 + 0.9 * (batch_idx_train / warm_step) + ) + + loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + sp: spm.SentencePieceProcessor, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + cur_batch_idx = params.get("cur_batch_idx", 0) + + for batch_idx, batch in enumerate(train_dl): + if batch_idx < cur_batch_idx: + continue + cur_batch_idx = batch_idx + + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + set_batch_count(model, params.batch_idx_train) + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except: # noqa + display_and_save_batch(batch, params=params, sp=sp) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + params.cur_batch_idx = batch_idx + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + del params.cur_batch_idx + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if batch_idx % params.log_interval == 0: + cur_lr = scheduler.get_last_lr()[0] + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", + cur_grad_scale, + params.batch_idx_train, + ) + + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + if params.full_libri is False: + params.valid_interval = 1600 + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + parameters_names = [] + parameters_names.append( + [name_param_pair[0] for name_param_pair in model.named_parameters()] + ) + optimizer = ScaledAdam( + model.parameters(), + lr=params.base_lr, + clipping_scale=2.0, + parameters_names=parameters_names, + ) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 2**22 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + librispeech = LibriSpeechAsrDataModule(args) + + if params.full_libri: + train_cuts = librispeech.train_all_shuf_cuts() + else: + train_cuts = librispeech.train_clean_100_cuts() + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + if c.duration < 1.0 or c.duration > 20.0: + logging.warning( + f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + ) + return False + + # In pruned RNN-T, we require that T >= S + # where T is the number of feature frames after subsampling + # and S is the number of tokens in the utterance + + # In ./zipformer.py, the conv module uses the following expression + # for subsampling + T = ((c.num_frames - 7) // 2 + 1) // 2 + tokens = sp.encode(c.supervisions[0].text, out_type=str) + + if T < len(tokens): + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Number of frames (before subsampling): {c.num_frames}. " + f"Number of frames (after subsampling): {T}. " + f"Text: {c.supervisions[0].text}. " + f"Tokens: {tokens}. " + f"Number of tokens: {len(tokens)}" + ) + return False + + return True + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = librispeech.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) + + valid_cuts = librispeech.dev_clean_cuts() + valid_cuts += librispeech.dev_other_cuts() + valid_dl = librispeech.valid_dataloaders(valid_cuts) + + if not params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + sp=sp, + params=params, + ) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sp=sp, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + sp: spm.SentencePieceProcessor, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + sp: + The BPE model. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + y = sp.encode(supervisions["text"], out_type=int) + num_tokens = sum(len(i) for i in y) + logging.info(f"num tokens: {num_tokens}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + sp: spm.SentencePieceProcessor, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params, sp=sp) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py index f7e52a9e6..a5c422959 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py @@ -44,7 +44,6 @@ from scaling import ( ) from torch import Tensor, nn -from icefall.dist import get_rank from icefall.utils import make_pad_mask, subsequent_chunk_mask @@ -271,7 +270,6 @@ class Zipformer(EncoderInterface): num_encoder_layers (int): number of encoder layers dropout (float): dropout rate cnn_module_kernels (int): Kernel size of convolution module - vgg_frontend (bool): whether to use vgg frontend. warmup_batches (float): number of batches to warm up over """ @@ -388,9 +386,9 @@ class Zipformer(EncoderInterface): def _init_skip_modules(self): """ If self.zipformer_downsampling_factors = (1, 2, 4, 8, 4, 2), then at the input of layer - indexed 4 (in zero indexing), with has subsapling_factor=4, we combine the output of - layers 2 and 3; and at the input of layer indexed 5, which which has subsampling_factor=2, - we combine the outputs of layers 1 and 5. + indexed 4 (in zero indexing), which has subsampling_factor=4, we combine the output of + layers 2 and 3; and at the input of layer indexed 5, which has subsampling_factor=2, + we combine the outputs of layers 1 and 4. """ skip_layers = [] skip_modules = [] @@ -1272,8 +1270,7 @@ class ZipformerEncoder(nn.Module): Shape: src: (S, N, E). - cached_len: (N,) - N is the batch size. + cached_len: (num_layers,) cached_avg: (num_layers, N, C). N is the batch size, C is the feature dimension. cached_key: (num_layers, left_context_len, N, K). @@ -1289,8 +1286,8 @@ class ZipformerEncoder(nn.Module): Returns: A tuple of 8 tensors: - output tensor - - updated cached number of past frmaes. - - updated cached average of past frmaes. + - updated cached number of past frames. + - updated cached average of past frames. - updated cached key tensor of of the first attention module. - updated cached value tensor of of the first attention module. - updated cached value tensor of of the second attention module. @@ -1522,9 +1519,6 @@ class AttentionDownsample(torch.nn.Module): """ def __init__(self, in_channels: int, out_channels: int, downsample: int): - """ - Require out_channels > in_channels. - """ super(AttentionDownsample, self).__init__() self.query = nn.Parameter(torch.randn(in_channels) * (in_channels**-0.5)) @@ -1902,8 +1896,6 @@ class RelPositionMultiheadAttention(nn.Module): Args: x: input to be projected to query, key, value pos_emb: Positional embedding tensor - attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all - the batches while a 3D mask allows to specify a different mask for the entries of each batch. Shape: - Inputs: @@ -1911,13 +1903,6 @@ class RelPositionMultiheadAttention(nn.Module): the embedding dimension. - pos_emb: :math:`(N, 2*L-1, E)` where L is the target sequence length, N is the batch size, E is the embedding dimension. - - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. - 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, - S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked - positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend - while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` - is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor - is provided, it will be added to the attention weight. - cached_key: :math:`(left_context_len, N, K)`, where N is the batch size, K is the key dimension. - cached_val: :math:`(left_context_len, N, V)`, where N is the batch size, V is the value dimension. diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py new file mode 100644 index 000000000..be9cd1608 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py @@ -0,0 +1,3144 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey,) +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import itertools +import logging +import math +import random +import warnings +from typing import List, Optional, Tuple, Union + +import torch +from encoder_interface import EncoderInterface +from scaling import ( # not as in other dirs.. just scales down initial parameter values. + ActivationBalancer, + BasicNorm, + DoubleSwish, + Identity, + MaxEig, + ScaledConv1d, + ScaledLinear, + Whiten, + _diag, + penalize_abs_values_gt, + random_clamp, + softmax, +) +from torch import Tensor, nn +from zipformer import PoolingModule + +from icefall.utils import make_pad_mask, subsequent_chunk_mask + + +def stack_states(state_list: List[List[Tensor]]) -> List[Tensor]: + """Stack list of zipformer states that correspond to separate utterances + into a single emformer state, so that it can be used as an input for + zipformer when those utterances are formed into a batch. + + Note: + It is the inverse of :func:`unstack_states`. + + Args: + state_list: + Each element in state_list corresponding to the internal state + of the zipformer model for a single utterance. + ``states[i]`` is a list of 7 * num_encoders elements of i-th utterance. + ``states[i][0:num_encoders]`` is the cached numbers of past frames. + ``states[i][num_encoders:2*num_encoders]`` is the cached average tensors. + ``states[i][2*num_encoders:3*num_encoders]`` is the cached key tensors of the first attention modules. + ``states[i][3*num_encoders:4*num_encoders]`` is the cached value tensors of the first attention modules. + ``states[i][4*num_encoders:5*num_encoders]`` is the cached value tensors of the second attention modules. + ``states[i][5*num_encoders:6*num_encoders]`` is the cached left contexts of the first convolution modules. + ``states[i][6*num_encoders:7*num_encoders]`` is the cached left contexts of the second convolution modules. + + Returns: + A new state corresponding to a batch of utterances. + See the input argument of :func:`unstack_states` for the meaning + of the returned tensor. + """ + batch_size = len(state_list) + assert len(state_list[0]) % 7 == 0, len(state_list[0]) + num_encoders = len(state_list[0]) // 7 + + cached_len = [] + cached_avg = [] + cached_key = [] + cached_val = [] + cached_val2 = [] + cached_conv1 = [] + cached_conv2 = [] + + # For cached_len + len_list = [state_list[n][0:num_encoders] for n in range(batch_size)] + for i in range(num_encoders): + # len_avg: (num_layers, batch_size) + len_avg = torch.cat([len_list[n][i] for n in range(batch_size)], dim=1) + cached_len.append(len_avg) + + # For cached_avg + avg_list = [ + state_list[n][num_encoders : 2 * num_encoders] for n in range(batch_size) + ] + for i in range(num_encoders): + # avg: (num_layers, batch_size, D) + avg = torch.cat([avg_list[n][i] for n in range(batch_size)], dim=1) + cached_avg.append(avg) + + # For cached_key + key_list = [ + state_list[n][2 * num_encoders : 3 * num_encoders] for n in range(batch_size) + ] + for i in range(num_encoders): + # key: (num_layers, left_context_size, batch_size, D) + key = torch.cat([key_list[n][i] for n in range(batch_size)], dim=2) + cached_key.append(key) + + # For cached_val + val_list = [ + state_list[n][3 * num_encoders : 4 * num_encoders] for n in range(batch_size) + ] + for i in range(num_encoders): + # val: (num_layers, left_context_size, batch_size, D) + val = torch.cat([val_list[n][i] for n in range(batch_size)], dim=2) + cached_val.append(val) + + # For cached_val2 + val2_list = [ + state_list[n][4 * num_encoders : 5 * num_encoders] for n in range(batch_size) + ] + for i in range(num_encoders): + # val2: (num_layers, left_context_size, batch_size, D) + val2 = torch.cat([val2_list[n][i] for n in range(batch_size)], dim=2) + cached_val2.append(val2) + + # For cached_conv1 + conv1_list = [ + state_list[n][5 * num_encoders : 6 * num_encoders] for n in range(batch_size) + ] + for i in range(num_encoders): + # conv1: (num_layers, batch_size, D, kernel-1) + conv1 = torch.cat([conv1_list[n][i] for n in range(batch_size)], dim=1) + cached_conv1.append(conv1) + + # For cached_conv2 + conv2_list = [ + state_list[n][6 * num_encoders : 7 * num_encoders] for n in range(batch_size) + ] + for i in range(num_encoders): + # conv2: (num_layers, batch_size, D, kernel-1) + conv2 = torch.cat([conv2_list[n][i] for n in range(batch_size)], dim=1) + cached_conv2.append(conv2) + + states = ( + cached_len + + cached_avg + + cached_key + + cached_val + + cached_val2 + + cached_conv1 + + cached_conv2 + ) + return states + + +def unstack_states(states: List[Tensor]) -> List[List[Tensor]]: + """Unstack the zipformer state corresponding to a batch of utterances + into a list of states, where the i-th entry is the state from the i-th + utterance in the batch. + + Note: + It is the inverse of :func:`stack_states`. + + Args: + states: + A list of 7 * num_encoders elements: + ``states[0:num_encoders]`` is the cached numbers of past frames. + ``states[num_encoders:2*num_encoders]`` is the cached average tensors. + ``states[2*num_encoders:3*num_encoders]`` is the cached key tensors of the first attention modules. + ``states[3*num_encoders:4*num_encoders]`` is the cached value tensors of the first attention modules. + ``states[4*num_encoders:5*num_encoders]`` is the cached value tensors of the second attention modules. + ``states[5*num_encoders:6*num_encoders]`` is the cached left contexts of the first convolution modules. + ``states[6*num_encoders:7*num_encoders]`` is the cached left contexts of the second convolution modules. + + Returns: + A list of states. + ``states[i]`` is a list of 7 * num_encoders elements of i-th utterance. + """ + assert len(states) % 7 == 0, len(states) + num_encoders = len(states) // 7 + ( + cached_len, + cached_avg, + cached_key, + cached_val, + cached_val2, + cached_conv1, + cached_conv2, + ) = (states[i * num_encoders : (i + 1) * num_encoders] for i in range(7)) + + batch_size = cached_len[0].shape[1] + + len_list = [[] for _ in range(batch_size)] + for i in range(num_encoders): + # cached_len[i]: (num_layers, batch_size) + len_avg = cached_len[i].chunk(chunks=batch_size, dim=1) + for n in range(batch_size): + len_list[n].append(len_avg[n]) + + avg_list = [[] for _ in range(batch_size)] + for i in range(num_encoders): + # cached_avg[i]: (num_layers, batch_size, D) + avg = cached_avg[i].chunk(chunks=batch_size, dim=1) + for n in range(batch_size): + avg_list[n].append(avg[n]) + + key_list = [[] for _ in range(batch_size)] + for i in range(num_encoders): + # cached_key[i]: (num_layers, left_context, batch_size, D) + key = cached_key[i].chunk(chunks=batch_size, dim=2) + for n in range(batch_size): + key_list[n].append(key[n]) + + val_list = [[] for _ in range(batch_size)] + for i in range(num_encoders): + # cached_val[i]: (num_layers, left_context, batch_size, D) + val = cached_val[i].chunk(chunks=batch_size, dim=2) + for n in range(batch_size): + val_list[n].append(val[n]) + + val2_list = [[] for _ in range(batch_size)] + for i in range(num_encoders): + # cached_val2[i]: (num_layers, left_context, batch_size, D) + val2 = cached_val2[i].chunk(chunks=batch_size, dim=2) + for n in range(batch_size): + val2_list[n].append(val2[n]) + + conv1_list = [[] for _ in range(batch_size)] + for i in range(num_encoders): + # cached_conv1[i]: (num_layers, batch_size, D, kernel-1) + conv1 = cached_conv1[i].chunk(chunks=batch_size, dim=1) + for n in range(batch_size): + conv1_list[n].append(conv1[n]) + + conv2_list = [[] for _ in range(batch_size)] + for i in range(num_encoders): + # cached_conv2[i]: (num_layers, batch_size, D, kernel-1) + conv2 = cached_conv2[i].chunk(chunks=batch_size, dim=1) + for n in range(batch_size): + conv2_list[n].append(conv2[n]) + + state_list = [ + ( + len_list[i] + + avg_list[i] + + key_list[i] + + val_list[i] + + val2_list[i] + + conv1_list[i] + + conv2_list[i] + ) + for i in range(batch_size) + ] + return state_list + + +class Zipformer(EncoderInterface): + """ + Args: + num_features (int): Number of input features + d_model: (int,int): embedding dimension of 2 encoder stacks + attention_dim: (int,int): attention dimension of 2 encoder stacks + nhead (int, int): number of heads + dim_feedforward (int, int): feedforward dimension in 2 encoder stacks + num_encoder_layers (int): number of encoder layers + dropout (float): dropout rate + cnn_module_kernels (int): Kernel size of convolution module + warmup_batches (float): number of batches to warm up over + is_pnnx (bool): True if we are going to convert this model via pnnx. + """ + + def __init__( + self, + num_features: int, + output_downsampling_factor: int = 2, + encoder_dims: Tuple[int] = (384, 384), + attention_dim: Tuple[int] = (256, 256), + encoder_unmasked_dims: Tuple[int] = (256, 256), + zipformer_downsampling_factors: Tuple[int] = (2, 4), + nhead: Tuple[int] = (8, 8), + feedforward_dim: Tuple[int] = (1536, 2048), + num_encoder_layers: Tuple[int] = (12, 12), + dropout: float = 0.1, + cnn_module_kernels: Tuple[int] = (31, 31), + pos_dim: int = 4, + num_left_chunks: int = 4, + short_chunk_threshold: float = 0.75, + short_chunk_size: int = 50, + decode_chunk_size: int = 16, + warmup_batches: float = 4000.0, + is_pnnx: bool = False, + ) -> None: + super(Zipformer, self).__init__() + self.is_pnnx = is_pnnx + + self.num_features = num_features + assert 0 < encoder_dims[0] <= encoder_dims[1] + self.encoder_dims = encoder_dims + self.encoder_unmasked_dims = encoder_unmasked_dims + self.zipformer_downsampling_factors = zipformer_downsampling_factors + self.output_downsampling_factor = output_downsampling_factor + + self.num_left_chunks = num_left_chunks + self.short_chunk_threshold = short_chunk_threshold + self.short_chunk_size = short_chunk_size + + # Used in decoding + self.decode_chunk_size = decode_chunk_size + + self.left_context_len = self.decode_chunk_size * self.num_left_chunks + + # will be written to, see set_batch_count() + self.batch_count = 0 + self.warmup_end = warmup_batches + + for u, d in zip(encoder_unmasked_dims, encoder_dims): + assert u <= d, (u, d) + + # self.encoder_embed converts the input of shape (N, T, num_features) + # to the shape (N, (T - 7)//2, encoder_dims). + # That is, it does two things simultaneously: + # (1) subsampling: T -> (T - 7)//2 + # (2) embedding: num_features -> encoder_dims + self.encoder_embed = Conv2dSubsampling( + num_features, encoder_dims[0], dropout=dropout, is_pnnx=is_pnnx + ) + + # each one will be ZipformerEncoder or DownsampledZipformerEncoder + encoders = [] + + self.num_encoders = len(encoder_dims) + for i in range(self.num_encoders): + ds = zipformer_downsampling_factors[i] + encoder_layer = ZipformerEncoderLayer( + encoder_dims[i], + attention_dim[i], + nhead[i], + feedforward_dim[i], + dropout, + cnn_module_kernels[i], + pos_dim, + is_pnnx=self.is_pnnx, + left_context_len=self.left_context_len // ds, + x_size=self.decode_chunk_size // ds, + ) + + # For the segment of the warmup period, we let the Conv2dSubsampling + # layer learn something. Then we start to warm up the other encoders. + encoder = ZipformerEncoder( + encoder_layer, + num_encoder_layers[i], + dropout, + warmup_begin=warmup_batches * (i + 1) / (self.num_encoders + 1), + warmup_end=warmup_batches * (i + 2) / (self.num_encoders + 1), + is_pnnx=is_pnnx, + left_context_len=self.left_context_len // ds, + x_size=self.decode_chunk_size // ds, + ) + + if zipformer_downsampling_factors[i] != 1: + in_x_size = self.decode_chunk_size + encoder = DownsampledZipformerEncoder( + encoder, + input_dim=encoder_dims[i - 1] if i > 0 else encoder_dims[0], + output_dim=encoder_dims[i], + downsample=zipformer_downsampling_factors[i], + is_pnnx=is_pnnx, + left_context_len=self.left_context_len // ds, + in_x_size=in_x_size, + ) + encoders.append(encoder) + self.encoders = nn.ModuleList(encoders) + + # initializes self.skip_layers and self.skip_modules + self._init_skip_modules() + + self.downsample_output = AttentionDownsample( + encoder_dims[-1], + encoder_dims[-1], + downsample=output_downsampling_factor, + is_pnnx=is_pnnx, + in_x_size=self.decode_chunk_size, + ) + + def _get_layer_skip_dropout_prob(self): + if not self.training: + return 0.0 + batch_count = self.batch_count + min_dropout_prob = 0.025 + + if batch_count > self.warmup_end: + return min_dropout_prob + else: + return 0.5 - (batch_count / self.warmup_end) * (0.5 - min_dropout_prob) + + def _init_skip_modules(self): + """ + If self.zipformer_downsampling_factors = (1, 2, 4, 8, 4, 2), then at the input of layer + indexed 4 (in zero indexing), which has subsampling_factor=4, we combine the output of + layers 2 and 3; and at the input of layer indexed 5, which has subsampling_factor=2, + we combine the outputs of layers 1 and 4. + """ + skip_layers = [] + skip_modules = [] + z = self.zipformer_downsampling_factors + for i in range(len(z)): + if i <= 1 or z[i - 1] <= z[i]: + skip_layers.append(None) + skip_modules.append(SimpleCombinerIdentity()) + else: + # TEMP + for j in range(i - 2, -1, -1): + if z[j] <= z[i] or j == 0: + # TEMP logging statement. + logging.info( + f"At encoder stack {i}, which has downsampling_factor={z[i]}, we will " + f"combine the outputs of layers {j} and {i-1}, with downsampling_factors={z[j]} and {z[i-1]}." + ) + skip_layers.append(j) + skip_modules.append( + SimpleCombiner( + self.encoder_dims[j], + self.encoder_dims[i - 1], + min_weight=(0.0, 0.25), + ) + ) + break + self.skip_layers = skip_layers + self.skip_modules = nn.ModuleList(skip_modules) + + def get_feature_masks(self, x: torch.Tensor) -> List[float]: + # Note: The actual return type is Union[List[float], List[Tensor]], + # but to make torch.jit.script() work, we use List[float] + """ + In eval mode, returns [1.0] * num_encoders; in training mode, returns a number of + randomized feature masks, one per encoder. + On e.g. 15% of frames, these masks will zero out all encoder dims larger than + some supplied number, e.g. >256, so in effect on those frames we are using + a smaller encoder dim. + + We generate the random masks at this level because we want the 2 masks to 'agree' + all the way up the encoder stack. This will mean that the 1st mask will have + mask values repeated self.zipformer_downsampling_factors times. + + Args: + x: the embeddings (needed for the shape and dtype and device), of shape + (num_frames, batch_size, encoder_dims0) + """ + num_encoders = len(self.encoder_dims) + if torch.jit.is_scripting() or not self.training: + return [1.0] * num_encoders + + (num_frames0, batch_size, _encoder_dims0) = x.shape + + assert self.encoder_dims[0] == _encoder_dims0, ( + self.encoder_dims, + _encoder_dims0, + ) + + max_downsampling_factor = max(self.zipformer_downsampling_factors) + + num_frames_max = num_frames0 + max_downsampling_factor - 1 + + feature_mask_dropout_prob = 0.15 + + # frame_mask_max shape: (num_frames_max, batch_size, 1) + frame_mask_max = ( + torch.rand(num_frames_max, batch_size, 1, device=x.device) + > feature_mask_dropout_prob + ).to(x.dtype) + + feature_masks = [] + for i in range(num_encoders): + ds = self.zipformer_downsampling_factors[i] + upsample_factor = max_downsampling_factor // ds + + frame_mask = ( + frame_mask_max.unsqueeze(1) + .expand(num_frames_max, upsample_factor, batch_size, 1) + .reshape(num_frames_max * upsample_factor, batch_size, 1) + ) + num_frames = (num_frames0 + ds - 1) // ds + frame_mask = frame_mask[:num_frames] + feature_mask = torch.ones( + num_frames, + batch_size, + self.encoder_dims[i], + dtype=x.dtype, + device=x.device, + ) + u = self.encoder_unmasked_dims[i] + feature_mask[:, :, u:] *= frame_mask + feature_masks.append(feature_mask) + + return feature_masks + + def forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + x: + The input tensor. Its shape is (batch_size, seq_len, feature_dim). + x_lens: + A tensor of shape (batch_size,) containing the number of frames in + `x` before padding. + chunk_size: + The chunk size used in evaluation mode. + Returns: + Return a tuple containing 2 tensors: + - embeddings: its shape is (batch_size, output_seq_len, encoder_dims[-1]) + - lengths, a tensor of shape (batch_size,) containing the number + of frames in `embeddings` before padding. + """ + x = self.encoder_embed(x) + + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + + lengths = (x_lens - 7) >> 1 + assert x.size(0) == lengths.max().item(), (x.shape, lengths, lengths.max()) + mask = make_pad_mask(lengths) + + outputs = [] + feature_masks = self.get_feature_masks(x) + + if self.training: + # Training mode + max_ds = max(self.zipformer_downsampling_factors) + # Generate dynamic chunk-wise attention mask during training + max_len = x.size(0) // max_ds + short_chunk_size = self.short_chunk_size // max_ds + chunk_size = torch.randint(1, max_len, (1,)).item() + if chunk_size > (max_len * self.short_chunk_threshold): + # Full attention + chunk_size = x.size(0) + else: + # Chunk-wise attention + chunk_size = chunk_size % short_chunk_size + 1 + chunk_size *= max_ds + else: + chunk_size = self.decode_chunk_size + # Evaluation mode + for ds in self.zipformer_downsampling_factors: + assert chunk_size % ds == 0, (chunk_size, ds) + + attn_mask = ~subsequent_chunk_mask( + size=x.size(0), + chunk_size=chunk_size, + num_left_chunks=self.num_left_chunks, + device=x.device, + ) + + for i, (module, skip_module) in enumerate( + zip(self.encoders, self.skip_modules) + ): + ds = self.zipformer_downsampling_factors[i] + k = self.skip_layers[i] + if isinstance(k, int): + layer_skip_dropout_prob = self._get_layer_skip_dropout_prob() + if torch.jit.is_scripting(): + x = skip_module(outputs[k], x) + elif (not self.training) or random.random() > layer_skip_dropout_prob: + x = skip_module(outputs[k], x) + x = module( + x, + feature_mask=feature_masks[i], + src_key_padding_mask=None if mask is None else mask[..., ::ds], + attn_mask=attn_mask[::ds, ::ds], + ) + outputs.append(x) + + x = self.downsample_output(x) + # class Downsample has this rounding behavior.. + assert self.output_downsampling_factor == 2, self.output_downsampling_factor + lengths = (lengths + 1) >> 1 + + x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + + return x, lengths + + def streaming_forward( + self, + x: torch.Tensor, + states: List[Tensor], + ) -> Tuple[Tensor, List[Tensor]]: + """ + Args: + x: + The input tensor. Its shape is (batch_size, seq_len, feature_dim). + seq_len is the input chunk length. + states: + A list of 7 * num_encoders elements: + ``states[0:num_encoders]`` is the cached numbers of past frames. + ``states[num_encoders:2*num_encoders]`` is the cached average tensors. + ``states[2*num_encoders:3*num_encoders]`` is the cached key tensors of the first attention modules. + ``states[3*num_encoders:4*num_encoders]`` is the cached value tensors of the first attention modules. + ``states[4*num_encoders:5*num_encoders]`` is the cached value tensors of the second attention modules. + ``states[5*num_encoders:6*num_encoders]`` is the cached left contexts of the first convolution modules. + ``states[6*num_encoders:7*num_encoders]`` is the cached left contexts of the second convolution modules. + + Returns: + Return a tuple containing 3 tensors: + - embeddings: its shape is (batch_size, output_seq_len, encoder_dims[-1]) + - lengths, a tensor of shape (batch_size,) containing the number + of frames in `embeddings` before padding. + - updated states. + """ + assert len(states) == 7 * self.num_encoders, (len(states), self.num_encoders) + + cached_len = states[: self.num_encoders] + cached_avg = states[self.num_encoders : 2 * self.num_encoders] + cached_key = states[2 * self.num_encoders : 3 * self.num_encoders] + cached_val = states[3 * self.num_encoders : 4 * self.num_encoders] + cached_val2 = states[4 * self.num_encoders : 5 * self.num_encoders] + cached_conv1 = states[5 * self.num_encoders : 6 * self.num_encoders] + cached_conv2 = states[6 * self.num_encoders : 7 * self.num_encoders] + + x = self.encoder_embed(x) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + + outputs = [] + new_cached_len = [] + new_cached_avg = [] + new_cached_key = [] + new_cached_val = [] + new_cached_val2 = [] + new_cached_conv1 = [] + new_cached_conv2 = [] + + for i, (module, skip_module) in enumerate( + zip(self.encoders, self.skip_modules) + ): + k = self.skip_layers[i] + if isinstance(k, int): + x = skip_module(outputs[k], x) + x, len_avg, avg, key, val, val2, conv1, conv2 = module.streaming_forward( + x, + cached_len=cached_len[i], + cached_avg=cached_avg[i], + cached_key=cached_key[i], + cached_val=cached_val[i], + cached_val2=cached_val2[i], + cached_conv1=cached_conv1[i], + cached_conv2=cached_conv2[i], + ) + + outputs.append(x) + # Update caches + new_cached_len.append(len_avg) + new_cached_avg.append(avg) + new_cached_key.append(key) + new_cached_val.append(val) + new_cached_val2.append(val2) + new_cached_conv1.append(conv1) + new_cached_conv2.append(conv2) + + x = self.downsample_output(x) + # class Downsample has this rounding behavior.. + assert self.output_downsampling_factor == 2, self.output_downsampling_factor + + x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + + new_states = ( + new_cached_len + + new_cached_avg + + new_cached_key + + new_cached_val + + new_cached_val2 + + new_cached_conv1 + + new_cached_conv2 + ) + return x, new_states + + @torch.jit.export + def get_init_state( + self, + device: torch.device = torch.device("cpu"), + ) -> List[Tensor]: + """Get initial states. + A list of 7 * num_encoders elements: + ``states[0:num_encoders]`` is the cached numbers of past frames. + ``states[num_encoders:2*num_encoders]`` is the cached average tensors. + ``states[2*num_encoders:3*num_encoders]`` is the cached key tensors of the first attention modules. + ``states[3*num_encoders:4*num_encoders]`` is the cached value tensors of the first attention modules. + ``states[4*num_encoders:5*num_encoders]`` is the cached value tensors of the second attention modules. + ``states[5*num_encoders:6*num_encoders]`` is the cached left contexts of the first convolution modules. + ``states[6*num_encoders:7*num_encoders]`` is the cached left contexts of the second convolution modules. + """ + cached_len = [] + cached_avg = [] + cached_key = [] + cached_val = [] + cached_val2 = [] + cached_conv1 = [] + cached_conv2 = [] + + for i, encoder in enumerate(self.encoders): + num_layers = encoder.num_layers + ds = self.zipformer_downsampling_factors[i] + + len_avg = torch.zeros(num_layers, 1, device=device) + cached_len.append(len_avg) + + avg = torch.zeros(num_layers, 1, encoder.d_model, device=device) + cached_avg.append(avg) + + key = torch.zeros( + num_layers, + self.left_context_len // ds, + 1, + encoder.attention_dim, + device=device, + ) + cached_key.append(key) + + val = torch.zeros( + num_layers, + self.left_context_len // ds, + 1, + encoder.attention_dim // 2, + device=device, + ) + cached_val.append(val) + + val2 = torch.zeros( + num_layers, + self.left_context_len // ds, + 1, + encoder.attention_dim // 2, + device=device, + ) + cached_val2.append(val2) + + conv1 = torch.zeros( + num_layers, + 1, + encoder.d_model, + encoder.cnn_module_kernel - 1, + device=device, + ) + cached_conv1.append(conv1) + + conv2 = torch.zeros( + num_layers, + 1, + encoder.d_model, + encoder.cnn_module_kernel - 1, + device=device, + ) + cached_conv2.append(conv2) + + states = ( + cached_len + + cached_avg + + cached_key + + cached_val + + cached_val2 + + cached_conv1 + + cached_conv2 + ) + return states + + +class ZipformerEncoderLayer(nn.Module): + """ + ZipformerEncoderLayer is made up of self-attn, feedforward and convolution networks. + + Args: + d_model: the number of expected features in the input (required). + nhead: the number of heads in the multiheadattention models (required). + feedforward_dim: the dimension of the feedforward network model (default=2048). + dropout: the dropout value (default=0.1). + cnn_module_kernel (int): Kernel size of convolution module. + + Examples:: + >>> encoder_layer = ZipformerEncoderLayer(d_model=512, nhead=8) + >>> src = torch.rand(10, 32, 512) + >>> pos_emb = torch.rand(32, 19, 512) + >>> out = encoder_layer(src, pos_emb) + """ + + def __init__( + self, + d_model: int, + attention_dim: int, + nhead: int, + feedforward_dim: int = 2048, + dropout: float = 0.1, + cnn_module_kernel: int = 31, + pos_dim: int = 4, + is_pnnx: bool = False, + left_context_len: int = 0, + x_size: int = 0, + ) -> None: + super(ZipformerEncoderLayer, self).__init__() + + self.d_model = d_model + self.attention_dim = attention_dim + self.cnn_module_kernel = cnn_module_kernel + + # will be written to, see set_batch_count() + self.batch_count = 0 + + self.self_attn = RelPositionMultiheadAttention( + d_model, + attention_dim, + nhead, + pos_dim, + dropout=0.0, + is_pnnx=is_pnnx, + left_context_len=left_context_len, + x_size=x_size, + ) + + self.pooling = PoolingModule(d_model) + + self.feed_forward1 = FeedforwardModule(d_model, feedforward_dim, dropout) + + self.feed_forward2 = FeedforwardModule(d_model, feedforward_dim, dropout) + + self.feed_forward3 = FeedforwardModule(d_model, feedforward_dim, dropout) + + self.conv_module1 = ConvolutionModule( + d_model, cnn_module_kernel, is_pnnx=is_pnnx, x_size=x_size + ) + + self.conv_module2 = ConvolutionModule( + d_model, cnn_module_kernel, is_pnnx=is_pnnx, x_size=x_size + ) + + self.norm_final = BasicNorm(d_model) + + self.bypass_scale = nn.Parameter(torch.tensor(0.5)) + + # try to ensure the output is close to zero-mean (or at least, zero-median). + self.balancer = ActivationBalancer( + d_model, + channel_dim=-1, + min_positive=0.45, + max_positive=0.55, + max_abs=6.0, + ) + self.whiten = Whiten( + num_groups=1, whitening_limit=5.0, prob=(0.025, 0.25), grad_scale=0.01 + ) + + def get_bypass_scale(self): + if torch.jit.is_scripting() or not self.training: + return self.bypass_scale + if random.random() < 0.1: + # ensure we get grads if self.bypass_scale becomes out of range + return self.bypass_scale + # hardcode warmup period for bypass scale + warmup_period = 20000.0 + initial_clamp_min = 0.75 + final_clamp_min = 0.25 + if self.batch_count > warmup_period: + clamp_min = final_clamp_min + else: + clamp_min = initial_clamp_min - (self.batch_count / warmup_period) * ( + initial_clamp_min - final_clamp_min + ) + return self.bypass_scale.clamp(min=clamp_min, max=1.0) + + def get_dynamic_dropout_rate(self): + # return dropout rate for the dynamic modules (self_attn, pooling, convolution); this + # starts at 0.2 and rapidly decreases to 0. Its purpose is to keep the training stable + # at the beginning, by making the network focus on the feedforward modules. + if torch.jit.is_scripting() or not self.training: + return 0.0 + warmup_period = 2000.0 + initial_dropout_rate = 0.2 + final_dropout_rate = 0.0 + if self.batch_count > warmup_period: + return final_dropout_rate + else: + return initial_dropout_rate - ( + initial_dropout_rate * final_dropout_rate + ) * (self.batch_count / warmup_period) + + def forward( + self, + src: Tensor, + pos_emb: Tensor, + attn_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + """ + Pass the input through the encoder layer. + + Args: + src: the sequence to the encoder layer (required). + pos_emb: Positional embedding tensor (required). + src_mask: the mask for the src sequence (optional). + src_key_padding_mask: the mask for the src keys per batch (optional). + batch_split: if not None, this layer will only be applied to + + Shape: + src: (S, N, E). + pos_emb: (N, 2*S-1, E) + src_mask: (S, S). + src_key_padding_mask: (N, S). + S is the source sequence length, N is the batch size, E is the feature number + """ + src_orig = src + + # macaron style feed forward module + src = src + self.feed_forward1(src) + + # dropout rate for submodules that interact with time. + dynamic_dropout = self.get_dynamic_dropout_rate() + + # pooling module + if torch.jit.is_scripting(): + src = src + self.pooling(src, src_key_padding_mask=src_key_padding_mask) + elif random.random() >= dynamic_dropout: + src = src + self.pooling(src, src_key_padding_mask=src_key_padding_mask) + + if torch.jit.is_scripting(): + src_att, attn_weights = self.self_attn( + src, + pos_emb=pos_emb, + attn_mask=attn_mask, + key_padding_mask=src_key_padding_mask, + ) + src = src + src_att + + src = src + self.conv_module1( + src, src_key_padding_mask=src_key_padding_mask + ) + + src = src + self.feed_forward2(src) + + src = src + self.self_attn.forward2(src, attn_weights) + + src = src + self.conv_module2( + src, src_key_padding_mask=src_key_padding_mask + ) + else: + use_self_attn = random.random() >= dynamic_dropout + if use_self_attn: + src_att, attn_weights = self.self_attn( + src, + pos_emb=pos_emb, + attn_mask=attn_mask, + key_padding_mask=src_key_padding_mask, + ) + src = src + src_att + + if random.random() >= dynamic_dropout: + src = src + self.conv_module1( + src, src_key_padding_mask=src_key_padding_mask + ) + + src = src + self.feed_forward2(src) + + if use_self_attn: + src = src + self.self_attn.forward2(src, attn_weights) + + if random.random() >= dynamic_dropout: + src = src + self.conv_module2( + src, src_key_padding_mask=src_key_padding_mask + ) + + src = src + self.feed_forward3(src) + + src = self.norm_final(self.balancer(src)) + + delta = src - src_orig + + src = src_orig + delta * self.get_bypass_scale() + + return self.whiten(src) + + def streaming_forward( + self, + src: Tensor, + pos_emb: Tensor, + cached_len: Tensor, + cached_avg: Tensor, + cached_key: Tensor, + cached_val: Tensor, + cached_val2: Tensor, + cached_conv1: Tensor, + cached_conv2: Tensor, + ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: + """ + Pass the input through the encoder layer. + + Args: + src: the sequence to the encoder layer (required). + pos_emb: Positional embedding tensor (required). + cached_len: processed number of past frames. + cached_avg: cached average of past frames. + cached_key: cached key tensor of left context for the first attention module. + cached_val: cached value tensor of left context for the first attention module. + cached_val2: cached value tensor of left context for the second attention module. + cached_conv1: cached left context for the first convolution module. + cached_conv2: cached left context for the second convolution module. + + Shape: + src: (S, N, E). + pos_emb: (N, left_context_len+2*S-1, E) + cached_len: (N,) + N is the batch size. + cached_avg: (N, C). + N is the batch size, C is the feature dimension. + cached_key: (left_context_len, N, K). + N is the batch size, K is the key dimension. + cached_val: (left_context_len, N, V). + N is the batch size, V is the key dimension. + cached_val2: (left_context_len, N, V). + N is the batch size, V is the key dimension. + cached_conv1: (N, C, kernel_size-1). + N is the batch size, C is the convolution channels. + cached_conv2: (N, C, kernel_size-1). + N is the batch size, C is the convolution channels. + """ + src_orig = src + + # macaron style feed forward module + src = src + self.feed_forward1(src) + + src_pool, cached_len, cached_avg = self.pooling.streaming_forward( + src, + cached_len=cached_len, + cached_avg=cached_avg, + ) + src = src + src_pool + + ( + src_attn, + attn_weights, + cached_key, + cached_val, + ) = self.self_attn.streaming_forward( + src, + pos_emb=pos_emb, + cached_key=cached_key, + cached_val=cached_val, + ) + + src = src + src_attn + + src_conv, cached_conv1 = self.conv_module1.streaming_forward( + src, + cache=cached_conv1, + ) + + src = src + src_conv + + src = src + self.feed_forward2(src) + + src_attn, cached_val2 = self.self_attn.streaming_forward2( + src, + attn_weights, + cached_val=cached_val2, + ) + src = src + src_attn + + src_conv, cached_conv2 = self.conv_module2.streaming_forward( + src, + cache=cached_conv2, + ) + src = src + src_conv + + src = src + self.feed_forward3(src) + + src = self.norm_final(self.balancer(src)) + + delta = src - src_orig + + src = src_orig + delta * self.bypass_scale + + return ( + src, + cached_len, + cached_avg, + cached_key, + cached_val, + cached_val2, + cached_conv1, + cached_conv2, + ) + + +class ZipformerStateSelect(nn.Module): + """ncnn does not support selecting along batch index. + This class provides a workaround for it. We + need to change pnnx accordingly. + """ + + def __init__(self, i: int): + super().__init__() + self.i = i + + def forward(self, x: torch.Tensor): + return x[self.i] + + +class ZipformerEncoder(nn.Module): + r"""ZipformerEncoder is a stack of N encoder layers + + Args: + encoder_layer: an instance of the ZipformerEncoderLayer() class (required). + num_layers: the number of sub-encoder-layers in the encoder (required). + + Examples:: + >>> encoder_layer = ZipformerEncoderLayer(d_model=512, nhead=8) + >>> zipformer_encoder = ZipformerEncoder(encoder_layer, num_layers=6) + >>> src = torch.rand(10, 32, 512) + >>> out = zipformer_encoder(src) + """ + + def __init__( + self, + encoder_layer: nn.Module, + num_layers: int, + dropout: float, + warmup_begin: float, + warmup_end: float, + is_pnnx: bool = False, + x_size: int = 0, + left_context_len: int = 0, + ) -> None: + super().__init__() + # will be written to, see set_batch_count() Note: in inference time this + # may be zero but should be treated as large, we can check if + # self.training is true. + self.batch_count = 0 + self.warmup_begin = warmup_begin + self.warmup_end = warmup_end + # module_seed is for when we need a random number that is unique to the module but + # shared across jobs. It's used to randomly select how many layers to drop, + # so that we can keep this consistent across worker tasks (for efficiency). + self.module_seed = torch.randint(0, 1000, ()).item() + self.left_context_len = left_context_len + + self.encoder_pos = RelPositionalEncoding( + encoder_layer.d_model, + dropout, + is_pnnx=is_pnnx, + x_size=x_size, + left_context_len=left_context_len, + ) + + self.layers = nn.ModuleList( + [copy.deepcopy(encoder_layer) for i in range(num_layers)] + ) + self.num_layers = num_layers + + state_select_list = [] + for i in range(num_layers): + state_select_list.append(ZipformerStateSelect(i)) + self.state_select_list = nn.ModuleList(state_select_list) + + self.d_model = encoder_layer.d_model + self.attention_dim = encoder_layer.attention_dim + self.cnn_module_kernel = encoder_layer.cnn_module_kernel + + assert 0 <= warmup_begin <= warmup_end, (warmup_begin, warmup_end) + + delta = (1.0 / num_layers) * (warmup_end - warmup_begin) + cur_begin = warmup_begin + for i in range(num_layers): + self.layers[i].warmup_begin = cur_begin + cur_begin += delta + self.layers[i].warmup_end = cur_begin + + def get_layers_to_drop(self, rnd_seed: int): + ans = set() + if not self.training: + return ans + + batch_count = self.batch_count + num_layers = len(self.layers) + + def get_layerdrop_prob(layer: int) -> float: + layer_warmup_begin = self.layers[layer].warmup_begin + layer_warmup_end = self.layers[layer].warmup_end + + initial_layerdrop_prob = 0.5 + final_layerdrop_prob = 0.05 + + if batch_count == 0: + # As a special case, if batch_count == 0, return 0 (drop no + # layers). This is rather ugly, I'm afraid; it is intended to + # enable our scan_pessimistic_batches_for_oom() code to work correctly + # so if we are going to get OOM it will happen early. + # also search for 'batch_count' with quotes in this file to see + # how we initialize the warmup count to a random number between + # 0 and 10. + return 0.0 + elif batch_count < layer_warmup_begin: + return initial_layerdrop_prob + elif batch_count > layer_warmup_end: + return final_layerdrop_prob + else: + # linearly interpolate + t = (batch_count - layer_warmup_begin) / layer_warmup_end + assert 0.0 <= t < 1.001, t + return initial_layerdrop_prob + t * ( + final_layerdrop_prob - initial_layerdrop_prob + ) + + shared_rng = random.Random(batch_count + self.module_seed) + independent_rng = random.Random(rnd_seed) + + layerdrop_probs = [get_layerdrop_prob(i) for i in range(num_layers)] + tot = sum(layerdrop_probs) + # Instead of drawing the samples independently, we first randomly decide + # how many layers to drop out, using the same random number generator between + # jobs so that all jobs drop out the same number (this is for speed). + # Then we use an approximate approach to drop out the individual layers + # with their specified probs while reaching this exact target. + num_to_drop = int(tot) + int(shared_rng.random() < (tot - int(tot))) + + layers = list(range(num_layers)) + independent_rng.shuffle(layers) + + # go through the shuffled layers until we get the required number of samples. + if num_to_drop > 0: + for layer in itertools.cycle(layers): + if independent_rng.random() < layerdrop_probs[layer]: + ans.add(layer) + if len(ans) == num_to_drop: + break + if shared_rng.random() < 0.005 or __name__ == "__main__": + logging.info( + f"warmup_begin={self.warmup_begin:.1f}, warmup_end={self.warmup_end:.1f}, " + f"batch_count={batch_count:.1f}, num_to_drop={num_to_drop}, layers_to_drop={ans}" + ) + return ans + + def forward( + self, + src: Tensor, + # Note: The type of feature_mask should be Union[float, Tensor], + # but to make torch.jit.script() work, we use `float` here + feature_mask: float = 1.0, + attn_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + r"""Pass the input through the encoder layers in turn. + + Args: + src: the sequence to the encoder (required). + feature_mask: something that broadcasts with src, that we'll multiply `src` + by at every layer. + mask: the mask for the src sequence (optional). + src_key_padding_mask: the mask for the src keys per batch (optional). + + Shape: + src: (S, N, E). + pos_emb: (N, 2*S-1, E) + mask: (S, S). + src_key_padding_mask: (N, S). + S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number + + Returns: (x, x_no_combine), both of shape (S, N, E) + """ + pos_emb = self.encoder_pos(src) + output = src + + if torch.jit.is_scripting(): + layers_to_drop = [] + else: + rnd_seed = src.numel() + random.randint(0, 1000) + layers_to_drop = self.get_layers_to_drop(rnd_seed) + + output = output * feature_mask + + for i, mod in enumerate(self.layers): + if not torch.jit.is_scripting(): + if i in layers_to_drop: + continue + output = mod( + output, + pos_emb, + attn_mask=attn_mask, + src_key_padding_mask=src_key_padding_mask, + ) + + output = output * feature_mask + + return output + + @torch.jit.export + def streaming_forward( + self, + src: Tensor, + cached_len: Tensor, + cached_avg: Tensor, + cached_key: Tensor, + cached_val: Tensor, + cached_val2: Tensor, + cached_conv1: Tensor, + cached_conv2: Tensor, + ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: + r"""Pass the input through the encoder layers in turn. + + Args: + src: the sequence to the encoder (required). + cached_len: number of past frames. + cached_avg: cached average of past frames. + cached_key: cached key tensor for first attention module. + cached_val: cached value tensor for first attention module. + cached_val2: cached value tensor for second attention module. + cached_conv1: cached left contexts for the first convolution module. + cached_conv2: cached left contexts for the second convolution module. + + Shape: + src: (S, N, E). + cached_len: (num_layers,) + cached_avg: (num_layers, N, C). + N is the batch size, C is the feature dimension. + cached_key: (num_layers, left_context_len, N, K). + N is the batch size, K is the key dimension. + cached_val: (num_layers, left_context_len, N, V). + N is the batch size, V is the key dimension. + cached_val2: (num_layers, left_context_len, N, V). + N is the batch size, V is the key dimension. + cached_conv1: (num_layers, N, C, kernel_size-1). + N is the batch size, C is the convolution channels. + cached_conv2: (num_layers, N, C, kernel_size-1). + N is the batch size, C is the convolution channels. + + Returns: A tuple of 8 tensors: + - output tensor + - updated cached number of past frames. + - updated cached average of past frames. + - updated cached key tensor of of the first attention module. + - updated cached value tensor of of the first attention module. + - updated cached value tensor of of the second attention module. + - updated cached left contexts of the first convolution module. + - updated cached left contexts of the second convolution module. + """ + assert cached_len.size(0) == self.num_layers, ( + cached_len.size(0), + self.num_layers, + ) + assert cached_avg.size(0) == self.num_layers, ( + cached_avg.size(0), + self.num_layers, + ) + assert cached_key.size(0) == self.num_layers, ( + cached_key.size(0), + self.num_layers, + ) + assert cached_val.size(0) == self.num_layers, ( + cached_val.size(0), + self.num_layers, + ) + assert cached_val2.size(0) == self.num_layers, ( + cached_val2.size(0), + self.num_layers, + ) + assert cached_conv1.size(0) == self.num_layers, ( + cached_conv1.size(0), + self.num_layers, + ) + assert cached_conv2.size(0) == self.num_layers, ( + cached_conv2.size(0), + self.num_layers, + ) + + assert self.left_context_len == cached_key.shape[1], ( + self.left_context_len, + cached_key.shape[1], + ) + + left_context_len = self.left_context_len + pos_emb = self.encoder_pos(src, left_context_len) + + output = src + + new_cached_len = [] + new_cached_avg = [] + new_cached_key = [] + new_cached_val = [] + new_cached_val2 = [] + new_cached_conv1 = [] + new_cached_conv2 = [] + for i, (mod, state_select) in enumerate( + zip(self.layers, self.state_select_list) + ): + output, len_avg, avg, key, val, val2, conv1, conv2 = mod.streaming_forward( + output, + pos_emb, + cached_len=cached_len[i], + cached_avg=cached_avg[i], + cached_key=cached_key[i], + cached_val=cached_val[i], + cached_val2=cached_val2[i], + cached_conv1=state_select(cached_conv1), + cached_conv2=state_select(cached_conv2), + ) + # Update caches + new_cached_len.append(len_avg) + new_cached_avg.append(avg) + new_cached_key.append(key) + new_cached_val.append(val) + new_cached_val2.append(val2) + new_cached_conv1.append(conv1) + new_cached_conv2.append(conv2) + + return ( + output, + torch.stack(new_cached_len, dim=0), + torch.stack(new_cached_avg, dim=0), + torch.stack(new_cached_key, dim=0), + torch.stack(new_cached_val, dim=0), + torch.stack(new_cached_val2, dim=0), + torch.stack(new_cached_conv1, dim=0), + torch.stack(new_cached_conv2, dim=0), + ) + + +class DownsampledZipformerEncoder(nn.Module): + r""" + DownsampledZipformerEncoder is a zipformer encoder evaluated at a reduced frame rate, + after convolutional downsampling, and then upsampled again at the output, and combined + with the origin input, so that the output has the same shape as the input. + """ + + def __init__( + self, + encoder: nn.Module, + input_dim: int, + output_dim: int, + downsample: int, + is_pnnx: bool = False, + left_context_len: int = 0, + in_x_size: int = 0, + ): + super(DownsampledZipformerEncoder, self).__init__() + self.downsample_factor = downsample + self.downsample = AttentionDownsample( + input_dim, output_dim, downsample, is_pnnx=is_pnnx, in_x_size=in_x_size + ) + self.encoder = encoder + self.num_layers = encoder.num_layers + self.d_model = encoder.d_model + self.attention_dim = encoder.attention_dim + self.cnn_module_kernel = encoder.cnn_module_kernel + self.upsample = SimpleUpsample(output_dim, downsample) + self.out_combiner = SimpleCombiner( + input_dim, output_dim, min_weight=(0.0, 0.25) + ) + self.in_x_size = in_x_size + + def forward( + self, + src: Tensor, + # Note: the type of feature_mask should be Unino[float, Tensor], + # but to make torch.jit.script() happ, we use float here + feature_mask: float = 1.0, + attn_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + r"""Downsample, go through encoder, upsample. + + Args: + src: the sequence to the encoder (required). + feature_mask: something that broadcasts with src, that we'll multiply `src` + by at every layer. feature_mask is expected to be already downsampled by + self.downsample_factor. + attn_mask: attention mask (optional). Should be downsampled already. + src_key_padding_mask: the mask for the src keys per batch (optional). Should be downsampled already. + + Shape: + src: (S, N, E). + attn_mask: (S, S). + src_key_padding_mask: (N, S). + S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number + + Returns: output of shape (S, N, F) where F is the number of output features + (output_dim to constructor) + """ + src_orig = src + src = self.downsample(src) + + src = self.encoder( + src, + feature_mask=feature_mask, + attn_mask=attn_mask, + src_key_padding_mask=src_key_padding_mask, + ) + src = self.upsample(src) + # remove any extra frames that are not a multiple of downsample_factor + src = src[: src_orig.shape[0]] + + return self.out_combiner(src_orig, src) + + def streaming_forward( + self, + src: Tensor, + cached_len: Tensor, + cached_avg: Tensor, + cached_key: Tensor, + cached_val: Tensor, + cached_val2: Tensor, + cached_conv1: Tensor, + cached_conv2: Tensor, + ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: + r"""Downsample, go through encoder, upsample. + + Args: + src: the sequence to the encoder (required). + cached_avg: cached average value of past frames. + cached_len: length of past frames. + cached_key: cached key tensor for the first attention module. + cached_val: cached value tensor for the first attention module. + cached_val2: cached value tensor for the second attention module. + cached_conv1: cached left context for the first convolution module. + cached_conv2: cached left context for the second convolution module. + + Shape: + src: (S, N, E). + cached_len: (N,) + N is the batch size. + cached_avg: (num_layers, N, C). + N is the batch size, C is the feature dimension. + cached_key: (num_layers, left_context_len, N, K). + N is the batch size, K is the key dimension. + cached_val: (num_layers, left_context_len, N, V). + N is the batch size, V is the key dimension. + cached_val2: (num_layers, left_context_len, N, V). + N is the batch size, V is the key dimension. + cached_conv1: (num_layers, N, C, kernel_size-1). + N is the batch size, C is the convolution channels. + cached_conv2: (num_layers, N, C, kernel_size-1). + N is the batch size, C is the convolution channels. + Returns: output of shape (S, N, F) where F is the number of output features + (output_dim to constructor) + """ + assert src.shape[0] == self.in_x_size, (src.shape[0], self.in_x_size) + + src_orig = src + + src = self.downsample(src) + + ( + src, + cached_len, + cached_avg, + cached_key, + cached_val, + cached_val2, + cached_conv1, + cached_conv2, + ) = self.encoder.streaming_forward( + src, + cached_len=cached_len, + cached_avg=cached_avg, + cached_key=cached_key, + cached_val=cached_val, + cached_val2=cached_val2, + cached_conv1=cached_conv1, + cached_conv2=cached_conv2, + ) + + src = self.upsample(src) + + if src.shape[0] != self.in_x_size: + # remove any extra frames that are not a multiple of downsample_factor + src = src[: self.in_x_size] + + return ( + self.out_combiner(src_orig, src), + cached_len, + cached_avg, + cached_key, + cached_val, + cached_val2, + cached_conv1, + cached_conv2, + ) + + +class AttentionDownsampleUnsqueeze(torch.nn.Module): + """We apply this operation only in PyTorch + and discards in ncnn. + """ + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x.unsqueeze(1) + + +class AttentionDownsample(torch.nn.Module): + """ + Does downsampling with attention, by weighted sum, and a projection.. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + downsample: int, + is_pnnx: bool = False, + in_x_size: int = 0, + ): + super(AttentionDownsample, self).__init__() + + self.query = nn.Parameter(torch.randn(in_channels) * (in_channels**-0.5)) + + self.in_channels = in_channels + self.out_channels = out_channels + self.is_pnnx = is_pnnx + self.in_x_size = in_x_size + + self.unsqueeze = AttentionDownsampleUnsqueeze() + + # fill in the extra dimensions with a projection of the input + if out_channels > in_channels: + self.extra_proj = nn.Linear( + in_channels * downsample, out_channels - in_channels, bias=False + ) + else: + self.extra_proj = None + self.downsample = downsample + + self.d_seq_len = (in_x_size + downsample - 1) // downsample + + def forward(self, src: Tensor) -> Tensor: + """ + x: (seq_len, 1, in_channels) + Returns a tensor of shape + ( (seq_len+downsample-1)//downsample, batch_size, out_channels) + """ + assert src.shape[0] == self.in_x_size, ( + src.shape[0], + self.in_x_size, + src.shape, + type(src), + ) + assert src.shape[2] == self.in_channels, (src.shape[2], self.in_channels) + if not self.is_pnnx: + (seq_len, batch_size, in_channels) = src.shape + else: + seq_len = self.in_x_size + batch_size = 1 + in_channels = self.in_channels + + ds = self.downsample + d_seq_len = self.d_seq_len + + # Pad to an exact multiple of self.downsample + if seq_len != d_seq_len * ds: + assert self.is_pnnx is False, "TODO(fangjun): Handle it!" + # right-pad src, repeating the last element. + pad = d_seq_len * ds - seq_len + src_extra = src[src.shape[0] - 1 :].expand(pad, src.shape[1], src.shape[2]) + src = torch.cat((src, src_extra), dim=0) + assert src.shape[0] == d_seq_len * ds, (src.shape[0], d_seq_len, ds) + + if not self.is_pnnx: + src = src.reshape(d_seq_len, ds, batch_size, in_channels) + scores = (src * self.query).sum(dim=-1, keepdim=True) + + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + scores = penalize_abs_values_gt(scores, limit=10.0, penalty=1.0e-04) + + weights = scores.softmax(dim=1) + + # ans1 is the first `in_channels` channels of the output + ans = (src * weights).sum(dim=1) + src = src.permute(0, 2, 1, 3).reshape( + d_seq_len, batch_size, ds * in_channels + ) + + if self.extra_proj is not None: + ans2 = self.extra_proj(src) + ans = torch.cat((ans, ans2), dim=2) + else: + src = src.reshape(d_seq_len, ds, in_channels) + scores = (src * self.query).sum(dim=-1, keepdim=True) + + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + scores = penalize_abs_values_gt(scores, limit=10.0, penalty=1.0e-04) + + weights = scores.softmax(dim=1) + + # ans1 is the first `in_channels` channels of the output + ans = (src * weights).sum(dim=1) + + assert ( + self.extra_proj is None + ), "The code for it being not None is not tested" + # ans = ans.unsqueeze(1) + ans = self.unsqueeze(ans) + # Note: In ncnn, we ignore self.unsqueeze + # so ans in ncnn is still a 2-D tensor, e.g., (8, 384) + + return ans + + +class SimpleUpsample(torch.nn.Module): + """ + A very simple form of upsampling that mostly just repeats the input, but + also adds a position-specific bias. + """ + + def __init__(self, num_channels: int, upsample: int): + super(SimpleUpsample, self).__init__() + self.bias = nn.Parameter(torch.randn(upsample, num_channels) * 0.01) + self.upsample = upsample + self.num_channels = num_channels + + def forward(self, src: Tensor) -> Tensor: + """ + x: (seq_len, batch_size, num_channels) + Returns a tensor of shape + ( (seq_len*upsample), batch_size, num_channels) + """ + upsample = self.bias.shape[0] + (seq_len, batch_size, num_channels) = src.shape + src = src.unsqueeze(1).expand(seq_len, upsample, batch_size, num_channels) + src = src + self.bias.unsqueeze(1) + src = src.reshape(seq_len * upsample, batch_size, num_channels) + return src + + +class SimpleCombinerIdentity(nn.Module): + def __init__(self, *args, **kwargs): + super().__init__() + + def forward(self, src1: Tensor, src2: Tensor) -> Tensor: + return src1 + + +class SimpleCombiner(torch.nn.Module): + """ + A very simple way of combining 2 vectors of 2 different dims, via a + learned weighted combination in the shared part of the dim. + Args: + dim1: the dimension of the first input, e.g. 256 + dim2: the dimension of the second input, e.g. 384. + The output will have the same dimension as dim2. + """ + + def __init__(self, dim1: int, dim2: int, min_weight: Tuple[float] = (0.0, 0.0)): + super(SimpleCombiner, self).__init__() + assert dim2 >= dim1, (dim2, dim1) + self.weight1 = nn.Parameter(torch.zeros(())) + self.min_weight = min_weight + self.dim1 = dim1 + self.dim2 = dim2 + + def forward(self, src1: Tensor, src2: Tensor) -> Tensor: + """ + src1: (*, dim1) + src2: (*, dim2) + + Returns: a tensor of shape (*, dim2) + """ + assert src1.shape[:-1] == src2.shape[:-1], (src1.shape, src2.shape) + + weight1 = self.weight1 + if not torch.jit.is_scripting(): + if ( + self.training + and random.random() < 0.25 + and self.min_weight != (0.0, 0.0) + ): + weight1 = weight1.clamp( + min=self.min_weight[0], max=1.0 - self.min_weight[1] + ) + + src1 = src1 * weight1 + src2 = src2 * (1.0 - weight1) + + assert src1.shape[-1] == self.dim1, (src1.shape[-1], self.dim1) + assert src2.shape[-1] == self.dim2, (src2.shape[-1], self.dim2) + + src1_dim = self.dim1 + src2_dim = self.dim2 + + if src1_dim != src2_dim: + if src1_dim < src2_dim: + src1 = torch.nn.functional.pad(src1, (0, src2_dim - src1_dim)) + else: + src1 = src1[:src2_dim] + + return src1 + src2 + + +class RelPositionalEncoding(torch.nn.Module): + """Relative positional encoding module. + + See : Appendix B in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" + Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/embedding.py + + Args: + d_model: Embedding dimension. + dropout_rate: Dropout rate. + max_len: Maximum input length. + + """ + + def __init__( + self, + d_model: int, + dropout_rate: float, + max_len: int = 5000, + is_pnnx: bool = False, + x_size: int = 0, + left_context_len: int = 0, + ) -> None: + """Construct a PositionalEncoding object.""" + super(RelPositionalEncoding, self).__init__() + self.d_model = d_model + self.dropout = torch.nn.Dropout(dropout_rate) + self.is_pnnx = is_pnnx + self.x_size = x_size + self.left_context_len = left_context_len + self.pe = None + if is_pnnx: + x_size_left = x_size + left_context_len + self.extend_pe(torch.tensor(0.0).expand(x_size_left)) + self.pe = self.pe[:, :-left_context_len] + assert self.pe.size(1) == x_size + left_context_len - 1 + x_size, ( + self.pe.size(1), + x_size, + left_context_len, + x_size, + self.pe.shape, + ) + else: + self.extend_pe(torch.tensor(0.0).expand(max_len)) + + def extend_pe(self, x: Tensor, left_context_len: int = 0) -> None: + """Reset the positional encodings.""" + x_size_left = x.size(0) + left_context_len + if self.pe is not None: + # self.pe contains both positive and negative parts + # the length of self.pe is 2 * input_len - 1 + if self.pe.size(1) >= x_size_left * 2 - 1: + # Note: TorchScript doesn't implement operator== for torch.Device + if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): + self.pe = self.pe.to(dtype=x.dtype, device=x.device) + return + # Suppose `i` means to the position of query vector and `j` means the + # position of key vector. We use positive relative positions when keys + # are to the left (i>j) and negative relative positions otherwise (i Tensor: + """Add positional encoding. + + Args: + x (torch.Tensor): Input tensor (time, batch, `*`). + left_context_len: (int): Length of cached left context. + + Returns: + torch.Tensor: Encoded tensor (batch, left_context_len + 2*time-1, `*`). + + """ + if self.is_pnnx: + assert self.x_size == x.size(0), (self.x_size, x.size(0)) + assert self.left_context_len == left_context_len, ( + self.left_context_len, + left_context_len, + ) + return self.pe + + self.extend_pe(x, left_context_len) + x_size_left = x.size(0) + left_context_len + pos_emb = self.pe[ + :, + self.pe.size(1) // 2 + - x_size_left + + 1 : self.pe.size(1) // 2 # noqa E203 + + x.size(0), + ] + return self.dropout(pos_emb) + + +class RelPositionMultiheadAttentionPermute(nn.Module): + """ncnn does not support permuatation relating to the batch axis 0. + This is a workaround for exporting to ncnn via PNNX. + """ + + def __init__(self, kind: int): + super().__init__() + self.kind = kind + assert self.kind in (2, 3), self.kind + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.kind == 2: + return x.permute(1, 0, 2) + elif self.kind == 3: + return x.permute(1, 2, 0) + else: + assert False, f"Unsupported kind {self.kind}" + + +class RelPositionMultiheadAttention(nn.Module): + r"""Multi-Head Attention layer with relative position encoding + + This is a quite heavily modified from: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context", + we have to write up the differences. + + + Args: + embed_dim: total dimension of the model. + attention_dim: dimension in the attention module, may be less or more than embed_dim + but must be a multiple of num_heads. + num_heads: parallel attention heads. + dropout: a Dropout layer on attn_output_weights. Default: 0.0. + + Examples:: + + >>> rel_pos_multihead_attn = RelPositionMultiheadAttention(embed_dim, num_heads) + >>> attn_output, attn_output_weights = multihead_attn(query, key, value, pos_emb) + """ + + def __init__( + self, + embed_dim: int, + attention_dim: int, + num_heads: int, + pos_dim: int, + dropout: float = 0.0, + is_pnnx: bool = False, + left_context_len: int = 0, + x_size: int = 0, + ) -> None: + super(RelPositionMultiheadAttention, self).__init__() + self.embed_dim = embed_dim + self.attention_dim = attention_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = attention_dim // num_heads + self.pos_dim = pos_dim + assert self.head_dim % 2 == 0, self.head_dim + assert self.head_dim * num_heads == attention_dim, ( + self.head_dim, + num_heads, + attention_dim, + ) + + self.is_pnnx = is_pnnx + + self.my_permute_pqv = RelPositionMultiheadAttentionPermute(kind=2) + self.my_permute_k_pos = RelPositionMultiheadAttentionPermute(kind=3) + self.left_context_len = left_context_len + self.x_size = x_size + + # the initial_scale is supposed to take over the "scaling" factor of + # head_dim ** -0.5, dividing it between the query and key. + in_proj_dim = ( + 2 * attention_dim + + attention_dim // 2 # query (attention_dim,), key (attention_dim,) + + pos_dim * num_heads # value (attention_dim // 2,) + ) # positional encoding query (pos_dim * num_heads, ) + + self.in_proj = ScaledLinear( + embed_dim, in_proj_dim, bias=True, initial_scale=self.head_dim**-0.25 + ) + + # self.whiten_values is applied on the values in forward(); + # it just copies the keys but prevents low-rank distribution by modifying grads. + self.whiten_values = Whiten( + num_groups=num_heads, + whitening_limit=2.0, + prob=(0.025, 0.25), + grad_scale=0.025, + ) + self.whiten_keys = Whiten( + num_groups=num_heads, + whitening_limit=2.0, + prob=(0.025, 0.25), + grad_scale=0.025, + ) + + # linear transformation for positional encoding. + self.linear_pos = ScaledLinear( + embed_dim, num_heads * pos_dim, bias=False, initial_scale=0.05 + ) + + # the following are for diagnosics only, see --print-diagnostics option. + # they only copy their inputs. + self.copy_pos_query = Identity() + self.copy_query = Identity() + + self.out_proj = ScaledLinear( + attention_dim // 2, embed_dim, bias=True, initial_scale=0.05 + ) + + self.in_proj2 = nn.Linear(embed_dim, attention_dim // 2, bias=False) + self.out_proj2 = ScaledLinear( + attention_dim // 2, embed_dim, bias=True, initial_scale=0.05 + ) + # self.whiten_values2 is applied on the values in forward2() + self.whiten_values2 = Whiten( + num_groups=num_heads, + whitening_limit=2.0, + prob=(0.025, 0.25), + grad_scale=0.025, + ) + + def forward( + self, + x: Tensor, + pos_emb: Tensor, + key_padding_mask: Optional[Tensor] = None, + attn_mask: Optional[Tensor] = None, + ) -> Tuple[Tensor, Tensor]: + r""" + Args: + x: input to be projected to query, key, value + pos_emb: Positional embedding tensor + key_padding_mask: if provided, specified padding elements in the key will + be ignored by the attention. When given a binary mask and a value is True, + the corresponding value on the attention layer will be ignored. When given + a byte mask and a value is non-zero, the corresponding value on the attention + layer will be ignored + attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all + the batches while a 3D mask allows to specify a different mask for the entries of each batch. + + Shape: + - Inputs: + - x: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is + the embedding dimension. + - pos_emb: :math:`(N, 2*L-1, E)` where L is the target sequence length, N is the batch size, E is + the embedding dimension. + - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. + If a ByteTensor is provided, the non-zero positions will be ignored while the position + with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the + value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. + - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. + 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, + S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked + positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend + while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` + is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor + is provided, it will be added to the attention weight. + + - Returns: (attn_output, attn_weights) + + - attn_output: :math:`(S, N, E)` where S is the sequence length, N is the batch size, + E is the embedding dimension. + - attn_weights: :math:`(N * N, S, S)` where N is the batch size, H is the num-heads + and S is the sequence length. + """ + x, weights = self.multi_head_attention_forward( + self.in_proj(x), + self.linear_pos(pos_emb), + self.attention_dim, + self.num_heads, + self.dropout, + self.out_proj.weight, + self.out_proj.bias, + training=self.training, + key_padding_mask=key_padding_mask, + attn_mask=attn_mask, + ) + return x, weights + + def streaming_forward( + self, + x: Tensor, + pos_emb: Tensor, + cached_key: Tensor, + cached_val: Tensor, + ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + r""" + Args: + x: input to be projected to query, key, value + pos_emb: Positional embedding tensor + + Shape: + - Inputs: + - x: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is + the embedding dimension. + - pos_emb: :math:`(N, 2*L-1, E)` where L is the target sequence length, N is the batch size, E is + the embedding dimension. + - cached_key: :math:`(left_context_len, N, K)`, where N is the batch size, K is the key dimension. + - cached_val: :math:`(left_context_len, N, V)`, where N is the batch size, V is the value dimension. + + - Returns: (attn_output, attn_weights, cached_key, cached_val) + + - attn_output: :math:`(S, N, E)` where S is the sequence length, N is the batch size, + E is the embedding dimension. + - attn_weights: :math:`(N * N, S, S)` where N is the batch size, H is the num-heads + and S is the sequence length. + - cached_key: :math:`(left_context_len, N, K)`, updated cached attention key tensor of + left context + - cached_val: :math:`(left_context_len, N, K)`, updated cached attention value tensor of + """ + ( + x, + weights, + cached_key, + cached_val, + ) = self.streaming_multi_head_attention_forward( + self.in_proj(x), + self.linear_pos(pos_emb), + self.attention_dim, + self.num_heads, + self.out_proj.weight, + self.out_proj.bias, + cached_key=cached_key, + cached_val=cached_val, + ) + return x, weights, cached_key, cached_val + + def multi_head_attention_forward( + self, + x_proj: Tensor, + pos: Tensor, + attention_dim: int, + num_heads: int, + dropout_p: float, + out_proj_weight: Tensor, + out_proj_bias: Tensor, + training: bool = True, + key_padding_mask: Optional[Tensor] = None, + attn_mask: Optional[Tensor] = None, + ) -> Tuple[Tensor, Tensor]: + r""" + Args: + x_proj: the projected input, to be split into query, key, value. + pos: head-specific biases arising from the positional embeddings. + attention_dim: dimension inside attention mechanism + num_heads: parallel attention heads. + dropout_p: probability of an element to be zeroed. + out_proj_weight, out_proj_bias: the output projection weight and bias. + training: apply dropout if is ``True``. + key_padding_mask: if provided, specified padding elements in the key will + be ignored by the attention. This is an binary mask. When the value is True, + the corresponding value on the attention layer will be filled with -inf. + attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all + the batches while a 3D mask allows to specify a different mask for the entries of each batch. + + Shape: + Inputs: + - x: :math:`(L, N, 7 * A // 2)` where L is the target sequence length, N is the batch size, A is + the attention dimension. Will be split into (query, key, value, pos). + - pos: :math:`(N, 2*L-1, A//2)` or :math:`(1, 2*L-1, A//2)` where L is the sequence + length, N is the batch size, and A is the attention dim. + - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. + If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions + will be unchanged. If a BoolTensor is provided, the positions with the + value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. + - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. + 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, + S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked + positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend + while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` + are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor + is provided, it will be added to the attention weight. + + Outputs: + - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, + E is the embedding dimension. + - attn_weights: :math:`(N * H, S, S)` where N is the batch size, + H is the num-heads, S is the sequence length. + """ + + seq_len, bsz, _ = x_proj.size() + + head_dim = attention_dim // num_heads + pos_dim = self.pos_dim # positional-encoding dim per head + assert ( + head_dim * num_heads == attention_dim + ), f"attention_dim must be divisible by num_heads: {head_dim}, {num_heads}, {attention_dim}" + + # self-attention + q = x_proj[..., 0:attention_dim] + k = x_proj[..., attention_dim : 2 * attention_dim] + value_dim = attention_dim // 2 + v = x_proj[..., 2 * attention_dim : 2 * attention_dim + value_dim] + # p is the position-encoding query, its dimension is num_heads*pos_dim.. + p = x_proj[..., 2 * attention_dim + value_dim :] + + k = self.whiten_keys(k) # does nothing in the forward pass. + v = self.whiten_values(v) # does nothing in the forward pass. + q = self.copy_query(q) # for diagnostics only, does nothing. + p = self.copy_pos_query(p) # for diagnostics only, does nothing. + + if attn_mask is not None: + assert ( + attn_mask.dtype == torch.float32 + or attn_mask.dtype == torch.float64 + or attn_mask.dtype == torch.float16 + or attn_mask.dtype == torch.uint8 + or attn_mask.dtype == torch.bool + ), "Only float, byte, and bool types are supported for attn_mask, not {}".format( + attn_mask.dtype + ) + if attn_mask.dtype == torch.uint8: + warnings.warn( + "Byte tensor for attn_mask is deprecated. Use bool tensor instead." + ) + attn_mask = attn_mask.to(torch.bool) + + if attn_mask.dim() == 2: + attn_mask = attn_mask.unsqueeze(0) + if list(attn_mask.size()) != [1, seq_len, seq_len]: + raise RuntimeError("The size of the 2D attn_mask is not correct.") + elif attn_mask.dim() == 3: + if list(attn_mask.size()) != [ + bsz * num_heads, + seq_len, + seq_len, + ]: + raise RuntimeError("The size of the 3D attn_mask is not correct.") + else: + raise RuntimeError( + "attn_mask's dimension {} is not supported".format(attn_mask.dim()) + ) + # attn_mask's dim is 3 now. + + # convert ByteTensor key_padding_mask to bool + if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: + warnings.warn( + "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." + ) + key_padding_mask = key_padding_mask.to(torch.bool) + + q = q.reshape(seq_len, bsz, num_heads, head_dim) + p = p.reshape(seq_len, bsz, num_heads, pos_dim) + k = k.reshape(seq_len, bsz, num_heads, head_dim) + v = v.reshape(seq_len, bsz * num_heads, head_dim // 2).transpose(0, 1) + + if key_padding_mask is not None: + assert key_padding_mask.size(0) == bsz, "{} == {}".format( + key_padding_mask.size(0), bsz + ) + assert key_padding_mask.size(1) == seq_len, "{} == {}".format( + key_padding_mask.size(1), seq_len + ) + + q = q.permute(1, 2, 0, 3) # (batch, head, time1, head_dim) + p = p.permute(1, 2, 0, 3) # (batch, head, time1, pos_dim) + k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) + + seq_len2 = 2 * seq_len - 1 + pos = pos.reshape(1, seq_len2, num_heads, pos_dim).permute(0, 2, 3, 1) + # pos shape now: (batch, head, pos_dim, seq_len2) + + # (batch, head, time1, pos_dim) x (1, head, pos_dim, seq_len2) -> (batch, head, time1, seq_len2) + # [where seq_len2 represents relative position.] + pos_weights = torch.matmul(p, pos) + # the following .as_strided() expression converts the last axis of pos_weights from relative + # to absolute position. I don't know whether I might have got the time-offsets backwards or + # not, but let this code define which way round it is supposed to be. + pos_weights = pos_weights.as_strided( + (bsz, num_heads, seq_len, seq_len), + ( + pos_weights.stride(0), + pos_weights.stride(1), + pos_weights.stride(2) - pos_weights.stride(3), + pos_weights.stride(3), + ), + storage_offset=pos_weights.stride(3) * (seq_len - 1), + ) + + # caution: they are really scores at this point. + attn_output_weights = torch.matmul(q, k) + pos_weights + + if not torch.jit.is_scripting(): + if training and random.random() < 0.1: + # This is a harder way of limiting the attention scores to not be too large. + # It incurs a penalty if any of them has an absolute value greater than 50.0. + # this should be outside the normal range of the attention scores. We use + # this mechanism instead of, say, a limit on entropy, because once the entropy + # gets very small gradients through the softmax can become very small, and + # some mechanisms like that become ineffective. + attn_output_weights = penalize_abs_values_gt( + attn_output_weights, limit=25.0, penalty=1.0e-04 + ) + + # attn_output_weights: (batch, head, time1, time2) + attn_output_weights = attn_output_weights.view( + bsz * num_heads, seq_len, seq_len + ) + + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + attn_output_weights = attn_output_weights.masked_fill( + attn_mask, float("-inf") + ) + else: + attn_output_weights = attn_output_weights + attn_mask + + if key_padding_mask is not None: + attn_output_weights = attn_output_weights.view( + bsz, num_heads, seq_len, seq_len + ) + attn_output_weights = attn_output_weights.masked_fill( + key_padding_mask.unsqueeze(1).unsqueeze(2), + float("-inf"), + ) + attn_output_weights = attn_output_weights.view( + bsz * num_heads, seq_len, seq_len + ) + + # Using this version of softmax, defined in scaling.py, + # should save a little of the memory used in backprop by, if + # we are in automatic mixed precision mode (amp) == autocast, + # only storing the half-precision output for backprop purposes. + attn_output_weights = softmax(attn_output_weights, dim=-1) + + # If we are using chunk-wise attention mask and setting a limited + # num_left_chunks, the attention may only see the padding values which + # will also be masked out by `key_padding_mask`. At this circumstances, + # the whole column of `attn_output_weights` will be `-inf` + # (i.e. be `nan` after softmax). So we fill `0.0` at the masking + # positions to avoid invalid loss value below. + if ( + attn_mask is not None + and attn_mask.dtype == torch.bool + and key_padding_mask is not None + ): + if attn_mask.size(0) != 1: + attn_mask = attn_mask.view(bsz, num_heads, seq_len, seq_len) + combined_mask = attn_mask | key_padding_mask.unsqueeze(1).unsqueeze(2) + else: + # attn_mask.shape == (1, tgt_len, src_len) + combined_mask = attn_mask.unsqueeze(0) | key_padding_mask.unsqueeze( + 1 + ).unsqueeze(2) + + attn_output_weights = attn_output_weights.view( + bsz, num_heads, seq_len, seq_len + ) + attn_output_weights = attn_output_weights.masked_fill(combined_mask, 0.0) + attn_output_weights = attn_output_weights.view( + bsz * num_heads, seq_len, seq_len + ) + + attn_output_weights = nn.functional.dropout( + attn_output_weights, p=dropout_p, training=training + ) + + attn_output = torch.bmm(attn_output_weights, v) + assert list(attn_output.size()) == [bsz * num_heads, seq_len, head_dim // 2] + attn_output = ( + attn_output.transpose(0, 1) + .contiguous() + .view(seq_len, bsz, attention_dim // 2) + ) + attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) + + return attn_output, attn_output_weights + + def streaming_multi_head_attention_forward( + self, + x_proj: Tensor, + pos: Tensor, + attention_dim: int, + num_heads: int, + out_proj_weight: Tensor, + out_proj_bias: Tensor, + cached_key: Tensor, + cached_val: Tensor, + ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + r""" + Args: + x_proj: the projected input, to be split into query, key, value. + pos: head-specific biases arising from the positional embeddings. + attention_dim: dimension inside attention mechanism + num_heads: parallel attention heads. + out_proj_weight, out_proj_bias: the output projection weight and bias. + cached_key: cached attention key tensor of left context. + cached_val: cached attention value tensor of left context. + + Shape: + Inputs: + - x: :math:`(L, N, 7 * A // 2)` where L is the target sequence length, N is the batch size, A is + the attention dimension. Will be split into (query, key, value, pos). + - pos: :math:`(N, 2*L-1, A//2)` or :math:`(1, 2*L-1, A//2)` where L is the sequence + length, N is the batch size, and A is the attention dim. + If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions + will be unchanged. If a BoolTensor is provided, the positions with the + value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. + + Outputs: + - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, + E is the embedding dimension. + - attn_weights: :math:`(N * H, S, S)` where N is the batch size, + H is the num-heads, S is the sequence length. + - cached_key: :math:`(left_context_len, N, K)`, updated cached attention key tensor of left context. + - cached_val: :math:`(left_context_len, N, K)`, updated cached attention value tensor of left context. + """ + if not self.is_pnnx: + seq_len, bsz, _ = x_proj.size() + assert seq_len == self.x_size, (seq_len, self.x_size) + else: + seq_len = self.x_size + bsz = 1 + + head_dim = attention_dim // num_heads + pos_dim = self.pos_dim # positional-encoding dim per head + assert ( + head_dim * num_heads == attention_dim + ), f"attention_dim must be divisible by num_heads: {head_dim}, {num_heads}, {attention_dim}" + + # self-attention + q = x_proj[:, :, 0:attention_dim] # (x_size, N, attention_dim) + # return q, q, q, q + k = x_proj[:, :, attention_dim : 2 * attention_dim] + # k is (x_size, N, attention_dim) + value_dim = attention_dim // 2 + v = x_proj[:, :, 2 * attention_dim : 2 * attention_dim + value_dim] + # v is (x_size, 0, attention_dim//2) + + # p is the position-encoding query, its dimension is num_heads*pos_dim.. + p = x_proj[:, :, 2 * attention_dim + value_dim :] + # p is (x_size, N, pos_dim * num_heads) + + if not self.is_pnnx: + left_context_len = cached_key.shape[0] + else: + assert cached_key.shape[0] == self.left_context_len, ( + cached_key.shape, + self.left_context_len, + ) + left_context_len = self.left_context_len + + assert left_context_len > 0, left_context_len + assert cached_key.shape[0] == cached_val.shape[0], ( + cached_key.shape, + cached_val.shape, + ) + # Note: We need to fix the Concat in ncnn + # cached_key is (1, 64, 192) in ncnn + # k is (16, 192) in ncnn + # Pad cached left contexts + k = torch.cat([cached_key, k], dim=0) + # (left_context_len + x_size, N, attention_dim) + + v = torch.cat([cached_val, v], dim=0) + # v: (left_context_len + x_size, N, attention_dim//2) + # Update cached left contexts + if not self.is_pnnx: + cached_key = k[-left_context_len:, ...] + cached_val = v[-left_context_len:, ...] + else: + cached_key = k[self.x_size :] + cached_val = v[self.x_size :] + assert cached_key.shape[0] == left_context_len, ( + cached_key.shape, + left_context_len, + ) + assert cached_val.shape[0] == left_context_len, ( + cached_val.shape, + left_context_len, + ) + + if not self.is_pnnx: + # The length of key and value + kv_len = k.shape[0] + else: + kv_len = left_context_len + self.x_size + assert kv_len == k.shape[0], (kv_len, k.shape) + + if not self.is_pnnx: + q = q.reshape(seq_len, bsz, num_heads, head_dim) + p = p.reshape(seq_len, bsz, num_heads, pos_dim) + k = k.reshape(kv_len, bsz, num_heads, head_dim) + + v = v.reshape(kv_len, bsz * num_heads, head_dim // 2).transpose(0, 1) + # v is (bsz * num_heads, kv_len, head_dim//2) + + q = q.permute(1, 2, 0, 3) # (batch, head, time1, head_dim) + p = p.permute(1, 2, 0, 3) # (batch, head, time1, pos_dim) + k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) + + seq_len2 = 2 * seq_len - 1 + left_context_len + pos = pos.reshape(1, seq_len2, num_heads, pos_dim).permute(0, 2, 3, 1) + # pos shape now: (batch, head, pos_dim, seq_len2) + else: + q = q.reshape(seq_len, num_heads, head_dim) + p = p.reshape(seq_len, num_heads, pos_dim) + k = k.reshape(kv_len, num_heads, head_dim) + # v = v.reshape(kv_len, num_heads, head_dim // 2).permute(1, 0, 2) + v = v.reshape(kv_len, num_heads, head_dim // 2) + v = self.my_permute_pqv(v) + # v is (num_heads, kv_len, head_dim//2) e.g., (8, 80, 12) + + # q = q.permute(1, 0, 2) # (head, time1, head_dim) + # p = p.permute(1, 0, 2) # (head, time1, pos_dim) + # k = k.permute(1, 2, 0) # (head, d_k, time2) + + q = self.my_permute_pqv(q) # (head, time1, head_dim), e.g., (8, 16, 24) + p = self.my_permute_pqv(p) # (head, time1, pos_dim), e.g., (8, 16, 4) + k = self.my_permute_k_pos(k) # (head, d_k, time2) e.g., (8, 24, 80) + + seq_len2 = 2 * seq_len - 1 + left_context_len + # pos = pos.reshape(seq_len2, num_heads, pos_dim).permute(1, 2, 0) + # pos shape now: (head, pos_dim, seq_len2) + + pos = pos.reshape(seq_len2, num_heads, pos_dim) + pos = self.my_permute_k_pos( + pos + ) # (head, pos_dim, seq_len2), e.g, (8, 4, 95) + + # (batch, head, time1, pos_dim) x (1, head, pos_dim, seq_len2) -> (batch, head, time1, seq_len2) ,e.g., (1, 8, 16, 95) + # [where seq_len2 represents relative position.] + pos_weights = torch.matmul(p, pos) + + # the following .as_strided() expression converts the last axis of pos_weights from relative + # to absolute position. I don't know whether I might have got the time-offsets backwards or + # not, but let this code define which way round it is supposed to be. + + if not self.is_pnnx: + pos_weights = pos_weights.as_strided( + (bsz, num_heads, seq_len, kv_len), + ( + pos_weights.stride(0), + pos_weights.stride(1), + pos_weights.stride(2) - pos_weights.stride(3), + pos_weights.stride(3), + ), + storage_offset=pos_weights.stride(3) * (seq_len - 1), + ) + else: + pos_weights = pos_weights.as_strided( + (num_heads, seq_len, kv_len), + ( + pos_weights.stride(0), + pos_weights.stride(1) - pos_weights.stride(2), + pos_weights.stride(2), + ), + storage_offset=pos_weights.stride(2) * (seq_len - 1), + ) + + # caution: they are really scores at this point. + attn_output_weights = torch.matmul(q, k) + pos_weights + + # attn_output_weights: (batch, head, time1, time2) + attn_output_weights = attn_output_weights.view(bsz * num_heads, seq_len, kv_len) + + # Using this version of softmax, defined in scaling.py, + # should save a little of the memory used in backprop by, if + # we are in automatic mixed precision mode (amp) == autocast, + # only storing the half-precision output for backprop purposes. + attn_output_weights = softmax(attn_output_weights, dim=-1) + + attn_output = torch.bmm(attn_output_weights, v) + + assert list(attn_output.size()) == [bsz * num_heads, seq_len, head_dim // 2] + # (8, 16, 12) + + if not self.is_pnnx: + attn_output = ( + attn_output.transpose(0, 1) + .contiguous() + .view(seq_len, bsz, attention_dim // 2) + ) + attn_output = nn.functional.linear( + attn_output, out_proj_weight, out_proj_bias + ) + else: + attn_output = self.my_permute_pqv(attn_output) # (1, 0, 2) + attn_output = attn_output.reshape(seq_len, bsz, attention_dim // 2) + # We have changed InnerProduct in ncnn to treat + # (seq_len, bsz, attention_dim//2) as + # (seq_len, attention_dim//2) + + attn_output = nn.functional.linear( + attn_output, out_proj_weight, out_proj_bias + ) + + return attn_output, attn_output_weights, cached_key, cached_val + + def forward2( + self, + x: Tensor, + attn_weights: Tensor, + ) -> Tensor: + """ + Second forward function, where we re-use the attn_weights returned by the first forward function + but with different input. + Args: + x: input, of shape (seq_len, batch_size, embed_dim) + attn_weights: attention weights returned by forward(), of shape (batch_size * num_heads, seq_len, seq_len) + Returns: + output of the same shape as x, i.e. (seq_len, batch_size, embed_dim) + """ + num_heads = self.num_heads + (seq_len, bsz, embed_dim) = x.shape + head_dim = self.attention_dim // num_heads + # v: (tgt_len, bsz, embed_dim // 2) + v = self.in_proj2(x) + v = self.whiten_values2(v) # does nothing in the forward pass. + v = v.reshape(seq_len, bsz * num_heads, head_dim // 2).transpose(0, 1) + + # now v: (bsz * num_heads, seq_len, head_dim // 2) + attn_output = torch.bmm(attn_weights, v) + + if not torch.jit.is_scripting(): + if random.random() < 0.001 or __name__ == "__main__": + self._print_attn_stats(attn_weights, attn_output) + + # attn_output: (bsz * num_heads, seq_len, head_dim) + attn_output = ( + attn_output.transpose(0, 1) + .contiguous() + .view(seq_len, bsz, self.attention_dim // 2) + ) + # returned value is of shape (seq_len, bsz, embed_dim), like x. + return self.out_proj2(attn_output) + + def streaming_forward2( + self, + x: Tensor, + attn_weights: Tensor, + cached_val: Tensor, + ) -> Tuple[Tensor, Tensor]: + """ + Second forward function, where we re-use the attn_weights returned by the first forward function + but with different input. + Args: + x: input, of shape (seq_len, batch_size, embed_dim) + attn_weights: attention weights returned by forward(), of shape (batch_size * num_heads, seq_len, seq_len) + cached_val: cached attention value tensor of left context. + Returns: + - output of the same shape as x, i.e. (seq_len, batch_size, embed_dim) + - updated cached attention value tensor of left context. + """ + num_heads = self.num_heads + + assert x.shape[0] == self.x_size, (x.shape[0], self.x_size) + assert x.shape[2] == self.embed_dim, (x.shape[2], self.embed_dim) + + if not self.is_pnnx: + (seq_len, bsz, embed_dim) = x.shape + else: + seq_len = self.x_size + bsz = 1 + embed_dim = self.embed_dim + + head_dim = self.attention_dim // num_heads + # v: (tgt_len, bsz, embed_dim // 2) + v = self.in_proj2(x) + + assert cached_val.shape[0] == self.left_context_len, ( + cached_val.shape[0], + self.left_context_len, + ) + + left_context_len = self.left_context_len + assert left_context_len > 0, left_context_len + v = torch.cat([cached_val, v], dim=0) + cached_val = v[-left_context_len:] + + seq_len2 = left_context_len + seq_len + if not self.is_pnnx: + v = v.reshape(seq_len2, bsz * num_heads, head_dim // 2).transpose(0, 1) + else: + v = v.reshape(seq_len2, bsz * num_heads, head_dim // 2) + # v = v.permute(1, 0, 2) + v = self.my_permute_pqv(v) + + # now v: (bsz * num_heads, seq_len, head_dim // 2) + attn_output = torch.bmm(attn_weights, v) + + if not self.is_pnnx: + # attn_output: (bsz * num_heads, seq_len, head_dim) + attn_output = ( + attn_output.transpose(0, 1) + .contiguous() + .view(seq_len, bsz, self.attention_dim // 2) + ) + else: + attn_output = self.my_permute_pqv(attn_output) # (1, 0, 2) + attn_output = attn_output.reshape(seq_len, bsz, self.attention_dim // 2) + # We have changed InnerProduct in ncnn to ignore bsz + # when invoking self.out_proj2(attn_output) + + # returned value is of shape (seq_len, bsz, embed_dim), like x. + return self.out_proj2(attn_output), cached_val + + def _print_attn_stats(self, attn_weights: Tensor, attn_output: Tensor): + # attn_weights: (batch_size * num_heads, seq_len, seq_len) + # attn_output: (bsz * num_heads, seq_len, head_dim) + (n, seq_len, head_dim) = attn_output.shape + num_heads = self.num_heads + bsz = n // num_heads + + with torch.no_grad(): + with torch.cuda.amp.autocast(enabled=False): + attn_weights = attn_weights.to(torch.float32) + attn_output = attn_output.to(torch.float32) + attn_weights_entropy = ( + -((attn_weights + 1.0e-20).log() * attn_weights) + .sum(dim=-1) + .reshape(bsz, num_heads, seq_len) + .mean(dim=(0, 2)) + ) + attn_output = attn_output.reshape(bsz, num_heads, seq_len, head_dim) + attn_output = attn_output.permute(1, 0, 2, 3).reshape( + num_heads, bsz * seq_len, head_dim + ) + attn_output_mean = attn_output.mean(dim=1, keepdim=True) + attn_output = attn_output - attn_output_mean + attn_covar = torch.matmul(attn_output.transpose(1, 2), attn_output) / ( + bsz * seq_len + ) + # attn_covar: (num_heads, head_dim, head_dim) + # eigs, _ = torch.symeig(attn_covar) + # logging.info(f"attn_weights_entropy = {attn_weights_entropy}, output_eigs = {eigs}") + + attn_covar = _diag(attn_covar).mean(dim=1) # (num_heads,) + embed_dim = self.in_proj2.weight.shape[1] + in_proj_covar = ( + self.in_proj2.weight.reshape(num_heads, head_dim, embed_dim) ** 2 + ).mean(dim=(1, 2)) + out_proj_covar = ( + self.out_proj2.weight.reshape(embed_dim, num_heads, head_dim) ** 2 + ).mean(dim=(0, 2)) + logging.info( + f"attn_weights_entropy = {attn_weights_entropy}, covar={attn_covar}, in_proj_covar={in_proj_covar}, out_proj_covar={out_proj_covar}" + ) + + +class FeedforwardModule(nn.Module): + """Feedforward module in Zipformer model.""" + + def __init__(self, d_model: int, feedforward_dim: int, dropout: float): + super(FeedforwardModule, self).__init__() + self.in_proj = nn.Linear(d_model, feedforward_dim) + self.balancer = ActivationBalancer( + feedforward_dim, channel_dim=-1, max_abs=10.0, min_prob=0.25 + ) + self.activation = DoubleSwish() + self.dropout = nn.Dropout(dropout) + self.out_proj = ScaledLinear(feedforward_dim, d_model, initial_scale=0.01) + + def forward(self, x: Tensor): + x = self.in_proj(x) + x = self.balancer(x) + x = self.activation(x) + x = self.dropout(x) + x = self.out_proj(x) + return x + + +class ConvolutionModule(nn.Module): + """ConvolutionModule in Zipformer model. + Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/conformer/convolution.py + + Args: + channels (int): The number of channels of conv layers. + kernel_size (int): Kernerl size of conv layers. + bias (bool): Whether to use bias in conv layers (default=True). + + """ + + def __init__( + self, + channels: int, + kernel_size: int, + bias: bool = True, + is_pnnx: bool = False, + x_size: int = 0, + ) -> None: + """Construct an ConvolutionModule object.""" + super(ConvolutionModule, self).__init__() + # kernerl_size should be a odd number for 'SAME' padding + assert (kernel_size - 1) % 2 == 0, kernel_size + + self.pointwise_conv1 = nn.Conv1d( + channels, + 2 * channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + ) + + # after pointwise_conv1 we put x through a gated linear unit (nn.functional.glu). + # For most layers the normal rms value of channels of x seems to be in the range 1 to 4, + # but sometimes, for some reason, for layer 0 the rms ends up being very large, + # between 50 and 100 for different channels. This will cause very peaky and + # sparse derivatives for the sigmoid gating function, which will tend to make + # the loss function not learn effectively. (for most layers the average absolute values + # are in the range 0.5..9.0, and the average p(x>0), i.e. positive proportion, + # at the output of pointwise_conv1.output is around 0.35 to 0.45 for different + # layers, which likely breaks down as 0.5 for the "linear" half and + # 0.2 to 0.3 for the part that goes into the sigmoid. The idea is that if we + # constrain the rms values to a reasonable range via a constraint of max_abs=10.0, + # it will be in a better position to start learning something, i.e. to latch onto + # the correct range. + self.deriv_balancer1 = ActivationBalancer( + 2 * channels, + channel_dim=1, + max_abs=10.0, + min_positive=0.05, + max_positive=1.0, + ) + + # Will pad cached left context + self.lorder = kernel_size - 1 + self.depthwise_conv = nn.Conv1d( + channels, + channels, + kernel_size, + stride=1, + padding=0, + groups=channels, + bias=bias, + ) + + self.deriv_balancer2 = ActivationBalancer( + channels, + channel_dim=1, + min_positive=0.05, + max_positive=1.0, + max_abs=20.0, + ) + + self.activation = DoubleSwish() + + self.pointwise_conv2 = ScaledConv1d( + channels, + channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + initial_scale=0.05, + ) + + self.is_pnnx = is_pnnx + self.x_size = x_size + + def forward( + self, + x: Tensor, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + """Compute convolution module. + + Args: + x: Input tensor (#time, batch, channels). + src_key_padding_mask: the mask for the src keys per batch (optional): + (batch, #time), contains bool in masked positions. + + Returns: + - Output tensor (#time, batch, channels). + """ + # exchange the temporal dimension and the feature dimension + x = x.permute(1, 2, 0) # (#batch, channels, time). + + # GLU mechanism + x = self.pointwise_conv1(x) # (batch, 2*channels, time) + + x = self.deriv_balancer1(x) + x = nn.functional.glu(x, dim=1) # (batch, channels, time) + + if src_key_padding_mask is not None: + x.masked_fill_(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0) + + # 1D Depthwise Conv + # Make depthwise_conv causal by + # manualy padding self.lorder zeros to the left + x = nn.functional.pad(x, (self.lorder, 0), "constant", 0.0) + x = self.depthwise_conv(x) + + x = self.deriv_balancer2(x) + x = self.activation(x) + + x = self.pointwise_conv2(x) # (batch, channel, time) + + return x.permute(2, 0, 1) + + def streaming_forward( + self, + x: Tensor, + cache: Tensor, + ) -> Tuple[Tensor, Tensor]: + """Compute convolution module. + + Args: + x: Input tensor (#time, batch, channels). + src_key_padding_mask: the mask for the src keys per batch: + (batch, #time), contains bool in masked positions. + cache: Cached left context for depthwise_conv, with shape of + (batch, channels, #kernel_size-1). Only used in real streaming decoding. + + Returns: + A tuple of 2 tensors: + - Output tensor (#time, batch, channels). + - New cached left context, with shape of (batch, channels, #kernel_size-1). + """ + # exchange the temporal dimension and the feature dimension + x = x.permute(1, 2, 0) # (#batch, channels, time). + + # GLU mechanism + x = self.pointwise_conv1(x) # (batch, 2*channels, time) + + x = self.deriv_balancer1(x) + x = nn.functional.glu(x, dim=1) # (batch, channels, time) + + # 1D Depthwise Conv + assert cache.shape == (x.size(0), x.size(1), self.lorder), ( + cache.shape, + (x.size(0), x.size(1), self.lorder), + ) + x = torch.cat([cache, x], dim=2) + + cache = x[:, :, self.x_size :] + + x = self.depthwise_conv(x) + + x = self.deriv_balancer2(x) + x = self.activation(x) + + x = self.pointwise_conv2(x) # (batch, channel, time) + + return x.permute(2, 0, 1), cache + + +class Conv2dSubsampling(nn.Module): + """Convolutional 2D subsampling (to 1/4 length). + + Convert an input of shape (N, T, idim) to an output + with shape (N, T', odim), where + T' = (T-3)//2 - 2 == (T-7)//2 + + It is based on + https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + layer1_channels: int = 8, + layer2_channels: int = 32, + layer3_channels: int = 128, + dropout: float = 0.1, + is_pnnx: bool = False, + ) -> None: + """ + Args: + in_channels: + Number of channels in. The input shape is (N, T, in_channels). + Caution: It requires: T >=7, in_channels >=7 + out_channels + Output dim. The output shape is (N, (T-7)//2, out_channels) + layer1_channels: + Number of channels in layer1 + layer2_channels: + Number of channels in layer2 + layer3_channels: + Number of channels in layer3 + is_pnnx: + True if we are converting the model to PNNX format. + False otherwise. + """ + assert in_channels >= 7, in_channels + super().__init__() + + self.conv = nn.Sequential( + nn.Conv2d( + in_channels=1, + out_channels=layer1_channels, + kernel_size=3, + padding=(0, 1), # (time, freq) + ), + # After this layer (N, 1, T, C) -> (N, layer1_channels, T-2, C) + ActivationBalancer(layer1_channels, channel_dim=1), + DoubleSwish(), + nn.Conv2d( + in_channels=layer1_channels, + out_channels=layer2_channels, + kernel_size=3, + stride=2, + padding=0, + ), + # After this layer (N, layer1_channels, T-2, C) -> (N, layer2_channels, ((T-2) - 3)//2+1, (C-3)//2+1) + # i.e., (N, layer2_channels, (T-5)//2+1, (C-3)//2+1) + # i.e., (N, layer2_channels, (T-3)//2, (C-1)//2) + ActivationBalancer(layer2_channels, channel_dim=1), + DoubleSwish(), + nn.Conv2d( + in_channels=layer2_channels, + out_channels=layer3_channels, + kernel_size=3, + stride=(1, 2), # (time, freq) + ), + # After this layer, (N, layer2_channels, (T-3)//2, (C-1)//2) + # -> + # (N, layer3_channels, (T-3)//2-2, ((C-1)//2 - 3)//2 + 1) + # (N, layer3_channels, (T-7)//2, (C-3)//4) + ActivationBalancer(layer3_channels, channel_dim=1), + DoubleSwish(), + ) + out_height = (((in_channels - 1) // 2) - 1) // 2 + self.out = ScaledLinear(out_height * layer3_channels, out_channels) + self.dropout = nn.Dropout(dropout) + + # ncnn supports only batch size == 1 + self.is_pnnx = is_pnnx + self.conv_out_dim = self.out.weight.shape[1] + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Subsample x. + + Args: + x: + Its shape is (N, T, idim). + + Returns: + Return a tensor of shape (N, (T-7)//2, odim) + """ + # On entry, x is (N, T, idim) + x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W) + x = self.conv(x) + + if torch.jit.is_tracing() and self.is_pnnx: + x = x.permute(0, 2, 1, 3).reshape(1, -1, self.conv_out_dim) + x = self.out(x) + else: + # Now x is of shape (N, odim, (T-7)//2, ((idim-1)//2 - 1)//2) + b, c, t, f = x.size() + x = self.out(x.transpose(1, 2).reshape(b, t, c * f)) + # Now x is of shape (N, (T-7)//2, odim) + x = self.dropout(x) + return x + + +def _test_zipformer_main(): + feature_dim = 50 + batch_size = 5 + seq_len = 47 + feature_dim = 50 + # Just make sure the forward pass runs. + + c = Zipformer( + num_features=feature_dim, + encoder_dims=(64, 96), + encoder_unmasked_dims=(48, 64), + nhead=(4, 4), + decode_chunk_size=4, + ) + # Just make sure the forward pass runs. + f = c( + torch.randn(batch_size, seq_len, feature_dim), + torch.full((batch_size,), seq_len, dtype=torch.int64), + ) + assert ((seq_len - 7) // 2 + 1) // 2 == f[0].shape[1], (seq_len, f.shape[1]) + f[0].sum().backward() + c.eval() + f = c( + torch.randn(batch_size, seq_len, feature_dim), + torch.full((batch_size,), seq_len, dtype=torch.int64), + ) + f # to remove flake8 warnings + + +def _test_conv2d_subsampling(): + num_features = 80 + encoder_dims = 384 + dropout = 0.1 + encoder_embed = Conv2dSubsampling(num_features, encoder_dims, dropout=dropout) + for i in range(20, 40): + x = torch.rand(2, i, num_features) + y = encoder_embed(x) + assert (x.shape[1] - 7) // 2 == y.shape[1], (x.shape[1], y.shape[1]) + + +def _test_pooling_module(): + N, S, C = 2, 12, 32 + chunk_len = 4 + m = PoolingModule(d_model=C) + + # test chunk-wise forward with padding_mask + x = torch.randn(S, N, C) + y = m(x) + cached_len = torch.zeros(N, dtype=torch.int32) + cached_avg = torch.zeros(N, C) + for i in range(S // chunk_len): + start = i * chunk_len + end = start + chunk_len + x_chunk = x[start:end] + y_chunk, cached_len, cached_avg = m.streaming_forward( + x_chunk, + cached_len=cached_len, + cached_avg=cached_avg, + ) + assert torch.allclose(y_chunk, y[start:end]), (y_chunk, y[start:end]) + + +def _test_state_stack_unstack(): + m = Zipformer( + num_features=80, + encoder_dims=(64, 96), + encoder_unmasked_dims=(48, 64), + nhead=(4, 4), + zipformer_downsampling_factors=(4, 8), + num_left_chunks=2, + decode_chunk_size=8, + ) + s1 = m.get_init_state() + s2 = m.get_init_state() + states = stack_states([s1, s2]) + new_s1, new_s2 = unstack_states(states) + for i in range(m.num_encoders * 7): + for x, y in zip(s1[i], new_s1[i]): + assert torch.equal(x, y) + for x, y in zip(s2[i], new_s2[i]): + assert torch.equal(x, y) + + +if __name__ == "__main__": + logging.getLogger().setLevel(logging.INFO) + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + _test_zipformer_main() + _test_conv2d_subsampling() + _test_pooling_module() + _test_state_stack_unstack() From 4e832fa6b0f3f6fa578b797de5800e06e909b5ce Mon Sep 17 00:00:00 2001 From: Zengwei Yao Date: Tue, 14 Feb 2023 20:45:38 +0800 Subject: [PATCH 118/174] fix reduction conformer_ctc3/train.py (#908) --- egs/librispeech/ASR/conformer_ctc3/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/conformer_ctc3/train.py b/egs/librispeech/ASR/conformer_ctc3/train.py index ac489af9e..2cd223945 100755 --- a/egs/librispeech/ASR/conformer_ctc3/train.py +++ b/egs/librispeech/ASR/conformer_ctc3/train.py @@ -382,7 +382,7 @@ def get_params() -> AttributeDict: "num_encoder_layers": 12, # parameters for loss "beam_size": 10, - "reduction": "sum", + "reduction": "none", "use_double_scores": True, # parameters for Noam "model_warm_step": 3000, # arg given to model, not for lrate From 6d7a55904c168821fba456fb57d44dc1a801b166 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Thu, 16 Feb 2023 19:47:54 +0800 Subject: [PATCH 119/174] export script to ncnn for csj (#912) --- .../export-for-ncnn.py | 369 +++++ .../streaming-ncnn-decode.py | 1 + .../train2.py | 1305 +++++++++++++++++ .../zipformer2.py | 1 + 4 files changed, 1676 insertions(+) create mode 100755 egs/csj/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py create mode 120000 egs/csj/ASR/pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py create mode 100755 egs/csj/ASR/pruned_transducer_stateless7_streaming/train2.py create mode 120000 egs/csj/ASR/pruned_transducer_stateless7_streaming/zipformer2.py diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py new file mode 100755 index 000000000..ebdb596a5 --- /dev/null +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py @@ -0,0 +1,369 @@ +#!/usr/bin/env python3 + +""" +Please see +https://k2-fsa.github.io/icefall/model-export/export-ncnn.html +for more details about how to use this file. + +We use +https://huggingface.co/TeoWenShen/icefall-asr-csj-pruned-transducer-stateless7-streaming-230208 + +to demonstrate the usage of this file. + +1. Download the pre-trained model + +cd egs/csj/ASR + +repo_url=https://huggingface.co/TeoWenShen/icefall-asr-csj-pruned-transducer-stateless7-streaming-230208 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "exp_fluent/pretrained.pt" + +cd exp_fluent +ln -s pretrained.pt epoch-99.pt +popd + +2. Export to ncnn + +./pruned_transducer_stateless7_streaming/export-for-ncnn.py \ + --lang $repo/data/lang_char \ + --exp-dir $repo/exp_fluent/ \ + --epoch 99 \ + --avg 1 \ + --use-averaged-model 0 \ + \ + --decode-chunk-len 32 \ + --num-left-chunks 4 \ + --num-encoder-layers "2,4,3,2,4" \ + --feedforward-dims "1024,1024,2048,2048,1024" \ + --nhead "8,8,8,8,8" \ + --encoder-dims "384,384,384,384,384" \ + --attention-dims "192,192,192,192,192" \ + --encoder-unmasked-dims "256,256,256,256,256" \ + --zipformer-downsampling-factors "1,2,4,8,2" \ + --cnn-module-kernels "31,31,31,31,31" \ + --decoder-dim 512 \ + --joiner-dim 512 + +cd $repo/exp_fluent + +pnnx encoder_jit_trace-pnnx.pt +pnnx decoder_jit_trace-pnnx.pt +pnnx joiner_jit_trace-pnnx.pt + +You can find converted models at +https://huggingface.co/csukuangfj/sherpa-ncnn-streaming-zipformer-ja-fluent-2023-02-14 + +Please also have a look at +https://huggingface.co/csukuangfj/sherpa-ncnn-streaming-zipformer-ja-fluent-2023-02-14/blob/main/export-for-ncnn-fluent.sh + +See ./streaming-ncnn-decode.py +and +https://github.com/k2-fsa/sherpa-ncnn +for usage. +""" + +import argparse +import logging +from pathlib import Path + +import torch +from scaling_converter import convert_scaled_to_non_scaled +from tokenizer import Tokenizer +from train2 import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import setup_logger, str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="""It specifies the checkpoint to use for averaging. + Note: Epoch counts from 0. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7_streaming/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + add_model_arguments(parser) + + return parser + + +def export_encoder_model_jit_trace( + encoder_model: torch.nn.Module, + encoder_filename: str, +) -> None: + """Export the given encoder model with torch.jit.trace() + + Note: The warmup argument is fixed to 1. + + Args: + encoder_model: + The input encoder model + encoder_filename: + The filename to save the exported model. + """ + encoder_model.__class__.forward = encoder_model.__class__.streaming_forward + + decode_chunk_len = encoder_model.decode_chunk_size * 2 + pad_length = 7 + T = decode_chunk_len + pad_length # 32 + 7 = 39 + + logging.info(f"decode_chunk_len: {decode_chunk_len}") + logging.info(f"T: {T}") + + x = torch.zeros(1, T, 80, dtype=torch.float32) + states = encoder_model.get_init_state() + + traced_model = torch.jit.trace(encoder_model, (x, states)) + traced_model.save(encoder_filename) + logging.info(f"Saved to {encoder_filename}") + + +def export_decoder_model_jit_trace( + decoder_model: torch.nn.Module, + decoder_filename: str, +) -> None: + """Export the given decoder model with torch.jit.trace() + + Note: The argument need_pad is fixed to False. + + Args: + decoder_model: + The input decoder model + decoder_filename: + The filename to save the exported model. + """ + y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64) + need_pad = torch.tensor([False]) + + traced_model = torch.jit.trace(decoder_model, (y, need_pad)) + traced_model.save(decoder_filename) + logging.info(f"Saved to {decoder_filename}") + + +def export_joiner_model_jit_trace( + joiner_model: torch.nn.Module, + joiner_filename: str, +) -> None: + """Export the given joiner model with torch.jit.trace() + + Note: The argument project_input is fixed to True. A user should not + project the encoder_out/decoder_out by himself/herself. The exported joiner + will do that for the user. + + Args: + joiner_model: + The input joiner model + joiner_filename: + The filename to save the exported model. + + """ + encoder_out_dim = joiner_model.encoder_proj.weight.shape[1] + decoder_out_dim = joiner_model.decoder_proj.weight.shape[1] + encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32) + decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32) + + traced_model = torch.jit.trace(joiner_model, (encoder_out, decoder_out)) + traced_model.save(joiner_filename) + logging.info(f"Saved to {joiner_filename}") + + +@torch.no_grad() +def main(): + parser = get_parser() + Tokenizer.add_arguments(parser) + args = parser.parse_args() + + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + + setup_logger(f"{params.exp_dir}/log-export/log-export-ncnn") + + logging.info(f"device: {device}") + + sp = Tokenizer.load(args.lang, args.lang_type) + + # is defined in local/prepare_lang_char.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + assert params.blank_id == 0, params.blank_id + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to("cpu") + model.eval() + + convert_scaled_to_non_scaled(model, inplace=True, is_pnnx=True) + + encoder_num_param = sum([p.numel() for p in model.encoder.parameters()]) + decoder_num_param = sum([p.numel() for p in model.decoder.parameters()]) + joiner_num_param = sum([p.numel() for p in model.joiner.parameters()]) + total_num_param = encoder_num_param + decoder_num_param + joiner_num_param + logging.info(f"encoder parameters: {encoder_num_param}") + logging.info(f"decoder parameters: {decoder_num_param}") + logging.info(f"joiner parameters: {joiner_num_param}") + logging.info(f"total parameters: {total_num_param}") + + logging.info("Using torch.jit.trace()") + + logging.info("Exporting encoder") + encoder_filename = params.exp_dir / "encoder_jit_trace-pnnx.pt" + export_encoder_model_jit_trace(model.encoder, encoder_filename) + + logging.info("Exporting decoder") + decoder_filename = params.exp_dir / "decoder_jit_trace-pnnx.pt" + export_decoder_model_jit_trace(model.decoder, decoder_filename) + + logging.info("Exporting joiner") + joiner_filename = params.exp_dir / "joiner_jit_trace-pnnx.pt" + export_joiner_model_jit_trace(model.joiner, joiner_filename) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + main() diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py new file mode 120000 index 000000000..92c3904af --- /dev/null +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py \ No newline at end of file diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/train2.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/train2.py new file mode 100755 index 000000000..d1913d718 --- /dev/null +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/train2.py @@ -0,0 +1,1305 @@ +#!/usr/bin/env python3 +# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo,) +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +./pruned_transducer_stateless7_streaming/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --exp-dir pruned_transducer_stateless7_streaming/exp \ + --lang data/lang_char \ + --max-duration 300 + +# For mix precision training: + +./pruned_transducer_stateless7_streaming/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir pruned_transducer_stateless7_streaming/exp \ + --lang data/lang_char \ + --max-duration 550 +""" + + +import argparse +import copy +import logging +import math +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import k2 +import optim +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import CSJAsrDataModule +from decoder import Decoder +from joiner import Joiner +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model import Transducer +from optim import Eden, ScaledAdam +from tokenizer import Tokenizer +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from zipformer2 import Zipformer + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] +LOG_EPS = math.log(1e-10) + +try: + from TelegramStreamIO import TelegramStreamIO + + HAS_TELEGRAM = True +except ImportError: + HAS_TELEGRAM = False + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for module in model.modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=str, + default="2,4,3,2,4", + help="Number of zipformer encoder layers, comma separated.", + ) + + parser.add_argument( + "--feedforward-dims", + type=str, + default="1024,1024,2048,2048,1024", + help="Feedforward dimension of the zipformer encoder layers, comma separated.", + ) + + parser.add_argument( + "--nhead", + type=str, + default="8,8,8,8,8", + help="Number of attention heads in the zipformer encoder layers.", + ) + + parser.add_argument( + "--encoder-dims", + type=str, + default="384,384,384,384,384", + help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated", + ) + + parser.add_argument( + "--attention-dims", + type=str, + default="192,192,192,192,192", + help="""Attention dimension in the 2 blocks of zipformer encoder layers, comma separated; + not the same as embedding dimension.""", + ) + + parser.add_argument( + "--encoder-unmasked-dims", + type=str, + default="256,256,256,256,256", + help="Unmasked dimensions in the encoders, relates to augmentation during training. " + "Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance " + " worse.", + ) + + parser.add_argument( + "--zipformer-downsampling-factors", + type=str, + default="1,2,4,8,2", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--cnn-module-kernels", + type=str, + default="31,31,31,31,31", + help="Sizes of kernels in convolution modules", + ) + + parser.add_argument( + "--decoder-dim", + type=int, + default=512, + help="Embedding dimension in the decoder model.", + ) + + parser.add_argument( + "--joiner-dim", + type=int, + default=512, + help="""Dimension used in the joiner model. + Outputs from the encoder and decoder model are projected + to this dimension before adding. + """, + ) + + parser.add_argument( + "--short-chunk-size", + type=int, + default=50, + help="""Chunk length of dynamic training, the chunk size would be either + max sequence length of current batch or uniformly sampled from (1, short_chunk_size). + """, + ) + + parser.add_argument( + "--num-left-chunks", + type=int, + default=4, + help="How many left context can be seen in chunks when calculating attention.", + ) + + parser.add_argument( + "--decode-chunk-len", + type=int, + default=32, + help="The chunk size for decoding (in frames before subsampling)", + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument("--debug", action="store_true", help="Use hardcoded arguments") + + parser.add_argument( + "--telegram-cred", + type=Path, + default=None, + help="Telegram credentials to report progress in telegram", + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=Path, + default="pruned_transducer_stateless7_streaming/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--base-lr", type=float, default=0.05, help="The base learning rate." + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=5000, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=3.5, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network) part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=2000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 0. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + parser.add_argument( + "--pad-feature", + type=int, + default=0, + help=""" + Number of frames to pad at the end. + """, + ) + + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - encoder_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warmup period that dictates the decay of the + scale on "simple" (un-pruned) loss. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 1000, # For the 100h subset, use 800 + # parameters for zipformer + "feature_dim": 80, + "subsampling_factor": 4, # not passed in, this is fixed. + "warm_step": 2000, + "env_info": get_env_info(), + } + ) + + return params + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + # TODO: We can add an option to switch between Zipformer and Transformer + def to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + + encoder = Zipformer( + num_features=params.feature_dim, + output_downsampling_factor=2, + zipformer_downsampling_factors=to_int_tuple( + params.zipformer_downsampling_factors + ), + encoder_dims=to_int_tuple(params.encoder_dims), + attention_dim=to_int_tuple(params.attention_dims), + encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims), + nhead=to_int_tuple(params.nhead), + feedforward_dim=to_int_tuple(params.feedforward_dims), + cnn_module_kernels=to_int_tuple(params.cnn_module_kernels), + num_encoder_layers=to_int_tuple(params.num_encoder_layers), + num_left_chunks=params.num_left_chunks, + short_chunk_size=params.short_chunk_size, + decode_chunk_size=params.decode_chunk_len // 2, + is_pnnx=True, + ) + return encoder + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + decoder_dim=params.decoder_dim, + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + joiner = Joiner( + encoder_dim=int(params.encoder_dims.split(",")[-1]), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return joiner + + +def get_transducer_model(params: AttributeDict) -> nn.Module: + encoder = get_encoder_model(params) + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + + model = Transducer( + encoder=encoder, + decoder=decoder, + joiner=joiner, + encoder_dim=int(params.encoder_dims.split(",")[-1]), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + if "cur_batch_idx" in saved_params: + params["cur_batch_idx"] = saved_params["cur_batch_idx"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: Tokenizer, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute transducer loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + if params.pad_feature: + feature_lens += params.pad_feature + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, params.pad_feature), + value=LOG_EPS, + ) + + batch_idx_train = params.batch_idx_train + warm_step = params.warm_step + + texts = batch["supervisions"]["text"] + y = sp.encode(texts, out_type=int) + y = k2.RaggedTensor(y).to(device) + + with torch.set_grad_enabled(is_training): + simple_loss, pruned_loss = model( + x=feature, + x_lens=feature_lens, + y=y, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + ) + + s = params.simple_loss_scale + # take down the scale on the simple loss from 1.0 at the start + # to params.simple_loss scale by warm_step. + simple_loss_scale = ( + s + if batch_idx_train >= warm_step + else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) + ) + pruned_loss_scale = ( + 1.0 + if batch_idx_train >= warm_step + else 0.1 + 0.9 * (batch_idx_train / warm_step) + ) + + loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: Tokenizer, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + sp: Tokenizer, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + cur_batch_idx = params.get("cur_batch_idx", 0) + + for batch_idx, batch in enumerate(train_dl): + if batch_idx < cur_batch_idx: + continue + cur_batch_idx = batch_idx + + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + set_batch_count(model, params.batch_idx_train) + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except Exception as e: # noqa + logging.error(e, exc_info=True) + display_and_save_batch(batch, params=params, sp=sp) + raise e + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + params.cur_batch_idx = batch_idx + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + del params.cur_batch_idx + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if batch_idx % params.log_interval == 0: + cur_lr = scheduler.get_last_lr()[0] + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + if HAS_TELEGRAM and batch_idx in [0, 500] and not rank: + logging.warning( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + else: + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", + cur_grad_scale, + params.batch_idx_train, + ) + + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + if ( + HAS_TELEGRAM + and batch_idx % (params.valid_interval * 3) == 0 + and not rank + ): + log_mode = logging.warning + else: + log_mode = logging.info + log_mode(f"Epoch {params.cur_epoch}, validation: {valid_info}") + log_mode( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, master_port=params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + if HAS_TELEGRAM and params.telegram_cred: + TelegramStreamIO.setup_logger(params) + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + sp = Tokenizer.load(args.lang, args.lang_type) + + # is defined in local/prepare_lang_char.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + parameters_names = [] + parameters_names.append( + [name_param_pair[0] for name_param_pair in model.named_parameters()] + ) + optimizer = ScaledAdam( + model.parameters(), + lr=params.base_lr, + clipping_scale=2.0, + parameters_names=parameters_names, + ) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 2**22 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + if c.duration < 0.3 or c.duration > 20.0: + logging.warning( + f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + ) + return False + + # In pruned RNN-T, we require that T >= S + # where T is the number of feature frames after subsampling + # and S is the number of tokens in the utterance + + # In ./zipformer.py, the conv module uses the following expression + # for subsampling + T = ((c.num_frames - 7) // 2 + 1) // 2 + tokens = sp.encode(c.supervisions[0].text, out_type=str) + + if T < len(tokens): + logging.info( + f"Exclude cut with ID {c.id} from training. " + f"Number of frames (before subsampling): {c.num_frames}. " + f"Number of frames (after subsampling): {T}. " + f"Text: {c.supervisions[0].text}. " + f"Tokens: {tokens}. " + f"Number of tokens: {len(tokens)}" + ) + return False + + return True + + csj_corpus = CSJAsrDataModule(args) + train_cuts = csj_corpus.train_cuts() + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = csj_corpus.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) + + valid_cuts = csj_corpus.valid_cuts() + valid_dl = csj_corpus.valid_dataloaders(valid_cuts) + + if params.start_batch <= 0 and not params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + sp=sp, + params=params, + ) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sp=sp, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + sp: Tokenizer, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + sp: + The BPE model. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + y = sp.encode(supervisions["text"], out_type=int) + num_tokens = sum(len(i) for i in y) + logging.info(f"num tokens: {num_tokens}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + sp: Tokenizer, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params, sp=sp) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + parser = get_parser() + CSJAsrDataModule.add_arguments(parser) + Tokenizer.add_arguments(parser) + args = parser.parse_args() + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/zipformer2.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/zipformer2.py new file mode 120000 index 000000000..12dbda888 --- /dev/null +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/zipformer2.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py \ No newline at end of file From c01175679e5a92c18f5dc4014c83fbe0f1c09fbe Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Thu, 16 Feb 2023 21:09:05 +0800 Subject: [PATCH 120/174] Add CI test for exporting csj pretrained zipformer to ncnn (#913) --- .github/scripts/test-ncnn-export.sh | 95 +++++++++++++++++++++++++---- 1 file changed, 82 insertions(+), 13 deletions(-) diff --git a/.github/scripts/test-ncnn-export.sh b/.github/scripts/test-ncnn-export.sh index 9dd7736c0..9f5df2d58 100755 --- a/.github/scripts/test-ncnn-export.sh +++ b/.github/scripts/test-ncnn-export.sh @@ -8,7 +8,7 @@ log() { echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" } -cd egs/librispeech/ASR +pushd egs/librispeech/ASR log "Install ncnn and pnnx" @@ -37,6 +37,8 @@ make -j4 pnnx popd +export PATH=$PWD/ncnn/tools/pnnx/build/src:$PATH + log "==========================================================================" repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05 GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url @@ -66,9 +68,9 @@ log "Export via torch.jit.trace()" --right-context-length 8 \ --memory-size 32 -./ncnn/tools/pnnx/build/src/pnnx $repo/exp/encoder_jit_trace-pnnx.pt -./ncnn/tools/pnnx/build/src/pnnx $repo/exp/decoder_jit_trace-pnnx.pt -./ncnn/tools/pnnx/build/src/pnnx $repo/exp/joiner_jit_trace-pnnx.pt +pnnx $repo/exp/encoder_jit_trace-pnnx.pt +pnnx $repo/exp/decoder_jit_trace-pnnx.pt +pnnx $repo/exp/joiner_jit_trace-pnnx.pt python3 ./conv_emformer_transducer_stateless2/streaming-ncnn-decode.py \ --tokens $repo/data/lang_bpe_500/tokens.txt \ @@ -105,9 +107,9 @@ log "Export via torch.jit.trace()" --avg 1 \ --use-averaged-model 0 -./ncnn/tools/pnnx/build/src/pnnx $repo/exp/encoder_jit_trace-pnnx.pt -./ncnn/tools/pnnx/build/src/pnnx $repo/exp/decoder_jit_trace-pnnx.pt -./ncnn/tools/pnnx/build/src/pnnx $repo/exp/joiner_jit_trace-pnnx.pt +pnnx $repo/exp/encoder_jit_trace-pnnx.pt +pnnx $repo/exp/decoder_jit_trace-pnnx.pt +pnnx $repo/exp/joiner_jit_trace-pnnx.pt python3 ./lstm_transducer_stateless2/streaming-ncnn-decode.py \ --tokens $repo/data/lang_bpe_500/tokens.txt \ @@ -164,9 +166,9 @@ popd --decoder-dim 512 \ --joiner-dim 512 -./ncnn/tools/pnnx/build/src/pnnx $repo/exp/encoder_jit_trace-pnnx.pt -./ncnn/tools/pnnx/build/src/pnnx $repo/exp/decoder_jit_trace-pnnx.pt -./ncnn/tools/pnnx/build/src/pnnx $repo/exp/joiner_jit_trace-pnnx.pt +pnnx $repo/exp/encoder_jit_trace-pnnx.pt +pnnx $repo/exp/decoder_jit_trace-pnnx.pt +pnnx $repo/exp/joiner_jit_trace-pnnx.pt python3 ./pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py \ --tokens $repo/data/lang_bpe_500/tokens.txt \ @@ -214,9 +216,9 @@ popd --decoder-dim 512 \ --joiner-dim 512 -./ncnn/tools/pnnx/build/src/pnnx $repo/exp/encoder_jit_trace-pnnx.pt -./ncnn/tools/pnnx/build/src/pnnx $repo/exp/decoder_jit_trace-pnnx.pt -./ncnn/tools/pnnx/build/src/pnnx $repo/exp/joiner_jit_trace-pnnx.pt +pnnx $repo/exp/encoder_jit_trace-pnnx.pt +pnnx $repo/exp/decoder_jit_trace-pnnx.pt +pnnx $repo/exp/joiner_jit_trace-pnnx.pt python3 ./pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py \ --tokens $repo/data/lang_char_bpe/tokens.txt \ @@ -230,3 +232,70 @@ python3 ./pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py \ rm -rf $repo log "--------------------------------------------------------------------------" + +# Go back to the root directory of icefall repo +popd + +pushd egs/csj/ASR + +log "==========================================================================" +repo_url=https://huggingface.co/TeoWenShen/icefall-asr-csj-pruned-transducer-stateless7-streaming-230208 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "exp_fluent/pretrained.pt" +git lfs pull --include "exp_disfluent/pretrained.pt" + +cd exp_fluent +ln -s pretrained.pt epoch-99.pt + +cd ../exp_disfluent +ln -s pretrained.pt epoch-99.pt + +cd ../test_wavs +git lfs pull --include "*.wav" +popd + +log "Export via torch.jit.trace()" + +for exp in exp_fluent exp_disfluent; do + ./pruned_transducer_stateless7_streaming/export-for-ncnn.py \ + --exp-dir $repo/$exp/ \ + --lang $repo/data/lang_char \ + --epoch 99 \ + --avg 1 \ + --use-averaged-model 0 \ + \ + --decode-chunk-len 32 \ + --num-left-chunks 4 \ + --num-encoder-layers "2,4,3,2,4" \ + --feedforward-dims "1024,1024,2048,2048,1024" \ + --nhead "8,8,8,8,8" \ + --encoder-dims "384,384,384,384,384" \ + --attention-dims "192,192,192,192,192" \ + --encoder-unmasked-dims "256,256,256,256,256" \ + --zipformer-downsampling-factors "1,2,4,8,2" \ + --cnn-module-kernels "31,31,31,31,31" \ + --decoder-dim 512 \ + --joiner-dim 512 + + pnnx $repo/$exp/encoder_jit_trace-pnnx.pt + pnnx $repo/$exp/decoder_jit_trace-pnnx.pt + pnnx $repo/$exp/joiner_jit_trace-pnnx.pt + + for wav in aps-smp.wav interview_aps-smp.wav reproduction-smp.wav sps-smp.wav; do + python3 ./pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py \ + --tokens $repo/data/lang_char/tokens.txt \ + --encoder-param-filename $repo/$exp/encoder_jit_trace-pnnx.ncnn.param \ + --encoder-bin-filename $repo/$exp/encoder_jit_trace-pnnx.ncnn.bin \ + --decoder-param-filename $repo/$exp/decoder_jit_trace-pnnx.ncnn.param \ + --decoder-bin-filename $repo/$exp/decoder_jit_trace-pnnx.ncnn.bin \ + --joiner-param-filename $repo/$exp/joiner_jit_trace-pnnx.ncnn.param \ + --joiner-bin-filename $repo/$exp/joiner_jit_trace-pnnx.ncnn.bin \ + $repo/test_wavs/$wav + done +done + +rm -rf $repo +log "--------------------------------------------------------------------------" From 52d7cdd1a60908fd290f51c6217f58f40a97385d Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Fri, 17 Feb 2023 12:50:13 +0800 Subject: [PATCH 121/174] Update doc about exporting LSTM models to ncnn (#914) --- ...export-lstm-transducer-for-ncnn-output.txt | 18 + .../generate-int-8-scale-table-for-lstm.txt | 44 + ...decode-conv-emformer-transducer-libri.txt} | 0 ...ming-ncnn-decode-lstm-transducer-libri.txt | 6 + .../export-ncnn-conv-emformer.rst | 749 +++++++++++++++++ docs/source/model-export/export-ncnn-lstm.rst | 644 +++++++++++++++ docs/source/model-export/export-ncnn.rst | 780 +----------------- docs/source/model-export/export-onnx.rst | 2 +- .../lstm_pruned_stateless_transducer.rst | 126 --- 9 files changed, 1484 insertions(+), 885 deletions(-) create mode 100644 docs/source/model-export/code/export-lstm-transducer-for-ncnn-output.txt create mode 100644 docs/source/model-export/code/generate-int-8-scale-table-for-lstm.txt rename docs/source/model-export/code/{test-stremaing-ncnn-decode-conv-emformer-transducer-libri.txt => test-streaming-ncnn-decode-conv-emformer-transducer-libri.txt} (100%) create mode 100644 docs/source/model-export/code/test-streaming-ncnn-decode-lstm-transducer-libri.txt create mode 100644 docs/source/model-export/export-ncnn-conv-emformer.rst create mode 100644 docs/source/model-export/export-ncnn-lstm.rst diff --git a/docs/source/model-export/code/export-lstm-transducer-for-ncnn-output.txt b/docs/source/model-export/code/export-lstm-transducer-for-ncnn-output.txt new file mode 100644 index 000000000..fe4460985 --- /dev/null +++ b/docs/source/model-export/code/export-lstm-transducer-for-ncnn-output.txt @@ -0,0 +1,18 @@ +2023-02-17 11:22:42,862 INFO [export-for-ncnn.py:222] device: cpu +2023-02-17 11:22:42,865 INFO [export-for-ncnn.py:231] {'best_train_loss': inf, 'best_valid_loss': inf, 'best_train_epoch': -1, 'best_valid_epoch': -1, 'batch_idx_train': 0, 'log_interval': 50, 'reset_interval': 200, 'valid_interval': 3000, 'feature_dim': 80, 'subsampling_factor': 4, 'dim_feedforward': 2048, 'decoder_dim': 512, 'joiner_dim': 512, 'is_pnnx': False, 'model_warm_step': 3000, 'env_info': {'k2-version': '1.23.4', 'k2-build-type': 'Release', 'k2-with-cuda': True, 'k2-git-sha1': '62e404dd3f3a811d73e424199b3408e309c06e1a', 'k2-git-date': 'Mon Jan 30 10:26:16 2023', 'lhotse-version': '1.12.0.dev+missing.version.file', 'torch-version': '1.10.0+cu102', 'torch-cuda-available': False, 'torch-cuda-version': '10.2', 'python-version': '3.8', 'icefall-git-branch': 'master', 'icefall-git-sha1': '6d7a559-dirty', 'icefall-git-date': 'Thu Feb 16 19:47:54 2023', 'icefall-path': '/star-fj/fangjun/open-source/icefall-2', 'k2-path': '/star-fj/fangjun/open-source/k2/k2/python/k2/__init__.py', 'lhotse-path': '/star-fj/fangjun/open-source/lhotse/lhotse/__init__.py', 'hostname': 'de-74279-k2-train-3-1220120619-7695ff496b-s9n4w', 'IP address': '10.177.6.147'}, 'epoch': 99, 'iter': 0, 'avg': 1, 'exp_dir': PosixPath('icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp'), 'bpe_model': './icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/data/lang_bpe_500/bpe.model', 'context_size': 2, 'use_averaged_model': False, 'num_encoder_layers': 12, 'encoder_dim': 512, 'rnn_hidden_size': 1024, 'aux_layer_period': 0, 'blank_id': 0, 'vocab_size': 500} +2023-02-17 11:22:42,865 INFO [export-for-ncnn.py:235] About to create model +2023-02-17 11:22:43,239 INFO [train.py:472] Disable giga +2023-02-17 11:22:43,249 INFO [checkpoint.py:112] Loading checkpoint from icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/epoch-99.pt +2023-02-17 11:22:44,595 INFO [export-for-ncnn.py:324] encoder parameters: 83137520 +2023-02-17 11:22:44,596 INFO [export-for-ncnn.py:325] decoder parameters: 257024 +2023-02-17 11:22:44,596 INFO [export-for-ncnn.py:326] joiner parameters: 781812 +2023-02-17 11:22:44,596 INFO [export-for-ncnn.py:327] total parameters: 84176356 +2023-02-17 11:22:44,596 INFO [export-for-ncnn.py:329] Using torch.jit.trace() +2023-02-17 11:22:44,596 INFO [export-for-ncnn.py:331] Exporting encoder +2023-02-17 11:22:48,182 INFO [export-for-ncnn.py:158] Saved to icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/encoder_jit_trace-pnnx.pt +2023-02-17 11:22:48,183 INFO [export-for-ncnn.py:335] Exporting decoder +/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/lstm_transducer_stateless2/decoder.py:101: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! + need_pad = bool(need_pad) +2023-02-17 11:22:48,259 INFO [export-for-ncnn.py:180] Saved to icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/decoder_jit_trace-pnnx.pt +2023-02-17 11:22:48,259 INFO [export-for-ncnn.py:339] Exporting joiner +2023-02-17 11:22:48,304 INFO [export-for-ncnn.py:207] Saved to icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/joiner_jit_trace-pnnx.pt diff --git a/docs/source/model-export/code/generate-int-8-scale-table-for-lstm.txt b/docs/source/model-export/code/generate-int-8-scale-table-for-lstm.txt new file mode 100644 index 000000000..d39215b14 --- /dev/null +++ b/docs/source/model-export/code/generate-int-8-scale-table-for-lstm.txt @@ -0,0 +1,44 @@ +Don't Use GPU. has_gpu: 0, config.use_vulkan_compute: 1 +num encoder conv layers: 28 +num joiner conv layers: 3 +num files: 3 +Processing ../test_wavs/1089-134686-0001.wav +Processing ../test_wavs/1221-135766-0001.wav +Processing ../test_wavs/1221-135766-0002.wav +Processing ../test_wavs/1089-134686-0001.wav +Processing ../test_wavs/1221-135766-0001.wav +Processing ../test_wavs/1221-135766-0002.wav +----------encoder---------- +conv_15 : max = 15.942385 threshold = 15.930708 scale = 7.972025 +conv_16 : max = 44.978855 threshold = 17.031788 scale = 7.456645 +conv_17 : max = 17.868437 threshold = 7.830528 scale = 16.218575 +linear_18 : max = 3.107259 threshold = 1.194808 scale = 106.293236 +linear_19 : max = 6.193777 threshold = 4.634748 scale = 27.401705 +linear_20 : max = 9.259933 threshold = 2.606617 scale = 48.722160 +linear_21 : max = 5.186600 threshold = 4.790260 scale = 26.512129 +linear_22 : max = 9.759041 threshold = 2.265832 scale = 56.050053 +linear_23 : max = 3.931209 threshold = 3.099090 scale = 40.979767 +linear_24 : max = 10.324160 threshold = 2.215561 scale = 57.321835 +linear_25 : max = 3.800708 threshold = 3.599352 scale = 35.284134 +linear_26 : max = 10.492444 threshold = 3.153369 scale = 40.274391 +linear_27 : max = 3.660161 threshold = 2.720994 scale = 46.674126 +linear_28 : max = 9.415265 threshold = 3.174434 scale = 40.007133 +linear_29 : max = 4.038418 threshold = 3.118534 scale = 40.724262 +linear_30 : max = 10.072084 threshold = 3.936867 scale = 32.259155 +linear_31 : max = 4.342712 threshold = 3.599489 scale = 35.282787 +linear_32 : max = 11.340535 threshold = 3.120308 scale = 40.701103 +linear_33 : max = 3.846987 threshold = 3.630030 scale = 34.985939 +linear_34 : max = 10.686298 threshold = 2.204571 scale = 57.607586 +linear_35 : max = 4.904821 threshold = 4.575518 scale = 27.756420 +linear_36 : max = 11.806659 threshold = 2.585589 scale = 49.118401 +linear_37 : max = 6.402340 threshold = 5.047157 scale = 25.162680 +linear_38 : max = 11.174589 threshold = 1.923361 scale = 66.030258 +linear_39 : max = 16.178576 threshold = 7.556058 scale = 16.807705 +linear_40 : max = 12.901954 threshold = 5.301267 scale = 23.956539 +linear_41 : max = 14.839805 threshold = 7.597429 scale = 16.716181 +linear_42 : max = 10.178945 threshold = 2.651595 scale = 47.895699 +----------joiner---------- +linear_2 : max = 24.829245 threshold = 16.627592 scale = 7.637907 +linear_1 : max = 10.746186 threshold = 5.255032 scale = 24.167313 +linear_3 : max = 1.000000 threshold = 0.999756 scale = 127.031013 +ncnn int8 calibration table create success, best wish for your int8 inference has a low accuracy loss...\(^0^)/...233... diff --git a/docs/source/model-export/code/test-stremaing-ncnn-decode-conv-emformer-transducer-libri.txt b/docs/source/model-export/code/test-streaming-ncnn-decode-conv-emformer-transducer-libri.txt similarity index 100% rename from docs/source/model-export/code/test-stremaing-ncnn-decode-conv-emformer-transducer-libri.txt rename to docs/source/model-export/code/test-streaming-ncnn-decode-conv-emformer-transducer-libri.txt diff --git a/docs/source/model-export/code/test-streaming-ncnn-decode-lstm-transducer-libri.txt b/docs/source/model-export/code/test-streaming-ncnn-decode-lstm-transducer-libri.txt new file mode 100644 index 000000000..3606eae3d --- /dev/null +++ b/docs/source/model-export/code/test-streaming-ncnn-decode-lstm-transducer-libri.txt @@ -0,0 +1,6 @@ +2023-02-17 11:37:30,861 INFO [streaming-ncnn-decode.py:255] {'tokens': './icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/data/lang_bpe_500/tokens.txt', 'encoder_param_filename': './icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/encoder_jit_trace-pnnx.ncnn.param', 'encoder_bin_filename': './icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/encoder_jit_trace-pnnx.ncnn.bin', 'decoder_param_filename': './icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/decoder_jit_trace-pnnx.ncnn.param', 'decoder_bin_filename': './icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/decoder_jit_trace-pnnx.ncnn.bin', 'joiner_param_filename': './icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/joiner_jit_trace-pnnx.ncnn.param', 'joiner_bin_filename': './icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/joiner_jit_trace-pnnx.ncnn.bin', 'sound_filename': './icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/test_wavs/1089-134686-0001.wav'} +2023-02-17 11:37:31,425 INFO [streaming-ncnn-decode.py:263] Constructing Fbank computer +2023-02-17 11:37:31,427 INFO [streaming-ncnn-decode.py:266] Reading sound files: ./icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/test_wavs/1089-134686-0001.wav +2023-02-17 11:37:31,431 INFO [streaming-ncnn-decode.py:271] torch.Size([106000]) +2023-02-17 11:37:34,115 INFO [streaming-ncnn-decode.py:342] ./icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/test_wavs/1089-134686-0001.wav +2023-02-17 11:37:34,115 INFO [streaming-ncnn-decode.py:343] AFTER EARLY NIGHTFALL THE YELLOW LAMPS WOULD LIGHT UP HERE AND THERE THE SQUALID QUARTER OF THE BROTHELS diff --git a/docs/source/model-export/export-ncnn-conv-emformer.rst b/docs/source/model-export/export-ncnn-conv-emformer.rst new file mode 100644 index 000000000..d19c7dac8 --- /dev/null +++ b/docs/source/model-export/export-ncnn-conv-emformer.rst @@ -0,0 +1,749 @@ +.. _export_conv_emformer_transducer_models_to_ncnn: + +Export ConvEmformer transducer models to ncnn +============================================= + +We use the pre-trained model from the following repository as an example: + + - ``_ + +We will show you step by step how to export it to `ncnn`_ and run it with `sherpa-ncnn`_. + +.. hint:: + + We use ``Ubuntu 18.04``, ``torch 1.13``, and ``Python 3.8`` for testing. + +.. caution:: + + Please use a more recent version of PyTorch. For instance, ``torch 1.8`` + may ``not`` work. + +1. Download the pre-trained model +--------------------------------- + +.. hint:: + + You can also refer to ``_ to download the pre-trained model. + + You have to install `git-lfs`_ before you continue. + +.. code-block:: bash + + cd egs/librispeech/ASR + + GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05 + cd icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05 + + git lfs pull --include "exp/pretrained-epoch-30-avg-10-averaged.pt" + git lfs pull --include "data/lang_bpe_500/bpe.model" + + cd .. + +.. note:: + + We downloaded ``exp/pretrained-xxx.pt``, not ``exp/cpu-jit_xxx.pt``. + + +In the above code, we downloaded the pre-trained model into the directory +``egs/librispeech/ASR/icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05``. + +.. _export_for_ncnn_install_ncnn_and_pnnx: + +2. Install ncnn and pnnx +------------------------ + +.. code-block:: bash + + # We put ncnn into $HOME/open-source/ncnn + # You can change it to anywhere you like + + cd $HOME + mkdir -p open-source + cd open-source + + git clone https://github.com/csukuangfj/ncnn + cd ncnn + git submodule update --recursive --init + + # Note: We don't use "python setup.py install" or "pip install ." here + + mkdir -p build-wheel + cd build-wheel + + cmake \ + -DCMAKE_BUILD_TYPE=Release \ + -DNCNN_PYTHON=ON \ + -DNCNN_BUILD_BENCHMARK=OFF \ + -DNCNN_BUILD_EXAMPLES=OFF \ + -DNCNN_BUILD_TOOLS=ON \ + .. + + make -j4 + + cd .. + + # Note: $PWD here is $HOME/open-source/ncnn + + export PYTHONPATH=$PWD/python:$PYTHONPATH + export PATH=$PWD/tools/pnnx/build/src:$PATH + export PATH=$PWD/build-wheel/tools/quantize:$PATH + + # Now build pnnx + cd tools/pnnx + mkdir build + cd build + cmake .. + make -j4 + + ./src/pnnx + +Congratulations! You have successfully installed the following components: + + - ``pnxx``, which is an executable located in + ``$HOME/open-source/ncnn/tools/pnnx/build/src``. We will use + it to convert models exported by ``torch.jit.trace()``. + - ``ncnn2int8``, which is an executable located in + ``$HOME/open-source/ncnn/build-wheel/tools/quantize``. We will use + it to quantize our models to ``int8``. + - ``ncnn.cpython-38-x86_64-linux-gnu.so``, which is a Python module located + in ``$HOME/open-source/ncnn/python/ncnn``. + + .. note:: + + I am using ``Python 3.8``, so it + is ``ncnn.cpython-38-x86_64-linux-gnu.so``. If you use a different + version, say, ``Python 3.9``, the name would be + ``ncnn.cpython-39-x86_64-linux-gnu.so``. + + Also, if you are not using Linux, the file name would also be different. + But that does not matter. As long as you can compile it, it should work. + +We have set up ``PYTHONPATH`` so that you can use ``import ncnn`` in your +Python code. We have also set up ``PATH`` so that you can use +``pnnx`` and ``ncnn2int8`` later in your terminal. + +.. caution:: + + Please don't use ``_. + We have made some modifications to the offical `ncnn`_. + + We will synchronize ``_ periodically + with the official one. + +3. Export the model via torch.jit.trace() +----------------------------------------- + +First, let us rename our pre-trained model: + +.. code-block:: + + cd egs/librispeech/ASR + + cd icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp + + ln -s pretrained-epoch-30-avg-10-averaged.pt epoch-30.pt + + cd ../.. + +Next, we use the following code to export our model: + +.. code-block:: bash + + dir=./icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/ + + ./conv_emformer_transducer_stateless2/export-for-ncnn.py \ + --exp-dir $dir/exp \ + --bpe-model $dir/data/lang_bpe_500/bpe.model \ + --epoch 30 \ + --avg 1 \ + --use-averaged-model 0 \ + \ + --num-encoder-layers 12 \ + --chunk-length 32 \ + --cnn-module-kernel 31 \ + --left-context-length 32 \ + --right-context-length 8 \ + --memory-size 32 \ + --encoder-dim 512 + +.. hint:: + + We have renamed our model to ``epoch-30.pt`` so that we can use ``--epoch 30``. + There is only one pre-trained model, so we use ``--avg 1 --use-averaged-model 0``. + + If you have trained a model by yourself and if you have all checkpoints + available, please first use ``decode.py`` to tune ``--epoch --avg`` + and select the best combination with with ``--use-averaged-model 1``. + +.. note:: + + You will see the following log output: + + .. literalinclude:: ./code/export-conv-emformer-transducer-for-ncnn-output.txt + + The log shows the model has ``75490012`` parameters, i.e., ``~75 M``. + + .. code-block:: + + ls -lh icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/pretrained-epoch-30-avg-10-averaged.pt + + -rw-r--r-- 1 kuangfangjun root 289M Jan 11 12:05 icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/pretrained-epoch-30-avg-10-averaged.pt + + You can see that the file size of the pre-trained model is ``289 MB``, which + is roughly equal to ``75490012*4/1024/1024 = 287.97 MB``. + +After running ``conv_emformer_transducer_stateless2/export-for-ncnn.py``, +we will get the following files: + +.. code-block:: bash + + ls -lh icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/*pnnx* + + -rw-r--r-- 1 kuangfangjun root 1010K Jan 11 12:15 icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/decoder_jit_trace-pnnx.pt + -rw-r--r-- 1 kuangfangjun root 283M Jan 11 12:15 icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/encoder_jit_trace-pnnx.pt + -rw-r--r-- 1 kuangfangjun root 3.0M Jan 11 12:15 icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/joiner_jit_trace-pnnx.pt + + +.. _conv-emformer-step-4-export-torchscript-model-via-pnnx: + +4. Export torchscript model via pnnx +------------------------------------ + +.. hint:: + + Make sure you have set up the ``PATH`` environment variable. Otherwise, + it will throw an error saying that ``pnnx`` could not be found. + +Now, it's time to export our models to `ncnn`_ via ``pnnx``. + +.. code-block:: + + cd icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/ + + pnnx ./encoder_jit_trace-pnnx.pt + pnnx ./decoder_jit_trace-pnnx.pt + pnnx ./joiner_jit_trace-pnnx.pt + +It will generate the following files: + +.. code-block:: bash + + ls -lh icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/*ncnn*{bin,param} + + -rw-r--r-- 1 kuangfangjun root 503K Jan 11 12:38 icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/decoder_jit_trace-pnnx.ncnn.bin + -rw-r--r-- 1 kuangfangjun root 437 Jan 11 12:38 icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/decoder_jit_trace-pnnx.ncnn.param + -rw-r--r-- 1 kuangfangjun root 142M Jan 11 12:36 icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/encoder_jit_trace-pnnx.ncnn.bin + -rw-r--r-- 1 kuangfangjun root 79K Jan 11 12:36 icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/encoder_jit_trace-pnnx.ncnn.param + -rw-r--r-- 1 kuangfangjun root 1.5M Jan 11 12:38 icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/joiner_jit_trace-pnnx.ncnn.bin + -rw-r--r-- 1 kuangfangjun root 488 Jan 11 12:38 icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/joiner_jit_trace-pnnx.ncnn.param + +There are two types of files: + +- ``param``: It is a text file containing the model architectures. You can + use a text editor to view its content. +- ``bin``: It is a binary file containing the model parameters. + +We compare the file sizes of the models below before and after converting via ``pnnx``: + +.. see https://tableconvert.com/restructuredtext-generator + ++----------------------------------+------------+ +| File name | File size | ++==================================+============+ +| encoder_jit_trace-pnnx.pt | 283 MB | ++----------------------------------+------------+ +| decoder_jit_trace-pnnx.pt | 1010 KB | ++----------------------------------+------------+ +| joiner_jit_trace-pnnx.pt | 3.0 MB | ++----------------------------------+------------+ +| encoder_jit_trace-pnnx.ncnn.bin | 142 MB | ++----------------------------------+------------+ +| decoder_jit_trace-pnnx.ncnn.bin | 503 KB | ++----------------------------------+------------+ +| joiner_jit_trace-pnnx.ncnn.bin | 1.5 MB | ++----------------------------------+------------+ + +You can see that the file sizes of the models after conversion are about one half +of the models before conversion: + + - encoder: 283 MB vs 142 MB + - decoder: 1010 KB vs 503 KB + - joiner: 3.0 MB vs 1.5 MB + +The reason is that by default ``pnnx`` converts ``float32`` parameters +to ``float16``. A ``float32`` parameter occupies 4 bytes, while it is 2 bytes +for ``float16``. Thus, it is ``twice smaller`` after conversion. + +.. hint:: + + If you use ``pnnx ./encoder_jit_trace-pnnx.pt fp16=0``, then ``pnnx`` + won't convert ``float32`` to ``float16``. + +5. Test the exported models in icefall +-------------------------------------- + +.. note:: + + We assume you have set up the environment variable ``PYTHONPATH`` when + building `ncnn`_. + +Now we have successfully converted our pre-trained model to `ncnn`_ format. +The generated 6 files are what we need. You can use the following code to +test the converted models: + +.. code-block:: bash + + ./conv_emformer_transducer_stateless2/streaming-ncnn-decode.py \ + --tokens ./icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/data/lang_bpe_500/tokens.txt \ + --encoder-param-filename ./icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/encoder_jit_trace-pnnx.ncnn.param \ + --encoder-bin-filename ./icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/encoder_jit_trace-pnnx.ncnn.bin \ + --decoder-param-filename ./icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/decoder_jit_trace-pnnx.ncnn.param \ + --decoder-bin-filename ./icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/decoder_jit_trace-pnnx.ncnn.bin \ + --joiner-param-filename ./icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/joiner_jit_trace-pnnx.ncnn.param \ + --joiner-bin-filename ./icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/joiner_jit_trace-pnnx.ncnn.bin \ + ./icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/test_wavs/1089-134686-0001.wav + +.. hint:: + + `ncnn`_ supports only ``batch size == 1``, so ``streaming-ncnn-decode.py`` accepts + only 1 wave file as input. + +The output is given below: + +.. literalinclude:: ./code/test-streaming-ncnn-decode-conv-emformer-transducer-libri.txt + +Congratulations! You have successfully exported a model from PyTorch to `ncnn`_! + + +.. _conv-emformer-modify-the-exported-encoder-for-sherpa-ncnn: + +6. Modify the exported encoder for sherpa-ncnn +---------------------------------------------- + +In order to use the exported models in `sherpa-ncnn`_, we have to modify +``encoder_jit_trace-pnnx.ncnn.param``. + +Let us have a look at the first few lines of ``encoder_jit_trace-pnnx.ncnn.param``: + +.. code-block:: + + 7767517 + 1060 1342 + Input in0 0 1 in0 + +**Explanation** of the above three lines: + + 1. ``7767517``, it is a magic number and should not be changed. + 2. ``1060 1342``, the first number ``1060`` specifies the number of layers + in this file, while ``1342`` specifies the number of intermediate outputs + of this file + 3. ``Input in0 0 1 in0``, ``Input`` is the layer type of this layer; ``in0`` + is the layer name of this layer; ``0`` means this layer has no input; + ``1`` means this layer has one output; ``in0`` is the output name of + this layer. + +We need to add 1 extra line and also increment the number of layers. +The result looks like below: + +.. code-block:: bash + + 7767517 + 1061 1342 + SherpaMetaData sherpa_meta_data1 0 0 0=1 1=12 2=32 3=31 4=8 5=32 6=8 7=512 + Input in0 0 1 in0 + +**Explanation** + + 1. ``7767517``, it is still the same + 2. ``1061 1342``, we have added an extra layer, so we need to update ``1060`` to ``1061``. + We don't need to change ``1342`` since the newly added layer has no inputs or outputs. + 3. ``SherpaMetaData sherpa_meta_data1 0 0 0=1 1=12 2=32 3=31 4=8 5=32 6=8 7=512`` + This line is newly added. Its explanation is given below: + + - ``SherpaMetaData`` is the type of this layer. Must be ``SherpaMetaData``. + - ``sherpa_meta_data1`` is the name of this layer. Must be ``sherpa_meta_data1``. + - ``0 0`` means this layer has no inputs or output. Must be ``0 0`` + - ``0=1``, 0 is the key and 1 is the value. MUST be ``0=1`` + - ``1=12``, 1 is the key and 12 is the value of the + parameter ``--num-encoder-layers`` that you provided when running + ``conv_emformer_transducer_stateless2/export-for-ncnn.py``. + - ``2=32``, 2 is the key and 32 is the value of the + parameter ``--memory-size`` that you provided when running + ``conv_emformer_transducer_stateless2/export-for-ncnn.py``. + - ``3=31``, 3 is the key and 31 is the value of the + parameter ``--cnn-module-kernel`` that you provided when running + ``conv_emformer_transducer_stateless2/export-for-ncnn.py``. + - ``4=8``, 4 is the key and 8 is the value of the + parameter ``--left-context-length`` that you provided when running + ``conv_emformer_transducer_stateless2/export-for-ncnn.py``. + - ``5=32``, 5 is the key and 32 is the value of the + parameter ``--chunk-length`` that you provided when running + ``conv_emformer_transducer_stateless2/export-for-ncnn.py``. + - ``6=8``, 6 is the key and 8 is the value of the + parameter ``--right-context-length`` that you provided when running + ``conv_emformer_transducer_stateless2/export-for-ncnn.py``. + - ``7=512``, 7 is the key and 512 is the value of the + parameter ``--encoder-dim`` that you provided when running + ``conv_emformer_transducer_stateless2/export-for-ncnn.py``. + + For ease of reference, we list the key-value pairs that you need to add + in the following table. If your model has a different setting, please + change the values for ``SherpaMetaData`` accordingly. Otherwise, you + will be ``SAD``. + + +------+-----------------------------+ + | key | value | + +======+=============================+ + | 0 | 1 (fixed) | + +------+-----------------------------+ + | 1 | ``--num-encoder-layers`` | + +------+-----------------------------+ + | 2 | ``--memory-size`` | + +------+-----------------------------+ + | 3 | ``--cnn-module-kernel`` | + +------+-----------------------------+ + | 4 | ``--left-context-length`` | + +------+-----------------------------+ + | 5 | ``--chunk-length`` | + +------+-----------------------------+ + | 6 | ``--right-context-length`` | + +------+-----------------------------+ + | 7 | ``--encoder-dim`` | + +------+-----------------------------+ + + 4. ``Input in0 0 1 in0``. No need to change it. + +.. caution:: + + When you add a new layer ``SherpaMetaData``, please remember to update the + number of layers. In our case, update ``1060`` to ``1061``. Otherwise, + you will be SAD later. + +.. hint:: + + After adding the new layer ``SherpaMetaData``, you cannot use this model + with ``streaming-ncnn-decode.py`` anymore since ``SherpaMetaData`` is + supported only in `sherpa-ncnn`_. + +.. hint:: + + `ncnn`_ is very flexible. You can add new layers to it just by text-editing + the ``param`` file! You don't need to change the ``bin`` file. + +Now you can use this model in `sherpa-ncnn`_. +Please refer to the following documentation: + + - Linux/macOS/Windows/arm/aarch64: ``_ + - ``Android``: ``_ + - ``iOS``: ``_ + - Python: ``_ + +We have a list of pre-trained models that have been exported for `sherpa-ncnn`_: + + - ``_ + + You can find more usages there. + +7. (Optional) int8 quantization with sherpa-ncnn +------------------------------------------------ + +This step is optional. + +In this step, we describe how to quantize our model with ``int8``. + +Change :ref:`conv-emformer-step-4-export-torchscript-model-via-pnnx` to +disable ``fp16`` when using ``pnnx``: + +.. code-block:: + + cd icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/ + + pnnx ./encoder_jit_trace-pnnx.pt fp16=0 + pnnx ./decoder_jit_trace-pnnx.pt + pnnx ./joiner_jit_trace-pnnx.pt fp16=0 + +.. note:: + + We add ``fp16=0`` when exporting the encoder and joiner. `ncnn`_ does not + support quantizing the decoder model yet. We will update this documentation + once `ncnn`_ supports it. (Maybe in this year, 2023). + +It will generate the following files + +.. code-block:: bash + + ls -lh icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/*_jit_trace-pnnx.ncnn.{param,bin} + + -rw-r--r-- 1 kuangfangjun root 503K Jan 11 15:56 icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/decoder_jit_trace-pnnx.ncnn.bin + -rw-r--r-- 1 kuangfangjun root 437 Jan 11 15:56 icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/decoder_jit_trace-pnnx.ncnn.param + -rw-r--r-- 1 kuangfangjun root 283M Jan 11 15:56 icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/encoder_jit_trace-pnnx.ncnn.bin + -rw-r--r-- 1 kuangfangjun root 79K Jan 11 15:56 icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/encoder_jit_trace-pnnx.ncnn.param + -rw-r--r-- 1 kuangfangjun root 3.0M Jan 11 15:56 icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/joiner_jit_trace-pnnx.ncnn.bin + -rw-r--r-- 1 kuangfangjun root 488 Jan 11 15:56 icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/joiner_jit_trace-pnnx.ncnn.param + +Let us compare again the file sizes: + ++----------------------------------------+------------+ +| File name | File size | ++----------------------------------------+------------+ +| encoder_jit_trace-pnnx.pt | 283 MB | ++----------------------------------------+------------+ +| decoder_jit_trace-pnnx.pt | 1010 KB | ++----------------------------------------+------------+ +| joiner_jit_trace-pnnx.pt | 3.0 MB | ++----------------------------------------+------------+ +| encoder_jit_trace-pnnx.ncnn.bin (fp16) | 142 MB | ++----------------------------------------+------------+ +| decoder_jit_trace-pnnx.ncnn.bin (fp16) | 503 KB | ++----------------------------------------+------------+ +| joiner_jit_trace-pnnx.ncnn.bin (fp16) | 1.5 MB | ++----------------------------------------+------------+ +| encoder_jit_trace-pnnx.ncnn.bin (fp32) | 283 MB | ++----------------------------------------+------------+ +| joiner_jit_trace-pnnx.ncnn.bin (fp32) | 3.0 MB | ++----------------------------------------+------------+ + +You can see that the file sizes are doubled when we disable ``fp16``. + +.. note:: + + You can again use ``streaming-ncnn-decode.py`` to test the exported models. + +Next, follow :ref:`conv-emformer-modify-the-exported-encoder-for-sherpa-ncnn` +to modify ``encoder_jit_trace-pnnx.ncnn.param``. + +Change + +.. code-block:: bash + + 7767517 + 1060 1342 + Input in0 0 1 in0 + +to + +.. code-block:: bash + + 7767517 + 1061 1342 + SherpaMetaData sherpa_meta_data1 0 0 0=1 1=12 2=32 3=31 4=8 5=32 6=8 7=512 + Input in0 0 1 in0 + +.. caution:: + + Please follow :ref:`conv-emformer-modify-the-exported-encoder-for-sherpa-ncnn` + to change the values for ``SherpaMetaData`` if your model uses a different setting. + + +Next, let us compile `sherpa-ncnn`_ since we will quantize our models within +`sherpa-ncnn`_. + +.. code-block:: bash + + # We will download sherpa-ncnn to $HOME/open-source/ + # You can change it to anywhere you like. + cd $HOME + mkdir -p open-source + + cd open-source + git clone https://github.com/k2-fsa/sherpa-ncnn + cd sherpa-ncnn + mkdir build + cd build + cmake .. + make -j 4 + + ./bin/generate-int8-scale-table + + export PATH=$HOME/open-source/sherpa-ncnn/build/bin:$PATH + +The output of the above commands are: + +.. code-block:: bash + + (py38) kuangfangjun:build$ generate-int8-scale-table + Please provide 10 arg. Currently given: 1 + Usage: + generate-int8-scale-table encoder.param encoder.bin decoder.param decoder.bin joiner.param joiner.bin encoder-scale-table.txt joiner-scale-table.txt wave_filenames.txt + + Each line in wave_filenames.txt is a path to some 16k Hz mono wave file. + +We need to create a file ``wave_filenames.txt``, in which we need to put +some calibration wave files. For testing purpose, we put the ``test_wavs`` +from the pre-trained model repository ``_ + +.. code-block:: bash + + cd egs/librispeech/ASR + cd icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/ + + cat < wave_filenames.txt + ../test_wavs/1089-134686-0001.wav + ../test_wavs/1221-135766-0001.wav + ../test_wavs/1221-135766-0002.wav + EOF + +Now we can calculate the scales needed for quantization with the calibration data: + +.. code-block:: bash + + cd egs/librispeech/ASR + cd icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/ + + generate-int8-scale-table \ + ./encoder_jit_trace-pnnx.ncnn.param \ + ./encoder_jit_trace-pnnx.ncnn.bin \ + ./decoder_jit_trace-pnnx.ncnn.param \ + ./decoder_jit_trace-pnnx.ncnn.bin \ + ./joiner_jit_trace-pnnx.ncnn.param \ + ./joiner_jit_trace-pnnx.ncnn.bin \ + ./encoder-scale-table.txt \ + ./joiner-scale-table.txt \ + ./wave_filenames.txt + +The output logs are in the following: + +.. literalinclude:: ./code/generate-int-8-scale-table-for-conv-emformer.txt + +It generates the following two files: + +.. code-block:: bash + + $ ls -lh encoder-scale-table.txt joiner-scale-table.txt + -rw-r--r-- 1 kuangfangjun root 955K Jan 11 17:28 encoder-scale-table.txt + -rw-r--r-- 1 kuangfangjun root 18K Jan 11 17:28 joiner-scale-table.txt + +.. caution:: + + Definitely, you need more calibration data to compute the scale table. + +Finally, let us use the scale table to quantize our models into ``int8``. + +.. code-block:: bash + + ncnn2int8 + + usage: ncnn2int8 [inparam] [inbin] [outparam] [outbin] [calibration table] + +First, we quantize the encoder model: + +.. code-block:: bash + + cd egs/librispeech/ASR + cd icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/ + + ncnn2int8 \ + ./encoder_jit_trace-pnnx.ncnn.param \ + ./encoder_jit_trace-pnnx.ncnn.bin \ + ./encoder_jit_trace-pnnx.ncnn.int8.param \ + ./encoder_jit_trace-pnnx.ncnn.int8.bin \ + ./encoder-scale-table.txt + +Next, we quantize the joiner model: + +.. code-block:: bash + + ncnn2int8 \ + ./joiner_jit_trace-pnnx.ncnn.param \ + ./joiner_jit_trace-pnnx.ncnn.bin \ + ./joiner_jit_trace-pnnx.ncnn.int8.param \ + ./joiner_jit_trace-pnnx.ncnn.int8.bin \ + ./joiner-scale-table.txt + +The above two commands generate the following 4 files: + +.. code-block:: bash + + -rw-r--r-- 1 kuangfangjun root 99M Jan 11 17:34 encoder_jit_trace-pnnx.ncnn.int8.bin + -rw-r--r-- 1 kuangfangjun root 78K Jan 11 17:34 encoder_jit_trace-pnnx.ncnn.int8.param + -rw-r--r-- 1 kuangfangjun root 774K Jan 11 17:35 joiner_jit_trace-pnnx.ncnn.int8.bin + -rw-r--r-- 1 kuangfangjun root 496 Jan 11 17:35 joiner_jit_trace-pnnx.ncnn.int8.param + +Congratulations! You have successfully quantized your model from ``float32`` to ``int8``. + +.. caution:: + + ``ncnn.int8.param`` and ``ncnn.int8.bin`` must be used in pairs. + + You can replace ``ncnn.param`` and ``ncnn.bin`` with ``ncnn.int8.param`` + and ``ncnn.int8.bin`` in `sherpa-ncnn`_ if you like. + + For instance, to use only the ``int8`` encoder in ``sherpa-ncnn``, you can + replace the following invocation: + + .. code-block:: bash + + cd egs/librispeech/ASR + cd icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/ + + sherpa-ncnn \ + ../data/lang_bpe_500/tokens.txt \ + ./encoder_jit_trace-pnnx.ncnn.param \ + ./encoder_jit_trace-pnnx.ncnn.bin \ + ./decoder_jit_trace-pnnx.ncnn.param \ + ./decoder_jit_trace-pnnx.ncnn.bin \ + ./joiner_jit_trace-pnnx.ncnn.param \ + ./joiner_jit_trace-pnnx.ncnn.bin \ + ../test_wavs/1089-134686-0001.wav + + with + + .. code-block:: + + cd egs/librispeech/ASR + cd icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/ + + sherpa-ncnn \ + ../data/lang_bpe_500/tokens.txt \ + ./encoder_jit_trace-pnnx.ncnn.int8.param \ + ./encoder_jit_trace-pnnx.ncnn.int8.bin \ + ./decoder_jit_trace-pnnx.ncnn.param \ + ./decoder_jit_trace-pnnx.ncnn.bin \ + ./joiner_jit_trace-pnnx.ncnn.param \ + ./joiner_jit_trace-pnnx.ncnn.bin \ + ../test_wavs/1089-134686-0001.wav + + +The following table compares again the file sizes: + + ++----------------------------------------+------------+ +| File name | File size | ++----------------------------------------+------------+ +| encoder_jit_trace-pnnx.pt | 283 MB | ++----------------------------------------+------------+ +| decoder_jit_trace-pnnx.pt | 1010 KB | ++----------------------------------------+------------+ +| joiner_jit_trace-pnnx.pt | 3.0 MB | ++----------------------------------------+------------+ +| encoder_jit_trace-pnnx.ncnn.bin (fp16) | 142 MB | ++----------------------------------------+------------+ +| decoder_jit_trace-pnnx.ncnn.bin (fp16) | 503 KB | ++----------------------------------------+------------+ +| joiner_jit_trace-pnnx.ncnn.bin (fp16) | 1.5 MB | ++----------------------------------------+------------+ +| encoder_jit_trace-pnnx.ncnn.bin (fp32) | 283 MB | ++----------------------------------------+------------+ +| joiner_jit_trace-pnnx.ncnn.bin (fp32) | 3.0 MB | ++----------------------------------------+------------+ +| encoder_jit_trace-pnnx.ncnn.int8.bin | 99 MB | ++----------------------------------------+------------+ +| joiner_jit_trace-pnnx.ncnn.int8.bin | 774 KB | ++----------------------------------------+------------+ + +You can see that the file sizes of the model after ``int8`` quantization +are much smaller. + +.. hint:: + + Currently, only linear layers and convolutional layers are quantized + with ``int8``, so you don't see an exact ``4x`` reduction in file sizes. + +.. note:: + + You need to test the recognition accuracy after ``int8`` quantization. + +You can find the speed comparison at ``_. + + +That's it! Have fun with `sherpa-ncnn`_! diff --git a/docs/source/model-export/export-ncnn-lstm.rst b/docs/source/model-export/export-ncnn-lstm.rst new file mode 100644 index 000000000..8e6dc7466 --- /dev/null +++ b/docs/source/model-export/export-ncnn-lstm.rst @@ -0,0 +1,644 @@ +.. _export_lstm_transducer_models_to_ncnn: + +Export LSTM transducer models to ncnn +------------------------------------- + +We use the pre-trained model from the following repository as an example: + +``_ + +We will show you step by step how to export it to `ncnn`_ and run it with `sherpa-ncnn`_. + +.. hint:: + + We use ``Ubuntu 18.04``, ``torch 1.13``, and ``Python 3.8`` for testing. + +.. caution:: + + Please use a more recent version of PyTorch. For instance, ``torch 1.8`` + may ``not`` work. + +1. Download the pre-trained model +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. hint:: + + You have to install `git-lfs`_ before you continue. + + +.. code-block:: bash + + cd egs/librispeech/ASR + GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03 + cd icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03 + + git lfs pull --include "exp/pretrained-iter-468000-avg-16.pt" + git lfs pull --include "data/lang_bpe_500/bpe.model" + + cd .. + +.. note:: + + We downloaded ``exp/pretrained-xxx.pt``, not ``exp/cpu-jit_xxx.pt``. + +In the above code, we downloaded the pre-trained model into the directory +``egs/librispeech/ASR/icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03``. + +2. Install ncnn and pnnx +^^^^^^^^^^^^^^^^^^^^^^^^ + +Please refer to :ref:`export_for_ncnn_install_ncnn_and_pnnx` . + + +3. Export the model via torch.jit.trace() +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +First, let us rename our pre-trained model: + +.. code-block:: + + cd egs/librispeech/ASR + + cd icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp + + ln -s pretrained-iter-468000-avg-16.pt epoch-99.pt + + cd ../.. + +Next, we use the following code to export our model: + +.. code-block:: bash + + dir=./icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03 + + ./lstm_transducer_stateless2/export-for-ncnn.py \ + --exp-dir $dir/exp \ + --bpe-model $dir/data/lang_bpe_500/bpe.model \ + --epoch 99 \ + --avg 1 \ + --use-averaged-model 0 \ + --num-encoder-layers 12 \ + --encoder-dim 512 \ + --rnn-hidden-size 1024 + +.. hint:: + + We have renamed our model to ``epoch-99.pt`` so that we can use ``--epoch 99``. + There is only one pre-trained model, so we use ``--avg 1 --use-averaged-model 0``. + + If you have trained a model by yourself and if you have all checkpoints + available, please first use ``decode.py`` to tune ``--epoch --avg`` + and select the best combination with with ``--use-averaged-model 1``. + +.. note:: + + You will see the following log output: + + .. literalinclude:: ./code/export-lstm-transducer-for-ncnn-output.txt + + The log shows the model has ``84176356`` parameters, i.e., ``~84 M``. + + .. code-block:: + + ls -lh icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/pretrained-iter-468000-avg-16.pt + + -rw-r--r-- 1 kuangfangjun root 324M Feb 17 10:34 icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/pretrained-iter-468000-avg-16.pt + + You can see that the file size of the pre-trained model is ``324 MB``, which + is roughly equal to ``84176356*4/1024/1024 = 321.107 MB``. + +After running ``lstm_transducer_stateless2/export-for-ncnn.py``, +we will get the following files: + +.. code-block:: bash + + ls -lh icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/*pnnx.pt + + -rw-r--r-- 1 kuangfangjun root 1010K Feb 17 11:22 icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/decoder_jit_trace-pnnx.pt + -rw-r--r-- 1 kuangfangjun root 318M Feb 17 11:22 icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/encoder_jit_trace-pnnx.pt + -rw-r--r-- 1 kuangfangjun root 3.0M Feb 17 11:22 icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/joiner_jit_trace-pnnx.pt + + +.. _lstm-transducer-step-4-export-torchscript-model-via-pnnx: + +4. Export torchscript model via pnnx +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. hint:: + + Make sure you have set up the ``PATH`` environment variable + in :ref:`export_for_ncnn_install_ncnn_and_pnnx`. Otherwise, + it will throw an error saying that ``pnnx`` could not be found. + +Now, it's time to export our models to `ncnn`_ via ``pnnx``. + +.. code-block:: + + cd icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/ + + pnnx ./encoder_jit_trace-pnnx.pt + pnnx ./decoder_jit_trace-pnnx.pt + pnnx ./joiner_jit_trace-pnnx.pt + +It will generate the following files: + +.. code-block:: bash + + ls -lh icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/*ncnn*{bin,param} + + -rw-r--r-- 1 kuangfangjun root 503K Feb 17 11:32 icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/decoder_jit_trace-pnnx.ncnn.bin + -rw-r--r-- 1 kuangfangjun root 437 Feb 17 11:32 icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/decoder_jit_trace-pnnx.ncnn.param + -rw-r--r-- 1 kuangfangjun root 159M Feb 17 11:32 icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/encoder_jit_trace-pnnx.ncnn.bin + -rw-r--r-- 1 kuangfangjun root 21K Feb 17 11:32 icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/encoder_jit_trace-pnnx.ncnn.param + -rw-r--r-- 1 kuangfangjun root 1.5M Feb 17 11:33 icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/joiner_jit_trace-pnnx.ncnn.bin + -rw-r--r-- 1 kuangfangjun root 488 Feb 17 11:33 icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/joiner_jit_trace-pnnx.ncnn.param + + +There are two types of files: + +- ``param``: It is a text file containing the model architectures. You can + use a text editor to view its content. +- ``bin``: It is a binary file containing the model parameters. + +We compare the file sizes of the models below before and after converting via ``pnnx``: + +.. see https://tableconvert.com/restructuredtext-generator + ++----------------------------------+------------+ +| File name | File size | ++==================================+============+ +| encoder_jit_trace-pnnx.pt | 318 MB | ++----------------------------------+------------+ +| decoder_jit_trace-pnnx.pt | 1010 KB | ++----------------------------------+------------+ +| joiner_jit_trace-pnnx.pt | 3.0 MB | ++----------------------------------+------------+ +| encoder_jit_trace-pnnx.ncnn.bin | 159 MB | ++----------------------------------+------------+ +| decoder_jit_trace-pnnx.ncnn.bin | 503 KB | ++----------------------------------+------------+ +| joiner_jit_trace-pnnx.ncnn.bin | 1.5 MB | ++----------------------------------+------------+ + +You can see that the file sizes of the models after conversion are about one half +of the models before conversion: + + - encoder: 318 MB vs 159 MB + - decoder: 1010 KB vs 503 KB + - joiner: 3.0 MB vs 1.5 MB + +The reason is that by default ``pnnx`` converts ``float32`` parameters +to ``float16``. A ``float32`` parameter occupies 4 bytes, while it is 2 bytes +for ``float16``. Thus, it is ``twice smaller`` after conversion. + +.. hint:: + + If you use ``pnnx ./encoder_jit_trace-pnnx.pt fp16=0``, then ``pnnx`` + won't convert ``float32`` to ``float16``. + +5. Test the exported models in icefall +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. note:: + + We assume you have set up the environment variable ``PYTHONPATH`` when + building `ncnn`_. + +Now we have successfully converted our pre-trained model to `ncnn`_ format. +The generated 6 files are what we need. You can use the following code to +test the converted models: + +.. code-block:: bash + + python3 ./lstm_transducer_stateless2/streaming-ncnn-decode.py \ + --tokens ./icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/data/lang_bpe_500/tokens.txt \ + --encoder-param-filename ./icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/encoder_jit_trace-pnnx.ncnn.param \ + --encoder-bin-filename ./icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/encoder_jit_trace-pnnx.ncnn.bin \ + --decoder-param-filename ./icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/decoder_jit_trace-pnnx.ncnn.param \ + --decoder-bin-filename ./icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/decoder_jit_trace-pnnx.ncnn.bin \ + --joiner-param-filename ./icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/joiner_jit_trace-pnnx.ncnn.param \ + --joiner-bin-filename ./icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/joiner_jit_trace-pnnx.ncnn.bin \ + ./icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/test_wavs/1089-134686-0001.wav + +.. hint:: + + `ncnn`_ supports only ``batch size == 1``, so ``streaming-ncnn-decode.py`` accepts + only 1 wave file as input. + +The output is given below: + +.. literalinclude:: ./code/test-streaming-ncnn-decode-lstm-transducer-libri.txt + +Congratulations! You have successfully exported a model from PyTorch to `ncnn`_! + +.. _lstm-modify-the-exported-encoder-for-sherpa-ncnn: + +6. Modify the exported encoder for sherpa-ncnn +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +In order to use the exported models in `sherpa-ncnn`_, we have to modify +``encoder_jit_trace-pnnx.ncnn.param``. + +Let us have a look at the first few lines of ``encoder_jit_trace-pnnx.ncnn.param``: + +.. code-block:: + + 7767517 + 267 379 + Input in0 0 1 in0 + +**Explanation** of the above three lines: + + 1. ``7767517``, it is a magic number and should not be changed. + 2. ``267 379``, the first number ``267`` specifies the number of layers + in this file, while ``379`` specifies the number of intermediate outputs + of this file + 3. ``Input in0 0 1 in0``, ``Input`` is the layer type of this layer; ``in0`` + is the layer name of this layer; ``0`` means this layer has no input; + ``1`` means this layer has one output; ``in0`` is the output name of + this layer. + +We need to add 1 extra line and also increment the number of layers. +The result looks like below: + +.. code-block:: bash + + 7767517 + 268 379 + SherpaMetaData sherpa_meta_data1 0 0 0=3 1=12 2=512 3=1024 + Input in0 0 1 in0 + +**Explanation** + + 1. ``7767517``, it is still the same + 2. ``268 379``, we have added an extra layer, so we need to update ``267`` to ``268``. + We don't need to change ``379`` since the newly added layer has no inputs or outputs. + 3. ``SherpaMetaData sherpa_meta_data1 0 0 0=3 1=12 2=512 3=1024`` + This line is newly added. Its explanation is given below: + + - ``SherpaMetaData`` is the type of this layer. Must be ``SherpaMetaData``. + - ``sherpa_meta_data1`` is the name of this layer. Must be ``sherpa_meta_data1``. + - ``0 0`` means this layer has no inputs or output. Must be ``0 0`` + - ``0=3``, 0 is the key and 3 is the value. MUST be ``0=3`` + - ``1=12``, 1 is the key and 12 is the value of the + parameter ``--num-encoder-layers`` that you provided when running + ``./lstm_transducer_stateless2/export-for-ncnn.py``. + - ``2=512``, 2 is the key and 512 is the value of the + parameter ``--encoder-dim`` that you provided when running + ``./lstm_transducer_stateless2/export-for-ncnn.py``. + - ``3=1024``, 3 is the key and 1024 is the value of the + parameter ``--rnn-hidden-size`` that you provided when running + ``./lstm_transducer_stateless2/export-for-ncnn.py``. + + For ease of reference, we list the key-value pairs that you need to add + in the following table. If your model has a different setting, please + change the values for ``SherpaMetaData`` accordingly. Otherwise, you + will be ``SAD``. + + +------+-----------------------------+ + | key | value | + +======+=============================+ + | 0 | 3 (fixed) | + +------+-----------------------------+ + | 1 | ``--num-encoder-layers`` | + +------+-----------------------------+ + | 2 | ``--encoder-dim`` | + +------+-----------------------------+ + | 3 | ``--rnn-hidden-size`` | + +------+-----------------------------+ + + 4. ``Input in0 0 1 in0``. No need to change it. + +.. caution:: + + When you add a new layer ``SherpaMetaData``, please remember to update the + number of layers. In our case, update ``267`` to ``268``. Otherwise, + you will be SAD later. + +.. hint:: + + After adding the new layer ``SherpaMetaData``, you cannot use this model + with ``streaming-ncnn-decode.py`` anymore since ``SherpaMetaData`` is + supported only in `sherpa-ncnn`_. + +.. hint:: + + `ncnn`_ is very flexible. You can add new layers to it just by text-editing + the ``param`` file! You don't need to change the ``bin`` file. + +Now you can use this model in `sherpa-ncnn`_. +Please refer to the following documentation: + + - Linux/macOS/Windows/arm/aarch64: ``_ + - ``Android``: ``_ + - ``iOS``: ``_ + - Python: ``_ + +We have a list of pre-trained models that have been exported for `sherpa-ncnn`_: + + - ``_ + + You can find more usages there. + +7. (Optional) int8 quantization with sherpa-ncnn +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +This step is optional. + +In this step, we describe how to quantize our model with ``int8``. + +Change :ref:`lstm-transducer-step-4-export-torchscript-model-via-pnnx` to +disable ``fp16`` when using ``pnnx``: + +.. code-block:: + + cd icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/ + + pnnx ./encoder_jit_trace-pnnx.pt fp16=0 + pnnx ./decoder_jit_trace-pnnx.pt + pnnx ./joiner_jit_trace-pnnx.pt fp16=0 + +.. note:: + + We add ``fp16=0`` when exporting the encoder and joiner. `ncnn`_ does not + support quantizing the decoder model yet. We will update this documentation + once `ncnn`_ supports it. (Maybe in this year, 2023). + +.. code-block:: bash + + ls -lh icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/*_jit_trace-pnnx.ncnn.{param,bin} + + -rw-r--r-- 1 kuangfangjun root 503K Feb 17 11:32 icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/decoder_jit_trace-pnnx.ncnn.bin + -rw-r--r-- 1 kuangfangjun root 437 Feb 17 11:32 icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/decoder_jit_trace-pnnx.ncnn.param + -rw-r--r-- 1 kuangfangjun root 317M Feb 17 11:54 icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/encoder_jit_trace-pnnx.ncnn.bin + -rw-r--r-- 1 kuangfangjun root 21K Feb 17 11:54 icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/encoder_jit_trace-pnnx.ncnn.param + -rw-r--r-- 1 kuangfangjun root 3.0M Feb 17 11:54 icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/joiner_jit_trace-pnnx.ncnn.bin + -rw-r--r-- 1 kuangfangjun root 488 Feb 17 11:54 icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/joiner_jit_trace-pnnx.ncnn.param + + +Let us compare again the file sizes: + ++----------------------------------------+------------+ +| File name | File size | ++----------------------------------------+------------+ +| encoder_jit_trace-pnnx.pt | 318 MB | ++----------------------------------------+------------+ +| decoder_jit_trace-pnnx.pt | 1010 KB | ++----------------------------------------+------------+ +| joiner_jit_trace-pnnx.pt | 3.0 MB | ++----------------------------------------+------------+ +| encoder_jit_trace-pnnx.ncnn.bin (fp16) | 159 MB | ++----------------------------------------+------------+ +| decoder_jit_trace-pnnx.ncnn.bin (fp16) | 503 KB | ++----------------------------------------+------------+ +| joiner_jit_trace-pnnx.ncnn.bin (fp16) | 1.5 MB | ++----------------------------------------+------------+ +| encoder_jit_trace-pnnx.ncnn.bin (fp32) | 317 MB | ++----------------------------------------+------------+ +| joiner_jit_trace-pnnx.ncnn.bin (fp32) | 3.0 MB | ++----------------------------------------+------------+ + +You can see that the file sizes are doubled when we disable ``fp16``. + +.. note:: + + You can again use ``streaming-ncnn-decode.py`` to test the exported models. + +Next, follow :ref:`lstm-modify-the-exported-encoder-for-sherpa-ncnn` +to modify ``encoder_jit_trace-pnnx.ncnn.param``. + +Change + +.. code-block:: bash + + 7767517 + 267 379 + Input in0 0 1 in0 + +to + +.. code-block:: bash + + 7767517 + 268 379 + SherpaMetaData sherpa_meta_data1 0 0 0=3 1=12 2=512 3=1024 + Input in0 0 1 in0 + +.. caution:: + + Please follow :ref:`lstm-modify-the-exported-encoder-for-sherpa-ncnn` + to change the values for ``SherpaMetaData`` if your model uses a different setting. + +Next, let us compile `sherpa-ncnn`_ since we will quantize our models within +`sherpa-ncnn`_. + +.. code-block:: bash + + # We will download sherpa-ncnn to $HOME/open-source/ + # You can change it to anywhere you like. + cd $HOME + mkdir -p open-source + + cd open-source + git clone https://github.com/k2-fsa/sherpa-ncnn + cd sherpa-ncnn + mkdir build + cd build + cmake .. + make -j 4 + + ./bin/generate-int8-scale-table + + export PATH=$HOME/open-source/sherpa-ncnn/build/bin:$PATH + +The output of the above commands are: + +.. code-block:: bash + + (py38) kuangfangjun:build$ generate-int8-scale-table + Please provide 10 arg. Currently given: 1 + Usage: + generate-int8-scale-table encoder.param encoder.bin decoder.param decoder.bin joiner.param joiner.bin encoder-scale-table.txt joiner-scale-table.txt wave_filenames.txt + + Each line in wave_filenames.txt is a path to some 16k Hz mono wave file. + +We need to create a file ``wave_filenames.txt``, in which we need to put +some calibration wave files. For testing purpose, we put the ``test_wavs`` +from the pre-trained model repository +``_ + +.. code-block:: bash + + cd egs/librispeech/ASR + cd icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/ + + cat < wave_filenames.txt + ../test_wavs/1089-134686-0001.wav + ../test_wavs/1221-135766-0001.wav + ../test_wavs/1221-135766-0002.wav + EOF + +Now we can calculate the scales needed for quantization with the calibration data: + +.. code-block:: bash + + cd egs/librispeech/ASR + cd icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/ + + generate-int8-scale-table \ + ./encoder_jit_trace-pnnx.ncnn.param \ + ./encoder_jit_trace-pnnx.ncnn.bin \ + ./decoder_jit_trace-pnnx.ncnn.param \ + ./decoder_jit_trace-pnnx.ncnn.bin \ + ./joiner_jit_trace-pnnx.ncnn.param \ + ./joiner_jit_trace-pnnx.ncnn.bin \ + ./encoder-scale-table.txt \ + ./joiner-scale-table.txt \ + ./wave_filenames.txt + +The output logs are in the following: + +.. literalinclude:: ./code/generate-int-8-scale-table-for-lstm.txt + +It generates the following two files: + +.. code-block:: bash + + ls -lh encoder-scale-table.txt joiner-scale-table.txt + + -rw-r--r-- 1 kuangfangjun root 345K Feb 17 12:13 encoder-scale-table.txt + -rw-r--r-- 1 kuangfangjun root 17K Feb 17 12:13 joiner-scale-table.txt + +.. caution:: + + Definitely, you need more calibration data to compute the scale table. + +Finally, let us use the scale table to quantize our models into ``int8``. + +.. code-block:: bash + + ncnn2int8 + + usage: ncnn2int8 [inparam] [inbin] [outparam] [outbin] [calibration table] + +First, we quantize the encoder model: + +.. code-block:: bash + + cd egs/librispeech/ASR + cd icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/ + + ncnn2int8 \ + ./encoder_jit_trace-pnnx.ncnn.param \ + ./encoder_jit_trace-pnnx.ncnn.bin \ + ./encoder_jit_trace-pnnx.ncnn.int8.param \ + ./encoder_jit_trace-pnnx.ncnn.int8.bin \ + ./encoder-scale-table.txt + +Next, we quantize the joiner model: + +.. code-block:: bash + + ncnn2int8 \ + ./joiner_jit_trace-pnnx.ncnn.param \ + ./joiner_jit_trace-pnnx.ncnn.bin \ + ./joiner_jit_trace-pnnx.ncnn.int8.param \ + ./joiner_jit_trace-pnnx.ncnn.int8.bin \ + ./joiner-scale-table.txt + +The above two commands generate the following 4 files: + +.. code-block:: + + -rw-r--r-- 1 kuangfangjun root 218M Feb 17 12:19 encoder_jit_trace-pnnx.ncnn.int8.bin + -rw-r--r-- 1 kuangfangjun root 21K Feb 17 12:19 encoder_jit_trace-pnnx.ncnn.int8.param + -rw-r--r-- 1 kuangfangjun root 774K Feb 17 12:19 joiner_jit_trace-pnnx.ncnn.int8.bin + -rw-r--r-- 1 kuangfangjun root 496 Feb 17 12:19 joiner_jit_trace-pnnx.ncnn.int8.param + +Congratulations! You have successfully quantized your model from ``float32`` to ``int8``. + +.. caution:: + + ``ncnn.int8.param`` and ``ncnn.int8.bin`` must be used in pairs. + + You can replace ``ncnn.param`` and ``ncnn.bin`` with ``ncnn.int8.param`` + and ``ncnn.int8.bin`` in `sherpa-ncnn`_ if you like. + + For instance, to use only the ``int8`` encoder in ``sherpa-ncnn``, you can + replace the following invocation: + + .. code-block:: + + cd egs/librispeech/ASR + cd icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/ + + sherpa-ncnn \ + ../data/lang_bpe_500/tokens.txt \ + ./encoder_jit_trace-pnnx.ncnn.param \ + ./encoder_jit_trace-pnnx.ncnn.bin \ + ./decoder_jit_trace-pnnx.ncnn.param \ + ./decoder_jit_trace-pnnx.ncnn.bin \ + ./joiner_jit_trace-pnnx.ncnn.param \ + ./joiner_jit_trace-pnnx.ncnn.bin \ + ../test_wavs/1089-134686-0001.wav + + with + + .. code-block:: bash + + cd egs/librispeech/ASR + cd icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/ + + sherpa-ncnn \ + ../data/lang_bpe_500/tokens.txt \ + ./encoder_jit_trace-pnnx.ncnn.int8.param \ + ./encoder_jit_trace-pnnx.ncnn.int8.bin \ + ./decoder_jit_trace-pnnx.ncnn.param \ + ./decoder_jit_trace-pnnx.ncnn.bin \ + ./joiner_jit_trace-pnnx.ncnn.param \ + ./joiner_jit_trace-pnnx.ncnn.bin \ + ../test_wavs/1089-134686-0001.wav + +The following table compares again the file sizes: + ++----------------------------------------+------------+ +| File name | File size | ++----------------------------------------+------------+ +| encoder_jit_trace-pnnx.pt | 318 MB | ++----------------------------------------+------------+ +| decoder_jit_trace-pnnx.pt | 1010 KB | ++----------------------------------------+------------+ +| joiner_jit_trace-pnnx.pt | 3.0 MB | ++----------------------------------------+------------+ +| encoder_jit_trace-pnnx.ncnn.bin (fp16) | 159 MB | ++----------------------------------------+------------+ +| decoder_jit_trace-pnnx.ncnn.bin (fp16) | 503 KB | ++----------------------------------------+------------+ +| joiner_jit_trace-pnnx.ncnn.bin (fp16) | 1.5 MB | ++----------------------------------------+------------+ +| encoder_jit_trace-pnnx.ncnn.bin (fp32) | 317 MB | ++----------------------------------------+------------+ +| joiner_jit_trace-pnnx.ncnn.bin (fp32) | 3.0 MB | ++----------------------------------------+------------+ +| encoder_jit_trace-pnnx.ncnn.int8.bin | 218 MB | ++----------------------------------------+------------+ +| joiner_jit_trace-pnnx.ncnn.int8.bin | 774 KB | ++----------------------------------------+------------+ + +You can see that the file size of the joiner model after ``int8`` quantization +is much smaller. However, the size of the encoder model is even larger than +the ``fp16`` counterpart. The reason is that `ncnn`_ currently does not support +quantizing ``LSTM`` layers into ``8-bit``. Please see +``_ + +.. hint:: + + Currently, only linear layers and convolutional layers are quantized + with ``int8``, so you don't see an exact ``4x`` reduction in file sizes. + +.. note:: + + You need to test the recognition accuracy after ``int8`` quantization. + + +That's it! Have fun with `sherpa-ncnn`_! diff --git a/docs/source/model-export/export-ncnn.rst b/docs/source/model-export/export-ncnn.rst index ed0264089..841d1d4de 100644 --- a/docs/source/model-export/export-ncnn.rst +++ b/docs/source/model-export/export-ncnn.rst @@ -1,15 +1,26 @@ Export to ncnn ============== -We support exporting both -`LSTM transducer models `_ -and -`ConvEmformer transducer models `_ -to `ncnn `_. +We support exporting the following models +to `ncnn `_: -We also provide ``_ -performing speech recognition using ``ncnn`` with exported models. -It has been tested on Linux, macOS, Windows, ``Android``, and ``Raspberry Pi``. + - `Zipformer transducer models `_ + + - `LSTM transducer models `_ + + - `ConvEmformer transducer models `_ + +We also provide `sherpa-ncnn`_ +for performing speech recognition using `ncnn`_ with exported models. +It has been tested on the following platforms: + + - Linux + - macOS + - Windows + - ``Android`` + - ``iOS`` + - ``Raspberry Pi`` + - `爱芯派 `_ (`MAIX-III AXera-Pi `_). `sherpa-ncnn`_ is self-contained and can be statically linked to produce a binary containing everything needed. Please refer @@ -18,754 +29,7 @@ to its documentation for details: - ``_ -Export LSTM transducer models ------------------------------ +.. toctree:: -Please refer to :ref:`export-lstm-transducer-model-for-ncnn` for details. - - - -Export ConvEmformer transducer models -------------------------------------- - -We use the pre-trained model from the following repository as an example: - - - ``_ - -We will show you step by step how to export it to `ncnn`_ and run it with `sherpa-ncnn`_. - -.. hint:: - - We use ``Ubuntu 18.04``, ``torch 1.10``, and ``Python 3.8`` for testing. - -.. caution:: - - Please use a more recent version of PyTorch. For instance, ``torch 1.8`` - may ``not`` work. - -1. Download the pre-trained model -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -.. hint:: - - You can also refer to ``_ to download the pre-trained model. - - You have to install `git-lfs`_ before you continue. - -.. code-block:: bash - - cd egs/librispeech/ASR - - GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05 - cd icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05 - - git lfs pull --include "exp/pretrained-epoch-30-avg-10-averaged.pt" - git lfs pull --include "data/lang_bpe_500/bpe.model" - - cd .. - -.. note:: - - We download ``exp/pretrained-xxx.pt``, not ``exp/cpu-jit_xxx.pt``. - - -In the above code, we download the pre-trained model into the directory -``egs/librispeech/ASR/icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05``. - -2. Install ncnn and pnnx -^^^^^^^^^^^^^^^^^^^^^^^^ - -.. code-block:: bash - - # We put ncnn into $HOME/open-source/ncnn - # You can change it to anywhere you like - - cd $HOME - mkdir -p open-source - cd open-source - - git clone https://github.com/csukuangfj/ncnn - cd ncnn - git submodule update --recursive --init - - # Note: We don't use "python setup.py install" or "pip install ." here - - mkdir -p build-wheel - cd build-wheel - - cmake \ - -DCMAKE_BUILD_TYPE=Release \ - -DNCNN_PYTHON=ON \ - -DNCNN_BUILD_BENCHMARK=OFF \ - -DNCNN_BUILD_EXAMPLES=OFF \ - -DNCNN_BUILD_TOOLS=ON \ - .. - - make -j4 - - cd .. - - # Note: $PWD here is $HOME/open-source/ncnn - - export PYTHONPATH=$PWD/python:$PYTHONPATH - export PATH=$PWD/tools/pnnx/build/src:$PATH - export PATH=$PWD/build-wheel/tools/quantize:$PATH - - # Now build pnnx - cd tools/pnnx - mkdir build - cd build - cmake .. - make -j4 - - ./src/pnnx - -Congratulations! You have successfully installed the following components: - - - ``pnxx``, which is an executable located in - ``$HOME/open-source/ncnn/tools/pnnx/build/src``. We will use - it to convert models exported by ``torch.jit.trace()``. - - ``ncnn2int8``, which is an executable located in - ``$HOME/open-source/ncnn/build-wheel/tools/quantize``. We will use - it to quantize our models to ``int8``. - - ``ncnn.cpython-38-x86_64-linux-gnu.so``, which is a Python module located - in ``$HOME/open-source/ncnn/python/ncnn``. - - .. note:: - - I am using ``Python 3.8``, so it - is ``ncnn.cpython-38-x86_64-linux-gnu.so``. If you use a different - version, say, ``Python 3.9``, the name would be - ``ncnn.cpython-39-x86_64-linux-gnu.so``. - - Also, if you are not using Linux, the file name would also be different. - But that does not matter. As long as you can compile it, it should work. - -We have set up ``PYTHONPATH`` so that you can use ``import ncnn`` in your -Python code. We have also set up ``PATH`` so that you can use -``pnnx`` and ``ncnn2int8`` later in your terminal. - -.. caution:: - - Please don't use ``_. - We have made some modifications to the offical `ncnn`_. - - We will synchronize ``_ periodically - with the official one. - -3. Export the model via torch.jit.trace() -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -First, let us rename our pre-trained model: - -.. code-block:: - - cd egs/librispeech/ASR - - cd icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp - - ln -s pretrained-epoch-30-avg-10-averaged.pt epoch-30.pt - - cd ../.. - -Next, we use the following code to export our model: - -.. code-block:: bash - - dir=./icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/ - - ./conv_emformer_transducer_stateless2/export-for-ncnn.py \ - --exp-dir $dir/exp \ - --bpe-model $dir/data/lang_bpe_500/bpe.model \ - --epoch 30 \ - --avg 1 \ - --use-averaged-model 0 \ - \ - --num-encoder-layers 12 \ - --chunk-length 32 \ - --cnn-module-kernel 31 \ - --left-context-length 32 \ - --right-context-length 8 \ - --memory-size 32 \ - --encoder-dim 512 - -.. hint:: - - We have renamed our model to ``epoch-30.pt`` so that we can use ``--epoch 30``. - There is only one pre-trained model, so we use ``--avg 1 --use-averaged-model 0``. - - If you have trained a model by yourself and if you have all checkpoints - available, please first use ``decode.py`` to tune ``--epoch --avg`` - and select the best combination with with ``--use-averaged-model 1``. - -.. note:: - - You will see the following log output: - - .. literalinclude:: ./code/export-conv-emformer-transducer-for-ncnn-output.txt - - The log shows the model has ``75490012`` parameters, i.e., ``~75 M``. - - .. code-block:: - - ls -lh icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/pretrained-epoch-30-avg-10-averaged.pt - - -rw-r--r-- 1 kuangfangjun root 289M Jan 11 12:05 icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/pretrained-epoch-30-avg-10-averaged.pt - - You can see that the file size of the pre-trained model is ``289 MB``, which - is roughly ``75490012*4/1024/1024 = 287.97 MB``. - -After running ``conv_emformer_transducer_stateless2/export-for-ncnn.py``, -we will get the following files: - -.. code-block:: bash - - ls -lh icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/*pnnx* - - -rw-r--r-- 1 kuangfangjun root 1010K Jan 11 12:15 icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/decoder_jit_trace-pnnx.pt - -rw-r--r-- 1 kuangfangjun root 283M Jan 11 12:15 icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/encoder_jit_trace-pnnx.pt - -rw-r--r-- 1 kuangfangjun root 3.0M Jan 11 12:15 icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/joiner_jit_trace-pnnx.pt - - -.. _conv-emformer-step-3-export-torchscript-model-via-pnnx: - -3. Export torchscript model via pnnx -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -.. hint:: - - Make sure you have set up the ``PATH`` environment variable. Otherwise, - it will throw an error saying that ``pnnx`` could not be found. - -Now, it's time to export our models to `ncnn`_ via ``pnnx``. - -.. code-block:: - - cd icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/ - - pnnx ./encoder_jit_trace-pnnx.pt - pnnx ./decoder_jit_trace-pnnx.pt - pnnx ./joiner_jit_trace-pnnx.pt - -It will generate the following files: - -.. code-block:: bash - - ls -lh icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/*ncnn*{bin,param} - - -rw-r--r-- 1 kuangfangjun root 503K Jan 11 12:38 icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/decoder_jit_trace-pnnx.ncnn.bin - -rw-r--r-- 1 kuangfangjun root 437 Jan 11 12:38 icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/decoder_jit_trace-pnnx.ncnn.param - -rw-r--r-- 1 kuangfangjun root 142M Jan 11 12:36 icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/encoder_jit_trace-pnnx.ncnn.bin - -rw-r--r-- 1 kuangfangjun root 79K Jan 11 12:36 icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/encoder_jit_trace-pnnx.ncnn.param - -rw-r--r-- 1 kuangfangjun root 1.5M Jan 11 12:38 icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/joiner_jit_trace-pnnx.ncnn.bin - -rw-r--r-- 1 kuangfangjun root 488 Jan 11 12:38 icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/joiner_jit_trace-pnnx.ncnn.param - -There are two types of files: - -- ``param``: It is a text file containing the model architectures. You can - use a text editor to view its content. -- ``bin``: It is a binary file containing the model parameters. - -We compare the file sizes of the models below before and after converting via ``pnnx``: - -.. see https://tableconvert.com/restructuredtext-generator - -+----------------------------------+------------+ -| File name | File size | -+==================================+============+ -| encoder_jit_trace-pnnx.pt | 283 MB | -+----------------------------------+------------+ -| decoder_jit_trace-pnnx.pt | 1010 KB | -+----------------------------------+------------+ -| joiner_jit_trace-pnnx.pt | 3.0 MB | -+----------------------------------+------------+ -| encoder_jit_trace-pnnx.ncnn.bin | 142 MB | -+----------------------------------+------------+ -| decoder_jit_trace-pnnx.ncnn.bin | 503 KB | -+----------------------------------+------------+ -| joiner_jit_trace-pnnx.ncnn.bin | 1.5 MB | -+----------------------------------+------------+ - -You can see that the file sizes of the models after conversion are about one half -of the models before conversion: - - - encoder: 283 MB vs 142 MB - - decoder: 1010 KB vs 503 KB - - joiner: 3.0 MB vs 1.5 MB - -The reason is that by default ``pnnx`` converts ``float32`` parameters -to ``float16``. A ``float32`` parameter occupies 4 bytes, while it is 2 bytes -for ``float16``. Thus, it is ``twice smaller`` after conversion. - -.. hint:: - - If you use ``pnnx ./encoder_jit_trace-pnnx.pt fp16=0``, then ``pnnx`` - won't convert ``float32`` to ``float16``. - -4. Test the exported models in icefall -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -.. note:: - - We assume you have set up the environment variable ``PYTHONPATH`` when - building `ncnn`_. - -Now we have successfully converted our pre-trained model to `ncnn`_ format. -The generated 6 files are what we need. You can use the following code to -test the converted models: - -.. code-block:: bash - - ./conv_emformer_transducer_stateless2/streaming-ncnn-decode.py \ - --tokens ./icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/data/lang_bpe_500/tokens.txt \ - --encoder-param-filename ./icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/encoder_jit_trace-pnnx.ncnn.param \ - --encoder-bin-filename ./icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/encoder_jit_trace-pnnx.ncnn.bin \ - --decoder-param-filename ./icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/decoder_jit_trace-pnnx.ncnn.param \ - --decoder-bin-filename ./icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/decoder_jit_trace-pnnx.ncnn.bin \ - --joiner-param-filename ./icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/joiner_jit_trace-pnnx.ncnn.param \ - --joiner-bin-filename ./icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/joiner_jit_trace-pnnx.ncnn.bin \ - ./icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/test_wavs/1089-134686-0001.wav - -.. hint:: - - `ncnn`_ supports only ``batch size == 1``, so ``streaming-ncnn-decode.py`` accepts - only 1 wave file as input. - -The output is given below: - -.. literalinclude:: ./code/test-stremaing-ncnn-decode-conv-emformer-transducer-libri.txt - -Congratulations! You have successfully exported a model from PyTorch to `ncnn`_! - - -.. _conv-emformer-modify-the-exported-encoder-for-sherpa-ncnn: - -5. Modify the exported encoder for sherpa-ncnn -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -In order to use the exported models in `sherpa-ncnn`_, we have to modify -``encoder_jit_trace-pnnx.ncnn.param``. - -Let us have a look at the first few lines of ``encoder_jit_trace-pnnx.ncnn.param``: - -.. code-block:: - - 7767517 - 1060 1342 - Input in0 0 1 in0 - -**Explanation** of the above three lines: - - 1. ``7767517``, it is a magic number and should not be changed. - 2. ``1060 1342``, the first number ``1060`` specifies the number of layers - in this file, while ``1342`` specifies the number of intermediate outputs - of this file - 3. ``Input in0 0 1 in0``, ``Input`` is the layer type of this layer; ``in0`` - is the layer name of this layer; ``0`` means this layer has no input; - ``1`` means this layer has one output; ``in0`` is the output name of - this layer. - -We need to add 1 extra line and also increment the number of layers. -The result looks like below: - -.. code-block:: bash - - 7767517 - 1061 1342 - SherpaMetaData sherpa_meta_data1 0 0 0=1 1=12 2=32 3=31 4=8 5=32 6=8 7=512 - Input in0 0 1 in0 - -**Explanation** - - 1. ``7767517``, it is still the same - 2. ``1061 1342``, we have added an extra layer, so we need to update ``1060`` to ``1061``. - We don't need to change ``1342`` since the newly added layer has no inputs or outputs. - 3. ``SherpaMetaData sherpa_meta_data1 0 0 0=1 1=12 2=32 3=31 4=8 5=32 6=8 7=512`` - This line is newly added. Its explanation is given below: - - - ``SherpaMetaData`` is the type of this layer. Must be ``SherpaMetaData``. - - ``sherpa_meta_data1`` is the name of this layer. Must be ``sherpa_meta_data1``. - - ``0 0`` means this layer has no inputs or output. Must be ``0 0`` - - ``0=1``, 0 is the key and 1 is the value. MUST be ``0=1`` - - ``1=12``, 1 is the key and 12 is the value of the - parameter ``--num-encoder-layers`` that you provided when running - ``conv_emformer_transducer_stateless2/export-for-ncnn.py``. - - ``2=32``, 2 is the key and 32 is the value of the - parameter ``--memory-size`` that you provided when running - ``conv_emformer_transducer_stateless2/export-for-ncnn.py``. - - ``3=31``, 3 is the key and 31 is the value of the - parameter ``--cnn-module-kernel`` that you provided when running - ``conv_emformer_transducer_stateless2/export-for-ncnn.py``. - - ``4=8``, 4 is the key and 8 is the value of the - parameter ``--left-context-length`` that you provided when running - ``conv_emformer_transducer_stateless2/export-for-ncnn.py``. - - ``5=32``, 5 is the key and 32 is the value of the - parameter ``--chunk-length`` that you provided when running - ``conv_emformer_transducer_stateless2/export-for-ncnn.py``. - - ``6=8``, 6 is the key and 8 is the value of the - parameter ``--right-context-length`` that you provided when running - ``conv_emformer_transducer_stateless2/export-for-ncnn.py``. - - ``7=512``, 7 is the key and 512 is the value of the - parameter ``--encoder-dim`` that you provided when running - ``conv_emformer_transducer_stateless2/export-for-ncnn.py``. - - For ease of reference, we list the key-value pairs that you need to add - in the following table. If your model has a different setting, please - change the values for ``SherpaMetaData`` accordingly. Otherwise, you - will be ``SAD``. - - +------+-----------------------------+ - | key | value | - +======+=============================+ - | 0 | 1 (fixed) | - +------+-----------------------------+ - | 1 | ``--num-encoder-layers`` | - +------+-----------------------------+ - | 2 | ``--memory-size`` | - +------+-----------------------------+ - | 3 | ``--cnn-module-kernel`` | - +------+-----------------------------+ - | 4 | ``--left-context-length`` | - +------+-----------------------------+ - | 5 | ``--chunk-length`` | - +------+-----------------------------+ - | 6 | ``--right-context-length`` | - +------+-----------------------------+ - | 7 | ``--encoder-dim`` | - +------+-----------------------------+ - - 4. ``Input in0 0 1 in0``. No need to change it. - -.. caution:: - - When you add a new layer ``SherpaMetaData``, please remember to update the - number of layers. In our case, update ``1060`` to ``1061``. Otherwise, - you will be SAD later. - -.. hint:: - - After adding the new layer ``SherpaMetaData``, you cannot use this model - with ``streaming-ncnn-decode.py`` anymore since ``SherpaMetaData`` is - supported only in `sherpa-ncnn`_. - -.. hint:: - - `ncnn`_ is very flexible. You can add new layers to it just by text-editing - the ``param`` file! You don't need to change the ``bin`` file. - -Now you can use this model in `sherpa-ncnn`_. -Please refer to the following documentation: - - - Linux/macOS/Windows/arm/aarch64: ``_ - - Android: ``_ - - Python: ``_ - -We have a list of pre-trained models that have been exported for `sherpa-ncnn`_: - - - ``_ - - You can find more usages there. - -6. (Optional) int8 quantization with sherpa-ncnn -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -This step is optional. - -In this step, we describe how to quantize our model with ``int8``. - -Change :ref:`conv-emformer-step-3-export-torchscript-model-via-pnnx` to -disable ``fp16`` when using ``pnnx``: - -.. code-block:: - - cd icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/ - - pnnx ./encoder_jit_trace-pnnx.pt fp16=0 - pnnx ./decoder_jit_trace-pnnx.pt - pnnx ./joiner_jit_trace-pnnx.pt fp16=0 - -.. note:: - - We add ``fp16=0`` when exporting the encoder and joiner. `ncnn`_ does not - support quantizing the decoder model yet. We will update this documentation - once `ncnn`_ supports it. (Maybe in this year, 2023). - -It will generate the following files - -.. code-block:: bash - - ls -lh icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/*_jit_trace-pnnx.ncnn.{param,bin} - - -rw-r--r-- 1 kuangfangjun root 503K Jan 11 15:56 icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/decoder_jit_trace-pnnx.ncnn.bin - -rw-r--r-- 1 kuangfangjun root 437 Jan 11 15:56 icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/decoder_jit_trace-pnnx.ncnn.param - -rw-r--r-- 1 kuangfangjun root 283M Jan 11 15:56 icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/encoder_jit_trace-pnnx.ncnn.bin - -rw-r--r-- 1 kuangfangjun root 79K Jan 11 15:56 icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/encoder_jit_trace-pnnx.ncnn.param - -rw-r--r-- 1 kuangfangjun root 3.0M Jan 11 15:56 icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/joiner_jit_trace-pnnx.ncnn.bin - -rw-r--r-- 1 kuangfangjun root 488 Jan 11 15:56 icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/joiner_jit_trace-pnnx.ncnn.param - -Let us compare again the file sizes: - -+----------------------------------------+------------+ -| File name | File size | -+----------------------------------------+------------+ -| encoder_jit_trace-pnnx.pt | 283 MB | -+----------------------------------------+------------+ -| decoder_jit_trace-pnnx.pt | 1010 KB | -+----------------------------------------+------------+ -| joiner_jit_trace-pnnx.pt | 3.0 MB | -+----------------------------------------+------------+ -| encoder_jit_trace-pnnx.ncnn.bin (fp16) | 142 MB | -+----------------------------------------+------------+ -| decoder_jit_trace-pnnx.ncnn.bin (fp16) | 503 KB | -+----------------------------------------+------------+ -| joiner_jit_trace-pnnx.ncnn.bin (fp16) | 1.5 MB | -+----------------------------------------+------------+ -| encoder_jit_trace-pnnx.ncnn.bin (fp32) | 283 MB | -+----------------------------------------+------------+ -| joiner_jit_trace-pnnx.ncnn.bin (fp32) | 3.0 MB | -+----------------------------------------+------------+ - -You can see that the file sizes are doubled when we disable ``fp16``. - -.. note:: - - You can again use ``streaming-ncnn-decode.py`` to test the exported models. - -Next, follow :ref:`conv-emformer-modify-the-exported-encoder-for-sherpa-ncnn` -to modify ``encoder_jit_trace-pnnx.ncnn.param``. - -Change - -.. code-block:: bash - - 7767517 - 1060 1342 - Input in0 0 1 in0 - -to - -.. code-block:: bash - - 7767517 - 1061 1342 - SherpaMetaData sherpa_meta_data1 0 0 0=1 1=12 2=32 3=31 4=8 5=32 6=8 7=512 - Input in0 0 1 in0 - -.. caution:: - - Please follow :ref:`conv-emformer-modify-the-exported-encoder-for-sherpa-ncnn` - to change the values for ``SherpaMetaData`` if your model uses a different setting. - - -Next, let us compile `sherpa-ncnn`_ since we will quantize our models within -`sherpa-ncnn`_. - -.. code-block:: bash - - # We will download sherpa-ncnn to $HOME/open-source/ - # You can change it to anywhere you like. - cd $HOME - mkdir -p open-source - - cd open-source - git clone https://github.com/k2-fsa/sherpa-ncnn - cd sherpa-ncnn - mkdir build - cd build - cmake .. - make -j 4 - - ./bin/generate-int8-scale-table - - export PATH=$HOME/open-source/sherpa-ncnn/build/bin:$PATH - -The output of the above commands are: - -.. code-block:: bash - - (py38) kuangfangjun:build$ generate-int8-scale-table - Please provide 10 arg. Currently given: 1 - Usage: - generate-int8-scale-table encoder.param encoder.bin decoder.param decoder.bin joiner.param joiner.bin encoder-scale-table.txt joiner-scale-table.txt wave_filenames.txt - - Each line in wave_filenames.txt is a path to some 16k Hz mono wave file. - -We need to create a file ``wave_filenames.txt``, in which we need to put -some calibration wave files. For testing purpose, we put the ``test_wavs`` -from the pre-trained model repository ``_ - -.. code-block:: bash - - cd egs/librispeech/ASR - cd icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/ - - cat < wave_filenames.txt - ../test_wavs/1089-134686-0001.wav - ../test_wavs/1221-135766-0001.wav - ../test_wavs/1221-135766-0002.wav - EOF - -Now we can calculate the scales needed for quantization with the calibration data: - -.. code-block:: bash - - cd egs/librispeech/ASR - cd icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/ - - generate-int8-scale-table \ - ./encoder_jit_trace-pnnx.ncnn.param \ - ./encoder_jit_trace-pnnx.ncnn.bin \ - ./decoder_jit_trace-pnnx.ncnn.param \ - ./decoder_jit_trace-pnnx.ncnn.bin \ - ./joiner_jit_trace-pnnx.ncnn.param \ - ./joiner_jit_trace-pnnx.ncnn.bin \ - ./encoder-scale-table.txt \ - ./joiner-scale-table.txt \ - ./wave_filenames.txt - -The output logs are in the following: - -.. literalinclude:: ./code/generate-int-8-scale-table-for-conv-emformer.txt - -It generates the following two files: - -.. code-block:: bash - - $ ls -lh encoder-scale-table.txt joiner-scale-table.txt - -rw-r--r-- 1 kuangfangjun root 955K Jan 11 17:28 encoder-scale-table.txt - -rw-r--r-- 1 kuangfangjun root 18K Jan 11 17:28 joiner-scale-table.txt - -.. caution:: - - Definitely, you need more calibration data to compute the scale table. - -Finally, let us use the scale table to quantize our models into ``int8``. - -.. code-block:: bash - - ncnn2int8 - - usage: ncnn2int8 [inparam] [inbin] [outparam] [outbin] [calibration table] - -First, we quantize the encoder model: - -.. code-block:: bash - - cd egs/librispeech/ASR - cd icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/ - - ncnn2int8 \ - ./encoder_jit_trace-pnnx.ncnn.param \ - ./encoder_jit_trace-pnnx.ncnn.bin \ - ./encoder_jit_trace-pnnx.ncnn.int8.param \ - ./encoder_jit_trace-pnnx.ncnn.int8.bin \ - ./encoder-scale-table.txt - -Next, we quantize the joiner model: - -.. code-block:: bash - - ncnn2int8 \ - ./joiner_jit_trace-pnnx.ncnn.param \ - ./joiner_jit_trace-pnnx.ncnn.bin \ - ./joiner_jit_trace-pnnx.ncnn.int8.param \ - ./joiner_jit_trace-pnnx.ncnn.int8.bin \ - ./joiner-scale-table.txt - -The above two commands generate the following 4 files: - -.. code-block:: bash - - -rw-r--r-- 1 kuangfangjun root 99M Jan 11 17:34 encoder_jit_trace-pnnx.ncnn.int8.bin - -rw-r--r-- 1 kuangfangjun root 78K Jan 11 17:34 encoder_jit_trace-pnnx.ncnn.int8.param - -rw-r--r-- 1 kuangfangjun root 774K Jan 11 17:35 joiner_jit_trace-pnnx.ncnn.int8.bin - -rw-r--r-- 1 kuangfangjun root 496 Jan 11 17:35 joiner_jit_trace-pnnx.ncnn.int8.param - -Congratulations! You have successfully quantized your model from ``float32`` to ``int8``. - -.. caution:: - - ``ncnn.int8.param`` and ``ncnn.int8.bin`` must be used in pairs. - - You can replace ``ncnn.param`` and ``ncnn.bin`` with ``ncnn.int8.param`` - and ``ncnn.int8.bin`` in `sherpa-ncnn`_ if you like. - - For instance, to use only the ``int8`` encoder in ``sherpa-ncnn``, you can - replace the following invocation: - - .. code-block:: - - cd egs/librispeech/ASR - cd icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/ - - sherpa-ncnn \ - ../data/lang_bpe_500/tokens.txt \ - ./encoder_jit_trace-pnnx.ncnn.param \ - ./encoder_jit_trace-pnnx.ncnn.bin \ - ./decoder_jit_trace-pnnx.ncnn.param \ - ./decoder_jit_trace-pnnx.ncnn.bin \ - ./joiner_jit_trace-pnnx.ncnn.param \ - ./joiner_jit_trace-pnnx.ncnn.bin \ - ../test_wavs/1089-134686-0001.wav - - with - - .. code-block:: - - cd egs/librispeech/ASR - cd icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/ - - sherpa-ncnn \ - ../data/lang_bpe_500/tokens.txt \ - ./encoder_jit_trace-pnnx.ncnn.int8.param \ - ./encoder_jit_trace-pnnx.ncnn.int8.bin \ - ./decoder_jit_trace-pnnx.ncnn.param \ - ./decoder_jit_trace-pnnx.ncnn.bin \ - ./joiner_jit_trace-pnnx.ncnn.param \ - ./joiner_jit_trace-pnnx.ncnn.bin \ - ../test_wavs/1089-134686-0001.wav - - -The following table compares again the file sizes: - - -+----------------------------------------+------------+ -| File name | File size | -+----------------------------------------+------------+ -| encoder_jit_trace-pnnx.pt | 283 MB | -+----------------------------------------+------------+ -| decoder_jit_trace-pnnx.pt | 1010 KB | -+----------------------------------------+------------+ -| joiner_jit_trace-pnnx.pt | 3.0 MB | -+----------------------------------------+------------+ -| encoder_jit_trace-pnnx.ncnn.bin (fp16) | 142 MB | -+----------------------------------------+------------+ -| decoder_jit_trace-pnnx.ncnn.bin (fp16) | 503 KB | -+----------------------------------------+------------+ -| joiner_jit_trace-pnnx.ncnn.bin (fp16) | 1.5 MB | -+----------------------------------------+------------+ -| encoder_jit_trace-pnnx.ncnn.bin (fp32) | 283 MB | -+----------------------------------------+------------+ -| joiner_jit_trace-pnnx.ncnn.bin (fp32) | 3.0 MB | -+----------------------------------------+------------+ -| encoder_jit_trace-pnnx.ncnn.int8.bin | 99 MB | -+----------------------------------------+------------+ -| joiner_jit_trace-pnnx.ncnn.int8.bin | 774 KB | -+----------------------------------------+------------+ - -You can see that the file sizes of the model after ``int8`` quantization -are much smaller. - -.. hint:: - - Currently, only linear layers and convolutional layers are quantized - with ``int8``, so you don't see an exact ``4x`` reduction in file sizes. - -.. note:: - - You need to test the recognition accuracy after ``int8`` quantization. - -You can find the speed comparison at ``_. - - -That's it! Have fun with `sherpa-ncnn`_! + export-ncnn-conv-emformer + export-ncnn-lstm diff --git a/docs/source/model-export/export-onnx.rst b/docs/source/model-export/export-onnx.rst index ddcbc965f..8f0cb11fb 100644 --- a/docs/source/model-export/export-onnx.rst +++ b/docs/source/model-export/export-onnx.rst @@ -10,7 +10,7 @@ There is also a file named ``onnx_pretrained.py``, which you can use the exported `ONNX`_ model in Python with `onnxruntime`_ to decode sound files. Example -======= +------- In the following, we demonstrate how to export a streaming Zipformer pre-trained model from diff --git a/docs/source/recipes/Streaming-ASR/librispeech/lstm_pruned_stateless_transducer.rst b/docs/source/recipes/Streaming-ASR/librispeech/lstm_pruned_stateless_transducer.rst index d04565e5d..911e84656 100644 --- a/docs/source/recipes/Streaming-ASR/librispeech/lstm_pruned_stateless_transducer.rst +++ b/docs/source/recipes/Streaming-ASR/librispeech/lstm_pruned_stateless_transducer.rst @@ -515,132 +515,6 @@ To use the generated files with ``./lstm_transducer_stateless2/jit_pretrained``: Please see ``_ for how to use the exported models in ``sherpa``. -.. _export-lstm-transducer-model-for-ncnn: - -Export LSTM transducer models for ncnn -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -We support exporting pretrained LSTM transducer models to -`ncnn `_ using -`pnnx `_. - -First, let us install a modified version of ``ncnn``: - -.. code-block:: bash - - git clone https://github.com/csukuangfj/ncnn - cd ncnn - git submodule update --recursive --init - - # Note: We don't use "python setup.py install" or "pip install ." here - - mkdir -p build-wheel - cd build-wheel - - cmake \ - -DCMAKE_BUILD_TYPE=Release \ - -DNCNN_PYTHON=ON \ - -DNCNN_BUILD_BENCHMARK=OFF \ - -DNCNN_BUILD_EXAMPLES=OFF \ - -DNCNN_BUILD_TOOLS=ON \ - .. - - make -j4 - - cd .. - - # Note: $PWD here is /path/to/ncnn - - export PYTHONPATH=$PWD/python:$PYTHONPATH - export PATH=$PWD/tools/pnnx/build/src:$PATH - export PATH=$PWD/build-wheel/tools/quantize:$PATH - - # now build pnnx - cd tools/pnnx - mkdir build - cd build - cmake .. - make -j4 - - ./src/pnnx - -.. note:: - - We assume that you have added the path to the binary ``pnnx`` to the - environment variable ``PATH``. - - We also assume that you have added ``build/tools/quantize`` to the environment - variable ``PATH`` so that you are able to use ``ncnn2int8`` later. - -Second, let us export the model using ``torch.jit.trace()`` that is suitable -for ``pnnx``: - -.. code-block:: bash - - iter=468000 - avg=16 - - ./lstm_transducer_stateless2/export-for-ncnn.py \ - --exp-dir ./lstm_transducer_stateless2/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ - --iter $iter \ - --avg $avg - -It will generate 3 files: - - - ``./lstm_transducer_stateless2/exp/encoder_jit_trace-pnnx.pt`` - - ``./lstm_transducer_stateless2/exp/decoder_jit_trace-pnnx.pt`` - - ``./lstm_transducer_stateless2/exp/joiner_jit_trace-pnnx.pt`` - -Third, convert torchscript model to ``ncnn`` format: - -.. code-block:: - - pnnx ./lstm_transducer_stateless2/exp/encoder_jit_trace-pnnx.pt - pnnx ./lstm_transducer_stateless2/exp/decoder_jit_trace-pnnx.pt - pnnx ./lstm_transducer_stateless2/exp/joiner_jit_trace-pnnx.pt - -It will generate the following files: - - - ``./lstm_transducer_stateless2/exp/encoder_jit_trace-pnnx.ncnn.param`` - - ``./lstm_transducer_stateless2/exp/encoder_jit_trace-pnnx.ncnn.bin`` - - ``./lstm_transducer_stateless2/exp/decoder_jit_trace-pnnx.ncnn.param`` - - ``./lstm_transducer_stateless2/exp/decoder_jit_trace-pnnx.ncnn.bin`` - - ``./lstm_transducer_stateless2/exp/joiner_jit_trace-pnnx.ncnn.param`` - - ``./lstm_transducer_stateless2/exp/joiner_jit_trace-pnnx.ncnn.bin`` - -To use the above generated files, run: - -.. code-block:: bash - - ./lstm_transducer_stateless2/ncnn-decode.py \ - --tokens ./data/lang_bpe_500/tokens.txt \ - --encoder-param-filename ./lstm_transducer_stateless2/exp/encoder_jit_trace-pnnx.ncnn.param \ - --encoder-bin-filename ./lstm_transducer_stateless2/exp/encoder_jit_trace-pnnx.ncnn.bin \ - --decoder-param-filename ./lstm_transducer_stateless2/exp/decoder_jit_trace-pnnx.ncnn.param \ - --decoder-bin-filename ./lstm_transducer_stateless2/exp/decoder_jit_trace-pnnx.ncnn.bin \ - --joiner-param-filename ./lstm_transducer_stateless2/exp/joiner_jit_trace-pnnx.ncnn.param \ - --joiner-bin-filename ./lstm_transducer_stateless2/exp/joiner_jit_trace-pnnx.ncnn.bin \ - /path/to/foo.wav - -.. code-block:: bash - - ./lstm_transducer_stateless2/streaming-ncnn-decode.py \ - --tokens ./data/lang_bpe_500/tokens.txt \ - --encoder-param-filename ./lstm_transducer_stateless2/exp/encoder_jit_trace-pnnx.ncnn.param \ - --encoder-bin-filename ./lstm_transducer_stateless2/exp/encoder_jit_trace-pnnx.ncnn.bin \ - --decoder-param-filename ./lstm_transducer_stateless2/exp/decoder_jit_trace-pnnx.ncnn.param \ - --decoder-bin-filename ./lstm_transducer_stateless2/exp/decoder_jit_trace-pnnx.ncnn.bin \ - --joiner-param-filename ./lstm_transducer_stateless2/exp/joiner_jit_trace-pnnx.ncnn.param \ - --joiner-bin-filename ./lstm_transducer_stateless2/exp/joiner_jit_trace-pnnx.ncnn.bin \ - /path/to/foo.wav - -To use the above generated files in C++, please see -``_ - -It is able to generate a static linked executable that can be run on Linux, Windows, -macOS, Raspberry Pi, etc, without external dependencies. - Download pretrained models -------------------------- From 4626c60c74d9a8bf455f994deea123553e3fe59b Mon Sep 17 00:00:00 2001 From: nihui Date: Fri, 17 Feb 2023 15:38:08 +0800 Subject: [PATCH 122/174] fix typo (#915) --- docs/source/model-export/export-ncnn-conv-emformer.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/model-export/export-ncnn-conv-emformer.rst b/docs/source/model-export/export-ncnn-conv-emformer.rst index d19c7dac8..133915da7 100644 --- a/docs/source/model-export/export-ncnn-conv-emformer.rst +++ b/docs/source/model-export/export-ncnn-conv-emformer.rst @@ -99,7 +99,7 @@ In the above code, we downloaded the pre-trained model into the directory Congratulations! You have successfully installed the following components: - - ``pnxx``, which is an executable located in + - ``pnnx``, which is an executable located in ``$HOME/open-source/ncnn/tools/pnnx/build/src``. We will use it to convert models exported by ``torch.jit.trace()``. - ``ncnn2int8``, which is an executable located in From c51e6c5b9c8e4d92b9e810e26202c3b3b633c519 Mon Sep 17 00:00:00 2001 From: marcoyang1998 <45973641+marcoyang1998@users.noreply.github.com> Date: Mon, 20 Feb 2023 19:04:57 +0800 Subject: [PATCH 123/174] fix typo (#916) --- egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py index 7388af389..bd2d6e258 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -2377,6 +2377,6 @@ def modified_beam_search_lm_shallow_fusion( return ans else: return DecodingResults( - tokens=ans, + hyps=ans, timestamps=ans_timestamps, ) From b7c85968aee94bcb1f16a553e9850faa3fb0c25f Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Wed, 22 Feb 2023 11:15:58 +0800 Subject: [PATCH 124/174] Use standard apache 2.0 license (#919) --- LICENSE | 9 --------- 1 file changed, 9 deletions(-) diff --git a/LICENSE b/LICENSE index ee06cfc77..d64569567 100644 --- a/LICENSE +++ b/LICENSE @@ -1,13 +1,4 @@ - Legal Notices - - NOTE (this is not from the Apache License): The copyright model is that - authors (or their employers, if noted in individual files) own their - individual contributions. The authors' contributions can be discerned - from the git history. - - ------------------------------------------------------------------------- - Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ From 8aaa9761e46c6d71e63096160ddee0197f64a5ff Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Mon, 27 Feb 2023 21:23:04 +0800 Subject: [PATCH 125/174] Add doc about exporting streaming zipformer to sherpa-ncnn (#927) --- docs/source/conf.py | 1 + ...t-zipformer-transducer-for-ncnn-output.txt | 74 ++++ ...ncnn-decode-zipformer-transducer-libri.txt | 7 + .../export-ncnn-conv-emformer.rst | 4 + .../model-export/export-ncnn-zipformer.rst | 383 ++++++++++++++++++ docs/source/model-export/export-ncnn.rst | 2 + docs/source/model-export/export-onnx.rst | 16 + 7 files changed, 487 insertions(+) create mode 100644 docs/source/model-export/code/export-zipformer-transducer-for-ncnn-output.txt create mode 100644 docs/source/model-export/code/test-streaming-ncnn-decode-zipformer-transducer-libri.txt create mode 100644 docs/source/model-export/export-ncnn-zipformer.rst diff --git a/docs/source/conf.py b/docs/source/conf.py index 6452c5d6d..6901dec02 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -81,6 +81,7 @@ todo_include_todos = True rst_epilog = """ .. _sherpa-ncnn: https://github.com/k2-fsa/sherpa-ncnn +.. _sherpa-onnx: https://github.com/k2-fsa/sherpa-onnx .. _icefall: https://github.com/k2-fsa/icefall .. _git-lfs: https://git-lfs.com/ .. _ncnn: https://github.com/tencent/ncnn diff --git a/docs/source/model-export/code/export-zipformer-transducer-for-ncnn-output.txt b/docs/source/model-export/code/export-zipformer-transducer-for-ncnn-output.txt new file mode 100644 index 000000000..25874a414 --- /dev/null +++ b/docs/source/model-export/code/export-zipformer-transducer-for-ncnn-output.txt @@ -0,0 +1,74 @@ +2023-02-27 20:23:07,473 INFO [export-for-ncnn.py:246] device: cpu +2023-02-27 20:23:07,477 INFO [export-for-ncnn.py:255] {'best_train_loss': inf, 'best_valid_loss': inf, 'best_train_epoch': -1, 'best_valid_epoch': -1, 'batch_idx_train': 0, 'log_interval': 50, 'reset_interval': 200, 'valid_interval': 3000, 'feature_dim': 80, 'subsampling_factor': 4, 'warm_step': 2000, 'env_info': {'k2-version': '1.23.4', 'k2-build-type': 'Release', 'k2-with-cuda': True, 'k2-git-sha1': '62e404dd3f3a811d73e424199b3408e309c06e1a', 'k2-git-date': 'Mon Jan 30 10:26:16 2023', 'lhotse-version': '1.12.0.dev+missing.version.file', 'torch-version': '1.10.0+cu102', 'torch-cuda-available': True, 'torch-cuda-version': '10.2', 'python-version': '3.8', 'icefall-git-branch': 'master', 'icefall-git-sha1': '6d7a559-clean', 'icefall-git-date': 'Thu Feb 16 19:47:54 2023', 'icefall-path': '/star-fj/fangjun/open-source/icefall-2', 'k2-path': '/star-fj/fangjun/open-source/k2/k2/python/k2/__init__.py', 'lhotse-path': '/star-fj/fangjun/open-source/lhotse/lhotse/__init__.py', 'hostname': 'de-74279-k2-train-3-1220120619-7695ff496b-s9n4w', 'IP address': '10.177.6.147'}, 'epoch': 99, 'iter': 0, 'avg': 1, 'exp_dir': PosixPath('icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp'), 'bpe_model': './icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/data/lang_bpe_500/bpe.model', 'context_size': 2, 'use_averaged_model': False, 'num_encoder_layers': '2,4,3,2,4', 'feedforward_dims': '1024,1024,2048,2048,1024', 'nhead': '8,8,8,8,8', 'encoder_dims': '384,384,384,384,384', 'attention_dims': '192,192,192,192,192', 'encoder_unmasked_dims': '256,256,256,256,256', 'zipformer_downsampling_factors': '1,2,4,8,2', 'cnn_module_kernels': '31,31,31,31,31', 'decoder_dim': 512, 'joiner_dim': 512, 'short_chunk_size': 50, 'num_left_chunks': 4, 'decode_chunk_len': 32, 'blank_id': 0, 'vocab_size': 500} +2023-02-27 20:23:07,477 INFO [export-for-ncnn.py:257] About to create model +2023-02-27 20:23:08,023 INFO [zipformer2.py:419] At encoder stack 4, which has downsampling_factor=2, we will combine the outputs of layers 1 and 3, with downsampling_factors=2 and 8. +2023-02-27 20:23:08,037 INFO [checkpoint.py:112] Loading checkpoint from icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/epoch-99.pt +2023-02-27 20:23:08,655 INFO [export-for-ncnn.py:346] encoder parameters: 68944004 +2023-02-27 20:23:08,655 INFO [export-for-ncnn.py:347] decoder parameters: 260096 +2023-02-27 20:23:08,655 INFO [export-for-ncnn.py:348] joiner parameters: 716276 +2023-02-27 20:23:08,656 INFO [export-for-ncnn.py:349] total parameters: 69920376 +2023-02-27 20:23:08,656 INFO [export-for-ncnn.py:351] Using torch.jit.trace() +2023-02-27 20:23:08,656 INFO [export-for-ncnn.py:353] Exporting encoder +2023-02-27 20:23:08,656 INFO [export-for-ncnn.py:174] decode_chunk_len: 32 +2023-02-27 20:23:08,656 INFO [export-for-ncnn.py:175] T: 39 +/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py:1344: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! + assert cached_len.size(0) == self.num_layers, ( +/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py:1348: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! + assert cached_avg.size(0) == self.num_layers, ( +/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py:1352: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! + assert cached_key.size(0) == self.num_layers, ( +/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py:1356: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! + assert cached_val.size(0) == self.num_layers, ( +/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py:1360: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! + assert cached_val2.size(0) == self.num_layers, ( +/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py:1364: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! + assert cached_conv1.size(0) == self.num_layers, ( +/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py:1368: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! + assert cached_conv2.size(0) == self.num_layers, ( +/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py:1373: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! + assert self.left_context_len == cached_key.shape[1], ( +/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py:1884: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! + assert self.x_size == x.size(0), (self.x_size, x.size(0)) +/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py:2442: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! + assert cached_key.shape[0] == self.left_context_len, ( +/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py:2449: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! + assert cached_key.shape[0] == cached_val.shape[0], ( +/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py:2469: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! + assert cached_key.shape[0] == left_context_len, ( +/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py:2473: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! + assert cached_val.shape[0] == left_context_len, ( +/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py:2483: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! + assert kv_len == k.shape[0], (kv_len, k.shape) +/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py:2570: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! + assert list(attn_output.size()) == [bsz * num_heads, seq_len, head_dim // 2] +/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py:2926: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! + assert cache.shape == (x.size(0), x.size(1), self.lorder), ( +/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py:2652: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! + assert x.shape[0] == self.x_size, (x.shape[0], self.x_size) +/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py:2653: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! + assert x.shape[2] == self.embed_dim, (x.shape[2], self.embed_dim) +/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py:2666: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! + assert cached_val.shape[0] == self.left_context_len, ( +/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py:1543: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! + assert src.shape[0] == self.in_x_size, (src.shape[0], self.in_x_size) +/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py:1637: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! + assert src.shape[0] == self.in_x_size, ( +/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py:1643: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! + assert src.shape[2] == self.in_channels, (src.shape[2], self.in_channels) +/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py:1571: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! + if src.shape[0] != self.in_x_size: +/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py:1763: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! + assert src1.shape[:-1] == src2.shape[:-1], (src1.shape, src2.shape) +/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py:1779: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! + assert src1.shape[-1] == self.dim1, (src1.shape[-1], self.dim1) +/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py:1780: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! + assert src2.shape[-1] == self.dim2, (src2.shape[-1], self.dim2) +/star-fj/fangjun/py38/lib/python3.8/site-packages/torch/jit/_trace.py:958: TracerWarning: Encountering a list at the output of the tracer might cause the trace to be incorrect, this is only valid if the container structure does not change based on the module's inputs. Consider using a constant container instead (e.g. for `list`, use a `tuple` instead. for `dict`, use a `NamedTuple` instead). If you absolutely need this and know the side effects, pass strict=False to trace() to allow this behavior. + module._c._create_method_from_trace( +2023-02-27 20:23:19,640 INFO [export-for-ncnn.py:182] Saved to icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/encoder_jit_trace-pnnx.pt +2023-02-27 20:23:19,646 INFO [export-for-ncnn.py:357] Exporting decoder +/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/decoder.py:102: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! + assert embedding_out.size(-1) == self.context_size +2023-02-27 20:23:19,686 INFO [export-for-ncnn.py:204] Saved to icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/decoder_jit_trace-pnnx.pt +2023-02-27 20:23:19,686 INFO [export-for-ncnn.py:361] Exporting joiner +2023-02-27 20:23:19,735 INFO [export-for-ncnn.py:231] Saved to icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/joiner_jit_trace-pnnx.pt diff --git a/docs/source/model-export/code/test-streaming-ncnn-decode-zipformer-transducer-libri.txt b/docs/source/model-export/code/test-streaming-ncnn-decode-zipformer-transducer-libri.txt new file mode 100644 index 000000000..5b4969e0f --- /dev/null +++ b/docs/source/model-export/code/test-streaming-ncnn-decode-zipformer-transducer-libri.txt @@ -0,0 +1,7 @@ +2023-02-27 20:43:40,283 INFO [streaming-ncnn-decode.py:349] {'tokens': './icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/data/lang_bpe_500/tokens.txt', 'encoder_param_filename': './icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/encoder_jit_trace-pnnx.ncnn.param', 'encoder_bin_filename': './icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/encoder_jit_trace-pnnx.ncnn.bin', 'decoder_param_filename': './icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/decoder_jit_trace-pnnx.ncnn.param', 'decoder_bin_filename': './icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/decoder_jit_trace-pnnx.ncnn.bin', 'joiner_param_filename': './icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/joiner_jit_trace-pnnx.ncnn.param', 'joiner_bin_filename': './icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/joiner_jit_trace-pnnx.ncnn.bin', 'sound_filename': './icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/test_wavs/1089-134686-0001.wav'} +2023-02-27 20:43:41,260 INFO [streaming-ncnn-decode.py:357] Constructing Fbank computer +2023-02-27 20:43:41,264 INFO [streaming-ncnn-decode.py:360] Reading sound files: ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/test_wavs/1089-134686-0001.wav +2023-02-27 20:43:41,269 INFO [streaming-ncnn-decode.py:365] torch.Size([106000]) +2023-02-27 20:43:41,280 INFO [streaming-ncnn-decode.py:372] number of states: 35 +2023-02-27 20:43:45,026 INFO [streaming-ncnn-decode.py:410] ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/test_wavs/1089-134686-0001.wav +2023-02-27 20:43:45,026 INFO [streaming-ncnn-decode.py:411] AFTER EARLY NIGHTFALL THE YELLOW LAMPS WOULD LIGHT UP HERE AND THERE THE SQUALID QUARTER OF THE BROTHELS diff --git a/docs/source/model-export/export-ncnn-conv-emformer.rst b/docs/source/model-export/export-ncnn-conv-emformer.rst index 133915da7..12b370143 100644 --- a/docs/source/model-export/export-ncnn-conv-emformer.rst +++ b/docs/source/model-export/export-ncnn-conv-emformer.rst @@ -166,6 +166,10 @@ Next, we use the following code to export our model: --memory-size 32 \ --encoder-dim 512 +.. caution:: + + If your model has different configuration parameters, please change them accordingly. + .. hint:: We have renamed our model to ``epoch-30.pt`` so that we can use ``--epoch 30``. diff --git a/docs/source/model-export/export-ncnn-zipformer.rst b/docs/source/model-export/export-ncnn-zipformer.rst new file mode 100644 index 000000000..5c81d25ca --- /dev/null +++ b/docs/source/model-export/export-ncnn-zipformer.rst @@ -0,0 +1,383 @@ +.. _export_streaming_zipformer_transducer_models_to_ncnn: + +Export streaming Zipformer transducer models to ncnn +---------------------------------------------------- + +We use the pre-trained model from the following repository as an example: + +``_ + +We will show you step by step how to export it to `ncnn`_ and run it with `sherpa-ncnn`_. + +.. hint:: + + We use ``Ubuntu 18.04``, ``torch 1.13``, and ``Python 3.8`` for testing. + +.. caution:: + + Please use a more recent version of PyTorch. For instance, ``torch 1.8`` + may ``not`` work. + +1. Download the pre-trained model +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. hint:: + + You have to install `git-lfs`_ before you continue. + + +.. code-block:: bash + + cd egs/librispeech/ASR + GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29 + cd icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29 + + git lfs pull --include "exp/pretrained.pt" + git lfs pull --include "data/lang_bpe_500/bpe.model" + + cd .. + +.. note:: + + We downloaded ``exp/pretrained-xxx.pt``, not ``exp/cpu-jit_xxx.pt``. + +In the above code, we downloaded the pre-trained model into the directory +``egs/librispeech/ASR/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29``. + +2. Install ncnn and pnnx +^^^^^^^^^^^^^^^^^^^^^^^^ + +Please refer to :ref:`export_for_ncnn_install_ncnn_and_pnnx` . + + +3. Export the model via torch.jit.trace() +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +First, let us rename our pre-trained model: + +.. code-block:: + + cd egs/librispeech/ASR + + cd icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp + + ln -s pretrained.pt epoch-99.pt + + cd ../.. + +Next, we use the following code to export our model: + +.. code-block:: bash + + dir=./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29 + + ./pruned_transducer_stateless7_streaming/export-for-ncnn.py \ + --bpe-model $dir/data/lang_bpe_500/bpe.model \ + --exp-dir $dir/exp \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + \ + --decode-chunk-len 32 \ + --num-left-chunks 4 \ + --num-encoder-layers "2,4,3,2,4" \ + --feedforward-dims "1024,1024,2048,2048,1024" \ + --nhead "8,8,8,8,8" \ + --encoder-dims "384,384,384,384,384" \ + --attention-dims "192,192,192,192,192" \ + --encoder-unmasked-dims "256,256,256,256,256" \ + --zipformer-downsampling-factors "1,2,4,8,2" \ + --cnn-module-kernels "31,31,31,31,31" \ + --decoder-dim 512 \ + --joiner-dim 512 + +.. caution:: + + If your model has different configuration parameters, please change them accordingly. + +.. hint:: + + We have renamed our model to ``epoch-99.pt`` so that we can use ``--epoch 99``. + There is only one pre-trained model, so we use ``--avg 1 --use-averaged-model 0``. + + If you have trained a model by yourself and if you have all checkpoints + available, please first use ``decode.py`` to tune ``--epoch --avg`` + and select the best combination with with ``--use-averaged-model 1``. + +.. note:: + + You will see the following log output: + + .. literalinclude:: ./code/export-zipformer-transducer-for-ncnn-output.txt + + The log shows the model has ``69920376`` parameters, i.e., ``~69.9 M``. + + .. code-block:: bash + + ls -lh icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/pretrained.pt + -rw-r--r-- 1 kuangfangjun root 269M Jan 12 12:53 icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/pretrained.pt + + You can see that the file size of the pre-trained model is ``269 MB``, which + is roughly equal to ``69920376*4/1024/1024 = 266.725 MB``. + +After running ``pruned_transducer_stateless7_streaming/export-for-ncnn.py``, +we will get the following files: + +.. code-block:: bash + + ls -lh icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/*pnnx.pt + + -rw-r--r-- 1 kuangfangjun root 1022K Feb 27 20:23 icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/decoder_jit_trace-pnnx.pt + -rw-r--r-- 1 kuangfangjun root 266M Feb 27 20:23 icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/encoder_jit_trace-pnnx.pt + -rw-r--r-- 1 kuangfangjun root 2.8M Feb 27 20:23 icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/joiner_jit_trace-pnnx.pt + +.. _zipformer-transducer-step-4-export-torchscript-model-via-pnnx: + +4. Export torchscript model via pnnx +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. hint:: + + Make sure you have set up the ``PATH`` environment variable + in :ref:`export_for_ncnn_install_ncnn_and_pnnx`. Otherwise, + it will throw an error saying that ``pnnx`` could not be found. + +Now, it's time to export our models to `ncnn`_ via ``pnnx``. + +.. code-block:: + + cd icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/ + + pnnx ./encoder_jit_trace-pnnx.pt + pnnx ./decoder_jit_trace-pnnx.pt + pnnx ./joiner_jit_trace-pnnx.pt + +It will generate the following files: + +.. code-block:: bash + + ls -lh icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/*ncnn*{bin,param} + + -rw-r--r-- 1 kuangfangjun root 509K Feb 27 20:31 icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/decoder_jit_trace-pnnx.ncnn.bin + -rw-r--r-- 1 kuangfangjun root 437 Feb 27 20:31 icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/decoder_jit_trace-pnnx.ncnn.param + -rw-r--r-- 1 kuangfangjun root 133M Feb 27 20:30 icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/encoder_jit_trace-pnnx.ncnn.bin + -rw-r--r-- 1 kuangfangjun root 152K Feb 27 20:30 icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/encoder_jit_trace-pnnx.ncnn.param + -rw-r--r-- 1 kuangfangjun root 1.4M Feb 27 20:31 icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/joiner_jit_trace-pnnx.ncnn.bin + -rw-r--r-- 1 kuangfangjun root 488 Feb 27 20:31 icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/joiner_jit_trace-pnnx.ncnn.param + +There are two types of files: + +- ``param``: It is a text file containing the model architectures. You can + use a text editor to view its content. +- ``bin``: It is a binary file containing the model parameters. + +We compare the file sizes of the models below before and after converting via ``pnnx``: + +.. see https://tableconvert.com/restructuredtext-generator + ++----------------------------------+------------+ +| File name | File size | ++==================================+============+ +| encoder_jit_trace-pnnx.pt | 266 MB | ++----------------------------------+------------+ +| decoder_jit_trace-pnnx.pt | 1022 KB | ++----------------------------------+------------+ +| joiner_jit_trace-pnnx.pt | 2.8 MB | ++----------------------------------+------------+ +| encoder_jit_trace-pnnx.ncnn.bin | 133 MB | ++----------------------------------+------------+ +| decoder_jit_trace-pnnx.ncnn.bin | 509 KB | ++----------------------------------+------------+ +| joiner_jit_trace-pnnx.ncnn.bin | 1.4 MB | ++----------------------------------+------------+ + +You can see that the file sizes of the models after conversion are about one half +of the models before conversion: + + - encoder: 266 MB vs 133 MB + - decoder: 1022 KB vs 509 KB + - joiner: 2.8 MB vs 1.4 MB + +The reason is that by default ``pnnx`` converts ``float32`` parameters +to ``float16``. A ``float32`` parameter occupies 4 bytes, while it is 2 bytes +for ``float16``. Thus, it is ``twice smaller`` after conversion. + +.. hint:: + + If you use ``pnnx ./encoder_jit_trace-pnnx.pt fp16=0``, then ``pnnx`` + won't convert ``float32`` to ``float16``. + +5. Test the exported models in icefall +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. note:: + + We assume you have set up the environment variable ``PYTHONPATH`` when + building `ncnn`_. + +Now we have successfully converted our pre-trained model to `ncnn`_ format. +The generated 6 files are what we need. You can use the following code to +test the converted models: + +.. code-block:: bash + + python3 ./pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py \ + --tokens ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/data/lang_bpe_500/tokens.txt \ + --encoder-param-filename ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/encoder_jit_trace-pnnx.ncnn.param \ + --encoder-bin-filename ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/encoder_jit_trace-pnnx.ncnn.bin \ + --decoder-param-filename ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/decoder_jit_trace-pnnx.ncnn.param \ + --decoder-bin-filename ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/decoder_jit_trace-pnnx.ncnn.bin \ + --joiner-param-filename ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/joiner_jit_trace-pnnx.ncnn.param \ + --joiner-bin-filename ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/joiner_jit_trace-pnnx.ncnn.bin \ + ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/test_wavs/1089-134686-0001.wav + +.. hint:: + + `ncnn`_ supports only ``batch size == 1``, so ``streaming-ncnn-decode.py`` accepts + only 1 wave file as input. + +The output is given below: + +.. literalinclude:: ./code/test-streaming-ncnn-decode-zipformer-transducer-libri.txt + +Congratulations! You have successfully exported a model from PyTorch to `ncnn`_! + +.. _zipformer-modify-the-exported-encoder-for-sherpa-ncnn: + +6. Modify the exported encoder for sherpa-ncnn +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +In order to use the exported models in `sherpa-ncnn`_, we have to modify +``encoder_jit_trace-pnnx.ncnn.param``. + +Let us have a look at the first few lines of ``encoder_jit_trace-pnnx.ncnn.param``: + +.. code-block:: + + 7767517 + 2028 2547 + Input in0 0 1 in0 + +**Explanation** of the above three lines: + + 1. ``7767517``, it is a magic number and should not be changed. + 2. ``2028 2547``, the first number ``2028`` specifies the number of layers + in this file, while ``2547`` specifies the number of intermediate outputs + of this file + 3. ``Input in0 0 1 in0``, ``Input`` is the layer type of this layer; ``in0`` + is the layer name of this layer; ``0`` means this layer has no input; + ``1`` means this layer has one output; ``in0`` is the output name of + this layer. + +We need to add 1 extra line and also increment the number of layers. +The result looks like below: + +.. code-block:: bash + + 7767517 + 2029 2547 + SherpaMetaData sherpa_meta_data1 0 0 0=2 1=32 2=4 3=7 -23316=5,2,4,3,2,4 -23317=5,384,384,384,384,384 -23318=5,192,192,192,192,192 -23319=5,1,2,4,8,2 -23320=5,31,31,31,31,31 + Input in0 0 1 in0 + +**Explanation** + + 1. ``7767517``, it is still the same + 2. ``2029 2547``, we have added an extra layer, so we need to update ``2028`` to ``2029``. + We don't need to change ``2547`` since the newly added layer has no inputs or outputs. + 3. ``SherpaMetaData sherpa_meta_data1 0 0 0=2 1=32 2=4 3=7 -23316=5,2,4,3,2,4 -23317=5,384,384,384,384,384 -23318=5,192,192,192,192,192 -23319=5,1,2,4,8,2 -23320=5,31,31,31,31,31`` + This line is newly added. Its explanation is given below: + + - ``SherpaMetaData`` is the type of this layer. Must be ``SherpaMetaData``. + - ``sherpa_meta_data1`` is the name of this layer. Must be ``sherpa_meta_data1``. + - ``0 0`` means this layer has no inputs or output. Must be ``0 0`` + - ``0=2``, 0 is the key and 2 is the value. MUST be ``0=2`` + - ``1=32``, 1 is the key and 32 is the value of the + parameter ``--decode-chunk-len`` that you provided when running + ``./pruned_transducer_stateless7_streaming/export-for-ncnn.py``. + - ``2=4``, 2 is the key and 4 is the value of the + parameter ``--num-left-chunks`` that you provided when running + ``./pruned_transducer_stateless7_streaming/export-for-ncnn.py``. + - ``3=7``, 3 is the key and 7 is the value of for the amount of padding + used in the Conv2DSubsampling layer. It should be 7 for zipformer + if you don't change zipformer.py. + - ``-23316=5,2,4,3,2,4``, attribute 16, this is an array attribute. + It is attribute 16 since -23300 - (-23316) = 16. + The first element of the array is the length of the array, which is 5 in our case. + ``2,4,3,2,4`` is the value of ``--num-encoder-layers``that you provided + when running ``./pruned_transducer_stateless7_streaming/export-for-ncnn.py``. + - ``-23317=5,384,384,384,384,384``, attribute 17. + The first element of the array is the length of the array, which is 5 in our case. + ``384,384,384,384,384`` is the value of ``--encoder-dims``that you provided + when running ``./pruned_transducer_stateless7_streaming/export-for-ncnn.py``. + - ``-23318=5,192,192,192,192,192``, attribute 18. + The first element of the array is the length of the array, which is 5 in our case. + ``192,192,192,192,192`` is the value of ``--attention-dims`` that you provided + when running ``./pruned_transducer_stateless7_streaming/export-for-ncnn.py``. + - ``-23319=5,1,2,4,8,2``, attribute 19. + The first element of the array is the length of the array, which is 5 in our case. + ``1,2,4,8,2`` is the value of ``--zipformer-downsampling-factors`` that you provided + when running ``./pruned_transducer_stateless7_streaming/export-for-ncnn.py``. + - ``-23320=5,31,31,31,31,31``, attribute 20. + The first element of the array is the length of the array, which is 5 in our case. + ``31,31,31,31,31`` is the value of ``--cnn-module-kernels`` that you provided + when running ``./pruned_transducer_stateless7_streaming/export-for-ncnn.py``. + + For ease of reference, we list the key-value pairs that you need to add + in the following table. If your model has a different setting, please + change the values for ``SherpaMetaData`` accordingly. Otherwise, you + will be ``SAD``. + + +----------+--------------------------------------------+ + | key | value | + +==========+============================================+ + | 0 | 2 (fixed) | + +----------+--------------------------------------------+ + | 1 | ``-decode-chunk-len`` | + +----------+--------------------------------------------+ + | 2 | ``--num-left-chunks`` | + +----------+--------------------------------------------+ + | 3 | 7 (if you don't change code) | + +----------+--------------------------------------------+ + |-23316 | ``--num-encoder-layer`` | + +----------+--------------------------------------------+ + |-23317 | ``--encoder-dims`` | + +----------+--------------------------------------------+ + |-23318 | ``--attention-dims`` | + +----------+--------------------------------------------+ + |-23319 | ``--zipformer-downsampling-factors`` | + +----------+--------------------------------------------+ + |-23320 | ``--cnn-module-kernels`` | + +----------+--------------------------------------------+ + + 4. ``Input in0 0 1 in0``. No need to change it. + +.. caution:: + + When you add a new layer ``SherpaMetaData``, please remember to update the + number of layers. In our case, update ``2028`` to ``2029``. Otherwise, + you will be SAD later. + +.. hint:: + + After adding the new layer ``SherpaMetaData``, you cannot use this model + with ``streaming-ncnn-decode.py`` anymore since ``SherpaMetaData`` is + supported only in `sherpa-ncnn`_. + +.. hint:: + + `ncnn`_ is very flexible. You can add new layers to it just by text-editing + the ``param`` file! You don't need to change the ``bin`` file. + +Now you can use this model in `sherpa-ncnn`_. +Please refer to the following documentation: + + - Linux/macOS/Windows/arm/aarch64: ``_ + - ``Android``: ``_ + - ``iOS``: ``_ + - Python: ``_ + +We have a list of pre-trained models that have been exported for `sherpa-ncnn`_: + + - ``_ + + You can find more usages there. diff --git a/docs/source/model-export/export-ncnn.rst b/docs/source/model-export/export-ncnn.rst index 841d1d4de..9eb5f85d2 100644 --- a/docs/source/model-export/export-ncnn.rst +++ b/docs/source/model-export/export-ncnn.rst @@ -21,6 +21,7 @@ It has been tested on the following platforms: - ``iOS`` - ``Raspberry Pi`` - `爱芯派 `_ (`MAIX-III AXera-Pi `_). + - `RV1126 `_ `sherpa-ncnn`_ is self-contained and can be statically linked to produce a binary containing everything needed. Please refer @@ -31,5 +32,6 @@ to its documentation for details: .. toctree:: + export-ncnn-zipformer export-ncnn-conv-emformer export-ncnn-lstm diff --git a/docs/source/model-export/export-onnx.rst b/docs/source/model-export/export-onnx.rst index 8f0cb11fb..aa77204cb 100644 --- a/docs/source/model-export/export-onnx.rst +++ b/docs/source/model-export/export-onnx.rst @@ -9,6 +9,22 @@ to export trained models to `ONNX`_. There is also a file named ``onnx_pretrained.py``, which you can use the exported `ONNX`_ model in Python with `onnxruntime`_ to decode sound files. +sherpa-onnx +----------- + +We have a separate repository `sherpa-onnx`_ for deploying your exported models +on various platforms such as: + + - iOS + - Android + - Raspberry Pi + - Linux/macOS/Windows + + +Please see the documentation of `sherpa-onnx`_ for details: + + ``_ + Example ------- From 07243d136a2aa42c71eda7a7f9ada10a07e82662 Mon Sep 17 00:00:00 2001 From: pehonnet Date: Wed, 8 Mar 2023 14:06:07 +0100 Subject: [PATCH 126/174] remove key from result filename (#936) Co-authored-by: pe-honnet --- .../ASR/pruned_transducer_stateless2/decode.py | 6 +++--- egs/aishell/ASR/pruned_transducer_stateless2/decode.py | 6 +++--- egs/aishell/ASR/pruned_transducer_stateless3/decode.py | 6 +++--- egs/aishell/ASR/transducer_stateless/decode.py | 6 +++--- egs/aishell/ASR/transducer_stateless_modified-2/decode.py | 6 +++--- egs/aishell/ASR/transducer_stateless_modified/decode.py | 6 +++--- egs/aishell2/ASR/pruned_transducer_stateless5/decode.py | 6 +++--- egs/aishell4/ASR/pruned_transducer_stateless5/decode.py | 6 +++--- egs/alimeeting/ASR/pruned_transducer_stateless2/decode.py | 6 +++--- .../ASR_v2/pruned_transducer_stateless7/decode.py | 6 +++--- egs/ami/ASR/pruned_transducer_stateless7/decode.py | 8 ++++---- .../ASR/pruned_transducer_stateless7_streaming/decode.py | 6 +++--- egs/gigaspeech/ASR/pruned_transducer_stateless2/decode.py | 6 +++--- egs/librispeech/ASR/conformer_ctc3/decode.py | 8 ++++---- .../ASR/conv_emformer_transducer_stateless/decode.py | 6 +++--- .../streaming_decode.py | 6 +++--- .../ASR/conv_emformer_transducer_stateless2/decode.py | 6 +++--- .../streaming_decode.py | 6 +++--- egs/librispeech/ASR/lstm_transducer_stateless/decode.py | 6 +++--- .../ASR/lstm_transducer_stateless/streaming_decode.py | 6 +++--- egs/librispeech/ASR/lstm_transducer_stateless2/decode.py | 6 +++--- egs/librispeech/ASR/lstm_transducer_stateless3/decode.py | 8 ++++---- .../ASR/lstm_transducer_stateless3/streaming_decode.py | 6 +++--- egs/librispeech/ASR/pruned2_knowledge/decode.py | 6 +++--- .../ASR/pruned_stateless_emformer_rnnt2/decode.py | 6 +++--- egs/librispeech/ASR/pruned_transducer_stateless/decode.py | 6 +++--- .../ASR/pruned_transducer_stateless/streaming_decode.py | 6 +++--- .../ASR/pruned_transducer_stateless2/decode.py | 6 +++--- .../ASR/pruned_transducer_stateless2/streaming_decode.py | 6 +++--- .../ASR/pruned_transducer_stateless3/decode.py | 6 +++--- .../ASR/pruned_transducer_stateless3/streaming_decode.py | 6 +++--- .../ASR/pruned_transducer_stateless4/decode.py | 8 ++++---- .../ASR/pruned_transducer_stateless4/streaming_decode.py | 6 +++--- .../ASR/pruned_transducer_stateless5/decode.py | 6 +++--- .../ASR/pruned_transducer_stateless5/streaming_decode.py | 6 +++--- .../ASR/pruned_transducer_stateless6/decode.py | 6 +++--- .../ASR/pruned_transducer_stateless7/decode.py | 6 +++--- .../ASR/pruned_transducer_stateless7_ctc/ctc_decode.py | 6 +++--- .../ASR/pruned_transducer_stateless7_ctc/decode.py | 6 +++--- .../ASR/pruned_transducer_stateless7_ctc_bs/ctc_decode.py | 6 +++--- .../ASR/pruned_transducer_stateless7_ctc_bs/decode.py | 6 +++--- .../ASR/pruned_transducer_stateless7_streaming/decode.py | 6 +++--- .../streaming_decode.py | 6 +++--- .../ASR/pruned_transducer_stateless8/decode.py | 6 +++--- egs/librispeech/ASR/transducer/decode.py | 6 +++--- egs/librispeech/ASR/transducer_lstm/decode.py | 6 +++--- egs/librispeech/ASR/transducer_stateless/decode.py | 6 +++--- egs/librispeech/ASR/transducer_stateless2/decode.py | 6 +++--- .../ASR/transducer_stateless_multi_datasets/decode.py | 6 +++--- egs/librispeech/ASR/zipformer_mmi/decode.py | 6 +++--- egs/mgb2/ASR/pruned_transducer_stateless5/decode.py | 6 +++--- egs/spgispeech/ASR/pruned_transducer_stateless2/decode.py | 8 ++++---- egs/tal_csasr/ASR/pruned_transducer_stateless5/decode.py | 6 +++--- egs/tedlium3/ASR/pruned_transducer_stateless/decode.py | 6 +++--- egs/tedlium3/ASR/transducer_stateless/decode.py | 6 +++--- .../ASR/pruned_transducer_stateless2/decode.py | 6 +++--- .../ASR/pruned_transducer_stateless5/decode.py | 6 +++--- .../ASR/pruned_transducer_stateless5/streaming_decode.py | 6 +++--- .../ASR/pruned_transducer_stateless5/decode.py | 6 +++--- .../ASR/pruned_transducer_stateless7/decode.py | 6 +++--- 60 files changed, 185 insertions(+), 185 deletions(-) diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/decode.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/decode.py index d0f118959..090f7ff84 100755 --- a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/decode.py @@ -392,7 +392,7 @@ def save_results( test_set_wers = dict() for key, results in results_dict.items(): recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" ) results = sorted(results) store_transcripts(filename=recog_path, texts=results) @@ -401,7 +401,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" ) with open(errs_filename, "w") as f: wer = write_error_stats( @@ -413,7 +413,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) diff --git a/egs/aishell/ASR/pruned_transducer_stateless2/decode.py b/egs/aishell/ASR/pruned_transducer_stateless2/decode.py index 20a4f21c7..04888fbc1 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/aishell/ASR/pruned_transducer_stateless2/decode.py @@ -389,7 +389,7 @@ def save_results( test_set_wers = dict() for key, results in results_dict.items(): recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" ) results = sorted(results) store_transcripts(filename=recog_path, texts=results) @@ -398,7 +398,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" ) # we compute CER for aishell dataset. results_char = [] @@ -414,7 +414,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/decode.py b/egs/aishell/ASR/pruned_transducer_stateless3/decode.py index bac829ae1..6e97f338f 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless3/decode.py +++ b/egs/aishell/ASR/pruned_transducer_stateless3/decode.py @@ -407,7 +407,7 @@ def save_results( test_set_wers = dict() for key, results in results_dict.items(): recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" ) results = sorted(results) store_transcripts(filename=recog_path, texts=results) @@ -416,7 +416,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" ) # we compute CER for aishell dataset. results_char = [] @@ -432,7 +432,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tCER", file=f) diff --git a/egs/aishell/ASR/transducer_stateless/decode.py b/egs/aishell/ASR/transducer_stateless/decode.py index e019d2329..d57fe6de4 100755 --- a/egs/aishell/ASR/transducer_stateless/decode.py +++ b/egs/aishell/ASR/transducer_stateless/decode.py @@ -326,7 +326,7 @@ def save_results( test_set_wers = dict() for key, results in results_dict.items(): recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" ) results = sorted(results) store_transcripts(filename=recog_path, texts=results) @@ -334,7 +334,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" ) # we compute CER for aishell dataset. results_char = [] @@ -350,7 +350,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tCER", file=f) diff --git a/egs/aishell/ASR/transducer_stateless_modified-2/decode.py b/egs/aishell/ASR/transducer_stateless_modified-2/decode.py index 41cc1c01c..743fc7f45 100755 --- a/egs/aishell/ASR/transducer_stateless_modified-2/decode.py +++ b/egs/aishell/ASR/transducer_stateless_modified-2/decode.py @@ -371,7 +371,7 @@ def save_results( test_set_wers = dict() for key, results in results_dict.items(): recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" ) results = sorted(results) store_transcripts(filename=recog_path, texts=results) @@ -380,7 +380,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" ) # we compute CER for aishell dataset. results_char = [] @@ -396,7 +396,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tCER", file=f) diff --git a/egs/aishell/ASR/transducer_stateless_modified/decode.py b/egs/aishell/ASR/transducer_stateless_modified/decode.py index 7c06e6e51..9a1645915 100755 --- a/egs/aishell/ASR/transducer_stateless_modified/decode.py +++ b/egs/aishell/ASR/transducer_stateless_modified/decode.py @@ -375,7 +375,7 @@ def save_results( test_set_wers = dict() for key, results in results_dict.items(): recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" ) results = sorted(results) store_transcripts(filename=recog_path, texts=results) @@ -384,7 +384,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" ) # we compute CER for aishell dataset. results_char = [] @@ -400,7 +400,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tCER", file=f) diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/decode.py b/egs/aishell2/ASR/pruned_transducer_stateless5/decode.py index b5da0959b..80194ad12 100755 --- a/egs/aishell2/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/aishell2/ASR/pruned_transducer_stateless5/decode.py @@ -544,7 +544,7 @@ def save_results( test_set_wers = dict() for key, results in results_dict.items(): recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" ) results = sorted(results) store_transcripts(filename=recog_path, texts=results) @@ -553,7 +553,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" ) with open(errs_filename, "w") as f: wer = write_error_stats( @@ -565,7 +565,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/decode.py b/egs/aishell4/ASR/pruned_transducer_stateless5/decode.py index 37d766ec8..eb202f8a8 100755 --- a/egs/aishell4/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/aishell4/ASR/pruned_transducer_stateless5/decode.py @@ -407,7 +407,7 @@ def save_results( test_set_wers = dict() for key, results in results_dict.items(): recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" ) results = sorted(results) store_transcripts(filename=recog_path, texts=results) @@ -416,7 +416,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" ) with open(errs_filename, "w") as f: wer = write_error_stats( @@ -428,7 +428,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) diff --git a/egs/alimeeting/ASR/pruned_transducer_stateless2/decode.py b/egs/alimeeting/ASR/pruned_transducer_stateless2/decode.py index e4a90ef71..675f0739f 100755 --- a/egs/alimeeting/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/alimeeting/ASR/pruned_transducer_stateless2/decode.py @@ -392,7 +392,7 @@ def save_results( test_set_wers = dict() for key, results in results_dict.items(): recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" ) results = sorted(results) store_transcripts(filename=recog_path, texts=results) @@ -401,7 +401,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" ) with open(errs_filename, "w") as f: wer = write_error_stats( @@ -413,7 +413,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) diff --git a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/decode.py b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/decode.py index 53381c1f4..9a7eef9bf 100755 --- a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/decode.py +++ b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/decode.py @@ -463,7 +463,7 @@ def save_results( test_set_wers = dict() for key, results in results_dict.items(): recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" ) results = sorted(results) store_transcripts(filename=recog_path, texts=results) @@ -472,7 +472,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" ) with open(errs_filename, "w") as f: wer = write_error_stats( @@ -484,7 +484,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) diff --git a/egs/ami/ASR/pruned_transducer_stateless7/decode.py b/egs/ami/ASR/pruned_transducer_stateless7/decode.py index f47228fbe..fc4005325 100755 --- a/egs/ami/ASR/pruned_transducer_stateless7/decode.py +++ b/egs/ami/ASR/pruned_transducer_stateless7/decode.py @@ -479,7 +479,7 @@ def save_results( test_set_cers = dict() for key, results in results_dict.items(): recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" ) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") @@ -487,7 +487,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. wers_filename = ( - params.res_dir / f"wers-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wers-{test_set_name}-{params.suffix}.txt" ) with open(wers_filename, "w") as f: wer = write_error_stats( @@ -500,7 +500,7 @@ def save_results( for res in results: results_char.append((res[0], list("".join(res[1])), list("".join(res[2])))) cers_filename = ( - params.res_dir / f"cers-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"cers-{test_set_name}-{params.suffix}.txt" ) with open(cers_filename, "w") as f: cer = write_error_stats( @@ -513,7 +513,7 @@ def save_results( test_set_wers = {k: v for k, v in sorted(test_set_wers.items(), key=lambda x: x[1])} test_set_cers = {k: v for k, v in sorted(test_set_cers.items(), key=lambda x: x[1])} errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER\tCER", file=f) diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/decode.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/decode.py index 19d3c79c8..c5892f511 100755 --- a/egs/csj/ASR/pruned_transducer_stateless7_streaming/decode.py +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/decode.py @@ -600,7 +600,7 @@ def save_results( test_set_wers = dict() for key, results in results_dict.items(): recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" ) results = sorted(results) store_transcripts(filename=recog_path, texts=results) @@ -610,7 +610,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" ) with open(errs_filename, "w") as f: wer = write_error_stats( @@ -622,7 +622,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/decode.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/decode.py index 8595c27bd..27ce41c87 100755 --- a/egs/gigaspeech/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/gigaspeech/ASR/pruned_transducer_stateless2/decode.py @@ -400,7 +400,7 @@ def save_results( test_set_wers = dict() for key, results in results_dict.items(): recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" ) results = post_processing(results) results = sorted(results) @@ -410,7 +410,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" ) with open(errs_filename, "w") as f: wer = write_error_stats( @@ -422,7 +422,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) diff --git a/egs/librispeech/ASR/conformer_ctc3/decode.py b/egs/librispeech/ASR/conformer_ctc3/decode.py index 6fbf9d674..cdee1ec9c 100755 --- a/egs/librispeech/ASR/conformer_ctc3/decode.py +++ b/egs/librispeech/ASR/conformer_ctc3/decode.py @@ -729,7 +729,7 @@ def save_results( test_set_delays = dict() for key, results in results_dict.items(): recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" ) results = sorted(results) store_transcripts_and_timestamps(filename=recog_path, texts=results) @@ -738,7 +738,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" ) with open(errs_filename, "w") as f: wer, mean_delay, var_delay = write_error_stats_with_timestamps( @@ -755,7 +755,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -766,7 +766,7 @@ def save_results( test_set_delays = sorted(test_set_delays.items(), key=lambda x: x[1][0][0]) delays_info = ( params.res_dir - / f"symbol-delay-summary-{test_set_name}-{key}-{params.suffix}.txt" + / f"symbol-delay-summary-{test_set_name}-{params.suffix}.txt" ) with open(delays_info, "w") as f: print("settings\t(start, end) symbol-delay (s) (start, end)", file=f) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/decode.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/decode.py index 365e8b8a7..5d241ccbf 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/decode.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/decode.py @@ -433,7 +433,7 @@ def save_results( test_set_wers = dict() for key, results in results_dict.items(): recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" ) results = sorted(results) store_transcripts(filename=recog_path, texts=results) @@ -442,7 +442,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" ) with open(errs_filename, "w") as f: wer = write_error_stats( @@ -454,7 +454,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py index c93125c80..e6c9d2ca2 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py @@ -751,7 +751,7 @@ def save_results( test_set_wers = dict() for key, results in results_dict.items(): recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" ) store_transcripts(filename=recog_path, texts=sorted(results)) logging.info(f"The transcripts are stored in {recog_path}") @@ -759,7 +759,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" ) with open(errs_filename, "w") as f: wer = write_error_stats( @@ -771,7 +771,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/decode.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/decode.py index 78e1f4096..f9c1633d8 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/decode.py @@ -433,7 +433,7 @@ def save_results( test_set_wers = dict() for key, results in results_dict.items(): recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" ) results = sorted(results) store_transcripts(filename=recog_path, texts=results) @@ -442,7 +442,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" ) with open(errs_filename, "w") as f: wer = write_error_stats( @@ -454,7 +454,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming_decode.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming_decode.py index b2cb2c96b..6b3c1b563 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming_decode.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming_decode.py @@ -751,7 +751,7 @@ def save_results( test_set_wers = dict() for key, results in results_dict.items(): recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" ) store_transcripts(filename=recog_path, texts=sorted(results)) logging.info(f"The transcripts are stored in {recog_path}") @@ -759,7 +759,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" ) with open(errs_filename, "w") as f: wer = write_error_stats( @@ -771,7 +771,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/decode.py b/egs/librispeech/ASR/lstm_transducer_stateless/decode.py index 3ad08f56a..6dc11bdb2 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless/decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/decode.py @@ -567,7 +567,7 @@ def save_results( test_set_wers = dict() for key, results in results_dict.items(): recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" ) results = sorted(results) store_transcripts(filename=recog_path, texts=results) @@ -576,7 +576,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" ) with open(errs_filename, "w") as f: wer = write_error_stats( @@ -588,7 +588,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/streaming_decode.py b/egs/librispeech/ASR/lstm_transducer_stateless/streaming_decode.py index 961d8ddfb..d510d9659 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless/streaming_decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/streaming_decode.py @@ -743,7 +743,7 @@ def save_results( test_set_wers = dict() for key, results in results_dict.items(): recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" ) store_transcripts(filename=recog_path, texts=sorted(results)) logging.info(f"The transcripts are stored in {recog_path}") @@ -751,7 +751,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" ) with open(errs_filename, "w") as f: wer = write_error_stats( @@ -763,7 +763,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py b/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py index 78be9c01f..15e1109f2 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py @@ -703,7 +703,7 @@ def save_results( test_set_wers = dict() for key, results in results_dict.items(): recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" ) results = sorted(results) store_transcripts(filename=recog_path, texts=results) @@ -712,7 +712,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" ) with open(errs_filename, "w") as f: wer = write_error_stats( @@ -724,7 +724,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/decode.py b/egs/librispeech/ASR/lstm_transducer_stateless3/decode.py index a380bc470..7ac9d5f34 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/decode.py @@ -612,7 +612,7 @@ def save_results( test_set_delays = dict() for key, results in results_dict.items(): recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" ) results = sorted(results) store_transcripts_and_timestamps(filename=recog_path, texts=results) @@ -621,7 +621,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" ) with open(errs_filename, "w") as f: wer, mean_delay, var_delay = write_error_stats_with_timestamps( @@ -634,7 +634,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -644,7 +644,7 @@ def save_results( test_set_delays = sorted(test_set_delays.items(), key=lambda x: x[1][0]) delays_info = ( params.res_dir - / f"symbol-delay-summary-{test_set_name}-{key}-{params.suffix}.txt" + / f"symbol-delay-summary-{test_set_name}-{params.suffix}.txt" ) with open(delays_info, "w") as f: print("settings\tsymbol-delay", file=f) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/streaming_decode.py b/egs/librispeech/ASR/lstm_transducer_stateless3/streaming_decode.py index 109746ed5..b8b6e4f43 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/streaming_decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/streaming_decode.py @@ -743,7 +743,7 @@ def save_results( test_set_wers = dict() for key, results in results_dict.items(): recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" ) store_transcripts(filename=recog_path, texts=sorted(results)) logging.info(f"The transcripts are stored in {recog_path}") @@ -751,7 +751,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" ) with open(errs_filename, "w") as f: wer = write_error_stats( @@ -763,7 +763,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) diff --git a/egs/librispeech/ASR/pruned2_knowledge/decode.py b/egs/librispeech/ASR/pruned2_knowledge/decode.py index 40d14bb5a..f22731469 100755 --- a/egs/librispeech/ASR/pruned2_knowledge/decode.py +++ b/egs/librispeech/ASR/pruned2_knowledge/decode.py @@ -387,7 +387,7 @@ def save_results( test_set_wers = dict() for key, results in results_dict.items(): recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" ) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") @@ -395,7 +395,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" ) with open(errs_filename, "w") as f: wer = write_error_stats( @@ -407,7 +407,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) diff --git a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/decode.py b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/decode.py index 0e3b7ff74..ea7692f49 100755 --- a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/decode.py +++ b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/decode.py @@ -421,7 +421,7 @@ def save_results( test_set_wers = dict() for key, results in results_dict.items(): recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" ) results = sorted(results) store_transcripts(filename=recog_path, texts=results) @@ -430,7 +430,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" ) with open(errs_filename, "w") as f: wer = write_error_stats( @@ -442,7 +442,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py index 0444afe40..8a719ae3b 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py @@ -586,7 +586,7 @@ def save_results( test_set_wers = dict() for key, results in results_dict.items(): recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" ) results = sorted(results) store_transcripts(filename=recog_path, texts=results) @@ -595,7 +595,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" ) with open(errs_filename, "w") as f: wer = write_error_stats( @@ -607,7 +607,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py index fbc39fb65..28c40c780 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py @@ -424,7 +424,7 @@ def save_results( test_set_wers = dict() for key, results in results_dict.items(): recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" ) # sort results so we can easily compare the difference between two # recognition results @@ -435,7 +435,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" ) with open(errs_filename, "w") as f: wer = write_error_stats( @@ -447,7 +447,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py index 5f135f219..2791a60de 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py @@ -610,7 +610,7 @@ def save_results( test_set_wers = dict() for key, results in results_dict.items(): recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" ) results = sorted(results) store_transcripts(filename=recog_path, texts=results) @@ -619,7 +619,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" ) with open(errs_filename, "w") as f: wer = write_error_stats( @@ -631,7 +631,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py index bb08246d9..eac8f8393 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py @@ -426,7 +426,7 @@ def save_results( test_set_wers = dict() for key, results in results_dict.items(): recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" ) # sort results so we can easily compare the difference between two # recognition results @@ -437,7 +437,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" ) with open(errs_filename, "w") as f: wer = write_error_stats( @@ -449,7 +449,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py index 109a94a69..298c6c950 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py @@ -870,7 +870,7 @@ def save_results( test_set_wers = dict() for key, results in results_dict.items(): recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" ) results = sorted(results) store_transcripts(filename=recog_path, texts=results) @@ -879,7 +879,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" ) with open(errs_filename, "w") as f: wer = write_error_stats( @@ -891,7 +891,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_decode.py index 0e5111f33..421bfb0b7 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_decode.py @@ -427,7 +427,7 @@ def save_results( test_set_wers = dict() for key, results in results_dict.items(): recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" ) results = sorted(results) store_transcripts(filename=recog_path, texts=results) @@ -436,7 +436,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" ) with open(errs_filename, "w") as f: wer = write_error_stats( @@ -448,7 +448,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py index c44db0206..dca2ec081 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py @@ -656,7 +656,7 @@ def save_results( test_set_delays = dict() for key, results in results_dict.items(): recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" ) results = sorted(results) store_transcripts_and_timestamps(filename=recog_path, texts=results) @@ -665,7 +665,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" ) with open(errs_filename, "w") as f: wer, mean_delay, var_delay = write_error_stats_with_timestamps( @@ -678,7 +678,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -688,7 +688,7 @@ def save_results( test_set_delays = sorted(test_set_delays.items(), key=lambda x: x[1][0]) delays_info = ( params.res_dir - / f"symbol-delay-summary-{test_set_name}-{key}-{params.suffix}.txt" + / f"symbol-delay-summary-{test_set_name}-{params.suffix}.txt" ) with open(delays_info, "w") as f: print("settings\tsymbol-delay", file=f) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_decode.py index c4e3cef16..cb5d52859 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_decode.py @@ -443,7 +443,7 @@ def save_results( test_set_wers = dict() for key, results in results_dict.items(): recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" ) results = sorted(results) store_transcripts(filename=recog_path, texts=results) @@ -452,7 +452,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" ) with open(errs_filename, "w") as f: wer = write_error_stats( @@ -464,7 +464,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py index 90b0fcf4b..5c5d3ecd9 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py @@ -736,7 +736,7 @@ def save_results( test_set_wers = dict() for key, results in results_dict.items(): recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" ) results = sorted(results) store_transcripts(filename=recog_path, texts=results) @@ -745,7 +745,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" ) with open(errs_filename, "w") as f: wer = write_error_stats( @@ -757,7 +757,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless5/streaming_decode.py index 064811f1c..ae221eaba 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/streaming_decode.py @@ -443,7 +443,7 @@ def save_results( test_set_wers = dict() for key, results in results_dict.items(): recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" ) results = sorted(results) store_transcripts(filename=recog_path, texts=results) @@ -452,7 +452,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" ) with open(errs_filename, "w") as f: wer = write_error_stats( @@ -464,7 +464,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless6/decode.py index fd9de052a..c81186295 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/decode.py @@ -417,7 +417,7 @@ def save_results( test_set_wers = dict() for key, results in results_dict.items(): recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" ) results = sorted(results) store_transcripts(filename=recog_path, texts=results) @@ -426,7 +426,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" ) with open(errs_filename, "w") as f: wer = write_error_stats( @@ -438,7 +438,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py index b9bce465f..856ef845a 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py @@ -723,7 +723,7 @@ def save_results( test_set_wers = dict() for key, results in results_dict.items(): recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" ) results = sorted(results) store_transcripts(filename=recog_path, texts=results) @@ -732,7 +732,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" ) with open(errs_filename, "w") as f: wer = write_error_stats( @@ -744,7 +744,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/ctc_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/ctc_decode.py index 4b373e4c7..6c11d95b4 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/ctc_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/ctc_decode.py @@ -542,7 +542,7 @@ def save_results( test_set_wers = dict() for key, results in results_dict.items(): recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" ) results = sorted(results) store_transcripts(filename=recog_path, texts=results) @@ -551,7 +551,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" ) with open(errs_filename, "w") as f: wer = write_error_stats(f, f"{test_set_name}-{key}", results) @@ -561,7 +561,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/decode.py index 32a9b6bb2..643486a6a 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/decode.py @@ -594,7 +594,7 @@ def save_results( test_set_wers = dict() for key, results in results_dict.items(): recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" ) results = sorted(results) store_transcripts(filename=recog_path, texts=results) @@ -603,7 +603,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" ) with open(errs_filename, "w") as f: wer = write_error_stats( @@ -615,7 +615,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/ctc_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/ctc_decode.py index f137485b2..aadf75c5f 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/ctc_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/ctc_decode.py @@ -533,7 +533,7 @@ def save_results( test_set_wers = dict() for key, results in results_dict.items(): recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" ) results = sorted(results) store_transcripts(filename=recog_path, texts=results) @@ -542,7 +542,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" ) with open(errs_filename, "w") as f: wer = write_error_stats(f, f"{test_set_name}-{key}", results) @@ -552,7 +552,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/decode.py index ce45a4beb..77160a9d4 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/decode.py @@ -594,7 +594,7 @@ def save_results( test_set_wers = dict() for key, results in results_dict.items(): recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" ) results = sorted(results) store_transcripts(filename=recog_path, texts=results) @@ -603,7 +603,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" ) with open(errs_filename, "w") as f: wer = write_error_stats( @@ -615,7 +615,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/decode.py index aebe2b94b..ed499d043 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/decode.py @@ -569,7 +569,7 @@ def save_results( test_set_wers = dict() for key, results in results_dict.items(): recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" ) results = sorted(results) store_transcripts(filename=recog_path, texts=results) @@ -578,7 +578,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" ) with open(errs_filename, "w") as f: wer = write_error_stats( @@ -590,7 +590,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py index 7a349ecb2..9191edaab 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py @@ -410,7 +410,7 @@ def save_results( test_set_wers = dict() for key, results in results_dict.items(): recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" ) results = sorted(results) store_transcripts(filename=recog_path, texts=results) @@ -419,7 +419,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" ) with open(errs_filename, "w") as f: wer = write_error_stats( @@ -431,7 +431,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless8/decode.py index e61367134..8314d6acf 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless8/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/decode.py @@ -595,7 +595,7 @@ def save_results( test_set_wers = dict() for key, results in results_dict.items(): recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" ) results = sorted(results) store_transcripts(filename=recog_path, texts=results) @@ -604,7 +604,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" ) with open(errs_filename, "w") as f: wer = write_error_stats( @@ -616,7 +616,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) diff --git a/egs/librispeech/ASR/transducer/decode.py b/egs/librispeech/ASR/transducer/decode.py index 804713a20..c0413e2d1 100755 --- a/egs/librispeech/ASR/transducer/decode.py +++ b/egs/librispeech/ASR/transducer/decode.py @@ -326,7 +326,7 @@ def save_results( test_set_wers = dict() for key, results in results_dict.items(): recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" ) results = sorted(results) store_transcripts(filename=recog_path, texts=results) @@ -335,7 +335,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" ) with open(errs_filename, "w") as f: wer = write_error_stats( @@ -347,7 +347,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) diff --git a/egs/librispeech/ASR/transducer_lstm/decode.py b/egs/librispeech/ASR/transducer_lstm/decode.py index 9511ca6d7..cd6d722bd 100755 --- a/egs/librispeech/ASR/transducer_lstm/decode.py +++ b/egs/librispeech/ASR/transducer_lstm/decode.py @@ -323,7 +323,7 @@ def save_results( test_set_wers = dict() for key, results in results_dict.items(): recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" ) results = sorted(results) store_transcripts(filename=recog_path, texts=results) @@ -332,7 +332,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" ) with open(errs_filename, "w") as f: wer = write_error_stats( @@ -344,7 +344,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) diff --git a/egs/librispeech/ASR/transducer_stateless/decode.py b/egs/librispeech/ASR/transducer_stateless/decode.py index 643238f1b..a72d60b9f 100755 --- a/egs/librispeech/ASR/transducer_stateless/decode.py +++ b/egs/librispeech/ASR/transducer_stateless/decode.py @@ -380,7 +380,7 @@ def save_results( test_set_wers = dict() for key, results in results_dict.items(): recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" ) results = sorted(results) store_transcripts(filename=recog_path, texts=results) @@ -389,7 +389,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" ) with open(errs_filename, "w") as f: wer = write_error_stats( @@ -401,7 +401,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) diff --git a/egs/librispeech/ASR/transducer_stateless2/decode.py b/egs/librispeech/ASR/transducer_stateless2/decode.py index 9a6363629..c91a1f490 100755 --- a/egs/librispeech/ASR/transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/transducer_stateless2/decode.py @@ -380,7 +380,7 @@ def save_results( test_set_wers = dict() for key, results in results_dict.items(): recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" ) results = sorted(results) store_transcripts(filename=recog_path, texts=results) @@ -389,7 +389,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" ) with open(errs_filename, "w") as f: wer = write_error_stats( @@ -401,7 +401,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) diff --git a/egs/librispeech/ASR/transducer_stateless_multi_datasets/decode.py b/egs/librispeech/ASR/transducer_stateless_multi_datasets/decode.py index 56ad558c6..5c20e2bfd 100755 --- a/egs/librispeech/ASR/transducer_stateless_multi_datasets/decode.py +++ b/egs/librispeech/ASR/transducer_stateless_multi_datasets/decode.py @@ -381,7 +381,7 @@ def save_results( test_set_wers = dict() for key, results in results_dict.items(): recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" ) results = sorted(results) store_transcripts(filename=recog_path, texts=results) @@ -390,7 +390,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" ) with open(errs_filename, "w") as f: wer = write_error_stats( @@ -402,7 +402,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) diff --git a/egs/librispeech/ASR/zipformer_mmi/decode.py b/egs/librispeech/ASR/zipformer_mmi/decode.py index 7d0ea78bb..a96c5c6f0 100755 --- a/egs/librispeech/ASR/zipformer_mmi/decode.py +++ b/egs/librispeech/ASR/zipformer_mmi/decode.py @@ -472,7 +472,7 @@ def save_results( test_set_wers = dict() for key, results in results_dict.items(): recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" ) results = sorted(results) store_transcripts(filename=recog_path, texts=results) @@ -481,7 +481,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" ) with open(errs_filename, "w") as f: wer = write_error_stats(f, f"{test_set_name}-{key}", results) @@ -491,7 +491,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) diff --git a/egs/mgb2/ASR/pruned_transducer_stateless5/decode.py b/egs/mgb2/ASR/pruned_transducer_stateless5/decode.py index 1463f8f67..f72d4d7f6 100755 --- a/egs/mgb2/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/mgb2/ASR/pruned_transducer_stateless5/decode.py @@ -411,7 +411,7 @@ def save_results( test_set_wers = dict() for key, results in results_dict.items(): recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" ) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") @@ -419,7 +419,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" ) with open(errs_filename, "w") as f: wer = write_error_stats( @@ -431,7 +431,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) diff --git a/egs/spgispeech/ASR/pruned_transducer_stateless2/decode.py b/egs/spgispeech/ASR/pruned_transducer_stateless2/decode.py index 219c96d60..cb9417d2a 100755 --- a/egs/spgispeech/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/spgispeech/ASR/pruned_transducer_stateless2/decode.py @@ -392,7 +392,7 @@ def save_results( test_set_cers = dict() for key, results in results_dict.items(): recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" ) results = sorted(results) store_transcripts(filename=recog_path, texts=results) @@ -401,7 +401,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. wers_filename = ( - params.res_dir / f"wers-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wers-{test_set_name}-{params.suffix}.txt" ) with open(wers_filename, "w") as f: wer = write_error_stats( @@ -414,7 +414,7 @@ def save_results( for res in results: results_char.append((res[0], list("".join(res[1])), list("".join(res[2])))) cers_filename = ( - params.res_dir / f"cers-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"cers-{test_set_name}-{params.suffix}.txt" ) with open(cers_filename, "w") as f: cer = write_error_stats( @@ -427,7 +427,7 @@ def save_results( test_set_wers = {k: v for k, v in sorted(test_set_wers.items(), key=lambda x: x[1])} test_set_cers = {k: v for k, v in sorted(test_set_cers.items(), key=lambda x: x[1])} errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER\tCER", file=f) diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/decode.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/decode.py index bf91fef7e..1d6a22973 100755 --- a/egs/tal_csasr/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/tal_csasr/ASR/pruned_transducer_stateless5/decode.py @@ -510,7 +510,7 @@ def save_results( test_set_wers = dict() for key, results in results_dict.items(): recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" ) results = sorted(results) store_transcripts(filename=recog_path, texts=results) @@ -519,7 +519,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" ) with open(errs_filename, "w") as f: wer = write_error_stats( @@ -531,7 +531,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless/decode.py b/egs/tedlium3/ASR/pruned_transducer_stateless/decode.py index 38f2ae83c..0d1fe9aa1 100755 --- a/egs/tedlium3/ASR/pruned_transducer_stateless/decode.py +++ b/egs/tedlium3/ASR/pruned_transducer_stateless/decode.py @@ -380,7 +380,7 @@ def save_results( test_set_wers = dict() for key, results in results_dict.items(): recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" ) results = sorted(results) store_transcripts(filename=recog_path, texts=results) @@ -389,7 +389,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" ) with open(errs_filename, "w") as f: wer = write_error_stats( @@ -401,7 +401,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) diff --git a/egs/tedlium3/ASR/transducer_stateless/decode.py b/egs/tedlium3/ASR/transducer_stateless/decode.py index 01f08ce59..c88760854 100755 --- a/egs/tedlium3/ASR/transducer_stateless/decode.py +++ b/egs/tedlium3/ASR/transducer_stateless/decode.py @@ -355,7 +355,7 @@ def save_results( test_set_wers = dict() for key, results in results_dict.items(): recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" ) results = sorted(results) store_transcripts(filename=recog_path, texts=results) @@ -364,7 +364,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" ) with open(errs_filename, "w") as f: wer = write_error_stats( @@ -376,7 +376,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py index 04602ea2e..a0bf77b39 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py @@ -517,7 +517,7 @@ def save_results( test_set_wers = dict() for key, results in results_dict.items(): recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" ) results = sorted(results) store_transcripts(filename=recog_path, texts=results) @@ -526,7 +526,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" ) with open(errs_filename, "w") as f: wer = write_error_stats( @@ -538,7 +538,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py index 7bd1177bd..9f6043926 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py @@ -490,7 +490,7 @@ def save_results( test_set_wers = dict() for key, results in results_dict.items(): recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" ) results = sorted(results) store_transcripts(filename=recog_path, texts=results) @@ -499,7 +499,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" ) with open(errs_filename, "w") as f: wer = write_error_stats( @@ -511,7 +511,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_decode.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_decode.py index c7863415b..398690d48 100644 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_decode.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_decode.py @@ -467,7 +467,7 @@ def save_results( test_set_wers = dict() for key, results in results_dict.items(): recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" ) # sort results so we can easily compare the difference between two # recognition results @@ -478,7 +478,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" ) with open(errs_filename, "w") as f: wer = write_error_stats( @@ -490,7 +490,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/decode.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/decode.py index 6a67e26f8..5b7f5f95b 100755 --- a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/decode.py @@ -702,7 +702,7 @@ def save_results( test_set_wers = dict() for key, results in results_dict.items(): recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" ) results = sorted(results) store_transcripts(filename=recog_path, texts=results) @@ -711,7 +711,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" ) with open(errs_filename, "w") as f: wer = write_error_stats( @@ -723,7 +723,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/decode.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/decode.py index ace792e13..a291bb303 100755 --- a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/decode.py +++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/decode.py @@ -594,7 +594,7 @@ def save_results( test_set_wers = dict() for key, results in results_dict.items(): recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" ) results = sorted(results) store_transcripts(filename=recog_path, texts=results) @@ -603,7 +603,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" ) with open(errs_filename, "w") as f: wer = write_error_stats( @@ -615,7 +615,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) From f5de2e90c6672a843d5e94166fbd60f339cb6b9b Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Wed, 8 Mar 2023 22:56:04 +0800 Subject: [PATCH 127/174] Fix style issues. (#937) --- .../ASR/pruned_transducer_stateless2/decode.py | 12 +++--------- .../ASR/pruned_transducer_stateless2/decode.py | 12 +++--------- .../ASR/pruned_transducer_stateless3/decode.py | 12 +++--------- egs/aishell/ASR/transducer_stateless/decode.py | 12 +++--------- .../transducer_stateless_modified-2/decode.py | 12 +++--------- .../ASR/transducer_stateless_modified/decode.py | 12 +++--------- .../ASR/pruned_transducer_stateless5/decode.py | 12 +++--------- .../ASR/pruned_transducer_stateless5/decode.py | 12 +++--------- .../ASR/pruned_transducer_stateless2/decode.py | 12 +++--------- .../pruned_transducer_stateless7/decode.py | 12 +++--------- .../ASR/pruned_transducer_stateless7/decode.py | 16 ++++------------ .../decode.py | 12 +++--------- .../ASR/pruned_transducer_stateless2/decode.py | 12 +++--------- egs/librispeech/ASR/conformer_ctc3/decode.py | 15 ++++----------- .../conv_emformer_transducer_stateless/decode.py | 12 +++--------- .../streaming_decode.py | 12 +++--------- .../decode.py | 12 +++--------- .../streaming_decode.py | 12 +++--------- .../ASR/lstm_transducer_stateless/decode.py | 12 +++--------- .../streaming_decode.py | 12 +++--------- .../ASR/lstm_transducer_stateless2/decode.py | 12 +++--------- .../ASR/lstm_transducer_stateless3/decode.py | 15 ++++----------- .../streaming_decode.py | 12 +++--------- egs/librispeech/ASR/pruned2_knowledge/decode.py | 12 +++--------- .../pruned_stateless_emformer_rnnt2/decode.py | 12 +++--------- .../ASR/pruned_transducer_stateless/decode.py | 12 +++--------- .../streaming_decode.py | 12 +++--------- .../ASR/pruned_transducer_stateless2/decode.py | 12 +++--------- .../streaming_decode.py | 12 +++--------- .../ASR/pruned_transducer_stateless3/decode.py | 12 +++--------- .../streaming_decode.py | 12 +++--------- .../ASR/pruned_transducer_stateless4/decode.py | 15 ++++----------- .../streaming_decode.py | 12 +++--------- .../ASR/pruned_transducer_stateless5/decode.py | 12 +++--------- .../streaming_decode.py | 12 +++--------- .../ASR/pruned_transducer_stateless6/decode.py | 12 +++--------- .../ASR/pruned_transducer_stateless7/decode.py | 12 +++--------- .../ctc_decode.py | 12 +++--------- .../pruned_transducer_stateless7_ctc/decode.py | 12 +++--------- .../ctc_decode.py | 12 +++--------- .../decode.py | 12 +++--------- .../decode.py | 12 +++--------- .../streaming_decode.py | 12 +++--------- .../ASR/pruned_transducer_stateless8/decode.py | 12 +++--------- egs/librispeech/ASR/transducer/decode.py | 12 +++--------- egs/librispeech/ASR/transducer_lstm/decode.py | 12 +++--------- .../ASR/transducer_stateless/decode.py | 12 +++--------- .../ASR/transducer_stateless2/decode.py | 12 +++--------- .../decode.py | 12 +++--------- egs/librispeech/ASR/zipformer_mmi/decode.py | 12 +++--------- .../ASR/pruned_transducer_stateless5/decode.py | 12 +++--------- .../ASR/pruned_transducer_stateless2/decode.py | 16 ++++------------ .../ASR/pruned_transducer_stateless5/decode.py | 12 +++--------- .../ASR/pruned_transducer_stateless/decode.py | 12 +++--------- egs/tedlium3/ASR/transducer_stateless/decode.py | 12 +++--------- .../ASR/pruned_transducer_stateless2/decode.py | 12 +++--------- .../ASR/pruned_transducer_stateless5/decode.py | 12 +++--------- .../streaming_decode.py | 12 +++--------- .../ASR/pruned_transducer_stateless5/decode.py | 12 +++--------- .../ASR/pruned_transducer_stateless7/decode.py | 12 +++--------- 60 files changed, 185 insertions(+), 552 deletions(-) diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/decode.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/decode.py index 090f7ff84..2512f233f 100755 --- a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/decode.py @@ -391,18 +391,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -412,9 +408,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/aishell/ASR/pruned_transducer_stateless2/decode.py b/egs/aishell/ASR/pruned_transducer_stateless2/decode.py index 04888fbc1..fb6c7c481 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/aishell/ASR/pruned_transducer_stateless2/decode.py @@ -388,18 +388,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" # we compute CER for aishell dataset. results_char = [] for res in results: @@ -413,9 +409,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/decode.py b/egs/aishell/ASR/pruned_transducer_stateless3/decode.py index 6e97f338f..954d9dc7e 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless3/decode.py +++ b/egs/aishell/ASR/pruned_transducer_stateless3/decode.py @@ -406,18 +406,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" # we compute CER for aishell dataset. results_char = [] for res in results: @@ -431,9 +427,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tCER", file=f) for key, val in test_set_wers: diff --git a/egs/aishell/ASR/transducer_stateless/decode.py b/egs/aishell/ASR/transducer_stateless/decode.py index d57fe6de4..d23f4f883 100755 --- a/egs/aishell/ASR/transducer_stateless/decode.py +++ b/egs/aishell/ASR/transducer_stateless/decode.py @@ -325,17 +325,13 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" # we compute CER for aishell dataset. results_char = [] for res in results: @@ -349,9 +345,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tCER", file=f) for key, val in test_set_wers: diff --git a/egs/aishell/ASR/transducer_stateless_modified-2/decode.py b/egs/aishell/ASR/transducer_stateless_modified-2/decode.py index 743fc7f45..d164b6890 100755 --- a/egs/aishell/ASR/transducer_stateless_modified-2/decode.py +++ b/egs/aishell/ASR/transducer_stateless_modified-2/decode.py @@ -370,18 +370,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" # we compute CER for aishell dataset. results_char = [] for res in results: @@ -395,9 +391,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tCER", file=f) for key, val in test_set_wers: diff --git a/egs/aishell/ASR/transducer_stateless_modified/decode.py b/egs/aishell/ASR/transducer_stateless_modified/decode.py index 9a1645915..0a7d87fe8 100755 --- a/egs/aishell/ASR/transducer_stateless_modified/decode.py +++ b/egs/aishell/ASR/transducer_stateless_modified/decode.py @@ -374,18 +374,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" # we compute CER for aishell dataset. results_char = [] for res in results: @@ -399,9 +395,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tCER", file=f) for key, val in test_set_wers: diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/decode.py b/egs/aishell2/ASR/pruned_transducer_stateless5/decode.py index 80194ad12..9e44b4e34 100755 --- a/egs/aishell2/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/aishell2/ASR/pruned_transducer_stateless5/decode.py @@ -543,18 +543,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -564,9 +560,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/decode.py b/egs/aishell4/ASR/pruned_transducer_stateless5/decode.py index eb202f8a8..068e2749a 100755 --- a/egs/aishell4/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/aishell4/ASR/pruned_transducer_stateless5/decode.py @@ -406,18 +406,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -427,9 +423,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/alimeeting/ASR/pruned_transducer_stateless2/decode.py b/egs/alimeeting/ASR/pruned_transducer_stateless2/decode.py index 675f0739f..6c170c392 100755 --- a/egs/alimeeting/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/alimeeting/ASR/pruned_transducer_stateless2/decode.py @@ -391,18 +391,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -412,9 +408,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/decode.py b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/decode.py index 9a7eef9bf..2741e0eeb 100755 --- a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/decode.py +++ b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/decode.py @@ -462,18 +462,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -483,9 +479,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/ami/ASR/pruned_transducer_stateless7/decode.py b/egs/ami/ASR/pruned_transducer_stateless7/decode.py index fc4005325..9999894d1 100755 --- a/egs/ami/ASR/pruned_transducer_stateless7/decode.py +++ b/egs/ami/ASR/pruned_transducer_stateless7/decode.py @@ -478,17 +478,13 @@ def save_results( test_set_wers = dict() test_set_cers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - wers_filename = ( - params.res_dir / f"wers-{test_set_name}-{params.suffix}.txt" - ) + wers_filename = params.res_dir / f"wers-{test_set_name}-{params.suffix}.txt" with open(wers_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -499,9 +495,7 @@ def save_results( results_char = [] for res in results: results_char.append((res[0], list("".join(res[1])), list("".join(res[2])))) - cers_filename = ( - params.res_dir / f"cers-{test_set_name}-{params.suffix}.txt" - ) + cers_filename = params.res_dir / f"cers-{test_set_name}-{params.suffix}.txt" with open(cers_filename, "w") as f: cer = write_error_stats( f, f"{test_set_name}-{key}", results_char, enable_log=True @@ -512,9 +506,7 @@ def save_results( test_set_wers = {k: v for k, v in sorted(test_set_wers.items(), key=lambda x: x[1])} test_set_cers = {k: v for k, v in sorted(test_set_cers.items(), key=lambda x: x[1])} - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER\tCER", file=f) for key in test_set_wers: diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/decode.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/decode.py index c5892f511..f5a1d750d 100755 --- a/egs/csj/ASR/pruned_transducer_stateless7_streaming/decode.py +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/decode.py @@ -599,9 +599,7 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) @@ -609,9 +607,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -621,9 +617,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/decode.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/decode.py index 27ce41c87..ee694a9e0 100755 --- a/egs/gigaspeech/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/gigaspeech/ASR/pruned_transducer_stateless2/decode.py @@ -399,9 +399,7 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = post_processing(results) results = sorted(results) store_transcripts(filename=recog_path, texts=results) @@ -409,9 +407,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -421,9 +417,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/librispeech/ASR/conformer_ctc3/decode.py b/egs/librispeech/ASR/conformer_ctc3/decode.py index cdee1ec9c..e6327bb5e 100755 --- a/egs/librispeech/ASR/conformer_ctc3/decode.py +++ b/egs/librispeech/ASR/conformer_ctc3/decode.py @@ -728,18 +728,14 @@ def save_results( test_set_wers = dict() test_set_delays = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts_and_timestamps(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer, mean_delay, var_delay = write_error_stats_with_timestamps( f, @@ -754,9 +750,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: @@ -765,8 +759,7 @@ def save_results( # sort according to the mean start symbol delay test_set_delays = sorted(test_set_delays.items(), key=lambda x: x[1][0][0]) delays_info = ( - params.res_dir - / f"symbol-delay-summary-{test_set_name}-{params.suffix}.txt" + params.res_dir / f"symbol-delay-summary-{test_set_name}-{params.suffix}.txt" ) with open(delays_info, "w") as f: print("settings\t(start, end) symbol-delay (s) (start, end)", file=f) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/decode.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/decode.py index 5d241ccbf..7be3299f3 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/decode.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/decode.py @@ -432,18 +432,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -453,9 +449,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py index e6c9d2ca2..e5a7c7116 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py @@ -750,17 +750,13 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" store_transcripts(filename=recog_path, texts=sorted(results)) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -770,9 +766,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/decode.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/decode.py index f9c1633d8..d022d463e 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/decode.py @@ -432,18 +432,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -453,9 +449,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming_decode.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming_decode.py index 6b3c1b563..f5d894a7b 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming_decode.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming_decode.py @@ -750,17 +750,13 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" store_transcripts(filename=recog_path, texts=sorted(results)) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -770,9 +766,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/decode.py b/egs/librispeech/ASR/lstm_transducer_stateless/decode.py index 6dc11bdb2..856c9d945 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless/decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/decode.py @@ -566,18 +566,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -587,9 +583,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/streaming_decode.py b/egs/librispeech/ASR/lstm_transducer_stateless/streaming_decode.py index d510d9659..f989d9bc0 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless/streaming_decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/streaming_decode.py @@ -742,17 +742,13 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" store_transcripts(filename=recog_path, texts=sorted(results)) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -762,9 +758,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py b/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py index 15e1109f2..6c58a57e1 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py @@ -702,18 +702,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -723,9 +719,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/decode.py b/egs/librispeech/ASR/lstm_transducer_stateless3/decode.py index 7ac9d5f34..a2b4f9e1a 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/decode.py @@ -611,18 +611,14 @@ def save_results( test_set_wers = dict() test_set_delays = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts_and_timestamps(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer, mean_delay, var_delay = write_error_stats_with_timestamps( f, f"{test_set_name}-{key}", results, enable_log=True @@ -633,9 +629,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: @@ -643,8 +637,7 @@ def save_results( test_set_delays = sorted(test_set_delays.items(), key=lambda x: x[1][0]) delays_info = ( - params.res_dir - / f"symbol-delay-summary-{test_set_name}-{params.suffix}.txt" + params.res_dir / f"symbol-delay-summary-{test_set_name}-{params.suffix}.txt" ) with open(delays_info, "w") as f: print("settings\tsymbol-delay", file=f) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/streaming_decode.py b/egs/librispeech/ASR/lstm_transducer_stateless3/streaming_decode.py index b8b6e4f43..c737e3611 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/streaming_decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/streaming_decode.py @@ -742,17 +742,13 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" store_transcripts(filename=recog_path, texts=sorted(results)) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -762,9 +758,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/librispeech/ASR/pruned2_knowledge/decode.py b/egs/librispeech/ASR/pruned2_knowledge/decode.py index f22731469..82fd103ea 100755 --- a/egs/librispeech/ASR/pruned2_knowledge/decode.py +++ b/egs/librispeech/ASR/pruned2_knowledge/decode.py @@ -386,17 +386,13 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -406,9 +402,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/decode.py b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/decode.py index ea7692f49..072d49d9c 100755 --- a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/decode.py +++ b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/decode.py @@ -420,18 +420,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -441,9 +437,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py index 8a719ae3b..6dfe11cee 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py @@ -585,18 +585,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -606,9 +602,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py index 28c40c780..f4b01fd06 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py @@ -423,9 +423,7 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" # sort results so we can easily compare the difference between two # recognition results results = sorted(results) @@ -434,9 +432,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -446,9 +442,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py index 2791a60de..172c9ab7c 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py @@ -609,18 +609,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -630,9 +626,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py index eac8f8393..9c4a13606 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py @@ -425,9 +425,7 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" # sort results so we can easily compare the difference between two # recognition results results = sorted(results) @@ -436,9 +434,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -448,9 +444,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py index 298c6c950..aa055049e 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py @@ -869,18 +869,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -890,9 +886,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_decode.py index 421bfb0b7..3a1ecb7ed 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_decode.py @@ -426,18 +426,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -447,9 +443,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py index dca2ec081..5ec3d3b45 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py @@ -655,18 +655,14 @@ def save_results( test_set_wers = dict() test_set_delays = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts_and_timestamps(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer, mean_delay, var_delay = write_error_stats_with_timestamps( f, f"{test_set_name}-{key}", results, enable_log=True @@ -677,9 +673,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: @@ -687,8 +681,7 @@ def save_results( test_set_delays = sorted(test_set_delays.items(), key=lambda x: x[1][0]) delays_info = ( - params.res_dir - / f"symbol-delay-summary-{test_set_name}-{params.suffix}.txt" + params.res_dir / f"symbol-delay-summary-{test_set_name}-{params.suffix}.txt" ) with open(delays_info, "w") as f: print("settings\tsymbol-delay", file=f) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_decode.py index cb5d52859..ca3a023ce 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_decode.py @@ -442,18 +442,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -463,9 +459,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py index 5c5d3ecd9..2be895feb 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py @@ -735,18 +735,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -756,9 +752,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless5/streaming_decode.py index ae221eaba..5b15dcee7 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/streaming_decode.py @@ -442,18 +442,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -463,9 +459,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless6/decode.py index c81186295..95534efef 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/decode.py @@ -416,18 +416,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -437,9 +433,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py index 856ef845a..32b3134b9 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py @@ -722,18 +722,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -743,9 +739,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/ctc_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/ctc_decode.py index 6c11d95b4..629bec058 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/ctc_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/ctc_decode.py @@ -541,18 +541,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats(f, f"{test_set_name}-{key}", results) test_set_wers[key] = wer @@ -560,9 +556,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/decode.py index 643486a6a..7641fa5af 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/decode.py @@ -593,18 +593,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -614,9 +610,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/ctc_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/ctc_decode.py index aadf75c5f..fa7144f0f 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/ctc_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/ctc_decode.py @@ -532,18 +532,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats(f, f"{test_set_name}-{key}", results) test_set_wers[key] = wer @@ -551,9 +547,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/decode.py index 77160a9d4..e497787d3 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/decode.py @@ -593,18 +593,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -614,9 +610,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/decode.py index ed499d043..e7616fbc5 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/decode.py @@ -568,18 +568,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -589,9 +585,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py index 9191edaab..c272ed641 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py @@ -409,18 +409,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -430,9 +426,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless8/decode.py index 8314d6acf..7b651a632 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless8/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/decode.py @@ -594,18 +594,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -615,9 +611,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/librispeech/ASR/transducer/decode.py b/egs/librispeech/ASR/transducer/decode.py index c0413e2d1..8d379d1fa 100755 --- a/egs/librispeech/ASR/transducer/decode.py +++ b/egs/librispeech/ASR/transducer/decode.py @@ -325,18 +325,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -346,9 +342,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/librispeech/ASR/transducer_lstm/decode.py b/egs/librispeech/ASR/transducer_lstm/decode.py index cd6d722bd..806b68f40 100755 --- a/egs/librispeech/ASR/transducer_lstm/decode.py +++ b/egs/librispeech/ASR/transducer_lstm/decode.py @@ -322,18 +322,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -343,9 +339,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/librispeech/ASR/transducer_stateless/decode.py b/egs/librispeech/ASR/transducer_stateless/decode.py index a72d60b9f..42125e19f 100755 --- a/egs/librispeech/ASR/transducer_stateless/decode.py +++ b/egs/librispeech/ASR/transducer_stateless/decode.py @@ -379,18 +379,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -400,9 +396,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/librispeech/ASR/transducer_stateless2/decode.py b/egs/librispeech/ASR/transducer_stateless2/decode.py index c91a1f490..b05fe2a4d 100755 --- a/egs/librispeech/ASR/transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/transducer_stateless2/decode.py @@ -379,18 +379,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -400,9 +396,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/librispeech/ASR/transducer_stateless_multi_datasets/decode.py b/egs/librispeech/ASR/transducer_stateless_multi_datasets/decode.py index 5c20e2bfd..5570b30ae 100755 --- a/egs/librispeech/ASR/transducer_stateless_multi_datasets/decode.py +++ b/egs/librispeech/ASR/transducer_stateless_multi_datasets/decode.py @@ -380,18 +380,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -401,9 +397,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/librispeech/ASR/zipformer_mmi/decode.py b/egs/librispeech/ASR/zipformer_mmi/decode.py index a96c5c6f0..33c0bf199 100755 --- a/egs/librispeech/ASR/zipformer_mmi/decode.py +++ b/egs/librispeech/ASR/zipformer_mmi/decode.py @@ -471,18 +471,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats(f, f"{test_set_name}-{key}", results) test_set_wers[key] = wer @@ -490,9 +486,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/mgb2/ASR/pruned_transducer_stateless5/decode.py b/egs/mgb2/ASR/pruned_transducer_stateless5/decode.py index f72d4d7f6..72338bade 100755 --- a/egs/mgb2/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/mgb2/ASR/pruned_transducer_stateless5/decode.py @@ -410,17 +410,13 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -430,9 +426,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/spgispeech/ASR/pruned_transducer_stateless2/decode.py b/egs/spgispeech/ASR/pruned_transducer_stateless2/decode.py index cb9417d2a..4434aae62 100755 --- a/egs/spgispeech/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/spgispeech/ASR/pruned_transducer_stateless2/decode.py @@ -391,18 +391,14 @@ def save_results( test_set_wers = dict() test_set_cers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - wers_filename = ( - params.res_dir / f"wers-{test_set_name}-{params.suffix}.txt" - ) + wers_filename = params.res_dir / f"wers-{test_set_name}-{params.suffix}.txt" with open(wers_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -413,9 +409,7 @@ def save_results( results_char = [] for res in results: results_char.append((res[0], list("".join(res[1])), list("".join(res[2])))) - cers_filename = ( - params.res_dir / f"cers-{test_set_name}-{params.suffix}.txt" - ) + cers_filename = params.res_dir / f"cers-{test_set_name}-{params.suffix}.txt" with open(cers_filename, "w") as f: cer = write_error_stats( f, f"{test_set_name}-{key}", results_char, enable_log=True @@ -426,9 +420,7 @@ def save_results( test_set_wers = {k: v for k, v in sorted(test_set_wers.items(), key=lambda x: x[1])} test_set_cers = {k: v for k, v in sorted(test_set_cers.items(), key=lambda x: x[1])} - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER\tCER", file=f) for key in test_set_wers: diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/decode.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/decode.py index 1d6a22973..3bfb832fb 100755 --- a/egs/tal_csasr/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/tal_csasr/ASR/pruned_transducer_stateless5/decode.py @@ -509,18 +509,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -530,9 +526,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless/decode.py b/egs/tedlium3/ASR/pruned_transducer_stateless/decode.py index 0d1fe9aa1..abba9d403 100755 --- a/egs/tedlium3/ASR/pruned_transducer_stateless/decode.py +++ b/egs/tedlium3/ASR/pruned_transducer_stateless/decode.py @@ -379,18 +379,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -400,9 +396,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/tedlium3/ASR/transducer_stateless/decode.py b/egs/tedlium3/ASR/transducer_stateless/decode.py index c88760854..fb0e3116b 100755 --- a/egs/tedlium3/ASR/transducer_stateless/decode.py +++ b/egs/tedlium3/ASR/transducer_stateless/decode.py @@ -354,18 +354,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -375,9 +371,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py index a0bf77b39..823b33ae5 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py @@ -516,18 +516,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -537,9 +533,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py index 9f6043926..32d5738b1 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py @@ -489,18 +489,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -510,9 +506,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_decode.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_decode.py index 398690d48..3a4dc3cb8 100644 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_decode.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_decode.py @@ -466,9 +466,7 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" # sort results so we can easily compare the difference between two # recognition results results = sorted(results) @@ -477,9 +475,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -489,9 +485,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/decode.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/decode.py index 5b7f5f95b..b77f734e3 100755 --- a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/decode.py @@ -701,18 +701,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -722,9 +718,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/decode.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/decode.py index a291bb303..e334e690a 100755 --- a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/decode.py +++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/decode.py @@ -593,18 +593,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -614,9 +610,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: From 28af269e5e27cb8ab62f1bc82d1c5a2b7f659843 Mon Sep 17 00:00:00 2001 From: Yifan Yang <64255737+yfyeung@users.noreply.github.com> Date: Thu, 9 Mar 2023 17:38:15 +0800 Subject: [PATCH 128/174] Fix for workflow (#934) --- ...k-librispeech-test-clean-and-test-other.sh | 2 +- ...pruned-transducer-stateless3-2022-06-20.sh | 1 - ...n-librispeech-conformer-ctc3-2022-11-28.sh | 1 - ...h-lstm-transducer-stateless2-2022-09-03.sh | 1 - ...-pruned-transducer-stateless-2022-03-12.sh | 1 - ...pruned-transducer-stateless2-2022-04-29.sh | 1 - ...pruned-transducer-stateless3-2022-04-29.sh | 1 - ...pruned-transducer-stateless3-2022-05-13.sh | 1 - ...pruned-transducer-stateless5-2022-05-13.sh | 1 - ...pruned-transducer-stateless7-2022-11-11.sh | 1 - ...ed-transducer-stateless7-ctc-2022-12-01.sh | 3 +- ...transducer-stateless7-ctc-bs-2022-12-15.sh | 3 +- ...nsducer-stateless7-streaming-2022-12-29.sh | 1 - ...pruned-transducer-stateless8-2022-11-14.sh | 1 - ...pruned-transducer-stateless2-2022-06-26.sh | 1 - ...speech-transducer-stateless2-2022-04-19.sh | 1 - ...un-librispeech-zipformer-mmi-2022-12-08.sh | 1 - .../scripts/run-pre-trained-conformer-ctc.sh | 1 - ...d-transducer-stateless-librispeech-100h.sh | 1 - ...d-transducer-stateless-librispeech-960h.sh | 1 - ...transducer-stateless-modified-2-aishell.sh | 1 - ...d-transducer-stateless-modified-aishell.sh | 1 - .../run-pre-trained-transducer-stateless.sh | 1 - .github/scripts/run-pre-trained-transducer.sh | 1 - ...enetspeech-pruned-transducer-stateless2.sh | 1 - .github/scripts/test-ncnn-export.sh | 67 ------------------- .github/workflows/run-aishell-2022-06-20.yml | 4 +- .../workflows/run-gigaspeech-2022-05-13.yml | 2 +- .../workflows/run-librispeech-2022-03-12.yml | 4 +- .../workflows/run-librispeech-2022-04-29.yml | 4 +- .../workflows/run-librispeech-2022-05-13.yml | 4 +- .../run-librispeech-2022-11-11-stateless7.yml | 4 +- .../run-librispeech-2022-11-14-stateless8.yml | 4 +- ...-librispeech-2022-12-01-stateless7-ctc.yml | 4 +- ...n-librispeech-2022-12-08-zipformer-mmi.yml | 4 +- ...brispeech-2022-12-15-stateless7-ctc-bs.yml | 6 +- ...speech-2022-12-29-stateless7-streaming.yml | 4 +- ...-librispeech-conformer-ctc3-2022-11-28.yml | 4 +- ...-lstm-transducer-stateless2-2022-09-03.yml | 4 +- ...runed-transducer-stateless3-2022-05-13.yml | 4 +- ...aming-transducer-stateless2-2022-06-26.yml | 4 +- ...peech-transducer-stateless2-2022-04-19.yml | 4 +- .../run-pretrained-conformer-ctc.yml | 4 +- ...-transducer-stateless-librispeech-100h.yml | 4 +- ...r-stateless-librispeech-multi-datasets.yml | 4 +- ...ransducer-stateless-modified-2-aishell.yml | 4 +- ...-transducer-stateless-modified-aishell.yml | 4 +- .../run-pretrained-transducer-stateless.yml | 4 +- .../workflows/run-pretrained-transducer.yml | 4 +- .github/workflows/run-ptb-rnn-lm.yml | 2 +- ...netspeech-pruned-transducer-stateless2.yml | 4 +- .github/workflows/run-yesno-recipe.yml | 2 +- .github/workflows/test-ncnn-export.yml | 2 +- .github/workflows/test-onnx-export.yml | 2 +- .github/workflows/test.yml | 4 +- .../ASR/local/compute_fbank_librispeech.py | 36 +++++++--- 56 files changed, 82 insertions(+), 159 deletions(-) diff --git a/.github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh b/.github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh index bb7c7dfdc..0bec8c0c4 100755 --- a/.github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh +++ b/.github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh @@ -15,5 +15,5 @@ mkdir -p data cd data [ ! -e fbank ] && ln -s ~/tmp/fbank-libri fbank cd .. -./local/compute_fbank_librispeech.py +./local/compute_fbank_librispeech.py --dataset 'test-clean test-other' ls -lh data/fbank/ diff --git a/.github/scripts/run-aishell-pruned-transducer-stateless3-2022-06-20.sh b/.github/scripts/run-aishell-pruned-transducer-stateless3-2022-06-20.sh index e70a1848d..4c393f6be 100755 --- a/.github/scripts/run-aishell-pruned-transducer-stateless3-2022-06-20.sh +++ b/.github/scripts/run-aishell-pruned-transducer-stateless3-2022-06-20.sh @@ -25,7 +25,6 @@ repo=$(basename $repo_url) log "Display test files" tree $repo/ -soxi $repo/test_wavs/*.wav ls -lh $repo/test_wavs/*.wav pushd $repo/exp diff --git a/.github/scripts/run-librispeech-conformer-ctc3-2022-11-28.sh b/.github/scripts/run-librispeech-conformer-ctc3-2022-11-28.sh index df29f188e..c68ccc954 100755 --- a/.github/scripts/run-librispeech-conformer-ctc3-2022-11-28.sh +++ b/.github/scripts/run-librispeech-conformer-ctc3-2022-11-28.sh @@ -18,7 +18,6 @@ repo=$(basename $repo_url) log "Display test files" tree $repo/ -soxi $repo/test_wavs/*.wav ls -lh $repo/test_wavs/*.wav pushd $repo/exp diff --git a/.github/scripts/run-librispeech-lstm-transducer-stateless2-2022-09-03.sh b/.github/scripts/run-librispeech-lstm-transducer-stateless2-2022-09-03.sh index 91cdea01a..4cd2c4bec 100755 --- a/.github/scripts/run-librispeech-lstm-transducer-stateless2-2022-09-03.sh +++ b/.github/scripts/run-librispeech-lstm-transducer-stateless2-2022-09-03.sh @@ -20,7 +20,6 @@ abs_repo=$(realpath $repo) log "Display test files" tree $repo/ -soxi $repo/test_wavs/*.wav ls -lh $repo/test_wavs/*.wav pushd $repo/exp diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless-2022-03-12.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless-2022-03-12.sh index dafea56db..6792c7088 100755 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless-2022-03-12.sh +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless-2022-03-12.sh @@ -19,7 +19,6 @@ repo=$(basename $repo_url) log "Display test files" tree $repo/ -soxi $repo/test_wavs/*.wav ls -lh $repo/test_wavs/*.wav for sym in 1 2 3; do diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless2-2022-04-29.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless2-2022-04-29.sh index c3d07dc0e..dbf678d72 100755 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless2-2022-04-29.sh +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless2-2022-04-29.sh @@ -23,7 +23,6 @@ popd log "Display test files" tree $repo/ -soxi $repo/test_wavs/*.wav ls -lh $repo/test_wavs/*.wav pushd $repo/exp diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-04-29.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-04-29.sh index 22de3b45d..b6d477afe 100755 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-04-29.sh +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-04-29.sh @@ -22,7 +22,6 @@ popd log "Display test files" tree $repo/ -soxi $repo/test_wavs/*.wav ls -lh $repo/test_wavs/*.wav pushd $repo/exp diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-05-13.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-05-13.sh index ceb77c7c3..efa4b53f0 100755 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-05-13.sh +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-05-13.sh @@ -19,7 +19,6 @@ repo=$(basename $repo_url) log "Display test files" tree $repo/ -soxi $repo/test_wavs/*.wav ls -lh $repo/test_wavs/*.wav pushd $repo/exp diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless5-2022-05-13.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless5-2022-05-13.sh index c6a781318..511fe0c9e 100755 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless5-2022-05-13.sh +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless5-2022-05-13.sh @@ -19,7 +19,6 @@ repo=$(basename $repo_url) log "Display test files" tree $repo/ -soxi $repo/test_wavs/*.wav ls -lh $repo/test_wavs/*.wav pushd $repo/exp diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless7-2022-11-11.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless7-2022-11-11.sh index 8e485d2e6..2bc179c86 100755 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless7-2022-11-11.sh +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless7-2022-11-11.sh @@ -19,7 +19,6 @@ repo=$(basename $repo_url) log "Display test files" tree $repo/ -soxi $repo/test_wavs/*.wav ls -lh $repo/test_wavs/*.wav pushd $repo/exp diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-2022-12-01.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-2022-12-01.sh index 3cbb480f6..192438353 100755 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-2022-12-01.sh +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-2022-12-01.sh @@ -18,7 +18,6 @@ repo=$(basename $repo_url) log "Display test files" tree $repo/ -soxi $repo/test_wavs/*.wav ls -lh $repo/test_wavs/*.wav pushd $repo/exp @@ -148,4 +147,4 @@ if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == done rm pruned_transducer_stateless7_ctc/exp/*.pt -fi \ No newline at end of file +fi diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-bs-2022-12-15.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-bs-2022-12-15.sh index ed66a728e..761eb72e2 100755 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-bs-2022-12-15.sh +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-bs-2022-12-15.sh @@ -10,7 +10,7 @@ log() { cd egs/librispeech/ASR -repo_url=https://huggingface.co/yfyeung/icefall-asr-librispeech-pruned_transducer_stateless7_ctc_bs-2022-12-14 +repo_url=https://huggingface.co/yfyeung/icefall-asr-librispeech-pruned_transducer_stateless7_ctc_bs-2023-01-29 log "Downloading pre-trained model from $repo_url" GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url @@ -18,7 +18,6 @@ repo=$(basename $repo_url) log "Display test files" tree $repo/ -soxi $repo/test_wavs/*.wav ls -lh $repo/test_wavs/*.wav pushd $repo/exp diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless7-streaming-2022-12-29.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless7-streaming-2022-12-29.sh index 584f5d488..e1e4e1f10 100755 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless7-streaming-2022-12-29.sh +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless7-streaming-2022-12-29.sh @@ -19,7 +19,6 @@ repo=$(basename $repo_url) log "Display test files" tree $repo/ -soxi $repo/test_wavs/*.wav ls -lh $repo/test_wavs/*.wav pushd $repo diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless8-2022-11-14.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless8-2022-11-14.sh index e782b8425..5d9485692 100755 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless8-2022-11-14.sh +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless8-2022-11-14.sh @@ -19,7 +19,6 @@ repo=$(basename $repo_url) log "Display test files" tree $repo/ -soxi $repo/test_wavs/*.wav ls -lh $repo/test_wavs/*.wav pushd $repo/exp diff --git a/.github/scripts/run-librispeech-streaming-pruned-transducer-stateless2-2022-06-26.sh b/.github/scripts/run-librispeech-streaming-pruned-transducer-stateless2-2022-06-26.sh index af37102d5..77cd59506 100755 --- a/.github/scripts/run-librispeech-streaming-pruned-transducer-stateless2-2022-06-26.sh +++ b/.github/scripts/run-librispeech-streaming-pruned-transducer-stateless2-2022-06-26.sh @@ -19,7 +19,6 @@ repo=$(basename $repo_url) log "Display test files" tree $repo/ -soxi $repo/test_wavs/*.wav ls -lh $repo/test_wavs/*.wav pushd $repo/exp diff --git a/.github/scripts/run-librispeech-transducer-stateless2-2022-04-19.sh b/.github/scripts/run-librispeech-transducer-stateless2-2022-04-19.sh index 5b8ed396b..b4aca1b6b 100755 --- a/.github/scripts/run-librispeech-transducer-stateless2-2022-04-19.sh +++ b/.github/scripts/run-librispeech-transducer-stateless2-2022-04-19.sh @@ -19,7 +19,6 @@ repo=$(basename $repo_url) log "Display test files" tree $repo/ -soxi $repo/test_wavs/*.wav ls -lh $repo/test_wavs/*.wav for sym in 1 2 3; do diff --git a/.github/scripts/run-librispeech-zipformer-mmi-2022-12-08.sh b/.github/scripts/run-librispeech-zipformer-mmi-2022-12-08.sh index 77f28b054..a58b8ec56 100755 --- a/.github/scripts/run-librispeech-zipformer-mmi-2022-12-08.sh +++ b/.github/scripts/run-librispeech-zipformer-mmi-2022-12-08.sh @@ -18,7 +18,6 @@ repo=$(basename $repo_url) log "Display test files" tree $repo/ -soxi $repo/test_wavs/*.wav ls -lh $repo/test_wavs/*.wav pushd $repo/exp diff --git a/.github/scripts/run-pre-trained-conformer-ctc.sh b/.github/scripts/run-pre-trained-conformer-ctc.sh index 96c320616..125d1f3b1 100755 --- a/.github/scripts/run-pre-trained-conformer-ctc.sh +++ b/.github/scripts/run-pre-trained-conformer-ctc.sh @@ -19,7 +19,6 @@ repo=$(basename $repo_url) log "Display test files" tree $repo/ -soxi $repo/test_wavs/*.flac ls -lh $repo/test_wavs/*.flac log "CTC decoding" diff --git a/.github/scripts/run-pre-trained-transducer-stateless-librispeech-100h.sh b/.github/scripts/run-pre-trained-transducer-stateless-librispeech-100h.sh index 209d4814f..89115e88d 100755 --- a/.github/scripts/run-pre-trained-transducer-stateless-librispeech-100h.sh +++ b/.github/scripts/run-pre-trained-transducer-stateless-librispeech-100h.sh @@ -19,7 +19,6 @@ repo=$(basename $repo_url) log "Display test files" tree $repo/ -soxi $repo/test_wavs/*.wav ls -lh $repo/test_wavs/*.wav for sym in 1 2 3; do diff --git a/.github/scripts/run-pre-trained-transducer-stateless-librispeech-960h.sh b/.github/scripts/run-pre-trained-transducer-stateless-librispeech-960h.sh index 34ff76fe4..85e2c89e6 100755 --- a/.github/scripts/run-pre-trained-transducer-stateless-librispeech-960h.sh +++ b/.github/scripts/run-pre-trained-transducer-stateless-librispeech-960h.sh @@ -19,7 +19,6 @@ repo=$(basename $repo_url) log "Display test files" tree $repo/ -soxi $repo/test_wavs/*.wav ls -lh $repo/test_wavs/*.wav for sym in 1 2 3; do diff --git a/.github/scripts/run-pre-trained-transducer-stateless-modified-2-aishell.sh b/.github/scripts/run-pre-trained-transducer-stateless-modified-2-aishell.sh index 75650c2d3..0644d9be0 100755 --- a/.github/scripts/run-pre-trained-transducer-stateless-modified-2-aishell.sh +++ b/.github/scripts/run-pre-trained-transducer-stateless-modified-2-aishell.sh @@ -19,7 +19,6 @@ repo=$(basename $repo_url) log "Display test files" tree $repo/ -soxi $repo/test_wavs/*.wav ls -lh $repo/test_wavs/*.wav for sym in 1 2 3; do diff --git a/.github/scripts/run-pre-trained-transducer-stateless-modified-aishell.sh b/.github/scripts/run-pre-trained-transducer-stateless-modified-aishell.sh index bcc2d74cb..79fb64311 100755 --- a/.github/scripts/run-pre-trained-transducer-stateless-modified-aishell.sh +++ b/.github/scripts/run-pre-trained-transducer-stateless-modified-aishell.sh @@ -19,7 +19,6 @@ repo=$(basename $repo_url) log "Display test files" tree $repo/ -soxi $repo/test_wavs/*.wav ls -lh $repo/test_wavs/*.wav for sym in 1 2 3; do diff --git a/.github/scripts/run-pre-trained-transducer-stateless.sh b/.github/scripts/run-pre-trained-transducer-stateless.sh index d3e40315a..41456f11b 100755 --- a/.github/scripts/run-pre-trained-transducer-stateless.sh +++ b/.github/scripts/run-pre-trained-transducer-stateless.sh @@ -19,7 +19,6 @@ repo=$(basename $repo_url) log "Display test files" tree $repo/ -soxi $repo/test_wavs/*.wav ls -lh $repo/test_wavs/*.wav for sym in 1 2 3; do diff --git a/.github/scripts/run-pre-trained-transducer.sh b/.github/scripts/run-pre-trained-transducer.sh index cfa006776..1331c966c 100755 --- a/.github/scripts/run-pre-trained-transducer.sh +++ b/.github/scripts/run-pre-trained-transducer.sh @@ -19,7 +19,6 @@ repo=$(basename $repo_url) log "Display test files" tree $repo/ -soxi $repo/test_wavs/*.wav ls -lh $repo/test_wavs/*.wav log "Beam search decoding" diff --git a/.github/scripts/run-wenetspeech-pruned-transducer-stateless2.sh b/.github/scripts/run-wenetspeech-pruned-transducer-stateless2.sh index 2d237dcf2..90097c752 100755 --- a/.github/scripts/run-wenetspeech-pruned-transducer-stateless2.sh +++ b/.github/scripts/run-wenetspeech-pruned-transducer-stateless2.sh @@ -20,7 +20,6 @@ repo=$(basename $repo_url) log "Display test files" tree $repo/ -soxi $repo/test_wavs/*.wav ls -lh $repo/test_wavs/*.wav pushd $repo/exp diff --git a/.github/scripts/test-ncnn-export.sh b/.github/scripts/test-ncnn-export.sh index 9f5df2d58..52491d2ea 100755 --- a/.github/scripts/test-ncnn-export.sh +++ b/.github/scripts/test-ncnn-export.sh @@ -232,70 +232,3 @@ python3 ./pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py \ rm -rf $repo log "--------------------------------------------------------------------------" - -# Go back to the root directory of icefall repo -popd - -pushd egs/csj/ASR - -log "==========================================================================" -repo_url=https://huggingface.co/TeoWenShen/icefall-asr-csj-pruned-transducer-stateless7-streaming-230208 -GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url -repo=$(basename $repo_url) - -pushd $repo -git lfs pull --include "exp_fluent/pretrained.pt" -git lfs pull --include "exp_disfluent/pretrained.pt" - -cd exp_fluent -ln -s pretrained.pt epoch-99.pt - -cd ../exp_disfluent -ln -s pretrained.pt epoch-99.pt - -cd ../test_wavs -git lfs pull --include "*.wav" -popd - -log "Export via torch.jit.trace()" - -for exp in exp_fluent exp_disfluent; do - ./pruned_transducer_stateless7_streaming/export-for-ncnn.py \ - --exp-dir $repo/$exp/ \ - --lang $repo/data/lang_char \ - --epoch 99 \ - --avg 1 \ - --use-averaged-model 0 \ - \ - --decode-chunk-len 32 \ - --num-left-chunks 4 \ - --num-encoder-layers "2,4,3,2,4" \ - --feedforward-dims "1024,1024,2048,2048,1024" \ - --nhead "8,8,8,8,8" \ - --encoder-dims "384,384,384,384,384" \ - --attention-dims "192,192,192,192,192" \ - --encoder-unmasked-dims "256,256,256,256,256" \ - --zipformer-downsampling-factors "1,2,4,8,2" \ - --cnn-module-kernels "31,31,31,31,31" \ - --decoder-dim 512 \ - --joiner-dim 512 - - pnnx $repo/$exp/encoder_jit_trace-pnnx.pt - pnnx $repo/$exp/decoder_jit_trace-pnnx.pt - pnnx $repo/$exp/joiner_jit_trace-pnnx.pt - - for wav in aps-smp.wav interview_aps-smp.wav reproduction-smp.wav sps-smp.wav; do - python3 ./pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py \ - --tokens $repo/data/lang_char/tokens.txt \ - --encoder-param-filename $repo/$exp/encoder_jit_trace-pnnx.ncnn.param \ - --encoder-bin-filename $repo/$exp/encoder_jit_trace-pnnx.ncnn.bin \ - --decoder-param-filename $repo/$exp/decoder_jit_trace-pnnx.ncnn.param \ - --decoder-bin-filename $repo/$exp/decoder_jit_trace-pnnx.ncnn.bin \ - --joiner-param-filename $repo/$exp/joiner_jit_trace-pnnx.ncnn.param \ - --joiner-bin-filename $repo/$exp/joiner_jit_trace-pnnx.ncnn.bin \ - $repo/test_wavs/$wav - done -done - -rm -rf $repo -log "--------------------------------------------------------------------------" diff --git a/.github/workflows/run-aishell-2022-06-20.yml b/.github/workflows/run-aishell-2022-06-20.yml index 1865a0da8..f5ba73195 100644 --- a/.github/workflows/run-aishell-2022-06-20.yml +++ b/.github/workflows/run-aishell-2022-06-20.yml @@ -65,7 +65,7 @@ jobs: run: | grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install pip uninstall -y protobuf - pip install --no-binary protobuf protobuf + pip install --no-binary protobuf protobuf==3.20.* - name: Cache kaldifeat id: my-cache @@ -87,7 +87,7 @@ jobs: GITHUB_EVENT_NAME: ${{ github.event_name }} GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }} run: | - sudo apt-get -qq install git-lfs tree sox + sudo apt-get -qq install git-lfs tree export PYTHONPATH=$PWD:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH diff --git a/.github/workflows/run-gigaspeech-2022-05-13.yml b/.github/workflows/run-gigaspeech-2022-05-13.yml index e438c5dba..c7b9cc79d 100644 --- a/.github/workflows/run-gigaspeech-2022-05-13.yml +++ b/.github/workflows/run-gigaspeech-2022-05-13.yml @@ -64,7 +64,7 @@ jobs: run: | grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install pip uninstall -y protobuf - pip install --no-binary protobuf protobuf + pip install --no-binary protobuf protobuf==3.20.* - name: Cache kaldifeat id: my-cache diff --git a/.github/workflows/run-librispeech-2022-03-12.yml b/.github/workflows/run-librispeech-2022-03-12.yml index 3ba6850cd..9c7cd1228 100644 --- a/.github/workflows/run-librispeech-2022-03-12.yml +++ b/.github/workflows/run-librispeech-2022-03-12.yml @@ -64,7 +64,7 @@ jobs: run: | grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install pip uninstall -y protobuf - pip install --no-binary protobuf protobuf + pip install --no-binary protobuf protobuf==3.20.* - name: Cache kaldifeat id: my-cache @@ -123,7 +123,7 @@ jobs: ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank ls -lh egs/librispeech/ASR/data/* - sudo apt-get -qq install git-lfs tree sox + sudo apt-get -qq install git-lfs tree export PYTHONPATH=$PWD:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH diff --git a/.github/workflows/run-librispeech-2022-04-29.yml b/.github/workflows/run-librispeech-2022-04-29.yml index 595b410b8..78c9e759f 100644 --- a/.github/workflows/run-librispeech-2022-04-29.yml +++ b/.github/workflows/run-librispeech-2022-04-29.yml @@ -64,7 +64,7 @@ jobs: run: | grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install pip uninstall -y protobuf - pip install --no-binary protobuf protobuf + pip install --no-binary protobuf protobuf==3.20.* - name: Cache kaldifeat id: my-cache @@ -123,7 +123,7 @@ jobs: ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank ls -lh egs/librispeech/ASR/data/* - sudo apt-get -qq install git-lfs tree sox + sudo apt-get -qq install git-lfs tree export PYTHONPATH=$PWD:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH diff --git a/.github/workflows/run-librispeech-2022-05-13.yml b/.github/workflows/run-librispeech-2022-05-13.yml index eb0b06a2d..04799bf52 100644 --- a/.github/workflows/run-librispeech-2022-05-13.yml +++ b/.github/workflows/run-librispeech-2022-05-13.yml @@ -64,7 +64,7 @@ jobs: run: | grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install pip uninstall -y protobuf - pip install --no-binary protobuf protobuf + pip install --no-binary protobuf protobuf==3.20.* - name: Cache kaldifeat id: my-cache @@ -123,7 +123,7 @@ jobs: ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank ls -lh egs/librispeech/ASR/data/* - sudo apt-get -qq install git-lfs tree sox + sudo apt-get -qq install git-lfs tree export PYTHONPATH=$PWD:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH diff --git a/.github/workflows/run-librispeech-2022-11-11-stateless7.yml b/.github/workflows/run-librispeech-2022-11-11-stateless7.yml index 365e2761a..6dfc23920 100644 --- a/.github/workflows/run-librispeech-2022-11-11-stateless7.yml +++ b/.github/workflows/run-librispeech-2022-11-11-stateless7.yml @@ -64,7 +64,7 @@ jobs: run: | grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install pip uninstall -y protobuf - pip install --no-binary protobuf protobuf + pip install --no-binary protobuf protobuf==3.20.* - name: Cache kaldifeat id: my-cache @@ -123,7 +123,7 @@ jobs: ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank ls -lh egs/librispeech/ASR/data/* - sudo apt-get -qq install git-lfs tree sox + sudo apt-get -qq install git-lfs tree export PYTHONPATH=$PWD:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH diff --git a/.github/workflows/run-librispeech-2022-11-14-stateless8.yml b/.github/workflows/run-librispeech-2022-11-14-stateless8.yml index acb11a8f4..0544e68b3 100644 --- a/.github/workflows/run-librispeech-2022-11-14-stateless8.yml +++ b/.github/workflows/run-librispeech-2022-11-14-stateless8.yml @@ -64,7 +64,7 @@ jobs: run: | grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install pip uninstall -y protobuf - pip install --no-binary protobuf protobuf + pip install --no-binary protobuf protobuf==3.20.* - name: Cache kaldifeat id: my-cache @@ -123,7 +123,7 @@ jobs: ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank ls -lh egs/librispeech/ASR/data/* - sudo apt-get -qq install git-lfs tree sox + sudo apt-get -qq install git-lfs tree export PYTHONPATH=$PWD:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH diff --git a/.github/workflows/run-librispeech-2022-12-01-stateless7-ctc.yml b/.github/workflows/run-librispeech-2022-12-01-stateless7-ctc.yml index ccd8d50d0..62e1f2a01 100644 --- a/.github/workflows/run-librispeech-2022-12-01-stateless7-ctc.yml +++ b/.github/workflows/run-librispeech-2022-12-01-stateless7-ctc.yml @@ -60,7 +60,7 @@ jobs: run: | grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install pip uninstall -y protobuf - pip install --no-binary protobuf protobuf + pip install --no-binary protobuf protobuf==3.20.* - name: Cache kaldifeat id: my-cache @@ -119,7 +119,7 @@ jobs: ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank ls -lh egs/librispeech/ASR/data/* - sudo apt-get -qq install git-lfs tree sox + sudo apt-get -qq install git-lfs tree export PYTHONPATH=$PWD:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH diff --git a/.github/workflows/run-librispeech-2022-12-08-zipformer-mmi.yml b/.github/workflows/run-librispeech-2022-12-08-zipformer-mmi.yml index 5472ca59b..7dc33aaa9 100644 --- a/.github/workflows/run-librispeech-2022-12-08-zipformer-mmi.yml +++ b/.github/workflows/run-librispeech-2022-12-08-zipformer-mmi.yml @@ -64,7 +64,7 @@ jobs: run: | grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install pip uninstall -y protobuf - pip install --no-binary protobuf protobuf + pip install --no-binary protobuf protobuf==3.20.* - name: Cache kaldifeat id: my-cache @@ -123,7 +123,7 @@ jobs: ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank ls -lh egs/librispeech/ASR/data/* - sudo apt-get -qq install git-lfs tree sox + sudo apt-get -qq install git-lfs tree export PYTHONPATH=$PWD:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH diff --git a/.github/workflows/run-librispeech-2022-12-15-stateless7-ctc-bs.yml b/.github/workflows/run-librispeech-2022-12-15-stateless7-ctc-bs.yml index 6e2b40cf3..de55847ad 100644 --- a/.github/workflows/run-librispeech-2022-12-15-stateless7-ctc-bs.yml +++ b/.github/workflows/run-librispeech-2022-12-15-stateless7-ctc-bs.yml @@ -35,7 +35,7 @@ on: jobs: run_librispeech_2022_12_15_zipformer_ctc_bs: - if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event.label.name == 'blank-skip' || github.event_name == 'push' || github.event_name == 'schedule' + if: github.event.label.name == 'run-decode' || github.event.label.name == 'blank-skip' || github.event_name == 'push' || github.event_name == 'schedule' runs-on: ${{ matrix.os }} strategy: matrix: @@ -60,7 +60,7 @@ jobs: run: | grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install pip uninstall -y protobuf - pip install --no-binary protobuf protobuf + pip install --no-binary protobuf protobuf==3.20.* - name: Cache kaldifeat id: my-cache @@ -119,7 +119,7 @@ jobs: ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank ls -lh egs/librispeech/ASR/data/* - sudo apt-get -qq install git-lfs tree sox + sudo apt-get -qq install git-lfs tree export PYTHONPATH=$PWD:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH diff --git a/.github/workflows/run-librispeech-2022-12-29-stateless7-streaming.yml b/.github/workflows/run-librispeech-2022-12-29-stateless7-streaming.yml index 6dd93946a..feb5c6fd0 100644 --- a/.github/workflows/run-librispeech-2022-12-29-stateless7-streaming.yml +++ b/.github/workflows/run-librispeech-2022-12-29-stateless7-streaming.yml @@ -64,7 +64,7 @@ jobs: run: | grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install pip uninstall -y protobuf - pip install --no-binary protobuf protobuf + pip install --no-binary protobuf protobuf==3.20.* - name: Cache kaldifeat id: my-cache @@ -123,7 +123,7 @@ jobs: ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank ls -lh egs/librispeech/ASR/data/* - sudo apt-get -qq install git-lfs tree sox + sudo apt-get -qq install git-lfs tree export PYTHONPATH=$PWD:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH diff --git a/.github/workflows/run-librispeech-conformer-ctc3-2022-11-28.yml b/.github/workflows/run-librispeech-conformer-ctc3-2022-11-28.yml index d763fb1c5..c95ed8b9a 100644 --- a/.github/workflows/run-librispeech-conformer-ctc3-2022-11-28.yml +++ b/.github/workflows/run-librispeech-conformer-ctc3-2022-11-28.yml @@ -64,7 +64,7 @@ jobs: run: | grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install pip uninstall -y protobuf - pip install --no-binary protobuf protobuf + pip install --no-binary protobuf protobuf==3.20.* - name: Cache kaldifeat id: my-cache @@ -123,7 +123,7 @@ jobs: ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank ls -lh egs/librispeech/ASR/data/* - sudo apt-get -qq install git-lfs tree sox + sudo apt-get -qq install git-lfs tree export PYTHONPATH=$PWD:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH diff --git a/.github/workflows/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml b/.github/workflows/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml index f737d9a25..e14d4e92f 100644 --- a/.github/workflows/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml +++ b/.github/workflows/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml @@ -47,7 +47,7 @@ jobs: run: | grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install pip uninstall -y protobuf - pip install --no-binary protobuf protobuf + pip install --no-binary protobuf protobuf==3.20.* - name: Cache kaldifeat id: my-cache @@ -106,7 +106,7 @@ jobs: ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank ls -lh egs/librispeech/ASR/data/* - sudo apt-get -qq install git-lfs tree sox + sudo apt-get -qq install git-lfs tree export PYTHONPATH=$PWD:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH diff --git a/.github/workflows/run-librispeech-pruned-transducer-stateless3-2022-05-13.yml b/.github/workflows/run-librispeech-pruned-transducer-stateless3-2022-05-13.yml index f67f7599b..73d91fcd4 100644 --- a/.github/workflows/run-librispeech-pruned-transducer-stateless3-2022-05-13.yml +++ b/.github/workflows/run-librispeech-pruned-transducer-stateless3-2022-05-13.yml @@ -64,7 +64,7 @@ jobs: run: | grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install pip uninstall -y protobuf - pip install --no-binary protobuf protobuf + pip install --no-binary protobuf protobuf==3.20.* - name: Cache kaldifeat id: my-cache @@ -123,7 +123,7 @@ jobs: ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank ls -lh egs/librispeech/ASR/data/* - sudo apt-get -qq install git-lfs tree sox + sudo apt-get -qq install git-lfs tree export PYTHONPATH=$PWD:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH diff --git a/.github/workflows/run-librispeech-streaming-transducer-stateless2-2022-06-26.yml b/.github/workflows/run-librispeech-streaming-transducer-stateless2-2022-06-26.yml index ac7e58b20..8a690393e 100644 --- a/.github/workflows/run-librispeech-streaming-transducer-stateless2-2022-06-26.yml +++ b/.github/workflows/run-librispeech-streaming-transducer-stateless2-2022-06-26.yml @@ -64,7 +64,7 @@ jobs: run: | grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install pip uninstall -y protobuf - pip install --no-binary protobuf protobuf + pip install --no-binary protobuf protobuf==3.20.* - name: Cache kaldifeat id: my-cache @@ -123,7 +123,7 @@ jobs: ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank ls -lh egs/librispeech/ASR/data/* - sudo apt-get -qq install git-lfs tree sox + sudo apt-get -qq install git-lfs tree export PYTHONPATH=$PWD:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH diff --git a/.github/workflows/run-librispeech-transducer-stateless2-2022-04-19.yml b/.github/workflows/run-librispeech-transducer-stateless2-2022-04-19.yml index 575727e22..217dbdfa1 100644 --- a/.github/workflows/run-librispeech-transducer-stateless2-2022-04-19.yml +++ b/.github/workflows/run-librispeech-transducer-stateless2-2022-04-19.yml @@ -64,7 +64,7 @@ jobs: run: | grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install pip uninstall -y protobuf - pip install --no-binary protobuf protobuf + pip install --no-binary protobuf protobuf==3.20.* - name: Cache kaldifeat id: my-cache @@ -123,7 +123,7 @@ jobs: ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank ls -lh egs/librispeech/ASR/data/* - sudo apt-get -qq install git-lfs tree sox + sudo apt-get -qq install git-lfs tree export PYTHONPATH=$PWD:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH diff --git a/.github/workflows/run-pretrained-conformer-ctc.yml b/.github/workflows/run-pretrained-conformer-ctc.yml index 7dbfd2bd9..4e8e7b8db 100644 --- a/.github/workflows/run-pretrained-conformer-ctc.yml +++ b/.github/workflows/run-pretrained-conformer-ctc.yml @@ -54,7 +54,7 @@ jobs: run: | grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install pip uninstall -y protobuf - pip install --no-binary protobuf protobuf + pip install --no-binary protobuf protobuf==3.20.* - name: Cache kaldifeat id: my-cache @@ -73,7 +73,7 @@ jobs: - name: Inference with pre-trained model shell: bash run: | - sudo apt-get -qq install git-lfs tree sox + sudo apt-get -qq install git-lfs tree export PYTHONPATH=$PWD:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH diff --git a/.github/workflows/run-pretrained-transducer-stateless-librispeech-100h.yml b/.github/workflows/run-pretrained-transducer-stateless-librispeech-100h.yml index d6b3de8d4..ddde4f1d6 100644 --- a/.github/workflows/run-pretrained-transducer-stateless-librispeech-100h.yml +++ b/.github/workflows/run-pretrained-transducer-stateless-librispeech-100h.yml @@ -63,7 +63,7 @@ jobs: run: | grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install pip uninstall -y protobuf - pip install --no-binary protobuf protobuf + pip install --no-binary protobuf protobuf==3.20.* - name: Cache kaldifeat id: my-cache @@ -122,7 +122,7 @@ jobs: ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank ls -lh egs/librispeech/ASR/data/* - sudo apt-get -qq install git-lfs tree sox + sudo apt-get -qq install git-lfs tree export PYTHONPATH=$PWD:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH diff --git a/.github/workflows/run-pretrained-transducer-stateless-librispeech-multi-datasets.yml b/.github/workflows/run-pretrained-transducer-stateless-librispeech-multi-datasets.yml index 749fb3fca..00ea97b2a 100644 --- a/.github/workflows/run-pretrained-transducer-stateless-librispeech-multi-datasets.yml +++ b/.github/workflows/run-pretrained-transducer-stateless-librispeech-multi-datasets.yml @@ -63,7 +63,7 @@ jobs: run: | grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install pip uninstall -y protobuf - pip install --no-binary protobuf protobuf + pip install --no-binary protobuf protobuf==3.20.* - name: Cache kaldifeat id: my-cache @@ -122,7 +122,7 @@ jobs: ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank ls -lh egs/librispeech/ASR/data/* - sudo apt-get -qq install git-lfs tree sox + sudo apt-get -qq install git-lfs tree export PYTHONPATH=$PWD:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH diff --git a/.github/workflows/run-pretrained-transducer-stateless-modified-2-aishell.yml b/.github/workflows/run-pretrained-transducer-stateless-modified-2-aishell.yml index 92bf6feb8..b3cfc9efd 100644 --- a/.github/workflows/run-pretrained-transducer-stateless-modified-2-aishell.yml +++ b/.github/workflows/run-pretrained-transducer-stateless-modified-2-aishell.yml @@ -54,7 +54,7 @@ jobs: run: | grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install pip uninstall -y protobuf - pip install --no-binary protobuf protobuf + pip install --no-binary protobuf protobuf==3.20.* - name: Cache kaldifeat id: my-cache @@ -73,7 +73,7 @@ jobs: - name: Inference with pre-trained model shell: bash run: | - sudo apt-get -qq install git-lfs tree sox + sudo apt-get -qq install git-lfs tree export PYTHONPATH=$PWD:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH diff --git a/.github/workflows/run-pretrained-transducer-stateless-modified-aishell.yml b/.github/workflows/run-pretrained-transducer-stateless-modified-aishell.yml index e51da8bd8..ab598541d 100644 --- a/.github/workflows/run-pretrained-transducer-stateless-modified-aishell.yml +++ b/.github/workflows/run-pretrained-transducer-stateless-modified-aishell.yml @@ -54,7 +54,7 @@ jobs: run: | grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install pip uninstall -y protobuf - pip install --no-binary protobuf protobuf + pip install --no-binary protobuf protobuf==3.20.* - name: Cache kaldifeat id: my-cache @@ -73,7 +73,7 @@ jobs: - name: Inference with pre-trained model shell: bash run: | - sudo apt-get -qq install git-lfs tree sox + sudo apt-get -qq install git-lfs tree export PYTHONPATH=$PWD:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH diff --git a/.github/workflows/run-pretrained-transducer-stateless.yml b/.github/workflows/run-pretrained-transducer-stateless.yml index 2103d0510..d663d49dd 100644 --- a/.github/workflows/run-pretrained-transducer-stateless.yml +++ b/.github/workflows/run-pretrained-transducer-stateless.yml @@ -63,7 +63,7 @@ jobs: run: | grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install pip uninstall -y protobuf - pip install --no-binary protobuf protobuf + pip install --no-binary protobuf protobuf==3.20.* - name: Cache kaldifeat id: my-cache @@ -122,7 +122,7 @@ jobs: ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank ls -lh egs/librispeech/ASR/data/* - sudo apt-get -qq install git-lfs tree sox + sudo apt-get -qq install git-lfs tree export PYTHONPATH=$PWD:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH diff --git a/.github/workflows/run-pretrained-transducer.yml b/.github/workflows/run-pretrained-transducer.yml index 902319b55..9cb9d3b59 100644 --- a/.github/workflows/run-pretrained-transducer.yml +++ b/.github/workflows/run-pretrained-transducer.yml @@ -54,7 +54,7 @@ jobs: run: | grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install pip uninstall -y protobuf - pip install --no-binary protobuf protobuf + pip install --no-binary protobuf protobuf==3.20.* - name: Cache kaldifeat id: my-cache @@ -73,7 +73,7 @@ jobs: - name: Inference with pre-trained model shell: bash run: | - sudo apt-get -qq install git-lfs tree sox + sudo apt-get -qq install git-lfs tree export PYTHONPATH=$PWD:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH diff --git a/.github/workflows/run-ptb-rnn-lm.yml b/.github/workflows/run-ptb-rnn-lm.yml index 47ed958f2..f8d9c02c5 100644 --- a/.github/workflows/run-ptb-rnn-lm.yml +++ b/.github/workflows/run-ptb-rnn-lm.yml @@ -47,7 +47,7 @@ jobs: run: | grep -v '^#' ./requirements-ci.txt | grep -v kaldifst | xargs -n 1 -L 1 pip install pip uninstall -y protobuf - pip install --no-binary protobuf protobuf + pip install --no-binary protobuf protobuf==3.20.* - name: Prepare data shell: bash diff --git a/.github/workflows/run-wenetspeech-pruned-transducer-stateless2.yml b/.github/workflows/run-wenetspeech-pruned-transducer-stateless2.yml index 8a7be0b80..14fb96ec8 100644 --- a/.github/workflows/run-wenetspeech-pruned-transducer-stateless2.yml +++ b/.github/workflows/run-wenetspeech-pruned-transducer-stateless2.yml @@ -54,7 +54,7 @@ jobs: run: | grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install pip uninstall -y protobuf - pip install --no-binary protobuf protobuf + pip install --no-binary protobuf protobuf==3.20.* - name: Cache kaldifeat id: my-cache @@ -76,7 +76,7 @@ jobs: GITHUB_EVENT_NAME: ${{ github.event_name }} GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }} run: | - sudo apt-get -qq install git-lfs tree sox + sudo apt-get -qq install git-lfs tree export PYTHONPATH=$PWD:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH diff --git a/.github/workflows/run-yesno-recipe.yml b/.github/workflows/run-yesno-recipe.yml index ed343aee5..1187dbf38 100644 --- a/.github/workflows/run-yesno-recipe.yml +++ b/.github/workflows/run-yesno-recipe.yml @@ -67,7 +67,7 @@ jobs: run: | grep -v '^#' ./requirements-ci.txt | grep -v kaldifst | xargs -n 1 -L 1 pip install pip uninstall -y protobuf - pip install --no-binary protobuf protobuf + pip install --no-binary protobuf protobuf==3.20.* - name: Run yesno recipe shell: bash diff --git a/.github/workflows/test-ncnn-export.yml b/.github/workflows/test-ncnn-export.yml index e10cfe76b..cdea54854 100644 --- a/.github/workflows/test-ncnn-export.yml +++ b/.github/workflows/test-ncnn-export.yml @@ -46,7 +46,7 @@ jobs: run: | grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install pip uninstall -y protobuf - pip install --no-binary protobuf protobuf + pip install --no-binary protobuf protobuf==3.20.* - name: Cache kaldifeat id: my-cache diff --git a/.github/workflows/test-onnx-export.yml b/.github/workflows/test-onnx-export.yml index c7729dedb..3dc4261ab 100644 --- a/.github/workflows/test-onnx-export.yml +++ b/.github/workflows/test-onnx-export.yml @@ -46,7 +46,7 @@ jobs: run: | grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install pip uninstall -y protobuf - pip install --no-binary protobuf protobuf + pip install --no-binary protobuf protobuf==3.20.* - name: Cache kaldifeat id: my-cache diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index c062a2a3d..0da4f6b4b 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -56,7 +56,7 @@ jobs: run: | sudo apt update sudo apt install -q -y libsndfile1-dev libsndfile1 ffmpeg - sudo apt install -q -y --fix-missing sox libsox-dev libsox-fmt-all + sudo apt install -q -y --fix-missing libsox-dev libsox-fmt-all - name: Install Python dependencies run: | @@ -70,7 +70,7 @@ jobs: pip install git+https://github.com/lhotse-speech/lhotse # icefall requirements pip uninstall -y protobuf - pip install --no-binary protobuf protobuf + pip install --no-binary protobuf protobuf==3.20.* pip install kaldifst pip install onnxruntime diff --git a/egs/librispeech/ASR/local/compute_fbank_librispeech.py b/egs/librispeech/ASR/local/compute_fbank_librispeech.py index 9f8503814..745eaf1e8 100755 --- a/egs/librispeech/ASR/local/compute_fbank_librispeech.py +++ b/egs/librispeech/ASR/local/compute_fbank_librispeech.py @@ -54,10 +54,20 @@ def get_args(): help="""Path to the bpe.model. If not None, we will remove short and long utterances before extracting features""", ) + + parser.add_argument( + "--dataset", + type=str, + help="""Dataset parts to compute fbank. If None, we will use all""", + ) + return parser.parse_args() -def compute_fbank_librispeech(bpe_model: Optional[str] = None): +def compute_fbank_librispeech( + bpe_model: Optional[str] = None, + dataset: Optional[str] = None, +): src_dir = Path("data/manifests") output_dir = Path("data/fbank") num_jobs = min(15, os.cpu_count()) @@ -68,15 +78,19 @@ def compute_fbank_librispeech(bpe_model: Optional[str] = None): sp = spm.SentencePieceProcessor() sp.load(bpe_model) - dataset_parts = ( - "dev-clean", - "dev-other", - "test-clean", - "test-other", - "train-clean-100", - "train-clean-360", - "train-other-500", - ) + if dataset is None: + dataset_parts = ( + "dev-clean", + "dev-other", + "test-clean", + "test-other", + "train-clean-100", + "train-clean-360", + "train-other-500", + ) + else: + dataset_parts = dataset.split(" ", -1) + prefix = "librispeech" suffix = "jsonl.gz" manifests = read_manifests_if_cached( @@ -131,4 +145,4 @@ if __name__ == "__main__": logging.basicConfig(format=formatter, level=logging.INFO) args = get_args() logging.info(vars(args)) - compute_fbank_librispeech(bpe_model=args.bpe_model) + compute_fbank_librispeech(bpe_model=args.bpe_model, dataset=args.dataset) From 9ddd811925534dc47b183a23429a4727c6416e81 Mon Sep 17 00:00:00 2001 From: marcoyang1998 <45973641+marcoyang1998@users.noreply.github.com> Date: Fri, 10 Mar 2023 14:37:28 +0800 Subject: [PATCH 129/174] Fix padding_idx (#942) * fix padding_idx * update RESULTS.md --- egs/librispeech/ASR/RESULTS.md | 4 ++++ egs/librispeech/ASR/pruned_transducer_stateless/decoder.py | 1 - egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py | 1 - egs/librispeech/ASR/pruned_transducer_stateless7/decoder.py | 1 - 4 files changed, 4 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/RESULTS.md b/egs/librispeech/ASR/RESULTS.md index ecb84eb01..9ca7a19b8 100644 --- a/egs/librispeech/ASR/RESULTS.md +++ b/egs/librispeech/ASR/RESULTS.md @@ -540,6 +540,10 @@ for m in greedy_search fast_beam_search modified_beam_search ; do done ``` +Note that a small change is made to the `pruned_transducer_stateless7/decoder.py` in +this [PR](/ceph-data4/yangxiaoyu/softwares/icefall_development/icefall_random_padding/egs/librispeech/ASR/pruned_transducer_stateless7/exp_960h_no_paddingidx_ngpu4/tensorboard) to address the +problem of emitting the first symbol at the very beginning. If you need a +model without this issue, please download the model from here: ### LibriSpeech BPE training results (Pruned Stateless LSTM RNN-T + gradient filter) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/decoder.py b/egs/librispeech/ASR/pruned_transducer_stateless/decoder.py index 72593173c..49b82c433 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless/decoder.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/decoder.py @@ -58,7 +58,6 @@ class Decoder(nn.Module): self.embedding = nn.Embedding( num_embeddings=vocab_size, embedding_dim=embedding_dim, - padding_idx=blank_id, ) self.blank_id = blank_id self.unk_id = unk_id diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py b/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py index b59928103..d44ed6f81 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py @@ -59,7 +59,6 @@ class Decoder(nn.Module): self.embedding = ScaledEmbedding( num_embeddings=vocab_size, embedding_dim=decoder_dim, - padding_idx=blank_id, ) self.blank_id = blank_id diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/decoder.py b/egs/librispeech/ASR/pruned_transducer_stateless7/decoder.py index 384b78524..b085a1817 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/decoder.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/decoder.py @@ -56,7 +56,6 @@ class Decoder(nn.Module): self.embedding = nn.Embedding( num_embeddings=vocab_size, embedding_dim=decoder_dim, - padding_idx=blank_id, ) self.blank_id = blank_id From cad6735e0739f149ba3f452e52a948da946527dc Mon Sep 17 00:00:00 2001 From: Yifan Yang <64255737+yfyeung@users.noreply.github.com> Date: Fri, 10 Mar 2023 19:28:59 +0800 Subject: [PATCH 130/174] Modify make_pad_mask to support TensorRT (#943) * Modify make_pad_mask to support TensorRT * Fix for test --- egs/librispeech/ASR/transducer/test_rnn.py | 10 +++++----- icefall/utils.py | 6 +++--- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/egs/librispeech/ASR/transducer/test_rnn.py b/egs/librispeech/ASR/transducer/test_rnn.py index 74c94cc70..d8effb996 100755 --- a/egs/librispeech/ASR/transducer/test_rnn.py +++ b/egs/librispeech/ASR/transducer/test_rnn.py @@ -432,11 +432,11 @@ def test_layernorm_lstm_forward(device="cpu"): def test_layernorm_lstm_with_projection_forward(device="cpu"): - input_size = torch.randint(low=2, high=100, size=(1,)).item() - hidden_size = torch.randint(low=10, high=100, size=(1,)).item() - proj_size = torch.randint(low=2, high=hidden_size, size=(1,)).item() - num_layers = torch.randint(low=2, high=100, size=(1,)).item() - bias = torch.randint(low=0, high=1000, size=(1,)).item() & 2 == 0 + input_size = 40 # torch.randint(low=2, high=100, size=(1,)).item() + hidden_size = 40 # torch.randint(low=10, high=100, size=(1,)).item() + proj_size = 20 # torch.randint(low=2, high=hidden_size, size=(1,)).item() + num_layers = 12 # torch.randint(low=2, high=100, size=(1,)).item() + bias = True # torch.randint(low=0, high=1000, size=(1,)).item() & 2 == 0 self_lstm = LayerNormLSTM( input_size=input_size, diff --git a/icefall/utils.py b/icefall/utils.py index 2358ed02f..5d86472b5 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -1095,10 +1095,10 @@ def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor: assert lengths.ndim == 1, lengths.ndim max_len = max(max_len, lengths.max()) n = lengths.size(0) + seq_range = torch.arange(0, max_len, device=lengths.device) + expaned_lengths = seq_range.unsqueeze(0).expand(n, max_len) - expaned_lengths = torch.arange(max_len).expand(n, max_len).to(lengths) - - return expaned_lengths >= lengths.unsqueeze(1) + return expaned_lengths >= lengths.unsqueeze(-1) # Copied and modified from https://github.com/wenet-e2e/wenet/blob/main/wenet/utils/mask.py From a48812ddb307069339e029942321b8c7417aed93 Mon Sep 17 00:00:00 2001 From: Yifan Yang <64255737+yfyeung@users.noreply.github.com> Date: Wed, 15 Mar 2023 22:02:20 +0800 Subject: [PATCH 131/174] Ban the test_rnn.py in ci-test (#949) --- .github/workflows/test.yml | 8 ++++---- egs/librispeech/ASR/transducer/test_rnn.py | 10 +++++----- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 0da4f6b4b..079772e97 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -119,8 +119,8 @@ jobs: cd ../transducer_stateless pytest -v -s - cd ../transducer - pytest -v -s + # cd ../transducer + # pytest -v -s cd ../transducer_stateless2 pytest -v -s @@ -157,8 +157,8 @@ jobs: cd ../transducer_stateless pytest -v -s - cd ../transducer - pytest -v -s + # cd ../transducer + # pytest -v -s cd ../transducer_stateless2 pytest -v -s diff --git a/egs/librispeech/ASR/transducer/test_rnn.py b/egs/librispeech/ASR/transducer/test_rnn.py index d8effb996..74c94cc70 100755 --- a/egs/librispeech/ASR/transducer/test_rnn.py +++ b/egs/librispeech/ASR/transducer/test_rnn.py @@ -432,11 +432,11 @@ def test_layernorm_lstm_forward(device="cpu"): def test_layernorm_lstm_with_projection_forward(device="cpu"): - input_size = 40 # torch.randint(low=2, high=100, size=(1,)).item() - hidden_size = 40 # torch.randint(low=10, high=100, size=(1,)).item() - proj_size = 20 # torch.randint(low=2, high=hidden_size, size=(1,)).item() - num_layers = 12 # torch.randint(low=2, high=100, size=(1,)).item() - bias = True # torch.randint(low=0, high=1000, size=(1,)).item() & 2 == 0 + input_size = torch.randint(low=2, high=100, size=(1,)).item() + hidden_size = torch.randint(low=10, high=100, size=(1,)).item() + proj_size = torch.randint(low=2, high=hidden_size, size=(1,)).item() + num_layers = torch.randint(low=2, high=100, size=(1,)).item() + bias = torch.randint(low=0, high=1000, size=(1,)).item() & 2 == 0 self_lstm = LayerNormLSTM( input_size=input_size, From 6196b4a407f0ff4359814c81c385eefb5636f04d Mon Sep 17 00:00:00 2001 From: Jason's Lab <563042811@qq.com> Date: Thu, 16 Mar 2023 09:52:11 +0800 Subject: [PATCH 132/174] Add char-based language model training process for aishell. (#945) * Add char-based language model training process for aishell. Add soft link from librispeech/ASR/local/sort_lm_training_data.py to aishell/ASR/local/ --------- Co-authored-by: lichao --- .../local/prepare_char_lm_training_data.py | 164 ++++++++++++++++++ egs/aishell/ASR/prepare.sh | 92 +++++++++- 2 files changed, 255 insertions(+), 1 deletion(-) create mode 100644 egs/aishell/ASR/local/prepare_char_lm_training_data.py diff --git a/egs/aishell/ASR/local/prepare_char_lm_training_data.py b/egs/aishell/ASR/local/prepare_char_lm_training_data.py new file mode 100644 index 000000000..e7995680b --- /dev/null +++ b/egs/aishell/ASR/local/prepare_char_lm_training_data.py @@ -0,0 +1,164 @@ +#!/usr/bin/env python3 + +# Copyright (c) 2021 Xiaomi Corporation (authors: Daniel Povey +# Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script takes a `tokens.txt` and a text file such as +./download/lm/aishell-transcript.txt +and outputs the LM training data to a supplied directory such +as data/lm_training_char. The format is as follows: +It creates a PyTorch archive (.pt file), say data/lm_training.pt, which is a +representation of a dict with the same format with librispeech receipe +""" + +import argparse +import logging +from pathlib import Path + +import k2 +import torch + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--lang-char", + type=str, + help="""Lang dir of asr model, e.g. data/lang_char""", + ) + parser.add_argument( + "--lm-data", + type=str, + help="""Input LM training data as text, e.g. + download/lm/aishell-train-word.txt""", + ) + parser.add_argument( + "--lm-archive", + type=str, + help="""Path to output archive, e.g. data/lm_training_char/lm_data.pt; + look at the source of this script to see the format.""", + ) + + return parser.parse_args() + + +def main(): + args = get_args() + + if Path(args.lm_archive).exists(): + logging.warning(f"{args.lm_archive} exists - skipping") + return + + # make token_dict from tokens.txt in order to map characters to tokens. + token_dict = {} + token_file = args.lang_char + "/tokens.txt" + + with open(token_file, "r") as f: + for line in f.readlines(): + line_list = line.split() + token_dict[line_list[0]] = int(line_list[1]) + + # word2index is a dictionary from words to integer ids. No need to reserve + # space for epsilon, etc.; the words are just used as a convenient way to + # compress the sequences of tokens. + word2index = dict() + + word2token = [] # Will be a list-of-list-of-int, representing tokens. + sentences = [] # Will be a list-of-list-of-int, representing word-ids. + + if "aishell-lm" in args.lm_data: + num_lines_in_total = 120098.0 + step = 50000 + elif "valid" in args.lm_data: + num_lines_in_total = 14326.0 + step = 3000 + elif "test" in args.lm_data: + num_lines_in_total = 7176.0 + step = 3000 + else: + num_lines_in_total = None + step = None + + processed = 0 + + with open(args.lm_data) as f: + while True: + line = f.readline() + if line == "": + break + + if step and processed % step == 0: + logging.info( + f"Processed number of lines: {processed} " + f"({processed / num_lines_in_total * 100: .3f}%)" + ) + processed += 1 + + line_words = line.split() + for w in line_words: + if w not in word2index: + w_token = [] + for t in w: + if t in token_dict: + w_token.append(token_dict[t]) + else: + w_token.append(token_dict[""]) + word2index[w] = len(word2token) + word2token.append(w_token) + sentences.append([word2index[w] for w in line_words]) + + logging.info("Constructing ragged tensors") + words = k2.ragged.RaggedTensor(word2token) + sentences = k2.ragged.RaggedTensor(sentences) + + output = dict(words=words, sentences=sentences) + + num_sentences = sentences.dim0 + logging.info(f"Computing sentence lengths, num_sentences: {num_sentences}") + sentence_lengths = [0] * num_sentences + for i in range(num_sentences): + if step and i % step == 0: + logging.info( + f"Processed number of lines: {i} ({i / num_sentences * 100: .3f}%)" + ) + + word_ids = sentences[i] + + # NOTE: If word_ids is a tensor with only 1 entry, + # token_ids is a torch.Tensor + token_ids = words[word_ids] + if isinstance(token_ids, k2.RaggedTensor): + token_ids = token_ids.values + + # token_ids is a 1-D tensor containing the BPE tokens + # of the current sentence + + sentence_lengths[i] = token_ids.numel() + + output["sentence_lengths"] = torch.tensor(sentence_lengths, dtype=torch.int32) + + torch.save(output, args.lm_archive) + logging.info(f"Saved to {args.lm_archive}") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + + main() diff --git a/egs/aishell/ASR/prepare.sh b/egs/aishell/ASR/prepare.sh index 5917668a1..cf4ee7818 100755 --- a/egs/aishell/ASR/prepare.sh +++ b/egs/aishell/ASR/prepare.sh @@ -7,7 +7,7 @@ set -eou pipefail nj=15 stage=-1 -stop_stage=10 +stop_stage=11 # We assume dl_dir (download dir) contains the following # directories and files. If not, they will be downloaded @@ -219,3 +219,93 @@ if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then ./local/compile_hlg.py --lang-dir $lang_phone_dir ./local/compile_hlg.py --lang-dir $lang_char_dir fi + +if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then + log "Stage 9: Generate LM training data" + + log "Processing char based data" + out_dir=data/lm_training_char + mkdir -p $out_dir $dl_dir/lm + + if [ ! -f $dl_dir/lm/aishell-train-word.txt ]; then + cp $lang_phone_dir/transcript_words.txt $dl_dir/lm/aishell-train-word.txt + fi + + ./local/prepare_char_lm_training_data.py \ + --lang-char data/lang_char \ + --lm-data $dl_dir/lm/aishell-train-word.txt \ + --lm-archive $out_dir/lm_data.pt + + if [ ! -f $dl_dir/lm/aishell-valid-word.txt ]; then + aishell_text=$dl_dir/aishell/data_aishell/transcript/aishell_transcript_v0.8.txt + aishell_valid_uid=$dl_dir/aishell/data_aishell/transcript/aishell_valid_uid + find $dl_dir/aishell/data_aishell/wav/dev -name "*.wav" | sed 's/\.wav//g' | awk -F '/' '{print $NF}' > $aishell_valid_uid + awk 'NR==FNR{uid[$1]=$1} NR!=FNR{if($1 in uid) print $0}' $aishell_valid_uid $aishell_text | + cut -d " " -f 2- > $dl_dir/lm/aishell-valid-word.txt + fi + + ./local/prepare_char_lm_training_data.py \ + --lang-char data/lang_char \ + --lm-data $dl_dir/lm/aishell-valid-word.txt \ + --lm-archive $out_dir/lm_data_valid.pt + + if [ ! -f $dl_dir/lm/aishell-test-word.txt ]; then + aishell_text=$dl_dir/aishell/data_aishell/transcript/aishell_transcript_v0.8.txt + aishell_test_uid=$dl_dir/aishell/data_aishell/transcript/aishell_test_uid + find $dl_dir/aishell/data_aishell/wav/test -name "*.wav" | sed 's/\.wav//g' | awk -F '/' '{print $NF}' > $aishell_test_uid + awk 'NR==FNR{uid[$1]=$1} NR!=FNR{if($1 in uid) print $0}' $aishell_test_uid $aishell_text | + cut -d " " -f 2- > $dl_dir/lm/aishell-test-word.txt + fi + + ./local/prepare_char_lm_training_data.py \ + --lang-char data/lang_char \ + --lm-data $dl_dir/lm/aishell-test-word.txt \ + --lm-archive $out_dir/lm_data_test.pt +fi + + +if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then + log "Stage 10: Sort LM training data" + # Sort LM training data by sentence length in descending order + # for ease of training. + # + # Sentence length equals to the number of tokens + # in a sentence. + + out_dir=data/lm_training_char + mkdir -p $out_dir + ln -snf ../../../librispeech/ASR/local/sort_lm_training_data.py local/ + + ./local/sort_lm_training_data.py \ + --in-lm-data $out_dir/lm_data.pt \ + --out-lm-data $out_dir/sorted_lm_data.pt \ + --out-statistics $out_dir/statistics.txt + + ./local/sort_lm_training_data.py \ + --in-lm-data $out_dir/lm_data_valid.pt \ + --out-lm-data $out_dir/sorted_lm_data-valid.pt \ + --out-statistics $out_dir/statistics-valid.txt + + ./local/sort_lm_training_data.py \ + --in-lm-data $out_dir/lm_data_test.pt \ + --out-lm-data $out_dir/sorted_lm_data-test.pt \ + --out-statistics $out_dir/statistics-test.txt +fi + +if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then + log "Stage 11: Train RNN LM model" + python ../../../icefall/rnn_lm/train.py \ + --start-epoch 0 \ + --world-size 1 \ + --num-epochs 20 \ + --use-fp16 0 \ + --embedding-dim 512 \ + --hidden-dim 512 \ + --num-layers 2 \ + --batch-size 400 \ + --exp-dir rnnlm_char/exp \ + --lm-data data/lm_training_char/sorted_lm_data.pt \ + --lm-data-valid data/lm_training_char/sorted_lm_data-valid.pt \ + --vocab-size 4336 \ + --master-port 12345 +fi From 7948624a220b9fc40dbfa87cb1eb83041af45ef3 Mon Sep 17 00:00:00 2001 From: marcoyang1998 <45973641+marcoyang1998@users.noreply.github.com> Date: Fri, 17 Mar 2023 13:44:29 +0800 Subject: [PATCH 133/174] Support fine-tuning (#944) * support finetune * add files for decoding giga * support initializing modules * add a fine-tune bash script --- egs/librispeech/ASR/finetune.sh | 85 ++ .../decode_gigaspeech.py | 861 +++++++++++ .../pruned_transducer_stateless7/finetune.py | 1342 +++++++++++++++++ .../gigaspeech.py | 406 +++++ .../gigaspeech_scoring.py | 1 + .../ASR/pruned_transducer_stateless7/optim.py | 46 +- 6 files changed, 2739 insertions(+), 2 deletions(-) create mode 100755 egs/librispeech/ASR/finetune.sh create mode 100644 egs/librispeech/ASR/pruned_transducer_stateless7/decode_gigaspeech.py create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless7/finetune.py create mode 100644 egs/librispeech/ASR/pruned_transducer_stateless7/gigaspeech.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless7/gigaspeech_scoring.py diff --git a/egs/librispeech/ASR/finetune.sh b/egs/librispeech/ASR/finetune.sh new file mode 100755 index 000000000..63d0966ed --- /dev/null +++ b/egs/librispeech/ASR/finetune.sh @@ -0,0 +1,85 @@ +#!/usr/bin/env bash + +# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 +export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python + +set -eou pipefail + +stage=-1 +stop_stage=100 + +# This is an example script for fine-tuning. Here, we fine-tune a model trained +# on Librispeech on GigaSpeech. The model used for fine-tuning is +# pruned_transducer_stateless7 (zipformer). If you want to fine-tune model +# from another recipe, you can adapt ./pruned_transducer_stateless7/finetune.py +# for that recipe. If you have any problem, please open up an issue in https://github.com/k2-fsa/icefall/issues. + +# We assume that you have already prepared the GigaSpeech manfiest&features under ./data. +# If you haven't done that, please see https://github.com/k2-fsa/icefall/blob/master/egs/gigaspeech/ASR/prepare.sh. + +dl_dir=$PWD/download + +. shared/parse_options.sh || exit 1 + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then + log "Stage -1: Download Pre-trained model" + + # clone from huggingface + git lfs install + git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11 + +fi + +if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then + log "Stage 0: Start fine-tuning" + + # The following configuration of lr schedule should work well + # You may also tune the following parameters to adjust learning rate schedule + base_lr=0.005 + lr_epochs=100 + lr_batches=100000 + + # We recommend to start from an averaged model + finetune_ckpt=icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11/exp/pretrained.pt + export CUDA_VISIBLE_DEVICES="0,1" + + ./pruned_transducer_stateless7/finetune.py \ + --world-size 2 \ + --master-port 18180 \ + --num-epochs 20 \ + --start-epoch 1 \ + --exp-dir pruned_transducer_stateless7/exp_giga_finetune \ + --subset S \ + --use-fp16 1 \ + --base-lr $base_lr \ + --lr-epochs $lr_epochs \ + --lr-batches $lr_batches \ + --bpe-model icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11/data/lang_bpe_500/bpe.model \ + --do-finetune True \ + --finetune-ckpt $finetune_ckpt \ + --max-duration 500 +fi + +if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then + log "Stage 1: Decoding" + + epoch=15 + avg=10 + + for m in greedy_search modified_beam_search; do + python pruned_transducer_stateless7/decode_gigaspeech.py \ + --epoch $epoch \ + --avg $avg \ + --use-averaged-model True \ + --beam-size 4 \ + --exp-dir pruned_transducer_stateless7/exp_giga_finetune \ + --max-duration 400 \ + --decoding-method $m + done +fi diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/decode_gigaspeech.py b/egs/librispeech/ASR/pruned_transducer_stateless7/decode_gigaspeech.py new file mode 100644 index 000000000..4f64850b6 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/decode_gigaspeech.py @@ -0,0 +1,861 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao, +# Xiaoyu Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) greedy search +./pruned_transducer_stateless7/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7/exp \ + --max-duration 600 \ + --decoding-method greedy_search + +(2) beam search (not recommended) +./pruned_transducer_stateless7/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7/exp \ + --max-duration 600 \ + --decoding-method beam_search \ + --beam-size 4 + +(3) modified beam search +./pruned_transducer_stateless7/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7/exp \ + --max-duration 600 \ + --decoding-method modified_beam_search \ + --beam-size 4 + +(4) fast beam search (one best) +./pruned_transducer_stateless7/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 + +(5) fast beam search (nbest) +./pruned_transducer_stateless7/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(6) fast beam search (nbest oracle WER) +./pruned_transducer_stateless7/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_oracle \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(7) fast beam search (with LG) +./pruned_transducer_stateless7/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_LG \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 +""" + + +import argparse +import logging +import math +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn + +# from asr_datamodule import LibriSpeechAsrDataModule +from gigaspeech import GigaSpeechAsrDataModule +from beam_search import ( + beam_search, + fast_beam_search_nbest, + fast_beam_search_nbest_LG, + fast_beam_search_nbest_oracle, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, +) +from gigaspeech_scoring import asr_text_post_processing +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=9, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_bpe_500", + help="The lang dir containing word table and LG graph", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - fast_beam_search + - fast_beam_search_nbest + - fast_beam_search_nbest_oracle + - fast_beam_search_nbest_LG + If you use fast_beam_search_nbest_LG, you have to specify + `--lang-dir`, which should contain `LG.pt`. + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=20.0, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search, + fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle + """, + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.01, + help=""" + Used only when --decoding_method is fast_beam_search_nbest_LG. + It specifies the scale for n-gram LM scores. + """, + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=8, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=64, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. + Used only when --decoding_method is greedy_search""", + ) + + parser.add_argument( + "--num-paths", + type=int, + default=200, + help="""Number of paths for nbest decoding. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""Scale applied to lattice scores when computing nbest paths. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--simulate-streaming", + type=str2bool, + default=False, + help="""Whether to simulate streaming in decoding, this is a good way to + test a streaming model. + """, + ) + + parser.add_argument( + "--decode-chunk-size", + type=int, + default=16, + help="The chunk size for decoding (in frames after subsampling)", + ) + + parser.add_argument( + "--left-context", + type=int, + default=64, + help="left context can be seen during decoding (in frames after subsampling)", + ) + + add_model_arguments(parser) + + return parser + + +def post_processing( + results: List[Tuple[str, List[str], List[str]]], +) -> List[Tuple[str, List[str], List[str]]]: + new_results = [] + for key, ref, hyp in results: + new_ref = asr_text_post_processing(" ".join(ref)).split() + new_hyp = asr_text_post_processing(" ".join(hyp)).split() + new_results.append((key, new_ref, new_hyp)) + return new_results + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + batch: dict, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if greedy_search is used, it would be "greedy_search" + If beam search with a beam size of 7 is used, it would be + "beam_7" + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return the decoding result. See above description for the format of + the returned dict. + """ + device = next(model.parameters()).device + feature = batch["inputs"] + assert feature.ndim == 3 + + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + if params.simulate_streaming: + feature_lens += params.left_context + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, params.left_context), + value=LOG_EPS, + ) + encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward( + x=feature, + x_lens=feature_lens, + chunk_size=params.decode_chunk_size, + left_context=params.left_context, + simulate_streaming=True, + ) + else: + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + + hyps = [] + + if params.decoding_method == "fast_beam_search": + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_LG": + hyp_tokens = fast_beam_search_nbest_LG( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in hyp_tokens: + hyps.append([word_table[i] for i in hyp]) + elif params.decoding_method == "fast_beam_search_nbest": + hyp_tokens = fast_beam_search_nbest( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_oracle": + hyp_tokens = fast_beam_search_nbest_oracle( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + ref_texts=sp.encode(supervisions["text"]), + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + else: + batch_size = encoder_out.size(0) + + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.decoding_method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyps.append(sp.decode(hyp).split()) + + if params.decoding_method == "greedy_search": + return {"greedy_search": hyps} + elif "fast_beam_search" in params.decoding_method: + key = f"beam_{params.beam}_" + key += f"max_contexts_{params.max_contexts}_" + key += f"max_states_{params.max_states}" + if "nbest" in params.decoding_method: + key += f"_num_paths_{params.num_paths}_" + key += f"nbest_scale_{params.nbest_scale}" + if "LG" in params.decoding_method: + key += f"_ngram_lm_scale_{params.ngram_lm_scale}" + + return {key: hyps} + else: + return {f"beam_size_{params.beam_size}": hyps} + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + if params.decoding_method == "greedy_search": + log_interval = 50 + else: + log_interval = 20 + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + word_table=word_table, + batch=batch, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + results = post_processing(results) + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = ( + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + """ + This scripts test a libri model with libri BPE + on Gigaspeech. + """ + parser = get_parser() + GigaSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + assert params.decoding_method in ( + "greedy_search", + "beam_search", + "fast_beam_search", + "fast_beam_search_nbest", + "fast_beam_search_nbest_LG", + "fast_beam_search_nbest_oracle", + "modified_beam_search", + ) + params.res_dir = params.exp_dir / (params.decoding_method + "_gigaspeech") + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if params.simulate_streaming: + params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}" + params.suffix += f"-left-context-{params.left_context}" + + if "fast_beam_search" in params.decoding_method: + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + if "nbest" in params.decoding_method: + params.suffix += f"-nbest-scale-{params.nbest_scale}" + params.suffix += f"-num-paths-{params.num_paths}" + if "LG" in params.decoding_method: + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" + elif "beam_search" in params.decoding_method: + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + else: + params.suffix += f"-context-{params.context_size}" + params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and are defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + if params.simulate_streaming: + assert ( + params.causal_convolution + ), "Decoding in streaming requires causal convolution" + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + + if "fast_beam_search" in params.decoding_method: + if params.decoding_method == "fast_beam_search_nbest_LG": + lexicon = Lexicon(params.lang_dir) + word_table = lexicon.word_table + lg_filename = params.lang_dir / "LG.pt" + logging.info(f"Loading {lg_filename}") + decoding_graph = k2.Fsa.from_dict( + torch.load(lg_filename, map_location=device) + ) + decoding_graph.scores *= params.ngram_lm_scale + else: + word_table = None + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + else: + decoding_graph = None + word_table = None + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + gigaspeech = GigaSpeechAsrDataModule(args) + + dev_cuts = gigaspeech.dev_cuts() + test_cuts = gigaspeech.test_cuts() + + dev_dl = gigaspeech.test_dataloaders(dev_cuts) + test_dl = gigaspeech.test_dataloaders(test_cuts) + + test_sets = ["dev", "test"] + test_dls = [dev_dl, test_dl] + + for test_set, test_dl in zip(test_sets, test_dls): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + sp=sp, + word_table=word_table, + decoding_graph=decoding_graph, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/finetune.py b/egs/librispeech/ASR/pruned_transducer_stateless7/finetune.py new file mode 100755 index 000000000..726a24809 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/finetune.py @@ -0,0 +1,1342 @@ +#!/usr/bin/env python3 +# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo,) +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +./pruned_transducer_stateless7/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --exp-dir pruned_transducer_stateless7/exp \ + --full-libri 1 \ + --max-duration 300 + +# For mix precision training: + +./pruned_transducer_stateless7/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir pruned_transducer_stateless7/exp \ + --full-libri 1 \ + --max-duration 550 + +""" + + +import argparse +import copy +import logging +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, List, Optional, Tuple, Union + +import k2 +import optim +import sentencepiece as spm +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from decoder import Decoder +from gigaspeech import GigaSpeechAsrDataModule +from joiner import Joiner +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model import Transducer +from optim import Eden, ScaledAdam +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from zipformer import Zipformer + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.utils import ( + AttributeDict, + MetricsTracker, + filter_uneven_sized_batch, + setup_logger, + str2bool, +) + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for module in model.modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + + +def add_finetune_arguments(parser: argparse.ArgumentParser): + parser.add_argument("--do-finetune", type=str2bool, default=False) + + parser.add_argument( + "--init-modules", + type=str, + default=None, + help=""" + Modules to be initialized. It matches all parameters starting with + a specific key. The keys are given with Comma seperated. If None, + all modules will be initialised. For example, if you only want to + initialise all parameters staring with "encoder", use "encoder"; + if you want to initialise parameters starting with encoder or decoder, + use "encoder,joiner". + """, + ) + + parser.add_argument( + "--finetune-ckpt", + type=str, + default=None, + help="Fine-tuning from which checkpoint (a path to a .pt file)", + ) + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=str, + default="2,4,3,2,4", + help="Number of zipformer encoder layers, comma separated.", + ) + + parser.add_argument( + "--feedforward-dims", + type=str, + default="1024,1024,2048,2048,1024", + help="Feedforward dimension of the zipformer encoder layers, comma separated.", + ) + + parser.add_argument( + "--nhead", + type=str, + default="8,8,8,8,8", + help="Number of attention heads in the zipformer encoder layers.", + ) + + parser.add_argument( + "--encoder-dims", + type=str, + default="384,384,384,384,384", + help="""Embedding dimension in the 2 blocks of zipformer encoder + layers, comma separated + """, + ) + + parser.add_argument( + "--attention-dims", + type=str, + default="192,192,192,192,192", + help="""Attention dimension in the 2 blocks of zipformer encoder layers,\ + comma separated; not the same as embedding dimension. + """, + ) + + parser.add_argument( + "--encoder-unmasked-dims", + type=str, + default="256,256,256,256,256", + help="""Unmasked dimensions in the encoders, relates to augmentation + during training. Must be <= each of encoder_dims. Empirically, less + than 256 seems to make performance worse. + """, + ) + + parser.add_argument( + "--zipformer-downsampling-factors", + type=str, + default="1,2,4,8,2", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--cnn-module-kernels", + type=str, + default="31,31,31,31,31", + help="Sizes of kernels in convolution modules", + ) + + parser.add_argument( + "--decoder-dim", + type=int, + default=512, + help="Embedding dimension in the decoder model.", + ) + + parser.add_argument( + "--joiner-dim", + type=int, + default=512, + help="""Dimension used in the joiner model. + Outputs from the encoder and decoder model are projected + to this dimension before adding. + """, + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="""Path to the BPE model. + This should be the bpe model of the original model + """, + ) + + parser.add_argument( + "--base-lr", type=float, default=0.005, help="The base learning rate." + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=100000, + help="""Number of steps that affects how rapidly the learning rate + decreases. During fine-tuning, we set this very large so that the + learning rate slowly decays with number of batches. You may tune + its value by yourself. + """, + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=100, + help="""Number of epochs that affects how rapidly the learning rate + decreases. During fine-tuning, we set this very large so that the + learning rate slowly decays with number of batches. You may tune + its value by yourself. + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network) part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=2000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 0. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + add_model_arguments(parser) + add_finetune_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - encoder_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warmup period that dictates the decay of the + scale on "simple" (un-pruned) loss. + """ + params = AttributeDict( + { + "frame_shift_ms": 10.0, + "allowed_excess_duration_ratio": 0.1, + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, # For the 100h subset, use 800 + # parameters for zipformer + "feature_dim": 80, + "subsampling_factor": 4, # not passed in, this is fixed. + "warm_step": 2000, + "env_info": get_env_info(), + } + ) + + return params + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + # TODO: We can add an option to switch between Zipformer and Transformer + def to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + + encoder = Zipformer( + num_features=params.feature_dim, + output_downsampling_factor=2, + zipformer_downsampling_factors=to_int_tuple( + params.zipformer_downsampling_factors + ), + encoder_dims=to_int_tuple(params.encoder_dims), + attention_dim=to_int_tuple(params.attention_dims), + encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims), + nhead=to_int_tuple(params.nhead), + feedforward_dim=to_int_tuple(params.feedforward_dims), + cnn_module_kernels=to_int_tuple(params.cnn_module_kernels), + num_encoder_layers=to_int_tuple(params.num_encoder_layers), + ) + return encoder + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + decoder_dim=params.decoder_dim, + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + joiner = Joiner( + encoder_dim=int(params.encoder_dims.split(",")[-1]), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return joiner + + +def get_transducer_model(params: AttributeDict) -> nn.Module: + encoder = get_encoder_model(params) + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + + model = Transducer( + encoder=encoder, + decoder=decoder, + joiner=joiner, + encoder_dim=int(params.encoder_dims.split(",")[-1]), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + if "cur_batch_idx" in saved_params: + params["cur_batch_idx"] = saved_params["cur_batch_idx"] + + return saved_params + + +def load_model_params( + ckpt: str, model: nn.Module, init_modules: List[str] = None, strict: bool = True +): + """Load model params from checkpoint + + Args: + ckpt (str): Path to the checkpoint + model (nn.Module): model to be loaded + + """ + logging.info(f"Loading checkpoint from {ckpt}") + checkpoint = torch.load(ckpt, map_location="cpu") + + # if module list is empty, load the whole model from ckpt + if not init_modules: + if next(iter(checkpoint["model"])).startswith("module."): + logging.info("Loading checkpoint saved by DDP") + + dst_state_dict = model.state_dict() + src_state_dict = checkpoint["model"] + for key in dst_state_dict.keys(): + src_key = "{}.{}".format("module", key) + dst_state_dict[key] = src_state_dict.pop(src_key) + assert len(src_state_dict) == 0 + model.load_state_dict(dst_state_dict, strict=strict) + else: + model.load_state_dict(checkpoint["model"], strict=strict) + else: + src_state_dict = checkpoint["model"] + dst_state_dict = model.state_dict() + for module in init_modules: + logging.info(f"Loading parameters starting with prefix {module}") + src_keys = [k for k in src_state_dict.keys() if k.startswith(module)] + dst_keys = [k for k in dst_state_dict.keys() if k.startswith(module)] + assert set(src_keys) == set(dst_keys) # two sets should match exactly + for key in src_keys: + dst_state_dict[key] = src_state_dict.pop(key) + + model.load_state_dict(dst_state_dict, strict=strict) + + return None + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute transducer loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + # For the uneven-sized batch, the total duration after padding would possibly + # cause OOM. Hence, for each batch, which is sorted descendingly by length, + # we simply drop the last few shortest samples, so that the retained total frames + # (after padding) would not exceed `allowed_max_frames`: + # `allowed_max_frames = int(max_frames * (1.0 + allowed_excess_duration_ratio))`, + # where `max_frames = max_duration * 1000 // frame_shift_ms`. + # We set allowed_excess_duration_ratio=0.1. + max_frames = params.max_duration * 1000 // params.frame_shift_ms + allowed_max_frames = int(max_frames * (1.0 + params.allowed_excess_duration_ratio)) + batch = filter_uneven_sized_batch(batch, allowed_max_frames) + + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + batch_idx_train = params.batch_idx_train + warm_step = params.warm_step + + texts = batch["supervisions"]["text"] + y = sp.encode(texts, out_type=int) + y = k2.RaggedTensor(y).to(device) + + with torch.set_grad_enabled(is_training): + simple_loss, pruned_loss = model( + x=feature, + x_lens=feature_lens, + y=y, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + ) + + s = params.simple_loss_scale + # take down the scale on the simple loss from 1.0 at the start + # to params.simple_loss scale by warm_step. + simple_loss_scale = ( + s + if batch_idx_train >= warm_step + else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) + ) + pruned_loss_scale = ( + 1.0 + if batch_idx_train >= warm_step + else 0.1 + 0.9 * (batch_idx_train / warm_step) + ) + + loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + sp: spm.SentencePieceProcessor, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + cur_batch_idx = params.get("cur_batch_idx", 0) + + for batch_idx, batch in enumerate(train_dl): + if batch_idx < cur_batch_idx: + continue + cur_batch_idx = batch_idx + + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + set_batch_count(model, params.batch_idx_train) + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except: # noqa + display_and_save_batch(batch, params=params, sp=sp) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + params.cur_batch_idx = batch_idx + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + del params.cur_batch_idx + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have + # different behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if batch_idx % params.log_interval == 0: + cur_lr = scheduler.get_last_lr()[0] + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", + cur_grad_scale, + params.batch_idx_train, + ) + + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + # load model parameters for model fine-tuning + if params.do_finetune: + modules = params.init_modules.split(",") if params.init_modules else None + checkpoints = load_model_params( + ckpt=params.finetune_ckpt, model=model, init_modules=modules + ) + else: + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + parameters_names = [] + parameters_names.append( + [name_param_pair[0] for name_param_pair in model.named_parameters()] + ) + optimizer = ScaledAdam( + model.parameters(), + lr=params.base_lr, + clipping_scale=2.0, + parameters_names=parameters_names, + ) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 2**22 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + gigaspeech = GigaSpeechAsrDataModule(args) + + train_cuts = gigaspeech.train_cuts() + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + if c.duration < 1.0 or c.duration > 20.0: + logging.warning( + f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + ) + return False + + # In pruned RNN-T, we require that T >= S + # where T is the number of feature frames after subsampling + # and S is the number of tokens in the utterance + + # In ./zipformer.py, the conv module uses the following expression + # for subsampling + T = ((c.num_frames - 7) // 2 + 1) // 2 + tokens = sp.encode(c.supervisions[0].text, out_type=str) + + if T < len(tokens): + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Number of frames (before subsampling): {c.num_frames}. " + f"Number of frames (after subsampling): {T}. " + f"Text: {c.supervisions[0].text}. " + f"Tokens: {tokens}. " + f"Number of tokens: {len(tokens)}" + ) + return False + + return True + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = gigaspeech.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) + + valid_cuts = gigaspeech.dev_cuts() + valid_dl = gigaspeech.valid_dataloaders(valid_cuts) + + if not params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + sp=sp, + params=params, + ) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sp=sp, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + sp: spm.SentencePieceProcessor, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + sp: + The BPE model. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + y = sp.encode(supervisions["text"], out_type=int) + num_tokens = sum(len(i) for i in y) + logging.info(f"num tokens: {num_tokens}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + sp: spm.SentencePieceProcessor, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params, sp=sp) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + parser = get_parser() + GigaSpeechAsrDataModule.add_arguments( + parser + ) # you may replace this with your own dataset + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/gigaspeech.py b/egs/librispeech/ASR/pruned_transducer_stateless7/gigaspeech.py new file mode 100644 index 000000000..5c01d7190 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/gigaspeech.py @@ -0,0 +1,406 @@ +# Copyright 2021 Piotr Żelasko +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import inspect +import logging +from functools import lru_cache +from pathlib import Path +from typing import Any, Dict, Optional + +import torch +from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy +from lhotse.dataset import ( + CutConcatenate, + CutMix, + DynamicBucketingSampler, + K2SpeechRecognitionDataset, + PrecomputedFeatures, + SingleCutSampler, + SpecAugment, +) +from lhotse.dataset.input_strategies import OnTheFlyFeatures +from lhotse.utils import fix_random_seed +from torch.utils.data import DataLoader + +from icefall.utils import str2bool + + +class _SeedWorkers: + def __init__(self, seed: int): + self.seed = seed + + def __call__(self, worker_id: int): + fix_random_seed(self.seed + worker_id) + + +class GigaSpeechAsrDataModule: + """ + DataModule for k2 ASR experiments. + It assumes there is always one train and valid dataloader, + but there can be multiple test dataloaders (e.g. LibriSpeech test-clean + and test-other). + + It contains all the common data pipeline modules used in ASR + experiments, e.g.: + - dynamic batch size, + - bucketing samplers, + - cut concatenation, + - augmentation, + - on-the-fly feature extraction + + This class should be derived for specific corpora used in ASR tasks. + """ + + def __init__(self, args: argparse.Namespace): + self.args = args + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="ASR data related options", + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", + ) + group.add_argument( + "--manifest-dir", + type=Path, + default=Path("data/fbank"), + help="Path to directory with train/valid/test cuts.", + ) + group.add_argument( + "--max-duration", + type=int, + default=200.0, + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", + ) + group.add_argument( + "--bucketing-sampler", + type=str2bool, + default=True, + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", + ) + group.add_argument( + "--num-buckets", + type=int, + default=30, + help="The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets).", + ) + group.add_argument( + "--concatenate-cuts", + type=str2bool, + default=False, + help="When enabled, utterances (cuts) will be concatenated " + "to minimize the amount of padding.", + ) + group.add_argument( + "--duration-factor", + type=float, + default=1.0, + help="Determines the maximum duration of a concatenated cut " + "relative to the duration of the longest cut in a batch.", + ) + group.add_argument( + "--gap", + type=float, + default=1.0, + help="The amount of padding (in seconds) inserted between " + "concatenated cuts. This padding is filled with noise when " + "noise augmentation is used.", + ) + group.add_argument( + "--on-the-fly-feats", + type=str2bool, + default=False, + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", + ) + group.add_argument( + "--shuffle", + type=str2bool, + default=True, + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", + ) + group.add_argument( + "--return-cuts", + type=str2bool, + default=True, + help="When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it.", + ) + + group.add_argument( + "--num-workers", + type=int, + default=2, + help="The number of training dataloader workers that " + "collect the batches.", + ) + + group.add_argument( + "--enable-spec-aug", + type=str2bool, + default=True, + help="When enabled, use SpecAugment for training dataset.", + ) + + group.add_argument( + "--spec-aug-time-warp-factor", + type=int, + default=80, + help="Used only when --enable-spec-aug is True. " + "It specifies the factor for time warping in SpecAugment. " + "Larger values mean more warping. " + "A value less than 1 means to disable time warp.", + ) + + group.add_argument( + "--enable-musan", + type=str2bool, + default=True, + help="When enabled, select noise from MUSAN and mix it " + "with training dataset. ", + ) + + # GigaSpeech specific arguments + group.add_argument( + "--subset", + type=str, + default="XL", + help="Select the GigaSpeech subset (XS|S|M|L|XL)", + ) + group.add_argument( + "--small-dev", + type=str2bool, + default=False, + help="Should we use only 1000 utterances for dev (speeds up training)", + ) + + def train_dataloaders( + self, + cuts_train: CutSet, + sampler_state_dict: Optional[Dict[str, Any]] = None, + ) -> DataLoader: + """ + Args: + cuts_train: + CutSet for training. + sampler_state_dict: + The state dict for the training sampler. + """ + + transforms = [] + if self.args.enable_musan: + logging.info("Enable MUSAN") + logging.info("About to get Musan cuts") + cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") + transforms.append( + CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + ) + else: + logging.info("Disable MUSAN") + + if self.args.concatenate_cuts: + logging.info( + f"Using cut concatenation with duration factor " + f"{self.args.duration_factor} and gap {self.args.gap}." + ) + # Cut concatenation should be the first transform in the list, + # so that if we e.g. mix noise in, it will fill the gaps between + # different utterances. + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + input_transforms = [] + if self.args.enable_spec_aug: + logging.info("Enable SpecAugment") + logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") + # Set the value of num_frame_masks according to Lhotse's version. + # In different Lhotse's versions, the default of num_frame_masks is + # different. + num_frame_masks = 10 + num_frame_masks_parameter = inspect.signature( + SpecAugment.__init__ + ).parameters["num_frame_masks"] + if num_frame_masks_parameter.default == 1: + num_frame_masks = 2 + logging.info(f"Num frame mask: {num_frame_masks}") + input_transforms.append( + SpecAugment( + time_warp_factor=self.args.spec_aug_time_warp_factor, + num_frame_masks=num_frame_masks, + features_mask_size=27, + num_feature_masks=2, + frames_mask_size=100, + ) + ) + else: + logging.info("Disable SpecAugment") + + logging.info("About to create train dataset") + train = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + ) + + if self.args.on_the_fly_feats: + # NOTE: the PerturbSpeed transform should be added only if we + # remove it from data prep stage. + # Add on-the-fly speed perturbation; since originally it would + # have increased epoch size by 3, we will apply prob 2/3 and use + # 3x more epochs. + # Speed perturbation probably should come first before + # concatenation, but in principle the transforms order doesn't have + # to be strict (e.g. could be randomized) + # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa + # Drop feats to be on the safe side. + train = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + ) + + if self.args.bucketing_sampler: + logging.info("Using DynamicBucketingSampler.") + train_sampler = DynamicBucketingSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + num_buckets=self.args.num_buckets, + drop_last=True, + ) + else: + logging.info("Using SingleCutSampler.") + train_sampler = SingleCutSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + ) + logging.info("About to create train dataloader") + + if sampler_state_dict is not None: + logging.info("Loading sampler state dict") + train_sampler.load_state_dict(sampler_state_dict) + + # 'seed' is derived from the current random state, which will have + # previously been set in the main process. + seed = torch.randint(0, 100000, ()).item() + worker_init_fn = _SeedWorkers(seed) + + train_dl = DataLoader( + train, + sampler=train_sampler, + batch_size=None, + num_workers=self.args.num_workers, + persistent_workers=False, + worker_init_fn=worker_init_fn, + ) + + return train_dl + + def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: + transforms = [] + if self.args.concatenate_cuts: + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + logging.info("About to create dev dataset") + if self.args.on_the_fly_feats: + validate = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + return_cuts=self.args.return_cuts, + ) + else: + validate = K2SpeechRecognitionDataset( + cut_transforms=transforms, + return_cuts=self.args.return_cuts, + ) + valid_sampler = DynamicBucketingSampler( + cuts_valid, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.info("About to create dev dataloader") + valid_dl = DataLoader( + validate, + sampler=valid_sampler, + batch_size=None, + num_workers=2, + persistent_workers=False, + ) + + return valid_dl + + def test_dataloaders(self, cuts: CutSet) -> DataLoader: + logging.debug("About to create test dataset") + test = K2SpeechRecognitionDataset( + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) + if self.args.on_the_fly_feats + else PrecomputedFeatures(), + return_cuts=self.args.return_cuts, + ) + sampler = DynamicBucketingSampler( + cuts, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.debug("About to create test dataloader") + test_dl = DataLoader( + test, + batch_size=None, + sampler=sampler, + num_workers=self.args.num_workers, + ) + return test_dl + + @lru_cache() + def train_cuts(self) -> CutSet: + logging.info(f"About to get train_{self.args.subset} cuts") + path = self.args.manifest_dir / f"cuts_{self.args.subset}.jsonl.gz" + cuts_train = CutSet.from_jsonl_lazy(path) + return cuts_train + + @lru_cache() + def dev_cuts(self) -> CutSet: + logging.info("About to get dev cuts") + cuts_valid = load_manifest_lazy(self.args.manifest_dir / "cuts_DEV.jsonl.gz") + if self.args.small_dev: + return cuts_valid.subset(first=1000) + else: + return cuts_valid + + @lru_cache() + def test_cuts(self) -> CutSet: + logging.info("About to get test cuts") + return load_manifest_lazy(self.args.manifest_dir / "cuts_TEST.jsonl.gz") diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/gigaspeech_scoring.py b/egs/librispeech/ASR/pruned_transducer_stateless7/gigaspeech_scoring.py new file mode 120000 index 000000000..fdfa6ce4b --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/gigaspeech_scoring.py @@ -0,0 +1 @@ +../../../gigaspeech/ASR/pruned_transducer_stateless2/gigaspeech_scoring.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py index 374b78cb3..b84e518d0 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py @@ -799,6 +799,47 @@ def _test_eden(): logging.info(f"state dict = {scheduler.state_dict()}") +def _plot_eden_lr(): + import matplotlib.pyplot as plt + + m = torch.nn.Linear(100, 100) + parameters_names = [] + parameters_names.append( + [name_param_pair[0] for name_param_pair in m.named_parameters()] + ) + + for lr_epoch in [4, 10, 100]: + for lr_batch in [100, 400]: + optim = ScaledAdam( + m.parameters(), lr=0.03, parameters_names=parameters_names + ) + scheduler = Eden( + optim, lr_batches=lr_batch, lr_epochs=lr_epoch, verbose=True + ) + lr = [] + + for epoch in range(10): + scheduler.step_epoch(epoch) # sets epoch to `epoch` + + for step in range(500): + lr.append(scheduler.get_lr()) + + x = torch.randn(200, 100).detach() + x.requires_grad = True + y = m(x) + dy = torch.randn(200, 100).detach() + f = (y * dy).sum() + f.backward() + + optim.step() + scheduler.step_batch() + optim.zero_grad() + plt.plot(lr, label=f"lr_epoch:{lr_epoch}, lr_batch:{lr_batch}") + + plt.legend() + plt.savefig("lr.png") + + # This is included mostly as a baseline for ScaledAdam. class Eve(Optimizer): """ @@ -1057,5 +1098,6 @@ if __name__ == "__main__": else: hidden_dim = 200 - _test_scaled_adam(hidden_dim) - _test_eden() + # _test_scaled_adam(hidden_dim) + # _test_eden() + _plot_eden_lr() From d74822d07b803f552602e727ebf099f406c74786 Mon Sep 17 00:00:00 2001 From: Wei Kang Date: Tue, 21 Mar 2023 21:35:32 +0800 Subject: [PATCH 134/174] Fix wenetspeech decoding speed (#953) --- .../compute_fbank_wenetspeech_dev_test.py | 4 +- .../asr_datamodule.py | 25 ++---- .../pruned_transducer_stateless2/decode.py | 77 ++----------------- .../pruned_transducer_stateless5/decode.py | 77 ++----------------- 4 files changed, 19 insertions(+), 164 deletions(-) diff --git a/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_dev_test.py b/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_dev_test.py index bd73e520e..20d7341db 100755 --- a/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_dev_test.py +++ b/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_dev_test.py @@ -20,7 +20,7 @@ import logging from pathlib import Path import torch -from lhotse import CutSet, KaldifeatFbank, KaldifeatFbankConfig, LilcomHdf5Writer +from lhotse import CutSet, KaldifeatFbank, KaldifeatFbankConfig, LilcomChunkyWriter # Torch's multithreaded behavior needs to be disabled or # it wastes a lot of CPU and slow things down. @@ -69,7 +69,7 @@ def compute_fbank_wenetspeech_dev_test(): storage_path=f"{in_out_dir}/feats_{partition}", num_workers=num_workers, batch_duration=batch_duration, - storage_type=LilcomHdf5Writer, + storage_type=LilcomChunkyWriter, overwrite=True, ) diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py index 9c07263a2..c9e30e737 100644 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py @@ -46,9 +46,6 @@ from torch.utils.data import DataLoader from icefall.utils import str2bool -set_caching_enabled(False) -torch.set_num_threads(1) - class _SeedWorkers: def __init__(self, seed: int): @@ -348,24 +345,18 @@ class WenetSpeechAsrDataModule: cut_transforms=transforms, return_cuts=self.args.return_cuts, ) + valid_sampler = DynamicBucketingSampler( cuts_valid, max_duration=self.args.max_duration, - rank=0, - world_size=1, shuffle=False, ) logging.info("About to create dev dataloader") - from lhotse.dataset.iterable_dataset import IterableDatasetWrapper - - dev_iter_dataset = IterableDatasetWrapper( - dataset=validate, - sampler=valid_sampler, - ) valid_dl = DataLoader( - dev_iter_dataset, + validate, batch_size=None, + sampler=valid_sampler, num_workers=self.args.num_workers, persistent_workers=False, ) @@ -383,19 +374,13 @@ class WenetSpeechAsrDataModule: sampler = DynamicBucketingSampler( cuts, max_duration=self.args.max_duration, - rank=0, - world_size=1, shuffle=False, ) - from lhotse.dataset.iterable_dataset import IterableDatasetWrapper - test_iter_dataset = IterableDatasetWrapper( - dataset=test, - sampler=sampler, - ) test_dl = DataLoader( - test_iter_dataset, + test, batch_size=None, + sampler=sampler, num_workers=self.args.num_workers, ) return test_dl diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py index 823b33ae5..bdd1f27bc 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py @@ -651,83 +651,18 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - # Note: Please use "pip install webdataset==0.1.103" - # for installing the webdataset. - import glob - import os - - from lhotse import CutSet - from lhotse.dataset.webdataset import export_to_webdataset - # we need cut ids to display recognition results. args.return_cuts = True wenetspeech = WenetSpeechAsrDataModule(args) - dev = "dev" - test_net = "test_net" - test_meeting = "test_meeting" + dev_cuts = wenetspeech.valid_cuts() + dev_dl = wenetspeech.valid_dataloaders(dev_cuts) - if not os.path.exists(f"{dev}/shared-0.tar"): - os.makedirs(dev) - dev_cuts = wenetspeech.valid_cuts() - export_to_webdataset( - dev_cuts, - output_path=f"{dev}/shared-%d.tar", - shard_size=300, - ) + test_net_cuts = wenetspeech.test_net_cuts() + test_net_dl = wenetspeech.test_dataloaders(test_net_cuts) - if not os.path.exists(f"{test_net}/shared-0.tar"): - os.makedirs(test_net) - test_net_cuts = wenetspeech.test_net_cuts() - export_to_webdataset( - test_net_cuts, - output_path=f"{test_net}/shared-%d.tar", - shard_size=300, - ) - - if not os.path.exists(f"{test_meeting}/shared-0.tar"): - os.makedirs(test_meeting) - test_meeting_cuts = wenetspeech.test_meeting_cuts() - export_to_webdataset( - test_meeting_cuts, - output_path=f"{test_meeting}/shared-%d.tar", - shard_size=300, - ) - - dev_shards = [ - str(path) for path in sorted(glob.glob(os.path.join(dev, "shared-*.tar"))) - ] - cuts_dev_webdataset = CutSet.from_webdataset( - dev_shards, - split_by_worker=True, - split_by_node=True, - shuffle_shards=True, - ) - - test_net_shards = [ - str(path) for path in sorted(glob.glob(os.path.join(test_net, "shared-*.tar"))) - ] - cuts_test_net_webdataset = CutSet.from_webdataset( - test_net_shards, - split_by_worker=True, - split_by_node=True, - shuffle_shards=True, - ) - - test_meeting_shards = [ - str(path) - for path in sorted(glob.glob(os.path.join(test_meeting, "shared-*.tar"))) - ] - cuts_test_meeting_webdataset = CutSet.from_webdataset( - test_meeting_shards, - split_by_worker=True, - split_by_node=True, - shuffle_shards=True, - ) - - dev_dl = wenetspeech.valid_dataloaders(cuts_dev_webdataset) - test_net_dl = wenetspeech.test_dataloaders(cuts_test_net_webdataset) - test_meeting_dl = wenetspeech.test_dataloaders(cuts_test_meeting_webdataset) + test_meeting_cuts = wenetspeech.test_meeting_cuts() + test_meeting_dl = wenetspeech.test_dataloaders(test_meeting_cuts) test_sets = ["DEV", "TEST_NET", "TEST_MEETING"] test_dl = [dev_dl, test_net_dl, test_meeting_dl] diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py index 32d5738b1..de12b2ff0 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py @@ -661,83 +661,18 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - # Note: Please use "pip install webdataset==0.1.103" - # for installing the webdataset. - import glob - import os - - from lhotse import CutSet - from lhotse.dataset.webdataset import export_to_webdataset - # we need cut ids to display recognition results. args.return_cuts = True wenetspeech = WenetSpeechAsrDataModule(args) - dev = "dev" - test_net = "test_net" - test_meeting = "test_meeting" + dev_cuts = wenetspeech.valid_cuts() + dev_dl = wenetspeech.valid_dataloaders(dev_cuts) - if not os.path.exists(f"{dev}/shared-0.tar"): - os.makedirs(dev) - dev_cuts = wenetspeech.valid_cuts() - export_to_webdataset( - dev_cuts, - output_path=f"{dev}/shared-%d.tar", - shard_size=300, - ) + test_net_cuts = wenetspeech.test_net_cuts() + test_net_dl = wenetspeech.test_dataloaders(test_net_cuts) - if not os.path.exists(f"{test_net}/shared-0.tar"): - os.makedirs(test_net) - test_net_cuts = wenetspeech.test_net_cuts() - export_to_webdataset( - test_net_cuts, - output_path=f"{test_net}/shared-%d.tar", - shard_size=300, - ) - - if not os.path.exists(f"{test_meeting}/shared-0.tar"): - os.makedirs(test_meeting) - test_meeting_cuts = wenetspeech.test_meeting_cuts() - export_to_webdataset( - test_meeting_cuts, - output_path=f"{test_meeting}/shared-%d.tar", - shard_size=300, - ) - - dev_shards = [ - str(path) for path in sorted(glob.glob(os.path.join(dev, "shared-*.tar"))) - ] - cuts_dev_webdataset = CutSet.from_webdataset( - dev_shards, - split_by_worker=True, - split_by_node=True, - shuffle_shards=True, - ) - - test_net_shards = [ - str(path) for path in sorted(glob.glob(os.path.join(test_net, "shared-*.tar"))) - ] - cuts_test_net_webdataset = CutSet.from_webdataset( - test_net_shards, - split_by_worker=True, - split_by_node=True, - shuffle_shards=True, - ) - - test_meeting_shards = [ - str(path) - for path in sorted(glob.glob(os.path.join(test_meeting, "shared-*.tar"))) - ] - cuts_test_meeting_webdataset = CutSet.from_webdataset( - test_meeting_shards, - split_by_worker=True, - split_by_node=True, - shuffle_shards=True, - ) - - dev_dl = wenetspeech.valid_dataloaders(cuts_dev_webdataset) - test_net_dl = wenetspeech.test_dataloaders(cuts_test_net_webdataset) - test_meeting_dl = wenetspeech.test_dataloaders(cuts_test_meeting_webdataset) + test_meeting_cuts = wenetspeech.test_meeting_cuts() + test_meeting_dl = wenetspeech.test_dataloaders(test_meeting_cuts) test_sets = ["DEV", "TEST_NET", "TEST_MEETING"] test_dl = [dev_dl, test_net_dl, test_meeting_dl] From f260a09ed4c899561ef3c367adcf7c3400f6dbdb Mon Sep 17 00:00:00 2001 From: Peng He <34941901+kamirdin@users.noreply.github.com> Date: Fri, 24 Mar 2023 14:30:43 +0800 Subject: [PATCH 135/174] remove if-branch at downsample pad in zipformer for onnx-export compatibility (#965) --- .../ASR/pruned_transducer_stateless7/zipformer.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 3959c0bb2..5b75b8d35 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -781,13 +781,12 @@ class AttentionDownsample(torch.nn.Module): ds = self.downsample d_seq_len = (seq_len + ds - 1) // ds - # Pad to an exact multiple of self.downsample - if seq_len != d_seq_len * ds: - # right-pad src, repeating the last element. - pad = d_seq_len * ds - seq_len - src_extra = src[src.shape[0] - 1 :].expand(pad, src.shape[1], src.shape[2]) - src = torch.cat((src, src_extra), dim=0) - assert src.shape[0] == d_seq_len * ds, (src.shape[0], d_seq_len, ds) + # Pad to an exact multiple of self.downsample, could be 0 for onnx-export-compatibility + # right-pad src, repeating the last element. + pad = d_seq_len * ds - seq_len + src_extra = src[src.shape[0] - 1 :].expand(pad, src.shape[1], src.shape[2]) + src = torch.cat((src, src_extra), dim=0) + assert src.shape[0] == d_seq_len * ds, (src.shape[0], d_seq_len, ds) src = src.reshape(d_seq_len, ds, batch_size, in_channels) scores = (src * self.query).sum(dim=-1, keepdim=True) From 7155769c1929d457bd74afbe947f77e14e5bd3db Mon Sep 17 00:00:00 2001 From: Zengwei Yao Date: Fri, 24 Mar 2023 15:30:29 +0800 Subject: [PATCH 136/174] minor fix, remove numel = p.numel() in optim.py (#967) --- egs/librispeech/ASR/pruned_transducer_stateless7/optim.py | 1 - 1 file changed, 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py index b84e518d0..aa3cef338 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py @@ -282,7 +282,6 @@ class ScaledAdam(BatchedOptimizer): batch_size = p.shape[0] numel = p.numel() // batch_size - numel = p.numel() if numel > 1: # "param_rms" just periodically records the scalar root-mean-square value of From 8c3ea93fc8a9be5dcac2bd0ad0d8e34cc13f6dd3 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Mon, 27 Mar 2023 11:39:29 +0800 Subject: [PATCH 137/174] Save meta data to exported ONNX models (#968) --- .../ASR/pruned_transducer_stateless3/export-onnx.py | 10 ++++++++++ .../ASR/pruned_transducer_stateless5/export-onnx.py | 10 ++++++++++ .../ASR/pruned_transducer_stateless7/export-onnx.py | 10 ++++++++++ 3 files changed, 30 insertions(+) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/export-onnx.py b/egs/librispeech/ASR/pruned_transducer_stateless3/export-onnx.py index ca8be307c..36e57e946 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/export-onnx.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/export-onnx.py @@ -273,6 +273,16 @@ def export_encoder_model_onnx( }, ) + meta_data = { + "model_type": "conformer", + "version": "1", + "model_author": "k2-fsa", + "comment": "stateless3", + } + logging.info(f"meta_data: {meta_data}") + + add_meta_data(filename=encoder_filename, meta_data=meta_data) + def export_decoder_model_onnx( decoder_model: OnnxDecoder, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/export-onnx.py b/egs/librispeech/ASR/pruned_transducer_stateless5/export-onnx.py index 743fe8a92..3d94760dc 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/export-onnx.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/export-onnx.py @@ -296,6 +296,16 @@ def export_encoder_model_onnx( }, ) + meta_data = { + "model_type": "conformer", + "version": "1", + "model_author": "k2-fsa", + "comment": "stateless5", + } + logging.info(f"meta_data: {meta_data}") + + add_meta_data(filename=encoder_filename, meta_data=meta_data) + def export_decoder_model_onnx( decoder_model: OnnxDecoder, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/export-onnx.py b/egs/librispeech/ASR/pruned_transducer_stateless7/export-onnx.py index f76915a74..93eb8df3d 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/export-onnx.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/export-onnx.py @@ -291,6 +291,16 @@ def export_encoder_model_onnx( }, ) + meta_data = { + "model_type": "zipformer", + "version": "1", + "model_author": "k2-fsa", + "comment": "stateless7", + } + logging.info(f"meta_data: {meta_data}") + + add_meta_data(filename=encoder_filename, meta_data=meta_data) + def export_decoder_model_onnx( decoder_model: OnnxDecoder, From 35e21a0d2ede47c632f97ae6d560194b66439472 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Mon, 27 Mar 2023 14:08:26 +0800 Subject: [PATCH 138/174] Fix torchscript export for aishell (#969) --- egs/aishell/ASR/pruned_transducer_stateless3/export.py | 2 ++ egs/aishell/ASR/pruned_transducer_stateless3/lstmp.py | 1 + .../ASR/pruned_transducer_stateless3/scaling_converter.py | 1 + 3 files changed, 4 insertions(+) create mode 120000 egs/aishell/ASR/pruned_transducer_stateless3/lstmp.py create mode 120000 egs/aishell/ASR/pruned_transducer_stateless3/scaling_converter.py diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/export.py b/egs/aishell/ASR/pruned_transducer_stateless3/export.py index 7f10eb36e..723414167 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless3/export.py +++ b/egs/aishell/ASR/pruned_transducer_stateless3/export.py @@ -48,6 +48,7 @@ import logging from pathlib import Path import torch +from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_params, get_transducer_model from icefall.checkpoint import ( @@ -244,6 +245,7 @@ def main(): model.eval() if params.jit: + convert_scaled_to_non_scaled(model, inplace=True) # We won't use the forward() method of the model in C++, so just ignore # it here. # Otherwise, one of its arguments is a ragged tensor and is not diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/lstmp.py b/egs/aishell/ASR/pruned_transducer_stateless3/lstmp.py new file mode 120000 index 000000000..557e18aa1 --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless3/lstmp.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless3/lstmp.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/scaling_converter.py b/egs/aishell/ASR/pruned_transducer_stateless3/scaling_converter.py new file mode 120000 index 000000000..db93d155b --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless3/scaling_converter.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py \ No newline at end of file From 15d48e3a6af6cd1626ec842686ffe76f928e158f Mon Sep 17 00:00:00 2001 From: PF Luo Date: Tue, 28 Mar 2023 19:14:08 +0800 Subject: [PATCH 139/174] fix rnn_lm && transformer_lm import problem (#971) --- icefall/rnn_lm/__init__.py | 0 icefall/transformer_lm/__init__.py | 0 2 files changed, 0 insertions(+), 0 deletions(-) create mode 100644 icefall/rnn_lm/__init__.py create mode 100644 icefall/transformer_lm/__init__.py diff --git a/icefall/rnn_lm/__init__.py b/icefall/rnn_lm/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/icefall/transformer_lm/__init__.py b/icefall/transformer_lm/__init__.py new file mode 100644 index 000000000..e69de29bb From bcc5923ab92c93bf38829f7d5de84d84c9050eb1 Mon Sep 17 00:00:00 2001 From: Zengwei Yao Date: Tue, 28 Mar 2023 23:24:24 +0800 Subject: [PATCH 140/174] Support batch-wise forced-alignment (#970) * support batch-wise forced-alignment based on beam search * add length_norm to HypothesisList.topk() * Use Hypothesis and HypothesisList instead --- .../beam_search.py | 17 +- .../pruned_transducer_stateless7/alignment.py | 206 +++++++++++ .../compute_ali.py | 345 ++++++++++++++++++ .../test_compute_ali.py | 130 +++++++ icefall/utils.py | 2 +- 5 files changed, 696 insertions(+), 4 deletions(-) create mode 100644 egs/librispeech/ASR/pruned_transducer_stateless7/alignment.py create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless7/compute_ali.py create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless7/test_compute_ali.py diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py index bd2d6e258..999d793a4 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -829,11 +829,22 @@ class HypothesisList(object): ans.add(hyp) # shallow copy return ans - def topk(self, k: int) -> "HypothesisList": - """Return the top-k hypothesis.""" + def topk(self, k: int, length_norm: bool = False) -> "HypothesisList": + """Return the top-k hypothesis. + + Args: + length_norm: + If True, the `log_prob` of a hypothesis is normalized by the + number of tokens in it. + """ hyps = list(self._data.items()) - hyps = sorted(hyps, key=lambda h: h[1].log_prob, reverse=True)[:k] + if length_norm: + hyps = sorted( + hyps, key=lambda h: h[1].log_prob / len(h[1].ys), reverse=True + )[:k] + else: + hyps = sorted(hyps, key=lambda h: h[1].log_prob, reverse=True)[:k] ans = HypothesisList(dict(hyps)) return ans diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/alignment.py b/egs/librispeech/ASR/pruned_transducer_stateless7/alignment.py new file mode 100644 index 000000000..76cd56bbb --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/alignment.py @@ -0,0 +1,206 @@ +# Copyright 2022-2023 Xiaomi Corp. (authors: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import List + +import k2 +import torch + +from beam_search import Hypothesis, HypothesisList, get_hyps_shape + +# The force alignment problem can be formulated as finding +# a path in a rectangular lattice, where the path starts +# from the lower left corner and ends at the upper right +# corner. The horizontal axis of the lattice is `t` (representing +# acoustic frame indexes) and the vertical axis is `u` (representing +# BPE tokens of the transcript). +# +# The notations `t` and `u` are from the paper +# https://arxiv.org/pdf/1211.3711.pdf +# +# Beam search is used to find the path with the highest log probabilities. +# +# It assumes the maximum number of symbols that can be +# emitted per frame is 1. + + +def batch_force_alignment( + model: torch.nn.Module, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + ys_list: List[List[int]], + beam_size: int = 4, +) -> List[int]: + """Compute the force alignment of a batch of utterances given their transcripts + in BPE tokens and the corresponding acoustic output from the encoder. + + Caution: + This function is modified from `modified_beam_search` in beam_search.py. + We assume that the maximum number of sybmols per frame is 1. + + Args: + model: + The transducer model. + encoder_out: + A tensor of shape (N, T, C). + encoder_out_lens: + A 1-D tensor of shape (N,), containing number of valid frames in + encoder_out before padding. + ys_list: + A list of BPE token IDs list. We require that for each utterance i, + len(ys_list[i]) <= encoder_out_lens[i]. + beam_size: + Size of the beam used in beam search. + + Returns: + Return a list of frame indexes list for each utterance i, + where len(ans[i]) == len(ys_list[i]). + """ + assert encoder_out.ndim == 3, encoder_out.ndim + assert encoder_out.size(0) == len(ys_list), (encoder_out.size(0), len(ys_list)) + assert encoder_out.size(0) > 0, encoder_out.size(0) + + blank_id = model.decoder.blank_id + context_size = model.decoder.context_size + device = next(model.parameters()).device + + packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( + input=encoder_out, + lengths=encoder_out_lens.cpu(), + batch_first=True, + enforce_sorted=False, + ) + batch_size_list = packed_encoder_out.batch_sizes.tolist() + N = encoder_out.size(0) + assert torch.all(encoder_out_lens > 0), encoder_out_lens + assert N == batch_size_list[0], (N, batch_size_list) + + sorted_indices = packed_encoder_out.sorted_indices.tolist() + encoder_out_lens = encoder_out_lens.tolist() + ys_lens = [len(ys) for ys in ys_list] + sorted_encoder_out_lens = [encoder_out_lens[i] for i in sorted_indices] + sorted_ys_lens = [ys_lens[i] for i in sorted_indices] + sorted_ys_list = [ys_list[i] for i in sorted_indices] + + B = [HypothesisList() for _ in range(N)] + for i in range(N): + B[i].add( + Hypothesis( + ys=[blank_id] * context_size, + log_prob=torch.zeros(1, dtype=torch.float32, device=device), + timestamp=[], + ) + ) + + encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) + + offset = 0 + finalized_B = [] + for (t, batch_size) in enumerate(batch_size_list): + start = offset + end = offset + batch_size + current_encoder_out = encoder_out.data[start:end] + current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1) + # current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim) + offset = end + + finalized_B = B[batch_size:] + finalized_B + B = B[:batch_size] + sorted_encoder_out_lens = sorted_encoder_out_lens[:batch_size] + sorted_ys_lens = sorted_ys_lens[:batch_size] + + hyps_shape = get_hyps_shape(B).to(device) + + A = [list(b) for b in B] + B = [HypothesisList() for _ in range(batch_size)] + + ys_log_probs = torch.cat( + [hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps] + ) # (num_hyps, 1) + + decoder_input = torch.tensor( + [hyp.ys[-context_size:] for hyps in A for hyp in hyps], + device=device, + dtype=torch.int64, + ) # (num_hyps, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) + decoder_out = model.joiner.decoder_proj(decoder_out) + # decoder_out is of shape (num_hyps, 1, 1, joiner_dim) + + # Note: For torch 1.7.1 and below, it requires a torch.int64 tensor + # as index, so we use `to(torch.int64)` below. + current_encoder_out = torch.index_select( + current_encoder_out, + dim=0, + index=hyps_shape.row_ids(1).to(torch.int64), + ) # (num_hyps, 1, 1, encoder_out_dim) + + logits = model.joiner( + current_encoder_out, decoder_out, project_input=False + ) # (num_hyps, 1, 1, vocab_size) + logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) + + log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size) + log_probs.add_(ys_log_probs) + + vocab_size = log_probs.size(-1) + + row_splits = hyps_shape.row_splits(1) * vocab_size + log_probs_shape = k2.ragged.create_ragged_shape2( + row_splits=row_splits, cached_tot_size=log_probs.numel() + ) + ragged_log_probs = k2.RaggedTensor( + shape=log_probs_shape, value=log_probs.reshape(-1) + ) # [batch][num_hyps*vocab_size] + + for i in range(batch_size): + for h, hyp in enumerate(A[i]): + pos_u = len(hyp.timestamp) + idx_offset = h * vocab_size + if (sorted_encoder_out_lens[i] - 1 - t) >= (sorted_ys_lens[i] - pos_u): + # emit blank token + new_hyp = Hypothesis( + log_prob=ragged_log_probs[i][idx_offset + blank_id], + ys=hyp.ys[:], + timestamp=hyp.timestamp[:], + ) + B[i].add(new_hyp) + if pos_u < sorted_ys_lens[i]: + # emit non-blank token + new_token = sorted_ys_list[i][pos_u] + new_hyp = Hypothesis( + log_prob=ragged_log_probs[i][idx_offset + new_token], + ys=hyp.ys + [new_token], + timestamp=hyp.timestamp + [t], + ) + B[i].add(new_hyp) + + if len(B[i]) > beam_size: + B[i] = B[i].topk(beam_size, length_norm=True) + + B = B + finalized_B + sorted_hyps = [b.get_most_probable() for b in B] + unsorted_indices = packed_encoder_out.unsorted_indices.tolist() + hyps = [sorted_hyps[i] for i in unsorted_indices] + ans = [] + for i, hyp in enumerate(hyps): + assert hyp.ys[context_size:] == ys_list[i], (hyp.ys[context_size:], ys_list[i]) + ans.append(hyp.timestamp) + + return ans diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/compute_ali.py b/egs/librispeech/ASR/pruned_transducer_stateless7/compute_ali.py new file mode 100755 index 000000000..8bcb56d62 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/compute_ali.py @@ -0,0 +1,345 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao, +# Xiaoyu Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +The script gets forced-alignments based on the modified_beam_search decoding method. +Both token-level alignments and word-level alignments are saved to the new cuts manifests. + +It loads a checkpoint and uses it to get the forced-alignments. +You can generate the checkpoint with the following command: + +./pruned_transducer_stateless7/export.py \ + --exp-dir ./pruned_transducer_stateless7/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 30 \ + --avg 9 + +Usage of this script: + +./pruned_transducer_stateless7/compute_ali.py \ + --checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \ + --bpe-model data/lang_bpe_500/bpe.model \ + --dataset test-clean \ + --max-duration 300 \ + --beam-size 4 \ + --cuts-out-dir data/fbank_ali_beam_search +""" + + +import argparse +import logging +from pathlib import Path +from typing import List, Tuple + +import sentencepiece as spm +import torch +import torch.nn as nn +from alignment import batch_force_alignment +from asr_datamodule import LibriSpeechAsrDataModule +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.utils import AttributeDict, convert_timestamp, parse_timestamp +from lhotse import CutSet +from lhotse.serialization import SequentialJsonlWriter +from lhotse.supervision import AlignmentItem + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--checkpoint", + type=str, + required=True, + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--dataset", + type=str, + required=True, + help="""The name of the dataset to compute alignments for. + Possible values are: + - test-clean + - test-other + - train-clean-100 + - train-clean-360 + - train-other-500 + - dev-clean + - dev-other + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--cuts-out-dir", + type=str, + default="data/fbank_ali_beam_search", + help="The dir to save the new cuts manifests with alignments", + ) + + add_model_arguments(parser) + + return parser + + +def align_one_batch( + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + batch: dict, +) -> Tuple[List[List[str]], List[List[str]], List[List[float]], List[List[float]]]: + """Get forced-alignments for one batch. + + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + + Returns: + token_list: + A list of token list. + word_list: + A list of word list. + token_time_list: + A list of timestamps list for tokens. + word_time_list. + A list of timestamps list for words. + + where len(token_list) == len(word_list) == len(token_time_list) == len(word_time_list), + len(token_list[i]) == len(token_time_list[i]), + and len(word_list[i]) == len(word_time_list[i]) + + """ + device = next(model.parameters()).device + feature = batch["inputs"] + assert feature.ndim == 3 + + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + + texts = supervisions["text"] + ys_list: List[List[int]] = sp.encode(texts, out_type=int) + + frame_indexes = batch_force_alignment( + model, encoder_out, encoder_out_lens, ys_list, params.beam_size + ) + + token_list = [] + word_list = [] + token_time_list = [] + word_time_list = [] + for i in range(encoder_out.size(0)): + tokens = sp.id_to_piece(ys_list[i]) + words = texts[i].split() + token_time = convert_timestamp( + frame_indexes[i], params.subsampling_factor, params.frame_shift_ms + ) + word_time = parse_timestamp(tokens, token_time) + assert len(word_time) == len(words), (len(word_time), len(words)) + + token_list.append(tokens) + word_list.append(words) + token_time_list.append(token_time) + word_time_list.append(word_time) + + return token_list, word_list, token_time_list, word_time_list + + +def align_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + writer: SequentialJsonlWriter, +) -> None: + """Get forced-alignments for the dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + writer: + Writer to save the cuts with alignments. + """ + log_interval = 20 + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + for batch_idx, batch in enumerate(dl): + token_list, word_list, token_time_list, word_time_list = align_one_batch( + params=params, model=model, sp=sp, batch=batch + ) + + cut_list = batch["supervisions"]["cut"] + for cut, token, word, token_time, word_time in zip( + cut_list, token_list, word_list, token_time_list, word_time_list + ): + assert len(cut.supervisions) == 1, f"{len(cut.supervisions)}" + token_ali = [ + AlignmentItem( + symbol=token[i], + start=round(token_time[i], ndigits=3), + duration=None, + ) + for i in range(len(token)) + ] + word_ali = [ + AlignmentItem( + symbol=word[i], start=round(word_time[i], ndigits=3), duration=None + ) + for i in range(len(word)) + ] + cut.supervisions[0].alignment = {"word": word_ali, "token": token_ali} + writer.write(cut, flush=True) + + num_cuts += len(cut_list) + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and are defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + checkpoint = torch.load(args.checkpoint, map_location="cpu") + model.load_state_dict(checkpoint["model"], strict=False) + model.to(device) + model.eval() + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + librispeech = LibriSpeechAsrDataModule(args) + + if params.dataset == "test-clean": + test_clean_cuts = librispeech.test_clean_cuts() + dl = librispeech.test_dataloaders(test_clean_cuts) + elif params.dataset == "test-other": + test_other_cuts = librispeech.test_other_cuts() + dl = librispeech.test_dataloaders(test_other_cuts) + elif params.dataset == "train-clean-100": + train_clean_100_cuts = librispeech.train_clean_100_cuts() + dl = librispeech.train_dataloaders(train_clean_100_cuts) + elif params.dataset == "train-clean-360": + train_clean_360_cuts = librispeech.train_clean_360_cuts() + dl = librispeech.train_dataloaders(train_clean_360_cuts) + elif params.dataset == "train-other-500": + train_other_500_cuts = librispeech.train_other_500_cuts() + dl = librispeech.train_dataloaders(train_other_500_cuts) + elif params.dataset == "dev-clean": + dev_clean_cuts = librispeech.dev_clean_cuts() + dl = librispeech.valid_dataloaders(dev_clean_cuts) + else: + assert params.dataset == "dev-other", f"{params.dataset}" + dev_other_cuts = librispeech.dev_other_cuts() + dl = librispeech.valid_dataloaders(dev_other_cuts) + + cuts_out_dir = Path(params.cuts_out_dir) + cuts_out_dir.mkdir(parents=True, exist_ok=True) + cuts_out_path = cuts_out_dir / f"librispeech_cuts_{params.dataset}.jsonl.gz" + + with CutSet.open_writer(cuts_out_path) as writer: + align_dataset(dl=dl, params=params, model=model, sp=sp, writer=writer) + + logging.info( + f"For dataset {params.dataset}, the cut manifest with framewise token alignments " + f"and word alignments are saved to {cuts_out_path}" + ) + logging.info("Done!") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/test_compute_ali.py b/egs/librispeech/ASR/pruned_transducer_stateless7/test_compute_ali.py new file mode 100755 index 000000000..081f7ba1a --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/test_compute_ali.py @@ -0,0 +1,130 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This script compares the word-level alignments generated based on modified_beam_search decoding +(in ./pruned_transducer_stateless7/compute_ali.py) to the reference alignments generated +by torchaudio framework (in ./add_alignments.sh). + +Usage: + +./pruned_transducer_stateless7/compute_ali.py \ + --checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \ + --bpe-model data/lang_bpe_500/bpe.model \ + --dataset test-clean \ + --max-duration 300 \ + --beam-size 4 \ + --cuts-out-dir data/fbank_ali_beam_search + +And the you can run: + +./pruned_transducer_stateless7/test_compute_ali.py \ + --cuts-out-dir ./data/fbank_ali_test \ + --cuts-ref-dir ./data/fbank_ali_torch \ + --dataset train-clean-100 +""" +import argparse +import logging +from pathlib import Path + +import torch +from lhotse import load_manifest + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--cuts-out-dir", + type=Path, + default="./data/fbank_ali", + help="The dir that saves the generated cuts manifests with alignments", + ) + + parser.add_argument( + "--cuts-ref-dir", + type=Path, + default="./data/fbank_ali_torch", + help="The dir that saves the reference cuts manifests with alignments", + ) + + parser.add_argument( + "--dataset", + type=str, + required=True, + help="""The name of the dataset: + Possible values are: + - test-clean + - test-other + - train-clean-100 + - train-clean-360 + - train-other-500 + - dev-clean + - dev-other + """, + ) + + return parser + + +def main(): + args = get_parser().parse_args() + + cuts_out_jsonl = args.cuts_out_dir / f"librispeech_cuts_{args.dataset}.jsonl.gz" + cuts_ref_jsonl = args.cuts_ref_dir / f"librispeech_cuts_{args.dataset}.jsonl.gz" + + logging.info(f"Loading {cuts_out_jsonl} and {cuts_ref_jsonl}") + cuts_out = load_manifest(cuts_out_jsonl) + cuts_ref = load_manifest(cuts_ref_jsonl) + cuts_ref = cuts_ref.sort_like(cuts_out) + + all_time_diffs = [] + for cut_out, cut_ref in zip(cuts_out, cuts_ref): + time_out = [ + ali.start + for ali in cut_out.supervisions[0].alignment["word"] + if ali.symbol != "" + ] + time_ref = [ + ali.start + for ali in cut_ref.supervisions[0].alignment["word"] + if ali.symbol != "" + ] + assert len(time_out) == len(time_ref), (len(time_out), len(time_ref)) + diff = [ + round(abs(out - ref), ndigits=3) for out, ref in zip(time_out, time_ref) + ] + all_time_diffs += diff + + all_time_diffs = torch.tensor(all_time_diffs) + logging.info( + f"For the word-level alignments abs difference on dataset {args.dataset}, " + f"mean: {'%.2f' % all_time_diffs.mean()}s, std: {'%.2f' % all_time_diffs.std()}s" + ) + logging.info("Done!") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/icefall/utils.py b/icefall/utils.py index 5d86472b5..1fd9156bd 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -1378,7 +1378,7 @@ def parse_timestamp(tokens: List[str], timestamp: List[float]) -> List[float]: List of timestamp of each word. """ start_token = b"\xe2\x96\x81".decode() # '_' - assert len(tokens) == len(timestamp) + assert len(tokens) == len(timestamp), (len(tokens), len(timestamp)) ans = [] for i in range(len(tokens)): flag = False From 2a5a75cb5655126c5b88a370f848cb87f2e491ac Mon Sep 17 00:00:00 2001 From: Zengwei Yao Date: Thu, 30 Mar 2023 14:30:13 +0800 Subject: [PATCH 141/174] add option of using full attention for streaming model decoding (#975) --- .../ASR/pruned_transducer_stateless/decode.py | 13 ++++++++++++- .../pruned_transducer_stateless2/conformer.py | 5 +++++ .../ASR/pruned_transducer_stateless2/decode.py | 16 +++++++++------- .../ASR/pruned_transducer_stateless3/decode.py | 16 +++++++++------- .../ASR/pruned_transducer_stateless4/decode.py | 18 ++++++++++-------- .../ASR/pruned_transducer_stateless5/decode.py | 16 +++++++++------- .../ASR/transducer_stateless/conformer.py | 5 +++++ 7 files changed, 59 insertions(+), 30 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py index 6dfe11cee..3c4500087 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py @@ -107,6 +107,7 @@ Usage: import argparse import logging +import math from collections import defaultdict from pathlib import Path from typing import Dict, List, Optional, Tuple @@ -138,6 +139,8 @@ from icefall.utils import ( write_error_stats, ) +LOG_EPS = math.log(1e-10) + def get_parser(): parser = argparse.ArgumentParser( @@ -288,7 +291,7 @@ def get_parser(): "--decode-chunk-size", type=int, default=16, - help="The chunk size for decoding (in frames after subsampling)", + help="The chunk size for decoding (in frames after subsampling). Set -1 to use full attention.", ) parser.add_argument( "--left-context", @@ -370,6 +373,14 @@ def decode_one_batch( feature_lens = supervisions["num_frames"].to(device) if params.simulate_streaming: + if params.decode_chunk_size > 0: + # except the case of using full attention + feature_lens += params.left_context + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, params.left_context), + value=LOG_EPS, + ) encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward( x=feature, x_lens=feature_lens, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py index f94ffef59..9bac46004 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py @@ -375,6 +375,11 @@ class Conformer(EncoderInterface): assert x.size(0) == lengths.max().item() + if chunk_size < 0: + # use full attention + chunk_size = x.size(0) + left_context = -1 + num_left_chunks = -1 if left_context >= 0: assert left_context % chunk_size == 0 diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py index 172c9ab7c..c57514193 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py @@ -295,7 +295,7 @@ def get_parser(): "--decode-chunk-size", type=int, default=16, - help="The chunk size for decoding (in frames after subsampling)", + help="The chunk size for decoding (in frames after subsampling). Set -1 to use full attention.", ) parser.add_argument( @@ -378,12 +378,14 @@ def decode_one_batch( feature_lens = supervisions["num_frames"].to(device) if params.simulate_streaming: - feature_lens += params.left_context - feature = torch.nn.functional.pad( - feature, - pad=(0, 0, 0, params.left_context), - value=LOG_EPS, - ) + if params.decode_chunk_size > 0: + # except the case of using full attention + feature_lens += params.left_context + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, params.left_context), + value=LOG_EPS, + ) encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward( x=feature, x_lens=feature_lens, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py index aa055049e..b39007dfc 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py @@ -344,7 +344,7 @@ def get_parser(): "--decode-chunk-size", type=int, default=16, - help="The chunk size for decoding (in frames after subsampling)", + help="The chunk size for decoding (in frames after subsampling). Set -1 to use full attention.", ) parser.add_argument( @@ -508,12 +508,14 @@ def decode_one_batch( feature_lens = supervisions["num_frames"].to(device) if params.simulate_streaming: - feature_lens += params.left_context - feature = torch.nn.functional.pad( - feature, - pad=(0, 0, 0, params.left_context), - value=LOG_EPS, - ) + if params.decode_chunk_size > 0: + # except the case of using full attention + feature_lens += params.left_context + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, params.left_context), + value=LOG_EPS, + ) encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward( x=feature, x_lens=feature_lens, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py index 5ec3d3b45..79d919ab1 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py @@ -326,14 +326,14 @@ def get_parser(): "--decode-chunk-size", type=int, default=16, - help="The chunk size for decoding (in frames after subsampling)", + help="The chunk size for decoding (in frames after subsampling). Set -1 to use full attention.", ) parser.add_argument( "--left-context", type=int, default=64, - help="left context can be seen during decoding (in frames after subsampling)", # noqa + help="""Left context can be seen during decoding (in frames after subsampling). """, # noqa ) parser.add_argument( @@ -409,12 +409,14 @@ def decode_one_batch( feature_lens = supervisions["num_frames"].to(device) if params.simulate_streaming: - feature_lens += params.left_context - feature = torch.nn.functional.pad( - feature, - pad=(0, 0, 0, params.left_context), - value=LOG_EPS, - ) + if params.decode_chunk_size > 0: + # except the case of using full attention + feature_lens += params.left_context + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, params.left_context), + value=LOG_EPS, + ) encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward( x=feature, x_lens=feature_lens, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py index 2be895feb..af0b2d9fc 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py @@ -291,7 +291,7 @@ def get_parser(): "--decode-chunk-size", type=int, default=16, - help="The chunk size for decoding (in frames after subsampling)", + help="The chunk size for decoding (in frames after subsampling). Set -1 to use full attention.", ) parser.add_argument( @@ -470,12 +470,14 @@ def decode_one_batch( feature_lens = supervisions["num_frames"].to(device) if params.simulate_streaming: - feature_lens += params.left_context - feature = torch.nn.functional.pad( - feature, - pad=(0, 0, 0, params.left_context), - value=LOG_EPS, - ) + if params.decode_chunk_size > 0: + # except the case of using full attention + feature_lens += params.left_context + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, params.left_context), + value=LOG_EPS, + ) encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward( x=feature, x_lens=feature_lens, diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index 01e8c5b21..94d0393c2 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -358,6 +358,11 @@ class Conformer(Transformer): assert x.size(0) == lengths.max().item() + if chunk_size < 0: + # use full attention + chunk_size = x.size(0) + left_context = -1 + num_left_chunks = -1 if left_context >= 0: assert left_context % chunk_size == 0 From c21b6a208b7919ae867d2c833f8be4ff9386e47b Mon Sep 17 00:00:00 2001 From: marcoyang1998 <45973641+marcoyang1998@users.noreply.github.com> Date: Thu, 30 Mar 2023 17:08:46 +0800 Subject: [PATCH 142/174] Add finetuning script for aishell (#974) * add aishell finetune scripts * add an example bash script --- egs/wenetspeech/ASR/finetune.sh | 82 ++ .../pruned_transducer_stateless2/aishell.py | 1 + .../decode_aishell.py | 547 +++++++++ .../pruned_transducer_stateless2/finetune.py | 1050 +++++++++++++++++ 4 files changed, 1680 insertions(+) create mode 100755 egs/wenetspeech/ASR/finetune.sh create mode 120000 egs/wenetspeech/ASR/pruned_transducer_stateless2/aishell.py create mode 100755 egs/wenetspeech/ASR/pruned_transducer_stateless2/decode_aishell.py create mode 100755 egs/wenetspeech/ASR/pruned_transducer_stateless2/finetune.py diff --git a/egs/wenetspeech/ASR/finetune.sh b/egs/wenetspeech/ASR/finetune.sh new file mode 100755 index 000000000..8559780e9 --- /dev/null +++ b/egs/wenetspeech/ASR/finetune.sh @@ -0,0 +1,82 @@ +#!/usr/bin/env bash + +# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 +export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python + +set -eou pipefail + +stage=-1 +stop_stage=100 + +# This is an example script for fine-tuning. Here, we fine-tune a model trained +# on WenetSpeech on Aishell. The model used for fine-tuning is +# pruned_transducer_stateless2 (zipformer). If you want to fine-tune model +# from another recipe, you can adapt ./pruned_transducer_stateless2/finetune.py +# for that recipe. If you have any problem, please open up an issue in https://github.com/k2-fsa/icefall/issues. + +# We assume that you have already prepared the Aishell manfiest&features under ./data. +# If you haven't done that, please see https://github.com/k2-fsa/icefall/blob/master/egs/aishell/ASR/prepare.sh. + +. shared/parse_options.sh || exit 1 + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then + log "Stage -1: Download Pre-trained model" + + # clone from huggingface + git lfs install + git clone https://huggingface.co/luomingshuang/icefall_asr_wenetspeech_pruned_transducer_stateless2 + +fi + +if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then + log "Stage 0: Start fine-tuning" + + # The following configuration of lr schedule should work well + # You may also tune the following parameters to adjust learning rate schedule + initial_lr=0.0001 + lr_epochs=100 + lr_batches=100000 + + # We recommend to start from an averaged model + finetune_ckpt=icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/pretrained_epoch_10_avg_2.pt + lang_dir=icefall_asr_wenetspeech_pruned_transducer_stateless2/data/lang_char + export CUDA_VISIBLE_DEVICES="0,1" + + ./pruned_transducer_stateless2/finetune.py \ + --world-size 2 \ + --master-port 18180 \ + --num-epochs 15 \ + --context-size 2 \ + --exp-dir pruned_transducer_stateless2/exp_aishell_finetune \ + --initial-lr $initial_lr \ + --lr-epochs $lr_epochs \ + --lr-batches $lr_batches \ + --lang-dir $lang_dir \ + --do-finetune True \ + --finetune-ckpt $finetune_ckpt \ + --max-duration 200 +fi + +if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then + log "Stage 1: Decoding" + + epoch=4 + avg=4 + + for m in greedy_search modified_beam_search; do + python pruned_transducer_stateless2/decode_aishell.py \ + --epoch $epoch \ + --avg $avg \ + --context-size 2 \ + --beam-size 4 \ + --exp-dir pruned_transducer_stateless2/exp_aishell_finetune \ + --max-duration 400 \ + --decoding-method $m + done +fi diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/aishell.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/aishell.py new file mode 120000 index 000000000..f7321272b --- /dev/null +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/aishell.py @@ -0,0 +1 @@ +../../../aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py \ No newline at end of file diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode_aishell.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode_aishell.py new file mode 100755 index 000000000..2e644ec2f --- /dev/null +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode_aishell.py @@ -0,0 +1,547 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) greedy search +./pruned_transducer_stateless2/decode.py \ + --epoch 84 \ + --avg 25 \ + --exp-dir ./pruned_transducer_stateless2/exp \ + --max-duration 600 \ + --decoding-method greedy_search + +(2) beam search (not recommended) +./pruned_transducer_stateless2/decode.py \ + --epoch 84 \ + --avg 25 \ + --exp-dir ./pruned_transducer_stateless2/exp \ + --max-duration 600 \ + --decoding-method beam_search \ + --beam-size 4 + +(3) modified beam search +./pruned_transducer_stateless2/decode.py \ + --epoch 84 \ + --avg 25 \ + --exp-dir ./pruned_transducer_stateless2/exp \ + --max-duration 600 \ + --decoding-method modified_beam_search \ + --beam-size 4 + +(4) fast beam search +./pruned_transducer_stateless2/decode.py \ + --epoch 84 \ + --avg 25 \ + --exp-dir ./pruned_transducer_stateless2/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search \ + --beam 4 \ + --max-contexts 4 \ + --max-states 8 +""" + + +import argparse +import logging +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import torch +import torch.nn as nn +from aishell import AishellAsrDataModule +from beam_search import ( + beam_search, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, +) +from finetune import get_params, get_transducer_model + +from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + setup_logger, + store_transcripts, + write_error_stats, +) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless2/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_char", + help="The lang dir", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - fast_beam_search + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=4, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search""", + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=4, + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=8, + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=1, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. + Used only when --decoding_method is greedy_search""", + ) + + return parser + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + token_table: k2.SymbolTable, + batch: dict, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if greedy_search is used, it would be "greedy_search" + If beam search with a beam size of 7 is used, it would be + "beam_7" + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + token_table: + It maps token ID to a string. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search. + Returns: + Return the decoding result. See above description for the format of + the returned dict. + """ + device = next(model.parameters()).device + feature = batch["inputs"] + assert feature.ndim == 3 + + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + + if params.decoding_method == "fast_beam_search": + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + elif params.decoding_method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + ) + else: + hyp_tokens = [] + batch_size = encoder_out.size(0) + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.decoding_method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyp_tokens.append(hyp) + + hyps = [[token_table[t] for t in tokens] for tokens in hyp_tokens] + + if params.decoding_method == "greedy_search": + return {"greedy_search": hyps} + elif params.decoding_method == "fast_beam_search": + return { + ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}" + ): hyps + } + else: + return {f"beam_size_{params.beam_size}": hyps} + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + token_table: k2.SymbolTable, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + token_table: + It maps a token ID to a string. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + if params.decoding_method == "greedy_search": + log_interval = 50 + else: + log_interval = 20 + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + token_table=token_table, + decoding_graph=decoding_graph, + batch=batch, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" + # we compute CER for aishell dataset. + results_char = [] + for res in results: + results_char.append((res[0], list("".join(res[1])), list("".join(res[2])))) + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results_char, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + AishellAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + args.lang_dir = Path(args.lang_dir) + + params = get_params() + params.update(vars(args)) + + assert params.decoding_method in ( + "greedy_search", + "beam_search", + "fast_beam_search", + "modified_beam_search", + ) + params.res_dir = params.exp_dir / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if "fast_beam_search" in params.decoding_method: + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + elif "beam_search" in params.decoding_method: + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + else: + params.suffix += f"-context-{params.context_size}" + params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + lexicon = Lexicon(params.lang_dir) + params.blank_id = 0 + params.vocab_size = max(lexicon.tokens) + 1 + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict( + average_checkpoints(filenames, device=device), strict=False + ) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict( + average_checkpoints(filenames, device=device), strict=False + ) + + model.to(device) + model.eval() + + if params.decoding_method == "fast_beam_search": + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + else: + decoding_graph = None + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + aishell = AishellAsrDataModule(args) + test_cuts = aishell.test_cuts() + dev_cuts = aishell.valid_cuts() + test_dl = aishell.test_dataloaders(test_cuts) + dev_dl = aishell.test_dataloaders(dev_cuts) + + test_sets = ["test", "dev"] + test_dls = [test_dl, dev_dl] + + for test_set, test_dl in zip(test_sets, test_dls): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + token_table=lexicon.token_table, + decoding_graph=decoding_graph, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/finetune.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/finetune.py new file mode 100755 index 000000000..e703100a9 --- /dev/null +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/finetune.py @@ -0,0 +1,1050 @@ +#!/usr/bin/env python3 +# Copyright 2021-2022 Xiaomi Corp. (authors: Xiaoyu Yang, +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +./pruned_transducer_stateless2/finetune.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --exp-dir pruned_transducer_stateless2/exp \ + --full-libri 1 \ + --do-finetune 1 \ + --max-duration 100 + +""" + + +import argparse +import logging +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, List, Optional, Tuple, Union + +import k2 +import optim +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from aishell import AishellAsrDataModule +from conformer import Conformer +from decoder import Decoder +from joiner import Joiner +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model import Transducer +from optim import Eden, Eve +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter + +from icefall import diagnostics +from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import save_checkpoint_with_global_batch_idx +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.lexicon import Lexicon +from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def add_finetune_arguments(parser: argparse.ArgumentParser): + parser.add_argument("--do-finetune", type=str2bool, default=False) + + parser.add_argument( + "--init-modules", + type=str, + default=None, + help=""" + Modules to be initialized. It matches all parameters starting with + a specific key. The keys are given with Comma seperated. If None, + all modules will be initialised. For example, if you only want to + initialise all parameters staring with "encoder", use "encoder"; + if you want to initialise parameters starting with encoder or decoder, + use "encoder,joiner". + """, + ) + + parser.add_argument( + "--finetune-ckpt", + type=str, + default=None, + help="Fine-tuning from which checkpoint (a path to a .pt file)", + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=0, + help="""Resume training from from this epoch. + If it is positive, it will load checkpoint from + pruned_transducer_stateless2/exp/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless2/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_char", + help="""The lang dir + It contains language related input files such as + "lexicon.txt" + """, + ) + + parser.add_argument( + "--initial-lr", + type=float, + default=0.0001, + help="The initial learning rate. This value should not need to be changed.", + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=100000, + help="""Number of steps that affects how rapidly the learning rate + decreases. During fine-tuning, we set this very large so that the + learning rate slowly decays with number of batches. You may tune + its value by yourself. + """, + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=100, + help="""Number of epochs that affects how rapidly the learning rate + decreases. During fine-tuning, we set this very large so that the + learning rate slowly decays with number of batches. You may tune + its value by yourself. + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=1, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network) part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=2000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 0. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=20, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + parser.add_argument( + "--valid-interval", + type=int, + default=3000, + help="""When training_subset is L, set the valid_interval to 3000. + When training_subset is M, set the valid_interval to 1000. + When training_subset is S, set the valid_interval to 400. + """, + ) + + parser.add_argument( + "--model-warm-step", + type=int, + default=3000, + help="""When training_subset is L, set the model_warm_step to 3000. + When training_subset is M, set the model_warm_step to 500. + When training_subset is S, set the model_warm_step to 100. + """, + ) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + Explanation of options saved in `params`: + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + - best_train_epoch: It is the epoch that has the best training loss. + - best_valid_epoch: It is the epoch that has the best validation loss. + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + - log_interval: Print training loss if batch_idx % log_interval` is 0 + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + - feature_dim: The model input dim. It has to match the one used + in computing features. + - subsampling_factor: The subsampling factor for the model. + - encoder_dim: Hidden dim for multi-head attention model. + - num_decoder_layers: Number of decoder layer of transformer decoder. + - warm_step: The warm_step for Noam optimizer. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + # parameters for conformer + "feature_dim": 80, + "subsampling_factor": 4, + "encoder_dim": 512, + "nhead": 8, + "dim_feedforward": 2048, + "num_encoder_layers": 12, + # parameters for decoder + "decoder_dim": 512, + # parameters for joiner + "joiner_dim": 512, + "env_info": get_env_info(), + } + ) + + return params + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + # TODO: We can add an option to switch between Conformer and Transformer + encoder = Conformer( + num_features=params.feature_dim, + subsampling_factor=params.subsampling_factor, + d_model=params.encoder_dim, + nhead=params.nhead, + dim_feedforward=params.dim_feedforward, + num_encoder_layers=params.num_encoder_layers, + ) + return encoder + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + decoder_dim=params.decoder_dim, + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + joiner = Joiner( + encoder_dim=params.encoder_dim, + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return joiner + + +def get_transducer_model(params: AttributeDict) -> nn.Module: + encoder = get_encoder_model(params) + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + + model = Transducer( + encoder=encoder, + decoder=decoder, + joiner=joiner, + encoder_dim=params.encoder_dim, + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is positive, it will load the checkpoint from + `params.start_epoch - 1`. + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 0: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + return saved_params + + +def load_model_params( + ckpt: str, model: nn.Module, init_modules: List[str] = None, strict: bool = True +): + """Load model params from checkpoint + + Args: + ckpt (str): Path to the checkpoint + model (nn.Module): model to be loaded + + """ + logging.info(f"Loading checkpoint from {ckpt}") + checkpoint = torch.load(ckpt, map_location="cpu") + + # if module list is empty, load the whole model from ckpt + if not init_modules: + if next(iter(checkpoint["model"])).startswith("module."): + logging.info("Loading checkpoint saved by DDP") + + dst_state_dict = model.state_dict() + src_state_dict = checkpoint["model"] + for key in dst_state_dict.keys(): + src_key = "{}.{}".format("module", key) + dst_state_dict[key] = src_state_dict.pop(src_key) + assert len(src_state_dict) == 0 + model.load_state_dict(dst_state_dict, strict=strict) + else: + model.load_state_dict(checkpoint["model"], strict=strict) + else: + src_state_dict = checkpoint["model"] + dst_state_dict = model.state_dict() + for module in init_modules: + logging.info(f"Loading parameters starting with prefix {module}") + src_keys = [k for k in src_state_dict.keys() if k.startswith(module)] + dst_keys = [k for k in dst_state_dict.keys() if k.startswith(module)] + assert set(src_keys) == set(dst_keys) # two sets should match exactly + for key in src_keys: + dst_state_dict[key] = src_state_dict.pop(key) + + model.load_state_dict(dst_state_dict, strict=strict) + + return None + + +def save_checkpoint( + params: AttributeDict, + model: nn.Module, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: nn.Module, + graph_compiler: CharCtcTrainingGraphCompiler, + batch: dict, + is_training: bool, + warmup: float = 1.0, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute transducer loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = model.device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + texts = batch["supervisions"]["text"] + + y = graph_compiler.texts_to_ids(texts) + if isinstance(y, list): + y = k2.RaggedTensor(y).to(device) + else: + y = y.to(device) + + with torch.set_grad_enabled(is_training): + simple_loss, pruned_loss = model( + x=feature, + x_lens=feature_lens, + y=y, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + warmup=warmup, + ) + # after the main warmup step, we keep pruned_loss_scale small + # for the same amount of time (model_warm_step), to avoid + # overwhelming the simple_loss and causing it to diverge, + # in case it had not fully learned the alignment yet. + pruned_loss_scale = ( + 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) + ) + loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: nn.Module, + graph_compiler: CharCtcTrainingGraphCompiler, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + graph_compiler=graph_compiler, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: nn.Module, + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + graph_compiler: CharCtcTrainingGraphCompiler, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(train_dl): + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + graph_compiler=graph_compiler, + batch=batch, + is_training=True, + warmup=(params.batch_idx_train / params.model_warm_step), + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + scheduler.step_batch(params.batch_idx_train) + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except: # noqa + display_and_save_batch(batch, params=params) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % params.log_interval == 0: + cur_lr = scheduler.get_last_lr()[0] + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}" + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + graph_compiler=graph_compiler, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + lexicon = Lexicon(params.lang_dir) + graph_compiler = CharCtcTrainingGraphCompiler( + lexicon=lexicon, + device=device, + ) + + params.blank_id = lexicon.token_table[""] + params.vocab_size = max(lexicon.tokens) + 1 + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # load model parameters for model fine-tuning + if params.do_finetune: + modules = params.init_modules.split(",") if params.init_modules else None + checkpoints = load_model_params( + ckpt=params.finetune_ckpt, model=model, init_modules=modules + ) + else: + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available(params=params, model=model) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank]) + model.device = device + + optimizer = Eve(model.parameters(), lr=params.initial_lr) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 2**22 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + aishell = AishellAsrDataModule(args) + train_dl = aishell.train_dataloaders(aishell.train_cuts()) + valid_dl = aishell.valid_dataloaders(aishell.valid_cuts()) + + if not params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + graph_compiler=graph_compiler, + params=params, + ) + + scaler = GradScaler(enabled=params.use_fp16) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + optimizer=optimizer, + scheduler=scheduler, + graph_compiler=graph_compiler, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + graph_compiler: CharCtcTrainingGraphCompiler, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + y = graph_compiler.texts_to_ids(supervisions["text"]) + num_tokens = sum(len(i) for i in y) + logging.info(f"num tokens: {num_tokens}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + graph_compiler: CharCtcTrainingGraphCompiler, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + # warmup = 0.0 is so that the derivs for the pruned loss stay zero + # (i.e. are not remembered by the decaying-average in adam), because + # we want to avoid these params being subject to shrinkage in adam. + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + graph_compiler=graph_compiler, + batch=batch, + is_training=True, + warmup=0.0 if params.start_epoch == 1 else 1.0, + ) + loss.backward() + optimizer.step() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params, graph_compiler=graph_compiler) + raise + + +def main(): + parser = get_parser() + AishellAsrDataModule.add_arguments( + parser + ) # you may replace this with your own dataset + add_finetune_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() From a632b24c353be0ac113a17a756156387fda5c7e0 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Fri, 31 Mar 2023 22:46:19 +0800 Subject: [PATCH 143/174] Export int8 quantized models for non-streaming Zipformer. (#977) * Export int8 quantized models for non-streaming Zipformer. * Delete export-onnx.py * Export int8 models for other folders --- .../export-onnx-zh.py | 632 ++++++++++++++++ .../lstm_transducer_stateless2/export-onnx.py | 35 +- .../export-onnx.py | 30 + .../export-onnx.py | 30 + .../export-onnx-zh.py | 678 ++++++++++++++++++ .../export-onnx.py | 30 + 6 files changed, 1434 insertions(+), 1 deletion(-) create mode 100755 egs/librispeech/ASR/lstm_transducer_stateless2/export-onnx-zh.py create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx-zh.py diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/export-onnx-zh.py b/egs/librispeech/ASR/lstm_transducer_stateless2/export-onnx-zh.py new file mode 100755 index 000000000..f068f6a0f --- /dev/null +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/export-onnx-zh.py @@ -0,0 +1,632 @@ +#!/usr/bin/env python3 +# +# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang) + +""" +This script exports a transducer model from PyTorch to ONNX. + +We use the pre-trained model from +https://huggingface.co/csukuangfj/icefall-asr-wenetspeech-lstm-transducer-stateless-2022-10-14 +as an example to show how to use this file. + +1. Download the pre-trained model + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/csukuangfj/icefall-asr-wenetspeech-lstm-transducer-stateless-2022-10-14 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "data/lexicon.txt" +git lfs pull --include "data/L.pt" +git lfs pull --include "exp/epoch-11.pt" +git lfs pull --include "exp/epoch-10.pt" + +popd + +2. Export the model to ONNX + +./lstm_transducer_stateless2/export-onnx-zh.py \ + --lang-dir ./icefall-asr-wenetspeech-lstm-transducer-stateless-2022-10-14/data/lang_char \ + --use-averaged-model 1 \ + --epoch 11 \ + --avg 1 \ + --exp-dir ./icefall-asr-wenetspeech-lstm-transducer-stateless-2022-10-14/exp \ + --num-encoder-layers 12 \ + --encoder-dim 512 \ + --rnn-hidden-size 1024 + +It will generate the following files inside $repo/exp: + + - encoder-epoch-11-avg-1.onnx + - decoder-epoch-11-avg-1.onnx + - joiner-epoch-11-avg-1.onnx + - encoder-epoch-11-avg-1.int8.onnx + - decoder-epoch-11-avg-1.int8.onnx + - joiner-epoch-11-avg-1.int8.onnx + +See ./onnx_pretrained.py and ./onnx_check.py for how to +use the exported ONNX models. +""" + +import argparse +import logging +from pathlib import Path +from typing import Dict, Optional, Tuple + +import onnx +import torch +import torch.nn as nn +from decoder import Decoder +from lstm import RNN +from onnxruntime.quantization import QuantType, quantize_dynamic +from scaling_converter import convert_scaled_to_non_scaled +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import setup_logger, str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="""It specifies the checkpoint to use for averaging. + Note: Epoch counts from 0. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless5/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_char", + help="The lang dir", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + add_model_arguments(parser) + + return parser + + +def add_meta_data(filename: str, meta_data: Dict[str, str]): + """Add meta data to an ONNX model. It is changed in-place. + + Args: + filename: + Filename of the ONNX model to be changed. + meta_data: + Key-value pairs. + """ + model = onnx.load(filename) + for key, value in meta_data.items(): + meta = model.metadata_props.add() + meta.key = key + meta.value = value + + onnx.save(model, filename) + + +class OnnxEncoder(nn.Module): + """A wrapper for RNN and the encoder_proj from the joiner""" + + def __init__(self, encoder: RNN, encoder_proj: nn.Linear): + """ + Args: + encoder: + An RNN encoder. + encoder_proj: + The projection layer for encoder from the joiner. + """ + super().__init__() + self.encoder = encoder + self.encoder_proj = encoder_proj + + def forward( + self, + x: torch.Tensor, + states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Please see the help information of RNN.forward + + Args: + x: + A 3-D tensor of shape (N, T, C) + states: + A tuple of 2 tensors (optional). It is for streaming inference. + states[0] is the hidden states of all layers, + with shape of (num_layers, N, d_model); + states[1] is the cell states of all layers, + with shape of (num_layers, N, rnn_hidden_size). + Returns: + Return a tuple containing: + - encoder_out, A 3-D tensor of shape (N, T', joiner_dim) + - updated states, whose shape is the same as the input states. + """ + N = x.size(0) + T = x.size(1) + x_lens = torch.tensor([T] * N, dtype=torch.int64, device=x.device) + encoder_out, _, next_states = self.encoder(x, x_lens, states) + + encoder_out = self.encoder_proj(encoder_out) + # Now encoder_out is of shape (N, T, joiner_dim) + + return encoder_out, next_states + + +class OnnxDecoder(nn.Module): + """A wrapper for Decoder and the decoder_proj from the joiner""" + + def __init__(self, decoder: Decoder, decoder_proj: nn.Linear): + super().__init__() + self.decoder = decoder + self.decoder_proj = decoder_proj + + def forward(self, y: torch.Tensor) -> torch.Tensor: + """ + Args: + y: + A 2-D tensor of shape (N, context_size). + Returns + Return a 2-D tensor of shape (N, joiner_dim) + """ + need_pad = False + decoder_output = self.decoder(y, need_pad=need_pad) + decoder_output = decoder_output.squeeze(1) + output = self.decoder_proj(decoder_output) + + return output + + +class OnnxJoiner(nn.Module): + """A wrapper for the joiner""" + + def __init__(self, output_linear: nn.Linear): + super().__init__() + self.output_linear = output_linear + + def forward( + self, + encoder_out: torch.Tensor, + decoder_out: torch.Tensor, + ) -> torch.Tensor: + """ + Args: + encoder_out: + A 2-D tensor of shape (N, joiner_dim) + decoder_out: + A 2-D tensor of shape (N, joiner_dim) + Returns: + Return a 2-D tensor of shape (N, vocab_size) + """ + logit = encoder_out + decoder_out + logit = self.output_linear(torch.tanh(logit)) + return logit + + +def export_encoder_model_onnx( + encoder_model: OnnxEncoder, + encoder_filename: str, + opset_version: int = 11, +) -> None: + """Export the given encoder model to ONNX format. + The exported model has the following inputs: + + - x, a tensor of shape (N, T, C); dtype is torch.float32 + - state0, a tensor of shape (num_encoder_layers, batch_size, d_model) + - state1, a tensor of shape (num_encoder_layers, batch_size, rnn_hidden_size) + + and it has 3 outputs: + + - encoder_out, a tensor of shape (N, T', joiner_dim) + - new_state0, a tensor of shape (num_encoder_layers, batch_size, d_model) + - new_state1, a tensor of shape (num_encoder_layers, batch_size, rnn_hidden_size) + + Args: + encoder_model: + The input encoder model + encoder_filename: + The filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + num_encoder_layers = encoder_model.encoder.num_encoder_layers + d_model = encoder_model.encoder.d_model + rnn_hidden_size = encoder_model.encoder.rnn_hidden_size + + decode_chunk_len = 4 + T = 9 + + x = torch.zeros(1, T, 80, dtype=torch.float32) + states = encoder_model.encoder.get_init_states() + # state0: (num_encoder_layers, batch_size, d_model) + # state1: (num_encoder_layers, batch_size, rnn_hidden_size) + + torch.onnx.export( + encoder_model, + (x, states), + encoder_filename, + verbose=False, + opset_version=opset_version, + input_names=["x", "state0", "state1"], + output_names=["encoder_out", "new_state0", "new_state1"], + dynamic_axes={ + "x": {0: "N", 1: "T"}, + "state0": {1: "N"}, + "state1": {1: "N"}, + "encoder_out": {0: "N"}, + "new_state0": {1: "N"}, + "new_state1": {1: "N"}, + }, + ) + + meta_data = { + "model_type": "lstm", + "version": "1", + "model_author": "k2-fsa", + "decode_chunk_len": str(decode_chunk_len), # 32 + "T": str(T), # 39 + "num_encoder_layers": str(num_encoder_layers), + "d_model": str(d_model), + "rnn_hidden_size": str(rnn_hidden_size), + } + logging.info(f"meta_data: {meta_data}") + + add_meta_data(filename=encoder_filename, meta_data=meta_data) + + +def export_decoder_model_onnx( + decoder_model: OnnxDecoder, + decoder_filename: str, + opset_version: int = 11, +) -> None: + """Export the decoder model to ONNX format. + + The exported model has one input: + + - y: a torch.int64 tensor of shape (N, decoder_model.context_size) + + and has one output: + + - decoder_out: a torch.float32 tensor of shape (N, joiner_dim) + + Args: + decoder_model: + The decoder model to be exported. + decoder_filename: + Filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + context_size = decoder_model.decoder.context_size + vocab_size = decoder_model.decoder.vocab_size + + y = torch.zeros(10, context_size, dtype=torch.int64) + torch.onnx.export( + decoder_model, + y, + decoder_filename, + verbose=False, + opset_version=opset_version, + input_names=["y"], + output_names=["decoder_out"], + dynamic_axes={ + "y": {0: "N"}, + "decoder_out": {0: "N"}, + }, + ) + + meta_data = { + "context_size": str(context_size), + "vocab_size": str(vocab_size), + } + add_meta_data(filename=decoder_filename, meta_data=meta_data) + + +def export_joiner_model_onnx( + joiner_model: nn.Module, + joiner_filename: str, + opset_version: int = 11, +) -> None: + """Export the joiner model to ONNX format. + The exported joiner model has two inputs: + + - encoder_out: a tensor of shape (N, joiner_dim) + - decoder_out: a tensor of shape (N, joiner_dim) + + and produces one output: + + - logit: a tensor of shape (N, vocab_size) + """ + joiner_dim = joiner_model.output_linear.weight.shape[1] + logging.info(f"joiner dim: {joiner_dim}") + + projected_encoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) + projected_decoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) + + torch.onnx.export( + joiner_model, + (projected_encoder_out, projected_decoder_out), + joiner_filename, + verbose=False, + opset_version=opset_version, + input_names=[ + "encoder_out", + "decoder_out", + ], + output_names=["logit"], + dynamic_axes={ + "encoder_out": {0: "N"}, + "decoder_out": {0: "N"}, + "logit": {0: "N"}, + }, + ) + meta_data = { + "joiner_dim": str(joiner_dim), + } + add_meta_data(filename=joiner_filename, meta_data=meta_data) + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + setup_logger(f"{params.exp_dir}/log-export/log-export-onnx") + + logging.info(f"device: {device}") + + lexicon = Lexicon(params.lang_dir) + params.blank_id = 0 + params.vocab_size = max(lexicon.tokens) + 1 + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params, enable_giga=False) + + model.to(device) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict( + average_checkpoints(filenames, device=device), strict=False + ) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict( + average_checkpoints(filenames, device=device), strict=False + ) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ), + strict=False, + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ), + strict=False, + ) + + model.to("cpu") + model.eval() + + convert_scaled_to_non_scaled(model, inplace=True, is_onnx=True) + + encoder = OnnxEncoder( + encoder=model.encoder, + encoder_proj=model.joiner.encoder_proj, + ) + + decoder = OnnxDecoder( + decoder=model.decoder, + decoder_proj=model.joiner.decoder_proj, + ) + + joiner = OnnxJoiner(output_linear=model.joiner.output_linear) + + encoder_num_param = sum([p.numel() for p in encoder.parameters()]) + decoder_num_param = sum([p.numel() for p in decoder.parameters()]) + joiner_num_param = sum([p.numel() for p in joiner.parameters()]) + total_num_param = encoder_num_param + decoder_num_param + joiner_num_param + logging.info(f"encoder parameters: {encoder_num_param}") + logging.info(f"decoder parameters: {decoder_num_param}") + logging.info(f"joiner parameters: {joiner_num_param}") + logging.info(f"total parameters: {total_num_param}") + + if params.iter > 0: + suffix = f"iter-{params.iter}" + else: + suffix = f"epoch-{params.epoch}" + + suffix += f"-avg-{params.avg}" + + opset_version = 13 + + logging.info("Exporting encoder") + encoder_filename = params.exp_dir / f"encoder-{suffix}.onnx" + export_encoder_model_onnx( + encoder, + encoder_filename, + opset_version=opset_version, + ) + logging.info(f"Exported encoder to {encoder_filename}") + + logging.info("Exporting decoder") + decoder_filename = params.exp_dir / f"decoder-{suffix}.onnx" + export_decoder_model_onnx( + decoder, + decoder_filename, + opset_version=opset_version, + ) + logging.info(f"Exported decoder to {decoder_filename}") + + logging.info("Exporting joiner") + joiner_filename = params.exp_dir / f"joiner-{suffix}.onnx" + export_joiner_model_onnx( + joiner, + joiner_filename, + opset_version=opset_version, + ) + logging.info(f"Exported joiner to {joiner_filename}") + + # Generate int8 quantization models + # See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection + + logging.info("Generate int8 quantization models") + + encoder_filename_int8 = params.exp_dir / f"encoder-{suffix}.int8.onnx" + quantize_dynamic( + model_input=encoder_filename, + model_output=encoder_filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) + + decoder_filename_int8 = params.exp_dir / f"decoder-{suffix}.int8.onnx" + quantize_dynamic( + model_input=decoder_filename, + model_output=decoder_filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) + + joiner_filename_int8 = params.exp_dir / f"joiner-{suffix}.int8.onnx" + quantize_dynamic( + model_input=joiner_filename, + model_output=joiner_filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/export-onnx.py b/egs/librispeech/ASR/lstm_transducer_stateless2/export-onnx.py index 46873ebf9..acaff8540 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/export-onnx.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/export-onnx.py @@ -34,11 +34,14 @@ popd --avg 1 \ --exp-dir $repo/exp -It will generate the following 3 files inside $repo/exp: +It will generate the following files inside $repo/exp: - encoder-epoch-99-avg-1.onnx - decoder-epoch-99-avg-1.onnx - joiner-epoch-99-avg-1.onnx + - encoder-epoch-99-avg-1.int8.onnx + - decoder-epoch-99-avg-1.int8.onnx + - joiner-epoch-99-avg-1.int8.onnx See ./onnx_pretrained.py and ./onnx_check.py for how to use the exported ONNX models. @@ -55,6 +58,7 @@ import torch import torch.nn as nn from decoder import Decoder from lstm import RNN +from onnxruntime.quantization import QuantType, quantize_dynamic from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_params, get_transducer_model @@ -586,6 +590,35 @@ def main(): ) logging.info(f"Exported joiner to {joiner_filename}") + # Generate int8 quantization models + # See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection + + logging.info("Generate int8 quantization models") + + encoder_filename_int8 = params.exp_dir / f"encoder-{suffix}.int8.onnx" + quantize_dynamic( + model_input=encoder_filename, + model_output=encoder_filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) + + decoder_filename_int8 = params.exp_dir / f"decoder-{suffix}.int8.onnx" + quantize_dynamic( + model_input=decoder_filename, + model_output=decoder_filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) + + joiner_filename_int8 = params.exp_dir / f"joiner-{suffix}.int8.onnx" + quantize_dynamic( + model_input=joiner_filename, + model_output=joiner_filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) + if __name__ == "__main__": formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/export-onnx.py b/egs/librispeech/ASR/pruned_transducer_stateless3/export-onnx.py index 36e57e946..9645b7801 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/export-onnx.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/export-onnx.py @@ -54,6 +54,7 @@ import torch import torch.nn as nn from conformer import Conformer from decoder import Decoder +from onnxruntime.quantization import QuantType, quantize_dynamic from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_params, get_transducer_model @@ -500,6 +501,35 @@ def main(): ) logging.info(f"Exported joiner to {joiner_filename}") + # Generate int8 quantization models + # See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection + + logging.info("Generate int8 quantization models") + + encoder_filename_int8 = params.exp_dir / f"encoder-{suffix}.int8.onnx" + quantize_dynamic( + model_input=encoder_filename, + model_output=encoder_filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) + + decoder_filename_int8 = params.exp_dir / f"decoder-{suffix}.int8.onnx" + quantize_dynamic( + model_input=decoder_filename, + model_output=decoder_filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) + + joiner_filename_int8 = params.exp_dir / f"joiner-{suffix}.int8.onnx" + quantize_dynamic( + model_input=joiner_filename, + model_output=joiner_filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) + if __name__ == "__main__": formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/export-onnx.py b/egs/librispeech/ASR/pruned_transducer_stateless7/export-onnx.py index 93eb8df3d..2f5d9e338 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/export-onnx.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/export-onnx.py @@ -55,6 +55,7 @@ import sentencepiece as spm import torch import torch.nn as nn from decoder import Decoder +from onnxruntime.quantization import QuantType, quantize_dynamic from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_params, get_transducer_model from zipformer import Zipformer @@ -563,6 +564,35 @@ def main(): ) logging.info(f"Exported joiner to {joiner_filename}") + # Generate int8 quantization models + # See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection + + logging.info("Generate int8 quantization models") + + encoder_filename_int8 = params.exp_dir / f"encoder-{suffix}.int8.onnx" + quantize_dynamic( + model_input=encoder_filename, + model_output=encoder_filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) + + decoder_filename_int8 = params.exp_dir / f"decoder-{suffix}.int8.onnx" + quantize_dynamic( + model_input=decoder_filename, + model_output=decoder_filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) + + joiner_filename_int8 = params.exp_dir / f"joiner-{suffix}.int8.onnx" + quantize_dynamic( + model_input=joiner_filename, + model_output=joiner_filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) + if __name__ == "__main__": formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx-zh.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx-zh.py new file mode 100755 index 000000000..04d97808d --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx-zh.py @@ -0,0 +1,678 @@ +#!/usr/bin/env python3 +# +# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang) + +""" +This script exports a transducer model from PyTorch to ONNX. + +We use the pre-trained model from +https://huggingface.co/pfluo/k2fsa-zipformer-chinese-english-mixed +as an example to show how to use this file. + +1. Download the pre-trained model + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/pfluo/k2fsa-zipformer-chinese-english-mixed +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "data/lang_char_bpe/L.pt" +git lfs pull --include "data/lang_char_bpe/Linv.pt" +git lfs pull --include "data/lang_char_bpe/L_disambig.pt" +git lfs pull --include "exp/pretrained.pt" +cd exp +ln -s pretrained.pt epoch-99.pt +popd + +2. Export the model to ONNX + +./pruned_transducer_stateless7_streaming/export-onnx-zh.py \ + --lang-dir $repo/data/lang_char_bpe \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --exp-dir $repo/exp/ \ + --decode-chunk-len 32 \ + --num-encoder-layers "2,4,3,2,4" \ + --feedforward-dims "1024,1024,1536,1536,1024" \ + --nhead "8,8,8,8,8" \ + --encoder-dims "384,384,384,384,384" \ + --attention-dims "192,192,192,192,192" \ + --encoder-unmasked-dims "256,256,256,256,256" \ + --zipformer-downsampling-factors "1,2,4,8,2" \ + --cnn-module-kernels "31,31,31,31,31" \ + --decoder-dim 512 \ + --joiner-dim 512 + +It will generate the following 3 files in $repo/exp + + - encoder-epoch-99-avg-1.onnx + - decoder-epoch-99-avg-1.onnx + - joiner-epoch-99-avg-1.onnx + +See ./onnx_pretrained.py for how to use the exported models. +""" + +import argparse +import logging +from pathlib import Path +from typing import Dict, List, Tuple + +import onnx +import torch +import torch.nn as nn +from decoder import Decoder +from onnxruntime.quantization import QuantType, quantize_dynamic +from scaling_converter import convert_scaled_to_non_scaled +from torch import Tensor +from train import add_model_arguments, get_params, get_transducer_model +from zipformer import Zipformer + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import setup_logger, str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=9, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7_streaming/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_char", + help="The lang dir", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + add_model_arguments(parser) + + return parser + + +class OnnxEncoder(nn.Module): + """A wrapper for Zipformer and the encoder_proj from the joiner""" + + def __init__(self, encoder: Zipformer, encoder_proj: nn.Linear): + """ + Args: + encoder: + A Zipformer encoder. + encoder_proj: + The projection layer for encoder from the joiner. + """ + super().__init__() + self.encoder = encoder + self.encoder_proj = encoder_proj + + def forward(self, x: Tensor, states: List[Tensor]) -> Tuple[Tensor, List[Tensor]]: + """Please see the help information of Zipformer.streaming_forward""" + N = x.size(0) + T = x.size(1) + x_lens = torch.tensor([T] * N, device=x.device) + + output, _, new_states = self.encoder.streaming_forward( + x=x, + x_lens=x_lens, + states=states, + ) + + output = self.encoder_proj(output) + # Now output is of shape (N, T, joiner_dim) + + return output, new_states + + +class OnnxDecoder(nn.Module): + """A wrapper for Decoder and the decoder_proj from the joiner""" + + def __init__(self, decoder: Decoder, decoder_proj: nn.Linear): + super().__init__() + self.decoder = decoder + self.decoder_proj = decoder_proj + + def forward(self, y: torch.Tensor) -> torch.Tensor: + """ + Args: + y: + A 2-D tensor of shape (N, context_size). + Returns + Return a 2-D tensor of shape (N, joiner_dim) + """ + need_pad = False + decoder_output = self.decoder(y, need_pad=need_pad) + decoder_output = decoder_output.squeeze(1) + output = self.decoder_proj(decoder_output) + + return output + + +class OnnxJoiner(nn.Module): + """A wrapper for the joiner""" + + def __init__(self, output_linear: nn.Linear): + super().__init__() + self.output_linear = output_linear + + def forward( + self, + encoder_out: torch.Tensor, + decoder_out: torch.Tensor, + ) -> torch.Tensor: + """ + Args: + encoder_out: + A 2-D tensor of shape (N, joiner_dim) + decoder_out: + A 2-D tensor of shape (N, joiner_dim) + Returns: + Return a 2-D tensor of shape (N, vocab_size) + """ + logit = encoder_out + decoder_out + logit = self.output_linear(torch.tanh(logit)) + return logit + + +def add_meta_data(filename: str, meta_data: Dict[str, str]): + """Add meta data to an ONNX model. It is changed in-place. + + Args: + filename: + Filename of the ONNX model to be changed. + meta_data: + Key-value pairs. + """ + model = onnx.load(filename) + for key, value in meta_data.items(): + meta = model.metadata_props.add() + meta.key = key + meta.value = value + + onnx.save(model, filename) + + +def export_encoder_model_onnx( + encoder_model: OnnxEncoder, + encoder_filename: str, + opset_version: int = 11, +) -> None: + """ + Onnx model inputs: + - 0: src + - many state tensors (the exact number depending on the actual model) + + Onnx model outputs: + - 0: output, its shape is (N, T, joiner_dim) + - many state tensors (the exact number depending on the actual model) + + Args: + encoder_model: + The model to be exported + encoder_filename: + The filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + + encoder_model.encoder.__class__.forward = ( + encoder_model.encoder.__class__.streaming_forward + ) + + decode_chunk_len = encoder_model.encoder.decode_chunk_size * 2 + pad_length = 7 + T = decode_chunk_len + pad_length + logging.info(f"decode_chunk_len: {decode_chunk_len}") + logging.info(f"pad_length: {pad_length}") + logging.info(f"T: {T}") + + x = torch.rand(1, T, 80, dtype=torch.float32) + + init_state = encoder_model.encoder.get_init_state() + + num_encoders = encoder_model.encoder.num_encoders + logging.info(f"num_encoders: {num_encoders}") + logging.info(f"len(init_state): {len(init_state)}") + + inputs = {} + input_names = ["x"] + + outputs = {} + output_names = ["encoder_out"] + + def build_inputs_outputs(tensors, name, N): + for i, s in enumerate(tensors): + logging.info(f"{name}_{i}.shape: {s.shape}") + inputs[f"{name}_{i}"] = {N: "N"} + outputs[f"new_{name}_{i}"] = {N: "N"} + input_names.append(f"{name}_{i}") + output_names.append(f"new_{name}_{i}") + + num_encoder_layers = ",".join(map(str, encoder_model.encoder.num_encoder_layers)) + encoder_dims = ",".join(map(str, encoder_model.encoder.encoder_dims)) + attention_dims = ",".join(map(str, encoder_model.encoder.attention_dims)) + cnn_module_kernels = ",".join(map(str, encoder_model.encoder.cnn_module_kernels)) + ds = encoder_model.encoder.zipformer_downsampling_factors + left_context_len = encoder_model.encoder.left_context_len + left_context_len = [left_context_len // k for k in ds] + left_context_len = ",".join(map(str, left_context_len)) + + meta_data = { + "model_type": "zipformer", + "version": "1", + "model_author": "k2-fsa", + "decode_chunk_len": str(decode_chunk_len), # 32 + "T": str(T), # 39 + "num_encoder_layers": num_encoder_layers, + "encoder_dims": encoder_dims, + "attention_dims": attention_dims, + "cnn_module_kernels": cnn_module_kernels, + "left_context_len": left_context_len, + } + logging.info(f"meta_data: {meta_data}") + + # (num_encoder_layers, 1) + cached_len = init_state[num_encoders * 0 : num_encoders * 1] + + # (num_encoder_layers, 1, encoder_dim) + cached_avg = init_state[num_encoders * 1 : num_encoders * 2] + + # (num_encoder_layers, left_context_len, 1, attention_dim) + cached_key = init_state[num_encoders * 2 : num_encoders * 3] + + # (num_encoder_layers, left_context_len, 1, attention_dim//2) + cached_val = init_state[num_encoders * 3 : num_encoders * 4] + + # (num_encoder_layers, left_context_len, 1, attention_dim//2) + cached_val2 = init_state[num_encoders * 4 : num_encoders * 5] + + # (num_encoder_layers, 1, encoder_dim, cnn_module_kernel-1) + cached_conv1 = init_state[num_encoders * 5 : num_encoders * 6] + + # (num_encoder_layers, 1, encoder_dim, cnn_module_kernel-1) + cached_conv2 = init_state[num_encoders * 6 : num_encoders * 7] + + build_inputs_outputs(cached_len, "cached_len", 1) + build_inputs_outputs(cached_avg, "cached_avg", 1) + build_inputs_outputs(cached_key, "cached_key", 2) + build_inputs_outputs(cached_val, "cached_val", 2) + build_inputs_outputs(cached_val2, "cached_val2", 2) + build_inputs_outputs(cached_conv1, "cached_conv1", 1) + build_inputs_outputs(cached_conv2, "cached_conv2", 1) + + logging.info(inputs) + logging.info(outputs) + logging.info(input_names) + logging.info(output_names) + + torch.onnx.export( + encoder_model, + (x, init_state), + encoder_filename, + verbose=False, + opset_version=opset_version, + input_names=input_names, + output_names=output_names, + dynamic_axes={ + "x": {0: "N"}, + "encoder_out": {0: "N"}, + **inputs, + **outputs, + }, + ) + + add_meta_data(filename=encoder_filename, meta_data=meta_data) + + +def export_decoder_model_onnx( + decoder_model: nn.Module, + decoder_filename: str, + opset_version: int = 11, +) -> None: + """Export the decoder model to ONNX format. + + The exported model has one input: + + - y: a torch.int64 tensor of shape (N, context_size) + + and has one output: + + - decoder_out: a torch.float32 tensor of shape (N, joiner_dim) + + Note: The argument need_pad is fixed to False. + + Args: + decoder_model: + The decoder model to be exported. + decoder_filename: + Filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + context_size = decoder_model.decoder.context_size + vocab_size = decoder_model.decoder.vocab_size + y = torch.zeros(10, context_size, dtype=torch.int64) + torch.onnx.export( + decoder_model, + y, + decoder_filename, + verbose=False, + opset_version=opset_version, + input_names=["y"], + output_names=["decoder_out"], + dynamic_axes={ + "y": {0: "N"}, + "decoder_out": {0: "N"}, + }, + ) + meta_data = { + "context_size": str(context_size), + "vocab_size": str(vocab_size), + } + add_meta_data(filename=decoder_filename, meta_data=meta_data) + + +def export_joiner_model_onnx( + joiner_model: nn.Module, + joiner_filename: str, + opset_version: int = 11, +) -> None: + """Export the joiner model to ONNX format. + The exported joiner model has two inputs: + + - encoder_out: a tensor of shape (N, joiner_dim) + - decoder_out: a tensor of shape (N, joiner_dim) + + and produces one output: + + - logit: a tensor of shape (N, vocab_size) + """ + joiner_dim = joiner_model.output_linear.weight.shape[1] + logging.info(f"joiner dim: {joiner_dim}") + + projected_encoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) + projected_decoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) + + torch.onnx.export( + joiner_model, + (projected_encoder_out, projected_decoder_out), + joiner_filename, + verbose=False, + opset_version=opset_version, + input_names=[ + "encoder_out", + "decoder_out", + ], + output_names=["logit"], + dynamic_axes={ + "encoder_out": {0: "N"}, + "decoder_out": {0: "N"}, + "logit": {0: "N"}, + }, + ) + meta_data = { + "joiner_dim": str(joiner_dim), + } + add_meta_data(filename=joiner_filename, meta_data=meta_data) + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + setup_logger(f"{params.exp_dir}/log-export/log-export-onnx") + + logging.info(f"device: {device}") + + lexicon = Lexicon(params.lang_dir) + params.blank_id = 0 + params.vocab_size = max(lexicon.tokens) + 1 + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + model.to(device) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to("cpu") + model.eval() + + convert_scaled_to_non_scaled(model, inplace=True) + encoder = OnnxEncoder( + encoder=model.encoder, + encoder_proj=model.joiner.encoder_proj, + ) + + decoder = OnnxDecoder( + decoder=model.decoder, + decoder_proj=model.joiner.decoder_proj, + ) + + joiner = OnnxJoiner(output_linear=model.joiner.output_linear) + + encoder_num_param = sum([p.numel() for p in encoder.parameters()]) + decoder_num_param = sum([p.numel() for p in decoder.parameters()]) + joiner_num_param = sum([p.numel() for p in joiner.parameters()]) + total_num_param = encoder_num_param + decoder_num_param + joiner_num_param + logging.info(f"encoder parameters: {encoder_num_param}") + logging.info(f"decoder parameters: {decoder_num_param}") + logging.info(f"joiner parameters: {joiner_num_param}") + logging.info(f"total parameters: {total_num_param}") + + if params.iter > 0: + suffix = f"iter-{params.iter}" + else: + suffix = f"epoch-{params.epoch}" + + suffix += f"-avg-{params.avg}" + if params.use_averaged_model: + suffix += "-with-averaged-model" + + opset_version = 13 + + logging.info("Exporting encoder") + encoder_filename = params.exp_dir / f"encoder-{suffix}.onnx" + export_encoder_model_onnx( + encoder, + encoder_filename, + opset_version=opset_version, + ) + logging.info(f"Exported encoder to {encoder_filename}") + + logging.info("Exporting decoder") + decoder_filename = params.exp_dir / f"decoder-{suffix}.onnx" + export_decoder_model_onnx( + decoder, + decoder_filename, + opset_version=opset_version, + ) + logging.info(f"Exported decoder to {decoder_filename}") + + logging.info("Exporting joiner") + joiner_filename = params.exp_dir / f"joiner-{suffix}.onnx" + export_joiner_model_onnx( + joiner, + joiner_filename, + opset_version=opset_version, + ) + logging.info(f"Exported joiner to {joiner_filename}") + + # Generate int8 quantization models + # See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection + + logging.info("Generate int8 quantization models") + + encoder_filename_int8 = params.exp_dir / f"encoder-{suffix}.int8.onnx" + quantize_dynamic( + model_input=encoder_filename, + model_output=encoder_filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) + + decoder_filename_int8 = params.exp_dir / f"decoder-{suffix}.int8.onnx" + quantize_dynamic( + model_input=decoder_filename, + model_output=decoder_filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) + + joiner_filename_int8 = params.exp_dir / f"joiner-{suffix}.int8.onnx" + quantize_dynamic( + model_input=joiner_filename, + model_output=joiner_filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx.py index d7092403e..e71bcaf29 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx.py @@ -53,6 +53,7 @@ import sentencepiece as spm import torch import torch.nn as nn from decoder import Decoder +from onnxruntime.quantization import QuantType, quantize_dynamic from scaling_converter import convert_scaled_to_non_scaled from torch import Tensor from train import add_model_arguments, get_params, get_transducer_model @@ -634,6 +635,35 @@ def main(): ) logging.info(f"Exported joiner to {joiner_filename}") + # Generate int8 quantization models + # See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection + + logging.info("Generate int8 quantization models") + + encoder_filename_int8 = params.exp_dir / f"encoder-{suffix}.int8.onnx" + quantize_dynamic( + model_input=encoder_filename, + model_output=encoder_filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) + + decoder_filename_int8 = params.exp_dir / f"decoder-{suffix}.int8.onnx" + quantize_dynamic( + model_input=decoder_filename, + model_output=decoder_filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) + + joiner_filename_int8 = params.exp_dir / f"joiner-{suffix}.int8.onnx" + quantize_dynamic( + model_input=joiner_filename, + model_output=joiner_filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) + if __name__ == "__main__": main() From 12a222aa4b70f145d91eb7acf4fce201ad35bdc7 Mon Sep 17 00:00:00 2001 From: Yifan Yang <64255737+yfyeung@users.noreply.github.com> Date: Sun, 2 Apr 2023 16:32:43 +0800 Subject: [PATCH 144/174] Fix comments on the usage of train.py (#981) --- .../ASR/pruned_transducer_stateless2/train.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py index 9edc42b61..6a9f9f32f 100755 --- a/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py @@ -19,26 +19,24 @@ """ Usage: -export CUDA_VISIBLE_DEVICES="0,1,2,3" +export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" ./pruned_transducer_stateless2/train.py \ - --world-size 4 \ + --world-size 8 \ --num-epochs 30 \ --start-epoch 0 \ --exp-dir pruned_transducer_stateless2/exp \ - --full-libri 1 \ - --max-duration 300 + --max-duration 120 # For mix precision training: ./pruned_transducer_stateless2/train.py \ - --world-size 4 \ + --world-size 8 \ --num-epochs 30 \ --start-epoch 0 \ --use_fp16 1 \ --exp-dir pruned_transducer_stateless2/exp \ - --full-libri 1 \ - --max-duration 550 + --max-duration 200 """ From 180c7c2b7ae1e03319839d1782a89c9d012e0ddb Mon Sep 17 00:00:00 2001 From: Yifan Yang <64255737+yfyeung@users.noreply.github.com> Date: Mon, 3 Apr 2023 12:39:34 +0800 Subject: [PATCH 145/174] Add UniqueLexicon for gigaspeech (#982) --- egs/gigaspeech/ASR/local/generate_unique_lexicon.py | 1 + 1 file changed, 1 insertion(+) create mode 120000 egs/gigaspeech/ASR/local/generate_unique_lexicon.py diff --git a/egs/gigaspeech/ASR/local/generate_unique_lexicon.py b/egs/gigaspeech/ASR/local/generate_unique_lexicon.py new file mode 120000 index 000000000..c0aea1403 --- /dev/null +++ b/egs/gigaspeech/ASR/local/generate_unique_lexicon.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/generate_unique_lexicon.py \ No newline at end of file From 46bf6df62fb3d78b4b9bf1b0592c889a04d7be9b Mon Sep 17 00:00:00 2001 From: Yifan Yang <64255737+yfyeung@users.noreply.github.com> Date: Mon, 3 Apr 2023 14:55:45 +0800 Subject: [PATCH 146/174] Remove simulate streaming from stateless7 (#983) * Remove simulate streaming from stateless7 --- .../pruned_transducer_stateless7/decode.py | 49 +------------------ 1 file changed, 1 insertion(+), 48 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py index 32b3134b9..576621e24 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py @@ -343,29 +343,6 @@ def get_parser(): fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", ) - parser.add_argument( - "--simulate-streaming", - type=str2bool, - default=False, - help="""Whether to simulate streaming in decoding, this is a good way to - test a streaming model. - """, - ) - - parser.add_argument( - "--decode-chunk-size", - type=int, - default=16, - help="The chunk size for decoding (in frames after subsampling)", - ) - - parser.add_argument( - "--left-context", - type=int, - default=64, - help="left context can be seen during decoding (in frames after subsampling)", - ) - parser.add_argument( "--use-shallow-fusion", type=str2bool, @@ -474,22 +451,7 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - if params.simulate_streaming: - feature_lens += params.left_context - feature = torch.nn.functional.pad( - feature, - pad=(0, 0, 0, params.left_context), - value=LOG_EPS, - ) - encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward( - x=feature, - x_lens=feature_lens, - chunk_size=params.decode_chunk_size, - left_context=params.left_context, - simulate_streaming=True, - ) - else: - encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) hyps = [] @@ -782,10 +744,6 @@ def main(): else: params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - if params.simulate_streaming: - params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}" - params.suffix += f"-left-context-{params.left_context}" - if "fast_beam_search" in params.decoding_method: params.suffix += f"-beam-{params.beam}" params.suffix += f"-max-contexts-{params.max_contexts}" @@ -834,11 +792,6 @@ def main(): params.unk_id = sp.piece_to_id("") params.vocab_size = sp.get_piece_size() - if params.simulate_streaming: - assert ( - params.causal_convolution - ), "Decoding in streaming requires causal convolution" - logging.info(params) logging.info("About to create model") From d337398d29e401600473ca11fd9d827eea86fadc Mon Sep 17 00:00:00 2001 From: marcoyang1998 <45973641+marcoyang1998@users.noreply.github.com> Date: Mon, 3 Apr 2023 16:20:29 +0800 Subject: [PATCH 147/174] Shallow fusion for Aishell (#954) * add shallow fusion and LODR for aishell * update RESULTS * add save by iterations --- egs/aishell/ASR/RESULTS.md | 74 ++++++++ .../local/prepare_char_lm_training_data.py | 0 .../ASR/local/sort_lm_training_data.py | 1 + egs/aishell/ASR/prepare.sh | 11 +- .../pruned_transducer_stateless3/decode.py | 166 ++++++++++++++++++ .../exp-context-size-1 | 1 - .../ASR/lstm_transducer_stateless2/decode.py | 2 - .../beam_search.py | 8 +- .../pruned_transducer_stateless3/decode.py | 2 - .../pruned_transducer_stateless5/decode.py | 2 - .../pruned_transducer_stateless7/decode.py | 2 - egs/wenetspeech/ASR/local/text2segments.py | 4 +- egs/wenetspeech/ASR/prepare.sh | 104 +++++++++++ .../pruned_transducer_stateless5/decode.py | 163 ++++++++++++++++- icefall/lm_wrapper.py | 2 +- icefall/rnn_lm/compute_perplexity.py | 50 +++++- icefall/rnn_lm/export.py | 43 ++++- icefall/rnn_lm/train.py | 48 ++++- 18 files changed, 647 insertions(+), 36 deletions(-) mode change 100644 => 100755 egs/aishell/ASR/local/prepare_char_lm_training_data.py create mode 120000 egs/aishell/ASR/local/sort_lm_training_data.py delete mode 120000 egs/aishell/ASR/pruned_transducer_stateless3/exp-context-size-1 diff --git a/egs/aishell/ASR/RESULTS.md b/egs/aishell/ASR/RESULTS.md index 9a06fbe9f..4c730c4ae 100644 --- a/egs/aishell/ASR/RESULTS.md +++ b/egs/aishell/ASR/RESULTS.md @@ -15,6 +15,8 @@ It uses pruned RNN-T. |------------------------|------|------|---------------------------------------| | greedy search | 5.39 | 5.09 | --epoch 29 --avg 5 --max-duration 600 | | modified beam search | 5.05 | 4.79 | --epoch 29 --avg 5 --max-duration 600 | +| modified beam search + RNNLM shallow fusion | 4.73 | 4.53 | --epoch 29 --avg 5 --max-duration 600 | +| modified beam search + LODR | 4.57 | 4.37 | --epoch 29 --avg 5 --max-duration 600 | | fast beam search | 5.13 | 4.91 | --epoch 29 --avg 5 --max-duration 600 | Training command is: @@ -73,6 +75,78 @@ for epoch in 29; do done ``` +We provide the option of shallow fusion with a RNN language model. The pre-trained language model is +available at . To decode with the language model, +please use the following command: + +```bash +# download pre-trained model +git lfs install +git clone https://huggingface.co/csukuangfj/icefall-aishell-pruned-transducer-stateless3-2022-06-20 + +aishell_exp=icefall-aishell-pruned-transducer-stateless3-2022-06-20/ + +pushd ${aishell_exp}/exp +ln -s pretrained-epoch-29-avg-5-torch-1.10.0.pt epoch-99.pt +popd + +# download RNN LM +git lfs install +git clone https://huggingface.co/marcoyang/icefall-aishell-rnn-lm +rnnlm_dir=icefall-aishell-rnn-lm + +# RNNLM shallow fusion +for lm_scale in $(seq 0.26 0.02 0.34); do + python ./pruned_transducer_stateless3/decode.py \ + --epoch 99 \ + --avg 1 \ + --lang-dir ${aishell_exp}/data/lang_char \ + --exp-dir ${aishell_exp}/exp \ + --use-averaged-model False \ + --decoding-method modified_beam_search_lm_shallow_fusion \ + --use-shallow-fusion 1 \ + --lm-type rnn \ + --lm-exp-dir ${rnnlm_dir}/exp \ + --lm-epoch 99 \ + --lm-scale $lm_scale \ + --lm-avg 1 \ + --rnn-lm-embedding-dim 2048 \ + --rnn-lm-hidden-dim 2048 \ + --rnn-lm-num-layers 2 \ + --lm-vocab-size 4336 +done + +# RNNLM Low-order density ratio (LODR) with a 2-gram + +cp ${rnnlm_dir}/2gram.fst.txt ${aishell_exp}/data/lang_char/2gram.fst.txt + +for lm_scale in 0.48; do + for LODR_scale in -0.28; do + python ./pruned_transducer_stateless3/decode.py \ + --epoch 99 \ + --avg 1 \ + --lang-dir ${aishell_exp}/data/lang_char \ + --exp-dir ${aishell_exp}/exp \ + --use-averaged-model False \ + --decoding-method modified_beam_search_LODR \ + --use-shallow-fusion 1 \ + --lm-type rnn \ + --lm-exp-dir ${rnnlm_dir}/exp \ + --lm-epoch 99 \ + --lm-scale $lm_scale \ + --lm-avg 1 \ + --rnn-lm-embedding-dim 2048 \ + --rnn-lm-hidden-dim 2048 \ + --rnn-lm-num-layers 2 \ + --lm-vocab-size 4336 \ + --tokens-ngram 2 \ + --backoff-id 4336 \ + --ngram-lm-scale $LODR_scale + done +done + +``` + Pretrained models, training logs, decoding logs, and decoding results are available at diff --git a/egs/aishell/ASR/local/prepare_char_lm_training_data.py b/egs/aishell/ASR/local/prepare_char_lm_training_data.py old mode 100644 new mode 100755 diff --git a/egs/aishell/ASR/local/sort_lm_training_data.py b/egs/aishell/ASR/local/sort_lm_training_data.py new file mode 120000 index 000000000..1d6ccbe33 --- /dev/null +++ b/egs/aishell/ASR/local/sort_lm_training_data.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/sort_lm_training_data.py \ No newline at end of file diff --git a/egs/aishell/ASR/prepare.sh b/egs/aishell/ASR/prepare.sh index cf4ee7818..bd34c1f44 100755 --- a/egs/aishell/ASR/prepare.sh +++ b/egs/aishell/ASR/prepare.sh @@ -230,12 +230,14 @@ if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then if [ ! -f $dl_dir/lm/aishell-train-word.txt ]; then cp $lang_phone_dir/transcript_words.txt $dl_dir/lm/aishell-train-word.txt fi - + + # training words ./local/prepare_char_lm_training_data.py \ --lang-char data/lang_char \ --lm-data $dl_dir/lm/aishell-train-word.txt \ --lm-archive $out_dir/lm_data.pt + # valid words if [ ! -f $dl_dir/lm/aishell-valid-word.txt ]; then aishell_text=$dl_dir/aishell/data_aishell/transcript/aishell_transcript_v0.8.txt aishell_valid_uid=$dl_dir/aishell/data_aishell/transcript/aishell_valid_uid @@ -249,6 +251,7 @@ if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then --lm-data $dl_dir/lm/aishell-valid-word.txt \ --lm-archive $out_dir/lm_data_valid.pt + # test words if [ ! -f $dl_dir/lm/aishell-test-word.txt ]; then aishell_text=$dl_dir/aishell/data_aishell/transcript/aishell_transcript_v0.8.txt aishell_test_uid=$dl_dir/aishell/data_aishell/transcript/aishell_test_uid @@ -303,9 +306,9 @@ if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then --hidden-dim 512 \ --num-layers 2 \ --batch-size 400 \ - --exp-dir rnnlm_char/exp \ - --lm-data data/lm_training_char/sorted_lm_data.pt \ - --lm-data-valid data/lm_training_char/sorted_lm_data-valid.pt \ + --exp-dir rnnlm_char/exp_aishell1_small \ + --lm-data data/lm_char/sorted_lm_data_aishell1.pt \ + --lm-data-valid data/lm_char/sorted_lm_data_valid.pt \ --vocab-size 4336 \ --master-port 12345 fi diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/decode.py b/egs/aishell/ASR/pruned_transducer_stateless3/decode.py index 954d9dc7e..27c64efaa 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless3/decode.py +++ b/egs/aishell/ASR/pruned_transducer_stateless3/decode.py @@ -54,6 +54,40 @@ Usage: --beam 4 \ --max-contexts 4 \ --max-states 8 + +(5) modified beam search (with LM shallow fusion) +./pruned_transducer_stateless3/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless3/exp \ + --max-duration 600 \ + --decoding-method modified_beam_search_lm_shallow_fusion \ + --beam-size 4 \ + --lm-type rnn \ + --lm-scale 0.3 \ + --lm-exp-dir /path/to/LM \ + --rnn-lm-epoch 99 \ + --rnn-lm-avg 1 \ + --rnn-lm-num-layers 3 \ + --rnn-lm-tie-weights 1 + +(6) modified beam search with LM shallow fusion + LODR +./pruned_transducer_stateless3/decode.py \ + --epoch 28 \ + --avg 15 \ + --max-duration 600 \ + --exp-dir ./pruned_transducer_stateless3/exp \ + --decoding-method modified_beam_search_LODR \ + --beam-size 4 \ + --lm-type rnn \ + --lm-scale 0.48 \ + --lm-exp-dir /path/to/LM \ + --rnn-lm-epoch 99 \ + --rnn-lm-avg 1 \ + --rnn-lm-num-layers 3 \ + --rnn-lm-tie-weights 1 + --tokens-ngram 2 \ + --ngram-lm-scale -0.28 \ """ @@ -74,9 +108,12 @@ from beam_search import ( greedy_search, greedy_search_batch, modified_beam_search, + modified_beam_search_lm_shallow_fusion, + modified_beam_search_LODR, ) from train import add_model_arguments, get_params, get_transducer_model +from icefall import LmScorer, NgramLm from icefall.checkpoint import ( average_checkpoints, average_checkpoints_with_averaged_model, @@ -212,6 +249,60 @@ def get_parser(): Used only when --decoding_method is greedy_search""", ) + parser.add_argument( + "--use-shallow-fusion", + type=str2bool, + default=False, + help="""Use neural network LM for shallow fusion. + If you want to use LODR, you will also need to set this to true + """, + ) + + parser.add_argument( + "--lm-type", + type=str, + default="rnn", + help="Type of NN lm", + choices=["rnn", "transformer"], + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.3, + help="""The scale of the neural network LM + Used only when `--use-shallow-fusion` is set to True. + """, + ) + + parser.add_argument( + "--tokens-ngram", + type=int, + default=2, + help="""Token Ngram used for rescoring. + Used only when the decoding method is + modified_beam_search_ngram_rescoring""", + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.01, + help=""" + Used only when --decoding_method is fast_beam_search_nbest_LG. + It specifies the scale for n-gram LM scores. + """, + ) + + parser.add_argument( + "--backoff-id", + type=int, + default=500, + help="""ID of the backoff symbol. + Used only when the decoding method is + modified_beam_search_ngram_rescoring""", + ) + add_model_arguments(parser) return parser @@ -223,6 +314,9 @@ def decode_one_batch( token_table: k2.SymbolTable, batch: dict, decoding_graph: Optional[k2.Fsa] = None, + ngram_lm: Optional[NgramLm] = None, + ngram_lm_scale: float = 1.0, + LM: Optional[LmScorer] = None, ) -> Dict[str, List[List[str]]]: """Decode one batch and return the result in a dict. The dict has the following format: @@ -287,6 +381,24 @@ def decode_one_batch( encoder_out_lens=encoder_out_lens, beam=params.beam_size, ) + elif params.decoding_method == "modified_beam_search_lm_shallow_fusion": + hyp_tokens = modified_beam_search_lm_shallow_fusion( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LM=LM, + ) + elif params.decoding_method == "modified_beam_search_LODR": + hyp_tokens = modified_beam_search_LODR( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LODR_lm=ngram_lm, + LODR_lm_scale=ngram_lm_scale, + LM=LM, + ) else: hyp_tokens = [] batch_size = encoder_out.size(0) @@ -334,6 +446,9 @@ def decode_dataset( model: nn.Module, token_table: k2.SymbolTable, decoding_graph: Optional[k2.Fsa] = None, + ngram_lm: Optional[NgramLm] = None, + ngram_lm_scale: float = 1.0, + LM: Optional[LmScorer] = None, ) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: """Decode dataset. @@ -379,6 +494,9 @@ def decode_dataset( token_table=token_table, decoding_graph=decoding_graph, batch=batch, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, + LM=LM, ) for name, hyps in hyps_dict.items(): @@ -445,6 +563,7 @@ def save_results( def main(): parser = get_parser() AsrDataModule.add_arguments(parser) + LmScorer.add_arguments(parser) args = parser.parse_args() args.exp_dir = Path(args.exp_dir) args.lang_dir = Path(args.lang_dir) @@ -458,6 +577,8 @@ def main(): "beam_search", "fast_beam_search", "modified_beam_search", + "modified_beam_search_LODR", + "modified_beam_search_lm_shallow_fusion", ) params.res_dir = params.exp_dir / params.decoding_method @@ -479,6 +600,19 @@ def main(): if params.use_averaged_model: params.suffix += "-use-averaged-model" + if "ngram" in params.decoding_method: + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" + if params.use_shallow_fusion: + if params.lm_type == "rnn": + params.suffix += f"-rnnlm-lm-scale-{params.lm_scale}" + elif params.lm_type == "transformer": + params.suffix += f"-transformer-lm-scale-{params.lm_scale}" + + if "LODR" in params.decoding_method: + params.suffix += ( + f"-LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}" + ) + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") logging.info("Decoding started") @@ -588,6 +722,35 @@ def main(): else: decoding_graph = None + # only load N-gram LM when needed + if "ngram" in params.decoding_method or "LODR" in params.decoding_method: + lm_filename = params.lang_dir / f"{params.tokens_ngram}gram.fst.txt" + logging.info(f"lm filename: {lm_filename}") + ngram_lm = NgramLm( + lm_filename, + backoff_id=params.backoff_id, + is_binary=False, + ) + logging.info(f"num states: {ngram_lm.lm.num_states}") + ngram_lm_scale = params.ngram_lm_scale + else: + ngram_lm = None + ngram_lm_scale = None + + # only load the neural network LM if doing shallow fusion + if params.use_shallow_fusion: + LM = LmScorer( + lm_type=params.lm_type, + params=params, + device=device, + lm_scale=params.lm_scale, + ) + LM.to(device) + LM.eval() + + else: + LM = None + num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") @@ -610,6 +773,9 @@ def main(): model=model, token_table=lexicon.token_table, decoding_graph=decoding_graph, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, + LM=LM, ) save_results( diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/exp-context-size-1 b/egs/aishell/ASR/pruned_transducer_stateless3/exp-context-size-1 deleted file mode 120000 index bcd4abc2f..000000000 --- a/egs/aishell/ASR/pruned_transducer_stateless3/exp-context-size-1 +++ /dev/null @@ -1 +0,0 @@ -/ceph-fj/fangjun/open-source/icefall-aishell/egs/aishell/ASR/pruned_transducer_stateless3/exp-context-size-1 \ No newline at end of file diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py b/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py index 6c58a57e1..73207017b 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py @@ -550,7 +550,6 @@ def decode_one_batch( encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, beam=params.beam_size, - sp=sp, LM=LM, ) for hyp in sp.decode(hyp_tokens): @@ -561,7 +560,6 @@ def decode_one_batch( encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, beam=params.beam_size, - sp=sp, LODR_lm=ngram_lm, LODR_lm_scale=ngram_lm_scale, LM=LM, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py index 999d793a4..75edf0c54 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -1863,7 +1863,6 @@ def modified_beam_search_LODR( model: Transducer, encoder_out: torch.Tensor, encoder_out_lens: torch.Tensor, - sp: spm.SentencePieceProcessor, LODR_lm: NgramLm, LODR_lm_scale: float, LM: LmScorer, @@ -1883,8 +1882,6 @@ def modified_beam_search_LODR( encoder_out_lens (torch.Tensor): A 1-D tensor of shape (N,), containing the number of valid frames in encoder_out before padding. - sp: - Sentence piece generator. LODR_lm: A low order n-gram LM, whose score will be subtracted during shallow fusion LODR_lm_scale: @@ -1912,7 +1909,7 @@ def modified_beam_search_LODR( ) blank_id = model.decoder.blank_id - sos_id = sp.piece_to_id("") + sos_id = getattr(LM, "sos_id", 1) unk_id = getattr(model, "unk_id", blank_id) context_size = model.decoder.context_size device = next(model.parameters()).device @@ -2137,7 +2134,6 @@ def modified_beam_search_lm_shallow_fusion( model: Transducer, encoder_out: torch.Tensor, encoder_out_lens: torch.Tensor, - sp: spm.SentencePieceProcessor, LM: LmScorer, beam: int = 4, return_timestamps: bool = False, @@ -2176,7 +2172,7 @@ def modified_beam_search_lm_shallow_fusion( ) blank_id = model.decoder.blank_id - sos_id = sp.piece_to_id("") + sos_id = getattr(LM, "sos_id", 1) unk_id = getattr(model, "unk_id", blank_id) context_size = model.decoder.context_size device = next(model.parameters()).device diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py index b39007dfc..7c62bfa58 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py @@ -675,7 +675,6 @@ def decode_one_batch( encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, beam=params.beam_size, - sp=sp, LM=LM, ) for hyp in sp.decode(hyp_tokens): @@ -686,7 +685,6 @@ def decode_one_batch( encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, beam=params.beam_size, - sp=sp, LODR_lm=ngram_lm, LODR_lm_scale=ngram_lm_scale, LM=LM, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py index af0b2d9fc..7a3e63218 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py @@ -586,7 +586,6 @@ def decode_one_batch( encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, beam=params.beam_size, - sp=sp, LM=LM, ) for hyp in sp.decode(hyp_tokens): @@ -597,7 +596,6 @@ def decode_one_batch( encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, beam=params.beam_size, - sp=sp, LODR_lm=ngram_lm, LODR_lm_scale=ngram_lm_scale, LM=LM, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py index 576621e24..55a2493e9 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py @@ -533,7 +533,6 @@ def decode_one_batch( encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, beam=params.beam_size, - sp=sp, LM=LM, ) for hyp in sp.decode(hyp_tokens): @@ -544,7 +543,6 @@ def decode_one_batch( encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, beam=params.beam_size, - sp=sp, LODR_lm=ngram_lm, LODR_lm_scale=ngram_lm_scale, LM=LM, diff --git a/egs/wenetspeech/ASR/local/text2segments.py b/egs/wenetspeech/ASR/local/text2segments.py index df5b3c119..bdf5a3984 100644 --- a/egs/wenetspeech/ASR/local/text2segments.py +++ b/egs/wenetspeech/ASR/local/text2segments.py @@ -40,8 +40,8 @@ from tqdm import tqdm # and 'data()' is only supported in static graph mode. So if you # want to use this api, should call 'paddle.enable_static()' before # this api to enter static graph mode. -paddle.enable_static() -paddle.disable_signal_handler() +# paddle.enable_static() +# paddle.disable_signal_handler() jieba.enable_paddle() diff --git a/egs/wenetspeech/ASR/prepare.sh b/egs/wenetspeech/ASR/prepare.sh index 50a00253d..f7b521794 100755 --- a/egs/wenetspeech/ASR/prepare.sh +++ b/egs/wenetspeech/ASR/prepare.sh @@ -261,3 +261,107 @@ if [ $stage -le 18 ] && [ $stop_stage -ge 18 ]; then log "Stage 18: Compile LG" python ./local/compile_lg.py --lang-dir $lang_char_dir fi + +# prepare RNNLM data +if [ $stage -le 19 ] && [ $stop_stage -ge 19 ]; then + log "Stage 19: Prepare LM training data" + + log "Processing char based data" + text_out_dir=data/lm_char + + mkdir -p $text_out_dir + + log "Genearating training text data" + + if [ ! -f $text_out_dir/lm_data.pt ]; then + ./local/prepare_char_lm_training_data.py \ + --lang-char data/lang_char \ + --lm-data $lang_char_dir/text_words_segmentation \ + --lm-archive $text_out_dir/lm_data.pt + fi + + log "Generating DEV text data" + # prepare validation text data + if [ ! -f $text_out_dir/valid_text_words_segmentation ]; then + valid_text=${text_out_dir}/ + + gunzip -c data/manifests/wenetspeech_supervisions_DEV.jsonl.gz \ + | jq '.text' | sed 's/"//g' \ + | ./local/text2token.py -t "char" > $text_out_dir/valid_text + + python3 ./local/text2segments.py \ + --num-process $nj \ + --input-file $text_out_dir/valid_text \ + --output-file $text_out_dir/valid_text_words_segmentation + fi + + ./local/prepare_char_lm_training_data.py \ + --lang-char data/lang_char \ + --lm-data $text_out_dir/valid_text_words_segmentation \ + --lm-archive $text_out_dir/lm_data_valid.pt + + # prepare TEST text data + if [ ! -f $text_out_dir/TEST_text_words_segmentation ]; then + log "Prepare text for test set." + for test_set in TEST_MEETING TEST_NET; do + gunzip -c data/manifests/wenetspeech_supervisions_${test_set}.jsonl.gz \ + | jq '.text' | sed 's/"//g' \ + | ./local/text2token.py -t "char" > $text_out_dir/${test_set}_text + + python3 ./local/text2segments.py \ + --num-process $nj \ + --input-file $text_out_dir/${test_set}_text \ + --output-file $text_out_dir/${test_set}_text_words_segmentation + done + + cat $text_out_dir/TEST_*_text_words_segmentation > $text_out_dir/test_text_words_segmentation + fi + + ./local/prepare_char_lm_training_data.py \ + --lang-char data/lang_char \ + --lm-data $text_out_dir/test_text_words_segmentation \ + --lm-archive $text_out_dir/lm_data_test.pt + +fi + +# sort RNNLM data +if [ $stage -le 20 ] && [ $stop_stage -ge 20 ]; then + text_out_dir=data/lm_char + + log "Sort lm data" + + ./local/sort_lm_training_data.py \ + --in-lm-data $text_out_dir/lm_data.pt \ + --out-lm-data $text_out_dir/sorted_lm_data.pt \ + --out-statistics $text_out_dir/statistics.txt + + ./local/sort_lm_training_data.py \ + --in-lm-data $text_out_dir/lm_data_valid.pt \ + --out-lm-data $text_out_dir/sorted_lm_data-valid.pt \ + --out-statistics $text_out_dir/statistics-valid.txt + + ./local/sort_lm_training_data.py \ + --in-lm-data $text_out_dir/lm_data_test.pt \ + --out-lm-data $text_out_dir/sorted_lm_data-test.pt \ + --out-statistics $text_out_dir/statistics-test.txt +fi + +export CUDA_VISIBLE_DEVICES="0,1" + +if [ $stage -le 21 ] && [ $stop_stage -ge 21 ]; then + log "Stage 21: Train RNN LM model" + python ../../../icefall/rnn_lm/train.py \ + --start-epoch 0 \ + --world-size 2 \ + --num-epochs 20 \ + --use-fp16 0 \ + --embedding-dim 2048 \ + --hidden-dim 2048 \ + --num-layers 2 \ + --batch-size 400 \ + --exp-dir rnnlm_char/exp \ + --lm-data data/lm_char/sorted_lm_data.pt \ + --lm-data-valid data/lm_char/sorted_lm_data-valid.pt \ + --vocab-size 4336 \ + --master-port 12340 +fi \ No newline at end of file diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py index de12b2ff0..46ba6b005 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py @@ -2,6 +2,7 @@ # # Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) # Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo) +# Copyright 2022 Xiaomi Corporation (Author: Xiaoyu Yang) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -91,6 +92,22 @@ When training with the L subset, the streaming usage: --causal-convolution 1 \ --decode-chunk-size 16 \ --left-context 64 + +(4) modified beam search with RNNLM shallow fusion +./pruned_transducer_stateless5/decode.py \ + --epoch 35 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless5/exp \ + --max-duration 600 \ + --decoding-method modified_beam_search_lm_shallow_fusion \ + --beam-size 4 \ + --lm-type rnn \ + --lm-scale 0.3 \ + --lm-exp-dir /path/to/LM \ + --rnn-lm-epoch 99 \ + --rnn-lm-avg 1 \ + --rnn-lm-num-layers 3 \ + --rnn-lm-tie-weights 1 """ @@ -111,9 +128,12 @@ from beam_search import ( greedy_search, greedy_search_batch, modified_beam_search, + modified_beam_search_lm_shallow_fusion, + modified_beam_search_LODR, ) from train import add_model_arguments, get_params, get_transducer_model +from icefall import LmScorer, NgramLm from icefall.checkpoint import ( average_checkpoints, average_checkpoints_with_averaged_model, @@ -224,6 +244,16 @@ def get_parser(): Used only when --decoding-method is fast_beam_search""", ) + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.01, + help=""" + Used only when --decoding_method is fast_beam_search_nbest_LG and fast_beam_search_LG. + It specifies the scale for n-gram LM scores. + """, + ) + parser.add_argument( "--max-contexts", type=int, @@ -277,6 +307,50 @@ def get_parser(): help="left context can be seen during decoding (in frames after subsampling)", ) + parser.add_argument( + "--use-shallow-fusion", + type=str2bool, + default=False, + help="""Use neural network LM for shallow fusion. + If you want to use LODR, you will also need to set this to true + """, + ) + + parser.add_argument( + "--lm-type", + type=str, + default="rnn", + help="Type of NN lm", + choices=["rnn", "transformer"], + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.3, + help="""The scale of the neural network LM + Used only when `--use-shallow-fusion` is set to True. + """, + ) + + parser.add_argument( + "--tokens-ngram", + type=int, + default=3, + help="""Token Ngram used for rescoring. + Used only when the decoding method is + modified_beam_search_ngram_rescoring, or LODR + """, + ) + + parser.add_argument( + "--backoff-id", + type=int, + default=500, + help="""ID of the backoff symbol. + Used only when the decoding method is + modified_beam_search_ngram_rescoring""", + ) add_model_arguments(parser) return parser @@ -288,6 +362,9 @@ def decode_one_batch( lexicon: Lexicon, batch: dict, decoding_graph: Optional[k2.Fsa] = None, + ngram_lm: Optional[NgramLm] = None, + ngram_lm_scale: float = 1.0, + LM: Optional[LmScorer] = None, ) -> Dict[str, List[List[str]]]: """Decode one batch and return the result in a dict. The dict has the following format: @@ -374,6 +451,28 @@ def decode_one_batch( ) for i in range(encoder_out.size(0)): hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) + elif params.decoding_method == "modified_beam_search_lm_shallow_fusion": + hyp_tokens = modified_beam_search_lm_shallow_fusion( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LM=LM, + ) + for i in range(encoder_out.size(0)): + hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) + elif params.decoding_method == "modified_beam_search_LODR": + hyp_tokens = modified_beam_search_LODR( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LODR_lm=ngram_lm, + LODR_lm_scale=ngram_lm_scale, + LM=LM, + ) + for i in range(encoder_out.size(0)): + hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) else: batch_size = encoder_out.size(0) @@ -419,6 +518,9 @@ def decode_dataset( model: nn.Module, lexicon: Lexicon, decoding_graph: Optional[k2.Fsa] = None, + ngram_lm: Optional[NgramLm] = None, + ngram_lm_scale: float = 1.0, + LM: Optional[LmScorer] = None, ) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: """Decode dataset. @@ -432,6 +534,8 @@ def decode_dataset( decoding_graph: The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used only when --decoding_method is fast_beam_search. + LM: + A neural network LM, used during shallow fusion Returns: Return a dict, whose key may be "greedy_search" if greedy search is used, or it may be "beam_7" if beam size of 7 is used. @@ -449,7 +553,7 @@ def decode_dataset( if params.decoding_method == "greedy_search": log_interval = 100 else: - log_interval = 2 + log_interval = 20 results = defaultdict(list) for batch_idx, batch in enumerate(dl): @@ -463,6 +567,9 @@ def decode_dataset( lexicon=lexicon, decoding_graph=decoding_graph, batch=batch, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, + LM=LM, ) for name, hyps in hyps_dict.items(): @@ -524,6 +631,7 @@ def save_results( def main(): parser = get_parser() WenetSpeechAsrDataModule.add_arguments(parser) + LmScorer.add_arguments(parser) args = parser.parse_args() args.exp_dir = Path(args.exp_dir) @@ -535,6 +643,8 @@ def main(): "beam_search", "fast_beam_search", "modified_beam_search", + "modified_beam_search_lm_shallow_fusion", + "modified_beam_search_LODR", ) params.res_dir = params.exp_dir / params.decoding_method @@ -549,6 +659,22 @@ def main(): params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + if "ngram" in params.decoding_method: + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" + if params.use_shallow_fusion: + if params.lm_type == "rnn": + params.suffix += f"-rnnlm-lm-scale-{params.lm_scale}" + elif params.lm_type == "transformer": + params.suffix += f"-transformer-lm-scale-{params.lm_scale}" + + if "LODR" in params.decoding_method: + params.suffix += ( + f"-LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}" + ) + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") logging.info("Decoding started") @@ -558,6 +684,7 @@ def main(): logging.info(f"Device: {device}") + # import pdb; pdb.set_trace() lexicon = Lexicon(params.lang_dir) params.blank_id = lexicon.token_table[""] params.vocab_size = max(lexicon.tokens) + 1 @@ -652,6 +779,37 @@ def main(): model.to(device) model.eval() model.device = device + # only load N-gram LM when needed + if "ngram" in params.decoding_method or "LODR" in params.decoding_method: + lm_filename = f"{params.tokens_ngram}gram.fst.txt" + logging.info(f"lm filename: {lm_filename}") + ngram_lm = NgramLm( + str(params.lang_dir / lm_filename), + backoff_id=params.backoff_id, + is_binary=False, + ) + logging.info(f"num states: {ngram_lm.lm.num_states}") + ngram_lm_scale = params.ngram_lm_scale + else: + ngram_lm = None + ngram_lm_scale = None + + # import pdb; pdb.set_trace() + # only load the neural network LM if doing shallow fusion + if params.use_shallow_fusion: + LM = LmScorer( + lm_type=params.lm_type, + params=params, + device=device, + lm_scale=params.lm_scale, + ) + LM.to(device) + LM.eval() + + num_param = sum([p.numel() for p in LM.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + else: + LM = None if params.decoding_method == "fast_beam_search": decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) @@ -684,6 +842,9 @@ def main(): model=model, lexicon=lexicon, decoding_graph=decoding_graph, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, + LM=LM, ) save_results( params=params, diff --git a/icefall/lm_wrapper.py b/icefall/lm_wrapper.py index 0468befd0..5e2783a47 100644 --- a/icefall/lm_wrapper.py +++ b/icefall/lm_wrapper.py @@ -50,7 +50,7 @@ class LmScorer(torch.nn.Module): def add_arguments(cls, parser): # LM general arguments parser.add_argument( - "--vocab-size", + "--lm-vocab-size", type=int, default=500, ) diff --git a/icefall/rnn_lm/compute_perplexity.py b/icefall/rnn_lm/compute_perplexity.py index f75a89590..cc566bd92 100755 --- a/icefall/rnn_lm/compute_perplexity.py +++ b/icefall/rnn_lm/compute_perplexity.py @@ -33,7 +33,7 @@ import torch from dataset import get_dataloader from model import RnnLmModel -from icefall.checkpoint import average_checkpoints, load_checkpoint +from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.utils import AttributeDict, setup_logger, str2bool @@ -49,6 +49,7 @@ def get_parser(): help="It specifies the checkpoint to use for decoding." "Note: Epoch counts from 0.", ) + parser.add_argument( "--avg", type=int, @@ -58,6 +59,16 @@ def get_parser(): "'--epoch'. ", ) + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + parser.add_argument( "--exp-dir", type=str, @@ -154,7 +165,14 @@ def main(): params = AttributeDict(vars(args)) - setup_logger(f"{params.exp_dir}/log-ppl/") + if params.iter > 0: + setup_logger( + f"{params.exp_dir}/log-ppl/log-ppl-iter-{params.iter}-avg-{params.avg}" + ) + else: + setup_logger( + f"{params.exp_dir}/log-ppl/log-ppl-epoch-{params.epoch}-avg-{params.avg}" + ) logging.info("Computing perplexity started") logging.info(params) @@ -173,19 +191,39 @@ def main(): tie_weights=params.tie_weights, ) - if params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") model.to(device) + model.load_state_dict( + average_checkpoints(filenames, device=device), strict=False + ) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) else: start = params.epoch - params.avg + 1 filenames = [] for i in range(start, params.epoch + 1): - if start >= 0: + if i >= 0: filenames.append(f"{params.exp_dir}/epoch-{i}.pt") logging.info(f"averaging {filenames}") model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) + model.load_state_dict( + average_checkpoints(filenames, device=device), strict=False + ) + model.to(device) model.eval() num_param = sum([p.numel() for p in model.parameters()]) num_param_requires_grad = sum( diff --git a/icefall/rnn_lm/export.py b/icefall/rnn_lm/export.py index 2411cb1f0..a8598a1ce 100644 --- a/icefall/rnn_lm/export.py +++ b/icefall/rnn_lm/export.py @@ -25,7 +25,7 @@ from pathlib import Path import torch from model import RnnLmModel -from icefall.checkpoint import load_checkpoint +from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.utils import AttributeDict, load_averaged_model, str2bool @@ -51,6 +51,16 @@ def get_parser(): "'--epoch'. ", ) + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + parser.add_argument( "--vocab-size", type=int, @@ -133,11 +143,36 @@ def main(): model.to(device) - if params.avg == 1: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict( + average_checkpoints(filenames, device=device), strict=False + ) + elif params.avg == 1: load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) else: - model = load_averaged_model( - params.exp_dir, model, params.epoch, params.avg, device + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 0: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict( + average_checkpoints(filenames, device=device), strict=False ) model.to("cpu") diff --git a/icefall/rnn_lm/train.py b/icefall/rnn_lm/train.py index f43e66cd2..91df4f921 100755 --- a/icefall/rnn_lm/train.py +++ b/icefall/rnn_lm/train.py @@ -49,6 +49,7 @@ from torch.utils.tensorboard import SummaryWriter from icefall.checkpoint import load_checkpoint from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import save_checkpoint_with_global_batch_idx from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool @@ -178,6 +179,33 @@ def get_parser(): help="The seed for random generators intended for reproducibility", ) + parser.add_argument( + "--lr", + type=float, + default=1e-3, + ) + + parser.add_argument( + "--max-sent-len", + type=int, + default=200, + help="""Maximum number of tokens in a sentence. This is used + to adjust batch-size dynamically""", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=2000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 0. + """, + ) + return parser @@ -190,16 +218,15 @@ def get_params() -> AttributeDict: "sos_id": 1, "eos_id": 1, "blank_id": 0, - "lr": 1e-3, "weight_decay": 1e-6, "best_train_loss": float("inf"), "best_valid_loss": float("inf"), "best_train_epoch": -1, "best_valid_epoch": -1, "batch_idx_train": 0, - "log_interval": 200, + "log_interval": 100, "reset_interval": 2000, - "valid_interval": 5000, + "valid_interval": 200, "env_info": get_env_info(), } ) @@ -382,6 +409,7 @@ def train_one_epoch( valid_dl: torch.utils.data.DataLoader, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, + rank: int = 0, ) -> None: """Train the model for one epoch. @@ -430,6 +458,19 @@ def train_one_epoch( clip_grad_norm_(model.parameters(), 5.0, 2.0) optimizer.step() + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + params=params, + optimizer=optimizer, + rank=rank, + ) + if batch_idx % params.log_interval == 0: # Note: "frames" here means "num_tokens" this_batch_ppl = math.exp(loss_info["loss"] / loss_info["frames"]) @@ -580,6 +621,7 @@ def run(rank, world_size, args): valid_dl=valid_dl, tb_writer=tb_writer, world_size=world_size, + rank=rank, ) save_checkpoint( From c90f57afdbc8116944771aab1c3a42217a53eeec Mon Sep 17 00:00:00 2001 From: Yifan Yang <64255737+yfyeung@users.noreply.github.com> Date: Tue, 4 Apr 2023 11:04:00 +0800 Subject: [PATCH 148/174] Remove simulate streaming from stateless8 (#985) --- .../pruned_transducer_stateless8/decode.py | 49 +------------------ 1 file changed, 1 insertion(+), 48 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless8/decode.py index 7b651a632..e07777c9f 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless8/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/decode.py @@ -301,29 +301,6 @@ def get_parser(): fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", ) - parser.add_argument( - "--simulate-streaming", - type=str2bool, - default=False, - help="""Whether to simulate streaming in decoding, this is a good way to - test a streaming model. - """, - ) - - parser.add_argument( - "--decode-chunk-size", - type=int, - default=16, - help="The chunk size for decoding (in frames after subsampling)", - ) - - parser.add_argument( - "--left-context", - type=int, - default=64, - help="left context can be seen during decoding (in frames after subsampling)", - ) - add_model_arguments(parser) return parser @@ -378,22 +355,7 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - if params.simulate_streaming: - feature_lens += params.left_context - feature = torch.nn.functional.pad( - feature, - pad=(0, 0, 0, params.left_context), - value=LOG_EPS, - ) - encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward( - x=feature, - x_lens=feature_lens, - chunk_size=params.decode_chunk_size, - left_context=params.left_context, - simulate_streaming=True, - ) - else: - encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) hyps = [] @@ -651,10 +613,6 @@ def main(): else: params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - if params.simulate_streaming: - params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}" - params.suffix += f"-left-context-{params.left_context}" - if "fast_beam_search" in params.decoding_method: params.suffix += f"-beam-{params.beam}" params.suffix += f"-max-contexts-{params.max_contexts}" @@ -690,11 +648,6 @@ def main(): params.unk_id = sp.piece_to_id("") params.vocab_size = sp.get_piece_size() - if params.simulate_streaming: - assert ( - params.causal_convolution - ), "Decoding in streaming requires causal convolution" - logging.info(params) logging.info("About to create model") From 136aa94d5757ef02654612f0b19fb5d25d5eda39 Mon Sep 17 00:00:00 2001 From: Zengwei Yao Date: Thu, 6 Apr 2023 17:47:33 +0800 Subject: [PATCH 149/174] remove duplicated lines (#988) --- .../ASR/pruned_transducer_stateless3/onnx_pretrained.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_pretrained.py index 5adb6c16a..3947aa102 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_pretrained.py @@ -403,9 +403,8 @@ def main(): text += symbol_table[i] return text.replace("▁", " ").strip() - context_size = model.context_size for filename, hyp in zip(args.sound_files, hyps): - words = token_ids_to_words(hyp[context_size:]) + words = token_ids_to_words(hyp) s += f"{filename}:\n{words}\n" logging.info(s) From 6434c8eadc0d4326e1db69824cf0e40dc9a71c8a Mon Sep 17 00:00:00 2001 From: Yifan Yang <64255737+yfyeung@users.noreply.github.com> Date: Sun, 9 Apr 2023 20:53:47 +0800 Subject: [PATCH 150/174] Add averaged model && change start from 0 to 1 && fix typo for gigaspeech (#990) * Add averaged model && change start from 0 to 1 && fix typo * Update train.py * Set use-averaged-model False for BC --------- Co-authored-by: yifanyang --- .../pruned_transducer_stateless2/decode.py | 181 ++++++++++++------ .../ASR/pruned_transducer_stateless2/train.py | 79 ++++++-- 2 files changed, 184 insertions(+), 76 deletions(-) diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/decode.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/decode.py index ee694a9e0..72f74c968 100755 --- a/egs/gigaspeech/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/gigaspeech/ASR/pruned_transducer_stateless2/decode.py @@ -19,40 +19,40 @@ Usage: (1) greedy search ./pruned_transducer_stateless2/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless2/exp \ - --max-duration 600 \ - --decoding-method greedy_search + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless2/exp \ + --max-duration 600 \ + --decoding-method greedy_search (2) beam search ./pruned_transducer_stateless2/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless2/exp \ - --max-duration 600 \ - --decoding-method beam_search \ - --beam-size 4 + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless2/exp \ + --max-duration 600 \ + --decoding-method beam_search \ + --beam-size 4 (3) modified beam search ./pruned_transducer_stateless2/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless2/exp \ - --max-duration 600 \ - --decoding-method modified_beam_search \ - --beam-size 4 + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless2/exp \ + --max-duration 600 \ + --decoding-method modified_beam_search \ + --beam-size 4 (4) fast beam search ./pruned_transducer_stateless2/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless2/exp \ - --max-duration 600 \ - --decoding-method fast_beam_search \ - --beam 4 \ - --max-contexts 4 \ - --max-states 8 + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless2/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search \ + --beam 4 \ + --max-contexts 4 \ + --max-states 8 """ @@ -76,12 +76,17 @@ from beam_search import ( ) from gigaspeech_scoring import asr_text_post_processing from train import get_params, get_transducer_model - -from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) from icefall.utils import ( AttributeDict, setup_logger, store_transcripts, + str2bool, write_error_stats, ) @@ -94,9 +99,9 @@ def get_parser(): parser.add_argument( "--epoch", type=int, - default=29, + default=30, help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 0. + Note: Epoch counts from 1. You can specify --avg to use more checkpoints for model averaging.""", ) @@ -119,6 +124,17 @@ def get_parser(): "'--epoch' and '--iter'", ) + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=False, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + parser.add_argument( "--exp-dir", type=str, @@ -464,6 +480,9 @@ def main(): params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") logging.info("Decoding started") @@ -476,7 +495,7 @@ def main(): sp = spm.SentencePieceProcessor() sp.load(params.bpe_model) - # and is defined in local/train_bpe_model.py + # and are defined in local/train_bpe_model.py params.blank_id = sp.piece_to_id("") params.unk_id = sp.piece_to_id("") params.vocab_size = sp.get_piece_size() @@ -486,37 +505,85 @@ def main(): logging.info("About to create model") model = get_transducer_model(params) - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if start >= 0: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) model.to(device) model.eval() - model.device = device if params.decoding_method == "fast_beam_search": decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py index 6a9f9f32f..578bd9218 100755 --- a/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py @@ -42,6 +42,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" import argparse +import copy import logging import warnings from pathlib import Path @@ -70,7 +71,10 @@ from torch.utils.tensorboard import SummaryWriter from icefall import diagnostics from icefall.checkpoint import load_checkpoint, remove_checkpoints from icefall.checkpoint import save_checkpoint as save_checkpoint_impl -from icefall.checkpoint import save_checkpoint_with_global_batch_idx +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool @@ -114,10 +118,10 @@ def get_parser(): parser.add_argument( "--start-epoch", type=int, - default=0, - help="""Resume training from from this epoch. - If it is positive, it will load checkpoint from - transducer_stateless2/exp/epoch-{start_epoch-1}.pt + default=1, + help="""Resume training from this epoch. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt """, ) @@ -240,7 +244,7 @@ def get_parser(): parser.add_argument( "--keep-last-k", type=int, - default=20, + default=30, help="""Only keep this number of checkpoints on disk. For instance, if it is 3, there are only 3 checkpoints in the exp-dir with filenames `checkpoint-xxx.pt`. @@ -248,6 +252,19 @@ def get_parser(): """, ) + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + parser.add_argument( "--use-fp16", type=str2bool, @@ -385,6 +402,7 @@ def get_transducer_model(params: AttributeDict) -> nn.Module: def load_checkpoint_if_available( params: AttributeDict, model: nn.Module, + model_avg: nn.Module = None, optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, ) -> Optional[Dict[str, Any]]: @@ -392,7 +410,7 @@ def load_checkpoint_if_available( If params.start_batch is positive, it will load the checkpoint from `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if - params.start_epoch is positive, it will load the checkpoint from + params.start_epoch is larger than 1, it will load the checkpoint from `params.start_epoch - 1`. Apart from loading state dict for `model` and `optimizer` it also updates @@ -404,6 +422,8 @@ def load_checkpoint_if_available( The return value of :func:`get_params`. model: The training model. + model_avg: + The stored model averaged from the start of training. optimizer: The optimizer that we are using. scheduler: @@ -413,7 +433,7 @@ def load_checkpoint_if_available( """ if params.start_batch > 0: filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" - elif params.start_epoch > 0: + elif params.start_epoch > 1: filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" else: return None @@ -423,6 +443,7 @@ def load_checkpoint_if_available( saved_params = load_checkpoint( filename, model=model, + model_avg=model_avg, optimizer=optimizer, scheduler=scheduler, ) @@ -449,7 +470,8 @@ def load_checkpoint_if_available( def save_checkpoint( params: AttributeDict, - model: nn.Module, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, sampler: Optional[CutSampler] = None, @@ -463,6 +485,8 @@ def save_checkpoint( It is returned by :func:`get_params`. model: The training model. + model_avg: + The stored model averaged from the start of training. optimizer: The optimizer used in the training. sampler: @@ -476,6 +500,7 @@ def save_checkpoint( save_checkpoint_impl( filename=filename, model=model, + model_avg=model_avg, params=params, optimizer=optimizer, scheduler=scheduler, @@ -495,14 +520,14 @@ def save_checkpoint( def compute_loss( params: AttributeDict, - model: nn.Module, + model: Union[nn.Module, DDP], sp: spm.SentencePieceProcessor, batch: dict, is_training: bool, warmup: float = 1.0, ) -> Tuple[Tensor, MetricsTracker]: """ - Compute CTC loss given the model and its inputs. + Compute transducer loss given the model and its inputs. Args: params: @@ -568,7 +593,7 @@ def compute_loss( def compute_validation_loss( params: AttributeDict, - model: nn.Module, + model: Union[nn.Module, DDP], sp: spm.SentencePieceProcessor, valid_dl: torch.utils.data.DataLoader, world_size: int = 1, @@ -602,13 +627,14 @@ def compute_validation_loss( def train_one_epoch( params: AttributeDict, - model: nn.Module, + model: Union[nn.Module, DDP], optimizer: torch.optim.Optimizer, scheduler: LRSchedulerType, sp: spm.SentencePieceProcessor, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, scaler: GradScaler, + model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, rank: int = 0, @@ -634,6 +660,8 @@ def train_one_epoch( Dataloader for the validation dataset. scaler: The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. tb_writer: Writer to write log messages to tensorboard. world_size: @@ -660,6 +688,7 @@ def train_one_epoch( loss, loss_info = compute_loss( params=params, model=model, + model_avg=model_avg, sp=sp, batch=batch, is_training=True, @@ -688,6 +717,7 @@ def train_one_epoch( out_dir=params.exp_dir, global_batch_idx=params.batch_idx_train, model=model, + model_avg=model_avg, params=params, optimizer=optimizer, scheduler=scheduler, @@ -791,7 +821,16 @@ def run(rank, world_size, args): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - checkpoints = load_checkpoint_if_available(params=params, model=model) + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) model.to(device) if world_size > 1: @@ -850,10 +889,10 @@ def run(rank, world_size, args): logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) - for epoch in range(params.start_epoch, params.num_epochs): - scheduler.step_epoch(epoch) - fix_random_seed(params.seed + epoch) - train_dl.sampler.set_epoch(epoch) + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) if tb_writer is not None: tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) @@ -863,6 +902,7 @@ def run(rank, world_size, args): train_one_epoch( params=params, model=model, + model_avg=model_avg, optimizer=optimizer, scheduler=scheduler, sp=sp, @@ -881,6 +921,7 @@ def run(rank, world_size, args): save_checkpoint( params=params, model=model, + model_avg=model_avg, optimizer=optimizer, scheduler=scheduler, sampler=train_dl.sampler, @@ -896,7 +937,7 @@ def run(rank, world_size, args): def scan_pessimistic_batches_for_oom( - model: nn.Module, + model: Union[nn.Module, DDP], train_dl: torch.utils.data.DataLoader, optimizer: torch.optim.Optimizer, sp: spm.SentencePieceProcessor, From 33578cca48613c04c227cd22a2d2fdd207a5e928 Mon Sep 17 00:00:00 2001 From: Yifan Yang <64255737+yfyeung@users.noreply.github.com> Date: Tue, 11 Apr 2023 11:12:05 +0800 Subject: [PATCH 151/174] Fix filter_cuts in compute_fbank_librispeech.py (#993) --- egs/librispeech/ASR/local/compute_fbank_librispeech.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/local/compute_fbank_librispeech.py b/egs/librispeech/ASR/local/compute_fbank_librispeech.py index 745eaf1e8..554d7f109 100755 --- a/egs/librispeech/ASR/local/compute_fbank_librispeech.py +++ b/egs/librispeech/ASR/local/compute_fbank_librispeech.py @@ -121,10 +121,10 @@ def compute_fbank_librispeech( recordings=m["recordings"], supervisions=m["supervisions"], ) - if bpe_model: - cut_set = filter_cuts(cut_set, sp) if "train" in partition: + if bpe_model: + cut_set = filter_cuts(cut_set, sp) cut_set = ( cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) ) From 3cb0a0121ba7ad934afb7b9e59045995e33e77b6 Mon Sep 17 00:00:00 2001 From: Yifan Yang <64255737+yfyeung@users.noreply.github.com> Date: Tue, 11 Apr 2023 20:56:40 +0800 Subject: [PATCH 152/174] Add Common Voice (#994) * Add commonvoice * Add data preparation recipe * Updata * update prepare.sh * Fix for black * Update prefix with cv- * 20 -> * Update compute_fbank_commonvoice_dev_test.py * Update prepare.sh * Update compute_fbank_commonvoice_dev_test.py --- .../compute_fbank_commonvoice_dev_test.py | 107 ++++++++++++ .../local/compute_fbank_commonvoice_splits.py | 157 ++++++++++++++++++ .../ASR/local/compute_fbank_musan.py | 1 + egs/commonvoice/ASR/local/filter_cuts.py | 1 + .../ASR/local/preprocess_commonvoice.py | 119 +++++++++++++ egs/commonvoice/ASR/prepare.sh | 156 +++++++++++++++++ egs/commonvoice/ASR/shared | 1 + 7 files changed, 542 insertions(+) create mode 100755 egs/commonvoice/ASR/local/compute_fbank_commonvoice_dev_test.py create mode 100755 egs/commonvoice/ASR/local/compute_fbank_commonvoice_splits.py create mode 120000 egs/commonvoice/ASR/local/compute_fbank_musan.py create mode 120000 egs/commonvoice/ASR/local/filter_cuts.py create mode 100755 egs/commonvoice/ASR/local/preprocess_commonvoice.py create mode 100755 egs/commonvoice/ASR/prepare.sh create mode 120000 egs/commonvoice/ASR/shared diff --git a/egs/commonvoice/ASR/local/compute_fbank_commonvoice_dev_test.py b/egs/commonvoice/ASR/local/compute_fbank_commonvoice_dev_test.py new file mode 100755 index 000000000..c8f9b6ccb --- /dev/null +++ b/egs/commonvoice/ASR/local/compute_fbank_commonvoice_dev_test.py @@ -0,0 +1,107 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Yifan Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This file computes fbank features of the CommonVoice dataset. +It looks for manifests in the directory data/${lang}/manifests. + +The generated fbank features are saved in data/${lang}/fbank. +""" + +import argparse +import logging +import os +from pathlib import Path +from typing import Optional + +import torch +from filter_cuts import filter_cuts +from lhotse import CutSet, KaldifeatFbank, KaldifeatFbankConfig, LilcomChunkyWriter + +# Torch's multithreaded behavior needs to be disabled or +# it wastes a lot of CPU and slow things down. +# Do this outside of main() in case it needs to take effect +# even when we are not invoking the main (e.g. when spawning subprocesses). +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--language", + type=str, + help="""Language of Common Voice""", + ) + + return parser.parse_args() + + +def compute_fbank_commonvoice_dev_test(language: str): + src_dir = Path(f"data/{language}/manifests") + output_dir = Path(f"data/{language}/fbank") + num_workers = 42 + batch_duration = 600 + + subsets = ("dev", "test") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + extractor = KaldifeatFbank(KaldifeatFbankConfig(device=device)) + + logging.info(f"device: {device}") + + for partition in subsets: + cuts_path = output_dir / f"cv-{language}_cuts_{partition}.jsonl.gz" + if cuts_path.is_file(): + logging.info(f"{partition} already exists - skipping.") + continue + + raw_cuts_path = output_dir / f"cv-{language}_cuts_{partition}_raw.jsonl.gz" + + logging.info(f"Loading {raw_cuts_path}") + cut_set = CutSet.from_file(raw_cuts_path) + + logging.info("Splitting cuts into smaller chunks") + cut_set = cut_set.trim_to_supervisions( + keep_overlapping=False, min_duration=None + ) + + logging.info("Computing features") + cut_set = cut_set.compute_and_store_features_batch( + extractor=extractor, + storage_path=f"{output_dir}/cv-{language}_feats_{partition}", + num_workers=num_workers, + batch_duration=batch_duration, + storage_type=LilcomChunkyWriter, + overwrite=True, + ) + + logging.info(f"Saving to {cuts_path}") + cut_set.to_file(cuts_path) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + args = get_args() + logging.info(vars(args)) + compute_fbank_commonvoice_dev_test(language=args.language) diff --git a/egs/commonvoice/ASR/local/compute_fbank_commonvoice_splits.py b/egs/commonvoice/ASR/local/compute_fbank_commonvoice_splits.py new file mode 100755 index 000000000..f8c09ccf3 --- /dev/null +++ b/egs/commonvoice/ASR/local/compute_fbank_commonvoice_splits.py @@ -0,0 +1,157 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (Yifan Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import logging +from datetime import datetime +from pathlib import Path + +import torch +from lhotse import ( + CutSet, + KaldifeatFbank, + KaldifeatFbankConfig, + LilcomChunkyWriter, + set_audio_duration_mismatch_tolerance, + set_caching_enabled, +) + +# Torch's multithreaded behavior needs to be disabled or +# it wastes a lot of CPU and slow things down. +# Do this outside of main() in case it needs to take effect +# even when we are not invoking the main (e.g. when spawning subprocesses). +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--language", + type=str, + help="""Language of Common Voice""", + ) + + parser.add_argument( + "--num-workers", + type=int, + default=20, + help="Number of dataloading workers used for reading the audio.", + ) + + parser.add_argument( + "--batch-duration", + type=float, + default=600.0, + help="The maximum number of audio seconds in a batch." + "Determines batch size dynamically.", + ) + + parser.add_argument( + "--num-splits", + type=int, + required=True, + help="The number of splits of the train subset", + ) + + parser.add_argument( + "--start", + type=int, + default=0, + help="Process pieces starting from this number (inclusive).", + ) + + parser.add_argument( + "--stop", + type=int, + default=-1, + help="Stop processing pieces until this number (exclusive).", + ) + + return parser.parse_args() + + +def compute_fbank_commonvoice_splits(args): + subset = "train" + num_splits = args.num_splits + language = args.language + output_dir = f"data/{language}/fbank/{subset}_split_{num_splits}" + output_dir = Path(output_dir) + assert output_dir.exists(), f"{output_dir} does not exist!" + + num_digits = len(str(num_splits)) + + start = args.start + stop = args.stop + if stop < start: + stop = num_splits + + stop = min(stop, num_splits) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + extractor = KaldifeatFbank(KaldifeatFbankConfig(device=device)) + logging.info(f"device: {device}") + + set_audio_duration_mismatch_tolerance(0.01) # 10ms tolerance + set_caching_enabled(False) + for i in range(start, stop): + idx = f"{i + 1}".zfill(num_digits) + logging.info(f"Processing {idx}/{num_splits}") + + cuts_path = output_dir / f"cv-{language}_cuts_{subset}.{idx}.jsonl.gz" + if cuts_path.is_file(): + logging.info(f"{cuts_path} exists - skipping") + continue + + raw_cuts_path = output_dir / f"cv-{language}_cuts_{subset}_raw.{idx}.jsonl.gz" + + logging.info(f"Loading {raw_cuts_path}") + cut_set = CutSet.from_file(raw_cuts_path) + + logging.info("Splitting cuts into smaller chunks.") + cut_set = cut_set.trim_to_supervisions( + keep_overlapping=False, min_duration=None + ) + + logging.info("Computing features") + cut_set = cut_set.compute_and_store_features_batch( + extractor=extractor, + storage_path=f"{output_dir}/cv-{language}_feats_{subset}_{idx}", + num_workers=args.num_workers, + batch_duration=args.batch_duration, + storage_type=LilcomChunkyWriter, + overwrite=True, + ) + + logging.info(f"Saving to {cuts_path}") + cut_set.to_file(cuts_path) + + +def main(): + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + args = get_args() + logging.info(vars(args)) + compute_fbank_commonvoice_splits(args) + + +if __name__ == "__main__": + main() diff --git a/egs/commonvoice/ASR/local/compute_fbank_musan.py b/egs/commonvoice/ASR/local/compute_fbank_musan.py new file mode 120000 index 000000000..5833f2484 --- /dev/null +++ b/egs/commonvoice/ASR/local/compute_fbank_musan.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/compute_fbank_musan.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/local/filter_cuts.py b/egs/commonvoice/ASR/local/filter_cuts.py new file mode 120000 index 000000000..27aca1729 --- /dev/null +++ b/egs/commonvoice/ASR/local/filter_cuts.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/filter_cuts.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/local/preprocess_commonvoice.py b/egs/commonvoice/ASR/local/preprocess_commonvoice.py new file mode 100755 index 000000000..22f7aa03c --- /dev/null +++ b/egs/commonvoice/ASR/local/preprocess_commonvoice.py @@ -0,0 +1,119 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Yifan Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import logging +from pathlib import Path +from typing import Optional + +from lhotse import CutSet, SupervisionSegment +from lhotse.recipes.utils import read_manifests_if_cached + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--dataset", + type=str, + help="""Dataset parts to compute fbank. If None, we will use all""", + ) + + parser.add_argument( + "--language", + type=str, + help="""Language of Common Voice""", + ) + + return parser.parse_args() + + +def preprocess_commonvoice( + language: str, + dataset: Optional[str] = None, +): + src_dir = Path("data/manifests") + output_dir = Path("data/fbank") + output_dir.mkdir(exist_ok=True) + + if dataset is None: + dataset_parts = ( + "dev", + "test", + "train", + ) + else: + dataset_parts = dataset.split(" ", -1) + + logging.info("Loading manifest") + prefix = f"cv-{language}" + suffix = "jsonl.gz" + manifests = read_manifests_if_cached( + dataset_parts=dataset_parts, + output_dir=src_dir, + suffix=suffix, + prefix=prefix, + ) + assert manifests is not None + + assert len(manifests) == len(dataset_parts), ( + len(manifests), + len(dataset_parts), + list(manifests.keys()), + dataset_parts, + ) + + for partition, m in manifests.items(): + logging.info(f"Processing {partition}") + raw_cuts_path = output_dir / f"{prefix}_cuts_{partition}_raw.{suffix}" + if raw_cuts_path.is_file(): + logging.info(f"{partition} already exists - skipping") + continue + + # Create long-recording cut manifests. + cut_set = CutSet.from_manifests( + recordings=m["recordings"], + supervisions=m["supervisions"], + ).resample(16000) + + # Run data augmentation that needs to be done in the + # time domain. + if "train" in partition: + logging.info( + f"Speed perturb for {partition} with factors 0.9 and 1.1 " + "(Perturbing may take 2 minutes and saving may take 7 minutes)" + ) + cut_set = cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) + logging.info(f"Saving to {raw_cuts_path}") + cut_set.to_file(raw_cuts_path) + + +def main(): + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + args = get_args() + logging.info(vars(args)) + preprocess_commonvoice( + language=args.language, + dataset=args.dataset, + ) + logging.info("Done") + + +if __name__ == "__main__": + main() diff --git a/egs/commonvoice/ASR/prepare.sh b/egs/commonvoice/ASR/prepare.sh new file mode 100755 index 000000000..9a28167b1 --- /dev/null +++ b/egs/commonvoice/ASR/prepare.sh @@ -0,0 +1,156 @@ +#!/usr/bin/env bash + +set -eou pipefail + +nj=16 +stage=-1 +stop_stage=100 + +# Split data/${lang}set to this number of pieces +# This is to avoid OOM during feature extraction. +num_splits=1000 + +# We assume dl_dir (download dir) contains the following +# directories and files. If not, they will be downloaded +# by this script automatically. +# +# - $dl_dir/$release/$lang +# This directory contains the following files downloaded from +# https://mozilla-common-voice-datasets.s3.dualstack.us-west-2.amazonaws.com/${release}/${release}-${lang}.tar.gz +# +# - clips +# - dev.tsv +# - invalidated.tsv +# - other.tsv +# - reported.tsv +# - test.tsv +# - train.tsv +# - validated.tsv +# +# - $dl_dir/musan +# This directory contains the following directories downloaded from +# http://www.openslr.org/17/ +# +# - music +# - noise +# - speech + +dl_dir=$PWD/download +release=cv-corpus-13.0-2023-03-09 +lang=en + +. shared/parse_options.sh || exit 1 + +# vocab size for sentence piece models. +# It will generate data/${lang}/lang_bpe_xxx, +# data/${lang}/lang_bpe_yyy if the array contains xxx, yyy +vocab_sizes=( + # 5000 + # 2000 + # 1000 + 500 +) + +# All files generated by this script are saved in "data/${lang}". +# You can safely remove "data/${lang}" and rerun this script to regenerate it. +mkdir -p data/${lang} + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +log "dl_dir: $dl_dir" + +if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then + log "Stage 0: Download data" + + # If you have pre-downloaded it to /path/to/$release, + # you can create a symlink + # + # ln -sfv /path/to/$release $dl_dir/$release + # + if [ ! -d $dl_dir/$release/$lang/clips ]; then + lhotse download commonvoice --languages $lang --release $release $dl_dir + fi + + # If you have pre-downloaded it to /path/to/musan, + # you can create a symlink + # + # ln -sfv /path/to/musan $dl_dir/ + # + if [ ! -d $dl_dir/musan ]; then + lhotse download musan $dl_dir + fi +fi + +if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then + log "Stage 1: Prepare CommonVoice manifest" + # We assume that you have downloaded the CommonVoice corpus + # to $dl_dir/$release + mkdir -p data/${lang}/manifests + if [ ! -e data/${lang}/manifests/.cv-${lang}.done ]; then + lhotse prepare commonvoice --language $lang -j $nj $dl_dir/$release data/${lang}/manifests + touch data/${lang}/manifests/.cv-${lang}.done + fi +fi + +if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then + log "Stage 2: Prepare musan manifest" + # We assume that you have downloaded the musan corpus + # to data/musan + mkdir -p data/manifests + if [ ! -e data/manifests/.musan.done ]; then + lhotse prepare musan $dl_dir/musan data/manifests + touch data/manifests/.musan.done + fi +fi + +if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then + log "Stage 3: Preprocess CommonVoice manifest" + if [ ! -e data/${lang}/fbank/.preprocess_complete ]; then + ./local/preprocess_commonvoice.py --language $lang + touch data/${lang}/fbank/.preprocess_complete + fi +fi + +if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then + log "Stage 4: Compute fbank for dev and test subsets of CommonVoice" + mkdir -p data/${lang}/fbank + if [ ! -e data/${lang}/fbank/.cv-${lang}_dev_test.done ]; then + ./local/compute_fbank_commonvoice_dev_test.py --language $lang + touch data/${lang}/fbank/.cv-${lang}_dev_test.done + fi +fi + +if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then + log "Stage 5: Split train subset into ${num_splits} pieces" + split_dir=data/${lang}/fbank/train_split_${num_splits} + if [ ! -e $split_dir/.cv-${lang}_train_split.done ]; then + lhotse split $num_splits ./data/${lang}/fbank/cv-${lang}_cuts_train_raw.jsonl.gz $split_dir + touch $split_dir/.cv-${lang}_train_split.done + fi +fi + +if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then + log "Stage 6: Compute features for train subset of CommonVoice" + if [ ! -e data/${lang}/fbank/.cv-${lang}_train.done ]; then + ./local/compute_fbank_commonvoice_splits.py \ + --num-workers $nj \ + --batch-duration 600 \ + --start 0 \ + --num-splits $num_splits \ + --language $lang + touch data/${lang}/fbank/.cv-${lang}_train.done + fi +fi + +if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then + log "Stage 7: Compute fbank for musan" + mkdir -p data/fbank + if [ ! -e data/fbank/.musan.done ]; then + ./local/compute_fbank_musan.py + touch data/fbank/.musan.done + fi +fi diff --git a/egs/commonvoice/ASR/shared b/egs/commonvoice/ASR/shared new file mode 120000 index 000000000..4c5e91438 --- /dev/null +++ b/egs/commonvoice/ASR/shared @@ -0,0 +1 @@ +../../../icefall/shared/ \ No newline at end of file From dbf2aa3212a47650c85da574a4a2ed61773d6be3 Mon Sep 17 00:00:00 2001 From: Yifan Yang <64255737+yfyeung@users.noreply.github.com> Date: Tue, 11 Apr 2023 21:04:54 +0800 Subject: [PATCH 153/174] Create preprocess_commonvoice.py (#996) --- egs/commonvoice/ASR/local/preprocess_commonvoice.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/commonvoice/ASR/local/preprocess_commonvoice.py b/egs/commonvoice/ASR/local/preprocess_commonvoice.py index 22f7aa03c..241e81963 100755 --- a/egs/commonvoice/ASR/local/preprocess_commonvoice.py +++ b/egs/commonvoice/ASR/local/preprocess_commonvoice.py @@ -46,8 +46,8 @@ def preprocess_commonvoice( language: str, dataset: Optional[str] = None, ): - src_dir = Path("data/manifests") - output_dir = Path("data/fbank") + src_dir = Path(f"data/{language}/manifests") + output_dir = Path(f"data/{language}/fbank") output_dir.mkdir(exist_ok=True) if dataset is None: From 5f066d3d538e093d72c94e48ad6be4c23da71f84 Mon Sep 17 00:00:00 2001 From: Zengwei Yao Date: Wed, 12 Apr 2023 19:04:50 +0800 Subject: [PATCH 154/174] support decoding and computing RTF on test sets with onnx models (#995) * support decode and compute RTF on test sets with onnx models * support onnx export and decode in pruned_transducer_stateless --- .../export-onnx.py | 527 ++++++++++++++++++ .../pruned_transducer_stateless/onnx_check.py | 1 + .../onnx_decode.py | 319 +++++++++++ .../onnx_pretrained.py | 1 + .../onnx_decode.py | 321 +++++++++++ .../onnx_pretrained.py | 3 +- .../export-onnx.py | 30 + .../onnx_decode.py | 326 +++++++++++ .../onnx_decode.py | 319 +++++++++++ .../ASR/transducer_stateless/conformer.py | 83 ++- 10 files changed, 1901 insertions(+), 29 deletions(-) create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless/export-onnx.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless/onnx_check.py create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless/onnx_decode.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless/onnx_pretrained.py create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless3/onnx_decode.py create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless5/onnx_decode.py create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless7/onnx_decode.py diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/export-onnx.py b/egs/librispeech/ASR/pruned_transducer_stateless/export-onnx.py new file mode 100755 index 000000000..a3ebe9d8c --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless/export-onnx.py @@ -0,0 +1,527 @@ +#!/usr/bin/env python3 +# +# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang) + +""" +This script exports a transducer model from PyTorch to ONNX. + +We use the pre-trained model from +https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless-2022-03-12 +as an example to show how to use this file. + +1. Download the pre-trained model + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless-2022-03-12 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "data/lang_bpe_500/bpe.model" +git lfs pull --include "exp/pretrained.pt" + +cd exp +ln -s pretrained.pt epoch-9999.pt +popd + +2. Export the model to ONNX + +./pruned_transducer_stateless/export-onnx.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --epoch 9999 \ + --avg 1 \ + --exp-dir $repo/exp/ + +It will generate the following 3 files inside $repo/exp: + + - encoder-epoch-9999-avg-1.onnx + - decoder-epoch-9999-avg-1.onnx + - joiner-epoch-9999-avg-1.onnx + +See ./onnx_pretrained.py and ./onnx_check.py for how to +use the exported ONNX models. +""" + +import argparse +import logging +from pathlib import Path +from typing import Dict, Tuple + +import onnx +import sentencepiece as spm +import torch +import torch.nn as nn +from conformer import Conformer +from decoder import Decoder +from onnxruntime.quantization import QuantType, quantize_dynamic +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint +from icefall.utils import setup_logger + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="""It specifies the checkpoint to use for averaging. + Note: Epoch counts from 0. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + add_model_arguments(parser) + + return parser + + +def add_meta_data(filename: str, meta_data: Dict[str, str]): + """Add meta data to an ONNX model. It is changed in-place. + + Args: + filename: + Filename of the ONNX model to be changed. + meta_data: + Key-value pairs. + """ + model = onnx.load(filename) + for key, value in meta_data.items(): + meta = model.metadata_props.add() + meta.key = key + meta.value = value + + onnx.save(model, filename) + + +class OnnxEncoder(nn.Module): + """A wrapper for Conformer""" + + def __init__(self, encoder: Conformer): + """ + Args: + encoder: + A Conformer encoder. + """ + super().__init__() + self.encoder = encoder + + def forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Please see the help information of Conformer.forward + + Args: + x: + A 3-D tensor of shape (N, T, C) + x_lens: + A 1-D tensor of shape (N,). Its dtype is torch.int64 + Returns: + Return a tuple containing: + - encoder_out, A 3-D tensor of shape (N, T', joiner_dim) + - encoder_out_lens, A 1-D tensor of shape (N,) + """ + encoder_out, encoder_out_lens = self.encoder(x, x_lens) + # Now encoder_out is of shape (N, T, joiner_dim) + + return encoder_out, encoder_out_lens + + +class OnnxDecoder(nn.Module): + """A wrapper for Decoder""" + + def __init__(self, decoder: Decoder): + super().__init__() + self.decoder = decoder + + def forward(self, y: torch.Tensor) -> torch.Tensor: + """ + Args: + y: + A 2-D tensor of shape (N, context_size). + Returns + Return a 2-D tensor of shape (N, joiner_dim) + """ + need_pad = False + decoder_output = self.decoder(y, need_pad=need_pad) + output = decoder_output.squeeze(1) + + return output + + +class OnnxJoiner(nn.Module): + """A wrapper for the joiner""" + + def __init__(self, inner_linear: nn.Linear, output_linear: nn.Linear): + super().__init__() + self.inner_linear = inner_linear + self.output_linear = output_linear + + def forward( + self, + encoder_out: torch.Tensor, + decoder_out: torch.Tensor, + ) -> torch.Tensor: + """ + Args: + encoder_out: + A 2-D tensor of shape (N, joiner_dim) + decoder_out: + A 2-D tensor of shape (N, joiner_dim) + Returns: + Return a 2-D tensor of shape (N, vocab_size) + """ + logit = encoder_out + decoder_out + logit = self.inner_linear(torch.tanh(logit)) + output = self.output_linear(nn.functional.relu(logit)) + + return output + + +def export_encoder_model_onnx( + encoder_model: OnnxEncoder, + encoder_filename: str, + opset_version: int = 11, +) -> None: + """Export the given encoder model to ONNX format. + The exported model has two inputs: + + - x, a tensor of shape (N, T, C); dtype is torch.float32 + - x_lens, a tensor of shape (N,); dtype is torch.int64 + + and it has two outputs: + + - encoder_out, a tensor of shape (N, T', joiner_dim) + - encoder_out_lens, a tensor of shape (N,) + + Args: + encoder_model: + The input encoder model + encoder_filename: + The filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + x = torch.zeros(1, 100, 80, dtype=torch.float32) + x_lens = torch.tensor([100], dtype=torch.int64) + + torch.onnx.export( + encoder_model, + (x, x_lens), + encoder_filename, + verbose=False, + opset_version=opset_version, + input_names=["x", "x_lens"], + output_names=["encoder_out", "encoder_out_lens"], + dynamic_axes={ + "x": {0: "N", 1: "T"}, + "x_lens": {0: "N"}, + "encoder_out": {0: "N", 1: "T"}, + "encoder_out_lens": {0: "N"}, + }, + ) + + meta_data = { + "model_type": "conformer", + "version": "1", + "model_author": "k2-fsa", + "comment": "stateless3", + } + logging.info(f"meta_data: {meta_data}") + + add_meta_data(filename=encoder_filename, meta_data=meta_data) + + +def export_decoder_model_onnx( + decoder_model: OnnxDecoder, + decoder_filename: str, + opset_version: int = 11, +) -> None: + """Export the decoder model to ONNX format. + + The exported model has one input: + + - y: a torch.int64 tensor of shape (N, decoder_model.context_size) + + and has one output: + + - decoder_out: a torch.float32 tensor of shape (N, joiner_dim) + + Args: + decoder_model: + The decoder model to be exported. + decoder_filename: + Filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + context_size = decoder_model.decoder.context_size + vocab_size = decoder_model.decoder.vocab_size + + y = torch.zeros(10, context_size, dtype=torch.int64) + torch.onnx.export( + decoder_model, + y, + decoder_filename, + verbose=False, + opset_version=opset_version, + input_names=["y"], + output_names=["decoder_out"], + dynamic_axes={ + "y": {0: "N"}, + "decoder_out": {0: "N"}, + }, + ) + + meta_data = { + "context_size": str(context_size), + "vocab_size": str(vocab_size), + } + add_meta_data(filename=decoder_filename, meta_data=meta_data) + + +def export_joiner_model_onnx( + joiner_model: nn.Module, + joiner_filename: str, + opset_version: int = 11, +) -> None: + """Export the joiner model to ONNX format. + The exported joiner model has two inputs: + + - encoder_out: a tensor of shape (N, joiner_dim) + - decoder_out: a tensor of shape (N, joiner_dim) + + and produces one output: + + - logit: a tensor of shape (N, vocab_size) + """ + joiner_dim = joiner_model.inner_linear.weight.shape[1] + logging.info(f"joiner dim: {joiner_dim}") + + projected_encoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) + projected_decoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) + + torch.onnx.export( + joiner_model, + (projected_encoder_out, projected_decoder_out), + joiner_filename, + verbose=False, + opset_version=opset_version, + input_names=[ + "encoder_out", + "decoder_out", + ], + output_names=["logit"], + dynamic_axes={ + "encoder_out": {0: "N"}, + "decoder_out": {0: "N"}, + "logit": {0: "N"}, + }, + ) + meta_data = { + "joiner_dim": str(joiner_dim), + } + add_meta_data(filename=joiner_filename, meta_data=meta_data) + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + setup_logger(f"{params.exp_dir}/log-export/log-export-onnx") + + logging.info(f"device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + model.to(device) + + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict( + average_checkpoints(filenames, device=device), strict=False + ) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if start >= 0: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict( + average_checkpoints(filenames, device=device), strict=False + ) + + model.to("cpu") + model.eval() + + encoder = OnnxEncoder(encoder=model.encoder) + + decoder = OnnxDecoder(decoder=model.decoder) + + joiner = OnnxJoiner( + inner_linear=model.joiner.inner_linear, output_linear=model.joiner.output_linear + ) + + encoder_num_param = sum([p.numel() for p in encoder.parameters()]) + decoder_num_param = sum([p.numel() for p in decoder.parameters()]) + joiner_num_param = sum([p.numel() for p in joiner.parameters()]) + total_num_param = encoder_num_param + decoder_num_param + joiner_num_param + logging.info(f"encoder parameters: {encoder_num_param}") + logging.info(f"decoder parameters: {decoder_num_param}") + logging.info(f"joiner parameters: {joiner_num_param}") + logging.info(f"total parameters: {total_num_param}") + + if params.iter > 0: + suffix = f"iter-{params.iter}" + else: + suffix = f"epoch-{params.epoch}" + + suffix += f"-avg-{params.avg}" + + opset_version = 13 + + logging.info("Exporting encoder") + encoder_filename = params.exp_dir / f"encoder-{suffix}.onnx" + export_encoder_model_onnx( + encoder, + encoder_filename, + opset_version=opset_version, + ) + logging.info(f"Exported encoder to {encoder_filename}") + + logging.info("Exporting decoder") + decoder_filename = params.exp_dir / f"decoder-{suffix}.onnx" + export_decoder_model_onnx( + decoder, + decoder_filename, + opset_version=opset_version, + ) + logging.info(f"Exported decoder to {decoder_filename}") + + logging.info("Exporting joiner") + joiner_filename = params.exp_dir / f"joiner-{suffix}.onnx" + export_joiner_model_onnx( + joiner, + joiner_filename, + opset_version=opset_version, + ) + logging.info(f"Exported joiner to {joiner_filename}") + + # Generate int8 quantization models + # See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection + + logging.info("Generate int8 quantization models") + + encoder_filename_int8 = params.exp_dir / f"encoder-{suffix}.int8.onnx" + quantize_dynamic( + model_input=encoder_filename, + model_output=encoder_filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) + + decoder_filename_int8 = params.exp_dir / f"decoder-{suffix}.int8.onnx" + quantize_dynamic( + model_input=decoder_filename, + model_output=decoder_filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) + + joiner_filename_int8 = params.exp_dir / f"joiner-{suffix}.int8.onnx" + quantize_dynamic( + model_input=joiner_filename, + model_output=joiner_filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/onnx_check.py b/egs/librispeech/ASR/pruned_transducer_stateless/onnx_check.py new file mode 120000 index 000000000..66d63b807 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless/onnx_check.py @@ -0,0 +1 @@ +../pruned_transducer_stateless3/onnx_check.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/onnx_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless/onnx_decode.py new file mode 100755 index 000000000..8134d43f8 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless/onnx_decode.py @@ -0,0 +1,319 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao, +# Xiaoyu Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads ONNX exported models and uses them to decode the test sets. + +We use the pre-trained model from +https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless-2022-03-12 +as an example to show how to use this file. + +1. Download the pre-trained model + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless-2022-03-12 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "data/lang_bpe_500/bpe.model" +git lfs pull --include "exp/pretrained.pt" + +cd exp +ln -s pretrained.pt epoch-9999.pt +popd + +2. Export the model to ONNX + +./pruned_transducer_stateless/export-onnx.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --epoch 9999 \ + --avg 1 \ + --exp-dir $repo/exp/ + +It will generate the following 3 files inside $repo/exp: + + - encoder-epoch-9999-avg-1.onnx + - decoder-epoch-9999-avg-1.onnx + - joiner-epoch-9999-avg-1.onnx + +3. Run this file + +./pruned_transducer_stateless/onnx_decode.py \ + --exp-dir ./pruned_transducer_stateless/exp \ + --max-duration 600 \ + --encoder-model-filename $repo/exp/encoder-epoch-9999-avg-1.onnx \ + --decoder-model-filename $repo/exp/decoder-epoch-9999-avg-1.onnx \ + --joiner-model-filename $repo/exp/joiner-epoch-9999-avg-1.onnx \ +""" + + +import argparse +import logging +import time +from pathlib import Path +from typing import List, Tuple + +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule + +from onnx_pretrained import greedy_search, OnnxModel + +from icefall.utils import setup_logger, store_transcripts, write_error_stats + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--encoder-model-filename", + type=str, + required=True, + help="Path to the encoder onnx model. ", + ) + + parser.add_argument( + "--decoder-model-filename", + type=str, + required=True, + help="Path to the decoder onnx model. ", + ) + + parser.add_argument( + "--joiner-model-filename", + type=str, + required=True, + help="Path to the joiner onnx model. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="Valid values are greedy_search and modified_beam_search", + ) + + return parser + + +def decode_one_batch( + model: OnnxModel, sp: spm.SentencePieceProcessor, batch: dict +) -> List[List[str]]: + """Decode one batch and return the result. + Currently it only greedy_search is supported. + + Args: + model: + The neural model. + sp: + The BPE model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + + Returns: + Return the decoded results for each utterance. + """ + feature = batch["inputs"] + assert feature.ndim == 3 + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(dtype=torch.int64) + + encoder_out, encoder_out_lens = model.run_encoder(x=feature, x_lens=feature_lens) + + hyps = greedy_search( + model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens + ) + + hyps = [sp.decode(h).split() for h in hyps] + return hyps + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + model: nn.Module, + sp: spm.SentencePieceProcessor, +) -> Tuple[List[Tuple[str, List[str], List[str]]], float]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + model: + The neural model. + sp: + The BPE model. + + Returns: + - A list of tuples. Each tuple contains three elements: + - cut_id, + - reference transcript, + - predicted result. + - The total duration (in seconds) of the dataset. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + log_interval = 10 + total_duration = 0 + + results = [] + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + total_duration += sum([cut.duration for cut in batch["supervisions"]["cut"]]) + + hyps = decode_one_batch(model=model, sp=sp, batch=batch) + + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results.extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + + return results, total_duration + + +def save_results( + res_dir: Path, + test_set_name: str, + results: List[Tuple[str, List[str], List[str]]], +): + recog_path = res_dir / f"recogs-{test_set_name}.txt" + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = res_dir / f"errs-{test_set_name}.txt" + with open(errs_filename, "w") as f: + wer = write_error_stats(f, f"{test_set_name}", results, enable_log=True) + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + errs_info = res_dir / f"wer-summary-{test_set_name}.txt" + with open(errs_info, "w") as f: + print("WER", file=f) + print(wer, file=f) + + s = "\nFor {}, WER is {}:\n".format(test_set_name, wer) + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + + assert ( + args.decoding_method == "greedy_search" + ), "Only supports greedy_search currently." + res_dir = Path(args.exp_dir) / f"onnx-{args.decoding_method}" + + setup_logger(f"{res_dir}/log-decode") + logging.info("Decoding started") + + device = torch.device("cpu") + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(args.bpe_model) + + blank_id = sp.piece_to_id("") + assert blank_id == 0, blank_id + + logging.info(vars(args)) + + logging.info("About to create model") + model = OnnxModel( + encoder_model_filename=args.encoder_model_filename, + decoder_model_filename=args.decoder_model_filename, + joiner_model_filename=args.joiner_model_filename, + ) + + # we need cut ids to display recognition results. + args.return_cuts = True + librispeech = LibriSpeechAsrDataModule(args) + + test_clean_cuts = librispeech.test_clean_cuts() + test_other_cuts = librispeech.test_other_cuts() + + test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) + test_other_dl = librispeech.test_dataloaders(test_other_cuts) + + test_sets = ["test-clean", "test-other"] + test_dl = [test_clean_dl, test_other_dl] + + for test_set, test_dl in zip(test_sets, test_dl): + start_time = time.time() + results, total_duration = decode_dataset(dl=test_dl, model=model, sp=sp) + end_time = time.time() + elapsed_seconds = end_time - start_time + rtf = elapsed_seconds / total_duration + + logging.info(f"Elapsed time: {elapsed_seconds:.3f} s") + logging.info(f"Wave duration: {total_duration:.3f} s") + logging.info( + f"Real time factor (RTF): {elapsed_seconds:.3f}/{total_duration:.3f} = {rtf:.3f}" + ) + + save_results(res_dir=res_dir, test_set_name=test_set, results=results) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/onnx_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless/onnx_pretrained.py new file mode 120000 index 000000000..7607623c8 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless/onnx_pretrained.py @@ -0,0 +1 @@ +../pruned_transducer_stateless3/onnx_pretrained.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_decode.py new file mode 100755 index 000000000..3b1c72cf1 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_decode.py @@ -0,0 +1,321 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao, +# Xiaoyu Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads ONNX exported models and uses them to decode the test sets. + +We use the pre-trained model from +https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13 +as an example to show how to use this file. + +1. Download the pre-trained model + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "data/lang_bpe_500/bpe.model" +git lfs pull --include "exp/pretrained-iter-1224000-avg-14.pt" + +cd exp +ln -s pretrained-iter-1224000-avg-14.pt epoch-9999.pt +popd + +2. Export the model to ONNX + +./pruned_transducer_stateless3/export-onnx.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --epoch 9999 \ + --avg 1 \ + --exp-dir $repo/exp/ + +It will generate the following 3 files inside $repo/exp: + + - encoder-epoch-9999-avg-1.onnx + - decoder-epoch-9999-avg-1.onnx + - joiner-epoch-9999-avg-1.onnx + +2. Run this file + +./pruned_transducer_stateless3/onnx_decode.py \ + --exp-dir ./pruned_transducer_stateless3/exp \ + --max-duration 600 \ + --encoder-model-filename $repo/exp/encoder-epoch-9999-avg-1.onnx \ + --decoder-model-filename $repo/exp/decoder-epoch-9999-avg-1.onnx \ + --joiner-model-filename $repo/exp/joiner-epoch-9999-avg-1.onnx \ +""" + + +import argparse +import logging +import time +from pathlib import Path +from typing import List, Tuple + +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import AsrDataModule +from librispeech import LibriSpeech + +from onnx_pretrained import greedy_search, OnnxModel + +from icefall.utils import setup_logger, store_transcripts, write_error_stats + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--encoder-model-filename", + type=str, + required=True, + help="Path to the encoder onnx model. ", + ) + + parser.add_argument( + "--decoder-model-filename", + type=str, + required=True, + help="Path to the decoder onnx model. ", + ) + + parser.add_argument( + "--joiner-model-filename", + type=str, + required=True, + help="Path to the joiner onnx model. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless3/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="Valid values are greedy_search and modified_beam_search", + ) + + return parser + + +def decode_one_batch( + model: OnnxModel, sp: spm.SentencePieceProcessor, batch: dict +) -> List[List[str]]: + """Decode one batch and return the result. + Currently it only greedy_search is supported. + + Args: + model: + The neural model. + sp: + The BPE model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + + Returns: + Return the decoded results for each utterance. + """ + feature = batch["inputs"] + assert feature.ndim == 3 + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(dtype=torch.int64) + + encoder_out, encoder_out_lens = model.run_encoder(x=feature, x_lens=feature_lens) + + hyps = greedy_search( + model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens + ) + + hyps = [sp.decode(h).split() for h in hyps] + return hyps + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + model: nn.Module, + sp: spm.SentencePieceProcessor, +) -> Tuple[List[Tuple[str, List[str], List[str]]], float]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + model: + The neural model. + sp: + The BPE model. + + Returns: + - A list of tuples. Each tuple contains three elements: + - cut_id, + - reference transcript, + - predicted result. + - The total duration (in seconds) of the dataset. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + log_interval = 10 + total_duration = 0 + + results = [] + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + total_duration += sum([cut.duration for cut in batch["supervisions"]["cut"]]) + + hyps = decode_one_batch(model=model, sp=sp, batch=batch) + + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results.extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + + return results, total_duration + + +def save_results( + res_dir: Path, + test_set_name: str, + results: List[Tuple[str, List[str], List[str]]], +): + recog_path = res_dir / f"recogs-{test_set_name}.txt" + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = res_dir / f"errs-{test_set_name}.txt" + with open(errs_filename, "w") as f: + wer = write_error_stats(f, f"{test_set_name}", results, enable_log=True) + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + errs_info = res_dir / f"wer-summary-{test_set_name}.txt" + with open(errs_info, "w") as f: + print("WER", file=f) + print(wer, file=f) + + s = "\nFor {}, WER is {}:\n".format(test_set_name, wer) + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + AsrDataModule.add_arguments(parser) + args = parser.parse_args() + + assert ( + args.decoding_method == "greedy_search" + ), "Only supports greedy_search currently." + res_dir = Path(args.exp_dir) / f"onnx-{args.decoding_method}" + + setup_logger(f"{res_dir}/log-decode") + logging.info("Decoding started") + + device = torch.device("cpu") + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(args.bpe_model) + + blank_id = sp.piece_to_id("") + assert blank_id == 0, blank_id + + logging.info(vars(args)) + + logging.info("About to create model") + model = OnnxModel( + encoder_model_filename=args.encoder_model_filename, + decoder_model_filename=args.decoder_model_filename, + joiner_model_filename=args.joiner_model_filename, + ) + + # we need cut ids to display recognition results. + args.return_cuts = True + asr_datamodule = AsrDataModule(args) + librispeech = LibriSpeech(manifest_dir=args.manifest_dir) + + test_clean_cuts = librispeech.test_clean_cuts() + test_other_cuts = librispeech.test_other_cuts() + + test_clean_dl = asr_datamodule.test_dataloaders(test_clean_cuts) + test_other_dl = asr_datamodule.test_dataloaders(test_other_cuts) + + test_sets = ["test-clean", "test-other"] + test_dl = [test_clean_dl, test_other_dl] + + for test_set, test_dl in zip(test_sets, test_dl): + start_time = time.time() + results, total_duration = decode_dataset(dl=test_dl, model=model, sp=sp) + end_time = time.time() + elapsed_seconds = end_time - start_time + rtf = elapsed_seconds / total_duration + + logging.info(f"Elapsed time: {elapsed_seconds:.3f} s") + logging.info(f"Wave duration: {total_duration:.3f} s") + logging.info( + f"Real time factor (RTF): {elapsed_seconds:.3f}/{total_duration:.3f} = {rtf:.3f}" + ) + + save_results(res_dir=res_dir, test_set_name=test_set, results=results) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_pretrained.py index 3947aa102..e10915086 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_pretrained.py @@ -71,7 +71,6 @@ from typing import List, Tuple import k2 import kaldifeat -import numpy as np import onnxruntime as ort import torch import torchaudio @@ -139,7 +138,7 @@ class OnnxModel: ): session_opts = ort.SessionOptions() session_opts.inter_op_num_threads = 1 - session_opts.intra_op_num_threads = 1 + session_opts.intra_op_num_threads = 4 self.session_opts = session_opts diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/export-onnx.py b/egs/librispeech/ASR/pruned_transducer_stateless5/export-onnx.py index 3d94760dc..e89d94d82 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/export-onnx.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/export-onnx.py @@ -60,6 +60,7 @@ import sentencepiece as spm import torch import torch.nn as nn from conformer import Conformer +from onnxruntime.quantization import QuantType, quantize_dynamic from decoder import Decoder from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_params, get_transducer_model @@ -568,6 +569,35 @@ def main(): ) logging.info(f"Exported joiner to {joiner_filename}") + # Generate int8 quantization models + # See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection + + logging.info("Generate int8 quantization models") + + encoder_filename_int8 = params.exp_dir / f"encoder-{suffix}.int8.onnx" + quantize_dynamic( + model_input=encoder_filename, + model_output=encoder_filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) + + decoder_filename_int8 = params.exp_dir / f"decoder-{suffix}.int8.onnx" + quantize_dynamic( + model_input=decoder_filename, + model_output=decoder_filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) + + joiner_filename_int8 = params.exp_dir / f"joiner-{suffix}.int8.onnx" + quantize_dynamic( + model_input=joiner_filename, + model_output=joiner_filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) + if __name__ == "__main__": formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/onnx_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless5/onnx_decode.py new file mode 100755 index 000000000..6f26e34b5 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/onnx_decode.py @@ -0,0 +1,326 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao, +# Xiaoyu Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads ONNX exported models and uses them to decode the test sets. + +We use the pre-trained model from +https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless5-2022-07-07 +as an example to show how to use this file. + +1. Download the pre-trained model + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless5-2022-07-07 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "data/lang_bpe_500/bpe.model" +git lfs pull --include "exp/pretrained-epoch-30-avg-10.pt" + +cd exp +ln -s pretrained-epoch-30-avg-10.pt epoch-9999.pt +popd + +2. Export the model to ONNX + +./pruned_transducer_stateless5/export-onnx.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --epoch 9999 \ + --avg 1 \ + --use-averaged-model 0 \ + --exp-dir $repo/exp/ \ + --num-encoder-layers 18 \ + --dim-feedforward 2048 \ + --nhead 8 \ + --encoder-dim 512 \ + --decoder-dim 512 \ + --joiner-dim 512 + +It will generate the following 3 files inside $repo/exp: + + - encoder-epoch-9999-avg-1.onnx + - decoder-epoch-9999-avg-1.onnx + - joiner-epoch-9999-avg-1.onnx + +2. Run this file + +./pruned_transducer_stateless5/onnx_decode.py \ + --exp-dir ./pruned_transducer_stateless5/exp \ + --max-duration 600 \ + --encoder-model-filename $repo/exp/encoder-epoch-9999-avg-1.onnx \ + --decoder-model-filename $repo/exp/decoder-epoch-9999-avg-1.onnx \ + --joiner-model-filename $repo/exp/joiner-epoch-9999-avg-1.onnx \ +""" + + +import argparse +import logging +import time +from pathlib import Path +from typing import List, Tuple + +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule + +from onnx_pretrained import greedy_search, OnnxModel + +from icefall.utils import setup_logger, store_transcripts, write_error_stats + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--encoder-model-filename", + type=str, + required=True, + help="Path to the encoder onnx model. ", + ) + + parser.add_argument( + "--decoder-model-filename", + type=str, + required=True, + help="Path to the decoder onnx model. ", + ) + + parser.add_argument( + "--joiner-model-filename", + type=str, + required=True, + help="Path to the joiner onnx model. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless5/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="Valid values are greedy_search and modified_beam_search", + ) + + return parser + + +def decode_one_batch( + model: OnnxModel, sp: spm.SentencePieceProcessor, batch: dict +) -> List[List[str]]: + """Decode one batch and return the result. + Currently it only greedy_search is supported. + + Args: + model: + The neural model. + sp: + The BPE model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + + Returns: + Return the decoded results for each utterance. + """ + feature = batch["inputs"] + assert feature.ndim == 3 + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(dtype=torch.int64) + + encoder_out, encoder_out_lens = model.run_encoder(x=feature, x_lens=feature_lens) + + hyps = greedy_search( + model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens + ) + + hyps = [sp.decode(h).split() for h in hyps] + return hyps + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + model: nn.Module, + sp: spm.SentencePieceProcessor, +) -> Tuple[List[Tuple[str, List[str], List[str]]], float]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + model: + The neural model. + sp: + The BPE model. + + Returns: + - A list of tuples. Each tuple contains three elements: + - cut_id, + - reference transcript, + - predicted result. + - The total duration (in seconds) of the dataset. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + log_interval = 10 + total_duration = 0 + + results = [] + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + total_duration += sum([cut.duration for cut in batch["supervisions"]["cut"]]) + + hyps = decode_one_batch(model=model, sp=sp, batch=batch) + + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results.extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + + return results, total_duration + + +def save_results( + res_dir: Path, + test_set_name: str, + results: List[Tuple[str, List[str], List[str]]], +): + recog_path = res_dir / f"recogs-{test_set_name}.txt" + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = res_dir / f"errs-{test_set_name}.txt" + with open(errs_filename, "w") as f: + wer = write_error_stats(f, f"{test_set_name}", results, enable_log=True) + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + errs_info = res_dir / f"wer-summary-{test_set_name}.txt" + with open(errs_info, "w") as f: + print("WER", file=f) + print(wer, file=f) + + s = "\nFor {}, WER is {}:\n".format(test_set_name, wer) + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + + assert ( + args.decoding_method == "greedy_search" + ), "Only supports greedy_search currently." + res_dir = Path(args.exp_dir) / f"onnx-{args.decoding_method}" + + setup_logger(f"{res_dir}/log-decode") + logging.info("Decoding started") + + device = torch.device("cpu") + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(args.bpe_model) + + blank_id = sp.piece_to_id("") + assert blank_id == 0, blank_id + + logging.info(vars(args)) + + logging.info("About to create model") + model = OnnxModel( + encoder_model_filename=args.encoder_model_filename, + decoder_model_filename=args.decoder_model_filename, + joiner_model_filename=args.joiner_model_filename, + ) + + # we need cut ids to display recognition results. + args.return_cuts = True + librispeech = LibriSpeechAsrDataModule(args) + + test_clean_cuts = librispeech.test_clean_cuts() + test_other_cuts = librispeech.test_other_cuts() + + test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) + test_other_dl = librispeech.test_dataloaders(test_other_cuts) + + test_sets = ["test-clean", "test-other"] + test_dl = [test_clean_dl, test_other_dl] + + for test_set, test_dl in zip(test_sets, test_dl): + start_time = time.time() + results, total_duration = decode_dataset(dl=test_dl, model=model, sp=sp) + end_time = time.time() + elapsed_seconds = end_time - start_time + rtf = elapsed_seconds / total_duration + + logging.info(f"Elapsed time: {elapsed_seconds:.3f} s") + logging.info(f"Wave duration: {total_duration:.3f} s") + logging.info( + f"Real time factor (RTF): {elapsed_seconds:.3f}/{total_duration:.3f} = {rtf:.3f}" + ) + + save_results(res_dir=res_dir, test_set_name=test_set, results=results) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/onnx_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7/onnx_decode.py new file mode 100755 index 000000000..67585ee47 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/onnx_decode.py @@ -0,0 +1,319 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao, +# Xiaoyu Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads ONNX exported models and uses them to decode the test sets. + +We use the pre-trained model from +https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11 +as an example to show how to use this file. + +1. Download the pre-trained model + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "data/lang_bpe_500/bpe.model" +git lfs pull --include "exp/pretrained-epoch-30-avg-9.pt" + +cd exp +ln -s pretrained-epoch-30-avg-9.pt epoch-9999.pt +popd + +2. Export the model to ONNX + +./pruned_transducer_stateless7/export-onnx.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --epoch 9999 \ + --avg 1 \ + --exp-dir $repo/exp/ + +It will generate the following 3 files inside $repo/exp: + + - encoder-epoch-9999-avg-1.onnx + - decoder-epoch-9999-avg-1.onnx + - joiner-epoch-9999-avg-1.onnx + +2. Run this file + +./pruned_transducer_stateless7/onnx_decode.py \ + --exp-dir ./pruned_transducer_stateless7/exp \ + --max-duration 600 \ + --encoder-model-filename $repo/exp/encoder-epoch-9999-avg-1.onnx \ + --decoder-model-filename $repo/exp/decoder-epoch-9999-avg-1.onnx \ + --joiner-model-filename $repo/exp/joiner-epoch-9999-avg-1.onnx \ +""" + + +import argparse +import logging +import time +from pathlib import Path +from typing import List, Tuple + +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule + +from onnx_pretrained import greedy_search, OnnxModel + +from icefall.utils import setup_logger, store_transcripts, write_error_stats + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--encoder-model-filename", + type=str, + required=True, + help="Path to the encoder onnx model. ", + ) + + parser.add_argument( + "--decoder-model-filename", + type=str, + required=True, + help="Path to the decoder onnx model. ", + ) + + parser.add_argument( + "--joiner-model-filename", + type=str, + required=True, + help="Path to the joiner onnx model. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="Valid values are greedy_search and modified_beam_search", + ) + + return parser + + +def decode_one_batch( + model: OnnxModel, sp: spm.SentencePieceProcessor, batch: dict +) -> List[List[str]]: + """Decode one batch and return the result. + Currently it only greedy_search is supported. + + Args: + model: + The neural model. + sp: + The BPE model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + + Returns: + Return the decoded results for each utterance. + """ + feature = batch["inputs"] + assert feature.ndim == 3 + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(dtype=torch.int64) + + encoder_out, encoder_out_lens = model.run_encoder(x=feature, x_lens=feature_lens) + + hyps = greedy_search( + model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens + ) + + hyps = [sp.decode(h).split() for h in hyps] + return hyps + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + model: nn.Module, + sp: spm.SentencePieceProcessor, +) -> Tuple[List[Tuple[str, List[str], List[str]]], float]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + model: + The neural model. + sp: + The BPE model. + + Returns: + - A list of tuples. Each tuple contains three elements: + - cut_id, + - reference transcript, + - predicted result. + - The total duration (in seconds) of the dataset. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + log_interval = 10 + total_duration = 0 + + results = [] + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + total_duration += sum([cut.duration for cut in batch["supervisions"]["cut"]]) + + hyps = decode_one_batch(model=model, sp=sp, batch=batch) + + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results.extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + + return results, total_duration + + +def save_results( + res_dir: Path, + test_set_name: str, + results: List[Tuple[str, List[str], List[str]]], +): + recog_path = res_dir / f"recogs-{test_set_name}.txt" + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = res_dir / f"errs-{test_set_name}.txt" + with open(errs_filename, "w") as f: + wer = write_error_stats(f, f"{test_set_name}", results, enable_log=True) + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + errs_info = res_dir / f"wer-summary-{test_set_name}.txt" + with open(errs_info, "w") as f: + print("WER", file=f) + print(wer, file=f) + + s = "\nFor {}, WER is {}:\n".format(test_set_name, wer) + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + + assert ( + args.decoding_method == "greedy_search" + ), "Only supports greedy_search currently." + res_dir = Path(args.exp_dir) / f"onnx-{args.decoding_method}" + + setup_logger(f"{res_dir}/log-decode") + logging.info("Decoding started") + + device = torch.device("cpu") + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(args.bpe_model) + + blank_id = sp.piece_to_id("") + assert blank_id == 0, blank_id + + logging.info(vars(args)) + + logging.info("About to create model") + model = OnnxModel( + encoder_model_filename=args.encoder_model_filename, + decoder_model_filename=args.decoder_model_filename, + joiner_model_filename=args.joiner_model_filename, + ) + + # we need cut ids to display recognition results. + args.return_cuts = True + librispeech = LibriSpeechAsrDataModule(args) + + test_clean_cuts = librispeech.test_clean_cuts() + test_other_cuts = librispeech.test_other_cuts() + + test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) + test_other_dl = librispeech.test_dataloaders(test_other_cuts) + + test_sets = ["test-clean", "test-other"] + test_dl = [test_clean_dl, test_other_dl] + + for test_set, test_dl in zip(test_sets, test_dl): + start_time = time.time() + results, total_duration = decode_dataset(dl=test_dl, model=model, sp=sp) + end_time = time.time() + elapsed_seconds = end_time - start_time + rtf = elapsed_seconds / total_duration + + logging.info(f"Elapsed time: {elapsed_seconds:.3f} s") + logging.info(f"Wave duration: {total_duration:.3f} s") + logging.info( + f"Real time factor (RTF): {elapsed_seconds:.3f}/{total_duration:.3f} = {rtf:.3f}" + ) + + save_results(res_dir=res_dir, test_set_name=test_set, results=results) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index 94d0393c2..90b722bde 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -24,7 +24,7 @@ import torch from torch import Tensor, nn from transformer import Transformer -from icefall.utils import make_pad_mask, subsequent_chunk_mask +from icefall.utils import is_jit_tracing, make_pad_mask, subsequent_chunk_mask class Conformer(Transformer): @@ -154,7 +154,8 @@ class Conformer(Transformer): # Note: rounding_mode in torch.div() is available only in torch >= 1.8.0 lengths = (((x_lens - 1) >> 1) - 1) >> 1 - assert x.size(0) == lengths.max().item() + if not is_jit_tracing(): + assert x.size(0) == lengths.max().item() src_key_padding_mask = make_pad_mask(lengths) @@ -768,6 +769,14 @@ class RelPositionalEncoding(torch.nn.Module): def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: """Construct an PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() + if is_jit_tracing(): + # 10k frames correspond to ~100k ms, e.g., 100 seconds, i.e., + # It assumes that the maximum input won't have more than + # 10k frames. + # + # TODO(fangjun): Use torch.jit.script() for this module + max_len = 10000 + self.d_model = d_model self.xscale = math.sqrt(self.d_model) self.dropout = torch.nn.Dropout(p=dropout_rate) @@ -975,22 +984,34 @@ class RelPositionMultiheadAttention(nn.Module): the key, while time1 is for the query). """ (batch_size, num_heads, time1, n) = x.shape + time2 = time1 + left_context + if not is_jit_tracing(): + assert ( + n == left_context + 2 * time1 - 1 + ), f"{n} == {left_context} + 2 * {time1} - 1" - assert ( - n == left_context + 2 * time1 - 1 - ), f"{n} == {left_context} + 2 * {time1} - 1" + if is_jit_tracing(): + rows = torch.arange(start=time1 - 1, end=-1, step=-1) + cols = torch.arange(time2) + rows = rows.repeat(batch_size * num_heads).unsqueeze(-1) + indexes = rows + cols - # Note: TorchScript requires explicit arg for stride() - batch_stride = x.stride(0) - head_stride = x.stride(1) - time1_stride = x.stride(2) - n_stride = x.stride(3) - return x.as_strided( - (batch_size, num_heads, time1, time2), - (batch_stride, head_stride, time1_stride - n_stride, n_stride), - storage_offset=n_stride * (time1 - 1), - ) + x = x.reshape(-1, n) + x = torch.gather(x, dim=1, index=indexes) + x = x.reshape(batch_size, num_heads, time1, time2) + return x + else: + # Note: TorchScript requires explicit arg for stride() + batch_stride = x.stride(0) + head_stride = x.stride(1) + time1_stride = x.stride(2) + n_stride = x.stride(3) + return x.as_strided( + (batch_size, num_heads, time1, time2), + (batch_stride, head_stride, time1_stride - n_stride, n_stride), + storage_offset=n_stride * (time1 - 1), + ) def multi_head_attention_forward( self, @@ -1061,13 +1082,16 @@ class RelPositionMultiheadAttention(nn.Module): """ tgt_len, bsz, embed_dim = query.size() - assert embed_dim == embed_dim_to_check - assert key.size(0) == value.size(0) and key.size(1) == value.size(1) + if not is_jit_tracing(): + assert embed_dim == embed_dim_to_check + assert key.size(0) == value.size(0) and key.size(1) == value.size(1) head_dim = embed_dim // num_heads - assert ( - head_dim * num_heads == embed_dim - ), "embed_dim must be divisible by num_heads" + if not is_jit_tracing(): + assert ( + head_dim * num_heads == embed_dim + ), "embed_dim must be divisible by num_heads" + scaling = float(head_dim) ** -0.5 if torch.equal(query, key) and torch.equal(key, value): @@ -1181,7 +1205,8 @@ class RelPositionMultiheadAttention(nn.Module): q = q.transpose(0, 1) # (batch, time1, head, d_k) pos_emb_bsz = pos_emb.size(0) - assert pos_emb_bsz in (1, bsz) # actually it is 1 + if not is_jit_tracing(): + assert pos_emb_bsz in (1, bsz) # actually it is 1 p = self.linear_pos(pos_emb).view(pos_emb_bsz, -1, num_heads, head_dim) # (batch, 2*time1, head, d_k) --> (batch, head, d_k, 2*time -1) @@ -1212,11 +1237,12 @@ class RelPositionMultiheadAttention(nn.Module): attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) - assert list(attn_output_weights.size()) == [ - bsz * num_heads, - tgt_len, - src_len, - ] + if not is_jit_tracing(): + assert list(attn_output_weights.size()) == [ + bsz * num_heads, + tgt_len, + src_len, + ] if attn_mask is not None: if attn_mask.dtype == torch.bool: @@ -1265,7 +1291,10 @@ class RelPositionMultiheadAttention(nn.Module): ) attn_output = torch.bmm(attn_output_weights, v) - assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] + + if not is_jit_tracing(): + assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] + attn_output = ( attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) ) From 7c7d9ab04261a69a393768a0934cb6bf7f4f38af Mon Sep 17 00:00:00 2001 From: Zengwei Yao Date: Mon, 17 Apr 2023 12:03:52 +0800 Subject: [PATCH 155/174] add @torch.jit.export for streaming_forward func in Zipformer class (#1004) --- .../ASR/pruned_transducer_stateless7_streaming/zipformer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py index a5c422959..0a6886dec 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py @@ -570,6 +570,7 @@ class Zipformer(EncoderInterface): return x, lengths + @torch.jit.export def streaming_forward( self, x: torch.Tensor, From e32658e620eedeef3b82a7a1dc2f084e62f35861 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Mon, 17 Apr 2023 16:13:30 +0800 Subject: [PATCH 156/174] Fix torch.jit.script() export for streaming zipformer. (#1005) --- .../ASR/pruned_transducer_stateless7_streaming/export.py | 1 + .../ASR/pruned_transducer_stateless7_streaming/zipformer.py | 1 - 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export.py index 1bc54fa26..5735ee692 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export.py @@ -856,6 +856,7 @@ def main(): # Otherwise, one of its arguments is a ragged tensor and is not # torch scriptabe. model.__class__.forward = torch.jit.ignore(model.__class__.forward) + model.encoder.__class__.forward = model.encoder.__class__.streaming_forward logging.info("Using torch.jit.script") model = torch.jit.script(model) filename = params.exp_dir / "cpu_jit.pt" diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py index 0a6886dec..a5c422959 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py @@ -570,7 +570,6 @@ class Zipformer(EncoderInterface): return x, lengths - @torch.jit.export def streaming_forward( self, x: torch.Tensor, From 34d1b07c3da71bc8091ecd549cbb7e249f2b80ee Mon Sep 17 00:00:00 2001 From: marcoyang1998 <45973641+marcoyang1998@users.noreply.github.com> Date: Mon, 17 Apr 2023 16:43:00 +0800 Subject: [PATCH 157/174] Modified beam search with RNNLM rescoring (#1002) * add RNNLM rescore * add shallow fusion and lm rescore for streaming zipformer * minor fix * update RESULTS.md * fix yesno workflow, change from ubuntu-18.04 to ubuntu-latest --- .github/workflows/run-yesno-recipe.yml | 2 +- egs/librispeech/ASR/RESULTS.md | 64 +++++- .../ASR/lstm_transducer_stateless2/decode.py | 1 - .../beam_search.py | 198 ++++++++++++++++++ .../decode.py | 84 ++++++++ icefall/rnn_lm/model.py | 9 +- 6 files changed, 349 insertions(+), 9 deletions(-) diff --git a/.github/workflows/run-yesno-recipe.yml b/.github/workflows/run-yesno-recipe.yml index 1187dbf38..83a1d5462 100644 --- a/.github/workflows/run-yesno-recipe.yml +++ b/.github/workflows/run-yesno-recipe.yml @@ -35,7 +35,7 @@ jobs: matrix: # os: [ubuntu-18.04, macos-10.15] # TODO: enable macOS for CPU testing - os: [ubuntu-18.04] + os: [ubuntu-latest] python-version: [3.8] fail-fast: false diff --git a/egs/librispeech/ASR/RESULTS.md b/egs/librispeech/ASR/RESULTS.md index 9ca7a19b8..881e8be97 100644 --- a/egs/librispeech/ASR/RESULTS.md +++ b/egs/librispeech/ASR/RESULTS.md @@ -76,6 +76,64 @@ for m in greedy_search modified_beam_search fast_beam_search; do --num-decode-streams 2000 done ``` +We also support decoding with neural network LMs. After combining with language models, the WERs are +| decoding method | chunk size | test-clean | test-other | comment | decoding mode | +|----------------------|------------|------------|------------|---------------------|----------------------| +| modified beam search | 320ms | 3.11 | 7.93 | --epoch 30 --avg 9 | simulated streaming | +| modified beam search + RNNLM shallow fusion | 320ms | 2.58 | 6.65 | --epoch 30 --avg 9 | simulated streaming | +| modified beam search + RNNLM nbest rescore | 320ms | 2.59 | 6.86 | --epoch 30 --avg 9 | simulated streaming | + +Please use the following command for RNNLM shallow fusion: +```bash +for lm_scale in $(seq 0.15 0.01 0.38); do + for beam_size in 4 8 12; do + ./pruned_transducer_stateless7_streaming/decode.py \ + --epoch 99 \ + --avg 1 \ + --use-averaged-model False \ + --beam-size $beam_size \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp-large-LM \ + --max-duration 600 \ + --decode-chunk-len 32 \ + --decoding-method modified_beam_search_lm_shallow_fusion \ + --use-shallow-fusion 1 \ + --lm-type rnn \ + --lm-exp-dir rnn_lm/exp \ + --lm-epoch 99 \ + --lm-scale $lm_scale \ + --lm-avg 1 \ + --rnn-lm-embedding-dim 2048 \ + --rnn-lm-hidden-dim 2048 \ + --rnn-lm-num-layers 3 \ + --lm-vocab-size 500 + done +done +``` + +Please use the following command for RNNLM rescore: +```bash +./pruned_transducer_stateless7_streaming/decode.py \ + --epoch 30 \ + --avg 9 \ + --use-averaged-model True \ + --beam-size 8 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --max-duration 600 \ + --decode-chunk-len 32 \ + --decoding-method modified_beam_search_lm_rescore \ + --use-shallow-fusion 0 \ + --lm-type rnn \ + --lm-exp-dir rnn_lm/exp \ + --lm-epoch 99 \ + --lm-avg 1 \ + --rnn-lm-embedding-dim 2048 \ + --rnn-lm-hidden-dim 2048 \ + --rnn-lm-num-layers 3 \ + --lm-vocab-size 500 +``` + +A well-trained RNNLM can be found here: . + #### Smaller model @@ -540,9 +598,9 @@ for m in greedy_search fast_beam_search modified_beam_search ; do done ``` -Note that a small change is made to the `pruned_transducer_stateless7/decoder.py` in -this [PR](/ceph-data4/yangxiaoyu/softwares/icefall_development/icefall_random_padding/egs/librispeech/ASR/pruned_transducer_stateless7/exp_960h_no_paddingidx_ngpu4/tensorboard) to address the -problem of emitting the first symbol at the very beginning. If you need a +Note that a small change is made to the `pruned_transducer_stateless7/decoder.py` in +this [PR](/ceph-data4/yangxiaoyu/softwares/icefall_development/icefall_random_padding/egs/librispeech/ASR/pruned_transducer_stateless7/exp_960h_no_paddingidx_ngpu4/tensorboard) to address the +problem of emitting the first symbol at the very beginning. If you need a model without this issue, please download the model from here: ### LibriSpeech BPE training results (Pruned Stateless LSTM RNN-T + gradient filter) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py b/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py index 73207017b..1a724830b 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py @@ -925,7 +925,6 @@ def main(): ) LM.to(device) LM.eval() - else: LM = None diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py index 75edf0c54..c44a2ad3e 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -1059,6 +1059,204 @@ def modified_beam_search( ) +def modified_beam_search_lm_rescore( + model: Transducer, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + LM: LmScorer, + lm_scale_list: List[int], + beam: int = 4, + temperature: float = 1.0, + return_timestamps: bool = False, +) -> Union[List[List[int]], DecodingResults]: + """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. + Rescore the final results with RNNLM and return the one with the highest score + + Args: + model: + The transducer model. + encoder_out: + Output from the encoder. Its shape is (N, T, C). + encoder_out_lens: + A 1-D tensor of shape (N,), containing number of valid frames in + encoder_out before padding. + beam: + Number of active paths during the beam search. + temperature: + Softmax temperature. + LM: + A neural network language model + return_timestamps: + Whether to return timestamps. + Returns: + If return_timestamps is False, return the decoded result. + Else, return a DecodingResults object containing + decoded result and corresponding timestamps. + """ + assert encoder_out.ndim == 3, encoder_out.shape + assert encoder_out.size(0) >= 1, encoder_out.size(0) + + packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( + input=encoder_out, + lengths=encoder_out_lens.cpu(), + batch_first=True, + enforce_sorted=False, + ) + + blank_id = model.decoder.blank_id + unk_id = getattr(model, "unk_id", blank_id) + context_size = model.decoder.context_size + device = next(model.parameters()).device + + batch_size_list = packed_encoder_out.batch_sizes.tolist() + N = encoder_out.size(0) + assert torch.all(encoder_out_lens > 0), encoder_out_lens + assert N == batch_size_list[0], (N, batch_size_list) + + B = [HypothesisList() for _ in range(N)] + for i in range(N): + B[i].add( + Hypothesis( + ys=[blank_id] * context_size, + log_prob=torch.zeros(1, dtype=torch.float32, device=device), + timestamp=[], + ) + ) + + encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) + + offset = 0 + finalized_B = [] + for (t, batch_size) in enumerate(batch_size_list): + start = offset + end = offset + batch_size + current_encoder_out = encoder_out.data[start:end] + current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1) + # current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim) + offset = end + + finalized_B = B[batch_size:] + finalized_B + B = B[:batch_size] + + hyps_shape = get_hyps_shape(B).to(device) + + A = [list(b) for b in B] + B = [HypothesisList() for _ in range(batch_size)] + + ys_log_probs = torch.cat( + [hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps] + ) # (num_hyps, 1) + + decoder_input = torch.tensor( + [hyp.ys[-context_size:] for hyps in A for hyp in hyps], + device=device, + dtype=torch.int64, + ) # (num_hyps, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) + decoder_out = model.joiner.decoder_proj(decoder_out) + # decoder_out is of shape (num_hyps, 1, 1, joiner_dim) + + # Note: For torch 1.7.1 and below, it requires a torch.int64 tensor + # as index, so we use `to(torch.int64)` below. + current_encoder_out = torch.index_select( + current_encoder_out, + dim=0, + index=hyps_shape.row_ids(1).to(torch.int64), + ) # (num_hyps, 1, 1, encoder_out_dim) + + logits = model.joiner( + current_encoder_out, + decoder_out, + project_input=False, + ) # (num_hyps, 1, 1, vocab_size) + + logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) + + log_probs = (logits / temperature).log_softmax(dim=-1) # (num_hyps, vocab_size) + + log_probs.add_(ys_log_probs) + + vocab_size = log_probs.size(-1) + + log_probs = log_probs.reshape(-1) + + row_splits = hyps_shape.row_splits(1) * vocab_size + log_probs_shape = k2.ragged.create_ragged_shape2( + row_splits=row_splits, cached_tot_size=log_probs.numel() + ) + ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) + + for i in range(batch_size): + topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + topk_hyp_indexes = (topk_indexes // vocab_size).tolist() + topk_token_indexes = (topk_indexes % vocab_size).tolist() + + for k in range(len(topk_hyp_indexes)): + hyp_idx = topk_hyp_indexes[k] + hyp = A[i][hyp_idx] + + new_ys = hyp.ys[:] + new_token = topk_token_indexes[k] + new_timestamp = hyp.timestamp[:] + if new_token not in (blank_id, unk_id): + new_ys.append(new_token) + new_timestamp.append(t) + + new_log_prob = topk_log_probs[k] + new_hyp = Hypothesis( + ys=new_ys, log_prob=new_log_prob, timestamp=new_timestamp + ) + B[i].add(new_hyp) + + B = B + finalized_B + + # get the am_scores for n-best list + hyps_shape = get_hyps_shape(B) + am_scores = torch.tensor([hyp.log_prob.item() for b in B for hyp in b]) + am_scores = k2.RaggedTensor(value=am_scores, shape=hyps_shape).to(device) + + # now LM rescore + # prepare input data to LM + candidate_seqs = [hyp.ys[context_size:] for b in B for hyp in b] + possible_seqs = k2.RaggedTensor(candidate_seqs) + row_splits = possible_seqs.shape.row_splits(1) + sentence_token_lengths = row_splits[1:] - row_splits[:-1] + possible_seqs_with_sos = add_sos(possible_seqs, sos_id=1) + possible_seqs_with_eos = add_eos(possible_seqs, eos_id=1) + sentence_token_lengths += 1 + + x = possible_seqs_with_sos.pad(mode="constant", padding_value=blank_id) + y = possible_seqs_with_eos.pad(mode="constant", padding_value=blank_id) + x = x.to(device).to(torch.int64) + y = y.to(device).to(torch.int64) + sentence_token_lengths = sentence_token_lengths.to(device).to(torch.int64) + + lm_scores = LM.lm(x=x, y=y, lengths=sentence_token_lengths) + assert lm_scores.ndim == 2 + lm_scores = -1 * lm_scores.sum(dim=1) + + ans = {} + unsorted_indices = packed_encoder_out.unsorted_indices.tolist() + + # get the best hyp with different lm_scale + for lm_scale in lm_scale_list: + key = f"nnlm_scale_{lm_scale}" + tot_scores = am_scores.values + lm_scores * lm_scale + ragged_tot_scores = k2.RaggedTensor(shape=am_scores.shape, value=tot_scores) + max_indexes = ragged_tot_scores.argmax().tolist() + unsorted_hyps = [candidate_seqs[idx] for idx in max_indexes] + hyps = [] + for idx in unsorted_indices: + hyps.append(unsorted_hyps[idx]) + + ans[key] = hyps + return ans + + def _deprecated_modified_beam_search( model: Transducer, encoder_out: torch.Tensor, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/decode.py index e7616fbc5..8aa0d8689 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/decode.py @@ -122,6 +122,8 @@ from beam_search import ( greedy_search, greedy_search_batch, modified_beam_search, + modified_beam_search_lm_rescore, + modified_beam_search_lm_shallow_fusion, ) from train import add_model_arguments, get_params, get_transducer_model @@ -132,6 +134,7 @@ from icefall.checkpoint import ( load_checkpoint, ) from icefall.lexicon import Lexicon +from icefall.lm_wrapper import LmScorer from icefall.utils import ( AttributeDict, setup_logger, @@ -307,6 +310,32 @@ def get_parser(): fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", ) + parser.add_argument( + "--use-shallow-fusion", + type=str2bool, + default=False, + help="""Use neural network LM for shallow fusion. + If you want to use LODR, you will also need to set this to true + """, + ) + + parser.add_argument( + "--lm-type", + type=str, + default="rnn", + help="Type of NN lm", + choices=["rnn", "transformer"], + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.3, + help="""The scale of the neural network LM + Used only when `--use-shallow-fusion` is set to True. + """, + ) + add_model_arguments(parser) return parser @@ -319,6 +348,7 @@ def decode_one_batch( batch: dict, word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, + LM: Optional[LmScorer] = None, ) -> Dict[str, List[List[str]]]: """Decode one batch and return the result in a dict. The dict has the following format: @@ -443,6 +473,26 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_lm_shallow_fusion": + hyp_tokens = modified_beam_search_lm_shallow_fusion( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LM=LM, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_lm_rescore": + lm_scale_list = [0.01 * i for i in range(10, 50)] + ans_dict = modified_beam_search_lm_rescore( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LM=LM, + lm_scale_list=lm_scale_list, + ) else: batch_size = encoder_out.size(0) @@ -481,6 +531,13 @@ def decode_one_batch( key += f"_ngram_lm_scale_{params.ngram_lm_scale}" return {key: hyps} + elif params.decoding_method == "modified_beam_search_lm_rescore": + ans = dict() + assert ans_dict is not None + for key, hyps in ans_dict.items(): + hyps = [sp.decode(hyp).split() for hyp in hyps] + ans[f"beam_size_{params.beam_size}_{key}"] = hyps + return ans else: return {f"beam_size_{params.beam_size}": hyps} @@ -492,6 +549,7 @@ def decode_dataset( sp: spm.SentencePieceProcessor, word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, + LM: Optional[LmScorer] = None, ) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: """Decode dataset. @@ -541,6 +599,7 @@ def decode_dataset( decoding_graph=decoding_graph, word_table=word_table, batch=batch, + LM=LM, ) for name, hyps in hyps_dict.items(): @@ -603,6 +662,7 @@ def save_results( def main(): parser = get_parser() LibriSpeechAsrDataModule.add_arguments(parser) + LmScorer.add_arguments(parser) args = parser.parse_args() args.exp_dir = Path(args.exp_dir) @@ -617,6 +677,8 @@ def main(): "fast_beam_search_nbest_LG", "fast_beam_search_nbest_oracle", "modified_beam_search", + "modified_beam_search_lm_shallow_fusion", + "modified_beam_search_lm_rescore", ) params.res_dir = params.exp_dir / params.decoding_method @@ -642,6 +704,14 @@ def main(): params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + if params.use_shallow_fusion: + params.suffix += f"-{params.lm_type}-lm-scale-{params.lm_scale}" + + if "LODR" in params.decoding_method: + params.suffix += ( + f"-LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}" + ) + if params.use_averaged_model: params.suffix += "-use-averaged-model" @@ -751,6 +821,19 @@ def main(): model.to(device) model.eval() + # only load the neural network LM if required + if params.use_shallow_fusion or "lm" in params.decoding_method: + LM = LmScorer( + lm_type=params.lm_type, + params=params, + device=device, + lm_scale=params.lm_scale, + ) + LM.to(device) + LM.eval() + else: + LM = None + if "fast_beam_search" in params.decoding_method: if params.decoding_method == "fast_beam_search_nbest_LG": lexicon = Lexicon(params.lang_dir) @@ -792,6 +875,7 @@ def main(): sp=sp, word_table=word_table, decoding_graph=decoding_graph, + LM=LM, ) save_results( diff --git a/icefall/rnn_lm/model.py b/icefall/rnn_lm/model.py index 08eb753b5..ebb3128e3 100644 --- a/icefall/rnn_lm/model.py +++ b/icefall/rnn_lm/model.py @@ -154,17 +154,18 @@ class RnnLmModel(torch.nn.Module): self.cache = {} def score_token(self, x: torch.Tensor, x_lens: torch.Tensor, state=None): - """Score a batch of tokens + """Score a batch of tokens, i.e each sample in the batch should be a + single token. For example, x = torch.tensor([[5],[10],[20]]) + Args: x (torch.Tensor): A batch of tokens x_lens (torch.Tensor): The length of tokens in the batch before padding - state (_type_, optional): + state (optional): Either None or a tuple of two torch.Tensor. Each tensor has - the shape of (hidden_dim) - + the shape of (num_layers, bs, hidden_dim) Returns: _type_: _description_ From 8838fe0bd22e33820791220c28fda6fb7d3cf27f Mon Sep 17 00:00:00 2001 From: Yifan Yang <64255737+yfyeung@users.noreply.github.com> Date: Mon, 17 Apr 2023 17:47:25 +0800 Subject: [PATCH 158/174] Zipformer for Common Voice (#997) * Add soft links in pruned_transducer_stateless7 for CommonVoice * Add python files * Update prepare.sh * Update normalization * Fix for soft links * Add some docs * Add export * Update egs/commonvoice/ASR/RESULTS.md Co-authored-by: Fangjun Kuang * Add export for onnx --------- Co-authored-by: Fangjun Kuang --- egs/commonvoice/ASR/README.md | 18 + egs/commonvoice/ASR/RESULTS.md | 59 + .../local/compute_fbank_commonvoice_splits.py | 2 +- egs/commonvoice/ASR/local/prepare_lang_bpe.py | 1 + .../ASR/local/preprocess_commonvoice.py | 23 +- egs/commonvoice/ASR/local/train_bpe_model.py | 1 + .../ASR/local/validate_bpe_lexicon.py | 1 + egs/commonvoice/ASR/prepare.sh | 92 +- .../pruned_transducer_stateless7/__init__.py | 0 .../asr_datamodule.py | 420 ++++++ .../beam_search.py | 1 + .../pruned_transducer_stateless7/decode.py | 962 +++++++++++++ .../pruned_transducer_stateless7/decoder.py | 1 + .../encoder_interface.py | 1 + .../export-onnx.py | 600 ++++++++ .../pruned_transducer_stateless7/export.py | 321 +++++ .../pruned_transducer_stateless7/joiner.py | 1 + .../ASR/pruned_transducer_stateless7/model.py | 1 + .../onnx_check.py | 240 ++++ .../onnx_pretrained.py | 419 ++++++ .../ASR/pruned_transducer_stateless7/optim.py | 1 + .../pretrained.py | 355 +++++ .../pruned_transducer_stateless7/scaling.py | 1 + .../ASR/pruned_transducer_stateless7/train.py | 1250 +++++++++++++++++ .../pruned_transducer_stateless7/zipformer.py | 1 + egs/commonvoice/ASR/scaling_converter.py | 1 + 26 files changed, 4764 insertions(+), 9 deletions(-) create mode 100644 egs/commonvoice/ASR/README.md create mode 100644 egs/commonvoice/ASR/RESULTS.md create mode 120000 egs/commonvoice/ASR/local/prepare_lang_bpe.py create mode 120000 egs/commonvoice/ASR/local/train_bpe_model.py create mode 120000 egs/commonvoice/ASR/local/validate_bpe_lexicon.py create mode 100644 egs/commonvoice/ASR/pruned_transducer_stateless7/__init__.py create mode 100644 egs/commonvoice/ASR/pruned_transducer_stateless7/asr_datamodule.py create mode 120000 egs/commonvoice/ASR/pruned_transducer_stateless7/beam_search.py create mode 100755 egs/commonvoice/ASR/pruned_transducer_stateless7/decode.py create mode 120000 egs/commonvoice/ASR/pruned_transducer_stateless7/decoder.py create mode 120000 egs/commonvoice/ASR/pruned_transducer_stateless7/encoder_interface.py create mode 100755 egs/commonvoice/ASR/pruned_transducer_stateless7/export-onnx.py create mode 100755 egs/commonvoice/ASR/pruned_transducer_stateless7/export.py create mode 120000 egs/commonvoice/ASR/pruned_transducer_stateless7/joiner.py create mode 120000 egs/commonvoice/ASR/pruned_transducer_stateless7/model.py create mode 100755 egs/commonvoice/ASR/pruned_transducer_stateless7/onnx_check.py create mode 100755 egs/commonvoice/ASR/pruned_transducer_stateless7/onnx_pretrained.py create mode 120000 egs/commonvoice/ASR/pruned_transducer_stateless7/optim.py create mode 100755 egs/commonvoice/ASR/pruned_transducer_stateless7/pretrained.py create mode 120000 egs/commonvoice/ASR/pruned_transducer_stateless7/scaling.py create mode 100755 egs/commonvoice/ASR/pruned_transducer_stateless7/train.py create mode 120000 egs/commonvoice/ASR/pruned_transducer_stateless7/zipformer.py create mode 120000 egs/commonvoice/ASR/scaling_converter.py diff --git a/egs/commonvoice/ASR/README.md b/egs/commonvoice/ASR/README.md new file mode 100644 index 000000000..a4582499b --- /dev/null +++ b/egs/commonvoice/ASR/README.md @@ -0,0 +1,18 @@ +# Introduction + +This recipe includes some different ASR models trained with Common Voice + +[./RESULTS.md](./RESULTS.md) contains the latest results. + +# Transducers + +There are various folders containing the name `transducer` in this folder. +The following table lists the differences among them. + +| | Encoder | Decoder | Comment | +|---------------------------------------|---------------------|--------------------|---------------------------------------------------| +| `pruned_transducer_stateless7` | Zipformer | Embedding + Conv1d | First experiment with Zipformer from Dan | + +The decoder in `transducer_stateless` is modified from the paper +[RNN-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/). +We place an additional Conv1d layer right after the input embedding layer. diff --git a/egs/commonvoice/ASR/RESULTS.md b/egs/commonvoice/ASR/RESULTS.md new file mode 100644 index 000000000..751625371 --- /dev/null +++ b/egs/commonvoice/ASR/RESULTS.md @@ -0,0 +1,59 @@ +## Results +### GigaSpeech BPE training results (Pruned Stateless Transducer 7) + +#### [pruned_transducer_stateless7](./pruned_transducer_stateless7) + +See #997 for more details. + +Number of model parameters: 70369391, i.e., 70.37 M + +The best WER, as of 2023-04-17, for Common Voice English 13.0 (cv-corpus-13.0-2023-03-09/en) is below: + +Results are: + +| | Dev | Test | +|----------------------|-------|-------| +| greedy search | 9.96 | 12.54 | +| modified beam search | 9.86 | 12.48 | + +To reproduce the above result, use the following commands for training: + +```bash +export CUDA_VISIBLE_DEVICES="0,1,2,3" +./pruned_transducer_stateless7/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir pruned_transducer_stateless7/exp \ + --max-duration 550 +``` + +and the following commands for decoding: + +```bash +# greedy search +./pruned_transducer_stateless7/decode.py \ + --epoch 30 \ + --avg 5 \ + --decoding-method greedy_search \ + --exp-dir pruned_transducer_stateless7/exp \ + --bpe-model data/en/lang_bpe_500/bpe.model \ + --max-duration 600 + +# modified beam search +./pruned_transducer_stateless7/decode.py \ + --epoch 30 \ + --avg 5 \ + --decoding-method modified_beam_search \ + --beam-size 4 \ + --exp-dir pruned_transducer_stateless7/exp \ + --bpe-model data/en/lang_bpe_500/bpe.model \ + --max-duration 600 +``` + +Pretrained model is available at + + +The tensorboard log for training is available at + diff --git a/egs/commonvoice/ASR/local/compute_fbank_commonvoice_splits.py b/egs/commonvoice/ASR/local/compute_fbank_commonvoice_splits.py index f8c09ccf3..0564f6ec6 100755 --- a/egs/commonvoice/ASR/local/compute_fbank_commonvoice_splits.py +++ b/egs/commonvoice/ASR/local/compute_fbank_commonvoice_splits.py @@ -90,7 +90,7 @@ def compute_fbank_commonvoice_splits(args): subset = "train" num_splits = args.num_splits language = args.language - output_dir = f"data/{language}/fbank/{subset}_split_{num_splits}" + output_dir = f"data/{language}/fbank/cv-{language}_{subset}_split_{num_splits}" output_dir = Path(output_dir) assert output_dir.exists(), f"{output_dir} does not exist!" diff --git a/egs/commonvoice/ASR/local/prepare_lang_bpe.py b/egs/commonvoice/ASR/local/prepare_lang_bpe.py new file mode 120000 index 000000000..36b40e7fc --- /dev/null +++ b/egs/commonvoice/ASR/local/prepare_lang_bpe.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/prepare_lang_bpe.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/local/preprocess_commonvoice.py b/egs/commonvoice/ASR/local/preprocess_commonvoice.py index 241e81963..c5ec14502 100755 --- a/egs/commonvoice/ASR/local/preprocess_commonvoice.py +++ b/egs/commonvoice/ASR/local/preprocess_commonvoice.py @@ -17,6 +17,7 @@ import argparse import logging +import re from pathlib import Path from typing import Optional @@ -42,6 +43,11 @@ def get_args(): return parser.parse_args() +def normalize_text(utt: str) -> str: + utt = re.sub(r"[{0}]+".format("-"), " ", utt) + return re.sub(r"[^a-zA-Z\s]", "", utt).upper() + + def preprocess_commonvoice( language: str, dataset: Optional[str] = None, @@ -84,6 +90,17 @@ def preprocess_commonvoice( logging.info(f"{partition} already exists - skipping") continue + logging.info(f"Normalizing text in {partition}") + for sup in m["supervisions"]: + text = str(sup.text) + orig_text = text + sup.text = normalize_text(sup.text) + text = str(sup.text) + if len(orig_text) != len(text): + logging.info( + f"\nOriginal text vs normalized text:\n{orig_text}\n{text}" + ) + # Create long-recording cut manifests. cut_set = CutSet.from_manifests( recordings=m["recordings"], @@ -92,12 +109,6 @@ def preprocess_commonvoice( # Run data augmentation that needs to be done in the # time domain. - if "train" in partition: - logging.info( - f"Speed perturb for {partition} with factors 0.9 and 1.1 " - "(Perturbing may take 2 minutes and saving may take 7 minutes)" - ) - cut_set = cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) logging.info(f"Saving to {raw_cuts_path}") cut_set.to_file(raw_cuts_path) diff --git a/egs/commonvoice/ASR/local/train_bpe_model.py b/egs/commonvoice/ASR/local/train_bpe_model.py new file mode 120000 index 000000000..6fad36421 --- /dev/null +++ b/egs/commonvoice/ASR/local/train_bpe_model.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/train_bpe_model.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/local/validate_bpe_lexicon.py b/egs/commonvoice/ASR/local/validate_bpe_lexicon.py new file mode 120000 index 000000000..721bb48e7 --- /dev/null +++ b/egs/commonvoice/ASR/local/validate_bpe_lexicon.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/validate_bpe_lexicon.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/prepare.sh b/egs/commonvoice/ASR/prepare.sh index 9a28167b1..7a583f9c8 100755 --- a/egs/commonvoice/ASR/prepare.sh +++ b/egs/commonvoice/ASR/prepare.sh @@ -126,7 +126,7 @@ fi if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then log "Stage 5: Split train subset into ${num_splits} pieces" - split_dir=data/${lang}/fbank/train_split_${num_splits} + split_dir=data/${lang}/fbank/cv-${lang}_train_split_${num_splits} if [ ! -e $split_dir/.cv-${lang}_train_split.done ]; then lhotse split $num_splits ./data/${lang}/fbank/cv-${lang}_cuts_train_raw.jsonl.gz $split_dir touch $split_dir/.cv-${lang}_train_split.done @@ -147,10 +147,98 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then fi if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then - log "Stage 7: Compute fbank for musan" + log "Stage 7: Combine features for train" + if [ ! -f data/${lang}/fbank/cv-${lang}_cuts_train.jsonl.gz ]; then + pieces=$(find data/${lang}/fbank/cv-${lang}_train_split_${num_splits} -name "cv-${lang}_cuts_train.*.jsonl.gz") + lhotse combine $pieces data/${lang}/fbank/cv-${lang}_cuts_train.jsonl.gz + fi +fi + +if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then + log "Stage 8: Compute fbank for musan" mkdir -p data/fbank if [ ! -e data/fbank/.musan.done ]; then ./local/compute_fbank_musan.py touch data/fbank/.musan.done fi fi + +if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then + log "Stage 9: Prepare BPE based lang" + + for vocab_size in ${vocab_sizes[@]}; do + lang_dir=data/${lang}/lang_bpe_${vocab_size} + mkdir -p $lang_dir + + if [ ! -f $lang_dir/transcript_words.txt ]; then + log "Generate data for BPE training" + file=$( + find "data/${lang}/fbank/cv-${lang}_cuts_train.jsonl.gz" + ) + gunzip -c ${file} | awk -F '"' '{print $30}' > $lang_dir/transcript_words.txt + + # Ensure space only appears once + sed -i 's/\t/ /g' $lang_dir/transcript_words.txt + sed -i 's/[ ][ ]*/ /g' $lang_dir/transcript_words.txt + fi + + if [ ! -f $lang_dir/words.txt ]; then + cat $lang_dir/transcript_words.txt | sed 's/ /\n/g' \ + | sort -u | sed '/^$/d' > $lang_dir/words.txt + (echo '!SIL'; echo ''; echo ''; ) | + cat - $lang_dir/words.txt | sort | uniq | awk ' + BEGIN { + print " 0"; + } + { + if ($1 == "") { + print " is in the vocabulary!" | "cat 1>&2" + exit 1; + } + if ($1 == "") { + print " is in the vocabulary!" | "cat 1>&2" + exit 1; + } + printf("%s %d\n", $1, NR); + } + END { + printf("#0 %d\n", NR+1); + printf(" %d\n", NR+2); + printf(" %d\n", NR+3); + }' > $lang_dir/words || exit 1; + mv $lang_dir/words $lang_dir/words.txt + fi + + if [ ! -f $lang_dir/bpe.model ]; then + ./local/train_bpe_model.py \ + --lang-dir $lang_dir \ + --vocab-size $vocab_size \ + --transcript $lang_dir/transcript_words.txt + fi + + if [ ! -f $lang_dir/L_disambig.pt ]; then + ./local/prepare_lang_bpe.py --lang-dir $lang_dir + + log "Validating $lang_dir/lexicon.txt" + ./local/validate_bpe_lexicon.py \ + --lexicon $lang_dir/lexicon.txt \ + --bpe-model $lang_dir/bpe.model + fi + + if [ ! -f $lang_dir/L.fst ]; then + log "Converting L.pt to L.fst" + ./shared/convert-k2-to-openfst.py \ + --olabels aux_labels \ + $lang_dir/L.pt \ + $lang_dir/L.fst + fi + + if [ ! -f $lang_dir/L_disambig.fst ]; then + log "Converting L_disambig.pt to L_disambig.fst" + ./shared/convert-k2-to-openfst.py \ + --olabels aux_labels \ + $lang_dir/L_disambig.pt \ + $lang_dir/L_disambig.fst + fi + done +fi diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7/__init__.py b/egs/commonvoice/ASR/pruned_transducer_stateless7/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7/asr_datamodule.py b/egs/commonvoice/ASR/pruned_transducer_stateless7/asr_datamodule.py new file mode 100644 index 000000000..2c37244a4 --- /dev/null +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7/asr_datamodule.py @@ -0,0 +1,420 @@ +# Copyright 2023 Xiaomi Corp. (authors: Yifan Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import inspect +import logging +from functools import lru_cache +from pathlib import Path +from typing import Any, Dict, Optional + +import torch +from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy +from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures + CutConcatenate, + CutMix, + DynamicBucketingSampler, + K2SpeechRecognitionDataset, + PrecomputedFeatures, + SingleCutSampler, + SpecAugment, +) +from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples + AudioSamples, + OnTheFlyFeatures, +) +from lhotse.utils import fix_random_seed +from torch.utils.data import DataLoader + +from icefall.utils import str2bool + + +class _SeedWorkers: + def __init__(self, seed: int): + self.seed = seed + + def __call__(self, worker_id: int): + fix_random_seed(self.seed + worker_id) + + +class CommonVoiceAsrDataModule: + """ + DataModule for k2 ASR experiments. + It assumes there is always one train and valid dataloader, + but there can be multiple test dataloaders (e.g. CommonVoice test-clean + and test-other). + + It contains all the common data pipeline modules used in ASR + experiments, e.g.: + - dynamic batch size, + - bucketing samplers, + - cut concatenation, + - augmentation, + - on-the-fly feature extraction + + This class should be derived for specific corpora used in ASR tasks. + """ + + def __init__(self, args: argparse.Namespace): + self.args = args + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="ASR data related options", + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", + ) + group.add_argument( + "--language", + type=str, + default="en", + help="""Language of Common Voice""", + ) + group.add_argument( + "--cv-manifest-dir", + type=Path, + default=Path("data/en/fbank"), + help="Path to directory with CommonVoice train/dev/test cuts.", + ) + group.add_argument( + "--manifest-dir", + type=Path, + default=Path("data/fbank"), + help="Path to directory with the other cuts.", + ) + group.add_argument( + "--max-duration", + type=int, + default=200.0, + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", + ) + group.add_argument( + "--bucketing-sampler", + type=str2bool, + default=True, + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", + ) + group.add_argument( + "--num-buckets", + type=int, + default=30, + help="The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets).", + ) + group.add_argument( + "--concatenate-cuts", + type=str2bool, + default=False, + help="When enabled, utterances (cuts) will be concatenated " + "to minimize the amount of padding.", + ) + group.add_argument( + "--duration-factor", + type=float, + default=1.0, + help="Determines the maximum duration of a concatenated cut " + "relative to the duration of the longest cut in a batch.", + ) + group.add_argument( + "--gap", + type=float, + default=1.0, + help="The amount of padding (in seconds) inserted between " + "concatenated cuts. This padding is filled with noise when " + "noise augmentation is used.", + ) + group.add_argument( + "--on-the-fly-feats", + type=str2bool, + default=False, + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", + ) + group.add_argument( + "--shuffle", + type=str2bool, + default=True, + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", + ) + group.add_argument( + "--drop-last", + type=str2bool, + default=True, + help="Whether to drop last batch. Used by sampler.", + ) + group.add_argument( + "--return-cuts", + type=str2bool, + default=True, + help="When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it.", + ) + + group.add_argument( + "--num-workers", + type=int, + default=2, + help="The number of training dataloader workers that " + "collect the batches.", + ) + + group.add_argument( + "--enable-spec-aug", + type=str2bool, + default=True, + help="When enabled, use SpecAugment for training dataset.", + ) + + group.add_argument( + "--spec-aug-time-warp-factor", + type=int, + default=80, + help="Used only when --enable-spec-aug is True. " + "It specifies the factor for time warping in SpecAugment. " + "Larger values mean more warping. " + "A value less than 1 means to disable time warp.", + ) + + group.add_argument( + "--enable-musan", + type=str2bool, + default=True, + help="When enabled, select noise from MUSAN and mix it" + "with training dataset. ", + ) + + group.add_argument( + "--input-strategy", + type=str, + default="PrecomputedFeatures", + help="AudioSamples or PrecomputedFeatures", + ) + + def train_dataloaders( + self, + cuts_train: CutSet, + sampler_state_dict: Optional[Dict[str, Any]] = None, + ) -> DataLoader: + """ + Args: + cuts_train: + CutSet for training. + sampler_state_dict: + The state dict for the training sampler. + """ + transforms = [] + if self.args.enable_musan: + logging.info("Enable MUSAN") + logging.info("About to get Musan cuts") + cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") + transforms.append( + CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + ) + else: + logging.info("Disable MUSAN") + + if self.args.concatenate_cuts: + logging.info( + f"Using cut concatenation with duration factor " + f"{self.args.duration_factor} and gap {self.args.gap}." + ) + # Cut concatenation should be the first transform in the list, + # so that if we e.g. mix noise in, it will fill the gaps between + # different utterances. + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + input_transforms = [] + if self.args.enable_spec_aug: + logging.info("Enable SpecAugment") + logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") + # Set the value of num_frame_masks according to Lhotse's version. + # In different Lhotse's versions, the default of num_frame_masks is + # different. + num_frame_masks = 10 + num_frame_masks_parameter = inspect.signature( + SpecAugment.__init__ + ).parameters["num_frame_masks"] + if num_frame_masks_parameter.default == 1: + num_frame_masks = 2 + logging.info(f"Num frame mask: {num_frame_masks}") + input_transforms.append( + SpecAugment( + time_warp_factor=self.args.spec_aug_time_warp_factor, + num_frame_masks=num_frame_masks, + features_mask_size=27, + num_feature_masks=2, + frames_mask_size=100, + ) + ) + else: + logging.info("Disable SpecAugment") + + logging.info("About to create train dataset") + train = K2SpeechRecognitionDataset( + input_strategy=eval(self.args.input_strategy)(), + cut_transforms=transforms, + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + ) + + if self.args.on_the_fly_feats: + # NOTE: the PerturbSpeed transform should be added only if we + # remove it from data prep stage. + # Add on-the-fly speed perturbation; since originally it would + # have increased epoch size by 3, we will apply prob 2/3 and use + # 3x more epochs. + # Speed perturbation probably should come first before + # concatenation, but in principle the transforms order doesn't have + # to be strict (e.g. could be randomized) + # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa + # Drop feats to be on the safe side. + train = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + ) + + if self.args.bucketing_sampler: + logging.info("Using DynamicBucketingSampler.") + train_sampler = DynamicBucketingSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + num_buckets=self.args.num_buckets, + drop_last=self.args.drop_last, + ) + else: + logging.info("Using SingleCutSampler.") + train_sampler = SingleCutSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + ) + logging.info("About to create train dataloader") + + if sampler_state_dict is not None: + logging.info("Loading sampler state dict") + train_sampler.load_state_dict(sampler_state_dict) + + # 'seed' is derived from the current random state, which will have + # previously been set in the main process. + seed = torch.randint(0, 100000, ()).item() + worker_init_fn = _SeedWorkers(seed) + + train_dl = DataLoader( + train, + sampler=train_sampler, + batch_size=None, + num_workers=self.args.num_workers, + persistent_workers=False, + worker_init_fn=worker_init_fn, + ) + + return train_dl + + def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: + transforms = [] + if self.args.concatenate_cuts: + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + logging.info("About to create dev dataset") + if self.args.on_the_fly_feats: + validate = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + return_cuts=self.args.return_cuts, + ) + else: + validate = K2SpeechRecognitionDataset( + cut_transforms=transforms, + return_cuts=self.args.return_cuts, + ) + valid_sampler = DynamicBucketingSampler( + cuts_valid, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.info("About to create dev dataloader") + valid_dl = DataLoader( + validate, + sampler=valid_sampler, + batch_size=None, + num_workers=2, + persistent_workers=False, + ) + + return valid_dl + + def test_dataloaders(self, cuts: CutSet) -> DataLoader: + logging.debug("About to create test dataset") + test = K2SpeechRecognitionDataset( + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) + if self.args.on_the_fly_feats + else eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + sampler = DynamicBucketingSampler( + cuts, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.debug("About to create test dataloader") + test_dl = DataLoader( + test, + batch_size=None, + sampler=sampler, + num_workers=self.args.num_workers, + ) + return test_dl + + @lru_cache() + def train_cuts(self) -> CutSet: + logging.info("About to get train cuts") + return load_manifest_lazy( + self.args.cv_manifest_dir / f"cv-{self.args.language}_cuts_train.jsonl.gz" + ) + + @lru_cache() + def dev_cuts(self) -> CutSet: + logging.info("About to get dev cuts") + return load_manifest_lazy( + self.args.cv_manifest_dir / f"cv-{self.args.language}_cuts_dev.jsonl.gz" + ) + + @lru_cache() + def test_cuts(self) -> CutSet: + logging.info("About to get test cuts") + return load_manifest_lazy( + self.args.cv_manifest_dir / f"cv-{self.args.language}_cuts_test.jsonl.gz" + ) diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7/beam_search.py b/egs/commonvoice/ASR/pruned_transducer_stateless7/beam_search.py new file mode 120000 index 000000000..e24eca39f --- /dev/null +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7/beam_search.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless2/beam_search.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7/decode.py b/egs/commonvoice/ASR/pruned_transducer_stateless7/decode.py new file mode 100755 index 000000000..52b2fbcab --- /dev/null +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7/decode.py @@ -0,0 +1,962 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao, +# Xiaoyu Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) greedy search +./pruned_transducer_stateless7/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7/exp \ + --max-duration 600 \ + --decoding-method greedy_search + +(2) beam search (not recommended) +./pruned_transducer_stateless7/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7/exp \ + --max-duration 600 \ + --decoding-method beam_search \ + --beam-size 4 + +(3) modified beam search +./pruned_transducer_stateless7/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7/exp \ + --max-duration 600 \ + --decoding-method modified_beam_search \ + --beam-size 4 + +(4) fast beam search (one best) +./pruned_transducer_stateless7/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 + +(5) fast beam search (nbest) +./pruned_transducer_stateless7/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(6) fast beam search (nbest oracle WER) +./pruned_transducer_stateless7/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_oracle \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(7) fast beam search (with LG) +./pruned_transducer_stateless7/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_LG \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 + +(8) modified beam search with RNNLM shallow fusion +./pruned_transducer_stateless5/decode.py \ + --epoch 35 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless5/exp \ + --max-duration 600 \ + --decoding-method modified_beam_search_lm_shallow_fusion \ + --beam-size 4 \ + --lm-type rnn \ + --lm-scale 0.3 \ + --lm-exp-dir /path/to/LM \ + --rnn-lm-epoch 99 \ + --rnn-lm-avg 1 \ + --rnn-lm-num-layers 3 \ + --rnn-lm-tie-weights 1 + +(9) modified beam search with LM shallow fusion + LODR +./pruned_transducer_stateless5/decode.py \ + --epoch 28 \ + --avg 15 \ + --max-duration 600 \ + --exp-dir ./pruned_transducer_stateless5/exp \ + --decoding-method modified_beam_search_LODR \ + --beam-size 4 \ + --lm-type rnn \ + --lm-scale 0.4 \ + --lm-exp-dir /path/to/LM \ + --rnn-lm-epoch 99 \ + --rnn-lm-avg 1 \ + --rnn-lm-num-layers 3 \ + --rnn-lm-tie-weights 1 + --tokens-ngram 2 \ + --ngram-lm-scale -0.16 \ + +""" + + +import argparse +import logging +import math +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import CommonVoiceAsrDataModule +from beam_search import ( + beam_search, + fast_beam_search_nbest, + fast_beam_search_nbest_LG, + fast_beam_search_nbest_oracle, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, + modified_beam_search_lm_shallow_fusion, + modified_beam_search_LODR, + modified_beam_search_ngram_rescoring, +) +from train import add_model_arguments, get_params, get_transducer_model + +from icefall import LmScorer, NgramLm +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=9, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/en/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/en/lang_bpe_500", + help="The lang dir containing word table and LG graph", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - fast_beam_search + - fast_beam_search_nbest + - fast_beam_search_nbest_oracle + - fast_beam_search_nbest_LG + - modified_beam_search_lm_shallow_fusion # for rnn lm shallow fusion + - modified_beam_search_LODR + If you use fast_beam_search_nbest_LG, you have to specify + `--lang-dir`, which should contain `LG.pt`. + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=20.0, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search, + fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle + """, + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.01, + help=""" + Used only when --decoding_method is fast_beam_search_nbest_LG. + It specifies the scale for n-gram LM scores. + """, + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=8, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=64, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. + Used only when --decoding_method is greedy_search""", + ) + + parser.add_argument( + "--num-paths", + type=int, + default=200, + help="""Number of paths for nbest decoding. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""Scale applied to lattice scores when computing nbest paths. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--use-shallow-fusion", + type=str2bool, + default=False, + help="""Use neural network LM for shallow fusion. + If you want to use LODR, you will also need to set this to true + """, + ) + + parser.add_argument( + "--lm-type", + type=str, + default="rnn", + help="Type of NN lm", + choices=["rnn", "transformer"], + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.3, + help="""The scale of the neural network LM + Used only when `--use-shallow-fusion` is set to True. + """, + ) + + parser.add_argument( + "--tokens-ngram", + type=int, + default=3, + help="""Token Ngram used for rescoring. + Used only when the decoding method is + modified_beam_search_ngram_rescoring, or LODR + """, + ) + + parser.add_argument( + "--backoff-id", + type=int, + default=500, + help="""ID of the backoff symbol. + Used only when the decoding method is + modified_beam_search_ngram_rescoring""", + ) + add_model_arguments(parser) + + return parser + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + batch: dict, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, + ngram_lm: Optional[NgramLm] = None, + ngram_lm_scale: float = 1.0, + LM: Optional[LmScorer] = None, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if greedy_search is used, it would be "greedy_search" + If beam search with a beam size of 7 is used, it would be + "beam_7" + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + LM: + A neural net LM for shallow fusion. Only used when `--use-shallow-fusion` + set to true. + ngram_lm: + A ngram lm. Used in LODR decoding. + ngram_lm_scale: + The scale of the ngram language model. + Returns: + Return the decoding result. See above description for the format of + the returned dict. + """ + device = next(model.parameters()).device + feature = batch["inputs"] + assert feature.ndim == 3 + + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + + hyps = [] + + if params.decoding_method == "fast_beam_search": + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_LG": + hyp_tokens = fast_beam_search_nbest_LG( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in hyp_tokens: + hyps.append([word_table[i] for i in hyp]) + elif params.decoding_method == "fast_beam_search_nbest": + hyp_tokens = fast_beam_search_nbest( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_oracle": + hyp_tokens = fast_beam_search_nbest_oracle( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + ref_texts=sp.encode(supervisions["text"]), + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_lm_shallow_fusion": + hyp_tokens = modified_beam_search_lm_shallow_fusion( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LM=LM, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_LODR": + hyp_tokens = modified_beam_search_LODR( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LODR_lm=ngram_lm, + LODR_lm_scale=ngram_lm_scale, + LM=LM, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + else: + batch_size = encoder_out.size(0) + + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.decoding_method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyps.append(sp.decode(hyp).split()) + + if params.decoding_method == "greedy_search": + return {"greedy_search": hyps} + elif "fast_beam_search" in params.decoding_method: + key = f"beam_{params.beam}_" + key += f"max_contexts_{params.max_contexts}_" + key += f"max_states_{params.max_states}" + if "nbest" in params.decoding_method: + key += f"_num_paths_{params.num_paths}_" + key += f"nbest_scale_{params.nbest_scale}" + if "LG" in params.decoding_method: + key += f"_ngram_lm_scale_{params.ngram_lm_scale}" + + return {key: hyps} + else: + return {f"beam_size_{params.beam_size}": hyps} + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, + ngram_lm: Optional[NgramLm] = None, + ngram_lm_scale: float = 1.0, + LM: Optional[LmScorer] = None, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + LM: + A neural network LM, used during shallow fusion + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + if params.decoding_method == "greedy_search": + log_interval = 50 + else: + log_interval = 20 + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + word_table=word_table, + batch=batch, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, + LM=LM, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + CommonVoiceAsrDataModule.add_arguments(parser) + LmScorer.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + assert params.decoding_method in ( + "greedy_search", + "beam_search", + "fast_beam_search", + "fast_beam_search_nbest", + "fast_beam_search_nbest_LG", + "fast_beam_search_nbest_oracle", + "modified_beam_search", + "modified_beam_search_lm_shallow_fusion", + "modified_beam_search_LODR", + ) + params.res_dir = params.exp_dir / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if "fast_beam_search" in params.decoding_method: + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + if "nbest" in params.decoding_method: + params.suffix += f"-nbest-scale-{params.nbest_scale}" + params.suffix += f"-num-paths-{params.num_paths}" + if "LG" in params.decoding_method: + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" + elif "beam_search" in params.decoding_method: + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + else: + params.suffix += f"-context-{params.context_size}" + params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + + if "ngram" in params.decoding_method: + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" + if params.use_shallow_fusion: + if params.lm_type == "rnn": + params.suffix += f"-rnnlm-lm-scale-{params.lm_scale}" + elif params.lm_type == "transformer": + params.suffix += f"-transformer-lm-scale-{params.lm_scale}" + + if "LODR" in params.decoding_method: + params.suffix += ( + f"-LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}" + ) + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and are defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + + # only load N-gram LM when needed + if "ngram" in params.decoding_method or "LODR" in params.decoding_method: + lm_filename = f"{params.tokens_ngram}gram.fst.txt" + logging.info(f"lm filename: {lm_filename}") + ngram_lm = NgramLm( + str(params.lang_dir / lm_filename), + backoff_id=params.backoff_id, + is_binary=False, + ) + logging.info(f"num states: {ngram_lm.lm.num_states}") + ngram_lm_scale = params.ngram_lm_scale + else: + ngram_lm = None + ngram_lm_scale = None + + # only load the neural network LM if doing shallow fusion + if params.use_shallow_fusion: + LM = LmScorer( + lm_type=params.lm_type, + params=params, + device=device, + lm_scale=params.lm_scale, + ) + LM.to(device) + LM.eval() + + else: + LM = None + if "fast_beam_search" in params.decoding_method: + if params.decoding_method == "fast_beam_search_nbest_LG": + lexicon = Lexicon(params.lang_dir) + word_table = lexicon.word_table + lg_filename = params.lang_dir / "LG.pt" + logging.info(f"Loading {lg_filename}") + decoding_graph = k2.Fsa.from_dict( + torch.load(lg_filename, map_location=device) + ) + decoding_graph.scores *= params.ngram_lm_scale + else: + word_table = None + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + else: + decoding_graph = None + word_table = None + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + commonvoice = CommonVoiceAsrDataModule(args) + + dev_cuts = commonvoice.dev_cuts() + test_cuts = commonvoice.test_cuts() + + dev_dl = commonvoice.valid_dataloaders(dev_cuts) + test_dl = commonvoice.test_dataloaders(test_cuts) + + test_sets = ["dev", "test"] + test_dl = [dev_dl, test_dl] + + for test_set, test_dl in zip(test_sets, test_dl): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + sp=sp, + word_table=word_table, + decoding_graph=decoding_graph, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, + LM=LM, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7/decoder.py b/egs/commonvoice/ASR/pruned_transducer_stateless7/decoder.py new file mode 120000 index 000000000..8283d8c5a --- /dev/null +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7/decoder.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7/decoder.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7/encoder_interface.py b/egs/commonvoice/ASR/pruned_transducer_stateless7/encoder_interface.py new file mode 120000 index 000000000..653c5b09a --- /dev/null +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7/encoder_interface.py @@ -0,0 +1 @@ +../../../librispeech/ASR/transducer_stateless/encoder_interface.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7/export-onnx.py b/egs/commonvoice/ASR/pruned_transducer_stateless7/export-onnx.py new file mode 100755 index 000000000..0c98885ac --- /dev/null +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7/export-onnx.py @@ -0,0 +1,600 @@ +#!/usr/bin/env python3 +# +# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang, +# Yifan Yang) + +""" +This script exports a transducer model from PyTorch to ONNX. + +We use the pre-trained model from +https://huggingface.co/yfyeung/icefall-asr-cv-corpus-13.0-2023-03-09-en-pruned-transducer-stateless7-2023-04-17 +as an example to show how to use this file. + +1. Download the pre-trained model + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/yfyeung/icefall-asr-cv-corpus-13.0-2023-03-09-en-pruned-transducer-stateless7-2023-04-17 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "data/lang_bpe_500/bpe.model" +git lfs pull --include "exp/pretrained.pt" + +cd exp +ln -s pretrained.pt epoch-9999.pt +popd + +2. Export the model to ONNX + +./pruned_transducer_stateless7/export-onnx.py \ + --bpe-model $repo/data/en/lang_bpe_500/bpe.model \ + --use-averaged-model 0 \ + --epoch 9999 \ + --avg 1 \ + --exp-dir $repo/exp + +It will generate the following 3 files inside $repo/exp: + + - encoder-epoch-9999-avg-1.onnx + - decoder-epoch-9999-avg-1.onnx + - joiner-epoch-9999-avg-1.onnx + +See ./onnx_pretrained.py and ./onnx_check.py for how to +use the exported ONNX models. +""" + +import argparse +import logging +from pathlib import Path +from typing import Dict, Tuple + +import onnx +import sentencepiece as spm +import torch +import torch.nn as nn +from decoder import Decoder +from onnxruntime.quantization import QuantType, quantize_dynamic +from scaling_converter import convert_scaled_to_non_scaled +from train import add_model_arguments, get_params, get_transducer_model +from zipformer import Zipformer + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import setup_logger, str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="""It specifies the checkpoint to use for averaging. + Note: Epoch counts from 0. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/en/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + add_model_arguments(parser) + + return parser + + +def add_meta_data(filename: str, meta_data: Dict[str, str]): + """Add meta data to an ONNX model. It is changed in-place. + + Args: + filename: + Filename of the ONNX model to be changed. + meta_data: + Key-value pairs. + """ + model = onnx.load(filename) + for key, value in meta_data.items(): + meta = model.metadata_props.add() + meta.key = key + meta.value = value + + onnx.save(model, filename) + + +class OnnxEncoder(nn.Module): + """A wrapper for Zipformer and the encoder_proj from the joiner""" + + def __init__(self, encoder: Zipformer, encoder_proj: nn.Linear): + """ + Args: + encoder: + A Zipformer encoder. + encoder_proj: + The projection layer for encoder from the joiner. + """ + super().__init__() + self.encoder = encoder + self.encoder_proj = encoder_proj + + def forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Please see the help information of Zipformer.forward + + Args: + x: + A 3-D tensor of shape (N, T, C) + x_lens: + A 1-D tensor of shape (N,). Its dtype is torch.int64 + Returns: + Return a tuple containing: + - encoder_out, A 3-D tensor of shape (N, T', joiner_dim) + - encoder_out_lens, A 1-D tensor of shape (N,) + """ + encoder_out, encoder_out_lens = self.encoder(x, x_lens) + + encoder_out = self.encoder_proj(encoder_out) + # Now encoder_out is of shape (N, T, joiner_dim) + + return encoder_out, encoder_out_lens + + +class OnnxDecoder(nn.Module): + """A wrapper for Decoder and the decoder_proj from the joiner""" + + def __init__(self, decoder: Decoder, decoder_proj: nn.Linear): + super().__init__() + self.decoder = decoder + self.decoder_proj = decoder_proj + + def forward(self, y: torch.Tensor) -> torch.Tensor: + """ + Args: + y: + A 2-D tensor of shape (N, context_size). + Returns + Return a 2-D tensor of shape (N, joiner_dim) + """ + need_pad = False + decoder_output = self.decoder(y, need_pad=need_pad) + decoder_output = decoder_output.squeeze(1) + output = self.decoder_proj(decoder_output) + + return output + + +class OnnxJoiner(nn.Module): + """A wrapper for the joiner""" + + def __init__(self, output_linear: nn.Linear): + super().__init__() + self.output_linear = output_linear + + def forward( + self, + encoder_out: torch.Tensor, + decoder_out: torch.Tensor, + ) -> torch.Tensor: + """ + Args: + encoder_out: + A 2-D tensor of shape (N, joiner_dim) + decoder_out: + A 2-D tensor of shape (N, joiner_dim) + Returns: + Return a 2-D tensor of shape (N, vocab_size) + """ + logit = encoder_out + decoder_out + logit = self.output_linear(torch.tanh(logit)) + return logit + + +def export_encoder_model_onnx( + encoder_model: OnnxEncoder, + encoder_filename: str, + opset_version: int = 11, +) -> None: + """Export the given encoder model to ONNX format. + The exported model has two inputs: + + - x, a tensor of shape (N, T, C); dtype is torch.float32 + - x_lens, a tensor of shape (N,); dtype is torch.int64 + + and it has two outputs: + + - encoder_out, a tensor of shape (N, T', joiner_dim) + - encoder_out_lens, a tensor of shape (N,) + + Args: + encoder_model: + The input encoder model + encoder_filename: + The filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + x = torch.zeros(1, 100, 80, dtype=torch.float32) + x_lens = torch.tensor([100], dtype=torch.int64) + + torch.onnx.export( + encoder_model, + (x, x_lens), + encoder_filename, + verbose=False, + opset_version=opset_version, + input_names=["x", "x_lens"], + output_names=["encoder_out", "encoder_out_lens"], + dynamic_axes={ + "x": {0: "N", 1: "T"}, + "x_lens": {0: "N"}, + "encoder_out": {0: "N", 1: "T"}, + "encoder_out_lens": {0: "N"}, + }, + ) + + meta_data = { + "model_type": "zipformer", + "version": "1", + "model_author": "k2-fsa", + "comment": "stateless7", + } + logging.info(f"meta_data: {meta_data}") + + add_meta_data(filename=encoder_filename, meta_data=meta_data) + + +def export_decoder_model_onnx( + decoder_model: OnnxDecoder, + decoder_filename: str, + opset_version: int = 11, +) -> None: + """Export the decoder model to ONNX format. + + The exported model has one input: + + - y: a torch.int64 tensor of shape (N, decoder_model.context_size) + + and has one output: + + - decoder_out: a torch.float32 tensor of shape (N, joiner_dim) + + Args: + decoder_model: + The decoder model to be exported. + decoder_filename: + Filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + context_size = decoder_model.decoder.context_size + vocab_size = decoder_model.decoder.vocab_size + + y = torch.zeros(10, context_size, dtype=torch.int64) + torch.onnx.export( + decoder_model, + y, + decoder_filename, + verbose=False, + opset_version=opset_version, + input_names=["y"], + output_names=["decoder_out"], + dynamic_axes={ + "y": {0: "N"}, + "decoder_out": {0: "N"}, + }, + ) + + meta_data = { + "context_size": str(context_size), + "vocab_size": str(vocab_size), + } + add_meta_data(filename=decoder_filename, meta_data=meta_data) + + +def export_joiner_model_onnx( + joiner_model: nn.Module, + joiner_filename: str, + opset_version: int = 11, +) -> None: + """Export the joiner model to ONNX format. + The exported joiner model has two inputs: + + - encoder_out: a tensor of shape (N, joiner_dim) + - decoder_out: a tensor of shape (N, joiner_dim) + + and produces one output: + + - logit: a tensor of shape (N, vocab_size) + """ + joiner_dim = joiner_model.output_linear.weight.shape[1] + logging.info(f"joiner dim: {joiner_dim}") + + projected_encoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) + projected_decoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) + + torch.onnx.export( + joiner_model, + (projected_encoder_out, projected_decoder_out), + joiner_filename, + verbose=False, + opset_version=opset_version, + input_names=[ + "encoder_out", + "decoder_out", + ], + output_names=["logit"], + dynamic_axes={ + "encoder_out": {0: "N"}, + "decoder_out": {0: "N"}, + "logit": {0: "N"}, + }, + ) + meta_data = { + "joiner_dim": str(joiner_dim), + } + add_meta_data(filename=joiner_filename, meta_data=meta_data) + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + setup_logger(f"{params.exp_dir}/log-export/log-export-onnx") + + logging.info(f"device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + model.to(device) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to("cpu") + model.eval() + + convert_scaled_to_non_scaled(model, inplace=True) + + encoder = OnnxEncoder( + encoder=model.encoder, + encoder_proj=model.joiner.encoder_proj, + ) + + decoder = OnnxDecoder( + decoder=model.decoder, + decoder_proj=model.joiner.decoder_proj, + ) + + joiner = OnnxJoiner(output_linear=model.joiner.output_linear) + + encoder_num_param = sum([p.numel() for p in encoder.parameters()]) + decoder_num_param = sum([p.numel() for p in decoder.parameters()]) + joiner_num_param = sum([p.numel() for p in joiner.parameters()]) + total_num_param = encoder_num_param + decoder_num_param + joiner_num_param + logging.info(f"encoder parameters: {encoder_num_param}") + logging.info(f"decoder parameters: {decoder_num_param}") + logging.info(f"joiner parameters: {joiner_num_param}") + logging.info(f"total parameters: {total_num_param}") + + if params.iter > 0: + suffix = f"iter-{params.iter}" + else: + suffix = f"epoch-{params.epoch}" + + suffix += f"-avg-{params.avg}" + + opset_version = 13 + + logging.info("Exporting encoder") + encoder_filename = params.exp_dir / f"encoder-{suffix}.onnx" + export_encoder_model_onnx( + encoder, + encoder_filename, + opset_version=opset_version, + ) + logging.info(f"Exported encoder to {encoder_filename}") + + logging.info("Exporting decoder") + decoder_filename = params.exp_dir / f"decoder-{suffix}.onnx" + export_decoder_model_onnx( + decoder, + decoder_filename, + opset_version=opset_version, + ) + logging.info(f"Exported decoder to {decoder_filename}") + + logging.info("Exporting joiner") + joiner_filename = params.exp_dir / f"joiner-{suffix}.onnx" + export_joiner_model_onnx( + joiner, + joiner_filename, + opset_version=opset_version, + ) + logging.info(f"Exported joiner to {joiner_filename}") + + # Generate int8 quantization models + # See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection + + logging.info("Generate int8 quantization models") + + encoder_filename_int8 = params.exp_dir / f"encoder-{suffix}.int8.onnx" + quantize_dynamic( + model_input=encoder_filename, + model_output=encoder_filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) + + decoder_filename_int8 = params.exp_dir / f"decoder-{suffix}.int8.onnx" + quantize_dynamic( + model_input=decoder_filename, + model_output=decoder_filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) + + joiner_filename_int8 = params.exp_dir / f"joiner-{suffix}.int8.onnx" + quantize_dynamic( + model_input=joiner_filename, + model_output=joiner_filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + main() diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7/export.py b/egs/commonvoice/ASR/pruned_transducer_stateless7/export.py new file mode 100755 index 000000000..53705321e --- /dev/null +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7/export.py @@ -0,0 +1,321 @@ +#!/usr/bin/env python3 +# +# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang, +# Yifan Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This script converts several saved checkpoints +# to a single one using model averaging. +""" + +Usage: + +(1) Export to torchscript model using torch.jit.script() + +./pruned_transducer_stateless7/export.py \ + --exp-dir ./pruned_transducer_stateless7/exp \ + --bpe-model data/en/lang_bpe_500/bpe.model \ + --epoch 30 \ + --avg 5 \ + --jit 1 + +It will generate a file `cpu_jit.pt` in the given `exp_dir`. You can later +load it by `torch.jit.load("cpu_jit.pt")`. + +Note `cpu` in the name `cpu_jit.pt` means the parameters when loaded into Python +are on CPU. You can use `to("cuda")` to move them to a CUDA device. + +Check +https://github.com/k2-fsa/sherpa +for how to use the exported models outside of icefall. + +(2) Export `model.state_dict()` + +./pruned_transducer_stateless7/export.py \ + --exp-dir ./pruned_transducer_stateless7/exp \ + --bpe-model data/en/lang_bpe_500/bpe.model \ + --epoch 30 \ + --avg 5 + +It will generate a file `pretrained.pt` in the given `exp_dir`. You can later +load it by `icefall.checkpoint.load_checkpoint()`. + +To use the generated file with `pruned_transducer_stateless7/decode.py`, +you can do: + + cd /path/to/exp_dir + ln -s pretrained.pt epoch-9999.pt + + cd /path/to/egs/commonvoice/ASR + ./pruned_transducer_stateless7/decode.py \ + --exp-dir ./pruned_transducer_stateless7/exp \ + --epoch 9999 \ + --avg 1 \ + --max-duration 600 \ + --decoding-method greedy_search \ + --bpe-model data/en/lang_bpe_500/bpe.model + +Check ./pretrained.py for its usage. + +Note: If you don't want to train a model from scratch, we have +provided one for you. You can get it at + +https://huggingface.co/yfyeung/icefall-asr-cv-corpus-13.0-2023-03-09-en-pruned-transducer-stateless7-2023-04-17 + +with the following commands: + + sudo apt-get install git-lfs + git lfs install + git clone https://huggingface.co/yfyeung/icefall-asr-cv-corpus-13.0-2023-03-09-en-pruned-transducer-stateless7-2023-04-17 + # You will find the pre-trained model in icefall-asr-cv-corpus-13.0-2023-03-09-en-pruned-transducer-stateless7-2023-04-17/exp +""" + +import argparse +import logging +from pathlib import Path + +import sentencepiece as spm +import torch +import torch.nn as nn +from scaling_converter import convert_scaled_to_non_scaled +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=9, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/en/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--jit", + type=str2bool, + default=False, + help="""True to save a model after applying torch.jit.script. + It will generate a file named cpu_jit.pt + + Check ./jit_pretrained.py for how to use it. + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + add_model_arguments(parser) + + return parser + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + model.to(device) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to("cpu") + model.eval() + + if params.jit is True: + convert_scaled_to_non_scaled(model, inplace=True) + # We won't use the forward() method of the model in C++, so just ignore + # it here. + # Otherwise, one of its arguments is a ragged tensor and is not + # torch scriptabe. + model.__class__.forward = torch.jit.ignore(model.__class__.forward) + logging.info("Using torch.jit.script") + model = torch.jit.script(model) + filename = params.exp_dir / "cpu_jit.pt" + model.save(str(filename)) + logging.info(f"Saved to {filename}") + else: + logging.info("Not using torchscript. Export model.state_dict()") + # Save it using a format so that it can be loaded + # by :func:`load_checkpoint` + filename = params.exp_dir / "pretrained.pt" + torch.save({"model": model.state_dict()}, str(filename)) + logging.info(f"Saved to {filename}") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7/joiner.py b/egs/commonvoice/ASR/pruned_transducer_stateless7/joiner.py new file mode 120000 index 000000000..0f0c3c90a --- /dev/null +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7/joiner.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7/joiner.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7/model.py b/egs/commonvoice/ASR/pruned_transducer_stateless7/model.py new file mode 120000 index 000000000..0d8bc665b --- /dev/null +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7/model.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7/model.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7/onnx_check.py b/egs/commonvoice/ASR/pruned_transducer_stateless7/onnx_check.py new file mode 100755 index 000000000..19c518eaf --- /dev/null +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7/onnx_check.py @@ -0,0 +1,240 @@ +#!/usr/bin/env python3 +# +# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang, +# Yifan Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script checks that exported onnx models produce the same output +with the given torchscript model for the same input. + +We use the pre-trained model from +https://huggingface.co/yfyeung/icefall-asr-cv-corpus-13.0-2023-03-09-en-pruned-transducer-stateless7-2023-04-17 +as an example to show how to use this file. + +1. Download the pre-trained model + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/yfyeung/icefall-asr-cv-corpus-13.0-2023-03-09-en-pruned-transducer-stateless7-2023-04-17 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "data/lang_bpe_500/bpe.model" +git lfs pull --include "exp/pretrained.pt" + +cd exp +ln -s pretrained.pt epoch-9999.pt +popd + +2. Export the model via torchscript (torch.jit.script()) + +./pruned_transducer_stateless7/export.py \ + --bpe-model $repo/data/en/lang_bpe_500/bpe.model \ + --epoch 9999 \ + --avg 1 \ + --exp-dir $repo/exp/ \ + --jit 1 + +It will generate the following file in $repo/exp: + - cpu_jit.pt + +3. Export the model to ONNX + +./pruned_transducer_stateless7/export-onnx.py \ + --bpe-model $repo/data/en/lang_bpe_500/bpe.model \ + --epoch 9999 \ + --avg 1 \ + --exp-dir $repo/exp/ + +It will generate the following 3 files inside $repo/exp: + + - encoder-epoch-9999-avg-1.onnx + - decoder-epoch-9999-avg-1.onnx + - joiner-epoch-9999-avg-1.onnx + +4. Run this file + +./pruned_transducer_stateless7/onnx_check.py \ + --jit-filename $repo/exp/cpu_jit.pt \ + --onnx-encoder-filename $repo/exp/encoder-epoch-9999-avg-1.onnx \ + --onnx-decoder-filename $repo/exp/decoder-epoch-9999-avg-1.onnx \ + --onnx-joiner-filename $repo/exp/joiner-epoch-9999-avg-1.onnx +""" + +import argparse +import logging + +from icefall import is_module_available +from onnx_pretrained import OnnxModel + +import torch + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--jit-filename", + required=True, + type=str, + help="Path to the torchscript model", + ) + + parser.add_argument( + "--onnx-encoder-filename", + required=True, + type=str, + help="Path to the onnx encoder model", + ) + + parser.add_argument( + "--onnx-decoder-filename", + required=True, + type=str, + help="Path to the onnx decoder model", + ) + + parser.add_argument( + "--onnx-joiner-filename", + required=True, + type=str, + help="Path to the onnx joiner model", + ) + + return parser + + +def test_encoder( + torch_model: torch.jit.ScriptModule, + onnx_model: OnnxModel, +): + C = 80 + for i in range(3): + N = torch.randint(low=1, high=20, size=(1,)).item() + T = torch.randint(low=30, high=50, size=(1,)).item() + logging.info(f"test_encoder: iter {i}, N={N}, T={T}") + + x = torch.rand(N, T, C) + x_lens = torch.randint(low=30, high=T + 1, size=(N,)) + x_lens[0] = T + + torch_encoder_out, torch_encoder_out_lens = torch_model.encoder(x, x_lens) + torch_encoder_out = torch_model.joiner.encoder_proj(torch_encoder_out) + + onnx_encoder_out, onnx_encoder_out_lens = onnx_model.run_encoder(x, x_lens) + + assert torch.allclose(torch_encoder_out, onnx_encoder_out, atol=1e-05), ( + (torch_encoder_out - onnx_encoder_out).abs().max() + ) + + +def test_decoder( + torch_model: torch.jit.ScriptModule, + onnx_model: OnnxModel, +): + context_size = onnx_model.context_size + vocab_size = onnx_model.vocab_size + for i in range(10): + N = torch.randint(1, 100, size=(1,)).item() + logging.info(f"test_decoder: iter {i}, N={N}") + x = torch.randint( + low=1, + high=vocab_size, + size=(N, context_size), + dtype=torch.int64, + ) + torch_decoder_out = torch_model.decoder(x, need_pad=torch.tensor([False])) + torch_decoder_out = torch_model.joiner.decoder_proj(torch_decoder_out) + torch_decoder_out = torch_decoder_out.squeeze(1) + + onnx_decoder_out = onnx_model.run_decoder(x) + assert torch.allclose(torch_decoder_out, onnx_decoder_out, atol=1e-4), ( + (torch_decoder_out - onnx_decoder_out).abs().max() + ) + + +def test_joiner( + torch_model: torch.jit.ScriptModule, + onnx_model: OnnxModel, +): + encoder_dim = torch_model.joiner.encoder_proj.weight.shape[1] + decoder_dim = torch_model.joiner.decoder_proj.weight.shape[1] + for i in range(10): + N = torch.randint(1, 100, size=(1,)).item() + logging.info(f"test_joiner: iter {i}, N={N}") + encoder_out = torch.rand(N, encoder_dim) + decoder_out = torch.rand(N, decoder_dim) + + projected_encoder_out = torch_model.joiner.encoder_proj(encoder_out) + projected_decoder_out = torch_model.joiner.decoder_proj(decoder_out) + + torch_joiner_out = torch_model.joiner(encoder_out, decoder_out) + onnx_joiner_out = onnx_model.run_joiner( + projected_encoder_out, projected_decoder_out + ) + + assert torch.allclose(torch_joiner_out, onnx_joiner_out, atol=1e-4), ( + (torch_joiner_out - onnx_joiner_out).abs().max() + ) + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + logging.info(vars(args)) + + torch_model = torch.jit.load(args.jit_filename) + + onnx_model = OnnxModel( + encoder_model_filename=args.onnx_encoder_filename, + decoder_model_filename=args.onnx_decoder_filename, + joiner_model_filename=args.onnx_joiner_filename, + ) + + logging.info("Test encoder") + test_encoder(torch_model, onnx_model) + + logging.info("Test decoder") + test_decoder(torch_model, onnx_model) + + logging.info("Test joiner") + test_joiner(torch_model, onnx_model) + logging.info("Finished checking ONNX models") + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +# See https://github.com/pytorch/pytorch/issues/38342 +# and https://github.com/pytorch/pytorch/issues/33354 +# +# If we don't do this, the delay increases whenever there is +# a new request that changes the actual batch size. +# If you use `py-spy dump --pid --native`, you will +# see a lot of time is spent in re-compiling the torch script model. +torch._C._jit_set_profiling_executor(False) +torch._C._jit_set_profiling_mode(False) +torch._C._set_graph_executor_optimize(False) +if __name__ == "__main__": + torch.manual_seed(20220727) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7/onnx_pretrained.py b/egs/commonvoice/ASR/pruned_transducer_stateless7/onnx_pretrained.py new file mode 100755 index 000000000..eee19191e --- /dev/null +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7/onnx_pretrained.py @@ -0,0 +1,419 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang, +# Yifan Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads ONNX models and uses them to decode waves. +You can use the following command to get the exported models: + +We use the pre-trained model from +https://huggingface.co/yfyeung/icefall-asr-cv-corpus-13.0-2023-03-09-en-pruned-transducer-stateless7-2023-04-17 +as an example to show how to use this file. + +1. Download the pre-trained model + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/yfyeung/icefall-asr-cv-corpus-13.0-2023-03-09-en-pruned-transducer-stateless7-2023-04-17 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "data/lang_bpe_500/bpe.model" +git lfs pull --include "exp/pretrained.pt" + +cd exp +ln -s pretrained.pt epoch-9999.pt +popd + +2. Export the model to ONNX + +./pruned_transducer_stateless7/export-onnx.py \ + --bpe-model $repo/data/en/lang_bpe_500/bpe.model \ + --epoch 9999 \ + --avg 1 \ + --exp-dir $repo/exp/ + +It will generate the following 3 files inside $repo/exp: + + - encoder-epoch-9999-avg-1.onnx + - decoder-epoch-9999-avg-1.onnx + - joiner-epoch-9999-avg-1.onnx + +3. Run this file + +./pruned_transducer_stateless7/onnx_pretrained.py \ + --encoder-model-filename $repo/exp/encoder-epoch-9999-avg-1.onnx \ + --decoder-model-filename $repo/exp/decoder-epoch-9999-avg-1.onnx \ + --joiner-model-filename $repo/exp/joiner-epoch-9999-avg-1.onnx \ + --tokens $repo/data/en/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav +""" + +import argparse +import logging +import math +from typing import List, Tuple + +import k2 +import kaldifeat +import numpy as np +import onnxruntime as ort +import torch +import torchaudio +from torch.nn.utils.rnn import pad_sequence + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--encoder-model-filename", + type=str, + required=True, + help="Path to the encoder onnx model. ", + ) + + parser.add_argument( + "--decoder-model-filename", + type=str, + required=True, + help="Path to the decoder onnx model. ", + ) + + parser.add_argument( + "--joiner-model-filename", + type=str, + required=True, + help="Path to the joiner onnx model. ", + ) + + parser.add_argument( + "--tokens", + type=str, + help="""Path to tokens.txt.""", + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + return parser + + +class OnnxModel: + def __init__( + self, + encoder_model_filename: str, + decoder_model_filename: str, + joiner_model_filename: str, + ): + session_opts = ort.SessionOptions() + session_opts.inter_op_num_threads = 1 + session_opts.intra_op_num_threads = 1 + + self.session_opts = session_opts + + self.init_encoder(encoder_model_filename) + self.init_decoder(decoder_model_filename) + self.init_joiner(joiner_model_filename) + + def init_encoder(self, encoder_model_filename: str): + self.encoder = ort.InferenceSession( + encoder_model_filename, + sess_options=self.session_opts, + ) + + def init_decoder(self, decoder_model_filename: str): + self.decoder = ort.InferenceSession( + decoder_model_filename, + sess_options=self.session_opts, + ) + + decoder_meta = self.decoder.get_modelmeta().custom_metadata_map + self.context_size = int(decoder_meta["context_size"]) + self.vocab_size = int(decoder_meta["vocab_size"]) + + logging.info(f"context_size: {self.context_size}") + logging.info(f"vocab_size: {self.vocab_size}") + + def init_joiner(self, joiner_model_filename: str): + self.joiner = ort.InferenceSession( + joiner_model_filename, + sess_options=self.session_opts, + ) + + joiner_meta = self.joiner.get_modelmeta().custom_metadata_map + self.joiner_dim = int(joiner_meta["joiner_dim"]) + + logging.info(f"joiner_dim: {self.joiner_dim}") + + def run_encoder( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + x: + A 3-D tensor of shape (N, T, C) + x_lens: + A 2-D tensor of shape (N,). Its dtype is torch.int64 + Returns: + Return a tuple containing: + - encoder_out, its shape is (N, T', joiner_dim) + - encoder_out_lens, its shape is (N,) + """ + out = self.encoder.run( + [ + self.encoder.get_outputs()[0].name, + self.encoder.get_outputs()[1].name, + ], + { + self.encoder.get_inputs()[0].name: x.numpy(), + self.encoder.get_inputs()[1].name: x_lens.numpy(), + }, + ) + return torch.from_numpy(out[0]), torch.from_numpy(out[1]) + + def run_decoder(self, decoder_input: torch.Tensor) -> torch.Tensor: + """ + Args: + decoder_input: + A 2-D tensor of shape (N, context_size) + Returns: + Return a 2-D tensor of shape (N, joiner_dim) + """ + out = self.decoder.run( + [self.decoder.get_outputs()[0].name], + {self.decoder.get_inputs()[0].name: decoder_input.numpy()}, + )[0] + + return torch.from_numpy(out) + + def run_joiner( + self, encoder_out: torch.Tensor, decoder_out: torch.Tensor + ) -> torch.Tensor: + """ + Args: + encoder_out: + A 2-D tensor of shape (N, joiner_dim) + decoder_out: + A 2-D tensor of shape (N, joiner_dim) + Returns: + Return a 2-D tensor of shape (N, vocab_size) + """ + out = self.joiner.run( + [self.joiner.get_outputs()[0].name], + { + self.joiner.get_inputs()[0].name: encoder_out.numpy(), + self.joiner.get_inputs()[1].name: decoder_out.numpy(), + }, + )[0] + + return torch.from_numpy(out) + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0]) + return ans + + +def greedy_search( + model: OnnxModel, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, +) -> List[List[int]]: + """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. + Args: + model: + The transducer model. + encoder_out: + A 3-D tensor of shape (N, T, joiner_dim) + encoder_out_lens: + A 1-D tensor of shape (N,). + Returns: + Return the decoded results for each utterance. + """ + assert encoder_out.ndim == 3, encoder_out.shape + assert encoder_out.size(0) >= 1, encoder_out.size(0) + + packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( + input=encoder_out, + lengths=encoder_out_lens.cpu(), + batch_first=True, + enforce_sorted=False, + ) + + blank_id = 0 # hard-code to 0 + + batch_size_list = packed_encoder_out.batch_sizes.tolist() + N = encoder_out.size(0) + + assert torch.all(encoder_out_lens > 0), encoder_out_lens + assert N == batch_size_list[0], (N, batch_size_list) + + context_size = model.context_size + hyps = [[blank_id] * context_size for _ in range(N)] + + decoder_input = torch.tensor( + hyps, + dtype=torch.int64, + ) # (N, context_size) + + decoder_out = model.run_decoder(decoder_input) + + offset = 0 + for batch_size in batch_size_list: + start = offset + end = offset + batch_size + current_encoder_out = packed_encoder_out.data[start:end] + # current_encoder_out's shape: (batch_size, joiner_dim) + offset = end + + decoder_out = decoder_out[:batch_size] + logits = model.run_joiner(current_encoder_out, decoder_out) + + # logits'shape (batch_size, vocab_size) + + assert logits.ndim == 2, logits.shape + y = logits.argmax(dim=1).tolist() + emitted = False + for i, v in enumerate(y): + if v != blank_id: + hyps[i].append(v) + emitted = True + if emitted: + # update decoder output + decoder_input = [h[-context_size:] for h in hyps[:batch_size]] + decoder_input = torch.tensor( + decoder_input, + dtype=torch.int64, + ) + decoder_out = model.run_decoder(decoder_input) + + sorted_ans = [h[context_size:] for h in hyps] + ans = [] + unsorted_indices = packed_encoder_out.unsorted_indices.tolist() + for i in range(N): + ans.append(sorted_ans[unsorted_indices[i]]) + + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + logging.info(vars(args)) + model = OnnxModel( + encoder_model_filename=args.encoder_model_filename, + decoder_model_filename=args.decoder_model_filename, + joiner_model_filename=args.joiner_model_filename, + ) + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = "cpu" + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = args.sample_rate + opts.mel_opts.num_bins = 80 + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {args.sound_files}") + waves = read_sound_files( + filenames=args.sound_files, + expected_sample_rate=args.sample_rate, + ) + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + + features = pad_sequence( + features, + batch_first=True, + padding_value=math.log(1e-10), + ) + + feature_lengths = torch.tensor(feature_lengths, dtype=torch.int64) + encoder_out, encoder_out_lens = model.run_encoder(features, feature_lengths) + + hyps = greedy_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + s = "\n" + + symbol_table = k2.SymbolTable.from_file(args.tokens) + + def token_ids_to_words(token_ids: List[int]) -> str: + text = "" + for i in token_ids: + text += symbol_table[i] + return text.replace("▁", " ").strip() + + for filename, hyp in zip(args.sound_files, hyps): + words = token_ids_to_words(hyp) + s += f"{filename}:\n{words}\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7/optim.py b/egs/commonvoice/ASR/pruned_transducer_stateless7/optim.py new file mode 120000 index 000000000..8a05abb5f --- /dev/null +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7/optim.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7/optim.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7/pretrained.py b/egs/commonvoice/ASR/pruned_transducer_stateless7/pretrained.py new file mode 100755 index 000000000..a22d1b4ba --- /dev/null +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7/pretrained.py @@ -0,0 +1,355 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads a checkpoint and uses it to decode waves. +You can generate the checkpoint with the following command: + +./pruned_transducer_stateless7/export.py \ + --exp-dir ./pruned_transducer_stateless7/exp \ + --bpe-model data/en/lang_bpe_500/bpe.model \ + --epoch 30 \ + --avg 5 + +Usage of this script: + +(1) greedy search +./pruned_transducer_stateless7/pretrained.py \ + --checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \ + --bpe-model ./data/en/lang_bpe_500/bpe.model \ + --method greedy_search \ + /path/to/foo.wav \ + /path/to/bar.wav + +(2) beam search +./pruned_transducer_stateless7/pretrained.py \ + --checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \ + --bpe-model ./data/en/lang_bpe_500/bpe.model \ + --method beam_search \ + --beam-size 4 \ + /path/to/foo.wav \ + /path/to/bar.wav + +(3) modified beam search +./pruned_transducer_stateless7/pretrained.py \ + --checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \ + --bpe-model ./data/en/lang_bpe_500/bpe.model \ + --method modified_beam_search \ + --beam-size 4 \ + /path/to/foo.wav \ + /path/to/bar.wav + +(4) fast beam search +./pruned_transducer_stateless7/pretrained.py \ + --checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \ + --bpe-model ./data/en/lang_bpe_500/bpe.model \ + --method fast_beam_search \ + --beam-size 4 \ + /path/to/foo.wav \ + /path/to/bar.wav + +You can also use `./pruned_transducer_stateless7/exp/epoch-xx.pt`. + +Note: ./pruned_transducer_stateless7/exp/pretrained.pt is generated by +./pruned_transducer_stateless7/export.py +""" + + +import argparse +import logging +import math +from typing import List + +import k2 +import kaldifeat +import sentencepiece as spm +import torch +import torchaudio +from beam_search import ( + beam_search, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, +) +from torch.nn.utils.rnn import pad_sequence +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.utils import str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--checkpoint", + type=str, + required=True, + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", + ) + + parser.add_argument( + "--bpe-model", + type=str, + help="""Path to bpe.model.""", + ) + + parser.add_argument( + "--method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - fast_beam_search + """, + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=4, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=4, + help="""Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=8, + help="""Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. Used only when + --method is greedy_search. + """, + ) + + add_model_arguments(parser) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0]) + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + + params = get_params() + + params.update(vars(args)) + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(f"{params}") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + logging.info("Creating model") + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + checkpoint = torch.load(args.checkpoint, map_location="cpu") + model.load_state_dict(checkpoint["model"], strict=False) + model.to(device) + model.eval() + model.device = device + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = params.sample_rate + opts.mel_opts.num_bins = params.feature_dim + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {params.sound_files}") + waves = read_sound_files( + filenames=params.sound_files, expected_sample_rate=params.sample_rate + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + + feature_lengths = torch.tensor(feature_lengths, device=device) + + encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths) + + num_waves = encoder_out.size(0) + hyps = [] + msg = f"Using {params.method}" + if params.method == "beam_search": + msg += f" with beam size {params.beam_size}" + logging.info(msg) + + if params.method == "fast_beam_search": + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + ) + + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + else: + for i in range(num_waves): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError(f"Unsupported method: {params.method}") + + hyps.append(sp.decode(hyp).split()) + + s = "\n" + for filename, hyp in zip(params.sound_files, hyps): + words = " ".join(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7/scaling.py b/egs/commonvoice/ASR/pruned_transducer_stateless7/scaling.py new file mode 120000 index 000000000..5f9be9fe0 --- /dev/null +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7/scaling.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7/scaling.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7/train.py b/egs/commonvoice/ASR/pruned_transducer_stateless7/train.py new file mode 100755 index 000000000..73a29a90a --- /dev/null +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7/train.py @@ -0,0 +1,1250 @@ +#!/usr/bin/env python3 +# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo,) +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +./pruned_transducer_stateless7/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --exp-dir pruned_transducer_stateless7/exp \ + --max-duration 300 + +# For mix precision training: + +./pruned_transducer_stateless7/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir pruned_transducer_stateless7/exp \ + --max-duration 550 + +""" + + +import argparse +import copy +import logging +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import k2 +import optim +import sentencepiece as spm +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import CommonVoiceAsrDataModule +from decoder import Decoder +from joiner import Joiner +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model import Transducer +from optim import Eden, ScaledAdam +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from zipformer import Zipformer + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.utils import ( + AttributeDict, + MetricsTracker, + filter_uneven_sized_batch, + setup_logger, + str2bool, +) + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for module in model.modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=str, + default="2,4,3,2,4", + help="Number of zipformer encoder layers, comma separated.", + ) + + parser.add_argument( + "--feedforward-dims", + type=str, + default="1024,1024,2048,2048,1024", + help="Feedforward dimension of the zipformer encoder layers, comma separated.", + ) + + parser.add_argument( + "--nhead", + type=str, + default="8,8,8,8,8", + help="Number of attention heads in the zipformer encoder layers.", + ) + + parser.add_argument( + "--encoder-dims", + type=str, + default="384,384,384,384,384", + help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated", + ) + + parser.add_argument( + "--attention-dims", + type=str, + default="192,192,192,192,192", + help="""Attention dimension in the 2 blocks of zipformer encoder layers, comma separated; + not the same as embedding dimension.""", + ) + + parser.add_argument( + "--encoder-unmasked-dims", + type=str, + default="256,256,256,256,256", + help="Unmasked dimensions in the encoders, relates to augmentation during training. " + "Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance " + " worse.", + ) + + parser.add_argument( + "--zipformer-downsampling-factors", + type=str, + default="1,2,4,8,2", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--cnn-module-kernels", + type=str, + default="31,31,31,31,31", + help="Sizes of kernels in convolution modules", + ) + + parser.add_argument( + "--decoder-dim", + type=int, + default=512, + help="Embedding dimension in the decoder model.", + ) + + parser.add_argument( + "--joiner-dim", + type=int, + default=512, + help="""Dimension used in the joiner model. + Outputs from the encoder and decoder model are projected + to this dimension before adding. + """, + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/en/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--base-lr", type=float, default=0.05, help="The base learning rate." + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=5000, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=3.5, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network) part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=2000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 0. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - encoder_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warmup period that dictates the decay of the + scale on "simple" (un-pruned) loss. + """ + params = AttributeDict( + { + "frame_shift_ms": 10.0, + "allowed_excess_duration_ratio": 0.1, + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, # For the 100h subset, use 800 + # parameters for zipformer + "feature_dim": 80, + "subsampling_factor": 4, # not passed in, this is fixed. + "warm_step": 2000, + "env_info": get_env_info(), + } + ) + + return params + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + # TODO: We can add an option to switch between Zipformer and Transformer + def to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + + encoder = Zipformer( + num_features=params.feature_dim, + output_downsampling_factor=2, + zipformer_downsampling_factors=to_int_tuple( + params.zipformer_downsampling_factors + ), + encoder_dims=to_int_tuple(params.encoder_dims), + attention_dim=to_int_tuple(params.attention_dims), + encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims), + nhead=to_int_tuple(params.nhead), + feedforward_dim=to_int_tuple(params.feedforward_dims), + cnn_module_kernels=to_int_tuple(params.cnn_module_kernels), + num_encoder_layers=to_int_tuple(params.num_encoder_layers), + ) + return encoder + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + decoder_dim=params.decoder_dim, + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + joiner = Joiner( + encoder_dim=int(params.encoder_dims.split(",")[-1]), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return joiner + + +def get_transducer_model(params: AttributeDict) -> nn.Module: + encoder = get_encoder_model(params) + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + + model = Transducer( + encoder=encoder, + decoder=decoder, + joiner=joiner, + encoder_dim=int(params.encoder_dims.split(",")[-1]), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + if "cur_batch_idx" in saved_params: + params["cur_batch_idx"] = saved_params["cur_batch_idx"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute transducer loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + # For the uneven-sized batch, the total duration after padding would possibly + # cause OOM. Hence, for each batch, which is sorted descendingly by length, + # we simply drop the last few shortest samples, so that the retained total frames + # (after padding) would not exceed `allowed_max_frames`: + # `allowed_max_frames = int(max_frames * (1.0 + allowed_excess_duration_ratio))`, + # where `max_frames = max_duration * 1000 // frame_shift_ms`. + # We set allowed_excess_duration_ratio=0.1. + max_frames = params.max_duration * 1000 // params.frame_shift_ms + allowed_max_frames = int(max_frames * (1.0 + params.allowed_excess_duration_ratio)) + batch = filter_uneven_sized_batch(batch, allowed_max_frames) + + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + batch_idx_train = params.batch_idx_train + warm_step = params.warm_step + + texts = batch["supervisions"]["text"] + y = sp.encode(texts, out_type=int) + y = k2.RaggedTensor(y).to(device) + + with torch.set_grad_enabled(is_training): + simple_loss, pruned_loss = model( + x=feature, + x_lens=feature_lens, + y=y, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + ) + + s = params.simple_loss_scale + # take down the scale on the simple loss from 1.0 at the start + # to params.simple_loss scale by warm_step. + simple_loss_scale = ( + s + if batch_idx_train >= warm_step + else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) + ) + pruned_loss_scale = ( + 1.0 + if batch_idx_train >= warm_step + else 0.1 + 0.9 * (batch_idx_train / warm_step) + ) + + loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + sp: spm.SentencePieceProcessor, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + cur_batch_idx = params.get("cur_batch_idx", 0) + + for batch_idx, batch in enumerate(train_dl): + if batch_idx < cur_batch_idx: + continue + cur_batch_idx = batch_idx + + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + set_batch_count(model, params.batch_idx_train) + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except: # noqa + display_and_save_batch(batch, params=params, sp=sp) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + params.cur_batch_idx = batch_idx + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + del params.cur_batch_idx + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if batch_idx % params.log_interval == 0: + cur_lr = scheduler.get_last_lr()[0] + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", + cur_grad_scale, + params.batch_idx_train, + ) + + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + parameters_names = [] + parameters_names.append( + [name_param_pair[0] for name_param_pair in model.named_parameters()] + ) + optimizer = ScaledAdam( + model.parameters(), + lr=params.base_lr, + clipping_scale=2.0, + parameters_names=parameters_names, + ) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 2**22 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + commonvoice = CommonVoiceAsrDataModule(args) + + train_cuts = commonvoice.train_cuts() + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + if c.duration < 1.0 or c.duration > 20.0: + logging.warning( + f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + ) + return False + + # In pruned RNN-T, we require that T >= S + # where T is the number of feature frames after subsampling + # and S is the number of tokens in the utterance + + # In ./zipformer.py, the conv module uses the following expression + # for subsampling + T = ((c.num_frames - 7) // 2 + 1) // 2 + tokens = sp.encode(c.supervisions[0].text, out_type=str) + + if T < len(tokens): + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Number of frames (before subsampling): {c.num_frames}. " + f"Number of frames (after subsampling): {T}. " + f"Text: {c.supervisions[0].text}. " + f"Tokens: {tokens}. " + f"Number of tokens: {len(tokens)}" + ) + return False + + return True + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = commonvoice.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) + + valid_cuts = commonvoice.dev_cuts() + valid_dl = commonvoice.valid_dataloaders(valid_cuts) + + if not params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + sp=sp, + params=params, + ) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sp=sp, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + sp: spm.SentencePieceProcessor, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + sp: + The BPE model. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + y = sp.encode(supervisions["text"], out_type=int) + num_tokens = sum(len(i) for i in y) + logging.info(f"num tokens: {num_tokens}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + sp: spm.SentencePieceProcessor, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params, sp=sp) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + parser = get_parser() + CommonVoiceAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7/zipformer.py b/egs/commonvoice/ASR/pruned_transducer_stateless7/zipformer.py new file mode 120000 index 000000000..f2f66041e --- /dev/null +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7/zipformer.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7/zipformer.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/scaling_converter.py b/egs/commonvoice/ASR/scaling_converter.py new file mode 120000 index 000000000..f06434a2c --- /dev/null +++ b/egs/commonvoice/ASR/scaling_converter.py @@ -0,0 +1 @@ +../../librispeech/ASR/pruned_transducer_stateless7/scaling_converter.py \ No newline at end of file From 05e7435d0d36789ba947b3ca664b730b8d79cb92 Mon Sep 17 00:00:00 2001 From: Yifan Yang <64255737+yfyeung@users.noreply.github.com> Date: Tue, 18 Apr 2023 10:11:12 +0800 Subject: [PATCH 159/174] Move soft links into proper position (#1007) --- .../ASR/pruned_transducer_stateless7/scaling_converter.py | 1 + egs/commonvoice/ASR/scaling_converter.py | 1 - 2 files changed, 1 insertion(+), 1 deletion(-) create mode 120000 egs/commonvoice/ASR/pruned_transducer_stateless7/scaling_converter.py delete mode 120000 egs/commonvoice/ASR/scaling_converter.py diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7/scaling_converter.py b/egs/commonvoice/ASR/pruned_transducer_stateless7/scaling_converter.py new file mode 120000 index 000000000..f9960e5c6 --- /dev/null +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7/scaling_converter.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7/scaling_converter.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/scaling_converter.py b/egs/commonvoice/ASR/scaling_converter.py deleted file mode 120000 index f06434a2c..000000000 --- a/egs/commonvoice/ASR/scaling_converter.py +++ /dev/null @@ -1 +0,0 @@ -../../librispeech/ASR/pruned_transducer_stateless7/scaling_converter.py \ No newline at end of file From 78b9dcc936e271ecabb5017da2e4a7404e314770 Mon Sep 17 00:00:00 2001 From: Wen Ding Date: Tue, 18 Apr 2023 17:05:08 +0800 Subject: [PATCH 160/174] Support exporting BS Zipformer models to ONNX, used in Triton Server (#1008) * Support export BS Zipformer models to ONNX in Tritron * Update copyright * Update exporting codes for BS zipformer models * Code format * Update comments * Update export_onnx.py --------- Co-authored-by: Yifan Yang <64255737+yfyeung@users.noreply.github.com> --- .../export_onnx.py | 280 ++++++++++++++++-- .../onnx_wrapper.py | 98 ++++++ 2 files changed, 354 insertions(+), 24 deletions(-) create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/onnx_wrapper.py diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/export_onnx.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/export_onnx.py index 50efa6e60..630a7f735 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/export_onnx.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/export_onnx.py @@ -2,6 +2,7 @@ # # Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang, # Yifan Yang) +# 2023 NVIDIA Corporation (Author: Wen Ding) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -29,7 +30,8 @@ Usage: --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ --bpe-model data/lang_bpe_500/bpe.model \ --epoch 30 \ - --avg 13 + --avg 13 \ + --onnx 1 It will generate the following files in the given `exp_dir`. Check `onnx_check.py` for how to use them. @@ -41,6 +43,25 @@ Check `onnx_check.py` for how to use them. - joiner_decoder_proj.onnx - lconv.onnx - frame_reducer.onnx + - ctc_output.onnx + +(2) Export to ONNX format which can be used in Triton Server +./pruned_transducer_stateless7_ctc_bs/export_onnx.py \ + --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 30 \ + --avg 13 \ + --onnx-triton 1 + +It will generate the following files in the given `exp_dir`. + + - encoder.onnx + - decoder.onnx + - joiner.onnx + - joiner_encoder_proj.onnx + - joiner_decoder_proj.onnx + - lconv.onnx + - ctc_output.onnx Please see ./onnx_pretrained.py for usage of the generated files @@ -78,6 +99,7 @@ from icefall.checkpoint import ( load_checkpoint, ) from icefall.utils import str2bool +from onnx_wrapper import TritonOnnxDecoder, TritonOnnxJoiner, TritonOnnxLconv def get_parser(): @@ -143,9 +165,10 @@ def get_parser(): parser.add_argument( "--onnx", type=str2bool, - default=True, + default=False, help="""If True, --jit is ignored and it exports the model - to onnx format. It will generate the following files: + to onnx format. + It will generate the following files: - encoder.onnx - decoder.onnx @@ -154,10 +177,28 @@ def get_parser(): - joiner_decoder_proj.onnx - lconv.onnx - frame_reducer.onnx + - ctc_output.onnx Refer to ./onnx_check.py and ./onnx_pretrained.py for how to use them. """, ) + parser.add_argument( + "--onnx-triton", + type=str2bool, + default=False, + help="""If True, and it exports the model + to onnx format which can be used in NVIDIA triton server. + It will generate the following files: + + - encoder.onnx + - decoder.onnx + - joiner.onnx + - joiner_encoder_proj.onnx + - joiner_decoder_proj.onnx + - lconv.onnx + - ctc_output.onnx + """, + ) parser.add_argument( "--context-size", @@ -273,6 +314,44 @@ def export_decoder_model_onnx( logging.info(f"Saved to {decoder_filename}") +def export_decoder_model_onnx_triton( + decoder_model: nn.Module, + decoder_filename: str, + opset_version: int = 11, +) -> None: + """Export the decoder model to ONNX-Triton format. + The exported model has one input: + - y: a torch.int64 tensor of shape (N, decoder_model.context_size) + and has one output: + - decoder_out: a torch.float32 tensor of shape (N, 1, C) + Args: + decoder_model: + The decoder model to be exported. + decoder_filename: + Filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64) + + decoder_model = TritonOnnxDecoder(decoder_model) + + torch.onnx.export( + decoder_model, + (y), + decoder_filename, + verbose=False, + opset_version=opset_version, + input_names=["y"], + output_names=["decoder_out"], + dynamic_axes={ + "y": {0: "N"}, + "decoder_out": {0: "N"}, + }, + ) + logging.info(f"Saved to {decoder_filename}") + + def export_joiner_model_onnx( joiner_model: nn.Module, joiner_filename: str, @@ -369,6 +448,91 @@ def export_joiner_model_onnx( logging.info(f"Saved to {decoder_proj_filename}") +def export_joiner_model_onnx_triton( + joiner_model: nn.Module, + joiner_filename: str, + opset_version: int = 11, +) -> None: + """Export the joiner model to ONNX format. + The exported joiner model has two inputs: + - projected_encoder_out: a tensor of shape (N, joiner_dim) + - projected_decoder_out: a tensor of shape (N, joiner_dim) + and produces one output: + - logit: a tensor of shape (N, vocab_size) + The exported encoder_proj model has one input: + - encoder_out: a tensor of shape (N, encoder_out_dim) + and produces one output: + - projected_encoder_out: a tensor of shape (N, joiner_dim) + The exported decoder_proj model has one input: + - decoder_out: a tensor of shape (N, decoder_out_dim) + and produces one output: + - projected_decoder_out: a tensor of shape (N, joiner_dim) + """ + encoder_proj_filename = str(joiner_filename).replace(".onnx", "_encoder_proj.onnx") + decoder_proj_filename = str(joiner_filename).replace(".onnx", "_decoder_proj.onnx") + + encoder_out_dim = joiner_model.encoder_proj.weight.shape[1] + decoder_out_dim = joiner_model.decoder_proj.weight.shape[1] + joiner_dim = joiner_model.decoder_proj.weight.shape[0] + + projected_encoder_out = torch.rand(1, joiner_dim, dtype=torch.float32) + projected_decoder_out = torch.rand(1, joiner_dim, dtype=torch.float32) + + # Note: It uses torch.jit.trace() internally + joiner_model = TritonOnnxJoiner(joiner_model) + torch.onnx.export( + joiner_model, + (projected_encoder_out, projected_decoder_out), + joiner_filename, + verbose=False, + opset_version=opset_version, + input_names=[ + "encoder_out", + "decoder_out", + "project_input", + ], + output_names=["logit"], + dynamic_axes={ + "encoder_out": {0: "N"}, + "decoder_out": {0: "N"}, + "logit": {0: "N"}, + }, + ) + logging.info(f"Saved to {joiner_filename}") + + encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32) + torch.onnx.export( + joiner_model.encoder_proj, + encoder_out, + encoder_proj_filename, + verbose=False, + opset_version=opset_version, + input_names=["encoder_out"], + output_names=["projected_encoder_out"], + dynamic_axes={ + "encoder_out": {0: "N"}, + "projected_encoder_out": {0: "N"}, + }, + ) + logging.info(f"Saved to {encoder_proj_filename}") + + decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32) + torch.onnx.export( + joiner_model.decoder_proj, + decoder_out, + decoder_proj_filename, + verbose=False, + opset_version=opset_version, + input_names=["decoder_out"], + output_names=["projected_decoder_out"], + dynamic_axes={ + "decoder_out": {0: "N"}, + "projected_decoder_out": {0: "N"}, + }, + ) + logging.info(f"Saved to {decoder_proj_filename}") + + def export_lconv_onnx( lconv: nn.Module, lconv_filename: str, @@ -413,6 +577,52 @@ def export_lconv_onnx( logging.info(f"Saved to {lconv_filename}") +def export_lconv_onnx_triton( + lconv: nn.Module, + lconv_filename: str, + opset_version: int = 11, +) -> None: + """Export the lconv to ONNX format. + + The exported lconv has two inputs: + + - lconv_input: a tensor of shape (N, T, C) + - lconv_input_lens: a tensor of shape (N, ) + + and has one output: + + - lconv_out: a tensor of shape (N, T, C) + + Args: + lconv: + The lconv to be exported. + lconv_filename: + Filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + lconv_input = torch.zeros(15, 498, 384, dtype=torch.float32) + lconv_input_lens = torch.tensor([498] * 15, dtype=torch.int64) + + lconv = TritonOnnxLconv(lconv) + + torch.onnx.export( + lconv, + (lconv_input, lconv_input_lens), + lconv_filename, + verbose=False, + opset_version=opset_version, + input_names=["lconv_input", "lconv_input_lens"], + output_names=["lconv_out"], + dynamic_axes={ + "lconv_input": {0: "N", 1: "T"}, + "lconv_input_lens": {0: "N"}, + "lconv_out": {0: "N", 1: "T"}, + }, + ) + logging.info(f"Saved to {lconv_filename}") + + def export_frame_reducer_onnx( frame_reducer: nn.Module, frame_reducer_filename: str, @@ -623,32 +833,54 @@ def main(): ) decoder_filename = params.exp_dir / "decoder.onnx" - export_decoder_model_onnx( - model.decoder, - decoder_filename, - opset_version=opset_version, - ) + if params.onnx is True: + export_decoder_model_onnx( + model.decoder, + decoder_filename, + opset_version=opset_version, + ) + elif params.onnx_triton is True: + export_decoder_model_onnx_triton( + model.decoder, + decoder_filename, + opset_version=opset_version, + ) joiner_filename = params.exp_dir / "joiner.onnx" - export_joiner_model_onnx( - model.joiner, - joiner_filename, - opset_version=opset_version, - ) + if params.onnx is True: + export_joiner_model_onnx( + model.joiner, + joiner_filename, + opset_version=opset_version, + ) + elif params.onnx_triton is True: + export_joiner_model_onnx_triton( + model.joiner, + joiner_filename, + opset_version=opset_version, + ) lconv_filename = params.exp_dir / "lconv.onnx" - export_lconv_onnx( - model.lconv, - lconv_filename, - opset_version=opset_version, - ) + if params.onnx is True: + export_lconv_onnx( + model.lconv, + lconv_filename, + opset_version=opset_version, + ) + elif params.onnx_triton is True: + export_lconv_onnx_triton( + model.lconv, + lconv_filename, + opset_version=opset_version, + ) - frame_reducer_filename = params.exp_dir / "frame_reducer.onnx" - export_frame_reducer_onnx( - model.frame_reducer, - frame_reducer_filename, - opset_version=opset_version, - ) + if params.onnx is True: + frame_reducer_filename = params.exp_dir / "frame_reducer.onnx" + export_frame_reducer_onnx( + model.frame_reducer, + frame_reducer_filename, + opset_version=opset_version, + ) ctc_output_filename = params.exp_dir / "ctc_output.onnx" export_ctc_output_onnx( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/onnx_wrapper.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/onnx_wrapper.py new file mode 100755 index 000000000..247da0949 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/onnx_wrapper.py @@ -0,0 +1,98 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from torch import nn +from icefall.utils import make_pad_mask + + +class TritonOnnxDecoder(nn.Module): + """ + Triton wrapper for decoder model + """ + + def __init__(self, model): + """ + Args: + model: decoder model + """ + super().__init__() + + self.model = model + + def forward(self, y: torch.Tensor) -> torch.Tensor: + """ + Args: + y: + A 2-D tensor of shape (N, U). + Returns: + Return a tensor of shape (N, U, decoder_dim). + """ + need_pad = False + return self.model(y, need_pad) + + +class TritonOnnxJoiner(nn.Module): + def __init__( + self, + model, + ): + super().__init__() + + self.model = model + self.encoder_proj = model.encoder_proj + self.decoder_proj = model.decoder_proj + + def forward( + self, + encoder_out: torch.Tensor, + decoder_out: torch.Tensor, + ) -> torch.Tensor: + """ + Args: + encoder_out: + Output from the encoder. Its shape is (N, T, C). + decoder_out: + Output from the decoder. Its shape is (N, T, C). + Returns: + Return a tensor of shape (N, T, C). + """ + project_input = False + return self.model(encoder_out, decoder_out, project_input) + + +class TritonOnnxLconv(nn.Module): + def __init__( + self, + model, + ): + super().__init__() + + self.model = model + + def forward( + self, + lconv_input: torch.Tensor, + lconv_input_lens: torch.Tensor, + ) -> torch.Tensor: + """ + Args: + lconv_input: Its shape is (N, T, C). + lconv_input_lens: Its shape is (N, ). + Returns: + Return a tensor of shape (N, T, C). + """ + mask = make_pad_mask(lconv_input_lens) + + return self.model(x=lconv_input, src_key_padding_mask=mask) From 81d386ef3e3bd68a30000b92a19de9676a0ec828 Mon Sep 17 00:00:00 2001 From: Yifan Yang <64255737+yfyeung@users.noreply.github.com> Date: Thu, 20 Apr 2023 12:27:43 +0800 Subject: [PATCH 161/174] Add compute_ppl.py and ngram_entropy_pruning.py (#1013) --- .../compute_ppl.py | 109 +++ icefall/shared/ngram_entropy_pruning.py | 630 ++++++++++++++++++ 2 files changed, 739 insertions(+) create mode 100755 egs/gigaspeech/ASR/pruned_transducer_stateless2/compute_ppl.py create mode 100755 icefall/shared/ngram_entropy_pruning.py diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/compute_ppl.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/compute_ppl.py new file mode 100755 index 000000000..76306fc4c --- /dev/null +++ b/egs/gigaspeech/ASR/pruned_transducer_stateless2/compute_ppl.py @@ -0,0 +1,109 @@ +#!/usr/bin/env python3 +# +# Copyright 2023 Xiaomi Corp. (Author: Yifan Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +./pruned_transducer_stateless7/compute_ppl.py \ + --ngram-lm-path ./download/lm/3gram_pruned_1e7.arpa + +""" + + +import argparse +import logging +import math +from typing import Dict, List, Optional, Tuple + +import kenlm +import torch +from asr_datamodule import GigaSpeechAsrDataModule + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--ngram-lm-path", + type=str, + default="download/lm/3gram_pruned_1e7.arpa", + help="The lang dir containing word table and LG graph", + ) + + return parser + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + model: kenlm.Model, +) -> Dict[str, float]: + """ + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + model: + A ngram lm of kenlm.Model object. + Returns: + Return the perplexity of the giving dataset. + """ + sum_score_log = 0 + sum_n = 0 + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + for text in texts: + sum_n += len(text.split()) + 1 + sum_score_log += -1 * model.score(text) + + ppl = math.pow(10.0, sum_score_log / sum_n) + + return ppl + + +def main(): + parser = get_parser() + GigaSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + + logging.info("About to load ngram LM") + model = kenlm.Model(args.ngram_lm_path) + + gigaspeech = GigaSpeechAsrDataModule(args) + + dev_cuts = gigaspeech.dev_cuts() + test_cuts = gigaspeech.test_cuts() + + dev_dl = gigaspeech.test_dataloaders(dev_cuts) + test_dl = gigaspeech.test_dataloaders(test_cuts) + + test_sets = ["dev", "test"] + test_dls = [dev_dl, test_dl] + + for test_set, test_dl in zip(test_sets, test_dls): + ppl = decode_dataset( + dl=test_dl, + model=model, + ) + logging.info(f"{test_set} PPL: {ppl}") + + logging.info("Done!") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/icefall/shared/ngram_entropy_pruning.py b/icefall/shared/ngram_entropy_pruning.py new file mode 100755 index 000000000..b1ebee9ea --- /dev/null +++ b/icefall/shared/ngram_entropy_pruning.py @@ -0,0 +1,630 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# +# Copyright 2021 Johns Hopkins University (Author: Ruizhe Huang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +./ngram_entropy_pruning.py \ + -threshold 1e-8 \ + -lm download/lm/4gram.arpa \ + -write-lm download/lm/4gram_pruned_1e8.arpa + +This file is from Kaldi `egs/wsj/s5/utils/lang/ngram_entropy_pruning.py`. +This is an implementation of ``Entropy-based Pruning of Backoff Language Models'' +in the same way as SRILM. +""" + + +import argparse +import gzip +import logging +import math +import re +from collections import OrderedDict, defaultdict +from enum import Enum, unique +from io import StringIO + +parser = argparse.ArgumentParser( + description=""" + Prune an n-gram language model based on the relative entropy + between the original and the pruned model, based on Andreas Stolcke's paper. + An n-gram entry is removed, if the removal causes (training set) perplexity + of the model to increase by less than threshold relative. + + The command takes an arpa file and a pruning threshold as input, + and outputs a pruned arpa file. + """ +) +parser.add_argument("-threshold", type=float, default=1e-6, help="Order of n-gram") +parser.add_argument("-lm", type=str, default=None, help="Path to the input arpa file") +parser.add_argument( + "-write-lm", type=str, default=None, help="Path to output arpa file after pruning" +) +parser.add_argument( + "-minorder", + type=int, + default=1, + help="The minorder parameter limits pruning to ngrams of that length and above.", +) +parser.add_argument( + "-encoding", type=str, default="utf-8", help="Encoding of the arpa file" +) +parser.add_argument( + "-verbose", + type=int, + default=2, + choices=[0, 1, 2, 3, 4, 5], + help="Verbose level, where 0 is most noisy; 5 is most silent", +) +args = parser.parse_args() + +default_encoding = args.encoding +logging.basicConfig( + format="%(asctime)s — %(levelname)s — %(funcName)s:%(lineno)d — %(message)s", + level=args.verbose * 10, +) + + +class Context(dict): + """ + This class stores data for a context h. + It behaves like a python dict object, except that it has several + additional attributes. + """ + + def __init__(self): + super().__init__() + self.log_bo = None + + +class Arpa: + """ + This is a class that implement the data structure of an APRA LM. + It (as well as some other classes) is modified based on the library + by Stefan Fischer: + https://github.com/sfischer13/python-arpa + """ + + UNK = "" + SOS = "" + EOS = "" + FLOAT_NDIGITS = 7 + base = 10 + + @staticmethod + def _check_input(my_input): + if not my_input: + raise ValueError + elif isinstance(my_input, tuple): + return my_input + elif isinstance(my_input, list): + return tuple(my_input) + elif isinstance(my_input, str): + return tuple(my_input.strip().split(" ")) + else: + raise ValueError + + @staticmethod + def _check_word(input_word): + if not isinstance(input_word, str): + raise ValueError + if " " in input_word: + raise ValueError + + def _replace_unks(self, words): + return tuple((w if w in self else self._unk) for w in words) + + def __init__(self, path=None, encoding=None, unk=None): + self._counts = OrderedDict() + self._ngrams = ( + OrderedDict() + ) # Use self._ngrams[len(h)][h][w] for saving the entry of (h,w) + self._vocabulary = set() + if unk is None: + self._unk = self.UNK + + if path is not None: + self.loadf(path, encoding) + + def __contains__(self, ngram): + h = ngram[:-1] # h is a tuple + w = ngram[-1] # w is a string/word + return h in self._ngrams[len(h)] and w in self._ngrams[len(h)][h] + + def contains_word(self, word): + self._check_word(word) + return word in self._vocabulary + + def add_count(self, order, count): + self._counts[order] = count + self._ngrams[order - 1] = defaultdict(Context) + + def update_counts(self): + for order in range(1, self.order() + 1): + count = sum([len(wlist) for _, wlist in self._ngrams[order - 1].items()]) + if count > 0: + self._counts[order] = count + + def add_entry(self, ngram, p, bo=None, order=None): + # Note: ngram is a tuple of strings, e.g. ("w1", "w2", "w3") + h = ngram[:-1] # h is a tuple + w = ngram[-1] # w is a string/word + + # Note that p and bo here are in fact in the log domain (self.base = 10) + h_context = self._ngrams[len(h)][h] + h_context[w] = p + if bo is not None: + self._ngrams[len(ngram)][ngram].log_bo = bo + + for word in ngram: + self._vocabulary.add(word) + + def counts(self): + return sorted(self._counts.items()) + + def order(self): + return max(self._counts.keys(), default=None) + + def vocabulary(self, sort=True): + if sort: + return sorted(self._vocabulary) + else: + return self._vocabulary + + def _entries(self, order): + return ( + self._entry(h, w) + for h, wlist in self._ngrams[order - 1].items() + for w in wlist + ) + + def _entry(self, h, w): + # return the entry for the ngram (h, w) + ngram = h + (w,) + log_p = self._ngrams[len(h)][h][w] + log_bo = self._log_bo(ngram) + if log_bo is not None: + return ( + round(log_p, self.FLOAT_NDIGITS), + ngram, + round(log_bo, self.FLOAT_NDIGITS), + ) + else: + return round(log_p, self.FLOAT_NDIGITS), ngram + + def _log_bo(self, ngram): + if len(ngram) in self._ngrams and ngram in self._ngrams[len(ngram)]: + return self._ngrams[len(ngram)][ngram].log_bo + else: + return None + + def _log_p(self, ngram): + h = ngram[:-1] # h is a tuple + w = ngram[-1] # w is a string/word + if h in self._ngrams[len(h)] and w in self._ngrams[len(h)][h]: + return self._ngrams[len(h)][h][w] + else: + return None + + def log_p_raw(self, ngram): + log_p = self._log_p(ngram) + if log_p is not None: + return log_p + else: + if len(ngram) == 1: + raise KeyError + else: + log_bo = self._log_bo(ngram[:-1]) + if log_bo is None: + log_bo = 0 + return log_bo + self.log_p_raw(ngram[1:]) + + def log_joint_prob(self, sequence): + # Compute the joint prob of the sequence based on the chain rule + # Note that sequence should be a tuple of strings + # + # Reference: + # https://github.com/BitSpeech/SRILM/blob/d571a4424fb0cf08b29fbfccfddd092ea969eae3/lm/src/LM.cc#L527 + + log_joint_p = 0 + seq = sequence + while len(seq) > 0: + log_joint_p += self.log_p_raw(seq) + seq = seq[:-1] + + # If we're computing the marginal probability of the unigram + # context we have to look up instead since the former + # has prob = 0. + if len(seq) == 1 and seq[0] == self.SOS: + seq = (self.EOS,) + + return log_joint_p + + def set_new_context(self, h): + old_context = self._ngrams[len(h)][h] + self._ngrams[len(h)][h] = Context() + return old_context + + def log_p(self, ngram): + words = self._check_input(ngram) + if self._unk: + words = self._replace_unks(words) + return self.log_p_raw(words) + + def log_s(self, sentence, sos=SOS, eos=EOS): + words = self._check_input(sentence) + if self._unk: + words = self._replace_unks(words) + if sos: + words = (sos,) + words + if eos: + words = words + (eos,) + result = sum(self.log_p_raw(words[:i]) for i in range(1, len(words) + 1)) + if sos: + result = result - self.log_p_raw(words[:1]) + return result + + def p(self, ngram): + return self.base ** self.log_p(ngram) + + def s(self, sentence): + return self.base ** self.log_s(sentence) + + def write(self, fp): + fp.write("\n\\data\\\n") + for order, count in self.counts(): + fp.write("ngram {}={}\n".format(order, count)) + fp.write("\n") + for order, _ in self.counts(): + fp.write("\\{}-grams:\n".format(order)) + for e in self._entries(order): + prob = e[0] + ngram = " ".join(e[1]) + if len(e) == 2: + fp.write("{}\t{}\n".format(prob, ngram)) + elif len(e) == 3: + backoff = e[2] + fp.write("{}\t{}\t{}\n".format(prob, ngram, backoff)) + else: + raise ValueError + fp.write("\n") + fp.write("\\end\\\n") + + +class ArpaParser: + """ + This is a class that implement a parser of an arpa file + """ + + @unique + class State(Enum): + DATA = 1 + COUNT = 2 + HEADER = 3 + ENTRY = 4 + + re_count = re.compile(r"^ngram (\d+)=(\d+)$") + re_header = re.compile(r"^\\(\d+)-grams:$") + re_entry = re.compile( + "^(-?\\d+(\\.\\d+)?([eE]-?\\d+)?)" + "\t" + "(\\S+( \\S+)*)" + "(\t((-?\\d+(\\.\\d+)?)([eE]-?\\d+)?))?$" + ) + + def _parse(self, fp): + self._result = [] + self._state = self.State.DATA + self._tmp_model = None + self._tmp_order = None + for line in fp: + line = line.strip() + if self._state == self.State.DATA: + self._data(line) + elif self._state == self.State.COUNT: + self._count(line) + elif self._state == self.State.HEADER: + self._header(line) + elif self._state == self.State.ENTRY: + self._entry(line) + if self._state != self.State.DATA: + raise Exception(line) + return self._result + + def _data(self, line): + if line == "\\data\\": + self._state = self.State.COUNT + self._tmp_model = Arpa() + else: + pass # skip comment line + + def _count(self, line): + match = self.re_count.match(line) + if match: + order = match.group(1) + count = match.group(2) + self._tmp_model.add_count(int(order), int(count)) + elif not line: + self._state = self.State.HEADER # there are no counts + else: + raise Exception(line) + + def _header(self, line): + match = self.re_header.match(line) + if match: + self._state = self.State.ENTRY + self._tmp_order = int(match.group(1)) + elif line == "\\end\\": + self._result.append(self._tmp_model) + self._state = self.State.DATA + self._tmp_model = None + self._tmp_order = None + elif not line: + pass # skip empty line + else: + raise Exception(line) + + def _entry(self, line): + match = self.re_entry.match(line) + if match: + p = self._float_or_int(match.group(1)) + ngram = tuple(match.group(4).split(" ")) + bo_match = match.group(7) + bo = self._float_or_int(bo_match) if bo_match else None + self._tmp_model.add_entry(ngram, p, bo, self._tmp_order) + elif not line: + self._state = self.State.HEADER # last entry + else: + raise Exception(line) + + @staticmethod + def _float_or_int(s): + f = float(s) + i = int(f) + if str(i) == s: # don't drop trailing ".0" + return i + else: + return f + + def load(self, fp): + """Deserialize fp (a file-like object) to a Python object.""" + return self._parse(fp) + + def loadf(self, path, encoding=None): + """Deserialize path (.arpa, .gz) to a Python object.""" + path = str(path) + if path.endswith(".gz"): + with gzip.open(path, mode="rt", encoding=encoding) as f: + return self.load(f) + else: + with open(path, mode="rt", encoding=encoding) as f: + return self.load(f) + + def loads(self, s): + """Deserialize s (a str) to a Python object.""" + with StringIO(s) as f: + return self.load(f) + + def dump(self, obj, fp): + """Serialize obj to fp (a file-like object) in ARPA format.""" + obj.write(fp) + + def dumpf(self, obj, path, encoding=None): + """Serialize obj to path in ARPA format (.arpa, .gz).""" + path = str(path) + if path.endswith(".gz"): + with gzip.open(path, mode="wt", encoding=encoding) as f: + return self.dump(obj, f) + else: + with open(path, mode="wt", encoding=encoding) as f: + self.dump(obj, f) + + def dumps(self, obj): + """Serialize obj to an ARPA formatted str.""" + with StringIO() as f: + self.dump(obj, f) + return f.getvalue() + + +def add_log_p(prev_log_sum, log_p, base): + return math.log(base**log_p + base**prev_log_sum, base) + + +def compute_numerator_denominator(lm, h): + log_sum_seen_h = -math.inf + log_sum_seen_h_lower = -math.inf + base = lm.base + for w, log_p in lm._ngrams[len(h)][h].items(): + log_sum_seen_h = add_log_p(log_sum_seen_h, log_p, base) + + ngram = h + (w,) + log_p_lower = lm.log_p_raw(ngram[1:]) + log_sum_seen_h_lower = add_log_p(log_sum_seen_h_lower, log_p_lower, base) + + numerator = 1.0 - base**log_sum_seen_h + denominator = 1.0 - base**log_sum_seen_h_lower + return numerator, denominator + + +def prune(lm, threshold, minorder): + # Reference: + # https://github.com/BitSpeech/SRILM/blob/d571a4424fb0cf08b29fbfccfddd092ea969eae3/lm/src/NgramLM.cc#L2330 + + for i in range( + lm.order(), max(minorder - 1, 1), -1 + ): # i is the order of the ngram (h, w) + logging.info("processing %d-grams ..." % i) + count_pruned_ngrams = 0 + + h_dict = lm._ngrams[i - 1] + for h in list(h_dict.keys()): + # old backoff weight, BOW(h) + log_bow = lm._log_bo(h) + if log_bow is None: + log_bow = 0 + + # Compute numerator and denominator of the backoff weight, + # so that we can quickly compute the BOW adjustment due to + # leaving out one prob. + numerator, denominator = compute_numerator_denominator(lm, h) + + # assert abs(math.log(numerator, lm.base) - math.log(denominator, lm.base) - h_dict[h].log_bo) < 1e-5 + + # Compute the marginal probability of the context, P(h) + h_log_p = lm.log_joint_prob(h) + + all_pruned = True + pruned_w_set = set() + + for w, log_p in h_dict[h].items(): + ngram = h + (w,) + + # lower-order estimate for ngramProb, P(w|h') + backoff_prob = lm.log_p_raw(ngram[1:]) + + # Compute BOW after removing ngram, BOW'(h) + new_log_bow = math.log( + numerator + lm.base**log_p, lm.base + ) - math.log(denominator + lm.base**backoff_prob, lm.base) + + # Compute change in entropy due to removal of ngram + delta_prob = backoff_prob + new_log_bow - log_p + delta_entropy = -(lm.base**h_log_p) * ( + (lm.base**log_p) * delta_prob + + numerator * (new_log_bow - log_bow) + ) + + # compute relative change in model (training set) perplexity + perp_change = lm.base**delta_entropy - 1.0 + + pruned = threshold > 0 and perp_change < threshold + + # Make sure we don't prune ngrams whose backoff nodes are needed + if ( + pruned + and len(ngram) in lm._ngrams + and len(lm._ngrams[len(ngram)][ngram]) > 0 + ): + pruned = False + + logging.debug( + "CONTEXT " + + str(h) + + " WORD " + + w + + " CONTEXTPROB %f " % h_log_p + + " OLDPROB %f " % log_p + + " NEWPROB %f " % (backoff_prob + new_log_bow) + + " DELTA-H %f " % delta_entropy + + " DELTA-LOGP %f " % delta_prob + + " PPL-CHANGE %f " % perp_change + + " PRUNED " + + str(pruned) + ) + + if pruned: + pruned_w_set.add(w) + count_pruned_ngrams += 1 + else: + all_pruned = False + + # If we removed all ngrams for this context we can + # remove the context itself, but only if the present + # context is not a prefix to a longer one. + if all_pruned and len(pruned_w_set) == len(h_dict[h]): + del h_dict[ + h + ] # this context h is no longer needed, as its ngram prob is stored at its own context h' + elif len(pruned_w_set) > 0: + # The pruning for this context h is actually done here + old_context = lm.set_new_context(h) + + for w, p_w in old_context.items(): + if w not in pruned_w_set: + lm.add_entry( + h + (w,), p_w + ) # the entry hw is stored at the context h + + # We need to recompute the back-off weight, but + # this can only be done after completing the pruning + # of the lower-order ngrams. + # Reference: + # https://github.com/BitSpeech/SRILM/blob/d571a4424fb0cf08b29fbfccfddd092ea969eae3/flm/src/FNgramLM.cc#L2124 + + logging.info("pruned %d %d-grams" % (count_pruned_ngrams, i)) + + # recompute backoff weights + for i in range( + max(minorder - 1, 1) + 1, lm.order() + 1 + ): # be careful of this order: from low- to high-order + for h in lm._ngrams[i - 1]: + numerator, denominator = compute_numerator_denominator(lm, h) + new_log_bow = math.log(numerator, lm.base) - math.log(denominator, lm.base) + lm._ngrams[len(h)][h].log_bo = new_log_bow + + # update counts + lm.update_counts() + + return + + +def check_h_is_valid(lm, h): + sum_under_h = sum( + [lm.base ** lm.log_p_raw(h + (w,)) for w in lm.vocabulary(sort=False)] + ) + if abs(sum_under_h - 1.0) > 1e-6: + logging.info("warning: %s %f" % (str(h), sum_under_h)) + return False + else: + return True + + +def validate_lm(lm): + # sanity check if the conditional probability sums to one under each context h + for i in range(lm.order(), 0, -1): # i is the order of the ngram (h, w) + logging.info("validating %d-grams ..." % i) + h_dict = lm._ngrams[i - 1] + for h in h_dict.keys(): + check_h_is_valid(lm, h) + + +def compare_two_apras(path1, path2): + pass + + +if __name__ == "__main__": + # load an arpa file + logging.info("Loading the arpa file from %s" % args.lm) + parser = ArpaParser() + models = parser.loadf(args.lm, encoding=default_encoding) + lm = models[0] # ARPA files may contain several models. + logging.info("Stats before pruning:") + for i, cnt in lm.counts(): + logging.info("ngram %d=%d" % (i, cnt)) + + # prune it, the language model will be modified in-place + logging.info("Start pruning the model with threshold=%.3E..." % args.threshold) + prune(lm, args.threshold, args.minorder) + + # validate_lm(lm) + + # write the arpa language model to a file + logging.info("Stats after pruning:") + for i, cnt in lm.counts(): + logging.info("ngram %d=%d" % (i, cnt)) + logging.info("Saving the pruned arpa file to %s" % args.write_lm) + parser.dumpf(lm, args.write_lm, encoding=default_encoding) + logging.info("Done.") From 5c65516e05a52eb0b9973bff6fe9fab84c0720b4 Mon Sep 17 00:00:00 2001 From: Wei Kang Date: Thu, 20 Apr 2023 16:14:16 +0800 Subject: [PATCH 162/174] Fix aishell rnnlm training command (#1015) --- egs/aishell/ASR/prepare.sh | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/egs/aishell/ASR/prepare.sh b/egs/aishell/ASR/prepare.sh index bd34c1f44..80672b9e3 100755 --- a/egs/aishell/ASR/prepare.sh +++ b/egs/aishell/ASR/prepare.sh @@ -230,7 +230,7 @@ if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then if [ ! -f $dl_dir/lm/aishell-train-word.txt ]; then cp $lang_phone_dir/transcript_words.txt $dl_dir/lm/aishell-train-word.txt fi - + # training words ./local/prepare_char_lm_training_data.py \ --lang-char data/lang_char \ @@ -278,7 +278,7 @@ if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then out_dir=data/lm_training_char mkdir -p $out_dir ln -snf ../../../librispeech/ASR/local/sort_lm_training_data.py local/ - + ./local/sort_lm_training_data.py \ --in-lm-data $out_dir/lm_data.pt \ --out-lm-data $out_dir/sorted_lm_data.pt \ @@ -306,9 +306,9 @@ if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then --hidden-dim 512 \ --num-layers 2 \ --batch-size 400 \ - --exp-dir rnnlm_char/exp_aishell1_small \ - --lm-data data/lm_char/sorted_lm_data_aishell1.pt \ - --lm-data-valid data/lm_char/sorted_lm_data_valid.pt \ + --exp-dir rnnlm_char/exp \ + --lm-data $out_dir/sorted_lm_data.pt \ + --lm-data-valid $out_dir/sorted_lm_data_valid.pt \ --vocab-size 4336 \ --master-port 12345 fi From 0efed1cec5145a4bdc5d43a9ba389d4189819271 Mon Sep 17 00:00:00 2001 From: Wei Kang Date: Thu, 20 Apr 2023 23:09:31 +0800 Subject: [PATCH 163/174] Fix path in aishell rnnlm training (#1016) --- egs/aishell/ASR/prepare.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/aishell/ASR/prepare.sh b/egs/aishell/ASR/prepare.sh index 80672b9e3..3e0d5f51b 100755 --- a/egs/aishell/ASR/prepare.sh +++ b/egs/aishell/ASR/prepare.sh @@ -308,7 +308,7 @@ if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then --batch-size 400 \ --exp-dir rnnlm_char/exp \ --lm-data $out_dir/sorted_lm_data.pt \ - --lm-data-valid $out_dir/sorted_lm_data_valid.pt \ + --lm-data-valid $out_dir/sorted_lm_data-valid.pt \ --vocab-size 4336 \ --master-port 12345 fi From 57d6482a7961472ec6f1c6ede2979e3ed48f1094 Mon Sep 17 00:00:00 2001 From: marcoyang1998 <45973641+marcoyang1998@users.noreply.github.com> Date: Fri, 21 Apr 2023 15:43:28 +0800 Subject: [PATCH 164/174] Streaming Zipformer with multi-dataset (#984) * modify train.py * add right padding option in decode.py * update RESULTS.md --- .../asr_datamodule.py | 10 +- egs/librispeech/ASR/.gitignore | 3 +- egs/librispeech/ASR/README.md | 1 + egs/librispeech/ASR/RESULTS.md | 140 +- .../pruned_transducer_stateless7/scaling.py | 5 +- .../export-for-ncnn.py | 1 - .../onnx_check.py | 3 +- .../streaming-ncnn-decode.py | 23 + .../test_model.py | 39 +- .../train.py | 24 +- .../asr_datamodule.py | 1 + .../beam_search.py | 1 + .../decode.py | 818 ++++++++++ .../decode_gigaspeech.py | 837 ++++++++++ .../decode_stream.py | 1 + .../decoder.py | 1 + .../encoder_interface.py | 1 + .../export-for-ncnn-zh.py | 1 + .../export-for-ncnn.py | 368 +++++ .../export-onnx.py | 1 + .../export.py | 1 + .../gigaspeech.py | 1 + .../gigaspeech_asrmodule.py | 1 + .../gigaspeech_scoring.py | 1 + .../jit_pretrained.py | 1 + .../jit_trace_export.py | 1 + .../jit_trace_pretrained.py | 1 + .../joiner.py | 1 + .../librispeech.py | 1 + .../model.py | 1 + .../onnx_check.py | 1 + .../onnx_model_wrapper.py | 1 + .../onnx_pretrained.py | 1 + .../optim.py | 1 + .../pretrained.py | 1 + .../scaling.py | 1 + .../scaling_converter.py | 1 + .../streaming-ncnn-decode.py | 1 + .../streaming_beam_search.py | 1 + .../streaming_decode.py | 610 ++++++++ .../test_model.py | 1 + .../train.py | 1370 +++++++++++++++++ .../train2.py | 1 + .../zipformer.py | 1 + .../zipformer2.py | 1 + 45 files changed, 4258 insertions(+), 24 deletions(-) create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/asr_datamodule.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/beam_search.py create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/decode.py create mode 100644 egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/decode_gigaspeech.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/decode_stream.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/decoder.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/encoder_interface.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/export-for-ncnn-zh.py create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/export-for-ncnn.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/export-onnx.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/export.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/gigaspeech.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/gigaspeech_asrmodule.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/gigaspeech_scoring.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/jit_pretrained.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/jit_trace_export.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/jit_trace_pretrained.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/joiner.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/librispeech.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/model.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/onnx_check.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/onnx_model_wrapper.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/onnx_pretrained.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/optim.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/pretrained.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/scaling.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/scaling_converter.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/streaming-ncnn-decode.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/streaming_beam_search.py create mode 100644 egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/streaming_decode.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/test_model.py create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/train.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/train2.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/zipformer.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/zipformer2.py diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py index 5c01d7190..4d5d2b8f9 100644 --- a/egs/gigaspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py +++ b/egs/gigaspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py @@ -387,14 +387,16 @@ class GigaSpeechAsrDataModule: @lru_cache() def train_cuts(self) -> CutSet: logging.info(f"About to get train_{self.args.subset} cuts") - path = self.args.manifest_dir / f"cuts_{self.args.subset}.jsonl.gz" + path = self.args.manifest_dir / f"gigaspeech_cuts_{self.args.subset}.jsonl.gz" cuts_train = CutSet.from_jsonl_lazy(path) return cuts_train @lru_cache() def dev_cuts(self) -> CutSet: logging.info("About to get dev cuts") - cuts_valid = load_manifest_lazy(self.args.manifest_dir / "cuts_DEV.jsonl.gz") + cuts_valid = load_manifest_lazy( + self.args.manifest_dir / "gigaspeech_cuts_DEV.jsonl.gz" + ) if self.args.small_dev: return cuts_valid.subset(first=1000) else: @@ -403,4 +405,6 @@ class GigaSpeechAsrDataModule: @lru_cache() def test_cuts(self) -> CutSet: logging.info("About to get test cuts") - return load_manifest_lazy(self.args.manifest_dir / "cuts_TEST.jsonl.gz") + return load_manifest_lazy( + self.args.manifest_dir / "gigaspeech_cuts_TEST.jsonl.gz" + ) diff --git a/egs/librispeech/ASR/.gitignore b/egs/librispeech/ASR/.gitignore index 8dec2d86d..1c26f7978 100644 --- a/egs/librispeech/ASR/.gitignore +++ b/egs/librispeech/ASR/.gitignore @@ -1,2 +1,3 @@ log-* -.DS_Store \ No newline at end of file +.DS_Store +run*.sh diff --git a/egs/librispeech/ASR/README.md b/egs/librispeech/ASR/README.md index 9ffd78d5b..82cef9817 100644 --- a/egs/librispeech/ASR/README.md +++ b/egs/librispeech/ASR/README.md @@ -26,6 +26,7 @@ The following table lists the differences among them. | `pruned_transducer_stateless7_ctc` | Zipformer | Embedding + Conv1d | Same as pruned_transducer_stateless7, but with extra CTC head| | `pruned_transducer_stateless7_ctc_bs` | Zipformer | Embedding + Conv1d | pruned_transducer_stateless7_ctc + blank skip | | `pruned_transducer_stateless7_streaming` | Streaming Zipformer | Embedding + Conv1d | streaming version of pruned_transducer_stateless7 | +| `pruned_transducer_stateless7_streaming_multi` | Streaming Zipformer | Embedding + Conv1d | same as pruned_transducer_stateless7_streaming, trained on LibriSpeech + GigaSpeech | | `pruned_transducer_stateless8` | Zipformer | Embedding + Conv1d | Same as pruned_transducer_stateless7, but using extra data from GigaSpeech| | `pruned_stateless_emformer_rnnt2` | Emformer(from torchaudio) | Embedding + Conv1d | Using Emformer from torchaudio for streaming ASR| | `conv_emformer_transducer_stateless` | ConvEmformer | Embedding + Conv1d | Using ConvEmformer for streaming ASR + mechanisms in reworked model | diff --git a/egs/librispeech/ASR/RESULTS.md b/egs/librispeech/ASR/RESULTS.md index 881e8be97..5a956fc9c 100644 --- a/egs/librispeech/ASR/RESULTS.md +++ b/egs/librispeech/ASR/RESULTS.md @@ -1,5 +1,141 @@ ## Results +### Streaming Zipformer-Transducer (Pruned Stateless Transducer + Streaming Zipformer + Multi-Dataset) + +#### [pruned_transducer_stateless7_streaming_multi](./pruned_transducer_stateless7_streaming_multi) + +See for more details. + +You can find a pretrained model, training logs, decoding logs, and decoding +results at: + +Number of model parameters: 70369391, i.e., 70.37 M + +##### training on full librispeech + full gigaspeech (with giga_prob=0.9) + +The WERs are: + + +| decoding method | chunk size | test-clean | test-other | comment | decoding mode | +|----------------------|------------|------------|------------|---------------------|----------------------| +| greedy search | 320ms | 2.43 | 6.0 | --epoch 20 --avg 4 | simulated streaming | +| greedy search | 320ms | 2.47 | 6.13 | --epoch 20 --avg 4 | chunk-wise | +| fast beam search | 320ms | 2.43 | 5.99 | --epoch 20 --avg 4 | simulated streaming | +| fast beam search | 320ms | 2.8 | 6.46 | --epoch 20 --avg 4 | chunk-wise | +| modified beam search | 320ms | 2.4 | 5.96 | --epoch 20 --avg 4 | simulated streaming | +| modified beam search | 320ms | 2.42 | 6.03 | --epoch 20 --avg 4 | chunk-size | +| greedy search | 640ms | 2.26 | 5.58 | --epoch 20 --avg 4 | simulated streaming | +| greedy search | 640ms | 2.33 | 5.76 | --epoch 20 --avg 4 | chunk-wise | +| fast beam search | 640ms | 2.27 | 5.54 | --epoch 20 --avg 4 | simulated streaming | +| fast beam search | 640ms | 2.37 | 5.75 | --epoch 20 --avg 4 | chunk-wise | +| modified beam search | 640ms | 2.22 | 5.5 | --epoch 20 --avg 4 | simulated streaming | +| modified beam search | 640ms | 2.25 | 5.69 | --epoch 20 --avg 4 | chunk-size | + +The model also has good WERs on GigaSpeech. The following WERs are achieved on GigaSpeech test and dev sets: + +| decoding method | chunk size | dev | test | comment | decoding mode | +|----------------------|------------|-----|------|------------|---------------------| +| greedy search | 320ms | 12.08 | 11.98 | --epoch 20 --avg 4 | simulated streaming | +| greedy search | 640ms | 11.66 | 11.71 | --epoch 20 --avg 4 | simulated streaming | +| modified beam search | 320ms | 11.95 | 11.83 | --epoch 20 --avg 4 | simulated streaming | +| modified beam search | 320ms | 11.65 | 11.56 | --epoch 20 --avg 4 | simulated streaming | + + +Note: `simulated streaming` indicates feeding full utterance during decoding using `decode.py`, +while `chunk-size` indicates feeding certain number of frames at each time using `streaming_decode.py`. + +The training command is: + +```bash +./pruned_transducer_stateless7_streaming_multi/train.py \ + --world-size 4 \ + --num-epochs 20 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir pruned_transducer_stateless7_streaming_multi/exp \ + --full-libri 1 \ + --giga-prob 0.9 \ + --max-duration 750 \ + --master-port 12345 +``` + +The tensorboard log can be found at + + +The simulated streaming decoding command (e.g., chunk-size=320ms) is: +```bash +for m in greedy_search fast_beam_search modified_beam_search; do + ./pruned_transducer_stateless7_streaming_multi/decode.py \ + --epoch 20 \ + --avg 4 \ + --exp-dir ./pruned_transducer_stateless7_streaming_multi/exp \ + --max-duration 600 \ + --decode-chunk-len 32 \ + --right-padding 64 \ + --decoding-method $m +done +``` + +The streaming chunk-size decoding command (e.g., chunk-size=320ms) is: +```bash +for m in greedy_search modified_beam_search fast_beam_search; do + ./pruned_transducer_stateless7_streaming_multi/streaming_decode.py \ + --epoch 20 \ + --avg 4 \ + --exp-dir ./pruned_transducer_stateless7_streaming_multi/exp \ + --decoding-method $m \ + --decode-chunk-len 32 \ + --num-decode-streams 2000 +done +``` + + +#### Smaller model + +We also provide a very small version (only 6.1M parameters) of this setup. The training command for the small model is: + +```bash +./pruned_transducer_stateless7_streaming_multi/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir pruned_transducer_stateless7_streaming_multi/exp \ + --full-libri 1 \ + --giga-prob 0.9 \ + --num-encoder-layers "2,2,2,2,2" \ + --feedforward-dims "256,256,512,512,256" \ + --nhead "4,4,4,4,4" \ + --encoder-dims "128,128,128,128,128" \ + --attention-dims "96,96,96,96,96" \ + --encoder-unmasked-dims "96,96,96,96,96" \ + --max-duration 1200 \ + --master-port 12345 +``` + +You can find this pretrained small model and its training logs, decoding logs, and decoding +results at: + + + +| decoding method | chunk size | test-clean | test-other | comment | decoding mode | +|----------------------|------------|------------|------------|---------------------|----------------------| +| greedy search | 320ms | 5.95 | 15.03 | --epoch 30 --avg 1 | simulated streaming | +| greedy search | 640ms | 5.61 | 13.86 | --epoch 30 --avg 1 | simulated streaming | +| modified beam search | 320ms | 5.72 | 14.34 | --epoch 30 --avg 1 | simulated streaming | +| modified beam search | 640ms | 5.43 | 13.16 | --epoch 30 --avg 1 | simulated streaming | +| fast beam search | 320ms | 5.88 | 14.45 | --epoch 30 --avg 1 | simulated streaming | +| fast beam search | 640ms | 5.48 | 13.31 | --epoch 30 --avg 1 | simulated streaming | + +This small model achieves the following WERs on GigaSpeech test and dev sets: + +| decoding method | chunk size | dev | test | comment | decoding mode | +|----------------------|------------|------------|------------|---------------------|----------------------| +| greedy search | 320ms | 17.57 | 17.2 | --epoch 30 --avg 1 | simulated streaming | +| modified beam search | 320ms | 16.98 | 11.98 | --epoch 30 --avg 1 | simulated streaming | + +You can find the tensorboard logs at . + ### Streaming Zipformer-Transducer (Pruned Stateless Transducer + Streaming Zipformer) #### [pruned_transducer_stateless7_streaming](./pruned_transducer_stateless7_streaming) @@ -53,7 +189,7 @@ The tensorboard log can be found at The simulated streaming decoding command (e.g., chunk-size=320ms) is: ```bash -for $m in greedy_search fast_beam_search modified_beam_search; do +for m in greedy_search fast_beam_search modified_beam_search; do ./pruned_transducer_stateless7_streaming/decode.py \ --epoch 30 \ --avg 9 \ @@ -599,7 +735,7 @@ done ``` Note that a small change is made to the `pruned_transducer_stateless7/decoder.py` in -this [PR](/ceph-data4/yangxiaoyu/softwares/icefall_development/icefall_random_padding/egs/librispeech/ASR/pruned_transducer_stateless7/exp_960h_no_paddingidx_ngpu4/tensorboard) to address the +this [PR](https://github.com/k2-fsa/icefall/pull/942) to address the problem of emitting the first symbol at the very beginning. If you need a model without this issue, please download the model from here: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 156b91f09..30a737061 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -777,8 +777,9 @@ class WithLoss(torch.autograd.Function): @staticmethod def backward(ctx, ans_grad: Tensor): - return ans_grad, torch.ones( - ctx.y_shape, dtype=ans_grad.dtype, device=ans_grad.device + return ( + ans_grad, + torch.ones(ctx.y_shape, dtype=ans_grad.dtype, device=ans_grad.device), ) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py index 0f84eca83..f5589d1b2 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py @@ -33,7 +33,6 @@ popd --use-averaged-model 0 \ --epoch 99 \ --avg 1 \ - \ --decode-chunk-len 32 \ --num-encoder-layers "2,4,3,2,4" \ --feedforward-dims "1024,1024,2048,2048,1024" \ diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_check.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_check.py index 6c78ba70b..d7a4b9551 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_check.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_check.py @@ -71,13 +71,12 @@ It will generate the following 3 files inside $repo/exp: import argparse import logging +import torch from onnx_pretrained import OnnxModel from zipformer import stack_states from icefall import is_module_available -import torch - def get_parser(): parser = argparse.ArgumentParser( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py index 5a36b695f..8acace979 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py @@ -109,6 +109,29 @@ class Model: self.init_joiner(args) # Please change the parameters according to your model + + # 20M + # self.num_encoder_layers = to_int_tuple("2,2,2,2,2") + # self.encoder_dims = to_int_tuple("256,256,256,256,256") # also known as d_model + # self.attention_dims = to_int_tuple("192,192,192,192,192") + # self.zipformer_downsampling_factors = to_int_tuple("1,2,4,8,2") + # self.cnn_module_kernels = to_int_tuple("31,31,31,31,31") + + # 9.6M + # self.num_encoder_layers = to_int_tuple("2,3,2,2,3") + # self.encoder_dims = to_int_tuple("160,160,160,160,160") # also known as d_model + # self.attention_dims = to_int_tuple("96,96,96,96,96") + # self.zipformer_downsampling_factors = to_int_tuple("1,2,4,8,2") + # self.cnn_module_kernels = to_int_tuple("31,31,31,31,31") + + # 5.5M or 6M + + # self.num_encoder_layers = to_int_tuple("2,2,2,2,2") + # self.encoder_dims = to_int_tuple("128,128,128,128,128") # also known as d_model + # self.attention_dims = to_int_tuple("96,96,96,96,96") + # self.zipformer_downsampling_factors = to_int_tuple("1,2,4,8,2") + # self.cnn_module_kernels = to_int_tuple("31,31,31,31,31") + self.num_encoder_layers = to_int_tuple("2,4,3,2,4") self.encoder_dims = to_int_tuple("384,384,384,384,384") # also known as d_model self.attention_dims = to_int_tuple("192,192,192,192,192") diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/test_model.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/test_model.py index 5400df804..de12d7af1 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/test_model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/test_model.py @@ -62,6 +62,43 @@ def test_model(): model = torch.jit.script(model) +def test_model_small(): + params = get_params() + params.vocab_size = 500 + params.blank_id = 0 + params.context_size = 2 + params.num_encoder_layers = "2,2,2,2,2" + params.feedforward_dims = "256,256,512,512,256" + params.nhead = "4,4,4,4,4" + params.encoder_dims = "128,128,128,128,128" + params.attention_dims = "96,96,96,96,96" + params.encoder_unmasked_dims = "96,96,96,96,96" + params.zipformer_downsampling_factors = "1,2,4,8,2" + params.cnn_module_kernels = "31,31,31,31,31" + params.decoder_dim = 320 + params.joiner_dim = 320 + params.num_left_chunks = 4 + params.short_chunk_size = 50 + params.decode_chunk_len = 32 + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + print(f"Number of model parameters: {num_param}") + import pdb + + pdb.set_trace() + + # Test jit script + convert_scaled_to_non_scaled(model, inplace=True) + # We won't use the forward() method of the model in C++, so just ignore + # it here. + # Otherwise, one of its arguments is a ragged tensor and is not + # torch scriptabe. + model.__class__.forward = torch.jit.ignore(model.__class__.forward) + print("Using torch.jit.script") + model = torch.jit.script(model) + + def test_model_jit_trace(): params = get_params() params.vocab_size = 500 @@ -142,7 +179,7 @@ def test_model_jit_trace(): def main(): - test_model() + test_model_small() test_model_jit_trace() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train.py index c7a2a136d..b2f9ffc09 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train.py @@ -1049,10 +1049,10 @@ def run(rank, world_size, args): librispeech = LibriSpeechAsrDataModule(args) + train_cuts = librispeech.train_clean_100_cuts() if params.full_libri: - train_cuts = librispeech.train_all_shuf_cuts() - else: - train_cuts = librispeech.train_clean_100_cuts() + train_cuts += librispeech.train_clean_360_cuts() + train_cuts += librispeech.train_other_500_cuts() def remove_short_and_long_utt(c: Cut): # Keep only utterances with duration between 1 second and 20 seconds @@ -1091,7 +1091,7 @@ def run(rank, world_size, args): return True - train_cuts = train_cuts.filter(remove_short_and_long_utt) + # train_cuts = train_cuts.filter(remove_short_and_long_utt) if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: # We only load the sampler's state dict when it loads a checkpoint @@ -1108,14 +1108,14 @@ def run(rank, world_size, args): valid_cuts += librispeech.dev_other_cuts() valid_dl = librispeech.valid_dataloaders(valid_cuts) - if not params.print_diagnostics: - scan_pessimistic_batches_for_oom( - model=model, - train_dl=train_dl, - optimizer=optimizer, - sp=sp, - params=params, - ) + # if not params.print_diagnostics: + # scan_pessimistic_batches_for_oom( + # model=model, + # train_dl=train_dl, + # optimizer=optimizer, + # sp=sp, + # params=params, + # ) scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/asr_datamodule.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/asr_datamodule.py new file mode 120000 index 000000000..a3a1584d1 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/asr_datamodule.py @@ -0,0 +1 @@ +../pruned_transducer_stateless8/asr_datamodule.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/beam_search.py new file mode 120000 index 000000000..8554e44cc --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/beam_search.py @@ -0,0 +1 @@ +../pruned_transducer_stateless2/beam_search.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/decode.py new file mode 100755 index 000000000..35158ced4 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/decode.py @@ -0,0 +1,818 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) greedy search +./pruned_transducer_stateless7_streaming/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --max-duration 600 \ + --decode-chunk-len 32 \ + --decoding-method greedy_search + +(2) beam search (not recommended) +./pruned_transducer_stateless7_streaming/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --max-duration 600 \ + --decode-chunk-len 32 \ + --decoding-method beam_search \ + --beam-size 4 + +(3) modified beam search +./pruned_transducer_stateless7_streaming/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --max-duration 600 \ + --decode-chunk-len 32 \ + --decoding-method modified_beam_search \ + --beam-size 4 + +(4) fast beam search (one best) +./pruned_transducer_stateless7_streaming/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --max-duration 600 \ + --decode-chunk-len 32 \ + --decoding-method fast_beam_search \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 + +(5) fast beam search (nbest) +./pruned_transducer_stateless7_streaming/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --max-duration 600 \ + --decode-chunk-len 32 \ + --decoding-method fast_beam_search_nbest \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(6) fast beam search (nbest oracle WER) +./pruned_transducer_stateless7_streaming/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --max-duration 600 \ + --decode-chunk-len 32 \ + --decoding-method fast_beam_search_nbest_oracle \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(7) fast beam search (with LG) +./pruned_transducer_stateless7_streaming/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --max-duration 600 \ + --decode-chunk-len 32 \ + --decoding-method fast_beam_search_nbest_LG \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 +""" + + +import argparse +import logging +import math +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import AsrDataModule +from beam_search import ( + beam_search, + fast_beam_search_nbest, + fast_beam_search_nbest_LG, + fast_beam_search_nbest_oracle, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, +) +from librispeech import LibriSpeech +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=9, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7_streaming/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_bpe_500", + help="The lang dir containing word table and LG graph", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - fast_beam_search + - fast_beam_search_nbest + - fast_beam_search_nbest_oracle + - fast_beam_search_nbest_LG + If you use fast_beam_search_nbest_LG, you have to specify + `--lang-dir`, which should contain `LG.pt`. + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=20.0, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search, + fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle + """, + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.01, + help=""" + Used only when --decoding_method is fast_beam_search_nbest_LG. + It specifies the scale for n-gram LM scores. + """, + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=8, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=64, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--right-padding", + type=int, + default=64, + help="Padding frames at the end of features", + ) + + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. + Used only when --decoding_method is greedy_search""", + ) + + parser.add_argument( + "--num-paths", + type=int, + default=200, + help="""Number of paths for nbest decoding. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""Scale applied to lattice scores when computing nbest paths. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + add_model_arguments(parser) + + return parser + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + batch: dict, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if greedy_search is used, it would be "greedy_search" + If beam search with a beam size of 7 is used, it would be + "beam_7" + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return the decoding result. See above description for the format of + the returned dict. + """ + device = next(model.parameters()).device + feature = batch["inputs"] + assert feature.ndim == 3 + + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + feature_lens += params.right_padding + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, params.right_padding), + value=LOG_EPS, + ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + + hyps = [] + + if params.decoding_method == "fast_beam_search": + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_LG": + hyp_tokens = fast_beam_search_nbest_LG( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in hyp_tokens: + hyps.append([word_table[i] for i in hyp]) + elif params.decoding_method == "fast_beam_search_nbest": + hyp_tokens = fast_beam_search_nbest( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_oracle": + hyp_tokens = fast_beam_search_nbest_oracle( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + ref_texts=sp.encode(supervisions["text"]), + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + else: + batch_size = encoder_out.size(0) + + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.decoding_method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyps.append(sp.decode(hyp).split()) + + if params.decoding_method == "greedy_search": + return {"greedy_search": hyps} + elif "fast_beam_search" in params.decoding_method: + key = f"beam_{params.beam}_" + key += f"max_contexts_{params.max_contexts}_" + key += f"max_states_{params.max_states}" + if "nbest" in params.decoding_method: + key += f"_num_paths_{params.num_paths}_" + key += f"nbest_scale_{params.nbest_scale}" + if "LG" in params.decoding_method: + key += f"_ngram_lm_scale_{params.ngram_lm_scale}" + + return {key: hyps} + else: + return {f"beam_size_{params.beam_size}": hyps} + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + if params.decoding_method == "greedy_search": + log_interval = 50 + else: + log_interval = 20 + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + word_table=word_table, + batch=batch, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + AsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + assert params.decoding_method in ( + "greedy_search", + "beam_search", + "fast_beam_search", + "fast_beam_search_nbest", + "fast_beam_search_nbest_LG", + "fast_beam_search_nbest_oracle", + "modified_beam_search", + ) + params.res_dir = params.exp_dir / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + params.suffix += f"-streaming-chunk-size-{params.decode_chunk_len}" + + if "fast_beam_search" in params.decoding_method: + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + params.suffix += f"-right-padding-{params.right_padding}" + if "nbest" in params.decoding_method: + params.suffix += f"-nbest-scale-{params.nbest_scale}" + params.suffix += f"-num-paths-{params.num_paths}" + if "LG" in params.decoding_method: + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" + elif "beam_search" in params.decoding_method: + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + else: + params.suffix += f"-context-{params.context_size}" + params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and are defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + assert model.encoder.decode_chunk_size == params.decode_chunk_len // 2, ( + model.encoder.decode_chunk_size, + params.decode_chunk_len, + ) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + + if "fast_beam_search" in params.decoding_method: + if params.decoding_method == "fast_beam_search_nbest_LG": + lexicon = Lexicon(params.lang_dir) + word_table = lexicon.word_table + lg_filename = params.lang_dir / "LG.pt" + logging.info(f"Loading {lg_filename}") + decoding_graph = k2.Fsa.from_dict( + torch.load(lg_filename, map_location=device) + ) + decoding_graph.scores *= params.ngram_lm_scale + else: + word_table = None + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + else: + decoding_graph = None + word_table = None + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + asr_datamodule = AsrDataModule(args) + librispeech = LibriSpeech(manifest_dir=args.manifest_dir) + + test_clean_cuts = librispeech.test_clean_cuts() + test_other_cuts = librispeech.test_other_cuts() + + test_clean_dl = asr_datamodule.test_dataloaders(test_clean_cuts) + test_other_dl = asr_datamodule.test_dataloaders(test_other_cuts) + + test_sets = ["test-clean", "test-other"] + test_dl = [test_clean_dl, test_other_dl] + + for test_set, test_dl in zip(test_sets, test_dl): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + sp=sp, + word_table=word_table, + decoding_graph=decoding_graph, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/decode_gigaspeech.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/decode_gigaspeech.py new file mode 100644 index 000000000..a4f52ad7f --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/decode_gigaspeech.py @@ -0,0 +1,837 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao, +# Xiaoyu Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) greedy search +./pruned_transducer_stateless7/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7/exp \ + --max-duration 600 \ + --decoding-method greedy_search + +(2) beam search (not recommended) +./pruned_transducer_stateless7/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7/exp \ + --max-duration 600 \ + --decoding-method beam_search \ + --beam-size 4 + +(3) modified beam search +./pruned_transducer_stateless7/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7/exp \ + --max-duration 600 \ + --decoding-method modified_beam_search \ + --beam-size 4 + +(4) fast beam search (one best) +./pruned_transducer_stateless7/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 + +(5) fast beam search (nbest) +./pruned_transducer_stateless7/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(6) fast beam search (nbest oracle WER) +./pruned_transducer_stateless7/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_oracle \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(7) fast beam search (with LG) +./pruned_transducer_stateless7/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_LG \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 +""" + + +import argparse +import logging +import math +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from beam_search import ( + beam_search, + fast_beam_search_nbest, + fast_beam_search_nbest_LG, + fast_beam_search_nbest_oracle, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, +) + +# from asr_datamodule import LibriSpeechAsrDataModule +from gigaspeech_asrmodule import GigaSpeechAsrDataModule +from gigaspeech_scoring import asr_text_post_processing +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=9, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_bpe_500", + help="The lang dir containing word table and LG graph", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - fast_beam_search + - fast_beam_search_nbest + - fast_beam_search_nbest_oracle + - fast_beam_search_nbest_LG + If you use fast_beam_search_nbest_LG, you have to specify + `--lang-dir`, which should contain `LG.pt`. + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=20.0, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search, + fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle + """, + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.01, + help=""" + Used only when --decoding_method is fast_beam_search_nbest_LG. + It specifies the scale for n-gram LM scores. + """, + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=8, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=64, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. + Used only when --decoding_method is greedy_search""", + ) + + parser.add_argument( + "--num-paths", + type=int, + default=200, + help="""Number of paths for nbest decoding. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""Scale applied to lattice scores when computing nbest paths. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--decode-chunk-size", + type=int, + default=16, + help="The chunk size for decoding (in frames after subsampling)", + ) + + parser.add_argument( + "--right-padding", + type=int, + default=64, + help="Padding frames at the end of features", + ) + + add_model_arguments(parser) + + return parser + + +def post_processing( + results: List[Tuple[str, List[str], List[str]]], +) -> List[Tuple[str, List[str], List[str]]]: + new_results = [] + for key, ref, hyp in results: + new_ref = asr_text_post_processing(" ".join(ref)).split() + new_hyp = asr_text_post_processing(" ".join(hyp)).split() + new_results.append((key, new_ref, new_hyp)) + return new_results + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + batch: dict, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if greedy_search is used, it would be "greedy_search" + If beam search with a beam size of 7 is used, it would be + "beam_7" + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return the decoding result. See above description for the format of + the returned dict. + """ + device = next(model.parameters()).device + feature = batch["inputs"] + assert feature.ndim == 3 + + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + feature_lens += params.right_padding + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, params.right_padding), + value=LOG_EPS, + ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + + hyps = [] + + if params.decoding_method == "fast_beam_search": + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_LG": + hyp_tokens = fast_beam_search_nbest_LG( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in hyp_tokens: + hyps.append([word_table[i] for i in hyp]) + elif params.decoding_method == "fast_beam_search_nbest": + hyp_tokens = fast_beam_search_nbest( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_oracle": + hyp_tokens = fast_beam_search_nbest_oracle( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + ref_texts=sp.encode(supervisions["text"]), + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + else: + batch_size = encoder_out.size(0) + + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.decoding_method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyps.append(sp.decode(hyp).split()) + + if params.decoding_method == "greedy_search": + return {"greedy_search": hyps} + elif "fast_beam_search" in params.decoding_method: + key = f"beam_{params.beam}_" + key += f"max_contexts_{params.max_contexts}_" + key += f"max_states_{params.max_states}" + if "nbest" in params.decoding_method: + key += f"_num_paths_{params.num_paths}_" + key += f"nbest_scale_{params.nbest_scale}" + if "LG" in params.decoding_method: + key += f"_ngram_lm_scale_{params.ngram_lm_scale}" + + return {key: hyps} + else: + return {f"beam_size_{params.beam_size}": hyps} + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + if params.decoding_method == "greedy_search": + log_interval = 50 + else: + log_interval = 20 + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + word_table=word_table, + batch=batch, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + results = post_processing(results) + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = ( + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + """ + This scripts test a libri model with libri BPE + on Gigaspeech. + """ + parser = get_parser() + GigaSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + assert params.decoding_method in ( + "greedy_search", + "beam_search", + "fast_beam_search", + "fast_beam_search_nbest", + "fast_beam_search_nbest_LG", + "fast_beam_search_nbest_oracle", + "modified_beam_search", + ) + params.res_dir = params.exp_dir / (params.decoding_method + "_gigaspeech") + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + params.suffix += f"-streaming-chunk-size-{params.decode_chunk_len}" + + if "fast_beam_search" in params.decoding_method: + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + params.suffix += f"-right-padding-{params.right_padding}" + if "nbest" in params.decoding_method: + params.suffix += f"-nbest-scale-{params.nbest_scale}" + params.suffix += f"-num-paths-{params.num_paths}" + if "LG" in params.decoding_method: + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" + elif "beam_search" in params.decoding_method: + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + else: + params.suffix += f"-context-{params.context_size}" + params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and are defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + + if "fast_beam_search" in params.decoding_method: + if params.decoding_method == "fast_beam_search_nbest_LG": + lexicon = Lexicon(params.lang_dir) + word_table = lexicon.word_table + lg_filename = params.lang_dir / "LG.pt" + logging.info(f"Loading {lg_filename}") + decoding_graph = k2.Fsa.from_dict( + torch.load(lg_filename, map_location=device) + ) + decoding_graph.scores *= params.ngram_lm_scale + else: + word_table = None + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + else: + decoding_graph = None + word_table = None + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + gigaspeech = GigaSpeechAsrDataModule(args) + + dev_cuts = gigaspeech.dev_cuts() + test_cuts = gigaspeech.test_cuts() + + dev_dl = gigaspeech.test_dataloaders(dev_cuts) + test_dl = gigaspeech.test_dataloaders(test_cuts) + + test_sets = ["dev", "test"] + test_dls = [dev_dl, test_dl] + + for test_set, test_dl in zip(test_sets, test_dls): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + sp=sp, + word_table=word_table, + decoding_graph=decoding_graph, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/decode_stream.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/decode_stream.py new file mode 120000 index 000000000..2b4596e0b --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/decode_stream.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7_streaming/decode_stream.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/decoder.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/decoder.py new file mode 120000 index 000000000..4e79a10e0 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/decoder.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7_streaming/decoder.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/encoder_interface.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/encoder_interface.py new file mode 120000 index 000000000..24f414dd1 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/encoder_interface.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7_streaming/encoder_interface.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/export-for-ncnn-zh.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/export-for-ncnn-zh.py new file mode 120000 index 000000000..c0e71accf --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/export-for-ncnn-zh.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/export-for-ncnn.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/export-for-ncnn.py new file mode 100755 index 000000000..f5589d1b2 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/export-for-ncnn.py @@ -0,0 +1,368 @@ +#!/usr/bin/env python3 + +""" +Please see +https://k2-fsa.github.io/icefall/model-export/export-ncnn.html +for more details about how to use this file. + +We use +https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29 +to demonstrate the usage of this file. + +1. Download the pre-trained model + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "data/lang_bpe/bpe.model" +git lfs pull --include "exp/pretrained.pt" + +cd exp +ln -s pretrained.pt epoch-99.pt +popd + +2. Export to ncnn + +./pruned_transducer_stateless7_streaming/export-for-ncnn.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --exp-dir $repo/exp \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --decode-chunk-len 32 \ + --num-encoder-layers "2,4,3,2,4" \ + --feedforward-dims "1024,1024,2048,2048,1024" \ + --nhead "8,8,8,8,8" \ + --encoder-dims "384,384,384,384,384" \ + --attention-dims "192,192,192,192,192" \ + --encoder-unmasked-dims "256,256,256,256,256" \ + --zipformer-downsampling-factors "1,2,4,8,2" \ + --cnn-module-kernels "31,31,31,31,31" \ + --decoder-dim 512 \ + --joiner-dim 512 + +cd $repo/exp + +pnnx encoder_jit_trace-pnnx.pt +pnnx decoder_jit_trace-pnnx.pt +pnnx joiner_jit_trace-pnnx.pt + +You can find converted models at +https://huggingface.co/csukuangfj/sherpa-ncnn-streaming-zipformer-en-2023-02-13 + +See ./streaming-ncnn-decode.py +and +https://github.com/k2-fsa/sherpa-ncnn +for usage. +""" + +import argparse +import logging +from pathlib import Path + +import sentencepiece as spm +import torch +from scaling_converter import convert_scaled_to_non_scaled +from train2 import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import setup_logger, str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="""It specifies the checkpoint to use for averaging. + Note: Epoch counts from 0. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7_streaming/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + add_model_arguments(parser) + + return parser + + +def export_encoder_model_jit_trace( + encoder_model: torch.nn.Module, + encoder_filename: str, +) -> None: + """Export the given encoder model with torch.jit.trace() + + Note: The warmup argument is fixed to 1. + + Args: + encoder_model: + The input encoder model + encoder_filename: + The filename to save the exported model. + """ + encoder_model.__class__.forward = encoder_model.__class__.streaming_forward + + decode_chunk_len = encoder_model.decode_chunk_size * 2 + pad_length = 7 + T = decode_chunk_len + pad_length # 32 + 7 = 39 + + logging.info(f"decode_chunk_len: {decode_chunk_len}") + logging.info(f"T: {T}") + + x = torch.zeros(1, T, 80, dtype=torch.float32) + states = encoder_model.get_init_state() + + traced_model = torch.jit.trace(encoder_model, (x, states)) + traced_model.save(encoder_filename) + logging.info(f"Saved to {encoder_filename}") + + +def export_decoder_model_jit_trace( + decoder_model: torch.nn.Module, + decoder_filename: str, +) -> None: + """Export the given decoder model with torch.jit.trace() + + Note: The argument need_pad is fixed to False. + + Args: + decoder_model: + The input decoder model + decoder_filename: + The filename to save the exported model. + """ + y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64) + need_pad = torch.tensor([False]) + + traced_model = torch.jit.trace(decoder_model, (y, need_pad)) + traced_model.save(decoder_filename) + logging.info(f"Saved to {decoder_filename}") + + +def export_joiner_model_jit_trace( + joiner_model: torch.nn.Module, + joiner_filename: str, +) -> None: + """Export the given joiner model with torch.jit.trace() + + Note: The argument project_input is fixed to True. A user should not + project the encoder_out/decoder_out by himself/herself. The exported joiner + will do that for the user. + + Args: + joiner_model: + The input joiner model + joiner_filename: + The filename to save the exported model. + + """ + encoder_out_dim = joiner_model.encoder_proj.weight.shape[1] + decoder_out_dim = joiner_model.decoder_proj.weight.shape[1] + encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32) + decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32) + + traced_model = torch.jit.trace(joiner_model, (encoder_out, decoder_out)) + traced_model.save(joiner_filename) + logging.info(f"Saved to {joiner_filename}") + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + + setup_logger(f"{params.exp_dir}/log-export/log-export-ncnn") + + logging.info(f"device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to("cpu") + model.eval() + + convert_scaled_to_non_scaled(model, inplace=True, is_pnnx=True) + + encoder_num_param = sum([p.numel() for p in model.encoder.parameters()]) + decoder_num_param = sum([p.numel() for p in model.decoder.parameters()]) + joiner_num_param = sum([p.numel() for p in model.joiner.parameters()]) + total_num_param = encoder_num_param + decoder_num_param + joiner_num_param + logging.info(f"encoder parameters: {encoder_num_param}") + logging.info(f"decoder parameters: {decoder_num_param}") + logging.info(f"joiner parameters: {joiner_num_param}") + logging.info(f"total parameters: {total_num_param}") + + logging.info("Using torch.jit.trace()") + + logging.info("Exporting encoder") + encoder_filename = params.exp_dir / "encoder_jit_trace-pnnx.pt" + export_encoder_model_jit_trace(model.encoder, encoder_filename) + + logging.info("Exporting decoder") + decoder_filename = params.exp_dir / "decoder_jit_trace-pnnx.pt" + export_decoder_model_jit_trace(model.decoder, decoder_filename) + + logging.info("Exporting joiner") + joiner_filename = params.exp_dir / "joiner_jit_trace-pnnx.pt" + export_joiner_model_jit_trace(model.joiner, joiner_filename) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/export-onnx.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/export-onnx.py new file mode 120000 index 000000000..137fa8cec --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/export-onnx.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7_streaming/export-onnx.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/export.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/export.py new file mode 120000 index 000000000..6a009311c --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/export.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7_streaming/export.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/gigaspeech.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/gigaspeech.py new file mode 120000 index 000000000..6c6b08d3f --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/gigaspeech.py @@ -0,0 +1 @@ +../pruned_transducer_stateless8/gigaspeech.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/gigaspeech_asrmodule.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/gigaspeech_asrmodule.py new file mode 120000 index 000000000..54f18a4f0 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/gigaspeech_asrmodule.py @@ -0,0 +1 @@ +../../../gigaspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/gigaspeech_scoring.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/gigaspeech_scoring.py new file mode 120000 index 000000000..fdfa6ce4b --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/gigaspeech_scoring.py @@ -0,0 +1 @@ +../../../gigaspeech/ASR/pruned_transducer_stateless2/gigaspeech_scoring.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/jit_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/jit_pretrained.py new file mode 120000 index 000000000..c427e7709 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/jit_pretrained.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7_streaming/jit_pretrained.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/jit_trace_export.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/jit_trace_export.py new file mode 120000 index 000000000..44ecf1780 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/jit_trace_export.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7_streaming/jit_trace_export.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/jit_trace_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/jit_trace_pretrained.py new file mode 120000 index 000000000..762d38b73 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/jit_trace_pretrained.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7_streaming/jit_trace_pretrained.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/joiner.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/joiner.py new file mode 120000 index 000000000..2a9c1ca5f --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/joiner.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7_streaming/joiner.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/librispeech.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/librispeech.py new file mode 120000 index 000000000..7c22bc4b7 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/librispeech.py @@ -0,0 +1 @@ +../pruned_transducer_stateless8/librispeech.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/model.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/model.py new file mode 120000 index 000000000..17ced2998 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/model.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7_streaming/model.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/onnx_check.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/onnx_check.py new file mode 120000 index 000000000..8ed81ba1c --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/onnx_check.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7_streaming/onnx_check.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/onnx_model_wrapper.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/onnx_model_wrapper.py new file mode 120000 index 000000000..c780015d1 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/onnx_model_wrapper.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7_streaming/onnx_model_wrapper.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/onnx_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/onnx_pretrained.py new file mode 120000 index 000000000..da0236c2d --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/onnx_pretrained.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7_streaming/onnx_pretrained.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/optim.py new file mode 120000 index 000000000..6c5f3fc3e --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/optim.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7_streaming/optim.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/pretrained.py new file mode 120000 index 000000000..4c519b771 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/pretrained.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7_streaming/pretrained.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/scaling.py new file mode 120000 index 000000000..420bc4149 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/scaling.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7_streaming/scaling.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/scaling_converter.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/scaling_converter.py new file mode 120000 index 000000000..b6cc7dc13 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/scaling_converter.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7_streaming/scaling_converter.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/streaming-ncnn-decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/streaming-ncnn-decode.py new file mode 120000 index 000000000..d137d28ad --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/streaming-ncnn-decode.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/streaming_beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/streaming_beam_search.py new file mode 120000 index 000000000..dee9005d0 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/streaming_beam_search.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7_streaming/streaming_beam_search.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/streaming_decode.py new file mode 100644 index 000000000..78713f920 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/streaming_decode.py @@ -0,0 +1,610 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corporation (Authors: Wei Kang, Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Usage: +./pruned_transducer_stateless7_streaming/streaming_decode.py \ + --epoch 28 \ + --avg 15 \ + --decode-chunk-len 32 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --decoding_method greedy_search \ + --num-decode-streams 2000 +""" + +import argparse +import logging +import math +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import numpy as np +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import AsrDataModule +from decode_stream import DecodeStream +from kaldifeat import Fbank, FbankOptions +from lhotse import CutSet +from librispeech import LibriSpeech +from streaming_beam_search import ( + fast_beam_search_one_best, + greedy_search, + modified_beam_search, +) +from torch.nn.utils.rnn import pad_sequence +from train import add_model_arguments, get_params, get_transducer_model +from zipformer import stack_states, unstack_states + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import ( + AttributeDict, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 0. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless2/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Supported decoding methods are: + greedy_search + modified_beam_search + fast_beam_search + """, + ) + + parser.add_argument( + "--num_active_paths", + type=int, + default=4, + help="""An interger indicating how many candidates we will keep for each + frame. Used only when --decoding-method is modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=4, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search""", + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=4, + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=32, + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--num-decode-streams", + type=int, + default=2000, + help="The number of streams that can be decoded parallel.", + ) + + add_model_arguments(parser) + + return parser + + +def decode_one_chunk( + params: AttributeDict, + model: nn.Module, + decode_streams: List[DecodeStream], +) -> List[int]: + """Decode one chunk frames of features for each decode_streams and + return the indexes of finished streams in a List. + + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + decode_streams: + A List of DecodeStream, each belonging to a utterance. + Returns: + Return a List containing which DecodeStreams are finished. + """ + device = model.device + + features = [] + feature_lens = [] + states = [] + processed_lens = [] + + for stream in decode_streams: + feat, feat_len = stream.get_feature_frames(params.decode_chunk_len) + features.append(feat) + feature_lens.append(feat_len) + states.append(stream.states) + processed_lens.append(stream.done_frames) + + feature_lens = torch.tensor(feature_lens, device=device) + features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS) + + # We subsample features with ((x_len - 7) // 2 + 1) // 2 and the max downsampling + # factor in encoders is 8. + # After feature embedding (x_len - 7) // 2, we have (23 - 7) // 2 = 8. + tail_length = 23 + if features.size(1) < tail_length: + pad_length = tail_length - features.size(1) + feature_lens += pad_length + features = torch.nn.functional.pad( + features, + (0, 0, 0, pad_length), + mode="constant", + value=LOG_EPS, + ) + + states = stack_states(states) + processed_lens = torch.tensor(processed_lens, device=device) + + encoder_out, encoder_out_lens, new_states = model.encoder.streaming_forward( + x=features, + x_lens=feature_lens, + states=states, + ) + + encoder_out = model.joiner.encoder_proj(encoder_out) + + if params.decoding_method == "greedy_search": + greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams) + elif params.decoding_method == "fast_beam_search": + processed_lens = processed_lens + encoder_out_lens + fast_beam_search_one_best( + model=model, + encoder_out=encoder_out, + processed_lens=processed_lens, + streams=decode_streams, + beam=params.beam, + max_states=params.max_states, + max_contexts=params.max_contexts, + ) + elif params.decoding_method == "modified_beam_search": + modified_beam_search( + model=model, + streams=decode_streams, + encoder_out=encoder_out, + num_active_paths=params.num_active_paths, + ) + else: + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + + states = unstack_states(new_states) + + finished_streams = [] + for i in range(len(decode_streams)): + decode_streams[i].states = states[i] + decode_streams[i].done_frames += encoder_out_lens[i] + if decode_streams[i].done: + finished_streams.append(i) + + return finished_streams + + +def decode_dataset( + cuts: CutSet, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[List[str], List[str]]]]: + """Decode dataset. + + Args: + cuts: + Lhotse Cutset containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + device = model.device + + opts = FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = 16000 + opts.mel_opts.num_bins = 80 + + log_interval = 50 + + decode_results = [] + # Contain decode streams currently running. + decode_streams = [] + for num, cut in enumerate(cuts): + # each utterance has a DecodeStream. + initial_states = model.encoder.get_init_state(device=device) + decode_stream = DecodeStream( + params=params, + cut_id=cut.id, + initial_states=initial_states, + decoding_graph=decoding_graph, + device=device, + ) + + audio: np.ndarray = cut.load_audio() + # audio.shape: (1, num_samples) + assert len(audio.shape) == 2 + assert audio.shape[0] == 1, "Should be single channel" + assert audio.dtype == np.float32, audio.dtype + + # The trained model is using normalized samples + assert audio.max() <= 1, "Should be normalized to [-1, 1])" + + samples = torch.from_numpy(audio).squeeze(0) + + fbank = Fbank(opts) + feature = fbank(samples.to(device)) + decode_stream.set_features(feature, tail_pad_len=params.decode_chunk_len) + decode_stream.ground_truth = cut.supervisions[0].text + + decode_streams.append(decode_stream) + + while len(decode_streams) >= params.num_decode_streams: + finished_streams = decode_one_chunk( + params=params, model=model, decode_streams=decode_streams + ) + for i in sorted(finished_streams, reverse=True): + decode_results.append( + ( + decode_streams[i].id, + decode_streams[i].ground_truth.split(), + sp.decode(decode_streams[i].decoding_result()).split(), + ) + ) + del decode_streams[i] + + if num % log_interval == 0: + logging.info(f"Cuts processed until now is {num}.") + + # decode final chunks of last sequences + while len(decode_streams): + finished_streams = decode_one_chunk( + params=params, model=model, decode_streams=decode_streams + ) + for i in sorted(finished_streams, reverse=True): + decode_results.append( + ( + decode_streams[i].id, + decode_streams[i].ground_truth.split(), + sp.decode(decode_streams[i].decoding_result()).split(), + ) + ) + del decode_streams[i] + + if params.decoding_method == "greedy_search": + key = "greedy_search" + elif params.decoding_method == "fast_beam_search": + key = ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}" + ) + elif params.decoding_method == "modified_beam_search": + key = f"num_active_paths_{params.num_active_paths}" + else: + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + return {key: decode_results} + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[List[str], List[str]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + AsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + params.res_dir = params.exp_dir / "streaming" / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + # for streaming + params.suffix += f"-streaming-chunk-size-{params.decode_chunk_len}" + + # for fast_beam_search + if params.decoding_method == "fast_beam_search": + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if start >= 0: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + model.device = device + + decoding_graph = None + if params.decoding_method == "fast_beam_search": + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + librispeech = LibriSpeech(manifest_dir=args.manifest_dir) + + test_clean_cuts = librispeech.test_clean_cuts() + test_other_cuts = librispeech.test_other_cuts() + + test_sets = ["test-clean", "test-other"] + test_cuts = [test_clean_cuts, test_other_cuts] + + for test_set, test_cut in zip(test_sets, test_cuts): + results_dict = decode_dataset( + cuts=test_cut, + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/test_model.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/test_model.py new file mode 120000 index 000000000..1a9ba93e6 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/test_model.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7_streaming/test_model.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/train.py new file mode 100755 index 000000000..09e8a512f --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/train.py @@ -0,0 +1,1370 @@ +#!/usr/bin/env python3 +# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo,) +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +./pruned_transducer_stateless7_streaming/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --exp-dir pruned_transducer_stateless7_streaming/exp \ + --full-libri 1 \ + --max-duration 300 + +# For mix precision training: + +./pruned_transducer_stateless7_streaming/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir pruned_transducer_stateless7_streaming/exp \ + --full-libri 1 \ + --max-duration 550 +""" + + +import argparse +import copy +import logging +import random +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import k2 +import optim +import sentencepiece as spm +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import AsrDataModule +from decoder import Decoder +from gigaspeech import GigaSpeech +from joiner import Joiner +from lhotse import CutSet, load_manifest +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from librispeech import LibriSpeech +from model import Transducer +from optim import Eden, ScaledAdam +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from zipformer import Zipformer + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for module in model.modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=str, + default="2,4,3,2,4", + help="Number of zipformer encoder layers, comma separated.", + ) + + parser.add_argument( + "--feedforward-dims", + type=str, + default="1024,1024,2048,2048,1024", + help="Feedforward dimension of the zipformer encoder layers, comma separated.", + ) + + parser.add_argument( + "--nhead", + type=str, + default="8,8,8,8,8", + help="Number of attention heads in the zipformer encoder layers.", + ) + + parser.add_argument( + "--encoder-dims", + type=str, + default="384,384,384,384,384", + help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated", + ) + + parser.add_argument( + "--attention-dims", + type=str, + default="192,192,192,192,192", + help="""Attention dimension in the 2 blocks of zipformer encoder layers, comma separated; + not the same as embedding dimension.""", + ) + + parser.add_argument( + "--encoder-unmasked-dims", + type=str, + default="256,256,256,256,256", + help="Unmasked dimensions in the encoders, relates to augmentation during training. " + "Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance " + " worse.", + ) + + parser.add_argument( + "--zipformer-downsampling-factors", + type=str, + default="1,2,4,8,2", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--cnn-module-kernels", + type=str, + default="31,31,31,31,31", + help="Sizes of kernels in convolution modules", + ) + + parser.add_argument( + "--decoder-dim", + type=int, + default=512, + help="Embedding dimension in the decoder model.", + ) + + parser.add_argument( + "--joiner-dim", + type=int, + default=512, + help="""Dimension used in the joiner model. + Outputs from the encoder and decoder model are projected + to this dimension before adding. + """, + ) + + parser.add_argument( + "--short-chunk-size", + type=int, + default=50, + help="""Chunk length of dynamic training, the chunk size would be either + max sequence length of current batch or uniformly sampled from (1, short_chunk_size). + """, + ) + + parser.add_argument( + "--num-left-chunks", + type=int, + default=4, + help="How many left context can be seen in chunks when calculating attention.", + ) + + parser.add_argument( + "--decode-chunk-len", + type=int, + default=32, + help="The chunk size for decoding (in frames before subsampling)", + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--full-libri", + type=str2bool, + default=True, + help="When enabled, use 960h LibriSpeech. Otherwise, use 100h subset.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7_streaming/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--base-lr", type=float, default=0.05, help="The base learning rate." + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=5000, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=3.5, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network) part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=2000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 0. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + parser.add_argument( + "--giga-prob", + type=float, + default=0.5, + help="The probability to select a batch from the GigaSpeech dataset", + ) + + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - encoder_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warmup period that dictates the decay of the + scale on "simple" (un-pruned) loss. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, # For the 100h subset, use 800 + # parameters for zipformer + "feature_dim": 80, + "subsampling_factor": 4, # not passed in, this is fixed. + "warm_step": 2000, + "env_info": get_env_info(), + } + ) + + return params + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + # TODO: We can add an option to switch between Zipformer and Transformer + def to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + + encoder = Zipformer( + num_features=params.feature_dim, + output_downsampling_factor=2, + zipformer_downsampling_factors=to_int_tuple( + params.zipformer_downsampling_factors + ), + encoder_dims=to_int_tuple(params.encoder_dims), + attention_dim=to_int_tuple(params.attention_dims), + encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims), + nhead=to_int_tuple(params.nhead), + feedforward_dim=to_int_tuple(params.feedforward_dims), + cnn_module_kernels=to_int_tuple(params.cnn_module_kernels), + num_encoder_layers=to_int_tuple(params.num_encoder_layers), + num_left_chunks=params.num_left_chunks, + short_chunk_size=params.short_chunk_size, + decode_chunk_size=params.decode_chunk_len // 2, + ) + return encoder + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + decoder_dim=params.decoder_dim, + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + joiner = Joiner( + encoder_dim=int(params.encoder_dims.split(",")[-1]), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return joiner + + +def get_transducer_model(params: AttributeDict) -> nn.Module: + encoder = get_encoder_model(params) + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + + model = Transducer( + encoder=encoder, + decoder=decoder, + joiner=joiner, + encoder_dim=int(params.encoder_dims.split(",")[-1]), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + if "cur_batch_idx" in saved_params: + params["cur_batch_idx"] = saved_params["cur_batch_idx"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def is_libri(c: Cut) -> bool: + """Return True if this cut is from the LibriSpeech dataset. + + Note: + During data preparation, we set the custom field in + the supervision segment of GigaSpeech to dict(origin='giga') + See ../local/preprocess_gigaspeech.py. + """ + return c.supervisions[0].custom is None + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute transducer loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + batch_idx_train = params.batch_idx_train + warm_step = params.warm_step + + texts = batch["supervisions"]["text"] + y = sp.encode(texts, out_type=int) + y = k2.RaggedTensor(y).to(device) + + with torch.set_grad_enabled(is_training): + simple_loss, pruned_loss = model( + x=feature, + x_lens=feature_lens, + y=y, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + ) + + s = params.simple_loss_scale + # take down the scale on the simple loss from 1.0 at the start + # to params.simple_loss scale by warm_step. + simple_loss_scale = ( + s + if batch_idx_train >= warm_step + else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) + ) + pruned_loss_scale = ( + 1.0 + if batch_idx_train >= warm_step + else 0.1 + 0.9 * (batch_idx_train / warm_step) + ) + + loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + sp: spm.SentencePieceProcessor, + train_dl: torch.utils.data.DataLoader, + giga_train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + rng: random.Random, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + libri_tot_loss = MetricsTracker() + giga_tot_loss = MetricsTracker() + tot_loss = MetricsTracker() + + # index 0: for LibriSpeech + # index 1: for GigaSpeech + # This sets the probabilities for choosing which datasets + dl_weights = [1 - params.giga_prob, params.giga_prob] + iter_libri = iter(train_dl) + iter_giga = iter(giga_train_dl) + + batch_idx = 0 + + while True: + idx = rng.choices((0, 1), weights=dl_weights, k=1)[0] + dl = iter_libri if idx == 0 else iter_giga + + try: + batch = next(dl) + except StopIteration: + name = "libri" if idx == 0 else "giga" + logging.info(f"{name} reaches end of dataloader") + break + + batch_idx += 1 + + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + libri = is_libri(batch["supervisions"]["cut"][0]) + + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + if libri: + libri_tot_loss = ( + libri_tot_loss * (1 - 1 / params.reset_interval) + ) + loss_info + prefix = "libri" # for logging only + else: + giga_tot_loss = ( + giga_tot_loss * (1 - 1 / params.reset_interval) + ) + loss_info + prefix = "giga" + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + set_batch_count(model, params.batch_idx_train) + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except: # noqa + display_and_save_batch(batch, params=params, sp=sp) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + params.cur_batch_idx = batch_idx + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + del params.cur_batch_idx + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if batch_idx % params.log_interval == 0: + cur_lr = scheduler.get_last_lr()[0] + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, {prefix}_loss[{loss_info}], " + f"tot_loss[{tot_loss}], " + f"libri_tot_loss[{libri_tot_loss}], " + f"giga_tot_loss[{giga_tot_loss}], " + f"batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + libri_tot_loss.write_summary( + tb_writer, "train/libri_tot_", params.batch_idx_train + ) + giga_tot_loss.write_summary( + tb_writer, "train/giga_tot_", params.batch_idx_train + ) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", + cur_grad_scale, + params.batch_idx_train, + ) + + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def filter_short_and_long_utterances( + cuts: CutSet, sp: spm.SentencePieceProcessor +) -> CutSet: + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + if c.duration < 1.0 or c.duration > 20.0: + return False + + # In pruned RNN-T, we require that T >= S + # where T is the number of feature frames after subsampling + # and S is the number of tokens in the utterance + + # In ./zipformer.py, the conv module uses the following expression + # for subsampling + T = ((c.num_frames - 7) // 2 + 1) // 2 + tokens = sp.encode(c.supervisions[0].text, out_type=str) + + if T < len(tokens): + return False + + return True + + cuts = cuts.filter(remove_short_and_long_utt) + + return cuts + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + if params.full_libri is False: + params.valid_interval = 1600 + + fix_random_seed(params.seed) + rng = random.Random(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + parameters_names = [] + parameters_names.append( + [name_param_pair[0] for name_param_pair in model.named_parameters()] + ) + optimizer = ScaledAdam( + model.parameters(), + lr=params.base_lr, + clipping_scale=2.0, + parameters_names=parameters_names, + ) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 2**22 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + librispeech = LibriSpeech(manifest_dir=args.manifest_dir) + + train_cuts = librispeech.train_clean_100_cuts() + if params.full_libri: + train_cuts += librispeech.train_clean_360_cuts() + train_cuts += librispeech.train_other_500_cuts() + + train_cuts = filter_short_and_long_utterances(train_cuts, sp) + + gigaspeech = GigaSpeech(manifest_dir=args.manifest_dir) + # XL 10k hours + # L 2.5k hours + # M 1k hours + # S 250 hours + # XS 10 hours + # DEV 12 hours + # Test 40 hours + if params.full_libri: + logging.info("Using the XL subset of GigaSpeech (10k hours)") + train_giga_cuts = gigaspeech.train_XL_cuts() + else: + logging.info("Using the S subset of GigaSpeech (250 hours)") + train_giga_cuts = gigaspeech.train_S_cuts() + + train_giga_cuts = filter_short_and_long_utterances(train_giga_cuts, sp) + train_giga_cuts = train_giga_cuts.repeat(times=None) + + if args.enable_musan: + cuts_musan = load_manifest(Path(args.manifest_dir) / "musan_cuts.jsonl.gz") + else: + cuts_musan = None + + asr_datamodule = AsrDataModule(args) + + train_dl = asr_datamodule.train_dataloaders( + train_cuts, + on_the_fly_feats=False, + cuts_musan=cuts_musan, + ) + + giga_train_dl = asr_datamodule.train_dataloaders( + train_giga_cuts, + on_the_fly_feats=False, + cuts_musan=cuts_musan, + ) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + valid_cuts = librispeech.dev_clean_cuts() + valid_cuts += librispeech.dev_other_cuts() + valid_dl = asr_datamodule.valid_dataloaders(valid_cuts) + + # if not params.print_diagnostics: + # scan_pessimistic_batches_for_oom( + # model=model, + # train_dl=train_dl, + # optimizer=optimizer, + # sp=sp, + # params=params, + # ) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sp=sp, + train_dl=train_dl, + giga_train_dl=giga_train_dl, + valid_dl=valid_dl, + rng=rng, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + sp: spm.SentencePieceProcessor, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + sp: + The BPE model. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + y = sp.encode(supervisions["text"], out_type=int) + num_tokens = sum(len(i) for i in y) + logging.info(f"num tokens: {num_tokens}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + sp: spm.SentencePieceProcessor, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params, sp=sp) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + parser = get_parser() + AsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + assert 0 <= args.giga_prob < 1, args.giga_prob + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/train2.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/train2.py new file mode 120000 index 000000000..3c3280b68 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/train2.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7_streaming/train2.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/zipformer.py new file mode 120000 index 000000000..be9e75bfa --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/zipformer.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7_streaming/zipformer.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/zipformer2.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/zipformer2.py new file mode 120000 index 000000000..d3625f478 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/zipformer2.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7_streaming/zipformer2.py \ No newline at end of file From d67a49afe46d5afdd04af2c03f27362bdbaca01e Mon Sep 17 00:00:00 2001 From: Yifan Yang <64255737+yfyeung@users.noreply.github.com> Date: Fri, 21 Apr 2023 18:09:41 +0800 Subject: [PATCH 165/174] Add multidataset (#1010) * Add Common Voice for multidataset * Add prepare_multidataset.sh * Add dataset mixing * Update prepare_multidataset.sh * Update prepare_giga_speech.sh * update comments * Add split and shuffle mechanism * Add multi-dataset train * Fix for deleting * Fix for modifying * Add comments * Change type for perturb_speed * Fix for style check * Small fix * Add filter * Remove warning --- .../ASR/local/compute_fbank_librispeech.py | 26 +- egs/librispeech/ASR/prepare_common_voice.sh | 117 ++++++ egs/librispeech/ASR/prepare_giga_speech.sh | 61 +-- egs/librispeech/ASR/prepare_multidataset.sh | 373 ++++++++++++++++++ .../gigaspeech.py | 2 +- .../multidataset.py | 53 +++ .../ASR/pruned_transducer_stateless7/train.py | 30 +- 7 files changed, 624 insertions(+), 38 deletions(-) create mode 100755 egs/librispeech/ASR/prepare_common_voice.sh create mode 100755 egs/librispeech/ASR/prepare_multidataset.sh create mode 100644 egs/librispeech/ASR/pruned_transducer_stateless7/multidataset.py diff --git a/egs/librispeech/ASR/local/compute_fbank_librispeech.py b/egs/librispeech/ASR/local/compute_fbank_librispeech.py index 554d7f109..25d6050bb 100755 --- a/egs/librispeech/ASR/local/compute_fbank_librispeech.py +++ b/egs/librispeech/ASR/local/compute_fbank_librispeech.py @@ -35,7 +35,7 @@ from filter_cuts import filter_cuts from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter from lhotse.recipes.utils import read_manifests_if_cached -from icefall.utils import get_executor +from icefall.utils import get_executor, str2bool # Torch's multithreaded behavior needs to be disabled or # it wastes a lot of CPU and slow things down. @@ -61,12 +61,20 @@ def get_args(): help="""Dataset parts to compute fbank. If None, we will use all""", ) + parser.add_argument( + "--perturb-speed", + type=str2bool, + default=True, + help="""Perturb speed with factor 0.9 and 1.1 on train subset.""", + ) + return parser.parse_args() def compute_fbank_librispeech( bpe_model: Optional[str] = None, dataset: Optional[str] = None, + perturb_speed: Optional[bool] = True, ): src_dir = Path("data/manifests") output_dir = Path("data/fbank") @@ -125,9 +133,13 @@ def compute_fbank_librispeech( if "train" in partition: if bpe_model: cut_set = filter_cuts(cut_set, sp) - cut_set = ( - cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) - ) + if perturb_speed: + logging.info(f"Doing speed perturb") + cut_set = ( + cut_set + + cut_set.perturb_speed(0.9) + + cut_set.perturb_speed(1.1) + ) cut_set = cut_set.compute_and_store_features( extractor=extractor, storage_path=f"{output_dir}/{prefix}_feats_{partition}", @@ -145,4 +157,8 @@ if __name__ == "__main__": logging.basicConfig(format=formatter, level=logging.INFO) args = get_args() logging.info(vars(args)) - compute_fbank_librispeech(bpe_model=args.bpe_model, dataset=args.dataset) + compute_fbank_librispeech( + bpe_model=args.bpe_model, + dataset=args.dataset, + perturb_speed=args.perturb_speed, + ) diff --git a/egs/librispeech/ASR/prepare_common_voice.sh b/egs/librispeech/ASR/prepare_common_voice.sh new file mode 100755 index 000000000..6f9c4fb2f --- /dev/null +++ b/egs/librispeech/ASR/prepare_common_voice.sh @@ -0,0 +1,117 @@ +#!/usr/bin/env bash + +set -eou pipefail + +nj=16 +stage=-1 +stop_stage=100 + +# Split data/${lang}set to this number of pieces +# This is to avoid OOM during feature extraction. +num_splits=1000 + +# We assume dl_dir (download dir) contains the following +# directories and files. If not, they will be downloaded +# by this script automatically. +# +# - $dl_dir/$release/$lang +# This directory contains the following files downloaded from +# https://mozilla-common-voice-datasets.s3.dualstack.us-west-2.amazonaws.com/${release}/${release}-${lang}.tar.gz +# +# - clips +# - dev.tsv +# - invalidated.tsv +# - other.tsv +# - reported.tsv +# - test.tsv +# - train.tsv +# - validated.tsv + +dl_dir=$PWD/download +release=cv-corpus-13.0-2023-03-09 +lang=en + +. shared/parse_options.sh || exit 1 + +# All files generated by this script are saved in "data/${lang}". +# You can safely remove "data/${lang}" and rerun this script to regenerate it. +mkdir -p data/${lang} + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +log "dl_dir: $dl_dir" + +if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then + log "Stage 0: Download data" + + # If you have pre-downloaded it to /path/to/$release, + # you can create a symlink + # + # ln -sfv /path/to/$release $dl_dir/$release + # + if [ ! -d $dl_dir/$release/$lang/clips ]; then + lhotse download commonvoice --languages $lang --release $release $dl_dir + fi +fi + +if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then + log "Stage 1: Prepare CommonVoice manifest" + # We assume that you have downloaded the CommonVoice corpus + # to $dl_dir/$release + mkdir -p data/${lang}/manifests + if [ ! -e data/${lang}/manifests/.cv-${lang}.done ]; then + lhotse prepare commonvoice --language $lang -j $nj $dl_dir/$release data/${lang}/manifests + touch data/${lang}/manifests/.cv-${lang}.done + fi +fi + +if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then + log "Stage 2: Preprocess CommonVoice manifest" + if [ ! -e data/${lang}/fbank/.preprocess_complete ]; then + ./local/preprocess_commonvoice.py --language $lang + touch data/${lang}/fbank/.preprocess_complete + fi +fi + +if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then + log "Stage 3: Compute fbank for dev and test subsets of CommonVoice" + mkdir -p data/${lang}/fbank + if [ ! -e data/${lang}/fbank/.cv-${lang}_dev_test.done ]; then + ./local/compute_fbank_commonvoice_dev_test.py --language $lang + touch data/${lang}/fbank/.cv-${lang}_dev_test.done + fi +fi + +if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then + log "Stage 4: Split train subset into ${num_splits} pieces" + split_dir=data/${lang}/fbank/cv-${lang}_train_split_${num_splits} + if [ ! -e $split_dir/.cv-${lang}_train_split.done ]; then + lhotse split $num_splits ./data/${lang}/fbank/cv-${lang}_cuts_train_raw.jsonl.gz $split_dir + touch $split_dir/.cv-${lang}_train_split.done + fi +fi + +if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then + log "Stage 5: Compute features for train subset of CommonVoice" + if [ ! -e data/${lang}/fbank/.cv-${lang}_train.done ]; then + ./local/compute_fbank_commonvoice_splits.py \ + --num-workers $nj \ + --batch-duration 600 \ + --start 0 \ + --num-splits $num_splits \ + --language $lang + touch data/${lang}/fbank/.cv-${lang}_train.done + fi +fi + +if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then + log "Stage 6: Combine features for train" + if [ ! -f data/${lang}/fbank/cv-${lang}_cuts_train.jsonl.gz ]; then + pieces=$(find data/${lang}/fbank/cv-${lang}_train_split_${num_splits} -name "cv-${lang}_cuts_train.*.jsonl.gz") + lhotse combine $pieces data/${lang}/fbank/cv-${lang}_cuts_train.jsonl.gz + fi +fi diff --git a/egs/librispeech/ASR/prepare_giga_speech.sh b/egs/librispeech/ASR/prepare_giga_speech.sh index 6f85ddc29..b077aaf3a 100755 --- a/egs/librispeech/ASR/prepare_giga_speech.sh +++ b/egs/librispeech/ASR/prepare_giga_speech.sh @@ -95,39 +95,45 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then log "Stage 1: Prepare GigaSpeech manifest (may take 30 minutes)" # We assume that you have downloaded the GigaSpeech corpus # to $dl_dir/GigaSpeech - mkdir -p data/manifests - lhotse prepare gigaspeech \ - --subset XL \ - --subset L \ - --subset M \ - --subset S \ - --subset XS \ - --subset DEV \ - --subset TEST \ - -j $nj \ - $dl_dir/GigaSpeech data/manifests + if [ ! -f data/manifests/.gigaspeech.done ]; then + mkdir -p data/manifests + lhotse prepare gigaspeech \ + --subset XL \ + --subset L \ + --subset M \ + --subset S \ + --subset XS \ + --subset DEV \ + --subset TEST \ + -j $nj \ + $dl_dir/GigaSpeech data/manifests + touch data/manifests/.gigaspeech.done + fi fi if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then log "Stage 2: Preprocess GigaSpeech manifest" - if [ ! -f data/fbank/.preprocess_complete ]; then - log "It may take 2 hours for this stage" - python3 ./local/preprocess_gigaspeech.py - touch data/fbank/.preprocess_complete + if [ ! -f data/fbank/.gigaspeech_preprocess.done ]; then + log "It may take 2 hours for this stage" + ./local/preprocess_gigaspeech.py + touch data/fbank/.gigaspeech_preprocess.done fi fi if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then log "Stage 3: Compute features for DEV and TEST subsets of GigaSpeech (may take 2 minutes)" - python3 ./local/compute_fbank_gigaspeech_dev_test.py + if [ ! -f data/fbank/.gigaspeech_dev_test.done ]; then + ./local/compute_fbank_gigaspeech_dev_test.py + touch data/fbank/.gigaspeech_dev_test.done + fi fi if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then log "Stage 4: Split XL subset into ${num_splits} pieces" split_dir=data/fbank/gigaspeech_XL_split_${num_splits} - if [ ! -f $split_dir/.split_completed ]; then + if [ ! -f $split_dir/.gigaspeech_XL_split.done ]; then lhotse split-lazy ./data/fbank/gigaspeech_cuts_XL_raw.jsonl.gz $split_dir $chunk_size - touch $split_dir/.split_completed + touch $split_dir/.gigaspeech_XL_split.done fi fi @@ -135,8 +141,19 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then log "Stage 5: Compute features for XL" # Note: The script supports --start and --stop options. # You can use several machines to compute the features in parallel. - python3 ./local/compute_fbank_gigaspeech_splits.py \ - --num-workers $nj \ - --batch-duration 600 \ - --num-splits $num_splits + if [ ! -f data/fbank/.gigaspeech_XL.done ]; then + ./local/compute_fbank_gigaspeech_splits.py \ + --num-workers $nj \ + --batch-duration 600 \ + --num-splits $num_splits + touch data/fbank/.gigaspeech_XL.done + fi +fi + +if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then + log "Stage 6: Combine features for XL (may take 15 hours)" + if [ ! -f data/fbank/gigaspeech_cuts_XL.jsonl.gz ]; then + pieces=$(find data/fbank/gigaspeech_XL_split_${num_splits} -name "gigaspeech_cuts_XL.*.jsonl.gz") + lhotse combine $pieces data/fbank/gigaspeech_cuts_XL.jsonl.gz + fi fi diff --git a/egs/librispeech/ASR/prepare_multidataset.sh b/egs/librispeech/ASR/prepare_multidataset.sh new file mode 100755 index 000000000..c068305c0 --- /dev/null +++ b/egs/librispeech/ASR/prepare_multidataset.sh @@ -0,0 +1,373 @@ +#!/usr/bin/env bash + +# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 +export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python + +set -eou pipefail + +nj=16 +stage=-1 +stop_stage=100 + +# We assume dl_dir (download dir) contains the following +# directories and files. If not, they will be downloaded +# by this script automatically. +# +# - $dl_dir/LibriSpeech +# You can find BOOKS.TXT, test-clean, train-clean-360, etc, inside it. +# You can download them from https://www.openslr.org/12 +# +# - $dl_dir/lm +# This directory contains the following files downloaded from +# http://www.openslr.org/resources/11 +# +# - 3-gram.pruned.1e-7.arpa.gz +# - 3-gram.pruned.1e-7.arpa +# - 4-gram.arpa.gz +# - 4-gram.arpa +# - librispeech-vocab.txt +# - librispeech-lexicon.txt +# - librispeech-lm-norm.txt.gz +# +# - $dl_dir/musan +# This directory contains the following directories downloaded from +# http://www.openslr.org/17/ +# +# - music +# - noise +# - speech + +# Split all dataset to this number of pieces and mix each dataset pieces +# into multidataset pieces with shuffling. +num_splits=1998 + +dl_dir=$PWD/download + +. shared/parse_options.sh || exit 1 + +# vocab size for sentence piece models. +# It will generate data/lang_bpe_xxx, +# data/lang_bpe_yyy if the array contains xxx, yyy +vocab_sizes=( + # 5000 + # 2000 + # 1000 + 500 +) + +# multidataset list. +# LibriSpeech and musan are required. +# The others are optional. +multidataset=( + "gigaspeech", + "commonvoice", +) + +# All files generated by this script are saved in "data". +# You can safely remove "data" and rerun this script to regenerate it. +mkdir -p data + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +log "dl_dir: $dl_dir" + +log "Dataset: LibriSpeech and musan" +if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then + log "Stage -1: Download LM" + mkdir -p $dl_dir/lm + if [ ! -e $dl_dir/lm/.done ]; then + ./local/download_lm.py --out-dir=$dl_dir/lm + touch $dl_dir/lm/.done + fi +fi + +if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then + log "Stage 0: Download data" + + # If you have pre-downloaded it to /path/to/LibriSpeech, + # you can create a symlink + # + # ln -sfv /path/to/LibriSpeech $dl_dir/LibriSpeech + # + if [ ! -d $dl_dir/LibriSpeech/train-other-500 ]; then + lhotse download librispeech --full $dl_dir + fi + + # If you have pre-downloaded it to /path/to/musan, + # you can create a symlink + # + # ln -sfv /path/to/musan $dl_dir/ + # + if [ ! -d $dl_dir/musan ]; then + lhotse download musan $dl_dir + fi +fi + +if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then + log "Stage 1: Prepare LibriSpeech manifest" + # We assume that you have downloaded the LibriSpeech corpus + # to $dl_dir/LibriSpeech + mkdir -p data/manifests + if [ ! -e data/manifests/.librispeech.done ]; then + lhotse prepare librispeech -j $nj $dl_dir/LibriSpeech data/manifests + touch data/manifests/.librispeech.done + fi +fi + +if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then + log "Stage 2: Prepare musan manifest" + # We assume that you have downloaded the musan corpus + # to data/musan + mkdir -p data/manifests + if [ ! -e data/manifests/.musan.done ]; then + lhotse prepare musan $dl_dir/musan data/manifests + touch data/manifests/.musan.done + fi +fi + +if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then + log "Stage 3: Compute fbank for librispeech" + mkdir -p data/fbank + if [ ! -e data/fbank/.librispeech.done ]; then + ./local/compute_fbank_librispeech.py --perturb-speed False + touch data/fbank/.librispeech.done + fi + + if [ ! -f data/fbank/librispeech_cuts_train-all-shuf.jsonl.gz ]; then + cat <(gunzip -c data/fbank/librispeech_cuts_train-clean-100.jsonl.gz) \ + <(gunzip -c data/fbank/librispeech_cuts_train-clean-360.jsonl.gz) \ + <(gunzip -c data/fbank/librispeech_cuts_train-other-500.jsonl.gz) | \ + shuf | gzip -c > data/fbank/librispeech_cuts_train-all-shuf.jsonl.gz + fi + + if [ ! -e data/fbank/.librispeech-validated.done ]; then + log "Validating data/fbank for LibriSpeech" + parts=( + train-clean-100 + train-clean-360 + train-other-500 + test-clean + test-other + dev-clean + dev-other + ) + for part in ${parts[@]}; do + python3 ./local/validate_manifest.py \ + data/fbank/librispeech_cuts_${part}.jsonl.gz + done + touch data/fbank/.librispeech-validated.done + fi +fi + +if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then + log "Stage 4: Compute fbank for musan" + mkdir -p data/fbank + if [ ! -e data/fbank/.musan.done ]; then + ./local/compute_fbank_musan.py + touch data/fbank/.musan.done + fi +fi + +if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then + log "Stage 5: Prepare phone based lang" + lang_dir=data/lang_phone + mkdir -p $lang_dir + + (echo '!SIL SIL'; echo ' SPN'; echo ' SPN'; ) | + cat - $dl_dir/lm/librispeech-lexicon.txt | + sort | uniq > $lang_dir/lexicon.txt + + if [ ! -f $lang_dir/L_disambig.pt ]; then + ./local/prepare_lang.py --lang-dir $lang_dir + fi + + if [ ! -f $lang_dir/L.fst ]; then + log "Converting L.pt to L.fst" + ./shared/convert-k2-to-openfst.py \ + --olabels aux_labels \ + $lang_dir/L.pt \ + $lang_dir/L.fst + fi + + if [ ! -f $lang_dir/L_disambig.fst ]; then + log "Converting L_disambig.pt to L_disambig.fst" + ./shared/convert-k2-to-openfst.py \ + --olabels aux_labels \ + $lang_dir/L_disambig.pt \ + $lang_dir/disambig_L.fst + fi +fi + +if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then + log "Stage 6: Prepare BPE based lang" + + for vocab_size in ${vocab_sizes[@]}; do + lang_dir=data/lang_bpe_${vocab_size} + mkdir -p $lang_dir + # We reuse words.txt from phone based lexicon + # so that the two can share G.pt later. + cp data/lang_phone/words.txt $lang_dir + + if [ ! -f $lang_dir/transcript_words.txt ]; then + log "Generate data for BPE training" + files=$( + find "$dl_dir/LibriSpeech/train-clean-100" -name "*.trans.txt" + find "$dl_dir/LibriSpeech/train-clean-360" -name "*.trans.txt" + find "$dl_dir/LibriSpeech/train-other-500" -name "*.trans.txt" + ) + for f in ${files[@]}; do + cat $f | cut -d " " -f 2- + done > $lang_dir/transcript_words.txt + fi + + if [ ! -f $lang_dir/bpe.model ]; then + ./local/train_bpe_model.py \ + --lang-dir $lang_dir \ + --vocab-size $vocab_size \ + --transcript $lang_dir/transcript_words.txt + fi + + if [ ! -f $lang_dir/L_disambig.pt ]; then + ./local/prepare_lang_bpe.py --lang-dir $lang_dir + + log "Validating $lang_dir/lexicon.txt" + ./local/validate_bpe_lexicon.py \ + --lexicon $lang_dir/lexicon.txt \ + --bpe-model $lang_dir/bpe.model + fi + + if [ ! -f $lang_dir/L.fst ]; then + log "Converting L.pt to L.fst" + ./shared/convert-k2-to-openfst.py \ + --olabels aux_labels \ + $lang_dir/L.pt \ + $lang_dir/L.fst + fi + + if [ ! -f $lang_dir/L_disambig.fst ]; then + log "Converting L_disambig.pt to L_disambig.fst" + ./shared/convert-k2-to-openfst.py \ + --olabels aux_labels \ + $lang_dir/L_disambig.pt \ + $lang_dir/L_disambig.fst + fi + done +fi + +if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then + log "Stage 7: Prepare G" + # We assume you have install kaldilm, if not, please install + # it using: pip install kaldilm + + mkdir -p data/lm + if [ ! -f data/lm/G_3_gram.fst.txt ]; then + # It is used in building HLG + python3 -m kaldilm \ + --read-symbol-table="data/lang_phone/words.txt" \ + --disambig-symbol='#0' \ + --max-order=3 \ + $dl_dir/lm/3-gram.pruned.1e-7.arpa > data/lm/G_3_gram.fst.txt + fi + + if [ ! -f data/lm/G_4_gram.fst.txt ]; then + # It is used for LM rescoring + python3 -m kaldilm \ + --read-symbol-table="data/lang_phone/words.txt" \ + --disambig-symbol='#0' \ + --max-order=4 \ + $dl_dir/lm/4-gram.arpa > data/lm/G_4_gram.fst.txt + fi +fi + +if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then + log "Stage 8: Compile HLG" + ./local/compile_hlg.py --lang-dir data/lang_phone + + # Note If ./local/compile_hlg.py throws OOM, + # please switch to the following command + # + # ./local/compile_hlg_using_openfst.py --lang-dir data/lang_phone + + for vocab_size in ${vocab_sizes[@]}; do + lang_dir=data/lang_bpe_${vocab_size} + ./local/compile_hlg.py --lang-dir $lang_dir + + # Note If ./local/compile_hlg.py throws OOM, + # please switch to the following command + # + # ./local/compile_hlg_using_openfst.py --lang-dir $lang_dir + done +fi + +# Compile LG for RNN-T fast_beam_search decoding +if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then + log "Stage 9: Compile LG" + ./local/compile_lg.py --lang-dir data/lang_phone + + for vocab_size in ${vocab_sizes[@]}; do + lang_dir=data/lang_bpe_${vocab_size} + ./local/compile_lg.py --lang-dir $lang_dir + done +fi + +if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then + log "Stage 10: Prepare the other datasets" + # GigaSpeech + if [[ "${multidataset[@]}" =~ "gigaspeech" ]]; then + log "Dataset: GigaSpeech" + ./prepare_giga_speech.sh --stop_stage 5 + fi + + # CommonVoice + if [[ "${multidataset[@]}" =~ "commonvoice" ]]; then + log "Dataset: CommonVoice" + ./prepare_common_voice.sh + fi +fi + +if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then + log "Stage 11: Create multidataset" + split_dir=data/fbank/multidataset_split_${num_splits} + if [ ! -f data/fbank/multidataset_split/.multidataset.done ]; then + mkdir -p $split_dir/multidataset + log "Split LibriSpeech" + if [ ! -f $split_dir/.librispeech_split.done ]; then + lhotse split $num_splits ./data/fbank/librispeech_cuts_train-all-shuf.jsonl.gz $split_dir + touch $split_dir/.librispeech_split.done + fi + + if [[ "${multidataset[@]}" =~ "gigaspeech" ]]; then + log "Split GigaSpeech XL" + if [ ! -f $split_dir/.gigaspeech_XL_split.done ]; then + cd $split_dir + ln -sv ../gigaspeech_XL_split_2000/gigaspeech_cuts_XL.*.jsonl.gz . + cd ../../.. + touch $split_dir/.gigaspeech_XL_split.done + fi + fi + + if [[ "${multidataset[@]}" =~ "commonvoice" ]]; then + log "Split CommonVoice" + if [ ! -f $split_dir/.cv-en_train_split.done ]; then + lhotse split $num_splits ./data/en/fbank/cv-en_cuts_train.jsonl.gz $split_dir + touch $split_dir/.cv-en_train_split.done + fi + fi + + if [ ! -f $split_dir/.multidataset_mix.done ]; then + log "Mix multidataset" + for ((seq=1; seq<=$num_splits; seq++)); do + fseq=$(printf "%04d" $seq) + gunzip -c $split_dir/*.*${fseq}.jsonl.gz | \ + shuf | gzip -c > $split_dir/multidataset/multidataset_cuts_train.${fseq}.jsonl.gz + done + touch $split_dir/.multidataset_mix.done + fi + + touch data/fbank/multidataset_split/.multidataset.done + fi +fi diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/gigaspeech.py b/egs/librispeech/ASR/pruned_transducer_stateless3/gigaspeech.py index 598434f54..f3bd6284e 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/gigaspeech.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/gigaspeech.py @@ -30,7 +30,7 @@ class GigaSpeech: """ Args: manifest_dir: - It is expected to contain the following files:: + It is expected to contain the following files: - gigaspeech_XL_split_2000/gigaspeech_cuts_XL.*.jsonl.gz - gigaspeech_cuts_L_raw.jsonl.gz diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/multidataset.py b/egs/librispeech/ASR/pruned_transducer_stateless7/multidataset.py new file mode 100644 index 000000000..dcb4cd141 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/multidataset.py @@ -0,0 +1,53 @@ +# Copyright 2023 Xiaomi Corp. (authors: Yifan Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import glob +import logging +import re +from pathlib import Path + +import lhotse +from lhotse import CutSet, load_manifest_lazy + + +class MultiDataset: + def __init__(self, manifest_dir: str): + """ + Args: + manifest_dir: + It is expected to contain the following files: + + - multidataset_split_1998/multidataset/multidataset_cuts_train.*.jsonl.gz + """ + self.manifest_dir = Path(manifest_dir) + + def train_cuts(self) -> CutSet: + logging.info("About to get multidataset train cuts") + + filenames = glob.glob( + f"{self.manifest_dir}/multidataset_split_1998/multidataset/multidataset_cuts_train.*.jsonl.gz" + ) + + pattern = re.compile(r"multidataset_cuts_train.([0-9]+).jsonl.gz") + idx_filenames = ((int(pattern.search(f).group(1)), f) for f in filenames) + idx_filenames = sorted(idx_filenames, key=lambda x: x[0]) + + sorted_filenames = [f[1] for f in idx_filenames] + + logging.info(f"Loading {len(sorted_filenames)} splits") + + return lhotse.combine(lhotse.load_manifest_lazy(p) for p in sorted_filenames) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py index 792a243e5..01c9500ce 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py @@ -1,8 +1,9 @@ #!/usr/bin/env python3 -# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang, +# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, # Wei Kang, -# Mingshuang Luo,) -# Zengwei Yao) +# Mingshuang Luo, +# Zengwei Yao, +# Yifan Yang) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -59,6 +60,7 @@ import torch import torch.multiprocessing as mp import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule +from multidataset import MultiDataset from decoder import Decoder from joiner import Joiner from lhotse.cut import Cut @@ -374,6 +376,13 @@ def get_parser(): help="Whether to use half precision training.", ) + parser.add_argument( + "--use-multidataset", + type=str2bool, + default=False, + help="Whether to use multidataset to train.", + ) + add_model_arguments(parser) return parser @@ -1043,10 +1052,14 @@ def run(rank, world_size, args): librispeech = LibriSpeechAsrDataModule(args) - if params.full_libri: - train_cuts = librispeech.train_all_shuf_cuts() + if params.use_multidataset: + multidataset = MultiDataset(params.manifest_dir) + train_cuts = multidataset.train_cuts() else: - train_cuts = librispeech.train_clean_100_cuts() + if params.full_libri: + train_cuts = librispeech.train_all_shuf_cuts() + else: + train_cuts = librispeech.train_clean_100_cuts() def remove_short_and_long_utt(c: Cut): # Keep only utterances with duration between 1 second and 20 seconds @@ -1058,9 +1071,6 @@ def run(rank, world_size, args): # an utterance duration distribution for your dataset to select # the threshold if c.duration < 1.0 or c.duration > 20.0: - logging.warning( - f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" - ) return False # In pruned RNN-T, we require that T >= S @@ -1102,7 +1112,7 @@ def run(rank, world_size, args): valid_cuts += librispeech.dev_other_cuts() valid_dl = librispeech.valid_dataloaders(valid_cuts) - if not params.print_diagnostics: + if not params.use_multidataset and not params.print_diagnostics: scan_pessimistic_batches_for_oom( model=model, train_dl=train_dl, From 2096e69bdaf6440ca6723cbd0200ae40748a94b4 Mon Sep 17 00:00:00 2001 From: Yifan Yang <64255737+yfyeung@users.noreply.github.com> Date: Sun, 23 Apr 2023 18:41:44 +0800 Subject: [PATCH 166/174] Use CutSet.mux for multidataset (#1020) * Use CutSet.mux * Remove mischange * Fix for style check --- egs/librispeech/ASR/prepare_multidataset.sh | 43 ------------------- .../multidataset.py | 38 +++++++++++++--- .../ASR/pruned_transducer_stateless7/train.py | 2 +- 3 files changed, 32 insertions(+), 51 deletions(-) diff --git a/egs/librispeech/ASR/prepare_multidataset.sh b/egs/librispeech/ASR/prepare_multidataset.sh index c068305c0..8b13a5bd8 100755 --- a/egs/librispeech/ASR/prepare_multidataset.sh +++ b/egs/librispeech/ASR/prepare_multidataset.sh @@ -328,46 +328,3 @@ if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then ./prepare_common_voice.sh fi fi - -if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then - log "Stage 11: Create multidataset" - split_dir=data/fbank/multidataset_split_${num_splits} - if [ ! -f data/fbank/multidataset_split/.multidataset.done ]; then - mkdir -p $split_dir/multidataset - log "Split LibriSpeech" - if [ ! -f $split_dir/.librispeech_split.done ]; then - lhotse split $num_splits ./data/fbank/librispeech_cuts_train-all-shuf.jsonl.gz $split_dir - touch $split_dir/.librispeech_split.done - fi - - if [[ "${multidataset[@]}" =~ "gigaspeech" ]]; then - log "Split GigaSpeech XL" - if [ ! -f $split_dir/.gigaspeech_XL_split.done ]; then - cd $split_dir - ln -sv ../gigaspeech_XL_split_2000/gigaspeech_cuts_XL.*.jsonl.gz . - cd ../../.. - touch $split_dir/.gigaspeech_XL_split.done - fi - fi - - if [[ "${multidataset[@]}" =~ "commonvoice" ]]; then - log "Split CommonVoice" - if [ ! -f $split_dir/.cv-en_train_split.done ]; then - lhotse split $num_splits ./data/en/fbank/cv-en_cuts_train.jsonl.gz $split_dir - touch $split_dir/.cv-en_train_split.done - fi - fi - - if [ ! -f $split_dir/.multidataset_mix.done ]; then - log "Mix multidataset" - for ((seq=1; seq<=$num_splits; seq++)); do - fseq=$(printf "%04d" $seq) - gunzip -c $split_dir/*.*${fseq}.jsonl.gz | \ - shuf | gzip -c > $split_dir/multidataset/multidataset_cuts_train.${fseq}.jsonl.gz - done - touch $split_dir/.multidataset_mix.done - fi - - touch data/fbank/multidataset_split/.multidataset.done - fi -fi diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/multidataset.py b/egs/librispeech/ASR/pruned_transducer_stateless7/multidataset.py index dcb4cd141..07c7126fa 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/multidataset.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/multidataset.py @@ -25,29 +25,53 @@ from lhotse import CutSet, load_manifest_lazy class MultiDataset: - def __init__(self, manifest_dir: str): + def __init__(self, manifest_dir: str, cv_manifest_dir: str): """ Args: manifest_dir: It is expected to contain the following files: - - multidataset_split_1998/multidataset/multidataset_cuts_train.*.jsonl.gz + - librispeech_cuts_train-all-shuf.jsonl.gz + - gigaspeech_XL_split_2000/gigaspeech_cuts_XL.*.jsonl.gz + + cv_manifest_dir: + It is expected to contain the following files: + + - cv-en_cuts_train.jsonl.gz """ self.manifest_dir = Path(manifest_dir) + self.cv_manifest_dir = Path(cv_manifest_dir) def train_cuts(self) -> CutSet: logging.info("About to get multidataset train cuts") - filenames = glob.glob( - f"{self.manifest_dir}/multidataset_split_1998/multidataset/multidataset_cuts_train.*.jsonl.gz" + # LibriSpeech + logging.info(f"Loading LibriSpeech in lazy mode") + librispeech_cuts = load_manifest_lazy( + self.manifest_dir / "librispeech_cuts_train-all-shuf.jsonl.gz" ) - pattern = re.compile(r"multidataset_cuts_train.([0-9]+).jsonl.gz") + # GigaSpeech + filenames = glob.glob( + f"{self.manifest_dir}/gigaspeech_XL_split_2000/gigaspeech_cuts_XL.*.jsonl.gz" + ) + + pattern = re.compile(r"gigaspeech_cuts_XL.([0-9]+).jsonl.gz") idx_filenames = ((int(pattern.search(f).group(1)), f) for f in filenames) idx_filenames = sorted(idx_filenames, key=lambda x: x[0]) sorted_filenames = [f[1] for f in idx_filenames] - logging.info(f"Loading {len(sorted_filenames)} splits") + logging.info(f"Loading GigaSpeech {len(sorted_filenames)} splits in lazy mode") - return lhotse.combine(lhotse.load_manifest_lazy(p) for p in sorted_filenames) + gigaspeech_cuts = lhotse.combine( + lhotse.load_manifest_lazy(p) for p in sorted_filenames + ) + + # CommonVoice + logging.info(f"Loading CommonVoice in lazy mode") + commonvoice_cuts = load_manifest_lazy( + self.cv_manifest_dir / f"cv-en_cuts_train.jsonl.gz" + ) + + return CutSet.mux(librispeech_cuts, gigaspeech_cuts, commonvoice_cuts) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py index 01c9500ce..1b179ceff 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py @@ -1053,7 +1053,7 @@ def run(rank, world_size, args): librispeech = LibriSpeechAsrDataModule(args) if params.use_multidataset: - multidataset = MultiDataset(params.manifest_dir) + multidataset = MultiDataset(params.manifest_dir, params.cv_manifest_dir) train_cuts = multidataset.train_cuts() else: if params.full_libri: From 45c13e90e42d0f6ff190d69acb18f4e868bfa954 Mon Sep 17 00:00:00 2001 From: marcoyang1998 <45973641+marcoyang1998@users.noreply.github.com> Date: Mon, 24 Apr 2023 15:00:02 +0800 Subject: [PATCH 167/174] RNNLM rescore + Low-order density ratio (#1017) * add rnnlm rescore + LODR * add LODR in decode.py * update RESULTS --- egs/librispeech/ASR/RESULTS.md | 38 ++- .../beam_search.py | 218 +++++++++++++++++- .../decode.py | 99 +++++++- 3 files changed, 345 insertions(+), 10 deletions(-) diff --git a/egs/librispeech/ASR/RESULTS.md b/egs/librispeech/ASR/RESULTS.md index 5a956fc9c..ef817d5dd 100644 --- a/egs/librispeech/ASR/RESULTS.md +++ b/egs/librispeech/ASR/RESULTS.md @@ -215,11 +215,12 @@ done We also support decoding with neural network LMs. After combining with language models, the WERs are | decoding method | chunk size | test-clean | test-other | comment | decoding mode | |----------------------|------------|------------|------------|---------------------|----------------------| -| modified beam search | 320ms | 3.11 | 7.93 | --epoch 30 --avg 9 | simulated streaming | -| modified beam search + RNNLM shallow fusion | 320ms | 2.58 | 6.65 | --epoch 30 --avg 9 | simulated streaming | -| modified beam search + RNNLM nbest rescore | 320ms | 2.59 | 6.86 | --epoch 30 --avg 9 | simulated streaming | +| `modified_beam_search` | 320ms | 3.11 | 7.93 | --epoch 30 --avg 9 | simulated streaming | +| `modified_beam_search_lm_shallow_fusion` | 320ms | 2.58 | 6.65 | --epoch 30 --avg 9 | simulated streaming | +| `modified_beam_search_lm_rescore` | 320ms | 2.59 | 6.86 | --epoch 30 --avg 9 | simulated streaming | +| `modified_beam_search_lm_rescore_LODR` | 320ms | 2.52 | 6.73 | --epoch 30 --avg 9 | simulated streaming | -Please use the following command for RNNLM shallow fusion: +Please use the following command for `modified_beam_search_lm_shallow_fusion`: ```bash for lm_scale in $(seq 0.15 0.01 0.38); do for beam_size in 4 8 12; do @@ -246,7 +247,7 @@ for lm_scale in $(seq 0.15 0.01 0.38); do done ``` -Please use the following command for RNNLM rescore: +Please use the following command for `modified_beam_search_lm_rescore`: ```bash ./pruned_transducer_stateless7_streaming/decode.py \ --epoch 30 \ @@ -268,7 +269,32 @@ Please use the following command for RNNLM rescore: --lm-vocab-size 500 ``` -A well-trained RNNLM can be found here: . +Please use the following command for `modified_beam_search_lm_rescore_LODR`: +```bash +./pruned_transducer_stateless7_streaming/decode.py \ + --epoch 30 \ + --avg 9 \ + --use-averaged-model True \ + --beam-size 8 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --max-duration 600 \ + --decode-chunk-len 32 \ + --decoding-method modified_beam_search_lm_rescore_LODR \ + --use-shallow-fusion 0 \ + --lm-type rnn \ + --lm-exp-dir rnn_lm/exp \ + --lm-epoch 99 \ + --lm-avg 1 \ + --rnn-lm-embedding-dim 2048 \ + --rnn-lm-hidden-dim 2048 \ + --rnn-lm-num-layers 3 \ + --lm-vocab-size 500 \ + --tokens-ngram 2 \ + --backoff-id 500 +``` + +A well-trained RNNLM can be found here: . The bi-gram used in LODR decoding +can be found here: . #### Smaller model diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py index c44a2ad3e..e45f2c652 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -1244,7 +1244,7 @@ def modified_beam_search_lm_rescore( # get the best hyp with different lm_scale for lm_scale in lm_scale_list: - key = f"nnlm_scale_{lm_scale}" + key = f"nnlm_scale_{lm_scale:.2f}" tot_scores = am_scores.values + lm_scores * lm_scale ragged_tot_scores = k2.RaggedTensor(shape=am_scores.shape, value=tot_scores) max_indexes = ragged_tot_scores.argmax().tolist() @@ -1257,6 +1257,222 @@ def modified_beam_search_lm_rescore( return ans +def modified_beam_search_lm_rescore_LODR( + model: Transducer, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + LM: LmScorer, + LODR_lm: NgramLm, + sp: spm.SentencePieceProcessor, + lm_scale_list: List[int], + beam: int = 4, + temperature: float = 1.0, + return_timestamps: bool = False, +) -> Union[List[List[int]], DecodingResults]: + """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. + Rescore the final results with RNNLM and return the one with the highest score + + Args: + model: + The transducer model. + encoder_out: + Output from the encoder. Its shape is (N, T, C). + encoder_out_lens: + A 1-D tensor of shape (N,), containing number of valid frames in + encoder_out before padding. + beam: + Number of active paths during the beam search. + temperature: + Softmax temperature. + LM: + A neural network language model + return_timestamps: + Whether to return timestamps. + Returns: + If return_timestamps is False, return the decoded result. + Else, return a DecodingResults object containing + decoded result and corresponding timestamps. + """ + assert encoder_out.ndim == 3, encoder_out.shape + assert encoder_out.size(0) >= 1, encoder_out.size(0) + + packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( + input=encoder_out, + lengths=encoder_out_lens.cpu(), + batch_first=True, + enforce_sorted=False, + ) + + blank_id = model.decoder.blank_id + unk_id = getattr(model, "unk_id", blank_id) + context_size = model.decoder.context_size + device = next(model.parameters()).device + + batch_size_list = packed_encoder_out.batch_sizes.tolist() + N = encoder_out.size(0) + assert torch.all(encoder_out_lens > 0), encoder_out_lens + assert N == batch_size_list[0], (N, batch_size_list) + + B = [HypothesisList() for _ in range(N)] + for i in range(N): + B[i].add( + Hypothesis( + ys=[blank_id] * context_size, + log_prob=torch.zeros(1, dtype=torch.float32, device=device), + timestamp=[], + ) + ) + + encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) + + offset = 0 + finalized_B = [] + for (t, batch_size) in enumerate(batch_size_list): + start = offset + end = offset + batch_size + current_encoder_out = encoder_out.data[start:end] + current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1) + # current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim) + offset = end + + finalized_B = B[batch_size:] + finalized_B + B = B[:batch_size] + + hyps_shape = get_hyps_shape(B).to(device) + + A = [list(b) for b in B] + B = [HypothesisList() for _ in range(batch_size)] + + ys_log_probs = torch.cat( + [hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps] + ) # (num_hyps, 1) + + decoder_input = torch.tensor( + [hyp.ys[-context_size:] for hyps in A for hyp in hyps], + device=device, + dtype=torch.int64, + ) # (num_hyps, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) + decoder_out = model.joiner.decoder_proj(decoder_out) + # decoder_out is of shape (num_hyps, 1, 1, joiner_dim) + + # Note: For torch 1.7.1 and below, it requires a torch.int64 tensor + # as index, so we use `to(torch.int64)` below. + current_encoder_out = torch.index_select( + current_encoder_out, + dim=0, + index=hyps_shape.row_ids(1).to(torch.int64), + ) # (num_hyps, 1, 1, encoder_out_dim) + + logits = model.joiner( + current_encoder_out, + decoder_out, + project_input=False, + ) # (num_hyps, 1, 1, vocab_size) + + logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) + + log_probs = (logits / temperature).log_softmax(dim=-1) # (num_hyps, vocab_size) + + log_probs.add_(ys_log_probs) + + vocab_size = log_probs.size(-1) + + log_probs = log_probs.reshape(-1) + + row_splits = hyps_shape.row_splits(1) * vocab_size + log_probs_shape = k2.ragged.create_ragged_shape2( + row_splits=row_splits, cached_tot_size=log_probs.numel() + ) + ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) + + for i in range(batch_size): + topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + topk_hyp_indexes = (topk_indexes // vocab_size).tolist() + topk_token_indexes = (topk_indexes % vocab_size).tolist() + + for k in range(len(topk_hyp_indexes)): + hyp_idx = topk_hyp_indexes[k] + hyp = A[i][hyp_idx] + + new_ys = hyp.ys[:] + new_token = topk_token_indexes[k] + new_timestamp = hyp.timestamp[:] + if new_token not in (blank_id, unk_id): + new_ys.append(new_token) + new_timestamp.append(t) + + new_log_prob = topk_log_probs[k] + new_hyp = Hypothesis( + ys=new_ys, log_prob=new_log_prob, timestamp=new_timestamp + ) + B[i].add(new_hyp) + + B = B + finalized_B + + # get the am_scores for n-best list + hyps_shape = get_hyps_shape(B) + am_scores = torch.tensor([hyp.log_prob.item() for b in B for hyp in b]) + am_scores = k2.RaggedTensor(value=am_scores, shape=hyps_shape).to(device) + + # now LM rescore + # prepare input data to LM + candidate_seqs = [hyp.ys[context_size:] for b in B for hyp in b] + possible_seqs = k2.RaggedTensor(candidate_seqs) + row_splits = possible_seqs.shape.row_splits(1) + sentence_token_lengths = row_splits[1:] - row_splits[:-1] + possible_seqs_with_sos = add_sos(possible_seqs, sos_id=1) + possible_seqs_with_eos = add_eos(possible_seqs, eos_id=1) + sentence_token_lengths += 1 + + x = possible_seqs_with_sos.pad(mode="constant", padding_value=blank_id) + y = possible_seqs_with_eos.pad(mode="constant", padding_value=blank_id) + x = x.to(device).to(torch.int64) + y = y.to(device).to(torch.int64) + sentence_token_lengths = sentence_token_lengths.to(device).to(torch.int64) + + lm_scores = LM.lm(x=x, y=y, lengths=sentence_token_lengths) + assert lm_scores.ndim == 2 + lm_scores = -1 * lm_scores.sum(dim=1) + + # now LODR scores + import math + + LODR_scores = [] + for seq in candidate_seqs: + tokens = " ".join(sp.id_to_piece(seq)) + LODR_scores.append(LODR_lm.score(tokens)) + LODR_scores = torch.tensor(LODR_scores).to(device) * math.log( + 10 + ) # arpa scores are 10-based + assert lm_scores.shape == LODR_scores.shape + + ans = {} + unsorted_indices = packed_encoder_out.unsorted_indices.tolist() + + LODR_scale_list = [0.05 * i for i in range(1, 20)] + # get the best hyp with different lm_scale and lodr_scale + for lm_scale in lm_scale_list: + for lodr_scale in LODR_scale_list: + key = f"nnlm_scale_{lm_scale:.2f}_lodr_scale_{lodr_scale:.2f}" + tot_scores = ( + am_scores.values / lm_scale + lm_scores - LODR_scores * lodr_scale + ) + ragged_tot_scores = k2.RaggedTensor(shape=am_scores.shape, value=tot_scores) + max_indexes = ragged_tot_scores.argmax().tolist() + unsorted_hyps = [candidate_seqs[idx] for idx in max_indexes] + hyps = [] + for idx in unsorted_indices: + hyps.append(unsorted_hyps[idx]) + + ans[key] = hyps + return ans + + def _deprecated_modified_beam_search( model: Transducer, encoder_out: torch.Tensor, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/decode.py index 8aa0d8689..3444f8193 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/decode.py @@ -123,10 +123,13 @@ from beam_search import ( greedy_search_batch, modified_beam_search, modified_beam_search_lm_rescore, + modified_beam_search_lm_rescore_LODR, modified_beam_search_lm_shallow_fusion, + modified_beam_search_LODR, ) from train import add_model_arguments, get_params, get_transducer_model +from icefall import LmScorer, NgramLm from icefall.checkpoint import ( average_checkpoints, average_checkpoints_with_averaged_model, @@ -134,7 +137,6 @@ from icefall.checkpoint import ( load_checkpoint, ) from icefall.lexicon import Lexicon -from icefall.lm_wrapper import LmScorer from icefall.utils import ( AttributeDict, setup_logger, @@ -336,6 +338,21 @@ def get_parser(): """, ) + parser.add_argument( + "--tokens-ngram", + type=int, + default=2, + help="""The order of the ngram lm. + """, + ) + + parser.add_argument( + "--backoff-id", + type=int, + default=500, + help="ID of the backoff symbol in the ngram LM", + ) + add_model_arguments(parser) return parser @@ -349,6 +366,8 @@ def decode_one_batch( word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, LM: Optional[LmScorer] = None, + ngram_lm=None, + ngram_lm_scale: float = 0.0, ) -> Dict[str, List[List[str]]]: """Decode one batch and return the result in a dict. The dict has the following format: @@ -483,6 +502,18 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_LODR": + hyp_tokens = modified_beam_search_LODR( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LODR_lm=ngram_lm, + LODR_lm_scale=ngram_lm_scale, + LM=LM, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) elif params.decoding_method == "modified_beam_search_lm_rescore": lm_scale_list = [0.01 * i for i in range(10, 50)] ans_dict = modified_beam_search_lm_rescore( @@ -493,6 +524,18 @@ def decode_one_batch( LM=LM, lm_scale_list=lm_scale_list, ) + elif params.decoding_method == "modified_beam_search_lm_rescore_LODR": + lm_scale_list = [0.02 * i for i in range(2, 30)] + ans_dict = modified_beam_search_lm_rescore_LODR( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LM=LM, + LODR_lm=ngram_lm, + sp=sp, + lm_scale_list=lm_scale_list, + ) else: batch_size = encoder_out.size(0) @@ -531,7 +574,10 @@ def decode_one_batch( key += f"_ngram_lm_scale_{params.ngram_lm_scale}" return {key: hyps} - elif params.decoding_method == "modified_beam_search_lm_rescore": + elif params.decoding_method in ( + "modified_beam_search_lm_rescore", + "modified_beam_search_lm_rescore_LODR", + ): ans = dict() assert ans_dict is not None for key, hyps in ans_dict.items(): @@ -550,6 +596,8 @@ def decode_dataset( word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, LM: Optional[LmScorer] = None, + ngram_lm=None, + ngram_lm_scale: float = 0.0, ) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: """Decode dataset. @@ -568,6 +616,8 @@ def decode_dataset( The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used only when --decoding_method is fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + ngram_lm: + A n-gram LM to be used for LODR. Returns: Return a dict, whose key may be "greedy_search" if greedy search is used, or it may be "beam_7" if beam size of 7 is used. @@ -600,6 +650,8 @@ def decode_dataset( word_table=word_table, batch=batch, LM=LM, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, ) for name, hyps in hyps_dict.items(): @@ -677,8 +729,10 @@ def main(): "fast_beam_search_nbest_LG", "fast_beam_search_nbest_oracle", "modified_beam_search", + "modified_beam_search_LODR", "modified_beam_search_lm_shallow_fusion", "modified_beam_search_lm_rescore", + "modified_beam_search_lm_rescore_LODR", ) params.res_dir = params.exp_dir / params.decoding_method @@ -822,7 +876,12 @@ def main(): model.eval() # only load the neural network LM if required - if params.use_shallow_fusion or "lm" in params.decoding_method: + if params.use_shallow_fusion or params.decoding_method in ( + "modified_beam_search_lm_rescore", + "modified_beam_search_lm_rescore_LODR", + "modified_beam_search_lm_shallow_fusion", + "modified_beam_search_LODR", + ): LM = LmScorer( lm_type=params.lm_type, params=params, @@ -834,6 +893,35 @@ def main(): else: LM = None + # only load N-gram LM when needed + if params.decoding_method == "modified_beam_search_lm_rescore_LODR": + try: + import kenlm + except ImportError: + print("Please install kenlm first. You can use") + print(" pip install https://github.com/kpu/kenlm/archive/master.zip") + print("to install it") + import sys + + sys.exit(-1) + ngram_file_name = str(params.lang_dir / f"{params.tokens_ngram}gram.arpa") + logging.info(f"lm filename: {ngram_file_name}") + ngram_lm = kenlm.Model(ngram_file_name) + + elif params.decoding_method == "modified_beam_search_LODR": + lm_filename = f"{params.tokens_ngram}gram.fst.txt" + logging.info(f"Loading token level lm: {lm_filename}") + ngram_lm = NgramLm( + str(params.lang_dir / lm_filename), + backoff_id=params.backoff_id, + is_binary=False, + ) + logging.info(f"num states: {ngram_lm.lm.num_states}") + ngram_lm_scale = params.ngram_lm_scale + else: + ngram_lm = None + ngram_lm_scale = None + if "fast_beam_search" in params.decoding_method: if params.decoding_method == "fast_beam_search_nbest_LG": lexicon = Lexicon(params.lang_dir) @@ -866,8 +954,10 @@ def main(): test_sets = ["test-clean", "test-other"] test_dl = [test_clean_dl, test_other_dl] + import time for test_set, test_dl in zip(test_sets, test_dl): + start = time.time() results_dict = decode_dataset( dl=test_dl, params=params, @@ -876,7 +966,10 @@ def main(): word_table=word_table, decoding_graph=decoding_graph, LM=LM, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, ) + logging.info(f"Elasped time for {test_set}: {time.time() - start}") save_results( params=params, From 2767b9ff11e2410c2307b2041442ea87a0441022 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Thu, 27 Apr 2023 14:36:36 +0800 Subject: [PATCH 168/174] Support exporting RNNLM to ONNX. (#1014) * Support exporting RNNLM to ONNX. * add int8 models * fix style issues * Fix EOS padding * support exporting for streaming ASR --- icefall/rnn_lm/.gitignore | 1 + icefall/rnn_lm/check-onnx-streaming.py | 132 +++++++++ icefall/rnn_lm/check-onnx.py | 119 ++++++++ icefall/rnn_lm/export-onnx.py | 395 +++++++++++++++++++++++++ icefall/rnn_lm/export-onnx.sh | 26 ++ icefall/rnn_lm/export.py | 7 +- icefall/rnn_lm/export.sh | 27 ++ icefall/rnn_lm/model.py | 46 +++ 8 files changed, 752 insertions(+), 1 deletion(-) create mode 100644 icefall/rnn_lm/.gitignore create mode 100755 icefall/rnn_lm/check-onnx-streaming.py create mode 100755 icefall/rnn_lm/check-onnx.py create mode 100755 icefall/rnn_lm/export-onnx.py create mode 100755 icefall/rnn_lm/export-onnx.sh create mode 100755 icefall/rnn_lm/export.sh diff --git a/icefall/rnn_lm/.gitignore b/icefall/rnn_lm/.gitignore new file mode 100644 index 000000000..877fb1e18 --- /dev/null +++ b/icefall/rnn_lm/.gitignore @@ -0,0 +1 @@ +icefall-librispeech-rnn-lm diff --git a/icefall/rnn_lm/check-onnx-streaming.py b/icefall/rnn_lm/check-onnx-streaming.py new file mode 100755 index 000000000..8850c1c71 --- /dev/null +++ b/icefall/rnn_lm/check-onnx-streaming.py @@ -0,0 +1,132 @@ +#!/usr/bin/env python3 +# +# Copyright 2023 Xiaomi Corporation + +""" +Usage: + +./check-onnx-streaming.py \ + --jit ./icefall-librispeech-rnn-lm/exp/cpu_jit.pt \ + --onnx ./icefall-librispeech-rnn-lm/exp/with-state-epoch-99-avg-1.onnx + +Note: You can download pre-trained models from +https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm + +""" + +import argparse +import logging +from typing import Tuple + +import onnxruntime as ort +import torch + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--jit", + required=True, + type=str, + help="Path to the torchscript model", + ) + + parser.add_argument( + "--onnx", + required=True, + type=str, + help="Path to the onnx model", + ) + + return parser + + +class OnnxModel: + def __init__(self, filename: str): + session_opts = ort.SessionOptions() + session_opts.inter_op_num_threads = 1 + session_opts.intra_op_num_threads = 1 + + self.model = ort.InferenceSession( + filename, + sess_options=session_opts, + ) + + meta_data = self.model.get_modelmeta().custom_metadata_map + self.sos_id = int(meta_data["sos_id"]) + self.eos_id = int(meta_data["eos_id"]) + self.vocab_size = int(meta_data["vocab_size"]) + self.num_layers = int(meta_data["num_layers"]) + self.hidden_size = int(meta_data["hidden_size"]) + print(meta_data) + + def __call__( + self, x: torch.Tensor, y: torch.Tensor, h0: torch.Tensor, c0: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + out = self.model.run( + [ + self.model.get_outputs()[0].name, + self.model.get_outputs()[1].name, + self.model.get_outputs()[2].name, + ], + { + self.model.get_inputs()[0].name: x.numpy(), + self.model.get_inputs()[1].name: y.numpy(), + self.model.get_inputs()[2].name: h0.numpy(), + self.model.get_inputs()[3].name: c0.numpy(), + }, + ) + return ( + torch.from_numpy(out[0]), + torch.from_numpy(out[1]), + torch.from_numpy(out[2]), + ) + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + logging.info(vars(args)) + + torch_model = torch.jit.load(args.jit).cpu() + onnx_model = OnnxModel(args.onnx) + N = torch.arange(1, 5).tolist() + + num_layers = onnx_model.num_layers + hidden_size = onnx_model.hidden_size + + for n in N: + L = torch.randint(low=1, high=100, size=(1,)).item() + x = torch.randint( + low=1, high=onnx_model.vocab_size, size=(n, L), dtype=torch.int64 + ) + y = torch.randint( + low=1, high=onnx_model.vocab_size, size=(n, L), dtype=torch.int64 + ) + h0 = torch.rand(num_layers, n, hidden_size) + c0 = torch.rand(num_layers, n, hidden_size) + + torch_nll, torch_h0, torch_c0 = torch_model.streaming_forward(x, y, h0, c0) + onnx_nll, onnx_h0, onnx_c0 = onnx_model(x, y, h0, c0) + + for torch_v, onnx_v in zip( + (torch_nll, torch_h0, torch_c0), (onnx_nll, onnx_h0, onnx_c0) + ): + + assert torch.allclose(torch_v, onnx_v, atol=1e-5), ( + torch_v.shape, + onnx_v.shape, + (torch_v - onnx_v).abs().max(), + ) + print(n, L, torch_v.sum(), onnx_v.sum()) + + +if __name__ == "__main__": + torch.manual_seed(20230423) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/icefall/rnn_lm/check-onnx.py b/icefall/rnn_lm/check-onnx.py new file mode 100755 index 000000000..24c5395f8 --- /dev/null +++ b/icefall/rnn_lm/check-onnx.py @@ -0,0 +1,119 @@ +#!/usr/bin/env python3 +# +# Copyright 2023 Xiaomi Corporation + +""" +Usage: + +./check-onnx.py \ + --jit ./icefall-librispeech-rnn-lm/exp/cpu_jit.pt \ + --onnx ./icefall-librispeech-rnn-lm/exp/no-state-epoch-99-avg-1.onnx + +Note: You can download pre-trained models from +https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm + +""" + +import argparse +import logging + +import onnxruntime as ort +import torch + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--jit", + required=True, + type=str, + help="Path to the torchscript model", + ) + + parser.add_argument( + "--onnx", + required=True, + type=str, + help="Path to the onnx model", + ) + + return parser + + +class OnnxModel: + def __init__(self, filename: str): + session_opts = ort.SessionOptions() + session_opts.inter_op_num_threads = 1 + session_opts.intra_op_num_threads = 1 + + self.model = ort.InferenceSession( + filename, + sess_options=session_opts, + ) + + meta_data = self.model.get_modelmeta().custom_metadata_map + self.sos_id = int(meta_data["sos_id"]) + self.eos_id = int(meta_data["eos_id"]) + self.vocab_size = int(meta_data["vocab_size"]) + print(meta_data) + + def __call__(self, x: torch.Tensor, x_lens: torch.Tensor) -> torch.Tensor: + out = self.model.run( + [ + self.model.get_outputs()[0].name, + ], + { + self.model.get_inputs()[0].name: x.numpy(), + self.model.get_inputs()[1].name: x_lens.numpy(), + }, + ) + return torch.from_numpy(out[0]) + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + logging.info(vars(args)) + + torch_model = torch.jit.load(args.jit).cpu() + onnx_model = OnnxModel(args.onnx) + N = torch.arange(1, 5).tolist() + + for n in N: + L = torch.randint(low=1, high=100, size=(1,)).item() + x = torch.randint( + low=1, high=onnx_model.vocab_size, size=(n, L), dtype=torch.int64 + ) + x_lens = torch.full((n,), fill_value=L, dtype=torch.int64) + if n > 1: + x_lens[0] = L // 2 + 1 + + sos = torch.full((1,), fill_value=onnx_model.sos_id).expand(n, 1) + sos_x = torch.cat([sos, x], dim=1) + + pad_col = torch.zeros((1,), dtype=x.dtype).expand(n, 1) + x_eos = torch.cat([x, pad_col], dim=1) + + row_index = torch.arange(0, n, dtype=x.dtype) + x_eos[row_index, x_lens] = onnx_model.eos_id + + torch_nll = torch_model(sos_x, x_eos, x_lens + 1).sum(dim=-1) + onnx_nll = onnx_model(x, x_lens) + # Note: For int8 models, the differences may be quite large, + # e.g., within 0.9 + assert torch.allclose(torch_nll, onnx_nll), ( + torch_nll, + onnx_nll, + ) + print(n, L, torch_nll, onnx_nll) + + +if __name__ == "__main__": + torch.manual_seed(20230420) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/icefall/rnn_lm/export-onnx.py b/icefall/rnn_lm/export-onnx.py new file mode 100755 index 000000000..6855f9bea --- /dev/null +++ b/icefall/rnn_lm/export-onnx.py @@ -0,0 +1,395 @@ +#!/usr/bin/env python3 +# +# Copyright 2023 Xiaomi Corporation + +import argparse +import logging +from pathlib import Path + +import onnx +import torch +from model import RnnLmModel +from onnxruntime.quantization import QuantType, quantize_dynamic + +from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint +from icefall.utils import AttributeDict, str2bool +from typing import Dict +from train import get_params + + +def add_meta_data(filename: str, meta_data: Dict[str, str]): + """Add meta data to an ONNX model. It is changed in-place. + + Args: + filename: + Filename of the ONNX model to be changed. + meta_data: + Key-value pairs. + """ + model = onnx.load(filename) + for key, value in meta_data.items(): + meta = model.metadata_props.add() + meta.key = key + meta.value = value + + onnx.save(model, filename) + + +# A wrapper for RnnLm model to simpily the C++ calling code +# when exporting the model to ONNX. +# +# TODO(fangjun): The current wrapper works only for non-streaming ASR +# since we don't expose the LM state and it is used to score +# a complete sentence at once. +class RnnLmModelWrapper(torch.nn.Module): + def __init__(self, model: RnnLmModel, sos_id: int, eos_id: int): + super().__init__() + self.model = model + self.sos_id = sos_id + self.eos_id = eos_id + + def forward(self, x: torch.Tensor, x_lens: torch.Tensor) -> torch.Tensor: + """ + Args: + x: + A 2-D tensor of shape (N, L) with dtype torch.int64. + It does not contain SOS or EOS. We will add SOS and EOS inside + this function. + x_lens: + A 1-D tensor of shape (N,) with dtype torch.int64. It contains + number of valid tokens in ``x`` before padding. + Returns: + Return a 1-D tensor of shape (N,) containing negative loglikelihood. + Its dtype is torch.float32 + """ + N = x.size(0) + + sos_tensor = torch.full((1,), fill_value=self.sos_id, dtype=x.dtype).expand( + N, 1 + ) + sos_x = torch.cat([sos_tensor, x], dim=1) + + pad_col = torch.zeros((1,), dtype=x.dtype).expand(N, 1) + x_eos = torch.cat([x, pad_col], dim=1) + + row_index = torch.arange(0, N, dtype=x.dtype) + x_eos[row_index, x_lens] = self.eos_id + + # use x_lens + 1 here since we prepended x with sos + return ( + self.model(x=sos_x, y=x_eos, lengths=x_lens + 1) + .to(torch.float32) + .sum(dim=1) + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=29, + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", + ) + + parser.add_argument( + "--avg", + type=int, + default=5, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--vocab-size", + type=int, + default=500, + help="Vocabulary size of the model", + ) + + parser.add_argument( + "--embedding-dim", + type=int, + default=2048, + help="Embedding dim of the model", + ) + + parser.add_argument( + "--hidden-dim", + type=int, + default=2048, + help="Hidden dim of the model", + ) + + parser.add_argument( + "--num-layers", + type=int, + default=3, + help="Number of RNN layers the model", + ) + + parser.add_argument( + "--tie-weights", + type=str2bool, + default=True, + help="""True to share the weights between the input embedding layer and the + last output linear layer + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="rnn_lm/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + return parser + + +def export_without_state( + model: RnnLmModel, + filename: str, + params: AttributeDict, + opset_version: int, +): + model_wrapper = RnnLmModelWrapper( + model, + sos_id=params.sos_id, + eos_id=params.eos_id, + ) + + N = 1 + L = 20 + x = torch.randint(low=1, high=params.vocab_size, size=(N, L), dtype=torch.int64) + x_lens = torch.full((N,), fill_value=L, dtype=torch.int64) + + # Note(fangjun): The following warnings can be ignored. + # We can use ./check-onnx.py to validate the exported model with batch_size > 1 + """ + torch/onnx/symbolic_opset9.py:2119: UserWarning: Exporting a model to ONNX + with a batch_size other than 1, with a variable length with LSTM can cause + an error when running the ONNX model with a different batch size. Make sure + to save the model with a batch size of 1, or define the initial states + (h0/c0) as inputs of the model. warnings.warn("Exporting a model to ONNX + with a batch_size other than 1, " + + """ + + torch.onnx.export( + model_wrapper, + (x, x_lens), + filename, + verbose=False, + opset_version=opset_version, + input_names=["x", "x_lens"], + output_names=["nll"], + dynamic_axes={ + "x": {0: "N", 1: "L"}, + "x_lens": {0: "N"}, + "nll": {0: "N"}, + }, + ) + + meta_data = { + "model_type": "rnnlm", + "version": "1", + "model_author": "k2-fsa", + "comment": "rnnlm without state", + "sos_id": str(params.sos_id), + "eos_id": str(params.eos_id), + "vocab_size": str(params.vocab_size), + "url": "https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm", + } + logging.info(f"meta_data: {meta_data}") + + add_meta_data(filename=filename, meta_data=meta_data) + + +def export_with_state( + model: RnnLmModel, + filename: str, + params: AttributeDict, + opset_version: int, +): + N = 1 + L = 20 + num_layers = model.rnn.num_layers + hidden_size = model.rnn.hidden_size + + x = torch.randint(low=1, high=params.vocab_size, size=(N, L), dtype=torch.int64) + y = torch.randint(low=1, high=params.vocab_size, size=(N, L), dtype=torch.int64) + h0 = torch.zeros(num_layers, N, hidden_size) + c0 = torch.zeros(num_layers, N, hidden_size) + + # Note(fangjun): The following warnings can be ignored. + # We can use ./check-onnx.py to validate the exported model with batch_size > 1 + """ + torch/onnx/symbolic_opset9.py:2119: UserWarning: Exporting a model to ONNX + with a batch_size other than 1, with a variable length with LSTM can cause + an error when running the ONNX model with a different batch size. Make sure + to save the model with a batch size of 1, or define the initial states + (h0/c0) as inputs of the model. warnings.warn("Exporting a model to ONNX + with a batch_size other than 1, " + + """ + + torch.onnx.export( + model, + (x, y, h0, c0), + filename, + verbose=False, + opset_version=opset_version, + input_names=["x", "y", "h0", "c0"], + output_names=["nll", "next_h0", "next_c0"], + dynamic_axes={ + "x": {0: "N", 1: "L"}, + "y": {0: "N", 1: "L"}, + "h0": {1: "N"}, + "c0": {1: "N"}, + "nll": {0: "N"}, + "next_h0": {1: "N"}, + "next_c0": {1: "N"}, + }, + ) + + meta_data = { + "model_type": "rnnlm", + "version": "1", + "model_author": "k2-fsa", + "comment": "rnnlm state", + "sos_id": str(params.sos_id), + "eos_id": str(params.eos_id), + "vocab_size": str(params.vocab_size), + "num_layers": str(num_layers), + "hidden_size": str(hidden_size), + "url": "https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm", + } + logging.info(f"meta_data: {meta_data}") + + add_meta_data(filename=filename, meta_data=meta_data) + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + logging.info(params) + + device = torch.device("cpu") + logging.info(f"device: {device}") + + model = RnnLmModel( + vocab_size=params.vocab_size, + embedding_dim=params.embedding_dim, + hidden_dim=params.hidden_dim, + num_layers=params.num_layers, + tie_weights=params.tie_weights, + ) + + model.to(device) + + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict( + average_checkpoints(filenames, device=device), strict=False + ) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 0: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict( + average_checkpoints(filenames, device=device), strict=False + ) + + model.to("cpu") + model.eval() + + if params.iter > 0: + suffix = f"iter-{params.iter}" + else: + suffix = f"epoch-{params.epoch}" + + suffix += f"-avg-{params.avg}" + + opset_version = 13 + + logging.info("Exporting model without state") + filename = params.exp_dir / f"no-state-{suffix}.onnx" + export_without_state( + model=model, + filename=filename, + params=params, + opset_version=opset_version, + ) + + filename_int8 = params.exp_dir / f"no-state-{suffix}.int8.onnx" + quantize_dynamic( + model_input=filename, + model_output=filename_int8, + weight_type=QuantType.QInt8, + ) + + # now for streaming export + saved_forward = model.__class__.forward + model.__class__.forward = model.__class__.streaming_forward + streaming_filename = params.exp_dir / f"with-state-{suffix}.onnx" + export_with_state( + model=model, + filename=streaming_filename, + params=params, + opset_version=opset_version, + ) + model.__class__.forward = saved_forward + + streaming_filename_int8 = params.exp_dir / f"with-state-{suffix}.int8.onnx" + quantize_dynamic( + model_input=streaming_filename, + model_output=streaming_filename_int8, + weight_type=QuantType.QInt8, + ) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/icefall/rnn_lm/export-onnx.sh b/icefall/rnn_lm/export-onnx.sh new file mode 100755 index 000000000..6e3262b5e --- /dev/null +++ b/icefall/rnn_lm/export-onnx.sh @@ -0,0 +1,26 @@ +#!/usr/bin/env bash + +# We use the model from +# https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm/tree/main +# as an example + +export CUDA_VISIBLE_DEVICES= + +if [ ! -f ./icefall-librispeech-rnn-lm/exp/pretrained.pt ]; then + GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm + pushd icefall-librispeech-rnn-lm/exp + git lfs pull --include "pretrained.pt" + ln -s pretrained.pt epoch-99.pt + popd +fi + +python3 ./export-onnx.py \ + --exp-dir ./icefall-librispeech-rnn-lm/exp \ + --epoch 99 \ + --avg 1 \ + --vocab-size 500 \ + --embedding-dim 2048 \ + --hidden-dim 2048 \ + --num-layers 3 \ + --tie-weights 1 + diff --git a/icefall/rnn_lm/export.py b/icefall/rnn_lm/export.py index a8598a1ce..be4e7f8c5 100644 --- a/icefall/rnn_lm/export.py +++ b/icefall/rnn_lm/export.py @@ -26,7 +26,7 @@ import torch from model import RnnLmModel from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint -from icefall.utils import AttributeDict, load_averaged_model, str2bool +from icefall.utils import AttributeDict, str2bool def get_parser(): @@ -118,6 +118,7 @@ def get_parser(): return parser +@torch.no_grad() def main(): args = get_parser().parse_args() args.exp_dir = Path(args.exp_dir) @@ -180,6 +181,10 @@ def main(): if params.jit: logging.info("Using torch.jit.script") + + model.__class__.streaming_forward = torch.jit.export( + model.__class__.streaming_forward + ) model = torch.jit.script(model) filename = params.exp_dir / "cpu_jit.pt" model.save(str(filename)) diff --git a/icefall/rnn_lm/export.sh b/icefall/rnn_lm/export.sh new file mode 100755 index 000000000..678bc294e --- /dev/null +++ b/icefall/rnn_lm/export.sh @@ -0,0 +1,27 @@ +#!/usr/bin/env bash + +# We use the model from +# https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm/tree/main +# as an example + +export CUDA_VISIBLE_DEVICES= + +if [ ! -f ./icefall-librispeech-rnn-lm/exp/pretrained.pt ]; then + GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm + pushd icefall-librispeech-rnn-lm/exp + git lfs pull --include "pretrained.pt" + ln -s pretrained.pt epoch-99.pt + popd +fi + +python3 ./export.py \ + --exp-dir ./icefall-librispeech-rnn-lm/exp \ + --epoch 99 \ + --avg 1 \ + --vocab-size 500 \ + --embedding-dim 2048 \ + --hidden-dim 2048 \ + --num-layers 3 \ + --tie-weights 1 \ + --jit 1 + diff --git a/icefall/rnn_lm/model.py b/icefall/rnn_lm/model.py index ebb3128e3..8d3e16432 100644 --- a/icefall/rnn_lm/model.py +++ b/icefall/rnn_lm/model.py @@ -15,6 +15,7 @@ # limitations under the License. import logging +from typing import Tuple import torch import torch.nn.functional as F @@ -47,6 +48,11 @@ class RnnLmModel(torch.nn.Module): and https://arxiv.org/abs/1611.01462 """ super().__init__() + self.vocab_size = vocab_size + self.embedding_dim = embedding_dim + self.hidden_dim = hidden_dim + self.num_layers = num_layers + self.tie_weights = tie_weights self.input_embedding = torch.nn.Embedding( num_embeddings=vocab_size, @@ -74,6 +80,46 @@ class RnnLmModel(torch.nn.Module): self.cache = {} + def streaming_forward( + self, + x: torch.Tensor, + y: torch.Tensor, + h0: torch.Tensor, + c0: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + ''' + Args: + x: + A 2-D tensor of shape (N, L). We won't prepend it with SOS. + y: + A 2-D tensor of shape (N, L). We won't append it with EOS. + h0: + A 3-D tensor of shape (num_layers, N, hidden_size). + (If proj_size > 0, then it is (num_layers, N, proj_size)) + c0: + A 3-D tensor of shape (num_layers, N, hidden_size). + Returns: + Return a tuple containing 3 tensors: + - negative loglike (nll), a 1-D tensor of shape (N,) + - next_h0, a 3-D tensor with the same shape as h0 + - next_c0, a 3-D tensor with the same shape as c0 + ''' + assert x.ndim == y.ndim == 2, (x.ndim, y.ndim) + assert x.shape == y.shape, (x.shape, y.shape) + + # embedding is of shape (N, L, embedding_dim) + embedding = self.input_embedding(x) + # Note: We use batch_first==True + rnn_out, (next_h0, next_c0) = self.rnn(embedding, (h0, c0)) + logits = self.output_linear(rnn_out) + nll_loss = F.cross_entropy( + logits.reshape(-1, self.vocab_size), y.reshape(-1), reduction="none" + ) + + batch_size = x.size(0) + nll_loss = nll_loss.reshape(batch_size, -1).sum(dim=1) + return nll_loss, next_h0, next_c0 + def forward( self, x: torch.Tensor, y: torch.Tensor, lengths: torch.Tensor ) -> torch.Tensor: From 298ed4520fb387ae815da24c8397157b6e08aec2 Mon Sep 17 00:00:00 2001 From: PF Luo Date: Fri, 28 Apr 2023 16:33:46 +0800 Subject: [PATCH 169/174] add meta-data embedding_dim to RNNLM onnx-model (#1026) --- icefall/rnn_lm/export-onnx.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/icefall/rnn_lm/export-onnx.py b/icefall/rnn_lm/export-onnx.py index 6855f9bea..b6d3a03ed 100755 --- a/icefall/rnn_lm/export-onnx.py +++ b/icefall/rnn_lm/export-onnx.py @@ -232,6 +232,7 @@ def export_with_state( L = 20 num_layers = model.rnn.num_layers hidden_size = model.rnn.hidden_size + embedding_dim = model.rnn.embedding_dim x = torch.randint(low=1, high=params.vocab_size, size=(N, L), dtype=torch.int64) y = torch.randint(low=1, high=params.vocab_size, size=(N, L), dtype=torch.int64) @@ -278,6 +279,7 @@ def export_with_state( "vocab_size": str(params.vocab_size), "num_layers": str(num_layers), "hidden_size": str(hidden_size), + "embedding_dim": str(embedding_dim), "url": "https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm", } logging.info(f"meta_data: {meta_data}") From b0228c536e68a50d35b16841ebb69d3fc29958fd Mon Sep 17 00:00:00 2001 From: Yuanhang Zhang Date: Fri, 28 Apr 2023 19:52:32 +0800 Subject: [PATCH 170/174] Fix typo in librispeech OpenFST-based HLG preparation script (#1028) --- egs/librispeech/ASR/prepare.sh | 2 +- egs/librispeech/ASR/prepare_multidataset.sh | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/prepare.sh b/egs/librispeech/ASR/prepare.sh index b1d207049..8342d5212 100755 --- a/egs/librispeech/ASR/prepare.sh +++ b/egs/librispeech/ASR/prepare.sh @@ -184,7 +184,7 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then ./shared/convert-k2-to-openfst.py \ --olabels aux_labels \ $lang_dir/L_disambig.pt \ - $lang_dir/disambig_L.fst + $lang_dir/L_disambig.fst fi fi diff --git a/egs/librispeech/ASR/prepare_multidataset.sh b/egs/librispeech/ASR/prepare_multidataset.sh index 8b13a5bd8..c95b4d039 100755 --- a/egs/librispeech/ASR/prepare_multidataset.sh +++ b/egs/librispeech/ASR/prepare_multidataset.sh @@ -198,7 +198,7 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then ./shared/convert-k2-to-openfst.py \ --olabels aux_labels \ $lang_dir/L_disambig.pt \ - $lang_dir/disambig_L.fst + $lang_dir/L_disambig.fst fi fi From 61ec3a7a8fc8be859b23a821e568950fb898b37a Mon Sep 17 00:00:00 2001 From: PF Luo Date: Fri, 28 Apr 2023 19:53:06 +0800 Subject: [PATCH 171/174] fix export RNNLM onnx model typo (#1029) --- icefall/rnn_lm/export-onnx.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/icefall/rnn_lm/export-onnx.py b/icefall/rnn_lm/export-onnx.py index b6d3a03ed..1d9af5e3d 100755 --- a/icefall/rnn_lm/export-onnx.py +++ b/icefall/rnn_lm/export-onnx.py @@ -232,7 +232,7 @@ def export_with_state( L = 20 num_layers = model.rnn.num_layers hidden_size = model.rnn.hidden_size - embedding_dim = model.rnn.embedding_dim + embedding_dim = model.embedding_dim x = torch.randint(low=1, high=params.vocab_size, size=(N, L), dtype=torch.int64) y = torch.randint(low=1, high=params.vocab_size, size=(N, L), dtype=torch.int64) From 80156dda09a51ff2a3e4711dd82cb15bd3243b90 Mon Sep 17 00:00:00 2001 From: Wei Kang Date: Thu, 4 May 2023 19:16:17 +0800 Subject: [PATCH 172/174] Training with byte level BPE (AIShell) (#986) * copy files from zipformer librispeech * Add byte bpe training for aishell * compile LG graph * Support LG decoding * Minor fixes * black * Minor fixes * export & fix pretrain.py * fix black * Update RESULTS.md * Fix export.py --- egs/aishell/ASR/RESULTS.md | 53 +- egs/aishell/ASR/local/compile_lg.py | 1 + egs/aishell/ASR/local/prepare_char.py | 17 +- egs/aishell/ASR/local/prepare_lang_bbpe.py | 267 ++++ egs/aishell/ASR/local/train_bbpe_model.py | 113 ++ egs/aishell/ASR/prepare.sh | 120 +- .../__init__.py | 0 .../asr_datamodule.py | 1 + .../beam_search.py | 1 + .../decode.py | 819 +++++++++++ .../decoder.py | 1 + .../encoder_interface.py | 1 + .../export.py | 320 +++++ .../jit_pretrained.py | 274 ++++ .../joiner.py | 1 + .../model.py | 1 + .../optim.py | 1 + .../pretrained.py | 345 +++++ .../scaling.py | 1 + .../scaling_converter.py | 1 + .../test_model.py | 1 + .../train.py | 1261 +++++++++++++++++ .../zipformer.py | 1 + .../ASR/tdnn_lstm_ctc/asr_datamodule.py | 19 +- egs/librispeech/ASR/local/compile_lg.py | 23 +- .../beam_search.py | 17 + icefall/__init__.py | 8 + icefall/byte_utils.py | 311 ++++ icefall/rnn_lm/model.py | 4 +- icefall/utils.py | 56 + 30 files changed, 3992 insertions(+), 47 deletions(-) create mode 120000 egs/aishell/ASR/local/compile_lg.py create mode 100755 egs/aishell/ASR/local/prepare_lang_bbpe.py create mode 100755 egs/aishell/ASR/local/train_bbpe_model.py create mode 100644 egs/aishell/ASR/pruned_transducer_stateless7_bbpe/__init__.py create mode 120000 egs/aishell/ASR/pruned_transducer_stateless7_bbpe/asr_datamodule.py create mode 120000 egs/aishell/ASR/pruned_transducer_stateless7_bbpe/beam_search.py create mode 100755 egs/aishell/ASR/pruned_transducer_stateless7_bbpe/decode.py create mode 120000 egs/aishell/ASR/pruned_transducer_stateless7_bbpe/decoder.py create mode 120000 egs/aishell/ASR/pruned_transducer_stateless7_bbpe/encoder_interface.py create mode 100755 egs/aishell/ASR/pruned_transducer_stateless7_bbpe/export.py create mode 100755 egs/aishell/ASR/pruned_transducer_stateless7_bbpe/jit_pretrained.py create mode 120000 egs/aishell/ASR/pruned_transducer_stateless7_bbpe/joiner.py create mode 120000 egs/aishell/ASR/pruned_transducer_stateless7_bbpe/model.py create mode 120000 egs/aishell/ASR/pruned_transducer_stateless7_bbpe/optim.py create mode 100755 egs/aishell/ASR/pruned_transducer_stateless7_bbpe/pretrained.py create mode 120000 egs/aishell/ASR/pruned_transducer_stateless7_bbpe/scaling.py create mode 120000 egs/aishell/ASR/pruned_transducer_stateless7_bbpe/scaling_converter.py create mode 120000 egs/aishell/ASR/pruned_transducer_stateless7_bbpe/test_model.py create mode 100755 egs/aishell/ASR/pruned_transducer_stateless7_bbpe/train.py create mode 120000 egs/aishell/ASR/pruned_transducer_stateless7_bbpe/zipformer.py create mode 100644 icefall/byte_utils.py diff --git a/egs/aishell/ASR/RESULTS.md b/egs/aishell/ASR/RESULTS.md index 4c730c4ae..aa18502c2 100644 --- a/egs/aishell/ASR/RESULTS.md +++ b/egs/aishell/ASR/RESULTS.md @@ -2,6 +2,57 @@ ### Aishell training result(Stateless Transducer) +#### Pruned transducer stateless 7 (zipformer) + +See + +[./pruned_transducer_stateless7_bbpe](./pruned_transducer_stateless7_bbpe) + +**Note**: The modeling units are byte level BPEs + +The best results I have gotten are: + +Vocab size | Greedy search(dev & test) | Modified beam search(dev & test) | Fast beam search (dev & test) | Fast beam search LG (dev & test) | comments +-- | -- | -- | -- | -- | -- +500 | 4.31 & 4.59 | 4.25 & 4.54 | 4.27 & 4.55 | 4.07 & 4.38 | --epoch 48 --avg 29 + +The training command: + +``` +export CUDA_VISIBLE_DEVICES="4,5,6,7" + +./pruned_transducer_stateless7_bbpe/train.py \ + --world-size 4 \ + --num-epochs 50 \ + --start-epoch 1 \ + --use-fp16 1 \ + --max-duration 800 \ + --bpe-model data/lang_bbpe_500/bbpe.model \ + --exp-dir pruned_transducer_stateless7_bbpe/exp \ + --lr-epochs 6 \ + --master-port 12535 +``` + +The decoding command: + +``` +for m in greedy_search modified_beam_search fast_beam_search fast_beam_search_LG; do + ./pruned_transducer_stateless7_bbpe/decode.py \ + --epoch 48 \ + --avg 29 \ + --exp-dir ./pruned_transducer_stateless7_bbpe/exp \ + --max-sym-per-frame 1 \ + --ngram-lm-scale 0.25 \ + --ilme-scale 0.2 \ + --bpe-model data/lang_bbpe_500/bbpe.model \ + --max-duration 2000 \ + --decoding-method $m +done +``` + +The pretrained model is available at: https://huggingface.co/pkufool/icefall_asr_aishell_pruned_transducer_stateless7_bbpe + + #### Pruned transducer stateless 3 See @@ -75,7 +126,7 @@ for epoch in 29; do done ``` -We provide the option of shallow fusion with a RNN language model. The pre-trained language model is +We provide the option of shallow fusion with a RNN language model. The pre-trained language model is available at . To decode with the language model, please use the following command: diff --git a/egs/aishell/ASR/local/compile_lg.py b/egs/aishell/ASR/local/compile_lg.py new file mode 120000 index 000000000..462d6d3fb --- /dev/null +++ b/egs/aishell/ASR/local/compile_lg.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/compile_lg.py \ No newline at end of file diff --git a/egs/aishell/ASR/local/prepare_char.py b/egs/aishell/ASR/local/prepare_char.py index 6b440dfb3..8cc0502c2 100755 --- a/egs/aishell/ASR/local/prepare_char.py +++ b/egs/aishell/ASR/local/prepare_char.py @@ -33,6 +33,7 @@ and generates the following files in the directory `lang_dir`: - tokens.txt """ +import argparse import re from pathlib import Path from typing import Dict, List @@ -189,8 +190,22 @@ def generate_tokens(text_file: str) -> Dict[str, int]: return tokens +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--lang-dir", + type=str, + help="""Input and output directory. + It should contain the bpe.model and words.txt + """, + ) + + return parser.parse_args() + + def main(): - lang_dir = Path("data/lang_char") + args = get_args() + lang_dir = Path(args.lang_dir) text_file = lang_dir / "text" word_sym_table = k2.SymbolTable.from_file(lang_dir / "words.txt") diff --git a/egs/aishell/ASR/local/prepare_lang_bbpe.py b/egs/aishell/ASR/local/prepare_lang_bbpe.py new file mode 100755 index 000000000..ddd90622e --- /dev/null +++ b/egs/aishell/ASR/local/prepare_lang_bbpe.py @@ -0,0 +1,267 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang +# Wei Kang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" + +This script takes as input `lang_dir`, which should contain:: + + - lang_dir/bbpe.model, + - lang_dir/words.txt + +and generates the following files in the directory `lang_dir`: + + - lexicon.txt + - lexicon_disambig.txt + - L.pt + - L_disambig.pt + - tokens.txt +""" + +import argparse +from pathlib import Path +from typing import Dict, List, Tuple + +import k2 +import sentencepiece as spm +import torch +from prepare_lang import ( + Lexicon, + add_disambig_symbols, + add_self_loops, + write_lexicon, + write_mapping, +) + +from icefall.byte_utils import byte_encode +from icefall.utils import str2bool, tokenize_by_CJK_char + + +def lexicon_to_fst_no_sil( + lexicon: Lexicon, + token2id: Dict[str, int], + word2id: Dict[str, int], + need_self_loops: bool = False, +) -> k2.Fsa: + """Convert a lexicon to an FST (in k2 format). + + Args: + lexicon: + The input lexicon. See also :func:`read_lexicon` + token2id: + A dict mapping tokens to IDs. + word2id: + A dict mapping words to IDs. + need_self_loops: + If True, add self-loop to states with non-epsilon output symbols + on at least one arc out of the state. The input label for this + self loop is `token2id["#0"]` and the output label is `word2id["#0"]`. + Returns: + Return an instance of `k2.Fsa` representing the given lexicon. + """ + loop_state = 0 # words enter and leave from here + next_state = 1 # the next un-allocated state, will be incremented as we go + + arcs = [] + + # The blank symbol is defined in local/train_bpe_model.py + assert token2id[""] == 0 + assert word2id[""] == 0 + + eps = 0 + + for word, pieces in lexicon: + assert len(pieces) > 0, f"{word} has no pronunciations" + cur_state = loop_state + + word = word2id[word] + pieces = [token2id[i] for i in pieces] + + for i in range(len(pieces) - 1): + w = word if i == 0 else eps + arcs.append([cur_state, next_state, pieces[i], w, 0]) + + cur_state = next_state + next_state += 1 + + # now for the last piece of this word + i = len(pieces) - 1 + w = word if i == 0 else eps + arcs.append([cur_state, loop_state, pieces[i], w, 0]) + + if need_self_loops: + disambig_token = token2id["#0"] + disambig_word = word2id["#0"] + arcs = add_self_loops( + arcs, + disambig_token=disambig_token, + disambig_word=disambig_word, + ) + + final_state = next_state + arcs.append([loop_state, final_state, -1, -1, 0]) + arcs.append([final_state]) + + arcs = sorted(arcs, key=lambda arc: arc[0]) + arcs = [[str(i) for i in arc] for arc in arcs] + arcs = [" ".join(arc) for arc in arcs] + arcs = "\n".join(arcs) + + fsa = k2.Fsa.from_str(arcs, acceptor=False) + return fsa + + +def generate_lexicon( + model_file: str, words: List[str], oov: str +) -> Tuple[Lexicon, Dict[str, int]]: + """Generate a lexicon from a BPE model. + + Args: + model_file: + Path to a sentencepiece model. + words: + A list of strings representing words. + oov: + The out of vocabulary word in lexicon. + Returns: + Return a tuple with two elements: + - A dict whose keys are words and values are the corresponding + word pieces. + - A dict representing the token symbol, mapping from tokens to IDs. + """ + sp = spm.SentencePieceProcessor() + sp.load(str(model_file)) + + # Convert word to word piece IDs instead of word piece strings + # to avoid OOV tokens. + encode_words = [byte_encode(tokenize_by_CJK_char(w)) for w in words] + words_pieces_ids: List[List[int]] = sp.encode(encode_words, out_type=int) + + # Now convert word piece IDs back to word piece strings. + words_pieces: List[List[str]] = [sp.id_to_piece(ids) for ids in words_pieces_ids] + + lexicon = [] + for word, pieces in zip(words, words_pieces): + lexicon.append((word, pieces)) + + lexicon.append((oov, ["▁", sp.id_to_piece(sp.unk_id())])) + + token2id: Dict[str, int] = {sp.id_to_piece(i): i for i in range(sp.vocab_size())} + + return lexicon, token2id + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--lang-dir", + type=str, + help="""Input and output directory. + It should contain the bpe.model and words.txt + """, + ) + + parser.add_argument( + "--oov", + type=str, + default="", + help="The out of vocabulary word in lexicon.", + ) + + parser.add_argument( + "--debug", + type=str2bool, + default=False, + help="""True for debugging, which will generate + a visualization of the lexicon FST. + + Caution: If your lexicon contains hundreds of thousands + of lines, please set it to False! + + See "test/test_bpe_lexicon.py" for usage. + """, + ) + + return parser.parse_args() + + +def main(): + args = get_args() + lang_dir = Path(args.lang_dir) + model_file = lang_dir / "bbpe.model" + + word_sym_table = k2.SymbolTable.from_file(lang_dir / "words.txt") + + words = word_sym_table.symbols + + excluded = ["", "!SIL", "", args.oov, "#0", "", ""] + + for w in excluded: + if w in words: + words.remove(w) + + lexicon, token_sym_table = generate_lexicon(model_file, words, args.oov) + + lexicon_disambig, max_disambig = add_disambig_symbols(lexicon) + + next_token_id = max(token_sym_table.values()) + 1 + for i in range(max_disambig + 1): + disambig = f"#{i}" + assert disambig not in token_sym_table + token_sym_table[disambig] = next_token_id + next_token_id += 1 + + word_sym_table.add("#0") + word_sym_table.add("") + word_sym_table.add("") + + write_mapping(lang_dir / "tokens.txt", token_sym_table) + + write_lexicon(lang_dir / "lexicon.txt", lexicon) + write_lexicon(lang_dir / "lexicon_disambig.txt", lexicon_disambig) + + L = lexicon_to_fst_no_sil( + lexicon, + token2id=token_sym_table, + word2id=word_sym_table, + ) + + L_disambig = lexicon_to_fst_no_sil( + lexicon_disambig, + token2id=token_sym_table, + word2id=word_sym_table, + need_self_loops=True, + ) + torch.save(L.as_dict(), lang_dir / "L.pt") + torch.save(L_disambig.as_dict(), lang_dir / "L_disambig.pt") + + if args.debug: + labels_sym = k2.SymbolTable.from_file(lang_dir / "tokens.txt") + aux_labels_sym = k2.SymbolTable.from_file(lang_dir / "words.txt") + + L.labels_sym = labels_sym + L.aux_labels_sym = aux_labels_sym + L.draw(f"{lang_dir / 'L.svg'}", title="L.pt") + + L_disambig.labels_sym = labels_sym + L_disambig.aux_labels_sym = aux_labels_sym + L_disambig.draw(f"{lang_dir / 'L_disambig.svg'}", title="L_disambig.pt") + + +if __name__ == "__main__": + main() diff --git a/egs/aishell/ASR/local/train_bbpe_model.py b/egs/aishell/ASR/local/train_bbpe_model.py new file mode 100755 index 000000000..d231d5d77 --- /dev/null +++ b/egs/aishell/ASR/local/train_bbpe_model.py @@ -0,0 +1,113 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Wei Kang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# You can install sentencepiece via: +# +# pip install sentencepiece +# +# Due to an issue reported in +# https://github.com/google/sentencepiece/pull/642#issuecomment-857972030 +# +# Please install a version >=0.1.96 + +import argparse +import re +import shutil +import tempfile +from pathlib import Path + +import sentencepiece as spm +from icefall import byte_encode, tokenize_by_CJK_char + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--lang-dir", + type=str, + help="""Input and output directory. + The generated bpe.model is saved to this directory. + """, + ) + + parser.add_argument( + "--transcript", + type=str, + help="Training transcript.", + ) + + parser.add_argument( + "--vocab-size", + type=int, + help="Vocabulary size for BPE training", + ) + + return parser.parse_args() + + +def _convert_to_bchar(in_path: str, out_path: str): + with open(out_path, "w") as f: + for line in open(in_path, "r").readlines(): + f.write(byte_encode(tokenize_by_CJK_char(line)) + "\n") + + +def main(): + args = get_args() + vocab_size = args.vocab_size + lang_dir = Path(args.lang_dir) + + model_type = "unigram" + + model_prefix = f"{lang_dir}/{model_type}_{vocab_size}" + character_coverage = 1.0 + input_sentence_size = 100000000 + + user_defined_symbols = ["", ""] + unk_id = len(user_defined_symbols) + # Note: unk_id is fixed to 2. + # If you change it, you should also change other + # places that are using it. + + temp = tempfile.NamedTemporaryFile() + train_text = temp.name + + _convert_to_bchar(args.transcript, train_text) + + model_file = Path(model_prefix + ".model") + if not model_file.is_file(): + spm.SentencePieceTrainer.train( + input=train_text, + vocab_size=vocab_size, + model_type=model_type, + model_prefix=model_prefix, + input_sentence_size=input_sentence_size, + character_coverage=character_coverage, + user_defined_symbols=user_defined_symbols, + unk_id=unk_id, + bos_id=-1, + eos_id=-1, + ) + else: + print(f"{model_file} exists - skipping") + return + + shutil.copyfile(model_file, f"{lang_dir}/bbpe.model") + + +if __name__ == "__main__": + main() diff --git a/egs/aishell/ASR/prepare.sh b/egs/aishell/ASR/prepare.sh index 3e0d5f51b..b763d72c1 100755 --- a/egs/aishell/ASR/prepare.sh +++ b/egs/aishell/ASR/prepare.sh @@ -35,6 +35,15 @@ dl_dir=$PWD/download . shared/parse_options.sh || exit 1 +# vocab size for sentence piece models. +# It will generate data/lang_bbpe_xxx, +# data/lang_bbpe_yyy if the array contains xxx, yyy +vocab_sizes=( + # 2000 + # 1000 + 500 +) + # All files generated by this script are saved in "data". # You can safely remove "data" and rerun this script to regenerate it. mkdir -p data @@ -47,20 +56,6 @@ log() { log "dl_dir: $dl_dir" -if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then - log "stage -1: Download LM" - # We assume that you have installed the git-lfs, if not, you could install it - # using: `sudo apt-get install git-lfs && git-lfs install` - git lfs 1>/dev/null 2>&1 || (echo "please install git-lfs, consider using: sudo apt-get install git-lfs && git-lfs install" && exit 1) - - if [ ! -f $dl_dir/lm/3-gram.unpruned.arpa ]; then - git clone https://huggingface.co/pkufool/aishell_lm $dl_dir/lm - pushd $dl_dir/lm - git lfs pull --include "3-gram.unpruned.arpa" - popd - fi -fi - if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then log "stage 0: Download data" @@ -134,7 +129,6 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then fi lang_phone_dir=data/lang_phone -lang_char_dir=data/lang_char if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then log "Stage 5: Prepare phone based lang" mkdir -p $lang_phone_dir @@ -183,45 +177,107 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then fi fi +lang_char_dir=data/lang_char if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then log "Stage 6: Prepare char based lang" mkdir -p $lang_char_dir # We reuse words.txt from phone based lexicon # so that the two can share G.pt later. - cp $lang_phone_dir/words.txt $lang_char_dir + + # The transcripts in training set, generated in stage 5 + cp $lang_phone_dir/transcript_words.txt $lang_char_dir/transcript_words.txt cat $dl_dir/aishell/data_aishell/transcript/aishell_transcript_v0.8.txt | - cut -d " " -f 2- | sed -e 's/[ \t\r\n]*//g' > $lang_char_dir/text + cut -d " " -f 2- > $lang_char_dir/text + + (echo ' 0'; echo '!SIL 1'; echo ' 2'; echo ' 3';) \ + > $lang_char_dir/words.txt + + cat $lang_char_dir/text | sed 's/ /\n/g' | sort -u | sed '/^$/d' \ + | awk '{print $1" "NR+3}' >> $lang_char_dir/words.txt + + num_lines=$(< $lang_char_dir/words.txt wc -l) + (echo "#0 $num_lines"; echo " $(($num_lines + 1))"; echo " $(($num_lines + 2))";) \ + >> $lang_char_dir/words.txt if [ ! -f $lang_char_dir/L_disambig.pt ]; then - ./local/prepare_char.py + ./local/prepare_char.py --lang-dir $lang_char_dir fi fi if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then - log "Stage 7: Prepare G" - # We assume you have install kaldilm, if not, please install - # it using: pip install kaldilm + log "Stage 7: Prepare Byte BPE based lang" + + for vocab_size in ${vocab_sizes[@]}; do + lang_dir=data/lang_bbpe_${vocab_size} + mkdir -p $lang_dir + + cp $lang_char_dir/words.txt $lang_dir + cp $lang_char_dir/text $lang_dir + + if [ ! -f $lang_dir/bbpe.model ]; then + ./local/train_bbpe_model.py \ + --lang-dir $lang_dir \ + --vocab-size $vocab_size \ + --transcript $lang_dir/text + fi + + if [ ! -f $lang_dir/L_disambig.pt ]; then + ./local/prepare_lang_bbpe.py --lang-dir $lang_dir + fi + done +fi + +if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then + log "Stage 8: Prepare G" mkdir -p data/lm - if [ ! -f data/lm/G_3_gram.fst.txt ]; then + + # Train LM on transcripts + if [ ! -f data/lm/3-gram.unpruned.arpa ]; then + python3 ./shared/make_kn_lm.py \ + -ngram-order 3 \ + -text $lang_char_dir/transcript_words.txt \ + -lm data/lm/3-gram.unpruned.arpa + fi + + # We assume you have install kaldilm, if not, please install + # it using: pip install kaldilm + if [ ! -f data/lm/G_3_gram_char.fst.txt ]; then # It is used in building HLG python3 -m kaldilm \ --read-symbol-table="$lang_phone_dir/words.txt" \ --disambig-symbol='#0' \ --max-order=3 \ - $dl_dir/lm/3-gram.unpruned.arpa > data/lm/G_3_gram.fst.txt + data/lm/3-gram.unpruned.arpa > data/lm/G_3_gram_phone.fst.txt + + python3 -m kaldilm \ + --read-symbol-table="$lang_char_dir/words.txt" \ + --disambig-symbol='#0' \ + --max-order=3 \ + data/lm/3-gram.unpruned.arpa > data/lm/G_3_gram_char.fst.txt fi fi -if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then - log "Stage 8: Compile HLG" - ./local/compile_hlg.py --lang-dir $lang_phone_dir - ./local/compile_hlg.py --lang-dir $lang_char_dir +if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then + log "Stage 9: Compile LG & HLG" + ./local/compile_hlg.py --lang-dir $lang_phone_dir --lm G_3_gram_phone + ./local/compile_hlg.py --lang-dir $lang_char_dir --lm G_3_gram_char + for vocab_size in ${vocab_sizes[@]}; do + lang_dir=data/lang_bbpe_${vocab_size} + ./local/compile_hlg.py --lang-dir $lang_dir --lm G_3_gram_char + done + + ./local/compile_lg.py --lang-dir $lang_phone_dir --lm G_3_gram_phone + ./local/compile_lg.py --lang-dir $lang_char_dir --lm G_3_gram_char + for vocab_size in ${vocab_sizes[@]}; do + lang_dir=data/lang_bbpe_${vocab_size} + ./local/compile_lg.py --lang-dir $lang_dir --lm G_3_gram_char + done fi -if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then - log "Stage 9: Generate LM training data" +if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then + log "Stage 10: Generate LM training data" log "Processing char based data" out_dir=data/lm_training_char @@ -267,8 +323,8 @@ if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then fi -if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then - log "Stage 10: Sort LM training data" +if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then + log "Stage 11: Sort LM training data" # Sort LM training data by sentence length in descending order # for ease of training. # @@ -295,7 +351,7 @@ if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then --out-statistics $out_dir/statistics-test.txt fi -if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then +if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then log "Stage 11: Train RNN LM model" python ../../../icefall/rnn_lm/train.py \ --start-epoch 0 \ diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/__init__.py b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/asr_datamodule.py b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/asr_datamodule.py new file mode 120000 index 000000000..a074d6085 --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/asr_datamodule.py @@ -0,0 +1 @@ +../pruned_transducer_stateless2/asr_datamodule.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/beam_search.py b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/beam_search.py new file mode 120000 index 000000000..8554e44cc --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/beam_search.py @@ -0,0 +1 @@ +../pruned_transducer_stateless2/beam_search.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/decode.py b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/decode.py new file mode 100755 index 000000000..fcb0ebc4e --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/decode.py @@ -0,0 +1,819 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao, +# Xiaoyu Yang, +# Wei Kang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) greedy search +./pruned_transducer_stateless7_bbpe/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_bbpe/exp \ + --max-duration 600 \ + --decoding-method greedy_search + +(2) beam search (not recommended) +./pruned_transducer_stateless7_bbpe/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_bbpe/exp \ + --max-duration 600 \ + --decoding-method beam_search \ + --beam-size 4 + +(3) modified beam search +./pruned_transducer_stateless7_bbpe/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_bbpe/exp \ + --max-duration 600 \ + --decoding-method modified_beam_search \ + --beam-size 4 + +(4) fast beam search (one best) +./pruned_transducer_stateless7_bbpe/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_bbpe/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 + +(5) fast beam search (nbest) +./pruned_transducer_stateless7_bbpe/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_bbpe/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(6) fast beam search (nbest oracle WER) +./pruned_transducer_stateless7_bbpe/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_bbpe/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_oracle \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(7) fast beam search (with LG) +./pruned_transducer_stateless7_bbpe/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_bbpe/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_LG \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 +""" + +import argparse +import logging +import math +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import AishellAsrDataModule +from beam_search import ( + beam_search, + fast_beam_search_nbest, + fast_beam_search_nbest_oracle, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, +) +from train import add_model_arguments, get_params, get_transducer_model + +from icefall import ( + LmScorer, + NgramLm, + byte_encode, + smart_byte_decode, + tokenize_by_CJK_char, +) +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=9, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bbpe_500/bbpe.model", + help="Path to the byte BPE model", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_bbpe_500", + help="The lang dir containing word table and LG graph", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - fast_beam_search + - fast_beam_search_LG + - fast_beam_search_nbest + - fast_beam_search_nbest_oracle + If you use fast_beam_search_LG, you have to specify + `--lang-dir`, which should contain `LG.pt`. + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=20.0, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search, + fast_beam_search_nbest, fast_beam_search_LG, + and fast_beam_search_nbest_oracle + """, + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.25, + help=""" + Used only when --decoding_method is fast_beam_search_LG. + It specifies the scale for n-gram LM scores. + """, + ) + + parser.add_argument( + "--ilme-scale", + type=float, + default=0.2, + help=""" + Used only when --decoding_method is fast_beam_search_LG. + It specifies the scale for the internal language model estimation. + """, + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=8, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=64, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. + Used only when --decoding_method is greedy_search""", + ) + + parser.add_argument( + "--num-paths", + type=int, + default=200, + help="""Number of paths for nbest decoding. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""Scale applied to lattice scores when computing nbest paths. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + add_model_arguments(parser) + + return parser + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + batch: dict, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if greedy_search is used, it would be "greedy_search" + If beam search with a beam size of 7 is used, it would be + "beam_7" + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return the decoding result. See above description for the format of + the returned dict. + """ + device = next(model.parameters()).device + feature = batch["inputs"] + assert feature.ndim == 3 + + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + + hyps = [] + + if params.decoding_method == "fast_beam_search": + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(smart_byte_decode(hyp).split()) + elif params.decoding_method == "fast_beam_search_LG": + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + subtract_ilme=True, + ilme_scale=params.ilme_scale, + ) + for hyp in hyp_tokens: + hyps.append([word_table[i] for i in hyp]) + elif params.decoding_method == "fast_beam_search_nbest": + hyp_tokens = fast_beam_search_nbest( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(smart_byte_decode(hyp).split()) + elif params.decoding_method == "fast_beam_search_nbest_oracle": + ref_texts = [] + for tx in supervisions["text"]: + ref_texts.append(byte_encode(tokenize_by_CJK_char(tx))) + + hyp_tokens = fast_beam_search_nbest_oracle( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + ref_texts=sp.encode(ref_texts), + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(smart_byte_decode(hyp).split()) + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(smart_byte_decode(hyp).split()) + elif params.decoding_method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(smart_byte_decode(hyp).split()) + else: + batch_size = encoder_out.size(0) + + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.decoding_method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyps.append(smart_byte_decode(sp.decode(hyp)).split()) + + if params.decoding_method == "greedy_search": + return {"greedy_search": hyps} + elif "fast_beam_search" in params.decoding_method: + key = f"beam_{params.beam}_" + key += f"max_contexts_{params.max_contexts}_" + key += f"max_states_{params.max_states}" + if "nbest" in params.decoding_method: + key += f"_num_paths_{params.num_paths}_" + key += f"nbest_scale_{params.nbest_scale}" + if "LG" in params.decoding_method: + key += f"_ngram_lm_scale_{params.ngram_lm_scale}" + key += f"_ilme_scale_{params.ilme_scale}" + + return {key: hyps} + else: + return {f"beam_size_{params.beam_size}": hyps} + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + LM: + A neural network LM, used during shallow fusion + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + if params.decoding_method == "greedy_search": + log_interval = 50 + else: + log_interval = 20 + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + word_table=word_table, + batch=batch, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" + + results_char = [] + for res in results: + results_char.append((res[0], list("".join(res[1])), list("".join(res[2])))) + + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results_char, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + AishellAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + assert params.decoding_method in ( + "greedy_search", + "beam_search", + "fast_beam_search", + "fast_beam_search_LG", + "fast_beam_search_nbest", + "fast_beam_search_nbest_oracle", + "modified_beam_search", + ) + params.res_dir = params.exp_dir / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if "fast_beam_search" in params.decoding_method: + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + if "nbest" in params.decoding_method: + params.suffix += f"-nbest-scale-{params.nbest_scale}" + params.suffix += f"-num-paths-{params.num_paths}" + if "LG" in params.decoding_method: + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" + params.suffix += f"-ilme-scale-{params.ilme_scale}" + elif "beam_search" in params.decoding_method: + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + else: + params.suffix += f"-context-{params.context_size}" + params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and are defined in local/train_bbpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + + if "fast_beam_search" in params.decoding_method: + if params.decoding_method == "fast_beam_search_LG": + lexicon = Lexicon(params.lang_dir) + word_table = lexicon.word_table + lg_filename = params.lang_dir / "LG.pt" + logging.info(f"Loading {lg_filename}") + decoding_graph = k2.Fsa.from_dict( + torch.load(lg_filename, map_location=device) + ) + decoding_graph.scores *= params.ngram_lm_scale + else: + word_table = None + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + else: + decoding_graph = None + word_table = None + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + aishell = AishellAsrDataModule(args) + + test_cuts = aishell.test_cuts() + dev_cuts = aishell.valid_cuts() + + test_dl = aishell.test_dataloaders(test_cuts) + dev_dl = aishell.test_dataloaders(dev_cuts) + + test_sets = ["test", "dev"] + test_dls = [test_dl, dev_dl] + + for test_set, test_dl in zip(test_sets, test_dls): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + sp=sp, + word_table=word_table, + decoding_graph=decoding_graph, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/decoder.py b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/decoder.py new file mode 120000 index 000000000..8283d8c5a --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/decoder.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7/decoder.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/encoder_interface.py b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/encoder_interface.py new file mode 120000 index 000000000..b9aa0ae08 --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/encoder_interface.py @@ -0,0 +1 @@ +../pruned_transducer_stateless2/encoder_interface.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/export.py b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/export.py new file mode 100755 index 000000000..4e82b45d3 --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/export.py @@ -0,0 +1,320 @@ +#!/usr/bin/env python3 +# +# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This script converts several saved checkpoints +# to a single one using model averaging. +""" + +Usage: + +(1) Export to torchscript model using torch.jit.script() + +./pruned_transducer_stateless7_bbpe/export.py \ + --exp-dir ./pruned_transducer_stateless7_bbpe/exp \ + --bpe-model data/lang_bbpe_500/bbpe.model \ + --epoch 30 \ + --avg 9 \ + --jit 1 + +It will generate a file `cpu_jit.pt` in the given `exp_dir`. You can later +load it by `torch.jit.load("cpu_jit.pt")`. + +Note `cpu` in the name `cpu_jit.pt` means the parameters when loaded into Python +are on CPU. You can use `to("cuda")` to move them to a CUDA device. + +Check +https://github.com/k2-fsa/sherpa +for how to use the exported models outside of icefall. + +(2) Export `model.state_dict()` + +./pruned_transducer_stateless7_bbpe/export.py \ + --exp-dir ./pruned_transducer_stateless7_bbpe/exp \ + --bpe-model data/lang_bbpe_500/bbpe.model \ + --epoch 20 \ + --avg 10 + +It will generate a file `pretrained.pt` in the given `exp_dir`. You can later +load it by `icefall.checkpoint.load_checkpoint()`. + +To use the generated file with `pruned_transducer_stateless7_bbpe/decode.py`, +you can do: + + cd /path/to/exp_dir + ln -s pretrained.pt epoch-9999.pt + + cd /path/to/egs/aishell/ASR + ./pruned_transducer_stateless7_bbpe/decode.py \ + --exp-dir ./pruned_transducer_stateless7_bbpe/exp \ + --epoch 9999 \ + --avg 1 \ + --max-duration 600 \ + --decoding-method greedy_search \ + --bpe-model data/lang_bbpe_500/bbpe.model + +Check ./pretrained.py for its usage. + +Note: If you don't want to train a model from scratch, we have +provided one for you. You can get it at + +https://huggingface.co/pkufool/icefall_asr_aishell_pruned_transducer_stateless7_bbpe + +with the following commands: + + sudo apt-get install git-lfs + git lfs install + git clone https://huggingface.co/pkufool/icefall_asr_aishell_pruned_transducer_stateless7_bbpe + # You will find the pre-trained model in icefall_asr_aishell_pruned_transducer_stateless7_bbpe/exp +""" + +import argparse +import logging +from pathlib import Path + +import sentencepiece as spm +import torch +import torch.nn as nn +from scaling_converter import convert_scaled_to_non_scaled +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=9, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7_bbpe/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bbpe_500/bbpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--jit", + type=str2bool, + default=False, + help="""True to save a model after applying torch.jit.script. + It will generate a file named cpu_jit.pt + + Check ./jit_pretrained.py for how to use it. + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + add_model_arguments(parser) + + return parser + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bbpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + model.to(device) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to("cpu") + model.eval() + + if params.jit is True: + convert_scaled_to_non_scaled(model, inplace=True) + # We won't use the forward() method of the model in C++, so just ignore + # it here. + # Otherwise, one of its arguments is a ragged tensor and is not + # torch scriptabe. + model.__class__.forward = torch.jit.ignore(model.__class__.forward) + logging.info("Using torch.jit.script") + model = torch.jit.script(model) + filename = params.exp_dir / "cpu_jit.pt" + model.save(str(filename)) + logging.info(f"Saved to {filename}") + else: + logging.info("Not using torchscript. Export model.state_dict()") + # Save it using a format so that it can be loaded + # by :func:`load_checkpoint` + filename = params.exp_dir / "pretrained.pt" + torch.save({"model": model.state_dict()}, str(filename)) + logging.info(f"Saved to {filename}") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/jit_pretrained.py b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/jit_pretrained.py new file mode 100755 index 000000000..0c43bf74b --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/jit_pretrained.py @@ -0,0 +1,274 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads torchscript models, exported by `torch.jit.script()` +and uses them to decode waves. +You can use the following command to get the exported models: + +./pruned_transducer_stateless7_bbpe/export.py \ + --exp-dir ./pruned_transducer_stateless7_bbpe/exp \ + --bpe-model data/lang_bbpe_500/bbpe.model \ + --epoch 49 \ + --avg 28 \ + --jit 1 + +Usage of this script: + +./pruned_transducer_stateless7_bbpe/jit_pretrained.py \ + --nn-model-filename ./pruned_transducer_stateless7_bbpe/exp/cpu_jit.pt \ + --bpe-model ./data/lang_bbpe_500/bbpe.model \ + /path/to/foo.wav \ + /path/to/bar.wav +""" + +import argparse +import logging +import math +from typing import List + +import kaldifeat +import sentencepiece as spm +import torch +import torchaudio +from torch.nn.utils.rnn import pad_sequence + +from icefall import smart_byte_decode + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--nn-model-filename", + type=str, + required=True, + help="Path to the torchscript model cpu_jit.pt", + ) + + parser.add_argument( + "--bpe-model", + type=str, + help="""Path to bpe.model.""", + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float = 16000 +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0]) + return ans + + +def greedy_search( + model: torch.jit.ScriptModule, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, +) -> List[List[int]]: + """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. + Args: + model: + The transducer model. + encoder_out: + A 3-D tensor of shape (N, T, C) + encoder_out_lens: + A 1-D tensor of shape (N,). + Returns: + Return the decoded results for each utterance. + """ + assert encoder_out.ndim == 3 + assert encoder_out.size(0) >= 1, encoder_out.size(0) + + packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( + input=encoder_out, + lengths=encoder_out_lens.cpu(), + batch_first=True, + enforce_sorted=False, + ) + + device = encoder_out.device + blank_id = 0 # hard-code to 0 + + batch_size_list = packed_encoder_out.batch_sizes.tolist() + N = encoder_out.size(0) + + assert torch.all(encoder_out_lens > 0), encoder_out_lens + assert N == batch_size_list[0], (N, batch_size_list) + + context_size = model.decoder.context_size + hyps = [[blank_id] * context_size for _ in range(N)] + + decoder_input = torch.tensor( + hyps, + device=device, + dtype=torch.int64, + ) # (N, context_size) + + decoder_out = model.decoder( + decoder_input, + need_pad=torch.tensor([False]), + ).squeeze(1) + + offset = 0 + for batch_size in batch_size_list: + start = offset + end = offset + batch_size + current_encoder_out = packed_encoder_out.data[start:end] + current_encoder_out = current_encoder_out + # current_encoder_out's shape: (batch_size, encoder_out_dim) + offset = end + + decoder_out = decoder_out[:batch_size] + + logits = model.joiner( + current_encoder_out, + decoder_out, + ) + # logits'shape (batch_size, vocab_size) + + assert logits.ndim == 2, logits.shape + y = logits.argmax(dim=1).tolist() + emitted = False + for i, v in enumerate(y): + if v != blank_id: + hyps[i].append(v) + emitted = True + if emitted: + # update decoder output + decoder_input = [h[-context_size:] for h in hyps[:batch_size]] + decoder_input = torch.tensor( + decoder_input, + device=device, + dtype=torch.int64, + ) + decoder_out = model.decoder( + decoder_input, + need_pad=torch.tensor([False]), + ) + decoder_out = decoder_out.squeeze(1) + + sorted_ans = [h[context_size:] for h in hyps] + ans = [] + unsorted_indices = packed_encoder_out.unsorted_indices.tolist() + for i in range(N): + ans.append(sorted_ans[unsorted_indices[i]]) + + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + logging.info(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + model = torch.jit.load(args.nn_model_filename) + + model.eval() + + model.to(device) + + sp = spm.SentencePieceProcessor() + sp.load(args.bpe_model) + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = 16000 + opts.mel_opts.num_bins = 80 + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {args.sound_files}") + waves = read_sound_files( + filenames=args.sound_files, + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + + features = pad_sequence( + features, + batch_first=True, + padding_value=math.log(1e-10), + ) + + feature_lengths = torch.tensor(feature_lengths, device=device) + + encoder_out, encoder_out_lens = model.encoder( + x=features, + x_lens=feature_lengths, + ) + + hyps = greedy_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + s = "\n" + for filename, hyp in zip(args.sound_files, hyps): + words = smart_byte_decode(sp.decode(hyp)) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/joiner.py b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/joiner.py new file mode 120000 index 000000000..0f0c3c90a --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/joiner.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7/joiner.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/model.py b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/model.py new file mode 120000 index 000000000..0d8bc665b --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/model.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7/model.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/optim.py b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/optim.py new file mode 120000 index 000000000..8a05abb5f --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/optim.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7/optim.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/pretrained.py b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/pretrained.py new file mode 100755 index 000000000..ea5bda4db --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/pretrained.py @@ -0,0 +1,345 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads a checkpoint and uses it to decode waves. +You can generate the checkpoint with the following command: + +./pruned_transducer_stateless7_bbpe/export.py \ + --exp-dir ./pruned_transducer_stateless7_bbpe/exp \ + --bpe-model data/lang_bbpe_500/bbpe.model \ + --epoch 48 \ + --avg 29 + +Usage of this script: + +(1) greedy search +./pruned_transducer_stateless7_bbpe/pretrained.py \ + --checkpoint ./pruned_transducer_stateless7_bbpe/exp/pretrained.pt \ + --bpe-model ./data/lang_bbpe_500/bbpe.model \ + --method greedy_search \ + /path/to/foo.wav \ + /path/to/bar.wav + +(2) modified beam search +./pruned_transducer_stateless7_bbpe/pretrained.py \ + --checkpoint ./pruned_transducer_stateless7_bbpe/exp/pretrained.pt \ + --bpe-model ./data/lang_bbpe_500/bbpe.model \ + --method modified_beam_search \ + --beam-size 4 \ + /path/to/foo.wav \ + /path/to/bar.wav + +(3) fast beam search +./pruned_transducer_stateless7_bbpe/pretrained.py \ + --checkpoint ./pruned_transducer_stateless7_bbpe/exp/pretrained.pt \ + --bpe-model ./data/lang_bbpe_500/bbpe.model \ + --method fast_beam_search \ + --beam-size 4 \ + /path/to/foo.wav \ + /path/to/bar.wav + +Note: ./pruned_transducer_stateless7_bbpe/exp/pretrained.pt is generated by +./pruned_transducer_stateless7_bbpe/export.py +""" + + +import argparse +import logging +import math +from typing import List + +import k2 +import kaldifeat +import sentencepiece as spm +import torch +import torchaudio +from beam_search import ( + beam_search, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, +) +from torch.nn.utils.rnn import pad_sequence +from train import add_model_arguments, get_params, get_transducer_model + +from icefall import smart_byte_decode +from icefall.utils import str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--checkpoint", + type=str, + required=True, + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", + ) + + parser.add_argument( + "--bpe-model", + type=str, + help="""Path to bpe.model.""", + ) + + parser.add_argument( + "--method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - fast_beam_search + """, + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=4, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=4, + help="""Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=8, + help="""Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. Used only when + --method is greedy_search. + """, + ) + + add_model_arguments(parser) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0]) + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + + params = get_params() + + params.update(vars(args)) + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(f"{params}") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + logging.info("Creating model") + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + checkpoint = torch.load(args.checkpoint, map_location="cpu") + model.load_state_dict(checkpoint["model"], strict=False) + model.to(device) + model.eval() + model.device = device + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = params.sample_rate + opts.mel_opts.num_bins = params.feature_dim + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {params.sound_files}") + waves = read_sound_files( + filenames=params.sound_files, expected_sample_rate=params.sample_rate + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + + feature_lengths = torch.tensor(feature_lengths, device=device) + + encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths) + + num_waves = encoder_out.size(0) + hyps = [] + msg = f"Using {params.method}" + if params.method == "beam_search": + msg += f" with beam size {params.beam_size}" + logging.info(msg) + + if params.method == "fast_beam_search": + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(smart_byte_decode(hyp).split()) + elif params.method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + ) + + for hyp in sp.decode(hyp_tokens): + hyps.append(smart_byte_decode(hyp).split()) + elif params.method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(smart_byte_decode(hyp).split()) + else: + for i in range(num_waves): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError(f"Unsupported method: {params.method}") + + hyps.append(smart_byte_decode(sp.decode(hyp)).split()) + + s = "\n" + for filename, hyp in zip(params.sound_files, hyps): + words = " ".join(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/scaling.py b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/scaling.py new file mode 120000 index 000000000..5f9be9fe0 --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/scaling.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7/scaling.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/scaling_converter.py b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/scaling_converter.py new file mode 120000 index 000000000..f9960e5c6 --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/scaling_converter.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7/scaling_converter.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/test_model.py b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/test_model.py new file mode 120000 index 000000000..7ceac5d10 --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/test_model.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7/test_model.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/train.py b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/train.py new file mode 100755 index 000000000..499badb14 --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/train.py @@ -0,0 +1,1261 @@ +#!/usr/bin/env python3 +# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo,) +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +./pruned_transducer_stateless7_bbpe/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --exp-dir pruned_transducer_stateless7_bbpe/exp \ + --max-duration 400 + +# For mix precision training: + +./pruned_transducer_stateless7_bbpe/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir pruned_transducer_stateless7_bbpe/exp \ + --max-duration 800 +""" + + +import argparse +import copy +import logging +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import k2 +import optim +import sentencepiece as spm +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import AishellAsrDataModule +from decoder import Decoder +from joiner import Joiner +from lhotse.cut import Cut, CutSet +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model import Transducer +from optim import Eden, ScaledAdam +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from zipformer import Zipformer + +from icefall import byte_encode, diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.utils import ( + AttributeDict, + MetricsTracker, + filter_uneven_sized_batch, + setup_logger, + str2bool, + tokenize_by_CJK_char, +) + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for module in model.modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=str, + default="2,4,3,2,4", + help="Number of zipformer encoder layers, comma separated.", + ) + + parser.add_argument( + "--feedforward-dims", + type=str, + default="1024,1024,2048,2048,1024", + help="Feedforward dimension of the zipformer encoder layers, comma separated.", + ) + + parser.add_argument( + "--nhead", + type=str, + default="8,8,8,8,8", + help="Number of attention heads in the zipformer encoder layers.", + ) + + parser.add_argument( + "--encoder-dims", + type=str, + default="384,384,384,384,384", + help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated", + ) + + parser.add_argument( + "--attention-dims", + type=str, + default="192,192,192,192,192", + help="""Attention dimension in the 2 blocks of zipformer encoder layers, comma separated; + not the same as embedding dimension.""", + ) + + parser.add_argument( + "--encoder-unmasked-dims", + type=str, + default="256,256,256,256,256", + help="Unmasked dimensions in the encoders, relates to augmentation during training. " + "Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance " + " worse.", + ) + + parser.add_argument( + "--zipformer-downsampling-factors", + type=str, + default="1,2,4,8,2", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--cnn-module-kernels", + type=str, + default="31,31,31,31,31", + help="Sizes of kernels in convolution modules", + ) + + parser.add_argument( + "--decoder-dim", + type=int, + default=512, + help="Embedding dimension in the decoder model.", + ) + + parser.add_argument( + "--joiner-dim", + type=int, + default=512, + help="""Dimension used in the joiner model. + Outputs from the encoder and decoder model are projected + to this dimension before adding. + """, + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7_bbpe/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bbpe_500/bbpe.model", + help="Path to the Byte BPE model", + ) + + parser.add_argument( + "--base-lr", type=float, default=0.05, help="The base learning rate." + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=5000, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=6, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network) part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=2000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 0. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - encoder_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warmup period that dictates the decay of the + scale on "simple" (un-pruned) loss. + """ + params = AttributeDict( + { + "frame_shift_ms": 10.0, + "allowed_excess_duration_ratio": 0.1, + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 2000, + # parameters for zipformer + "feature_dim": 80, + "subsampling_factor": 4, # not passed in, this is fixed. + "warm_step": 2000, + "env_info": get_env_info(), + } + ) + + return params + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + # TODO: We can add an option to switch between Zipformer and Transformer + def to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + + encoder = Zipformer( + num_features=params.feature_dim, + output_downsampling_factor=2, + zipformer_downsampling_factors=to_int_tuple( + params.zipformer_downsampling_factors + ), + encoder_dims=to_int_tuple(params.encoder_dims), + attention_dim=to_int_tuple(params.attention_dims), + encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims), + nhead=to_int_tuple(params.nhead), + feedforward_dim=to_int_tuple(params.feedforward_dims), + cnn_module_kernels=to_int_tuple(params.cnn_module_kernels), + num_encoder_layers=to_int_tuple(params.num_encoder_layers), + ) + return encoder + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + decoder_dim=params.decoder_dim, + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + joiner = Joiner( + encoder_dim=int(params.encoder_dims.split(",")[-1]), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return joiner + + +def get_transducer_model(params: AttributeDict) -> nn.Module: + encoder = get_encoder_model(params) + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + + model = Transducer( + encoder=encoder, + decoder=decoder, + joiner=joiner, + encoder_dim=int(params.encoder_dims.split(",")[-1]), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + if "cur_batch_idx" in saved_params: + params["cur_batch_idx"] = saved_params["cur_batch_idx"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute transducer loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + # For the uneven-sized batch, the total duration after padding would possibly + # cause OOM. Hence, for each batch, which is sorted descendingly by length, + # we simply drop the last few shortest samples, so that the retained total frames + # (after padding) would not exceed `allowed_max_frames`: + # `allowed_max_frames = int(max_frames * (1.0 + allowed_excess_duration_ratio))`, + # where `max_frames = max_duration * 1000 // frame_shift_ms`. + # We set allowed_excess_duration_ratio=0.1. + max_frames = params.max_duration * 1000 // params.frame_shift_ms + allowed_max_frames = int(max_frames * (1.0 + params.allowed_excess_duration_ratio)) + batch = filter_uneven_sized_batch(batch, allowed_max_frames) + + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + batch_idx_train = params.batch_idx_train + warm_step = params.warm_step + + texts = batch["supervisions"]["text"] + y = sp.encode(texts, out_type=int) + y = k2.RaggedTensor(y).to(device) + + with torch.set_grad_enabled(is_training): + simple_loss, pruned_loss = model( + x=feature, + x_lens=feature_lens, + y=y, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + ) + + s = params.simple_loss_scale + # take down the scale on the simple loss from 1.0 at the start + # to params.simple_loss scale by warm_step. + simple_loss_scale = ( + s + if batch_idx_train >= warm_step + else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) + ) + pruned_loss_scale = ( + 1.0 + if batch_idx_train >= warm_step + else 0.1 + 0.9 * (batch_idx_train / warm_step) + ) + + loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + sp: spm.SentencePieceProcessor, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + cur_batch_idx = params.get("cur_batch_idx", 0) + + for batch_idx, batch in enumerate(train_dl): + if batch_idx < cur_batch_idx: + continue + cur_batch_idx = batch_idx + + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + set_batch_count(model, params.batch_idx_train) + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except: # noqa + display_and_save_batch(batch, params=params, sp=sp) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + params.cur_batch_idx = batch_idx + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + del params.cur_batch_idx + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if batch_idx % params.log_interval == 0: + cur_lr = scheduler.get_last_lr()[0] + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", + cur_grad_scale, + params.batch_idx_train, + ) + + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bbpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + parameters_names = [] + parameters_names.append( + [name_param_pair[0] for name_param_pair in model.named_parameters()] + ) + optimizer = ScaledAdam( + model.parameters(), + lr=params.base_lr, + clipping_scale=2.0, + parameters_names=parameters_names, + ) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 2**22 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + aishell = AishellAsrDataModule(args) + + train_cuts = aishell.train_cuts() + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 12 seconds + # + # Caution: There is a reason to select 12.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + if c.duration < 1.0 or c.duration > 12.0: + logging.warning( + f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + ) + return False + + # In pruned RNN-T, we require that T >= S + # where T is the number of feature frames after subsampling + # and S is the number of tokens in the utterance + + # In ./zipformer.py, the conv module uses the following expression + # for subsampling + T = ((c.num_frames - 7) // 2 + 1) // 2 + tokens = sp.encode(c.supervisions[0].text, out_type=str) + + if T < len(tokens): + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Number of frames (before subsampling): {c.num_frames}. " + f"Number of frames (after subsampling): {T}. " + f"Text: {c.supervisions[0].text}. " + f"Tokens: {tokens}. " + f"Number of tokens: {len(tokens)}" + ) + return False + + return True + + def tokenize_and_encode_text(c: Cut): + # Text normalize for each sample + text = c.supervisions[0].text + text = byte_encode(tokenize_by_CJK_char(text)) + c.supervisions[0].text = text + return c + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + train_cuts = train_cuts.map(tokenize_and_encode_text) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = aishell.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) + + valid_cuts = aishell.valid_cuts() + valid_cuts = valid_cuts.map(tokenize_and_encode_text) + + valid_dl = aishell.valid_dataloaders(valid_cuts) + + if not params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + sp=sp, + params=params, + ) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sp=sp, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + sp: spm.SentencePieceProcessor, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + sp: + The BPE model. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + y = sp.encode(supervisions["text"], out_type=int) + num_tokens = sum(len(i) for i in y) + logging.info(f"num tokens: {num_tokens}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + sp: spm.SentencePieceProcessor, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params, sp=sp) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + parser = get_parser() + AishellAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/zipformer.py b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/zipformer.py new file mode 120000 index 000000000..f2f66041e --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/zipformer.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7/zipformer.py \ No newline at end of file diff --git a/egs/aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py index fc28e8dbc..efb32336a 100644 --- a/egs/aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py +++ b/egs/aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py @@ -21,7 +21,7 @@ import inspect import logging from functools import lru_cache from pathlib import Path -from typing import List +from typing import Any, Dict, List, Optional from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy from lhotse.dataset import ( @@ -181,7 +181,16 @@ class AishellAsrDataModule: "with training dataset. ", ) - def train_dataloaders(self, cuts_train: CutSet) -> DataLoader: + def train_dataloaders( + self, cuts_train: CutSet, sampler_state_dict: Optional[Dict[str, Any]] = None + ) -> DataLoader: + """ + Args: + cuts_train: + CutSet for training. + sampler_state_dict: + The state dict for the training sampler. + """ logging.info("About to get Musan cuts") cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") @@ -277,6 +286,10 @@ class AishellAsrDataModule: ) logging.info("About to create train dataloader") + if sampler_state_dict is not None: + logging.info("Loading sampler state dict") + train_sampler.load_state_dict(sampler_state_dict) + train_dl = DataLoader( train, sampler=train_sampler, @@ -325,7 +338,7 @@ class AishellAsrDataModule: return valid_dl def test_dataloaders(self, cuts: CutSet) -> DataLoader: - logging.debug("About to create test dataset") + logging.info("About to create test dataset") test = K2SpeechRecognitionDataset( input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) if self.args.on_the_fly_feats diff --git a/egs/librispeech/ASR/local/compile_lg.py b/egs/librispeech/ASR/local/compile_lg.py index 19bf3bff4..4c42969a1 100755 --- a/egs/librispeech/ASR/local/compile_lg.py +++ b/egs/librispeech/ASR/local/compile_lg.py @@ -45,11 +45,18 @@ def get_args(): help="""Input and output directory. """, ) + parser.add_argument( + "--lm", + type=str, + default="G_3_gram", + help="""Stem name for LM used in HLG compiling. + """, + ) return parser.parse_args() -def compile_LG(lang_dir: str) -> k2.Fsa: +def compile_LG(lang_dir: str, lm: str = "G_3_gram") -> k2.Fsa: """ Args: lang_dir: @@ -61,15 +68,15 @@ def compile_LG(lang_dir: str) -> k2.Fsa: lexicon = Lexicon(lang_dir) L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt")) - if Path("data/lm/G_3_gram.pt").is_file(): - logging.info("Loading pre-compiled G_3_gram") - d = torch.load("data/lm/G_3_gram.pt") + if Path(f"data/lm/{lm}.pt").is_file(): + logging.info(f"Loading pre-compiled {lm}") + d = torch.load(f"data/lm/{lm}.pt") G = k2.Fsa.from_dict(d) else: - logging.info("Loading G_3_gram.fst.txt") - with open("data/lm/G_3_gram.fst.txt") as f: + logging.info(f"Loading {lm}.fst.txt") + with open(f"data/lm/{lm}.fst.txt") as f: G = k2.Fsa.from_openfst(f.read(), acceptor=False) - torch.save(G.as_dict(), "data/lm/G_3_gram.pt") + torch.save(G.as_dict(), f"data/lm/{lm}.pt") first_token_disambig_id = lexicon.token_table["#0"] first_word_disambig_id = lexicon.word_table["#0"] @@ -126,7 +133,7 @@ def main(): logging.info(f"Processing {lang_dir}") - LG = compile_LG(lang_dir) + LG = compile_LG(lang_dir, args.lm) logging.info(f"Saving LG.pt to {lang_dir}") torch.save(LG.as_dict(), f"{lang_dir}/LG.pt") diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py index e45f2c652..0280193ca 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -47,6 +47,8 @@ def fast_beam_search_one_best( max_states: int, max_contexts: int, temperature: float = 1.0, + subtract_ilme: bool = False, + ilme_scale: float = 0.1, return_timestamps: bool = False, ) -> Union[List[List[int]], DecodingResults]: """It limits the maximum number of symbols per frame to 1. @@ -88,6 +90,8 @@ def fast_beam_search_one_best( max_states=max_states, max_contexts=max_contexts, temperature=temperature, + subtract_ilme=subtract_ilme, + ilme_scale=ilme_scale, ) best_path = one_best_decoding(lattice) @@ -428,6 +432,8 @@ def fast_beam_search( max_states: int, max_contexts: int, temperature: float = 1.0, + subtract_ilme: bool = False, + ilme_scale: float = 0.1, ) -> k2.Fsa: """It limits the maximum number of symbols per frame to 1. @@ -498,6 +504,17 @@ def fast_beam_search( ) logits = logits.squeeze(1).squeeze(1) log_probs = (logits / temperature).log_softmax(dim=-1) + if subtract_ilme: + ilme_logits = model.joiner( + torch.zeros_like( + current_encoder_out, device=current_encoder_out.device + ).unsqueeze(2), + decoder_out.unsqueeze(1), + project_input=False, + ) + ilme_logits = ilme_logits.squeeze(1).squeeze(1) + ilme_log_probs = (ilme_logits / temperature).log_softmax(dim=-1) + log_probs -= ilme_scale * ilme_log_probs decoding_streams.advance(log_probs) decoding_streams.terminate_and_flush_to_streams() lattice = decoding_streams.format_output(encoder_out_lens.tolist()) diff --git a/icefall/__init__.py b/icefall/__init__.py index 82d21706c..5d846b41d 100644 --- a/icefall/__init__.py +++ b/icefall/__init__.py @@ -8,6 +8,12 @@ from . import ( utils ) +from .byte_utils import ( + byte_decode, + byte_encode, + smart_byte_decode, +) + from .checkpoint import ( average_checkpoints, find_checkpoints, @@ -49,6 +55,7 @@ from .utils import ( get_alignments, get_executor, get_texts, + is_cjk, is_jit_tracing, is_module_available, l1_norm, @@ -64,6 +71,7 @@ from .utils import ( store_transcripts, str2bool, subsequent_chunk_mask, + tokenize_by_CJK_char, write_error_stats, ) diff --git a/icefall/byte_utils.py b/icefall/byte_utils.py new file mode 100644 index 000000000..7ee84ad27 --- /dev/null +++ b/icefall/byte_utils.py @@ -0,0 +1,311 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +# +# This file was copied and modified from https://github.com/facebookresearch/fairseq/blob/main/fairseq/data/encoders/byte_utils.py + +import re +import unicodedata + + +WHITESPACE_NORMALIZER = re.compile(r"\s+") +SPACE = chr(32) +SPACE_ESCAPE = chr(9601) + +PRINTABLE_BASE_CHARS = [ + 256, + 257, + 258, + 259, + 260, + 261, + 262, + 263, + 264, + 265, + 266, + 267, + 268, + 269, + 270, + 271, + 272, + 273, + 274, + 275, + 276, + 277, + 278, + 279, + 280, + 281, + 282, + 283, + 284, + 285, + 286, + 287, + 32, + 33, + 34, + 35, + 36, + 37, + 38, + 39, + 40, + 41, + 42, + 43, + 44, + 45, + 46, + 47, + 48, + 49, + 50, + 51, + 52, + 53, + 54, + 55, + 56, + 57, + 58, + 59, + 60, + 61, + 62, + 63, + 64, + 65, + 66, + 67, + 68, + 69, + 70, + 71, + 72, + 73, + 74, + 75, + 76, + 77, + 78, + 79, + 80, + 81, + 82, + 83, + 84, + 85, + 86, + 87, + 88, + 89, + 90, + 91, + 92, + 93, + 94, + 95, + 96, + 97, + 98, + 99, + 100, + 101, + 102, + 103, + 104, + 105, + 106, + 107, + 108, + 109, + 110, + 111, + 112, + 113, + 114, + 115, + 116, + 117, + 118, + 119, + 120, + 121, + 122, + 123, + 124, + 125, + 126, + 288, + 289, + 290, + 291, + 292, + 293, + 294, + 295, + 296, + 297, + 298, + 299, + 300, + 301, + 302, + 303, + 304, + 305, + 308, + 309, + 310, + 311, + 312, + 313, + 314, + 315, + 316, + 317, + 318, + 321, + 322, + 323, + 324, + 325, + 326, + 327, + 328, + 330, + 331, + 332, + 333, + 334, + 335, + 336, + 337, + 338, + 339, + 340, + 341, + 342, + 343, + 344, + 345, + 346, + 347, + 348, + 349, + 350, + 351, + 352, + 353, + 354, + 355, + 356, + 357, + 358, + 359, + 360, + 361, + 362, + 363, + 364, + 365, + 366, + 367, + 368, + 369, + 370, + 371, + 372, + 373, + 374, + 375, + 376, + 377, + 378, + 379, + 380, + 381, + 382, + 384, + 385, + 386, + 387, + 388, + 389, + 390, + 391, + 392, + 393, + 394, + 395, + 396, + 397, + 398, + 399, + 400, + 401, + 402, + 403, + 404, + 405, + 406, + 407, + 408, + 409, + 410, + 411, + 412, + 413, + 414, + 415, + 416, + 417, + 418, + 419, + 420, + 421, + 422, +] + +for c in PRINTABLE_BASE_CHARS: + assert unicodedata.normalize("NFKC", chr(c)) == chr(c), c + +BYTE_TO_BCHAR = {b: chr(PRINTABLE_BASE_CHARS[b]) for b in range(256)} +BCHAR_TO_BYTE = {bc: b for b, bc in BYTE_TO_BCHAR.items()} + + +def byte_encode(x: str) -> str: + normalized = WHITESPACE_NORMALIZER.sub(SPACE, x) + return "".join([BYTE_TO_BCHAR[b] for b in normalized.encode("utf-8")]) + + +def byte_decode(x: str) -> str: + try: + return bytes([BCHAR_TO_BYTE[bc] for bc in x]).decode("utf-8") + except ValueError: + return "" + + +def smart_byte_decode(x: str) -> str: + output = byte_decode(x) + if output == "": + # DP the best recovery (max valid chars) if it's broken + n_bytes = len(x) + f = [0 for _ in range(n_bytes + 1)] + pt = [0 for _ in range(n_bytes + 1)] + for i in range(1, n_bytes + 1): + f[i], pt[i] = f[i - 1], i - 1 + for j in range(1, min(4, i) + 1): + if f[i - j] + 1 > f[i] and len(byte_decode(x[i - j : i])) > 0: + f[i], pt[i] = f[i - j] + 1, i - j + cur_pt = n_bytes + while cur_pt > 0: + if f[cur_pt] == f[pt[cur_pt]] + 1: + output = byte_decode(x[pt[cur_pt] : cur_pt]) + output + cur_pt = pt[cur_pt] + return output diff --git a/icefall/rnn_lm/model.py b/icefall/rnn_lm/model.py index 8d3e16432..a8eaadc0c 100644 --- a/icefall/rnn_lm/model.py +++ b/icefall/rnn_lm/model.py @@ -87,7 +87,7 @@ class RnnLmModel(torch.nn.Module): h0: torch.Tensor, c0: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - ''' + """ Args: x: A 2-D tensor of shape (N, L). We won't prepend it with SOS. @@ -103,7 +103,7 @@ class RnnLmModel(torch.nn.Module): - negative loglike (nll), a 1-D tensor of shape (N,) - next_h0, a 3-D tensor with the same shape as h0 - next_c0, a 3-D tensor with the same shape as c0 - ''' + """ assert x.ndim == y.ndim == 2, (x.ndim, y.ndim) assert x.shape == y.shape, (x.shape, y.shape) diff --git a/icefall/utils.py b/icefall/utils.py index 1fd9156bd..4aa8197ad 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -1306,6 +1306,31 @@ def tokenize_by_bpe_model( return txt_with_bpe +def tokenize_by_CJK_char(line: str) -> str: + """ + Tokenize a line of text with CJK char. + + Note: All return charaters will be upper case. + + Example: + input = "你好世界是 hello world 的中文" + output = "你 好 世 界 是 HELLO WORLD 的 中 文" + + Args: + line: + The input text. + + Return: + A new string tokenize by CJK char. + """ + # The CJK ranges is from https://github.com/alvations/nltk/blob/79eed6ddea0d0a2c212c1060b477fc268fec4d4b/nltk/tokenize/util.py + pattern = re.compile( + r"([\u1100-\u11ff\u2e80-\ua4cf\ua840-\uD7AF\uF900-\uFAFF\uFE30-\uFE4F\uFF65-\uFFDC\U00020000-\U0002FFFF])" + ) + chars = pattern.split(line.strip().upper()) + return " ".join([w.strip() for w in chars if w.strip()]) + + def display_and_save_batch( batch: dict, params: AttributeDict, @@ -1764,3 +1789,34 @@ def parse_fsa_timestamps_and_texts( utt_time_pairs.append(list(zip(start, end))) return utt_time_pairs, utt_words + + +# Copied from https://github.com/alvations/nltk/blob/79eed6ddea0d0a2c212c1060b477fc268fec4d4b/nltk/tokenize/util.py +def is_cjk(character): + """ + Python port of Moses' code to check for CJK character. + + >>> is_cjk(u'\u33fe') + True + >>> is_cjk(u'\uFE5F') + False + + :param character: The character that needs to be checked. + :type character: char + :return: bool + """ + return any( + [ + start <= ord(character) <= end + for start, end in [ + (4352, 4607), + (11904, 42191), + (43072, 43135), + (44032, 55215), + (63744, 64255), + (65072, 65103), + (65381, 65500), + (131072, 196607), + ] + ] + ) From 98569b2607250c3b73f351828759bfd7de37d6f4 Mon Sep 17 00:00:00 2001 From: Yifan Yang <64255737+yfyeung@users.noreply.github.com> Date: Sat, 6 May 2023 17:51:55 +0800 Subject: [PATCH 173/174] Update RESULTS.md (#1036) * Update RESULTS.md --- egs/librispeech/ASR/RESULTS.md | 71 ++++++++++++++++++++++++++++++++++ 1 file changed, 71 insertions(+) diff --git a/egs/librispeech/ASR/RESULTS.md b/egs/librispeech/ASR/RESULTS.md index ef817d5dd..2ca0558ab 100644 --- a/egs/librispeech/ASR/RESULTS.md +++ b/egs/librispeech/ASR/RESULTS.md @@ -1,5 +1,76 @@ ## Results +### pruned_transducer_stateless7 (zipformer + multidataset(LibriSpeech + GigaSpeech + CommonVoice 13.0)) + +See for more details. + +[pruned_transducer_stateless7](./pruned_transducer_stateless7) + +The tensorboard log can be found at + + +You can find a pretrained model, training logs, decoding logs, and decoding +results at: + + +You can use to deploy it. + +Number of model parameters: 70369391, i.e., 70.37 M + +| decoding method | test-clean | test-other | comment | +|----------------------|------------|------------|--------------------| +| greedy_search | 1.91 | 4.06 | --epoch 30 --avg 7 | +| modified_beam_search | 1.90 | 3.99 | --epoch 30 --avg 7 | +| fast_beam_search | 1.90 | 3.98 | --epoch 30 --avg 7 | + + +The training commands are: +```bash +export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" + +./pruned_transducer_stateless7/train.py \ + --world-size 8 \ + --num-epochs 30 \ + --use-multidataset 1 \ + --use-fp16 1 \ + --max-duration 750 \ + --exp-dir pruned_transducer_stateless7/exp +``` + +The decoding commands are: +```bash +# greedy_search +./pruned_transducer_stateless7/decode.py \ + --epoch 30 \ + --avg 7 \ + --use-averaged-model 1 \ + --exp-dir ./pruned_transducer_stateless7/exp \ + --max-duration 600 \ + --decoding-method greedy_search + +# modified_beam_search +./pruned_transducer_stateless7/decode.py \ + --epoch 30 \ + --avg 7 \ + --use-averaged-model 1 \ + --exp-dir ./pruned_transducer_stateless7/exp \ + --max-duration 600 \ + --decoding-method modified_beam_search \ + --beam-size 4 + +# fast_beam_search +./pruned_transducer_stateless7/decode.py \ + --epoch 30 \ + --avg 7 \ + --use-averaged-model 1 \ + --exp-dir ./pruned_transducer_stateless7/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 +``` + ### Streaming Zipformer-Transducer (Pruned Stateless Transducer + Streaming Zipformer + Multi-Dataset) #### [pruned_transducer_stateless7_streaming_multi](./pruned_transducer_stateless7_streaming_multi) From efbb577b88b8f72bc4e903561792c70cbcbb3ca1 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Sun, 7 May 2023 16:26:13 +0800 Subject: [PATCH 174/174] fix compiling HLG (#1039) --- egs/librispeech/ASR/local/compile_hlg.py | 9 +++++---- egs/librispeech/ASR/local/compile_lg.py | 9 +++++---- egs/timit/ASR/local/compile_hlg.py | 6 +++++- egs/yesno/ASR/local/compile_hlg.py | 9 +++++---- 4 files changed, 20 insertions(+), 13 deletions(-) diff --git a/egs/librispeech/ASR/local/compile_hlg.py b/egs/librispeech/ASR/local/compile_hlg.py index 08dac6a7b..d19d50ae6 100755 --- a/egs/librispeech/ASR/local/compile_hlg.py +++ b/egs/librispeech/ASR/local/compile_hlg.py @@ -109,10 +109,11 @@ def compile_HLG(lang_dir: str, lm: str = "G_3_gram") -> k2.Fsa: logging.info("Removing disambiguation symbols on LG") - LG.labels[LG.labels >= first_token_disambig_id] = 0 - # See https://github.com/k2-fsa/k2/issues/874 - # for why we need to set LG.properties to None - LG.__dict__["_properties"] = None + # LG.labels[LG.labels >= first_token_disambig_id] = 0 + # see https://github.com/k2-fsa/k2/pull/1140 + labels = LG.labels + labels[labels >= first_token_disambig_id] = 0 + LG.labels = labels assert isinstance(LG.aux_labels, k2.RaggedTensor) LG.aux_labels.values[LG.aux_labels.values >= first_word_disambig_id] = 0 diff --git a/egs/librispeech/ASR/local/compile_lg.py b/egs/librispeech/ASR/local/compile_lg.py index 4c42969a1..709b14070 100755 --- a/egs/librispeech/ASR/local/compile_lg.py +++ b/egs/librispeech/ASR/local/compile_lg.py @@ -103,10 +103,11 @@ def compile_LG(lang_dir: str, lm: str = "G_3_gram") -> k2.Fsa: logging.info("Removing disambiguation symbols on LG") - LG.labels[LG.labels >= first_token_disambig_id] = 0 - # See https://github.com/k2-fsa/k2/issues/874 - # for why we need to set LG.properties to None - LG.__dict__["_properties"] = None + # LG.labels[LG.labels >= first_token_disambig_id] = 0 + # see https://github.com/k2-fsa/k2/pull/1140 + labels = LG.labels + labels[labels >= first_token_disambig_id] = 0 + LG.labels = labels assert isinstance(LG.aux_labels, k2.RaggedTensor) LG.aux_labels.values[LG.aux_labels.values >= first_word_disambig_id] = 0 diff --git a/egs/timit/ASR/local/compile_hlg.py b/egs/timit/ASR/local/compile_hlg.py index 32c248d7e..c8562f4fb 100644 --- a/egs/timit/ASR/local/compile_hlg.py +++ b/egs/timit/ASR/local/compile_hlg.py @@ -100,7 +100,11 @@ def compile_HLG(lang_dir: str) -> k2.Fsa: logging.info("Removing disambiguation symbols on LG") - LG.labels[LG.labels >= first_token_disambig_id] = 0 + # LG.labels[LG.labels >= first_token_disambig_id] = 0 + # see https://github.com/k2-fsa/k2/pull/1140 + labels = LG.labels + labels[labels >= first_token_disambig_id] = 0 + LG.labels = labels LG.aux_labels.values[LG.aux_labels.values >= first_word_disambig_id] = 0 diff --git a/egs/yesno/ASR/local/compile_hlg.py b/egs/yesno/ASR/local/compile_hlg.py index 7234ca929..e0a94bf08 100755 --- a/egs/yesno/ASR/local/compile_hlg.py +++ b/egs/yesno/ASR/local/compile_hlg.py @@ -78,10 +78,11 @@ def compile_HLG(lang_dir: str) -> k2.Fsa: logging.info("Removing disambiguation symbols on LG") - LG.labels[LG.labels >= first_token_disambig_id] = 0 - # See https://github.com/k2-fsa/k2/issues/874 - # for why we need to set LG.properties to None - LG.__dict__["_properties"] = None + # LG.labels[LG.labels >= first_token_disambig_id] = 0 + # see https://github.com/k2-fsa/k2/pull/1140 + labels = LG.labels + labels[labels >= first_token_disambig_id] = 0 + LG.labels = labels assert isinstance(LG.aux_labels, k2.RaggedTensor) LG.aux_labels.values[LG.aux_labels.values >= first_word_disambig_id] = 0